Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
eb06c923
Commit
eb06c923
authored
Jan 29, 2025
by
Andriy Roshchenko
Browse files
Add tests for MFMA_F8F6F4::F32_16x16x128 and MFMA_F8F6F4::F32_32x32x64 instructions
parent
a619e3f5
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
563 additions
and
154 deletions
+563
-154
CMakeLists.txt
CMakeLists.txt
+1
-1
CMakePresets.json
CMakePresets.json
+189
-0
test/mx_mfma_op/mx_mfma_op.cpp
test/mx_mfma_op/mx_mfma_op.cpp
+42
-39
test/mx_mfma_op/mx_mfma_op.hpp
test/mx_mfma_op/mx_mfma_op.hpp
+331
-114
No files found.
CMakeLists.txt
View file @
eb06c923
...
...
@@ -530,7 +530,7 @@ endif()
message
(
"CMAKE_CXX_FLAGS:
${
CMAKE_CXX_FLAGS
}
"
)
if
(
"
${
CMAKE_CXX_COMPILER_ID
}
"
MATCHES
"Clang"
)
add_compile_options
(
-fcolor-diagnostics
)
#
add_compile_options(-fcolor-diagnostics)
endif
()
if
(
"
${
CMAKE_CXX_COMPILER_ID
}
"
STREQUAL
"GNU"
AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 4.9
)
add_compile_options
(
-fdiagnostics-color=always
)
...
...
CMakePresets.json
0 → 100644
View file @
eb06c923
{
"version"
:
3
,
"configurePresets"
:
[
{
"name"
:
"linux-debug"
,
"displayName"
:
"Linux Debug"
,
"hidden"
:
true
,
"generator"
:
"Unix Makefiles"
,
"binaryDir"
:
"${sourceDir}/build/${presetName}"
,
"installDir"
:
"${sourceDir}/build/install/${presetName}"
,
"environment"
:
{
"MY_ENVIRONMENT_VARIABLE"
:
"NONE"
,
"PATH"
:
"/usr/local/.cargo/bin:$penv{PATH}"
,
"SCCACHE_IDLE_TIMEOUT"
:
"11000"
},
"cacheVariables"
:
{
"CMAKE_BUILD_TYPE"
:
"Debug"
,
"CMAKE_EXPORT_COMPILE_COMMANDS"
:
"ON"
,
"BUILD_DEV"
:
"ON"
,
"CMAKE_CXX_COMPILER"
:
"/opt/rocm/bin/hipcc"
,
"CMAKE_PREFIX_PATH"
:
"/opt/rocm"
,
"CMAKE_CXX_COMPILER_LAUNCHER"
:
"sccache"
,
"CMAKE_C_COMPILER_LAUNCHER"
:
"sccache"
},
"condition"
:
{
"type"
:
"equals"
,
"lhs"
:
"${hostSystemName}"
,
"rhs"
:
"Linux"
}
},
{
"name"
:
"MI355-debug"
,
"displayName"
:
"MI355 Debug"
,
"inherits"
:
"linux-debug"
,
"description"
:
"Development Environment for MI355."
,
"cacheVariables"
:
{
"GPU_TARGETS"
:
"gfx950"
,
"CMAKE_BUILD_TYPE"
:
"Debug"
,
"CMAKE_CXX_FLAGS"
:
"-O0 -ggdb"
}
},
{
"name"
:
"MI355-release"
,
"displayName"
:
"MI355 Release"
,
"inherits"
:
"linux-debug"
,
"cacheVariables"
:
{
"GPU_TARGETS"
:
"gfx950"
,
"CMAKE_BUILD_TYPE"
:
"Release"
,
"CMAKE_CXX_FLAGS"
:
"-O3"
}
},
{
"name"
:
"MI300X-release"
,
"displayName"
:
"MI300X Release"
,
"inherits"
:
"linux-debug"
,
"cacheVariables"
:
{
"GPU_TARGETS"
:
"gfx942"
,
"CMAKE_BUILD_TYPE"
:
"Release"
,
"CMAKE_CXX_FLAGS"
:
"-O3"
}
},
{
"name"
:
"MI250-release"
,
"displayName"
:
"MI250 Release"
,
"inherits"
:
"linux-debug"
,
"cacheVariables"
:
{
"GPU_TARGETS"
:
"gfx90a"
,
"CMAKE_BUILD_TYPE"
:
"Release"
,
"CMAKE_CXX_FLAGS"
:
"-O3"
,
"CK_USE_FP8_ON_UNSUPPORTED_ARCH"
:
"ON"
}
},
{
"name"
:
"MI250-debug"
,
"displayName"
:
"MI250 Debug"
,
"inherits"
:
"linux-debug"
,
"cacheVariables"
:
{
"GPU_TARGETS"
:
"gfx90a"
,
"CMAKE_BUILD_TYPE"
:
"Debug"
,
"CMAKE_CXX_FLAGS"
:
"-O0 -ggdb"
,
"CK_USE_FP8_ON_UNSUPPORTED_ARCH"
:
"ON"
}
},
{
"name"
:
"RX7800-release"
,
"displayName"
:
"RX7800 Release"
,
"inherits"
:
"linux-debug"
,
"cacheVariables"
:
{
"GPU_TARGETS"
:
"gfx1101"
,
"DL_KERNELS"
:
"ON"
,
"CMAKE_BUILD_TYPE"
:
"Release"
,
"CMAKE_CXX_FLAGS"
:
"-O3"
}
},
{
"name"
:
"RX7800-debug"
,
"displayName"
:
"RX7800 Debug"
,
"inherits"
:
"linux-debug"
,
"cacheVariables"
:
{
"GPU_TARGETS"
:
"gfx1101"
,
"DL_KERNELS"
:
"ON"
,
"CMAKE_BUILD_TYPE"
:
"Debug"
,
"CMAKE_CXX_FLAGS"
:
"-O0 -ggdb"
}
}
],
"buildPresets"
:
[
{
"name"
:
"Debug"
,
"hidden"
:
true
,
"configuration"
:
"Debug"
},
{
"name"
:
"Release"
,
"hidden"
:
true
,
"configuration"
:
"Release"
},
{
"name"
:
"MI355-debug"
,
"displayName"
:
"MI355"
,
"configurePreset"
:
"MI355-debug"
,
"description"
:
"Build Environment for MI355 Debug."
,
"inherits"
:
[
"Debug"
],
"jobs"
:
128
},
{
"name"
:
"MI355-release"
,
"displayName"
:
"MI355"
,
"configurePreset"
:
"MI355-release"
,
"description"
:
"Build Environment for MI355 Release."
,
"inherits"
:
[
"Release"
],
"jobs"
:
128
},
{
"name"
:
"MI300X-release"
,
"displayName"
:
"MI300X"
,
"configurePreset"
:
"MI300X-release"
,
"description"
:
"Build Environment for MI300X Release."
,
"inherits"
:
[
"Release"
],
"jobs"
:
128
},
{
"name"
:
"MI250-release"
,
"displayName"
:
"MI250"
,
"configurePreset"
:
"MI250-release"
,
"description"
:
"Build Environment for MI250 Release."
,
"inherits"
:
[
"Release"
],
"jobs"
:
128
},
{
"name"
:
"MI250-debug"
,
"displayName"
:
"MI250"
,
"configurePreset"
:
"MI250-debug"
,
"description"
:
"Build Environment for MI250 Debug."
,
"inherits"
:
[
"Debug"
],
"jobs"
:
128
},
{
"name"
:
"RX7800-release"
,
"displayName"
:
"RX7800"
,
"configurePreset"
:
"RX7800-release"
,
"description"
:
"Build Environment for RX7800 Release."
,
"inherits"
:
[
"Release"
],
"jobs"
:
128
},
{
"name"
:
"RX7800-debug"
,
"displayName"
:
"RX7800"
,
"configurePreset"
:
"RX7800-debug"
,
"description"
:
"Build Environment for RX7800 Debug."
,
"inherits"
:
[
"Debug"
],
"jobs"
:
128
}
]
}
test/mx_mfma_op/mx_mfma_op.cpp
View file @
eb06c923
...
...
@@ -6,52 +6,55 @@
#include "mx_mfma_op.hpp"
using
ck
::
e8m0_bexp_t
;
using
ck
::
f8_ocp_t
;
using
ck
::
f8_t
;
using
ck
::
half_t
;
using
ck
::
type_convert
;
template
<
typename
Src1Type
,
ck
::
index_t
Src1VecSize
,
typename
Src2Type
,
ck
::
index_t
Src2VecSize
,
typename
DstType
,
ck
::
index_t
AccVecSize
,
typename
AccType
,
typename
CPUAccType
,
ck
::
index_t
M
,
ck
::
index_t
N
,
ck
::
index_t
K
>
template
<
typename
AType
,
typename
BType
,
typename
CType
,
ck
::
mx_mfma_test
::
MFMA_F8F6F4
mfma
>
bool
run_test
()
{
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
bool
pass
=
true
;
const
auto
mx_mfma_kernel
=
ck
::
mx_mfma_test
::
matmul
<
Src1Type
,
Src1VecSize
,
Src2Type
,
Src2VecSize
,
AccType
,
AccVecSize
,
DstType
,
M
,
N
,
K
>
;
pass
=
ck
::
mx_mfma_test
::
TestMXMFMA
<
decltype
(
mx_mfma_kernel
),
Src1Type
,
Src2Type
,
DstType
,
AccType
,
CPUAccType
,
decltype
(
Row
{}),
decltype
(
Row
{}),
decltype
(
Row
{}),
PassThrough
,
PassThrough
,
PassThrough
,
AccVecSize
,
M
,
N
,
K
>
{}(
mx_mfma_kernel
);
using
ALayout
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
BLayout
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
CLayout
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
AccType
=
float
;
// only MFMA_F32 instructions supported
using
CPUAccType
=
AccType
;
ck
::
mfma_type
<
static_cast
<
ck
::
MfmaInstr
>
(
mfma
)
>
mfma_instr
;
constexpr
auto
BLOCK_M
=
mfma_instr
.
m_per_blk
;
constexpr
auto
BLOCK_N
=
mfma_instr
.
n_per_blk
;
constexpr
auto
BLOCK_K
=
mfma_instr
.
num_input_blks
*
mfma_instr
.
k_per_blk
;
const
auto
mx_mfma_kernel
=
ck
::
mx_mfma_test
::
matmul
<
AType
,
BType
,
CType
,
AccType
,
BLOCK_M
,
BLOCK_N
,
BLOCK_K
>
;
bool
pass
=
true
;
pass
=
ck
::
mx_mfma_test
::
TestMFMA
<
decltype
(
mx_mfma_kernel
),
AType
,
BType
,
CType
,
AccType
,
CPUAccType
,
ALayout
,
BLayout
,
CLayout
,
BLOCK_M
,
BLOCK_N
,
BLOCK_K
>
{}(
mx_mfma_kernel
);
return
pass
;
}
TEST
(
MXMFMA
,
FP8MFMA16x16x128
)
TEST
(
MFMA
,
FP8MFMA16x16x128
)
{
auto
pass
=
run_test
<
f8_t
,
f8_t
,
half_t
,
ck
::
mx_mfma_test
::
MFMA_F8F6F4
::
F32_16x16x128
>
();
EXPECT_TRUE
(
pass
);
}
TEST
(
MFMA
,
FP8MFMA32x32x64
)
{
auto
pass
=
run_test
<
f
loat
,
1
,
float
,
1
,
float
,
1
,
float
,
float
,
16
,
16
,
128
>
();
auto
pass
=
run_test
<
f
8_t
,
f8_t
,
float
,
ck
::
mx_mfma_test
::
MFMA_F8F6F4
::
F32_32x32x64
>
();
EXPECT_TRUE
(
pass
);
}
...
...
@@ -70,5 +73,5 @@ TEST(MXMFMA, FP8MFMA16x16x128)
// EXPECT_TRUE(run_test<bf8, 1, bf8, 1, float, 1, float, float, 32, 32, 64>());
// }
TEST
(
MXMFMA
,
MXFP8xMXFP8
)
{
EXPECT_TRUE
(
false
)
<<
"Not Implemented
\n
"
;
}
TEST
(
MXMFMA
,
MXBF8xMXBF8
)
{
EXPECT_TRUE
(
false
)
<<
"Not Implemented
\n
"
;
}
//
TEST(MXMFMA, MXFP8xMXFP8) { EXPECT_TRUE(false) << "Not Implemented\n"; }
//
TEST(MXMFMA, MXBF8xMXBF8) { EXPECT_TRUE(false) << "Not Implemented\n"; }
test/mx_mfma_op/mx_mfma_op.hpp
View file @
eb06c923
...
...
@@ -5,6 +5,7 @@
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/warp/xdlops_gemm.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/check_err.hpp"
...
...
@@ -12,114 +13,332 @@
namespace
ck
{
namespace
mx_mfma_test
{
template
<
typename
src_vec1
,
typename
src_vec2
,
typename
acc_vec
>
__device__
void
builtin_mx_mfma_naive_selector
(
const
src_vec1
&
,
const
src_vec2
&
,
acc_vec
&
)
// MFMA instructions supported in this test
enum
class
MFMA_F8F6F4
{
}
F32_16x16x128
=
static_cast
<
int
>
(
MfmaInstr
::
mfma_f32_16x16x128f8f6f4
),
// V_MFMA_F32_16X16X128_F8F6F4
F32_32x32x64
=
static_cast
<
int
>
(
MfmaInstr
::
mfma_f32_32x32x64f8f6f4
)
// V_MFMA_F32_32X32X64_F8F6F4
};
template
<
typename
AFragT
,
typename
BFragT
,
typename
AccumFragT
,
int32_t
BLOCK_M
,
int32_t
BLOCK_N
>
struct
mfma_type_selector
;
// Smfmac instructions are using 4:2 structural sparsity, that means that in every contignuous
// subgroup of 4 elements, atleast 2 must be equal to zero and the position of non-zero elements is
// stored in idx register to allow selection of corresponding B matrix elements for multiplication.
// Currently smfmac instructions support only A matrix as sparse
template
<
typename
src1_t
,
index_t
src1_vec_size
,
typename
src2_t
,
index_t
src2_vec_size
,
typename
acc_t
,
index_t
acc_vec_size
,
typename
dst_t
,
int32_t
M
,
int32_t
N
,
int32_t
K
>
__global__
void
matmul
(
const
src1_t
*
a
,
const
src2_t
*
b
,
dst_t
*
c
)
template
<
typename
AFragT
,
typename
BFragT
,
typename
AccumFragT
>
struct
mfma_type_selector
<
AFragT
,
BFragT
,
AccumFragT
,
16
,
16
>
{
__shared__
src1_t
a_shared
[
M
*
K
];
__shared__
src2_t
b_shared
[
K
*
N
];
const
int
lane
=
threadIdx
.
x
;
// smfmac's A part is storing only non-zero elements in 2VGPRs
// smfmac's B part is storing all elements in 4VGPRs
using
src1_vec
=
typename
vector_type
<
src1_t
,
src1_vec_size
>::
type
;
using
src1_full_vec
=
typename
vector_type
<
src1_t
,
src1_vec_size
*
2
>::
type
;
using
src2_vec
=
typename
vector_type
<
src2_t
,
src2_vec_size
>::
type
;
src1_vec
a_frag
=
{};
src2_vec
b_frag
=
{};
src1_full_vec
a_temp
=
{};
src2_vec
b_temp
=
{};
// initialize c fragment to 0
using
acc_vec
=
StaticBufferTupleOfVector
<
AddressSpaceEnum
::
Vgpr
,
acc_t
,
1
,
acc_vec_size
,
true
>
;
acc_vec
c_thread_buf_
;
for
(
int
i
=
0
;
i
<
8
;
++
i
)
__device__
void
operator
()(
AFragT
const
&
fragA
,
BFragT
const
&
fragB
,
AccumFragT
&
fragAcc
)
{
a_temp
[
i
]
=
a
[(
lane
%
M
)
*
K
+
(
lane
/
M
)
*
8
+
i
];
// M K
#if 1
auto
op
=
mfma_type
<
MfmaInstr
::
mfma_f32_16x16x128f8f6f4
>
{};
op
.
template
run
<
16
,
16
,
AFragT
,
BFragT
,
AccumFragT
>(
fragA
,
fragB
,
fragAcc
);
#else
ignore
=
fragA
;
ignore
=
fragB
;
ignore
=
fragAcc
;
#endif
}
};
for
(
int
i
=
0
;
i
<
8
;
++
i
)
template
<
typename
AFragT
,
typename
BFragT
,
typename
AccumFragT
>
struct
mfma_type_selector
<
AFragT
,
BFragT
,
AccumFragT
,
32
,
32
>
{
__device__
void
operator
()(
AFragT
const
&
fragA
,
BFragT
const
&
fragB
,
AccumFragT
&
fragAcc
)
{
b_temp
[
i
]
=
b
[(
8
*
(
lane
/
N
)
+
i
)
*
N
+
(
lane
%
N
)];
// K N
#if 1
auto
op
=
mfma_type
<
MfmaInstr
::
mfma_f32_32x32x64f8f6f4
>
{};
op
.
template
run
<
32
,
32
,
AFragT
,
BFragT
,
AccumFragT
>(
fragA
,
fragB
,
fragAcc
);
#else
ignore
=
fragA
;
ignore
=
fragB
;
ignore
=
fragAcc
;
#endif
}
};
__syncthreads
();
template
<
typename
VecT
>
static
constexpr
int32_t
vectorSize
(
const
VecT
&
)
{
return
scalar_type
<
VecT
>::
vector_size
;
}
for
(
int
i
=
0
;
i
<
8
;
++
i
)
{
a_shared
[(
lane
%
M
)
*
K
+
(
lane
/
M
)
*
8
+
i
]
=
a_temp
[
i
];
}
for
(
int
i
=
0
;
i
<
8
;
++
i
)
{
b_shared
[(
8
*
(
lane
/
N
)
+
i
)
*
N
+
(
lane
%
N
)]
=
b_temp
[
i
];
}
// Define a load function for input A blocks:
// Size: (BLOCK_M x BLOCK_K)
// ASSUMPTION:
// - We want contiguous BLOCK_M sized column neighbors in register.
// - Data is in col_major format
// This means:
// - From A we will load K columns of size BLOCK_M to satisfy our input data
template
<
typename
AType
,
typename
AFragT
,
int32_t
BLOCK_M
,
int32_t
BLOCK_K
>
__device__
AFragT
load_A_col_major
(
AType
const
*
input_ptr
)
{
// Here we want to load a BLOCK_M x BLOCK_K block of data.
static
constexpr
uint32_t
VW
=
vectorSize
(
AFragT
{});
using
ARawT
=
typename
scalar_type
<
AFragT
>::
type
;
using
AScalarFragT
=
vector_type
<
ARawT
,
VW
>::
type
;
// To start the loading process, let's visualize in 2D coords.
// Each thread will load 32 elements.
// We need to know where they start, and where the next elements are.
auto
startCoord2D
=
std
::
make_pair
(
threadIdx
.
x
%
BLOCK_M
,
// Row
(
threadIdx
.
x
/
BLOCK_M
)
*
VW
);
// Col
auto
stepCoord2D
=
std
::
make_pair
(
0u
,
1u
);
// Flatten to 1D col_major offsets.
auto
col_major
=
[](
auto
const
&
coord
,
auto
ld
)
{
return
coord
.
first
+
coord
.
second
*
ld
;
};
// BLOCK_M is a stride in A matrix
auto
startOffset
=
col_major
(
startCoord2D
,
BLOCK_M
);
auto
kOffset
=
col_major
(
stepCoord2D
,
BLOCK_M
);
// kOffset == BLOCK_M
// This means every BLOCK_M element is loaded into output vector
auto
fragA
=
AScalarFragT
{
bit_cast
<
ARawT
>
(
input_ptr
[
startOffset
]),
// XXX v[0] = Reg 0 [0:7]
bit_cast
<
ARawT
>
(
input_ptr
[
startOffset
+
1
*
kOffset
]),
// XXX v[1] = Reg 0 [8:15]
bit_cast
<
ARawT
>
(
input_ptr
[
startOffset
+
2
*
kOffset
]),
// XXX v[2] = Reg 0 [16:23]
bit_cast
<
ARawT
>
(
input_ptr
[
startOffset
+
3
*
kOffset
]),
// XXX v[3] = Reg 0 [24:31]
bit_cast
<
ARawT
>
(
input_ptr
[
startOffset
+
4
*
kOffset
]),
// XXX v[4] = Reg 1 [0:7]
bit_cast
<
ARawT
>
(
input_ptr
[
startOffset
+
5
*
kOffset
]),
// XXX v[5] = Reg 1 [8:15]
bit_cast
<
ARawT
>
(
input_ptr
[
startOffset
+
6
*
kOffset
]),
// XXX v[6] = Reg 1 [16:23]
bit_cast
<
ARawT
>
(
input_ptr
[
startOffset
+
7
*
kOffset
]),
// XXX v[7] = Reg 1 [24:31]
bit_cast
<
ARawT
>
(
input_ptr
[
startOffset
+
8
*
kOffset
]),
// XXX v[8] = Reg 2 [0:7]
bit_cast
<
ARawT
>
(
input_ptr
[
startOffset
+
9
*
kOffset
]),
// XXX v[9] = Reg 2 [8:15]
bit_cast
<
ARawT
>
(
input_ptr
[
startOffset
+
10
*
kOffset
]),
// XXX v[10] = Reg 2 [16:23]
bit_cast
<
ARawT
>
(
input_ptr
[
startOffset
+
11
*
kOffset
]),
// XXX v[11] = Reg 2 [24:31]
bit_cast
<
ARawT
>
(
input_ptr
[
startOffset
+
12
*
kOffset
]),
// XXX v[12] = Reg 3 [0:7]
bit_cast
<
ARawT
>
(
input_ptr
[
startOffset
+
13
*
kOffset
]),
// XXX v[13] = Reg 3 [8:15]
bit_cast
<
ARawT
>
(
input_ptr
[
startOffset
+
14
*
kOffset
]),
// XXX v[14] = Reg 3 [16:23]
bit_cast
<
ARawT
>
(
input_ptr
[
startOffset
+
15
*
kOffset
]),
// XXX v[15] = Reg 3 [24:31]
bit_cast
<
ARawT
>
(
input_ptr
[
startOffset
+
16
*
kOffset
]),
// XXX v[16] = Reg 4 [0:7]
bit_cast
<
ARawT
>
(
input_ptr
[
startOffset
+
17
*
kOffset
]),
// XXX v[17] = Reg 4 [8:15]
bit_cast
<
ARawT
>
(
input_ptr
[
startOffset
+
18
*
kOffset
]),
// XXX v[18] = Reg 4 [16:23]
bit_cast
<
ARawT
>
(
input_ptr
[
startOffset
+
19
*
kOffset
]),
// XXX v[19] = Reg 4 [24:31]
bit_cast
<
ARawT
>
(
input_ptr
[
startOffset
+
20
*
kOffset
]),
// XXX v[20] = Reg 5 [0:7]
bit_cast
<
ARawT
>
(
input_ptr
[
startOffset
+
21
*
kOffset
]),
// XXX v[21] = Reg 5 [8:15]
bit_cast
<
ARawT
>
(
input_ptr
[
startOffset
+
22
*
kOffset
]),
// XXX v[22] = Reg 5 [16:23]
bit_cast
<
ARawT
>
(
input_ptr
[
startOffset
+
23
*
kOffset
]),
// XXX v[23] = Reg 5 [24:31]
bit_cast
<
ARawT
>
(
input_ptr
[
startOffset
+
24
*
kOffset
]),
// XXX v[24] = Reg 6 [0:7]
bit_cast
<
ARawT
>
(
input_ptr
[
startOffset
+
25
*
kOffset
]),
// XXX v[25] = Reg 6 [8:15]
bit_cast
<
ARawT
>
(
input_ptr
[
startOffset
+
26
*
kOffset
]),
// XXX v[26] = Reg 6 [16:23]
bit_cast
<
ARawT
>
(
input_ptr
[
startOffset
+
27
*
kOffset
]),
// XXX v[27] = Reg 6 [24:31]
bit_cast
<
ARawT
>
(
input_ptr
[
startOffset
+
28
*
kOffset
]),
// XXX v[28] = Reg 7 [0:7]
bit_cast
<
ARawT
>
(
input_ptr
[
startOffset
+
29
*
kOffset
]),
// XXX v[29] = Reg 7 [8:15]
bit_cast
<
ARawT
>
(
input_ptr
[
startOffset
+
30
*
kOffset
]),
// XXX v[30] = Reg 7 [16:23]
bit_cast
<
ARawT
>
(
input_ptr
[
startOffset
+
31
*
kOffset
])};
// XXX v[31] = Reg 7 [24:31]
return
fragA
;
}
// Define a load function for input B blocks:
// Size: (BLOCK_K x BLOCK_N)
// ASSUMPTION:
// - We want contiguous BLOCK_N sized row neighbors in register.
// - Data is in row_major format
// This means:
// - From B we will load K rows of size BLOCK_N to satisfy our input data
template
<
typename
BType
,
typename
BFragT
,
int32_t
BLOCK_K
,
int32_t
BLOCK_N
>
__device__
BFragT
load_B_col_major
(
BType
const
*
input_ptr
)
{
// Here we want to load a BLOCK_K x BLOCK_N block of data.
static
constexpr
uint32_t
VW
=
vectorSize
(
BFragT
{});
// To start the loading process, let's visualize in 2D coords.
// Each thread will load 32 elements.
// We need to know where they start, and where the next elements are.
auto
startCoord2D
=
std
::
make_pair
((
threadIdx
.
x
/
BLOCK_N
)
*
VW
,
// Row
threadIdx
.
x
%
BLOCK_N
);
// Col
// auto stepCoord2D = std::make_pair(1u, 0u);
__syncthreads
();
// Flatten to 1D col_major offsets.
auto
col_major
=
[](
auto
const
&
coord
,
auto
ld
)
{
return
coord
.
first
+
coord
.
second
*
ld
;
};
// Idx must be a 32-bit register and it is storing 4 2-bit indexes of A's non zero elements.
// It starts with last two elements of every 4 elements subgroup set as non-zero
int32_t
idx
=
0b11101110
;
// Bit masks are for zeroing 0-3rd position of idx
static
constexpr
int32_t
bit_clear_masks
[
4
]
=
{
0b11
,
0b1100
,
0b110000
,
0b11000000
};
auto
startOffset
=
col_major
(
startCoord2D
,
BLOCK_K
);
// auto kOffset = col_major(stepCoord2D, BLOCK_K);
src1_t
curr_val
;
int32_t
a_pos
=
0
;
for
(
int
j
=
0
;
j
<
2
;
++
j
)
// kOffset == 1
auto
const
*
fragPtr
=
reinterpret_cast
<
BFragT
const
*>
(
input_ptr
+
startOffset
);
return
*
fragPtr
;
}
// Define a store function for C
// Size: (BLOCK_M x BLOCK_N)
// ASSUMPTION:
// - We want contiguous BLOCK_N sized row neighbors in register.
// - Data is in col_major format
// This means:
// - From C we will load BLOCK_M rows of size BLOCK_N to satisfy our input data
template
<
typename
CType
,
typename
CFragT
,
int32_t
BLOCK_M
,
int32_t
BLOCK_N
>
struct
store_C_col_major
;
// Here we want to store a 16x16 block of data.
//
// Size | BLOCK_N | BLOCK_N | BLOCK_N | BLOCK_N | Vector
// Register Element | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Element
// ____________ _____________ _____________ ______________
// Reg0 | M0 | M4 | M8 | M12 | v[0]
// Reg1 | M1 | M5 | M9 | M13 | v[1]
// Reg2 | M2 | M6 | M10 | M14 | v[2]
// Reg3 | M3 | M7 | M11 | M15 | v[3]
template
<
typename
CType
,
typename
CFragT
>
struct
store_C_col_major
<
CType
,
CFragT
,
16
,
16
>
{
__device__
void
operator
()(
CType
*
output
,
CFragT
cFrag
)
{
a_pos
=
j
*
2
;
for
(
int
i
=
0
;
i
<
4
;
++
i
)
static
constexpr
uint32_t
VW
=
vectorSize
(
cFrag
);
// 4
static
constexpr
uint32_t
Dim
=
16
;
#if 1
for
(
int
i
=
0
;
i
<
vectorSize
(
cFrag
);
++
i
)
{
curr_val
=
a_shared
[(
lane
%
M
)
*
K
+
(
lane
/
M
)
*
8
+
4
*
j
+
i
];
if
(
curr_val
!=
0.0
f
)
{
idx
&=
~
bit_clear_masks
[
a_pos
];
idx
|=
(
i
%
4
)
<<
2
*
a_pos
;
a_frag
[
a_pos
]
=
curr_val
;
a_pos
++
;
}
printf
(
"threadIdx.x = %d; cFrag[%d] = %f
\n
"
,
static_cast
<
int
>
(
threadIdx
.
x
),
i
,
static_cast
<
float
>
(
cFrag
[
i
]));
}
#endif
// Each thread will load 4 elements.
// We need to know where they start, and where the next elements are.
auto
startCoord2D
=
std
::
make_pair
((
threadIdx
.
x
/
Dim
)
*
VW
,
// Row
threadIdx
.
x
%
Dim
);
// Col
// auto stepCoord2D = std::make_pair(1u, 0u);
// Flatten to 1D col_major offsets.
auto
col_major
=
[](
auto
const
&
coord
,
auto
ld
)
{
return
coord
.
first
+
coord
.
second
*
ld
;
};
auto
startOffset
=
col_major
(
startCoord2D
,
16
);
// auto kOffset = col_major(stepCoord2D, 16); // 1
// kOffset == 1
auto
*
fragPtr
=
reinterpret_cast
<
CFragT
*>
(
output
+
startOffset
);
*
fragPtr
=
cFrag
;
}
};
for
(
int
i
=
0
;
i
<
8
;
++
i
)
// Here we want to store a 32x32 block of data.
// Register Mapping:
// Size | BLOCK_N | BLOCK_N | Vector
// Register Element | 0 ... 31 | 32 ... 63 | Element
// ____________ _____________
// Reg0 | M0 | M4 | v[0]
// Reg1 | M1 | M5 | v[1]
// Reg2 | M2 | M6 | v[2]
// Reg3 | M3 | M7 | v[3]
// ____________ _____________
// Reg4 | M8 | M12 | v[4]
// Reg5 | M9 | M13 | v[5]
// Reg6 | M10 | M14 | v[6]
// Reg7 | M11 | M15 | v[7]
// ____________ _____________
// Reg8 | M16 | M20 | v[8]
// Reg9 | M17 | M21 | v[9]
// Reg10 | M18 | M22 | v[10]
// Reg11 | M19 | M23 | v[11]
// ____________ _____________
// Reg12 | M24 | M28 | v[12]
// Reg13 | M25 | M29 | v[13]
// Reg14 | M26 | M30 | v[14]
// Reg15 | M27 | M31 | v[15]
template
<
typename
CType
,
typename
CFragT
>
struct
store_C_col_major
<
CType
,
CFragT
,
32
,
32
>
{
__device__
void
operator
()(
CType
*
output
,
CFragT
cFrag
)
{
b_frag
[
i
]
=
b_shared
[(
8
*
(
lane
/
N
)
+
i
)
*
N
+
(
lane
%
N
)];
}
static
constexpr
uint32_t
WAVE_SIZE
=
64
;
static
constexpr
uint32_t
VW
=
4
;
static
constexpr
uint32_t
Dim
=
32
;
static
constexpr
uint32_t
M_PER_VW_CHUNK
=
VW
*
WAVE_SIZE
/
32
;
// 8
builtin_smfmac_naive_selector
<
src1_vec
,
src2_vec
,
acc_vec
>
(
a_frag
,
b_frag
,
idx
,
c_thread_buf_
);
__syncthreads
();
#if 1
for
(
int
i
=
0
;
i
<
vectorSize
(
cFrag
);
++
i
)
{
printf
(
"threadIdx.x = %d; cFrag[%d] = %f
\n
"
,
static_cast
<
int
>
(
threadIdx
.
x
),
i
,
static_cast
<
float
>
(
cFrag
[
i
]));
}
#endif
// store results from unpacked c_thread_buf_ output
if
constexpr
(
K
==
32
)
{
static_for
<
0
,
acc_vec_size
,
1
>
{}([
&
](
auto
i
)
{
c
[(
4
*
(
lane
/
16
)
+
i
)
*
N
+
lane
%
16
]
=
ck
::
type_convert
<
dst_t
>
(
c_thread_buf_
[
Number
<
i
>
{}]);
});
auto
startCoord2D
=
std
::
make_pair
((
threadIdx
.
x
/
Dim
)
*
VW
,
// Row
threadIdx
.
x
%
Dim
);
// Col
// Minor step for each 'chunk'
// auto minorStepCoord2D = std::make_pair(1u, 0u);
// Major step between 'chunks'
auto
majorStepCoord2D
=
std
::
make_pair
(
M_PER_VW_CHUNK
,
0
);
// Flatten to 1D col_major offsets.
auto
col_major
=
[](
auto
const
&
coord
,
auto
ld
)
{
return
coord
.
first
+
coord
.
second
*
ld
;
};
auto
startOffset
=
col_major
(
startCoord2D
,
32
);
// auto kMinorOffset = col_major(minorStepCoord2D, 32); // 1
auto
kMajorOffset
=
col_major
(
majorStepCoord2D
,
32
);
// 8
// kMinorOffset == 1.
// This means we can vector store 4 contiguous elements at a time.
using
CRawT
=
typename
scalar_type
<
CFragT
>::
type
;
using
CScalarFragT
=
vector_type
<
CRawT
,
VW
>::
type
;
union
{
CFragT
frag
;
CScalarFragT
chunks
[
vectorSize
(
CFragT
{})
/
VW
];
}
fragC
{
cFrag
};
// Initialize with input fragment
*
(
reinterpret_cast
<
CScalarFragT
*>
(
output
+
startOffset
))
=
fragC
.
chunks
[
0
];
*
(
reinterpret_cast
<
CScalarFragT
*>
(
output
+
startOffset
+
kMajorOffset
))
=
fragC
.
chunks
[
1
];
*
(
reinterpret_cast
<
CScalarFragT
*>
(
output
+
startOffset
+
2
*
kMajorOffset
))
=
fragC
.
chunks
[
2
];
*
(
reinterpret_cast
<
CScalarFragT
*>
(
output
+
startOffset
+
3
*
kMajorOffset
))
=
fragC
.
chunks
[
3
];
}
else
};
template
<
typename
AType
,
typename
BType
,
typename
CType
,
typename
AccType
,
int32_t
BLOCK_M
,
int32_t
BLOCK_N
,
int32_t
BLOCK_K
>
__global__
void
matmul
(
const
AType
*
a
,
const
BType
*
b
,
CType
*
c
)
{
constexpr
int
WAVE_SIZE
=
64
;
assert
(
threadIdx
.
x
<
WAVE_SIZE
);
assert
(
blockDim
.
x
==
1
&&
blockDim
.
y
==
1
&&
blockDim
.
z
==
1
);
using
AFragT
=
vector_type
<
AType
,
BLOCK_M
*
BLOCK_K
/
WAVE_SIZE
>::
type
;
using
BFragT
=
vector_type
<
BType
,
BLOCK_K
*
BLOCK_N
/
WAVE_SIZE
>::
type
;
using
CFragT
=
vector_type
<
CType
,
BLOCK_M
*
BLOCK_N
/
WAVE_SIZE
>::
type
;
using
AccumFragT
=
vector_type
<
AccType
,
BLOCK_M
*
BLOCK_N
/
WAVE_SIZE
>
;
using
RawAccumFragT
=
vector_type
<
AccType
,
BLOCK_M
*
BLOCK_N
/
WAVE_SIZE
>::
type
;
// Create frags
auto
fragA
=
AFragT
{};
auto
fragB
=
BFragT
{};
auto
fragC
=
CFragT
{};
auto
fragAcc
=
AccumFragT
{
0
};
// Load the inputs.
// A = col major, BLOCK_M x BLOCK_K
fragA
=
load_A_col_major
<
AType
,
AFragT
,
BLOCK_M
,
BLOCK_K
>
(
a
);
// B = col major, BLOCK_K x BLOCK_N
fragB
=
load_B_col_major
<
BType
,
BFragT
,
BLOCK_K
,
BLOCK_N
>
(
b
);
// Matrix multiply-accumulate using MFMA units
// Accumulation intermediate = BLOCK_M x BLOCK_N
mfma_type_selector
<
AFragT
,
BFragT
,
AccumFragT
,
BLOCK_M
,
BLOCK_N
>
{}(
fragA
,
fragB
,
fragAcc
);
for
(
int
i
=
0
;
i
<
vectorSize
(
fragC
);
++
i
)
{
static_for
<
0
,
acc_vec_size
,
1
>
{}([
&
](
auto
i
)
{
c
[((
8
*
(
i
/
4
))
%
32
+
4
*
(
lane
/
32
)
+
i
%
4
)
*
N
+
lane
%
32
]
=
ck
::
type_convert
<
dst_t
>
(
c_thread_buf_
[
Number
<
i
>
{}]);
});
fragC
[
i
]
=
type_convert
<
CType
>
(
fragAcc
.
template
AsType
<
RawAccumFragT
>()[
Number
<
0
>
{}][
i
]);
}
auto
storeC
=
store_C_col_major
<
CType
,
CFragT
,
BLOCK_M
,
BLOCK_N
>
{};
storeC
(
c
,
fragC
);
}
/**
...
...
@@ -191,7 +410,7 @@ bool RunDeviceGEMM(KernelType kernel,
return
true
;
}
template
<
typename
DeviceM
XM
FMA
,
template
<
typename
DeviceMFMA
,
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
...
...
@@ -200,14 +419,10 @@ template <typename DeviceMXMFMA,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
index_t
CAccNum
,
index_t
M
,
index_t
N
,
index_t
K
>
struct
TestMXMFMA
index_t
BLOCK_M
,
index_t
BLOCK_N
,
index_t
BLOCK_K
>
struct
TestMFMA
{
auto
PrepareGemmTensors
(
const
GemmParams
&
params
)
{
...
...
@@ -234,25 +449,25 @@ struct TestMXMFMA
Tensor
<
CDataType
>
c_m_n_device_result
(
f_host_tensor_descriptor
(
params
.
M
,
params
.
N
,
params
.
StrideC
,
CLayout
{}));
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_1
<
ADataType
>
{
1
});
b_n_k
.
GenerateTensorValue
(
GeneratorTensor_
1
<
BDataType
>
{
1
});
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_1
<
ADataType
>
{
0.015625
f
});
b_n_k
.
GenerateTensorValue
(
GeneratorTensor_
Sequential
<
BDataType
,
1
>
{});
return
std
::
make_tuple
(
a_m_k
,
b_n_k
,
c_m_n_host_result
,
c_m_n_device_result
);
}
auto
operator
()(
const
DeviceM
XM
FMA
&
mfma_kernel
)
auto
operator
()(
const
DeviceMFMA
&
mfma_kernel
)
{
std
::
cout
<<
"ALayout = "
<<
ALayout
{}.
name
<<
", BLayout = "
<<
BLayout
{}.
name
<<
", CLayout = "
<<
CLayout
{}.
name
<<
std
::
endl
;
// Arrange
GemmParams
params
;
params
.
M
=
M
;
params
.
N
=
N
;
params
.
K
=
K
;
params
.
StrideA
=
K
;
// M K
params
.
StrideB
=
N
;
// K N
params
.
StrideC
=
N
;
// M N
params
.
M
=
BLOCK_
M
;
params
.
N
=
BLOCK_
N
;
params
.
K
=
BLOCK_
K
;
params
.
StrideA
=
BLOCK_
K
;
// M K
params
.
StrideB
=
BLOCK_
N
;
// K N
params
.
StrideC
=
BLOCK_
N
;
// M N
auto
host_tensors
=
PrepareGemmTensors
(
params
);
...
...
@@ -261,25 +476,27 @@ struct TestMXMFMA
Tensor
<
CDataType
>&
c_host
=
std
::
get
<
2
>
(
host_tensors
);
Tensor
<
CDataType
>&
c_device
=
std
::
get
<
3
>
(
host_tensors
);
auto
a_element_op
=
AElementwiseOperation
{};
auto
b_element_op
=
BElementwiseOperation
{};
auto
c_element_op
=
CElementwiseOperation
{};
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
auto
a_element_op
=
PassThrough
{};
auto
b_element_op
=
PassThrough
{};
auto
c_element_op
=
PassThrough
{};
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
CPUAccDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
;
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
CPUAccDataType
,
PassThrough
,
PassThrough
,
PassThrough
>
;
RunHostGEMM
<
ReferenceGemmInstance
>
(
a
,
b
,
c_host
,
a_element_op
,
b_element_op
,
c_element_op
);
RunDeviceGEMM
(
mfma_kernel
,
a
,
b
,
c_device
);
bool
res
=
false
;
if
constexpr
(
std
::
is_same
<
CDataType
,
float
>::
value
)
if
constexpr
(
std
::
is_same
<
CDataType
,
float
>::
value
||
std
::
is_same
<
CDataType
,
half_t
>::
value
)
{
res
=
ck
::
utils
::
check_err
(
c_device
.
mData
,
c_host
.
mData
);
std
::
cout
<<
(
res
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment