Commit b80b9e29 authored by Bartlomiej Wroblewski's avatar Bartlomiej Wroblewski
Browse files

Review: Fix tensor adaptor; remove unnecessary element

parent 2efa20ad
...@@ -283,7 +283,12 @@ struct DppGemm ...@@ -283,7 +283,12 @@ struct DppGemm
}); });
} }
__device__ static auto GetLaneId() { return get_thread_local_1d_id() % dpp_instr.wave_size; } __device__ static auto GetLaneIdInWave()
{
return get_thread_local_1d_id() % dpp_instr.wave_size;
}
__device__ static auto GetWaveId() { return get_thread_local_1d_id() / dpp_instr.wave_size; }
__device__ static auto GetLaneIdInLaneGroup() __device__ static auto GetLaneIdInLaneGroup()
{ {
...@@ -292,26 +297,25 @@ struct DppGemm ...@@ -292,26 +297,25 @@ struct DppGemm
__device__ static auto GetLaneGroupIdInWave() __device__ static auto GetLaneGroupIdInWave()
{ {
return get_thread_local_1d_id() / dpp_instr.lanegroup_size; return GetLaneIdInWave() / dpp_instr.lanegroup_size;
} }
__device__ static auto GetDppIdx() __device__ static auto GetDppOpIdx()
{ {
const auto lanegroupId = GetLaneGroupIdInWave(); const auto lanegroupId = GetLaneGroupIdInWave();
constexpr auto lanegroup_idx_1d_to_dpp_idx_2d_adaptor = make_single_stage_tensor_adaptor( constexpr auto lanegroup_idx_1d_to_dpp_idx_2d_adaptor = make_single_stage_tensor_adaptor(
make_tuple( make_tuple(
make_merge_transform(make_tuple(1, make_merge_transform(make_tuple(dpp_instr.m_per_wave / dpp_instr.m_per_lanegroup,
dpp_instr.m_per_wave / dpp_instr.m_per_lanegroup,
dpp_instr.n_per_wave / dpp_instr.n_per_lanegroup))), dpp_instr.n_per_wave / dpp_instr.n_per_lanegroup))),
make_tuple(Sequence<0, 1, 2>{}), make_tuple(Sequence<0, 1>{}),
make_tuple(Sequence<0>{})); make_tuple(Sequence<0>{}));
const auto dpp_idx = lanegroup_idx_1d_to_dpp_idx_2d_adaptor.CalculateBottomIndex( const auto dpp_idx = lanegroup_idx_1d_to_dpp_idx_2d_adaptor.CalculateBottomIndex(
make_multi_index(lanegroupId)); make_multi_index(lanegroupId));
const auto m_dpp_idx = dpp_idx[I1]; const auto m_dpp_idx = dpp_idx[I0];
const auto n_dpp_idx = dpp_idx[I2]; const auto n_dpp_idx = dpp_idx[I1];
return make_tuple(m_dpp_idx, n_dpp_idx); return make_tuple(m_dpp_idx, n_dpp_idx);
} }
...@@ -333,7 +337,7 @@ struct DppGemm ...@@ -333,7 +337,7 @@ struct DppGemm
__device__ static CIndex GetBeginOfThreadBlk() __device__ static CIndex GetBeginOfThreadBlk()
{ {
const auto dpp_idx = GetDppIdx(); const auto dpp_idx = GetDppOpIdx();
const auto m_dpp_idx = dpp_idx[I0]; const auto m_dpp_idx = dpp_idx[I0];
const auto n_dpp_idx = dpp_idx[I1]; const auto n_dpp_idx = dpp_idx[I1];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment