Commit a1628458 authored by 王敏's avatar 王敏
Browse files

[feat]添加rmsnorm+int8 quant融合module

parent c9733a54
...@@ -6,6 +6,7 @@ import torch ...@@ -6,6 +6,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from typing import Optional
from vllm._aiter_ops import rocm_aiter_ops from vllm._aiter_ops import rocm_aiter_ops
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
...@@ -14,8 +15,11 @@ from vllm.model_executor.layers.batch_invariant import ( ...@@ -14,8 +15,11 @@ from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant, vllm_is_batch_invariant,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
from vllm import envs from vllm import envs
from lightop import rms_norm_dynamic_per_token_quant
def rms_norm( def rms_norm(
x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
...@@ -298,6 +302,100 @@ class RMSNorm(CustomOp): ...@@ -298,6 +302,100 @@ class RMSNorm(CustomOp):
return s return s
class FusedRMSNormQuant(nn.Module):
"""Fuse Root mean square normalization and int8 quant.
Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
Refer to https://arxiv.org/abs/1910.07467
"""
def __init__(
self,
hidden_size: int,
eps: float = 1e-6,
var_hidden_size: int | None = None,
has_weight: bool = True,
dtype: torch.dtype | None = None,
) -> None:
super().__init__()
self.hidden_size = hidden_size
self.variance_epsilon = eps
self.variance_size_override = (
None if var_hidden_size == hidden_size else var_hidden_size
)
weight_dtype = dtype or torch.get_default_dtype()
self.has_weight = has_weight
self.weight = torch.ones(hidden_size, dtype=weight_dtype)
if self.has_weight:
self.weight = nn.Parameter(self.weight)
def forward(
self,
x: torch.Tensor,
residual: torch.Tensor | None = None,
quant_dtype: torch.dtype = torch.int8
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
x, x_scales = fused_rmsquant(x, self.weight,
self.variance_epsilon,
quant_dtype, residual)
return x, x_scales, residual
def fused_rmsquant_impl(
input: torch.Tensor,
weight: torch.Tensor,
epsilon: float,
quant_dtype: torch.dtype,
residual: Optional[torch.Tensor] = None,
update_input: Optional[bool] = True
) -> tuple[torch.Tensor, torch.Tensor]:
output, scales = rms_norm_dynamic_per_token_quant(input,
weight,
epsilon,
quant_dtype,
residual,
update_input)
return output, scales
def fused_rmsquant_fake(
input: torch.Tensor,
weight: torch.Tensor,
epsilon: float,
quant_dtype: torch.dtype,
residual: Optional[torch.Tensor] = None,
update_input: Optional[bool] = True
) -> tuple[torch.Tensor, torch.Tensor]:
"""Fake implementation for torch.compile"""
output = torch.zeros_like(input, dtype=quant_dtype)
scales = torch.ones((input.numel() // input.shape[-1], 1),
device=input.device,
dtype=torch.float32)
return output, scales
direct_register_custom_op(
op_name="rms_norm_dynamic_per_token_quant",
op_func=fused_rmsquant_impl,
mutates_args=[],
fake_impl=fused_rmsquant_fake,
)
def fused_rmsquant(input: torch.Tensor,
rms_weight: torch.Tensor,
epsilon: float,
quant_dtype: torch.dtype,
residual: Optional[torch.Tensor] = None,
update_input: Optional[bool] = True):
i_q, _scales = torch.ops.vllm.fused_rmsquant(input=input,
weight=rms_weight,
epsilon=epsilon,
quant_dtype=quant_dtype,
residual=residual,
update_input=update_input)
return i_q, _scales
# --8<-- [start:gemma_rms_norm] # --8<-- [start:gemma_rms_norm]
@CustomOp.register("gemma_rms_norm") @CustomOp.register("gemma_rms_norm")
class GemmaRMSNorm(CustomOp): class GemmaRMSNorm(CustomOp):
......
...@@ -32,6 +32,8 @@ import torch ...@@ -32,6 +32,8 @@ import torch
from torch import nn from torch import nn
from transformers.models.glm4_moe import Glm4MoeConfig from transformers.models.glm4_moe import Glm4MoeConfig
from vllm import envs
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
...@@ -43,7 +45,7 @@ from vllm.distributed import ( ...@@ -43,7 +45,7 @@ from vllm.distributed import (
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import SharedFusedMoE from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm, FusedRMSNormQuant
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear, MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
...@@ -71,6 +73,10 @@ from .utils import ( ...@@ -71,6 +73,10 @@ from .utils import (
make_layers, make_layers,
maybe_prefix, maybe_prefix,
) )
from vllm.utils.torch_utils import direct_register_custom_op
if envs.VLLM_USE_FUSED_RMS_QUANT:
from lightop import rms_norm_dynamic_per_token_quant
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -113,6 +119,44 @@ class Glm4MoeMLP(nn.Module): ...@@ -113,6 +119,44 @@ class Glm4MoeMLP(nn.Module):
x, _ = self.down_proj(x) x, _ = self.down_proj(x)
return x return x
class Glm4MoeQuantedMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: QuantizationConfig | None = None,
reduce_results: bool = True,
prefix: str = "",
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size,
[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
reduce_results=reduce_results,
prefix=f"{prefix}.down_proj",
)
if hidden_act != "silu":
raise ValueError(
f"Unsupported activation: {hidden_act}. Only silu is supported for now."
)
self.act_fn = SiluAndMul()
def forward(self, x, x_scales):
gate_up, _ = self.gate_up_proj(x, x_scales)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class Glm4MoE(nn.Module): class Glm4MoE(nn.Module):
def __init__( def __init__(
...@@ -342,6 +386,8 @@ class Glm4MoeDecoderLayer(nn.Module): ...@@ -342,6 +386,8 @@ class Glm4MoeDecoderLayer(nn.Module):
layer_idx = int(prefix.split(sep=".")[-1]) layer_idx = int(prefix.split(sep=".")[-1])
self.layer_idx = layer_idx self.layer_idx = layer_idx
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.self_attn = Glm4MoeAttention( self.self_attn = Glm4MoeAttention(
config=config, config=config,
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
...@@ -368,6 +414,7 @@ class Glm4MoeDecoderLayer(nn.Module): ...@@ -368,6 +414,7 @@ class Glm4MoeDecoderLayer(nn.Module):
enable_eplb=enable_eplb, enable_eplb=enable_eplb,
) )
else: else:
if not envs.VLLM_USE_FUSED_RMS_QUANT:
self.mlp = Glm4MoeMLP( self.mlp = Glm4MoeMLP(
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
...@@ -375,11 +422,23 @@ class Glm4MoeDecoderLayer(nn.Module): ...@@ -375,11 +422,23 @@ class Glm4MoeDecoderLayer(nn.Module):
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.mlp", prefix=f"{prefix}.mlp",
) )
else:
self.mlp = Glm4MoeQuantedMLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) if not envs.VLLM_USE_FUSED_RMS_QUANT:
self.post_attention_layernorm = RMSNorm( self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps config.hidden_size, eps=config.rms_norm_eps
) )
else:
self.post_attention_layernorm = FusedRMSNormQuant(
config.hidden_size, eps=config.rms_norm_eps
)
self.routed_scaling_factor = config.routed_scaling_factor self.routed_scaling_factor = config.routed_scaling_factor
def forward( def forward(
...@@ -394,8 +453,13 @@ class Glm4MoeDecoderLayer(nn.Module): ...@@ -394,8 +453,13 @@ class Glm4MoeDecoderLayer(nn.Module):
else: else:
hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states = self.self_attn(positions=positions, hidden_states=hidden_states) hidden_states = self.self_attn(positions=positions, hidden_states=hidden_states)
if not envs.VLLM_USE_FUSED_RMS_QUANT:
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
hidden_states = self.mlp(hidden_states) hidden_states = self.mlp(hidden_states)
else:
hidden_states, scales, residual = self.post_attention_layernorm(hidden_states, residual)
hidden_states = self.mlp(hidden_states, scales)
return hidden_states, residual return hidden_states, residual
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment