Unverified Commit d71af5f5 authored by Wentao Ye's avatar Wentao Ye Committed by GitHub
Browse files

[Feature] Enable TP + EP `shared_experts` overlap with router, 3.7% E2E...


[Feature] Enable TP + EP `shared_experts` overlap with router, 3.7% E2E performance improvement (#28164)
Signed-off-by: default avataryewentao256 <zhyanwentao@126.com>
parent 90189c71
...@@ -1178,7 +1178,7 @@ class FusedMoE(CustomOp): ...@@ -1178,7 +1178,7 @@ class FusedMoE(CustomOp):
hidden_size: Input hidden state size of the transformer hidden_size: Input hidden state size of the transformer
intermediate_size: Intermediate size of the experts intermediate_size: Intermediate size of the experts
params_dtype: Data type for the parameters. params_dtype: Data type for the parameters.
reduce_results: Whether to all all_reduce on the output of the layer reduce_results: Whether to all_reduce on the output of the layer
renormalize: Whether to renormalize the logits in the fused_moe kernel renormalize: Whether to renormalize the logits in the fused_moe kernel
quant_config: Quantization configure. quant_config: Quantization configure.
enable_eplb: Whether to enable expert parallelism load balancer. enable_eplb: Whether to enable expert parallelism load balancer.
......
...@@ -3,7 +3,10 @@ ...@@ -3,7 +3,10 @@
import torch import torch
from vllm.distributed import tensor_model_parallel_all_reduce from vllm.distributed import (
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from vllm.model_executor.layers.fused_moe.layer import FusedMoE from vllm.model_executor.layers.fused_moe.layer import FusedMoE
...@@ -25,16 +28,13 @@ class SharedFusedMoE(FusedMoE): ...@@ -25,16 +28,13 @@ class SharedFusedMoE(FusedMoE):
super().__init__(**kwargs) super().__init__(**kwargs)
self._shared_experts = shared_experts self._shared_experts = shared_experts
# Disable shared expert overlap if EP is disabled or we are not using # Disable shared expert overlap if we are not using
# flashinfer + DP since there is nothing to be gained in this case. # flashinfer + DP since there is nothing to be gained in this case.
# Disabling the overlap optimization also prevents the shared experts # Disabling the overlap optimization also prevents the shared experts
# from being hidden from torch.compile. # from being hidden from torch.compile.
self.use_overlapped = ( self.use_overlapped = (
use_overlapped use_overlapped
and not ( and not (self.use_flashinfer_cutlass_kernels and self.dp_size > 1)
self.use_ep
or (self.use_flashinfer_cutlass_kernels and self.dp_size > 1)
)
and self._shared_experts is not None and self._shared_experts is not None
) )
...@@ -65,7 +65,7 @@ class SharedFusedMoE(FusedMoE): ...@@ -65,7 +65,7 @@ class SharedFusedMoE(FusedMoE):
# should have been created with reduce_results=False. # should have been created with reduce_results=False.
if ( if (
self.reduce_results self.reduce_results
and self.tp_size > 1 and get_tensor_model_parallel_world_size() > 1
and self.must_reduce_shared_expert_outputs() and self.must_reduce_shared_expert_outputs()
): ):
shared_out = tensor_model_parallel_all_reduce(shared_out) shared_out = tensor_model_parallel_all_reduce(shared_out)
...@@ -81,4 +81,12 @@ class SharedFusedMoE(FusedMoE): ...@@ -81,4 +81,12 @@ class SharedFusedMoE(FusedMoE):
hidden_states=hidden_states, hidden_states=hidden_states,
router_logits=router_logits, router_logits=router_logits,
) )
# ensure early TP reduction of shared expert outputs when required
if (
shared_out is not None
and self.reduce_results
and get_tensor_model_parallel_world_size() > 1
and self.must_reduce_shared_expert_outputs()
):
shared_out = tensor_model_parallel_all_reduce(shared_out)
return shared_out, fused_out return shared_out, fused_out
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment