Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
f3af1da6
Commit
f3af1da6
authored
Feb 05, 2025
by
Andriy Roshchenko
Browse files
Merge remote-tracking branch 'internal/andriy/lwpck-2788' into andriy/lwpck-2788
parents
2bef5501
60b885ae
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
1343 additions
and
14 deletions
+1343
-14
CMakeLists.txt
CMakeLists.txt
+1
-1
CMakePresets.json
CMakePresets.json
+189
-0
include/ck/library/utility/host_tensor_generator.hpp
include/ck/library/utility/host_tensor_generator.hpp
+15
-0
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
+15
-6
include/ck/utility/amd_xdlops.hpp
include/ck/utility/amd_xdlops.hpp
+26
-2
test/mx_mfma_op/mx_mfma_op.cpp
test/mx_mfma_op/mx_mfma_op.cpp
+98
-3
test/mx_mfma_op/mx_mfma_op.hpp
test/mx_mfma_op/mx_mfma_op.hpp
+816
-2
test/mx_mfma_op/scale_mfma_repro.cpp
test/mx_mfma_op/scale_mfma_repro.cpp
+183
-0
No files found.
CMakeLists.txt
View file @
f3af1da6
...
@@ -541,7 +541,7 @@ endif()
...
@@ -541,7 +541,7 @@ endif()
message
(
"CMAKE_CXX_FLAGS:
${
CMAKE_CXX_FLAGS
}
"
)
message
(
"CMAKE_CXX_FLAGS:
${
CMAKE_CXX_FLAGS
}
"
)
if
(
"
${
CMAKE_CXX_COMPILER_ID
}
"
MATCHES
"Clang"
)
if
(
"
${
CMAKE_CXX_COMPILER_ID
}
"
MATCHES
"Clang"
)
add_compile_options
(
-fcolor-diagnostics
)
#
add_compile_options(-fcolor-diagnostics)
endif
()
endif
()
if
(
"
${
CMAKE_CXX_COMPILER_ID
}
"
STREQUAL
"GNU"
AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 4.9
)
if
(
"
${
CMAKE_CXX_COMPILER_ID
}
"
STREQUAL
"GNU"
AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 4.9
)
add_compile_options
(
-fdiagnostics-color=always
)
add_compile_options
(
-fdiagnostics-color=always
)
...
...
CMakePresets.json
0 → 100644
View file @
f3af1da6
{
"version"
:
3
,
"configurePresets"
:
[
{
"name"
:
"linux-debug"
,
"displayName"
:
"Linux Debug"
,
"hidden"
:
true
,
"generator"
:
"Unix Makefiles"
,
"binaryDir"
:
"${sourceDir}/build/${presetName}"
,
"installDir"
:
"${sourceDir}/build/install/${presetName}"
,
"environment"
:
{
"MY_ENVIRONMENT_VARIABLE"
:
"NONE"
,
"PATH"
:
"/usr/local/.cargo/bin:$penv{PATH}"
,
"SCCACHE_IDLE_TIMEOUT"
:
"11000"
},
"cacheVariables"
:
{
"CMAKE_BUILD_TYPE"
:
"Debug"
,
"CMAKE_EXPORT_COMPILE_COMMANDS"
:
"ON"
,
"BUILD_DEV"
:
"ON"
,
"CMAKE_CXX_COMPILER"
:
"/opt/rocm/bin/hipcc"
,
"CMAKE_PREFIX_PATH"
:
"/opt/rocm"
,
"CMAKE_CXX_COMPILER_LAUNCHER"
:
"sccache"
,
"CMAKE_C_COMPILER_LAUNCHER"
:
"sccache"
},
"condition"
:
{
"type"
:
"equals"
,
"lhs"
:
"${hostSystemName}"
,
"rhs"
:
"Linux"
}
},
{
"name"
:
"MI355-debug"
,
"displayName"
:
"MI355 Debug"
,
"inherits"
:
"linux-debug"
,
"description"
:
"Development Environment for MI355."
,
"cacheVariables"
:
{
"GPU_TARGETS"
:
"gfx950"
,
"CMAKE_BUILD_TYPE"
:
"Debug"
,
"CMAKE_CXX_FLAGS"
:
"-O0 -ggdb"
}
},
{
"name"
:
"MI355-release"
,
"displayName"
:
"MI355 Release"
,
"inherits"
:
"linux-debug"
,
"cacheVariables"
:
{
"GPU_TARGETS"
:
"gfx950"
,
"CMAKE_BUILD_TYPE"
:
"Release"
,
"CMAKE_CXX_FLAGS"
:
"-O3"
}
},
{
"name"
:
"MI300X-release"
,
"displayName"
:
"MI300X Release"
,
"inherits"
:
"linux-debug"
,
"cacheVariables"
:
{
"GPU_TARGETS"
:
"gfx942"
,
"CMAKE_BUILD_TYPE"
:
"Release"
,
"CMAKE_CXX_FLAGS"
:
"-O3"
}
},
{
"name"
:
"MI250-release"
,
"displayName"
:
"MI250 Release"
,
"inherits"
:
"linux-debug"
,
"cacheVariables"
:
{
"GPU_TARGETS"
:
"gfx90a"
,
"CMAKE_BUILD_TYPE"
:
"Release"
,
"CMAKE_CXX_FLAGS"
:
"-O3"
,
"CK_USE_FP8_ON_UNSUPPORTED_ARCH"
:
"ON"
}
},
{
"name"
:
"MI250-debug"
,
"displayName"
:
"MI250 Debug"
,
"inherits"
:
"linux-debug"
,
"cacheVariables"
:
{
"GPU_TARGETS"
:
"gfx90a"
,
"CMAKE_BUILD_TYPE"
:
"Debug"
,
"CMAKE_CXX_FLAGS"
:
"-O0 -ggdb"
,
"CK_USE_FP8_ON_UNSUPPORTED_ARCH"
:
"ON"
}
},
{
"name"
:
"RX7800-release"
,
"displayName"
:
"RX7800 Release"
,
"inherits"
:
"linux-debug"
,
"cacheVariables"
:
{
"GPU_TARGETS"
:
"gfx1101"
,
"DL_KERNELS"
:
"ON"
,
"CMAKE_BUILD_TYPE"
:
"Release"
,
"CMAKE_CXX_FLAGS"
:
"-O3"
}
},
{
"name"
:
"RX7800-debug"
,
"displayName"
:
"RX7800 Debug"
,
"inherits"
:
"linux-debug"
,
"cacheVariables"
:
{
"GPU_TARGETS"
:
"gfx1101"
,
"DL_KERNELS"
:
"ON"
,
"CMAKE_BUILD_TYPE"
:
"Debug"
,
"CMAKE_CXX_FLAGS"
:
"-O0 -ggdb"
}
}
],
"buildPresets"
:
[
{
"name"
:
"Debug"
,
"hidden"
:
true
,
"configuration"
:
"Debug"
},
{
"name"
:
"Release"
,
"hidden"
:
true
,
"configuration"
:
"Release"
},
{
"name"
:
"MI355-debug"
,
"displayName"
:
"MI355"
,
"configurePreset"
:
"MI355-debug"
,
"description"
:
"Build Environment for MI355 Debug."
,
"inherits"
:
[
"Debug"
],
"jobs"
:
128
},
{
"name"
:
"MI355-release"
,
"displayName"
:
"MI355"
,
"configurePreset"
:
"MI355-release"
,
"description"
:
"Build Environment for MI355 Release."
,
"inherits"
:
[
"Release"
],
"jobs"
:
128
},
{
"name"
:
"MI300X-release"
,
"displayName"
:
"MI300X"
,
"configurePreset"
:
"MI300X-release"
,
"description"
:
"Build Environment for MI300X Release."
,
"inherits"
:
[
"Release"
],
"jobs"
:
128
},
{
"name"
:
"MI250-release"
,
"displayName"
:
"MI250"
,
"configurePreset"
:
"MI250-release"
,
"description"
:
"Build Environment for MI250 Release."
,
"inherits"
:
[
"Release"
],
"jobs"
:
128
},
{
"name"
:
"MI250-debug"
,
"displayName"
:
"MI250"
,
"configurePreset"
:
"MI250-debug"
,
"description"
:
"Build Environment for MI250 Debug."
,
"inherits"
:
[
"Debug"
],
"jobs"
:
128
},
{
"name"
:
"RX7800-release"
,
"displayName"
:
"RX7800"
,
"configurePreset"
:
"RX7800-release"
,
"description"
:
"Build Environment for RX7800 Release."
,
"inherits"
:
[
"Release"
],
"jobs"
:
128
},
{
"name"
:
"RX7800-debug"
,
"displayName"
:
"RX7800"
,
"configurePreset"
:
"RX7800-debug"
,
"description"
:
"Build Environment for RX7800 Debug."
,
"inherits"
:
[
"Debug"
],
"jobs"
:
128
}
]
}
include/ck/library/utility/host_tensor_generator.hpp
View file @
f3af1da6
...
@@ -359,6 +359,21 @@ struct GeneratorTensor_Sequential
...
@@ -359,6 +359,21 @@ struct GeneratorTensor_Sequential
}
}
};
};
template
<
ck
::
index_t
Dim
>
struct
GeneratorTensor_Sequential
<
ck
::
e8m0_bexp_t
,
Dim
>
{
int
offset
=
0
;
template
<
typename
...
Ts
>
ck
::
e8m0_bexp_t
operator
()(
Ts
...
Xs
)
const
{
std
::
array
<
ck
::
index_t
,
sizeof
...(
Ts
)
>
dims
=
{{
static_cast
<
ck
::
index_t
>
(
Xs
)...}};
int
tmp
=
dims
[
Dim
];
return
ck
::
type_convert
<
ck
::
e8m0_bexp_t
>
(
powf
(
2
,
tmp
+
offset
));
}
};
template
<
typename
T
,
size_t
NumEffectiveDim
=
2
>
template
<
typename
T
,
size_t
NumEffectiveDim
=
2
>
struct
GeneratorTensor_Diagonal
struct
GeneratorTensor_Diagonal
{
{
...
...
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
View file @
f3af1da6
...
@@ -780,7 +780,6 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32bf8f8>
...
@@ -780,7 +780,6 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32bf8f8>
}
}
};
};
// TODO: fix mfma...f8f6f4 instructions
template
<
>
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_32x32x64f8f6f4
>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_32x32x64f8f6f4
>
{
{
...
@@ -847,9 +846,14 @@ struct mfma_type<MfmaInstr::mfma_scale_f32_32x32x64f8f6f4>
...
@@ -847,9 +846,14 @@ struct mfma_type<MfmaInstr::mfma_scale_f32_32x32x64f8f6f4>
// clang-format on
// clang-format on
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
__device__
void
run
(
const
FloatA
&
a
,
const
int32_t
&
scale_a
,
const
FloatB
&
b
,
const
int32_t
&
scale_b
,
FloatC
&
reg_c
)
const
{
{
intrin_mfma_scale_f32_32x32x64f8f6f4
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
intrin_mfma_scale_f32_32x32x64f8f6f4
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
scale_a
,
b
,
scale_b
,
reg_c
);
}
}
};
};
...
@@ -871,9 +875,14 @@ struct mfma_type<MfmaInstr::mfma_scale_f32_16x16x128f8f6f4>
...
@@ -871,9 +875,14 @@ struct mfma_type<MfmaInstr::mfma_scale_f32_16x16x128f8f6f4>
// clang-format on
// clang-format on
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
__device__
void
run
(
const
FloatA
&
a
,
{
const
int32_t
&
scale_a
,
intrin_mfma_scale_f32_16x16x128f8f6f4
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
const
FloatB
&
b
,
const
int32_t
&
scale_b
,
FloatC
&
reg_c
)
const
{
intrin_mfma_scale_f32_16x16x128f8f6f4
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
scale_a
,
b
,
scale_b
,
reg_c
);
}
}
};
};
...
...
include/ck/utility/amd_xdlops.hpp
View file @
f3af1da6
...
@@ -519,12 +519,36 @@ struct intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32>
...
@@ -519,12 +519,36 @@ struct intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32>
{
{
template
<
class
FloatC
>
template
<
class
FloatC
>
__device__
static
void
Run
(
const
f8x32_t
&
reg_a
,
__device__
static
void
Run
(
const
f8x32_t
&
reg_a
,
const
int32_t
scale_a
,
const
int32_t
&
scale_a
,
const
f8x32_t
&
reg_b
,
const
f8x32_t
&
reg_b
,
const
int32_t
scale_b
,
const
int32_t
&
scale_b
,
FloatC
&
reg_c
)
FloatC
&
reg_c
)
{
{
#if defined(__gfx950__)
#if defined(__gfx950__)
if
(
threadIdx
.
x
==
0
||
threadIdx
.
x
==
32
)
{
printf
(
"thread: %u -- xA: %x
\n
"
,
threadIdx
.
x
,
static_cast
<
uint32_t
>
(
scale_a
));
printf
(
"thread: %u -- xB: %x
\n
"
,
threadIdx
.
x
,
static_cast
<
uint32_t
>
(
scale_b
));
// printf("intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32> thread: %u -- scale_a: %f\n",
// threadIdx.x,
// static_cast<float>(ck::e8m0_bexp_t(scale_a)));
// printf("intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32> thread: %u -- scale_b: %f\n",
// threadIdx.x,
// static_cast<float>(ck::e8m0_bexp_t(scale_b)));
// for(size_t i = 0; i < 32; i++)
// {
// printf("thread: %u -- reg_a[%zu]: %f\n",
// threadIdx.x,
// i,
// type_convert<float>(f8_t{static_cast<f8x32_t::data_v>(reg_a)[i]}));
// // printf("thread: %u -- reg_a[%zu]: %f\n",
// // threadIdx.x,
// // i,
// // type_convert<float>(f8_t{static_cast<f8x32_t::data_v>(reg_b)[i]}));
// }
}
// https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
// https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4
(
__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4
(
...
...
test/mx_mfma_op/mx_mfma_op.cpp
View file @
f3af1da6
...
@@ -30,11 +30,11 @@ bool run_mfma_test(ck::index_t init)
...
@@ -30,11 +30,11 @@ bool run_mfma_test(ck::index_t init)
constexpr
auto
BLOCK_N
=
mfma_instr
.
n_per_blk
;
constexpr
auto
BLOCK_N
=
mfma_instr
.
n_per_blk
;
constexpr
auto
BLOCK_K
=
mfma_instr
.
num_input_blks
*
mfma_instr
.
k_per_blk
;
constexpr
auto
BLOCK_K
=
mfma_instr
.
num_input_blks
*
mfma_instr
.
k_per_blk
;
const
auto
mx_
mfma_kernel
=
ck
::
matmul
<
AType
,
BType
,
CType
,
AccType
,
BLOCK_M
,
BLOCK_N
,
BLOCK_K
>
;
const
auto
mfma_kernel
=
ck
::
matmul
<
AType
,
BType
,
CType
,
AccType
,
BLOCK_M
,
BLOCK_N
,
BLOCK_K
>
;
bool
pass
=
true
;
bool
pass
=
true
;
pass
=
ck
::
mfma_test
::
TestMFMA
<
decltype
(
mx_
mfma_kernel
),
pass
=
ck
::
mfma_test
::
TestMFMA
<
decltype
(
mfma_kernel
),
AType
,
AType
,
BType
,
BType
,
CType
,
CType
,
...
@@ -45,7 +45,7 @@ bool run_mfma_test(ck::index_t init)
...
@@ -45,7 +45,7 @@ bool run_mfma_test(ck::index_t init)
CLayout
,
CLayout
,
BLOCK_M
,
BLOCK_M
,
BLOCK_N
,
BLOCK_N
,
BLOCK_K
>
{}(
mx_
mfma_kernel
,
init
);
BLOCK_K
>
{}(
mfma_kernel
,
init
);
return
pass
;
return
pass
;
}
}
...
@@ -63,3 +63,98 @@ TEST(MFMA, FP8MFMA32x32x64)
...
@@ -63,3 +63,98 @@ TEST(MFMA, FP8MFMA32x32x64)
auto
pass
=
run_mfma_test
<
f8_t
,
f8_t
,
float
,
ck
::
MFMA_F8F6F4
::
F32_32x32x64
>
(
AB_init
);
auto
pass
=
run_mfma_test
<
f8_t
,
f8_t
,
float
,
ck
::
MFMA_F8F6F4
::
F32_32x32x64
>
(
AB_init
);
EXPECT_TRUE
(
pass
);
EXPECT_TRUE
(
pass
);
}
}
/**
* @brief Run the test for the given MX MFMA instruction
*
* @param init - selects initialization algorithm for A and B tensors
*/
template
<
typename
AType
,
typename
BType
,
typename
CType
,
ck
::
MFMA_F8F6F4
mfma
>
bool
run_mxmfma_test
(
ck
::
index_t
init
)
{
static_assert
(
mfma
==
ck
::
MFMA_F8F6F4
::
SCALE_F32_16x16x128
||
mfma
==
ck
::
MFMA_F8F6F4
::
SCALE_F32_32x32x64
,
"Only SCALE_F32_16x16x128 and SCALE_F32_32x32x64 are supported"
);
using
ALayout
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
BLayout
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
CLayout
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
AccType
=
float
;
// only MFMA_F32 instructions supported
// using CPUAccType = AccType;
using
ScaleType
=
ck
::
e8m0_bexp_t
;
// biased exponent type
ck
::
mfma_type
<
static_cast
<
ck
::
MfmaInstr
>
(
mfma
)
>
mfma_instr
;
constexpr
auto
BLOCK_M
=
mfma_instr
.
m_per_blk
;
constexpr
auto
BLOCK_N
=
mfma_instr
.
n_per_blk
;
constexpr
auto
BLOCK_K
=
mfma_instr
.
num_input_blks
*
mfma_instr
.
k_per_blk
;
constexpr
auto
BLOCK_X
=
32
;
// scaling vector size
const
auto
mx_mfma_kernel
=
ck
::
matmul
<
AType
,
BType
,
ScaleType
,
CType
,
AccType
,
BLOCK_M
,
BLOCK_N
,
BLOCK_K
,
BLOCK_X
>
;
bool
pass
=
true
;
pass
=
ck
::
mxmfma_test
::
TestMXMFMA
<
decltype
(
mx_mfma_kernel
),
AType
,
BType
,
ScaleType
,
CType
,
ALayout
,
BLayout
,
CLayout
,
BLOCK_M
,
BLOCK_N
,
BLOCK_K
,
BLOCK_X
>
{}(
mx_mfma_kernel
,
init
);
return
pass
;
}
TEST
(
MXMFMA
,
MXFP8MFMA16x16x128i2
)
{
auto
AB_init
=
2
;
auto
pass
=
run_mxmfma_test
<
f8_t
,
f8_t
,
float
,
ck
::
MFMA_F8F6F4
::
SCALE_F32_16x16x128
>
(
AB_init
);
EXPECT_TRUE
(
pass
);
}
TEST
(
MXMFMA
,
MXFP8MFMA32x32x64i2
)
{
auto
AB_init
=
2
;
auto
pass
=
run_mxmfma_test
<
f8_t
,
f8_t
,
float
,
ck
::
MFMA_F8F6F4
::
SCALE_F32_32x32x64
>
(
AB_init
);
EXPECT_TRUE
(
pass
);
}
TEST
(
MXMFMA
,
MXFP8MFMA16x16x128i3
)
{
auto
AB_init
=
3
;
auto
pass
=
run_mxmfma_test
<
f8_t
,
f8_t
,
float
,
ck
::
MFMA_F8F6F4
::
SCALE_F32_16x16x128
>
(
AB_init
);
EXPECT_TRUE
(
pass
);
}
TEST
(
MXMFMA
,
MXFP8MFMA32x32x64i3
)
{
auto
AB_init
=
3
;
auto
pass
=
run_mxmfma_test
<
f8_t
,
f8_t
,
float
,
ck
::
MFMA_F8F6F4
::
SCALE_F32_32x32x64
>
(
AB_init
);
EXPECT_TRUE
(
pass
);
}
TEST
(
MXMFMA
,
MXFP8MFMA16x16x128i4
)
{
auto
AB_init
=
4
;
auto
pass
=
run_mxmfma_test
<
f8_t
,
f8_t
,
float
,
ck
::
MFMA_F8F6F4
::
SCALE_F32_16x16x128
>
(
AB_init
);
EXPECT_TRUE
(
pass
);
}
TEST
(
MXMFMA
,
MXFP8MFMA32x32x64i4
)
{
auto
AB_init
=
4
;
auto
pass
=
run_mxmfma_test
<
f8_t
,
f8_t
,
float
,
ck
::
MFMA_F8F6F4
::
SCALE_F32_32x32x64
>
(
AB_init
);
EXPECT_TRUE
(
pass
);
}
TEST
(
MXMFMA
,
MXFP8MFMA32x32x64i5
)
{
auto
AB_init
=
5
;
auto
pass
=
run_mxmfma_test
<
f8_t
,
f8_t
,
float
,
ck
::
MFMA_F8F6F4
::
SCALE_F32_32x32x64
>
(
AB_init
);
EXPECT_TRUE
(
pass
);
}
test/mx_mfma_op/mx_mfma_op.hpp
View file @
f3af1da6
...
@@ -18,7 +18,13 @@ enum class MFMA_F8F6F4
...
@@ -18,7 +18,13 @@ enum class MFMA_F8F6F4
F32_16x16x128
=
F32_16x16x128
=
static_cast
<
int
>
(
MfmaInstr
::
mfma_f32_16x16x128f8f6f4
),
// V_MFMA_F32_16X16X128_F8F6F4
static_cast
<
int
>
(
MfmaInstr
::
mfma_f32_16x16x128f8f6f4
),
// V_MFMA_F32_16X16X128_F8F6F4
F32_32x32x64
=
F32_32x32x64
=
static_cast
<
int
>
(
MfmaInstr
::
mfma_f32_32x32x64f8f6f4
)
// V_MFMA_F32_32X32X64_F8F6F4
static_cast
<
int
>
(
MfmaInstr
::
mfma_f32_32x32x64f8f6f4
),
// V_MFMA_F32_32X32X64_F8F6F4
SCALE_F32_16x16x128
=
static_cast
<
int
>
(
MfmaInstr
::
mfma_scale_f32_16x16x128f8f6f4
),
// V_MFMA_SCALE_F32_16X16X128_F8F6F4
SCALE_F32_32x32x64
=
static_cast
<
int
>
(
MfmaInstr
::
mfma_scale_f32_32x32x64f8f6f4
)
// V_MFMA_SCALE_F32_32X32X64_F8F6F4
};
};
template
<
typename
AFragT
,
typename
BFragT
,
typename
AccumFragT
,
int32_t
BLOCK_M
,
int32_t
BLOCK_N
>
template
<
typename
AFragT
,
typename
BFragT
,
typename
AccumFragT
,
int32_t
BLOCK_M
,
int32_t
BLOCK_N
>
...
@@ -32,6 +38,17 @@ struct mfma_type_selector<AFragT, BFragT, AccumFragT, 16, 16>
...
@@ -32,6 +38,17 @@ struct mfma_type_selector<AFragT, BFragT, AccumFragT, 16, 16>
auto
op
=
mfma_type
<
MfmaInstr
::
mfma_f32_16x16x128f8f6f4
>
{};
auto
op
=
mfma_type
<
MfmaInstr
::
mfma_f32_16x16x128f8f6f4
>
{};
op
.
template
run
<
16
,
16
,
AFragT
,
BFragT
,
AccumFragT
>(
fragA
,
fragB
,
fragAcc
);
op
.
template
run
<
16
,
16
,
AFragT
,
BFragT
,
AccumFragT
>(
fragA
,
fragB
,
fragAcc
);
}
}
__device__
void
operator
()(
AFragT
const
&
fragA
,
const
int32_t
&
scale_a
,
BFragT
const
&
fragB
,
const
int32_t
&
scale_b
,
AccumFragT
&
fragAcc
)
{
auto
op
=
mfma_type
<
MfmaInstr
::
mfma_scale_f32_16x16x128f8f6f4
>
{};
op
.
template
run
<
16
,
16
,
AFragT
,
BFragT
,
AccumFragT
>(
fragA
,
scale_a
,
fragB
,
scale_b
,
fragAcc
);
}
};
};
template
<
typename
AFragT
,
typename
BFragT
,
typename
AccumFragT
>
template
<
typename
AFragT
,
typename
BFragT
,
typename
AccumFragT
>
...
@@ -42,6 +59,17 @@ struct mfma_type_selector<AFragT, BFragT, AccumFragT, 32, 32>
...
@@ -42,6 +59,17 @@ struct mfma_type_selector<AFragT, BFragT, AccumFragT, 32, 32>
auto
op
=
mfma_type
<
MfmaInstr
::
mfma_f32_32x32x64f8f6f4
>
{};
auto
op
=
mfma_type
<
MfmaInstr
::
mfma_f32_32x32x64f8f6f4
>
{};
op
.
template
run
<
32
,
32
,
AFragT
,
BFragT
,
AccumFragT
>(
fragA
,
fragB
,
fragAcc
);
op
.
template
run
<
32
,
32
,
AFragT
,
BFragT
,
AccumFragT
>(
fragA
,
fragB
,
fragAcc
);
}
}
__device__
void
operator
()(
AFragT
const
&
fragA
,
const
int32_t
&
scale_a
,
BFragT
const
&
fragB
,
const
int32_t
&
scale_b
,
AccumFragT
&
fragAcc
)
{
auto
op
=
mfma_type
<
MfmaInstr
::
mfma_scale_f32_32x32x64f8f6f4
>
{};
op
.
template
run
<
32
,
32
,
AFragT
,
BFragT
,
AccumFragT
>(
fragA
,
scale_a
,
fragB
,
scale_b
,
fragAcc
);
}
};
};
template
<
typename
VecT
>
template
<
typename
VecT
>
...
@@ -131,11 +159,121 @@ __device__ AFragT load_A_col_major(AType const* input_ptr)
...
@@ -131,11 +159,121 @@ __device__ AFragT load_A_col_major(AType const* input_ptr)
return
fragA
;
return
fragA
;
}
}
// Define a load function for input A blocks:
// Size: (BLOCK_M x BLOCK_K)
// ASSUMPTION:
// - We want contiguous BLOCK_M sized column neighbors in register.
// - Data is in row major format
// This means:
// - From A we will load BLOCK_M rows of size K to satisfy our input data
template
<
typename
AType
,
typename
AFragT
,
int32_t
BLOCK_M
,
int32_t
BLOCK_K
>
__device__
AFragT
load_A_row_major
(
AType
const
*
input_ptr
)
{
// clang-format off
// Register Mapping for 16x128: || Register Mapping for 32x64:
// Size | BLOCK_M | BLOCK_M | BLOCK_M | BLOCK_M | || Size | BLOCK_M | BLOCK_M |
// M | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 | || M | 0 ... 31 | 0 ... 31 |
// Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Vector || Thread Id | 0 ... 31 | 32 ... 63 | Vector
// Register Element ------------ ------------- ------------ ------------- Element || Register Element ------------ ------------- Element
// Reg 0 [0:7] | K0 | K32 | K64 | K96 | v[0] || Reg 0 [0:7] | K0 | K32 | v[0]
// Reg 0 [8:15] | K1 | K33 | K65 | K97 | v[1] || Reg 0 [8:15] | K1 | K33 | v[1]
// Reg 0 [16:23] | K2 | K34 | K66 | K98 | v[2] || Reg 0 [16:23] | K2 | K34 | v[2]
// Reg 0 [24:31] | K3 | K35 | K67 | K99 | v[3] || Reg 0 [24:31] | K3 | K35 | v[3]
// Reg 1 [0:7] | K4 | K36 | K68 | K100 | v[4] || Reg 1 [0:7] | K4 | K36 | v[4]
// Reg 1 [8:15] | K5 | K37 | K69 | K101 | v[5] || Reg 1 [8:15] | K5 | K37 | v[5]
// Reg 1 [16:23] | K6 | K38 | K70 | K102 | v[6] || Reg 1 [16:23] | K6 | K38 | v[6]
// Reg 1 [24:31] | K7 | K39 | K71 | K103 | v[7] || Reg 1 [24:31] | K7 | K39 | v[7]
// Reg 2 [0:7] | K8 | K40 | K72 | K104 | v[8] || Reg 2 [0:7] | K8 | K40 | v[8]
// Reg 2 [8:15] | K9 | K41 | K73 | K105 | v[9] || Reg 2 [8:15] | K9 | K41 | v[9]
// Reg 2 [16:23] | K10 | K42 | K74 | K106 | v[10] || Reg 2 [16:23] | K10 | K42 | v[10]
// Reg 2 [24:31] | K11 | K43 | K75 | K107 | v[11] || Reg 2 [24:31] | K11 | K43 | v[11]
// Reg 3 [0:7] | K12 | K44 | K76 | K108 | v[12] || Reg 3 [0:7] | K12 | K44 | v[12]
// Reg 3 [8:15] | K13 | K45 | K77 | K109 | v[13] || Reg 3 [8:15] | K13 | K45 | v[13]
// Reg 3 [16:23] | K14 | K46 | K78 | K110 | v[14] || Reg 3 [16:23] | K14 | K46 | v[14]
// Reg 3 [24:31] | K15 | K47 | K79 | K111 | v[15] || Reg 3 [24:31] | K15 | K47 | v[15]
// Reg 4 [0:7] | K16 | K48 | K80 | K112 | v[16] || Reg 4 [0:7] | K16 | K48 | v[16]
// Reg 4 [8:15] | K17 | K49 | K81 | K113 | v[17] || Reg 4 [8:15] | K17 | K49 | v[17]
// Reg 4 [16:23] | K18 | K50 | K82 | K114 | v[18] || Reg 4 [16:23] | K18 | K50 | v[18]
// Reg 4 [24:31] | K19 | K51 | K83 | K115 | v[19] || Reg 4 [24:31] | K19 | K51 | v[19]
// Reg 5 [0:7] | K20 | K52 | K84 | K116 | v[20] || Reg 5 [0:7] | K20 | K52 | v[20]
// Reg 5 [8:15] | K21 | K53 | K85 | K117 | v[21] || Reg 5 [8:15] | K21 | K53 | v[21]
// Reg 5 [16:23] | K22 | K54 | K86 | K118 | v[22] || Reg 5 [16:23] | K22 | K54 | v[22]
// Reg 5 [24:31] | K23 | K55 | K87 | K119 | v[23] || Reg 5 [24:31] | K23 | K55 | v[23]
// Reg 6 [0:7] | K24 | K56 | K88 | K120 | v[24] || Reg 6 [0:7] | K24 | K56 | v[24]
// Reg 6 [8:15] | K25 | K57 | K89 | K121 | v[25] || Reg 6 [8:15] | K25 | K57 | v[25]
// Reg 6 [16:23] | K26 | K58 | K90 | K122 | v[26] || Reg 6 [16:23] | K26 | K58 | v[26]
// Reg 6 [24:31] | K27 | K59 | K91 | K123 | v[27] || Reg 6 [24:31] | K27 | K59 | v[27]
// Reg 7 [0:7] | K28 | K60 | K92 | K124 | v[28] || Reg 7 [0:7] | K28 | K60 | v[28]
// Reg 7 [8:15] | K29 | K61 | K93 | K125 | v[29] || Reg 7 [8:15] | K29 | K61 | v[29]
// Reg 7 [16:23] | K30 | K62 | K94 | K126 | v[30] || Reg 7 [16:23] | K30 | K62 | v[30]
// Reg 7 [24:31] | K31 | K63 | K95 | K127 | v[31] || Reg 7 [24:31] | K31 | K63 | v[31]
// clang-format on
// Here we want to load a BLOCK_M x BLOCK_K block of data.
static
constexpr
uint32_t
VW
=
vectorSize
(
AFragT
{});
// To start the loading process, let's visualize in 2D coords.
// Each thread will load 32 elements.
// We need to know where they start, and where the next elements are.
auto
startCoord2D
=
std
::
make_pair
(
threadIdx
.
x
%
BLOCK_M
,
// Row
(
threadIdx
.
x
/
BLOCK_M
)
*
VW
);
// Col
// Flatten to 1D row_major offsets.
auto
row_major
=
[](
auto
const
&
coord
,
auto
ld
)
{
return
coord
.
first
*
ld
+
coord
.
second
;
};
// BLOCK_K is a stride in A matrix
auto
startOffset
=
row_major
(
startCoord2D
,
BLOCK_K
);
auto
const
*
fragPtr
=
reinterpret_cast
<
AFragT
const
*>
(
input_ptr
+
startOffset
);
return
*
fragPtr
;
}
// Define a load function for scaled A blocks:
// Size: (BLOCK_M x BLOCK_K)
// ASSUMPTION:
// - We want contiguous BLOCK_M sized column neighbors in register.
// - Data is in row major format
// - The scale inputs distributed across 64 lanes.
// This means:
// - From A we will load BLOCK_M rows of size K to satisfy our input data
template
<
typename
AType
,
typename
AFragT
,
typename
ScaleType
,
typename
ScaleFragT
,
int32_t
BLOCK_M
,
int32_t
BLOCK_K
,
int32_t
BLOCK_X
>
__device__
AFragT
load_mx_A_row_major
(
AType
const
*
input_ptr
,
ScaleType
const
*
scale_ptr
,
ScaleFragT
&
fragX
)
{
static
constexpr
uint32_t
VW
=
vectorSize
(
AFragT
{});
static_assert
(
VW
==
BLOCK_X
,
"Fragment size must be equal to BLOCK_X"
);
// To start the loading process, let's visualize in 2D coords.
// Each thread will load 1 element
// We need to know where they start
auto
startCoord2D
=
std
::
make_pair
(
threadIdx
.
x
%
BLOCK_M
,
// Row
(
threadIdx
.
x
/
BLOCK_M
)
*
VW
/
BLOCK_X
);
// Col
// Flatten to 1D row_major offsets.
auto
row_major
=
[](
auto
const
&
coord
,
auto
ld
)
{
return
coord
.
first
*
ld
+
coord
.
second
;
};
// BLOCK_K / BLOCK_X is a stride in xA matrix
auto
startOffset
=
row_major
(
startCoord2D
,
BLOCK_K
/
BLOCK_X
);
// preserve upper bits obtain 8-bit exponent
fragX
=
(
fragX
&
0xFFFFFF00
)
|
(
utils
::
get_exponent_value
(
scale_ptr
[
startOffset
])
&
0xFF
);
return
load_A_row_major
<
AType
,
AFragT
,
BLOCK_M
,
BLOCK_K
>
(
input_ptr
);
}
// Define a load function for input B blocks:
// Define a load function for input B blocks:
// Size: (BLOCK_K x BLOCK_N)
// Size: (BLOCK_K x BLOCK_N)
// ASSUMPTION:
// ASSUMPTION:
// - We want contiguous BLOCK_N sized row neighbors in register.
// - We want contiguous BLOCK_N sized row neighbors in register.
// - Data is in
row_
major format
// - Data is in
column
major format
// This means:
// This means:
// - From B we will load K rows of size BLOCK_N to satisfy our input data
// - From B we will load K rows of size BLOCK_N to satisfy our input data
template
<
typename
BType
,
typename
BFragT
,
int32_t
BLOCK_K
,
int32_t
BLOCK_N
>
template
<
typename
BType
,
typename
BFragT
,
int32_t
BLOCK_K
,
int32_t
BLOCK_N
>
...
@@ -199,6 +337,46 @@ __device__ BFragT load_B_col_major(BType const* input_ptr)
...
@@ -199,6 +337,46 @@ __device__ BFragT load_B_col_major(BType const* input_ptr)
return
*
fragPtr
;
return
*
fragPtr
;
}
}
// Define a load function for scaled B blocks:
// Size: (BLOCK_K x BLOCK_N)
// ASSUMPTION:
// - We want contiguous BLOCK_N sized row neighbors in register.
// - Data is in column major format
// - The scale inputs distributed across 64 lanes.
// This means:
// - From B we will load K rows of size BLOCK_N to satisfy our input data
template
<
typename
BType
,
typename
BFragT
,
typename
ScaleType
,
typename
ScaleFragT
,
int32_t
BLOCK_K
,
int32_t
BLOCK_N
,
int32_t
BLOCK_X
>
__device__
BFragT
load_mx_B_col_major
(
BType
const
*
input_ptr
,
ScaleType
const
*
scale_ptr
,
ScaleFragT
&
fragX
)
{
static
constexpr
uint32_t
VW
=
vectorSize
(
BFragT
{});
static_assert
(
VW
==
BLOCK_X
,
"Fragment size must be equal to BLOCK_X"
);
// To start the loading process, let's visualize in 2D coords.
// Each thread will load 1 element
// We need to know where to start
auto
startCoord2D
=
std
::
make_pair
((
threadIdx
.
x
/
BLOCK_N
)
*
VW
/
BLOCK_X
,
// Row
threadIdx
.
x
%
BLOCK_N
);
// Col
// Flatten to 1D col_major offsets.
auto
col_major
=
[](
auto
const
&
coord
,
auto
ld
)
{
return
coord
.
first
+
coord
.
second
*
ld
;
};
auto
startOffset
=
col_major
(
startCoord2D
,
BLOCK_K
/
BLOCK_X
);
// preserve upper bits obtain 8-bit exponent
fragX
=
(
fragX
&
0xFFFFFF00
)
|
(
utils
::
get_exponent_value
(
scale_ptr
[
startOffset
])
&
0xFF
);
return
load_B_col_major
<
BType
,
BFragT
,
BLOCK_K
,
BLOCK_N
>
(
input_ptr
);
}
// Define a store function for C
// Define a store function for C
// Size: (BLOCK_M x BLOCK_N)
// Size: (BLOCK_M x BLOCK_N)
// ASSUMPTION:
// ASSUMPTION:
...
@@ -309,6 +487,129 @@ struct store_C_col_major<CType, CFragT, 32, 32>
...
@@ -309,6 +487,129 @@ struct store_C_col_major<CType, CFragT, 32, 32>
}
}
};
};
// Define a store function for C
// Size: (BLOCK_M x BLOCK_N)
// ASSUMPTION:
// - We want contiguous BLOCK_N sized row neighbors in register.
// - Data is in row major format
template
<
typename
CType
,
typename
CFragT
,
int32_t
BLOCK_M
,
int32_t
BLOCK_N
>
struct
store_C_row_major
;
// Here we want to store a 16x16 block of data.
//
// Size | BLOCK_N | BLOCK_N | BLOCK_N | BLOCK_N |
// N | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 |
// Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Vector
// Register Element ------------ ------------- ------------ -------------- Element
// Reg0 | M0 | M4 | M8 | M12 | v[0]
// Reg1 | M1 | M5 | M9 | M13 | v[1]
// Reg2 | M2 | M6 | M10 | M14 | v[2]
// Reg3 | M3 | M7 | M11 | M15 | v[3]
template
<
typename
CType
,
typename
CFragT
>
struct
store_C_row_major
<
CType
,
CFragT
,
16
,
16
>
{
__device__
void
operator
()(
CType
*
output
,
CFragT
cFrag
)
{
static
constexpr
uint32_t
VW
=
vectorSize
(
cFrag
);
// 4
static
constexpr
uint32_t
Dim
=
16
;
// Each thread will load 4 elements.
// We need to know where they start, and where the next elements are.
auto
startCoord2D
=
std
::
make_pair
((
threadIdx
.
x
/
Dim
)
*
VW
,
// Row
threadIdx
.
x
%
Dim
);
// Col
auto
stepCoord2D
=
std
::
make_pair
(
1u
,
0u
);
// Flatten to 1D row_major offsets.
auto
row_major
=
[](
auto
const
&
coord
,
auto
ld
)
{
return
coord
.
first
*
ld
+
coord
.
second
;
};
auto
startOffset
=
row_major
(
startCoord2D
,
16
);
auto
kOffset
=
row_major
(
stepCoord2D
,
16
);
auto
*
fragPtr
=
reinterpret_cast
<
CFragT
*>
(
output
+
startOffset
);
*
fragPtr
=
cFrag
;
// If you notice carefully, kOffset != 1.
// This means the following is vector is updated with 4 non-contiguous offsets,
// which the compiler will separate into 4 different global_store_dword instructions.
output
[
startOffset
]
=
cFrag
[
0
];
// v[0] = Reg 0
output
[
startOffset
+
kOffset
]
=
cFrag
[
1
];
// v[1] = Reg 1
output
[
startOffset
+
2
*
kOffset
]
=
cFrag
[
2
];
// v[2] = Reg 2
output
[
startOffset
+
3
*
kOffset
]
=
cFrag
[
3
];
// v[3] = Reg 3
}
};
// Here we want to store a 32x32 block of data.
// Register Mapping:
// Size | BLOCK_N | BLOCK_N |
// N | 0 ... 31 | 0 ... 31 |
// Thread Id | 0 ... 31 | 32 ... 63 | Vector
// Register Element ------------ ------------- Element
// Reg0 | M0 | M4 | v[0]
// Reg1 | M1 | M5 | v[1]
// Reg2 | M2 | M6 | v[2]
// Reg3 | M3 | M7 | v[3]
// ____________ _____________
// Reg4 | M8 | M12 | v[4]
// Reg5 | M9 | M13 | v[5]
// Reg6 | M10 | M14 | v[6]
// Reg7 | M11 | M15 | v[7]
// ____________ _____________
// Reg8 | M16 | M20 | v[8]
// Reg9 | M17 | M21 | v[9]
// Reg10 | M18 | M22 | v[10]
// Reg11 | M19 | M23 | v[11]
// ____________ _____________
// Reg12 | M24 | M28 | v[12]
// Reg13 | M25 | M29 | v[13]
// Reg14 | M26 | M30 | v[14]
// Reg15 | M27 | M31 | v[15]
template
<
typename
CType
,
typename
CFragT
>
struct
store_C_row_major
<
CType
,
CFragT
,
32
,
32
>
{
__device__
void
operator
()(
CType
*
output
,
CFragT
cFrag
)
{
static
constexpr
uint32_t
WAVE_SIZE
=
64
;
static
constexpr
uint32_t
VW
=
4
;
// This VW is per 'chunk'
static
constexpr
uint32_t
Dim
=
32
;
// BLOCK_N
static
constexpr
uint32_t
M_PER_VW_CHUNK
=
VW
*
WAVE_SIZE
/
32
;
// 8
auto
startCoord2D
=
std
::
make_pair
((
threadIdx
.
x
/
Dim
)
*
VW
,
// Row
threadIdx
.
x
%
Dim
);
// Col
// Minor step for each 'chunk'
auto
minorStepCoord2D
=
std
::
make_pair
(
1u
,
0u
);
// Major step between 'chunks'
auto
majorStepCoord2D
=
std
::
make_pair
(
M_PER_VW_CHUNK
,
0
);
// Flatten to 1D row_major offsets.
auto
row_major
=
[](
auto
const
&
coord
,
auto
ld
)
{
return
coord
.
first
*
ld
+
coord
.
second
;
};
auto
startOffset
=
row_major
(
startCoord2D
,
32
);
auto
kMinorOffset
=
row_major
(
minorStepCoord2D
,
32
);
auto
kMajorOffset
=
row_major
(
majorStepCoord2D
,
32
);
output
[
startOffset
]
=
cFrag
[
0
];
// v[0] = Reg 0
output
[
startOffset
+
kMinorOffset
]
=
cFrag
[
1
];
// v[1] = Reg 1
output
[
startOffset
+
2
*
kMinorOffset
]
=
cFrag
[
2
];
// v[2] = Reg 2
output
[
startOffset
+
3
*
kMinorOffset
]
=
cFrag
[
3
];
// v[3] = Reg 3
output
[
startOffset
+
kMajorOffset
]
=
cFrag
[
4
];
// v[4] = Reg 4
output
[
startOffset
+
kMajorOffset
+
kMinorOffset
]
=
cFrag
[
5
];
// v[5] = Reg 5
output
[
startOffset
+
kMajorOffset
+
2
*
kMinorOffset
]
=
cFrag
[
6
];
// v[6] = Reg 6
output
[
startOffset
+
kMajorOffset
+
3
*
kMinorOffset
]
=
cFrag
[
7
];
// v[7] = Reg 7
output
[
startOffset
+
2
*
kMajorOffset
]
=
cFrag
[
8
];
// v[8] = Reg 8
output
[
startOffset
+
2
*
kMajorOffset
+
kMinorOffset
]
=
cFrag
[
9
];
// v[9] = Reg 9
output
[
startOffset
+
2
*
kMajorOffset
+
2
*
kMinorOffset
]
=
cFrag
[
10
];
// v[10] = Reg 10
output
[
startOffset
+
2
*
kMajorOffset
+
3
*
kMinorOffset
]
=
cFrag
[
11
];
// v[11] = Reg 11
output
[
startOffset
+
3
*
kMajorOffset
]
=
cFrag
[
12
];
// v[12] = Reg 12
output
[
startOffset
+
3
*
kMajorOffset
+
kMinorOffset
]
=
cFrag
[
13
];
// v[13] = Reg 13
output
[
startOffset
+
3
*
kMajorOffset
+
2
*
kMinorOffset
]
=
cFrag
[
14
];
// v[14] = Reg 14
output
[
startOffset
+
3
*
kMajorOffset
+
3
*
kMinorOffset
]
=
cFrag
[
15
];
// v[15] = Reg 15
}
};
template
<
typename
AType
,
template
<
typename
AType
,
typename
BType
,
typename
BType
,
typename
CType
,
typename
CType
,
...
@@ -342,7 +643,9 @@ __global__ void matmul(const AType* a, const BType* b, CType* c)
...
@@ -342,7 +643,9 @@ __global__ void matmul(const AType* a, const BType* b, CType* c)
// Matrix multiply-accumulate using MFMA units
// Matrix multiply-accumulate using MFMA units
// Accumulation intermediate = BLOCK_M x BLOCK_N
// Accumulation intermediate = BLOCK_M x BLOCK_N
__syncthreads
();
mfma_type_selector
<
AFragT
,
BFragT
,
AccumFragT
,
BLOCK_M
,
BLOCK_N
>
{}(
fragA
,
fragB
,
fragAcc
);
mfma_type_selector
<
AFragT
,
BFragT
,
AccumFragT
,
BLOCK_M
,
BLOCK_N
>
{}(
fragA
,
fragB
,
fragAcc
);
__syncthreads
();
for
(
int
i
=
0
;
i
<
vectorSize
(
fragC
);
++
i
)
for
(
int
i
=
0
;
i
<
vectorSize
(
fragC
);
++
i
)
{
{
...
@@ -352,6 +655,139 @@ __global__ void matmul(const AType* a, const BType* b, CType* c)
...
@@ -352,6 +655,139 @@ __global__ void matmul(const AType* a, const BType* b, CType* c)
auto
storeC
=
store_C_col_major
<
CType
,
CFragT
,
BLOCK_M
,
BLOCK_N
>
{};
auto
storeC
=
store_C_col_major
<
CType
,
CFragT
,
BLOCK_M
,
BLOCK_N
>
{};
storeC
(
c
,
fragC
);
storeC
(
c
,
fragC
);
}
}
template
<
typename
AType
,
typename
BType
,
typename
ScaleType
,
typename
CType
,
typename
AccType
,
int32_t
BLOCK_M
,
int32_t
BLOCK_N
,
int32_t
BLOCK_K
,
int32_t
BLOCK_X
>
__global__
void
matmul
(
const
AType
*
a
,
const
ScaleType
*
xa
,
const
BType
*
b
,
const
ScaleType
*
xb
,
CType
*
c
)
{
constexpr
int
WAVE_SIZE
=
64
;
assert
(
threadIdx
.
x
<
WAVE_SIZE
);
assert
(
blockDim
.
x
==
1
&&
blockDim
.
y
==
1
&&
blockDim
.
z
==
1
);
using
AFragT
=
vector_type
<
AType
,
BLOCK_M
*
BLOCK_K
/
WAVE_SIZE
>::
type
;
using
BFragT
=
vector_type
<
BType
,
BLOCK_K
*
BLOCK_N
/
WAVE_SIZE
>::
type
;
using
CFragT
=
vector_type
<
CType
,
BLOCK_M
*
BLOCK_N
/
WAVE_SIZE
>::
type
;
using
AccumFragT
=
vector_type
<
AccType
,
BLOCK_M
*
BLOCK_N
/
WAVE_SIZE
>
;
using
RawAccumFragT
=
vector_type
<
AccType
,
BLOCK_M
*
BLOCK_N
/
WAVE_SIZE
>::
type
;
using
ScaleFragT
=
int32_t
;
// Create frags
auto
fragA
=
AFragT
{};
auto
fragB
=
BFragT
{};
auto
fragC
=
CFragT
{};
auto
fragAcc
=
AccumFragT
{
0
};
auto
fragXa
=
ScaleFragT
{
0
};
auto
fragXb
=
ScaleFragT
{
0
};
// Load the inputs.
// A = col major, BLOCK_M x BLOCK_K
fragA
=
load_mx_A_row_major
<
AType
,
AFragT
,
ScaleType
,
ScaleFragT
,
BLOCK_M
,
BLOCK_K
,
BLOCK_X
>
(
a
,
xa
,
fragXa
);
// B = col major, BLOCK_K x BLOCK_N
fragB
=
load_mx_B_col_major
<
BType
,
BFragT
,
ScaleType
,
ScaleFragT
,
BLOCK_K
,
BLOCK_N
,
BLOCK_X
>
(
b
,
xb
,
fragXb
);
// Scaled Matrix multiply-accumulate using MFMA units
// Accumulation intermediate = BLOCK_M x BLOCK_N
__syncthreads
();
// printf("thread: %u -- fragXa: %d\n", threadIdx.x, fragXa);
printf
(
"thread: %u -- fragA: %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x "
"%x %x %x %x %x %x %x %x %x %x
\n
"
,
threadIdx
.
x
,
fragA
.
data_
.
dN
[
0
],
fragA
.
data_
.
dN
[
1
],
fragA
.
data_
.
dN
[
2
],
fragA
.
data_
.
dN
[
3
],
fragA
.
data_
.
dN
[
4
],
fragA
.
data_
.
dN
[
5
],
fragA
.
data_
.
dN
[
6
],
fragA
.
data_
.
dN
[
7
],
fragA
.
data_
.
dN
[
8
],
fragA
.
data_
.
dN
[
9
],
fragA
.
data_
.
dN
[
10
],
fragA
.
data_
.
dN
[
11
],
fragA
.
data_
.
dN
[
12
],
fragA
.
data_
.
dN
[
13
],
fragA
.
data_
.
dN
[
14
],
fragA
.
data_
.
dN
[
15
],
fragA
.
data_
.
dN
[
16
],
fragA
.
data_
.
dN
[
17
],
fragA
.
data_
.
dN
[
18
],
fragA
.
data_
.
dN
[
19
],
fragA
.
data_
.
dN
[
20
],
fragA
.
data_
.
dN
[
21
],
fragA
.
data_
.
dN
[
22
],
fragA
.
data_
.
dN
[
23
],
fragA
.
data_
.
dN
[
24
],
fragA
.
data_
.
dN
[
25
],
fragA
.
data_
.
dN
[
26
],
fragA
.
data_
.
dN
[
27
],
fragA
.
data_
.
dN
[
28
],
fragA
.
data_
.
dN
[
29
],
fragA
.
data_
.
dN
[
30
],
fragA
.
data_
.
dN
[
31
]);
printf
(
"thread: %u -- fragB: %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x "
"%x %x %x %x %x %x %x %x %x %x
\n
"
,
threadIdx
.
x
,
fragB
.
data_
.
dN
[
0
],
fragB
.
data_
.
dN
[
1
],
fragB
.
data_
.
dN
[
2
],
fragB
.
data_
.
dN
[
3
],
fragB
.
data_
.
dN
[
4
],
fragB
.
data_
.
dN
[
5
],
fragB
.
data_
.
dN
[
6
],
fragB
.
data_
.
dN
[
7
],
fragB
.
data_
.
dN
[
8
],
fragB
.
data_
.
dN
[
9
],
fragB
.
data_
.
dN
[
10
],
fragB
.
data_
.
dN
[
11
],
fragB
.
data_
.
dN
[
12
],
fragB
.
data_
.
dN
[
13
],
fragB
.
data_
.
dN
[
14
],
fragB
.
data_
.
dN
[
15
],
fragB
.
data_
.
dN
[
16
],
fragB
.
data_
.
dN
[
17
],
fragB
.
data_
.
dN
[
18
],
fragB
.
data_
.
dN
[
19
],
fragB
.
data_
.
dN
[
20
],
fragB
.
data_
.
dN
[
21
],
fragB
.
data_
.
dN
[
22
],
fragB
.
data_
.
dN
[
23
],
fragB
.
data_
.
dN
[
24
],
fragB
.
data_
.
dN
[
25
],
fragB
.
data_
.
dN
[
26
],
fragB
.
data_
.
dN
[
27
],
fragB
.
data_
.
dN
[
28
],
fragB
.
data_
.
dN
[
29
],
fragB
.
data_
.
dN
[
30
],
fragB
.
data_
.
dN
[
31
]);
//__builtin_amdgcn_mfma_ld_scale_b32(fragXa, 0, 0);
mfma_type_selector
<
AFragT
,
BFragT
,
AccumFragT
,
BLOCK_M
,
BLOCK_N
>
{}(
fragA
,
fragXa
,
fragB
,
fragXb
,
fragAcc
);
__syncthreads
();
for
(
int
i
=
0
;
i
<
vectorSize
(
fragC
);
++
i
)
{
fragC
[
i
]
=
type_convert
<
CType
>
(
fragAcc
.
template
AsType
<
RawAccumFragT
>()[
Number
<
0
>
{}][
i
]);
}
__syncthreads
();
auto
storeC
=
store_C_row_major
<
CType
,
CFragT
,
BLOCK_M
,
BLOCK_N
>
{};
storeC
(
c
,
fragC
);
}
/**
/**
* @brief Structure to hold dimension parameters for GEMM tensors.
* @brief Structure to hold dimension parameters for GEMM tensors.
*
*
...
@@ -373,6 +809,384 @@ struct GemmParams
...
@@ -373,6 +809,384 @@ struct GemmParams
ck
::
index_t
StrideC
=
-
1
;
ck
::
index_t
StrideC
=
-
1
;
};
};
namespace
mxmfma_test
{
template
<
typename
ADataType
,
typename
BDataType
,
typename
ScaleType
,
typename
CDataType
>
void
RunHostGEMM
(
const
Tensor
<
ADataType
>&
A
,
const
Tensor
<
ScaleType
>&
a_scales
,
const
Tensor
<
BDataType
>&
B
,
const
Tensor
<
ScaleType
>&
b_scales
,
Tensor
<
CDataType
>&
C
)
{
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
GemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
float
,
float
,
CDataType
,
float
,
PassThrough
,
PassThrough
,
PassThrough
,
float
,
float
>
;
Tensor
<
float
>
a_m_k
(
A
.
mDesc
);
Tensor
<
float
>
b_k_n
(
B
.
mDesc
);
const
auto
M
=
A
.
mDesc
.
GetLengths
()[
0
];
const
auto
N
=
B
.
mDesc
.
GetLengths
()[
1
];
const
auto
K
=
A
.
mDesc
.
GetLengths
()[
1
];
const
auto
BLOCK_X
=
K
/
a_scales
.
mDesc
.
GetLengths
()[
1
];
for
(
size_t
m
=
0
;
m
<
M
;
m
++
)
{
for
(
size_t
k
=
0
;
k
<
K
;
k
++
)
{
a_m_k
(
m
,
k
)
=
type_convert
<
float
>
(
A
(
m
,
k
))
*
type_convert
<
float
>
(
a_scales
(
m
,
k
/
BLOCK_X
));
}
}
for
(
size_t
n
=
0
;
n
<
N
;
n
++
)
{
for
(
size_t
k
=
0
;
k
<
K
;
k
++
)
{
b_k_n
(
k
,
n
)
=
type_convert
<
float
>
(
B
(
k
,
n
))
*
type_convert
<
float
>
(
b_scales
(
k
/
BLOCK_X
,
n
));
}
}
auto
ref_gemm
=
GemmInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
a_m_k
,
b_k_n
,
C
,
PassThrough
{},
PassThrough
{},
PassThrough
{});
ref_invoker
.
Run
(
ref_argument
);
}
template
<
typename
KernelType
,
typename
ADataType
,
typename
BDataType
,
typename
ScaleType
,
typename
CDataType
>
bool
RunDeviceGEMM
(
KernelType
kernel
,
const
Tensor
<
ADataType
>&
A
,
const
Tensor
<
ScaleType
>&
a_scales
,
const
Tensor
<
BDataType
>&
B
,
const
Tensor
<
ScaleType
>&
b_scales
,
Tensor
<
CDataType
>&
C
)
{
DeviceMem
a_m_k_device_buf
(
sizeof
(
ADataType
)
*
A
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
a_scales_device_buf
(
sizeof
(
ScaleType
)
*
a_scales
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_n_k_device_buf
(
sizeof
(
BDataType
)
*
B
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_scales_device_buf
(
sizeof
(
ScaleType
)
*
b_scales
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
c_m_n_device_buf
(
sizeof
(
CDataType
)
*
C
.
mDesc
.
GetElementSpaceSize
());
a_m_k_device_buf
.
ToDevice
(
A
.
mData
.
data
());
a_scales_device_buf
.
ToDevice
(
a_scales
.
mData
.
data
());
b_n_k_device_buf
.
ToDevice
(
B
.
mData
.
data
());
b_scales_device_buf
.
ToDevice
(
b_scales
.
mData
.
data
());
kernel
<<<
1
,
64
>>>
(
static_cast
<
const
ADataType
*>
(
a_m_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
const
ScaleType
*>
(
a_scales_device_buf
.
GetDeviceBuffer
()),
static_cast
<
const
BDataType
*>
(
b_n_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
const
ScaleType
*>
(
b_scales_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_m_n_device_buf
.
GetDeviceBuffer
()));
c_m_n_device_buf
.
FromDevice
(
C
.
mData
.
data
());
return
true
;
}
template
<
typename
DeviceMFMA
,
typename
ADataType
,
typename
BDataType
,
typename
ScaleType
,
typename
CDataType
,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
index_t
BLOCK_M
,
index_t
BLOCK_N
,
index_t
BLOCK_K
,
index_t
BLOCK_X
>
struct
TestMXMFMA
{
auto
PrepareGemmTensors
(
const
GemmParams
&
params
,
index_t
init
)
{
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
if
(
std
::
is_same
<
decltype
(
layout
),
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
row
,
col
}),
std
::
vector
<
std
::
size_t
>
({
stride
,
1
}));
}
else
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
row
,
col
}),
std
::
vector
<
std
::
size_t
>
({
1
,
stride
}));
}
};
Tensor
<
ADataType
>
a_m_k
(
f_host_tensor_descriptor
(
params
.
M
,
params
.
K
,
params
.
StrideA
,
ALayout
{}));
Tensor
<
ScaleType
>
a_scales
(
f_host_tensor_descriptor
(
params
.
M
,
params
.
K
/
BLOCK_X
,
params
.
K
/
BLOCK_X
,
ALayout
{}));
Tensor
<
BDataType
>
b_n_k
(
f_host_tensor_descriptor
(
params
.
K
,
params
.
N
,
params
.
StrideB
,
BLayout
{}));
Tensor
<
ScaleType
>
b_scales
(
f_host_tensor_descriptor
(
params
.
K
/
BLOCK_X
,
params
.
N
,
params
.
K
/
BLOCK_X
,
BLayout
{}));
Tensor
<
CDataType
>
c_m_n_host_result
(
f_host_tensor_descriptor
(
params
.
M
,
params
.
N
,
params
.
StrideC
,
CLayout
{}));
Tensor
<
CDataType
>
c_m_n_device_result
(
f_host_tensor_descriptor
(
params
.
M
,
params
.
N
,
params
.
StrideC
,
CLayout
{}));
switch
(
init
)
{
case
0
:
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_1
<
ADataType
>
{
1.0
f
});
a_scales
.
GenerateTensorValue
(
GeneratorTensor_1
<
ScaleType
>
{
ScaleType
{
0.015625
f
}});
// NOTE: not all numbers are representable in FP8, BF8, etc.
b_n_k
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
BDataType
,
1
>
{});
b_scales
.
GenerateTensorValue
(
GeneratorTensor_1
<
ScaleType
>
{
ScaleType
{
1.0
f
}});
break
;
case
1
:
// results in C = {K}
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_1
<
ADataType
>
{
1.0
f
});
a_scales
.
GenerateTensorValue
(
GeneratorTensor_1
<
ScaleType
>
{
ScaleType
{
512.0
f
}});
b_n_k
.
GenerateTensorValue
(
GeneratorTensor_1
<
BDataType
>
{
1.0
f
});
b_scales
.
GenerateTensorValue
(
GeneratorTensor_1
<
ScaleType
>
{
ScaleType
{
1.0
f
/
512
}});
break
;
case
2
:
// expect small round off errors
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
-
2.0
,
2.0
});
a_scales
.
GenerateTensorValue
(
GeneratorTensor_2
<
ScaleType
>
{
127
,
129
});
// 1, 2 // scales: {0.5, 1, 2}
b_n_k
.
GenerateTensorValue
(
GeneratorTensor_1
<
BDataType
>
{
1.0
f
});
b_scales
.
GenerateTensorValue
(
GeneratorTensor_1
<
ScaleType
>
{
ScaleType
{
1.0
f
}});
break
;
case
3
:
// expect small round off errors
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
-
2.0
,
2.0
});
a_scales
.
GenerateTensorValue
(
GeneratorTensor_2
<
ScaleType
>
{
128
,
129
});
// 2
// a_scales.GenerateTensorValue(GeneratorTensor_1<ScaleType>{ScaleType{1.0f}});
b_n_k
.
GenerateTensorValue
(
GeneratorTensor_1
<
BDataType
>
{
1.0
f
});
b_scales
.
GenerateTensorValue
(
GeneratorTensor_1
<
ScaleType
>
{
ScaleType
{
1.0
f
}});
break
;
case
4
:
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_1
<
ADataType
>
{
1.3
});
a_scales
.
GenerateTensorValue
(
GeneratorTensor_2
<
ScaleType
>
{
126
,
128
});
// 1, 2
b_n_k
.
GenerateTensorValue
(
GeneratorTensor_1
<
BDataType
>
{
1.0
f
});
b_scales
.
GenerateTensorValue
(
GeneratorTensor_1
<
ScaleType
>
{
ScaleType
{
1.0
f
}});
break
;
case
5
:
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_1
<
ADataType
>
{
0.0
});
for
(
size_t
i
=
0
;
i
<
32
;
i
++
)
{
a_m_k
(
0
,
i
)
=
type_convert
<
ADataType
>
(
1.0
f
);
}
for
(
size_t
i
=
32
;
i
<
64
;
i
++
)
{
a_m_k
(
0
,
i
)
=
type_convert
<
ADataType
>
(
-
2.0
f
);
}
// printf("f8 1: %x \n", type_convert<ADataType>(1.0f).data);
// printf("f8 -2: %x \n", type_convert<ADataType>(-2.0f).data);
a_scales
.
GenerateTensorValue
(
GeneratorTensor_1
<
ScaleType
>
{
ScaleType
{
1.0
f
}});
a_scales
(
0
,
0
)
=
ScaleType
{
1.0
f
};
a_scales
(
0
,
1
)
=
ScaleType
{
0.5
f
};
b_n_k
.
GenerateTensorValue
(
GeneratorTensor_1
<
BDataType
>
{
0.0
f
});
b_scales
.
GenerateTensorValue
(
GeneratorTensor_1
<
ScaleType
>
{
ScaleType
{
1.0
f
}});
for
(
size_t
i
=
0
;
i
<
64
;
i
++
)
{
b_n_k
(
i
,
0
)
=
type_convert
<
BDataType
>
(
1.0
f
);
}
break
;
// case 3:
// // expect small round off errors
// a_m_k.GenerateTensorValue(GeneratorTensor_4<ADataType>(-1, 3));
// a_scales.GenerateTensorValue(
// GeneratorTensor_2<ScaleType>{126, 129}); // scales: {0.5, 1, 2}
// b_n_k.GenerateTensorValue(GeneratorTensor_4<BDataType>(1, 3));
// b_scales.GenerateTensorValue(
// GeneratorTensor_2<ScaleType>{126, 129}); // scales: {0.5, 1, 2}
// break;
// case 4:
// a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1.0f});
// a_scales.GenerateTensorValue(GeneratorTensor_Sequential<ScaleType, 0>{-9});
// b_n_k.GenerateTensorValue(GeneratorTensor_1<BDataType>{1.0f});
// b_scales.GenerateTensorValue(GeneratorTensor_1<ScaleType>{ScaleType{1.0f}});
// break;
// case 5:
// a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1.0f});
// a_scales.GenerateTensorValue(GeneratorTensor_1<ScaleType>{ScaleType{1.0f}});
// b_n_k.GenerateTensorValue(GeneratorTensor_1<BDataType>{1.0f});
// b_scales.GenerateTensorValue(GeneratorTensor_Sequential<ScaleType, 1>{-9});
// break;
case
6
:
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_1
<
ADataType
>
{
0.00195312
f
});
a_scales
.
GenerateTensorValue
(
GeneratorTensor_1
<
ScaleType
>
{
ScaleType
{
1.0
f
/
16
}});
b_n_k
.
GenerateTensorValue
(
GeneratorTensor_1
<
BDataType
>
{
1.0
f
});
b_scales
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
ScaleType
,
1
>
{
-
9
});
break
;
default:
// all initial values are representable in FP8, BF8
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
5
,
6
});
// a_scales.GenerateTensorValue(GeneratorTensor_3<ScaleType>{1.0f / 32.0f, 1.0f});
a_scales
.
GenerateTensorValue
(
GeneratorTensor_1
<
ScaleType
>
{
ScaleType
{
1.0
f
}});
b_n_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
5
,
6
});
// b_scales.GenerateTensorValue(GeneratorTensor_3<ScaleType>{1.0f / 32.0f, 1.0f});
b_scales
.
GenerateTensorValue
(
GeneratorTensor_1
<
ScaleType
>
{
ScaleType
{
1.0
f
}});
break
;
}
return
std
::
make_tuple
(
a_m_k
,
a_scales
,
b_n_k
,
b_scales
,
c_m_n_host_result
,
c_m_n_device_result
);
}
auto
operator
()(
const
DeviceMFMA
&
mfma_kernel
,
index_t
init
)
{
std
::
cout
<<
"ALayout = "
<<
ALayout
{}.
name
<<
", BLayout = "
<<
BLayout
{}.
name
<<
", CLayout = "
<<
CLayout
{}.
name
<<
std
::
endl
;
// Arrange
GemmParams
params
;
params
.
M
=
BLOCK_M
;
params
.
N
=
BLOCK_N
;
params
.
K
=
BLOCK_K
;
auto
f_get_default_stride
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
ck
::
index_t
stride
,
auto
layout
)
{
if
(
stride
==
-
1
)
{
// give a chance if stride is -1, return a default packed stride
if
constexpr
(
std
::
is_same_v
<
decltype
(
layout
),
ck
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
static_cast
<
std
::
size_t
>
(
col
);
}
else
{
return
static_cast
<
std
::
size_t
>
(
row
);
}
}
else
return
static_cast
<
std
::
size_t
>
(
stride
);
};
params
.
StrideA
=
f_get_default_stride
(
BLOCK_M
,
BLOCK_K
,
params
.
StrideA
,
ALayout
{});
params
.
StrideB
=
f_get_default_stride
(
BLOCK_K
,
BLOCK_N
,
params
.
StrideB
,
BLayout
{});
params
.
StrideC
=
f_get_default_stride
(
BLOCK_M
,
BLOCK_N
,
params
.
StrideC
,
CLayout
{});
auto
host_tensors
=
PrepareGemmTensors
(
params
,
init
);
const
Tensor
<
ADataType
>&
a
=
std
::
get
<
0
>
(
host_tensors
);
const
Tensor
<
ScaleType
>&
a_scales
=
std
::
get
<
1
>
(
host_tensors
);
const
Tensor
<
BDataType
>&
b
=
std
::
get
<
2
>
(
host_tensors
);
const
Tensor
<
ScaleType
>&
b_scales
=
std
::
get
<
3
>
(
host_tensors
);
Tensor
<
CDataType
>&
c_host
=
std
::
get
<
4
>
(
host_tensors
);
Tensor
<
CDataType
>&
c_device
=
std
::
get
<
5
>
(
host_tensors
);
RunHostGEMM
(
a
,
a_scales
,
b
,
b_scales
,
c_host
);
RunDeviceGEMM
(
mfma_kernel
,
a
,
a_scales
,
b
,
b_scales
,
c_device
);
#if 0
#if 1
std::cout << "a:" << std::endl;
for(size_t i = 0; i < BLOCK_M; i++)
{
for(size_t j = 0; j < BLOCK_K; j++)
{
std::cout << type_convert<float>(a(i, j)) << " ";
}
std::cout << std::endl;
break;
}
// std::cout << "b:" << std::endl;
// for(size_t i = 0; i < BLOCK_K; i++)
// {
// for(size_t j = 0; j < BLOCK_N; j++)
// {
// if(j == 0)
// std::cout << type_convert<float>(b(i, j)) << " ";
// }
// std::cout << std::endl;
// }
#endif
#if 0
std::cout << "a_scale:" << std::endl;
for(size_t i = 0; i < BLOCK_M; i++)
{
for(size_t j = 0; j < BLOCK_K / BLOCK_X; j++)
{
std::cout << type_convert<float>(a_scales(i, j)) << " ";
}
std::cout << std::endl;
}
// std::cout << "b_scale:" << std::endl;
// for(size_t i = 0; i < BLOCK_K / BLOCK_X; i++)
// {
// for(size_t j = 0; j < BLOCK_N; j++)
// {
// std::cout << type_convert<float>(b_scales(i, j)) << " ";
// }
// std::cout << std::endl;
// }
#endif
std::cout << "c_device:" << std::endl;
for(size_t i = 0; i < BLOCK_M; i++)
{
for(size_t j = 0; j < BLOCK_N; j++)
{
std::cout << type_convert<float>(c_device(i, j)) << " ";
}
std::cout << std::endl;
break;
}
#endif
bool
res
=
false
;
if
constexpr
(
std
::
is_same
<
CDataType
,
float
>::
value
||
std
::
is_same
<
CDataType
,
half_t
>::
value
)
{
res
=
ck
::
utils
::
check_err
(
c_device
.
mData
,
c_host
.
mData
);
std
::
cout
<<
(
res
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
if
(
!
res
)
{
std
::
cout
<<
"c_host:"
<<
std
::
endl
;
for
(
size_t
i
=
0
;
i
<
BLOCK_M
;
i
++
)
{
for
(
size_t
j
=
0
;
j
<
BLOCK_N
;
j
++
)
{
std
::
cout
<<
type_convert
<
float
>
(
c_host
(
i
,
j
))
<<
" "
;
}
std
::
cout
<<
std
::
endl
;
break
;
}
}
}
else
{
std
::
cout
<<
"UNSUPPORTED CDataType"
<<
std
::
endl
;
}
return
res
;
}
};
}
// namespace mxmfma_test
namespace
mfma_test
{
namespace
mfma_test
{
template
<
typename
GemmInstance
,
template
<
typename
GemmInstance
,
typename
ADataType
,
typename
ADataType
,
...
...
test/mx_mfma_op/scale_mfma_repro.cpp
0 → 100644
View file @
f3af1da6
#include <hip/hip_ext.h>
#include <hip/hip_runtime.h>
__global__
void
kernel
()
{
using
dataAB
=
uint8_t
__attribute__
((
ext_vector_type
(
32
)));
using
dataC
=
float
__attribute__
((
ext_vector_type
(
16
)));
using
dataX
=
int32_t
__attribute__
((
ext_vector_type
(
2
)));
dataAB
regA
(
0x38
);
dataAB
regB
(
0x38
);
dataC
regC
(
1.0
f
);
// dataC regCin(1.0f);
#if 1
// dataX xa{127, 127}; // 1.0
dataX
xa
(
127
&
0xFF
);
// 1.0
dataX
xb
(
127
&
0xFF
);
// 1.0
#else
dataX
xa
(
0
);
dataX
xb
(
0
);
#endif
#if 0
if(threadIdx.x == 0)
{
// xa = 127; // 1.0
for(size_t i = 0; i < 32; i++)
{
regA[i] = 0x38; // 1.0
}
for(size_t i = 0; i < 32; i++)
{
regB[i] = 0x38; // 1.0
}
printf("thread: %u -- xA: %x\n", threadIdx.x, xa[threadIdx.x / 32]);
printf("thread: %u -- xB: %x\n", threadIdx.x, xb[threadIdx.x / 32]);
}
if(threadIdx.x == 32)
{
// xa = 126; // 0.5
for(size_t i = 0; i < 32; i++)
{
regA[i] = 0xC0; // -2.0
}
for(size_t i = 0; i < 32; i++)
{
regB[i] = 0x38; // 1.0
}
printf("thread: %u -- xA: %x\n", threadIdx.x, xa[threadIdx.x / 32]);
printf("thread: %u -- xB: %x\n", threadIdx.x, xb[threadIdx.x / 32]);
}
#endif
__syncthreads
();
printf
(
"thread: %u -- regA: %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x "
"%x %x %x %x %x %x %x %x %x %x
\n
"
,
threadIdx
.
x
,
regA
[
0
],
regA
[
1
],
regA
[
2
],
regA
[
3
],
regA
[
4
],
regA
[
5
],
regA
[
6
],
regA
[
7
],
regA
[
8
],
regA
[
9
],
regA
[
10
],
regA
[
11
],
regA
[
12
],
regA
[
13
],
regA
[
14
],
regA
[
15
],
regA
[
16
],
regA
[
17
],
regA
[
18
],
regA
[
19
],
regA
[
20
],
regA
[
21
],
regA
[
22
],
regA
[
23
],
regA
[
24
],
regA
[
25
],
regA
[
26
],
regA
[
27
],
regA
[
28
],
regA
[
29
],
regA
[
30
],
regA
[
31
]);
printf
(
"thread: %u -- regB: %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x "
"%x %x %x %x %x %x %x %x %x %x
\n
"
,
threadIdx
.
x
,
regB
[
0
],
regB
[
1
],
regB
[
2
],
regB
[
3
],
regB
[
4
],
regB
[
5
],
regB
[
6
],
regB
[
7
],
regB
[
8
],
regB
[
9
],
regB
[
10
],
regB
[
11
],
regB
[
12
],
regB
[
13
],
regB
[
14
],
regB
[
15
],
regB
[
16
],
regB
[
17
],
regB
[
18
],
regB
[
19
],
regB
[
20
],
regB
[
21
],
regB
[
22
],
regB
[
23
],
regB
[
24
],
regB
[
25
],
regB
[
26
],
regB
[
27
],
regB
[
28
],
regB
[
29
],
regB
[
30
],
regB
[
31
]);
//__builtin_amdgcn_mfma_ld_scale_b32(xb[threadIdx.x / 32], 0, 0);
regC
=
__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4
(
regA
,
regB
,
regC
,
0
,
// cbsz
0
,
// blgp
0
,
xa
[
threadIdx
.
x
/
32
],
0
,
xb
[
threadIdx
.
x
/
32
]);
__syncthreads
();
printf
(
"thread: %u -- regC: %f %f %f %f %f %f %f %f %f %f %f %f %f %f %f %f
\n
"
,
threadIdx
.
x
,
regC
[
0
],
regC
[
1
],
regC
[
2
],
regC
[
3
],
regC
[
4
],
regC
[
5
],
regC
[
6
],
regC
[
7
],
regC
[
8
],
regC
[
9
],
regC
[
10
],
regC
[
11
],
regC
[
12
],
regC
[
13
],
regC
[
14
],
regC
[
15
]);
// printf("thread: %u -- regCin: %f %f %f %f %f %f %f %f %f %f %f %f %f %f %f %f\n",
// threadIdx.x,
// regCin[0],
// regCin[1],
// regCin[2],
// regCin[3],
// regCin[4],
// regCin[5],
// regCin[6],
// regCin[7],
// regCin[8],
// regCin[9],
// regCin[10],
// regCin[11],
// regCin[12],
// regCin[13],
// regCin[14],
// regCin[15]);
}
int
main
()
{
kernel
<<<
1
,
64
>>>
();
return
0
;
}
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment