Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
465ba138
"...composable_kernel.git" did not exist on "d714fa15cb556fcb5f132f1a0c7cfbc22163a253"
Commit
465ba138
authored
Feb 01, 2025
by
Andriy Roshchenko
Browse files
WIP: Implementing MX MFMA test.
parent
94079d60
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
501 additions
and
5 deletions
+501
-5
CMakeLists.txt
CMakeLists.txt
+1
-1
CMakePresets.json
CMakePresets.json
+189
-0
test/mx_mfma_op/mx_mfma_op.cpp
test/mx_mfma_op/mx_mfma_op.cpp
+63
-3
test/mx_mfma_op/mx_mfma_op.hpp
test/mx_mfma_op/mx_mfma_op.hpp
+248
-1
No files found.
CMakeLists.txt
View file @
465ba138
...
@@ -530,7 +530,7 @@ endif()
...
@@ -530,7 +530,7 @@ endif()
message
(
"CMAKE_CXX_FLAGS:
${
CMAKE_CXX_FLAGS
}
"
)
message
(
"CMAKE_CXX_FLAGS:
${
CMAKE_CXX_FLAGS
}
"
)
if
(
"
${
CMAKE_CXX_COMPILER_ID
}
"
MATCHES
"Clang"
)
if
(
"
${
CMAKE_CXX_COMPILER_ID
}
"
MATCHES
"Clang"
)
add_compile_options
(
-fcolor-diagnostics
)
#
add_compile_options(-fcolor-diagnostics)
endif
()
endif
()
if
(
"
${
CMAKE_CXX_COMPILER_ID
}
"
STREQUAL
"GNU"
AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 4.9
)
if
(
"
${
CMAKE_CXX_COMPILER_ID
}
"
STREQUAL
"GNU"
AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 4.9
)
add_compile_options
(
-fdiagnostics-color=always
)
add_compile_options
(
-fdiagnostics-color=always
)
...
...
CMakePresets.json
0 → 100644
View file @
465ba138
{
"version"
:
3
,
"configurePresets"
:
[
{
"name"
:
"linux-debug"
,
"displayName"
:
"Linux Debug"
,
"hidden"
:
true
,
"generator"
:
"Unix Makefiles"
,
"binaryDir"
:
"${sourceDir}/build/${presetName}"
,
"installDir"
:
"${sourceDir}/build/install/${presetName}"
,
"environment"
:
{
"MY_ENVIRONMENT_VARIABLE"
:
"NONE"
,
"PATH"
:
"/usr/local/.cargo/bin:$penv{PATH}"
,
"SCCACHE_IDLE_TIMEOUT"
:
"11000"
},
"cacheVariables"
:
{
"CMAKE_BUILD_TYPE"
:
"Debug"
,
"CMAKE_EXPORT_COMPILE_COMMANDS"
:
"ON"
,
"BUILD_DEV"
:
"ON"
,
"CMAKE_CXX_COMPILER"
:
"/opt/rocm/bin/hipcc"
,
"CMAKE_PREFIX_PATH"
:
"/opt/rocm"
,
"CMAKE_CXX_COMPILER_LAUNCHER"
:
"sccache"
,
"CMAKE_C_COMPILER_LAUNCHER"
:
"sccache"
},
"condition"
:
{
"type"
:
"equals"
,
"lhs"
:
"${hostSystemName}"
,
"rhs"
:
"Linux"
}
},
{
"name"
:
"MI355-debug"
,
"displayName"
:
"MI355 Debug"
,
"inherits"
:
"linux-debug"
,
"description"
:
"Development Environment for MI355."
,
"cacheVariables"
:
{
"GPU_TARGETS"
:
"gfx950"
,
"CMAKE_BUILD_TYPE"
:
"Debug"
,
"CMAKE_CXX_FLAGS"
:
"-O0 -ggdb"
}
},
{
"name"
:
"MI355-release"
,
"displayName"
:
"MI355 Release"
,
"inherits"
:
"linux-debug"
,
"cacheVariables"
:
{
"GPU_TARGETS"
:
"gfx950"
,
"CMAKE_BUILD_TYPE"
:
"Release"
,
"CMAKE_CXX_FLAGS"
:
"-O3"
}
},
{
"name"
:
"MI300X-release"
,
"displayName"
:
"MI300X Release"
,
"inherits"
:
"linux-debug"
,
"cacheVariables"
:
{
"GPU_TARGETS"
:
"gfx942"
,
"CMAKE_BUILD_TYPE"
:
"Release"
,
"CMAKE_CXX_FLAGS"
:
"-O3"
}
},
{
"name"
:
"MI250-release"
,
"displayName"
:
"MI250 Release"
,
"inherits"
:
"linux-debug"
,
"cacheVariables"
:
{
"GPU_TARGETS"
:
"gfx90a"
,
"CMAKE_BUILD_TYPE"
:
"Release"
,
"CMAKE_CXX_FLAGS"
:
"-O3"
,
"CK_USE_FP8_ON_UNSUPPORTED_ARCH"
:
"ON"
}
},
{
"name"
:
"MI250-debug"
,
"displayName"
:
"MI250 Debug"
,
"inherits"
:
"linux-debug"
,
"cacheVariables"
:
{
"GPU_TARGETS"
:
"gfx90a"
,
"CMAKE_BUILD_TYPE"
:
"Debug"
,
"CMAKE_CXX_FLAGS"
:
"-O0 -ggdb"
,
"CK_USE_FP8_ON_UNSUPPORTED_ARCH"
:
"ON"
}
},
{
"name"
:
"RX7800-release"
,
"displayName"
:
"RX7800 Release"
,
"inherits"
:
"linux-debug"
,
"cacheVariables"
:
{
"GPU_TARGETS"
:
"gfx1101"
,
"DL_KERNELS"
:
"ON"
,
"CMAKE_BUILD_TYPE"
:
"Release"
,
"CMAKE_CXX_FLAGS"
:
"-O3"
}
},
{
"name"
:
"RX7800-debug"
,
"displayName"
:
"RX7800 Debug"
,
"inherits"
:
"linux-debug"
,
"cacheVariables"
:
{
"GPU_TARGETS"
:
"gfx1101"
,
"DL_KERNELS"
:
"ON"
,
"CMAKE_BUILD_TYPE"
:
"Debug"
,
"CMAKE_CXX_FLAGS"
:
"-O0 -ggdb"
}
}
],
"buildPresets"
:
[
{
"name"
:
"Debug"
,
"hidden"
:
true
,
"configuration"
:
"Debug"
},
{
"name"
:
"Release"
,
"hidden"
:
true
,
"configuration"
:
"Release"
},
{
"name"
:
"MI355-debug"
,
"displayName"
:
"MI355"
,
"configurePreset"
:
"MI355-debug"
,
"description"
:
"Build Environment for MI355 Debug."
,
"inherits"
:
[
"Debug"
],
"jobs"
:
128
},
{
"name"
:
"MI355-release"
,
"displayName"
:
"MI355"
,
"configurePreset"
:
"MI355-release"
,
"description"
:
"Build Environment for MI355 Release."
,
"inherits"
:
[
"Release"
],
"jobs"
:
128
},
{
"name"
:
"MI300X-release"
,
"displayName"
:
"MI300X"
,
"configurePreset"
:
"MI300X-release"
,
"description"
:
"Build Environment for MI300X Release."
,
"inherits"
:
[
"Release"
],
"jobs"
:
128
},
{
"name"
:
"MI250-release"
,
"displayName"
:
"MI250"
,
"configurePreset"
:
"MI250-release"
,
"description"
:
"Build Environment for MI250 Release."
,
"inherits"
:
[
"Release"
],
"jobs"
:
128
},
{
"name"
:
"MI250-debug"
,
"displayName"
:
"MI250"
,
"configurePreset"
:
"MI250-debug"
,
"description"
:
"Build Environment for MI250 Debug."
,
"inherits"
:
[
"Debug"
],
"jobs"
:
128
},
{
"name"
:
"RX7800-release"
,
"displayName"
:
"RX7800"
,
"configurePreset"
:
"RX7800-release"
,
"description"
:
"Build Environment for RX7800 Release."
,
"inherits"
:
[
"Release"
],
"jobs"
:
128
},
{
"name"
:
"RX7800-debug"
,
"displayName"
:
"RX7800"
,
"configurePreset"
:
"RX7800-debug"
,
"description"
:
"Build Environment for RX7800 Debug."
,
"inherits"
:
[
"Debug"
],
"jobs"
:
128
}
]
}
test/mx_mfma_op/mx_mfma_op.cpp
View file @
465ba138
...
@@ -30,11 +30,11 @@ bool run_mfma_test(ck::index_t init)
...
@@ -30,11 +30,11 @@ bool run_mfma_test(ck::index_t init)
constexpr
auto
BLOCK_N
=
mfma_instr
.
n_per_blk
;
constexpr
auto
BLOCK_N
=
mfma_instr
.
n_per_blk
;
constexpr
auto
BLOCK_K
=
mfma_instr
.
num_input_blks
*
mfma_instr
.
k_per_blk
;
constexpr
auto
BLOCK_K
=
mfma_instr
.
num_input_blks
*
mfma_instr
.
k_per_blk
;
const
auto
mx_
mfma_kernel
=
ck
::
matmul
<
AType
,
BType
,
CType
,
AccType
,
BLOCK_M
,
BLOCK_N
,
BLOCK_K
>
;
const
auto
mfma_kernel
=
ck
::
matmul
<
AType
,
BType
,
CType
,
AccType
,
BLOCK_M
,
BLOCK_N
,
BLOCK_K
>
;
bool
pass
=
true
;
bool
pass
=
true
;
pass
=
ck
::
mfma_test
::
TestMFMA
<
decltype
(
mx_
mfma_kernel
),
pass
=
ck
::
mfma_test
::
TestMFMA
<
decltype
(
mfma_kernel
),
AType
,
AType
,
BType
,
BType
,
CType
,
CType
,
...
@@ -45,7 +45,7 @@ bool run_mfma_test(ck::index_t init)
...
@@ -45,7 +45,7 @@ bool run_mfma_test(ck::index_t init)
CLayout
,
CLayout
,
BLOCK_M
,
BLOCK_M
,
BLOCK_N
,
BLOCK_N
,
BLOCK_K
>
{}(
mx_
mfma_kernel
,
init
);
BLOCK_K
>
{}(
mfma_kernel
,
init
);
return
pass
;
return
pass
;
}
}
...
@@ -63,3 +63,63 @@ TEST(MFMA, FP8MFMA32x32x64)
...
@@ -63,3 +63,63 @@ TEST(MFMA, FP8MFMA32x32x64)
auto
pass
=
run_mfma_test
<
f8_t
,
f8_t
,
float
,
ck
::
MFMA_F8F6F4
::
F32_32x32x64
>
(
AB_init
);
auto
pass
=
run_mfma_test
<
f8_t
,
f8_t
,
float
,
ck
::
MFMA_F8F6F4
::
F32_32x32x64
>
(
AB_init
);
EXPECT_TRUE
(
pass
);
EXPECT_TRUE
(
pass
);
}
}
/**
* @brief Run the test for the given MX MFMA instruction
*
* @param init - selects initialization algorithm for A and B tensors
*/
template
<
typename
AType
,
typename
BType
,
typename
CType
,
ck
::
MFMA_F8F6F4
mfma
>
bool
run_mxmfma_test
(
ck
::
index_t
init
)
{
static_assert
(
mfma
==
ck
::
MFMA_F8F6F4
::
SCALE_F32_16x16x128
||
mfma
==
ck
::
MFMA_F8F6F4
::
SCALE_F32_32x32x64
,
"Only SCALE_F32_16x16x128 and SCALE_F32_32x32x64 are supported"
);
using
ALayout
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
BLayout
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
CLayout
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
AccType
=
float
;
// only MFMA_F32 instructions supported
using
CPUAccType
=
AccType
;
using
ScaleType
=
ck
::
e8m0_bexp_t
;
// biased exponent type
ck
::
mfma_type
<
static_cast
<
ck
::
MfmaInstr
>
(
mfma
)
>
mfma_instr
;
constexpr
auto
BLOCK_M
=
mfma_instr
.
m_per_blk
;
constexpr
auto
BLOCK_N
=
mfma_instr
.
n_per_blk
;
constexpr
auto
BLOCK_K
=
mfma_instr
.
num_input_blks
*
mfma_instr
.
k_per_blk
;
constexpr
auto
BLOCK_X
=
32
;
// scaling vector size
const
auto
mx_mfma_kernel
=
ck
::
matmul
<
AType
,
BType
,
ScaleType
,
CType
,
AccType
,
BLOCK_M
,
BLOCK_N
,
BLOCK_K
,
BLOCK_X
>
;
bool
pass
=
true
;
pass
=
ck
::
mxmfma_test
::
TestMXMFMA
<
decltype
(
mx_mfma_kernel
),
AType
,
BType
,
ScaleType
,
CType
,
ALayout
,
BLayout
,
CLayout
,
BLOCK_M
,
BLOCK_N
,
BLOCK_K
,
BLOCK_X
>
{}(
mx_mfma_kernel
,
init
);
return
pass
;
}
TEST
(
MXMFMA
,
MXFP8MFMA16x16x128
)
{
auto
AB_init
=
1
;
auto
pass
=
run_mxmfma_test
<
f8_t
,
f8_t
,
float
,
ck
::
MFMA_F8F6F4
::
SCALE_F32_16x16x128
>
(
AB_init
);
EXPECT_TRUE
(
pass
);
}
TEST
(
MXMFMA
,
MXFP8MFMA32x32x64
)
{
auto
AB_init
=
1
;
auto
pass
=
run_mxmfma_test
<
f8_t
,
f8_t
,
half_t
,
ck
::
MFMA_F8F6F4
::
SCALE_F32_32x32x64
>
(
AB_init
);
EXPECT_TRUE
(
pass
);
}
test/mx_mfma_op/mx_mfma_op.hpp
View file @
465ba138
...
@@ -18,7 +18,13 @@ enum class MFMA_F8F6F4
...
@@ -18,7 +18,13 @@ enum class MFMA_F8F6F4
F32_16x16x128
=
F32_16x16x128
=
static_cast
<
int
>
(
MfmaInstr
::
mfma_f32_16x16x128f8f6f4
),
// V_MFMA_F32_16X16X128_F8F6F4
static_cast
<
int
>
(
MfmaInstr
::
mfma_f32_16x16x128f8f6f4
),
// V_MFMA_F32_16X16X128_F8F6F4
F32_32x32x64
=
F32_32x32x64
=
static_cast
<
int
>
(
MfmaInstr
::
mfma_f32_32x32x64f8f6f4
)
// V_MFMA_F32_32X32X64_F8F6F4
static_cast
<
int
>
(
MfmaInstr
::
mfma_f32_32x32x64f8f6f4
),
// V_MFMA_F32_32X32X64_F8F6F4
SCALE_F32_16x16x128
=
static_cast
<
int
>
(
MfmaInstr
::
mfma_scale_f32_16x16x128f8f6f4
),
// V_MFMA_SCALE_F32_16X16X128_F8F6F4
SCALE_F32_32x32x64
=
static_cast
<
int
>
(
MfmaInstr
::
mfma_scale_f32_32x32x64f8f6f4
)
// V_MFMA_SCALE_F32_32X32X64_F8F6F4
};
};
template
<
typename
AFragT
,
typename
BFragT
,
typename
AccumFragT
,
int32_t
BLOCK_M
,
int32_t
BLOCK_N
>
template
<
typename
AFragT
,
typename
BFragT
,
typename
AccumFragT
,
int32_t
BLOCK_M
,
int32_t
BLOCK_N
>
...
@@ -352,6 +358,24 @@ __global__ void matmul(const AType* a, const BType* b, CType* c)
...
@@ -352,6 +358,24 @@ __global__ void matmul(const AType* a, const BType* b, CType* c)
auto
storeC
=
store_C_col_major
<
CType
,
CFragT
,
BLOCK_M
,
BLOCK_N
>
{};
auto
storeC
=
store_C_col_major
<
CType
,
CFragT
,
BLOCK_M
,
BLOCK_N
>
{};
storeC
(
c
,
fragC
);
storeC
(
c
,
fragC
);
}
}
template
<
typename
AType
,
typename
BType
,
typename
ScaleType
,
typename
CType
,
typename
AccType
,
int32_t
BLOCK_M
,
int32_t
BLOCK_N
,
int32_t
BLOCK_K
,
int32_t
BLOCK_X
>
__global__
void
matmul
(
const
AType
*
a
,
const
BType
*
b
,
const
ScaleType
*
x
,
CType
*
c
)
{
ignore
=
a
;
ignore
=
b
;
ignore
=
x
;
ignore
=
c
;
}
/**
/**
* @brief Structure to hold dimension parameters for GEMM tensors.
* @brief Structure to hold dimension parameters for GEMM tensors.
*
*
...
@@ -373,6 +397,229 @@ struct GemmParams
...
@@ -373,6 +397,229 @@ struct GemmParams
ck
::
index_t
StrideC
=
-
1
;
ck
::
index_t
StrideC
=
-
1
;
};
};
namespace
mxmfma_test
{
template
<
typename
ADataType
,
typename
BDataType
,
typename
ScaleType
,
typename
CDataType
>
void
RunHostGEMM
(
const
Tensor
<
ADataType
>&
A
,
const
Tensor
<
ScaleType
>&
a_scales
,
const
Tensor
<
BDataType
>&
B
,
const
Tensor
<
ScaleType
>&
b_scales
,
Tensor
<
CDataType
>&
C
)
{
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
GemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
float
,
float
,
CDataType
,
float
,
PassThrough
,
PassThrough
,
PassThrough
>
;
Tensor
<
float
>
a_m_k
(
A
.
mDesc
);
Tensor
<
float
>
b_k_n
(
B
.
mDesc
);
const
auto
M
=
A
.
mDesc
.
GetLengths
()[
0
];
const
auto
N
=
B
.
mDesc
.
GetLengths
()[
1
];
const
auto
K
=
A
.
mDesc
.
GetLengths
()[
1
];
const
auto
BLOCK_X
=
K
/
a_scales
.
mDesc
.
GetLengths
()[
1
];
for
(
size_t
m
=
0
;
m
<
M
;
m
++
)
{
for
(
size_t
k
=
0
;
k
<
K
;
k
++
)
{
a_m_k
(
m
,
k
)
=
type_convert
<
float
>
(
A
(
m
,
k
))
*
type_convert
<
float
>
(
a_scales
(
m
,
k
/
BLOCK_X
));
}
}
for
(
size_t
n
=
0
;
n
<
N
;
n
++
)
{
for
(
size_t
k
=
0
;
k
<
K
;
k
++
)
{
b_k_n
(
k
,
n
)
=
type_convert
<
float
>
(
B
(
k
,
n
))
*
type_convert
<
float
>
(
b_scales
(
k
/
BLOCK_X
,
n
));
}
}
auto
ref_gemm
=
GemmInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
a_m_k
,
b_k_n
,
C
,
PassThrough
{},
PassThrough
{},
PassThrough
{});
ref_invoker
.
Run
(
ref_argument
);
}
template
<
typename
KernelType
,
typename
ADataType
,
typename
BDataType
,
typename
CDataType
>
bool
RunDeviceGEMM
(
KernelType
kernel
,
const
Tensor
<
ADataType
>&
A
,
const
Tensor
<
BDataType
>&
B
,
Tensor
<
CDataType
>&
C
)
{
DeviceMem
a_m_k_device_buf
(
sizeof
(
ADataType
)
*
A
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_n_k_device_buf
(
sizeof
(
BDataType
)
*
B
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
c_m_n_device_buf
(
sizeof
(
CDataType
)
*
C
.
mDesc
.
GetElementSpaceSize
());
a_m_k_device_buf
.
ToDevice
(
A
.
mData
.
data
());
b_n_k_device_buf
.
ToDevice
(
B
.
mData
.
data
());
kernel
<<<
1
,
64
>>>
(
static_cast
<
const
ADataType
*>
(
a_m_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
const
BDataType
*>
(
b_n_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_m_n_device_buf
.
GetDeviceBuffer
()));
c_m_n_device_buf
.
FromDevice
(
C
.
mData
.
data
());
return
true
;
}
template
<
typename
DeviceMFMA
,
typename
ADataType
,
typename
BDataType
,
typename
ScaleType
,
typename
CDataType
,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
index_t
BLOCK_M
,
index_t
BLOCK_N
,
index_t
BLOCK_K
,
index_t
BLOCK_X
>
struct
TestMXMFMA
{
auto
PrepareGemmTensors
(
const
GemmParams
&
params
,
index_t
init
)
{
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
if
(
std
::
is_same
<
decltype
(
layout
),
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
row
,
col
}),
std
::
vector
<
std
::
size_t
>
({
stride
,
1
}));
}
else
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
row
,
col
}),
std
::
vector
<
std
::
size_t
>
({
1
,
stride
}));
}
};
Tensor
<
ADataType
>
a_m_k
(
f_host_tensor_descriptor
(
params
.
M
,
params
.
K
,
params
.
StrideA
,
ALayout
{}));
Tensor
<
ScaleType
>
a_scales
(
f_host_tensor_descriptor
(
params
.
M
,
params
.
K
/
BLOCK_X
,
params
.
K
/
BLOCK_X
,
ALayout
{}));
Tensor
<
BDataType
>
b_n_k
(
f_host_tensor_descriptor
(
params
.
K
,
params
.
N
,
params
.
StrideB
,
BLayout
{}));
Tensor
<
ScaleType
>
b_scales
(
f_host_tensor_descriptor
(
params
.
K
/
BLOCK_X
,
params
.
N
,
params
.
K
/
BLOCK_X
,
BLayout
{}));
Tensor
<
CDataType
>
c_m_n_host_result
(
f_host_tensor_descriptor
(
params
.
M
,
params
.
N
,
params
.
StrideC
,
CLayout
{}));
Tensor
<
CDataType
>
c_m_n_device_result
(
f_host_tensor_descriptor
(
params
.
M
,
params
.
N
,
params
.
StrideC
,
CLayout
{}));
switch
(
init
)
{
case
0
:
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_1
<
ADataType
>
{
1.0
f
});
a_scales
.
GenerateTensorValue
(
GeneratorTensor_1
<
ScaleType
>
{
ScaleType
{
0.015625
f
}});
// NOTE: not all numbers are representable in FP8, BF8, etc.
b_n_k
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
BDataType
,
1
>
{});
b_scales
.
GenerateTensorValue
(
GeneratorTensor_1
<
ScaleType
>
{
ScaleType
{
1.0
f
}});
break
;
case
1
:
// results in C = {K}
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_1
<
ADataType
>
{
1.0
f
});
a_scales
.
GenerateTensorValue
(
GeneratorTensor_1
<
ScaleType
>
{
ScaleType
{
1.0
f
}});
b_n_k
.
GenerateTensorValue
(
GeneratorTensor_1
<
BDataType
>
{
1.0
f
});
b_scales
.
GenerateTensorValue
(
GeneratorTensor_1
<
ScaleType
>
{
ScaleType
{
1.0
f
}});
break
;
case
2
:
// expect small round off errors
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
-
5
,
5
});
a_scales
.
GenerateTensorValue
(
GeneratorTensor_2
<
ScaleType
>
{
126
,
129
});
// scales: {0.5, 1, 2}
b_n_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
5
,
5
});
b_scales
.
GenerateTensorValue
(
GeneratorTensor_2
<
ScaleType
>
{
126
,
129
});
// scales: {0.5, 1, 2}
break
;
case
3
:
// expect small round off errors
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_4
<
ADataType
>
(
-
1
,
3
));
a_scales
.
GenerateTensorValue
(
GeneratorTensor_2
<
ScaleType
>
{
126
,
129
});
// scales: {0.5, 1, 2}
b_n_k
.
GenerateTensorValue
(
GeneratorTensor_4
<
BDataType
>
(
1
,
3
));
b_scales
.
GenerateTensorValue
(
GeneratorTensor_2
<
ScaleType
>
{
126
,
129
});
// scales: {0.5, 1, 2}
break
;
default:
// all initial values are representable in FP8, BF8
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
5
,
6
});
a_scales
.
GenerateTensorValue
(
GeneratorTensor_3
<
ScaleType
>
{
1.0
f
/
32.0
f
,
1.0
f
});
b_n_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
5
,
6
});
b_scales
.
GenerateTensorValue
(
GeneratorTensor_3
<
ScaleType
>
{
1.0
f
/
32.0
f
,
1.0
f
});
break
;
}
return
std
::
make_tuple
(
a_m_k
,
a_scales
,
b_n_k
,
b_scales
,
c_m_n_host_result
,
c_m_n_device_result
);
}
auto
operator
()(
const
DeviceMFMA
&
mfma_kernel
,
index_t
init
)
{
std
::
cout
<<
"ALayout = "
<<
ALayout
{}.
name
<<
", BLayout = "
<<
BLayout
{}.
name
<<
", CLayout = "
<<
CLayout
{}.
name
<<
std
::
endl
;
// Arrange
GemmParams
params
;
params
.
M
=
BLOCK_M
;
params
.
N
=
BLOCK_N
;
params
.
K
=
BLOCK_K
;
auto
f_get_default_stride
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
ck
::
index_t
stride
,
auto
layout
)
{
if
(
stride
==
-
1
)
{
// give a chance if stride is -1, return a default packed stride
if
constexpr
(
std
::
is_same_v
<
decltype
(
layout
),
ck
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
static_cast
<
std
::
size_t
>
(
col
);
}
else
{
return
static_cast
<
std
::
size_t
>
(
row
);
}
}
else
return
static_cast
<
std
::
size_t
>
(
stride
);
};
params
.
StrideA
=
f_get_default_stride
(
BLOCK_M
,
BLOCK_K
,
params
.
StrideA
,
ALayout
{});
params
.
StrideB
=
f_get_default_stride
(
BLOCK_K
,
BLOCK_N
,
params
.
StrideB
,
BLayout
{});
params
.
StrideC
=
f_get_default_stride
(
BLOCK_M
,
BLOCK_N
,
params
.
StrideC
,
CLayout
{});
auto
host_tensors
=
PrepareGemmTensors
(
params
,
init
);
const
Tensor
<
ADataType
>&
a
=
std
::
get
<
0
>
(
host_tensors
);
const
Tensor
<
ScaleType
>&
a_scales
=
std
::
get
<
1
>
(
host_tensors
);
const
Tensor
<
BDataType
>&
b
=
std
::
get
<
2
>
(
host_tensors
);
const
Tensor
<
ScaleType
>&
b_scales
=
std
::
get
<
3
>
(
host_tensors
);
Tensor
<
CDataType
>&
c_host
=
std
::
get
<
4
>
(
host_tensors
);
Tensor
<
CDataType
>&
c_device
=
std
::
get
<
5
>
(
host_tensors
);
RunHostGEMM
(
a
,
a_scales
,
b
,
b_scales
,
c_host
);
RunDeviceGEMM
(
mfma_kernel
,
a
,
b
,
c_device
);
bool
res
=
false
;
if
constexpr
(
std
::
is_same
<
CDataType
,
float
>::
value
||
std
::
is_same
<
CDataType
,
half_t
>::
value
)
{
res
=
ck
::
utils
::
check_err
(
c_device
.
mData
,
c_host
.
mData
);
std
::
cout
<<
(
res
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
}
else
{
std
::
cout
<<
"UNSUPPORTED CDataType"
<<
std
::
endl
;
}
return
res
;
}
};
}
// namespace mxmfma_test
namespace
mfma_test
{
namespace
mfma_test
{
template
<
typename
GemmInstance
,
template
<
typename
GemmInstance
,
typename
ADataType
,
typename
ADataType
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment