Commit 2cf1757e authored by Jing Zhang's avatar Jing Zhang
Browse files

tuning

parent 90ec6a19
......@@ -151,11 +151,11 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad(
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}),
Sequence<0, 0, 1, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}));
Sequence<0, 0, 2, 0, 0>{}));
return make_tuple(wei_gemmk_gemmm_global_desc,
in_gemmk_gemmn_global_desc,
......
......@@ -98,8 +98,8 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
}
}
__device__ static CIndex
CalculateCThreadOriginDataIndex(const index_t m0, const index_t n0, const index_t blk_i)
template <index_t m0, index_t n0, index_t blk_i>
__device__ static CIndex CalculateCThreadOriginDataIndex(Number<m0>, Number<n0>, Number<blk_i>)
{
const index_t waveId = get_thread_local_1d_id() / WaveSize;
......@@ -109,10 +109,10 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
const index_t waveId_m = waveId / NWaves;
const index_t waveId_n = waveId % NWaves;
const index_t row = m0 * M1 + waveId_m * MPerWave + thread_mtx_on_blk.row;
const index_t col = n0 * N1 + waveId_n * NPerWave + thread_mtx_on_blk.col;
const index_t m_offset = m0 * M1 + waveId_m * MPerWave + thread_mtx_on_blk[I0];
const index_t n_offset = n0 * N1 + waveId_n * NPerWave + thread_mtx_on_blk[I1];
return CIndex{row, col};
return CIndex{m_offset, n_offset};
}
__device__ BlockwiseGemmXdlops_km_kn_m0m1m2n_v1()
......@@ -307,10 +307,10 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
const index_t waveId_m = waveId / NWaves;
const index_t waveId_n = waveId % NWaves;
const index_t row = m0 * M1 + waveId_m * MPerWave + thread_mtx_on_blk.row;
const index_t col = n0 * N1 + waveId_n * NPerWave + thread_mtx_on_blk.col;
const index_t m_offset = m0 * M1 + waveId_m * MPerWave + thread_mtx_on_blk[I0];
const index_t n_offset = n0 * N1 + waveId_n * NPerWave + thread_mtx_on_blk[I1];
return CIndex{row, col};
return CIndex{m_offset, n_offset};
}
__device__ BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline()
......
......@@ -301,10 +301,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}));
// constexpr auto c_m0_m1_n0_n1_thread_desc =
// make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
// Number<MRepeat>{}, Number<MPerThread>{}, Number<NRepeat>{}, Number<NPerThread>{}));
const auto blockwise_gemm =
BlockwiseGemmXdlops_km_kn_m0m1m2n_v1<BlockSize,
FloatAB,
......@@ -314,6 +310,13 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
MPerWave,
NPerWave,
KPerWave>{};
constexpr auto OutputLayout = blockwise_gemm.GetOutputLayout();
constexpr index_t BlkSize = OutputLayout.GetBlkSize();
constexpr index_t NumBlks = OutputLayout.GetNumBlks();
constexpr auto c_mr_nr_nb_bk_thread_desc = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, Number<NumBlks>{}, Number<BlkSize>{}));
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size =
......@@ -334,7 +337,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
// Sequence<MRepeat, MPerThread, NRepeat, NPerThread>>{}
//.Run(c_m0_m1_n0_n1_thread_desc, make_tuple(I0, I0, I0, I0), c_thread_buf, FloatAcc{0});
constexpr auto c_vec_size = MPerBlock * NPerBlock / BlockSize;
constexpr auto c_vec_size = c_mr_nr_nb_bk_thread_desc.GetElementSpaceSize();
vector_type<float, c_vec_size> c_thread_buf;
......@@ -468,7 +471,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
// output: register to global memory
{
constexpr auto OutputLayout = blockwise_gemm.GetOutputLayout();
constexpr index_t M0 = OutputLayout.M1();
constexpr index_t M1 = OutputLayout.N1();
constexpr index_t M2 = OutputLayout.M0();
......@@ -479,27 +481,23 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<M0>{}, Number<1>{}, Number<M2>{}, Number<1>{}));
constexpr index_t BlkSize = OutputLayout.GetBlkSize();
constexpr index_t NumBlks = OutputLayout.GetNumBlks();
// static_assert(BlkSize == 16 && NumBlks == 4, "");
static_for<0, MRepeat, 1>{}([&](auto m_i) {
static_for<0, NRepeat, 1>{}([&](auto n_i) {
// force unrolling the output loop to get ride of scratches
static_for<0, NumBlks, 1>{}([&](auto i) {
static_for<0, MRepeat, 1>{}([&](auto mr_i) {
static_for<0, NRepeat, 1>{}([&](auto nr_i) {
static_for<0, NumBlks, 1>{}([&](auto blk_i) {
StaticBuffer<AddressSpace::Vgpr, float, BlkSize> c_thread_buf_;
static_for<0, BlkSize, 1>{}([&](auto j) {
c_thread_buf_(j) = c_thread_buf.template AsType<
float>()[Number<m_i*(NRepeat * BlkSize * NumBlks) +
n_i*(BlkSize * NumBlks) + i * BlkSize + j>{}];
float>()[Number<c_mr_nr_nb_bk_thread_desc.CalculateOffset(
make_tuple(mr_i, nr_i, blk_i, j))>{}];
});
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.CalculateCThreadOriginDataIndex(m_i, n_i, i);
blockwise_gemm.CalculateCThreadOriginDataIndex(mr_i, nr_i, blk_i);
const index_t k_thread_data_on_global =
m_block_data_idx_on_global + c_thread_mtx_on_block[I0];
......@@ -507,7 +505,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
const index_t b_thread_data_on_global =
n_block_data_idx_on_global + c_thread_mtx_on_block[I1];
constexpr auto c_m0_m1_n0_n1_global_tensor_iterator_hacks =
constexpr auto c_m0_m1_m2_n_global_tensor_iterator_hacks =
CGlobalIteratorHacks{};
ThreadwiseDynamicTensorSliceTransfer_v1r3<
......@@ -531,7 +529,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
c_thread_buf_,
c_m0_m1_m2_n_global_desc,
c_global_buf,
c_m0_m1_n0_n1_global_tensor_iterator_hacks);
c_m0_m1_m2_n_global_tensor_iterator_hacks);
});
});
});
......
......@@ -741,11 +741,7 @@ struct XdlopsGemm
}
#endif
struct MatrixIndex
{
index_t row;
index_t col;
};
using CIndex = MultiIndex<2>;
__device__ static constexpr index_t GetNumBlksPerXdlops()
{
......@@ -795,15 +791,16 @@ struct XdlopsGemm
static_assert(KPerWave % KPerXdlops == 0, "KPerWave cannot be divided by KPerXdlops");
static_for<0, KPerWave, KPerXdlops>{}([&](auto k) {
constexpr index_t a_offset = ADesc{}.CalculateOffset(make_multi_index(k, m0, 0));
constexpr index_t b_offset = BDesc{}.CalculateOffset(make_multi_index(k, n0, 0));
constexpr index_t c_offset = CDesc{}.CalculateOffset(make_multi_index(m0, n0));
vector_type<base_type, GetXdlopsInfo().GetNumCRegs()> t;
using c_type = decltype(GetXdlopsInfo().GetCType());
constexpr index_t c_offset = CDesc{}.CalculateOffset(make_tuple(m0, n0));
static_for<0, KPerWave, KPerXdlops>{}([&](auto k) {
constexpr index_t a_offset = ADesc{}.CalculateOffset(make_tuple(k, m0, 0));
constexpr index_t b_offset = BDesc{}.CalculateOffset(make_tuple(k, n0, 0));
t.template AsType<c_type>()(Number<0>{}) =
p_c_thread.template AsType<c_type>()[Number<c_offset>{}];
......@@ -815,7 +812,7 @@ struct XdlopsGemm
});
}
__device__ static MatrixIndex GetBeginOfThreadBlk(index_t i)
__device__ static CIndex GetBeginOfThreadBlk(index_t i)
{
const index_t xdlops_i = i / GetNumBlksPerXdlops();
const index_t j = i % GetNumBlksPerXdlops();
......@@ -838,7 +835,7 @@ struct XdlopsGemm
index_t col = col_blk * mfma_type.n + blk_td + n_i * NPerXdlops;
index_t row = row_blk * mfma_type.m + blk_id * mfma_type.group_size + m_i * MPerXdlops;
return MatrixIndex{row, col};
return CIndex{row, col};
}
static constexpr index_t MRepeats = GetXdlopsInfo().MRepeats;
......
......@@ -104,25 +104,25 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
#else
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256;
constexpr index_t GemmNPerBlock = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerWave = 64;
constexpr index_t GemmNPerWave = 64;
constexpr index_t GemmKPerWave = 4;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 2;
constexpr index_t MRepeat = 1;
constexpr index_t NRepeat = 1;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<2, 4>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 64>;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 2;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<2, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<4, 64>;
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<1, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<8, 32>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 4;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment