Commit cca00f5c authored by zhangqha's avatar zhangqha
Browse files

Merge branch 'v0.15.1-dev_RQ_SiluMulQuant' into 'v0.15.1-dev'

perf:Deepseek v2模型增加rmsQuant和siluMulQuant融合

See merge request dcutoolkit/deeplearing/vllm!468
parents 7826240b 58a36508
...@@ -1808,7 +1808,8 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1808,7 +1808,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
# vllm will use silu_mul_quant fused op # vllm will use silu_mul_quant fused op
"USE_FUSED_SILU_MUL_QUANT": "USE_FUSED_SILU_MUL_QUANT":
lambda: bool(int(os.getenv("USE_FUSED_SILU_MUL_QUANT", "0"))), lambda: (os.getenv("USE_FUSED_SILU_MUL_QUANT", "False").lower() in
("true", "1")),
# vLLM will split prefill and decode, not mix up # vLLM will split prefill and decode, not mix up
"VLLM_USE_PD_SPLIT": "VLLM_USE_PD_SPLIT":
......
...@@ -6,7 +6,7 @@ import os ...@@ -6,7 +6,7 @@ import os
from collections.abc import Callable, Iterable from collections.abc import Callable, Iterable
from contextlib import nullcontext from contextlib import nullcontext
from enum import Enum from enum import Enum
from typing import Literal, cast, get_args, overload from typing import Literal, cast, get_args, overload, Optional
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -1669,6 +1669,8 @@ class FusedMoE(CustomOp): ...@@ -1669,6 +1669,8 @@ class FusedMoE(CustomOp):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
i_q: torch.Tensor | None = None,
i_s: torch.Tensor | None = None
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
og_hidden_states = hidden_states.shape[-1] og_hidden_states = hidden_states.shape[-1]
if self.hidden_size != og_hidden_states: if self.hidden_size != og_hidden_states:
...@@ -1720,7 +1722,9 @@ class FusedMoE(CustomOp): ...@@ -1720,7 +1722,9 @@ class FusedMoE(CustomOp):
) )
else: else:
shared_output, fused_output = torch.ops.vllm.moe_forward_shared( shared_output, fused_output = torch.ops.vllm.moe_forward_shared(
hidden_states, router_logits, encode_layer_name() hidden_states, router_logits, encode_layer_name(),
i_q=i_q,
i_s=i_s
) )
return ( return (
reduce_output(shared_output)[..., :og_hidden_states], reduce_output(shared_output)[..., :og_hidden_states],
...@@ -1737,8 +1741,10 @@ class FusedMoE(CustomOp): ...@@ -1737,8 +1741,10 @@ class FusedMoE(CustomOp):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
i_q: torch.Tensor | None = None,
i_s: torch.Tensor | None = None
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
return self.forward_native(hidden_states, router_logits) return self.forward_native(hidden_states, router_logits, i_q=i_q, i_s=i_s)
def forward_impl_chunked( def forward_impl_chunked(
self, self,
...@@ -1880,6 +1886,8 @@ class FusedMoE(CustomOp): ...@@ -1880,6 +1886,8 @@ class FusedMoE(CustomOp):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
use_fused_gate: bool | None = False, use_fused_gate: bool | None = False,
i_q: torch.Tensor | None = None,
i_s: torch.Tensor | None = None
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.quant_method is not None assert self.quant_method is not None
...@@ -2004,13 +2012,25 @@ class FusedMoE(CustomOp): ...@@ -2004,13 +2012,25 @@ class FusedMoE(CustomOp):
if self.capture is not None: if self.capture is not None:
self.capture(topk_ids) self.capture(topk_ids)
final_hidden_states = self.quant_method.apply( if envs.USE_FUSED_RMS_QUANT:
final_hidden_states = self.quant_method.apply(
layer=self,
x=x, # The type signture of this is wrong due to the hack.
topk_weights=topk_weights,
topk_ids=topk_ids,
use_nn_moe=self.use_nn_moe,
i_q=i_q,
i_s=i_s
)
else:
final_hidden_states = self.quant_method.apply(
layer=self, layer=self,
x=x, # The type signture of this is wrong due to the hack. x=x, # The type signture of this is wrong due to the hack.
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
use_nn_moe=self.use_nn_moe, use_nn_moe=self.use_nn_moe
) )
if has_separate_shared_experts: if has_separate_shared_experts:
...@@ -2133,16 +2153,20 @@ def moe_forward( ...@@ -2133,16 +2153,20 @@ def moe_forward(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
layer_name: str, layer_name: str,
i_q: torch.Tensor | None = None,
i_s: torch.Tensor | None = None
) -> torch.Tensor: ) -> torch.Tensor:
self = get_layer_from_name(layer_name) self = get_layer_from_name(layer_name)
assert self.shared_experts is None assert self.shared_experts is None
return self.forward_impl(hidden_states, router_logits) return self.forward_impl(hidden_states, router_logits, i_q=i_q, i_s=i_s)
def moe_forward_fake( def moe_forward_fake(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
layer_name: str, layer_name: str,
i_q: torch.Tensor | None = None,
i_s: torch.Tensor | None = None
) -> torch.Tensor: ) -> torch.Tensor:
return torch.empty_like(hidden_states) return torch.empty_like(hidden_states)
...@@ -2160,16 +2184,23 @@ def moe_forward_shared( ...@@ -2160,16 +2184,23 @@ def moe_forward_shared(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
layer_name: str, layer_name: str,
i_q: torch.Tensor | None = None,
i_s: torch.Tensor | None = None
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
self = get_layer_from_name(layer_name) self = get_layer_from_name(layer_name)
assert self.shared_experts is not None assert self.shared_experts is not None
return self.forward_impl(hidden_states, router_logits) if envs.USE_FUSED_RMS_QUANT:
return self.forward_impl(hidden_states, router_logits, i_q=i_q, i_s=i_s)
else:
return self.forward_impl(hidden_states, router_logits)
def moe_forward_shared_fake( def moe_forward_shared_fake(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
layer_name: str, layer_name: str,
i_q: torch.Tensor | None = None,
i_s: torch.Tensor | None = None
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
shared_out = torch.empty_like(hidden_states) shared_out = torch.empty_like(hidden_states)
fused_out = torch.empty_like(hidden_states) fused_out = torch.empty_like(hidden_states)
......
...@@ -60,10 +60,13 @@ class SharedFusedMoE(FusedMoE): ...@@ -60,10 +60,13 @@ class SharedFusedMoE(FusedMoE):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
*, iqis: tuple[torch.Tensor, torch.Tensor] | None = None
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
if not self.use_overlapped: if not self.use_overlapped:
if self._shared_experts is not None: if self._shared_experts is not None:
shared_out = self._shared_experts(hidden_states) shared_out = self._shared_experts(hidden_states,
i_q=iqis[0] if iqis is not None else None,
i_s=iqis[1] if iqis is not None else None)
# Reduce shared expert outputs if necessary, since the MLP # Reduce shared expert outputs if necessary, since the MLP
# should have been created with reduce_results=False. # should have been created with reduce_results=False.
...@@ -79,11 +82,15 @@ class SharedFusedMoE(FusedMoE): ...@@ -79,11 +82,15 @@ class SharedFusedMoE(FusedMoE):
fused_out = super().forward( fused_out = super().forward(
hidden_states=hidden_states, hidden_states=hidden_states,
router_logits=router_logits, router_logits=router_logits,
i_q=iqis[0] if iqis is not None else None,
i_s=iqis[1] if iqis is not None else None,
) )
else: else:
shared_out, fused_out = super().forward( shared_out, fused_out = super().forward(
hidden_states=hidden_states, hidden_states=hidden_states,
router_logits=router_logits, router_logits=router_logits,
i_q=iqis[0] if iqis is not None else None,
i_s=iqis[1] if iqis is not None else None,
) )
# ensure early TP reduction of shared expert outputs when required # ensure early TP reduction of shared expert outputs when required
if ( if (
......
...@@ -370,7 +370,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -370,7 +370,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
x: torch.Tensor, x: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
use_nn_moe: bool | None = False, use_nn_moe: bool | None = False, **_
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
return self.forward( return self.forward(
layer=layer, layer=layer,
......
...@@ -711,6 +711,31 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -711,6 +711,31 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
disable_tp: If true, all weights matrix won't be sharded, this layer disable_tp: If true, all weights matrix won't be sharded, this layer
will be treated as a "Replicated" MergedLinear. will be treated as a "Replicated" MergedLinear.
""" """
def forward(
self,
input_,
*, iqis: tuple[torch.Tensor, torch.Tensor] | None = None
) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
bias = self.bias if not self.skip_bias_add else None
# Matrix multiply.
assert self.quant_method is not None
if envs.USE_FUSED_RMS_QUANT and iqis is not None:
output_parallel = self.quant_method.apply(self, input_, bias, input_quant_args=iqis)
else:
output_parallel = self.quant_method.apply(self, input_, bias)
if self.gather_output and self.tp_size > 1:
# All-gather across the partitions.
output = tensor_model_parallel_all_gather(output_parallel)
else:
output = output_parallel
if not self.return_bias:
return output
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
def __init__( def __init__(
self, self,
......
...@@ -8,6 +8,7 @@ from vllm.attention.layer import MLAAttention ...@@ -8,6 +8,7 @@ from vllm.attention.layer import MLAAttention
from vllm.config import CacheConfig from vllm.config import CacheConfig
from vllm.model_executor.custom_op import PluggableLayer from vllm.model_executor.custom_op import PluggableLayer
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm import envs
@dataclass @dataclass
...@@ -115,6 +116,7 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer): ...@@ -115,6 +116,7 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
llama_4_scaling: torch.Tensor | None = None, llama_4_scaling: torch.Tensor | None = None,
*, iqis: tuple[torch.Tensor, torch.Tensor] | None = None
) -> torch.Tensor: ) -> torch.Tensor:
q_c = None q_c = None
kv_lora = None kv_lora = None
...@@ -129,7 +131,10 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer): ...@@ -129,7 +131,10 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
assert self.q_b_proj is not None, ( assert self.q_b_proj is not None, (
"q_b_proj is required when q_lora_rank is not None" "q_b_proj is required when q_lora_rank is not None"
) )
qkv_lora = self.fused_qkv_a_proj(hidden_states)[0] if envs.USE_FUSED_RMS_QUANT and iqis is not None:
qkv_lora = self.fused_qkv_a_proj(hidden_states, iqis=iqis)[0]
else:
qkv_lora = self.fused_qkv_a_proj(hidden_states)[0]
q_c, kv_lora = qkv_lora.split( q_c, kv_lora = qkv_lora.split(
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
dim=-1, dim=-1,
......
...@@ -1255,6 +1255,8 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): ...@@ -1255,6 +1255,8 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
use_nn_moe: bool | None = False, use_nn_moe: bool | None = False,
i_q: torch.Tensor | None = None,
i_s: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
...@@ -1271,6 +1273,8 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): ...@@ -1271,6 +1273,8 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
expert_map=layer.expert_map, expert_map=layer.expert_map,
quant_config=self.moe_quant_config, quant_config=self.moe_quant_config,
use_nn_moe=use_nn_moe, use_nn_moe=use_nn_moe,
i_q=i_q,
i_s=i_s
) )
......
...@@ -398,6 +398,8 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod) ...@@ -398,6 +398,8 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
use_nn_moe: bool | None = False, use_nn_moe: bool | None = False,
i_q: torch.Tensor | None = None,
i_s: torch.Tensor | None = None
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
...@@ -420,6 +422,8 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod) ...@@ -420,6 +422,8 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
a1_scale=layer.w13_input_scale, a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale, a2_scale=layer.w2_input_scale,
use_nn_moe=False, use_nn_moe=False,
i_q=i_q,
i_s=i_s
) )
def select_gemm_impl( def select_gemm_impl(
......
...@@ -94,6 +94,8 @@ from .utils import ( ...@@ -94,6 +94,8 @@ from .utils import (
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.utils import W8a8GetCacheJSON from vllm.utils import W8a8GetCacheJSON
from vllm.model_executor.layers.layernorm import FusedRMSNormQuant
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -169,6 +171,7 @@ class DeepseekAttention(nn.Module): ...@@ -169,6 +171,7 @@ class DeepseekAttention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
*, iqis: tuple[torch.Tensor, torch.Tensor] | None = None
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
...@@ -218,10 +221,23 @@ class DeepseekV2MLP(nn.Module): ...@@ -218,10 +221,23 @@ class DeepseekV2MLP(nn.Module):
) )
self.act_fn = SiluAndMul() self.act_fn = SiluAndMul()
def forward(self, x): def forward(self,
gate_up, _ = self.gate_up_proj(x) x,
x = self.act_fn(gate_up) *, iqis: tuple[torch.Tensor, torch.Tensor] | None = None
x, _ = self.down_proj(x) ):
if envs.USE_FUSED_RMS_QUANT:
gate_up, _ = self.gate_up_proj(x, iqis=iqis)
if envs.USE_FUSED_SILU_MUL_QUANT:
from lmslim.quantize.quant_ops import lm_fuse_silu_mul_quant
xq, xs = lm_fuse_silu_mul_quant(gate_up)
x, _ = self.down_proj(gate_up, iqis=(xq, xs))
else:
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
else:
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x return x
...@@ -334,7 +350,9 @@ class DeepseekV2MoE(nn.Module): ...@@ -334,7 +350,9 @@ class DeepseekV2MoE(nn.Module):
else None, else None,
) )
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor,
*, iqis: tuple[torch.Tensor, torch.Tensor] | None = None
) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim) hidden_states = hidden_states.view(-1, hidden_dim)
...@@ -528,6 +546,7 @@ class DeepseekV2Attention(nn.Module): ...@@ -528,6 +546,7 @@ class DeepseekV2Attention(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
llama_4_scaling: torch.Tensor | None, llama_4_scaling: torch.Tensor | None,
*, iqis: tuple[torch.Tensor, torch.Tensor] | None = None
) -> torch.Tensor: ) -> torch.Tensor:
if self.q_lora_rank is not None: if self.q_lora_rank is not None:
q = self.q_a_proj(hidden_states)[0] q = self.q_a_proj(hidden_states)[0]
...@@ -907,8 +926,9 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -907,8 +926,9 @@ class DeepseekV2MLAAttention(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
llama_4_scaling: torch.Tensor | None, llama_4_scaling: torch.Tensor | None,
*, iqis: tuple[torch.Tensor, torch.Tensor] | None = None
) -> torch.Tensor: ) -> torch.Tensor:
return self.mla_attn(positions, hidden_states, llama_4_scaling) return self.mla_attn(positions, hidden_states, llama_4_scaling, iqis=iqis)
class DeepseekV2DecoderLayer(nn.Module): class DeepseekV2DecoderLayer(nn.Module):
...@@ -989,13 +1009,91 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -989,13 +1009,91 @@ class DeepseekV2DecoderLayer(nn.Module):
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.mlp", prefix=f"{prefix}.mlp",
) )
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm( if not envs.USE_FUSED_RMS_QUANT:
config.hidden_size, eps=config.rms_norm_eps self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
) self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
else:
self.input_layernorm = FusedRMSNormQuant(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = FusedRMSNormQuant(
config.hidden_size, eps=config.rms_norm_eps
)
self.routed_scaling_factor = getattr(config, "routed_scaling_factor", 1.0) self.routed_scaling_factor = getattr(config, "routed_scaling_factor", 1.0)
def forward( def forward_RQ(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: torch.Tensor | None,
llama_4_scaling: torch.Tensor | None = None,
) -> torch.Tensor:
# Self Attention
# Fix residual FP16 overflow
residual_fix_overflow = False
assert self.input_layernorm.has_weight is True
if residual is None:
residual = hidden_states.clone()
i_q, i_s, _ = self.input_layernorm(x=hidden_states,
residual=None,
quant_dtype=torch.int8,
update_input=False
)
residual_fix_overflow = True
else:
i_q, i_s, residual = self.input_layernorm(x=hidden_states,
residual=residual,
quant_dtype=torch.int8,
update_input=False
)
attn_kwargs = {
"positions": positions,
"hidden_states": hidden_states,
"iqis": (i_q, i_s)
}
if not self.use_mha:
attn_kwargs["llama_4_scaling"] = llama_4_scaling
hidden_states = self.self_attn(**attn_kwargs)
if (
not isinstance(self.self_attn, DeepseekAttention)
and hidden_states.dtype == torch.float16
):
# Fix FP16 overflow
# We scale both hidden_states and residual before
# rmsnorm, and rmsnorm result would not affect by scale.
hidden_states *= 1.0 / self.routed_scaling_factor
if self.layer_idx == 0:
# The residual is shared by all layers, we only scale it on
# first layer.
residual *= 1.0 / self.routed_scaling_factor
# Fully Connected
update_hs = True if isinstance(self.mlp, DeepseekV2MoE) else False
assert self.post_attention_layernorm.has_weight is True
_i_q, _i_s, residual = self.post_attention_layernorm(x=hidden_states,
residual=residual,
quant_dtype=torch.int8,
update_input=update_hs
)
new_resi = residual
hidden_states = self.mlp(hidden_states,
# iqis=(_i_q, _i_s) # TODO:wjl
)
if isinstance(self.mlp, DeepseekV2MLP) and hidden_states.dtype == torch.float16:
# Fix FP16 overflow
# Scaling the DeepseekV2MLP output, it is the input of
# input_layernorm of next decoder layer.
# The scaling of DeepseekV2MOE output would be done in the forward
# of DeepseekV2MOE
hidden_states *= 1.0 / self.routed_scaling_factor
return hidden_states, new_resi
def forward_default(
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -1048,6 +1146,25 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1048,6 +1146,25 @@ class DeepseekV2DecoderLayer(nn.Module):
return hidden_states, residual return hidden_states, residual
def choose_forward(self):
if envs.USE_FUSED_RMS_QUANT:
return self.forward_RQ
else:
return self.forward_default
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: torch.Tensor | None,
llama_4_scaling: torch.Tensor | None = None,
) -> torch.Tensor:
forward_func = self.choose_forward()
return forward_func(positions=positions,
hidden_states=hidden_states,
residual=residual,
llama_4_scaling=llama_4_scaling)
@support_torch_compile @support_torch_compile
class DeepseekV2Model(nn.Module): class DeepseekV2Model(nn.Module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment