Commit 3e0a595d authored by zhuwenwen's avatar zhuwenwen
Browse files

add apex rmsnorm

parent a495fc3b
...@@ -151,6 +151,7 @@ if TYPE_CHECKING: ...@@ -151,6 +151,7 @@ if TYPE_CHECKING:
VLLM_ZERO_OVERHEAD: bool = False VLLM_ZERO_OVERHEAD: bool = False
VLLM_ENABLE_MOE_FUSED_GATE: bool = False VLLM_ENABLE_MOE_FUSED_GATE: bool = False
VLLM_USE_FLASH_ATTN_PA: bool = False VLLM_USE_FLASH_ATTN_PA: bool = False
VLLM_USE_APEX_RN: bool = False
def get_default_cache_root(): def get_default_cache_root():
...@@ -996,6 +997,11 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -996,6 +997,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_FLASH_ATTN_PA": "VLLM_USE_FLASH_ATTN_PA":
lambda: (os.environ.get("VLLM_USE_FLASH_ATTN_PA", "False").lower() in lambda: (os.environ.get("VLLM_USE_FLASH_ATTN_PA", "False").lower() in
("true", "1")), ("true", "1")),
# vLLM will use apex for rmsnorm
"VLLM_USE_APEX_RN":
lambda: (os.environ.get("VLLM_USE_APEX_RN", "False").lower() in
("true", "1")),
} }
# --8<-- [end:env-vars-definition] # --8<-- [end:env-vars-definition]
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Custom normalization layers.""" """Custom normalization layers."""
from typing import Optional, Union from typing import Optional, Union, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -188,6 +188,21 @@ class RMSNorm(CustomOp): ...@@ -188,6 +188,21 @@ class RMSNorm(CustomOp):
else: else:
return norm_func(x, self.weight.data, self.variance_epsilon) return norm_func(x, self.weight.data, self.variance_epsilon)
def forward_apex(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
from apex.normalization.fused_layer_norm import fused_rms_norm_affine
add_residual = residual is not None
norm_func = dispatch_cuda_rmsnorm_func(add_residual)
if add_residual:
return norm_func(x, residual, self.weight.data,
self.variance_epsilon)
else:
return fused_rms_norm_affine(x, self.weight.data, torch.Size((x.shape[-1],)), self.variance_epsilon)
def forward_hpu( def forward_hpu(
self, self,
x: torch.Tensor, x: torch.Tensor,
......
...@@ -50,6 +50,7 @@ from .interfaces import SupportsCrossEncoding, SupportsLoRA, SupportsPP ...@@ -50,6 +50,7 @@ from .interfaces import SupportsCrossEncoding, SupportsLoRA, SupportsPP
from .qwen2 import Qwen2MLP as Qwen3MLP from .qwen2 import Qwen2MLP as Qwen3MLP
from .qwen2 import Qwen2Model from .qwen2 import Qwen2Model
from .utils import AutoWeightsLoader, PPMissingLayer, maybe_prefix from .utils import AutoWeightsLoader, PPMissingLayer, maybe_prefix
import vllm.envs as envs
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -137,10 +138,16 @@ class Qwen3Attention(nn.Module): ...@@ -137,10 +138,16 @@ class Qwen3Attention(nn.Module):
# Add qk-norm # Add qk-norm
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim,
self.head_dim) self.head_dim)
if envs.VLLM_USE_APEX_RN:
q_by_head = self.q_norm.forward_apex(q_by_head)
else:
q_by_head = self.q_norm(q_by_head) q_by_head = self.q_norm(q_by_head)
q = q_by_head.view(q.shape) q = q_by_head.view(q.shape)
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim,
self.head_dim) self.head_dim)
if envs.VLLM_USE_APEX_RN:
k_by_head = self.k_norm.forward_apex(k_by_head)
else:
k_by_head = self.k_norm(k_by_head) k_by_head = self.k_norm(k_by_head)
k = k_by_head.view(k.shape) k = k_by_head.view(k.shape)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
......
...@@ -57,6 +57,7 @@ from .utils import (AutoWeightsLoader, extract_layer_index, ...@@ -57,6 +57,7 @@ from .utils import (AutoWeightsLoader, extract_layer_index,
is_pp_missing_parameter, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers, make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix) maybe_prefix)
import vllm.envs as envs
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.utils import pad_weight, gemm_bank_conf from vllm.model_executor.utils import pad_weight, gemm_bank_conf
from vllm.utils import W8a8GetCacheJSON from vllm.utils import W8a8GetCacheJSON
...@@ -230,11 +231,17 @@ class Qwen3MoeAttention(nn.Module): ...@@ -230,11 +231,17 @@ class Qwen3MoeAttention(nn.Module):
# Add qk-norm # Add qk-norm
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim,
self.head_dim) self.head_dim)
if envs.VLLM_USE_APEX_RN:
q_by_head = self.q_norm.forward_apex(q_by_head)
else:
q_by_head = self.q_norm(q_by_head) q_by_head = self.q_norm(q_by_head)
q = q_by_head.view(q.shape) q = q_by_head.view(q.shape)
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim,
self.head_dim) self.head_dim)
if envs.VLLM_USE_APEX_RN:
k_by_head = self.k_norm.forward_apex(k_by_head)
else:
k_by_head = self.k_norm(k_by_head) k_by_head = self.k_norm(k_by_head)
k = k_by_head.view(k.shape) k = k_by_head.view(k.shape)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment