Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
f64b1375
Commit
f64b1375
authored
Feb 17, 2025
by
coderfeli
Browse files
merge haocong branch
parents
88412f9e
f18cfec4
Changes
124
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
75 additions
and
1082 deletions
+75
-1082
library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp
..._operation_instance/device_operation_instance_factory.hpp
+0
-1
library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp
...include/ck/library/tensor_operation_instance/gpu/gemm.hpp
+8
-36
library/include/ck/library/tensor_operation_instance/gpu/gemm_dl.inc
...lude/ck/library/tensor_operation_instance/gpu/gemm_dl.inc
+30
-0
library/include/ck/library/tensor_operation_instance/gpu/gemm_multiply_multiply.hpp
.../tensor_operation_instance/gpu/gemm_multiply_multiply.hpp
+1
-109
library/include/ck/library/tensor_operation_instance/gpu/gemm_universal.hpp
.../library/tensor_operation_instance/gpu/gemm_universal.hpp
+0
-33
library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_streamk.hpp
.../tensor_operation_instance/gpu/gemm_universal_streamk.hpp
+0
-500
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp
...device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp
+9
-69
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp
...ed_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp
+4
-12
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp
...wd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp
+2
-8
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp
...tion_instance/gpu/grouped_convolution_backward_weight.hpp
+1
-17
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_xdl.inc
..._instance/gpu/grouped_convolution_backward_weight_xdl.inc
+1
-97
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp
...or_operation_instance/gpu/grouped_convolution_forward.hpp
+1
-17
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_comp_xdl.inc
...ion_instance/gpu/grouped_convolution_forward_comp_xdl.inc
+0
-16
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_inter_xdl.inc
...nstance/gpu/grouped_convolution_forward_mem_inter_xdl.inc
+0
-16
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_intra_xdl.inc
...nstance/gpu/grouped_convolution_forward_mem_intra_xdl.inc
+0
-16
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc
...peration_instance/gpu/grouped_convolution_forward_xdl.inc
+0
-16
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl_merged_groups.inc
...nce/gpu/grouped_convolution_forward_xdl_merged_groups.inc
+0
-14
library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_fixed_nk.hpp
...y/tensor_operation_instance/gpu/grouped_gemm_fixed_nk.hpp
+0
-47
library/src/tensor_operation_instance/gpu/CMakeLists.txt
library/src/tensor_operation_instance/gpu/CMakeLists.txt
+17
-54
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instance.cpp
...ice_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instance.cpp
+1
-4
No files found.
library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp
View file @
f64b1375
...
...
@@ -22,7 +22,6 @@ using I8 = int8_t;
using
I32
=
int32_t
;
using
F8
=
ck
::
f8_t
;
using
BF8
=
ck
::
bf8_t
;
using
I4
=
ck
::
pk_i4_t
;
using
Empty_Tuple
=
ck
::
Tuple
<>
;
...
...
library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp
View file @
f64b1375
...
...
@@ -15,9 +15,6 @@
#ifdef DL_KERNELS
#include "gemm_dl.inc"
#endif
#ifdef DPP_KERNELS
#include "gemm_dpp.inc"
#endif
#ifdef CK_USE_WMMA
#include "gemm_wmma.inc"
#endif
...
...
@@ -95,24 +92,32 @@ struct DeviceOperationInstanceFactory<
{
add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances
(
op_ptrs
);
add_device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instances
(
op_ptrs
);
add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_instances
(
op_ptrs
);
add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_irregular_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances
(
op_ptrs
);
add_device_gemm_dl_f16_f16_f16_mk_nk_mn_irregular_instances
(
op_ptrs
);
add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_instances
(
op_ptrs
);
add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_irregular_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances
(
op_ptrs
);
add_device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instances
(
op_ptrs
);
add_device_gemm_dpp_f16_f16_f16_km_kn_mn_instances
(
op_ptrs
);
add_device_gemm_dpp_f16_f16_f16_km_kn_mn_irregular_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances
(
op_ptrs
);
add_device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instances
(
op_ptrs
);
add_device_gemm_dpp_f16_f16_f16_km_nk_mn_instances
(
op_ptrs
);
add_device_gemm_dpp_f16_f16_f16_km_nk_mn_irregular_instances
(
op_ptrs
);
}
}
#endif
...
...
@@ -148,39 +153,6 @@ struct DeviceOperationInstanceFactory<
#endif
#endif // DL_KERNELS
#ifdef DPP_KERNELS
#ifdef CK_ENABLE_FP16
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
BDataType
,
half_t
>
&&
is_same_v
<
CDataType
,
half_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_instances
(
op_ptrs
);
add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_irregular_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_instances
(
op_ptrs
);
add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_irregular_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_gemm_dpp_f16_f16_f16_km_kn_mn_instances
(
op_ptrs
);
add_device_gemm_dpp_f16_f16_f16_km_kn_mn_irregular_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_gemm_dpp_f16_f16_f16_km_nk_mn_instances
(
op_ptrs
);
add_device_gemm_dpp_f16_f16_f16_km_nk_mn_irregular_instances
(
op_ptrs
);
}
}
#endif
#endif // DPP_KERNELS
#ifdef CK_USE_WMMA
#ifdef CK_ENABLE_FP16
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
BDataType
,
half_t
>
&&
...
...
library/include/ck/library/tensor_operation_instance/gpu/gemm_dl.inc
View file @
f64b1375
...
...
@@ -28,6 +28,16 @@ void add_device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instances(
DeviceGemm
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_dpp_f16_f16_f16_km_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_dpp_f16_f16_f16_km_kn_mn_irregular_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
...
...
@@ -38,6 +48,16 @@ void add_device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instances(
DeviceGemm
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_dpp_f16_f16_f16_km_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_dpp_f16_f16_f16_km_nk_mn_irregular_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
...
...
@@ -48,6 +68,16 @@ void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instances(
DeviceGemm
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_irregular_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
...
...
library/include/ck/library/tensor_operation_instance/gpu/gemm_multiply_multiply.hpp
View file @
f64b1375
...
...
@@ -18,7 +18,7 @@ namespace device {
namespace
instance
{
#ifdef CK_ENABLE_FP8
#ifdef CK_ENABLE_BF16
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instances
(
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instances
_part1
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
...
...
@@ -174,86 +174,6 @@ void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_i
PassThrough
,
MultiplyMultiply
>>>&
instances
);
#endif
#ifdef CK_ENABLE_FP16
void
add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
F16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
void
add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
F16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
void
add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_mem_v1_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
F16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
void
add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_mem_v1_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
F16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
void
add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_mem_v2_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
F16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
void
add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_mem_v2_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
F16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
#endif
#endif
#ifdef CK_ENABLE_FP16
void
add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_default_instances_part1
(
...
...
@@ -573,34 +493,6 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
#ifdef CK_ENABLE_FP16
if
constexpr
(
is_same_v
<
ADataType
,
f8_t
>
&&
is_same_v
<
BDataType
,
f8_t
>
&&
is_same_v
<
CDataType
,
half_t
>
)
<<<<<<<
HEAD
=======
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_default_instances
(
op_ptrs
);
add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_kpadding_instances
(
op_ptrs
);
add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_mem_v1_default_instances
(
op_ptrs
);
add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_mem_v1_kpadding_instances
(
op_ptrs
);
add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_mem_v2_default_instances
(
op_ptrs
);
add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_mem_v2_kpadding_instances
(
op_ptrs
);
}
}
#endif
#endif
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_INT8))
if
constexpr
(
is_same_v
<
ADataType
,
int8_t
>
&&
is_same_v
<
BDataType
,
int8_t
>
&&
is_same_v
<
CDataType
,
bhalf_t
>
)
>>>>>>>
origin
/
develop
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
CLayout
,
Row
>
)
...
...
library/include/ck/library/tensor_operation_instance/gpu/gemm_universal.hpp
View file @
f64b1375
...
...
@@ -166,22 +166,11 @@ void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_kpadding_instances
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F16
,
F8
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F16
,
F8
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f16_i4_f16_mk_nk_mn_mem_v2_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F16
,
I4
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_bf16_i4_bf16_mk_nk_mn_mem_v2_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
BF16
,
I4
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F16
,
F8
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
...
...
@@ -821,28 +810,6 @@ struct DeviceOperationInstanceFactory<
}
}
#endif
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
BDataType
,
pk_i4_t
>
&&
is_same_v
<
CDataType
,
half_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_gemm_xdl_universal_f16_i4_f16_mk_nk_mn_mem_v2_default_instances
(
op_ptrs
);
}
}
if
constexpr
(
is_same_v
<
ADataType
,
bhalf_t
>
&&
is_same_v
<
BDataType
,
pk_i4_t
>
&&
is_same_v
<
CDataType
,
bhalf_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_gemm_xdl_universal_bf16_i4_bf16_mk_nk_mn_mem_v2_default_instances
(
op_ptrs
);
}
}
return
op_ptrs
;
}
};
...
...
library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_streamk.hpp
View file @
f64b1375
...
...
@@ -238,403 +238,6 @@ void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_mnkpaddin
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_BF16
void
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Row
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Row
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_mnpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Row
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Row
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_mem_v1_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Row
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_mem_v1_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Row
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_mem_v1_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Row
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_mem_v2_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Row
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_mem_v2_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Row
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_mem_v2_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Row
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_comp_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Row
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_comp_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Row
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_mem_v1_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Row
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_mem_v1_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Row
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_mem_v2_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Row
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_mem_v2_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Row
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_mnpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_mem_v1_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_mem_v1_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_mem_v1_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_mem_v2_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_mem_v2_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_mem_v2_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_comp_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_comp_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_comp_mpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_comp_mkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_mem_v1_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_mem_v1_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_mem_v1_mkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_mem_v2_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_mem_v2_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_mem_v2_mkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm_Streamk_V2
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#if(defined(CK_ENABLE_FP8))
void
add_device_gemm_xdl_universal_streamk_f16_f8_f16_mk_kn_mn_comp_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
...
...
@@ -924,109 +527,6 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemm_S
}
#endif
#ifdef CK_ENABLE_BF16
if
constexpr
(
is_same_v
<
ADataType
,
bhalf_t
>
&&
is_same_v
<
BDataType
,
bhalf_t
>
&&
is_same_v
<
CDataType
,
bhalf_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_mnpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_mnkpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_mem_v1_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_mem_v1_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_mem_v1_mnkpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_mem_v2_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_mem_v2_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_mem_v2_mnkpadding_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_comp_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_comp_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_mem_v1_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_mem_v1_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_mem_v2_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_mem_v2_kpadding_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_mnpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_mnkpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_mem_v1_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_mem_v1_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_mem_v1_mnkpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_mem_v2_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_mem_v2_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_mem_v2_mnkpadding_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_comp_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_comp_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_comp_mpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_comp_mkpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_mem_v1_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_mem_v1_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_mem_v1_mkpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_mem_v2_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_mem_v2_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_mem_v2_mkpadding_instances
(
op_ptrs
);
}
}
#endif
#if(defined(CK_ENABLE_FP8))
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
BDataType
,
f8_t
>
&&
is_same_v
<
CDataType
,
half_t
>
)
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp
View file @
f64b1375
// SPDX-License-Identifier: MIT
// Copyright (c) 2024
-2025
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
...
...
@@ -41,13 +41,11 @@ template <ck::index_t NDimSpatial,
BlockGemmPipelineVersion
PipelineVersion
>
using
device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_f16_generic_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
//#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| NumGroups|
//#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge|
//#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| |
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | |
#if defined(CK_USE_AMD_MFMA_GFX950)
#endif // defined(CK_USE_AMD_MFMA_GFX950)
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
16
,
16
,
32
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
1
,
4
,
false
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
1
,
4
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
1
>
// clang-format on
>
;
...
...
@@ -60,13 +58,11 @@ template <ck::index_t NDimSpatial,
BlockGemmPipelineScheduler
Scheduler
,
BlockGemmPipelineVersion
PipelineVersion
>
using
device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_f16_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
//#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| NumGroups|
//#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge|
//#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| |
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | |
#if defined(CK_USE_AMD_MFMA_GFX950)
#endif // defined(CK_USE_AMD_MFMA_GFX950)
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
16
,
16
,
32
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
1
,
4
,
false
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
1
,
4
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
1
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
32
,
32
,
32
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
2
,
2
,
false
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
2
,
2
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
2
>
,
...
...
@@ -79,28 +75,6 @@ using device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_f16_instances
// clang-format on
>
;
template
<
ck
::
index_t
NDimSpatial
,
typename
ALayout
,
typename
BLayout
,
typename
ELayout
,
ConvolutionBackwardWeightSpecialization
ConvSpec
,
BlockGemmPipelineScheduler
Scheduler
,
BlockGemmPipelineVersion
PipelineVersion
>
using
device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_f16_irregular_instances
=
std
::
tuple
<
// clang-format off
//#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| NumGroups|
//#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge|
//#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| |
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | |
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
48
,
64
,
32
,
8
,
16
,
16
,
3
,
4
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
2
,
0
,
1
>
,
1
,
3
,
4
,
false
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
2
,
0
,
1
>
,
1
,
4
,
4
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
1
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
64
,
48
,
32
,
8
,
16
,
16
,
4
,
3
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
2
,
0
,
1
>
,
1
,
4
,
4
,
false
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
2
,
0
,
1
>
,
1
,
3
,
4
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
1
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
64
,
80
,
32
,
8
,
16
,
16
,
4
,
5
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
2
,
0
,
1
>
,
1
,
4
,
4
,
false
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
2
,
0
,
1
>
,
1
,
5
,
4
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
1
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
64
,
112
,
32
,
8
,
16
,
16
,
4
,
7
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
2
,
0
,
1
>
,
1
,
4
,
4
,
false
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
2
,
0
,
1
>
,
1
,
7
,
4
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
1
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
64
,
208
,
32
,
8
,
16
,
16
,
4
,
13
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
2
,
0
,
1
>
,
1
,
4
,
4
,
false
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
2
,
0
,
1
>
,
1
,
13
,
4
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
1
>
// clang-format on
>
;
template
<
ck
::
index_t
NDimSpatial
,
typename
ALayout
,
typename
BLayout
,
...
...
@@ -110,13 +84,11 @@ template <ck::index_t NDimSpatial,
BlockGemmPipelineVersion
PipelineVersion
>
using
device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_bf16_generic_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
//#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| NumGroups|
//#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge|
//#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| |
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | |
#if defined(CK_USE_AMD_MFMA_GFX950)
#endif // defined(CK_USE_AMD_MFMA_GFX950)
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
16
,
16
,
32
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
1
,
4
,
false
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
1
,
4
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
1
>
// clang-format on
>
;
...
...
@@ -129,13 +101,11 @@ template <ck::index_t NDimSpatial,
BlockGemmPipelineScheduler
Scheduler
,
BlockGemmPipelineVersion
PipelineVersion
>
using
device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_bf16_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
//#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| NumGroups|
//#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge|
//#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| |
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | |
#if defined(CK_USE_AMD_MFMA_GFX950)
#endif // defined(CK_USE_AMD_MFMA_GFX950)
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
16
,
16
,
32
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
1
,
4
,
false
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
1
,
4
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
1
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
32
,
32
,
32
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
2
,
2
,
false
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
2
,
2
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
2
>
,
...
...
@@ -148,28 +118,6 @@ using device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_bf16_instance
// clang-format on
>
;
template
<
ck
::
index_t
NDimSpatial
,
typename
ALayout
,
typename
BLayout
,
typename
ELayout
,
ConvolutionBackwardWeightSpecialization
ConvSpec
,
BlockGemmPipelineScheduler
Scheduler
,
BlockGemmPipelineVersion
PipelineVersion
>
using
device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_bf16_irregular_instances
=
std
::
tuple
<
// clang-format off
//#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| NumGroups|
//#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge|
//#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| |
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | |
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
48
,
64
,
32
,
8
,
16
,
16
,
3
,
4
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
2
,
0
,
1
>
,
1
,
3
,
4
,
false
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
2
,
0
,
1
>
,
1
,
4
,
4
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
1
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
64
,
48
,
32
,
8
,
16
,
16
,
4
,
3
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
2
,
0
,
1
>
,
1
,
4
,
4
,
false
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
2
,
0
,
1
>
,
1
,
3
,
4
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
1
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
64
,
80
,
32
,
8
,
16
,
16
,
4
,
5
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
2
,
0
,
1
>
,
1
,
4
,
4
,
false
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
2
,
0
,
1
>
,
1
,
5
,
4
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
1
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
64
,
112
,
32
,
8
,
16
,
16
,
4
,
7
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
2
,
0
,
1
>
,
1
,
4
,
4
,
false
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
2
,
0
,
1
>
,
1
,
7
,
4
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
1
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
64
,
208
,
32
,
8
,
16
,
16
,
4
,
13
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
2
,
0
,
1
>
,
1
,
4
,
4
,
false
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
2
,
0
,
1
>
,
1
,
13
,
4
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
1
>
// clang-format on
>
;
template
<
ck
::
index_t
NDimSpatial
,
typename
ALayout
,
typename
BLayout
,
...
...
@@ -179,13 +127,11 @@ template <ck::index_t NDimSpatial,
BlockGemmPipelineVersion
PipelineVersion
>
using
device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_f16_generic_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
//#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| NumGroups|
//#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge|
//#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| |
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | |
#if defined(CK_USE_AMD_MFMA_GFX950)
#endif // defined(CK_USE_AMD_MFMA_GFX950)
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
16
,
16
,
32
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
1
,
4
,
false
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
1
,
4
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
1
,
F16
,
F16
,
1
,
1
>
// clang-format on
>
;
...
...
@@ -199,13 +145,11 @@ template <ck::index_t NDimSpatial,
BlockGemmPipelineScheduler
Scheduler
,
BlockGemmPipelineVersion
PipelineVersion
>
using
device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_f16_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
//#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| NumGroups|
//#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge|
//#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| |
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | |
#if defined(CK_USE_AMD_MFMA_GFX950)
#endif // defined(CK_USE_AMD_MFMA_GFX950)
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
16
,
16
,
32
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
1
,
4
,
false
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
1
,
4
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
1
,
F16
,
F16
,
1
,
1
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
32
,
32
,
32
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
2
,
2
,
false
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
2
,
2
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
2
,
F16
,
F16
,
2
,
2
>
,
...
...
@@ -241,13 +185,11 @@ template <ck::index_t NDimSpatial,
BlockGemmPipelineVersion
PipelineVersion
>
using
device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_bf16_generic_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
//#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| NumGroups|
//#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge|
//#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| |
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | |
#if defined(CK_USE_AMD_MFMA_GFX950)
#endif // defined(CK_USE_AMD_MFMA_GFX950)
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
16
,
16
,
32
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
1
,
4
,
false
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
1
,
4
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
1
,
BF16
,
BF16
,
1
,
1
>
// clang-format on
>
;
...
...
@@ -260,13 +202,11 @@ template <ck::index_t NDimSpatial,
BlockGemmPipelineScheduler
Scheduler
,
BlockGemmPipelineVersion
PipelineVersion
>
using
device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_bf16_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
//#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| NumGroups|
//#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge|
//#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| |
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | |
#if defined(CK_USE_AMD_MFMA_GFX950)
#endif // defined(CK_USE_AMD_MFMA_GFX950)
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
16
,
16
,
32
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
1
,
4
,
false
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
1
,
4
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
1
,
BF16
,
BF16
,
1
,
1
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
32
,
32
,
32
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
2
,
2
,
false
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
2
,
2
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
2
,
BF16
,
BF16
,
2
,
2
>
,
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp
View file @
f64b1375
...
...
@@ -56,13 +56,11 @@ template <index_t NDimSpatial,
typename
ELayout
,
ConvolutionForwardSpecialization
ConvSpec
>
using
device_grouped_conv_fwd_xdl_bf16_comp_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
#if defined(__gfx950__)
#else
// Compute friendly
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
DsLayout
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
256
,
256
,
256
,
32
,
8
,
8
,
32
,
32
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
DsLayout
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
256
,
128
,
128
,
64
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
>
,
...
...
@@ -81,7 +79,7 @@ using device_grouped_conv_fwd_xdl_bf16_comp_instances = std::tuple<
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
DsLayout
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
256
,
128
,
64
,
64
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
DsLayout
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
256
,
64
,
128
,
64
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
DsLayout
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
256
,
64
,
64
,
64
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
>
#endif // defined(__gfx950__)
// clang-format on
>
;
...
...
@@ -92,13 +90,11 @@ template <index_t NDimSpatial,
typename
ELayout
,
ConvolutionForwardSpecialization
ConvSpec
>
using
device_grouped_conv_fwd_xdl_f16_comp_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
#if defined(__gfx950__)
#else
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
DsLayout
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
256
,
256
,
256
,
32
,
8
,
8
,
32
,
32
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
DsLayout
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
256
,
128
,
128
,
64
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
DsLayout
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
>
,
...
...
@@ -113,7 +109,6 @@ using device_grouped_conv_fwd_xdl_f16_comp_instances = std::tuple<
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
DsLayout
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
256
,
128
,
256
,
32
,
8
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Interwave
,
BlockGemmPipelineVersion
::
v1
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
DsLayout
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Interwave
,
BlockGemmPipelineVersion
::
v1
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
DsLayout
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
256
,
128
,
128
,
64
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Interwave
,
BlockGemmPipelineVersion
::
v1
>
#endif // defined(__gfx950__)
// clang-format on
>
;
...
...
@@ -143,13 +138,11 @@ template <index_t NDimSpatial,
typename
ELayout
,
ConvolutionForwardSpecialization
ConvSpec
>
using
device_grouped_conv_fwd_xdl_int8_comp_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
#if defined(__gfx950__)
#else
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
DsLayout
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
256
,
256
,
256
,
32
,
8
,
8
,
32
,
32
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
DsLayout
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
DsLayout
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
256
,
256
,
256
,
32
,
8
,
8
,
32
,
32
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
>
,
...
...
@@ -160,7 +153,6 @@ using device_grouped_conv_fwd_xdl_int8_comp_instances = std::tuple<
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
DsLayout
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
256
,
128
,
256
,
32
,
8
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Interwave
,
BlockGemmPipelineVersion
::
v1
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
DsLayout
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Interwave
,
BlockGemmPipelineVersion
::
v1
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
DsLayout
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
256
,
128
,
128
,
64
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Interwave
,
BlockGemmPipelineVersion
::
v1
>
#endif // defined(__gfx950__)
// clang-format on
>
;
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp
View file @
f64b1375
...
...
@@ -40,18 +40,15 @@ template <index_t NDimSpatial,
typename
ELayout
,
ConvolutionForwardSpecialization
ConvSpec
>
using
device_grouped_conv_fwd_xdl_merged_groups_bf16_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| ACompute| BCompute| BlockGemm| NumGroups|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Type| Type| Pipeline| ToMerge|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | Scheduler| |
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
#if defined(__gfx950__)
#else
// Instances with NumGroupsPerBatch > 1
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
DsLayout
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
16
,
16
,
4
,
4
,
16
,
16
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
BF16
,
BF16
,
LoopScheduler
::
Default
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
DsLayout
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
16
,
16
,
4
,
4
,
16
,
16
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
BF16
,
BF16
,
LoopScheduler
::
Default
,
16
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
DsLayout
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
16
,
16
,
4
,
4
,
16
,
16
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
BF16
,
BF16
,
LoopScheduler
::
Default
,
32
>
#endif // defined(__gfx950__)
// clang-format on
>
;
...
...
@@ -62,18 +59,15 @@ template <index_t NDimSpatial,
typename
ELayout
,
ConvolutionForwardSpecialization
ConvSpec
>
using
device_grouped_conv_fwd_xdl_merged_groups_f16_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
#if defined(__gfx950__)
#else
// Instances with NumGroupsPerBatch > 1
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
DsLayout
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
16
,
16
,
4
,
4
,
16
,
16
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
F16
,
F16
,
LoopScheduler
::
Default
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
DsLayout
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
16
,
16
,
4
,
4
,
16
,
16
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
F16
,
F16
,
LoopScheduler
::
Default
,
16
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
DsLayout
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
16
,
16
,
4
,
4
,
16
,
16
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
F16
,
F16
,
LoopScheduler
::
Default
,
32
>
#endif // defined(__gfx950__)
// clang-format on
>
;
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp
View file @
f64b1375
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -358,10 +358,6 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
op_ptrs
);
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_instances
(
op_ptrs
);
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev2_irregular_instances
(
op_ptrs
);
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_irregular_instances
(
op_ptrs
);
}
#endif
#ifdef CK_ENABLE_BF16
...
...
@@ -387,10 +383,6 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
op_ptrs
);
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev5_instances
(
op_ptrs
);
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev2_irregular_instances
(
op_ptrs
);
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev5_irregular_instances
(
op_ptrs
);
}
#endif
}
...
...
@@ -486,10 +478,6 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
op_ptrs
);
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_instances
(
op_ptrs
);
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_irregular_instances
(
op_ptrs
);
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_irregular_instances
(
op_ptrs
);
}
#endif
#ifdef CK_ENABLE_BF16
...
...
@@ -515,10 +503,6 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
op_ptrs
);
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev5_instances
(
op_ptrs
);
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev2_irregular_instances
(
op_ptrs
);
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev5_irregular_instances
(
op_ptrs
);
}
#endif
#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_xdl.inc
View file @
f64b1375
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -149,30 +149,6 @@ void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_p
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev2_irregular_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
2
,
NHWGC
,
GKYXC
,
NHWGK
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev5_irregular_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
2
,
NHWGC
,
GKYXC
,
NHWGK
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_bf16_pipev1_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
2
,
NGCHW
,
...
...
@@ -258,30 +234,6 @@ void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pi
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev2_irregular_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
2
,
NHWGC
,
GKYXC
,
NHWGK
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_irregular_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
2
,
NHWGC
,
GKYXC
,
NHWGK
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev1_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
2
,
NGCHW
,
...
...
@@ -432,30 +384,6 @@ void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf1
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev2_irregular_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
3
,
NDHWGC
,
GKZYXC
,
NDHWGK
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev5_irregular_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
3
,
NDHWGC
,
GKZYXC
,
NDHWGK
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev1_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
3
,
NGCDHW
,
...
...
@@ -541,30 +469,6 @@ void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_irregular_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
3
,
NDHWGC
,
GKZYXC
,
NDHWGK
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_irregular_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
3
,
NDHWGC
,
GKZYXC
,
NDHWGK
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev1_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
3
,
NGCDHW
,
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp
View file @
f64b1375
...
...
@@ -304,23 +304,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
op_ptrs
);
}
#endif
#ifdef CK_ENABLE_BF16
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
WeiDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
OutDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
AComputeType
,
ck
::
bhalf_t
>
&&
is_same_v
<
BComputeType
,
ck
::
bhalf_t
>
)
{
add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_bf16_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_comp_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_mem_intra_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_mem_inter_instances
(
op_ptrs
);
}
#endif
#ifdef CK_ENABLE_INT8
if
constexpr
(
is_same_v
<
InDataType
,
int8_t
>
&&
is_same_v
<
WeiDataType
,
int8_t
>
&&
is_same_v
<
OutDataType
,
int8_t
>
&&
is_same_v
<
AComputeType
,
int8_t
>
&&
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_comp_xdl.inc
View file @
f64b1375
...
...
@@ -90,22 +90,6 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_comp_instances(
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_BF16
void
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_comp_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NGCHW
,
GKYXC
,
Empty_Tuple
,
NGKHW
,
BF16
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_FP32
void
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_comp_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_inter_xdl.inc
View file @
f64b1375
...
...
@@ -90,22 +90,6 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_mem_inter_instances
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_BF16
void
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_mem_inter_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NGCHW
,
GKYXC
,
Empty_Tuple
,
NGKHW
,
BF16
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_FP32
void
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_mem_inter_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_intra_xdl.inc
View file @
f64b1375
...
...
@@ -90,22 +90,6 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_mem_intra_instances
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_BF16
void
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_mem_intra_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NGCHW
,
GKYXC
,
Empty_Tuple
,
NGKHW
,
BF16
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_FP32
void
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_mem_intra_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc
View file @
f64b1375
...
...
@@ -204,22 +204,6 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_instances(
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_BF16
void
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NGCHW
,
GKYXC
,
Empty_Tuple
,
NGKHW
,
BF16
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_FP32
void
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl_merged_groups.inc
View file @
f64b1375
...
...
@@ -23,20 +23,6 @@ void add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_inst
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NGCHW
,
GKYXC
,
Empty_Tuple
,
NGKHW
,
BF16
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_FP16
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_fixed_nk.hpp
View file @
f64b1375
...
...
@@ -126,35 +126,6 @@ void add_device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_nk_mn_instances(
PassThrough
>>>&
instances
);
#endif
// bf16_inputA bf16_inputB
#if defined(CK_ENABLE_BF16)
void
add_device_grouped_gemm_xdl_fixed_nk_bf16_bf16_bf16_mk_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemmFixedNK
<
Row
,
Row
,
Empty_Tuple
,
Row
,
BF16
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_gemm_xdl_fixed_nk_bf16_bf16_bf16_mk_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemmFixedNK
<
Row
,
Col
,
Empty_Tuple
,
Row
,
BF16
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif // CK_ENABLE_BF16
template
<
typename
ALayout
,
typename
BLayout
,
typename
ELayout
,
...
...
@@ -256,24 +227,6 @@ struct DeviceOperationInstanceFactory<
}
#endif
// bf16_inputA bf16_inputB
#if defined(CK_ENABLE_BF16)
if
constexpr
(
is_same_v
<
ADataType
,
bhalf_t
>
&&
is_same_v
<
BDataType
,
bhalf_t
>
&&
is_same_v
<
EDataType
,
bhalf_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_grouped_gemm_xdl_fixed_nk_bf16_bf16_bf16_mk_kn_mn_instances
(
op_ptrs
);
}
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_grouped_gemm_xdl_fixed_nk_bf16_bf16_bf16_mk_nk_mn_instances
(
op_ptrs
);
}
}
#endif // CK_ENABLE_BF16
return
op_ptrs
;
}
};
...
...
library/src/tensor_operation_instance/gpu/CMakeLists.txt
100755 → 100644
View file @
f64b1375
...
...
@@ -39,13 +39,6 @@ function(add_instance_library INSTANCE_NAME)
set
(
INST_TARGETS
${
SUPPORTED_GPU_TARGETS
}
)
# Do not build DPP instances if DPP_KERNELS macro is not set
foreach
(
source IN LISTS ARGN
)
if
(
NOT DEFINED DPP_KERNELS AND source MATCHES
"_dpp"
)
message
(
"removing dpp instance
${
source
}
"
)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
endif
()
endforeach
()
# Do not build DL instances if DL_KERNELS macro is not set
foreach
(
source IN LISTS ARGN
)
if
(
NOT DEFINED DL_KERNELS AND source MATCHES
"_dl"
)
...
...
@@ -69,7 +62,7 @@ function(add_instance_library INSTANCE_NAME)
endforeach
()
# Do not build mha instances if gfx94 or gfx90a targets are not on the target list
foreach
(
source IN LISTS ARGN
)
if
(
NOT INST_TARGETS MATCHES
"gfx94"
AND NOT INST_TARGETS MATCHES
"gfx90a"
AND
NOT INST_TARGETS MATCHES
"gfx95"
AND
source MATCHES
"mha"
)
if
(
NOT INST_TARGETS MATCHES
"gfx94"
AND NOT INST_TARGETS MATCHES
"gfx90a"
AND source MATCHES
"mha"
)
message
(
"removing mha instance
${
source
}
"
)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
endif
()
...
...
@@ -77,25 +70,25 @@ function(add_instance_library INSTANCE_NAME)
# Do not build gemm_universal_f8 or gemm_multiply_multiply_f8 for any targets except gfx94
if
(
NOT CK_USE_FP8_ON_UNSUPPORTED_ARCH
)
foreach
(
source IN LISTS ARGN
)
if
(
NOT INST_TARGETS MATCHES
"gfx94"
AND
NOT INST_TARGETS MATCHES
"gfx95"
AND
source MATCHES
"gemm_multiply_multiply_xdl_f8"
)
if
(
NOT INST_TARGETS MATCHES
"gfx94"
AND source MATCHES
"gemm_multiply_multiply_xdl_f8"
)
message
(
"removing gemm_multiply_multiply_f8 instance
${
source
}
"
)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
endif
()
endforeach
()
foreach
(
source IN LISTS ARGN
)
if
(
NOT INST_TARGETS MATCHES
"gfx94"
AND
NOT INST_TARGETS MATCHES
"gfx95"
AND
source MATCHES
"gemm_xdl_universal"
AND source MATCHES
"_f8_"
)
if
(
NOT INST_TARGETS MATCHES
"gfx94"
AND source MATCHES
"gemm_xdl_universal"
AND source MATCHES
"_f8_"
)
message
(
"removing gemm_universal_f8 instance
${
source
}
"
)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
endif
()
endforeach
()
foreach
(
source IN LISTS ARGN
)
if
(
NOT INST_TARGETS MATCHES
"gfx94"
AND
NOT INST_TARGETS MATCHES
"gfx95"
AND
source MATCHES
"batched_gemm_xdl_universal"
AND source MATCHES
"_f8_"
)
if
(
NOT INST_TARGETS MATCHES
"gfx94"
AND source MATCHES
"batched_gemm_xdl_universal"
AND source MATCHES
"_f8_"
)
message
(
"removing batched_gemm_universal_f8 instance
${
source
}
"
)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
endif
()
endforeach
()
foreach
(
source IN LISTS ARGN
)
if
(
NOT INST_TARGETS MATCHES
"gfx94"
AND
NOT INST_TARGETS MATCHES
"gfx95"
AND
source MATCHES
"gemm_xdl_universal_streamk"
AND source MATCHES
"_f8_"
)
if
(
NOT INST_TARGETS MATCHES
"gfx94"
AND source MATCHES
"gemm_xdl_universal_streamk"
AND source MATCHES
"_f8_"
)
message
(
"removing gemm_universal_streamk_f8 instance
${
source
}
"
)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
endif
()
...
...
@@ -109,7 +102,7 @@ function(add_instance_library INSTANCE_NAME)
if
(
source MATCHES
"_xdl"
)
list
(
REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic
)
elseif
(
source MATCHES
"_wmma"
)
list
(
REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030
gfx950
)
list
(
REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030
)
elseif
(
source MATCHES
"mha"
)
list
(
REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx908 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic
)
endif
()
...
...
@@ -190,10 +183,6 @@ FOREACH(subdir_path ${dir_list})
message
(
"bf8 instance found!"
)
set
(
add_inst 1
)
endif
()
if
((
"
${
cmake_instance
}
"
MATCHES
"_bf16"
OR
"
${
cmake_instance
}
"
MATCHES
"_b16"
)
AND DTYPES MATCHES
"bf16"
)
message
(
"bf16 instance found!"
)
set
(
add_inst 1
)
endif
()
if
((
"
${
cmake_instance
}
"
MATCHES
"_fp16"
OR
"
${
cmake_instance
}
"
MATCHES
"_f16"
)
AND DTYPES MATCHES
"fp16"
)
message
(
"fp16 instance found!"
)
set
(
add_inst 1
)
...
...
@@ -206,6 +195,10 @@ FOREACH(subdir_path ${dir_list})
message
(
"fp64 instance found!"
)
set
(
add_inst 1
)
endif
()
if
(
"
${
cmake_instance
}
"
MATCHES
"_bf16"
AND DTYPES MATCHES
"bf16"
)
message
(
"bf16 instance found!"
)
set
(
add_inst 1
)
endif
()
if
((
"
${
cmake_instance
}
"
MATCHES
"_int8"
OR
"
${
cmake_instance
}
"
MATCHES
"_i8"
)
AND DTYPES MATCHES
"int8"
)
message
(
"int8 instance found!"
)
set
(
add_inst 1
)
...
...
@@ -285,14 +278,9 @@ ENDFOREACH()
if
(
CK_DEVICE_OTHER_INSTANCES
)
add_library
(
device_other_operations
${
CK_DEVICE_OTHER_INSTANCES
}
)
add_library
(
device_other_operations
STATIC
${
CK_DEVICE_OTHER_INSTANCES
}
)
add_library
(
composablekernels::device_other_operations ALIAS device_other_operations
)
set_target_properties
(
device_other_operations PROPERTIES POSITION_INDEPENDENT_CODE ON
)
set_target_properties
(
device_other_operations
PROPERTIES
VERSION
${
CMAKE_PROJECT_VERSION
}
SOVERSION
${
CMAKE_PROJECT_VERSION_MAJOR
}
)
target_include_directories
(
device_other_operations PUBLIC
$<INSTALL_INTERFACE:
${
CMAKE_INSTALL_INCLUDEDIR
}
/ck>
$<INSTALL_INTERFACE:
${
CMAKE_INSTALL_INCLUDEDIR
}
/ck/utility>
...
...
@@ -321,15 +309,10 @@ if(CK_DEVICE_OTHER_INSTANCES)
)
endif
()
if
(
CK_DEVICE_GEMM_INSTANCES
)
add_library
(
device_gemm_operations
${
CK_DEVICE_GEMM_INSTANCES
}
)
add_library
(
device_gemm_operations
STATIC
${
CK_DEVICE_GEMM_INSTANCES
}
)
add_library
(
composablekernels::device_gemm_operations ALIAS device_gemm_operations
)
target_compile_features
(
device_gemm_operations PUBLIC
)
set_target_properties
(
device_gemm_operations PROPERTIES POSITION_INDEPENDENT_CODE ON
)
set_target_properties
(
device_gemm_operations
PROPERTIES
VERSION
${
CMAKE_PROJECT_VERSION
}
SOVERSION
${
CMAKE_PROJECT_VERSION_MAJOR
}
)
target_include_directories
(
device_gemm_operations PUBLIC
$<INSTALL_INTERFACE:
${
CMAKE_INSTALL_INCLUDEDIR
}
/ck/library/tensor_operation_instance/gpu>
)
...
...
@@ -342,15 +325,10 @@ if(CK_DEVICE_GEMM_INSTANCES)
)
endif
()
if
(
CK_DEVICE_CONV_INSTANCES
)
add_library
(
device_conv_operations
${
CK_DEVICE_CONV_INSTANCES
}
)
add_library
(
device_conv_operations
STATIC
${
CK_DEVICE_CONV_INSTANCES
}
)
add_library
(
composablekernels::device_conv_operations ALIAS device_conv_operations
)
target_compile_features
(
device_conv_operations PUBLIC
)
set_target_properties
(
device_conv_operations PROPERTIES POSITION_INDEPENDENT_CODE ON
)
set_target_properties
(
device_conv_operations
PROPERTIES
VERSION
${
CMAKE_PROJECT_VERSION
}
SOVERSION
${
CMAKE_PROJECT_VERSION_MAJOR
}
)
target_include_directories
(
device_conv_operations PUBLIC
$<INSTALL_INTERFACE:
${
CMAKE_INSTALL_INCLUDEDIR
}
/ck/library/tensor_operation_instance/gpu>
$<INSTALL_INTERFACE:
${
CMAKE_INSTALL_INCLUDEDIR
}
/ck/library/tensor_operation_instance/gpu/conv_tensor_rearrange>
...
...
@@ -368,13 +346,8 @@ if(CK_DEVICE_CONV_INSTANCES)
endif
()
if
(
CK_DEVICE_MHA_INSTANCES
)
set
(
gpu_list
${
INST_TARGETS
}
)
if
(
gpu_list MATCHES
"gfx94"
OR gpu_list MATCHES
"gfx90a"
OR gpu_list MATCHES
"gfx95"
)
add_library
(
device_mha_operations
${
CK_DEVICE_MHA_INSTANCES
}
)
set_target_properties
(
device_mha_operations
PROPERTIES
VERSION
${
CMAKE_PROJECT_VERSION
}
SOVERSION
${
CMAKE_PROJECT_VERSION_MAJOR
}
)
if
(
gpu_list MATCHES
"gfx94"
OR gpu_list MATCHES
"gfx90a"
)
add_library
(
device_mha_operations STATIC
${
CK_DEVICE_MHA_INSTANCES
}
)
add_library
(
composablekernels::device_mha_operations ALIAS device_mha_operations
)
target_compile_features
(
device_mha_operations PUBLIC
)
set_target_properties
(
device_mha_operations PROPERTIES POSITION_INDEPENDENT_CODE ON
)
...
...
@@ -389,15 +362,10 @@ if(CK_DEVICE_MHA_INSTANCES)
endif
()
endif
()
if
(
CK_DEVICE_CONTRACTION_INSTANCES
)
add_library
(
device_contraction_operations
${
CK_DEVICE_CONTRACTION_INSTANCES
}
)
add_library
(
device_contraction_operations
STATIC
${
CK_DEVICE_CONTRACTION_INSTANCES
}
)
add_library
(
composablekernels::device_contraction_operations ALIAS device_contraction_operations
)
target_compile_features
(
device_contraction_operations PUBLIC
)
set_target_properties
(
device_contraction_operations PROPERTIES POSITION_INDEPENDENT_CODE ON
)
set_target_properties
(
device_contraction_operations
PROPERTIES
VERSION
${
CMAKE_PROJECT_VERSION
}
SOVERSION
${
CMAKE_PROJECT_VERSION_MAJOR
}
)
target_include_directories
(
device_contraction_operations PUBLIC
$<INSTALL_INTERFACE:
${
CMAKE_INSTALL_INCLUDEDIR
}
/ck/library/tensor_operation_instance/gpu>
$<INSTALL_INTERFACE:
${
CMAKE_INSTALL_INCLUDEDIR
}
/ck/library/tensor_operation_instance/gpu/contraction>
...
...
@@ -411,15 +379,10 @@ if(CK_DEVICE_CONTRACTION_INSTANCES)
)
endif
()
if
(
CK_DEVICE_REDUCTION_INSTANCES
)
add_library
(
device_reduction_operations
${
CK_DEVICE_REDUCTION_INSTANCES
}
)
add_library
(
device_reduction_operations
STATIC
${
CK_DEVICE_REDUCTION_INSTANCES
}
)
add_library
(
composablekernels::device_reduction_operations ALIAS device_reduction_operations
)
target_compile_features
(
device_reduction_operations PUBLIC
)
set_target_properties
(
device_reduction_operations PROPERTIES POSITION_INDEPENDENT_CODE ON
)
set_target_properties
(
device_reduction_operations
PROPERTIES
VERSION
${
CMAKE_PROJECT_VERSION
}
SOVERSION
${
CMAKE_PROJECT_VERSION_MAJOR
}
)
target_include_directories
(
device_reduction_operations PUBLIC
$<INSTALL_INTERFACE:
${
CMAKE_INSTALL_INCLUDEDIR
}
/ck/library/tensor_operation_instance/gpu/reduce>
)
...
...
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instance.cpp
View file @
f64b1375
...
...
@@ -27,15 +27,12 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
// Compilation parameters for a[k, m] * b[k, n] = c[m, n]
using
device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumGemmK| LoopScheduler| Pipeline|
//##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| Prefetch| | |
//##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| Stage | | |
//##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// pipeline v1, 1 wave
#if defined(CK_USE_AMD_MFMA_GFX950)
DeviceBatchedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
128
,
128
,
4
,
16
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
7
,
1
,
1
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
,
#endif // defined(CK_USE_AMD_MFMA_GFX950)
DeviceBatchedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
7
,
1
,
1
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
,
DeviceBatchedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
7
,
1
,
1
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
,
DeviceBatchedGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
128
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
7
,
1
,
1
,
LoopScheduler
::
Default
,
PipelineVersion
::
v1
>
,
...
...
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment