Unverified Commit 40bb1750 authored by Luka Govedič's avatar Luka Govedič Committed by GitHub
Browse files

[vLLM IR] 1/N Implement IR skeleton and rms_norm op (#33825)


Signed-off-by: default avatarLuka Govedič <lgovedic@redhat.com>
Signed-off-by: default avatarXinyu Chen <xinyu1.chen@intel.com>
Signed-off-by: default avatarchzhang <chaojun.zhang@intel.com>
Signed-off-by: default avatarLuka Govedic <luka.govedic@gmail.com>
Co-authored-by: default avatarXinyu Chen <xinyu1.chen@intel.com>
Co-authored-by: default avatarChaojun Zhang <chaojun.zhang@intel.com>
Co-authored-by: default avatarLuka Govedič <ProExpertProg@h100-01.nemg-001.lab.rdu2.dc.redhat.com>
parent 0fab52f0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from torch import Tensor
from vllm import ir
from vllm.platforms import current_platform
current_platform.import_kernels()
CUDA_ALIKE = current_platform.is_cuda_alike()
"""Most kernels in this file are supported on all CUDA-alike platforms."""
rms_no_var_size = lambda x, weight, epsilon, variance_size=None: variance_size is None
"""vLLM kernel does not support variance_size parameter."""
@ir.ops.rms_norm.register_impl(
"vllm_c", supports_args=rms_no_var_size, supported=CUDA_ALIKE
)
def rms_norm(
x: Tensor, weight: Tensor | None, epsilon: float, variance_size: int | None = None
) -> Tensor:
if weight is None:
# Kernel requires weight tensor, pass ones
weight = torch.ones(x.shape[-1], device=x.device, dtype=x.dtype)
assert variance_size is None
output = torch.empty(x.shape, device=x.device, dtype=x.dtype)
torch.ops._C.rms_norm(output, x, weight, epsilon)
return output
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from torch import Tensor
from vllm import ir
from vllm.platforms import current_platform
current_platform.import_kernels()
def is_xpu_kernels_found() -> bool:
from importlib.util import find_spec
return find_spec("vllm_xpu_kernels") is not None
XPU_KERNELS_SUPPORTED = is_xpu_kernels_found()
"""Kernels in this file are supported if vLLM XPU kernels are installed."""
rms_no_var = lambda x, weight, epsilon, variance_size=None: variance_size is None
@ir.ops.rms_norm.register_impl(
"xpu_kernels", supports_args=rms_no_var, supported=XPU_KERNELS_SUPPORTED
)
def rms_norm(
x: Tensor, weight: Tensor | None, epsilon: float, variance_size: int | None = None
) -> Tensor:
if weight is None:
# Kernel requires weight tensor, pass ones
weight = torch.ones(x.shape[-1], device=x.device, dtype=x.dtype)
assert variance_size is None
output = torch.empty(x.shape, device=x.device, dtype=x.dtype)
torch.ops._C.rms_norm(output, x, weight, epsilon)
return output
......@@ -8,6 +8,7 @@ from vllm.logging_utils.access_log_filter import (
from vllm.logging_utils.formatter import ColoredFormatter, NewLineFormatter
from vllm.logging_utils.lazy import lazy
from vllm.logging_utils.log_time import logtime
from vllm.logging_utils.torch_tensor import tensors_str_no_data
__all__ = [
"NewLineFormatter",
......@@ -16,4 +17,5 @@ __all__ = [
"create_uvicorn_log_config",
"lazy",
"logtime",
"tensors_str_no_data",
]
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
def tensors_str_no_data(arg: Any):
from torch._tensor_str import printoptions
with printoptions(threshold=1, edgeitems=0):
return str(arg)
......@@ -6,7 +6,9 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from vllm import _oink_ops, envs
# Import kernels
import vllm.kernels # noqa: F401
from vllm import _oink_ops, envs, ir
from vllm._aiter_ops import rocm_aiter_ops
from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
......@@ -51,23 +53,6 @@ def _is_oink_stride_compatible_2d(x_2d: torch.Tensor) -> bool:
return (x_2d.stride(0) % divby) == 0
def rms_norm(
x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
) -> torch.Tensor:
from vllm import _custom_ops as ops
if envs.VLLM_BATCH_INVARIANT:
return rms_norm_batch_invariant(x, weight, variance_epsilon)
out = torch.empty_like(x)
ops.rms_norm(
out,
x,
weight,
variance_epsilon,
)
return out
def fused_add_rms_norm(
x: torch.Tensor,
residual: torch.Tensor,
......@@ -105,23 +90,16 @@ def poly_norm(
return out
def dispatch_rocm_rmsnorm_func(
with_fused_add: bool, dtype: torch.dtype, use_aiter: bool = False
):
def dispatch_rocm_rmsnorm_func(dtype: torch.dtype, use_aiter: bool = False):
use_aiter = use_aiter and dtype in [
torch.float16,
torch.bfloat16,
]
if use_aiter and with_fused_add:
return rocm_aiter_ops.rms_norm2d_with_add
if use_aiter:
return rocm_aiter_ops.rms_norm
# fall back to CUDA implementation
if with_fused_add:
return rocm_aiter_ops.rms_norm2d_with_add
else:
return fused_add_rms_norm
return rms_norm
# --8<-- [start:rms_norm]
......@@ -158,20 +136,14 @@ class RMSNorm(CustomOp):
if current_platform.is_rocm():
aiter_rmsnorm_enabled = rocm_aiter_ops.is_rmsnorm_enabled()
self.rocm_norm_func = dispatch_rocm_rmsnorm_func(
with_fused_add=False,
dtype=weight_dtype,
use_aiter=aiter_rmsnorm_enabled,
)
self.rocm_norm_func_with_add = dispatch_rocm_rmsnorm_func(
with_fused_add=True, dtype=weight_dtype, use_aiter=aiter_rmsnorm_enabled
dtype=weight_dtype, use_aiter=aiter_rmsnorm_enabled
)
# Optional: enable Oink Blackwell RMSNorm custom-op fast path on
# compatible CUDA devices (e.g., SM100) when the external Oink
# package is available. This is detected once at construction time
# to avoid per-call device queries in the hot path.
self._use_oink_rmsnorm = False
self._use_oink_fused_add_rmsnorm = False
if (
not current_platform.is_rocm()
......@@ -203,7 +175,6 @@ class RMSNorm(CustomOp):
try:
device_index = torch.accelerator.current_device_index()
if _oink_ops.is_oink_available_for_device(device_index):
self._use_oink_rmsnorm = True
self._use_oink_fused_add_rmsnorm = (
_oink_ops.has_fused_add_rms_norm()
)
......@@ -215,7 +186,6 @@ class RMSNorm(CustomOp):
"RMSNorm; falling back to vLLM RMSNorm. Error: %s",
e,
)
self._use_oink_rmsnorm = False
self._use_oink_fused_add_rmsnorm = False
@staticmethod
......@@ -270,6 +240,10 @@ class RMSNorm(CustomOp):
residual: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""PyTorch-native implementation equivalent to forward()."""
if residual is None:
return ir.ops.rms_norm(
x, self.weight.data, self.variance_epsilon, self.variance_size_override
)
return self.forward_static(
x,
......@@ -286,35 +260,14 @@ class RMSNorm(CustomOp):
x: torch.Tensor,
residual: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if residual is None and not envs.VLLM_BATCH_INVARIANT:
return ir.ops.rms_norm(
x, self.weight.data, self.variance_epsilon, self.variance_size_override
)
if self.variance_size_override is not None:
return self.forward_native(x, residual)
# Optional Oink SM100 fast path (no residual). This path is
# torch.compile-friendly via torch.ops.oink.rmsnorm and preserves
# 2D layouts (including padded rows) when using the Oink
# pointer-based kernel.
if (
residual is None
and getattr(self, "_use_oink_rmsnorm", False)
and x.is_cuda
and x.dim() >= 2
and self.has_weight
and not envs.VLLM_BATCH_INVARIANT
and self.weight.data.dtype == x.dtype
and self.weight.data.is_contiguous()
):
orig_shape = x.shape
hidden_size = orig_shape[-1]
if _can_view_as_2d(x):
x_2d = x.view(-1, hidden_size)
if _is_oink_stride_compatible_2d(x_2d):
y_2d = _oink_ops.rmsnorm(
x_2d,
self.weight.data,
self.variance_epsilon,
)
return y_2d.view(orig_shape)
# Optional Oink SM100 fast path (fused residual-add + RMSNorm, in-place).
# This mirrors vLLM's fused_add_rms_norm semantics by mutating both
# `x` (normalized output) and `residual` (residual-out buffer).
......@@ -356,29 +309,34 @@ class RMSNorm(CustomOp):
)
return x, residual
add_residual = residual is not None
if add_residual:
if residual is not None:
return fused_add_rms_norm(
x, residual, self.weight.data, self.variance_epsilon
)
else:
return rms_norm(x, self.weight.data, self.variance_epsilon)
assert envs.VLLM_BATCH_INVARIANT
return rms_norm_batch_invariant(x, self.weight.data, self.variance_epsilon)
def forward_hip(
self,
x: torch.Tensor,
residual: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if residual is None and not envs.VLLM_BATCH_INVARIANT:
return ir.ops.rms_norm(
x, self.weight.data, self.variance_epsilon, self.variance_size_override
)
if self.variance_size_override is not None:
return self.forward_native(x, residual)
add_residual = residual is not None
if add_residual:
if residual is not None:
return self.rocm_norm_func_with_add(
x, residual, self.weight.data, self.variance_epsilon
)
else:
return self.rocm_norm_func(x, self.weight.data, self.variance_epsilon)
assert envs.VLLM_BATCH_INVARIANT
return rms_norm_batch_invariant(x, self.weight.data, self.variance_epsilon)
def forward_xpu(
self,
......
......@@ -30,6 +30,7 @@ from .interface import DeviceCapability, Platform, PlatformEnum
if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.config.cache import CacheDType
from vllm.config.kernel import IrOpPriorityConfig
from vllm.v1.attention.selector import AttentionSelectorConfig
else:
VllmConfig = None
......@@ -550,6 +551,26 @@ class CudaPlatformBase(Platform):
def use_custom_op_collectives(cls) -> bool:
return True
@classmethod
def get_default_ir_op_priority(cls, vllm_config: VllmConfig) -> IrOpPriorityConfig:
from vllm.config.compilation import CompilationMode
from vllm.config.kernel import IrOpPriorityConfig
# Native used by default when compiling,
# use vllm_c kernels where available when no codegen
cc = vllm_config.compilation_config
using_inductor = cc.backend == "inductor" and cc.mode != CompilationMode.NONE
default = ["native"] if using_inductor else ["vllm_c", "native"]
# Use oink if enabled for rms_norm
# TODO(Laurawly/luka): remove this env var,
# users can just use IR op priority directly
rms_norm = default
if envs.VLLM_USE_OINK_OPS:
rms_norm = ["oink"] + default
return IrOpPriorityConfig.with_default(default, rms_norm=rms_norm)
# NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
......
......@@ -17,6 +17,7 @@ if TYPE_CHECKING:
from torch.distributed import PrefixStore, ProcessGroup
from vllm.config import VllmConfig
from vllm.config.kernel import IrOpPriorityConfig
from vllm.inputs import EngineInput
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
......@@ -931,6 +932,16 @@ class Platform:
"num_compute_units is not implemented for the current platform."
)
@classmethod
def get_default_ir_op_priority(
cls, vllm_config: "VllmConfig"
) -> "IrOpPriorityConfig":
"""Get the default IR op priority for the current platform."""
from vllm.config.kernel import IrOpPriorityConfig
# Native always used by default. Platforms can override this behavior.
return IrOpPriorityConfig.with_default(["native"])
class UnspecifiedPlatform(Platform):
_enum = PlatformEnum.UNSPECIFIED
......
......@@ -19,6 +19,7 @@ from .interface import DeviceCapability, Platform, PlatformEnum
if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.config.kernel import IrOpPriorityConfig
from vllm.v1.attention.selector import AttentionSelectorConfig
logger = init_logger(__name__)
......@@ -903,3 +904,32 @@ class RocmPlatform(Platform):
@classmethod
def use_custom_op_collectives(cls) -> bool:
return True
@classmethod
def get_default_ir_op_priority(
cls, vllm_config: "VllmConfig"
) -> "IrOpPriorityConfig":
from vllm.config.compilation import CompilationMode
from vllm.config.kernel import IrOpPriorityConfig
# Native used by default when compiling,
# use vllm_c kernels where available when no codegen
# TODO(luka/TJ) use aiter, vllm_c, native by default on ROCm
cc = vllm_config.compilation_config
using_inductor = cc.backend == "inductor" and cc.mode != CompilationMode.NONE
default = ["native"] if using_inductor else ["vllm_c", "native"]
# This (mostly) preserves previous CustomOp behavior
# Necessary on ROCm because it's common that users
# enable rms_norm to use the aiter kernel.
# TODO(luka/TJ) remove env vars completely
if (
cc.is_custom_op_enabled("rms_norm")
and envs.VLLM_ROCM_USE_AITER
and envs.VLLM_ROCM_USE_AITER_RMSNORM
):
rms_norm = ["aiter"] + default
else:
rms_norm = default
return IrOpPriorityConfig.with_default(default, rms_norm=rms_norm)
......@@ -21,6 +21,7 @@ from .interface import DeviceCapability, Platform, PlatformEnum
if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.config.kernel import IrOpPriorityConfig
from vllm.v1.attention.selector import AttentionSelectorConfig
else:
VllmConfig = None
......@@ -257,6 +258,21 @@ class XPUPlatform(Platform):
)
return "vllm.distributed.device_communicators.xpu_communicator.XpuCommunicator" # noqa
@classmethod
def get_default_ir_op_priority(
cls, vllm_config: "VllmConfig"
) -> "IrOpPriorityConfig":
from vllm.config.compilation import CompilationMode
from vllm.config.kernel import IrOpPriorityConfig
# Native used by default when compiling,
# use fused kernels where available when no codegen
cc = vllm_config.compilation_config
using_inductor = cc.backend == "inductor" and cc.mode != CompilationMode.NONE
default = ["native"] if using_inductor else ["xpu_kernels", "native"]
return IrOpPriorityConfig.with_default(default)
@classmethod
def device_count(cls) -> int:
return torch.xpu.device_count()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment