Commit 88412f9e authored by coderfeli's avatar coderfeli
Browse files

impl sorting count eid

parent 4b91d1ce
...@@ -125,7 +125,7 @@ bool test_moe_sorting(ck_tile::ArgParser args) ...@@ -125,7 +125,7 @@ bool test_moe_sorting(ck_tile::ArgParser args)
ck_tile::HostTensor<IndexType> sorted_ids_host({max_output_ids}, {1}); ck_tile::HostTensor<IndexType> sorted_ids_host({max_output_ids}, {1});
ck_tile::HostTensor<WeightType> sorted_weights_host({max_output_ids}, {1}); ck_tile::HostTensor<WeightType> sorted_weights_host({max_output_ids}, {1});
ck_tile::HostTensor<IndexType> sorted_expert_ids_host({max_output_ids / unit_size}, {1}); ck_tile::HostTensor<IndexType> sorted_expert_ids_host({max_output_ids / unit_size}, {1});
ck_tile::HostTensor<IndexType> sorted_id_cnt_host({1}, {1}); ck_tile::HostTensor<IndexType> sorted_id_cnt_host({1 + max_output_ids / unit_size}, {1});
ck_tile::HostTensor<float> moe_buf_host({moe_buf_size}); ck_tile::HostTensor<float> moe_buf_host({moe_buf_size});
ck_tile::FillUniformDistribution<WeightType>{-.5f, .5f}(weights_host); ck_tile::FillUniformDistribution<WeightType>{-.5f, .5f}(weights_host);
...@@ -205,7 +205,8 @@ bool test_moe_sorting(ck_tile::ArgParser args) ...@@ -205,7 +205,8 @@ bool test_moe_sorting(ck_tile::ArgParser args)
{ {
moe_buf_dev.FromDevice(moe_buf_host.data()); moe_buf_dev.FromDevice(moe_buf_host.data());
} }
sorted_expert_ids_host.savetxt("sorted_expert_ids_host.txt","int");
sorted_id_cnt_host.savetxt("sorted_id_cnt_host.txt","int");
bool rtn = true; bool rtn = true;
if(validate) if(validate)
{ {
......
...@@ -573,11 +573,13 @@ struct MoeSortingKernel ...@@ -573,11 +573,13 @@ struct MoeSortingKernel
{ {
int e_start = cumsum[tid]; int e_start = cumsum[tid];
int e_end = cumsum[tid + 1]; int e_end = cumsum[tid + 1];
index_t *p_sorted_expert_cnts = p_total_tokens_post_pad + 1;
int e_size = unit_size_mdiv.div(e_end - e_start + unit_size_mdiv.divisor - 1); int e_size = unit_size_mdiv.div(e_end - e_start + unit_size_mdiv.divisor - 1);
for(int i = e_start; i < e_end; i += unit_size_mdiv.divisor) for(int i = e_start; i < e_end; i += unit_size_mdiv.divisor)
{ {
p_sorted_expert_ids[unit_size_mdiv.div(i)] = tid; p_sorted_expert_ids[unit_size_mdiv.div(i)] = tid;
p_sorted_expert_cnts[] p_sorted_expert_cnts[unit_size_mdiv.div(i)] = e_size;
printf("tid %d size %d \n", tid, e_size);
} }
} }
...@@ -866,7 +868,6 @@ struct MoeSortingKernel ...@@ -866,7 +868,6 @@ struct MoeSortingKernel
} }
__syncthreads(); __syncthreads();
} }
for(int i_e = tid; i_e < num_experts; i_e += block_size) for(int i_e = tid; i_e < num_experts; i_e += block_size)
{ {
int e_start = smem_cumsum(i_e); int e_start = smem_cumsum(i_e);
...@@ -894,10 +895,12 @@ struct MoeSortingKernel ...@@ -894,10 +895,12 @@ struct MoeSortingKernel
if(local_expert_mask[i_e] == 0) if(local_expert_mask[i_e] == 0)
continue; continue;
} }
index_t *p_sorted_expert_cnts = p_total_tokens_post_pad + 1;
int e_size = unit_size_mdiv.div(e_end - e_start + unit_size_mdiv.divisor - 1);
for(int i = e_start; i < e_end; i += unit_size_mdiv.divisor) for(int i = e_start; i < e_end; i += unit_size_mdiv.divisor)
{ {
p_sorted_expert_ids[unit_size_mdiv.div(i)] = expert_id; p_sorted_expert_ids[unit_size_mdiv.div(i)] = expert_id;
p_sorted_expert_cnts[unit_size_mdiv.div(i)] = e_size;
} }
} }
smem_cumdup(num_experts) = smem_cumsum(num_experts); smem_cumdup(num_experts) = smem_cumsum(num_experts);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment