"driver/vscode:/vscode.git/clone" did not exist on "b9fcd6653b3630821736b68048413d5278c61755"
Commit 35267a40 authored by Bartlomiej Wroblewski's avatar Bartlomiej Wroblewski
Browse files

Review: Add blockwise doc, change function names to include dimension names

parent 9ca59788
......@@ -10,6 +10,15 @@
namespace ck {
/**
* Blockwise GEMM that uses DPP instruction modifier to limit the amount of data loaded for each
* thread by sharing the data between threads in a lanegroup.
*
* In every iteration, each wave calculates a C tile of size `MPerDpp` * `NPerDpp`, there are
* `MRepeat` iterations for `M` dimension and `NRepeat` for `N` one.
* In total, the algorithm runs using
* `MPerBlock / (MRepeat * MPerDpp) * NPerBlock / (NRepeat * NPerDpp)` waves.
*/
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
......@@ -69,20 +78,24 @@ struct BlockwiseGemmDpp_k0mk1_k0nk1_m0n0m1n1m2n2
return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
}
__device__ static auto CalculateAThreadOriginDataIndex()
__device__ static auto CalculateAThreadOriginDataIndex_M0_M1_M2_K()
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto dpp_a_idx = dpp_gemm.CalculateAThreadOriginDataIndex();
return make_tuple(0, waveId_m, dpp_a_idx[I1], KPerThread * dpp_a_idx[I0]);
const auto dpp_a_idx = dpp_gemm.CalculateAThreadOriginDataIndex_K_M();
const auto dpp_a_idx_k = dpp_a_idx[I0];
const auto dpp_a_idx_m = dpp_a_idx[I1];
return make_tuple(0, waveId_m, dpp_a_idx_m, KPerThread * dpp_a_idx_k);
}
__device__ static auto CalculateBThreadOriginDataIndex()
__device__ static auto CalculateBThreadOriginDataIndex_N0_N1_N2_K()
{
const auto wave_idx = GetWaveIdx();
const auto waveId_n = wave_idx[I1];
const auto dpp_b_idx = dpp_gemm.CalculateBThreadOriginDataIndex();
return make_tuple(0, waveId_n, dpp_b_idx[I1], KPerThread * dpp_b_idx[I0]);
const auto dpp_b_idx = dpp_gemm.CalculateBThreadOriginDataIndex_K_N();
const auto dpp_b_idx_k = dpp_b_idx[I0];
const auto dpp_b_idx_n = dpp_b_idx[I1];
return make_tuple(0, waveId_n, dpp_b_idx_n, KPerThread * dpp_b_idx_k);
}
template <index_t m0, index_t n0>
......@@ -91,7 +104,10 @@ struct BlockwiseGemmDpp_k0mk1_k0nk1_m0n0m1n1m2n2
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto waveId_n = wave_idx[I1];
const auto blk_idx = dpp_gemm.GetBeginOfThreadBlk();
const auto blk_m_offset = blk_idx[I0];
const auto blk_n_offset = blk_idx[I1];
constexpr auto mrepeat_mwave_MPerDpp_to_m_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerDpp))),
......@@ -104,9 +120,9 @@ struct BlockwiseGemmDpp_k0mk1_k0nk1_m0n0m1n1m2n2
make_tuple(Sequence<0, 1, 2>{}));
const index_t c_thread_m = mrepeat_mwave_MPerDpp_to_m_adaptor.CalculateBottomIndex(
make_tuple(m0, waveId_m, blk_idx[I0]))[I0];
make_tuple(m0, waveId_m, blk_m_offset))[I0];
const index_t c_thread_n = nrepeat_nwave_NPerDpp_to_n_adaptor.CalculateBottomIndex(
make_tuple(n0, waveId_n, blk_idx[I1]))[I0];
make_tuple(n0, waveId_n, blk_n_offset))[I0];
return make_tuple(c_thread_m, c_thread_n);
}
......@@ -324,8 +340,8 @@ struct BlockwiseGemmDpp_k0mk1_k0nk1_m0n0m1n1m2n2
B_K1,
B_K1>;
AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex()};
BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex()};
AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex_M0_M1_M2_K()};
BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex_N0_N1_N2_K()};
};
} // namespace ck
......@@ -20,7 +20,7 @@ enum struct DppInstr
* Structure representing DPP GEMM executed by a single wavefront.
*
* Each structure instantiation must contain the following fields:
* - wave_size - number of threads that execute single DPP GEMM operation, usually equal to the
* - wave_size - number of threads that execute a single DPP GEMM operation, usually equal to the
* number of threads in a wavefront;
* - lanegroup_size - number of threads (lanes) that share data using DPP instruction modifier,
* it's 8 in case of DPP8;
......@@ -254,16 +254,15 @@ struct DppGemm
return make_tuple(m_dpp_idx, n_dpp_idx);
}
__host__ __device__ static auto CalculateAThreadOriginDataIndex()
__host__ __device__ static auto CalculateAThreadOriginDataIndex_K_M()
{
const auto laneId = get_thread_local_1d_id();
const auto wave_row = laneId / dpp_instr.n_per_wave;
auto m_idx = dpp_instr.m_per_thread * wave_row + GetLaneIdInLaneGroup();
return make_tuple(0, m_idx % dpp_instr.m_per_wave);
return make_tuple(0, laneId % dpp_instr.m_per_lanegroup);
}
__host__ __device__ static auto CalculateBThreadOriginDataIndex()
__host__ __device__ static auto CalculateBThreadOriginDataIndex_K_N()
{
const auto laneId = get_thread_local_1d_id();
return make_tuple(0, laneId % dpp_instr.n_per_wave);
......@@ -271,13 +270,13 @@ struct DppGemm
__device__ static CIndex GetBeginOfThreadBlk()
{
const auto dpp_idx = GetDppOpIdx();
const auto dpp_op_idx = GetDppOpIdx();
const auto m_dpp_idx = dpp_idx[I0];
const auto n_dpp_idx = dpp_idx[I1];
const auto m_dpp_op_idx = dpp_op_idx[I0];
const auto n_dpp_op_idx = dpp_op_idx[I1];
index_t n_offset = n_dpp_idx * dpp_instr.n_per_lanegroup + GetLaneIdInLaneGroup();
index_t m_offset = m_dpp_idx * dpp_instr.m_per_lanegroup;
index_t n_offset = n_dpp_op_idx * dpp_instr.n_per_lanegroup + GetLaneIdInLaneGroup();
index_t m_offset = m_dpp_op_idx * dpp_instr.m_per_lanegroup;
return CIndex{m_offset, n_offset};
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment