Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
8c42225c
Commit
8c42225c
authored
Dec 03, 2019
by
Chao Liu
Browse files
minor bug fix
parent
157491ab
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
57 additions
and
104 deletions
+57
-104
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp
...ution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp
+12
-15
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer.hpp
...n_implicit_gemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer.hpp
+45
-89
No files found.
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp
View file @
8c42225c
...
...
@@ -20,8 +20,8 @@ template <index_t GridSize,
typename
OutGlobalDesc
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
LeftPads
,
typename
RightPads
,
typename
Input
LeftPads
,
typename
Input
RightPads
,
index_t
GemmMPerBlock
,
index_t
GemmNPerBlock
,
index_t
GemmKPerBlock
,
...
...
@@ -98,8 +98,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
PassThrough
<
C
>
{},
Pad
<
Sequence
<
Y
,
X
>
,
Sequence
<
0
,
0
>
,
Sequence
<
Ydot
*
Ytilda
-
Y
,
Xdot
*
Xtilda
-
X
>
,
true
>
{}),
Sequence
<
Ydot
*
Ytilda
-
Y
,
Xdot
*
Xtilda
-
X
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}));
...
...
@@ -121,14 +120,12 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// output tensor
constexpr
auto
out_n_k_hop_wop_global_desc
=
transform_tensor_descriptor
(
out_n_k_ho_wo_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
constexpr
auto
out_n_k_hop_wop_global_desc
=
transform_tensor_descriptor
(
out_n_k_ho_wo_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
K
>
{},
Pad
<
Sequence
<
Ho
,
Wo
>
,
Sequence
<
0
,
0
>
,
Sequence
<
right_pad_ho
,
right_pad_wo
>
,
true
>
{}),
Pad
<
Sequence
<
Ho
,
Wo
>
,
Sequence
<
0
,
0
>
,
Sequence
<
right_pad_ho
,
right_pad_wo
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}));
...
...
@@ -154,7 +151,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
in_n_c_hi_wi_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
C
>
{},
Pad
<
Sequence
<
Hi
,
Wi
>
,
LeftPads
,
RightPads
,
true
>
{}),
Pad
<
Sequence
<
Hi
,
Wi
>
,
Input
LeftPads
,
Input
RightPads
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}));
...
...
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer.hpp
View file @
8c42225c
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_
LDS_DOUBLE_BUFFER_
HPP
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_
LDS_DOUBLE_BUFFER_
HPP
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_HPP
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
#include "blockwise_gemm.hpp"
#include "gridwise_gemm.hpp"
namespace
ck
{
// B = merge(N, Ho, Wo)
// GEMM_M = K
// GEMM_N = N * Ho * Wo
// GEMM_K = C * Y * X
template
<
index_t
GridSize
,
index_t
BlockSize
,
typename
Float
,
typename
AccFloat
,
typename
InGlobalDesc
,
typename
WeiGlobalDesc
,
typename
OutGlobalDesc
,
...
...
@@ -21,9 +22,9 @@ template <index_t GridSize,
typename
ConvDilations
,
typename
LeftPads
,
typename
RightPads
,
index_t
B
PerBlock
,
index_t
K
PerBlock
,
index_t
E
PerBlock
,
index_t
GemmN
PerBlock
,
index_t
GemmM
PerBlock
,
index_t
GemmK
PerBlock
,
index_t
GemmMPerThreadSubC
,
index_t
GemmNPerThreadSubC
,
index_t
GemmMLevel0Cluster
,
...
...
@@ -31,14 +32,8 @@ template <index_t GridSize,
index_t
GemmMLevel1Cluster
,
index_t
GemmNLevel1Cluster
,
index_t
GemmKPerThreadLoop
,
index_t
GemmDataPerReadA
,
index_t
GemmDataPerReadB
,
typename
InBlockCopySubLengths_E_B
,
typename
InBlockCopyClusterLengths_E_B
,
typename
InBlockCopyThreadClusterArrangeOrder
,
typename
InBlockCopySrcAccessOrder
,
typename
InBlockCopyDstAccessOrder
,
index_t
InBlockCopyDataPerAccess_B
,
index_t
GemmThreadGemmDataPerReadM
,
index_t
GemmThreadGemmDataPerReadN
,
typename
WeiBlockCopySubLengths_E_K
,
typename
WeiBlockCopyClusterLengths_E_K
,
typename
WeiBlockCopyThreadClusterArrangeOrder
,
...
...
@@ -46,6 +41,12 @@ template <index_t GridSize,
typename
WeiBlockCopyDstAccessOrder
,
index_t
WeiBlockCopySrcDataPerRead_E
,
index_t
WeiBlockCopyDstDataPerWrite_K
,
typename
InBlockCopySubLengths_E_B
,
typename
InBlockCopyClusterLengths_E_B
,
typename
InBlockCopyThreadClusterArrangeOrder
,
typename
InBlockCopySrcAccessOrder
,
typename
InBlockCopyDstAccessOrder
,
index_t
InBlockCopyDataPerAccess_B
,
index_t
OutThreadCopyDataPerAccess_B
>
struct
GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
{
...
...
@@ -58,8 +59,6 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
in_n_c_hi_wi_global_desc
=
make_native_tensor_descriptor
(
InGlobalDesc
::
GetLengths
(),
InGlobalDesc
::
GetStrides
());
constexpr
auto
wei_k_c_y_x_global_desc
=
...
...
@@ -94,23 +93,11 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated"
);
// divide block work by [K, B]
static_assert
(
K
%
KPerBlock
==
0
&&
B
%
BPerBlock
==
0
&&
E
%
EPerBlock
==
0
,
"wrong! cannot divide work evenly among block"
);
constexpr
index_t
KBlockWork
=
K
/
KPerBlock
;
constexpr
index_t
BBlockWork
=
B
/
BPerBlock
;
constexpr
auto
block_work_desc
=
make_cluster_descriptor
(
Sequence
<
KBlockWork
,
BBlockWork
>
{});
const
auto
block_work_id
=
block_work_desc
.
CalculateClusterIndex
(
get_block_1d_id
());
const
index_t
k_block_data_on_global
=
block_work_id
[
0
]
*
KPerBlock
;
const
index_t
b_block_data_on_global
=
block_work_id
[
1
]
*
BPerBlock
;
// weight tensor
constexpr
auto
wei_gemmk_gemmm_global_desc
=
reorder_tensor_descriptor_given_upper2lower
(
unfold_tensor_descriptor
(
wei_k_c_y_x_global_desc
,
I1
,
I3
),
Sequence
<
1
,
0
>
{});
// input tensor
// global mem
constexpr
auto
in_n_c_hip_wip_global_desc
=
transform_tensor_descriptor
(
in_n_c_hi_wi_global_desc
,
make_tuple
(
...
...
@@ -127,54 +114,23 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
constexpr
auto
in_
e_b
_global_desc
=
transform_tensor_descriptor
(
constexpr
auto
in_
gemmk_gemmn
_global_desc
=
transform_tensor_descriptor
(
in_n_c_y_ho_x_wo_global_desc
,
make_tuple
(
Merge
<
Sequence
<
C
,
Y
,
X
>>
{},
Merge
<
Sequence
<
N
,
Ho
,
Wo
>>
{}),
make_tuple
(
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
0
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// LDS mem
// be careful of LDS alignment
constexpr
auto
in_e_b_block_desc
=
make_native_tensor_descriptor_packed
(
Sequence
<
EPerBlock
,
BPerBlock
>
{});
// input blockwise copy
auto
blockwise_in_copy
=
BlockwiseGenericTensorSliceCopy_v4
<
BlockSize
,
decltype
(
in_e_b_global_desc
),
decltype
(
in_e_b_block_desc
),
decltype
(
in_e_b_block_desc
.
GetLengths
()),
InBlockCopySubLengths_E_B
,
InBlockCopyClusterLengths_E_B
,
InBlockCopyThreadClusterArrangeOrder
,
InBlockCopySrcAccessOrder
,
InBlockCopyDstAccessOrder
,
1
,
1
,
InBlockCopyDataPerAccess_B
,
InBlockCopyDataPerAccess_B
,
AddressSpace
::
global
,
AddressSpace
::
vgpr
,
AddressSpace
::
lds
,
InMemoryDataOperation
::
none
>
(
{
0
,
b_block_data_on_global
},
{
0
,
0
});
// weight tensor
// global mem
constexpr
auto
wei_e_k_global_desc
=
reorder_tensor_descriptor_given_upper2lower
(
unfold_tensor_descriptor
(
wei_k_c_y_x_global_desc
,
I1
,
I3
),
Sequence
<
1
,
0
>
{});
// LDS
// be careful of LDS alignment
constexpr
auto
wei_e_k_block_desc
=
make_native_tensor_descriptor_aligned
(
Sequence
<
E
PerBlock
,
K
PerBlock
>
{},
Number
<
math
::
lcm
(
WeiBlockCopyDstDataPerWrite_K
,
GemmDataPerRead
A
)
>
{});
Sequence
<
GemmK
PerBlock
,
GemmM
PerBlock
>
{},
Number
<
math
::
lcm
(
WeiBlockCopyDstDataPerWrite_K
,
GemmThread
GemmDataPerRead
M
)
>
{});
// this check is ad-hoc
// TODO: need to properly implement tensor descriptor with multiple alignment
// requirements
static_assert
(
wei_e_k_block_desc
.
GetStride
(
I0
)
%
GemmDataPerRead
A
==
0
,
"GemmDataPerRead
A
alignment requirement is not satisfied"
);
static_assert
(
wei_e_k_block_desc
.
GetStride
(
I0
)
%
GemmThread
GemmDataPerRead
M
==
0
,
"
GemmThread
GemmDataPerRead
M
alignment requirement is not satisfied"
);
// weight blockwise copy
auto
blockwise_wei_copy
=
...
...
@@ -199,24 +155,24 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[
E
PerBlock,
K
PerBlock] is in LDS
// b_mtx[EPerBlocl,
B
PerBlock] is in LDS
// c_mtx[
K
PerBlock,
B
PerBlock] is distributed among threads, and saved in
// a_mtx[
GemmK
PerBlock,
GemmM
PerBlock] is in LDS
// b_mtx[EPerBlocl,
GemmN
PerBlock] is in LDS
// c_mtx[
GemmM
PerBlock,
GemmN
PerBlock] is distributed among threads, and saved in
// register
constexpr
auto
a_e_k_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
wei_e_k_block_desc
);
constexpr
auto
b_e_b_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
in_e_b_block_desc
);
// sanity check
static_assert
(
K
PerBlock
%
(
GemmMPerThreadSubC
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
)
==
0
&&
B
PerBlock
%
(
GemmNPerThreadSubC
*
GemmNLevel0Cluster
*
GemmNLevel1Cluster
)
==
0
,
GemmM
PerBlock
%
(
GemmMPerThreadSubC
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
)
==
0
&&
GemmN
PerBlock
%
(
GemmNPerThreadSubC
*
GemmNLevel0Cluster
*
GemmNLevel1Cluster
)
==
0
,
"wrong!"
);
constexpr
index_t
GemmMRepeat
=
K
PerBlock
/
(
GemmMPerThreadSubC
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
);
GemmM
PerBlock
/
(
GemmMPerThreadSubC
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
);
constexpr
index_t
GemmNRepeat
=
B
PerBlock
/
(
GemmNPerThreadSubC
*
GemmNLevel0Cluster
*
GemmNLevel1Cluster
);
GemmN
PerBlock
/
(
GemmNPerThreadSubC
*
GemmNLevel0Cluster
*
GemmNLevel1Cluster
);
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
...
...
@@ -235,14 +191,14 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
GemmDataPerRead
A
,
GemmDataPerRead
B
>
{};
GemmThread
GemmDataPerRead
M
,
GemmThread
GemmDataPerRead
N
>
{};
// LDS allocation for input and weight: be careful of alignment
constexpr
index_t
max_align
=
math
::
lcm
(
InBlockCopyDataPerAccess_B
,
WeiBlockCopyDstDataPerWrite_K
,
GemmDataPerRead
A
,
GemmDataPerRead
B
);
GemmThread
GemmDataPerRead
M
,
GemmThread
GemmDataPerRead
N
);
constexpr
index_t
in_block_space
=
math
::
integer_least_multiple
(
in_e_b_block_desc
.
GetElementSpace
(),
max_align
);
...
...
@@ -266,8 +222,8 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
}
// LDS double buffer: main body
for
(
index_t
e_block_data_begin
=
0
;
e_block_data_begin
+
2
*
E
PerBlock
<
E
;
e_block_data_begin
+=
2
*
E
PerBlock
)
for
(
index_t
e_block_data_begin
=
0
;
e_block_data_begin
+
2
*
GemmK
PerBlock
<
E
;
e_block_data_begin
+=
2
*
GemmK
PerBlock
)
{
#pragma unroll
for
(
index_t
iloop
=
0
;
iloop
<
2
;
++
iloop
)
...
...
@@ -287,8 +243,8 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
Float
p_in_thread_buffer
[
blockwise_in_copy
.
GetThreadBufferSize
()];
Float
p_wei_thread_buffer
[
blockwise_wei_copy
.
GetThreadBufferSize
()];
blockwise_in_copy
.
MoveSrcSliceWindow
(
Sequence
<
E
PerBlock
,
0
>
{},
True
);
blockwise_wei_copy
.
MoveSrcSliceWindow
(
Sequence
<
E
PerBlock
,
0
>
{},
True
);
blockwise_in_copy
.
MoveSrcSliceWindow
(
Sequence
<
GemmK
PerBlock
,
0
>
{},
True
);
blockwise_wei_copy
.
MoveSrcSliceWindow
(
Sequence
<
GemmK
PerBlock
,
0
>
{},
True
);
__syncthreads
();
...
...
@@ -307,15 +263,15 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
// LDS double buffer: tail
{
constexpr
bool
has_two_iteration_left
=
(
E
%
(
2
*
E
PerBlock
)
==
0
);
constexpr
bool
has_two_iteration_left
=
(
E
%
(
2
*
GemmK
PerBlock
)
==
0
);
if
(
has_two_iteration_left
)
// if has 2 iteration left
{
Float
p_in_thread_buffer
[
blockwise_in_copy
.
GetThreadBufferSize
()];
Float
p_wei_thread_buffer
[
blockwise_wei_copy
.
GetThreadBufferSize
()];
blockwise_in_copy
.
MoveSrcSliceWindow
(
Sequence
<
E
PerBlock
,
0
>
{},
True
);
blockwise_wei_copy
.
MoveSrcSliceWindow
(
Sequence
<
E
PerBlock
,
0
>
{},
True
);
blockwise_in_copy
.
MoveSrcSliceWindow
(
Sequence
<
GemmK
PerBlock
,
0
>
{},
True
);
blockwise_wei_copy
.
MoveSrcSliceWindow
(
Sequence
<
GemmK
PerBlock
,
0
>
{},
True
);
__syncthreads
();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment