Unverified Commit 0dd5dee9 authored by xuebwang-amd's avatar xuebwang-amd Committed by GitHub
Browse files

[Bugfix][Kernel] fix bias adding in triton kernel implemented fused moe (#31676)


Signed-off-by: default avatarxuebwang-amd <xuebwang@amd.com>
parent 4614c5a5
...@@ -518,11 +518,7 @@ def fused_moe_kernel( ...@@ -518,11 +518,7 @@ def fused_moe_kernel(
# Advance the ptrs to the next K block. # Advance the ptrs to the next K block.
a_ptrs += BLOCK_SIZE_K * stride_ak a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk b_ptrs += BLOCK_SIZE_K * stride_bk
if HAS_BIAS:
accumulator = accumulator + bias[None, :]
if MUL_ROUTED_WEIGHT:
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
accumulator = accumulator * moe_weight[:, None]
if use_int8_w8a16: if use_int8_w8a16:
accumulator = (accumulator * b_scale).to(compute_type) accumulator = (accumulator * b_scale).to(compute_type)
elif use_fp8_w8a8 or use_int8_w8a8: elif use_fp8_w8a8 or use_int8_w8a8:
...@@ -533,6 +529,13 @@ def fused_moe_kernel( ...@@ -533,6 +529,13 @@ def fused_moe_kernel(
else: else:
accumulator = accumulator.to(compute_type) accumulator = accumulator.to(compute_type)
# Since bias is typically not quantized, it's added after dequantization.
if HAS_BIAS:
accumulator = accumulator + bias[None, :]
if MUL_ROUTED_WEIGHT:
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
accumulator = accumulator * moe_weight[:, None]
# ----------------------------------------------------------- # -----------------------------------------------------------
# Write back the block of the output # Write back the block of the output
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment