Commit 59462dca authored by Jing Zhang's avatar Jing Zhang
Browse files

use StaticBuffer of vector_type

parent 2cf1757e
...@@ -315,8 +315,9 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1 ...@@ -315,8 +315,9 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
constexpr index_t BlkSize = OutputLayout.GetBlkSize(); constexpr index_t BlkSize = OutputLayout.GetBlkSize();
constexpr index_t NumBlks = OutputLayout.GetNumBlks(); constexpr index_t NumBlks = OutputLayout.GetNumBlks();
constexpr auto c_mr_nr_nb_bk_thread_desc = make_dynamic_naive_tensor_descriptor_packed_v2( // constexpr auto c_mr_nr_nb_bk_thread_desc =
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, Number<NumBlks>{}, Number<BlkSize>{})); // make_dynamic_naive_tensor_descriptor_packed_v2( make_tuple(Number<MRepeat>{},
// Number<NRepeat>{}, Number<NumBlks>{}, Number<BlkSize>{}));
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size = constexpr auto a_block_space_size =
...@@ -337,9 +338,8 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1 ...@@ -337,9 +338,8 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
// Sequence<MRepeat, MPerThread, NRepeat, NPerThread>>{} // Sequence<MRepeat, MPerThread, NRepeat, NPerThread>>{}
//.Run(c_m0_m1_n0_n1_thread_desc, make_tuple(I0, I0, I0, I0), c_thread_buf, FloatAcc{0}); //.Run(c_m0_m1_n0_n1_thread_desc, make_tuple(I0, I0, I0, I0), c_thread_buf, FloatAcc{0});
constexpr auto c_vec_size = c_mr_nr_nb_bk_thread_desc.GetElementSpaceSize(); StaticBuffer<AddressSpace::Vgpr, vector_type<float, NumBlks * BlkSize>, MRepeat * NRepeat>
c_thread_buf;
vector_type<float, c_vec_size> c_thread_buf;
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0); constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0);
...@@ -475,23 +475,19 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1 ...@@ -475,23 +475,19 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
constexpr index_t M1 = OutputLayout.N1(); constexpr index_t M1 = OutputLayout.N1();
constexpr index_t M2 = OutputLayout.M0(); constexpr index_t M2 = OutputLayout.M0();
// static_assert(M0 == 4 && M1 == 2 && M2 == 4, "");
constexpr auto c_m0_m1_m2_n_thread_desc = constexpr auto c_m0_m1_m2_n_thread_desc =
make_dynamic_naive_tensor_descriptor_packed_v2( make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<M0>{}, Number<1>{}, Number<M2>{}, Number<1>{})); make_tuple(Number<M0>{}, Number<1>{}, Number<M2>{}, Number<1>{}));
// static_assert(BlkSize == 16 && NumBlks == 4, ""); StaticBuffer<AddressSpace::Vgpr, float, BlkSize> c_blk_buf_;
static_for<0, MRepeat, 1>{}([&](auto mr_i) { static_for<0, MRepeat, 1>{}([&](auto mr_i) {
static_for<0, NRepeat, 1>{}([&](auto nr_i) { static_for<0, NRepeat, 1>{}([&](auto nr_i) {
static_for<0, NumBlks, 1>{}([&](auto blk_i) { static_for<0, NumBlks, 1>{}([&](auto blk_i) {
StaticBuffer<AddressSpace::Vgpr, float, BlkSize> c_thread_buf_;
static_for<0, BlkSize, 1>{}([&](auto j) { static_for<0, BlkSize, 1>{}([&](auto j) {
c_thread_buf_(j) = c_thread_buf.template AsType< c_blk_buf_(j) =
float>()[Number<c_mr_nr_nb_bk_thread_desc.CalculateOffset( c_thread_buf[Number<mr_i * NRepeat + nr_i>{}]
make_tuple(mr_i, nr_i, blk_i, j))>{}]; .template AsType<float>()[Number<blk_i * BlkSize + j>{}];
}); });
// calculate origin of thread output tensor on global memory // calculate origin of thread output tensor on global memory
...@@ -526,7 +522,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1 ...@@ -526,7 +522,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
b_thread_data_on_global)} b_thread_data_on_global)}
.Run(c_m0_m1_m2_n_thread_desc, .Run(c_m0_m1_m2_n_thread_desc,
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
c_thread_buf_, c_blk_buf_,
c_m0_m1_m2_n_global_desc, c_m0_m1_m2_n_global_desc,
c_global_buf, c_global_buf,
c_m0_m1_m2_n_global_tensor_iterator_hacks); c_m0_m1_m2_n_global_tensor_iterator_hacks);
......
...@@ -791,24 +791,15 @@ struct XdlopsGemm ...@@ -791,24 +791,15 @@ struct XdlopsGemm
static_assert(KPerWave % KPerXdlops == 0, "KPerWave cannot be divided by KPerXdlops"); static_assert(KPerWave % KPerXdlops == 0, "KPerWave cannot be divided by KPerXdlops");
vector_type<base_type, GetXdlopsInfo().GetNumCRegs()> t;
using c_type = decltype(GetXdlopsInfo().GetCType());
constexpr index_t c_offset = CDesc{}.CalculateOffset(make_tuple(m0, n0)); constexpr index_t c_offset = CDesc{}.CalculateOffset(make_tuple(m0, n0));
static_for<0, KPerWave, KPerXdlops>{}([&](auto k) { static_for<0, KPerWave, KPerXdlops>{}([&](auto k) {
constexpr index_t a_offset = ADesc{}.CalculateOffset(make_tuple(k, m0, 0)); constexpr index_t a_offset = ADesc{}.CalculateOffset(make_tuple(k, m0, 0));
constexpr index_t b_offset = BDesc{}.CalculateOffset(make_tuple(k, n0, 0)); constexpr index_t b_offset = BDesc{}.CalculateOffset(make_tuple(k, n0, 0));
t.template AsType<c_type>()(Number<0>{}) = mfma_type.template run<MPerXdlops, NPerXdlops>(p_a_wave[Number<a_offset>{}],
p_c_thread.template AsType<c_type>()[Number<c_offset>{}]; p_b_wave[Number<b_offset>{}],
p_c_thread(Number<c_offset>{}));
mfma_type.template run<MPerXdlops, NPerXdlops>(
p_a_wave[Number<a_offset>{}], p_b_wave[Number<b_offset>{}], t);
p_c_thread.template AsType<c_type>()(Number<c_offset>{}) =
t.template AsType<c_type>()[Number<0>{}];
}); });
} }
......
...@@ -106,7 +106,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw ...@@ -106,7 +106,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
constexpr index_t GemmMPerBlock = 128; constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 8; constexpr index_t GemmKPerBlock = 16;
constexpr index_t GemmMPerWave = 64; constexpr index_t GemmMPerWave = 64;
constexpr index_t GemmNPerWave = 64; constexpr index_t GemmNPerWave = 64;
...@@ -115,13 +115,13 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw ...@@ -115,13 +115,13 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
constexpr index_t MRepeat = 1; constexpr index_t MRepeat = 1;
constexpr index_t NRepeat = 1; constexpr index_t NRepeat = 1;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>; using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 2>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>; using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 64>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 4; constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1; constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<1, 4>; using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<2, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<8, 32>; using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<8, 32>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4; constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment