Unverified Commit 629584bf authored by xuebwang-amd's avatar xuebwang-amd Committed by GitHub
Browse files

[Kernel][MoE] fix computation order of MoE weight multiplication and improve flow (#31962)


Signed-off-by: default avatarxuebwang-amd <xuebwang@amd.com>
parent 0a7dd237
...@@ -531,22 +531,37 @@ def fused_moe_kernel( ...@@ -531,22 +531,37 @@ def fused_moe_kernel(
a_ptrs += BLOCK_SIZE_K * stride_ak a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk b_ptrs += BLOCK_SIZE_K * stride_bk
# Router weight multiplication MUST happen in float32 before precision # Dequantization for supported quantization schemes:
# conversion for numerical stability (especially critical on ROCm). # - int8_w8a16
if MUL_ROUTED_WEIGHT: # - fp8_w8a8
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0) # - int8_w8a8
accumulator = accumulator * moe_weight[:, None] # Accumulator and scalings are in float32 to preserve numerical accuracy.
if use_int8_w8a16: if use_int8_w8a16:
accumulator = accumulator * b_scale accumulator = accumulator * b_scale
elif (use_fp8_w8a8 or use_int8_w8a8) and not (group_k > 0 and group_n > 0): elif (use_fp8_w8a8 or use_int8_w8a8) and not (group_k > 0 and group_n > 0):
accumulator = accumulator * a_scale * b_scale accumulator = accumulator * a_scale * b_scale
# Bias is added AFTER dequantization since bias is typically stored in # Bias addition:
# the output dtype and should not be scaled by quantization factors. # Bias must be applied after dequantization:
# - Since bias is typically not quantized
# - Bias should not be scaled by quantization factors
if HAS_BIAS: if HAS_BIAS:
accumulator = accumulator + bias[None, :] accumulator += bias[None, :]
# Router (MoE) weight multiplication:
# This multiplication MUST be performed in float32 before any precision
# conversion to ensure numerical stability, which is especially critical
# on ROCm platforms.
if MUL_ROUTED_WEIGHT:
moe_weight = tl.load(
topk_weights_ptr + offs_token,
mask=token_mask,
other=0,
)
accumulator *= moe_weight[:, None]
# Final precision conversion:
# Cast once at the end to the desired compute/output dtype.
accumulator = accumulator.to(compute_type) accumulator = accumulator.to(compute_type)
# ----------------------------------------------------------- # -----------------------------------------------------------
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment