Unverified Commit 40bb1750 authored by Luka Govedič's avatar Luka Govedič Committed by GitHub
Browse files

[vLLM IR] 1/N Implement IR skeleton and rms_norm op (#33825)


Signed-off-by: default avatarLuka Govedič <lgovedic@redhat.com>
Signed-off-by: default avatarXinyu Chen <xinyu1.chen@intel.com>
Signed-off-by: default avatarchzhang <chaojun.zhang@intel.com>
Signed-off-by: default avatarLuka Govedic <luka.govedic@gmail.com>
Co-authored-by: default avatarXinyu Chen <xinyu1.chen@intel.com>
Co-authored-by: default avatarChaojun Zhang <chaojun.zhang@intel.com>
Co-authored-by: default avatarLuka Govedič <ProExpertProg@h100-01.nemg-001.lab.rdu2.dc.redhat.com>
parent 0fab52f0
......@@ -9,6 +9,7 @@ from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import PatternMatcherPass
from torch._ops import OpOverload
import vllm.ir.ops
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
......@@ -30,7 +31,6 @@ from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
from .matcher_utils import (
MatcherFusedAddRMSNorm,
MatcherQuantFP8,
MatcherRMSNorm,
)
logger = init_logger(__name__)
......@@ -54,7 +54,6 @@ def empty_i64(*args: Any, **kwargs: Any) -> torch.Tensor:
return torch.empty(*args, **kwargs, dtype=torch.int64, device="cuda")
RMS_OP = torch.ops._C.rms_norm.default
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
QUANT_OPS: dict[QuantKey, OpOverload] = {
......@@ -131,11 +130,9 @@ class RMSNormQuantPattern:
assert key in FUSED_OPS, f"unsupported fused rmsnorm+quant op for {key}"
self.FUSED_OP = FUSED_OPS[key]
self.rmsnorm_matcher = (
MatcherRMSNorm(epsilon)
if not key.fused_add
else MatcherFusedAddRMSNorm(epsilon)
)
if key.fused_add:
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
self.quant_matcher = MatcherQuantFP8(
key.quant,
has_col_major_scales=has_col_major_scales,
......@@ -161,16 +158,12 @@ class RMSNormStaticQuantPattern(RMSNormQuantPattern):
def pattern(
input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor
) -> torch.Tensor:
result_rms = self.rmsnorm_matcher(input, weight)
result_rms = vllm.ir.ops.rms_norm(input, weight, self.epsilon)
return self.quant_matcher(result_rms, scale)[0]
def replacement(
input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor
) -> torch.Tensor:
# In case we're matching native rms-norm, conversions might be
# optimized out. We convert here just to be safe.
input = input.to(dtype=self.model_dtype)
result = torch.empty(
input.shape, device=input.device, dtype=self.quant_dtype
)
......@@ -187,8 +180,8 @@ class RMSNormStaticQuantPattern(RMSNormQuantPattern):
return at[1]
inputs = [
# input, weight
*self.rmsnorm_matcher.inputs(),
empty_bf16(5, 16), # input
empty_bf16(16), # weight
self.quant_matcher.inputs()[1], # scale
]
pattern(*inputs)
......@@ -391,7 +384,7 @@ class RMSNormGroupQuantPattern(RMSNormQuantPattern):
def pattern(
input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
result_rms = self.rmsnorm_matcher(input, weight)
result_rms = vllm.ir.ops.rms_norm(input, weight, self.epsilon)
result = torch.empty(
result_rms.shape,
device=result_rms.device,
......@@ -442,12 +435,14 @@ class RMSNormGroupQuantPattern(RMSNormQuantPattern):
# result, scale
return at[1], at[2]
scale = self.quant_matcher.empty_f32(1, 1)
pm.register_replacement(
pattern,
replacement,
self.rmsnorm_matcher.inputs() + [scale],
[
empty_bf16(5, 16), # input
empty_bf16(16), # weight
self.quant_matcher.empty_f32(1, 1), # scale
],
pm.fwd_only,
pm_pass,
)
......@@ -472,7 +467,7 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
def pattern(
input: torch.Tensor, weight: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
result_rms = self.rmsnorm_matcher(input, weight)
result_rms = vllm.ir.ops.rms_norm(input, weight, self.epsilon)
# result, scale
return self.quant_matcher(result_rms) # type: ignore[no-any-return]
......@@ -502,7 +497,10 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
pm.register_replacement(
pattern,
replacement,
self.rmsnorm_matcher.inputs(),
[
empty_bf16(5, 16), # input
empty_bf16(16), # weight
],
pm.fwd_only,
pm_pass,
)
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
import torch
import torch._inductor.pattern_matcher as pm
......@@ -24,7 +25,6 @@ from .act_quant_fusion import ActivationQuantPattern
from .matcher_utils import (
MatcherFusedAddRMSNorm,
MatcherQuantFP8,
MatcherRMSNorm,
MatcherSiluAndMul,
)
from .rms_quant_fusion import (
......@@ -41,17 +41,23 @@ class AiterRMSNormQuantPattern:
):
self.epsilon = epsilon
self.quant_dtype = key.quant.dtype
self.device = torch.device("cuda")
self.rmsnorm_matcher = (
MatcherRMSNorm(epsilon, match_rocm_aiter=True)
if not key.fused_add
else MatcherFusedAddRMSNorm(epsilon, match_rocm_aiter=True)
)
if key.fused_add:
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(
epsilon, match_rocm_aiter=True
)
self.quant_matcher = MatcherQuantFP8(
key.quant,
match_rocm_aiter=match_aiter_quant,
)
def empty(self, *args: Any, **kwargs: Any) -> torch.Tensor:
return torch.empty(*args, dtype=torch.bfloat16, device=self.device, **kwargs)
def empty_f32(self, *args: Any, **kwargs: Any) -> torch.Tensor:
return torch.empty(*args, dtype=torch.float32, device=self.device, **kwargs)
class AiterRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern):
"""AITER RMSNorm + Dynamic Quantization pattern."""
......@@ -79,7 +85,7 @@ class AiterRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern):
input: torch.Tensor,
weight: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
result_rms = self.rmsnorm_matcher(input, weight)
result_rms = torch.ops.vllm_ir.rms_norm(input, weight, self.epsilon)
result, scale = self.quant_matcher(result_rms)
return result, scale
......@@ -99,7 +105,8 @@ class AiterRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern):
pm.register_replacement(
pattern,
replacement,
self.rmsnorm_matcher.inputs(),
# input, weight
[self.empty(5, 16), self.empty(16)],
pm.fwd_only,
pm_pass,
)
......@@ -188,7 +195,7 @@ class AiterRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern):
input: torch.Tensor,
weight: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
result_rms = self.rmsnorm_matcher(input, weight)
result_rms = torch.ops.vllm_ir.rms_norm(input, weight, self.epsilon)
result, scale = self.quant_matcher(result_rms)
return result, scale
......@@ -206,7 +213,12 @@ class AiterRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern):
return at[0], at[1]
pm.register_replacement(
pattern, replacement, self.rmsnorm_matcher.inputs(), pm.fwd_only, pm_pass
pattern,
replacement,
# input, weight
[self.empty(5, 16), self.empty(16)],
pm.fwd_only,
pm_pass,
)
......
......@@ -10,6 +10,7 @@ import torch._inductor.pattern_matcher as pm
import torch.fx as fx
from torch._inductor.pattern_matcher import PatternMatcherPass
import vllm.ir.ops
from vllm.config import VllmConfig
from vllm.config.utils import Range
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
......@@ -22,7 +23,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
from ..inductor_pass import enable_fake_mode
from ..utility.noop_elimination import NoOpEliminationPass
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm
from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8
logger = init_logger(__name__)
......@@ -122,35 +123,38 @@ class _SequenceParallelPatternHelper:
x, dim=0, world_size=self.tp_size, group_name=self.tp_group.unique_name
)
def empty(self, *args: Any, **kwargs: Any) -> torch.Tensor:
return torch.empty(*args, dtype=self.dtype, device=self.device, **kwargs)
def empty_f32(self, *args: Any, **kwargs: Any) -> torch.Tensor:
return torch.empty(*args, dtype=torch.float32, device=self.device, **kwargs)
class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
def __init__(self, epsilon: float, dtype: torch.dtype, device: str | None) -> None:
super().__init__(epsilon, dtype, device)
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
def get_inputs(self) -> list[torch.Tensor]:
input = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype)
arg3_1 = torch.empty([4], device=self.device, dtype=self.dtype)
return [input, arg3_1]
# input, weight
return [self.empty([1, 8, 4]), self.empty([4])]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor,
arg3_1: torch.Tensor,
weight: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
all_reduce = self._all_reduce(input)
rmsnorm = self.rmsnorm_matcher(all_reduce, arg3_1)
rmsnorm = vllm.ir.ops.rms_norm(all_reduce, weight, self.epsilon)
return rmsnorm, all_reduce
def replacement(
input: torch.Tensor,
arg3_1: torch.Tensor,
weight: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
reduce_scatter = self._reduce_scatter(input)
rmsnorm = self.rmsnorm_matcher(reduce_scatter, arg3_1)
rmsnorm = vllm.ir.ops.rms_norm(reduce_scatter, weight, self.epsilon)
all_gather = self._all_gather(rmsnorm)
return all_gather, reduce_scatter
......@@ -222,14 +226,11 @@ class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
device: str | None,
) -> None:
super().__init__(epsilon, dtype, device)
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
def get_inputs(self) -> list[torch.Tensor]:
input = torch.zeros([1, 8, 4], device=self.device, dtype=self.dtype)
weight = torch.empty([4], device=self.device, dtype=self.dtype)
scale = torch.tensor(1.0, device=self.device, dtype=torch.float32)
return [input, weight, scale]
# input, weight, scale
return [self.empty([1, 8, 4]), self.empty([4]), self.empty_f32([1, 1])]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
......@@ -238,7 +239,7 @@ class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
all_reduce = self._all_reduce(input)
rms = self.rmsnorm_matcher(all_reduce, weight)
rms = vllm.ir.ops.rms_norm(all_reduce, weight, self.epsilon)
quant, _ = self.quant_matcher(rms, scale)
return quant, all_reduce
......@@ -248,7 +249,7 @@ class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
reduce_scatter = self._reduce_scatter(input)
rms = self.rmsnorm_matcher(reduce_scatter, weight)
rms = vllm.ir.ops.rms_norm(reduce_scatter, weight, self.epsilon)
quant, _ = self.quant_matcher(rms, scale)
all_gather = self._all_gather(quant)
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections import defaultdict
from collections.abc import Iterable
from torch import fx
from torch._inductor.pattern_matcher import (
CallFunctionVarArgs,
Match,
PatternMatcherPass,
register_graph_pattern,
)
from torch._ops import OpOverload, OpOverloadPacket
from vllm.config import VllmConfig
from vllm.ir.op import IrOp
from vllm.logger import init_logger
from vllm.logging_utils import lazy
from ..vllm_inductor_pass import VllmInductorPass
logger = init_logger(__name__)
def get_default_overload(op: OpOverload | OpOverloadPacket) -> OpOverload:
if isinstance(op, OpOverloadPacket):
return op.default
assert isinstance(op, OpOverload), "Expected an OpOverload or OpOverloadPacket"
return op
def get_ir_op(node: fx.Node) -> IrOp | None:
if node.op != "call_function":
return None
if not isinstance(node.target, (OpOverload, OpOverloadPacket)):
return None
op_overload = get_default_overload(node.target)
if op_overload.namespace != "vllm_ir":
return None
op_name = op_overload._opname
if op_name not in IrOp.registry:
logger.warning(
"Unknown vLLM IR op %s, there's likely an issue with torch registration, "
"or a torch custom op was registered in the vllm_ir namespace by mistake.",
op_name,
)
return None
ir_op = IrOp.registry[op_name]
return ir_op
class VllmIRLoweringPass(VllmInductorPass):
"""
This pass lowers vLLM IR ops to their implementations the priority list.
"""
def __init__(self, vllm_config: VllmConfig) -> None:
super().__init__(vllm_config)
self.patterns = PatternMatcherPass(self.pass_name)
self.selected_impls: dict[str, dict[str, str]] = defaultdict(lambda: {})
self.ops = [ir_op.torch_op for ir_op in IrOp.registry.values()]
# Look for any call_function node where the target is a vLLM IR op.
# Then, lower_matched_op will select, trace, and insert the implementation.
register_graph_pattern(
CallFunctionVarArgs(self.ops),
pass_dict=self.patterns,
)(self.lower_matched_op)
def lower_matched_op(self, match: Match, *args, **kwargs):
# TODO(luka) I think args and kwargs are for the match, but just use the node?
assert len(match.nodes) == 1, "Expected single node match"
node = match.nodes[0]
ir_op = get_ir_op(node)
assert ir_op is not None, "Expected vLLM IR op"
assert not node.kwargs # I think there should never be kwargs here
# Select and record the implementation, using fake args
fake_args = fx.map_arg(node.args, lambda arg: arg.meta["val"])
ir_op_impl = ir_op.dispatch(*fake_args)
self.selected_impls[ir_op.name][node.name] = ir_op_impl.provider
# replace_by_example wants node args, not the fake tensors
# TODO(luka): Use aot_export_module to get functionalized graph
# TODO(luka): Cache the fx_replacement to avoid re-tracing the same impl
# Defaults not present on node.args but required for replacement tracing
bound_args = ir_op._py_signature.bind(*node.args)
bound_args.apply_defaults()
match.replace_by_example(ir_op_impl.impl_fn, bound_args.args)
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph) -> None:
# clear at the beginning instead of end, so that tests can inspect
self.selected_impls.clear()
count = self.patterns.apply(graph)
logger.debug("VllmIRLoweringPass lowered %d vLLM IR nodes", count)
# TODO write self.selected_impls to depyf/tlparse dir
def count_items(impls: Iterable[str]) -> dict[str, int]:
counts: dict[str, int] = defaultdict(lambda: 0)
for impl in impls:
counts[impl] += 1
return counts
def print_count(counts: dict[str, int]) -> str:
# e.g., "impl1*3,impl2"
impl_count = lambda i, c: f"{i}" if c == 1 else f"{i}*{c}"
return ",".join(impl_count(impl, count) for impl, count in counts.items())
logger.debug(
"Selected implementations: %s",
lazy(
lambda: ", ".join(
f"{op}={print_count(count_items(impls_by_node.values()))}"
for op, impls_by_node in self.selected_impls.items()
)
),
)
failed_nodes: list[fx.Node] = []
failed_ops: set[str] = set()
# Check no vllm_ir nodes were left in the graph
for node in graph.nodes:
if (ir_op := get_ir_op(node)) is None:
continue
failed_nodes.append(node)
failed_ops.add(ir_op.name)
if failed_nodes or failed_ops:
logger.warning("Failed to lower vLLM IR ops: %s", ",".join(failed_ops))
logger.warning("Full node list: %s", failed_nodes)
def uuid(self) -> str:
"""
IR op priority & impl sources affect lowering pass output,
so we include them in the cache key.
"""
priorities = {name: op.get_priority() for name, op in IrOp.registry.items()}
priorities_str = ";".join(
f"{name}={','.join(p)}" for name, p in priorities.items()
)
impl_uuids_str = ";".join(
f"{name}={
','.join(IrOp.registry[name].impls[provider].uuid() for provider in p)
}"
for name, p in priorities.items()
)
return f"{super().uuid()}|{priorities_str}|{impl_uuids_str}"
......@@ -14,6 +14,7 @@ from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils.system_utils import set_env_var
from .ir.lowering_pass import VllmIRLoweringPass
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
if rocm_aiter_ops.is_enabled():
......@@ -99,8 +100,17 @@ class PostGradPassManager(CustomGraphPass): # type: ignore[misc]
else:
logger.debug("Skipping %s with compile range %s", pass_, compile_range)
# post-cleanup goes before fix_functionalization
# because it requires a functional graph
# perform the first post-cleanup before IR lowering to clean up fusion artifacts
# and make sure no dead IR ops are lowered.
self.post_cleanup(graph)
VllmInductorPass.dump_prefix += 1
# lowering before cleanup so DCE can clean up lowered ops.
# DCE handles mutating ops correctly as well.
self.ir_lowering(graph)
VllmInductorPass.dump_prefix += 1
# clean up after lowering again
self.post_cleanup(graph)
VllmInductorPass.dump_prefix += 1
......@@ -152,7 +162,7 @@ class PostGradPassManager(CustomGraphPass): # type: ignore[misc]
self.passes += [SplitCoalescingPass(config)]
self.passes += [QKNormRoPEFusionPass(config)]
# needs a functional graph
self.ir_lowering = VllmIRLoweringPass(config)
self.post_cleanup = PostCleanupPass(config)
self.fix_functionalization = FixFunctionalizationPass(config)
......@@ -171,6 +181,10 @@ class PostGradPassManager(CustomGraphPass): # type: ignore[misc]
state: dict[str, Any] = {"pass_config": self.pass_config.compute_hash()}
for pass_ in self.passes:
passes.append(pass_.uuid())
passes.append(self.post_cleanup.uuid())
passes.append(self.ir_lowering.uuid())
passes.append(self.post_cleanup.uuid())
passes.append(self.fix_functionalization.uuid())
# Include the compile range in the uuid to ensure that inductor
......
......@@ -152,6 +152,7 @@ class VllmPatternMatcherPass(VllmInductorPass):
f"auto_functionalized as auto_functionalized\n"
f"from torch._inductor.pattern_matcher import *\n"
f"vllm = torch.ops.vllm",
"vllm_ir = torch.ops.vllm_ir",
file=f,
)
......
......@@ -466,6 +466,15 @@ class CompilationConfig:
disabled when running with Inductor: mode>CompilationMode.NONE and
backend="inductor".
Inductor generates (fused) Triton kernels for disabled custom ops."""
ir_enable_torch_wrap: bool = None # type: ignore[assignment]
"""If True, enable vllm_ir torch custom op wrapping during the forward pass.
When False, torch custom op wrapping is disabled, allowing Dynamo to trace the
selected implementation directly or avoiding torch custom op overhead in eager mode.
Defaults to True when using Inductor with vllm-compile
(backend=="inductor" and mode == VLLM_COMPILE), False otherwise.
"""
splitting_ops: list[str] | None = None
"""A list of ops to exclude from cudagraphs, used in piecewise compilation.
......@@ -830,6 +839,7 @@ class CompilationConfig:
"cudagraph_mode",
"max_cudagraph_capture_size",
"use_inductor_graph_partition",
"ir_enable_torch_wrap",
mode="wrap",
)
@classmethod
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
from collections.abc import Callable
from typing import Any, Literal
from dataclasses import asdict, fields
from typing import TYPE_CHECKING, Any, Literal
from pydantic import Field, field_validator
from vllm.config.utils import config, get_hash_factors, hash_factors
from vllm.logger import init_logger
if TYPE_CHECKING:
from vllm.config import VllmConfig
logger = init_logger(__name__)
@config
class IrOpPriorityConfig:
"""
Configuration for vLLM IR op priority for dispatching/lowering during the
forward pass. Each member is a list of strings, which will be passed to
vllm.ir.ops.<op_name>.set_priority() for the duration of the forward pass.
A single comma-separated string is accepted as well,
If specified manually, platform defaults will be appended to the lists.
See KernelConfig.set_platform_defaults().
"""
rms_norm: list[str] = Field(default_factory=list)
"""Priority list for vllm.ir.ops.rms_norm"""
def compute_hash(self) -> str:
"""
Produces a hash unique to the pass configuration.
Any new fields that affect compilation should be added to the hash.
Any future fields that don't affect compilation should be excluded.
Also, manually add IR op impl UUIDs to make sure they affect the compile cache.
"""
factors = get_hash_factors(self, set())
# Implementations are hidden from Dynamo,
# so they don't show up in the traced files list.
from vllm.ir.op import IrOp
assert "_impls" not in factors
factors["_impls"] = {
name: {
provider: IrOp.registry[name].impls[provider].uuid() for provider in p
}
for name, p in asdict(self).items()
}
return hash_factors(factors)
from pydantic import field_validator
@field_validator("*", mode="before")
@classmethod
def _to_list_str(cls, value: str | list[str]):
if isinstance(value, str):
value = value.replace(" ", "").split(",")
assert all(isinstance(v, str) for v in value)
return value
@contextlib.contextmanager
def set_priority(self):
"""
Context manager to set the IR op priority for all op members.
It also imports vllm.kernels to ensure all implementations are made available.
"""
import vllm.kernels # noqa: F401, registers IR op implementations
from vllm.ir.op import IrOp
with contextlib.ExitStack() as stack:
for field in fields(self):
op_priority = getattr(self, field.name)
assert op_priority is not None, (
f"IR op priority for {field.name} must be set"
)
logger.debug(
"Setting IR op priority for %s to %s", field.name, op_priority
)
ir_op = IrOp.registry[field.name]
stack.enter_context(ir_op.set_priority(op_priority))
yield
@classmethod
def with_default(
cls, default: list[str], /, **kwargs: list[str]
) -> "IrOpPriorityConfig":
"""
A helper to create an IrOpPriorityConfig where fields not specified in kwargs
use the given default list.
"""
for field in fields(cls):
if field.name not in kwargs:
kwargs[field.name] = list(default)
return cls(**kwargs)
from vllm.config.utils import config
from vllm.utils.hashing import safe_hash
MoEBackend = Literal[
"auto",
......@@ -26,6 +119,12 @@ MoEBackend = Literal[
class KernelConfig:
"""Configuration for kernel selection and warmup behavior."""
ir_op_priority: IrOpPriorityConfig = Field(default_factory=IrOpPriorityConfig)
"""
vLLM IR op priority for dispatching/lowering during the forward pass.
Platform defaults appended automatically during VllmConfig.__post_init__.
"""
enable_flashinfer_autotune: bool = None # type: ignore[assignment]
"""If True, run FlashInfer autotuning during kernel warmup."""
......@@ -51,21 +150,17 @@ class KernelConfig:
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
Produces a hash unique to the pass configuration.
Any new fields that affect compilation should be added to the hash.
Any future fields that don't affect compilation should be excluded.
"""
# no factors to consider.
# this config will not affect the computation graph.
factors: list[Any] = []
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str
ignored_factors = {
"enable_flashinfer_autotune",
"ir_op_priority", # handled separately below
}
factors = get_hash_factors(self, ignored_factors)
factors["ir_op_priority"] = self.ir_op_priority.compute_hash()
return hash_factors(factors)
@field_validator("enable_flashinfer_autotune", mode="wrap")
@classmethod
......@@ -74,3 +169,31 @@ class KernelConfig:
if value is None:
return value
return handler(value)
def set_platform_defaults(self, vllm_config: "VllmConfig") -> None:
"""Set platform-specific defaults for the kernel config."""
from vllm.platforms import current_platform
platform_op_priority = current_platform.get_default_ir_op_priority(vllm_config)
logger.debug(
"Setting platform-specific IR op priority defaults: %s, user-defined: %s",
platform_op_priority,
self.ir_op_priority,
)
for op_name, op_priority in asdict(platform_op_priority).items():
current_op_priority: list[str] = getattr(self.ir_op_priority, op_name)
if current_op_priority is None:
setattr(self.ir_op_priority, op_name, op_priority)
else:
# Append platform-specific priorities
# Must be idempotent because vllm_config.set_platform_defaults() may be
# called multiple times (due to VllmConfig.__post_init__ manual call).
unique_op_priority = [
op for op in op_priority if op not in current_op_priority
]
current_op_priority.extend(unique_op_priority)
logger.info(
"Final IR op priority after setting platform defaults: %s",
self.ir_op_priority,
)
......@@ -95,9 +95,11 @@ def enable_norm_fusion(cfg: "VllmConfig") -> bool:
"""Enable if either RMS norm or quant FP8 custom op is active;
otherwise Inductor handles fusion."""
return cfg.compilation_config.is_custom_op_enabled(
"rms_norm"
) or cfg.compilation_config.is_custom_op_enabled("quant_fp8")
return (
cfg.compilation_config.is_custom_op_enabled("rms_norm")
or cfg.compilation_config.is_custom_op_enabled("quant_fp8")
or cfg.kernel_config.ir_op_priority.rms_norm[0] != "native"
)
def enable_act_fusion(cfg: "VllmConfig") -> bool:
......@@ -417,6 +419,10 @@ class VllmConfig:
vllm_factors.append(self.compilation_config.compute_hash())
else:
vllm_factors.append("None")
if self.kernel_config:
vllm_factors.append(self.kernel_config.compute_hash())
else:
vllm_factors.append(None)
if self.kv_transfer_config:
vllm_factors.append(self.kv_transfer_config.compute_hash())
else:
......@@ -890,6 +896,13 @@ class VllmConfig:
else:
self.compilation_config.mode = CompilationMode.NONE
# By default, enable torch wrapping only when using custom Inductor lowering
if self.compilation_config.ir_enable_torch_wrap is None:
self.compilation_config.ir_enable_torch_wrap = (
self.compilation_config.mode == CompilationMode.VLLM_COMPILE
and self.compilation_config.backend == "inductor"
)
if all(s not in self.compilation_config.custom_ops for s in ("all", "none")):
if (
self.compilation_config.backend == "inductor"
......@@ -899,6 +912,11 @@ class VllmConfig:
else:
self.compilation_config.custom_ops.append("all")
# This populates IR op priorities,
# must happen after compilation mode and backend are decided,
# but before fusion defaults are applied as those may depend on op priority.
self.kernel_config.set_platform_defaults(self)
default_config = OPTIMIZATION_LEVEL_TO_CONFIG[self.optimization_level]
self._apply_optimization_level_defaults(default_config)
if self.kernel_config.enable_flashinfer_autotune is None:
......@@ -1706,7 +1724,8 @@ class VllmConfig:
f"enable_prefix_caching={self.cache_config.enable_prefix_caching}, "
f"enable_chunked_prefill={self.scheduler_config.enable_chunked_prefill}, " # noqa
f"pooler_config={self.model_config.pooler_config!r}, "
f"compilation_config={self.compilation_config!r}"
f"compilation_config={self.compilation_config!r}, "
f"kernel_config={self.kernel_config!r}"
)
def validate_block_size(self) -> None:
......
......@@ -8,7 +8,7 @@ import functools
import json
import sys
from collections.abc import Callable
from dataclasses import MISSING, dataclass, fields, is_dataclass
from dataclasses import MISSING, asdict, dataclass, fields, is_dataclass
from itertools import permutations
from types import UnionType
from typing import (
......@@ -70,7 +70,7 @@ from vllm.config.cache import (
PrefixCachingHashAlgo,
)
from vllm.config.device import Device
from vllm.config.kernel import MoEBackend
from vllm.config.kernel import IrOpPriorityConfig, MoEBackend
from vllm.config.lora import MaxLoRARanks
from vllm.config.model import (
ConvertOption,
......@@ -401,6 +401,7 @@ class EngineArgs:
max_cudagraph_capture_size: int | None = get_field(
CompilationConfig, "max_cudagraph_capture_size"
)
ir_op_priority: IrOpPriorityConfig = get_field(KernelConfig, "ir_op_priority")
# Note: Specifying a custom executor backend by passing a class
# is intended for expert use only. The API may change without
# notice.
......@@ -657,6 +658,9 @@ class EngineArgs:
self.weight_transfer_config = WeightTransferConfig(
**self.weight_transfer_config
)
if isinstance(self.ir_op_priority, dict):
self.ir_op_priority = IrOpPriorityConfig(**self.ir_op_priority)
# Setup plugins
from vllm.plugins import load_general_plugins
......@@ -1293,6 +1297,7 @@ class EngineArgs:
title="KernelConfig",
description=KernelConfig.__doc__,
)
kernel_group.add_argument("--ir-op-priority", **kernel_kwargs["ir_op_priority"])
kernel_group.add_argument(
"--enable-flashinfer-autotune",
**kernel_kwargs["enable_flashinfer_autotune"],
......@@ -1917,6 +1922,22 @@ class EngineArgs:
if self.moe_backend != "auto":
kernel_config.moe_backend = self.moe_backend
# Transfer top-level ir_op_priority into KernelConfig.ir_op_priority
for op_name, op_priority in asdict(self.ir_op_priority).items():
# Empty means unset
if not op_priority:
continue
# Priority cannot be set 2x for the same op
if getattr(kernel_config.ir_op_priority, op_name):
raise ValueError(
f"Op priority for {op_name} specified via both ir_op_priority "
f"and KernelConfig.ir_op_priority, only one allowed at a time."
)
# Set the attribute
setattr(kernel_config.ir_op_priority, op_name, op_priority)
load_config = self.create_load_config()
# Pass reasoning_parser into StructuredOutputsConfig
......
......@@ -10,6 +10,7 @@ from typing import Any
import torch
import vllm.envs as envs
import vllm.ir
from vllm.config import CUDAGraphMode, ParallelConfig, VllmConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
......@@ -378,7 +379,13 @@ def set_forward_context(
)
try:
with override_forward_context(forward_context):
with (
override_forward_context(forward_context),
vllm_config.kernel_config.ir_op_priority.set_priority(),
vllm.ir.enable_torch_wrap(
vllm_config.compilation_config.ir_enable_torch_wrap
),
):
yield
finally:
global last_logging_time, batchsize_logging_interval
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from . import ops
from .op import enable_torch_wrap, register_op
__all__ = ["enable_torch_wrap", "register_op", "ops"]
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import inspect
from collections.abc import Callable
from pathlib import Path
from typing import Any, ClassVar, overload
import torch
from torch.library import Library, infer_schema
from vllm.ir.util import hash_source, weak_cache
from vllm.logger import init_logger
from vllm.logging_utils import lazy, tensors_str_no_data
vllm_ir_lib = Library("vllm_ir", "FRAGMENT")
logger = init_logger(__name__)
RESERVED_PROVIDERS = ["native", "unfused"]
"""Providers that are reserved and cannot be used for custom implementations."""
_ENABLE_TORCH_WRAP: bool = True
"""Global override flag to control torch op layer wrapping."""
@contextlib.contextmanager
def enable_torch_wrap(enable: bool = True):
"""
Context manager to enable/disable torch custom op wrapping for vLLM IR ops.
When torch wrapping is disabled, the torch custom op layer is skipped
and IR ops dispatch directly to the implementation.
Helpful for avoiding torch dispatch overhead in eager mode
and avoiding the need for lowering for platforms not using Inductor.
"""
global _ENABLE_TORCH_WRAP
old = _ENABLE_TORCH_WRAP
try:
_ENABLE_TORCH_WRAP = enable
yield
finally:
_ENABLE_TORCH_WRAP = old
# 0-param decorator overload
@overload
def register_op(f: Callable[..., Any]) -> "IrOp": ...
# parametrized decorator overload
@overload
def register_op(
*,
name: str | None = None,
) -> Callable[[Callable[..., Any]], "IrOp"]: ...
def register_op(
f: Callable | None = None,
*,
name: str | None = None,
) -> "IrOp | Callable[[Callable], IrOp]":
"""
Register a new vLLM IR op.
:param f: the native implementation of the op
:param name: the name of the op, defaults to the function name
:return: the IrOp object if f is provided, otherwise a decorator
Example usage:
```python
@vllm.ir.register_op
def my_op(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x + y
@vllm.ir.register_op(name="custom_mul")
def multiply(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x * y"""
def decorator(_f: Callable):
op_name: str = _f.__name__ if name is None else name
assert op_name not in IrOp.registry
op = IrOp(op_name, _f)
IrOp.registry[op_name] = op
return op
if f is not None:
return decorator(f)
return decorator
class IrOp:
registry: ClassVar[dict[str, "IrOp"]] = {}
name: str
impls: dict[str, "IrOpImpl"]
def __init__(self, name: str, native_impl: Callable):
self._py_signature = inspect.signature(native_impl)
if any(
p.kind == inspect.Parameter.KEYWORD_ONLY
for p in self._py_signature.parameters.values()
):
raise ValueError(
f"Op {name} has keyword-only arguments which are not currently "
f"supported. That's because kwargs are not allowed during lowering."
)
self.name = name
self.impls: dict[str, IrOpImpl] = {}
self._priority_impls: list[IrOpImpl] = []
self._schema_str = infer_schema(native_impl, mutates_args=[])
# native implementation
self.impls["native"] = IrOpImpl(
self, "native", native_impl, supported=True, supports_args=None
)
# By default, fake routes directly to native,
# can be overridden by register_fake
self._fake_fn = native_impl
# torch registration
vllm_ir_lib.define(self.name + self._schema_str)
# CompositeExplicitAutograd is not decomposed
# by ATen IR normalization in AOTAutograd
vllm_ir_lib.impl(
self.name, self._inner_call, dispatch_key="CompositeExplicitAutograd"
)
vllm_ir_lib._register_fake(self.name, self._fake_call)
assert hasattr(torch.ops.vllm_ir, name)
self.torch_op: torch._ops.OpOverload = getattr(torch.ops.vllm_ir, name).default
def register_fake(self, fn: Callable) -> Callable:
"""
Register a fake impl for the torch custom op. If this method is not called,
the native implementation is used directly for the fake implementation.
"""
self._fake_fn = fn
return fn
def _fake_call(self, *args, **kwargs) -> Any:
"""
Call to the fake implementation of the op. We use indirection because we want
users to be able to register fake later but also want it to fall back to native
directly by default, instead of going through the dispatching mechanism.
"""
return self._fake_fn(*args, **kwargs)
def register_impl(
self,
provider: str,
*,
supported: bool = True,
supports_args: Callable[..., bool] | None = None,
):
"""
Register an implementation for this custom op.
:param provider: The name of the provider, must be unique.
:param supported: Static support check, use this to check platform support.
:param supports_args: Dynamic arg support check, used for types and shapes.
:return: A decorator that registers the implementation.
The decorated function must have the same semantics and signature as
the native implementation.
The provider name must be unique and not one of the RESERVED_PROVIDERS.
The supported and supports_args parameters should not be used to implement
custom enablement logic based on global state (e.g. environment variables).
Instead, supported param should only be used to check for platform support
(e.g. whether a specific hardware or library is available).
supports_args should be used to check whether the provided arguments are
compatible with the implementation.
For custom enablement logic, set op impl priority.
Example:
```python
@my_op.register_impl("my_provider", supported=torch.cuda.is_available())
def my_provider_impl(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: ...
```
"""
assert provider not in RESERVED_PROVIDERS, (
f"Provider name {provider} is reserved."
)
def _register_impl(f: Callable):
impl = IrOpImpl(self, provider, f, supported, supports_args)
self.impls[provider] = impl
if self.get_priority():
logger.warning(
"Warning: registering new impl %s for op %s while priority is set.",
provider,
self.name,
)
return impl
return _register_impl
def _inner_call(self, *args, **kwargs) -> Any:
"""
Eager call to torch op lands here. When torch wrapping is disabled,
__call__ routes straight here instead of going through torch op dispatching.
"""
impl = self.dispatch(*args, **kwargs)
return impl.impl_fn(*args, **kwargs)
def apply_arg_defaults(self, args) -> tuple:
"""
Return args with default values applied.
Defaults are taken from the native implementation signature.
SHOULD NOT BE USED IN THE DISPATCH PATH (SLOW).
Only for Inductor lowering.
"""
bound_args = self._py_signature.bind(*args)
bound_args.apply_defaults()
return bound_args.args
def dispatch(self, *args, **kwargs) -> "IrOpImpl":
"""
Dispatch to the appropriate implementation based on current priority
and argument support checks. Returns the selected IrOpImpl.
THIS FUNCTION IS ON THE HOT PATH (OP DISPATCH), MUST BE FAST.
"""
if not self._priority_impls:
if not torch.compiler.is_compiling():
# Logging not compatible with Dynamo tracing
# (this code is exposed when torch wrapping is disabled)
logger.warning_once(
"Priority not set for op %s, using native implementation.",
self.name,
)
return self.impls["native"]
for impl in self._priority_impls:
if not impl.supported:
raise ValueError(
f"Implementation {impl.provider} for op {self.name} not supported. "
f"All implementations in priority list must be supported."
)
if impl.supports_args(*args, **kwargs):
return impl
if not torch.compiler.is_compiling():
logger.debug(
"Skipping provider %s because it does not support "
"%s with args=%s kwargs=%s",
impl.provider,
self.name,
lazy(lambda: tensors_str_no_data(args)),
lazy(lambda: tensors_str_no_data(kwargs)),
)
raise RuntimeError(
"Priority set incorrectly: the last implementation must "
"support all args (can be native). This is likely an internal bug"
)
def __call__(self, *args, **kwargs) -> Any:
if not _ENABLE_TORCH_WRAP:
return self._inner_call(*args, **kwargs)
return self.torch_op(*args, **kwargs)
def get_priority(self) -> list[str]:
"""Get the current dispatch priority for implementations for this op."""
return [p.provider for p in self._priority_impls]
@contextlib.contextmanager
def set_priority(self, priority: list[str]):
"""
Context manager to set the dispatch priority for implementations for this op.
"""
assert all(p in self.impls for p in priority), (
"All providers in priority must be registered implementations."
)
def filter_priority_impls(p_list: list[str]) -> list[IrOpImpl]:
filtered_impls = []
for p in p_list:
impl = self.impls[p]
if not impl.supported:
# Skip unsupported implementations
continue
filtered_impls.append(impl)
# If all args are supported, skip other implementations
if impl.supports_all_args:
return filtered_impls
logger.warning_once(
"Op %s: No implementation in priority list supports all args, "
"execution fallback to native is possible. To silence this warning, "
"explicitly add 'native' to the end of the priority list",
self.name,
)
filtered_impls.append(self.impls["native"])
return filtered_impls
# Temporarily set priority
old_priority_impls = self._priority_impls
try:
self._priority_impls = filter_priority_impls(priority)
yield
finally:
self._priority_impls = old_priority_impls
def supported_providers(self) -> list[str]:
return [p.provider for p in self.impls.values() if p.supported]
class IrOpImpl:
def __init__(
self,
op: IrOp,
provider: str,
impl_fn: Callable,
supported: bool,
supports_args: Callable[..., bool] | None,
):
assert provider not in op.impls, (
f"Implementation for provider {provider} already registered."
)
# Native also uses this path, so we allow it here.
assert provider == "native" or provider not in RESERVED_PROVIDERS
# Enforce the exact same schema as the native implementation.
# This takes care of names, types, and defaults.
schema = infer_schema(impl_fn, mutates_args=[])
if schema != op._schema_str:
raise ValueError(
f"Implementation for provider {provider} has schema '{schema}' which "
f"does not match native schema '{op._schema_str}' for op {op.name}."
)
if supports_args is not None:
if not callable(supports_args):
raise ValueError(
f"supports_args for provider {provider} must be a callable"
)
# We also manually validate the supports_args signature.
# Matching signatures allow faster dispatch on the hotpath.
# Check that supports_args does not have keyword-only parameters
supports_args_signature = inspect.signature(supports_args)
params = supports_args_signature.parameters
if any(p.kind == inspect.Parameter.KEYWORD_ONLY for p in params.values()):
raise ValueError(
f"supports_args for provider {provider} "
f"cannot have keyword-only parameters"
)
# Check that supports_args has the same total number of parameters
op_params = op._py_signature.parameters
if len(params) != len(op_params):
raise ValueError(
f"supports_args for provider {provider} must have the same number "
f"of parameters ({len(params)}) as the native implementation "
f"({len(op_params)})"
)
# Check that names and defaults match for supports_args
for p, op_p in zip(params.values(), op_params.values()):
if p.name != op_p.name:
raise ValueError(
f"supports_args for provider {provider} has parameter "
f"'{p.name}' which does not match native parameter "
f"'{op_p.name}'"
)
if p.default != op_p.default:
raise ValueError(
f"supports_args for provider {provider} has parameter "
f"'{p.name}' with default {p.default} which does not match "
f"native default {op_p.default}'"
)
self.op = op
self.provider = provider
self.impl_fn = impl_fn
self.supported = supported
self._supports_args = supports_args
@property
def supports_all_args(self) -> bool:
"""Check if this implementation supports all args unconditionally."""
return self._supports_args is None
def supports_args(self, *args, **kwargs) -> bool:
if self._supports_args is None:
return True
return self._supports_args(*args, **kwargs)
@weak_cache
def uuid(self):
"""
Compile-time hash to uniquely determine whether the implementation has changed.
Used by vllm-compile hash mechanism and torch.compile lowering pass uuid to
control the vLLM compile cache and AOTAutograd/Inductor caches respectively.
Source file contents do not change so we cache uuid.
TODO(luka): Cache the file hash as multiple impls are likely in the same file.
"""
sources = [Path(inspect.getfile(self.impl_fn))]
return hash_source(*sources)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from .layernorm import rms_norm
__all__ = ["rms_norm"]
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from torch import Tensor
from ..op import register_op
@register_op
def rms_norm(
x: Tensor, weight: Tensor | None, epsilon: float, variance_size: int | None = None
) -> Tensor:
"""Weighted root-mean-square layer normalization"""
orig_dtype = x.dtype
x = x.to(torch.float32)
x_var = x if variance_size is None else x[..., :variance_size]
variance = x_var.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + epsilon)
x = x.to(orig_dtype)
if weight is not None:
x = x * weight
return x
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
import hashlib
import inspect
import types
import weakref
from pathlib import Path
from typing import Any
def hash_source(*srcs: str | Any) -> str:
"""
Utility method to hash the sources of functions or objects.
:param srcs: strings or objects to add to the hash.
Objects and functions have their source inspected.
:return:
"""
hasher = hashlib.sha256()
for src in srcs:
if src is None:
src_str = "None"
elif isinstance(src, str):
src_str = src
elif isinstance(src, Path):
src_str = src.read_text()
elif isinstance(src, (types.FunctionType, type)):
src_str = inspect.getsource(src)
else:
# object instance
src_str = inspect.getsource(src.__class__)
hasher.update(src_str.encode("utf-8"))
return hasher.hexdigest()
def weak_lru_cache(maxsize: int | None = 128, typed: bool = False):
"""
LRU Cache decorator that keeps a weak reference to 'self'.
This avoids memory leakage, which happens when functools.lru_cache
stores a reference to self in the global cache.
Taken from: https://stackoverflow.com/a/68052994/5082708
"""
def wrapper(func):
@functools.lru_cache(maxsize, typed)
def _func(_self, *args, **kwargs):
return func(_self(), *args, **kwargs)
@functools.wraps(func)
def inner(self, *args, **kwargs):
return _func(weakref.ref(self), *args, **kwargs)
return inner
return wrapper
def weak_cache(user_function, /):
"""Simple weak equivalent to functools.cache"""
return weak_lru_cache(maxsize=None)(user_function)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Kernel implementations for vLLM."""
from . import aiter_ops, oink_ops, vllm_c, xpu_ops
__all__ = ["vllm_c", "aiter_ops", "oink_ops", "xpu_ops"]
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
import torch
from torch import Tensor
from torch.library import Library
from vllm import ir
from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op
current_platform.import_kernels()
def is_aiter_found() -> bool:
from importlib.util import find_spec
return find_spec("aiter") is not None
aiter_lib = Library("vllm_aiter", "FRAGMENT")
"""
This library holds torch custom ops for wrapped AITER ops.
Many AITER ops want to remain invisible to torch.compile even after lowering.
They are thus wrapped into torch custom ops inside the IR op implementations.
"""
direct_register_aiter_op = functools.partial(
direct_register_custom_op, target_lib=aiter_lib
)
"""Syntactic sugar for registering AITER custom ops."""
AITER_SUPPORTED = is_aiter_found()
"""Most kernels in this file are supported if AITER is installed."""
rms_no_var_16bit_only = (
lambda x, weight, epsilon, variance_size=None: variance_size is None
and x.dtype
in (
torch.float16,
torch.bfloat16,
)
)
"""AITER rms_norm only supports float16 and bfloat16 acts and no var_size override."""
@ir.ops.rms_norm.register_impl(
"aiter", supports_args=rms_no_var_16bit_only, supported=AITER_SUPPORTED
)
def rms_norm(
x: Tensor, weight: Tensor | None, epsilon: float, variance_size: int | None = None
) -> Tensor:
assert variance_size is None
assert x.dtype in (torch.float16, torch.bfloat16)
if weight is None:
weight = torch.ones(x.shape[-1], device=x.device, dtype=x.dtype)
return torch.ops.vllm_aiter.rms_norm(x, weight, epsilon)
def _rms_norm_impl(x: Tensor, weight: Tensor, variance_epsilon: float) -> Tensor:
from aiter import rms_norm
if x.dim() > 2:
x_original_shape = x.shape
x = x.reshape(-1, x_original_shape[-1])
x = rms_norm(x, weight, variance_epsilon)
return x.reshape(x_original_shape)
return rms_norm(x, weight, variance_epsilon)
def _rms_norm_fake(x: Tensor, weight: Tensor, variance_epsilon: float) -> Tensor:
return torch.empty_like(x)
direct_register_aiter_op(
op_name="rms_norm", op_func=_rms_norm_impl, fake_impl=_rms_norm_fake
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm import ir
from vllm.platforms import current_platform
OINK_AVAILABLE = current_platform.has_device_capability(100) and hasattr(
torch.ops, "oink"
)
def has_oink_op(name: str) -> bool:
"""Check if a specific oink op is registered."""
return OINK_AVAILABLE and hasattr(torch.ops.oink, name)
def _can_view_as_2d(x: torch.Tensor) -> bool:
"""Return True if x.view(-1, x.shape[-1]) is viewable (no copy)."""
if x.dim() < 2:
return False
if x.dim() == 2:
return True
# For a view(-1, N) to be valid, all leading dims must be contiguous with
# respect to each other (size-1 dims are ignored).
for dim in range(x.dim() - 1):
# Strides for size-1 dims are irrelevant and can be arbitrary.
if x.size(dim + 1) != 1 and x.stride(dim) != x.stride(dim + 1) * x.size(
dim + 1
):
return False
return True
def _is_oink_stride_compatible_2d(x_2d: torch.Tensor) -> bool:
"""Return True if x_2d meets Oink's pointer-path stride constraints."""
if x_2d.dim() != 2:
return False
if x_2d.stride(1) != 1:
return False
# Match Oink's vectorization constraint: stride(0) divisible by 256b.
if x_2d.dtype in (torch.float16, torch.bfloat16):
divby = 16
elif x_2d.dtype == torch.float32:
divby = 8
else:
return False
return (x_2d.stride(0) % divby) == 0
oink_rms_supported = (
lambda x, weight, epsilon, variance_size=None: variance_size is None
and weight is not None
and x.dim() >= 2
and x.dtype == weight.dtype
and weight.is_contiguous()
and _can_view_as_2d(x)
and _is_oink_stride_compatible_2d(x.view(-1, x.shape[-1]))
)
"""
Oink rms only supports 2d-like inputs with contiguous weight
and no variance_size override.
"""
@ir.ops.rms_norm.register_impl(
"oink", supports_args=oink_rms_supported, supported=has_oink_op("rmsnorm")
)
def rms_norm(
x: torch.Tensor,
weight: torch.Tensor | None,
epsilon: float,
variance_size: int | None = None,
) -> torch.Tensor:
assert variance_size is None
x_2d = x.view(-1, x.shape[-1])
return torch.ops.oink.rmsnorm(x_2d, weight, epsilon).view_as(x)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment