Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
b7992190
Commit
b7992190
authored
Dec 02, 2019
by
Chao Liu
Browse files
adding bwd data v2r1
parent
cfff66cd
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
445 additions
and
57 deletions
+445
-57
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp
...ution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp
+8
-8
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw_lds_double_buffer.hpp
...a_implicit_gemm_v1r2_nchw_kcyx_nkhw_lds_double_buffer.hpp
+1
-1
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp
...ution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp
+201
-0
driver/include/device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp
...ution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp
+40
-41
driver/include/device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw.hpp
...ution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw.hpp
+1
-1
driver/include/device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp
...ution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp
+182
-0
driver/src/conv_bwd_data_driver.cpp
driver/src/conv_bwd_data_driver.cpp
+12
-6
No files found.
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp
View file @
b7992190
...
...
@@ -19,8 +19,8 @@ template <index_t GridSize,
typename
ConvDilations
,
typename
LeftPads
,
typename
RightPads
,
index_t
BPerBlock
,
index_t
EPerBlock
,
index_t
BPerBlock
,
index_t
KPerBlock
,
index_t
GemmMPerThreadSubC
,
index_t
GemmNPerThreadSubC
,
...
...
@@ -29,14 +29,14 @@ template <index_t GridSize,
index_t
GemmMLevel1Cluster
,
index_t
GemmNLevel1Cluster
,
index_t
GemmKPerThreadLoop
,
index_t
GemmDataPerReadA
,
index_t
GemmDataPerReadB
,
typename
OutBlockCopySubLengths_K_B
,
typename
OutBlockCopyClusterLengths_K_B
,
index_t
OutBlockCopyDataPerAccess_B
,
index_t
GemmThreadGemmDataPerReadM
,
index_t
GemmThreadGemmDataPerReadN
,
typename
WeiBlockCopySubLengths_K_E
,
typename
WeiBlockCopyClusterLengths_K_E
,
index_t
WeiBlockCopyDataPerAccess_E
,
typename
OutBlockCopySubLengths_K_B
,
typename
OutBlockCopyClusterLengths_K_B
,
index_t
OutBlockCopyDataPerAccess_B
,
index_t
InThreadCopyDataPerAccess_B
>
struct
GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
{
...
...
@@ -139,8 +139,8 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
GemmDataPerRead
A
,
GemmDataPerRead
B
,
GemmThread
GemmDataPerRead
M
,
GemmThread
GemmDataPerRead
N
,
WeiBlockCopySubLengths_K_E
,
WeiBlockCopyClusterLengths_K_E
,
WeiBlockCopyDataPerAccess_E
,
...
...
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw_lds_double_buffer.hpp
View file @
b7992190
...
...
@@ -22,8 +22,8 @@ template <index_t GridSize,
typename
ConvDilations
,
typename
LeftPads
,
typename
RightPads
,
index_t
BPerBlock
,
index_t
EPerBlock
,
index_t
BPerBlock
,
index_t
KPerBlock
,
index_t
GemmMPerThreadSubC
,
index_t
GemmNPerThreadSubC
,
...
...
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp
0 → 100644
View file @
b7992190
#ifndef CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V2R1_NCHW_KCYX_NKHW_HPP
#define CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V2R1_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm.hpp"
namespace
ck
{
template
<
index_t
GridSize
,
index_t
BlockSize
,
typename
Float
,
typename
AccFloat
,
typename
InGlobalDesc
,
typename
WeiGlobalDesc
,
typename
OutGlobalDesc
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
LeftPads
,
typename
RightPads
,
index_t
GemmMPerBlock
,
index_t
GemmNPerBlock
,
index_t
GemmKPerBlock
,
index_t
GemmMPerThreadSubC
,
index_t
GemmNPerThreadSubC
,
index_t
GemmMLevel0Cluster
,
index_t
GemmNLevel0Cluster
,
index_t
GemmMLevel1Cluster
,
index_t
GemmNLevel1Cluster
,
index_t
GemmKPerThreadLoop
,
index_t
GemmThreadGemmDataPerReadM
,
index_t
GemmThreadGemmDataPerReadN
,
typename
GemmABlockCopySubLengths
,
// Gemm-K, Gemm-M
typename
GemmABlockCopyClusterLengths
,
// Gemm-K, Gemm-M
index_t
GemmABlockCopyDataPerAccess
,
// Gemm-M
typename
GemmBBlockCopySubLengths
,
// Gemm-K, Gemm-N
typename
GemmBBlockCopyClusterLengths
,
// Gemm-K, Gemm-N
index_t
GemmBBlockCopyDataPerAccess
,
// Gemm-N
index_t
GemmCThreadCopyDataPerAccess
// Gemm-N
>
struct
GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
{
__device__
void
Run
(
Float
*
__restrict__
p_in_global
,
const
Float
*
__restrict__
p_wei_global
,
const
Float
*
__restrict__
p_out_global
)
const
{
constexpr
auto
in_n_c_hi_wi_global_desc
=
InGlobalDesc
{};
constexpr
auto
wei_k_c_y_x_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
out_n_k_ho_wo_global_desc
=
OutGlobalDesc
{};
constexpr
index_t
N
=
in_n_c_hi_wi_global_desc
.
GetLengths
()[
0
];
constexpr
index_t
C
=
in_n_c_hi_wi_global_desc
.
GetLengths
()[
1
];
constexpr
index_t
Hi
=
in_n_c_hi_wi_global_desc
.
GetLengths
()[
2
];
constexpr
index_t
Wi
=
in_n_c_hi_wi_global_desc
.
GetLengths
()[
3
];
constexpr
index_t
K
=
out_n_k_ho_wo_global_desc
.
GetLengths
()[
1
];
constexpr
index_t
Ho
=
out_n_k_ho_wo_global_desc
.
GetLengths
()[
2
];
constexpr
index_t
Wo
=
out_n_k_ho_wo_global_desc
.
GetLengths
()[
3
];
constexpr
index_t
Y
=
wei_k_c_y_x_global_desc
.
GetLengths
()[
2
];
constexpr
index_t
X
=
wei_k_c_y_x_global_desc
.
GetLengths
()[
3
];
constexpr
index_t
ConvStrideH
=
ConvStrides
{}[
0
];
constexpr
index_t
ConvStrideW
=
ConvStrides
{}[
1
];
constexpr
index_t
ConvDilationH
=
ConvDilations
{}[
0
];
constexpr
index_t
ConvDilationW
=
ConvDilations
{}[
1
];
// sanity-check for vectorized memory load
static_assert
((
Wo
==
1
||
(
ConvStrideW
==
1
||
GemmCThreadCopyDataPerAccess
==
1
))
&&
(
X
==
1
||
ConvDilationW
%
GemmCThreadCopyDataPerAccess
==
0
),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated"
);
// TODO: this algo support any stride and dilation. But for now, let's fix them to be 1 for
// simplicity
static_assert
(
ConvStrideH
==
1
&&
ConvStrideW
==
1
&&
ConvDilationH
==
1
&&
ConvDilationW
==
1
,
"wrong! not supported yet"
);
// TODO: these logic are only for stride = 1, dilation = 1
constexpr
index_t
Ydot
=
Y
;
constexpr
index_t
Ytilda
=
1
;
constexpr
index_t
Htilda
=
Ho
+
Y
-
1
;
constexpr
index_t
Xdot
=
X
;
constexpr
index_t
Xtilda
=
1
;
constexpr
index_t
Wtilda
=
Wo
+
X
-
1
;
constexpr
index_t
GemmK
=
K
*
Ydot
*
Xdot
;
constexpr
index_t
GemmM
=
C
*
Ytilda
*
Xtilda
;
constexpr
index_t
GemmN
=
N
*
Htilda
*
Wtilda
;
// weight tensor
constexpr
auto
wei_k_c_ydot_ytilda_xdot_xtilda_global_desc
=
transform_tensor_descriptor
(
wei_k_c_y_x_global_desc
,
make_tuple
(
PassThrough
<
K
>
{},
PassThrough
<
C
>
{},
Embed
<
Sequence
<
Ydot
,
Ytilda
>
,
Sequence
<
1
,
1
,
0
>>
{},
// coefficient may be wrong
Embed
<
Sequence
<
Xdot
,
Xtilda
>
,
Sequence
<
1
,
1
,
0
>>
{}),
// coefficient may be wrong
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
constexpr
auto
wei_gemmk_gemmm_global_desc
=
transform_tensor_descriptor
(
wei_k_c_ydot_ytilda_xdot_xtilda_global_desc
,
make_tuple
(
Merge
<
Sequence
<
K
,
Ydot
,
Xdot
>>
{},
Merge
<
Sequence
<
C
,
Ytilda
,
Xtilda
>>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// output tensor
constexpr
auto
out_n_k_hop_wop_global_desc
=
transform_tensor_descriptor
(
out_n_k_ho_wo_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
K
>
{},
Pad
<
Sequence
<
Ho
,
Wo
>
,
Sequence
<
0
,
0
>
,
Sequence
<
Y
-
1
,
X
-
1
>>
{}),
// coefficient may
// be wrong
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}));
constexpr
auto
out_n_k_ydot_htilda_xdot_wtilda_global_desc
=
transform_tensor_descriptor
(
out_n_k_hop_wop_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
K
>
{},
Embed
<
Sequence
<
Ydot
,
Htilda
>
,
Sequence
<
0
,
1
,
0
>>
{},
// coefficient may be wrong
Embed
<
Sequence
<
Xdot
,
Wtilda
>
,
Sequence
<
0
,
1
,
0
>>
{}),
// coefficient may be wrong
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
constexpr
auto
out_gemmk_gemmn_global_desc
=
transform_tensor_descriptor
(
out_n_k_ydot_htilda_xdot_wtilda_global_desc
,
make_tuple
(
Merge
<
Sequence
<
K
,
Ydot
,
Xdot
>>
{},
Merge
<
Sequence
<
N
,
Htilda
,
Wtilda
>>
{}),
make_tuple
(
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
0
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// input tensor
constexpr
auto
eff_left_pads
=
LeftPads
{}
+
Sequence
<
Y
-
1
,
X
-
1
>
{};
constexpr
auto
eff_right_pads
=
RightPads
{}
+
Sequence
<
Y
-
1
,
X
-
1
>
{};
constexpr
auto
in_n_c_hip_wip_global_desc
=
transform_tensor_descriptor
(
in_n_c_hi_wi_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
C
>
{},
Pad
<
Sequence
<
Hi
,
Wi
>
,
decltype
(
eff_left_pads
),
decltype
(
eff_right_pads
)
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}));
constexpr
auto
in_n_c_ytilda_htilda_xtilda_wtilda_global_desc
=
transform_tensor_descriptor
(
in_n_c_hip_wip_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
C
>
{},
Embed
<
Sequence
<
Ytilda
,
Htilda
>
,
Sequence
<
ConvDilationH
,
ConvStrideH
,
0
>>
{},
Embed
<
Sequence
<
Xtilda
,
Wtilda
>
,
Sequence
<
ConvDilationW
,
ConvStrideW
,
0
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
constexpr
auto
in_gemmm_gemmn_global_desc
=
transform_tensor_descriptor
(
in_n_c_ytilda_htilda_xtilda_wtilda_global_desc
,
make_tuple
(
Merge
<
Sequence
<
C
,
Ytilda
,
Xtilda
>>
{},
Merge
<
Sequence
<
N
,
Htilda
,
Wtilda
>>
{}),
make_tuple
(
Sequence
<
1
,
3
,
5
>
{},
Sequence
<
0
,
2
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// GEMM
constexpr
auto
gridwise_gemm
=
GridwiseGemmTransposedANormalBNormalC_v1r1
<
GridSize
,
BlockSize
,
Float
,
AccFloat
,
decltype
(
wei_gemmk_gemmm_global_desc
),
decltype
(
out_gemmk_gemmn_global_desc
),
decltype
(
in_gemmm_gemmn_global_desc
),
InMemoryDataOperation
::
none
,
GemmMPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
GemmThreadGemmDataPerReadM
,
GemmThreadGemmDataPerReadN
,
GemmABlockCopySubLengths
,
GemmABlockCopyClusterLengths
,
GemmABlockCopyDataPerAccess
,
GemmBBlockCopySubLengths
,
GemmBBlockCopyClusterLengths
,
GemmBBlockCopyDataPerAccess
,
GemmCThreadCopyDataPerAccess
>
{};
gridwise_gemm
.
Run
(
p_wei_global
,
p_out_global
,
p_in_global
);
}
};
}
// namespace ck
#endif
driver/include/device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp
View file @
b7992190
...
...
@@ -49,10 +49,9 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i
// BlockSize = 256, each thread hold 64 data
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
EPerBlock
=
128
;
constexpr
index_t
BPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
8
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
8
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
...
...
@@ -60,27 +59,27 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmDataPerRead
A
=
4
;
constexpr
index_t
GemmDataPerRead
B
=
4
;
constexpr
index_t
GemmThread
GemmDataPerRead
M
=
4
;
constexpr
index_t
GemmThread
GemmDataPerRead
N
=
4
;
using
Out
BlockCopySubLengths
_K_B
=
Sequence
<
4
,
1
>
;
using
Out
BlockCopyClusterLengths
_K_B
=
Sequence
<
2
,
128
>
;
using
GemmA
BlockCopySubLengths
=
Sequence
<
1
,
4
>
;
// Gemm-K, Gemm-M
using
GemmA
BlockCopyClusterLengths
=
Sequence
<
8
,
32
>
;
// Gemm-K, Gemm-M
constexpr
index_t
Out
BlockCopyDataPerAccess
_B
=
1
;
constexpr
index_t
GemmA
BlockCopyDataPerAccess
=
4
;
// Gemm-M
using
Wei
BlockCopySubLengths
_K_E
=
Sequence
<
1
,
4
>
;
using
Wei
BlockCopyClusterLengths
_K_E
=
Sequence
<
8
,
32
>
;
using
GemmB
BlockCopySubLengths
=
Sequence
<
4
,
1
>
;
// Gemm-K, Gemm-N
using
GemmB
BlockCopyClusterLengths
=
Sequence
<
2
,
128
>
;
// Gemm-K, Gemm-N
constexpr
index_t
Wei
BlockCopyDataPerAccess
_E
=
4
;
constexpr
index_t
GemmB
BlockCopyDataPerAccess
=
1
;
// Gemm-N
constexpr
index_t
In
ThreadCopyDataPerAccess
_B
=
1
;
constexpr
index_t
GemmC
ThreadCopyDataPerAccess
=
1
;
// Gemm-N
#endif
constexpr
index_t
E
=
C
*
Y
*
X
;
constexpr
index_t
B
=
(
N
*
Ho
*
Wo
)
;
constexpr
index_t
GemmM
=
C
*
Y
*
X
;
constexpr
index_t
GemmN
=
N
*
Ho
*
Wo
;
constexpr
index_t
GridSize
=
((
E
+
EPerBlock
-
1
)
/
EPerBlock
)
*
((
B
+
B
PerBlock
-
1
)
/
B
PerBlock
);
constexpr
index_t
GridSize
=
((
GemmM
+
GemmMPerBlock
-
1
)
/
GemmMPerBlock
)
*
((
GemmN
+
GemmN
PerBlock
-
1
)
/
GemmN
PerBlock
);
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
...
...
@@ -96,9 +95,9 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i
ConvDilations
,
LeftPads
,
RightPads
,
B
PerBlock
,
E
PerBlock
,
KPerBlock
,
GemmM
PerBlock
,
GemmN
PerBlock
,
Gemm
KPerBlock
,
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
...
...
@@ -106,15 +105,15 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
GemmDataPerRead
A
,
GemmDataPerRead
B
,
Out
BlockCopySubLengths
_K_B
,
Out
BlockCopyClusterLengths
_K_B
,
Out
BlockCopyDataPerAccess
_B
,
Wei
BlockCopySubLengths
_K_E
,
Wei
BlockCopyClusterLengths
_K_E
,
Wei
BlockCopyDataPerAccess
_E
,
In
ThreadCopyDataPerAccess
_B
>
{};
GemmThread
GemmDataPerRead
M
,
GemmThread
GemmDataPerRead
N
,
GemmA
BlockCopySubLengths
,
GemmA
BlockCopyClusterLengths
,
GemmA
BlockCopyDataPerAccess
,
GemmB
BlockCopySubLengths
,
GemmB
BlockCopyClusterLengths
,
GemmB
BlockCopyDataPerAccess
,
GemmC
ThreadCopyDataPerAccess
>
{};
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
...
...
driver/include/device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw.hpp
View file @
b7992190
...
...
@@ -105,8 +105,8 @@ void device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw(InDesc i
ConvDilations
,
LeftPads
,
RightPads
,
BPerBlock
,
EPerBlock
,
BPerBlock
,
KPerBlock
,
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
...
...
driver/include/device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp
0 → 100644
View file @
b7992190
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "tensor.hpp"
#include "gridwise_operation_wrapper.hpp"
#include "gridwise_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp"
template
<
typename
T
,
typename
InDesc
,
typename
WeiDesc
,
typename
OutDesc
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
LeftPads
,
typename
RightPads
>
void
device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw
(
InDesc
in_nchw_desc
,
Tensor
<
T
>&
in_nchw
,
WeiDesc
wei_kcyx_desc
,
const
Tensor
<
T
>&
wei_kcyx
,
OutDesc
out_nkhw_desc
,
const
Tensor
<
T
>&
out_nkhw
,
ConvStrides
,
ConvDilations
,
LeftPads
,
RightPads
,
std
::
size_t
nrepeat
)
{
using
namespace
ck
;
constexpr
index_t
N
=
out_nkhw_desc
.
GetLengths
()[
0
];
constexpr
index_t
K
=
out_nkhw_desc
.
GetLengths
()[
1
];
constexpr
index_t
Ho
=
out_nkhw_desc
.
GetLengths
()[
2
];
constexpr
index_t
Wo
=
out_nkhw_desc
.
GetLengths
()[
3
];
constexpr
index_t
C
=
wei_kcyx_desc
.
GetLengths
()[
1
];
constexpr
index_t
Y
=
wei_kcyx_desc
.
GetLengths
()[
2
];
constexpr
index_t
X
=
wei_kcyx_desc
.
GetLengths
()[
3
];
std
::
size_t
data_sz
=
sizeof
(
T
);
DeviceMem
in_nchw_device_buf
(
data_sz
*
in_nchw
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_kcyx_device_buf
(
data_sz
*
wei_kcyx
.
mDesc
.
GetElementSpace
());
DeviceMem
out_nkhw_device_buf
(
data_sz
*
out_nkhw
.
mDesc
.
GetElementSpace
());
in_nchw_device_buf
.
ToDevice
(
in_nchw
.
mData
.
data
());
wei_kcyx_device_buf
.
ToDevice
(
wei_kcyx
.
mData
.
data
());
out_nkhw_device_buf
.
ToDevice
(
out_nkhw
.
mData
.
data
());
#if 1
// BlockSize = 256, each thread hold 64 data
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
8
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmThreadGemmDataPerReadM
=
4
;
constexpr
index_t
GemmThreadGemmDataPerReadN
=
4
;
using
GemmABlockCopySubLengths
=
Sequence
<
4
,
1
>
;
// Gemm-K, Gemm-M
using
GemmABlockCopyClusterLengths
=
Sequence
<
2
,
128
>
;
// Gemm-K, Gemm-M
constexpr
index_t
GemmABlockCopyDataPerAccess
=
1
;
// Gemm-M
using
GemmBBlockCopySubLengths
=
Sequence
<
4
,
1
>
;
// Gemm-K, Gemm-N
using
GemmBBlockCopyClusterLengths
=
Sequence
<
2
,
128
>
;
// Gemm-K, Gemm-N
constexpr
index_t
GemmBBlockCopyDataPerAccess
=
1
;
// Gemm-N
constexpr
index_t
GemmCThreadCopyDataPerAccess
=
1
;
// Gemm-N
#elif 0
// BlockSize = 256, each thread hold 64 data
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
8
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmThreadGemmDataPerReadM
=
4
;
constexpr
index_t
GemmThreadGemmDataPerReadN
=
4
;
using
GemmABlockCopySubLengths
=
Sequence
<
1
,
4
>
;
// Gemm-K, Gemm-M
using
GemmABlockCopyClusterLengths
=
Sequence
<
8
,
32
>
;
// Gemm-K, Gemm-M
constexpr
index_t
GemmABlockCopyDataPerAccess
=
4
;
// Gemm-M
using
GemmBBlockCopySubLengths
=
Sequence
<
4
,
1
>
;
// Gemm-K, Gemm-N
using
GemmBBlockCopyClusterLengths
=
Sequence
<
2
,
128
>
;
// Gemm-K, Gemm-N
constexpr
index_t
GemmBBlockCopyDataPerAccess
=
1
;
// Gemm-N
constexpr
index_t
GemmCThreadCopyDataPerAccess
=
1
;
// Gemm-N
#endif
// TODO: this algo support any stride and dilation. But for now, let's fix them to be 1 for
// simplicity
constexpr
index_t
Ydot
=
1
;
constexpr
index_t
Ytilda
=
Y
;
constexpr
index_t
Htilda
=
Ho
+
Y
-
1
;
constexpr
index_t
Xdot
=
1
;
constexpr
index_t
Xtilda
=
X
;
constexpr
index_t
Wtilda
=
Wo
+
X
-
1
;
constexpr
index_t
GemmK
=
K
*
Ydot
*
Xdot
;
constexpr
index_t
GemmM
=
C
*
Ytilda
*
Xtilda
;
constexpr
index_t
GemmN
=
N
*
Htilda
*
Wtilda
;
constexpr
index_t
GridSize
=
((
GemmM
+
GemmMPerBlock
-
1
)
/
GemmMPerBlock
)
*
((
GemmN
+
GemmNPerBlock
-
1
)
/
GemmNPerBlock
);
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
constexpr
auto
gridwise_conv
=
GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
<
GridSize
,
BlockSize
,
T
,
T
,
decltype
(
in_nchw_desc
),
decltype
(
wei_kcyx_desc
),
decltype
(
out_nkhw_desc
),
ConvStrides
,
ConvDilations
,
LeftPads
,
RightPads
,
GemmMPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
GemmThreadGemmDataPerReadM
,
GemmThreadGemmDataPerReadN
,
GemmABlockCopySubLengths
,
GemmABlockCopyClusterLengths
,
GemmABlockCopyDataPerAccess
,
GemmBBlockCopySubLengths
,
GemmBBlockCopyClusterLengths
,
GemmBBlockCopyDataPerAccess
,
GemmCThreadCopyDataPerAccess
>
{};
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
float
time
=
launch_kernel
(
run_gridwise_operation
<
decltype
(
gridwise_conv
),
T
*
const
__restrict__
,
const
T
*
const
__restrict__
,
const
T
*
const
__restrict__
>
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
gridwise_conv
,
const_cast
<
T
*
const
__restrict__
>
(
static_cast
<
T
*>
(
in_nchw_device_buf
.
GetDeviceBuffer
())),
const_cast
<
const
T
*
const
__restrict__
>
(
static_cast
<
T
*>
(
wei_kcyx_device_buf
.
GetDeviceBuffer
())),
const_cast
<
const
T
*
const
__restrict__
>
(
static_cast
<
T
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
())));
printf
(
"Elapsed time : %f ms, %f TFlop/s
\n
"
,
time
,
(
float
)
calculate_convolution_flops
(
InDesc
{},
WeiDesc
{},
OutDesc
{})
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
time
);
usleep
(
std
::
min
(
time
*
1000
,
float
(
10000
)));
}
in_nchw_device_buf
.
FromDevice
(
in_nchw
.
mData
.
data
());
}
driver/src/conv_bwd_data_driver.cpp
View file @
b7992190
...
...
@@ -15,6 +15,7 @@
#include "host_conv_bwd_data.hpp"
#include "device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp"
#include "device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw.hpp"
#include "device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp"
int
main
(
int
argc
,
char
*
argv
[])
{
...
...
@@ -34,7 +35,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif
0
#elif
1
// 3x3, 34x34
constexpr
index_t
N
=
64
;
constexpr
index_t
C
=
256
;
...
...
@@ -49,7 +50,7 @@ int main(int argc, char* argv[])
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif
1
#elif
0
// 1x1 filter, 8x8 image
// cudnn@V100 68%, ck@V100 72%, ck@P100 52%, ck@VII 42%
constexpr
index_t
N
=
64
;
...
...
@@ -337,18 +338,23 @@ int main(int argc, char* argv[])
if
(
do_verification
)
{
#if 0
out_nkhw.GenerateTensorValue(GeneratorTensor_1{1}, num_thread);
wei_kcyx.GenerateTensorValue(GeneratorTensor_1{1}, num_thread);
out_nkhw.GenerateTensorValue(GeneratorTensor_1{1}, num_thread);
#elif
0
wei_kcyx
.
GenerateTensorValue
(
GeneratorTensor_2
{
-
5
,
5
},
num_thread
);
out_nkhw
.
GenerateTensorValue
(
GeneratorTensor_1
{
1
},
num_thread
);
#else
out_nkhw
.
GenerateTensorValue
(
GeneratorTensor_2
{
-
5
,
5
},
num_thread
);
wei_kcyx
.
GenerateTensorValue
(
GeneratorTensor_2
{
-
5
,
5
},
num_thread
);
out_nkhw
.
GenerateTensorValue
(
GeneratorTensor_2
{
-
5
,
5
},
num_thread
);
#endif
}
#if
1
#if
0
device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw
#el
se
#el
if
0
device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw
#else
device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw
#endif
(
in_nchw_desc
,
in_nchw_device
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment