Commit 0019ecdc authored by laibao's avatar laibao
Browse files

feat: Support shared expert overlap with expert.

parent 3ab9494d
...@@ -1939,6 +1939,25 @@ class ParallelConfig: ...@@ -1939,6 +1939,25 @@ class ParallelConfig:
assert last_exc is not None assert last_exc is not None
raise last_exc raise last_exc
# The all_reduce at the end of attention (during o_proj) means that
# inputs are replicated across each rank of the tensor parallel group.
# If using expert-parallelism with DeepEP All2All ops, replicated
# tokens results in useless duplicate computation and communication.
#
# In this case, ensure the input to the experts is sequence parallel
# to avoid the excess work.
#
# Not needed for pplx-kernels as it can handle duplicate input tokens.
@property
def use_sequence_parallel_moe(self) -> bool:
return (envs.VLLM_ALL2ALL_BACKEND
in ("allgather_reducescatter", "naive",
"deepep_high_throughput", "deepep_low_latency")
and self.enable_expert_parallel
and self.tensor_parallel_size > 1
and self.data_parallel_size > 1)
@staticmethod @staticmethod
def has_unfinished_dp(dp_group: "ProcessGroup", def has_unfinished_dp(dp_group: "ProcessGroup",
has_unfinished: bool) -> bool: has_unfinished: bool) -> bool:
......
...@@ -194,6 +194,7 @@ if TYPE_CHECKING: ...@@ -194,6 +194,7 @@ if TYPE_CHECKING:
VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT: bool = False VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT: bool = False
VLLM_USE_FUSED_RMS_ROPE: bool = False VLLM_USE_FUSED_RMS_ROPE: bool = False
VLLM_USE_MARLIN_W16A16_MOE:bool = False VLLM_USE_MARLIN_W16A16_MOE:bool = False
VLLM_DISABLE_SHARED_EXPERTS_STREAM:bool = False
def get_default_cache_root(): def get_default_cache_root():
return os.getenv( return os.getenv(
...@@ -1265,7 +1266,10 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1265,7 +1266,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_MARLIN_W16A16_MOE": "VLLM_USE_MARLIN_W16A16_MOE":
lambda: (os.environ.get("VLLM_USE_MARLIN_W16A16_MOE", "False").lower() in lambda: (os.environ.get("VLLM_USE_MARLIN_W16A16_MOE", "False").lower() in
("true", "1")), ("true", "1")),
"VLLM_DISABLE_SHARED_EXPERTS_STREAM": lambda: bool(
int(os.getenv("VLLM_DISABLE_SHARED_EXPERTS_STREAM", "0"))
),
} }
# --8<-- [end:env-vars-definition] # --8<-- [end:env-vars-definition]
......
...@@ -10,6 +10,7 @@ from vllm.model_executor.layers.fused_moe.layer import ( ...@@ -10,6 +10,7 @@ from vllm.model_executor.layers.fused_moe.layer import (
from vllm.model_executor.layers.fused_moe.modular_kernel import ( from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEActivationFormat, FusedMoEPermuteExpertsUnpermute, FusedMoEActivationFormat, FusedMoEPermuteExpertsUnpermute,
FusedMoEPrepareAndFinalize) FusedMoEPrepareAndFinalize)
from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE
from vllm.triton_utils import HAS_TRITON from vllm.triton_utils import HAS_TRITON
_config: Optional[dict[str, Any]] = None _config: Optional[dict[str, Any]] = None
...@@ -30,6 +31,7 @@ def get_config() -> Optional[dict[str, Any]]: ...@@ -30,6 +31,7 @@ def get_config() -> Optional[dict[str, Any]]:
__all__ = [ __all__ = [
"FusedMoE", "FusedMoE",
"SharedFusedMoE",
"FusedMoEConfig", "FusedMoEConfig",
"FusedMoEMethodBase", "FusedMoEMethodBase",
"FusedMoeWeightScaleSupported", "FusedMoeWeightScaleSupported",
......
...@@ -73,6 +73,26 @@ else: ...@@ -73,6 +73,26 @@ else:
logger = init_logger(__name__) logger = init_logger(__name__)
# Global auxilary stream for running operations in background streams.
# We have single global auxilary stream to avoid an explosion of streams
# for every layer (and make profiling look sane).
#
# aux_stream() is currently used for:
# - MoE shared_expert overlap with router
_aux_stream: torch.cuda.Stream | None = None
def aux_stream() -> torch.cuda.Stream | None:
"""
Ensures aux_stream is initialized only once
"""
global _aux_stream
from vllm.platforms import current_platform
if _aux_stream is None and current_platform.is_cuda_alike():
_aux_stream = torch.cuda.Stream()
return _aux_stream
class FusedMoeWeightScaleSupported(Enum): class FusedMoeWeightScaleSupported(Enum):
TENSOR = "tensor" TENSOR = "tensor"
...@@ -686,6 +706,21 @@ class FusedMoE(torch.nn.Module): ...@@ -686,6 +706,21 @@ class FusedMoE(torch.nn.Module):
routed_scaling_factor: Optional[float] = 1.0, routed_scaling_factor: Optional[float] = 1.0,
): ):
super().__init__() super().__init__()
# Allow disabling of the separate shared experts stream for
# debug purposes.
# TODO: Remove this after more extensive testings with TP/DP
# and other execution modes
if envs.VLLM_DISABLE_SHARED_EXPERTS_STREAM:
logger.info_once("Disabling MoE shared_experts cuda stream")
self.shared_experts_stream = None
else:
# TODO(rob): enable shared expert overlap with non-cuda-alike.
# aux_stream() returns None on non-cuda-alike platforms.
self.shared_experts_stream = aux_stream()
if self.shared_experts_stream is not None:
logger.info_once("Enabled separate cuda stream for MoE shared_experts")
if params_dtype is None: if params_dtype is None:
params_dtype = torch.get_default_dtype() params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype self.params_dtype = params_dtype
...@@ -898,6 +933,19 @@ class FusedMoE(torch.nn.Module): ...@@ -898,6 +933,19 @@ class FusedMoE(torch.nn.Module):
def use_deepep_ll_kernels(self): def use_deepep_ll_kernels(self):
return self.moe_parallel_config.use_deepep_ll_kernels return self.moe_parallel_config.use_deepep_ll_kernels
@property
def shared_experts(self) -> torch.nn.Module | None:
return None
@property
def gate(self) -> torch.nn.Module | None:
return None
@property
def is_internal_router(self) -> bool:
# By default, router/gate is called before FusedMoE forward pass
return False
def _load_per_tensor_weight_scale(self, shard_id: str, def _load_per_tensor_weight_scale(self, shard_id: str,
param: torch.nn.Parameter, param: torch.nn.Parameter,
loaded_weight: torch.Tensor, loaded_weight: torch.Tensor,
...@@ -1438,16 +1486,23 @@ class FusedMoE(torch.nn.Module): ...@@ -1438,16 +1486,23 @@ class FusedMoE(torch.nn.Module):
shared_output: Optional[torch.Tensor] = None, shared_output: Optional[torch.Tensor] = None,
i_q: Optional[torch.Tensor] = None, i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None, **_ i_s: Optional[torch.Tensor] = None, **_
): ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
# TODO: Once the OOM issue for the TPU backend is resolved, we will if self.shared_experts is None:
# switch to using the moe_forward custom op. # TODO: Once the OOM issue for the TPU backend is resolved, we will
if current_platform.is_tpu(): # switch to using the moe_forward custom op.
assert i_q is None and i_s is None, "moe.quant fused not support TPU now" if current_platform.is_tpu():
return self.forward_impl(hidden_states, router_logits) assert i_q is None and i_s is None, "moe.quant fused not support TPU now"
return self.forward_impl(hidden_states, router_logits)
else:
return torch.ops.vllm.moe_forward(hidden_states, router_logits,
self.layer_name, shared_output,
i_q, i_s)
else: else:
return torch.ops.vllm.moe_forward(hidden_states, router_logits, if current_platform.is_tpu():
self.layer_name, shared_output, assert i_q is None and i_s is None, "moe.quant fused not support TPU now"
i_q, i_s) return self.forward_impl(hidden_states, router_logits)
else:
return torch.ops.vllm.moe_forward_shared(hidden_states, router_logits, self.layer_name)
def forward_impl_chunked(self, full_hidden_states: torch.Tensor, def forward_impl_chunked(self, full_hidden_states: torch.Tensor,
full_router_logits: torch.Tensor): full_router_logits: torch.Tensor):
...@@ -1524,13 +1579,58 @@ class FusedMoE(torch.nn.Module): ...@@ -1524,13 +1579,58 @@ class FusedMoE(torch.nn.Module):
skip_result_store=chunk_start_ >= num_tokens) skip_result_store=chunk_start_ >= num_tokens)
return full_final_hidden_states return full_final_hidden_states
def _maybe_setup_shared_experts_stream(
self,
hidden_states: torch.Tensor,
has_separate_shared_experts: bool,
use_chunked_impl: bool,
) -> tuple[bool, torch.Tensor | None]:
use_shared_experts_stream = (
# current_platform.is_cuda()
True
and has_separate_shared_experts
and not use_chunked_impl
and self.shared_experts_stream is not None
# and (
# hidden_states.shape[0]
# <= envs.VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD
# )
)
hidden_states_clone: torch.Tensor | None = None
if use_shared_experts_stream:
assert self.shared_experts_stream is not None
# Clone BEFORE switching streams to avoid race condition
# where routed_expert kernel may mutate hidden_states.
# hidden_states_clone = hidden_states.clone()
# Record that the clone will be used by shared_experts_stream
# to avoid gc issue from deallocation of hidden_states_clone
# For more details: https://docs.pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html # noqa: E501
# NOTE: We don't need shared_output.record_stream(current_stream())
# because we synch the streams before using shared_output.
# hidden_states_clone.record_stream(self.shared_experts_stream)
# Mark sync start point for the separate shared experts
# stream here since we want to run in parallel with the
# router/gate (next op below)
assert self.shared_experts_stream is not None
self.shared_experts_stream.wait_stream(torch.cuda.current_stream())
return use_shared_experts_stream, hidden_states_clone
def forward_impl(self, hidden_states: torch.Tensor, def forward_impl(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
shared_output: Optional[torch.Tensor] = None, shared_output: Optional[torch.Tensor] = None,
i_q: Optional[torch.Tensor] = None, i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None, **_): i_s: Optional[torch.Tensor] = None, **_) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.quant_method is not None assert self.quant_method is not None
use_shared_experts_stream, hidden_states_clone = self._maybe_setup_shared_experts_stream(hidden_states,
self.shared_experts is not None and self.shared_experts_stream is not None,
self.moe_parallel_config.use_pplx_kernels or self.moe_parallel_config.use_deepep_ll_kernels)
if (self.moe_parallel_config.use_pplx_kernels if (self.moe_parallel_config.use_pplx_kernels
or self.moe_parallel_config.use_deepep_ll_kernels): or self.moe_parallel_config.use_deepep_ll_kernels):
return self.forward_impl_chunked(hidden_states, router_logits) return self.forward_impl_chunked(hidden_states, router_logits)
...@@ -1592,24 +1692,48 @@ class FusedMoE(torch.nn.Module): ...@@ -1592,24 +1692,48 @@ class FusedMoE(torch.nn.Module):
expert_load_view=self.expert_load_view, expert_load_view=self.expert_load_view,
logical_to_physical_map=self.logical_to_physical_map, logical_to_physical_map=self.logical_to_physical_map,
logical_replica_count=self.logical_replica_count, logical_replica_count=self.logical_replica_count,
shared_output=shared_output, shared_output=None,
use_nn_moe=self.use_nn_moe, use_nn_moe=self.use_nn_moe,
routed_scaling_factor=self.routed_scaling_factor, routed_scaling_factor=self.routed_scaling_factor,
use_fused_gate=self.use_fused_gate, use_fused_gate=self.use_fused_gate,
) )
if use_shared_experts_stream:
if do_naive_dispatch_combine: assert self.shared_experts is not None
final_hidden_states = get_ep_group().combine(final_hidden_states) # Run shared experts in parallel on a separate stream
# NOTE: We start the separate stream here and mark the
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): # sync end point immediately after it is done. This is
# Default set to False. (May have to add shared expert outputs. # important to avoid excessive stream allocations by the cuda
if envs.VLLM_ENABLE_TBO: # graph replay later.
final_hidden_states = self.tbo_all_reduce(final_hidden_states) with torch.cuda.stream(self.shared_experts_stream):
else: # Note that hidden_states clone() is necessary here to avoid
final_hidden_states = self.maybe_all_reduce_tensor_model_parallel( # conflict with the main stream
final_hidden_states) shared_output = self.shared_experts(hidden_states)
torch.cuda.current_stream().wait_stream(self.shared_experts_stream)
return final_hidden_states
final_hidden_states = (
shared_output,
final_hidden_states,
)
def combine_output(states: torch.Tensor) -> torch.Tensor:
if do_naive_dispatch_combine:
states = get_ep_group().combine(states)
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
# Default set to False. (May have to add shared expert outputs.
if envs.VLLM_ENABLE_TBO:
states = self.tbo_all_reduce(states)
else:
states = self.maybe_all_reduce_tensor_model_parallel(
states)
return states
if self.shared_experts is not None and not envs.USE_FUSED_RMS_QUANT:
return (
final_hidden_states[0],
combine_output(final_hidden_states[1]),
)
else:
return combine_output(final_hidden_states)
@classmethod @classmethod
def make_expert_params_mapping( def make_expert_params_mapping(
...@@ -1694,3 +1818,34 @@ direct_register_custom_op( ...@@ -1694,3 +1818,34 @@ direct_register_custom_op(
dispatch_key=current_platform.dispatch_key, dispatch_key=current_platform.dispatch_key,
tags=(torch.Tag.needs_fixed_stride_order, ), tags=(torch.Tag.needs_fixed_stride_order, ),
) )
def moe_forward_shared(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
layer_name: str) -> tuple[torch.Tensor, torch.Tensor]:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
assert self.quant_method is not None
out = self.forward_impl(hidden_states, router_logits)
return out
def moe_forward_shared_fake(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
layer_name: str) -> tuple[torch.Tensor, torch.Tensor]:
shared_out = torch.empty_like(hidden_states)
fused_out = torch.empty_like(hidden_states)
return shared_out, fused_out
direct_register_custom_op(
op_name="moe_forward_shared",
op_func=moe_forward_shared,
mutates_args=["hidden_states"],
fake_impl=moe_forward_shared_fake,
dispatch_key=current_platform.dispatch_key,
tags=(torch.Tag.needs_fixed_stride_order, ),
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import torch
from vllm.distributed import (
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
# TODO(bnell): Add shared + fused combo function? e.g. +
class SharedFusedMoE(FusedMoE):
"""
A FusedMoE operation that also computes the results of shared experts.
If an all2all communicator is being used the shared expert computation
can be interleaved with the fused all2all dispatch communication step.
"""
def __init__(
self,
shared_experts: torch.nn.Module | None,
gate: torch.nn.Module | None = None,
**kwargs,
):
super().__init__(**kwargs)
self._shared_experts = shared_experts
self._gate = gate
@property
def shared_experts(self) -> torch.nn.Module | None:
return self._shared_experts
@property
def gate(self) -> torch.nn.Module | None:
return self._gate
@property
def is_internal_router(self) -> bool:
return self.gate is not None
def forward(self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
shared_out, fused_out = super().forward(
hidden_states=hidden_states,
router_logits=router_logits,
)
# # ensure early TP reduction of shared expert outputs when required
# if (
# shared_out is not None
# and self.reduce_results
# and get_tensor_model_parallel_world_size() > 1
# and self.must_reduce_shared_expert_outputs()
# ):
# shared_out = tensor_model_parallel_all_reduce(shared_out)
return shared_out, fused_out
\ No newline at end of file
...@@ -43,6 +43,7 @@ from vllm.distributed import (get_ep_group, get_pp_group, get_dp_group, ...@@ -43,6 +43,7 @@ from vllm.distributed import (get_ep_group, get_pp_group, get_dp_group,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear, MergedColumnParallelLinear,
...@@ -163,6 +164,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -163,6 +164,7 @@ class DeepseekV2MoE(nn.Module):
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
parallel_config = vllm_config.parallel_config parallel_config = vllm_config.parallel_config
self.enable_eplb = enable_eplb self.enable_eplb = enable_eplb
self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
self.n_redundant_experts = parallel_config.num_redundant_experts self.n_redundant_experts = parallel_config.num_redundant_experts
self.n_logical_experts = self.n_routed_experts self.n_logical_experts = self.n_routed_experts
...@@ -175,24 +177,6 @@ class DeepseekV2MoE(nn.Module): ...@@ -175,24 +177,6 @@ class DeepseekV2MoE(nn.Module):
self.physical_expert_end = (self.physical_expert_start + self.physical_expert_end = (self.physical_expert_start +
self.n_local_physical_experts) self.n_local_physical_experts)
self.experts = FusedMoE(
num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
use_grouped_topk=True,
num_expert_group=config.n_group,
topk_group=config.topk_group,
prefix=f"{prefix}.experts",
scoring_func=config.scoring_func,
e_score_correction_bias=self.gate.e_score_correction_bias,
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
routed_scaling_factor=self.routed_scaling_factor)
if config.n_shared_experts is not None: if config.n_shared_experts is not None:
intermediate_size = (config.moe_intermediate_size * intermediate_size = (config.moe_intermediate_size *
config.n_shared_experts) config.n_shared_experts)
...@@ -201,10 +185,51 @@ class DeepseekV2MoE(nn.Module): ...@@ -201,10 +185,51 @@ class DeepseekV2MoE(nn.Module):
intermediate_size=intermediate_size, intermediate_size=intermediate_size,
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
quant_config=quant_config, quant_config=quant_config,
reduce_results=self.experts.must_reduce_shared_expert_outputs( reduce_results = self.is_sequence_parallel,
),
prefix=f"{prefix}.shared_experts", prefix=f"{prefix}.shared_experts",
) )
self.enable_shared_experts_overlap = not (envs.VLLM_DISABLE_SHARED_EXPERTS_STREAM
or envs.USE_FUSED_RMS_QUANT
or envs.VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD
or config.n_shared_experts is None)
if self.enable_shared_experts_overlap:
self.experts = SharedFusedMoE(
shared_experts = self.shared_experts,
gate=self.gate,
num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
use_grouped_topk=True,
num_expert_group=config.n_group,
topk_group=config.topk_group,
prefix=f"{prefix}.experts",
scoring_func=config.scoring_func,
e_score_correction_bias=self.gate.e_score_correction_bias,
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
routed_scaling_factor=self.routed_scaling_factor)
else:
self.experts = FusedMoE(
num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
use_grouped_topk=True,
num_expert_group=config.n_group,
topk_group=config.topk_group,
prefix=f"{prefix}.experts",
scoring_func=config.scoring_func,
e_score_correction_bias=self.gate.e_score_correction_bias,
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
routed_scaling_factor=self.routed_scaling_factor)
from vllm.two_batch_overlap.two_batch_overlap import tbo_all_reduce from vllm.two_batch_overlap.two_batch_overlap import tbo_all_reduce
self.tbo_all_reduce = tbo_all_reduce self.tbo_all_reduce = tbo_all_reduce
...@@ -215,39 +240,36 @@ class DeepseekV2MoE(nn.Module): ...@@ -215,39 +240,36 @@ class DeepseekV2MoE(nn.Module):
xqxs: Optional[tuple[torch.Tensor, torch.Tensor]] = None xqxs: Optional[tuple[torch.Tensor, torch.Tensor]] = None
) -> Union[torch.Tensor, ) -> Union[torch.Tensor,
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]:
if envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and xqxs is not None: if self.enable_shared_experts_overlap:
num_tokens, hidden_dim = hidden_states.shape num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim) hidden_states = hidden_states.view(-1, hidden_dim)
if self.n_shared_experts is not None:
shared_output = self.shared_experts(hidden_states, xqxs=xqxs)
router_logits, _ = self.gate(hidden_states) router_logits, _ = self.gate(hidden_states)
shared_output, final_hidden_states = self.experts(hidden_states=hidden_states, router_logits=router_logits)
if self.shared_experts is None:
assert shared_output is None
if envs.VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD: # Fix FP16 overflow
final_hidden_states = self.experts( # See DeepseekV2DecoderLayer for more details.
hidden_states=hidden_states, if hidden_states.dtype != torch.float16:
router_logits=router_logits, final_hidden_states *= self.routed_scaling_factor
shared_output=shared_output) elif self.shared_experts is not None:
else: assert shared_output is not None
if hidden_states.dtype != torch.float16: shared_output *= 1.0 / self.routed_scaling_factor
final_hidden_states = self.experts(
hidden_states=hidden_states, if self.shared_experts is not None:
router_logits=router_logits) * self.routed_scaling_factor assert shared_output is not None
else: final_hidden_states += shared_output
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details. # if self.is_sequence_parallel:
final_hidden_states = self.experts(hidden_states=hidden_states, # final_hidden_states = tensor_model_parallel_all_gather(
router_logits=router_logits) # final_hidden_states, 0
# )
if shared_output is not None: # final_hidden_states = final_hidden_states[:num_tokens]
if hidden_states.dtype != torch.float16: # elif self.tp_size > 1:
final_hidden_states = final_hidden_states + shared_output # final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(
else: # final_hidden_states
# Fix FP16 overflow # )
# See DeepseekV2DecoderLayer for more details.
final_hidden_states = final_hidden_states + shared_output \
* (1. / self.routed_scaling_factor)
if self.tp_size > 1: if self.tp_size > 1:
if envs.VLLM_ENABLE_TBO: if envs.VLLM_ENABLE_TBO:
final_hidden_states = self.tbo_all_reduce(final_hidden_states) final_hidden_states = self.tbo_all_reduce(final_hidden_states)
...@@ -256,59 +278,104 @@ class DeepseekV2MoE(nn.Module): ...@@ -256,59 +278,104 @@ class DeepseekV2MoE(nn.Module):
self.experts.maybe_all_reduce_tensor_model_parallel( self.experts.maybe_all_reduce_tensor_model_parallel(
final_hidden_states)) final_hidden_states))
return final_hidden_states.view(num_tokens, hidden_dim) return final_hidden_states.view(num_tokens, hidden_dim)
else:
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim) else:
i_q, i_s = None, None if envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and xqxs is not None:
if self.n_shared_experts is not None: num_tokens, hidden_dim = hidden_states.shape
if envs.USE_FUSED_RMS_QUANT: hidden_states = hidden_states.view(-1, hidden_dim)
shared_output, new_resi, i_q, i_s = self.shared_experts(hidden_states, rms_weight, residual, update_hd=True) if self.n_shared_experts is not None:
shared_output = self.shared_experts(hidden_states, xqxs=xqxs)
router_logits, _ = self.gate(hidden_states)
if envs.VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD:
final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits,
shared_output=shared_output)
else: else:
shared_output = self.shared_experts(hidden_states) if hidden_states.dtype != torch.float16:
final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits) * self.routed_scaling_factor
else:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits)
router_logits, _ = self.gate(hidden_states) if shared_output is not None:
if hidden_states.dtype != torch.float16:
final_hidden_states = final_hidden_states + shared_output
else:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states = final_hidden_states + shared_output \
* (1. / self.routed_scaling_factor)
if self.tp_size > 1:
if envs.VLLM_ENABLE_TBO:
final_hidden_states = self.tbo_all_reduce(final_hidden_states)
else:
final_hidden_states = (
self.experts.maybe_all_reduce_tensor_model_parallel(
final_hidden_states))
return final_hidden_states.view(num_tokens, hidden_dim)
else:
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
i_q, i_s = None, None
if self.n_shared_experts is not None:
if envs.USE_FUSED_RMS_QUANT:
shared_output, new_resi, i_q, i_s = self.shared_experts(hidden_states, rms_weight, residual, update_hd=True)
else:
shared_output = self.shared_experts(hidden_states)
router_logits, _ = self.gate(hidden_states)
if envs.VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD: if envs.VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD:
final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits,
shared_output=shared_output,
i_q=i_q, i_s=i_s)
else:
if hidden_states.dtype != torch.float16:
final_hidden_states = self.experts( final_hidden_states = self.experts(
hidden_states=hidden_states, hidden_states=hidden_states,
router_logits=router_logits, router_logits=router_logits,
i_q=i_q, i_s=i_s) * self.routed_scaling_factor shared_output=shared_output,
i_q=i_q, i_s=i_s)
else: else:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
# fp16 mode not fused quant
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)
if shared_output is not None:
if hidden_states.dtype != torch.float16: if hidden_states.dtype != torch.float16:
final_hidden_states = final_hidden_states + shared_output final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits,
i_q=i_q, i_s=i_s) * self.routed_scaling_factor
else: else:
# Fix FP16 overflow # Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details. # See DeepseekV2DecoderLayer for more details.
final_hidden_states = final_hidden_states + shared_output \ # fp16 mode not fused quant
* (1. / self.routed_scaling_factor) final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)
if shared_output is not None:
if hidden_states.dtype != torch.float16:
final_hidden_states = final_hidden_states + shared_output
else:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states = final_hidden_states + shared_output \
* (1. / self.routed_scaling_factor)
if self.tp_size > 1:
if envs.VLLM_ENABLE_TBO:
final_hidden_states = self.tbo_all_reduce(final_hidden_states)
else:
final_hidden_states = (
self.experts.maybe_all_reduce_tensor_model_parallel(
final_hidden_states))
if self.tp_size > 1: if envs.USE_FUSED_RMS_QUANT:
if envs.VLLM_ENABLE_TBO: return final_hidden_states.view(num_tokens, hidden_dim), new_resi, i_q, i_s
final_hidden_states = self.tbo_all_reduce(final_hidden_states)
else: else:
final_hidden_states = ( return final_hidden_states.view(num_tokens, hidden_dim)
self.experts.maybe_all_reduce_tensor_model_parallel(
final_hidden_states))
if envs.USE_FUSED_RMS_QUANT:
return final_hidden_states.view(num_tokens, hidden_dim), new_resi, i_q, i_s
else:
return final_hidden_states.view(num_tokens, hidden_dim)
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment