Commit 58a36508 authored by wujl5's avatar wujl5
Browse files

perf:Deepseek v2模型增加rmsQuant和siluMulQuant融合

parent 7826240b
......@@ -1808,7 +1808,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
# vllm will use silu_mul_quant fused op
"USE_FUSED_SILU_MUL_QUANT":
lambda: bool(int(os.getenv("USE_FUSED_SILU_MUL_QUANT", "0"))),
lambda: (os.getenv("USE_FUSED_SILU_MUL_QUANT", "False").lower() in
("true", "1")),
# vLLM will split prefill and decode, not mix up
"VLLM_USE_PD_SPLIT":
......
......@@ -6,7 +6,7 @@ import os
from collections.abc import Callable, Iterable
from contextlib import nullcontext
from enum import Enum
from typing import Literal, cast, get_args, overload
from typing import Literal, cast, get_args, overload, Optional
import torch
import torch.nn.functional as F
......@@ -1669,6 +1669,8 @@ class FusedMoE(CustomOp):
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
i_q: torch.Tensor | None = None,
i_s: torch.Tensor | None = None
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
og_hidden_states = hidden_states.shape[-1]
if self.hidden_size != og_hidden_states:
......@@ -1720,7 +1722,9 @@ class FusedMoE(CustomOp):
)
else:
shared_output, fused_output = torch.ops.vllm.moe_forward_shared(
hidden_states, router_logits, encode_layer_name()
hidden_states, router_logits, encode_layer_name(),
i_q=i_q,
i_s=i_s
)
return (
reduce_output(shared_output)[..., :og_hidden_states],
......@@ -1737,8 +1741,10 @@ class FusedMoE(CustomOp):
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
i_q: torch.Tensor | None = None,
i_s: torch.Tensor | None = None
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
return self.forward_native(hidden_states, router_logits)
return self.forward_native(hidden_states, router_logits, i_q=i_q, i_s=i_s)
def forward_impl_chunked(
self,
......@@ -1880,6 +1886,8 @@ class FusedMoE(CustomOp):
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
use_fused_gate: bool | None = False,
i_q: torch.Tensor | None = None,
i_s: torch.Tensor | None = None
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.quant_method is not None
......@@ -2004,13 +2012,25 @@ class FusedMoE(CustomOp):
if self.capture is not None:
self.capture(topk_ids)
final_hidden_states = self.quant_method.apply(
if envs.USE_FUSED_RMS_QUANT:
final_hidden_states = self.quant_method.apply(
layer=self,
x=x, # The type signture of this is wrong due to the hack.
topk_weights=topk_weights,
topk_ids=topk_ids,
use_nn_moe=self.use_nn_moe,
i_q=i_q,
i_s=i_s
)
else:
final_hidden_states = self.quant_method.apply(
layer=self,
x=x, # The type signture of this is wrong due to the hack.
topk_weights=topk_weights,
topk_ids=topk_ids,
use_nn_moe=self.use_nn_moe,
use_nn_moe=self.use_nn_moe
)
if has_separate_shared_experts:
......@@ -2133,16 +2153,20 @@ def moe_forward(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
layer_name: str,
i_q: torch.Tensor | None = None,
i_s: torch.Tensor | None = None
) -> torch.Tensor:
self = get_layer_from_name(layer_name)
assert self.shared_experts is None
return self.forward_impl(hidden_states, router_logits)
return self.forward_impl(hidden_states, router_logits, i_q=i_q, i_s=i_s)
def moe_forward_fake(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
layer_name: str,
i_q: torch.Tensor | None = None,
i_s: torch.Tensor | None = None
) -> torch.Tensor:
return torch.empty_like(hidden_states)
......@@ -2160,16 +2184,23 @@ def moe_forward_shared(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
layer_name: str,
i_q: torch.Tensor | None = None,
i_s: torch.Tensor | None = None
) -> tuple[torch.Tensor, torch.Tensor]:
self = get_layer_from_name(layer_name)
assert self.shared_experts is not None
return self.forward_impl(hidden_states, router_logits)
if envs.USE_FUSED_RMS_QUANT:
return self.forward_impl(hidden_states, router_logits, i_q=i_q, i_s=i_s)
else:
return self.forward_impl(hidden_states, router_logits)
def moe_forward_shared_fake(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
layer_name: str,
i_q: torch.Tensor | None = None,
i_s: torch.Tensor | None = None
) -> tuple[torch.Tensor, torch.Tensor]:
shared_out = torch.empty_like(hidden_states)
fused_out = torch.empty_like(hidden_states)
......
......@@ -60,10 +60,13 @@ class SharedFusedMoE(FusedMoE):
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
*, iqis: tuple[torch.Tensor, torch.Tensor] | None = None
) -> tuple[torch.Tensor, torch.Tensor]:
if not self.use_overlapped:
if self._shared_experts is not None:
shared_out = self._shared_experts(hidden_states)
shared_out = self._shared_experts(hidden_states,
i_q=iqis[0] if iqis is not None else None,
i_s=iqis[1] if iqis is not None else None)
# Reduce shared expert outputs if necessary, since the MLP
# should have been created with reduce_results=False.
......@@ -79,11 +82,15 @@ class SharedFusedMoE(FusedMoE):
fused_out = super().forward(
hidden_states=hidden_states,
router_logits=router_logits,
i_q=iqis[0] if iqis is not None else None,
i_s=iqis[1] if iqis is not None else None,
)
else:
shared_out, fused_out = super().forward(
hidden_states=hidden_states,
router_logits=router_logits,
i_q=iqis[0] if iqis is not None else None,
i_s=iqis[1] if iqis is not None else None,
)
# ensure early TP reduction of shared expert outputs when required
if (
......
......@@ -370,7 +370,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
use_nn_moe: bool | None = False,
use_nn_moe: bool | None = False, **_
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
return self.forward(
layer=layer,
......
......@@ -711,6 +711,31 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
disable_tp: If true, all weights matrix won't be sharded, this layer
will be treated as a "Replicated" MergedLinear.
"""
def forward(
self,
input_,
*, iqis: tuple[torch.Tensor, torch.Tensor] | None = None
) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
bias = self.bias if not self.skip_bias_add else None
# Matrix multiply.
assert self.quant_method is not None
if envs.USE_FUSED_RMS_QUANT and iqis is not None:
output_parallel = self.quant_method.apply(self, input_, bias, input_quant_args=iqis)
else:
output_parallel = self.quant_method.apply(self, input_, bias)
if self.gather_output and self.tp_size > 1:
# All-gather across the partitions.
output = tensor_model_parallel_all_gather(output_parallel)
else:
output = output_parallel
if not self.return_bias:
return output
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
def __init__(
self,
......
......@@ -8,6 +8,7 @@ from vllm.attention.layer import MLAAttention
from vllm.config import CacheConfig
from vllm.model_executor.custom_op import PluggableLayer
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm import envs
@dataclass
......@@ -115,6 +116,7 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
positions: torch.Tensor,
hidden_states: torch.Tensor,
llama_4_scaling: torch.Tensor | None = None,
*, iqis: tuple[torch.Tensor, torch.Tensor] | None = None
) -> torch.Tensor:
q_c = None
kv_lora = None
......@@ -129,7 +131,10 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
assert self.q_b_proj is not None, (
"q_b_proj is required when q_lora_rank is not None"
)
qkv_lora = self.fused_qkv_a_proj(hidden_states)[0]
if envs.USE_FUSED_RMS_QUANT and iqis is not None:
qkv_lora = self.fused_qkv_a_proj(hidden_states, iqis=iqis)[0]
else:
qkv_lora = self.fused_qkv_a_proj(hidden_states)[0]
q_c, kv_lora = qkv_lora.split(
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
dim=-1,
......
......@@ -1255,6 +1255,8 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
use_nn_moe: bool | None = False,
i_q: torch.Tensor | None = None,
i_s: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts
......@@ -1271,6 +1273,8 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
expert_map=layer.expert_map,
quant_config=self.moe_quant_config,
use_nn_moe=use_nn_moe,
i_q=i_q,
i_s=i_s
)
......
......@@ -398,6 +398,8 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
use_nn_moe: bool | None = False,
i_q: torch.Tensor | None = None,
i_s: torch.Tensor | None = None
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts
......@@ -420,6 +422,8 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
use_nn_moe=False,
i_q=i_q,
i_s=i_s
)
def select_gemm_impl(
......
......@@ -94,6 +94,8 @@ from .utils import (
from vllm import _custom_ops as ops
from vllm.utils import W8a8GetCacheJSON
from vllm.model_executor.layers.layernorm import FusedRMSNormQuant
logger = init_logger(__name__)
......@@ -169,6 +171,7 @@ class DeepseekAttention(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
*, iqis: tuple[torch.Tensor, torch.Tensor] | None = None
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
......@@ -218,10 +221,23 @@ class DeepseekV2MLP(nn.Module):
)
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
def forward(self,
x,
*, iqis: tuple[torch.Tensor, torch.Tensor] | None = None
):
if envs.USE_FUSED_RMS_QUANT:
gate_up, _ = self.gate_up_proj(x, iqis=iqis)
if envs.USE_FUSED_SILU_MUL_QUANT:
from lmslim.quantize.quant_ops import lm_fuse_silu_mul_quant
xq, xs = lm_fuse_silu_mul_quant(gate_up)
x, _ = self.down_proj(gate_up, iqis=(xq, xs))
else:
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
else:
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
......@@ -334,7 +350,9 @@ class DeepseekV2MoE(nn.Module):
else None,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
def forward(self, hidden_states: torch.Tensor,
*, iqis: tuple[torch.Tensor, torch.Tensor] | None = None
) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
......@@ -528,6 +546,7 @@ class DeepseekV2Attention(nn.Module):
positions: torch.Tensor,
hidden_states: torch.Tensor,
llama_4_scaling: torch.Tensor | None,
*, iqis: tuple[torch.Tensor, torch.Tensor] | None = None
) -> torch.Tensor:
if self.q_lora_rank is not None:
q = self.q_a_proj(hidden_states)[0]
......@@ -907,8 +926,9 @@ class DeepseekV2MLAAttention(nn.Module):
positions: torch.Tensor,
hidden_states: torch.Tensor,
llama_4_scaling: torch.Tensor | None,
*, iqis: tuple[torch.Tensor, torch.Tensor] | None = None
) -> torch.Tensor:
return self.mla_attn(positions, hidden_states, llama_4_scaling)
return self.mla_attn(positions, hidden_states, llama_4_scaling, iqis=iqis)
class DeepseekV2DecoderLayer(nn.Module):
......@@ -989,13 +1009,91 @@ class DeepseekV2DecoderLayer(nn.Module):
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
if not envs.USE_FUSED_RMS_QUANT:
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
else:
self.input_layernorm = FusedRMSNormQuant(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = FusedRMSNormQuant(
config.hidden_size, eps=config.rms_norm_eps
)
self.routed_scaling_factor = getattr(config, "routed_scaling_factor", 1.0)
def forward(
def forward_RQ(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: torch.Tensor | None,
llama_4_scaling: torch.Tensor | None = None,
) -> torch.Tensor:
# Self Attention
# Fix residual FP16 overflow
residual_fix_overflow = False
assert self.input_layernorm.has_weight is True
if residual is None:
residual = hidden_states.clone()
i_q, i_s, _ = self.input_layernorm(x=hidden_states,
residual=None,
quant_dtype=torch.int8,
update_input=False
)
residual_fix_overflow = True
else:
i_q, i_s, residual = self.input_layernorm(x=hidden_states,
residual=residual,
quant_dtype=torch.int8,
update_input=False
)
attn_kwargs = {
"positions": positions,
"hidden_states": hidden_states,
"iqis": (i_q, i_s)
}
if not self.use_mha:
attn_kwargs["llama_4_scaling"] = llama_4_scaling
hidden_states = self.self_attn(**attn_kwargs)
if (
not isinstance(self.self_attn, DeepseekAttention)
and hidden_states.dtype == torch.float16
):
# Fix FP16 overflow
# We scale both hidden_states and residual before
# rmsnorm, and rmsnorm result would not affect by scale.
hidden_states *= 1.0 / self.routed_scaling_factor
if self.layer_idx == 0:
# The residual is shared by all layers, we only scale it on
# first layer.
residual *= 1.0 / self.routed_scaling_factor
# Fully Connected
update_hs = True if isinstance(self.mlp, DeepseekV2MoE) else False
assert self.post_attention_layernorm.has_weight is True
_i_q, _i_s, residual = self.post_attention_layernorm(x=hidden_states,
residual=residual,
quant_dtype=torch.int8,
update_input=update_hs
)
new_resi = residual
hidden_states = self.mlp(hidden_states,
# iqis=(_i_q, _i_s) # TODO:wjl
)
if isinstance(self.mlp, DeepseekV2MLP) and hidden_states.dtype == torch.float16:
# Fix FP16 overflow
# Scaling the DeepseekV2MLP output, it is the input of
# input_layernorm of next decoder layer.
# The scaling of DeepseekV2MOE output would be done in the forward
# of DeepseekV2MOE
hidden_states *= 1.0 / self.routed_scaling_factor
return hidden_states, new_resi
def forward_default(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
......@@ -1048,6 +1146,25 @@ class DeepseekV2DecoderLayer(nn.Module):
return hidden_states, residual
def choose_forward(self):
if envs.USE_FUSED_RMS_QUANT:
return self.forward_RQ
else:
return self.forward_default
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: torch.Tensor | None,
llama_4_scaling: torch.Tensor | None = None,
) -> torch.Tensor:
forward_func = self.choose_forward()
return forward_func(positions=positions,
hidden_states=hidden_states,
residual=residual,
llama_4_scaling=llama_4_scaling)
@support_torch_compile
class DeepseekV2Model(nn.Module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment