Commit b5e35626 authored by coderfeli's avatar coderfeli
Browse files

fix typo

parent d12c3c6f
...@@ -187,7 +187,7 @@ struct FusedMoeGemmPipeline_FlatmmUk ...@@ -187,7 +187,7 @@ struct FusedMoeGemmPipeline_FlatmmUk
const IndexDataType expert_id = __builtin_amdgcn_readfirstlane( const IndexDataType expert_id = __builtin_amdgcn_readfirstlane(
reinterpret_cast<const IndexDataType*>(kargs.sorted_expert_ids_ptr)[sorted_tile_id]); reinterpret_cast<const IndexDataType*>(kargs.sorted_expert_ids_ptr)[sorted_tile_id]);
const IndexDataType expert_first_token = __builtin_amdgcn_readfirstlane( const IndexDataType expert_first_token = __builtin_amdgcn_readfirstlane(
reinterpret_cast<const IndexDataType*>(kargs.sorted_token_ids_ptr)[sorted_tile_id * 32]); reinterpret_cast<const IndexDataType*>(kargs.sorted_token_ids_ptr)[sorted_tile_id * BlockShape::Block_M0]);
index_t expert_stride_0 = shared_intermediate_size_0 * kargs.hidden_size; index_t expert_stride_0 = shared_intermediate_size_0 * kargs.hidden_size;
index_t expert_stride_1 = shared_intermediate_size_1 * kargs.hidden_size; index_t expert_stride_1 = shared_intermediate_size_1 * kargs.hidden_size;
...@@ -209,9 +209,10 @@ struct FusedMoeGemmPipeline_FlatmmUk ...@@ -209,9 +209,10 @@ struct FusedMoeGemmPipeline_FlatmmUk
threadIdx.x % (BlockShape::Block_K0 / kAlignmentA) * kAlignmentA; threadIdx.x % (BlockShape::Block_K0 / kAlignmentA) * kAlignmentA;
}, },
number<row_ids_a.size()>{}); number<row_ids_a.size()>{});
if (expert_first_token&0xffffff >= kargs.num_tokens) if ((expert_first_token&0xffffff) >= kargs.num_tokens)
return; return;
// printf("tid %d %d, first %d\n", blockIdx.x, threadIdx.x,expert_first_token&0xffffff); // if (threadIdx.x %32==0)
// printf("block %d %d thread %d, expert %d firstt %x %d, sorted_tile_id %d\n", blockIdx.x,blockIdx.y, threadIdx.x,expert_id, expert_first_token,(expert_first_token&0xffffff), sorted_tile_id);
// for (int i = 0; i < row_ids_a.size(); i++) { // for (int i = 0; i < row_ids_a.size(); i++) {
// printf("%d bid %d tid %d rowid %d\n", i, blockIdx.x, threadIdx.x, row_ids_a[i]); // printf("%d bid %d tid %d rowid %d\n", i, blockIdx.x, threadIdx.x, row_ids_a[i]);
// } // }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment