Commit bad2a99c authored by zhangqha's avatar zhangqha
Browse files

Merge branch 'v0.15.1-dev-wm-1' into 'v0.15.1-dev'

[perf]glm4_moe模型适配rmsquant和silu_quant融合算子

See merge request dcutoolkit/deeplearing/vllm!467
parents f38f6c1e 110bbdd5
...@@ -281,6 +281,7 @@ if TYPE_CHECKING: ...@@ -281,6 +281,7 @@ if TYPE_CHECKING:
VLLM_USE_LIGHTOP_MOE_ALIGN: bool = False VLLM_USE_LIGHTOP_MOE_ALIGN: bool = False
VLLM_USE_MERGE_ATTN_STATES_OPT: bool = False VLLM_USE_MERGE_ATTN_STATES_OPT: bool = False
USE_FUSED_RMS_QUANT: bool = False USE_FUSED_RMS_QUANT: bool = False
USE_FUSED_SILU_MUL_QUANT: bool = False
VLLM_USE_PD_SPLIT: bool = False VLLM_USE_PD_SPLIT: bool = False
VLLM_USE_PP_SYNC: bool = False VLLM_USE_PP_SYNC: bool = False
VLLM_USE_PIECEWISE: bool = False VLLM_USE_PIECEWISE: bool = False
...@@ -1804,6 +1805,11 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1804,6 +1805,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
# vllm will use rmsquant fused op # vllm will use rmsquant fused op
"USE_FUSED_RMS_QUANT": "USE_FUSED_RMS_QUANT":
lambda: bool(int(os.getenv("USE_FUSED_RMS_QUANT", "0"))), lambda: bool(int(os.getenv("USE_FUSED_RMS_QUANT", "0"))),
# vllm will use silu_mul_quant fused op
"USE_FUSED_SILU_MUL_QUANT":
lambda: bool(int(os.getenv("USE_FUSED_SILU_MUL_QUANT", "0"))),
# vLLM will split prefill and decode, not mix up # vLLM will split prefill and decode, not mix up
"VLLM_USE_PD_SPLIT": "VLLM_USE_PD_SPLIT":
lambda: (os.environ.get("VLLM_USE_PD_SPLIT", "True").lower() in lambda: (os.environ.get("VLLM_USE_PD_SPLIT", "True").lower() in
......
...@@ -6,6 +6,7 @@ import torch ...@@ -6,6 +6,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from typing import Optional
from vllm._aiter_ops import rocm_aiter_ops from vllm._aiter_ops import rocm_aiter_ops
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
...@@ -14,6 +15,7 @@ from vllm.model_executor.layers.batch_invariant import ( ...@@ -14,6 +15,7 @@ from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant, vllm_is_batch_invariant,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
from vllm import envs from vllm import envs
...@@ -298,6 +300,96 @@ class RMSNorm(CustomOp): ...@@ -298,6 +300,96 @@ class RMSNorm(CustomOp):
return s return s
class FusedRMSNormQuant(nn.Module):
"""Fuse Root mean square normalization and int8 quant.
Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
Refer to https://arxiv.org/abs/1910.07467
"""
def __init__(
self,
hidden_size: int,
eps: float = 1e-6,
var_hidden_size: int | None = None,
has_weight: bool = True,
dtype: torch.dtype | None = None,
) -> None:
super().__init__()
self.hidden_size = hidden_size
self.variance_epsilon = eps
self.variance_size_override = (
None if var_hidden_size == hidden_size else var_hidden_size
)
weight_dtype = dtype or torch.get_default_dtype()
self.has_weight = has_weight
self.weight = torch.ones(hidden_size, dtype=weight_dtype)
if self.has_weight:
self.weight = nn.Parameter(self.weight)
def forward(
self,
x: torch.Tensor,
residual: torch.Tensor | None = None,
quant_dtype: torch.dtype = torch.int8,
update_input: Optional[bool] = True
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
i_q, i_s = torch.ops.vllm.fused_rmsquant(input=x,
weight=self.weight,
epsilon=self.variance_epsilon,
quant_dtype=quant_dtype,
residual=residual,
update_input=update_input)
return i_q, i_s, residual
def fused_rmsquant_impl(
input: torch.Tensor,
weight: torch.Tensor,
epsilon: float,
quant_dtype: torch.dtype,
residual: Optional[torch.Tensor] = None,
update_input: Optional[bool] = True
) -> tuple[torch.Tensor, torch.Tensor]:
output = torch.empty_like(input, device=input.device, dtype=quant_dtype)
scales = torch.empty((input.numel() // input.shape[-1], 1),
device=input.device,
dtype=torch.float32)
from lightop.op import rms_norm_dynamic_per_token_quant as ligtop_rms_norm_dynamic_per_token_quant
ligtop_rms_norm_dynamic_per_token_quant(output, input, weight,
scales, epsilon,
residual, update_input)
return output, scales
def fused_rmsquant_fake(
input: torch.Tensor,
weight: torch.Tensor,
epsilon: float,
quant_dtype: torch.dtype,
residual: Optional[torch.Tensor] = None,
update_input: Optional[bool] = True
) -> tuple[torch.Tensor, torch.Tensor]:
"""Fake implementation for torch.compile"""
output = torch.empty_like(input, dtype=quant_dtype)
scales = torch.empty((input.numel() // input.shape[-1], 1),
device=input.device,
dtype=torch.float32)
return output, scales
# from torch.library import Library
# customer_lib = Library("customer_", "FRAGMENT")
direct_register_custom_op(
op_name="fused_rmsquant",
op_func=fused_rmsquant_impl,
mutates_args=[],
fake_impl=fused_rmsquant_fake,
)
# --8<-- [start:gemma_rms_norm] # --8<-- [start:gemma_rms_norm]
@CustomOp.register("gemma_rms_norm") @CustomOp.register("gemma_rms_norm")
class GemmaRMSNorm(CustomOp): class GemmaRMSNorm(CustomOp):
......
...@@ -654,11 +654,16 @@ class ColumnParallelLinear(LinearBase): ...@@ -654,11 +654,16 @@ class ColumnParallelLinear(LinearBase):
def forward( def forward(
self, self,
input_, input_,
*,
iqis: tuple[torch.Tensor, torch.Tensor] | None = None
) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]: ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
bias = self.bias if not self.skip_bias_add else None bias = self.bias if not self.skip_bias_add else None
# Matrix multiply. # Matrix multiply.
assert self.quant_method is not None assert self.quant_method is not None
if iqis is not None:
output_parallel = self.quant_method.apply(self, input_, bias, input_quant_args=iqis)
else:
output_parallel = self.quant_method.apply(self, input_, bias) output_parallel = self.quant_method.apply(self, input_, bias)
if self.gather_output and self.tp_size > 1: if self.gather_output and self.tp_size > 1:
...@@ -1523,6 +1528,8 @@ class RowParallelLinear(LinearBase): ...@@ -1523,6 +1528,8 @@ class RowParallelLinear(LinearBase):
def forward( def forward(
self, self,
input_, input_,
*,
iqis: tuple[torch.Tensor, torch.Tensor] | None = None
) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]: ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
if self.input_is_parallel: if self.input_is_parallel:
input_parallel = input_ input_parallel = input_
...@@ -1537,6 +1544,9 @@ class RowParallelLinear(LinearBase): ...@@ -1537,6 +1544,9 @@ class RowParallelLinear(LinearBase):
# Only fuse bias add into GEMM for rank 0 (this ensures that # Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in TP>1 case) # bias will not get added more than once in TP>1 case)
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
if iqis is not None:
output_parallel = self.quant_method.apply(self, input_parallel, bias_, input_quant_args=iqis)
else:
output_parallel = self.quant_method.apply(self, input_parallel, bias_) output_parallel = self.quant_method.apply(self, input_parallel, bias_)
if self.reduce_results and self.tp_size > 1: if self.reduce_results and self.tp_size > 1:
......
...@@ -7,6 +7,7 @@ import torch ...@@ -7,6 +7,7 @@ import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import W8a8GetCacheJSON from vllm.utils import W8a8GetCacheJSON
from vllm import envs
try: try:
from lmslim.layers.gemm.int8_utils import per_token_quant_int8 from lmslim.layers.gemm.int8_utils import per_token_quant_int8
...@@ -167,6 +168,15 @@ def apply_int8_linear( ...@@ -167,6 +168,15 @@ def apply_int8_linear(
# ops.scaled_int8_quant supports both dynamic and static quant. # ops.scaled_int8_quant supports both dynamic and static quant.
# * dynamic, layer.input_scale is None and x_scale computed from x. # * dynamic, layer.input_scale is None and x_scale computed from x.
# * static, layer.input_scale is scalar and x_scale is input_scale. # * static, layer.input_scale is scalar and x_scale is input_scale.
if envs.USE_FUSED_RMS_QUANT and input_quant_args is not None:
assert len(input_quant_args) == 2
x_zp =None
x_q, x_scale = input_quant_args
elif envs.USE_FUSED_RMS_QUANT and silu_quant_args is not None:
assert len(silu_quant_args) == 2
x_zp =None
x_q, x_scale = silu_quant_args
else:
symmetric = azp_adj is None symmetric = azp_adj is None
if input_scale is None and input_zero_point is None and symmetric is True: if input_scale is None and input_zero_point is None and symmetric is True:
x_q, x_scale=per_token_quant_int8(input) x_q, x_scale=per_token_quant_int8(input)
......
...@@ -44,7 +44,7 @@ from vllm.distributed import ( ...@@ -44,7 +44,7 @@ from vllm.distributed import (
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import SharedFusedMoE from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm, FusedRMSNormQuant
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear, MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
...@@ -108,8 +108,14 @@ class Glm4MoeMLP(nn.Module): ...@@ -108,8 +108,14 @@ class Glm4MoeMLP(nn.Module):
) )
self.act_fn = SiluAndMul() self.act_fn = SiluAndMul()
def forward(self, x): def forward(self, x,
gate_up, _ = self.gate_up_proj(x) iqis: tuple[torch.Tensor, torch.Tensor] | None = None):
gate_up, _ = self.gate_up_proj(x, iqis=iqis)
if envs.USE_FUSED_SILU_MUL_QUANT:
from lmslim.quantize.quant_ops import lm_fuse_silu_mul_quant
xq, xs = lm_fuse_silu_mul_quant(gate_up)
x, _ = self.down_proj(gate_up, iqis=(xq, xs))
else:
x = self.act_fn(gate_up) x = self.act_fn(gate_up)
x, _ = self.down_proj(x) x, _ = self.down_proj(x)
return x return x
...@@ -321,8 +327,9 @@ class Glm4MoeAttention(nn.Module): ...@@ -321,8 +327,9 @@ class Glm4MoeAttention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
iqis: tuple[torch.Tensor, torch.Tensor] | None = None
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states, iqis=iqis)
if not envs.VLLM_V1_USE_FUSED_QKV_SPLIT_RMS_ROPE_KVSTORE: if not envs.VLLM_V1_USE_FUSED_QKV_SPLIT_RMS_ROPE_KVSTORE:
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
...@@ -408,7 +415,16 @@ class Glm4MoeDecoderLayer(nn.Module): ...@@ -408,7 +415,16 @@ class Glm4MoeDecoderLayer(nn.Module):
prefix=f"{prefix}.mlp", prefix=f"{prefix}.mlp",
) )
if envs.USE_FUSED_RMS_QUANT:
self.input_layernorm = FusedRMSNormQuant(config.hidden_size, eps=config.rms_norm_eps)
else:
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
if envs.USE_FUSED_RMS_QUANT and isinstance(self.mlp, Glm4MoeMLP):
self.post_attention_layernorm = FusedRMSNormQuant(
config.hidden_size, eps=config.rms_norm_eps
)
else:
self.post_attention_layernorm = RMSNorm( self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps config.hidden_size, eps=config.rms_norm_eps
) )
...@@ -420,12 +436,33 @@ class Glm4MoeDecoderLayer(nn.Module): ...@@ -420,12 +436,33 @@ class Glm4MoeDecoderLayer(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
residual: torch.Tensor | None, residual: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
if not envs.USE_FUSED_RMS_QUANT:
if residual is None: if residual is None:
residual = hidden_states residual = hidden_states
hidden_states = self.input_layernorm(hidden_states) hidden_states = self.input_layernorm(hidden_states)
else: else:
hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states = self.self_attn(positions=positions, hidden_states=hidden_states) hidden_states = self.self_attn(positions=positions, hidden_states=hidden_states)
else:
if residual is None:
residual = hidden_states.clone()
i_q, i_s, _ = self.input_layernorm(x=hidden_states,
residual=None,
quant_dtype=torch.int8,
update_input=False
)
else:
i_q, i_s, residual = self.input_layernorm(x=hidden_states,
residual=residual,
quant_dtype=torch.int8,
update_input=False
)
hidden_states = self.self_attn(positions=positions, hidden_states=hidden_states, iqis=(i_q, i_s))
if envs.USE_FUSED_RMS_QUANT and isinstance(self.mlp, Glm4MoeMLP):
i_q, i_s, _ = self.post_attention_layernorm(hidden_states, residual)
hidden_states = self.mlp(hidden_states, iqis=(i_q, i_s))
else:
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
hidden_states = self.mlp(hidden_states) hidden_states = self.mlp(hidden_states)
return hidden_states, residual return hidden_states, residual
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment