Unverified Commit 1f3dbd95 authored by Jakub Zakrzewski's avatar Jakub Zakrzewski Committed by GitHub
Browse files

[Bugfix][Model] Fix gpt-oss batch invariance (#35404)


Signed-off-by: default avatarJakub Zakrzewski <jzakrzewski@nvidia.com>
parent 1d532f9d
...@@ -28,7 +28,6 @@ from vllm.model_executor.layers.quantization.base_config import ( ...@@ -28,7 +28,6 @@ from vllm.model_executor.layers.quantization.base_config import (
) )
from vllm.model_executor.layers.utils import ( from vllm.model_executor.layers.utils import (
dispatch_unquantized_gemm, dispatch_unquantized_gemm,
is_layer_moe_router_gate,
) )
from vllm.model_executor.parameter import ( from vllm.model_executor.parameter import (
BasevLLMParameter, BasevLLMParameter,
...@@ -257,11 +256,7 @@ class UnquantizedLinearMethod(LinearMethodBase): ...@@ -257,11 +256,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
x: torch.Tensor, x: torch.Tensor,
bias: torch.Tensor | None = None, bias: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
if ( if vllm_is_batch_invariant() and current_platform.is_cuda_alike():
vllm_is_batch_invariant()
and current_platform.is_cuda_alike()
and is_layer_moe_router_gate(getattr(layer, "prefix", ""))
):
return linear_batch_invariant(x, layer.weight, bias) return linear_batch_invariant(x, layer.weight, bias)
return dispatch_unquantized_gemm()(layer, x, layer.weight, bias) return dispatch_unquantized_gemm()(layer, x, layer.weight, bias)
......
...@@ -23,7 +23,11 @@ from vllm.model_executor.layers.attention import Attention ...@@ -23,7 +23,11 @@ from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear from vllm.model_executor.layers.linear import (
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import OCP_MX_BLOCK_SIZE from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import OCP_MX_BLOCK_SIZE
...@@ -165,7 +169,14 @@ class MLPBlock(torch.nn.Module): ...@@ -165,7 +169,14 @@ class MLPBlock(torch.nn.Module):
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.experts_per_token = config.num_experts_per_tok self.experts_per_token = config.num_experts_per_tok
self.world_size = dist.get_world_size() if dist.is_initialized() else 1 self.world_size = dist.get_world_size() if dist.is_initialized() else 1
self.router = torch.nn.Linear(config.hidden_size, config.num_local_experts) self.router = ReplicatedLinear(
config.hidden_size,
config.num_local_experts,
bias=True,
quant_config=None,
prefix=f"{prefix}.router",
return_bias=False,
)
assert config.intermediate_size % self.world_size == 0 assert config.intermediate_size % self.world_size == 0
self.experts = FusedMoE( self.experts = FusedMoE(
num_experts=config.num_local_experts, num_experts=config.num_local_experts,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment