Unverified Commit 344a0017 authored by Alexander Matveev's avatar Alexander Matveev Committed by GitHub
Browse files

[Performance] Dual stream execution of "shared_experts" and "selected_experts"...


[Performance] Dual stream execution of "shared_experts" and "selected_experts" inside FusedMoE (#26440)
Signed-off-by: default avatarAlexander Matveev <amatveev@redhat.com>
parent becb7de4
...@@ -213,6 +213,7 @@ if TYPE_CHECKING: ...@@ -213,6 +213,7 @@ if TYPE_CHECKING:
VLLM_NCCL_INCLUDE_PATH: str | None = None VLLM_NCCL_INCLUDE_PATH: str | None = None
VLLM_USE_FBGEMM: bool = False VLLM_USE_FBGEMM: bool = False
VLLM_GC_DEBUG: str = "" VLLM_GC_DEBUG: str = ""
VLLM_DISABLE_SHARED_EXPERTS_STREAM: bool = False
def get_default_cache_root(): def get_default_cache_root():
...@@ -1379,6 +1380,10 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1379,6 +1380,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
# - VLLM_GC_DEBUG='{"top_objects":5}': enable GC debugger with # - VLLM_GC_DEBUG='{"top_objects":5}': enable GC debugger with
# top 5 collected objects # top 5 collected objects
"VLLM_GC_DEBUG": lambda: os.getenv("VLLM_GC_DEBUG", ""), "VLLM_GC_DEBUG": lambda: os.getenv("VLLM_GC_DEBUG", ""),
# Disables parallel execution of shared_experts via separate cuda stream
"VLLM_DISABLE_SHARED_EXPERTS_STREAM": lambda: os.getenv(
"VLLM_DISABLE_SHARED_EXPERTS_STREAM", False
),
} }
# --8<-- [end:env-vars-definition] # --8<-- [end:env-vars-definition]
......
...@@ -57,7 +57,7 @@ from vllm.platforms import current_platform ...@@ -57,7 +57,7 @@ from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum from vllm.platforms.interface import CpuArchEnum
from vllm.utils import cdiv, has_deep_ep, has_pplx, round_up from vllm.utils import cdiv, has_deep_ep, has_pplx, round_up
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
from vllm.utils.torch_utils import direct_register_custom_op from vllm.utils.torch_utils import current_stream, direct_register_custom_op
from vllm.v1.worker.ubatching import dbo_current_ubatch_id from vllm.v1.worker.ubatching import dbo_current_ubatch_id
if current_platform.is_cuda_alike(): if current_platform.is_cuda_alike():
...@@ -1082,6 +1082,17 @@ class FusedMoE(CustomOp): ...@@ -1082,6 +1082,17 @@ class FusedMoE(CustomOp):
n_shared_experts: int | None = None, n_shared_experts: int | None = None,
): ):
super().__init__() super().__init__()
# Allow disabling of the separate shared experts stream for
# debug purposes.
# TODO: Remove this after more extensive testings with TP/DP
# and other execution modes
if envs.VLLM_DISABLE_SHARED_EXPERTS_STREAM:
logger.info_once("Disabling MoE shared_experts cuda stream")
self.shared_experts_stream = None
else:
self.shared_experts_stream = torch.cuda.Stream()
if params_dtype is None: if params_dtype is None:
params_dtype = torch.get_default_dtype() params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype self.params_dtype = params_dtype
...@@ -1332,6 +1343,10 @@ class FusedMoE(CustomOp): ...@@ -1332,6 +1343,10 @@ class FusedMoE(CustomOp):
def shared_experts(self) -> torch.nn.Module | None: def shared_experts(self) -> torch.nn.Module | None:
return None return None
@property
def gate(self) -> torch.nn.Module | None:
return None
@property @property
def tp_size(self): def tp_size(self):
return self.moe_parallel_config.tp_size return self.moe_parallel_config.tp_size
...@@ -1390,6 +1405,11 @@ class FusedMoE(CustomOp): ...@@ -1390,6 +1405,11 @@ class FusedMoE(CustomOp):
or (self.dp_size > 1 and self.use_flashinfer_cutlass_kernels) or (self.dp_size > 1 and self.use_flashinfer_cutlass_kernels)
) )
@property
def is_internal_router(self) -> bool:
# By default, router/gate is called before FusedMoE forward pass
return False
def update_expert_map(self): def update_expert_map(self):
# ep_size and ep_rank should already be updated # ep_size and ep_rank should already be updated
assert self.expert_map is not None assert self.expert_map is not None
...@@ -2168,6 +2188,7 @@ class FusedMoE(CustomOp): ...@@ -2168,6 +2188,7 @@ class FusedMoE(CustomOp):
self, self,
full_hidden_states: torch.Tensor, full_hidden_states: torch.Tensor,
full_router_logits: torch.Tensor, full_router_logits: torch.Tensor,
has_separate_shared_experts: bool,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.batched_hidden_states is not None assert self.batched_hidden_states is not None
assert self.batched_router_logits is not None assert self.batched_router_logits is not None
...@@ -2216,11 +2237,23 @@ class FusedMoE(CustomOp): ...@@ -2216,11 +2237,23 @@ class FusedMoE(CustomOp):
# If there are shared experts but we are not using a modular kernel, # If there are shared experts but we are not using a modular kernel,
# the shared experts must be called here # the shared experts must be called here
if ( if has_separate_shared_experts:
not isinstance(self.quant_method.fused_experts, FusedMoEModularKernel) assert self.shared_experts is not None
and self.shared_experts is not None
): if self.shared_experts_stream is not None:
# For chunked, we start the shared experts stream here
# (Note that no concurrency with the router/gate)
self.shared_experts_stream.wait_stream(current_stream())
with torch.cuda.stream(self.shared_experts_stream):
# Note that staged_hidden_states clone() is necessary
# here to avoid conflict with the main stream
shared_output = self.shared_experts(
staged_hidden_states.clone()
)
else:
shared_output = self.shared_experts(staged_hidden_states) shared_output = self.shared_experts(staged_hidden_states)
else: else:
shared_output = None shared_output = None
...@@ -2249,9 +2282,14 @@ class FusedMoE(CustomOp): ...@@ -2249,9 +2282,14 @@ class FusedMoE(CustomOp):
logical_replica_count=self.logical_replica_count, logical_replica_count=self.logical_replica_count,
) )
if shared_output is not None: if has_separate_shared_experts:
assert not isinstance(final_hidden_states, tuple) assert not isinstance(final_hidden_states, tuple)
assert self.shared_experts is not None assert self.shared_experts is not None
# Here we finish the shared experts stream
if self.shared_experts_stream is not None:
current_stream().wait_stream(self.shared_experts_stream)
final_hidden_states = ( final_hidden_states = (
shared_output, shared_output,
final_hidden_states, final_hidden_states,
...@@ -2321,8 +2359,33 @@ class FusedMoE(CustomOp): ...@@ -2321,8 +2359,33 @@ class FusedMoE(CustomOp):
self.ensure_moe_quant_config() self.ensure_moe_quant_config()
if self.use_dp_chunking: has_separate_shared_experts = (
return self.forward_impl_chunked(hidden_states, router_logits) not isinstance(self.quant_method.fused_experts, FusedMoEModularKernel)
and self.shared_experts is not None
)
use_chunked_impl = self.use_dp_chunking
if (
has_separate_shared_experts
and not use_chunked_impl
and self.shared_experts_stream is not None
):
# Start the separate shared experts stream here since we want
# to run in parallel with the router/gate (next op below)
self.shared_experts_stream.wait_stream(current_stream())
# If router/gate provided, then apply it here.
# (Note: This code runs only when "overlapped mode" is on to allow
# parallel execution of shared experts with the FusedMoE via
# separate cuda stream)
if self.gate is not None:
router_logits, _ = self.gate(hidden_states)
if use_chunked_impl:
return self.forward_impl_chunked(
hidden_states, router_logits, has_separate_shared_experts
)
do_naive_dispatch_combine: bool = ( do_naive_dispatch_combine: bool = (
self.dp_size > 1 and not self.quant_method.using_modular_kernel self.dp_size > 1 and not self.quant_method.using_modular_kernel
...@@ -2330,10 +2393,16 @@ class FusedMoE(CustomOp): ...@@ -2330,10 +2393,16 @@ class FusedMoE(CustomOp):
# If there are shared experts but we are not using a modular kernel, the # If there are shared experts but we are not using a modular kernel, the
# shared experts must be called here # shared experts must be called here
if ( if has_separate_shared_experts:
not isinstance(self.quant_method.fused_experts, FusedMoEModularKernel) assert self.shared_experts is not None
and self.shared_experts is not None
): if self.shared_experts_stream is not None:
# Run shared experts in parallel on a separate stream
with torch.cuda.stream(self.shared_experts_stream):
# Note that hidden_states clone() is necessary here to avoid
# conflict with the main stream
shared_output = self.shared_experts(hidden_states.clone())
else:
shared_output = self.shared_experts(hidden_states) shared_output = self.shared_experts(hidden_states)
else: else:
shared_output = None shared_output = None
...@@ -2377,9 +2446,14 @@ class FusedMoE(CustomOp): ...@@ -2377,9 +2446,14 @@ class FusedMoE(CustomOp):
logical_replica_count=self.logical_replica_count, logical_replica_count=self.logical_replica_count,
) )
if shared_output is not None: if has_separate_shared_experts:
assert not isinstance(final_hidden_states, tuple) assert not isinstance(final_hidden_states, tuple)
assert self.shared_experts is not None assert self.shared_experts is not None
# Wait for the parallel shared experts stream to finish here
if self.shared_experts_stream is not None:
current_stream().wait_stream(self.shared_experts_stream)
final_hidden_states = ( final_hidden_states = (
shared_output, shared_output,
final_hidden_states, final_hidden_states,
......
...@@ -18,25 +18,40 @@ class SharedFusedMoE(FusedMoE): ...@@ -18,25 +18,40 @@ class SharedFusedMoE(FusedMoE):
def __init__( def __init__(
self, self,
shared_experts: torch.nn.Module | None, shared_experts: torch.nn.Module | None,
gate: torch.nn.Module | None = None,
use_overlapped: bool = True, use_overlapped: bool = True,
**kwargs, **kwargs,
): ):
super().__init__(**kwargs) super().__init__(**kwargs)
self._shared_experts = shared_experts self._shared_experts = shared_experts
# Disable shared expert overlap if EP is disabled or we are not using # Disable shared expert overlap if EP is disabled or we are not using
# flashinfer + DP since there is nothing to be gained in this case. # flashinfer + DP since there is nothing to be gained in this case.
# Disabling the overlap optimization also prevents the shared experts # Disabling the overlap optimization also prevents the shared experts
# from being hidden from torch.compile. # from being hidden from torch.compile.
self.use_overlapped = ( self.use_overlapped = (
use_overlapped use_overlapped
and not (self.use_ep or self.use_flashinfer_cutlass_kernels) and not (
self.use_ep
or (self.use_flashinfer_cutlass_kernels and self.dp_size > 1)
)
and self._shared_experts is not None and self._shared_experts is not None
) )
self._gate = gate
@property @property
def shared_experts(self) -> torch.nn.Module | None: def shared_experts(self) -> torch.nn.Module | None:
return self._shared_experts if self.use_overlapped else None return self._shared_experts if self.use_overlapped else None
@property
def gate(self) -> torch.nn.Module | None:
return self._gate if self.use_overlapped else None
@property
def is_internal_router(self) -> bool:
return self.gate is not None
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
......
...@@ -227,6 +227,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -227,6 +227,7 @@ class DeepseekV2MoE(nn.Module):
self.experts = SharedFusedMoE( self.experts = SharedFusedMoE(
shared_experts=self.shared_experts, shared_experts=self.shared_experts,
gate=self.gate,
num_experts=config.n_routed_experts, num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok, top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
...@@ -264,9 +265,14 @@ class DeepseekV2MoE(nn.Module): ...@@ -264,9 +265,14 @@ class DeepseekV2MoE(nn.Module):
if self.is_sequence_parallel: if self.is_sequence_parallel:
hidden_states = sequence_parallel_chunk(hidden_states) hidden_states = sequence_parallel_chunk(hidden_states)
if self.experts.is_internal_router:
# In this case, the gate/router runs inside the FusedMoE class
fused_moe_out = self.experts(
hidden_states=hidden_states, router_logits=hidden_states
)
else:
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states) router_logits, _ = self.gate(hidden_states)
fused_moe_out = self.experts( fused_moe_out = self.experts(
hidden_states=hidden_states, router_logits=router_logits hidden_states=hidden_states, router_logits=router_logits
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment