Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
a2ddbd2b
Unverified
Commit
a2ddbd2b
authored
Oct 14, 2023
by
arai713
Committed by
GitHub
Oct 14, 2023
Browse files
Merge branch 'develop' into transpose_5d
parents
e9ecf8d1
fa753f27
Changes
30
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
396 additions
and
494 deletions
+396
-494
library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_mk_kn_mn_instance.cpp
...emm_xdl_splitk_f16_f16_f16_comp_fp8_mk_kn_mn_instance.cpp
+83
-0
library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_mk_nk_mn_instance.cpp
...emm_xdl_splitk_f16_f16_f16_comp_fp8_mk_nk_mn_instance.cpp
+79
-0
profiler/include/profiler/profile_gemm_splitk_impl.hpp
profiler/include/profiler/profile_gemm_splitk_impl.hpp
+6
-3
profiler/src/profile_gemm_splitk.cpp
profiler/src/profile_gemm_splitk.cpp
+41
-19
test/batched_gemm/CMakeLists.txt
test/batched_gemm/CMakeLists.txt
+2
-16
test/batched_gemm/batched_gemm_bf16.cpp
test/batched_gemm/batched_gemm_bf16.cpp
+0
-114
test/batched_gemm/batched_gemm_fp16.cpp
test/batched_gemm/batched_gemm_fp16.cpp
+0
-114
test/batched_gemm/batched_gemm_fp32.cpp
test/batched_gemm/batched_gemm_fp32.cpp
+0
-114
test/batched_gemm/batched_gemm_int8.cpp
test/batched_gemm/batched_gemm_int8.cpp
+0
-114
test/batched_gemm/test_batched_gemm.cpp
test/batched_gemm/test_batched_gemm.cpp
+185
-0
No files found.
library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_mk_kn_mn_instance.cpp
0 → 100644
View file @
a2ddbd2b
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
F8
=
ck
::
f8_t
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
// static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static
constexpr
auto
GemmMNPadding
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
;
using
device_gemm_xdl_splitk_f16_f16_f16_comp_f8_mk_kn_mn_generic_instances
=
std
::
tuple
<
// clang-format off
//#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute|
//#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Type|
//#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| |
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
128
,
32
,
64
,
4
,
8
,
32
,
32
,
1
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
1
,
8
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
2
,
F8
>
// clang-format on
>
;
// Compilation parameters for a[m, k] * b[k, n] = c[m, n]
using
device_gemm_xdl_splitk_f16_f16_f16_comp_f8_mk_kn_mn_instances
=
std
::
tuple
<
// clang-format off
//#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute|
//#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Type|
//#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| |
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
F8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
F8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
128
,
128
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
F8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
256
,
64
,
192
,
4
,
8
,
32
,
32
,
1
,
3
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
48
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
F8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
256
,
192
,
64
,
4
,
8
,
32
,
32
,
3
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
F8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
256
,
128
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
F8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
128
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
F8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
128
,
64
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
F8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
256
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
F8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
256
,
64
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
F8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
128
,
32
,
192
,
4
,
8
,
32
,
32
,
1
,
3
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
24
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
F8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
128
,
192
,
32
,
4
,
8
,
32
,
32
,
3
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
F8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
128
,
32
,
64
,
4
,
8
,
32
,
32
,
1
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
F8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
128
,
64
,
32
,
4
,
8
,
32
,
32
,
1
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
F8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
128
,
32
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
F8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNPadding
,
128
,
128
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
F8
>
// clang-format on
>
;
void
add_device_gemm_xdl_splitk_f16_f16_f16_comp_f8_mk_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmSplitK
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
F8
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_xdl_splitk_f16_f16_f16_comp_f8_mk_kn_mn_generic_instances
{});
add_device_operation_instances
(
instances
,
device_gemm_xdl_splitk_f16_f16_f16_comp_f8_mk_kn_mn_instances
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_mk_nk_mn_instance.cpp
0 → 100644
View file @
a2ddbd2b
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
F8
=
ck
::
f8_t
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
MNKPadding
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
using
device_gemm_xdl_splitk_f16_f16_f16_comp_f8_mk_nk_mn_generic_instances
=
std
::
tuple
<
// clang-format off
//#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute|
//#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Type|
//#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| |
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
MNKPadding
,
64
,
32
,
64
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
1
,
8
,
true
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
3
,
1
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
2
,
F8
>
// clang-format on
>
;
// Compilation parameters for a[m, k] * b[k, n] = c[m, n]
using
device_gemm_xdl_splitk_f16_f16_f16_comp_f8_mk_nk_mn_instances
=
std
::
tuple
<
// clang-format off
//#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute|
//#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Type|
//#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| |
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
MNKPadding
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
3
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
F8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
MNKPadding
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
3
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
F8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
MNKPadding
,
128
,
128
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
3
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
F8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
MNKPadding
,
256
,
128
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
3
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
F8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
MNKPadding
,
128
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
3
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
F8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
MNKPadding
,
128
,
64
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
3
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
F8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
MNKPadding
,
64
,
64
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
3
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
F8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
MNKPadding
,
256
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
3
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
F8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
MNKPadding
,
256
,
64
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
3
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
F8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
MNKPadding
,
128
,
128
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
3
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
F8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
MNKPadding
,
128
,
32
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
3
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
F8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
MNKPadding
,
64
,
64
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
3
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
F8
>
,
DeviceGemmXdlSplitKCShuffle
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
MNKPadding
,
64
,
32
,
64
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
3
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
F8
>
// clang-format on
>
;
void
add_device_gemm_xdl_splitk_f16_f16_f16_comp_f8_mk_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmSplitK
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
F8
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_xdl_splitk_f16_f16_f16_comp_f8_mk_nk_mn_generic_instances
{});
add_device_operation_instances
(
instances
,
device_gemm_xdl_splitk_f16_f16_f16_comp_f8_mk_nk_mn_instances
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
profiler/include/profiler/profile_gemm_splitk_impl.hpp
View file @
a2ddbd2b
...
@@ -30,7 +30,8 @@ template <typename ADataType,
...
@@ -30,7 +30,8 @@ template <typename ADataType,
typename
CDataType
,
typename
CDataType
,
typename
ALayout
,
typename
ALayout
,
typename
BLayout
,
typename
BLayout
,
typename
CLayout
>
typename
CLayout
,
typename
ComputeType
=
CDataType
>
bool
profile_gemm_splitk_impl
(
int
do_verification
,
bool
profile_gemm_splitk_impl
(
int
do_verification
,
int
init_method
,
int
init_method
,
bool
do_log
,
bool
do_log
,
...
@@ -103,7 +104,8 @@ bool profile_gemm_splitk_impl(int do_verification,
...
@@ -103,7 +104,8 @@ bool profile_gemm_splitk_impl(int do_verification,
CDataType
,
CDataType
,
AElementOp
,
AElementOp
,
BElementOp
,
BElementOp
,
CElementOp
>
;
CElementOp
,
ComputeType
>
;
// get device op instances
// get device op instances
const
auto
op_ptrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
const
auto
op_ptrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
...
@@ -120,7 +122,8 @@ bool profile_gemm_splitk_impl(int do_verification,
...
@@ -120,7 +122,8 @@ bool profile_gemm_splitk_impl(int do_verification,
AccDataType
,
AccDataType
,
AElementOp
,
AElementOp
,
BElementOp
,
BElementOp
,
CElementOp
>
;
CElementOp
,
ComputeType
>
;
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
...
...
profiler/src/profile_gemm_splitk.cpp
View file @
a2ddbd2b
...
@@ -25,6 +25,7 @@ enum struct GemmDataType
...
@@ -25,6 +25,7 @@ enum struct GemmDataType
INT8_INT8_INT8
,
// 3
INT8_INT8_INT8
,
// 3
F8_F16_F16
,
// 4
F8_F16_F16
,
// 4
F16_F8_F16
,
// 5
F16_F8_F16
,
// 5
F16_F16_F16_F8
,
// 6
};
};
#define OP_NAME "gemm_splitk"
#define OP_NAME "gemm_splitk"
...
@@ -35,7 +36,8 @@ int profile_gemm_splitk(int argc, char* argv[])
...
@@ -35,7 +36,8 @@ int profile_gemm_splitk(int argc, char* argv[])
if
(
argc
!=
15
)
if
(
argc
!=
15
)
{
{
printf
(
"arg1: tensor operation ("
OP_NAME
": "
OP_DESC
")
\n
"
);
printf
(
"arg1: tensor operation ("
OP_NAME
": "
OP_DESC
")
\n
"
);
printf
(
"arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: f8@f16; 5: f16@f8)
\n
"
);
printf
(
"arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: f8@f16; 5: f16@f8; 6: f16, "
"comp f8)
\n
"
);
printf
(
"arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];
\n
"
);
printf
(
"arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];
\n
"
);
printf
(
" 1: A[m, k] * B[n, k] = C[m, n];
\n
"
);
printf
(
" 1: A[m, k] * B[n, k] = C[m, n];
\n
"
);
printf
(
" 2: A[k, m] * B[k, n] = C[m, n];
\n
"
);
printf
(
" 2: A[k, m] * B[k, n] = C[m, n];
\n
"
);
...
@@ -80,7 +82,8 @@ int profile_gemm_splitk(int argc, char* argv[])
...
@@ -80,7 +82,8 @@ int profile_gemm_splitk(int argc, char* argv[])
auto
c_type
,
auto
c_type
,
auto
a_layout
,
auto
a_layout
,
auto
b_layout
,
auto
b_layout
,
auto
c_layout
)
{
auto
c_layout
,
auto
compute_type
)
{
using
ADataType
=
decltype
(
a_type
);
using
ADataType
=
decltype
(
a_type
);
using
BDataType
=
decltype
(
b_type
);
using
BDataType
=
decltype
(
b_type
);
using
AccDataType
=
decltype
(
acc_type
);
using
AccDataType
=
decltype
(
acc_type
);
...
@@ -90,6 +93,8 @@ int profile_gemm_splitk(int argc, char* argv[])
...
@@ -90,6 +93,8 @@ int profile_gemm_splitk(int argc, char* argv[])
using
BLayout
=
decltype
(
b_layout
);
using
BLayout
=
decltype
(
b_layout
);
using
CLayout
=
decltype
(
c_layout
);
using
CLayout
=
decltype
(
c_layout
);
using
ComputeType
=
decltype
(
compute_type
);
const
int
DefaultStrideA
=
ck
::
is_same_v
<
ALayout
,
Row
>
?
K
:
M
;
const
int
DefaultStrideA
=
ck
::
is_same_v
<
ALayout
,
Row
>
?
K
:
M
;
const
int
DefaultStrideB
=
ck
::
is_same_v
<
BLayout
,
Row
>
?
N
:
K
;
const
int
DefaultStrideB
=
ck
::
is_same_v
<
BLayout
,
Row
>
?
N
:
K
;
const
int
DefaultStrideC
=
ck
::
is_same_v
<
CLayout
,
Row
>
?
N
:
M
;
const
int
DefaultStrideC
=
ck
::
is_same_v
<
CLayout
,
Row
>
?
N
:
M
;
...
@@ -100,7 +105,8 @@ int profile_gemm_splitk(int argc, char* argv[])
...
@@ -100,7 +105,8 @@ int profile_gemm_splitk(int argc, char* argv[])
CDataType
,
CDataType
,
ALayout
,
ALayout
,
BLayout
,
BLayout
,
CLayout
>
(
CLayout
,
ComputeType
>
(
do_verification
,
do_verification
,
init_method
,
init_method
,
do_log
,
do_log
,
...
@@ -118,68 +124,84 @@ int profile_gemm_splitk(int argc, char* argv[])
...
@@ -118,68 +124,84 @@ int profile_gemm_splitk(int argc, char* argv[])
if
(
data_type
==
GemmDataType
::
F32_F32_F32
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
if
(
data_type
==
GemmDataType
::
F32_F32_F32
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
{
return
profile
(
F32
{},
F32
{},
F32
{},
F32
{},
Row
{},
Row
{},
Row
{});
return
profile
(
F32
{},
F32
{},
F32
{},
F32
{},
Row
{},
Row
{},
Row
{}
,
F32
{}
);
}
}
else
if
(
data_type
==
GemmDataType
::
F32_F32_F32
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
else
if
(
data_type
==
GemmDataType
::
F32_F32_F32
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
{
{
return
profile
(
F32
{},
F32
{},
F32
{},
F32
{},
Row
{},
Col
{},
Row
{});
return
profile
(
F32
{},
F32
{},
F32
{},
F32
{},
Row
{},
Col
{},
Row
{}
,
F32
{}
);
}
}
else
if
(
data_type
==
GemmDataType
::
F32_F32_F32
&&
layout
==
GemmMatrixLayout
::
KM_KN_MN
)
else
if
(
data_type
==
GemmDataType
::
F32_F32_F32
&&
layout
==
GemmMatrixLayout
::
KM_KN_MN
)
{
{
return
profile
(
F32
{},
F32
{},
F32
{},
F32
{},
Col
{},
Row
{},
Row
{});
return
profile
(
F32
{},
F32
{},
F32
{},
F32
{},
Col
{},
Row
{},
Row
{}
,
F32
{}
);
}
}
else
if
(
data_type
==
GemmDataType
::
F32_F32_F32
&&
layout
==
GemmMatrixLayout
::
KM_NK_MN
)
else
if
(
data_type
==
GemmDataType
::
F32_F32_F32
&&
layout
==
GemmMatrixLayout
::
KM_NK_MN
)
{
{
return
profile
(
F32
{},
F32
{},
F32
{},
F32
{},
Col
{},
Col
{},
Row
{});
return
profile
(
F32
{},
F32
{},
F32
{},
F32
{},
Col
{},
Col
{},
Row
{}
,
F32
{}
);
}
}
else
if
(
data_type
==
GemmDataType
::
F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
else
if
(
data_type
==
GemmDataType
::
F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
{
return
profile
(
F16
{},
F16
{},
F32
{},
F16
{},
Row
{},
Row
{},
Row
{});
return
profile
(
F16
{},
F16
{},
F32
{},
F16
{},
Row
{},
Row
{},
Row
{}
,
F16
{}
);
}
}
else
if
(
data_type
==
GemmDataType
::
F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
else
if
(
data_type
==
GemmDataType
::
F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
{
{
return
profile
(
F16
{},
F16
{},
F32
{},
F16
{},
Row
{},
Col
{},
Row
{});
return
profile
(
F16
{},
F16
{},
F32
{},
F16
{},
Row
{},
Col
{},
Row
{}
,
F16
{}
);
}
}
else
if
(
data_type
==
GemmDataType
::
F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
KM_KN_MN
)
else
if
(
data_type
==
GemmDataType
::
F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
KM_KN_MN
)
{
{
return
profile
(
F16
{},
F16
{},
F32
{},
F16
{},
Col
{},
Row
{},
Row
{});
return
profile
(
F16
{},
F16
{},
F32
{},
F16
{},
Col
{},
Row
{},
Row
{}
,
F16
{}
);
}
}
else
if
(
data_type
==
GemmDataType
::
F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
KM_NK_MN
)
else
if
(
data_type
==
GemmDataType
::
F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
KM_NK_MN
)
{
{
return
profile
(
F16
{},
F16
{},
F32
{},
F16
{},
Col
{},
Col
{},
Row
{});
return
profile
(
F16
{},
F16
{},
F32
{},
F16
{},
Col
{},
Col
{},
Row
{}
,
F16
{}
);
}
}
#if defined CK_ENABLE_FP8
#if defined CK_ENABLE_FP8
else
if
(
data_type
==
GemmDataType
::
F8_F16_F16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
else
if
(
data_type
==
GemmDataType
::
F8_F16_F16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
{
return
profile
(
F8
{},
F16
{},
F32
{},
F16
{},
Row
{},
Row
{},
Row
{});
return
profile
(
F8
{},
F16
{},
F32
{},
F16
{},
Row
{},
Row
{},
Row
{}
,
F16
{}
);
}
}
else
if
(
data_type
==
GemmDataType
::
F8_F16_F16
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
else
if
(
data_type
==
GemmDataType
::
F8_F16_F16
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
{
{
return
profile
(
F8
{},
F16
{},
F32
{},
F16
{},
Row
{},
Col
{},
Row
{});
return
profile
(
F8
{},
F16
{},
F32
{},
F16
{},
Row
{},
Col
{},
Row
{}
,
F16
{}
);
}
}
else
if
(
data_type
==
GemmDataType
::
F8_F16_F16
&&
layout
==
GemmMatrixLayout
::
KM_KN_MN
)
else
if
(
data_type
==
GemmDataType
::
F8_F16_F16
&&
layout
==
GemmMatrixLayout
::
KM_KN_MN
)
{
{
return
profile
(
F8
{},
F16
{},
F32
{},
F16
{},
Col
{},
Row
{},
Row
{});
return
profile
(
F8
{},
F16
{},
F32
{},
F16
{},
Col
{},
Row
{},
Row
{}
,
F16
{}
);
}
}
else
if
(
data_type
==
GemmDataType
::
F8_F16_F16
&&
layout
==
GemmMatrixLayout
::
KM_NK_MN
)
else
if
(
data_type
==
GemmDataType
::
F8_F16_F16
&&
layout
==
GemmMatrixLayout
::
KM_NK_MN
)
{
{
return
profile
(
F8
{},
F16
{},
F32
{},
F16
{},
Col
{},
Col
{},
Row
{});
return
profile
(
F8
{},
F16
{},
F32
{},
F16
{},
Col
{},
Col
{},
Row
{}
,
F16
{}
);
}
}
else
if
(
data_type
==
GemmDataType
::
F16_F8_F16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
else
if
(
data_type
==
GemmDataType
::
F16_F8_F16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
{
return
profile
(
F16
{},
F8
{},
F32
{},
F16
{},
Row
{},
Row
{},
Row
{});
return
profile
(
F16
{},
F8
{},
F32
{},
F16
{},
Row
{},
Row
{},
Row
{}
,
F16
{}
);
}
}
else
if
(
data_type
==
GemmDataType
::
F16_F8_F16
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
else
if
(
data_type
==
GemmDataType
::
F16_F8_F16
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
{
{
return
profile
(
F16
{},
F8
{},
F32
{},
F16
{},
Row
{},
Col
{},
Row
{});
return
profile
(
F16
{},
F8
{},
F32
{},
F16
{},
Row
{},
Col
{},
Row
{}
,
F16
{}
);
}
}
else
if
(
data_type
==
GemmDataType
::
F16_F8_F16
&&
layout
==
GemmMatrixLayout
::
KM_KN_MN
)
else
if
(
data_type
==
GemmDataType
::
F16_F8_F16
&&
layout
==
GemmMatrixLayout
::
KM_KN_MN
)
{
{
return
profile
(
F16
{},
F8
{},
F32
{},
F16
{},
Col
{},
Row
{},
Row
{});
return
profile
(
F16
{},
F8
{},
F32
{},
F16
{},
Col
{},
Row
{},
Row
{}
,
F16
{}
);
}
}
else
if
(
data_type
==
GemmDataType
::
F16_F8_F16
&&
layout
==
GemmMatrixLayout
::
KM_NK_MN
)
else
if
(
data_type
==
GemmDataType
::
F16_F8_F16
&&
layout
==
GemmMatrixLayout
::
KM_NK_MN
)
{
{
return
profile
(
F16
{},
F8
{},
F32
{},
F16
{},
Col
{},
Col
{},
Row
{});
return
profile
(
F16
{},
F8
{},
F32
{},
F16
{},
Col
{},
Col
{},
Row
{},
F16
{});
}
else
if
(
data_type
==
GemmDataType
::
F16_F16_F16_F8
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
return
profile
(
F16
{},
F16
{},
F32
{},
F16
{},
Row
{},
Row
{},
Row
{},
F8
{});
}
else
if
(
data_type
==
GemmDataType
::
F16_F16_F16_F8
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
{
return
profile
(
F16
{},
F16
{},
F32
{},
F16
{},
Row
{},
Col
{},
Row
{},
F8
{});
}
else
if
(
data_type
==
GemmDataType
::
F16_F16_F16_F8
&&
layout
==
GemmMatrixLayout
::
KM_KN_MN
)
{
return
profile
(
F16
{},
F16
{},
F32
{},
F16
{},
Col
{},
Row
{},
Row
{},
F8
{});
}
else
if
(
data_type
==
GemmDataType
::
F16_F16_F16_F8
&&
layout
==
GemmMatrixLayout
::
KM_NK_MN
)
{
return
profile
(
F16
{},
F16
{},
F32
{},
F16
{},
Col
{},
Col
{},
Row
{},
F8
{});
}
}
#endif
#endif
else
else
...
...
test/batched_gemm/CMakeLists.txt
View file @
a2ddbd2b
...
@@ -2,22 +2,8 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
...
@@ -2,22 +2,8 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set
(
target 0
)
set
(
target 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
add_test_executable
(
test_batched_gemm_fp16 batched_gemm_fp16.cpp
)
add_gtest_executable
(
test_batched_gemm test_batched_gemm.cpp
)
if
(
result EQUAL 0
)
target_link_libraries
(
test_batched_gemm PRIVATE utility device_batched_gemm_instance
)
target_link_libraries
(
test_batched_gemm_fp16 PRIVATE utility device_batched_gemm_instance
)
endif
()
add_test_executable
(
test_batched_gemm_fp32 batched_gemm_fp32.cpp
)
if
(
result EQUAL 0
)
target_link_libraries
(
test_batched_gemm_fp32 PRIVATE utility device_batched_gemm_instance
)
endif
()
add_test_executable
(
test_batched_gemm_bf16 batched_gemm_bf16.cpp
)
if
(
result EQUAL 0
)
target_link_libraries
(
test_batched_gemm_bf16 PRIVATE utility device_batched_gemm_instance
)
endif
()
add_test_executable
(
test_batched_gemm_int8 batched_gemm_int8.cpp
)
if
(
result EQUAL 0
)
target_link_libraries
(
test_batched_gemm_int8 PRIVATE utility device_batched_gemm_instance
)
endif
()
set
(
target 1
)
set
(
target 1
)
endif
()
endif
()
endforeach
()
endforeach
()
\ No newline at end of file
test/batched_gemm/batched_gemm_bf16.cpp
deleted
100644 → 0
View file @
e9ecf8d1
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include "profiler/profile_batched_gemm_impl.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm.hpp"
namespace
{
using
ADataType
=
ck
::
bhalf_t
;
using
BDataType
=
ck
::
bhalf_t
;
using
CDataType
=
ck
::
bhalf_t
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
}
// namespace
int
main
()
{
int
M
=
256
;
int
N
=
256
;
int
K
=
128
;
int
BatchCount
=
3
;
bool
pass
=
true
;
using
namespace
ck
::
tensor_operation
::
device
;
pass
=
pass
&&
ck
::
profiler
::
profile_batched_gemm_impl
<
ADataType
,
BDataType
,
CDataType
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
DeviceBatchedGemm
<
Row
,
Row
,
Row
,
ADataType
,
BDataType
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
>>
(
true
,
1
,
false
,
1
,
M
,
N
,
K
,
K
,
N
,
N
,
M
*
K
,
K
*
N
,
M
*
N
,
BatchCount
);
pass
=
pass
&&
ck
::
profiler
::
profile_batched_gemm_impl
<
ADataType
,
BDataType
,
CDataType
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
DeviceBatchedGemm
<
Row
,
Col
,
Row
,
ADataType
,
BDataType
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
>>
(
true
,
1
,
false
,
1
,
M
,
N
,
K
,
K
,
K
,
N
,
M
*
K
,
K
*
N
,
M
*
N
,
BatchCount
);
pass
=
pass
&&
ck
::
profiler
::
profile_batched_gemm_impl
<
ADataType
,
BDataType
,
CDataType
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
DeviceBatchedGemm
<
Col
,
Row
,
Row
,
ADataType
,
BDataType
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
>>
(
true
,
1
,
false
,
1
,
M
,
N
,
K
,
M
,
N
,
N
,
M
*
K
,
K
*
N
,
M
*
N
,
BatchCount
);
pass
=
pass
&&
ck
::
profiler
::
profile_batched_gemm_impl
<
ADataType
,
BDataType
,
CDataType
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
DeviceBatchedGemm
<
Col
,
Col
,
Row
,
ADataType
,
BDataType
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
>>
(
true
,
1
,
false
,
1
,
M
,
N
,
K
,
M
,
K
,
N
,
M
*
K
,
K
*
N
,
M
*
N
,
BatchCount
);
std
::
cout
<<
"test BatchedGEMM bf16: "
<<
(
pass
?
"Pass"
:
"Fail"
)
<<
std
::
endl
;
return
pass
?
0
:
1
;
}
test/batched_gemm/batched_gemm_fp16.cpp
deleted
100644 → 0
View file @
e9ecf8d1
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include "profiler/profile_batched_gemm_impl.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm.hpp"
namespace
{
using
ADataType
=
ck
::
half_t
;
using
BDataType
=
ck
::
half_t
;
using
CDataType
=
ck
::
half_t
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
}
// namespace
int
main
()
{
int
M
=
512
;
int
N
=
256
;
int
K
=
128
;
int
BatchCount
=
3
;
bool
pass
=
true
;
using
namespace
ck
::
tensor_operation
::
device
;
pass
=
pass
&&
ck
::
profiler
::
profile_batched_gemm_impl
<
ADataType
,
BDataType
,
CDataType
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
DeviceBatchedGemm
<
Row
,
Row
,
Row
,
ADataType
,
BDataType
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
>>
(
true
,
1
,
false
,
1
,
M
,
N
,
K
,
K
,
N
,
N
,
M
*
K
,
K
*
N
,
M
*
N
,
BatchCount
);
pass
=
pass
&&
ck
::
profiler
::
profile_batched_gemm_impl
<
ADataType
,
BDataType
,
CDataType
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
DeviceBatchedGemm
<
Row
,
Col
,
Row
,
ADataType
,
BDataType
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
>>
(
true
,
1
,
false
,
1
,
M
,
N
,
K
,
K
,
K
,
N
,
M
*
K
,
K
*
N
,
M
*
N
,
BatchCount
);
pass
=
pass
&&
ck
::
profiler
::
profile_batched_gemm_impl
<
ADataType
,
BDataType
,
CDataType
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
DeviceBatchedGemm
<
Col
,
Row
,
Row
,
ADataType
,
BDataType
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
>>
(
true
,
1
,
false
,
1
,
M
,
N
,
K
,
M
,
N
,
N
,
M
*
K
,
K
*
N
,
M
*
N
,
BatchCount
);
pass
=
pass
&&
ck
::
profiler
::
profile_batched_gemm_impl
<
ADataType
,
BDataType
,
CDataType
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
DeviceBatchedGemm
<
Col
,
Col
,
Row
,
ADataType
,
BDataType
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
>>
(
true
,
1
,
false
,
1
,
M
,
N
,
K
,
M
,
K
,
N
,
M
*
K
,
K
*
N
,
M
*
N
,
BatchCount
);
std
::
cout
<<
"test BatchedGEMM fp16: "
<<
(
pass
?
"Pass"
:
"Fail"
)
<<
std
::
endl
;
return
pass
?
0
:
1
;
}
test/batched_gemm/batched_gemm_fp32.cpp
deleted
100644 → 0
View file @
e9ecf8d1
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include "profiler/profile_batched_gemm_impl.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm.hpp"
namespace
{
using
ADataType
=
float
;
using
BDataType
=
float
;
using
CDataType
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
}
// namespace
int
main
()
{
int
M
=
256
;
int
N
=
256
;
int
K
=
128
;
int
BatchCount
=
3
;
bool
pass
=
true
;
using
namespace
ck
::
tensor_operation
::
device
;
pass
=
pass
&&
ck
::
profiler
::
profile_batched_gemm_impl
<
ADataType
,
BDataType
,
CDataType
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
DeviceBatchedGemm
<
Row
,
Row
,
Row
,
ADataType
,
BDataType
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
>>
(
true
,
1
,
false
,
1
,
M
,
N
,
K
,
K
,
N
,
N
,
M
*
K
,
K
*
N
,
M
*
N
,
BatchCount
);
pass
=
pass
&&
ck
::
profiler
::
profile_batched_gemm_impl
<
ADataType
,
BDataType
,
CDataType
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
DeviceBatchedGemm
<
Row
,
Col
,
Row
,
ADataType
,
BDataType
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
>>
(
true
,
1
,
false
,
1
,
M
,
N
,
K
,
K
,
K
,
N
,
M
*
K
,
K
*
N
,
M
*
N
,
BatchCount
);
pass
=
pass
&&
ck
::
profiler
::
profile_batched_gemm_impl
<
ADataType
,
BDataType
,
CDataType
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
DeviceBatchedGemm
<
Col
,
Row
,
Row
,
ADataType
,
BDataType
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
>>
(
true
,
1
,
false
,
1
,
M
,
N
,
K
,
M
,
N
,
N
,
M
*
K
,
K
*
N
,
M
*
N
,
BatchCount
);
pass
=
pass
&&
ck
::
profiler
::
profile_batched_gemm_impl
<
ADataType
,
BDataType
,
CDataType
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
DeviceBatchedGemm
<
Col
,
Col
,
Row
,
ADataType
,
BDataType
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
>>
(
true
,
1
,
false
,
1
,
M
,
N
,
K
,
M
,
K
,
N
,
M
*
K
,
K
*
N
,
M
*
N
,
BatchCount
);
std
::
cout
<<
"test BatchedGEMM fp32: "
<<
(
pass
?
"Pass"
:
"Fail"
)
<<
std
::
endl
;
return
pass
?
0
:
1
;
}
test/batched_gemm/batched_gemm_int8.cpp
deleted
100644 → 0
View file @
e9ecf8d1
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include "profiler/profile_batched_gemm_impl.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm.hpp"
namespace
{
using
ADataType
=
int8_t
;
using
BDataType
=
int8_t
;
using
CDataType
=
int8_t
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
}
// namespace
int
main
()
{
int
M
=
256
;
int
N
=
256
;
int
K
=
128
;
int
BatchCount
=
3
;
bool
pass
=
true
;
using
namespace
ck
::
tensor_operation
::
device
;
pass
=
pass
&&
ck
::
profiler
::
profile_batched_gemm_impl
<
ADataType
,
BDataType
,
CDataType
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
DeviceBatchedGemm
<
Row
,
Row
,
Row
,
ADataType
,
BDataType
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
>>
(
true
,
1
,
false
,
1
,
M
,
N
,
K
,
K
,
N
,
N
,
M
*
K
,
K
*
N
,
M
*
N
,
BatchCount
);
pass
=
pass
&&
ck
::
profiler
::
profile_batched_gemm_impl
<
ADataType
,
BDataType
,
CDataType
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
DeviceBatchedGemm
<
Row
,
Col
,
Row
,
ADataType
,
BDataType
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
>>
(
true
,
1
,
false
,
1
,
M
,
N
,
K
,
K
,
K
,
N
,
M
*
K
,
K
*
N
,
M
*
N
,
BatchCount
);
pass
=
pass
&&
ck
::
profiler
::
profile_batched_gemm_impl
<
ADataType
,
BDataType
,
CDataType
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
DeviceBatchedGemm
<
Col
,
Row
,
Row
,
ADataType
,
BDataType
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
>>
(
true
,
1
,
false
,
1
,
M
,
N
,
K
,
M
,
N
,
N
,
M
*
K
,
K
*
N
,
M
*
N
,
BatchCount
);
pass
=
pass
&&
ck
::
profiler
::
profile_batched_gemm_impl
<
ADataType
,
BDataType
,
CDataType
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
DeviceBatchedGemm
<
Col
,
Col
,
Row
,
ADataType
,
BDataType
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
>>
(
true
,
1
,
false
,
1
,
M
,
N
,
K
,
M
,
K
,
N
,
M
*
K
,
K
*
N
,
M
*
N
,
BatchCount
);
std
::
cout
<<
"test BatchedGEMM int8: "
<<
(
pass
?
"Pass"
:
"Fail"
)
<<
std
::
endl
;
return
pass
?
0
:
1
;
}
test/batched_gemm/test_batched_gemm.cpp
0 → 100644
View file @
a2ddbd2b
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <iostream>
#include <initializer_list>
#include <tuple>
#include <vector>
#include <gtest/gtest.h>
#include "profiler/profile_batched_gemm_impl.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm.hpp"
struct
GemmParams
{
ck
::
index_t
M
;
ck
::
index_t
N
;
ck
::
index_t
K
;
ck
::
index_t
BatchCount
;
};
class
TestBatchedGemm
:
public
::
testing
::
Test
{
protected:
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
std
::
vector
<
GemmParams
>
params
;
template
<
typename
DataType
>
void
Run
()
{
using
namespace
ck
::
tensor_operation
::
device
;
bool
pass
=
true
;
for
(
auto
&
param
:
params
)
{
const
auto
M
=
param
.
M
;
const
auto
N
=
param
.
N
;
const
auto
K
=
param
.
K
;
const
auto
BatchCount
=
param
.
BatchCount
;
pass
=
pass
&&
ck
::
profiler
::
profile_batched_gemm_impl
<
DataType
,
DataType
,
DataType
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
DeviceBatchedGemm
<
Row
,
Row
,
Row
,
DataType
,
DataType
,
DataType
,
PassThrough
,
PassThrough
,
PassThrough
>>
(
true
,
1
,
false
,
1
,
M
,
N
,
K
,
K
,
N
,
N
,
M
*
K
,
K
*
N
,
M
*
N
,
BatchCount
);
pass
=
pass
&&
ck
::
profiler
::
profile_batched_gemm_impl
<
DataType
,
DataType
,
DataType
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
DeviceBatchedGemm
<
Row
,
Col
,
Row
,
DataType
,
DataType
,
DataType
,
PassThrough
,
PassThrough
,
PassThrough
>>
(
true
,
1
,
false
,
1
,
M
,
N
,
K
,
K
,
K
,
N
,
M
*
K
,
K
*
N
,
M
*
N
,
BatchCount
);
pass
=
pass
&&
ck
::
profiler
::
profile_batched_gemm_impl
<
DataType
,
DataType
,
DataType
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
DeviceBatchedGemm
<
Col
,
Row
,
Row
,
DataType
,
DataType
,
DataType
,
PassThrough
,
PassThrough
,
PassThrough
>>
(
true
,
1
,
false
,
1
,
M
,
N
,
K
,
M
,
N
,
N
,
M
*
K
,
K
*
N
,
M
*
N
,
BatchCount
);
pass
=
pass
&&
ck
::
profiler
::
profile_batched_gemm_impl
<
DataType
,
DataType
,
DataType
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
DeviceBatchedGemm
<
Col
,
Col
,
Row
,
DataType
,
DataType
,
DataType
,
PassThrough
,
PassThrough
,
PassThrough
>>
(
true
,
1
,
false
,
1
,
M
,
N
,
K
,
M
,
K
,
N
,
M
*
K
,
K
*
N
,
M
*
N
,
BatchCount
);
}
EXPECT_TRUE
(
pass
);
}
};
#ifdef CK_ENABLE_INT8
TEST_F
(
TestBatchedGemm
,
i8
)
{
this
->
params
.
push_back
({
64
,
64
,
64
,
2
});
this
->
params
.
push_back
({
64
,
64
,
64
,
1
});
this
->
params
.
push_back
({
60
,
60
,
60
,
2
});
this
->
params
.
push_back
({
68
,
68
,
68
,
2
});
this
->
params
.
push_back
({
40
,
40
,
40
,
2
});
this
->
params
.
push_back
({
256
,
256
,
128
,
3
});
this
->
template
Run
<
int8_t
>();
}
#endif
#ifdef CK_ENABLE_BF16
TEST_F
(
TestBatchedGemm
,
bf16
)
{
this
->
params
.
push_back
({
64
,
64
,
64
,
2
});
this
->
params
.
push_back
({
64
,
64
,
64
,
1
});
this
->
params
.
push_back
({
60
,
60
,
60
,
2
});
this
->
params
.
push_back
({
68
,
68
,
68
,
2
});
this
->
params
.
push_back
({
40
,
40
,
40
,
2
});
this
->
params
.
push_back
({
256
,
256
,
128
,
3
});
this
->
template
Run
<
ck
::
bhalf_t
>();
}
#endif
#ifdef CK_ENABLE_FP16
TEST_F
(
TestBatchedGemm
,
fp16
)
{
this
->
params
.
push_back
({
64
,
64
,
64
,
2
});
this
->
params
.
push_back
({
64
,
64
,
64
,
1
});
this
->
params
.
push_back
({
60
,
60
,
60
,
2
});
this
->
params
.
push_back
({
68
,
68
,
68
,
2
});
this
->
params
.
push_back
({
40
,
40
,
40
,
2
});
this
->
params
.
push_back
({
256
,
256
,
128
,
3
});
this
->
template
Run
<
ck
::
half_t
>();
}
#endif
#ifdef CK_ENABLE_FP32
TEST_F
(
TestBatchedGemm
,
fp32
)
{
this
->
params
.
push_back
({
64
,
64
,
64
,
2
});
this
->
params
.
push_back
({
64
,
64
,
64
,
1
});
this
->
params
.
push_back
({
60
,
60
,
60
,
2
});
this
->
params
.
push_back
({
68
,
68
,
68
,
2
});
this
->
params
.
push_back
({
40
,
40
,
40
,
2
});
this
->
params
.
push_back
({
256
,
256
,
128
,
3
});
this
->
template
Run
<
float
>();
}
#endif
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment