"tests/python/common/test_heterograph-misc.py" did not exist on "a0390dde93aed5e81974ba6eab29bef393798836"
Unverified Commit ccd3fb94 authored by hlu1's avatar hlu1 Committed by GitHub
Browse files

[fix] Fix mxfp4 triton MoE tp bug (#9473)


Signed-off-by: default avatarHao Lu <14827759+hlu1@users.noreply.github.com>
parent c9dd70fb
......@@ -111,9 +111,8 @@ class FusedMoE(torch.nn.Module):
hidden_size: Input hidden state size of the transformer
intermediate_size: Intermediate size of the experts
params_dtype: Data type for the parameters.
reduce_results: Whether to all all_reduce on the output of the layer
renomalize: Whether to renormalize the logits in the fused_moe kernel
quant_config: Quantization configure.
reduce_results: Whether to apply all_reduce on the output of the layer
quant_config: Quantization configuration.
inplace: suggestion to compute inplace (modify input activation).
"""
......@@ -182,9 +181,6 @@ class FusedMoE(torch.nn.Module):
self.expert_map_cpu = torch.full(
(self.num_experts,), -1, dtype=torch.int32, device="cpu"
)
self.expert_map_cpu = torch.full(
(self.num_experts,), -1, dtype=torch.int32, device="cpu"
)
# Create a expert map for the local experts
self.expert_map_cpu[
self.moe_ep_rank
......
......@@ -309,6 +309,13 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
intermediate_size_per_partition_after_pad = round_up(
intermediate_size, 64
)
elif has_triton_kernels:
# TODO: this is a hack to make
# intermediate_size_per_partition_after_pad the same as the
# per_rank_intermediate_size during weight loading
intermediate_size_per_partition_after_pad = round_up(
intermediate_size, mxfp4_block
)
self.intermediate_size = intermediate_size_per_partition_after_pad
......
......@@ -793,12 +793,11 @@ class GptOssForCausalLM(nn.Module):
intermediate_size % mxfp4_block == 0
), f"{intermediate_size=} must be divisible by {mxfp4_block=}"
intermediate_size_block = intermediate_size // mxfp4_block
if _is_sm100_supported:
per_rank_intermediate_size_block = math.ceil(
intermediate_size_block / moe_tp_size
)
else:
per_rank_intermediate_size_block = intermediate_size_block // moe_tp_size
per_rank_intermediate_size_block = math.ceil(
intermediate_size_block / moe_tp_size
)
per_rank_intermediate_size = per_rank_intermediate_size_block * mxfp4_block
# Calculate common slicing bounds for current rank
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment