Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yangql
composable_kernel-1
Commits
9b280cc5
Commit
9b280cc5
authored
Sep 27, 2019
by
Chao Liu
Browse files
remove dead code
parent
98a2cfcc
Changes
26
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
18 additions
and
5365 deletions
+18
-5365
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
...ridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
+0
-401
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_padded.hpp
..._convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_padded.hpp
+0
-530
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
...ridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
+0
-331
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_padded.hpp
..._convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_padded.hpp
+0
-461
composable_kernel/include/kernel_algorithm/gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hpp
...idwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hpp
+0
-259
composable_kernel/include/kernel_algorithm/gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded.hpp
...ise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded.hpp
+0
-298
composable_kernel/include/tensor_description/dimension.hpp
composable_kernel/include/tensor_description/dimension.hpp
+0
-6
composable_kernel/include/tensor_description/tensor_coordinate.hpp
...e_kernel/include/tensor_description/tensor_coordinate.hpp
+2
-2
composable_kernel/include/tensor_description/tensor_coordinate_deprecated.hpp
...clude/tensor_description/tensor_coordinate_deprecated.hpp
+2
-2
composable_kernel/include/tensor_description/tensor_view.hpp
composable_kernel/include/tensor_description/tensor_view.hpp
+0
-100
composable_kernel/include/tensor_description/tensor_visit.hpp
...osable_kernel/include/tensor_description/tensor_visit.hpp
+0
-124
composable_kernel/include/tensor_operation/blockwise_2d_tensor_op.hpp
...ernel/include/tensor_operation/blockwise_2d_tensor_op.hpp
+0
-806
composable_kernel/include/tensor_operation/blockwise_3d_tensor_op.hpp
...ernel/include/tensor_operation/blockwise_3d_tensor_op.hpp
+0
-378
composable_kernel/include/tensor_operation/blockwise_4d_tensor_op.hpp
...ernel/include/tensor_operation/blockwise_4d_tensor_op.hpp
+0
-779
composable_kernel/include/tensor_operation/blockwise_generic_tensor_slice_copy_deprecated.hpp
...ration/blockwise_generic_tensor_slice_copy_deprecated.hpp
+4
-141
composable_kernel/include/tensor_operation/blockwise_tensor_slice_copy.hpp
.../include/tensor_operation/blockwise_tensor_slice_copy.hpp
+0
-298
composable_kernel/include/tensor_operation/threadwise_4d_tensor_op.hpp
...rnel/include/tensor_operation/threadwise_4d_tensor_op.hpp
+0
-60
composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy_deprecated.hpp
...ation/threadwise_generic_tensor_slice_copy_deprecated.hpp
+10
-158
composable_kernel/include/tensor_operation/threadwise_tensor_slice_copy.hpp
...include/tensor_operation/threadwise_tensor_slice_copy.hpp
+0
-201
composable_kernel/include/utility/config_nvidia.hpp.in
composable_kernel/include/utility/config_nvidia.hpp.in
+0
-30
No files found.
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
deleted
100644 → 0
View file @
98a2cfcc
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_HPP
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMergedTensorDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp"
#include "blockwise_gemm.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
namespace
ck
{
// define B = merge(N0, Ho, Wo)
template
<
index_t
GridSize
,
index_t
BlockSize
,
typename
Float
,
typename
InGlobalDesc
,
typename
WeiGlobalDesc
,
typename
OutGlobalDesc
,
typename
ConvStrides
,
typename
ConvDilations
,
index_t
BPerBlock
,
index_t
KPerBlock
,
index_t
EPerBlock
,
index_t
GemmNRepeat
,
index_t
GemmMPerThreadSubC
,
index_t
GemmNPerThreadSubC
,
index_t
GemmMLevel0Cluster
,
index_t
GemmNLevel0Cluster
,
index_t
GemmMLevel1Cluster
,
index_t
GemmNLevel1Cluster
,
index_t
GemmKPerThreadLoop
,
index_t
GemmDataPerReadA
,
index_t
GemmDataPerReadB
,
typename
InBlockCopySubLengths_E_N1_B_N2
,
typename
InBlockCopyClusterLengths_E_N1_B_N2
,
typename
InBlockCopyThreadClusterArrangeOrder
,
typename
InBlockCopySrcAccessOrder
,
typename
InBlockCopyDstAccessOrder
,
index_t
InBlockCopySrcDataPerRead_B
,
index_t
InBlockCopyDstDataPerWrite_N2
,
typename
WeiBlockCopySubLengths_E_K
,
typename
WeiBlockCopyClusterLengths_E_K
,
typename
WeiBlockCopyThreadClusterArrangeOrder
,
typename
WeiBlockCopySrcAccessOrder
,
typename
WeiBlockCopyDstAccessOrder
,
index_t
WeiBlockCopySrcDataPerRead_E
,
index_t
WeiBlockCopyDstDataPerWrite_K
>
struct
GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw
{
__device__
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
const
Float
*
const
__restrict__
p_wei_global
,
Float
*
const
__restrict__
p_out_global
)
const
{
// this is a mess
// TODO: find more elegent way of specifying (or calculating) performance parameters
constexpr
index_t
N1
=
GemmNRepeat
;
constexpr
index_t
N2
=
GemmNPerThreadSubC
;
static_assert
((
N1
*
N2
*
BPerBlock
)
%
(
GemmNPerThreadSubC
*
GemmNLevel0Cluster
*
GemmNLevel1Cluster
)
==
0
,
"wrong!"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I5
=
Number
<
5
>
{};
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
in_n_c_h_w_global_desc
=
InGlobalDesc
{};
constexpr
auto
wei_k_c_y_x_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
out_n_k_h_w_global_desc
=
OutGlobalDesc
{};
constexpr
index_t
N
=
in_n_c_h_w_global_desc
.
GetLength
(
I0
);
constexpr
index_t
C
=
in_n_c_h_w_global_desc
.
GetLength
(
I1
);
constexpr
index_t
K
=
out_n_k_h_w_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Ho
=
out_n_k_h_w_global_desc
.
GetLength
(
I2
);
constexpr
index_t
Wo
=
out_n_k_h_w_global_desc
.
GetLength
(
I3
);
constexpr
index_t
Y
=
wei_k_c_y_x_global_desc
.
GetLength
(
I2
);
constexpr
index_t
X
=
wei_k_c_y_x_global_desc
.
GetLength
(
I3
);
constexpr
index_t
ConvStrideH
=
ConvStrides
{}[
0
];
constexpr
index_t
ConvStrideW
=
ConvStrides
{}[
1
];
constexpr
index_t
ConvDilationH
=
ConvDilations
{}[
0
];
constexpr
index_t
ConvDilationW
=
ConvDilations
{}[
1
];
static_assert
(
N
%
(
N1
*
N2
)
==
0
,
"wrong! cannot divice N evenly among thread"
);
constexpr
index_t
N0
=
N
/
(
N1
*
N2
);
constexpr
index_t
B
=
N0
*
Ho
*
Wo
;
constexpr
index_t
E
=
C
*
Y
*
X
;
// sanity-check for vectorized memory load
static_assert
((
Wo
==
1
||
(
ConvStrideW
==
1
||
InBlockCopySrcDataPerRead_B
==
1
))
&&
(
X
==
1
||
ConvDilationW
%
InBlockCopySrcDataPerRead_B
==
0
),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated"
);
// divide block work by [K, B]
static_assert
(
K
%
KPerBlock
==
0
&&
B
%
BPerBlock
==
0
&&
E
%
EPerBlock
==
0
,
"wrong! cannot divide work evenly among block"
);
constexpr
index_t
KBlockWork
=
K
/
KPerBlock
;
constexpr
index_t
BBlockWork
=
B
/
BPerBlock
;
constexpr
auto
block_work_desc
=
make_ConstantTensorDescriptor_packed
(
Sequence
<
KBlockWork
,
BBlockWork
>
{});
const
auto
block_work_multi_id
=
block_work_desc
.
GetMultiIndexFrom1dIndex
(
get_block_1d_id
());
const
index_t
k_block_data_on_global
=
block_work_multi_id
[
0
]
*
KPerBlock
;
const
index_t
b_block_data_on_global
=
block_work_multi_id
[
1
]
*
BPerBlock
;
// input tensor
// tensor descriptor in device memory [N0, N1, N2, Ho, Wo]
constexpr
auto
in_n0_n1_n2_h_w_global_desc
=
in_n_c_h_w_global_desc
.
StridedSlice
(
I2
,
Number
<
Ho
>
{},
Number
<
ConvStrideH
>
{})
.
StridedSlice
(
I3
,
Number
<
Wo
>
{},
Number
<
ConvStrideW
>
{})
.
Fold
(
I0
,
Number
<
N1
>
{},
Number
<
N2
>
{})
.
Extract
(
Sequence
<
0
,
1
,
2
,
4
,
5
>
{});
// batch descritpor for device memory
constexpr
auto
in_c_y_x_global_desc
=
in_n_c_h_w_global_desc
.
StridedSlice
(
I2
,
Number
<
Y
>
{},
Number
<
ConvDilationH
>
{})
.
StridedSlice
(
I3
,
Number
<
X
>
{},
Number
<
ConvDilationW
>
{})
.
Extract
(
Sequence
<
1
,
2
,
3
>
{});
// merged tensor descriptor in device memory [E, N1, B, N2], src of blockwise copy
constexpr
auto
in_e_n1_b_n2_global_merged_desc
=
make_ConstantMergedTensorDescriptor
(
in_c_y_x_global_desc
.
Embed
(
in_n0_n1_n2_h_w_global_desc
),
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
4
>
{},
Sequence
<
3
,
6
,
7
>
{},
Sequence
<
5
>
{});
// memory layout descriptor in LDS [E, N1, B, N2], dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
in_e_n1_b_n2_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
EPerBlock
,
N1
,
BPerBlock
,
N2
>
{},
Number
<
InBlockCopyDstDataPerWrite_N2
>
{});
// this check is ad-hoc
// TODO: need to properly implement tensor descriptor with multiple alignment
// requirements
static_assert
(
in_e_n1_b_n2_block_desc
.
GetStride
(
I1
)
%
GemmDataPerReadB
==
0
,
"GemmDataPerReadB alignment requirement is not satisfied"
);
// input blockwise copy
// slice a merged tensor, reorder and copy to a normal tensor
// this copy operator already has blockwise offset built-in
auto
blockwise_in_copy
=
BlockwiseGenericTensorSliceCopy_v2
<
BlockSize
,
decltype
(
in_e_n1_b_n2_global_merged_desc
),
decltype
(
in_e_n1_b_n2_block_desc
),
decltype
(
in_e_n1_b_n2_block_desc
.
GetLengths
()),
InBlockCopySubLengths_E_N1_B_N2
,
InBlockCopyClusterLengths_E_N1_B_N2
,
InBlockCopyThreadClusterArrangeOrder
,
InBlockCopySrcAccessOrder
,
InBlockCopyDstAccessOrder
,
2
,
3
,
InBlockCopySrcDataPerRead_B
,
InBlockCopyDstDataPerWrite_N2
>
(
{
0
,
0
,
b_block_data_on_global
,
0
},
{
0
,
0
,
0
,
0
});
// weight tensor
// tensor descriptor in device memory, src of blockwise copy
constexpr
auto
wei_e_k_global_desc
=
wei_k_c_y_x_global_desc
.
Unfold
(
I1
,
I3
).
ReorderGivenNew2Old
(
Sequence
<
1
,
0
>
{});
// tensor descriptor in LDS, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
wei_e_k_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
EPerBlock
,
KPerBlock
>
{},
Number
<
math
::
lcm
(
WeiBlockCopyDstDataPerWrite_K
,
GemmDataPerReadA
)
>
{});
// this check is ad-hoc
// TODO: need to properly implement tensor descriptor with multiple alignment
// requirements
static_assert
(
wei_e_k_block_desc
.
GetStride
(
I0
)
%
GemmDataPerReadA
==
0
,
"GemmDataPerReadA alignment requirement is not satisfied"
);
// operator for blockwise copy of weight into LDS
// slice a tensor, and copy it into another tensor
// this copy operator already have blockwise offset built-in
auto
blockwise_wei_copy
=
BlockwiseGenericTensorSliceCopy_v2
<
BlockSize
,
decltype
(
wei_e_k_global_desc
),
decltype
(
wei_e_k_block_desc
),
decltype
(
wei_e_k_block_desc
.
GetLengths
()),
WeiBlockCopySubLengths_E_K
,
WeiBlockCopyClusterLengths_E_K
,
WeiBlockCopyThreadClusterArrangeOrder
,
WeiBlockCopySrcAccessOrder
,
WeiBlockCopyDstAccessOrder
,
0
,
1
,
WeiBlockCopySrcDataPerRead_E
,
WeiBlockCopyDstDataPerWrite_K
>
(
{
0
,
k_block_data_on_global
},
{
0
,
0
});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[EPerBlock, KPerBlock] is in LDS
// b_mtx[EPerBlocl, N1 * BPerBlock * N2] is in LDS
// c_mtx[KPerBlock, N1 * BPerBlock * N2] is distributed among threads, and saved in
// register
constexpr
auto
a_e_k_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
wei_e_k_block_desc
);
constexpr
auto
b_e_n1bn2_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
in_e_n1_b_n2_block_desc
.
Unfold
(
I1
,
I3
));
// sanity check
static_assert
(
KPerBlock
%
(
GemmMPerThreadSubC
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
)
==
0
,
"wrong!"
);
constexpr
index_t
GemmMRepeat
=
KPerBlock
/
(
GemmMPerThreadSubC
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
);
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
constexpr
auto
c_k0k2_n1n2_thread_mtx_desc
=
make_ConstantMatrixDescriptor_packed
(
Number
<
GemmMRepeat
*
GemmMPerThreadSubC
>
{},
Number
<
GemmNRepeat
*
GemmNPerThreadSubC
>
{});
const
auto
blockwise_gemm
=
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
<
BlockSize
,
decltype
(
a_e_k_block_mtx_desc
),
decltype
(
b_e_n1bn2_block_mtx_desc
),
decltype
(
c_k0k2_n1n2_thread_mtx_desc
),
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
GemmDataPerReadA
,
GemmDataPerReadB
>
{};
// LDS allocation for input and weight: be careful of alignment
constexpr
index_t
max_align
=
math
::
lcm
(
InBlockCopyDstDataPerWrite_N2
,
WeiBlockCopyDstDataPerWrite_K
,
GemmDataPerReadA
,
GemmDataPerReadB
);
constexpr
index_t
in_block_space
=
math
::
integer_least_multiple
(
in_e_n1_b_n2_block_desc
.
GetElementSpace
(),
max_align
);
constexpr
index_t
wei_block_space
=
math
::
integer_least_multiple
(
wei_e_k_block_desc
.
GetElementSpace
(),
max_align
);
__shared__
Float
p_in_block
[
in_block_space
];
__shared__
Float
p_wei_block
[
wei_block_space
];
// register allocation for output
Float
p_out_thread
[
c_k0k2_n1n2_thread_mtx_desc
.
GetElementSpace
()];
// zero out threadwise output
threadwise_matrix_set_zero
(
c_k0k2_n1n2_thread_mtx_desc
,
p_out_thread
);
// do work
for
(
index_t
e
=
0
;
e
<
E
;
e
+=
EPerBlock
)
{
blockwise_in_copy
.
Run
(
p_in_global
,
p_in_block
);
blockwise_wei_copy
.
Run
(
p_wei_global
,
p_wei_block
);
__syncthreads
();
blockwise_gemm
.
Run
(
p_wei_block
,
p_in_block
,
p_out_thread
);
__syncthreads
();
blockwise_in_copy
.
MoveSrcSliceWindow
(
make_multi_index
(
EPerBlock
,
0
,
0
,
0
),
True
);
blockwise_wei_copy
.
MoveSrcSliceWindow
(
make_multi_index
(
EPerBlock
,
0
),
True
);
}
// copy output: register to global memory
{
#if 0
constexpr index_t K2 = GemmMPerThreadSubC;
constexpr index_t K1 = GemmMLevel0Cluster * GemmMLevel1Cluster;
// define tensor descriptor for threadwise copy
// output memory layout descriptor in register
constexpr auto out_k0_k1_k2_n1_n0_h_w_n2_thread_mem_desc =
make_ConstantTensorDescriptor_packed(
Sequence<KPerBlock / (K1 * K2), 1, K2, N1, 1, 1, 1, N2>{});
// output tensor descriptor in register, src of threadwise copy
constexpr auto out_n0_n1_n2_k0_k1_k2_h_w_thread_desc =
out_k0_k1_k2_n1_n0_h_w_n2_thread_mem_desc.ReorderGivenNew2Old(
Sequence<4, 3, 7, 0, 1, 2, 5, 6>{});
// output memory layout descriptor in device memory, dst of threadwise copy
constexpr auto out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc =
out_n_k_h_w_global_desc.Fold(I1, Number<K1>{}, Number<K2>{})
.Fold(I0, Number<N1>{}, Number<N2>{});
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const index_t k_thread_data_on_global =
k_block_data_on_global + c_thread_mtx_on_block.row;
const index_t b_thread_data_on_global =
b_block_data_on_global + c_thread_mtx_on_block.col / N2;
// output merged global tensor descriptor, for calculating origin of thread tensor
// in global memory
constexpr auto out_k_n1_b_n2_global_merged_desc = make_ConstantMergedTensorDescriptor(
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.Unfold(I3, I5),
Sequence<3>{},
Sequence<1>{},
Sequence<0, 4, 5>{},
Sequence<2>{});
// origin of dst in device memory
Float* p_out_thread_on_global =
p_out_global +
out_k_n1_b_n2_global_merged_desc.GetOffsetFromMultiIndex(
k_thread_data_on_global, 0, b_thread_data_on_global, 0);
ThreadwiseGenericTensorSliceCopy_v2r1<
decltype(out_n0_n1_n2_k0_k1_k2_h_w_thread_desc),
decltype(out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc),
decltype(out_n0_n1_n2_k0_k1_k2_h_w_thread_desc.GetLengths()),
arithmetic_sequence_gen<0, 8, 1>::type,
arithmetic_sequence_gen<0, 8, 1>::type,
7,
7,
1,
1>({0, 0, 0, 0, 0, 0, 0, 0}, {0, 0, 0, 0, 0, 0, 0, 0})
.Run(p_out_thread, p_out_thread_on_global);
#else
constexpr
index_t
K1
=
GemmMPerThreadSubC
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
;
// define tensor descriptor for threadwise copy
// output memory layout descriptor in register, src of threadwise copy
constexpr
auto
out_k0_k1_n1_b_n2_thread_mem_desc
=
make_ConstantTensorDescriptor_packed
(
Sequence
<
GemmMRepeat
,
GemmMPerThreadSubC
,
N1
,
1
,
N2
>
{});
// output memory layout descriptor in device memory
constexpr
auto
out_n0_n1_n2_k0_k1_h_w_global_mem_desc
=
out_n_k_h_w_global_desc
.
Fold
(
I1
,
Number
<
K1
>
{}).
Fold
(
I0
,
Number
<
N1
>
{},
Number
<
N2
>
{});
// output merged global tensor descriptor, dst of threadwise copy
constexpr
auto
out_k0_k1_n1_b_n2_global_merged_desc
=
make_ConstantMergedTensorDescriptor
(
out_n0_n1_n2_k0_k1_h_w_global_mem_desc
,
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
1
>
{},
Sequence
<
0
,
5
,
6
>
{},
Sequence
<
2
>
{});
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
blockwise_gemm
.
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
const
index_t
k_thread_data_on_global
=
k_block_data_on_global
+
c_thread_mtx_on_block
.
row
;
const
index_t
b_thread_data_on_global
=
b_block_data_on_global
+
c_thread_mtx_on_block
.
col
/
N2
;
ThreadwiseGenericTensorSliceCopy_v2r1
<
decltype
(
out_k0_k1_n1_b_n2_thread_mem_desc
),
decltype
(
out_k0_k1_n1_b_n2_global_merged_desc
),
decltype
(
out_k0_k1_n1_b_n2_thread_mem_desc
.
GetLengths
()),
arithmetic_sequence_gen
<
0
,
5
,
1
>::
type
,
arithmetic_sequence_gen
<
0
,
5
,
1
>::
type
,
3
,
3
,
1
,
1
>
({
0
,
0
,
0
,
0
,
0
},
{
k_thread_data_on_global
/
K1
,
k_thread_data_on_global
%
K1
,
0
,
b_thread_data_on_global
,
0
})
.
template
Run_amd_experiment
<
Float
,
0
,
2
>(
p_out_thread
,
p_out_global
);
#endif
}
}
};
}
// namespace ck
#endif
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_padded.hpp
deleted
100644 → 0
View file @
98a2cfcc
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_PADDED_HPP
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_PADDED_HPP
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMergedTensorDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp"
#include "blockwise_gemm.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
namespace
ck
{
// define B = merge(N0, Ho, Wo)
template
<
index_t
GridSize
,
index_t
BlockSize
,
typename
Float
,
typename
InGlobalDesc
,
typename
WeiGlobalDesc
,
typename
OutGlobalDesc
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
LeftPads
,
typename
RightPads
,
index_t
BPerBlock
,
index_t
KPerBlock
,
index_t
EPerBlock
,
index_t
GemmNRepeat
,
index_t
GemmMPerThreadSubC
,
index_t
GemmNPerThreadSubC
,
index_t
GemmMLevel0Cluster
,
index_t
GemmNLevel0Cluster
,
index_t
GemmMLevel1Cluster
,
index_t
GemmNLevel1Cluster
,
index_t
GemmKPerThreadLoop
,
index_t
GemmDataPerReadA
,
index_t
GemmDataPerReadB
,
typename
InBlockCopySubLengths_E_N1_B_N2
,
typename
InBlockCopyClusterLengths_E_N1_B_N2
,
typename
InBlockCopyThreadClusterArrangeOrder
,
typename
InBlockCopySrcAccessOrder
,
typename
InBlockCopyDstAccessOrder
,
index_t
InBlockCopySrcDataPerRead_B
,
index_t
InBlockCopyDstDataPerWrite_N2
,
typename
WeiBlockCopySubLengths_E_K
,
typename
WeiBlockCopyClusterLengths_E_K
,
typename
WeiBlockCopyThreadClusterArrangeOrder
,
typename
WeiBlockCopySrcAccessOrder
,
typename
WeiBlockCopyDstAccessOrder
,
index_t
WeiBlockCopySrcDataPerRead_E
,
index_t
WeiBlockCopyDstDataPerWrite_K
>
struct
GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded
{
#if 1
__device__
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
const
Float
*
const
__restrict__
p_wei_global
,
Float
*
const
__restrict__
p_out_global
)
const
{
// this is a mess
// TODO: find more elegent way of specifying (or calculating) performance parameters
constexpr
index_t
N1
=
GemmNRepeat
;
constexpr
index_t
N2
=
GemmNPerThreadSubC
;
static_assert
((
N1
*
N2
*
BPerBlock
)
%
(
GemmNPerThreadSubC
*
GemmNLevel0Cluster
*
GemmNLevel1Cluster
)
==
0
,
"wrong!"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
in_n_c_hi_wi_global_desc
=
make_native_tensor_descriptor
(
InGlobalDesc
::
GetLengths
(),
InGlobalDesc
::
GetStrides
());
constexpr
auto
wei_k_c_y_x_global_desc
=
make_native_tensor_descriptor
(
WeiGlobalDesc
::
GetLengths
(),
WeiGlobalDesc
::
GetStrides
());
constexpr
auto
out_n_k_ho_wo_global_desc
=
make_native_tensor_descriptor
(
OutGlobalDesc
::
GetLengths
(),
OutGlobalDesc
::
GetStrides
());
constexpr
index_t
N
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I0
);
constexpr
index_t
C
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Hi
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I2
);
constexpr
index_t
Wi
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I3
);
constexpr
index_t
K
=
out_n_k_ho_wo_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Ho
=
out_n_k_ho_wo_global_desc
.
GetLength
(
I2
);
constexpr
index_t
Wo
=
out_n_k_ho_wo_global_desc
.
GetLength
(
I3
);
constexpr
index_t
Y
=
wei_k_c_y_x_global_desc
.
GetLength
(
I2
);
constexpr
index_t
X
=
wei_k_c_y_x_global_desc
.
GetLength
(
I3
);
constexpr
index_t
ConvStrideH
=
ConvStrides
{}[
0
];
constexpr
index_t
ConvStrideW
=
ConvStrides
{}[
1
];
constexpr
index_t
ConvDilationH
=
ConvDilations
{}[
0
];
constexpr
index_t
ConvDilationW
=
ConvDilations
{}[
1
];
static_assert
(
N
%
(
N1
*
N2
)
==
0
,
"wrong! cannot divice N evenly among thread"
);
constexpr
index_t
N0
=
N
/
(
N1
*
N2
);
constexpr
index_t
B
=
N0
*
Ho
*
Wo
;
constexpr
index_t
E
=
C
*
Y
*
X
;
// sanity-check for vectorized memory load
static_assert
((
Wo
==
1
||
(
ConvStrideW
==
1
||
InBlockCopySrcDataPerRead_B
==
1
))
&&
(
X
==
1
||
ConvDilationW
%
InBlockCopySrcDataPerRead_B
==
0
),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated"
);
// divide block work by [K, B]
static_assert
(
K
%
KPerBlock
==
0
&&
B
%
BPerBlock
==
0
&&
E
%
EPerBlock
==
0
,
"wrong! cannot divide work evenly among block"
);
constexpr
index_t
KBlockWork
=
K
/
KPerBlock
;
constexpr
index_t
BBlockWork
=
B
/
BPerBlock
;
constexpr
auto
block_work_desc
=
make_ConstantTensorDescriptor_packed
(
Sequence
<
KBlockWork
,
BBlockWork
>
{});
const
auto
block_work_multi_id
=
block_work_desc
.
GetMultiIndexFrom1dIndex
(
get_block_1d_id
());
const
index_t
k_block_data_on_global
=
block_work_multi_id
[
0
]
*
KPerBlock
;
const
index_t
b_block_data_on_global
=
block_work_multi_id
[
1
]
*
BPerBlock
;
// input tensor
// global memory
constexpr
auto
in_n_c_hip_wip_global_desc
=
transform_tensor_descriptor
(
in_n_c_hi_wi_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
C
>
{},
Pad
<
Sequence
<
Hi
,
Wi
>
,
LeftPads
,
RightPads
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}));
constexpr
auto
in_n0_n1_n2_c_y_ho_x_wo_global_desc
=
transform_tensor_descriptor
(
in_n_c_hip_wip_global_desc
,
make_tuple
(
Unmerge
<
Sequence
<
N0
,
N1
,
N2
>>
{},
PassThrough
<
C
>
{},
Embed
<
Sequence
<
Y
,
Ho
>
,
Sequence
<
ConvDilationH
,
ConvStrideH
,
0
>>
{},
Embed
<
Sequence
<
X
,
Wo
>
,
Sequence
<
ConvDilationW
,
ConvStrideW
,
0
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
,
5
>
{},
Sequence
<
6
,
7
>
{}));
constexpr
auto
in_e_n1_b_n2_global_desc
=
transform_tensor_descriptor
(
in_n0_n1_n2_c_y_ho_x_wo_global_desc
,
make_tuple
(
Merge
<
Sequence
<
C
,
Y
,
X
>>
{},
PassThrough
<
N1
>
{},
Merge
<
Sequence
<
N0
,
Ho
,
Wo
>>
{},
PassThrough
<
N2
>
{}),
make_tuple
(
Sequence
<
3
,
4
,
6
>
{},
Sequence
<
1
>
{},
Sequence
<
0
,
5
,
7
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
// memory layout descriptor in LDS [E, N1, B, N2], dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
in_e_n1_b_n2_block_desc
=
make_native_tensor_descriptor_aligned
(
Sequence
<
EPerBlock
,
N1
,
BPerBlock
,
N2
>
{},
Number
<
InBlockCopyDstDataPerWrite_N2
>
{});
// this check is ad-hoc
// TODO: need to properly implement tensor descriptor with multiple alignment
// requirements
static_assert
(
in_e_n1_b_n2_block_desc
.
GetStride
(
I1
)
%
GemmDataPerReadB
==
0
,
"GemmDataPerReadB alignment requirement is not satisfied"
);
// input blockwise copy
// slice a merged tensor, reorder and copy to a normal tensor
// this copy operator already has blockwise offset built-in
auto
blockwise_in_copy
=
BlockwiseGenericTensorSliceCopy_v4
<
BlockSize
,
decltype
(
in_e_n1_b_n2_global_desc
),
decltype
(
in_e_n1_b_n2_block_desc
),
decltype
(
in_e_n1_b_n2_block_desc
.
GetLengths
()),
InBlockCopySubLengths_E_N1_B_N2
,
InBlockCopyClusterLengths_E_N1_B_N2
,
InBlockCopyThreadClusterArrangeOrder
,
InBlockCopySrcAccessOrder
,
InBlockCopyDstAccessOrder
,
2
,
3
,
InBlockCopySrcDataPerRead_B
,
InBlockCopyDstDataPerWrite_N2
>
(
{
0
,
0
,
b_block_data_on_global
,
0
},
{
0
,
0
,
0
,
0
});
#if 0
// weight tensor
// tensor descriptor in device memory, src of blockwise copy
constexpr auto wei_e_k_global_desc =
transform_tensor_descriptor(wei_k_c_y_x_global_desc,
make_tuple(Merge<Sequence<C, Y, X>>{}, PassThrough<K>{}),
make_tuple(Sequence<1, 2, 3>{}, Sequence<0>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
#else
// hack
constexpr
auto
wei_e_k_global_desc_old
=
WeiGlobalDesc
::
Unfold
(
I1
,
I3
).
ReorderGivenNew2Old
(
Sequence
<
1
,
0
>
{});
constexpr
auto
wei_e_k_global_desc
=
make_native_tensor_descriptor
(
wei_e_k_global_desc_old
.
GetLengths
(),
wei_e_k_global_desc_old
.
GetStrides
());
#endif
// tensor descriptor in LDS, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
wei_e_k_block_desc
=
make_native_tensor_descriptor_aligned
(
Sequence
<
EPerBlock
,
KPerBlock
>
{},
Number
<
math
::
lcm
(
WeiBlockCopyDstDataPerWrite_K
,
GemmDataPerReadA
)
>
{});
// this check is ad-hoc
// TODO: need to properly implement tensor descriptor with multiple alignment
// requirements
static_assert
(
wei_e_k_block_desc
.
GetStride
(
I0
)
%
GemmDataPerReadA
==
0
,
"GemmDataPerReadA alignment requirement is not satisfied"
);
// operator for blockwise copy of weight into LDS
// slice a tensor, and copy it into another tensor
// this copy operator already have blockwise offset built-in
auto
blockwise_wei_copy
=
BlockwiseGenericTensorSliceCopy_v4
<
BlockSize
,
decltype
(
wei_e_k_global_desc
),
decltype
(
wei_e_k_block_desc
),
decltype
(
wei_e_k_block_desc
.
GetLengths
()),
WeiBlockCopySubLengths_E_K
,
WeiBlockCopyClusterLengths_E_K
,
WeiBlockCopyThreadClusterArrangeOrder
,
WeiBlockCopySrcAccessOrder
,
WeiBlockCopyDstAccessOrder
,
0
,
1
,
WeiBlockCopySrcDataPerRead_E
,
WeiBlockCopyDstDataPerWrite_K
>
(
{
0
,
k_block_data_on_global
},
{
0
,
0
});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[EPerBlock, KPerBlock] is in LDS
// b_mtx[EPerBlocl, N1 * BPerBlock * N2] is in LDS
// c_mtx[KPerBlock, N1 * BPerBlock * N2] is distributed among threads, and saved in
// register
constexpr
auto
a_e_k_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
wei_e_k_block_desc
);
constexpr
auto
b_e_n1bn2_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
in_e_n1_b_n2_block_desc
.
GetLength
(
I0
),
in_e_n1_b_n2_block_desc
.
GetLength
(
I1
)
*
in_e_n1_b_n2_block_desc
.
GetLength
(
I2
)
*
in_e_n1_b_n2_block_desc
.
GetLength
(
I3
),
in_e_n1_b_n2_block_desc
.
GetStride
(
I0
));
// sanity check
static_assert
(
KPerBlock
%
(
GemmMPerThreadSubC
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
)
==
0
,
"wrong!"
);
constexpr
index_t
GemmMRepeat
=
KPerBlock
/
(
GemmMPerThreadSubC
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
);
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
constexpr
auto
c_k0k2_n1n2_thread_mtx_desc
=
make_ConstantMatrixDescriptor_packed
(
Number
<
GemmMRepeat
*
GemmMPerThreadSubC
>
{},
Number
<
N1
*
N2
>
{});
const
auto
blockwise_gemm
=
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
<
BlockSize
,
decltype
(
a_e_k_block_mtx_desc
),
decltype
(
b_e_n1bn2_block_mtx_desc
),
decltype
(
c_k0k2_n1n2_thread_mtx_desc
),
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
GemmDataPerReadA
,
GemmDataPerReadB
>
{};
// LDS allocation for input and weight: be careful of alignment
constexpr
index_t
max_align
=
math
::
lcm
(
InBlockCopyDstDataPerWrite_N2
,
WeiBlockCopyDstDataPerWrite_K
,
GemmDataPerReadA
,
GemmDataPerReadB
);
constexpr
index_t
in_block_space
=
math
::
integer_least_multiple
(
in_e_n1_b_n2_block_desc
.
GetElementSpace
(),
max_align
);
constexpr
index_t
wei_block_space
=
math
::
integer_least_multiple
(
wei_e_k_block_desc
.
GetElementSpace
(),
max_align
);
__shared__
Float
p_in_block
[
in_block_space
];
__shared__
Float
p_wei_block
[
wei_block_space
];
// register allocation for output
Float
p_out_thread
[
c_k0k2_n1n2_thread_mtx_desc
.
GetElementSpace
()];
// zero out threadwise output
threadwise_matrix_set_zero
(
c_k0k2_n1n2_thread_mtx_desc
,
p_out_thread
);
// do work
for
(
index_t
e
=
0
;
e
<
E
;
e
+=
EPerBlock
)
{
blockwise_in_copy
.
Run
(
p_in_global
,
p_in_block
);
blockwise_wei_copy
.
Run
(
p_wei_global
,
p_wei_block
);
__syncthreads
();
blockwise_gemm
.
Run
(
p_wei_block
,
p_in_block
,
p_out_thread
);
__syncthreads
();
blockwise_in_copy
.
MoveSrcSliceWindow
(
make_multi_index
(
EPerBlock
,
0
,
0
,
0
),
True
);
blockwise_wei_copy
.
MoveSrcSliceWindow
(
make_multi_index
(
EPerBlock
,
0
),
True
);
}
// copy output: register to global memory
{
constexpr
index_t
K2
=
GemmMPerThreadSubC
;
constexpr
index_t
K1
=
GemmMLevel0Cluster
*
GemmMLevel1Cluster
;
static_assert
(
K
%
(
K1
*
K2
)
==
0
,
"wrong!"
);
// define tensor descriptor for threadwise copy
// output memory layout descriptor in register
constexpr
auto
out_k0_k1_k2_n1_n0_ho_wo_n2_thread_desc
=
make_native_tensor_descriptor_packed
(
Sequence
<
KPerBlock
/
(
K1
*
K2
),
1
,
K2
,
N1
,
1
,
1
,
1
,
N2
>
{});
// output tensor descriptor in register, src of threadwise copy
constexpr
auto
out_n0_n1_n2_k0_k1_k2_ho_wo_thread_desc
=
reorder_tensor_descriptor_given_upper2lower
(
out_k0_k1_k2_n1_n0_ho_wo_n2_thread_desc
,
Sequence
<
4
,
3
,
7
,
0
,
1
,
2
,
5
,
6
>
{});
// output memory layout descriptor in device memory, dst of threadwise copy
constexpr
auto
out_n0_n1_n2_k0_k1_k2_ho_wo_global_desc
=
transform_tensor_descriptor
(
out_n_k_ho_wo_global_desc
,
make_tuple
(
Unmerge
<
Sequence
<
N
/
(
N1
*
N2
),
N1
,
N2
>>
{},
Unmerge
<
Sequence
<
K
/
(
K1
*
K2
),
K1
,
K2
>>
{},
PassThrough
<
Ho
>
{},
PassThrough
<
Wo
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
,
5
>
{},
Sequence
<
6
>
{},
Sequence
<
7
>
{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
blockwise_gemm
.
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
const
index_t
k_thread_data_on_global
=
k_block_data_on_global
+
c_thread_mtx_on_block
.
row
;
const
index_t
b_thread_data_on_global
=
b_block_data_on_global
+
c_thread_mtx_on_block
.
col
/
N2
;
// output merged global tensor descriptor, for calculating origin of thread tensor
// in global memory
constexpr
auto
out_n0_n1_n2_k_ho_wo_global_desc
=
transform_tensor_descriptor
(
out_n_k_ho_wo_global_desc
,
make_tuple
(
Unmerge
<
Sequence
<
N
/
(
N1
*
N2
),
N1
,
N2
>>
{},
PassThrough
<
K
>
{},
PassThrough
<
Ho
>
{},
PassThrough
<
Wo
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}));
constexpr
auto
out_k_n1_b_n2_global_desc
=
transform_tensor_descriptor
(
out_n0_n1_n2_k_ho_wo_global_desc
,
make_tuple
(
PassThrough
<
K
>
{},
PassThrough
<
N1
>
{},
Merge
<
Sequence
<
N0
,
Ho
,
Wo
>>
{},
PassThrough
<
N2
>
{}),
make_tuple
(
Sequence
<
3
>
{},
Sequence
<
1
>
{},
Sequence
<
0
,
4
,
5
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
// origin of dst in device memory
Float
*
p_out_thread_on_global
=
p_out_global
+
out_k_n1_b_n2_global_desc
.
CalculateOffset
(
{
k_thread_data_on_global
,
0
,
b_thread_data_on_global
,
0
});
ThreadwiseGenericTensorSliceCopy_v4r2
<
decltype
(
out_n0_n1_n2_k0_k1_k2_ho_wo_thread_desc
),
decltype
(
out_n0_n1_n2_k0_k1_k2_ho_wo_global_desc
),
decltype
(
out_n0_n1_n2_k0_k1_k2_ho_wo_thread_desc
.
GetLengths
()),
arithmetic_sequence_gen
<
0
,
8
,
1
>::
type
,
7
,
1
,
1
>
({
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
},
{
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
})
.
Run
(
p_out_thread
,
p_out_thread_on_global
);
}
}
#else
__device__
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
const
Float
*
const
__restrict__
p_wei_global
,
Float
*
const
__restrict__
p_out_global
)
const
{
// this is a mess
// TODO: find more elegent way of specifying (or calculating) performance parameters
constexpr
index_t
N1
=
GemmNRepeat
;
constexpr
index_t
N2
=
GemmNPerThreadSubC
;
static_assert
((
N1
*
N2
*
BPerBlock
)
%
(
GemmNPerThreadSubC
*
GemmNLevel0Cluster
*
GemmNLevel1Cluster
)
==
0
,
"wrong!"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I5
=
Number
<
5
>
{};
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
in_n_c_h_w_global_desc
=
make_native_tensor_descriptor
(
InGlobalDesc
::
GetLengths
(),
InGlobalDesc
::
GetStrides
());
constexpr
auto
wei_k_c_y_x_global_desc
=
make_native_tensor_descriptor
(
WeiGlobalDesc
::
GetLengths
(),
WeiGlobalDesc
::
GetStrides
());
constexpr
auto
out_n_k_h_w_global_desc
=
make_native_tensor_descriptor
(
OutGlobalDesc
::
GetLengths
(),
OutGlobalDesc
::
GetStrides
());
constexpr
index_t
N
=
in_n_c_h_w_global_desc
.
GetLength
(
I0
);
constexpr
index_t
C
=
in_n_c_h_w_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Hi
=
in_n_c_h_w_global_desc
.
GetLength
(
I2
);
constexpr
index_t
Wi
=
in_n_c_h_w_global_desc
.
GetLength
(
I3
);
constexpr
index_t
K
=
out_n_k_h_w_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Ho
=
out_n_k_h_w_global_desc
.
GetLength
(
I2
);
constexpr
index_t
Wo
=
out_n_k_h_w_global_desc
.
GetLength
(
I3
);
constexpr
index_t
Y
=
wei_k_c_y_x_global_desc
.
GetLength
(
I2
);
constexpr
index_t
X
=
wei_k_c_y_x_global_desc
.
GetLength
(
I3
);
constexpr
index_t
ConvStrideH
=
ConvStrides
{}[
0
];
constexpr
index_t
ConvStrideW
=
ConvStrides
{}[
1
];
constexpr
index_t
ConvDilationH
=
ConvDilations
{}[
0
];
constexpr
index_t
ConvDilationW
=
ConvDilations
{}[
1
];
static_assert
(
N
%
(
N1
*
N2
)
==
0
,
"wrong! cannot divice N evenly among thread"
);
constexpr
index_t
N0
=
N
/
(
N1
*
N2
);
constexpr
index_t
B
=
N0
*
Ho
*
Wo
;
constexpr
index_t
E
=
C
*
Y
*
X
;
// sanity-check for vectorized memory load
static_assert
(
ConvStrideW
==
1
||
InBlockCopySrcDataPerRead_B
==
1
,
"wrong! global vector load of input tensor is wrong"
);
static_assert
((
X
==
1
||
ConvDilationW
%
InBlockCopySrcDataPerRead_B
==
0
),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated"
);
// input
constexpr
auto
in_n_c_hi_wi_global_desc
=
make_native_tensor_descriptor
(
InGlobalDesc
::
GetLengths
(),
InGlobalDesc
::
GetStrides
());
constexpr
auto
in_n_c_hip_wip_global_desc
=
transform_tensor_descriptor
(
in_n_c_hi_wi_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
C
>
{},
Pad
<
Sequence
<
Hi
,
Wi
>
,
LeftPads
,
RightPads
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}));
constexpr
auto
in_n0_n1_n2_c_y_ho_x_wo_global_desc
=
transform_tensor_descriptor
(
in_n_c_hip_wip_global_desc
,
make_tuple
(
Unmerge
<
Sequence
<
N0
,
N1
,
N2
>>
{},
PassThrough
<
C
>
{},
Embed
<
Sequence
<
Y
,
Ho
>
,
Sequence
<
ConvDilationH
,
ConvStrideH
,
0
>>
{},
Embed
<
Sequence
<
X
,
Wo
>
,
Sequence
<
ConvDilationW
,
ConvStrideW
,
0
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
,
5
>
{},
Sequence
<
6
,
7
>
{}));
constexpr
auto
in_e_n1_b_n2_global_desc
=
transform_tensor_descriptor
(
in_n0_n1_n2_c_y_ho_x_wo_global_desc
,
make_tuple
(
Merge
<
Sequence
<
C
,
Y
,
X
>>
{},
PassThrough
<
N1
>
{},
Merge
<
Sequence
<
N0
,
Ho
,
Wo
>>
{},
PassThrough
<
N2
>
{}),
make_tuple
(
Sequence
<
3
,
4
,
6
>
{},
Sequence
<
1
>
{},
Sequence
<
0
,
5
,
7
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
// weight
constexpr
auto
wei_e_k_global_desc
=
transform_tensor_descriptor
(
wei_k_c_y_x_global_desc
,
make_tuple
(
Merge
<
Sequence
<
C
,
Y
,
X
>>
{},
PassThrough
<
K
>
{}),
make_tuple
(
Sequence
<
1
,
2
,
3
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_tensor_descriptor("in_n_c_hi_wi_global_desc: ", in_n_c_hi_wi_global_desc);
print_tensor_descriptor("in_n_c_hip_wip_global_desc: ", in_n_c_hip_wip_global_desc);
print_tensor_descriptor("in_n0_n1_n2_c_y_ho_x_wo_global_desc: ",
in_n0_n1_n2_c_y_ho_x_wo_global_desc);
print_tensor_descriptor("in_e_n1_b_n2_global_desc: ", in_e_n1_b_n2_global_desc);
auto coord3 = make_tensor_coordinate_v2(in_e_n1_b_n2_global_desc, {1, 1, 1, 1});
auto idx3 = coord3.GetIndex();
auto idx2 = coord3.GetLowerCoordinate().GetIndex();
auto idx1 = coord3.GetLowerCoordinate().GetLowerCoordinate().GetIndex();
auto idx0 =
coord3.GetLowerCoordinate().GetLowerCoordinate().GetLowerCoordinate().GetIndex();
print_array("idx3: ", idx3);
print_array("idx2: ", idx2);
print_array("idx1: ", idx1);
print_array("idx0: ", idx0);
}
#else
index_t
itmp
=
get_block_1d_id
()
+
get_thread_local_1d_id
();
auto
wei_coord1
=
make_tensor_coordinate_v2
(
wei_e_k_global_desc
,
{
itmp
,
itmp
+
1
});
auto
step_sizes
=
make_multi_index
(
EPerBlock
,
0
);
wei_coord1
+=
step_sizes
;
p_out_global
[
0
]
=
wei_coord1
.
GetLowerCoordinate
().
GetIndex
()[
0
];
p_out_global
[
1
]
=
wei_coord1
.
GetLowerCoordinate
().
GetIndex
()[
1
];
p_out_global
[
2
]
=
wei_coord1
.
GetLowerCoordinate
().
GetIndex
()[
2
];
p_out_global
[
3
]
=
wei_coord1
.
GetLowerCoordinate
().
GetIndex
()[
3
];
#endif
}
#endif
};
}
// namespace ck
#endif
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
deleted
100644 → 0
View file @
98a2cfcc
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_HPP
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMergedTensorDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp"
#include "blockwise_gemm.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
namespace
ck
{
// B = merge(N, Ho, Wo)
template
<
index_t
GridSize
,
index_t
BlockSize
,
class
Float
,
class
InGlobalDesc
,
class
WeiGlobalDesc
,
class
OutGlobalDesc
,
class
ConvStrides
,
class
ConvDilations
,
index_t
BPerBlock
,
index_t
KPerBlock
,
index_t
EPerBlock
,
index_t
GemmMPerThreadSubC
,
index_t
GemmNPerThreadSubC
,
index_t
GemmMLevel0Cluster
,
index_t
GemmNLevel0Cluster
,
index_t
GemmMLevel1Cluster
,
index_t
GemmNLevel1Cluster
,
index_t
GemmKPerThreadLoop
,
index_t
GemmDataPerReadA
,
index_t
GemmDataPerReadB
,
class
InBlockCopySubLengths_E_B
,
class
InBlockCopyClusterLengths_E_B
,
class
InBlockCopyThreadClusterArrangeOrder
,
class
InBlockCopySrcAccessOrder
,
class
InBlockCopyDstAccessOrder
,
index_t
InBlockCopyDataPerAccess_B
,
class
WeiBlockCopySubLengths_E_K
,
class
WeiBlockCopyClusterLengths_E_K
,
class
WeiBlockCopyThreadClusterArrangeOrder
,
class
WeiBlockCopySrcAccessOrder
,
class
WeiBlockCopyDstAccessOrder
,
index_t
WeiBlockCopySrcDataPerRead_E
,
index_t
WeiBlockCopyDstDataPerWrite_K
,
index_t
OutThreadCopyDataPerAccess_B
>
struct
GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
{
__device__
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
const
Float
*
const
__restrict__
p_wei_global
,
Float
*
const
__restrict__
p_out_global
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I5
=
Number
<
5
>
{};
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
in_n_c_h_w_global_desc
=
InGlobalDesc
{};
constexpr
auto
wei_k_c_y_x_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
out_n_k_h_w_global_desc
=
OutGlobalDesc
{};
constexpr
index_t
N
=
in_n_c_h_w_global_desc
.
GetLengths
()[
0
];
constexpr
index_t
C
=
in_n_c_h_w_global_desc
.
GetLengths
()[
1
];
constexpr
index_t
K
=
out_n_k_h_w_global_desc
.
GetLengths
()[
1
];
constexpr
index_t
Ho
=
out_n_k_h_w_global_desc
.
GetLengths
()[
2
];
constexpr
index_t
Wo
=
out_n_k_h_w_global_desc
.
GetLengths
()[
3
];
constexpr
index_t
Y
=
wei_k_c_y_x_global_desc
.
GetLengths
()[
2
];
constexpr
index_t
X
=
wei_k_c_y_x_global_desc
.
GetLengths
()[
3
];
constexpr
index_t
ConvStrideH
=
ConvStrides
{}[
0
];
constexpr
index_t
ConvStrideW
=
ConvStrides
{}[
1
];
constexpr
index_t
ConvDilationH
=
ConvDilations
{}[
0
];
constexpr
index_t
ConvDilationW
=
ConvDilations
{}[
1
];
constexpr
index_t
E
=
C
*
Y
*
X
;
constexpr
index_t
B
=
N
*
Ho
*
Wo
;
// sanity-check for vectorized memory load
static_assert
((
Wo
==
1
||
(
ConvStrideW
==
1
||
InBlockCopyDataPerAccess_B
==
1
))
&&
(
X
==
1
||
ConvDilationW
%
InBlockCopyDataPerAccess_B
==
0
),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated"
);
// divide block work by [K, B]
static_assert
(
K
%
KPerBlock
==
0
&&
B
%
BPerBlock
==
0
&&
E
%
EPerBlock
==
0
,
"wrong! cannot divide work evenly among block"
);
constexpr
index_t
KBlockWork
=
K
/
KPerBlock
;
constexpr
index_t
BBlockWork
=
B
/
BPerBlock
;
constexpr
auto
block_work_desc
=
make_ConstantTensorDescriptor_packed
(
Sequence
<
KBlockWork
,
BBlockWork
>
{});
const
auto
block_work_multi_id
=
block_work_desc
.
GetMultiIndexFrom1dIndex
(
get_block_1d_id
());
const
index_t
k_block_data_on_global
=
block_work_multi_id
[
0
]
*
KPerBlock
;
const
index_t
b_block_data_on_global
=
block_work_multi_id
[
1
]
*
BPerBlock
;
// input tensor
// tensor descriptor in device memory [N, Ho, Wo]
constexpr
auto
in_n_ho_wo_global_desc
=
in_n_c_h_w_global_desc
.
Extract
(
I0
,
I2
,
I3
)
.
StridedSlice
(
I1
,
Number
<
Ho
>
{},
Number
<
ConvStrideH
>
{})
.
StridedSlice
(
I2
,
Number
<
Wo
>
{},
Number
<
ConvStrideW
>
{});
// batch descritpor for device memory
constexpr
auto
in_c_y_x_global_desc
=
in_n_c_h_w_global_desc
.
StridedSlice
(
I2
,
Number
<
Y
>
{},
Number
<
ConvDilationH
>
{})
.
StridedSlice
(
I3
,
Number
<
X
>
{},
Number
<
ConvDilationW
>
{})
.
Extract
(
Sequence
<
1
,
2
,
3
>
{});
// merged tensor descriptor in device memory [E, B], src of blockwise copy
constexpr
auto
in_e_b_global_desc
=
make_ConstantMergedTensorDescriptor
(
in_c_y_x_global_desc
.
Embed
(
in_n_ho_wo_global_desc
),
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
,
5
>
{});
// memory layout descriptor in LDS [E, B], dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
in_e_b_block_desc
=
make_ConstantTensorDescriptor_packed
(
Sequence
<
EPerBlock
,
BPerBlock
>
{});
// input blockwise copy
// slice a merged tensor, reorder and copy to a normal tensor
// this copy operator already has blockwise offset built-in
auto
blockwise_in_copy
=
BlockwiseGenericTensorSliceCopy_v2
<
BlockSize
,
decltype
(
in_e_b_global_desc
),
decltype
(
in_e_b_block_desc
),
decltype
(
in_e_b_block_desc
.
GetLengths
()),
InBlockCopySubLengths_E_B
,
InBlockCopyClusterLengths_E_B
,
InBlockCopyThreadClusterArrangeOrder
,
InBlockCopySrcAccessOrder
,
InBlockCopyDstAccessOrder
,
1
,
1
,
InBlockCopyDataPerAccess_B
,
InBlockCopyDataPerAccess_B
>
(
{
0
,
b_block_data_on_global
},
{
0
,
0
});
// weight tensor
// tensor descriptor in device memory, src of blockwise copy
constexpr
auto
wei_e_k_global_desc
=
wei_k_c_y_x_global_desc
.
Unfold
(
I1
,
I3
).
ReorderGivenNew2Old
(
Sequence
<
1
,
0
>
{});
// tensor descriptor in LDS, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
wei_e_k_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
EPerBlock
,
KPerBlock
>
{},
Number
<
math
::
lcm
(
WeiBlockCopyDstDataPerWrite_K
,
GemmDataPerReadA
)
>
{});
// this check is ad-hoc
// TODO: need to properly implement tensor descriptor with multiple alignment
// requirements
static_assert
(
wei_e_k_block_desc
.
GetStride
(
I0
)
%
GemmDataPerReadA
==
0
,
"GemmDataPerReadA alignment requirement is not satisfied"
);
// operator for blockwise copy of weight into LDS
// slice a tensor, and copy it into another tensor
// this copy operator already have blockwise offset built-in
auto
blockwise_wei_copy
=
BlockwiseGenericTensorSliceCopy_v2
<
BlockSize
,
decltype
(
wei_e_k_global_desc
),
decltype
(
wei_e_k_block_desc
),
decltype
(
wei_e_k_block_desc
.
GetLengths
()),
WeiBlockCopySubLengths_E_K
,
WeiBlockCopyClusterLengths_E_K
,
WeiBlockCopyThreadClusterArrangeOrder
,
WeiBlockCopySrcAccessOrder
,
WeiBlockCopyDstAccessOrder
,
0
,
1
,
WeiBlockCopySrcDataPerRead_E
,
WeiBlockCopyDstDataPerWrite_K
>
(
{
0
,
k_block_data_on_global
},
{
0
,
0
});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[EPerBlock, KPerBlock] is in LDS
// b_mtx[EPerBlocl, BPerBlock] is in LDS
// c_mtx[KPerBlock, BPerBlock] is distributed among threads, and saved in
// register
constexpr
auto
a_e_k_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
wei_e_k_block_desc
);
constexpr
auto
b_e_b_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
in_e_b_block_desc
);
// sanity check
static_assert
(
KPerBlock
%
(
GemmMPerThreadSubC
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
)
==
0
&&
BPerBlock
%
(
GemmNPerThreadSubC
*
GemmNLevel0Cluster
*
GemmNLevel1Cluster
)
==
0
,
"wrong!"
);
constexpr
index_t
GemmMRepeat
=
KPerBlock
/
(
GemmMPerThreadSubC
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
);
constexpr
index_t
GemmNRepeat
=
BPerBlock
/
(
GemmNPerThreadSubC
*
GemmNLevel0Cluster
*
GemmNLevel1Cluster
);
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
constexpr
auto
c_k0k1_b0b1_thread_mtx_desc
=
make_ConstantMatrixDescriptor_packed
(
Number
<
GemmMRepeat
*
GemmMPerThreadSubC
>
{},
Number
<
GemmNRepeat
*
GemmNPerThreadSubC
>
{});
const
auto
blockwise_gemm
=
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
<
BlockSize
,
decltype
(
a_e_k_block_mtx_desc
),
decltype
(
b_e_b_block_mtx_desc
),
decltype
(
c_k0k1_b0b1_thread_mtx_desc
),
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
GemmDataPerReadA
,
GemmDataPerReadB
>
{};
// LDS allocation for input and weight: be careful of alignment
constexpr
index_t
max_align
=
math
::
lcm
(
InBlockCopyDataPerAccess_B
,
WeiBlockCopyDstDataPerWrite_K
,
GemmDataPerReadA
,
GemmDataPerReadB
);
constexpr
index_t
in_block_space
=
math
::
integer_least_multiple
(
in_e_b_block_desc
.
GetElementSpace
(),
max_align
);
constexpr
index_t
wei_block_space
=
math
::
integer_least_multiple
(
wei_e_k_block_desc
.
GetElementSpace
(),
max_align
);
__shared__
Float
p_in_block
[
in_block_space
];
__shared__
Float
p_wei_block
[
wei_block_space
];
// register allocation for output
Float
p_out_thread
[
c_k0k1_b0b1_thread_mtx_desc
.
GetElementSpace
()];
// zero out threadwise output
threadwise_matrix_set_zero
(
c_k0k1_b0b1_thread_mtx_desc
,
p_out_thread
);
for
(
index_t
e_block_data_begin
=
0
;
e_block_data_begin
<
E
;
e_block_data_begin
+=
EPerBlock
)
{
blockwise_in_copy
.
Run
(
p_in_global
,
p_in_block
);
blockwise_wei_copy
.
Run
(
p_wei_global
,
p_wei_block
);
__syncthreads
();
blockwise_gemm
.
Run
(
p_wei_block
,
p_in_block
,
p_out_thread
);
__syncthreads
();
blockwise_in_copy
.
MoveSrcSliceWindow
(
Sequence
<
EPerBlock
,
0
>
{},
True
);
blockwise_wei_copy
.
MoveSrcSliceWindow
(
Sequence
<
EPerBlock
,
0
>
{},
True
);
}
// copy output: register to global memory
{
constexpr
index_t
K1
=
GemmMPerThreadSubC
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
;
constexpr
index_t
B1
=
GemmNPerThreadSubC
*
GemmNLevel0Cluster
*
GemmNLevel1Cluster
;
// define tensor descriptor for threadwise copy
// output global descriptor, for calculating origin of thread tensor
// in global memory
constexpr
auto
out_k_b_global_desc
=
make_ConstantMergedTensorDescriptor
(
out_n_k_h_w_global_desc
,
Sequence
<
1
>
{},
Sequence
<
0
,
2
,
3
>
{});
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
blockwise_gemm
.
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
const
index_t
k_thread_data_on_global
=
k_block_data_on_global
+
c_thread_mtx_on_block
.
row
;
const
index_t
b_thread_data_on_global
=
b_block_data_on_global
+
c_thread_mtx_on_block
.
col
;
// This is a hack, because slicing a merged dimension is not supported yet.
// This should be replaced with logic above, once slicing a merged dimension support
// become available
// dst descriptor
constexpr
auto
out_k0_k1_b_global_desc
=
make_ConstantMergedTensorDescriptor
(
out_n_k_h_w_global_desc
.
Fold
(
I1
,
Number
<
K1
>
{}),
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
0
,
3
,
4
>
{});
// src descriptor
constexpr
auto
out_k0_k1_b_thread_desc
=
make_ConstantTensorDescriptor_packed
(
Sequence
<
GemmMRepeat
,
GemmMPerThreadSubC
,
GemmNRepeat
*
GemmNPerThreadSubC
>
{});
using
OutThreadCopySliceLengths
=
Sequence
<
GemmMRepeat
,
GemmMPerThreadSubC
,
GemmNPerThreadSubC
>
;
auto
threadwise_out_copy
=
ThreadwiseGenericTensorSliceCopy_v2r1
<
decltype
(
out_k0_k1_b_thread_desc
),
decltype
(
out_k0_k1_b_global_desc
),
OutThreadCopySliceLengths
,
arithmetic_sequence_gen
<
0
,
3
,
1
>::
type
,
arithmetic_sequence_gen
<
0
,
3
,
1
>::
type
,
2
,
2
,
OutThreadCopyDataPerAccess_B
,
OutThreadCopyDataPerAccess_B
>
(
{
0
,
0
,
0
},
{
k_thread_data_on_global
/
K1
,
k_thread_data_on_global
%
K1
,
b_thread_data_on_global
});
for
(
index_t
nrepeat
=
0
;
nrepeat
<
GemmNRepeat
;
++
nrepeat
)
{
threadwise_out_copy
.
Run
(
p_out_thread
,
p_out_global
);
threadwise_out_copy
.
MoveSrcSliceWindow
(
Sequence
<
0
,
0
,
GemmNPerThreadSubC
>
{},
True
);
threadwise_out_copy
.
MoveDstSliceWindow
(
Sequence
<
0
,
0
,
B1
>
{},
True
);
}
}
}
};
}
// namespace ck
#endif
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_padded.hpp
deleted
100644 → 0
View file @
98a2cfcc
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_PADDED_HPP
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_PADDED_HPP
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMergedTensorDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp"
#include "blockwise_gemm.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
namespace
ck
{
// B = merge(N, Ho, Wo)
template
<
index_t
GridSize
,
index_t
BlockSize
,
typename
Float
,
typename
InGlobalDesc
,
typename
WeiGlobalDesc
,
typename
OutGlobalDesc
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
LeftPads
,
typename
RightPads
,
index_t
BPerBlock
,
index_t
KPerBlock
,
index_t
EPerBlock
,
index_t
GemmMPerThreadSubC
,
index_t
GemmNPerThreadSubC
,
index_t
GemmMLevel0Cluster
,
index_t
GemmNLevel0Cluster
,
index_t
GemmMLevel1Cluster
,
index_t
GemmNLevel1Cluster
,
index_t
GemmKPerThreadLoop
,
index_t
GemmDataPerReadA
,
index_t
GemmDataPerReadB
,
typename
InBlockCopySubLengths_E_B
,
typename
InBlockCopyClusterLengths_E_B
,
typename
InBlockCopyThreadClusterArrangeOrder
,
typename
InBlockCopySrcAccessOrder
,
typename
InBlockCopyDstAccessOrder
,
index_t
InBlockCopyDataPerAccess_B
,
typename
WeiBlockCopySubLengths_E_K
,
typename
WeiBlockCopyClusterLengths_E_K
,
typename
WeiBlockCopyThreadClusterArrangeOrder
,
typename
WeiBlockCopySrcAccessOrder
,
typename
WeiBlockCopyDstAccessOrder
,
index_t
WeiBlockCopySrcDataPerRead_E
,
index_t
WeiBlockCopyDstDataPerWrite_K
,
index_t
OutThreadCopyDataPerAccess_B
>
struct
GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded
{
#if 1
__device__
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
const
Float
*
const
__restrict__
p_wei_global
,
Float
*
const
__restrict__
p_out_global
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
in_n_c_hi_wi_global_desc
=
make_native_tensor_descriptor
(
InGlobalDesc
::
GetLengths
(),
InGlobalDesc
::
GetStrides
());
constexpr
auto
wei_k_c_y_x_global_desc
=
make_native_tensor_descriptor
(
WeiGlobalDesc
::
GetLengths
(),
WeiGlobalDesc
::
GetStrides
());
constexpr
auto
out_n_k_ho_wo_global_desc
=
make_native_tensor_descriptor
(
OutGlobalDesc
::
GetLengths
(),
OutGlobalDesc
::
GetStrides
());
constexpr
index_t
N
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I0
);
constexpr
index_t
C
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Hi
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I2
);
constexpr
index_t
Wi
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I3
);
constexpr
index_t
K
=
out_n_k_ho_wo_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Ho
=
out_n_k_ho_wo_global_desc
.
GetLength
(
I2
);
constexpr
index_t
Wo
=
out_n_k_ho_wo_global_desc
.
GetLength
(
I3
);
constexpr
index_t
Y
=
wei_k_c_y_x_global_desc
.
GetLength
(
I2
);
constexpr
index_t
X
=
wei_k_c_y_x_global_desc
.
GetLength
(
I3
);
constexpr
index_t
ConvStrideH
=
ConvStrides
{}[
0
];
constexpr
index_t
ConvStrideW
=
ConvStrides
{}[
1
];
constexpr
index_t
ConvDilationH
=
ConvDilations
{}[
0
];
constexpr
index_t
ConvDilationW
=
ConvDilations
{}[
1
];
constexpr
index_t
E
=
C
*
Y
*
X
;
constexpr
index_t
B
=
N
*
Ho
*
Wo
;
// sanity-check for vectorized memory load
static_assert
((
Wo
==
1
||
(
ConvStrideW
==
1
||
InBlockCopyDataPerAccess_B
==
1
))
&&
(
X
==
1
||
ConvDilationW
%
InBlockCopyDataPerAccess_B
==
0
),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated"
);
// divide block work by [K, B]
static_assert
(
K
%
KPerBlock
==
0
&&
B
%
BPerBlock
==
0
&&
E
%
EPerBlock
==
0
,
"wrong! cannot divide work evenly among block"
);
constexpr
index_t
KBlockWork
=
K
/
KPerBlock
;
constexpr
index_t
BBlockWork
=
B
/
BPerBlock
;
constexpr
auto
block_work_desc
=
make_ConstantTensorDescriptor_packed
(
Sequence
<
KBlockWork
,
BBlockWork
>
{});
const
auto
block_work_multi_id
=
block_work_desc
.
GetMultiIndexFrom1dIndex
(
get_block_1d_id
());
const
index_t
k_block_data_on_global
=
block_work_multi_id
[
0
]
*
KPerBlock
;
const
index_t
b_block_data_on_global
=
block_work_multi_id
[
1
]
*
BPerBlock
;
// input tensor
// global mem
constexpr
auto
in_n_c_hip_wip_global_desc
=
transform_tensor_descriptor
(
in_n_c_hi_wi_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
C
>
{},
Pad
<
Sequence
<
Hi
,
Wi
>
,
LeftPads
,
RightPads
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}));
constexpr
auto
in_n_c_y_ho_x_wo_global_desc
=
transform_tensor_descriptor
(
in_n_c_hip_wip_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
C
>
{},
Embed
<
Sequence
<
Y
,
Ho
>
,
Sequence
<
ConvDilationH
,
ConvStrideH
,
0
>>
{},
Embed
<
Sequence
<
X
,
Wo
>
,
Sequence
<
ConvDilationW
,
ConvStrideW
,
0
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
constexpr
auto
in_e_b_global_desc
=
transform_tensor_descriptor
(
in_n_c_y_ho_x_wo_global_desc
,
make_tuple
(
Merge
<
Sequence
<
C
,
Y
,
X
>>
{},
Merge
<
Sequence
<
N
,
Ho
,
Wo
>>
{}),
make_tuple
(
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
0
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// LDS mem
// be careful of LDS alignment
constexpr
auto
in_e_b_block_desc
=
make_native_tensor_descriptor_packed
(
Sequence
<
EPerBlock
,
BPerBlock
>
{});
// input blockwise copy
auto
blockwise_in_copy
=
BlockwiseGenericTensorSliceCopy_v4
<
BlockSize
,
decltype
(
in_e_b_global_desc
),
decltype
(
in_e_b_block_desc
),
decltype
(
in_e_b_block_desc
.
GetLengths
()),
InBlockCopySubLengths_E_B
,
InBlockCopyClusterLengths_E_B
,
InBlockCopyThreadClusterArrangeOrder
,
InBlockCopySrcAccessOrder
,
InBlockCopyDstAccessOrder
,
1
,
1
,
InBlockCopyDataPerAccess_B
,
InBlockCopyDataPerAccess_B
>
(
{
0
,
b_block_data_on_global
},
{
0
,
0
});
// weight tensor
// global mem
constexpr
auto
wei_e_k_global_desc
=
transform_tensor_descriptor
(
wei_k_c_y_x_global_desc
,
make_tuple
(
Merge
<
Sequence
<
C
,
Y
,
X
>>
{},
PassThrough
<
K
>
{}),
make_tuple
(
Sequence
<
1
,
2
,
3
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// LDS
// be careful of LDS alignment
constexpr
auto
wei_e_k_block_desc
=
make_native_tensor_descriptor_aligned
(
Sequence
<
EPerBlock
,
KPerBlock
>
{},
Number
<
math
::
lcm
(
WeiBlockCopyDstDataPerWrite_K
,
GemmDataPerReadA
)
>
{});
// this check is ad-hoc
// TODO: need to properly implement tensor descriptor with multiple alignment
// requirements
static_assert
(
wei_e_k_block_desc
.
GetStride
(
I0
)
%
GemmDataPerReadA
==
0
,
"GemmDataPerReadA alignment requirement is not satisfied"
);
// weight blockwise copy
auto
blockwise_wei_copy
=
BlockwiseGenericTensorSliceCopy_v4
<
BlockSize
,
decltype
(
wei_e_k_global_desc
),
decltype
(
wei_e_k_block_desc
),
decltype
(
wei_e_k_block_desc
.
GetLengths
()),
WeiBlockCopySubLengths_E_K
,
WeiBlockCopyClusterLengths_E_K
,
WeiBlockCopyThreadClusterArrangeOrder
,
WeiBlockCopySrcAccessOrder
,
WeiBlockCopyDstAccessOrder
,
0
,
1
,
WeiBlockCopySrcDataPerRead_E
,
WeiBlockCopyDstDataPerWrite_K
>
(
{
0
,
k_block_data_on_global
},
{
0
,
0
});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[EPerBlock, KPerBlock] is in LDS
// b_mtx[EPerBlocl, BPerBlock] is in LDS
// c_mtx[KPerBlock, BPerBlock] is distributed among threads, and saved in
// register
constexpr
auto
a_e_k_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
wei_e_k_block_desc
);
constexpr
auto
b_e_b_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
in_e_b_block_desc
);
// sanity check
static_assert
(
KPerBlock
%
(
GemmMPerThreadSubC
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
)
==
0
&&
BPerBlock
%
(
GemmNPerThreadSubC
*
GemmNLevel0Cluster
*
GemmNLevel1Cluster
)
==
0
,
"wrong!"
);
constexpr
index_t
GemmMRepeat
=
KPerBlock
/
(
GemmMPerThreadSubC
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
);
constexpr
index_t
GemmNRepeat
=
BPerBlock
/
(
GemmNPerThreadSubC
*
GemmNLevel0Cluster
*
GemmNLevel1Cluster
);
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
constexpr
auto
c_k0k1_b0b1_thread_mtx_desc
=
make_ConstantMatrixDescriptor_packed
(
Number
<
GemmMRepeat
*
GemmMPerThreadSubC
>
{},
Number
<
GemmNRepeat
*
GemmNPerThreadSubC
>
{});
const
auto
blockwise_gemm
=
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
<
BlockSize
,
decltype
(
a_e_k_block_mtx_desc
),
decltype
(
b_e_b_block_mtx_desc
),
decltype
(
c_k0k1_b0b1_thread_mtx_desc
),
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
GemmDataPerReadA
,
GemmDataPerReadB
>
{};
// LDS allocation for input and weight: be careful of alignment
constexpr
index_t
max_align
=
math
::
lcm
(
InBlockCopyDataPerAccess_B
,
WeiBlockCopyDstDataPerWrite_K
,
GemmDataPerReadA
,
GemmDataPerReadB
);
constexpr
index_t
in_block_space
=
math
::
integer_least_multiple
(
in_e_b_block_desc
.
GetElementSpace
(),
max_align
);
constexpr
index_t
wei_block_space
=
math
::
integer_least_multiple
(
wei_e_k_block_desc
.
GetElementSpace
(),
max_align
);
__shared__
Float
p_in_block
[
in_block_space
];
__shared__
Float
p_wei_block
[
wei_block_space
];
// register allocation for output
Float
p_out_thread
[
c_k0k1_b0b1_thread_mtx_desc
.
GetElementSpace
()];
// zero out threadwise output
threadwise_matrix_set_zero
(
c_k0k1_b0b1_thread_mtx_desc
,
p_out_thread
);
for
(
index_t
e_block_data_begin
=
0
;
e_block_data_begin
<
E
;
e_block_data_begin
+=
EPerBlock
)
{
blockwise_in_copy
.
Run
(
p_in_global
,
p_in_block
);
blockwise_wei_copy
.
Run
(
p_wei_global
,
p_wei_block
);
__syncthreads
();
blockwise_gemm
.
Run
(
p_wei_block
,
p_in_block
,
p_out_thread
);
__syncthreads
();
blockwise_in_copy
.
MoveSrcSliceWindow
(
make_multi_index
(
EPerBlock
,
0
),
True
);
blockwise_wei_copy
.
MoveSrcSliceWindow
(
make_multi_index
(
EPerBlock
,
0
),
True
);
}
// copy output: register to global memory
{
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
blockwise_gemm
.
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
const
index_t
k_thread_data_on_global
=
k_block_data_on_global
+
c_thread_mtx_on_block
.
row
;
const
index_t
b_thread_data_on_global
=
b_block_data_on_global
+
c_thread_mtx_on_block
.
col
;
// src descriptor
constexpr
auto
out_k0_k1_b0_b1_thread_desc
=
make_native_tensor_descriptor_packed
(
Sequence
<
GemmMRepeat
,
GemmMPerThreadSubC
,
GemmNRepeat
,
GemmNPerThreadSubC
>
{});
// dst descriptor
constexpr
index_t
K1
=
GemmMPerThreadSubC
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
;
constexpr
index_t
B1
=
GemmNPerThreadSubC
*
GemmNLevel0Cluster
*
GemmNLevel1Cluster
;
constexpr
index_t
K0
=
K
/
K1
;
constexpr
index_t
B0
=
B
/
B1
;
constexpr
auto
out_k_b_global_desc
=
transform_tensor_descriptor
(
out_n_k_ho_wo_global_desc
,
make_tuple
(
PassThrough
<
K
>
{},
Merge
<
Sequence
<
N
,
Ho
,
Wo
>>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
constexpr
auto
out_k0_k1_b0_b1_global_desc
=
transform_tensor_descriptor
(
out_k_b_global_desc
,
make_tuple
(
Unmerge
<
Sequence
<
K0
,
K1
>>
{},
Unmerge
<
Sequence
<
B0
,
B1
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}));
// output threadwise copy
auto
threadwise_out_copy
=
ThreadwiseGenericTensorSliceCopy_v4r2
<
decltype
(
out_k0_k1_b0_b1_thread_desc
),
decltype
(
out_k0_k1_b0_b1_global_desc
),
decltype
(
out_k0_k1_b0_b1_thread_desc
.
GetLengths
()),
arithmetic_sequence_gen
<
0
,
4
,
1
>::
type
,
3
,
OutThreadCopyDataPerAccess_B
,
OutThreadCopyDataPerAccess_B
>
({
0
,
0
,
0
,
0
},
{
k_thread_data_on_global
/
K1
,
k_thread_data_on_global
%
K1
,
b_thread_data_on_global
/
B1
,
b_thread_data_on_global
%
B1
});
threadwise_out_copy
.
Run
(
p_out_thread
,
p_out_global
);
}
}
#else
__device__
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
const
Float
*
const
__restrict__
p_wei_global
,
Float
*
const
__restrict__
p_out_global
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
in_n_c_hi_wi_global_desc
=
make_native_tensor_descriptor
(
InGlobalDesc
::
GetLengths
(),
InGlobalDesc
::
GetStrides
());
constexpr
auto
wei_k_c_y_x_global_desc
=
make_native_tensor_descriptor
(
WeiGlobalDesc
::
GetLengths
(),
WeiGlobalDesc
::
GetStrides
());
constexpr
auto
out_n_k_ho_wo_global_desc
=
make_native_tensor_descriptor
(
OutGlobalDesc
::
GetLengths
(),
OutGlobalDesc
::
GetStrides
());
constexpr
index_t
N
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I0
);
constexpr
index_t
C
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Hi
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I2
);
constexpr
index_t
Wi
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I3
);
constexpr
index_t
K
=
out_n_k_ho_wo_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Ho
=
out_n_k_ho_wo_global_desc
.
GetLength
(
I2
);
constexpr
index_t
Wo
=
out_n_k_ho_wo_global_desc
.
GetLength
(
I3
);
constexpr
index_t
Y
=
wei_k_c_y_x_global_desc
.
GetLength
(
I2
);
constexpr
index_t
X
=
wei_k_c_y_x_global_desc
.
GetLength
(
I3
);
constexpr
index_t
ConvStrideH
=
ConvStrides
{}[
0
];
constexpr
index_t
ConvStrideW
=
ConvStrides
{}[
1
];
constexpr
index_t
ConvDilationH
=
ConvDilations
{}[
0
];
constexpr
index_t
ConvDilationW
=
ConvDilations
{}[
1
];
constexpr
index_t
E
=
C
*
Y
*
X
;
constexpr
index_t
B
=
N
*
Ho
*
Wo
;
// sanity-check for vectorized memory load
static_assert
((
Ho
==
1
||
ConvStrideW
%
InBlockCopyDataPerAccess_B
==
0
)
&&
(
X
==
1
||
ConvDilationW
%
InBlockCopyDataPerAccess_B
==
0
),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated"
);
// input tensor
constexpr
auto
in_n_c_hip_wip_global_desc
=
transform_tensor_descriptor
(
in_n_c_hi_wi_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
C
>
{},
Pad
<
Sequence
<
Hi
,
Wi
>
,
LeftPads
,
RightPads
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}));
constexpr
auto
in_n_c_y_ho_x_wo_global_desc
=
transform_tensor_descriptor
(
in_n_c_hip_wip_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
C
>
{},
Embed
<
Sequence
<
Y
,
Ho
>
,
Sequence
<
ConvDilationH
,
ConvStrideH
,
0
>>
{},
Embed
<
Sequence
<
X
,
Wo
>
,
Sequence
<
ConvDilationW
,
ConvStrideW
,
0
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
constexpr
auto
in_e_b_global_desc
=
transform_tensor_descriptor
(
in_n_c_y_ho_x_wo_global_desc
,
make_tuple
(
Merge
<
Sequence
<
C
,
Y
,
X
>>
{},
Merge
<
Sequence
<
N
,
Ho
,
Wo
>>
{}),
make_tuple
(
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
0
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// output tensor
constexpr
auto
out_k_b_global_desc
=
transform_tensor_descriptor
(
out_n_k_ho_wo_global_desc
,
make_tuple
(
PassThrough
<
K
>
{},
Merge
<
Sequence
<
N
,
Ho
,
Wo
>>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
constexpr
index_t
K1
=
GemmMPerThreadSubC
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
;
constexpr
index_t
B1
=
GemmNPerThreadSubC
*
GemmNLevel0Cluster
*
GemmNLevel1Cluster
;
constexpr
index_t
K0
=
K
/
K1
;
constexpr
index_t
B0
=
B
/
B1
;
constexpr
auto
out_k0_k1_b0_b1_global_desc
=
transform_tensor_descriptor
(
out_k_b_global_desc
,
make_tuple
(
Unmerge
<
Sequence
<
K0
,
K1
>>
{},
Unmerge
<
Sequence
<
B0
,
B1
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}));
#if 1
if
(
get_thread_local_1d_id
()
==
0
&&
get_block_1d_id
()
==
0
)
{
print_tensor_descriptor
(
"in_e_b_global_desc: "
,
in_e_b_global_desc
);
print_tensor_descriptor
(
"in_n_c_y_ho_x_wo_global_desc: "
,
in_n_c_y_ho_x_wo_global_desc
);
print_tensor_descriptor
(
"in_n_c_hip_wip_global_desc: "
,
in_n_c_hip_wip_global_desc
);
print_tensor_descriptor
(
"in_n_c_hi_wi_global_desc: "
,
in_n_c_hi_wi_global_desc
);
auto
coord3
=
make_tensor_coordinate_v2
(
in_e_b_global_desc
,
{
1
,
1
});
auto
idx3
=
coord3
.
GetIndex
();
auto
idx2
=
coord3
.
GetLowerCoordinate
().
GetIndex
();
auto
idx1
=
coord3
.
GetLowerCoordinate
().
GetLowerCoordinate
().
GetIndex
();
auto
idx0
=
coord3
.
GetLowerCoordinate
().
GetLowerCoordinate
().
GetLowerCoordinate
().
GetIndex
();
print_array
(
"idx3: "
,
idx3
);
print_array
(
"idx2: "
,
idx2
);
print_array
(
"idx1: "
,
idx1
);
print_array
(
"idx0: "
,
idx0
);
}
if
(
get_thread_local_1d_id
()
==
0
&&
get_block_1d_id
()
==
0
)
{
print_tensor_descriptor
(
"out_k0_k1_b0_b1_global_desc: "
,
out_k0_k1_b0_b1_global_desc
);
print_tensor_descriptor
(
"out_k_b_global_desc: "
,
out_k_b_global_desc
);
print_tensor_descriptor
(
"out_n_k_ho_wo_global_desc: "
,
out_n_k_ho_wo_global_desc
);
auto
coord2
=
make_tensor_coordinate_v2
(
out_k0_k1_b0_b1_global_desc
,
{
1
,
1
,
1
,
1
});
auto
idx2
=
coord2
.
GetIndex
();
auto
idx1
=
coord2
.
GetLowerCoordinate
().
GetIndex
();
auto
idx0
=
coord2
.
GetLowerCoordinate
().
GetLowerCoordinate
().
GetIndex
();
print_array
(
"idx2: "
,
idx2
);
print_array
(
"idx1: "
,
idx1
);
print_array
(
"idx0: "
,
idx0
);
}
#endif
}
#endif
};
}
// namespace ck
#endif
composable_kernel/include/kernel_algorithm/gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hpp
deleted
100644 → 0
View file @
98a2cfcc
#pragma once
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "blockwise_2d_tensor_op.hpp"
#include "blockwise_4d_tensor_op.hpp"
#include "blockwise_direct_convolution.hpp"
#include "threadwise_4d_tensor_op.hpp"
#include "threadwise_direct_convolution.hpp"
namespace
ck
{
template
<
class
TInWei
,
class
TOut
,
class
TAccum
,
class
InGlobalDesc
,
class
WeiGlobalDesc
,
class
OutGlobalDesc
,
index_t
ScalarPerVector
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
CPerBlock
,
index_t
HoPerBlock
,
index_t
WoPerBlock
,
index_t
NPerThread
,
index_t
KPerThread
,
index_t
CPerThread
,
index_t
HoPerThread
,
index_t
WoPerThread
,
index_t
InBlockCopyDataPerRead
,
index_t
WeiBlockCopyDataPerRead
,
index_t
BlockSize
,
index_t
GridSize
>
__global__
void
gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw
(
const
typename
vector_type
<
TInWei
,
ScalarPerVector
>::
MemoryType
*
const
__restrict__
p_in_vec_global
,
const
typename
vector_type
<
TInWei
,
ScalarPerVector
>::
MemoryType
*
const
__restrict__
p_wei_vec_global
,
TOut
*
const
__restrict__
p_out_global
)
{
using
in_scalar_t
=
TInWei
;
using
in_vector_mem_t
=
typename
vector_type
<
in_scalar_t
,
ScalarPerVector
>::
MemoryType
;
using
out_scalar_t
=
TOut
;
using
accum_t
=
TAccum
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
in_nchw_vec_global_desc
=
InGlobalDesc
{};
constexpr
auto
wei_kcyx_vec_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
out_nkhw_global_desc
=
OutGlobalDesc
{};
constexpr
index_t
N
=
in_nchw_vec_global_desc
.
GetLength
(
I0
);
constexpr
index_t
K
=
wei_kcyx_vec_global_desc
.
GetLength
(
I0
);
constexpr
index_t
C
=
wei_kcyx_vec_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Y
=
wei_kcyx_vec_global_desc
.
GetLength
(
I2
);
constexpr
index_t
X
=
wei_kcyx_vec_global_desc
.
GetLength
(
I3
);
constexpr
auto
wei_ke_vec_global_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
K
,
C
*
Y
*
X
>
{});
// 2d view of wei for blockwise copy
constexpr
index_t
HiPerBlock
=
HoPerBlock
+
Y
-
1
;
constexpr
index_t
WiPerBlock
=
WoPerBlock
+
X
-
1
;
constexpr
auto
in_nchw_vec_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
NPerBlock
,
CPerBlock
,
HiPerBlock
,
WiPerBlock
>
{},
Number
<
InBlockCopyDataPerRead
>
{});
constexpr
auto
wei_ke_vec_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
KPerBlock
,
CPerBlock
*
Y
*
X
>
{},
Number
<
WeiBlockCopyDataPerRead
>
{});
// 2d view of wei for blockwise copy
constexpr
auto
wei_kcyx_vec_block_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
KPerBlock
,
CPerBlock
,
Y
,
X
>
{},
Sequence
<
wei_ke_vec_block_desc
.
GetStride
(
I0
),
Y
*
X
,
X
,
1
>
{});
// shared mem
constexpr
index_t
in_block_element_size
=
in_nchw_vec_block_desc
.
GetElementSpace
(
Number
<
InBlockCopyDataPerRead
>
{});
constexpr
index_t
wei_block_element_size
=
wei_kcyx_vec_block_desc
.
GetElementSpace
(
Number
<
WeiBlockCopyDataPerRead
>
{});
constexpr
index_t
max_align
=
InBlockCopyDataPerRead
>
WeiBlockCopyDataPerRead
?
InBlockCopyDataPerRead
:
WeiBlockCopyDataPerRead
;
__shared__
in_vector_mem_t
p_in_vec_block
[
max_align
*
((
in_block_element_size
+
max_align
-
1
)
/
max_align
)];
__shared__
in_vector_mem_t
p_wei_vec_block
[
max_align
*
((
wei_block_element_size
+
max_align
-
1
)
/
max_align
)];
// threadwise tensors
constexpr
index_t
HiPerThread
=
HoPerThread
+
Y
-
1
;
constexpr
index_t
WiPerThread
=
WoPerThread
+
X
-
1
;
constexpr
auto
in_nchw_vec_thread_block_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
NPerThread
,
CPerThread
,
HiPerThread
,
WiPerThread
>
{},
in_nchw_vec_block_desc
.
GetStrides
());
constexpr
auto
wei_kcyx_vec_thread_block_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
KPerThread
,
CPerThread
,
Y
,
X
>
{},
wei_kcyx_vec_block_desc
.
GetStrides
());
constexpr
auto
out_nkhw_thread_desc
=
get_convolution_output_default_4d_tensor_descriptor
(
in_nchw_vec_thread_block_desc
,
wei_kcyx_vec_thread_block_desc
);
// register
out_scalar_t
p_out_thread
[
out_nkhw_thread_desc
.
GetElementSpace
()];
// divide block work
constexpr
index_t
NBlockWork
=
(
out_nkhw_global_desc
.
GetLength
(
I0
)
+
NPerBlock
-
1
)
/
NPerBlock
;
constexpr
index_t
KBlockWork
=
(
out_nkhw_global_desc
.
GetLength
(
I1
)
+
KPerBlock
-
1
)
/
KPerBlock
;
constexpr
index_t
HBlockWork
=
(
out_nkhw_global_desc
.
GetLength
(
I2
)
+
HoPerBlock
-
1
)
/
HoPerBlock
;
constexpr
index_t
WBlockWork
=
(
out_nkhw_global_desc
.
GetLength
(
I3
)
+
WoPerBlock
-
1
)
/
WoPerBlock
;
const
index_t
block_id
=
blockIdx
.
x
;
index_t
itmp
=
block_id
;
const
index_t
n_block_work_id
=
itmp
/
(
KBlockWork
*
HBlockWork
*
WBlockWork
);
itmp
-=
n_block_work_id
*
(
KBlockWork
*
HBlockWork
*
WBlockWork
);
const
index_t
k_block_work_id
=
itmp
/
(
HBlockWork
*
WBlockWork
);
itmp
-=
k_block_work_id
*
(
HBlockWork
*
WBlockWork
);
const
index_t
h_block_work_id
=
itmp
/
WBlockWork
;
const
index_t
w_block_work_id
=
itmp
-
h_block_work_id
*
WBlockWork
;
const
index_t
n_block_data_begin
=
n_block_work_id
*
NPerBlock
;
const
index_t
k_block_data_begin
=
k_block_work_id
*
KPerBlock
;
const
index_t
ho_block_data_begin
=
h_block_work_id
*
HoPerBlock
;
const
index_t
wo_block_data_begin
=
w_block_work_id
*
WoPerBlock
;
const
index_t
hi_block_data_begin
=
ho_block_data_begin
;
// minus padding
const
index_t
wi_block_data_begin
=
wo_block_data_begin
;
// minus padding
// divide thread work
constexpr
index_t
NThreadWork
=
(
NPerBlock
+
NPerThread
-
1
)
/
NPerThread
;
constexpr
index_t
KThreadWork
=
(
KPerBlock
+
KPerThread
-
1
)
/
KPerThread
;
constexpr
index_t
HThreadWork
=
(
HoPerBlock
+
HoPerThread
-
1
)
/
HoPerThread
;
constexpr
index_t
WThreadWork
=
(
WoPerBlock
+
WoPerThread
-
1
)
/
WoPerThread
;
const
index_t
thread_id
=
get_thread_local_1d_id
();
itmp
=
thread_id
;
const
index_t
n_thread_work_id
=
itmp
/
(
KThreadWork
*
HThreadWork
*
WThreadWork
);
itmp
-=
n_thread_work_id
*
(
KThreadWork
*
HThreadWork
*
WThreadWork
);
const
index_t
k_thread_work_id
=
itmp
/
(
HThreadWork
*
WThreadWork
);
itmp
-=
k_thread_work_id
*
(
HThreadWork
*
WThreadWork
);
const
index_t
h_thread_work_id
=
itmp
/
WThreadWork
;
const
index_t
w_thread_work_id
=
itmp
-
h_thread_work_id
*
WThreadWork
;
const
index_t
n_thread_data_begin
=
n_thread_work_id
*
NPerThread
;
const
index_t
k_thread_data_begin
=
k_thread_work_id
*
KPerThread
;
const
index_t
ho_thread_data_begin
=
h_thread_work_id
*
HoPerThread
;
const
index_t
wo_thread_data_begin
=
w_thread_work_id
*
WoPerThread
;
const
index_t
hi_thread_data_begin
=
ho_thread_data_begin
;
const
index_t
wi_thread_data_begin
=
wo_thread_data_begin
;
constexpr
auto
blockwise_in_copy
=
Blockwise4dTensorCopy1
<
BlockSize
,
in_vector_mem_t
,
decltype
(
in_nchw_vec_global_desc
),
decltype
(
in_nchw_vec_block_desc
),
decltype
(
in_nchw_vec_block_desc
.
GetLengths
()),
InBlockCopyDataPerRead
>
{};
#if 0
constexpr auto blockwise_wei_copy =
Blockwise4dTensorCopy1<BlockSize,
in_vector_mem_t,
decltype(wei_kcyx_vec_global_desc),
decltype(wei_kcyx_vec_block_desc),
decltype(wei_kcyx_vec_block_desc.GetLengths()),
1>{};
#elif
1
const
auto
blockwise_wei_copy
=
Blockwise2dTensorCopy3
<
BlockSize
,
in_vector_mem_t
,
decltype
(
wei_ke_vec_global_desc
),
decltype
(
wei_ke_vec_block_desc
),
decltype
(
wei_ke_vec_block_desc
.
GetLengths
()),
WeiBlockCopyDataPerRead
>
{};
#endif
#if 1 // debug
// set threadwise output tensor to 0
threadwise_4d_tensor_set_zero
(
out_nkhw_thread_desc
,
p_out_thread
);
#endif
for
(
index_t
c_block_data_begin
=
0
;
c_block_data_begin
<
C
;
c_block_data_begin
+=
CPerBlock
,
__syncthreads
())
{
// copy input tensor to LDS
blockwise_in_copy
.
Run
(
p_in_vec_global
+
in_nchw_vec_global_desc
.
GetOffsetFromMultiIndex
(
n_block_data_begin
,
c_block_data_begin
,
hi_block_data_begin
,
wi_block_data_begin
),
p_in_vec_block
);
// copy weight tensor to LDS
blockwise_wei_copy
.
Run
(
p_wei_vec_global
+
wei_kcyx_vec_global_desc
.
GetOffsetFromMultiIndex
(
k_block_data_begin
,
c_block_data_begin
,
0
,
0
),
p_wei_vec_block
);
__syncthreads
();
for
(
index_t
c_thread_data
=
0
;
c_thread_data
<
CPerBlock
;
c_thread_data
+=
CPerThread
)
{
// threadwise convolution
#if 1
threadwise_direct_convolution_2
(
in_nchw_vec_thread_block_desc
,
p_in_vec_block
+
in_nchw_vec_block_desc
.
GetOffsetFromMultiIndex
(
n_thread_data_begin
,
c_thread_data
,
hi_thread_data_begin
,
wi_thread_data_begin
),
wei_kcyx_vec_thread_block_desc
,
p_wei_vec_block
+
wei_kcyx_vec_block_desc
.
GetOffsetFromMultiIndex
(
k_thread_data_begin
,
c_thread_data
,
0
,
0
),
out_nkhw_thread_desc
,
p_out_thread
);
#elif 0
threadwise_direct_convolution_3
(
in_nchw_vec_thread_block_desc
,
p_in_vec_block
+
in_nchw_vec_block_desc
.
GetOffsetFromMultiIndex
(
n_thread_data_begin
,
c_thread_data
,
hi_thread_data_begin
,
wi_thread_data_begin
),
wei_kcyx_vec_thread_block_desc
,
p_wei_vec_block
+
wei_kcyx_vec_block_desc
.
GetOffsetFromMultiIndex
(
k_thread_data_begin
,
c_thread_data
,
0
,
0
),
out_nkhw_thread_desc
,
p_out_thread
);
#endif
}
}
// copy output tensor from register to global mem
threadwise_4d_tensor_copy
(
out_nkhw_thread_desc
,
p_out_thread
,
out_nkhw_global_desc
,
p_out_global
+
out_nkhw_global_desc
.
GetOffsetFromMultiIndex
(
n_block_data_begin
+
n_thread_data_begin
,
k_block_data_begin
+
k_thread_data_begin
,
ho_block_data_begin
+
ho_thread_data_begin
,
wo_block_data_begin
+
wo_thread_data_begin
),
out_nkhw_thread_desc
.
GetLengths
());
}
}
// namespace ck
composable_kernel/include/kernel_algorithm/gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded.hpp
deleted
100644 → 0
View file @
98a2cfcc
#pragma once
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_4d_tensor_op.hpp"
#include "blockwise_2d_tensor_op.hpp"
#include "threadwise_4d_tensor_op.hpp"
#include "blockwise_gemm.hpp"
namespace
ck
{
template
<
index_t
GridSize
,
index_t
BlockSize
,
class
Float
,
class
InGlobalDesc
,
class
WeiGlobalDesc
,
class
OutGlobalDesc
,
class
LowerPads
,
class
UpperPads
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
CPerBlock
,
index_t
HoPerBlock
,
index_t
WoPerBlock
,
index_t
NPerThread
,
index_t
KPerThread
,
index_t
CPerThread
,
index_t
HoPerThread
,
index_t
WoPerThread
,
index_t
WeiBlockCopyThreadPerDim0
,
index_t
WeiBlockCopyThreadPerDim1
>
__global__
void
gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded
(
const
Float
*
const
__restrict__
p_in_global
,
const
Float
*
const
__restrict__
p_wei_global
,
Float
*
const
__restrict__
p_out_global
)
{
// NPerThread == NPerBlock, because the format of input in LDS [C,Hi,Wi,N]
// for GEMM trans([C,K]) * [C,Wo*N], we need a thread to do all the "N"
// if we use [C,Hi,N,Wi,N] in LDS, then NPerThread can be different from NPerBlock
static_assert
(
NPerBlock
%
NPerThread
==
0
,
"wrong! NPerBlock % NPerThread !=0"
);
static_assert
((
NPerThread
<
NPerBlock
&&
WoPerThread
==
1
)
||
NPerThread
==
NPerBlock
,
"wrong!"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
in_chwn_global_desc
=
InGlobalDesc
{};
constexpr
auto
wei_cyxk_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
out_khwn_global_desc
=
OutGlobalDesc
{};
constexpr
index_t
C
=
in_chwn_global_desc
.
GetLength
(
I0
);
constexpr
index_t
K
=
out_khwn_global_desc
.
GetLength
(
I0
);
constexpr
index_t
Ho
=
out_khwn_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Wo
=
out_khwn_global_desc
.
GetLength
(
I2
);
constexpr
index_t
N
=
out_khwn_global_desc
.
GetLength
(
I3
);
constexpr
index_t
Y
=
wei_cyxk_global_desc
.
GetLength
(
I1
);
constexpr
index_t
X
=
wei_cyxk_global_desc
.
GetLength
(
I2
);
constexpr
index_t
HPadLow
=
LowerPads
{}.
Get
(
I0
);
constexpr
index_t
WPadLow
=
LowerPads
{}.
Get
(
I1
);
constexpr
index_t
HPadUp
=
UpperPads
{}.
Get
(
I0
);
constexpr
index_t
WPadUp
=
UpperPads
{}.
Get
(
I1
);
constexpr
index_t
HiPerBlock
=
HoPerBlock
+
Y
-
1
;
constexpr
index_t
WiPerBlock
=
WoPerBlock
+
X
-
1
;
// divide block work: [K, Ho, Wo, N]
constexpr
index_t
KBlockWork
=
(
K
+
KPerBlock
-
1
)
/
KPerBlock
;
constexpr
index_t
HBlockWork
=
(
Ho
+
HoPerBlock
-
1
)
/
HoPerBlock
;
constexpr
index_t
WBlockWork
=
(
Wo
+
WoPerBlock
-
1
)
/
WoPerBlock
;
constexpr
index_t
NBlockWork
=
(
N
+
NPerBlock
-
1
)
/
NPerBlock
;
const
index_t
k_block_work_id
=
get_block_1d_id
()
/
(
HBlockWork
*
WBlockWork
*
NBlockWork
);
index_t
itmp
=
get_block_1d_id
()
-
k_block_work_id
*
(
HBlockWork
*
WBlockWork
*
NBlockWork
);
const
index_t
h_block_work_id
=
itmp
/
(
WBlockWork
*
NBlockWork
);
itmp
-=
h_block_work_id
*
(
WBlockWork
*
NBlockWork
);
const
index_t
w_block_work_id
=
itmp
/
NBlockWork
;
const
index_t
n_block_work_id
=
itmp
-
w_block_work_id
*
NBlockWork
;
const
index_t
k_block_data_begin
=
k_block_work_id
*
KPerBlock
;
const
index_t
ho_block_data_begin
=
h_block_work_id
*
HoPerBlock
;
const
index_t
wo_block_data_begin
=
w_block_work_id
*
WoPerBlock
;
const
index_t
n_block_data_begin
=
n_block_work_id
*
NPerBlock
;
// flattened (2d) tensor view of wei in global mem
constexpr
auto
wei_ek_global_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
C
*
Y
*
X
,
K
>
{});
// tensor view of blockwise input and weight in LDS
constexpr
auto
in_chwn_block_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
CPerBlock
,
HiPerBlock
,
WiPerBlock
,
NPerBlock
>
{});
constexpr
auto
wei_cyxk_block_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
CPerBlock
,
Y
,
X
,
KPerBlock
>
{});
// flattened (2d) tensor view of wei in LDS
constexpr
auto
wei_ek_block_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
CPerBlock
*
Y
*
X
,
KPerBlock
>
{});
// tensor view of threadwise output in register
constexpr
auto
out_hkwn_thread_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
HoPerThread
,
KPerThread
,
WoPerThread
,
NPerThread
>
{});
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantTensorDescriptor(in_chwn_block_desc, "in_chwn_block_desc");
print_ConstantTensorDescriptor(wei_cyxk_block_desc, "wei_cyxk_block_desc");
print_ConstantTensorDescriptor(out_hkwn_thread_desc, "out_hkwn_thread_desc");
}
#endif
// blockwise copy
// input: format is [C, Hi, Wi, N]
const
index_t
h_block_pad_low
=
h_block_work_id
==
0
?
HPadLow
:
0
;
const
index_t
w_block_pad_low
=
w_block_work_id
==
0
?
WPadLow
:
0
;
const
index_t
h_block_pad_up
=
h_block_work_id
==
HBlockWork
-
1
?
HPadUp
:
0
;
const
index_t
w_block_pad_up
=
w_block_work_id
==
WBlockWork
-
1
?
WPadUp
:
0
;
#if 0
if(get_thread_local_1d_id() == 0)
;
{
printf(
"%u %u, h_block_pad_low %u w_block_pad_low %u h_block_pad_up %u w_block_pad_up %u\n",
get_block_1d_id(),
get_thread_local_1d_id(),
h_block_pad_low,
w_block_pad_low,
h_block_pad_up,
w_block_pad_up);
}
#endif
constexpr
auto
blockwise_in_copy
=
BlockwiseChwnTensorCopyPadded
<
BlockSize
,
Float
,
decltype
(
in_chwn_global_desc
),
decltype
(
in_chwn_block_desc
),
decltype
(
in_chwn_block_desc
.
GetLengths
()),
LowerPads
>
{};
#if 0
// weight: format is [C,Y,X,K]
constexpr auto blockwise_wei_copy =
Blockwise4dTensorCopy1<BlockSize,
Float,
decltype(wei_cyxk_global_desc),
decltype(wei_cyxk_block_desc),
decltype(wei_cyxk_block_desc.GetLengths())>{};
#elif
0
// weight: format is [C*Y*X,K]
constexpr
auto
blockwise_wei_copy
=
Blockwise2dTensorCopy1
<
BlockSize
,
Float
,
decltype
(
wei_ek_global_desc
),
decltype
(
wei_ek_block_desc
),
decltype
(
wei_ek_block_desc
.
GetLengths
())
>
{};
#elif 1
// weight: format is [C*Y*X,K]
const
auto
blockwise_wei_copy
=
Blockwise2dTensorCopy2
<
BlockSize
,
Float
,
decltype
(
wei_ek_global_desc
),
decltype
(
wei_ek_block_desc
),
decltype
(
wei_ek_block_desc
.
GetLengths
()),
WeiBlockCopyThreadPerDim0
,
WeiBlockCopyThreadPerDim1
>
{};
#endif
// a series of blockwise batched GEMM
// C_matrix += transpose(A_matrix) * B_matrix
// A_matrix and B_matrix saved in LDS, C_matrix saved in register
// A_matrix[C,K] is a sub-matrix of wei_block[C,Y,X,K]
// B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N]
// C_matrix[K,Wo*N] is a sub-matrix of out_block[Ho,K,Wo,N]
constexpr
auto
a_cxk_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
CPerBlock
>
{},
Number
<
KPerBlock
>
{},
Number
<
wei_cyxk_block_desc
.
GetStride
(
I0
)
>
{});
constexpr
auto
b_cxwn_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
CPerBlock
>
{},
Number
<
WoPerBlock
*
NPerBlock
>
{},
Number
<
in_chwn_block_desc
.
GetStride
(
I0
)
>
{});
constexpr
auto
c_kxwn_thread_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThread
>
{},
Number
<
WoPerThread
*
NPerThread
>
{});
const
auto
blockwise_batch_gemm
=
Blockwise1dStridedBatchedGemmBlockABlockBThreadC
<
BlockSize
,
decltype
(
a_cxk_block_mtx_desc
),
decltype
(
b_cxwn_block_mtx_desc
),
decltype
(
c_kxwn_thread_mtx_desc
),
true
,
false
,
false
,
0
,
in_chwn_block_desc
.
GetStride
(
I1
),
out_hkwn_thread_desc
.
GetStride
(
I0
),
HoPerBlock
,
HoPerThread
,
CPerThread
,
true
>
{};
// LDS
constexpr
index_t
in_block_element_size
=
in_chwn_block_desc
.
GetElementSpace
();
constexpr
index_t
wei_block_element_size
=
wei_cyxk_block_desc
.
GetElementSpace
();
__shared__
Float
p_in_block
[
in_block_element_size
];
__shared__
Float
p_wei_block
[
wei_block_element_size
];
// register
Float
p_out_thread
[
out_hkwn_thread_desc
.
GetElementSpace
()];
// set threadwise output tensor to 0
threadwise_4d_tensor_set_zero
(
out_hkwn_thread_desc
,
p_out_thread
);
const
Float
*
p_wei_global_block_begin
=
p_wei_global
+
wei_ek_global_desc
.
GetOffsetFromMultiIndex
(
0
,
k_block_data_begin
);
for
(
index_t
c_block_data_begin
=
0
;
c_block_data_begin
<
C
;
c_block_data_begin
+=
CPerBlock
,
p_wei_global_block_begin
+=
CPerBlock
*
wei_ek_global_desc
.
GetStride
(
I0
),
__syncthreads
())
{
#if 1
// input: global mem to LDS,
blockwise_in_copy
.
Run
(
p_in_global
,
c_block_data_begin
,
ho_block_data_begin
,
wo_block_data_begin
,
n_block_data_begin
,
p_in_block
,
h_block_pad_low
,
w_block_pad_low
,
h_block_pad_up
,
w_block_pad_up
);
#endif
#if 1
// weight: global mem to LDS,
blockwise_wei_copy
.
Run
(
p_wei_global_block_begin
,
p_wei_block
);
#endif
__syncthreads
();
// a series of batched GEMM
for
(
index_t
y
=
0
;
y
<
Y
;
++
y
)
{
for
(
index_t
x
=
0
;
x
<
X
;
++
x
)
{
auto
f_accum
=
[](
auto
&
acc
,
const
auto
&&
v
)
{
acc
+=
v
;
};
blockwise_batch_gemm
.
Run
(
p_wei_block
+
wei_cyxk_block_desc
.
GetOffsetFromMultiIndex
(
0
,
y
,
x
,
0
),
p_in_block
+
in_chwn_block_desc
.
GetOffsetFromMultiIndex
(
0
,
y
,
x
,
0
),
p_out_thread
,
f_accum
);
}
}
}
const
auto
matrix_c_index
=
blockwise_batch_gemm
.
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
const
index_t
ho_thread_data_begin
=
matrix_c_index
.
batch
;
const
index_t
k_thread_data_begin
=
matrix_c_index
.
row
;
const
index_t
wo_thread_data_begin
=
matrix_c_index
.
col
/
NPerBlock
;
const
index_t
n_thread_data_begin
=
matrix_c_index
.
col
-
wo_thread_data_begin
*
NPerBlock
;
#if 0
printf("block %u %u, %u %u %u %u, %u %u %u %u, %f \n",
get_block_1d_id(), get_thread_local_1d_id(),
ho_block_data_begin, k_block_data_begin, wo_block_data_begin, n_block_data_begin,
ho_thread_data_begin, k_thread_data_begin, wo_thread_data_begin, n_thread_data_begin,
p_out_thread[0]);
#endif
// output: register to global mem,
// convert out_thread[Ho,K,Wo,N] to out_global[K,Ho,Wo,N]
constexpr
auto
reorder_khwn_from_hkwn
=
Sequence
<
1
,
0
,
2
,
3
>
{};
threadwise_4d_tensor_copy_reorder_by_get_dst_from_src
(
out_hkwn_thread_desc
,
p_out_thread
,
out_khwn_global_desc
,
p_out_global
+
out_khwn_global_desc
.
GetOffsetFromMultiIndex
(
k_block_data_begin
+
k_thread_data_begin
,
ho_block_data_begin
+
ho_thread_data_begin
,
wo_block_data_begin
+
wo_thread_data_begin
,
n_block_data_begin
+
n_thread_data_begin
),
out_hkwn_thread_desc
.
GetLengths
(),
reorder_khwn_from_hkwn
);
}
}
// namespace ck
composable_kernel/include/tensor_description/dimension.hpp
View file @
9b280cc5
...
@@ -5,12 +5,6 @@
...
@@ -5,12 +5,6 @@
namespace
ck
{
namespace
ck
{
template
<
index_t
Length
>
struct
Dimension
{
__host__
__device__
static
constexpr
auto
GetLength
()
{
return
Number
<
Length
>
{};
}
};
template
<
index_t
Length
,
index_t
Stride
>
template
<
index_t
Length
,
index_t
Stride
>
struct
NativeDimension
struct
NativeDimension
{
{
...
...
composable_kernel/include/tensor_description/tensor_coordinate.hpp
View file @
9b280cc5
composable_kernel/include/tensor_description/tensor_coordinate_deprecated.hpp
View file @
9b280cc5
composable_kernel/include/tensor_description/tensor_view.hpp
deleted
100644 → 0
View file @
98a2cfcc
#ifndef CK_TENSOR_VIEW_HPP
#define CK_TENSOR_VIEW_HPP
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMergedTensorDescriptor.hpp"
#include "tensor_coordinate_deprecated.hpp"
namespace
ck
{
// TensorDesc is ConstantTensorDescriptor or ConstantMergedTensorDescriptor
template
<
class
TensorDesc
,
class
TData
>
struct
NormalTensorView
{
using
type
=
NormalTensorView
;
using
tensor_desc_type
=
TensorDesc
;
using
coordinate_type
=
typename
NormalTensorCoordinate_deprecated
<
TensorDesc
>::
type
;
using
data_type
=
TData
;
static
constexpr
auto
nDim
=
TensorDesc
::
GetNumOfDimension
();
__host__
__device__
constexpr
NormalTensorView
(
TData
*
p_data
)
:
mpData
{
p_data
}
{}
__host__
__device__
constexpr
NormalTensorView
()
:
NormalTensorView
{
nullptr
}
{}
__host__
__device__
static
constexpr
auto
GetNumOfDimension
()
{
return
nDim
;
}
__host__
__device__
static
constexpr
auto
GetLengths
()
{
return
TensorDesc
::
GetLengths
();
}
__host__
__device__
const
TData
&
operator
[](
coordinate_type
coord
)
const
{
return
mpData
[
coord
.
GetOffset
()];
}
__host__
__device__
TData
&
operator
()(
coordinate_type
coord
)
const
{
return
mpData
[
coord
.
GetOffset
()];
}
template
<
class
IDim
,
class
DataPerVector
>
__host__
__device__
static
constexpr
auto
IsVectorizationAllowed
(
IDim
,
DataPerVector
)
{
return
TensorDesc
::
IsVectorizationAllowed
(
IDim
{},
DataPerVector
{});
}
template
<
class
IDim
,
class
DataPerVector
>
__host__
__device__
auto
Vectorize
(
IDim
idim
,
DataPerVector
data_per_vector
)
const
{
static_assert
(
IsVectorizationAllowed
(
idim
,
data_per_vector
),
"wrong!"
);
using
vector_t
=
typename
vector_type
<
TData
,
data_per_vector
>::
MemoryType
;
return
NormalTensorView
<
decltype
(
TensorDesc
::
Vectorize
(
idim
,
data_per_vector
)),
vector_t
>
(
reinterpret_cast
<
vector_t
*>
(
mpData
));
}
template
<
index_t
...
Is
>
__host__
__device__
auto
Slice
(
coordinate_type
slice_origin
,
Sequence
<
Is
...
>
slice_lengths
)
{
static_assert
(
slice_lengths
.
GetSize
()
==
nDim
,
"wrong!"
);
return
NormalTensorView
<
decltype
(
TensorDesc
::
Slice
(
slice_lengths
)),
TData
>
(
mpData
+
slice_origin
.
GetOffset
());
}
template
<
class
IDim
,
class
SliceLen
>
__host__
__device__
auto
Slice
(
coordinate_type
slice_origin
,
IDim
idim
,
SliceLen
slice_len
)
const
{
return
NormalTensorView
<
decltype
(
TensorDesc
::
Slice
(
idim
,
slice_len
)),
TData
>
(
mpData
+
slice_origin
.
GetOffset
());
}
// slice_window is a slicing window on "*this"
template
<
class
SliceWindow
,
class
T
,
bool
PositiveDirection
>
__device__
void
MoveSliceWindow
(
SliceWindow
&
slice_window
,
T
step_sizes
,
integral_constant
<
bool
,
PositiveDirection
>
)
{
if
(
PositiveDirection
)
{
slice_window
.
mpData
+=
coordinate_type
{
step_sizes
}.
GetOffset
();
}
else
{
slice_window
.
mpData
-=
coordinate_type
{
step_sizes
}.
GetOffset
();
}
}
// private:
data_type
*
mpData
;
};
template
<
class
...
Xs
,
class
TData
>
__host__
__device__
constexpr
auto
make_TensorView
(
ConstantTensorDescriptor
<
Xs
...
>
,
TData
*
p_data
)
{
return
NormalTensorView
<
ConstantTensorDescriptor
<
Xs
...
>
,
TData
>
{
p_data
};
}
}
// namespace ck
#endif
composable_kernel/include/tensor_description/tensor_visit.hpp
deleted
100644 → 0
View file @
98a2cfcc
#ifndef CK_TENSOR_VISIT_HPP
#define CK_TENSOR_VISIT_HPP
#include "common_header.hpp"
#include "dimension.hpp"
#include "dimension_transform.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_coordinate.hpp"
namespace
ck
{
template
<
class
TensorDescriptor
>
struct
TensorVisit
{
using
Index
=
typename
TensorDescriptor
::
Index
;
using
Coordinate
=
typename
TensorCoordinate
<
TensorDescriptor
>::
type
;
__host__
__device__
static
void
Run_v1
(
Index
idx_begin
)
{
const
auto
coord_begin
=
Coordinate
(
idx_begin
);
ford
<
TensorDescriptor
::
GetLengths
()
>
{}(
[
&
](
auto
idx_diff
)
{
index_t
offset
=
(
coord_begin
+
idx_diff
).
GetOffset
();
});
}
__host__
__device__
static
void
Run_v2
(
Index
idx_begin
)
{
const
auto
coord_begin
=
Coordinate
(
idx_begin
);
ford
<
TensorDescriptor
::
GetLengths
()
>
{}([
&
](
auto
idx_diff
)
{
index_t
offset_diff
=
coord_begin
.
GetOffsetDiff
(
idx_diff
);
index_t
offset
=
coord_begin
.
GetOffset
()
+
offset_diff
;
});
}
__host__
__device__
static
void
Run_v3
(
Index
idx_begin
)
{
const
auto
coord_begin
=
Coordinate
(
idx_begin
);
constexpr
auto
linear_dimensions
=
TensorDescriptor
::
GetLinearDimensions
();
constexpr
auto
nonlinear_dimensions
=
TensorDescriptor
::
GetNonLinearDimensions
();
constexpr
auto
lengths
=
TensorDescriptor
::
GetLengths
();
constexpr
auto
linear_dimension_lengths_hack
=
lambda_HackLengths
{}(
lengths
,
linear_dimensions
);
constexpr
auto
nonlinear_dimension_lengths_hack
=
lambda_HackLengths
{}(
lengths
,
nonlinear_dimensions
);
ford
<
nonlinear_dimension_lengths_hack
>
{}([
&
](
auto
idx_diff_nonlinear_hack
)
{
// run-time component
index_t
offset_diff_nonlinear
=
coord_begin
.
GetOffsetDiff
(
idx_diff_nonlinear_hack
);
ford
<
linear_dimension_lengths_hack
>
{}([
&
](
auto
idx_diff_linear_hack
)
{
// compile-time component
index_t
offset_diff_linear
=
coord_begin
.
GetOffsetDiff
(
idx_diff_linear_hack
);
index_t
offset
=
coord_begin
.
GetOffset
()
+
offset_diff_nonlinear
+
offset_diff_linear
;
});
});
}
__host__
__device__
static
void
Run_v4
(
Index
idx_begin
)
{
const
auto
coord_begin
=
Coordinate
(
idx_begin
);
constexpr
auto
linear_dimensions
=
TensorDescriptor
::
GetLinearDimensions
();
constexpr
auto
nonlinear_independent_dimension_groups
=
TensorDescriptor
::
GetNonLinearIndependentDimensionGroups
();
constexpr
auto
lengths
=
TensorDescriptor
::
GetLengths
();
constexpr
auto
linear_dimension_lengths
=
lambda_HackLengths
{}(
lengths
,
linear_dimensions
);
// run-time component
index_t
offset_diff_nonlinear
=
0
;
template
<
index_t
NGroup
>
struct
f_recursion
{
template
<
index_t
IGroup
>
__host__
__device__
void
Run
(
Number
<
IGroup
>
)
{
constexpr
auto
nonlinear_independent_dimensions_igroup
=
nonlinear_independent_dimension_groups
.
Get
(
igroup
);
constexpr
auto
nonlinear_independent_lengths_igroup
=
lambda_HackLengths
{}(
lengths
,
nonlinear_independent_dimensions_igroup
);
ford
<
nonlinear_independent_lengths_igroup
>
{}(
[
&
](
auto
idx_diff_nonlinear_igroup_hack
)
{
// run-time component
offset_diff_nonlinear
+=
coord_begin
.
GetOffsetDiff
(
idx_diff_nonlinear_igroup_hack
);
Run
(
Number
<
IGroup
+
1
>
{});
});
};
// inner-most work
template
<
>
__host__
__device__
void
Run
(
Number
<
NGroup
>
)
{
ford
<
linear_dimension_lengths
>
{}([
&
](
auto
idx_diff_linear_hack
)
{
// compile-time component
index_t
offset_diff_linear
=
coord_begin
.
GetOffsetDiff
(
idx_diff_linear_hack
);
index_t
offset
=
coord_begin
.
GetOffset
()
+
offset_diff_nonlinear
+
offset_diff_linear
;
});
}
};
// run-time component
index_t
offset_diff_nonlinear
=
0
;
f_recursion
<
nonlinear_independent_dimension_groups
.
GetSize
()
>
{}.
Run
();
}
};
}
// namespace ck
#endif
composable_kernel/include/tensor_operation/blockwise_2d_tensor_op.hpp
deleted
100644 → 0
View file @
98a2cfcc
#ifndef CK_BLOCKWISE_2D_TENSOR_OP_HPP
#define CK_BLOCKWISE_2D_TENSOR_OP_HPP
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
namespace
ck
{
template
<
index_t
BlockSize
,
class
Float
,
class
DstDesc
,
class
F
>
__device__
void
blockwise_2d_tensor_pointwise_operation_unary
(
DstDesc
,
Float
*
__restrict__
p_dst
,
F
f
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
dst_desc
=
DstDesc
{};
constexpr
auto
desc
=
make_ConstantTensorDescriptor
(
dst_desc
.
GetLengths
());
#if 0
if(get_thread_local_1d_id() == 0)
{
print_ConstantTensorDescriptor(dst_desc, "blockwise_4d_tensor_op_unary: dst_desc: ");
print_ConstantTensorDescriptor(desc, "blockwise_4d_tensor_op_unary: desc: ");
}
#endif
constexpr
index_t
NLoop
=
desc
.
GetElementSize
()
/
BlockSize
;
for
(
index_t
iloop
=
0
;
iloop
<
NLoop
;
++
iloop
)
{
index_t
is
=
get_thread_local_1d_id
()
+
iloop
*
BlockSize
;
const
index_t
did0
=
is
/
desc
.
GetStride
(
I0
);
is
-=
did0
*
desc
.
GetStride
(
I0
);
const
index_t
did1
=
is
/
desc
.
GetStride
(
I1
);
const
index_t
dindex
=
dst_desc
.
GetOffsetFromMultiIndex
(
did0
,
did1
);
f
(
p_dst
[
dindex
]);
}
constexpr
bool
has_tail
=
(
desc
.
GetElementSize
()
>
NLoop
*
BlockSize
);
if
(
has_tail
)
{
index_t
is
=
get_thread_local_1d_id
()
+
NLoop
*
BlockSize
;
if
(
is
<
desc
.
GetElementSize
())
{
const
index_t
did0
=
is
/
desc
.
GetStride
(
I0
);
is
-=
did0
*
desc
.
GetStride
(
I0
);
const
index_t
did1
=
is
/
desc
.
GetStride
(
I1
);
const
index_t
dindex
=
dst_desc
.
GetOffsetFromMultiIndex
(
did0
,
did1
);
f
(
p_dst
[
dindex
]);
}
}
}
// Function: p_dst[reorder[i0], reorder[i1] = p_src[i0,i1]
// TODO: in order to optimize mem access for different mem type,
// need to write specialized version
template
<
index_t
BlockSize
,
class
Float
,
class
SrcDesc
,
class
DstDesc
,
class
SrcOpLengths
,
class
MapDst2Src
,
class
F
>
__device__
void
blockwise_2d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src
(
SrcDesc
,
const
Float
*
__restrict__
p_src
,
DstDesc
,
Float
*
__restrict__
p_dst
,
SrcOpLengths
,
MapDst2Src
,
F
f
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
index_t
IR0
=
MapDst2Src
{}.
Get
(
I0
);
constexpr
index_t
IR1
=
MapDst2Src
{}.
Get
(
I1
);
constexpr
auto
src_desc
=
SrcDesc
{};
constexpr
auto
dst_desc
=
DstDesc
{};
constexpr
auto
ref_desc
=
make_ConstantTensorDescriptor
(
SrcOpLengths
{});
constexpr
index_t
NLoop
=
ref_desc
.
GetElementSize
()
/
BlockSize
;
for
(
index_t
iloop
=
0
;
iloop
<
NLoop
;
++
iloop
)
{
index_t
is
=
get_thread_local_1d_id
()
+
iloop
*
BlockSize
;
index_t
did
[
2
];
did
[
0
]
=
is
/
ref_desc
.
GetStride
(
I0
);
is
-=
did
[
0
]
*
ref_desc
.
GetStride
(
I0
);
did
[
1
]
=
is
/
ref_desc
.
GetStride
(
I1
);
const
index_t
aindex
=
src_desc
.
GetOffsetFromMultiIndex
(
did
[
0
],
did
[
1
]);
const
index_t
bindex
=
dst_desc
.
GetOffsetFromMultiIndex
(
did
[
IR0
],
did
[
IR1
]);
f
(
p_src
[
aindex
],
p_dst
[
bindex
]);
}
constexpr
bool
has_tail
=
(
ref_desc
.
GetElementSize
()
>
NLoop
*
BlockSize
);
if
(
has_tail
)
{
index_t
is
=
get_thread_local_1d_id
()
+
NLoop
*
BlockSize
;
if
(
is
<
ref_desc
.
GetElementSize
())
{
index_t
did
[
2
];
did
[
0
]
=
is
/
ref_desc
.
GetStride
(
I0
);
is
-=
did
[
0
]
*
ref_desc
.
GetStride
(
I0
);
did
[
1
]
=
is
/
ref_desc
.
GetStride
(
I1
);
const
index_t
aindex
=
src_desc
.
GetOffsetFromMultiIndex
(
did
[
0
],
did
[
1
]);
const
index_t
bindex
=
dst_desc
.
GetOffsetFromMultiIndex
(
did
[
IR0
],
did
[
IR1
]);
f
(
p_src
[
aindex
],
p_dst
[
bindex
]);
}
}
}
template
<
index_t
BlockSize
,
class
Float
,
class
DstDesc
>
__device__
void
blockwise_2d_tensor_set_zero
(
DstDesc
,
Float
*
__restrict__
p_dst
)
{
auto
f_set_zero
=
[](
Float
&
v
)
{
v
=
Float
(
0
);
};
blockwise_2d_tensor_pointwise_operation_unary
<
BlockSize
>
(
DstDesc
{},
p_dst
,
f_set_zero
);
}
template
<
index_t
BlockSize
,
class
Float
,
class
SrcDesc
,
class
DstDesc
,
class
SrcOpLengths
,
class
MapDst2Src
>
__device__
void
blockwise_2d_tensor_copy_reorder_by_get_dst_from_src
(
SrcDesc
,
const
Float
*
__restrict__
p_src
,
DstDesc
,
Float
*
__restrict__
p_dst
,
SrcOpLengths
,
MapDst2Src
)
{
auto
f_copy
=
[](
const
Float
&
src
,
Float
&
dst
)
{
dst
=
src
;
};
blockwise_2d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src
<
BlockSize
>
(
SrcDesc
{},
p_src
,
DstDesc
{},
p_dst
,
SrcOpLengths
{},
MapDst2Src
{},
f_copy
);
}
template
<
index_t
BlockSize
,
class
Float
,
class
SrcDesc
,
class
DstDesc
,
class
CopyLengths
,
index_t
DataPerRead
>
struct
Blockwise2dTensorCopy1
{
using
vector_t
=
typename
vector_type
<
Float
,
DataPerRead
>::
MemoryType
;
__device__
constexpr
Blockwise2dTensorCopy1
()
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
static_assert
(
DataPerRead
==
1
||
(
SrcDesc
{}.
GetStride
(
I1
)
==
1
&&
DstDesc
{}.
GetStride
(
I1
)
==
1
),
"wrong! only support stride1 == 1 if DataPerRead > 1!
\n
"
);
static_assert
(
DataPerRead
==
1
||
DataPerRead
==
2
||
DataPerRead
==
4
,
"wrong! only support DataPerRead == 1, 2 or 4!
\n
"
);
static_assert
(
SrcDesc
{}.
GetStride
(
I0
)
%
DataPerRead
==
0
&&
DstDesc
{}.
GetStride
(
I0
)
%
DataPerRead
==
0
,
"src and dst stride2 should be multiple of DataPerRead to keep alignment"
);
// we allow out-of-bound read from src in D1 dimension,
// but we need to make sure dst stride0 is big enough,
// so that the out-of-bound write won't contaminate next line in dst
constexpr
index_t
L1
=
CopyLengths
{}.
Get
(
I1
);
constexpr
index_t
read_per_d1
=
math
::
integer_divide_ceil
(
L1
,
DataPerRead
);
static_assert
(
read_per_d1
*
DataPerRead
<=
DstDesc
{}.
GetStride
(
I0
),
"wrong! out-of-bound write will contaminate next line!
\n
"
);
}
__device__
void
Run
(
const
Float
*
__restrict__
p_src
,
Float
*
__restrict__
p_dst
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
src_desc
=
SrcDesc
{};
constexpr
auto
dst_desc
=
DstDesc
{};
constexpr
index_t
L0
=
CopyLengths
{}.
Get
(
I0
);
constexpr
index_t
L1
=
CopyLengths
{}.
Get
(
I1
);
constexpr
index_t
read_per_d1
=
math
::
integer_divide_ceil
(
L1
,
DataPerRead
);
constexpr
auto
ref_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
L0
,
read_per_d1
>
{});
constexpr
index_t
NLoop
=
ref_desc
.
GetElementSize
()
/
BlockSize
;
auto
f_copy
=
[
&
](
index_t
is
)
{
index_t
did
[
4
];
did
[
0
]
=
is
/
ref_desc
.
GetStride
(
I0
);
is
-=
did
[
0
]
*
ref_desc
.
GetStride
(
I0
);
did
[
1
]
=
is
/
ref_desc
.
GetStride
(
I1
);
const
index_t
src_index
=
src_desc
.
GetOffsetFromMultiIndex
(
did
[
0
],
did
[
1
]
*
DataPerRead
);
const
index_t
dst_index
=
dst_desc
.
GetOffsetFromMultiIndex
(
did
[
0
],
did
[
1
]
*
DataPerRead
);
*
(
reinterpret_cast
<
vector_t
*>
(
p_dst
+
dst_index
))
=
*
(
reinterpret_cast
<
const
vector_t
*>
(
p_src
+
src_index
));
};
for
(
index_t
iloop
=
0
;
iloop
<
NLoop
;
++
iloop
)
{
index_t
is
=
get_thread_local_1d_id
()
+
iloop
*
BlockSize
;
f_copy
(
is
);
}
constexpr
bool
has_tail
=
(
ref_desc
.
GetElementSize
()
>
NLoop
*
BlockSize
);
if
(
has_tail
)
{
index_t
is
=
get_thread_local_1d_id
()
+
NLoop
*
BlockSize
;
if
(
is
<
ref_desc
.
GetElementSize
())
{
f_copy
(
is
);
}
}
}
};
// need to be aligned to float4 and float2
// stride1 need to be 1 for both source and destination
template
<
index_t
BlockSize
,
class
Float
,
class
SrcDesc
,
class
DstDesc
,
class
SrcOpLengths
,
index_t
ThreadPerDim0
,
index_t
ThreadPerDim1
>
struct
Blockwise2dTensorCopy2
{
index_t
mThreadId0
;
index_t
mThreadId1
;
__device__
Blockwise2dTensorCopy2
()
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
static_assert
(
SrcDesc
{}.
GetStride
(
I1
)
==
1
&&
DstDesc
{}.
GetStride
(
I1
)
==
1
,
"wrong! stride is not 1!
\n
"
);
mThreadId0
=
get_thread_local_1d_id
()
/
ThreadPerDim1
;
mThreadId1
=
get_thread_local_1d_id
()
-
mThreadId0
*
ThreadPerDim1
;
}
__device__
void
Run
(
const
Float
*
__restrict__
p_src
,
Float
*
__restrict__
p_dst
)
const
{
static_assert
(
is_same
<
Float
,
float
>
{},
"wrong! only support float!
\n
"
);
using
Float4
=
float4
;
using
Float2
=
float2
;
if
(
get_thread_local_1d_id
()
>=
ThreadPerDim0
*
ThreadPerDim1
)
return
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
src_desc
=
SrcDesc
{};
constexpr
auto
dst_desc
=
DstDesc
{};
// check alignment
constexpr
bool
align_v4
=
src_desc
.
GetStride
(
I0
)
%
4
==
0
&&
dst_desc
.
GetStride
(
I0
)
%
4
==
0
;
constexpr
bool
align_v2
=
src_desc
.
GetStride
(
I0
)
%
2
==
0
&&
dst_desc
.
GetStride
(
I0
)
%
2
==
0
;
constexpr
index_t
L0
=
SrcOpLengths
{}.
Get
(
I0
);
constexpr
index_t
L1
=
SrcOpLengths
{}.
Get
(
I1
);
constexpr
index_t
Dim0Loop
=
L0
/
ThreadPerDim0
;
constexpr
bool
d0_has_tail
=
(
L0
>
ThreadPerDim0
*
Dim0Loop
);
constexpr
index_t
Dim1V4Loop
=
align_v4
?
L1
/
(
ThreadPerDim1
*
4
)
:
0
;
constexpr
index_t
Dim1V2Loop
=
align_v2
?
(
L1
-
Dim1V4Loop
*
(
ThreadPerDim1
*
4
))
/
(
ThreadPerDim1
*
2
)
:
0
;
constexpr
index_t
Dim1V1Loop
=
(
L1
-
Dim1V4Loop
*
(
ThreadPerDim1
*
4
)
-
Dim1V2Loop
*
(
ThreadPerDim1
*
2
))
/
ThreadPerDim1
;
constexpr
bool
d1_has_tail
=
(
L1
>
ThreadPerDim1
*
(
4
*
Dim1V4Loop
+
2
*
Dim1V2Loop
+
Dim1V1Loop
));
for
(
index_t
d0loop
=
0
;
d0loop
<
Dim0Loop
;
++
d0loop
)
{
index_t
did0
=
d0loop
*
ThreadPerDim0
+
mThreadId0
;
// v4
for
(
index_t
d1v4loop
=
0
;
d1v4loop
<
Dim1V4Loop
;
++
d1v4loop
)
{
index_t
did1
=
d1v4loop
*
4
*
ThreadPerDim1
+
4
*
mThreadId1
;
const
index_t
sindex
=
src_desc
.
GetOffsetFromMultiIndex
(
did0
,
did1
);
const
index_t
dindex
=
dst_desc
.
GetOffsetFromMultiIndex
(
did0
,
did1
);
*
(
reinterpret_cast
<
Float4
*>
(
p_dst
+
dindex
))
=
*
(
reinterpret_cast
<
const
Float4
*>
(
p_src
+
sindex
));
}
// v2
for
(
index_t
d1v2loop
=
0
;
d1v2loop
<
Dim1V2Loop
;
++
d1v2loop
)
{
index_t
did1
=
Dim1V4Loop
*
4
*
ThreadPerDim1
+
d1v2loop
*
2
*
ThreadPerDim1
+
2
*
mThreadId1
;
const
index_t
sindex
=
src_desc
.
GetOffsetFromMultiIndex
(
did0
,
did1
);
const
index_t
dindex
=
dst_desc
.
GetOffsetFromMultiIndex
(
did0
,
did1
);
*
(
reinterpret_cast
<
Float2
*>
(
p_dst
+
dindex
))
=
*
(
reinterpret_cast
<
const
Float2
*>
(
p_src
+
sindex
));
}
// v1
for
(
index_t
d1v1loop
=
0
;
d1v1loop
<
Dim1V1Loop
;
++
d1v1loop
)
{
index_t
did1
=
Dim1V4Loop
*
4
*
ThreadPerDim1
+
Dim1V2Loop
*
2
*
ThreadPerDim1
+
d1v1loop
*
ThreadPerDim1
+
mThreadId1
;
const
index_t
sindex
=
src_desc
.
GetOffsetFromMultiIndex
(
did0
,
did1
);
const
index_t
dindex
=
dst_desc
.
GetOffsetFromMultiIndex
(
did0
,
did1
);
p_dst
[
dindex
]
=
p_src
[
sindex
];
}
// dim-1 tail
if
(
d1_has_tail
)
{
index_t
did1
=
Dim1V4Loop
*
4
*
ThreadPerDim1
+
Dim1V2Loop
*
2
*
ThreadPerDim1
+
Dim1V1Loop
*
ThreadPerDim1
+
mThreadId1
;
if
(
did1
<
L1
)
{
const
index_t
sindex
=
src_desc
.
GetOffsetFromMultiIndex
(
did0
,
did1
);
const
index_t
dindex
=
dst_desc
.
GetOffsetFromMultiIndex
(
did0
,
did1
);
p_dst
[
dindex
]
=
p_src
[
sindex
];
}
}
}
// dim-0 tail
if
(
d0_has_tail
)
{
index_t
did0
=
Dim0Loop
*
ThreadPerDim0
+
mThreadId0
;
if
(
did0
<
L0
)
{
// v4
for
(
index_t
d1v4loop
=
0
;
d1v4loop
<
Dim1V4Loop
;
++
d1v4loop
)
{
index_t
did1
=
d1v4loop
*
4
*
ThreadPerDim1
+
4
*
mThreadId1
;
const
index_t
sindex
=
src_desc
.
GetOffsetFromMultiIndex
(
did0
,
did1
);
const
index_t
dindex
=
dst_desc
.
GetOffsetFromMultiIndex
(
did0
,
did1
);
*
(
reinterpret_cast
<
Float4
*>
(
p_dst
+
dindex
))
=
*
(
reinterpret_cast
<
const
Float4
*>
(
p_src
+
sindex
));
}
// v2
for
(
index_t
d1v2loop
=
0
;
d1v2loop
<
Dim1V2Loop
;
++
d1v2loop
)
{
index_t
did1
=
Dim1V4Loop
*
4
*
ThreadPerDim1
+
d1v2loop
*
2
*
ThreadPerDim1
+
2
*
mThreadId1
;
const
index_t
sindex
=
src_desc
.
GetOffsetFromMultiIndex
(
did0
,
did1
);
const
index_t
dindex
=
dst_desc
.
GetOffsetFromMultiIndex
(
did0
,
did1
);
*
(
reinterpret_cast
<
Float2
*>
(
p_dst
+
dindex
))
=
*
(
reinterpret_cast
<
const
Float2
*>
(
p_src
+
sindex
));
}
// v1
for
(
index_t
d1v1loop
=
0
;
d1v1loop
<
Dim1V1Loop
;
++
d1v1loop
)
{
index_t
did1
=
Dim1V4Loop
*
4
*
ThreadPerDim1
+
Dim1V2Loop
*
2
*
ThreadPerDim1
+
d1v1loop
*
ThreadPerDim1
+
mThreadId1
;
const
index_t
sindex
=
src_desc
.
GetOffsetFromMultiIndex
(
did0
,
did1
);
const
index_t
dindex
=
dst_desc
.
GetOffsetFromMultiIndex
(
did0
,
did1
);
p_dst
[
dindex
]
=
p_src
[
sindex
];
}
// tail
if
(
d1_has_tail
)
{
index_t
did1
=
Dim1V4Loop
*
4
*
ThreadPerDim1
+
Dim1V2Loop
*
2
*
ThreadPerDim1
+
Dim1V1Loop
*
ThreadPerDim1
+
mThreadId1
;
if
(
did1
<
L1
)
{
const
index_t
sindex
=
src_desc
.
GetOffsetFromMultiIndex
(
did0
,
did1
);
const
index_t
dindex
=
dst_desc
.
GetOffsetFromMultiIndex
(
did0
,
did1
);
p_dst
[
dindex
]
=
p_src
[
sindex
];
}
}
}
}
}
};
// starting point need to be aligned to float4 or float2 or float
// stride1 need to be 1 for both source and destination
template
<
index_t
BlockSize
,
class
Float
,
class
SrcDesc
,
class
DstDesc
,
class
CopyLengths
,
index_t
DataPerRead
>
struct
Blockwise2dTensorCopy3
{
using
vector_t
=
typename
vector_type
<
Float
,
DataPerRead
>::
MemoryType
;
index_t
mSrcMyThreadOffset
;
index_t
mDstMyThreadOffset
;
__device__
Blockwise2dTensorCopy3
(
Array
<
index_t
,
2
>
src_block_data_multi_id_begin
,
Array
<
index_t
,
2
>
dst_block_data_multi_id_begin
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
static_assert
(
DataPerRead
==
1
||
(
SrcDesc
{}.
GetStride
(
I1
)
==
1
&&
DstDesc
{}.
GetStride
(
I1
)
==
1
),
"wrong! only support stride1 == 1 if DataPerRead > 1!
\n
"
);
static_assert
(
DataPerRead
==
1
||
DataPerRead
==
2
||
DataPerRead
==
4
,
"wrong! only support DataPerRead == 1, 2 or 4!
\n
"
);
static_assert
(
SrcDesc
{}.
GetStride
(
I0
)
%
DataPerRead
==
0
&&
DstDesc
{}.
GetStride
(
I0
)
%
DataPerRead
==
0
,
"src and dst stride should be multiple of DataPerRead to keep alignment"
);
constexpr
index_t
L1
=
CopyLengths
{}.
Get
(
I1
);
constexpr
index_t
thread_per_d1
=
(
L1
+
DataPerRead
-
1
)
/
DataPerRead
;
constexpr
index_t
thread_per_d0
=
BlockSize
/
thread_per_d1
;
// we allow out-of-bound read from src in D1 dimension,
// but we need to make sure dst stride is big enough,
// so that the out-of-bound write won't contaminate next line in dst
static_assert
(
thread_per_d1
*
DataPerRead
<=
DstDesc
{}.
GetStride
(
I0
),
"wrong! out-of-bound write will contaminate next line!
\n
"
);
static_assert
(
thread_per_d0
>=
1
,
"wrong! not enough threads to cover one line
\n
"
);
constexpr
index_t
num_active_thread
=
thread_per_d0
*
thread_per_d1
;
if
(
BlockSize
>
num_active_thread
)
{
if
(
get_thread_local_1d_id
()
>=
num_active_thread
)
{
return
;
}
}
const
index_t
thread_id_d0
=
get_thread_local_1d_id
()
/
thread_per_d1
;
const
index_t
thread_id_d1
=
get_thread_local_1d_id
()
-
thread_id_d0
*
thread_per_d1
;
mSrcMyThreadOffset
=
SrcDesc
{}.
GetOffsetFromMultiIndex
(
src_block_data_multi_id_begin
+
Array
<
index_t
,
2
>
{
thread_id_d0
,
thread_id_d1
*
DataPerRead
});
mDstMyThreadOffset
=
DstDesc
{}.
GetOffsetFromMultiIndex
(
dst_block_data_multi_id_begin
+
Array
<
index_t
,
2
>
{
thread_id_d0
,
thread_id_d1
*
DataPerRead
});
}
__device__
void
Run
(
const
Float
*
__restrict__
p_src
,
Float
*
__restrict__
p_dst
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
index_t
L0
=
CopyLengths
{}.
Get
(
I0
);
constexpr
index_t
L1
=
CopyLengths
{}.
Get
(
I1
);
constexpr
index_t
thread_per_d1
=
(
L1
+
DataPerRead
-
1
)
/
DataPerRead
;
constexpr
index_t
thread_per_d0
=
BlockSize
/
thread_per_d1
;
constexpr
index_t
num_active_thread
=
thread_per_d0
*
thread_per_d1
;
if
(
BlockSize
>
num_active_thread
)
{
if
(
get_thread_local_1d_id
()
>=
num_active_thread
)
{
return
;
}
}
constexpr
index_t
nloop_d0
=
L0
/
thread_per_d0
;
constexpr
index_t
src_loop_stride
=
SrcDesc
{}.
GetStride
(
I0
)
*
thread_per_d0
;
constexpr
index_t
dst_loop_stride
=
DstDesc
{}.
GetStride
(
I0
)
*
thread_per_d0
;
auto
f_copy
=
[
&
](
index_t
iloop
)
{
*
(
reinterpret_cast
<
vector_t
*>
(
p_dst
+
mDstMyThreadOffset
+
iloop
*
dst_loop_stride
))
=
*
(
reinterpret_cast
<
const
vector_t
*>
(
p_src
+
mSrcMyThreadOffset
+
iloop
*
src_loop_stride
));
};
for
(
index_t
iloop
=
0
;
iloop
<
nloop_d0
;
++
iloop
)
{
f_copy
(
iloop
);
}
constexpr
bool
has_tail_d0
=
(
L0
>
nloop_d0
*
thread_per_d0
);
if
(
has_tail_d0
)
{
constexpr
index_t
tail_d0
=
L0
-
nloop_d0
*
thread_per_d0
;
if
(
get_thread_local_1d_id
()
<
tail_d0
*
thread_per_d1
)
{
f_copy
(
nloop_d0
);
}
}
}
__device__
constexpr
index_t
GetRegisterBufferSize
()
const
{
static_assert
(
is_same
<
Float
,
float
>
{},
"wrong! only support float!
\n
"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
index_t
L0
=
CopyLengths
{}.
Get
(
I0
);
constexpr
index_t
L1
=
CopyLengths
{}.
Get
(
I1
);
constexpr
index_t
thread_per_d1
=
(
L1
+
DataPerRead
-
1
)
/
DataPerRead
;
constexpr
index_t
thread_per_d0
=
BlockSize
/
thread_per_d1
;
return
DataPerRead
*
(
L0
+
thread_per_d0
-
1
)
/
thread_per_d0
;
}
__device__
void
RunLoadRegisterBuffer
(
const
Float
*
__restrict__
p_src
,
Float
*
__restrict__
p_clipboard
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
index_t
L0
=
CopyLengths
{}.
Get
(
I0
);
constexpr
index_t
L1
=
CopyLengths
{}.
Get
(
I1
);
constexpr
index_t
thread_per_d1
=
(
L1
+
DataPerRead
-
1
)
/
DataPerRead
;
constexpr
index_t
thread_per_d0
=
BlockSize
/
thread_per_d1
;
constexpr
index_t
num_active_thread
=
thread_per_d0
*
thread_per_d1
;
if
(
BlockSize
>
num_active_thread
)
{
if
(
get_thread_local_1d_id
()
>=
num_active_thread
)
{
return
;
}
}
constexpr
index_t
nloop_d0
=
L0
/
thread_per_d0
;
constexpr
index_t
src_loop_stride
=
SrcDesc
{}.
GetStride
(
I0
)
*
thread_per_d0
;
constexpr
index_t
dst_loop_stride
=
DstDesc
{}.
GetStride
(
I0
)
*
thread_per_d0
;
auto
f_copy
=
[
&
](
index_t
iloop
)
{
*
(
reinterpret_cast
<
vector_t
*>
(
&
p_clipboard
[
iloop
*
DataPerRead
]))
=
*
(
reinterpret_cast
<
const
vector_t
*>
(
&
p_src
[
mSrcMyThreadOffset
+
iloop
*
src_loop_stride
]));
};
for
(
index_t
iloop
=
0
;
iloop
<
nloop_d0
;
++
iloop
)
{
f_copy
(
iloop
);
}
constexpr
bool
has_tail_d0
=
(
L0
>
nloop_d0
*
thread_per_d0
);
if
(
has_tail_d0
)
{
constexpr
index_t
tail_d0
=
L0
-
nloop_d0
*
thread_per_d0
;
if
(
get_thread_local_1d_id
()
<
tail_d0
*
thread_per_d1
)
{
f_copy
(
nloop_d0
);
}
}
}
__device__
void
RunStoreRegisterBuffer
(
const
Float
*
__restrict__
p_clipboard
,
Float
*
__restrict__
p_dst
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
index_t
L0
=
CopyLengths
{}.
Get
(
I0
);
constexpr
index_t
L1
=
CopyLengths
{}.
Get
(
I1
);
constexpr
index_t
thread_per_d1
=
(
L1
+
DataPerRead
-
1
)
/
DataPerRead
;
constexpr
index_t
thread_per_d0
=
BlockSize
/
thread_per_d1
;
constexpr
index_t
num_active_thread
=
thread_per_d0
*
thread_per_d1
;
if
(
BlockSize
>
num_active_thread
)
{
if
(
get_thread_local_1d_id
()
>=
num_active_thread
)
{
return
;
}
}
constexpr
index_t
nloop_d0
=
L0
/
thread_per_d0
;
constexpr
index_t
src_loop_stride
=
SrcDesc
{}.
GetStride
(
I0
)
*
thread_per_d0
;
constexpr
index_t
dst_loop_stride
=
DstDesc
{}.
GetStride
(
I0
)
*
thread_per_d0
;
auto
f_copy
=
[
&
](
index_t
iloop
)
{
*
(
reinterpret_cast
<
vector_t
*>
(
&
p_dst
[
mDstMyThreadOffset
+
iloop
*
dst_loop_stride
]))
=
*
(
reinterpret_cast
<
const
vector_t
*>
(
&
p_clipboard
[
iloop
*
DataPerRead
]));
};
for
(
index_t
iloop
=
0
;
iloop
<
nloop_d0
;
++
iloop
)
{
f_copy
(
iloop
);
}
constexpr
bool
has_tail_d0
=
(
L0
>
nloop_d0
*
thread_per_d0
);
if
(
has_tail_d0
)
{
constexpr
index_t
tail_d0
=
L0
-
nloop_d0
*
thread_per_d0
;
if
(
get_thread_local_1d_id
()
<
tail_d0
*
thread_per_d1
)
{
f_copy
(
nloop_d0
);
}
}
}
#if CK_USE_AMD_INLINE_ASM
__device__
void
RunLoadRegisterBuffer_asm
(
const
Float
*
__restrict__
p_src
,
Float
*
p_clipboard
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
index_t
L0
=
CopyLengths
{}.
Get
(
I0
);
constexpr
index_t
L1
=
CopyLengths
{}.
Get
(
I1
);
constexpr
index_t
thread_per_d1
=
(
L1
+
DataPerRead
-
1
)
/
DataPerRead
;
constexpr
index_t
thread_per_d0
=
BlockSize
/
thread_per_d1
;
constexpr
index_t
num_active_thread
=
thread_per_d0
*
thread_per_d1
;
if
(
BlockSize
>
num_active_thread
)
{
if
(
get_thread_local_1d_id
()
>=
num_active_thread
)
{
return
;
}
}
constexpr
index_t
nloop_d0
=
L0
/
thread_per_d0
;
constexpr
index_t
src_loop_stride
=
SrcDesc
{}.
GetStride
(
I0
)
*
thread_per_d0
;
constexpr
index_t
dst_loop_stride
=
DstDesc
{}.
GetStride
(
I0
)
*
thread_per_d0
;
auto
f_copy
=
[
&
](
index_t
iloop
)
{
#if 0
*(reinterpret_cast<vector_t*>(&p_clipboard[iloop * DataPerRead])) =
*(reinterpret_cast<const vector_t*>(&p_src[mSrcMyThreadOffset +
iloop * src_loop_stride]));
#else
static_assert
(
is_same
<
float
,
Float
>
{}
&&
DataPerRead
==
4
,
"global_load is only for float4"
);
global_load
(
reinterpret_cast
<
vector_t
&>
(
p_clipboard
[
iloop
*
DataPerRead
]),
reinterpret_cast
<
const
vector_t
*>
(
&
p_src
[
mSrcMyThreadOffset
+
iloop
*
src_loop_stride
]));
#endif
};
for
(
index_t
iloop
=
0
;
iloop
<
nloop_d0
;
++
iloop
)
{
f_copy
(
iloop
);
}
constexpr
bool
has_tail_d0
=
(
L0
>
nloop_d0
*
thread_per_d0
);
if
(
has_tail_d0
)
{
constexpr
index_t
tail_d0
=
L0
-
nloop_d0
*
thread_per_d0
;
if
(
get_thread_local_1d_id
()
<
tail_d0
*
thread_per_d1
)
{
f_copy
(
nloop_d0
);
}
}
}
__device__
void
RunStoreRegisterBuffer_asm
(
const
Float
*
__restrict__
p_clipboard
,
Float
*
__restrict__
p_dst
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
index_t
L0
=
CopyLengths
{}.
Get
(
I0
);
constexpr
index_t
L1
=
CopyLengths
{}.
Get
(
I1
);
constexpr
index_t
thread_per_d1
=
(
L1
+
DataPerRead
-
1
)
/
DataPerRead
;
constexpr
index_t
thread_per_d0
=
BlockSize
/
thread_per_d1
;
constexpr
index_t
num_active_thread
=
thread_per_d0
*
thread_per_d1
;
if
(
BlockSize
>
num_active_thread
)
{
if
(
get_thread_local_1d_id
()
>=
num_active_thread
)
{
return
;
}
}
constexpr
index_t
nloop_d0
=
L0
/
thread_per_d0
;
constexpr
index_t
src_loop_stride
=
SrcDesc
{}.
GetStride
(
I0
)
*
thread_per_d0
;
constexpr
index_t
dst_loop_stride
=
DstDesc
{}.
GetStride
(
I0
)
*
thread_per_d0
;
auto
f_copy
=
[
&
](
index_t
iloop
)
{
#if 0
*(reinterpret_cast<vector_t*>(&p_dst[mDstMyThreadOffset + iloop * dst_loop_stride]) =
*(reinterpret_cast<const vector_t*>(&p_clipboard[iloop * DataPerRead]);
#else
static_assert
(
is_same
<
float
,
Float
>
{}
&&
DataPerRead
==
4
,
"ds_write_b128 is only for float4"
);
ds_write_b128
(
reinterpret_cast
<
const
vector_t
&>
(
p_clipboard
[
iloop
*
DataPerRead
]),
&
p_dst
[
mDstMyThreadOffset
+
iloop
*
dst_loop_stride
]);
#endif
};
for
(
index_t
iloop
=
0
;
iloop
<
nloop_d0
;
++
iloop
)
{
f_copy
(
iloop
);
}
constexpr
bool
has_tail_d0
=
(
L0
>
nloop_d0
*
thread_per_d0
);
if
(
has_tail_d0
)
{
constexpr
index_t
tail_d0
=
L0
-
nloop_d0
*
thread_per_d0
;
if
(
get_thread_local_1d_id
()
<
tail_d0
*
thread_per_d1
)
{
f_copy
(
nloop_d0
);
}
}
}
#endif
};
}
// namespace ck
#endif
composable_kernel/include/tensor_operation/blockwise_3d_tensor_op.hpp
deleted
100644 → 0
View file @
98a2cfcc
#ifndef CK_BLOCKWISE_3D_TENSOR_OP_HPP
#define CK_BLOCKWISE_3D_TENSOR_OP_HPP
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
namespace
ck
{
template
<
index_t
BlockSize
,
class
Float
,
class
SrcDesc
,
class
DstDesc
,
class
CopyLengths
,
index_t
DataPerRead
>
struct
Blockwise3dTensorCopy1
{
using
vector_t
=
typename
vector_type
<
Float
,
DataPerRead
>::
MemoryType
;
__device__
constexpr
Blockwise3dTensorCopy1
()
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
static_assert
(
DataPerRead
==
1
||
(
SrcDesc
{}.
GetStride
(
I2
)
==
1
&&
DstDesc
{}.
GetStride
(
I2
)
==
1
),
"wrong! only support stride2 == 1 if DataPerRead > 1!
\n
"
);
static_assert
(
DataPerRead
==
1
||
DataPerRead
==
2
||
DataPerRead
==
4
,
"wrong! only support DataPerRead == 1, 2 or 4!
\n
"
);
static_assert
(
SrcDesc
{}.
GetStride
(
I1
)
%
DataPerRead
==
0
&&
DstDesc
{}.
GetStride
(
I1
)
%
DataPerRead
==
0
,
"src and dst stride1 should be multiple of DataPerRead to keep alignment"
);
// we allow out-of-bound read from src in D3 dimension,
// but we need to make sure dst stride2 is big enough,
// so that the out-of-bound write won't contaminate next line in dst
constexpr
index_t
L2
=
CopyLengths
{}.
Get
(
I2
);
constexpr
index_t
read_per_d2
=
math
::
integer_divide_ceil
(
L2
,
DataPerRead
);
static_assert
(
read_per_d2
*
DataPerRead
<=
DstDesc
{}.
GetStride
(
I1
),
"wrong! out-of-bound write will contaminate next line!
\n
"
);
}
__device__
void
Run
(
const
Float
*
__restrict__
p_src
,
Float
*
__restrict__
p_dst
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
src_desc
=
SrcDesc
{};
constexpr
auto
dst_desc
=
DstDesc
{};
constexpr
index_t
L0
=
CopyLengths
{}.
Get
(
I0
);
constexpr
index_t
L1
=
CopyLengths
{}.
Get
(
I1
);
constexpr
index_t
L2
=
CopyLengths
{}.
Get
(
I2
);
constexpr
index_t
read_per_d2
=
math
::
integer_divide_ceil
(
L2
,
DataPerRead
);
constexpr
auto
ref_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
L0
,
L1
,
read_per_d2
>
{});
constexpr
index_t
NLoop
=
ref_desc
.
GetElementSize
()
/
BlockSize
;
auto
f_copy
=
[
&
](
index_t
is
)
{
index_t
did
[
3
];
did
[
0
]
=
is
/
ref_desc
.
GetStride
(
I0
);
is
-=
did
[
0
]
*
ref_desc
.
GetStride
(
I0
);
did
[
1
]
=
is
/
ref_desc
.
GetStride
(
I1
);
is
-=
did
[
1
]
*
ref_desc
.
GetStride
(
I1
);
did
[
2
]
=
is
/
ref_desc
.
GetStride
(
I2
);
const
index_t
src_index
=
src_desc
.
GetOffsetFromMultiIndex
(
did
[
0
],
did
[
1
],
did
[
2
]
*
DataPerRead
);
const
index_t
dst_index
=
dst_desc
.
GetOffsetFromMultiIndex
(
did
[
0
],
did
[
1
],
did
[
2
]
*
DataPerRead
);
*
(
reinterpret_cast
<
vector_t
*>
(
p_dst
+
dst_index
))
=
*
(
reinterpret_cast
<
const
vector_t
*>
(
p_src
+
src_index
));
};
for
(
index_t
iloop
=
0
;
iloop
<
NLoop
;
++
iloop
)
{
index_t
is
=
get_thread_local_1d_id
()
+
iloop
*
BlockSize
;
f_copy
(
is
);
}
constexpr
bool
has_tail
=
(
ref_desc
.
GetElementSize
()
>
NLoop
*
BlockSize
);
if
(
has_tail
)
{
index_t
is
=
get_thread_local_1d_id
()
+
NLoop
*
BlockSize
;
if
(
is
<
ref_desc
.
GetElementSize
())
{
f_copy
(
is
);
}
}
}
};
// starting point need to be aligned to float4 or float2 or float
// stride3 need to be 1 for both source and destination
template
<
index_t
BlockSize
,
class
Float
,
class
SrcDesc
,
class
DstDesc
,
class
CopyLengths
,
class
ThreadPerDims
,
index_t
DataPerRead
>
struct
Blockwise3dTensorCopy3
{
using
vector_t
=
typename
vector_type
<
Float
,
DataPerRead
>::
MemoryType
;
index_t
mSrcMyThreadOffset
;
index_t
mDstMyThreadOffset
;
__device__
Blockwise3dTensorCopy3
()
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
static_assert
(
DataPerRead
==
1
||
(
SrcDesc
{}.
GetStride
(
I2
)
==
1
&&
DstDesc
{}.
GetStride
(
I2
)
==
1
),
"wrong! only support stride3 == 1 if DataPerRead > 1!
\n
"
);
static_assert
(
DataPerRead
==
1
||
DataPerRead
==
2
||
DataPerRead
==
4
,
"wrong! only support DataPerRead == 1, 2 or 4!
\n
"
);
static_assert
(
SrcDesc
{}.
GetStride
(
I1
)
%
DataPerRead
==
0
&&
DstDesc
{}.
GetStride
(
I1
)
%
DataPerRead
==
0
,
"wrong! src and dst stride1 should be multiple of DataPerRead to keep alignment"
);
constexpr
index_t
L0
=
CopyLengths
{}.
Get
(
I0
);
constexpr
index_t
L1
=
CopyLengths
{}.
Get
(
I1
);
constexpr
index_t
L2
=
CopyLengths
{}.
Get
(
I2
);
constexpr
index_t
thread_per_d0
=
ThreadPerDims
{}.
Get
(
I0
);
constexpr
index_t
thread_per_d1
=
ThreadPerDims
{}.
Get
(
I1
);
constexpr
index_t
thread_per_d2
=
ThreadPerDims
{}.
Get
(
I2
);
// we allow out-of-bound read from src in D2 dimension,
// but we need to make sure dst stride is big enough,
// so that the out-of-bound write won't contaminate next line in dst
constexpr
index_t
nloop_d2
=
math
::
integer_divide_ceil
(
L2
,
thread_per_d2
*
DataPerRead
);
static_assert
(
nloop_d2
*
thread_per_d2
*
DataPerRead
<=
DstDesc
{}.
GetStride
(
I1
),
"wrong! out-of-bound write will contaminate next line!
\n
"
);
static_assert
(
L0
%
thread_per_d0
==
0
&&
L1
%
thread_per_d1
==
0
,
"wrong! L0, L1, L2 should be divided evenly!
\n
"
);
static_assert
(
BlockSize
>=
thread_per_d0
*
thread_per_d1
*
thread_per_d2
,
"wrrong! BlockSize is not big enough for ThreadPerDims!"
);
constexpr
index_t
num_active_thread
=
reduce_on_sequence
(
ThreadPerDims
{},
math
::
multiplies
<
index_t
>
{},
Number
<
1
>
{});
if
(
BlockSize
>
num_active_thread
)
{
if
(
get_thread_local_1d_id
()
>=
num_active_thread
)
{
return
;
}
}
constexpr
auto
thread_cluster_desc
=
make_ConstantTensorDescriptor
(
ThreadPerDims
{});
const
auto
thread_multi_id
=
thread_cluster_desc
.
GetMultiIndexFrom1dIndex
(
get_thread_local_1d_id
());
mSrcMyThreadOffset
=
SrcDesc
{}.
GetOffsetFromMultiIndex
(
thread_multi_id
[
0
],
thread_multi_id
[
1
],
thread_multi_id
[
2
]
*
DataPerRead
);
mDstMyThreadOffset
=
DstDesc
{}.
GetOffsetFromMultiIndex
(
thread_multi_id
[
0
],
thread_multi_id
[
1
],
thread_multi_id
[
2
]
*
DataPerRead
);
}
__device__
void
Run
(
const
Float
*
__restrict__
p_src
,
Float
*
__restrict__
p_dst
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
index_t
L0
=
CopyLengths
{}.
Get
(
I0
);
constexpr
index_t
L1
=
CopyLengths
{}.
Get
(
I1
);
constexpr
index_t
L2
=
CopyLengths
{}.
Get
(
I2
);
constexpr
index_t
thread_per_d0
=
ThreadPerDims
{}.
Get
(
I0
);
constexpr
index_t
thread_per_d1
=
ThreadPerDims
{}.
Get
(
I1
);
constexpr
index_t
thread_per_d2
=
ThreadPerDims
{}.
Get
(
I2
);
constexpr
index_t
num_active_thread
=
thread_per_d0
*
thread_per_d1
*
thread_per_d2
;
if
(
BlockSize
>
num_active_thread
)
{
if
(
get_thread_local_1d_id
()
>=
num_active_thread
)
{
return
;
}
}
constexpr
index_t
nloop_d0
=
L0
/
thread_per_d0
;
constexpr
index_t
nloop_d1
=
L1
/
thread_per_d1
;
constexpr
index_t
nloop_d2
=
math
::
integer_divide_ceil
(
L2
,
thread_per_d2
*
DataPerRead
);
#pragma unroll
for
(
index_t
iloop_d0
=
0
;
iloop_d0
<
nloop_d0
;
++
iloop_d0
)
{
#pragma unroll
for
(
index_t
iloop_d1
=
0
;
iloop_d1
<
nloop_d1
;
++
iloop_d1
)
{
#pragma unroll
for
(
index_t
iloop_d2
=
0
;
iloop_d2
<
nloop_d2
;
++
iloop_d2
)
{
const
index_t
src_offset
=
SrcDesc
{}.
GetOffsetFromMultiIndex
(
iloop_d0
*
thread_per_d0
,
iloop_d1
*
thread_per_d1
,
iloop_d2
*
thread_per_d2
*
DataPerRead
);
const
index_t
dst_offset
=
DstDesc
{}.
GetOffsetFromMultiIndex
(
iloop_d0
*
thread_per_d0
,
iloop_d1
*
thread_per_d1
,
iloop_d2
*
thread_per_d2
*
DataPerRead
);
*
(
reinterpret_cast
<
vector_t
*>
(
&
p_dst
[
dst_offset
+
mDstMyThreadOffset
]))
=
*
(
reinterpret_cast
<
const
vector_t
*>
(
&
p_src
[
src_offset
+
mSrcMyThreadOffset
]));
}
}
}
}
__device__
static
constexpr
index_t
GetRegisterBufferSize
()
{
static_assert
(
is_same
<
Float
,
float
>
{},
"wrong! only support float!
\n
"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
index_t
L0
=
CopyLengths
{}.
Get
(
I0
);
constexpr
index_t
L1
=
CopyLengths
{}.
Get
(
I1
);
constexpr
index_t
L2
=
CopyLengths
{}.
Get
(
I2
);
constexpr
index_t
thread_per_d0
=
ThreadPerDims
{}.
Get
(
I0
);
constexpr
index_t
thread_per_d1
=
ThreadPerDims
{}.
Get
(
I1
);
constexpr
index_t
thread_per_d2
=
ThreadPerDims
{}.
Get
(
I2
);
constexpr
index_t
nloop_d0
=
L0
/
thread_per_d0
;
constexpr
index_t
nloop_d1
=
L1
/
thread_per_d1
;
constexpr
index_t
nloop_d2
=
math
::
integer_divide_ceil
(
L2
,
thread_per_d2
*
DataPerRead
);
return
DataPerRead
*
nloop_d0
*
nloop_d1
*
nloop_d2
;
}
__device__
void
RunLoadRegisterBuffer
(
const
Float
*
__restrict__
p_src
,
Float
*
__restrict__
p_clipboard
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
index_t
L0
=
CopyLengths
{}.
Get
(
I0
);
constexpr
index_t
L1
=
CopyLengths
{}.
Get
(
I1
);
constexpr
index_t
L2
=
CopyLengths
{}.
Get
(
I2
);
constexpr
index_t
thread_per_d0
=
ThreadPerDims
{}.
Get
(
I0
);
constexpr
index_t
thread_per_d1
=
ThreadPerDims
{}.
Get
(
I1
);
constexpr
index_t
thread_per_d2
=
ThreadPerDims
{}.
Get
(
I2
);
constexpr
index_t
num_active_thread
=
thread_per_d0
*
thread_per_d1
*
thread_per_d2
;
if
(
BlockSize
>
num_active_thread
)
{
if
(
get_thread_local_1d_id
()
>=
num_active_thread
)
{
return
;
}
}
constexpr
index_t
nloop_d0
=
L0
/
thread_per_d0
;
constexpr
index_t
nloop_d1
=
L1
/
thread_per_d1
;
constexpr
index_t
nloop_d2
=
math
::
integer_divide_ceil
(
L2
,
thread_per_d2
*
DataPerRead
);
constexpr
auto
clipboard_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
nloop_d0
,
nloop_d1
,
nloop_d2
*
DataPerRead
>
{});
#pragma unroll
for
(
index_t
iloop_d0
=
0
;
iloop_d0
<
nloop_d0
;
++
iloop_d0
)
{
#pragma unroll
for
(
index_t
iloop_d1
=
0
;
iloop_d1
<
nloop_d1
;
++
iloop_d1
)
{
#pragma unroll
for
(
index_t
iloop_d2
=
0
;
iloop_d2
<
nloop_d2
;
++
iloop_d2
)
{
const
index_t
src_offset
=
SrcDesc
{}.
GetOffsetFromMultiIndex
(
iloop_d0
*
thread_per_d0
,
iloop_d1
*
thread_per_d1
,
iloop_d2
*
thread_per_d2
*
DataPerRead
);
const
index_t
clipboard_offset
=
clipboard_desc
.
GetOffsetFromMultiIndex
(
iloop_d0
,
iloop_d1
,
iloop_d2
*
DataPerRead
);
*
(
reinterpret_cast
<
vector_t
*>
(
&
p_clipboard
[
clipboard_offset
]))
=
*
(
reinterpret_cast
<
const
vector_t
*>
(
&
p_src
[
src_offset
+
mSrcMyThreadOffset
]));
}
}
}
}
__device__
void
RunStoreRegisterBuffer
(
const
Float
*
__restrict__
p_clipboard
,
Float
*
__restrict__
p_dst
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
index_t
L0
=
CopyLengths
{}.
Get
(
I0
);
constexpr
index_t
L1
=
CopyLengths
{}.
Get
(
I1
);
constexpr
index_t
L2
=
CopyLengths
{}.
Get
(
I2
);
constexpr
index_t
thread_per_d0
=
ThreadPerDims
{}.
Get
(
I0
);
constexpr
index_t
thread_per_d1
=
ThreadPerDims
{}.
Get
(
I1
);
constexpr
index_t
thread_per_d2
=
ThreadPerDims
{}.
Get
(
I2
);
constexpr
index_t
num_active_thread
=
thread_per_d0
*
thread_per_d1
*
thread_per_d2
;
if
(
BlockSize
>
num_active_thread
)
{
if
(
get_thread_local_1d_id
()
>=
num_active_thread
)
{
return
;
}
}
constexpr
index_t
nloop_d0
=
L0
/
thread_per_d0
;
constexpr
index_t
nloop_d1
=
L1
/
thread_per_d1
;
constexpr
index_t
nloop_d2
=
math
::
integer_divide_ceil
(
L2
,
thread_per_d2
*
DataPerRead
);
constexpr
auto
clipboard_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
nloop_d0
,
nloop_d1
,
nloop_d2
*
DataPerRead
>
{});
#pragma unroll
for
(
index_t
iloop_d0
=
0
;
iloop_d0
<
nloop_d0
;
++
iloop_d0
)
{
#pragma unroll
for
(
index_t
iloop_d1
=
0
;
iloop_d1
<
nloop_d1
;
++
iloop_d1
)
{
#pragma unroll
for
(
index_t
iloop_d2
=
0
;
iloop_d2
<
nloop_d2
;
++
iloop_d2
)
{
const
index_t
clipboard_offset
=
clipboard_desc
.
GetOffsetFromMultiIndex
(
iloop_d0
,
iloop_d1
,
iloop_d2
*
DataPerRead
);
const
index_t
dst_offset
=
DstDesc
{}.
GetOffsetFromMultiIndex
(
iloop_d0
*
thread_per_d0
,
iloop_d1
*
thread_per_d1
,
iloop_d2
*
thread_per_d2
*
DataPerRead
);
*
(
reinterpret_cast
<
vector_t
*>
(
&
p_dst
[
dst_offset
+
mDstMyThreadOffset
]))
=
*
(
reinterpret_cast
<
const
vector_t
*>
(
&
p_clipboard
[
clipboard_offset
]));
}
}
}
}
};
}
// namespace ck
#endif
composable_kernel/include/tensor_operation/blockwise_4d_tensor_op.hpp
deleted
100644 → 0
View file @
98a2cfcc
#ifndef CK_BLOCKWISE_4D_TENSOR_OP_HPP
#define CK_BLOCKWISE_4D_TENSOR_OP_HPP
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "threadwise_tensor_slice_copy.hpp"
namespace
ck
{
template
<
index_t
BlockSize
,
class
Float
,
class
DstDesc
,
class
F
>
__device__
void
blockwise_4d_tensor_pointwise_operation_unary
(
DstDesc
,
Float
*
__restrict__
p_dst
,
F
f
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
dst_desc
=
DstDesc
{};
constexpr
auto
desc
=
make_ConstantTensorDescriptor_packed
(
dst_desc
.
GetLengths
());
#if 0
if(get_thread_local_1d_id() == 0)
{
print_ConstantTensorDescriptor(dst_desc, "blockwise_4d_tensor_op_unary: dst_desc: ");
print_ConstantTensorDescriptor(desc, "blockwise_4d_tensor_op_unary: desc: ");
}
#endif
constexpr
index_t
NLoop
=
desc
.
GetElementSize
()
/
BlockSize
;
for
(
index_t
iloop
=
0
;
iloop
<
NLoop
;
++
iloop
)
{
index_t
is
=
get_thread_local_1d_id
()
+
iloop
*
BlockSize
;
const
index_t
did0
=
is
/
desc
.
GetStride
(
I0
);
is
-=
did0
*
desc
.
GetStride
(
I0
);
const
index_t
did1
=
is
/
desc
.
GetStride
(
I1
);
is
-=
did1
*
desc
.
GetStride
(
I1
);
const
index_t
did2
=
is
/
desc
.
GetStride
(
I2
);
is
-=
did2
*
desc
.
GetStride
(
I2
);
const
index_t
did3
=
is
/
desc
.
GetStride
(
I3
);
const
index_t
dindex
=
dst_desc
.
GetOffsetFromMultiIndex
(
did0
,
did1
,
did2
,
did3
);
f
(
p_dst
[
dindex
]);
}
constexpr
bool
has_tail
=
(
desc
.
GetElementSize
()
>
NLoop
*
BlockSize
);
if
(
has_tail
)
{
index_t
is
=
get_thread_local_1d_id
()
+
NLoop
*
BlockSize
;
if
(
is
<
desc
.
GetElementSize
())
{
const
index_t
did0
=
is
/
desc
.
GetStride
(
I0
);
is
-=
did0
*
desc
.
GetStride
(
I0
);
const
index_t
did1
=
is
/
desc
.
GetStride
(
I1
);
is
-=
did1
*
desc
.
GetStride
(
I1
);
const
index_t
did2
=
is
/
desc
.
GetStride
(
I2
);
is
-=
did2
*
desc
.
GetStride
(
I2
);
const
index_t
did3
=
is
/
desc
.
GetStride
(
I3
);
const
index_t
dindex
=
dst_desc
.
GetOffsetFromMultiIndex
(
did0
,
did1
,
did2
,
did3
);
f
(
p_dst
[
dindex
]);
}
}
}
// Function: p_dst[reorder[i0], reorder[i1], reorder[i2], reorder[i3]] = p_src[i0,i1,i2,i3]
// TODO: in order to optimize mem access for different mem type,
// need to write specialized version
template
<
index_t
BlockSize
,
class
Float
,
class
SrcDesc
,
class
DstDesc
,
class
SrcOpLengths
,
class
MapDst2Src
,
class
F
>
__device__
void
blockwise_4d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src
(
SrcDesc
,
const
Float
*
__restrict__
p_src
,
DstDesc
,
Float
*
__restrict__
p_dst
,
SrcOpLengths
,
MapDst2Src
,
F
f
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
index_t
IR0
=
MapDst2Src
{}.
Get
(
I0
);
constexpr
index_t
IR1
=
MapDst2Src
{}.
Get
(
I1
);
constexpr
index_t
IR2
=
MapDst2Src
{}.
Get
(
I2
);
constexpr
index_t
IR3
=
MapDst2Src
{}.
Get
(
I3
);
constexpr
auto
src_desc
=
SrcDesc
{};
constexpr
auto
dst_desc
=
DstDesc
{};
constexpr
auto
ref_desc
=
make_ConstantTensorDescriptor_packed
(
SrcOpLengths
{});
constexpr
index_t
NLoop
=
ref_desc
.
GetElementSize
()
/
BlockSize
;
for
(
index_t
iloop
=
0
;
iloop
<
NLoop
;
++
iloop
)
{
index_t
is
=
get_thread_local_1d_id
()
+
iloop
*
BlockSize
;
index_t
did
[
4
];
did
[
0
]
=
is
/
ref_desc
.
GetStride
(
I0
);
is
-=
did
[
0
]
*
ref_desc
.
GetStride
(
I0
);
did
[
1
]
=
is
/
ref_desc
.
GetStride
(
I1
);
is
-=
did
[
1
]
*
ref_desc
.
GetStride
(
I1
);
did
[
2
]
=
is
/
ref_desc
.
GetStride
(
I2
);
is
-=
did
[
2
]
*
ref_desc
.
GetStride
(
I2
);
did
[
3
]
=
is
/
ref_desc
.
GetStride
(
I3
);
const
index_t
src_index
=
src_desc
.
GetOffsetFromMultiIndex
(
did
[
0
],
did
[
1
],
did
[
2
],
did
[
3
]);
const
index_t
dst_index
=
dst_desc
.
GetOffsetFromMultiIndex
(
did
[
IR0
],
did
[
IR1
],
did
[
IR2
],
did
[
IR3
]);
f
(
p_src
[
src_index
],
p_dst
[
dst_index
]);
}
constexpr
bool
has_tail
=
(
ref_desc
.
GetElementSize
()
>
NLoop
*
BlockSize
);
if
(
has_tail
)
{
index_t
is
=
get_thread_local_1d_id
()
+
NLoop
*
BlockSize
;
if
(
is
<
ref_desc
.
GetElementSize
())
{
index_t
did
[
4
];
did
[
0
]
=
is
/
ref_desc
.
GetStride
(
I0
);
is
-=
did
[
0
]
*
ref_desc
.
GetStride
(
I0
);
did
[
1
]
=
is
/
ref_desc
.
GetStride
(
I1
);
is
-=
did
[
1
]
*
ref_desc
.
GetStride
(
I1
);
did
[
2
]
=
is
/
ref_desc
.
GetStride
(
I2
);
is
-=
did
[
2
]
*
ref_desc
.
GetStride
(
I2
);
did
[
3
]
=
is
/
ref_desc
.
GetStride
(
I3
);
const
index_t
src_index
=
src_desc
.
GetOffsetFromMultiIndex
(
did
[
0
],
did
[
1
],
did
[
2
],
did
[
3
]);
const
index_t
dst_index
=
dst_desc
.
GetOffsetFromMultiIndex
(
did
[
IR0
],
did
[
IR1
],
did
[
IR2
],
did
[
IR3
]);
f
(
p_src
[
src_index
],
p_dst
[
dst_index
]);
}
}
}
template
<
index_t
BlockSize
,
class
Float
,
class
DstDesc
>
__device__
void
blockwise_4d_tensor_set_zero
(
DstDesc
,
Float
*
__restrict__
p_dst
)
{
auto
f_set_zero
=
[](
Float
&
v
)
{
v
=
Float
(
0
);
};
blockwise_4d_tensor_pointwise_operation_unary
<
BlockSize
>
(
DstDesc
{},
p_dst
,
f_set_zero
);
}
template
<
index_t
BlockSize
,
class
Float
,
class
SrcDesc
,
class
DstDesc
,
class
SrcOpLengths
,
class
MapDst2Src
>
__device__
void
blockwise_4d_tensor_copy_reorder_by_get_dst_from_src
(
SrcDesc
,
const
Float
*
__restrict__
p_src
,
DstDesc
,
Float
*
__restrict__
p_dst
,
SrcOpLengths
,
MapDst2Src
)
{
auto
f_copy
=
[](
const
Float
&
src
,
Float
&
dst
)
{
dst
=
src
;
};
blockwise_4d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src
<
BlockSize
>
(
SrcDesc
{},
p_src
,
DstDesc
{},
p_dst
,
SrcOpLengths
{},
MapDst2Src
{},
f_copy
);
}
template
<
index_t
BlockSize
,
class
Float
,
class
SrcDesc
,
class
DstDesc
,
class
CopyLengths
,
index_t
DataPerRead
>
struct
Blockwise4dTensorCopy1
{
using
vector_t
=
typename
vector_type
<
Float
,
DataPerRead
>::
MemoryType
;
__device__
constexpr
Blockwise4dTensorCopy1
()
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
static_assert
(
DataPerRead
==
1
||
(
SrcDesc
{}.
GetStride
(
I3
)
==
1
&&
DstDesc
{}.
GetStride
(
I3
)
==
1
),
"wrong! only support stride3 == 1 if DataPerRead > 1!
\n
"
);
static_assert
(
DataPerRead
==
1
||
DataPerRead
==
2
||
DataPerRead
==
4
,
"wrong! only support DataPerRead == 1, 2 or 4!
\n
"
);
static_assert
(
SrcDesc
{}.
GetStride
(
I2
)
%
DataPerRead
==
0
&&
DstDesc
{}.
GetStride
(
I2
)
%
DataPerRead
==
0
,
"src and dst stride2 should be multiple of DataPerRead to keep alignment"
);
// we allow out-of-bound read from src in D3 dimension,
// but we need to make sure dst stride2 is big enough,
// so that the out-of-bound write won't contaminate next line in dst
constexpr
index_t
L3
=
CopyLengths
{}.
Get
(
I3
);
constexpr
index_t
read_per_d3
=
math
::
integer_divide_ceil
(
L3
,
DataPerRead
);
static_assert
(
read_per_d3
*
DataPerRead
<=
DstDesc
{}.
GetStride
(
I2
),
"wrong! out-of-bound write will contaminate next line!
\n
"
);
}
__device__
void
Run
(
const
Float
*
__restrict__
p_src
,
Float
*
__restrict__
p_dst
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
src_desc
=
SrcDesc
{};
constexpr
auto
dst_desc
=
DstDesc
{};
constexpr
index_t
L0
=
CopyLengths
{}.
Get
(
I0
);
constexpr
index_t
L1
=
CopyLengths
{}.
Get
(
I1
);
constexpr
index_t
L2
=
CopyLengths
{}.
Get
(
I2
);
constexpr
index_t
L3
=
CopyLengths
{}.
Get
(
I3
);
constexpr
index_t
read_per_d3
=
math
::
integer_divide_ceil
(
L3
,
DataPerRead
);
constexpr
auto
ref_desc
=
make_ConstantTensorDescriptor_packed
(
Sequence
<
L0
,
L1
,
L2
,
read_per_d3
>
{});
constexpr
index_t
NLoop
=
ref_desc
.
GetElementSize
()
/
BlockSize
;
auto
f_copy
=
[
&
](
index_t
is
)
{
index_t
did
[
4
];
did
[
0
]
=
is
/
ref_desc
.
GetStride
(
I0
);
is
-=
did
[
0
]
*
ref_desc
.
GetStride
(
I0
);
did
[
1
]
=
is
/
ref_desc
.
GetStride
(
I1
);
is
-=
did
[
1
]
*
ref_desc
.
GetStride
(
I1
);
did
[
2
]
=
is
/
ref_desc
.
GetStride
(
I2
);
is
-=
did
[
2
]
*
ref_desc
.
GetStride
(
I2
);
did
[
3
]
=
is
/
ref_desc
.
GetStride
(
I3
);
const
index_t
src_index
=
src_desc
.
GetOffsetFromMultiIndex
(
did
[
0
],
did
[
1
],
did
[
2
],
did
[
3
]
*
DataPerRead
);
const
index_t
dst_index
=
dst_desc
.
GetOffsetFromMultiIndex
(
did
[
0
],
did
[
1
],
did
[
2
],
did
[
3
]
*
DataPerRead
);
*
(
reinterpret_cast
<
vector_t
*>
(
p_dst
+
dst_index
))
=
*
(
reinterpret_cast
<
const
vector_t
*>
(
p_src
+
src_index
));
};
for
(
index_t
iloop
=
0
;
iloop
<
NLoop
;
++
iloop
)
{
index_t
is
=
get_thread_local_1d_id
()
+
iloop
*
BlockSize
;
f_copy
(
is
);
}
constexpr
bool
has_tail
=
(
ref_desc
.
GetElementSize
()
>
NLoop
*
BlockSize
);
if
(
has_tail
)
{
index_t
is
=
get_thread_local_1d_id
()
+
NLoop
*
BlockSize
;
if
(
is
<
ref_desc
.
GetElementSize
())
{
f_copy
(
is
);
}
}
}
};
template
<
index_t
BlockSize
,
class
Float
,
class
SrcDesc
,
class
DstDesc
,
class
DstOpLengths
,
class
GlobalLowerPads
>
struct
BlockwiseChwnTensorCopyPadded
{
__device__
void
Run
(
const
Float
*
__restrict__
p_src
,
index_t
c_block_data_begin
,
index_t
ho_block_data_begin
,
index_t
wo_block_data_begin
,
index_t
n_block_data_begin
,
Float
*
__restrict__
p_dst
,
index_t
h_block_pad_low
,
index_t
w_block_pad_low
,
index_t
h_block_pad_up
,
index_t
w_block_pad_up
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
src_desc
=
SrcDesc
{};
constexpr
auto
dst_desc
=
DstDesc
{};
constexpr
auto
ref_desc
=
make_ConstantTensorDescriptor_packed
(
DstOpLengths
{});
constexpr
auto
h_global_pad_low
=
GlobalLowerPads
{}.
Get
(
I0
);
constexpr
auto
w_global_pad_low
=
GlobalLowerPads
{}.
Get
(
I1
);
constexpr
index_t
NLoop
=
ref_desc
.
GetElementSize
()
/
BlockSize
;
const
Float
*
p_src_tmp
=
p_src
+
src_desc
.
GetOffsetFromMultiIndex
(
c_block_data_begin
,
(
ho_block_data_begin
+
h_block_pad_low
)
-
h_global_pad_low
,
(
wo_block_data_begin
+
w_block_pad_low
)
-
w_global_pad_low
,
n_block_data_begin
);
#if 0
if(get_thread_local_1d_id() == 0)
{
print_ConstantTensorDescriptor(src_desc, "src_desc: ");
print_ConstantTensorDescriptor(dst_desc, "dst_desc: ");
print_ConstantTensorDescriptor(ref_desc, "ref_desc: ");
printf("%u %u, \t"
"h_global_pad_low %u w_global_pad_low %u \t"
"h_block_pad_low %u w_block_pad_low %u h_block_pad_up %u w_block_pad_up %u \t"
"\n",
get_block_1d_id(),
get_thread_local_1d_id(),
h_global_pad_low,
w_global_pad_low,
h_block_pad_low,
w_block_pad_low,
h_block_pad_up,
w_block_pad_up);
}
#endif
for
(
index_t
iloop
=
0
;
iloop
<
NLoop
;
++
iloop
)
{
index_t
is
=
get_thread_local_1d_id
()
+
iloop
*
BlockSize
;
index_t
did
[
4
];
did
[
0
]
=
is
/
ref_desc
.
GetStride
(
I0
);
is
-=
did
[
0
]
*
ref_desc
.
GetStride
(
I0
);
did
[
1
]
=
is
/
ref_desc
.
GetStride
(
I1
);
is
-=
did
[
1
]
*
ref_desc
.
GetStride
(
I1
);
did
[
2
]
=
is
/
ref_desc
.
GetStride
(
I2
);
is
-=
did
[
2
]
*
ref_desc
.
GetStride
(
I2
);
did
[
3
]
=
is
/
ref_desc
.
GetStride
(
I3
);
const
index_t
bindex
=
dst_desc
.
GetOffsetFromMultiIndex
(
did
[
0
],
did
[
1
],
did
[
2
],
did
[
3
]);
p_dst
[
bindex
]
=
(
did
[
1
]
<
h_block_pad_low
||
did
[
1
]
+
h_block_pad_up
>=
ref_desc
.
GetLength
(
I1
)
||
did
[
2
]
<
w_block_pad_low
||
did
[
2
]
+
w_block_pad_up
>=
ref_desc
.
GetLength
(
I2
))
?
Float
(
0
)
:
p_src_tmp
[
src_desc
.
GetOffsetFromMultiIndex
(
did
[
0
],
did
[
1
],
did
[
2
],
did
[
3
])];
}
constexpr
bool
has_tail
=
(
ref_desc
.
GetElementSize
()
>
NLoop
*
BlockSize
);
if
(
has_tail
)
{
index_t
is
=
get_thread_local_1d_id
()
+
NLoop
*
BlockSize
;
if
(
is
<
ref_desc
.
GetElementSize
())
{
index_t
did
[
4
];
did
[
0
]
=
is
/
ref_desc
.
GetStride
(
I0
);
is
-=
did
[
0
]
*
ref_desc
.
GetStride
(
I0
);
did
[
1
]
=
is
/
ref_desc
.
GetStride
(
I1
);
is
-=
did
[
1
]
*
ref_desc
.
GetStride
(
I1
);
did
[
2
]
=
is
/
ref_desc
.
GetStride
(
I2
);
is
-=
did
[
2
]
*
ref_desc
.
GetStride
(
I2
);
did
[
3
]
=
is
/
ref_desc
.
GetStride
(
I3
);
const
index_t
bindex
=
dst_desc
.
GetOffsetFromMultiIndex
(
did
[
0
],
did
[
1
],
did
[
2
],
did
[
3
]);
p_dst
[
bindex
]
=
(
did
[
1
]
<
h_block_pad_low
||
did
[
1
]
+
h_block_pad_up
>=
ref_desc
.
GetLength
(
I1
)
||
did
[
2
]
<
w_block_pad_low
||
did
[
2
]
+
w_block_pad_up
>=
ref_desc
.
GetLength
(
I2
))
?
Float
(
0
)
:
p_src_tmp
[
src_desc
.
GetOffsetFromMultiIndex
(
did
[
0
],
did
[
1
],
did
[
2
],
did
[
3
])];
}
}
}
};
// starting point need to be aligned to float4 or float2 or float
// stride3 need to be 1 for both source and destination
template
<
index_t
BlockSize
,
class
Float
,
class
SrcDesc
,
class
DstDesc
,
class
CopyLengths
,
class
ThreadPerDims
,
index_t
DataPerRead
>
struct
Blockwise4dTensorCopy3
{
using
vector_t
=
typename
vector_type
<
Float
,
DataPerRead
>::
MemoryType
;
index_t
mSrcMyThreadOffset
;
index_t
mDstMyThreadOffset
;
__device__
Blockwise4dTensorCopy3
()
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
static_assert
(
DataPerRead
==
1
||
(
SrcDesc
{}.
GetStride
(
I3
)
==
1
&&
DstDesc
{}.
GetStride
(
I3
)
==
1
),
"wrong! only support stride3 == 1 if DataPerRead > 1!
\n
"
);
static_assert
(
DataPerRead
==
1
||
DataPerRead
==
2
||
DataPerRead
==
4
,
"wrong! only support DataPerRead == 1, 2 or 4!
\n
"
);
static_assert
(
SrcDesc
{}.
GetStride
(
I2
)
%
DataPerRead
==
0
&&
DstDesc
{}.
GetStride
(
I2
)
%
DataPerRead
==
0
,
"wrong! src and dst stride2 should be multiple of DataPerRead to keep alignment"
);
constexpr
index_t
L0
=
CopyLengths
{}.
Get
(
I0
);
constexpr
index_t
L1
=
CopyLengths
{}.
Get
(
I1
);
constexpr
index_t
L2
=
CopyLengths
{}.
Get
(
I2
);
constexpr
index_t
L3
=
CopyLengths
{}.
Get
(
I3
);
constexpr
index_t
thread_per_d0
=
ThreadPerDims
{}.
Get
(
I0
);
constexpr
index_t
thread_per_d1
=
ThreadPerDims
{}.
Get
(
I1
);
constexpr
index_t
thread_per_d2
=
ThreadPerDims
{}.
Get
(
I2
);
constexpr
index_t
thread_per_d3
=
ThreadPerDims
{}.
Get
(
I3
);
// we allow out-of-bound read from src in D3 dimension,
// but we need to make sure dst stride is big enough,
// so that the out-of-bound write won't contaminate next line in dst
constexpr
index_t
nloop_d3
=
math
::
integer_divide_ceil
(
L3
,
thread_per_d3
*
DataPerRead
);
static_assert
(
nloop_d3
*
thread_per_d3
*
DataPerRead
<=
DstDesc
{}.
GetStride
(
I2
),
"wrong! out-of-bound write will contaminate next line!
\n
"
);
static_assert
(
L0
%
thread_per_d0
==
0
&&
L1
%
thread_per_d1
==
0
&&
L2
%
thread_per_d2
==
0
,
"wrong! L0, L1, L2 should be divided evenly!
\n
"
);
static_assert
(
BlockSize
>=
thread_per_d0
*
thread_per_d1
*
thread_per_d2
*
thread_per_d3
,
"wrrong! BlockSize is not big enough for ThreadPerDims!"
);
constexpr
index_t
num_active_thread
=
reduce_on_sequence
(
ThreadPerDims
{},
math
::
multiplies
<
index_t
>
{},
Number
<
1
>
{});
if
(
BlockSize
>
num_active_thread
)
{
if
(
get_thread_local_1d_id
()
>=
num_active_thread
)
{
return
;
}
}
constexpr
auto
thread_cluster_desc
=
make_ConstantTensorDescriptor_packed
(
ThreadPerDims
{});
const
auto
thread_multi_id
=
thread_cluster_desc
.
GetMultiIndexFrom1dIndex
(
get_thread_local_1d_id
());
mSrcMyThreadOffset
=
SrcDesc
{}.
GetOffsetFromMultiIndex
(
thread_multi_id
[
0
],
thread_multi_id
[
1
],
thread_multi_id
[
2
],
thread_multi_id
[
3
]
*
DataPerRead
);
mDstMyThreadOffset
=
DstDesc
{}.
GetOffsetFromMultiIndex
(
thread_multi_id
[
0
],
thread_multi_id
[
1
],
thread_multi_id
[
2
],
thread_multi_id
[
3
]
*
DataPerRead
);
}
__device__
void
Run
(
const
Float
*
__restrict__
p_src
,
Float
*
__restrict__
p_dst
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
index_t
L0
=
CopyLengths
{}.
Get
(
I0
);
constexpr
index_t
L1
=
CopyLengths
{}.
Get
(
I1
);
constexpr
index_t
L2
=
CopyLengths
{}.
Get
(
I2
);
constexpr
index_t
L3
=
CopyLengths
{}.
Get
(
I3
);
constexpr
index_t
thread_per_d0
=
ThreadPerDims
{}.
Get
(
I0
);
constexpr
index_t
thread_per_d1
=
ThreadPerDims
{}.
Get
(
I1
);
constexpr
index_t
thread_per_d2
=
ThreadPerDims
{}.
Get
(
I2
);
constexpr
index_t
thread_per_d3
=
ThreadPerDims
{}.
Get
(
I3
);
constexpr
index_t
num_active_thread
=
thread_per_d0
*
thread_per_d1
*
thread_per_d2
*
thread_per_d3
;
if
(
BlockSize
>
num_active_thread
)
{
if
(
get_thread_local_1d_id
()
>=
num_active_thread
)
{
return
;
}
}
constexpr
index_t
nloop_d0
=
L0
/
thread_per_d0
;
constexpr
index_t
nloop_d1
=
L1
/
thread_per_d1
;
constexpr
index_t
nloop_d2
=
L2
/
thread_per_d2
;
constexpr
index_t
nloop_d3
=
math
::
integer_divide_ceil
(
L3
,
thread_per_d3
*
DataPerRead
);
#pragma unroll
for
(
index_t
iloop_d0
=
0
;
iloop_d0
<
nloop_d0
;
++
iloop_d0
)
{
#pragma unroll
for
(
index_t
iloop_d1
=
0
;
iloop_d1
<
nloop_d1
;
++
iloop_d1
)
{
#pragma unroll
for
(
index_t
iloop_d2
=
0
;
iloop_d2
<
nloop_d2
;
++
iloop_d2
)
{
#pragma unroll
for
(
index_t
iloop_d3
=
0
;
iloop_d3
<
nloop_d3
;
++
iloop_d3
)
{
const
index_t
src_offset
=
SrcDesc
{}.
GetOffsetFromMultiIndex
(
iloop_d0
*
thread_per_d0
,
iloop_d1
*
thread_per_d1
,
iloop_d2
*
thread_per_d2
,
iloop_d3
*
thread_per_d3
*
DataPerRead
);
const
index_t
dst_offset
=
DstDesc
{}.
GetOffsetFromMultiIndex
(
iloop_d0
*
thread_per_d0
,
iloop_d1
*
thread_per_d1
,
iloop_d2
*
thread_per_d2
,
iloop_d3
*
thread_per_d3
*
DataPerRead
);
*
(
reinterpret_cast
<
vector_t
*>
(
&
p_dst
[
dst_offset
+
mDstMyThreadOffset
]))
=
*
(
reinterpret_cast
<
const
vector_t
*>
(
&
p_src
[
src_offset
+
mSrcMyThreadOffset
]));
}
}
}
}
}
__device__
constexpr
index_t
GetRegisterBufferSize
()
const
{
static_assert
(
is_same
<
Float
,
float
>
{},
"wrong! only support float!
\n
"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
index_t
L0
=
CopyLengths
{}.
Get
(
I0
);
constexpr
index_t
L1
=
CopyLengths
{}.
Get
(
I1
);
constexpr
index_t
L2
=
CopyLengths
{}.
Get
(
I2
);
constexpr
index_t
L3
=
CopyLengths
{}.
Get
(
I3
);
constexpr
index_t
thread_per_d0
=
ThreadPerDims
{}.
Get
(
I0
);
constexpr
index_t
thread_per_d1
=
ThreadPerDims
{}.
Get
(
I1
);
constexpr
index_t
thread_per_d2
=
ThreadPerDims
{}.
Get
(
I2
);
constexpr
index_t
thread_per_d3
=
ThreadPerDims
{}.
Get
(
I3
);
constexpr
index_t
nloop_d0
=
L0
/
thread_per_d0
;
constexpr
index_t
nloop_d1
=
L1
/
thread_per_d1
;
constexpr
index_t
nloop_d2
=
L2
/
thread_per_d2
;
constexpr
index_t
nloop_d3
=
math
::
integer_divide_ceil
(
L3
,
thread_per_d3
*
DataPerRead
);
return
DataPerRead
*
nloop_d0
*
nloop_d1
*
nloop_d2
*
nloop_d3
;
}
__device__
void
RunLoadRegisterBuffer
(
const
Float
*
__restrict__
p_src
,
Float
*
__restrict__
p_clipboard
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
index_t
L0
=
CopyLengths
{}.
Get
(
I0
);
constexpr
index_t
L1
=
CopyLengths
{}.
Get
(
I1
);
constexpr
index_t
L2
=
CopyLengths
{}.
Get
(
I2
);
constexpr
index_t
L3
=
CopyLengths
{}.
Get
(
I3
);
constexpr
index_t
thread_per_d0
=
ThreadPerDims
{}.
Get
(
I0
);
constexpr
index_t
thread_per_d1
=
ThreadPerDims
{}.
Get
(
I1
);
constexpr
index_t
thread_per_d2
=
ThreadPerDims
{}.
Get
(
I2
);
constexpr
index_t
thread_per_d3
=
ThreadPerDims
{}.
Get
(
I3
);
constexpr
index_t
num_active_thread
=
thread_per_d0
*
thread_per_d1
*
thread_per_d2
*
thread_per_d3
;
if
(
BlockSize
>
num_active_thread
)
{
if
(
get_thread_local_1d_id
()
>=
num_active_thread
)
{
return
;
}
}
constexpr
index_t
nloop_d0
=
L0
/
thread_per_d0
;
constexpr
index_t
nloop_d1
=
L1
/
thread_per_d1
;
constexpr
index_t
nloop_d2
=
L2
/
thread_per_d2
;
constexpr
index_t
nloop_d3
=
math
::
integer_divide_ceil
(
L3
,
thread_per_d3
*
DataPerRead
);
constexpr
auto
clipboard_desc
=
make_ConstantTensorDescriptor_packed
(
Sequence
<
nloop_d0
,
nloop_d1
,
nloop_d2
,
nloop_d3
*
DataPerRead
>
{});
#pragma unroll
for
(
index_t
iloop_d0
=
0
;
iloop_d0
<
nloop_d0
;
++
iloop_d0
)
{
#pragma unroll
for
(
index_t
iloop_d1
=
0
;
iloop_d1
<
nloop_d1
;
++
iloop_d1
)
{
#pragma unroll
for
(
index_t
iloop_d2
=
0
;
iloop_d2
<
nloop_d2
;
++
iloop_d2
)
{
#pragma unroll
for
(
index_t
iloop_d3
=
0
;
iloop_d3
<
nloop_d3
;
++
iloop_d3
)
{
const
index_t
src_offset
=
SrcDesc
{}.
GetOffsetFromMultiIndex
(
iloop_d0
*
thread_per_d0
,
iloop_d1
*
thread_per_d1
,
iloop_d2
*
thread_per_d2
,
iloop_d3
*
thread_per_d3
*
DataPerRead
);
const
index_t
clipboard_offset
=
clipboard_desc
.
GetOffsetFromMultiIndex
(
iloop_d0
,
iloop_d1
,
iloop_d2
,
iloop_d3
*
DataPerRead
);
*
(
reinterpret_cast
<
vector_t
*>
(
&
p_clipboard
[
clipboard_offset
]))
=
*
(
reinterpret_cast
<
const
vector_t
*>
(
&
p_src
[
src_offset
+
mSrcMyThreadOffset
]));
}
}
}
}
}
__device__
void
RunStoreRegisterBuffer
(
const
Float
*
__restrict__
p_clipboard
,
Float
*
__restrict__
p_dst
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
index_t
L0
=
CopyLengths
{}.
Get
(
I0
);
constexpr
index_t
L1
=
CopyLengths
{}.
Get
(
I1
);
constexpr
index_t
L2
=
CopyLengths
{}.
Get
(
I2
);
constexpr
index_t
L3
=
CopyLengths
{}.
Get
(
I3
);
constexpr
index_t
thread_per_d0
=
ThreadPerDims
{}.
Get
(
I0
);
constexpr
index_t
thread_per_d1
=
ThreadPerDims
{}.
Get
(
I1
);
constexpr
index_t
thread_per_d2
=
ThreadPerDims
{}.
Get
(
I2
);
constexpr
index_t
thread_per_d3
=
ThreadPerDims
{}.
Get
(
I3
);
constexpr
index_t
num_active_thread
=
thread_per_d0
*
thread_per_d1
*
thread_per_d2
*
thread_per_d3
;
if
(
BlockSize
>
num_active_thread
)
{
if
(
get_thread_local_1d_id
()
>=
num_active_thread
)
{
return
;
}
}
constexpr
index_t
nloop_d0
=
L0
/
thread_per_d0
;
constexpr
index_t
nloop_d1
=
L1
/
thread_per_d1
;
constexpr
index_t
nloop_d2
=
L2
/
thread_per_d2
;
constexpr
index_t
nloop_d3
=
math
::
integer_divide_ceil
(
L3
,
thread_per_d3
*
DataPerRead
);
constexpr
auto
clipboard_desc
=
make_ConstantTensorDescriptor_packed
(
Sequence
<
nloop_d0
,
nloop_d1
,
nloop_d2
,
nloop_d3
*
DataPerRead
>
{});
#pragma unroll
for
(
index_t
iloop_d0
=
0
;
iloop_d0
<
nloop_d0
;
++
iloop_d0
)
{
#pragma unroll
for
(
index_t
iloop_d1
=
0
;
iloop_d1
<
nloop_d1
;
++
iloop_d1
)
{
#pragma unroll
for
(
index_t
iloop_d2
=
0
;
iloop_d2
<
nloop_d2
;
++
iloop_d2
)
{
#pragma unroll
for
(
index_t
iloop_d3
=
0
;
iloop_d3
<
nloop_d3
;
++
iloop_d3
)
{
const
index_t
clipboard_offset
=
clipboard_desc
.
GetOffsetFromMultiIndex
(
iloop_d0
,
iloop_d1
,
iloop_d2
,
iloop_d3
*
DataPerRead
);
const
index_t
dst_offset
=
DstDesc
{}.
GetOffsetFromMultiIndex
(
iloop_d0
*
thread_per_d0
,
iloop_d1
*
thread_per_d1
,
iloop_d2
*
thread_per_d2
,
iloop_d3
*
thread_per_d3
*
DataPerRead
);
*
(
reinterpret_cast
<
vector_t
*>
(
&
p_dst
[
dst_offset
+
mDstMyThreadOffset
]))
=
*
(
reinterpret_cast
<
const
vector_t
*>
(
&
p_clipboard
[
clipboard_offset
]));
}
}
}
}
}
};
template
<
index_t
BlockSize
,
class
Float
,
class
SrcDesc
,
class
DstDesc
,
class
SrcOpLengths
,
class
MapDst2Src
>
struct
Blockwise4dTensorCopyReorder1
{
__device__
void
Run
(
const
Float
*
__restrict__
p_src
,
Float
*
__restrict__
p_dst
)
const
{
auto
f_copy
=
[](
const
Float
&
src
,
Float
&
dst
)
{
dst
=
src
;
};
blockwise_4d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src
<
BlockSize
>
(
SrcDesc
{},
p_src
,
DstDesc
{},
p_dst
,
SrcOpLengths
{},
MapDst2Src
{},
f_copy
);
}
};
}
// namespace
#endif
composable_kernel/include/tensor_operation/blockwise_generic_tensor_slice_copy_deprecated.hpp
View file @
9b280cc5
...
@@ -4,7 +4,6 @@
...
@@ -4,7 +4,6 @@
#include "common_header.hpp"
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMergedTensorDescriptor.hpp"
#include "ConstantMergedTensorDescriptor.hpp"
#include "tensor_view.hpp"
#include "tensor_coordinate_deprecated.hpp"
#include "tensor_coordinate_deprecated.hpp"
#include "threadwise_generic_tensor_slice_copy_deprecated.hpp"
#include "threadwise_generic_tensor_slice_copy_deprecated.hpp"
...
@@ -484,14 +483,8 @@ struct BlockwiseGenericTensorSliceCopy_v2
...
@@ -484,14 +483,8 @@ struct BlockwiseGenericTensorSliceCopy_v2
address_space_t
ThreadBufferAddressSpace
=
address_space_t
::
generic
>
address_space_t
ThreadBufferAddressSpace
=
address_space_t
::
generic
>
__device__
void
RunLoadThreadBuffer
(
const
TData
*
p_block_src
,
TData
*
p_thread_buffer
)
const
__device__
void
RunLoadThreadBuffer
(
const
TData
*
p_block_src
,
TData
*
p_thread_buffer
)
const
{
{
#if 0
mThreadwiseLoad
.
Run
<
TData
,
BlockSrcAddressSpace
,
ThreadBufferAddressSpace
>
(
p_block_src
,
mThreadwiseLoad.Run(p_block_src, p_thread_buffer);
p_thread_buffer
);
#else
// tweaking
mThreadwiseLoad
.
template
Run_optimized_address_calculation
<
TData
,
BlockSrcAddressSpace
,
ThreadBufferAddressSpace
>(
p_block_src
,
p_thread_buffer
);
#endif
}
}
template
<
typename
TData
,
template
<
typename
TData
,
...
@@ -499,14 +492,8 @@ struct BlockwiseGenericTensorSliceCopy_v2
...
@@ -499,14 +492,8 @@ struct BlockwiseGenericTensorSliceCopy_v2
address_space_t
BlockDstAddressSpace
=
address_space_t
::
generic
>
address_space_t
BlockDstAddressSpace
=
address_space_t
::
generic
>
__device__
void
RunStoreThreadBuffer
(
const
TData
*
p_thread_buffer
,
TData
*
p_block_dst
)
const
__device__
void
RunStoreThreadBuffer
(
const
TData
*
p_thread_buffer
,
TData
*
p_block_dst
)
const
{
{
#if 0
mThreadwiseStore
.
Run
<
TData
,
ThreadBufferAddressSpace
,
BlockDstAddressSpace
>
(
p_thread_buffer
,
mThreadwiseStore.Run(p_thread_buffer, p_block_dst);
p_block_dst
);
#else
// tweaking
mThreadwiseStore
.
template
Run_optimized_address_calculation
<
TData
,
ThreadBufferAddressSpace
,
BlockDstAddressSpace
>(
p_thread_buffer
,
p_block_dst
);
#endif
}
}
template
<
typename
TData
,
template
<
typename
TData
,
...
@@ -563,130 +550,6 @@ struct BlockwiseGenericTensorSliceCopy_v2
...
@@ -563,130 +550,6 @@ struct BlockwiseGenericTensorSliceCopy_v2
ThreadwiseStore
mThreadwiseStore
;
ThreadwiseStore
mThreadwiseStore
;
};
};
// this version use TensorView and TensorCoordinate_deprecated
template
<
index_t
BlockSize
,
typename
SrcTensor
,
typename
DstTensor
,
typename
SliceLengths
,
typename
SubLengths
,
typename
ThreadClusterLengths
,
typename
ThreadClusterArrangeOrder
,
typename
SrcDimAccessOrder
,
typename
DstDimAccessOrder
,
index_t
SrcVectorAccessDim
,
index_t
DstVectorAccessDim
,
index_t
SrcDataPerAccess
,
index_t
DstDataPerAccess
>
struct
BlockwiseGenericTensorSliceCopy_v3
{
static
constexpr
index_t
nDim
=
SrcTensor
::
GetNumOfDimension
();
using
data_type
=
remove_cv_t
<
typename
SrcTensor
::
data_type
>
;
using
SrcCoordinate
=
typename
SrcTensor
::
coordinate_type
;
using
DstCoordinate
=
typename
DstTensor
::
coordinate_type
;
__device__
constexpr
BlockwiseGenericTensorSliceCopy_v3
(
SrcTensor
src_block
,
SrcCoordinate
src_block_slice_origin
,
DstTensor
dst_block
,
DstCoordinate
dst_block_slice_origin
)
:
mThreadBuffer
{
make_TensorView
(
ThreadBufferDesc
{},
mpBuffer
)}
{
static_assert
(
nDim
==
SrcTensor
::
GetNumOfDimension
()
&&
nDim
==
DstTensor
::
GetNumOfDimension
()
&&
nDim
==
SliceLengths
::
GetSize
()
&&
nDim
==
SubLengths
::
GetSize
()
&&
nDim
==
ThreadClusterLengths
::
GetSize
()
&&
nDim
==
ThreadClusterArrangeOrder
::
GetSize
()
&&
nDim
==
SrcDimAccessOrder
::
GetSize
()
&&
nDim
==
DstDimAccessOrder
::
GetSize
(),
"wrong! nDim not consistent"
);
static_assert
(
is_same
<
SliceLengths
,
decltype
(
SubLengths
{}
*
ThreadClusterLengths
{})
>
{},
"wrong! threads should be mapped to cover entire slicing window"
);
static_assert
(
is_same
<
remove_cv_t
<
typename
SrcTensor
::
data_type
>
,
remove_cv_t
<
typename
DstTensor
::
data_type
>>
{},
"wrong! type conversion not supported yet"
);
constexpr
auto
thread_cluster_desc
=
make_ConstantTensorDescriptor_packed
(
ThreadClusterLengths
::
ReorderGivenNew2Old
(
ThreadClusterArrangeOrder
{}));
static_assert
(
BlockSize
==
thread_cluster_desc
.
GetElementSize
(),
"wrong! BlockSize not consistent with ThreadClusterLengths"
);
const
auto
thread_cluster_id
=
thread_cluster_desc
.
GetMultiIndexFrom1dIndex
(
get_thread_local_1d_id
());
const
auto
data_cluster_id
=
reorder_array_given_old2new
(
thread_cluster_id
,
ThreadClusterArrangeOrder
{});
const
auto
thread_data_id_begin
=
data_cluster_id
*
SubLengths
{};
mThreadwiseLoad
=
ThreadwiseLoad
(
src_block
,
src_block_slice_origin
+
thread_data_id_begin
,
mThreadBuffer
,
make_zero_array
<
index_t
,
nDim
>
());
mThreadwiseStore
=
ThreadwiseStore
(
mThreadBuffer
,
make_zero_array
<
index_t
,
nDim
>
(),
dst_block
,
dst_block_slice_origin
+
thread_data_id_begin
);
}
__device__
void
RunLoadRegisterBuffer
()
{
mThreadwiseLoad
.
Run
();
}
__device__
void
RunStoreRegisterBuffer
()
const
{
mThreadwiseStore
.
Run
();
}
__device__
void
Run
()
{
mThreadwiseLoad
.
Run
();
mThreadwiseStore
.
Run
();
}
template
<
typename
T
,
bool
PositiveDirection
>
__device__
void
MoveSrcSliceWindow
(
T
step_sizes
,
integral_constant
<
bool
,
PositiveDirection
>
positive_direction
)
{
mThreadwiseLoad
.
MoveSrcSliceWindow
(
step_sizes
,
positive_direction
);
}
template
<
typename
T
,
bool
PositiveDirection
>
__device__
void
MoveDstSliceWindow
(
T
step_sizes
,
integral_constant
<
bool
,
PositiveDirection
>
positive_direction
)
{
mThreadwiseStore
.
MoveDstSliceWindow
(
step_sizes
,
positive_direction
);
}
private:
using
ThreadBufferDesc
=
decltype
(
make_ConstantTensorDescriptor_packed
(
SubLengths
{}));
using
ThreadBufferTensor
=
NormalTensorView
<
ThreadBufferDesc
,
data_type
>
;
using
ThreadwiseLoad
=
ThreadwiseGenericTensorSliceCopy_v3r1
<
SrcTensor
,
ThreadBufferTensor
,
SubLengths
,
SrcDimAccessOrder
,
SrcDimAccessOrder
,
SrcVectorAccessDim
,
SrcVectorAccessDim
,
SrcDataPerAccess
,
1
>
;
using
ThreadwiseStore
=
ThreadwiseGenericTensorSliceCopy_v3r1
<
ThreadBufferTensor
,
DstTensor
,
SubLengths
,
DstDimAccessOrder
,
DstDimAccessOrder
,
DstVectorAccessDim
,
DstVectorAccessDim
,
1
,
DstDataPerAccess
>
;
data_type
mpBuffer
[
ThreadBufferDesc
::
GetElementSpace
()];
ThreadBufferTensor
mThreadBuffer
;
ThreadwiseLoad
mThreadwiseLoad
;
ThreadwiseStore
mThreadwiseStore
;
};
}
// namespace ck
}
// namespace ck
#endif
#endif
composable_kernel/include/tensor_operation/blockwise_tensor_slice_copy.hpp
deleted
100644 → 0
View file @
98a2cfcc
#ifndef CK_BLOCKWISE_TENSOR_SLICE_COPY_HPP
#define CK_BLOCKWISE_TENSOR_SLICE_COPY_HPP
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "threadwise_tensor_slice_copy.hpp"
namespace
ck
{
template
<
index_t
BlockSize
,
class
Float
,
class
SrcDesc
,
class
DstDesc
,
class
SrcLengths
,
class
SrcSubLengths
,
class
SrcClusterLengths
,
class
MapDst2Src
,
class
MapThreadCluster2SrcCluster
,
index_t
SrcDataPerRead
,
index_t
DstDataPerWrite
>
struct
BlockwiseTensorSliceReorderCopy_v3
{
static
constexpr
index_t
nDim
=
SrcLengths
::
GetSize
();
index_t
mThreadSrcOffset
;
index_t
mThreadDstOffset
;
__device__
BlockwiseTensorSliceReorderCopy_v3
(
Array
<
index_t
,
nDim
>
src_block_data_multi_id_begin
,
Array
<
index_t
,
nDim
>
dst_block_data_multi_id_begin
)
{
constexpr
auto
src_desc
=
SrcDesc
{};
constexpr
auto
dst_desc
=
DstDesc
{};
constexpr
auto
src_lengths
=
SrcLengths
{};
constexpr
auto
map_dst2src
=
MapDst2Src
{};
constexpr
auto
src_sub_lengths
=
SrcSubLengths
{};
constexpr
auto
dst_sub_lengths
=
src_sub_lengths
.
ReorderGivenNew2Old
(
map_dst2src
);
constexpr
auto
map_thread_cluster_2_src_cluster
=
MapThreadCluster2SrcCluster
{};
constexpr
auto
src_cluster_lengths
=
SrcClusterLengths
{};
constexpr
auto
thread_cluster_lengths
=
src_cluster_lengths
.
ReorderGivenNew2Old
(
map_thread_cluster_2_src_cluster
);
constexpr
auto
thread_cluster_desc
=
make_ConstantTensorDescriptor_packed
(
thread_cluster_lengths
);
// sanity check: data type
static_assert
(
is_same
<
Float
,
float
>
{},
"wrong! only support float for now!
\n
"
);
// sanity check: nDim
static_assert
(
SrcDesc
::
GetNumOfDimension
()
==
nDim
&&
DstDesc
::
GetNumOfDimension
()
==
nDim
&&
SrcLengths
::
GetSize
()
==
nDim
&&
SrcSubLengths
::
GetSize
()
==
nDim
&&
SrcClusterLengths
::
GetSize
()
==
nDim
&&
MapDst2Src
::
GetSize
()
==
nDim
&&
MapThreadCluster2SrcCluster
::
GetSize
()
==
nDim
,
"wrong! nDim is not consistent
\n
"
);
// sanity check: BlockSize
constexpr
index_t
num_active_thread
=
thread_cluster_desc
.
GetElementSize
();
static_assert
(
BlockSize
>=
num_active_thread
,
"wrong! BlockSize is not big enough for ThreadPerDims!"
);
// sanity check: work division
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
IDim
)
{
constexpr
auto
I
=
decltype
(
IDim
){};
constexpr
index_t
src_len
=
src_lengths
.
Get
(
I
);
constexpr
index_t
src_sub_len
=
src_sub_lengths
.
Get
(
I
);
constexpr
index_t
src_cluster_len
=
src_cluster_lengths
.
Get
(
I
);
static_assert
(
src_len
%
(
src_sub_len
*
src_cluster_len
)
==
0
,
"wrong! cannot evenly divide Src tensor lengths"
);
});
// sanity check: src read
static_assert
(
SrcDataPerRead
==
1
||
SrcDataPerRead
==
2
||
SrcDataPerRead
==
4
,
"wrong! only support SrcDataPerRead == 1, 2 or 4!
\n
"
);
static_assert
(
SrcDataPerRead
==
1
||
src_desc
.
GetStride
(
Number
<
nDim
-
1
>
{})
==
1
,
"wrong! only support src.stride(nDim-1) == 1 if SrcDataPerRead > 1!
\n
"
);
static_assert
(
src_sub_lengths
.
Get
(
Number
<
nDim
-
1
>
{})
%
SrcDataPerRead
==
0
,
"wrong! src_sub_lengths[nDim-1] % SrcDataPerRead != 0
\n
"
);
static_assert
(
src_desc
.
GetStride
(
Number
<
nDim
-
2
>
{})
%
SrcDataPerRead
==
0
,
"wrong! should satisfy src_desc.stride(nDim-2) % SrcDataPerRead == 0, to "
"keep alignment"
);
// sanity check: dst write
static_assert
(
DstDataPerWrite
==
1
||
DstDataPerWrite
==
2
||
DstDataPerWrite
==
4
,
"wrong! only support DstDataPerWrite == 1, 2 or 4!
\n
"
);
static_assert
(
DstDataPerWrite
==
1
||
dst_desc
.
GetStride
(
Number
<
nDim
-
1
>
{})
==
1
,
"wrong! only support dst.stride(nDim-1) == 1 if DstDataPerWrite > 1!
\n
"
);
static_assert
(
dst_sub_lengths
.
Get
(
Number
<
nDim
-
1
>
{})
%
DstDataPerWrite
==
0
,
"wrong! dst_sub_lengths[nDim-1] % DstDataPerWrite != 0
\n
"
);
static_assert
(
dst_desc
.
GetStride
(
Number
<
nDim
-
2
>
{})
%
DstDataPerWrite
==
0
,
"wrong! should satisfy dst_desc.stride(nDim-2) % DstDataPerWrite == 0, to "
"keep alignment"
);
// start dividing work
if
(
BlockSize
>
num_active_thread
)
{
if
(
get_thread_local_1d_id
()
>=
num_active_thread
)
{
return
;
}
}
const
auto
thread_multi_id
=
thread_cluster_desc
.
GetMultiIndexFrom1dIndex
(
get_thread_local_1d_id
());
// compiler: thread_multi_id, src_data_multi_id, dst_data_multi_id, will use separate
// regsiters, or only one copy???
auto
src_data_multi_id
=
reorder_array_given_old2new
(
thread_multi_id
,
map_thread_cluster_2_src_cluster
);
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
IDim
)
{
constexpr
index_t
idim
=
IDim
;
// compiler: will it really compute index here, or be merged with
// GetOffsetFromMultiIndex and
// optimized away???
src_data_multi_id
(
idim
)
*=
src_sub_lengths
.
Get
(
IDim
);
});
// compiler: will it really compute index here, or be merged with GetOffsetFromMultiIndex
// and
// optimized away???
const
auto
dst_data_multi_id
=
reorder_array_given_new2old
(
src_data_multi_id
,
map_dst2src
);
mThreadSrcOffset
=
src_desc
.
GetOffsetFromMultiIndex
(
src_data_multi_id
+
src_block_data_multi_id_begin
);
mThreadDstOffset
=
dst_desc
.
GetOffsetFromMultiIndex
(
dst_data_multi_id
+
dst_block_data_multi_id_begin
);
#if 0
if(get_block_1d_id() == 0 && get_thread_local_1d_id() == 0)
{
print_ConstantTensorDescriptor(thread_cluster_desc, "thread_cluster_desc: ");
}
if(get_block_1d_id() == 0)
{
printf("id %5u %5u: "
"thread_multi_id: %u %u, "
"src_block_data_multi_id_begin: %u %u, "
"src_data_multi_id: %u %u, "
"mThreadSrcOffset %u, mThreadDstOffset %u \n",
get_block_1d_id(),
get_thread_local_1d_id(),
thread_multi_id[0],
thread_multi_id[1],
src_block_data_multi_id_begin[0],
src_block_data_multi_id_begin[1],
src_data_multi_id[0],
src_data_multi_id[1],
mThreadSrcOffset,
mThreadDstOffset);
}
#endif
}
__device__
static
constexpr
index_t
GetRegisterBufferSize
()
{
constexpr
auto
thread_sub_tensor_lengths
=
SrcSubLengths
{};
constexpr
auto
src_data_per_cluster_per_dims
=
thread_sub_tensor_lengths
*
SrcClusterLengths
{};
constexpr
auto
repeat_lengths
=
transform_sequences
(
math
::
integer_divide_ceiler
<
index_t
>
{},
SrcLengths
{},
src_data_per_cluster_per_dims
);
constexpr
auto
thread_tensor_lengths
=
thread_sub_tensor_lengths
*
repeat_lengths
;
constexpr
auto
thread_tensor_desc
=
make_ConstantTensorDescriptor_packed
(
thread_tensor_lengths
);
return
thread_tensor_desc
.
GetElementSpace
();
}
__device__
void
RunLoadRegisterBuffer
(
const
Float
*
__restrict__
p_src
,
Float
*
__restrict__
p_clipboard
)
const
{
constexpr
auto
thread_sub_tensor_lengths
=
SrcSubLengths
{};
constexpr
auto
src_data_per_cluster_per_dims
=
thread_sub_tensor_lengths
*
SrcClusterLengths
{};
constexpr
auto
repeat_lengths
=
transform_sequences
(
math
::
integer_divide_ceiler
<
index_t
>
{},
SrcLengths
{},
src_data_per_cluster_per_dims
);
constexpr
auto
thread_tensor_lengths
=
thread_sub_tensor_lengths
*
repeat_lengths
;
constexpr
auto
thread_tensor_desc
=
make_ConstantTensorDescriptor_packed
(
thread_tensor_lengths
);
static_ford
<
decltype
(
repeat_lengths
)
>
{}([
&
](
auto
repeat_multi_id_
)
{
constexpr
auto
repeat_multi_id
=
decltype
(
repeat_multi_id_
){};
constexpr
auto
src_data_multi_id
=
repeat_multi_id
*
src_data_per_cluster_per_dims
;
constexpr
auto
clipboard_data_multi_id
=
repeat_multi_id
*
thread_sub_tensor_lengths
;
constexpr
index_t
src_offset
=
SrcDesc
{}.
GetOffsetFromMultiIndex
(
src_data_multi_id
);
constexpr
index_t
clipboard_offset
=
thread_tensor_desc
.
GetOffsetFromMultiIndex
(
clipboard_data_multi_id
);
threadwise_tensor_slice_copy
(
SrcDesc
{},
p_src
+
src_offset
+
mThreadSrcOffset
,
thread_tensor_desc
,
p_clipboard
+
clipboard_offset
,
thread_sub_tensor_lengths
,
Number
<
SrcDataPerRead
>
{});
});
}
__device__
void
RunStoreRegisterBuffer
(
const
Float
*
__restrict__
p_clipboard
,
Float
*
__restrict__
p_dst
)
const
{
constexpr
auto
thread_sub_tensor_lengths
=
SrcSubLengths
{};
constexpr
auto
src_data_per_cluster_per_dims
=
thread_sub_tensor_lengths
*
SrcClusterLengths
{};
constexpr
auto
repeat_lengths
=
transform_sequences
(
math
::
integer_divide_ceiler
<
index_t
>
{},
SrcLengths
{},
src_data_per_cluster_per_dims
);
constexpr
auto
thread_tensor_lengths
=
thread_sub_tensor_lengths
*
repeat_lengths
;
constexpr
auto
thread_tensor_desc
=
make_ConstantTensorDescriptor_packed
(
thread_tensor_lengths
);
static_ford
<
decltype
(
repeat_lengths
)
>
{}([
&
](
auto
repeat_multi_id_
)
{
constexpr
auto
repeat_multi_id
=
decltype
(
repeat_multi_id_
){};
constexpr
auto
clipboard_data_multi_id
=
repeat_multi_id
*
thread_sub_tensor_lengths
;
constexpr
auto
src_data_multi_id
=
repeat_multi_id
*
src_data_per_cluster_per_dims
;
// reorder src_data_multi_id to get dst_data_multi_id
constexpr
auto
dst_data_multi_id
=
src_data_multi_id
.
ReorderGivenNew2Old
(
MapDst2Src
{});
constexpr
index_t
clipboard_offset
=
thread_tensor_desc
.
GetOffsetFromMultiIndex
(
clipboard_data_multi_id
);
constexpr
index_t
dst_offset
=
DstDesc
{}.
GetOffsetFromMultiIndex
(
dst_data_multi_id
);
// write in the order of dst
#if 1
threadwise_tensor_slice_copy_reorder_given_dst2src_v2
(
thread_tensor_desc
,
p_clipboard
+
clipboard_offset
,
DstDesc
{},
p_dst
+
dst_offset
+
mThreadDstOffset
,
thread_sub_tensor_lengths
,
MapDst2Src
{});
#else
threadwise_tensor_slice_copy_reorder_given_dst2src_v3
(
thread_tensor_desc
,
p_clipboard
+
clipboard_offset
,
DstDesc
{},
p_dst
+
dst_offset
+
mThreadDstOffset
,
thread_sub_tensor_lengths
,
MapDst2Src
{},
Number
<
DstDataPerWrite
>
{});
#endif
});
}
__device__
void
Run
(
const
Float
*
__restrict__
p_src
,
Float
*
__restrict__
p_dst
)
const
{
Float
p_clipboard
[
GetRegisterBufferSize
()];
RunLoadRegisterBuffer
(
p_src
,
p_clipboard
);
RunStoreRegisterBuffer
(
p_clipboard
,
p_dst
);
}
// this function doesn't do santiy check on whether the slicing window is out of the boundary
// of the tensor being sliced
template
<
index_t
IDim_
,
index_t
StepSize
,
bool
PositiveDirection
>
__device__
void
MoveSlicingWindowOnSourceTensor
(
Number
<
IDim_
>
,
Number
<
StepSize
>
,
integral_constant
<
bool
,
PositiveDirection
>
direction
)
{
constexpr
auto
IDim
=
Number
<
IDim_
>
{};
static_if
<
PositiveDirection
>
{}([
&
](
auto
fwd
)
{
mThreadSrcOffset
+=
StepSize
*
fwd
(
SrcDesc
{}).
GetStride
(
IDim
);
}).
Else
([
&
](
auto
fwd
)
{
mThreadSrcOffset
-=
StepSize
*
fwd
(
SrcDesc
{}).
GetStride
(
IDim
);
});
}
};
}
// namespace ck
#endif
composable_kernel/include/tensor_operation/threadwise_4d_tensor_op.hpp
deleted
100644 → 0
View file @
98a2cfcc
#ifndef CK_THREADWISE_4D_TENSOR_OP_HPP
#define CK_THREADWISE_4D_TENSOR_OP_HPP
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
namespace
ck
{
template
<
class
Float
,
class
Desc
,
class
IDim
,
class
NShift
>
__device__
void
threadwise_4d_tensor_shift_down
(
Desc
,
Float
*
__restrict__
p
,
IDim
,
NShift
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
desc
=
Desc
{};
#if 0
if(get_thread_local_1d_id() == 0)
{
print_ConstantTensorDescriptor(desc, "threadwise_4d_tensor_shift_down: ");
}
#endif
constexpr
index_t
nshift
=
NShift
::
mValue
;
constexpr
index_t
did0_end
=
is_same
<
decltype
(
I0
),
IDim
>
{}
?
desc
.
GetLength
(
I0
)
-
nshift
:
desc
.
GetLength
(
I0
);
constexpr
index_t
did1_end
=
is_same
<
decltype
(
I1
),
IDim
>
{}
?
desc
.
GetLength
(
I1
)
-
nshift
:
desc
.
GetLength
(
I1
);
constexpr
index_t
did2_end
=
is_same
<
decltype
(
I2
),
IDim
>
{}
?
desc
.
GetLength
(
I2
)
-
nshift
:
desc
.
GetLength
(
I2
);
constexpr
index_t
did3_end
=
is_same
<
decltype
(
I3
),
IDim
>
{}
?
desc
.
GetLength
(
I3
)
-
nshift
:
desc
.
GetLength
(
I3
);
for
(
index_t
did0
=
0
;
did0
<
did0_end
;
++
did0
)
{
for
(
index_t
did1
=
0
;
did1
<
did1_end
;
++
did1
)
{
for
(
index_t
did2
=
0
;
did2
<
did2_end
;
++
did2
)
{
for
(
index_t
did3
=
0
;
did3
<
did3_end
;
++
did3
)
{
const
index_t
dindex
=
desc
.
GetOffsetFromMultiIndex
(
did0
,
did1
,
did2
,
did3
);
const
index_t
sindex
=
dindex
+
nshift
*
desc
.
GetStride
(
IDim
{});
p
[
dindex
]
=
p
[
sindex
];
}
}
}
}
}
}
// namespace ck
#endif
composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy_deprecated.hpp
View file @
9b280cc5
...
@@ -4,7 +4,6 @@
...
@@ -4,7 +4,6 @@
#include "common_header.hpp"
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMergedTensorDescriptor.hpp"
#include "ConstantMergedTensorDescriptor.hpp"
#include "tensor_view.hpp"
#include "tensor_coordinate_deprecated.hpp"
#include "tensor_coordinate_deprecated.hpp"
#ifndef CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R1
#ifndef CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R1
...
@@ -600,18 +599,18 @@ struct ThreadwiseGenericTensorSliceCopy_v2r1
...
@@ -600,18 +599,18 @@ struct ThreadwiseGenericTensorSliceCopy_v2r1
// Read vector from src.
// Read vector from src.
// 1. Source code version can take src of all kinds of memory-space
// 1. Source code version can take src of all kinds of memory-space
// 2. In
line asm
version
s
using
global_load or
buffer_load can only take
// 2. In
trinsic
version using buffer_load can only take
// src from global-memory
// src from global-memory
//
//
// Commemt for loading from global-memory:
// Commemt for loading from global-memory:
// When
// When
:
// 1) using source code, in order for compiler to emit optimal
// 1) using source code, in order for compiler to emit optimal
// load instruction, or
// load instruction, or
// 2) using inline asm (global_load or buffer_load), in order
// 2) using buffer_load intrinsic, in order for ISA to be valid,
// for inline asm to be valid,
// following assumptions need to be satisfied:
// following assumptions need to be satisfied:
// 1. p_src need to be block-invariant (assumption)
// 1. p_src need to be block-invariant (assumption)
// 2. src_normal_offset must be calculatd at compile time (guaranteed)
// 2. src_normal_offset must be calculatd at compile time (guaranteed by
// algorithm)
// 3. src_merged_offset can be runtime value (no assumption imposed)
// 3. src_merged_offset can be runtime value (no assumption imposed)
static_if
<
SrcAddressSpace
==
address_space_t
::
global
>
{}([
&
](
auto
)
{
static_if
<
SrcAddressSpace
==
address_space_t
::
global
>
{}([
&
](
auto
)
{
#if CK_USE_AMD_INTRINSIC && CK_USE_AMD_INTRINSIC_BUFFER_LOAD_STORE
#if CK_USE_AMD_INTRINSIC && CK_USE_AMD_INTRINSIC_BUFFER_LOAD_STORE
...
@@ -698,18 +697,18 @@ struct ThreadwiseGenericTensorSliceCopy_v2r1
...
@@ -698,18 +697,18 @@ struct ThreadwiseGenericTensorSliceCopy_v2r1
// Write vector into dst.
// Write vector into dst.
// 1. Source code version can take dst of all kinds of memory-space
// 1. Source code version can take dst of all kinds of memory-space
// 2. In
line asm
version
s
using
global_store or
buffer_store can only take
// 2. In
trinsic
version using buffer_store can only take
// dst from global-memory
// dst from global-memory
//
//
// Commemt for storing into global-memory:
// Commemt for storing into global-memory:
// When
// When
:
// 1) using source code, in order for compiler to emit optimal
// 1) using source code, in order for compiler to emit optimal
// store instruction, or
// store instruction, or
// 2) using inline asm (global_store or buffer_store), in order
// 2) using buffer_store, intrinsic in order ISA to be valid
// for inline asm to be valid,
// following assumptions need to be satisfied:
// following assumptions need to be satisfied:
// 1. p_dst need to be block-invariant (assumption)
// 1. p_dst need to be block-invariant (assumption)
// 2. dst_normal_offset must be calculatd at compile time (guaranteed)
// 2. dst_normal_offset must be calculatd at compile time (guaranteed by
// algorithm)
// 3. dst_merged_offset can be runtime value (no assumption imposed)
// 3. dst_merged_offset can be runtime value (no assumption imposed)
static_if
<
DstAddressSpace
==
address_space_t
::
global
>
{}([
&
](
auto
)
{
static_if
<
DstAddressSpace
==
address_space_t
::
global
>
{}([
&
](
auto
)
{
#if CK_USE_AMD_INTRINSIC && CK_USE_AMD_INTRINSIC_BUFFER_LOAD_STORE
#if CK_USE_AMD_INTRINSIC && CK_USE_AMD_INTRINSIC_BUFFER_LOAD_STORE
...
@@ -751,152 +750,5 @@ struct ThreadwiseGenericTensorSliceCopy_v2r1
...
@@ -751,152 +750,5 @@ struct ThreadwiseGenericTensorSliceCopy_v2r1
DstCoordinate
mDstSliceOrigin
;
DstCoordinate
mDstSliceOrigin
;
};
};
// this version use TensorView and TensorCoordinate_deprecated
template
<
typename
SrcTensor
,
typename
DstTensor
,
typename
SliceLengths
,
typename
SrcDimAccessOrder
,
typename
DstDimAccessOrder
,
index_t
SrcVectorAccessDim
,
index_t
DstVectorAccessDim
,
index_t
SrcDataPerAccess
,
index_t
DstDataPerAccess
>
struct
ThreadwiseGenericTensorSliceCopy_v3r1
{
static
constexpr
index_t
nDim
=
SrcTensor
::
GetNumOfDimension
();
using
data_type
=
remove_cv_t
<
typename
SrcTensor
::
data_type
>
;
using
SrcCoordinate
=
typename
SrcTensor
::
coordinate_type
;
using
DstCoordinate
=
typename
DstTensor
::
coordinate_type
;
__device__
constexpr
ThreadwiseGenericTensorSliceCopy_v3r1
(
SrcTensor
src
,
SrcCoordinate
src_slice_origin
,
DstTensor
dst
,
DstCoordinate
dst_slice_origin
)
:
mSrc
{
src
},
mDst
{
dst
},
mSrcSlice
{
src
.
Slice
(
src_slice_origin
,
SliceLengths
{})},
mDstSlice
{
dst
.
Slice
(
dst_slice_origin
,
SliceLengths
{})}
{
static_assert
(
nDim
==
SrcTensor
::
GetNumOfDimension
()
&&
nDim
==
DstTensor
::
GetNumOfDimension
()
&&
nDim
==
SliceLengths
::
GetSize
()
&&
nDim
==
SrcDimAccessOrder
::
GetSize
()
&&
nDim
==
DstDimAccessOrder
::
GetSize
(),
"wrong! # of dimensions not the same"
);
static_assert
(
is_valid_sequence_map
<
SrcDimAccessOrder
>::
value
&&
is_valid_sequence_map
<
DstDimAccessOrder
>::
value
,
"wrong! map is not valid"
);
static_assert
(
is_same
<
remove_cv_t
<
typename
SrcTensor
::
data_type
>
,
remove_cv_t
<
typename
DstTensor
::
data_type
>>
{},
"wrong! type conversion is not supported yet"
);
static_assert
(
decltype
(
mSrcSlice
)
::
IsVectorizationAllowed
(
Number
<
SrcVectorAccessDim
>
{},
Number
<
SrcDataPerAccess
>
{})
&&
decltype
(
mDstSlice
)
::
IsVectorizationAllowed
(
Number
<
DstVectorAccessDim
>
{},
Number
<
DstDataPerAccess
>
{}),
"wrong! vectorized access is not allowed"
);
}
__device__
constexpr
ThreadwiseGenericTensorSliceCopy_v3r1
()
:
ThreadwiseGenericTensorSliceCopy_v3r1
(
SrcTensor
{},
SrcCoordinate
{},
DstTensor
{},
DstCoordinate
{})
{
}
__device__
void
Run
()
const
{
// buffer
constexpr
auto
buffer_desc
=
make_ConstantTensorDescriptor_packed
(
SrcTensor
::
GetLengths
());
data_type
p_buffer
[
buffer_desc
.
GetElementSpace
()];
auto
buffer
=
make_TensorView
(
buffer_desc
,
p_buffer
);
// copy data from src into buffer
{
using
src_vector_t
=
typename
vector_type
<
data_type
,
SrcDataPerAccess
>::
MemoryType
;
constexpr
auto
src_vector_access_dim
=
Number
<
SrcVectorAccessDim
>
{};
constexpr
auto
src_data_per_access
=
Number
<
SrcDataPerAccess
>
{};
auto
src_slice_vectorized
=
mSrcSlice
.
Vectorize
(
src_vector_access_dim
,
src_data_per_access
);
ford
<
decltype
(
src_slice_vectorized
.
GetLengths
()),
SrcDimAccessOrder
>
{}(
[
&
](
auto
src_vector_id
)
{
// load vector from src
const
src_vector_t
vector_data
=
src_slice_vectorized
[
src_vector_id
];
// unpack vector into buffer
auto
src_scalar_id
=
src_vector_id
;
src_scalar_id
(
src_vector_access_dim
)
*=
src_data_per_access
;
for
(
index_t
i
=
0
;
i
<
SrcDataPerAccess
;
++
i
)
{
auto
id
=
make_zero_array
<
index_t
,
nDim
>
();
id
(
src_vector_access_dim
)
=
i
;
buffer
(
src_scalar_id
+
id
)
=
reinterpret_cast
<
const
data_type
*>
(
&
vector_data
)[
i
];
}
});
}
// copy data from buffer into dst
{
using
dst_vector_t
=
typename
vector_type
<
data_type
,
DstDataPerAccess
>::
MemoryType
;
constexpr
auto
dst_vector_access_dim
=
Number
<
DstVectorAccessDim
>
{};
constexpr
auto
dst_data_per_access
=
Number
<
DstDataPerAccess
>
{};
auto
dst_slice_vectorized
=
mDstSlice
.
Vectorize
(
dst_vector_access_dim
,
dst_data_per_access
);
ford
<
decltype
(
dst_slice_vectorized
.
GetLengths
()),
DstDimAccessOrder
>
{}(
[
&
](
auto
dst_vector_id
)
{
dst_vector_t
vector_data
{};
// pack vector from buffer
auto
dst_scalar_id
=
dst_vector_id
;
dst_scalar_id
(
dst_vector_access_dim
)
*=
dst_data_per_access
;
for
(
index_t
i
=
0
;
i
<
DstDataPerAccess
;
++
i
)
{
auto
id
=
make_zero_array
<
index_t
,
nDim
>
();
id
(
dst_vector_access_dim
)
=
i
;
reinterpret_cast
<
data_type
*>
(
&
vector_data
)[
i
]
=
buffer
[
dst_scalar_id
+
id
];
}
// write vector into dst
dst_slice_vectorized
(
dst_vector_id
)
=
vector_data
;
});
}
}
// T can be Sequence or Array
template
<
typename
T
,
bool
PositiveDirection
>
__device__
void
MoveSrcSliceWindow
(
T
step_sizes
,
integral_constant
<
bool
,
PositiveDirection
>
)
{
mSrc
.
MoveSliceWindow
(
mSrcSlice
,
step_sizes
,
integral_constant
<
bool
,
PositiveDirection
>
{});
}
template
<
typename
T
,
bool
PositiveDirection
>
__device__
void
MoveDstSliceWindow
(
T
step_sizes
,
integral_constant
<
bool
,
PositiveDirection
>
)
{
mDst
.
MoveSliceWindow
(
mDstSlice
,
step_sizes
,
integral_constant
<
bool
,
PositiveDirection
>
{});
}
private:
using
SrcSlice
=
decltype
(
SrcTensor
{}.
Slice
(
make_zero_array
<
index_t
,
nDim
>
(),
SliceLengths
{}));
using
DstSlice
=
decltype
(
DstTensor
{}.
Slice
(
make_zero_array
<
index_t
,
nDim
>
(),
SliceLengths
{}));
SrcTensor
mSrc
;
DstTensor
mDst
;
SrcSlice
mSrcSlice
;
DstSlice
mDstSlice
;
};
}
// namespace ck
}
// namespace ck
#endif
#endif
composable_kernel/include/tensor_operation/threadwise_tensor_slice_copy.hpp
deleted
100644 → 0
View file @
98a2cfcc
#ifndef CK_THREADWISE_TENSOR_SLICE_COPY_HPP
#define CK_THREADWISE_TENSOR_SLICE_COPY_HPP
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
namespace
ck
{
// need to assume src and dst is aligned
template
<
class
Float
,
class
SrcDesc
,
class
DstDesc
,
class
SrcOpLengths
,
index_t
DataPerRead
>
__device__
void
threadwise_tensor_slice_copy
(
SrcDesc
,
const
Float
*
__restrict__
p_src
,
DstDesc
,
Float
*
__restrict__
p_dst
,
SrcOpLengths
,
Number
<
DataPerRead
>
)
{
using
vector_t
=
typename
vector_type
<
Float
,
DataPerRead
>::
MemoryType
;
constexpr
index_t
nDim
=
SrcOpLengths
::
GetSize
();
static_assert
(
SrcDesc
{}.
GetNumOfDimension
()
==
nDim
&&
DstDesc
{}.
GetNumOfDimension
()
==
nDim
,
"wrong! dimension not consistent"
);
constexpr
auto
src_desc
=
SrcDesc
{};
constexpr
auto
dst_desc
=
DstDesc
{};
constexpr
auto
ref_desc
=
make_ConstantTensorDescriptor_packed
(
SrcOpLengths
{});
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantTensorDescriptor(src_desc, "src_desc");
print_ConstantTensorDescriptor(dst_desc, "dst_desc");
print_ConstantTensorDescriptor(ref_desc, "ref_desc");
}
#endif
static_assert
(
DataPerRead
==
1
||
(
SrcDesc
{}.
GetStride
(
Number
<
nDim
-
1
>
{})
==
1
&&
DstDesc
{}.
GetStride
(
Number
<
nDim
-
1
>
{})
==
1
),
"wrong! only support stride[nDim-1] == 1!
\n
"
);
static_assert
(
DataPerRead
==
1
||
DataPerRead
==
2
||
DataPerRead
==
4
,
"wrong! only support DataPerRead == 1, 2 or 4!
\n
"
);
static_assert
(
SrcDesc
{}.
GetStride
(
Number
<
nDim
-
2
>
{})
%
DataPerRead
==
0
&&
DstDesc
{}.
GetStride
(
Number
<
nDim
-
2
>
{})
%
DataPerRead
==
0
,
"wrong! src and dst stride[nDim-2] should be multiple of DataPerRead to keep alignment"
);
constexpr
index_t
L_Back
=
SrcOpLengths
{}.
Back
();
static_assert
(
L_Back
%
DataPerRead
==
0
,
"wrong! lengths[nDim-1] should be evenly divided by DataPerRead"
);
constexpr
index_t
nRead
=
L_Back
/
DataPerRead
;
static_ford
<
decltype
(
ref_desc
.
GetLengths
().
PopBack
())
>
{}([
=
](
auto
Ids
)
{
static_for
<
0
,
nRead
,
1
>
{}([
&
](
auto
IRead
)
{
constexpr
auto
multi_id
=
decltype
(
Ids
){}.
PushBack
(
Number
<
IRead
*
DataPerRead
>
{});
const
index_t
src_index
=
src_desc
.
GetOffsetFromMultiIndex
(
multi_id
);
const
index_t
dst_index
=
dst_desc
.
GetOffsetFromMultiIndex
(
multi_id
);
*
(
reinterpret_cast
<
vector_t
*>
(
&
p_dst
[
dst_index
]))
=
*
(
reinterpret_cast
<
const
vector_t
*>
(
&
p_src
[
src_index
]));
});
});
}
// access in order of src
template
<
class
SrcData
,
class
DstData
,
class
SrcDesc
,
class
DstDesc
,
class
SrcOpLengths
,
class
MapDst2Src
>
__device__
void
threadwise_tensor_slice_copy_reorder_given_dst2src_v1
(
SrcDesc
,
const
SrcData
*
__restrict__
p_src
,
DstDesc
,
DstData
*
__restrict__
p_dst
,
SrcOpLengths
,
MapDst2Src
)
{
constexpr
auto
src_desc
=
SrcDesc
{};
constexpr
auto
dst_desc
=
DstDesc
{};
ford
<
SrcOpLengths
>
{}([
&
](
auto
src_multi_id
)
{
const
auto
dst_multi_id
=
reorder_array_given_new2old
(
src_multi_id
,
MapDst2Src
{});
const
index_t
dst_index
=
dst_desc
.
GetOffsetFromMultiIndex
(
dst_multi_id
);
const
index_t
src_index
=
src_desc
.
GetOffsetFromMultiIndex
(
src_multi_id
);
p_dst
[
dst_index
]
=
p_src
[
src_index
];
});
}
// access in order of dst
template
<
class
SrcData
,
class
DstData
,
class
SrcDesc
,
class
DstDesc
,
class
SrcOpLengths
,
class
MapDst2Src
>
__device__
void
threadwise_tensor_slice_copy_reorder_given_dst2src_v2
(
SrcDesc
,
const
SrcData
*
__restrict__
p_src
,
DstDesc
,
DstData
*
__restrict__
p_dst
,
SrcOpLengths
,
MapDst2Src
)
{
constexpr
auto
src_desc
=
SrcDesc
{};
constexpr
auto
dst_desc
=
DstDesc
{};
constexpr
auto
dst_op_lengths
=
SrcOpLengths
{}.
ReorderGivenNew2Old
(
MapDst2Src
{});
ford
<
decltype
(
dst_op_lengths
)
>
{}([
&
](
auto
dst_multi_id
)
{
const
auto
src_multi_id
=
reorder_array_given_old2new
(
dst_multi_id
,
MapDst2Src
{});
const
index_t
dst_index
=
dst_desc
.
GetOffsetFromMultiIndex
(
dst_multi_id
);
const
index_t
src_index
=
src_desc
.
GetOffsetFromMultiIndex
(
src_multi_id
);
p_dst
[
dst_index
]
=
p_src
[
src_index
];
});
}
// access in order of dst
// manually pack data into vector before write
template
<
class
Float
,
class
SrcDesc
,
class
DstDesc
,
class
SrcOpLengths
,
class
MapDst2Src
,
index_t
DstDataPerWrite
>
__device__
void
threadwise_tensor_slice_copy_reorder_given_dst2src_v3
(
SrcDesc
,
const
Float
*
__restrict__
p_src
,
DstDesc
,
Float
*
__restrict__
p_dst
,
SrcOpLengths
,
MapDst2Src
,
Number
<
DstDataPerWrite
>
)
{
using
vector_t
=
typename
vector_type
<
Float
,
DstDataPerWrite
>::
MemoryType
;
constexpr
index_t
nDim
=
SrcOpLengths
::
GetSize
();
static_assert
(
DstDataPerWrite
==
1
||
DstDesc
{}.
GetStride
(
Number
<
nDim
-
1
>
{})
==
1
,
"wrong! only support dst.stride[nDim-1] == 1, if DstDataPerWrite != 1"
);
static_assert
(
DstDataPerWrite
==
1
||
DstDataPerWrite
==
2
||
DstDataPerWrite
==
4
,
"wrong! only support DstDataPerWrite == 1, 2 or 4"
);
static_assert
(
DstDesc
{}.
GetStride
(
Number
<
nDim
-
2
>
{})
%
DstDataPerWrite
==
0
,
"wrong! dst.stride[nDim-2] should be multiple of DstDataPerWrite to keep alignment"
);
constexpr
auto
src_desc
=
SrcDesc
{};
constexpr
auto
dst_desc
=
DstDesc
{};
constexpr
auto
dst_op_lengths
=
SrcOpLengths
{}.
ReorderGivenNew2Old
(
MapDst2Src
{});
constexpr
index_t
L_Dst_Back
=
dst_op_lengths
.
Back
();
static_assert
(
L_Dst_Back
%
DstDataPerWrite
==
0
,
"wrong! dst.lengths[nDim-1] should be evenly divided by DstDataPerWrite"
);
constexpr
index_t
nWrite
=
L_Dst_Back
/
DstDataPerWrite
;
ford
<
decltype
(
dst_op_lengths
.
PopBack
())
>
{}([
&
](
auto
ids
)
{
static_for
<
0
,
nWrite
,
1
>
{}([
&
](
auto
IWrite
)
{
vector_t
dst_vec_data
;
// pack data
static_for
<
0
,
DstDataPerWrite
,
1
>
{}([
&
](
auto
IDstData
)
{
const
auto
dst_multi_id
=
ids
.
PushBack
(
IWrite
*
DstDataPerWrite
+
IDstData
);
const
auto
src_multi_id
=
reorder_array_given_old2new
(
dst_multi_id
,
MapDst2Src
{});
const
index_t
src_index
=
src_desc
.
GetOffsetFromMultiIndex
(
src_multi_id
);
vector_type
<
Float
,
DstDataPerWrite
>::
SetScalar
(
dst_vec_data
,
p_src
[
src_index
],
IDstData
);
});
// write data
const
auto
dst_multi_id
=
ids
.
PushBack
(
IWrite
*
DstDataPerWrite
);
const
index_t
dst_index
=
dst_desc
.
GetOffsetFromMultiIndex
(
dst_multi_id
);
*
(
reinterpret_cast
<
vector_t
*>
(
&
p_dst
[
dst_index
]))
=
dst_vec_data
;
});
});
}
}
// namespace ck
#endif
composable_kernel/include/utility/config_nvidia.hpp.in
View file @
9b280cc5
...
@@ -48,36 +48,6 @@ struct type_convert
...
@@ -48,36 +48,6 @@ struct type_convert
}
}
};
};
template <class T>
__device__ void fused_multiply_accumulate(T& d, const T& s0, const T& s1)
{
d += s0 * s1;
}
#if 0
__device__ void fused_multiply_accumulate(half& d, const half& s0, const half& s1) { d += s0 * s1; }
__device__ void fused_multiply_accumulate(half& d, const half2& s0, const half2& s1)
{
d += s0.x * s1.x;
d += s0.y * s1.y;
}
__device__ void fused_multiply_accumulate(float& d, const half2& s0, const half2& s1)
{
d += s0.x * s1.x + s0.y * s1.y;
}
__device__ void fused_multiply_accumulate(char& d, const char& s0, const char& s1) { d += s0 * s1; }
// TODO:: this interface is misleading, s0, s1 are actually int8x4
// need to make a better interface
__device__ void fused_multiply_accumulate(int32_t& d, const int32_t& s0, const int32_t& s1)
{
d = __dp4a(s0, s1, d);
}
#endif
} // namespace ck
} // namespace ck
#endif
#endif
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment