Unverified Commit d4154c35 authored by Jinzhen Lin's avatar Jinzhen Lin Committed by GitHub
Browse files

[Bugfix] fix moe marlin `topk_weight` loading (#18080)


Co-authored-by: default avatarmgoin <mgoin64@gmail.com>
parent 6685890d
...@@ -473,15 +473,15 @@ __global__ void Marlin( ...@@ -473,15 +473,15 @@ __global__ void Marlin(
if (mul_topk_weights) { if (mul_topk_weights) {
#pragma unroll #pragma unroll
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
int idx = tid4 * 4 + i;
idx = idx < block_num_valid_tokens ? idx : 0;
if constexpr (w_type == vllm::kFE2M1f) { if constexpr (w_type == vllm::kFE2M1f) {
sh_block_topk_weights[tid4 * 4 + i] = __hmul2( sh_block_topk_weights[idx] = __hmul2(
global_scale, global_scale, Dtype::num2num2(Dtype::float2num(
Dtype::num2num2(Dtype::float2num( topk_weights_ptr[sh_block_sorted_ids[idx]])));
topk_weights_ptr[sh_block_sorted_ids[tid4 * 4 + i]])));
} else { } else {
sh_block_topk_weights[tid4 * 4 + i] = sh_block_topk_weights[idx] = Dtype::num2num2(
Dtype::num2num2(Dtype::float2num( Dtype::float2num(topk_weights_ptr[sh_block_sorted_ids[idx]]));
topk_weights_ptr[sh_block_sorted_ids[tid4 * 4 + i]]));
} }
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment