Unverified Commit 1f3dbd95 authored by Jakub Zakrzewski's avatar Jakub Zakrzewski Committed by GitHub
Browse files

[Bugfix][Model] Fix gpt-oss batch invariance (#35404)


Signed-off-by: default avatarJakub Zakrzewski <jzakrzewski@nvidia.com>
parent 1d532f9d
......@@ -28,7 +28,6 @@ from vllm.model_executor.layers.quantization.base_config import (
)
from vllm.model_executor.layers.utils import (
dispatch_unquantized_gemm,
is_layer_moe_router_gate,
)
from vllm.model_executor.parameter import (
BasevLLMParameter,
......@@ -257,11 +256,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
if (
vllm_is_batch_invariant()
and current_platform.is_cuda_alike()
and is_layer_moe_router_gate(getattr(layer, "prefix", ""))
):
if vllm_is_batch_invariant() and current_platform.is_cuda_alike():
return linear_batch_invariant(x, layer.weight, bias)
return dispatch_unquantized_gemm()(layer, x, layer.weight, bias)
......
......@@ -23,7 +23,11 @@ from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear
from vllm.model_executor.layers.linear import (
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import OCP_MX_BLOCK_SIZE
......@@ -165,7 +169,14 @@ class MLPBlock(torch.nn.Module):
self.hidden_size = config.hidden_size
self.experts_per_token = config.num_experts_per_tok
self.world_size = dist.get_world_size() if dist.is_initialized() else 1
self.router = torch.nn.Linear(config.hidden_size, config.num_local_experts)
self.router = ReplicatedLinear(
config.hidden_size,
config.num_local_experts,
bias=True,
quant_config=None,
prefix=f"{prefix}.router",
return_bias=False,
)
assert config.intermediate_size % self.world_size == 0
self.experts = FusedMoE(
num_experts=config.num_local_experts,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment