Unverified Commit dae9a80f authored by hlu1's avatar hlu1 Committed by GitHub
Browse files

[fix] Fix mxfp4 weight loading bug with TP sharding in GPT-OSS (#9433)


Signed-off-by: default avatarHao Lu <14827759+hlu1@users.noreply.github.com>
Signed-off-by: default avatarXinyuan Tong <xinyuantong.cs@gmail.com>
Co-authored-by: default avatarXinyuan Tong <xinyuantong.cs@gmail.com>
parent e85cb1ce
...@@ -737,8 +737,8 @@ class ResponsesRequest(BaseModel): ...@@ -737,8 +737,8 @@ class ResponsesRequest(BaseModel):
else: else:
max_tokens = default_max_tokens max_tokens = default_max_tokens
# Avoid exceed the context length by minus 1 token # Avoid exceed the context length by minus 2 token
max_tokens -= 1 max_tokens -= 2
# Get parameters with defaults # Get parameters with defaults
temperature = self.temperature temperature = self.temperature
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
"""Inference-only GptOss model compatible with HuggingFace weights.""" """Inference-only GptOss model compatible with HuggingFace weights."""
import logging import logging
import math
from collections.abc import Iterable from collections.abc import Iterable
from functools import partial from functools import partial
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
...@@ -788,18 +789,25 @@ class GptOssForCausalLM(nn.Module): ...@@ -788,18 +789,25 @@ class GptOssForCausalLM(nn.Module):
moe_ep_size = get_moe_expert_parallel_world_size() moe_ep_size = get_moe_expert_parallel_world_size()
intermediate_size = self.config.intermediate_size intermediate_size = self.config.intermediate_size
assert (
intermediate_size % mxfp4_block == 0
), f"{intermediate_size=} must be divisible by {mxfp4_block=}"
intermediate_size_block = intermediate_size // mxfp4_block intermediate_size_block = intermediate_size // mxfp4_block
per_rank_intermediate_size_block = intermediate_size_block // moe_tp_size per_rank_intermediate_size_block = math.ceil(
intermediate_size_block / moe_tp_size
)
per_rank_intermediate_size = per_rank_intermediate_size_block * mxfp4_block per_rank_intermediate_size = per_rank_intermediate_size_block * mxfp4_block
# Calculate common slicing bounds for current rank # Calculate common slicing bounds for current rank
assert self.config.num_local_experts % moe_ep_size == 0 assert self.config.num_local_experts % moe_ep_size == 0
moe_num_global_experts = self.config.num_local_experts moe_num_global_experts = self.config.num_local_experts
moe_num_local_experts = self.config.num_local_experts // moe_ep_size moe_num_local_experts = self.config.num_local_experts // moe_ep_size
moe_tp_rank_start = moe_tp_rank * per_rank_intermediate_size moe_tp_rank_start = moe_tp_rank * per_rank_intermediate_size
moe_tp_rank_end = min( moe_tp_rank_end = min(
(moe_tp_rank + 1) * per_rank_intermediate_size, intermediate_size (moe_tp_rank + 1) * per_rank_intermediate_size, intermediate_size
) )
moe_ep_rank_start = moe_ep_rank * moe_num_local_experts moe_ep_rank_start = moe_ep_rank * moe_num_local_experts
moe_ep_rank_end = (moe_ep_rank + 1) * moe_num_local_experts moe_ep_rank_end = (moe_ep_rank + 1) * moe_num_local_experts
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment