Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
7e8230da
"tests/models/unets/test_unet_2d_blocks.py" did not exist on "eadf0e2555cfa19b033e02de53553f71ac33536f"
Commit
7e8230da
authored
Oct 02, 2023
by
Jun Liu
Browse files
Merge branch 'develop' into amd-develop
parents
56c72035
bd09b5c5
Changes
185
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1353 additions
and
154 deletions
+1353
-154
include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp
...pu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp
+2
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
...nsor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
+13
-11
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp
...tion/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp
+1
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp
...tion/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp
+1
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp
...or_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp
+1
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp
...ensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp
+1
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
+2
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
...tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
+1
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp
+1
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp
+1
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp
+1
-0
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
+86
-12
include/ck/utility/amd_buffer_addressing.hpp
include/ck/utility/amd_buffer_addressing.hpp
+92
-55
include/ck/utility/amd_xdlops.hpp
include/ck/utility/amd_xdlops.hpp
+65
-0
library/include/ck/library/reference_tensor_operation/cpu/reference_contraction.hpp
.../reference_tensor_operation/cpu/reference_contraction.hpp
+11
-5
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp
...library/reference_tensor_operation/cpu/reference_gemm.hpp
+4
-3
library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp
..._operation_instance/device_operation_instance_factory.hpp
+2
-0
library/include/ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp
..._instance/gpu/contraction/device_contraction_instance.hpp
+292
-0
library/include/ck/library/tensor_operation_instance/gpu/contraction_bilinear.hpp
...ry/tensor_operation_instance/gpu/contraction_bilinear.hpp
+389
-35
library/include/ck/library/tensor_operation_instance/gpu/contraction_scale.hpp
...brary/tensor_operation_instance/gpu/contraction_scale.hpp
+387
-33
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp
View file @
7e8230da
...
...
@@ -588,6 +588,7 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
<
BlockSize
,
ABDataType
,
ABDataType
,
AccDataType
,
decltype
(
a_block_desc_ak0_m_ak1
),
decltype
(
b_block_desc_bk0_n_bk1
),
...
...
@@ -1012,6 +1013,7 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
<
BlockSize
,
ABDataType
,
ABDataType
,
AccDataType
,
decltype
(
a_block_desc_ak0_m_ak1
),
decltype
(
b_block_desc_bk0_n_bk1
),
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
View file @
7e8230da
...
...
@@ -108,7 +108,8 @@ template <typename ALayout,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopScheduler
LoopSched
,
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
,
typename
ComputeType
=
FloatC
>
typename
ComputeTypeA
=
FloatC
,
typename
ComputeTypeB
=
ComputeTypeA
>
struct
GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
...
@@ -547,8 +548,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
constexpr
auto
c_block_size
=
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
();
return
math
::
max
((
a_block_space_size_aligned
*
sizeof
(
ComputeType
)
+
b_block_space_size_aligned
*
sizeof
(
ComputeType
)),
return
math
::
max
((
a_block_space_size_aligned
*
sizeof
(
ComputeType
A
)
+
b_block_space_size_aligned
*
sizeof
(
ComputeType
B
)),
c_block_size
*
sizeof
(
FloatCShuffle
));
}
...
...
@@ -750,7 +751,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
FloatA
,
ComputeType
,
ComputeType
A
,
decltype
(
a_grid_desc_ak0_m_ak1
),
decltype
(
a_block_desc_ak0_m_ak1
),
ABlockTransferSrcAccessOrder
,
...
...
@@ -781,7 +782,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
FloatB
,
ComputeType
,
ComputeType
B
,
decltype
(
b_grid_desc_bk0_n_bk1
),
decltype
(
b_block_desc_bk0_n_bk1
),
BBlockTransferSrcAccessOrder
,
...
...
@@ -809,13 +810,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1Number
,
BK1Number
),
MfmaSelector
<
ComputeType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1Number
,
BK1Number
),
MfmaSelector
<
ComputeType
A
,
MPerXdl
,
NPerXdl
,
ComputeTypeB
>::
selected_mfma
.
k_per_blk
);
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
<
BlockSize
,
ComputeType
,
ComputeTypeA
,
ComputeTypeB
,
FloatGemmAcc
,
decltype
(
a_block_desc_ak0_m_ak1
),
decltype
(
b_block_desc_bk0_n_bk1
),
...
...
@@ -833,10 +835,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
ComputeType
*>
(
p_shared
),
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
static_cast
<
ComputeType
A
*>
(
p_shared
),
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
ComputeType
*>
(
p_shared
)
+
a_block_space_size_aligned
,
static_cast
<
ComputeType
B
*>
(
p_shared
)
+
a_block_space_size_aligned
,
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
AK1Number
,
0
,
0
);
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp
View file @
7e8230da
...
...
@@ -495,6 +495,7 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
<
BlockSize
,
FloatAB
,
FloatAB
,
FloatGemmAcc
,
decltype
(
a_block_desc_ak0_m_ak1
),
decltype
(
b_block_desc_bk0_n_bk1
),
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp
View file @
7e8230da
...
...
@@ -494,6 +494,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
TileMathThreadGroupSize
,
ABDataType
,
ABDataType
,
FloatGemmAcc
,
decltype
(
a_block_desc_ak0_m_ak1
),
decltype
(
b_block_desc_bk0_n_bk1
),
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp
View file @
7e8230da
...
...
@@ -737,6 +737,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
FloatABAdjusted
,
FloatABAdjusted
,
FloatAcc
,
decltype
(
a_k0_m_k1_block_desc
),
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp
View file @
7e8230da
...
...
@@ -490,6 +490,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
FloatAB
,
FloatAB
,
FloatAcc
,
decltype
(
a_block_desc_k0_m_k1
),
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
View file @
7e8230da
...
...
@@ -424,6 +424,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
using
BlockwiseGemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
FloatABAdjusted
,
FloatABAdjusted
,
FloatAcc
,
decltype
(
a_block_desc_k0_m_k1
),
...
...
@@ -569,6 +570,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
<
BlockSize
,
FloatABAdjusted
,
FloatABAdjusted
,
FloatAcc
,
decltype
(
a_block_desc_k0_m_k1
),
decltype
(
b_block_desc_k0_n_k1
),
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
View file @
7e8230da
...
...
@@ -762,6 +762,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
<
BlockSize
,
ComputeType
,
ComputeType
,
FloatAcc
,
decltype
(
a_k0_m_k1_block_desc
),
decltype
(
b_k0_n_k1_block_desc
),
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp
View file @
7e8230da
...
...
@@ -451,6 +451,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
FloatAB
,
FloatAB
,
FloatAcc
,
decltype
(
a_block_desc_ak0_m_ak1
),
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp
View file @
7e8230da
...
...
@@ -471,6 +471,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
FloatAB
,
FloatAB
,
FloatAcc
,
decltype
(
a_block_desc_k0_m_k1
),
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp
View file @
7e8230da
...
...
@@ -489,6 +489,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
FloatAB
,
FloatAB
,
FloatAcc
,
decltype
(
a_block_desc_k0_m_k1
),
...
...
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
View file @
7e8230da
...
...
@@ -31,7 +31,9 @@ enum struct MfmaInstr
mfma_i32_16x16x32i8
,
mfma_f64_16x16x4f64
,
mfma_f32_32x32x16f8f8
,
mfma_f32_16x16x32f8f8
mfma_f32_16x16x32f8f8
,
mfma_f32_32x32x16f8bf8
,
mfma_f32_16x16x32f8bf8
};
template
<
MfmaInstr
instr
>
...
...
@@ -502,10 +504,62 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32f8f8>
};
#endif
template
<
typename
base_type
,
index_t
MPerXdlops
,
index_t
NPerXdlops
>
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_32x32x16f8bf8
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
4
;
static
constexpr
index_t
num_regs_per_blk
=
16
;
static
constexpr
index_t
num_threads_per_blk
=
32
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
2
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
m_per_blk
=
32
;
static
constexpr
index_t
n_per_blk
=
32
;
static
constexpr
index_t
k_per_blk
=
8
;
static
constexpr
bool
is_k_reduction
=
true
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
intrin_mfma_f32_32x32x16f8bf8
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
};
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_16x16x32f8bf8
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
1
;
static
constexpr
index_t
num_regs_per_blk
=
4
;
static
constexpr
index_t
num_threads_per_blk
=
16
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
4
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
m_per_blk
=
16
;
static
constexpr
index_t
n_per_blk
=
16
;
static
constexpr
index_t
k_per_blk
=
8
;
static
constexpr
bool
is_k_reduction
=
true
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
intrin_mfma_f32_16x16x32f8bf8
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
};
#endif
template
<
typename
base_type
,
index_t
MPerXdlops
,
index_t
NPerXdlops
,
typename
additional_type
=
base_type
>
struct
MfmaSelector
{
template
<
typename
base_type_
,
index_t
MPerXdlops_
,
index_t
NPerXdlops_
>
template
<
typename
base_type_
,
index_t
MPerXdlops_
,
index_t
NPerXdlops_
,
typename
additional_type_
=
base_type_
>
static
constexpr
auto
GetMfma
();
template
<
>
...
...
@@ -656,7 +710,22 @@ struct MfmaSelector
}
#endif
static
constexpr
auto
selected_mfma
=
mfma_type
<
GetMfma
<
base_type
,
MPerXdlops
,
NPerXdlops
>
()
>
{};
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
template
<
>
static
constexpr
auto
GetMfma
<
f8_t
,
32
,
32
,
bf8_t
>
()
{
return
MfmaInstr
::
mfma_f32_32x32x16f8bf8
;
}
template
<
>
static
constexpr
auto
GetMfma
<
f8_t
,
16
,
16
,
bf8_t
>
()
{
return
MfmaInstr
::
mfma_f32_16x16x32f8bf8
;
}
#endif
static
constexpr
auto
selected_mfma
=
mfma_type
<
GetMfma
<
base_type
,
MPerXdlops
,
NPerXdlops
,
additional_type
>
()
>
{};
__host__
__device__
constexpr
MfmaSelector
()
{
...
...
@@ -703,7 +772,8 @@ template <typename base_type,
index_t
MPerXdlops
,
index_t
NPerXdlops
,
index_t
KPack
,
bool
TransposeC
=
false
>
typename
additional_type
=
base_type
,
bool
TransposeC
=
false
>
struct
XdlopsGemm
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
...
@@ -854,14 +924,18 @@ struct XdlopsGemm
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
Run
(
const
FloatA
&
p_a_wave
,
const
FloatB
&
p_b_wave
,
FloatC
&
p_c_thread
)
const
{
static_assert
(
is_same
<
base_type
,
double
>::
value
||
is_same
<
base_type
,
float
>::
value
||
is_same
<
base_type
,
half_t
>::
value
||
is_same
<
base_type
,
bhalf_t
>::
value
||
is_same
<
base_type
,
int8_t
>::
value
static_assert
(
is_same
<
base_type
,
double
>::
value
||
is_same
<
base_type
,
float
>::
value
||
is_same
<
base_type
,
half_t
>::
value
||
is_same
<
base_type
,
bhalf_t
>::
value
||
is_same
<
base_type
,
int8_t
>::
value
#if defined CK_ENABLE_FP8
||
is_same
<
base_type
,
f8_t
>::
value
||
is_same
<
base_type
,
f8_t
>::
value
#endif
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
||
(
is_same
<
base_type
,
f8_t
>::
value
&&
is_same
<
additional_type
,
bf8_t
>::
value
)
#endif
,
"base base_type must be double, float, half, bfloat16,
and int
8_t!"
);
,
"base base_type must be double, float, half, bfloat16,
int8_t, f8_t or bf
8_t!"
);
static_for
<
0
,
KPack
/
mfma_instr
.
k_per_blk
,
1
>
{}([
&
](
auto
k
)
{
if
constexpr
(
!
TransposeC
)
...
...
@@ -957,7 +1031,7 @@ struct XdlopsGemm
return
TransposeC
?
CIndex4D
{
blk_td
,
I0
,
blk_id
,
I0
}
:
CIndex4D
{
I0
,
blk_id
,
I0
,
blk_td
};
}
static
constexpr
auto
mfma
=
MfmaSelector
<
base_type
,
MPerXdlops
,
NPerXdlops
>
{};
static
constexpr
auto
mfma
=
MfmaSelector
<
base_type
,
MPerXdlops
,
NPerXdlops
,
additional_type
>
{};
static
constexpr
auto
mfma_instr
=
mfma
.
selected_mfma
;
...
...
include/ck/utility/amd_buffer_addressing.hpp
View file @
7e8230da
...
...
@@ -1127,37 +1127,53 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t
src_addr_shift
=
src_thread_element_valid
?
0
:
0x80000000
;
#if defined CK_ENABLE_FP8
if
constexpr
(
is_same
<
scalar_t
,
f8_t
>::
value
)
{
auto
tmp
=
amd_buffer_load_impl
<
int8_t
,
vector_size
,
coherence
>
(
src_wave_buffer_resource
,
src_addr_shift
+
src_thread_addr_offset
,
0
);
return
bit_cast
<
vector_t
>
(
tmp
);
}
else
{
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
if
constexpr
(
is_same
<
scalar_t
,
f8_t
>::
value
||
is_same
<
scalar_t
,
bf8_t
>::
value
)
#endif
return
amd_buffer_load_impl
<
scalar_t
,
vector_size
,
coherence
>
(
src_wave_buffer_resource
,
src_addr_shift
+
src_thread_addr_offset
,
0
);
#if defined CK_ENABLE_FP8
}
#if defined CK_ENABLE_FP8 && !defined CK_ENABLE_BF8
if
constexpr
(
is_same
<
scalar_t
,
f8_t
>::
value
)
#endif
#if !defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
if
constexpr
(
is_same
<
scalar_t
,
bf8_t
>::
value
)
#endif
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
{
auto
tmp
=
amd_buffer_load_impl
<
int8_t
,
vector_size
,
coherence
>
(
src_wave_buffer_resource
,
src_addr_shift
+
src_thread_addr_offset
,
0
);
return
bit_cast
<
vector_t
>
(
tmp
);
}
else
{
#endif
return
amd_buffer_load_impl
<
scalar_t
,
vector_size
,
coherence
>
(
src_wave_buffer_resource
,
src_addr_shift
+
src_thread_addr_offset
,
0
);
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
}
#endif
#else
#if defined CK_ENABLE_FP8
if
constexpr
(
is_same
<
scalar_t
,
f8_t
>::
value
)
{
auto
tmp
=
amd_buffer_load_impl
<
int8_t
,
vector_size
,
coherence
>
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
0
);
return
src_thread_element_valid
?
bit_cast
<
vector_t
>
(
tmp
)
:
vector_t
(
0
);
}
else
{
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
if
constexpr
(
is_same
<
scalar_t
,
f8_t
>::
value
||
is_same
<
scalar_t
,
bf8_t
>::
value
)
#endif
vector_t
tmp
=
amd_buffer_load_impl
<
scalar_t
,
vector_size
,
coherence
>
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
0
);
return
src_thread_element_valid
?
tmp
:
vector_t
(
0
);
#if defined CK_ENABLE_FP8
}
#if defined CK_ENABLE_FP8 && !defined CK_ENABLE_BF8
if
constexpr
(
is_same
<
scalar_t
,
f8_t
>::
value
)
#endif
#if !defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
if
constexpr
(
is_same
<
scalar_t
,
bf8_t
>::
value
)
#endif
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
{
auto
tmp
=
amd_buffer_load_impl
<
int8_t
,
vector_size
,
coherence
>
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
0
);
return
src_thread_element_valid
?
bit_cast
<
vector_t
>
(
tmp
)
:
vector_t
(
0
);
}
else
{
#endif
vector_t
tmp
=
amd_buffer_load_impl
<
scalar_t
,
vector_size
,
coherence
>
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
0
);
return
src_thread_element_valid
?
tmp
:
vector_t
(
0
);
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
}
#endif
#endif
}
...
...
@@ -1216,40 +1232,61 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t
dst_addr_shift
=
dst_thread_element_valid
?
0
:
0x80000000
;
#if defined CK_ENABLE_FP8
if
constexpr
(
is_same
<
scalar_t
,
f8_t
>::
value
)
{
auto
tmp
=
bit_cast
<
typename
vector_type_maker
<
int8_t
,
vector_size
>::
type
::
type
>
(
src_thread_data
);
amd_buffer_store_impl
<
int8_t
,
vector_size
,
coherence
>
(
tmp
,
dst_wave_buffer_resource
,
dst_addr_shift
+
dst_thread_addr_offset
,
0
);
}
else
{
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
if
constexpr
(
is_same
<
scalar_t
,
f8_t
>::
value
||
is_same
<
scalar_t
,
bf8_t
>::
value
)
#endif
amd_buffer_store_impl
<
scalar_t
,
vector_size
,
coherence
>
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_addr_shift
+
dst_thread_addr_offset
,
0
);
#if defined CK_ENABLE_FP8
}
#if defined CK_ENABLE_FP8 && !defined CK_ENABLE_BF8
if
constexpr
(
is_same
<
scalar_t
,
f8_t
>::
value
)
#endif
#if !defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
if
constexpr
(
is_same
<
scalar_t
,
bf8_t
>::
value
)
#endif
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
{
auto
tmp
=
bit_cast
<
typename
vector_type_maker
<
int8_t
,
vector_size
>::
type
::
type
>
(
src_thread_data
);
amd_buffer_store_impl
<
int8_t
,
vector_size
,
coherence
>
(
tmp
,
dst_wave_buffer_resource
,
dst_addr_shift
+
dst_thread_addr_offset
,
0
);
}
else
{
#endif
amd_buffer_store_impl
<
scalar_t
,
vector_size
,
coherence
>
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_addr_shift
+
dst_thread_addr_offset
,
0
);
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
}
#endif
#else
if
(
dst_thread_element_valid
)
{
#if defined CK_ENABLE_FP8
if
constexpr
(
is_same
<
scalar_t
,
f8_t
>::
value
)
{
auto
tmp
=
bit_cast
<
typename
vector_type_maker
<
int8_t
,
vector_size
>::
type
::
type
>
(
src_thread_data
);
amd_buffer_store_impl
<
int8_t
,
vector_size
,
coherence
>
(
tmp
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
0
);
}
else
{
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
if
constexpr
(
is_same
<
scalar_t
,
f8_t
>::
value
||
is_same
<
scalar_t
,
bf8_t
>::
value
)
#endif
amd_buffer_store_impl
<
scalar_t
,
vector_size
,
coherence
>
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
0
);
#if defined CK_ENABLE_FP8
}
#if defined CK_ENABLE_FP8 && !defined CK_ENABLE_BF8
if
constexpr
(
is_same
<
scalar_t
,
f8_t
>::
value
)
#endif
#if !defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
if
constexpr
(
is_same
<
scalar_t
,
bf8_t
>::
value
)
#endif
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
{
auto
tmp
=
bit_cast
<
typename
vector_type_maker
<
int8_t
,
vector_size
>::
type
::
type
>
(
src_thread_data
);
amd_buffer_store_impl
<
int8_t
,
vector_size
,
coherence
>
(
tmp
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
0
);
}
else
{
#endif
amd_buffer_store_impl
<
scalar_t
,
vector_size
,
coherence
>
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
0
);
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
}
#endif
}
#endif
...
...
include/ck/utility/amd_xdlops.hpp
View file @
7e8230da
...
...
@@ -419,5 +419,70 @@ struct intrin_mfma_f32_16x16x32f8f8<16, 16>
}
};
#endif
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_32x32x16f8bf8
;
template
<
>
struct
intrin_mfma_f32_32x32x16f8bf8
<
32
,
32
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
f8x8_t
&
reg_a
,
const
bf8x8_t
&
reg_b
,
FloatC
&
reg_c
)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8
(
bit_cast
<
long
>
(
reg_a
),
bit_cast
<
long
>
(
reg_b
),
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
#else
vector_type
<
f8_t
,
8
>
reg_a_v
(
reg_a
);
vector_type
<
bf8_t
,
8
>
reg_b_v
(
reg_b
);
static_for
<
0
,
8
,
1
>
{}([
&
](
auto
k
)
{
float
reg_a_f32
=
type_convert
<
float
>
(
reg_a_v
.
template
AsType
<
f8_t
>()[
Number
<
k
>
{}]);
float
reg_b_f32
=
type_convert
<
float
>
(
reg_b_v
.
template
AsType
<
bf8_t
>()[
Number
<
k
>
{}]);
intrin_mfma_f32_32x32x2f32
<
32
,
32
>::
Run
(
reg_a_f32
,
reg_b_f32
,
reg_c
);
});
#endif
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_16x16x32f8bf8
;
template
<
>
struct
intrin_mfma_f32_16x16x32f8bf8
<
16
,
16
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
f8x8_t
&
reg_a
,
const
bf8x8_t
&
reg_b
,
FloatC
&
reg_c
)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_f32_16x16x32_fp8_bf8
(
bit_cast
<
long
>
(
reg_a
),
bit_cast
<
long
>
(
reg_b
),
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
#else
vector_type
<
f8_t
,
8
>
reg_a_v
(
reg_a
);
vector_type
<
bf8_t
,
8
>
reg_b_v
(
reg_b
);
static_for
<
0
,
8
,
1
>
{}([
&
](
auto
k
)
{
float
reg_a_f32
=
type_convert
<
float
>
(
reg_a_v
.
template
AsType
<
f8_t
>()[
Number
<
k
>
{}]);
float
reg_b_f32
=
type_convert
<
float
>
(
reg_b_v
.
template
AsType
<
bf8_t
>()[
Number
<
k
>
{}]);
intrin_mfma_f32_16x16x4f32
<
16
,
16
>::
Run
(
reg_a_f32
,
reg_b_f32
,
reg_c
);
});
#endif
}
};
#endif
}
// namespace ck
#endif
library/include/ck/library/reference_tensor_operation/cpu/reference_contraction.hpp
View file @
7e8230da
...
...
@@ -23,6 +23,7 @@ template <ck::index_t NumDimM,
typename
BDataType
,
typename
CDataType
,
typename
AccDataType
,
typename
ComputeDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
ck
::
enable_if_t
<
NumDimM
==
2
&&
NumDimN
==
2
&&
NumDimK
==
2
,
bool
>
=
false
>
...
...
@@ -69,19 +70,24 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base
{
for
(
ck
::
index_t
k1
=
0
;
k1
<
K1
;
++
k1
)
{
// Simulate the possible casting when ComputeDataType is different than the
// A/B data types
ComputeDataType
v_a_compute_input
=
ck
::
type_convert
<
ComputeDataType
>
(
arg
.
a_ms_ks_
(
m0
,
m1
,
k0
,
k1
));
ComputeDataType
v_b_compute_input
=
ck
::
type_convert
<
ComputeDataType
>
(
arg
.
b_ns_ks_
(
n0
,
n1
,
k0
,
k1
));
AccDataType
v_a
;
AccDataType
v_b
;
arg
.
a_element_op_
(
v_a
,
ck
::
type_convert
<
const
AccDataType
>
(
arg
.
a_ms_ks_
(
m0
,
m1
,
k0
,
k1
)));
arg
.
b_element_op_
(
v_b
,
ck
::
type_convert
<
const
AccDataType
>
(
arg
.
b_ns_ks_
(
n0
,
n1
,
k0
,
k1
)));
arg
.
a_element_op_
(
v_a
,
ck
::
type_convert
<
AccDataType
>
(
v_a_compute_input
));
arg
.
b_element_op_
(
v_b
,
ck
::
type_convert
<
AccDataType
>
(
v_b_compute_input
));
v_acc
+=
v_a
*
v_b
;
}
}
arg
.
c_ms_ns_
(
m0
,
m1
,
n0
,
n1
)
=
v_acc
;
arg
.
c_ms_ns_
(
m0
,
m1
,
n0
,
n1
)
=
ck
::
type_convert
<
CDataType
>
(
v_acc
)
;
};
make_ParallelTensorFunctor
(
f_ms_ns
,
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp
View file @
7e8230da
...
...
@@ -21,7 +21,8 @@ template <typename ADataType,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
ComputType
=
ADataType
>
typename
ComputeTypeA
=
ADataType
,
typename
ComputeTypeB
=
ComputeTypeA
>
struct
ReferenceGemm
:
public
device
::
BaseOperator
{
// Argument
...
...
@@ -65,8 +66,8 @@ struct ReferenceGemm : public device::BaseOperator
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
ComputType
v_a
;
ComputType
v_b
;
Comput
e
Type
A
v_a
;
Comput
e
Type
B
v_b
;
// use PassThrough instead of ConvertBF16RTN for reference calculation
if
constexpr
(
is_same_v
<
AElementwiseOperation
,
...
...
library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp
View file @
7e8230da
...
...
@@ -29,6 +29,8 @@ using BF8 = ck::bf8_t;
using
Empty_Tuple
=
ck
::
Tuple
<>
;
using
BF16_Tuple
=
ck
::
Tuple
<
BF16
>
;
using
F16_Tuple
=
ck
::
Tuple
<
F16
>
;
using
F16_F16_Tuple
=
ck
::
Tuple
<
F16
,
F16
>
;
...
...
library/include/ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp
0 → 100644
View file @
7e8230da
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
F16
=
ck
::
half_t
;
using
BF16
=
ck
::
bhalf_t
;
using
F32
=
float
;
using
F64
=
double
;
using
F16_Tuple
=
ck
::
Tuple
<
F16
>
;
using
BF16_Tuple
=
ck
::
Tuple
<
BF16
>
;
using
F32_Tuple
=
ck
::
Tuple
<
F32
>
;
using
F64_Tuple
=
ck
::
Tuple
<
F64
>
;
using
Empty_Tuple
=
ck
::
Tuple
<>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
Bilinear
=
ck
::
tensor_operation
::
element_wise
::
Bilinear
;
using
Scale
=
ck
::
tensor_operation
::
element_wise
::
Scale
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
static
constexpr
auto
GemmMNKPadding
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
template
<
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CShuffleDataType
,
typename
DsDataType
,
typename
EDataType
,
typename
ComputeDataType
,
typename
AElementwiseOp
,
typename
BElementwiseOp
,
typename
CDEElementwiseOp
>
using
device_contraction_kk_instance
=
std
::
tuple
<
// clang-format off
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| Compute| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Data| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#####################################| | | | | | | | | | Type| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
256
,
128
,
16
,
4
,
4
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
128
,
256
,
16
,
4
,
4
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
128
,
128
,
128
,
16
,
4
,
4
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
16
>
,
4
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
128
,
128
,
16
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
128
,
128
,
64
,
16
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
128
,
64
,
128
,
16
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
16
>
,
4
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
64
,
64
,
64
,
16
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
4
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
128
,
64
,
16
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
64
,
128
,
16
,
4
,
4
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
128
,
128
,
32
,
16
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
128
,
32
,
128
,
16
,
4
,
4
,
32
,
32
,
1
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
16
>
,
4
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
64
,
64
,
32
,
16
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
4
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
64
,
32
,
64
,
16
,
4
,
4
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
4
>
// clang-format on
>
;
template
<
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CShuffleDataType
,
typename
DsDataType
,
typename
EDataType
,
typename
ComputeDataType
,
typename
AElementwiseOp
,
typename
BElementwiseOp
,
typename
CDEElementwiseOp
>
using
device_contraction_kn_instance
=
std
::
tuple
<
// clang-format off
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| Compute| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Data| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#####################################| | | | | | | | | | Type| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
256
,
128
,
16
,
4
,
1
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
1
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
256
,
128
,
16
,
4
,
4
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
128
,
256
,
16
,
4
,
1
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
1
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
128
,
256
,
16
,
4
,
4
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
128
,
128
,
128
,
16
,
4
,
1
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
1
,
0
,
1
,
1
,
S
<
1
,
8
,
1
,
16
>
,
4
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
128
,
128
,
128
,
16
,
4
,
4
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
16
>
,
4
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
128
,
128
,
16
,
4
,
1
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
1
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
128
,
128
,
16
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
128
,
128
,
64
,
16
,
4
,
1
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
8
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
1
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
128
,
128
,
64
,
16
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
128
,
64
,
128
,
16
,
4
,
1
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
1
,
0
,
1
,
1
,
S
<
1
,
8
,
1
,
16
>
,
4
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
128
,
64
,
128
,
16
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
16
>
,
4
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
128
,
64
,
16
,
4
,
1
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
1
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
128
,
64
,
16
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
64
,
128
,
16
,
4
,
1
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
1
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
64
,
128
,
16
,
4
,
4
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
>
// clang-format on
>
;
template
<
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CShuffleDataType
,
typename
DsDataType
,
typename
EDataType
,
typename
ComputeDataType
,
typename
AElementwiseOp
,
typename
BElementwiseOp
,
typename
CDEElementwiseOp
>
using
device_contraction_mk_instance
=
std
::
tuple
<
// clang-format off
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| Compute| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Data| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#####################################| | | | | | | | | | Type| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
256
,
128
,
16
,
1
,
4
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
1
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
256
,
128
,
16
,
4
,
4
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
128
,
256
,
16
,
1
,
4
,
32
,
32
,
2
,
4
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
1
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
128
,
256
,
16
,
4
,
4
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
128
,
128
,
128
,
16
,
1
,
4
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
1
,
0
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
16
>
,
4
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
128
,
128
,
128
,
16
,
4
,
4
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
16
>
,
4
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
128
,
128
,
16
,
1
,
4
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
1
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
128
,
128
,
16
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
128
,
128
,
64
,
16
,
1
,
4
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
1
,
0
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
128
,
128
,
64
,
16
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
128
,
64
,
128
,
16
,
1
,
4
,
32
,
32
,
2
,
2
,
S
<
8
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
1
,
0
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
16
>
,
4
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
128
,
64
,
128
,
16
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
16
>
,
4
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
128
,
64
,
16
,
1
,
4
,
32
,
32
,
2
,
1
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
1
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
128
,
64
,
16
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
64
,
128
,
16
,
1
,
4
,
32
,
32
,
1
,
2
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
1
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
64
,
128
,
16
,
4
,
4
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
>
// clang-format on
>
;
template
<
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CShuffleDataType
,
typename
DsDataType
,
typename
EDataType
,
typename
ComputeDataType
,
typename
AElementwiseOp
,
typename
BElementwiseOp
,
typename
CDEElementwiseOp
>
using
device_contraction_mn_instance
=
std
::
tuple
<
// clang-format off
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| Compute| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Data| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#####################################| | | | | | | | | | Type| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
256
,
128
,
16
,
1
,
1
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
1
,
0
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
1
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
256
,
128
,
16
,
4
,
4
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
128
,
256
,
16
,
1
,
1
,
32
,
32
,
2
,
4
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
1
,
0
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
1
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
128
,
256
,
16
,
4
,
4
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
128
,
128
,
128
,
16
,
1
,
1
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
1
,
0
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
1
,
0
,
1
,
1
,
S
<
1
,
8
,
1
,
16
>
,
4
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
128
,
128
,
128
,
16
,
4
,
4
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
1
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
16
>
,
4
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
128
,
128
,
16
,
1
,
1
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
1
,
0
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
1
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
128
,
128
,
16
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
128
,
128
,
64
,
16
,
1
,
1
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
1
,
0
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
1
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
128
,
128
,
64
,
16
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
1
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
128
,
64
,
128
,
16
,
1
,
1
,
32
,
32
,
2
,
2
,
S
<
8
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
1
,
0
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
1
,
0
,
1
,
1
,
S
<
1
,
8
,
1
,
16
>
,
4
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
128
,
64
,
128
,
16
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
1
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
16
>
,
4
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
128
,
64
,
16
,
1
,
1
,
32
,
32
,
2
,
1
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
1
,
0
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
1
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
128
,
64
,
16
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
64
,
128
,
16
,
1
,
1
,
32
,
32
,
1
,
2
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
1
,
0
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
1
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
64
,
128
,
16
,
4
,
4
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
>
// clang-format on
>
;
template
<
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CShuffleDataType
,
typename
DsDataType
,
typename
EDataType
,
typename
ComputeDataType
,
typename
AElementwiseOp
,
typename
BElementwiseOp
,
typename
CDEElementwiseOp
>
using
device_contraction_f64_kk_instance
=
std
::
tuple
<
// clang-format off
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| Compute| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Data| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#####################################| | | | | | | | | | Type| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
128
,
128
,
16
,
2
,
2
,
16
,
16
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
1
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
1
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
1
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
128
,
128
,
64
,
16
,
2
,
2
,
16
,
16
,
4
,
4
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
1
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
1
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
1
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
128
,
64
,
128
,
16
,
2
,
2
,
16
,
16
,
4
,
4
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
1
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
1
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
16
>
,
1
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
64
,
64
,
64
,
16
,
2
,
2
,
16
,
16
,
4
,
4
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
1
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
1
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
128
,
64
,
16
,
2
,
2
,
16
,
16
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
1
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
1
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
1
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
64
,
128
,
16
,
2
,
2
,
16
,
16
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
1
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
1
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
1
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
128
,
128
,
32
,
16
,
2
,
2
,
16
,
16
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
1
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
1
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
1
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
128
,
32
,
128
,
16
,
2
,
2
,
16
,
16
,
2
,
4
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
1
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
1
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
16
>
,
1
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
64
,
64
,
32
,
16
,
2
,
2
,
16
,
16
,
4
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
1
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
1
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
64
,
32
,
64
,
16
,
2
,
2
,
16
,
16
,
2
,
4
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
1
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
1
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
>
// clang-format on
>
;
template
<
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CShuffleDataType
,
typename
DsDataType
,
typename
EDataType
,
typename
ComputeDataType
,
typename
AElementwiseOp
,
typename
BElementwiseOp
,
typename
CDEElementwiseOp
>
using
device_contraction_f64_kn_instance
=
std
::
tuple
<
// clang-format off
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| Compute| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Data| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#####################################| | | | | | | | | | Type| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
128
,
128
,
16
,
2
,
1
,
16
,
16
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
1
,
1
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
1
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
1
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
128
,
128
,
16
,
2
,
2
,
16
,
16
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
1
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
1
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
1
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
128
,
128
,
64
,
16
,
2
,
1
,
16
,
16
,
4
,
4
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
1
,
1
,
S
<
8
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
1
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
1
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
128
,
128
,
64
,
16
,
2
,
2
,
16
,
16
,
4
,
4
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
1
,
1
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
1
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
1
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
128
,
64
,
128
,
16
,
2
,
1
,
16
,
16
,
4
,
4
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
1
,
1
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
1
,
0
,
1
,
1
,
S
<
1
,
8
,
1
,
16
>
,
1
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
128
,
64
,
128
,
16
,
2
,
2
,
16
,
16
,
4
,
4
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
1
,
1
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
1
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
16
>
,
1
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
128
,
64
,
16
,
2
,
1
,
16
,
16
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
1
,
1
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
1
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
1
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
128
,
64
,
16
,
2
,
2
,
16
,
16
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
1
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
1
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
1
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
64
,
128
,
16
,
2
,
1
,
16
,
16
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
1
,
1
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
1
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
1
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
64
,
128
,
16
,
2
,
2
,
16
,
16
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
1
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
1
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
1
>
// clang-format on
>
;
template
<
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CShuffleDataType
,
typename
DsDataType
,
typename
EDataType
,
typename
ComputeDataType
,
typename
AElementwiseOp
,
typename
BElementwiseOp
,
typename
CDEElementwiseOp
>
using
device_contraction_f64_mk_instance
=
std
::
tuple
<
// clang-format off
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| Compute| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Data| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#####################################| | | | | | | | | | Type| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
128
,
128
,
16
,
1
,
2
,
16
,
16
,
4
,
4
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
1
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
1
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
1
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
128
,
128
,
16
,
2
,
2
,
16
,
16
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
1
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
1
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
1
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
128
,
128
,
64
,
16
,
1
,
2
,
16
,
16
,
4
,
4
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
1
,
0
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
1
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
1
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
128
,
128
,
64
,
16
,
2
,
2
,
16
,
16
,
4
,
4
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
1
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
1
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
1
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
128
,
64
,
128
,
16
,
1
,
2
,
16
,
16
,
4
,
4
,
S
<
8
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
1
,
0
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
1
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
16
>
,
1
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
128
,
64
,
128
,
16
,
2
,
2
,
16
,
16
,
4
,
4
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
1
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
1
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
16
>
,
1
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
128
,
64
,
16
,
1
,
2
,
16
,
16
,
4
,
2
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
1
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
1
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
1
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
128
,
64
,
16
,
2
,
2
,
16
,
16
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
1
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
1
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
1
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
64
,
128
,
16
,
1
,
2
,
16
,
16
,
2
,
4
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
1
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
1
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
1
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
64
,
128
,
16
,
2
,
2
,
16
,
16
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
1
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
1
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
1
>
// clang-format on
>
;
template
<
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CShuffleDataType
,
typename
DsDataType
,
typename
EDataType
,
typename
ComputeDataType
,
typename
AElementwiseOp
,
typename
BElementwiseOp
,
typename
CDEElementwiseOp
>
using
device_contraction_f64_mn_instance
=
std
::
tuple
<
// clang-format off
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| Compute| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Data| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#####################################| | | | | | | | | | Type| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
128
,
128
,
16
,
1
,
1
,
16
,
16
,
4
,
4
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
1
,
0
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
1
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
1
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
128
,
128
,
16
,
2
,
2
,
16
,
16
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
1
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
1
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
1
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
128
,
128
,
64
,
16
,
1
,
1
,
16
,
16
,
4
,
4
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
1
,
0
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
1
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
1
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
128
,
128
,
64
,
16
,
2
,
2
,
16
,
16
,
4
,
4
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
1
,
1
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
1
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
1
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
128
,
64
,
128
,
16
,
1
,
1
,
16
,
16
,
4
,
4
,
S
<
8
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
1
,
0
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
1
,
0
,
1
,
1
,
S
<
1
,
8
,
1
,
16
>
,
1
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
128
,
64
,
128
,
16
,
2
,
2
,
16
,
16
,
4
,
4
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
1
,
1
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
1
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
16
>
,
1
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
128
,
64
,
16
,
1
,
1
,
16
,
16
,
4
,
2
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
1
,
0
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
1
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
1
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
128
,
64
,
16
,
2
,
2
,
16
,
16
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
1
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
1
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
1
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
64
,
128
,
16
,
1
,
1
,
16
,
16
,
2
,
4
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
1
,
0
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
1
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
1
>
,
DeviceContractionMultipleD_Xdl_CShuffle
<
2
,
2
,
2
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ComputeDataType
,
AElementwiseOp
,
BElementwiseOp
,
CDEElementwiseOp
,
GemmMNKPadding
,
1
,
256
,
64
,
128
,
16
,
2
,
2
,
16
,
16
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
1
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
1
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
1
>
// clang-format on
>
;
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/contraction_bilinear.hpp
View file @
7e8230da
...
...
@@ -17,7 +17,6 @@ namespace tensor_operation {
namespace
device
{
namespace
instance
{
#ifdef CK_ENABLE_FP32
// float
void
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
2
,
...
...
@@ -28,7 +27,8 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn
F32
,
PassThrough
,
PassThrough
,
Bilinear
>>>&
instances
);
Bilinear
,
F32
>>>&
instances
);
void
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
...
...
@@ -40,7 +40,8 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn
F32
,
PassThrough
,
PassThrough
,
Bilinear
>>>&
instances
);
Bilinear
,
F32
>>>&
instances
);
void
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
...
...
@@ -52,7 +53,8 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn
F32
,
PassThrough
,
PassThrough
,
Bilinear
>>>&
instances
);
Bilinear
,
F32
>>>&
instances
);
void
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
...
...
@@ -64,10 +66,115 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn
F32
,
PassThrough
,
PassThrough
,
Bilinear
>>>&
instances
);
#endif
Bilinear
,
F32
>>>&
instances
);
void
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_kknn_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
2
,
2
,
F32
,
F32
,
F32_Tuple
,
F32
,
PassThrough
,
PassThrough
,
Bilinear
,
F16
>>>&
instances
);
void
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_knnn_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
2
,
2
,
F32
,
F32
,
F32_Tuple
,
F32
,
PassThrough
,
PassThrough
,
Bilinear
,
F16
>>>&
instances
);
void
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_mknn_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
2
,
2
,
F32
,
F32
,
F32_Tuple
,
F32
,
PassThrough
,
PassThrough
,
Bilinear
,
F16
>>>&
instances
);
void
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_mnnn_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
2
,
2
,
F32
,
F32
,
F32_Tuple
,
F32
,
PassThrough
,
PassThrough
,
Bilinear
,
F16
>>>&
instances
);
void
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_kknn_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
2
,
2
,
F32
,
F32
,
F32_Tuple
,
F32
,
PassThrough
,
PassThrough
,
Bilinear
,
BF16
>>>&
instances
);
void
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_knnn_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
2
,
2
,
F32
,
F32
,
F32_Tuple
,
F32
,
PassThrough
,
PassThrough
,
Bilinear
,
BF16
>>>&
instances
);
void
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_mknn_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
2
,
2
,
F32
,
F32
,
F32_Tuple
,
F32
,
PassThrough
,
PassThrough
,
Bilinear
,
BF16
>>>&
instances
);
void
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_mnnn_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
2
,
2
,
F32
,
F32
,
F32_Tuple
,
F32
,
PassThrough
,
PassThrough
,
Bilinear
,
BF16
>>>&
instances
);
#endif // CK_ENABLE_FP32
#ifdef CK_ENABLE_FP64
// double
void
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
2
,
...
...
@@ -78,7 +185,8 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn
F64
,
PassThrough
,
PassThrough
,
Bilinear
>>>&
instances
);
Bilinear
,
F64
>>>&
instances
);
void
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
...
...
@@ -90,7 +198,8 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_knnn
F64
,
PassThrough
,
PassThrough
,
Bilinear
>>>&
instances
);
Bilinear
,
F64
>>>&
instances
);
void
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
...
...
@@ -102,7 +211,8 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mknn
F64
,
PassThrough
,
PassThrough
,
Bilinear
>>>&
instances
);
Bilinear
,
F64
>>>&
instances
);
void
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
...
...
@@ -114,8 +224,170 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn
F64
,
PassThrough
,
PassThrough
,
Bilinear
>>>&
instances
);
#endif
Bilinear
,
F64
>>>&
instances
);
void
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_kknn_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
2
,
2
,
F64
,
F64
,
F64_Tuple
,
F64
,
PassThrough
,
PassThrough
,
Bilinear
,
F32
>>>&
instances
);
void
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_knnn_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
2
,
2
,
F64
,
F64
,
F64_Tuple
,
F64
,
PassThrough
,
PassThrough
,
Bilinear
,
F32
>>>&
instances
);
void
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_mknn_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
2
,
2
,
F64
,
F64
,
F64_Tuple
,
F64
,
PassThrough
,
PassThrough
,
Bilinear
,
F32
>>>&
instances
);
void
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_mnnn_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
2
,
2
,
F64
,
F64
,
F64_Tuple
,
F64
,
PassThrough
,
PassThrough
,
Bilinear
,
F32
>>>&
instances
);
#endif // CK_ENABLE_FP64
#ifdef CK_ENABLE_FP16
void
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_kknn_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
2
,
2
,
F16
,
F16
,
F16_Tuple
,
F16
,
PassThrough
,
PassThrough
,
Bilinear
,
F32
>>>&
instances
);
void
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_knnn_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
2
,
2
,
F16
,
F16
,
F16_Tuple
,
F16
,
PassThrough
,
PassThrough
,
Bilinear
,
F32
>>>&
instances
);
void
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_mknn_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
2
,
2
,
F16
,
F16
,
F16_Tuple
,
F16
,
PassThrough
,
PassThrough
,
Bilinear
,
F32
>>>&
instances
);
void
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_mnnn_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
2
,
2
,
F16
,
F16
,
F16_Tuple
,
F16
,
PassThrough
,
PassThrough
,
Bilinear
,
F32
>>>&
instances
);
#endif // CK_ENABLE_FP16
#ifdef CK_ENABLE_BF16
void
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_kknn_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
2
,
2
,
BF16
,
BF16
,
BF16_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
Bilinear
,
F32
>>>&
instances
);
void
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_knnn_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
2
,
2
,
BF16
,
BF16
,
BF16_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
Bilinear
,
F32
>>>&
instances
);
void
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_mknn_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
2
,
2
,
BF16
,
BF16
,
BF16_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
Bilinear
,
F32
>>>&
instances
);
void
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_mnnn_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
2
,
2
,
BF16
,
BF16
,
BF16_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
Bilinear
,
F32
>>>&
instances
);
#endif // CK_ENABLE_FP16
// Contraction + Bilinear
template
<
index_t
NumDimM
,
index_t
NumDimN
,
...
...
@@ -123,7 +395,8 @@ template <index_t NumDimM,
typename
ADataType
,
typename
BDataType
,
typename
DDataType
,
typename
EDataType
>
typename
EDataType
,
typename
ComputeDataType
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceContractionMultipleD
<
NumDimM
,
NumDimN
,
...
...
@@ -134,7 +407,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContra
EDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
Bilinear
>>
ck
::
tensor_operation
::
element_wise
::
Bilinear
,
ComputeDataType
>>
{
using
DeviceOp
=
DeviceContractionMultipleD
<
NumDimM
,
NumDimN
,
...
...
@@ -145,45 +419,125 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContra
EDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
Bilinear
>
;
ck
::
tensor_operation
::
element_wise
::
Bilinear
,
ComputeDataType
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
#ifdef CK_ENABLE_FP32
if
constexpr
(
is_same_v
<
ADataType
,
float
>
&&
is_same_v
<
BDataType
,
float
>
&&
is_same_v
<
DDataType
,
float
>
&&
is_same_v
<
EDataType
,
float
>
)
is_same_v
<
EDataType
,
float
>
)
{
if
constexpr
(
NumDimM
==
2
&&
NumDimN
==
2
&&
NumDimK
==
2
)
{
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance
(
op_ptrs
);
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance
(
op_ptrs
);
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance
(
op_ptrs
);
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance
(
op_ptrs
);
if
constexpr
(
is_same_v
<
ComputeDataType
,
float
>
)
{
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance
(
op_ptrs
);
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance
(
op_ptrs
);
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance
(
op_ptrs
);
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ComputeDataType
,
ck
::
half_t
>
)
{
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_kknn_instance
(
op_ptrs
);
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_knnn_instance
(
op_ptrs
);
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_mknn_instance
(
op_ptrs
);
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_mnnn_instance
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ComputeDataType
,
ck
::
bhalf_t
>
)
{
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_kknn_instance
(
op_ptrs
);
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_knnn_instance
(
op_ptrs
);
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_mknn_instance
(
op_ptrs
);
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_mnnn_instance
(
op_ptrs
);
}
}
}
#endif
#endif
// CK_ENABLE_FP32
#ifdef CK_ENABLE_FP64
if
constexpr
(
is_same_v
<
ADataType
,
double
>
&&
is_same_v
<
BDataType
,
double
>
&&
is_same_v
<
DDataType
,
double
>
&&
is_same_v
<
EDataType
,
double
>
)
is_same_v
<
EDataType
,
double
>
)
{
if
constexpr
(
NumDimM
==
2
&&
NumDimN
==
2
&&
NumDimK
==
2
)
{
if
constexpr
(
is_same_v
<
ComputeDataType
,
double
>
)
{
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance
(
op_ptrs
);
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance
(
op_ptrs
);
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance
(
op_ptrs
);
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ComputeDataType
,
float
>
)
{
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_kknn_instance
(
op_ptrs
);
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_knnn_instance
(
op_ptrs
);
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_mknn_instance
(
op_ptrs
);
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_mnnn_instance
(
op_ptrs
);
}
}
}
#endif // CK_ENABLE_FP64
#ifdef CK_ENABLE_FP16
if
constexpr
(
is_same_v
<
ADataType
,
ck
::
half_t
>
&&
is_same_v
<
BDataType
,
ck
::
half_t
>
&&
is_same_v
<
EDataType
,
ck
::
half_t
>
)
{
if
constexpr
(
NumDimM
==
2
&&
NumDimN
==
2
&&
NumDimK
==
2
)
{
if
constexpr
(
is_same_v
<
ComputeDataType
,
float
>
)
{
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_kknn_instance
(
op_ptrs
);
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_knnn_instance
(
op_ptrs
);
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_mknn_instance
(
op_ptrs
);
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_mnnn_instance
(
op_ptrs
);
}
}
}
#endif // CK_ENABLE_FP16
#ifdef CK_ENABLE_BF16
if
constexpr
(
is_same_v
<
ADataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
BDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
EDataType
,
ck
::
bhalf_t
>
)
{
if
constexpr
(
NumDimM
==
2
&&
NumDimN
==
2
&&
NumDimK
==
2
)
{
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance
(
op_ptrs
);
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance
(
op_ptrs
);
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance
(
op_ptrs
);
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance
(
op_ptrs
);
if
constexpr
(
is_same_v
<
ComputeDataType
,
float
>
)
{
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_kknn_instance
(
op_ptrs
);
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_knnn_instance
(
op_ptrs
);
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_mknn_instance
(
op_ptrs
);
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_mnnn_instance
(
op_ptrs
);
}
}
}
#endif
#endif
// CK_ENABLE_BF16
return
op_ptrs
;
}
};
...
...
library/include/ck/library/tensor_operation_instance/gpu/contraction_scale.hpp
View file @
7e8230da
...
...
@@ -17,7 +17,6 @@ namespace tensor_operation {
namespace
device
{
namespace
instance
{
#ifdef CK_ENABLE_FP32
// float
void
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
2
,
...
...
@@ -28,7 +27,8 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instanc
F32
,
PassThrough
,
PassThrough
,
Scale
>>>&
instances
);
Scale
,
F32
>>>&
instances
);
void
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
...
...
@@ -40,7 +40,8 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instanc
F32
,
PassThrough
,
PassThrough
,
Scale
>>>&
instances
);
Scale
,
F32
>>>&
instances
);
void
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
...
...
@@ -52,7 +53,8 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instanc
F32
,
PassThrough
,
PassThrough
,
Scale
>>>&
instances
);
Scale
,
F32
>>>&
instances
);
void
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
...
...
@@ -64,10 +66,115 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instanc
F32
,
PassThrough
,
PassThrough
,
Scale
>>>&
instances
);
#endif
Scale
,
F32
>>>&
instances
);
void
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_kkn_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
2
,
2
,
F32
,
F32
,
Empty_Tuple
,
F32
,
PassThrough
,
PassThrough
,
Scale
,
F16
>>>&
instances
);
void
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_knn_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
2
,
2
,
F32
,
F32
,
Empty_Tuple
,
F32
,
PassThrough
,
PassThrough
,
Scale
,
F16
>>>&
instances
);
void
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_mkn_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
2
,
2
,
F32
,
F32
,
Empty_Tuple
,
F32
,
PassThrough
,
PassThrough
,
Scale
,
F16
>>>&
instances
);
void
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_mnn_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
2
,
2
,
F32
,
F32
,
Empty_Tuple
,
F32
,
PassThrough
,
PassThrough
,
Scale
,
F16
>>>&
instances
);
void
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_kkn_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
2
,
2
,
F32
,
F32
,
Empty_Tuple
,
F32
,
PassThrough
,
PassThrough
,
Scale
,
BF16
>>>&
instances
);
void
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_knn_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
2
,
2
,
F32
,
F32
,
Empty_Tuple
,
F32
,
PassThrough
,
PassThrough
,
Scale
,
BF16
>>>&
instances
);
void
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_mkn_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
2
,
2
,
F32
,
F32
,
Empty_Tuple
,
F32
,
PassThrough
,
PassThrough
,
Scale
,
BF16
>>>&
instances
);
void
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_mnn_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
2
,
2
,
F32
,
F32
,
Empty_Tuple
,
F32
,
PassThrough
,
PassThrough
,
Scale
,
BF16
>>>&
instances
);
#endif // CK_ENABLE_FP32
#ifdef CK_ENABLE_FP64
// double
void
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
2
,
...
...
@@ -78,7 +185,8 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instanc
F64
,
PassThrough
,
PassThrough
,
Scale
>>>&
instances
);
Scale
,
F64
>>>&
instances
);
void
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
...
...
@@ -90,7 +198,8 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instanc
F64
,
PassThrough
,
PassThrough
,
Scale
>>>&
instances
);
Scale
,
F64
>>>&
instances
);
void
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
...
...
@@ -102,7 +211,8 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instanc
F64
,
PassThrough
,
PassThrough
,
Scale
>>>&
instances
);
Scale
,
F64
>>>&
instances
);
void
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
...
...
@@ -114,15 +224,178 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instanc
F64
,
PassThrough
,
PassThrough
,
Scale
>>>&
instances
);
#endif
Scale
,
F64
>>>&
instances
);
void
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_kkn_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
2
,
2
,
F64
,
F64
,
Empty_Tuple
,
F64
,
PassThrough
,
PassThrough
,
Scale
,
F32
>>>&
instances
);
void
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_knn_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
2
,
2
,
F64
,
F64
,
Empty_Tuple
,
F64
,
PassThrough
,
PassThrough
,
Scale
,
F32
>>>&
instances
);
void
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_mkn_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
2
,
2
,
F64
,
F64
,
Empty_Tuple
,
F64
,
PassThrough
,
PassThrough
,
Scale
,
F32
>>>&
instances
);
void
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_mnn_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
2
,
2
,
F64
,
F64
,
Empty_Tuple
,
F64
,
PassThrough
,
PassThrough
,
Scale
,
F32
>>>&
instances
);
#endif // CK_ENABLE_FP64
#ifdef CK_ENABLE_FP16
void
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_kkn_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
2
,
2
,
F16
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
Scale
,
F32
>>>&
instances
);
void
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_knn_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
2
,
2
,
F16
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
Scale
,
F32
>>>&
instances
);
void
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_mkn_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
2
,
2
,
F16
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
Scale
,
F32
>>>&
instances
);
void
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_mnn_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
2
,
2
,
F16
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
Scale
,
F32
>>>&
instances
);
#endif // CK_ENABLE_FP16
#ifdef CK_ENABLE_BF16
void
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_kkn_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
2
,
2
,
BF16
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
Scale
,
F32
>>>&
instances
);
void
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_knn_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
2
,
2
,
BF16
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
Scale
,
F32
>>>&
instances
);
void
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_mkn_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
2
,
2
,
BF16
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
Scale
,
F32
>>>&
instances
);
void
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_mnn_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
2
,
2
,
BF16
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
Scale
,
F32
>>>&
instances
);
#endif // CK_ENABLE_FP16
// Contraction + Scale
template
<
index_t
NumDimM
,
index_t
NumDimN
,
index_t
NumDimK
,
typename
ADataType
,
typename
BDataType
,
typename
EDataType
>
typename
EDataType
,
typename
ComputeDataType
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceContractionMultipleD
<
NumDimM
,
NumDimN
,
...
...
@@ -133,7 +406,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContra
EDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
Scale
>>
ck
::
tensor_operation
::
element_wise
::
Scale
,
ComputeDataType
>>
{
using
DeviceOp
=
DeviceContractionMultipleD
<
NumDimM
,
NumDimN
,
...
...
@@ -144,7 +418,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContra
EDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
Scale
>
;
ck
::
tensor_operation
::
element_wise
::
Scale
,
ComputeDataType
>
;
static
auto
GetInstances
()
{
...
...
@@ -155,34 +430,113 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContra
{
if
constexpr
(
NumDimM
==
2
&&
NumDimN
==
2
&&
NumDimK
==
2
)
{
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance
(
op_ptrs
);
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance
(
op_ptrs
);
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instance
(
op_ptrs
);
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instance
(
op_ptrs
);
if
constexpr
(
is_same_v
<
ComputeDataType
,
float
>
)
{
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance
(
op_ptrs
);
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance
(
op_ptrs
);
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instance
(
op_ptrs
);
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instance
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ComputeDataType
,
ck
::
half_t
>
)
{
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_kkn_instance
(
op_ptrs
);
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_knn_instance
(
op_ptrs
);
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_mkn_instance
(
op_ptrs
);
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_mnn_instance
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ComputeDataType
,
ck
::
bhalf_t
>
)
{
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_kkn_instance
(
op_ptrs
);
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_knn_instance
(
op_ptrs
);
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_mkn_instance
(
op_ptrs
);
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_mnn_instance
(
op_ptrs
);
}
}
}
#endif
#endif
// CK_ENABLE_FP32
#ifdef CK_ENABLE_FP64
if
constexpr
(
is_same_v
<
ADataType
,
double
>
&&
is_same_v
<
BDataType
,
double
>
&&
is_same_v
<
EDataType
,
double
>
)
{
if
constexpr
(
NumDimM
==
2
&&
NumDimN
==
2
&&
NumDimK
==
2
)
{
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance
(
op_ptrs
);
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance
(
op_ptrs
);
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance
(
op_ptrs
);
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance
(
op_ptrs
);
if
constexpr
(
is_same_v
<
ComputeDataType
,
double
>
)
{
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance
(
op_ptrs
);
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance
(
op_ptrs
);
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance
(
op_ptrs
);
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ComputeDataType
,
float
>
)
{
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_kkn_instance
(
op_ptrs
);
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_knn_instance
(
op_ptrs
);
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_mkn_instance
(
op_ptrs
);
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_mnn_instance
(
op_ptrs
);
}
}
}
#endif // CK_ENABLE_FP64
#ifdef CK_ENABLE_FP16
if
constexpr
(
is_same_v
<
ADataType
,
ck
::
half_t
>
&&
is_same_v
<
BDataType
,
ck
::
half_t
>
&&
is_same_v
<
EDataType
,
ck
::
half_t
>
)
{
if
constexpr
(
NumDimM
==
2
&&
NumDimN
==
2
&&
NumDimK
==
2
)
{
if
constexpr
(
is_same_v
<
ComputeDataType
,
float
>
)
{
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_kkn_instance
(
op_ptrs
);
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_knn_instance
(
op_ptrs
);
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_mkn_instance
(
op_ptrs
);
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_mnn_instance
(
op_ptrs
);
}
}
}
#endif // CK_ENABLE_FP16
#ifdef CK_ENABLE_BF16
if
constexpr
(
is_same_v
<
ADataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
BDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
EDataType
,
ck
::
bhalf_t
>
)
{
if
constexpr
(
NumDimM
==
2
&&
NumDimN
==
2
&&
NumDimK
==
2
)
{
if
constexpr
(
is_same_v
<
ComputeDataType
,
float
>
)
{
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_kkn_instance
(
op_ptrs
);
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_knn_instance
(
op_ptrs
);
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_mkn_instance
(
op_ptrs
);
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_mnn_instance
(
op_ptrs
);
}
}
}
#endif
#endif
// CK_ENABLE_BF16
return
op_ptrs
;
}
};
...
...
Prev
1
2
3
4
5
6
7
8
…
10
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment