Unverified Commit 40bb1750 authored by Luka Govedič's avatar Luka Govedič Committed by GitHub
Browse files

[vLLM IR] 1/N Implement IR skeleton and rms_norm op (#33825)


Signed-off-by: default avatarLuka Govedič <lgovedic@redhat.com>
Signed-off-by: default avatarXinyu Chen <xinyu1.chen@intel.com>
Signed-off-by: default avatarchzhang <chaojun.zhang@intel.com>
Signed-off-by: default avatarLuka Govedic <luka.govedic@gmail.com>
Co-authored-by: default avatarXinyu Chen <xinyu1.chen@intel.com>
Co-authored-by: default avatarChaojun Zhang <chaojun.zhang@intel.com>
Co-authored-by: default avatarLuka Govedič <ProExpertProg@h100-01.nemg-001.lab.rdu2.dc.redhat.com>
parent 0fab52f0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from torch import Tensor
from vllm import ir
from vllm.platforms import current_platform
current_platform.import_kernels()
CUDA_ALIKE = current_platform.is_cuda_alike()
"""Most kernels in this file are supported on all CUDA-alike platforms."""
rms_no_var_size = lambda x, weight, epsilon, variance_size=None: variance_size is None
"""vLLM kernel does not support variance_size parameter."""
@ir.ops.rms_norm.register_impl(
"vllm_c", supports_args=rms_no_var_size, supported=CUDA_ALIKE
)
def rms_norm(
x: Tensor, weight: Tensor | None, epsilon: float, variance_size: int | None = None
) -> Tensor:
if weight is None:
# Kernel requires weight tensor, pass ones
weight = torch.ones(x.shape[-1], device=x.device, dtype=x.dtype)
assert variance_size is None
output = torch.empty(x.shape, device=x.device, dtype=x.dtype)
torch.ops._C.rms_norm(output, x, weight, epsilon)
return output
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from torch import Tensor
from vllm import ir
from vllm.platforms import current_platform
current_platform.import_kernels()
def is_xpu_kernels_found() -> bool:
from importlib.util import find_spec
return find_spec("vllm_xpu_kernels") is not None
XPU_KERNELS_SUPPORTED = is_xpu_kernels_found()
"""Kernels in this file are supported if vLLM XPU kernels are installed."""
rms_no_var = lambda x, weight, epsilon, variance_size=None: variance_size is None
@ir.ops.rms_norm.register_impl(
"xpu_kernels", supports_args=rms_no_var, supported=XPU_KERNELS_SUPPORTED
)
def rms_norm(
x: Tensor, weight: Tensor | None, epsilon: float, variance_size: int | None = None
) -> Tensor:
if weight is None:
# Kernel requires weight tensor, pass ones
weight = torch.ones(x.shape[-1], device=x.device, dtype=x.dtype)
assert variance_size is None
output = torch.empty(x.shape, device=x.device, dtype=x.dtype)
torch.ops._C.rms_norm(output, x, weight, epsilon)
return output
...@@ -8,6 +8,7 @@ from vllm.logging_utils.access_log_filter import ( ...@@ -8,6 +8,7 @@ from vllm.logging_utils.access_log_filter import (
from vllm.logging_utils.formatter import ColoredFormatter, NewLineFormatter from vllm.logging_utils.formatter import ColoredFormatter, NewLineFormatter
from vllm.logging_utils.lazy import lazy from vllm.logging_utils.lazy import lazy
from vllm.logging_utils.log_time import logtime from vllm.logging_utils.log_time import logtime
from vllm.logging_utils.torch_tensor import tensors_str_no_data
__all__ = [ __all__ = [
"NewLineFormatter", "NewLineFormatter",
...@@ -16,4 +17,5 @@ __all__ = [ ...@@ -16,4 +17,5 @@ __all__ = [
"create_uvicorn_log_config", "create_uvicorn_log_config",
"lazy", "lazy",
"logtime", "logtime",
"tensors_str_no_data",
] ]
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
def tensors_str_no_data(arg: Any):
from torch._tensor_str import printoptions
with printoptions(threshold=1, edgeitems=0):
return str(arg)
...@@ -6,7 +6,9 @@ import torch ...@@ -6,7 +6,9 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from vllm import _oink_ops, envs # Import kernels
import vllm.kernels # noqa: F401
from vllm import _oink_ops, envs, ir
from vllm._aiter_ops import rocm_aiter_ops from vllm._aiter_ops import rocm_aiter_ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
...@@ -51,23 +53,6 @@ def _is_oink_stride_compatible_2d(x_2d: torch.Tensor) -> bool: ...@@ -51,23 +53,6 @@ def _is_oink_stride_compatible_2d(x_2d: torch.Tensor) -> bool:
return (x_2d.stride(0) % divby) == 0 return (x_2d.stride(0) % divby) == 0
def rms_norm(
x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
) -> torch.Tensor:
from vllm import _custom_ops as ops
if envs.VLLM_BATCH_INVARIANT:
return rms_norm_batch_invariant(x, weight, variance_epsilon)
out = torch.empty_like(x)
ops.rms_norm(
out,
x,
weight,
variance_epsilon,
)
return out
def fused_add_rms_norm( def fused_add_rms_norm(
x: torch.Tensor, x: torch.Tensor,
residual: torch.Tensor, residual: torch.Tensor,
...@@ -105,23 +90,16 @@ def poly_norm( ...@@ -105,23 +90,16 @@ def poly_norm(
return out return out
def dispatch_rocm_rmsnorm_func( def dispatch_rocm_rmsnorm_func(dtype: torch.dtype, use_aiter: bool = False):
with_fused_add: bool, dtype: torch.dtype, use_aiter: bool = False
):
use_aiter = use_aiter and dtype in [ use_aiter = use_aiter and dtype in [
torch.float16, torch.float16,
torch.bfloat16, torch.bfloat16,
] ]
if use_aiter and with_fused_add:
return rocm_aiter_ops.rms_norm2d_with_add
if use_aiter: if use_aiter:
return rocm_aiter_ops.rms_norm return rocm_aiter_ops.rms_norm2d_with_add
else:
# fall back to CUDA implementation
if with_fused_add:
return fused_add_rms_norm return fused_add_rms_norm
return rms_norm
# --8<-- [start:rms_norm] # --8<-- [start:rms_norm]
...@@ -158,20 +136,14 @@ class RMSNorm(CustomOp): ...@@ -158,20 +136,14 @@ class RMSNorm(CustomOp):
if current_platform.is_rocm(): if current_platform.is_rocm():
aiter_rmsnorm_enabled = rocm_aiter_ops.is_rmsnorm_enabled() aiter_rmsnorm_enabled = rocm_aiter_ops.is_rmsnorm_enabled()
self.rocm_norm_func = dispatch_rocm_rmsnorm_func(
with_fused_add=False,
dtype=weight_dtype,
use_aiter=aiter_rmsnorm_enabled,
)
self.rocm_norm_func_with_add = dispatch_rocm_rmsnorm_func( self.rocm_norm_func_with_add = dispatch_rocm_rmsnorm_func(
with_fused_add=True, dtype=weight_dtype, use_aiter=aiter_rmsnorm_enabled dtype=weight_dtype, use_aiter=aiter_rmsnorm_enabled
) )
# Optional: enable Oink Blackwell RMSNorm custom-op fast path on # Optional: enable Oink Blackwell RMSNorm custom-op fast path on
# compatible CUDA devices (e.g., SM100) when the external Oink # compatible CUDA devices (e.g., SM100) when the external Oink
# package is available. This is detected once at construction time # package is available. This is detected once at construction time
# to avoid per-call device queries in the hot path. # to avoid per-call device queries in the hot path.
self._use_oink_rmsnorm = False
self._use_oink_fused_add_rmsnorm = False self._use_oink_fused_add_rmsnorm = False
if ( if (
not current_platform.is_rocm() not current_platform.is_rocm()
...@@ -203,7 +175,6 @@ class RMSNorm(CustomOp): ...@@ -203,7 +175,6 @@ class RMSNorm(CustomOp):
try: try:
device_index = torch.accelerator.current_device_index() device_index = torch.accelerator.current_device_index()
if _oink_ops.is_oink_available_for_device(device_index): if _oink_ops.is_oink_available_for_device(device_index):
self._use_oink_rmsnorm = True
self._use_oink_fused_add_rmsnorm = ( self._use_oink_fused_add_rmsnorm = (
_oink_ops.has_fused_add_rms_norm() _oink_ops.has_fused_add_rms_norm()
) )
...@@ -215,7 +186,6 @@ class RMSNorm(CustomOp): ...@@ -215,7 +186,6 @@ class RMSNorm(CustomOp):
"RMSNorm; falling back to vLLM RMSNorm. Error: %s", "RMSNorm; falling back to vLLM RMSNorm. Error: %s",
e, e,
) )
self._use_oink_rmsnorm = False
self._use_oink_fused_add_rmsnorm = False self._use_oink_fused_add_rmsnorm = False
@staticmethod @staticmethod
...@@ -270,6 +240,10 @@ class RMSNorm(CustomOp): ...@@ -270,6 +240,10 @@ class RMSNorm(CustomOp):
residual: torch.Tensor | None = None, residual: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""PyTorch-native implementation equivalent to forward().""" """PyTorch-native implementation equivalent to forward()."""
if residual is None:
return ir.ops.rms_norm(
x, self.weight.data, self.variance_epsilon, self.variance_size_override
)
return self.forward_static( return self.forward_static(
x, x,
...@@ -286,35 +260,14 @@ class RMSNorm(CustomOp): ...@@ -286,35 +260,14 @@ class RMSNorm(CustomOp):
x: torch.Tensor, x: torch.Tensor,
residual: torch.Tensor | None = None, residual: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if residual is None and not envs.VLLM_BATCH_INVARIANT:
return ir.ops.rms_norm(
x, self.weight.data, self.variance_epsilon, self.variance_size_override
)
if self.variance_size_override is not None: if self.variance_size_override is not None:
return self.forward_native(x, residual) return self.forward_native(x, residual)
# Optional Oink SM100 fast path (no residual). This path is
# torch.compile-friendly via torch.ops.oink.rmsnorm and preserves
# 2D layouts (including padded rows) when using the Oink
# pointer-based kernel.
if (
residual is None
and getattr(self, "_use_oink_rmsnorm", False)
and x.is_cuda
and x.dim() >= 2
and self.has_weight
and not envs.VLLM_BATCH_INVARIANT
and self.weight.data.dtype == x.dtype
and self.weight.data.is_contiguous()
):
orig_shape = x.shape
hidden_size = orig_shape[-1]
if _can_view_as_2d(x):
x_2d = x.view(-1, hidden_size)
if _is_oink_stride_compatible_2d(x_2d):
y_2d = _oink_ops.rmsnorm(
x_2d,
self.weight.data,
self.variance_epsilon,
)
return y_2d.view(orig_shape)
# Optional Oink SM100 fast path (fused residual-add + RMSNorm, in-place). # Optional Oink SM100 fast path (fused residual-add + RMSNorm, in-place).
# This mirrors vLLM's fused_add_rms_norm semantics by mutating both # This mirrors vLLM's fused_add_rms_norm semantics by mutating both
# `x` (normalized output) and `residual` (residual-out buffer). # `x` (normalized output) and `residual` (residual-out buffer).
...@@ -356,29 +309,34 @@ class RMSNorm(CustomOp): ...@@ -356,29 +309,34 @@ class RMSNorm(CustomOp):
) )
return x, residual return x, residual
add_residual = residual is not None if residual is not None:
if add_residual:
return fused_add_rms_norm( return fused_add_rms_norm(
x, residual, self.weight.data, self.variance_epsilon x, residual, self.weight.data, self.variance_epsilon
) )
else: else:
return rms_norm(x, self.weight.data, self.variance_epsilon) assert envs.VLLM_BATCH_INVARIANT
return rms_norm_batch_invariant(x, self.weight.data, self.variance_epsilon)
def forward_hip( def forward_hip(
self, self,
x: torch.Tensor, x: torch.Tensor,
residual: torch.Tensor | None = None, residual: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if residual is None and not envs.VLLM_BATCH_INVARIANT:
return ir.ops.rms_norm(
x, self.weight.data, self.variance_epsilon, self.variance_size_override
)
if self.variance_size_override is not None: if self.variance_size_override is not None:
return self.forward_native(x, residual) return self.forward_native(x, residual)
add_residual = residual is not None if residual is not None:
if add_residual:
return self.rocm_norm_func_with_add( return self.rocm_norm_func_with_add(
x, residual, self.weight.data, self.variance_epsilon x, residual, self.weight.data, self.variance_epsilon
) )
else: else:
return self.rocm_norm_func(x, self.weight.data, self.variance_epsilon) assert envs.VLLM_BATCH_INVARIANT
return rms_norm_batch_invariant(x, self.weight.data, self.variance_epsilon)
def forward_xpu( def forward_xpu(
self, self,
......
...@@ -30,6 +30,7 @@ from .interface import DeviceCapability, Platform, PlatformEnum ...@@ -30,6 +30,7 @@ from .interface import DeviceCapability, Platform, PlatformEnum
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.cache import CacheDType from vllm.config.cache import CacheDType
from vllm.config.kernel import IrOpPriorityConfig
from vllm.v1.attention.selector import AttentionSelectorConfig from vllm.v1.attention.selector import AttentionSelectorConfig
else: else:
VllmConfig = None VllmConfig = None
...@@ -550,6 +551,26 @@ class CudaPlatformBase(Platform): ...@@ -550,6 +551,26 @@ class CudaPlatformBase(Platform):
def use_custom_op_collectives(cls) -> bool: def use_custom_op_collectives(cls) -> bool:
return True return True
@classmethod
def get_default_ir_op_priority(cls, vllm_config: VllmConfig) -> IrOpPriorityConfig:
from vllm.config.compilation import CompilationMode
from vllm.config.kernel import IrOpPriorityConfig
# Native used by default when compiling,
# use vllm_c kernels where available when no codegen
cc = vllm_config.compilation_config
using_inductor = cc.backend == "inductor" and cc.mode != CompilationMode.NONE
default = ["native"] if using_inductor else ["vllm_c", "native"]
# Use oink if enabled for rms_norm
# TODO(Laurawly/luka): remove this env var,
# users can just use IR op priority directly
rms_norm = default
if envs.VLLM_USE_OINK_OPS:
rms_norm = ["oink"] + default
return IrOpPriorityConfig.with_default(default, rms_norm=rms_norm)
# NVML utils # NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`, # Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
......
...@@ -17,6 +17,7 @@ if TYPE_CHECKING: ...@@ -17,6 +17,7 @@ if TYPE_CHECKING:
from torch.distributed import PrefixStore, ProcessGroup from torch.distributed import PrefixStore, ProcessGroup
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.kernel import IrOpPriorityConfig
from vllm.inputs import EngineInput from vllm.inputs import EngineInput
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
...@@ -931,6 +932,16 @@ class Platform: ...@@ -931,6 +932,16 @@ class Platform:
"num_compute_units is not implemented for the current platform." "num_compute_units is not implemented for the current platform."
) )
@classmethod
def get_default_ir_op_priority(
cls, vllm_config: "VllmConfig"
) -> "IrOpPriorityConfig":
"""Get the default IR op priority for the current platform."""
from vllm.config.kernel import IrOpPriorityConfig
# Native always used by default. Platforms can override this behavior.
return IrOpPriorityConfig.with_default(["native"])
class UnspecifiedPlatform(Platform): class UnspecifiedPlatform(Platform):
_enum = PlatformEnum.UNSPECIFIED _enum = PlatformEnum.UNSPECIFIED
......
...@@ -19,6 +19,7 @@ from .interface import DeviceCapability, Platform, PlatformEnum ...@@ -19,6 +19,7 @@ from .interface import DeviceCapability, Platform, PlatformEnum
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.kernel import IrOpPriorityConfig
from vllm.v1.attention.selector import AttentionSelectorConfig from vllm.v1.attention.selector import AttentionSelectorConfig
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -903,3 +904,32 @@ class RocmPlatform(Platform): ...@@ -903,3 +904,32 @@ class RocmPlatform(Platform):
@classmethod @classmethod
def use_custom_op_collectives(cls) -> bool: def use_custom_op_collectives(cls) -> bool:
return True return True
@classmethod
def get_default_ir_op_priority(
cls, vllm_config: "VllmConfig"
) -> "IrOpPriorityConfig":
from vllm.config.compilation import CompilationMode
from vllm.config.kernel import IrOpPriorityConfig
# Native used by default when compiling,
# use vllm_c kernels where available when no codegen
# TODO(luka/TJ) use aiter, vllm_c, native by default on ROCm
cc = vllm_config.compilation_config
using_inductor = cc.backend == "inductor" and cc.mode != CompilationMode.NONE
default = ["native"] if using_inductor else ["vllm_c", "native"]
# This (mostly) preserves previous CustomOp behavior
# Necessary on ROCm because it's common that users
# enable rms_norm to use the aiter kernel.
# TODO(luka/TJ) remove env vars completely
if (
cc.is_custom_op_enabled("rms_norm")
and envs.VLLM_ROCM_USE_AITER
and envs.VLLM_ROCM_USE_AITER_RMSNORM
):
rms_norm = ["aiter"] + default
else:
rms_norm = default
return IrOpPriorityConfig.with_default(default, rms_norm=rms_norm)
...@@ -21,6 +21,7 @@ from .interface import DeviceCapability, Platform, PlatformEnum ...@@ -21,6 +21,7 @@ from .interface import DeviceCapability, Platform, PlatformEnum
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.kernel import IrOpPriorityConfig
from vllm.v1.attention.selector import AttentionSelectorConfig from vllm.v1.attention.selector import AttentionSelectorConfig
else: else:
VllmConfig = None VllmConfig = None
...@@ -257,6 +258,21 @@ class XPUPlatform(Platform): ...@@ -257,6 +258,21 @@ class XPUPlatform(Platform):
) )
return "vllm.distributed.device_communicators.xpu_communicator.XpuCommunicator" # noqa return "vllm.distributed.device_communicators.xpu_communicator.XpuCommunicator" # noqa
@classmethod
def get_default_ir_op_priority(
cls, vllm_config: "VllmConfig"
) -> "IrOpPriorityConfig":
from vllm.config.compilation import CompilationMode
from vllm.config.kernel import IrOpPriorityConfig
# Native used by default when compiling,
# use fused kernels where available when no codegen
cc = vllm_config.compilation_config
using_inductor = cc.backend == "inductor" and cc.mode != CompilationMode.NONE
default = ["native"] if using_inductor else ["xpu_kernels", "native"]
return IrOpPriorityConfig.with_default(default)
@classmethod @classmethod
def device_count(cls) -> int: def device_count(cls) -> int:
return torch.xpu.device_count() return torch.xpu.device_count()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment