Unverified Commit f35cb46c authored by HAI's avatar HAI Committed by GitHub
Browse files

ROCm: Fix MoE padding for none FP8 cases (#2111)

parent 7f8fcd39
...@@ -250,9 +250,12 @@ def invoke_fused_moe_kernel( ...@@ -250,9 +250,12 @@ def invoke_fused_moe_kernel(
assert topk_weights.stride(1) == 1 assert topk_weights.stride(1) == 1
assert sorted_token_ids.stride(0) == 1 assert sorted_token_ids.stride(0) == 1
padded_size = padding_size
if not use_fp8: if not use_fp8:
assert A_scale is None assert A_scale is None
assert B_scale is None assert B_scale is None
# MOE_PADDING FP8 only
padded_size = 0
else: else:
A, A_scale = ops.scaled_fp8_quant(A, A_scale) A, A_scale = ops.scaled_fp8_quant(A, A_scale)
assert B_scale is not None assert B_scale is not None
...@@ -262,7 +265,7 @@ def invoke_fused_moe_kernel( ...@@ -262,7 +265,7 @@ def invoke_fused_moe_kernel(
* triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]), * triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]),
) )
K = B.shape[2] - padding_size K = B.shape[2] - padded_size
if K % config["BLOCK_SIZE_K"] == 0: if K % config["BLOCK_SIZE_K"] == 0:
even_ks = True even_ks = True
else: else:
...@@ -279,7 +282,7 @@ def invoke_fused_moe_kernel( ...@@ -279,7 +282,7 @@ def invoke_fused_moe_kernel(
expert_ids, expert_ids,
num_tokens_post_padded, num_tokens_post_padded,
B.shape[1], B.shape[1],
B.shape[2] - padding_size, B.shape[2] - padded_size,
sorted_token_ids.shape[0], sorted_token_ids.shape[0],
topk_ids.numel(), topk_ids.numel(),
A.stride(0), A.stride(0),
...@@ -480,8 +483,12 @@ def fused_experts( ...@@ -480,8 +483,12 @@ def fused_experts(
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
): ):
padded_size = padding_size
if not use_fp8:
# MOE_PADDING FP8 only
padded_size = 0
# Check constraints. # Check constraints.
assert hidden_states.shape[1] == w1.shape[2] - padding_size, "Hidden size mismatch" assert hidden_states.shape[1] == w1.shape[2] - padded_size, "Hidden size mismatch"
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous" assert w1.is_contiguous(), "Expert weights1 must be contiguous"
...@@ -498,7 +505,7 @@ def fused_experts( ...@@ -498,7 +505,7 @@ def fused_experts(
get_config_func = functools.partial( get_config_func = functools.partial(
try_get_optimal_moe_config, try_get_optimal_moe_config,
w1.shape, w1.shape,
(w2.shape[0], w2.shape[1], w2.shape[2] - padding_size), (w2.shape[0], w2.shape[1], w2.shape[2] - padded_size),
topk_ids.shape[1], topk_ids.shape[1],
"float8" if use_fp8 else None, "float8" if use_fp8 else None,
override_config=override_config, override_config=override_config,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment