Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
457c024d
Commit
457c024d
authored
Nov 20, 2021
by
Chao Liu
Browse files
update ckProfiler
parent
2066a3d4
Changes
19
Hide whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
356 additions
and
189 deletions
+356
-189
device_operation/device_conv_xdl_instance_f16_f16_f16_nhwc_kyxc_nhwk.cpp
...n/device_conv_xdl_instance_f16_f16_f16_nhwc_kyxc_nhwk.cpp
+21
-18
device_operation/device_conv_xdl_instance_f32_f32_f32_nhwc_kyxc_nhwk.cpp
...n/device_conv_xdl_instance_f32_f32_f32_nhwc_kyxc_nhwk.cpp
+21
-18
device_operation/device_gemm_xdl_instance_f16_f16_f16_km_kn_mn.cpp
...eration/device_gemm_xdl_instance_f16_f16_f16_km_kn_mn.cpp
+16
-13
device_operation/device_gemm_xdl_instance_f16_f16_f16_km_nk_mn.cpp
...eration/device_gemm_xdl_instance_f16_f16_f16_km_nk_mn.cpp
+16
-13
device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn.cpp
...eration/device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn.cpp
+16
-13
device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn.cpp
...eration/device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn.cpp
+21
-18
device_operation/device_gemm_xdl_instance_f32_f32_f32_km_kn_mn.cpp
...eration/device_gemm_xdl_instance_f32_f32_f32_km_kn_mn.cpp
+16
-13
device_operation/device_gemm_xdl_instance_f32_f32_f32_km_nk_mn.cpp
...eration/device_gemm_xdl_instance_f32_f32_f32_km_nk_mn.cpp
+16
-13
device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn.cpp
...eration/device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn.cpp
+16
-13
device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn.cpp
...eration/device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn.cpp
+21
-18
device_operation/include/device_conv.hpp
device_operation/include/device_conv.hpp
+38
-6
device_operation/include/device_conv_fwd_xdl.hpp
device_operation/include/device_conv_fwd_xdl.hpp
+3
-0
device_operation/include/device_conv_fwd_xdl_nhwc_kyxc_nhwk.hpp
..._operation/include/device_conv_fwd_xdl_nhwc_kyxc_nhwk.hpp
+51
-8
device_operation/include/device_conv_instance.hpp
device_operation/include/device_conv_instance.hpp
+13
-3
device_operation/include/device_gemm_instance.hpp
device_operation/include/device_gemm_instance.hpp
+5
-1
device_operation/include/element_wise_operation.hpp
device_operation/include/element_wise_operation.hpp
+20
-0
example/1_gemm_xdl/gemm_xdl.cpp
example/1_gemm_xdl/gemm_xdl.cpp
+6
-6
profiler/include/profile_conv.hpp
profiler/include/profile_conv.hpp
+17
-4
profiler/include/profile_gemm.hpp
profiler/include/profile_gemm.hpp
+23
-11
No files found.
device_operation/device_conv_xdl_instance_f16_f16_f16_nhwc_kyxc_nhwk.cpp
View file @
457c024d
...
...
@@ -2,6 +2,7 @@
#include "config.hpp"
#include "device_conv_fwd_xdl_nhwc_kyxc_nhwk.hpp"
#include "device_conv_instance.hpp"
#include "element_wise_operation.hpp"
namespace
ck
{
namespace
tensor_operation
{
...
...
@@ -18,32 +19,34 @@ using NHWK = ck::tensor_layout::convolution::NHWK;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k]
using
device_conv_fwd_xdl_instances_f16_f16_f16_nhwc_kyxc_nhwk
=
std
::
tuple
<
// clang-format off
//##############| NDim| InData| WeiData| OutData| AccData| In| Wei| Out| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds|
//##############| Spatial| Type| Type| Type| Type| Layout| Layout| Layout| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
//##############| | | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
//##############| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceConvFwdXdl
<
2
,
F16
,
F16
,
F16
,
F32
,
NHWC
,
KYXC
,
NHWK
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceConvFwdXdl
<
2
,
F16
,
F16
,
F16
,
F32
,
NHWC
,
KYXC
,
NHWK
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
4
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceConvFwdXdl
<
2
,
F16
,
F16
,
F16
,
F32
,
NHWC
,
KYXC
,
NHWK
,
128
,
128
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
4
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceConvFwdXdl
<
2
,
F16
,
F16
,
F16
,
F32
,
NHWC
,
KYXC
,
NHWK
,
256
,
128
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceConvFwdXdl
<
2
,
F16
,
F16
,
F16
,
F32
,
NHWC
,
KYXC
,
NHWK
,
128
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
2
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceConvFwdXdl
<
2
,
F16
,
F16
,
F16
,
F32
,
NHWC
,
KYXC
,
NHWK
,
128
,
64
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
2
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
4
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceConvFwdXdl
<
2
,
F16
,
F16
,
F16
,
F32
,
NHWC
,
KYXC
,
NHWK
,
64
,
64
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
8
>
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
4
,
8
>
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceConvFwdXdl
<
2
,
F16
,
F16
,
F16
,
F32
,
NHWC
,
KYXC
,
NHWK
,
256
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
1
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceConvFwdXdl
<
2
,
F16
,
F16
,
F16
,
F32
,
NHWC
,
KYXC
,
NHWK
,
256
,
64
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
1
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceConvFwdXdl
<
2
,
F16
,
F16
,
F16
,
F32
,
NHWC
,
KYXC
,
NHWK
,
128
,
128
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
1
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceConvFwdXdl
<
2
,
F16
,
F16
,
F16
,
F32
,
NHWC
,
KYXC
,
NHWK
,
128
,
32
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
1
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
4
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceConvFwdXdl
<
2
,
F16
,
F16
,
F16
,
F32
,
NHWC
,
KYXC
,
NHWK
,
64
,
64
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
8
>
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
2
,
8
>
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceConvFwdXdl
<
2
,
F16
,
F16
,
F16
,
F32
,
NHWC
,
KYXC
,
NHWK
,
64
,
32
,
64
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
2
,
8
>
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
4
,
8
>
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
//##############| NDim| InData| WeiData| OutData| AccData| In| Wei| Out|
A| B| C|
Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds|
//##############| Spatial| Type| Type| Type| Type| Layout| Layout| Layout|
Elementwise| Elementwise| Elementwise|
Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
//##############| | | | | | | | |
Operation| Operation| Operation|
| | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
//##############| | | | | | | | |
| | |
| | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceConvFwdXdl
<
2
,
F16
,
F16
,
F16
,
F32
,
NHWC
,
KYXC
,
NHWK
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceConvFwdXdl
<
2
,
F16
,
F16
,
F16
,
F32
,
NHWC
,
KYXC
,
NHWK
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
4
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceConvFwdXdl
<
2
,
F16
,
F16
,
F16
,
F32
,
NHWC
,
KYXC
,
NHWK
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
128
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
4
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceConvFwdXdl
<
2
,
F16
,
F16
,
F16
,
F32
,
NHWC
,
KYXC
,
NHWK
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
128
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceConvFwdXdl
<
2
,
F16
,
F16
,
F16
,
F32
,
NHWC
,
KYXC
,
NHWK
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
2
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceConvFwdXdl
<
2
,
F16
,
F16
,
F16
,
F32
,
NHWC
,
KYXC
,
NHWK
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
64
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
2
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
4
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceConvFwdXdl
<
2
,
F16
,
F16
,
F16
,
F32
,
NHWC
,
KYXC
,
NHWK
,
PassThrough
,
PassThrough
,
PassThrough
,
64
,
64
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
8
>
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
4
,
8
>
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceConvFwdXdl
<
2
,
F16
,
F16
,
F16
,
F32
,
NHWC
,
KYXC
,
NHWK
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
1
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceConvFwdXdl
<
2
,
F16
,
F16
,
F16
,
F32
,
NHWC
,
KYXC
,
NHWK
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
64
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
1
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceConvFwdXdl
<
2
,
F16
,
F16
,
F16
,
F32
,
NHWC
,
KYXC
,
NHWK
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
128
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
1
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceConvFwdXdl
<
2
,
F16
,
F16
,
F16
,
F32
,
NHWC
,
KYXC
,
NHWK
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
32
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
1
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
4
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceConvFwdXdl
<
2
,
F16
,
F16
,
F16
,
F32
,
NHWC
,
KYXC
,
NHWK
,
PassThrough
,
PassThrough
,
PassThrough
,
64
,
64
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
8
>
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
2
,
8
>
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceConvFwdXdl
<
2
,
F16
,
F16
,
F16
,
F32
,
NHWC
,
KYXC
,
NHWK
,
PassThrough
,
PassThrough
,
PassThrough
,
64
,
32
,
64
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
2
,
8
>
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
4
,
8
>
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
// clang-format on
>
;
template
<
>
void
add_device_conv_fwd_instance
<
2
,
F16
,
F16
,
F16
,
NHWC
,
KYXC
,
NHWK
>
(
std
::
vector
<
DeviceConvFwdPtr
>&
device_conv_instances
)
std
::
vector
<
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
PassThrough
>
>&
device_conv_instances
)
{
using
DeviceConvs
=
device_conv_fwd_xdl_instances_f16_f16_f16_nhwc_kyxc_nhwk
;
...
...
device_operation/device_conv_xdl_instance_f32_f32_f32_nhwc_kyxc_nhwk.cpp
View file @
457c024d
...
...
@@ -2,6 +2,7 @@
#include "config.hpp"
#include "device_conv_fwd_xdl_nhwc_kyxc_nhwk.hpp"
#include "device_conv_instance.hpp"
#include "element_wise_operation.hpp"
namespace
ck
{
namespace
tensor_operation
{
...
...
@@ -18,32 +19,34 @@ using NHWK = ck::tensor_layout::convolution::NHWK;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k]
using
device_conv_fwd_xdl_instances_f32_f32_f32_nhwc_kyxc_nhwk
=
std
::
tuple
<
// clang-format off
//##############| NDim| InData| WeiData| OutData| AccData| In| Wei| Out| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds|
//##############| Spatial| Type| Type| Type| Type| Layout| Layout| Layout| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
//##############| | | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
//##############| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceConvFwdXdl
<
2
,
F32
,
F32
,
F32
,
F32
,
NHWC
,
KYXC
,
NHWK
,
256
,
256
,
128
,
4
,
4
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceConvFwdXdl
<
2
,
F32
,
F32
,
F32
,
F32
,
NHWC
,
KYXC
,
NHWK
,
256
,
128
,
256
,
4
,
4
,
32
,
32
,
2
,
4
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
4
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceConvFwdXdl
<
2
,
F32
,
F32
,
F32
,
F32
,
NHWC
,
KYXC
,
NHWK
,
128
,
128
,
128
,
4
,
4
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
4
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceConvFwdXdl
<
2
,
F32
,
F32
,
F32
,
F32
,
NHWC
,
KYXC
,
NHWK
,
256
,
128
,
128
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceConvFwdXdl
<
2
,
F32
,
F32
,
F32
,
F32
,
NHWC
,
KYXC
,
NHWK
,
128
,
128
,
64
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
2
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceConvFwdXdl
<
2
,
F32
,
F32
,
F32
,
F32
,
NHWC
,
KYXC
,
NHWK
,
128
,
64
,
128
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
1
,
2
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
4
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceConvFwdXdl
<
2
,
F32
,
F32
,
F32
,
F32
,
NHWC
,
KYXC
,
NHWK
,
64
,
64
,
64
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
4
>
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
4
,
4
>
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceConvFwdXdl
<
2
,
F32
,
F32
,
F32
,
F32
,
NHWC
,
KYXC
,
NHWK
,
256
,
128
,
64
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
1
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceConvFwdXdl
<
2
,
F32
,
F32
,
F32
,
F32
,
NHWC
,
KYXC
,
NHWK
,
256
,
64
,
128
,
4
,
4
,
32
,
32
,
1
,
2
,
S
<
1
,
1
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceConvFwdXdl
<
2
,
F32
,
F32
,
F32
,
F32
,
NHWC
,
KYXC
,
NHWK
,
128
,
128
,
32
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
1
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceConvFwdXdl
<
2
,
F32
,
F32
,
F32
,
F32
,
NHWC
,
KYXC
,
NHWK
,
128
,
32
,
128
,
4
,
4
,
32
,
32
,
1
,
2
,
S
<
1
,
1
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
4
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceConvFwdXdl
<
2
,
F32
,
F32
,
F32
,
F32
,
NHWC
,
KYXC
,
NHWK
,
64
,
64
,
32
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
4
>
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
2
,
4
>
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceConvFwdXdl
<
2
,
F32
,
F32
,
F32
,
F32
,
NHWC
,
KYXC
,
NHWK
,
64
,
32
,
64
,
4
,
4
,
32
,
32
,
1
,
2
,
S
<
1
,
2
,
4
>
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
4
,
4
>
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
7
,
1
,
true
,
true
>
//##############| NDim| InData| WeiData| OutData| AccData| In| Wei| Out|
A| B| C|
Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds|
//##############| Spatial| Type| Type| Type| Type| Layout| Layout| Layout|
Elementwise| Elementwise| Elementwise|
Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
//##############| | | | | | | | |
Operation| Operation| Operation|
| | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
//##############| | | | | | | | |
| | |
| | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceConvFwdXdl
<
2
,
F32
,
F32
,
F32
,
F32
,
NHWC
,
KYXC
,
NHWK
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
256
,
128
,
4
,
4
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceConvFwdXdl
<
2
,
F32
,
F32
,
F32
,
F32
,
NHWC
,
KYXC
,
NHWK
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
128
,
256
,
4
,
4
,
32
,
32
,
2
,
4
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
4
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceConvFwdXdl
<
2
,
F32
,
F32
,
F32
,
F32
,
NHWC
,
KYXC
,
NHWK
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
128
,
128
,
4
,
4
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
4
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceConvFwdXdl
<
2
,
F32
,
F32
,
F32
,
F32
,
NHWC
,
KYXC
,
NHWK
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
128
,
128
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceConvFwdXdl
<
2
,
F32
,
F32
,
F32
,
F32
,
NHWC
,
KYXC
,
NHWK
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
128
,
64
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
2
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceConvFwdXdl
<
2
,
F32
,
F32
,
F32
,
F32
,
NHWC
,
KYXC
,
NHWK
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
64
,
128
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
1
,
2
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
4
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceConvFwdXdl
<
2
,
F32
,
F32
,
F32
,
F32
,
NHWC
,
KYXC
,
NHWK
,
PassThrough
,
PassThrough
,
PassThrough
,
64
,
64
,
64
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
4
>
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
4
,
4
>
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceConvFwdXdl
<
2
,
F32
,
F32
,
F32
,
F32
,
NHWC
,
KYXC
,
NHWK
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
128
,
64
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
1
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceConvFwdXdl
<
2
,
F32
,
F32
,
F32
,
F32
,
NHWC
,
KYXC
,
NHWK
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
64
,
128
,
4
,
4
,
32
,
32
,
1
,
2
,
S
<
1
,
1
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceConvFwdXdl
<
2
,
F32
,
F32
,
F32
,
F32
,
NHWC
,
KYXC
,
NHWK
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
128
,
32
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
1
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceConvFwdXdl
<
2
,
F32
,
F32
,
F32
,
F32
,
NHWC
,
KYXC
,
NHWK
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
32
,
128
,
4
,
4
,
32
,
32
,
1
,
2
,
S
<
1
,
1
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
4
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceConvFwdXdl
<
2
,
F32
,
F32
,
F32
,
F32
,
NHWC
,
KYXC
,
NHWK
,
PassThrough
,
PassThrough
,
PassThrough
,
64
,
64
,
32
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
4
>
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
2
,
4
>
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceConvFwdXdl
<
2
,
F32
,
F32
,
F32
,
F32
,
NHWC
,
KYXC
,
NHWK
,
PassThrough
,
PassThrough
,
PassThrough
,
64
,
32
,
64
,
4
,
4
,
32
,
32
,
1
,
2
,
S
<
1
,
2
,
4
>
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
4
,
4
>
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
7
,
1
,
true
,
true
>
// clang-format on
>
;
template
<
>
void
add_device_conv_fwd_instance
<
2
,
F32
,
F32
,
F32
,
NHWC
,
KYXC
,
NHWK
>
(
std
::
vector
<
DeviceConvFwdPtr
>&
device_conv_instances
)
std
::
vector
<
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
PassThrough
>
>&
device_conv_instances
)
{
using
DeviceConvs
=
device_conv_fwd_xdl_instances_f32_f32_f32_nhwc_kyxc_nhwk
;
...
...
device_operation/device_gemm_xdl_instance_f16_f16_f16_km_kn_mn.cpp
View file @
457c024d
...
...
@@ -2,6 +2,7 @@
#include "config.hpp"
#include "device_gemm_xdl.hpp"
#include "device_gemm_instance.hpp"
#include "element_wise_operation.hpp"
namespace
ck
{
namespace
tensor_operation
{
...
...
@@ -17,27 +18,29 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
// Compilation parameters for a[k, m] * b[k, n] = c[m, n]
using
device_gemm_xdl_instance_f16_f16_f16_km_kn_mn
=
std
::
tuple
<
// clang-format off
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds|
//##########| Type| Type| Type| Type| | | | Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
//##########| | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
//##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
S
<
1
,
4
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
128
,
128
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
S
<
1
,
4
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
256
,
128
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
128
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
S
<
1
,
2
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
128
,
64
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
2
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
S
<
1
,
4
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
256
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
S
<
1
,
1
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
256
,
64
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
1
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
7
,
1
,
true
,
true
>
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout|
A| B| C|
Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds|
//##########| Type| Type| Type| Type| | | |
Elementwise| Elementwise| Elementwise|
Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
//##########| | | | | | | |
Operation| Operation| Operation|
| | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
//##########| | | | | | | |
| | |
| | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
S
<
1
,
4
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
128
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
S
<
1
,
4
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
128
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
S
<
1
,
2
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
64
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
2
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
S
<
1
,
4
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
S
<
1
,
1
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
64
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
1
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
7
,
1
,
true
,
true
>
// clang-format on
>
;
template
<
>
void
add_device_gemm_instance
<
F16
,
F16
,
F16
,
Col
,
Row
,
Row
>
(
std
::
vector
<
DeviceGemmPtr
>&
device_op_instances
)
std
::
vector
<
DeviceGemmPtr
<
PassThrough
,
PassThrough
,
PassThrough
>
>&
device_op_instances
)
{
using
DeviceGemms
=
device_gemm_instance
::
device_gemm_xdl_instance_f16_f16_f16_km_kn_mn
;
...
...
device_operation/device_gemm_xdl_instance_f16_f16_f16_km_nk_mn.cpp
View file @
457c024d
...
...
@@ -2,6 +2,7 @@
#include "config.hpp"
#include "device_gemm_xdl.hpp"
#include "device_gemm_instance.hpp"
#include "element_wise_operation.hpp"
namespace
ck
{
namespace
tensor_operation
{
...
...
@@ -17,27 +18,29 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
// Compilation parameters for a[k, m] * b[n, k] = c[m, n]
using
device_gemm_xdl_instance_f16_f16_f16_km_nk_mn
=
std
::
tuple
<
// clang-format off
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds|
//##########| Type| Type| Type| Type| | | | Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
//##########| | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
//##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Col
,
Row
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Col
,
Row
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
S
<
1
,
4
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Col
,
Row
,
128
,
128
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
S
<
1
,
4
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Col
,
Row
,
256
,
128
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Col
,
Row
,
128
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
S
<
1
,
2
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Col
,
Row
,
128
,
64
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
2
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
S
<
1
,
4
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Col
,
Row
,
256
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
S
<
1
,
1
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Col
,
Row
,
256
,
64
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
1
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout|
A| B| C|
Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds|
//##########| Type| Type| Type| Type| | | |
Elementwise| Elementwise| Elementwise|
Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
//##########| | | | | | | |
Operation| Operation| Operation|
| | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
//##########| | | | | | | |
| | |
| | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
S
<
1
,
4
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
128
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
S
<
1
,
4
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
128
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
S
<
1
,
2
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
64
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
2
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
S
<
1
,
4
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
S
<
1
,
1
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
64
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
1
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
// clang-format on
>
;
template
<
>
void
add_device_gemm_instance
<
F16
,
F16
,
F16
,
Col
,
Col
,
Row
>
(
std
::
vector
<
DeviceGemmPtr
>&
device_op_instances
)
std
::
vector
<
DeviceGemmPtr
<
PassThrough
,
PassThrough
,
PassThrough
>
>&
device_op_instances
)
{
using
DeviceGemms
=
device_gemm_instance
::
device_gemm_xdl_instance_f16_f16_f16_km_nk_mn
;
...
...
device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn.cpp
View file @
457c024d
...
...
@@ -2,6 +2,7 @@
#include "config.hpp"
#include "device_gemm_xdl.hpp"
#include "device_gemm_instance.hpp"
#include "element_wise_operation.hpp"
namespace
ck
{
namespace
tensor_operation
{
...
...
@@ -17,27 +18,29 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
// Compilation parameters for a[m, k] * b[k, n] = c[m, n]
using
device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn
=
std
::
tuple
<
// clang-format off
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds|
//##########| Type| Type| Type| Type| | | | Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
//##########| | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
//##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
4
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
128
,
128
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
4
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
256
,
128
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
128
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
2
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
128
,
64
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
2
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
4
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
256
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
1
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
256
,
64
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
1
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
7
,
1
,
true
,
true
>
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout|
A| B| C|
Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds|
//##########| Type| Type| Type| Type| | | |
Elementwise| Elementwise| Elementwise|
Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
//##########| | | | | | | |
Operation| Operation| Operation|
| | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
//##########| | | | | | | |
| | |
| | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
4
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
128
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
4
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
128
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
2
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
64
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
2
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
4
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
1
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
64
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
1
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
7
,
1
,
true
,
true
>
// clang-format on
>
;
template
<
>
void
add_device_gemm_instance
<
F16
,
F16
,
F16
,
Row
,
Row
,
Row
>
(
std
::
vector
<
DeviceGemmPtr
>&
device_op_instances
)
std
::
vector
<
DeviceGemmPtr
<
PassThrough
,
PassThrough
,
PassThrough
>
>&
device_op_instances
)
{
using
DeviceGemms
=
device_gemm_instance
::
device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn
;
...
...
device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn.cpp
View file @
457c024d
...
...
@@ -2,6 +2,7 @@
#include "config.hpp"
#include "device_gemm_xdl.hpp"
#include "device_gemm_instance.hpp"
#include "element_wise_operation.hpp"
namespace
ck
{
namespace
tensor_operation
{
...
...
@@ -17,32 +18,34 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
// Compilation parameters for a[m, k] * b[n, k] = c[m, n]
using
device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn
=
std
::
tuple
<
// clang-format off
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds|
//##########| Type| Type| Type| Type| | | | Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
//##########| | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
//##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
4
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
128
,
128
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
4
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
256
,
128
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
128
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
2
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
128
,
64
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
2
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
4
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
64
,
64
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
8
>
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
4
,
8
>
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
256
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
1
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
256
,
64
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
1
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
128
,
128
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
1
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
128
,
32
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
1
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
4
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
64
,
64
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
8
>
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
2
,
8
>
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
64
,
32
,
64
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
2
,
8
>
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
4
,
8
>
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout|
A| B| C|
Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds|
//##########| Type| Type| Type| Type| | | |
Elementwise| Elementwise| Elementwise|
Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
//##########| | | | | | | |
Operation| Operation| Operation|
| | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
//##########| | | | | | | |
| | |
| | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
4
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
128
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
4
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
128
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
2
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
64
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
2
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
4
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
64
,
64
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
8
>
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
4
,
8
>
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
1
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
64
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
1
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
128
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
1
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
32
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
1
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
4
,
8
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
64
,
64
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
8
>
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
2
,
8
>
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
64
,
32
,
64
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
2
,
8
>
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
4
,
8
>
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
// clang-format on
>
;
template
<
>
void
add_device_gemm_instance
<
F16
,
F16
,
F16
,
Row
,
Col
,
Row
>
(
std
::
vector
<
DeviceGemmPtr
>&
device_op_instances
)
std
::
vector
<
DeviceGemmPtr
<
PassThrough
,
PassThrough
,
PassThrough
>
>&
device_op_instances
)
{
using
DeviceGemms
=
device_gemm_instance
::
device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn
;
...
...
device_operation/device_gemm_xdl_instance_f32_f32_f32_km_kn_mn.cpp
View file @
457c024d
...
...
@@ -2,6 +2,7 @@
#include "config.hpp"
#include "device_gemm_xdl.hpp"
#include "device_gemm_instance.hpp"
#include "element_wise_operation.hpp"
namespace
ck
{
namespace
tensor_operation
{
...
...
@@ -17,27 +18,29 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
// Compilation parameters for a[k, m] * b[k, n] = c[m, n]
using
device_gemm_xdl_instance_f32_f32_f32_km_kn_mn
=
std
::
tuple
<
// clang-format off
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds|
//##########| Type| Type| Type| Type| | | | Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
//##########| | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
//##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Col
,
Row
,
Row
,
256
,
256
,
128
,
4
,
4
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Col
,
Row
,
Row
,
256
,
128
,
256
,
4
,
4
,
32
,
32
,
2
,
4
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
S
<
1
,
4
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Col
,
Row
,
Row
,
128
,
128
,
128
,
4
,
4
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
S
<
1
,
4
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Col
,
Row
,
Row
,
256
,
128
,
128
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Col
,
Row
,
Row
,
128
,
128
,
64
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
S
<
1
,
2
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Col
,
Row
,
Row
,
128
,
64
,
128
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
1
,
2
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
S
<
1
,
4
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Col
,
Row
,
Row
,
256
,
128
,
64
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
S
<
1
,
1
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Col
,
Row
,
Row
,
256
,
64
,
128
,
4
,
4
,
32
,
32
,
1
,
2
,
S
<
1
,
1
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
4
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
7
,
1
,
true
,
true
>
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout|
A| B| C|
Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds|
//##########| Type| Type| Type| Type| | | |
Elementwise| Elementwise| Elementwise|
Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
//##########| | | | | | | |
Operation| Operation| Operation|
| | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
//##########| | | | | | | |
| | |
| | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
256
,
128
,
4
,
4
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
128
,
256
,
4
,
4
,
32
,
32
,
2
,
4
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
S
<
1
,
4
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
128
,
128
,
4
,
4
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
S
<
1
,
4
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
128
,
128
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
128
,
64
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
S
<
1
,
2
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
64
,
128
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
1
,
2
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
S
<
1
,
4
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
128
,
64
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
S
<
1
,
1
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
64
,
128
,
4
,
4
,
32
,
32
,
1
,
2
,
S
<
1
,
1
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
4
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
7
,
1
,
true
,
true
>
// clang-format on
>
;
template
<
>
void
add_device_gemm_instance
<
F32
,
F32
,
F32
,
Col
,
Row
,
Row
>
(
std
::
vector
<
DeviceGemmPtr
>&
device_op_instances
)
std
::
vector
<
DeviceGemmPtr
<
PassThrough
,
PassThrough
,
PassThrough
>
>&
device_op_instances
)
{
using
DeviceGemms
=
device_gemm_instance
::
device_gemm_xdl_instance_f32_f32_f32_km_kn_mn
;
...
...
device_operation/device_gemm_xdl_instance_f32_f32_f32_km_nk_mn.cpp
View file @
457c024d
...
...
@@ -2,6 +2,7 @@
#include "config.hpp"
#include "device_gemm_xdl.hpp"
#include "device_gemm_instance.hpp"
#include "element_wise_operation.hpp"
namespace
ck
{
namespace
tensor_operation
{
...
...
@@ -17,27 +18,29 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
// Compilation parameters for a[k, m] * b[n, k] = c[m, n]
using
device_gemm_xdl_instance_f32_f32_f32_km_nk_mn
=
std
::
tuple
<
// clang-format off
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds|
//##########| Type| Type| Type| Type| | | | Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
//##########| | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
//##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Col
,
Col
,
Row
,
256
,
256
,
128
,
4
,
4
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Col
,
Col
,
Row
,
256
,
128
,
256
,
4
,
4
,
32
,
32
,
2
,
4
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
S
<
1
,
4
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Col
,
Col
,
Row
,
128
,
128
,
128
,
4
,
4
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
S
<
1
,
4
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Col
,
Col
,
Row
,
256
,
128
,
128
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Col
,
Col
,
Row
,
128
,
128
,
64
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
S
<
1
,
2
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Col
,
Col
,
Row
,
128
,
64
,
128
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
1
,
2
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
S
<
1
,
4
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Col
,
Col
,
Row
,
256
,
128
,
64
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
S
<
1
,
1
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Col
,
Col
,
Row
,
256
,
64
,
128
,
4
,
4
,
32
,
32
,
1
,
2
,
S
<
1
,
1
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
4
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
7
,
1
,
true
,
true
>
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout|
A| B| C|
Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds|
//##########| Type| Type| Type| Type| | | |
Elementwise| Elementwise| Elementwise|
Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
//##########| | | | | | | |
Operation| Operation| Operation|
| | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
//##########| | | | | | | |
| | |
| | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
256
,
128
,
4
,
4
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
128
,
256
,
4
,
4
,
32
,
32
,
2
,
4
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
S
<
1
,
4
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
128
,
128
,
4
,
4
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
S
<
1
,
4
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
128
,
128
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
128
,
64
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
S
<
1
,
2
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
64
,
128
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
1
,
2
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
S
<
1
,
4
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
128
,
64
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
S
<
1
,
1
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
64
,
128
,
4
,
4
,
32
,
32
,
1
,
2
,
S
<
1
,
1
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
4
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
7
,
1
,
true
,
true
>
// clang-format on
>
;
template
<
>
void
add_device_gemm_instance
<
F32
,
F32
,
F32
,
Col
,
Col
,
Row
>
(
std
::
vector
<
DeviceGemmPtr
>&
device_op_instances
)
std
::
vector
<
DeviceGemmPtr
<
PassThrough
,
PassThrough
,
PassThrough
>
>&
device_op_instances
)
{
using
DeviceGemms
=
device_gemm_instance
::
device_gemm_xdl_instance_f32_f32_f32_km_nk_mn
;
...
...
device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn.cpp
View file @
457c024d
...
...
@@ -2,6 +2,7 @@
#include "config.hpp"
#include "device_gemm_xdl.hpp"
#include "device_gemm_instance.hpp"
#include "element_wise_operation.hpp"
namespace
ck
{
namespace
tensor_operation
{
...
...
@@ -17,27 +18,29 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
// Compilation parameters for a[m, k] * b[k, n] = c[m, n]
using
device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn
=
std
::
tuple
<
// clang-format off
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds|
//##########| Type| Type| Type| Type| | | | Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
//##########| | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
//##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Row
,
Row
,
Row
,
256
,
256
,
128
,
4
,
4
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Row
,
Row
,
Row
,
256
,
128
,
256
,
4
,
4
,
32
,
32
,
2
,
4
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
4
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Row
,
Row
,
Row
,
128
,
128
,
128
,
4
,
4
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
4
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Row
,
Row
,
Row
,
256
,
128
,
128
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Row
,
Row
,
Row
,
128
,
128
,
64
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
2
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Row
,
Row
,
Row
,
128
,
64
,
128
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
1
,
2
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
4
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Row
,
Row
,
Row
,
256
,
128
,
64
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
1
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Row
,
Row
,
Row
,
256
,
64
,
128
,
4
,
4
,
32
,
32
,
1
,
2
,
S
<
1
,
1
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
7
,
1
,
true
,
true
>
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout|
A| B| C|
Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds|
//##########| Type| Type| Type| Type| | | |
Elementwise| Elementwise| Elementwise|
Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
//##########| | | | | | | |
Operation| Operation| Operation|
| | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
//##########| | | | | | | |
| | |
| | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
256
,
128
,
4
,
4
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
128
,
256
,
4
,
4
,
32
,
32
,
2
,
4
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
4
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
128
,
128
,
4
,
4
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
4
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
128
,
128
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
128
,
64
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
2
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
64
,
128
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
1
,
2
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
4
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
128
,
64
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
1
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
64
,
128
,
4
,
4
,
32
,
32
,
1
,
2
,
S
<
1
,
1
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
7
,
1
,
true
,
true
>
// clang-format on
>
;
template
<
>
void
add_device_gemm_instance
<
F32
,
F32
,
F32
,
Row
,
Row
,
Row
>
(
std
::
vector
<
DeviceGemmPtr
>&
device_op_instances
)
std
::
vector
<
DeviceGemmPtr
<
PassThrough
,
PassThrough
,
PassThrough
>
>&
device_op_instances
)
{
using
DeviceGemms
=
device_gemm_instance
::
device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn
;
...
...
device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn.cpp
View file @
457c024d
...
...
@@ -2,6 +2,7 @@
#include "config.hpp"
#include "device_gemm_xdl.hpp"
#include "device_gemm_instance.hpp"
#include "element_wise_operation.hpp"
namespace
ck
{
namespace
tensor_operation
{
...
...
@@ -17,32 +18,34 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
// Compilation parameters for a[m, k] * b[n, k] = c[m, n]
using
device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn
=
std
::
tuple
<
// clang-format off
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds|
//##########| Type| Type| Type| Type| | | | Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
//##########| | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
//##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Row
,
Col
,
Row
,
256
,
256
,
128
,
4
,
4
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Row
,
Col
,
Row
,
256
,
128
,
256
,
4
,
4
,
32
,
32
,
2
,
4
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
4
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Row
,
Col
,
Row
,
128
,
128
,
128
,
4
,
4
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
4
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Row
,
Col
,
Row
,
256
,
128
,
128
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Row
,
Col
,
Row
,
128
,
128
,
64
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
2
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Row
,
Col
,
Row
,
128
,
64
,
128
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
1
,
2
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
4
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Row
,
Col
,
Row
,
64
,
64
,
64
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
4
>
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
4
,
4
>
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Row
,
Col
,
Row
,
256
,
128
,
64
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
1
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Row
,
Col
,
Row
,
256
,
64
,
128
,
4
,
4
,
32
,
32
,
1
,
2
,
S
<
1
,
1
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Row
,
Col
,
Row
,
128
,
128
,
32
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
1
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Row
,
Col
,
Row
,
128
,
32
,
128
,
4
,
4
,
32
,
32
,
1
,
2
,
S
<
1
,
1
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
4
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Row
,
Col
,
Row
,
64
,
64
,
32
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
4
>
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
2
,
4
>
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Row
,
Col
,
Row
,
64
,
32
,
64
,
4
,
4
,
32
,
32
,
1
,
2
,
S
<
1
,
2
,
4
>
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
4
,
4
>
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
7
,
1
,
true
,
true
>
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout|
A| B| C|
Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds|
//##########| Type| Type| Type| Type| | | |
Elementwise| Elementwise| Elementwise|
Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
//##########| | | | | | | |
Operation| Operation| Operation|
| | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
//##########| | | | | | | |
| | |
| | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
256
,
128
,
4
,
4
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
128
,
256
,
4
,
4
,
32
,
32
,
2
,
4
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
4
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
128
,
128
,
4
,
4
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
4
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
128
,
128
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
128
,
64
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
2
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
64
,
128
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
1
,
2
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
4
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
64
,
64
,
64
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
4
>
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
4
,
4
>
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
128
,
64
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
1
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
64
,
128
,
4
,
4
,
32
,
32
,
1
,
2
,
S
<
1
,
1
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
128
,
32
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
1
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
32
,
128
,
4
,
4
,
32
,
32
,
1
,
2
,
S
<
1
,
1
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
4
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
64
,
64
,
32
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
4
>
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
2
,
4
>
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
64
,
32
,
64
,
4
,
4
,
32
,
32
,
1
,
2
,
S
<
1
,
2
,
4
>
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
4
,
4
>
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
7
,
1
,
true
,
true
>
// clang-format on
>
;
template
<
>
void
add_device_gemm_instance
<
F32
,
F32
,
F32
,
Row
,
Col
,
Row
>
(
std
::
vector
<
DeviceGemmPtr
>&
device_op_instances
)
std
::
vector
<
DeviceGemmPtr
<
PassThrough
,
PassThrough
,
PassThrough
>
>&
device_op_instances
)
{
using
DeviceGemms
=
device_gemm_instance
::
device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn
;
...
...
device_operation/include/device_conv.hpp
View file @
457c024d
...
...
@@ -8,6 +8,9 @@ namespace ck {
namespace
tensor_operation
{
namespace
device
{
template
<
typename
InElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
OutElementwiseOperation
>
struct
DeviceConvFwd
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
...
...
@@ -23,11 +26,17 @@ struct DeviceConvFwd : public BaseOperator
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
)
=
0
;
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
template
<
typename
InElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
OutElementwiseOperation
>
struct
DeviceConvBwd
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
...
...
@@ -43,11 +52,17 @@ struct DeviceConvBwd : public BaseOperator
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
)
=
0
;
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
template
<
typename
InElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
OutElementwiseOperation
>
struct
DeviceConvWrw
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
...
...
@@ -63,14 +78,31 @@ struct DeviceConvWrw : public BaseOperator
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
)
=
0
;
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
using
DeviceConvFwdPtr
=
std
::
unique_ptr
<
DeviceConvFwd
>
;
using
DeviceConvBwdPtr
=
std
::
unique_ptr
<
DeviceConvBwd
>
;
using
DeviceConvWrwPtr
=
std
::
unique_ptr
<
DeviceConvWrw
>
;
template
<
typename
InElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
OutElementwiseOperation
>
using
DeviceConvFwdPtr
=
std
::
unique_ptr
<
DeviceConvFwd
<
InElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
>>
;
template
<
typename
InElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
OutElementwiseOperation
>
using
DeviceConvBwdPtr
=
std
::
unique_ptr
<
DeviceConvBwd
<
InElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
>>
;
template
<
typename
InElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
OutElementwiseOperation
>
using
DeviceConvWrwPtr
=
std
::
unique_ptr
<
DeviceConvWrw
<
InElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
>>
;
}
// namespace device
}
// namespace tensor_operation
...
...
device_operation/include/device_conv_fwd_xdl.hpp
View file @
457c024d
...
...
@@ -23,6 +23,9 @@ template <ck::index_t NDimSpatial,
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
,
typename
InElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
OutElementwiseOperation
,
ck
::
index_t
BlockSize
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
NPerBlock
,
...
...
device_operation/include/device_conv_fwd_xdl_nhwc_kyxc_nhwk.hpp
View file @
457c024d
...
...
@@ -22,6 +22,9 @@ template <typename InDataType,
typename
WeiDataType
,
typename
OutDataType
,
typename
AccDataType
,
typename
InElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
OutElementwiseOperation
,
ck
::
index_t
BlockSize
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
NPerBlock
,
...
...
@@ -58,6 +61,9 @@ struct DeviceConvFwdXdl<
ck
::
tensor_layout
::
convolution
::
NHWC
,
// typename InLayout,
ck
::
tensor_layout
::
convolution
::
KYXC
,
// typename WeiLayout,
ck
::
tensor_layout
::
convolution
::
NHWK
,
// typename OutLayout,
InElementwiseOperation
,
// typename InElementwiseOperation,
WeiElementwiseOperation
,
// typename WeiElementwiseOperation,
OutElementwiseOperation
,
// typename OutElementwiseOperation,
BlockSize
,
// ck::index_t BlockSize,
MPerBlock
,
// ck::index_t MPerBlock,
NPerBlock
,
// ck::index_t NPerBlock,
...
...
@@ -87,7 +93,8 @@ struct DeviceConvFwdXdl<
CThreadTransferDstScalarPerVector
,
// ck::index_t CThreadTransferDstScalarPerVector,
ABlockLdsAddExtraM
,
// bool ABlockLdsAddExtraM,
BBlockLdsAddExtraN
// bool BBlockLdsAddExtraN>
>
:
public
DeviceConvFwd
>
:
public
DeviceConvFwd
<
InElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
>
{
using
ADataType
=
InDataType
;
using
BDataType
=
WeiDataType
;
...
...
@@ -293,6 +300,9 @@ struct DeviceConvFwdXdl<
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K1
,
CGridDesc_M_N
,
InElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
,
MPerBlock
,
NPerBlock
,
K0PerBlock
,
...
...
@@ -351,7 +361,10 @@ struct DeviceConvFwdXdl<
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
ck
::
index_t
M01
,
ck
::
index_t
N01
)
ck
::
index_t
N01
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
)
:
p_a_grid_
{
p_in_grid
},
p_b_grid_
{
p_wei_grid
},
p_c_grid_
{
p_out_grid
},
...
...
@@ -361,7 +374,10 @@ struct DeviceConvFwdXdl<
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
{},
block_2_ctile_map_
{},
M01_
{
M01
},
N01_
{
N01
}
N01_
{
N01
},
in_element_op_
{
in_element_op
},
wei_element_op_
{
wei_element_op
},
out_element_op_
{
out_element_op
}
{
const
auto
descs
=
DeviceConvFwdXdl
::
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
N
,
...
...
@@ -400,6 +416,9 @@ struct DeviceConvFwdXdl<
Block2CTileMap
block_2_ctile_map_
;
index_t
M01_
;
index_t
N01_
;
InElementwiseOperation
in_element_op_
;
WeiElementwiseOperation
wei_element_op_
;
OutElementwiseOperation
out_element_op_
;
};
// Invoker
...
...
@@ -449,6 +468,9 @@ struct DeviceConvFwdXdl<
remove_reference_t
<
DeviceConvFwdXdl
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceConvFwdXdl
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceConvFwdXdl
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
InElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
,
remove_reference_t
<
DeviceConvFwdXdl
::
Block2CTileMap
>
,
true
>
;
...
...
@@ -463,6 +485,9 @@ struct DeviceConvFwdXdl<
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
arg
.
in_element_op_
,
arg
.
wei_element_op_
,
arg
.
out_element_op_
,
arg
.
block_2_ctile_map_
);
}
else
...
...
@@ -474,6 +499,9 @@ struct DeviceConvFwdXdl<
remove_reference_t
<
DeviceConvFwdXdl
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceConvFwdXdl
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceConvFwdXdl
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
InElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
,
remove_reference_t
<
DeviceConvFwdXdl
::
Block2CTileMap
>
,
false
>
;
...
...
@@ -488,6 +516,9 @@ struct DeviceConvFwdXdl<
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
arg
.
in_element_op_
,
arg
.
wei_element_op_
,
arg
.
out_element_op_
,
arg
.
block_2_ctile_map_
);
}
...
...
@@ -534,7 +565,10 @@ struct DeviceConvFwdXdl<
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
)
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
)
{
return
Argument
{
p_in_grid
,
p_wei_grid
,
...
...
@@ -550,7 +584,10 @@ struct DeviceConvFwdXdl<
input_left_pads
,
input_right_pads
,
1
,
1
};
1
,
in_element_op
,
wei_element_op
,
out_element_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
...
...
@@ -569,7 +606,10 @@ struct DeviceConvFwdXdl<
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
)
override
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
)
override
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
InDataType
*>
(
p_in_grid
),
static_cast
<
const
WeiDataType
*>
(
p_wei_grid
),
...
...
@@ -585,7 +625,10 @@ struct DeviceConvFwdXdl<
input_left_pads
,
input_right_pads
,
1
,
1
);
1
,
in_element_op
,
wei_element_op
,
out_element_op
);
}
// polymorphic
...
...
@@ -593,7 +636,7 @@ struct DeviceConvFwdXdl<
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
};
};
// namespace device
}
// namespace device
}
// namespace tensor_operation
...
...
device_operation/include/device_conv_instance.hpp
View file @
457c024d
...
...
@@ -2,6 +2,7 @@
#define DEVICE_CONV_INSTANTCE_HPP
#include "device_conv.hpp"
#include "element_wise_operation.hpp"
namespace
ck
{
namespace
tensor_operation
{
...
...
@@ -15,7 +16,10 @@ template <ck::index_t NDimSpatial,
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
>
void
add_device_conv_fwd_instance
(
std
::
vector
<
DeviceConvFwdPtr
>&
);
void
add_device_conv_fwd_instance
(
std
::
vector
<
DeviceConvFwdPtr
<
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>>&
);
template
<
ck
::
index_t
NDimSpatial
,
typename
InDataType
,
...
...
@@ -24,7 +28,10 @@ template <ck::index_t NDimSpatial,
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
>
void
add_device_conv_bwd_instance
(
std
::
vector
<
DeviceConvBwdPtr
>&
);
void
add_device_conv_bwd_instance
(
std
::
vector
<
DeviceConvBwdPtr
<
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>>&
);
template
<
ck
::
index_t
NDimSpatial
,
typename
InDataType
,
...
...
@@ -33,7 +40,10 @@ template <ck::index_t NDimSpatial,
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
>
void
add_device_conv_wrw_instance
(
std
::
vector
<
DeviceConvWrwPtr
>&
);
void
add_device_conv_wrw_instance
(
std
::
vector
<
DeviceConvWrwPtr
<
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>>&
);
}
// namespace device_conv_instance
}
// namespace device
...
...
device_operation/include/device_gemm_instance.hpp
View file @
457c024d
...
...
@@ -2,6 +2,7 @@
#define DEVICE_GEMM_INSTANTCE_HPP
#include "device_gemm.hpp"
#include "element_wise_operation.hpp"
namespace
ck
{
namespace
tensor_operation
{
...
...
@@ -14,7 +15,10 @@ template <typename ADataType,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
void
add_device_gemm_instance
(
std
::
vector
<
DeviceGemmPtr
>&
);
void
add_device_gemm_instance
(
std
::
vector
<
DeviceGemmPtr
<
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>>&
);
}
// namespace device_gemm_instance
}
// namespace device
...
...
device_operation/include/element_wise_operation.hpp
0 → 100644
View file @
457c024d
#ifndef ELEMENT_WISE_OPERATION_HPP
#define ELEMENT_WISE_OPERATION_HPP
namespace
ck
{
namespace
tensor_operation
{
namespace
element_wise
{
struct
PassThrough
{
template
<
typename
T
>
__host__
__device__
constexpr
T
operator
()(
T
v
)
const
{
return
v
;
}
};
}
// namespace element_wise
}
// namespace tensor_operation
}
// namespace ck
#endif
example/1_gemm_xdl/gemm_xdl.cpp
View file @
457c024d
...
...
@@ -14,7 +14,7 @@
#include "device_base.hpp"
#include "device_gemm_xdl.hpp"
struct
Equal
struct
PassThrough
{
template
<
typename
T
>
__host__
__device__
constexpr
T
operator
()(
T
v
)
const
...
...
@@ -204,8 +204,8 @@ int main(int argc, char* argv[])
ALayout
,
BLayout
,
CLayout
,
Equal
,
Equal
,
PassThrough
,
PassThrough
,
Relu
>::
type
{};
auto
invoker
=
gemm
.
MakeInvoker
();
...
...
@@ -218,8 +218,8 @@ int main(int argc, char* argv[])
StrideA
,
StrideB
,
StrideC
,
Equal
{},
Equal
{},
PassThrough
{},
PassThrough
{},
Relu
{});
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
...
...
@@ -246,7 +246,7 @@ int main(int argc, char* argv[])
if
(
do_verification
)
{
host_gemm_mk_kn_mn
(
a_m_k
,
b_k_n
,
c_m_n_host_result
,
Equal
{},
Equal
{},
Relu
{});
host_gemm_mk_kn_mn
(
a_m_k
,
b_k_n
,
c_m_n_host_result
,
PassThrough
{},
PassThrough
{},
Relu
{});
check_error
(
c_m_n_host_result
,
c_m_n_device_result
);
}
...
...
profiler/include/profile_conv.hpp
View file @
457c024d
...
...
@@ -8,12 +8,17 @@
#include "device_tensor.hpp"
#include "device_conv.hpp"
#include "device_conv_instance.hpp"
#include "element_wise_operation.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_conv_instance
{
using
DeviceConvFwdNoOpPtr
=
DeviceConvFwdPtr
<
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
template
<
>
void
add_device_conv_fwd_instance
<
2
,
float
,
...
...
@@ -22,7 +27,7 @@ void add_device_conv_fwd_instance<2,
ck
::
tensor_layout
::
convolution
::
NHWC
,
ck
::
tensor_layout
::
convolution
::
KYXC
,
ck
::
tensor_layout
::
convolution
::
NHWK
>
(
std
::
vector
<
ck
::
tensor_operation
::
device
::
DeviceConvFwdPtr
>&
);
std
::
vector
<
DeviceConvFwd
NoOp
Ptr
>&
);
template
<
>
void
add_device_conv_fwd_instance
<
2
,
...
...
@@ -32,7 +37,7 @@ void add_device_conv_fwd_instance<2,
ck
::
tensor_layout
::
convolution
::
NHWC
,
ck
::
tensor_layout
::
convolution
::
KYXC
,
ck
::
tensor_layout
::
convolution
::
NHWK
>
(
std
::
vector
<
ck
::
tensor_operation
::
device
::
DeviceConvFwdPtr
>&
);
std
::
vector
<
DeviceConvFwd
NoOp
Ptr
>&
);
}
// namespace device_conv_instance
}
// namespace device
...
...
@@ -133,8 +138,13 @@ void profile_conv(int do_verification,
in_device_buf
.
ToDevice
(
in_n_c_hi_wi
.
mData
.
data
());
wei_device_buf
.
ToDevice
(
wei_k_c_y_x
.
mData
.
data
());
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
DeviceConvFwdNoOpPtr
=
ck
::
tensor_operation
::
device
::
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
PassThrough
>
;
// add device Conv instances
std
::
vector
<
ck
::
tensor_operation
::
device
::
DeviceConvFwdPtr
>
conv_ptrs
;
std
::
vector
<
DeviceConvFwd
NoOp
Ptr
>
conv_ptrs
;
ck
::
tensor_operation
::
device
::
device_conv_instance
::
add_device_conv_fwd_instance
<
2
,
InDataType
,
...
...
@@ -170,7 +180,10 @@ void profile_conv(int do_verification,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
);
input_right_pads
,
PassThrough
{},
PassThrough
{},
PassThrough
{});
auto
invoker_ptr
=
conv_ptr
->
MakeInvokerPointer
();
...
...
profiler/include/profile_gemm.hpp
View file @
457c024d
...
...
@@ -6,13 +6,17 @@ namespace tensor_operation {
namespace
device
{
namespace
device_gemm_instance
{
using
DeviceGemmNoOpPtr
=
DeviceGemmPtr
<
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
template
<
>
void
add_device_gemm_instance
<
float
,
float
,
float
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
std
::
vector
<
DeviceGemmPtr
>&
);
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
std
::
vector
<
DeviceGemm
NoOp
Ptr
>&
);
template
<
>
void
add_device_gemm_instance
<
float
,
...
...
@@ -20,7 +24,7 @@ void add_device_gemm_instance<float,
float
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
std
::
vector
<
DeviceGemmPtr
>&
);
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
std
::
vector
<
DeviceGemm
NoOp
Ptr
>&
);
template
<
>
void
add_device_gemm_instance
<
float
,
...
...
@@ -28,7 +32,7 @@ void add_device_gemm_instance<float,
float
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
std
::
vector
<
DeviceGemmPtr
>&
);
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
std
::
vector
<
DeviceGemm
NoOp
Ptr
>&
);
template
<
>
void
add_device_gemm_instance
<
float
,
...
...
@@ -36,7 +40,7 @@ void add_device_gemm_instance<float,
float
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
std
::
vector
<
DeviceGemmPtr
>&
);
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
std
::
vector
<
DeviceGemm
NoOp
Ptr
>&
);
template
<
>
void
add_device_gemm_instance
<
ck
::
half_t
,
...
...
@@ -44,7 +48,7 @@ void add_device_gemm_instance<ck::half_t,
ck
::
half_t
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
std
::
vector
<
DeviceGemmPtr
>&
);
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
std
::
vector
<
DeviceGemm
NoOp
Ptr
>&
);
template
<
>
void
add_device_gemm_instance
<
ck
::
half_t
,
...
...
@@ -52,7 +56,7 @@ void add_device_gemm_instance<ck::half_t,
ck
::
half_t
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
std
::
vector
<
DeviceGemmPtr
>&
);
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
std
::
vector
<
DeviceGemm
NoOp
Ptr
>&
);
template
<
>
void
add_device_gemm_instance
<
ck
::
half_t
,
...
...
@@ -60,7 +64,7 @@ void add_device_gemm_instance<ck::half_t,
ck
::
half_t
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
std
::
vector
<
DeviceGemmPtr
>&
);
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
std
::
vector
<
DeviceGemm
NoOp
Ptr
>&
);
template
<
>
void
add_device_gemm_instance
<
ck
::
half_t
,
...
...
@@ -68,7 +72,7 @@ void add_device_gemm_instance<ck::half_t,
ck
::
half_t
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
std
::
vector
<
DeviceGemmPtr
>&
);
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
std
::
vector
<
DeviceGemm
NoOp
Ptr
>&
);
}
// namespace device_gemm_instance
}
// namespace device
...
...
@@ -132,7 +136,12 @@ void profile_gemm(int do_verification,
if
(
do_verification
)
{
host_gemm_mk_kn_mn
(
a_m_k
,
b_k_n
,
c_m_n_host_result
);
host_gemm_mk_kn_mn
(
a_m_k
,
b_k_n
,
c_m_n_host_result
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
{},
ck
::
tensor_operation
::
element_wise
::
PassThrough
{},
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
}
DeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
a_m_k
.
mDesc
.
GetElementSpace
());
...
...
@@ -144,7 +153,7 @@ void profile_gemm(int do_verification,
c_device_buf
.
ToDevice
(
c_m_n_device_result
.
mData
.
data
());
// add device GEMM instances
std
::
vector
<
ck
::
tensor_operation
::
device
::
DeviceGemmPtr
>
gemm_ptrs
;
std
::
vector
<
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
DeviceGemm
NoOp
Ptr
>
gemm_ptrs
;
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_instance
<
ADataType
,
BDataType
,
CDataType
,
ALayout
,
BLayout
,
CLayout
>
(
...
...
@@ -171,7 +180,10 @@ void profile_gemm(int do_verification,
K
,
StrideA
,
StrideB
,
StrideC
);
StrideC
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
{},
ck
::
tensor_operation
::
element_wise
::
PassThrough
{},
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
auto
invoker_ptr
=
gemm_ptr
->
MakeInvokerPointer
();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment