Unverified Commit ecc3dd66 authored by Xin Yang's avatar Xin Yang Committed by GitHub
Browse files

[Bugfix] Fix FusedMoE LoRA kernel offs_token out of bound value (#32279)


Signed-off-by: default avatarXin Yang <xyangx@amazon.com>
parent 7e1f10d5
......@@ -139,7 +139,9 @@ def _fused_moe_lora_kernel(
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
token_ind = stride_tl * lora_id + offs_token_id
offs_token = tl.load(
sorted_token_ids_ptr + token_ind, token_ind < max_loras * stride_tl, 0
sorted_token_ids_ptr + token_ind,
mask=token_ind < max_loras * stride_tl,
other=num_valid_tokens,
)
token_mask = offs_token < num_valid_tokens
......@@ -185,7 +187,7 @@ def _fused_moe_lora_kernel(
b_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_bk
if MUL_ROUTED_WEIGHT:
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0.0)
accumulator = accumulator * moe_weight[:, None]
accumulator = accumulator.to(c_ptr.dtype.element_ty)
# Write back the block of the output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment