Commit 0019ecdc authored by laibao's avatar laibao
Browse files

feat: Support shared expert overlap with expert.

parent 3ab9494d
......@@ -1939,6 +1939,25 @@ class ParallelConfig:
assert last_exc is not None
raise last_exc
# The all_reduce at the end of attention (during o_proj) means that
# inputs are replicated across each rank of the tensor parallel group.
# If using expert-parallelism with DeepEP All2All ops, replicated
# tokens results in useless duplicate computation and communication.
#
# In this case, ensure the input to the experts is sequence parallel
# to avoid the excess work.
#
# Not needed for pplx-kernels as it can handle duplicate input tokens.
@property
def use_sequence_parallel_moe(self) -> bool:
return (envs.VLLM_ALL2ALL_BACKEND
in ("allgather_reducescatter", "naive",
"deepep_high_throughput", "deepep_low_latency")
and self.enable_expert_parallel
and self.tensor_parallel_size > 1
and self.data_parallel_size > 1)
@staticmethod
def has_unfinished_dp(dp_group: "ProcessGroup",
has_unfinished: bool) -> bool:
......
......@@ -194,6 +194,7 @@ if TYPE_CHECKING:
VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT: bool = False
VLLM_USE_FUSED_RMS_ROPE: bool = False
VLLM_USE_MARLIN_W16A16_MOE:bool = False
VLLM_DISABLE_SHARED_EXPERTS_STREAM:bool = False
def get_default_cache_root():
return os.getenv(
......@@ -1265,7 +1266,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_MARLIN_W16A16_MOE":
lambda: (os.environ.get("VLLM_USE_MARLIN_W16A16_MOE", "False").lower() in
("true", "1")),
"VLLM_DISABLE_SHARED_EXPERTS_STREAM": lambda: bool(
int(os.getenv("VLLM_DISABLE_SHARED_EXPERTS_STREAM", "0"))
),
}
# --8<-- [end:env-vars-definition]
......
......@@ -10,6 +10,7 @@ from vllm.model_executor.layers.fused_moe.layer import (
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEActivationFormat, FusedMoEPermuteExpertsUnpermute,
FusedMoEPrepareAndFinalize)
from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE
from vllm.triton_utils import HAS_TRITON
_config: Optional[dict[str, Any]] = None
......@@ -30,6 +31,7 @@ def get_config() -> Optional[dict[str, Any]]:
__all__ = [
"FusedMoE",
"SharedFusedMoE",
"FusedMoEConfig",
"FusedMoEMethodBase",
"FusedMoeWeightScaleSupported",
......
......@@ -73,6 +73,26 @@ else:
logger = init_logger(__name__)
# Global auxilary stream for running operations in background streams.
# We have single global auxilary stream to avoid an explosion of streams
# for every layer (and make profiling look sane).
#
# aux_stream() is currently used for:
# - MoE shared_expert overlap with router
_aux_stream: torch.cuda.Stream | None = None
def aux_stream() -> torch.cuda.Stream | None:
"""
Ensures aux_stream is initialized only once
"""
global _aux_stream
from vllm.platforms import current_platform
if _aux_stream is None and current_platform.is_cuda_alike():
_aux_stream = torch.cuda.Stream()
return _aux_stream
class FusedMoeWeightScaleSupported(Enum):
TENSOR = "tensor"
......@@ -686,6 +706,21 @@ class FusedMoE(torch.nn.Module):
routed_scaling_factor: Optional[float] = 1.0,
):
super().__init__()
# Allow disabling of the separate shared experts stream for
# debug purposes.
# TODO: Remove this after more extensive testings with TP/DP
# and other execution modes
if envs.VLLM_DISABLE_SHARED_EXPERTS_STREAM:
logger.info_once("Disabling MoE shared_experts cuda stream")
self.shared_experts_stream = None
else:
# TODO(rob): enable shared expert overlap with non-cuda-alike.
# aux_stream() returns None on non-cuda-alike platforms.
self.shared_experts_stream = aux_stream()
if self.shared_experts_stream is not None:
logger.info_once("Enabled separate cuda stream for MoE shared_experts")
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
......@@ -898,6 +933,19 @@ class FusedMoE(torch.nn.Module):
def use_deepep_ll_kernels(self):
return self.moe_parallel_config.use_deepep_ll_kernels
@property
def shared_experts(self) -> torch.nn.Module | None:
return None
@property
def gate(self) -> torch.nn.Module | None:
return None
@property
def is_internal_router(self) -> bool:
# By default, router/gate is called before FusedMoE forward pass
return False
def _load_per_tensor_weight_scale(self, shard_id: str,
param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
......@@ -1438,16 +1486,23 @@ class FusedMoE(torch.nn.Module):
shared_output: Optional[torch.Tensor] = None,
i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None, **_
):
# TODO: Once the OOM issue for the TPU backend is resolved, we will
# switch to using the moe_forward custom op.
if current_platform.is_tpu():
assert i_q is None and i_s is None, "moe.quant fused not support TPU now"
return self.forward_impl(hidden_states, router_logits)
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if self.shared_experts is None:
# TODO: Once the OOM issue for the TPU backend is resolved, we will
# switch to using the moe_forward custom op.
if current_platform.is_tpu():
assert i_q is None and i_s is None, "moe.quant fused not support TPU now"
return self.forward_impl(hidden_states, router_logits)
else:
return torch.ops.vllm.moe_forward(hidden_states, router_logits,
self.layer_name, shared_output,
i_q, i_s)
else:
return torch.ops.vllm.moe_forward(hidden_states, router_logits,
self.layer_name, shared_output,
i_q, i_s)
if current_platform.is_tpu():
assert i_q is None and i_s is None, "moe.quant fused not support TPU now"
return self.forward_impl(hidden_states, router_logits)
else:
return torch.ops.vllm.moe_forward_shared(hidden_states, router_logits, self.layer_name)
def forward_impl_chunked(self, full_hidden_states: torch.Tensor,
full_router_logits: torch.Tensor):
......@@ -1524,13 +1579,58 @@ class FusedMoE(torch.nn.Module):
skip_result_store=chunk_start_ >= num_tokens)
return full_final_hidden_states
def _maybe_setup_shared_experts_stream(
self,
hidden_states: torch.Tensor,
has_separate_shared_experts: bool,
use_chunked_impl: bool,
) -> tuple[bool, torch.Tensor | None]:
use_shared_experts_stream = (
# current_platform.is_cuda()
True
and has_separate_shared_experts
and not use_chunked_impl
and self.shared_experts_stream is not None
# and (
# hidden_states.shape[0]
# <= envs.VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD
# )
)
hidden_states_clone: torch.Tensor | None = None
if use_shared_experts_stream:
assert self.shared_experts_stream is not None
# Clone BEFORE switching streams to avoid race condition
# where routed_expert kernel may mutate hidden_states.
# hidden_states_clone = hidden_states.clone()
# Record that the clone will be used by shared_experts_stream
# to avoid gc issue from deallocation of hidden_states_clone
# For more details: https://docs.pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html # noqa: E501
# NOTE: We don't need shared_output.record_stream(current_stream())
# because we synch the streams before using shared_output.
# hidden_states_clone.record_stream(self.shared_experts_stream)
# Mark sync start point for the separate shared experts
# stream here since we want to run in parallel with the
# router/gate (next op below)
assert self.shared_experts_stream is not None
self.shared_experts_stream.wait_stream(torch.cuda.current_stream())
return use_shared_experts_stream, hidden_states_clone
def forward_impl(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor,
shared_output: Optional[torch.Tensor] = None,
i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None, **_):
i_s: Optional[torch.Tensor] = None, **_) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.quant_method is not None
use_shared_experts_stream, hidden_states_clone = self._maybe_setup_shared_experts_stream(hidden_states,
self.shared_experts is not None and self.shared_experts_stream is not None,
self.moe_parallel_config.use_pplx_kernels or self.moe_parallel_config.use_deepep_ll_kernels)
if (self.moe_parallel_config.use_pplx_kernels
or self.moe_parallel_config.use_deepep_ll_kernels):
return self.forward_impl_chunked(hidden_states, router_logits)
......@@ -1592,24 +1692,48 @@ class FusedMoE(torch.nn.Module):
expert_load_view=self.expert_load_view,
logical_to_physical_map=self.logical_to_physical_map,
logical_replica_count=self.logical_replica_count,
shared_output=shared_output,
shared_output=None,
use_nn_moe=self.use_nn_moe,
routed_scaling_factor=self.routed_scaling_factor,
use_fused_gate=self.use_fused_gate,
)
if do_naive_dispatch_combine:
final_hidden_states = get_ep_group().combine(final_hidden_states)
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
# Default set to False. (May have to add shared expert outputs.
if envs.VLLM_ENABLE_TBO:
final_hidden_states = self.tbo_all_reduce(final_hidden_states)
else:
final_hidden_states = self.maybe_all_reduce_tensor_model_parallel(
final_hidden_states)
return final_hidden_states
if use_shared_experts_stream:
assert self.shared_experts is not None
# Run shared experts in parallel on a separate stream
# NOTE: We start the separate stream here and mark the
# sync end point immediately after it is done. This is
# important to avoid excessive stream allocations by the cuda
# graph replay later.
with torch.cuda.stream(self.shared_experts_stream):
# Note that hidden_states clone() is necessary here to avoid
# conflict with the main stream
shared_output = self.shared_experts(hidden_states)
torch.cuda.current_stream().wait_stream(self.shared_experts_stream)
final_hidden_states = (
shared_output,
final_hidden_states,
)
def combine_output(states: torch.Tensor) -> torch.Tensor:
if do_naive_dispatch_combine:
states = get_ep_group().combine(states)
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
# Default set to False. (May have to add shared expert outputs.
if envs.VLLM_ENABLE_TBO:
states = self.tbo_all_reduce(states)
else:
states = self.maybe_all_reduce_tensor_model_parallel(
states)
return states
if self.shared_experts is not None and not envs.USE_FUSED_RMS_QUANT:
return (
final_hidden_states[0],
combine_output(final_hidden_states[1]),
)
else:
return combine_output(final_hidden_states)
@classmethod
def make_expert_params_mapping(
......@@ -1694,3 +1818,34 @@ direct_register_custom_op(
dispatch_key=current_platform.dispatch_key,
tags=(torch.Tag.needs_fixed_stride_order, ),
)
def moe_forward_shared(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
layer_name: str) -> tuple[torch.Tensor, torch.Tensor]:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
assert self.quant_method is not None
out = self.forward_impl(hidden_states, router_logits)
return out
def moe_forward_shared_fake(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
layer_name: str) -> tuple[torch.Tensor, torch.Tensor]:
shared_out = torch.empty_like(hidden_states)
fused_out = torch.empty_like(hidden_states)
return shared_out, fused_out
direct_register_custom_op(
op_name="moe_forward_shared",
op_func=moe_forward_shared,
mutates_args=["hidden_states"],
fake_impl=moe_forward_shared_fake,
dispatch_key=current_platform.dispatch_key,
tags=(torch.Tag.needs_fixed_stride_order, ),
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import torch
from vllm.distributed import (
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
# TODO(bnell): Add shared + fused combo function? e.g. +
class SharedFusedMoE(FusedMoE):
"""
A FusedMoE operation that also computes the results of shared experts.
If an all2all communicator is being used the shared expert computation
can be interleaved with the fused all2all dispatch communication step.
"""
def __init__(
self,
shared_experts: torch.nn.Module | None,
gate: torch.nn.Module | None = None,
**kwargs,
):
super().__init__(**kwargs)
self._shared_experts = shared_experts
self._gate = gate
@property
def shared_experts(self) -> torch.nn.Module | None:
return self._shared_experts
@property
def gate(self) -> torch.nn.Module | None:
return self._gate
@property
def is_internal_router(self) -> bool:
return self.gate is not None
def forward(self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
shared_out, fused_out = super().forward(
hidden_states=hidden_states,
router_logits=router_logits,
)
# # ensure early TP reduction of shared expert outputs when required
# if (
# shared_out is not None
# and self.reduce_results
# and get_tensor_model_parallel_world_size() > 1
# and self.must_reduce_shared_expert_outputs()
# ):
# shared_out = tensor_model_parallel_all_reduce(shared_out)
return shared_out, fused_out
\ No newline at end of file
......@@ -43,6 +43,7 @@ from vllm.distributed import (get_ep_group, get_pp_group, get_dp_group,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
......@@ -163,6 +164,7 @@ class DeepseekV2MoE(nn.Module):
vllm_config = get_current_vllm_config()
parallel_config = vllm_config.parallel_config
self.enable_eplb = enable_eplb
self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
self.n_redundant_experts = parallel_config.num_redundant_experts
self.n_logical_experts = self.n_routed_experts
......@@ -175,24 +177,6 @@ class DeepseekV2MoE(nn.Module):
self.physical_expert_end = (self.physical_expert_start +
self.n_local_physical_experts)
self.experts = FusedMoE(
num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
use_grouped_topk=True,
num_expert_group=config.n_group,
topk_group=config.topk_group,
prefix=f"{prefix}.experts",
scoring_func=config.scoring_func,
e_score_correction_bias=self.gate.e_score_correction_bias,
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
routed_scaling_factor=self.routed_scaling_factor)
if config.n_shared_experts is not None:
intermediate_size = (config.moe_intermediate_size *
config.n_shared_experts)
......@@ -201,10 +185,51 @@ class DeepseekV2MoE(nn.Module):
intermediate_size=intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
reduce_results=self.experts.must_reduce_shared_expert_outputs(
),
reduce_results = self.is_sequence_parallel,
prefix=f"{prefix}.shared_experts",
)
self.enable_shared_experts_overlap = not (envs.VLLM_DISABLE_SHARED_EXPERTS_STREAM
or envs.USE_FUSED_RMS_QUANT
or envs.VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD
or config.n_shared_experts is None)
if self.enable_shared_experts_overlap:
self.experts = SharedFusedMoE(
shared_experts = self.shared_experts,
gate=self.gate,
num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
use_grouped_topk=True,
num_expert_group=config.n_group,
topk_group=config.topk_group,
prefix=f"{prefix}.experts",
scoring_func=config.scoring_func,
e_score_correction_bias=self.gate.e_score_correction_bias,
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
routed_scaling_factor=self.routed_scaling_factor)
else:
self.experts = FusedMoE(
num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
use_grouped_topk=True,
num_expert_group=config.n_group,
topk_group=config.topk_group,
prefix=f"{prefix}.experts",
scoring_func=config.scoring_func,
e_score_correction_bias=self.gate.e_score_correction_bias,
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
routed_scaling_factor=self.routed_scaling_factor)
from vllm.two_batch_overlap.two_batch_overlap import tbo_all_reduce
self.tbo_all_reduce = tbo_all_reduce
......@@ -215,39 +240,36 @@ class DeepseekV2MoE(nn.Module):
xqxs: Optional[tuple[torch.Tensor, torch.Tensor]] = None
) -> Union[torch.Tensor,
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]:
if envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and xqxs is not None:
if self.enable_shared_experts_overlap:
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
if self.n_shared_experts is not None:
shared_output = self.shared_experts(hidden_states, xqxs=xqxs)
router_logits, _ = self.gate(hidden_states)
shared_output, final_hidden_states = self.experts(hidden_states=hidden_states, router_logits=router_logits)
if self.shared_experts is None:
assert shared_output is None
if envs.VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD:
final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits,
shared_output=shared_output)
else:
if hidden_states.dtype != torch.float16:
final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits) * self.routed_scaling_factor
else:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)
if shared_output is not None:
if hidden_states.dtype != torch.float16:
final_hidden_states = final_hidden_states + shared_output
else:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states = final_hidden_states + shared_output \
* (1. / self.routed_scaling_factor)
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
if hidden_states.dtype != torch.float16:
final_hidden_states *= self.routed_scaling_factor
elif self.shared_experts is not None:
assert shared_output is not None
shared_output *= 1.0 / self.routed_scaling_factor
if self.shared_experts is not None:
assert shared_output is not None
final_hidden_states += shared_output
# if self.is_sequence_parallel:
# final_hidden_states = tensor_model_parallel_all_gather(
# final_hidden_states, 0
# )
# final_hidden_states = final_hidden_states[:num_tokens]
# elif self.tp_size > 1:
# final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(
# final_hidden_states
# )
if self.tp_size > 1:
if envs.VLLM_ENABLE_TBO:
final_hidden_states = self.tbo_all_reduce(final_hidden_states)
......@@ -256,59 +278,104 @@ class DeepseekV2MoE(nn.Module):
self.experts.maybe_all_reduce_tensor_model_parallel(
final_hidden_states))
return final_hidden_states.view(num_tokens, hidden_dim)
else:
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
i_q, i_s = None, None
if self.n_shared_experts is not None:
if envs.USE_FUSED_RMS_QUANT:
shared_output, new_resi, i_q, i_s = self.shared_experts(hidden_states, rms_weight, residual, update_hd=True)
else:
if envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and xqxs is not None:
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
if self.n_shared_experts is not None:
shared_output = self.shared_experts(hidden_states, xqxs=xqxs)
router_logits, _ = self.gate(hidden_states)
if envs.VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD:
final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits,
shared_output=shared_output)
else:
shared_output = self.shared_experts(hidden_states)
if hidden_states.dtype != torch.float16:
final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits) * self.routed_scaling_factor
else:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits)
router_logits, _ = self.gate(hidden_states)
if shared_output is not None:
if hidden_states.dtype != torch.float16:
final_hidden_states = final_hidden_states + shared_output
else:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states = final_hidden_states + shared_output \
* (1. / self.routed_scaling_factor)
if self.tp_size > 1:
if envs.VLLM_ENABLE_TBO:
final_hidden_states = self.tbo_all_reduce(final_hidden_states)
else:
final_hidden_states = (
self.experts.maybe_all_reduce_tensor_model_parallel(
final_hidden_states))
return final_hidden_states.view(num_tokens, hidden_dim)
else:
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
i_q, i_s = None, None
if self.n_shared_experts is not None:
if envs.USE_FUSED_RMS_QUANT:
shared_output, new_resi, i_q, i_s = self.shared_experts(hidden_states, rms_weight, residual, update_hd=True)
else:
shared_output = self.shared_experts(hidden_states)
router_logits, _ = self.gate(hidden_states)
if envs.VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD:
final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits,
shared_output=shared_output,
i_q=i_q, i_s=i_s)
else:
if hidden_states.dtype != torch.float16:
if envs.VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD:
final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits,
i_q=i_q, i_s=i_s) * self.routed_scaling_factor
router_logits=router_logits,
shared_output=shared_output,
i_q=i_q, i_s=i_s)
else:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
# fp16 mode not fused quant
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)
if shared_output is not None:
if hidden_states.dtype != torch.float16:
final_hidden_states = final_hidden_states + shared_output
final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits,
i_q=i_q, i_s=i_s) * self.routed_scaling_factor
else:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states = final_hidden_states + shared_output \
* (1. / self.routed_scaling_factor)
# fp16 mode not fused quant
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)
if shared_output is not None:
if hidden_states.dtype != torch.float16:
final_hidden_states = final_hidden_states + shared_output
else:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states = final_hidden_states + shared_output \
* (1. / self.routed_scaling_factor)
if self.tp_size > 1:
if envs.VLLM_ENABLE_TBO:
final_hidden_states = self.tbo_all_reduce(final_hidden_states)
else:
final_hidden_states = (
self.experts.maybe_all_reduce_tensor_model_parallel(
final_hidden_states))
if self.tp_size > 1:
if envs.VLLM_ENABLE_TBO:
final_hidden_states = self.tbo_all_reduce(final_hidden_states)
if envs.USE_FUSED_RMS_QUANT:
return final_hidden_states.view(num_tokens, hidden_dim), new_resi, i_q, i_s
else:
final_hidden_states = (
self.experts.maybe_all_reduce_tensor_model_parallel(
final_hidden_states))
if envs.USE_FUSED_RMS_QUANT:
return final_hidden_states.view(num_tokens, hidden_dim), new_resi, i_q, i_s
else:
return final_hidden_states.view(num_tokens, hidden_dim)
return final_hidden_states.view(num_tokens, hidden_dim)
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment