Unverified Commit c4041f37 authored by Andreas Karatzas's avatar Andreas Karatzas Committed by GitHub
Browse files

[ROCm][LoRA] Fix MoE accuracy regression by preserving float32 router weight scaling (#31931)


Signed-off-by: default avatarAndreas Karatzas <akaratza@amd.com>
parent a79079fe
......@@ -519,6 +519,12 @@ def fused_moe_kernel(
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
# Router weight multiplication MUST happen in float32 before precision
# conversion for numerical stability (especially critical on ROCm).
if MUL_ROUTED_WEIGHT:
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
accumulator = accumulator * moe_weight[:, None]
if use_int8_w8a16:
accumulator = (accumulator * b_scale).to(compute_type)
elif use_fp8_w8a8 or use_int8_w8a8:
......@@ -529,12 +535,10 @@ def fused_moe_kernel(
else:
accumulator = accumulator.to(compute_type)
# Since bias is typically not quantized, it's added after dequantization.
# Bias is added AFTER dequantization since bias is typically stored in
# the output dtype and should not be scaled by quantization factors.
if HAS_BIAS:
accumulator = accumulator + bias[None, :]
if MUL_ROUTED_WEIGHT:
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
accumulator = accumulator * moe_weight[:, None]
# -----------------------------------------------------------
# Write back the block of the output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment