Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
b134b7d6
Commit
b134b7d6
authored
May 16, 2022
by
carlushuang
Browse files
Merge remote-tracking branch 'origin/develop' into cpu_avx2
parents
090ba885
9f71ff48
Changes
211
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
343 additions
and
168 deletions
+343
-168
library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp
...xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp
+22
-24
library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instance.cpp
...xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instance.cpp
+22
-24
library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instance.cpp
...xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instance.cpp
+22
-24
library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instance.cpp
...xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instance.cpp
+19
-21
library/src/tensor_operation_instance/gpu/conv1d_fwd/CMakeLists.txt
...c/tensor_operation_instance/gpu/conv1d_fwd/CMakeLists.txt
+3
-3
library/src/tensor_operation_instance/gpu/conv2d_bwd_data/CMakeLists.txt
...sor_operation_instance/gpu/conv2d_bwd_data/CMakeLists.txt
+1
-3
library/src/tensor_operation_instance/gpu/conv2d_bwd_weight/CMakeLists.txt
...r_operation_instance/gpu/conv2d_bwd_weight/CMakeLists.txt
+1
-1
library/src/tensor_operation_instance/gpu/conv2d_fwd/CMakeLists.txt
...c/tensor_operation_instance/gpu/conv2d_fwd/CMakeLists.txt
+1
-3
library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu/CMakeLists.txt
...peration_instance/gpu/conv2d_fwd_bias_relu/CMakeLists.txt
+1
-3
library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu_add/CMakeLists.txt
...tion_instance/gpu/conv2d_fwd_bias_relu_add/CMakeLists.txt
+1
-3
library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu_atomic_add/CMakeLists.txt
...stance/gpu/conv2d_fwd_bias_relu_atomic_add/CMakeLists.txt
+1
-3
library/src/tensor_operation_instance/gpu/conv3d_fwd/CMakeLists.txt
...c/tensor_operation_instance/gpu/conv3d_fwd/CMakeLists.txt
+1
-2
library/src/tensor_operation_instance/gpu/convnd_bwd_data/CMakeLists.txt
...sor_operation_instance/gpu/convnd_bwd_data/CMakeLists.txt
+1
-1
library/src/tensor_operation_instance/gpu/device_conv2d.cpp
library/src/tensor_operation_instance/gpu/device_conv2d.cpp
+201
-0
library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt
...ary/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt
+1
-2
library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp
...m_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp
+20
-18
library/src/tensor_operation_instance/gpu/gemm_bias2d/CMakeLists.txt
.../tensor_operation_instance/gpu/gemm_bias2d/CMakeLists.txt
+1
-3
library/src/tensor_operation_instance/gpu/gemm_bias_relu/CMakeLists.txt
...nsor_operation_instance/gpu/gemm_bias_relu/CMakeLists.txt
+1
-3
library/src/tensor_operation_instance/gpu/gemm_bias_relu_add/CMakeLists.txt
..._operation_instance/gpu/gemm_bias_relu_add/CMakeLists.txt
+1
-3
library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp
...ce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp
+22
-24
No files found.
library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp
View file @
b134b7d6
...
...
@@ -2,7 +2,7 @@
#include "config.hpp"
#include "device_batched_gemm_reduce_xdl_cshuffle.hpp"
#include "element_wise_operation.hpp"
#include "
element_wise_
reduc
e
_operat
ion
.hpp"
#include "reduc
tion
_operat
or
.hpp"
#include "device_operation_instance.hpp"
namespace
ck
{
...
...
@@ -19,9 +19,9 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
ReduceSum
=
ck
::
tensor_operation
::
element_wise
::
ReduceSum
;
using
Reduce
Square
Sum
=
ck
::
tensor_operation
::
element_wise
::
ReduceSquareSum
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
ReduceSum
=
ck
::
reduce
::
Add
<
F32
>
;
using
Square
=
ck
::
tensor_operation
::
element_wise
::
UnarySquare
<
F32
,
F32
,
false
>
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
...
...
@@ -31,33 +31,31 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
using
device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances
=
std
::
tuple
<
// clang-format off
//##################################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| D0|
D1
|
GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//##################################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| D0|
D1|
D1
EleOp|
GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//##################################| | | | Type| Type| Type| DataType| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Reduce| Reduce| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//##################################| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
ReduceSquare
Sum
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
2
,
2
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
ReduceSquare
Sum
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
ReduceSquare
Sum
,
GemmDefault
,
1
,
256
,
128
,
256
,
32
,
4
,
4
,
32
,
32
,
2
,
4
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
ReduceSquare
Sum
,
GemmDefault
,
1
,
256
,
128
,
256
,
32
,
8
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
ReduceSquare
Sum
,
GemmDefault
,
1
,
128
,
128
,
128
,
32
,
2
,
2
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
S
<
32
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
ReduceSquare
Sum
,
GemmDefault
,
1
,
128
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
S
<
32
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
ReduceSquare
Sum
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
2
,
2
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
ReduceSquare
Sum
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
ReduceSquare
Sum
,
GemmDefault
,
1
,
128
,
128
,
64
,
32
,
2
,
2
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
S
<
64
,
2
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
ReduceSquare
Sum
,
GemmDefault
,
1
,
128
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
S
<
64
,
2
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
ReduceSquare
Sum
,
GemmDefault
,
1
,
128
,
64
,
128
,
32
,
2
,
2
,
32
,
32
,
2
,
2
,
S
<
8
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
S
<
32
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
ReduceSquare
Sum
,
GemmDefault
,
1
,
128
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
S
<
32
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
ReduceSquare
Sum
,
GemmDefault
,
1
,
256
,
128
,
64
,
32
,
2
,
2
,
32
,
32
,
2
,
1
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
ReduceSquare
Sum
,
GemmDefault
,
1
,
256
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
ReduceSquare
Sum
,
GemmDefault
,
1
,
256
,
64
,
128
,
32
,
2
,
2
,
32
,
32
,
1
,
2
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
ReduceSquare
Sum
,
GemmDefault
,
1
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
Reduce
Sum
,
Square
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
2
,
2
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
Reduce
Sum
,
Square
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
Reduce
Sum
,
Square
,
GemmDefault
,
1
,
256
,
128
,
256
,
32
,
4
,
4
,
32
,
32
,
2
,
4
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
Reduce
Sum
,
Square
,
GemmDefault
,
1
,
256
,
128
,
256
,
32
,
8
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
Reduce
Sum
,
Square
,
GemmDefault
,
1
,
128
,
128
,
128
,
32
,
2
,
2
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
S
<
32
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
Reduce
Sum
,
Square
,
GemmDefault
,
1
,
128
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
S
<
32
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
Reduce
Sum
,
Square
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
2
,
2
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
Reduce
Sum
,
Square
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
Reduce
Sum
,
Square
,
GemmDefault
,
1
,
128
,
128
,
64
,
32
,
2
,
2
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
S
<
64
,
2
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
Reduce
Sum
,
Square
,
GemmDefault
,
1
,
128
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
S
<
64
,
2
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
Reduce
Sum
,
Square
,
GemmDefault
,
1
,
128
,
64
,
128
,
32
,
2
,
2
,
32
,
32
,
2
,
2
,
S
<
8
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
S
<
32
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
Reduce
Sum
,
Square
,
GemmDefault
,
1
,
128
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
S
<
32
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
Reduce
Sum
,
Square
,
GemmDefault
,
1
,
256
,
128
,
64
,
32
,
2
,
2
,
32
,
32
,
2
,
1
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
Reduce
Sum
,
Square
,
GemmDefault
,
1
,
256
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
Reduce
Sum
,
Square
,
GemmDefault
,
1
,
256
,
64
,
128
,
32
,
2
,
2
,
32
,
32
,
1
,
2
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
Reduce
Sum
,
Square
,
GemmDefault
,
1
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
// clang-format on
>
;
void
add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances
(
std
::
vector
<
DeviceGemmReducePtr
<
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
ReduceSquareSum
>>&
instances
)
std
::
vector
<
DeviceGemmReducePtr
<
PassThrough
,
PassThrough
,
PassThrough
,
Square
>>&
instances
)
{
add_device_operation_instances
(
instances
,
...
...
library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instance.cpp
View file @
b134b7d6
...
...
@@ -2,7 +2,7 @@
#include "config.hpp"
#include "device_batched_gemm_reduce_xdl_cshuffle.hpp"
#include "element_wise_operation.hpp"
#include "
element_wise_
reduc
e
_operat
ion
.hpp"
#include "reduc
tion
_operat
or
.hpp"
#include "device_operation_instance.hpp"
namespace
ck
{
...
...
@@ -19,9 +19,9 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
ReduceSum
=
ck
::
tensor_operation
::
element_wise
::
ReduceSum
;
using
Reduce
Square
Sum
=
ck
::
tensor_operation
::
element_wise
::
ReduceSquareSum
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
ReduceSum
=
ck
::
reduce
::
Add
<
F32
>
;
using
Square
=
ck
::
tensor_operation
::
element_wise
::
UnarySquare
<
F32
,
F32
,
false
>
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
...
...
@@ -31,33 +31,31 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
using
device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances
=
std
::
tuple
<
// clang-format off
//##################################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| D0|
D1
|
GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//##################################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| D0|
D1|
D1
EleOp|
GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//##################################| | | | Type| Type| Type| DataType| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Reduce| Reduce| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//##################################| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
ReduceSquare
Sum
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
2
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
ReduceSquare
Sum
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
ReduceSquare
Sum
,
GemmDefault
,
1
,
256
,
128
,
256
,
32
,
2
,
8
,
32
,
32
,
2
,
4
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
ReduceSquare
Sum
,
GemmDefault
,
1
,
256
,
128
,
256
,
32
,
8
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
ReduceSquare
Sum
,
GemmDefault
,
1
,
128
,
128
,
128
,
32
,
2
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
S
<
32
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
ReduceSquare
Sum
,
GemmDefault
,
1
,
128
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
S
<
32
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
ReduceSquare
Sum
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
2
,
8
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
ReduceSquare
Sum
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
ReduceSquare
Sum
,
GemmDefault
,
1
,
128
,
128
,
64
,
32
,
2
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
S
<
64
,
2
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
ReduceSquare
Sum
,
GemmDefault
,
1
,
128
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
S
<
64
,
2
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
ReduceSquare
Sum
,
GemmDefault
,
1
,
128
,
64
,
128
,
32
,
2
,
8
,
32
,
32
,
2
,
2
,
S
<
8
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
S
<
32
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
ReduceSquare
Sum
,
GemmDefault
,
1
,
128
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
S
<
32
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
ReduceSquare
Sum
,
GemmDefault
,
1
,
256
,
128
,
64
,
32
,
2
,
8
,
32
,
32
,
2
,
1
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
ReduceSquare
Sum
,
GemmDefault
,
1
,
256
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
ReduceSquare
Sum
,
GemmDefault
,
1
,
256
,
64
,
128
,
32
,
2
,
8
,
32
,
32
,
1
,
2
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
ReduceSquare
Sum
,
GemmDefault
,
1
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
Reduce
Sum
,
Square
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
2
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
Reduce
Sum
,
Square
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
Reduce
Sum
,
Square
,
GemmDefault
,
1
,
256
,
128
,
256
,
32
,
2
,
8
,
32
,
32
,
2
,
4
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
Reduce
Sum
,
Square
,
GemmDefault
,
1
,
256
,
128
,
256
,
32
,
8
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
Reduce
Sum
,
Square
,
GemmDefault
,
1
,
128
,
128
,
128
,
32
,
2
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
S
<
32
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
Reduce
Sum
,
Square
,
GemmDefault
,
1
,
128
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
S
<
32
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
Reduce
Sum
,
Square
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
2
,
8
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
Reduce
Sum
,
Square
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
Reduce
Sum
,
Square
,
GemmDefault
,
1
,
128
,
128
,
64
,
32
,
2
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
S
<
64
,
2
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
Reduce
Sum
,
Square
,
GemmDefault
,
1
,
128
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
S
<
64
,
2
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
Reduce
Sum
,
Square
,
GemmDefault
,
1
,
128
,
64
,
128
,
32
,
2
,
8
,
32
,
32
,
2
,
2
,
S
<
8
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
S
<
32
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
Reduce
Sum
,
Square
,
GemmDefault
,
1
,
128
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
S
<
32
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
Reduce
Sum
,
Square
,
GemmDefault
,
1
,
256
,
128
,
64
,
32
,
2
,
8
,
32
,
32
,
2
,
1
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
Reduce
Sum
,
Square
,
GemmDefault
,
1
,
256
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
Reduce
Sum
,
Square
,
GemmDefault
,
1
,
256
,
64
,
128
,
32
,
2
,
8
,
32
,
32
,
1
,
2
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
Reduce
Sum
,
Square
,
GemmDefault
,
1
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
// clang-format on
>
;
void
add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances
(
std
::
vector
<
DeviceGemmReducePtr
<
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
ReduceSquareSum
>>&
instances
)
std
::
vector
<
DeviceGemmReducePtr
<
PassThrough
,
PassThrough
,
PassThrough
,
Square
>>&
instances
)
{
add_device_operation_instances
(
instances
,
...
...
library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instance.cpp
View file @
b134b7d6
...
...
@@ -2,7 +2,7 @@
#include "config.hpp"
#include "device_batched_gemm_reduce_xdl_cshuffle.hpp"
#include "element_wise_operation.hpp"
#include "
element_wise_
reduc
e
_operat
ion
.hpp"
#include "reduc
tion
_operat
or
.hpp"
#include "device_operation_instance.hpp"
namespace
ck
{
...
...
@@ -19,9 +19,9 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
ReduceSum
=
ck
::
tensor_operation
::
element_wise
::
ReduceSum
;
using
Reduce
Square
Sum
=
ck
::
tensor_operation
::
element_wise
::
ReduceSquareSum
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
ReduceSum
=
ck
::
reduce
::
Add
<
F32
>
;
using
Square
=
ck
::
tensor_operation
::
element_wise
::
UnarySquare
<
F32
,
F32
,
false
>
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
...
...
@@ -31,33 +31,31 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
using
device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances
=
std
::
tuple
<
// clang-format off
//##################################| ALayout| BLayout| CLayout| AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| D0|
D1|
GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//##################################| ALayout| BLayout| CLayout| AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| D0| D1|
D1EleOp|
GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//##################################| | | | Type| Type| Type| DataType| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Reduce| Reduce| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//##################################| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
ReduceSquare
Sum
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
8
,
2
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
ReduceSquare
Sum
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
ReduceSquare
Sum
,
GemmDefault
,
1
,
256
,
128
,
256
,
32
,
8
,
2
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
ReduceSquare
Sum
,
GemmDefault
,
1
,
256
,
128
,
256
,
32
,
8
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
ReduceSquare
Sum
,
GemmDefault
,
1
,
128
,
128
,
128
,
32
,
8
,
2
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
S
<
32
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
ReduceSquare
Sum
,
GemmDefault
,
1
,
128
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
S
<
32
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
ReduceSquare
Sum
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
8
,
2
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
ReduceSquare
Sum
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
ReduceSquare
Sum
,
GemmDefault
,
1
,
128
,
128
,
64
,
32
,
8
,
2
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
S
<
64
,
2
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
ReduceSquare
Sum
,
GemmDefault
,
1
,
128
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
S
<
64
,
2
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
ReduceSquare
Sum
,
GemmDefault
,
1
,
128
,
64
,
128
,
32
,
8
,
2
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
S
<
32
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
ReduceSquare
Sum
,
GemmDefault
,
1
,
128
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
S
<
32
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
ReduceSquare
Sum
,
GemmDefault
,
1
,
256
,
128
,
64
,
32
,
8
,
2
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
ReduceSquare
Sum
,
GemmDefault
,
1
,
256
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
ReduceSquare
Sum
,
GemmDefault
,
1
,
256
,
64
,
128
,
32
,
8
,
2
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
ReduceSquare
Sum
,
GemmDefault
,
1
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
Reduce
Sum
,
Square
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
8
,
2
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
Reduce
Sum
,
Square
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
Reduce
Sum
,
Square
,
GemmDefault
,
1
,
256
,
128
,
256
,
32
,
8
,
2
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
Reduce
Sum
,
Square
,
GemmDefault
,
1
,
256
,
128
,
256
,
32
,
8
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
Reduce
Sum
,
Square
,
GemmDefault
,
1
,
128
,
128
,
128
,
32
,
8
,
2
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
S
<
32
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
Reduce
Sum
,
Square
,
GemmDefault
,
1
,
128
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
S
<
32
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
Reduce
Sum
,
Square
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
8
,
2
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
Reduce
Sum
,
Square
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
Reduce
Sum
,
Square
,
GemmDefault
,
1
,
128
,
128
,
64
,
32
,
8
,
2
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
S
<
64
,
2
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
Reduce
Sum
,
Square
,
GemmDefault
,
1
,
128
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
S
<
64
,
2
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
Reduce
Sum
,
Square
,
GemmDefault
,
1
,
128
,
64
,
128
,
32
,
8
,
2
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
S
<
32
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
Reduce
Sum
,
Square
,
GemmDefault
,
1
,
128
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
S
<
32
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
Reduce
Sum
,
Square
,
GemmDefault
,
1
,
256
,
128
,
64
,
32
,
8
,
2
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
Reduce
Sum
,
Square
,
GemmDefault
,
1
,
256
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
Reduce
Sum
,
Square
,
GemmDefault
,
1
,
256
,
64
,
128
,
32
,
8
,
2
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
Reduce
Sum
,
Square
,
GemmDefault
,
1
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
// clang-format on
>
;
void
add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances
(
std
::
vector
<
DeviceGemmReducePtr
<
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
ReduceSquareSum
>>&
instances
)
std
::
vector
<
DeviceGemmReducePtr
<
PassThrough
,
PassThrough
,
PassThrough
,
Square
>>&
instances
)
{
add_device_operation_instances
(
instances
,
...
...
library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instance.cpp
View file @
b134b7d6
...
...
@@ -2,7 +2,7 @@
#include "config.hpp"
#include "device_batched_gemm_reduce_xdl_cshuffle.hpp"
#include "element_wise_operation.hpp"
#include "
element_wise_
reduc
e
_operat
ion
.hpp"
#include "reduc
tion
_operat
or
.hpp"
#include "device_operation_instance.hpp"
namespace
ck
{
...
...
@@ -19,9 +19,9 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
ReduceSum
=
ck
::
tensor_operation
::
element_wise
::
ReduceSum
;
using
Reduce
Square
Sum
=
ck
::
tensor_operation
::
element_wise
::
ReduceSquareSum
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
ReduceSum
=
ck
::
reduce
::
Add
<
F32
>
;
using
Square
=
ck
::
tensor_operation
::
element_wise
::
UnarySquare
<
F32
,
F32
,
false
>
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
...
...
@@ -31,30 +31,28 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
using
device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instances
=
std
::
tuple
<
// clang-format off
//##################################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| D0|
D1|
GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//##################################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| D0|
D1| D1EleOp|
GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//##################################| | | | Type| Type| Type| DataType| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Reduce| Reduce| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//##################################| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
ReduceSquare
Sum
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
ReduceSquare
Sum
,
GemmDefault
,
1
,
256
,
128
,
256
,
32
,
8
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
ReduceSquare
Sum
,
GemmDefault
,
1
,
128
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
S
<
32
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
ReduceSquare
Sum
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
ReduceSquare
Sum
,
GemmDefault
,
1
,
128
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
S
<
64
,
2
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
ReduceSquare
Sum
,
GemmDefault
,
1
,
128
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
S
<
32
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
ReduceSquare
Sum
,
GemmDefault
,
1
,
64
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
S
<
32
,
2
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
ReduceSquare
Sum
,
GemmDefault
,
1
,
256
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
ReduceSquare
Sum
,
GemmDefault
,
1
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
ReduceSquare
Sum
,
GemmDefault
,
1
,
128
,
128
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
S
<
64
,
2
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
ReduceSquare
Sum
,
GemmDefault
,
1
,
128
,
32
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
S
<
32
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
ReduceSquare
Sum
,
GemmDefault
,
1
,
64
,
64
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
S
<
32
,
2
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
ReduceSquare
Sum
,
GemmDefault
,
1
,
64
,
32
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
S
<
32
,
2
>
,
4
,
1
>
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
Reduce
Sum
,
Square
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
Reduce
Sum
,
Square
,
GemmDefault
,
1
,
256
,
128
,
256
,
32
,
8
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
Reduce
Sum
,
Square
,
GemmDefault
,
1
,
128
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
S
<
32
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
Reduce
Sum
,
Square
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
Reduce
Sum
,
Square
,
GemmDefault
,
1
,
128
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
S
<
64
,
2
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
Reduce
Sum
,
Square
,
GemmDefault
,
1
,
128
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
S
<
32
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
Reduce
Sum
,
Square
,
GemmDefault
,
1
,
64
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
S
<
32
,
2
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
Reduce
Sum
,
Square
,
GemmDefault
,
1
,
256
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
Reduce
Sum
,
Square
,
GemmDefault
,
1
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
Reduce
Sum
,
Square
,
GemmDefault
,
1
,
128
,
128
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
S
<
64
,
2
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
Reduce
Sum
,
Square
,
GemmDefault
,
1
,
128
,
32
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
S
<
32
,
4
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
Reduce
Sum
,
Square
,
GemmDefault
,
1
,
64
,
64
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
S
<
32
,
2
>
,
4
,
1
>
,
DeviceBatchedGemmReduce_Xdl_CShuffle
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
Reduce
Sum
,
Square
,
GemmDefault
,
1
,
64
,
32
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
S
<
32
,
2
>
,
4
,
1
>
// clang-format on
>
;
void
add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instances
(
std
::
vector
<
DeviceGemmReducePtr
<
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
ReduceSquareSum
>>&
instances
)
std
::
vector
<
DeviceGemmReducePtr
<
PassThrough
,
PassThrough
,
PassThrough
,
Square
>>&
instances
)
{
add_device_operation_instances
(
instances
,
...
...
library/src/tensor_operation_instance/gpu/conv1d_fwd/CMakeLists.txt
View file @
b134b7d6
...
...
@@ -6,9 +6,9 @@ set(DEVICE_CONV1D_FWD_INSTANCE_SOURCE
device_conv1d_fwd_xdl_nwc_kxc_nwk_int8_instance.cpp;
)
add_library
(
device_conv1d_fwd_instance
SHARED
${
DEVICE_CONV1D_FWD_INSTANCE_SOURCE
}
)
target_compile_features
(
device_conv1d_fwd_instance PUBLIC
)
add_library
(
device_conv1d_fwd_instance
OBJECT
${
DEVICE_CONV1D_FWD_INSTANCE_SOURCE
}
)
#
target_compile_features(device_conv1d_fwd_instance PUBLIC)
set_target_properties
(
device_conv1d_fwd_instance PROPERTIES POSITION_INDEPENDENT_CODE ON
)
install
(
TARGETS device_conv1d_fwd_instance LIBRARY DESTINATION lib
)
#
install(TARGETS device_conv1d_fwd_instance LIBRARY DESTINATION lib)
clang_tidy_check
(
device_conv1d_fwd_instance
)
library/src/tensor_operation_instance/gpu/conv2d_bwd_data/CMakeLists.txt
View file @
b134b7d6
...
...
@@ -6,9 +6,7 @@ set(DEVICE_CONV2D_BWD_DATA_INSTANCE_SOURCE
device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instance.cpp;
)
add_library
(
device_conv2d_bwd_data_instance SHARED
${
DEVICE_CONV2D_BWD_DATA_INSTANCE_SOURCE
}
)
target_compile_features
(
device_conv2d_bwd_data_instance PUBLIC
)
add_library
(
device_conv2d_bwd_data_instance OBJECT
${
DEVICE_CONV2D_BWD_DATA_INSTANCE_SOURCE
}
)
set_target_properties
(
device_conv2d_bwd_data_instance PROPERTIES POSITION_INDEPENDENT_CODE ON
)
install
(
TARGETS device_conv2d_bwd_data_instance LIBRARY DESTINATION lib
)
clang_tidy_check
(
device_conv2d_bwd_data_instance
)
library/src/tensor_operation_instance/gpu/conv2d_bwd_weight/CMakeLists.txt
View file @
b134b7d6
...
...
@@ -3,7 +3,7 @@ set(DEVICE_CONV2D_BWD_WEIGHT_INSTANCE_SOURCE
device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instance.cpp;
device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instance.cpp;
)
add_library
(
device_conv2d_bwd_weight_instance
SHARED
${
DEVICE_CONV2D_BWD_WEIGHT_INSTANCE_SOURCE
}
)
add_library
(
device_conv2d_bwd_weight_instance
OBJECT
${
DEVICE_CONV2D_BWD_WEIGHT_INSTANCE_SOURCE
}
)
target_compile_features
(
device_conv2d_bwd_weight_instance PUBLIC
)
set_target_properties
(
device_conv2d_bwd_weight_instance PROPERTIES POSITION_INDEPENDENT_CODE ON
)
install
(
TARGETS device_conv2d_bwd_weight_instance LIBRARY DESTINATION lib
)
...
...
library/src/tensor_operation_instance/gpu/conv2d_fwd/CMakeLists.txt
View file @
b134b7d6
...
...
@@ -6,9 +6,7 @@ set(DEVICE_CONV2D_FWD_INSTANCE_SOURCE
device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp;
device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp;
)
add_library
(
device_conv2d_fwd_instance SHARED
${
DEVICE_CONV2D_FWD_INSTANCE_SOURCE
}
)
target_compile_features
(
device_conv2d_fwd_instance PUBLIC
)
add_library
(
device_conv2d_fwd_instance OBJECT
${
DEVICE_CONV2D_FWD_INSTANCE_SOURCE
}
)
set_target_properties
(
device_conv2d_fwd_instance PROPERTIES POSITION_INDEPENDENT_CODE ON
)
install
(
TARGETS device_conv2d_fwd_instance LIBRARY DESTINATION lib
)
clang_tidy_check
(
device_conv2d_fwd_instance
)
library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu/CMakeLists.txt
View file @
b134b7d6
...
...
@@ -2,9 +2,7 @@
set
(
DEVICE_CONV2D_FWD_BIAS_RELU_INSTANCE_SOURCE
device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instance.cpp;
)
add_library
(
device_conv2d_fwd_bias_relu_instance SHARED
${
DEVICE_CONV2D_FWD_BIAS_RELU_INSTANCE_SOURCE
}
)
target_compile_features
(
device_conv2d_fwd_bias_relu_instance PUBLIC
)
add_library
(
device_conv2d_fwd_bias_relu_instance OBJECT
${
DEVICE_CONV2D_FWD_BIAS_RELU_INSTANCE_SOURCE
}
)
set_target_properties
(
device_conv2d_fwd_bias_relu_instance PROPERTIES POSITION_INDEPENDENT_CODE ON
)
install
(
TARGETS device_conv2d_fwd_bias_relu_instance LIBRARY DESTINATION lib
)
clang_tidy_check
(
device_conv2d_fwd_bias_relu_instance
)
library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu_add/CMakeLists.txt
View file @
b134b7d6
...
...
@@ -2,9 +2,7 @@
set
(
DEVICE_CONV2D_FWD_BIAS_RELU_ADD_INSTANCE_SOURCE
device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instance.cpp;
)
add_library
(
device_conv2d_fwd_bias_relu_add_instance SHARED
${
DEVICE_CONV2D_FWD_BIAS_RELU_ADD_INSTANCE_SOURCE
}
)
target_compile_features
(
device_conv2d_fwd_bias_relu_add_instance PUBLIC
)
add_library
(
device_conv2d_fwd_bias_relu_add_instance OBJECT
${
DEVICE_CONV2D_FWD_BIAS_RELU_ADD_INSTANCE_SOURCE
}
)
set_target_properties
(
device_conv2d_fwd_bias_relu_add_instance PROPERTIES POSITION_INDEPENDENT_CODE ON
)
install
(
TARGETS device_conv2d_fwd_bias_relu_add_instance LIBRARY DESTINATION lib
)
clang_tidy_check
(
device_conv2d_fwd_bias_relu_add_instance
)
library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu_atomic_add/CMakeLists.txt
View file @
b134b7d6
...
...
@@ -3,9 +3,7 @@ set(DEVICE_CONV2D_FWD_BIAS_RELU_ATOMIC_ADD_INSTANCE_SOURCE
device_conv2d_fwd_xdl_c_shuffle_bias_relu_atomic_add_nhwc_kyxc_nhwk_f16_instance.cpp;
)
add_library
(
device_conv2d_fwd_bias_relu_atomic_add_instance SHARED
${
DEVICE_CONV2D_FWD_BIAS_RELU_ATOMIC_ADD_INSTANCE_SOURCE
}
)
target_compile_features
(
device_conv2d_fwd_bias_relu_atomic_add_instance PUBLIC
)
add_library
(
device_conv2d_fwd_bias_relu_atomic_add_instance OBJECT
${
DEVICE_CONV2D_FWD_BIAS_RELU_ATOMIC_ADD_INSTANCE_SOURCE
}
)
set_target_properties
(
device_conv2d_fwd_bias_relu_atomic_add_instance PROPERTIES POSITION_INDEPENDENT_CODE ON
)
install
(
TARGETS device_conv2d_fwd_bias_relu_atomic_add_instance LIBRARY DESTINATION lib
)
clang_tidy_check
(
device_conv2d_fwd_bias_relu_atomic_add_instance
)
library/src/tensor_operation_instance/gpu/conv3d_fwd/CMakeLists.txt
View file @
b134b7d6
...
...
@@ -5,9 +5,8 @@ set(DEVICE_CONV3D_FWD_INSTANCE_SOURCE
device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instance.cpp;
device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_int8_instance.cpp;
)
add_library
(
device_conv3d_fwd_instance
SHARED
${
DEVICE_CONV3D_FWD_INSTANCE_SOURCE
}
)
add_library
(
device_conv3d_fwd_instance
OBJECT
${
DEVICE_CONV3D_FWD_INSTANCE_SOURCE
}
)
target_compile_features
(
device_conv3d_fwd_instance PUBLIC
)
set_target_properties
(
device_conv3d_fwd_instance PROPERTIES POSITION_INDEPENDENT_CODE ON
)
install
(
TARGETS device_conv3d_fwd_instance LIBRARY DESTINATION lib
)
clang_tidy_check
(
device_conv3d_fwd_instance
)
library/src/tensor_operation_instance/gpu/convnd_bwd_data/CMakeLists.txt
View file @
b134b7d6
...
...
@@ -14,7 +14,7 @@ set(DEVICE_CONVND_BWD_DATA_INSTANCE_SOURCE
device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instance.cpp;
)
add_library
(
device_convnd_bwd_data_instance
SHARED
${
DEVICE_CONVND_BWD_DATA_INSTANCE_SOURCE
}
)
add_library
(
device_convnd_bwd_data_instance
OBJECT
${
DEVICE_CONVND_BWD_DATA_INSTANCE_SOURCE
}
)
target_compile_features
(
device_convnd_bwd_data_instance PUBLIC
)
set_target_properties
(
device_convnd_bwd_data_instance PROPERTIES POSITION_INDEPENDENT_CODE ON
)
install
(
TARGETS device_convnd_bwd_data_instance LIBRARY DESTINATION lib
)
...
...
library/src/tensor_operation_instance/gpu/device_conv2d.cpp
0 → 100644
View file @
b134b7d6
#include <stdlib.h>
#include "config.hpp"
#include "device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
#include "host_interface.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_conv2d_fwd_instance
{
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
void
add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances
(
std
::
vector
<
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
);
void
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances
(
std
::
vector
<
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
);
void
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances
(
std
::
vector
<
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
);
void
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances
(
std
::
vector
<
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
);
void
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances
(
std
::
vector
<
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
);
}
// namespace device_conv2d_fwd_instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
struct
DeviceConvFwdPtr_t
::
DeviceConvFwdPtrImpl
{
std
::
unique_ptr
<
DeviceConvFwdPtr_t
::
BaseArgument
>
MakeArgumentPointer
(
void
*
in_ptr
,
void
*
wei_ptr
,
void
*
out_ptr
,
size_t
N
,
size_t
K
,
size_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
)
const
{
return
el
->
MakeArgumentPointer
(
in_ptr
,
wei_ptr
,
out_ptr
,
N
,
K
,
C
,
input_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
PassThrough
{},
PassThrough
{},
PassThrough
{});
}
std
::
unique_ptr
<
DeviceConvFwdPtr_t
::
BaseInvoker
>
MakeInvokerPointer
()
const
{
return
el
->
MakeInvokerPointer
();
}
std
::
string
GetTypeString
()
{
return
el
->
GetTypeString
();
}
bool
IsSupportedArgument
(
const
DeviceConvFwdPtr_t
::
BaseArgument
*
arg
)
{
return
el
->
IsSupportedArgument
(
arg
);
}
ck
::
tensor_operation
::
device
::
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
PassThrough
>
el
;
};
DeviceConvFwdPtr_t
::
DeviceConvFwdPtr_t
()
:
pImpl
(
nullptr
)
{}
DeviceConvFwdPtr_t
::~
DeviceConvFwdPtr_t
()
=
default
;
DeviceConvFwdPtr_t
::
DeviceConvFwdPtr_t
(
DeviceConvFwdPtr_t
&&
)
=
default
;
DeviceConvFwdPtr_t
::
DeviceConvFwdPtr_t
(
DeviceConvFwdPtr_t
::
DeviceConvFwdPtrImpl
&
other
)
:
pImpl
(
std
::
make_unique
<
DeviceConvFwdPtr_t
::
DeviceConvFwdPtrImpl
>
(
std
::
move
(
other
)))
{
}
std
::
unique_ptr
<
DeviceConvFwdPtr_t
::
BaseArgument
>
DeviceConvFwdPtr_t
::
MakeArgumentPointer
(
void
*
in_ptr
,
void
*
wei_ptr
,
void
*
out_ptr
,
size_t
N
,
size_t
K
,
size_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
)
const
{
return
pImpl
->
MakeArgumentPointer
(
in_ptr
,
wei_ptr
,
out_ptr
,
N
,
K
,
C
,
input_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
);
}
std
::
unique_ptr
<
DeviceConvFwdPtr_t
::
BaseInvoker
>
DeviceConvFwdPtr_t
::
MakeInvokerPointer
()
const
{
return
pImpl
->
MakeInvokerPointer
();
}
std
::
string
DeviceConvFwdPtr_t
::
GetTypeString
()
{
return
pImpl
->
GetTypeString
();
}
bool
DeviceConvFwdPtr_t
::
IsSupportedArgument
(
const
DeviceConvFwdPtr_t
::
BaseArgument
*
arg_ptr
)
{
return
pImpl
->
IsSupportedArgument
(
arg_ptr
);
}
using
namespace
ck
::
tensor_operation
::
device
::
device_conv2d_fwd_instance
;
void
add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances_t
(
std
::
vector
<
DeviceConvFwdPtr_t
>&
instances
)
{
std
::
vector
<
ck
::
tensor_operation
::
device
::
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>
local_instances
;
add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances
(
local_instances
);
for
(
auto
&
kinder
:
local_instances
)
{
DeviceConvFwdPtr_t
::
DeviceConvFwdPtrImpl
tmp
{
std
::
move
(
kinder
)};
instances
.
emplace_back
(
tmp
);
}
return
;
}
void
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances_t
(
std
::
vector
<
DeviceConvFwdPtr_t
>&
instances
)
{
std
::
vector
<
ck
::
tensor_operation
::
device
::
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>
local_instances
;
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances
(
local_instances
);
for
(
auto
&
kinder
:
local_instances
)
{
DeviceConvFwdPtr_t
::
DeviceConvFwdPtrImpl
tmp
{
std
::
move
(
kinder
)};
instances
.
emplace_back
(
tmp
);
// Perhaps we can do better
}
return
;
}
void
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances_t
(
std
::
vector
<
DeviceConvFwdPtr_t
>&
instances
)
{
std
::
vector
<
ck
::
tensor_operation
::
device
::
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>
local_instances
;
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances
(
local_instances
);
for
(
auto
&
kinder
:
local_instances
)
{
DeviceConvFwdPtr_t
::
DeviceConvFwdPtrImpl
tmp
{
std
::
move
(
kinder
)};
instances
.
emplace_back
(
tmp
);
// Perhaps we can do better
}
return
;
}
void
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances_t
(
std
::
vector
<
DeviceConvFwdPtr_t
>&
instances
)
{
std
::
vector
<
ck
::
tensor_operation
::
device
::
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>
local_instances
;
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances
(
local_instances
);
for
(
auto
&
kinder
:
local_instances
)
{
DeviceConvFwdPtr_t
::
DeviceConvFwdPtrImpl
tmp
{
std
::
move
(
kinder
)};
instances
.
emplace_back
(
tmp
);
// Perhaps we can do better
}
return
;
}
void
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances_t
(
std
::
vector
<
DeviceConvFwdPtr_t
>&
instances
)
{
std
::
vector
<
ck
::
tensor_operation
::
device
::
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>
local_instances
;
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances
(
local_instances
);
for
(
auto
&
kinder
:
local_instances
)
{
DeviceConvFwdPtr_t
::
DeviceConvFwdPtrImpl
tmp
{
std
::
move
(
kinder
)};
instances
.
emplace_back
(
tmp
);
}
return
;
}
library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt
View file @
b134b7d6
...
...
@@ -35,10 +35,9 @@ set(DEVICE_GEMM_INSTANCE_SOURCE
device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp;
)
add_library
(
device_gemm_instance
SHARED
${
DEVICE_GEMM_INSTANCE_SOURCE
}
)
add_library
(
device_gemm_instance
OBJECT
${
DEVICE_GEMM_INSTANCE_SOURCE
}
)
target_compile_features
(
device_gemm_instance PUBLIC
)
set_target_properties
(
device_gemm_instance PROPERTIES POSITION_INDEPENDENT_CODE ON
)
install
(
TARGETS device_gemm_instance LIBRARY DESTINATION lib
)
clang_tidy_check
(
device_gemm_instance
)
library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp
View file @
b134b7d6
#include <stdlib.h>
#include "config.hpp"
#include "device_gemm_xdl_c
_
shuffle.hpp"
#include "device_gemm_xdl_cshuffle.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
...
...
@@ -20,26 +20,28 @@ using S = ck::Sequence<Is...>;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
// Compilation parameters for a[m, k] * b[n, k] = c[m, n]
using
device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances
=
std
::
tuple
<
// clang-format off
//#####################| AData| BData| CData| AccData| CShuffle|
ALayout| BLayout| CLayout|
A| B| C| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle|
CBlockTransferClusterLengths| CBlockTransfer|
Num|
//#####################| Type| Type| Type| Type| DataType|
| | |
Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave|
_MBlock_MXdlPerWave
_MWaveMPerXdl| ScalarPerVector|
Prefetch|
//#####################|
|
|
| | |
| | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle|
_NBlock_NXdlPerWave
_NWaveNPerXdl| _NWaveNPerXdl|
|
//#####################| | | | | |
| |
| | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
| |
|
DeviceGemmXdl_C
_
Shuffle
<
F16
,
F16
,
F16
,
F32
,
F16
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
32
,
1
,
1
,
8
>
,
8
,
2
>
,
DeviceGemmXdl_C
_
Shuffle
<
F16
,
F16
,
F16
,
F32
,
F16
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
128
,
256
,
32
,
8
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
32
,
1
,
1
,
8
>
,
8
,
2
>
,
DeviceGemmXdl_C
_
Shuffle
<
F16
,
F16
,
F16
,
F32
,
F16
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
16
,
1
,
1
,
8
>
,
8
,
2
>
,
DeviceGemmXdl_C
_
Shuffle
<
F16
,
F16
,
F16
,
F32
,
F16
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
32
,
1
,
1
,
8
>
,
8
,
2
>
,
DeviceGemmXdl_C
_
Shuffle
<
F16
,
F16
,
F16
,
F32
,
F16
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
32
,
1
,
1
,
4
>
,
8
,
2
>
,
DeviceGemmXdl_C
_
Shuffle
<
F16
,
F16
,
F16
,
F32
,
F16
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
16
,
1
,
1
,
8
>
,
8
,
2
>
,
DeviceGemmXdl_C
_
Shuffle
<
F16
,
F16
,
F16
,
F32
,
F16
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
64
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
16
,
1
,
1
,
4
>
,
8
,
2
>
,
DeviceGemmXdl_C
_
Shuffle
<
F16
,
F16
,
F16
,
F32
,
F16
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
32
,
1
,
1
,
8
>
,
8
,
2
>
,
DeviceGemmXdl_C
_
Shuffle
<
F16
,
F16
,
F16
,
F32
,
F16
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
32
,
1
,
1
,
8
>
,
8
,
2
>
,
DeviceGemmXdl_C
_
Shuffle
<
F16
,
F16
,
F16
,
F32
,
F16
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
128
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
32
,
1
,
1
,
4
>
,
8
,
2
>
,
DeviceGemmXdl_C
_
Shuffle
<
F16
,
F16
,
F16
,
F32
,
F16
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
32
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
16
,
1
,
1
,
8
>
,
8
,
2
>
,
DeviceGemmXdl_C
_
Shuffle
<
F16
,
F16
,
F16
,
F32
,
F16
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
64
,
64
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
16
,
1
,
1
,
4
>
,
8
,
2
>
,
DeviceGemmXdl_C
_
Shuffle
<
F16
,
F16
,
F16
,
F32
,
F16
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
64
,
32
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
16
,
1
,
1
,
4
>
,
8
,
2
>
//#####################|
ALayout| BLayout| CLayout|
AData| BData| CData| AccData| CShuffle| A| B| C|
GEMM| NumGemmK|
Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################|
| | |
Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise|
Spacialization| Prefetch|
Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave|
_MBlock
_MWaveMPerXdl| ScalarPerVector|
//#####################|
|
| |
|
| | |
| Operation| Operation| Operation|
| Stage|
| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle|
_NBlock
_NWaveNPerXdl| _NWaveNPerXdl|
//#####################|
| | |
| | | | |
| | |
|
| | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm
_
Xdl_CShuffle
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
2
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGemm
_
Xdl_CShuffle
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
2
,
256
,
128
,
256
,
32
,
8
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGemm
_
Xdl_CShuffle
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
2
,
128
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGemm
_
Xdl_CShuffle
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
2
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGemm
_
Xdl_CShuffle
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
2
,
128
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
,
DeviceGemm
_
Xdl_CShuffle
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
2
,
128
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGemm
_
Xdl_CShuffle
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
2
,
64
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
>
,
DeviceGemm
_
Xdl_CShuffle
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
2
,
256
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGemm
_
Xdl_CShuffle
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
2
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGemm
_
Xdl_CShuffle
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
2
,
128
,
128
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
,
DeviceGemm
_
Xdl_CShuffle
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
2
,
128
,
32
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGemm
_
Xdl_CShuffle
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
2
,
64
,
64
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
>
,
DeviceGemm
_
Xdl_CShuffle
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
2
,
64
,
32
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
>
// clang-format on
>
;
...
...
library/src/tensor_operation_instance/gpu/gemm_bias2d/CMakeLists.txt
View file @
b134b7d6
...
...
@@ -10,9 +10,7 @@ set(DEVICE_GEMM_BIAS2D_INSTANCE_SOURCE
device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_nk_mn_instance.cpp;
)
add_library
(
device_gemm_bias2d_instance SHARED
${
DEVICE_GEMM_BIAS2D_INSTANCE_SOURCE
}
)
target_compile_features
(
device_gemm_bias2d_instance PUBLIC
)
add_library
(
device_gemm_bias2d_instance OBJECT
${
DEVICE_GEMM_BIAS2D_INSTANCE_SOURCE
}
)
set_target_properties
(
device_gemm_bias2d_instance PROPERTIES POSITION_INDEPENDENT_CODE ON
)
install
(
TARGETS device_gemm_bias2d_instance LIBRARY DESTINATION lib
)
clang_tidy_check
(
device_gemm_bias2d_instance
)
library/src/tensor_operation_instance/gpu/gemm_bias_relu/CMakeLists.txt
View file @
b134b7d6
...
...
@@ -6,9 +6,7 @@ set(DEVICE_GEMM_BIAS_RELU_INSTANCE_SOURCE
device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_nk_mn_instance.cpp;
)
add_library
(
device_gemm_bias_relu_instance SHARED
${
DEVICE_GEMM_BIAS_RELU_INSTANCE_SOURCE
}
)
target_compile_features
(
device_gemm_bias_relu_instance PUBLIC
)
add_library
(
device_gemm_bias_relu_instance OBJECT
${
DEVICE_GEMM_BIAS_RELU_INSTANCE_SOURCE
}
)
set_target_properties
(
device_gemm_bias_relu_instance PROPERTIES POSITION_INDEPENDENT_CODE ON
)
install
(
TARGETS device_gemm_bias_relu_instance LIBRARY DESTINATION lib
)
clang_tidy_check
(
device_gemm_bias_relu_instance
)
library/src/tensor_operation_instance/gpu/gemm_bias_relu_add/CMakeLists.txt
View file @
b134b7d6
...
...
@@ -6,9 +6,7 @@ set(DEVICE_GEMM_BIAS_RELU_ADD_INSTANCE_SOURCE
device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_nk_mn_instance.cpp;
)
add_library
(
device_gemm_bias_relu_add_instance SHARED
${
DEVICE_GEMM_BIAS_RELU_ADD_INSTANCE_SOURCE
}
)
target_compile_features
(
device_gemm_bias_relu_add_instance PUBLIC
)
add_library
(
device_gemm_bias_relu_add_instance OBJECT
${
DEVICE_GEMM_BIAS_RELU_ADD_INSTANCE_SOURCE
}
)
set_target_properties
(
device_gemm_bias_relu_add_instance PROPERTIES POSITION_INDEPENDENT_CODE ON
)
install
(
TARGETS device_gemm_bias_relu_add_instance LIBRARY DESTINATION lib
)
clang_tidy_check
(
device_gemm_bias_relu_add_instance
)
library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp
View file @
b134b7d6
...
...
@@ -2,7 +2,7 @@
#include "config.hpp"
#include "device_gemm_reduce_xdl_cshuffle.hpp"
#include "element_wise_operation.hpp"
#include "
element_wise_
reduc
e
_operat
ion
.hpp"
#include "reduc
tion
_operat
or
.hpp"
#include "device_operation_instance.hpp"
namespace
ck
{
...
...
@@ -19,9 +19,9 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
ReduceSum
=
ck
::
tensor_operation
::
element_wise
::
ReduceSum
;
using
Reduce
Square
Sum
=
ck
::
tensor_operation
::
element_wise
::
ReduceSquareSum
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
ReduceSum
=
ck
::
reduce
::
Add
<
F32
>
;
using
Square
=
ck
::
tensor_operation
::
element_wise
::
UnarySquare
<
F32
,
F32
,
false
>
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
...
...
@@ -30,33 +30,31 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
// d1[m] = reduce1(c[m, n])
using
device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instances
=
std
::
tuple
<
// clang-format off
//###########################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| D0|
D1|
GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//###########################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| D0|
D1| D1EleOp|
GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//###########################| | | | Type| Type| Type| DataType| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Reduce| Reduce| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//###########################| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
ReduceSquare
Sum
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
2
,
2
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
ReduceSquare
Sum
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
ReduceSquare
Sum
,
GemmDefault
,
1
,
256
,
128
,
256
,
32
,
2
,
2
,
32
,
32
,
2
,
4
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
ReduceSquare
Sum
,
GemmDefault
,
1
,
256
,
128
,
256
,
32
,
8
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
ReduceSquare
Sum
,
GemmDefault
,
1
,
128
,
128
,
128
,
32
,
2
,
2
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
S
<
32
,
4
>
,
4
,
1
>
,
DeviceGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
ReduceSquare
Sum
,
GemmDefault
,
1
,
128
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
S
<
32
,
4
>
,
4
,
1
>
,
DeviceGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
ReduceSquare
Sum
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
2
,
2
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
ReduceSquare
Sum
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
ReduceSquare
Sum
,
GemmDefault
,
1
,
128
,
128
,
64
,
32
,
2
,
2
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
S
<
64
,
2
>
,
4
,
1
>
,
DeviceGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
ReduceSquare
Sum
,
GemmDefault
,
1
,
128
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
S
<
64
,
2
>
,
4
,
1
>
,
DeviceGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
ReduceSquare
Sum
,
GemmDefault
,
1
,
128
,
64
,
128
,
32
,
2
,
2
,
32
,
32
,
2
,
2
,
S
<
8
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
S
<
32
,
4
>
,
4
,
1
>
,
DeviceGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
ReduceSquare
Sum
,
GemmDefault
,
1
,
128
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
S
<
32
,
4
>
,
4
,
1
>
,
DeviceGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
ReduceSquare
Sum
,
GemmDefault
,
1
,
256
,
128
,
64
,
32
,
2
,
2
,
32
,
32
,
2
,
1
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
ReduceSquare
Sum
,
GemmDefault
,
1
,
256
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
ReduceSquare
Sum
,
GemmDefault
,
1
,
256
,
64
,
128
,
32
,
2
,
2
,
32
,
32
,
1
,
2
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
ReduceSquare
Sum
,
GemmDefault
,
1
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
DeviceGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
Reduce
Sum
,
Square
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
2
,
2
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
Reduce
Sum
,
Square
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
Reduce
Sum
,
Square
,
GemmDefault
,
1
,
256
,
128
,
256
,
32
,
2
,
2
,
32
,
32
,
2
,
4
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
Reduce
Sum
,
Square
,
GemmDefault
,
1
,
256
,
128
,
256
,
32
,
8
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
Reduce
Sum
,
Square
,
GemmDefault
,
1
,
128
,
128
,
128
,
32
,
2
,
2
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
S
<
32
,
4
>
,
4
,
1
>
,
DeviceGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
Reduce
Sum
,
Square
,
GemmDefault
,
1
,
128
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
S
<
32
,
4
>
,
4
,
1
>
,
DeviceGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
Reduce
Sum
,
Square
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
2
,
2
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
Reduce
Sum
,
Square
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
Reduce
Sum
,
Square
,
GemmDefault
,
1
,
128
,
128
,
64
,
32
,
2
,
2
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
S
<
64
,
2
>
,
4
,
1
>
,
DeviceGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
Reduce
Sum
,
Square
,
GemmDefault
,
1
,
128
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
S
<
64
,
2
>
,
4
,
1
>
,
DeviceGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
Reduce
Sum
,
Square
,
GemmDefault
,
1
,
128
,
64
,
128
,
32
,
2
,
2
,
32
,
32
,
2
,
2
,
S
<
8
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
S
<
32
,
4
>
,
4
,
1
>
,
DeviceGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
Reduce
Sum
,
Square
,
GemmDefault
,
1
,
128
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
S
<
32
,
4
>
,
4
,
1
>
,
DeviceGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
Reduce
Sum
,
Square
,
GemmDefault
,
1
,
256
,
128
,
64
,
32
,
2
,
2
,
32
,
32
,
2
,
1
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
Reduce
Sum
,
Square
,
GemmDefault
,
1
,
256
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
Reduce
Sum
,
Square
,
GemmDefault
,
1
,
256
,
64
,
128
,
32
,
2
,
2
,
32
,
32
,
1
,
2
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceGemmReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
Reduce
Sum
,
Square
,
GemmDefault
,
1
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
// clang-format on
>
;
void
add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instances
(
std
::
vector
<
DeviceGemmReducePtr
<
PassThrough
,
PassThrough
,
PassThrough
,
ReduceSum
,
ReduceSquareSum
>>&
instances
)
std
::
vector
<
DeviceGemmReducePtr
<
PassThrough
,
PassThrough
,
PassThrough
,
Square
>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instances
{});
...
...
Prev
1
…
3
4
5
6
7
8
9
10
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment