Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
a7ae4f8e
Commit
a7ae4f8e
authored
Jan 27, 2025
by
Astha Rai
Browse files
Merge branch 'codegen_hiprtc' of github.com:ROCm/composable_kernel into codegen_hiprtc
parents
a6055c3c
781005a5
Changes
175
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
495 additions
and
166 deletions
+495
-166
library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/CMakeLists.txt
...eration_instance/gpu/grouped_gemm_fixed_nk/CMakeLists.txt
+4
-2
library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_xdl_fixed_nk_bf16_bf16_bf16_mk_kn_mn_instance.cpp
...ed_gemm_xdl_fixed_nk_bf16_bf16_bf16_mk_kn_mn_instance.cpp
+73
-0
library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_xdl_fixed_nk_bf16_bf16_bf16_mk_nk_mn_instance.cpp
...ed_gemm_xdl_fixed_nk_bf16_bf16_bf16_mk_nk_mn_instance.cpp
+76
-0
profiler/include/profiler/profile_gemm_b_scale_impl.hpp
profiler/include/profiler/profile_gemm_b_scale_impl.hpp
+1
-1
profiler/src/profile_grouped_gemm_fixed_nk.cpp
profiler/src/profile_grouped_gemm_fixed_nk.cpp
+100
-63
test/CMakeLists.txt
test/CMakeLists.txt
+46
-0
test/ck_tile/batched_gemm/test_batched_gemm.cpp
test/ck_tile/batched_gemm/test_batched_gemm.cpp
+1
-1
test/ck_tile/batched_gemm/test_batched_gemm_util.hpp
test/ck_tile/batched_gemm/test_batched_gemm_util.hpp
+4
-4
test/ck_tile/gemm/test_gemm_pipeline.cpp
test/ck_tile/gemm/test_gemm_pipeline.cpp
+15
-13
test/ck_tile/gemm/test_gemm_pipeline_ut_cases.inc
test/ck_tile/gemm/test_gemm_pipeline_ut_cases.inc
+26
-5
test/ck_tile/gemm/test_gemm_pipeline_util.hpp
test/ck_tile/gemm/test_gemm_pipeline_util.hpp
+98
-72
test/ck_tile/grouped_gemm/test_grouped_gemm.cpp
test/ck_tile/grouped_gemm/test_grouped_gemm.cpp
+1
-1
test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp
test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp
+1
-4
test/data_type/CMakeLists.txt
test/data_type/CMakeLists.txt
+1
-0
test/data_type/test_bhalf.cpp
test/data_type/test_bhalf.cpp
+48
-0
No files found.
library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/CMakeLists.txt
View file @
a7ae4f8e
...
@@ -8,6 +8,8 @@ list(APPEND GROUPED_GEMM_FIXED_NK_INSTANCES device_grouped_gemm_xdl_fixed_nk_f16
...
@@ -8,6 +8,8 @@ list(APPEND GROUPED_GEMM_FIXED_NK_INSTANCES device_grouped_gemm_xdl_fixed_nk_f16
device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_kn_mn_instance.cpp
device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_kn_mn_instance.cpp
device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_nk_mn_instance.cpp
device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_nk_mn_instance.cpp
device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_kn_mn_instance.cpp
device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_kn_mn_instance.cpp
device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_nk_mn_instance.cpp
)
device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_nk_mn_instance.cpp
device_grouped_gemm_xdl_fixed_nk_bf16_bf16_bf16_mk_kn_mn_instance.cpp
device_grouped_gemm_xdl_fixed_nk_bf16_bf16_bf16_mk_nk_mn_instance.cpp
)
add_instance_library
(
device_grouped_gemm_fixed_nk_instance
${
GROUPED_GEMM_FIXED_NK_INSTANCES
}
)
add_instance_library
(
device_grouped_gemm_fixed_nk_instance
${
GROUPED_GEMM_FIXED_NK_INSTANCES
}
)
\ No newline at end of file
library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_xdl_fixed_nk_bf16_bf16_bf16_mk_kn_mn_instance.cpp
0 → 100644
View file @
a7ae4f8e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
BF16
=
ck
::
bhalf_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
DsDataType
=
ck
::
Tuple
<>
;
using
DsLayout
=
ck
::
Tuple
<>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
GemmMNKPadding
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
using
device_grouped_gemm_xdl_fixed_nk_bf16_bf16_bf16_mk_kn_mn_irregular_tile_instances
=
std
::
tuple
<
// clang-format off
//############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemm_Xdl_Fixed_NK
<
Row
,
Row
,
DsLayout
,
Row
,
BF16
,
BF16
,
F32
,
F32
,
DsDataType
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemm_Xdl_Fixed_NK
<
Row
,
Row
,
DsLayout
,
Row
,
BF16
,
BF16
,
F32
,
F32
,
DsDataType
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemm_Xdl_Fixed_NK
<
Row
,
Row
,
DsLayout
,
Row
,
BF16
,
BF16
,
F32
,
F32
,
DsDataType
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
128
,
64
,
32
,
8
,
2
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
16
,
16
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemm_Xdl_Fixed_NK
<
Row
,
Row
,
DsLayout
,
Row
,
BF16
,
BF16
,
F32
,
F32
,
DsDataType
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemm_Xdl_Fixed_NK
<
Row
,
Row
,
DsLayout
,
Row
,
BF16
,
BF16
,
F32
,
F32
,
DsDataType
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
64
,
128
,
32
,
8
,
2
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
8
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemm_Xdl_Fixed_NK
<
Row
,
Row
,
DsLayout
,
Row
,
BF16
,
BF16
,
F32
,
F32
,
DsDataType
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemm_Xdl_Fixed_NK
<
Row
,
Row
,
DsLayout
,
Row
,
BF16
,
BF16
,
F32
,
F32
,
DsDataType
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
128
,
64
,
32
,
8
,
2
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
,
DeviceGroupedGemm_Xdl_Fixed_NK
<
Row
,
Row
,
DsLayout
,
Row
,
BF16
,
BF16
,
F32
,
F32
,
DsDataType
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
,
DeviceGroupedGemm_Xdl_Fixed_NK
<
Row
,
Row
,
DsLayout
,
Row
,
BF16
,
BF16
,
F32
,
F32
,
DsDataType
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
64
,
128
,
32
,
8
,
2
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGroupedGemm_Xdl_Fixed_NK
<
Row
,
Row
,
DsLayout
,
Row
,
BF16
,
BF16
,
F32
,
F32
,
DsDataType
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
// clang-format on
>
;
void
add_device_grouped_gemm_xdl_fixed_nk_bf16_bf16_bf16_mk_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemmFixedNK
<
Row
,
Row
,
DsLayout
,
Row
,
BF16
,
BF16
,
DsDataType
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_grouped_gemm_xdl_fixed_nk_bf16_bf16_bf16_mk_kn_mn_irregular_tile_instances
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_xdl_fixed_nk_bf16_bf16_bf16_mk_nk_mn_instance.cpp
0 → 100644
View file @
a7ae4f8e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
BF16
=
ck
::
bhalf_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
DsDataType
=
ck
::
Tuple
<>
;
using
DsLayout
=
ck
::
Tuple
<>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
GemmMNKPadding
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
using
device_grouped_gemm_xdl_fixed_nk_bf16_bf16_bf16_mk_nk_mn_irregular_tile_instances
=
std
::
tuple
<
// clang-format off
//############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemm_Xdl_Fixed_NK
<
Row
,
Col
,
DsLayout
,
Row
,
BF16
,
BF16
,
F32
,
F32
,
DsDataType
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
128
,
256
,
64
,
8
,
8
,
32
,
32
,
2
,
4
,
S
<
1
,
8
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
8
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemm_Xdl_Fixed_NK
<
Row
,
Col
,
DsLayout
,
Row
,
BF16
,
BF16
,
F32
,
F32
,
DsDataType
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
128
,
128
,
64
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
8
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
8
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemm_Xdl_Fixed_NK
<
Row
,
Col
,
DsLayout
,
Row
,
BF16
,
BF16
,
F32
,
F32
,
DsDataType
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
128
,
64
,
64
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
8
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
8
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemm_Xdl_Fixed_NK
<
Row
,
Col
,
DsLayout
,
Row
,
BF16
,
BF16
,
F32
,
F32
,
DsDataType
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
64
,
128
,
64
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
8
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
8
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemm_Xdl_Fixed_NK
<
Row
,
Col
,
DsLayout
,
Row
,
BF16
,
BF16
,
F32
,
F32
,
DsDataType
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
128
,
128
,
64
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGroupedGemm_Xdl_Fixed_NK
<
Row
,
Col
,
DsLayout
,
Row
,
BF16
,
BF16
,
F32
,
F32
,
DsDataType
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
128
,
64
,
64
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
,
DeviceGroupedGemm_Xdl_Fixed_NK
<
Row
,
Col
,
DsLayout
,
Row
,
BF16
,
BF16
,
F32
,
F32
,
DsDataType
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
64
,
128
,
64
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGroupedGemm_Xdl_Fixed_NK
<
Row
,
Col
,
DsLayout
,
Row
,
BF16
,
BF16
,
F32
,
F32
,
DsDataType
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
128
,
32
,
64
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
,
DeviceGroupedGemm_Xdl_Fixed_NK
<
Row
,
Col
,
DsLayout
,
Row
,
BF16
,
BF16
,
F32
,
F32
,
DsDataType
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
32
,
128
,
64
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGroupedGemm_Xdl_Fixed_NK
<
Row
,
Col
,
DsLayout
,
Row
,
BF16
,
BF16
,
F32
,
F32
,
DsDataType
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
32
,
256
,
64
,
8
,
8
,
32
,
32
,
1
,
4
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
8
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGroupedGemm_Xdl_Fixed_NK
<
Row
,
Col
,
DsLayout
,
Row
,
BF16
,
BF16
,
F32
,
F32
,
DsDataType
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
64
,
64
,
64
,
64
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
8
,
8
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
8
,
8
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
>
,
DeviceGroupedGemm_Xdl_Fixed_NK
<
Row
,
Col
,
DsLayout
,
Row
,
BF16
,
BF16
,
F32
,
F32
,
DsDataType
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
64
,
64
,
32
,
64
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
8
,
8
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
8
,
8
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
>
,
DeviceGroupedGemm_Xdl_Fixed_NK
<
Row
,
Col
,
DsLayout
,
Row
,
BF16
,
BF16
,
F32
,
F32
,
DsDataType
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
64
,
32
,
64
,
64
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
8
,
8
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
8
,
8
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
>
// clang-format on
>
;
void
add_device_grouped_gemm_xdl_fixed_nk_bf16_bf16_bf16_mk_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemmFixedNK
<
Row
,
Col
,
DsLayout
,
Row
,
BF16
,
BF16
,
DsDataType
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_grouped_gemm_xdl_fixed_nk_bf16_bf16_bf16_mk_nk_mn_irregular_tile_instances
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
profiler/include/profiler/profile_gemm_b_scale_impl.hpp
View file @
a7ae4f8e
...
@@ -342,7 +342,7 @@ bool profile_gemm_b_scale_impl(int do_verification,
...
@@ -342,7 +342,7 @@ bool profile_gemm_b_scale_impl(int do_verification,
if
(
do_log
)
if
(
do_log
)
{
{
LogRangeAsType
<
float
>
(
std
::
cout
<<
"a : "
,
a_m_k
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"a : "
,
a_m_k
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
int8_
t
>
(
std
::
cout
<<
"b: "
,
b_k_n
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
floa
t
>
(
std
::
cout
<<
"b: "
,
b_k_n
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
LogRangeAsType
<
float
>
(
std
::
cout
<<
"c_host : "
,
c_m_n_host_result
.
mData
,
","
)
std
::
cout
<<
"c_host : "
,
c_m_n_host_result
.
mData
,
","
)
<<
std
::
endl
;
<<
std
::
endl
;
...
...
profiler/src/profile_grouped_gemm_fixed_nk.cpp
View file @
a7ae4f8e
...
@@ -17,11 +17,11 @@ enum struct GemmMatrixLayout
...
@@ -17,11 +17,11 @@ enum struct GemmMatrixLayout
enum
struct
GemmDataType
enum
struct
GemmDataType
{
{
BF16_I8_BF16
,
// 0
BF16_I8_BF16
,
// 0
F16_F16_F16
,
// 1
F16_F16_F16
,
// 1
F16_F8_F16
,
// 2
F16_F8_F16
,
// 2
F16_I8_F16
,
// 3
F16_I8_F16
,
// 3
BF16_BF16_BF16
// 4
};
};
#define OP_NAME "grouped_gemm_fixed_nk"
#define OP_NAME "grouped_gemm_fixed_nk"
...
@@ -39,7 +39,6 @@ std::vector<int> argToIntArray(char* input)
...
@@ -39,7 +39,6 @@ std::vector<int> argToIntArray(char* input)
{
{
out
.
push_back
(
std
::
stoi
(
item
));
out
.
push_back
(
std
::
stoi
(
item
));
}
}
return
out
;
return
out
;
}
}
...
@@ -83,14 +82,6 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
...
@@ -83,14 +82,6 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
const
auto
StrideCs
=
argToIntArray
(
argv
[
13
]);
const
auto
StrideCs
=
argToIntArray
(
argv
[
13
]);
const
int
kbatch
=
argc
>=
15
?
std
::
stoi
(
argv
[
14
])
:
1
;
const
int
kbatch
=
argc
>=
15
?
std
::
stoi
(
argv
[
14
])
:
1
;
using
F32
=
float
;
using
F16
=
ck
::
half_t
;
#if defined(CK_ENABLE_FP8)
using
F8
=
ck
::
f8_t
;
#endif
using
BF16
=
ck
::
bhalf_t
;
using
I8
=
int8_t
;
int
n_warmup
=
1
;
int
n_warmup
=
1
;
int
n_iter
=
10
;
int
n_iter
=
10
;
if
(
argc
==
17
)
if
(
argc
==
17
)
...
@@ -99,13 +90,12 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
...
@@ -99,13 +90,12 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
n_iter
=
std
::
stoi
(
argv
[
16
]);
n_iter
=
std
::
stoi
(
argv
[
16
]);
}
}
#if defined(CK_ENABLE_BF16) && defined(CK_ENABLE_INT8)
if
(
data_type
==
GemmDataType
::
F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
if
(
data_type
==
GemmDataType
::
BF16_I8_BF16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
{
ck
::
profiler
::
profile_grouped_gemm_fixed_nk_impl
<
BF16
,
ck
::
profiler
::
profile_grouped_gemm_fixed_nk_impl
<
ck
::
half_t
,
I8
,
ck
::
half_t
,
BF16
,
ck
::
half_t
,
F32
,
float
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
...
@@ -123,12 +113,12 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
...
@@ -123,12 +113,12 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
n_warmup
,
n_warmup
,
n_iter
);
n_iter
);
}
}
else
if
(
data_type
==
GemmDataType
::
B
F16_
I8_B
F16
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
else
if
(
data_type
==
GemmDataType
::
F16_
F16_
F16
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
{
{
ck
::
profiler
::
profile_grouped_gemm_fixed_nk_impl
<
BF16
,
ck
::
profiler
::
profile_grouped_gemm_fixed_nk_impl
<
ck
::
half_t
,
I8
,
ck
::
half_t
,
BF16
,
ck
::
half_t
,
F32
,
float
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
...
@@ -146,14 +136,13 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
...
@@ -146,14 +136,13 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
n_warmup
,
n_warmup
,
n_iter
);
n_iter
);
}
}
#endif
#if defined(CK_ENABLE_FP8)
#if defined(CK_ENABLE_FP16)
else
if
(
data_type
==
GemmDataType
::
F16_F8_F16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
else
if
(
data_type
==
GemmDataType
::
F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
{
ck
::
profiler
::
profile_grouped_gemm_fixed_nk_impl
<
F16
,
ck
::
profiler
::
profile_grouped_gemm_fixed_nk_impl
<
ck
::
half_t
,
F16
,
ck
::
f8_t
,
F16
,
ck
::
half_t
,
F32
,
float
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
...
@@ -171,12 +160,12 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
...
@@ -171,12 +160,12 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
n_warmup
,
n_warmup
,
n_iter
);
n_iter
);
}
}
else
if
(
data_type
==
GemmDataType
::
F16_F
16
_F16
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
else
if
(
data_type
==
GemmDataType
::
F16_F
8
_F16
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
{
{
ck
::
profiler
::
profile_grouped_gemm_fixed_nk_impl
<
F16
,
ck
::
profiler
::
profile_grouped_gemm_fixed_nk_impl
<
ck
::
half_t
,
F16
,
ck
::
f8_t
,
F16
,
ck
::
half_t
,
F32
,
float
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
...
@@ -194,14 +183,14 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
...
@@ -194,14 +183,14 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
n_warmup
,
n_warmup
,
n_iter
);
n_iter
);
}
}
#endif
#endif
// CK_ENABLE_FP8
#if defined(CK_ENABLE_
FP16) && defined(CK_ENABLE_FP
8)
#if defined(CK_ENABLE_
INT
8)
else
if
(
data_type
==
GemmDataType
::
F16_
F
8_F16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
else
if
(
data_type
==
GemmDataType
::
F16_
I
8_F16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
{
ck
::
profiler
::
profile_grouped_gemm_fixed_nk_impl
<
F16
,
ck
::
profiler
::
profile_grouped_gemm_fixed_nk_impl
<
ck
::
half_t
,
F8
,
int8_t
,
F16
,
ck
::
half_t
,
F32
,
float
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
...
@@ -219,12 +208,12 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
...
@@ -219,12 +208,12 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
n_warmup
,
n_warmup
,
n_iter
);
n_iter
);
}
}
else
if
(
data_type
==
GemmDataType
::
F16_
F
8_F16
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
else
if
(
data_type
==
GemmDataType
::
F16_
I
8_F16
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
{
{
ck
::
profiler
::
profile_grouped_gemm_fixed_nk_impl
<
F16
,
ck
::
profiler
::
profile_grouped_gemm_fixed_nk_impl
<
ck
::
half_t
,
F8
,
int8_t
,
F16
,
ck
::
half_t
,
F32
,
float
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
...
@@ -242,14 +231,14 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
...
@@ -242,14 +231,14 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
n_warmup
,
n_warmup
,
n_iter
);
n_iter
);
}
}
#endif
#endif
// CK_ENABLE_INT8
#if defined(CK_ENABLE_F
P
16)
&& defined(CK_ENABLE_INT8)
#if defined(CK_ENABLE_
B
F16)
else
if
(
data_type
==
GemmDataType
::
F16_
I8_
F16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
else
if
(
data_type
==
GemmDataType
::
B
F16_
BF16_B
F16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
{
ck
::
profiler
::
profile_grouped_gemm_fixed_nk_impl
<
F16
,
ck
::
profiler
::
profile_grouped_gemm_fixed_nk_impl
<
ck
::
bhalf_t
,
I8
,
ck
::
bhalf_t
,
F16
,
ck
::
bhalf_t
,
F32
,
float
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
...
@@ -267,12 +256,59 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
...
@@ -267,12 +256,59 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
n_warmup
,
n_warmup
,
n_iter
);
n_iter
);
}
}
else
if
(
data_type
==
GemmDataType
::
F16_I8_F16
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
else
if
(
data_type
==
GemmDataType
::
BF16_BF16_BF16
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
{
ck
::
profiler
::
profile_grouped_gemm_fixed_nk_impl
<
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
float
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
time_kernel
,
Ms
,
Ns
,
Ks
,
StrideAs
,
StrideBs
,
StrideCs
,
kbatch
,
n_warmup
,
n_iter
);
}
#if defined(CK_ENABLE_INT8)
else
if
(
data_type
==
GemmDataType
::
BF16_I8_BF16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
{
ck
::
profiler
::
profile_grouped_gemm_fixed_nk_impl
<
F16
,
ck
::
profiler
::
profile_grouped_gemm_fixed_nk_impl
<
ck
::
bhalf_t
,
I8
,
int8_t
,
F16
,
ck
::
bhalf_t
,
F32
,
float
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
time_kernel
,
Ms
,
Ns
,
Ks
,
StrideAs
,
StrideBs
,
StrideCs
,
kbatch
,
n_warmup
,
n_iter
);
}
else
if
(
data_type
==
GemmDataType
::
BF16_I8_BF16
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
{
ck
::
profiler
::
profile_grouped_gemm_fixed_nk_impl
<
ck
::
bhalf_t
,
int8_t
,
ck
::
bhalf_t
,
float
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
...
@@ -286,11 +322,12 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
...
@@ -286,11 +322,12 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
StrideAs
,
StrideAs
,
StrideBs
,
StrideBs
,
StrideCs
,
StrideCs
,
1
,
kbatch
,
n_warmup
,
n_warmup
,
n_iter
);
n_iter
);
}
}
#endif
#endif // CK_ENABLE_INT8
#endif // CK_ENABLE_BF16
else
else
{
{
throw
std
::
runtime_error
(
"wrong! this GEMM data_type & layout is not implemented"
);
throw
std
::
runtime_error
(
"wrong! this GEMM data_type & layout is not implemented"
);
...
...
test/CMakeLists.txt
View file @
a7ae4f8e
...
@@ -7,6 +7,34 @@ include(gtest)
...
@@ -7,6 +7,34 @@ include(gtest)
add_custom_target
(
tests
)
add_custom_target
(
tests
)
# list of tests that are labelled as REGRESSION_TEST for make regression (runtime more than 30 seconds)
# all other tests are labelled as SMOKE_TEST
set
(
REGRESSION_TESTS
test_gemm_standalone_xdl_fp16
test_gemm_fp16
test_gemm_splitk
test_batched_gemm
test_gemm_universal
test_batched_gemm_softmax_gemm_fp16
test_batched_gemm_softmax_gemm_permute_fp16
test_batched_gemm_bias_softmax_gemm_permute_fp16
test_batched_gemm_softmax_gemm_permute_bf16
test_batched_gemm_bias_softmax_gemm_permute_bf16
test_grouped_gemm_splitk
test_reduce_no_index
test_reduce_with_index
test_convnd_fwd
test_convnd_bwd_data
test_grouped_convnd_fwd
test_grouped_convnd_bwd_weight
test_softmax_rank3
test_softmax_rank4
test_batchnorm_fwd_rank_4
test_batchnorm_bwd_rank_4
test_grouped_convnd_bwd_data_xdl
test_conv_tensor_rearrange
)
function
(
add_test_executable TEST_NAME
)
function
(
add_test_executable TEST_NAME
)
message
(
"adding test
${
TEST_NAME
}
"
)
message
(
"adding test
${
TEST_NAME
}
"
)
set
(
result 1
)
set
(
result 1
)
...
@@ -88,6 +116,15 @@ function(add_test_executable TEST_NAME)
...
@@ -88,6 +116,15 @@ function(add_test_executable TEST_NAME)
endif
()
endif
()
#message("add_test returns ${result}")
#message("add_test returns ${result}")
set
(
result
${
result
}
PARENT_SCOPE
)
set
(
result
${
result
}
PARENT_SCOPE
)
if
(
result EQUAL 0 AND NOT
"
${
TEST_NAME
}
"
IN_LIST REGRESSION_TESTS
)
message
(
"adding to SMOKE TEST FILTER
${
TEST_NAME
}
"
)
set_tests_properties
(
${
TEST_NAME
}
PROPERTIES LABELS
"SMOKE_TEST"
)
add_dependencies
(
smoke
${
TEST_NAME
}
)
elseif
(
result EQUAL 0 AND
"
${
TEST_NAME
}
"
IN_LIST REGRESSION_TESTS
)
message
(
"Adding to REGRESSION TEST FILTER
${
TEST_NAME
}
"
)
set_tests_properties
(
${
TEST_NAME
}
PROPERTIES LABELS
"REGRESSION_TEST"
)
add_dependencies
(
regression
${
TEST_NAME
}
)
endif
()
endfunction
()
endfunction
()
function
(
add_gtest_executable TEST_NAME
)
function
(
add_gtest_executable TEST_NAME
)
...
@@ -168,6 +205,15 @@ function(add_gtest_executable TEST_NAME)
...
@@ -168,6 +205,15 @@ function(add_gtest_executable TEST_NAME)
endif
()
endif
()
#message("add_gtest returns ${result}")
#message("add_gtest returns ${result}")
set
(
result
${
result
}
PARENT_SCOPE
)
set
(
result
${
result
}
PARENT_SCOPE
)
if
(
result EQUAL 0 AND NOT
"
${
TEST_NAME
}
"
IN_LIST REGRESSION_TESTS
)
#message("adding to smoke test FILTER ${TEST_NAME}")
set_tests_properties
(
${
TEST_NAME
}
PROPERTIES LABELS
"SMOKE_TEST"
)
add_dependencies
(
smoke
${
TEST_NAME
}
)
elseif
(
result EQUAL 0 AND
"
${
TEST_NAME
}
"
IN_LIST REGRESSION_TESTS
)
#message("Adding to REGRESSION TEST FILTER ${TEST_NAME}")
set_tests_properties
(
${
TEST_NAME
}
PROPERTIES LABELS
"REGRESSION_TEST"
)
add_dependencies
(
regression
${
TEST_NAME
}
)
endif
()
endfunction
()
endfunction
()
add_compile_options
(
-Wno-c++20-extensions
)
add_compile_options
(
-Wno-c++20-extensions
)
...
...
test/ck_tile/batched_gemm/test_batched_gemm.cpp
View file @
a7ae4f8e
...
@@ -17,7 +17,7 @@ using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
...
@@ -17,7 +17,7 @@ using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
// clang-format off
// clang-format off
using
KernelTypes
=
::
testing
::
Types
<
using
KernelTypes
=
::
testing
::
Types
<
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType
std
::
tuple
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
>
,
//
std::tuple< Row, Row, Row, F16, F16, F32, F16>,
//std::tuple< Col, Row, Row, F16, F16, F32, F16>,
//std::tuple< Col, Row, Row, F16, F16, F32, F16>,
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
>
//,
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
>
//,
//std::tuple< Col, Col, Row, F16, F16, F32, F16>
//std::tuple< Col, Col, Row, F16, F16, F32, F16>
...
...
test/ck_tile/batched_gemm/test_batched_gemm_util.hpp
View file @
a7ae4f8e
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024
-2025
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#include <sstream>
#include <sstream>
...
@@ -61,7 +61,7 @@ class TestCkTileBatchedGemm : public ::testing::Test
...
@@ -61,7 +61,7 @@ class TestCkTileBatchedGemm : public ::testing::Test
ck_tile
::
sequence
<
M_Warp
,
N_Warp
,
K_Warp
>
,
ck_tile
::
sequence
<
M_Warp
,
N_Warp
,
K_Warp
>
,
ck_tile
::
sequence
<
M_Warp_Tile
,
N_Warp_Tile
,
K_Warp_Tile
>>
;
ck_tile
::
sequence
<
M_Warp_Tile
,
N_Warp_Tile
,
K_Warp_Tile
>>
;
using
TilePartitioner
=
ck_tile
::
GemmTilePartitioner
<
CodegenGemmShape
>
;
using
TilePartitioner
=
ck_tile
::
GemmTile
2D
Partitioner
<
CodegenGemmShape
>
;
using
GemmEpilogue
=
std
::
conditional_t
<
using
GemmEpilogue
=
std
::
conditional_t
<
CShuffleEpilogue
,
CShuffleEpilogue
,
...
@@ -73,8 +73,8 @@ class TestCkTileBatchedGemm : public ::testing::Test
...
@@ -73,8 +73,8 @@ class TestCkTileBatchedGemm : public ::testing::Test
kOutputRank
,
kOutputRank
,
1
,
1
,
0
,
0
,
TilePartitioner
::
k
M
,
TilePartitioner
::
M
PerBlock
,
TilePartitioner
::
k
N
>>
,
TilePartitioner
::
N
PerBlock
>>
,
ck_tile
::
Default2DEpilogue
<
ck_tile
::
Default2DEpilogue
<
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
kPadM
,
kPadN
>>>
;
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
kPadM
,
kPadN
>>>
;
...
...
test/ck_tile/gemm/test_gemm_pipeline.cpp
View file @
a7ae4f8e
...
@@ -14,26 +14,28 @@ using Row = ck_tile::tensor_layout::gemm::RowMajor;
...
@@ -14,26 +14,28 @@ using Row = ck_tile::tensor_layout::gemm::RowMajor;
using
Col
=
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
Col
=
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
Intrawave
=
ck_tile
::
integral_constant
<
ck_tile
::
GemmPipelineScheduler
,
using
Intrawave
=
ck_tile
::
integral_constant
<
ck_tile
::
GemmPipelineScheduler
,
ck_tile
::
GemmPipelineScheduler
::
Intrawave
>
;
ck_tile
::
GemmPipelineScheduler
::
Intrawave
>
;
using
Interwave
=
ck_tile
::
integral_constant
<
ck_tile
::
GemmPipelineScheduler
,
// using Interwave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler,
ck_tile
::
GemmPipelineScheduler
::
Interwave
>
;
// ck_tile::GemmPipelineScheduler::Interwave>;
using
Mem
=
ck_tile
::
integral_constant
<
GemmPipelineType
,
GemmPipelineType
::
Mem
>
;
// using Mem = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::Mem>;
using
Comp
=
ck_tile
::
integral_constant
<
GemmPipelineType
,
GemmPipelineType
::
Comp
>
;
using
Comp
=
ck_tile
::
integral_constant
<
GemmPipelineType
,
GemmPipelineType
::
Comp
>
;
// TODO: Enable Memory pipeline, when it would be updated for vector loads on non-K major tensors.
// clang-format off
// clang-format off
using
KernelTypes
=
::
testing
::
Types
<
using
KernelTypes
=
::
testing
::
Types
<
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, GemmPipelineScheduler, PipelineType
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, GemmPipelineScheduler, PipelineType
std
::
tuple
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Mem
>
,
//
std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, Mem>,
std
::
tuple
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Comp
>
,
std
::
tuple
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Comp
>
,
std
::
tuple
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Interwave
,
Mem
>
,
//
std::tuple< Row, Row, Row, F16, F16, F32, F16, Interwave, Mem>,
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Mem
>
,
//
std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, Mem>,
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Comp
>
,
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Comp
>
,
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Interwave
,
Mem
>
,
//
std::tuple< Row, Col, Row, F16, F16, F32, F16, Interwave, Mem>,
std
::
tuple
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Mem
>
,
//
std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, Mem>,
std
::
tuple
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Comp
>
,
std
::
tuple
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Comp
>
,
std
::
tuple
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Interwave
,
Mem
>
,
//
std::tuple< Col, Row, Row, F16, F16, F32, F16, Interwave, Mem>,
std
::
tuple
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Mem
>
,
//
std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, Mem>,
std
::
tuple
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Comp
>
,
std
::
tuple
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Comp
>
std
::
tuple
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Interwave
,
Mem
>
//
std::tuple< Col, Col, Row, F16, F16, F32, F16, Interwave, Mem>
>
;
>
;
// clang-format on
// clang-format on
...
...
test/ck_tile/gemm/test_gemm_pipeline_ut_cases.inc
View file @
a7ae4f8e
...
@@ -10,22 +10,43 @@ TYPED_TEST(TestCkTileGemmPipeline, SmallM)
...
@@ -10,22 +10,43 @@ TYPED_TEST(TestCkTileGemmPipeline, SmallM)
constexpr
int
K
=
320
;
constexpr
int
K
=
320
;
for
(
int
M
:
Ms
)
for
(
int
M
:
Ms
)
this
->
Run
(
M
,
N
,
K
);
{
if
constexpr
(
std
::
is_same_v
<
typename
TestFixture
::
ALayout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
)
EXPECT_THROW
((
this
->
Run
(
M
,
N
,
K
)),
std
::
runtime_error
);
else
this
->
Run
(
M
,
N
,
K
);
}
}
}
TYPED_TEST
(
TestCkTileGemmPipeline
,
MidLargeM
)
TYPED_TEST
(
TestCkTileGemmPipeline
,
MidLargeM
)
{
{
std
::
vector
<
int
>
Ms
{
127
,
255
,
312
,
799
,
1573
};
std
::
vector
<
int
>
Ms
{
127
,
255
,
312
,
799
,
1573
};
constexpr
int
N
=
1024
;
constexpr
int
N
=
1024
;
constexpr
int
K
=
320
;
constexpr
int
K
=
320
;
constexpr
int
VecLoadSize
=
8
;
for
(
int
M
:
Ms
)
for
(
int
M
:
Ms
)
this
->
Run
(
M
,
N
,
K
);
{
if
constexpr
(
std
::
is_same_v
<
typename
TestFixture
::
ALayout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
)
{
// TODO: Can we anyhow deduce used vector load size?
if
(
M
%
VecLoadSize
==
0
)
this
->
Run
(
M
,
N
,
K
);
else
EXPECT_THROW
((
this
->
Run
(
M
,
N
,
K
)),
std
::
runtime_error
);
}
else
{
this
->
Run
(
M
,
N
,
K
);
}
}
}
}
TYPED_TEST
(
TestCkTileGemmPipeline
,
PaddK
)
TYPED_TEST
(
TestCkTileGemmPipeline
,
PaddK
)
{
{
std
::
vector
<
int
>
Ms
{
12
7
};
std
::
vector
<
int
>
Ms
{
12
8
};
constexpr
int
N
=
1024
;
constexpr
int
N
=
1024
;
constexpr
int
K
=
432
;
constexpr
int
K
=
432
;
...
...
test/ck_tile/gemm/test_gemm_pipeline_util.hpp
View file @
a7ae4f8e
...
@@ -16,6 +16,7 @@ enum struct GemmPipelineType
...
@@ -16,6 +16,7 @@ enum struct GemmPipelineType
Mem
,
Mem
,
Comp
Comp
};
};
template
<
typename
Tuple
>
template
<
typename
Tuple
>
class
TestCkTileGemmPipeline
:
public
::
testing
::
Test
class
TestCkTileGemmPipeline
:
public
::
testing
::
Test
{
{
...
@@ -51,6 +52,9 @@ class TestCkTileGemmPipeline : public ::testing::Test
...
@@ -51,6 +52,9 @@ class TestCkTileGemmPipeline : public ::testing::Test
constexpr
bool
kPadN
=
PadN
;
constexpr
bool
kPadN
=
PadN
;
constexpr
bool
kPadK
=
PadK
;
constexpr
bool
kPadK
=
PadK
;
// TODO: For now - but this should also be a test parameter
constexpr
bool
TransposeC
=
false
;
constexpr
int
kBlockPerCu
=
1
;
constexpr
int
kBlockPerCu
=
1
;
// ===============================================
// ===============================================
...
@@ -59,20 +63,22 @@ class TestCkTileGemmPipeline : public ::testing::Test
...
@@ -59,20 +63,22 @@ class TestCkTileGemmPipeline : public ::testing::Test
ck_tile
::
TileGemmShape
<
ck_tile
::
sequence
<
M_Tile
,
N_Tile
,
K_Tile
>
,
ck_tile
::
TileGemmShape
<
ck_tile
::
sequence
<
M_Tile
,
N_Tile
,
K_Tile
>
,
ck_tile
::
sequence
<
M_Warp
,
N_Warp
,
K_Warp
>
,
ck_tile
::
sequence
<
M_Warp
,
N_Warp
,
K_Warp
>
,
ck_tile
::
sequence
<
M_Warp_Tile
,
N_Warp_Tile
,
K_Warp_Tile
>>
;
ck_tile
::
sequence
<
M_Warp_Tile
,
N_Warp_Tile
,
K_Warp_Tile
>>
;
using
TilePartitioner
=
ck_tile
::
GemmTilePartitioner
<
GemmShape
>
;
using
TilePartitioner
=
ck_tile
::
GemmTile
2D
Partitioner
<
GemmShape
>
;
using
GemmEpilogue
=
ck_tile
::
Default2DEpilogue
<
using
GemmEpilogue
=
ck_tile
::
Default2DEpilogue
<
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
kPadM
,
kPadN
>>
;
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
kPadM
,
kPadN
>>
;
using
Traits
=
ck_tile
::
TileGemmTraits
<
kPadM
,
kPadN
,
kPadK
,
ALayout
,
BLayout
,
CLayout
>
;
using
Traits
=
ck_tile
::
TileGemmTraits
<
kPadM
,
kPadN
,
kPadK
,
ALayout
,
BLayout
,
CLayout
>
;
using
GemmUniversalTraits
=
ck_tile
::
TileGemmUniversalTraits
<
kPadM
,
kPadN
,
kPadK
,
ALayout
,
BLayout
,
CLayout
,
TransposeC
>
;
using
GemmPipelineProblem
=
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
GemmShape
,
Traits
>
;
using
BaseGemmPipeline
=
std
::
conditional_t
<
using
BaseGemmPipeline
=
PipelineType
==
GemmPipelineType
::
Mem
,
std
::
conditional_t
<
PipelineType
==
GemmPipelineType
::
Mem
,
ck_tile
::
BaseGemmPipelineAgBgCrMem
<
ck_tile
::
BaseGemmPipelineAgBgCrMem
<
GemmPipelineProblem
>
,
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
GemmShape
,
Traits
>>
,
ck_tile
::
BaseGemmPipelineAgBgCrCompV3
<
GemmPipelineProblem
>>
;
ck_tile
::
BaseGemmPipelineAgBgCrCompV3
<
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
GemmShape
,
Traits
>>>
;
const
ck_tile
::
index_t
k_grain
=
args
.
k_batch
*
K_Tile
;
const
ck_tile
::
index_t
k_grain
=
args
.
k_batch
*
K_Tile
;
const
ck_tile
::
index_t
K_split
=
(
args
.
K
+
k_grain
-
1
)
/
k_grain
*
K_Tile
;
const
ck_tile
::
index_t
K_split
=
(
args
.
K
+
k_grain
-
1
)
/
k_grain
*
K_Tile
;
...
@@ -84,26 +90,22 @@ class TestCkTileGemmPipeline : public ::testing::Test
...
@@ -84,26 +90,22 @@ class TestCkTileGemmPipeline : public ::testing::Test
constexpr
bool
has_hot_loop_v
=
has_hot_loop_
.
value
;
constexpr
bool
has_hot_loop_v
=
has_hot_loop_
.
value
;
constexpr
auto
tail_number_v
=
tail_number_
.
value
;
constexpr
auto
tail_number_v
=
tail_number_
.
value
;
using
GemmPipeline
=
using
UniversalGemmProblem
=
ck_tile
::
UniversalGemmPipelineProblem
<
ADataType
,
std
::
conditional_t
<
PipelineType
==
GemmPipelineType
::
Mem
,
BDataType
,
ck_tile
::
GemmPipelineAgBgCrMem
<
AccDataType
,
ck_tile
::
UniversalGemmPipelineProblem
<
ADataType
,
GemmShape
,
BDataType
,
GemmUniversalTraits
,
AccDataType
,
Scheduler
,
GemmShape
,
has_hot_loop_v
,
Traits
,
tail_number_v
>
;
Scheduler
,
has_hot_loop_v
,
using
GemmPipeline
=
std
::
conditional_t
<
tail_number_v
>>
,
PipelineType
==
GemmPipelineType
::
Mem
,
ck_tile
::
GemmPipelineAgBgCrCompV3
<
ck_tile
::
GemmPipelineAgBgCrMem
<
UniversalGemmProblem
,
ck_tile
::
UniversalGemmPipelineProblem
<
ADataType
,
ck_tile
::
UniversalGemmPipelineAgBgCrPolicy
>
,
BDataType
,
ck_tile
::
GemmPipelineAgBgCrCompV3
<
UniversalGemmProblem
,
AccDataType
,
ck_tile
::
UniversalGemmPipelineAgBgCrPolicy
>>
;
GemmShape
,
Traits
,
Scheduler
,
has_hot_loop_v
,
tail_number_v
>>>
;
using
Kernel
=
ck_tile
::
GemmKernel
<
TilePartitioner
,
GemmPipeline
,
GemmEpilogue
>
;
using
Kernel
=
ck_tile
::
GemmKernel
<
TilePartitioner
,
GemmPipeline
,
GemmEpilogue
>
;
auto
kargs
=
Kernel
::
MakeKernelArgs
(
args
);
auto
kargs
=
Kernel
::
MakeKernelArgs
(
args
);
...
@@ -129,70 +131,94 @@ class TestCkTileGemmPipeline : public ::testing::Test
...
@@ -129,70 +131,94 @@ class TestCkTileGemmPipeline : public ::testing::Test
if
(
has_hot_loop
)
if
(
has_hot_loop
)
{
{
// Tail pipeline One to Seven
if
constexpr
(
PipelineType
==
GemmPipelineType
::
Comp
)
if
(
tail_num
==
ck_tile
::
TailNumber
::
One
)
{
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
One
>
{});
}
else
if
(
tail_num
==
ck_tile
::
TailNumber
::
Full
)
{
{
Run
(
ck_tile
::
bool_constant
<
true
>
{},
if
(
tail_num
==
ck_tile
::
TailNumber
::
Full
)
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Full
>
{});
}
if
constexpr
(
BaseGemmPipeline
::
PrefetchStages
>
2
)
{
if
(
tail_num
==
ck_tile
::
TailNumber
::
Two
)
{
{
Run
(
ck_tile
::
bool_constant
<
true
>
{},
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Two
>
{});
ck_tile
::
TailNumber
::
Full
>
{});
}
}
}
else
if
constexpr
(
BaseGemmPipeline
::
PrefetchStages
>
3
)
{
if
(
tail_num
==
ck_tile
::
TailNumber
::
Three
)
{
{
Run
(
ck_tile
::
bool_constant
<
true
>
{},
std
::
ostringstream
err
;
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
err
<<
"For compute pipeline tail number should always be Full, but have
\"
"
ck_tile
::
TailNumber
::
Three
>
{});
<<
tail_num
<<
"
\"
which is not supported! PrefetchStages: "
<<
BaseGemmPipeline
::
PrefetchStages
<<
"
\n
File: "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
throw
std
::
runtime_error
(
err
.
str
());
}
}
}
}
if
constexpr
(
BaseGemmPipeline
::
PrefetchStages
>
4
)
if
constexpr
(
PipelineType
==
GemmPipelineType
::
Mem
)
{
{
if
(
tail_num
==
ck_tile
::
TailNumber
::
Four
)
// Tail pipeline One to Seven
if
(
tail_num
==
ck_tile
::
TailNumber
::
One
)
{
{
Run
(
ck_tile
::
bool_constant
<
true
>
{},
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Four
>
{});
ck_tile
::
TailNumber
::
One
>
{});
}
}
}
else
if
(
tail_num
==
ck_tile
::
TailNumber
::
Full
)
if
constexpr
(
BaseGemmPipeline
::
PrefetchStages
>
5
)
{
if
(
tail_num
==
ck_tile
::
TailNumber
::
Five
)
{
{
Run
(
ck_tile
::
bool_constant
<
true
>
{},
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
F
ive
>
{});
ck_tile
::
TailNumber
::
F
ull
>
{});
}
}
}
if
constexpr
(
BaseGemmPipeline
::
PrefetchStages
>
6
)
if
constexpr
(
BaseGemmPipeline
::
PrefetchStages
>
2
)
{
if
(
tail_num
==
ck_tile
::
TailNumber
::
Six
)
{
{
Run
(
ck_tile
::
bool_constant
<
true
>
{},
if
(
tail_num
==
ck_tile
::
TailNumber
::
Two
)
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
{
ck_tile
::
TailNumber
::
Six
>
{});
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Two
>
{});
}
}
}
}
if
constexpr
(
BaseGemmPipeline
::
PrefetchStages
>
3
)
if
constexpr
(
BaseGemmPipeline
::
PrefetchStages
>
7
)
{
if
(
tail_num
==
ck_tile
::
TailNumber
::
Seven
)
{
{
Run
(
ck_tile
::
bool_constant
<
true
>
{},
if
(
tail_num
==
ck_tile
::
TailNumber
::
Three
)
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
{
ck_tile
::
TailNumber
::
Seven
>
{});
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Three
>
{});
}
}
if
constexpr
(
BaseGemmPipeline
::
PrefetchStages
>
4
)
{
if
(
tail_num
==
ck_tile
::
TailNumber
::
Four
)
{
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Four
>
{});
}
}
if
constexpr
(
BaseGemmPipeline
::
PrefetchStages
>
5
)
{
if
(
tail_num
==
ck_tile
::
TailNumber
::
Five
)
{
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Five
>
{});
}
}
if
constexpr
(
BaseGemmPipeline
::
PrefetchStages
>
6
)
{
if
(
tail_num
==
ck_tile
::
TailNumber
::
Six
)
{
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Six
>
{});
}
}
if
constexpr
(
BaseGemmPipeline
::
PrefetchStages
>
7
)
{
if
(
tail_num
==
ck_tile
::
TailNumber
::
Seven
)
{
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Seven
>
{});
}
}
}
}
}
}
}
...
...
test/ck_tile/grouped_gemm/test_grouped_gemm.cpp
View file @
a7ae4f8e
...
@@ -17,7 +17,7 @@ using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
...
@@ -17,7 +17,7 @@ using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
// clang-format off
// clang-format off
using
KernelTypes
=
::
testing
::
Types
<
using
KernelTypes
=
::
testing
::
Types
<
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType
std
::
tuple
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
>
,
//
std::tuple< Row, Row, Row, F16, F16, F32, F16>,
//std::tuple< Col, Row, Row, F16, F16, F32, F16>,
//std::tuple< Col, Row, Row, F16, F16, F32, F16>,
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
>
//,
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
>
//,
//std::tuple< Col, Col, Row, F16, F16, F32, F16>
//std::tuple< Col, Col, Row, F16, F16, F32, F16>
...
...
test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp
View file @
a7ae4f8e
...
@@ -96,12 +96,9 @@ class TestCkTileGroupedGemm : public ::testing::Test
...
@@ -96,12 +96,9 @@ class TestCkTileGroupedGemm : public ::testing::Test
CodegenGemmShape
,
CodegenGemmShape
,
CodegenGemmTraits
<
ALayout
,
BLayout
,
CLayout
>>
;
CodegenGemmTraits
<
ALayout
,
BLayout
,
CLayout
>>
;
using
CodegenGemmPolicy
=
ck_tile
::
UniversalGemmPipelineAgBgCrPolicy
;
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
using
CodegenGemmPipeline
=
using
CodegenGemmPipeline
=
ck_tile
::
GemmPipelineAGmemBGmemCRegV1
<
CodegenPipelineProblem
<
ALayout
,
BLayout
,
CLayout
>
,
ck_tile
::
GemmPipelineAGmemBGmemCRegV1
<
CodegenPipelineProblem
<
ALayout
,
BLayout
,
CLayout
>>
;
CodegenGemmPolicy
>
;
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
using
Kernel
=
ck_tile
::
GroupedGemmKernel
<
TilePartitioner
,
using
Kernel
=
ck_tile
::
GroupedGemmKernel
<
TilePartitioner
,
...
...
test/data_type/CMakeLists.txt
View file @
a7ae4f8e
...
@@ -49,3 +49,4 @@ if(result EQUAL 0)
...
@@ -49,3 +49,4 @@ if(result EQUAL 0)
endif
()
endif
()
add_gtest_executable
(
test_type_convert_const type_convert_const.cpp
)
add_gtest_executable
(
test_type_convert_const type_convert_const.cpp
)
add_gtest_executable
(
test_bhalf test_bhalf.cpp
)
test/data_type/test_bhalf.cpp
0 → 100644
View file @
a7ae4f8e
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "ck/utility/data_type.hpp"
#include "ck/utility/type_convert.hpp"
using
ck
::
bhalf_t
;
using
ck
::
type_convert
;
TEST
(
BHALF_T
,
Nan
)
{
const
uint16_t
binary_bhalf_nan
=
0x7FC0
;
const
bhalf_t
bhalf_nan
=
ck
::
bit_cast
<
bhalf_t
>
(
binary_bhalf_nan
);
EXPECT_EQ
(
bhalf_nan
,
type_convert
<
bhalf_t
>
(
ck
::
NumericLimits
<
float
>::
QuietNaN
()));
}
TEST
(
BHALF_T
,
Inf
)
{
const
uint16_t
binary_bhalf_inf
=
0x7F80
;
const
bhalf_t
bhalf_inf
=
ck
::
bit_cast
<
bhalf_t
>
(
binary_bhalf_inf
);
EXPECT_EQ
(
bhalf_inf
,
type_convert
<
bhalf_t
>
(
ck
::
NumericLimits
<
float
>::
Infinity
()));
}
TEST
(
BHALF_T
,
MantisaOverflow
)
{
const
float
abs_tol
=
std
::
pow
(
2
,
-
7
);
const
uint32_t
val
=
0x81FFFFFF
;
const
float
float_val
=
ck
::
bit_cast
<
float
>
(
val
);
ASSERT_NEAR
(
float_val
,
type_convert
<
float
>
(
type_convert
<
bhalf_t
>
(
float_val
)),
abs_tol
);
}
TEST
(
BHALF_T
,
ExpOverflow
)
{
const
uint32_t
val
=
0xFF800000
;
const
float
float_val
=
ck
::
bit_cast
<
float
>
(
val
);
ASSERT_EQ
(
type_convert
<
float
>
(
type_convert
<
bhalf_t
>
(
float_val
)),
float_val
);
}
TEST
(
BHALF_T
,
MantisaExpOverflow
)
{
const
uint32_t
val
=
0xFFFFFFFF
;
const
float
float_val
=
ck
::
bit_cast
<
float
>
(
val
);
ASSERT_TRUE
(
std
::
isnan
(
float_val
));
ASSERT_TRUE
(
std
::
isnan
(
type_convert
<
float
>
(
type_convert
<
bhalf_t
>
(
float_val
))));
}
Prev
1
…
5
6
7
8
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment