"test/vscode:/vscode.git/clone" did not exist on "c51f20f9c5fd15e60702d3e4cbfe5d68c16a487a"
Commit 971220d8 authored by ltqin's avatar ltqin
Browse files

gridwise gemm data copy and blockgwise gemm

parent a52e5a92
...@@ -264,23 +264,24 @@ struct GridwiseGemm_gk0mk1_gk0nk1_gmn_xdlops_v3r1 ...@@ -264,23 +264,24 @@ struct GridwiseGemm_gk0mk1_gk0nk1_gmn_xdlops_v3r1
decltype(MakeCGM0N0M1N1M2M3M4N2GridDescriptor(CGMNGridDesc{})); decltype(MakeCGM0N0M1N1M2M3M4N2GridDescriptor(CGMNGridDesc{}));
using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CGMNGridDesc{})); using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CGMNGridDesc{}));
__device__ static void Run(const FloatAB* __restrict__ p_a_grid, __device__ static void
Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
FloatAB* __restrict__ p_shared_block, FloatAB* __restrict__ p_shared_block,
const AGK0MK1GridDesc& a_k0_m_k1_grid_desc, const AGK0MK1GridDesc& a_g_k0_m_k1_grid_desc,
const BGK0NK1GridDesc& b_k0_n_k1_grid_desc, const BGK0NK1GridDesc& b_g_k0_n_k1_grid_desc,
const CM0N0M1N1M2M3M4N2GridDesc& c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, const CM0N0M1N1M2M3M4N2GridDesc& c_g_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
const CBlockClusterAdaptor& c_block_cluster_adaptor) const CBlockClusterAdaptor& c_block_cluster_adaptor)
{ {
/* const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_a_grid, a_k0_m_k1_grid_desc.GetElementSpaceSize()); p_a_grid, a_g_k0_m_k1_grid_desc.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_b_grid, b_k0_n_k1_grid_desc.GetElementSpaceSize()); p_b_grid, b_g_k0_n_k1_grid_desc.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_c_grid, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc.GetElementSpaceSize()); p_c_grid, c_g_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc.GetElementSpaceSize());
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0); const auto K0 = a_g_k0_m_k1_grid_desc.GetLength(I1);
// divide block work by [M, N] // divide block work by [M, N]
const auto block_work_idx = const auto block_work_idx =
...@@ -288,10 +289,11 @@ struct GridwiseGemm_gk0mk1_gk0nk1_gmn_xdlops_v3r1 ...@@ -288,10 +289,11 @@ struct GridwiseGemm_gk0mk1_gk0nk1_gmn_xdlops_v3r1
// HACK: this force m/n_block_data_idx_on_grid into SGPR // HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid = const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); __builtin_amdgcn_readfirstlane(block_work_idx[I1] * MPerBlock);
const index_t n_block_data_idx_on_grid = const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); __builtin_amdgcn_readfirstlane(block_work_idx[I2] * NPerBlock);
const index_t g_idx = block_work_idx[I0];
// lds max alignment // lds max alignment
constexpr auto max_lds_align = K1; constexpr auto max_lds_align = K1;
...@@ -301,60 +303,68 @@ struct GridwiseGemm_gk0mk1_gk0nk1_gmn_xdlops_v3r1 ...@@ -301,60 +303,68 @@ struct GridwiseGemm_gk0mk1_gk0nk1_gmn_xdlops_v3r1
constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned( constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align); make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
constexpr auto a_g_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<1>{}, Number<KPerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
// B matrix in LDS memory, dst of blockwise copy // B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned( constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align); make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
constexpr auto b_g_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<1>{}, Number<KPerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
// A matrix blockwise copy // A matrix blockwise copy
auto a_blockwise_copy = auto a_blockwise_copy =
BlockwiseTensorSliceTransfer_v4<BlockSize, BlockwiseTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum_t::Set,
Sequence<KPerBlock, MPerBlock, K1>, Sequence<1, KPerBlock, MPerBlock, K1>,
ABlockTransferThreadSliceLengths_K0_M_K1, ABlockTransferThreadSliceLengths_G_K0_M_K1,
ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterLengths_G_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
FloatAB, FloatAB,
FloatAB, FloatAB,
decltype(a_k0_m_k1_grid_desc), decltype(a_g_k0_m_k1_grid_desc),
decltype(a_k0_m_k1_block_desc), decltype(a_g_k0_m_k1_block_desc),
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
Sequence<1, 0, 2>, Sequence<0, 2, 1, 3>,
ABlockTransferSrcVectorDim, ABlockTransferSrcVectorDim,
2, 3,
ABlockTransferSrcScalarPerVector, ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1, ABlockTransferDstScalarPerVector_K1,
1, 1,
1, 1,
AThreadTransferSrcResetCoordinateAfterRun, AThreadTransferSrcResetCoordinateAfterRun,
true>(a_k0_m_k1_grid_desc, true>(
make_multi_index(0, m_block_data_idx_on_grid, a_g_k0_m_k1_grid_desc,
0), a_k0_m_k1_block_desc, make_multi_index(0, 0, 0)); make_multi_index(g_idx, 0, m_block_data_idx_on_grid, 0),
a_g_k0_m_k1_block_desc,
make_multi_index(0, 0, 0, 0));
// B matrix blockwise copy // B matrix blockwise copy
auto b_blockwise_copy = auto b_blockwise_copy =
BlockwiseTensorSliceTransfer_v4<BlockSize, BlockwiseTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum_t::Set,
Sequence<KPerBlock, NPerBlock, K1>, Sequence<1, KPerBlock, NPerBlock, K1>,
BBlockTransferThreadSliceLengths_K0_N_K1, BBlockTransferThreadSliceLengths_G_K0_N_K1,
BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterLengths_G_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
FloatAB, FloatAB,
FloatAB, FloatAB,
decltype(b_k0_n_k1_grid_desc), decltype(b_g_k0_n_k1_grid_desc),
decltype(b_k0_n_k1_block_desc), decltype(b_g_k0_n_k1_block_desc),
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
Sequence<1, 0, 2>, Sequence<0, 2, 1, 3>,
BBlockTransferSrcVectorDim, BBlockTransferSrcVectorDim,
2, 3,
BBlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1, BBlockTransferDstScalarPerVector_K1,
1, 1,
1, 1,
BThreadTransferSrcResetCoordinateAfterRun, BThreadTransferSrcResetCoordinateAfterRun,
true>(b_k0_n_k1_grid_desc, true>(
make_multi_index(0, n_block_data_idx_on_grid, b_g_k0_n_k1_grid_desc,
0), b_k0_n_k1_block_desc, make_multi_index(0, 0, 0)); make_multi_index(g_idx, 0, n_block_data_idx_on_grid, 0),
b_g_k0_n_k1_block_desc,
make_multi_index(0, 0, 0, 0));
// GEMM definition // GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx // c_mtx += transpose(a_mtx) * b_mtx
...@@ -376,8 +386,7 @@ struct GridwiseGemm_gk0mk1_gk0nk1_gmn_xdlops_v3r1 ...@@ -376,8 +386,7 @@ struct GridwiseGemm_gk0mk1_gk0nk1_gmn_xdlops_v3r1
K1>{}; K1>{};
constexpr auto c_mr_nr_blk_desc = constexpr auto c_mr_nr_blk_desc =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{}, make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{}, Number<NRepeat>{}));
Number<NRepeat>{}));
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc = constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc =
blockwise_gemm.GetCM0N0M1N1M2M3M4N2ThreadDescriptor(); blockwise_gemm.GetCM0N0M1N1M2M3M4N2ThreadDescriptor();
...@@ -391,38 +400,39 @@ struct GridwiseGemm_gk0mk1_gk0nk1_gmn_xdlops_v3r1 ...@@ -391,38 +400,39 @@ struct GridwiseGemm_gk0mk1_gk0nk1_gmn_xdlops_v3r1
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size = constexpr auto a_block_space_size =
math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
max_lds_align);
FloatAB* p_a_block = p_shared_block; FloatAB* p_a_block = p_shared_block;
FloatAB* p_b_block = p_shared_block + a_block_space_size; FloatAB* p_b_block = p_shared_block + a_block_space_size;
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0); constexpr auto a_block_slice_copy_step = make_multi_index(0, KPerBlock, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(0, KPerBlock, 0, 0);
// hack to control index calculation when iterating over A and B matrix for threadwise // hack to control index calculation when iterating over A and B matrix for threadwise copy
copy constexpr auto a_k0_m_k1_grid_step_hacks = AGridStepHacks{}; constexpr auto constexpr auto a_g_k0_m_k1_grid_step_hacks = AGridStepHacks{};
b_k0_n_k1_grid_step_hacks = BGridStepHacks{}; constexpr auto b_g_k0_n_k1_grid_step_hacks = BGridStepHacks{};
// hack to control index calculation when move slice window for A and B matrix for // hack to control index calculation when move slice window for A and B matrix for
// threadwise copy // threadwise copy
constexpr auto a_k0_m_k1_grid_move_slice_window_step_hack = constexpr auto a_g_k0_m_k1_grid_move_slice_window_step_hack =
AGridMoveSliceWindowStepHacks{}; constexpr auto AGridMoveSliceWindowStepHacks{};
b_k0_n_k1_grid_move_slice_window_step_hack = BGridMoveSliceWindowStepHacks{}; constexpr auto b_g_k0_n_k1_grid_move_slice_window_step_hack =
BGridMoveSliceWindowStepHacks{};
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>( auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
p_a_block, a_k0_m_k1_block_desc.GetElementSpaceSize()); p_a_block, a_g_k0_m_k1_block_desc.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>( auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
p_b_block, b_k0_n_k1_block_desc.GetElementSpaceSize()); p_b_block, b_g_k0_n_k1_block_desc.GetElementSpaceSize());
// preload data into LDS // preload data into LDS
{ {
a_blockwise_copy.RunRead(a_k0_m_k1_grid_desc, a_grid_buf, a_blockwise_copy.RunRead(
a_k0_m_k1_grid_step_hacks); b_blockwise_copy.RunRead(b_k0_n_k1_grid_desc, b_grid_buf, a_g_k0_m_k1_grid_desc, a_grid_buf, a_g_k0_m_k1_grid_step_hacks);
b_k0_n_k1_grid_step_hacks); b_blockwise_copy.RunRead(
b_g_k0_n_k1_grid_desc, b_grid_buf, b_g_k0_n_k1_grid_step_hacks);
a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_buf); a_blockwise_copy.RunWrite(a_g_k0_m_k1_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_buf); b_blockwise_copy.RunWrite(b_g_k0_n_k1_block_desc, b_block_buf);
} }
// main body // main body
...@@ -430,27 +440,27 @@ struct GridwiseGemm_gk0mk1_gk0nk1_gmn_xdlops_v3r1 ...@@ -430,27 +440,27 @@ struct GridwiseGemm_gk0mk1_gk0nk1_gmn_xdlops_v3r1
do do
{ {
a_blockwise_copy.MoveSrcSliceWindow(a_k0_m_k1_grid_desc, a_blockwise_copy.MoveSrcSliceWindow(a_g_k0_m_k1_grid_desc,
a_block_slice_copy_step, a_block_slice_copy_step,
a_k0_m_k1_grid_move_slice_window_step_hack); a_g_k0_m_k1_grid_move_slice_window_step_hack);
b_blockwise_copy.MoveSrcSliceWindow(b_k0_n_k1_grid_desc, b_blockwise_copy.MoveSrcSliceWindow(b_g_k0_n_k1_grid_desc,
b_block_slice_copy_step, b_block_slice_copy_step,
b_k0_n_k1_grid_move_slice_window_step_hack); b_g_k0_n_k1_grid_move_slice_window_step_hack);
a_blockwise_copy.RunRead(a_k0_m_k1_grid_desc, a_grid_buf, a_blockwise_copy.RunRead(
a_k0_m_k1_grid_step_hacks); a_g_k0_m_k1_grid_desc, a_grid_buf, a_g_k0_m_k1_grid_step_hacks);
block_sync_lds(); block_sync_lds();
b_blockwise_copy.RunRead(b_k0_n_k1_grid_desc, b_grid_buf, b_blockwise_copy.RunRead(
b_k0_n_k1_grid_step_hacks); b_g_k0_n_k1_grid_desc, b_grid_buf, b_g_k0_n_k1_grid_step_hacks);
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds(); block_sync_lds();
a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_buf); a_blockwise_copy.RunWrite(a_g_k0_m_k1_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_buf); b_blockwise_copy.RunWrite(b_g_k0_n_k1_block_desc, b_block_buf);
k_block_data_begin += KPerBlock; k_block_data_begin += KPerBlock;
} while(k_block_data_begin < (K0 - KPerBlock)); } while(k_block_data_begin < (K0 - KPerBlock));
...@@ -462,19 +472,21 @@ struct GridwiseGemm_gk0mk1_gk0nk1_gmn_xdlops_v3r1 ...@@ -462,19 +472,21 @@ struct GridwiseGemm_gk0mk1_gk0nk1_gmn_xdlops_v3r1
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
} }
// output: register to global memory /* // output: register to global memory
{ {
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc = constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc =
blockwise_gemm.GetCM0N0M1N1M2M3M4N2BlockDescriptor(); blockwise_gemm.GetCM0N0M1N1M2M3M4N2BlockDescriptor();
constexpr auto M2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I4); constexpr auto M2 =
constexpr auto M3 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5); c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I4); constexpr auto M3 =
constexpr auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6); c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5); constexpr auto M4 =
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6);
// calculate origin of thread output tensor on global memory // calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index // blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block = const auto c_thread_mtx_on_block =
blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0,
I0);
const index_t m_thread_data_on_grid = const index_t m_thread_data_on_grid =
m_block_data_idx_on_grid + c_thread_mtx_on_block[I0]; m_block_data_idx_on_grid + c_thread_mtx_on_block[I0];
...@@ -482,16 +494,16 @@ struct GridwiseGemm_gk0mk1_gk0nk1_gmn_xdlops_v3r1 ...@@ -482,16 +494,16 @@ struct GridwiseGemm_gk0mk1_gk0nk1_gmn_xdlops_v3r1
const index_t n_thread_data_on_grid = const index_t n_thread_data_on_grid =
n_block_data_idx_on_grid + c_thread_mtx_on_block[I1]; n_block_data_idx_on_grid + c_thread_mtx_on_block[I1];
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks = CGridStepHacks{}; constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks
= CGridStepHacks{};
auto c_thread_copy = auto c_thread_copy =
ThreadwiseTensorSliceTransfer_v1r3<FloatC, ThreadwiseTensorSliceTransfer_v1r3<FloatC,
FloatC, FloatC,
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc), decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc),
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc), decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc),
Sequence<I1, I1, I1, I1, M2, I1, M4, I1>, Sequence<I1, I1, I1, I1,
CThreadTransferSrcDstAccessOrder, M2, I1, M4, I1>, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector, CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation, CGlobalMemoryDataOperation,
1, 1,
...@@ -502,81 +514,81 @@ struct GridwiseGemm_gk0mk1_gk0nk1_gmn_xdlops_v3r1 ...@@ -502,81 +514,81 @@ struct GridwiseGemm_gk0mk1_gk0nk1_gmn_xdlops_v3r1
0, 0,
0, 0,
m_thread_data_on_grid / (M3 * M4), m_thread_data_on_grid / (M3 * M4),
m_thread_data_on_grid % (M3 * M4) / M4, m_thread_data_on_grid % (M3 * M4) /
m_thread_data_on_grid % M4, M4, m_thread_data_on_grid % M4, n_thread_data_on_grid)};
n_thread_data_on_grid)};
auto init_copy = [&](auto c_thread_idx_) { auto init_copy = [&](auto c_thread_idx_) {
constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_); constexpr auto blk_off =
c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0,
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(), I0), c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, c_grid_buf,
c_grid_buf,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks); c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks);
return c_thread_idx_; return c_thread_idx_;
}; };
auto mrepeat_plus_copy = [&](auto c_thread_idx_) { auto mrepeat_plus_copy = [&](auto c_thread_idx_) {
constexpr auto mrepeat_step_plus = make_multi_index(1, 0, 0, 0, 0, 0, 0, 0); constexpr auto mrepeat_step_plus = make_multi_index(1, 0,
c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, 0, 0, 0, 0, 0, 0); c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
mrepeat_step_plus); mrepeat_step_plus);
constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_); constexpr auto blk_off =
c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0,
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(), I0), c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, c_grid_buf,
c_grid_buf,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks); c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks);
}; };
auto nrepeat_plus_copy = [&](auto c_thread_idx_) { auto nrepeat_plus_copy = [&](auto c_thread_idx_) {
constexpr auto nrepeat_step_plus = make_multi_index(0, 1, 0, 0, 0, 0, 0, 0); constexpr auto nrepeat_step_plus = make_multi_index(0, 1,
c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, 0, 0, 0, 0, 0, 0); c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
nrepeat_step_plus); nrepeat_step_plus);
constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_); constexpr auto blk_off =
c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0,
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(), I0), c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, c_grid_buf,
c_grid_buf,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks); c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks);
}; };
auto mrepeat_minus_copy = [&](auto c_thread_idx_) { auto mrepeat_minus_copy = [&](auto c_thread_idx_) {
constexpr auto mrepeat_step_plus = make_multi_index(-1, 0, 0, 0, 0, 0, 0, 0); constexpr auto mrepeat_step_plus = make_multi_index(-1, 0,
c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, 0, 0, 0, 0, 0, 0); c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
mrepeat_step_plus); mrepeat_step_plus);
constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_); constexpr auto blk_off =
c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0,
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(), I0), c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, c_grid_buf,
c_grid_buf,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks); c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks);
}; };
auto nrepeat_minus_copy = [&](auto c_thread_idx_) { auto nrepeat_minus_copy = [&](auto c_thread_idx_) {
constexpr auto nrepeat_step_minus = make_multi_index(0, -1, 0, 0, 0, 0, 0, 0); constexpr auto nrepeat_step_minus = make_multi_index(0, -1,
c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, 0, 0, 0, 0, 0, 0); c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
nrepeat_step_minus); nrepeat_step_minus);
constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_); constexpr auto blk_off =
c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0,
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(), I0), c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, c_grid_buf,
c_grid_buf,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks); c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks);
}; };
static_assert((MRepeat == 4 && NRepeat == 4) or (MRepeat == 4 && NRepeat == 2) or static_assert((MRepeat == 4 && NRepeat == 4) or (MRepeat == 4
(MRepeat == 2 && NRepeat == 4) or (MRepeat == 2 && NRepeat == 2) && NRepeat == 2) or (MRepeat == 2 && NRepeat == 4) or (MRepeat == 2 && NRepeat == 2) or
or (MRepeat == 2 && NRepeat == 1) or (MRepeat == 1 && NRepeat == 2) or (MRepeat == 1 && (MRepeat == 2
&& NRepeat == 1) or (MRepeat == 1 && NRepeat == 2) or (MRepeat == 1 &&
NRepeat == 1), "wrong"); NRepeat == 1), "wrong");
if constexpr(MRepeat == 4 && NRepeat == 4) if constexpr(MRepeat == 4 && NRepeat == 4)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment