Unverified Commit 8060bb03 authored by Jiangyun Zhu's avatar Jiangyun Zhu Committed by GitHub
Browse files

[vLLM IR] rework gemma_rms_norm (#39014)


Signed-off-by: default avatarzjy0516 <riverclouds.zhu@qq.com>
Signed-off-by: default avatarJiangyun Zhu <riverclouds.zhu@qq.com>
Co-authored-by: default avatarLuka Govedič <ProExpertProg@users.noreply.github.com>
parent da4c0e4d
...@@ -6,7 +6,7 @@ import torch ...@@ -6,7 +6,7 @@ import torch
from tests.kernels.quant_utils import FP8_DTYPE from tests.kernels.quant_utils import FP8_DTYPE
from tests.kernels.utils import opcheck from tests.kernels.utils import opcheck
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.torch_utils import set_random_seed from vllm.utils.torch_utils import set_random_seed
...@@ -162,3 +162,31 @@ def test_fused_rms_norm_quant( ...@@ -162,3 +162,31 @@ def test_fused_rms_norm_quant(
atol=1e-3, atol=1e-3,
rtol=1e-3, rtol=1e-3,
) )
@torch.inference_mode()
def test_gemma_rms_norm_mixed_input_weight_dtype(default_vllm_config) -> None:
if not torch.cuda.is_available():
pytest.skip("CUDA required")
device = CUDA_DEVICES[0]
torch.set_default_device(device)
num_tokens, hidden_size = 32, 1024
x = torch.randn(num_tokens, hidden_size, dtype=torch.bfloat16, device=device)
layer = GemmaRMSNorm(hidden_size, eps=1e-6).to(device=device)
layer.weight.data.normal_(mean=0.0, std=0.1)
# Gemma uses fp32 weight parameter while activations can be bf16.
assert layer.weight.dtype == torch.float32
out = layer(x)
x_fp32 = x.float()
weight_fp32 = layer.weight.data.float() + 1.0
variance = x_fp32.pow(2).mean(dim=-1, keepdim=True)
ref = (x_fp32 * torch.rsqrt(variance + layer.variance_epsilon) * weight_fp32).to(
x.dtype
)
assert out.dtype == x.dtype
torch.testing.assert_close(out, ref, atol=1e-2, rtol=1e-2)
...@@ -12,6 +12,9 @@ from torch._higher_order_ops.auto_functionalize import auto_functionalized ...@@ -12,6 +12,9 @@ from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import PatternMatcherPass from torch._inductor.pattern_matcher import PatternMatcherPass
import vllm.ir.ops import vllm.ir.ops
from vllm.compilation.passes.fusion.rms_quant_fusion import (
_rms_input_weight_dtype_match,
)
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.utils import Range from vllm.config.utils import Range
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
...@@ -320,7 +323,12 @@ class AllReduceRMSNormPattern(BasePattern): ...@@ -320,7 +323,12 @@ class AllReduceRMSNormPattern(BasePattern):
return allreduce[3], allreduce[1] return allreduce[3], allreduce[1]
pm.register_replacement( pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass pattern,
replacement,
self.get_inputs(),
pm.fwd_only,
pm_pass,
extra_check=_rms_input_weight_dtype_match,
) )
...@@ -459,7 +467,12 @@ class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern): ...@@ -459,7 +467,12 @@ class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern):
return allreduce[4], allreduce[1] return allreduce[4], allreduce[1]
pm.register_replacement( pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass pattern,
replacement,
self.get_inputs(),
pm.fwd_only,
pm_pass,
extra_check=_rms_input_weight_dtype_match,
) )
...@@ -621,7 +634,12 @@ class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern): ...@@ -621,7 +634,12 @@ class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern):
return allreduce[4], allreduce[1], allreduce[5] return allreduce[4], allreduce[1], allreduce[5]
pm.register_replacement( pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass pattern,
replacement,
self.get_inputs(),
pm.fwd_only,
pm_pass,
extra_check=_rms_input_weight_dtype_match,
) )
......
...@@ -38,6 +38,22 @@ FP8_DTYPE = current_platform.fp8_dtype() ...@@ -38,6 +38,22 @@ FP8_DTYPE = current_platform.fp8_dtype()
FP4_DTYPE = torch.uint8 FP4_DTYPE = torch.uint8
_RMS_NORM_OP = torch.ops.vllm_ir.rms_norm.default
# TODO: extend rmsnorm quant kernels to support mixed input/weight dtypes,
# and remove this check.
def _rms_input_weight_dtype_match(match: pm.Match) -> bool:
"""Prevent fusion when rms_norm input and weight dtypes differ."""
for node in match.nodes:
if node.target == _RMS_NORM_OP:
# rms_norm(x, weight, epsilon, variance_size)
x, weight = node.args[0], node.args[1]
if isinstance(x, fx.Node) and isinstance(weight, fx.Node):
return x.meta["val"].dtype == weight.meta["val"].dtype
return True
def empty_bf16(*args: Any, **kwargs: Any) -> torch.Tensor: def empty_bf16(*args: Any, **kwargs: Any) -> torch.Tensor:
return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda") return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda")
...@@ -186,7 +202,14 @@ class RMSNormStaticQuantPattern(RMSNormQuantPattern): ...@@ -186,7 +202,14 @@ class RMSNormStaticQuantPattern(RMSNormQuantPattern):
] ]
pattern(*inputs) pattern(*inputs)
pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass) pm.register_replacement(
pattern,
replacement,
inputs,
pm.fwd_only,
pm_pass,
extra_check=_rms_input_weight_dtype_match,
)
class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern): class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
...@@ -249,6 +272,7 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern): ...@@ -249,6 +272,7 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
inputs, inputs,
pm.fwd_only, pm.fwd_only,
pm_pass, pm_pass,
extra_check=_rms_input_weight_dtype_match,
) )
...@@ -350,6 +374,7 @@ class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern): ...@@ -350,6 +374,7 @@ class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern):
self.rmsnorm_matcher.inputs() + [scale], self.rmsnorm_matcher.inputs() + [scale],
pm.fwd_only, pm.fwd_only,
pm_pass, pm_pass,
extra_check=_rms_input_weight_dtype_match,
) )
...@@ -445,6 +470,7 @@ class RMSNormGroupQuantPattern(RMSNormQuantPattern): ...@@ -445,6 +470,7 @@ class RMSNormGroupQuantPattern(RMSNormQuantPattern):
], ],
pm.fwd_only, pm.fwd_only,
pm_pass, pm_pass,
extra_check=_rms_input_weight_dtype_match,
) )
...@@ -503,6 +529,7 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern): ...@@ -503,6 +529,7 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
], ],
pm.fwd_only, pm.fwd_only,
pm_pass, pm_pass,
extra_check=_rms_input_weight_dtype_match,
) )
...@@ -559,6 +586,7 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern): ...@@ -559,6 +586,7 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
self.rmsnorm_matcher.inputs(), self.rmsnorm_matcher.inputs(),
pm.fwd_only, pm.fwd_only,
pm_pass, pm_pass,
extra_check=_rms_input_weight_dtype_match,
) )
......
...@@ -16,7 +16,6 @@ def rms_norm( ...@@ -16,7 +16,6 @@ def rms_norm(
x_var = x if variance_size is None else x[..., :variance_size] x_var = x if variance_size is None else x[..., :variance_size]
variance = x_var.pow(2).mean(dim=-1, keepdim=True) variance = x_var.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + epsilon) x = x * torch.rsqrt(variance + epsilon)
x = x.to(orig_dtype)
if weight is not None: if weight is not None:
x = x * weight x = x.to(weight.dtype) * weight
return x return x.to(orig_dtype)
...@@ -36,13 +36,11 @@ AITER_SUPPORTED = is_aiter_found() ...@@ -36,13 +36,11 @@ AITER_SUPPORTED = is_aiter_found()
rms_no_var_16bit_only = ( rms_no_var_16bit_only = (
lambda x, weight, epsilon, variance_size=None: variance_size is None lambda x, weight, epsilon, variance_size=None: variance_size is None
and x.dtype and x.dtype in (torch.float16, torch.bfloat16)
in ( and (weight is None or weight.dtype == x.dtype)
torch.float16,
torch.bfloat16,
)
) )
"""AITER rms_norm only supports float16 and bfloat16 acts and no var_size override.""" """AITER rms_norm only supports float16 and bfloat16 acts, no var_size override,
and requires weight dtype to match x dtype."""
@ir.ops.rms_norm.register_impl( @ir.ops.rms_norm.register_impl(
......
...@@ -11,8 +11,11 @@ current_platform.import_kernels() ...@@ -11,8 +11,11 @@ current_platform.import_kernels()
CUDA_ALIKE = current_platform.is_cuda_alike() CUDA_ALIKE = current_platform.is_cuda_alike()
"""Most kernels in this file are supported on all CUDA-alike platforms.""" """Most kernels in this file are supported on all CUDA-alike platforms."""
rms_no_var_size = lambda x, weight, epsilon, variance_size=None: variance_size is None rms_no_var_size = (
"""vLLM kernel does not support variance_size parameter.""" lambda x, weight, epsilon, variance_size=None: variance_size is None
and (weight is None or weight.dtype == x.dtype)
)
"""vLLM kernel requires no variance_size override and matching input/weight dtype."""
@ir.ops.rms_norm.register_impl( @ir.ops.rms_norm.register_impl(
......
...@@ -18,7 +18,9 @@ def is_xpu_kernels_found() -> bool: ...@@ -18,7 +18,9 @@ def is_xpu_kernels_found() -> bool:
XPU_KERNELS_SUPPORTED = is_xpu_kernels_found() XPU_KERNELS_SUPPORTED = is_xpu_kernels_found()
"""Kernels in this file are supported if vLLM XPU kernels are installed.""" """Kernels in this file are supported if vLLM XPU kernels are installed."""
rms_no_var = lambda x, weight, epsilon, variance_size=None: variance_size is None rms_no_var = lambda x, weight, epsilon, variance_size=None: variance_size is None and (
weight is None or weight.dtype == x.dtype
)
@ir.ops.rms_norm.register_impl( @ir.ops.rms_norm.register_impl(
......
...@@ -376,77 +376,32 @@ class GemmaRMSNorm(CustomOp): ...@@ -376,77 +376,32 @@ class GemmaRMSNorm(CustomOp):
self.weight = nn.Parameter(torch.zeros(hidden_size)) self.weight = nn.Parameter(torch.zeros(hidden_size))
self.variance_epsilon = eps self.variance_epsilon = eps
@staticmethod
def _forward_static_no_residual(
weight: torch.Tensor,
variance_epsilon: float,
x: torch.Tensor,
) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward() without residual."""
orig_dtype = x.dtype
x = x.float()
variance = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + variance_epsilon)
x = x * (1.0 + weight.float())
x = x.to(orig_dtype)
return x
@staticmethod
def _forward_static_with_residual(
weight: torch.Tensor,
variance_epsilon: float,
x: torch.Tensor,
residual: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""PyTorch-native implementation equivalent to forward() with residual."""
orig_dtype = x.dtype
x = (
x.float() + residual.float()
if orig_dtype == torch.float16
else x + residual
)
residual = x
x = x.float()
variance = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + variance_epsilon)
# Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
# See https://github.com/huggingface/transformers/pull/29402
x = x * (1.0 + weight.float())
x = x.to(orig_dtype)
return x, residual
def forward_native( def forward_native(
self, self,
x: torch.Tensor, x: torch.Tensor,
residual: torch.Tensor | None = None, residual: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""PyTorch-native implementation equivalent to forward().""" """PyTorch-native implementation equivalent to forward()."""
if residual is None: orig_dtype = x.dtype
return self._forward_static_no_residual( weight = self.weight.data.float() + 1.0
self.weight.data, self.variance_epsilon, x if residual is not None:
) x = (
else: x.float() + residual.float()
return self._forward_static_with_residual( if orig_dtype == torch.float16
self.weight.data, self.variance_epsilon, x, residual else x + residual
) )
residual = x
# ir.ops.rms_norm handles fp32 upcast internally
out = ir.ops.rms_norm(x, weight, self.variance_epsilon)
return (
out.to(orig_dtype) if residual is None else (out.to(orig_dtype), residual)
)
def forward_cuda( def forward_cuda(
self, self,
x: torch.Tensor, x: torch.Tensor,
residual: torch.Tensor | None = None, residual: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if torch.compiler.is_compiling():
return self.forward_native(x, residual)
if not getattr(self, "_is_compiled", False):
self._forward_static_no_residual = torch.compile( # type: ignore
self._forward_static_no_residual
)
self._forward_static_with_residual = torch.compile( # type: ignore
self._forward_static_with_residual
)
self._is_compiled = True
return self.forward_native(x, residual) return self.forward_native(x, residual)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment