Unverified Commit e9b92dcd authored by bnellnm's avatar bnellnm Committed by GitHub
Browse files

[Kernels] Overlap shared experts with send/recv (#23273)


Signed-off-by: default avatarBill Nell <bnell@redhat.com>
parent fa4311d8
......@@ -654,7 +654,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
assert self.fused_experts is None
if enable_eplb:
......
......@@ -491,7 +491,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for `ModelOptFp8MoEMethod` yet.")
......@@ -1366,7 +1366,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
):
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for `ModelOptNvFp4FusedMoE` yet.")
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Callable, Optional
from typing import Any, Callable, Optional, Union
import torch
......@@ -305,7 +305,7 @@ class MoeWNA16Method(FusedMoEMethodBase):
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
assert self.fused_experts is None
if enable_eplb:
raise NotImplementedError(
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Callable, Optional
from typing import Callable, Optional, Union
import torch
from torch.nn.parameter import Parameter
......@@ -554,7 +554,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
if enable_eplb:
raise NotImplementedError("EPLB is not supported for mxfp4")
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Callable, Optional
from typing import Any, Callable, Optional, Union
import torch
......@@ -226,7 +226,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
assert self.fused_experts is None
if enable_eplb:
......@@ -390,7 +390,7 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
assert self.fused_experts is None
if enable_eplb:
......
......@@ -3,7 +3,7 @@
# Copyright © 2025, Oracle and/or its affiliates.
import os
from typing import Any, Callable, Optional
from typing import Any, Callable, Optional, Union
import torch
import torch.nn.functional as F
......@@ -291,7 +291,7 @@ class RTNMoEMethod(FusedMoEMethodBase):
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
assert self.fused_experts is None
if enable_eplb:
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.model_executor.layers.shared_fused_moe.shared_fused_moe import (
SharedFusedMoE)
__all__ = ["SharedFusedMoE"]
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import torch
from vllm.distributed import tensor_model_parallel_all_reduce
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
# TODO(bnell): Add shared + fused combo function? e.g. +
class SharedFusedMoE(FusedMoE):
"""
A FusedMoE operation that also computes the results of shared experts.
If an all2all communicator is being used the shared expert computation
can be interleaved with the fused all2all dispatch communication step.
"""
def __init__(
self,
shared_experts: torch.nn.Module,
use_overlapped: bool = True,
**kwargs,
):
super().__init__(**kwargs)
self._shared_experts = shared_experts
self.use_overlapped = use_overlapped
@property
def shared_experts(self) -> Optional[torch.nn.Module]:
return self._shared_experts if self.use_overlapped else None
def forward(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
if not self.use_overlapped:
shared_out = self._shared_experts(hidden_states)
# Reduce outputs if necessary, since the MLP should
# have been created with reduce_results=False.
if (self.reduce_results and self.tp_size > 1
and self.must_reduce_shared_expert_outputs()):
shared_out = tensor_model_parallel_all_reduce(shared_out)
fused_out = super().forward(
hidden_states=hidden_states,
router_logits=router_logits,
)
else:
shared_out, fused_out = super().forward(
hidden_states=hidden_states,
router_logits=router_logits,
)
return shared_out, fused_out
......@@ -49,6 +49,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
......@@ -147,6 +148,7 @@ class DeepseekV2MoE(nn.Module):
self.physical_expert_end = (self.physical_expert_start +
self.n_local_physical_experts)
if config.n_shared_experts is None:
self.experts = FusedMoE(
num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok,
......@@ -165,45 +167,66 @@ class DeepseekV2MoE(nn.Module):
e_score_correction_bias=self.gate.e_score_correction_bias,
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts)
if config.n_shared_experts is not None:
self.shared_experts = None
else:
intermediate_size = (config.moe_intermediate_size *
config.n_shared_experts)
self.shared_experts = DeepseekV2MLP(
hidden_size=config.hidden_size,
intermediate_size=intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
reduce_results=self.experts.must_reduce_shared_expert_outputs(
),
reduce_results=False,
prefix=f"{prefix}.shared_experts",
)
self.experts = SharedFusedMoE(
shared_experts=self.shared_experts,
num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
use_grouped_topk=True,
num_expert_group=config.n_group,
topk_group=config.topk_group,
prefix=f"{prefix}.experts",
scoring_func=config.scoring_func,
# we do scaling outside, set factor to 1.0 to avoid double mul
routed_scaling_factor=1.0,
e_score_correction_bias=self.gate.e_score_correction_bias,
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
if self.n_shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
if hidden_states.dtype != torch.float16:
final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits) * self.routed_scaling_factor
else:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states = self.experts(hidden_states=hidden_states,
fused_moe_out = self.experts(hidden_states=hidden_states,
router_logits=router_logits)
if shared_output is not None:
if hidden_states.dtype != torch.float16:
final_hidden_states = final_hidden_states + shared_output
if self.shared_experts is not None:
shared_output, final_hidden_states = fused_moe_out
else:
shared_output = None
final_hidden_states = fused_moe_out
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states = final_hidden_states + shared_output \
* (1. / self.routed_scaling_factor)
if hidden_states.dtype != torch.float16:
final_hidden_states *= self.routed_scaling_factor
elif self.shared_experts is not None:
assert shared_output is not None
shared_output *= (1. / self.routed_scaling_factor)
if self.shared_experts is not None:
assert shared_output is not None
final_hidden_states += shared_output
if self.tp_size > 1:
final_hidden_states = (
......
......@@ -184,6 +184,8 @@ class Glm4MoE(nn.Module):
if self.n_shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
else:
shared_output = None
router_logits = self.gate(hidden_states.to(dtype=torch.float32))
final_hidden_states = self.experts(
hidden_states=hidden_states,
......
......@@ -36,6 +36,7 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
......@@ -73,7 +74,18 @@ class Llama4MoE(nn.Module):
quant_config=None,
prefix=f"{prefix}.router")
self.experts = FusedMoE(
self.shared_expert = LlamaMLP(
hidden_size=config.hidden_size,
intermediate_size=intermediate_size_moe,
hidden_act="silu",
quant_config=quant_config,
bias=False,
prefix=f"{prefix}.shared_expert",
reduce_results=False,
)
self.experts = SharedFusedMoE(
shared_experts=self.shared_expert,
num_experts=config.num_local_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
......@@ -83,22 +95,13 @@ class Llama4MoE(nn.Module):
reduce_results=False,
renormalize=False,
quant_config=quant_config,
prefix=f"{prefix}.experts")
self.shared_expert = LlamaMLP(
hidden_size=config.hidden_size,
intermediate_size=intermediate_size_moe,
hidden_act="silu",
quant_config=quant_config,
bias=False,
prefix=f"{prefix}.shared_expert",
reduce_results=self.experts.must_reduce_shared_expert_outputs(),
prefix=f"{prefix}.experts",
)
def forward(self, hidden_states):
router_logits, _ = self.router(hidden_states)
shared_out = self.shared_expert(hidden_states)
routed_out = self.experts(
shared_out, routed_out = self.experts(
hidden_states=hidden_states,
router_logits=router_logits,
)
......
......@@ -500,7 +500,8 @@ class Worker(WorkerBase):
parallel_config = self.vllm_config.parallel_config
moe_modules = [
module for module in self.model_runner.model.modules()
if module.__class__.__name__ == "FusedMoE"
if (module.__class__.__name__ == "FusedMoE"
or module.__class__.__name__ == "SharedFusedMoE")
]
num_local_experts = moe_modules[0].moe_config.num_local_experts
assert all(module.moe_config.num_local_experts == num_local_experts
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment