Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
06af86fb
Commit
06af86fb
authored
Feb 07, 2025
by
Andriy Roshchenko
Browse files
WIP: Enabling gfx90a build
parent
9da21f99
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
27 additions
and
23 deletions
+27
-23
example/67_gemm_microscaling/gemm_mx_common.hpp
example/67_gemm_microscaling/gemm_mx_common.hpp
+26
-22
example/67_gemm_microscaling/gemm_mx_fp8.cpp
example/67_gemm_microscaling/gemm_mx_fp8.cpp
+1
-1
No files found.
example/67_gemm_microscaling/gemm_mx_common.hpp
View file @
06af86fb
...
@@ -27,19 +27,21 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
...
@@ -27,19 +27,21 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
ck
::
type_convert
;
struct
ExecutionConfig
final
struct
ExecutionConfig
final
{
{
int
do_verification
=
1
;
// (0=no, 1=CPU)
int
do_verification
=
1
;
// (0=no, 1=CPU)
int
init_method
=
2
;
// (0=no init, 1=integer value, 2=decimal value)
int
init_method
=
10
;
// (0=no init, 1=integer value, 2=decimal value)
bool
time_kernel
=
false
;
// (0=no, 1=yes)
bool
time_kernel
=
false
;
// (0=no, 1=yes)
int
verbosity
=
0
;
// (0=no info, 1=verbose info)
int
verbosity
=
1
;
// (0=no info, 1=verbose info)
};
};
struct
ProblemSize
final
struct
ProblemSize
final
{
{
ck
::
index_t
M
=
3840
;
ck
::
index_t
M
=
256
;
ck
::
index_t
N
=
409
6
;
ck
::
index_t
N
=
25
6
;
ck
::
index_t
K
=
4096
;
ck
::
index_t
K
=
384
;
ck
::
index_t
StrideA
=
-
1
;
ck
::
index_t
StrideA
=
-
1
;
ck
::
index_t
StrideB
=
-
1
;
ck
::
index_t
StrideB
=
-
1
;
...
@@ -139,11 +141,11 @@ bool run_mx_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
...
@@ -139,11 +141,11 @@ bool run_mx_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
// clang-format off
// clang-format off
using
DeviceOpInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
using
DeviceOpInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
// ######| ALayout| BLayout| DsLayout| CLayout| ADataType| AScale| BDataType| BScale| DsDataType| CDataType| GemmAcc| CShuffleDataType|AElementwise|BElementwise| CElementwise| GemmSpec|Block| ScaleBlockM| ScaleBlockN| ScaleBlockK| M| N| K| AK1| BK1| M| N|MXdl|NXdl|ABlockTransfer|ABlockTransfer|ABlockTransfer|ABlockTransfer|ABlockTransfer|ABlockTransfer| ABlock|BBlockTransfer|BBlockTransfer|BBlockTransfer|BBlockTransfer|BBlockTransfer|BBlockTransfer| BBlock| CShuffle| CShuffle|CShuffleBlockTransfer|CDEShuffleBlockTransfer| BlkGemm| BlkGemm|ComputeTypeA|ComputeTypeB|LDSTypeA|LDSTypeB|
// ######| ALayout| BLayout| DsLayout| CLayout| ADataType| AScale| BDataType| BScale| DsDataType| CDataType| GemmAcc| CShuffleDataType|AElementwise|BElementwise| CElementwise| GemmSpec|Block| ScaleBlockM| ScaleBlockN| ScaleBlockK| M| N| K| AK1| BK1| M| N|MXdl|NXdl|ABlockTransfer|ABlockTransfer|ABlockTransfer|ABlockTransfer|ABlockTransfer|ABlockTransfer| ABlock|BBlockTransfer|BBlockTransfer|BBlockTransfer|BBlockTransfer|BBlockTransfer|BBlockTransfer| BBlock| CShuffle| CShuffle|CShuffleBlockTransfer|CDEShuffleBlockTransfer| BlkGemm| BlkGemm|ComputeTypeA|ComputeTypeB|
LDSTypeA|
LDSTypeB|
// ######| | | | | | DataType| | DataType| | | DataType| | Operation| Operation| Operation| | Size| | | | Per| Per| Per| | | Per| Per| Per| Per| ThreadCluster| ThreadCluster|SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar|LdsExtraM| ThreadCluster| ThreadCluster|SrcAccessOrder| SrcVector| SrcScalar| DstScalar|LdsExtraN| MXdl| NXdl| ClusterLengths| Scalar| PipeSched| PipelineVer| | | | |
// ######| | | | | | DataType| | DataType| | | DataType| | Operation| Operation| Operation| | Size| | | | Per| Per| Per| | | Per| Per| Per| Per| ThreadCluster| ThreadCluster|SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar|LdsExtraM| ThreadCluster| ThreadCluster|SrcAccessOrder| SrcVector| SrcScalar| DstScalar|LdsExtraN| MXdl| NXdl| ClusterLengths| Scalar| PipeSched| PipelineVer| | | | |
// ######| | | | | | | | | | | | | | | | | | | | |Block|Block| Block| | | XDL| XDL|Wave|Wave| Lengths| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths| ArrangeOrder| | Dim| PerVector| PerVector_BK1| | PerWave| PerWave| MBlock_MPerBlock| PerVectors| | | | | | |
// ######| | | | | | | | | | | | | | | | | | | | |Block|Block| Block| | | XDL| XDL|Wave|Wave| Lengths| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths| ArrangeOrder| | Dim| PerVector| PerVector_BK1| | PerWave| PerWave| MBlock_MPerBlock| PerVectors| | | | | | |
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | AK0_M_AK1| | | | | | | BK0_N_BK1| | | | | |PerShuffle|PerShuffle| NBlock_NPerBlock| | | | | | | |
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | AK0_M_AK1| | | | | | | BK0_N_BK1| | | | | |PerShuffle|PerShuffle| NBlock_NPerBlock| | | | | | | |
<
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
ADataType
,
XDataType
,
BDataType
,
XDataType
,
DsDataType
,
CDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
,
GemmSpec
,
256
,
Scale_Block_M
,
Scale_Block_N
,
Scale_Block_K
,
128
,
128
,
128
,
16
,
16
,
16
,
16
,
4
,
4
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlkGemmPSched
,
BlkGemmPVer
,
float
,
float
,
float
,
float
>
;
<
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
ADataType
,
XDataType
,
BDataType
,
XDataType
,
DsDataType
,
CDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
,
GemmSpec
,
256
,
Scale_Block_M
,
Scale_Block_N
,
Scale_Block_K
,
128
,
128
,
128
,
16
,
16
,
16
,
16
,
4
,
4
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlkGemmPSched
,
BlkGemmPVer
,
float
,
float
,
ADataType
,
BDataType
>
;
// clang-format on
// clang-format on
auto
M
=
problem_size
.
M
;
auto
M
=
problem_size
.
M
;
...
@@ -225,19 +227,18 @@ bool run_mx_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
...
@@ -225,19 +227,18 @@ bool run_mx_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
std
::
cout
<<
"NOTE: No input data initialization."
<<
std
::
endl
;
std
::
cout
<<
"NOTE: No input data initialization."
<<
std
::
endl
;
}
}
break
;
break
;
case
1
:
case
10
:
// Initializations for development and debugging
case
2
:
ck
::
utils
::
FillConstant
<
ADataType
>
{
ck
::
type_convert
<
ADataType
>
(
1.0
f
)}(
a_m_k
);
ck
::
utils
::
FillConstant
<
ADataType
>
{
ck
::
type_convert
<
ADataType
>
(
1.0
f
)}(
a_m_k
);
ck
::
utils
::
FillConstant
<
XDataType
>
{
ck
::
type_convert
<
XDataType
>
(
0.5
f
)}(
a_m_k_scale
);
ck
::
utils
::
FillConstant
<
XDataType
>
{
ck
::
type_convert
<
XDataType
>
(
0.
2
5
f
)}(
a_m_k_scale
);
ck
::
utils
::
FillConstant
<
BDataType
>
{
ck
::
type_convert
<
BDataType
>
(
1.0
f
)}(
b_k_n
);
ck
::
utils
::
FillConstant
<
BDataType
>
{
ck
::
type_convert
<
BDataType
>
(
0.25
f
)}(
b_k_n
);
ck
::
utils
::
FillConstant
<
XDataType
>
{
ck
::
type_convert
<
XDataType
>
(
2.0
f
)}(
b_k_n_scale
);
ck
::
utils
::
FillConstant
<
XDataType
>
{
ck
::
type_convert
<
XDataType
>
(
2.0
f
)}(
b_k_n_scale
);
if
(
config
.
verbosity
>
0
)
if
(
config
.
verbosity
>
0
)
{
{
std
::
cout
<<
"Init A = {1}"
<<
std
::
endl
;
std
::
cout
<<
"Init A = {1}"
<<
std
::
endl
;
std
::
cout
<<
"Init A scale = {0.5}"
<<
std
::
endl
;
std
::
cout
<<
"Init A scale = {0.
2
5}"
<<
std
::
endl
;
std
::
cout
<<
"Init B = {
1
}"
<<
std
::
endl
;
std
::
cout
<<
"Init B = {
0.25
}"
<<
std
::
endl
;
std
::
cout
<<
"Init B scale = {2.0}"
<<
std
::
endl
;
std
::
cout
<<
"Init B scale = {2.0}"
<<
std
::
endl
;
std
::
cout
<<
"Expect C = {K}"
<<
std
::
endl
;
std
::
cout
<<
"Expect C = {K
*(0.25*0.5)
}"
<<
std
::
endl
;
}
}
break
;
break
;
...
@@ -343,12 +344,15 @@ bool run_mx_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
...
@@ -343,12 +344,15 @@ bool run_mx_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
std
::
cout
<<
"Comparing results..."
<<
std
::
endl
;
std
::
cout
<<
"Comparing results..."
<<
std
::
endl
;
}
}
if
(
config
.
init_method
==
1
)
if
(
config
.
init_method
==
1
0
)
{
{
res_verified
=
auto
expected
=
static_cast
<
float
>
(
K
)
*
(
0.25
f
*
0.5
f
);
res_verified
&&
std
::
abs
(
static_cast
<
float
>
(
K
)
-
c_m_n_device_result
(
0
,
0
))
<=
0.0
f
;
auto
computed
=
type_convert
<
float
>
(
c_m_n_device_result
(
1
,
12
));
std
::
cout
<<
"Expected vs Computed: "
<<
1.0
f
*
K
<<
" vs "
<<
c_m_n_device_result
(
0
,
0
)
<<
((
res_verified
)
?
" (PASSED!)"
:
" (FAILED!)"
)
<<
std
::
endl
;
res_verified
=
res_verified
&&
std
::
abs
(
expected
-
computed
)
<=
0.0
f
;
std
::
cout
<<
"
\n
Expected vs Computed: "
<<
expected
<<
" vs "
<<
computed
<<
((
res_verified
)
?
" (PASSED!)"
:
" (FAILED!)"
)
<<
std
::
endl
<<
std
::
endl
;
}
}
res_verified
=
res_verified
&&
ck
::
utils
::
check_err
(
c_m_n_device_result
,
res_verified
=
res_verified
&&
ck
::
utils
::
check_err
(
c_m_n_device_result
,
...
...
example/67_gemm_microscaling/gemm_mx_fp8.cpp
View file @
06af86fb
...
@@ -13,7 +13,7 @@ using XDataType = ck::e8m0_bexp_t;
...
@@ -13,7 +13,7 @@ using XDataType = ck::e8m0_bexp_t;
#endif
#endif
using
AccDataType
=
float
;
using
AccDataType
=
float
;
using
CShuffleDataType
=
float
;
using
CShuffleDataType
=
float
;
#if
1
#if
0
using CDataType = ck::half_t;
using CDataType = ck::half_t;
#else
#else
using
CDataType
=
float
;
using
CDataType
=
float
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment