Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
262b4a5c
Commit
262b4a5c
authored
Jan 30, 2025
by
Andriy Roshchenko
Browse files
Allow selection of initialization algorithm
parent
f625455c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
17 additions
and
8 deletions
+17
-8
test/mx_mfma_op/mx_mfma_op.cpp
test/mx_mfma_op/mx_mfma_op.cpp
+11
-4
test/mx_mfma_op/mx_mfma_op.hpp
test/mx_mfma_op/mx_mfma_op.hpp
+6
-4
No files found.
test/mx_mfma_op/mx_mfma_op.cpp
View file @
262b4a5c
...
@@ -10,8 +10,13 @@ using ck::f8_t;
...
@@ -10,8 +10,13 @@ using ck::f8_t;
using
ck
::
half_t
;
using
ck
::
half_t
;
using
ck
::
type_convert
;
using
ck
::
type_convert
;
/**
* @brief Run the test for the given MFMA instruction
*
* @param init - selects initialization algorithm for A and B tensors
*/
template
<
typename
AType
,
typename
BType
,
typename
CType
,
ck
::
mx_mfma_test
::
MFMA_F8F6F4
mfma
>
template
<
typename
AType
,
typename
BType
,
typename
CType
,
ck
::
mx_mfma_test
::
MFMA_F8F6F4
mfma
>
bool
run_test
()
bool
run_test
(
ck
::
index_t
init
)
{
{
using
ALayout
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
ALayout
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
BLayout
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
BLayout
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
...
@@ -41,20 +46,22 @@ bool run_test()
...
@@ -41,20 +46,22 @@ bool run_test()
CLayout
,
CLayout
,
BLOCK_M
,
BLOCK_M
,
BLOCK_N
,
BLOCK_N
,
BLOCK_K
>
{}(
mx_mfma_kernel
);
BLOCK_K
>
{}(
mx_mfma_kernel
,
init
);
return
pass
;
return
pass
;
}
}
TEST
(
MFMA
,
FP8MFMA16x16x128
)
TEST
(
MFMA
,
FP8MFMA16x16x128
)
{
{
auto
pass
=
run_test
<
f8_t
,
f8_t
,
half_t
,
ck
::
mx_mfma_test
::
MFMA_F8F6F4
::
F32_16x16x128
>
();
auto
AB_init
=
0
;
auto
pass
=
run_test
<
f8_t
,
f8_t
,
half_t
,
ck
::
mx_mfma_test
::
MFMA_F8F6F4
::
F32_16x16x128
>
(
AB_init
);
EXPECT_TRUE
(
pass
);
EXPECT_TRUE
(
pass
);
}
}
TEST
(
MFMA
,
FP8MFMA32x32x64
)
TEST
(
MFMA
,
FP8MFMA32x32x64
)
{
{
auto
pass
=
run_test
<
f8_t
,
f8_t
,
float
,
ck
::
mx_mfma_test
::
MFMA_F8F6F4
::
F32_32x32x64
>
();
auto
AB_init
=
0
;
auto
pass
=
run_test
<
f8_t
,
f8_t
,
float
,
ck
::
mx_mfma_test
::
MFMA_F8F6F4
::
F32_32x32x64
>
(
AB_init
);
EXPECT_TRUE
(
pass
);
EXPECT_TRUE
(
pass
);
}
}
...
...
test/mx_mfma_op/mx_mfma_op.hpp
View file @
262b4a5c
...
@@ -433,7 +433,7 @@ template <typename DeviceMFMA,
...
@@ -433,7 +433,7 @@ template <typename DeviceMFMA,
index_t
BLOCK_K
>
index_t
BLOCK_K
>
struct
TestMFMA
struct
TestMFMA
{
{
auto
PrepareGemmTensors
(
const
GemmParams
&
params
)
auto
PrepareGemmTensors
(
const
GemmParams
&
params
,
index_t
init
)
{
{
auto
f_host_tensor_descriptor
=
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
...
@@ -458,7 +458,7 @@ struct TestMFMA
...
@@ -458,7 +458,7 @@ struct TestMFMA
Tensor
<
CDataType
>
c_m_n_device_result
(
Tensor
<
CDataType
>
c_m_n_device_result
(
f_host_tensor_descriptor
(
params
.
M
,
params
.
N
,
params
.
StrideC
,
CLayout
{}));
f_host_tensor_descriptor
(
params
.
M
,
params
.
N
,
params
.
StrideC
,
CLayout
{}));
switch
(
0
)
switch
(
init
)
{
{
case
0
:
case
0
:
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_1
<
ADataType
>
{
0.015625
f
});
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_1
<
ADataType
>
{
0.015625
f
});
...
@@ -466,6 +466,7 @@ struct TestMFMA
...
@@ -466,6 +466,7 @@ struct TestMFMA
b_n_k
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
BDataType
,
1
>
{});
b_n_k
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
BDataType
,
1
>
{});
break
;
break
;
case
1
:
case
1
:
// results in C = {K}
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_1
<
ADataType
>
{
1.0
f
});
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_1
<
ADataType
>
{
1.0
f
});
b_n_k
.
GenerateTensorValue
(
GeneratorTensor_1
<
BDataType
>
{
1.0
f
});
b_n_k
.
GenerateTensorValue
(
GeneratorTensor_1
<
BDataType
>
{
1.0
f
});
break
;
break
;
...
@@ -480,6 +481,7 @@ struct TestMFMA
...
@@ -480,6 +481,7 @@ struct TestMFMA
b_n_k
.
GenerateTensorValue
(
GeneratorTensor_4
<
BDataType
>
(
1
,
3
));
b_n_k
.
GenerateTensorValue
(
GeneratorTensor_4
<
BDataType
>
(
1
,
3
));
break
;
break
;
default:
default:
// all initial values are representable in FP8, BF8
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
5
,
6
});
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
5
,
6
});
b_n_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
5
,
6
});
b_n_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
5
,
6
});
...
@@ -489,7 +491,7 @@ struct TestMFMA
...
@@ -489,7 +491,7 @@ struct TestMFMA
return
std
::
make_tuple
(
a_m_k
,
b_n_k
,
c_m_n_host_result
,
c_m_n_device_result
);
return
std
::
make_tuple
(
a_m_k
,
b_n_k
,
c_m_n_host_result
,
c_m_n_device_result
);
}
}
auto
operator
()(
const
DeviceMFMA
&
mfma_kernel
)
auto
operator
()(
const
DeviceMFMA
&
mfma_kernel
,
index_t
init
)
{
{
std
::
cout
<<
"ALayout = "
<<
ALayout
{}.
name
<<
", BLayout = "
<<
BLayout
{}.
name
std
::
cout
<<
"ALayout = "
<<
ALayout
{}.
name
<<
", BLayout = "
<<
BLayout
{}.
name
<<
", CLayout = "
<<
CLayout
{}.
name
<<
std
::
endl
;
<<
", CLayout = "
<<
CLayout
{}.
name
<<
std
::
endl
;
...
@@ -524,7 +526,7 @@ struct TestMFMA
...
@@ -524,7 +526,7 @@ struct TestMFMA
params
.
StrideB
=
f_get_default_stride
(
BLOCK_K
,
BLOCK_N
,
params
.
StrideB
,
BLayout
{});
params
.
StrideB
=
f_get_default_stride
(
BLOCK_K
,
BLOCK_N
,
params
.
StrideB
,
BLayout
{});
params
.
StrideC
=
f_get_default_stride
(
BLOCK_M
,
BLOCK_N
,
params
.
StrideC
,
CLayout
{});
params
.
StrideC
=
f_get_default_stride
(
BLOCK_M
,
BLOCK_N
,
params
.
StrideC
,
CLayout
{});
auto
host_tensors
=
PrepareGemmTensors
(
params
);
auto
host_tensors
=
PrepareGemmTensors
(
params
,
init
);
const
Tensor
<
ADataType
>&
a
=
std
::
get
<
0
>
(
host_tensors
);
const
Tensor
<
ADataType
>&
a
=
std
::
get
<
0
>
(
host_tensors
);
const
Tensor
<
BDataType
>&
b
=
std
::
get
<
1
>
(
host_tensors
);
const
Tensor
<
BDataType
>&
b
=
std
::
get
<
1
>
(
host_tensors
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment