Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
ecd3240b
Commit
ecd3240b
authored
Nov 30, 2019
by
Chao Liu
Browse files
refactor
parent
bec35fbc
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
488 additions
and
402 deletions
+488
-402
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp
...ution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp
+157
-0
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw_lds_double_buffer.hpp
...a_implicit_gemm_v1r1_nchw_kcyx_nkhw_lds_double_buffer.hpp
+0
-401
composable_kernel/include/tensor_operation/gridwise_gemm.hpp
composable_kernel/include/tensor_operation/gridwise_gemm.hpp
+330
-0
driver/src/conv_bwd_data_driver.cpp
driver/src/conv_bwd_data_driver.cpp
+1
-1
No files found.
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp
0 → 100644
View file @
ecd3240b
#ifndef CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V1R1_NCHW_KCYX_NKHW_HPP
#define CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V1R1_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm.hpp"
namespace
ck
{
template
<
index_t
GridSize
,
index_t
BlockSize
,
typename
Float
,
typename
AccFloat
,
typename
InGlobalDesc
,
typename
WeiGlobalDesc
,
typename
OutGlobalDesc
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
LeftPads
,
typename
RightPads
,
index_t
BPerBlock
,
index_t
EPerBlock
,
index_t
KPerBlock
,
index_t
GemmMPerThreadSubC
,
index_t
GemmNPerThreadSubC
,
index_t
GemmMLevel0Cluster
,
index_t
GemmNLevel0Cluster
,
index_t
GemmMLevel1Cluster
,
index_t
GemmNLevel1Cluster
,
index_t
GemmKPerThreadLoop
,
index_t
GemmDataPerReadA
,
index_t
GemmDataPerReadB
,
typename
OutBlockCopySubLengths_K_B
,
typename
OutBlockCopyClusterLengths_K_B
,
index_t
OutBlockCopyDataPerAccess_B
,
typename
WeiBlockCopySubLengths_K_E
,
typename
WeiBlockCopyClusterLengths_K_E
,
index_t
WeiBlockCopyDataPerAccess_E
,
index_t
InThreadCopyDataPerAccess_B
>
struct
GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
{
__device__
void
Run
(
Float
*
__restrict__
p_in_global
,
const
Float
*
__restrict__
p_wei_global
,
const
Float
*
__restrict__
p_out_global
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
in_n_c_hi_wi_global_desc
=
InGlobalDesc
{};
constexpr
auto
wei_k_c_y_x_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
out_n_k_ho_wo_global_desc
=
OutGlobalDesc
{};
constexpr
index_t
N
=
in_n_c_hi_wi_global_desc
.
GetLengths
()[
0
];
constexpr
index_t
C
=
in_n_c_hi_wi_global_desc
.
GetLengths
()[
1
];
constexpr
index_t
Hi
=
in_n_c_hi_wi_global_desc
.
GetLengths
()[
2
];
constexpr
index_t
Wi
=
in_n_c_hi_wi_global_desc
.
GetLengths
()[
3
];
constexpr
index_t
K
=
out_n_k_ho_wo_global_desc
.
GetLengths
()[
1
];
constexpr
index_t
Ho
=
out_n_k_ho_wo_global_desc
.
GetLengths
()[
2
];
constexpr
index_t
Wo
=
out_n_k_ho_wo_global_desc
.
GetLengths
()[
3
];
constexpr
index_t
Y
=
wei_k_c_y_x_global_desc
.
GetLengths
()[
2
];
constexpr
index_t
X
=
wei_k_c_y_x_global_desc
.
GetLengths
()[
3
];
constexpr
index_t
ConvStrideH
=
ConvStrides
{}[
0
];
constexpr
index_t
ConvStrideW
=
ConvStrides
{}[
1
];
constexpr
index_t
ConvDilationH
=
ConvDilations
{}[
0
];
constexpr
index_t
ConvDilationW
=
ConvDilations
{}[
1
];
constexpr
index_t
E
=
C
*
Y
*
X
;
constexpr
index_t
B
=
N
*
Ho
*
Wo
;
// sanity-check for vectorized memory load
static_assert
((
Wo
==
1
||
(
ConvStrideW
==
1
||
InThreadCopyDataPerAccess_B
==
1
))
&&
(
X
==
1
||
ConvDilationW
%
InThreadCopyDataPerAccess_B
==
0
),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated"
);
// output tensor
constexpr
auto
out_n_k_howo_global_desc
=
unfold_tensor_descriptor
(
out_n_k_ho_wo_global_desc
,
I2
,
I3
);
constexpr
auto
out_k_b_global_desc
=
transform_tensor_descriptor
(
out_n_k_howo_global_desc
,
make_tuple
(
PassThrough
<
K
>
{},
Merge
<
Sequence
<
N
,
Ho
*
Wo
>>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// weight tensor
constexpr
auto
wei_k_e_global_desc
=
unfold_tensor_descriptor
(
wei_k_c_y_x_global_desc
,
I1
,
I3
);
// input tensor
constexpr
auto
in_n_c_hip_wip_global_desc
=
transform_tensor_descriptor
(
in_n_c_hi_wi_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
C
>
{},
Pad
<
Sequence
<
Hi
,
Wi
>
,
LeftPads
,
RightPads
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}));
constexpr
auto
in_n_c_y_ho_x_wo_global_desc
=
transform_tensor_descriptor
(
in_n_c_hip_wip_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
C
>
{},
Embed
<
Sequence
<
Y
,
Ho
>
,
Sequence
<
ConvDilationH
,
ConvStrideH
,
0
>>
{},
Embed
<
Sequence
<
X
,
Wo
>
,
Sequence
<
ConvDilationW
,
ConvStrideW
,
0
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
constexpr
auto
in_e_b_global_desc
=
transform_tensor_descriptor
(
in_n_c_y_ho_x_wo_global_desc
,
make_tuple
(
Merge
<
Sequence
<
C
,
Y
,
X
>>
{},
Merge
<
Sequence
<
N
,
Ho
,
Wo
>>
{}),
make_tuple
(
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
0
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// GEMM: atomic add
constexpr
auto
gridwise_gemm
=
GridwiseGemmTransposedANormalBNormalC_v1r1
<
GridSize
,
BlockSize
,
Float
,
AccFloat
,
decltype
(
wei_k_e_global_desc
),
decltype
(
out_k_b_global_desc
),
decltype
(
in_e_b_global_desc
),
InMemoryDataOperation
::
atomic_add
,
EPerBlock
,
BPerBlock
,
KPerBlock
,
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
GemmDataPerReadA
,
GemmDataPerReadB
,
WeiBlockCopySubLengths_K_E
,
WeiBlockCopyClusterLengths_K_E
,
WeiBlockCopyDataPerAccess_E
,
OutBlockCopySubLengths_K_B
,
OutBlockCopyClusterLengths_K_B
,
OutBlockCopyDataPerAccess_B
,
InThreadCopyDataPerAccess_B
>
{};
gridwise_gemm
.
Run
(
p_wei_global
,
p_out_global
,
p_in_global
);
}
};
}
// namespace ck
#endif
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw_lds_double_buffer.hpp
deleted
100644 → 0
View file @
bec35fbc
#ifndef CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V1R1_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP
#define CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V1R1_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
#include "blockwise_gemm.hpp"
namespace
ck
{
template
<
index_t
GridSize
,
index_t
BlockSize
,
typename
Float
,
typename
AccDataType
,
typename
InGlobalDesc
,
typename
WeiGlobalDesc
,
typename
OutGlobalDesc
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
LeftPads
,
typename
RightPads
,
index_t
BPerBlock
,
index_t
EPerBlock
,
index_t
KPerBlock
,
index_t
GemmMPerThreadSubC
,
index_t
GemmNPerThreadSubC
,
index_t
GemmMLevel0Cluster
,
index_t
GemmNLevel0Cluster
,
index_t
GemmMLevel1Cluster
,
index_t
GemmNLevel1Cluster
,
index_t
GemmKPerThreadLoop
,
index_t
GemmDataPerReadA
,
index_t
GemmDataPerReadB
,
typename
OutBlockCopySubLengths_K_B
,
typename
OutBlockCopyClusterLengths_K_B
,
index_t
OutBlockCopyDataPerAccess_B
,
typename
WeiBlockCopySubLengths_K_E
,
typename
WeiBlockCopyClusterLengths_K_E
,
index_t
WeiBlockCopyDataPerAccess_E
,
index_t
InThreadCopyDataPerAccess_B
>
struct
GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw_lds_double_buffer
{
__device__
void
Run
(
Float
*
const
__restrict__
p_in_global
,
const
Float
*
const
__restrict__
p_wei_global
,
const
Float
*
const
__restrict__
p_out_global
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
in_n_c_hi_wi_global_desc
=
InGlobalDesc
{};
constexpr
auto
wei_k_c_y_x_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
out_n_k_ho_wo_global_desc
=
OutGlobalDesc
{};
constexpr
index_t
N
=
in_n_c_hi_wi_global_desc
.
GetLengths
()[
0
];
constexpr
index_t
C
=
in_n_c_hi_wi_global_desc
.
GetLengths
()[
1
];
constexpr
index_t
Hi
=
in_n_c_hi_wi_global_desc
.
GetLengths
()[
2
];
constexpr
index_t
Wi
=
in_n_c_hi_wi_global_desc
.
GetLengths
()[
3
];
constexpr
index_t
K
=
out_n_k_ho_wo_global_desc
.
GetLengths
()[
1
];
constexpr
index_t
Ho
=
out_n_k_ho_wo_global_desc
.
GetLengths
()[
2
];
constexpr
index_t
Wo
=
out_n_k_ho_wo_global_desc
.
GetLengths
()[
3
];
constexpr
index_t
Y
=
wei_k_c_y_x_global_desc
.
GetLengths
()[
2
];
constexpr
index_t
X
=
wei_k_c_y_x_global_desc
.
GetLengths
()[
3
];
constexpr
index_t
ConvStrideH
=
ConvStrides
{}[
0
];
constexpr
index_t
ConvStrideW
=
ConvStrides
{}[
1
];
constexpr
index_t
ConvDilationH
=
ConvDilations
{}[
0
];
constexpr
index_t
ConvDilationW
=
ConvDilations
{}[
1
];
constexpr
index_t
E
=
C
*
Y
*
X
;
constexpr
index_t
B
=
N
*
Ho
*
Wo
;
// sanity-check for vectorized memory load
static_assert
((
Wo
==
1
||
(
ConvStrideW
==
1
||
InThreadCopyDataPerAccess_B
==
1
))
&&
(
X
==
1
||
ConvDilationW
%
InThreadCopyDataPerAccess_B
==
0
),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated"
);
// lds max alignment
constexpr
index_t
max_lds_align
=
math
::
lcm
(
WeiBlockCopyDataPerAccess_E
,
OutBlockCopyDataPerAccess_B
,
GemmDataPerReadA
,
GemmDataPerReadB
);
// divide block work by [K, B]
static_assert
(
E
%
EPerBlock
==
0
&&
B
%
BPerBlock
==
0
&&
K
%
KPerBlock
==
0
,
"wrong! cannot divide work evenly among block"
);
constexpr
index_t
EBlockWork
=
E
/
EPerBlock
;
constexpr
index_t
BBlockWork
=
B
/
BPerBlock
;
constexpr
auto
block_work_desc
=
make_cluster_descriptor
(
Sequence
<
EBlockWork
,
BBlockWork
>
{});
const
auto
block_work_id
=
block_work_desc
.
CalculateClusterIndex
(
get_block_1d_id
());
const
index_t
e_block_data_on_global
=
block_work_id
[
0
]
*
EPerBlock
;
const
index_t
b_block_data_on_global
=
block_work_id
[
1
]
*
BPerBlock
;
// output tensor
// global tensor in global memory
constexpr
auto
out_n_k_howo_global_desc
=
unfold_tensor_descriptor
(
out_n_k_ho_wo_global_desc
,
I2
,
I3
);
// global tensor in global memory, src of blockwise copy
constexpr
auto
out_k_b_global_desc
=
transform_tensor_descriptor
(
out_n_k_howo_global_desc
,
make_tuple
(
PassThrough
<
K
>
{},
Merge
<
Sequence
<
N
,
Ho
*
Wo
>>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// block tensor in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
out_k_b_block_desc
=
make_native_tensor_descriptor_aligned
(
Sequence
<
KPerBlock
,
BPerBlock
>
{},
Number
<
max_lds_align
>
{});
// input tensor blockwise copy
auto
blockwise_out_copy
=
BlockwiseGenericTensorSliceCopy_v4
<
BlockSize
,
decltype
(
out_k_b_global_desc
),
decltype
(
out_k_b_block_desc
),
decltype
(
out_k_b_block_desc
.
GetLengths
()),
OutBlockCopySubLengths_K_B
,
OutBlockCopyClusterLengths_K_B
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
1
,
1
,
OutBlockCopyDataPerAccess_B
,
OutBlockCopyDataPerAccess_B
,
AddressSpace
::
global
,
AddressSpace
::
vgpr
,
AddressSpace
::
lds
,
InMemoryDataOperation
::
none
>
(
{
0
,
b_block_data_on_global
},
{
0
,
0
});
// weight tensor
// global tensor in global memory, src of blockwise copy
// It is constructed differently, depending on whether forward or backward weight
// convolution
constexpr
auto
wei_k_e_global_desc
=
unfold_tensor_descriptor
(
wei_k_c_y_x_global_desc
,
I1
,
I3
);
// block tensor in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
wei_k_e_block_desc
=
make_native_tensor_descriptor_aligned
(
Sequence
<
KPerBlock
,
EPerBlock
>
{},
Number
<
max_lds_align
>
{});
// weight tensor blockwise copy
auto
blockwise_wei_copy
=
BlockwiseGenericTensorSliceCopy_v4
<
BlockSize
,
decltype
(
wei_k_e_global_desc
),
decltype
(
wei_k_e_block_desc
),
decltype
(
wei_k_e_block_desc
.
GetLengths
()),
WeiBlockCopySubLengths_K_E
,
WeiBlockCopyClusterLengths_K_E
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
1
,
1
,
WeiBlockCopyDataPerAccess_E
,
WeiBlockCopyDataPerAccess_E
,
AddressSpace
::
global
,
AddressSpace
::
vgpr
,
AddressSpace
::
lds
,
InMemoryDataOperation
::
none
>
(
{
0
,
e_block_data_on_global
},
{
0
,
0
});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, EPerBlock] is in LDS
// b_mtx[KPerBlocl, BPerBlock] is in LDS
// c_mtx[EPerBlock, BPerBlock] is distributed among threads, and saved in
// register
constexpr
auto
a_k_e_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
wei_k_e_block_desc
);
constexpr
auto
b_k_b_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
out_k_b_block_desc
);
// sanity check
static_assert
(
EPerBlock
%
(
GemmMPerThreadSubC
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
)
==
0
&&
BPerBlock
%
(
GemmNPerThreadSubC
*
GemmNLevel0Cluster
*
GemmNLevel1Cluster
)
==
0
,
"wrong!"
);
constexpr
index_t
GemmMRepeat
=
EPerBlock
/
(
GemmMPerThreadSubC
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
);
constexpr
index_t
GemmNRepeat
=
BPerBlock
/
(
GemmNPerThreadSubC
*
GemmNLevel0Cluster
*
GemmNLevel1Cluster
);
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
constexpr
auto
c_e0e1_b0b1_thread_mtx_desc
=
make_ConstantMatrixDescriptor_packed
(
Number
<
GemmMRepeat
*
GemmMPerThreadSubC
>
{},
Number
<
GemmNRepeat
*
GemmNPerThreadSubC
>
{});
const
auto
blockwise_gemm
=
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
<
BlockSize
,
decltype
(
a_k_e_block_mtx_desc
),
decltype
(
b_k_b_block_mtx_desc
),
decltype
(
c_e0e1_b0b1_thread_mtx_desc
),
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
GemmDataPerReadA
,
GemmDataPerReadB
>
{};
// LDS allocation for input and weight: be careful of alignment
constexpr
index_t
out_block_space
=
math
::
integer_least_multiple
(
out_k_b_block_desc
.
GetElementSpace
(),
max_lds_align
);
constexpr
index_t
wei_block_space
=
math
::
integer_least_multiple
(
wei_k_e_block_desc
.
GetElementSpace
(),
max_lds_align
);
__shared__
Float
p_out_block_double
[
2
*
out_block_space
];
__shared__
Float
p_wei_block_double
[
2
*
wei_block_space
];
// register allocation for output
AccDataType
p_in_thread
[
c_e0e1_b0b1_thread_mtx_desc
.
GetElementSpace
()];
// zero out threadwise output
threadwise_matrix_set_zero
(
c_e0e1_b0b1_thread_mtx_desc
,
p_in_thread
);
// LDS double buffer: preload data into LDS
{
blockwise_out_copy
.
Run
(
p_out_global
,
p_out_block_double
);
blockwise_wei_copy
.
Run
(
p_wei_global
,
p_wei_block_double
);
}
// LDS double buffer: main body
for
(
index_t
k_block_data_begin
=
0
;
k_block_data_begin
+
2
*
KPerBlock
<
K
;
k_block_data_begin
+=
2
*
KPerBlock
)
{
#pragma unroll
for
(
index_t
iloop
=
0
;
iloop
<
2
;
++
iloop
)
{
const
bool
even_loop
=
(
iloop
%
2
==
0
);
Float
*
p_out_block_now
=
even_loop
?
p_out_block_double
:
p_out_block_double
+
out_block_space
;
Float
*
p_wei_block_now
=
even_loop
?
p_wei_block_double
:
p_wei_block_double
+
wei_block_space
;
Float
*
p_out_block_next
=
even_loop
?
p_out_block_double
+
out_block_space
:
p_out_block_double
;
Float
*
p_wei_block_next
=
even_loop
?
p_wei_block_double
+
wei_block_space
:
p_wei_block_double
;
Float
p_out_thread_buffer
[
blockwise_out_copy
.
GetThreadBufferSize
()];
Float
p_wei_thread_buffer
[
blockwise_wei_copy
.
GetThreadBufferSize
()];
blockwise_out_copy
.
MoveSrcSliceWindow
(
Sequence
<
KPerBlock
,
0
>
{},
True
);
blockwise_wei_copy
.
MoveSrcSliceWindow
(
Sequence
<
KPerBlock
,
0
>
{},
True
);
__syncthreads
();
// LDS doubel buffer: load next data from device mem
blockwise_out_copy
.
RunLoadThreadBuffer
(
p_out_global
,
p_out_thread_buffer
);
blockwise_wei_copy
.
RunLoadThreadBuffer
(
p_wei_global
,
p_wei_thread_buffer
);
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
p_wei_block_now
,
p_out_block_now
,
p_in_thread
);
// LDS double buffer: store next data to LDS
blockwise_out_copy
.
RunStoreThreadBuffer
(
p_out_thread_buffer
,
p_out_block_next
);
blockwise_wei_copy
.
RunStoreThreadBuffer
(
p_wei_thread_buffer
,
p_wei_block_next
);
}
}
// LDS double buffer: tail
{
constexpr
bool
has_two_iteration_left
=
(
K
%
(
2
*
KPerBlock
)
==
0
);
if
(
has_two_iteration_left
)
// if has 2 iteration left
{
Float
p_out_thread_buffer
[
blockwise_out_copy
.
GetThreadBufferSize
()];
Float
p_wei_thread_buffer
[
blockwise_wei_copy
.
GetThreadBufferSize
()];
blockwise_out_copy
.
MoveSrcSliceWindow
(
Sequence
<
KPerBlock
,
0
>
{},
True
);
blockwise_wei_copy
.
MoveSrcSliceWindow
(
Sequence
<
KPerBlock
,
0
>
{},
True
);
__syncthreads
();
// LDS double buffer: load last data from device mem
blockwise_out_copy
.
RunLoadThreadBuffer
(
p_out_global
,
p_out_thread_buffer
);
blockwise_wei_copy
.
RunLoadThreadBuffer
(
p_wei_global
,
p_wei_thread_buffer
);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm
.
Run
(
p_wei_block_double
,
p_out_block_double
,
p_in_thread
);
// LDS double buffer: store last data to LDS
blockwise_out_copy
.
RunStoreThreadBuffer
(
p_out_thread_buffer
,
p_out_block_double
+
out_block_space
);
blockwise_wei_copy
.
RunStoreThreadBuffer
(
p_wei_thread_buffer
,
p_wei_block_double
+
wei_block_space
);
__syncthreads
();
// LDS double buffer: GEMM on last data
blockwise_gemm
.
Run
(
p_wei_block_double
+
wei_block_space
,
p_out_block_double
+
out_block_space
,
p_in_thread
);
}
else
// if has 1 iteration left
{
__syncthreads
();
// LDS double buffer: GEMM on last data
blockwise_gemm
.
Run
(
p_wei_block_double
,
p_out_block_double
,
p_in_thread
);
}
}
// input: register to global memory, atomic add
{
constexpr
index_t
E1
=
GemmMPerThreadSubC
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
;
constexpr
index_t
E0
=
E
/
E1
;
constexpr
index_t
B1
=
GemmNPerThreadSubC
*
GemmNLevel0Cluster
*
GemmNLevel1Cluster
;
constexpr
index_t
B0
=
B
/
B1
;
// define input tensor descriptor for threadwise copy
// thread input tensor, src of threadwise copy
constexpr
auto
in_e0_e1_b0_b1_thread_desc
=
make_native_tensor_descriptor_packed
(
Sequence
<
GemmMRepeat
,
GemmMPerThreadSubC
,
GemmNRepeat
,
GemmNPerThreadSubC
>
{});
// global input tensor, dst of threadwise copy
constexpr
auto
in_n_c_hip_wip_global_desc
=
transform_tensor_descriptor
(
in_n_c_hi_wi_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
C
>
{},
Pad
<
Sequence
<
Hi
,
Wi
>
,
LeftPads
,
RightPads
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}));
constexpr
auto
in_n_c_y_ho_x_wo_global_desc
=
transform_tensor_descriptor
(
in_n_c_hip_wip_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
C
>
{},
Embed
<
Sequence
<
Y
,
Ho
>
,
Sequence
<
ConvDilationH
,
ConvStrideH
,
0
>>
{},
Embed
<
Sequence
<
X
,
Wo
>
,
Sequence
<
ConvDilationW
,
ConvStrideW
,
0
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
constexpr
auto
in_e_b_global_desc
=
transform_tensor_descriptor
(
in_n_c_y_ho_x_wo_global_desc
,
make_tuple
(
Merge
<
Sequence
<
C
,
Y
,
X
>>
{},
Merge
<
Sequence
<
N
,
Ho
,
Wo
>>
{}),
make_tuple
(
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
0
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
constexpr
auto
in_e0_e1_b0_b1_global_desc
=
transform_tensor_descriptor
(
in_e_b_global_desc
,
make_tuple
(
UnMerge
<
Sequence
<
E0
,
E1
>>
{},
UnMerge
<
Sequence
<
B0
,
B1
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}));
// calculate origin of thread input tensor on global memory
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
blockwise_gemm
.
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
const
index_t
e_thread_data_on_global
=
e_block_data_on_global
+
c_thread_mtx_on_block
.
row
;
const
index_t
b_thread_data_on_global
=
b_block_data_on_global
+
c_thread_mtx_on_block
.
col
;
ThreadwiseGenericTensorSliceCopy_v4r2
<
decltype
(
in_e0_e1_b0_b1_thread_desc
),
decltype
(
in_e0_e1_b0_b1_global_desc
),
decltype
(
in_e0_e1_b0_b1_thread_desc
.
GetLengths
()),
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
InThreadCopyDataPerAccess_B
,
InThreadCopyDataPerAccess_B
,
AddressSpace
::
vgpr
,
AddressSpace
::
global
,
InMemoryDataOperation
::
atomic_add
>
(
{
0
,
0
,
0
,
0
},
{
e_thread_data_on_global
/
E1
,
e_thread_data_on_global
%
E1
,
b_thread_data_on_global
/
B1
,
b_thread_data_on_global
%
B1
})
.
Run
(
p_in_thread
,
p_in_global
);
}
}
};
}
// namespace ck
#endif
composable_kernel/include/tensor_operation/gridwise_gemm.hpp
0 → 100644
View file @
ecd3240b
#ifndef CK_GRIDWISE_GEMM_HPP
#define CK_GRIDWISE_GEMM_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
#include "blockwise_gemm.hpp"
namespace
ck
{
template
<
index_t
GridSize
,
index_t
BlockSize
,
typename
Float
,
typename
AccFloat
,
typename
AGlobalDesc
,
typename
BGlobalDesc
,
typename
CGlobalDesc
,
InMemoryDataOperation
CGlobalMemoryDataOperation
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
MPerThreadSubC
,
index_t
NPerThreadSubC
,
index_t
MLevel0Cluster
,
index_t
NLevel0Cluster
,
index_t
MLevel1Cluster
,
index_t
NLevel1Cluster
,
index_t
KPerThreadLoop
,
index_t
ThreadGemmDataPerReadM
,
index_t
ThreadGemmDataPerReadN
,
typename
ABlockCopySubLengths_K_M
,
typename
ABlockCopyClusterLengths_K_M
,
index_t
ABlockCopyDataPerAccess_M
,
typename
BBlockCopySubLengths_K_N
,
typename
BBlockCopyClusterLengths_K_N
,
index_t
BBlockCopyDataPerAccess_N
,
index_t
CThreadCopyDataPerAccess_N
>
struct
GridwiseGemmTransposedANormalBNormalC_v1r1
{
__device__
void
Run
(
const
Float
*
__restrict__
p_a_global
,
const
Float
*
__restrict__
p_b_global
,
Float
*
__restrict__
p_c_global
)
const
{
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
a_k_m_global_desc
=
AGlobalDesc
{};
constexpr
auto
b_k_n_global_desc
=
BGlobalDesc
{};
constexpr
auto
c_m_n_global_desc
=
CGlobalDesc
{};
constexpr
auto
K
=
a_k_m_global_desc
.
GetLength
(
I0
);
constexpr
auto
M
=
a_k_m_global_desc
.
GetLength
(
I1
);
constexpr
auto
N
=
b_k_n_global_desc
.
GetLength
(
I1
);
// lds max alignment
constexpr
index_t
max_lds_align
=
math
::
lcm
(
ABlockCopyDataPerAccess_M
,
BBlockCopyDataPerAccess_N
,
ThreadGemmDataPerReadM
,
ThreadGemmDataPerReadN
);
// divide block work by [M, N]
static_assert
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
,
"wrong! cannot divide work evenly among block"
);
constexpr
index_t
MBlockWork
=
M
/
MPerBlock
;
constexpr
index_t
NBlockWork
=
N
/
NPerBlock
;
constexpr
auto
block_work_desc
=
make_cluster_descriptor
(
Sequence
<
MBlockWork
,
NBlockWork
>
{});
const
auto
block_work_id
=
block_work_desc
.
CalculateClusterIndex
(
get_block_1d_id
());
const
index_t
m_block_data_on_global
=
block_work_id
[
0
]
*
MPerBlock
;
const
index_t
n_block_data_on_global
=
block_work_id
[
1
]
*
NPerBlock
;
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
a_k_m_block_desc
=
make_native_tensor_descriptor_aligned
(
Sequence
<
KPerBlock
,
MPerBlock
>
{},
Number
<
max_lds_align
>
{});
// A matrix blockwise copy
auto
a_blockwise_copy
=
BlockwiseGenericTensorSliceCopy_v4
<
BlockSize
,
decltype
(
a_k_m_global_desc
),
decltype
(
a_k_m_block_desc
),
decltype
(
a_k_m_block_desc
.
GetLengths
()),
ABlockCopySubLengths_K_M
,
ABlockCopyClusterLengths_K_M
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
1
,
1
,
ABlockCopyDataPerAccess_M
,
ABlockCopyDataPerAccess_M
,
AddressSpace
::
global
,
AddressSpace
::
vgpr
,
AddressSpace
::
lds
,
InMemoryDataOperation
::
none
>
(
{
0
,
m_block_data_on_global
},
{
0
,
0
});
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
b_k_n_block_desc
=
make_native_tensor_descriptor_aligned
(
Sequence
<
KPerBlock
,
NPerBlock
>
{},
Number
<
max_lds_align
>
{});
// B matrix blockwise copy
auto
b_blockwise_copy
=
BlockwiseGenericTensorSliceCopy_v4
<
BlockSize
,
decltype
(
b_k_n_global_desc
),
decltype
(
b_k_n_block_desc
),
decltype
(
b_k_n_block_desc
.
GetLengths
()),
BBlockCopySubLengths_K_N
,
BBlockCopyClusterLengths_K_N
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
1
,
1
,
BBlockCopyDataPerAccess_N
,
BBlockCopyDataPerAccess_N
,
AddressSpace
::
global
,
AddressSpace
::
vgpr
,
AddressSpace
::
lds
,
InMemoryDataOperation
::
none
>
(
{
0
,
n_block_data_on_global
},
{
0
,
0
});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, MPerBlock] is in LDS
// b_mtx[KPerBlocl, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
constexpr
auto
a_k_m_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
a_k_m_block_desc
);
constexpr
auto
b_k_n_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
b_k_n_block_desc
);
// sanity check
static_assert
(
MPerBlock
%
(
MPerThreadSubC
*
MLevel0Cluster
*
MLevel1Cluster
)
==
0
&&
NPerBlock
%
(
NPerThreadSubC
*
NLevel0Cluster
*
NLevel1Cluster
)
==
0
,
"wrong!"
);
constexpr
index_t
GemmMRepeat
=
MPerBlock
/
(
MPerThreadSubC
*
MLevel0Cluster
*
MLevel1Cluster
);
constexpr
index_t
GemmNRepeat
=
NPerBlock
/
(
NPerThreadSubC
*
NLevel0Cluster
*
NLevel1Cluster
);
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
constexpr
auto
c_m0m1_n0n1_thread_mtx_desc
=
make_ConstantMatrixDescriptor_packed
(
Number
<
GemmMRepeat
*
MPerThreadSubC
>
{},
Number
<
GemmNRepeat
*
NPerThreadSubC
>
{});
const
auto
blockwise_gemm
=
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
<
BlockSize
,
decltype
(
a_k_m_block_mtx_desc
),
decltype
(
b_k_n_block_mtx_desc
),
decltype
(
c_m0m1_n0n1_thread_mtx_desc
),
MPerThreadSubC
,
NPerThreadSubC
,
MLevel0Cluster
,
NLevel0Cluster
,
MLevel1Cluster
,
NLevel1Cluster
,
KPerThreadLoop
,
ThreadGemmDataPerReadM
,
ThreadGemmDataPerReadN
>
{};
// LDS allocation for A and B: be careful of alignment
constexpr
index_t
a_block_space
=
math
::
integer_least_multiple
(
a_k_m_block_desc
.
GetElementSpace
(),
max_lds_align
);
constexpr
index_t
b_block_space
=
math
::
integer_least_multiple
(
b_k_n_block_desc
.
GetElementSpace
(),
max_lds_align
);
__shared__
Float
p_a_block_double
[
2
*
a_block_space
];
__shared__
Float
p_b_block_double
[
2
*
b_block_space
];
// register allocation for output
AccFloat
p_c_thread
[
c_m0m1_n0n1_thread_mtx_desc
.
GetElementSpace
()];
// zero out threadwise output
threadwise_matrix_set_zero
(
c_m0m1_n0n1_thread_mtx_desc
,
p_c_thread
);
// LDS double buffer: preload data into LDS
{
a_blockwise_copy
.
Run
(
p_a_global
,
p_a_block_double
);
b_blockwise_copy
.
Run
(
p_b_global
,
p_b_block_double
);
}
// LDS double buffer: main body
for
(
index_t
k_block_data_begin
=
0
;
k_block_data_begin
+
2
*
KPerBlock
<
K
;
k_block_data_begin
+=
2
*
KPerBlock
)
{
#pragma unroll
for
(
index_t
iloop
=
0
;
iloop
<
2
;
++
iloop
)
{
const
bool
even_loop
=
(
iloop
%
2
==
0
);
Float
*
p_a_block_now
=
even_loop
?
p_a_block_double
:
p_a_block_double
+
a_block_space
;
Float
*
p_b_block_now
=
even_loop
?
p_b_block_double
:
p_b_block_double
+
b_block_space
;
Float
*
p_a_block_next
=
even_loop
?
p_a_block_double
+
a_block_space
:
p_a_block_double
;
Float
*
p_b_block_next
=
even_loop
?
p_b_block_double
+
b_block_space
:
p_b_block_double
;
Float
p_a_thread_buffer
[
a_blockwise_copy
.
GetThreadBufferSize
()];
Float
p_b_thread_buffer
[
b_blockwise_copy
.
GetThreadBufferSize
()];
a_blockwise_copy
.
MoveSrcSliceWindow
(
Sequence
<
KPerBlock
,
0
>
{},
True
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
Sequence
<
KPerBlock
,
0
>
{},
True
);
__syncthreads
();
// LDS doubel buffer: load next data from device mem
a_blockwise_copy
.
RunLoadThreadBuffer
(
p_a_global
,
p_a_thread_buffer
);
b_blockwise_copy
.
RunLoadThreadBuffer
(
p_b_global
,
p_b_thread_buffer
);
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
p_a_block_now
,
p_b_block_now
,
p_c_thread
);
// LDS double buffer: store next data to LDS
a_blockwise_copy
.
RunStoreThreadBuffer
(
p_a_thread_buffer
,
p_a_block_next
);
b_blockwise_copy
.
RunStoreThreadBuffer
(
p_b_thread_buffer
,
p_b_block_next
);
}
}
// LDS double buffer: tail
{
constexpr
bool
has_two_iteration_left
=
(
K
%
(
2
*
KPerBlock
)
==
0
);
if
(
has_two_iteration_left
)
// if has 2 iteration left
{
Float
p_a_thread_buffer
[
a_blockwise_copy
.
GetThreadBufferSize
()];
Float
p_b_thread_buffer
[
b_blockwise_copy
.
GetThreadBufferSize
()];
a_blockwise_copy
.
MoveSrcSliceWindow
(
Sequence
<
KPerBlock
,
0
>
{},
True
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
Sequence
<
KPerBlock
,
0
>
{},
True
);
__syncthreads
();
// LDS double buffer: load last data from device mem
a_blockwise_copy
.
RunLoadThreadBuffer
(
p_a_global
,
p_a_thread_buffer
);
b_blockwise_copy
.
RunLoadThreadBuffer
(
p_b_global
,
p_b_thread_buffer
);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm
.
Run
(
p_a_block_double
,
p_b_block_double
,
p_c_thread
);
// LDS double buffer: store last data to LDS
a_blockwise_copy
.
RunStoreThreadBuffer
(
p_a_thread_buffer
,
p_a_block_double
+
a_block_space
);
b_blockwise_copy
.
RunStoreThreadBuffer
(
p_b_thread_buffer
,
p_b_block_double
+
b_block_space
);
__syncthreads
();
// LDS double buffer: GEMM on last data
blockwise_gemm
.
Run
(
p_a_block_double
+
a_block_space
,
p_b_block_double
+
b_block_space
,
p_c_thread
);
}
else
// if has 1 iteration left
{
__syncthreads
();
// LDS double buffer: GEMM on last data
blockwise_gemm
.
Run
(
p_a_block_double
,
p_b_block_double
,
p_c_thread
);
}
}
// input: register to global memory
{
constexpr
index_t
M1
=
MPerThreadSubC
*
MLevel0Cluster
*
MLevel1Cluster
;
constexpr
index_t
M0
=
M
/
M1
;
constexpr
index_t
N1
=
NPerThreadSubC
*
NLevel0Cluster
*
NLevel1Cluster
;
constexpr
index_t
N0
=
N
/
N1
;
// define input tensor descriptor for threadwise copy
// thread input tensor, src of threadwise copy
constexpr
auto
c_m0_m1_n0_n1_thread_desc
=
make_native_tensor_descriptor_packed
(
Sequence
<
GemmMRepeat
,
MPerThreadSubC
,
GemmNRepeat
,
NPerThreadSubC
>
{});
constexpr
auto
c_m0_m1_n0_n1_global_desc
=
transform_tensor_descriptor
(
c_m_n_global_desc
,
make_tuple
(
UnMerge
<
Sequence
<
M0
,
M1
>>
{},
UnMerge
<
Sequence
<
N0
,
N1
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}));
// calculate origin of thread input tensor on global memory
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
blockwise_gemm
.
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
const
index_t
m_thread_data_on_global
=
m_block_data_on_global
+
c_thread_mtx_on_block
.
row
;
const
index_t
n_thread_data_on_global
=
n_block_data_on_global
+
c_thread_mtx_on_block
.
col
;
ThreadwiseGenericTensorSliceCopy_v4r2
<
decltype
(
c_m0_m1_n0_n1_thread_desc
),
decltype
(
c_m0_m1_n0_n1_global_desc
),
decltype
(
c_m0_m1_n0_n1_thread_desc
.
GetLengths
()),
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
CThreadCopyDataPerAccess_N
,
CThreadCopyDataPerAccess_N
,
AddressSpace
::
vgpr
,
AddressSpace
::
global
,
CGlobalMemoryDataOperation
>
(
{
0
,
0
,
0
,
0
},
{
m_thread_data_on_global
/
M1
,
m_thread_data_on_global
%
M1
,
n_thread_data_on_global
/
N1
,
n_thread_data_on_global
%
N1
})
.
Run
(
p_c_thread
,
p_c_global
);
}
}
};
}
// namespace ck
#endif
driver/src/conv_bwd_data_driver.cpp
View file @
ecd3240b
...
@@ -49,7 +49,7 @@ int main(int argc, char* argv[])
...
@@ -49,7 +49,7 @@ int main(int argc, char* argv[])
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif
0
#elif
1
// 1x1 filter, 8x8 image
// 1x1 filter, 8x8 image
// cudnn@V100 68%, ck@V100 72%, ck@P100 52%, ck@VII 42%
// cudnn@V100 68%, ck@V100 72%, ck@P100 52%, ck@VII 42%
constexpr
index_t
N
=
64
;
constexpr
index_t
N
=
64
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment