Commit 2cf1757e authored by Jing Zhang's avatar Jing Zhang
Browse files

tuning

parent 90ec6a19
...@@ -151,11 +151,11 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad( ...@@ -151,11 +151,11 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad(
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}), Sequence<0, 0, 1, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{}, make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{})); Sequence<0, 0, 2, 0, 0>{}));
return make_tuple(wei_gemmk_gemmm_global_desc, return make_tuple(wei_gemmk_gemmm_global_desc,
in_gemmk_gemmn_global_desc, in_gemmk_gemmn_global_desc,
......
...@@ -98,8 +98,8 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 ...@@ -98,8 +98,8 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
} }
} }
__device__ static CIndex template <index_t m0, index_t n0, index_t blk_i>
CalculateCThreadOriginDataIndex(const index_t m0, const index_t n0, const index_t blk_i) __device__ static CIndex CalculateCThreadOriginDataIndex(Number<m0>, Number<n0>, Number<blk_i>)
{ {
const index_t waveId = get_thread_local_1d_id() / WaveSize; const index_t waveId = get_thread_local_1d_id() / WaveSize;
...@@ -109,10 +109,10 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 ...@@ -109,10 +109,10 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
const index_t waveId_m = waveId / NWaves; const index_t waveId_m = waveId / NWaves;
const index_t waveId_n = waveId % NWaves; const index_t waveId_n = waveId % NWaves;
const index_t row = m0 * M1 + waveId_m * MPerWave + thread_mtx_on_blk.row; const index_t m_offset = m0 * M1 + waveId_m * MPerWave + thread_mtx_on_blk[I0];
const index_t col = n0 * N1 + waveId_n * NPerWave + thread_mtx_on_blk.col; const index_t n_offset = n0 * N1 + waveId_n * NPerWave + thread_mtx_on_blk[I1];
return CIndex{row, col}; return CIndex{m_offset, n_offset};
} }
__device__ BlockwiseGemmXdlops_km_kn_m0m1m2n_v1() __device__ BlockwiseGemmXdlops_km_kn_m0m1m2n_v1()
...@@ -307,10 +307,10 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline ...@@ -307,10 +307,10 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
const index_t waveId_m = waveId / NWaves; const index_t waveId_m = waveId / NWaves;
const index_t waveId_n = waveId % NWaves; const index_t waveId_n = waveId % NWaves;
const index_t row = m0 * M1 + waveId_m * MPerWave + thread_mtx_on_blk.row; const index_t m_offset = m0 * M1 + waveId_m * MPerWave + thread_mtx_on_blk[I0];
const index_t col = n0 * N1 + waveId_n * NPerWave + thread_mtx_on_blk.col; const index_t n_offset = n0 * N1 + waveId_n * NPerWave + thread_mtx_on_blk[I1];
return CIndex{row, col}; return CIndex{m_offset, n_offset};
} }
__device__ BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline() __device__ BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline()
......
...@@ -301,10 +301,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1 ...@@ -301,10 +301,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{})); make_tuple(Sequence<0>{}, Sequence<1, 2>{}));
// constexpr auto c_m0_m1_n0_n1_thread_desc =
// make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
// Number<MRepeat>{}, Number<MPerThread>{}, Number<NRepeat>{}, Number<NPerThread>{}));
const auto blockwise_gemm = const auto blockwise_gemm =
BlockwiseGemmXdlops_km_kn_m0m1m2n_v1<BlockSize, BlockwiseGemmXdlops_km_kn_m0m1m2n_v1<BlockSize,
FloatAB, FloatAB,
...@@ -314,6 +310,13 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1 ...@@ -314,6 +310,13 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
MPerWave, MPerWave,
NPerWave, NPerWave,
KPerWave>{}; KPerWave>{};
constexpr auto OutputLayout = blockwise_gemm.GetOutputLayout();
constexpr index_t BlkSize = OutputLayout.GetBlkSize();
constexpr index_t NumBlks = OutputLayout.GetNumBlks();
constexpr auto c_mr_nr_nb_bk_thread_desc = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, Number<NumBlks>{}, Number<BlkSize>{}));
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size = constexpr auto a_block_space_size =
...@@ -334,7 +337,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1 ...@@ -334,7 +337,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
// Sequence<MRepeat, MPerThread, NRepeat, NPerThread>>{} // Sequence<MRepeat, MPerThread, NRepeat, NPerThread>>{}
//.Run(c_m0_m1_n0_n1_thread_desc, make_tuple(I0, I0, I0, I0), c_thread_buf, FloatAcc{0}); //.Run(c_m0_m1_n0_n1_thread_desc, make_tuple(I0, I0, I0, I0), c_thread_buf, FloatAcc{0});
constexpr auto c_vec_size = MPerBlock * NPerBlock / BlockSize; constexpr auto c_vec_size = c_mr_nr_nb_bk_thread_desc.GetElementSpaceSize();
vector_type<float, c_vec_size> c_thread_buf; vector_type<float, c_vec_size> c_thread_buf;
...@@ -468,10 +471,9 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1 ...@@ -468,10 +471,9 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
// output: register to global memory // output: register to global memory
{ {
constexpr auto OutputLayout = blockwise_gemm.GetOutputLayout(); constexpr index_t M0 = OutputLayout.M1();
constexpr index_t M0 = OutputLayout.M1(); constexpr index_t M1 = OutputLayout.N1();
constexpr index_t M1 = OutputLayout.N1(); constexpr index_t M2 = OutputLayout.M0();
constexpr index_t M2 = OutputLayout.M0();
// static_assert(M0 == 4 && M1 == 2 && M2 == 4, ""); // static_assert(M0 == 4 && M1 == 2 && M2 == 4, "");
...@@ -479,27 +481,23 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1 ...@@ -479,27 +481,23 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
make_dynamic_naive_tensor_descriptor_packed_v2( make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<M0>{}, Number<1>{}, Number<M2>{}, Number<1>{})); make_tuple(Number<M0>{}, Number<1>{}, Number<M2>{}, Number<1>{}));
constexpr index_t BlkSize = OutputLayout.GetBlkSize();
constexpr index_t NumBlks = OutputLayout.GetNumBlks();
// static_assert(BlkSize == 16 && NumBlks == 4, ""); // static_assert(BlkSize == 16 && NumBlks == 4, "");
static_for<0, MRepeat, 1>{}([&](auto m_i) { static_for<0, MRepeat, 1>{}([&](auto mr_i) {
static_for<0, NRepeat, 1>{}([&](auto n_i) { static_for<0, NRepeat, 1>{}([&](auto nr_i) {
// force unrolling the output loop to get ride of scratches static_for<0, NumBlks, 1>{}([&](auto blk_i) {
static_for<0, NumBlks, 1>{}([&](auto i) {
StaticBuffer<AddressSpace::Vgpr, float, BlkSize> c_thread_buf_; StaticBuffer<AddressSpace::Vgpr, float, BlkSize> c_thread_buf_;
static_for<0, BlkSize, 1>{}([&](auto j) { static_for<0, BlkSize, 1>{}([&](auto j) {
c_thread_buf_(j) = c_thread_buf.template AsType< c_thread_buf_(j) = c_thread_buf.template AsType<
float>()[Number<m_i*(NRepeat * BlkSize * NumBlks) + float>()[Number<c_mr_nr_nb_bk_thread_desc.CalculateOffset(
n_i*(BlkSize * NumBlks) + i * BlkSize + j>{}]; make_tuple(mr_i, nr_i, blk_i, j))>{}];
}); });
// calculate origin of thread output tensor on global memory // calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index // blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block = const auto c_thread_mtx_on_block =
blockwise_gemm.CalculateCThreadOriginDataIndex(m_i, n_i, i); blockwise_gemm.CalculateCThreadOriginDataIndex(mr_i, nr_i, blk_i);
const index_t k_thread_data_on_global = const index_t k_thread_data_on_global =
m_block_data_idx_on_global + c_thread_mtx_on_block[I0]; m_block_data_idx_on_global + c_thread_mtx_on_block[I0];
...@@ -507,7 +505,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1 ...@@ -507,7 +505,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
const index_t b_thread_data_on_global = const index_t b_thread_data_on_global =
n_block_data_idx_on_global + c_thread_mtx_on_block[I1]; n_block_data_idx_on_global + c_thread_mtx_on_block[I1];
constexpr auto c_m0_m1_n0_n1_global_tensor_iterator_hacks = constexpr auto c_m0_m1_m2_n_global_tensor_iterator_hacks =
CGlobalIteratorHacks{}; CGlobalIteratorHacks{};
ThreadwiseDynamicTensorSliceTransfer_v1r3< ThreadwiseDynamicTensorSliceTransfer_v1r3<
...@@ -531,7 +529,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1 ...@@ -531,7 +529,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
c_thread_buf_, c_thread_buf_,
c_m0_m1_m2_n_global_desc, c_m0_m1_m2_n_global_desc,
c_global_buf, c_global_buf,
c_m0_m1_n0_n1_global_tensor_iterator_hacks); c_m0_m1_m2_n_global_tensor_iterator_hacks);
}); });
}); });
}); });
......
...@@ -741,11 +741,7 @@ struct XdlopsGemm ...@@ -741,11 +741,7 @@ struct XdlopsGemm
} }
#endif #endif
struct MatrixIndex using CIndex = MultiIndex<2>;
{
index_t row;
index_t col;
};
__device__ static constexpr index_t GetNumBlksPerXdlops() __device__ static constexpr index_t GetNumBlksPerXdlops()
{ {
...@@ -795,14 +791,15 @@ struct XdlopsGemm ...@@ -795,14 +791,15 @@ struct XdlopsGemm
static_assert(KPerWave % KPerXdlops == 0, "KPerWave cannot be divided by KPerXdlops"); static_assert(KPerWave % KPerXdlops == 0, "KPerWave cannot be divided by KPerXdlops");
static_for<0, KPerWave, KPerXdlops>{}([&](auto k) { vector_type<base_type, GetXdlopsInfo().GetNumCRegs()> t;
constexpr index_t a_offset = ADesc{}.CalculateOffset(make_multi_index(k, m0, 0));
constexpr index_t b_offset = BDesc{}.CalculateOffset(make_multi_index(k, n0, 0));
constexpr index_t c_offset = CDesc{}.CalculateOffset(make_multi_index(m0, n0));
vector_type<base_type, GetXdlopsInfo().GetNumCRegs()> t; using c_type = decltype(GetXdlopsInfo().GetCType());
using c_type = decltype(GetXdlopsInfo().GetCType()); constexpr index_t c_offset = CDesc{}.CalculateOffset(make_tuple(m0, n0));
static_for<0, KPerWave, KPerXdlops>{}([&](auto k) {
constexpr index_t a_offset = ADesc{}.CalculateOffset(make_tuple(k, m0, 0));
constexpr index_t b_offset = BDesc{}.CalculateOffset(make_tuple(k, n0, 0));
t.template AsType<c_type>()(Number<0>{}) = t.template AsType<c_type>()(Number<0>{}) =
p_c_thread.template AsType<c_type>()[Number<c_offset>{}]; p_c_thread.template AsType<c_type>()[Number<c_offset>{}];
...@@ -815,7 +812,7 @@ struct XdlopsGemm ...@@ -815,7 +812,7 @@ struct XdlopsGemm
}); });
} }
__device__ static MatrixIndex GetBeginOfThreadBlk(index_t i) __device__ static CIndex GetBeginOfThreadBlk(index_t i)
{ {
const index_t xdlops_i = i / GetNumBlksPerXdlops(); const index_t xdlops_i = i / GetNumBlksPerXdlops();
const index_t j = i % GetNumBlksPerXdlops(); const index_t j = i % GetNumBlksPerXdlops();
...@@ -838,7 +835,7 @@ struct XdlopsGemm ...@@ -838,7 +835,7 @@ struct XdlopsGemm
index_t col = col_blk * mfma_type.n + blk_td + n_i * NPerXdlops; index_t col = col_blk * mfma_type.n + blk_td + n_i * NPerXdlops;
index_t row = row_blk * mfma_type.m + blk_id * mfma_type.group_size + m_i * MPerXdlops; index_t row = row_blk * mfma_type.m + blk_id * mfma_type.group_size + m_i * MPerXdlops;
return MatrixIndex{row, col}; return CIndex{row, col};
} }
static constexpr index_t MRepeats = GetXdlopsInfo().MRepeats; static constexpr index_t MRepeats = GetXdlopsInfo().MRepeats;
......
...@@ -104,25 +104,25 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw ...@@ -104,25 +104,25 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
#else #else
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256; constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 256; constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 8; constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerWave = 64; constexpr index_t GemmMPerWave = 64;
constexpr index_t GemmNPerWave = 64; constexpr index_t GemmNPerWave = 64;
constexpr index_t GemmKPerWave = 4; constexpr index_t GemmKPerWave = 4;
constexpr index_t MRepeat = 2; constexpr index_t MRepeat = 1;
constexpr index_t NRepeat = 2; constexpr index_t NRepeat = 1;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<2, 4>; using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 64>; using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 2; constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1; constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<2, 4>; using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<1, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<4, 64>; using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<8, 32>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4; constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 4; constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 4;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment