Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
b93575ca
Commit
b93575ca
authored
Aug 28, 2023
by
Jing Zhang
Browse files
merge develop
parents
54df59bf
c8a8385f
Changes
352
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
223 additions
and
38 deletions
+223
-38
library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt
...tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt
+2
-0
library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instance.cpp
...mm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instance.cpp
+41
-24
library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/CMakeLists.txt
...eration_instance/gpu/grouped_gemm_fastgelu/CMakeLists.txt
+2
-0
library/src/tensor_operation_instance/gpu/normalization/CMakeLists.txt
...ensor_operation_instance/gpu/normalization/CMakeLists.txt
+12
-8
library/src/tensor_operation_instance/gpu/normalization/device_groupnorm_f16_instance.cpp
...tance/gpu/normalization/device_groupnorm_f16_instance.cpp
+2
-0
library/src/tensor_operation_instance/gpu/normalization/device_groupnorm_f32_instance.cpp
...tance/gpu/normalization/device_groupnorm_f32_instance.cpp
+2
-0
library/src/tensor_operation_instance/gpu/normalization/device_groupnorm_swish_f16_f32_f32_f16_instance.cpp
...ation/device_groupnorm_swish_f16_f32_f32_f16_instance.cpp
+2
-0
library/src/tensor_operation_instance/gpu/normalization/device_groupnorm_swish_f16_instance.cpp
...gpu/normalization/device_groupnorm_swish_f16_instance.cpp
+2
-0
library/src/tensor_operation_instance/gpu/normalization/device_groupnorm_swish_f32_instance.cpp
...gpu/normalization/device_groupnorm_swish_f32_instance.cpp
+2
-0
library/src/tensor_operation_instance/gpu/normalization/device_layernorm2d_f16_instance.cpp
...nce/gpu/normalization/device_layernorm2d_f16_instance.cpp
+2
-0
library/src/tensor_operation_instance/gpu/normalization/device_layernorm2d_f32_instance.cpp
...nce/gpu/normalization/device_layernorm2d_f32_instance.cpp
+2
-0
library/src/tensor_operation_instance/gpu/normalization/device_layernorm4d_f16_instance.cpp
...nce/gpu/normalization/device_layernorm4d_f16_instance.cpp
+2
-0
library/src/tensor_operation_instance/gpu/normalization/device_layernorm4d_f32_instance.cpp
...nce/gpu/normalization/device_layernorm4d_f32_instance.cpp
+2
-0
library/src/tensor_operation_instance/gpu/normalization/normalization_instance_common.hpp
...tance/gpu/normalization/normalization_instance_common.hpp
+79
-0
library/src/tensor_operation_instance/gpu/pool3d_fwd/CMakeLists.txt
...c/tensor_operation_instance/gpu/pool3d_fwd/CMakeLists.txt
+10
-0
library/src/tensor_operation_instance/gpu/pool3d_fwd/device_avg_pool3d_fwd_ndhwc_f16_instance.cpp
...u/pool3d_fwd/device_avg_pool3d_fwd_ndhwc_f16_instance.cpp
+3
-1
library/src/tensor_operation_instance/gpu/pool3d_fwd/device_avg_pool3d_fwd_ndhwc_f32_instance.cpp
...u/pool3d_fwd/device_avg_pool3d_fwd_ndhwc_f32_instance.cpp
+3
-1
library/src/tensor_operation_instance/gpu/pool3d_fwd/device_max_pool3d_fwd_ndhwc_f16_instance.cpp
...u/pool3d_fwd/device_max_pool3d_fwd_ndhwc_f16_instance.cpp
+6
-2
library/src/tensor_operation_instance/gpu/pool3d_fwd/device_max_pool3d_fwd_ndhwc_f32_instance.cpp
...u/pool3d_fwd/device_max_pool3d_fwd_ndhwc_f32_instance.cpp
+6
-2
library/src/tensor_operation_instance/gpu/pool3d_fwd/pool_fwd_instance_common.hpp
...tion_instance/gpu/pool3d_fwd/pool_fwd_instance_common.hpp
+41
-0
No files found.
library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt
View file @
b93575ca
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
add_instance_library
(
device_grouped_gemm_instance
device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp
device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp
...
...
@@ -8,3 +9,4 @@ add_instance_library(device_grouped_gemm_instance
device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instance.cpp
device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instance.cpp
)
endif
()
library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instance.cpp
View file @
b93575ca
...
...
@@ -41,30 +41,47 @@ using device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_tile_instanc
// DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>,
// DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
128
,
256
,
32
,
8
,
8
,
32
,
32
,
2
,
4
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
192
,
64
,
32
,
8
,
8
,
32
,
32
,
3
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
64
,
192
,
32
,
8
,
8
,
32
,
32
,
1
,
3
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
48
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
32
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
32
,
192
,
32
,
8
,
8
,
32
,
32
,
1
,
3
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
24
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
192
,
32
,
32
,
8
,
8
,
32
,
32
,
3
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
32
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
64
,
32
,
32
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
32
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
128
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
64
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
64
,
64
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
64
,
32
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
16
,
128
,
32
,
8
,
8
,
16
,
16
,
1
,
4
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
16
,
128
,
32
,
8
,
8
,
16
,
16
,
1
,
2
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
64
,
16
,
16
,
32
,
8
,
8
,
16
,
16
,
1
,
1
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
4
>
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
128
,
256
,
32
,
8
,
8
,
32
,
32
,
2
,
4
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
192
,
64
,
32
,
8
,
8
,
32
,
32
,
3
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
64
,
192
,
32
,
8
,
8
,
32
,
32
,
1
,
3
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
48
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
32
,
192
,
32
,
8
,
8
,
32
,
32
,
1
,
3
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
24
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
192
,
32
,
32
,
8
,
8
,
32
,
32
,
3
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
32
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
64
,
32
,
32
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
32
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
128
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
64
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
64
,
64
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
64
,
32
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
PipelineVersion
::
v2
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
128
,
256
,
32
,
8
,
8
,
32
,
32
,
2
,
4
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
PipelineVersion
::
v2
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
192
,
64
,
32
,
8
,
8
,
32
,
32
,
3
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
PipelineVersion
::
v2
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
64
,
192
,
32
,
8
,
8
,
32
,
32
,
1
,
3
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
48
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
PipelineVersion
::
v2
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
PipelineVersion
::
v2
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
PipelineVersion
::
v2
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
PipelineVersion
::
v2
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
PipelineVersion
::
v2
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
PipelineVersion
::
v2
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
PipelineVersion
::
v2
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
32
,
192
,
32
,
8
,
8
,
32
,
32
,
1
,
3
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
24
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
PipelineVersion
::
v2
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
192
,
32
,
32
,
8
,
8
,
32
,
32
,
3
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
PipelineVersion
::
v2
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
32
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
PipelineVersion
::
v2
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
64
,
32
,
32
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
PipelineVersion
::
v2
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
32
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
PipelineVersion
::
v2
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
128
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
PipelineVersion
::
v2
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
PipelineVersion
::
v2
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
64
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
PipelineVersion
::
v2
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
64
,
64
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
PipelineVersion
::
v2
>
,
DeviceGroupedGemmXdlSplitKCShuffle
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
64
,
32
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
PipelineVersion
::
v2
>
// clang-format on
>
;
...
...
library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/CMakeLists.txt
View file @
b93575ca
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
add_instance_library
(
device_grouped_gemm_fastgelu_instance
device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_kn_mn_instance.cpp
device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_nk_mn_instance.cpp
device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_kn_mn_instance.cpp
device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_nk_mn_instance.cpp
)
endif
()
library/src/tensor_operation_instance/gpu/normalization/CMakeLists.txt
View file @
b93575ca
add_instance_library
(
device_normalization_instance
device_layernorm2d_f16_instance.cpp
device_layernorm2d_f
32
_instance.cpp
set
(
DEVICE_NORMALIZATION_INSTANCES
)
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
list
(
APPEND DEVICE_NORMALIZATION_INSTANCES
device_layernorm2d_f
16
_instance.cpp
device_layernorm4d_f16_instance.cpp
device_layernorm4d_f32_instance.cpp
device_groupnorm_f16_instance.cpp
device_groupnorm_f32_instance.cpp
device_groupnorm_swish_f16_instance.cpp
device_groupnorm_swish_f32_instance.cpp
device_groupnorm_swish_f16_f32_f32_f16_instance.cpp
)
device_groupnorm_swish_f16_f32_f32_f16_instance.cpp
)
endif
()
if
(
DTYPES MATCHES
"fp32"
OR NOT DEFINED DTYPES
)
list
(
APPEND DEVICE_NORMALIZATION_INSTANCES device_layernorm2d_f32_instance.cpp
device_layernorm4d_f32_instance.cpp
device_groupnorm_f32_instance.cpp
device_groupnorm_swish_f32_instance.cpp
)
endif
()
add_instance_library
(
device_normalization_instance
${
DEVICE_NORMALIZATION_INSTANCES
}
)
library/src/tensor_operation_instance/gpu/normalization/device_groupnorm_f16_instance.cpp
View file @
b93575ca
...
...
@@ -17,6 +17,8 @@ void add_device_normalization_rank_5_3_f16_instances(
add_device_operation_instances
(
instances
,
device_normalization_f16_generic_instance
<
Pass
,
5
,
3
>
{});
add_device_operation_instances
(
instances
,
device_normalization_f16_instances
<
Pass
,
5
,
3
>
{});
add_device_operation_instances
(
instances
,
device_normalization_splitk_f16_instances
<
Pass
,
5
,
3
>
{});
}
}
// namespace instance
...
...
library/src/tensor_operation_instance/gpu/normalization/device_groupnorm_f32_instance.cpp
View file @
b93575ca
...
...
@@ -17,6 +17,8 @@ void add_device_normalization_rank_5_3_f32_instances(
add_device_operation_instances
(
instances
,
device_normalization_f32_generic_instance
<
Pass
,
5
,
3
>
{});
add_device_operation_instances
(
instances
,
device_normalization_f32_instances
<
Pass
,
5
,
3
>
{});
add_device_operation_instances
(
instances
,
device_normalization_splitk_f32_instances
<
Pass
,
5
,
3
>
{});
}
}
// namespace instance
...
...
library/src/tensor_operation_instance/gpu/normalization/device_groupnorm_swish_f16_f32_f32_f16_instance.cpp
View file @
b93575ca
...
...
@@ -18,6 +18,8 @@ void add_device_normalization_rank_5_3_swish_f16_f32_f32_f16_instances(
instances
,
device_normalization_f16_f32_f32_f16_generic_instance
<
Swish
,
5
,
3
>
{});
add_device_operation_instances
(
instances
,
device_normalization_f16_f32_f32_f16_instances
<
Swish
,
5
,
3
>
{});
add_device_operation_instances
(
instances
,
device_normalization_splitk_f16_f32_f32_f16_instances
<
Swish
,
5
,
3
>
{});
}
}
// namespace instance
...
...
library/src/tensor_operation_instance/gpu/normalization/device_groupnorm_swish_f16_instance.cpp
View file @
b93575ca
...
...
@@ -17,6 +17,8 @@ void add_device_normalization_rank_5_3_swish_f16_instances(
add_device_operation_instances
(
instances
,
device_normalization_f16_generic_instance
<
Swish
,
5
,
3
>
{});
add_device_operation_instances
(
instances
,
device_normalization_f16_instances
<
Swish
,
5
,
3
>
{});
add_device_operation_instances
(
instances
,
device_normalization_splitk_f16_instances
<
Swish
,
5
,
3
>
{});
}
}
// namespace instance
...
...
library/src/tensor_operation_instance/gpu/normalization/device_groupnorm_swish_f32_instance.cpp
View file @
b93575ca
...
...
@@ -17,6 +17,8 @@ void add_device_normalization_rank_5_3_swish_f32_instances(
add_device_operation_instances
(
instances
,
device_normalization_f32_generic_instance
<
Swish
,
5
,
3
>
{});
add_device_operation_instances
(
instances
,
device_normalization_f32_instances
<
Swish
,
5
,
3
>
{});
add_device_operation_instances
(
instances
,
device_normalization_splitk_f32_instances
<
Swish
,
5
,
3
>
{});
}
}
// namespace instance
...
...
library/src/tensor_operation_instance/gpu/normalization/device_layernorm2d_f16_instance.cpp
View file @
b93575ca
...
...
@@ -17,6 +17,8 @@ void add_device_normalization_rank_2_1_f16_instances(
add_device_operation_instances
(
instances
,
device_normalization_f16_generic_instance
<
Pass
,
2
,
1
>
{});
add_device_operation_instances
(
instances
,
device_normalization_f16_instances
<
Pass
,
2
,
1
>
{});
add_device_operation_instances
(
instances
,
device_normalization_splitk_f16_instances
<
Pass
,
2
,
1
>
{});
}
}
// namespace instance
...
...
library/src/tensor_operation_instance/gpu/normalization/device_layernorm2d_f32_instance.cpp
View file @
b93575ca
...
...
@@ -17,6 +17,8 @@ void add_device_normalization_rank_2_1_f32_instances(
add_device_operation_instances
(
instances
,
device_normalization_f32_generic_instance
<
Pass
,
2
,
1
>
{});
add_device_operation_instances
(
instances
,
device_normalization_f32_instances
<
Pass
,
2
,
1
>
{});
add_device_operation_instances
(
instances
,
device_normalization_splitk_f32_instances
<
Pass
,
2
,
1
>
{});
}
}
// namespace instance
...
...
library/src/tensor_operation_instance/gpu/normalization/device_layernorm4d_f16_instance.cpp
View file @
b93575ca
...
...
@@ -17,6 +17,8 @@ void add_device_normalization_rank_4_3_f16_instances(
add_device_operation_instances
(
instances
,
device_normalization_f16_generic_instance
<
Pass
,
4
,
3
>
{});
add_device_operation_instances
(
instances
,
device_normalization_f16_instances
<
Pass
,
4
,
3
>
{});
add_device_operation_instances
(
instances
,
device_normalization_splitk_f16_instances
<
Pass
,
4
,
3
>
{});
}
}
// namespace instance
...
...
library/src/tensor_operation_instance/gpu/normalization/device_layernorm4d_f32_instance.cpp
View file @
b93575ca
...
...
@@ -17,6 +17,8 @@ void add_device_normalization_rank_4_3_f32_instances(
add_device_operation_instances
(
instances
,
device_normalization_f32_generic_instance
<
Pass
,
4
,
3
>
{});
add_device_operation_instances
(
instances
,
device_normalization_f32_instances
<
Pass
,
4
,
3
>
{});
add_device_operation_instances
(
instances
,
device_normalization_splitk_f32_instances
<
Pass
,
4
,
3
>
{});
}
}
// namespace instance
...
...
library/src/tensor_operation_instance/gpu/normalization/normalization_instance_common.hpp
View file @
b93575ca
...
...
@@ -5,6 +5,7 @@
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_normalization_splitk_impl.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
...
...
@@ -43,6 +44,32 @@ using device_normalization_f16_instances =
// clang-format on
>
;
template
<
typename
OutElementwise
,
index_t
Rank
,
index_t
Reduce
>
using
device_normalization_splitk_f16_instances
=
// clang-format off
std
::
tuple
<
// XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize>
DeviceNormalizationSplitKImpl
<
F16
,
F16
,
F16
,
F32
,
F16
,
OutElementwise
,
Rank
,
Reduce
,
128
,
1
,
128
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
>
,
// irregular size
DeviceNormalizationSplitKImpl
<
F16
,
F16
,
F16
,
F32
,
F16
,
OutElementwise
,
Rank
,
Reduce
,
256
,
1
,
256
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
>
,
// irregular size
DeviceNormalizationSplitKImpl
<
F16
,
F16
,
F16
,
F32
,
F16
,
OutElementwise
,
Rank
,
Reduce
,
512
,
1
,
512
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
>
,
// irregular size
DeviceNormalizationSplitKImpl
<
F16
,
F16
,
F16
,
F32
,
F16
,
OutElementwise
,
Rank
,
Reduce
,
1024
,
1
,
1024
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
>
,
// irregular size
DeviceNormalizationSplitKImpl
<
F16
,
F16
,
F16
,
F32
,
F16
,
OutElementwise
,
Rank
,
Reduce
,
256
,
1
,
256
,
1
,
2
,
1
,
2
,
1
,
2
,
1
,
2
,
2
>
,
// irregular size
DeviceNormalizationSplitKImpl
<
F16
,
F16
,
F16
,
F32
,
F16
,
OutElementwise
,
Rank
,
Reduce
,
256
,
1
,
256
,
1
,
4
,
1
,
4
,
1
,
4
,
1
,
4
,
4
>
,
// irregular size
DeviceNormalizationSplitKImpl
<
F16
,
F16
,
F16
,
F32
,
F16
,
OutElementwise
,
Rank
,
Reduce
,
64
,
1
,
64
,
1
,
8
,
1
,
8
,
1
,
8
,
1
,
8
,
8
>
,
DeviceNormalizationSplitKImpl
<
F16
,
F16
,
F16
,
F32
,
F16
,
OutElementwise
,
Rank
,
Reduce
,
128
,
1
,
128
,
1
,
8
,
1
,
8
,
1
,
8
,
1
,
8
,
8
>
,
DeviceNormalizationSplitKImpl
<
F16
,
F16
,
F16
,
F32
,
F16
,
OutElementwise
,
Rank
,
Reduce
,
128
,
1
,
128
,
1
,
16
,
1
,
8
,
1
,
8
,
1
,
8
,
8
>
,
DeviceNormalizationSplitKImpl
<
F16
,
F16
,
F16
,
F32
,
F16
,
OutElementwise
,
Rank
,
Reduce
,
128
,
1
,
128
,
1
,
32
,
1
,
8
,
1
,
8
,
1
,
8
,
8
>
,
DeviceNormalizationSplitKImpl
<
F16
,
F16
,
F16
,
F32
,
F16
,
OutElementwise
,
Rank
,
Reduce
,
256
,
1
,
256
,
1
,
8
,
1
,
8
,
1
,
8
,
1
,
8
,
8
>
,
DeviceNormalizationSplitKImpl
<
F16
,
F16
,
F16
,
F32
,
F16
,
OutElementwise
,
Rank
,
Reduce
,
256
,
1
,
256
,
1
,
16
,
1
,
8
,
1
,
8
,
1
,
8
,
8
>
,
DeviceNormalizationSplitKImpl
<
F16
,
F16
,
F16
,
F32
,
F16
,
OutElementwise
,
Rank
,
Reduce
,
256
,
1
,
256
,
2
,
16
,
1
,
8
,
1
,
8
,
1
,
8
,
8
>
,
DeviceNormalizationSplitKImpl
<
F16
,
F16
,
F16
,
F32
,
F16
,
OutElementwise
,
Rank
,
Reduce
,
256
,
1
,
256
,
1
,
32
,
1
,
8
,
1
,
8
,
1
,
8
,
8
>
,
DeviceNormalizationSplitKImpl
<
F16
,
F16
,
F16
,
F32
,
F16
,
OutElementwise
,
Rank
,
Reduce
,
512
,
1
,
512
,
1
,
8
,
1
,
8
,
1
,
8
,
1
,
8
,
8
>
,
DeviceNormalizationSplitKImpl
<
F16
,
F16
,
F16
,
F32
,
F16
,
OutElementwise
,
Rank
,
Reduce
,
512
,
1
,
512
,
1
,
16
,
1
,
8
,
1
,
8
,
1
,
8
,
8
>
,
DeviceNormalizationSplitKImpl
<
F16
,
F16
,
F16
,
F32
,
F16
,
OutElementwise
,
Rank
,
Reduce
,
1024
,
1
,
1024
,
1
,
8
,
1
,
8
,
1
,
8
,
1
,
8
,
8
>
,
DeviceNormalizationSplitKImpl
<
F16
,
F16
,
F16
,
F32
,
F16
,
OutElementwise
,
Rank
,
Reduce
,
1024
,
1
,
1024
,
1
,
16
,
1
,
8
,
1
,
8
,
1
,
8
,
8
>
// clang-format on
>
;
template
<
typename
OutElementwise
,
index_t
Rank
,
index_t
Reduce
>
using
device_normalization_f16_generic_instance
=
std
::
tuple
<
// clang-format off
...
...
@@ -76,6 +103,32 @@ using device_normalization_f32_instances = std::tuple<
// clang-format on
>
;
template
<
typename
OutElementwise
,
index_t
Rank
,
index_t
Reduce
>
using
device_normalization_splitk_f32_instances
=
std
::
tuple
<
// clang-format off
// XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorSize, BetaSrcVectorSize, YDstVectorSize>
DeviceNormalizationSplitKImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
OutElementwise
,
Rank
,
Reduce
,
128
,
1
,
128
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
>
,
// irregular size
DeviceNormalizationSplitKImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
OutElementwise
,
Rank
,
Reduce
,
256
,
1
,
256
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
>
,
// irregular size
DeviceNormalizationSplitKImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
OutElementwise
,
Rank
,
Reduce
,
512
,
1
,
512
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
>
,
// irregular size
DeviceNormalizationSplitKImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
OutElementwise
,
Rank
,
Reduce
,
1024
,
1
,
1024
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
>
,
// irregular size
DeviceNormalizationSplitKImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
OutElementwise
,
Rank
,
Reduce
,
256
,
1
,
256
,
1
,
2
,
1
,
2
,
1
,
2
,
1
,
2
,
2
>
,
// irregular size
DeviceNormalizationSplitKImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
OutElementwise
,
Rank
,
Reduce
,
128
,
1
,
128
,
1
,
4
,
1
,
4
,
1
,
4
,
1
,
4
,
4
>
,
DeviceNormalizationSplitKImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
OutElementwise
,
Rank
,
Reduce
,
128
,
1
,
128
,
1
,
8
,
1
,
4
,
1
,
4
,
1
,
4
,
4
>
,
DeviceNormalizationSplitKImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
OutElementwise
,
Rank
,
Reduce
,
128
,
1
,
128
,
1
,
16
,
1
,
4
,
1
,
4
,
1
,
4
,
4
>
,
DeviceNormalizationSplitKImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
OutElementwise
,
Rank
,
Reduce
,
128
,
1
,
128
,
1
,
32
,
1
,
4
,
1
,
4
,
1
,
4
,
4
>
,
DeviceNormalizationSplitKImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
OutElementwise
,
Rank
,
Reduce
,
256
,
1
,
256
,
1
,
4
,
1
,
4
,
1
,
4
,
1
,
4
,
4
>
,
DeviceNormalizationSplitKImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
OutElementwise
,
Rank
,
Reduce
,
256
,
1
,
256
,
1
,
8
,
1
,
4
,
1
,
4
,
1
,
4
,
4
>
,
DeviceNormalizationSplitKImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
OutElementwise
,
Rank
,
Reduce
,
256
,
1
,
256
,
1
,
16
,
1
,
4
,
1
,
4
,
1
,
4
,
4
>
,
DeviceNormalizationSplitKImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
OutElementwise
,
Rank
,
Reduce
,
256
,
1
,
256
,
2
,
16
,
1
,
4
,
1
,
4
,
1
,
4
,
4
>
,
DeviceNormalizationSplitKImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
OutElementwise
,
Rank
,
Reduce
,
256
,
1
,
256
,
1
,
32
,
1
,
4
,
1
,
4
,
1
,
4
,
4
>
,
DeviceNormalizationSplitKImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
OutElementwise
,
Rank
,
Reduce
,
512
,
1
,
512
,
1
,
4
,
1
,
4
,
1
,
4
,
1
,
4
,
4
>
,
DeviceNormalizationSplitKImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
OutElementwise
,
Rank
,
Reduce
,
512
,
1
,
512
,
1
,
8
,
1
,
4
,
1
,
4
,
1
,
4
,
4
>
,
DeviceNormalizationSplitKImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
OutElementwise
,
Rank
,
Reduce
,
512
,
1
,
512
,
2
,
8
,
1
,
4
,
1
,
4
,
1
,
4
,
4
>
,
DeviceNormalizationSplitKImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
OutElementwise
,
Rank
,
Reduce
,
1024
,
1
,
1024
,
1
,
4
,
1
,
4
,
1
,
4
,
1
,
4
,
4
>
,
DeviceNormalizationSplitKImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
OutElementwise
,
Rank
,
Reduce
,
1024
,
1
,
1024
,
1
,
8
,
1
,
4
,
1
,
4
,
1
,
4
,
4
>
// clang-format on
>
;
template
<
typename
OutElementwise
,
index_t
Rank
,
index_t
Reduce
>
using
device_normalization_f32_generic_instance
=
std
::
tuple
<
// clang-format off
...
...
@@ -109,6 +162,32 @@ using device_normalization_f16_f32_f32_f16_instances = std::tuple<
// clang-format on
>
;
template
<
typename
OutElementwise
,
index_t
Rank
,
index_t
Reduce
>
using
device_normalization_splitk_f16_f32_f32_f16_instances
=
std
::
tuple
<
// clang-format off
// XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorSize, BetaSrcVectorSize, YDstVectorSize>
DeviceNormalizationSplitKImpl
<
F16
,
F32
,
F32
,
F32
,
F16
,
OutElementwise
,
Rank
,
Reduce
,
128
,
1
,
128
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
>
,
// irregular size
DeviceNormalizationSplitKImpl
<
F16
,
F32
,
F32
,
F32
,
F16
,
OutElementwise
,
Rank
,
Reduce
,
256
,
1
,
256
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
>
,
// irregular size
DeviceNormalizationSplitKImpl
<
F16
,
F32
,
F32
,
F32
,
F16
,
OutElementwise
,
Rank
,
Reduce
,
512
,
1
,
512
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
>
,
// irregular size
DeviceNormalizationSplitKImpl
<
F16
,
F32
,
F32
,
F32
,
F16
,
OutElementwise
,
Rank
,
Reduce
,
1024
,
1
,
1024
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
>
,
// irregular size
DeviceNormalizationSplitKImpl
<
F16
,
F32
,
F32
,
F32
,
F16
,
OutElementwise
,
Rank
,
Reduce
,
256
,
1
,
256
,
1
,
2
,
1
,
2
,
1
,
2
,
1
,
2
,
2
>
,
// irregular size
DeviceNormalizationSplitKImpl
<
F16
,
F32
,
F32
,
F32
,
F16
,
OutElementwise
,
Rank
,
Reduce
,
128
,
1
,
128
,
1
,
4
,
1
,
4
,
1
,
4
,
1
,
4
,
4
>
,
DeviceNormalizationSplitKImpl
<
F16
,
F32
,
F32
,
F32
,
F16
,
OutElementwise
,
Rank
,
Reduce
,
128
,
1
,
128
,
1
,
8
,
1
,
4
,
1
,
4
,
1
,
4
,
4
>
,
DeviceNormalizationSplitKImpl
<
F16
,
F32
,
F32
,
F32
,
F16
,
OutElementwise
,
Rank
,
Reduce
,
128
,
1
,
128
,
1
,
16
,
1
,
4
,
1
,
4
,
1
,
4
,
4
>
,
DeviceNormalizationSplitKImpl
<
F16
,
F32
,
F32
,
F32
,
F16
,
OutElementwise
,
Rank
,
Reduce
,
128
,
1
,
128
,
1
,
32
,
1
,
4
,
1
,
4
,
1
,
4
,
4
>
,
DeviceNormalizationSplitKImpl
<
F16
,
F32
,
F32
,
F32
,
F16
,
OutElementwise
,
Rank
,
Reduce
,
256
,
1
,
256
,
1
,
4
,
1
,
4
,
1
,
4
,
1
,
4
,
4
>
,
DeviceNormalizationSplitKImpl
<
F16
,
F32
,
F32
,
F32
,
F16
,
OutElementwise
,
Rank
,
Reduce
,
256
,
1
,
256
,
1
,
8
,
1
,
4
,
1
,
4
,
1
,
4
,
4
>
,
DeviceNormalizationSplitKImpl
<
F16
,
F32
,
F32
,
F32
,
F16
,
OutElementwise
,
Rank
,
Reduce
,
256
,
1
,
256
,
1
,
16
,
1
,
4
,
1
,
4
,
1
,
4
,
4
>
,
DeviceNormalizationSplitKImpl
<
F16
,
F32
,
F32
,
F32
,
F16
,
OutElementwise
,
Rank
,
Reduce
,
256
,
1
,
256
,
2
,
16
,
1
,
4
,
1
,
4
,
1
,
4
,
4
>
,
DeviceNormalizationSplitKImpl
<
F16
,
F32
,
F32
,
F32
,
F16
,
OutElementwise
,
Rank
,
Reduce
,
256
,
1
,
256
,
1
,
32
,
1
,
4
,
1
,
4
,
1
,
4
,
4
>
,
DeviceNormalizationSplitKImpl
<
F16
,
F32
,
F32
,
F32
,
F16
,
OutElementwise
,
Rank
,
Reduce
,
512
,
1
,
512
,
1
,
4
,
1
,
4
,
1
,
4
,
1
,
4
,
4
>
,
DeviceNormalizationSplitKImpl
<
F16
,
F32
,
F32
,
F32
,
F16
,
OutElementwise
,
Rank
,
Reduce
,
512
,
1
,
512
,
1
,
8
,
1
,
4
,
1
,
4
,
1
,
4
,
4
>
,
DeviceNormalizationSplitKImpl
<
F16
,
F32
,
F32
,
F32
,
F16
,
OutElementwise
,
Rank
,
Reduce
,
512
,
1
,
512
,
2
,
8
,
1
,
4
,
1
,
4
,
1
,
4
,
4
>
,
DeviceNormalizationSplitKImpl
<
F16
,
F32
,
F32
,
F32
,
F16
,
OutElementwise
,
Rank
,
Reduce
,
1024
,
1
,
1024
,
1
,
4
,
1
,
4
,
1
,
4
,
1
,
4
,
4
>
,
DeviceNormalizationSplitKImpl
<
F16
,
F32
,
F32
,
F32
,
F16
,
OutElementwise
,
Rank
,
Reduce
,
1024
,
1
,
1024
,
1
,
8
,
1
,
4
,
1
,
4
,
1
,
4
,
4
>
// clang-format on
>
;
template
<
typename
OutElementwise
,
index_t
Rank
,
index_t
Reduce
>
using
device_normalization_f16_f32_f32_f16_generic_instance
=
std
::
tuple
<
// clang-format off
...
...
library/src/tensor_operation_instance/gpu/pool3d_fwd/CMakeLists.txt
0 → 100644
View file @
b93575ca
set
(
DEVICE_POOL3D_FWD_INSTANCES
)
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
list
(
APPEND DEVICE_POOL3D_FWD_INSTANCES device_avg_pool3d_fwd_ndhwc_f16_instance.cpp
device_max_pool3d_fwd_ndhwc_f16_instance.cpp
)
endif
()
if
(
DTYPES MATCHES
"fp32"
OR NOT DEFINED DTYPES
)
list
(
APPEND DEVICE_POOL3D_FWD_INSTANCES device_avg_pool3d_fwd_ndhwc_f32_instance.cpp
device_max_pool3d_fwd_ndhwc_f32_instance.cpp
)
endif
()
add_instance_library
(
device_pool3d_fwd_instance
${
DEVICE_POOL3D_FWD_INSTANCES
}
)
library/src/tensor_operation_instance/gpu/pool_fwd/device_avg_pool3d_fwd_ndhwc_f16_instance.cpp
→
library/src/tensor_operation_instance/gpu/pool
3d
_fwd/device_avg_pool3d_fwd_ndhwc_f16_instance.cpp
View file @
b93575ca
...
...
@@ -11,7 +11,9 @@ namespace instance {
static
constexpr
auto
ReduceOpId
=
ck
::
ReduceTensorOp
::
AVG
;
void
add_device_pool3d_fwd_ndhwc_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DevicePoolFwd
<
5
,
3
,
F16
,
F16
,
I32
,
ReduceOpId
,
false
>>>&
instances
)
std
::
vector
<
std
::
unique_ptr
<
DevicePoolFwd
<
5
,
3
,
F16
,
F16
,
I32
,
NDHWC
,
NDHWC
,
ReduceOpId
,
false
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_pool3d_fwd_ndhwc_instances
<
F16
,
F16
,
I32
,
F32
,
ReduceOpId
,
false
>
{});
...
...
library/src/tensor_operation_instance/gpu/pool_fwd/device_avg_pool3d_fwd_ndhwc_f32_instance.cpp
→
library/src/tensor_operation_instance/gpu/pool
3d
_fwd/device_avg_pool3d_fwd_ndhwc_f32_instance.cpp
View file @
b93575ca
...
...
@@ -11,7 +11,9 @@ namespace instance {
static
constexpr
auto
ReduceOpId
=
ck
::
ReduceTensorOp
::
AVG
;
void
add_device_pool3d_fwd_ndhwc_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DevicePoolFwd
<
5
,
3
,
F32
,
F32
,
I32
,
ReduceOpId
,
false
>>>&
instances
)
std
::
vector
<
std
::
unique_ptr
<
DevicePoolFwd
<
5
,
3
,
F32
,
F32
,
I32
,
NDHWC
,
NDHWC
,
ReduceOpId
,
false
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_pool3d_fwd_ndhwc_instances
<
F32
,
F32
,
I32
,
F32
,
ReduceOpId
,
false
>
{});
...
...
library/src/tensor_operation_instance/gpu/pool_fwd/device_max_pool3d_fwd_ndhwc_f16_instance.cpp
→
library/src/tensor_operation_instance/gpu/pool
3d
_fwd/device_max_pool3d_fwd_ndhwc_f16_instance.cpp
View file @
b93575ca
...
...
@@ -11,14 +11,18 @@ namespace instance {
static
constexpr
auto
ReduceOpId
=
ck
::
ReduceTensorOp
::
MAX
;
void
add_device_pool3d_fwd_ndhwc_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DevicePoolFwd
<
5
,
3
,
F16
,
F16
,
I32
,
ReduceOpId
,
false
>>>&
instances
)
std
::
vector
<
std
::
unique_ptr
<
DevicePoolFwd
<
5
,
3
,
F16
,
F16
,
I32
,
NDHWC
,
NDHWC
,
ReduceOpId
,
false
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_pool3d_fwd_ndhwc_instances
<
F16
,
F16
,
I32
,
F16
,
ReduceOpId
,
false
>
{});
}
void
add_device_pool3d_fwd_ndhwc_index_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DevicePoolFwd
<
5
,
3
,
F16
,
F16
,
I32
,
ReduceOpId
,
true
>>>&
instances
)
std
::
vector
<
std
::
unique_ptr
<
DevicePoolFwd
<
5
,
3
,
F16
,
F16
,
I32
,
NDHWC
,
NDHWC
,
ReduceOpId
,
true
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_pool3d_fwd_ndhwc_instances
<
F16
,
F16
,
I32
,
F16
,
ReduceOpId
,
true
>
{});
...
...
library/src/tensor_operation_instance/gpu/pool_fwd/device_max_pool3d_fwd_ndhwc_f32_instance.cpp
→
library/src/tensor_operation_instance/gpu/pool
3d
_fwd/device_max_pool3d_fwd_ndhwc_f32_instance.cpp
View file @
b93575ca
...
...
@@ -11,14 +11,18 @@ namespace instance {
static
constexpr
auto
ReduceOpId
=
ck
::
ReduceTensorOp
::
MAX
;
void
add_device_pool3d_fwd_ndhwc_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DevicePoolFwd
<
5
,
3
,
F32
,
F32
,
I32
,
ReduceOpId
,
false
>>>&
instances
)
std
::
vector
<
std
::
unique_ptr
<
DevicePoolFwd
<
5
,
3
,
F32
,
F32
,
I32
,
NDHWC
,
NDHWC
,
ReduceOpId
,
false
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_pool3d_fwd_ndhwc_instances
<
F32
,
F32
,
I32
,
F32
,
ReduceOpId
,
false
>
{});
}
void
add_device_pool3d_fwd_ndhwc_index_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DevicePoolFwd
<
5
,
3
,
F32
,
F32
,
I32
,
ReduceOpId
,
true
>>>&
instances
)
std
::
vector
<
std
::
unique_ptr
<
DevicePoolFwd
<
5
,
3
,
F32
,
F32
,
I32
,
NDHWC
,
NDHWC
,
ReduceOpId
,
true
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_pool3d_fwd_ndhwc_instances
<
F32
,
F32
,
I32
,
F32
,
ReduceOpId
,
true
>
{});
...
...
library/src/tensor_operation_instance/gpu/pool_fwd/pool_fwd_instance_common.hpp
→
library/src/tensor_operation_instance/gpu/pool
3d
_fwd/pool_fwd_instance_common.hpp
View file @
b93575ca
...
...
@@ -15,24 +15,10 @@ namespace tensor_operation {
namespace
device
{
namespace
instance
{
using
I32
=
int32_t
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
template
<
typename
InDataType
,
typename
OutDataType
,
typename
IndexDataType
,
typename
ComputeDataType
,
ReduceTensorOp
ReduceOpId
,
bool
OutputIndex
>
using
device_pool2d_fwd_nhwc_instances
=
// clang-format off
std
::
tuple
<
DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C
<
InDataType
,
OutDataType
,
IndexDataType
,
ComputeDataType
,
ReduceOpId
,
OutputIndex
,
256
,
256
,
1
,
1
,
1
,
1
>
,
DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C
<
InDataType
,
OutDataType
,
IndexDataType
,
ComputeDataType
,
ReduceOpId
,
OutputIndex
,
256
,
256
,
1
,
2
,
1
,
2
>
,
DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C
<
InDataType
,
OutDataType
,
IndexDataType
,
ComputeDataType
,
ReduceOpId
,
OutputIndex
,
256
,
256
,
1
,
4
,
1
,
4
>
// clang-format on
>
;
using
I32
=
int32_t
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
NDHWC
=
ck
::
tensor_layout
::
convolution
::
NDHWC
;
template
<
typename
InDataType
,
typename
OutDataType
,
...
...
@@ -43,9 +29,9 @@ template <typename InDataType,
using
device_pool3d_fwd_ndhwc_instances
=
// clang-format off
std
::
tuple
<
DevicePool3dFwd_
Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_
C
<
InDataType
,
OutDataType
,
IndexDataType
,
ComputeDataType
,
ReduceOpId
,
OutputIndex
,
256
,
256
,
1
,
1
,
1
,
1
>
,
DevicePool3dFwd_
Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_
C
<
InDataType
,
OutDataType
,
IndexDataType
,
ComputeDataType
,
ReduceOpId
,
OutputIndex
,
256
,
256
,
1
,
2
,
1
,
2
>
,
DevicePool3dFwd_
Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_
C
<
InDataType
,
OutDataType
,
IndexDataType
,
ComputeDataType
,
ReduceOpId
,
OutputIndex
,
256
,
256
,
1
,
4
,
1
,
4
>
DevicePool3dFwd_
NDHWC_NDHW
C
<
InDataType
,
OutDataType
,
IndexDataType
,
ComputeDataType
,
ReduceOpId
,
OutputIndex
,
256
,
256
,
1
,
1
,
1
,
1
>
,
DevicePool3dFwd_
NDHWC_NDHW
C
<
InDataType
,
OutDataType
,
IndexDataType
,
ComputeDataType
,
ReduceOpId
,
OutputIndex
,
256
,
256
,
1
,
2
,
1
,
2
>
,
DevicePool3dFwd_
NDHWC_NDHW
C
<
InDataType
,
OutDataType
,
IndexDataType
,
ComputeDataType
,
ReduceOpId
,
OutputIndex
,
256
,
256
,
1
,
4
,
1
,
4
>
// clang-format on
>
;
...
...
Prev
1
…
11
12
13
14
15
16
17
18
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment