Commit 58db931e authored by coderfeli's avatar coderfeli
Browse files

fix topk id

parent 84b27d75
......@@ -1170,7 +1170,6 @@ struct GridwiseMoeGemmGather
if(token_pos >= max_token_id || token0 >= problem.NumTokens)
return;
const index_t topk_id = (p_sorted_token_ids[block_m_id * MPerBlock] & 0xff000000) >> 24;
StaticallyIndexedArray<index_t, AMRepeats> gather_offsets; //= p_sorted_token_ids[token_pos];
static_for<0, AMRepeats, 1>{}([&](auto m0) {
const index_t token_offset = (token_pos + m0 < max_token_id) ?
......@@ -1463,8 +1462,11 @@ struct GridwiseMoeGemmGather
StaticallyIndexedArray<float, EMRepeats> scatter_weights; //= for topk
// too hack here, 2 specific for topk weights, fixme
const float *p_sorted_weights = p_ds_grid[I0];
// const index_t topk_id[EMRepeats];// = (p_sorted_token_ids[block_m_id * MPerBlock] & 0xff000000) >> 24;
static_for<0, EMRepeats, 1>{}([&](auto m0) {
scatter_offsets(m0) = ((p_sorted_token_ids[c_token_pos + m0] & 0xffffff) * problem.TopK + topk_id) * problem.N;
const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
scatter_offsets(m0) = ((fused_token & 0xffffff) * problem.TopK + (fused_token >> 24)) * problem.N;
scatter_weights(m0) = p_sorted_weights[(c_token_pos + m0) * problem.StrideDs[0]];
// if(threadIdx.x % 16 == 0)
// printf("init off bid %d tid %d m %d off %d\n", blockIdx.y, threadIdx.x, m0(), scatter_offsets(m0));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment