Unverified Commit 156405d2 authored by Xiaoshuang Wang's avatar Xiaoshuang Wang Committed by GitHub
Browse files

[vLLM IR] gemma_rms_norm (#38780)


Signed-off-by: default avatarIcey <1790571317@qq.com>
parent 99e5539a
...@@ -16,7 +16,6 @@ def rms_norm( ...@@ -16,7 +16,6 @@ def rms_norm(
x_var = x if variance_size is None else x[..., :variance_size] x_var = x if variance_size is None else x[..., :variance_size]
variance = x_var.pow(2).mean(dim=-1, keepdim=True) variance = x_var.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + epsilon) x = x * torch.rsqrt(variance + epsilon)
x = x.to(orig_dtype)
if weight is not None: if weight is not None:
x = x * weight x = x.to(weight.dtype) * weight
return x return x.to(orig_dtype)
...@@ -36,13 +36,11 @@ AITER_SUPPORTED = is_aiter_found() ...@@ -36,13 +36,11 @@ AITER_SUPPORTED = is_aiter_found()
rms_no_var_16bit_only = ( rms_no_var_16bit_only = (
lambda x, weight, epsilon, variance_size=None: variance_size is None lambda x, weight, epsilon, variance_size=None: variance_size is None
and x.dtype and x.dtype in (torch.float16, torch.bfloat16)
in ( and (weight is None or weight.dtype == x.dtype)
torch.float16,
torch.bfloat16,
)
) )
"""AITER rms_norm only supports float16 and bfloat16 acts and no var_size override.""" """AITER rms_norm only supports float16 and bfloat16 acts, no var_size override,
and requires weight dtype to match x dtype."""
@ir.ops.rms_norm.register_impl( @ir.ops.rms_norm.register_impl(
......
...@@ -11,8 +11,11 @@ current_platform.import_kernels() ...@@ -11,8 +11,11 @@ current_platform.import_kernels()
CUDA_ALIKE = current_platform.is_cuda_alike() CUDA_ALIKE = current_platform.is_cuda_alike()
"""Most kernels in this file are supported on all CUDA-alike platforms.""" """Most kernels in this file are supported on all CUDA-alike platforms."""
rms_no_var_size = lambda x, weight, epsilon, variance_size=None: variance_size is None rms_no_var_size = (
"""vLLM kernel does not support variance_size parameter.""" lambda x, weight, epsilon, variance_size=None: variance_size is None
and (weight is None or weight.dtype == x.dtype)
)
"""vLLM kernel does not support variance_size parameter or mismatched weight dtype."""
@ir.ops.rms_norm.register_impl( @ir.ops.rms_norm.register_impl(
......
...@@ -18,7 +18,9 @@ def is_xpu_kernels_found() -> bool: ...@@ -18,7 +18,9 @@ def is_xpu_kernels_found() -> bool:
XPU_KERNELS_SUPPORTED = is_xpu_kernels_found() XPU_KERNELS_SUPPORTED = is_xpu_kernels_found()
"""Kernels in this file are supported if vLLM XPU kernels are installed.""" """Kernels in this file are supported if vLLM XPU kernels are installed."""
rms_no_var = lambda x, weight, epsilon, variance_size=None: variance_size is None rms_no_var = lambda x, weight, epsilon, variance_size=None: variance_size is None and (
weight is None or weight.dtype == x.dtype
)
@ir.ops.rms_norm.register_impl( @ir.ops.rms_norm.register_impl(
......
...@@ -376,46 +376,6 @@ class GemmaRMSNorm(CustomOp): ...@@ -376,46 +376,6 @@ class GemmaRMSNorm(CustomOp):
self.weight = nn.Parameter(torch.zeros(hidden_size)) self.weight = nn.Parameter(torch.zeros(hidden_size))
self.variance_epsilon = eps self.variance_epsilon = eps
@staticmethod
def _forward_static_no_residual(
weight: torch.Tensor,
variance_epsilon: float,
x: torch.Tensor,
) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward() without residual."""
orig_dtype = x.dtype
x = x.float()
variance = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + variance_epsilon)
x = x * (1.0 + weight.float())
x = x.to(orig_dtype)
return x
@staticmethod
def _forward_static_with_residual(
weight: torch.Tensor,
variance_epsilon: float,
x: torch.Tensor,
residual: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""PyTorch-native implementation equivalent to forward() with residual."""
orig_dtype = x.dtype
x = (
x.float() + residual.float()
if orig_dtype == torch.float16
else x + residual
)
residual = x
x = x.float()
variance = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + variance_epsilon)
# Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
# See https://github.com/huggingface/transformers/pull/29402
x = x * (1.0 + weight.float())
x = x.to(orig_dtype)
return x, residual
def forward_native( def forward_native(
self, self,
x: torch.Tensor, x: torch.Tensor,
...@@ -423,30 +383,26 @@ class GemmaRMSNorm(CustomOp): ...@@ -423,30 +383,26 @@ class GemmaRMSNorm(CustomOp):
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""PyTorch-native implementation equivalent to forward().""" """PyTorch-native implementation equivalent to forward()."""
if residual is None: if residual is None:
return self._forward_static_no_residual( return ir.ops.rms_norm(
self.weight.data, self.variance_epsilon, x x, self.weight.data.float() + 1.0, self.variance_epsilon
) )
else: else:
return self._forward_static_with_residual( orig_dtype = x.dtype
self.weight.data, self.variance_epsilon, x, residual x = (
x.float() + residual.float()
if orig_dtype == torch.float16
else x + residual
) )
residual = x
return ir.ops.rms_norm(
x, self.weight.data.float() + 1.0, self.variance_epsilon
).to(orig_dtype), residual
def forward_cuda( def forward_cuda(
self, self,
x: torch.Tensor, x: torch.Tensor,
residual: torch.Tensor | None = None, residual: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if torch.compiler.is_compiling():
return self.forward_native(x, residual)
if not getattr(self, "_is_compiled", False):
self._forward_static_no_residual = torch.compile( # type: ignore
self._forward_static_no_residual
)
self._forward_static_with_residual = torch.compile( # type: ignore
self._forward_static_with_residual
)
self._is_compiled = True
return self.forward_native(x, residual) return self.forward_native(x, residual)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment