Unverified Commit 8e374781 authored by guangzlu's avatar guangzlu Committed by GitHub
Browse files

modified grouped gemm addressing method (#307)



* modified grouped gemm addressing method

* modified addressing method in device_grouped_gemm_xdl.hpp
Co-authored-by: default avatarroot <root@dc-smc-13.amd.com>
Co-authored-by: default avatarChao Liu <chao.liu2@amd.com>
parent 63fd5da6
......@@ -46,13 +46,22 @@ __global__ void
const auto gemm_desc_ptr =
reinterpret_cast<const GemmDesc*>(cast_pointer_to_generic_address_space(gemm_descs_const));
index_t group_id = 0;
for(index_t i = 0; i < group_count; i++)
index_t left = 0;
index_t right = group_count;
index_t group_id = index_t((left + right) / 2);
while((!(block_id >= gemm_desc_ptr[group_id].BlockStart_ &&
block_id < gemm_desc_ptr[group_id].BlockEnd_)) &&
left <= right)
{
group_id =
(block_id >= gemm_desc_ptr[i].BlockStart_ && block_id < gemm_desc_ptr[i].BlockEnd_)
? i
: group_id;
if(block_id < gemm_desc_ptr[group_id].BlockStart_)
{
right = group_id;
}
else
{
left = group_id;
}
group_id = index_t((left + right) / 2);
}
GridwiseGemm::template Run<HasMainKBlockLoop>(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment