"vscode:/vscode.git/clone" did not exist on "2ce90e5b01aafca68cc821ff778de3ca65c75439"
Commit ee19dca6 authored by wanglong3's avatar wanglong3
Browse files

feat: enable shared expert overlap.

parent ffc00331
...@@ -1939,6 +1939,24 @@ class ParallelConfig: ...@@ -1939,6 +1939,24 @@ class ParallelConfig:
assert last_exc is not None assert last_exc is not None
raise last_exc raise last_exc
# The all_reduce at the end of attention (during o_proj) means that
# inputs are replicated across each rank of the tensor parallel group.
# If using expert-parallelism with DeepEP All2All ops, replicated
# tokens results in useless duplicate computation and communication.
#
# In this case, ensure the input to the experts is sequence parallel
# to avoid the excess work.
#
# Not needed for pplx-kernels as it can handle duplicate input tokens.
@property
def use_sequence_parallel_moe(self) -> bool:
return (envs.VLLM_ALL2ALL_BACKEND
in ("allgather_reducescatter", "naive",
"deepep_high_throughput", "deepep_low_latency")
and self.enable_expert_parallel
and self.tensor_parallel_size > 1
and self.data_parallel_size > 1)
@staticmethod @staticmethod
def has_unfinished_dp(dp_group: "ProcessGroup", def has_unfinished_dp(dp_group: "ProcessGroup",
has_unfinished: bool) -> bool: has_unfinished: bool) -> bool:
......
...@@ -204,6 +204,7 @@ if TYPE_CHECKING: ...@@ -204,6 +204,7 @@ if TYPE_CHECKING:
VLLM_ZERO_OVERHEAD_ENHANCE: bool = False VLLM_ZERO_OVERHEAD_ENHANCE: bool = False
VLLM_USE_FUSED_QA_KVA_GEMM: bool = False VLLM_USE_FUSED_QA_KVA_GEMM: bool = False
VLLM_V1_FAST_TOKEN_ID_COPY: bool = False VLLM_V1_FAST_TOKEN_ID_COPY: bool = False
VLLM_DISABLE_SHARED_EXPERTS_STREAM:bool = True
def get_default_cache_root(): def get_default_cache_root():
return os.getenv( return os.getenv(
...@@ -1306,6 +1307,7 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1306,6 +1307,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_ENABLE_DEEPEP_HT_DEEPGEMM": "VLLM_ENABLE_DEEPEP_HT_DEEPGEMM":
lambda: (os.getenv('VLLM_ENABLE_DEEPEP_HT_DEEPGEMM', '1').lower() in lambda: (os.getenv('VLLM_ENABLE_DEEPEP_HT_DEEPGEMM', '1').lower() in
("true", "1")), ("true", "1")),
# Only quantized DeepSeek models supported. # Only quantized DeepSeek models supported.
# Unquantized versions are not supported. # Unquantized versions are not supported.
"VLLM_USE_FUSED_QA_KVA_GEMM": "VLLM_USE_FUSED_QA_KVA_GEMM":
...@@ -1318,6 +1320,11 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1318,6 +1320,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_V1_FAST_TOKEN_ID_COPY": "VLLM_V1_FAST_TOKEN_ID_COPY":
lambda: (os.environ.get("VLLM_V1_FAST_TOKEN_ID_COPY", "False").lower() in lambda: (os.environ.get("VLLM_V1_FAST_TOKEN_ID_COPY", "False").lower() in
("true", "1")), ("true", "1")),
"VLLM_DISABLE_SHARED_EXPERTS_STREAM": lambda: bool(
int(os.getenv("VLLM_DISABLE_SHARED_EXPERTS_STREAM", "1"))
),
} }
# --8<-- [end:env-vars-definition] # --8<-- [end:env-vars-definition]
......
...@@ -28,8 +28,8 @@ from vllm.model_executor.layers.fused_moe.config import ( ...@@ -28,8 +28,8 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEParallelConfig) FusedMoEConfig, FusedMoEParallelConfig)
# yapf: enable # yapf: enable
from vllm.model_executor.layers.fused_moe.modular_kernel import ( from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEActivationFormat, FusedMoEModularKernel, FusedMoEActivationFormat, FusedMoEModularKernel,
DeepGemmDisabledFusedMoEModularKernel, FusedMoEPermuteExpertsUnpermute, DeepGemmDisabledFusedMoEModularKernel, FusedMoEPermuteExpertsUnpermute,
FusedMoEPrepareAndFinalize) FusedMoEPrepareAndFinalize)
# from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
# is_rocm_aiter_moe_enabled) # is_rocm_aiter_moe_enabled)
...@@ -74,6 +74,26 @@ else: ...@@ -74,6 +74,26 @@ else:
logger = init_logger(__name__) logger = init_logger(__name__)
# Global auxilary stream for running operations in background streams.
# We have single global auxilary stream to avoid an explosion of streams
# for every layer (and make profiling look sane).
#
# aux_stream() is currently used for:
# - MoE shared_expert overlap with router
_aux_stream: torch.cuda.Stream | None = None
def aux_stream() -> torch.cuda.Stream | None:
"""
Ensures aux_stream is initialized only once
"""
global _aux_stream
from vllm.platforms import current_platform
if _aux_stream is None and current_platform.is_cuda_alike():
_aux_stream = torch.cuda.Stream()
return _aux_stream
class FusedMoeWeightScaleSupported(Enum): class FusedMoeWeightScaleSupported(Enum):
TENSOR = "tensor" TENSOR = "tensor"
...@@ -170,7 +190,7 @@ class FusedMoEMethodBase(QuantizeMethodBase): ...@@ -170,7 +190,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
== current_platform.fp8_dtype() == current_platform.fp8_dtype()
and moe.quant_config.block_shape and moe.quant_config.block_shape
== DEEPEP_QUANT_BLOCK_SHAPE) == DEEPEP_QUANT_BLOCK_SHAPE)
use_int8_dispatch = moe.quant_config.quant_dtype == torch.int8 use_int8_dispatch = moe.quant_config.quant_dtype == torch.int8
# Note (varun): Whether to use FP8 dispatch or not needs some # Note (varun): Whether to use FP8 dispatch or not needs some
...@@ -698,6 +718,21 @@ class FusedMoE(torch.nn.Module): ...@@ -698,6 +718,21 @@ class FusedMoE(torch.nn.Module):
routed_scaling_factor: Optional[float] = 1.0, routed_scaling_factor: Optional[float] = 1.0,
): ):
super().__init__() super().__init__()
# Allow disabling of the separate shared experts stream for
# debug purposes.
# TODO: Remove this after more extensive testings with TP/DP
# and other execution modes
if envs.VLLM_DISABLE_SHARED_EXPERTS_STREAM:
logger.info_once("Disabling MoE shared_experts cuda stream")
self.shared_experts_stream = None
else:
# TODO(rob): enable shared expert overlap with non-cuda-alike.
# aux_stream() returns None on non-cuda-alike platforms.
self.shared_experts_stream = aux_stream()
if self.shared_experts_stream is not None:
logger.info_once("Enabled separate cuda stream for MoE shared_experts")
if params_dtype is None: if params_dtype is None:
params_dtype = torch.get_default_dtype() params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype self.params_dtype = params_dtype
...@@ -814,7 +849,7 @@ class FusedMoE(torch.nn.Module): ...@@ -814,7 +849,7 @@ class FusedMoE(torch.nn.Module):
# please refer to the implementation in `Fp8MoEMethod`. # please refer to the implementation in `Fp8MoEMethod`.
raise NotImplementedError("EPLB is only supported for FP8 " raise NotImplementedError("EPLB is only supported for FP8 "
"quantization for now.") "quantization for now.")
if quant_config is None: if quant_config is None:
# Not considering quant for now, temporarily # Not considering quant for now, temporarily
self.use_nn_moe = int(os.environ.get('MOE_NN', 1)) == 1 self.use_nn_moe = int(os.environ.get('MOE_NN', 1)) == 1
...@@ -909,9 +944,9 @@ class FusedMoE(torch.nn.Module): ...@@ -909,9 +944,9 @@ class FusedMoE(torch.nn.Module):
@property @property
def use_deepep_ll_kernels(self): def use_deepep_ll_kernels(self):
return self.moe_parallel_config.use_deepep_ll_kernels return self.moe_parallel_config.use_deepep_ll_kernels
@property @property
def shared_experts(self) -> Optional[torch.nn.Module]: def shared_experts(self) -> torch.nn.Module | None:
return None return None
def _load_per_tensor_weight_scale(self, shard_id: str, def _load_per_tensor_weight_scale(self, shard_id: str,
...@@ -1451,6 +1486,7 @@ class FusedMoE(torch.nn.Module): ...@@ -1451,6 +1486,7 @@ class FusedMoE(torch.nn.Module):
def forward(self, hidden_states: torch.Tensor, def forward(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
hidden_states_copy: Optional[torch.Tensor] = None, # for shared expert overlap
shared_output: Optional[torch.Tensor] = None, shared_output: Optional[torch.Tensor] = None,
i_q: Optional[torch.Tensor] = None, i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None, **_ i_s: Optional[torch.Tensor] = None, **_
...@@ -1458,7 +1494,7 @@ class FusedMoE(torch.nn.Module): ...@@ -1458,7 +1494,7 @@ class FusedMoE(torch.nn.Module):
# TODO: Once the OOM issue for the TPU backend is resolved, we will # TODO: Once the OOM issue for the TPU backend is resolved, we will
# switch to using the moe_forward custom op. # switch to using the moe_forward custom op.
if current_platform.is_tpu(): if current_platform.is_tpu():
assert i_q is None and i_s is None, "moe.quant fused not support TPU now" assert i_q is None and i_s is None, "moe.quant fused not support TPU now"
return self.forward_impl(hidden_states, router_logits) return self.forward_impl(hidden_states, router_logits)
else: else:
if self.shared_experts is None: if self.shared_experts is None:
...@@ -1467,7 +1503,7 @@ class FusedMoE(torch.nn.Module): ...@@ -1467,7 +1503,7 @@ class FusedMoE(torch.nn.Module):
i_q, i_s) i_q, i_s)
else: else:
return torch.ops.vllm.moe_forward_shared(hidden_states, router_logits, return torch.ops.vllm.moe_forward_shared(hidden_states, router_logits,
self.layer_name, shared_output) self.layer_name, hidden_states_copy, shared_output)
def forward_impl_chunked(self, full_hidden_states: torch.Tensor, def forward_impl_chunked(self, full_hidden_states: torch.Tensor,
full_router_logits: torch.Tensor): full_router_logits: torch.Tensor):
...@@ -1547,10 +1583,22 @@ class FusedMoE(torch.nn.Module): ...@@ -1547,10 +1583,22 @@ class FusedMoE(torch.nn.Module):
def forward_impl(self, hidden_states: torch.Tensor, def forward_impl(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
hidden_states_copy: Optional[torch.Tensor] = None,
shared_output: Optional[torch.Tensor] = None, shared_output: Optional[torch.Tensor] = None,
i_q: Optional[torch.Tensor] = None, i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None, **_): i_s: Optional[torch.Tensor] = None, **_)-> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.quant_method is not None assert self.quant_method is not None
enable_shared_experts_overlap = False
if (self.shared_experts_stream is not None
and hidden_states_copy is not None
and self.shared_experts is not None
and not self.moe_parallel_config.use_pplx_kernels):
enable_shared_experts_overlap = True
hidden_states_copy.record_stream(self.shared_experts_stream)
self.shared_experts_stream.wait_stream(torch.cuda.current_stream())
if (self.moe_parallel_config.use_pplx_kernels): if (self.moe_parallel_config.use_pplx_kernels):
#or self.moe_parallel_config.use_deepep_ll_kernels): #or self.moe_parallel_config.use_deepep_ll_kernels):
return self.forward_impl_chunked(hidden_states, router_logits) return self.forward_impl_chunked(hidden_states, router_logits)
...@@ -1619,18 +1667,45 @@ class FusedMoE(torch.nn.Module): ...@@ -1619,18 +1667,45 @@ class FusedMoE(torch.nn.Module):
use_fused_gate=self.use_fused_gate, use_fused_gate=self.use_fused_gate,
) )
if do_naive_dispatch_combine: if enable_shared_experts_overlap:
final_hidden_states = get_ep_group().combine(final_hidden_states) assert self.shared_experts is not None
# Run shared experts in parallel on a separate stream
# NOTE: We start the separate stream here and mark the
# sync end point immediately after it is done. This is
# important to avoid excessive stream allocations by the cuda
# graph replay later.
with torch.cuda.stream(self.shared_experts_stream):
# Note that hidden_states clone() is necessary here to avoid
# conflict with the main stream
assert hidden_states_copy is not None
shared_output = self.shared_experts(hidden_states_copy)
torch.cuda.current_stream().wait_stream(self.shared_experts_stream)
final_hidden_states = (
shared_output,
final_hidden_states,
)
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): def combine_output(states: torch.Tensor) -> torch.Tensor:
# Default set to False. (May have to add shared expert outputs. if do_naive_dispatch_combine:
if envs.VLLM_ENABLE_TBO: states = get_ep_group().combine(states)
final_hidden_states = self.tbo_all_reduce(final_hidden_states)
else:
final_hidden_states = self.maybe_all_reduce_tensor_model_parallel(
final_hidden_states)
return final_hidden_states if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
# Default set to False. (May have to add shared expert outputs.
if envs.VLLM_ENABLE_TBO:
states = self.tbo_all_reduce(states)
else:
states = self.maybe_all_reduce_tensor_model_parallel(
states)
return states
if enable_shared_experts_overlap and not envs.USE_FUSED_RMS_QUANT:
return (
final_hidden_states[0],
combine_output(final_hidden_states[1]),
)
else:
return combine_output(final_hidden_states)
@classmethod @classmethod
def make_expert_params_mapping( def make_expert_params_mapping(
...@@ -1686,7 +1761,7 @@ class FusedMoE(torch.nn.Module): ...@@ -1686,7 +1761,7 @@ class FusedMoE(torch.nn.Module):
return s return s
def moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor, def moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor,
layer_name: str, shared_output: Optional[torch.Tensor] = None, layer_name: str, shared_output: Optional[torch.Tensor] = None,
i_q: Optional[torch.Tensor] = None, i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None) -> torch.Tensor: i_s: Optional[torch.Tensor] = None) -> torch.Tensor:
...@@ -1697,7 +1772,7 @@ def moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor, ...@@ -1697,7 +1772,7 @@ def moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor,
return self.forward_impl(hidden_states, router_logits, shared_output, i_q, i_s) return self.forward_impl(hidden_states, router_logits, shared_output, i_q, i_s)
else: else:
return self.forward_impl(hidden_states, router_logits, shared_output) return self.forward_impl(hidden_states, router_logits, shared_output)
def moe_forward_fake(hidden_states: torch.Tensor, router_logits: torch.Tensor, def moe_forward_fake(hidden_states: torch.Tensor, router_logits: torch.Tensor,
...@@ -1720,18 +1795,20 @@ def moe_forward_shared( ...@@ -1720,18 +1795,20 @@ def moe_forward_shared(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
layer_name: str, layer_name: str,
hidden_states_copy: Optional[torch.Tensor] = None,
shared_output: Optional[torch.Tensor] = None shared_output: Optional[torch.Tensor] = None
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
forward_context: ForwardContext = get_forward_context() forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name] self = forward_context.no_compile_layers[layer_name]
assert self.shared_experts is not None assert self.shared_experts is not None
return self.forward_impl(hidden_states, router_logits, shared_output) return self.forward_impl(hidden_states, router_logits, hidden_states_copy, shared_output)
def moe_forward_shared_fake( def moe_forward_shared_fake(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
layer_name: str, layer_name: str,
hidden_states_copy: Optional[torch.Tensor] = None,
shared_output: Optional[torch.Tensor] = None shared_output: Optional[torch.Tensor] = None
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
shared_out = torch.empty_like(hidden_states) shared_out = torch.empty_like(hidden_states)
...@@ -1742,7 +1819,7 @@ def moe_forward_shared_fake( ...@@ -1742,7 +1819,7 @@ def moe_forward_shared_fake(
direct_register_custom_op( direct_register_custom_op(
op_name="moe_forward_shared", op_name="moe_forward_shared",
op_func=moe_forward_shared, op_func=moe_forward_shared,
mutates_args=["hidden_states"], mutates_args=["hidden_states", "hidden_states_copy"],
fake_impl=moe_forward_shared_fake, fake_impl=moe_forward_shared_fake,
tags=(torch.Tag.needs_fixed_stride_order,), tags=(torch.Tag.needs_fixed_stride_order,),
) )
\ No newline at end of file
...@@ -34,7 +34,8 @@ class SharedFusedMoE(FusedMoE): ...@@ -34,7 +34,8 @@ class SharedFusedMoE(FusedMoE):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
) -> torch.Tensor: hidden_states_copy: Optional[torch.Tensor] = None
) -> tuple[torch.Tensor, torch.Tensor]|torch.Tensor:
if not self.use_overlapped: if not self.use_overlapped:
shared_out = self._shared_experts(hidden_states) shared_out = self._shared_experts(hidden_states)
...@@ -53,6 +54,6 @@ class SharedFusedMoE(FusedMoE): ...@@ -53,6 +54,6 @@ class SharedFusedMoE(FusedMoE):
fused_out = super().forward( fused_out = super().forward(
hidden_states=hidden_states, hidden_states=hidden_states,
router_logits=router_logits, router_logits=router_logits,
hidden_states_copy = hidden_states_copy,
) )
return fused_out return fused_out
...@@ -70,8 +70,7 @@ from .utils import (PPMissingLayer, is_pp_missing_parameter, ...@@ -70,8 +70,7 @@ from .utils import (PPMissingLayer, is_pp_missing_parameter,
maybe_prefix) maybe_prefix)
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.utils import W8a8GetCacheJSON from vllm.utils import W8a8GetCacheJSON
class DeepseekV2MLP(nn.Module): class DeepseekV2MLP(nn.Module):
def __init__( def __init__(
...@@ -114,7 +113,7 @@ class DeepseekV2MLP(nn.Module): ...@@ -114,7 +113,7 @@ class DeepseekV2MLP(nn.Module):
else: else:
x = self.act_fn(gate_up) x = self.act_fn(gate_up)
x, _ = self.down_proj(x) x, _ = self.down_proj(x)
return x, new_resi, i_q, _scales return x, new_resi, i_q, _scales
elif envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and xqxs is not None: elif envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and xqxs is not None:
gate_up, _ = self.gate_up_proj(x, xqxs=xqxs) gate_up, _ = self.gate_up_proj(x, xqxs=xqxs)
...@@ -180,32 +179,15 @@ class DeepseekV2MoE(nn.Module): ...@@ -180,32 +179,15 @@ class DeepseekV2MoE(nn.Module):
self.n_local_physical_experts) self.n_local_physical_experts)
self.physical_expert_end = (self.physical_expert_start + self.physical_expert_end = (self.physical_expert_start +
self.n_local_physical_experts) self.n_local_physical_experts)
dp_size = get_dp_group().world_size dp_size = get_dp_group().world_size
self.enable_expert_parallel = parallel_config.enable_expert_parallel self.enable_expert_parallel = parallel_config.enable_expert_parallel
self.use_deepep = dp_size > 1 and parallel_config.enable_expert_parallel and \ self.use_deepep = dp_size > 1 and parallel_config.enable_expert_parallel and \
(envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" or \ (envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" or \
envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency") envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency")
self.enable_shared_experts_overlap = False
if not self.use_deepep: if not self.use_deepep:
self.experts = FusedMoE(
num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
use_grouped_topk=True,
num_expert_group=config.n_group,
topk_group=config.topk_group,
prefix=f"{prefix}.experts",
scoring_func=config.scoring_func,
e_score_correction_bias=self.gate.e_score_correction_bias,
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
routed_scaling_factor=self.routed_scaling_factor)
if config.n_shared_experts is not None: if config.n_shared_experts is not None:
intermediate_size = (config.moe_intermediate_size * intermediate_size = (config.moe_intermediate_size *
config.n_shared_experts) config.n_shared_experts)
...@@ -214,10 +196,51 @@ class DeepseekV2MoE(nn.Module): ...@@ -214,10 +196,51 @@ class DeepseekV2MoE(nn.Module):
intermediate_size=intermediate_size, intermediate_size=intermediate_size,
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
quant_config=quant_config, quant_config=quant_config,
reduce_results=self.experts.must_reduce_shared_expert_outputs( reduce_results = False,
),
prefix=f"{prefix}.shared_experts", prefix=f"{prefix}.shared_experts",
) )
self.enable_shared_experts_overlap = (not envs.VLLM_DISABLE_SHARED_EXPERTS_STREAM
and not envs.USE_FUSED_RMS_QUANT
and not envs.VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD
and config.n_shared_experts is not None)
if self.enable_shared_experts_overlap:
self.experts = SharedFusedMoE(
shared_experts=self.shared_experts,
num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
use_grouped_topk=True,
num_expert_group=config.n_group,
topk_group=config.topk_group,
prefix=f"{prefix}.experts",
scoring_func=config.scoring_func,
e_score_correction_bias=self.gate.e_score_correction_bias,
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
routed_scaling_factor=self.routed_scaling_factor)
else:
self.experts = FusedMoE(
num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
use_grouped_topk=True,
num_expert_group=config.n_group,
topk_group=config.topk_group,
prefix=f"{prefix}.experts",
scoring_func=config.scoring_func,
e_score_correction_bias=self.gate.e_score_correction_bias,
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
routed_scaling_factor=self.routed_scaling_factor)
else: else:
if config.n_shared_experts is not None: if config.n_shared_experts is not None:
intermediate_size = (config.moe_intermediate_size * intermediate_size = (config.moe_intermediate_size *
...@@ -249,6 +272,8 @@ class DeepseekV2MoE(nn.Module): ...@@ -249,6 +272,8 @@ class DeepseekV2MoE(nn.Module):
routed_scaling_factor=self.routed_scaling_factor, routed_scaling_factor=self.routed_scaling_factor,
shared_experts=self.shared_experts) shared_experts=self.shared_experts)
self.run_shared_expert_singlely = (self.n_shared_experts is not None and not self.enable_shared_experts_overlap)
from vllm.two_batch_overlap.two_batch_overlap import tbo_all_reduce from vllm.two_batch_overlap.two_batch_overlap import tbo_all_reduce
self.tbo_all_reduce = tbo_all_reduce self.tbo_all_reduce = tbo_all_reduce
...@@ -261,10 +286,19 @@ class DeepseekV2MoE(nn.Module): ...@@ -261,10 +286,19 @@ class DeepseekV2MoE(nn.Module):
num_tokens, hidden_dim = hidden_states.shape num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim) hidden_states = hidden_states.view(-1, hidden_dim)
def shared_exprts_overlap_pass(
hidden_states: torch.Tensor, router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
hidden_states_copy = hidden_states.clone()
return self.experts(
hidden_states=hidden_states,
router_logits=router_logits,
hidden_states_copy = hidden_states_copy)
if envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and xqxs is not None: if envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and xqxs is not None:
if self.n_shared_experts is not None: if self.n_shared_experts is not None and not self.enable_shared_experts_overlap:
shared_output = self.shared_experts(hidden_states, xqxs=xqxs) shared_output = self.shared_experts(hidden_states, xqxs=xqxs)
router_logits, _ = self.gate(hidden_states) router_logits, _ = self.gate(hidden_states)
if envs.VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD: if envs.VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD:
...@@ -273,76 +307,90 @@ class DeepseekV2MoE(nn.Module): ...@@ -273,76 +307,90 @@ class DeepseekV2MoE(nn.Module):
router_logits=router_logits, router_logits=router_logits,
shared_output=shared_output) shared_output=shared_output)
else: else:
if hidden_states.dtype != torch.float16: if self.enable_shared_experts_overlap:
final_hidden_states = self.experts( assert self.shared_experts is not None
hidden_states=hidden_states, shared_output, final_hidden_states = shared_exprts_overlap_pass(hidden_states, router_logits)
router_logits=router_logits) * self.routed_scaling_factor
else:
# Fix FP16 overflow # Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details. # See DeepseekV2DecoderLayer for more details.
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)
if shared_output is not None:
if hidden_states.dtype != torch.float16: if hidden_states.dtype != torch.float16:
final_hidden_states = final_hidden_states + shared_output final_hidden_states *= self.routed_scaling_factor
final_hidden_states += shared_output
else:
assert shared_output is not None
final_hidden_states += (shared_output * (1.0 / self.routed_scaling_factor))
else:
if hidden_states.dtype != torch.float16:
final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits) * self.routed_scaling_factor
else: else:
# Fix FP16 overflow # Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details. # See DeepseekV2DecoderLayer for more details.
final_hidden_states = final_hidden_states + shared_output \ final_hidden_states = self.experts(hidden_states=hidden_states,
* (1. / self.routed_scaling_factor) router_logits=router_logits)
if self.tp_size > 1: if shared_output is not None:
if envs.VLLM_ENABLE_TBO: if hidden_states.dtype != torch.float16:
final_hidden_states = self.tbo_all_reduce(final_hidden_states) final_hidden_states = final_hidden_states + shared_output
else: else:
final_hidden_states = ( # Fix FP16 overflow
self.experts.maybe_all_reduce_tensor_model_parallel( # See DeepseekV2DecoderLayer for more details.
final_hidden_states)) final_hidden_states = final_hidden_states + shared_output \
return final_hidden_states.view(num_tokens, hidden_dim) * (1. / self.routed_scaling_factor)
else: else:
if not self.enable_expert_parallel: if not self.enable_expert_parallel:
i_q, i_s = None, None i_q, i_s = None, None
if self.n_shared_experts is not None: if self.run_shared_expert_singlely:
if envs.USE_FUSED_RMS_QUANT: if envs.USE_FUSED_RMS_QUANT:
shared_output, new_resi, i_q, i_s = self.shared_experts(hidden_states, rms_weight, residual, update_hd=True) shared_output, new_resi, i_q, i_s = self.shared_experts(hidden_states, rms_weight, residual, update_hd=True)
else: else:
shared_output = self.shared_experts(hidden_states) shared_output = self.shared_experts(hidden_states)
router_logits, _ = self.gate(hidden_states) router_logits, _ = self.gate(hidden_states)
if envs.VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD: if self.enable_shared_experts_overlap:
final_hidden_states = self.experts( assert self.shared_experts is not None
hidden_states=hidden_states, shared_output, final_hidden_states = shared_exprts_overlap_pass(hidden_states, router_logits)
router_logits=router_logits, # Fix FP16 overflow
shared_output=shared_output, # See DeepseekV2DecoderLayer for more details.
i_q=i_q, i_s=i_s)
else:
if hidden_states.dtype != torch.float16: if hidden_states.dtype != torch.float16:
final_hidden_states *= self.routed_scaling_factor
final_hidden_states += shared_output
else:
assert shared_output is not None
final_hidden_states += (shared_output * (1.0 / self.routed_scaling_factor))
else:
if envs.VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD:
final_hidden_states = self.experts( final_hidden_states = self.experts(
hidden_states=hidden_states, hidden_states=hidden_states,
router_logits=router_logits, router_logits=router_logits,
i_q=i_q, i_s=i_s) * self.routed_scaling_factor shared_output=shared_output,
i_q=i_q, i_s=i_s)
else: else:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
# fp16 mode not fused quant
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)
if shared_output is not None:
if hidden_states.dtype != torch.float16: if hidden_states.dtype != torch.float16:
final_hidden_states = final_hidden_states + shared_output final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits,
i_q=i_q, i_s=i_s) * self.routed_scaling_factor
else: else:
# Fix FP16 overflow # Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details. # See DeepseekV2DecoderLayer for more details.
final_hidden_states = final_hidden_states + shared_output \ # fp16 mode not fused quant
* (1. / self.routed_scaling_factor) final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)
if shared_output is not None:
if hidden_states.dtype != torch.float16:
final_hidden_states = final_hidden_states + shared_output
else:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states = final_hidden_states + shared_output \
* (1. / self.routed_scaling_factor)
else: else:
router_logits, _ = self.gate(hidden_states) router_logits, _ = self.gate(hidden_states)
if self.use_deepep: if self.use_deepep:
shared_output, final_hidden_states = self.experts(hidden_states=hidden_states, shared_output, final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits) router_logits=router_logits)
if shared_output is not None: if shared_output is not None:
...@@ -354,37 +402,48 @@ class DeepseekV2MoE(nn.Module): ...@@ -354,37 +402,48 @@ class DeepseekV2MoE(nn.Module):
final_hidden_states = final_hidden_states + shared_output \ final_hidden_states = final_hidden_states + shared_output \
* (1. / self.routed_scaling_factor) * (1. / self.routed_scaling_factor)
else: else:
if self.n_shared_experts is not None: if self.run_shared_expert_singlely:
if envs.USE_FUSED_RMS_QUANT: if envs.USE_FUSED_RMS_QUANT:
shared_output, new_resi = self.shared_experts(hidden_states, rms_weight, residual, update_hd=True) shared_output, new_resi = self.shared_experts(hidden_states, rms_weight, residual, update_hd=True)
else: else:
shared_output = self.shared_experts(hidden_states) shared_output = self.shared_experts(hidden_states)
final_hidden_states = self.experts( if self.enable_shared_experts_overlap:
hidden_states=hidden_states, assert self.shared_experts is not None
router_logits=router_logits) shared_output, final_hidden_states = shared_exprts_overlap_pass(hidden_states, router_logits)
# Fix FP16 overflow
if shared_output is not None: # See DeepseekV2DecoderLayer for more details.
if hidden_states.dtype != torch.float16: if hidden_states.dtype != torch.float16:
final_hidden_states = final_hidden_states + shared_output final_hidden_states += shared_output
else: else:
# Fix FP16 overflow assert shared_output is not None
# See DeepseekV2DecoderLayer for more details. final_hidden_states += (shared_output * (1. / self.routed_scaling_factor))
final_hidden_states = final_hidden_states + shared_output \ else:
* (1. / self.routed_scaling_factor) final_hidden_states = self.experts(
hidden_states=hidden_states,
if self.tp_size > 1: router_logits=router_logits)
if envs.VLLM_ENABLE_TBO:
final_hidden_states = self.tbo_all_reduce(final_hidden_states) if shared_output is not None:
else: if hidden_states.dtype != torch.float16:
final_hidden_states = ( final_hidden_states = final_hidden_states + shared_output
self.experts.maybe_all_reduce_tensor_model_parallel( else:
final_hidden_states)) # Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
if envs.USE_FUSED_RMS_QUANT: final_hidden_states = final_hidden_states + shared_output \
return final_hidden_states.view(num_tokens, hidden_dim), new_resi, i_q, i_s * (1. / self.routed_scaling_factor)
if self.tp_size > 1:
if envs.VLLM_ENABLE_TBO:
final_hidden_states = self.tbo_all_reduce(final_hidden_states)
else: else:
return final_hidden_states.view(num_tokens, hidden_dim) final_hidden_states = (
self.experts.maybe_all_reduce_tensor_model_parallel(
final_hidden_states))
if envs.USE_FUSED_RMS_QUANT:
return final_hidden_states.view(num_tokens, hidden_dim), new_resi, i_q, i_s
else:
return final_hidden_states.view(num_tokens, hidden_dim)
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
...@@ -546,7 +605,7 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -546,7 +605,7 @@ class DeepseekV2MLAAttention(nn.Module):
""" """
Main reference: DeepseekV2 paper, and FlashInfer Implementation Main reference: DeepseekV2 paper, and FlashInfer Implementation
(https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551). (https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551).
For more info see MLACommonImpl in: vllm/attention/backends/mla/utils.py For more info see MLACommonImpl in: vllm/attention/backends/mla/utils.py
""" """
...@@ -623,7 +682,7 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -623,7 +682,7 @@ class DeepseekV2MLAAttention(nn.Module):
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.q_b_proj") prefix=f"{prefix}.q_b_proj")
self.q_a_layernorm = RMSNorm(self.q_lora_rank, self.q_a_layernorm = RMSNorm(self.q_lora_rank,
eps=config.rms_norm_eps) eps=config.rms_norm_eps)
...@@ -735,7 +794,7 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -735,7 +794,7 @@ class DeepseekV2MLAAttention(nn.Module):
kvc_kpe = self.kv_a_proj_with_mqa(hidden_states, quant_args=input_quant_args, update_hd=False)[0] kvc_kpe = self.kv_a_proj_with_mqa(hidden_states, quant_args=input_quant_args, update_hd=False)[0]
kv_c, k_pe = kvc_kpe.split( kv_c, k_pe = kvc_kpe.split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
if not envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT: if not envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT:
if envs.VLLM_USE_LIGHTOP: if envs.VLLM_USE_LIGHTOP:
kv_c_normed = self.kv_a_layernorm.forward_cuda_opt(kv_c) kv_c_normed = self.kv_a_layernorm.forward_cuda_opt(kv_c)
...@@ -763,7 +822,7 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -763,7 +822,7 @@ class DeepseekV2MLAAttention(nn.Module):
cos_sin_cache = self.rotary_emb.cos_sin_cache cos_sin_cache = self.rotary_emb.cos_sin_cache
if cos_sin_cache.device != positions.device or cos_sin_cache.device != q.dtype: if cos_sin_cache.device != positions.device or cos_sin_cache.device != q.dtype:
cos_sin_cache = cos_sin_cache.to(positions.device, dtype=q.dtype) cos_sin_cache = cos_sin_cache.to(positions.device, dtype=q.dtype)
kv_c_normed = torch.empty(kv_c.shape, dtype=kv_c.dtype, device=kv_c.device) kv_c_normed = torch.empty(kv_c.shape, dtype=kv_c.dtype, device=kv_c.device)
attn_out = self.mla_attn( attn_out = self.mla_attn(
q[..., self.qk_nope_head_dim:], q[..., self.qk_nope_head_dim:],
kv_c, kv_c,
...@@ -788,7 +847,7 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -788,7 +847,7 @@ class DeepseekV2MLAAttention(nn.Module):
if not envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT: if not envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT:
if envs.VLLM_USE_LIGHTOP: if envs.VLLM_USE_LIGHTOP:
kv_c_normed = self.kv_a_layernorm.forward_cuda_opt(kv_c) kv_c_normed = self.kv_a_layernorm.forward_cuda_opt(kv_c)
else: else:
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
q = q.view(-1, self.num_local_heads, self.qk_head_dim) q = q.view(-1, self.num_local_heads, self.qk_head_dim)
...@@ -811,7 +870,7 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -811,7 +870,7 @@ class DeepseekV2MLAAttention(nn.Module):
cos_sin_cache = self.rotary_emb.cos_sin_cache cos_sin_cache = self.rotary_emb.cos_sin_cache
if cos_sin_cache.device != positions.device or cos_sin_cache.device != q.dtype: if cos_sin_cache.device != positions.device or cos_sin_cache.device != q.dtype:
cos_sin_cache = cos_sin_cache.to(positions.device, dtype=q.dtype) cos_sin_cache = cos_sin_cache.to(positions.device, dtype=q.dtype)
kv_c_normed = torch.empty(kv_c.shape, dtype=kv_c.dtype, device=kv_c.device) kv_c_normed = torch.empty(kv_c.shape, dtype=kv_c.dtype, device=kv_c.device)
attn_out = self.mla_attn( attn_out = self.mla_attn(
q[..., self.qk_nope_head_dim:], q[..., self.qk_nope_head_dim:],
kv_c, kv_c,
...@@ -823,7 +882,7 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -823,7 +882,7 @@ class DeepseekV2MLAAttention(nn.Module):
positions=positions, positions=positions,
weight=weight, weight=weight,
cos_sin_cache=cos_sin_cache) cos_sin_cache=cos_sin_cache)
packages_ = self.o_proj(attn_out, packages_ = self.o_proj(attn_out,
pa_rms_weight=pa_rms_weight, pa_rms_weight=pa_rms_weight,
pa_residual=pa_residual, pa_residual=pa_residual,
pa_rms_eps=pa_rms_eps, pa_rms_eps=pa_rms_eps,
...@@ -870,7 +929,7 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -870,7 +929,7 @@ class DeepseekV2MLAAttention(nn.Module):
cos_sin_cache = self.rotary_emb.cos_sin_cache cos_sin_cache = self.rotary_emb.cos_sin_cache
if cos_sin_cache.device != positions.device or cos_sin_cache.device != q.dtype: if cos_sin_cache.device != positions.device or cos_sin_cache.device != q.dtype:
cos_sin_cache = cos_sin_cache.to(positions.device, dtype=q.dtype) cos_sin_cache = cos_sin_cache.to(positions.device, dtype=q.dtype)
kv_c_normed = torch.empty(kv_c.shape, dtype=kv_c.dtype, device=kv_c.device) kv_c_normed = torch.empty(kv_c.shape, dtype=kv_c.dtype, device=kv_c.device)
attn_out = self.mla_attn( attn_out = self.mla_attn(
q[..., self.qk_nope_head_dim:], q[..., self.qk_nope_head_dim:],
kv_c, kv_c,
...@@ -975,7 +1034,7 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -975,7 +1034,7 @@ class DeepseekV2DecoderLayer(nn.Module):
self.use_fused_rms_quant = envs.USE_FUSED_RMS_QUANT self.use_fused_rms_quant = envs.USE_FUSED_RMS_QUANT
self.use_fused_custom_all_reduce = envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT self.use_fused_custom_all_reduce = envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT
def forward_fused_rmsquant( def forward_fused_rmsquant(
self, self,
...@@ -985,7 +1044,7 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -985,7 +1044,7 @@ class DeepseekV2DecoderLayer(nn.Module):
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# Fix residual FP16 overflow # Fix residual FP16 overflow
residual_fix_overflow = False residual_fix_overflow = False
assert self.input_layernorm.has_weight is True assert self.input_layernorm.has_weight is True
if residual is None: if residual is None:
residual = hidden_states residual = hidden_states
...@@ -1004,7 +1063,7 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1004,7 +1063,7 @@ class DeepseekV2DecoderLayer(nn.Module):
residual = residual residual = residual
) )
residual = new_residual residual = new_residual
if hidden_states.dtype == torch.float16: if hidden_states.dtype == torch.float16:
# rmsnorm, and rmsnorm result would not affect by scale. # rmsnorm, and rmsnorm result would not affect by scale.
hidden_states *= 1. / self.routed_scaling_factor hidden_states *= 1. / self.routed_scaling_factor
...@@ -1013,8 +1072,8 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1013,8 +1072,8 @@ class DeepseekV2DecoderLayer(nn.Module):
# first layer. # first layer.
residual *= 1. / self.routed_scaling_factor residual *= 1. / self.routed_scaling_factor
hidden_states, new_resi, _i_q, _scales = self.mlp(hidden_states, hidden_states, new_resi, _i_q, _scales = self.mlp(hidden_states,
rms_weight=self.post_attention_layernorm.weight.data, rms_weight=self.post_attention_layernorm.weight.data,
residual=residual, residual=residual,
) )
...@@ -1029,9 +1088,9 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1029,9 +1088,9 @@ class DeepseekV2DecoderLayer(nn.Module):
return hidden_states, new_resi return hidden_states, new_resi
def forward_fused_CRQ( def forward_fused_CRQ(
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
residual: Optional[torch.Tensor] residual: Optional[torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
residual_fix_overflow = False residual_fix_overflow = False
...@@ -1042,33 +1101,33 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1042,33 +1101,33 @@ class DeepseekV2DecoderLayer(nn.Module):
else: else:
hidden_states, resi_new = self.input_layernorm( hidden_states, resi_new = self.input_layernorm(
hidden_states, residual) hidden_states, residual)
residual = resi_new residual = resi_new
new_hs, new_resi, xq, xs = self.self_attn( new_hs, new_resi, xq, xs = self.self_attn(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
pa_rms_weight=self.post_attention_layernorm.weight.data, pa_rms_weight=self.post_attention_layernorm.weight.data,
pa_residual=residual, pa_residual=residual,
pa_rms_eps=self.post_attention_layernorm.variance_epsilon, pa_rms_eps=self.post_attention_layernorm.variance_epsilon,
pa_quant_dtype = torch.int8, pa_quant_dtype = torch.int8,
update_input=True update_input=True
) )
assert xq is not None and xs is not None assert xq is not None and xs is not None
if new_hs.dtype == torch.float16: # overflow处理逻辑 if new_hs.dtype == torch.float16: # overflow处理逻辑
new_hs *= 1. / self.routed_scaling_factor new_hs *= 1. / self.routed_scaling_factor
if self.layer_idx == 0 or residual_fix_overflow: if self.layer_idx == 0 or residual_fix_overflow:
new_resi *= 1. / self.routed_scaling_factor new_resi *= 1. / self.routed_scaling_factor
hidden_states = self.mlp(new_hs, xqxs=(xq, xs)) hidden_states = self.mlp(new_hs, xqxs=(xq, xs))
if isinstance(self.mlp, if isinstance(self.mlp,
DeepseekV2MLP) and hidden_states.dtype == torch.float16: DeepseekV2MLP) and hidden_states.dtype == torch.float16:
hidden_states *= 1. / self.routed_scaling_factor hidden_states *= 1. / self.routed_scaling_factor
return hidden_states, new_resi return hidden_states, new_resi
def forward_default( def forward_default(
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
residual: Optional[torch.Tensor] residual: Optional[torch.Tensor]
...@@ -1083,7 +1142,7 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1083,7 +1142,7 @@ class DeepseekV2DecoderLayer(nn.Module):
else: else:
hidden_states, residual = self.input_layernorm( hidden_states, residual = self.input_layernorm(
hidden_states, residual) hidden_states, residual)
if not self.is_mtp_layer: if not self.is_mtp_layer:
if isinstance(self.mlp, if isinstance(self.mlp,
DeepseekV2MoE) and self.use_deepep and self.tp_size > 1 and \ DeepseekV2MoE) and self.use_deepep and self.tp_size > 1 and \
...@@ -1117,7 +1176,7 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1117,7 +1176,7 @@ class DeepseekV2DecoderLayer(nn.Module):
# Fully Connected # Fully Connected
hidden_states, residual = self.post_attention_layernorm( hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual) hidden_states, residual)
if self.is_mtp_layer: if self.is_mtp_layer:
if isinstance(self.mlp, if isinstance(self.mlp,
DeepseekV2MoE) and self.use_deepep and self.tp_size > 1: DeepseekV2MoE) and self.use_deepep and self.tp_size > 1:
...@@ -1147,7 +1206,7 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1147,7 +1206,7 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states *= 1. / self.routed_scaling_factor hidden_states *= 1. / self.routed_scaling_factor
return hidden_states, residual return hidden_states, residual
def choose_forward(self): def choose_forward(self):
if self.use_fused_rms_quant: if self.use_fused_rms_quant:
return self.forward_fused_rmsquant return self.forward_fused_rmsquant
...@@ -1212,7 +1271,7 @@ class DeepseekV2Model(nn.Module): ...@@ -1212,7 +1271,7 @@ class DeepseekV2Model(nn.Module):
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory( make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size)) ["hidden_states", "residual"], config.hidden_size))
self.dp_size = get_dp_group().world_size self.dp_size = get_dp_group().world_size
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
parallel_config = vllm_config.parallel_config parallel_config = vllm_config.parallel_config
...@@ -1312,10 +1371,10 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts): ...@@ -1312,10 +1371,10 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
self.num_routed_experts = example_moe.n_routed_experts self.num_routed_experts = example_moe.n_routed_experts
self.num_shared_experts = example_moe.n_shared_experts self.num_shared_experts = example_moe.n_shared_experts
self.num_redundant_experts = example_moe.n_redundant_experts self.num_redundant_experts = example_moe.n_redundant_experts
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1' self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_awq_pad = os.environ.get('AWQ_PAD') == '1' self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
self.tritonsingleton= W8a8GetCacheJSON() self.tritonsingleton= W8a8GetCacheJSON()
self.tritonsingleton.topk = config.num_experts_per_tok self.tritonsingleton.topk = config.num_experts_per_tok
self.tritonsingleton.quant_method=self.quant_method self.tritonsingleton.quant_method=self.quant_method
...@@ -1371,22 +1430,22 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts): ...@@ -1371,22 +1430,22 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
dtype=dtype, dtype=dtype,
device=device), device=device),
}) })
def restore_qzeros_tensor(self, qzeros, qscales): def restore_qzeros_tensor(self, qzeros, qscales):
low_bits = qzeros & 0x0F low_bits = qzeros & 0x0F
high_bits = qzeros >> 4 high_bits = qzeros >> 4
zeors_tensor = torch.stack([low_bits, high_bits], dim=2).view(qzeros.shape[0], -1 , qzeros.shape[-1]) zeors_tensor = torch.stack([low_bits, high_bits], dim=2).view(qzeros.shape[0], -1 , qzeros.shape[-1])
zeors_int16 = zeors_tensor.to(torch.int16) zeors_int16 = zeors_tensor.to(torch.int16)
assert zeors_int16.shape == qscales.shape assert zeors_int16.shape == qscales.shape
uint16_tensor1 = zeors_int16.view(torch.uint16) uint16_tensor1 = zeors_int16.view(torch.uint16)
uint16_tensor2 = qscales.view(torch.uint16) uint16_tensor2 = qscales.view(torch.uint16)
uint32_tensor1 = uint16_tensor1.to(torch.int32) << 16 uint32_tensor1 = uint16_tensor1.to(torch.int32) << 16
uint32_tensor2 = uint16_tensor2.to(torch.int32) uint32_tensor2 = uint16_tensor2.to(torch.int32)
result_tensor = uint32_tensor1 + uint32_tensor2 result_tensor = uint32_tensor1 + uint32_tensor2
result_tensor =result_tensor.view(torch.uint32) result_tensor =result_tensor.view(torch.uint32)
result_tensor = result_tensor.transpose(1, 2).contiguous() result_tensor = result_tensor.transpose(1, 2).contiguous()
...@@ -1494,7 +1553,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts): ...@@ -1494,7 +1553,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
# However it's not mapped locally to this rank # However it's not mapped locally to this rank
# So we simply skip it # So we simply skip it
continue continue
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
...@@ -1515,7 +1574,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts): ...@@ -1515,7 +1574,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name) loaded_params.add(name)
if self.use_llama_nn and self.quant_method is None: if self.use_llama_nn and self.quant_method is None:
lay_key_words = [ lay_key_words = [
"self_attn.q_proj.weight", "self_attn.q_proj.weight",
...@@ -1533,19 +1592,19 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts): ...@@ -1533,19 +1592,19 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
] ]
combined_words = "|".join(lay_key_words) combined_words = "|".join(lay_key_words)
for layername in loaded_params: for layername in loaded_params:
weight = params_dict[layername] weight = params_dict[layername]
matches = re.findall(combined_words, layername) matches = re.findall(combined_words, layername)
if matches: if matches:
_weight = torch.zeros_like(weight.data) _weight = torch.zeros_like(weight.data)
ori_shape =_weight.shape ori_shape =_weight.shape
ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1]) ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
weight.data.copy_(_weight) weight.data.copy_(_weight)
weight.data=weight.data.reshape(ori_shape[1],-1) weight.data=weight.data.reshape(ori_shape[1],-1)
return loaded_params return loaded_params
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment