Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
971220d8
"...resnet50_tensorflow.git" did not exist on "f67822f5ff1d9e3495da069b31aa7643430e02c3"
Commit
971220d8
authored
Sep 16, 2021
by
ltqin
Browse files
gridwise gemm data copy and blockgwise gemm
parent
a52e5a92
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
447 additions
and
435 deletions
+447
-435
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v3r1.hpp
...el/include/tensor_operation/gridwise_gemm_xdlops_v3r1.hpp
+447
-435
No files found.
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v3r1.hpp
View file @
971220d8
...
@@ -264,444 +264,456 @@ struct GridwiseGemm_gk0mk1_gk0nk1_gmn_xdlops_v3r1
...
@@ -264,444 +264,456 @@ struct GridwiseGemm_gk0mk1_gk0nk1_gmn_xdlops_v3r1
decltype
(
MakeCGM0N0M1N1M2M3M4N2GridDescriptor
(
CGMNGridDesc
{}));
decltype
(
MakeCGM0N0M1N1M2M3M4N2GridDescriptor
(
CGMNGridDesc
{}));
using
CBlockClusterAdaptor
=
decltype
(
MakeCBlockClusterAdaptor
(
CGMNGridDesc
{}));
using
CBlockClusterAdaptor
=
decltype
(
MakeCBlockClusterAdaptor
(
CGMNGridDesc
{}));
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
__device__
static
void
const
FloatAB
*
__restrict__
p_b_grid
,
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatAB
*
__restrict__
p_shared_block
,
FloatC
*
__restrict__
p_c_grid
,
const
AGK0MK1GridDesc
&
a_k0_m_k1_grid_desc
,
FloatAB
*
__restrict__
p_shared_block
,
const
BGK0NK1GridDesc
&
b_k0_n_k1_grid_desc
,
const
AGK0MK1GridDesc
&
a_g_k0_m_k1_grid_desc
,
const
CM0N0M1N1M2M3M4N2GridDesc
&
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
const
BGK0NK1GridDesc
&
b_g_k0_n_k1_grid_desc
,
const
CBlockClusterAdaptor
&
c_block_cluster_adaptor
)
const
CM0N0M1N1M2M3M4N2GridDesc
&
c_g_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
const
CBlockClusterAdaptor
&
c_block_cluster_adaptor
)
{
{
/* const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_a_grid, a_k0_m_k1_grid_desc.GetElementSpaceSize());
p_a_grid
,
a_g_k0_m_k1_grid_desc
.
GetElementSpaceSize
());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_b_grid, b_k0_n_k1_grid_desc.GetElementSpaceSize());
p_b_grid
,
b_g_k0_n_k1_grid_desc
.
GetElementSpaceSize
());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_c_grid, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc.GetElementSpaceSize());
p_c_grid
,
c_g_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
.
GetElementSpaceSize
());
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0);
const
auto
K0
=
a_g_k0_m_k1_grid_desc
.
GetLength
(
I1
);
// divide block work by [M, N]
// divide block work by [M, N]
const auto block_work_idx =
const
auto
block_work_idx
=
c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
c_block_cluster_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
// HACK: this force m/n_block_data_idx_on_grid into SGPR
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid =
const
index_t
m_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]
*
MPerBlock
);
const index_t n_block_data_idx_on_grid =
const
index_t
n_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I2
]
*
NPerBlock
);
const
index_t
g_idx
=
block_work_idx
[
I0
];
// lds max alignment
constexpr auto max_lds_align = K1;
// lds max alignment
constexpr
auto
max_lds_align
=
K1
;
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned(
// be careful of LDS alignment
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
constexpr
auto
a_k0_m_k1_block_desc
=
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
a_g_k0_m_k1_block_desc
=
make_naive_tensor_descriptor_aligned
(
constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple
(
Number
<
1
>
{},
Number
<
KPerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
// A matrix blockwise copy
constexpr
auto
b_k0_n_k1_block_desc
=
make_naive_tensor_descriptor_aligned
(
auto a_blockwise_copy =
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
BlockwiseTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperationEnum_t::Set,
constexpr
auto
b_g_k0_n_k1_block_desc
=
make_naive_tensor_descriptor_aligned
(
Sequence<KPerBlock, MPerBlock, K1>,
make_tuple
(
Number
<
1
>
{},
Number
<
KPerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
ABlockTransferThreadSliceLengths_K0_M_K1,
// A matrix blockwise copy
ABlockTransferThreadClusterLengths_K0_M_K1,
auto
a_blockwise_copy
=
ABlockTransferThreadClusterArrangeOrder,
BlockwiseTensorSliceTransfer_v4
<
BlockSize
,
FloatAB,
InMemoryDataOperationEnum_t
::
Set
,
FloatAB,
Sequence
<
1
,
KPerBlock
,
MPerBlock
,
K1
>
,
decltype(a_k0_m_k1_grid_desc),
ABlockTransferThreadSliceLengths_G_K0_M_K1
,
decltype(a_k0_m_k1_block_desc),
ABlockTransferThreadClusterLengths_G_K0_M_K1
,
ABlockTransferSrcAccessOrder,
ABlockTransferThreadClusterArrangeOrder
,
Sequence<1, 0, 2>,
FloatAB
,
ABlockTransferSrcVectorDim,
FloatAB
,
2,
decltype
(
a_g_k0_m_k1_grid_desc
),
ABlockTransferSrcScalarPerVector,
decltype
(
a_g_k0_m_k1_block_desc
),
ABlockTransferDstScalarPerVector_K1,
ABlockTransferSrcAccessOrder
,
1,
Sequence
<
0
,
2
,
1
,
3
>
,
1,
ABlockTransferSrcVectorDim
,
AThreadTransferSrcResetCoordinateAfterRun,
3
,
true>(a_k0_m_k1_grid_desc,
ABlockTransferSrcScalarPerVector
,
make_multi_index(0, m_block_data_idx_on_grid,
ABlockTransferDstScalarPerVector_K1
,
0), a_k0_m_k1_block_desc, make_multi_index(0, 0, 0));
1
,
1
,
// B matrix blockwise copy
AThreadTransferSrcResetCoordinateAfterRun
,
auto b_blockwise_copy =
true
>
(
BlockwiseTensorSliceTransfer_v4<BlockSize,
a_g_k0_m_k1_grid_desc
,
InMemoryDataOperationEnum_t::Set,
make_multi_index
(
g_idx
,
0
,
m_block_data_idx_on_grid
,
0
),
Sequence<KPerBlock, NPerBlock, K1>,
a_g_k0_m_k1_block_desc
,
BBlockTransferThreadSliceLengths_K0_N_K1,
make_multi_index
(
0
,
0
,
0
,
0
));
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
// B matrix blockwise copy
FloatAB,
auto
b_blockwise_copy
=
FloatAB,
BlockwiseTensorSliceTransfer_v4
<
BlockSize
,
decltype(b_k0_n_k1_grid_desc),
InMemoryDataOperationEnum_t
::
Set
,
decltype(b_k0_n_k1_block_desc),
Sequence
<
1
,
KPerBlock
,
NPerBlock
,
K1
>
,
BBlockTransferSrcAccessOrder,
BBlockTransferThreadSliceLengths_G_K0_N_K1
,
Sequence<1, 0, 2>,
BBlockTransferThreadClusterLengths_G_K0_N_K1
,
BBlockTransferSrcVectorDim,
BBlockTransferThreadClusterArrangeOrder
,
2,
FloatAB
,
BBlockTransferSrcScalarPerVector,
FloatAB
,
BBlockTransferDstScalarPerVector_K1,
decltype
(
b_g_k0_n_k1_grid_desc
),
1,
decltype
(
b_g_k0_n_k1_block_desc
),
1,
BBlockTransferSrcAccessOrder
,
BThreadTransferSrcResetCoordinateAfterRun,
Sequence
<
0
,
2
,
1
,
3
>
,
true>(b_k0_n_k1_grid_desc,
BBlockTransferSrcVectorDim
,
make_multi_index(0, n_block_data_idx_on_grid,
3
,
0), b_k0_n_k1_block_desc, make_multi_index(0, 0, 0));
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_K1
,
// GEMM definition
1
,
// c_mtx += transpose(a_mtx) * b_mtx
1
,
// a_mtx[KPerBlock, MPerBlock] is in LDS
BThreadTransferSrcResetCoordinateAfterRun
,
// b_mtx[KPerBlock, NPerBlock] is in LDS
true
>
(
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
b_g_k0_n_k1_grid_desc
,
// register
make_multi_index
(
g_idx
,
0
,
n_block_data_idx_on_grid
,
0
),
// sanity check
b_g_k0_n_k1_block_desc
,
make_multi_index
(
0
,
0
,
0
,
0
));
const auto blockwise_gemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
// GEMM definition
FloatAB,
// c_mtx += transpose(a_mtx) * b_mtx
decltype(a_k0_m_k1_block_desc),
// a_mtx[KPerBlock, MPerBlock] is in LDS
decltype(b_k0_n_k1_block_desc),
// b_mtx[KPerBlock, NPerBlock] is in LDS
MPerXDL,
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
NPerXDL,
// register
MRepeat,
// sanity check
NRepeat,
K1>{};
const
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
constexpr auto c_mr_nr_blk_desc =
FloatAB
,
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
decltype
(
a_k0_m_k1_block_desc
),
Number<NRepeat>{}));
decltype
(
b_k0_n_k1_block_desc
),
MPerXDL
,
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc =
NPerXDL
,
blockwise_gemm.GetCM0N0M1N1M2M3M4N2ThreadDescriptor();
MRepeat
,
constexpr auto CBlkSize = c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc.GetElementSpaceSize();
NRepeat
,
K1
>
{};
StaticBuffer<AddressSpaceEnum_t::Vgpr,
vector_type<FloatAcc, CBlkSize>,
constexpr
auto
c_mr_nr_blk_desc
=
c_mr_nr_blk_desc.GetElementSpaceSize(),
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{}));
true>
c_thread_buf;
constexpr
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc
=
blockwise_gemm
.
GetCM0N0M1N1M2M3M4N2ThreadDescriptor
();
// LDS allocation for A and B: be careful of alignment
constexpr
auto
CBlkSize
=
c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc
.
GetElementSpaceSize
();
constexpr auto a_block_space_size =
math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(),
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
max_lds_align);
vector_type
<
FloatAcc
,
CBlkSize
>
,
c_mr_nr_blk_desc
.
GetElementSpaceSize
(),
FloatAB* p_a_block = p_shared_block;
true
>
FloatAB* p_b_block = p_shared_block + a_block_space_size;
c_thread_buf
;
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0);
// LDS allocation for A and B: be careful of alignment
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0);
constexpr
auto
a_block_space_size
=
math
::
integer_least_multiple
(
a_k0_m_k1_block_desc
.
GetElementSpaceSize
(),
max_lds_align
);
// hack to control index calculation when iterating over A and B matrix for threadwise
copy constexpr auto a_k0_m_k1_grid_step_hacks = AGridStepHacks{}; constexpr auto
FloatAB
*
p_a_block
=
p_shared_block
;
b_k0_n_k1_grid_step_hacks = BGridStepHacks{};
FloatAB
*
p_b_block
=
p_shared_block
+
a_block_space_size
;
// hack to control index calculation when move slice window for A and B matrix for
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
0
,
KPerBlock
,
0
,
0
);
// threadwise copy
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
0
,
KPerBlock
,
0
,
0
);
constexpr auto a_k0_m_k1_grid_move_slice_window_step_hack =
AGridMoveSliceWindowStepHacks{}; constexpr auto
// hack to control index calculation when iterating over A and B matrix for threadwise copy
b_k0_n_k1_grid_move_slice_window_step_hack = BGridMoveSliceWindowStepHacks{};
constexpr
auto
a_g_k0_m_k1_grid_step_hacks
=
AGridStepHacks
{};
constexpr
auto
b_g_k0_n_k1_grid_step_hacks
=
BGridStepHacks
{};
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
p_a_block, a_k0_m_k1_block_desc.GetElementSpaceSize());
// hack to control index calculation when move slice window for A and B matrix for
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
// threadwise copy
p_b_block, b_k0_n_k1_block_desc.GetElementSpaceSize());
constexpr
auto
a_g_k0_m_k1_grid_move_slice_window_step_hack
=
AGridMoveSliceWindowStepHacks
{};
// preload data into LDS
constexpr
auto
b_g_k0_n_k1_grid_move_slice_window_step_hack
=
{
BGridMoveSliceWindowStepHacks
{};
a_blockwise_copy.RunRead(a_k0_m_k1_grid_desc, a_grid_buf,
a_k0_m_k1_grid_step_hacks); b_blockwise_copy.RunRead(b_k0_n_k1_grid_desc, b_grid_buf,
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Lds
>
(
b_k0_n_k1_grid_step_hacks);
p_a_block
,
a_g_k0_m_k1_block_desc
.
GetElementSpaceSize
());
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Lds
>
(
a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_buf);
p_b_block
,
b_g_k0_n_k1_block_desc
.
GetElementSpaceSize
());
b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_buf);
}
// preload data into LDS
{
// main body
a_blockwise_copy
.
RunRead
(
index_t k_block_data_begin = 0;
a_g_k0_m_k1_grid_desc
,
a_grid_buf
,
a_g_k0_m_k1_grid_step_hacks
);
b_blockwise_copy
.
RunRead
(
do
b_g_k0_n_k1_grid_desc
,
b_grid_buf
,
b_g_k0_n_k1_grid_step_hacks
);
{
a_blockwise_copy.MoveSrcSliceWindow(a_k0_m_k1_grid_desc,
a_blockwise_copy
.
RunWrite
(
a_g_k0_m_k1_block_desc
,
a_block_buf
);
a_block_slice_copy_step,
b_blockwise_copy
.
RunWrite
(
b_g_k0_n_k1_block_desc
,
b_block_buf
);
a_k0_m_k1_grid_move_slice_window_step_hack);
}
b_blockwise_copy.MoveSrcSliceWindow(b_k0_n_k1_grid_desc,
b_block_slice_copy_step,
// main body
b_k0_n_k1_grid_move_slice_window_step_hack);
index_t
k_block_data_begin
=
0
;
a_blockwise_copy.RunRead(a_k0_m_k1_grid_desc, a_grid_buf,
do
a_k0_m_k1_grid_step_hacks);
{
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_g_k0_m_k1_grid_desc
,
block_sync_lds();
a_block_slice_copy_step
,
a_g_k0_m_k1_grid_move_slice_window_step_hack
);
b_blockwise_copy.RunRead(b_k0_n_k1_grid_desc, b_grid_buf,
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_g_k0_n_k1_grid_desc
,
b_k0_n_k1_grid_step_hacks);
b_block_slice_copy_step
,
b_g_k0_n_k1_grid_move_slice_window_step_hack
);
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
a_blockwise_copy
.
RunRead
(
block_sync_lds();
a_g_k0_m_k1_grid_desc
,
a_grid_buf
,
a_g_k0_m_k1_grid_step_hacks
);
a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_buf);
block_sync_lds
();
b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_buf);
b_blockwise_copy
.
RunRead
(
k_block_data_begin += KPerBlock;
b_g_k0_n_k1_grid_desc
,
b_grid_buf
,
b_g_k0_n_k1_grid_step_hacks
);
} while(k_block_data_begin < (K0 - KPerBlock));
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
// tail
{
block_sync_lds
();
block_sync_lds();
a_blockwise_copy
.
RunWrite
(
a_g_k0_m_k1_block_desc
,
a_block_buf
);
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
b_blockwise_copy
.
RunWrite
(
b_g_k0_n_k1_block_desc
,
b_block_buf
);
}
k_block_data_begin
+=
KPerBlock
;
// output: register to global memory
}
while
(
k_block_data_begin
<
(
K0
-
KPerBlock
));
{
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc =
// tail
blockwise_gemm.GetCM0N0M1N1M2M3M4N2BlockDescriptor();
{
block_sync_lds
();
constexpr auto M2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I4);
constexpr auto M3 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5);
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
constexpr auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6);
}
// calculate origin of thread output tensor on global memory
/* // output: register to global memory
// blockwise GEMM c matrix starting index
{
const auto c_thread_mtx_on_block =
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc =
blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
blockwise_gemm.GetCM0N0M1N1M2M3M4N2BlockDescriptor();
const index_t m_thread_data_on_grid =
constexpr auto M2 =
m_block_data_idx_on_grid + c_thread_mtx_on_block[I0];
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I4); constexpr auto M3 =
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5); constexpr auto M4 =
const index_t n_thread_data_on_grid =
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6);
n_block_data_idx_on_grid + c_thread_mtx_on_block[I1];
// calculate origin of thread output tensor on global memory
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks = CGridStepHacks{};
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
auto c_thread_copy =
blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0,
ThreadwiseTensorSliceTransfer_v1r3<FloatC,
I0);
FloatC,
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc),
const index_t m_thread_data_on_grid =
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc),
m_block_data_idx_on_grid + c_thread_mtx_on_block[I0];
Sequence<I1, I1, I1, I1, M2, I1, M4, I1>,
CThreadTransferSrcDstAccessOrder,
const index_t n_thread_data_on_grid =
CThreadTransferSrcDstVectorDim,
n_block_data_idx_on_grid + c_thread_mtx_on_block[I1];
CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation,
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks
1,
= CGridStepHacks{};
true>{
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
auto c_thread_copy =
make_multi_index(0,
ThreadwiseTensorSliceTransfer_v1r3<FloatC,
0,
FloatC,
0,
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc),
0,
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc),
m_thread_data_on_grid / (M3 * M4),
Sequence<I1, I1, I1, I1,
m_thread_data_on_grid % (M3 * M4) / M4,
M2, I1, M4, I1>, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim,
m_thread_data_on_grid % M4,
CThreadTransferDstScalarPerVector,
n_thread_data_on_grid)};
CGlobalMemoryDataOperation,
1,
auto init_copy = [&](auto c_thread_idx_) {
true>{
constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
make_multi_index(0,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
0,
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
0,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
0,
c_grid_buf,
m_thread_data_on_grid / (M3 * M4),
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks);
m_thread_data_on_grid % (M3 * M4) /
M4, m_thread_data_on_grid % M4, n_thread_data_on_grid)};
return c_thread_idx_;
};
auto init_copy = [&](auto c_thread_idx_) {
constexpr auto blk_off =
auto mrepeat_plus_copy = [&](auto c_thread_idx_) {
c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
constexpr auto mrepeat_step_plus = make_multi_index(1, 0, 0, 0, 0, 0, 0, 0);
c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0,
mrepeat_step_plus);
I0), c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, c_grid_buf,
constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks);
c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
return c_thread_idx_;
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
};
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_grid_buf,
auto mrepeat_plus_copy = [&](auto c_thread_idx_) {
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks);
constexpr auto mrepeat_step_plus = make_multi_index(1, 0,
};
0, 0, 0, 0, 0, 0); c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
mrepeat_step_plus);
auto nrepeat_plus_copy = [&](auto c_thread_idx_) {
constexpr auto nrepeat_step_plus = make_multi_index(0, 1, 0, 0, 0, 0, 0, 0);
constexpr auto blk_off =
c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
nrepeat_step_plus);
c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0,
constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
I0), c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, c_grid_buf,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks);
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
};
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_grid_buf,
auto nrepeat_plus_copy = [&](auto c_thread_idx_) {
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks);
constexpr auto nrepeat_step_plus = make_multi_index(0, 1,
};
0, 0, 0, 0, 0, 0); c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
nrepeat_step_plus);
auto mrepeat_minus_copy = [&](auto c_thread_idx_) {
constexpr auto mrepeat_step_plus = make_multi_index(-1, 0, 0, 0, 0, 0, 0, 0);
constexpr auto blk_off =
c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
mrepeat_step_plus);
c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0,
constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
I0), c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, c_grid_buf,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks);
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
};
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_grid_buf,
auto mrepeat_minus_copy = [&](auto c_thread_idx_) {
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks);
constexpr auto mrepeat_step_plus = make_multi_index(-1, 0,
};
0, 0, 0, 0, 0, 0); c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
mrepeat_step_plus);
auto nrepeat_minus_copy = [&](auto c_thread_idx_) {
constexpr auto nrepeat_step_minus = make_multi_index(0, -1, 0, 0, 0, 0, 0, 0);
constexpr auto blk_off =
c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
nrepeat_step_minus);
c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0,
constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
I0), c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, c_grid_buf,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks);
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
};
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_grid_buf,
auto nrepeat_minus_copy = [&](auto c_thread_idx_) {
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks);
constexpr auto nrepeat_step_minus = make_multi_index(0, -1,
};
0, 0, 0, 0, 0, 0); c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
nrepeat_step_minus);
static_assert((MRepeat == 4 && NRepeat == 4) or (MRepeat == 4 && NRepeat == 2) or
(MRepeat == 2 && NRepeat == 4) or (MRepeat == 2 && NRepeat == 2)
constexpr auto blk_off =
or (MRepeat == 2 && NRepeat == 1) or (MRepeat == 1 && NRepeat == 2) or (MRepeat == 1 &&
c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0,
I0), c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, c_grid_buf,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks);
};
static_assert((MRepeat == 4 && NRepeat == 4) or (MRepeat == 4
&& NRepeat == 2) or (MRepeat == 2 && NRepeat == 4) or (MRepeat == 2 && NRepeat == 2) or
(MRepeat == 2
&& NRepeat == 1) or (MRepeat == 1 && NRepeat == 2) or (MRepeat == 1 &&
NRepeat == 1), "wrong");
NRepeat == 1), "wrong");
if constexpr(MRepeat == 4 && NRepeat == 4)
if constexpr(MRepeat == 4 && NRepeat == 4)
{
{
init_copy(make_tuple(I0, I0));
init_copy(make_tuple(I0, I0));
if constexpr(CAccessOrderMRepeatNRepeat)
if constexpr(CAccessOrderMRepeatNRepeat)
{
{
nrepeat_plus_copy(make_tuple(I0, I1));
nrepeat_plus_copy(make_tuple(I0, I1));
nrepeat_plus_copy(make_tuple(I0, I2));
nrepeat_plus_copy(make_tuple(I0, I2));
nrepeat_plus_copy(make_tuple(I0, I3));
nrepeat_plus_copy(make_tuple(I0, I3));
mrepeat_plus_copy(make_tuple(I1, I3));
mrepeat_plus_copy(make_tuple(I1, I3));
nrepeat_minus_copy(make_tuple(I1, I2));
nrepeat_minus_copy(make_tuple(I1, I2));
nrepeat_minus_copy(make_tuple(I1, I1));
nrepeat_minus_copy(make_tuple(I1, I1));
nrepeat_minus_copy(make_tuple(I1, I0));
nrepeat_minus_copy(make_tuple(I1, I0));
mrepeat_plus_copy(make_tuple(I2, I0));
mrepeat_plus_copy(make_tuple(I2, I0));
nrepeat_plus_copy(make_tuple(I2, I1));
nrepeat_plus_copy(make_tuple(I2, I1));
nrepeat_plus_copy(make_tuple(I2, I2));
nrepeat_plus_copy(make_tuple(I2, I2));
nrepeat_plus_copy(make_tuple(I2, I3));
nrepeat_plus_copy(make_tuple(I2, I3));
mrepeat_plus_copy(make_tuple(I3, I3));
mrepeat_plus_copy(make_tuple(I3, I3));
nrepeat_minus_copy(make_tuple(I3, I2));
nrepeat_minus_copy(make_tuple(I3, I2));
nrepeat_minus_copy(make_tuple(I3, I1));
nrepeat_minus_copy(make_tuple(I3, I1));
nrepeat_minus_copy(make_tuple(I3, I0));
nrepeat_minus_copy(make_tuple(I3, I0));
}
}
else
else
{
{
mrepeat_plus_copy(make_tuple(I1, I0));
mrepeat_plus_copy(make_tuple(I1, I0));
mrepeat_plus_copy(make_tuple(I2, I0));
mrepeat_plus_copy(make_tuple(I2, I0));
mrepeat_plus_copy(make_tuple(I3, I0));
mrepeat_plus_copy(make_tuple(I3, I0));
nrepeat_plus_copy(make_tuple(I3, I1));
nrepeat_plus_copy(make_tuple(I3, I1));
mrepeat_minus_copy(make_tuple(I2, I1));
mrepeat_minus_copy(make_tuple(I2, I1));
mrepeat_minus_copy(make_tuple(I1, I1));
mrepeat_minus_copy(make_tuple(I1, I1));
mrepeat_minus_copy(make_tuple(I0, I1));
mrepeat_minus_copy(make_tuple(I0, I1));
nrepeat_plus_copy(make_tuple(I0, I2));
nrepeat_plus_copy(make_tuple(I0, I2));
mrepeat_plus_copy(make_tuple(I1, I2));
mrepeat_plus_copy(make_tuple(I1, I2));
mrepeat_plus_copy(make_tuple(I2, I2));
mrepeat_plus_copy(make_tuple(I2, I2));
mrepeat_plus_copy(make_tuple(I3, I2));
mrepeat_plus_copy(make_tuple(I3, I2));
nrepeat_plus_copy(make_tuple(I3, I3));
nrepeat_plus_copy(make_tuple(I3, I3));
mrepeat_minus_copy(make_tuple(I2, I3));
mrepeat_minus_copy(make_tuple(I2, I3));
mrepeat_minus_copy(make_tuple(I1, I3));
mrepeat_minus_copy(make_tuple(I1, I3));
mrepeat_minus_copy(make_tuple(I0, I3));
mrepeat_minus_copy(make_tuple(I0, I3));
}
}
}
}
else if constexpr(MRepeat == 4 && NRepeat == 2)
else if constexpr(MRepeat == 4 && NRepeat == 2)
{
{
init_copy(make_tuple(I0, I0));
init_copy(make_tuple(I0, I0));
if constexpr(CAccessOrderMRepeatNRepeat)
if constexpr(CAccessOrderMRepeatNRepeat)
{
{
nrepeat_plus_copy(make_tuple(I0, I1));
nrepeat_plus_copy(make_tuple(I0, I1));
mrepeat_plus_copy(make_tuple(I1, I1));
mrepeat_plus_copy(make_tuple(I1, I1));
nrepeat_minus_copy(make_tuple(I1, I0));
nrepeat_minus_copy(make_tuple(I1, I0));
mrepeat_plus_copy(make_tuple(I2, I0));
mrepeat_plus_copy(make_tuple(I2, I0));
nrepeat_plus_copy(make_tuple(I2, I1));
nrepeat_plus_copy(make_tuple(I2, I1));
mrepeat_plus_copy(make_tuple(I3, I1));
mrepeat_plus_copy(make_tuple(I3, I1));
nrepeat_minus_copy(make_tuple(I3, I0));
nrepeat_minus_copy(make_tuple(I3, I0));
}
}
else
else
{
{
mrepeat_plus_copy(make_tuple(I1, I0));
mrepeat_plus_copy(make_tuple(I1, I0));
mrepeat_plus_copy(make_tuple(I2, I0));
mrepeat_plus_copy(make_tuple(I2, I0));
mrepeat_plus_copy(make_tuple(I3, I0));
mrepeat_plus_copy(make_tuple(I3, I0));
nrepeat_plus_copy(make_tuple(I3, I1));
nrepeat_plus_copy(make_tuple(I3, I1));
mrepeat_minus_copy(make_tuple(I2, I1));
mrepeat_minus_copy(make_tuple(I2, I1));
mrepeat_minus_copy(make_tuple(I1, I1));
mrepeat_minus_copy(make_tuple(I1, I1));
mrepeat_minus_copy(make_tuple(I0, I1));
mrepeat_minus_copy(make_tuple(I0, I1));
}
}
}
}
else if constexpr(MRepeat == 2 && NRepeat == 4)
else if constexpr(MRepeat == 2 && NRepeat == 4)
{
{
init_copy(make_tuple(I0, I0));
init_copy(make_tuple(I0, I0));
if constexpr(CAccessOrderMRepeatNRepeat)
if constexpr(CAccessOrderMRepeatNRepeat)
{
{
nrepeat_plus_copy(make_tuple(I0, I1));
nrepeat_plus_copy(make_tuple(I0, I1));
nrepeat_plus_copy(make_tuple(I0, I2));
nrepeat_plus_copy(make_tuple(I0, I2));
nrepeat_plus_copy(make_tuple(I0, I3));
nrepeat_plus_copy(make_tuple(I0, I3));
mrepeat_plus_copy(make_tuple(I1, I3));
mrepeat_plus_copy(make_tuple(I1, I3));
nrepeat_minus_copy(make_tuple(I1, I2));
nrepeat_minus_copy(make_tuple(I1, I2));
nrepeat_minus_copy(make_tuple(I1, I1));
nrepeat_minus_copy(make_tuple(I1, I1));
nrepeat_minus_copy(make_tuple(I1, I0));
nrepeat_minus_copy(make_tuple(I1, I0));
}
}
else
else
{
{
mrepeat_plus_copy(make_tuple(I1, I0));
mrepeat_plus_copy(make_tuple(I1, I0));
nrepeat_plus_copy(make_tuple(I1, I1));
nrepeat_plus_copy(make_tuple(I1, I1));
mrepeat_minus_copy(make_tuple(I0, I1));
mrepeat_minus_copy(make_tuple(I0, I1));
nrepeat_plus_copy(make_tuple(I0, I2));
nrepeat_plus_copy(make_tuple(I0, I2));
mrepeat_plus_copy(make_tuple(I1, I2));
mrepeat_plus_copy(make_tuple(I1, I2));
nrepeat_plus_copy(make_tuple(I1, I3));
nrepeat_plus_copy(make_tuple(I1, I3));
mrepeat_minus_copy(make_tuple(I0, I3));
mrepeat_minus_copy(make_tuple(I0, I3));
}
}
}
}
else if constexpr(MRepeat == 2 && NRepeat == 2)
else if constexpr(MRepeat == 2 && NRepeat == 2)
{
{
init_copy(make_tuple(I0, I0));
init_copy(make_tuple(I0, I0));
if constexpr(CAccessOrderMRepeatNRepeat)
if constexpr(CAccessOrderMRepeatNRepeat)
{
{
nrepeat_plus_copy(make_tuple(I0, I1));
nrepeat_plus_copy(make_tuple(I0, I1));
mrepeat_plus_copy(make_tuple(I1, I1));
mrepeat_plus_copy(make_tuple(I1, I1));
nrepeat_minus_copy(make_tuple(I1, I0));
nrepeat_minus_copy(make_tuple(I1, I0));
}
}
else
else
{
{
mrepeat_plus_copy(make_tuple(I1, I0));
mrepeat_plus_copy(make_tuple(I1, I0));
nrepeat_plus_copy(make_tuple(I1, I1));
nrepeat_plus_copy(make_tuple(I1, I1));
mrepeat_minus_copy(make_tuple(I0, I1));
mrepeat_minus_copy(make_tuple(I0, I1));
}
}
}
}
else if constexpr(MRepeat == 2 && NRepeat == 1)
else if constexpr(MRepeat == 2 && NRepeat == 1)
{
{
init_copy(make_tuple(I0, I0));
init_copy(make_tuple(I0, I0));
mrepeat_plus_copy(make_tuple(I1, I0));
mrepeat_plus_copy(make_tuple(I1, I0));
}
}
else if constexpr(MRepeat == 1 && NRepeat == 2)
else if constexpr(MRepeat == 1 && NRepeat == 2)
{
{
init_copy(make_tuple(I0, I0));
init_copy(make_tuple(I0, I0));
nrepeat_plus_copy(make_tuple(I0, I1));
nrepeat_plus_copy(make_tuple(I0, I1));
}
}
else if constexpr(MRepeat == 1 && NRepeat == 1)
else if constexpr(MRepeat == 1 && NRepeat == 1)
{
{
init_copy(make_tuple(I0, I0));
init_copy(make_tuple(I0, I0));
}
}
}*/
}*/
}
}
};
// namespace ck
};
// namespace ck
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment