Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
667cd6ab
Commit
667cd6ab
authored
Nov 05, 2024
by
illsilin
Browse files
merge from public repo
parents
7d50244e
365f39ae
Changes
121
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
563 additions
and
26 deletions
+563
-26
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_dynamic_op_instance.hpp
...v_fwd/device_grouped_conv_fwd_xdl_dynamic_op_instance.hpp
+12
-8
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp
...fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp
+19
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp
...ped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp
+37
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp
...wd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp
+19
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp
...or_operation_instance/gpu/grouped_convolution_forward.hpp
+43
-2
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_comp_xdl.inc
...ion_instance/gpu/grouped_convolution_forward_comp_xdl.inc
+32
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_inter_xdl.inc
...nstance/gpu/grouped_convolution_forward_mem_inter_xdl.inc
+32
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_intra_xdl.inc
...nstance/gpu/grouped_convolution_forward_mem_intra_xdl.inc
+32
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc
...peration_instance/gpu/grouped_convolution_forward_xdl.inc
+32
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl_large_tensor.inc
...ance/gpu/grouped_convolution_forward_xdl_large_tensor.inc
+16
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl_merged_groups.inc
...nce/gpu/grouped_convolution_forward_xdl_merged_groups.inc
+30
-0
library/src/tensor_operation_instance/gpu/CMakeLists.txt
library/src/tensor_operation_instance/gpu/CMakeLists.txt
+30
-6
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn.hpp
...device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn.hpp
+6
-4
library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn.hpp
...f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn.hpp
+3
-2
library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn.hpp
...f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn.hpp
+6
-4
library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt
..._operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt
+11
-0
library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_comp_instance.cpp
...d_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_comp_instance.cpp
+39
-0
library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_comp_instance.cpp
...d_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_comp_instance.cpp
+64
-0
library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_instance.cpp
...rouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_instance.cpp
+38
-0
library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_instance.cpp
...rouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_instance.cpp
+62
-0
No files found.
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_dynamic_op_instance.hpp
View file @
667cd6ab
...
@@ -53,8 +53,8 @@ using device_grouped_conv_fwd_xdl_dynamic_op_bf16_instances = std::tuple<
...
@@ -53,8 +53,8 @@ using device_grouped_conv_fwd_xdl_dynamic_op_bf16_instances = std::tuple<
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
Tuple
<>
,
BF16
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
Tuple
<>
,
BF16
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
>
,
// instances for small conv.K and conv.C
// instances for small conv.K and conv.C
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
Tuple
<>
,
BF16
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
Tuple
<>
,
BF16
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
Tuple
<>
,
BF16
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
Tuple
<>
,
BF16
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
#if 0 // Enable with dynamic op optimizations (at now generating a lot of virtual functions cause long compilation time)
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF16, BF16, F32, BF16, Tuple<>, BF16, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF16, BF16, F32, BF16, Tuple<>, BF16, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF16, BF16, F32, BF16, Tuple<>, BF16, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF16, BF16, F32, BF16, Tuple<>, BF16, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF16, BF16, F32, BF16, Tuple<>, BF16, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF16, BF16, F32, BF16, Tuple<>, BF16, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
...
@@ -68,6 +68,7 @@ using device_grouped_conv_fwd_xdl_dynamic_op_bf16_instances = std::tuple<
...
@@ -68,6 +68,7 @@ using device_grouped_conv_fwd_xdl_dynamic_op_bf16_instances = std::tuple<
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF16, BF16, F32, BF16, Tuple<>, BF16, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF16, BF16, F32, BF16, Tuple<>, BF16, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF16, BF16, F32, BF16, Tuple<>, BF16, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF16, BF16, F32, BF16, Tuple<>, BF16, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF16, BF16, F32, BF16, Tuple<>, BF16, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF16, BF16, F32, BF16, Tuple<>, BF16, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>
#endif
// clang-format on
// clang-format on
>
;
>
;
...
@@ -87,8 +88,8 @@ using device_grouped_conv_fwd_xdl_dynamic_op_f16_instances = std::tuple<
...
@@ -87,8 +88,8 @@ using device_grouped_conv_fwd_xdl_dynamic_op_f16_instances = std::tuple<
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
Tuple
<>
,
F16
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
Tuple
<>
,
F16
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
>
,
// instances for small conv.K and conv.C
// instances for small conv.K and conv.C
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
Tuple
<>
,
F16
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
Tuple
<>
,
F16
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
Tuple
<>
,
F16
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
Tuple
<>
,
F16
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
#if 0 // Enable with dynamic op optimizations (at now generating a lot of virtual functions cause long compilation time)
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, Tuple<>, F16, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, Tuple<>, F16, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, Tuple<>, F16, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, Tuple<>, F16, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, Tuple<>, F16, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, Tuple<>, F16, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
...
@@ -102,6 +103,7 @@ using device_grouped_conv_fwd_xdl_dynamic_op_f16_instances = std::tuple<
...
@@ -102,6 +103,7 @@ using device_grouped_conv_fwd_xdl_dynamic_op_f16_instances = std::tuple<
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, Tuple<>, F16, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, Tuple<>, F16, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, Tuple<>, F16, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, Tuple<>, F16, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, Tuple<>, F16, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, Tuple<>, F16, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>
#endif
// clang-format on
// clang-format on
>
;
>
;
...
@@ -121,8 +123,8 @@ using device_grouped_conv_fwd_xdl_dynamic_op_f32_instances = std::tuple<
...
@@ -121,8 +123,8 @@ using device_grouped_conv_fwd_xdl_dynamic_op_f32_instances = std::tuple<
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F32
,
F32
,
F32
,
F32
,
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
64
,
16
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F32
,
F32
,
F32
,
F32
,
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
64
,
16
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
>
,
// instances for small conv.K and conv.C
// instances for small conv.K and conv.C
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F32
,
F32
,
F32
,
F32
,
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
32
,
16
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F32
,
F32
,
F32
,
F32
,
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
32
,
16
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F32
,
F32
,
F32
,
F32
,
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
128
,
128
,
16
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F32
,
F32
,
F32
,
F32
,
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
128
,
128
,
16
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
>
#if 0 // Enable with dynamic op optimizations (at now generating a lot of virtual functions cause long compilation time)
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>,
...
@@ -136,6 +138,7 @@ using device_grouped_conv_fwd_xdl_dynamic_op_f32_instances = std::tuple<
...
@@ -136,6 +138,7 @@ using device_grouped_conv_fwd_xdl_dynamic_op_f32_instances = std::tuple<
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 1, 128, 32, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 1, 128, 32, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 1, 64, 32, 64, 16, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 1, 64, 32, 64, 16, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>
#endif
// clang-format on
// clang-format on
>
;
>
;
...
@@ -155,8 +158,8 @@ using device_grouped_conv_fwd_xdl_dynamic_op_int8_instances = std::tuple<
...
@@ -155,8 +158,8 @@ using device_grouped_conv_fwd_xdl_dynamic_op_int8_instances = std::tuple<
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
Tuple
<>
,
int8_t
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
Tuple
<>
,
int8_t
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
>
,
// instances for small conv.K and conv.C
// instances for small conv.K and conv.C
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
Tuple
<>
,
int8_t
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
Tuple
<>
,
int8_t
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
Tuple
<>
,
int8_t
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
Tuple
<>
,
int8_t
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
#if 0 // Enable with dynamic op optimizations (at now generating a lot of virtual functions cause long compilation time)
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, int8_t, int8_t, int32_t, int8_t, Tuple<>, int8_t, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, int8_t, int8_t, int32_t, int8_t, Tuple<>, int8_t, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, int8_t, int8_t, int32_t, int8_t, Tuple<>, int8_t, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, int8_t, int8_t, int32_t, int8_t, Tuple<>, int8_t, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, int8_t, int8_t, int32_t, int8_t, Tuple<>, int8_t, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, int8_t, int8_t, int32_t, int8_t, Tuple<>, int8_t, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
...
@@ -170,6 +173,7 @@ using device_grouped_conv_fwd_xdl_dynamic_op_int8_instances = std::tuple<
...
@@ -170,6 +173,7 @@ using device_grouped_conv_fwd_xdl_dynamic_op_int8_instances = std::tuple<
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, int8_t, int8_t, int32_t, int8_t, Tuple<>, int8_t, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, int8_t, int8_t, int32_t, int8_t, Tuple<>, int8_t, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, int8_t, int8_t, int32_t, int8_t, Tuple<>, int8_t, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, int8_t, int8_t, int32_t, int8_t, Tuple<>, int8_t, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, int8_t, int8_t, int32_t, int8_t, Tuple<>, int8_t, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, int8_t, int8_t, int32_t, int8_t, Tuple<>, int8_t, PassThrough, PassThrough, DynamicUnaryOp, ConvSpec, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>
#endif
// clang-format on
// clang-format on
>
;
>
;
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp
View file @
667cd6ab
...
@@ -87,6 +87,25 @@ using device_grouped_conv_fwd_xdl_large_tensor_f32_instances = std::tuple<
...
@@ -87,6 +87,25 @@ using device_grouped_conv_fwd_xdl_large_tensor_f32_instances = std::tuple<
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F32
,
F32
,
F32
,
F32
,
DsLayout
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
256
,
128
,
16
,
4
,
4
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
>
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F32
,
F32
,
F32
,
F32
,
DsLayout
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
256
,
128
,
16
,
4
,
4
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
>
// clang-format on
// clang-format on
>
;
>
;
template
<
index_t
NDimSpatial
,
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
ELayout
,
ConvolutionForwardSpecialization
ConvSpec
>
using
device_grouped_conv_fwd_xdl_large_tensor_int8_instances
=
std
::
tuple
<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instance
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
DsLayout
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
>
,
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
DsLayout
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
// clang-format on
>
;
}
// namespace instance
}
// namespace instance
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp
View file @
667cd6ab
...
@@ -154,6 +154,43 @@ using device_grouped_conv_fwd_xdl_f32_mem_instances = std::tuple<
...
@@ -154,6 +154,43 @@ using device_grouped_conv_fwd_xdl_f32_mem_instances = std::tuple<
// clang-format on
// clang-format on
>
;
>
;
template
<
index_t
NDimSpatial
,
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
ELayout
,
ConvolutionForwardSpecialization
ConvSpec
,
BlockGemmPipelineScheduler
BlkGemmPipeSched
>
using
device_grouped_conv_fwd_xdl_int8_mem_instances
=
std
::
tuple
<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
DsLayout
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
128
,
32
,
16
,
64
,
8
,
8
,
16
,
16
,
1
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
2
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
DsLayout
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
64
,
16
,
16
,
128
,
8
,
8
,
16
,
16
,
1
,
1
,
S
<
16
,
4
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
16
,
4
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
DsLayout
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
64
,
16
,
16
,
64
,
8
,
8
,
16
,
16
,
1
,
1
,
S
<
8
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
DsLayout
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
128
,
16
,
32
,
64
,
8
,
8
,
16
,
16
,
1
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
>
,
// Memory friendly
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
DsLayout
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
256
,
256
,
32
,
64
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
DsLayout
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
256
,
256
,
16
,
64
,
8
,
8
,
16
,
16
,
4
,
1
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
2
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
DsLayout
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
128
,
128
,
32
,
64
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
DsLayout
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
128
,
128
,
16
,
64
,
8
,
8
,
16
,
16
,
4
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
2
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
DsLayout
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
128
,
64
,
32
,
64
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
DsLayout
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
128
,
64
,
16
,
64
,
8
,
8
,
16
,
16
,
2
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
2
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
DsLayout
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
128
,
32
,
16
,
64
,
8
,
8
,
16
,
16
,
1
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
2
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
DsLayout
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
64
,
16
,
16
,
128
,
8
,
8
,
16
,
16
,
1
,
1
,
S
<
16
,
4
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
16
,
4
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
DsLayout
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
64
,
16
,
16
,
64
,
8
,
8
,
16
,
16
,
1
,
1
,
S
<
8
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
DsLayout
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
128
,
16
,
32
,
64
,
8
,
8
,
16
,
16
,
1
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
DsLayout
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
128
,
16
,
64
,
64
,
8
,
8
,
16
,
16
,
1
,
2
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
DsLayout
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
128
,
32
,
64
,
64
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
DsLayout
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
128
,
16
,
128
,
64
,
8
,
8
,
16
,
16
,
1
,
4
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
DsLayout
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
128
,
32
,
128
,
64
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
DsLayout
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
256
,
16
,
256
,
64
,
8
,
8
,
16
,
16
,
1
,
4
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
DsLayout
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
256
,
32
,
256
,
64
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
8
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
>
// clang-format on
>
;
}
// namespace instance
}
// namespace instance
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp
View file @
667cd6ab
...
@@ -90,6 +90,25 @@ using device_grouped_conv_fwd_xdl_merged_groups_f32_instances = std::tuple<
...
@@ -90,6 +90,25 @@ using device_grouped_conv_fwd_xdl_merged_groups_f32_instances = std::tuple<
// clang-format on
// clang-format on
>
;
>
;
template
<
index_t
NDimSpatial
,
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
ELayout
,
ConvolutionForwardSpecialization
ConvSpec
>
using
device_grouped_conv_fwd_xdl_merged_groups_int8_instances
=
std
::
tuple
<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// Instances with NumGroupsPerBatch > 1
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
DsLayout
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
32
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
int8_t
,
int8_t
,
LoopScheduler
::
Default
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
DsLayout
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
32
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
int8_t
,
int8_t
,
LoopScheduler
::
Default
,
16
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
DsLayout
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
32
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
int8_t
,
int8_t
,
LoopScheduler
::
Default
,
32
>
// clang-format on
>
;
}
// namespace instance
}
// namespace instance
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp
View file @
667cd6ab
...
@@ -122,6 +122,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -122,6 +122,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
#endif // DL_KERNELS
#endif // DL_KERNELS
#ifdef CK_USE_XDL
#ifdef CK_USE_XDL
// 1D
// layout GNWC/GKXC/GNWK
if
constexpr
(
NumDimSpatial
==
1
&&
is_same_v
<
InLayout
,
GNWC
>
&&
if
constexpr
(
NumDimSpatial
==
1
&&
is_same_v
<
InLayout
,
GNWC
>
&&
is_same_v
<
WeiLayout
,
GKXC
>
&&
is_same_v
<
OutLayout
,
GNWK
>
)
is_same_v
<
WeiLayout
,
GKXC
>
&&
is_same_v
<
OutLayout
,
GNWK
>
)
{
{
...
@@ -160,7 +162,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -160,7 +162,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
}
}
#endif
#endif
}
}
// 2D
// layout GNHWC/GKYXC/GNHWK
if
constexpr
(
NumDimSpatial
==
2
&&
is_same_v
<
InLayout
,
GNHWC
>
&&
if
constexpr
(
NumDimSpatial
==
2
&&
is_same_v
<
InLayout
,
GNHWC
>
&&
is_same_v
<
WeiLayout
,
GKYXC
>
&&
is_same_v
<
OutLayout
,
GNHWK
>
)
is_same_v
<
WeiLayout
,
GKYXC
>
&&
is_same_v
<
OutLayout
,
GNHWK
>
)
{
{
...
@@ -191,7 +194,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -191,7 +194,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
}
}
#endif
#endif
}
}
// layout NHWGC/GKYXC/NHWGK
if
constexpr
(
NumDimSpatial
==
2
&&
is_same_v
<
InLayout
,
NHWGC
>
&&
if
constexpr
(
NumDimSpatial
==
2
&&
is_same_v
<
InLayout
,
NHWGC
>
&&
is_same_v
<
WeiLayout
,
GKYXC
>
&&
is_same_v
<
OutLayout
,
NHWGK
>
)
is_same_v
<
WeiLayout
,
GKYXC
>
&&
is_same_v
<
OutLayout
,
NHWGK
>
)
{
{
...
@@ -247,8 +250,27 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -247,8 +250,27 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instances
(
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instances
(
op_ptrs
);
op_ptrs
);
}
}
#endif
#ifdef CK_ENABLE_INT8
if
constexpr
(
is_same_v
<
InDataType
,
int8_t
>
&&
is_same_v
<
WeiDataType
,
int8_t
>
&&
is_same_v
<
OutDataType
,
int8_t
>
&&
is_same_v
<
AComputeType
,
int8_t
>
&&
is_same_v
<
BComputeType
,
int8_t
>
)
{
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_int8_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_int8_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_comp_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_intra_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_inter_instances
(
op_ptrs
);
}
#endif
#endif
}
}
// layout NGCHW/GKYXC/NGKHW
if
constexpr
(
NumDimSpatial
==
2
&&
is_same_v
<
InLayout
,
NGCHW
>
&&
if
constexpr
(
NumDimSpatial
==
2
&&
is_same_v
<
InLayout
,
NGCHW
>
&&
is_same_v
<
WeiLayout
,
GKYXC
>
&&
is_same_v
<
OutLayout
,
NGKHW
>
)
is_same_v
<
WeiLayout
,
GKYXC
>
&&
is_same_v
<
OutLayout
,
NGKHW
>
)
{
{
...
@@ -282,8 +304,26 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -282,8 +304,26 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
op_ptrs
);
op_ptrs
);
}
}
#endif
#endif
#ifdef CK_ENABLE_INT8
if
constexpr
(
is_same_v
<
InDataType
,
int8_t
>
&&
is_same_v
<
WeiDataType
,
int8_t
>
&&
is_same_v
<
OutDataType
,
int8_t
>
&&
is_same_v
<
AComputeType
,
int8_t
>
&&
is_same_v
<
BComputeType
,
int8_t
>
)
{
add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_int8_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_comp_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_mem_intra_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_mem_inter_instances
(
op_ptrs
);
}
#endif
}
}
// 3D
// layout GNDHWC/GKZYXC/GNDHWK
if
constexpr
(
NumDimSpatial
==
3
&&
is_same_v
<
InLayout
,
GNDHWC
>
&&
if
constexpr
(
NumDimSpatial
==
3
&&
is_same_v
<
InLayout
,
GNDHWC
>
&&
is_same_v
<
WeiLayout
,
GKZYXC
>
&&
is_same_v
<
OutLayout
,
GNDHWK
>
)
is_same_v
<
WeiLayout
,
GKZYXC
>
&&
is_same_v
<
OutLayout
,
GNDHWK
>
)
{
{
...
@@ -323,6 +363,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -323,6 +363,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
#endif
#endif
}
}
// layout NDHWGC/GKZYXC/NDHWGK
if
constexpr
(
NumDimSpatial
==
3
&&
is_same_v
<
InLayout
,
NDHWGC
>
&&
if
constexpr
(
NumDimSpatial
==
3
&&
is_same_v
<
InLayout
,
NDHWGC
>
&&
is_same_v
<
WeiLayout
,
GKZYXC
>
&&
is_same_v
<
OutLayout
,
NDHWGK
>
)
is_same_v
<
WeiLayout
,
GKZYXC
>
&&
is_same_v
<
OutLayout
,
NDHWGK
>
)
{
{
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_comp_xdl.inc
View file @
667cd6ab
...
@@ -57,6 +57,22 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances(
...
@@ -57,6 +57,22 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances(
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
#endif
#endif
#ifdef CK_ENABLE_INT8
void
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_comp_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NHWGC
,
GKYXC
,
Empty_Tuple
,
NHWGK
,
int8_t
,
int8_t
,
Empty_Tuple
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
// grouped conv2d forward, NGCHW/GKYXC/NGKHW
// grouped conv2d forward, NGCHW/GKYXC/NGKHW
#ifdef CK_ENABLE_FP16
#ifdef CK_ENABLE_FP16
void
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_comp_instances
(
void
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_comp_instances
(
...
@@ -90,6 +106,22 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_comp_instances(
...
@@ -90,6 +106,22 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_comp_instances(
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
#endif
#endif
#ifdef CK_ENABLE_INT8
void
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_comp_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NGCHW
,
GKYXC
,
Empty_Tuple
,
NGKHW
,
int8_t
,
int8_t
,
Empty_Tuple
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_BF16
#ifdef CK_ENABLE_BF16
// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK
// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK
void
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances
(
void
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances
(
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_inter_xdl.inc
View file @
667cd6ab
...
@@ -57,6 +57,22 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instances
...
@@ -57,6 +57,22 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instances
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
#endif
#endif
#ifdef CK_ENABLE_INT8
void
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_inter_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NHWGC
,
GKYXC
,
Empty_Tuple
,
NHWGK
,
int8_t
,
int8_t
,
Empty_Tuple
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
// grouped conv2d forward, NGCHW/GKYXC/NGKHW
// grouped conv2d forward, NGCHW/GKYXC/NGKHW
#ifdef CK_ENABLE_FP16
#ifdef CK_ENABLE_FP16
void
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_mem_inter_instances
(
void
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_mem_inter_instances
(
...
@@ -90,6 +106,22 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_mem_inter_instances
...
@@ -90,6 +106,22 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_mem_inter_instances
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
#endif
#endif
#ifdef CK_ENABLE_INT8
void
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_mem_inter_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NGCHW
,
GKYXC
,
Empty_Tuple
,
NGKHW
,
int8_t
,
int8_t
,
Empty_Tuple
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_BF16
#ifdef CK_ENABLE_BF16
// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK
// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK
void
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instances
(
void
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instances
(
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_intra_xdl.inc
View file @
667cd6ab
...
@@ -57,6 +57,22 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances
...
@@ -57,6 +57,22 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
#endif
#endif
#ifdef CK_ENABLE_INT8
void
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_intra_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NHWGC
,
GKYXC
,
Empty_Tuple
,
NHWGK
,
int8_t
,
int8_t
,
Empty_Tuple
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
// grouped conv2d forward, NGCHW/GKYXC/NGKHW
// grouped conv2d forward, NGCHW/GKYXC/NGKHW
#ifdef CK_ENABLE_FP16
#ifdef CK_ENABLE_FP16
void
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_mem_intra_instances
(
void
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_mem_intra_instances
(
...
@@ -90,6 +106,22 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_mem_intra_instances
...
@@ -90,6 +106,22 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_mem_intra_instances
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
#endif
#endif
#ifdef CK_ENABLE_INT8
void
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_mem_intra_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NGCHW
,
GKYXC
,
Empty_Tuple
,
NGKHW
,
int8_t
,
int8_t
,
Empty_Tuple
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_BF16
#ifdef CK_ENABLE_BF16
// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK
// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK
void
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances
(
void
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances
(
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc
View file @
667cd6ab
...
@@ -171,6 +171,22 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
...
@@ -171,6 +171,22 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
#endif
#endif
#ifdef CK_ENABLE_INT8
void
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NHWGC
,
GKYXC
,
Empty_Tuple
,
NHWGK
,
int8_t
,
int8_t
,
Empty_Tuple
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
// grouped conv2d forward, NGCHW/GKYXC/NGKHW
// grouped conv2d forward, NGCHW/GKYXC/NGKHW
#ifdef CK_ENABLE_FP16
#ifdef CK_ENABLE_FP16
void
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_instances
(
void
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_instances
(
...
@@ -204,6 +220,22 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_instances(
...
@@ -204,6 +220,22 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_instances(
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
#endif
#endif
#ifdef CK_ENABLE_INT8
void
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NGCHW
,
GKYXC
,
Empty_Tuple
,
NGKHW
,
int8_t
,
int8_t
,
Empty_Tuple
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_BF16
#ifdef CK_ENABLE_BF16
// grouped conv3d forward, GNDHWC/GKZYXC/GNDHWK
// grouped conv3d forward, GNDHWC/GKZYXC/GNDHWK
void
add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instances
(
void
add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instances
(
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl_large_tensor.inc
View file @
667cd6ab
...
@@ -57,6 +57,22 @@ void add_device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_instan
...
@@ -57,6 +57,22 @@ void add_device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_instan
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
#endif
#endif
#ifdef CK_ENABLE_INT8
void
add_device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NHWGC
,
GKYXC
,
Empty_Tuple
,
NHWGK
,
int8_t
,
int8_t
,
Empty_Tuple
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_BF16
#ifdef CK_ENABLE_BF16
// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK
// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK
void
add_device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances
(
void
add_device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances
(
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl_merged_groups.inc
View file @
667cd6ab
...
@@ -85,6 +85,36 @@ void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f32_insta
...
@@ -85,6 +85,36 @@ void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f32_insta
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
#endif
#endif
#ifdef CK_ENABLE_INT8
void
add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NHWGC
,
GKYXC
,
Empty_Tuple
,
NHWGK
,
int8_t
,
int8_t
,
Empty_Tuple
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NGCHW
,
GKYXC
,
Empty_Tuple
,
NGKHW
,
int8_t
,
int8_t
,
Empty_Tuple
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_BF16
#ifdef CK_ENABLE_BF16
// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK
// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK
void
add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances
(
void
add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances
(
...
...
library/src/tensor_operation_instance/gpu/CMakeLists.txt
View file @
667cd6ab
...
@@ -67,17 +67,41 @@ function(add_instance_library INSTANCE_NAME)
...
@@ -67,17 +67,41 @@ function(add_instance_library INSTANCE_NAME)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
endif
()
endif
()
endforeach
()
endforeach
()
# Do not build gemm_universal_f8 or gemm_multiply_multiply_f8 for any targets except gfx94
if
(
NOT CK_USE_FP8_ON_UNSUPPORTED_ARCH
)
foreach
(
source IN LISTS ARGN
)
if
(
NOT INST_TARGETS MATCHES
"gfx94"
AND source MATCHES
"gemm_multiply_multiply_xdl_f8"
)
message
(
"removing gemm_multiply_multiply_f8 instance
${
source
}
"
)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
endif
()
endforeach
()
foreach
(
source IN LISTS ARGN
)
if
(
NOT INST_TARGETS MATCHES
"gfx94"
AND source MATCHES
"gemm_xdl_universal"
AND source MATCHES
"_f8_"
)
message
(
"removing gemm_universal_f8 instance
${
source
}
"
)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
endif
()
endforeach
()
endif
()
#only continue if there are some source files left on the list
#only continue if there are some source files left on the list
if
(
ARGN
)
if
(
ARGN
)
set
(
INST_OBJ
)
set
(
INST_OBJ
)
foreach
(
source IN LISTS ARGN
)
foreach
(
source IN LISTS ARGN
)
set
(
INST_TARGETS
${
SUPPORTED_GPU_TARGETS
}
)
set
(
INST_TARGETS
${
SUPPORTED_GPU_TARGETS
}
)
if
(
source MATCHES
"_xdl"
)
if
(
source MATCHES
"_xdl"
)
list
(
REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201
)
list
(
REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201
)
elseif
(
ARGN MATCHES
"_wmma"
)
elseif
(
source MATCHES
"_wmma"
)
list
(
REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030
)
list
(
REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030
)
elseif
(
ARGN MATCHES
"mha"
)
elseif
(
source MATCHES
"mha"
)
list
(
REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx908 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201
)
list
(
REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx908 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201
)
endif
()
#only build the fp8 gemm instances for gfx908/90a if the build argument is set
if
(
NOT CK_USE_FP8_ON_UNSUPPORTED_ARCH
)
if
(
source MATCHES
"gemm_xdl_universal"
AND source MATCHES
"f8"
)
list
(
REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201
)
endif
()
if
(
source MATCHES
"gemm_multiply_multiply_f8"
)
list
(
REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201
)
endif
()
endif
()
endif
()
set
(
offload_targets
)
set
(
offload_targets
)
foreach
(
target IN LISTS INST_TARGETS
)
foreach
(
target IN LISTS INST_TARGETS
)
...
@@ -108,7 +132,7 @@ function(add_instance_library INSTANCE_NAME)
...
@@ -108,7 +132,7 @@ function(add_instance_library INSTANCE_NAME)
# flags to compress the library
# flags to compress the library
if
(
NOT WIN32 AND
${
hip_VERSION_FLAT
}
GREATER 600241132
)
if
(
NOT WIN32 AND
${
hip_VERSION_FLAT
}
GREATER 600241132
)
message
(
"Adding --offload-compress flag for
${
INSTANCE_NAME
}
"
)
#
message("Adding --offload-compress flag for ${INSTANCE_NAME}")
target_compile_options
(
${
INSTANCE_NAME
}
PRIVATE --offload-compress
)
target_compile_options
(
${
INSTANCE_NAME
}
PRIVATE --offload-compress
)
endif
()
endif
()
...
...
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn.hpp
View file @
667cd6ab
...
@@ -36,12 +36,12 @@ static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
...
@@ -36,12 +36,12 @@ static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
template
<
GemmSpecialization
GemmSpec
>
template
<
GemmSpecialization
GemmSpec
>
using
device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_instances
=
std
::
tuple
<
using
device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
//################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
#if defined(__gfx94__) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH)
// Compute friendly
// Compute friendly
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
256
,
256
,
64
,
16
,
16
,
32
,
32
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
256
,
256
,
64
,
16
,
16
,
32
,
32
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
128
,
128
,
128
,
16
,
16
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
128
,
128
,
128
,
16
,
16
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
,
F8
>
,
...
@@ -58,17 +58,18 @@ using device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_instances = std
...
@@ -58,17 +58,18 @@ using device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_instances = std
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
128
,
64
,
128
,
16
,
16
,
32
,
32
,
2
,
1
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
128
,
64
,
128
,
16
,
16
,
32
,
32
,
2
,
1
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
64
,
128
,
128
,
16
,
16
,
32
,
32
,
1
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
64
,
128
,
128
,
16
,
16
,
32
,
32
,
1
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
64
,
64
,
128
,
16
,
16
,
32
,
32
,
1
,
1
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
F8
>
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
64
,
64
,
128
,
16
,
16
,
32
,
32
,
1
,
1
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
F8
>
#endif
// clang-format on
// clang-format on
>
;
>
;
template
<
BlockGemmPipelineScheduler
BlkGemmPipeSched
,
GemmSpecialization
GemmSpec
>
template
<
BlockGemmPipelineScheduler
BlkGemmPipeSched
,
GemmSpecialization
GemmSpec
>
using
device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_instances
=
std
::
tuple
<
using
device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
//################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
#if defined(__gfx94__) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH)
// Latency friendly
// Latency friendly
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
128
,
32
,
16
,
128
,
16
,
16
,
16
,
16
,
1
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
S
<
2
,
2
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
128
,
32
,
16
,
128
,
16
,
16
,
16
,
16
,
1
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
S
<
2
,
2
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
64
,
16
,
16
,
128
,
16
,
16
,
16
,
16
,
1
,
1
,
S
<
8
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
S
<
4
,
4
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
64
,
16
,
16
,
128
,
16
,
16
,
16
,
16
,
1
,
1
,
S
<
8
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
S
<
4
,
4
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
,
F8
>
,
...
@@ -90,6 +91,7 @@ using device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_instances = std:
...
@@ -90,6 +91,7 @@ using device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_instances = std:
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
128
,
32
,
128
,
128
,
16
,
16
,
32
,
32
,
1
,
2
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
128
,
32
,
128
,
128
,
16
,
16
,
32
,
32
,
1
,
2
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
16
,
256
,
128
,
16
,
16
,
16
,
16
,
1
,
4
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
S
<
4
,
4
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
16
,
256
,
128
,
16
,
16
,
16
,
16
,
1
,
4
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
S
<
4
,
4
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
32
,
256
,
128
,
16
,
16
,
32
,
32
,
1
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
S
<
8
,
8
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
F8
>
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
32
,
256
,
128
,
16
,
16
,
32
,
32
,
1
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
S
<
8
,
8
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
F8
>
#endif
// clang-format on
// clang-format on
>
;
>
;
}
// namespace instance
}
// namespace instance
...
...
library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn.hpp
View file @
667cd6ab
...
@@ -62,12 +62,12 @@ using device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_instances = std::tuple<
...
@@ -62,12 +62,12 @@ using device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_instances = std::tuple<
template
<
BlockGemmPipelineScheduler
BlkGemmPipeSched
,
GemmSpecialization
GemmSpec
>
template
<
BlockGemmPipelineScheduler
BlkGemmPipeSched
,
GemmSpecialization
GemmSpec
>
using
device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_mem_instances
=
std
::
tuple
<
using
device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_mem_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
//#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
#if defined(__gfx94__) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH)
// Latency friendly
// Latency friendly
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Row
,
Row
,
F8
,
F8
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
32
,
16
,
128
,
16
,
4
,
16
,
16
,
1
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
32
,
4
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
2
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
,
F8
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Row
,
Row
,
F8
,
F8
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
32
,
16
,
128
,
16
,
4
,
16
,
16
,
1
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
32
,
4
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
2
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
,
F8
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Row
,
Row
,
F8
,
F8
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
64
,
16
,
16
,
128
,
16
,
4
,
16
,
16
,
1
,
1
,
S
<
8
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
32
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
,
F8
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Row
,
Row
,
F8
,
F8
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
64
,
16
,
16
,
128
,
16
,
4
,
16
,
16
,
1
,
1
,
S
<
8
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
32
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
,
F8
>
,
...
@@ -90,6 +90,7 @@ using device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_mem_instances = std::tuple<
...
@@ -90,6 +90,7 @@ using device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_mem_instances = std::tuple<
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Row
,
Row
,
F8
,
F8
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
16
,
64
,
128
,
16
,
4
,
16
,
16
,
1
,
2
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
32
,
4
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
16
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
F8
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Row
,
Row
,
F8
,
F8
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
16
,
64
,
128
,
16
,
4
,
16
,
16
,
1
,
2
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
32
,
4
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
16
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
F8
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Row
,
Row
,
F8
,
F8
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
16
,
128
,
128
,
16
,
8
,
16
,
16
,
1
,
4
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
16
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
16
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
F8
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Row
,
Row
,
F8
,
F8
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
16
,
128
,
128
,
16
,
8
,
16
,
16
,
1
,
4
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
16
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
16
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
F8
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Row
,
Row
,
F8
,
F8
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
16
,
256
,
128
,
8
,
8
,
16
,
16
,
1
,
4
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
16
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
F8
>
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Row
,
Row
,
F8
,
F8
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
16
,
256
,
128
,
8
,
8
,
16
,
16
,
1
,
4
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
16
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
F8
>
#endif
// clang-format on
// clang-format on
>
;
>
;
}
// namespace instance
}
// namespace instance
...
...
library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn.hpp
View file @
667cd6ab
...
@@ -35,12 +35,12 @@ static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
...
@@ -35,12 +35,12 @@ static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
template
<
GemmSpecialization
GemmSpec
>
template
<
GemmSpecialization
GemmSpec
>
using
device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_instances
=
std
::
tuple
<
using
device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
//#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
#if defined(__gfx94__) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH)
// Compute friendly
// Compute friendly
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
256
,
64
,
16
,
16
,
32
,
32
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
,
F8
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
256
,
64
,
16
,
16
,
32
,
32
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
,
F8
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
128
,
128
,
128
,
16
,
16
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
,
F8
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
128
,
128
,
128
,
16
,
16
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
,
F8
>
,
...
@@ -57,17 +57,18 @@ using device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_instances = std::tuple<
...
@@ -57,17 +57,18 @@ using device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_instances = std::tuple<
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
128
,
64
,
128
,
16
,
16
,
32
,
32
,
2
,
1
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
F8
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
128
,
64
,
128
,
16
,
16
,
32
,
32
,
2
,
1
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
F8
>
,
// DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
// DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
64
,
64
,
128
,
16
,
16
,
32
,
32
,
1
,
1
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
F8
>
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
64
,
64
,
128
,
16
,
16
,
32
,
32
,
1
,
1
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
F8
>
#endif
// clang-format on
// clang-format on
>
;
>
;
template
<
BlockGemmPipelineScheduler
BlkGemmPipeSched
,
GemmSpecialization
GemmSpec
>
template
<
BlockGemmPipelineScheduler
BlkGemmPipeSched
,
GemmSpecialization
GemmSpec
>
using
device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_instances
=
std
::
tuple
<
using
device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
//#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
#if defined(__gfx94__) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH)
// Latency friendly
// Latency friendly
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
32
,
16
,
128
,
16
,
16
,
16
,
16
,
1
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
2
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
,
F8
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
32
,
16
,
128
,
16
,
16
,
16
,
16
,
1
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
2
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
,
F8
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
64
,
16
,
16
,
128
,
16
,
16
,
16
,
16
,
1
,
1
,
S
<
8
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
,
F8
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
64
,
16
,
16
,
128
,
16
,
16
,
16
,
16
,
1
,
1
,
S
<
8
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
,
F8
>
,
...
@@ -97,6 +98,7 @@ using device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_instances = std::tuple<
...
@@ -97,6 +98,7 @@ using device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_instances = std::tuple<
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
32
,
128
,
128
,
16
,
16
,
32
,
32
,
1
,
2
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
F8
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
32
,
128
,
128
,
16
,
16
,
32
,
32
,
1
,
2
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
F8
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
16
,
256
,
128
,
16
,
16
,
16
,
16
,
1
,
4
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
F8
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
16
,
256
,
128
,
16
,
16
,
16
,
16
,
1
,
4
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
F8
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
32
,
256
,
128
,
16
,
16
,
32
,
32
,
1
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
8
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
F8
>
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
32
,
256
,
128
,
16
,
16
,
32
,
32
,
1
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
8
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
F8
>
#endif
// clang-format on
// clang-format on
>
;
>
;
}
// namespace instance
}
// namespace instance
...
...
library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt
View file @
667cd6ab
...
@@ -9,45 +9,56 @@ add_instance_library(device_grouped_conv2d_fwd_instance
...
@@ -9,45 +9,56 @@ add_instance_library(device_grouped_conv2d_fwd_instance
xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp
xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp
xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp
xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp
xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_instance.cpp
# NGCHW, GKYXC, NGKHW
# NGCHW, GKYXC, NGKHW
xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_instance.cpp
xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_instance.cpp
xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_instance.cpp
xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_instance.cpp
xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_instance.cpp
# large tensor
# large tensor
# NHWGC, GKYXC, NHWGK
# NHWGC, GKYXC, NHWGK
xdl/large_tensor/device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
xdl/large_tensor/device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
xdl/large_tensor/device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f16_instance.cpp
xdl/large_tensor/device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f16_instance.cpp
xdl/large_tensor/device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_instance.cpp
xdl/large_tensor/device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_instance.cpp
xdl/large_tensor/device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_int8_instance.cpp
# merged groups
# merged groups
# NHWGC, GKYXC, NHWGK
# NHWGC, GKYXC, NHWGK
xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_instance.cpp
xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_instance.cpp
xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instance.cpp
xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instance.cpp
xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_int8_instance.cpp
# NGCHW, GKYXC, NGKHW
# NGCHW, GKYXC, NGKHW
xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f16_instance.cpp
xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f16_instance.cpp
xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f32_instance.cpp
xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f32_instance.cpp
xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_int8_instance.cpp
#mem
#mem
# NHWGC, GKYXC, NHWGK
# NHWGC, GKYXC, NHWGK
xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instance.cpp
xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instance.cpp
xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_intra_instance.cpp
xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_intra_instance.cpp
xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instance.cpp
xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instance.cpp
xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_intra_instance.cpp
# NHWGC, GKYXC, NHWGK
# NHWGC, GKYXC, NHWGK
xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instance.cpp
xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instance.cpp
xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_inter_instance.cpp
xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_inter_instance.cpp
xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instance.cpp
xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instance.cpp
xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_inter_instance.cpp
# NGCHW, GKYXC, NGKHW
# NGCHW, GKYXC, NGKHW
xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_mem_intra_instance.cpp
xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_mem_intra_instance.cpp
xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_mem_intra_instance.cpp
xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_mem_intra_instance.cpp
xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_mem_intra_instance.cpp
# NGCHW, GKYXC, NGKHW
# NGCHW, GKYXC, NGKHW
xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_mem_inter_instance.cpp
xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_mem_inter_instance.cpp
xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_mem_inter_instance.cpp
xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_mem_inter_instance.cpp
xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_mem_inter_instance.cpp
#comp
#comp
# NHWGC, GKYXC, NHWGK
# NHWGC, GKYXC, NHWGK
xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instance.cpp
xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instance.cpp
xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instance.cpp
xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instance.cpp
xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instance.cpp
xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instance.cpp
xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_comp_instance.cpp
# NGCHW, GKYXC, NGKHW
# NGCHW, GKYXC, NGKHW
xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_comp_instance.cpp
xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_comp_instance.cpp
xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_comp_instance.cpp
xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_comp_instance.cpp
xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_comp_instance.cpp
#dl
#dl
# GNHWC, GKYXC, GNHWK
# GNHWC, GKYXC, GNHWK
dl/device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instance.cpp
dl/device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instance.cpp
...
...
library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_comp_instance.cpp
0 → 100644
View file @
667cd6ab
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_comp_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NGCHW
,
GKYXC
,
Empty_Tuple
,
NGKHW
,
int8_t
,
int8_t
,
Empty_Tuple
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_grouped_conv_fwd_xdl_int8_comp_instances
<
2
,
NGCHW
,
GKYXC
,
Empty_Tuple
,
NGKHW
,
ConvFwdDefault
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_comp_instance.cpp
0 → 100644
View file @
667cd6ab
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_comp_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NHWGC
,
GKYXC
,
Empty_Tuple
,
NHWGK
,
int8_t
,
int8_t
,
Empty_Tuple
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_grouped_conv_fwd_xdl_int8_comp_instances
<
2
,
NHWGC
,
GKYXC
,
Empty_Tuple
,
NHWGK
,
ConvFwdDefault
>
{});
add_device_operation_instances
(
instances
,
device_grouped_conv_fwd_xdl_int8_comp_instances
<
2
,
NHWGC
,
GKYXC
,
Empty_Tuple
,
NHWGK
,
ConvFwd1x1P0
>
{});
add_device_operation_instances
(
instances
,
device_grouped_conv_fwd_xdl_int8_comp_instances
<
2
,
NHWGC
,
GKYXC
,
Empty_Tuple
,
NHWGK
,
ConvFwd1x1S1P0
>
{});
add_device_operation_instances
(
instances
,
device_grouped_conv_fwd_xdl_int8_comp_instances
<
2
,
NHWGC
,
GKYXC
,
Empty_Tuple
,
NHWGK
,
ConvFwdOddC
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_instance.cpp
0 → 100644
View file @
667cd6ab
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NGCHW
,
GKYXC
,
Empty_Tuple
,
NGKHW
,
int8_t
,
int8_t
,
Empty_Tuple
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_grouped_conv_fwd_xdl_int8_instances
<
2
,
NGCHW
,
GKYXC
,
Empty_Tuple
,
NGKHW
,
ConvFwdDefault
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_instance.cpp
0 → 100644
View file @
667cd6ab
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NHWGC
,
GKYXC
,
Empty_Tuple
,
NHWGK
,
int8_t
,
int8_t
,
Empty_Tuple
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_grouped_conv_fwd_xdl_int8_instances
<
2
,
NHWGC
,
GKYXC
,
Empty_Tuple
,
NHWGK
,
ConvFwdDefault
>
{});
add_device_operation_instances
(
instances
,
device_grouped_conv_fwd_xdl_int8_instances
<
2
,
NHWGC
,
GKYXC
,
Empty_Tuple
,
NHWGK
,
ConvFwd1x1P0
>
{});
add_device_operation_instances
(
instances
,
device_grouped_conv_fwd_xdl_int8_instances
<
2
,
NHWGC
,
GKYXC
,
Empty_Tuple
,
NHWGK
,
ConvFwd1x1S1P0
>
{});
add_device_operation_instances
(
instances
,
device_grouped_conv_fwd_xdl_int8_instances
<
2
,
NHWGC
,
GKYXC
,
Empty_Tuple
,
NHWGK
,
ConvFwdOddC
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment