Commit 88412f9e authored by coderfeli's avatar coderfeli
Browse files

impl sorting count eid

parent 4b91d1ce
......@@ -125,7 +125,7 @@ bool test_moe_sorting(ck_tile::ArgParser args)
ck_tile::HostTensor<IndexType> sorted_ids_host({max_output_ids}, {1});
ck_tile::HostTensor<WeightType> sorted_weights_host({max_output_ids}, {1});
ck_tile::HostTensor<IndexType> sorted_expert_ids_host({max_output_ids / unit_size}, {1});
ck_tile::HostTensor<IndexType> sorted_id_cnt_host({1}, {1});
ck_tile::HostTensor<IndexType> sorted_id_cnt_host({1 + max_output_ids / unit_size}, {1});
ck_tile::HostTensor<float> moe_buf_host({moe_buf_size});
ck_tile::FillUniformDistribution<WeightType>{-.5f, .5f}(weights_host);
......@@ -205,7 +205,8 @@ bool test_moe_sorting(ck_tile::ArgParser args)
{
moe_buf_dev.FromDevice(moe_buf_host.data());
}
sorted_expert_ids_host.savetxt("sorted_expert_ids_host.txt","int");
sorted_id_cnt_host.savetxt("sorted_id_cnt_host.txt","int");
bool rtn = true;
if(validate)
{
......
......@@ -573,11 +573,13 @@ struct MoeSortingKernel
{
int e_start = cumsum[tid];
int e_end = cumsum[tid + 1];
index_t *p_sorted_expert_cnts = p_total_tokens_post_pad + 1;
int e_size = unit_size_mdiv.div(e_end - e_start + unit_size_mdiv.divisor - 1);
for(int i = e_start; i < e_end; i += unit_size_mdiv.divisor)
{
p_sorted_expert_ids[unit_size_mdiv.div(i)] = tid;
p_sorted_expert_cnts[]
p_sorted_expert_cnts[unit_size_mdiv.div(i)] = e_size;
printf("tid %d size %d \n", tid, e_size);
}
}
......@@ -866,7 +868,6 @@ struct MoeSortingKernel
}
__syncthreads();
}
for(int i_e = tid; i_e < num_experts; i_e += block_size)
{
int e_start = smem_cumsum(i_e);
......@@ -894,10 +895,12 @@ struct MoeSortingKernel
if(local_expert_mask[i_e] == 0)
continue;
}
index_t *p_sorted_expert_cnts = p_total_tokens_post_pad + 1;
int e_size = unit_size_mdiv.div(e_end - e_start + unit_size_mdiv.divisor - 1);
for(int i = e_start; i < e_end; i += unit_size_mdiv.divisor)
{
p_sorted_expert_ids[unit_size_mdiv.div(i)] = expert_id;
p_sorted_expert_cnts[unit_size_mdiv.div(i)] = e_size;
}
}
smem_cumdup(num_experts) = smem_cumsum(num_experts);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment