Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
5e8b5703
Commit
5e8b5703
authored
Dec 11, 2024
by
chenjun
Browse files
add multiple multiple add profiler
parent
c1f8d53c
Changes
22
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3553 additions
and
2 deletions
+3553
-2
example/65_gemm_multiply_multiply/CMakeLists.txt
example/65_gemm_multiply_multiply/CMakeLists.txt
+2
-1
example/65_gemm_multiply_multiply/gemm_multiply_multiply_add_xdl_int8.cpp
...multiply_multiply/gemm_multiply_multiply_add_xdl_int8.cpp
+349
-0
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp
...pu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp
+1
-1
include/ck/tensor_operation/gpu/element/element_wise_operation.hpp
...k/tensor_operation/gpu/element/element_wise_operation.hpp
+29
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_bshuffle.hpp
...u/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_bshuffle.hpp
+2155
-0
library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp
..._operation_instance/device_operation_instance_factory.hpp
+1
-0
library/include/ck/library/tensor_operation_instance/gpu/gemm_multiply_multiply_add.hpp
...sor_operation_instance/gpu/gemm_multiply_multiply_add.hpp
+168
-0
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_i8_i8_bf16/device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn.hpp
...device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn.hpp
+36
-0
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_add/CMakeLists.txt
...on_instance/gpu/gemm_multiply_multiply_add/CMakeLists.txt
+15
-0
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_add/device_gemm_multiply_multiply_add_xdl_i8_i8_bf16/device_gemm_multiply_multiply_add_xdl_i8_i8_bf16_mk_nk_mn.hpp
...ce_gemm_multiply_multiply_add_xdl_i8_i8_bf16_mk_nk_mn.hpp
+135
-0
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_add/device_gemm_multiply_multiply_add_xdl_i8_i8_bf16/device_gemm_multiply_multiply_add_xdl_i8_i8_bf16_mk_nk_mn_comp_default_instance.cpp
...ply_add_xdl_i8_i8_bf16_mk_nk_mn_comp_default_instance.cpp
+32
-0
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_add/device_gemm_multiply_multiply_add_xdl_i8_i8_bf16/device_gemm_multiply_multiply_add_xdl_i8_i8_bf16_mk_nk_mn_comp_kpadding_instance.cpp
...ly_add_xdl_i8_i8_bf16_mk_nk_mn_comp_kpadding_instance.cpp
+32
-0
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_add/device_gemm_multiply_multiply_add_xdl_i8_i8_bf16/device_gemm_multiply_multiply_add_xdl_i8_i8_bf16_mk_nk_mn_mem_v1_default_instance.cpp
...y_add_xdl_i8_i8_bf16_mk_nk_mn_mem_v1_default_instance.cpp
+33
-0
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_add/device_gemm_multiply_multiply_add_xdl_i8_i8_bf16/device_gemm_multiply_multiply_add_xdl_i8_i8_bf16_mk_nk_mn_mem_v1_kpadding_instance.cpp
..._add_xdl_i8_i8_bf16_mk_nk_mn_mem_v1_kpadding_instance.cpp
+33
-0
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_add/device_gemm_multiply_multiply_add_xdl_i8_i8_bf16/device_gemm_multiply_multiply_add_xdl_i8_i8_bf16_mk_nk_mn_mem_v2_default_instance.cpp
...y_add_xdl_i8_i8_bf16_mk_nk_mn_mem_v2_default_instance.cpp
+33
-0
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_add/device_gemm_multiply_multiply_add_xdl_i8_i8_bf16/device_gemm_multiply_multiply_add_xdl_i8_i8_bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp
..._add_xdl_i8_i8_bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp
+33
-0
profiler/include/profiler/profile_gemm_multiply_multiply_add_impl.hpp
...lude/profiler/profile_gemm_multiply_multiply_add_impl.hpp
+373
-0
profiler/int8_gmm_profiler.sh
profiler/int8_gmm_profiler.sh
+24
-0
profiler/int8_gmma_mb_profiler.sh
profiler/int8_gmma_mb_profiler.sh
+33
-0
profiler/int8_gmma_profiler.sh
profiler/int8_gmma_profiler.sh
+36
-0
No files found.
example/65_gemm_multiply_multiply/CMakeLists.txt
View file @
5e8b5703
...
@@ -2,3 +2,4 @@ add_example_executable(example_gemm_multiply_multiply_xdl_fp8 gemm_multiply_mult
...
@@ -2,3 +2,4 @@ add_example_executable(example_gemm_multiply_multiply_xdl_fp8 gemm_multiply_mult
add_example_executable
(
example_gemm_multiply_multiply_xdl_fp8_ab_scale gemm_multiply_multiply_xdl_fp8_ab_scale.cpp
)
add_example_executable
(
example_gemm_multiply_multiply_xdl_fp8_ab_scale gemm_multiply_multiply_xdl_fp8_ab_scale.cpp
)
add_example_executable
(
example_gemm_add_add_xdl_fp16 gemm_add_add_xdl_fp16.cpp
)
add_example_executable
(
example_gemm_add_add_xdl_fp16 gemm_add_add_xdl_fp16.cpp
)
add_example_executable
(
example_gemm_multiply_multiply_xdl_int8 gemm_multiply_multiply_xdl_int8.cpp
)
add_example_executable
(
example_gemm_multiply_multiply_xdl_int8 gemm_multiply_multiply_xdl_int8.cpp
)
add_example_executable
(
example_gemm_multiply_multiply_add_xdl_int8 gemm_multiply_multiply_add_xdl_int8.cpp
)
\ No newline at end of file
example/65_gemm_multiply_multiply/gemm_multiply_multiply_add_xdl_int8.cpp
0 → 100644
View file @
5e8b5703
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/utility/blkgemmpipe_scheduler.hpp"
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
I8
=
int8_t
;
using
I32
=
int
;
using
F16
=
ck
::
half_t
;
using
BF16
=
ck
::
bhalf_t
;
using
FP8
=
ck
::
f8_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
A0DataType
=
I8
;
using
B0DataType
=
I8
;
using
AccDataType
=
I32
;
using
CShuffleDataType
=
I32
;
using
EDataType
=
BF16
;
using
D0DataType
=
F32
;
using
D1DataType
=
F32
;
using
D2DataType
=
EDataType
;
using
DsDataType
=
ck
::
Tuple
<
D0DataType
,
D1DataType
,
D2DataType
>
;
using
A0Layout
=
Row
;
using
B0Layout
=
Col
;
using
D0Layout
=
Row
;
using
D1Layout
=
Col
;
using
D2Layout
=
Row
;
using
DsLayout
=
ck
::
Tuple
<
D0Layout
,
D1Layout
,
D2Layout
>
;
using
ELayout
=
Row
;
struct
MultiplyMultiplyAdd
{
template
<
typename
E
,
typename
C
,
typename
D0
,
typename
D1
,
typename
D2
>
__host__
__device__
constexpr
void
operator
()(
E
&
e
,
const
C
&
c
,
const
D0
&
d0
,
const
D1
&
d1
,
const
D2
&
d2
)
const
;
template
<
>
__host__
__device__
constexpr
void
operator
()
<
ck
::
half_t
,
int
,
float
,
float
,
ck
::
half_t
>
(
ck
::
half_t
&
e
,
const
int
&
c
,
const
float
&
d0
,
const
float
&
d1
,
const
ck
::
half_t
&
d2
)
const
{
const
float
x0_f
=
ck
::
type_convert
<
float
>
(
c
)
*
ck
::
type_convert
<
float
>
(
d0
)
*
ck
::
type_convert
<
float
>
(
d1
)
+
ck
::
type_convert
<
float
>
(
d2
);
e
=
ck
::
type_convert
<
ck
::
half_t
>
(
x0_f
);
}
template
<
>
__host__
__device__
constexpr
void
operator
()
<
ck
::
bhalf_t
,
int
,
float
,
float
,
ck
::
bhalf_t
>
(
ck
::
bhalf_t
&
e
,
const
int
&
c
,
const
float
&
d0
,
const
float
&
d1
,
const
ck
::
bhalf_t
&
d2
)
const
{
const
ck
::
bhalf_t
x0_f
=
ck
::
type_convert
<
float
>
(
c
)
*
ck
::
type_convert
<
float
>
(
d0
)
*
ck
::
type_convert
<
float
>
(
d1
)
+
ck
::
type_convert
<
float
>
(
d2
);
e
=
ck
::
type_convert
<
ck
::
bhalf_t
>
(
x0_f
);
}
};
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
AElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
using
CDEElementOp
=
MultiplyMultiplyAdd
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
// static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::KPadding;
using
DeviceOpInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmMultiD_Xdl_CShuffle_V3
// clang-format off
///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
///######| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
///######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |S<C, D0, D1,D2>|
///###### RRR
///< Row, Row, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 256, 128, 64, 16, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, I8>;
///###### RCR
// [M = 128 N = 1280 K = 8192]
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 32, 64, 512, 16, 16, 16, 16, 1, 2, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1, 1>, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, I8>;
// [M = 128 N = 8192 K = 1024]
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 64, 128, 16, 16, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1, 1>, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, I8>;
<
Row
,
Col
,
DsLayout
,
ELayout
,
A0DataType
,
B0DataType
,
DsDataType
,
EDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
,
GemmSpec
,
256
,
128
,
128
,
256
,
16
,
16
,
32
,
32
,
2
,
2
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
,
1
>
,
ck
::
BlockGemmPipelineScheduler
::
Intrawave
,
ck
::
BlockGemmPipelineVersion
::
v3
,
I8
>
;
// [M = 4096 N = 1280 K = 8192]
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 128, 16, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1, 8>, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, I8>;
// [M = 4096 N = 8192 K = 1024]
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 128, 16, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1, 1>, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, I8>;
// clang-format on
int
main
(
int
argc
,
char
*
argv
[])
{
bool
do_verification
=
true
;
int
init_method
=
1
;
bool
time_kernel
=
false
;
// GEMM shape
ck
::
index_t
M
=
128
;
ck
::
index_t
N
=
8192
;
ck
::
index_t
K
=
1024
;
ck
::
index_t
StrideA
=
K
;
ck
::
index_t
StrideB
=
K
;
ck
::
index_t
StrideD
=
0
;
ck
::
index_t
StrideE
=
N
;
ck
::
index_t
KBatch
=
1
;
if
(
argc
==
1
)
{
// use default case
}
else
if
(
argc
==
4
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
else
if
(
argc
==
12
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
M
=
std
::
stoi
(
argv
[
4
]);
N
=
std
::
stoi
(
argv
[
5
]);
K
=
std
::
stoi
(
argv
[
6
]);
StrideA
=
std
::
stoi
(
argv
[
7
]);
StrideB
=
std
::
stoi
(
argv
[
8
]);
StrideD
=
std
::
stoi
(
argv
[
9
]);
StrideE
=
std
::
stoi
(
argv
[
10
]);
KBatch
=
std
::
stoi
(
argv
[
11
]);
}
else
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE, KBatch
\n
"
);
exit
(
0
);
}
// do_verification = false;
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
using
namespace
ck
::
literals
;
if
(
std
::
is_same
<
decltype
(
layout
),
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
return
HostTensorDescriptor
({
row
,
col
},
{
stride
,
1
_uz
});
}
else
{
return
HostTensorDescriptor
({
row
,
col
},
{
1
_uz
,
stride
});
}
};
Tensor
<
A0DataType
>
a0_m_k
(
f_host_tensor_descriptor
(
M
,
K
,
StrideA
,
A0Layout
{}));
Tensor
<
B0DataType
>
b0_k_n
(
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
B0Layout
{}));
Tensor
<
D0DataType
>
d0_m_n
(
f_host_tensor_descriptor
(
M
,
N
,
StrideD
,
D0Layout
{}));
Tensor
<
D1DataType
>
d1_m_n
(
f_host_tensor_descriptor
(
M
,
N
,
StrideD
,
D1Layout
{}));
Tensor
<
D2DataType
>
d2_m_n
(
f_host_tensor_descriptor
(
M
,
N
,
StrideD
,
D2Layout
{}));
Tensor
<
EDataType
>
e_m_n_host_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideE
,
ELayout
{}));
Tensor
<
EDataType
>
e_m_n_device_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideE
,
ELayout
{}));
std
::
cout
<<
"a0_m_k: "
<<
a0_m_k
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b0_k_n: "
<<
b0_k_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"d2_m_n: "
<<
d2_m_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"d1_m_n: "
<<
d1_m_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"d0_m_n: "
<<
d0_m_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"e_m_n: "
<<
e_m_n_host_result
.
mDesc
<<
std
::
endl
;
switch
(
init_method
)
{
case
0
:
break
;
case
1
:
a0_m_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
A0DataType
>
{
-
2
,
2
});
b0_k_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
B0DataType
>
{
0
,
2
});
d0_m_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
D0DataType
>
{
0
,
2
});
d1_m_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
D1DataType
>
{
0
,
2
});
d2_m_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
D2DataType
>
{
0
,
2
});
break
;
default:
a0_m_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
A0DataType
>
{
-
127
,
127
});
b0_k_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
B0DataType
>
{
-
127
,
127
});
d0_m_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
D0DataType
>
{
-
0.5
,
0.5
});
d1_m_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
D1DataType
>
{
-
0.5
,
0.5
});
d2_m_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
D2DataType
>
{
-
0.5
,
0.5
});
}
DeviceMem
a0_device_buf
(
sizeof
(
A0DataType
)
*
a0_m_k
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b0_device_buf
(
sizeof
(
B0DataType
)
*
b0_k_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
d0_device_buf
(
sizeof
(
D0DataType
)
*
d0_m_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
d1_device_buf
(
sizeof
(
D1DataType
)
*
d1_m_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
d2_device_buf
(
sizeof
(
D2DataType
)
*
d2_m_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
e_device_buf
(
sizeof
(
EDataType
)
*
e_m_n_device_result
.
mDesc
.
GetElementSpaceSize
());
a0_device_buf
.
ToDevice
(
a0_m_k
.
mData
.
data
());
b0_device_buf
.
ToDevice
(
b0_k_n
.
mData
.
data
());
d0_device_buf
.
ToDevice
(
d0_m_n
.
mData
.
data
());
d1_device_buf
.
ToDevice
(
d1_m_n
.
mData
.
data
());
d2_device_buf
.
ToDevice
(
d2_m_n
.
mData
.
data
());
e_device_buf
.
ToDevice
(
e_m_n_device_result
.
mData
.
data
());
auto
a_element_op
=
AElementOp
{};
auto
b_element_op
=
BElementOp
{};
auto
cde_element_op
=
CDEElementOp
{};
constexpr
ck
::
index_t
NumDTensor
=
DsDataType
::
Size
();
constexpr
auto
I0
=
ck
::
Number
<
0
>
{};
// do GEMM
auto
device_op
=
DeviceOpInstance
{};
auto
invoker
=
device_op
.
MakeInvoker
();
auto
argument
=
device_op
.
MakeArgument
(
a0_device_buf
.
GetDeviceBuffer
(),
b0_device_buf
.
GetDeviceBuffer
(),
std
::
array
<
const
void
*
,
NumDTensor
>
{
d0_device_buf
.
GetDeviceBuffer
(),
d1_device_buf
.
GetDeviceBuffer
(),
d2_device_buf
.
GetDeviceBuffer
()},
e_device_buf
.
GetDeviceBuffer
(),
M
,
N
,
K
,
StrideA
,
StrideB
,
std
::
array
<
ck
::
index_t
,
NumDTensor
>
{
I0
,
I0
,
I0
},
StrideE
,
KBatch
,
a_element_op
,
b_element_op
,
cde_element_op
);
if
(
!
device_op
.
IsSupportedArgument
(
argument
))
{
throw
std
::
runtime_error
(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem"
);
}
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
,
20
});
//, 50, 200, true, 200});
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
;
std
::
size_t
num_btype
=
sizeof
(
A0DataType
)
*
M
*
K
+
sizeof
(
B0DataType
)
*
K
*
N
+
sizeof
(
EDataType
)
*
M
*
N
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s"
<<
std
::
endl
;
if
(
do_verification
)
{
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
false
});
e_device_buf
.
FromDevice
(
e_m_n_device_result
.
mData
.
data
());
Tensor
<
CShuffleDataType
>
c_m_n
({
M
,
N
});
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
A0DataType
,
B0DataType
,
CShuffleDataType
,
AccDataType
,
PassThrough
,
PassThrough
,
PassThrough
>
;
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
a0_m_k
,
b0_k_n
,
c_m_n
,
PassThrough
{},
PassThrough
{},
PassThrough
{});
ref_invoker
.
Run
(
ref_argument
);
for
(
int
m
=
0
;
m
<
M
;
++
m
)
{
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
cde_element_op
(
e_m_n_host_result
(
m
,
n
),
c_m_n
(
m
,
n
),
d0_m_n
(
m
,
n
),
d1_m_n
(
m
,
n
),
d2_m_n
(
m
,
n
));
}
}
e_device_buf
.
FromDevice
(
e_m_n_device_result
.
mData
.
data
());
auto
print_tensor
=
[](
auto
&
tensor
,
auto
&
name
)
{
size_t
colID
=
0
;
size_t
rowID
=
0
;
std
::
cout
<<
"
\n
"
<<
name
<<
"
\n
"
<<
rowID
++
<<
": "
;
for
(
auto
el
:
tensor
.
mData
)
{
// std::cout << el << " ";
std
::
cout
<<
ck
::
type_convert
<
float
>
(
el
)
<<
" "
;
if
(
colID
<
((
tensor
.
GetLengths
()[
1
]
-
1
)
<
20
?
(
tensor
.
GetLengths
()[
1
]
-
1
)
:
20
))
{
colID
++
;
}
else
{
colID
=
0
;
std
::
cout
<<
std
::
endl
<<
rowID
++
<<
": "
;
}
if
(
rowID
>
5
)
{
break
;
}
}
std
::
cout
<<
std
::
endl
;
};
print_tensor
(
a0_m_k
,
"a0_m_k"
);
print_tensor
(
b0_k_n
,
"b0_k_n"
);
print_tensor
(
e_m_n_device_result
,
"e_m_n_device_result"
);
print_tensor
(
e_m_n_host_result
,
"e_m_n_host_result"
);
return
ck
::
utils
::
check_err
(
e_m_n_device_result
,
e_m_n_host_result
,
"Error: Incorrect results!"
,
5e-2
,
1
)
?
0
:
1
;
}
return
0
;
}
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp
View file @
5e8b5703
...
@@ -190,7 +190,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleDSplitK<ALayo
...
@@ -190,7 +190,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleDSplitK<ALayo
});
});
ck
::
utility
::
RotatingMemWrapperMultiD
<
Argument
,
DsDataType
>
rotating_mem
(
ck
::
utility
::
RotatingMemWrapperMultiD
<
Argument
,
DsDataType
>
rotating_mem
(
arg_
,
stream_config
.
rotating_count
,
size_a_buffer
,
size_b_buffer
,
DsSize
);
arg_
,
stream_config
.
rotating_count
,
size_a_buffer
,
size_b_buffer
,
DsSize
);
rotating_mem
.
Print
();
//
rotating_mem.Print();
auto
run_flush_cache
=
[
&
]()
{
auto
run_flush_cache
=
[
&
]()
{
// flush icache
// flush icache
...
...
include/ck/tensor_operation/gpu/element/element_wise_operation.hpp
View file @
5e8b5703
...
@@ -294,6 +294,35 @@ struct MultiplyMultiply
...
@@ -294,6 +294,35 @@ struct MultiplyMultiply
}
}
};
};
struct
MultiplyMultiplyAdd
{
template
<
typename
E
,
typename
C
,
typename
D0
,
typename
D1
,
typename
D2
>
__host__
__device__
constexpr
void
operator
()(
E
&
e
,
const
C
&
c
,
const
D0
&
d0
,
const
D1
&
d1
,
const
D2
&
d2
)
const
;
template
<
>
__host__
__device__
constexpr
void
operator
()
<
ck
::
half_t
,
int
,
float
,
float
,
ck
::
half_t
>
(
ck
::
half_t
&
e
,
const
int
&
c
,
const
float
&
d0
,
const
float
&
d1
,
const
ck
::
half_t
&
d2
)
const
{
const
float
x0_f
=
ck
::
type_convert
<
float
>
(
c
)
*
ck
::
type_convert
<
float
>
(
d0
)
*
ck
::
type_convert
<
float
>
(
d1
)
+
ck
::
type_convert
<
float
>
(
d2
);
e
=
ck
::
type_convert
<
ck
::
half_t
>
(
x0_f
);
}
template
<
>
__host__
__device__
constexpr
void
operator
()
<
ck
::
bhalf_t
,
int
,
float
,
float
,
ck
::
bhalf_t
>
(
ck
::
bhalf_t
&
e
,
const
int
&
c
,
const
float
&
d0
,
const
float
&
d1
,
const
ck
::
bhalf_t
&
d2
)
const
{
const
float
x0_f
=
ck
::
type_convert
<
float
>
(
c
)
*
ck
::
type_convert
<
float
>
(
d0
)
*
ck
::
type_convert
<
float
>
(
d1
)
+
ck
::
type_convert
<
float
>
(
d2
);
e
=
ck
::
type_convert
<
ck
::
bhalf_t
>
(
x0_f
);
}
};
struct
MultiplyAddFastGelu
struct
MultiplyAddFastGelu
{
{
template
<
typename
E
,
typename
C
,
typename
D0
,
typename
D1
>
template
<
typename
E
,
typename
C
,
typename
D0
,
typename
D1
>
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_bshuffle.hpp
0 → 100644
View file @
5e8b5703
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp"
#define DEBUG_LOG 0
namespace
ck
{
// Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same
// kernel function Blockers:
// 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on
// two lds chunks.
// 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds
// buffer when we declare __shared__ inside blkgemmpipe
template
<
typename
GridwiseGemm
,
bool
HasMainKBlockLoop
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
index_t
MinimumOccupancy
=
1
,
TailNumber
TailNum
=
TailNumber
::
Full
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
MinimumOccupancy
)
#endif
// __attribute__((amdgpu_waves_per_eu(1, 1)))
kernel_gemm_xdl_cshuffle_v3_multi_d
(
typename
GridwiseGemm
::
Argument
karg
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
auto
splitk_batch_offset
=
typename
GridwiseGemm
::
SplitKBatchOffset
(
karg
);
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
,
TailNum
>(
karg
.
p_a_grid
+
splitk_batch_offset
.
a_k_split_offset
,
karg
.
p_b_grid
+
splitk_batch_offset
.
b_k_split_offset
,
karg
.
p_ds_grid
,
karg
.
p_c_grid
,
p_shared
,
karg
,
karg
.
a_element_op
,
karg
.
b_element_op
,
karg
.
c_element_op
);
#else
ignore
=
karg
;
#endif // end of if (defined(__gfx9__))
}
template
<
typename
GridwiseGemm
,
bool
HasMainKBlockLoop
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
index_t
MinimumOccupancy
=
1
,
TailNumber
TailNum
=
TailNumber
::
Full
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
MinimumOccupancy
)
#endif
// __attribute__((amdgpu_waves_per_eu(1, 1)))
kernel_gemm_xdl_cshuffle_v3_multi_d_2lds
(
typename
GridwiseGemm
::
Argument
karg
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
// Pass two lds pointer is the key to tell compiler that ds_read/write
// operate on different lds chunk at same time without order dependecy
__shared__
char
p_shared_0
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared_1
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
auto
splitk_batch_offset
=
typename
GridwiseGemm
::
SplitKBatchOffset
(
karg
);
GridwiseGemm
::
template
Run_2Lds
<
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
,
TailNum
>(
karg
.
p_a_grid
+
splitk_batch_offset
.
a_k_split_offset
,
karg
.
p_b_grid
+
splitk_batch_offset
.
b_k_split_offset
,
karg
.
p_ds_grid
,
karg
.
p_c_grid
,
p_shared_0
,
p_shared_1
,
karg
,
karg
.
a_element_op
,
karg
.
b_element_op
,
karg
.
c_element_op
);
#else
ignore
=
karg
;
#endif // end of if (defined(__gfx9__))
}
template
<
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
CLayout
,
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CShuffleDataType
,
typename
DsDataType
,
typename
CDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
tensor_operation
::
device
::
GemmSpecialization
GemmSpec
,
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
AK1Value
,
index_t
BK1Value
,
index_t
MPerXdl
,
index_t
NPerXdl
,
index_t
MXdlPerWave
,
index_t
NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_AK1
,
bool
AThreadTransferSrcResetCoordinateAfterRun
,
index_t
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_BK1
,
bool
BThreadTransferSrcResetCoordinateAfterRun
,
index_t
BBlockLdsExtraN
,
index_t
CShuffleMXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
CDEShuffleBlockTransferScalarPerVectors
,
BlockGemmPipelineScheduler
BlkGemmPipeSched
=
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
BlkGemmPipelineVer
=
BlockGemmPipelineVersion
::
v4
,
typename
ComputeTypeA
=
CDataType
,
typename
ComputeTypeB
=
ComputeTypeA
,
typename
LDSTypeA
=
ADataType
,
typename
LDSTypeB
=
BDataType
>
struct
GridwiseGemmMultiD_xdl_cshuffle_v3
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
I6
=
Number
<
6
>
{};
static
constexpr
auto
I7
=
Number
<
7
>
{};
static
constexpr
auto
CShuffleBlockTransferScalarPerVector_NPerBlock
=
CDEShuffleBlockTransferScalarPerVectors
{}[
I0
];
// K1 should be Number<...>
static
constexpr
auto
AK0Number
=
Number
<
KPerBlock
/
AK1Value
>
{};
static
constexpr
auto
BK0Number
=
Number
<
KPerBlock
/
BK1Value
>
{};
static
constexpr
auto
AK1Number
=
Number
<
AK1Value
>
{};
static
constexpr
auto
BK1Number
=
Number
<
BK1Value
>
{};
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
static
constexpr
auto
MakeDsGridPointer
()
{
return
generate_tuple
(
[
&
](
auto
i
)
{
using
DDataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsDataType
>>
;
return
static_cast
<
const
DDataType
*>
(
nullptr
);
},
Number
<
NumDTensor
>
{});
}
using
DsGridPointer
=
decltype
(
MakeDsGridPointer
());
static
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1Number
,
BK1Number
),
MfmaSelector
<
ComputeTypeA
,
MPerXdl
,
NPerXdl
,
ComputeTypeB
>::
selected_mfma
.
k_per_blk
);
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
__host__
static
auto
CalculateGridSize
(
index_t
M
,
index_t
N
,
index_t
KBatch
)
{
return
std
::
make_tuple
(
Block2CTileMapDefault
::
CalculateGridSize
(
M
,
N
),
1
,
KBatch
);
}
__host__
__device__
static
auto
CalculateMPadded
(
index_t
M
)
{
return
math
::
integer_least_multiple
(
M
,
MPerBlock
);
}
__host__
__device__
static
auto
CalculateNPadded
(
index_t
N
)
{
return
math
::
integer_least_multiple
(
N
,
NPerBlock
);
}
__host__
__device__
static
auto
CalculateKPadded
(
index_t
K
)
{
return
math
::
integer_divide_ceil
(
K
,
KPerBlock
)
*
KPerBlock
;
}
__host__
__device__
static
auto
CalculateAK0Padded
(
index_t
K
,
index_t
K_Batch
=
1
)
{
auto
K_t
=
K_Batch
*
KPerBlock
;
return
(
K
+
K_t
-
1
)
/
K_t
*
(
KPerBlock
/
AK1Value
);
}
__host__
__device__
static
auto
CalculateBK0Padded
(
index_t
K
,
index_t
K_Batch
=
1
)
{
auto
K_t
=
K_Batch
*
KPerBlock
;
return
(
K
+
K_t
-
1
)
/
K_t
*
(
KPerBlock
/
BK1Value
);
}
__host__
__device__
static
auto
CalculateKPadded
(
index_t
K
,
index_t
K_Batch
=
1
)
{
auto
K_t
=
K_Batch
*
KPerBlock
;
return
(
K
+
K_t
-
1
)
/
K_t
*
KPerBlock
;
}
__host__
__device__
static
auto
CalculateKRead
(
index_t
K
,
index_t
K_Batch
=
1
)
{
constexpr
auto
KReadVec
=
math
::
lcm
(
AK1Number
,
BK1Number
);
auto
K_t
=
K_Batch
*
KReadVec
;
return
(
K
+
K_t
-
1
)
/
K_t
*
KReadVec
;
}
__host__
__device__
static
auto
CalculateMBlock
(
index_t
M
)
{
return
math
::
integer_divide_ceil
(
M
,
MPerBlock
);
}
__host__
__device__
static
auto
CalculateNBlock
(
index_t
N
)
{
return
math
::
integer_divide_ceil
(
N
,
NPerBlock
);
}
template
<
index_t
MNXdlPerWave
,
index_t
MNWaves
,
index_t
MNPerXdl
,
typename
TileDesc_K0_MN_K1
>
__host__
__device__
static
constexpr
auto
MakeGemmMmaTileDescriptor
(
const
TileDesc_K0_MN_K1
&
)
{
constexpr
index_t
K0
=
TileDesc_K0_MN_K1
{}.
GetLength
(
Number
<
0
>
{});
constexpr
index_t
K1
=
TileDesc_K0_MN_K1
{}.
GetLength
(
Number
<
2
>
{});
return
transform_tensor_descriptor
(
TileDesc_K0_MN_K1
{},
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
Number
<
K0
>
{},
Number
<
K1
>
{})),
make_unmerge_transform
(
make_tuple
(
Number
<
MNXdlPerWave
>
{},
Number
<
MNWaves
>
{},
Number
<
MNPerXdl
>
{}))),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
3
>
{},
Sequence
<
0
,
1
,
2
>
{}));
}
__host__
__device__
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
index_t
M
,
index_t
MPad
,
index_t
K
,
index_t
KPad
,
index_t
StrideA
,
index_t
AK0
)
{
const
auto
a_grid_desc_mraw_kraw
=
[
&
]()
{
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
StrideA
,
I1
));
}
else
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
I1
,
StrideA
));
}
}();
using
GemmSpecialization
=
tensor_operation
::
device
::
GemmSpecialization
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad both M and K
const
auto
a_grid_desc_m_k
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_right_pad_transform
(
M
,
MPad
-
M
),
make_right_pad_transform
(
K
,
KPad
-
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1Value
)),
make_pass_through_transform
(
MPad
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MPadding
||
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
// pad M, but not K
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1Value
)),
make_right_pad_transform
(
M
,
MPad
-
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
KPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
)
{
// pad K, but not M
const
auto
a_grid_desc_m_k
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_pass_through_transform
(
M
),
make_right_pad_transform
(
K
,
KPad
-
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1Value
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
else
{
// not pad M or K
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1Value
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
}
__host__
__device__
static
auto
MakeBGridDescriptor_BK0_N_BK1
(
index_t
K
,
index_t
KPad
,
index_t
N
,
index_t
NPad
,
index_t
StrideB
,
index_t
BK0
)
{
// const auto b_grid_desc_nraw_kraw = [&]() {
// if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
// {
// return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB));
// }
// else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
// {
// return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1));
// }
// }();
const
auto
b_grid_desc_nraw_kraw_shuffle_n0_k0_k1_n1_k2
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
N
,
BK0Number
,
BK1Number
),
make_tuple
(
StrideB
,
BK1Number
,
I1
));
}
}();
const
auto
b_grid_desc_nraw_kraw_n_k
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
N
,
BK0Number
,
BK1Number
),
make_tuple
(
StrideB
,
BK1Number
,
I1
));
}
}();
using
GemmSpecialization
=
tensor_operation
::
device
::
GemmSpecialization
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad both N and K
const
auto
b_grid_desc_n_k
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_right_pad_transform
(
N
,
NPad
-
N
),
make_right_pad_transform
(
K
,
KPad
-
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1Value
)),
make_pass_through_transform
(
NPad
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NPadding
||
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
// pad N, but not K
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1Value
)),
make_right_pad_transform
(
N
,
NPad
-
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
KPadding
||
GemmSpec
==
GemmSpecialization
::
MKPadding
)
{
// pad K, but not N
const
auto
b_grid_desc_n_k
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_pass_through_transform
(
N
),
make_right_pad_transform
(
K
,
KPad
-
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1Value
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
else
{
// not pad N or K
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1Value
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
}
template
<
typename
ABlockDesc_AK0_M_AK1
>
__host__
__device__
static
constexpr
auto
MakeAMmaTileDescriptor_M0_M1_M2_K
(
const
ABlockDesc_AK0_M_AK1
&
)
{
constexpr
index_t
MWaves
=
MPerBlock
/
(
MXdlPerWave
*
MPerXdl
);
return
MakeGemmMmaTileDescriptor
<
MXdlPerWave
,
MWaves
,
MPerXdl
>
(
ABlockDesc_AK0_M_AK1
{});
}
template
<
typename
BBlockDesc_BK0_N_BK1
>
__host__
__device__
static
constexpr
auto
MakeBMmaTileDescriptor_N0_N1_N2_K
(
const
BBlockDesc_BK0_N_BK1
&
)
{
constexpr
index_t
NWaves
=
NPerBlock
/
(
NXdlPerWave
*
NPerXdl
);
return
MakeGemmMmaTileDescriptor
<
NXdlPerWave
,
NWaves
,
NPerXdl
>
(
BBlockDesc_BK0_N_BK1
{});
}
template
<
typename
ELayout
>
__host__
__device__
static
auto
MakeCGridDescriptor_M_N
(
index_t
M
,
index_t
MPad
,
index_t
N
,
index_t
NPad
,
index_t
StrideC
)
{
const
auto
c_grid_desc_mraw_nraw
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ELayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
StrideC
,
I1
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
ELayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
I1
,
StrideC
));
}
}();
// pad M and N
return
transform_tensor_descriptor
(
c_grid_desc_mraw_nraw
,
make_tuple
(
make_right_pad_transform
(
M
,
MPad
-
M
),
make_right_pad_transform
(
N
,
NPad
-
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
#if 0
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad M and N
return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(M, MPad - M),
make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MKPadding)
{
// pad M, but not N
return transform_tensor_descriptor(
c_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(M, MPad - M), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::NKPadding)
{
// pad N, but not M
return transform_tensor_descriptor(
c_grid_desc_mraw_nraw,
make_tuple(make_pass_through_transform(M), make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
// not pad M or N
return c_grid_desc_mraw_nraw;
}
#endif
}
__host__
__device__
static
auto
MakeDsGridDescriptor_M_N
(
index_t
M
,
index_t
MPad
,
index_t
N
,
index_t
NPad
,
std
::
array
<
index_t
,
NumDTensor
>
StrideDs
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
return
MakeCGridDescriptor_M_N
<
DLayout
>
(
M
,
MPad
,
N
,
NPad
,
StrideDs
[
i
]);
},
Number
<
NumDTensor
>
{});
}
template
<
typename
DsGridDesc
>
__device__
static
constexpr
auto
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
const
DsGridDesc
&
ds_grid_desc_m_n
,
index_t
MBlock
,
index_t
NBlock
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
return
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
ds_grid_desc_m_n
[
i
],
MBlock
,
NBlock
);
},
Number
<
NumDTensor
>
{});
}
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_M_N
(
0
,
0
,
0
,
0
,
{}))
>
;
struct
Problem
{
__host__
__device__
Problem
(
index_t
M_
,
index_t
N_
,
index_t
K_
,
index_t
StrideA_
,
index_t
StrideB_
,
std
::
array
<
index_t
,
NumDTensor
>
StrideDs_
,
index_t
StrideC_
,
index_t
KBatch_
)
:
M
{
M_
},
N
{
N_
},
K
{
K_
},
StrideA
{
StrideA_
},
StrideB
{
StrideB_
},
StrideDs
{
StrideDs_
},
StrideC
{
StrideC_
},
KBatch
{
KBatch_
},
MPadded
{
CalculateMPadded
(
M_
)},
NPadded
{
CalculateNPadded
(
N_
)},
KRead
{
CalculateKRead
(
K_
,
KBatch_
)},
KPadded
{
CalculateKPadded
(
K_
,
KBatch_
)},
AK0
{
CalculateAK0Padded
(
K_
,
KBatch_
)},
BK0
{
CalculateBK0Padded
(
K_
,
KBatch_
)},
MBlock
{
CalculateMBlock
(
M_
)},
NBlock
{
CalculateNBlock
(
N_
)}
{
}
__host__
void
Print
()
const
{
std
::
cout
<<
"problem {"
<<
"M:"
<<
M
<<
", "
<<
"N:"
<<
N
<<
", "
<<
"K:"
<<
K
<<
", "
<<
"SA:"
<<
StrideA
<<
", "
<<
"SB:"
<<
StrideB
<<
", "
<<
"SC:"
<<
StrideC
<<
", "
<<
"MP:"
<<
MPadded
<<
", "
<<
"NP:"
<<
NPadded
<<
", "
<<
"KRead:"
<<
KRead
<<
", "
<<
"KP:"
<<
KPadded
<<
", "
<<
"AK0:"
<<
AK0
<<
", "
<<
"BK0:"
<<
BK0
<<
", "
<<
"MBlock: "
<<
MBlock
<<
", "
<<
"NBlock: "
<<
NBlock
<<
"}"
<<
std
::
endl
;
}
index_t
M
;
index_t
N
;
index_t
K
;
index_t
StrideA
;
index_t
StrideB
;
std
::
array
<
index_t
,
NumDTensor
>
StrideDs
;
index_t
StrideC
;
index_t
KBatch
;
index_t
MPadded
;
index_t
NPadded
;
index_t
KRead
;
index_t
KPadded
;
index_t
AK0
;
index_t
BK0
;
index_t
MBlock
;
index_t
NBlock
;
};
// Argument
struct
Argument
:
public
tensor_operation
::
device
::
BaseArgument
,
public
Problem
{
__host__
Argument
(
const
ADataType
*
p_a_grid_
,
const
BDataType
*
p_b_grid_
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds_grid_
,
CDataType
*
p_c_grid_
,
index_t
M_
,
index_t
N_
,
index_t
K_
,
index_t
StrideA_
,
index_t
StrideB_
,
std
::
array
<
index_t
,
NumDTensor
>
StrideDs_
,
index_t
StrideC_
,
index_t
k_batch_
,
AElementwiseOperation
a_element_op_
,
BElementwiseOperation
b_element_op_
,
CElementwiseOperation
c_element_op_
)
:
Problem
{
M_
,
N_
,
K_
,
StrideA_
,
StrideB_
,
StrideDs_
,
StrideC_
,
k_batch_
},
p_a_grid
{
p_a_grid_
},
p_b_grid
{
p_b_grid_
},
p_ds_grid
{},
p_c_grid
{
p_c_grid_
},
a_element_op
{
a_element_op_
},
b_element_op
{
b_element_op_
},
c_element_op
{
c_element_op_
}
{
// populate pointer, desc for Ds
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
using
DDataType_
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsDataType
>>
;
// D pointer
p_ds_grid
(
i
)
=
static_cast
<
const
DDataType_
*>
(
p_ds_grid_
[
i
]);
});
}
const
ADataType
*
p_a_grid
;
const
BDataType
*
p_b_grid
;
DsGridPointer
p_ds_grid
;
CDataType
*
p_c_grid
;
const
AElementwiseOperation
a_element_op
;
const
BElementwiseOperation
b_element_op
;
const
CElementwiseOperation
c_element_op
;
};
struct
SplitKBatchOffset
{
__device__
SplitKBatchOffset
(
Argument
&
karg
)
{
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>
)
{
a_k_split_offset
=
blockIdx
.
z
*
karg
.
KRead
;
}
else
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>
)
{
a_k_split_offset
=
blockIdx
.
z
*
karg
.
KRead
*
karg
.
StrideA
;
}
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>
)
{
b_k_split_offset
=
blockIdx
.
z
*
karg
.
KRead
*
karg
.
StrideB
;
}
else
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>
)
{
b_k_split_offset
=
blockIdx
.
z
*
karg
.
KRead
;
}
if
(
blockIdx
.
z
<
static_cast
<
uint32_t
>
(
karg
.
KBatch
-
1
))
{
karg
.
K
=
karg
.
KRead
;
}
else
{
karg
.
K
=
karg
.
K
-
karg
.
KRead
*
(
karg
.
KBatch
-
1
);
}
}
index_t
a_k_split_offset
;
index_t
b_k_split_offset
;
};
__device__
static
constexpr
auto
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()
{
// A matrix in LDS memory, dst of blockwise copy
if
constexpr
(
ABlockLdsExtraM
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
AK0Number
,
Number
<
MPerBlock
>
{},
AK1Number
),
make_tuple
(
AK1Number
,
Number
<
KPerBlock
+
ABlockLdsExtraM
>
{},
I1
));
}
// xor tensor transformation request more unnecessary vgpr usage, would cause register spill
// in some cases.
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
)
{
constexpr
auto
MLdsLayer
=
32
*
4
/
KPerBlock
/
sizeof
(
LDSTypeA
)
<
1
?
1
:
32
*
4
/
KPerBlock
/
sizeof
(
LDSTypeA
);
constexpr
auto
a_lds_block_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
AK0Number
*
Number
<
MLdsLayer
>
{},
Number
<
MPerBlock
/
MLdsLayer
>
{},
AK1Number
),
make_tuple
(
AK1Number
,
Number
<
KPerBlock
*
MLdsLayer
>
{},
I1
));
constexpr
auto
a_lds_block_desc_permuted
=
transform_tensor_descriptor
(
a_lds_block_desc
,
make_tuple
(
make_xor_with_modulo_transform
(
make_tuple
(
Number
<
MPerBlock
/
MLdsLayer
>
{},
Number
<
AK0Number
*
MLdsLayer
>
{})),
make_pass_through_transform
(
AK1Number
)),
make_tuple
(
Sequence
<
1
,
0
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
1
,
0
>
{},
Sequence
<
2
>
{}));
constexpr
auto
a_lds_block_desc_ak0_mldslayer_m_ak1
=
transform_tensor_descriptor
(
a_lds_block_desc_permuted
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0Number
,
Number
<
MLdsLayer
>
{})),
make_pass_through_transform
(
Number
<
MPerBlock
/
MLdsLayer
>
{}),
make_pass_through_transform
(
AK1Number
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{}));
constexpr
auto
a_lds_block_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_lds_block_desc_ak0_mldslayer_m_ak1
,
make_tuple
(
make_pass_through_transform
(
AK0Number
),
make_merge_transform_v3_division_mod
(
make_tuple
(
Number
<
MPerBlock
/
MLdsLayer
>
{},
Number
<
MLdsLayer
>
{})),
make_pass_through_transform
(
AK1Number
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
return
a_lds_block_desc_ak0_m_ak1
;
}
else
// ColumnMajor A
{
// kfold and mpair dimension is not always required.
// more dimension in merge_transform increase the difficulty of generating immarg offset
// for compiler.
constexpr
auto
M0
=
ABlockTransferThreadClusterLengths_AK0_M_AK1
{}.
At
(
I1
);
constexpr
auto
M1
=
MPerBlock
/
M0
;
constexpr
auto
KThreadWrite
=
ABlockTransferThreadClusterLengths_AK0_M_AK1
{}.
At
(
I0
);
constexpr
auto
K0PerThreadWrite
=
AK0Number
/
KThreadWrite
;
constexpr
auto
KThreadRead
=
64
/
MPerXdl
;
constexpr
auto
K0PerThreadRead
=
AK0Number
/
KThreadRead
;
constexpr
auto
kfold
=
(
AK1Number
*
M0
*
sizeof
(
LDSTypeA
)
>
128
)
?
1
:
128
/
(
AK1Number
*
M0
*
sizeof
(
LDSTypeA
));
constexpr
auto
KThreadReadPerm
=
(
kfold
*
K0PerThreadWrite
/
K0PerThreadRead
)
>
1
?
KThreadRead
/
(
kfold
*
K0PerThreadWrite
/
K0PerThreadRead
)
:
KThreadRead
;
// 1<=mpair<=n0
constexpr
auto
mpair
=
(
AK1Number
*
MPerXdl
*
sizeof
(
LDSTypeA
)
>
128
)
?
1
:
((
128
/
(
AK1Number
*
MPerXdl
*
sizeof
(
LDSTypeA
)))
>
M0
?
M0
:
128
/
(
AK1Number
*
MPerXdl
*
sizeof
(
LDSTypeA
)));
constexpr
auto
a_lds_block_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{},
Number
<
K0PerThreadWrite
>
{},
Number
<
KThreadReadPerm
*
M1
>
{},
Number
<
kfold
*
M0
/
mpair
>
{},
Number
<
mpair
>
{},
AK1Number
));
constexpr
auto
a_lds_block_desc_permuted
=
transform_tensor_descriptor
(
a_lds_block_desc
,
make_tuple
(
make_pass_through_transform
(
Number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{}),
make_pass_through_transform
(
Number
<
K0PerThreadWrite
>
{}),
make_xor_with_modulo_transform
(
make_tuple
(
Number
<
KThreadReadPerm
*
M1
>
{},
Number
<
kfold
*
M0
/
mpair
>
{})),
make_pass_through_transform
(
Number
<
mpair
>
{}),
make_pass_through_transform
(
AK1Number
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}));
constexpr
auto
a_lds_block_desc_unmerged
=
transform_tensor_descriptor
(
a_lds_block_desc_permuted
,
make_tuple
(
make_pass_through_transform
(
Number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{}),
make_pass_through_transform
(
Number
<
K0PerThreadWrite
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
KThreadReadPerm
>
{},
Number
<
M1
>
{})),
make_unmerge_transform
(
make_tuple
(
Number
<
kfold
>
{},
Number
<
M0
/
mpair
>
{})),
make_pass_through_transform
(
Number
<
mpair
>
{}),
make_pass_through_transform
(
AK1Number
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
0
,
3
>
{},
Sequence
<
4
,
5
>
{},
Sequence
<
6
>
{},
Sequence
<
7
>
{}));
constexpr
auto
a_lds_block_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_lds_block_desc_unmerged
,
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
Number
<
KThreadReadPerm
>
{},
Number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{},
Number
<
kfold
>
{},
Number
<
K0PerThreadWrite
>
{})),
make_merge_transform_v3_division_mod
(
make_tuple
(
Number
<
M0
/
mpair
>
{},
Number
<
mpair
>
{},
Number
<
M1
>
{})),
make_pass_through_transform
(
AK1Number
)),
make_tuple
(
Sequence
<
0
,
1
,
4
,
2
>
{},
Sequence
<
5
,
6
,
3
>
{},
Sequence
<
7
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
return
a_lds_block_desc_ak0_m_ak1
;
}
}
__device__
static
constexpr
auto
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
()
{
// B matrix in LDS memory, dst of blockwise copy
if
constexpr
(
BBlockLdsExtraN
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
BK0Number
,
Number
<
NPerBlock
>
{},
BK1Number
),
make_tuple
(
BK1Number
,
Number
<
KPerBlock
+
BBlockLdsExtraN
>
{},
I1
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
{
// NLdsLayer * K0 as logical Bank
constexpr
auto
NLdsLayer
=
32
*
4
/
KPerBlock
/
sizeof
(
LDSTypeB
)
<
1
?
1
:
32
*
4
/
KPerBlock
/
sizeof
(
LDSTypeB
);
;
constexpr
auto
b_lds_block_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
BK0Number
*
Number
<
NLdsLayer
>
{},
Number
<
NPerBlock
/
NLdsLayer
>
{},
BK1Number
),
make_tuple
(
BK1Number
,
Number
<
KPerBlock
*
NLdsLayer
>
{},
I1
));
constexpr
auto
b_lds_block_desc_permuted
=
transform_tensor_descriptor
(
b_lds_block_desc
,
make_tuple
(
make_xor_with_modulo_transform
(
make_tuple
(
Number
<
NPerBlock
/
NLdsLayer
>
{},
Number
<
BK0Number
*
NLdsLayer
>
{})),
make_pass_through_transform
(
BK1Number
)),
make_tuple
(
Sequence
<
1
,
0
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
1
,
0
>
{},
Sequence
<
2
>
{}));
constexpr
auto
b_lds_block_desc_bk0_nldslayer_n_bk1
=
transform_tensor_descriptor
(
b_lds_block_desc_permuted
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0Number
,
Number
<
NLdsLayer
>
{})),
make_pass_through_transform
(
Number
<
NPerBlock
/
NLdsLayer
>
{}),
make_pass_through_transform
(
BK1Number
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{}));
constexpr
auto
b_lds_block_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_lds_block_desc_bk0_nldslayer_n_bk1
,
make_tuple
(
make_pass_through_transform
(
BK0Number
),
make_merge_transform_v3_division_mod
(
make_tuple
(
Number
<
NPerBlock
/
NLdsLayer
>
{},
Number
<
NLdsLayer
>
{})),
make_pass_through_transform
(
BK1Number
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
return
b_lds_block_desc_bk0_n_bk1
;
}
else
// RowMajor B
{
constexpr
auto
N0
=
BBlockTransferThreadClusterLengths_BK0_N_BK1
{}.
At
(
I1
);
constexpr
auto
N1
=
NPerBlock
/
N0
;
constexpr
auto
KThreadWrite
=
BBlockTransferThreadClusterLengths_BK0_N_BK1
{}.
At
(
I0
);
constexpr
auto
K0PerThreadWrite
=
BK0Number
/
KThreadWrite
;
constexpr
auto
KThreadRead
=
64
/
NPerXdl
;
constexpr
auto
K0PerThreadRead
=
BK0Number
/
KThreadRead
;
constexpr
auto
kfold
=
(
BK1Number
*
N0
*
sizeof
(
LDSTypeB
)
>
128
)
?
1
:
128
/
(
BK1Number
*
N0
*
sizeof
(
LDSTypeB
));
constexpr
auto
KThreadReadPerm
=
(
kfold
*
K0PerThreadWrite
/
K0PerThreadRead
)
>
1
?
KThreadRead
/
(
kfold
*
K0PerThreadWrite
/
K0PerThreadRead
)
:
KThreadRead
;
// 1<=npair<=n0
constexpr
auto
npair
=
(
BK1Number
*
NPerXdl
*
sizeof
(
LDSTypeB
)
>
128
)
?
1
:
((
128
/
(
BK1Number
*
NPerXdl
*
sizeof
(
LDSTypeB
)))
>
N0
?
N0
:
128
/
(
BK1Number
*
NPerXdl
*
sizeof
(
LDSTypeB
)));
constexpr
auto
b_lds_block_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{},
Number
<
K0PerThreadWrite
>
{},
Number
<
KThreadReadPerm
*
N1
>
{},
Number
<
kfold
*
N0
/
npair
>
{},
Number
<
npair
>
{},
BK1Number
));
constexpr
auto
b_lds_block_desc_permuted
=
transform_tensor_descriptor
(
b_lds_block_desc
,
make_tuple
(
make_pass_through_transform
(
Number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{}),
make_pass_through_transform
(
Number
<
K0PerThreadWrite
>
{}),
make_xor_with_modulo_transform
(
make_tuple
(
Number
<
KThreadReadPerm
*
N1
>
{},
Number
<
kfold
*
N0
/
npair
>
{})),
make_pass_through_transform
(
Number
<
npair
>
{}),
make_pass_through_transform
(
BK1Number
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}));
constexpr
auto
b_lds_block_desc_unmerged
=
transform_tensor_descriptor
(
b_lds_block_desc_permuted
,
make_tuple
(
make_pass_through_transform
(
Number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{}),
make_pass_through_transform
(
Number
<
K0PerThreadWrite
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
KThreadReadPerm
>
{},
Number
<
N1
>
{})),
make_unmerge_transform
(
make_tuple
(
Number
<
kfold
>
{},
Number
<
N0
/
npair
>
{})),
make_pass_through_transform
(
Number
<
npair
>
{}),
make_pass_through_transform
(
BK1Number
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
0
,
3
>
{},
Sequence
<
4
,
5
>
{},
Sequence
<
6
>
{},
Sequence
<
7
>
{}));
constexpr
auto
b_lds_block_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_lds_block_desc_unmerged
,
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
Number
<
KThreadReadPerm
>
{},
Number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{},
Number
<
kfold
>
{},
Number
<
K0PerThreadWrite
>
{})),
make_merge_transform_v3_division_mod
(
make_tuple
(
Number
<
N0
/
npair
>
{},
Number
<
npair
>
{},
Number
<
N1
>
{})),
make_pass_through_transform
(
BK1Number
)),
make_tuple
(
Sequence
<
0
,
1
,
4
,
2
>
{},
Sequence
<
5
,
6
,
3
>
{},
Sequence
<
7
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
return
b_lds_block_desc_bk0_n_bk1
;
}
}
__device__
static
constexpr
auto
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
()
{
constexpr
index_t
MWave
=
MPerBlock
/
(
MXdlPerWave
*
MPerXdl
);
constexpr
index_t
NWave
=
NPerBlock
/
(
NXdlPerWave
*
NPerXdl
);
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
>
{},
I1
,
Number
<
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>
{}));
return
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
;
}
using
BlockwiseGemmPipe
=
remove_cvref_t
<
decltype
(
BlockGemmPipeline_Selector
<
BlkGemmPipelineVer
,
BlkGemmPipeSched
,
BlockSize
,
LDSTypeA
,
LDSTypeB
,
ComputeTypeA
,
AccDataType
,
decltype
(
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()),
decltype
(
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
()),
decltype
(
MakeAMmaTileDescriptor_M0_M1_M2_K
(
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
())),
decltype
(
MakeBMmaTileDescriptor_N0_N1_N2_K
(
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
())),
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXdl
,
NPerXdl
,
MXdlPerWave
,
NXdlPerWave
,
KPack
>
())
>
;
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
constexpr
auto
b_block_desc_bk0_n_bk1
=
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
// lds max alignment
constexpr
auto
max_lds_align
=
math
::
lcm
(
AK1Number
,
BK1Number
);
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
constexpr
auto
b_block_space_size_aligned
=
math
::
integer_least_multiple
(
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
(),
max_lds_align
);
// LDS allocation for C shuffle in LDS
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
constexpr
auto
c_block_size
=
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
();
return
math
::
max
((
a_block_space_size_aligned
*
sizeof
(
LDSTypeA
)
+
b_block_space_size_aligned
*
sizeof
(
LDSTypeB
)),
c_block_size
*
sizeof
(
CShuffleDataType
));
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
__host__
static
constexpr
bool
CheckValidity
(
const
Argument
&
karg
)
{
static_assert
((
MPerBlock
%
(
MPerXdl
*
MXdlPerWave
)
==
0
)
&&
(
NPerBlock
%
(
NXdlPerWave
*
NPerXdl
))
==
0
,
"Invalid tuning param!"
);
if
constexpr
(
!
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
)
&&
!
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
))
{
if
(
!
(
karg
.
M
%
MPerBlock
==
0
))
{
#if DEBUG_LOG
std
::
cout
<<
"Arg M value is not a multiple of MPerBlock! M: "
<<
karg
.
M
<<
" "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
#endif // DEBUG_LOG
return
false
;
}
}
if
constexpr
(
!
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
)
&&
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
))
{
if
(
!
(
karg
.
N
%
NPerBlock
==
0
))
{
#if DEBUG_LOG
std
::
cout
<<
"Arg N value is not a multiple of NPerBlock! N: "
<<
karg
.
N
<<
" "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
#endif // DEBUG_LOG
return
false
;
}
}
if
constexpr
(
!
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
KPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
))
{
auto
K_t
=
karg
.
KBatch
*
KPerBlock
;
if
(
!
(
karg
.
K
%
K_t
==
0
))
{
#if DEBUG_LOG
std
::
cout
<<
"Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
<<
karg
.
K
<<
" "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
#endif // DEBUG_LOG
return
false
;
}
}
else
{
constexpr
auto
KReadVec
=
math
::
lcm
(
AK1Number
,
BK1Number
);
auto
K_t
=
karg
.
KBatch
*
KReadVec
;
auto
KReadPadSplited
=
math
::
integer_divide_ceil
(
karg
.
K
,
K_t
)
*
KReadVec
;
if
((
KReadPadSplited
*
(
karg
.
KBatch
-
1
))
>=
karg
.
K
)
{
return
false
;
}
}
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
)
{
if
(
karg
.
K
%
ABlockTransferSrcScalarPerVector
!=
0
)
{
#if DEBUG_LOG
std
::
cout
<<
"Arg K ("
<<
karg
.
K
<<
") value is not a multiple of ABlockTransferSrcScalarPerVector ("
<<
ABlockTransferSrcScalarPerVector
<<
" )! "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
#endif // DEBUG_LOG
return
false
;
}
}
else
{
if
(
karg
.
M
%
ABlockTransferSrcScalarPerVector
!=
0
)
{
#if DEBUG_LOG
std
::
cout
<<
"Arg M ("
<<
karg
.
M
<<
") value is not a multiple of ABlockTransferSrcScalarPerVector ("
<<
ABlockTransferSrcScalarPerVector
<<
" )! "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
#endif // DEBUG_LOG
return
false
;
}
}
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
{
if
(
karg
.
N
%
BBlockTransferSrcScalarPerVector
!=
0
)
{
#if DEBUG_LOG
std
::
cout
<<
"Arg N ("
<<
karg
.
N
<<
") value is not a multiple of BBlockTransferSrcScalarPerVector ("
<<
BBlockTransferSrcScalarPerVector
<<
" )! "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
#endif // DEBUG_LOG
return
false
;
}
}
else
{
if
(
karg
.
K
%
BBlockTransferSrcScalarPerVector
!=
0
)
{
#if DEBUG_LOG
std
::
cout
<<
"Arg K ("
<<
karg
.
K
<<
") value is not a multiple of BBlockTransferSrcScalarPerVector ("
<<
BBlockTransferSrcScalarPerVector
<<
" )! "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
#endif // DEBUG_LOG
return
false
;
}
}
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
CLayout
>::
value
)
{
if
(
karg
.
N
%
CShuffleBlockTransferScalarPerVector_NPerBlock
!=
0
)
{
#if DEBUG_LOG
std
::
cout
<<
"Arg N ("
<<
karg
.
N
<<
") value is not a multiple of "
"CShuffleBlockTransferScalarPerVector_NPerBlock ("
<<
CShuffleBlockTransferScalarPerVector_NPerBlock
<<
" )! "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
#endif // DEBUG_LOG
return
false
;
}
}
else
{
if
(
karg
.
M
%
CShuffleBlockTransferScalarPerVector_NPerBlock
!=
0
)
{
#if DEBUG_LOG
std
::
cout
<<
"Arg M ("
<<
karg
.
M
<<
") value is not a multiple of "
"CShuffleBlockTransferScalarPerVector_NPerBlock ("
<<
CShuffleBlockTransferScalarPerVector_NPerBlock
<<
" )! "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
#endif // DEBUG_LOG
return
false
;
}
}
// check gridwise gemm pipeline
const
auto
num_k_loop
=
karg
.
AK0
/
(
KPerBlock
/
AK1Value
);
if
constexpr
(
BlkGemmPipelineVer
!=
BlockGemmPipelineVersion
::
v1
)
{
if
(
num_k_loop
<=
BlockwiseGemmPipe
::
PrefetchStages
)
{
return
false
;
}
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return
true
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K
)
{
const
index_t
num_loop
=
K
/
KPerBlock
;
return
BlockwiseGemmPipe
::
BlockHasHotloop
(
num_loop
);
}
__host__
__device__
static
constexpr
TailNumber
CalculateKBlockLoopTailNum
(
index_t
K
)
{
const
index_t
num_loop
=
K
/
KPerBlock
;
return
BlockwiseGemmPipe
::
BlockLoopTailNum
(
num_loop
);
}
template
<
typename
CGridDesc
>
__device__
static
constexpr
auto
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
const
CGridDesc
&
c_grid_desc_m_n
,
index_t
MBlock
,
index_t
NBlock
)
{
const
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
transform_tensor_descriptor
(
c_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MBlock
,
Number
<
MPerBlock
>
{})),
make_unmerge_transform
(
make_tuple
(
NBlock
,
Number
<
NPerBlock
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}));
return
c_grid_desc_mblock_mperblock_nblock_nperblock
;
}
// return block_id to C matrix tile idx (m0, n0) mapping
// if arch = gfx942
using
Block2CTileMapDefault
=
BlockToCTileMap_Grouped_M00_N0_M01Adapt
<
8
,
MPerBlock
,
NPerBlock
>
;
template
<
bool
HasMainKBlockLoop
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
TailNumber
TailNum
=
TailNumber
::
Odd
>
__device__
static
void
Run
(
const
ADataType
*
p_a_grid
,
const
BDataType
*
p_b_grid
,
DsGridPointer
&
p_ds_grid
,
CDataType
*
p_c_grid
,
void
*
p_shared
,
const
Problem
&
problem
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
{
const
auto
block_2_ctile_map
=
Block2CTileMapDefault
{
problem
.
M
,
problem
.
N
,
4
};
Run
<
Block2CTileMapDefault
,
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
,
TailNum
>
(
p_a_grid
,
p_b_grid
,
p_ds_grid
,
p_c_grid
,
p_shared
,
problem
,
a_element_op
,
b_element_op
,
c_element_op
,
block_2_ctile_map
);
}
template
<
typename
Block2CTileMap
,
bool
HasMainKBlockLoop
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
TailNumber
TailNum
=
TailNumber
::
Odd
>
__device__
static
void
Run
(
const
ADataType
*
p_a_grid
,
const
BDataType
*
p_b_grid
,
DsGridPointer
&
p_ds_grid
,
CDataType
*
p_c_grid
,
void
*
p_shared
,
const
Problem
&
problem
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
const
Block2CTileMap
&
block_2_ctile_map
)
{
const
auto
a_grid_desc_ak0_m_ak1
=
MakeAGridDescriptor_AK0_M_AK1
(
problem
.
M
,
problem
.
MPadded
,
problem
.
K
,
problem
.
KPadded
,
problem
.
StrideA
,
problem
.
AK0
);
const
auto
b_grid_desc_bk0_n_bk1
=
MakeBGridDescriptor_BK0_N_BK1
(
problem
.
K
,
problem
.
KPadded
,
problem
.
N
,
problem
.
NPadded
,
problem
.
StrideB
,
problem
.
BK0
);
const
auto
c_grid_desc_m_n
=
MakeCGridDescriptor_M_N
<
CLayout
>
(
problem
.
M
,
problem
.
MPadded
,
problem
.
N
,
problem
.
NPadded
,
problem
.
StrideC
);
const
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n
,
problem
.
MBlock
,
problem
.
NBlock
);
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
());
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_grid
,
b_grid_desc_bk0_n_bk1
.
GetElementSpaceSize
());
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_grid
,
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
const
auto
block_work_idx
=
block_2_ctile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
if
(
!
block_2_ctile_map
.
ValidCTileIndex
(
block_work_idx
,
make_tuple
(
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetLength
(
I0
),
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetLength
(
I2
))))
{
return
;
}
const
index_t
block_m_id
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]);
const
index_t
block_n_id
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]);
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const
index_t
m_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_m_id
*
MPerBlock
);
const
index_t
n_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_n_id
*
NPerBlock
);
// lds max alignment
constexpr
auto
max_lds_align
=
math
::
lcm
(
AK1Number
,
BK1Number
);
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_block_desc_bk0_n_bk1
=
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
// A matrix blockwise copy
auto
a_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
AElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
AK0Number
,
MPerBlock
,
AK1Number
>
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ADataType
,
LDSTypeA
,
decltype
(
a_grid_desc_ak0_m_ak1
),
decltype
(
a_block_desc_ak0_m_ak1
),
ABlockTransferSrcAccessOrder
,
Sequence
<
0
,
1
,
2
>
,
ABlockTransferSrcVectorDim
,
2
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
1
,
1
,
AThreadTransferSrcResetCoordinateAfterRun
,
true
,
BlockwiseGemmPipe
::
GlobalBufferNum
>
(
a_grid_desc_ak0_m_ak1
,
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
a_element_op
,
a_block_desc_ak0_m_ak1
,
make_multi_index
(
0
,
0
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
// B matrix blockwise copy
auto
b_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
BK0Number
,
NPerBlock
,
BK1Number
>
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
BDataType
,
LDSTypeB
,
decltype
(
b_grid_desc_bk0_n_bk1
),
decltype
(
b_block_desc_bk0_n_bk1
),
BBlockTransferSrcAccessOrder
,
Sequence
<
0
,
1
,
2
>
,
BBlockTransferSrcVectorDim
,
2
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
1
,
1
,
BThreadTransferSrcResetCoordinateAfterRun
,
true
,
BlockwiseGemmPipe
::
GlobalBufferNum
>
(
b_grid_desc_bk0_n_bk1
,
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
b_element_op
,
b_block_desc_bk0_n_bk1
,
make_multi_index
(
0
,
0
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
// Cast after lds
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
LDSTypeA
*>
(
p_shared
),
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
LDSTypeB
*>
(
p_shared
)
+
a_block_space_size_aligned
*
sizeof
(
LDSTypeA
)
/
sizeof
(
LDSTypeB
),
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
AK1Number
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
BK1Number
,
0
,
0
);
// Blockwise GEMM pipeline
static_assert
(
std
::
is_default_constructible_v
<
BlockwiseGemmPipe
>
);
auto
blockwise_gemm_pipeline
=
BlockwiseGemmPipe
{};
auto
c_thread_buf
=
blockwise_gemm_pipeline
.
GetCThreadBuffer
();
const
index_t
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))
/
KPerBlock
);
blockwise_gemm_pipeline
.
template
Run
<
HasMainKBlockLoop
,
TailNum
>(
a_grid_desc_ak0_m_ak1
,
a_block_desc_ak0_m_ak1
,
a_blockwise_copy
,
a_grid_buf
,
a_block_buf
,
a_block_slice_copy_step
,
b_grid_desc_bk0_n_bk1
,
b_block_desc_bk0_n_bk1
,
b_blockwise_copy
,
b_grid_buf
,
b_block_buf
,
b_block_slice_copy_step
,
c_thread_buf
,
num_k_block_main_loop
);
// shuffle C and write out
{
static_assert
(
MXdlPerWave
%
CShuffleMXdlPerWavePerShuffle
==
0
&&
NXdlPerWave
%
CShuffleNXdlPerWavePerShuffle
==
0
,
"wrong!"
);
constexpr
index_t
MWave
=
MPerBlock
/
(
MXdlPerWave
*
MPerXdl
);
constexpr
index_t
NWave
=
NPerBlock
/
(
NXdlPerWave
*
NPerXdl
);
// TODO: hacky, fix it!
constexpr
auto
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
blockwise_gemm_pipeline
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
// TODO: hacky, fix it!
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
=
blockwise_gemm_pipeline
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
constexpr
auto
M0
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I0
);
constexpr
auto
N0
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I1
);
constexpr
auto
M1
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I2
);
constexpr
auto
N1
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I3
);
constexpr
auto
M2
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I4
);
constexpr
auto
M3
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I5
);
constexpr
auto
M4
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I6
);
constexpr
auto
N2
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I7
);
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
auto
c_shuffle_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
CShuffleDataType
*>
(
p_shared
),
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
transform_tensor_descriptor
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
make_tuple
(
make_freeze_transform
(
I0
),
make_unmerge_transform
(
make_tuple
(
Number
<
CShuffleMXdlPerWavePerShuffle
>
{},
// M0 (MXdlPerWave) per shuffle
M1
,
// M1 = MWave
M2
,
// M2 * M3 * M4 = MPerXdl
M3
,
M4
)),
make_freeze_transform
(
I0
),
make_unmerge_transform
(
make_tuple
(
Number
<
CShuffleNXdlPerWavePerShuffle
>
{},
// N0 (NXdlPerWave) per shuffle
N1
,
// N1 = NWave
N2
))),
// N2 = NPerXdl
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<>
{},
Sequence
<
0
,
2
,
4
,
5
,
6
>
{},
Sequence
<>
{},
Sequence
<
1
,
3
,
7
>
{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
blockwise_gemm_pipeline
.
CalculateCThreadOriginDataIndex
(
I0
,
I0
,
I0
,
I0
);
const
index_t
m_thread_data_on_block
=
c_thread_mtx_on_block
[
I0
];
const
index_t
n_thread_data_on_block
=
c_thread_mtx_on_block
[
I1
];
const
auto
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M0
,
M1
,
M2
,
M3
,
M4
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
m_thread_data_on_block_idx
=
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
m_thread_data_on_block
));
const
auto
n_thread_data_on_block_to_n0_n1_n2_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
N0
,
N1
,
N2
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
n_thread_data_on_block_idx
=
n_thread_data_on_block_to_n0_n1_n2_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
n_thread_data_on_block
));
// shuffle: threadwise copy C from VGPR to LDS
auto
c_thread_copy_vgpr_to_lds
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
CShuffleDataType
,
decltype
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
decltype
(
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
I1
,
I1
,
M2
,
I1
,
M4
,
I1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
7
,
1
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
{
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
make_multi_index
(
0
,
0
,
m_thread_data_on_block_idx
[
I1
],
n_thread_data_on_block_idx
[
I1
],
m_thread_data_on_block_idx
[
I2
],
m_thread_data_on_block_idx
[
I3
],
m_thread_data_on_block_idx
[
I4
],
n_thread_data_on_block_idx
[
I2
]),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
using
EDataType
=
CDataType
;
const
auto
ds_grid_desc_m_n
=
MakeDsGridDescriptor_M_N
(
problem
.
M
,
problem
.
MPadded
,
problem
.
N
,
problem
.
NPadded
,
problem
.
StrideDs
);
const
auto
ds_grid_desc_mblock_mperblock_nblock_nperblock
=
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
ds_grid_desc_m_n
,
problem
.
MBlock
,
problem
.
NBlock
);
const
auto
ds_grid_buf
=
generate_tuple
(
[
&
](
auto
i
)
{
return
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_ds_grid
[
i
],
ds_grid_desc_m_n
[
i
].
GetElementSpaceSize
());
},
Number
<
NumDTensor
>
{});
// tuple of reference to C/Ds tensor descriptors
const
auto
c_ds_desc_refs
=
concat_tuple_of_reference
(
tie
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
),
generate_tie
([
&
](
auto
i
)
->
const
auto
&
// return type should be reference
{
return
ds_grid_desc_mblock_mperblock_nblock_nperblock
[
i
];
},
Number
<
NumDTensor
>
{}));
// tuple of reference to C/Ds tensor descriptors
const
auto
c_ds_buf_refs
=
concat_tuple_of_reference
(
tie
(
c_shuffle_block_buf
),
generate_tie
([
&
](
auto
i
)
->
const
auto
&
// return type should be reference
{
return
ds_grid_buf
[
i
];
},
Number
<
NumDTensor
>
{}));
// tuple of starting index of C/Ds blockwise copy
const
auto
idx_c_ds_block_begin
=
container_concat
(
make_tuple
(
make_multi_index
(
0
,
0
,
0
,
0
)),
generate_tuple
(
[
&
](
auto
)
{
return
make_multi_index
(
block_work_idx
[
I0
],
0
,
block_work_idx
[
I1
],
0
);
},
Number
<
NumDTensor
>
{}));
const
auto
e_grid_desc_mblock_mperblock_nblock_nperblock
=
c_grid_desc_mblock_mperblock_nblock_nperblock
;
using
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
=
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
;
const
auto
EGlobalMemoryDataOperation
=
CGlobalMemoryDataOperation
;
auto
cde_block_copy_lds_and_global
=
ThreadGroupTensorSliceTransfer_v7r3
<
ThisThreadBlock
,
decltype
(
container_concat
(
make_tuple
(
CShuffleDataType
{}),
DsDataType
{})),
Tuple
<
EDataType
>
,
decltype
(
c_ds_desc_refs
),
decltype
(
tie
(
e_grid_desc_mblock_mperblock_nblock_nperblock
)),
CElementwiseOperation
,
Sequence
<
static_cast
<
index_t
>
(
EGlobalMemoryDataOperation
)
>
,
// FIXME: make Sequence
// support arbitray type
Sequence
<
1
,
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
,
1
,
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>
,
// BlockSliceLengths,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename SrcDimAccessOrder,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DstDimAccessOrder,
3
,
// index_t SrcVectorDim,
3
,
// index_t DstVectorDim,
CDEShuffleBlockTransferScalarPerVectors
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
sequence_merge_t
<
Sequence
<
true
>
,
uniform_sequence_gen_t
<
NumDTensor
,
false
>>
,
// ThreadTransferSrcResetCoordinateAfterRunFlags
Sequence
<
false
>>
// ThreadTransferDstResetCoordinateAfterRunFlags
{
c_ds_desc_refs
,
idx_c_ds_block_begin
,
tie
(
e_grid_desc_mblock_mperblock_nblock_nperblock
),
make_tuple
(
make_multi_index
(
block_m_id
,
0
,
block_n_id
,
0
)),
c_element_op
};
// space filling curve for threadwise C in VGPR
constexpr
auto
sfc_c_vgpr
=
SpaceFillingCurve
<
Sequence
<
MXdlPerWave
,
NXdlPerWave
,
1
,
1
,
M2
,
1
,
M4
,
1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
Sequence
<
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
1
,
1
,
M2
,
1
,
M4
,
1
>>
{};
constexpr
index_t
num_access
=
sfc_c_vgpr
.
GetNumOfAccess
();
// space filling curve for shuffled blockwise C/D/E
constexpr
auto
sfc_cde_block
=
SpaceFillingCurve
<
Sequence
<
1
,
MPerBlock
,
1
,
NPerBlock
>
,
Sequence
<
0
,
2
,
1
,
3
>
,
Sequence
<
1
,
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
,
1
,
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>>
{};
static_assert
(
num_access
==
sfc_cde_block
.
GetNumOfAccess
(),
"wrong!"
);
static_for
<
0
,
num_access
,
1
>
{}([
&
](
auto
access_id
)
{
// make sure it's safe to write to LDS
block_sync_lds
();
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds
.
Run
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
sfc_c_vgpr
.
GetIndexTupleOfNumber
(
access_id
),
c_thread_buf
,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
c_shuffle_block_buf
);
// make sure it's safe to read from LDS
block_sync_lds
();
// each block copy its data from LDS to global
cde_block_copy_lds_and_global
.
Run
(
c_ds_desc_refs
,
c_ds_buf_refs
,
tie
(
e_grid_desc_mblock_mperblock_nblock_nperblock
),
tie
(
c_grid_buf
));
if
constexpr
(
access_id
<
num_access
-
1
)
{
constexpr
auto
cde_lds_and_global_step
=
sfc_cde_block
.
GetForwardStep
(
access_id
);
// move on Ds
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
cde_block_copy_lds_and_global
.
MoveSrcSliceWindow
(
c_ds_desc_refs
,
i
+
I1
,
cde_lds_and_global_step
);
});
// move on E
cde_block_copy_lds_and_global
.
MoveDstSliceWindow
(
tie
(
e_grid_desc_mblock_mperblock_nblock_nperblock
),
I0
,
cde_lds_and_global_step
);
}
});
}
}
template
<
bool
HasMainKBlockLoop
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
TailNumber
TailNum
=
TailNumber
::
Odd
>
__device__
static
void
Run_2Lds
(
const
ADataType
*
p_a_grid
,
const
BDataType
*
p_b_grid
,
DsGridPointer
&
p_ds_grid
,
CDataType
*
p_c_grid
,
void
*
p_shared_0
,
void
*
p_shared_1
,
const
Problem
&
problem
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
{
// divide block work by [M, N]
const
auto
block_2_ctile_map
=
Block2CTileMapDefault
{
problem
.
M
,
problem
.
N
,
4
};
Run_2Lds
<
Block2CTileMapDefault
,
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
,
TailNum
>
(
p_a_grid
,
p_b_grid
,
p_ds_grid
,
p_c_grid
,
p_shared_0
,
p_shared_1
,
problem
,
a_element_op
,
b_element_op
,
c_element_op
,
block_2_ctile_map
);
}
template
<
typename
Block2CTileMap
,
bool
HasMainKBlockLoop
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
TailNumber
TailNum
=
TailNumber
::
Odd
>
__device__
static
void
Run_2Lds
(
const
ADataType
*
p_a_grid
,
const
BDataType
*
p_b_grid
,
DsGridPointer
&
p_ds_grid
,
CDataType
*
p_c_grid
,
void
*
p_shared_0
,
void
*
p_shared_1
,
const
Problem
&
problem
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
const
Block2CTileMap
&
block_2_ctile_map
)
{
const
auto
a_grid_desc_ak0_m_ak1
=
MakeAGridDescriptor_AK0_M_AK1
(
problem
.
M
,
problem
.
MPadded
,
problem
.
K
,
problem
.
KPadded
,
problem
.
StrideA
,
problem
.
AK0
);
const
auto
b_grid_desc_bk0_n_bk1
=
MakeBGridDescriptor_BK0_N_BK1
(
problem
.
K
,
problem
.
KPadded
,
problem
.
N
,
problem
.
NPadded
,
problem
.
StrideB
,
problem
.
BK0
);
const
auto
c_grid_desc_m_n
=
MakeCGridDescriptor_M_N
<
CLayout
>
(
problem
.
M
,
problem
.
MPadded
,
problem
.
N
,
problem
.
NPadded
,
problem
.
StrideC
);
const
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n
,
problem
.
MBlock
,
problem
.
NBlock
);
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
());
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_grid
,
b_grid_desc_bk0_n_bk1
.
GetElementSpaceSize
());
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_grid
,
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
const
auto
block_work_idx
=
block_2_ctile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
if
(
!
block_2_ctile_map
.
ValidCTileIndex
(
block_work_idx
,
make_tuple
(
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetLength
(
I0
),
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetLength
(
I2
))))
{
return
;
}
const
index_t
block_m_id
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]);
const
index_t
block_n_id
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]);
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const
index_t
m_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_m_id
*
MPerBlock
);
const
index_t
n_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_n_id
*
NPerBlock
);
// lds max alignment
constexpr
auto
max_lds_align
=
math
::
lcm
(
AK1Number
,
BK1Number
);
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_block_desc_bk0_n_bk1
=
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
// A matrix blockwise copy
auto
a_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
AElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
AK0Number
,
MPerBlock
,
AK1Number
>
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ADataType
,
LDSTypeA
,
decltype
(
a_grid_desc_ak0_m_ak1
),
decltype
(
a_block_desc_ak0_m_ak1
),
ABlockTransferSrcAccessOrder
,
Sequence
<
0
,
1
,
2
>
,
ABlockTransferSrcVectorDim
,
2
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
1
,
1
,
AThreadTransferSrcResetCoordinateAfterRun
,
true
,
BlockwiseGemmPipe
::
GlobalBufferNum
>
(
a_grid_desc_ak0_m_ak1
,
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
a_element_op
,
a_block_desc_ak0_m_ak1
,
make_multi_index
(
0
,
0
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
// B matrix blockwise copy
auto
b_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
BK0Number
,
NPerBlock
,
BK1Number
>
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
BDataType
,
LDSTypeB
,
decltype
(
b_grid_desc_bk0_n_bk1
),
decltype
(
b_block_desc_bk0_n_bk1
),
BBlockTransferSrcAccessOrder
,
Sequence
<
0
,
1
,
2
>
,
BBlockTransferSrcVectorDim
,
2
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
1
,
1
,
BThreadTransferSrcResetCoordinateAfterRun
,
true
,
BlockwiseGemmPipe
::
GlobalBufferNum
>
(
b_grid_desc_bk0_n_bk1
,
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
b_element_op
,
b_block_desc_bk0_n_bk1
,
make_multi_index
(
0
,
0
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
auto
a_block_buf_ping
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
LDSTypeA
*>
(
p_shared_0
),
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
auto
b_block_buf_ping
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
LDSTypeB
*>
(
p_shared_0
)
+
a_block_space_size_aligned
*
sizeof
(
LDSTypeA
)
/
sizeof
(
LDSTypeB
),
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
auto
a_block_buf_pong
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
LDSTypeA
*>
(
p_shared_1
),
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
auto
b_block_buf_pong
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
LDSTypeB
*>
(
p_shared_1
)
+
a_block_space_size_aligned
*
sizeof
(
LDSTypeA
)
/
sizeof
(
LDSTypeB
),
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
auto
a_block_bufs
=
make_tuple
(
a_block_buf_ping
,
a_block_buf_pong
);
auto
b_block_bufs
=
make_tuple
(
b_block_buf_ping
,
b_block_buf_pong
);
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
AK1Number
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
BK1Number
,
0
,
0
);
// Blockwise GEMM pipeline
static_assert
(
std
::
is_default_constructible_v
<
BlockwiseGemmPipe
>
);
auto
blockwise_gemm_pipeline
=
BlockwiseGemmPipe
{};
auto
c_thread_buf
=
blockwise_gemm_pipeline
.
GetCThreadBuffer
();
const
index_t
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))
/
KPerBlock
);
blockwise_gemm_pipeline
.
template
Run
<
HasMainKBlockLoop
,
TailNum
>(
a_grid_desc_ak0_m_ak1
,
a_block_desc_ak0_m_ak1
,
a_blockwise_copy
,
a_grid_buf
,
a_block_bufs
,
a_block_slice_copy_step
,
b_grid_desc_bk0_n_bk1
,
b_block_desc_bk0_n_bk1
,
b_blockwise_copy
,
b_grid_buf
,
b_block_bufs
,
b_block_slice_copy_step
,
c_thread_buf
,
num_k_block_main_loop
);
// shuffle C and write out
{
static_assert
(
MXdlPerWave
%
CShuffleMXdlPerWavePerShuffle
==
0
&&
NXdlPerWave
%
CShuffleNXdlPerWavePerShuffle
==
0
,
"wrong!"
);
constexpr
index_t
MWave
=
MPerBlock
/
(
MXdlPerWave
*
MPerXdl
);
constexpr
index_t
NWave
=
NPerBlock
/
(
NXdlPerWave
*
NPerXdl
);
// TODO: hacky, fix it!
constexpr
auto
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
blockwise_gemm_pipeline
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
// TODO: hacky, fix it!
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
=
blockwise_gemm_pipeline
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
constexpr
auto
M0
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I0
);
constexpr
auto
N0
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I1
);
constexpr
auto
M1
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I2
);
constexpr
auto
N1
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I3
);
constexpr
auto
M2
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I4
);
constexpr
auto
M3
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I5
);
constexpr
auto
M4
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I6
);
constexpr
auto
N2
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I7
);
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
auto
c_shuffle_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
CShuffleDataType
*>
(
p_shared_0
),
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
transform_tensor_descriptor
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
make_tuple
(
make_freeze_transform
(
I0
),
make_unmerge_transform
(
make_tuple
(
Number
<
CShuffleMXdlPerWavePerShuffle
>
{},
// M0 (MXdlPerWave) per shuffle
M1
,
// M1 = MWave
M2
,
// M2 * M3 * M4 = MPerXdl
M3
,
M4
)),
make_freeze_transform
(
I0
),
make_unmerge_transform
(
make_tuple
(
Number
<
CShuffleNXdlPerWavePerShuffle
>
{},
// N0 (NXdlPerWave) per shuffle
N1
,
// N1 = NWave
N2
))),
// N2 = NPerXdl
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<>
{},
Sequence
<
0
,
2
,
4
,
5
,
6
>
{},
Sequence
<>
{},
Sequence
<
1
,
3
,
7
>
{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
blockwise_gemm_pipeline
.
CalculateCThreadOriginDataIndex
(
I0
,
I0
,
I0
,
I0
);
const
index_t
m_thread_data_on_block
=
c_thread_mtx_on_block
[
I0
];
const
index_t
n_thread_data_on_block
=
c_thread_mtx_on_block
[
I1
];
const
auto
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M0
,
M1
,
M2
,
M3
,
M4
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
m_thread_data_on_block_idx
=
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
m_thread_data_on_block
));
const
auto
n_thread_data_on_block_to_n0_n1_n2_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
N0
,
N1
,
N2
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
n_thread_data_on_block_idx
=
n_thread_data_on_block_to_n0_n1_n2_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
n_thread_data_on_block
));
// shuffle: threadwise copy C from VGPR to LDS
auto
c_thread_copy_vgpr_to_lds
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
CShuffleDataType
,
decltype
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
decltype
(
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
I1
,
I1
,
M2
,
I1
,
M4
,
I1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
7
,
1
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
{
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
make_multi_index
(
0
,
0
,
m_thread_data_on_block_idx
[
I1
],
n_thread_data_on_block_idx
[
I1
],
m_thread_data_on_block_idx
[
I2
],
m_thread_data_on_block_idx
[
I3
],
m_thread_data_on_block_idx
[
I4
],
n_thread_data_on_block_idx
[
I2
]),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
using
EDataType
=
CDataType
;
const
auto
ds_grid_desc_m_n
=
MakeDsGridDescriptor_M_N
(
problem
.
M
,
problem
.
MPadded
,
problem
.
N
,
problem
.
NPadded
,
problem
.
StrideDs
);
const
auto
ds_grid_desc_mblock_mperblock_nblock_nperblock
=
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
ds_grid_desc_m_n
,
problem
.
MBlock
,
problem
.
NBlock
);
const
auto
ds_grid_buf
=
generate_tuple
(
[
&
](
auto
i
)
{
return
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_ds_grid
[
i
],
ds_grid_desc_m_n
[
i
].
GetElementSpaceSize
());
},
Number
<
NumDTensor
>
{});
// tuple of reference to C/Ds tensor descriptors
const
auto
c_ds_desc_refs
=
concat_tuple_of_reference
(
tie
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
),
generate_tie
([
&
](
auto
i
)
->
const
auto
&
// return type should be reference
{
return
ds_grid_desc_mblock_mperblock_nblock_nperblock
[
i
];
},
Number
<
NumDTensor
>
{}));
// tuple of reference to C/Ds tensor descriptors
const
auto
c_ds_buf_refs
=
concat_tuple_of_reference
(
tie
(
c_shuffle_block_buf
),
generate_tie
([
&
](
auto
i
)
->
const
auto
&
// return type should be reference
{
return
ds_grid_buf
[
i
];
},
Number
<
NumDTensor
>
{}));
// tuple of starting index of C/Ds blockwise copy
const
auto
idx_c_ds_block_begin
=
container_concat
(
make_tuple
(
make_multi_index
(
0
,
0
,
0
,
0
)),
generate_tuple
(
[
&
](
auto
)
{
return
make_multi_index
(
block_work_idx
[
I0
],
0
,
block_work_idx
[
I1
],
0
);
},
Number
<
NumDTensor
>
{}));
const
auto
e_grid_desc_mblock_mperblock_nblock_nperblock
=
c_grid_desc_mblock_mperblock_nblock_nperblock
;
using
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
=
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
;
const
auto
EGlobalMemoryDataOperation
=
CGlobalMemoryDataOperation
;
auto
cde_block_copy_lds_and_global
=
ThreadGroupTensorSliceTransfer_v7r3
<
ThisThreadBlock
,
decltype
(
container_concat
(
make_tuple
(
CShuffleDataType
{}),
DsDataType
{})),
Tuple
<
EDataType
>
,
decltype
(
c_ds_desc_refs
),
decltype
(
tie
(
e_grid_desc_mblock_mperblock_nblock_nperblock
)),
CElementwiseOperation
,
Sequence
<
static_cast
<
index_t
>
(
EGlobalMemoryDataOperation
)
>
,
// FIXME: make Sequence
// support arbitray type
Sequence
<
1
,
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
,
1
,
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>
,
// BlockSliceLengths,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename SrcDimAccessOrder,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DstDimAccessOrder,
3
,
// index_t SrcVectorDim,
3
,
// index_t DstVectorDim,
CDEShuffleBlockTransferScalarPerVectors
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
sequence_merge_t
<
Sequence
<
true
>
,
uniform_sequence_gen_t
<
NumDTensor
,
false
>>
,
// ThreadTransferSrcResetCoordinateAfterRunFlags
Sequence
<
false
>>
// ThreadTransferDstResetCoordinateAfterRunFlags
{
c_ds_desc_refs
,
idx_c_ds_block_begin
,
tie
(
e_grid_desc_mblock_mperblock_nblock_nperblock
),
make_tuple
(
make_multi_index
(
block_m_id
,
0
,
block_n_id
,
0
)),
c_element_op
};
// space filling curve for threadwise C in VGPR
constexpr
auto
sfc_c_vgpr
=
SpaceFillingCurve
<
Sequence
<
MXdlPerWave
,
NXdlPerWave
,
1
,
1
,
M2
,
1
,
M4
,
1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
Sequence
<
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
1
,
1
,
M2
,
1
,
M4
,
1
>>
{};
constexpr
index_t
num_access
=
sfc_c_vgpr
.
GetNumOfAccess
();
// space filling curve for shuffled blockwise C/D/E
constexpr
auto
sfc_cde_block
=
SpaceFillingCurve
<
Sequence
<
1
,
MPerBlock
,
1
,
NPerBlock
>
,
Sequence
<
0
,
2
,
1
,
3
>
,
Sequence
<
1
,
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
,
1
,
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>>
{};
static_assert
(
num_access
==
sfc_cde_block
.
GetNumOfAccess
(),
"wrong!"
);
static_for
<
0
,
num_access
,
1
>
{}([
&
](
auto
access_id
)
{
// make sure it's safe to write to LDS
block_sync_lds
();
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds
.
Run
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
sfc_c_vgpr
.
GetIndexTupleOfNumber
(
access_id
),
c_thread_buf
,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
c_shuffle_block_buf
);
// make sure it's safe to read from LDS
block_sync_lds
();
// each block copy its data from LDS to global
cde_block_copy_lds_and_global
.
Run
(
c_ds_desc_refs
,
c_ds_buf_refs
,
tie
(
e_grid_desc_mblock_mperblock_nblock_nperblock
),
tie
(
c_grid_buf
));
if
constexpr
(
access_id
<
num_access
-
1
)
{
constexpr
auto
cde_lds_and_global_step
=
sfc_cde_block
.
GetForwardStep
(
access_id
);
// move on Ds
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
cde_block_copy_lds_and_global
.
MoveSrcSliceWindow
(
c_ds_desc_refs
,
i
+
I1
,
cde_lds_and_global_step
);
});
// move on E
cde_block_copy_lds_and_global
.
MoveDstSliceWindow
(
tie
(
e_grid_desc_mblock_mperblock_nblock_nperblock
),
I0
,
cde_lds_and_global_step
);
}
});
}
}
};
}
// namespace ck
library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp
View file @
5e8b5703
...
@@ -117,6 +117,7 @@ using MultiplyFastGelu = ck::tensor_operation::element_wise::MultiplyFastGelu
...
@@ -117,6 +117,7 @@ using MultiplyFastGelu = ck::tensor_operation::element_wise::MultiplyFastGelu
using
AddMultiply
=
ck
::
tensor_operation
::
element_wise
::
AddMultiply
;
using
AddMultiply
=
ck
::
tensor_operation
::
element_wise
::
AddMultiply
;
using
MultiplyAdd
=
ck
::
tensor_operation
::
element_wise
::
MultiplyAdd
;
using
MultiplyAdd
=
ck
::
tensor_operation
::
element_wise
::
MultiplyAdd
;
using
MultiplyMultiply
=
ck
::
tensor_operation
::
element_wise
::
MultiplyMultiply
;
using
MultiplyMultiply
=
ck
::
tensor_operation
::
element_wise
::
MultiplyMultiply
;
using
MultiplyMultiplyAdd
=
ck
::
tensor_operation
::
element_wise
::
MultiplyMultiplyAdd
;
using
ScaleAdd
=
ck
::
tensor_operation
::
element_wise
::
ScaleAdd
;
using
ScaleAdd
=
ck
::
tensor_operation
::
element_wise
::
ScaleAdd
;
using
Gelu
=
ck
::
tensor_operation
::
element_wise
::
Gelu
;
using
Gelu
=
ck
::
tensor_operation
::
element_wise
::
Gelu
;
using
Swish
=
ck
::
tensor_operation
::
element_wise
::
Swish
;
using
Swish
=
ck
::
tensor_operation
::
element_wise
::
Swish
;
...
...
library/include/ck/library/tensor_operation_instance/gpu/gemm_multiply_multiply_add.hpp
0 → 100644
View file @
5e8b5703
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_INT8))
void
add_device_gemm_multiply_multiply_add_xdl_i8_i8_bf16_mk_nk_mn_comp_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Tuple
<
Row
,
Col
,
Row
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
,
BF16
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiplyAdd
>>>&
instances
);
void
add_device_gemm_multiply_multiply_add_xdl_i8_i8_bf16_mk_nk_mn_comp_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Tuple
<
Row
,
Col
,
Row
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
,
BF16
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiplyAdd
>>>&
instances
);
void
add_device_gemm_multiply_multiply_add_xdl_i8_i8_bf16_mk_nk_mn_mem_v1_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Tuple
<
Row
,
Col
,
Row
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
,
BF16
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiplyAdd
>>>&
instances
);
void
add_device_gemm_multiply_multiply_add_xdl_i8_i8_bf16_mk_nk_mn_mem_v1_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Tuple
<
Row
,
Col
,
Row
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
,
BF16
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiplyAdd
>>>&
instances
);
void
add_device_gemm_multiply_multiply_add_xdl_i8_i8_bf16_mk_nk_mn_mem_v2_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Tuple
<
Row
,
Col
,
Row
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
,
BF16
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiplyAdd
>>>&
instances
);
void
add_device_gemm_multiply_multiply_add_xdl_i8_i8_bf16_mk_nk_mn_mem_v2_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Tuple
<
Row
,
Col
,
Row
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
,
BF16
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiplyAdd
>>>&
instances
);
#endif
template
<
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
DsDataType
,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceGemmMultipleDSplitK
<
ALayout
,
BLayout
,
Tuple
<
Row
,
Col
,
Row
>
,
CLayout
,
ADataType
,
BDataType
,
DsDataType
,
CDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
MultiplyMultiplyAdd
>>
{
using
DeviceOp
=
DeviceGemmMultipleDSplitK
<
ALayout
,
BLayout
,
Tuple
<
Row
,
Col
,
Row
>
,
CLayout
,
ADataType
,
BDataType
,
DsDataType
,
CDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
MultiplyMultiplyAdd
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_INT8))
if
constexpr
(
is_same_v
<
ADataType
,
int8_t
>
&&
is_same_v
<
BDataType
,
int8_t
>
&&
is_same_v
<
CDataType
,
bhalf_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_gemm_multiply_multiply_add_xdl_i8_i8_bf16_mk_nk_mn_comp_default_instances
(
op_ptrs
);
add_device_gemm_multiply_multiply_add_xdl_i8_i8_bf16_mk_nk_mn_comp_kpadding_instances
(
op_ptrs
);
add_device_gemm_multiply_multiply_add_xdl_i8_i8_bf16_mk_nk_mn_mem_v1_default_instances
(
op_ptrs
);
add_device_gemm_multiply_multiply_add_xdl_i8_i8_bf16_mk_nk_mn_mem_v1_kpadding_instances
(
op_ptrs
);
add_device_gemm_multiply_multiply_add_xdl_i8_i8_bf16_mk_nk_mn_mem_v2_default_instances
(
op_ptrs
);
add_device_gemm_multiply_multiply_add_xdl_i8_i8_bf16_mk_nk_mn_mem_v2_kpadding_instances
(
op_ptrs
);
}
}
#endif
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_i8_i8_bf16/device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn.hpp
View file @
5e8b5703
...
@@ -45,6 +45,41 @@ using device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_comp_instances = std
...
@@ -45,6 +45,41 @@ using device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_comp_instances = std
// Compute friendly
// Compute friendly
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
256
,
256
,
64
,
16
,
16
,
32
,
32
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
256
,
256
,
64
,
16
,
16
,
32
,
32
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
256
,
256
,
128
,
16
,
16
,
32
,
32
,
4
,
4
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
256
,
224
,
128
,
16
,
16
,
32
,
32
,
2
,
7
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
2
,
1
,
S
<
1
,
64
,
1
,
4
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
256
,
192
,
128
,
16
,
16
,
32
,
32
,
4
,
3
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
256
,
160
,
128
,
16
,
16
,
32
,
32
,
2
,
5
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
2
,
1
,
S
<
1
,
64
,
1
,
4
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
256
,
128
,
128
,
16
,
16
,
32
,
32
,
4
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
256
,
96
,
128
,
16
,
16
,
32
,
32
,
2
,
3
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
2
,
1
,
S
<
1
,
64
,
1
,
4
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
256
,
64
,
128
,
16
,
16
,
32
,
32
,
4
,
1
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
128
,
256
,
128
,
16
,
16
,
32
,
32
,
2
,
4
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
128
,
224
,
128
,
16
,
16
,
32
,
32
,
1
,
7
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
64
,
1
,
4
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
128
,
192
,
128
,
16
,
16
,
32
,
32
,
2
,
3
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
128
,
160
,
128
,
16
,
16
,
32
,
32
,
1
,
5
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
64
,
1
,
4
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
128
,
128
,
256
,
16
,
16
,
32
,
32
,
2
,
2
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
128
,
128
,
128
,
16
,
16
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
128
,
96
,
256
,
16
,
16
,
32
,
32
,
1
,
3
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
64
,
1
,
4
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
128
,
64
,
256
,
16
,
16
,
32
,
32
,
2
,
1
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
64
,
256
,
128
,
16
,
16
,
32
,
32
,
1
,
4
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
64
,
224
,
128
,
16
,
16
,
16
,
16
,
2
,
7
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
2
,
1
,
S
<
1
,
64
,
1
,
4
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
64
,
192
,
256
,
16
,
16
,
32
,
32
,
1
,
3
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
64
,
192
,
128
,
16
,
16
,
32
,
32
,
1
,
3
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
64
,
160
,
256
,
16
,
16
,
16
,
16
,
2
,
5
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
2
,
1
,
S
<
1
,
64
,
1
,
4
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
64
,
128
,
256
,
16
,
16
,
32
,
32
,
1
,
2
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
64
,
96
,
256
,
16
,
16
,
16
,
16
,
2
,
3
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
2
,
1
,
S
<
1
,
64
,
1
,
4
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
64
,
64
,
512
,
16
,
16
,
32
,
32
,
1
,
1
,
S
<
32
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
32
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
32
,
256
,
128
,
16
,
16
,
32
,
32
,
1
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
32
,
224
,
256
,
16
,
16
,
16
,
16
,
1
,
7
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
4
,
4
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
32
,
192
,
256
,
16
,
16
,
16
,
16
,
1
,
6
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
32
,
160
,
256
,
16
,
16
,
16
,
16
,
1
,
5
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
4
,
4
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
32
,
128
,
256
,
16
,
16
,
32
,
32
,
1
,
1
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
32
,
96
,
256
,
16
,
16
,
16
,
16
,
1
,
3
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
4
,
4
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
32
,
64
,
512
,
16
,
16
,
16
,
16
,
1
,
2
,
S
<
32
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
32
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
16
,
256
,
128
,
8
,
16
,
16
,
16
,
1
,
4
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
2
,
S
<
1
,
16
,
1
,
16
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
16
,
192
,
256
,
16
,
16
,
16
,
16
,
1
,
3
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
S
<
4
,
4
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
16
,
128
,
256
,
16
,
16
,
16
,
16
,
1
,
2
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
2
,
S
<
1
,
16
,
1
,
16
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
16
,
64
,
512
,
16
,
16
,
16
,
16
,
1
,
1
,
S
<
32
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
32
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
S
<
4
,
4
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
128
,
128
,
128
,
16
,
16
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
128
,
128
,
128
,
16
,
16
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
128
,
128
,
64
,
16
,
16
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
128
,
128
,
64
,
16
,
16
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
256
,
256
,
128
,
16
,
16
,
16
,
16
,
8
,
8
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
256
,
256
,
128
,
16
,
16
,
16
,
16
,
8
,
8
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
...
@@ -59,6 +94,7 @@ using device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_comp_instances = std
...
@@ -59,6 +94,7 @@ using device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_comp_instances = std
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
128
,
64
,
128
,
16
,
16
,
32
,
32
,
2
,
1
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
128
,
64
,
128
,
16
,
16
,
32
,
32
,
2
,
1
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
64
,
128
,
128
,
16
,
16
,
32
,
32
,
1
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
64
,
128
,
128
,
16
,
16
,
32
,
32
,
1
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
64
,
64
,
128
,
16
,
16
,
32
,
32
,
1
,
1
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
64
,
64
,
128
,
16
,
16
,
32
,
32
,
1
,
1
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
// clang-format oI
// clang-format oI
>
;
>
;
...
...
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_add/CMakeLists.txt
0 → 100644
View file @
5e8b5703
# ONLY XDL_KERNELS
set
(
GEMM_MULTIPLY_MULTIPLY_ADD_INSTANCES
)
list
(
APPEND GEMM_MULTIPLY_MULTIPLY_ADD_INSTANCES
device_gemm_multiply_multiply_add_xdl_i8_i8_bf16/device_gemm_multiply_multiply_add_xdl_i8_i8_bf16_mk_nk_mn_comp_default_instance.cpp
device_gemm_multiply_multiply_add_xdl_i8_i8_bf16/device_gemm_multiply_multiply_add_xdl_i8_i8_bf16_mk_nk_mn_comp_kpadding_instance.cpp
device_gemm_multiply_multiply_add_xdl_i8_i8_bf16/device_gemm_multiply_multiply_add_xdl_i8_i8_bf16_mk_nk_mn_mem_v1_default_instance.cpp
device_gemm_multiply_multiply_add_xdl_i8_i8_bf16/device_gemm_multiply_multiply_add_xdl_i8_i8_bf16_mk_nk_mn_mem_v1_kpadding_instance.cpp
device_gemm_multiply_multiply_add_xdl_i8_i8_bf16/device_gemm_multiply_multiply_add_xdl_i8_i8_bf16_mk_nk_mn_mem_v2_default_instance.cpp
device_gemm_multiply_multiply_add_xdl_i8_i8_bf16/device_gemm_multiply_multiply_add_xdl_i8_i8_bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp
)
set_source_files_properties
(
device_gemm_multiply_multiply_add_xdl_i8_i8_bf16/device_gemm_multiply_multiply_add_xdl_i8_i8_bf16_mk_nk_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS
";-mllvm;-greedy-reverse-local-assignment=1"
)
set_source_files_properties
(
device_gemm_multiply_multiply_add_xdl_i8_i8_bf16/device_gemm_multiply_multiply_add_xdl_i8_i8_bf16_mk_nk_mn_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS
";-mllvm;-greedy-reverse-local-assignment=1"
)
add_instance_library
(
device_gemm_multiply_multiply_add_instance
${
GEMM_MULTIPLY_MULTIPLY_ADD_INSTANCES
}
)
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_add/device_gemm_multiply_multiply_add_xdl_i8_i8_bf16/device_gemm_multiply_multiply_add_xdl_i8_i8_bf16_mk_nk_mn.hpp
0 → 100644
View file @
5e8b5703
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
I8
=
int8_t
;
using
I32
=
int
;
using
BF16
=
bhalf_t
;
using
F32
=
float
;
using
Row
=
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
tensor_layout
::
gemm
::
ColumnMajor
;
template
<
index_t
...
Is
>
using
S
=
Sequence
<
Is
...
>
;
using
PassThrough
=
element_wise
::
PassThrough
;
using
MultiplyMultiplyAdd
=
element_wise
::
MultiplyMultiplyAdd
;
static
constexpr
auto
GemmDefault
=
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmKPadding
=
GemmSpecialization
::
KPadding
;
static
constexpr
auto
GemmMNPadding
=
GemmSpecialization
::
MNPadding
;
static
constexpr
auto
GemmMNKPadding
=
GemmSpecialization
::
MNKPadding
;
static
constexpr
auto
Intrawave
=
BlockGemmPipelineScheduler
::
Intrawave
;
static
constexpr
auto
Interwave
=
BlockGemmPipelineScheduler
::
Interwave
;
template
<
GemmSpecialization
GemmSpec
>
using
device_gemm_multiply_multiply_add_xdl_i8_i8_bf16_mk_nk_mn_comp_instances
=
std
::
tuple
<
// clang-format off
//################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |S<C, D0, D1,D2>| | |
// Compute friendly
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
,
Row
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
,
BF16
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiplyAdd
,
GemmSpec
,
256
,
256
,
256
,
64
,
16
,
16
,
32
,
32
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
,
8
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
,
Row
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
,
BF16
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiplyAdd
,
GemmSpec
,
256
,
256
,
256
,
128
,
16
,
16
,
32
,
32
,
4
,
4
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
,
8
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
,
Row
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
,
BF16
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiplyAdd
,
GemmSpec
,
256
,
256
,
224
,
128
,
16
,
16
,
32
,
32
,
2
,
7
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
2
,
1
,
S
<
1
,
64
,
1
,
4
>
,
S
<
8
,
8
,
1
,
8
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
,
Row
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
,
BF16
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiplyAdd
,
GemmSpec
,
256
,
256
,
192
,
128
,
16
,
16
,
32
,
32
,
4
,
3
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
,
8
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
,
Row
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
,
BF16
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiplyAdd
,
GemmSpec
,
256
,
256
,
160
,
128
,
16
,
16
,
32
,
32
,
2
,
5
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
2
,
1
,
S
<
1
,
64
,
1
,
4
>
,
S
<
8
,
8
,
1
,
8
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
,
Row
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
,
BF16
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiplyAdd
,
GemmSpec
,
256
,
256
,
128
,
128
,
16
,
16
,
32
,
32
,
4
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
,
8
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
,
Row
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
,
BF16
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiplyAdd
,
GemmSpec
,
256
,
256
,
96
,
128
,
16
,
16
,
32
,
32
,
2
,
3
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
2
,
1
,
S
<
1
,
64
,
1
,
4
>
,
S
<
8
,
8
,
1
,
8
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
,
Row
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
,
BF16
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiplyAdd
,
GemmSpec
,
256
,
256
,
64
,
128
,
16
,
16
,
32
,
32
,
4
,
1
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
,
8
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
,
Row
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
,
BF16
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiplyAdd
,
GemmSpec
,
256
,
128
,
256
,
128
,
16
,
16
,
32
,
32
,
2
,
4
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
,
8
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
,
Row
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
,
BF16
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiplyAdd
,
GemmSpec
,
256
,
128
,
224
,
128
,
16
,
16
,
32
,
32
,
1
,
7
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
64
,
1
,
4
>
,
S
<
8
,
8
,
1
,
8
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
,
Row
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
,
BF16
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiplyAdd
,
GemmSpec
,
256
,
128
,
192
,
128
,
16
,
16
,
32
,
32
,
2
,
3
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
,
8
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
,
Row
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
,
BF16
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiplyAdd
,
GemmSpec
,
256
,
128
,
160
,
128
,
16
,
16
,
32
,
32
,
1
,
5
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
64
,
1
,
4
>
,
S
<
8
,
8
,
1
,
8
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
,
Row
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
,
BF16
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiplyAdd
,
GemmSpec
,
256
,
128
,
128
,
256
,
16
,
16
,
32
,
32
,
2
,
2
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
,
8
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
,
Row
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
,
BF16
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiplyAdd
,
GemmSpec
,
256
,
128
,
128
,
128
,
16
,
16
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
,
8
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
,
Row
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
,
BF16
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiplyAdd
,
GemmSpec
,
256
,
128
,
96
,
256
,
16
,
16
,
32
,
32
,
1
,
3
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
64
,
1
,
4
>
,
S
<
8
,
8
,
1
,
8
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
,
Row
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
,
BF16
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiplyAdd
,
GemmSpec
,
256
,
128
,
64
,
256
,
16
,
16
,
32
,
32
,
2
,
1
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
,
8
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
,
Row
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
,
BF16
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiplyAdd
,
GemmSpec
,
256
,
64
,
256
,
128
,
16
,
16
,
32
,
32
,
1
,
4
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
,
8
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
,
Row
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
,
BF16
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiplyAdd
,
GemmSpec
,
256
,
64
,
224
,
128
,
16
,
16
,
16
,
16
,
2
,
7
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
2
,
1
,
S
<
1
,
64
,
1
,
4
>
,
S
<
8
,
8
,
1
,
8
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
,
Row
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
,
BF16
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiplyAdd
,
GemmSpec
,
256
,
64
,
192
,
256
,
16
,
16
,
32
,
32
,
1
,
3
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
,
8
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
,
Row
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
,
BF16
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiplyAdd
,
GemmSpec
,
256
,
64
,
192
,
128
,
16
,
16
,
32
,
32
,
1
,
3
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
,
8
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
,
Row
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
,
BF16
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiplyAdd
,
GemmSpec
,
256
,
64
,
160
,
256
,
16
,
16
,
16
,
16
,
2
,
5
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
2
,
1
,
S
<
1
,
64
,
1
,
4
>
,
S
<
8
,
8
,
1
,
8
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
,
Row
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
,
BF16
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiplyAdd
,
GemmSpec
,
256
,
64
,
128
,
256
,
16
,
16
,
32
,
32
,
1
,
2
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
,
8
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
,
Row
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
,
BF16
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiplyAdd
,
GemmSpec
,
256
,
64
,
96
,
256
,
16
,
16
,
16
,
16
,
2
,
3
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
2
,
1
,
S
<
1
,
64
,
1
,
4
>
,
S
<
8
,
8
,
1
,
8
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
,
Row
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
,
BF16
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiplyAdd
,
GemmSpec
,
256
,
64
,
64
,
512
,
16
,
16
,
32
,
32
,
1
,
1
,
S
<
32
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
32
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
,
8
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
,
Row
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
,
BF16
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiplyAdd
,
GemmSpec
,
256
,
32
,
256
,
128
,
16
,
16
,
32
,
32
,
1
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
,
8
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
,
Row
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
,
BF16
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiplyAdd
,
GemmSpec
,
256
,
32
,
224
,
256
,
16
,
16
,
16
,
16
,
1
,
7
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
4
,
4
,
1
,
4
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
,
Row
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
,
BF16
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiplyAdd
,
GemmSpec
,
256
,
32
,
192
,
256
,
16
,
16
,
16
,
16
,
1
,
6
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
,
8
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
,
Row
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
,
BF16
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiplyAdd
,
GemmSpec
,
256
,
32
,
160
,
256
,
16
,
16
,
16
,
16
,
1
,
5
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
4
,
4
,
1
,
4
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
,
Row
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
,
BF16
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiplyAdd
,
GemmSpec
,
256
,
32
,
128
,
256
,
16
,
16
,
32
,
32
,
1
,
1
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
,
8
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
,
Row
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
,
BF16
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiplyAdd
,
GemmSpec
,
256
,
32
,
96
,
256
,
16
,
16
,
16
,
16
,
1
,
3
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
4
,
4
,
1
,
4
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
,
Row
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
,
BF16
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiplyAdd
,
GemmSpec
,
256
,
32
,
64
,
512
,
16
,
16
,
16
,
16
,
1
,
2
,
S
<
32
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
32
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
,
8
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
,
Row
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
,
BF16
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiplyAdd
,
GemmSpec
,
256
,
16
,
256
,
128
,
8
,
16
,
16
,
16
,
1
,
4
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
2
,
S
<
1
,
16
,
1
,
16
>
,
S
<
8
,
8
,
1
,
8
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
,
Row
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
,
BF16
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiplyAdd
,
GemmSpec
,
256
,
16
,
192
,
256
,
16
,
16
,
16
,
16
,
1
,
3
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
S
<
4
,
4
,
1
,
4
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
,
Row
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
,
BF16
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiplyAdd
,
GemmSpec
,
256
,
16
,
128
,
256
,
16
,
16
,
16
,
16
,
1
,
2
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
2
,
S
<
1
,
16
,
1
,
16
>
,
S
<
8
,
8
,
1
,
8
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
,
Row
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
,
BF16
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiplyAdd
,
GemmSpec
,
256
,
16
,
64
,
512
,
16
,
16
,
16
,
16
,
1
,
1
,
S
<
32
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
32
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
S
<
4
,
4
,
1
,
4
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
,
Row
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
,
BF16
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiplyAdd
,
GemmSpec
,
256
,
128
,
128
,
128
,
16
,
16
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
,
8
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
,
Row
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
,
BF16
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiplyAdd
,
GemmSpec
,
256
,
128
,
128
,
64
,
16
,
16
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
,
8
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
,
Row
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
,
BF16
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiplyAdd
,
GemmSpec
,
256
,
256
,
256
,
128
,
16
,
16
,
16
,
16
,
8
,
8
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
,
8
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
,
Row
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
,
BF16
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiplyAdd
,
GemmSpec
,
256
,
256
,
256
,
64
,
16
,
16
,
16
,
16
,
8
,
8
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
,
8
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
,
Row
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
,
BF16
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiplyAdd
,
GemmSpec
,
256
,
224
,
256
,
128
,
16
,
16
,
16
,
16
,
7
,
8
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
,
8
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
,
Row
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
,
BF16
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiplyAdd
,
GemmSpec
,
256
,
256
,
224
,
128
,
16
,
16
,
16
,
16
,
8
,
7
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
2
,
1
,
S
<
1
,
64
,
1
,
4
>
,
S
<
8
,
8
,
1
,
8
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
// DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col, Row>, Row, I8, I8, Tuple<F32, F32, BF16>, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiplyAdd, GemmSpec, 256, 128, 128, 128, 16, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, I8>,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
,
Row
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
,
BF16
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiplyAdd
,
GemmSpec
,
256
,
128
,
128
,
128
,
16
,
16
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
,
8
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v5
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
,
Row
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
,
BF16
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiplyAdd
,
GemmSpec
,
256
,
128
,
256
,
64
,
16
,
16
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
,
8
>
,
BlockGemmPipelineScheduler
::
Interwave
,
BlockGemmPipelineVersion
::
v1
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
,
Row
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
,
BF16
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiplyAdd
,
GemmSpec
,
256
,
256
,
128
,
64
,
16
,
16
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
,
8
>
,
BlockGemmPipelineScheduler
::
Interwave
,
BlockGemmPipelineVersion
::
v1
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
,
Row
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
,
BF16
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiplyAdd
,
GemmSpec
,
256
,
128
,
128
,
128
,
16
,
16
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
,
8
>
,
BlockGemmPipelineScheduler
::
Interwave
,
BlockGemmPipelineVersion
::
v1
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
,
Row
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
,
BF16
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiplyAdd
,
GemmSpec
,
256
,
128
,
64
,
128
,
16
,
16
,
32
,
32
,
2
,
1
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
,
8
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
,
Row
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
,
BF16
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiplyAdd
,
GemmSpec
,
256
,
64
,
128
,
128
,
16
,
16
,
32
,
32
,
1
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
,
8
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
,
Row
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
,
BF16
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiplyAdd
,
GemmSpec
,
256
,
64
,
64
,
128
,
16
,
16
,
32
,
32
,
1
,
1
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
,
8
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
// clang-format oI
>
;
template
<
BlockGemmPipelineScheduler
BlkGemmPipeSched
,
GemmSpecialization
GemmSpec
>
using
device_gemm_multiply_multiply_add_xdl_i8_i8_bf16_mk_nk_mn_mem_instances
=
std
::
tuple
<
// clang-format off
//################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// Latency friendly
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
,
Row
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
,
BF16
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiplyAdd
,
GemmSpec
,
128
,
32
,
16
,
128
,
16
,
16
,
16
,
16
,
1
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
S
<
2
,
2
,
1
,
2
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
,
Row
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
,
BF16
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiplyAdd
,
GemmSpec
,
64
,
16
,
16
,
128
,
16
,
16
,
16
,
16
,
1
,
1
,
S
<
8
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
S
<
4
,
4
,
1
,
4
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
,
Row
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
,
BF16
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiplyAdd
,
GemmSpec
,
128
,
16
,
32
,
128
,
16
,
16
,
16
,
16
,
1
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
S
<
4
,
4
,
1
,
4
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
,
I8
>
,
// Memory friendly, Col
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
,
Row
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
,
BF16
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiplyAdd
,
GemmSpec
,
256
,
256
,
32
,
128
,
16
,
16
,
32
,
32
,
2
,
1
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
4
,
4
,
1
,
4
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
,
Row
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
,
BF16
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiplyAdd
,
GemmSpec
,
256
,
256
,
16
,
128
,
16
,
16
,
16
,
16
,
4
,
1
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
2
,
2
,
1
,
2
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
,
Row
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
,
BF16
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiplyAdd
,
GemmSpec
,
128
,
128
,
32
,
128
,
16
,
16
,
32
,
32
,
2
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
S
<
4
,
4
,
1
,
4
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
,
Row
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
,
BF16
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiplyAdd
,
GemmSpec
,
128
,
128
,
16
,
128
,
16
,
16
,
16
,
16
,
4
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
S
<
2
,
2
,
1
,
2
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
,
Row
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
,
BF16
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiplyAdd
,
GemmSpec
,
128
,
64
,
32
,
128
,
16
,
16
,
32
,
32
,
1
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
S
<
4
,
4
,
1
,
4
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
,
Row
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
,
BF16
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiplyAdd
,
GemmSpec
,
128
,
64
,
16
,
128
,
16
,
16
,
16
,
16
,
2
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
S
<
2
,
2
,
1
,
2
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
,
Row
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
,
BF16
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiplyAdd
,
GemmSpec
,
128
,
32
,
16
,
128
,
16
,
16
,
16
,
16
,
1
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
S
<
2
,
2
,
1
,
2
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
,
Row
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
,
BF16
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiplyAdd
,
GemmSpec
,
64
,
16
,
16
,
64
,
16
,
16
,
16
,
16
,
1
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
S
<
4
,
4
,
1
,
4
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
,
Row
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
,
BF16
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiplyAdd
,
GemmSpec
,
64
,
16
,
16
,
128
,
16
,
16
,
16
,
16
,
1
,
1
,
S
<
8
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
S
<
4
,
4
,
1
,
4
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
,
Row
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
,
BF16
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiplyAdd
,
GemmSpec
,
128
,
16
,
32
,
128
,
16
,
16
,
16
,
16
,
1
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
S
<
4
,
4
,
1
,
4
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
,
Row
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
,
BF16
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiplyAdd
,
GemmSpec
,
128
,
16
,
64
,
128
,
16
,
16
,
16
,
16
,
1
,
2
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
S
<
4
,
4
,
1
,
4
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
,
Row
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
,
BF16
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiplyAdd
,
GemmSpec
,
128
,
32
,
64
,
128
,
16
,
16
,
32
,
32
,
1
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
S
<
8
,
8
,
1
,
8
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
,
Row
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
,
BF16
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiplyAdd
,
GemmSpec
,
128
,
16
,
128
,
128
,
16
,
16
,
16
,
16
,
1
,
4
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
S
<
4
,
4
,
1
,
4
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
,
Row
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
,
BF16
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiplyAdd
,
GemmSpec
,
128
,
32
,
128
,
128
,
16
,
16
,
32
,
32
,
1
,
2
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
S
<
8
,
8
,
1
,
8
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
,
Row
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
,
BF16
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiplyAdd
,
GemmSpec
,
256
,
16
,
256
,
128
,
16
,
16
,
16
,
16
,
1
,
4
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
S
<
4
,
4
,
1
,
4
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
,
Row
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
,
BF16
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiplyAdd
,
GemmSpec
,
256
,
32
,
256
,
128
,
16
,
16
,
32
,
32
,
1
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
S
<
8
,
8
,
1
,
8
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
I8
>
// clang-format oI
>
;
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_add/device_gemm_multiply_multiply_add_xdl_i8_i8_bf16/device_gemm_multiply_multiply_add_xdl_i8_i8_bf16_mk_nk_mn_comp_default_instance.cpp
0 → 100644
View file @
5e8b5703
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_multiply_multiply_add_xdl_i8_i8_bf16_mk_nk_mn.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_multiply_multiply_add_xdl_i8_i8_bf16_mk_nk_mn_comp_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Tuple
<
Row
,
Col
,
Row
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
,
BF16
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiplyAdd
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_multiply_multiply_add_xdl_i8_i8_bf16_mk_nk_mn_comp_instances
<
GemmDefault
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_add/device_gemm_multiply_multiply_add_xdl_i8_i8_bf16/device_gemm_multiply_multiply_add_xdl_i8_i8_bf16_mk_nk_mn_comp_kpadding_instance.cpp
0 → 100644
View file @
5e8b5703
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_multiply_multiply_add_xdl_i8_i8_bf16_mk_nk_mn.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_multiply_multiply_add_xdl_i8_i8_bf16_mk_nk_mn_comp_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Tuple
<
Row
,
Col
,
Row
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
,
BF16
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiplyAdd
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_multiply_multiply_add_xdl_i8_i8_bf16_mk_nk_mn_comp_instances
<
GemmKPadding
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_add/device_gemm_multiply_multiply_add_xdl_i8_i8_bf16/device_gemm_multiply_multiply_add_xdl_i8_i8_bf16_mk_nk_mn_mem_v1_default_instance.cpp
0 → 100644
View file @
5e8b5703
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_multiply_multiply_add_xdl_i8_i8_bf16_mk_nk_mn.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_multiply_multiply_add_xdl_i8_i8_bf16_mk_nk_mn_mem_v1_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Tuple
<
Row
,
Col
,
Row
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
,
BF16
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiplyAdd
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_multiply_multiply_add_xdl_i8_i8_bf16_mk_nk_mn_mem_instances
<
Intrawave
,
GemmDefault
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_add/device_gemm_multiply_multiply_add_xdl_i8_i8_bf16/device_gemm_multiply_multiply_add_xdl_i8_i8_bf16_mk_nk_mn_mem_v1_kpadding_instance.cpp
0 → 100644
View file @
5e8b5703
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_multiply_multiply_add_xdl_i8_i8_bf16_mk_nk_mn.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_multiply_multiply_add_xdl_i8_i8_bf16_mk_nk_mn_mem_v1_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Tuple
<
Row
,
Col
,
Row
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
,
BF16
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiplyAdd
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_multiply_multiply_add_xdl_i8_i8_bf16_mk_nk_mn_mem_instances
<
Intrawave
,
GemmKPadding
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_add/device_gemm_multiply_multiply_add_xdl_i8_i8_bf16/device_gemm_multiply_multiply_add_xdl_i8_i8_bf16_mk_nk_mn_mem_v2_default_instance.cpp
0 → 100644
View file @
5e8b5703
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_multiply_multiply_add_xdl_i8_i8_bf16_mk_nk_mn.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_multiply_multiply_add_xdl_i8_i8_bf16_mk_nk_mn_mem_v2_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Tuple
<
Row
,
Col
,
Row
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
,
BF16
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiplyAdd
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_multiply_multiply_add_xdl_i8_i8_bf16_mk_nk_mn_mem_instances
<
Interwave
,
GemmDefault
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_add/device_gemm_multiply_multiply_add_xdl_i8_i8_bf16/device_gemm_multiply_multiply_add_xdl_i8_i8_bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp
0 → 100644
View file @
5e8b5703
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_multiply_multiply_add_xdl_i8_i8_bf16_mk_nk_mn.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_multiply_multiply_add_xdl_i8_i8_bf16_mk_nk_mn_mem_v2_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Tuple
<
Row
,
Col
,
Row
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
,
BF16
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiplyAdd
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_multiply_multiply_add_xdl_i8_i8_bf16_mk_nk_mn_mem_instances
<
Interwave
,
GemmKPadding
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
profiler/include/profiler/profile_gemm_multiply_multiply_add_impl.hpp
0 → 100644
View file @
5e8b5703
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iomanip>
#include <iostream>
#include <typeinfo>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/gemm_multiply_multiply_add.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
namespace
ck
{
namespace
profiler
{
template
<
typename
ADataType
,
typename
BDataType
,
typename
ComputeDataType
,
typename
AccDataType
,
typename
D0DataType
,
typename
D1DataType
,
typename
D2DataType
,
typename
EDataType
,
typename
ALayout
,
typename
BLayout
,
typename
D0Layout
,
typename
D1Layout
,
typename
D2Layout
,
typename
ELayout
>
bool
profile_gemm_multiply_multiply_add_impl
(
int
do_verification
,
int
init_method
,
bool
do_log
,
bool
time_kernel
,
int
M
,
int
N
,
int
K
,
int
StrideA
,
int
StrideB
,
int
StrideD0
,
int
StrideD1
,
int
StrideD2
,
int
StrideE
,
int
KBatch
,
int
n_warmup
,
int
n_iter
,
uint64_t
rotating
=
0
)
{
bool
pass
=
true
;
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
using
namespace
ck
::
literals
;
if
(
is_same
<
decltype
(
layout
),
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
return
HostTensorDescriptor
({
row
,
col
},
{
stride
,
1
_uz
});
}
else
{
return
HostTensorDescriptor
({
row
,
col
},
{
1
_uz
,
stride
});
}
};
Tensor
<
ADataType
>
a_m_k
(
f_host_tensor_descriptor
(
M
,
K
,
StrideA
,
ALayout
{}));
Tensor
<
BDataType
>
b_k_n
(
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
BLayout
{}));
Tensor
<
D0DataType
>
d0_m_n
(
f_host_tensor_descriptor
(
M
,
N
,
StrideD0
,
D0Layout
{}));
Tensor
<
D1DataType
>
d1_m_n
(
f_host_tensor_descriptor
(
M
,
N
,
StrideD1
,
D1Layout
{}));
Tensor
<
D2DataType
>
d2_m_n
(
f_host_tensor_descriptor
(
M
,
N
,
StrideD2
,
D2Layout
{}));
Tensor
<
EDataType
>
e_m_n_host_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideE
,
ELayout
{}));
Tensor
<
EDataType
>
e_m_n_device_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideE
,
ELayout
{}));
int
total_gemm_needed
=
a_m_k
.
GetElementSpaceSizeInBytes
()
+
b_k_n
.
GetElementSpaceSizeInBytes
()
+
d0_m_n
.
GetElementSpaceSizeInBytes
()
+
d1_m_n
.
GetElementSpaceSizeInBytes
()
+
d2_m_n
.
GetElementSpaceSizeInBytes
();
int
rotating_count
=
std
::
max
(
1
,
std
::
min
(
n_iter
,
static_cast
<
int
>
(
std
::
ceil
(
static_cast
<
double
>
(
rotating
)
/
total_gemm_needed
))));
// std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
// std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
// std::cout << "d0_m_n: " << d0_m_n.mDesc << std::endl;
// std::cout << "d1_m_n: " << d1_m_n.mDesc << std::endl;
// std::cout << "d2_m_n: " << d2_m_n.mDesc << std::endl;
// std::cout << "e_m_n: " << e_m_n_device_result.mDesc << std::endl;
// std::cout << "rotating count: " << rotating_count << std::endl;
switch
(
init_method
)
{
case
0
:
break
;
case
1
:
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
1
,
2
});
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
1
,
2
});
d0_m_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
D0DataType
>
{
-
5
,
5
});
d1_m_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
D1DataType
>
{
-
1
,
1
});
d2_m_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
D2DataType
>
{
-
1
,
1
});
break
;
default:
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
1.0
});
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
});
d0_m_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
D0DataType
>
{
0.0
,
1.0
});
d1_m_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
D1DataType
>
{
0.0
,
1.0
});
d2_m_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
D2DataType
>
{
0.0
,
1.0
});
}
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
MultiplyMultiplyAdd
=
ck
::
tensor_operation
::
element_wise
::
MultiplyMultiplyAdd
;
using
AElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
using
CElementOp
=
MultiplyMultiplyAdd
;
const
auto
a_element_op
=
AElementOp
{};
const
auto
b_element_op
=
BElementOp
{};
const
auto
c_element_op
=
CElementOp
{};
DeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
a_m_k
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_device_buf
(
sizeof
(
BDataType
)
*
b_k_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
d0_device_buf
(
sizeof
(
D0DataType
)
*
d0_m_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
d1_device_buf
(
sizeof
(
D1DataType
)
*
d1_m_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
d2_device_buf
(
sizeof
(
D2DataType
)
*
d2_m_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
c_device_buf
(
sizeof
(
EDataType
)
*
e_m_n_device_result
.
mDesc
.
GetElementSpaceSize
());
a_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
b_device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
d0_device_buf
.
ToDevice
(
d0_m_n
.
mData
.
data
());
d1_device_buf
.
ToDevice
(
d1_m_n
.
mData
.
data
());
d2_device_buf
.
ToDevice
(
d2_m_n
.
mData
.
data
());
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceGemmMultipleDSplitK
<
ALayout
,
BLayout
,
ck
::
Tuple
<
D0Layout
,
D1Layout
,
D2Layout
>
,
ELayout
,
ADataType
,
BDataType
,
ck
::
Tuple
<
D0DataType
,
D1DataType
,
D2DataType
>
,
EDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
// get device op instances
const
auto
op_ptrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
DeviceOp
>::
GetInstances
();
// std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
// Run reference GEMM
if
(
do_verification
)
{
Tensor
<
AccDataType
>
c_m_n
({
M
,
N
});
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
AccDataType
,
AccDataType
,
AElementOp
,
BElementOp
,
PassThrough
,
ComputeDataType
>
;
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
a_m_k
,
b_k_n
,
c_m_n
,
PassThrough
{},
PassThrough
{},
PassThrough
{});
ref_invoker
.
Run
(
ref_argument
);
for
(
int
m
=
0
;
m
<
M
;
++
m
)
{
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
c_element_op
(
e_m_n_host_result
(
m
,
n
),
c_m_n
(
m
,
n
),
d0_m_n
(
m
,
n
),
d1_m_n
(
m
,
n
),
d2_m_n
(
m
,
n
));
}
}
}
std
::
string
best_op_name
;
float
best_ave_time
=
0
;
float
best_tflops
=
0
;
float
best_gb_per_sec
=
0
;
float
best_kbatch
=
0
;
// profile device GEMM instances
for
(
auto
&
op_ptr
:
op_ptrs
)
{
// Seems like when performance measurement has bug when spiltK is large
// std::vector<int> kbatch_list = {1, 2, 4, 8, 16, 19, 32, 38};
std
::
vector
<
int
>
kbatch_list
=
{
1
,
2
,
4
,
8
,
16
};
if
(
KBatch
>
0
)
{
kbatch_list
=
{
KBatch
};
}
for
(
std
::
size_t
i
=
0
;
i
<
kbatch_list
.
size
();
i
++
)
{
auto
kbatch_curr
=
kbatch_list
[
i
];
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
static_cast
<
ADataType
*>
(
a_device_buf
.
GetDeviceBuffer
()),
static_cast
<
BDataType
*>
(
b_device_buf
.
GetDeviceBuffer
()),
std
::
array
<
const
void
*
,
3
>
{
d0_device_buf
.
GetDeviceBuffer
(),
d1_device_buf
.
GetDeviceBuffer
(),
d2_device_buf
.
GetDeviceBuffer
()},
static_cast
<
EDataType
*>
(
c_device_buf
.
GetDeviceBuffer
()),
M
,
N
,
K
,
StrideA
,
StrideB
,
std
::
array
<
ck
::
index_t
,
3
>
{
StrideD0
,
StrideD1
,
StrideD2
},
StrideE
,
kbatch_curr
,
a_element_op
,
b_element_op
,
c_element_op
);
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
// re-init C to zero before profiling next kernel
c_device_buf
.
SetZero
();
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
false
,
0
,
n_warmup
,
n_iter
});
if
(
do_verification
)
{
c_device_buf
.
FromDevice
(
e_m_n_device_result
.
mData
.
data
());
pass
=
pass
&
ck
::
utils
::
check_err
(
e_m_n_device_result
,
e_m_n_host_result
);
if
(
do_log
)
{
LogRangeAsType
<
float
>
(
std
::
cout
<<
"a : "
,
a_m_k
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"b: "
,
b_k_n
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"c_host : "
,
e_m_n_host_result
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"c_device: "
,
e_m_n_device_result
.
mData
,
","
)
<<
std
::
endl
;
}
}
std
::
string
op_name
=
op_ptr
->
GetTypeString
();
hipStream_t
stream
;
hip_check_error
(
hipStreamCreate
(
&
stream
));
float
ave_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
stream
,
time_kernel
,
0
,
n_warmup
,
n_iter
,
rotating_count
>
1
,
rotating_count
});
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
;
std
::
size_t
num_btype
=
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
BDataType
)
*
K
*
N
+
sizeof
(
EDataType
)
*
M
*
N
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
// std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops
// << " TFlops, " << gb_per_sec << " GB/s, " << op_name << ", KBatch "
// << kbatch_curr << std::endl;
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_INT8
// set softer tolerances for fp8
if
constexpr
((
is_same_v
<
ADataType
,
f8_t
>
||
is_same_v
<
BDataType
,
f8_t
>
||
is_same_v
<
EDataType
,
f8_t
>
)
||
(
is_same_v
<
ADataType
,
int8_t
>
||
is_same_v
<
BDataType
,
int8_t
>
||
is_same_v
<
EDataType
,
int8_t
>
))
{
std
::
string
msg
=
"Error: Incorrect results!"
;
double
rtol
=
1e-1
;
double
atol
=
1e-1
;
pass
=
pass
&
ck
::
utils
::
check_err
(
e_m_n_device_result
,
e_m_n_host_result
,
msg
,
rtol
,
atol
);
}
else
{
#endif
pass
=
pass
&
ck
::
utils
::
check_err
(
e_m_n_device_result
,
e_m_n_host_result
);
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_INT8
}
#endif
if
(
tflops
>
best_tflops
&&
ave_time
>
1e-10
)
{
best_op_name
=
op_name
;
best_tflops
=
tflops
;
best_ave_time
=
ave_time
;
best_gb_per_sec
=
gb_per_sec
;
best_kbatch
=
kbatch_curr
;
}
}
else
{
std
::
cout
<<
op_ptr
->
GetTypeString
()
<<
" does not support this problem"
<<
std
::
endl
;
}
}
}
if
constexpr
(
is_same
<
EDataType
,
float
>::
value
)
{
std
::
cout
<<
"Best Perf for datatype = f32"
;
}
else
if
constexpr
(
is_same
<
EDataType
,
half_t
>::
value
)
{
std
::
cout
<<
"Best Perf for datatype = f16"
;
}
else
if
constexpr
(
is_same
<
EDataType
,
bhalf_t
>::
value
)
{
std
::
cout
<<
"Best Perf for datatype = bf16"
;
}
else
if
constexpr
(
is_same
<
EDataType
,
int8_t
>::
value
)
{
std
::
cout
<<
"Best Perf for datatype = int8"
;
}
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
std
::
cout
<<
" ALayout = RowMajor"
;
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
)
{
std
::
cout
<<
" ALayout = ColumnMajor"
;
}
if
constexpr
(
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
std
::
cout
<<
" BLayout = RowMajor"
;
}
else
if
constexpr
(
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
)
{
std
::
cout
<<
" BLayout = ColumnMajor"
;
}
std
::
cout
<<
" M = "
<<
M
<<
" N = "
<<
N
<<
" K = "
<<
K
<<
" StrideA = "
<<
StrideA
<<
" StrideB = "
<<
StrideB
<<
" StrideE = "
<<
StrideE
<<
" KBatch = "
<<
best_kbatch
<<
" : "
<<
best_ave_time
<<
" ms, "
<<
best_tflops
<<
" TFlops, "
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_op_name
<<
std
::
endl
;
return
pass
;
}
}
// namespace profiler
}
// namespace ck
profiler/int8_gmm_profiler.sh
0 → 100644
View file @
5e8b5703
EXE
=
"
$(
find
.
-name
ckProfiler
-type
f |
head
-n
1
)
"
op
=
"gemm_multiply_multiply"
loopFunc
()
{
N
=
$1
K
=
$2
# $EXE $op 8 1 0 2 0 1 1 $N $K -1 -1 0 0 -1 1 40 500 4096
# for ((M=32; M<=20480;M*=2))
# do
# # echo "M = $M, N = $N, K = $K"
# $EXE $op 8 1 0 2 0 1 $M $N $K -1 -1 0 0 -1 1 40 500 4096
# done
$EXE
$op
8 1 1 2 0 1 128
$N
$K
-1
-1
0 0
-1
1 20 200 4096
}
N
=
1280
K
=
8192
loopFunc
$N
$K
# N=8192
# K=1024
# loopFunc $N $K
profiler/int8_gmma_mb_profiler.sh
0 → 100644
View file @
5e8b5703
EXE
=
"
$(
find
.
-name
ckProfiler
-type
f |
head
-n
1
)
"
op
=
"gemm_multiply_multiply_add"
loopFunc
()
{
N
=
$1
K
=
$2
$EXE
$op
8 1 0 2 0 1 1
$N
$K
-1
-1
0 0
-1
1 20 50 4096
for
((
M
=
32
;
M<
=
32768
;
M
*
=
2
))
do
# echo "M = $M, N = $N, K = $K"
$EXE
$op
8 1 0 2 0 1
$M
$N
$K
-1
-1
0 0
-1
1 20 50 4096
done
# $EXE $op 8 1 0 2 0 1 $M $N $K -1 -1 0 0 -1 1 20 50 4096
}
N
=
1280
K
=
8192
loopFunc
$N
$K
N
=
8192
K
=
1024
loopFunc
$N
$K
# M=4096
# N=1280
# K=8192
# loopFunc $M $N $K
# M=4096
# N=8192
# K=1024
# loopFunc $M $N $K
profiler/int8_gmma_profiler.sh
0 → 100644
View file @
5e8b5703
EXE
=
"
$(
find
.
-name
ckProfiler
-type
f |
head
-n
1
)
"
op
=
"gemm_multiply_multiply_add"
loopFunc
()
{
M
=
$1
N
=
$2
K
=
$3
# $EXE $op 8 1 0 2 0 1 1 $N $K -1 -1 0 0 -1 1 40 500 4096
# for ((M=32; M<=20480;M*=2))
# do
# # echo "M = $M, N = $N, K = $K"
# $EXE $op 8 1 0 2 0 1 $M $N $K -1 -1 0 0 -1 1 40 500 4096
# done
$EXE
$op
8 1 0 2 0 1
$M
$N
$K
-1
-1
0 0
-1
1 20 50 4096
}
# M=128
# N=1280
# K=8192
# loopFunc $M $N $K
M
=
128
N
=
8192
K
=
1024
loopFunc
$M
$N
$K
# M=4096
# N=1280
# K=8192
# loopFunc $M $N $K
# M=4096
# N=8192
# K=1024
# loopFunc $M $N $K
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment