Unverified Commit f35cb46c authored by HAI's avatar HAI Committed by GitHub
Browse files

ROCm: Fix MoE padding for none FP8 cases (#2111)

parent 7f8fcd39
......@@ -250,9 +250,12 @@ def invoke_fused_moe_kernel(
assert topk_weights.stride(1) == 1
assert sorted_token_ids.stride(0) == 1
padded_size = padding_size
if not use_fp8:
assert A_scale is None
assert B_scale is None
# MOE_PADDING FP8 only
padded_size = 0
else:
A, A_scale = ops.scaled_fp8_quant(A, A_scale)
assert B_scale is not None
......@@ -262,7 +265,7 @@ def invoke_fused_moe_kernel(
* triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]),
)
K = B.shape[2] - padding_size
K = B.shape[2] - padded_size
if K % config["BLOCK_SIZE_K"] == 0:
even_ks = True
else:
......@@ -279,7 +282,7 @@ def invoke_fused_moe_kernel(
expert_ids,
num_tokens_post_padded,
B.shape[1],
B.shape[2] - padding_size,
B.shape[2] - padded_size,
sorted_token_ids.shape[0],
topk_ids.numel(),
A.stride(0),
......@@ -480,8 +483,12 @@ def fused_experts(
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
):
padded_size = padding_size
if not use_fp8:
# MOE_PADDING FP8 only
padded_size = 0
# Check constraints.
assert hidden_states.shape[1] == w1.shape[2] - padding_size, "Hidden size mismatch"
assert hidden_states.shape[1] == w1.shape[2] - padded_size, "Hidden size mismatch"
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
......@@ -498,7 +505,7 @@ def fused_experts(
get_config_func = functools.partial(
try_get_optimal_moe_config,
w1.shape,
(w2.shape[0], w2.shape[1], w2.shape[2] - padding_size),
(w2.shape[0], w2.shape[1], w2.shape[2] - padded_size),
topk_ids.shape[1],
"float8" if use_fp8 else None,
override_config=override_config,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment