Commit 3e0a595d authored by zhuwenwen's avatar zhuwenwen
Browse files

add apex rmsnorm

parent a495fc3b
......@@ -151,6 +151,7 @@ if TYPE_CHECKING:
VLLM_ZERO_OVERHEAD: bool = False
VLLM_ENABLE_MOE_FUSED_GATE: bool = False
VLLM_USE_FLASH_ATTN_PA: bool = False
VLLM_USE_APEX_RN: bool = False
def get_default_cache_root():
......@@ -996,6 +997,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_FLASH_ATTN_PA":
lambda: (os.environ.get("VLLM_USE_FLASH_ATTN_PA", "False").lower() in
("true", "1")),
# vLLM will use apex for rmsnorm
"VLLM_USE_APEX_RN":
lambda: (os.environ.get("VLLM_USE_APEX_RN", "False").lower() in
("true", "1")),
}
# --8<-- [end:env-vars-definition]
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Custom normalization layers."""
from typing import Optional, Union
from typing import Optional, Union, Tuple
import torch
import torch.nn as nn
......@@ -188,6 +188,21 @@ class RMSNorm(CustomOp):
else:
return norm_func(x, self.weight.data, self.variance_epsilon)
def forward_apex(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
from apex.normalization.fused_layer_norm import fused_rms_norm_affine
add_residual = residual is not None
norm_func = dispatch_cuda_rmsnorm_func(add_residual)
if add_residual:
return norm_func(x, residual, self.weight.data,
self.variance_epsilon)
else:
return fused_rms_norm_affine(x, self.weight.data, torch.Size((x.shape[-1],)), self.variance_epsilon)
def forward_hpu(
self,
x: torch.Tensor,
......
......@@ -50,6 +50,7 @@ from .interfaces import SupportsCrossEncoding, SupportsLoRA, SupportsPP
from .qwen2 import Qwen2MLP as Qwen3MLP
from .qwen2 import Qwen2Model
from .utils import AutoWeightsLoader, PPMissingLayer, maybe_prefix
import vllm.envs as envs
logger = init_logger(__name__)
......@@ -137,10 +138,16 @@ class Qwen3Attention(nn.Module):
# Add qk-norm
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim,
self.head_dim)
if envs.VLLM_USE_APEX_RN:
q_by_head = self.q_norm.forward_apex(q_by_head)
else:
q_by_head = self.q_norm(q_by_head)
q = q_by_head.view(q.shape)
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim,
self.head_dim)
if envs.VLLM_USE_APEX_RN:
k_by_head = self.k_norm.forward_apex(k_by_head)
else:
k_by_head = self.k_norm(k_by_head)
k = k_by_head.view(k.shape)
q, k = self.rotary_emb(positions, q, k)
......
......@@ -57,6 +57,7 @@ from .utils import (AutoWeightsLoader, extract_layer_index,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.model_executor.utils import pad_weight, gemm_bank_conf
from vllm.utils import W8a8GetCacheJSON
......@@ -230,11 +231,17 @@ class Qwen3MoeAttention(nn.Module):
# Add qk-norm
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim,
self.head_dim)
if envs.VLLM_USE_APEX_RN:
q_by_head = self.q_norm.forward_apex(q_by_head)
else:
q_by_head = self.q_norm(q_by_head)
q = q_by_head.view(q.shape)
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim,
self.head_dim)
if envs.VLLM_USE_APEX_RN:
k_by_head = self.k_norm.forward_apex(k_by_head)
else:
k_by_head = self.k_norm(k_by_head)
k = k_by_head.view(k.shape)
q, k = self.rotary_emb(positions, q, k)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment