Unverified Commit 849957bc authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

fix: tmp revert gpt oss tp sharding on hopper (#9469)

parent cded039b
......@@ -793,9 +793,12 @@ class GptOssForCausalLM(nn.Module):
intermediate_size % mxfp4_block == 0
), f"{intermediate_size=} must be divisible by {mxfp4_block=}"
intermediate_size_block = intermediate_size // mxfp4_block
per_rank_intermediate_size_block = math.ceil(
intermediate_size_block / moe_tp_size
)
if _is_sm100_supported:
per_rank_intermediate_size_block = math.ceil(
intermediate_size_block / moe_tp_size
)
else:
per_rank_intermediate_size_block = intermediate_size_block // moe_tp_size
per_rank_intermediate_size = per_rank_intermediate_size_block * mxfp4_block
# Calculate common slicing bounds for current rank
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment