Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
6778c318
Commit
6778c318
authored
Jan 14, 2025
by
Andriy Roshchenko
Browse files
WIP: Introduce MX MFMA test
parent
c4a05057
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
393 additions
and
0 deletions
+393
-0
test/CMakeLists.txt
test/CMakeLists.txt
+13
-0
test/mx_mfma_op/CMakeLists.txt
test/mx_mfma_op/CMakeLists.txt
+9
-0
test/mx_mfma_op/mx_mfma_op.cpp
test/mx_mfma_op/mx_mfma_op.cpp
+74
-0
test/mx_mfma_op/mx_mfma_op.hpp
test/mx_mfma_op/mx_mfma_op.hpp
+297
-0
No files found.
test/CMakeLists.txt
View file @
6778c318
...
@@ -126,18 +126,28 @@ function(add_gtest_executable TEST_NAME)
...
@@ -126,18 +126,28 @@ function(add_gtest_executable TEST_NAME)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
endif
()
endif
()
endforeach
()
endforeach
()
foreach
(
source IN LISTS ARGN
)
foreach
(
source IN LISTS ARGN
)
if
(
NOT TEST_TARGETS MATCHES
"gfx9"
AND source MATCHES
"xdl"
)
if
(
NOT TEST_TARGETS MATCHES
"gfx9"
AND source MATCHES
"xdl"
)
message
(
"removing xdl test
${
source
}
"
)
message
(
"removing xdl test
${
source
}
"
)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
endif
()
endif
()
endforeach
()
endforeach
()
foreach
(
source IN LISTS ARGN
)
if
(
NOT TEST_TARGETS MATCHES
"gfx95"
AND source MATCHES
"mx_"
)
message
(
"removing microscaling test
${
source
}
"
)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
endif
()
endforeach
()
foreach
(
source IN LISTS ARGN
)
foreach
(
source IN LISTS ARGN
)
if
(
NOT TEST_TARGETS MATCHES
"gfx11"
AND NOT TEST_TARGETS MATCHES
"gfx12"
AND source MATCHES
"wmma"
)
if
(
NOT TEST_TARGETS MATCHES
"gfx11"
AND NOT TEST_TARGETS MATCHES
"gfx12"
AND source MATCHES
"wmma"
)
message
(
"removing wmma test
${
source
}
"
)
message
(
"removing wmma test
${
source
}
"
)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
endif
()
endif
()
endforeach
()
endforeach
()
#only continue if there are some source files left on the list
#only continue if there are some source files left on the list
if
(
ARGN
)
if
(
ARGN
)
if
(
ARGN MATCHES
"_xdl"
)
if
(
ARGN MATCHES
"_xdl"
)
...
@@ -209,5 +219,8 @@ endif()
...
@@ -209,5 +219,8 @@ endif()
if
(
SUPPORTED_GPU_TARGETS MATCHES
"gfx942"
OR SUPPORTED_GPU_TARGETS MATCHES
"gfx950"
)
# smfmac needs ROCm6.2
if
(
SUPPORTED_GPU_TARGETS MATCHES
"gfx942"
OR SUPPORTED_GPU_TARGETS MATCHES
"gfx950"
)
# smfmac needs ROCm6.2
add_subdirectory
(
smfmac_op
)
add_subdirectory
(
smfmac_op
)
endif
()
endif
()
if
(
SUPPORTED_GPU_TARGETS MATCHES
"gfx950"
)
add_subdirectory
(
mx_mfma_op
)
endif
()
add_subdirectory
(
position_embedding
)
add_subdirectory
(
position_embedding
)
add_subdirectory
(
scatter_gather
)
add_subdirectory
(
scatter_gather
)
test/mx_mfma_op/CMakeLists.txt
0 → 100644
View file @
6778c318
add_custom_target
(
test_mx_mfma
)
add_gtest_executable
(
test_mx_mfma_op mx_mfma_op.cpp
)
if
(
result EQUAL 0
)
target_link_libraries
(
test_mx_mfma_op PRIVATE utility
)
endif
()
add_dependencies
(
test_mx_mfma test_mx_mfma_op
)
test/mx_mfma_op/mx_mfma_op.cpp
0 → 100644
View file @
6778c318
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "mx_mfma_op.hpp"
using
ck
::
e8m0_bexp_t
;
using
ck
::
f8_ocp_t
;
using
ck
::
type_convert
;
template
<
typename
Src1Type
,
ck
::
index_t
Src1VecSize
,
typename
Src2Type
,
ck
::
index_t
Src2VecSize
,
typename
DstType
,
ck
::
index_t
AccVecSize
,
typename
AccType
,
typename
CPUAccType
,
ck
::
index_t
M
,
ck
::
index_t
N
,
ck
::
index_t
K
>
bool
run_test
()
{
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
bool
pass
=
true
;
const
auto
mx_mfma_kernel
=
ck
::
mx_mfma_test
::
matmul
<
Src1Type
,
Src1VecSize
,
Src2Type
,
Src2VecSize
,
AccType
,
AccVecSize
,
DstType
,
M
,
N
,
K
>
;
pass
=
ck
::
mx_mfma_test
::
TestMXMFMA
<
decltype
(
mx_mfma_kernel
),
Src1Type
,
Src2Type
,
DstType
,
AccType
,
CPUAccType
,
decltype
(
Row
{}),
decltype
(
Row
{}),
decltype
(
Row
{}),
PassThrough
,
PassThrough
,
PassThrough
,
AccVecSize
,
M
,
N
,
K
>
{}(
mx_mfma_kernel
);
return
pass
;
}
TEST
(
MXMFMA
,
FP8MFMA16x16x128
)
{
auto
pass
=
run_test
<
float
,
1
,
float
,
1
,
float
,
1
,
float
,
float
,
16
,
16
,
128
>
();
EXPECT_TRUE
(
pass
);
}
// TEST(MXMFMA, FP8MFMA32x32x64)
// {
// EXPECT_TRUE(run_test<f8, 1, f8, 1, float, 1, float, float, 32, 32, 64>());
// }
// TEST(MXMFMA, BF8MFMA16x16x128)
// {
// EXPECT_TRUE(run_test<bf8, 1, bf8, 1, float, 1, float, float, 16, 16, 128>());
// }
// TEST(MXMFMA, BF8MFMA32x32x64)
// {
// EXPECT_TRUE(run_test<bf8, 1, bf8, 1, float, 1, float, float, 32, 32, 64>());
// }
TEST
(
MXMFMA
,
MXFP8xMXFP8
)
{
EXPECT_TRUE
(
false
)
<<
"Not Implemented
\n
"
;
}
TEST
(
MXMFMA
,
MXBF8xMXBF8
)
{
EXPECT_TRUE
(
false
)
<<
"Not Implemented
\n
"
;
}
test/mx_mfma_op/mx_mfma_op.hpp
0 → 100644
View file @
6778c318
#pragma once
#include "ck/ck.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/check_err.hpp"
namespace
ck
{
namespace
mx_mfma_test
{
template
<
typename
src_vec1
,
typename
src_vec2
,
typename
acc_vec
>
__device__
void
builtin_mx_mfma_naive_selector
(
const
src_vec1
&
,
const
src_vec2
&
,
acc_vec
&
)
{
}
// Smfmac instructions are using 4:2 structural sparsity, that means that in every contignuous
// subgroup of 4 elements, atleast 2 must be equal to zero and the position of non-zero elements is
// stored in idx register to allow selection of corresponding B matrix elements for multiplication.
// Currently smfmac instructions support only A matrix as sparse
template
<
typename
src1_t
,
index_t
src1_vec_size
,
typename
src2_t
,
index_t
src2_vec_size
,
typename
acc_t
,
index_t
acc_vec_size
,
typename
dst_t
,
int32_t
M
,
int32_t
N
,
int32_t
K
>
__global__
void
matmul
(
const
src1_t
*
a
,
const
src2_t
*
b
,
dst_t
*
c
)
{
__shared__
src1_t
a_shared
[
M
*
K
];
__shared__
src2_t
b_shared
[
K
*
N
];
const
int
lane
=
threadIdx
.
x
;
// smfmac's A part is storing only non-zero elements in 2VGPRs
// smfmac's B part is storing all elements in 4VGPRs
using
src1_vec
=
typename
vector_type
<
src1_t
,
src1_vec_size
>::
type
;
using
src1_full_vec
=
typename
vector_type
<
src1_t
,
src1_vec_size
*
2
>::
type
;
using
src2_vec
=
typename
vector_type
<
src2_t
,
src2_vec_size
>::
type
;
src1_vec
a_frag
=
{};
src2_vec
b_frag
=
{};
src1_full_vec
a_temp
=
{};
src2_vec
b_temp
=
{};
// initialize c fragment to 0
using
acc_vec
=
StaticBufferTupleOfVector
<
AddressSpaceEnum
::
Vgpr
,
acc_t
,
1
,
acc_vec_size
,
true
>
;
acc_vec
c_thread_buf_
;
for
(
int
i
=
0
;
i
<
8
;
++
i
)
{
a_temp
[
i
]
=
a
[(
lane
%
M
)
*
K
+
(
lane
/
M
)
*
8
+
i
];
// M K
}
for
(
int
i
=
0
;
i
<
8
;
++
i
)
{
b_temp
[
i
]
=
b
[(
8
*
(
lane
/
N
)
+
i
)
*
N
+
(
lane
%
N
)];
// K N
}
__syncthreads
();
for
(
int
i
=
0
;
i
<
8
;
++
i
)
{
a_shared
[(
lane
%
M
)
*
K
+
(
lane
/
M
)
*
8
+
i
]
=
a_temp
[
i
];
}
for
(
int
i
=
0
;
i
<
8
;
++
i
)
{
b_shared
[(
8
*
(
lane
/
N
)
+
i
)
*
N
+
(
lane
%
N
)]
=
b_temp
[
i
];
}
__syncthreads
();
// Idx must be a 32-bit register and it is storing 4 2-bit indexes of A's non zero elements.
// It starts with last two elements of every 4 elements subgroup set as non-zero
int32_t
idx
=
0b11101110
;
// Bit masks are for zeroing 0-3rd position of idx
static
constexpr
int32_t
bit_clear_masks
[
4
]
=
{
0b11
,
0b1100
,
0b110000
,
0b11000000
};
src1_t
curr_val
;
int32_t
a_pos
=
0
;
for
(
int
j
=
0
;
j
<
2
;
++
j
)
{
a_pos
=
j
*
2
;
for
(
int
i
=
0
;
i
<
4
;
++
i
)
{
curr_val
=
a_shared
[(
lane
%
M
)
*
K
+
(
lane
/
M
)
*
8
+
4
*
j
+
i
];
if
(
curr_val
!=
0.0
f
)
{
idx
&=
~
bit_clear_masks
[
a_pos
];
idx
|=
(
i
%
4
)
<<
2
*
a_pos
;
a_frag
[
a_pos
]
=
curr_val
;
a_pos
++
;
}
}
}
for
(
int
i
=
0
;
i
<
8
;
++
i
)
{
b_frag
[
i
]
=
b_shared
[(
8
*
(
lane
/
N
)
+
i
)
*
N
+
(
lane
%
N
)];
}
builtin_smfmac_naive_selector
<
src1_vec
,
src2_vec
,
acc_vec
>
(
a_frag
,
b_frag
,
idx
,
c_thread_buf_
);
__syncthreads
();
// store results from unpacked c_thread_buf_ output
if
constexpr
(
K
==
32
)
{
static_for
<
0
,
acc_vec_size
,
1
>
{}([
&
](
auto
i
)
{
c
[(
4
*
(
lane
/
16
)
+
i
)
*
N
+
lane
%
16
]
=
ck
::
type_convert
<
dst_t
>
(
c_thread_buf_
[
Number
<
i
>
{}]);
});
}
else
{
static_for
<
0
,
acc_vec_size
,
1
>
{}([
&
](
auto
i
)
{
c
[((
8
*
(
i
/
4
))
%
32
+
4
*
(
lane
/
32
)
+
i
%
4
)
*
N
+
lane
%
32
]
=
ck
::
type_convert
<
dst_t
>
(
c_thread_buf_
[
Number
<
i
>
{}]);
});
}
}
/**
* @brief Structure to hold dimension parameters for GEMM tensors.
*
* M Number of rows in matrix A and matrix C.
* N Number of columns in matrix B and matrix C.
* K Number of columns in matrix A and number of rows in matrix B.
* StrideA Stride (leading dimension) of matrix A.
* StrideB Stride (leading dimension) of matrix B.
* StrideC Stride (leading dimension) of matrix C.
*/
struct
GemmParams
{
/**
* @brief This constructor initializes the parameters for GEMM storage with default values.
*
* A[16x128] * B[128x16] = C[16x16], all row major.
*/
GemmParams
()
:
M
(
16
),
N
(
16
),
K
(
128
),
StrideA
(
128
),
StrideB
(
16
),
StrideC
(
16
)
{}
ck
::
index_t
M
;
ck
::
index_t
N
;
ck
::
index_t
K
;
ck
::
index_t
StrideA
;
ck
::
index_t
StrideB
;
ck
::
index_t
StrideC
;
};
template
<
typename
GemmInstance
,
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
void
RunHostGEMM
(
const
Tensor
<
ADataType
>&
A
,
const
Tensor
<
BDataType
>&
B
,
Tensor
<
CDataType
>&
C
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
{
auto
ref_gemm
=
GemmInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
A
,
B
,
C
,
a_element_op
,
b_element_op
,
c_element_op
);
ref_invoker
.
Run
(
ref_argument
);
}
template
<
typename
KernelType
,
typename
ADataType
,
typename
BDataType
,
typename
CDataType
>
bool
RunDeviceGEMM
(
KernelType
kernel
,
const
Tensor
<
ADataType
>&
A
,
const
Tensor
<
BDataType
>&
B
,
Tensor
<
CDataType
>&
C
)
{
DeviceMem
a_m_k_device_buf
(
sizeof
(
ADataType
)
*
A
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_n_k_device_buf
(
sizeof
(
BDataType
)
*
B
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
c_m_n_device_buf
(
sizeof
(
CDataType
)
*
C
.
mDesc
.
GetElementSpaceSize
());
a_m_k_device_buf
.
ToDevice
(
A
.
mData
.
data
());
b_n_k_device_buf
.
ToDevice
(
B
.
mData
.
data
());
kernel
<<<
1
,
64
>>>
(
static_cast
<
const
ADataType
*>
(
a_m_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
const
BDataType
*>
(
b_n_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_m_n_device_buf
.
GetDeviceBuffer
()));
c_m_n_device_buf
.
FromDevice
(
C
.
mData
.
data
());
return
true
;
}
template
<
typename
DeviceMXMFMA
,
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
GPUAccDataType
,
typename
CPUAccDataType
,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
index_t
CAccNum
,
index_t
M
,
index_t
N
,
index_t
K
>
struct
TestMXMFMA
{
auto
PrepareGemmTensors
(
const
GemmParams
&
params
)
{
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
if
(
std
::
is_same
<
decltype
(
layout
),
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
row
,
col
}),
std
::
vector
<
std
::
size_t
>
({
stride
,
1
}));
}
else
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
row
,
col
}),
std
::
vector
<
std
::
size_t
>
({
1
,
stride
}));
}
};
Tensor
<
ADataType
>
a_m_k
(
f_host_tensor_descriptor
(
params
.
M
,
params
.
K
,
params
.
StrideA
,
ALayout
{}));
Tensor
<
BDataType
>
b_n_k
(
f_host_tensor_descriptor
(
params
.
K
,
params
.
N
,
params
.
StrideB
,
BLayout
{}));
Tensor
<
CDataType
>
c_m_n_host_result
(
f_host_tensor_descriptor
(
params
.
M
,
params
.
N
,
params
.
StrideC
,
CLayout
{}));
Tensor
<
CDataType
>
c_m_n_device_result
(
f_host_tensor_descriptor
(
params
.
M
,
params
.
N
,
params
.
StrideC
,
CLayout
{}));
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_1
<
ADataType
>
{
1
});
b_n_k
.
GenerateTensorValue
(
GeneratorTensor_1
<
BDataType
>
{
1
});
return
std
::
make_tuple
(
a_m_k
,
b_n_k
,
c_m_n_host_result
,
c_m_n_device_result
);
}
auto
operator
()(
const
DeviceMXMFMA
&
mfma_kernel
)
{
std
::
cout
<<
"ALayout = "
<<
ALayout
{}.
name
<<
", BLayout = "
<<
BLayout
{}.
name
<<
", CLayout = "
<<
CLayout
{}.
name
<<
std
::
endl
;
// Arrange
GemmParams
params
;
params
.
M
=
M
;
params
.
N
=
N
;
params
.
K
=
K
;
params
.
StrideA
=
K
;
// M K
params
.
StrideB
=
N
;
// K N
params
.
StrideC
=
N
;
// M N
auto
host_tensors
=
PrepareGemmTensors
(
params
);
const
Tensor
<
ADataType
>&
a
=
std
::
get
<
0
>
(
host_tensors
);
const
Tensor
<
BDataType
>&
b
=
std
::
get
<
1
>
(
host_tensors
);
Tensor
<
CDataType
>&
c_host
=
std
::
get
<
2
>
(
host_tensors
);
Tensor
<
CDataType
>&
c_device
=
std
::
get
<
3
>
(
host_tensors
);
auto
a_element_op
=
AElementwiseOperation
{};
auto
b_element_op
=
BElementwiseOperation
{};
auto
c_element_op
=
CElementwiseOperation
{};
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
CPUAccDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
;
RunHostGEMM
<
ReferenceGemmInstance
>
(
a
,
b
,
c_host
,
a_element_op
,
b_element_op
,
c_element_op
);
RunDeviceGEMM
(
mfma_kernel
,
a
,
b
,
c_device
);
bool
res
=
false
;
if
constexpr
(
std
::
is_same
<
CDataType
,
float
>::
value
)
{
res
=
ck
::
utils
::
check_err
(
c_device
.
mData
,
c_host
.
mData
);
std
::
cout
<<
(
res
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
}
else
{
std
::
cout
<<
"UNSUPPORTED CDataType"
<<
std
::
endl
;
}
return
res
;
}
};
}
// namespace mx_mfma_test
}
// namespace ck
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment