Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
1e3d69b9
Commit
1e3d69b9
authored
Jun 28, 2019
by
Chao Liu
Browse files
small test case for hip compiler
parent
f0716f5b
Changes
32
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
0 additions
and
8180 deletions
+0
-8180
composable_kernel/include/kernel_algorithm/gridwise_convolution_direct_v2_nchw_kcyx_nkhw.hpp
...gorithm/gridwise_convolution_direct_v2_nchw_kcyx_nkhw.hpp
+0
-254
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r1_chwn_cyxk_khwn.hpp
...ridwise_convolution_implicit_gemm_v1r1_chwn_cyxk_khwn.hpp
+0
-399
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r2_chwn_cyxk_khwn.hpp
...ridwise_convolution_implicit_gemm_v1r2_chwn_cyxk_khwn.hpp
+0
-435
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn.hpp
...ridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn.hpp
+0
-425
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn_lds_double_buffer.hpp
...n_implicit_gemm_v1r3_chwn_cyxk_khwn_lds_double_buffer.hpp
+0
-475
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r3_nchw_cyxk_nkhw.hpp
...ridwise_convolution_implicit_gemm_v1r3_nchw_cyxk_nkhw.hpp
+0
-451
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r3_nchw_cyxk_nkhw_lds_double_buffer.hpp
...n_implicit_gemm_v1r3_nchw_cyxk_nkhw_lds_double_buffer.hpp
+0
-502
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hpp
.../gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hpp
+0
-284
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn_lds_double_buffer.hpp
...ion_implicit_gemm_v2_chwn_cyxk_khwn_lds_double_buffer.hpp
+0
-413
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp
.../gridwise_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp
+0
-377
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v3_nchw_cyxk_nkhw_lds_double_buffer.hpp
...ion_implicit_gemm_v3_nchw_cyxk_nkhw_lds_double_buffer.hpp
+0
-404
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4_nchw_kcyx_nkhw.hpp
.../gridwise_convolution_implicit_gemm_v4_nchw_kcyx_nkhw.hpp
+0
-354
composable_kernel/include/kernel_algorithm/gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hpp
...idwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hpp
+0
-259
composable_kernel/include/kernel_algorithm/gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded.hpp
...ise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded.hpp
+0
-298
composable_kernel/include/tensor_operation/blockwise_2d_tensor_op.hpp
...ernel/include/tensor_operation/blockwise_2d_tensor_op.hpp
+0
-806
composable_kernel/include/tensor_operation/blockwise_3d_tensor_op.hpp
...ernel/include/tensor_operation/blockwise_3d_tensor_op.hpp
+0
-378
composable_kernel/include/tensor_operation/blockwise_4d_tensor_op.hpp
...ernel/include/tensor_operation/blockwise_4d_tensor_op.hpp
+0
-779
composable_kernel/include/tensor_operation/blockwise_batched_gemm.hpp
...ernel/include/tensor_operation/blockwise_batched_gemm.hpp
+0
-529
composable_kernel/include/tensor_operation/blockwise_tensor_slice_copy.hpp
.../include/tensor_operation/blockwise_tensor_slice_copy.hpp
+0
-298
composable_kernel/include/tensor_operation/threadwise_4d_tensor_op.hpp
...rnel/include/tensor_operation/threadwise_4d_tensor_op.hpp
+0
-60
No files found.
composable_kernel/include/kernel_algorithm/gridwise_convolution_direct_v2_nchw_kcyx_nkhw.hpp
deleted
100644 → 0
View file @
f0716f5b
#ifndef CK_GRIDWISE_CONVOLUTION_DIRECT_V2_NCHW_KCYX_NKHW
#define CK_GRIDWISE_CONVOLUTION_DIRECT_V2_NCHW_KCYX_NKHW
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "blockwise_2d_tensor_op.hpp"
#include "blockwise_4d_tensor_op.hpp"
#include "threadwise_tensor_slice_copy.hpp"
#include "threadwise_direct_convolution.hpp"
namespace
ck
{
template
<
index_t
GridSize
,
index_t
BlockSize
,
class
Float
,
class
InGlobalDesc
,
class
WeiGlobalDesc
,
class
OutGlobalDesc
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
CPerBlock
,
index_t
HoPerBlock
,
index_t
WoPerBlock
,
index_t
NPerThread
,
index_t
KPerThread
,
index_t
CPerThread
,
index_t
HoPerThread
,
index_t
WoPerThread
,
index_t
InBlockCopyDataPerRead
,
index_t
WeiBlockCopyDataPerRead
>
struct
GridwiseConvolutionDirect_v2_nchw_kcyx_nkhw
{
__device__
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
const
Float
*
const
__restrict__
p_wei_global
,
Float
*
const
__restrict__
p_out_global
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
in_nchw_global_desc
=
InGlobalDesc
{};
constexpr
auto
wei_kcyx_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
out_nkhw_global_desc
=
OutGlobalDesc
{};
constexpr
index_t
N
=
in_nchw_global_desc
.
GetLength
(
I0
);
constexpr
index_t
K
=
wei_kcyx_global_desc
.
GetLength
(
I0
);
constexpr
index_t
C
=
wei_kcyx_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Y
=
wei_kcyx_global_desc
.
GetLength
(
I2
);
constexpr
index_t
X
=
wei_kcyx_global_desc
.
GetLength
(
I3
);
constexpr
auto
wei_ke_global_desc
=
make_ConstantTensorDescriptor_packed
(
Sequence
<
K
,
C
*
Y
*
X
>
{});
// 2d view of wei for blockwise copy
constexpr
index_t
HiPerBlock
=
HoPerBlock
+
Y
-
1
;
constexpr
index_t
WiPerBlock
=
WoPerBlock
+
X
-
1
;
constexpr
auto
in_nchw_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
NPerBlock
,
CPerBlock
,
HiPerBlock
,
WiPerBlock
>
{},
Number
<
InBlockCopyDataPerRead
>
{});
constexpr
auto
wei_ke_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
KPerBlock
,
CPerBlock
*
Y
*
X
>
{},
Number
<
WeiBlockCopyDataPerRead
>
{});
// 2d view of wei for blockwise copy
constexpr
auto
wei_kcyx_block_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
KPerBlock
,
CPerBlock
,
Y
,
X
>
{},
Sequence
<
wei_ke_block_desc
.
GetStride
(
I0
),
Y
*
X
,
X
,
1
>
{});
// shared mem
constexpr
index_t
in_block_element_size
=
in_nchw_block_desc
.
GetElementSpace
(
Number
<
InBlockCopyDataPerRead
>
{});
constexpr
index_t
wei_block_element_size
=
wei_kcyx_block_desc
.
GetElementSpace
(
Number
<
WeiBlockCopyDataPerRead
>
{});
constexpr
index_t
max_align
=
InBlockCopyDataPerRead
>
WeiBlockCopyDataPerRead
?
InBlockCopyDataPerRead
:
WeiBlockCopyDataPerRead
;
__shared__
Float
p_in_block
[
max_align
*
((
in_block_element_size
+
max_align
-
1
)
/
max_align
)];
__shared__
Float
p_wei_block
[
max_align
*
((
wei_block_element_size
+
max_align
-
1
)
/
max_align
)];
// threadwise tensors
constexpr
index_t
HiPerThread
=
HoPerThread
+
Y
-
1
;
constexpr
index_t
WiPerThread
=
WoPerThread
+
X
-
1
;
constexpr
auto
in_nchw_thread_block_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
NPerThread
,
CPerThread
,
HiPerThread
,
WiPerThread
>
{},
in_nchw_block_desc
.
GetStrides
());
constexpr
auto
wei_kcyx_thread_block_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
KPerThread
,
CPerThread
,
Y
,
X
>
{},
wei_kcyx_block_desc
.
GetStrides
());
constexpr
auto
out_nkhw_thread_desc
=
get_convolution_output_default_4d_tensor_descriptor
(
in_nchw_thread_block_desc
,
wei_kcyx_thread_block_desc
);
// register
Float
p_out_thread
[
out_nkhw_thread_desc
.
GetElementSpace
()];
// divide block work
constexpr
index_t
NBlockWork
=
(
out_nkhw_global_desc
.
GetLength
(
I0
)
+
NPerBlock
-
1
)
/
NPerBlock
;
constexpr
index_t
KBlockWork
=
(
out_nkhw_global_desc
.
GetLength
(
I1
)
+
KPerBlock
-
1
)
/
KPerBlock
;
constexpr
index_t
HBlockWork
=
(
out_nkhw_global_desc
.
GetLength
(
I2
)
+
HoPerBlock
-
1
)
/
HoPerBlock
;
constexpr
index_t
WBlockWork
=
(
out_nkhw_global_desc
.
GetLength
(
I3
)
+
WoPerBlock
-
1
)
/
WoPerBlock
;
const
index_t
block_id
=
blockIdx
.
x
;
index_t
itmp
=
block_id
;
const
index_t
n_block_work_id
=
itmp
/
(
KBlockWork
*
HBlockWork
*
WBlockWork
);
itmp
-=
n_block_work_id
*
(
KBlockWork
*
HBlockWork
*
WBlockWork
);
const
index_t
k_block_work_id
=
itmp
/
(
HBlockWork
*
WBlockWork
);
itmp
-=
k_block_work_id
*
(
HBlockWork
*
WBlockWork
);
const
index_t
h_block_work_id
=
itmp
/
WBlockWork
;
const
index_t
w_block_work_id
=
itmp
-
h_block_work_id
*
WBlockWork
;
const
index_t
n_block_data_begin
=
n_block_work_id
*
NPerBlock
;
const
index_t
k_block_data_begin
=
k_block_work_id
*
KPerBlock
;
const
index_t
ho_block_data_begin
=
h_block_work_id
*
HoPerBlock
;
const
index_t
wo_block_data_begin
=
w_block_work_id
*
WoPerBlock
;
const
index_t
hi_block_data_begin
=
ho_block_data_begin
;
// minus padding
const
index_t
wi_block_data_begin
=
wo_block_data_begin
;
// minus padding
// divide thread work
constexpr
index_t
NThreadWork
=
(
NPerBlock
+
NPerThread
-
1
)
/
NPerThread
;
constexpr
index_t
KThreadWork
=
(
KPerBlock
+
KPerThread
-
1
)
/
KPerThread
;
constexpr
index_t
HThreadWork
=
(
HoPerBlock
+
HoPerThread
-
1
)
/
HoPerThread
;
constexpr
index_t
WThreadWork
=
(
WoPerBlock
+
WoPerThread
-
1
)
/
WoPerThread
;
const
index_t
thread_id
=
get_thread_local_1d_id
();
itmp
=
thread_id
;
const
index_t
n_thread_work_id
=
itmp
/
(
KThreadWork
*
HThreadWork
*
WThreadWork
);
itmp
-=
n_thread_work_id
*
(
KThreadWork
*
HThreadWork
*
WThreadWork
);
const
index_t
k_thread_work_id
=
itmp
/
(
HThreadWork
*
WThreadWork
);
itmp
-=
k_thread_work_id
*
(
HThreadWork
*
WThreadWork
);
const
index_t
h_thread_work_id
=
itmp
/
WThreadWork
;
const
index_t
w_thread_work_id
=
itmp
-
h_thread_work_id
*
WThreadWork
;
const
index_t
n_thread_data_begin
=
n_thread_work_id
*
NPerThread
;
const
index_t
k_thread_data_begin
=
k_thread_work_id
*
KPerThread
;
const
index_t
ho_thread_data_begin
=
h_thread_work_id
*
HoPerThread
;
const
index_t
wo_thread_data_begin
=
w_thread_work_id
*
WoPerThread
;
const
index_t
hi_thread_data_begin
=
ho_thread_data_begin
;
const
index_t
wi_thread_data_begin
=
wo_thread_data_begin
;
constexpr
auto
blockwise_in_copy
=
Blockwise4dTensorCopy1
<
BlockSize
,
Float
,
decltype
(
in_nchw_global_desc
),
decltype
(
in_nchw_block_desc
),
decltype
(
in_nchw_block_desc
.
GetLengths
()),
InBlockCopyDataPerRead
>
{};
#if 0
constexpr auto blockwise_wei_copy =
Blockwise4dTensorCopy1<BlockSize,
Float,
decltype(wei_kcyx_global_desc),
decltype(wei_kcyx_block_desc),
decltype(wei_kcyx_block_desc.GetLengths()),
1>{};
#elif
1
const
auto
blockwise_wei_copy
=
Blockwise2dTensorCopy3
<
BlockSize
,
Float
,
decltype
(
wei_ke_global_desc
),
decltype
(
wei_ke_block_desc
),
decltype
(
wei_ke_block_desc
.
GetLengths
()),
WeiBlockCopyDataPerRead
>
({
0
,
0
},
{
0
,
0
});
#endif
// set threadwise output tensor to 0
threadwise_4d_tensor_set_zero
(
out_nkhw_thread_desc
,
p_out_thread
);
for
(
index_t
c_block_data_begin
=
0
;
c_block_data_begin
<
C
;
c_block_data_begin
+=
CPerBlock
,
__syncthreads
())
{
// copy input tensor to LDS
blockwise_in_copy
.
Run
(
p_in_global
+
in_nchw_global_desc
.
GetOffsetFromMultiIndex
(
n_block_data_begin
,
c_block_data_begin
,
hi_block_data_begin
,
wi_block_data_begin
),
p_in_block
);
// copy weight tensor to LDS
blockwise_wei_copy
.
Run
(
p_wei_global
+
wei_kcyx_global_desc
.
GetOffsetFromMultiIndex
(
k_block_data_begin
,
c_block_data_begin
,
0
,
0
),
p_wei_block
);
__syncthreads
();
for
(
index_t
c_thread_data
=
0
;
c_thread_data
<
CPerBlock
;
c_thread_data
+=
CPerThread
)
{
// threadwise convolution
#if 1
threadwise_direct_convolution_2
(
in_nchw_thread_block_desc
,
p_in_block
+
in_nchw_block_desc
.
GetOffsetFromMultiIndex
(
n_thread_data_begin
,
c_thread_data
,
hi_thread_data_begin
,
wi_thread_data_begin
),
wei_kcyx_thread_block_desc
,
p_wei_block
+
wei_kcyx_block_desc
.
GetOffsetFromMultiIndex
(
k_thread_data_begin
,
c_thread_data
,
0
,
0
),
out_nkhw_thread_desc
,
p_out_thread
);
#elif 0
threadwise_direct_convolution_3
(
in_nchw_thread_block_desc
,
p_in_block
+
in_nchw_block_desc
.
GetOffsetFromMultiIndex
(
n_thread_data_begin
,
c_thread_data
,
hi_thread_data_begin
,
wi_thread_data_begin
),
wei_kcyx_thread_block_desc
,
p_wei_block
+
wei_kcyx_block_desc
.
GetOffsetFromMultiIndex
(
k_thread_data_begin
,
c_thread_data
,
0
,
0
),
out_nkhw_thread_desc
,
p_out_thread
);
#endif
}
}
// copy output tensor from register to global mem
threadwise_tensor_slice_copy
(
out_nkhw_thread_desc
,
p_out_thread
,
out_nkhw_global_desc
,
p_out_global
+
out_nkhw_global_desc
.
GetOffsetFromMultiIndex
(
n_block_data_begin
+
n_thread_data_begin
,
k_block_data_begin
+
k_thread_data_begin
,
ho_block_data_begin
+
ho_thread_data_begin
,
wo_block_data_begin
+
wo_thread_data_begin
),
out_nkhw_thread_desc
.
GetLengths
(),
Number
<
1
>
{});
}
};
}
// namespace ck
#endif
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r1_chwn_cyxk_khwn.hpp
deleted
100644 → 0
View file @
f0716f5b
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R1_CHWN_CYXK_KHWN
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R1_CHWN_CYXK_KHWN
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_4d_tensor_op.hpp"
#include "blockwise_2d_tensor_op.hpp"
#include "threadwise_tensor_slice_copy.hpp"
#include "threadwise_4d_tensor_op.hpp"
#include "blockwise_batched_gemm.hpp"
namespace
ck
{
template
<
index_t
GridSize
,
index_t
BlockSize
,
class
Float
,
class
InGlobalDesc
,
class
WeiGlobalDesc
,
class
OutGlobalDesc
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
CPerBlock
,
index_t
HoPerBlock
,
index_t
WoPerBlock
,
index_t
NPerThread
,
index_t
KPerThread
,
index_t
HoPerThread
,
index_t
WoPerThread
,
index_t
GemmMPerThreadSubC
,
index_t
GemmNPerThreadSubC
,
index_t
GemmMLevel0Cluster
,
index_t
GemmNLevel0Cluster
,
index_t
GemmMLevel1Cluster
,
index_t
GemmNLevel1Cluster
,
index_t
GemmKPerThreadLoop
,
index_t
GemmDataPerReadA
,
index_t
GemmDataPerReadB
,
class
InBlockCopyClusterLengths_CHWN
,
index_t
InBlockCopyDataPerRead_N
,
index_t
WeiBlockCopyDataPerRead_K
,
index_t
OutThreadCopyDataPerWrite_N
>
struct
GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn
{
__device__
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
const
Float
*
const
__restrict__
p_wei_global
,
Float
*
const
__restrict__
p_out_global
)
const
{
// be careful of this assertion
static_assert
(
NPerBlock
%
NPerThread
==
0
&&
((
GemmNPerThreadSubC
<=
NPerBlock
&&
NPerBlock
%
GemmNPerThreadSubC
==
0
)
||
(
GemmNPerThreadSubC
>=
NPerBlock
&&
NPerThread
==
NPerBlock
&&
GemmNPerThreadSubC
%
NPerThread
==
0
)),
"wrong!"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
in_c_h_w_n_global_desc
=
InGlobalDesc
{};
constexpr
auto
wei_c_y_x_k_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
out_k_h_w_n_global_desc
=
OutGlobalDesc
{};
constexpr
index_t
C
=
in_c_h_w_n_global_desc
.
GetLength
(
I0
);
constexpr
index_t
K
=
out_k_h_w_n_global_desc
.
GetLength
(
I0
);
constexpr
index_t
Ho
=
out_k_h_w_n_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Wo
=
out_k_h_w_n_global_desc
.
GetLength
(
I2
);
constexpr
index_t
N
=
out_k_h_w_n_global_desc
.
GetLength
(
I3
);
constexpr
index_t
Y
=
wei_c_y_x_k_global_desc
.
GetLength
(
I1
);
constexpr
index_t
X
=
wei_c_y_x_k_global_desc
.
GetLength
(
I2
);
constexpr
index_t
HiPerBlock
=
HoPerBlock
+
Y
-
1
;
constexpr
index_t
WiPerBlock
=
WoPerBlock
+
X
-
1
;
// divide block work: [K, Ho, Wo, N]
static_assert
(
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
&&
C
%
CPerBlock
==
0
&&
Ho
%
HoPerBlock
==
0
&&
Wo
%
WoPerBlock
==
0
,
"wrong! cannot evenly divide work for workgroup "
);
constexpr
index_t
KBlockWork
=
(
K
+
KPerBlock
-
1
)
/
KPerBlock
;
constexpr
index_t
HBlockWork
=
(
Ho
+
HoPerBlock
-
1
)
/
HoPerBlock
;
constexpr
index_t
WBlockWork
=
(
Wo
+
WoPerBlock
-
1
)
/
WoPerBlock
;
constexpr
index_t
NBlockWork
=
(
N
+
NPerBlock
-
1
)
/
NPerBlock
;
const
index_t
k_block_work_id
=
get_block_1d_id
()
/
(
HBlockWork
*
WBlockWork
*
NBlockWork
);
index_t
itmp
=
get_block_1d_id
()
-
k_block_work_id
*
(
HBlockWork
*
WBlockWork
*
NBlockWork
);
const
index_t
h_block_work_id
=
itmp
/
(
WBlockWork
*
NBlockWork
);
itmp
-=
h_block_work_id
*
(
WBlockWork
*
NBlockWork
);
const
index_t
w_block_work_id
=
itmp
/
NBlockWork
;
const
index_t
n_block_work_id
=
itmp
-
w_block_work_id
*
NBlockWork
;
const
index_t
k_block_data_begin
=
k_block_work_id
*
KPerBlock
;
const
index_t
ho_block_data_begin
=
h_block_work_id
*
HoPerBlock
;
const
index_t
wo_block_data_begin
=
w_block_work_id
*
WoPerBlock
;
const
index_t
n_block_data_begin
=
n_block_work_id
*
NPerBlock
;
const
index_t
hi_block_data_begin
=
ho_block_data_begin
;
const
index_t
wi_block_data_begin
=
wo_block_data_begin
;
// flattend (2d) tensor view of gridwise weight
constexpr
auto
wei_cyx_k_global_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
C
*
Y
*
X
,
K
>
{});
// tensor view of blockwise input and weight in LDS
// be careful of alignment
constexpr
index_t
max_align
=
math
::
lcm
(
InBlockCopyDataPerRead_N
,
WeiBlockCopyDataPerRead_K
,
GemmDataPerReadA
,
GemmDataPerReadB
);
constexpr
auto
in_c_h_w_n_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
,
HiPerBlock
,
WiPerBlock
,
NPerBlock
>
{},
Number
<
InBlockCopyDataPerRead_N
>
{});
// this check is ad-hoc
// TODO: need to properly implement tensor descriptor with alignment
static_assert
(
in_c_h_w_n_block_desc
.
GetStride
(
I1
)
%
GemmDataPerReadB
==
0
,
"GemmDataPerReadB alignment requirement is not meet"
);
constexpr
auto
wei_cyx_k_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
*
Y
*
X
,
KPerBlock
>
{},
Number
<
math
::
lcm
(
WeiBlockCopyDataPerRead_K
,
GemmDataPerReadA
)
>
{});
constexpr
auto
wei_c_y_x_k_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
,
Y
,
X
,
KPerBlock
>
{},
Number
<
math
::
lcm
(
WeiBlockCopyDataPerRead_K
,
GemmDataPerReadA
)
>
{});
// tensor view of threadwise output in register
constexpr
auto
out_k_h_w_n_thread_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
KPerThread
,
HoPerThread
,
WoPerThread
,
NPerThread
>
{});
// blockwise copy
// input: format is [C, Hi, Wi, N]
const
auto
blockwise_in_copy
=
#if 0
Blockwise4dTensorCopy1<BlockSize,
Float,
decltype(in_c_h_w_n_global_desc),
decltype(in_c_h_w_n_block_desc),
decltype(in_c_h_w_n_block_desc.GetLengths()),
InBlockCopyDataPerRead_N>{};
#else
Blockwise4dTensorCopy3
<
BlockSize
,
Float
,
decltype
(
in_c_h_w_n_global_desc
),
decltype
(
in_c_h_w_n_block_desc
),
decltype
(
in_c_h_w_n_block_desc
.
GetLengths
()),
InBlockCopyClusterLengths_CHWN
,
InBlockCopyDataPerRead_N
>
{};
#endif
// blockwise wei copy
// format is [CPerBlock*Y*X,KPerBlock]
const
auto
blockwise_wei_copy
=
Blockwise2dTensorCopy3
<
BlockSize
,
Float
,
decltype
(
wei_cyx_k_global_desc
),
decltype
(
wei_cyx_k_block_desc
),
decltype
(
wei_cyx_k_block_desc
.
GetLengths
()),
WeiBlockCopyDataPerRead_K
>
{};
// a series of blockwise batched GEMM
// C_matrix += transpose(A_matrix) * B_matrix
// A_matrix and B_matrix saved in LDS, C_matrix saved in register
// A_matrix[C,K] is a sub-matrix of wei_block[C,Y,X,K]
// B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N]
// C_matrix[K,Wo*N] is a sub-matrix of out_block[K,Ho,Wo,N]
constexpr
auto
a_c_k_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
CPerBlock
>
{},
Number
<
KPerBlock
>
{},
Number
<
wei_c_y_x_k_block_desc
.
GetStride
(
I0
)
>
{});
constexpr
auto
b_c_wn_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
CPerBlock
>
{},
Number
<
WoPerBlock
*
NPerBlock
>
{},
Number
<
in_c_h_w_n_block_desc
.
GetStride
(
I0
)
>
{});
constexpr
auto
c_k_wn_thread_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThread
>
{},
Number
<
WoPerThread
*
NPerThread
>
{},
Number
<
out_k_h_w_n_thread_desc
.
GetStride
(
I0
)
>
{});
const
auto
blockwise_batch_gemm
=
BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
<
BlockSize
,
decltype
(
a_c_k_block_mtx_desc
),
decltype
(
b_c_wn_block_mtx_desc
),
decltype
(
c_k_wn_thread_mtx_desc
),
0
,
in_c_h_w_n_block_desc
.
GetStride
(
I1
),
out_k_h_w_n_thread_desc
.
GetStride
(
I1
),
HoPerBlock
,
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
HoPerThread
,
GemmDataPerReadA
,
GemmDataPerReadB
>
{};
// LDS: be careful of alignment
constexpr
index_t
in_block_space
=
in_c_h_w_n_block_desc
.
GetElementSpace
(
Number
<
max_align
>
{});
constexpr
index_t
wei_block_space
=
wei_c_y_x_k_block_desc
.
GetElementSpace
(
Number
<
max_align
>
{});
__shared__
Float
p_in_block
[
in_block_space
];
__shared__
Float
p_wei_block
[
wei_block_space
];
// register
// C++ lambda doesn't capture array, use pointer instead
Float
p_out_thread_data
[
out_k_h_w_n_thread_desc
.
GetElementSpace
()];
Float
*
const
p_out_thread
=
p_out_thread_data
;
// set threadwise output tensor to 0
threadwise_4d_tensor_set_zero
(
out_k_h_w_n_thread_desc
,
p_out_thread
);
const
Float
*
p_in_global_block_offset
=
p_in_global
+
in_c_h_w_n_global_desc
.
GetOffsetFromMultiIndex
(
0
,
hi_block_data_begin
,
wi_block_data_begin
,
n_block_data_begin
);
const
Float
*
p_wei_global_block_offset
=
p_wei_global
+
wei_c_y_x_k_global_desc
.
GetOffsetFromMultiIndex
(
0
,
0
,
0
,
k_block_data_begin
);
for
(
index_t
c_block_data_begin
=
0
;
c_block_data_begin
<
C
;
c_block_data_begin
+=
CPerBlock
,
p_in_global_block_offset
+=
CPerBlock
*
in_c_h_w_n_global_desc
.
GetStride
(
I0
),
p_wei_global_block_offset
+=
CPerBlock
*
wei_c_y_x_k_global_desc
.
GetStride
(
I0
),
__syncthreads
())
{
#if 1
blockwise_in_copy
.
Run
(
p_in_global_block_offset
,
p_in_block
);
blockwise_wei_copy
.
Run
(
p_wei_global_block_offset
,
p_wei_block
);
#else
Float
p_in_register_clipboard
[
blockwise_in_copy
.
GetRegisterClipboardSize
()];
Float
p_wei_register_clipboard
[
blockwise_wei_copy
.
GetRegisterClipboardSize
()];
blockwise_in_copy
.
RunLoadRegisterClipboard
(
p_in_global_block_offset
,
p_in_register_clipboard
);
blockwise_wei_copy
.
RunLoadRegisterClipboard
(
p_wei_global_block_offset
,
p_wei_register_clipboard
);
blockwise_in_copy
.
RunStoreRegisterClipboard
(
p_in_register_clipboard
,
p_in_block
);
blockwise_wei_copy
.
RunStoreRegisterClipboard
(
p_wei_register_clipboard
,
p_wei_block
);
#endif
__syncthreads
();
#pragma unroll
for
(
index_t
y
=
0
;
y
<
Y
;
++
y
)
{
#pragma unroll
for
(
index_t
x
=
0
;
x
<
X
;
++
x
)
{
#if 1
blockwise_batch_gemm
.
Run
#else
blockwise_batch_gemm
.
Run_amd_asm
#endif
(
p_wei_block
+
wei_c_y_x_k_block_desc
.
GetOffsetFromMultiIndex
(
0
,
y
,
x
,
0
),
p_in_block
+
in_c_h_w_n_block_desc
.
GetOffsetFromMultiIndex
(
0
,
y
,
x
,
0
),
p_out_thread
);
}
}
}
// output: register to global mem,
const
auto
c_thread_mtx_begin
=
blockwise_batch_gemm
.
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
const
index_t
k_thread_data_begin
=
c_thread_mtx_begin
.
row
;
const
index_t
ho_thread_data_begin
=
c_thread_mtx_begin
.
batch
;
const
index_t
wo_thread_data_begin
=
c_thread_mtx_begin
.
col
/
NPerBlock
;
const
index_t
n_thread_data_begin
=
c_thread_mtx_begin
.
col
%
NPerBlock
;
static_if
<
GemmNPerThreadSubC
<=
NPerBlock
>
{}([
&
](
auto
f_dummy
)
{
// f_dummy do nothing but
// perfect forwarding.
// Using this trick to
// make this lambda a generic lambda, so it won't be compiled until
// instantiated
static_assert
(
(
f_dummy
(
GemmNPerThreadSubC
)
<=
NPerBlock
&&
NPerBlock
%
GemmNPerThreadSubC
==
0
),
"wrong!"
);
// output is a 10d tensor
constexpr
index_t
N2
=
GemmNPerThreadSubC
;
constexpr
index_t
N1
=
NPerBlock
/
N2
;
constexpr
index_t
W2
=
(
GemmNLevel0Cluster
*
GemmNLevel1Cluster
)
/
f_dummy
(
NPerBlock
/
GemmNPerThreadSubC
);
constexpr
index_t
W1
=
WoPerBlock
/
W2
;
constexpr
index_t
K2
=
GemmMPerThreadSubC
;
constexpr
index_t
K1
=
KPerBlock
/
KPerThread
;
constexpr
auto
out_10d_global_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
K
/
(
K1
*
K2
),
K1
,
K2
,
Ho
,
Wo
/
(
W1
*
W2
),
W1
,
W2
,
N
/
f_dummy
(
N1
*
N2
),
N1
,
N2
>
{});
constexpr
auto
out_10d_thread_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
KPerThread
/
K2
,
1
,
K2
,
HoPerThread
,
1
,
W1
,
1
,
1
,
1
,
N2
>
{});
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
"out_k_h_w_n_thread_desc");
print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc");
print_ConstantTensorDescriptor(out_k_h_w_n_global_desc,
"out_k_h_w_n_global_desc");
print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc");
}
#endif
threadwise_tensor_slice_copy
(
out_10d_thread_desc
,
p_out_thread
,
out_10d_global_desc
,
p_out_global
+
out_k_h_w_n_global_desc
.
GetOffsetFromMultiIndex
(
k_block_data_begin
+
k_thread_data_begin
,
ho_block_data_begin
+
ho_thread_data_begin
,
wo_block_data_begin
+
wo_thread_data_begin
,
n_block_data_begin
+
n_thread_data_begin
),
out_10d_thread_desc
.
GetLengths
(),
Number
<
OutThreadCopyDataPerWrite_N
>
{});
}).
Else
([
&
](
auto
f_dummy
)
{
static_assert
(
f_dummy
(
GemmNPerThreadSubC
)
>=
NPerBlock
&&
NPerThread
==
NPerBlock
&&
GemmNPerThreadSubC
%
NPerThread
==
0
,
"wrong!"
);
// output is a 10d tensor
constexpr
index_t
N1
=
NPerBlock
;
constexpr
index_t
W3
=
GemmNPerThreadSubC
/
NPerBlock
;
constexpr
index_t
W2
=
GemmNLevel0Cluster
*
GemmNLevel1Cluster
;
constexpr
index_t
W1
=
WoPerBlock
/
f_dummy
(
W2
*
W3
);
constexpr
index_t
K2
=
GemmMPerThreadSubC
;
constexpr
index_t
K1
=
KPerBlock
/
KPerThread
;
constexpr
auto
out_10d_global_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
K
/
(
K1
*
K2
),
K1
,
K2
,
Ho
,
Wo
/
(
W1
*
W2
*
W3
),
W1
,
W2
,
W3
,
N
/
N1
,
N1
>
{});
constexpr
auto
out_10d_thread_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
KPerThread
/
K2
,
1
,
K2
,
HoPerThread
,
1
,
W1
,
1
,
W3
,
1
,
N1
>
{});
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
"out_k_h_w_n_thread_desc");
print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc");
print_ConstantTensorDescriptor(out_k_h_w_n_global_desc,
"out_k_h_w_n_global_desc");
print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc");
for(index_t i = 0; i < 64; ++i)
{
printf("out %f, ", p_out_thread[i]);
}
}
#endif
threadwise_tensor_slice_copy
(
out_10d_thread_desc
,
p_out_thread
,
out_10d_global_desc
,
p_out_global
+
out_k_h_w_n_global_desc
.
GetOffsetFromMultiIndex
(
k_block_data_begin
+
k_thread_data_begin
,
ho_block_data_begin
+
ho_thread_data_begin
,
wo_block_data_begin
+
wo_thread_data_begin
,
n_block_data_begin
+
n_thread_data_begin
),
out_10d_thread_desc
.
GetLengths
(),
Number
<
OutThreadCopyDataPerWrite_N
>
{});
});
}
};
}
// namespace ck
#endif
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r2_chwn_cyxk_khwn.hpp
deleted
100644 → 0
View file @
f0716f5b
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R2_CHWN_CYXK_KHWN
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R2_CHWN_CYXK_KHWN
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_2d_tensor_op.hpp"
#include "blockwise_3d_tensor_op.hpp"
#include "blockwise_4d_tensor_op.hpp"
#include "threadwise_tensor_slice_copy.hpp"
#include "threadwise_4d_tensor_op.hpp"
#include "blockwise_batched_gemm.hpp"
namespace
ck
{
template
<
index_t
GridSize
,
index_t
BlockSize
,
class
Float
,
class
InGlobalDesc
,
class
WeiGlobalDesc
,
class
OutGlobalDesc
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
CPerBlock
,
index_t
HoPerBlock
,
index_t
WoPerBlock
,
index_t
NPerThread
,
index_t
KPerThread
,
index_t
HoPerThread
,
index_t
WoPerThread
,
index_t
GemmMPerThreadSubC
,
index_t
GemmNPerThreadSubC
,
index_t
GemmMLevel0Cluster
,
index_t
GemmNLevel0Cluster
,
index_t
GemmMLevel1Cluster
,
index_t
GemmNLevel1Cluster
,
index_t
GemmKPerThreadLoop
,
index_t
GemmDataPerReadA
,
index_t
GemmDataPerReadB
,
class
InBlockCopyClusterLengths_CHWN
,
index_t
InBlockCopyDataPerRead_N
,
index_t
WeiBlockCopyDataPerRead_K
,
index_t
OutThreadCopyDataPerWrite_N
>
struct
GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
{
__device__
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
const
Float
*
const
__restrict__
p_wei_global
,
Float
*
const
__restrict__
p_out_global
)
const
{
// be careful of this assertion
static_assert
(
NPerBlock
%
NPerThread
==
0
&&
((
GemmNPerThreadSubC
<=
NPerBlock
&&
NPerBlock
%
GemmNPerThreadSubC
==
0
)
||
(
GemmNPerThreadSubC
>=
NPerBlock
&&
NPerThread
==
NPerBlock
&&
GemmNPerThreadSubC
%
NPerThread
==
0
)),
"wrong!"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
in_c_h_w_n_global_desc
=
InGlobalDesc
{};
constexpr
auto
wei_c_y_x_k_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
out_k_h_w_n_global_desc
=
OutGlobalDesc
{};
constexpr
index_t
C
=
in_c_h_w_n_global_desc
.
GetLength
(
I0
);
constexpr
index_t
K
=
out_k_h_w_n_global_desc
.
GetLength
(
I0
);
constexpr
index_t
Ho
=
out_k_h_w_n_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Wo
=
out_k_h_w_n_global_desc
.
GetLength
(
I2
);
constexpr
index_t
N
=
out_k_h_w_n_global_desc
.
GetLength
(
I3
);
constexpr
index_t
Y
=
wei_c_y_x_k_global_desc
.
GetLength
(
I1
);
constexpr
index_t
X
=
wei_c_y_x_k_global_desc
.
GetLength
(
I2
);
constexpr
index_t
HiPerBlock
=
HoPerBlock
+
Y
-
1
;
constexpr
index_t
WiPerBlock
=
WoPerBlock
+
X
-
1
;
// divide block work: [K, Ho, Wo, N]
static_assert
(
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
&&
C
%
CPerBlock
==
0
&&
Ho
%
HoPerBlock
==
0
&&
Wo
%
WoPerBlock
==
0
,
"wrong! cannot evenly divide work for workgroup "
);
constexpr
index_t
KBlockWork
=
(
K
+
KPerBlock
-
1
)
/
KPerBlock
;
constexpr
index_t
HBlockWork
=
(
Ho
+
HoPerBlock
-
1
)
/
HoPerBlock
;
constexpr
index_t
WBlockWork
=
(
Wo
+
WoPerBlock
-
1
)
/
WoPerBlock
;
constexpr
index_t
NBlockWork
=
(
N
+
NPerBlock
-
1
)
/
NPerBlock
;
const
index_t
k_block_work_id
=
get_block_1d_id
()
/
(
HBlockWork
*
WBlockWork
*
NBlockWork
);
index_t
itmp
=
get_block_1d_id
()
-
k_block_work_id
*
(
HBlockWork
*
WBlockWork
*
NBlockWork
);
const
index_t
h_block_work_id
=
itmp
/
(
WBlockWork
*
NBlockWork
);
itmp
-=
h_block_work_id
*
(
WBlockWork
*
NBlockWork
);
const
index_t
w_block_work_id
=
itmp
/
NBlockWork
;
const
index_t
n_block_work_id
=
itmp
-
w_block_work_id
*
NBlockWork
;
const
index_t
k_block_data_begin
=
k_block_work_id
*
KPerBlock
;
const
index_t
ho_block_data_begin
=
h_block_work_id
*
HoPerBlock
;
const
index_t
wo_block_data_begin
=
w_block_work_id
*
WoPerBlock
;
const
index_t
n_block_data_begin
=
n_block_work_id
*
NPerBlock
;
const
index_t
hi_block_data_begin
=
ho_block_data_begin
;
const
index_t
wi_block_data_begin
=
wo_block_data_begin
;
// global tensor view
constexpr
auto
wei_c_x_k_global_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
C
,
X
,
K
>
{},
Sequence
<
Y
*
X
*
K
,
K
,
1
>
{});
// LDS tensor view
// be careful of alignment
constexpr
index_t
max_align
=
math
::
lcm
(
InBlockCopyDataPerRead_N
,
WeiBlockCopyDataPerRead_K
,
GemmDataPerReadA
,
GemmDataPerReadB
);
constexpr
auto
in_c_h_w_n_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
,
HoPerBlock
,
WiPerBlock
,
NPerBlock
>
{},
Number
<
InBlockCopyDataPerRead_N
>
{});
// this check is ad-hoc
// TODO: need to properly implement tensor descriptor with alignment
static_assert
(
in_c_h_w_n_block_desc
.
GetStride
(
I1
)
%
GemmDataPerReadB
==
0
,
"GemmDataPerReadB alignment requirement is not meet"
);
constexpr
auto
wei_c_x_k_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
,
X
,
KPerBlock
>
{},
Number
<
math
::
lcm
(
WeiBlockCopyDataPerRead_K
,
GemmDataPerReadA
)
>
{});
// tensor view of threadwise output in register
constexpr
auto
out_k_h_w_n_thread_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
KPerThread
,
HoPerThread
,
WoPerThread
,
NPerThread
>
{});
// blockwise copy
// input: format is [C, Hi, Wi, N]
#if 1
const
auto
blockwise_in_copy
=
Blockwise4dTensorCopy1
<
BlockSize
,
Float
,
decltype
(
in_c_h_w_n_global_desc
),
decltype
(
in_c_h_w_n_block_desc
),
decltype
(
in_c_h_w_n_block_desc
.
GetLengths
()),
InBlockCopyDataPerRead_N
>
{};
#else
const
auto
blockwise_in_copy
=
Blockwise4dTensorCopy3
<
BlockSize
,
Float
,
decltype
(
in_c_h_w_n_global_desc
),
decltype
(
in_c_h_w_n_block_desc
),
decltype
(
in_c_h_w_n_block_desc
.
GetLengths
()),
InBlockCopyClusterLengths_CHWN
,
InBlockCopyDataPerRead_N
>
{};
#endif
// blockwise wei copy
// format is [CPerBlock, X * KPerBlock]
const
auto
blockwise_wei_copy
=
Blockwise3dTensorCopy1
<
BlockSize
,
Float
,
decltype
(
wei_c_x_k_global_desc
),
decltype
(
wei_c_x_k_block_desc
),
decltype
(
wei_c_x_k_block_desc
.
GetLengths
()),
WeiBlockCopyDataPerRead_K
>
{};
// a series of blockwise batched GEMM
// C_matrix += transpose(A_matrix) * B_matrix
// A_matrix and B_matrix saved in LDS, C_matrix saved in register
// A_matrix[C,K] is a sub-matrix of wei_block[C,K]
// B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N]
// C_matrix[K,Wo*N] is a sub-matrix of out_block[K,Ho,Wo,N]
constexpr
auto
a_c_k_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
CPerBlock
>
{},
Number
<
KPerBlock
>
{},
Number
<
wei_c_x_k_block_desc
.
GetStride
(
I0
)
>
{});
constexpr
auto
b_c_wn_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
CPerBlock
>
{},
Number
<
WoPerBlock
*
NPerBlock
>
{},
Number
<
in_c_h_w_n_block_desc
.
GetStride
(
I0
)
>
{});
constexpr
auto
c_k_wn_thread_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThread
>
{},
Number
<
WoPerThread
*
NPerThread
>
{},
Number
<
out_k_h_w_n_thread_desc
.
GetStride
(
I0
)
>
{});
const
auto
blockwise_batch_gemm
=
BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
<
BlockSize
,
decltype
(
a_c_k_block_mtx_desc
),
decltype
(
b_c_wn_block_mtx_desc
),
decltype
(
c_k_wn_thread_mtx_desc
),
0
,
in_c_h_w_n_block_desc
.
GetStride
(
I1
),
out_k_h_w_n_thread_desc
.
GetStride
(
I1
),
HoPerBlock
,
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
HoPerThread
,
GemmDataPerReadA
,
GemmDataPerReadB
>
{};
// LDS: be careful of alignment
constexpr
index_t
in_block_space
=
in_c_h_w_n_block_desc
.
GetElementSpace
(
Number
<
max_align
>
{});
constexpr
index_t
wei_block_space
=
wei_c_x_k_block_desc
.
GetElementSpace
(
Number
<
max_align
>
{});
__shared__
Float
p_in_block
[
in_block_space
];
__shared__
Float
p_wei_block
[
wei_block_space
];
// register
// C++ lambda doesn't capture array, use pointer instead
Float
p_out_thread_data
[
out_k_h_w_n_thread_desc
.
GetElementSpace
()];
Float
*
const
p_out_thread
=
p_out_thread_data
;
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantTensorDescriptor(in_c_h_w_n_global_desc, "in_c_h_w_n_global_desc");
print_ConstantTensorDescriptor(wei_c_y_x_k_global_desc, "wei_c_y_x_k_global_desc");
print_ConstantTensorDescriptor(in_c_h_w_n_block_desc, "in_c_h_w_n_block_desc");
print_ConstantTensorDescriptor(wei_c_x_k_block_desc, "wei_c_x_k_block_desc");
printf("in_block_space %u, wei_block_space %u\n", in_block_space, wei_block_space);
}
#endif
// set threadwise output tensor to 0
threadwise_4d_tensor_set_zero
(
out_k_h_w_n_thread_desc
,
p_out_thread
);
#if 1
const
Float
*
p_in_global_block_offset
=
p_in_global
+
in_c_h_w_n_global_desc
.
GetOffsetFromMultiIndex
(
0
,
hi_block_data_begin
,
wi_block_data_begin
,
n_block_data_begin
);
const
Float
*
p_wei_global_block_offset
=
p_wei_global
+
wei_c_y_x_k_global_desc
.
GetOffsetFromMultiIndex
(
0
,
0
,
0
,
k_block_data_begin
);
for
(
index_t
c_block_data_begin
=
0
;
c_block_data_begin
<
C
;
c_block_data_begin
+=
CPerBlock
,
p_in_global_block_offset
+=
CPerBlock
*
in_c_h_w_n_global_desc
.
GetStride
(
I0
),
p_wei_global_block_offset
+=
CPerBlock
*
wei_c_y_x_k_global_desc
.
GetStride
(
I0
))
{
for
(
index_t
y
=
0
;
y
<
Y
;
++
y
)
{
blockwise_in_copy
.
Run
(
p_in_global_block_offset
+
in_c_h_w_n_global_desc
.
GetOffsetFromMultiIndex
(
0
,
y
,
0
,
0
),
p_in_block
);
blockwise_wei_copy
.
Run
(
p_wei_global_block_offset
+
wei_c_y_x_k_global_desc
.
GetOffsetFromMultiIndex
(
0
,
y
,
0
,
0
),
p_wei_block
);
__syncthreads
();
for
(
index_t
x
=
0
;
x
<
X
;
++
x
)
{
blockwise_batch_gemm
.
Run
(
p_wei_block
+
wei_c_x_k_block_desc
.
GetOffsetFromMultiIndex
(
0
,
x
,
0
),
p_in_block
+
in_c_h_w_n_block_desc
.
GetOffsetFromMultiIndex
(
0
,
0
,
x
,
0
),
p_out_thread
);
}
__syncthreads
();
}
}
#else
// this use much more register, haven't figure out why?
for
(
index_t
y
=
0
;
y
<
Y
;
++
y
)
{
const
Float
*
p_in_global_block_offset
=
p_in_global
+
in_c_h_w_n_global_desc
.
GetOffsetFromMultiIndex
(
0
,
hi_block_data_begin
+
y
,
wi_block_data_begin
,
n_block_data_begin
);
const
Float
*
p_wei_global_block_offset
=
p_wei_global
+
wei_c_y_x_k_global_desc
.
GetOffsetFromMultiIndex
(
0
,
y
,
0
,
k_block_data_begin
);
for
(
index_t
c_block_data_begin
=
0
;
c_block_data_begin
<
C
;
c_block_data_begin
+=
CPerBlock
,
p_in_global_block_offset
+=
CPerBlock
*
in_c_h_w_n_global_desc
.
GetStride
(
I0
),
p_wei_global_block_offset
+=
CPerBlock
*
wei_c_y_x_k_global_desc
.
GetStride
(
I0
))
{
blockwise_in_copy
.
Run
(
p_in_global_block_offset
,
p_in_block
);
blockwise_wei_copy
.
Run
(
p_wei_global_block_offset
,
p_wei_block
);
__syncthreads
();
for
(
index_t
x
=
0
;
x
<
X
;
++
x
)
{
blockwise_batch_gemm
.
Run
(
p_wei_block
+
wei_c_x_k_block_desc
.
GetOffsetFromMultiIndex
(
0
,
x
,
0
),
p_in_block
+
in_c_h_w_n_block_desc
.
GetOffsetFromMultiIndex
(
0
,
0
,
x
,
0
),
p_out_thread
);
}
__syncthreads
();
}
}
#endif
// output: register to global mem,
const
auto
c_thread_mtx_begin
=
blockwise_batch_gemm
.
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
const
index_t
k_thread_data_begin
=
c_thread_mtx_begin
.
row
;
const
index_t
ho_thread_data_begin
=
c_thread_mtx_begin
.
batch
;
const
index_t
wo_thread_data_begin
=
c_thread_mtx_begin
.
col
/
NPerBlock
;
const
index_t
n_thread_data_begin
=
c_thread_mtx_begin
.
col
%
NPerBlock
;
static_if
<
GemmNPerThreadSubC
<=
NPerBlock
>
{}([
&
](
auto
f_dummy
)
{
// f_dummy do nothing but
// perfect forwarding.
// Using this trick to
// make this lambda a generic lambda, so it won't be compiled until
// instantiated
static_assert
(
(
f_dummy
(
GemmNPerThreadSubC
)
<=
NPerBlock
&&
NPerBlock
%
GemmNPerThreadSubC
==
0
),
"wrong!"
);
// output is a 10d tensor
constexpr
index_t
N2
=
GemmNPerThreadSubC
;
constexpr
index_t
N1
=
NPerBlock
/
N2
;
constexpr
index_t
W2
=
(
GemmNLevel0Cluster
*
GemmNLevel1Cluster
)
/
f_dummy
(
NPerBlock
/
GemmNPerThreadSubC
);
constexpr
index_t
W1
=
WoPerBlock
/
W2
;
constexpr
index_t
K2
=
GemmMPerThreadSubC
;
constexpr
index_t
K1
=
KPerBlock
/
KPerThread
;
constexpr
auto
out_10d_global_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
K
/
(
K1
*
K2
),
K1
,
K2
,
Ho
,
Wo
/
(
W1
*
W2
),
W1
,
W2
,
N
/
f_dummy
(
N1
*
N2
),
N1
,
N2
>
{});
constexpr
auto
out_10d_thread_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
KPerThread
/
K2
,
1
,
K2
,
HoPerThread
,
1
,
W1
,
1
,
1
,
1
,
N2
>
{});
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
"out_k_h_w_n_thread_desc");
print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc");
print_ConstantTensorDescriptor(out_k_h_w_n_global_desc,
"out_k_h_w_n_global_desc");
print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc");
}
#endif
threadwise_tensor_slice_copy
(
out_10d_thread_desc
,
p_out_thread
,
out_10d_global_desc
,
p_out_global
+
out_k_h_w_n_global_desc
.
GetOffsetFromMultiIndex
(
k_block_data_begin
+
k_thread_data_begin
,
ho_block_data_begin
+
ho_thread_data_begin
,
wo_block_data_begin
+
wo_thread_data_begin
,
n_block_data_begin
+
n_thread_data_begin
),
out_10d_thread_desc
.
GetLengths
(),
Number
<
OutThreadCopyDataPerWrite_N
>
{});
}).
Else
([
&
](
auto
f_dummy
)
{
static_assert
(
f_dummy
(
GemmNPerThreadSubC
)
>=
NPerBlock
&&
NPerThread
==
NPerBlock
&&
GemmNPerThreadSubC
%
NPerThread
==
0
,
"wrong!"
);
// output is a 10d tensor
constexpr
index_t
N1
=
NPerBlock
;
constexpr
index_t
W3
=
GemmNPerThreadSubC
/
NPerBlock
;
constexpr
index_t
W2
=
GemmNLevel0Cluster
*
GemmNLevel1Cluster
;
constexpr
index_t
W1
=
WoPerBlock
/
f_dummy
(
W2
*
W3
);
constexpr
index_t
K2
=
GemmMPerThreadSubC
;
constexpr
index_t
K1
=
KPerBlock
/
KPerThread
;
constexpr
auto
out_10d_global_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
K
/
(
K1
*
K2
),
K1
,
K2
,
Ho
,
Wo
/
(
W1
*
W2
*
W3
),
W1
,
W2
,
W3
,
N
/
N1
,
N1
>
{});
constexpr
auto
out_10d_thread_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
KPerThread
/
K2
,
1
,
K2
,
HoPerThread
,
1
,
W1
,
1
,
W3
,
1
,
N1
>
{});
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
"out_k_h_w_n_thread_desc");
print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc");
print_ConstantTensorDescriptor(out_k_h_w_n_global_desc,
"out_k_h_w_n_global_desc");
print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc");
for(index_t i = 0; i < 64; ++i)
{
printf("out %f, ", p_out_thread[i]);
}
}
#endif
threadwise_tensor_slice_copy
(
out_10d_thread_desc
,
p_out_thread
,
out_10d_global_desc
,
p_out_global
+
out_k_h_w_n_global_desc
.
GetOffsetFromMultiIndex
(
k_block_data_begin
+
k_thread_data_begin
,
ho_block_data_begin
+
ho_thread_data_begin
,
wo_block_data_begin
+
wo_thread_data_begin
,
n_block_data_begin
+
n_thread_data_begin
),
out_10d_thread_desc
.
GetLengths
(),
Number
<
OutThreadCopyDataPerWrite_N
>
{});
});
}
};
}
// namespace ck
#endif
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn.hpp
deleted
100644 → 0
View file @
f0716f5b
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_CHWN_CYXK_KHWN
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_CHWN_CYXK_KHWN
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_2d_tensor_op.hpp"
#include "blockwise_4d_tensor_op.hpp"
#include "threadwise_tensor_slice_copy.hpp"
#include "threadwise_4d_tensor_op.hpp"
#include "blockwise_batched_gemm.hpp"
namespace
ck
{
template
<
index_t
GridSize
,
index_t
BlockSize
,
class
Float
,
class
InGlobalDesc
,
class
WeiGlobalDesc
,
class
OutGlobalDesc
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
CPerBlock
,
index_t
HoPerBlock
,
index_t
WoPerBlock
,
index_t
NPerThread
,
index_t
KPerThread
,
index_t
HoPerThread
,
index_t
WoPerThread
,
index_t
GemmMPerThreadSubC
,
index_t
GemmNPerThreadSubC
,
index_t
GemmMLevel0Cluster
,
index_t
GemmNLevel0Cluster
,
index_t
GemmMLevel1Cluster
,
index_t
GemmNLevel1Cluster
,
index_t
GemmKPerThreadLoop
,
index_t
GemmDataPerReadA
,
index_t
GemmDataPerReadB
,
class
InBlockCopyClusterLengths_CHWN
,
index_t
InBlockCopyDataPerRead_N
,
index_t
WeiBlockCopyDataPerRead_K
,
index_t
OutThreadCopyDataPerWrite_N
>
struct
GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
{
__device__
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
const
Float
*
const
__restrict__
p_wei_global
,
Float
*
const
__restrict__
p_out_global
)
const
{
// be careful of this assertion
static_assert
(
NPerBlock
%
NPerThread
==
0
&&
((
GemmNPerThreadSubC
<=
NPerBlock
&&
NPerBlock
%
GemmNPerThreadSubC
==
0
)
||
(
GemmNPerThreadSubC
>=
NPerBlock
&&
NPerThread
==
NPerBlock
&&
GemmNPerThreadSubC
%
NPerThread
==
0
)),
"wrong!"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
in_c_h_w_n_global_desc
=
InGlobalDesc
{};
constexpr
auto
wei_c_y_x_k_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
out_k_h_w_n_global_desc
=
OutGlobalDesc
{};
constexpr
index_t
C
=
in_c_h_w_n_global_desc
.
GetLength
(
I0
);
constexpr
index_t
K
=
out_k_h_w_n_global_desc
.
GetLength
(
I0
);
constexpr
index_t
Ho
=
out_k_h_w_n_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Wo
=
out_k_h_w_n_global_desc
.
GetLength
(
I2
);
constexpr
index_t
N
=
out_k_h_w_n_global_desc
.
GetLength
(
I3
);
constexpr
index_t
Y
=
wei_c_y_x_k_global_desc
.
GetLength
(
I1
);
constexpr
index_t
X
=
wei_c_y_x_k_global_desc
.
GetLength
(
I2
);
// divide block work: [K, Ho, Wo, N]
static_assert
(
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
&&
C
%
CPerBlock
==
0
&&
Ho
%
HoPerBlock
==
0
&&
Wo
%
WoPerBlock
==
0
,
"wrong! cannot evenly divide work for workgroup "
);
constexpr
index_t
NBlockWork
=
math
::
integer_divide_ceil
(
N
,
NPerBlock
);
constexpr
index_t
KBlockWork
=
math
::
integer_divide_ceil
(
K
,
KPerBlock
);
constexpr
index_t
HBlockWork
=
math
::
integer_divide_ceil
(
Ho
,
HoPerBlock
);
constexpr
index_t
WBlockWork
=
math
::
integer_divide_ceil
(
Wo
,
WoPerBlock
);
constexpr
auto
block_work_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
NBlockWork
,
KBlockWork
,
HBlockWork
,
WBlockWork
>
{});
const
auto
block_work_multi_id
=
block_work_desc
.
GetMultiIndexFrom1dIndex
(
get_block_1d_id
());
const
index_t
n_block_data_begin
=
block_work_multi_id
[
0
]
*
NPerBlock
;
const
index_t
k_block_data_begin
=
block_work_multi_id
[
1
]
*
KPerBlock
;
const
index_t
ho_block_data_begin
=
block_work_multi_id
[
2
]
*
HoPerBlock
;
const
index_t
wo_block_data_begin
=
block_work_multi_id
[
3
]
*
WoPerBlock
;
const
index_t
hi_block_data_begin
=
ho_block_data_begin
;
const
index_t
wi_block_data_begin
=
wo_block_data_begin
;
// global tensor view
constexpr
auto
wei_c_k_global_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
C
,
K
>
{},
Sequence
<
Y
*
X
*
K
,
1
>
{});
// LDS tensor view
// be careful of alignment
constexpr
index_t
max_align
=
math
::
lcm
(
InBlockCopyDataPerRead_N
,
WeiBlockCopyDataPerRead_K
,
GemmDataPerReadA
,
GemmDataPerReadB
);
constexpr
auto
in_c_h_w_n_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
,
HoPerBlock
,
WoPerBlock
,
NPerBlock
>
{},
Number
<
InBlockCopyDataPerRead_N
>
{});
// this check is ad-hoc
// TODO: need to properly implement tensor descriptor with alignment
static_assert
(
in_c_h_w_n_block_desc
.
GetStride
(
I1
)
%
GemmDataPerReadB
==
0
,
"GemmDataPerReadB alignment requirement is not meet"
);
constexpr
auto
wei_c_k_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
,
KPerBlock
>
{},
Number
<
math
::
lcm
(
WeiBlockCopyDataPerRead_K
,
GemmDataPerReadA
)
>
{});
// tensor view of threadwise output in register
constexpr
auto
out_k_h_w_n_thread_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
KPerThread
,
HoPerThread
,
WoPerThread
,
NPerThread
>
{});
// blockwise copy
// input: format is [C, Hi, Wi, N]
const
auto
blockwise_in_copy
=
Blockwise4dTensorCopy3
<
BlockSize
,
Float
,
decltype
(
in_c_h_w_n_global_desc
),
decltype
(
in_c_h_w_n_block_desc
),
decltype
(
in_c_h_w_n_block_desc
.
GetLengths
()),
InBlockCopyClusterLengths_CHWN
,
InBlockCopyDataPerRead_N
>
{};
// blockwise wei copy
// format is [CPerBlock, X * KPerBlock]
const
auto
blockwise_wei_copy
=
Blockwise2dTensorCopy3
<
BlockSize
,
Float
,
decltype
(
wei_c_k_global_desc
),
decltype
(
wei_c_k_block_desc
),
decltype
(
wei_c_k_block_desc
.
GetLengths
()),
WeiBlockCopyDataPerRead_K
>
{};
// a series of blockwise batched GEMM
// C_matrix += transpose(A_matrix) * B_matrix
// A_matrix and B_matrix saved in LDS, C_matrix saved in register
// A_matrix[C,K] is a sub-matrix of wei_block[C,K]
// B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N]
// C_matrix[K,Wo*N] is a sub-matrix of out_block[K,Ho,Wo,N]
constexpr
auto
a_c_k_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
CPerBlock
>
{},
Number
<
KPerBlock
>
{},
Number
<
wei_c_k_block_desc
.
GetStride
(
I0
)
>
{});
constexpr
auto
b_c_wn_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
CPerBlock
>
{},
Number
<
WoPerBlock
*
NPerBlock
>
{},
Number
<
in_c_h_w_n_block_desc
.
GetStride
(
I0
)
>
{});
constexpr
auto
c_k_wn_thread_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThread
>
{},
Number
<
WoPerThread
*
NPerThread
>
{},
Number
<
out_k_h_w_n_thread_desc
.
GetStride
(
I0
)
>
{});
const
auto
blockwise_batch_gemm
=
BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
<
BlockSize
,
decltype
(
a_c_k_block_mtx_desc
),
decltype
(
b_c_wn_block_mtx_desc
),
decltype
(
c_k_wn_thread_mtx_desc
),
0
,
in_c_h_w_n_block_desc
.
GetStride
(
I1
),
out_k_h_w_n_thread_desc
.
GetStride
(
I1
),
HoPerBlock
,
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
HoPerThread
,
GemmDataPerReadA
,
GemmDataPerReadB
>
{};
// choose GEMM implementation here
const
auto
run_blockwise_batch_gemm
=
[
&
](
auto
...
Xs
)
{
#if 0
return blockwise_batch_gemm.Run(Xs...);
#elif
0
return
blockwise_batch_gemm
.
Run_amd_asm
(
Xs
...);
#else
return
blockwise_batch_gemm
.
Run_asm_v2
(
Xs
...);
#endif
};
// LDS: be careful of alignment
// TODO:: need to properly implement tensor descriptor with alignment
constexpr
index_t
in_block_space
=
in_c_h_w_n_block_desc
.
GetElementSpace
(
Number
<
max_align
>
{});
constexpr
index_t
wei_block_space
=
wei_c_k_block_desc
.
GetElementSpace
(
Number
<
max_align
>
{});
__shared__
Float
p_in_block
[
in_block_space
];
__shared__
Float
p_wei_block
[
wei_block_space
];
// register
// C++ lambda doesn't capture array, use pointer instead
Float
p_out_thread_data
[
out_k_h_w_n_thread_desc
.
GetElementSpace
()];
Float
*
const
p_out_thread
=
p_out_thread_data
;
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantTensorDescriptor(in_c_h_w_n_global_desc, "in_c_h_w_n_global_desc");
print_ConstantTensorDescriptor(wei_c_y_x_k_global_desc, "wei_c_y_x_k_global_desc");
print_ConstantTensorDescriptor(in_c_h_w_n_block_desc, "in_c_h_w_n_block_desc");
print_ConstantTensorDescriptor(wei_c_x_k_block_desc, "wei_c_x_k_block_desc");
printf("in_block_space %u, wei_block_space %u\n", in_block_space, wei_block_space);
}
#endif
// set threadwise output tensor to 0
threadwise_4d_tensor_set_zero
(
out_k_h_w_n_thread_desc
,
p_out_thread
);
#if 1
const
Float
*
p_in_global_block_offset
=
p_in_global
+
in_c_h_w_n_global_desc
.
GetOffsetFromMultiIndex
(
0
,
hi_block_data_begin
,
wi_block_data_begin
,
n_block_data_begin
);
const
Float
*
p_wei_global_block_offset
=
p_wei_global
+
wei_c_y_x_k_global_desc
.
GetOffsetFromMultiIndex
(
0
,
0
,
0
,
k_block_data_begin
);
for
(
index_t
c_block_data_begin
=
0
;
c_block_data_begin
<
C
;
c_block_data_begin
+=
CPerBlock
,
p_in_global_block_offset
+=
CPerBlock
*
in_c_h_w_n_global_desc
.
GetStride
(
I0
),
p_wei_global_block_offset
+=
CPerBlock
*
wei_c_y_x_k_global_desc
.
GetStride
(
I0
))
{
for
(
index_t
y
=
0
;
y
<
Y
;
++
y
)
{
#pragma unroll
for
(
index_t
x
=
0
;
x
<
X
;
++
x
)
{
blockwise_in_copy
.
Run
(
p_in_global_block_offset
+
in_c_h_w_n_global_desc
.
GetOffsetFromMultiIndex
(
0
,
y
,
x
,
0
),
p_in_block
);
blockwise_wei_copy
.
Run
(
p_wei_global_block_offset
+
wei_c_y_x_k_global_desc
.
GetOffsetFromMultiIndex
(
0
,
y
,
x
,
0
),
p_wei_block
);
__syncthreads
();
run_blockwise_batch_gemm
(
p_wei_block
,
p_in_block
,
p_out_thread
);
__syncthreads
();
}
}
}
#else
for
(
index_t
y
=
0
;
y
<
Y
;
++
y
)
{
for
(
index_t
x
=
0
;
x
<
X
;
++
x
)
{
const
Float
*
p_in_global_block_offset
=
p_in_global
+
in_c_h_w_n_global_desc
.
GetOffsetFromMultiIndex
(
0
,
hi_block_data_begin
+
y
,
wi_block_data_begin
+
x
,
n_block_data_begin
);
const
Float
*
p_wei_global_block_offset
=
p_wei_global
+
wei_c_y_x_k_global_desc
.
GetOffsetFromMultiIndex
(
0
,
y
,
x
,
k_block_data_begin
);
for
(
index_t
c_block_data_begin
=
0
;
c_block_data_begin
<
C
;
c_block_data_begin
+=
CPerBlock
,
p_in_global_block_offset
+=
CPerBlock
*
in_c_h_w_n_global_desc
.
GetStride
(
I0
),
p_wei_global_block_offset
+=
CPerBlock
*
wei_c_y_x_k_global_desc
.
GetStride
(
I0
))
{
blockwise_in_copy
.
Run
(
p_in_global_block_offset
,
p_in_block
);
blockwise_wei_copy
.
Run
(
p_wei_global_block_offset
,
p_wei_block
);
__syncthreads
();
run_blockwise_batch_gemm
(
p_wei_block
,
p_in_block
,
p_out_thread
);
__syncthreads
();
}
}
}
#endif
// output: register to global mem,
const
auto
c_thread_mtx_begin
=
blockwise_batch_gemm
.
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
const
index_t
k_thread_data_begin
=
c_thread_mtx_begin
.
row
;
const
index_t
ho_thread_data_begin
=
c_thread_mtx_begin
.
batch
;
const
index_t
wo_thread_data_begin
=
c_thread_mtx_begin
.
col
/
NPerBlock
;
const
index_t
n_thread_data_begin
=
c_thread_mtx_begin
.
col
%
NPerBlock
;
static_if
<
GemmNPerThreadSubC
<=
NPerBlock
>
{}([
&
](
auto
f_dummy
)
{
// f_dummy do nothing but
// perfect forwarding.
// Using this trick to
// make this lambda a generic lambda, so it won't be compiled until
// instantiated
static_assert
(
(
f_dummy
(
GemmNPerThreadSubC
)
<=
NPerBlock
&&
NPerBlock
%
GemmNPerThreadSubC
==
0
),
"wrong!"
);
// output is a 10d tensor
constexpr
index_t
N2
=
GemmNPerThreadSubC
;
constexpr
index_t
N1
=
NPerBlock
/
N2
;
constexpr
index_t
W2
=
(
GemmNLevel0Cluster
*
GemmNLevel1Cluster
)
/
f_dummy
(
NPerBlock
/
GemmNPerThreadSubC
);
constexpr
index_t
W1
=
WoPerBlock
/
W2
;
constexpr
index_t
K2
=
GemmMPerThreadSubC
;
constexpr
index_t
K1
=
KPerBlock
/
KPerThread
;
constexpr
auto
out_10d_global_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
K
/
(
K1
*
K2
),
K1
,
K2
,
Ho
,
Wo
/
(
W1
*
W2
),
W1
,
W2
,
N
/
f_dummy
(
N1
*
N2
),
N1
,
N2
>
{});
constexpr
auto
out_10d_thread_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
KPerThread
/
K2
,
1
,
K2
,
HoPerThread
,
1
,
W1
,
1
,
1
,
1
,
N2
>
{});
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
"out_k_h_w_n_thread_desc");
print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc");
print_ConstantTensorDescriptor(out_k_h_w_n_global_desc,
"out_k_h_w_n_global_desc");
print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc");
}
#endif
threadwise_tensor_slice_copy
(
out_10d_thread_desc
,
p_out_thread
,
out_10d_global_desc
,
p_out_global
+
out_k_h_w_n_global_desc
.
GetOffsetFromMultiIndex
(
k_block_data_begin
+
k_thread_data_begin
,
ho_block_data_begin
+
ho_thread_data_begin
,
wo_block_data_begin
+
wo_thread_data_begin
,
n_block_data_begin
+
n_thread_data_begin
),
out_10d_thread_desc
.
GetLengths
(),
Number
<
OutThreadCopyDataPerWrite_N
>
{});
}).
Else
([
&
](
auto
f_dummy
)
{
static_assert
(
f_dummy
(
GemmNPerThreadSubC
)
>=
NPerBlock
&&
NPerThread
==
NPerBlock
&&
GemmNPerThreadSubC
%
NPerThread
==
0
,
"wrong!"
);
// output is a 10d tensor
constexpr
index_t
N1
=
NPerBlock
;
constexpr
index_t
W3
=
GemmNPerThreadSubC
/
NPerBlock
;
constexpr
index_t
W2
=
GemmNLevel0Cluster
*
GemmNLevel1Cluster
;
constexpr
index_t
W1
=
WoPerBlock
/
f_dummy
(
W2
*
W3
);
constexpr
index_t
K2
=
GemmMPerThreadSubC
;
constexpr
index_t
K1
=
KPerBlock
/
KPerThread
;
constexpr
auto
out_10d_global_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
K
/
(
K1
*
K2
),
K1
,
K2
,
Ho
,
Wo
/
(
W1
*
W2
*
W3
),
W1
,
W2
,
W3
,
N
/
N1
,
N1
>
{});
constexpr
auto
out_10d_thread_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
KPerThread
/
K2
,
1
,
K2
,
HoPerThread
,
1
,
W1
,
1
,
W3
,
1
,
N1
>
{});
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
"out_k_h_w_n_thread_desc");
print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc");
print_ConstantTensorDescriptor(out_k_h_w_n_global_desc,
"out_k_h_w_n_global_desc");
print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc");
for(index_t i = 0; i < 64; ++i)
{
printf("out %f, ", p_out_thread[i]);
}
}
#endif
threadwise_tensor_slice_copy
(
out_10d_thread_desc
,
p_out_thread
,
out_10d_global_desc
,
p_out_global
+
out_k_h_w_n_global_desc
.
GetOffsetFromMultiIndex
(
k_block_data_begin
+
k_thread_data_begin
,
ho_block_data_begin
+
ho_thread_data_begin
,
wo_block_data_begin
+
wo_thread_data_begin
,
n_block_data_begin
+
n_thread_data_begin
),
out_10d_thread_desc
.
GetLengths
(),
Number
<
OutThreadCopyDataPerWrite_N
>
{});
});
}
};
}
// namespace ck
#endif
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn_lds_double_buffer.hpp
deleted
100644 → 0
View file @
f0716f5b
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_CHWN_CYXK_KHWN_LDS_DOUBLE_BUFFER
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_CHWN_CYXK_KHWN_LDS_DOUBLE_BUFFER
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_2d_tensor_op.hpp"
#include "blockwise_4d_tensor_op.hpp"
#include "threadwise_tensor_slice_copy.hpp"
#include "threadwise_4d_tensor_op.hpp"
#include "blockwise_batched_gemm.hpp"
namespace
ck
{
template
<
index_t
GridSize
,
index_t
BlockSize
,
class
Float
,
class
InGlobalDesc
,
class
WeiGlobalDesc
,
class
OutGlobalDesc
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
CPerBlock
,
index_t
HoPerBlock
,
index_t
WoPerBlock
,
index_t
NPerThread
,
index_t
KPerThread
,
index_t
HoPerThread
,
index_t
WoPerThread
,
index_t
GemmMPerThreadSubC
,
index_t
GemmNPerThreadSubC
,
index_t
GemmMLevel0Cluster
,
index_t
GemmNLevel0Cluster
,
index_t
GemmMLevel1Cluster
,
index_t
GemmNLevel1Cluster
,
index_t
GemmKPerThreadLoop
,
index_t
GemmDataPerReadA
,
index_t
GemmDataPerReadB
,
class
InBlockCopyClusterLengths_CHWN
,
index_t
InBlockCopyDataPerRead_N
,
index_t
WeiBlockCopyDataPerRead_K
,
index_t
OutThreadCopyDataPerWrite_N
>
struct
GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
{
__device__
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
const
Float
*
const
__restrict__
p_wei_global
,
Float
*
const
__restrict__
p_out_global
)
const
{
// be careful of this assertion
static_assert
(
NPerBlock
%
NPerThread
==
0
&&
((
GemmNPerThreadSubC
<=
NPerBlock
&&
NPerBlock
%
GemmNPerThreadSubC
==
0
)
||
(
GemmNPerThreadSubC
>=
NPerBlock
&&
NPerThread
==
NPerBlock
&&
GemmNPerThreadSubC
%
NPerThread
==
0
)),
"wrong!"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
in_c_h_w_n_global_desc
=
InGlobalDesc
{};
constexpr
auto
wei_c_y_x_k_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
out_k_h_w_n_global_desc
=
OutGlobalDesc
{};
constexpr
index_t
C
=
in_c_h_w_n_global_desc
.
GetLength
(
I0
);
constexpr
index_t
K
=
out_k_h_w_n_global_desc
.
GetLength
(
I0
);
constexpr
index_t
Ho
=
out_k_h_w_n_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Wo
=
out_k_h_w_n_global_desc
.
GetLength
(
I2
);
constexpr
index_t
N
=
out_k_h_w_n_global_desc
.
GetLength
(
I3
);
constexpr
index_t
Y
=
wei_c_y_x_k_global_desc
.
GetLength
(
I1
);
constexpr
index_t
X
=
wei_c_y_x_k_global_desc
.
GetLength
(
I2
);
constexpr
index_t
HiPerBlock
=
HoPerBlock
+
Y
-
1
;
constexpr
index_t
WiPerBlock
=
WoPerBlock
+
X
-
1
;
// assert for LDS double buffer
static_assert
(
C
%
(
2
*
CPerBlock
)
==
0
,
"C cannot be evenly divided"
);
// divide block work: [K, Ho, Wo, N]
static_assert
(
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
&&
C
%
CPerBlock
==
0
&&
Ho
%
HoPerBlock
==
0
&&
Wo
%
WoPerBlock
==
0
,
"wrong! cannot evenly divide work for workgroup "
);
constexpr
index_t
KBlockWork
=
math
::
integer_divide_ceil
(
K
,
KPerBlock
);
constexpr
index_t
HBlockWork
=
math
::
integer_divide_ceil
(
Ho
,
HoPerBlock
);
constexpr
index_t
WBlockWork
=
math
::
integer_divide_ceil
(
Wo
,
WoPerBlock
);
constexpr
index_t
NBlockWork
=
math
::
integer_divide_ceil
(
N
,
NPerBlock
);
constexpr
auto
block_work_desc
=
make_ConstantTensorDescriptor_packed
(
Sequence
<
KBlockWork
,
HBlockWork
,
WBlockWork
,
NBlockWork
>
{});
const
auto
block_work_multi_id
=
block_work_desc
.
GetMultiIndexFrom1dIndex
(
get_block_1d_id
());
const
index_t
k_block_data_begin
=
block_work_multi_id
[
0
]
*
KPerBlock
;
const
index_t
ho_block_data_begin
=
block_work_multi_id
[
1
]
*
HoPerBlock
;
const
index_t
wo_block_data_begin
=
block_work_multi_id
[
2
]
*
WoPerBlock
;
const
index_t
n_block_data_begin
=
block_work_multi_id
[
3
]
*
NPerBlock
;
const
index_t
hi_block_data_begin
=
ho_block_data_begin
;
const
index_t
wi_block_data_begin
=
wo_block_data_begin
;
// global tensor view
constexpr
auto
wei_c_k_global_desc
=
wei_c_y_x_k_global_desc
.
Extract
(
I0
,
I3
);
// LDS tensor view
// be careful of alignment
constexpr
index_t
max_align
=
math
::
lcm
(
InBlockCopyDataPerRead_N
,
WeiBlockCopyDataPerRead_K
,
GemmDataPerReadA
,
GemmDataPerReadB
);
constexpr
auto
in_c_h_w_n_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
,
HoPerBlock
,
WoPerBlock
,
NPerBlock
>
{},
Number
<
InBlockCopyDataPerRead_N
>
{});
// this check is ad-hoc
// TODO: need to properly implement tensor descriptor with alignment
static_assert
(
in_c_h_w_n_block_desc
.
GetStride
(
I1
)
%
GemmDataPerReadB
==
0
,
"GemmDataPerReadB alignment requirement is not meet"
);
constexpr
auto
wei_c_k_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
,
KPerBlock
>
{},
Number
<
math
::
lcm
(
WeiBlockCopyDataPerRead_K
,
GemmDataPerReadA
)
>
{});
// tensor view of threadwise output in register
constexpr
auto
out_k_h_w_n_thread_desc
=
make_ConstantTensorDescriptor_packed
(
Sequence
<
KPerThread
,
HoPerThread
,
WoPerThread
,
NPerThread
>
{});
// blockwise copy
// input: format is [C, Hi, Wi, N]
#if 0
const auto blockwise_in_copy =
Blockwise4dTensorCopy1<BlockSize,
Float,
decltype(in_c_h_w_n_global_desc),
decltype(in_c_h_w_n_block_desc),
decltype(in_c_h_w_n_block_desc.GetLengths()),
InBlockCopyDataPerRead_N>{};
#else
const
auto
blockwise_in_copy
=
Blockwise4dTensorCopy3
<
BlockSize
,
Float
,
decltype
(
in_c_h_w_n_global_desc
),
decltype
(
in_c_h_w_n_block_desc
),
decltype
(
in_c_h_w_n_block_desc
.
GetLengths
()),
InBlockCopyClusterLengths_CHWN
,
InBlockCopyDataPerRead_N
>
{};
#endif
// blockwise wei copy
// format is [CPerBlock, X * KPerBlock]
const
auto
blockwise_wei_copy
=
Blockwise2dTensorCopy3
<
BlockSize
,
Float
,
decltype
(
wei_c_k_global_desc
),
decltype
(
wei_c_k_block_desc
),
decltype
(
wei_c_k_block_desc
.
GetLengths
()),
WeiBlockCopyDataPerRead_K
>
({
0
,
0
},
{
0
,
0
});
// a series of blockwise batched GEMM
// C_matrix += transpose(A_matrix) * B_matrix
// A_matrix and B_matrix saved in LDS, C_matrix saved in register
// A_matrix[C,K] is a sub-matrix of wei_block[C,K]
// B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N]
// C_matrix[K,Wo*N] is a sub-matrix of out_block[K,Ho,Wo,N]
constexpr
auto
a_c_k_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
CPerBlock
>
{},
Number
<
KPerBlock
>
{},
Number
<
wei_c_k_block_desc
.
GetStride
(
I0
)
>
{});
constexpr
auto
b_c_wn_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
CPerBlock
>
{},
Number
<
WoPerBlock
*
NPerBlock
>
{},
Number
<
in_c_h_w_n_block_desc
.
GetStride
(
I0
)
>
{});
constexpr
auto
c_k_wn_thread_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThread
>
{},
Number
<
WoPerThread
*
NPerThread
>
{},
Number
<
out_k_h_w_n_thread_desc
.
GetStride
(
I0
)
>
{});
const
auto
blockwise_batch_gemm
=
BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
<
BlockSize
,
decltype
(
a_c_k_block_mtx_desc
),
decltype
(
b_c_wn_block_mtx_desc
),
decltype
(
c_k_wn_thread_mtx_desc
),
0
,
in_c_h_w_n_block_desc
.
GetStride
(
I1
),
out_k_h_w_n_thread_desc
.
GetStride
(
I1
),
HoPerBlock
,
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
HoPerThread
,
GemmDataPerReadA
,
GemmDataPerReadB
>
{};
// choose GEMM implementation here
const
auto
run_blockwise_batch_gemm
=
[
&
](
auto
...
Xs
)
{
#if 1
return
blockwise_batch_gemm
.
Run
(
Xs
...);
#elif 0
return
blockwise_batch_gemm
.
Run_amd_asm
(
Xs
...);
#else
return
blockwise_batch_gemm
.
Run_asm_v2
(
Xs
...);
#endif
};
// LDS: be careful of alignment
constexpr
index_t
in_block_space
=
in_c_h_w_n_block_desc
.
GetElementSpace
(
Number
<
max_align
>
{});
constexpr
index_t
wei_block_space
=
wei_c_k_block_desc
.
GetElementSpace
(
Number
<
max_align
>
{});
// LDS double buffer
__shared__
Float
p_in_block_double
[
2
*
in_block_space
];
__shared__
Float
p_wei_block_double
[
2
*
wei_block_space
];
// register
// C++ lambda doesn't capture array, use pointer instead
Float
p_out_thread_data
[
out_k_h_w_n_thread_desc
.
GetElementSpace
()];
Float
*
const
p_out_thread
=
p_out_thread_data
;
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantTensorDescriptor(in_c_h_w_n_global_desc, "in_c_h_w_n_global_desc");
print_ConstantTensorDescriptor(wei_c_y_x_k_global_desc, "wei_c_y_x_k_global_desc");
print_ConstantTensorDescriptor(in_c_h_w_n_block_desc, "in_c_h_w_n_block_desc");
print_ConstantTensorDescriptor(wei_c_x_k_block_desc, "wei_c_x_k_block_desc");
printf("in_block_space %u, wei_block_space %u\n", in_block_space, wei_block_space);
}
#endif
// set threadwise output to 0
threadwise_matrix_set_zero
(
c_k_wn_thread_mtx_desc
,
p_out_thread
);
for
(
index_t
y
=
0
;
y
<
Y
;
++
y
)
{
for
(
index_t
x
=
0
;
x
<
X
;
++
x
)
{
const
Float
*
p_in_global_block_offset
=
p_in_global
+
in_c_h_w_n_global_desc
.
GetOffsetFromMultiIndex
(
0
,
hi_block_data_begin
+
y
,
wi_block_data_begin
+
x
,
n_block_data_begin
);
const
Float
*
p_wei_global_block_offset
=
p_wei_global
+
wei_c_y_x_k_global_desc
.
GetOffsetFromMultiIndex
(
0
,
y
,
x
,
k_block_data_begin
);
// LDS double buffer: preload data into LDS
{
Float
p_in_register_clipboard
[
blockwise_in_copy
.
GetRegisterClipboardSize
()];
Float
p_wei_register_clipboard
[
blockwise_wei_copy
.
GetRegisterClipboardSize
()];
blockwise_in_copy
.
RunLoadRegisterClipboard
(
p_in_global_block_offset
,
p_in_register_clipboard
);
blockwise_wei_copy
.
RunLoadRegisterClipboard
(
p_wei_global_block_offset
,
p_wei_register_clipboard
);
blockwise_in_copy
.
RunStoreRegisterClipboard
(
p_in_register_clipboard
,
p_in_block_double
);
blockwise_wei_copy
.
RunStoreRegisterClipboard
(
p_wei_register_clipboard
,
p_wei_block_double
);
}
// LDS double buffer: main body
for
(
index_t
c_block_data_begin
=
0
;
c_block_data_begin
+
2
*
CPerBlock
<
C
;
c_block_data_begin
+=
2
*
CPerBlock
)
{
#pragma unroll
for
(
index_t
iloop
=
0
;
iloop
<
2
;
++
iloop
)
{
const
bool
even_loop
=
(
iloop
%
2
==
0
);
Float
*
p_in_block_now
=
even_loop
?
p_in_block_double
:
p_in_block_double
+
in_block_space
;
Float
*
p_wei_block_now
=
even_loop
?
p_wei_block_double
:
p_wei_block_double
+
wei_block_space
;
Float
*
p_in_block_next
=
even_loop
?
p_in_block_double
+
in_block_space
:
p_in_block_double
;
Float
*
p_wei_block_next
=
even_loop
?
p_wei_block_double
+
wei_block_space
:
p_wei_block_double
;
Float
p_in_register_clipboard
[
blockwise_in_copy
.
GetRegisterClipboardSize
()];
Float
p_wei_register_clipboard
[
blockwise_wei_copy
.
GetRegisterClipboardSize
()];
p_in_global_block_offset
+=
CPerBlock
*
in_c_h_w_n_global_desc
.
GetStride
(
I0
);
p_wei_global_block_offset
+=
CPerBlock
*
wei_c_y_x_k_global_desc
.
GetStride
(
I0
);
__syncthreads
();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy
.
RunLoadRegisterClipboard
(
p_in_global_block_offset
,
p_in_register_clipboard
);
blockwise_wei_copy
.
RunLoadRegisterClipboard
(
p_wei_global_block_offset
,
p_wei_register_clipboard
);
run_blockwise_batch_gemm
(
p_wei_block_now
,
p_in_block_now
,
p_out_thread
);
// LDS double buffer: store next data to LDS
blockwise_in_copy
.
RunStoreRegisterClipboard
(
p_in_register_clipboard
,
p_in_block_next
);
blockwise_wei_copy
.
RunStoreRegisterClipboard
(
p_wei_register_clipboard
,
p_wei_block_next
);
}
}
// LDS double buffer: tail
{
Float
p_in_register_clipboard
[
blockwise_in_copy
.
GetRegisterClipboardSize
()];
Float
p_wei_register_clipboard
[
blockwise_wei_copy
.
GetRegisterClipboardSize
()];
// even iteration
p_in_global_block_offset
+=
CPerBlock
*
in_c_h_w_n_global_desc
.
GetStride
(
I0
);
p_wei_global_block_offset
+=
CPerBlock
*
wei_c_y_x_k_global_desc
.
GetStride
(
I0
);
__syncthreads
();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy
.
RunLoadRegisterClipboard
(
p_in_global_block_offset
,
p_in_register_clipboard
);
blockwise_wei_copy
.
RunLoadRegisterClipboard
(
p_wei_global_block_offset
,
p_wei_register_clipboard
);
// LDS double buffer: GEMM on current data
run_blockwise_batch_gemm
(
p_wei_block_double
,
p_in_block_double
,
p_out_thread
);
// LDS double buffer: store next data to LDS
blockwise_in_copy
.
RunStoreRegisterClipboard
(
p_in_register_clipboard
,
p_in_block_double
+
in_block_space
);
blockwise_wei_copy
.
RunStoreRegisterClipboard
(
p_wei_register_clipboard
,
p_wei_block_double
+
wei_block_space
);
// odd iteration
__syncthreads
();
// LDS double buffer: GEMM on current data
run_blockwise_batch_gemm
(
p_wei_block_double
+
wei_block_space
,
p_in_block_double
+
in_block_space
,
p_out_thread
);
}
}
}
// output: register to global mem,
const
auto
c_thread_mtx_begin
=
blockwise_batch_gemm
.
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
const
index_t
k_thread_data_begin
=
c_thread_mtx_begin
.
row
;
const
index_t
ho_thread_data_begin
=
c_thread_mtx_begin
.
batch
;
const
index_t
wo_thread_data_begin
=
c_thread_mtx_begin
.
col
/
NPerBlock
;
const
index_t
n_thread_data_begin
=
c_thread_mtx_begin
.
col
%
NPerBlock
;
static_if
<
GemmNPerThreadSubC
<=
NPerBlock
>
{}([
&
](
auto
fwd
)
{
// fwd do nothing but perfect forwarding.
// Using this trick to make this lambda a generic lambda, so it won't be compiled until
// being instantiated here
static_assert
(
(
fwd
(
GemmNPerThreadSubC
)
<=
NPerBlock
&&
NPerBlock
%
GemmNPerThreadSubC
==
0
),
"wrong!"
);
// output is a 10d tensor
constexpr
index_t
N2
=
GemmNPerThreadSubC
;
constexpr
index_t
N1
=
NPerBlock
/
N2
;
constexpr
index_t
W2
=
(
GemmNLevel0Cluster
*
GemmNLevel1Cluster
)
/
fwd
(
NPerBlock
/
GemmNPerThreadSubC
);
constexpr
index_t
W1
=
WoPerBlock
/
W2
;
constexpr
index_t
K2
=
GemmMPerThreadSubC
;
constexpr
index_t
K1
=
KPerBlock
/
KPerThread
;
constexpr
auto
out_10d_global_desc
=
fwd
(
out_k_h_w_n_global_desc
)
.
Fold
(
I3
,
Number
<
N1
>
{},
Number
<
N2
>
{})
.
Fold
(
I2
,
Number
<
W1
>
{},
Number
<
W2
>
{})
.
Fold
(
I0
,
Number
<
K1
>
{},
Number
<
K2
>
{});
constexpr
auto
out_10d_thread_desc
=
fwd
(
out_k_h_w_n_thread_desc
)
.
Fold
(
I3
,
Number
<
1
>
{},
Number
<
N2
>
{})
.
Fold
(
I2
,
Number
<
W1
>
{},
Number
<
1
>
{})
.
Fold
(
I0
,
Number
<
1
>
{},
Number
<
K2
>
{});
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
"a: out_k_h_w_n_thread_desc");
print_ConstantTensorDescriptor(out_10d_thread_desc, "a: out_10d_thread_desc");
print_ConstantTensorDescriptor(out_k_h_w_n_global_desc,
"a: out_k_h_w_n_global_desc");
print_ConstantTensorDescriptor(out_10d_global_desc, "a: out_10d_global_desc");
}
#endif
threadwise_tensor_slice_copy
(
out_10d_thread_desc
,
p_out_thread
,
out_10d_global_desc
,
p_out_global
+
out_k_h_w_n_global_desc
.
GetOffsetFromMultiIndex
(
k_block_data_begin
+
k_thread_data_begin
,
ho_block_data_begin
+
ho_thread_data_begin
,
wo_block_data_begin
+
wo_thread_data_begin
,
n_block_data_begin
+
n_thread_data_begin
),
out_10d_thread_desc
.
GetLengths
(),
Number
<
OutThreadCopyDataPerWrite_N
>
{});
}).
Else
([
&
](
auto
fwd
)
{
static_assert
(
fwd
(
GemmNPerThreadSubC
)
>=
NPerBlock
&&
NPerThread
==
NPerBlock
&&
GemmNPerThreadSubC
%
NPerThread
==
0
,
"wrong!"
);
// output is a 10d tensor
constexpr
index_t
N1
=
NPerBlock
;
constexpr
index_t
W3
=
GemmNPerThreadSubC
/
NPerBlock
;
constexpr
index_t
W2
=
GemmNLevel0Cluster
*
GemmNLevel1Cluster
;
constexpr
index_t
W1
=
WoPerBlock
/
fwd
(
W2
*
W3
);
constexpr
index_t
K2
=
GemmMPerThreadSubC
;
constexpr
index_t
K1
=
KPerBlock
/
KPerThread
;
constexpr
auto
out_10d_global_desc
=
fwd
(
out_k_h_w_n_global_desc
)
.
Fold
(
I3
,
Number
<
N1
>
{})
.
Fold
(
I2
,
Number
<
W1
>
{},
Number
<
W2
>
{},
Number
<
W3
>
{})
.
Fold
(
I0
,
Number
<
K1
>
{},
Number
<
K2
>
{});
constexpr
auto
out_10d_thread_desc
=
fwd
(
out_k_h_w_n_thread_desc
)
.
Fold
(
I3
,
Number
<
N1
>
{})
.
Fold
(
I2
,
Number
<
W1
>
{},
Number
<
1
>
{},
Number
<
W3
>
{})
.
Fold
(
I0
,
Number
<
1
>
{},
Number
<
K2
>
{});
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
"b: out_k_h_w_n_thread_desc");
print_ConstantTensorDescriptor(out_10d_thread_desc, "b: out_10d_thread_desc");
print_ConstantTensorDescriptor(out_k_h_w_n_global_desc,
"b: out_k_h_w_n_global_desc");
print_ConstantTensorDescriptor(out_10d_global_desc, "b: out_10d_global_desc");
}
#endif
threadwise_tensor_slice_copy
(
out_10d_thread_desc
,
p_out_thread
,
out_10d_global_desc
,
p_out_global
+
out_k_h_w_n_global_desc
.
GetOffsetFromMultiIndex
(
k_block_data_begin
+
k_thread_data_begin
,
ho_block_data_begin
+
ho_thread_data_begin
,
wo_block_data_begin
+
wo_thread_data_begin
,
n_block_data_begin
+
n_thread_data_begin
),
out_10d_thread_desc
.
GetLengths
(),
Number
<
OutThreadCopyDataPerWrite_N
>
{});
});
}
};
}
// namespace ck
#endif
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r3_nchw_cyxk_nkhw.hpp
deleted
100644 → 0
View file @
f0716f5b
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_NCHW_CYXK_NKHW
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_NCHW_CYXK_NKHW
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_2d_tensor_op.hpp"
#include "blockwise_tensor_slice_copy.hpp"
#include "threadwise_tensor_slice_copy.hpp"
#include "threadwise_generic_tensor_op.hpp"
#include "blockwise_batched_gemm.hpp"
namespace
ck
{
template
<
index_t
GridSize
,
index_t
BlockSize
,
class
Float
,
class
InGlobalDesc
,
class
WeiGlobalDesc
,
class
OutGlobalDesc
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
CPerBlock
,
index_t
HoPerBlock
,
index_t
WoPerBlock
,
index_t
NPerThread
,
index_t
KPerThread
,
index_t
HoPerThread
,
index_t
WoPerThread
,
index_t
GemmMPerThreadSubC
,
index_t
GemmNPerThreadSubC
,
index_t
GemmMLevel0Cluster
,
index_t
GemmNLevel0Cluster
,
index_t
GemmMLevel1Cluster
,
index_t
GemmNLevel1Cluster
,
index_t
GemmKPerThreadLoop
,
index_t
GemmDataPerReadA
,
index_t
GemmDataPerReadB
,
class
InBlockReorderSrcSubLengths_NCHW
,
class
InBlockReorderSrcClusterLengths_NCHW
,
class
InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW
,
index_t
InBlockReorderDataPerRead_W
,
index_t
InBlockReorderDataPerWrite_N
,
class
WeiBlockCopyClusterLengths_CK
,
// not used
index_t
WeiBlockCopyDataPerRead_K
,
index_t
OutThreadCopyDataPerWrite_W
>
struct
GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
{
__device__
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
const
Float
*
const
__restrict__
p_wei_global
,
Float
*
const
__restrict__
p_out_global
)
const
{
// be careful of this assertion
static_assert
(
NPerBlock
%
NPerThread
==
0
&&
((
GemmNPerThreadSubC
<=
NPerBlock
&&
NPerBlock
%
GemmNPerThreadSubC
==
0
)
||
(
GemmNPerThreadSubC
>=
NPerBlock
&&
NPerThread
==
NPerBlock
&&
GemmNPerThreadSubC
%
NPerThread
==
0
)),
"wrong!"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
in_n_c_h_w_global_desc
=
InGlobalDesc
{};
constexpr
auto
wei_c_y_x_k_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
out_n_k_h_w_global_desc
=
OutGlobalDesc
{};
constexpr
index_t
C
=
in_n_c_h_w_global_desc
.
GetLength
(
I1
);
constexpr
index_t
N
=
out_n_k_h_w_global_desc
.
GetLength
(
I0
);
constexpr
index_t
K
=
out_n_k_h_w_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Ho
=
out_n_k_h_w_global_desc
.
GetLength
(
I2
);
constexpr
index_t
Wo
=
out_n_k_h_w_global_desc
.
GetLength
(
I3
);
constexpr
index_t
Y
=
wei_c_y_x_k_global_desc
.
GetLength
(
I1
);
constexpr
index_t
X
=
wei_c_y_x_k_global_desc
.
GetLength
(
I2
);
// divide block work: [N, K, Ho, Wo]
static_assert
(
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
&&
C
%
CPerBlock
==
0
&&
Ho
%
HoPerBlock
==
0
&&
Wo
%
WoPerBlock
==
0
,
"wrong! cannot evenly divide work for workgroup "
);
constexpr
index_t
NBlockWork
=
math
::
integer_divide_ceil
(
N
,
NPerBlock
);
constexpr
index_t
KBlockWork
=
math
::
integer_divide_ceil
(
K
,
KPerBlock
);
constexpr
index_t
HBlockWork
=
math
::
integer_divide_ceil
(
Ho
,
HoPerBlock
);
constexpr
index_t
WBlockWork
=
math
::
integer_divide_ceil
(
Wo
,
WoPerBlock
);
constexpr
auto
block_work_desc
=
make_ConstantTensorDescriptor_packed
(
Sequence
<
NBlockWork
,
KBlockWork
,
HBlockWork
,
WBlockWork
>
{});
const
auto
block_work_multi_id
=
block_work_desc
.
GetMultiIndexFrom1dIndex
(
get_block_1d_id
());
const
index_t
n_block_data_begin
=
block_work_multi_id
[
0
]
*
NPerBlock
;
const
index_t
k_block_data_begin
=
block_work_multi_id
[
1
]
*
KPerBlock
;
const
index_t
ho_block_data_begin
=
block_work_multi_id
[
2
]
*
HoPerBlock
;
const
index_t
wo_block_data_begin
=
block_work_multi_id
[
3
]
*
WoPerBlock
;
const
index_t
hi_block_data_begin
=
ho_block_data_begin
;
const
index_t
wi_block_data_begin
=
wo_block_data_begin
;
// global tensor view
constexpr
auto
wei_c_k_global_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
C
,
K
>
{},
Sequence
<
Y
*
X
*
K
,
1
>
{});
// LDS tensor view
// be careful of alignment
constexpr
index_t
max_align
=
math
::
lcm
(
InBlockReorderDataPerWrite_N
,
WeiBlockCopyDataPerRead_K
,
GemmDataPerReadA
,
GemmDataPerReadB
);
constexpr
auto
in_c_h_w_n_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
,
HoPerBlock
,
WoPerBlock
,
NPerBlock
>
{},
Number
<
InBlockReorderDataPerWrite_N
>
{});
// this check is ad-hoc
// TODO: need to properly implement tensor descriptor with alignment
static_assert
(
in_c_h_w_n_block_desc
.
GetStride
(
I1
)
%
GemmDataPerReadB
==
0
,
"GemmDataPerReadB alignment requirement is not meet"
);
constexpr
auto
wei_c_k_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
,
KPerBlock
>
{},
Number
<
math
::
lcm
(
WeiBlockCopyDataPerRead_K
,
GemmDataPerReadA
)
>
{});
// tensor view of threadwise output in register
constexpr
auto
out_k_h_w_n_thread_desc
=
make_ConstantTensorDescriptor_packed
(
Sequence
<
KPerThread
,
HoPerThread
,
WoPerThread
,
NPerThread
>
{});
// blockwise copy
// input: format is [N, C, Hi, Wi] to [C, Hi, Wi, N]
constexpr
auto
map_chwn2nchw
=
Sequence
<
1
,
2
,
3
,
0
>
{};
const
auto
blockwise_in_copy_reorder
=
BlockwiseTensorSliceReorderCopy_v3
<
BlockSize
,
Float
,
decltype
(
in_n_c_h_w_global_desc
),
decltype
(
in_c_h_w_n_block_desc
),
Sequence
<
NPerBlock
,
CPerBlock
,
HoPerBlock
,
WoPerBlock
>
,
InBlockReorderSrcSubLengths_NCHW
,
InBlockReorderSrcClusterLengths_NCHW
,
decltype
(
map_chwn2nchw
),
InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW
,
InBlockReorderDataPerRead_W
,
InBlockReorderDataPerWrite_N
>
({
0
,
0
,
0
,
0
},
{
0
,
0
,
0
,
0
});
// blockwise wei copy
// format is [CPerBlock, KPerBlock]
const
auto
blockwise_wei_copy
=
Blockwise2dTensorCopy3
<
BlockSize
,
Float
,
decltype
(
wei_c_k_global_desc
),
decltype
(
wei_c_k_block_desc
),
decltype
(
wei_c_k_block_desc
.
GetLengths
()),
WeiBlockCopyDataPerRead_K
>
({
0
,
0
},
{
0
,
0
});
// a series of blockwise batched GEMM
// C_matrix += transpose(A_matrix) * B_matrix
// A_matrix and B_matrix saved in LDS, C_matrix saved in register
// A_matrix[C,K] is a sub-matrix of wei_block[C,K]
// B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N]
// C_matrix[K,Wo*N] is a sub-matrix of out_block[K,Ho,Wo,N]
constexpr
auto
a_c_k_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
CPerBlock
>
{},
Number
<
KPerBlock
>
{},
Number
<
wei_c_k_block_desc
.
GetStride
(
I0
)
>
{});
constexpr
auto
b_c_wn_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
CPerBlock
>
{},
Number
<
WoPerBlock
*
NPerBlock
>
{},
Number
<
in_c_h_w_n_block_desc
.
GetStride
(
I0
)
>
{});
constexpr
auto
c_k_wn_thread_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThread
>
{},
Number
<
WoPerThread
*
NPerThread
>
{},
Number
<
out_k_h_w_n_thread_desc
.
GetStride
(
I0
)
>
{});
const
auto
blockwise_batch_gemm
=
BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
<
BlockSize
,
decltype
(
a_c_k_block_mtx_desc
),
decltype
(
b_c_wn_block_mtx_desc
),
decltype
(
c_k_wn_thread_mtx_desc
),
0
,
in_c_h_w_n_block_desc
.
GetStride
(
I1
),
out_k_h_w_n_thread_desc
.
GetStride
(
I1
),
HoPerBlock
,
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
HoPerThread
,
GemmDataPerReadA
,
GemmDataPerReadB
>
{};
// choose GEMM implementation here
const
auto
run_blockwise_batch_gemm
=
[
&
](
auto
...
Xs
)
{
#if 1
return
blockwise_batch_gemm
.
Run
(
Xs
...);
#elif 0
return
blockwise_batch_gemm
.
Run_amd_asm
(
Xs
...);
#else
return
blockwise_batch_gemm
.
Run_asm_v2
(
Xs
...);
#endif
};
// LDS: be careful of alignment
constexpr
index_t
in_block_space
=
in_c_h_w_n_block_desc
.
GetElementSpace
(
Number
<
max_align
>
{});
constexpr
index_t
wei_block_space
=
wei_c_k_block_desc
.
GetElementSpace
(
Number
<
max_align
>
{});
__shared__
Float
p_in_block
[
in_block_space
];
__shared__
Float
p_wei_block
[
wei_block_space
];
// register
// C++ lambda doesn't capture array, use pointer instead
Float
p_out_thread_data
[
out_k_h_w_n_thread_desc
.
GetElementSpace
()];
Float
*
const
p_out_thread
=
p_out_thread_data
;
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantTensorDescriptor(in_c_h_w_n_global_desc, "in_c_h_w_n_global_desc");
print_ConstantTensorDescriptor(wei_c_y_x_k_global_desc, "wei_c_y_x_k_global_desc");
print_ConstantTensorDescriptor(in_c_h_w_n_block_desc, "in_c_h_w_n_block_desc");
print_ConstantTensorDescriptor(wei_c_k_block_desc, "wei_c_k_block_desc");
printf("in_block_space %u, wei_block_space %u\n", in_block_space, wei_block_space);
}
#endif
// set threadwise output tensor to 0
threadwise_generic_tensor_set_zero
(
out_k_h_w_n_thread_desc
,
p_out_thread
);
#if 0
const Float* p_in_global_block_offset =
p_in_global +
in_n_c_h_w_global_desc.GetOffsetFromMultiIndex(
n_block_data_begin, 0, hi_block_data_begin, wi_block_data_begin);
const Float* p_wei_global_block_offset =
p_wei_global + wei_c_y_x_k_global_desc.GetOffsetFromMultiIndex(0, 0, 0, k_block_data_begin);
for(index_t c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock,
p_in_global_block_offset += CPerBlock * in_n_c_h_w_global_desc.GetStride(I1),
p_wei_global_block_offset += CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0))
{
for(index_t y = 0; y < Y; ++y)
{
for(index_t x = 0; x < X; ++x)
{
blockwise_in_copy_reorder.Run(p_in_global_block_offset +
in_n_c_h_w_global_desc.GetOffsetFromMultiIndex(0, 0, y, x),
p_in_block);
blockwise_wei_copy.Run(p_wei_global_block_offset +
wei_c_y_x_k_global_desc.GetOffsetFromMultiIndex(0, y, x, 0),
p_wei_block);
__syncthreads();
run_blockwise_batch_gemm(p_wei_block, p_in_block, p_out_thread);
__syncthreads();
}
}
}
#else
for
(
index_t
y
=
0
;
y
<
Y
;
++
y
)
{
for
(
index_t
x
=
0
;
x
<
X
;
++
x
)
{
const
Float
*
p_in_global_block_offset
=
p_in_global
+
in_n_c_h_w_global_desc
.
GetOffsetFromMultiIndex
(
n_block_data_begin
,
0
,
hi_block_data_begin
+
y
,
wi_block_data_begin
+
x
);
const
Float
*
p_wei_global_block_offset
=
p_wei_global
+
wei_c_y_x_k_global_desc
.
GetOffsetFromMultiIndex
(
0
,
y
,
x
,
k_block_data_begin
);
for
(
index_t
c_block_data_begin
=
0
;
c_block_data_begin
<
C
;
c_block_data_begin
+=
CPerBlock
,
p_in_global_block_offset
+=
CPerBlock
*
in_n_c_h_w_global_desc
.
GetStride
(
I1
),
p_wei_global_block_offset
+=
CPerBlock
*
wei_c_y_x_k_global_desc
.
GetStride
(
I0
))
{
blockwise_in_copy_reorder
.
Run
(
p_in_global_block_offset
,
p_in_block
);
blockwise_wei_copy
.
Run
(
p_wei_global_block_offset
,
p_wei_block
);
__syncthreads
();
run_blockwise_batch_gemm
(
p_wei_block
,
p_in_block
,
p_out_thread
);
__syncthreads
();
}
}
}
#endif
// output: register to global mem,
const
auto
c_thread_mtx_begin
=
blockwise_batch_gemm
.
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
const
index_t
k_thread_data_begin
=
c_thread_mtx_begin
.
row
;
const
index_t
ho_thread_data_begin
=
c_thread_mtx_begin
.
batch
;
const
index_t
wo_thread_data_begin
=
c_thread_mtx_begin
.
col
/
NPerBlock
;
const
index_t
n_thread_data_begin
=
c_thread_mtx_begin
.
col
%
NPerBlock
;
static_if
<
GemmNPerThreadSubC
<=
NPerBlock
>
{}([
&
](
auto
fwd
)
{
// fwd do nothing but perfect forwarding.
// Using this trick to make this lambda a generic lambda, so it won't be compiled until
// begin instantiated here
static_assert
(
(
fwd
(
GemmNPerThreadSubC
)
<=
NPerBlock
&&
NPerBlock
%
GemmNPerThreadSubC
==
0
),
"wrong!"
);
// output is a 10d tensor
constexpr
index_t
N2
=
GemmNPerThreadSubC
;
constexpr
index_t
N1
=
NPerBlock
/
N2
;
constexpr
index_t
W2
=
(
GemmNLevel0Cluster
*
GemmNLevel1Cluster
)
/
fwd
(
NPerBlock
/
GemmNPerThreadSubC
);
constexpr
index_t
W1
=
WoPerBlock
/
W2
;
constexpr
index_t
K2
=
GemmMPerThreadSubC
;
constexpr
index_t
K1
=
KPerBlock
/
KPerThread
;
constexpr
auto
out_10d_global_desc
=
fwd
(
out_n_k_h_w_global_desc
)
.
Fold
(
I3
,
Number
<
W1
>
{},
Number
<
W2
>
{})
.
Fold
(
I1
,
Number
<
K1
>
{},
Number
<
K2
>
{})
.
Fold
(
I0
,
Number
<
N1
>
{},
Number
<
N2
>
{});
constexpr
auto
out_10d_thread_desc
=
fwd
(
out_k_h_w_n_thread_desc
)
.
Fold
(
I3
,
Number
<
1
>
{},
Number
<
N2
>
{})
.
Fold
(
I2
,
Number
<
W1
>
{},
Number
<
1
>
{})
.
Fold
(
I0
,
Number
<
1
>
{},
Number
<
K2
>
{});
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
"a: out_k_h_w_n_thread_desc");
print_ConstantTensorDescriptor(out_10d_thread_desc, "a: out_10d_thread_desc");
print_ConstantTensorDescriptor(out_n_k_h_w_global_desc,
"a: out_n_k_h_w_global_desc");
print_ConstantTensorDescriptor(out_10d_global_desc, "a: out_10d_global_desc");
}
#endif
constexpr
auto
map_out_global2thread
=
Sequence
<
7
,
8
,
9
,
0
,
1
,
2
,
3
,
4
,
5
,
6
>
{};
threadwise_tensor_slice_copy_reorder_given_dst2src_v2
(
out_10d_thread_desc
,
p_out_thread
,
out_10d_global_desc
,
p_out_global
+
out_n_k_h_w_global_desc
.
GetOffsetFromMultiIndex
(
n_block_data_begin
+
n_thread_data_begin
,
k_block_data_begin
+
k_thread_data_begin
,
ho_block_data_begin
+
ho_thread_data_begin
,
wo_block_data_begin
+
wo_thread_data_begin
),
out_10d_thread_desc
.
GetLengths
(),
map_out_global2thread
);
// Number<OutThreadCopyDataPerWrite_W>{});
}).
Else
([
&
](
auto
fwd
)
{
static_assert
(
fwd
(
GemmNPerThreadSubC
)
>=
NPerBlock
&&
NPerThread
==
NPerBlock
&&
GemmNPerThreadSubC
%
NPerThread
==
0
,
"wrong!"
);
// output is a 10d tensor
constexpr
index_t
N1
=
NPerBlock
;
constexpr
index_t
W3
=
GemmNPerThreadSubC
/
NPerBlock
;
constexpr
index_t
W2
=
GemmNLevel0Cluster
*
GemmNLevel1Cluster
;
constexpr
index_t
W1
=
WoPerBlock
/
fwd
(
W2
*
W3
);
constexpr
index_t
K2
=
GemmMPerThreadSubC
;
constexpr
index_t
K1
=
KPerBlock
/
KPerThread
;
constexpr
auto
out_10d_global_desc
=
fwd
(
out_n_k_h_w_global_desc
)
.
Fold
(
I3
,
Number
<
W1
>
{},
Number
<
W2
>
{},
Number
<
W3
>
{})
.
Fold
(
I1
,
Number
<
K1
>
{},
Number
<
K2
>
{})
.
Fold
(
I0
,
Number
<
N1
>
{});
constexpr
auto
out_10d_thread_desc
=
fwd
(
out_k_h_w_n_thread_desc
)
.
Fold
(
I3
,
Number
<
N1
>
{})
.
Fold
(
I2
,
Number
<
W1
>
{},
Number
<
1
>
{},
Number
<
W3
>
{})
.
Fold
(
I0
,
Number
<
1
>
{},
Number
<
K2
>
{});
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
"b: out_k_h_w_n_thread_desc");
print_ConstantTensorDescriptor(out_10d_thread_desc, "b: out_10d_thread_desc");
print_ConstantTensorDescriptor(out_n_k_h_w_global_desc,
"b: out_n_k_h_w_global_desc");
print_ConstantTensorDescriptor(out_10d_global_desc, "b: out_10d_global_desc");
}
#endif
constexpr
auto
map_out_global2thread
=
Sequence
<
8
,
9
,
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
{};
#if 0
threadwise_tensor_slice_copy_reorder_given_dst2src_v3(
out_10d_thread_desc,
p_out_thread,
out_10d_global_desc,
p_out_global +
out_n_k_h_w_global_desc.GetOffsetFromMultiIndex(
n_block_data_begin + n_thread_data_begin,
k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin),
out_10d_thread_desc.GetLengths(),
map_out_global2thread,
Number<OutThreadCopyDataPerWrite_W>{});
#else
threadwise_generic_tensor_slice_copy_v1
(
out_10d_thread_desc
.
ReorderGivenNew2Old
(
map_out_global2thread
),
p_out_thread
,
make_zero_array
<
index_t
,
10
>
(),
out_10d_global_desc
,
p_out_global
+
out_n_k_h_w_global_desc
.
GetOffsetFromMultiIndex
(
n_block_data_begin
+
n_thread_data_begin
,
k_block_data_begin
+
k_thread_data_begin
,
ho_block_data_begin
+
ho_thread_data_begin
,
wo_block_data_begin
+
wo_thread_data_begin
),
make_zero_array
<
index_t
,
10
>
(),
out_10d_thread_desc
.
GetLengths
().
ReorderGivenNew2Old
(
map_out_global2thread
),
arithmetic_sequence_gen
<
0
,
10
,
1
>::
type
{},
Number
<
1
>
{});
#endif
});
}
};
}
// namespace ck
#endif
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r3_nchw_cyxk_nkhw_lds_double_buffer.hpp
deleted
100644 → 0
View file @
f0716f5b
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_NCHW_CYXK_NKHW_LDS_DOUBLE_BUFFER
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_NCHW_CYXK_NKHW_LDS_DOUBLE_BUFFER
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_2d_tensor_op.hpp"
#include "blockwise_tensor_slice_copy.hpp"
#include "threadwise_tensor_slice_copy.hpp"
#include "threadwise_generic_tensor_op.hpp"
#include "blockwise_batched_gemm.hpp"
namespace
ck
{
template
<
index_t
GridSize
,
index_t
BlockSize
,
class
Float
,
class
InGlobalDesc
,
class
WeiGlobalDesc
,
class
OutGlobalDesc
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
CPerBlock
,
index_t
HoPerBlock
,
index_t
WoPerBlock
,
index_t
NPerThread
,
index_t
KPerThread
,
index_t
HoPerThread
,
index_t
WoPerThread
,
index_t
GemmMPerThreadSubC
,
index_t
GemmNPerThreadSubC
,
index_t
GemmMLevel0Cluster
,
index_t
GemmNLevel0Cluster
,
index_t
GemmMLevel1Cluster
,
index_t
GemmNLevel1Cluster
,
index_t
GemmKPerThreadLoop
,
index_t
GemmDataPerReadA
,
index_t
GemmDataPerReadB
,
class
InBlockReorderSrcSubLengths_NCHW
,
class
InBlockReorderSrcClusterLengths_NCHW
,
class
InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW
,
index_t
InBlockReorderDataPerRead_W
,
index_t
InBlockReorderDataPerWrite_N
,
class
WeiBlockCopyClusterLengths_CK
,
// not used
index_t
WeiBlockCopyDataPerRead_K
,
index_t
OutThreadCopyDataPerWrite_W
>
struct
GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw_lds_double_buffer
{
__device__
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
const
Float
*
const
__restrict__
p_wei_global
,
Float
*
const
__restrict__
p_out_global
)
const
{
// be careful of this assertion
static_assert
(
NPerBlock
%
NPerThread
==
0
&&
((
GemmNPerThreadSubC
<=
NPerBlock
&&
NPerBlock
%
GemmNPerThreadSubC
==
0
)
||
(
GemmNPerThreadSubC
>=
NPerBlock
&&
NPerThread
==
NPerBlock
&&
GemmNPerThreadSubC
%
NPerThread
==
0
)),
"wrong!"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
in_n_c_h_w_global_desc
=
InGlobalDesc
{};
constexpr
auto
wei_c_y_x_k_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
out_n_k_h_w_global_desc
=
OutGlobalDesc
{};
constexpr
index_t
C
=
in_n_c_h_w_global_desc
.
GetLength
(
I1
);
constexpr
index_t
N
=
out_n_k_h_w_global_desc
.
GetLength
(
I0
);
constexpr
index_t
K
=
out_n_k_h_w_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Ho
=
out_n_k_h_w_global_desc
.
GetLength
(
I2
);
constexpr
index_t
Wo
=
out_n_k_h_w_global_desc
.
GetLength
(
I3
);
constexpr
index_t
Y
=
wei_c_y_x_k_global_desc
.
GetLength
(
I1
);
constexpr
index_t
X
=
wei_c_y_x_k_global_desc
.
GetLength
(
I2
);
// assert for LDS double buffer
static_assert
(
C
%
(
2
*
CPerBlock
)
==
0
,
"C cannot be evenly divided"
);
// divide block work: [K, Ho, Wo, N]
static_assert
(
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
&&
C
%
CPerBlock
==
0
&&
Ho
%
HoPerBlock
==
0
&&
Wo
%
WoPerBlock
==
0
,
"wrong! cannot evenly divide work for workgroup "
);
constexpr
index_t
NBlockWork
=
math
::
integer_divide_ceil
(
N
,
NPerBlock
);
constexpr
index_t
KBlockWork
=
math
::
integer_divide_ceil
(
K
,
KPerBlock
);
constexpr
index_t
HBlockWork
=
math
::
integer_divide_ceil
(
Ho
,
HoPerBlock
);
constexpr
index_t
WBlockWork
=
math
::
integer_divide_ceil
(
Wo
,
WoPerBlock
);
constexpr
auto
block_work_desc
=
make_ConstantTensorDescriptor_packed
(
Sequence
<
NBlockWork
,
KBlockWork
,
HBlockWork
,
WBlockWork
>
{});
const
auto
block_work_multi_id
=
block_work_desc
.
GetMultiIndexFrom1dIndex
(
get_block_1d_id
());
const
index_t
n_block_data_begin
=
block_work_multi_id
[
0
]
*
NPerBlock
;
const
index_t
k_block_data_begin
=
block_work_multi_id
[
1
]
*
KPerBlock
;
const
index_t
ho_block_data_begin
=
block_work_multi_id
[
2
]
*
HoPerBlock
;
const
index_t
wo_block_data_begin
=
block_work_multi_id
[
3
]
*
WoPerBlock
;
const
index_t
hi_block_data_begin
=
ho_block_data_begin
;
const
index_t
wi_block_data_begin
=
wo_block_data_begin
;
// global tensor view
constexpr
auto
wei_c_k_global_desc
=
wei_c_y_x_k_global_desc
.
Extract
(
I0
,
I3
);
// LDS tensor view
// be careful of alignment
constexpr
index_t
max_align
=
math
::
lcm
(
InBlockReorderDataPerWrite_N
,
WeiBlockCopyDataPerRead_K
,
GemmDataPerReadA
,
GemmDataPerReadB
);
constexpr
auto
in_c_h_w_n_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
,
HoPerBlock
,
WoPerBlock
,
NPerBlock
>
{},
Number
<
InBlockReorderDataPerWrite_N
>
{});
// this check is ad-hoc
// TODO: need to properly implement tensor descriptor with multiple alignment requirements
static_assert
(
in_c_h_w_n_block_desc
.
GetStride
(
I1
)
%
GemmDataPerReadB
==
0
,
"GemmDataPerReadB alignment requirement is not meet"
);
constexpr
auto
wei_c_k_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
,
KPerBlock
>
{},
Number
<
math
::
lcm
(
WeiBlockCopyDataPerRead_K
,
GemmDataPerReadA
)
>
{});
// tensor view of threadwise output in register
constexpr
auto
out_k_h_w_n_thread_desc
=
make_ConstantTensorDescriptor_packed
(
Sequence
<
KPerThread
,
HoPerThread
,
WoPerThread
,
NPerThread
>
{});
// blockwise copy
// input: format is [N, C, Hi, Wi] to [C, Hi, Wi, N]
constexpr
auto
map_chwn2nchw
=
Sequence
<
1
,
2
,
3
,
0
>
{};
const
auto
blockwise_in_copy_reorder
=
BlockwiseTensorSliceReorderCopy_v3
<
BlockSize
,
Float
,
decltype
(
in_n_c_h_w_global_desc
),
decltype
(
in_c_h_w_n_block_desc
),
Sequence
<
NPerBlock
,
CPerBlock
,
HoPerBlock
,
WoPerBlock
>
,
InBlockReorderSrcSubLengths_NCHW
,
InBlockReorderSrcClusterLengths_NCHW
,
decltype
(
map_chwn2nchw
),
InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW
,
InBlockReorderDataPerRead_W
,
InBlockReorderDataPerWrite_N
>
({
0
,
0
,
0
,
0
},
{
0
,
0
,
0
,
0
});
// blockwise wei copy
// format is [CPerBlock, KPerBlock]
const
auto
blockwise_wei_copy
=
Blockwise2dTensorCopy3
<
BlockSize
,
Float
,
decltype
(
wei_c_k_global_desc
),
decltype
(
wei_c_k_block_desc
),
decltype
(
wei_c_k_block_desc
.
GetLengths
()),
WeiBlockCopyDataPerRead_K
>
({
0
,
0
},
{
0
,
0
});
// a series of blockwise batched GEMM
// C_matrix += transpose(A_matrix) * B_matrix
// A_matrix and B_matrix saved in LDS, C_matrix saved in register
// A_matrix[C,K] is a sub-matrix of wei_block[C,K]
// B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N]
// C_matrix[K,Wo*N] is a sub-matrix of out_block[K,Ho,Wo,N]
constexpr
auto
a_c_k_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
CPerBlock
>
{},
Number
<
KPerBlock
>
{},
Number
<
wei_c_k_block_desc
.
GetStride
(
I0
)
>
{});
constexpr
auto
b_c_wn_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
CPerBlock
>
{},
Number
<
WoPerBlock
*
NPerBlock
>
{},
Number
<
in_c_h_w_n_block_desc
.
GetStride
(
I0
)
>
{});
constexpr
auto
c_k_wn_thread_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThread
>
{},
Number
<
WoPerThread
*
NPerThread
>
{},
Number
<
out_k_h_w_n_thread_desc
.
GetStride
(
I0
)
>
{});
const
auto
blockwise_batch_gemm
=
BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
<
BlockSize
,
decltype
(
a_c_k_block_mtx_desc
),
decltype
(
b_c_wn_block_mtx_desc
),
decltype
(
c_k_wn_thread_mtx_desc
),
0
,
in_c_h_w_n_block_desc
.
GetStride
(
I1
),
out_k_h_w_n_thread_desc
.
GetStride
(
I1
),
HoPerBlock
,
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
HoPerThread
,
GemmDataPerReadA
,
GemmDataPerReadB
>
{};
// choose GEMM implementation here
const
auto
run_blockwise_batch_gemm
=
[
&
](
auto
...
Xs
)
{
#if 1
return
blockwise_batch_gemm
.
Run
(
Xs
...);
#elif 0
return
blockwise_batch_gemm
.
Run_amd_asm
(
Xs
...);
#else
return
blockwise_batch_gemm
.
Run_asm_v2
(
Xs
...);
#endif
};
// LDS: be careful of alignment
constexpr
index_t
in_block_space
=
in_c_h_w_n_block_desc
.
GetElementSpace
(
Number
<
max_align
>
{});
constexpr
index_t
wei_block_space
=
wei_c_k_block_desc
.
GetElementSpace
(
Number
<
max_align
>
{});
// LDS double buffer
__shared__
Float
p_in_block_double
[
2
*
in_block_space
];
__shared__
Float
p_wei_block_double
[
2
*
wei_block_space
];
// register
// C++ lambda doesn't capture array, use pointer instead
Float
p_out_thread_data
[
out_k_h_w_n_thread_desc
.
GetElementSpace
()];
Float
*
const
p_out_thread
=
p_out_thread_data
;
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantTensorDescriptor(in_c_h_w_n_global_desc, "in_c_h_w_n_global_desc");
print_ConstantTensorDescriptor(wei_c_y_x_k_global_desc, "wei_c_y_x_k_global_desc");
print_ConstantTensorDescriptor(in_c_h_w_n_block_desc, "in_c_h_w_n_block_desc");
print_ConstantTensorDescriptor(wei_c_k_block_desc, "wei_c_k_block_desc");
printf("in_block_space %u, wei_block_space %u\n", in_block_space, wei_block_space);
}
#endif
// set threadwise output tensor to 0
threadwise_generic_tensor_set_zero
(
out_k_h_w_n_thread_desc
,
p_out_thread
);
for
(
index_t
y
=
0
;
y
<
Y
;
++
y
)
{
for
(
index_t
x
=
0
;
x
<
X
;
++
x
)
{
const
Float
*
p_in_global_block_offset
=
p_in_global
+
in_n_c_h_w_global_desc
.
GetOffsetFromMultiIndex
(
n_block_data_begin
,
0
,
hi_block_data_begin
+
y
,
wi_block_data_begin
+
x
);
const
Float
*
p_wei_global_block_offset
=
p_wei_global
+
wei_c_y_x_k_global_desc
.
GetOffsetFromMultiIndex
(
0
,
y
,
x
,
k_block_data_begin
);
// LDS double buffer: preload data into LDS
{
Float
p_in_register_clipboard
[
blockwise_in_copy_reorder
.
GetRegisterClipboardSize
()];
Float
p_wei_register_clipboard
[
blockwise_wei_copy
.
GetRegisterClipboardSize
()];
blockwise_in_copy_reorder
.
RunLoadRegisterClipboard
(
p_in_global_block_offset
,
p_in_register_clipboard
);
blockwise_wei_copy
.
RunLoadRegisterClipboard
(
p_wei_global_block_offset
,
p_wei_register_clipboard
);
blockwise_in_copy_reorder
.
RunStoreRegisterClipboard
(
p_in_register_clipboard
,
p_in_block_double
);
blockwise_wei_copy
.
RunStoreRegisterClipboard
(
p_wei_register_clipboard
,
p_wei_block_double
);
}
// LDS double buffer: main body
for
(
index_t
c_block_data_begin
=
0
;
c_block_data_begin
+
2
*
CPerBlock
<
C
;
c_block_data_begin
+=
2
*
CPerBlock
)
{
#pragma unroll
for
(
index_t
iloop
=
0
;
iloop
<
2
;
++
iloop
)
{
const
bool
even_loop
=
(
iloop
%
2
==
0
);
Float
*
p_in_block_now
=
even_loop
?
p_in_block_double
:
p_in_block_double
+
in_block_space
;
Float
*
p_wei_block_now
=
even_loop
?
p_wei_block_double
:
p_wei_block_double
+
wei_block_space
;
Float
*
p_in_block_next
=
even_loop
?
p_in_block_double
+
in_block_space
:
p_in_block_double
;
Float
*
p_wei_block_next
=
even_loop
?
p_wei_block_double
+
wei_block_space
:
p_wei_block_double
;
Float
p_in_register_clipboard
[
blockwise_in_copy_reorder
.
GetRegisterClipboardSize
()];
Float
p_wei_register_clipboard
[
blockwise_wei_copy
.
GetRegisterClipboardSize
()];
p_in_global_block_offset
+=
CPerBlock
*
in_n_c_h_w_global_desc
.
GetStride
(
I1
);
p_wei_global_block_offset
+=
CPerBlock
*
wei_c_y_x_k_global_desc
.
GetStride
(
I0
);
__syncthreads
();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy_reorder
.
RunLoadRegisterClipboard
(
p_in_global_block_offset
,
p_in_register_clipboard
);
blockwise_wei_copy
.
RunLoadRegisterClipboard
(
p_wei_global_block_offset
,
p_wei_register_clipboard
);
// LDS double buffer: GEMM on current data
run_blockwise_batch_gemm
(
p_wei_block_now
,
p_in_block_now
,
p_out_thread
);
// LDS double buffer: store next data to LDS
blockwise_in_copy_reorder
.
RunStoreRegisterClipboard
(
p_in_register_clipboard
,
p_in_block_next
);
blockwise_wei_copy
.
RunStoreRegisterClipboard
(
p_wei_register_clipboard
,
p_wei_block_next
);
}
}
// LDS double buffer: tail
{
Float
p_in_register_clipboard
[
blockwise_in_copy_reorder
.
GetRegisterClipboardSize
()];
Float
p_wei_register_clipboard
[
blockwise_wei_copy
.
GetRegisterClipboardSize
()];
// even iteration
p_in_global_block_offset
+=
CPerBlock
*
in_n_c_h_w_global_desc
.
GetStride
(
I1
);
p_wei_global_block_offset
+=
CPerBlock
*
wei_c_y_x_k_global_desc
.
GetStride
(
I0
);
__syncthreads
();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy_reorder
.
RunLoadRegisterClipboard
(
p_in_global_block_offset
,
p_in_register_clipboard
);
blockwise_wei_copy
.
RunLoadRegisterClipboard
(
p_wei_global_block_offset
,
p_wei_register_clipboard
);
// LDS double buffer: GEMM on current data
run_blockwise_batch_gemm
(
p_wei_block_double
,
p_in_block_double
,
p_out_thread
);
// LDS double buffer: store next data to LDS
blockwise_in_copy_reorder
.
RunStoreRegisterClipboard
(
p_in_register_clipboard
,
p_in_block_double
+
in_block_space
);
blockwise_wei_copy
.
RunStoreRegisterClipboard
(
p_wei_register_clipboard
,
p_wei_block_double
+
wei_block_space
);
// odd iteration
__syncthreads
();
// LDS double buffer: GEMM on current data
run_blockwise_batch_gemm
(
p_wei_block_double
+
wei_block_space
,
p_in_block_double
+
in_block_space
,
p_out_thread
);
}
}
}
// output: register to global mem,
const
auto
c_thread_mtx_begin
=
blockwise_batch_gemm
.
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
const
index_t
k_thread_data_begin
=
c_thread_mtx_begin
.
row
;
const
index_t
ho_thread_data_begin
=
c_thread_mtx_begin
.
batch
;
const
index_t
wo_thread_data_begin
=
c_thread_mtx_begin
.
col
/
NPerBlock
;
const
index_t
n_thread_data_begin
=
c_thread_mtx_begin
.
col
%
NPerBlock
;
static_if
<
GemmNPerThreadSubC
<=
NPerBlock
>
{}([
&
](
auto
fwd
)
{
// fwd do nothing but perfect forwarding.
// Using this trick to make this lambda a generic lambda, so it won't be compiled until
// begin instantiated here
static_assert
(
(
fwd
(
GemmNPerThreadSubC
)
<=
NPerBlock
&&
NPerBlock
%
GemmNPerThreadSubC
==
0
),
"wrong!"
);
// output is a 10d tensor
constexpr
index_t
N2
=
GemmNPerThreadSubC
;
constexpr
index_t
N1
=
NPerBlock
/
N2
;
constexpr
index_t
W2
=
(
GemmNLevel0Cluster
*
GemmNLevel1Cluster
)
/
fwd
(
NPerBlock
/
GemmNPerThreadSubC
);
constexpr
index_t
W1
=
WoPerBlock
/
W2
;
constexpr
index_t
K2
=
GemmMPerThreadSubC
;
constexpr
index_t
K1
=
KPerBlock
/
KPerThread
;
constexpr
auto
out_10d_global_desc
=
fwd
(
out_n_k_h_w_global_desc
)
.
Fold
(
I3
,
Number
<
W1
>
{},
Number
<
W2
>
{})
.
Fold
(
I1
,
Number
<
K1
>
{},
Number
<
K2
>
{})
.
Fold
(
I0
,
Number
<
N1
>
{},
Number
<
N2
>
{});
constexpr
auto
out_10d_thread_desc
=
fwd
(
out_k_h_w_n_thread_desc
)
.
Fold
(
I3
,
Number
<
1
>
{},
Number
<
N2
>
{})
.
Fold
(
I2
,
Number
<
W1
>
{},
Number
<
1
>
{})
.
Fold
(
I0
,
Number
<
1
>
{},
Number
<
K2
>
{});
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
"a: out_k_h_w_n_thread_desc");
print_ConstantTensorDescriptor(out_10d_thread_desc, "a: out_10d_thread_desc");
print_ConstantTensorDescriptor(out_n_k_h_w_global_desc,
"a: out_n_k_h_w_global_desc");
print_ConstantTensorDescriptor(out_10d_global_desc, "a: out_10d_global_desc");
}
#endif
constexpr
auto
map_out_global2thread
=
Sequence
<
7
,
8
,
9
,
0
,
1
,
2
,
3
,
4
,
5
,
6
>
{};
threadwise_tensor_slice_copy_reorder_given_dst2src_v2
(
out_10d_thread_desc
,
p_out_thread
,
out_10d_global_desc
,
p_out_global
+
out_n_k_h_w_global_desc
.
GetOffsetFromMultiIndex
(
n_block_data_begin
+
n_thread_data_begin
,
k_block_data_begin
+
k_thread_data_begin
,
ho_block_data_begin
+
ho_thread_data_begin
,
wo_block_data_begin
+
wo_thread_data_begin
),
out_10d_thread_desc
.
GetLengths
(),
map_out_global2thread
);
// Number<OutThreadCopyDataPerWrite_W>{});
}).
Else
([
&
](
auto
fwd
)
{
static_assert
(
fwd
(
GemmNPerThreadSubC
)
>=
NPerBlock
&&
NPerThread
==
NPerBlock
&&
GemmNPerThreadSubC
%
NPerThread
==
0
,
"wrong!"
);
// output is a 10d tensor
constexpr
index_t
N1
=
NPerBlock
;
constexpr
index_t
W3
=
GemmNPerThreadSubC
/
NPerBlock
;
constexpr
index_t
W2
=
GemmNLevel0Cluster
*
GemmNLevel1Cluster
;
constexpr
index_t
W1
=
WoPerBlock
/
fwd
(
W2
*
W3
);
constexpr
index_t
K2
=
GemmMPerThreadSubC
;
constexpr
index_t
K1
=
KPerBlock
/
KPerThread
;
constexpr
auto
out_10d_global_desc
=
fwd
(
out_n_k_h_w_global_desc
)
.
Fold
(
I3
,
Number
<
W1
>
{},
Number
<
W2
>
{},
Number
<
W3
>
{})
.
Fold
(
I1
,
Number
<
K1
>
{},
Number
<
K2
>
{})
.
Fold
(
I0
,
Number
<
N1
>
{});
constexpr
auto
out_10d_thread_desc
=
fwd
(
out_k_h_w_n_thread_desc
)
.
Fold
(
I3
,
Number
<
N1
>
{})
.
Fold
(
I2
,
Number
<
W1
>
{},
Number
<
1
>
{},
Number
<
W3
>
{})
.
Fold
(
I0
,
Number
<
1
>
{},
Number
<
K2
>
{});
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
"b: out_k_h_w_n_thread_desc");
print_ConstantTensorDescriptor(out_10d_thread_desc, "b: out_10d_thread_desc");
print_ConstantTensorDescriptor(out_n_k_h_w_global_desc,
"b: out_n_k_h_w_global_desc");
print_ConstantTensorDescriptor(out_10d_global_desc, "b: out_10d_global_desc");
}
#endif
constexpr
auto
map_out_global2thread
=
Sequence
<
8
,
9
,
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
{};
#if 0
threadwise_tensor_slice_copy_reorder_given_dst2src_v3(
out_10d_thread_desc,
p_out_thread,
out_10d_global_desc,
p_out_global +
out_n_k_h_w_global_desc.GetOffsetFromMultiIndex(
n_block_data_begin + n_thread_data_begin,
k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin),
out_10d_thread_desc.GetLengths(),
map_out_global2thread,
Number<OutThreadCopyDataPerWrite_W>{});
#else
threadwise_generic_tensor_slice_copy_v1
(
out_10d_thread_desc
.
ReorderGivenNew2Old
(
map_out_global2thread
),
p_out_thread
,
make_zero_array
<
index_t
,
10
>
(),
out_10d_global_desc
,
p_out_global
+
out_n_k_h_w_global_desc
.
GetOffsetFromMultiIndex
(
n_block_data_begin
+
n_thread_data_begin
,
k_block_data_begin
+
k_thread_data_begin
,
ho_block_data_begin
+
ho_thread_data_begin
,
wo_block_data_begin
+
wo_thread_data_begin
),
make_zero_array
<
index_t
,
10
>
(),
out_10d_thread_desc
.
GetLengths
().
ReorderGivenNew2Old
(
map_out_global2thread
),
arithmetic_sequence_gen
<
0
,
10
,
1
>::
type
{},
Number
<
1
>
{});
#endif
});
}
};
}
// namespace
#endif
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hpp
deleted
100644 → 0
View file @
f0716f5b
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V2_CHWN_CYXK_KHWN
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V2_CHWN_CYXK_KHWN
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_4d_tensor_op.hpp"
#include "blockwise_2d_tensor_op.hpp"
#include "blockwise_gemm.hpp"
namespace
ck
{
// define B = flatten(N, Hi, Wi)
template
<
index_t
GridSize
,
index_t
BlockSize
,
class
Float
,
class
InGlobalDesc
,
class
WeiGlobalDesc
,
class
OutGlobalDesc
,
index_t
BPerBlock
,
index_t
KPerBlock
,
index_t
CPerBlock
,
index_t
BPerThread
,
index_t
KPerThread
,
index_t
GemmMPerThreadSubC
,
index_t
GemmNPerThreadSubC
,
index_t
GemmMLevel0Cluster
,
index_t
GemmNLevel0Cluster
,
index_t
GemmMLevel1Cluster
,
index_t
GemmNLevel1Cluster
,
index_t
GemmKPerThreadLoop
,
index_t
GemmDataPerReadA
,
index_t
GemmDataPerReadB
,
index_t
InBlockCopyThreadPerDim0
,
index_t
InBlockCopyThreadPerDim1
,
index_t
WeiBlockCopyThreadPerDim0
,
index_t
WeiBlockCopyThreadPerDim1
,
index_t
InBlockCopyDataPerRead
,
index_t
WeiBlockCopyDataPerRead
,
index_t
OutThreadCopyDataPerWrite
>
struct
GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn
{
__device__
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
const
Float
*
const
__restrict__
p_wei_global
,
Float
*
const
__restrict__
p_out_global
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
in_chwn_global_desc
=
InGlobalDesc
{};
constexpr
auto
wei_cyxk_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
out_khwn_global_desc
=
OutGlobalDesc
{};
constexpr
index_t
C
=
in_chwn_global_desc
.
GetLength
(
I0
);
constexpr
index_t
Hi
=
in_chwn_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Wi
=
in_chwn_global_desc
.
GetLength
(
I2
);
constexpr
index_t
N
=
in_chwn_global_desc
.
GetLength
(
I3
);
constexpr
index_t
K
=
out_khwn_global_desc
.
GetLength
(
I0
);
constexpr
index_t
Ho
=
out_khwn_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Wo
=
out_khwn_global_desc
.
GetLength
(
I2
);
constexpr
index_t
Y
=
wei_cyxk_global_desc
.
GetLength
(
I1
);
constexpr
index_t
X
=
wei_cyxk_global_desc
.
GetLength
(
I2
);
constexpr
index_t
B
=
N
*
Hi
*
Wi
;
constexpr
index_t
BGhostRead
=
(
Y
-
1
)
*
Wi
+
(
X
-
1
);
// divide block work by 2d: [K, B]
constexpr
index_t
KBlockWork
=
(
K
+
KPerBlock
-
1
)
/
KPerBlock
;
constexpr
index_t
BBlockWork
=
(
B
+
BPerBlock
-
1
)
/
BPerBlock
;
const
index_t
k_block_work_id
=
get_block_1d_id
()
/
BBlockWork
;
const
index_t
b_block_work_id
=
get_block_1d_id
()
-
k_block_work_id
*
BBlockWork
;
const
index_t
k_block_data_begin
=
k_block_work_id
*
KPerBlock
;
const
index_t
b_block_data_begin
=
b_block_work_id
*
BPerBlock
;
// flattend (2d) tensor view of gridwise input
constexpr
auto
in_cb_global_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
C
,
B
>
{});
constexpr
auto
wei_ek_global_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
C
*
Y
*
X
,
K
>
{});
// tensor view of blockwise input and weight
// be careful of alignment
constexpr
auto
in_cb_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
,
BPerBlock
+
BGhostRead
>
{},
Number
<
InBlockCopyDataPerRead
>
{});
constexpr
auto
wei_ek_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
*
Y
*
X
,
KPerBlock
>
{},
Number
<
WeiBlockCopyDataPerRead
>
{});
constexpr
auto
wei_cyxk_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
,
Y
,
X
,
KPerBlock
>
{},
Number
<
WeiBlockCopyDataPerRead
>
{});
// tensor view of threadwise output in register
constexpr
auto
out_kb_thread_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
KPerThread
,
BPerThread
>
{});
// blockwise in copy
// formmat is [CPerBlock,BPerBlock + BGhostRead]
#if 0
const auto blockwise_in_copy =
Blockwise2dTensorCopy1<BlockSize,
Float,
decltype(in_cb_global_desc),
decltype(in_cb_block_desc),
decltype(in_cb_block_desc.GetLengths())>{};
#elif
0
const
auto
blockwise_in_copy
=
Blockwise2dTensorCopy2
<
BlockSize
,
Float
,
decltype
(
in_cb_global_desc
),
decltype
(
in_cb_block_desc
),
decltype
(
in_cb_block_desc
.
GetLengths
()),
InBlockCopyThreadPerDim0
,
InBlockCopyThreadPerDim1
>
{};
#elif 1
const
auto
blockwise_in_copy
=
Blockwise2dTensorCopy3
<
BlockSize
,
Float
,
decltype
(
in_cb_global_desc
),
decltype
(
in_cb_block_desc
),
decltype
(
in_cb_block_desc
.
GetLengths
()),
InBlockCopyDataPerRead
>
{};
#endif
// blockwise wei copy
// format is [CPerBlock*Y*X,KPerBlock]
#if 0
const auto blockwise_wei_copy =
Blockwise2dTensorCopy1<BlockSize,
Float,
decltype(wei_ek_global_desc),
decltype(wei_ek_block_desc),
decltype(wei_ek_block_desc.GetLengths())>{};
#elif
0
const
auto
blockwise_wei_copy
=
Blockwise2dTensorCopy2
<
BlockSize
,
Float
,
decltype
(
wei_ek_global_desc
),
decltype
(
wei_ek_block_desc
),
decltype
(
wei_ek_block_desc
.
GetLengths
()),
WeiBlockCopyThreadPerDim0
,
WeiBlockCopyThreadPerDim1
>
{};
#elif 1
const
auto
blockwise_wei_copy
=
Blockwise2dTensorCopy3
<
BlockSize
,
Float
,
decltype
(
wei_ek_global_desc
),
decltype
(
wei_ek_block_desc
),
decltype
(
wei_ek_block_desc
.
GetLengths
()),
WeiBlockCopyDataPerRead
>
{};
#endif
// a series of blockwise GEMM
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx and b_mtx saved in LDS, c_mtx saved in register
// a_mtx[C,K] is a sub-matrix of wei_block[C,Y,X,K]
// b_mtx[C,B] is a subset of in_block[C,B + BGhostRead]
// c_mtx[K,B] is out_block[K,B]
constexpr
auto
a_cxk_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
CPerBlock
>
{},
Number
<
KPerBlock
>
{},
Number
<
wei_cyxk_block_desc
.
GetStride
(
I0
)
>
{});
constexpr
auto
b_cxb_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
CPerBlock
>
{},
Number
<
BPerBlock
>
{},
Number
<
in_cb_block_desc
.
GetStride
(
I0
)
>
{});
constexpr
auto
c_kxb_thread_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThread
>
{},
Number
<
BPerThread
>
{});
const
auto
blockwise_gemm
=
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
<
BlockSize
,
decltype
(
a_cxk_block_mtx_desc
),
decltype
(
b_cxb_block_mtx_desc
),
decltype
(
c_kxb_thread_mtx_desc
),
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
GemmDataPerReadA
,
GemmDataPerReadB
>
{};
// LDS: be careful of alignment
constexpr
index_t
max_align
=
math
::
lcm
(
index_t
(
4
),
InBlockCopyDataPerRead
,
WeiBlockCopyDataPerRead
);
constexpr
index_t
in_block_space
=
in_cb_block_desc
.
GetElementSpace
(
Number
<
max_align
>
{});
constexpr
index_t
wei_block_space
=
wei_cyxk_block_desc
.
GetElementSpace
(
Number
<
max_align
>
{});
__shared__
Float
p_in_block
[
in_block_space
];
__shared__
Float
p_wei_block
[
wei_block_space
];
const
Float
*
p_in_global_block_offset
=
p_in_global
+
in_cb_global_desc
.
GetOffsetFromMultiIndex
(
0
,
b_block_data_begin
);
const
Float
*
p_wei_global_block_offset
=
p_wei_global
+
wei_cyxk_global_desc
.
GetOffsetFromMultiIndex
(
0
,
0
,
0
,
k_block_data_begin
);
// register
Float
p_out_thread
[
out_kb_thread_desc
.
GetElementSpace
()];
// set threadwise output to 0
threadwise_matrix_set_zero
(
c_kxb_thread_mtx_desc
,
p_out_thread
);
for
(
index_t
c_block_data_begin
=
0
;
c_block_data_begin
<
C
;
c_block_data_begin
+=
CPerBlock
,
p_in_global_block_offset
+=
CPerBlock
*
in_cb_global_desc
.
GetStride
(
I0
),
p_wei_global_block_offset
+=
CPerBlock
*
wei_cyxk_global_desc
.
GetStride
(
I0
),
__syncthreads
())
{
// load data
Float
p_in_register_clipboard
[
blockwise_in_copy
.
GetRegisterClipboardSize
()];
Float
p_wei_register_clipboard
[
blockwise_wei_copy
.
GetRegisterClipboardSize
()];
blockwise_in_copy
.
RunLoadRegisterClipboard
(
p_in_global_block_offset
,
p_in_register_clipboard
);
blockwise_wei_copy
.
RunLoadRegisterClipboard
(
p_wei_global_block_offset
,
p_wei_register_clipboard
);
blockwise_in_copy
.
RunStoreRegisterClipboard
(
p_in_register_clipboard
,
p_in_block
);
blockwise_wei_copy
.
RunStoreRegisterClipboard
(
p_wei_register_clipboard
,
p_wei_block
);
__syncthreads
();
// compute on current data
// a series of GEMM
for
(
index_t
y
=
0
;
y
<
Y
;
++
y
)
{
for
(
index_t
x
=
0
;
x
<
X
;
++
x
)
{
#if 1
blockwise_gemm
.
Run
#elif 0
blockwise_gemm
.
Run_RegisterDoubleBuffer
#elif 1
blockwise_gemm
.
Run_amd_asm
#endif
(
p_wei_block
+
wei_cyxk_block_desc
.
GetOffsetFromMultiIndex
(
0
,
y
,
x
,
0
),
p_in_block
+
y
*
Wi
+
x
,
p_out_thread
);
}
}
}
// output: register to global mem,
const
auto
c_thread_mtx_begin
=
blockwise_gemm
.
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
const
index_t
k_thread_data_begin
=
k_block_data_begin
+
c_thread_mtx_begin
.
row
;
const
index_t
b_thread_data_begin
=
b_block_data_begin
+
c_thread_mtx_begin
.
col
;
for
(
index_t
k
=
0
;
k
<
out_kb_thread_desc
.
GetLength
(
I0
);
++
k
)
{
for
(
index_t
b
=
0
;
b
<
out_kb_thread_desc
.
GetLength
(
I1
);
++
b
)
{
const
auto
c_thread_mtx_distance
=
blockwise_gemm
.
GetDistanceFromBeginOfThreadMatrixC
(
k
,
b
);
index_t
k_data
=
k_thread_data_begin
+
c_thread_mtx_distance
.
row
;
index_t
b_data
=
b_thread_data_begin
+
c_thread_mtx_distance
.
col
;
index_t
h_data
=
b_data
/
(
Wi
*
N
);
index_t
itmp
=
b_data
-
h_data
*
(
Wi
*
N
);
index_t
w_data
=
itmp
/
N
;
index_t
n_data
=
itmp
-
w_data
*
N
;
if
(
n_data
<
N
&&
h_data
<
Ho
&&
w_data
<
Wo
)
{
p_out_global
[
out_khwn_global_desc
.
GetOffsetFromMultiIndex
(
k_data
,
h_data
,
w_data
,
n_data
)]
=
p_out_thread
[
out_kb_thread_desc
.
GetOffsetFromMultiIndex
(
k
,
b
)];
}
}
}
}
};
}
// namespace ck
#endif
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn_lds_double_buffer.hpp
deleted
100644 → 0
View file @
f0716f5b
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V2_CHWN_CYXK_KHWN_LDS_DOUBLE_BUFFER
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V2_CHWN_CYXK_KHWN_LDS_DOUBLE_BUFFER
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_4d_tensor_op.hpp"
#include "blockwise_2d_tensor_op.hpp"
#include "threadwise_tensor_slice_copy.hpp"
#include "blockwise_gemm.hpp"
namespace
ck
{
// define B = flatten(N, Hi, Wi)
template
<
index_t
GridSize
,
index_t
BlockSize
,
class
Float
,
class
InGlobalDesc
,
class
WeiGlobalDesc
,
class
OutGlobalDesc
,
index_t
BPerBlock
,
index_t
KPerBlock
,
index_t
CPerBlock
,
index_t
BPerThread
,
index_t
KPerThread
,
index_t
GemmMPerThreadSubC
,
index_t
GemmNPerThreadSubC
,
index_t
GemmMLevel0Cluster
,
index_t
GemmNLevel0Cluster
,
index_t
GemmMLevel1Cluster
,
index_t
GemmNLevel1Cluster
,
index_t
GemmKPerThreadLoop
,
index_t
GemmDataPerReadA
,
index_t
GemmDataPerReadB
,
index_t
InBlockCopyThreadPerDim0
,
index_t
InBlockCopyThreadPerDim1
,
index_t
WeiBlockCopyThreadPerDim0
,
index_t
WeiBlockCopyThreadPerDim1
,
index_t
InBlockCopyDataPerRead
,
index_t
WeiBlockCopyDataPerRead
,
index_t
OutThreadCopyDataPerWrite
>
struct
GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
{
__device__
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
const
Float
*
const
__restrict__
p_wei_global
,
Float
*
const
__restrict__
p_out_global
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
in_chwn_global_desc
=
InGlobalDesc
{};
constexpr
auto
wei_cyxk_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
out_khwn_global_desc
=
OutGlobalDesc
{};
constexpr
index_t
C
=
in_chwn_global_desc
.
GetLength
(
I0
);
constexpr
index_t
Hi
=
in_chwn_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Wi
=
in_chwn_global_desc
.
GetLength
(
I2
);
constexpr
index_t
N
=
in_chwn_global_desc
.
GetLength
(
I3
);
constexpr
index_t
K
=
out_khwn_global_desc
.
GetLength
(
I0
);
constexpr
index_t
Ho
=
out_khwn_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Wo
=
out_khwn_global_desc
.
GetLength
(
I2
);
constexpr
index_t
Y
=
wei_cyxk_global_desc
.
GetLength
(
I1
);
constexpr
index_t
X
=
wei_cyxk_global_desc
.
GetLength
(
I2
);
constexpr
index_t
B
=
N
*
Hi
*
Wi
;
constexpr
index_t
BGhostRead
=
(
Y
-
1
)
*
Wi
+
(
X
-
1
);
// assert for LDS double buffer
static_assert
(
C
%
(
2
*
CPerBlock
)
==
0
,
"C cannot be evenly divided"
);
// divide block work by 2d: [K, B]
constexpr
index_t
KBlockWork
=
(
K
+
KPerBlock
-
1
)
/
KPerBlock
;
constexpr
index_t
BBlockWork
=
(
B
+
BPerBlock
-
1
)
/
BPerBlock
;
const
index_t
k_block_work_id
=
get_block_1d_id
()
/
BBlockWork
;
const
index_t
b_block_work_id
=
get_block_1d_id
()
-
k_block_work_id
*
BBlockWork
;
const
index_t
k_block_data_begin
=
k_block_work_id
*
KPerBlock
;
const
index_t
b_block_data_begin
=
b_block_work_id
*
BPerBlock
;
// flattend (2d) tensor view of gridwise input
constexpr
auto
in_cb_global_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
C
,
B
>
{});
constexpr
auto
wei_ek_global_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
C
*
Y
*
X
,
K
>
{});
// tensor view of blockwise input and weight
// be careful of alignment
constexpr
auto
in_cb_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
,
BPerBlock
+
BGhostRead
>
{},
Number
<
InBlockCopyDataPerRead
>
{});
constexpr
auto
wei_ek_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
*
Y
*
X
,
KPerBlock
>
{},
Number
<
WeiBlockCopyDataPerRead
>
{});
constexpr
auto
wei_cyxk_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
,
Y
,
X
,
KPerBlock
>
{},
Number
<
WeiBlockCopyDataPerRead
>
{});
// tensor view of threadwise output in register
constexpr
auto
out_kb_thread_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
KPerThread
,
BPerThread
>
{});
// blockwise in copy
// formmat is [CPerBlock,BPerBlock + BGhostRead]
#if 0
const auto blockwise_in_copy =
Blockwise2dTensorCopy1<BlockSize,
Float,
decltype(in_cb_global_desc),
decltype(in_cb_block_desc),
decltype(in_cb_block_desc.GetLengths())>{};
#elif
0
const
auto
blockwise_in_copy
=
Blockwise2dTensorCopy2
<
BlockSize
,
Float
,
decltype
(
in_cb_global_desc
),
decltype
(
in_cb_block_desc
),
decltype
(
in_cb_block_desc
.
GetLengths
()),
InBlockCopyThreadPerDim0
,
InBlockCopyThreadPerDim1
>
{};
#elif 1
const
auto
blockwise_in_copy
=
Blockwise2dTensorCopy3
<
BlockSize
,
Float
,
decltype
(
in_cb_global_desc
),
decltype
(
in_cb_block_desc
),
decltype
(
in_cb_block_desc
.
GetLengths
()),
InBlockCopyDataPerRead
>
{};
#endif
// blockwise wei copy
// format is [CPerBlock*Y*X,KPerBlock]
#if 0
const auto blockwise_wei_copy =
Blockwise2dTensorCopy1<BlockSize,
Float,
decltype(wei_ek_global_desc),
decltype(wei_ek_block_desc),
decltype(wei_ek_block_desc.GetLengths())>{};
#elif
0
const
auto
blockwise_wei_copy
=
Blockwise2dTensorCopy2
<
BlockSize
,
Float
,
decltype
(
wei_ek_global_desc
),
decltype
(
wei_ek_block_desc
),
decltype
(
wei_ek_block_desc
.
GetLengths
()),
WeiBlockCopyThreadPerDim0
,
WeiBlockCopyThreadPerDim1
>
{};
#elif 1
const
auto
blockwise_wei_copy
=
Blockwise2dTensorCopy3
<
BlockSize
,
Float
,
decltype
(
wei_ek_global_desc
),
decltype
(
wei_ek_block_desc
),
decltype
(
wei_ek_block_desc
.
GetLengths
()),
WeiBlockCopyDataPerRead
>
{};
#endif
// a series of blockwise GEMM
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx and b_mtx saved in LDS, c_mtx saved in register
// a_mtx[C,K] is a sub-matrix of wei_block[C,Y,X,K]
// b_mtx[C,B] is a subset of in_block[C,B + BGhostRead]
// c_mtx[K,B] is out_block[K,B]
constexpr
auto
a_cxk_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
CPerBlock
>
{},
Number
<
KPerBlock
>
{},
Number
<
wei_cyxk_block_desc
.
GetStride
(
I0
)
>
{});
constexpr
auto
b_cxb_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
CPerBlock
>
{},
Number
<
BPerBlock
>
{},
Number
<
in_cb_block_desc
.
GetStride
(
I0
)
>
{});
constexpr
auto
c_kxb_thread_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThread
>
{},
Number
<
BPerThread
>
{});
const
auto
blockwise_gemm
=
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
<
BlockSize
,
decltype
(
a_cxk_block_mtx_desc
),
decltype
(
b_cxb_block_mtx_desc
),
decltype
(
c_kxb_thread_mtx_desc
),
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
GemmDataPerReadA
,
GemmDataPerReadB
>
{};
// LDS: be careful of alignment
constexpr
index_t
max_align
=
math
::
lcm
(
index_t
(
4
),
InBlockCopyDataPerRead
,
WeiBlockCopyDataPerRead
);
constexpr
index_t
in_block_space
=
in_cb_block_desc
.
GetElementSpace
(
Number
<
max_align
>
{});
constexpr
index_t
wei_block_space
=
wei_cyxk_block_desc
.
GetElementSpace
(
Number
<
max_align
>
{});
// LDS double buffer
__shared__
Float
p_in_block_double
[
2
*
in_block_space
];
__shared__
Float
p_wei_block_double
[
2
*
wei_block_space
];
const
Float
*
p_in_global_block_offset
=
p_in_global
+
in_cb_global_desc
.
GetOffsetFromMultiIndex
(
0
,
b_block_data_begin
);
const
Float
*
p_wei_global_block_offset
=
p_wei_global
+
wei_cyxk_global_desc
.
GetOffsetFromMultiIndex
(
0
,
0
,
0
,
k_block_data_begin
);
// preload data into LDS
{
Float
p_in_register_clipboard
[
blockwise_in_copy
.
GetRegisterClipboardSize
()];
Float
p_wei_register_clipboard
[
blockwise_wei_copy
.
GetRegisterClipboardSize
()];
blockwise_in_copy
.
RunLoadRegisterClipboard
(
p_in_global_block_offset
,
p_in_register_clipboard
);
blockwise_wei_copy
.
RunLoadRegisterClipboard
(
p_wei_global_block_offset
,
p_wei_register_clipboard
);
blockwise_in_copy
.
RunStoreRegisterClipboard
(
p_in_register_clipboard
,
p_in_block_double
);
blockwise_wei_copy
.
RunStoreRegisterClipboard
(
p_wei_register_clipboard
,
p_wei_block_double
);
}
// register
Float
p_out_thread
[
out_kb_thread_desc
.
GetElementSpace
()];
// set threadwise output to 0
threadwise_matrix_set_zero
(
c_kxb_thread_mtx_desc
,
p_out_thread
);
for
(
index_t
c_block_data_begin
=
0
;
c_block_data_begin
+
2
*
CPerBlock
<
C
;
c_block_data_begin
+=
2
*
CPerBlock
)
{
#pragma unroll
for
(
index_t
iloop
=
0
;
iloop
<
2
;
++
iloop
)
{
const
bool
even_loop
=
(
iloop
%
2
==
0
);
Float
*
p_in_block_now
=
even_loop
?
p_in_block_double
:
p_in_block_double
+
in_block_space
;
Float
*
p_wei_block_now
=
even_loop
?
p_wei_block_double
:
p_wei_block_double
+
wei_block_space
;
Float
*
p_in_block_next
=
even_loop
?
p_in_block_double
+
in_block_space
:
p_in_block_double
;
Float
*
p_wei_block_next
=
even_loop
?
p_wei_block_double
+
wei_block_space
:
p_wei_block_double
;
// load next data
Float
p_in_register_clipboard
[
blockwise_in_copy
.
GetRegisterClipboardSize
()];
Float
p_wei_register_clipboard
[
blockwise_wei_copy
.
GetRegisterClipboardSize
()];
p_in_global_block_offset
+=
CPerBlock
*
in_cb_global_desc
.
GetStride
(
I0
);
p_wei_global_block_offset
+=
CPerBlock
*
wei_cyxk_global_desc
.
GetStride
(
I0
);
__syncthreads
();
blockwise_in_copy
.
RunLoadRegisterClipboard
(
p_in_global_block_offset
,
p_in_register_clipboard
);
blockwise_wei_copy
.
RunLoadRegisterClipboard
(
p_wei_global_block_offset
,
p_wei_register_clipboard
);
// compute on current data
// a series of GEMM
for
(
index_t
y
=
0
;
y
<
Y
;
++
y
)
{
for
(
index_t
x
=
0
;
x
<
X
;
++
x
)
{
#if 1
blockwise_gemm
.
Run
#elif 0
blockwise_gemm
.
Run_RegisterDoubleBuffer
#elif 0
blockwise_gemm
.
Run_amd_asm
#endif
(
p_wei_block_now
+
wei_cyxk_block_desc
.
GetOffsetFromMultiIndex
(
0
,
y
,
x
,
0
),
p_in_block_now
+
y
*
Wi
+
x
,
p_out_thread
);
}
}
blockwise_in_copy
.
RunStoreRegisterClipboard
(
p_in_register_clipboard
,
p_in_block_next
);
blockwise_wei_copy
.
RunStoreRegisterClipboard
(
p_wei_register_clipboard
,
p_wei_block_next
);
}
}
// tail
{
// even
p_in_global_block_offset
+=
CPerBlock
*
in_cb_global_desc
.
GetStride
(
I0
);
p_wei_global_block_offset
+=
CPerBlock
*
wei_cyxk_global_desc
.
GetStride
(
I0
);
__syncthreads
();
Float
p_in_register_clipboard
[
blockwise_in_copy
.
GetRegisterClipboardSize
()];
Float
p_wei_register_clipboard
[
blockwise_wei_copy
.
GetRegisterClipboardSize
()];
blockwise_in_copy
.
RunLoadRegisterClipboard
(
p_in_global_block_offset
,
p_in_register_clipboard
);
blockwise_wei_copy
.
RunLoadRegisterClipboard
(
p_wei_global_block_offset
,
p_wei_register_clipboard
);
for
(
index_t
y
=
0
;
y
<
Y
;
++
y
)
{
for
(
index_t
x
=
0
;
x
<
X
;
++
x
)
{
#if 1
blockwise_gemm
.
Run
#elif 0
blockwise_gemm
.
Run_RegisterDoubleBuffer
#elif 0
blockwise_gemm
.
Run_amd_asm
#endif
(
p_wei_block_double
+
wei_cyxk_block_desc
.
GetOffsetFromMultiIndex
(
0
,
y
,
x
,
0
),
p_in_block_double
+
y
*
Wi
+
x
,
p_out_thread
);
}
}
blockwise_in_copy
.
RunStoreRegisterClipboard
(
p_in_register_clipboard
,
p_in_block_double
+
in_block_space
);
blockwise_wei_copy
.
RunStoreRegisterClipboard
(
p_wei_register_clipboard
,
p_wei_block_double
+
wei_block_space
);
// odd
__syncthreads
();
for
(
index_t
y
=
0
;
y
<
Y
;
++
y
)
{
for
(
index_t
x
=
0
;
x
<
X
;
++
x
)
{
#if 1
blockwise_gemm
.
Run
#elif 0
blockwise_gemm
.
Run_RegisterDoubleBuffer
#elif 0
blockwise_gemm
.
Run_amd_asm
#endif
(
p_wei_block_double
+
wei_block_space
+
wei_cyxk_block_desc
.
GetOffsetFromMultiIndex
(
0
,
y
,
x
,
0
),
p_in_block_double
+
in_block_space
+
y
*
Wi
+
x
,
p_out_thread
);
}
}
}
// output: register to global mem,
const
auto
c_thread_mtx_begin
=
blockwise_gemm
.
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
const
index_t
k_thread_data_begin
=
k_block_data_begin
+
c_thread_mtx_begin
.
row
;
const
index_t
b_thread_data_begin
=
b_block_data_begin
+
c_thread_mtx_begin
.
col
;
if
(
Y
==
1
&&
X
==
1
)
{
// pure 1x1 conv (non padding, 1x1 stride)
constexpr
index_t
K2_
=
GemmMPerThreadSubC
;
constexpr
index_t
K1_
=
KPerBlock
/
KPerThread
;
constexpr
index_t
B2_
=
GemmNPerThreadSubC
;
constexpr
index_t
B1_
=
BPerBlock
/
BPerThread
;
constexpr
auto
out_6d_global_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
K
/
(
K1_
*
K2_
),
K1_
,
K2_
,
B
/
(
B1_
*
B2_
),
B1_
,
B2_
>
{});
constexpr
auto
out_6d_thread_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
KPerBlock
/
(
K1_
*
K2_
),
1
,
K2_
,
BPerBlock
/
(
B1_
*
B2_
),
1
,
B2_
>
{});
constexpr
auto
out_kb_global_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
K
,
B
>
{});
threadwise_6d_tensor_copy
(
out_6d_thread_desc
,
p_out_thread
,
out_6d_global_desc
,
p_out_global
+
out_kb_global_desc
.
GetOffsetFromMultiIndex
(
k_thread_data_begin
,
b_thread_data_begin
),
out_6d_thread_desc
.
GetLengths
(),
Number
<
OutThreadCopyDataPerWrite
>
{});
}
else
{
for
(
index_t
k
=
0
;
k
<
out_kb_thread_desc
.
GetLength
(
I0
);
++
k
)
{
for
(
index_t
b
=
0
;
b
<
out_kb_thread_desc
.
GetLength
(
I1
);
++
b
)
{
const
auto
c_thread_mtx_distance
=
blockwise_gemm
.
GetDistanceFromBeginOfThreadMatrixC
(
k
,
b
);
index_t
k_data
=
k_thread_data_begin
+
c_thread_mtx_distance
.
row
;
index_t
b_data
=
b_thread_data_begin
+
c_thread_mtx_distance
.
col
;
index_t
h_data
=
b_data
/
(
Wi
*
N
);
index_t
itmp
=
b_data
-
h_data
*
(
Wi
*
N
);
index_t
w_data
=
itmp
/
N
;
index_t
n_data
=
itmp
-
w_data
*
N
;
if
(
n_data
<
N
&&
h_data
<
Ho
&&
w_data
<
Wo
)
{
p_out_global
[
out_khwn_global_desc
.
GetOffsetFromMultiIndex
(
k_data
,
h_data
,
w_data
,
n_data
)]
=
p_out_thread
[
out_kb_thread_desc
.
GetOffsetFromMultiIndex
(
k
,
b
)];
}
}
}
}
}
};
}
// namespace ck
#endif
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp
deleted
100644 → 0
View file @
f0716f5b
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V3_NCHW_CYXK_NKHW
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V3_NCHW_CYXK_NKHW
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMergedTensorDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp"
#include "blockwise_gemm.hpp"
namespace
ck
{
// define B = merge(N0, Ho, Wo)
template
<
index_t
GridSize
,
index_t
BlockSize
,
class
Float
,
class
InGlobalDesc
,
class
WeiGlobalDesc
,
class
OutGlobalDesc
,
index_t
BPerBlock
,
index_t
KPerBlock
,
index_t
CPerBlock
,
index_t
N1
,
index_t
N2
,
index_t
GemmMPerThreadSubC
,
index_t
GemmNPerThreadSubC
,
index_t
GemmMLevel0Cluster
,
index_t
GemmNLevel0Cluster
,
index_t
GemmMLevel1Cluster
,
index_t
GemmNLevel1Cluster
,
index_t
GemmKPerThreadLoop
,
index_t
GemmDataPerReadA
,
index_t
GemmDataPerReadB
,
class
InBlockCopySubLengths_C_N1_B_N2
,
class
InBlockCopyClusterLengths_C_N1_B_N2
,
index_t
InBlockCopySrcDataPerRead_B
,
index_t
InBlockCopyDstDataPerWrite_N2
,
class
WeiBlockCopySubLengths_C_K
,
class
WeiBlockCopyClusterLengths_C_K
,
index_t
WeiBlockCopyDataPerAccess_K
>
struct
GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
{
__device__
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
const
Float
*
const
__restrict__
p_wei_global
,
Float
*
const
__restrict__
p_out_global
)
const
{
// this is a mess
// TODO: find more elegent way of specifying (or calculating) performance parameters
static_assert
(
N2
==
GemmNPerThreadSubC
,
"wrong!"
);
static_assert
((
N1
*
N2
*
BPerBlock
)
%
(
GemmNPerThreadSubC
*
GemmNLevel0Cluster
*
GemmNLevel1Cluster
)
==
0
,
"wrong!"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I4
=
Number
<
4
>
{};
constexpr
auto
I5
=
Number
<
5
>
{};
constexpr
auto
I6
=
Number
<
6
>
{};
constexpr
auto
I7
=
Number
<
7
>
{};
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
False
=
integral_constant
<
bool
,
false
>
{};
constexpr
auto
in_n_c_h_w_global_desc
=
InGlobalDesc
{};
constexpr
auto
wei_c_y_x_k_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
out_n_k_h_w_global_desc
=
OutGlobalDesc
{};
constexpr
index_t
N
=
in_n_c_h_w_global_desc
.
GetLength
(
I0
);
constexpr
index_t
C
=
in_n_c_h_w_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Hi
=
in_n_c_h_w_global_desc
.
GetLength
(
I2
);
constexpr
index_t
Wi
=
in_n_c_h_w_global_desc
.
GetLength
(
I3
);
constexpr
index_t
K
=
out_n_k_h_w_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Ho
=
out_n_k_h_w_global_desc
.
GetLength
(
I2
);
constexpr
index_t
Wo
=
out_n_k_h_w_global_desc
.
GetLength
(
I3
);
constexpr
index_t
Y
=
wei_c_y_x_k_global_desc
.
GetLength
(
I1
);
constexpr
index_t
X
=
wei_c_y_x_k_global_desc
.
GetLength
(
I2
);
static_assert
(
N
%
(
N1
*
N2
)
==
0
,
"wrong! cannot divice N evenly among thread"
);
constexpr
index_t
N0
=
N
/
(
N1
*
N2
);
constexpr
index_t
B
=
N0
*
Ho
*
Wo
;
// divide block work by [K, B]
static_assert
(
K
%
KPerBlock
==
0
&&
B
%
BPerBlock
==
0
&&
C
%
CPerBlock
==
0
,
"wrong! cannot divide work evenly among block"
);
constexpr
index_t
KBlockWork
=
K
/
KPerBlock
;
constexpr
index_t
BBlockWork
=
B
/
BPerBlock
;
constexpr
auto
block_work_desc
=
make_ConstantTensorDescriptor_packed
(
Sequence
<
KBlockWork
,
BBlockWork
>
{});
const
auto
block_work_multi_id
=
block_work_desc
.
GetMultiIndexFrom1dIndex
(
get_block_1d_id
());
const
index_t
k_block_data_on_global
=
block_work_multi_id
[
0
]
*
KPerBlock
;
const
index_t
b_block_data_on_global
=
block_work_multi_id
[
1
]
*
BPerBlock
;
// input tensor
// memory layout descriptor in device memory [N0, N1, N2, C, H, W]
constexpr
auto
in_n0_n1_n2_c_h_w_global_mem_desc
=
in_n_c_h_w_global_desc
.
Fold
(
I0
,
Number
<
N1
>
{},
Number
<
N2
>
{});
// merged tensor descriptor in device memory [C, N1, B, N2], src of blockwise copy
constexpr
auto
in_c_n1_b_n2_global_merged_desc
=
make_ConstantMergedTensorDescriptor
(
in_n0_n1_n2_c_h_w_global_mem_desc
.
Slice
(
I4
,
Number
<
Ho
>
{}).
Slice
(
I5
,
Number
<
Wo
>
{}),
Sequence
<
3
>
{},
Sequence
<
1
>
{},
Sequence
<
0
,
4
,
5
>
{},
Sequence
<
2
>
{});
// memory layout descriptor in LDS [C, N1, B, N2], dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
in_c_n1_b_n2_block_mem_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
,
N1
,
BPerBlock
,
N2
>
{},
Number
<
InBlockCopyDstDataPerWrite_N2
>
{});
// this check is ad-hoc
// TODO: need to properly implement tensor descriptor with alignment
static_assert
(
in_c_n1_b_n2_block_mem_desc
.
GetStride
(
I1
)
%
GemmDataPerReadB
==
0
,
"GemmDataPerReadB alignment requirement is not satisfied"
);
// input blockwise copy
// slice a merged tensor, reorder and copy to a normal tensor
// this copy operator already has blockwise offset built-in
auto
blockwise_in_copy
=
BlockwiseGenericTensorSliceCopy_v1
<
BlockSize
,
Float
,
decltype
(
in_c_n1_b_n2_global_merged_desc
),
decltype
(
in_c_n1_b_n2_block_mem_desc
),
decltype
(
in_c_n1_b_n2_block_mem_desc
.
GetLengths
()),
InBlockCopySubLengths_C_N1_B_N2
,
InBlockCopyClusterLengths_C_N1_B_N2
,
Sequence
<
0
,
1
,
3
,
2
>
,
// thread_arrange_order [C, N1, N2, B]
Sequence
<
1
,
3
,
0
,
2
>
,
// src_access_order [N1, N2, C, B]
Sequence
<
0
,
1
,
2
,
3
>
,
// dst_access_order [C, N1, B, N2]
InBlockCopySrcDataPerRead_B
,
InBlockCopyDstDataPerWrite_N2
>
({
0
,
0
,
b_block_data_on_global
,
0
},
{
0
,
0
,
0
,
0
});
// weight tensor
// tensor descriptor in device memory, src of blockwise copy
constexpr
auto
wei_c_k_global_desc
=
wei_c_y_x_k_global_desc
.
Extract
(
I0
,
I3
);
// tensor descriptor in LDS, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
wei_c_k_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
,
KPerBlock
>
{},
Number
<
math
::
lcm
(
WeiBlockCopyDataPerAccess_K
,
GemmDataPerReadA
)
>
{});
// operator for blockwise copy of weight into LDS
// slice a tensor, and copy it into another tensor
// this copy operator already have blockwise offset built-in
auto
blockwise_wei_copy
=
BlockwiseGenericTensorSliceCopy_v1
<
BlockSize
,
Float
,
decltype
(
wei_c_k_global_desc
),
decltype
(
wei_c_k_block_desc
),
decltype
(
wei_c_k_block_desc
.
GetLengths
()),
WeiBlockCopySubLengths_C_K
,
WeiBlockCopyClusterLengths_C_K
,
Sequence
<
0
,
1
>
,
// thread_arrange_order [C, K]
Sequence
<
0
,
1
>
,
// src_access_order [C, K]
Sequence
<
0
,
1
>
,
// dst_access_order [C, K]
WeiBlockCopyDataPerAccess_K
,
WeiBlockCopyDataPerAccess_K
>
(
{
0
,
k_block_data_on_global
},
{
0
,
0
});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[CPerBlock, KPerBlock] is in LDS
// b_mtx[CPerBlocl, N1 * BPerBlock * N2] is in LDS
// c_mtx[KPerBlock, N1 * BPerBlock * N2] is distributed among threads, and saved in
// register
constexpr
auto
a_c_k_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
CPerBlock
>
{},
Number
<
KPerBlock
>
{},
Number
<
wei_c_k_block_desc
.
GetStride
(
I0
)
>
{});
constexpr
auto
b_c_n1bn2_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
CPerBlock
>
{},
Number
<
N1
*
BPerBlock
*
N2
>
{},
Number
<
in_c_n1_b_n2_block_mem_desc
.
GetStride
(
I0
)
>
{});
// sanity check
static_assert
(
KPerBlock
%
(
GemmMPerThreadSubC
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
)
==
0
,
"wrong!"
);
constexpr
index_t
GemmMRepeat
=
KPerBlock
/
(
GemmMPerThreadSubC
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
);
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
constexpr
auto
c_k0k2_n1n2_thread_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
GemmMRepeat
*
GemmMPerThreadSubC
>
{},
Number
<
N1
*
N2
>
{});
const
auto
blockwise_gemm
=
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
<
BlockSize
,
decltype
(
a_c_k_block_mtx_desc
),
decltype
(
b_c_n1bn2_block_mtx_desc
),
decltype
(
c_k0k2_n1n2_thread_mtx_desc
),
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
GemmDataPerReadA
,
GemmDataPerReadB
>
{};
// choose GEMM implementation here
const
auto
run_blockwise_gemm
=
[
&
](
auto
...
Xs
)
{
#if 1
return
blockwise_gemm
.
Run
(
Xs
...);
#else
return
blockwise_gemm
.
Run_amd_asm
(
Xs
...);
#endif
};
// LDS allocation for input and weight: be careful of alignment
constexpr
index_t
max_align
=
math
::
lcm
(
InBlockCopyDstDataPerWrite_N2
,
WeiBlockCopyDataPerAccess_K
,
GemmDataPerReadA
,
GemmDataPerReadB
);
constexpr
index_t
in_block_space
=
in_c_n1_b_n2_block_mem_desc
.
GetElementSpace
(
Number
<
max_align
>
{});
constexpr
index_t
wei_block_space
=
wei_c_k_block_desc
.
GetElementSpace
(
Number
<
max_align
>
{});
__shared__
Float
p_in_block
[
in_block_space
];
__shared__
Float
p_wei_block
[
wei_block_space
];
// register allocation for output
Float
p_out_thread
[
c_k0k2_n1n2_thread_mtx_desc
.
GetElementSpace
()];
// zero out threadwise output
threadwise_matrix_set_zero
(
c_k0k2_n1n2_thread_mtx_desc
,
p_out_thread
);
#if 0
// do work
for(index_t y = 0; y < Y; ++y)
{
for(index_t x = 0; x < X; ++x)
{
// calculate origin of block input and weight tensor on global memory
const Float* p_in_block_on_global =
p_in_global + in_n_c_h_w_global_desc.GetOffsetFromMultiIndex(0, 0, y, x);
const Float* p_wei_block_on_global =
p_wei_global + wei_c_y_x_k_global_desc.GetOffsetFromMultiIndex(0, y, x, 0);
for(index_t
c_block_data_on_global = 0;
c_block_data_on_global < C;
c_block_data_on_global += CPerBlock,
p_in_block_on_global += CPerBlock * in_n_c_h_w_global_desc.GetStride(I1),
p_wei_block_on_global += CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0))
{
blockwise_in_copy.Run(p_in_block_on_global, p_in_block);
blockwise_wei_copy.Run(p_wei_block_on_global, p_wei_block);
__syncthreads();
run_blockwise_gemm(p_wei_block, p_in_block, p_out_thread);
__syncthreads();
}
}
}
#else
for
(
index_t
y
=
0
;
y
<
Y
;
++
y
)
{
for
(
index_t
x
=
0
;
x
<
X
;
++
x
)
{
// calculate origin of block input and weight tensor on global memory
const
Float
*
p_in_block_on_global
=
p_in_global
+
in_n_c_h_w_global_desc
.
GetOffsetFromMultiIndex
(
0
,
0
,
y
,
x
);
const
Float
*
p_wei_block_on_global
=
p_wei_global
+
wei_c_y_x_k_global_desc
.
GetOffsetFromMultiIndex
(
0
,
y
,
x
,
0
);
for
(
index_t
c_block_data_on_global
=
0
;
c_block_data_on_global
<
C
;
c_block_data_on_global
+=
CPerBlock
)
{
blockwise_in_copy
.
Run
(
p_in_block_on_global
,
p_in_block
);
blockwise_wei_copy
.
Run
(
p_wei_block_on_global
,
p_wei_block
);
__syncthreads
();
blockwise_gemm
.
Run
(
p_wei_block
,
p_in_block
,
p_out_thread
);
__syncthreads
();
blockwise_in_copy
.
MoveSlicingWindowOnSourceTensor
(
I0
,
Number
<
CPerBlock
>
{},
True
);
blockwise_wei_copy
.
MoveSlicingWindowOnSourceTensor
(
I0
,
Number
<
CPerBlock
>
{},
True
);
}
// reset C
blockwise_in_copy
.
MoveSlicingWindowOnSourceTensor
(
I0
,
Number
<
C
>
{},
False
);
blockwise_wei_copy
.
MoveSlicingWindowOnSourceTensor
(
I0
,
Number
<
C
>
{},
False
);
}
}
#endif
// copy output: register to global memory
{
constexpr
index_t
K2
=
GemmMPerThreadSubC
;
constexpr
index_t
K1
=
GemmMLevel0Cluster
*
GemmMLevel1Cluster
;
constexpr
index_t
K0
=
K
/
(
K1
*
K2
);
// define tensor descriptor for threadwise copy
// output memory layout descriptor in register
constexpr
auto
out_k0_k1_k2_n1_n0_h_w_n2_thread_mem_desc
=
make_ConstantTensorDescriptor_packed
(
Sequence
<
KPerBlock
/
(
K1
*
K2
),
1
,
K2
,
N1
,
1
,
1
,
1
,
N2
>
{});
// output tensor descriptor in register, src of threadwise copy
constexpr
auto
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc
=
out_k0_k1_k2_n1_n0_h_w_n2_thread_mem_desc
.
ReorderGivenNew2Old
(
Sequence
<
4
,
3
,
7
,
0
,
1
,
2
,
5
,
6
>
{});
// output memory layout descriptor in device memory, dst of threadwise copy
constexpr
auto
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc
=
out_n_k_h_w_global_desc
.
Fold
(
I1
,
Number
<
K1
>
{},
Number
<
K2
>
{})
.
Fold
(
I0
,
Number
<
N1
>
{},
Number
<
N2
>
{});
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
blockwise_gemm
.
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
const
index_t
k_thread_data_on_global
=
k_block_data_on_global
+
c_thread_mtx_on_block
.
row
;
const
index_t
b_thread_data_on_global
=
b_block_data_on_global
+
c_thread_mtx_on_block
.
col
/
N2
;
// output merged global tensor descriptor, for calculating origin of thread tensor
// in global memory
constexpr
auto
out_k_n1_b_n2_global_merged_desc
=
make_ConstantMergedTensorDescriptor
(
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc
.
Unfold
(
I3
,
I5
),
Sequence
<
3
>
{},
Sequence
<
1
>
{},
Sequence
<
0
,
4
,
5
>
{},
Sequence
<
2
>
{});
// origin of dst in device memory
Float
*
p_out_thread_on_global
=
p_out_global
+
out_k_n1_b_n2_global_merged_desc
.
GetOffsetFromMultiIndex
(
k_thread_data_on_global
,
0
,
b_thread_data_on_global
,
0
);
threadwise_generic_tensor_slice_copy_v1
(
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc
,
p_out_thread
,
{
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
},
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc
,
p_out_thread_on_global
,
{
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
},
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc
.
GetLengths
(),
arithmetic_sequence_gen
<
0
,
8
,
1
>::
type
{},
Number
<
1
>
{});
}
}
};
}
// namespace ck
#endif
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v3_nchw_cyxk_nkhw_lds_double_buffer.hpp
deleted
100644 → 0
View file @
f0716f5b
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V3_NCHW_CYXK_NKHW_LDS_DOUBLE_BUFFER
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V3_NCHW_CYXK_NKHW_LDS_DOUBLE_BUFFER
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMergedTensorDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp"
#include "blockwise_gemm.hpp"
namespace
ck
{
// define B = merge(N0, Ho, Wo)
template
<
index_t
GridSize
,
index_t
BlockSize
,
class
Float
,
class
InGlobalDesc
,
class
WeiGlobalDesc
,
class
OutGlobalDesc
,
index_t
BPerBlock
,
index_t
KPerBlock
,
index_t
CPerBlock
,
index_t
N1
,
index_t
N2
,
index_t
GemmMPerThreadSubC
,
index_t
GemmNPerThreadSubC
,
index_t
GemmMLevel0Cluster
,
index_t
GemmNLevel0Cluster
,
index_t
GemmMLevel1Cluster
,
index_t
GemmNLevel1Cluster
,
index_t
GemmKPerThreadLoop
,
index_t
GemmDataPerReadA
,
index_t
GemmDataPerReadB
,
class
InBlockCopySubLengths_C_N1_B_N2
,
class
InBlockCopyClusterLengths_C_N1_B_N2
,
index_t
InBlockCopySrcDataPerRead_B
,
index_t
InBlockCopyDstDataPerWrite_N2
,
class
WeiBlockCopySubLengths_C_K
,
class
WeiBlockCopyClusterLengths_C_K
,
index_t
WeiBlockCopyDataPerAccess_K
>
struct
GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw_lds_double_buffer
{
__device__
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
const
Float
*
const
__restrict__
p_wei_global
,
Float
*
const
__restrict__
p_out_global
)
const
{
// this is a mess
// TODO: find more elegent way of specifying (or calculating) performance parameters
static_assert
(
N2
==
GemmNPerThreadSubC
,
"wrong!"
);
static_assert
((
N1
*
N2
*
BPerBlock
)
%
(
GemmNPerThreadSubC
*
GemmNLevel0Cluster
*
GemmNLevel1Cluster
)
==
0
,
"wrong!"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I4
=
Number
<
4
>
{};
constexpr
auto
I5
=
Number
<
5
>
{};
constexpr
auto
I6
=
Number
<
6
>
{};
constexpr
auto
I7
=
Number
<
7
>
{};
constexpr
auto
in_n_c_h_w_global_desc
=
InGlobalDesc
{};
constexpr
auto
wei_c_y_x_k_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
out_n_k_h_w_global_desc
=
OutGlobalDesc
{};
constexpr
index_t
N
=
in_n_c_h_w_global_desc
.
GetLength
(
I0
);
constexpr
index_t
C
=
in_n_c_h_w_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Hi
=
in_n_c_h_w_global_desc
.
GetLength
(
I2
);
constexpr
index_t
Wi
=
in_n_c_h_w_global_desc
.
GetLength
(
I3
);
constexpr
index_t
K
=
out_n_k_h_w_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Ho
=
out_n_k_h_w_global_desc
.
GetLength
(
I2
);
constexpr
index_t
Wo
=
out_n_k_h_w_global_desc
.
GetLength
(
I3
);
constexpr
index_t
Y
=
wei_c_y_x_k_global_desc
.
GetLength
(
I1
);
constexpr
index_t
X
=
wei_c_y_x_k_global_desc
.
GetLength
(
I2
);
static_assert
(
N
%
(
N1
*
N2
)
==
0
,
"wrong! cannot divice N evenly among thread"
);
constexpr
index_t
N0
=
N
/
(
N1
*
N2
);
constexpr
index_t
B
=
N0
*
Ho
*
Wo
;
// divide block work by [K, B]
static_assert
(
K
%
KPerBlock
==
0
&&
B
%
BPerBlock
==
0
&&
C
%
(
2
*
CPerBlock
)
==
0
,
"wrong! cannot divide work evenly among block"
);
constexpr
index_t
KBlockWork
=
K
/
KPerBlock
;
constexpr
index_t
BBlockWork
=
B
/
BPerBlock
;
constexpr
auto
block_work_desc
=
make_ConstantTensorDescriptor_packed
(
Sequence
<
KBlockWork
,
BBlockWork
>
{});
const
auto
block_work_multi_id
=
block_work_desc
.
GetMultiIndexFrom1dIndex
(
get_block_1d_id
());
const
index_t
k_block_data_on_global
=
block_work_multi_id
[
0
]
*
KPerBlock
;
const
index_t
b_block_data_on_global
=
block_work_multi_id
[
1
]
*
BPerBlock
;
// input tensor
// memory layout descriptor in device memory [N0, N1, N2, C, H, W]
constexpr
auto
in_n0_n1_n2_c_h_w_global_mem_desc
=
in_n_c_h_w_global_desc
.
Fold
(
I0
,
Number
<
N1
>
{},
Number
<
N2
>
{});
// merged tensor descriptor in device memory [C, N1, B, N2], src of blockwise copy
constexpr
auto
in_c_n1_b_n2_global_merged_desc
=
make_ConstantMergedTensorDescriptor
(
in_n0_n1_n2_c_h_w_global_mem_desc
.
Slice
(
I4
,
Number
<
Ho
>
{}).
Slice
(
I5
,
Number
<
Wo
>
{}),
Sequence
<
3
>
{},
Sequence
<
1
>
{},
Sequence
<
0
,
4
,
5
>
{},
Sequence
<
2
>
{});
// memory layout descriptor in LDS [C, N1, B, N2], dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
in_c_n1_b_n2_block_mem_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
,
N1
,
BPerBlock
,
N2
>
{},
Number
<
InBlockCopyDstDataPerWrite_N2
>
{});
// this check is ad-hoc
// TODO: need to properly implement tensor descriptor with alignment
static_assert
(
in_c_n1_b_n2_block_mem_desc
.
GetStride
(
I1
)
%
GemmDataPerReadB
==
0
,
"GemmDataPerReadB alignment requirement is not satisfied"
);
// input blockwise copy
// slice a merged tensor, reorder and copy to a normal tensor
// this copy operator already has blockwise offset built-in
const
auto
blockwise_in_copy
=
BlockwiseGenericTensorSliceCopy_v1
<
BlockSize
,
Float
,
decltype
(
in_c_n1_b_n2_global_merged_desc
),
decltype
(
in_c_n1_b_n2_block_mem_desc
),
decltype
(
in_c_n1_b_n2_block_mem_desc
.
GetLengths
()),
InBlockCopySubLengths_C_N1_B_N2
,
InBlockCopyClusterLengths_C_N1_B_N2
,
Sequence
<
0
,
1
,
3
,
2
>
,
// thread_arrange_order [C, N1, N2, B]
Sequence
<
1
,
3
,
0
,
2
>
,
// src_access_order [N1, N2, C, B]
Sequence
<
0
,
1
,
2
,
3
>
,
// dst_access_order [C, N1, B, N2]
InBlockCopySrcDataPerRead_B
,
InBlockCopyDstDataPerWrite_N2
>
({
0
,
0
,
b_block_data_on_global
,
0
},
{
0
,
0
,
0
,
0
});
// weight tensor
// tensor descriptor in device memory, src of blockwise copy
constexpr
auto
wei_c_k_global_desc
=
wei_c_y_x_k_global_desc
.
Extract
(
I0
,
I3
);
// tensor descriptor in LDS, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
wei_c_k_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
,
KPerBlock
>
{},
Number
<
math
::
lcm
(
WeiBlockCopyDataPerAccess_K
,
GemmDataPerReadA
)
>
{});
// operator for blockwise copy of weight into LDS
// slice a tensor, and copy it into another tensor
// this copy operator already have blockwise offset built-in
const
auto
blockwise_wei_copy
=
BlockwiseGenericTensorSliceCopy_v1
<
BlockSize
,
Float
,
decltype
(
wei_c_k_global_desc
),
decltype
(
wei_c_k_block_desc
),
decltype
(
wei_c_k_block_desc
.
GetLengths
()),
WeiBlockCopySubLengths_C_K
,
WeiBlockCopyClusterLengths_C_K
,
Sequence
<
0
,
1
>
,
// thread_arrange_order [C, K]
Sequence
<
0
,
1
>
,
// src_access_order [C, K]
Sequence
<
0
,
1
>
,
// dst_access_order [C, K]
WeiBlockCopyDataPerAccess_K
,
WeiBlockCopyDataPerAccess_K
>
(
{
0
,
k_block_data_on_global
},
{
0
,
0
});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[CPerBlock, KPerBlock] is in LDS
// b_mtx[CPerBlocl, N1 * BPerBlock * N2] is in LDS
// c_mtx[KPerBlock, N1 * BPerBlock * N2] is distributed among threads, and saved in
// register
constexpr
auto
a_c_k_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
CPerBlock
>
{},
Number
<
KPerBlock
>
{},
Number
<
wei_c_k_block_desc
.
GetStride
(
I0
)
>
{});
constexpr
auto
b_c_n1bn2_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
CPerBlock
>
{},
Number
<
N1
*
BPerBlock
*
N2
>
{},
Number
<
in_c_n1_b_n2_block_mem_desc
.
GetStride
(
I0
)
>
{});
// sanity check
static_assert
(
KPerBlock
%
(
GemmMPerThreadSubC
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
)
==
0
,
"wrong!"
);
constexpr
index_t
GemmMRepeat
=
KPerBlock
/
(
GemmMPerThreadSubC
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
);
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
constexpr
auto
c_k0k2_n1n2_thread_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
GemmMRepeat
*
GemmMPerThreadSubC
>
{},
Number
<
N1
*
N2
>
{});
const
auto
blockwise_gemm
=
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
<
BlockSize
,
decltype
(
a_c_k_block_mtx_desc
),
decltype
(
b_c_n1bn2_block_mtx_desc
),
decltype
(
c_k0k2_n1n2_thread_mtx_desc
),
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
GemmDataPerReadA
,
GemmDataPerReadB
>
{};
// choose GEMM implementation here
const
auto
run_blockwise_gemm
=
[
&
](
auto
...
Xs
)
{
#if 1
return
blockwise_gemm
.
Run
(
Xs
...);
#else
return
blockwise_gemm
.
Run_amd_asm
(
Xs
...);
#endif
};
// LDS allocation for input and weight: be careful of alignment
constexpr
index_t
max_align
=
math
::
lcm
(
InBlockCopyDstDataPerWrite_N2
,
WeiBlockCopyDataPerAccess_K
,
GemmDataPerReadA
,
GemmDataPerReadB
);
constexpr
index_t
in_block_space
=
in_c_n1_b_n2_block_mem_desc
.
GetElementSpace
(
Number
<
max_align
>
{});
constexpr
index_t
wei_block_space
=
wei_c_k_block_desc
.
GetElementSpace
(
Number
<
max_align
>
{});
__shared__
Float
p_in_block_double
[
2
*
in_block_space
];
__shared__
Float
p_wei_block_double
[
2
*
wei_block_space
];
// register allocation for output
Float
p_out_thread
[
c_k0k2_n1n2_thread_mtx_desc
.
GetElementSpace
()];
// zero out threadwise output
threadwise_matrix_set_zero
(
c_k0k2_n1n2_thread_mtx_desc
,
p_out_thread
);
// do work
for
(
index_t
y
=
0
;
y
<
Y
;
++
y
)
{
for
(
index_t
x
=
0
;
x
<
X
;
++
x
)
{
// calculate origin of block input and weight tensor on global memory
const
Float
*
p_in_block_on_global
=
p_in_global
+
in_n_c_h_w_global_desc
.
GetOffsetFromMultiIndex
(
0
,
0
,
y
,
x
);
const
Float
*
p_wei_block_on_global
=
p_wei_global
+
wei_c_y_x_k_global_desc
.
GetOffsetFromMultiIndex
(
0
,
y
,
x
,
0
);
// LDS double buffer: preload data into LDS
{
blockwise_in_copy
.
Run
(
p_in_block_on_global
,
p_in_block_double
);
blockwise_wei_copy
.
Run
(
p_wei_block_on_global
,
p_wei_block_double
);
}
// LDS double buffer: main body
for
(
index_t
c_block_data_begin
=
0
;
c_block_data_begin
+
2
*
CPerBlock
<
C
;
c_block_data_begin
+=
2
*
CPerBlock
)
{
#pragma unroll
for
(
index_t
iloop
=
0
;
iloop
<
2
;
++
iloop
)
{
const
bool
even_loop
=
(
iloop
%
2
==
0
);
Float
*
p_in_block_now
=
even_loop
?
p_in_block_double
:
p_in_block_double
+
in_block_space
;
Float
*
p_wei_block_now
=
even_loop
?
p_wei_block_double
:
p_wei_block_double
+
wei_block_space
;
Float
*
p_in_block_next
=
even_loop
?
p_in_block_double
+
in_block_space
:
p_in_block_double
;
Float
*
p_wei_block_next
=
even_loop
?
p_wei_block_double
+
wei_block_space
:
p_wei_block_double
;
Float
p_in_register_clipboard
[
blockwise_in_copy
.
GetRegisterClipboardSize
()];
Float
p_wei_register_clipboard
[
blockwise_wei_copy
.
GetRegisterClipboardSize
()];
p_in_block_on_global
+=
CPerBlock
*
in_n_c_h_w_global_desc
.
GetStride
(
I1
);
p_wei_block_on_global
+=
CPerBlock
*
wei_c_y_x_k_global_desc
.
GetStride
(
I0
);
__syncthreads
();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy
.
RunLoadRegisterClipboard
(
p_in_block_on_global
,
p_in_register_clipboard
);
blockwise_wei_copy
.
RunLoadRegisterClipboard
(
p_wei_block_on_global
,
p_wei_register_clipboard
);
// LDS double buffer: GEMM on current data
run_blockwise_gemm
(
p_wei_block_now
,
p_in_block_now
,
p_out_thread
);
// LDS double buffer: store next data to LDS
blockwise_in_copy
.
RunStoreRegisterClipboard
(
p_in_register_clipboard
,
p_in_block_next
);
blockwise_wei_copy
.
RunStoreRegisterClipboard
(
p_wei_register_clipboard
,
p_wei_block_next
);
}
}
// LDS double buffer: tail
{
Float
p_in_register_clipboard
[
blockwise_in_copy
.
GetRegisterClipboardSize
()];
Float
p_wei_register_clipboard
[
blockwise_wei_copy
.
GetRegisterClipboardSize
()];
// even iteration
p_in_block_on_global
+=
CPerBlock
*
in_n_c_h_w_global_desc
.
GetStride
(
I1
);
p_wei_block_on_global
+=
CPerBlock
*
wei_c_y_x_k_global_desc
.
GetStride
(
I0
);
__syncthreads
();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy
.
RunLoadRegisterClipboard
(
p_in_block_on_global
,
p_in_register_clipboard
);
blockwise_wei_copy
.
RunLoadRegisterClipboard
(
p_wei_block_on_global
,
p_wei_register_clipboard
);
// LDS double buffer: GEMM on current data
run_blockwise_gemm
(
p_wei_block_double
,
p_in_block_double
,
p_out_thread
);
// LDS double buffer: store next data to LDS
blockwise_in_copy
.
RunStoreRegisterClipboard
(
p_in_register_clipboard
,
p_in_block_double
+
in_block_space
);
blockwise_wei_copy
.
RunStoreRegisterClipboard
(
p_wei_register_clipboard
,
p_wei_block_double
+
wei_block_space
);
// odd iteration
__syncthreads
();
// LDS double buffer: GEMM on current data
run_blockwise_gemm
(
p_wei_block_double
+
wei_block_space
,
p_in_block_double
+
in_block_space
,
p_out_thread
);
}
}
}
// copy output: register to global memory
{
constexpr
index_t
K2
=
GemmMPerThreadSubC
;
constexpr
index_t
K1
=
GemmMLevel0Cluster
*
GemmMLevel1Cluster
;
constexpr
index_t
K0
=
K
/
(
K1
*
K2
);
// define tensor descriptor for threadwise copy
// output memory layout descriptor in register
constexpr
auto
out_k0_k1_k2_n1_n0_h_w_n2_thread_mem_desc
=
make_ConstantTensorDescriptor_packed
(
Sequence
<
KPerBlock
/
(
K1
*
K2
),
1
,
K2
,
N1
,
1
,
1
,
1
,
N2
>
{});
// output tensor descriptor in register, src of threadwise copy
constexpr
auto
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc
=
out_k0_k1_k2_n1_n0_h_w_n2_thread_mem_desc
.
ReorderGivenNew2Old
(
Sequence
<
4
,
3
,
7
,
0
,
1
,
2
,
5
,
6
>
{});
// output memory layout descriptor in device memory, dst of threadwise copy
constexpr
auto
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc
=
out_n_k_h_w_global_desc
.
Fold
(
I1
,
Number
<
K1
>
{},
Number
<
K2
>
{})
.
Fold
(
I0
,
Number
<
N1
>
{},
Number
<
N2
>
{});
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
blockwise_gemm
.
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
const
index_t
k_thread_data_on_global
=
k_block_data_on_global
+
c_thread_mtx_on_block
.
row
;
const
index_t
b_thread_data_on_global
=
b_block_data_on_global
+
c_thread_mtx_on_block
.
col
/
N2
;
// output merged global tensor descriptor, for calculating origin of thread tensor
// in global memory
constexpr
auto
out_k_n1_b_n2_global_merged_desc
=
make_ConstantMergedTensorDescriptor
(
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc
.
Unfold
(
I3
,
I5
),
Sequence
<
3
>
{},
Sequence
<
1
>
{},
Sequence
<
0
,
4
,
5
>
{},
Sequence
<
2
>
{});
// origin of dst in device memory
Float
*
p_out_thread_on_global
=
p_out_global
+
out_k_n1_b_n2_global_merged_desc
.
GetOffsetFromMultiIndex
(
k_thread_data_on_global
,
0
,
b_thread_data_on_global
,
0
);
threadwise_generic_tensor_slice_copy_v1
(
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc
,
p_out_thread
,
{
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
},
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc
,
p_out_thread_on_global
,
{
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
},
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc
.
GetLengths
(),
arithmetic_sequence_gen
<
0
,
8
,
1
>::
type
{},
Number
<
1
>
{});
}
}
};
}
// namesspace ck
#endif
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4_nchw_kcyx_nkhw.hpp
deleted
100644 → 0
View file @
f0716f5b
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4_NCHW_KCYX_NKHW
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4_NCHW_KCYX_NKHW
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMergedTensorDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp"
#include "blockwise_gemm.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
namespace
ck
{
// define B = merge(N0, Ho, Wo)
template
<
index_t
GridSize
,
index_t
BlockSize
,
class
Float
,
class
InGlobalDesc
,
class
WeiGlobalDesc
,
class
OutGlobalDesc
,
index_t
BPerBlock
,
index_t
KPerBlock
,
index_t
EPerBlock
,
index_t
N1
,
index_t
N2
,
index_t
GemmMPerThreadSubC
,
index_t
GemmNPerThreadSubC
,
index_t
GemmMLevel0Cluster
,
index_t
GemmNLevel0Cluster
,
index_t
GemmMLevel1Cluster
,
index_t
GemmNLevel1Cluster
,
index_t
GemmKPerThreadLoop
,
index_t
GemmDataPerReadA
,
index_t
GemmDataPerReadB
,
class
InBlockCopySubLengths_E_N1_B_N2
,
class
InBlockCopyClusterLengths_E_N1_B_N2
,
class
InBlockCopyThreadClusterArrangeOrder
,
class
InBlockCopySrcAccessOrder
,
class
InBlockCopyDstAccessOrder
,
index_t
InBlockCopySrcDataPerRead_B
,
index_t
InBlockCopyDstDataPerWrite_N2
,
class
WeiBlockCopySubLengths_E_K
,
class
WeiBlockCopyClusterLengths_E_K
,
class
WeiBlockCopyThreadClusterArrangeOrder
,
class
WeiBlockCopySrcAccessOrder
,
class
WeiBlockCopyDstAccessOrder
,
index_t
WeiBlockCopySrcDataPerRead_E
,
index_t
WeiBlockCopyDstDataPerWrite_K
>
struct
GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw
{
__device__
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
const
Float
*
const
__restrict__
p_wei_global
,
Float
*
const
__restrict__
p_out_global
)
const
{
// this is a mess
// TODO: find more elegent way of specifying (or calculating) performance parameters
static_assert
(
N2
==
GemmNPerThreadSubC
,
"wrong!"
);
static_assert
((
N1
*
N2
*
BPerBlock
)
%
(
GemmNPerThreadSubC
*
GemmNLevel0Cluster
*
GemmNLevel1Cluster
)
==
0
,
"wrong!"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I4
=
Number
<
4
>
{};
constexpr
auto
I5
=
Number
<
5
>
{};
constexpr
auto
I6
=
Number
<
6
>
{};
constexpr
auto
I7
=
Number
<
7
>
{};
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
in_n_c_h_w_global_desc
=
InGlobalDesc
{};
constexpr
auto
wei_k_c_y_x_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
out_n_k_h_w_global_desc
=
OutGlobalDesc
{};
constexpr
index_t
N
=
in_n_c_h_w_global_desc
.
GetLength
(
I0
);
constexpr
index_t
C
=
in_n_c_h_w_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Hi
=
in_n_c_h_w_global_desc
.
GetLength
(
I2
);
constexpr
index_t
Wi
=
in_n_c_h_w_global_desc
.
GetLength
(
I3
);
constexpr
index_t
K
=
out_n_k_h_w_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Ho
=
out_n_k_h_w_global_desc
.
GetLength
(
I2
);
constexpr
index_t
Wo
=
out_n_k_h_w_global_desc
.
GetLength
(
I3
);
constexpr
index_t
Y
=
wei_k_c_y_x_global_desc
.
GetLength
(
I2
);
constexpr
index_t
X
=
wei_k_c_y_x_global_desc
.
GetLength
(
I3
);
static_assert
(
N
%
(
N1
*
N2
)
==
0
,
"wrong! cannot divice N evenly among thread"
);
constexpr
index_t
N0
=
N
/
(
N1
*
N2
);
constexpr
index_t
B
=
N0
*
Ho
*
Wo
;
constexpr
index_t
E
=
C
*
Y
*
X
;
// divide block work by [K, B]
static_assert
(
K
%
KPerBlock
==
0
&&
B
%
BPerBlock
==
0
&&
E
%
EPerBlock
==
0
,
"wrong! cannot divide work evenly among block"
);
constexpr
index_t
KBlockWork
=
K
/
KPerBlock
;
constexpr
index_t
BBlockWork
=
B
/
BPerBlock
;
constexpr
auto
block_work_desc
=
make_ConstantTensorDescriptor_packed
(
Sequence
<
KBlockWork
,
BBlockWork
>
{});
const
auto
block_work_multi_id
=
block_work_desc
.
GetMultiIndexFrom1dIndex
(
get_block_1d_id
());
const
index_t
k_block_data_on_global
=
block_work_multi_id
[
0
]
*
KPerBlock
;
const
index_t
b_block_data_on_global
=
block_work_multi_id
[
1
]
*
BPerBlock
;
// input tensor
// tensor descriptor in device memory [N0, N1, N2, Ho, Wo]
constexpr
auto
in_n0_n1_n2_h_w_global_desc
=
in_n_c_h_w_global_desc
.
Slice
(
I2
,
Number
<
Ho
>
{})
.
Slice
(
I3
,
Number
<
Wo
>
{})
.
Fold
(
I0
,
Number
<
N1
>
{},
Number
<
N2
>
{})
.
Extract
(
Sequence
<
0
,
1
,
2
,
4
,
5
>
{});
// batch descritpor for device memory
constexpr
auto
in_c_y_x_global_desc
=
in_n_c_h_w_global_desc
.
Slice
(
I2
,
Number
<
Y
>
{})
.
Slice
(
I3
,
Number
<
X
>
{})
.
Extract
(
Sequence
<
1
,
2
,
3
>
{});
// merged tensor descriptor in device memory [E, N1, B, N2], src of blockwise copy
constexpr
auto
in_e_n1_b_n2_global_merged_desc
=
make_ConstantMergedTensorDescriptor
(
in_c_y_x_global_desc
.
Embed
(
in_n0_n1_n2_h_w_global_desc
),
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
4
>
{},
Sequence
<
3
,
6
,
7
>
{},
Sequence
<
5
>
{});
#if 0
if(get_block_1d_id() == 0 && get_thread_local_1d_id() == 0)
{
print_ConstantTensorDescriptor(in_n0_n1_n2_h_w_global_desc,
"in_n0_n1_n2_h_w_global_desc: ");
print_ConstantTensorDescriptor(in_c_y_x_global_desc, "in_c_y_x_global_desc: ");
print_ConstantMergedTensorDescriptor(in_e_n1_b_n2_global_merged_desc,
"in_e_n1_b_n2_global_merged_desc: ");
}
#endif
// memory layout descriptor in LDS [E, N1, B, N2], dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
in_e_n1_b_n2_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
EPerBlock
,
N1
,
BPerBlock
,
N2
>
{},
Number
<
InBlockCopyDstDataPerWrite_N2
>
{});
// this check is ad-hoc
// TODO: need to properly implement tensor descriptor with multiple alignment
// requirements
static_assert
(
in_e_n1_b_n2_block_desc
.
GetStride
(
I1
)
%
GemmDataPerReadB
==
0
,
"GemmDataPerReadB alignment requirement is not satisfied"
);
// input blockwise copy
// slice a merged tensor, reorder and copy to a normal tensor
// this copy operator already has blockwise offset built-in
auto
blockwise_in_copy
=
BlockwiseGenericTensorSliceCopy_v1
<
BlockSize
,
Float
,
decltype
(
in_e_n1_b_n2_global_merged_desc
),
decltype
(
in_e_n1_b_n2_block_desc
),
decltype
(
in_e_n1_b_n2_block_desc
.
GetLengths
()),
InBlockCopySubLengths_E_N1_B_N2
,
InBlockCopyClusterLengths_E_N1_B_N2
,
InBlockCopyThreadClusterArrangeOrder
,
InBlockCopySrcAccessOrder
,
InBlockCopyDstAccessOrder
,
InBlockCopySrcDataPerRead_B
,
InBlockCopyDstDataPerWrite_N2
>
(
{
0
,
0
,
b_block_data_on_global
,
0
},
{
0
,
0
,
0
,
0
});
// weight tensor
// tensor descriptor in device memory, src of blockwise copy
constexpr
auto
wei_e_k_global_desc
=
wei_k_c_y_x_global_desc
.
Unfold
(
I1
,
I3
).
ReorderGivenNew2Old
(
Sequence
<
1
,
0
>
{});
// tensor descriptor in LDS, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
wei_e_k_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
EPerBlock
,
KPerBlock
>
{},
Number
<
math
::
lcm
(
WeiBlockCopyDstDataPerWrite_K
,
GemmDataPerReadA
)
>
{});
// operator for blockwise copy of weight into LDS
// slice a tensor, and copy it into another tensor
// this copy operator already have blockwise offset built-in
auto
blockwise_wei_copy
=
BlockwiseGenericTensorSliceCopy_v1
<
BlockSize
,
Float
,
decltype
(
wei_e_k_global_desc
),
decltype
(
wei_e_k_block_desc
),
decltype
(
wei_e_k_block_desc
.
GetLengths
()),
WeiBlockCopySubLengths_E_K
,
WeiBlockCopyClusterLengths_E_K
,
WeiBlockCopyThreadClusterArrangeOrder
,
WeiBlockCopySrcAccessOrder
,
WeiBlockCopyDstAccessOrder
,
WeiBlockCopySrcDataPerRead_E
,
WeiBlockCopyDstDataPerWrite_K
>
(
{
0
,
k_block_data_on_global
},
{
0
,
0
});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[EPerBlock, KPerBlock] is in LDS
// b_mtx[EPerBlocl, N1 * BPerBlock * N2] is in LDS
// c_mtx[KPerBlock, N1 * BPerBlock * N2] is distributed among threads, and saved in
// register
constexpr
auto
a_e_k_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
EPerBlock
>
{},
Number
<
KPerBlock
>
{},
Number
<
wei_e_k_block_desc
.
GetStride
(
I0
)
>
{});
constexpr
auto
b_e_n1bn2_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
EPerBlock
>
{},
Number
<
N1
*
BPerBlock
*
N2
>
{},
Number
<
in_e_n1_b_n2_block_desc
.
GetStride
(
I0
)
>
{});
// sanity check
static_assert
(
KPerBlock
%
(
GemmMPerThreadSubC
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
)
==
0
,
"wrong!"
);
constexpr
index_t
GemmMRepeat
=
KPerBlock
/
(
GemmMPerThreadSubC
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
);
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
constexpr
auto
c_k0k2_n1n2_thread_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
GemmMRepeat
*
GemmMPerThreadSubC
>
{},
Number
<
N1
*
N2
>
{});
const
auto
blockwise_gemm
=
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
<
BlockSize
,
decltype
(
a_e_k_block_mtx_desc
),
decltype
(
b_e_n1bn2_block_mtx_desc
),
decltype
(
c_k0k2_n1n2_thread_mtx_desc
),
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
GemmDataPerReadA
,
GemmDataPerReadB
>
{};
// choose GEMM implementation here
const
auto
run_blockwise_gemm
=
[
&
](
auto
...
Xs
)
{
#if 1
return
blockwise_gemm
.
Run
(
Xs
...);
#else
return
blockwise_gemm
.
Run_amd_asm
(
Xs
...);
#endif
};
// LDS allocation for input and weight: be careful of alignment
constexpr
index_t
max_align
=
math
::
lcm
(
InBlockCopyDstDataPerWrite_N2
,
WeiBlockCopyDstDataPerWrite_K
,
GemmDataPerReadA
,
GemmDataPerReadB
);
constexpr
index_t
in_block_space
=
in_e_n1_b_n2_block_desc
.
GetElementSpace
(
Number
<
max_align
>
{});
constexpr
index_t
wei_block_space
=
wei_e_k_block_desc
.
GetElementSpace
(
Number
<
max_align
>
{});
__shared__
Float
p_in_block
[
in_block_space
];
__shared__
Float
p_wei_block
[
wei_block_space
];
// register allocation for output
Float
p_out_thread
[
c_k0k2_n1n2_thread_mtx_desc
.
GetElementSpace
()];
// zero out threadwise output
threadwise_matrix_set_zero
(
c_k0k2_n1n2_thread_mtx_desc
,
p_out_thread
);
// do work
for
(
index_t
e
=
0
;
e
<
E
;
e
+=
EPerBlock
)
{
// marching slicing window
blockwise_in_copy
.
Run
(
p_in_global
,
p_in_block
);
blockwise_wei_copy
.
Run
(
p_wei_global
,
p_wei_block
);
__syncthreads
();
run_blockwise_gemm
(
p_wei_block
,
p_in_block
,
p_out_thread
);
__syncthreads
();
blockwise_in_copy
.
MoveSlicingWindowOnSourceTensor
(
I0
,
Number
<
EPerBlock
>
{},
True
);
blockwise_wei_copy
.
MoveSlicingWindowOnSourceTensor
(
I0
,
Number
<
EPerBlock
>
{},
True
);
}
// copy output: register to global memory
{
constexpr
index_t
K2
=
GemmMPerThreadSubC
;
constexpr
index_t
K1
=
GemmMLevel0Cluster
*
GemmMLevel1Cluster
;
constexpr
index_t
K0
=
K
/
(
K1
*
K2
);
// define tensor descriptor for threadwise copy
// output memory layout descriptor in register
constexpr
auto
out_k0_k1_k2_n1_n0_h_w_n2_thread_mem_desc
=
make_ConstantTensorDescriptor_packed
(
Sequence
<
KPerBlock
/
(
K1
*
K2
),
1
,
K2
,
N1
,
1
,
1
,
1
,
N2
>
{});
// output tensor descriptor in register, src of threadwise copy
constexpr
auto
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc
=
out_k0_k1_k2_n1_n0_h_w_n2_thread_mem_desc
.
ReorderGivenNew2Old
(
Sequence
<
4
,
3
,
7
,
0
,
1
,
2
,
5
,
6
>
{});
// output memory layout descriptor in device memory, dst of threadwise copy
constexpr
auto
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc
=
out_n_k_h_w_global_desc
.
Fold
(
I1
,
Number
<
K1
>
{},
Number
<
K2
>
{})
.
Fold
(
I0
,
Number
<
N1
>
{},
Number
<
N2
>
{});
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
blockwise_gemm
.
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
const
index_t
k_thread_data_on_global
=
k_block_data_on_global
+
c_thread_mtx_on_block
.
row
;
const
index_t
b_thread_data_on_global
=
b_block_data_on_global
+
c_thread_mtx_on_block
.
col
/
N2
;
// output merged global tensor descriptor, for calculating origin of thread tensor
// in global memory
constexpr
auto
out_k_n1_b_n2_global_merged_desc
=
make_ConstantMergedTensorDescriptor
(
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc
.
Unfold
(
I3
,
I5
),
Sequence
<
3
>
{},
Sequence
<
1
>
{},
Sequence
<
0
,
4
,
5
>
{},
Sequence
<
2
>
{});
// origin of dst in device memory
Float
*
p_out_thread_on_global
=
p_out_global
+
out_k_n1_b_n2_global_merged_desc
.
GetOffsetFromMultiIndex
(
k_thread_data_on_global
,
0
,
b_thread_data_on_global
,
0
);
threadwise_generic_tensor_slice_copy_v1
(
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc
,
p_out_thread
,
{
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
},
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc
,
p_out_thread_on_global
,
{
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
},
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc
.
GetLengths
(),
arithmetic_sequence_gen
<
0
,
8
,
1
>::
type
{},
Number
<
1
>
{});
}
}
};
}
// namespace ck
#endif
composable_kernel/include/kernel_algorithm/gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hpp
deleted
100644 → 0
View file @
f0716f5b
#pragma once
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "blockwise_2d_tensor_op.hpp"
#include "blockwise_4d_tensor_op.hpp"
#include "blockwise_direct_convolution.hpp"
#include "threadwise_4d_tensor_op.hpp"
#include "threadwise_direct_convolution.hpp"
namespace
ck
{
template
<
class
TInWei
,
class
TOut
,
class
TAccum
,
class
InGlobalDesc
,
class
WeiGlobalDesc
,
class
OutGlobalDesc
,
index_t
ScalarPerVector
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
CPerBlock
,
index_t
HoPerBlock
,
index_t
WoPerBlock
,
index_t
NPerThread
,
index_t
KPerThread
,
index_t
CPerThread
,
index_t
HoPerThread
,
index_t
WoPerThread
,
index_t
InBlockCopyDataPerRead
,
index_t
WeiBlockCopyDataPerRead
,
index_t
BlockSize
,
index_t
GridSize
>
__global__
void
gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw
(
const
typename
vector_type
<
TInWei
,
ScalarPerVector
>::
MemoryType
*
const
__restrict__
p_in_vec_global
,
const
typename
vector_type
<
TInWei
,
ScalarPerVector
>::
MemoryType
*
const
__restrict__
p_wei_vec_global
,
TOut
*
const
__restrict__
p_out_global
)
{
using
in_scalar_t
=
TInWei
;
using
in_vector_mem_t
=
typename
vector_type
<
in_scalar_t
,
ScalarPerVector
>::
MemoryType
;
using
out_scalar_t
=
TOut
;
using
accum_t
=
TAccum
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
in_nchw_vec_global_desc
=
InGlobalDesc
{};
constexpr
auto
wei_kcyx_vec_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
out_nkhw_global_desc
=
OutGlobalDesc
{};
constexpr
index_t
N
=
in_nchw_vec_global_desc
.
GetLength
(
I0
);
constexpr
index_t
K
=
wei_kcyx_vec_global_desc
.
GetLength
(
I0
);
constexpr
index_t
C
=
wei_kcyx_vec_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Y
=
wei_kcyx_vec_global_desc
.
GetLength
(
I2
);
constexpr
index_t
X
=
wei_kcyx_vec_global_desc
.
GetLength
(
I3
);
constexpr
auto
wei_ke_vec_global_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
K
,
C
*
Y
*
X
>
{});
// 2d view of wei for blockwise copy
constexpr
index_t
HiPerBlock
=
HoPerBlock
+
Y
-
1
;
constexpr
index_t
WiPerBlock
=
WoPerBlock
+
X
-
1
;
constexpr
auto
in_nchw_vec_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
NPerBlock
,
CPerBlock
,
HiPerBlock
,
WiPerBlock
>
{},
Number
<
InBlockCopyDataPerRead
>
{});
constexpr
auto
wei_ke_vec_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
KPerBlock
,
CPerBlock
*
Y
*
X
>
{},
Number
<
WeiBlockCopyDataPerRead
>
{});
// 2d view of wei for blockwise copy
constexpr
auto
wei_kcyx_vec_block_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
KPerBlock
,
CPerBlock
,
Y
,
X
>
{},
Sequence
<
wei_ke_vec_block_desc
.
GetStride
(
I0
),
Y
*
X
,
X
,
1
>
{});
// shared mem
constexpr
index_t
in_block_element_size
=
in_nchw_vec_block_desc
.
GetElementSpace
(
Number
<
InBlockCopyDataPerRead
>
{});
constexpr
index_t
wei_block_element_size
=
wei_kcyx_vec_block_desc
.
GetElementSpace
(
Number
<
WeiBlockCopyDataPerRead
>
{});
constexpr
index_t
max_align
=
InBlockCopyDataPerRead
>
WeiBlockCopyDataPerRead
?
InBlockCopyDataPerRead
:
WeiBlockCopyDataPerRead
;
__shared__
in_vector_mem_t
p_in_vec_block
[
max_align
*
((
in_block_element_size
+
max_align
-
1
)
/
max_align
)];
__shared__
in_vector_mem_t
p_wei_vec_block
[
max_align
*
((
wei_block_element_size
+
max_align
-
1
)
/
max_align
)];
// threadwise tensors
constexpr
index_t
HiPerThread
=
HoPerThread
+
Y
-
1
;
constexpr
index_t
WiPerThread
=
WoPerThread
+
X
-
1
;
constexpr
auto
in_nchw_vec_thread_block_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
NPerThread
,
CPerThread
,
HiPerThread
,
WiPerThread
>
{},
in_nchw_vec_block_desc
.
GetStrides
());
constexpr
auto
wei_kcyx_vec_thread_block_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
KPerThread
,
CPerThread
,
Y
,
X
>
{},
wei_kcyx_vec_block_desc
.
GetStrides
());
constexpr
auto
out_nkhw_thread_desc
=
get_convolution_output_default_4d_tensor_descriptor
(
in_nchw_vec_thread_block_desc
,
wei_kcyx_vec_thread_block_desc
);
// register
out_scalar_t
p_out_thread
[
out_nkhw_thread_desc
.
GetElementSpace
()];
// divide block work
constexpr
index_t
NBlockWork
=
(
out_nkhw_global_desc
.
GetLength
(
I0
)
+
NPerBlock
-
1
)
/
NPerBlock
;
constexpr
index_t
KBlockWork
=
(
out_nkhw_global_desc
.
GetLength
(
I1
)
+
KPerBlock
-
1
)
/
KPerBlock
;
constexpr
index_t
HBlockWork
=
(
out_nkhw_global_desc
.
GetLength
(
I2
)
+
HoPerBlock
-
1
)
/
HoPerBlock
;
constexpr
index_t
WBlockWork
=
(
out_nkhw_global_desc
.
GetLength
(
I3
)
+
WoPerBlock
-
1
)
/
WoPerBlock
;
const
index_t
block_id
=
blockIdx
.
x
;
index_t
itmp
=
block_id
;
const
index_t
n_block_work_id
=
itmp
/
(
KBlockWork
*
HBlockWork
*
WBlockWork
);
itmp
-=
n_block_work_id
*
(
KBlockWork
*
HBlockWork
*
WBlockWork
);
const
index_t
k_block_work_id
=
itmp
/
(
HBlockWork
*
WBlockWork
);
itmp
-=
k_block_work_id
*
(
HBlockWork
*
WBlockWork
);
const
index_t
h_block_work_id
=
itmp
/
WBlockWork
;
const
index_t
w_block_work_id
=
itmp
-
h_block_work_id
*
WBlockWork
;
const
index_t
n_block_data_begin
=
n_block_work_id
*
NPerBlock
;
const
index_t
k_block_data_begin
=
k_block_work_id
*
KPerBlock
;
const
index_t
ho_block_data_begin
=
h_block_work_id
*
HoPerBlock
;
const
index_t
wo_block_data_begin
=
w_block_work_id
*
WoPerBlock
;
const
index_t
hi_block_data_begin
=
ho_block_data_begin
;
// minus padding
const
index_t
wi_block_data_begin
=
wo_block_data_begin
;
// minus padding
// divide thread work
constexpr
index_t
NThreadWork
=
(
NPerBlock
+
NPerThread
-
1
)
/
NPerThread
;
constexpr
index_t
KThreadWork
=
(
KPerBlock
+
KPerThread
-
1
)
/
KPerThread
;
constexpr
index_t
HThreadWork
=
(
HoPerBlock
+
HoPerThread
-
1
)
/
HoPerThread
;
constexpr
index_t
WThreadWork
=
(
WoPerBlock
+
WoPerThread
-
1
)
/
WoPerThread
;
const
index_t
thread_id
=
get_thread_local_1d_id
();
itmp
=
thread_id
;
const
index_t
n_thread_work_id
=
itmp
/
(
KThreadWork
*
HThreadWork
*
WThreadWork
);
itmp
-=
n_thread_work_id
*
(
KThreadWork
*
HThreadWork
*
WThreadWork
);
const
index_t
k_thread_work_id
=
itmp
/
(
HThreadWork
*
WThreadWork
);
itmp
-=
k_thread_work_id
*
(
HThreadWork
*
WThreadWork
);
const
index_t
h_thread_work_id
=
itmp
/
WThreadWork
;
const
index_t
w_thread_work_id
=
itmp
-
h_thread_work_id
*
WThreadWork
;
const
index_t
n_thread_data_begin
=
n_thread_work_id
*
NPerThread
;
const
index_t
k_thread_data_begin
=
k_thread_work_id
*
KPerThread
;
const
index_t
ho_thread_data_begin
=
h_thread_work_id
*
HoPerThread
;
const
index_t
wo_thread_data_begin
=
w_thread_work_id
*
WoPerThread
;
const
index_t
hi_thread_data_begin
=
ho_thread_data_begin
;
const
index_t
wi_thread_data_begin
=
wo_thread_data_begin
;
constexpr
auto
blockwise_in_copy
=
Blockwise4dTensorCopy1
<
BlockSize
,
in_vector_mem_t
,
decltype
(
in_nchw_vec_global_desc
),
decltype
(
in_nchw_vec_block_desc
),
decltype
(
in_nchw_vec_block_desc
.
GetLengths
()),
InBlockCopyDataPerRead
>
{};
#if 0
constexpr auto blockwise_wei_copy =
Blockwise4dTensorCopy1<BlockSize,
in_vector_mem_t,
decltype(wei_kcyx_vec_global_desc),
decltype(wei_kcyx_vec_block_desc),
decltype(wei_kcyx_vec_block_desc.GetLengths()),
1>{};
#elif
1
const
auto
blockwise_wei_copy
=
Blockwise2dTensorCopy3
<
BlockSize
,
in_vector_mem_t
,
decltype
(
wei_ke_vec_global_desc
),
decltype
(
wei_ke_vec_block_desc
),
decltype
(
wei_ke_vec_block_desc
.
GetLengths
()),
WeiBlockCopyDataPerRead
>
{};
#endif
#if 1 // debug
// set threadwise output tensor to 0
threadwise_4d_tensor_set_zero
(
out_nkhw_thread_desc
,
p_out_thread
);
#endif
for
(
index_t
c_block_data_begin
=
0
;
c_block_data_begin
<
C
;
c_block_data_begin
+=
CPerBlock
,
__syncthreads
())
{
// copy input tensor to LDS
blockwise_in_copy
.
Run
(
p_in_vec_global
+
in_nchw_vec_global_desc
.
GetOffsetFromMultiIndex
(
n_block_data_begin
,
c_block_data_begin
,
hi_block_data_begin
,
wi_block_data_begin
),
p_in_vec_block
);
// copy weight tensor to LDS
blockwise_wei_copy
.
Run
(
p_wei_vec_global
+
wei_kcyx_vec_global_desc
.
GetOffsetFromMultiIndex
(
k_block_data_begin
,
c_block_data_begin
,
0
,
0
),
p_wei_vec_block
);
__syncthreads
();
for
(
index_t
c_thread_data
=
0
;
c_thread_data
<
CPerBlock
;
c_thread_data
+=
CPerThread
)
{
// threadwise convolution
#if 1
threadwise_direct_convolution_2
(
in_nchw_vec_thread_block_desc
,
p_in_vec_block
+
in_nchw_vec_block_desc
.
GetOffsetFromMultiIndex
(
n_thread_data_begin
,
c_thread_data
,
hi_thread_data_begin
,
wi_thread_data_begin
),
wei_kcyx_vec_thread_block_desc
,
p_wei_vec_block
+
wei_kcyx_vec_block_desc
.
GetOffsetFromMultiIndex
(
k_thread_data_begin
,
c_thread_data
,
0
,
0
),
out_nkhw_thread_desc
,
p_out_thread
);
#elif 0
threadwise_direct_convolution_3
(
in_nchw_vec_thread_block_desc
,
p_in_vec_block
+
in_nchw_vec_block_desc
.
GetOffsetFromMultiIndex
(
n_thread_data_begin
,
c_thread_data
,
hi_thread_data_begin
,
wi_thread_data_begin
),
wei_kcyx_vec_thread_block_desc
,
p_wei_vec_block
+
wei_kcyx_vec_block_desc
.
GetOffsetFromMultiIndex
(
k_thread_data_begin
,
c_thread_data
,
0
,
0
),
out_nkhw_thread_desc
,
p_out_thread
);
#endif
}
}
// copy output tensor from register to global mem
threadwise_4d_tensor_copy
(
out_nkhw_thread_desc
,
p_out_thread
,
out_nkhw_global_desc
,
p_out_global
+
out_nkhw_global_desc
.
GetOffsetFromMultiIndex
(
n_block_data_begin
+
n_thread_data_begin
,
k_block_data_begin
+
k_thread_data_begin
,
ho_block_data_begin
+
ho_thread_data_begin
,
wo_block_data_begin
+
wo_thread_data_begin
),
out_nkhw_thread_desc
.
GetLengths
());
}
}
// namespace ck
composable_kernel/include/kernel_algorithm/gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded.hpp
deleted
100644 → 0
View file @
f0716f5b
#pragma once
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_4d_tensor_op.hpp"
#include "blockwise_2d_tensor_op.hpp"
#include "threadwise_4d_tensor_op.hpp"
#include "blockwise_gemm.hpp"
namespace
ck
{
template
<
index_t
GridSize
,
index_t
BlockSize
,
class
Float
,
class
InGlobalDesc
,
class
WeiGlobalDesc
,
class
OutGlobalDesc
,
class
LowerPads
,
class
UpperPads
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
CPerBlock
,
index_t
HoPerBlock
,
index_t
WoPerBlock
,
index_t
NPerThread
,
index_t
KPerThread
,
index_t
CPerThread
,
index_t
HoPerThread
,
index_t
WoPerThread
,
index_t
WeiBlockCopyThreadPerDim0
,
index_t
WeiBlockCopyThreadPerDim1
>
__global__
void
gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded
(
const
Float
*
const
__restrict__
p_in_global
,
const
Float
*
const
__restrict__
p_wei_global
,
Float
*
const
__restrict__
p_out_global
)
{
// NPerThread == NPerBlock, because the format of input in LDS [C,Hi,Wi,N]
// for GEMM trans([C,K]) * [C,Wo*N], we need a thread to do all the "N"
// if we use [C,Hi,N,Wi,N] in LDS, then NPerThread can be different from NPerBlock
static_assert
(
NPerBlock
%
NPerThread
==
0
,
"wrong! NPerBlock % NPerThread !=0"
);
static_assert
((
NPerThread
<
NPerBlock
&&
WoPerThread
==
1
)
||
NPerThread
==
NPerBlock
,
"wrong!"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
in_chwn_global_desc
=
InGlobalDesc
{};
constexpr
auto
wei_cyxk_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
out_khwn_global_desc
=
OutGlobalDesc
{};
constexpr
index_t
C
=
in_chwn_global_desc
.
GetLength
(
I0
);
constexpr
index_t
K
=
out_khwn_global_desc
.
GetLength
(
I0
);
constexpr
index_t
Ho
=
out_khwn_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Wo
=
out_khwn_global_desc
.
GetLength
(
I2
);
constexpr
index_t
N
=
out_khwn_global_desc
.
GetLength
(
I3
);
constexpr
index_t
Y
=
wei_cyxk_global_desc
.
GetLength
(
I1
);
constexpr
index_t
X
=
wei_cyxk_global_desc
.
GetLength
(
I2
);
constexpr
index_t
HPadLow
=
LowerPads
{}.
Get
(
I0
);
constexpr
index_t
WPadLow
=
LowerPads
{}.
Get
(
I1
);
constexpr
index_t
HPadUp
=
UpperPads
{}.
Get
(
I0
);
constexpr
index_t
WPadUp
=
UpperPads
{}.
Get
(
I1
);
constexpr
index_t
HiPerBlock
=
HoPerBlock
+
Y
-
1
;
constexpr
index_t
WiPerBlock
=
WoPerBlock
+
X
-
1
;
// divide block work: [K, Ho, Wo, N]
constexpr
index_t
KBlockWork
=
(
K
+
KPerBlock
-
1
)
/
KPerBlock
;
constexpr
index_t
HBlockWork
=
(
Ho
+
HoPerBlock
-
1
)
/
HoPerBlock
;
constexpr
index_t
WBlockWork
=
(
Wo
+
WoPerBlock
-
1
)
/
WoPerBlock
;
constexpr
index_t
NBlockWork
=
(
N
+
NPerBlock
-
1
)
/
NPerBlock
;
const
index_t
k_block_work_id
=
get_block_1d_id
()
/
(
HBlockWork
*
WBlockWork
*
NBlockWork
);
index_t
itmp
=
get_block_1d_id
()
-
k_block_work_id
*
(
HBlockWork
*
WBlockWork
*
NBlockWork
);
const
index_t
h_block_work_id
=
itmp
/
(
WBlockWork
*
NBlockWork
);
itmp
-=
h_block_work_id
*
(
WBlockWork
*
NBlockWork
);
const
index_t
w_block_work_id
=
itmp
/
NBlockWork
;
const
index_t
n_block_work_id
=
itmp
-
w_block_work_id
*
NBlockWork
;
const
index_t
k_block_data_begin
=
k_block_work_id
*
KPerBlock
;
const
index_t
ho_block_data_begin
=
h_block_work_id
*
HoPerBlock
;
const
index_t
wo_block_data_begin
=
w_block_work_id
*
WoPerBlock
;
const
index_t
n_block_data_begin
=
n_block_work_id
*
NPerBlock
;
// flattened (2d) tensor view of wei in global mem
constexpr
auto
wei_ek_global_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
C
*
Y
*
X
,
K
>
{});
// tensor view of blockwise input and weight in LDS
constexpr
auto
in_chwn_block_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
CPerBlock
,
HiPerBlock
,
WiPerBlock
,
NPerBlock
>
{});
constexpr
auto
wei_cyxk_block_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
CPerBlock
,
Y
,
X
,
KPerBlock
>
{});
// flattened (2d) tensor view of wei in LDS
constexpr
auto
wei_ek_block_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
CPerBlock
*
Y
*
X
,
KPerBlock
>
{});
// tensor view of threadwise output in register
constexpr
auto
out_hkwn_thread_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
HoPerThread
,
KPerThread
,
WoPerThread
,
NPerThread
>
{});
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantTensorDescriptor(in_chwn_block_desc, "in_chwn_block_desc");
print_ConstantTensorDescriptor(wei_cyxk_block_desc, "wei_cyxk_block_desc");
print_ConstantTensorDescriptor(out_hkwn_thread_desc, "out_hkwn_thread_desc");
}
#endif
// blockwise copy
// input: format is [C, Hi, Wi, N]
const
index_t
h_block_pad_low
=
h_block_work_id
==
0
?
HPadLow
:
0
;
const
index_t
w_block_pad_low
=
w_block_work_id
==
0
?
WPadLow
:
0
;
const
index_t
h_block_pad_up
=
h_block_work_id
==
HBlockWork
-
1
?
HPadUp
:
0
;
const
index_t
w_block_pad_up
=
w_block_work_id
==
WBlockWork
-
1
?
WPadUp
:
0
;
#if 0
if(get_thread_local_1d_id() == 0)
;
{
printf(
"%u %u, h_block_pad_low %u w_block_pad_low %u h_block_pad_up %u w_block_pad_up %u\n",
get_block_1d_id(),
get_thread_local_1d_id(),
h_block_pad_low,
w_block_pad_low,
h_block_pad_up,
w_block_pad_up);
}
#endif
constexpr
auto
blockwise_in_copy
=
BlockwiseChwnTensorCopyPadded
<
BlockSize
,
Float
,
decltype
(
in_chwn_global_desc
),
decltype
(
in_chwn_block_desc
),
decltype
(
in_chwn_block_desc
.
GetLengths
()),
LowerPads
>
{};
#if 0
// weight: format is [C,Y,X,K]
constexpr auto blockwise_wei_copy =
Blockwise4dTensorCopy1<BlockSize,
Float,
decltype(wei_cyxk_global_desc),
decltype(wei_cyxk_block_desc),
decltype(wei_cyxk_block_desc.GetLengths())>{};
#elif
0
// weight: format is [C*Y*X,K]
constexpr
auto
blockwise_wei_copy
=
Blockwise2dTensorCopy1
<
BlockSize
,
Float
,
decltype
(
wei_ek_global_desc
),
decltype
(
wei_ek_block_desc
),
decltype
(
wei_ek_block_desc
.
GetLengths
())
>
{};
#elif 1
// weight: format is [C*Y*X,K]
const
auto
blockwise_wei_copy
=
Blockwise2dTensorCopy2
<
BlockSize
,
Float
,
decltype
(
wei_ek_global_desc
),
decltype
(
wei_ek_block_desc
),
decltype
(
wei_ek_block_desc
.
GetLengths
()),
WeiBlockCopyThreadPerDim0
,
WeiBlockCopyThreadPerDim1
>
{};
#endif
// a series of blockwise batched GEMM
// C_matrix += transpose(A_matrix) * B_matrix
// A_matrix and B_matrix saved in LDS, C_matrix saved in register
// A_matrix[C,K] is a sub-matrix of wei_block[C,Y,X,K]
// B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N]
// C_matrix[K,Wo*N] is a sub-matrix of out_block[Ho,K,Wo,N]
constexpr
auto
a_cxk_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
CPerBlock
>
{},
Number
<
KPerBlock
>
{},
Number
<
wei_cyxk_block_desc
.
GetStride
(
I0
)
>
{});
constexpr
auto
b_cxwn_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
CPerBlock
>
{},
Number
<
WoPerBlock
*
NPerBlock
>
{},
Number
<
in_chwn_block_desc
.
GetStride
(
I0
)
>
{});
constexpr
auto
c_kxwn_thread_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThread
>
{},
Number
<
WoPerThread
*
NPerThread
>
{});
const
auto
blockwise_batch_gemm
=
Blockwise1dStridedBatchedGemmBlockABlockBThreadC
<
BlockSize
,
decltype
(
a_cxk_block_mtx_desc
),
decltype
(
b_cxwn_block_mtx_desc
),
decltype
(
c_kxwn_thread_mtx_desc
),
true
,
false
,
false
,
0
,
in_chwn_block_desc
.
GetStride
(
I1
),
out_hkwn_thread_desc
.
GetStride
(
I0
),
HoPerBlock
,
HoPerThread
,
CPerThread
,
true
>
{};
// LDS
constexpr
index_t
in_block_element_size
=
in_chwn_block_desc
.
GetElementSpace
();
constexpr
index_t
wei_block_element_size
=
wei_cyxk_block_desc
.
GetElementSpace
();
__shared__
Float
p_in_block
[
in_block_element_size
];
__shared__
Float
p_wei_block
[
wei_block_element_size
];
// register
Float
p_out_thread
[
out_hkwn_thread_desc
.
GetElementSpace
()];
// set threadwise output tensor to 0
threadwise_4d_tensor_set_zero
(
out_hkwn_thread_desc
,
p_out_thread
);
const
Float
*
p_wei_global_block_begin
=
p_wei_global
+
wei_ek_global_desc
.
GetOffsetFromMultiIndex
(
0
,
k_block_data_begin
);
for
(
index_t
c_block_data_begin
=
0
;
c_block_data_begin
<
C
;
c_block_data_begin
+=
CPerBlock
,
p_wei_global_block_begin
+=
CPerBlock
*
wei_ek_global_desc
.
GetStride
(
I0
),
__syncthreads
())
{
#if 1
// input: global mem to LDS,
blockwise_in_copy
.
Run
(
p_in_global
,
c_block_data_begin
,
ho_block_data_begin
,
wo_block_data_begin
,
n_block_data_begin
,
p_in_block
,
h_block_pad_low
,
w_block_pad_low
,
h_block_pad_up
,
w_block_pad_up
);
#endif
#if 1
// weight: global mem to LDS,
blockwise_wei_copy
.
Run
(
p_wei_global_block_begin
,
p_wei_block
);
#endif
__syncthreads
();
// a series of batched GEMM
for
(
index_t
y
=
0
;
y
<
Y
;
++
y
)
{
for
(
index_t
x
=
0
;
x
<
X
;
++
x
)
{
auto
f_accum
=
[](
auto
&
acc
,
const
auto
&&
v
)
{
acc
+=
v
;
};
blockwise_batch_gemm
.
Run
(
p_wei_block
+
wei_cyxk_block_desc
.
GetOffsetFromMultiIndex
(
0
,
y
,
x
,
0
),
p_in_block
+
in_chwn_block_desc
.
GetOffsetFromMultiIndex
(
0
,
y
,
x
,
0
),
p_out_thread
,
f_accum
);
}
}
}
const
auto
matrix_c_index
=
blockwise_batch_gemm
.
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
const
index_t
ho_thread_data_begin
=
matrix_c_index
.
batch
;
const
index_t
k_thread_data_begin
=
matrix_c_index
.
row
;
const
index_t
wo_thread_data_begin
=
matrix_c_index
.
col
/
NPerBlock
;
const
index_t
n_thread_data_begin
=
matrix_c_index
.
col
-
wo_thread_data_begin
*
NPerBlock
;
#if 0
printf("block %u %u, %u %u %u %u, %u %u %u %u, %f \n",
get_block_1d_id(), get_thread_local_1d_id(),
ho_block_data_begin, k_block_data_begin, wo_block_data_begin, n_block_data_begin,
ho_thread_data_begin, k_thread_data_begin, wo_thread_data_begin, n_thread_data_begin,
p_out_thread[0]);
#endif
// output: register to global mem,
// convert out_thread[Ho,K,Wo,N] to out_global[K,Ho,Wo,N]
constexpr
auto
reorder_khwn_from_hkwn
=
Sequence
<
1
,
0
,
2
,
3
>
{};
threadwise_4d_tensor_copy_reorder_by_get_dst_from_src
(
out_hkwn_thread_desc
,
p_out_thread
,
out_khwn_global_desc
,
p_out_global
+
out_khwn_global_desc
.
GetOffsetFromMultiIndex
(
k_block_data_begin
+
k_thread_data_begin
,
ho_block_data_begin
+
ho_thread_data_begin
,
wo_block_data_begin
+
wo_thread_data_begin
,
n_block_data_begin
+
n_thread_data_begin
),
out_hkwn_thread_desc
.
GetLengths
(),
reorder_khwn_from_hkwn
);
}
}
// namespace ck
composable_kernel/include/tensor_operation/blockwise_2d_tensor_op.hpp
deleted
100644 → 0
View file @
f0716f5b
#ifndef CK_BLOCKWISE_2D_TENSOR_OP_HPP
#define CK_BLOCKWISE_2D_TENSOR_OP_HPP
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
namespace
ck
{
template
<
index_t
BlockSize
,
class
Float
,
class
DstDesc
,
class
F
>
__device__
void
blockwise_2d_tensor_pointwise_operation_unary
(
DstDesc
,
Float
*
__restrict__
p_dst
,
F
f
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
dst_desc
=
DstDesc
{};
constexpr
auto
desc
=
make_ConstantTensorDescriptor
(
dst_desc
.
GetLengths
());
#if 0
if(get_thread_local_1d_id() == 0)
{
print_ConstantTensorDescriptor(dst_desc, "blockwise_4d_tensor_op_unary: dst_desc: ");
print_ConstantTensorDescriptor(desc, "blockwise_4d_tensor_op_unary: desc: ");
}
#endif
constexpr
index_t
NLoop
=
desc
.
GetElementSize
()
/
BlockSize
;
for
(
index_t
iloop
=
0
;
iloop
<
NLoop
;
++
iloop
)
{
index_t
is
=
get_thread_local_1d_id
()
+
iloop
*
BlockSize
;
const
index_t
did0
=
is
/
desc
.
GetStride
(
I0
);
is
-=
did0
*
desc
.
GetStride
(
I0
);
const
index_t
did1
=
is
/
desc
.
GetStride
(
I1
);
const
index_t
dindex
=
dst_desc
.
GetOffsetFromMultiIndex
(
did0
,
did1
);
f
(
p_dst
[
dindex
]);
}
constexpr
bool
has_tail
=
(
desc
.
GetElementSize
()
>
NLoop
*
BlockSize
);
if
(
has_tail
)
{
index_t
is
=
get_thread_local_1d_id
()
+
NLoop
*
BlockSize
;
if
(
is
<
desc
.
GetElementSize
())
{
const
index_t
did0
=
is
/
desc
.
GetStride
(
I0
);
is
-=
did0
*
desc
.
GetStride
(
I0
);
const
index_t
did1
=
is
/
desc
.
GetStride
(
I1
);
const
index_t
dindex
=
dst_desc
.
GetOffsetFromMultiIndex
(
did0
,
did1
);
f
(
p_dst
[
dindex
]);
}
}
}
// Function: p_dst[reorder[i0], reorder[i1] = p_src[i0,i1]
// TODO: in order to optimize mem access for different mem type,
// need to write specialized version
template
<
index_t
BlockSize
,
class
Float
,
class
SrcDesc
,
class
DstDesc
,
class
SrcOpLengths
,
class
MapDst2Src
,
class
F
>
__device__
void
blockwise_2d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src
(
SrcDesc
,
const
Float
*
__restrict__
p_src
,
DstDesc
,
Float
*
__restrict__
p_dst
,
SrcOpLengths
,
MapDst2Src
,
F
f
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
index_t
IR0
=
MapDst2Src
{}.
Get
(
I0
);
constexpr
index_t
IR1
=
MapDst2Src
{}.
Get
(
I1
);
constexpr
auto
src_desc
=
SrcDesc
{};
constexpr
auto
dst_desc
=
DstDesc
{};
constexpr
auto
ref_desc
=
make_ConstantTensorDescriptor
(
SrcOpLengths
{});
constexpr
index_t
NLoop
=
ref_desc
.
GetElementSize
()
/
BlockSize
;
for
(
index_t
iloop
=
0
;
iloop
<
NLoop
;
++
iloop
)
{
index_t
is
=
get_thread_local_1d_id
()
+
iloop
*
BlockSize
;
index_t
did
[
2
];
did
[
0
]
=
is
/
ref_desc
.
GetStride
(
I0
);
is
-=
did
[
0
]
*
ref_desc
.
GetStride
(
I0
);
did
[
1
]
=
is
/
ref_desc
.
GetStride
(
I1
);
const
index_t
aindex
=
src_desc
.
GetOffsetFromMultiIndex
(
did
[
0
],
did
[
1
]);
const
index_t
bindex
=
dst_desc
.
GetOffsetFromMultiIndex
(
did
[
IR0
],
did
[
IR1
]);
f
(
p_src
[
aindex
],
p_dst
[
bindex
]);
}
constexpr
bool
has_tail
=
(
ref_desc
.
GetElementSize
()
>
NLoop
*
BlockSize
);
if
(
has_tail
)
{
index_t
is
=
get_thread_local_1d_id
()
+
NLoop
*
BlockSize
;
if
(
is
<
ref_desc
.
GetElementSize
())
{
index_t
did
[
2
];
did
[
0
]
=
is
/
ref_desc
.
GetStride
(
I0
);
is
-=
did
[
0
]
*
ref_desc
.
GetStride
(
I0
);
did
[
1
]
=
is
/
ref_desc
.
GetStride
(
I1
);
const
index_t
aindex
=
src_desc
.
GetOffsetFromMultiIndex
(
did
[
0
],
did
[
1
]);
const
index_t
bindex
=
dst_desc
.
GetOffsetFromMultiIndex
(
did
[
IR0
],
did
[
IR1
]);
f
(
p_src
[
aindex
],
p_dst
[
bindex
]);
}
}
}
template
<
index_t
BlockSize
,
class
Float
,
class
DstDesc
>
__device__
void
blockwise_2d_tensor_set_zero
(
DstDesc
,
Float
*
__restrict__
p_dst
)
{
auto
f_set_zero
=
[](
Float
&
v
)
{
v
=
Float
(
0
);
};
blockwise_2d_tensor_pointwise_operation_unary
<
BlockSize
>
(
DstDesc
{},
p_dst
,
f_set_zero
);
}
template
<
index_t
BlockSize
,
class
Float
,
class
SrcDesc
,
class
DstDesc
,
class
SrcOpLengths
,
class
MapDst2Src
>
__device__
void
blockwise_2d_tensor_copy_reorder_by_get_dst_from_src
(
SrcDesc
,
const
Float
*
__restrict__
p_src
,
DstDesc
,
Float
*
__restrict__
p_dst
,
SrcOpLengths
,
MapDst2Src
)
{
auto
f_copy
=
[](
const
Float
&
src
,
Float
&
dst
)
{
dst
=
src
;
};
blockwise_2d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src
<
BlockSize
>
(
SrcDesc
{},
p_src
,
DstDesc
{},
p_dst
,
SrcOpLengths
{},
MapDst2Src
{},
f_copy
);
}
template
<
index_t
BlockSize
,
class
Float
,
class
SrcDesc
,
class
DstDesc
,
class
CopyLengths
,
index_t
DataPerRead
>
struct
Blockwise2dTensorCopy1
{
using
vector_t
=
typename
vector_type
<
Float
,
DataPerRead
>::
MemoryType
;
__device__
constexpr
Blockwise2dTensorCopy1
()
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
static_assert
(
DataPerRead
==
1
||
(
SrcDesc
{}.
GetStride
(
I1
)
==
1
&&
DstDesc
{}.
GetStride
(
I1
)
==
1
),
"wrong! only support stride1 == 1 if DataPerRead > 1!
\n
"
);
static_assert
(
DataPerRead
==
1
||
DataPerRead
==
2
||
DataPerRead
==
4
,
"wrong! only support DataPerRead == 1, 2 or 4!
\n
"
);
static_assert
(
SrcDesc
{}.
GetStride
(
I0
)
%
DataPerRead
==
0
&&
DstDesc
{}.
GetStride
(
I0
)
%
DataPerRead
==
0
,
"src and dst stride2 should be multiple of DataPerRead to keep alignment"
);
// we allow out-of-bound read from src in D1 dimension,
// but we need to make sure dst stride0 is big enough,
// so that the out-of-bound write won't contaminate next line in dst
constexpr
index_t
L1
=
CopyLengths
{}.
Get
(
I1
);
constexpr
index_t
read_per_d1
=
math
::
integer_divide_ceil
(
L1
,
DataPerRead
);
static_assert
(
read_per_d1
*
DataPerRead
<=
DstDesc
{}.
GetStride
(
I0
),
"wrong! out-of-bound write will contaminate next line!
\n
"
);
}
__device__
void
Run
(
const
Float
*
__restrict__
p_src
,
Float
*
__restrict__
p_dst
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
src_desc
=
SrcDesc
{};
constexpr
auto
dst_desc
=
DstDesc
{};
constexpr
index_t
L0
=
CopyLengths
{}.
Get
(
I0
);
constexpr
index_t
L1
=
CopyLengths
{}.
Get
(
I1
);
constexpr
index_t
read_per_d1
=
math
::
integer_divide_ceil
(
L1
,
DataPerRead
);
constexpr
auto
ref_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
L0
,
read_per_d1
>
{});
constexpr
index_t
NLoop
=
ref_desc
.
GetElementSize
()
/
BlockSize
;
auto
f_copy
=
[
&
](
index_t
is
)
{
index_t
did
[
4
];
did
[
0
]
=
is
/
ref_desc
.
GetStride
(
I0
);
is
-=
did
[
0
]
*
ref_desc
.
GetStride
(
I0
);
did
[
1
]
=
is
/
ref_desc
.
GetStride
(
I1
);
const
index_t
src_index
=
src_desc
.
GetOffsetFromMultiIndex
(
did
[
0
],
did
[
1
]
*
DataPerRead
);
const
index_t
dst_index
=
dst_desc
.
GetOffsetFromMultiIndex
(
did
[
0
],
did
[
1
]
*
DataPerRead
);
*
(
reinterpret_cast
<
vector_t
*>
(
p_dst
+
dst_index
))
=
*
(
reinterpret_cast
<
const
vector_t
*>
(
p_src
+
src_index
));
};
for
(
index_t
iloop
=
0
;
iloop
<
NLoop
;
++
iloop
)
{
index_t
is
=
get_thread_local_1d_id
()
+
iloop
*
BlockSize
;
f_copy
(
is
);
}
constexpr
bool
has_tail
=
(
ref_desc
.
GetElementSize
()
>
NLoop
*
BlockSize
);
if
(
has_tail
)
{
index_t
is
=
get_thread_local_1d_id
()
+
NLoop
*
BlockSize
;
if
(
is
<
ref_desc
.
GetElementSize
())
{
f_copy
(
is
);
}
}
}
};
// need to be aligned to float4 and float2
// stride1 need to be 1 for both source and destination
template
<
index_t
BlockSize
,
class
Float
,
class
SrcDesc
,
class
DstDesc
,
class
SrcOpLengths
,
index_t
ThreadPerDim0
,
index_t
ThreadPerDim1
>
struct
Blockwise2dTensorCopy2
{
index_t
mThreadId0
;
index_t
mThreadId1
;
__device__
Blockwise2dTensorCopy2
()
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
static_assert
(
SrcDesc
{}.
GetStride
(
I1
)
==
1
&&
DstDesc
{}.
GetStride
(
I1
)
==
1
,
"wrong! stride is not 1!
\n
"
);
mThreadId0
=
get_thread_local_1d_id
()
/
ThreadPerDim1
;
mThreadId1
=
get_thread_local_1d_id
()
-
mThreadId0
*
ThreadPerDim1
;
}
__device__
void
Run
(
const
Float
*
__restrict__
p_src
,
Float
*
__restrict__
p_dst
)
const
{
static_assert
(
is_same
<
Float
,
float
>
{},
"wrong! only support float!
\n
"
);
using
Float4
=
float4
;
using
Float2
=
float2
;
if
(
get_thread_local_1d_id
()
>=
ThreadPerDim0
*
ThreadPerDim1
)
return
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
src_desc
=
SrcDesc
{};
constexpr
auto
dst_desc
=
DstDesc
{};
// check alignment
constexpr
bool
align_v4
=
src_desc
.
GetStride
(
I0
)
%
4
==
0
&&
dst_desc
.
GetStride
(
I0
)
%
4
==
0
;
constexpr
bool
align_v2
=
src_desc
.
GetStride
(
I0
)
%
2
==
0
&&
dst_desc
.
GetStride
(
I0
)
%
2
==
0
;
constexpr
index_t
L0
=
SrcOpLengths
{}.
Get
(
I0
);
constexpr
index_t
L1
=
SrcOpLengths
{}.
Get
(
I1
);
constexpr
index_t
Dim0Loop
=
L0
/
ThreadPerDim0
;
constexpr
bool
d0_has_tail
=
(
L0
>
ThreadPerDim0
*
Dim0Loop
);
constexpr
index_t
Dim1V4Loop
=
align_v4
?
L1
/
(
ThreadPerDim1
*
4
)
:
0
;
constexpr
index_t
Dim1V2Loop
=
align_v2
?
(
L1
-
Dim1V4Loop
*
(
ThreadPerDim1
*
4
))
/
(
ThreadPerDim1
*
2
)
:
0
;
constexpr
index_t
Dim1V1Loop
=
(
L1
-
Dim1V4Loop
*
(
ThreadPerDim1
*
4
)
-
Dim1V2Loop
*
(
ThreadPerDim1
*
2
))
/
ThreadPerDim1
;
constexpr
bool
d1_has_tail
=
(
L1
>
ThreadPerDim1
*
(
4
*
Dim1V4Loop
+
2
*
Dim1V2Loop
+
Dim1V1Loop
));
for
(
index_t
d0loop
=
0
;
d0loop
<
Dim0Loop
;
++
d0loop
)
{
index_t
did0
=
d0loop
*
ThreadPerDim0
+
mThreadId0
;
// v4
for
(
index_t
d1v4loop
=
0
;
d1v4loop
<
Dim1V4Loop
;
++
d1v4loop
)
{
index_t
did1
=
d1v4loop
*
4
*
ThreadPerDim1
+
4
*
mThreadId1
;
const
index_t
sindex
=
src_desc
.
GetOffsetFromMultiIndex
(
did0
,
did1
);
const
index_t
dindex
=
dst_desc
.
GetOffsetFromMultiIndex
(
did0
,
did1
);
*
(
reinterpret_cast
<
Float4
*>
(
p_dst
+
dindex
))
=
*
(
reinterpret_cast
<
const
Float4
*>
(
p_src
+
sindex
));
}
// v2
for
(
index_t
d1v2loop
=
0
;
d1v2loop
<
Dim1V2Loop
;
++
d1v2loop
)
{
index_t
did1
=
Dim1V4Loop
*
4
*
ThreadPerDim1
+
d1v2loop
*
2
*
ThreadPerDim1
+
2
*
mThreadId1
;
const
index_t
sindex
=
src_desc
.
GetOffsetFromMultiIndex
(
did0
,
did1
);
const
index_t
dindex
=
dst_desc
.
GetOffsetFromMultiIndex
(
did0
,
did1
);
*
(
reinterpret_cast
<
Float2
*>
(
p_dst
+
dindex
))
=
*
(
reinterpret_cast
<
const
Float2
*>
(
p_src
+
sindex
));
}
// v1
for
(
index_t
d1v1loop
=
0
;
d1v1loop
<
Dim1V1Loop
;
++
d1v1loop
)
{
index_t
did1
=
Dim1V4Loop
*
4
*
ThreadPerDim1
+
Dim1V2Loop
*
2
*
ThreadPerDim1
+
d1v1loop
*
ThreadPerDim1
+
mThreadId1
;
const
index_t
sindex
=
src_desc
.
GetOffsetFromMultiIndex
(
did0
,
did1
);
const
index_t
dindex
=
dst_desc
.
GetOffsetFromMultiIndex
(
did0
,
did1
);
p_dst
[
dindex
]
=
p_src
[
sindex
];
}
// dim-1 tail
if
(
d1_has_tail
)
{
index_t
did1
=
Dim1V4Loop
*
4
*
ThreadPerDim1
+
Dim1V2Loop
*
2
*
ThreadPerDim1
+
Dim1V1Loop
*
ThreadPerDim1
+
mThreadId1
;
if
(
did1
<
L1
)
{
const
index_t
sindex
=
src_desc
.
GetOffsetFromMultiIndex
(
did0
,
did1
);
const
index_t
dindex
=
dst_desc
.
GetOffsetFromMultiIndex
(
did0
,
did1
);
p_dst
[
dindex
]
=
p_src
[
sindex
];
}
}
}
// dim-0 tail
if
(
d0_has_tail
)
{
index_t
did0
=
Dim0Loop
*
ThreadPerDim0
+
mThreadId0
;
if
(
did0
<
L0
)
{
// v4
for
(
index_t
d1v4loop
=
0
;
d1v4loop
<
Dim1V4Loop
;
++
d1v4loop
)
{
index_t
did1
=
d1v4loop
*
4
*
ThreadPerDim1
+
4
*
mThreadId1
;
const
index_t
sindex
=
src_desc
.
GetOffsetFromMultiIndex
(
did0
,
did1
);
const
index_t
dindex
=
dst_desc
.
GetOffsetFromMultiIndex
(
did0
,
did1
);
*
(
reinterpret_cast
<
Float4
*>
(
p_dst
+
dindex
))
=
*
(
reinterpret_cast
<
const
Float4
*>
(
p_src
+
sindex
));
}
// v2
for
(
index_t
d1v2loop
=
0
;
d1v2loop
<
Dim1V2Loop
;
++
d1v2loop
)
{
index_t
did1
=
Dim1V4Loop
*
4
*
ThreadPerDim1
+
d1v2loop
*
2
*
ThreadPerDim1
+
2
*
mThreadId1
;
const
index_t
sindex
=
src_desc
.
GetOffsetFromMultiIndex
(
did0
,
did1
);
const
index_t
dindex
=
dst_desc
.
GetOffsetFromMultiIndex
(
did0
,
did1
);
*
(
reinterpret_cast
<
Float2
*>
(
p_dst
+
dindex
))
=
*
(
reinterpret_cast
<
const
Float2
*>
(
p_src
+
sindex
));
}
// v1
for
(
index_t
d1v1loop
=
0
;
d1v1loop
<
Dim1V1Loop
;
++
d1v1loop
)
{
index_t
did1
=
Dim1V4Loop
*
4
*
ThreadPerDim1
+
Dim1V2Loop
*
2
*
ThreadPerDim1
+
d1v1loop
*
ThreadPerDim1
+
mThreadId1
;
const
index_t
sindex
=
src_desc
.
GetOffsetFromMultiIndex
(
did0
,
did1
);
const
index_t
dindex
=
dst_desc
.
GetOffsetFromMultiIndex
(
did0
,
did1
);
p_dst
[
dindex
]
=
p_src
[
sindex
];
}
// tail
if
(
d1_has_tail
)
{
index_t
did1
=
Dim1V4Loop
*
4
*
ThreadPerDim1
+
Dim1V2Loop
*
2
*
ThreadPerDim1
+
Dim1V1Loop
*
ThreadPerDim1
+
mThreadId1
;
if
(
did1
<
L1
)
{
const
index_t
sindex
=
src_desc
.
GetOffsetFromMultiIndex
(
did0
,
did1
);
const
index_t
dindex
=
dst_desc
.
GetOffsetFromMultiIndex
(
did0
,
did1
);
p_dst
[
dindex
]
=
p_src
[
sindex
];
}
}
}
}
}
};
// starting point need to be aligned to float4 or float2 or float
// stride1 need to be 1 for both source and destination
template
<
index_t
BlockSize
,
class
Float
,
class
SrcDesc
,
class
DstDesc
,
class
CopyLengths
,
index_t
DataPerRead
>
struct
Blockwise2dTensorCopy3
{
using
vector_t
=
typename
vector_type
<
Float
,
DataPerRead
>::
MemoryType
;
index_t
mSrcMyThreadOffset
;
index_t
mDstMyThreadOffset
;
__device__
Blockwise2dTensorCopy3
(
Array
<
index_t
,
2
>
src_block_data_multi_id_begin
,
Array
<
index_t
,
2
>
dst_block_data_multi_id_begin
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
static_assert
(
DataPerRead
==
1
||
(
SrcDesc
{}.
GetStride
(
I1
)
==
1
&&
DstDesc
{}.
GetStride
(
I1
)
==
1
),
"wrong! only support stride1 == 1 if DataPerRead > 1!
\n
"
);
static_assert
(
DataPerRead
==
1
||
DataPerRead
==
2
||
DataPerRead
==
4
,
"wrong! only support DataPerRead == 1, 2 or 4!
\n
"
);
static_assert
(
SrcDesc
{}.
GetStride
(
I0
)
%
DataPerRead
==
0
&&
DstDesc
{}.
GetStride
(
I0
)
%
DataPerRead
==
0
,
"src and dst stride should be multiple of DataPerRead to keep alignment"
);
constexpr
index_t
L1
=
CopyLengths
{}.
Get
(
I1
);
constexpr
index_t
thread_per_d1
=
(
L1
+
DataPerRead
-
1
)
/
DataPerRead
;
constexpr
index_t
thread_per_d0
=
BlockSize
/
thread_per_d1
;
// we allow out-of-bound read from src in D1 dimension,
// but we need to make sure dst stride is big enough,
// so that the out-of-bound write won't contaminate next line in dst
static_assert
(
thread_per_d1
*
DataPerRead
<=
DstDesc
{}.
GetStride
(
I0
),
"wrong! out-of-bound write will contaminate next line!
\n
"
);
static_assert
(
thread_per_d0
>=
1
,
"wrong! not enough threads to cover one line
\n
"
);
constexpr
index_t
num_active_thread
=
thread_per_d0
*
thread_per_d1
;
if
(
BlockSize
>
num_active_thread
)
{
if
(
get_thread_local_1d_id
()
>=
num_active_thread
)
{
return
;
}
}
const
index_t
thread_id_d0
=
get_thread_local_1d_id
()
/
thread_per_d1
;
const
index_t
thread_id_d1
=
get_thread_local_1d_id
()
-
thread_id_d0
*
thread_per_d1
;
mSrcMyThreadOffset
=
SrcDesc
{}.
GetOffsetFromMultiIndex
(
src_block_data_multi_id_begin
+
Array
<
index_t
,
2
>
{
thread_id_d0
,
thread_id_d1
*
DataPerRead
});
mDstMyThreadOffset
=
DstDesc
{}.
GetOffsetFromMultiIndex
(
dst_block_data_multi_id_begin
+
Array
<
index_t
,
2
>
{
thread_id_d0
,
thread_id_d1
*
DataPerRead
});
}
__device__
void
Run
(
const
Float
*
__restrict__
p_src
,
Float
*
__restrict__
p_dst
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
index_t
L0
=
CopyLengths
{}.
Get
(
I0
);
constexpr
index_t
L1
=
CopyLengths
{}.
Get
(
I1
);
constexpr
index_t
thread_per_d1
=
(
L1
+
DataPerRead
-
1
)
/
DataPerRead
;
constexpr
index_t
thread_per_d0
=
BlockSize
/
thread_per_d1
;
constexpr
index_t
num_active_thread
=
thread_per_d0
*
thread_per_d1
;
if
(
BlockSize
>
num_active_thread
)
{
if
(
get_thread_local_1d_id
()
>=
num_active_thread
)
{
return
;
}
}
constexpr
index_t
nloop_d0
=
L0
/
thread_per_d0
;
constexpr
index_t
src_loop_stride
=
SrcDesc
{}.
GetStride
(
I0
)
*
thread_per_d0
;
constexpr
index_t
dst_loop_stride
=
DstDesc
{}.
GetStride
(
I0
)
*
thread_per_d0
;
auto
f_copy
=
[
&
](
index_t
iloop
)
{
*
(
reinterpret_cast
<
vector_t
*>
(
p_dst
+
mDstMyThreadOffset
+
iloop
*
dst_loop_stride
))
=
*
(
reinterpret_cast
<
const
vector_t
*>
(
p_src
+
mSrcMyThreadOffset
+
iloop
*
src_loop_stride
));
};
for
(
index_t
iloop
=
0
;
iloop
<
nloop_d0
;
++
iloop
)
{
f_copy
(
iloop
);
}
constexpr
bool
has_tail_d0
=
(
L0
>
nloop_d0
*
thread_per_d0
);
if
(
has_tail_d0
)
{
constexpr
index_t
tail_d0
=
L0
-
nloop_d0
*
thread_per_d0
;
if
(
get_thread_local_1d_id
()
<
tail_d0
*
thread_per_d1
)
{
f_copy
(
nloop_d0
);
}
}
}
__device__
constexpr
index_t
GetRegisterClipboardSize
()
const
{
static_assert
(
is_same
<
Float
,
float
>
{},
"wrong! only support float!
\n
"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
index_t
L0
=
CopyLengths
{}.
Get
(
I0
);
constexpr
index_t
L1
=
CopyLengths
{}.
Get
(
I1
);
constexpr
index_t
thread_per_d1
=
(
L1
+
DataPerRead
-
1
)
/
DataPerRead
;
constexpr
index_t
thread_per_d0
=
BlockSize
/
thread_per_d1
;
return
DataPerRead
*
(
L0
+
thread_per_d0
-
1
)
/
thread_per_d0
;
}
__device__
void
RunLoadRegisterClipboard
(
const
Float
*
__restrict__
p_src
,
Float
*
__restrict__
p_clipboard
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
index_t
L0
=
CopyLengths
{}.
Get
(
I0
);
constexpr
index_t
L1
=
CopyLengths
{}.
Get
(
I1
);
constexpr
index_t
thread_per_d1
=
(
L1
+
DataPerRead
-
1
)
/
DataPerRead
;
constexpr
index_t
thread_per_d0
=
BlockSize
/
thread_per_d1
;
constexpr
index_t
num_active_thread
=
thread_per_d0
*
thread_per_d1
;
if
(
BlockSize
>
num_active_thread
)
{
if
(
get_thread_local_1d_id
()
>=
num_active_thread
)
{
return
;
}
}
constexpr
index_t
nloop_d0
=
L0
/
thread_per_d0
;
constexpr
index_t
src_loop_stride
=
SrcDesc
{}.
GetStride
(
I0
)
*
thread_per_d0
;
constexpr
index_t
dst_loop_stride
=
DstDesc
{}.
GetStride
(
I0
)
*
thread_per_d0
;
auto
f_copy
=
[
&
](
index_t
iloop
)
{
*
(
reinterpret_cast
<
vector_t
*>
(
&
p_clipboard
[
iloop
*
DataPerRead
]))
=
*
(
reinterpret_cast
<
const
vector_t
*>
(
&
p_src
[
mSrcMyThreadOffset
+
iloop
*
src_loop_stride
]));
};
for
(
index_t
iloop
=
0
;
iloop
<
nloop_d0
;
++
iloop
)
{
f_copy
(
iloop
);
}
constexpr
bool
has_tail_d0
=
(
L0
>
nloop_d0
*
thread_per_d0
);
if
(
has_tail_d0
)
{
constexpr
index_t
tail_d0
=
L0
-
nloop_d0
*
thread_per_d0
;
if
(
get_thread_local_1d_id
()
<
tail_d0
*
thread_per_d1
)
{
f_copy
(
nloop_d0
);
}
}
}
__device__
void
RunStoreRegisterClipboard
(
const
Float
*
__restrict__
p_clipboard
,
Float
*
__restrict__
p_dst
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
index_t
L0
=
CopyLengths
{}.
Get
(
I0
);
constexpr
index_t
L1
=
CopyLengths
{}.
Get
(
I1
);
constexpr
index_t
thread_per_d1
=
(
L1
+
DataPerRead
-
1
)
/
DataPerRead
;
constexpr
index_t
thread_per_d0
=
BlockSize
/
thread_per_d1
;
constexpr
index_t
num_active_thread
=
thread_per_d0
*
thread_per_d1
;
if
(
BlockSize
>
num_active_thread
)
{
if
(
get_thread_local_1d_id
()
>=
num_active_thread
)
{
return
;
}
}
constexpr
index_t
nloop_d0
=
L0
/
thread_per_d0
;
constexpr
index_t
src_loop_stride
=
SrcDesc
{}.
GetStride
(
I0
)
*
thread_per_d0
;
constexpr
index_t
dst_loop_stride
=
DstDesc
{}.
GetStride
(
I0
)
*
thread_per_d0
;
auto
f_copy
=
[
&
](
index_t
iloop
)
{
*
(
reinterpret_cast
<
vector_t
*>
(
&
p_dst
[
mDstMyThreadOffset
+
iloop
*
dst_loop_stride
]))
=
*
(
reinterpret_cast
<
const
vector_t
*>
(
&
p_clipboard
[
iloop
*
DataPerRead
]));
};
for
(
index_t
iloop
=
0
;
iloop
<
nloop_d0
;
++
iloop
)
{
f_copy
(
iloop
);
}
constexpr
bool
has_tail_d0
=
(
L0
>
nloop_d0
*
thread_per_d0
);
if
(
has_tail_d0
)
{
constexpr
index_t
tail_d0
=
L0
-
nloop_d0
*
thread_per_d0
;
if
(
get_thread_local_1d_id
()
<
tail_d0
*
thread_per_d1
)
{
f_copy
(
nloop_d0
);
}
}
}
#if CK_USE_AMD_INLINE_ASM
__device__
void
RunLoadRegisterClipboard_asm
(
const
Float
*
__restrict__
p_src
,
Float
*
p_clipboard
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
index_t
L0
=
CopyLengths
{}.
Get
(
I0
);
constexpr
index_t
L1
=
CopyLengths
{}.
Get
(
I1
);
constexpr
index_t
thread_per_d1
=
(
L1
+
DataPerRead
-
1
)
/
DataPerRead
;
constexpr
index_t
thread_per_d0
=
BlockSize
/
thread_per_d1
;
constexpr
index_t
num_active_thread
=
thread_per_d0
*
thread_per_d1
;
if
(
BlockSize
>
num_active_thread
)
{
if
(
get_thread_local_1d_id
()
>=
num_active_thread
)
{
return
;
}
}
constexpr
index_t
nloop_d0
=
L0
/
thread_per_d0
;
constexpr
index_t
src_loop_stride
=
SrcDesc
{}.
GetStride
(
I0
)
*
thread_per_d0
;
constexpr
index_t
dst_loop_stride
=
DstDesc
{}.
GetStride
(
I0
)
*
thread_per_d0
;
auto
f_copy
=
[
&
](
index_t
iloop
)
{
#if 0
*(reinterpret_cast<vector_t*>(&p_clipboard[iloop * DataPerRead])) =
*(reinterpret_cast<const vector_t*>(&p_src[mSrcMyThreadOffset +
iloop * src_loop_stride]));
#else
static_assert
(
is_same
<
float
,
Float
>
{}
&&
DataPerRead
==
4
,
"global_load is only for float4"
);
global_load
(
reinterpret_cast
<
vector_t
&>
(
p_clipboard
[
iloop
*
DataPerRead
]),
reinterpret_cast
<
const
vector_t
*>
(
&
p_src
[
mSrcMyThreadOffset
+
iloop
*
src_loop_stride
]));
#endif
};
for
(
index_t
iloop
=
0
;
iloop
<
nloop_d0
;
++
iloop
)
{
f_copy
(
iloop
);
}
constexpr
bool
has_tail_d0
=
(
L0
>
nloop_d0
*
thread_per_d0
);
if
(
has_tail_d0
)
{
constexpr
index_t
tail_d0
=
L0
-
nloop_d0
*
thread_per_d0
;
if
(
get_thread_local_1d_id
()
<
tail_d0
*
thread_per_d1
)
{
f_copy
(
nloop_d0
);
}
}
}
__device__
void
RunStoreRegisterClipboard_asm
(
const
Float
*
__restrict__
p_clipboard
,
Float
*
__restrict__
p_dst
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
index_t
L0
=
CopyLengths
{}.
Get
(
I0
);
constexpr
index_t
L1
=
CopyLengths
{}.
Get
(
I1
);
constexpr
index_t
thread_per_d1
=
(
L1
+
DataPerRead
-
1
)
/
DataPerRead
;
constexpr
index_t
thread_per_d0
=
BlockSize
/
thread_per_d1
;
constexpr
index_t
num_active_thread
=
thread_per_d0
*
thread_per_d1
;
if
(
BlockSize
>
num_active_thread
)
{
if
(
get_thread_local_1d_id
()
>=
num_active_thread
)
{
return
;
}
}
constexpr
index_t
nloop_d0
=
L0
/
thread_per_d0
;
constexpr
index_t
src_loop_stride
=
SrcDesc
{}.
GetStride
(
I0
)
*
thread_per_d0
;
constexpr
index_t
dst_loop_stride
=
DstDesc
{}.
GetStride
(
I0
)
*
thread_per_d0
;
auto
f_copy
=
[
&
](
index_t
iloop
)
{
#if 0
*(reinterpret_cast<vector_t*>(&p_dst[mDstMyThreadOffset + iloop * dst_loop_stride]) =
*(reinterpret_cast<const vector_t*>(&p_clipboard[iloop * DataPerRead]);
#else
static_assert
(
is_same
<
float
,
Float
>
{}
&&
DataPerRead
==
4
,
"ds_write_b128 is only for float4"
);
ds_write_b128
(
reinterpret_cast
<
const
vector_t
&>
(
p_clipboard
[
iloop
*
DataPerRead
]),
&
p_dst
[
mDstMyThreadOffset
+
iloop
*
dst_loop_stride
]);
#endif
};
for
(
index_t
iloop
=
0
;
iloop
<
nloop_d0
;
++
iloop
)
{
f_copy
(
iloop
);
}
constexpr
bool
has_tail_d0
=
(
L0
>
nloop_d0
*
thread_per_d0
);
if
(
has_tail_d0
)
{
constexpr
index_t
tail_d0
=
L0
-
nloop_d0
*
thread_per_d0
;
if
(
get_thread_local_1d_id
()
<
tail_d0
*
thread_per_d1
)
{
f_copy
(
nloop_d0
);
}
}
}
#endif
};
}
// namespace ck
#endif
composable_kernel/include/tensor_operation/blockwise_3d_tensor_op.hpp
deleted
100644 → 0
View file @
f0716f5b
#ifndef CK_BLOCKWISE_3D_TENSOR_OP_HPP
#define CK_BLOCKWISE_3D_TENSOR_OP_HPP
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
namespace
ck
{
template
<
index_t
BlockSize
,
class
Float
,
class
SrcDesc
,
class
DstDesc
,
class
CopyLengths
,
index_t
DataPerRead
>
struct
Blockwise3dTensorCopy1
{
using
vector_t
=
typename
vector_type
<
Float
,
DataPerRead
>::
MemoryType
;
__device__
constexpr
Blockwise3dTensorCopy1
()
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
static_assert
(
DataPerRead
==
1
||
(
SrcDesc
{}.
GetStride
(
I2
)
==
1
&&
DstDesc
{}.
GetStride
(
I2
)
==
1
),
"wrong! only support stride2 == 1 if DataPerRead > 1!
\n
"
);
static_assert
(
DataPerRead
==
1
||
DataPerRead
==
2
||
DataPerRead
==
4
,
"wrong! only support DataPerRead == 1, 2 or 4!
\n
"
);
static_assert
(
SrcDesc
{}.
GetStride
(
I1
)
%
DataPerRead
==
0
&&
DstDesc
{}.
GetStride
(
I1
)
%
DataPerRead
==
0
,
"src and dst stride1 should be multiple of DataPerRead to keep alignment"
);
// we allow out-of-bound read from src in D3 dimension,
// but we need to make sure dst stride2 is big enough,
// so that the out-of-bound write won't contaminate next line in dst
constexpr
index_t
L2
=
CopyLengths
{}.
Get
(
I2
);
constexpr
index_t
read_per_d2
=
math
::
integer_divide_ceil
(
L2
,
DataPerRead
);
static_assert
(
read_per_d2
*
DataPerRead
<=
DstDesc
{}.
GetStride
(
I1
),
"wrong! out-of-bound write will contaminate next line!
\n
"
);
}
__device__
void
Run
(
const
Float
*
__restrict__
p_src
,
Float
*
__restrict__
p_dst
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
src_desc
=
SrcDesc
{};
constexpr
auto
dst_desc
=
DstDesc
{};
constexpr
index_t
L0
=
CopyLengths
{}.
Get
(
I0
);
constexpr
index_t
L1
=
CopyLengths
{}.
Get
(
I1
);
constexpr
index_t
L2
=
CopyLengths
{}.
Get
(
I2
);
constexpr
index_t
read_per_d2
=
math
::
integer_divide_ceil
(
L2
,
DataPerRead
);
constexpr
auto
ref_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
L0
,
L1
,
read_per_d2
>
{});
constexpr
index_t
NLoop
=
ref_desc
.
GetElementSize
()
/
BlockSize
;
auto
f_copy
=
[
&
](
index_t
is
)
{
index_t
did
[
3
];
did
[
0
]
=
is
/
ref_desc
.
GetStride
(
I0
);
is
-=
did
[
0
]
*
ref_desc
.
GetStride
(
I0
);
did
[
1
]
=
is
/
ref_desc
.
GetStride
(
I1
);
is
-=
did
[
1
]
*
ref_desc
.
GetStride
(
I1
);
did
[
2
]
=
is
/
ref_desc
.
GetStride
(
I2
);
const
index_t
src_index
=
src_desc
.
GetOffsetFromMultiIndex
(
did
[
0
],
did
[
1
],
did
[
2
]
*
DataPerRead
);
const
index_t
dst_index
=
dst_desc
.
GetOffsetFromMultiIndex
(
did
[
0
],
did
[
1
],
did
[
2
]
*
DataPerRead
);
*
(
reinterpret_cast
<
vector_t
*>
(
p_dst
+
dst_index
))
=
*
(
reinterpret_cast
<
const
vector_t
*>
(
p_src
+
src_index
));
};
for
(
index_t
iloop
=
0
;
iloop
<
NLoop
;
++
iloop
)
{
index_t
is
=
get_thread_local_1d_id
()
+
iloop
*
BlockSize
;
f_copy
(
is
);
}
constexpr
bool
has_tail
=
(
ref_desc
.
GetElementSize
()
>
NLoop
*
BlockSize
);
if
(
has_tail
)
{
index_t
is
=
get_thread_local_1d_id
()
+
NLoop
*
BlockSize
;
if
(
is
<
ref_desc
.
GetElementSize
())
{
f_copy
(
is
);
}
}
}
};
// starting point need to be aligned to float4 or float2 or float
// stride3 need to be 1 for both source and destination
template
<
index_t
BlockSize
,
class
Float
,
class
SrcDesc
,
class
DstDesc
,
class
CopyLengths
,
class
ThreadPerDims
,
index_t
DataPerRead
>
struct
Blockwise3dTensorCopy3
{
using
vector_t
=
typename
vector_type
<
Float
,
DataPerRead
>::
MemoryType
;
index_t
mSrcMyThreadOffset
;
index_t
mDstMyThreadOffset
;
__device__
Blockwise3dTensorCopy3
()
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
static_assert
(
DataPerRead
==
1
||
(
SrcDesc
{}.
GetStride
(
I2
)
==
1
&&
DstDesc
{}.
GetStride
(
I2
)
==
1
),
"wrong! only support stride3 == 1 if DataPerRead > 1!
\n
"
);
static_assert
(
DataPerRead
==
1
||
DataPerRead
==
2
||
DataPerRead
==
4
,
"wrong! only support DataPerRead == 1, 2 or 4!
\n
"
);
static_assert
(
SrcDesc
{}.
GetStride
(
I1
)
%
DataPerRead
==
0
&&
DstDesc
{}.
GetStride
(
I1
)
%
DataPerRead
==
0
,
"wrong! src and dst stride1 should be multiple of DataPerRead to keep alignment"
);
constexpr
index_t
L0
=
CopyLengths
{}.
Get
(
I0
);
constexpr
index_t
L1
=
CopyLengths
{}.
Get
(
I1
);
constexpr
index_t
L2
=
CopyLengths
{}.
Get
(
I2
);
constexpr
index_t
thread_per_d0
=
ThreadPerDims
{}.
Get
(
I0
);
constexpr
index_t
thread_per_d1
=
ThreadPerDims
{}.
Get
(
I1
);
constexpr
index_t
thread_per_d2
=
ThreadPerDims
{}.
Get
(
I2
);
// we allow out-of-bound read from src in D2 dimension,
// but we need to make sure dst stride is big enough,
// so that the out-of-bound write won't contaminate next line in dst
constexpr
index_t
nloop_d2
=
math
::
integer_divide_ceil
(
L2
,
thread_per_d2
*
DataPerRead
);
static_assert
(
nloop_d2
*
thread_per_d2
*
DataPerRead
<=
DstDesc
{}.
GetStride
(
I1
),
"wrong! out-of-bound write will contaminate next line!
\n
"
);
static_assert
(
L0
%
thread_per_d0
==
0
&&
L1
%
thread_per_d1
==
0
,
"wrong! L0, L1, L2 should be divided evenly!
\n
"
);
static_assert
(
BlockSize
>=
thread_per_d0
*
thread_per_d1
*
thread_per_d2
,
"wrrong! BlockSize is not big enough for ThreadPerDims!"
);
constexpr
index_t
num_active_thread
=
accumulate_on_sequence
(
ThreadPerDims
{},
math
::
multiplies
<
index_t
>
{},
Number
<
1
>
{});
if
(
BlockSize
>
num_active_thread
)
{
if
(
get_thread_local_1d_id
()
>=
num_active_thread
)
{
return
;
}
}
constexpr
auto
thread_cluster_desc
=
make_ConstantTensorDescriptor
(
ThreadPerDims
{});
const
auto
thread_multi_id
=
thread_cluster_desc
.
GetMultiIndexFrom1dIndex
(
get_thread_local_1d_id
());
mSrcMyThreadOffset
=
SrcDesc
{}.
GetOffsetFromMultiIndex
(
thread_multi_id
[
0
],
thread_multi_id
[
1
],
thread_multi_id
[
2
]
*
DataPerRead
);
mDstMyThreadOffset
=
DstDesc
{}.
GetOffsetFromMultiIndex
(
thread_multi_id
[
0
],
thread_multi_id
[
1
],
thread_multi_id
[
2
]
*
DataPerRead
);
}
__device__
void
Run
(
const
Float
*
__restrict__
p_src
,
Float
*
__restrict__
p_dst
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
index_t
L0
=
CopyLengths
{}.
Get
(
I0
);
constexpr
index_t
L1
=
CopyLengths
{}.
Get
(
I1
);
constexpr
index_t
L2
=
CopyLengths
{}.
Get
(
I2
);
constexpr
index_t
thread_per_d0
=
ThreadPerDims
{}.
Get
(
I0
);
constexpr
index_t
thread_per_d1
=
ThreadPerDims
{}.
Get
(
I1
);
constexpr
index_t
thread_per_d2
=
ThreadPerDims
{}.
Get
(
I2
);
constexpr
index_t
num_active_thread
=
thread_per_d0
*
thread_per_d1
*
thread_per_d2
;
if
(
BlockSize
>
num_active_thread
)
{
if
(
get_thread_local_1d_id
()
>=
num_active_thread
)
{
return
;
}
}
constexpr
index_t
nloop_d0
=
L0
/
thread_per_d0
;
constexpr
index_t
nloop_d1
=
L1
/
thread_per_d1
;
constexpr
index_t
nloop_d2
=
math
::
integer_divide_ceil
(
L2
,
thread_per_d2
*
DataPerRead
);
#pragma unroll
for
(
index_t
iloop_d0
=
0
;
iloop_d0
<
nloop_d0
;
++
iloop_d0
)
{
#pragma unroll
for
(
index_t
iloop_d1
=
0
;
iloop_d1
<
nloop_d1
;
++
iloop_d1
)
{
#pragma unroll
for
(
index_t
iloop_d2
=
0
;
iloop_d2
<
nloop_d2
;
++
iloop_d2
)
{
const
index_t
src_offset
=
SrcDesc
{}.
GetOffsetFromMultiIndex
(
iloop_d0
*
thread_per_d0
,
iloop_d1
*
thread_per_d1
,
iloop_d2
*
thread_per_d2
*
DataPerRead
);
const
index_t
dst_offset
=
DstDesc
{}.
GetOffsetFromMultiIndex
(
iloop_d0
*
thread_per_d0
,
iloop_d1
*
thread_per_d1
,
iloop_d2
*
thread_per_d2
*
DataPerRead
);
*
(
reinterpret_cast
<
vector_t
*>
(
&
p_dst
[
dst_offset
+
mDstMyThreadOffset
]))
=
*
(
reinterpret_cast
<
const
vector_t
*>
(
&
p_src
[
src_offset
+
mSrcMyThreadOffset
]));
}
}
}
}
__device__
static
constexpr
index_t
GetRegisterClipboardSize
()
{
static_assert
(
is_same
<
Float
,
float
>
{},
"wrong! only support float!
\n
"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
index_t
L0
=
CopyLengths
{}.
Get
(
I0
);
constexpr
index_t
L1
=
CopyLengths
{}.
Get
(
I1
);
constexpr
index_t
L2
=
CopyLengths
{}.
Get
(
I2
);
constexpr
index_t
thread_per_d0
=
ThreadPerDims
{}.
Get
(
I0
);
constexpr
index_t
thread_per_d1
=
ThreadPerDims
{}.
Get
(
I1
);
constexpr
index_t
thread_per_d2
=
ThreadPerDims
{}.
Get
(
I2
);
constexpr
index_t
nloop_d0
=
L0
/
thread_per_d0
;
constexpr
index_t
nloop_d1
=
L1
/
thread_per_d1
;
constexpr
index_t
nloop_d2
=
math
::
integer_divide_ceil
(
L2
,
thread_per_d2
*
DataPerRead
);
return
DataPerRead
*
nloop_d0
*
nloop_d1
*
nloop_d2
;
}
__device__
void
RunLoadRegisterClipboard
(
const
Float
*
__restrict__
p_src
,
Float
*
__restrict__
p_clipboard
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
index_t
L0
=
CopyLengths
{}.
Get
(
I0
);
constexpr
index_t
L1
=
CopyLengths
{}.
Get
(
I1
);
constexpr
index_t
L2
=
CopyLengths
{}.
Get
(
I2
);
constexpr
index_t
thread_per_d0
=
ThreadPerDims
{}.
Get
(
I0
);
constexpr
index_t
thread_per_d1
=
ThreadPerDims
{}.
Get
(
I1
);
constexpr
index_t
thread_per_d2
=
ThreadPerDims
{}.
Get
(
I2
);
constexpr
index_t
num_active_thread
=
thread_per_d0
*
thread_per_d1
*
thread_per_d2
;
if
(
BlockSize
>
num_active_thread
)
{
if
(
get_thread_local_1d_id
()
>=
num_active_thread
)
{
return
;
}
}
constexpr
index_t
nloop_d0
=
L0
/
thread_per_d0
;
constexpr
index_t
nloop_d1
=
L1
/
thread_per_d1
;
constexpr
index_t
nloop_d2
=
math
::
integer_divide_ceil
(
L2
,
thread_per_d2
*
DataPerRead
);
constexpr
auto
clipboard_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
nloop_d0
,
nloop_d1
,
nloop_d2
*
DataPerRead
>
{});
#pragma unroll
for
(
index_t
iloop_d0
=
0
;
iloop_d0
<
nloop_d0
;
++
iloop_d0
)
{
#pragma unroll
for
(
index_t
iloop_d1
=
0
;
iloop_d1
<
nloop_d1
;
++
iloop_d1
)
{
#pragma unroll
for
(
index_t
iloop_d2
=
0
;
iloop_d2
<
nloop_d2
;
++
iloop_d2
)
{
const
index_t
src_offset
=
SrcDesc
{}.
GetOffsetFromMultiIndex
(
iloop_d0
*
thread_per_d0
,
iloop_d1
*
thread_per_d1
,
iloop_d2
*
thread_per_d2
*
DataPerRead
);
const
index_t
clipboard_offset
=
clipboard_desc
.
GetOffsetFromMultiIndex
(
iloop_d0
,
iloop_d1
,
iloop_d2
*
DataPerRead
);
*
(
reinterpret_cast
<
vector_t
*>
(
&
p_clipboard
[
clipboard_offset
]))
=
*
(
reinterpret_cast
<
const
vector_t
*>
(
&
p_src
[
src_offset
+
mSrcMyThreadOffset
]));
}
}
}
}
__device__
void
RunStoreRegisterClipboard
(
const
Float
*
__restrict__
p_clipboard
,
Float
*
__restrict__
p_dst
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
index_t
L0
=
CopyLengths
{}.
Get
(
I0
);
constexpr
index_t
L1
=
CopyLengths
{}.
Get
(
I1
);
constexpr
index_t
L2
=
CopyLengths
{}.
Get
(
I2
);
constexpr
index_t
thread_per_d0
=
ThreadPerDims
{}.
Get
(
I0
);
constexpr
index_t
thread_per_d1
=
ThreadPerDims
{}.
Get
(
I1
);
constexpr
index_t
thread_per_d2
=
ThreadPerDims
{}.
Get
(
I2
);
constexpr
index_t
num_active_thread
=
thread_per_d0
*
thread_per_d1
*
thread_per_d2
;
if
(
BlockSize
>
num_active_thread
)
{
if
(
get_thread_local_1d_id
()
>=
num_active_thread
)
{
return
;
}
}
constexpr
index_t
nloop_d0
=
L0
/
thread_per_d0
;
constexpr
index_t
nloop_d1
=
L1
/
thread_per_d1
;
constexpr
index_t
nloop_d2
=
math
::
integer_divide_ceil
(
L2
,
thread_per_d2
*
DataPerRead
);
constexpr
auto
clipboard_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
nloop_d0
,
nloop_d1
,
nloop_d2
*
DataPerRead
>
{});
#pragma unroll
for
(
index_t
iloop_d0
=
0
;
iloop_d0
<
nloop_d0
;
++
iloop_d0
)
{
#pragma unroll
for
(
index_t
iloop_d1
=
0
;
iloop_d1
<
nloop_d1
;
++
iloop_d1
)
{
#pragma unroll
for
(
index_t
iloop_d2
=
0
;
iloop_d2
<
nloop_d2
;
++
iloop_d2
)
{
const
index_t
clipboard_offset
=
clipboard_desc
.
GetOffsetFromMultiIndex
(
iloop_d0
,
iloop_d1
,
iloop_d2
*
DataPerRead
);
const
index_t
dst_offset
=
DstDesc
{}.
GetOffsetFromMultiIndex
(
iloop_d0
*
thread_per_d0
,
iloop_d1
*
thread_per_d1
,
iloop_d2
*
thread_per_d2
*
DataPerRead
);
*
(
reinterpret_cast
<
vector_t
*>
(
&
p_dst
[
dst_offset
+
mDstMyThreadOffset
]))
=
*
(
reinterpret_cast
<
const
vector_t
*>
(
&
p_clipboard
[
clipboard_offset
]));
}
}
}
}
};
}
// namespace ck
#endif
composable_kernel/include/tensor_operation/blockwise_4d_tensor_op.hpp
deleted
100644 → 0
View file @
f0716f5b
#ifndef CK_BLOCKWISE_4D_TENSOR_OP_HPP
#define CK_BLOCKWISE_4D_TENSOR_OP_HPP
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "threadwise_tensor_slice_copy.hpp"
namespace
ck
{
template
<
index_t
BlockSize
,
class
Float
,
class
DstDesc
,
class
F
>
__device__
void
blockwise_4d_tensor_pointwise_operation_unary
(
DstDesc
,
Float
*
__restrict__
p_dst
,
F
f
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
dst_desc
=
DstDesc
{};
constexpr
auto
desc
=
make_ConstantTensorDescriptor_packed
(
dst_desc
.
GetLengths
());
#if 0
if(get_thread_local_1d_id() == 0)
{
print_ConstantTensorDescriptor(dst_desc, "blockwise_4d_tensor_op_unary: dst_desc: ");
print_ConstantTensorDescriptor(desc, "blockwise_4d_tensor_op_unary: desc: ");
}
#endif
constexpr
index_t
NLoop
=
desc
.
GetElementSize
()
/
BlockSize
;
for
(
index_t
iloop
=
0
;
iloop
<
NLoop
;
++
iloop
)
{
index_t
is
=
get_thread_local_1d_id
()
+
iloop
*
BlockSize
;
const
index_t
did0
=
is
/
desc
.
GetStride
(
I0
);
is
-=
did0
*
desc
.
GetStride
(
I0
);
const
index_t
did1
=
is
/
desc
.
GetStride
(
I1
);
is
-=
did1
*
desc
.
GetStride
(
I1
);
const
index_t
did2
=
is
/
desc
.
GetStride
(
I2
);
is
-=
did2
*
desc
.
GetStride
(
I2
);
const
index_t
did3
=
is
/
desc
.
GetStride
(
I3
);
const
index_t
dindex
=
dst_desc
.
GetOffsetFromMultiIndex
(
did0
,
did1
,
did2
,
did3
);
f
(
p_dst
[
dindex
]);
}
constexpr
bool
has_tail
=
(
desc
.
GetElementSize
()
>
NLoop
*
BlockSize
);
if
(
has_tail
)
{
index_t
is
=
get_thread_local_1d_id
()
+
NLoop
*
BlockSize
;
if
(
is
<
desc
.
GetElementSize
())
{
const
index_t
did0
=
is
/
desc
.
GetStride
(
I0
);
is
-=
did0
*
desc
.
GetStride
(
I0
);
const
index_t
did1
=
is
/
desc
.
GetStride
(
I1
);
is
-=
did1
*
desc
.
GetStride
(
I1
);
const
index_t
did2
=
is
/
desc
.
GetStride
(
I2
);
is
-=
did2
*
desc
.
GetStride
(
I2
);
const
index_t
did3
=
is
/
desc
.
GetStride
(
I3
);
const
index_t
dindex
=
dst_desc
.
GetOffsetFromMultiIndex
(
did0
,
did1
,
did2
,
did3
);
f
(
p_dst
[
dindex
]);
}
}
}
// Function: p_dst[reorder[i0], reorder[i1], reorder[i2], reorder[i3]] = p_src[i0,i1,i2,i3]
// TODO: in order to optimize mem access for different mem type,
// need to write specialized version
template
<
index_t
BlockSize
,
class
Float
,
class
SrcDesc
,
class
DstDesc
,
class
SrcOpLengths
,
class
MapDst2Src
,
class
F
>
__device__
void
blockwise_4d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src
(
SrcDesc
,
const
Float
*
__restrict__
p_src
,
DstDesc
,
Float
*
__restrict__
p_dst
,
SrcOpLengths
,
MapDst2Src
,
F
f
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
index_t
IR0
=
MapDst2Src
{}.
Get
(
I0
);
constexpr
index_t
IR1
=
MapDst2Src
{}.
Get
(
I1
);
constexpr
index_t
IR2
=
MapDst2Src
{}.
Get
(
I2
);
constexpr
index_t
IR3
=
MapDst2Src
{}.
Get
(
I3
);
constexpr
auto
src_desc
=
SrcDesc
{};
constexpr
auto
dst_desc
=
DstDesc
{};
constexpr
auto
ref_desc
=
make_ConstantTensorDescriptor_packed
(
SrcOpLengths
{});
constexpr
index_t
NLoop
=
ref_desc
.
GetElementSize
()
/
BlockSize
;
for
(
index_t
iloop
=
0
;
iloop
<
NLoop
;
++
iloop
)
{
index_t
is
=
get_thread_local_1d_id
()
+
iloop
*
BlockSize
;
index_t
did
[
4
];
did
[
0
]
=
is
/
ref_desc
.
GetStride
(
I0
);
is
-=
did
[
0
]
*
ref_desc
.
GetStride
(
I0
);
did
[
1
]
=
is
/
ref_desc
.
GetStride
(
I1
);
is
-=
did
[
1
]
*
ref_desc
.
GetStride
(
I1
);
did
[
2
]
=
is
/
ref_desc
.
GetStride
(
I2
);
is
-=
did
[
2
]
*
ref_desc
.
GetStride
(
I2
);
did
[
3
]
=
is
/
ref_desc
.
GetStride
(
I3
);
const
index_t
src_index
=
src_desc
.
GetOffsetFromMultiIndex
(
did
[
0
],
did
[
1
],
did
[
2
],
did
[
3
]);
const
index_t
dst_index
=
dst_desc
.
GetOffsetFromMultiIndex
(
did
[
IR0
],
did
[
IR1
],
did
[
IR2
],
did
[
IR3
]);
f
(
p_src
[
src_index
],
p_dst
[
dst_index
]);
}
constexpr
bool
has_tail
=
(
ref_desc
.
GetElementSize
()
>
NLoop
*
BlockSize
);
if
(
has_tail
)
{
index_t
is
=
get_thread_local_1d_id
()
+
NLoop
*
BlockSize
;
if
(
is
<
ref_desc
.
GetElementSize
())
{
index_t
did
[
4
];
did
[
0
]
=
is
/
ref_desc
.
GetStride
(
I0
);
is
-=
did
[
0
]
*
ref_desc
.
GetStride
(
I0
);
did
[
1
]
=
is
/
ref_desc
.
GetStride
(
I1
);
is
-=
did
[
1
]
*
ref_desc
.
GetStride
(
I1
);
did
[
2
]
=
is
/
ref_desc
.
GetStride
(
I2
);
is
-=
did
[
2
]
*
ref_desc
.
GetStride
(
I2
);
did
[
3
]
=
is
/
ref_desc
.
GetStride
(
I3
);
const
index_t
src_index
=
src_desc
.
GetOffsetFromMultiIndex
(
did
[
0
],
did
[
1
],
did
[
2
],
did
[
3
]);
const
index_t
dst_index
=
dst_desc
.
GetOffsetFromMultiIndex
(
did
[
IR0
],
did
[
IR1
],
did
[
IR2
],
did
[
IR3
]);
f
(
p_src
[
src_index
],
p_dst
[
dst_index
]);
}
}
}
template
<
index_t
BlockSize
,
class
Float
,
class
DstDesc
>
__device__
void
blockwise_4d_tensor_set_zero
(
DstDesc
,
Float
*
__restrict__
p_dst
)
{
auto
f_set_zero
=
[](
Float
&
v
)
{
v
=
Float
(
0
);
};
blockwise_4d_tensor_pointwise_operation_unary
<
BlockSize
>
(
DstDesc
{},
p_dst
,
f_set_zero
);
}
template
<
index_t
BlockSize
,
class
Float
,
class
SrcDesc
,
class
DstDesc
,
class
SrcOpLengths
,
class
MapDst2Src
>
__device__
void
blockwise_4d_tensor_copy_reorder_by_get_dst_from_src
(
SrcDesc
,
const
Float
*
__restrict__
p_src
,
DstDesc
,
Float
*
__restrict__
p_dst
,
SrcOpLengths
,
MapDst2Src
)
{
auto
f_copy
=
[](
const
Float
&
src
,
Float
&
dst
)
{
dst
=
src
;
};
blockwise_4d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src
<
BlockSize
>
(
SrcDesc
{},
p_src
,
DstDesc
{},
p_dst
,
SrcOpLengths
{},
MapDst2Src
{},
f_copy
);
}
template
<
index_t
BlockSize
,
class
Float
,
class
SrcDesc
,
class
DstDesc
,
class
CopyLengths
,
index_t
DataPerRead
>
struct
Blockwise4dTensorCopy1
{
using
vector_t
=
typename
vector_type
<
Float
,
DataPerRead
>::
MemoryType
;
__device__
constexpr
Blockwise4dTensorCopy1
()
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
static_assert
(
DataPerRead
==
1
||
(
SrcDesc
{}.
GetStride
(
I3
)
==
1
&&
DstDesc
{}.
GetStride
(
I3
)
==
1
),
"wrong! only support stride3 == 1 if DataPerRead > 1!
\n
"
);
static_assert
(
DataPerRead
==
1
||
DataPerRead
==
2
||
DataPerRead
==
4
,
"wrong! only support DataPerRead == 1, 2 or 4!
\n
"
);
static_assert
(
SrcDesc
{}.
GetStride
(
I2
)
%
DataPerRead
==
0
&&
DstDesc
{}.
GetStride
(
I2
)
%
DataPerRead
==
0
,
"src and dst stride2 should be multiple of DataPerRead to keep alignment"
);
// we allow out-of-bound read from src in D3 dimension,
// but we need to make sure dst stride2 is big enough,
// so that the out-of-bound write won't contaminate next line in dst
constexpr
index_t
L3
=
CopyLengths
{}.
Get
(
I3
);
constexpr
index_t
read_per_d3
=
math
::
integer_divide_ceil
(
L3
,
DataPerRead
);
static_assert
(
read_per_d3
*
DataPerRead
<=
DstDesc
{}.
GetStride
(
I2
),
"wrong! out-of-bound write will contaminate next line!
\n
"
);
}
__device__
void
Run
(
const
Float
*
__restrict__
p_src
,
Float
*
__restrict__
p_dst
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
src_desc
=
SrcDesc
{};
constexpr
auto
dst_desc
=
DstDesc
{};
constexpr
index_t
L0
=
CopyLengths
{}.
Get
(
I0
);
constexpr
index_t
L1
=
CopyLengths
{}.
Get
(
I1
);
constexpr
index_t
L2
=
CopyLengths
{}.
Get
(
I2
);
constexpr
index_t
L3
=
CopyLengths
{}.
Get
(
I3
);
constexpr
index_t
read_per_d3
=
math
::
integer_divide_ceil
(
L3
,
DataPerRead
);
constexpr
auto
ref_desc
=
make_ConstantTensorDescriptor_packed
(
Sequence
<
L0
,
L1
,
L2
,
read_per_d3
>
{});
constexpr
index_t
NLoop
=
ref_desc
.
GetElementSize
()
/
BlockSize
;
auto
f_copy
=
[
&
](
index_t
is
)
{
index_t
did
[
4
];
did
[
0
]
=
is
/
ref_desc
.
GetStride
(
I0
);
is
-=
did
[
0
]
*
ref_desc
.
GetStride
(
I0
);
did
[
1
]
=
is
/
ref_desc
.
GetStride
(
I1
);
is
-=
did
[
1
]
*
ref_desc
.
GetStride
(
I1
);
did
[
2
]
=
is
/
ref_desc
.
GetStride
(
I2
);
is
-=
did
[
2
]
*
ref_desc
.
GetStride
(
I2
);
did
[
3
]
=
is
/
ref_desc
.
GetStride
(
I3
);
const
index_t
src_index
=
src_desc
.
GetOffsetFromMultiIndex
(
did
[
0
],
did
[
1
],
did
[
2
],
did
[
3
]
*
DataPerRead
);
const
index_t
dst_index
=
dst_desc
.
GetOffsetFromMultiIndex
(
did
[
0
],
did
[
1
],
did
[
2
],
did
[
3
]
*
DataPerRead
);
*
(
reinterpret_cast
<
vector_t
*>
(
p_dst
+
dst_index
))
=
*
(
reinterpret_cast
<
const
vector_t
*>
(
p_src
+
src_index
));
};
for
(
index_t
iloop
=
0
;
iloop
<
NLoop
;
++
iloop
)
{
index_t
is
=
get_thread_local_1d_id
()
+
iloop
*
BlockSize
;
f_copy
(
is
);
}
constexpr
bool
has_tail
=
(
ref_desc
.
GetElementSize
()
>
NLoop
*
BlockSize
);
if
(
has_tail
)
{
index_t
is
=
get_thread_local_1d_id
()
+
NLoop
*
BlockSize
;
if
(
is
<
ref_desc
.
GetElementSize
())
{
f_copy
(
is
);
}
}
}
};
template
<
index_t
BlockSize
,
class
Float
,
class
SrcDesc
,
class
DstDesc
,
class
DstOpLengths
,
class
GlobalLowerPads
>
struct
BlockwiseChwnTensorCopyPadded
{
__device__
void
Run
(
const
Float
*
__restrict__
p_src
,
index_t
c_block_data_begin
,
index_t
ho_block_data_begin
,
index_t
wo_block_data_begin
,
index_t
n_block_data_begin
,
Float
*
__restrict__
p_dst
,
index_t
h_block_pad_low
,
index_t
w_block_pad_low
,
index_t
h_block_pad_up
,
index_t
w_block_pad_up
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
src_desc
=
SrcDesc
{};
constexpr
auto
dst_desc
=
DstDesc
{};
constexpr
auto
ref_desc
=
make_ConstantTensorDescriptor_packed
(
DstOpLengths
{});
constexpr
auto
h_global_pad_low
=
GlobalLowerPads
{}.
Get
(
I0
);
constexpr
auto
w_global_pad_low
=
GlobalLowerPads
{}.
Get
(
I1
);
constexpr
index_t
NLoop
=
ref_desc
.
GetElementSize
()
/
BlockSize
;
const
Float
*
p_src_tmp
=
p_src
+
src_desc
.
GetOffsetFromMultiIndex
(
c_block_data_begin
,
(
ho_block_data_begin
+
h_block_pad_low
)
-
h_global_pad_low
,
(
wo_block_data_begin
+
w_block_pad_low
)
-
w_global_pad_low
,
n_block_data_begin
);
#if 0
if(get_thread_local_1d_id() == 0)
{
print_ConstantTensorDescriptor(src_desc, "src_desc: ");
print_ConstantTensorDescriptor(dst_desc, "dst_desc: ");
print_ConstantTensorDescriptor(ref_desc, "ref_desc: ");
printf("%u %u, \t"
"h_global_pad_low %u w_global_pad_low %u \t"
"h_block_pad_low %u w_block_pad_low %u h_block_pad_up %u w_block_pad_up %u \t"
"\n",
get_block_1d_id(),
get_thread_local_1d_id(),
h_global_pad_low,
w_global_pad_low,
h_block_pad_low,
w_block_pad_low,
h_block_pad_up,
w_block_pad_up);
}
#endif
for
(
index_t
iloop
=
0
;
iloop
<
NLoop
;
++
iloop
)
{
index_t
is
=
get_thread_local_1d_id
()
+
iloop
*
BlockSize
;
index_t
did
[
4
];
did
[
0
]
=
is
/
ref_desc
.
GetStride
(
I0
);
is
-=
did
[
0
]
*
ref_desc
.
GetStride
(
I0
);
did
[
1
]
=
is
/
ref_desc
.
GetStride
(
I1
);
is
-=
did
[
1
]
*
ref_desc
.
GetStride
(
I1
);
did
[
2
]
=
is
/
ref_desc
.
GetStride
(
I2
);
is
-=
did
[
2
]
*
ref_desc
.
GetStride
(
I2
);
did
[
3
]
=
is
/
ref_desc
.
GetStride
(
I3
);
const
index_t
bindex
=
dst_desc
.
GetOffsetFromMultiIndex
(
did
[
0
],
did
[
1
],
did
[
2
],
did
[
3
]);
p_dst
[
bindex
]
=
(
did
[
1
]
<
h_block_pad_low
||
did
[
1
]
+
h_block_pad_up
>=
ref_desc
.
GetLength
(
I1
)
||
did
[
2
]
<
w_block_pad_low
||
did
[
2
]
+
w_block_pad_up
>=
ref_desc
.
GetLength
(
I2
))
?
Float
(
0
)
:
p_src_tmp
[
src_desc
.
GetOffsetFromMultiIndex
(
did
[
0
],
did
[
1
],
did
[
2
],
did
[
3
])];
}
constexpr
bool
has_tail
=
(
ref_desc
.
GetElementSize
()
>
NLoop
*
BlockSize
);
if
(
has_tail
)
{
index_t
is
=
get_thread_local_1d_id
()
+
NLoop
*
BlockSize
;
if
(
is
<
ref_desc
.
GetElementSize
())
{
index_t
did
[
4
];
did
[
0
]
=
is
/
ref_desc
.
GetStride
(
I0
);
is
-=
did
[
0
]
*
ref_desc
.
GetStride
(
I0
);
did
[
1
]
=
is
/
ref_desc
.
GetStride
(
I1
);
is
-=
did
[
1
]
*
ref_desc
.
GetStride
(
I1
);
did
[
2
]
=
is
/
ref_desc
.
GetStride
(
I2
);
is
-=
did
[
2
]
*
ref_desc
.
GetStride
(
I2
);
did
[
3
]
=
is
/
ref_desc
.
GetStride
(
I3
);
const
index_t
bindex
=
dst_desc
.
GetOffsetFromMultiIndex
(
did
[
0
],
did
[
1
],
did
[
2
],
did
[
3
]);
p_dst
[
bindex
]
=
(
did
[
1
]
<
h_block_pad_low
||
did
[
1
]
+
h_block_pad_up
>=
ref_desc
.
GetLength
(
I1
)
||
did
[
2
]
<
w_block_pad_low
||
did
[
2
]
+
w_block_pad_up
>=
ref_desc
.
GetLength
(
I2
))
?
Float
(
0
)
:
p_src_tmp
[
src_desc
.
GetOffsetFromMultiIndex
(
did
[
0
],
did
[
1
],
did
[
2
],
did
[
3
])];
}
}
}
};
// starting point need to be aligned to float4 or float2 or float
// stride3 need to be 1 for both source and destination
template
<
index_t
BlockSize
,
class
Float
,
class
SrcDesc
,
class
DstDesc
,
class
CopyLengths
,
class
ThreadPerDims
,
index_t
DataPerRead
>
struct
Blockwise4dTensorCopy3
{
using
vector_t
=
typename
vector_type
<
Float
,
DataPerRead
>::
MemoryType
;
index_t
mSrcMyThreadOffset
;
index_t
mDstMyThreadOffset
;
__device__
Blockwise4dTensorCopy3
()
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
static_assert
(
DataPerRead
==
1
||
(
SrcDesc
{}.
GetStride
(
I3
)
==
1
&&
DstDesc
{}.
GetStride
(
I3
)
==
1
),
"wrong! only support stride3 == 1 if DataPerRead > 1!
\n
"
);
static_assert
(
DataPerRead
==
1
||
DataPerRead
==
2
||
DataPerRead
==
4
,
"wrong! only support DataPerRead == 1, 2 or 4!
\n
"
);
static_assert
(
SrcDesc
{}.
GetStride
(
I2
)
%
DataPerRead
==
0
&&
DstDesc
{}.
GetStride
(
I2
)
%
DataPerRead
==
0
,
"wrong! src and dst stride2 should be multiple of DataPerRead to keep alignment"
);
constexpr
index_t
L0
=
CopyLengths
{}.
Get
(
I0
);
constexpr
index_t
L1
=
CopyLengths
{}.
Get
(
I1
);
constexpr
index_t
L2
=
CopyLengths
{}.
Get
(
I2
);
constexpr
index_t
L3
=
CopyLengths
{}.
Get
(
I3
);
constexpr
index_t
thread_per_d0
=
ThreadPerDims
{}.
Get
(
I0
);
constexpr
index_t
thread_per_d1
=
ThreadPerDims
{}.
Get
(
I1
);
constexpr
index_t
thread_per_d2
=
ThreadPerDims
{}.
Get
(
I2
);
constexpr
index_t
thread_per_d3
=
ThreadPerDims
{}.
Get
(
I3
);
// we allow out-of-bound read from src in D3 dimension,
// but we need to make sure dst stride is big enough,
// so that the out-of-bound write won't contaminate next line in dst
constexpr
index_t
nloop_d3
=
math
::
integer_divide_ceil
(
L3
,
thread_per_d3
*
DataPerRead
);
static_assert
(
nloop_d3
*
thread_per_d3
*
DataPerRead
<=
DstDesc
{}.
GetStride
(
I2
),
"wrong! out-of-bound write will contaminate next line!
\n
"
);
static_assert
(
L0
%
thread_per_d0
==
0
&&
L1
%
thread_per_d1
==
0
&&
L2
%
thread_per_d2
==
0
,
"wrong! L0, L1, L2 should be divided evenly!
\n
"
);
static_assert
(
BlockSize
>=
thread_per_d0
*
thread_per_d1
*
thread_per_d2
*
thread_per_d3
,
"wrrong! BlockSize is not big enough for ThreadPerDims!"
);
constexpr
index_t
num_active_thread
=
accumulate_on_sequence
(
ThreadPerDims
{},
math
::
multiplies
<
index_t
>
{},
Number
<
1
>
{});
if
(
BlockSize
>
num_active_thread
)
{
if
(
get_thread_local_1d_id
()
>=
num_active_thread
)
{
return
;
}
}
constexpr
auto
thread_cluster_desc
=
make_ConstantTensorDescriptor_packed
(
ThreadPerDims
{});
const
auto
thread_multi_id
=
thread_cluster_desc
.
GetMultiIndexFrom1dIndex
(
get_thread_local_1d_id
());
mSrcMyThreadOffset
=
SrcDesc
{}.
GetOffsetFromMultiIndex
(
thread_multi_id
[
0
],
thread_multi_id
[
1
],
thread_multi_id
[
2
],
thread_multi_id
[
3
]
*
DataPerRead
);
mDstMyThreadOffset
=
DstDesc
{}.
GetOffsetFromMultiIndex
(
thread_multi_id
[
0
],
thread_multi_id
[
1
],
thread_multi_id
[
2
],
thread_multi_id
[
3
]
*
DataPerRead
);
}
__device__
void
Run
(
const
Float
*
__restrict__
p_src
,
Float
*
__restrict__
p_dst
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
index_t
L0
=
CopyLengths
{}.
Get
(
I0
);
constexpr
index_t
L1
=
CopyLengths
{}.
Get
(
I1
);
constexpr
index_t
L2
=
CopyLengths
{}.
Get
(
I2
);
constexpr
index_t
L3
=
CopyLengths
{}.
Get
(
I3
);
constexpr
index_t
thread_per_d0
=
ThreadPerDims
{}.
Get
(
I0
);
constexpr
index_t
thread_per_d1
=
ThreadPerDims
{}.
Get
(
I1
);
constexpr
index_t
thread_per_d2
=
ThreadPerDims
{}.
Get
(
I2
);
constexpr
index_t
thread_per_d3
=
ThreadPerDims
{}.
Get
(
I3
);
constexpr
index_t
num_active_thread
=
thread_per_d0
*
thread_per_d1
*
thread_per_d2
*
thread_per_d3
;
if
(
BlockSize
>
num_active_thread
)
{
if
(
get_thread_local_1d_id
()
>=
num_active_thread
)
{
return
;
}
}
constexpr
index_t
nloop_d0
=
L0
/
thread_per_d0
;
constexpr
index_t
nloop_d1
=
L1
/
thread_per_d1
;
constexpr
index_t
nloop_d2
=
L2
/
thread_per_d2
;
constexpr
index_t
nloop_d3
=
math
::
integer_divide_ceil
(
L3
,
thread_per_d3
*
DataPerRead
);
#pragma unroll
for
(
index_t
iloop_d0
=
0
;
iloop_d0
<
nloop_d0
;
++
iloop_d0
)
{
#pragma unroll
for
(
index_t
iloop_d1
=
0
;
iloop_d1
<
nloop_d1
;
++
iloop_d1
)
{
#pragma unroll
for
(
index_t
iloop_d2
=
0
;
iloop_d2
<
nloop_d2
;
++
iloop_d2
)
{
#pragma unroll
for
(
index_t
iloop_d3
=
0
;
iloop_d3
<
nloop_d3
;
++
iloop_d3
)
{
const
index_t
src_offset
=
SrcDesc
{}.
GetOffsetFromMultiIndex
(
iloop_d0
*
thread_per_d0
,
iloop_d1
*
thread_per_d1
,
iloop_d2
*
thread_per_d2
,
iloop_d3
*
thread_per_d3
*
DataPerRead
);
const
index_t
dst_offset
=
DstDesc
{}.
GetOffsetFromMultiIndex
(
iloop_d0
*
thread_per_d0
,
iloop_d1
*
thread_per_d1
,
iloop_d2
*
thread_per_d2
,
iloop_d3
*
thread_per_d3
*
DataPerRead
);
*
(
reinterpret_cast
<
vector_t
*>
(
&
p_dst
[
dst_offset
+
mDstMyThreadOffset
]))
=
*
(
reinterpret_cast
<
const
vector_t
*>
(
&
p_src
[
src_offset
+
mSrcMyThreadOffset
]));
}
}
}
}
}
__device__
constexpr
index_t
GetRegisterClipboardSize
()
const
{
static_assert
(
is_same
<
Float
,
float
>
{},
"wrong! only support float!
\n
"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
index_t
L0
=
CopyLengths
{}.
Get
(
I0
);
constexpr
index_t
L1
=
CopyLengths
{}.
Get
(
I1
);
constexpr
index_t
L2
=
CopyLengths
{}.
Get
(
I2
);
constexpr
index_t
L3
=
CopyLengths
{}.
Get
(
I3
);
constexpr
index_t
thread_per_d0
=
ThreadPerDims
{}.
Get
(
I0
);
constexpr
index_t
thread_per_d1
=
ThreadPerDims
{}.
Get
(
I1
);
constexpr
index_t
thread_per_d2
=
ThreadPerDims
{}.
Get
(
I2
);
constexpr
index_t
thread_per_d3
=
ThreadPerDims
{}.
Get
(
I3
);
constexpr
index_t
nloop_d0
=
L0
/
thread_per_d0
;
constexpr
index_t
nloop_d1
=
L1
/
thread_per_d1
;
constexpr
index_t
nloop_d2
=
L2
/
thread_per_d2
;
constexpr
index_t
nloop_d3
=
math
::
integer_divide_ceil
(
L3
,
thread_per_d3
*
DataPerRead
);
return
DataPerRead
*
nloop_d0
*
nloop_d1
*
nloop_d2
*
nloop_d3
;
}
__device__
void
RunLoadRegisterClipboard
(
const
Float
*
__restrict__
p_src
,
Float
*
__restrict__
p_clipboard
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
index_t
L0
=
CopyLengths
{}.
Get
(
I0
);
constexpr
index_t
L1
=
CopyLengths
{}.
Get
(
I1
);
constexpr
index_t
L2
=
CopyLengths
{}.
Get
(
I2
);
constexpr
index_t
L3
=
CopyLengths
{}.
Get
(
I3
);
constexpr
index_t
thread_per_d0
=
ThreadPerDims
{}.
Get
(
I0
);
constexpr
index_t
thread_per_d1
=
ThreadPerDims
{}.
Get
(
I1
);
constexpr
index_t
thread_per_d2
=
ThreadPerDims
{}.
Get
(
I2
);
constexpr
index_t
thread_per_d3
=
ThreadPerDims
{}.
Get
(
I3
);
constexpr
index_t
num_active_thread
=
thread_per_d0
*
thread_per_d1
*
thread_per_d2
*
thread_per_d3
;
if
(
BlockSize
>
num_active_thread
)
{
if
(
get_thread_local_1d_id
()
>=
num_active_thread
)
{
return
;
}
}
constexpr
index_t
nloop_d0
=
L0
/
thread_per_d0
;
constexpr
index_t
nloop_d1
=
L1
/
thread_per_d1
;
constexpr
index_t
nloop_d2
=
L2
/
thread_per_d2
;
constexpr
index_t
nloop_d3
=
math
::
integer_divide_ceil
(
L3
,
thread_per_d3
*
DataPerRead
);
constexpr
auto
clipboard_desc
=
make_ConstantTensorDescriptor_packed
(
Sequence
<
nloop_d0
,
nloop_d1
,
nloop_d2
,
nloop_d3
*
DataPerRead
>
{});
#pragma unroll
for
(
index_t
iloop_d0
=
0
;
iloop_d0
<
nloop_d0
;
++
iloop_d0
)
{
#pragma unroll
for
(
index_t
iloop_d1
=
0
;
iloop_d1
<
nloop_d1
;
++
iloop_d1
)
{
#pragma unroll
for
(
index_t
iloop_d2
=
0
;
iloop_d2
<
nloop_d2
;
++
iloop_d2
)
{
#pragma unroll
for
(
index_t
iloop_d3
=
0
;
iloop_d3
<
nloop_d3
;
++
iloop_d3
)
{
const
index_t
src_offset
=
SrcDesc
{}.
GetOffsetFromMultiIndex
(
iloop_d0
*
thread_per_d0
,
iloop_d1
*
thread_per_d1
,
iloop_d2
*
thread_per_d2
,
iloop_d3
*
thread_per_d3
*
DataPerRead
);
const
index_t
clipboard_offset
=
clipboard_desc
.
GetOffsetFromMultiIndex
(
iloop_d0
,
iloop_d1
,
iloop_d2
,
iloop_d3
*
DataPerRead
);
*
(
reinterpret_cast
<
vector_t
*>
(
&
p_clipboard
[
clipboard_offset
]))
=
*
(
reinterpret_cast
<
const
vector_t
*>
(
&
p_src
[
src_offset
+
mSrcMyThreadOffset
]));
}
}
}
}
}
__device__
void
RunStoreRegisterClipboard
(
const
Float
*
__restrict__
p_clipboard
,
Float
*
__restrict__
p_dst
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
index_t
L0
=
CopyLengths
{}.
Get
(
I0
);
constexpr
index_t
L1
=
CopyLengths
{}.
Get
(
I1
);
constexpr
index_t
L2
=
CopyLengths
{}.
Get
(
I2
);
constexpr
index_t
L3
=
CopyLengths
{}.
Get
(
I3
);
constexpr
index_t
thread_per_d0
=
ThreadPerDims
{}.
Get
(
I0
);
constexpr
index_t
thread_per_d1
=
ThreadPerDims
{}.
Get
(
I1
);
constexpr
index_t
thread_per_d2
=
ThreadPerDims
{}.
Get
(
I2
);
constexpr
index_t
thread_per_d3
=
ThreadPerDims
{}.
Get
(
I3
);
constexpr
index_t
num_active_thread
=
thread_per_d0
*
thread_per_d1
*
thread_per_d2
*
thread_per_d3
;
if
(
BlockSize
>
num_active_thread
)
{
if
(
get_thread_local_1d_id
()
>=
num_active_thread
)
{
return
;
}
}
constexpr
index_t
nloop_d0
=
L0
/
thread_per_d0
;
constexpr
index_t
nloop_d1
=
L1
/
thread_per_d1
;
constexpr
index_t
nloop_d2
=
L2
/
thread_per_d2
;
constexpr
index_t
nloop_d3
=
math
::
integer_divide_ceil
(
L3
,
thread_per_d3
*
DataPerRead
);
constexpr
auto
clipboard_desc
=
make_ConstantTensorDescriptor_packed
(
Sequence
<
nloop_d0
,
nloop_d1
,
nloop_d2
,
nloop_d3
*
DataPerRead
>
{});
#pragma unroll
for
(
index_t
iloop_d0
=
0
;
iloop_d0
<
nloop_d0
;
++
iloop_d0
)
{
#pragma unroll
for
(
index_t
iloop_d1
=
0
;
iloop_d1
<
nloop_d1
;
++
iloop_d1
)
{
#pragma unroll
for
(
index_t
iloop_d2
=
0
;
iloop_d2
<
nloop_d2
;
++
iloop_d2
)
{
#pragma unroll
for
(
index_t
iloop_d3
=
0
;
iloop_d3
<
nloop_d3
;
++
iloop_d3
)
{
const
index_t
clipboard_offset
=
clipboard_desc
.
GetOffsetFromMultiIndex
(
iloop_d0
,
iloop_d1
,
iloop_d2
,
iloop_d3
*
DataPerRead
);
const
index_t
dst_offset
=
DstDesc
{}.
GetOffsetFromMultiIndex
(
iloop_d0
*
thread_per_d0
,
iloop_d1
*
thread_per_d1
,
iloop_d2
*
thread_per_d2
,
iloop_d3
*
thread_per_d3
*
DataPerRead
);
*
(
reinterpret_cast
<
vector_t
*>
(
&
p_dst
[
dst_offset
+
mDstMyThreadOffset
]))
=
*
(
reinterpret_cast
<
const
vector_t
*>
(
&
p_clipboard
[
clipboard_offset
]));
}
}
}
}
}
};
template
<
index_t
BlockSize
,
class
Float
,
class
SrcDesc
,
class
DstDesc
,
class
SrcOpLengths
,
class
MapDst2Src
>
struct
Blockwise4dTensorCopyReorder1
{
__device__
void
Run
(
const
Float
*
__restrict__
p_src
,
Float
*
__restrict__
p_dst
)
const
{
auto
f_copy
=
[](
const
Float
&
src
,
Float
&
dst
)
{
dst
=
src
;
};
blockwise_4d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src
<
BlockSize
>
(
SrcDesc
{},
p_src
,
DstDesc
{},
p_dst
,
SrcOpLengths
{},
MapDst2Src
{},
f_copy
);
}
};
}
// namespace
#endif
composable_kernel/include/tensor_operation/blockwise_batched_gemm.hpp
deleted
100644 → 0
View file @
f0716f5b
#ifndef CK_BLOCKWISE_BATCHED_GEMM_HPP
#define CK_BLOCKWISE_BATCHED_GEMM_HPP
#include "common_header.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "threadwise_gemm.hpp"
namespace
ck
{
template
<
index_t
BlockSize
,
class
BlockMatrixA
,
class
BlockMatrixB
,
class
ThreadMatrixC
,
index_t
BlockMatrixStrideA
,
index_t
BlockMatrixStrideB
,
index_t
ThreadMatrixStrideC
,
index_t
BatchSize
,
index_t
MPerThreadSubC
,
index_t
NPerThreadSubC
,
index_t
MLevel0Cluster
,
index_t
NLevel0Cluster
,
index_t
MLevel1Cluster
,
index_t
NLevel1Cluster
,
index_t
KPerThreadLoop
,
index_t
BatchPerThread
,
index_t
DataPerReadA
,
index_t
DataPerReadB
>
struct
BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
{
index_t
mMyThreadOffsetA
=
0
;
index_t
mMyThreadOffsetB
=
0
;
struct
MatrixIndex
{
index_t
batch
;
index_t
row
;
index_t
col
;
};
__device__
BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
()
{
static_assert
(
BatchSize
%
BatchPerThread
==
0
,
"wrong! BatchSize is not dividable by BatchPerThread"
);
constexpr
index_t
BatchThreadWork
=
BatchSize
/
BatchPerThread
;
constexpr
index_t
ThreadPerLevel1Cluster
=
MLevel0Cluster
*
NLevel0Cluster
*
MLevel1Cluster
*
NLevel1Cluster
;
static_assert
(
BlockSize
==
BatchThreadWork
*
ThreadPerLevel1Cluster
,
"wrong! wrong blocksize
\n
"
);
constexpr
auto
a_block_mtx
=
BlockMatrixA
{};
constexpr
auto
b_block_mtx
=
BlockMatrixB
{};
constexpr
auto
c_thread_mtx
=
ThreadMatrixC
{};
static_assert
(
a_block_mtx
.
NRow
()
==
b_block_mtx
.
NRow
(),
"wrong! K dimension not consistent
\n
"
);
constexpr
index_t
M
=
a_block_mtx
.
NCol
();
// A is transposed
constexpr
index_t
N
=
b_block_mtx
.
NCol
();
constexpr
index_t
MPerThread
=
c_thread_mtx
.
NRow
();
constexpr
index_t
NPerThread
=
c_thread_mtx
.
NCol
();
static_assert
((
MPerThread
%
MPerThreadSubC
==
0
)
&&
(
NPerThread
%
NPerThreadSubC
==
0
),
"wrong! Cannot evenly divide thread work among repeat
\n
"
);
constexpr
index_t
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
index_t
NRepeat
=
NPerThread
/
NPerThreadSubC
;
static_assert
((
M
%
MRepeat
==
0
)
&&
(
N
%
NRepeat
==
0
),
"wrong! Cannot evenly divide work among repeat
\n
"
);
constexpr
index_t
MPerLevel1Cluster
=
M
/
MRepeat
;
constexpr
index_t
NPerLevel1Cluster
=
N
/
NRepeat
;
static_assert
((
MPerLevel1Cluster
%
MLevel1Cluster
==
0
)
&&
(
NPerLevel1Cluster
%
NLevel1Cluster
==
0
),
"wrong! Cannot evenly divide work among Level1Cluster
\n
"
);
constexpr
index_t
MPerLevel0Cluster
=
MPerLevel1Cluster
/
MLevel1Cluster
;
constexpr
index_t
NPerLevel0Cluster
=
NPerLevel1Cluster
/
NLevel1Cluster
;
static_assert
((
MPerLevel0Cluster
%
MLevel0Cluster
==
0
)
&&
(
NPerLevel0Cluster
%
NLevel0Cluster
==
0
),
"wrong! Cannot evenly divide work among Level0Cluster
\n
"
);
static_assert
((
MPerThreadSubC
==
MPerLevel0Cluster
/
MLevel0Cluster
)
&&
(
NPerThreadSubC
==
NPerLevel0Cluster
/
NLevel0Cluster
),
"wrong! thread work size is wrong
\n
"
);
const
auto
c_thread_mtx_index
=
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
mMyThreadOffsetA
=
c_thread_mtx_index
.
batch
*
BlockMatrixStrideA
+
a_block_mtx
.
GetOffsetFromMultiIndex
(
0
,
c_thread_mtx_index
.
row
);
mMyThreadOffsetB
=
c_thread_mtx_index
.
batch
*
BlockMatrixStrideB
+
b_block_mtx
.
GetOffsetFromMultiIndex
(
0
,
c_thread_mtx_index
.
col
);
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantMatrixDescriptor(BlockMatrixA{}, "a_block_mtx: ");
print_ConstantMatrixDescriptor(BlockMatrixB{}, "b_block_mtx: ");
print_ConstantMatrixDescriptor(ThreadMatrixC{}, "c_thread_mtx: ");
printf("%u %u, %u %u %u, %u %u\n",
get_block_1d_id(),
get_thread_local_1d_id(),
c_thread_mtx_index.batch,
c_thread_mtx_index.row,
c_thread_mtx_index.col,
mMyThreadOffsetA,
mMyThreadOffsetB);
}
#endif
}
__device__
MatrixIndex
GetBeginOfThreadMatrixC
(
index_t
thread_id
)
const
{
constexpr
index_t
ThreadPerLevel1Cluster
=
MLevel0Cluster
*
NLevel0Cluster
*
MLevel1Cluster
*
NLevel1Cluster
;
constexpr
index_t
ThreadPerLevel0Cluster
=
MLevel0Cluster
*
NLevel0Cluster
;
index_t
batch_work_id
=
thread_id
/
ThreadPerLevel1Cluster
;
index_t
cluster_id
=
thread_id
-
batch_work_id
*
ThreadPerLevel1Cluster
;
index_t
level1_id
=
cluster_id
/
ThreadPerLevel0Cluster
;
index_t
level1_m_id
=
level1_id
/
NLevel1Cluster
;
index_t
level1_n_id
=
level1_id
%
NLevel1Cluster
;
index_t
level0_id
=
cluster_id
%
ThreadPerLevel0Cluster
;
index_t
level0_m_id
=
level0_id
/
NLevel0Cluster
;
index_t
level0_n_id
=
level0_id
%
NLevel0Cluster
;
constexpr
index_t
MPerLevel0Cluster
=
MPerThreadSubC
*
MLevel0Cluster
;
constexpr
index_t
NPerLevel0Cluster
=
NPerThreadSubC
*
NLevel0Cluster
;
return
MatrixIndex
{
batch_work_id
*
BatchPerThread
,
level1_m_id
*
MPerLevel0Cluster
+
level0_m_id
*
MPerThreadSubC
,
level1_n_id
*
NPerLevel0Cluster
+
level0_n_id
*
NPerThreadSubC
};
}
// this should be optimized away because input will be known at compile time
__device__
static
MatrixIndex
GetDistanceFromBeginOfThreadMatrixC
(
index_t
batch_in_c
,
index_t
m_in_c
,
index_t
n_in_c
)
{
constexpr
auto
c_thread_mtx
=
ThreadMatrixC
{};
constexpr
index_t
MPerThread
=
c_thread_mtx
.
NRow
();
constexpr
index_t
NPerThread
=
c_thread_mtx
.
NCol
();
constexpr
index_t
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
index_t
NRepeat
=
NPerThread
/
NPerThreadSubC
;
constexpr
index_t
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0Cluster
*
MLevel1Cluster
;
constexpr
index_t
NPerLevel1Cluster
=
NPerThreadSubC
*
NLevel0Cluster
*
NLevel1Cluster
;
index_t
m_repeat
=
m_in_c
/
MPerThreadSubC
;
index_t
n_repeat
=
n_in_c
/
NPerThreadSubC
;
index_t
m_in_sub_c
=
m_in_c
%
MPerThreadSubC
;
index_t
n_in_sub_c
=
n_in_c
%
NPerThreadSubC
;
return
MatrixIndex
{
batch_in_c
,
m_repeat
*
MPerLevel1Cluster
+
m_in_sub_c
,
n_repeat
*
NPerLevel1Cluster
+
n_in_sub_c
};
}
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
Run
(
const
FloatA
*
__restrict__
p_a_block
,
const
FloatB
*
__restrict__
p_b_block
,
FloatC
*
__restrict__
p_c_thread
)
const
{
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
False
=
integral_constant
<
bool
,
false
>
{};
constexpr
auto
a_block_mtx
=
BlockMatrixA
{};
constexpr
auto
b_block_mtx
=
BlockMatrixB
{};
constexpr
auto
c_thread_mtx
=
ThreadMatrixC
{};
constexpr
index_t
KPerBlock
=
a_block_mtx
.
NRow
();
// A is transposed
constexpr
index_t
MPerThread
=
c_thread_mtx
.
NRow
();
constexpr
index_t
NPerThread
=
c_thread_mtx
.
NCol
();
// thread A, B for GEMM
// A is transposed, b is not
constexpr
auto
a_thread_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
MPerThread
>
{});
constexpr
auto
b_thread_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThread
>
{});
// thread A-sub, B-sub for copy
constexpr
auto
a_thread_sub_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
MPerThreadSubC
>
{},
Number
<
MPerThread
>
{});
constexpr
auto
b_thread_sub_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThreadSubC
>
{},
Number
<
NPerThread
>
{});
FloatA
p_a_thread
[
a_thread_mtx
.
GetElementSpace
()];
FloatB
p_b_thread
[
b_thread_mtx
.
GetElementSpace
()];
constexpr
index_t
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0Cluster
*
MLevel1Cluster
;
constexpr
index_t
NPerLevel1Cluster
=
NPerThreadSubC
*
NLevel0Cluster
*
NLevel1Cluster
;
constexpr
index_t
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
index_t
NRepeat
=
NPerThread
/
NPerThreadSubC
;
// loop over k
#pragma unroll
for
(
index_t
k_begin
=
0
;
k_begin
<
KPerBlock
;
k_begin
+=
KPerThreadLoop
)
{
// loop over batch
#pragma unroll
for
(
index_t
ib
=
0
;
ib
<
BatchPerThread
;
++
ib
)
{
// read next batch of a, b
if
(
BlockMatrixStrideA
!=
0
or
ib
==
0
)
{
#pragma unroll
for
(
index_t
m_repeat
=
0
;
m_repeat
<
MRepeat
;
++
m_repeat
)
{
threadwise_matrix_copy
(
a_block_mtx
,
p_a_block
+
a_block_mtx
.
GetOffsetFromMultiIndex
(
k_begin
,
m_repeat
*
MPerLevel1Cluster
)
+
ib
*
BlockMatrixStrideA
+
mMyThreadOffsetA
,
a_thread_mtx
,
p_a_thread
+
a_thread_mtx
.
GetOffsetFromMultiIndex
(
0
,
m_repeat
*
MPerThreadSubC
),
a_thread_sub_mtx
.
GetLengths
(),
Number
<
DataPerReadA
>
{});
}
}
if
(
BlockMatrixStrideB
!=
0
or
ib
==
0
)
{
#pragma unroll
for
(
index_t
n_repeat
=
0
;
n_repeat
<
NRepeat
;
++
n_repeat
)
{
threadwise_matrix_copy
(
b_block_mtx
,
p_b_block
+
b_block_mtx
.
GetOffsetFromMultiIndex
(
k_begin
,
n_repeat
*
NPerLevel1Cluster
)
+
ib
*
BlockMatrixStrideB
+
mMyThreadOffsetB
,
b_thread_mtx
,
p_b_thread
+
b_thread_mtx
.
GetOffsetFromMultiIndex
(
0
,
n_repeat
*
NPerThreadSubC
),
b_thread_sub_mtx
.
GetLengths
(),
Number
<
DataPerReadB
>
{});
}
}
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
printf("a: %f %f %f %f %f %f %f %f, b: %f %f %f %f %f %f %f %f\n",
p_a_thread[0],
p_a_thread[1],
p_a_thread[2],
p_a_thread[3],
p_a_thread[4],
p_a_thread[5],
p_a_thread[6],
p_a_thread[7],
p_b_thread[0],
p_b_thread[1],
p_b_thread[2],
p_b_thread[3],
p_b_thread[4],
p_b_thread[5],
p_b_thread[6],
p_b_thread[7]);
}
#endif
threadwise_gemm
(
a_thread_mtx
,
True
,
p_a_thread
,
b_thread_mtx
,
False
,
p_b_thread
,
c_thread_mtx
,
False
,
p_c_thread
+
ib
*
ThreadMatrixStrideC
);
}
}
}
#if CK_USE_AMD_INLINE_ASM
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
Run_amd_asm
(
const
FloatA
*
__restrict__
p_a_block
,
const
FloatB
*
__restrict__
p_b_block
,
FloatC
*
__restrict__
p_c_thread
)
const
{
constexpr
auto
a_block_mtx
=
BlockMatrixA
{};
constexpr
auto
b_block_mtx
=
BlockMatrixB
{};
constexpr
auto
c_thread_mtx
=
ThreadMatrixC
{};
constexpr
index_t
K
=
a_block_mtx
.
NRow
();
// A is transposed
constexpr
index_t
MPerThread
=
c_thread_mtx
.
NRow
();
constexpr
index_t
NPerThread
=
c_thread_mtx
.
NCol
();
// thread A, B for GEMM
// A is transposed, b is not
constexpr
auto
a_thread_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
MPerThread
>
{});
constexpr
auto
b_thread_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThread
>
{});
// thread A-sub, B-sub for copy
constexpr
auto
a_thread_sub_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
MPerThreadSubC
>
{},
Number
<
MPerThread
>
{});
constexpr
auto
b_thread_sub_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThreadSubC
>
{},
Number
<
NPerThread
>
{});
FloatA
p_a_thread
[
a_thread_mtx
.
GetElementSpace
()];
FloatB
p_b_thread
[
b_thread_mtx
.
GetElementSpace
()];
constexpr
index_t
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0Cluster
*
MLevel1Cluster
;
constexpr
index_t
NPerLevel1Cluster
=
NPerThreadSubC
*
NLevel0Cluster
*
NLevel1Cluster
;
// assertion for inline asm
static_assert
(
is_same
<
FloatA
,
float
>
{}
&&
is_same
<
FloatB
,
float
>
{}
&&
is_same
<
FloatC
,
float
>
{},
"Run_amd_asm only deal with float
\n
"
);
static_assert
(
MPerThreadSubC
==
4
&&
NPerThreadSubC
==
4
&&
KPerThreadLoop
==
1
&&
MPerThread
==
8
&&
NPerThread
==
8
,
"Run_amd_asm cannot deal with this GEMM shape yet
\n
"
);
static_assert
(
DataPerReadA
==
4
&&
DataPerReadB
==
4
,
"Run_amd_asm only do float4 read
\n
"
);
static_assert
(
BlockMatrixStrideA
==
0
&&
BatchPerThread
==
1
,
"Run_amd_asm can only deal with BlockMatrixStrideA == 0 && BatchPerThread == "
"1 for now
\n
"
);
using
Float4
=
vector_type
<
float
,
4
>::
MemoryType
;
Float4
*
reg_a
=
(
Float4
*
)(
p_a_thread
);
Float4
*
reg_b
=
(
Float4
*
)(
p_b_thread
);
Float4
*
reg_c
=
(
Float4
*
)(
p_c_thread
);
reg_a
[
0
]
=
*
reinterpret_cast
<
const
Float4
*>
(
&
p_a_block
[
mMyThreadOffsetA
]);
reg_b
[
0
]
=
*
reinterpret_cast
<
const
Float4
*>
(
&
p_b_block
[
mMyThreadOffsetB
]);
reg_b
[
1
]
=
*
reinterpret_cast
<
const
Float4
*>
(
&
p_b_block
[
b_block_mtx
.
GetOffsetFromMultiIndex
(
0
,
NPerLevel1Cluster
)
+
mMyThreadOffsetB
]);
reg_a
[
1
]
=
*
reinterpret_cast
<
const
Float4
*>
(
&
p_a_block
[
a_block_mtx
.
GetOffsetFromMultiIndex
(
0
,
MPerLevel1Cluster
)
+
mMyThreadOffsetA
]);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
[
0
],
reg_c
[
2
],
reg_c
[
4
],
reg_c
[
6
]);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
1
],
reg_c
[
1
],
reg_c
[
3
],
reg_c
[
5
],
reg_c
[
7
]);
#pragma unroll
for
(
index_t
k
=
1
;
k
<
K
;
++
k
)
{
reg_a
[
0
]
=
*
reinterpret_cast
<
const
Float4
*>
(
&
p_a_block
[
a_block_mtx
.
GetOffsetFromMultiIndex
(
k
,
0
)
+
mMyThreadOffsetA
]);
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
0
],
reg_c
[
8
],
reg_c
[
10
],
reg_c
[
12
],
reg_c
[
14
]);
reg_b
[
0
]
=
*
reinterpret_cast
<
const
Float4
*>
(
&
p_b_block
[
b_block_mtx
.
GetOffsetFromMultiIndex
(
k
,
0
)
+
mMyThreadOffsetB
]);
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
1
],
reg_c
[
9
],
reg_c
[
11
],
reg_c
[
13
],
reg_c
[
15
]);
reg_b
[
1
]
=
*
reinterpret_cast
<
const
Float4
*>
(
&
p_b_block
[
b_block_mtx
.
GetOffsetFromMultiIndex
(
k
,
NPerLevel1Cluster
)
+
mMyThreadOffsetB
]);
reg_a
[
1
]
=
*
reinterpret_cast
<
const
Float4
*>
(
&
p_a_block
[
a_block_mtx
.
GetOffsetFromMultiIndex
(
k
,
MPerLevel1Cluster
)
+
mMyThreadOffsetA
]);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
[
0
],
reg_c
[
2
],
reg_c
[
4
],
reg_c
[
6
]);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
1
],
reg_c
[
1
],
reg_c
[
3
],
reg_c
[
5
],
reg_c
[
7
]);
}
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
0
],
reg_c
[
8
],
reg_c
[
10
],
reg_c
[
12
],
reg_c
[
14
]);
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
1
],
reg_c
[
9
],
reg_c
[
11
],
reg_c
[
13
],
reg_c
[
15
]);
}
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
Run_asm_v2
(
const
FloatA
*
__restrict__
p_a_block
,
const
FloatB
*
__restrict__
p_b_block
,
FloatC
*
__restrict__
p_c_thread
)
const
{
constexpr
auto
a_block_mtx
=
BlockMatrixA
{};
constexpr
auto
b_block_mtx
=
BlockMatrixB
{};
constexpr
auto
c_thread_mtx
=
ThreadMatrixC
{};
constexpr
index_t
M
=
a_block_mtx
.
NCol
();
constexpr
index_t
N
=
b_block_mtx
.
NCol
();
constexpr
index_t
K
=
a_block_mtx
.
NRow
();
// A is transposed
constexpr
index_t
MPerThread
=
c_thread_mtx
.
NRow
();
constexpr
index_t
NPerThread
=
c_thread_mtx
.
NCol
();
// thread A, B for GEMM
// A is transposed, b is not
constexpr
auto
a_thread_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
MPerThread
>
{});
constexpr
auto
b_thread_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThread
>
{});
// thread A-sub, B-sub for copy
constexpr
auto
a_thread_sub_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
MPerThreadSubC
>
{},
Number
<
MPerThread
>
{});
constexpr
auto
b_thread_sub_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThreadSubC
>
{},
Number
<
NPerThread
>
{});
FloatA
p_a_thread
[
a_thread_mtx
.
GetElementSpace
()];
FloatB
p_b_thread
[
b_thread_mtx
.
GetElementSpace
()];
constexpr
index_t
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0Cluster
*
MLevel1Cluster
;
constexpr
index_t
NPerLevel1Cluster
=
NPerThreadSubC
*
NLevel0Cluster
*
NLevel1Cluster
;
// assertion for inline asm
static_assert
(
is_same
<
FloatA
,
float
>
{}
&&
is_same
<
FloatB
,
float
>
{}
&&
is_same
<
FloatC
,
float
>
{},
"Run_amd_asm only deal with float
\n
"
);
static_assert
(
MPerThreadSubC
==
4
&&
NPerThreadSubC
==
4
&&
KPerThreadLoop
==
1
&&
MPerThread
==
8
&&
NPerThread
==
8
,
"Run_amd_asm cannot deal with this GEMM shape yet
\n
"
);
static_assert
(
DataPerReadA
==
4
&&
DataPerReadB
==
4
,
"Run_amd_asm only do float4 read
\n
"
);
static_assert
(
BlockMatrixStrideA
==
0
&&
BatchPerThread
==
1
,
"Run_amd_asm can only deal with BlockMatrixStrideA == 0 && BatchPerThread == "
"1 for now
\n
"
);
using
Float4
=
vector_type
<
float
,
4
>::
MemoryType
;
Float4
*
reg_a
=
(
Float4
*
)(
p_a_thread
);
Float4
*
reg_b
=
(
Float4
*
)(
p_b_thread
);
Float4
*
reg_c
=
(
Float4
*
)(
p_c_thread
);
void
*
a_lds_loc
=
(
void
*
)(
p_a_block
+
mMyThreadOffsetA
);
void
*
b_lds_loc
=
(
void
*
)(
p_b_block
+
mMyThreadOffsetB
);
constexpr
index_t
a_lds_row_stride
=
sizeof
(
float
)
*
a_block_mtx
.
RowStride
();
constexpr
index_t
b_lds_row_stride
=
sizeof
(
float
)
*
b_block_mtx
.
RowStride
();
constexpr
index_t
a_lds_cluster_col_stride
=
sizeof
(
float
)
*
MPerLevel1Cluster
;
constexpr
index_t
b_lds_cluster_col_stride
=
sizeof
(
float
)
*
NPerLevel1Cluster
;
ds_read_b128
(
reg_a
[
0
],
a_lds_loc
,
0
);
ds_read_b128
(
reg_b
[
0
],
b_lds_loc
,
0
);
ds_read_b128
(
reg_b
[
1
],
b_lds_loc
,
b_lds_cluster_col_stride
);
ds_read_b128
(
reg_a
[
1
],
a_lds_loc
,
a_lds_cluster_col_stride
);
lgkmcnt
(
2
);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
[
0
],
reg_c
[
2
],
reg_c
[
4
],
reg_c
[
6
]);
lgkmcnt
(
1
);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
1
],
reg_c
[
1
],
reg_c
[
3
],
reg_c
[
5
],
reg_c
[
7
]);
#pragma unroll
for
(
index_t
k
=
1
;
k
<
K
;
++
k
)
{
ds_read_b128
(
reg_a
[
0
],
a_lds_loc
,
k
*
a_lds_row_stride
);
lgkmcnt
(
1
);
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
0
],
reg_c
[
8
],
reg_c
[
10
],
reg_c
[
12
],
reg_c
[
14
]);
ds_read_b128
(
reg_b
[
0
],
b_lds_loc
,
k
*
b_lds_row_stride
);
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
1
],
reg_c
[
9
],
reg_c
[
11
],
reg_c
[
13
],
reg_c
[
15
]);
ds_read_b128
(
reg_b
[
1
],
b_lds_loc
,
b_lds_cluster_col_stride
+
k
*
b_lds_row_stride
);
ds_read_b128
(
reg_a
[
1
],
a_lds_loc
,
a_lds_cluster_col_stride
+
k
*
a_lds_row_stride
);
lgkmcnt
(
2
);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
[
0
],
reg_c
[
2
],
reg_c
[
4
],
reg_c
[
6
]);
lgkmcnt
(
1
);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
1
],
reg_c
[
1
],
reg_c
[
3
],
reg_c
[
5
],
reg_c
[
7
]);
}
lgkmcnt
(
0
);
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
0
],
reg_c
[
8
],
reg_c
[
10
],
reg_c
[
12
],
reg_c
[
14
]);
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
1
],
reg_c
[
9
],
reg_c
[
11
],
reg_c
[
13
],
reg_c
[
15
]);
}
#endif
template
<
class
BlockMatrixC
,
index_t
BlockMatrixStrideC
,
class
FloatC
>
__device__
void
CopyThreadMatrixCToBlockMatrixC
(
const
FloatC
*
__restrict__
p_c_thread
,
FloatC
*
__restrict__
p_c_block
)
const
{
constexpr
auto
c_block_mtx
=
BlockMatrixC
{};
constexpr
auto
c_thread_mtx
=
ThreadMatrixC
{};
constexpr
index_t
MPerThread
=
c_thread_mtx
.
NRow
();
constexpr
index_t
NPerThread
=
c_thread_mtx
.
NCol
();
constexpr
auto
c_thread_sub_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
MPerThreadSubC
>
{},
Number
<
NPerThreadSubC
>
{},
Number
<
NPerThread
>
{});
constexpr
index_t
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0Cluster
*
MLevel1Cluster
;
constexpr
index_t
NPerLevel1Cluster
=
NPerThreadSubC
*
NLevel0Cluster
*
NLevel1Cluster
;
constexpr
index_t
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
index_t
NRepeat
=
NPerThread
/
NPerThreadSubC
;
const
auto
c_thread_mtx_begin
=
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
const
index_t
c_thread_offset
=
c_thread_mtx_begin
.
batch
*
BlockMatrixStrideC
+
c_block_mtx
.
GetOffsetFromMultiIndex
(
c_thread_mtx_begin
.
row
,
c_thread_mtx_begin
.
col
);
for
(
index_t
m_repeat
=
0
;
m_repeat
<
MRepeat
;
++
m_repeat
)
{
for
(
index_t
n_repeat
=
0
;
n_repeat
<
NRepeat
;
++
n_repeat
)
{
threadwise_matrix_copy
(
c_thread_sub_mtx
,
p_c_thread
+
c_thread_sub_mtx
.
GetOffsetFromMultiIndex
(
m_repeat
*
MPerLevel1Cluster
,
n_repeat
*
NPerLevel1Cluster
),
c_block_mtx
,
p_c_block
+
c_block_mtx
.
GetOffsetFromMultiIndex
(
m_repeat
*
MPerLevel1Cluster
,
n_repeat
*
NPerLevel1Cluster
)
+
c_thread_offset
,
c_thread_sub_mtx
.
GetLengths
());
}
}
}
};
}
// namespace
#endif
composable_kernel/include/tensor_operation/blockwise_tensor_slice_copy.hpp
deleted
100644 → 0
View file @
f0716f5b
#ifndef CK_BLOCKWISE_TENSOR_SLICE_COPY_HPP
#define CK_BLOCKWISE_TENSOR_SLICE_COPY_HPP
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "threadwise_tensor_slice_copy.hpp"
namespace
ck
{
template
<
index_t
BlockSize
,
class
Float
,
class
SrcDesc
,
class
DstDesc
,
class
SrcLengths
,
class
SrcSubLengths
,
class
SrcClusterLengths
,
class
MapDst2Src
,
class
MapThreadCluster2SrcCluster
,
index_t
SrcDataPerRead
,
index_t
DstDataPerWrite
>
struct
BlockwiseTensorSliceReorderCopy_v3
{
static
constexpr
index_t
nDim
=
SrcLengths
::
GetSize
();
index_t
mThreadSrcOffset
;
index_t
mThreadDstOffset
;
__device__
BlockwiseTensorSliceReorderCopy_v3
(
Array
<
index_t
,
nDim
>
src_block_data_multi_id_begin
,
Array
<
index_t
,
nDim
>
dst_block_data_multi_id_begin
)
{
constexpr
auto
src_desc
=
SrcDesc
{};
constexpr
auto
dst_desc
=
DstDesc
{};
constexpr
auto
src_lengths
=
SrcLengths
{};
constexpr
auto
map_dst2src
=
MapDst2Src
{};
constexpr
auto
src_sub_lengths
=
SrcSubLengths
{};
constexpr
auto
dst_sub_lengths
=
src_sub_lengths
.
ReorderGivenNew2Old
(
map_dst2src
);
constexpr
auto
map_thread_cluster_2_src_cluster
=
MapThreadCluster2SrcCluster
{};
constexpr
auto
src_cluster_lengths
=
SrcClusterLengths
{};
constexpr
auto
thread_cluster_lengths
=
src_cluster_lengths
.
ReorderGivenNew2Old
(
map_thread_cluster_2_src_cluster
);
constexpr
auto
thread_cluster_desc
=
make_ConstantTensorDescriptor_packed
(
thread_cluster_lengths
);
// sanity check: data type
static_assert
(
is_same
<
Float
,
float
>
{},
"wrong! only support float for now!
\n
"
);
// sanity check: nDim
static_assert
(
SrcDesc
::
GetNumOfDimension
()
==
nDim
&&
DstDesc
::
GetNumOfDimension
()
==
nDim
&&
SrcLengths
::
GetSize
()
==
nDim
&&
SrcSubLengths
::
GetSize
()
==
nDim
&&
SrcClusterLengths
::
GetSize
()
==
nDim
&&
MapDst2Src
::
GetSize
()
==
nDim
&&
MapThreadCluster2SrcCluster
::
GetSize
()
==
nDim
,
"wrong! nDim is not consistent
\n
"
);
// sanity check: BlockSize
constexpr
index_t
num_active_thread
=
thread_cluster_desc
.
GetElementSize
();
static_assert
(
BlockSize
>=
num_active_thread
,
"wrong! BlockSize is not big enough for ThreadPerDims!"
);
// sanity check: work division
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
IDim
)
{
constexpr
auto
I
=
decltype
(
IDim
){};
constexpr
index_t
src_len
=
src_lengths
.
Get
(
I
);
constexpr
index_t
src_sub_len
=
src_sub_lengths
.
Get
(
I
);
constexpr
index_t
src_cluster_len
=
src_cluster_lengths
.
Get
(
I
);
static_assert
(
src_len
%
(
src_sub_len
*
src_cluster_len
)
==
0
,
"wrong! cannot evenly divide Src tensor lengths"
);
});
// sanity check: src read
static_assert
(
SrcDataPerRead
==
1
||
SrcDataPerRead
==
2
||
SrcDataPerRead
==
4
,
"wrong! only support SrcDataPerRead == 1, 2 or 4!
\n
"
);
static_assert
(
SrcDataPerRead
==
1
||
src_desc
.
GetStride
(
Number
<
nDim
-
1
>
{})
==
1
,
"wrong! only support src.stride(nDim-1) == 1 if SrcDataPerRead > 1!
\n
"
);
static_assert
(
src_sub_lengths
.
Get
(
Number
<
nDim
-
1
>
{})
%
SrcDataPerRead
==
0
,
"wrong! src_sub_lengths[nDim-1] % SrcDataPerRead != 0
\n
"
);
static_assert
(
src_desc
.
GetStride
(
Number
<
nDim
-
2
>
{})
%
SrcDataPerRead
==
0
,
"wrong! should satisfy src_desc.stride(nDim-2) % SrcDataPerRead == 0, to "
"keep alignment"
);
// sanity check: dst write
static_assert
(
DstDataPerWrite
==
1
||
DstDataPerWrite
==
2
||
DstDataPerWrite
==
4
,
"wrong! only support DstDataPerWrite == 1, 2 or 4!
\n
"
);
static_assert
(
DstDataPerWrite
==
1
||
dst_desc
.
GetStride
(
Number
<
nDim
-
1
>
{})
==
1
,
"wrong! only support dst.stride(nDim-1) == 1 if DstDataPerWrite > 1!
\n
"
);
static_assert
(
dst_sub_lengths
.
Get
(
Number
<
nDim
-
1
>
{})
%
DstDataPerWrite
==
0
,
"wrong! dst_sub_lengths[nDim-1] % DstDataPerWrite != 0
\n
"
);
static_assert
(
dst_desc
.
GetStride
(
Number
<
nDim
-
2
>
{})
%
DstDataPerWrite
==
0
,
"wrong! should satisfy dst_desc.stride(nDim-2) % DstDataPerWrite == 0, to "
"keep alignment"
);
// start dividing work
if
(
BlockSize
>
num_active_thread
)
{
if
(
get_thread_local_1d_id
()
>=
num_active_thread
)
{
return
;
}
}
const
auto
thread_multi_id
=
thread_cluster_desc
.
GetMultiIndexFrom1dIndex
(
get_thread_local_1d_id
());
// compiler: thread_multi_id, src_data_multi_id, dst_data_multi_id, will use separate
// regsiters, or only one copy???
auto
src_data_multi_id
=
reorder_array_given_old2new
(
thread_multi_id
,
map_thread_cluster_2_src_cluster
);
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
IDim
)
{
constexpr
index_t
idim
=
IDim
;
// compiler: will it really compute index here, or be merged with
// GetOffsetFromMultiIndex and
// optimized away???
src_data_multi_id
(
idim
)
*=
src_sub_lengths
.
Get
(
IDim
);
});
// compiler: will it really compute index here, or be merged with GetOffsetFromMultiIndex
// and
// optimized away???
const
auto
dst_data_multi_id
=
reorder_array_given_new2old
(
src_data_multi_id
,
map_dst2src
);
mThreadSrcOffset
=
src_desc
.
GetOffsetFromMultiIndex
(
src_data_multi_id
+
src_block_data_multi_id_begin
);
mThreadDstOffset
=
dst_desc
.
GetOffsetFromMultiIndex
(
dst_data_multi_id
+
dst_block_data_multi_id_begin
);
#if 0
if(get_block_1d_id() == 0 && get_thread_local_1d_id() == 0)
{
print_ConstantTensorDescriptor(thread_cluster_desc, "thread_cluster_desc: ");
}
if(get_block_1d_id() == 0)
{
printf("id %5u %5u: "
"thread_multi_id: %u %u, "
"src_block_data_multi_id_begin: %u %u, "
"src_data_multi_id: %u %u, "
"mThreadSrcOffset %u, mThreadDstOffset %u \n",
get_block_1d_id(),
get_thread_local_1d_id(),
thread_multi_id[0],
thread_multi_id[1],
src_block_data_multi_id_begin[0],
src_block_data_multi_id_begin[1],
src_data_multi_id[0],
src_data_multi_id[1],
mThreadSrcOffset,
mThreadDstOffset);
}
#endif
}
__device__
static
constexpr
index_t
GetRegisterClipboardSize
()
{
constexpr
auto
thread_sub_tensor_lengths
=
SrcSubLengths
{};
constexpr
auto
src_data_per_cluster_per_dims
=
thread_sub_tensor_lengths
*
SrcClusterLengths
{};
constexpr
auto
repeat_lengths
=
transform_sequences
(
math
::
integer_divide_ceiler
<
index_t
>
{},
SrcLengths
{},
src_data_per_cluster_per_dims
);
constexpr
auto
thread_tensor_lengths
=
thread_sub_tensor_lengths
*
repeat_lengths
;
constexpr
auto
thread_tensor_desc
=
make_ConstantTensorDescriptor_packed
(
thread_tensor_lengths
);
return
thread_tensor_desc
.
GetElementSpace
();
}
__device__
void
RunLoadRegisterClipboard
(
const
Float
*
__restrict__
p_src
,
Float
*
__restrict__
p_clipboard
)
const
{
constexpr
auto
thread_sub_tensor_lengths
=
SrcSubLengths
{};
constexpr
auto
src_data_per_cluster_per_dims
=
thread_sub_tensor_lengths
*
SrcClusterLengths
{};
constexpr
auto
repeat_lengths
=
transform_sequences
(
math
::
integer_divide_ceiler
<
index_t
>
{},
SrcLengths
{},
src_data_per_cluster_per_dims
);
constexpr
auto
thread_tensor_lengths
=
thread_sub_tensor_lengths
*
repeat_lengths
;
constexpr
auto
thread_tensor_desc
=
make_ConstantTensorDescriptor_packed
(
thread_tensor_lengths
);
static_ford
<
decltype
(
repeat_lengths
)
>
{}([
&
](
auto
repeat_multi_id_
)
{
constexpr
auto
repeat_multi_id
=
decltype
(
repeat_multi_id_
){};
constexpr
auto
src_data_multi_id
=
repeat_multi_id
*
src_data_per_cluster_per_dims
;
constexpr
auto
clipboard_data_multi_id
=
repeat_multi_id
*
thread_sub_tensor_lengths
;
constexpr
index_t
src_offset
=
SrcDesc
{}.
GetOffsetFromMultiIndex
(
src_data_multi_id
);
constexpr
index_t
clipboard_offset
=
thread_tensor_desc
.
GetOffsetFromMultiIndex
(
clipboard_data_multi_id
);
threadwise_tensor_slice_copy
(
SrcDesc
{},
p_src
+
src_offset
+
mThreadSrcOffset
,
thread_tensor_desc
,
p_clipboard
+
clipboard_offset
,
thread_sub_tensor_lengths
,
Number
<
SrcDataPerRead
>
{});
});
}
__device__
void
RunStoreRegisterClipboard
(
const
Float
*
__restrict__
p_clipboard
,
Float
*
__restrict__
p_dst
)
const
{
constexpr
auto
thread_sub_tensor_lengths
=
SrcSubLengths
{};
constexpr
auto
src_data_per_cluster_per_dims
=
thread_sub_tensor_lengths
*
SrcClusterLengths
{};
constexpr
auto
repeat_lengths
=
transform_sequences
(
math
::
integer_divide_ceiler
<
index_t
>
{},
SrcLengths
{},
src_data_per_cluster_per_dims
);
constexpr
auto
thread_tensor_lengths
=
thread_sub_tensor_lengths
*
repeat_lengths
;
constexpr
auto
thread_tensor_desc
=
make_ConstantTensorDescriptor_packed
(
thread_tensor_lengths
);
static_ford
<
decltype
(
repeat_lengths
)
>
{}([
&
](
auto
repeat_multi_id_
)
{
constexpr
auto
repeat_multi_id
=
decltype
(
repeat_multi_id_
){};
constexpr
auto
clipboard_data_multi_id
=
repeat_multi_id
*
thread_sub_tensor_lengths
;
constexpr
auto
src_data_multi_id
=
repeat_multi_id
*
src_data_per_cluster_per_dims
;
// reorder src_data_multi_id to get dst_data_multi_id
constexpr
auto
dst_data_multi_id
=
src_data_multi_id
.
ReorderGivenNew2Old
(
MapDst2Src
{});
constexpr
index_t
clipboard_offset
=
thread_tensor_desc
.
GetOffsetFromMultiIndex
(
clipboard_data_multi_id
);
constexpr
index_t
dst_offset
=
DstDesc
{}.
GetOffsetFromMultiIndex
(
dst_data_multi_id
);
// write in the order of dst
#if 1
threadwise_tensor_slice_copy_reorder_given_dst2src_v2
(
thread_tensor_desc
,
p_clipboard
+
clipboard_offset
,
DstDesc
{},
p_dst
+
dst_offset
+
mThreadDstOffset
,
thread_sub_tensor_lengths
,
MapDst2Src
{});
#else
threadwise_tensor_slice_copy_reorder_given_dst2src_v3
(
thread_tensor_desc
,
p_clipboard
+
clipboard_offset
,
DstDesc
{},
p_dst
+
dst_offset
+
mThreadDstOffset
,
thread_sub_tensor_lengths
,
MapDst2Src
{},
Number
<
DstDataPerWrite
>
{});
#endif
});
}
__device__
void
Run
(
const
Float
*
__restrict__
p_src
,
Float
*
__restrict__
p_dst
)
const
{
Float
p_clipboard
[
GetRegisterClipboardSize
()];
RunLoadRegisterClipboard
(
p_src
,
p_clipboard
);
RunStoreRegisterClipboard
(
p_clipboard
,
p_dst
);
}
// this function doesn't do santiy check on whether the slicing window is out of the boundary
// of the tensor being sliced
template
<
index_t
IDim_
,
index_t
StepSize
,
bool
PositiveDirection
>
__device__
void
MoveSlicingWindowOnSourceTensor
(
Number
<
IDim_
>
,
Number
<
StepSize
>
,
integral_constant
<
bool
,
PositiveDirection
>
direction
)
{
constexpr
auto
IDim
=
Number
<
IDim_
>
{};
static_if
<
PositiveDirection
>
{}([
&
](
auto
fwd
)
{
mThreadSrcOffset
+=
StepSize
*
fwd
(
SrcDesc
{}).
GetStride
(
IDim
);
}).
Else
([
&
](
auto
fwd
)
{
mThreadSrcOffset
-=
StepSize
*
fwd
(
SrcDesc
{}).
GetStride
(
IDim
);
});
}
};
}
// namespace ck
#endif
composable_kernel/include/tensor_operation/threadwise_4d_tensor_op.hpp
deleted
100644 → 0
View file @
f0716f5b
#ifndef CK_THREADWISE_4D_TENSOR_OP_HPP
#define CK_THREADWISE_4D_TENSOR_OP_HPP
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
namespace
ck
{
template
<
class
Float
,
class
Desc
,
class
IDim
,
class
NShift
>
__device__
void
threadwise_4d_tensor_shift_down
(
Desc
,
Float
*
__restrict__
p
,
IDim
,
NShift
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
desc
=
Desc
{};
#if 0
if(get_thread_local_1d_id() == 0)
{
print_ConstantTensorDescriptor(desc, "threadwise_4d_tensor_shift_down: ");
}
#endif
constexpr
index_t
nshift
=
NShift
::
mValue
;
constexpr
index_t
did0_end
=
is_same
<
decltype
(
I0
),
IDim
>
{}
?
desc
.
GetLength
(
I0
)
-
nshift
:
desc
.
GetLength
(
I0
);
constexpr
index_t
did1_end
=
is_same
<
decltype
(
I1
),
IDim
>
{}
?
desc
.
GetLength
(
I1
)
-
nshift
:
desc
.
GetLength
(
I1
);
constexpr
index_t
did2_end
=
is_same
<
decltype
(
I2
),
IDim
>
{}
?
desc
.
GetLength
(
I2
)
-
nshift
:
desc
.
GetLength
(
I2
);
constexpr
index_t
did3_end
=
is_same
<
decltype
(
I3
),
IDim
>
{}
?
desc
.
GetLength
(
I3
)
-
nshift
:
desc
.
GetLength
(
I3
);
for
(
index_t
did0
=
0
;
did0
<
did0_end
;
++
did0
)
{
for
(
index_t
did1
=
0
;
did1
<
did1_end
;
++
did1
)
{
for
(
index_t
did2
=
0
;
did2
<
did2_end
;
++
did2
)
{
for
(
index_t
did3
=
0
;
did3
<
did3_end
;
++
did3
)
{
const
index_t
dindex
=
desc
.
GetOffsetFromMultiIndex
(
did0
,
did1
,
did2
,
did3
);
const
index_t
sindex
=
dindex
+
nshift
*
desc
.
GetStride
(
IDim
{});
p
[
dindex
]
=
p
[
sindex
];
}
}
}
}
}
}
// namespace ck
#endif
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment