Commit 58b6996a authored by ltqin's avatar ltqin
Browse files

some gridwise gemm write to C matrix

parent 971220d8
......@@ -131,6 +131,17 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
return make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, I1, M0, M1, M2, N));
}
__host__ __device__ static constexpr auto GetCGM0N0M1N1M2M3M4N2ThreadDescriptor()
{
constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
return make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}, I1, I1, I1, I1, M0, M1, M2, N));
}
__host__ __device__ static constexpr auto GetCM0N0M1N1M2M3M4N2BlockDescriptor()
{
constexpr auto c_m0_n0_m1_n1_m2_n2_block_desc =
......@@ -144,6 +155,20 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
return xdlops_gemm.MakeCM0N0M1N1M2M3M4N2Descriptor(c_m0_n0_m1_n1_m2_n2_block_desc);
}
__host__ __device__ static constexpr auto GetCGM0N0M1N1M2M3M4N2BlockDescriptor()
{
constexpr auto c_m0_n0_m1_n1_m2_n2_block_desc =
make_naive_tensor_descriptor_packed(make_tuple(Number<1>{},
Number<MRepeat>{},
Number<NRepeat>{},
Number<MWaves>{},
Number<NWaves>{},
Number<MPerXDL>{},
Number<NPerXDL>{}));
return xdlops_gemm.MakeCM0N0M1N1M2M3M4N2Descriptor(c_m0_n0_m1_n1_m2_n2_block_desc);
}
template <typename CMNGridDesc>
__host__ __device__ static constexpr auto
MakeCM0N0M1N1M2M3M4N2GridDescriptor(const CMNGridDesc& c_m_n_grid_desc)
......
......@@ -142,6 +142,7 @@ struct GridwiseGemm_gk0mk1_gk0nk1_gmn_xdlops_v3r1
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{};
// K1 should be Number<...>
static constexpr auto K1 = Number<K1Value>{};
......@@ -388,9 +389,9 @@ struct GridwiseGemm_gk0mk1_gk0nk1_gmn_xdlops_v3r1
constexpr auto c_mr_nr_blk_desc =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{}, Number<NRepeat>{}));
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc =
blockwise_gemm.GetCM0N0M1N1M2M3M4N2ThreadDescriptor();
constexpr auto CBlkSize = c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc.GetElementSpaceSize();
constexpr auto c_g_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc =
blockwise_gemm.GetCGM0N0M1N1M2M3M4N2ThreadDescriptor();
constexpr auto CBlkSize = c_g_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc.GetElementSpaceSize();
StaticBuffer<AddressSpaceEnum_t::Vgpr,
vector_type<FloatAcc, CBlkSize>,
......@@ -472,21 +473,19 @@ struct GridwiseGemm_gk0mk1_gk0nk1_gmn_xdlops_v3r1
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
}
/* // output: register to global memory
// output: register to global memory
{
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc =
blockwise_gemm.GetCM0N0M1N1M2M3M4N2BlockDescriptor();
constexpr auto c_g_m0_n0_m1_n1_m2_m3_m4_n2_block_desc =
blockwise_gemm.GetCGM0N0M1N1M2M3M4N2BlockDescriptor();
constexpr auto M2 =
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I4); constexpr auto M3 =
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5); constexpr auto M4 =
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6);
constexpr auto M2 = c_g_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5);
constexpr auto M3 = c_g_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6);
constexpr auto M4 = c_g_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I7);
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0,
I0);
blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
const index_t m_thread_data_on_grid =
m_block_data_idx_on_grid + c_thread_mtx_on_block[I0];
......@@ -494,102 +493,104 @@ struct GridwiseGemm_gk0mk1_gk0nk1_gmn_xdlops_v3r1
const index_t n_thread_data_on_grid =
n_block_data_idx_on_grid + c_thread_mtx_on_block[I1];
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks
= CGridStepHacks{};
constexpr auto c_g_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks = CGridStepHacks{};
auto c_thread_copy =
ThreadwiseTensorSliceTransfer_v1r3<FloatC,
FloatC,
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc),
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc),
Sequence<I1, I1, I1, I1,
M2, I1, M4, I1>, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim,
decltype(c_g_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc),
decltype(c_g_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc),
Sequence<I1, I1, I1, I1, I1, M2, I1, M4, I1>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation,
1,
true>{
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
make_multi_index(0,
c_g_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
make_multi_index(g_idx,
0,
0,
0,
0,
m_thread_data_on_grid / (M3 * M4),
m_thread_data_on_grid % (M3 * M4) /
M4, m_thread_data_on_grid % M4, n_thread_data_on_grid)};
m_thread_data_on_grid % (M3 * M4) / M4,
m_thread_data_on_grid % M4,
n_thread_data_on_grid)};
auto init_copy = [&](auto c_thread_idx_) {
constexpr auto blk_off =
c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0,
I0), c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, c_grid_buf,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks);
constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
c_thread_copy.Run(c_g_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0),
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
c_g_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_grid_buf,
c_g_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks);
return c_thread_idx_;
};
auto mrepeat_plus_copy = [&](auto c_thread_idx_) {
constexpr auto mrepeat_step_plus = make_multi_index(1, 0,
0, 0, 0, 0, 0, 0); c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
/* auto mrepeat_plus_copy = [&](auto c_thread_idx_) {
constexpr auto mrepeat_step_plus = make_multi_index(1, 0, 0, 0, 0, 0, 0, 0);
c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
mrepeat_step_plus);
constexpr auto blk_off =
c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0,
I0), c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, c_grid_buf,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_grid_buf,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks);
};
auto nrepeat_plus_copy = [&](auto c_thread_idx_) {
constexpr auto nrepeat_step_plus = make_multi_index(0, 1,
0, 0, 0, 0, 0, 0); c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
constexpr auto nrepeat_step_plus = make_multi_index(0, 1, 0, 0, 0, 0, 0, 0);
c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
nrepeat_step_plus);
constexpr auto blk_off =
c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0,
I0), c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, c_grid_buf,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_grid_buf,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks);
};
auto mrepeat_minus_copy = [&](auto c_thread_idx_) {
constexpr auto mrepeat_step_plus = make_multi_index(-1, 0,
0, 0, 0, 0, 0, 0); c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
constexpr auto mrepeat_step_plus = make_multi_index(-1, 0, 0, 0, 0, 0, 0, 0);
c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
mrepeat_step_plus);
constexpr auto blk_off =
c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0,
I0), c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, c_grid_buf,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_grid_buf,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks);
};
auto nrepeat_minus_copy = [&](auto c_thread_idx_) {
constexpr auto nrepeat_step_minus = make_multi_index(0, -1,
0, 0, 0, 0, 0, 0); c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
constexpr auto nrepeat_step_minus = make_multi_index(0, -1, 0, 0, 0, 0, 0, 0);
c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
nrepeat_step_minus);
constexpr auto blk_off =
c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0,
I0), c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, c_grid_buf,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_grid_buf,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks);
};
static_assert((MRepeat == 4 && NRepeat == 4) or (MRepeat == 4
&& NRepeat == 2) or (MRepeat == 2 && NRepeat == 4) or (MRepeat == 2 && NRepeat == 2) or
(MRepeat == 2
&& NRepeat == 1) or (MRepeat == 1 && NRepeat == 2) or (MRepeat == 1 &&
NRepeat == 1), "wrong");
static_assert((MRepeat == 4 && NRepeat == 4) or (MRepeat == 4 && NRepeat == 2) or
(MRepeat == 2 && NRepeat == 4) or (MRepeat == 2 && NRepeat == 2) or
(MRepeat == 2 && NRepeat == 1) or (MRepeat == 1 && NRepeat == 2) or
(MRepeat == 1 && NRepeat == 1),
"wrong");
if constexpr(MRepeat == 4 && NRepeat == 4)
{
......@@ -713,7 +714,7 @@ struct GridwiseGemm_gk0mk1_gk0nk1_gmn_xdlops_v3r1
{
init_copy(make_tuple(I0, I0));
}
}*/
*/ }
}
}; // namespace ck
......
......@@ -255,24 +255,24 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwgc_gkyxc_nhwgk(
Sequence<0, 0, 0, 0, 0, 0, 0>{})); // 3-: GemmK1
constexpr auto out_gemmg_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: N0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: N1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M3
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 7+: M4
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 8+: N2
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: N0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: N1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M3
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 7-: M4
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 8-: N2
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: N0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: N1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M3
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 7+: M4
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 8+: N2
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: N0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: N1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M3
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 7-: M4
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 8-: N2
constexpr auto in_gemmg_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment