Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
cc8df39e
Commit
cc8df39e
authored
Mar 30, 2022
by
carlushuang
Browse files
Merge remote-tracking branch 'origin/develop' into cpu_avx2
parents
0b9fe840
98e1e2d0
Changes
39
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2725 additions
and
253 deletions
+2725
-253
example/01_gemm/gemm_xdl_int8.cpp
example/01_gemm/gemm_xdl_int8.cpp
+7
-7
example/10_conv2d_bwd_data/conv2d_bwd_data_xdl.cpp
example/10_conv2d_bwd_data/conv2d_bwd_data_xdl.cpp
+1
-0
example/14_gemm_xdl_requant_relu_requant/gemm_xdl_requant_relu_requant_int8.cpp
...quant_relu_requant/gemm_xdl_requant_relu_requant_int8.cpp
+15
-15
example/17_convnd_bwd_data_xdl/CMakeLists.txt
example/17_convnd_bwd_data_xdl/CMakeLists.txt
+1
-0
example/17_convnd_bwd_data_xdl/README.md
example/17_convnd_bwd_data_xdl/README.md
+80
-0
example/17_convnd_bwd_data_xdl/convnd_bwd_data_xdl.cpp
example/17_convnd_bwd_data_xdl/convnd_bwd_data_xdl.cpp
+415
-0
example/CMakeLists.txt
example/CMakeLists.txt
+1
-0
include/ck/problem_transform/transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk.hpp
...ckward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk.hpp
+55
-55
include/ck/problem_transform/transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp
...ward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp
+53
-53
include/ck/tensor_operation/gpu/device/conv_utils.hpp
include/ck/tensor_operation/gpu/device/conv_utils.hpp
+23
-1
include/ck/tensor_operation/gpu/device/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp
.../gpu/device/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp
+55
-56
include/ck/tensor_operation/gpu/device/device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp
...u/device/device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp
+1543
-0
include/ck/tensor_operation/gpu/device/tensor_layout.hpp
include/ck/tensor_operation/gpu/device/tensor_layout.hpp
+0
-1
library/include/ck/library/obselete_driver_offline/device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp
...kward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp
+6
-6
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp
...eference_tensor_operation/cpu/reference_conv_bwd_data.hpp
+194
-59
library/src/tensor_operation_instance/gpu/CMakeLists.txt
library/src/tensor_operation_instance/gpu/CMakeLists.txt
+1
-0
library/src/tensor_operation_instance/gpu/convnd_bwd_data/CMakeLists.txt
...sor_operation_instance/gpu/convnd_bwd_data/CMakeLists.txt
+22
-0
library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instance.cpp
.../device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instance.cpp
+84
-0
library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instance.cpp
...a/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instance.cpp
+86
-0
library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instance.cpp
...a/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instance.cpp
+83
-0
No files found.
example/01_gemm/gemm_xdl_int8.cpp
View file @
cc8df39e
...
@@ -53,9 +53,9 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle
...
@@ -53,9 +53,9 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle
256
,
// BlockSize
256
,
// BlockSize
256
,
// MPerBlock
256
,
// MPerBlock
128
,
// NPerBlock
128
,
// NPerBlock
32
,
// KPerBlock
64
,
// KPerBlock
8
,
// AK1
16
,
// AK1
8
,
// BK1
16
,
// BK1
32
,
// MPerXDL
32
,
// MPerXDL
32
,
// NPerXDL
32
,
// NPerXDL
4
,
// MXdlPerWave
4
,
// MXdlPerWave
...
@@ -64,15 +64,15 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle
...
@@ -64,15 +64,15 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle
S
<
1
,
0
,
2
>
,
// ABlockTransferThreadClusterArrangeOrder
S
<
1
,
0
,
2
>
,
// ABlockTransferThreadClusterArrangeOrder
S
<
1
,
0
,
2
>
,
// ABlockTransferSrcAccessOrder
S
<
1
,
0
,
2
>
,
// ABlockTransferSrcAccessOrder
2
,
// ABlockTransferSrcVectorDim
2
,
// ABlockTransferSrcVectorDim
8
,
// ABlockTransferSrcScalarPerVector
16
,
// ABlockTransferSrcScalarPerVector
8
,
// ABlockTransferDstScalarPerVector_K1
16
,
// ABlockTransferDstScalarPerVector_K1
true
,
// ABlockLdsAddExtraM
true
,
// ABlockLdsAddExtraM
S
<
4
,
64
,
1
>
,
// BBlockTransferThreadClusterLengths_K0_N_K1
S
<
4
,
64
,
1
>
,
// BBlockTransferThreadClusterLengths_K0_N_K1
S
<
1
,
0
,
2
>
,
// BBlockTransferThreadClusterArrangeOrder
S
<
1
,
0
,
2
>
,
// BBlockTransferThreadClusterArrangeOrder
S
<
1
,
0
,
2
>
,
// BBlockTransferSrcAccessOrder
S
<
1
,
0
,
2
>
,
// BBlockTransferSrcAccessOrder
2
,
// BBlockTransferSrcVectorDim
2
,
// BBlockTransferSrcVectorDim
8
,
// BBlockTransferSrcScalarPerVector
16
,
// BBlockTransferSrcScalarPerVector
8
,
// BBlockTransferDstScalarPerVector_K1
16
,
// BBlockTransferDstScalarPerVector_K1
true
,
// BBlockLdsAddExtraN
true
,
// BBlockLdsAddExtraN
1
,
// CShuffleMXdlPerWavePerShuffle
1
,
// CShuffleMXdlPerWavePerShuffle
1
,
// CShuffleNXdlPerWavePerShuffle
1
,
// CShuffleNXdlPerWavePerShuffle
...
...
example/10_conv2d_bwd_data/conv2d_bwd_data_xdl.cpp
View file @
cc8df39e
...
@@ -68,6 +68,7 @@ using DeviceConvBwdDataInstance = ck::tensor_operation::device::
...
@@ -68,6 +68,7 @@ using DeviceConvBwdDataInstance = ck::tensor_operation::device::
using
ReferenceConvBwdInstance
=
ck
::
tensor_operation
::
host
::
ReferenceConvBwdData
<
InDataType
,
using
ReferenceConvBwdInstance
=
ck
::
tensor_operation
::
host
::
ReferenceConvBwdData
<
InDataType
,
WeiDataType
,
WeiDataType
,
OutDataType
,
OutDataType
,
AccDataType
,
InElementOp
,
InElementOp
,
WeiElementOp
,
WeiElementOp
,
OutElementOp
>
;
OutElementOp
>
;
...
...
example/14_gemm_xdl_requant_relu_requant/gemm_xdl_requant_relu_requant_int8.cpp
View file @
cc8df39e
...
@@ -32,7 +32,7 @@ using ADataType = int8_t;
...
@@ -32,7 +32,7 @@ using ADataType = int8_t;
using
BDataType
=
int8_t
;
using
BDataType
=
int8_t
;
using
CDataType
=
int8_t
;
using
CDataType
=
int8_t
;
using
AccDataType
=
int32_t
;
using
AccDataType
=
int32_t
;
using
ShuffleDataType
=
int32_t
;
using
C
ShuffleDataType
=
int32_t
;
using
ALayout
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
ALayout
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
BLayout
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
BLayout
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
...
@@ -44,7 +44,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle
...
@@ -44,7 +44,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle
BDataType
,
// BDataType
BDataType
,
// BDataType
CDataType
,
// CDataType
CDataType
,
// CDataType
AccDataType
,
// AccDataType
AccDataType
,
// AccDataType
ShuffleDataType
,
// ShuffleDataType
C
ShuffleDataType
,
//
C
ShuffleDataType
ALayout
,
// ALayout
ALayout
,
// ALayout
BLayout
,
// BLayout
BLayout
,
// BLayout
CLayout
,
// CLayout
CLayout
,
// CLayout
...
@@ -54,9 +54,9 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle
...
@@ -54,9 +54,9 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle
256
,
// BlockSize
256
,
// BlockSize
256
,
// MPerBlock
256
,
// MPerBlock
128
,
// NPerBlock
128
,
// NPerBlock
32
,
// KPerBlock
64
,
// KPerBlock
8
,
// AK1
16
,
// AK1
8
,
// BK1
16
,
// BK1
32
,
// MPerXDL
32
,
// MPerXDL
32
,
// NPerXDL
32
,
// NPerXDL
4
,
// MXdlPerWave
4
,
// MXdlPerWave
...
@@ -65,20 +65,20 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle
...
@@ -65,20 +65,20 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle
S
<
1
,
0
,
2
>
,
// ABlockTransferThreadClusterArrangeOrder
S
<
1
,
0
,
2
>
,
// ABlockTransferThreadClusterArrangeOrder
S
<
1
,
0
,
2
>
,
// ABlockTransferSrcAccessOrder
S
<
1
,
0
,
2
>
,
// ABlockTransferSrcAccessOrder
2
,
// ABlockTransferSrcVectorDim
2
,
// ABlockTransferSrcVectorDim
8
,
// ABlockTransferSrcScalarPerVector
16
,
// ABlockTransferSrcScalarPerVector
8
,
// ABlockTransferDstScalarPerVector_K1
16
,
// ABlockTransferDstScalarPerVector_K1
true
,
// ABlockLdsAddExtraM
true
,
// ABlockLdsAddExtraM
S
<
4
,
64
,
1
>
,
// BBlockTransferThreadClusterLengths_K0_N_K1
S
<
4
,
64
,
1
>
,
// BBlockTransferThreadClusterLengths_K0_N_K1
S
<
1
,
0
,
2
>
,
// BBlockTransferThreadClusterArrangeOrder
S
<
1
,
0
,
2
>
,
// BBlockTransferThreadClusterArrangeOrder
S
<
1
,
0
,
2
>
,
// BBlockTransferSrcAccessOrder
S
<
1
,
0
,
2
>
,
// BBlockTransferSrcAccessOrder
2
,
// BBlockTransferSrcVectorDim
2
,
// BBlockTransferSrcVectorDim
8
,
// BBlockTransferSrcScalarPerVector
16
,
// BBlockTransferSrcScalarPerVector
8
,
// BBlockTransferDstScalarPerVector_K1
16
,
// BBlockTransferDstScalarPerVector_K1
true
,
// BBlockLdsAddExtraN
true
,
// BBlockLdsAddExtraN
1
,
// CShuffleMXdlPerWavePerShuffle
1
,
// CShuffleMXdlPerWavePerShuffle
1
,
// CShuffleNXdlPerWavePerShuffle
1
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
1
,
32
,
1
,
1
,
8
>
,
// CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
S
<
1
,
1
,
64
,
1
,
1
,
4
>
,
// CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
8
>
;
// CBlockTransferScalarPerVector_NWaveNPerXdl
16
>
;
// CBlockTransferScalarPerVector_NWaveNPerXdl
// clang-format on
// clang-format on
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
...
...
example/17_convnd_bwd_data_xdl/CMakeLists.txt
0 → 100644
View file @
cc8df39e
add_example_executable
(
example_convnd_bwd_data_xdl convnd_bwd_data_xdl.cpp
)
example/17_convnd_bwd_data_xdl/README.md
0 → 100644
View file @
cc8df39e
# Instructions for ```convnd_bwd_data_xdl``` Example
## Docker script
```
bash
docker run
\
-it
\
--rm
\
--privileged
\
--group-add
sudo
\
-w
/root/workspace
\
-v
${
PATH_TO_LOCAL_WORKSPACE
}
:/root/workspace
\
rocm/tensorflow:rocm4.3.1-tf2.6-dev
\
/bin/bash
```
## Build ```convnd_bwd_data_xdl```
```
bash
mkdir
build
&&
cd
build
```
```
bash
# Need to specify target ID, example below is gfx908
cmake
\
-D
BUILD_DEV
=
OFF
\
-D
CMAKE_BUILD_TYPE
=
Release
\
-D
CMAKE_CXX_FLAGS
=
"-DCK_AMD_GPU_GFX908 --amdgpu-target=gfx908 -O3 "
\
-D
CMAKE_CXX_COMPILER
=
/opt/rocm/bin/hipcc
\
-D
CMAKE_PREFIX_PATH
=
/opt/rocm
\
..
```
```
bash
make
-j
convnd_bwd_data_xdl
```
## Run ```example_convnd_bwd_data_xdl```
```
bash
#arg1: verification (0=no, 1=yes)
#arg2: initialization (0=no init, 1=integer value, 2=decimal value)
#arg3: run kernel # of times (>1)
#arg4: num_dim_spatial(1|2|3)
#arg5 to ...: N, K, C, [Z,] [Y,] X, [Di,] [Hi,] Wi, S[z,] [Sy,] Sx, [Dz,] [Dy,] Dx, [LeftPz,] [LeftPy,] LeftPx, [RightPy,] [RightPy,] RightPx
./bin/convnd_bwd_data_xdl 0 1 5
```
Result
```
in_n_c_hi_wi: dim 4, lengths {128, 128, 71, 71}, strides {645248, 1, 9088, 128}
wei_k_c_y_x: dim 4, lengths {256, 128, 3, 3}, strides {1152, 1, 384, 128}
out_n_k_ho_wo: dim 4, lengths {128, 256, 36, 36}, strides {331776, 1, 9216, 256}
arg.a_grid_desc_k0_m_k1_container_{128, 175232, 8}
arg.b_grid_desc_k0_n_k1_container_{128, 128, 8}
arg.c_grid_desc_m_n_container_{ 175232, 128}
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_( 2738, 2, 2, 2, 4, 2 )
launch_and_time_kernel: grid_dim {1369, 1, 1}, block_dim {256, 1, 1}
Warm up
Start running 1 times...
arg.a_grid_desc_k0_m_k1_container_{64, 175232, 8}
arg.b_grid_desc_k0_n_k1_container_{64, 128, 8}
arg.c_grid_desc_m_n_container_{ 175232, 128}
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_( 2738, 2, 2, 2, 4, 2 )
launch_and_time_kernel: grid_dim {1369, 1, 1}, block_dim {256, 1, 1}
Warm up
Start running 1 times...
arg.a_grid_desc_k0_m_k1_container_{64, 175232, 8}
arg.b_grid_desc_k0_n_k1_container_{64, 128, 8}
arg.c_grid_desc_m_n_container_{ 175232, 128}
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_( 2738, 2, 2, 2, 4, 2 )
launch_and_time_kernel: grid_dim {1369, 1, 1}, block_dim {256, 1, 1}
Warm up
Start running 1 times...
arg.a_grid_desc_k0_m_k1_container_{32, 175232, 8}
arg.b_grid_desc_k0_n_k1_container_{32, 128, 8}
arg.c_grid_desc_m_n_container_{ 175232, 128}
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_( 2738, 2, 2, 2, 4, 2 )
launch_and_time_kernel: grid_dim {1369, 1, 1}, block_dim {256, 1, 1}
Warm up
Start running 1 times...
Perf: 1.40031 ms, 69.8734 TFlops, 179.037 GB/s
```
example/17_convnd_bwd_data_xdl/convnd_bwd_data_xdl.cpp
0 → 100644
View file @
cc8df39e
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
#include <half.hpp>
#include "config.hpp"
#include "conv_utils.hpp"
#include "print.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "device_tensor.hpp"
#include "tensor_layout.hpp"
#include "element_wise_operation.hpp"
#include "device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp"
#include "reference_conv_bwd_data.hpp"
using
InDataType
=
ck
::
half_t
;
using
WeiDataType
=
ck
::
half_t
;
using
OutDataType
=
ck
::
half_t
;
using
AccDataType
=
float
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
InElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
WeiElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
OutElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
ConvBwdDefault
=
ck
::
tensor_operation
::
device
::
ConvolutionBackwardDataSpecialization_t
::
Default
;
using
DeviceConvBwdDataBasePtr
=
ck
::
tensor_operation
::
device
::
DeviceConvBwdDataPtr
<
InElementOp
,
WeiElementOp
,
OutElementOp
>
;
template
<
ck
::
index_t
NumDimSpatial
>
using
DeviceConvNDBwdDataInstance
=
ck
::
tensor_operation
::
device
::
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
<
InDataType
,
// InDataType
WeiDataType
,
// WeiDataType
OutDataType
,
// OutDataType
AccDataType
,
// AccDataType
InElementOp
,
// InElementwiseOperation
WeiElementOp
,
// WeiElementwiseOperation
OutElementOp
,
// OutElementwiseOperation
ConvBwdDefault
,
// ConvolutionBackwardDataSpecialization_t
NumDimSpatial
,
// NumDimSpatial
256
,
// BlockSize
128
,
// MPerBlock
128
,
// NPerBlock
4
,
// K0PerBlock
8
,
// K1
32
,
// MPerXdl
32
,
// NPerXdl
2
,
// MXdlPerWave
2
,
// NXdlPerWave
S
<
4
,
64
,
1
>
,
// ABlockTransferThreadClusterLengths_K0_M_K1
S
<
1
,
0
,
2
>
,
// ABlockTransferThreadClusterArrangeOrder
S
<
1
,
0
,
2
>
,
// ABlockTransferSrcAccessOrder
2
,
// ABlockTransferSrcVectorDim
8
,
// ABlockTransferSrcScalarPerVector
8
,
// ABlockTransferDstScalarPerVector_K1
true
,
// ABlockLdsAddExtraM
S
<
4
,
64
,
1
>
,
// BBlockTransferThreadClusterLengths_K0_N_K1
S
<
2
,
0
,
1
>
,
// BBlockTransferThreadClusterArrangeOrder
S
<
0
,
2
,
1
>
,
// BBlockTransferSrcAccessOrder
1
,
// BBlockTransferSrcVectorDim
2
,
// BBlockTransferSrcScalarPerVector
8
,
// BBlockTransferDstScalarPerVector_K1
true
,
// BBlockLdsAddExtraN
7
,
1
>
;
// GemmCThreadTransferDstScalarPerVector
template
<
ck
::
index_t
NumDimSpatial
>
using
ReferenceConvBwdDataInstance
=
ck
::
tensor_operation
::
host
::
ReferenceConvBwdData
<
InDataType
,
WeiDataType
,
OutDataType
,
AccDataType
,
InElementOp
,
WeiElementOp
,
OutElementOp
,
NumDimSpatial
>
;
void
PrintUseMsg
()
{
std
::
cout
<<
"arg1: verification (0=no, 1=yes)
\n
"
<<
"arg2: initialization (0=no init, 1=random value, 2= init to 1 )
\n
"
<<
"arg3: run kernel # of times (>1)
\n
"
<<
"arg4: N spatial dimensions (default 2)
\n
"
<<
"Following arguments (depending on number of spatial dims):
\n
"
<<
" N, K, C,
\n
"
<<
" <filter spatial dimensions>, (ie Y, X for 2D)
\n
"
<<
" <input image spatial dimensions>, (ie Hi, Wi for 2D)
\n
"
<<
" <strides>, (ie Sy, Sx for 2D)
\n
"
<<
" <dilations>, (ie Dy, Dx for 2D)
\n
"
<<
" <left padding>, (ie LeftPy, LeftPx for 2D)
\n
"
<<
" <right padding>, (ie RightPy, RightPx for 2D)
\n
"
<<
std
::
endl
;
}
ck
::
conv_util
::
ConvParams
ParseConvParams
(
int
num_dim_spatial
,
char
*
argv
[])
{
// (N, K, C) + num_dim_spatial * 6 (filter, input, strides, dilations, pad left, pad right)
ck
::
conv_util
::
ConvParams
params
;
int
arg_idx
=
5
;
params
.
num_dim_spatial
=
num_dim_spatial
;
params
.
N
=
std
::
stoi
(
argv
[
arg_idx
++
]);
params
.
K
=
std
::
stoi
(
argv
[
arg_idx
++
]);
params
.
C
=
std
::
stoi
(
argv
[
arg_idx
++
]);
params
.
filter_spatial_lengths
.
resize
(
num_dim_spatial
);
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
{
params
.
filter_spatial_lengths
[
i
]
=
std
::
stoi
(
argv
[
arg_idx
++
]);
}
params
.
input_spatial_lengths
.
resize
(
num_dim_spatial
);
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
{
params
.
input_spatial_lengths
[
i
]
=
std
::
stoi
(
argv
[
arg_idx
++
]);
}
params
.
conv_filter_strides
.
resize
(
num_dim_spatial
);
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
{
params
.
conv_filter_strides
[
i
]
=
std
::
stoi
(
argv
[
arg_idx
++
]);
}
params
.
conv_filter_dilations
.
resize
(
num_dim_spatial
);
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
{
params
.
conv_filter_dilations
[
i
]
=
std
::
stoi
(
argv
[
arg_idx
++
]);
}
params
.
input_left_pads
.
resize
(
num_dim_spatial
);
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
{
params
.
input_left_pads
[
i
]
=
std
::
stoi
(
argv
[
arg_idx
++
]);
}
params
.
input_right_pads
.
resize
(
num_dim_spatial
);
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
{
params
.
input_right_pads
[
i
]
=
std
::
stoi
(
argv
[
arg_idx
++
]);
}
return
params
;
}
HostTensorDescriptor
GetInputHostTensorDescriptor
(
const
std
::
vector
<
std
::
size_t
>&
dims
,
int
num_dim_spatial
=
2
)
{
namespace
tl
=
ck
::
tensor_layout
::
convolution
;
switch
(
num_dim_spatial
)
{
case
3
:
{
return
ck
::
conv_util
::
GetHostTensorDescriptor
(
dims
,
tl
::
NDHWC
{});
}
case
2
:
{
return
ck
::
conv_util
::
GetHostTensorDescriptor
(
dims
,
tl
::
NHWC
{});
}
case
1
:
{
return
ck
::
conv_util
::
GetHostTensorDescriptor
(
dims
,
tl
::
NWC
{});
}
default:
{
throw
std
::
runtime_error
(
"Unsupported number of spatial dimensions provided!"
);
}
}
}
HostTensorDescriptor
GetFiltersHostTensorDescriptor
(
const
std
::
vector
<
std
::
size_t
>&
dims
,
int
num_dim_spatial
=
2
)
{
namespace
tl
=
ck
::
tensor_layout
::
convolution
;
switch
(
num_dim_spatial
)
{
case
3
:
{
return
ck
::
conv_util
::
GetHostTensorDescriptor
(
dims
,
tl
::
KZYXC
{});
}
case
2
:
{
return
ck
::
conv_util
::
GetHostTensorDescriptor
(
dims
,
tl
::
KYXC
{});
}
case
1
:
{
return
ck
::
conv_util
::
GetHostTensorDescriptor
(
dims
,
tl
::
KXC
{});
}
default:
{
throw
std
::
runtime_error
(
"Unsupported number of spatial dimensions provided!"
);
}
}
}
HostTensorDescriptor
GetOutputHostTensorDescriptor
(
const
std
::
vector
<
std
::
size_t
>&
dims
,
int
num_dim_spatial
=
2
)
{
namespace
tl
=
ck
::
tensor_layout
::
convolution
;
switch
(
num_dim_spatial
)
{
case
3
:
{
return
ck
::
conv_util
::
GetHostTensorDescriptor
(
dims
,
tl
::
NDHWK
{});
}
case
2
:
{
return
ck
::
conv_util
::
GetHostTensorDescriptor
(
dims
,
tl
::
NHWK
{});
}
case
1
:
{
return
ck
::
conv_util
::
GetHostTensorDescriptor
(
dims
,
tl
::
NWK
{});
}
default:
{
throw
std
::
runtime_error
(
"Unsupported number of spatial dimensions provided!"
);
}
}
}
DeviceConvBwdDataBasePtr
GetConvInstance
(
int
num_dim_spatial
)
{
switch
(
num_dim_spatial
)
{
case
3
:
{
return
std
::
make_unique
<
DeviceConvNDBwdDataInstance
<
3
>>
();
}
case
2
:
{
return
std
::
make_unique
<
DeviceConvNDBwdDataInstance
<
2
>>
();
}
case
1
:
{
return
std
::
make_unique
<
DeviceConvNDBwdDataInstance
<
1
>>
();
}
default:
{
throw
std
::
runtime_error
(
"Unsupported number of spatial dimensions provided!"
);
}
}
}
int
main
(
int
argc
,
char
*
argv
[])
{
bool
do_verification
=
0
;
int
init_method
=
0
;
int
nrepeat
=
5
;
int
num_dim_spatial
=
2
;
ck
::
conv_util
::
ConvParams
params
;
params
.
C
=
128
;
if
(
argc
==
4
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
nrepeat
=
std
::
stoi
(
argv
[
3
]);
}
else
if
(
argc
>
4
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
nrepeat
=
std
::
stoi
(
argv
[
3
]);
num_dim_spatial
=
std
::
stoi
(
argv
[
4
]);
// check args number
int
conv_args
=
3
+
num_dim_spatial
*
6
;
int
cmdline_nargs
=
conv_args
+
5
;
if
(
cmdline_nargs
!=
argc
)
{
PrintUseMsg
();
exit
(
1
);
}
params
=
ParseConvParams
(
num_dim_spatial
,
argv
);
}
else
if
(
argc
!=
1
)
{
PrintUseMsg
();
exit
(
1
);
}
std
::
vector
<
std
::
size_t
>
input_dims
{
static_cast
<
std
::
size_t
>
(
params
.
N
),
static_cast
<
std
::
size_t
>
(
params
.
C
)};
input_dims
.
insert
(
std
::
end
(
input_dims
),
std
::
begin
(
params
.
input_spatial_lengths
),
std
::
end
(
params
.
input_spatial_lengths
));
std
::
vector
<
std
::
size_t
>
filter_dims
{
static_cast
<
std
::
size_t
>
(
params
.
K
),
static_cast
<
std
::
size_t
>
(
params
.
C
)};
filter_dims
.
insert
(
std
::
end
(
filter_dims
),
std
::
begin
(
params
.
filter_spatial_lengths
),
std
::
end
(
params
.
filter_spatial_lengths
));
const
std
::
vector
<
ck
::
index_t
>&
output_spatial_lengths
=
params
.
GetOutputSpatialLengths
();
std
::
vector
<
std
::
size_t
>
output_dims
{
static_cast
<
std
::
size_t
>
(
params
.
N
),
static_cast
<
std
::
size_t
>
(
params
.
K
)};
output_dims
.
insert
(
std
::
end
(
output_dims
),
std
::
begin
(
output_spatial_lengths
),
std
::
end
(
output_spatial_lengths
));
Tensor
<
InDataType
>
in_n_c_hi_wi_host_result
(
GetInputHostTensorDescriptor
(
input_dims
,
num_dim_spatial
));
Tensor
<
InDataType
>
in_n_c_hi_wi_device_result
(
GetInputHostTensorDescriptor
(
input_dims
,
num_dim_spatial
));
Tensor
<
WeiDataType
>
wei_k_c_y_x
(
GetFiltersHostTensorDescriptor
(
filter_dims
,
num_dim_spatial
));
Tensor
<
OutDataType
>
out_n_k_ho_wo
(
GetOutputHostTensorDescriptor
(
output_dims
,
num_dim_spatial
));
std
::
cout
<<
"in_n_c_hi_wi: "
<<
in_n_c_hi_wi_host_result
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"wei_k_c_y_x: "
<<
wei_k_c_y_x
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"out_n_k_ho_wo: "
<<
out_n_k_ho_wo
.
mDesc
<<
std
::
endl
;
switch
(
init_method
)
{
case
0
:
break
;
case
1
:
out_n_k_ho_wo
.
GenerateTensorValue
(
GeneratorTensor_3
<
OutDataType
>
{
-
0.2
,
0.2
});
wei_k_c_y_x
.
GenerateTensorValue
(
GeneratorTensor_3
<
WeiDataType
>
{
-
0.2
,
0.2
});
break
;
default:
out_n_k_ho_wo
.
GenerateTensorValue
(
GeneratorTensor_1
<
OutDataType
>
{
1
});
wei_k_c_y_x
.
GenerateTensorValue
(
GeneratorTensor_1
<
WeiDataType
>
{
1
});
}
DeviceMem
in_device_buf
(
sizeof
(
InDataType
)
*
in_n_c_hi_wi_device_result
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_device_buf
(
sizeof
(
WeiDataType
)
*
wei_k_c_y_x
.
mDesc
.
GetElementSpace
());
DeviceMem
out_device_buf
(
sizeof
(
OutDataType
)
*
out_n_k_ho_wo
.
mDesc
.
GetElementSpace
());
out_device_buf
.
ToDevice
(
out_n_k_ho_wo
.
mData
.
data
());
wei_device_buf
.
ToDevice
(
wei_k_c_y_x
.
mData
.
data
());
// reset input to zero
in_n_c_hi_wi_device_result
.
GenerateTensorValue
(
GeneratorTensor_1
<
InDataType
>
{
0
});
in_device_buf
.
ToDevice
(
in_n_c_hi_wi_device_result
.
mData
.
data
());
// do GEMM
auto
conv
=
GetConvInstance
(
num_dim_spatial
);
auto
invoker
=
conv
->
MakeInvokerPointer
();
auto
argument
=
conv
->
MakeArgumentPointer
(
static_cast
<
InDataType
*>
(
in_device_buf
.
GetDeviceBuffer
()),
static_cast
<
WeiDataType
*>
(
wei_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutDataType
*>
(
out_device_buf
.
GetDeviceBuffer
()),
params
.
N
,
params
.
K
,
params
.
C
,
params
.
input_spatial_lengths
,
params
.
filter_spatial_lengths
,
output_spatial_lengths
,
params
.
conv_filter_strides
,
params
.
conv_filter_dilations
,
params
.
input_left_pads
,
params
.
input_right_pads
,
InElementOp
{},
WeiElementOp
{},
OutElementOp
{});
if
(
!
conv
->
IsSupportedArgument
(
argument
.
get
()))
{
throw
std
::
runtime_error
(
"wrong! device_conv with the specified compilation parameters does "
"not support this Conv problem"
);
}
float
ave_time
=
invoker
->
Run
(
argument
.
get
(),
nrepeat
);
std
::
size_t
flop
=
ck
::
conv_util
::
GetFlops
(
params
.
N
,
params
.
C
,
params
.
K
,
params
.
filter_spatial_lengths
,
output_spatial_lengths
);
std
::
size_t
num_btype
=
ck
::
conv_util
::
GetBtype
<
InDataType
,
WeiDataType
,
OutDataType
>
(
params
.
N
,
params
.
C
,
params
.
K
,
params
.
input_spatial_lengths
,
params
.
filter_spatial_lengths
,
output_spatial_lengths
);
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s"
<<
std
::
endl
;
if
(
do_verification
)
{
auto
verify_f
=
[
&
](
const
auto
&
ref_conv
)
{
auto
ref_invoker
=
ref_conv
.
MakeInvoker
();
auto
ref_argument
=
ref_conv
.
MakeArgument
(
in_n_c_hi_wi_host_result
,
wei_k_c_y_x
,
out_n_k_ho_wo
,
params
.
conv_filter_strides
,
params
.
conv_filter_dilations
,
params
.
input_left_pads
,
params
.
input_right_pads
,
InElementOp
{},
WeiElementOp
{},
OutElementOp
{});
ref_invoker
.
Run
(
ref_argument
);
in_device_buf
.
FromDevice
(
in_n_c_hi_wi_device_result
.
mData
.
data
());
check_error
(
in_n_c_hi_wi_host_result
,
in_n_c_hi_wi_device_result
);
};
switch
(
num_dim_spatial
)
{
case
3
:
{
auto
ref_conv
=
ReferenceConvBwdDataInstance
<
3
>
();
verify_f
(
ref_conv
);
break
;
}
case
2
:
{
auto
ref_conv
=
ReferenceConvBwdDataInstance
<
2
>
();
verify_f
(
ref_conv
);
break
;
}
case
1
:
{
auto
ref_conv
=
ReferenceConvBwdDataInstance
<
1
>
();
verify_f
(
ref_conv
);
break
;
}
default:
{
throw
std
::
runtime_error
(
"Unsupported number of spatial dimensions provided!"
);
}
}
}
}
example/CMakeLists.txt
View file @
cc8df39e
...
@@ -39,5 +39,6 @@ add_subdirectory(11_conv2d_bwd_wgt)
...
@@ -39,5 +39,6 @@ add_subdirectory(11_conv2d_bwd_wgt)
add_subdirectory
(
12_reduce
)
add_subdirectory
(
12_reduce
)
add_subdirectory
(
13_pool2d_fwd
)
add_subdirectory
(
13_pool2d_fwd
)
add_subdirectory
(
14_gemm_xdl_requant_relu_requant
)
add_subdirectory
(
14_gemm_xdl_requant_relu_requant
)
add_subdirectory
(
17_convnd_bwd_data_xdl
)
add_subdirectory
(
15_grouped_gemm
)
add_subdirectory
(
15_grouped_gemm
)
add_subdirectory
(
16_gemm_reduce
)
add_subdirectory
(
16_gemm_reduce
)
include/ck/problem_transform/transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk.hpp
View file @
cc8df39e
...
@@ -7,9 +7,9 @@
...
@@ -7,9 +7,9 @@
namespace
ck
{
namespace
ck
{
// Number of GEMMs = YTild
a
* XTild
a
// Number of GEMMs = YTild
e
* XTild
e
// GemmM = C
// GemmM = C
// GemmN = N * HTild
a
Slice * WTild
a
Slice
// GemmN = N * HTild
e
Slice * WTild
e
Slice
// GemmK = K * YDotSlice * XDotSlice
// GemmK = K * YDotSlice * XDotSlice
template
<
typename
...
Wei
,
template
<
typename
...
Wei
,
typename
...
In
,
typename
...
In
,
...
@@ -18,8 +18,8 @@ template <typename... Wei,
...
@@ -18,8 +18,8 @@ template <typename... Wei,
typename
ConvDilations
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InLeftPads
,
typename
InRightPads
,
typename
InRightPads
,
index_t
IYTild
a
Value
,
index_t
IYTild
e
Value
,
index_t
IXTild
a
Value
,
index_t
IXTild
e
Value
,
index_t
GemmK1Value
>
index_t
GemmK1Value
>
__host__
__device__
constexpr
auto
__host__
__device__
constexpr
auto
transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk
(
transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk
(
...
@@ -30,8 +30,8 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
...
@@ -30,8 +30,8 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
const
ConvDilations
&
conv_dilations
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
const
InLeftPads
&
in_left_pads
,
const
InRightPads
&
in_right_pads
,
const
InRightPads
&
in_right_pads
,
Number
<
IYTild
a
Value
>
,
Number
<
IYTild
e
Value
>
,
Number
<
IXTild
a
Value
>
,
Number
<
IXTild
e
Value
>
,
Number
<
GemmK1Value
>
)
Number
<
GemmK1Value
>
)
{
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
...
@@ -40,8 +40,8 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
...
@@ -40,8 +40,8 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
GemmK1
=
Number
<
GemmK1Value
>
{};
constexpr
auto
GemmK1
=
Number
<
GemmK1Value
>
{};
constexpr
auto
IYTild
a
=
Number
<
IYTild
a
Value
>
{};
constexpr
auto
IYTild
e
=
Number
<
IYTild
e
Value
>
{};
constexpr
auto
IXTild
a
=
Number
<
IXTild
a
Value
>
{};
constexpr
auto
IXTild
e
=
Number
<
IXTild
e
Value
>
{};
const
auto
N
=
in_n_hi_wi_c_grid_desc
.
GetLength
(
I0
);
const
auto
N
=
in_n_hi_wi_c_grid_desc
.
GetLength
(
I0
);
const
auto
C
=
in_n_hi_wi_c_grid_desc
.
GetLength
(
I3
);
const
auto
C
=
in_n_hi_wi_c_grid_desc
.
GetLength
(
I3
);
...
@@ -71,55 +71,55 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
...
@@ -71,55 +71,55 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
const
auto
GcdStrideDilationH
=
math
::
gcd
(
ConvStrideH
,
ConvDilationH
);
const
auto
GcdStrideDilationH
=
math
::
gcd
(
ConvStrideH
,
ConvDilationH
);
const
auto
GcdStrideDilationW
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
const
auto
GcdStrideDilationW
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
const
auto
YTild
a
=
ConvStrideH
/
GcdStrideDilationH
;
const
auto
YTild
e
=
ConvStrideH
/
GcdStrideDilationH
;
const
auto
XTild
a
=
ConvStrideW
/
GcdStrideDilationW
;
const
auto
XTild
e
=
ConvStrideW
/
GcdStrideDilationW
;
const
auto
YDot
=
math
::
integer_divide_ceil
(
Y
,
YTild
a
);
const
auto
YDot
=
math
::
integer_divide_ceil
(
Y
,
YTild
e
);
const
auto
XDot
=
math
::
integer_divide_ceil
(
X
,
XTild
a
);
const
auto
XDot
=
math
::
integer_divide_ceil
(
X
,
XTild
e
);
const
auto
HTild
a
=
Ho
+
math
::
integer_divide_ceil
(
ConvDilationH
*
(
Y
-
I1
),
ConvStrideH
);
const
auto
HTild
e
=
Ho
+
math
::
integer_divide_ceil
(
ConvDilationH
*
(
Y
-
I1
),
ConvStrideH
);
const
auto
WTild
a
=
Wo
+
math
::
integer_divide_ceil
(
ConvDilationW
*
(
X
-
I1
),
ConvStrideW
);
const
auto
WTild
e
=
Wo
+
math
::
integer_divide_ceil
(
ConvDilationW
*
(
X
-
I1
),
ConvStrideW
);
// only work on HTild
a
and WTild
a
that contribute to non-padding area of input tensor
// only work on HTild
e
and WTild
e
that contribute to non-padding area of input tensor
const
auto
IHTild
a
SliceBegin
=
math
::
integer_divide_floor
(
const
auto
IHTild
e
SliceBegin
=
math
::
integer_divide_floor
(
math
::
max
(
I0
,
InLeftPadH
-
ConvDilationH
*
(
YTild
a
-
I1
)),
ConvStrideH
);
math
::
max
(
I0
,
InLeftPadH
-
ConvDilationH
*
(
YTild
e
-
I1
)),
ConvStrideH
);
const
auto
IWTild
a
SliceBegin
=
math
::
integer_divide_floor
(
const
auto
IWTild
e
SliceBegin
=
math
::
integer_divide_floor
(
math
::
max
(
I0
,
InLeftPadW
-
ConvDilationW
*
(
XTild
a
-
I1
)),
ConvStrideW
);
math
::
max
(
I0
,
InLeftPadW
-
ConvDilationW
*
(
XTild
e
-
I1
)),
ConvStrideW
);
const
auto
IHTild
a
SliceEnd
=
const
auto
IHTild
e
SliceEnd
=
math
::
min
(
HTild
a
,
math
::
integer_divide_ceil
(
InLeftPadH
+
Hi
-
I1
,
ConvStrideH
)
+
I1
);
math
::
min
(
HTild
e
,
math
::
integer_divide_ceil
(
InLeftPadH
+
Hi
-
I1
,
ConvStrideH
)
+
I1
);
const
auto
IWTild
a
SliceEnd
=
const
auto
IWTild
e
SliceEnd
=
math
::
min
(
WTild
a
,
math
::
integer_divide_ceil
(
InLeftPadW
+
Wi
-
I1
,
ConvStrideW
)
+
I1
);
math
::
min
(
WTild
e
,
math
::
integer_divide_ceil
(
InLeftPadW
+
Wi
-
I1
,
ConvStrideW
)
+
I1
);
const
auto
HTild
a
Slice
=
IHTild
a
SliceEnd
-
IHTild
a
SliceBegin
;
const
auto
HTild
e
Slice
=
IHTild
e
SliceEnd
-
IHTild
e
SliceBegin
;
const
auto
WTild
a
Slice
=
IWTild
a
SliceEnd
-
IWTild
a
SliceBegin
;
const
auto
WTild
e
Slice
=
IWTild
e
SliceEnd
-
IWTild
e
SliceBegin
;
// GemmK is different for each GEMM
// GemmK is different for each GEMM
const
auto
YDotSlice
=
math
::
integer_divide_ceil
(
Y
-
IYTild
a
,
YTild
a
);
const
auto
YDotSlice
=
math
::
integer_divide_ceil
(
Y
-
IYTild
e
,
YTild
e
);
const
auto
XDotSlice
=
math
::
integer_divide_ceil
(
X
-
IXTild
a
,
XTild
a
);
const
auto
XDotSlice
=
math
::
integer_divide_ceil
(
X
-
IXTild
e
,
XTild
e
);
const
auto
K1
=
GemmK1
;
const
auto
K1
=
GemmK1
;
const
auto
K0
=
K
/
K1
;
const
auto
K0
=
K
/
K1
;
// weight tensor
// weight tensor
const
auto
wei_k_ydot_ytild
a
_xdot_xtild
a
_c_grid_desc
=
transform_tensor_descriptor
(
const
auto
wei_k_ydot_ytild
e
_xdot_xtild
e
_c_grid_desc
=
transform_tensor_descriptor
(
wei_k_y_x_c_grid_desc
,
wei_k_y_x_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
K
),
make_tuple
(
make_pass_through_transform
(
K
),
make_embed_transform
(
make_tuple
(
YDot
,
YTild
a
),
make_embed_transform
(
make_tuple
(
YDot
,
YTild
e
),
make_tuple
(
ConvStrideH
/
GcdStrideDilationH
,
I1
)),
make_tuple
(
ConvStrideH
/
GcdStrideDilationH
,
I1
)),
make_embed_transform
(
make_tuple
(
XDot
,
XTild
a
),
make_embed_transform
(
make_tuple
(
XDot
,
XTild
e
),
make_tuple
(
ConvStrideW
/
GcdStrideDilationW
,
I1
)),
make_tuple
(
ConvStrideW
/
GcdStrideDilationW
,
I1
)),
make_pass_through_transform
(
C
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
wei_k0_k1_ydotslice_xdotslice_c_grid_desc
=
const
auto
wei_k0_k1_ydotslice_xdotslice_c_grid_desc
=
transform_tensor_descriptor
(
wei_k_ydot_ytild
a
_xdot_xtild
a
_c_grid_desc
,
transform_tensor_descriptor
(
wei_k_ydot_ytild
e
_xdot_xtild
e
_c_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1
)),
make_slice_transform
(
YDot
,
I0
,
YDotSlice
),
make_slice_transform
(
YDot
,
I0
,
YDotSlice
),
make_slice_transform
(
XDot
,
I0
,
XDotSlice
),
make_slice_transform
(
XDot
,
I0
,
XDotSlice
),
make_freeze_transform
(
IYTild
a
),
make_freeze_transform
(
IYTild
e
),
make_freeze_transform
(
IXTild
a
),
make_freeze_transform
(
IXTild
e
),
make_pass_through_transform
(
C
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
1
>
{},
...
@@ -163,25 +163,25 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
...
@@ -163,25 +163,25 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
out_n_ydot_htild
a
_xdot_wtild
a
_k_grid_desc
=
transform_tensor_descriptor
(
const
auto
out_n_ydot_htild
e
_xdot_wtild
e
_k_grid_desc
=
transform_tensor_descriptor
(
out_n_hop_wop_k_grid_desc
,
out_n_hop_wop_k_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
YDot
,
HTild
a
),
make_embed_transform
(
make_tuple
(
YDot
,
HTild
e
),
make_tuple
(
-
ConvDilationH
/
GcdStrideDilationH
,
I1
)),
make_tuple
(
-
ConvDilationH
/
GcdStrideDilationH
,
I1
)),
make_embed_transform
(
make_tuple
(
XDot
,
WTild
a
),
make_embed_transform
(
make_tuple
(
XDot
,
WTild
e
),
make_tuple
(
-
ConvDilationW
/
GcdStrideDilationW
,
I1
)),
make_tuple
(
-
ConvDilationW
/
GcdStrideDilationW
,
I1
)),
make_pass_through_transform
(
K
)),
make_pass_through_transform
(
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
out_n_ydotslice_htild
a
slice_xdotslice_wtild
a
slice_k0_k1_grid_desc
=
const
auto
out_n_ydotslice_htild
e
slice_xdotslice_wtild
e
slice_k0_k1_grid_desc
=
transform_tensor_descriptor
(
transform_tensor_descriptor
(
out_n_ydot_htild
a
_xdot_wtild
a
_k_grid_desc
,
out_n_ydot_htild
e
_xdot_wtild
e
_k_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_tuple
(
make_pass_through_transform
(
N
),
make_slice_transform
(
YDot
,
I0
,
YDotSlice
),
make_slice_transform
(
YDot
,
I0
,
YDotSlice
),
make_slice_transform
(
HTild
a
,
IHTild
a
SliceBegin
,
HTild
a
Slice
),
make_slice_transform
(
HTild
e
,
IHTild
e
SliceBegin
,
HTild
e
Slice
),
make_slice_transform
(
XDot
,
I0
,
XDotSlice
),
make_slice_transform
(
XDot
,
I0
,
XDotSlice
),
make_slice_transform
(
WTild
a
,
IWTild
a
SliceBegin
,
WTild
a
Slice
),
make_slice_transform
(
WTild
e
,
IWTild
e
SliceBegin
,
WTild
e
Slice
),
make_unmerge_transform
(
make_tuple
(
K0
,
K1
))),
make_unmerge_transform
(
make_tuple
(
K0
,
K1
))),
make_tuple
(
Sequence
<
0
>
{},
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
1
>
{},
...
@@ -198,17 +198,17 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
...
@@ -198,17 +198,17 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
#if 1
#if 1
const
auto
out_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
const
auto
out_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
out_n_ydotslice_htild
a
slice_xdotslice_wtild
a
slice_k0_k1_grid_desc
,
out_n_ydotslice_htild
e
slice_xdotslice_wtild
e
slice_k0_k1_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
YDotSlice
,
XDotSlice
,
K0
)),
make_tuple
(
make_merge_transform
(
make_tuple
(
YDotSlice
,
XDotSlice
,
K0
)),
make_merge_transform
(
make_tuple
(
N
,
HTild
a
Slice
,
WTild
a
Slice
)),
make_merge_transform
(
make_tuple
(
N
,
HTild
e
Slice
,
WTild
e
Slice
)),
make_pass_through_transform
(
K1
)),
make_pass_through_transform
(
K1
)),
make_tuple
(
Sequence
<
1
,
3
,
5
>
{},
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
6
>
{}),
make_tuple
(
Sequence
<
1
,
3
,
5
>
{},
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
6
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
#else
#else
const
auto
out_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
const
auto
out_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
out_n_ydotslice_htild
a
slice_xdotslice_wtild
a
slice_k0_k1_grid_desc
,
out_n_ydotslice_htild
e
slice_xdotslice_wtild
e
slice_k0_k1_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
K0
,
YDotSlice
,
XDotSlice
)),
make_tuple
(
make_merge_transform
(
make_tuple
(
K0
,
YDotSlice
,
XDotSlice
)),
make_merge_transform
(
make_tuple
(
N
,
HTild
a
Slice
,
WTild
a
Slice
)),
make_merge_transform
(
make_tuple
(
N
,
HTild
e
Slice
,
WTild
e
Slice
)),
make_pass_through_transform
(
K1
)),
make_pass_through_transform
(
K1
)),
make_tuple
(
Sequence
<
5
,
1
,
3
>
{},
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
6
>
{}),
make_tuple
(
Sequence
<
5
,
1
,
3
>
{},
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
6
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
...
@@ -224,24 +224,24 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
...
@@ -224,24 +224,24 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_n_ytild
a
_htild
a
_xtild
a
_wtild
a
_c_grid_desc
=
transform_tensor_descriptor
(
const
auto
in_n_ytild
e
_htild
e
_xtild
e
_wtild
e
_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hip_wip_c_grid_desc
,
in_n_hip_wip_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
YTild
a
,
HTild
a
),
make_embed_transform
(
make_tuple
(
YTild
e
,
HTild
e
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
XTild
a
,
WTild
a
),
make_embed_transform
(
make_tuple
(
XTild
e
,
WTild
e
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
in_n_htild
a
slice_wtild
a
slice_c_grid_desc
=
transform_tensor_descriptor
(
const
auto
in_n_htild
e
slice_wtild
e
slice_c_grid_desc
=
transform_tensor_descriptor
(
in_n_ytild
a
_htild
a
_xtild
a
_wtild
a
_c_grid_desc
,
in_n_ytild
e
_htild
e
_xtild
e
_wtild
e
_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_tuple
(
make_pass_through_transform
(
N
),
make_freeze_transform
(
IYTild
a
),
make_freeze_transform
(
IYTild
e
),
make_slice_transform
(
HTild
a
,
IHTild
a
SliceBegin
,
HTild
a
Slice
),
make_slice_transform
(
HTild
e
,
IHTild
e
SliceBegin
,
HTild
e
Slice
),
make_freeze_transform
(
IXTild
a
),
make_freeze_transform
(
IXTild
e
),
make_slice_transform
(
WTild
a
,
IWTild
a
SliceBegin
,
WTild
a
Slice
),
make_slice_transform
(
WTild
e
,
IWTild
e
SliceBegin
,
WTild
e
Slice
),
make_pass_through_transform
(
C
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
1
>
{},
...
@@ -257,9 +257,9 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
...
@@ -257,9 +257,9 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
Sequence
<
3
>
{}));
Sequence
<
3
>
{}));
const
auto
in_gemmm_gemmn_grid_desc
=
transform_tensor_descriptor
(
const
auto
in_gemmm_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_n_htild
a
slice_wtild
a
slice_c_grid_desc
,
in_n_htild
e
slice_wtild
e
slice_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
C
),
make_tuple
(
make_pass_through_transform
(
C
),
make_merge_transform
(
make_tuple
(
N
,
HTild
a
Slice
,
WTild
a
Slice
))),
make_merge_transform
(
make_tuple
(
N
,
HTild
e
Slice
,
WTild
e
Slice
))),
make_tuple
(
Sequence
<
3
>
{},
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
3
>
{},
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
...
...
include/ck/problem_transform/transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp
View file @
cc8df39e
...
@@ -10,8 +10,8 @@ namespace ck {
...
@@ -10,8 +10,8 @@ namespace ck {
// A: out
// A: out
// B: wei
// B: wei
// C: in
// C: in
// Number of GEMMs = YTild
a
* XTild
a
// Number of GEMMs = YTild
e
* XTild
e
// GemmM = N * HTild
a
Slice * WTild
a
Slice
// GemmM = N * HTild
e
Slice * WTild
e
Slice
// GemmN = C
// GemmN = C
// GemmK = K * YDotSlice * XDotSlice
// GemmK = K * YDotSlice * XDotSlice
template
<
typename
...
Wei
,
template
<
typename
...
Wei
,
...
@@ -21,8 +21,8 @@ template <typename... Wei,
...
@@ -21,8 +21,8 @@ template <typename... Wei,
typename
ConvDilations
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InLeftPads
,
typename
InRightPads
,
typename
InRightPads
,
typename
IYTild
a
,
typename
IYTild
e
,
typename
IXTild
a
,
typename
IXTild
e
,
index_t
GemmK1Value
>
index_t
GemmK1Value
>
__host__
__device__
constexpr
auto
__host__
__device__
constexpr
auto
transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk
(
transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk
(
...
@@ -33,8 +33,8 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
...
@@ -33,8 +33,8 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
const
ConvDilations
&
conv_dilations
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
const
InLeftPads
&
in_left_pads
,
const
InRightPads
&
in_right_pads
,
const
InRightPads
&
in_right_pads
,
IYTild
a
i_ytild
a
,
IYTild
e
i_ytild
e
,
IXTild
a
i_xtild
a
,
IXTild
e
i_xtild
e
,
Number
<
GemmK1Value
>
)
Number
<
GemmK1Value
>
)
{
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
...
@@ -72,32 +72,32 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
...
@@ -72,32 +72,32 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
const
auto
GcdStrideDilationH
=
math
::
gcd
(
ConvStrideH
,
ConvDilationH
);
const
auto
GcdStrideDilationH
=
math
::
gcd
(
ConvStrideH
,
ConvDilationH
);
const
auto
GcdStrideDilationW
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
const
auto
GcdStrideDilationW
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
const
auto
YTild
a
=
ConvStrideH
/
GcdStrideDilationH
;
const
auto
YTild
e
=
ConvStrideH
/
GcdStrideDilationH
;
const
auto
XTild
a
=
ConvStrideW
/
GcdStrideDilationW
;
const
auto
XTild
e
=
ConvStrideW
/
GcdStrideDilationW
;
const
auto
YDot
=
math
::
integer_divide_ceil
(
Y
,
YTild
a
);
const
auto
YDot
=
math
::
integer_divide_ceil
(
Y
,
YTild
e
);
const
auto
XDot
=
math
::
integer_divide_ceil
(
X
,
XTild
a
);
const
auto
XDot
=
math
::
integer_divide_ceil
(
X
,
XTild
e
);
const
auto
HTild
a
=
Ho
+
math
::
integer_divide_ceil
(
ConvDilationH
*
(
Y
-
I1
),
ConvStrideH
);
const
auto
HTild
e
=
Ho
+
math
::
integer_divide_ceil
(
ConvDilationH
*
(
Y
-
I1
),
ConvStrideH
);
const
auto
WTild
a
=
Wo
+
math
::
integer_divide_ceil
(
ConvDilationW
*
(
X
-
I1
),
ConvStrideW
);
const
auto
WTild
e
=
Wo
+
math
::
integer_divide_ceil
(
ConvDilationW
*
(
X
-
I1
),
ConvStrideW
);
// only work on HTild
a
and WTild
a
that contribute to non-padding area of input tensor
// only work on HTild
e
and WTild
e
that contribute to non-padding area of input tensor
const
auto
IHTild
a
SliceBegin
=
math
::
integer_divide_floor
(
const
auto
IHTild
e
SliceBegin
=
math
::
integer_divide_floor
(
math
::
max
(
I0
,
InLeftPadH
-
ConvDilationH
*
(
YTild
a
-
I1
)),
ConvStrideH
);
math
::
max
(
I0
,
InLeftPadH
-
ConvDilationH
*
(
YTild
e
-
I1
)),
ConvStrideH
);
const
auto
IWTild
a
SliceBegin
=
math
::
integer_divide_floor
(
const
auto
IWTild
e
SliceBegin
=
math
::
integer_divide_floor
(
math
::
max
(
I0
,
InLeftPadW
-
ConvDilationW
*
(
XTild
a
-
I1
)),
ConvStrideW
);
math
::
max
(
I0
,
InLeftPadW
-
ConvDilationW
*
(
XTild
e
-
I1
)),
ConvStrideW
);
const
auto
IHTild
a
SliceEnd
=
const
auto
IHTild
e
SliceEnd
=
math
::
min
(
HTild
a
,
math
::
integer_divide_ceil
(
InLeftPadH
+
Hi
-
I1
,
ConvStrideH
)
+
I1
);
math
::
min
(
HTild
e
,
math
::
integer_divide_ceil
(
InLeftPadH
+
Hi
-
I1
,
ConvStrideH
)
+
I1
);
const
auto
IWTild
a
SliceEnd
=
const
auto
IWTild
e
SliceEnd
=
math
::
min
(
WTild
a
,
math
::
integer_divide_ceil
(
InLeftPadW
+
Wi
-
I1
,
ConvStrideW
)
+
I1
);
math
::
min
(
WTild
e
,
math
::
integer_divide_ceil
(
InLeftPadW
+
Wi
-
I1
,
ConvStrideW
)
+
I1
);
const
auto
HTild
a
Slice
=
IHTild
a
SliceEnd
-
IHTild
a
SliceBegin
;
const
auto
HTild
e
Slice
=
IHTild
e
SliceEnd
-
IHTild
e
SliceBegin
;
const
auto
WTild
a
Slice
=
IWTild
a
SliceEnd
-
IWTild
a
SliceBegin
;
const
auto
WTild
e
Slice
=
IWTild
e
SliceEnd
-
IWTild
e
SliceBegin
;
// GemmK is different for each GEMM
// GemmK is different for each GEMM
const
auto
YDotSlice
=
math
::
integer_divide_ceil
(
Y
-
i_ytild
a
,
YTild
a
);
const
auto
YDotSlice
=
math
::
integer_divide_ceil
(
Y
-
i_ytild
e
,
YTild
e
);
const
auto
XDotSlice
=
math
::
integer_divide_ceil
(
X
-
i_xtild
a
,
XTild
a
);
const
auto
XDotSlice
=
math
::
integer_divide_ceil
(
X
-
i_xtild
e
,
XTild
e
);
const
auto
K1
=
GemmK1
;
const
auto
K1
=
GemmK1
;
const
auto
K0
=
K
/
K1
;
const
auto
K0
=
K
/
K1
;
...
@@ -113,25 +113,25 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
...
@@ -113,25 +113,25 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
out_n_ydot_htild
a
_xdot_wtild
a
_k_grid_desc
=
transform_tensor_descriptor
(
const
auto
out_n_ydot_htild
e
_xdot_wtild
e
_k_grid_desc
=
transform_tensor_descriptor
(
out_n_hop_wop_k_grid_desc
,
out_n_hop_wop_k_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
YDot
,
HTild
a
),
make_embed_transform
(
make_tuple
(
YDot
,
HTild
e
),
make_tuple
(
-
ConvDilationH
/
GcdStrideDilationH
,
I1
)),
make_tuple
(
-
ConvDilationH
/
GcdStrideDilationH
,
I1
)),
make_embed_transform
(
make_tuple
(
XDot
,
WTild
a
),
make_embed_transform
(
make_tuple
(
XDot
,
WTild
e
),
make_tuple
(
-
ConvDilationW
/
GcdStrideDilationW
,
I1
)),
make_tuple
(
-
ConvDilationW
/
GcdStrideDilationW
,
I1
)),
make_pass_through_transform
(
K
)),
make_pass_through_transform
(
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
out_n_ydotslice_htild
a
slice_xdotslice_wtild
a
slice_k0_k1_grid_desc
=
const
auto
out_n_ydotslice_htild
e
slice_xdotslice_wtild
e
slice_k0_k1_grid_desc
=
transform_tensor_descriptor
(
transform_tensor_descriptor
(
out_n_ydot_htild
a
_xdot_wtild
a
_k_grid_desc
,
out_n_ydot_htild
e
_xdot_wtild
e
_k_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_tuple
(
make_pass_through_transform
(
N
),
make_slice_transform
(
YDot
,
I0
,
YDotSlice
),
make_slice_transform
(
YDot
,
I0
,
YDotSlice
),
make_slice_transform
(
HTild
a
,
IHTild
a
SliceBegin
,
HTild
a
Slice
),
make_slice_transform
(
HTild
e
,
IHTild
e
SliceBegin
,
HTild
e
Slice
),
make_slice_transform
(
XDot
,
I0
,
XDotSlice
),
make_slice_transform
(
XDot
,
I0
,
XDotSlice
),
make_slice_transform
(
WTild
a
,
IWTild
a
SliceBegin
,
WTild
a
Slice
),
make_slice_transform
(
WTild
e
,
IWTild
e
SliceBegin
,
WTild
e
Slice
),
make_unmerge_transform
(
make_tuple
(
K0
,
K1
))),
make_unmerge_transform
(
make_tuple
(
K0
,
K1
))),
make_tuple
(
Sequence
<
0
>
{},
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
1
>
{},
...
@@ -148,41 +148,41 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
...
@@ -148,41 +148,41 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
#if 1
#if 1
const
auto
out_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
const
auto
out_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
out_n_ydotslice_htild
a
slice_xdotslice_wtild
a
slice_k0_k1_grid_desc
,
out_n_ydotslice_htild
e
slice_xdotslice_wtild
e
slice_k0_k1_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
YDotSlice
,
XDotSlice
,
K0
)),
make_tuple
(
make_merge_transform
(
make_tuple
(
YDotSlice
,
XDotSlice
,
K0
)),
make_merge_transform
(
make_tuple
(
N
,
HTild
a
Slice
,
WTild
a
Slice
)),
make_merge_transform
(
make_tuple
(
N
,
HTild
e
Slice
,
WTild
e
Slice
)),
make_pass_through_transform
(
K1
)),
make_pass_through_transform
(
K1
)),
make_tuple
(
Sequence
<
1
,
3
,
5
>
{},
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
6
>
{}),
make_tuple
(
Sequence
<
1
,
3
,
5
>
{},
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
6
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
#else
#else
const
auto
out_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
const
auto
out_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
out_n_ydotslice_htild
a
slice_xdotslice_wtild
a
slice_k0_k1_grid_desc
,
out_n_ydotslice_htild
e
slice_xdotslice_wtild
e
slice_k0_k1_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
K0
,
YDotSlice
,
XDotSlice
)),
make_tuple
(
make_merge_transform
(
make_tuple
(
K0
,
YDotSlice
,
XDotSlice
)),
make_merge_transform
(
make_tuple
(
N
,
HTild
a
Slice
,
WTild
a
Slice
)),
make_merge_transform
(
make_tuple
(
N
,
HTild
e
Slice
,
WTild
e
Slice
)),
make_pass_through_transform
(
K1
)),
make_pass_through_transform
(
K1
)),
make_tuple
(
Sequence
<
5
,
1
,
3
>
{},
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
6
>
{}),
make_tuple
(
Sequence
<
5
,
1
,
3
>
{},
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
6
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
#endif
#endif
// B: weight tensor
// B: weight tensor
const
auto
wei_k_ydot_ytild
a
_xdot_xtild
a
_c_grid_desc
=
transform_tensor_descriptor
(
const
auto
wei_k_ydot_ytild
e
_xdot_xtild
e
_c_grid_desc
=
transform_tensor_descriptor
(
wei_k_y_x_c_grid_desc
,
wei_k_y_x_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
K
),
make_tuple
(
make_pass_through_transform
(
K
),
make_embed_transform
(
make_tuple
(
YDot
,
YTild
a
),
make_embed_transform
(
make_tuple
(
YDot
,
YTild
e
),
make_tuple
(
ConvStrideH
/
GcdStrideDilationH
,
I1
)),
make_tuple
(
ConvStrideH
/
GcdStrideDilationH
,
I1
)),
make_embed_transform
(
make_tuple
(
XDot
,
XTild
a
),
make_embed_transform
(
make_tuple
(
XDot
,
XTild
e
),
make_tuple
(
ConvStrideW
/
GcdStrideDilationW
,
I1
)),
make_tuple
(
ConvStrideW
/
GcdStrideDilationW
,
I1
)),
make_pass_through_transform
(
C
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
wei_k0_k1_ydotslice_xdotslice_c_grid_desc
=
const
auto
wei_k0_k1_ydotslice_xdotslice_c_grid_desc
=
transform_tensor_descriptor
(
wei_k_ydot_ytild
a
_xdot_xtild
a
_c_grid_desc
,
transform_tensor_descriptor
(
wei_k_ydot_ytild
e
_xdot_xtild
e
_c_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1
)),
make_slice_transform
(
YDot
,
I0
,
YDotSlice
),
make_slice_transform
(
YDot
,
I0
,
YDotSlice
),
make_slice_transform
(
XDot
,
I0
,
XDotSlice
),
make_slice_transform
(
XDot
,
I0
,
XDotSlice
),
make_freeze_transform
(
i_ytild
a
),
make_freeze_transform
(
i_ytild
e
),
make_freeze_transform
(
i_xtild
a
),
make_freeze_transform
(
i_xtild
e
),
make_pass_through_transform
(
C
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
1
>
{},
...
@@ -225,24 +225,24 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
...
@@ -225,24 +225,24 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_n_ytild
a
_htild
a
_xtild
a
_wtild
a
_c_grid_desc
=
transform_tensor_descriptor
(
const
auto
in_n_ytild
e
_htild
e
_xtild
e
_wtild
e
_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hip_wip_c_grid_desc
,
in_n_hip_wip_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
YTild
a
,
HTild
a
),
make_embed_transform
(
make_tuple
(
YTild
e
,
HTild
e
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
XTild
a
,
WTild
a
),
make_embed_transform
(
make_tuple
(
XTild
e
,
WTild
e
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
in_n_htild
a
slice_wtild
a
slice_c_grid_desc
=
transform_tensor_descriptor
(
const
auto
in_n_htild
e
slice_wtild
e
slice_c_grid_desc
=
transform_tensor_descriptor
(
in_n_ytild
a
_htild
a
_xtild
a
_wtild
a
_c_grid_desc
,
in_n_ytild
e
_htild
e
_xtild
e
_wtild
e
_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_tuple
(
make_pass_through_transform
(
N
),
make_freeze_transform
(
i_ytild
a
),
make_freeze_transform
(
i_ytild
e
),
make_slice_transform
(
HTild
a
,
IHTild
a
SliceBegin
,
HTild
a
Slice
),
make_slice_transform
(
HTild
e
,
IHTild
e
SliceBegin
,
HTild
e
Slice
),
make_freeze_transform
(
i_xtild
a
),
make_freeze_transform
(
i_xtild
e
),
make_slice_transform
(
WTild
a
,
IWTild
a
SliceBegin
,
WTild
a
Slice
),
make_slice_transform
(
WTild
e
,
IWTild
e
SliceBegin
,
WTild
e
Slice
),
make_pass_through_transform
(
C
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
1
>
{},
...
@@ -258,8 +258,8 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
...
@@ -258,8 +258,8 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
Sequence
<
3
>
{}));
Sequence
<
3
>
{}));
const
auto
in_gemmm_gemmn_grid_desc
=
transform_tensor_descriptor
(
const
auto
in_gemmm_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_n_htild
a
slice_wtild
a
slice_c_grid_desc
,
in_n_htild
e
slice_wtild
e
slice_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
HTild
a
Slice
,
WTild
a
Slice
)),
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
HTild
e
Slice
,
WTild
e
Slice
)),
make_pass_through_transform
(
C
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
...
...
include/ck/tensor_operation/gpu/device/conv_utils.hpp
View file @
cc8df39e
...
@@ -108,6 +108,28 @@ struct ConvParams
...
@@ -108,6 +108,28 @@ struct ConvParams
input_right_pads
(
2
,
1
)
input_right_pads
(
2
,
1
)
{
{
}
}
ConvParams
(
ck
::
index_t
n_dim_spatial
,
ck
::
index_t
n
,
ck
::
index_t
k
,
ck
::
index_t
c
,
std
::
vector
<
ck
::
index_t
>
filter_lengths
,
std
::
vector
<
ck
::
index_t
>
input_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_strides
,
std
::
vector
<
ck
::
index_t
>
conv_dilations
,
std
::
vector
<
ck
::
index_t
>
left_pads
,
std
::
vector
<
ck
::
index_t
>
right_pads
)
:
num_dim_spatial
(
n_dim_spatial
),
N
(
n
),
K
(
k
),
C
(
c
),
filter_spatial_lengths
(
filter_lengths
),
input_spatial_lengths
(
input_lengths
),
conv_filter_strides
(
conv_strides
),
conv_filter_dilations
(
conv_dilations
),
input_left_pads
(
left_pads
),
input_right_pads
(
right_pads
)
{
}
ck
::
index_t
num_dim_spatial
;
ck
::
index_t
num_dim_spatial
;
ck
::
index_t
N
;
ck
::
index_t
N
;
...
@@ -206,7 +228,7 @@ HostTensorDescriptor GetHostTensorDescriptor(const std::vector<std::size_t>& dim
...
@@ -206,7 +228,7 @@ HostTensorDescriptor GetHostTensorDescriptor(const std::vector<std::size_t>& dim
return
HostTensorDescriptor
(
return
HostTensorDescriptor
(
dims
,
dims
,
std
::
vector
<
std
::
size_t
>
{
std
::
vector
<
std
::
size_t
>
{
C
*
dims
[
2
]
*
dims
[
3
]
*
dims
[
4
],
1
,
C
*
dims
[
3
]
*
dims
[
4
]
,
C
*
dims
[
4
],
C
});
C
*
dims
[
2
]
*
dims
[
3
]
*
dims
[
4
],
1
,
dims
[
3
]
*
dims
[
4
]
*
C
,
dims
[
4
]
*
C
,
C
});
}
}
std
::
stringstream
err_msg
;
std
::
stringstream
err_msg
;
...
...
include/ck/tensor_operation/gpu/device/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp
View file @
cc8df39e
...
@@ -95,8 +95,8 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -95,8 +95,8 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
index_t
i_ytild
a
,
index_t
i_ytild
e
,
index_t
i_xtild
a
)
index_t
i_xtild
e
)
{
{
using
namespace
ck
;
using
namespace
ck
;
...
@@ -177,34 +177,34 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -177,34 +177,34 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
const
auto
GcdStrideDilationH
=
math
::
gcd
(
ConvStrideH
,
ConvDilationH
);
const
auto
GcdStrideDilationH
=
math
::
gcd
(
ConvStrideH
,
ConvDilationH
);
const
auto
GcdStrideDilationW
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
const
auto
GcdStrideDilationW
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
const
auto
YTild
a
=
ConvStrideH
/
GcdStrideDilationH
;
const
auto
YTild
e
=
ConvStrideH
/
GcdStrideDilationH
;
const
auto
XTild
a
=
ConvStrideW
/
GcdStrideDilationW
;
const
auto
XTild
e
=
ConvStrideW
/
GcdStrideDilationW
;
const
auto
YDot
=
math
::
integer_divide_ceil
(
Y
,
YTild
a
);
const
auto
YDot
=
math
::
integer_divide_ceil
(
Y
,
YTild
e
);
const
auto
XDot
=
math
::
integer_divide_ceil
(
X
,
XTild
a
);
const
auto
XDot
=
math
::
integer_divide_ceil
(
X
,
XTild
e
);
const
auto
HTild
a
=
const
auto
HTild
e
=
Ho
+
math
::
integer_divide_ceil
(
ConvDilationH
*
(
Y
-
I1
),
ConvStrideH
);
Ho
+
math
::
integer_divide_ceil
(
ConvDilationH
*
(
Y
-
I1
),
ConvStrideH
);
const
auto
WTild
a
=
const
auto
WTild
e
=
Wo
+
math
::
integer_divide_ceil
(
ConvDilationW
*
(
X
-
I1
),
ConvStrideW
);
Wo
+
math
::
integer_divide_ceil
(
ConvDilationW
*
(
X
-
I1
),
ConvStrideW
);
// only work on HTild
a
and WTild
a
that contribute to non-padding area of input tensor
// only work on HTild
e
and WTild
e
that contribute to non-padding area of input tensor
const
auto
IHTild
a
SliceBegin
=
math
::
integer_divide_floor
(
const
auto
IHTild
e
SliceBegin
=
math
::
integer_divide_floor
(
math
::
max
(
I0
,
InLeftPadH
-
ConvDilationH
*
(
YTild
a
-
I1
)),
ConvStrideH
);
math
::
max
(
I0
,
InLeftPadH
-
ConvDilationH
*
(
YTild
e
-
I1
)),
ConvStrideH
);
const
auto
IWTild
a
SliceBegin
=
math
::
integer_divide_floor
(
const
auto
IWTild
e
SliceBegin
=
math
::
integer_divide_floor
(
math
::
max
(
I0
,
InLeftPadW
-
ConvDilationW
*
(
XTild
a
-
I1
)),
ConvStrideW
);
math
::
max
(
I0
,
InLeftPadW
-
ConvDilationW
*
(
XTild
e
-
I1
)),
ConvStrideW
);
const
auto
IHTild
a
SliceEnd
=
math
::
min
(
const
auto
IHTild
e
SliceEnd
=
math
::
min
(
HTild
a
,
math
::
integer_divide_ceil
(
InLeftPadH
+
Hi
-
I1
,
ConvStrideH
)
+
I1
);
HTild
e
,
math
::
integer_divide_ceil
(
InLeftPadH
+
Hi
-
I1
,
ConvStrideH
)
+
I1
);
const
auto
IWTild
a
SliceEnd
=
math
::
min
(
const
auto
IWTild
e
SliceEnd
=
math
::
min
(
WTild
a
,
math
::
integer_divide_ceil
(
InLeftPadW
+
Wi
-
I1
,
ConvStrideW
)
+
I1
);
WTild
e
,
math
::
integer_divide_ceil
(
InLeftPadW
+
Wi
-
I1
,
ConvStrideW
)
+
I1
);
const
auto
HTild
a
Slice
=
IHTild
a
SliceEnd
-
IHTild
a
SliceBegin
;
const
auto
HTild
e
Slice
=
IHTild
e
SliceEnd
-
IHTild
e
SliceBegin
;
const
auto
WTild
a
Slice
=
IWTild
a
SliceEnd
-
IWTild
a
SliceBegin
;
const
auto
WTild
e
Slice
=
IWTild
e
SliceEnd
-
IWTild
e
SliceBegin
;
// GemmK is different for each GEMM
// GemmK is different for each GEMM
const
auto
YDotSlice
=
math
::
integer_divide_ceil
(
Y
-
i_ytild
a
,
YTild
a
);
const
auto
YDotSlice
=
math
::
integer_divide_ceil
(
Y
-
i_ytild
e
,
YTild
e
);
const
auto
XDotSlice
=
math
::
integer_divide_ceil
(
X
-
i_xtild
a
,
XTild
a
);
const
auto
XDotSlice
=
math
::
integer_divide_ceil
(
X
-
i_xtild
e
,
XTild
e
);
// A: output tensor
// A: output tensor
const
auto
out_n_hop_wop_k_grid_desc
=
transform_tensor_descriptor
(
const
auto
out_n_hop_wop_k_grid_desc
=
transform_tensor_descriptor
(
...
@@ -216,26 +216,26 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -216,26 +216,26 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
out_n_ydot_htild
a
_xdot_wtild
a
_k_grid_desc
=
transform_tensor_descriptor
(
const
auto
out_n_ydot_htild
e
_xdot_wtild
e
_k_grid_desc
=
transform_tensor_descriptor
(
out_n_hop_wop_k_grid_desc
,
out_n_hop_wop_k_grid_desc
,
make_tuple
(
make_tuple
(
make_pass_through_transform
(
N
),
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
YDot
,
HTild
a
),
make_embed_transform
(
make_tuple
(
YDot
,
HTild
e
),
make_tuple
(
-
ConvDilationH
/
GcdStrideDilationH
,
I1
)),
make_tuple
(
-
ConvDilationH
/
GcdStrideDilationH
,
I1
)),
make_embed_transform
(
make_tuple
(
XDot
,
WTild
a
),
make_embed_transform
(
make_tuple
(
XDot
,
WTild
e
),
make_tuple
(
-
ConvDilationW
/
GcdStrideDilationW
,
I1
)),
make_tuple
(
-
ConvDilationW
/
GcdStrideDilationW
,
I1
)),
make_pass_through_transform
(
K
)),
make_pass_through_transform
(
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
out_n_ydotslice_htild
a
slice_xdotslice_wtild
a
slice_k0_k1_grid_desc
=
const
auto
out_n_ydotslice_htild
e
slice_xdotslice_wtild
e
slice_k0_k1_grid_desc
=
transform_tensor_descriptor
(
transform_tensor_descriptor
(
out_n_ydot_htild
a
_xdot_wtild
a
_k_grid_desc
,
out_n_ydot_htild
e
_xdot_wtild
e
_k_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_tuple
(
make_pass_through_transform
(
N
),
make_slice_transform
(
YDot
,
I0
,
YDotSlice
),
make_slice_transform
(
YDot
,
I0
,
YDotSlice
),
make_slice_transform
(
HTild
a
,
IHTild
a
SliceBegin
,
HTild
a
Slice
),
make_slice_transform
(
HTild
e
,
IHTild
e
SliceBegin
,
HTild
e
Slice
),
make_slice_transform
(
XDot
,
I0
,
XDotSlice
),
make_slice_transform
(
XDot
,
I0
,
XDotSlice
),
make_slice_transform
(
WTild
a
,
IWTild
a
SliceBegin
,
WTild
a
Slice
),
make_slice_transform
(
WTild
e
,
IWTild
e
SliceBegin
,
WTild
e
Slice
),
make_unmerge_transform
(
make_tuple
(
K0
,
K1
))),
make_unmerge_transform
(
make_tuple
(
K0
,
K1
))),
make_tuple
(
Sequence
<
0
>
{},
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
1
>
{},
...
@@ -251,32 +251,32 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -251,32 +251,32 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
Sequence
<
5
,
6
>
{}));
Sequence
<
5
,
6
>
{}));
const
auto
out_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
const
auto
out_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
out_n_ydotslice_htild
a
slice_xdotslice_wtild
a
slice_k0_k1_grid_desc
,
out_n_ydotslice_htild
e
slice_xdotslice_wtild
e
slice_k0_k1_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
YDotSlice
,
XDotSlice
,
K0
)),
make_tuple
(
make_merge_transform
(
make_tuple
(
YDotSlice
,
XDotSlice
,
K0
)),
make_merge_transform
(
make_tuple
(
N
,
HTild
a
Slice
,
WTild
a
Slice
)),
make_merge_transform
(
make_tuple
(
N
,
HTild
e
Slice
,
WTild
e
Slice
)),
make_pass_through_transform
(
K1
)),
make_pass_through_transform
(
K1
)),
make_tuple
(
Sequence
<
1
,
3
,
5
>
{},
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
6
>
{}),
make_tuple
(
Sequence
<
1
,
3
,
5
>
{},
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
6
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
// B weight tensor
// B weight tensor
const
auto
wei_k_ydot_ytild
a
_xdot_xtild
a
_c_grid_desc
=
transform_tensor_descriptor
(
const
auto
wei_k_ydot_ytild
e
_xdot_xtild
e
_c_grid_desc
=
transform_tensor_descriptor
(
wei_k_y_x_c_grid_desc
,
wei_k_y_x_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
K
),
make_tuple
(
make_pass_through_transform
(
K
),
make_embed_transform
(
make_tuple
(
YDot
,
YTild
a
),
make_embed_transform
(
make_tuple
(
YDot
,
YTild
e
),
make_tuple
(
ConvStrideH
/
GcdStrideDilationH
,
I1
)),
make_tuple
(
ConvStrideH
/
GcdStrideDilationH
,
I1
)),
make_embed_transform
(
make_tuple
(
XDot
,
XTild
a
),
make_embed_transform
(
make_tuple
(
XDot
,
XTild
e
),
make_tuple
(
ConvStrideW
/
GcdStrideDilationW
,
I1
)),
make_tuple
(
ConvStrideW
/
GcdStrideDilationW
,
I1
)),
make_pass_through_transform
(
C
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
wei_k0_k1_ydotslice_xdotslice_c_grid_desc
=
const
auto
wei_k0_k1_ydotslice_xdotslice_c_grid_desc
=
transform_tensor_descriptor
(
wei_k_ydot_ytild
a
_xdot_xtild
a
_c_grid_desc
,
transform_tensor_descriptor
(
wei_k_ydot_ytild
e
_xdot_xtild
e
_c_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1
)),
make_slice_transform
(
YDot
,
I0
,
YDotSlice
),
make_slice_transform
(
YDot
,
I0
,
YDotSlice
),
make_slice_transform
(
XDot
,
I0
,
XDotSlice
),
make_slice_transform
(
XDot
,
I0
,
XDotSlice
),
make_freeze_transform
(
i_ytild
a
),
make_freeze_transform
(
i_ytild
e
),
make_freeze_transform
(
i_xtild
a
),
make_freeze_transform
(
i_xtild
e
),
make_pass_through_transform
(
C
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
1
>
{},
...
@@ -309,24 +309,24 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -309,24 +309,24 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_n_ytild
a
_htild
a
_xtild
a
_wtild
a
_c_grid_desc
=
transform_tensor_descriptor
(
const
auto
in_n_ytild
e
_htild
e
_xtild
e
_wtild
e
_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hip_wip_c_grid_desc
,
in_n_hip_wip_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
YTild
a
,
HTild
a
),
make_embed_transform
(
make_tuple
(
YTild
e
,
HTild
e
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
XTild
a
,
WTild
a
),
make_embed_transform
(
make_tuple
(
XTild
e
,
WTild
e
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
in_n_htild
a
slice_wtild
a
slice_c_grid_desc
=
transform_tensor_descriptor
(
const
auto
in_n_htild
e
slice_wtild
e
slice_c_grid_desc
=
transform_tensor_descriptor
(
in_n_ytild
a
_htild
a
_xtild
a
_wtild
a
_c_grid_desc
,
in_n_ytild
e
_htild
e
_xtild
e
_wtild
e
_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_tuple
(
make_pass_through_transform
(
N
),
make_freeze_transform
(
i_ytild
a
),
make_freeze_transform
(
i_ytild
e
),
make_slice_transform
(
HTild
a
,
IHTild
a
SliceBegin
,
HTild
a
Slice
),
make_slice_transform
(
HTild
e
,
IHTild
e
SliceBegin
,
HTild
e
Slice
),
make_freeze_transform
(
i_xtild
a
),
make_freeze_transform
(
i_xtild
e
),
make_slice_transform
(
WTild
a
,
IWTild
a
SliceBegin
,
WTild
a
Slice
),
make_slice_transform
(
WTild
e
,
IWTild
e
SliceBegin
,
WTild
e
Slice
),
make_pass_through_transform
(
C
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
1
>
{},
...
@@ -342,8 +342,8 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -342,8 +342,8 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
Sequence
<
3
>
{}));
Sequence
<
3
>
{}));
const
auto
in_gemmm_gemmn_grid_desc
=
transform_tensor_descriptor
(
const
auto
in_gemmm_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_n_htild
a
slice_wtild
a
slice_c_grid_desc
,
in_n_htild
e
slice_wtild
e
slice_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
HTild
a
Slice
,
WTild
a
Slice
)),
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
HTild
e
Slice
,
WTild
e
Slice
)),
make_pass_through_transform
(
C
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
...
@@ -452,18 +452,18 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -452,18 +452,18 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
const
auto
GcdStrideDilationH
=
math
::
gcd
(
ConvStrideH
,
ConvDilationH
);
const
auto
GcdStrideDilationH
=
math
::
gcd
(
ConvStrideH
,
ConvDilationH
);
const
auto
GcdStrideDilationW
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
const
auto
GcdStrideDilationW
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
const
auto
YTild
a
=
ConvStrideH
/
GcdStrideDilationH
;
const
auto
YTild
e
=
ConvStrideH
/
GcdStrideDilationH
;
const
auto
XTild
a
=
ConvStrideW
/
GcdStrideDilationW
;
const
auto
XTild
e
=
ConvStrideW
/
GcdStrideDilationW
;
for
(
index_t
i_ytild
a
=
0
;
i_ytild
a
<
YTild
a
;
++
i_ytild
a
)
for
(
index_t
i_ytild
e
=
0
;
i_ytild
e
<
YTild
e
;
++
i_ytild
e
)
{
{
for
(
index_t
i_xtild
a
=
0
;
i_xtild
a
<
XTild
a
;
++
i_xtild
a
)
for
(
index_t
i_xtild
e
=
0
;
i_xtild
e
<
XTild
e
;
++
i_xtild
e
)
{
{
// check slice is valid
// check slice is valid
const
index_t
Y
=
filter_spatial_lengths_
[
0
];
const
index_t
Y
=
filter_spatial_lengths_
[
0
];
const
index_t
X
=
filter_spatial_lengths_
[
1
];
const
index_t
X
=
filter_spatial_lengths_
[
1
];
const
auto
YDotSlice
=
math
::
integer_divide_ceil
(
Y
-
i_ytild
a
,
YTild
a
);
const
auto
YDotSlice
=
math
::
integer_divide_ceil
(
Y
-
i_ytild
e
,
YTild
e
);
const
auto
XDotSlice
=
math
::
integer_divide_ceil
(
X
-
i_xtild
a
,
XTild
a
);
const
auto
XDotSlice
=
math
::
integer_divide_ceil
(
X
-
i_xtild
e
,
XTild
e
);
if
(
YDotSlice
*
XDotSlice
<=
0
)
if
(
YDotSlice
*
XDotSlice
<=
0
)
{
{
continue
;
continue
;
...
@@ -480,8 +480,8 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -480,8 +480,8 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
conv_filter_dilations
,
conv_filter_dilations
,
input_left_pads
,
input_left_pads
,
input_right_pads
,
input_right_pads
,
i_ytild
a
,
i_ytild
e
,
i_xtild
a
);
i_xtild
e
);
a_grid_desc_k0_m_k1_container_
.
push_back
(
descs
[
I0
]);
a_grid_desc_k0_m_k1_container_
.
push_back
(
descs
[
I0
]);
b_grid_desc_k0_n_k1_container_
.
push_back
(
descs
[
I1
]);
b_grid_desc_k0_n_k1_container_
.
push_back
(
descs
[
I1
]);
c_grid_desc_m_n_container_
.
push_back
(
descs
[
I2
]);
c_grid_desc_m_n_container_
.
push_back
(
descs
[
I2
]);
...
@@ -533,7 +533,6 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -533,7 +533,6 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
float
Run
(
const
Argument
&
arg
,
int
nrepeat
=
1
)
float
Run
(
const
Argument
&
arg
,
int
nrepeat
=
1
)
{
{
nrepeat
=
1
;
float
ave_time
=
0
;
float
ave_time
=
0
;
for
(
size_t
i
=
0
;
i
<
arg
.
a_grid_desc_k0_m_k1_container_
.
size
();
i
++
)
for
(
size_t
i
=
0
;
i
<
arg
.
a_grid_desc_k0_m_k1_container_
.
size
();
i
++
)
{
{
...
...
include/ck/tensor_operation/gpu/device/device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp
0 → 100644
View file @
cc8df39e
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/device/tensor_layout.hpp
View file @
cc8df39e
...
@@ -100,7 +100,6 @@ struct NDHWK : public BaseTensorLayout
...
@@ -100,7 +100,6 @@ struct NDHWK : public BaseTensorLayout
{
{
static
constexpr
const
char
*
name
=
"NDHWK"
;
static
constexpr
const
char
*
name
=
"NDHWK"
;
};
};
struct
NCDHW
:
public
BaseTensorLayout
struct
NCDHW
:
public
BaseTensorLayout
{
{
static
constexpr
const
char
*
name
=
"NCDHW"
;
static
constexpr
const
char
*
name
=
"NCDHW"
;
...
...
library/include/ck/library/obselete_driver_offline/device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp
View file @
cc8df39e
...
@@ -303,14 +303,14 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
...
@@ -303,14 +303,14 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
const
auto
GcdStrideDilationH
=
math
::
gcd
(
ConvStrideH
,
ConvDilationH
);
const
auto
GcdStrideDilationH
=
math
::
gcd
(
ConvStrideH
,
ConvDilationH
);
const
auto
GcdStrideDilationW
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
const
auto
GcdStrideDilationW
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
const
auto
YTild
a
=
ConvStrideH
/
GcdStrideDilationH
;
const
auto
YTild
e
=
ConvStrideH
/
GcdStrideDilationH
;
const
auto
XTild
a
=
ConvStrideW
/
GcdStrideDilationW
;
const
auto
XTild
e
=
ConvStrideW
/
GcdStrideDilationW
;
float
ave_time
=
0
;
float
ave_time
=
0
;
for
(
index_t
i_ytild
a
=
0
;
i_ytild
a
<
YTild
a
;
++
i_ytild
a
)
for
(
index_t
i_ytild
e
=
0
;
i_ytild
e
<
YTild
e
;
++
i_ytild
e
)
{
{
for
(
index_t
i_xtild
a
=
0
;
i_xtild
a
<
XTild
a
;
++
i_xtild
a
)
for
(
index_t
i_xtild
e
=
0
;
i_xtild
e
<
XTild
e
;
++
i_xtild
e
)
{
{
const
auto
descs
=
const
auto
descs
=
transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk
(
transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk
(
...
@@ -321,8 +321,8 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
...
@@ -321,8 +321,8 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
conv_dilations
,
conv_dilations
,
in_left_pads
,
in_left_pads
,
in_right_pads
,
in_right_pads
,
i_ytild
a
,
i_ytild
e
,
i_xtild
a
,
i_xtild
e
,
Number
<
GemmK1
>
{});
Number
<
GemmK1
>
{});
const
auto
out_gemmk0_gemmm_gemmk1_grid_desc
=
descs
[
I0
];
const
auto
out_gemmk0_gemmm_gemmk1_grid_desc
=
descs
[
I0
];
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp
View file @
cc8df39e
...
@@ -14,17 +14,20 @@ namespace host {
...
@@ -14,17 +14,20 @@ namespace host {
template
<
typename
InDataType
,
template
<
typename
InDataType
,
typename
WeiDataType
,
typename
WeiDataType
,
typename
OutDataType
,
typename
OutDataType
,
typename
AccDataType
,
typename
InElementwiseOperation
,
typename
InElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
OutElementwiseOperation
>
typename
OutElementwiseOperation
,
ck
::
index_t
NumDimSpatial
=
2
,
typename
std
::
enable_if
<
NumDimSpatial
>
=
1
&&
NumDimSpatial
<=
3
,
bool
>::
type
=
false
>
struct
ReferenceConvBwdData
:
public
device
::
BaseOperator
struct
ReferenceConvBwdData
:
public
device
::
BaseOperator
{
{
// Argument
// Argument
struct
Argument
:
public
device
::
BaseArgument
struct
Argument
:
public
device
::
BaseArgument
{
{
Argument
(
Tensor
<
InDataType
>&
in
_n_c_hi_wi
,
Argument
(
Tensor
<
InDataType
>&
in
put
,
const
Tensor
<
WeiDataType
>&
wei
_k_c_y_x
,
const
Tensor
<
WeiDataType
>&
wei
ght
,
const
Tensor
<
OutDataType
>&
out
_n_k_ho_wo
,
const
Tensor
<
OutDataType
>&
out
put
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
...
@@ -32,9 +35,9 @@ struct ReferenceConvBwdData : public device::BaseOperator
...
@@ -32,9 +35,9 @@ struct ReferenceConvBwdData : public device::BaseOperator
InElementwiseOperation
in_element_op
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
)
OutElementwiseOperation
out_element_op
)
:
in
_n_c_hi_wi_
{
in_n_c_hi_wi
},
:
in
put_
{
input
},
wei
_k_c_y_x_
{
wei_k_c_y_x
},
wei
ght_
{
weight
},
out
_n_k_ho_wo_
{
out_n_k_ho_wo
},
out
put_
{
output
},
conv_strides_
{
conv_filter_strides
},
conv_strides_
{
conv_filter_strides
},
conv_dilations_
{
conv_filter_dilations
},
conv_dilations_
{
conv_filter_dilations
},
in_left_pads_
{
input_left_pads
},
in_left_pads_
{
input_left_pads
},
...
@@ -45,9 +48,9 @@ struct ReferenceConvBwdData : public device::BaseOperator
...
@@ -45,9 +48,9 @@ struct ReferenceConvBwdData : public device::BaseOperator
{
{
}
}
Tensor
<
InDataType
>&
in
_n_c_hi_wi
_
;
Tensor
<
InDataType
>&
in
put
_
;
const
Tensor
<
WeiDataType
>&
wei
_k_c_y_x
_
;
const
Tensor
<
WeiDataType
>&
wei
ght
_
;
const
Tensor
<
OutDataType
>&
out
_n_k_ho_wo
_
;
const
Tensor
<
OutDataType
>&
out
put
_
;
std
::
vector
<
index_t
>
conv_strides_
;
std
::
vector
<
index_t
>
conv_strides_
;
std
::
vector
<
index_t
>
conv_dilations_
;
std
::
vector
<
index_t
>
conv_dilations_
;
...
@@ -65,16 +68,65 @@ struct ReferenceConvBwdData : public device::BaseOperator
...
@@ -65,16 +68,65 @@ struct ReferenceConvBwdData : public device::BaseOperator
using
Argument
=
ReferenceConvBwdData
::
Argument
;
using
Argument
=
ReferenceConvBwdData
::
Argument
;
float
Run
(
const
Argument
&
arg
)
float
Run
(
const
Argument
&
arg
)
{
if
constexpr
(
NumDimSpatial
==
1
)
{
auto
f_nchw
=
[
&
](
auto
n
,
auto
c
,
auto
wi
)
{
std
::
size_t
K
=
arg
.
weight_
.
mDesc
.
GetLengths
()[
0
];
std
::
size_t
X
=
arg
.
weight_
.
mDesc
.
GetLengths
()[
2
];
std
::
size_t
Wo
=
arg
.
output_
.
mDesc
.
GetLengths
()[
2
];
AccDataType
v_acc
=
0
;
for
(
int
x
=
0
;
x
<
X
;
++
x
)
{
int
w_tmp
=
wi
+
arg
.
in_left_pads_
[
0
]
-
x
*
arg
.
conv_dilations_
[
0
];
if
(
w_tmp
%
arg
.
conv_strides_
[
0
]
==
0
)
{
int
wo
=
w_tmp
/
arg
.
conv_strides_
[
0
];
if
(
wo
>=
0
&&
wo
<
Wo
)
{
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
AccDataType
v_out
=
0
;
AccDataType
v_wei
=
0
;
arg
.
out_element_op_
(
v_out
,
ck
::
type_convert
<
AccDataType
>
(
arg
.
output_
(
n
,
k
,
wo
)));
arg
.
wei_element_op_
(
v_wei
,
ck
::
type_convert
<
AccDataType
>
(
arg
.
weight_
(
k
,
c
,
x
)));
v_acc
+=
v_out
*
v_wei
;
}
}
}
}
float
v_in
;
arg
.
in_element_op_
(
v_in
,
v_acc
);
arg
.
input_
(
n
,
c
,
wi
)
=
ck
::
type_convert
<
InDataType
>
(
v_in
);
};
make_ParallelTensorFunctor
(
f_nchw
,
arg
.
input_
.
mDesc
.
GetLengths
()[
0
],
arg
.
input_
.
mDesc
.
GetLengths
()[
1
],
arg
.
input_
.
mDesc
.
GetLengths
()[
2
])(
std
::
thread
::
hardware_concurrency
());
return
0
;
}
else
if
constexpr
(
NumDimSpatial
==
2
)
{
{
auto
f_nchw
=
[
&
](
auto
n
,
auto
c
,
auto
hi
,
auto
wi
)
{
auto
f_nchw
=
[
&
](
auto
n
,
auto
c
,
auto
hi
,
auto
wi
)
{
std
::
size_t
K
=
arg
.
wei
_k_c_y_x
_
.
mDesc
.
GetLengths
()[
0
];
std
::
size_t
K
=
arg
.
wei
ght
_
.
mDesc
.
GetLengths
()[
0
];
std
::
size_t
Y
=
arg
.
wei
_k_c_y_x
_
.
mDesc
.
GetLengths
()[
2
];
std
::
size_t
Y
=
arg
.
wei
ght
_
.
mDesc
.
GetLengths
()[
2
];
std
::
size_t
X
=
arg
.
wei
_k_c_y_x
_
.
mDesc
.
GetLengths
()[
3
];
std
::
size_t
X
=
arg
.
wei
ght
_
.
mDesc
.
GetLengths
()[
3
];
std
::
size_t
Ho
=
arg
.
out
_n_k_ho_wo
_
.
mDesc
.
GetLengths
()[
2
];
std
::
size_t
Ho
=
arg
.
out
put
_
.
mDesc
.
GetLengths
()[
2
];
std
::
size_t
Wo
=
arg
.
out
_n_k_ho_wo
_
.
mDesc
.
GetLengths
()[
3
];
std
::
size_t
Wo
=
arg
.
out
put
_
.
mDesc
.
GetLengths
()[
3
];
float
v_acc
=
0
;
AccDataType
v_acc
=
0
;
for
(
int
y
=
0
;
y
<
Y
;
++
y
)
for
(
int
y
=
0
;
y
<
Y
;
++
y
)
{
{
...
@@ -86,7 +138,8 @@ struct ReferenceConvBwdData : public device::BaseOperator
...
@@ -86,7 +138,8 @@ struct ReferenceConvBwdData : public device::BaseOperator
{
{
for
(
int
x
=
0
;
x
<
X
;
++
x
)
for
(
int
x
=
0
;
x
<
X
;
++
x
)
{
{
int
w_tmp
=
wi
+
arg
.
in_left_pads_
[
1
]
-
x
*
arg
.
conv_dilations_
[
1
];
int
w_tmp
=
wi
+
arg
.
in_left_pads_
[
1
]
-
x
*
arg
.
conv_dilations_
[
1
];
if
(
w_tmp
%
arg
.
conv_strides_
[
1
]
==
0
)
if
(
w_tmp
%
arg
.
conv_strides_
[
1
]
==
0
)
{
{
int
wo
=
w_tmp
/
arg
.
conv_strides_
[
1
];
int
wo
=
w_tmp
/
arg
.
conv_strides_
[
1
];
...
@@ -94,16 +147,93 @@ struct ReferenceConvBwdData : public device::BaseOperator
...
@@ -94,16 +147,93 @@ struct ReferenceConvBwdData : public device::BaseOperator
{
{
for
(
int
k
=
0
;
k
<
K
;
++
k
)
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
{
float
v_out
=
0
;
AccDataType
v_out
=
0
;
float
v_wei
=
0
;
AccDataType
v_wei
=
0
;
arg
.
out_element_op_
(
v_out
,
ck
::
type_convert
<
AccDataType
>
(
arg
.
output_
(
n
,
k
,
ho
,
wo
)));
arg
.
wei_element_op_
(
v_wei
,
ck
::
type_convert
<
AccDataType
>
(
arg
.
weight_
(
k
,
c
,
y
,
x
)));
v_acc
+=
v_out
*
v_wei
;
}
}
}
}
}
}
}
AccDataType
v_in
;
arg
.
in_element_op_
(
v_in
,
v_acc
);
arg
.
input_
(
n
,
c
,
hi
,
wi
)
=
ck
::
type_convert
<
InDataType
>
(
v_in
);
};
make_ParallelTensorFunctor
(
f_nchw
,
arg
.
input_
.
mDesc
.
GetLengths
()[
0
],
arg
.
input_
.
mDesc
.
GetLengths
()[
1
],
arg
.
input_
.
mDesc
.
GetLengths
()[
2
],
arg
.
input_
.
mDesc
.
GetLengths
()[
3
])(
std
::
thread
::
hardware_concurrency
());
return
0
;
}
else
if
constexpr
(
NumDimSpatial
==
3
)
{
auto
f_nchw
=
[
&
](
auto
n
,
auto
c
,
auto
di
,
auto
hi
,
auto
wi
)
{
std
::
size_t
K
=
arg
.
weight_
.
mDesc
.
GetLengths
()[
0
];
std
::
size_t
Z
=
arg
.
weight_
.
mDesc
.
GetLengths
()[
2
];
std
::
size_t
Y
=
arg
.
weight_
.
mDesc
.
GetLengths
()[
3
];
std
::
size_t
X
=
arg
.
weight_
.
mDesc
.
GetLengths
()[
4
];
std
::
size_t
Do
=
arg
.
output_
.
mDesc
.
GetLengths
()[
2
];
std
::
size_t
Ho
=
arg
.
output_
.
mDesc
.
GetLengths
()[
3
];
std
::
size_t
Wo
=
arg
.
output_
.
mDesc
.
GetLengths
()[
4
];
AccDataType
v_acc
=
0
;
for
(
int
z
=
0
;
z
<
Z
;
++
z
)
{
int
d_tmp
=
di
+
arg
.
in_left_pads_
[
0
]
-
z
*
arg
.
conv_dilations_
[
0
];
if
(
d_tmp
%
arg
.
conv_strides_
[
0
]
==
0
)
{
int
do_
=
d_tmp
/
arg
.
conv_strides_
[
0
];
if
(
do_
>=
0
&&
do_
<
Do
)
{
for
(
int
y
=
0
;
y
<
Y
;
++
y
)
{
int
h_tmp
=
hi
+
arg
.
in_left_pads_
[
1
]
-
y
*
arg
.
conv_dilations_
[
1
];
if
(
h_tmp
%
arg
.
conv_strides_
[
1
]
==
0
)
{
int
ho
=
h_tmp
/
arg
.
conv_strides_
[
1
];
if
(
ho
>=
0
&&
ho
<
Ho
)
{
for
(
int
x
=
0
;
x
<
X
;
++
x
)
{
int
w_tmp
=
wi
+
arg
.
in_left_pads_
[
2
]
-
x
*
arg
.
conv_dilations_
[
2
];
if
(
w_tmp
%
arg
.
conv_strides_
[
2
]
==
0
)
{
int
wo
=
w_tmp
/
arg
.
conv_strides_
[
2
];
if
(
wo
>=
0
&&
wo
<
Wo
)
{
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
AccDataType
v_out
=
0
;
AccDataType
v_wei
=
0
;
arg
.
out_element_op_
(
arg
.
out_element_op_
(
v_out
,
v_out
,
ck
::
type_convert
<
float
>
(
ck
::
type_convert
<
AccDataType
>
(
arg
.
out_n_k_ho_wo_
(
n
,
k
,
ho
,
wo
)));
arg
.
output_
(
arg
.
wei_element_op_
(
v_wei
,
n
,
k
,
do_
,
ho
,
wo
)));
ck
::
type_convert
<
float
>
(
arg
.
wei_element_op_
(
arg
.
wei_k_c_y_x_
(
k
,
c
,
y
,
x
)));
v_wei
,
ck
::
type_convert
<
AccDataType
>
(
arg
.
weight_
(
k
,
c
,
z
,
y
,
x
)));
v_acc
+=
v_out
*
v_wei
;
v_acc
+=
v_out
*
v_wei
;
}
}
...
@@ -113,21 +243,26 @@ struct ReferenceConvBwdData : public device::BaseOperator
...
@@ -113,21 +243,26 @@ struct ReferenceConvBwdData : public device::BaseOperator
}
}
}
}
}
}
}
}
}
float
v_in
;
AccDataType
v_in
;
arg
.
in_element_op_
(
v_in
,
v_acc
);
arg
.
in_element_op_
(
v_in
,
v_acc
);
arg
.
in
_n_c_hi_wi
_
(
n
,
c
,
hi
,
wi
)
=
ck
::
type_convert
<
InDataType
>
(
v_in
);
arg
.
in
put
_
(
n
,
c
,
di
,
hi
,
wi
)
=
ck
::
type_convert
<
InDataType
>
(
v_in
);
};
};
make_ParallelTensorFunctor
(
f_nchw
,
make_ParallelTensorFunctor
(
f_nchw
,
arg
.
in_n_c_hi_wi_
.
mDesc
.
GetLengths
()[
0
],
arg
.
input_
.
mDesc
.
GetLengths
()[
0
],
arg
.
in_n_c_hi_wi_
.
mDesc
.
GetLengths
()[
1
],
arg
.
input_
.
mDesc
.
GetLengths
()[
1
],
arg
.
in_n_c_hi_wi_
.
mDesc
.
GetLengths
()[
2
],
arg
.
input_
.
mDesc
.
GetLengths
()[
2
],
arg
.
in_n_c_hi_wi_
.
mDesc
.
GetLengths
()[
3
])(
arg
.
input_
.
mDesc
.
GetLengths
()[
3
],
arg
.
input_
.
mDesc
.
GetLengths
()[
4
])(
std
::
thread
::
hardware_concurrency
());
std
::
thread
::
hardware_concurrency
());
return
0
;
return
0
;
}
}
}
float
Run
(
const
device
::
BaseArgument
*
p_arg
,
int
)
override
float
Run
(
const
device
::
BaseArgument
*
p_arg
,
int
)
override
{
{
...
@@ -143,9 +278,9 @@ struct ReferenceConvBwdData : public device::BaseOperator
...
@@ -143,9 +278,9 @@ struct ReferenceConvBwdData : public device::BaseOperator
bool
IsSupportedArgument
(
const
device
::
BaseArgument
*
)
override
{
return
true
;
}
bool
IsSupportedArgument
(
const
device
::
BaseArgument
*
)
override
{
return
true
;
}
static
auto
MakeArgument
(
Tensor
<
InDataType
>&
in
_n_c_hi_wi
,
static
auto
MakeArgument
(
Tensor
<
InDataType
>&
in
put
,
const
Tensor
<
WeiDataType
>&
wei
_k_c_y_x
,
const
Tensor
<
WeiDataType
>&
wei
ght
,
const
Tensor
<
OutDataType
>&
out
_n_k_ho_wo
,
const
Tensor
<
OutDataType
>&
out
put
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
...
@@ -154,9 +289,9 @@ struct ReferenceConvBwdData : public device::BaseOperator
...
@@ -154,9 +289,9 @@ struct ReferenceConvBwdData : public device::BaseOperator
WeiElementwiseOperation
wei_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
)
OutElementwiseOperation
out_element_op
)
{
{
return
Argument
{
in
_n_c_hi_wi
,
return
Argument
{
in
put
,
wei
_k_c_y_x
,
wei
ght
,
out
_n_k_ho_wo
,
out
put
,
conv_filter_strides
,
conv_filter_strides
,
conv_filter_dilations
,
conv_filter_dilations
,
input_left_pads
,
input_left_pads
,
...
...
library/src/tensor_operation_instance/gpu/CMakeLists.txt
View file @
cc8df39e
...
@@ -37,4 +37,5 @@ add_subdirectory(conv2d_fwd_bias_relu_add)
...
@@ -37,4 +37,5 @@ add_subdirectory(conv2d_fwd_bias_relu_add)
add_subdirectory
(
conv2d_fwd_bias_relu_atomic_add
)
add_subdirectory
(
conv2d_fwd_bias_relu_atomic_add
)
add_subdirectory
(
conv2d_bwd_data
)
add_subdirectory
(
conv2d_bwd_data
)
add_subdirectory
(
reduce
)
add_subdirectory
(
reduce
)
add_subdirectory
(
convnd_bwd_data
)
add_subdirectory
(
grouped_gemm
)
add_subdirectory
(
grouped_gemm
)
library/src/tensor_operation_instance/gpu/convnd_bwd_data/CMakeLists.txt
0 → 100644
View file @
cc8df39e
# device_convnd_bwd_data_instance
set
(
DEVICE_CONVND_BWD_DATA_INSTANCE_SOURCE
device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instance.cpp;
device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instance.cpp;
device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instance.cpp;
device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instance.cpp;
device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instance.cpp;
device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instance.cpp;
device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp;
device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instance.cpp;
device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instance.cpp;
device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instance.cpp;
device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instance.cpp;
device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instance.cpp;
)
add_library
(
device_convnd_bwd_data_instance SHARED
${
DEVICE_CONVND_BWD_DATA_INSTANCE_SOURCE
}
)
target_compile_features
(
device_convnd_bwd_data_instance PUBLIC
)
set_target_properties
(
device_convnd_bwd_data_instance PROPERTIES POSITION_INDEPENDENT_CODE ON
)
install
(
TARGETS device_convnd_bwd_data_instance LIBRARY DESTINATION lib
)
clang_tidy_check
(
device_convnd_bwd_data_instance
)
library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instance.cpp
0 → 100644
View file @
cc8df39e
#include <stdlib.h>
#include "config.hpp"
#include "device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_conv2d_bwd_data_instance
{
using
BF16
=
ushort
;
using
F32
=
float
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
ConvBwdDataDefault
=
ck
::
tensor_operation
::
device
::
ConvolutionBackwardDataSpecialization_t
::
Default
;
static
constexpr
auto
ConvBwdDataFilter1x1Stride1Pad0
=
ck
::
tensor_operation
::
device
::
ConvolutionBackwardDataSpecialization_t
::
Filter1x1Stride1Pad0
;
// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k]
using
device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances
=
std
::
tuple
<
// clang-format off
//#############################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//#############################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//#############################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//#############################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
<
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
1
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
7
,
1
>
,
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
<
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
1
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
7
,
1
>
,
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
<
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
1
,
128
,
128
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
7
,
1
>
,
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
<
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
1
,
256
,
128
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
7
,
1
>
,
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
<
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
1
,
128
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
7
,
1
>
,
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
<
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
1
,
128
,
64
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
7
,
1
>
,
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
<
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
1
,
64
,
64
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
7
,
1
>
,
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
<
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
1
,
256
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
true
,
7
,
1
>
,
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
<
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
1
,
256
,
64
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
7
,
1
>
,
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
<
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
1
,
128
,
128
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
true
,
7
,
1
>
,
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
<
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
1
,
128
,
32
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
7
,
1
>
,
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
<
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
1
,
64
,
64
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
7
,
1
>
,
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
<
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
1
,
64
,
32
,
64
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
7
,
1
>
// clang-format on
>
;
using
device_conv1d_bwd_data_xdl_nwc_kxc_nwk_1x1_s1_p0_bf16_instances
=
std
::
tuple
<
// clang-format off
//#############################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//#############################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//#############################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//#############################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
<
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
1
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
7
,
1
>
,
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
<
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
1
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
7
,
1
>
,
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
<
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
1
,
128
,
128
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
7
,
1
>
,
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
<
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
1
,
256
,
128
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
7
,
1
>
,
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
<
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
1
,
128
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
7
,
1
>
,
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
<
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
1
,
128
,
64
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
7
,
1
>
,
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
<
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
1
,
64
,
64
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
7
,
1
>
,
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
<
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
1
,
256
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
true
,
7
,
1
>
,
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
<
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
1
,
256
,
64
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
7
,
1
>
,
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
<
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
1
,
128
,
128
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
true
,
7
,
1
>
,
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
<
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
1
,
128
,
32
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
7
,
1
>
,
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
<
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
1
,
64
,
64
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
7
,
1
>
,
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
<
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
1
,
64
,
32
,
64
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
7
,
1
>
// clang-format on
>
;
void
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances
(
std
::
vector
<
DeviceConvBwdDataPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances
{});
add_device_operation_instances
(
instances
,
device_conv1d_bwd_data_xdl_nwc_kxc_nwk_1x1_s1_p0_bf16_instances
{});
}
}
// namespace device_conv2d_bwd_data_instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instance.cpp
0 → 100644
View file @
cc8df39e
#include <stdlib.h>
#include "config.hpp"
#include "device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_conv2d_bwd_data_instance
{
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
ConvBwdDataDefault
=
ck
::
tensor_operation
::
device
::
ConvolutionBackwardDataSpecialization_t
::
Default
;
static
constexpr
auto
ConvBwdDataFilter1x1Stride1Pad0
=
ck
::
tensor_operation
::
device
::
ConvolutionBackwardDataSpecialization_t
::
Filter1x1Stride1Pad0
;
// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k]
using
device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instances
=
std
::
tuple
<
// clang-format off
//#############################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//#############################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//#############################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//#############################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
1
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
7
,
1
>
,
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
1
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
7
,
1
>
,
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
1
,
128
,
128
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
7
,
1
>
,
#if 1
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
1
,
256
,
128
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
7
,
1
>
,
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
1
,
256
,
64
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
7
,
1
>
,
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
1
,
128
,
32
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
7
,
1
>
,
#endif
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
1
,
128
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
7
,
1
>
,
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
1
,
128
,
64
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
7
,
1
>
,
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
1
,
64
,
64
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
7
,
1
>
,
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
1
,
256
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
true
,
7
,
1
>
,
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
1
,
128
,
128
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
true
,
7
,
1
>
,
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
1
,
64
,
64
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
7
,
1
>
,
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
1
,
64
,
32
,
64
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
7
,
1
>
// clang-format on
>
;
using
device_conv1d_bwd_data_xdl_nwc_kxc_nwk_1x1_s1_p0_f16_instances
=
std
::
tuple
<
// clang-format off
//#############################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//#############################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//#############################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//#############################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
1
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
7
,
1
>
,
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
1
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
7
,
1
>
,
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
1
,
128
,
128
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
7
,
1
>
,
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
1
,
256
,
128
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
7
,
1
>
,
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
1
,
128
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
7
,
1
>
,
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
1
,
128
,
64
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
7
,
1
>
,
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
1
,
64
,
64
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
7
,
1
>
,
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
1
,
256
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
true
,
7
,
1
>
,
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
1
,
256
,
64
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
7
,
1
>
,
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
1
,
128
,
128
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
true
,
7
,
1
>
,
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
1
,
128
,
32
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
7
,
1
>
,
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
1
,
64
,
64
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
7
,
1
>
,
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
1
,
64
,
32
,
64
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
7
,
1
>
// clang-format on
>
;
void
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instances
(
std
::
vector
<
DeviceConvBwdDataPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instances
{});
add_device_operation_instances
(
instances
,
device_conv1d_bwd_data_xdl_nwc_kxc_nwk_1x1_s1_p0_f16_instances
{});
}
}
// namespace device_conv2d_bwd_data_instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/convnd_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instance.cpp
0 → 100644
View file @
cc8df39e
#include <stdlib.h>
#include "config.hpp"
#include "device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_conv2d_bwd_data_instance
{
using
F32
=
float
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
ConvBwdDataDefault
=
ck
::
tensor_operation
::
device
::
ConvolutionBackwardDataSpecialization_t
::
Default
;
static
constexpr
auto
ConvBwdDataFilter1x1Stride1Pad0
=
ck
::
tensor_operation
::
device
::
ConvolutionBackwardDataSpecialization_t
::
Filter1x1Stride1Pad0
;
// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k]
using
device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances
=
std
::
tuple
<
// clang-format off
//#############################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//#############################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//#############################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//#############################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
<
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
1
,
256
,
256
,
128
,
4
,
4
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
S
<
4
,
64
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
true
,
7
,
1
>
,
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
<
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
1
,
256
,
128
,
256
,
4
,
4
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
S
<
4
,
64
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
true
,
7
,
1
>
,
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
<
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
1
,
128
,
128
,
128
,
4
,
4
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
S
<
4
,
32
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
true
,
7
,
1
>
,
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
<
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
1
,
256
,
128
,
128
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
S
<
4
,
64
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
true
,
7
,
1
>
,
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
<
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
1
,
128
,
128
,
64
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
S
<
4
,
32
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
true
,
7
,
1
>
,
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
<
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
1
,
128
,
64
,
128
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
S
<
4
,
32
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
true
,
7
,
1
>
,
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
<
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
1
,
64
,
64
,
64
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
true
,
7
,
1
>
,
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
<
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
1
,
256
,
128
,
64
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
S
<
4
,
64
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
4
,
true
,
7
,
1
>
,
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
<
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
1
,
256
,
64
,
128
,
4
,
4
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
S
<
4
,
64
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
true
,
7
,
1
>
,
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
<
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
1
,
128
,
128
,
32
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
S
<
4
,
32
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
4
,
true
,
7
,
1
>
,
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
<
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
1
,
128
,
32
,
128
,
4
,
4
,
32
,
32
,
1
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
S
<
4
,
32
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
true
,
7
,
1
>
,
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
<
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
1
,
64
,
64
,
32
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
true
,
7
,
1
>
,
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
<
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
1
,
64
,
32
,
64
,
4
,
4
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
true
,
7
,
1
>
// clang-format on
>
;
using
device_conv1d_bwd_data_xdl_nwc_kxc_nwk_1x1_s1_p0_f32_instances
=
std
::
tuple
<
// clang-format off
//#############################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Num| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//#############################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Dim| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//#############################################################################| | | | | Operation| Operation| Operation| Specialization|Spatial| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//#############################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
<
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
1
,
256
,
256
,
128
,
4
,
4
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
S
<
4
,
64
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
true
,
7
,
1
>
,
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
<
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
1
,
256
,
128
,
256
,
4
,
4
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
S
<
4
,
64
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
true
,
7
,
1
>
,
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
<
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
1
,
128
,
128
,
128
,
4
,
4
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
S
<
4
,
32
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
true
,
7
,
1
>
,
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
<
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
1
,
256
,
128
,
128
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
S
<
4
,
64
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
true
,
7
,
1
>
,
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
<
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
1
,
128
,
128
,
64
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
S
<
4
,
32
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
true
,
7
,
1
>
,
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
<
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
1
,
128
,
64
,
128
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
S
<
4
,
32
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
true
,
7
,
1
>
,
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
<
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
1
,
64
,
64
,
64
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
true
,
7
,
1
>
,
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
<
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
1
,
256
,
128
,
64
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
S
<
4
,
64
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
4
,
true
,
7
,
1
>
,
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
<
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
1
,
256
,
64
,
128
,
4
,
4
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
S
<
4
,
64
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
true
,
7
,
1
>
,
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
<
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
1
,
128
,
128
,
32
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
S
<
4
,
32
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
4
,
true
,
7
,
1
>
,
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
<
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
1
,
128
,
32
,
128
,
4
,
4
,
32
,
32
,
1
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
S
<
4
,
32
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
true
,
7
,
1
>
,
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
<
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
1
,
64
,
64
,
32
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
true
,
7
,
1
>
,
DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
<
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
1
,
64
,
32
,
64
,
4
,
4
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
true
,
7
,
1
>
// clang-format on
>
;
void
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances
(
std
::
vector
<
DeviceConvBwdDataPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances
{});
add_device_operation_instances
(
instances
,
device_conv1d_bwd_data_xdl_nwc_kxc_nwk_1x1_s1_p0_f32_instances
{});
}
}
// namespace device_conv2d_bwd_data_instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment