"...resnet50_tensorflow.git" did not exist on "f67822f5ff1d9e3495da069b31aa7643430e02c3"
Commit 971220d8 authored by ltqin's avatar ltqin
Browse files

gridwise gemm data copy and blockgwise gemm

parent a52e5a92
...@@ -264,444 +264,456 @@ struct GridwiseGemm_gk0mk1_gk0nk1_gmn_xdlops_v3r1 ...@@ -264,444 +264,456 @@ struct GridwiseGemm_gk0mk1_gk0nk1_gmn_xdlops_v3r1
decltype(MakeCGM0N0M1N1M2M3M4N2GridDescriptor(CGMNGridDesc{})); decltype(MakeCGM0N0M1N1M2M3M4N2GridDescriptor(CGMNGridDesc{}));
using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CGMNGridDesc{})); using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CGMNGridDesc{}));
__device__ static void Run(const FloatAB* __restrict__ p_a_grid, __device__ static void
const FloatAB* __restrict__ p_b_grid, Run(const FloatAB* __restrict__ p_a_grid,
FloatC* __restrict__ p_c_grid, const FloatAB* __restrict__ p_b_grid,
FloatAB* __restrict__ p_shared_block, FloatC* __restrict__ p_c_grid,
const AGK0MK1GridDesc& a_k0_m_k1_grid_desc, FloatAB* __restrict__ p_shared_block,
const BGK0NK1GridDesc& b_k0_n_k1_grid_desc, const AGK0MK1GridDesc& a_g_k0_m_k1_grid_desc,
const CM0N0M1N1M2M3M4N2GridDesc& c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, const BGK0NK1GridDesc& b_g_k0_n_k1_grid_desc,
const CBlockClusterAdaptor& c_block_cluster_adaptor) const CM0N0M1N1M2M3M4N2GridDesc& c_g_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
const CBlockClusterAdaptor& c_block_cluster_adaptor)
{ {
/* const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_a_grid, a_k0_m_k1_grid_desc.GetElementSpaceSize()); p_a_grid, a_g_k0_m_k1_grid_desc.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_b_grid, b_k0_n_k1_grid_desc.GetElementSpaceSize()); p_b_grid, b_g_k0_n_k1_grid_desc.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_c_grid, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc.GetElementSpaceSize()); p_c_grid, c_g_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc.GetElementSpaceSize());
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0); const auto K0 = a_g_k0_m_k1_grid_desc.GetLength(I1);
// divide block work by [M, N] // divide block work by [M, N]
const auto block_work_idx = const auto block_work_idx =
c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_block_1d_id())); c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
// HACK: this force m/n_block_data_idx_on_grid into SGPR // HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid = const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); __builtin_amdgcn_readfirstlane(block_work_idx[I1] * MPerBlock);
const index_t n_block_data_idx_on_grid = const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); __builtin_amdgcn_readfirstlane(block_work_idx[I2] * NPerBlock);
const index_t g_idx = block_work_idx[I0];
// lds max alignment
constexpr auto max_lds_align = K1; // lds max alignment
constexpr auto max_lds_align = K1;
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // A matrix in LDS memory, dst of blockwise copy
constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned( // be careful of LDS alignment
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align); constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment constexpr auto a_g_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned(
constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned( make_tuple(Number<1>{}, Number<KPerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align); // B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
// A matrix blockwise copy constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned(
auto a_blockwise_copy = make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
BlockwiseTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperationEnum_t::Set, constexpr auto b_g_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned(
Sequence<KPerBlock, MPerBlock, K1>, make_tuple(Number<1>{}, Number<KPerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
ABlockTransferThreadSliceLengths_K0_M_K1, // A matrix blockwise copy
ABlockTransferThreadClusterLengths_K0_M_K1, auto a_blockwise_copy =
ABlockTransferThreadClusterArrangeOrder, BlockwiseTensorSliceTransfer_v4<BlockSize,
FloatAB, InMemoryDataOperationEnum_t::Set,
FloatAB, Sequence<1, KPerBlock, MPerBlock, K1>,
decltype(a_k0_m_k1_grid_desc), ABlockTransferThreadSliceLengths_G_K0_M_K1,
decltype(a_k0_m_k1_block_desc), ABlockTransferThreadClusterLengths_G_K0_M_K1,
ABlockTransferSrcAccessOrder, ABlockTransferThreadClusterArrangeOrder,
Sequence<1, 0, 2>, FloatAB,
ABlockTransferSrcVectorDim, FloatAB,
2, decltype(a_g_k0_m_k1_grid_desc),
ABlockTransferSrcScalarPerVector, decltype(a_g_k0_m_k1_block_desc),
ABlockTransferDstScalarPerVector_K1, ABlockTransferSrcAccessOrder,
1, Sequence<0, 2, 1, 3>,
1, ABlockTransferSrcVectorDim,
AThreadTransferSrcResetCoordinateAfterRun, 3,
true>(a_k0_m_k1_grid_desc, ABlockTransferSrcScalarPerVector,
make_multi_index(0, m_block_data_idx_on_grid, ABlockTransferDstScalarPerVector_K1,
0), a_k0_m_k1_block_desc, make_multi_index(0, 0, 0)); 1,
1,
// B matrix blockwise copy AThreadTransferSrcResetCoordinateAfterRun,
auto b_blockwise_copy = true>(
BlockwiseTensorSliceTransfer_v4<BlockSize, a_g_k0_m_k1_grid_desc,
InMemoryDataOperationEnum_t::Set, make_multi_index(g_idx, 0, m_block_data_idx_on_grid, 0),
Sequence<KPerBlock, NPerBlock, K1>, a_g_k0_m_k1_block_desc,
BBlockTransferThreadSliceLengths_K0_N_K1, make_multi_index(0, 0, 0, 0));
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder, // B matrix blockwise copy
FloatAB, auto b_blockwise_copy =
FloatAB, BlockwiseTensorSliceTransfer_v4<BlockSize,
decltype(b_k0_n_k1_grid_desc), InMemoryDataOperationEnum_t::Set,
decltype(b_k0_n_k1_block_desc), Sequence<1, KPerBlock, NPerBlock, K1>,
BBlockTransferSrcAccessOrder, BBlockTransferThreadSliceLengths_G_K0_N_K1,
Sequence<1, 0, 2>, BBlockTransferThreadClusterLengths_G_K0_N_K1,
BBlockTransferSrcVectorDim, BBlockTransferThreadClusterArrangeOrder,
2, FloatAB,
BBlockTransferSrcScalarPerVector, FloatAB,
BBlockTransferDstScalarPerVector_K1, decltype(b_g_k0_n_k1_grid_desc),
1, decltype(b_g_k0_n_k1_block_desc),
1, BBlockTransferSrcAccessOrder,
BThreadTransferSrcResetCoordinateAfterRun, Sequence<0, 2, 1, 3>,
true>(b_k0_n_k1_grid_desc, BBlockTransferSrcVectorDim,
make_multi_index(0, n_block_data_idx_on_grid, 3,
0), b_k0_n_k1_block_desc, make_multi_index(0, 0, 0)); BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
// GEMM definition 1,
// c_mtx += transpose(a_mtx) * b_mtx 1,
// a_mtx[KPerBlock, MPerBlock] is in LDS BThreadTransferSrcResetCoordinateAfterRun,
// b_mtx[KPerBlock, NPerBlock] is in LDS true>(
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in b_g_k0_n_k1_grid_desc,
// register make_multi_index(g_idx, 0, n_block_data_idx_on_grid, 0),
// sanity check b_g_k0_n_k1_block_desc,
make_multi_index(0, 0, 0, 0));
const auto blockwise_gemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize, // GEMM definition
FloatAB, // c_mtx += transpose(a_mtx) * b_mtx
decltype(a_k0_m_k1_block_desc), // a_mtx[KPerBlock, MPerBlock] is in LDS
decltype(b_k0_n_k1_block_desc), // b_mtx[KPerBlock, NPerBlock] is in LDS
MPerXDL, // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
NPerXDL, // register
MRepeat, // sanity check
NRepeat,
K1>{}; const auto blockwise_gemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
constexpr auto c_mr_nr_blk_desc = FloatAB,
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{}, decltype(a_k0_m_k1_block_desc),
Number<NRepeat>{})); decltype(b_k0_n_k1_block_desc),
MPerXDL,
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc = NPerXDL,
blockwise_gemm.GetCM0N0M1N1M2M3M4N2ThreadDescriptor(); MRepeat,
constexpr auto CBlkSize = c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc.GetElementSpaceSize(); NRepeat,
K1>{};
StaticBuffer<AddressSpaceEnum_t::Vgpr,
vector_type<FloatAcc, CBlkSize>, constexpr auto c_mr_nr_blk_desc =
c_mr_nr_blk_desc.GetElementSpaceSize(), make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{}, Number<NRepeat>{}));
true>
c_thread_buf; constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc =
blockwise_gemm.GetCM0N0M1N1M2M3M4N2ThreadDescriptor();
// LDS allocation for A and B: be careful of alignment constexpr auto CBlkSize = c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc.GetElementSpaceSize();
constexpr auto a_block_space_size =
math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), StaticBuffer<AddressSpaceEnum_t::Vgpr,
max_lds_align); vector_type<FloatAcc, CBlkSize>,
c_mr_nr_blk_desc.GetElementSpaceSize(),
FloatAB* p_a_block = p_shared_block; true>
FloatAB* p_b_block = p_shared_block + a_block_space_size; c_thread_buf;
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0); // LDS allocation for A and B: be careful of alignment
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0); constexpr auto a_block_space_size =
math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
// hack to control index calculation when iterating over A and B matrix for threadwise
copy constexpr auto a_k0_m_k1_grid_step_hacks = AGridStepHacks{}; constexpr auto FloatAB* p_a_block = p_shared_block;
b_k0_n_k1_grid_step_hacks = BGridStepHacks{}; FloatAB* p_b_block = p_shared_block + a_block_space_size;
// hack to control index calculation when move slice window for A and B matrix for constexpr auto a_block_slice_copy_step = make_multi_index(0, KPerBlock, 0, 0);
// threadwise copy constexpr auto b_block_slice_copy_step = make_multi_index(0, KPerBlock, 0, 0);
constexpr auto a_k0_m_k1_grid_move_slice_window_step_hack =
AGridMoveSliceWindowStepHacks{}; constexpr auto // hack to control index calculation when iterating over A and B matrix for threadwise copy
b_k0_n_k1_grid_move_slice_window_step_hack = BGridMoveSliceWindowStepHacks{}; constexpr auto a_g_k0_m_k1_grid_step_hacks = AGridStepHacks{};
constexpr auto b_g_k0_n_k1_grid_step_hacks = BGridStepHacks{};
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
p_a_block, a_k0_m_k1_block_desc.GetElementSpaceSize()); // hack to control index calculation when move slice window for A and B matrix for
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>( // threadwise copy
p_b_block, b_k0_n_k1_block_desc.GetElementSpaceSize()); constexpr auto a_g_k0_m_k1_grid_move_slice_window_step_hack =
AGridMoveSliceWindowStepHacks{};
// preload data into LDS constexpr auto b_g_k0_n_k1_grid_move_slice_window_step_hack =
{ BGridMoveSliceWindowStepHacks{};
a_blockwise_copy.RunRead(a_k0_m_k1_grid_desc, a_grid_buf,
a_k0_m_k1_grid_step_hacks); b_blockwise_copy.RunRead(b_k0_n_k1_grid_desc, b_grid_buf, auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
b_k0_n_k1_grid_step_hacks); p_a_block, a_g_k0_m_k1_block_desc.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_buf); p_b_block, b_g_k0_n_k1_block_desc.GetElementSpaceSize());
b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_buf);
} // preload data into LDS
{
// main body a_blockwise_copy.RunRead(
index_t k_block_data_begin = 0; a_g_k0_m_k1_grid_desc, a_grid_buf, a_g_k0_m_k1_grid_step_hacks);
b_blockwise_copy.RunRead(
do b_g_k0_n_k1_grid_desc, b_grid_buf, b_g_k0_n_k1_grid_step_hacks);
{
a_blockwise_copy.MoveSrcSliceWindow(a_k0_m_k1_grid_desc, a_blockwise_copy.RunWrite(a_g_k0_m_k1_block_desc, a_block_buf);
a_block_slice_copy_step, b_blockwise_copy.RunWrite(b_g_k0_n_k1_block_desc, b_block_buf);
a_k0_m_k1_grid_move_slice_window_step_hack); }
b_blockwise_copy.MoveSrcSliceWindow(b_k0_n_k1_grid_desc,
b_block_slice_copy_step, // main body
b_k0_n_k1_grid_move_slice_window_step_hack); index_t k_block_data_begin = 0;
a_blockwise_copy.RunRead(a_k0_m_k1_grid_desc, a_grid_buf, do
a_k0_m_k1_grid_step_hacks); {
a_blockwise_copy.MoveSrcSliceWindow(a_g_k0_m_k1_grid_desc,
block_sync_lds(); a_block_slice_copy_step,
a_g_k0_m_k1_grid_move_slice_window_step_hack);
b_blockwise_copy.RunRead(b_k0_n_k1_grid_desc, b_grid_buf, b_blockwise_copy.MoveSrcSliceWindow(b_g_k0_n_k1_grid_desc,
b_k0_n_k1_grid_step_hacks); b_block_slice_copy_step,
b_g_k0_n_k1_grid_move_slice_window_step_hack);
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
a_blockwise_copy.RunRead(
block_sync_lds(); a_g_k0_m_k1_grid_desc, a_grid_buf, a_g_k0_m_k1_grid_step_hacks);
a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_buf); block_sync_lds();
b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_buf);
b_blockwise_copy.RunRead(
k_block_data_begin += KPerBlock; b_g_k0_n_k1_grid_desc, b_grid_buf, b_g_k0_n_k1_grid_step_hacks);
} while(k_block_data_begin < (K0 - KPerBlock));
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
// tail
{ block_sync_lds();
block_sync_lds();
a_blockwise_copy.RunWrite(a_g_k0_m_k1_block_desc, a_block_buf);
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); b_blockwise_copy.RunWrite(b_g_k0_n_k1_block_desc, b_block_buf);
}
k_block_data_begin += KPerBlock;
// output: register to global memory } while(k_block_data_begin < (K0 - KPerBlock));
{
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc = // tail
blockwise_gemm.GetCM0N0M1N1M2M3M4N2BlockDescriptor(); {
block_sync_lds();
constexpr auto M2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I4);
constexpr auto M3 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5); blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
constexpr auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6); }
// calculate origin of thread output tensor on global memory /* // output: register to global memory
// blockwise GEMM c matrix starting index {
const auto c_thread_mtx_on_block = constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc =
blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); blockwise_gemm.GetCM0N0M1N1M2M3M4N2BlockDescriptor();
const index_t m_thread_data_on_grid = constexpr auto M2 =
m_block_data_idx_on_grid + c_thread_mtx_on_block[I0]; c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I4); constexpr auto M3 =
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5); constexpr auto M4 =
const index_t n_thread_data_on_grid = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6);
n_block_data_idx_on_grid + c_thread_mtx_on_block[I1];
// calculate origin of thread output tensor on global memory
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks = CGridStepHacks{}; // blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
auto c_thread_copy = blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0,
ThreadwiseTensorSliceTransfer_v1r3<FloatC, I0);
FloatC,
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc), const index_t m_thread_data_on_grid =
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc), m_block_data_idx_on_grid + c_thread_mtx_on_block[I0];
Sequence<I1, I1, I1, I1, M2, I1, M4, I1>,
CThreadTransferSrcDstAccessOrder, const index_t n_thread_data_on_grid =
CThreadTransferSrcDstVectorDim, n_block_data_idx_on_grid + c_thread_mtx_on_block[I1];
CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation, constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks
1, = CGridStepHacks{};
true>{
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, auto c_thread_copy =
make_multi_index(0, ThreadwiseTensorSliceTransfer_v1r3<FloatC,
0, FloatC,
0, decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc),
0, decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc),
m_thread_data_on_grid / (M3 * M4), Sequence<I1, I1, I1, I1,
m_thread_data_on_grid % (M3 * M4) / M4, M2, I1, M4, I1>, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim,
m_thread_data_on_grid % M4, CThreadTransferDstScalarPerVector,
n_thread_data_on_grid)}; CGlobalMemoryDataOperation,
1,
auto init_copy = [&](auto c_thread_idx_) { true>{
constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_); c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, make_multi_index(0,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), 0,
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(), 0,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, 0,
c_grid_buf, m_thread_data_on_grid / (M3 * M4),
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks); m_thread_data_on_grid % (M3 * M4) /
M4, m_thread_data_on_grid % M4, n_thread_data_on_grid)};
return c_thread_idx_;
}; auto init_copy = [&](auto c_thread_idx_) {
constexpr auto blk_off =
auto mrepeat_plus_copy = [&](auto c_thread_idx_) { c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
constexpr auto mrepeat_step_plus = make_multi_index(1, 0, 0, 0, 0, 0, 0, 0); c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, make_tuple(I0, I0, I0, I0, I0, I0, I0,
mrepeat_step_plus); I0), c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, c_grid_buf,
constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_); c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks);
c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), return c_thread_idx_;
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(), };
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_grid_buf, auto mrepeat_plus_copy = [&](auto c_thread_idx_) {
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks); constexpr auto mrepeat_step_plus = make_multi_index(1, 0,
}; 0, 0, 0, 0, 0, 0); c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
mrepeat_step_plus);
auto nrepeat_plus_copy = [&](auto c_thread_idx_) {
constexpr auto nrepeat_step_plus = make_multi_index(0, 1, 0, 0, 0, 0, 0, 0); constexpr auto blk_off =
c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
nrepeat_step_plus); c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0,
constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_); I0), c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, c_grid_buf,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks);
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(), };
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_grid_buf, auto nrepeat_plus_copy = [&](auto c_thread_idx_) {
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks); constexpr auto nrepeat_step_plus = make_multi_index(0, 1,
}; 0, 0, 0, 0, 0, 0); c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
nrepeat_step_plus);
auto mrepeat_minus_copy = [&](auto c_thread_idx_) {
constexpr auto mrepeat_step_plus = make_multi_index(-1, 0, 0, 0, 0, 0, 0, 0); constexpr auto blk_off =
c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
mrepeat_step_plus); c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0,
constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_); I0), c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, c_grid_buf,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks);
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(), };
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_grid_buf, auto mrepeat_minus_copy = [&](auto c_thread_idx_) {
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks); constexpr auto mrepeat_step_plus = make_multi_index(-1, 0,
}; 0, 0, 0, 0, 0, 0); c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
mrepeat_step_plus);
auto nrepeat_minus_copy = [&](auto c_thread_idx_) {
constexpr auto nrepeat_step_minus = make_multi_index(0, -1, 0, 0, 0, 0, 0, 0); constexpr auto blk_off =
c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
nrepeat_step_minus); c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0,
constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_); I0), c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, c_grid_buf,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks);
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(), };
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_grid_buf, auto nrepeat_minus_copy = [&](auto c_thread_idx_) {
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks); constexpr auto nrepeat_step_minus = make_multi_index(0, -1,
}; 0, 0, 0, 0, 0, 0); c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
nrepeat_step_minus);
static_assert((MRepeat == 4 && NRepeat == 4) or (MRepeat == 4 && NRepeat == 2) or
(MRepeat == 2 && NRepeat == 4) or (MRepeat == 2 && NRepeat == 2) constexpr auto blk_off =
or (MRepeat == 2 && NRepeat == 1) or (MRepeat == 1 && NRepeat == 2) or (MRepeat == 1 && c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0,
I0), c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, c_grid_buf,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks);
};
static_assert((MRepeat == 4 && NRepeat == 4) or (MRepeat == 4
&& NRepeat == 2) or (MRepeat == 2 && NRepeat == 4) or (MRepeat == 2 && NRepeat == 2) or
(MRepeat == 2
&& NRepeat == 1) or (MRepeat == 1 && NRepeat == 2) or (MRepeat == 1 &&
NRepeat == 1), "wrong"); NRepeat == 1), "wrong");
if constexpr(MRepeat == 4 && NRepeat == 4) if constexpr(MRepeat == 4 && NRepeat == 4)
{ {
init_copy(make_tuple(I0, I0)); init_copy(make_tuple(I0, I0));
if constexpr(CAccessOrderMRepeatNRepeat) if constexpr(CAccessOrderMRepeatNRepeat)
{ {
nrepeat_plus_copy(make_tuple(I0, I1)); nrepeat_plus_copy(make_tuple(I0, I1));
nrepeat_plus_copy(make_tuple(I0, I2)); nrepeat_plus_copy(make_tuple(I0, I2));
nrepeat_plus_copy(make_tuple(I0, I3)); nrepeat_plus_copy(make_tuple(I0, I3));
mrepeat_plus_copy(make_tuple(I1, I3)); mrepeat_plus_copy(make_tuple(I1, I3));
nrepeat_minus_copy(make_tuple(I1, I2)); nrepeat_minus_copy(make_tuple(I1, I2));
nrepeat_minus_copy(make_tuple(I1, I1)); nrepeat_minus_copy(make_tuple(I1, I1));
nrepeat_minus_copy(make_tuple(I1, I0)); nrepeat_minus_copy(make_tuple(I1, I0));
mrepeat_plus_copy(make_tuple(I2, I0)); mrepeat_plus_copy(make_tuple(I2, I0));
nrepeat_plus_copy(make_tuple(I2, I1)); nrepeat_plus_copy(make_tuple(I2, I1));
nrepeat_plus_copy(make_tuple(I2, I2)); nrepeat_plus_copy(make_tuple(I2, I2));
nrepeat_plus_copy(make_tuple(I2, I3)); nrepeat_plus_copy(make_tuple(I2, I3));
mrepeat_plus_copy(make_tuple(I3, I3)); mrepeat_plus_copy(make_tuple(I3, I3));
nrepeat_minus_copy(make_tuple(I3, I2)); nrepeat_minus_copy(make_tuple(I3, I2));
nrepeat_minus_copy(make_tuple(I3, I1)); nrepeat_minus_copy(make_tuple(I3, I1));
nrepeat_minus_copy(make_tuple(I3, I0)); nrepeat_minus_copy(make_tuple(I3, I0));
} }
else else
{ {
mrepeat_plus_copy(make_tuple(I1, I0)); mrepeat_plus_copy(make_tuple(I1, I0));
mrepeat_plus_copy(make_tuple(I2, I0)); mrepeat_plus_copy(make_tuple(I2, I0));
mrepeat_plus_copy(make_tuple(I3, I0)); mrepeat_plus_copy(make_tuple(I3, I0));
nrepeat_plus_copy(make_tuple(I3, I1)); nrepeat_plus_copy(make_tuple(I3, I1));
mrepeat_minus_copy(make_tuple(I2, I1)); mrepeat_minus_copy(make_tuple(I2, I1));
mrepeat_minus_copy(make_tuple(I1, I1)); mrepeat_minus_copy(make_tuple(I1, I1));
mrepeat_minus_copy(make_tuple(I0, I1)); mrepeat_minus_copy(make_tuple(I0, I1));
nrepeat_plus_copy(make_tuple(I0, I2)); nrepeat_plus_copy(make_tuple(I0, I2));
mrepeat_plus_copy(make_tuple(I1, I2)); mrepeat_plus_copy(make_tuple(I1, I2));
mrepeat_plus_copy(make_tuple(I2, I2)); mrepeat_plus_copy(make_tuple(I2, I2));
mrepeat_plus_copy(make_tuple(I3, I2)); mrepeat_plus_copy(make_tuple(I3, I2));
nrepeat_plus_copy(make_tuple(I3, I3)); nrepeat_plus_copy(make_tuple(I3, I3));
mrepeat_minus_copy(make_tuple(I2, I3)); mrepeat_minus_copy(make_tuple(I2, I3));
mrepeat_minus_copy(make_tuple(I1, I3)); mrepeat_minus_copy(make_tuple(I1, I3));
mrepeat_minus_copy(make_tuple(I0, I3)); mrepeat_minus_copy(make_tuple(I0, I3));
} }
} }
else if constexpr(MRepeat == 4 && NRepeat == 2) else if constexpr(MRepeat == 4 && NRepeat == 2)
{ {
init_copy(make_tuple(I0, I0)); init_copy(make_tuple(I0, I0));
if constexpr(CAccessOrderMRepeatNRepeat) if constexpr(CAccessOrderMRepeatNRepeat)
{ {
nrepeat_plus_copy(make_tuple(I0, I1)); nrepeat_plus_copy(make_tuple(I0, I1));
mrepeat_plus_copy(make_tuple(I1, I1)); mrepeat_plus_copy(make_tuple(I1, I1));
nrepeat_minus_copy(make_tuple(I1, I0)); nrepeat_minus_copy(make_tuple(I1, I0));
mrepeat_plus_copy(make_tuple(I2, I0)); mrepeat_plus_copy(make_tuple(I2, I0));
nrepeat_plus_copy(make_tuple(I2, I1)); nrepeat_plus_copy(make_tuple(I2, I1));
mrepeat_plus_copy(make_tuple(I3, I1)); mrepeat_plus_copy(make_tuple(I3, I1));
nrepeat_minus_copy(make_tuple(I3, I0)); nrepeat_minus_copy(make_tuple(I3, I0));
} }
else else
{ {
mrepeat_plus_copy(make_tuple(I1, I0)); mrepeat_plus_copy(make_tuple(I1, I0));
mrepeat_plus_copy(make_tuple(I2, I0)); mrepeat_plus_copy(make_tuple(I2, I0));
mrepeat_plus_copy(make_tuple(I3, I0)); mrepeat_plus_copy(make_tuple(I3, I0));
nrepeat_plus_copy(make_tuple(I3, I1)); nrepeat_plus_copy(make_tuple(I3, I1));
mrepeat_minus_copy(make_tuple(I2, I1)); mrepeat_minus_copy(make_tuple(I2, I1));
mrepeat_minus_copy(make_tuple(I1, I1)); mrepeat_minus_copy(make_tuple(I1, I1));
mrepeat_minus_copy(make_tuple(I0, I1)); mrepeat_minus_copy(make_tuple(I0, I1));
} }
} }
else if constexpr(MRepeat == 2 && NRepeat == 4) else if constexpr(MRepeat == 2 && NRepeat == 4)
{ {
init_copy(make_tuple(I0, I0)); init_copy(make_tuple(I0, I0));
if constexpr(CAccessOrderMRepeatNRepeat) if constexpr(CAccessOrderMRepeatNRepeat)
{ {
nrepeat_plus_copy(make_tuple(I0, I1)); nrepeat_plus_copy(make_tuple(I0, I1));
nrepeat_plus_copy(make_tuple(I0, I2)); nrepeat_plus_copy(make_tuple(I0, I2));
nrepeat_plus_copy(make_tuple(I0, I3)); nrepeat_plus_copy(make_tuple(I0, I3));
mrepeat_plus_copy(make_tuple(I1, I3)); mrepeat_plus_copy(make_tuple(I1, I3));
nrepeat_minus_copy(make_tuple(I1, I2)); nrepeat_minus_copy(make_tuple(I1, I2));
nrepeat_minus_copy(make_tuple(I1, I1)); nrepeat_minus_copy(make_tuple(I1, I1));
nrepeat_minus_copy(make_tuple(I1, I0)); nrepeat_minus_copy(make_tuple(I1, I0));
} }
else else
{ {
mrepeat_plus_copy(make_tuple(I1, I0)); mrepeat_plus_copy(make_tuple(I1, I0));
nrepeat_plus_copy(make_tuple(I1, I1)); nrepeat_plus_copy(make_tuple(I1, I1));
mrepeat_minus_copy(make_tuple(I0, I1)); mrepeat_minus_copy(make_tuple(I0, I1));
nrepeat_plus_copy(make_tuple(I0, I2)); nrepeat_plus_copy(make_tuple(I0, I2));
mrepeat_plus_copy(make_tuple(I1, I2)); mrepeat_plus_copy(make_tuple(I1, I2));
nrepeat_plus_copy(make_tuple(I1, I3)); nrepeat_plus_copy(make_tuple(I1, I3));
mrepeat_minus_copy(make_tuple(I0, I3)); mrepeat_minus_copy(make_tuple(I0, I3));
} }
} }
else if constexpr(MRepeat == 2 && NRepeat == 2) else if constexpr(MRepeat == 2 && NRepeat == 2)
{ {
init_copy(make_tuple(I0, I0)); init_copy(make_tuple(I0, I0));
if constexpr(CAccessOrderMRepeatNRepeat) if constexpr(CAccessOrderMRepeatNRepeat)
{ {
nrepeat_plus_copy(make_tuple(I0, I1)); nrepeat_plus_copy(make_tuple(I0, I1));
mrepeat_plus_copy(make_tuple(I1, I1)); mrepeat_plus_copy(make_tuple(I1, I1));
nrepeat_minus_copy(make_tuple(I1, I0)); nrepeat_minus_copy(make_tuple(I1, I0));
} }
else else
{ {
mrepeat_plus_copy(make_tuple(I1, I0)); mrepeat_plus_copy(make_tuple(I1, I0));
nrepeat_plus_copy(make_tuple(I1, I1)); nrepeat_plus_copy(make_tuple(I1, I1));
mrepeat_minus_copy(make_tuple(I0, I1)); mrepeat_minus_copy(make_tuple(I0, I1));
} }
} }
else if constexpr(MRepeat == 2 && NRepeat == 1) else if constexpr(MRepeat == 2 && NRepeat == 1)
{ {
init_copy(make_tuple(I0, I0)); init_copy(make_tuple(I0, I0));
mrepeat_plus_copy(make_tuple(I1, I0)); mrepeat_plus_copy(make_tuple(I1, I0));
} }
else if constexpr(MRepeat == 1 && NRepeat == 2) else if constexpr(MRepeat == 1 && NRepeat == 2)
{ {
init_copy(make_tuple(I0, I0)); init_copy(make_tuple(I0, I0));
nrepeat_plus_copy(make_tuple(I0, I1)); nrepeat_plus_copy(make_tuple(I0, I1));
} }
else if constexpr(MRepeat == 1 && NRepeat == 1) else if constexpr(MRepeat == 1 && NRepeat == 1)
{ {
init_copy(make_tuple(I0, I0)); init_copy(make_tuple(I0, I0));
} }
}*/ }*/
} }
}; // namespace ck }; // namespace ck
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment