Commit 35267a40 authored by Bartlomiej Wroblewski's avatar Bartlomiej Wroblewski
Browse files

Review: Add blockwise doc, change function names to include dimension names

parent 9ca59788
...@@ -10,6 +10,15 @@ ...@@ -10,6 +10,15 @@
namespace ck { namespace ck {
/**
* Blockwise GEMM that uses DPP instruction modifier to limit the amount of data loaded for each
* thread by sharing the data between threads in a lanegroup.
*
* In every iteration, each wave calculates a C tile of size `MPerDpp` * `NPerDpp`, there are
* `MRepeat` iterations for `M` dimension and `NRepeat` for `N` one.
* In total, the algorithm runs using
* `MPerBlock / (MRepeat * MPerDpp) * NPerBlock / (NRepeat * NPerDpp)` waves.
*/
template <index_t BlockSize, template <index_t BlockSize,
typename FloatAB, typename FloatAB,
typename FloatAcc, typename FloatAcc,
...@@ -69,20 +78,24 @@ struct BlockwiseGemmDpp_k0mk1_k0nk1_m0n0m1n1m2n2 ...@@ -69,20 +78,24 @@ struct BlockwiseGemmDpp_k0mk1_k0nk1_m0n0m1n1m2n2
return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id)); return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
} }
__device__ static auto CalculateAThreadOriginDataIndex() __device__ static auto CalculateAThreadOriginDataIndex_M0_M1_M2_K()
{ {
const auto wave_idx = GetWaveIdx(); const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0]; const auto waveId_m = wave_idx[I0];
const auto dpp_a_idx = dpp_gemm.CalculateAThreadOriginDataIndex(); const auto dpp_a_idx = dpp_gemm.CalculateAThreadOriginDataIndex_K_M();
return make_tuple(0, waveId_m, dpp_a_idx[I1], KPerThread * dpp_a_idx[I0]); const auto dpp_a_idx_k = dpp_a_idx[I0];
const auto dpp_a_idx_m = dpp_a_idx[I1];
return make_tuple(0, waveId_m, dpp_a_idx_m, KPerThread * dpp_a_idx_k);
} }
__device__ static auto CalculateBThreadOriginDataIndex() __device__ static auto CalculateBThreadOriginDataIndex_N0_N1_N2_K()
{ {
const auto wave_idx = GetWaveIdx(); const auto wave_idx = GetWaveIdx();
const auto waveId_n = wave_idx[I1]; const auto waveId_n = wave_idx[I1];
const auto dpp_b_idx = dpp_gemm.CalculateBThreadOriginDataIndex(); const auto dpp_b_idx = dpp_gemm.CalculateBThreadOriginDataIndex_K_N();
return make_tuple(0, waveId_n, dpp_b_idx[I1], KPerThread * dpp_b_idx[I0]); const auto dpp_b_idx_k = dpp_b_idx[I0];
const auto dpp_b_idx_n = dpp_b_idx[I1];
return make_tuple(0, waveId_n, dpp_b_idx_n, KPerThread * dpp_b_idx_k);
} }
template <index_t m0, index_t n0> template <index_t m0, index_t n0>
...@@ -91,7 +104,10 @@ struct BlockwiseGemmDpp_k0mk1_k0nk1_m0n0m1n1m2n2 ...@@ -91,7 +104,10 @@ struct BlockwiseGemmDpp_k0mk1_k0nk1_m0n0m1n1m2n2
const auto wave_idx = GetWaveIdx(); const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0]; const auto waveId_m = wave_idx[I0];
const auto waveId_n = wave_idx[I1]; const auto waveId_n = wave_idx[I1];
const auto blk_idx = dpp_gemm.GetBeginOfThreadBlk(); const auto blk_idx = dpp_gemm.GetBeginOfThreadBlk();
const auto blk_m_offset = blk_idx[I0];
const auto blk_n_offset = blk_idx[I1];
constexpr auto mrepeat_mwave_MPerDpp_to_m_adaptor = make_single_stage_tensor_adaptor( constexpr auto mrepeat_mwave_MPerDpp_to_m_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerDpp))), make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerDpp))),
...@@ -104,9 +120,9 @@ struct BlockwiseGemmDpp_k0mk1_k0nk1_m0n0m1n1m2n2 ...@@ -104,9 +120,9 @@ struct BlockwiseGemmDpp_k0mk1_k0nk1_m0n0m1n1m2n2
make_tuple(Sequence<0, 1, 2>{})); make_tuple(Sequence<0, 1, 2>{}));
const index_t c_thread_m = mrepeat_mwave_MPerDpp_to_m_adaptor.CalculateBottomIndex( const index_t c_thread_m = mrepeat_mwave_MPerDpp_to_m_adaptor.CalculateBottomIndex(
make_tuple(m0, waveId_m, blk_idx[I0]))[I0]; make_tuple(m0, waveId_m, blk_m_offset))[I0];
const index_t c_thread_n = nrepeat_nwave_NPerDpp_to_n_adaptor.CalculateBottomIndex( const index_t c_thread_n = nrepeat_nwave_NPerDpp_to_n_adaptor.CalculateBottomIndex(
make_tuple(n0, waveId_n, blk_idx[I1]))[I0]; make_tuple(n0, waveId_n, blk_n_offset))[I0];
return make_tuple(c_thread_m, c_thread_n); return make_tuple(c_thread_m, c_thread_n);
} }
...@@ -324,8 +340,8 @@ struct BlockwiseGemmDpp_k0mk1_k0nk1_m0n0m1n1m2n2 ...@@ -324,8 +340,8 @@ struct BlockwiseGemmDpp_k0mk1_k0nk1_m0n0m1n1m2n2
B_K1, B_K1,
B_K1>; B_K1>;
AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex()}; AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex_M0_M1_M2_K()};
BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex()}; BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex_N0_N1_N2_K()};
}; };
} // namespace ck } // namespace ck
...@@ -20,7 +20,7 @@ enum struct DppInstr ...@@ -20,7 +20,7 @@ enum struct DppInstr
* Structure representing DPP GEMM executed by a single wavefront. * Structure representing DPP GEMM executed by a single wavefront.
* *
* Each structure instantiation must contain the following fields: * Each structure instantiation must contain the following fields:
* - wave_size - number of threads that execute single DPP GEMM operation, usually equal to the * - wave_size - number of threads that execute a single DPP GEMM operation, usually equal to the
* number of threads in a wavefront; * number of threads in a wavefront;
* - lanegroup_size - number of threads (lanes) that share data using DPP instruction modifier, * - lanegroup_size - number of threads (lanes) that share data using DPP instruction modifier,
* it's 8 in case of DPP8; * it's 8 in case of DPP8;
...@@ -254,16 +254,15 @@ struct DppGemm ...@@ -254,16 +254,15 @@ struct DppGemm
return make_tuple(m_dpp_idx, n_dpp_idx); return make_tuple(m_dpp_idx, n_dpp_idx);
} }
__host__ __device__ static auto CalculateAThreadOriginDataIndex() __host__ __device__ static auto CalculateAThreadOriginDataIndex_K_M()
{ {
const auto laneId = get_thread_local_1d_id(); const auto laneId = get_thread_local_1d_id();
const auto wave_row = laneId / dpp_instr.n_per_wave; const auto wave_row = laneId / dpp_instr.n_per_wave;
auto m_idx = dpp_instr.m_per_thread * wave_row + GetLaneIdInLaneGroup(); auto m_idx = dpp_instr.m_per_thread * wave_row + GetLaneIdInLaneGroup();
return make_tuple(0, m_idx % dpp_instr.m_per_wave); return make_tuple(0, m_idx % dpp_instr.m_per_wave);
return make_tuple(0, laneId % dpp_instr.m_per_lanegroup);
} }
__host__ __device__ static auto CalculateBThreadOriginDataIndex() __host__ __device__ static auto CalculateBThreadOriginDataIndex_K_N()
{ {
const auto laneId = get_thread_local_1d_id(); const auto laneId = get_thread_local_1d_id();
return make_tuple(0, laneId % dpp_instr.n_per_wave); return make_tuple(0, laneId % dpp_instr.n_per_wave);
...@@ -271,13 +270,13 @@ struct DppGemm ...@@ -271,13 +270,13 @@ struct DppGemm
__device__ static CIndex GetBeginOfThreadBlk() __device__ static CIndex GetBeginOfThreadBlk()
{ {
const auto dpp_idx = GetDppOpIdx(); const auto dpp_op_idx = GetDppOpIdx();
const auto m_dpp_idx = dpp_idx[I0]; const auto m_dpp_op_idx = dpp_op_idx[I0];
const auto n_dpp_idx = dpp_idx[I1]; const auto n_dpp_op_idx = dpp_op_idx[I1];
index_t n_offset = n_dpp_idx * dpp_instr.n_per_lanegroup + GetLaneIdInLaneGroup(); index_t n_offset = n_dpp_op_idx * dpp_instr.n_per_lanegroup + GetLaneIdInLaneGroup();
index_t m_offset = m_dpp_idx * dpp_instr.m_per_lanegroup; index_t m_offset = m_dpp_op_idx * dpp_instr.m_per_lanegroup;
return CIndex{m_offset, n_offset}; return CIndex{m_offset, n_offset};
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment