"vscode:/vscode.git/clone" did not exist on "a3205beffb6b3d2923fd9ad8e1ef8b4fd5f7ed29"
Unverified Commit 40bb1750 authored by Luka Govedič's avatar Luka Govedič Committed by GitHub
Browse files

[vLLM IR] 1/N Implement IR skeleton and rms_norm op (#33825)


Signed-off-by: default avatarLuka Govedič <lgovedic@redhat.com>
Signed-off-by: default avatarXinyu Chen <xinyu1.chen@intel.com>
Signed-off-by: default avatarchzhang <chaojun.zhang@intel.com>
Signed-off-by: default avatarLuka Govedic <luka.govedic@gmail.com>
Co-authored-by: default avatarXinyu Chen <xinyu1.chen@intel.com>
Co-authored-by: default avatarChaojun Zhang <chaojun.zhang@intel.com>
Co-authored-by: default avatarLuka Govedič <ProExpertProg@h100-01.nemg-001.lab.rdu2.dc.redhat.com>
parent 0fab52f0
......@@ -2,6 +2,16 @@ group: Kernels
depends_on:
- image-build
steps:
- label: vLLM IR Tests
timeout_in_minutes: 10
working_dir: "/vllm-workspace/"
source_file_dependencies:
- vllm/ir
- vllm/kernels
commands:
- pytest -v -s tests/ir
- pytest -v -s tests/kernels/ir
- label: Kernels Core Operation Test
timeout_in_minutes: 75
source_file_dependencies:
......
......@@ -13,6 +13,9 @@
/vllm/model_executor/layers/rotary_embedding.py @vadiklyutiy
/vllm/model_executor/model_loader @22quinn
/vllm/model_executor/layers/batch_invariant.py @yewentao256
/vllm/ir @ProExpertProg
/vllm/kernels/ @ProExpertProg @tjtanaa
/vllm/kernels/helion @ProExpertProg @zou3519
/vllm/multimodal @DarkLight1337 @ywang96 @NickLucche @tjtanaa
/vllm/vllm_flash_attn @LucasWilkinson @MatthewBonanni
CMakeLists.txt @tlrmchlsmth @LucasWilkinson
......@@ -74,6 +77,7 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson
/tests/entrypoints @DarkLight1337 @robertgshaw2-redhat @aarnphm @NickLucche
/tests/evals @mgoin @vadiklyutiy
/tests/kernels @mgoin @tlrmchlsmth @WoosukKwon @yewentao256
/tests/kernels/ir @ProExpertProg @tjtanaa
/tests/models @DarkLight1337 @ywang96
/tests/multimodal @DarkLight1337 @ywang96 @NickLucche
/tests/quantization @mgoin @robertgshaw2-redhat @yewentao256 @pavanimajety
......
......@@ -8,7 +8,7 @@ from copy import deepcopy
import depyf
from torch import fx
from torch._ops import OpOverload
from torch._ops import OpOverload, OpOverloadPacket
from torch.fx._utils import lazy_format_graph_code
from vllm.compilation.passes.fx_utils import find_op_nodes
......@@ -90,7 +90,9 @@ class TestBackend:
# assign by reference, will reflect the final state of the graph
self.final_graph = graph
def check_before_ops(self, ops: Sequence[OpOverload], fully_replaced=True):
def check_before_ops(
self, ops: Sequence[OpOverload | OpOverloadPacket], fully_replaced=True
):
for op in ops:
num_pre = len(list(find_op_nodes(op, self.graph_pre_pass)))
num_post = len(list(find_op_nodes(op, self.graph_post_pass)))
......@@ -99,13 +101,19 @@ class TestBackend:
if fully_replaced:
assert num_post == 0, f"Unexpected op {op.name()} in post-pass graph"
def check_after_ops(self, ops: Sequence[OpOverload]):
def check_after_ops(self, ops: Sequence[OpOverload | OpOverloadPacket]):
for op in ops:
num_pre = len(list(find_op_nodes(op, self.graph_pre_pass)))
num_post = len(list(find_op_nodes(op, self.graph_post_pass)))
assert num_pre == 0, f"Unexpected op {op.name()} in pre-pass graph"
assert num_post > 0, f"Op {op.name()} not found in post-pass graph"
def op_count(self, op: OpOverload, before=False) -> int:
def op_count(self, op: OpOverload | OpOverloadPacket, before=False) -> int:
graph = self.graph_pre_pass if before else self.graph_post_pass
return len(list(find_op_nodes(op, graph)))
def print_graphs(self):
print("=== Graph before custom passes ===")
print(self.graph_pre_pass.python_code(root_module="self", verbose=True).src)
print("=== Graph after custom passes ===")
print(self.graph_post_pass.python_code(root_module="self", verbose=True).src)
......@@ -99,6 +99,8 @@ def test_tp1_fp8_fusions(
model_kwargs["hf_overrides"] = hf_overrides(n_layers)
model_kwargs["load_format"] = "dummy"
model_kwargs["max_model_len"] = 1024
model_kwargs["kernel_config"] = {"enable_flashinfer_autotune": False}
compilation_config = dict(
use_inductor_graph_partition=inductor_graph_partition,
custom_ops=custom_ops.split(","),
......@@ -166,6 +168,7 @@ def test_tp1_fp4_fusions(
model_kwargs["hf_overrides"] = hf_overrides(n_layers)
model_kwargs["load_format"] = "dummy"
model_kwargs["max_model_len"] = 1024
model_kwargs["kernel_config"] = {"enable_flashinfer_autotune": False}
compilation_config = dict(
use_inductor_graph_partition=inductor_graph_partition,
......
......@@ -68,6 +68,7 @@ def test_tp2_ar_rms_fp8_fusions(
model_kwargs["hf_overrides"] = hf_overrides(n_layers)
model_kwargs["load_format"] = "dummy"
model_kwargs["max_model_len"] = 1024
model_kwargs["kernel_config"] = {"enable_flashinfer_autotune": False}
compilation_config = dict(
use_inductor_graph_partition=inductor_graph_partition,
......@@ -128,6 +129,7 @@ def test_tp2_ar_rms_fp4_fusions(
model_kwargs["hf_overrides"] = hf_overrides(n_layers)
model_kwargs["load_format"] = "dummy"
model_kwargs["max_model_len"] = 1024
model_kwargs["kernel_config"] = {"enable_flashinfer_autotune": False}
compilation_config = dict(
use_inductor_graph_partition=inductor_graph_partition,
......@@ -182,6 +184,7 @@ def test_tp2_ar_rms_fusions(
model_kwargs["hf_overrides"] = hf_overrides(n_layers)
model_kwargs["load_format"] = "dummy"
model_kwargs["max_model_len"] = 1024
model_kwargs["kernel_config"] = {"enable_flashinfer_autotune": False}
compilation_config = dict(
use_inductor_graph_partition=inductor_graph_partition,
......
......@@ -58,6 +58,7 @@ def test_tp2_async_tp_fp8_fusions(
model_kwargs["hf_overrides"] = hf_overrides(n_layers)
model_kwargs["load_format"] = "dummy"
model_kwargs["max_model_len"] = 1024
model_kwargs["kernel_config"] = {"enable_flashinfer_autotune": False}
compilation_config = dict(
use_inductor_graph_partition=inductor_graph_partition,
......@@ -121,6 +122,7 @@ def test_tp2_async_tp_fusions(
model_kwargs["hf_overrides"] = hf_overrides(n_layers)
model_kwargs["load_format"] = "dummy"
model_kwargs["max_model_len"] = 1024
model_kwargs["kernel_config"] = {"enable_flashinfer_autotune": False}
compilation_config = dict(
use_inductor_graph_partition=inductor_graph_partition,
......
......@@ -9,7 +9,6 @@ from tests.compile.backend import TestBackend
from tests.utils import TestFP8Layer, multi_gpu_test
from vllm.compilation.passes.fusion.rms_quant_fusion import RMSNormQuantFusionPass
from vllm.compilation.passes.fusion.sequence_parallelism import SequenceParallelismPass
from vllm.compilation.passes.fx_utils import find_auto_fn
from vllm.compilation.passes.utility.noop_elimination import NoOpEliminationPass
from vllm.compilation.passes.utility.post_cleanup import PostCleanupPass
from vllm.compilation.passes.vllm_inductor_pass import VllmInductorPass
......@@ -86,13 +85,14 @@ class TestAllReduceRMSNormModel(torch.nn.Module):
]
def ops_in_model(self):
if RMSNorm.enabled():
return [
torch.ops._C.rms_norm.default,
return (
[torch.ops.vllm_ir.rms_norm]
+ [
torch.ops._C.fused_add_rms_norm.default,
]
else:
return []
if RMSNorm.enabled()
else []
)
class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
......@@ -321,4 +321,4 @@ def sequence_parallelism_pass_on_test_model(
assert backend.op_count(op, before=False) == 4
for op in model.ops_in_model():
find_auto_fn(backend.graph_post_pass.nodes, op)
assert backend.op_count(op, before=False) > 0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from torch import nn
import vllm.kernels # noqa: F401 to register kernels
from vllm import ir
from vllm.compilation.passes.ir.lowering_pass import (
VllmIRLoweringPass,
)
from vllm.config import get_current_vllm_config
from vllm.ir import ops
from vllm.platforms import current_platform
from ...backend import TestBackend
class Model(nn.Module):
def __init__(self, hidden_size=16, *args, **kwargs):
super().__init__(*args, **kwargs)
self.hidden_size = hidden_size
self.weight = torch.ones(hidden_size, dtype=torch.bfloat16)
def forward(self, x):
x1 = x + 4.0
x2 = ops.rms_norm(x1, self.weight, 1e-5)
x3 = x2 * 5.0
# no weight
x4 = ops.rms_norm(x3, None, 1e-5)
x5 = x4 / 2.0
# dispatch to native due to variance_size parameter
x6 = ops.rms_norm(x5, self.weight, 1e-5, self.hidden_size // 2)
return x6 + 3.0
@pytest.mark.parametrize("rms_provider", ops.rms_norm.supported_providers())
def test_lowering_rms_norm(rms_provider, default_vllm_config):
torch.set_default_device(current_platform.device_type)
lowering_pass = VllmIRLoweringPass(get_current_vllm_config())
backend = TestBackend(lowering_pass)
backend_unlowered = TestBackend()
model = Model()
x = torch.randn(8, 16, dtype=torch.bfloat16)
with (
ops.rms_norm.set_priority([rms_provider, "native"]),
ir.enable_torch_wrap(True),
):
compiled_model = torch.compile(model, backend=backend, fullgraph=True)
compiled_unlowered_model = torch.compile(
model, backend=backend_unlowered, fullgraph=True
)
output = compiled_model(x)
output_unlowered = compiled_unlowered_model(x)
selected = lowering_pass.selected_impls["rms_norm"]
assert len(selected) == 3
assert selected["rms_norm"] == rms_provider
assert selected["rms_norm_1"] == rms_provider
assert selected["rms_norm_2"] == "native"
# Compiled function guards on global value, avoid recompilation
with ir.enable_torch_wrap(True):
output2 = compiled_model(x)
torch.testing.assert_close(output_unlowered, output)
torch.testing.assert_close(output_unlowered, output2)
......@@ -6,6 +6,7 @@ import pytest
import torch
import vllm.config
import vllm.ir.ops
import vllm.plugins
from tests.compile.backend import TestBackend
from tests.utils import TestBlockFP8Layer, TestFP8Layer
......@@ -51,7 +52,6 @@ from vllm.utils.deep_gemm import (
FP8_DTYPE = current_platform.fp8_dtype()
RMS_OP = torch.ops._C.rms_norm.default
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
# Kernel and group_shape combinations: (kernel, group_shape)
......@@ -246,10 +246,8 @@ class TestModel(torch.nn.Module):
]
def ops_in_model_before_partial(self):
return (
[RMS_OP, RMS_ADD_OP]
if self.enable_rms_norm_custom_op
else [torch.ops.aten.rsqrt]
return [torch.ops.vllm_ir.rms_norm] + (
[RMS_ADD_OP] if self.enable_rms_norm_custom_op else [torch.ops.aten.rsqrt]
)
......@@ -340,7 +338,10 @@ def test_fusion_rmsnorm_quant(
),
)
with vllm.config.set_current_vllm_config(vllm_config):
with (
vllm.config.set_current_vllm_config(vllm_config),
vllm_config.kernel_config.ir_op_priority.set_priority(),
):
# Setup device before model creation
torch.set_default_device("cuda")
torch.set_default_dtype(dtype)
......@@ -370,8 +371,9 @@ def test_fusion_rmsnorm_quant(
# Hence, we check only 2 add nodes are left (final fused rmsnorm add).
if not enable_rms_norm_custom_op:
n_add_nodes = lambda g: sum(1 for _ in find_op_nodes(torch.ops.aten.add, g))
# 7 = 1 (RMS) + 3x2 (3xRMS_ADD, 2 each)
assert n_add_nodes(backend.graph_pre_pass) == 7
# rms_norm is IR, not included
# 6 = 3x2 (3xRMS_ADD, 2 each)
assert n_add_nodes(backend.graph_pre_pass) == 6
assert n_add_nodes(backend.graph_post_pass) == 2
......
......@@ -3,11 +3,11 @@
import pytest
import torch
from torch._ops import OpOverload, OpOverloadPacket
from tests.compile.backend import TestBackend
from vllm.compilation.passes.fusion.matcher_utils import (
FLASHINFER_ROTARY_OP,
RMS_OP,
ROTARY_OP,
)
from vllm.compilation.passes.fusion.qk_norm_rope_fusion import (
......@@ -100,13 +100,8 @@ class QKNormRoPETestModel(torch.nn.Module):
q, k = self.rotary_emb(positions, q, k)
return q, k, v
def ops_in_model_before(self) -> list[torch._ops.OpOverload]:
ops = []
if self.enable_rms_norm_custom_op:
ops.append(RMS_OP)
else:
ops.append(RSQRT_OP)
def ops_in_model_before(self) -> list[OpOverload | OpOverloadPacket]:
ops: list[OpOverload | OpOverloadPacket] = [torch.ops.vllm_ir.rms_norm]
if self.enable_rope_custom_op:
if self.rotary_emb.use_flashinfer:
ops.append(FLASHINFER_ROTARY_OP)
......@@ -116,7 +111,7 @@ class QKNormRoPETestModel(torch.nn.Module):
ops.append(INDEX_SELECT_OP)
return ops
def ops_in_model_after(self) -> list[torch._ops.OpOverload]:
def ops_in_model_after(self) -> list[OpOverload | OpOverloadPacket]:
return [FUSED_QK_ROPE_OP]
......@@ -166,7 +161,10 @@ def test_qk_norm_rope_fusion(
num_heads, num_kv_heads, head_dim = 16, 4, 128
T = 5
with set_current_vllm_config(vllm_config):
with (
set_current_vllm_config(vllm_config),
vllm_config.kernel_config.ir_op_priority.set_priority(),
):
model = QKNormRoPETestModel(
num_heads=num_heads,
num_kv_heads=num_kv_heads,
......
......@@ -1622,3 +1622,26 @@ def fresh_vllm_cache(monkeypatch, use_fresh_inductor_cache):
def enable_pickle(monkeypatch):
"""`LLM.apply_model` requires pickling a function."""
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
@pytest.fixture(scope="function")
def disable_log_dedup(monkeypatch):
"""
Disable log deduplication such that warning_once and info_once always print.
"""
# Patch logger._print_warning_once to remove the lru_cache decorator
from vllm import logger
original_print_warning_once = logger._print_warning_once
original_print_info_once = logger._print_info_once
original_print_debug_once = logger._print_debug_once
logger._print_warning_once = original_print_warning_once.__wrapped__
logger._print_info_once = original_print_info_once.__wrapped__
logger._print_debug_once = original_print_debug_once.__wrapped__
yield
logger._print_warning_once = original_print_warning_once
logger._print_info_once = original_print_info_once
logger._print_debug_once = original_print_debug_once
......@@ -523,3 +523,20 @@ def test_human_readable_model_len():
for invalid in ["1a", "pwd", "10.24", "1.23M", "1.22T"]:
with pytest.raises(ArgumentError):
parser.parse_args(["--max-model-len", invalid])
def test_ir_op_priority():
from vllm.config.kernel import IrOpPriorityConfig, KernelConfig
ir_op_priority = IrOpPriorityConfig(rms_norm=["vllm_c"])
cfg1 = EngineArgs(ir_op_priority=ir_op_priority).create_engine_config()
cfg2 = EngineArgs(
kernel_config=KernelConfig(ir_op_priority=ir_op_priority)
).create_engine_config()
assert cfg1.kernel_config.ir_op_priority == cfg2.kernel_config.ir_op_priority
with pytest.raises(ValueError, match="rms_norm"):
_ = EngineArgs(
ir_op_priority=ir_op_priority,
kernel_config=KernelConfig(ir_op_priority=ir_op_priority),
).create_engine_config()
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import importlib.util
import logging
from pathlib import Path
from typing import Any
import pytest
import torch
from torch import fx
from torch.fx.experimental.proxy_tensor import make_fx
import vllm.ir.op
from vllm.ir.op import RESERVED_PROVIDERS, IrOp, IrOpImpl
# This should not exist
assert "_custom_add" not in IrOp.registry
class CustomError(Exception):
pass
@vllm.ir.register_op
def _custom_add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x + y
def test_registration_overloads():
assert all(
n not in IrOp.registry for n in ["_custom_sub", "_custom_mul", "_custom_div"]
)
# Calling with decorator
@vllm.ir.register_op()
def _custom_sub(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x - y
assert _custom_sub.name == "_custom_sub"
assert _custom_sub is IrOp.registry["_custom_sub"]
# Custom name
@vllm.ir.register_op(name="_custom_mul")
def custom_mul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x * y
assert custom_mul.name == "_custom_mul"
assert custom_mul is IrOp.registry["_custom_mul"]
# Direct construction does not register directly
def _custom_div(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x / y
custom_div = IrOp("_custom_div", _custom_div)
assert custom_div.name == "_custom_div"
assert "_custom_div" not in IrOp.registry
# Duplicate op registration not allowed
with pytest.raises(AssertionError):
@vllm.ir.register_op
def _custom_mul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x * y - 100
def test_no_kw_only_args():
# kw-only args not supported
with pytest.raises(ValueError, match="keyword-only arguments"):
@vllm.ir.register_op
def _custom_kwarg_op(
x: torch.Tensor, y: torch.Tensor, *, kwarg: int = 0
) -> torch.Tensor:
return x + y + kwarg
assert "_custom_kwarg_op" not in IrOp.registry
class TestIrOpCustomAdd:
# Registration invariants
def test_decorated_object(self):
"""Make sure that referring directly to an op is correct"""
assert isinstance(_custom_add, IrOp)
assert "_custom_add" in IrOp.registry
assert _custom_add is IrOp.registry["_custom_add"]
def test_torch_op_is_registered(self):
assert hasattr(torch.ops.vllm_ir, "_custom_add")
assert callable(torch.ops.vllm_ir._custom_add.default)
# Semantic correctness
def test_semantics_match_native(self):
x = torch.randn(4, 5)
y = torch.randn(4, 5)
# Calls native by default
out = _custom_add(x, y)
ref = x + y
torch.testing.assert_close(out, ref)
# -------------------------
# Implementation registration
# -------------------------
def test_register_impl_is_non_intrusive(self):
@_custom_add.register_impl("dummy_provider")
def dummy_impl(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x + y + 123
assert "dummy_provider" in _custom_add.impls
assert isinstance(_custom_add.impls["dummy_provider"], IrOpImpl)
x = torch.ones(2, 2)
y = torch.ones(2, 2)
# Native semantics must still hold
torch.testing.assert_close(_custom_add(x, y), x + y)
def test_schema_contains_tensor_signature(self):
schema = _custom_add._schema_str
assert "Tensor" in schema
assert "-> Tensor" in schema
# -------------------------
# FX visibility
# -------------------------
@pytest.mark.parametrize("enable_torch_wrap", [True, False])
@pytest.mark.parametrize("symbolic_trace", [True, False])
def test_trace_sees_single_custom_op(
self, symbolic_trace: bool, enable_torch_wrap: bool
):
def fn(x, y):
return _custom_add(x, y)
def find_fn(target: Any, gm: fx.GraphModule):
return gm.graph.find_nodes(op="call_function", target=target)
with pytest.raises(CustomError), vllm.ir.enable_torch_wrap(enable_torch_wrap):
if symbolic_trace:
gm = torch.fx.symbolic_trace(fn)
else:
gm = make_fx(fn)(torch.randn(2, 2), torch.randn(2, 2))
x1, y1 = torch.rand(5, 4), torch.rand(5, 4)
out_fx = gm(x1, y1)
out_eager = fn(x1, y1)
# raise error to check enable_torch_wrap context restored correctly
raise CustomError
# check behavior matches eager in all cases
torch.testing.assert_close(out_fx, out_eager)
# check that IR nodes only appear if enable_torch_wrap=True
ir_nodes = find_fn(torch.ops.vllm_ir._custom_add.default, gm)
if enable_torch_wrap:
assert len(ir_nodes) == 1, gm.code
else:
assert len(ir_nodes) == 0, gm.code
# with torch wrapping enabled (default), IR nodes appear
if symbolic_trace:
gm = torch.fx.symbolic_trace(fn)
else:
gm = make_fx(fn)(torch.randn(2, 2), torch.randn(2, 2))
ir_nodes = find_fn(torch.ops.vllm_ir._custom_add.default, gm)
assert len(ir_nodes) == 1, gm.code
@_custom_add.register_impl("impl_a")
def impl_a(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x + y + 10
@_custom_add.register_impl("impl_b")
def impl_b(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x + y + 20
@_custom_add.register_impl("impl_even", supports_args=lambda x, y: x.size(1) % 2 == 0)
def impl_even(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x + y + 50
class TestIrOpImplDispatch:
def test_register_impl(self):
assert "impl_a" in _custom_add.impls
impl = _custom_add.impls["impl_a"]
assert impl is impl_a
assert impl.op is _custom_add
assert impl.provider == "impl_a"
assert callable(impl.impl_fn)
# Test duplicate registration rejected
with pytest.raises(AssertionError):
@_custom_add.register_impl("impl_a")
def impl_a_dup(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x + y + 30
# Check the original impl is still intact
assert _custom_add.impls["impl_a"] is impl_a
# Check support all args
assert impl_a.supports_all_args
assert impl_b.supports_all_args
assert not impl_even.supports_all_args
def test_reserved_provider_rejected(self):
for provider in RESERVED_PROVIDERS:
with pytest.raises(AssertionError):
@_custom_add.register_impl(provider)
def bad_impl(x, y):
return x + y
def test_set_priority_scoped(self):
assert _custom_add.get_priority() == []
with _custom_add.set_priority(["impl_even", "impl_b"]):
assert _custom_add.get_priority() == ["impl_even", "impl_b"]
# Check nesting
with _custom_add.set_priority(["impl_b"]):
assert _custom_add.get_priority() == ["impl_b"]
# Restored
assert _custom_add.get_priority() == ["impl_even", "impl_b"]
# Check that exception restores priority
with pytest.raises(CustomError), _custom_add.set_priority(["impl_a"]):
assert _custom_add.get_priority() == ["impl_a"]
raise CustomError
# Restored again
assert _custom_add.get_priority() == ["impl_even", "impl_b"]
# Restored to empty
assert _custom_add.get_priority() == []
def test_dispatch_priority_order(self):
x = torch.tensor(1, dtype=torch.int32)
y = torch.tensor(2, dtype=torch.int32)
with _custom_add.set_priority(["impl_b", "impl_a"]):
assert _custom_add.dispatch(x, y) is impl_b
out1 = _custom_add(x, y)
out2 = torch.ops.vllm_ir._custom_add(x, y)
with _custom_add.set_priority(["impl_a"]):
assert _custom_add.dispatch(x, y) is impl_a
out3 = _custom_add(x, y)
out4 = torch.ops.vllm_ir._custom_add(x, y)
# impl_b
assert out1.item() == 1 + 2 + 20
assert out2.item() == 1 + 2 + 20
# impl_a
assert out3.item() == 1 + 2 + 10
assert out4.item() == 1 + 2 + 10
def test_unsupported_impl_filtered(self):
@_custom_add.register_impl("unsupported", supported=False)
def impl_bad(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x + y + 999
x = torch.tensor(1, dtype=torch.int32)
y = torch.tensor(2, dtype=torch.int32)
with _custom_add.set_priority(["unsupported", "impl_a"]):
assert _custom_add.get_priority() == ["impl_a"]
out = _custom_add(x, y)
# impl_bad skipped → impl_a
assert out.item() == 1 + 2 + 10
def test_supports_args_runtime_dispatch_and_warning(
self, caplog_vllm: pytest.LogCaptureFixture
):
x1 = torch.ones((2, 2), dtype=torch.int32)
y1 = torch.full((2, 2), 2, dtype=torch.int32)
x2 = torch.ones((2, 3), dtype=torch.int32)
y2 = torch.full((2, 3), 2, dtype=torch.int32)
with (
caplog_vllm.at_level(logging.WARNING),
_custom_add.set_priority(["impl_even"]),
):
# Test the warning about native fallback is logged (before even dispatching)
assert len(caplog_vllm.records) == 1
message = caplog_vllm.records[0].message
assert "_custom_add" in message
assert "fallback to native" in message
assert "priority" in message
# Check dispatching
assert _custom_add.get_priority() == ["impl_even", "native"]
assert _custom_add.dispatch(x1, y1) is impl_even
assert _custom_add.dispatch(x2, y2) is _custom_add.impls["native"]
out1 = _custom_add(x1, y1) # size(1) == 2 → impl_even
out2 = _custom_add(x2, y2) # size(1) == 3 → native fallback
# no other warnings
assert len(caplog_vllm.records) == 1
assert torch.all(out1 == 1 + 2 + 50)
assert torch.all(out2 == 1 + 2)
def test_default_priority(
self, caplog_vllm: pytest.LogCaptureFixture, disable_log_dedup
):
# Make sure logs are not deduplicated to properly test the warning
x = torch.tensor([3], dtype=torch.int32)
y = torch.tensor([4], dtype=torch.int32)
# No priority set → falls back to native
assert _custom_add.get_priority() == []
with caplog_vllm.at_level(logging.WARNING):
# Native by default
assert _custom_add.dispatch(x, y) is _custom_add.impls["native"]
out = _custom_add(x, y)
# Check dispatching to native by default
assert out.item() == 3 + 4
# Check warning
assert len(caplog_vllm.records) == 2
message = caplog_vllm.records[0].message.lower()
assert "_custom_add" in message
assert "priority not set" in message
@vllm.ir.register_op
def _custom_mm(
x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor | None = None
) -> torch.Tensor:
tmp = x @ y
return tmp if bias is None else tmp + bias
def test_default_args():
# Test that default args are properly applied when dispatching and calling
@_custom_mm.register_impl("impl_mm", supports_args=lambda x, y, bias=None: True)
def impl_mm(
x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor | None = None
) -> torch.Tensor:
tmp = x @ y
return tmp + 50 if bias is None else tmp + bias + 100
x1 = torch.tensor([1, 2], dtype=torch.int32)
x2 = torch.tensor([3, 4], dtype=torch.int32)
# Test that supports_args receives the defaulted args
assert impl_mm.supports_args(x1, x2)
with _custom_mm.set_priority(["impl_mm", "native"]):
assert _custom_mm.dispatch(x1, x2) is impl_mm
def test_bad_impl_registrations():
# Check bad schema
with pytest.raises(ValueError, match="does not match native schema"):
@_custom_mm.register_impl("impl_mm_bad_schema")
def impl_mm_bad_schema(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x @ y - 1
with pytest.raises(ValueError, match="does not match native schema"):
@_custom_mm.register_impl("impl_mm_bad_schema_2")
def impl_mm_bad_schema_2(
x: torch.Tensor, y: torch.Tensor, b: torch.Tensor | None = None
) -> torch.Tensor:
return x @ y + b - 2
with pytest.raises(ValueError, match="does not match native schema"):
@_custom_mm.register_impl("impl_mm_bad_schema_3")
def impl_mm_bad_schema_3(
x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor
) -> torch.Tensor:
return x @ y + bias - 5
# check supports_args with incorrect params
with pytest.raises(ValueError, match="supports_args must be a callable"):
@_custom_mm.register_impl("impl_mm_bad_supports_args", supports_args=True)
def impl_mm_bad_supports_args(
x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor | None = None
) -> torch.Tensor:
return x @ y + 10
with pytest.raises(ValueError, match="number of parameters"):
@_custom_mm.register_impl(
"impl_mm_bad_supports_args_2", supports_args=lambda x, y: True
)
def impl_mm_bad_supports_args(
x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor | None = None
) -> torch.Tensor:
return x @ y + 10
with pytest.raises(ValueError, match="keyword-only parameters"):
@_custom_mm.register_impl(
"impl_mm_bad_supports_args_3", supports_args=lambda x, y, *, b: True
)
def impl_mm_bad_supports_args_2(
x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor | None = None
) -> torch.Tensor:
return x @ y + 20
with pytest.raises(ValueError, match="does not match native parameter"):
@_custom_mm.register_impl(
"impl_mm_bad_supports_args_4", supports_args=lambda x, y, b: True
)
def impl_mm_bad_supports_args_4(
x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor | None = None
) -> torch.Tensor:
return x @ y + 30
with pytest.raises(ValueError, match="does not match native default"):
@_custom_mm.register_impl(
"impl_mm_bad_supports_args_5", supports_args=lambda x, y, bias=1: True
)
def impl_mm_bad_supports_args_5(
x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor | None = None
) -> torch.Tensor:
return x @ y + 40
assert set(_custom_mm.impls.keys()) == {"impl_mm", "native"}
IMPL_OOT_SRC = """
import torch
@_custom_mm.register_impl("impl_mm_oot")
def impl_mm_oot(
x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor | None = None
) -> torch.Tensor:
return x @ y - 99
"""
def load_custom_mm_module(file_path: Path):
spec = importlib.util.spec_from_file_location("_custom_mm_oot", file_path)
assert spec is not None
module = importlib.util.module_from_spec(spec)
# Inject the variable into the module's global namespace
# This allows the @_custom_mm.register_impl decorator to work
module._custom_mm = _custom_mm # type: ignore[attr-defined]
# Execute the file; this triggers the decorator
assert spec.loader is not None
spec.loader.exec_module(module)
return module
def test_uuid_and_oot(tmp_path: Path):
file_path = tmp_path / "_custom_mm_oot.py"
file_path.write_text(IMPL_OOT_SRC)
assert "impl_mm_oot" not in _custom_mm.impls
_ = load_custom_mm_module(file_path)
assert "impl_mm_oot" in _custom_mm.impls
uuid = _custom_mm.impls["impl_mm_oot"].uuid()
del _custom_mm.impls["impl_mm_oot"]
# Replace file source
file_path.write_text(IMPL_OOT_SRC + " # added file source")
assert "impl_mm_oot" not in _custom_mm.impls
_ = load_custom_mm_module(file_path)
assert "impl_mm_oot" in _custom_mm.impls
uuid1 = _custom_mm.impls["impl_mm_oot"].uuid()
assert uuid1 != uuid
del _custom_mm.impls["impl_mm_oot"]
# Back to original
file_path.write_text(IMPL_OOT_SRC)
assert "impl_mm_oot" not in _custom_mm.impls
_ = load_custom_mm_module(file_path)
assert "impl_mm_oot" in _custom_mm.impls
uuid2 = _custom_mm.impls["impl_mm_oot"].uuid()
assert uuid2 == uuid
assert uuid2 != uuid1
del _custom_mm.impls["impl_mm_oot"]
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
# This registers op implementations
import vllm.kernels # noqa: F401
from tests.kernels.allclose_default import get_default_rtol
from vllm import ir
from vllm.platforms import current_platform
def rms_norm_inputs(n_tokens: int, hidden_size: int, dtype: torch.dtype):
x = torch.randn(n_tokens, hidden_size, dtype=dtype)
weight = torch.rand(hidden_size, dtype=dtype)
return x, weight
rms_norm_native = ir.ops.rms_norm.impls["native"].impl_fn
@pytest.mark.skipif(
not current_platform.is_cuda_alike() and not current_platform.is_xpu(),
reason="Currently only kernels on CUDA, ROCm and XPU",
)
def test_rms_norm_registration():
expected = {
"native": True,
"vllm_c": current_platform.is_cuda_alike(),
"aiter": current_platform.is_rocm(),
"oink": False,
"xpu_kernels": current_platform.is_xpu(),
}
actual = {
provider: impl.supported for provider, impl in ir.ops.rms_norm.impls.items()
}
assert actual == expected
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
@pytest.mark.parametrize("n_tokens", [1, 8, 17])
@pytest.mark.parametrize("hidden_size", [16, 4096, 8192])
@pytest.mark.parametrize("epsilon", [1e-6, 1e-5])
@pytest.mark.skipif(
not current_platform.is_cuda_alike() and not current_platform.is_xpu(),
reason="Currently only kernels on CUDA, ROCm and XPU",
)
class TestRMSNorm:
@classmethod
def setup_class(cls, **kwargs):
torch.set_default_device(current_platform.device_type)
def test_native_semantics(self, dtype, n_tokens, hidden_size, epsilon):
x, weight = rms_norm_inputs(4, 8, dtype)
out = rms_norm_native(x, weight, epsilon=epsilon)
# Check shape, dtype, device
assert out.shape == x.shape
assert out.dtype == x.dtype
assert out.device == x.device
# Check the scaling property of rms norm
out2 = rms_norm_native(x * 2.0, weight, epsilon=epsilon)
torch.testing.assert_close(out2, out, rtol=get_default_rtol(out), atol=1e-3)
# Check behavior with and without weight
weight1 = torch.ones_like(weight)
out3 = rms_norm_native(x, weight1, epsilon=epsilon)
out4 = rms_norm_native(x, None, epsilon=epsilon)
torch.testing.assert_close(out3, out4)
@pytest.mark.parametrize("provider", ["vllm_c", "aiter", "xpu_kernels"])
def test_impls(self, dtype, n_tokens, hidden_size, epsilon, provider):
impl = ir.ops.rms_norm.impls[provider]
if not impl.supported:
pytest.skip(f"{provider} impl not supported on this platform")
x, weight = rms_norm_inputs(n_tokens, hidden_size, dtype)
args = (x, weight, epsilon, None)
assert impl.supported
if provider == "aiter" and dtype not in [torch.float16, torch.bfloat16]:
assert not impl.supports_args(*args)
return
assert impl.supports_args(*args)
out_impl = impl.impl_fn(*args)
out_native = rms_norm_native(*args)
torch.testing.assert_close(
out_impl, out_native, rtol=get_default_rtol(out_impl), atol=1e-3
)
# check that dispatched call matches direct call
with ir.ops.rms_norm.set_priority([provider, "native"]):
out_impl2 = ir.ops.rms_norm(*args)
# exact match
torch.testing.assert_close(out_impl2, out_impl, rtol=0.0, atol=0.0)
# none of these support variance_size override
assert not impl.supports_args(x, weight, epsilon, 4)
assert not impl.supports_args(x, weight, epsilon, variance_size=4)
# test weight=None behavior
out_impl_no_weight = impl.impl_fn(x, None, epsilon)
out_impl_unit_weight = impl.impl_fn(x, torch.ones_like(weight), epsilon)
torch.testing.assert_close(
out_impl_no_weight,
out_impl_unit_weight,
rtol=get_default_rtol(out_impl_no_weight),
atol=2e-4,
)
@pytest.mark.parametrize("provider", ["vllm_c", "aiter", "xpu_kernels", "native"])
def test_torch_opcheck(self, dtype, n_tokens, hidden_size, epsilon, provider):
if not ir.ops.rms_norm.impls[provider].supported:
pytest.skip(f"{provider} impl not supported on this platform")
x, weight = rms_norm_inputs(n_tokens, hidden_size, dtype)
args = (x, weight, epsilon, None)
# When checking the torch op, we have to set priority and use dispatch
with ir.ops.rms_norm.set_priority([provider, "native"]):
torch.library.opcheck(torch.ops.vllm_ir.rms_norm, args)
......@@ -27,7 +27,6 @@ from vllm.model_executor.layers.layernorm import (
RMSNorm,
dispatch_rocm_rmsnorm_func,
fused_add_rms_norm,
rms_norm,
)
from vllm.platforms import current_platform
......@@ -156,7 +155,7 @@ def test_topk_sigmoid_dispatch(use_rocm_aiter: bool):
assert topk_func == vllm_topk_sigmoid
@pytest.mark.parametrize("add_residual", [True, False])
@pytest.mark.parametrize("add_residual", [False])
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("use_rocm_aiter", [True, False])
@pytest.mark.skipif(
......@@ -165,7 +164,7 @@ def test_topk_sigmoid_dispatch(use_rocm_aiter: bool):
def test_rms_norm_dispatch(
add_residual: bool, dtype: torch.dtype, use_rocm_aiter: bool
):
rms_norm_func = dispatch_rocm_rmsnorm_func(add_residual, dtype, use_rocm_aiter)
rms_norm_func = dispatch_rocm_rmsnorm_func(dtype, use_rocm_aiter)
should_use_rocm_aiter = (
current_platform.is_rocm()
......@@ -173,11 +172,7 @@ def test_rms_norm_dispatch(
and dtype in RMS_NORM_SUPPORTED_DTYPES
)
if add_residual and should_use_rocm_aiter:
if should_use_rocm_aiter:
assert rms_norm_func == rocm_aiter_ops.rms_norm2d_with_add
elif should_use_rocm_aiter:
assert rms_norm_func == rocm_aiter_ops.rms_norm
elif add_residual:
assert rms_norm_func == fused_add_rms_norm
else:
assert rms_norm_func == rms_norm
assert rms_norm_func == fused_add_rms_norm
......@@ -6,12 +6,14 @@ import os
from dataclasses import MISSING, Field, asdict, dataclass, field
from unittest.mock import patch
import pydantic
import pytest
from pydantic import ValidationError
from vllm.compilation.backends import VllmBackend
from vllm.config import (
CompilationConfig,
KernelConfig,
ModelConfig,
ParallelConfig,
PoolerConfig,
......@@ -21,6 +23,7 @@ from vllm.config import (
update_config,
)
from vllm.config.compilation import CompilationMode, CUDAGraphMode
from vllm.config.kernel import IrOpPriorityConfig
from vllm.config.load import LoadConfig
from vllm.config.utils import get_field
from vllm.config.vllm import (
......@@ -1077,6 +1080,39 @@ def test_vllm_config_explicit_overrides():
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE
def test_fusion_pass_op_priority():
"""This test checks that custom op enablement & IR op priority
correctly control default fusions"""
# Default config, O2, rms_norm+quant fusion disabled
cfg1 = VllmConfig()
assert not cfg1.compilation_config.pass_config.fuse_norm_quant
# rms_norm manually enabled, O1, rms_norm+quant fusion enabled
cfg2 = VllmConfig(
optimization_level=OptimizationLevel.O1,
compilation_config=CompilationConfig(
custom_ops=["+rms_norm"],
),
)
assert cfg2.compilation_config.pass_config.fuse_norm_quant
# using custom kernel for RMSNorm via IR:
# Note that vLLM IR only supports the non-residual rms_norm for now;
# soon this will be resolved.
cfg3 = VllmConfig(
kernel_config=KernelConfig(
ir_op_priority=IrOpPriorityConfig(rms_norm=["vllm_c"])
)
)
assert cfg3.compilation_config.pass_config.fuse_norm_quant
# block-fp8 model should enable quant_fp8 automatically
cfg4 = VllmConfig(model_config=ModelConfig("Qwen/Qwen3-4B-FP8"))
assert "+quant_fp8" in cfg4.compilation_config.custom_ops
assert cfg4.compilation_config.pass_config.fuse_norm_quant
def test_scheduler_config_init():
with pytest.raises(ValidationError):
# Positional InitVars missing
......@@ -1171,3 +1207,35 @@ def test_eagle_draft_model_config():
assert draft_model_config.hf_text_config.model_type == "eagle"
assert draft_model_config.architectures == ["EagleLlamaForCausalLM"]
assert draft_model_config.architecture == "EagleLlamaForCausalLM"
def test_ir_op_priority_default():
"""Test that IR op priority defaults are set correctly."""
from vllm.config.kernel import IrOpPriorityConfig
# Assert default is applied to ops
priority_config = IrOpPriorityConfig.with_default(["vllm_c", "native"])
assert priority_config.rms_norm == ["vllm_c", "native"]
# Assert single ops override the default
assert IrOpPriorityConfig.with_default(
["vllm_c", "native"], rms_norm=["oink", "native"]
) == IrOpPriorityConfig(rms_norm=["oink", "native"])
def test_ir_op_priority_str():
"""Test that passing a comma-delimited string works"""
from vllm.config.kernel import IrOpPriorityConfig
priority_config = IrOpPriorityConfig(rms_norm="vllm_c")
assert priority_config.rms_norm == ["vllm_c"]
priority_config = IrOpPriorityConfig(rms_norm="vllm_c,native")
assert priority_config.rms_norm == ["vllm_c", "native"]
priority_config = IrOpPriorityConfig(rms_norm=" native, vllm_c ")
assert priority_config.rms_norm == ["native", "vllm_c"]
with pytest.raises(pydantic.ValidationError):
# must be list of only strings
priority_config = IrOpPriorityConfig(rms_norm=["vllm_c", 4, "native"])
......@@ -3,6 +3,7 @@
import contextlib
from importlib.util import find_spec
from types import ModuleType
from typing import Any
import torch
import torch._inductor.pattern_matcher as pm
......@@ -10,6 +11,7 @@ import torch.fx as fx
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import PatternMatcherPass
import vllm.ir.ops
from vllm.config import VllmConfig
from vllm.config.utils import Range
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
......@@ -28,7 +30,7 @@ from vllm.utils.torch_utils import (
from ..inductor_pass import enable_fake_mode
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm
from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8
FP8_DTYPE = current_platform.fp8_dtype()
......@@ -258,6 +260,12 @@ class BasePattern:
self.tp = get_tp_group()
self.tp_size = get_tensor_model_parallel_world_size()
def empty(self, *args: Any, **kwargs: Any) -> torch.Tensor:
return torch.empty(*args, dtype=self.dtype, device=self.device, **kwargs)
def empty_f32(self, *args: Any, **kwargs: Any) -> torch.Tensor:
return torch.empty(*args, dtype=torch.float32, device=self.device, **kwargs)
class AllReduceRMSNormPattern(BasePattern):
"""
......@@ -276,20 +284,17 @@ class AllReduceRMSNormPattern(BasePattern):
super().__init__(dtype, device)
self.epsilon = epsilon
self.allreduce_params = allreduce_params
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
def get_inputs(self) -> list[torch.Tensor]:
input, weight = self.rmsnorm_matcher.inputs()
# input goes through allreduce first, always 16-bit
return [input.to(self.dtype), weight]
# input, weight
return [self.empty(5, 16), self.empty(16)]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor, weight: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
allreduce_output = tensor_model_parallel_all_reduce(input)
rms = self.rmsnorm_matcher(allreduce_output, weight)
rms = vllm.ir.ops.rms_norm(allreduce_output, weight, self.epsilon)
return rms, allreduce_output
......@@ -407,15 +412,13 @@ class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern):
self.epsilon = epsilon
self.allreduce_params = allreduce_params
self.quant_dtype = torch.float8_e4m3fn
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
def get_inputs(self) -> list[torch.Tensor]:
input, weight = self.rmsnorm_matcher.inputs()
_, scale = self.quant_matcher.inputs()
# input goes through allreduce first, always 16-bit
return [input.to(self.dtype), weight, scale]
# input, weight
return [self.empty(5, 16), self.empty(16), scale]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
......@@ -424,7 +427,7 @@ class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern):
scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
all_reduce = tensor_model_parallel_all_reduce(input)
rms = self.rmsnorm_matcher(all_reduce, weight)
rms = vllm.ir.ops.rms_norm(all_reduce, weight, self.epsilon)
quant, _ = self.quant_matcher(rms, scale)
return quant, all_reduce
......@@ -553,7 +556,6 @@ class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern):
super().__init__(dtype, device)
self.epsilon = epsilon
self.allreduce_params = allreduce_params
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
def get_inputs(self) -> list[torch.Tensor]:
input = torch.empty([1, 16, 16], device=self.device, dtype=self.dtype)
......@@ -575,7 +577,7 @@ class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern):
output_scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
all_reduce = tensor_model_parallel_all_reduce(input)
rms = self.rmsnorm_matcher(all_reduce, weight)
rms = vllm.ir.ops.rms_norm(all_reduce, weight, self.epsilon)
quant_out_tuple = auto_functionalized(
STATIC_FP4_QUANT_OP,
input=rms,
......
......@@ -26,7 +26,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.platforms import current_platform
RMS_OP = torch.ops._C.rms_norm.default
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
ROTARY_OP = torch.ops._C.rotary_embedding.default
FLASHINFER_ROTARY_OP = torch.ops.vllm.flashinfer_rotary_embedding.default
......@@ -160,69 +159,6 @@ class MatcherRotaryEmbedding(MatcherCustomOp):
return result
class MatcherRMSNorm(MatcherCustomOp):
def __init__(
self,
epsilon: float,
enabled: bool | None = None,
match_rocm_aiter: bool = False,
) -> None:
if enabled is None:
enabled = RMSNorm.enabled()
super().__init__(enabled)
self.epsilon = epsilon
self._rmsnorm_op = RMS_OP
self.match_rocm_aiter = match_rocm_aiter
if match_rocm_aiter:
self._rmsnorm_op = rocm_aiter_ops.get_rmsnorm_op()
def inputs(self) -> list[torch.Tensor]:
input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16)
weight = self.empty(16)
return [input, weight]
def forward_rocm_aiter(
self,
input: torch.Tensor,
weight: torch.Tensor,
) -> torch.Tensor:
return self._rmsnorm_op(
x=input,
weight=weight,
variance_epsilon=self.epsilon,
)
def forward_custom(
self,
input: torch.Tensor,
weight: torch.Tensor,
) -> torch.Tensor:
if self.match_rocm_aiter:
return self.forward_rocm_aiter(input, weight)
result = torch.empty_like(input)
_, result = auto_functionalized(
self._rmsnorm_op,
result=result,
input=input,
weight=weight,
epsilon=self.epsilon,
)
return result
def forward_native(
self,
input: torch.Tensor,
weight: torch.Tensor,
) -> torch.Tensor:
return RMSNorm.forward_static(
input, self.epsilon, input.size(-1), self.model_dtype, weight
)
class MatcherFusedAddRMSNorm(MatcherCustomOp):
def __init__(
self,
......
......@@ -10,6 +10,7 @@ from torch import fx
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import PatternMatcherPass
import vllm.ir.ops
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.layers.attention import Attention
......@@ -17,7 +18,7 @@ from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from ..inductor_pass import enable_fake_mode
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
from .matcher_utils import MatcherRMSNorm, MatcherRotaryEmbedding
from .matcher_utils import MatcherRotaryEmbedding
from .rms_quant_fusion import empty_bf16, empty_fp32, empty_i64
logger = init_logger(__name__)
......@@ -64,7 +65,6 @@ class QkNormRopePattern:
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.eps = eps
self.rmsnorm_matcher = MatcherRMSNorm(eps)
self.is_neox = is_neox
self.rope_flashinfer = rope_flashinfer
self.rope_matcher = MatcherRotaryEmbedding(
......@@ -129,14 +129,14 @@ class QkNormRopePattern:
q_by_head = q.view(
*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim
)
q_normed_by_head = self.rmsnorm_matcher(q_by_head, q_weight)
q_normed_by_head = vllm.ir.ops.rms_norm(q_by_head, q_weight, self.eps)
q_flat = q_normed_by_head.view(q.shape)
# K path: view -> RMS -> view back to k.shape
k_by_head = k.view(
*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim
)
k_normed_by_head = self.rmsnorm_matcher(k_by_head, k_weight)
k_normed_by_head = vllm.ir.ops.rms_norm(k_by_head, k_weight, self.eps)
k_flat = k_normed_by_head.view(k.shape)
# RoPE: apply to flattened q/k
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment