Commit cf378e46 authored by Jing Zhang's avatar Jing Zhang
Browse files

rename hacks and use single stage adapter

parent 9fd0bb97
...@@ -81,16 +81,20 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -81,16 +81,20 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i); const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i);
constexpr auto mrepeat_mwave_mperxdl_to_m = make_naive_tensor_descriptor_packed( constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor(
make_tuple(Number<MRepeat>{}, Number<MWaves>{}, Number<MPerXDL>{})); make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL))),
make_tuple(Sequence<0>{}),
constexpr auto nrepeat_nwave_nperxdl_to_n = make_naive_tensor_descriptor_packed( make_tuple(Sequence<0, 1, 2>{}));
make_tuple(Number<NRepeat>{}, Number<NWaves>{}, Number<NPerXDL>{}));
constexpr auto nrepeat_nwave_nperxdl_to_n_adaptor = make_single_stage_tensor_adaptor(
const index_t c_thread_m = make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))),
mrepeat_mwave_mperxdl_to_m.CalculateOffset(make_tuple(m0, waveId_m, blk_idx[I0])); make_tuple(Sequence<0>{}),
const index_t c_thread_n = make_tuple(Sequence<0, 1, 2>{}));
nrepeat_nwave_nperxdl_to_n.CalculateOffset(make_tuple(n0, waveId_n, blk_idx[I1]));
const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex(
make_tuple(m0, waveId_m, blk_idx[I0]))[I0];
const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex(
make_tuple(n0, waveId_n, blk_idx[I1]))[I0];
return make_tuple(c_thread_m, c_thread_n); return make_tuple(c_thread_m, c_thread_n);
} }
......
...@@ -473,7 +473,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -473,7 +473,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
const index_t n_thread_data_on_grid = const index_t n_thread_data_on_grid =
n_block_data_idx_on_grid + c_thread_mtx_on_block[I1]; n_block_data_idx_on_grid + c_thread_mtx_on_block[I1];
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_iterator_hacks = CGridStepHacks{}; constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks = CGridStepHacks{};
auto c_thread_copy = auto c_thread_copy =
ThreadwiseTensorSliceTransfer_v1r3<FloatC, ThreadwiseTensorSliceTransfer_v1r3<FloatC,
...@@ -504,7 +504,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -504,7 +504,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(), c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_grid_buf, c_grid_buf,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_iterator_hacks); c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks);
return c_thread_idx_; return c_thread_idx_;
}; };
...@@ -520,7 +520,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -520,7 +520,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(), c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_grid_buf, c_grid_buf,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_iterator_hacks); c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks);
}; };
auto nrepeat_plus_copy = [&](auto c_thread_idx_) { auto nrepeat_plus_copy = [&](auto c_thread_idx_) {
...@@ -534,7 +534,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -534,7 +534,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(), c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_grid_buf, c_grid_buf,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_iterator_hacks); c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks);
}; };
auto mrepeat_minus_copy = [&](auto c_thread_idx_) { auto mrepeat_minus_copy = [&](auto c_thread_idx_) {
...@@ -548,7 +548,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -548,7 +548,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(), c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_grid_buf, c_grid_buf,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_iterator_hacks); c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks);
}; };
auto nrepeat_minus_copy = [&](auto c_thread_idx_) { auto nrepeat_minus_copy = [&](auto c_thread_idx_) {
...@@ -562,7 +562,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -562,7 +562,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(), c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_grid_buf, c_grid_buf,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_iterator_hacks); c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks);
}; };
static_assert((MRepeat == 4 && NRepeat == 4) or (MRepeat == 4 && NRepeat == 2) or static_assert((MRepeat == 4 && NRepeat == 4) or (MRepeat == 4 && NRepeat == 2) or
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment