Commit b80b9e29 authored by Bartlomiej Wroblewski's avatar Bartlomiej Wroblewski
Browse files

Review: Fix tensor adaptor; remove unnecessary element

parent 2efa20ad
......@@ -283,7 +283,12 @@ struct DppGemm
});
}
__device__ static auto GetLaneId() { return get_thread_local_1d_id() % dpp_instr.wave_size; }
__device__ static auto GetLaneIdInWave()
{
return get_thread_local_1d_id() % dpp_instr.wave_size;
}
__device__ static auto GetWaveId() { return get_thread_local_1d_id() / dpp_instr.wave_size; }
__device__ static auto GetLaneIdInLaneGroup()
{
......@@ -292,26 +297,25 @@ struct DppGemm
__device__ static auto GetLaneGroupIdInWave()
{
return get_thread_local_1d_id() / dpp_instr.lanegroup_size;
return GetLaneIdInWave() / dpp_instr.lanegroup_size;
}
__device__ static auto GetDppIdx()
__device__ static auto GetDppOpIdx()
{
const auto lanegroupId = GetLaneGroupIdInWave();
constexpr auto lanegroup_idx_1d_to_dpp_idx_2d_adaptor = make_single_stage_tensor_adaptor(
make_tuple(
make_merge_transform(make_tuple(1,
dpp_instr.m_per_wave / dpp_instr.m_per_lanegroup,
make_merge_transform(make_tuple(dpp_instr.m_per_wave / dpp_instr.m_per_lanegroup,
dpp_instr.n_per_wave / dpp_instr.n_per_lanegroup))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0, 1>{}),
make_tuple(Sequence<0>{}));
const auto dpp_idx = lanegroup_idx_1d_to_dpp_idx_2d_adaptor.CalculateBottomIndex(
make_multi_index(lanegroupId));
const auto m_dpp_idx = dpp_idx[I1];
const auto n_dpp_idx = dpp_idx[I2];
const auto m_dpp_idx = dpp_idx[I0];
const auto n_dpp_idx = dpp_idx[I1];
return make_tuple(m_dpp_idx, n_dpp_idx);
}
......@@ -333,7 +337,7 @@ struct DppGemm
__device__ static CIndex GetBeginOfThreadBlk()
{
const auto dpp_idx = GetDppIdx();
const auto dpp_idx = GetDppOpIdx();
const auto m_dpp_idx = dpp_idx[I0];
const auto n_dpp_idx = dpp_idx[I1];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment