Commit fc5eb9e1 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'dev_092_shared_expert_overlap' into 'v0.9.2-dev'

feat: enable shared expert overlap.

See merge request dcutoolkit/deeplearing/vllm!339
parents ffc00331 ee19dca6
......@@ -1939,6 +1939,24 @@ class ParallelConfig:
assert last_exc is not None
raise last_exc
# The all_reduce at the end of attention (during o_proj) means that
# inputs are replicated across each rank of the tensor parallel group.
# If using expert-parallelism with DeepEP All2All ops, replicated
# tokens results in useless duplicate computation and communication.
#
# In this case, ensure the input to the experts is sequence parallel
# to avoid the excess work.
#
# Not needed for pplx-kernels as it can handle duplicate input tokens.
@property
def use_sequence_parallel_moe(self) -> bool:
return (envs.VLLM_ALL2ALL_BACKEND
in ("allgather_reducescatter", "naive",
"deepep_high_throughput", "deepep_low_latency")
and self.enable_expert_parallel
and self.tensor_parallel_size > 1
and self.data_parallel_size > 1)
@staticmethod
def has_unfinished_dp(dp_group: "ProcessGroup",
has_unfinished: bool) -> bool:
......
......@@ -204,6 +204,7 @@ if TYPE_CHECKING:
VLLM_ZERO_OVERHEAD_ENHANCE: bool = False
VLLM_USE_FUSED_QA_KVA_GEMM: bool = False
VLLM_V1_FAST_TOKEN_ID_COPY: bool = False
VLLM_DISABLE_SHARED_EXPERTS_STREAM:bool = True
def get_default_cache_root():
return os.getenv(
......@@ -1306,6 +1307,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_ENABLE_DEEPEP_HT_DEEPGEMM":
lambda: (os.getenv('VLLM_ENABLE_DEEPEP_HT_DEEPGEMM', '1').lower() in
("true", "1")),
# Only quantized DeepSeek models supported.
# Unquantized versions are not supported.
"VLLM_USE_FUSED_QA_KVA_GEMM":
......@@ -1318,6 +1320,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_V1_FAST_TOKEN_ID_COPY":
lambda: (os.environ.get("VLLM_V1_FAST_TOKEN_ID_COPY", "False").lower() in
("true", "1")),
"VLLM_DISABLE_SHARED_EXPERTS_STREAM": lambda: bool(
int(os.getenv("VLLM_DISABLE_SHARED_EXPERTS_STREAM", "1"))
),
}
# --8<-- [end:env-vars-definition]
......
......@@ -74,6 +74,26 @@ else:
logger = init_logger(__name__)
# Global auxilary stream for running operations in background streams.
# We have single global auxilary stream to avoid an explosion of streams
# for every layer (and make profiling look sane).
#
# aux_stream() is currently used for:
# - MoE shared_expert overlap with router
_aux_stream: torch.cuda.Stream | None = None
def aux_stream() -> torch.cuda.Stream | None:
"""
Ensures aux_stream is initialized only once
"""
global _aux_stream
from vllm.platforms import current_platform
if _aux_stream is None and current_platform.is_cuda_alike():
_aux_stream = torch.cuda.Stream()
return _aux_stream
class FusedMoeWeightScaleSupported(Enum):
TENSOR = "tensor"
......@@ -698,6 +718,21 @@ class FusedMoE(torch.nn.Module):
routed_scaling_factor: Optional[float] = 1.0,
):
super().__init__()
# Allow disabling of the separate shared experts stream for
# debug purposes.
# TODO: Remove this after more extensive testings with TP/DP
# and other execution modes
if envs.VLLM_DISABLE_SHARED_EXPERTS_STREAM:
logger.info_once("Disabling MoE shared_experts cuda stream")
self.shared_experts_stream = None
else:
# TODO(rob): enable shared expert overlap with non-cuda-alike.
# aux_stream() returns None on non-cuda-alike platforms.
self.shared_experts_stream = aux_stream()
if self.shared_experts_stream is not None:
logger.info_once("Enabled separate cuda stream for MoE shared_experts")
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
......@@ -911,7 +946,7 @@ class FusedMoE(torch.nn.Module):
return self.moe_parallel_config.use_deepep_ll_kernels
@property
def shared_experts(self) -> Optional[torch.nn.Module]:
def shared_experts(self) -> torch.nn.Module | None:
return None
def _load_per_tensor_weight_scale(self, shard_id: str,
......@@ -1451,6 +1486,7 @@ class FusedMoE(torch.nn.Module):
def forward(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor,
hidden_states_copy: Optional[torch.Tensor] = None, # for shared expert overlap
shared_output: Optional[torch.Tensor] = None,
i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None, **_
......@@ -1467,7 +1503,7 @@ class FusedMoE(torch.nn.Module):
i_q, i_s)
else:
return torch.ops.vllm.moe_forward_shared(hidden_states, router_logits,
self.layer_name, shared_output)
self.layer_name, hidden_states_copy, shared_output)
def forward_impl_chunked(self, full_hidden_states: torch.Tensor,
full_router_logits: torch.Tensor):
......@@ -1547,10 +1583,22 @@ class FusedMoE(torch.nn.Module):
def forward_impl(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor,
hidden_states_copy: Optional[torch.Tensor] = None,
shared_output: Optional[torch.Tensor] = None,
i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None, **_):
i_s: Optional[torch.Tensor] = None, **_)-> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.quant_method is not None
enable_shared_experts_overlap = False
if (self.shared_experts_stream is not None
and hidden_states_copy is not None
and self.shared_experts is not None
and not self.moe_parallel_config.use_pplx_kernels):
enable_shared_experts_overlap = True
hidden_states_copy.record_stream(self.shared_experts_stream)
self.shared_experts_stream.wait_stream(torch.cuda.current_stream())
if (self.moe_parallel_config.use_pplx_kernels):
#or self.moe_parallel_config.use_deepep_ll_kernels):
return self.forward_impl_chunked(hidden_states, router_logits)
......@@ -1619,18 +1667,45 @@ class FusedMoE(torch.nn.Module):
use_fused_gate=self.use_fused_gate,
)
if enable_shared_experts_overlap:
assert self.shared_experts is not None
# Run shared experts in parallel on a separate stream
# NOTE: We start the separate stream here and mark the
# sync end point immediately after it is done. This is
# important to avoid excessive stream allocations by the cuda
# graph replay later.
with torch.cuda.stream(self.shared_experts_stream):
# Note that hidden_states clone() is necessary here to avoid
# conflict with the main stream
assert hidden_states_copy is not None
shared_output = self.shared_experts(hidden_states_copy)
torch.cuda.current_stream().wait_stream(self.shared_experts_stream)
final_hidden_states = (
shared_output,
final_hidden_states,
)
def combine_output(states: torch.Tensor) -> torch.Tensor:
if do_naive_dispatch_combine:
final_hidden_states = get_ep_group().combine(final_hidden_states)
states = get_ep_group().combine(states)
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
# Default set to False. (May have to add shared expert outputs.
if envs.VLLM_ENABLE_TBO:
final_hidden_states = self.tbo_all_reduce(final_hidden_states)
states = self.tbo_all_reduce(states)
else:
final_hidden_states = self.maybe_all_reduce_tensor_model_parallel(
final_hidden_states)
return final_hidden_states
states = self.maybe_all_reduce_tensor_model_parallel(
states)
return states
if enable_shared_experts_overlap and not envs.USE_FUSED_RMS_QUANT:
return (
final_hidden_states[0],
combine_output(final_hidden_states[1]),
)
else:
return combine_output(final_hidden_states)
@classmethod
def make_expert_params_mapping(
......@@ -1720,18 +1795,20 @@ def moe_forward_shared(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
layer_name: str,
hidden_states_copy: Optional[torch.Tensor] = None,
shared_output: Optional[torch.Tensor] = None
) -> tuple[torch.Tensor, torch.Tensor]:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
assert self.shared_experts is not None
return self.forward_impl(hidden_states, router_logits, shared_output)
return self.forward_impl(hidden_states, router_logits, hidden_states_copy, shared_output)
def moe_forward_shared_fake(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
layer_name: str,
hidden_states_copy: Optional[torch.Tensor] = None,
shared_output: Optional[torch.Tensor] = None
) -> tuple[torch.Tensor, torch.Tensor]:
shared_out = torch.empty_like(hidden_states)
......@@ -1742,7 +1819,7 @@ def moe_forward_shared_fake(
direct_register_custom_op(
op_name="moe_forward_shared",
op_func=moe_forward_shared,
mutates_args=["hidden_states"],
mutates_args=["hidden_states", "hidden_states_copy"],
fake_impl=moe_forward_shared_fake,
tags=(torch.Tag.needs_fixed_stride_order,),
)
\ No newline at end of file
......@@ -34,7 +34,8 @@ class SharedFusedMoE(FusedMoE):
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor:
hidden_states_copy: Optional[torch.Tensor] = None
) -> tuple[torch.Tensor, torch.Tensor]|torch.Tensor:
if not self.use_overlapped:
shared_out = self._shared_experts(hidden_states)
......@@ -53,6 +54,6 @@ class SharedFusedMoE(FusedMoE):
fused_out = super().forward(
hidden_states=hidden_states,
router_logits=router_logits,
hidden_states_copy = hidden_states_copy,
)
return fused_out
......@@ -71,7 +71,6 @@ from .utils import (PPMissingLayer, is_pp_missing_parameter,
from vllm import _custom_ops as ops
from vllm.utils import W8a8GetCacheJSON
class DeepseekV2MLP(nn.Module):
def __init__(
......@@ -187,8 +186,27 @@ class DeepseekV2MoE(nn.Module):
(envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" or \
envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency")
self.enable_shared_experts_overlap = False
if not self.use_deepep:
self.experts = FusedMoE(
if config.n_shared_experts is not None:
intermediate_size = (config.moe_intermediate_size *
config.n_shared_experts)
self.shared_experts = DeepseekV2MLP(
hidden_size=config.hidden_size,
intermediate_size=intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
reduce_results = False,
prefix=f"{prefix}.shared_experts",
)
self.enable_shared_experts_overlap = (not envs.VLLM_DISABLE_SHARED_EXPERTS_STREAM
and not envs.USE_FUSED_RMS_QUANT
and not envs.VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD
and config.n_shared_experts is not None)
if self.enable_shared_experts_overlap:
self.experts = SharedFusedMoE(
shared_experts=self.shared_experts,
num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
......@@ -205,19 +223,24 @@ class DeepseekV2MoE(nn.Module):
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
routed_scaling_factor=self.routed_scaling_factor)
if config.n_shared_experts is not None:
intermediate_size = (config.moe_intermediate_size *
config.n_shared_experts)
self.shared_experts = DeepseekV2MLP(
else:
self.experts = FusedMoE(
num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=intermediate_size,
hidden_act=config.hidden_act,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
reduce_results=self.experts.must_reduce_shared_expert_outputs(
),
prefix=f"{prefix}.shared_experts",
)
use_grouped_topk=True,
num_expert_group=config.n_group,
topk_group=config.topk_group,
prefix=f"{prefix}.experts",
scoring_func=config.scoring_func,
e_score_correction_bias=self.gate.e_score_correction_bias,
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
routed_scaling_factor=self.routed_scaling_factor)
else:
if config.n_shared_experts is not None:
intermediate_size = (config.moe_intermediate_size *
......@@ -249,6 +272,8 @@ class DeepseekV2MoE(nn.Module):
routed_scaling_factor=self.routed_scaling_factor,
shared_experts=self.shared_experts)
self.run_shared_expert_singlely = (self.n_shared_experts is not None and not self.enable_shared_experts_overlap)
from vllm.two_batch_overlap.two_batch_overlap import tbo_all_reduce
self.tbo_all_reduce = tbo_all_reduce
......@@ -261,8 +286,17 @@ class DeepseekV2MoE(nn.Module):
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
def shared_exprts_overlap_pass(
hidden_states: torch.Tensor, router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
hidden_states_copy = hidden_states.clone()
return self.experts(
hidden_states=hidden_states,
router_logits=router_logits,
hidden_states_copy = hidden_states_copy)
if envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and xqxs is not None:
if self.n_shared_experts is not None:
if self.n_shared_experts is not None and not self.enable_shared_experts_overlap:
shared_output = self.shared_experts(hidden_states, xqxs=xqxs)
router_logits, _ = self.gate(hidden_states)
......@@ -272,6 +306,18 @@ class DeepseekV2MoE(nn.Module):
hidden_states=hidden_states,
router_logits=router_logits,
shared_output=shared_output)
else:
if self.enable_shared_experts_overlap:
assert self.shared_experts is not None
shared_output, final_hidden_states = shared_exprts_overlap_pass(hidden_states, router_logits)
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
if hidden_states.dtype != torch.float16:
final_hidden_states *= self.routed_scaling_factor
final_hidden_states += shared_output
else:
assert shared_output is not None
final_hidden_states += (shared_output * (1.0 / self.routed_scaling_factor))
else:
if hidden_states.dtype != torch.float16:
final_hidden_states = self.experts(
......@@ -291,19 +337,10 @@ class DeepseekV2MoE(nn.Module):
# See DeepseekV2DecoderLayer for more details.
final_hidden_states = final_hidden_states + shared_output \
* (1. / self.routed_scaling_factor)
if self.tp_size > 1:
if envs.VLLM_ENABLE_TBO:
final_hidden_states = self.tbo_all_reduce(final_hidden_states)
else:
final_hidden_states = (
self.experts.maybe_all_reduce_tensor_model_parallel(
final_hidden_states))
return final_hidden_states.view(num_tokens, hidden_dim)
else:
if not self.enable_expert_parallel:
i_q, i_s = None, None
if self.n_shared_experts is not None:
if self.run_shared_expert_singlely:
if envs.USE_FUSED_RMS_QUANT:
shared_output, new_resi, i_q, i_s = self.shared_experts(hidden_states, rms_weight, residual, update_hd=True)
else:
......@@ -311,6 +348,18 @@ class DeepseekV2MoE(nn.Module):
router_logits, _ = self.gate(hidden_states)
if self.enable_shared_experts_overlap:
assert self.shared_experts is not None
shared_output, final_hidden_states = shared_exprts_overlap_pass(hidden_states, router_logits)
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
if hidden_states.dtype != torch.float16:
final_hidden_states *= self.routed_scaling_factor
final_hidden_states += shared_output
else:
assert shared_output is not None
final_hidden_states += (shared_output * (1.0 / self.routed_scaling_factor))
else:
if envs.VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD:
final_hidden_states = self.experts(
hidden_states=hidden_states,
......@@ -340,7 +389,6 @@ class DeepseekV2MoE(nn.Module):
* (1. / self.routed_scaling_factor)
else:
router_logits, _ = self.gate(hidden_states)
if self.use_deepep:
shared_output, final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)
......@@ -354,12 +402,23 @@ class DeepseekV2MoE(nn.Module):
final_hidden_states = final_hidden_states + shared_output \
* (1. / self.routed_scaling_factor)
else:
if self.n_shared_experts is not None:
if self.run_shared_expert_singlely:
if envs.USE_FUSED_RMS_QUANT:
shared_output, new_resi = self.shared_experts(hidden_states, rms_weight, residual, update_hd=True)
else:
shared_output = self.shared_experts(hidden_states)
if self.enable_shared_experts_overlap:
assert self.shared_experts is not None
shared_output, final_hidden_states = shared_exprts_overlap_pass(hidden_states, router_logits)
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
if hidden_states.dtype != torch.float16:
final_hidden_states += shared_output
else:
assert shared_output is not None
final_hidden_states += (shared_output * (1. / self.routed_scaling_factor))
else:
final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment