Unverified Commit 40bb1750 authored by Luka Govedič's avatar Luka Govedič Committed by GitHub
Browse files

[vLLM IR] 1/N Implement IR skeleton and rms_norm op (#33825)


Signed-off-by: default avatarLuka Govedič <lgovedic@redhat.com>
Signed-off-by: default avatarXinyu Chen <xinyu1.chen@intel.com>
Signed-off-by: default avatarchzhang <chaojun.zhang@intel.com>
Signed-off-by: default avatarLuka Govedic <luka.govedic@gmail.com>
Co-authored-by: default avatarXinyu Chen <xinyu1.chen@intel.com>
Co-authored-by: default avatarChaojun Zhang <chaojun.zhang@intel.com>
Co-authored-by: default avatarLuka Govedič <ProExpertProg@h100-01.nemg-001.lab.rdu2.dc.redhat.com>
parent 0fab52f0
...@@ -2,6 +2,16 @@ group: Kernels ...@@ -2,6 +2,16 @@ group: Kernels
depends_on: depends_on:
- image-build - image-build
steps: steps:
- label: vLLM IR Tests
timeout_in_minutes: 10
working_dir: "/vllm-workspace/"
source_file_dependencies:
- vllm/ir
- vllm/kernels
commands:
- pytest -v -s tests/ir
- pytest -v -s tests/kernels/ir
- label: Kernels Core Operation Test - label: Kernels Core Operation Test
timeout_in_minutes: 75 timeout_in_minutes: 75
source_file_dependencies: source_file_dependencies:
......
...@@ -13,6 +13,9 @@ ...@@ -13,6 +13,9 @@
/vllm/model_executor/layers/rotary_embedding.py @vadiklyutiy /vllm/model_executor/layers/rotary_embedding.py @vadiklyutiy
/vllm/model_executor/model_loader @22quinn /vllm/model_executor/model_loader @22quinn
/vllm/model_executor/layers/batch_invariant.py @yewentao256 /vllm/model_executor/layers/batch_invariant.py @yewentao256
/vllm/ir @ProExpertProg
/vllm/kernels/ @ProExpertProg @tjtanaa
/vllm/kernels/helion @ProExpertProg @zou3519
/vllm/multimodal @DarkLight1337 @ywang96 @NickLucche @tjtanaa /vllm/multimodal @DarkLight1337 @ywang96 @NickLucche @tjtanaa
/vllm/vllm_flash_attn @LucasWilkinson @MatthewBonanni /vllm/vllm_flash_attn @LucasWilkinson @MatthewBonanni
CMakeLists.txt @tlrmchlsmth @LucasWilkinson CMakeLists.txt @tlrmchlsmth @LucasWilkinson
...@@ -74,6 +77,7 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson ...@@ -74,6 +77,7 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson
/tests/entrypoints @DarkLight1337 @robertgshaw2-redhat @aarnphm @NickLucche /tests/entrypoints @DarkLight1337 @robertgshaw2-redhat @aarnphm @NickLucche
/tests/evals @mgoin @vadiklyutiy /tests/evals @mgoin @vadiklyutiy
/tests/kernels @mgoin @tlrmchlsmth @WoosukKwon @yewentao256 /tests/kernels @mgoin @tlrmchlsmth @WoosukKwon @yewentao256
/tests/kernels/ir @ProExpertProg @tjtanaa
/tests/models @DarkLight1337 @ywang96 /tests/models @DarkLight1337 @ywang96
/tests/multimodal @DarkLight1337 @ywang96 @NickLucche /tests/multimodal @DarkLight1337 @ywang96 @NickLucche
/tests/quantization @mgoin @robertgshaw2-redhat @yewentao256 @pavanimajety /tests/quantization @mgoin @robertgshaw2-redhat @yewentao256 @pavanimajety
......
...@@ -8,7 +8,7 @@ from copy import deepcopy ...@@ -8,7 +8,7 @@ from copy import deepcopy
import depyf import depyf
from torch import fx from torch import fx
from torch._ops import OpOverload from torch._ops import OpOverload, OpOverloadPacket
from torch.fx._utils import lazy_format_graph_code from torch.fx._utils import lazy_format_graph_code
from vllm.compilation.passes.fx_utils import find_op_nodes from vllm.compilation.passes.fx_utils import find_op_nodes
...@@ -90,7 +90,9 @@ class TestBackend: ...@@ -90,7 +90,9 @@ class TestBackend:
# assign by reference, will reflect the final state of the graph # assign by reference, will reflect the final state of the graph
self.final_graph = graph self.final_graph = graph
def check_before_ops(self, ops: Sequence[OpOverload], fully_replaced=True): def check_before_ops(
self, ops: Sequence[OpOverload | OpOverloadPacket], fully_replaced=True
):
for op in ops: for op in ops:
num_pre = len(list(find_op_nodes(op, self.graph_pre_pass))) num_pre = len(list(find_op_nodes(op, self.graph_pre_pass)))
num_post = len(list(find_op_nodes(op, self.graph_post_pass))) num_post = len(list(find_op_nodes(op, self.graph_post_pass)))
...@@ -99,13 +101,19 @@ class TestBackend: ...@@ -99,13 +101,19 @@ class TestBackend:
if fully_replaced: if fully_replaced:
assert num_post == 0, f"Unexpected op {op.name()} in post-pass graph" assert num_post == 0, f"Unexpected op {op.name()} in post-pass graph"
def check_after_ops(self, ops: Sequence[OpOverload]): def check_after_ops(self, ops: Sequence[OpOverload | OpOverloadPacket]):
for op in ops: for op in ops:
num_pre = len(list(find_op_nodes(op, self.graph_pre_pass))) num_pre = len(list(find_op_nodes(op, self.graph_pre_pass)))
num_post = len(list(find_op_nodes(op, self.graph_post_pass))) num_post = len(list(find_op_nodes(op, self.graph_post_pass)))
assert num_pre == 0, f"Unexpected op {op.name()} in pre-pass graph" assert num_pre == 0, f"Unexpected op {op.name()} in pre-pass graph"
assert num_post > 0, f"Op {op.name()} not found in post-pass graph" assert num_post > 0, f"Op {op.name()} not found in post-pass graph"
def op_count(self, op: OpOverload, before=False) -> int: def op_count(self, op: OpOverload | OpOverloadPacket, before=False) -> int:
graph = self.graph_pre_pass if before else self.graph_post_pass graph = self.graph_pre_pass if before else self.graph_post_pass
return len(list(find_op_nodes(op, graph))) return len(list(find_op_nodes(op, graph)))
def print_graphs(self):
print("=== Graph before custom passes ===")
print(self.graph_pre_pass.python_code(root_module="self", verbose=True).src)
print("=== Graph after custom passes ===")
print(self.graph_post_pass.python_code(root_module="self", verbose=True).src)
...@@ -99,6 +99,8 @@ def test_tp1_fp8_fusions( ...@@ -99,6 +99,8 @@ def test_tp1_fp8_fusions(
model_kwargs["hf_overrides"] = hf_overrides(n_layers) model_kwargs["hf_overrides"] = hf_overrides(n_layers)
model_kwargs["load_format"] = "dummy" model_kwargs["load_format"] = "dummy"
model_kwargs["max_model_len"] = 1024 model_kwargs["max_model_len"] = 1024
model_kwargs["kernel_config"] = {"enable_flashinfer_autotune": False}
compilation_config = dict( compilation_config = dict(
use_inductor_graph_partition=inductor_graph_partition, use_inductor_graph_partition=inductor_graph_partition,
custom_ops=custom_ops.split(","), custom_ops=custom_ops.split(","),
...@@ -166,6 +168,7 @@ def test_tp1_fp4_fusions( ...@@ -166,6 +168,7 @@ def test_tp1_fp4_fusions(
model_kwargs["hf_overrides"] = hf_overrides(n_layers) model_kwargs["hf_overrides"] = hf_overrides(n_layers)
model_kwargs["load_format"] = "dummy" model_kwargs["load_format"] = "dummy"
model_kwargs["max_model_len"] = 1024 model_kwargs["max_model_len"] = 1024
model_kwargs["kernel_config"] = {"enable_flashinfer_autotune": False}
compilation_config = dict( compilation_config = dict(
use_inductor_graph_partition=inductor_graph_partition, use_inductor_graph_partition=inductor_graph_partition,
......
...@@ -68,6 +68,7 @@ def test_tp2_ar_rms_fp8_fusions( ...@@ -68,6 +68,7 @@ def test_tp2_ar_rms_fp8_fusions(
model_kwargs["hf_overrides"] = hf_overrides(n_layers) model_kwargs["hf_overrides"] = hf_overrides(n_layers)
model_kwargs["load_format"] = "dummy" model_kwargs["load_format"] = "dummy"
model_kwargs["max_model_len"] = 1024 model_kwargs["max_model_len"] = 1024
model_kwargs["kernel_config"] = {"enable_flashinfer_autotune": False}
compilation_config = dict( compilation_config = dict(
use_inductor_graph_partition=inductor_graph_partition, use_inductor_graph_partition=inductor_graph_partition,
...@@ -128,6 +129,7 @@ def test_tp2_ar_rms_fp4_fusions( ...@@ -128,6 +129,7 @@ def test_tp2_ar_rms_fp4_fusions(
model_kwargs["hf_overrides"] = hf_overrides(n_layers) model_kwargs["hf_overrides"] = hf_overrides(n_layers)
model_kwargs["load_format"] = "dummy" model_kwargs["load_format"] = "dummy"
model_kwargs["max_model_len"] = 1024 model_kwargs["max_model_len"] = 1024
model_kwargs["kernel_config"] = {"enable_flashinfer_autotune": False}
compilation_config = dict( compilation_config = dict(
use_inductor_graph_partition=inductor_graph_partition, use_inductor_graph_partition=inductor_graph_partition,
...@@ -182,6 +184,7 @@ def test_tp2_ar_rms_fusions( ...@@ -182,6 +184,7 @@ def test_tp2_ar_rms_fusions(
model_kwargs["hf_overrides"] = hf_overrides(n_layers) model_kwargs["hf_overrides"] = hf_overrides(n_layers)
model_kwargs["load_format"] = "dummy" model_kwargs["load_format"] = "dummy"
model_kwargs["max_model_len"] = 1024 model_kwargs["max_model_len"] = 1024
model_kwargs["kernel_config"] = {"enable_flashinfer_autotune": False}
compilation_config = dict( compilation_config = dict(
use_inductor_graph_partition=inductor_graph_partition, use_inductor_graph_partition=inductor_graph_partition,
......
...@@ -58,6 +58,7 @@ def test_tp2_async_tp_fp8_fusions( ...@@ -58,6 +58,7 @@ def test_tp2_async_tp_fp8_fusions(
model_kwargs["hf_overrides"] = hf_overrides(n_layers) model_kwargs["hf_overrides"] = hf_overrides(n_layers)
model_kwargs["load_format"] = "dummy" model_kwargs["load_format"] = "dummy"
model_kwargs["max_model_len"] = 1024 model_kwargs["max_model_len"] = 1024
model_kwargs["kernel_config"] = {"enable_flashinfer_autotune": False}
compilation_config = dict( compilation_config = dict(
use_inductor_graph_partition=inductor_graph_partition, use_inductor_graph_partition=inductor_graph_partition,
...@@ -121,6 +122,7 @@ def test_tp2_async_tp_fusions( ...@@ -121,6 +122,7 @@ def test_tp2_async_tp_fusions(
model_kwargs["hf_overrides"] = hf_overrides(n_layers) model_kwargs["hf_overrides"] = hf_overrides(n_layers)
model_kwargs["load_format"] = "dummy" model_kwargs["load_format"] = "dummy"
model_kwargs["max_model_len"] = 1024 model_kwargs["max_model_len"] = 1024
model_kwargs["kernel_config"] = {"enable_flashinfer_autotune": False}
compilation_config = dict( compilation_config = dict(
use_inductor_graph_partition=inductor_graph_partition, use_inductor_graph_partition=inductor_graph_partition,
......
...@@ -9,7 +9,6 @@ from tests.compile.backend import TestBackend ...@@ -9,7 +9,6 @@ from tests.compile.backend import TestBackend
from tests.utils import TestFP8Layer, multi_gpu_test from tests.utils import TestFP8Layer, multi_gpu_test
from vllm.compilation.passes.fusion.rms_quant_fusion import RMSNormQuantFusionPass from vllm.compilation.passes.fusion.rms_quant_fusion import RMSNormQuantFusionPass
from vllm.compilation.passes.fusion.sequence_parallelism import SequenceParallelismPass from vllm.compilation.passes.fusion.sequence_parallelism import SequenceParallelismPass
from vllm.compilation.passes.fx_utils import find_auto_fn
from vllm.compilation.passes.utility.noop_elimination import NoOpEliminationPass from vllm.compilation.passes.utility.noop_elimination import NoOpEliminationPass
from vllm.compilation.passes.utility.post_cleanup import PostCleanupPass from vllm.compilation.passes.utility.post_cleanup import PostCleanupPass
from vllm.compilation.passes.vllm_inductor_pass import VllmInductorPass from vllm.compilation.passes.vllm_inductor_pass import VllmInductorPass
...@@ -86,13 +85,14 @@ class TestAllReduceRMSNormModel(torch.nn.Module): ...@@ -86,13 +85,14 @@ class TestAllReduceRMSNormModel(torch.nn.Module):
] ]
def ops_in_model(self): def ops_in_model(self):
if RMSNorm.enabled(): return (
return [ [torch.ops.vllm_ir.rms_norm]
torch.ops._C.rms_norm.default, + [
torch.ops._C.fused_add_rms_norm.default, torch.ops._C.fused_add_rms_norm.default,
] ]
else: if RMSNorm.enabled()
return [] else []
)
class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module): class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
...@@ -321,4 +321,4 @@ def sequence_parallelism_pass_on_test_model( ...@@ -321,4 +321,4 @@ def sequence_parallelism_pass_on_test_model(
assert backend.op_count(op, before=False) == 4 assert backend.op_count(op, before=False) == 4
for op in model.ops_in_model(): for op in model.ops_in_model():
find_auto_fn(backend.graph_post_pass.nodes, op) assert backend.op_count(op, before=False) > 0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from torch import nn
import vllm.kernels # noqa: F401 to register kernels
from vllm import ir
from vllm.compilation.passes.ir.lowering_pass import (
VllmIRLoweringPass,
)
from vllm.config import get_current_vllm_config
from vllm.ir import ops
from vllm.platforms import current_platform
from ...backend import TestBackend
class Model(nn.Module):
def __init__(self, hidden_size=16, *args, **kwargs):
super().__init__(*args, **kwargs)
self.hidden_size = hidden_size
self.weight = torch.ones(hidden_size, dtype=torch.bfloat16)
def forward(self, x):
x1 = x + 4.0
x2 = ops.rms_norm(x1, self.weight, 1e-5)
x3 = x2 * 5.0
# no weight
x4 = ops.rms_norm(x3, None, 1e-5)
x5 = x4 / 2.0
# dispatch to native due to variance_size parameter
x6 = ops.rms_norm(x5, self.weight, 1e-5, self.hidden_size // 2)
return x6 + 3.0
@pytest.mark.parametrize("rms_provider", ops.rms_norm.supported_providers())
def test_lowering_rms_norm(rms_provider, default_vllm_config):
torch.set_default_device(current_platform.device_type)
lowering_pass = VllmIRLoweringPass(get_current_vllm_config())
backend = TestBackend(lowering_pass)
backend_unlowered = TestBackend()
model = Model()
x = torch.randn(8, 16, dtype=torch.bfloat16)
with (
ops.rms_norm.set_priority([rms_provider, "native"]),
ir.enable_torch_wrap(True),
):
compiled_model = torch.compile(model, backend=backend, fullgraph=True)
compiled_unlowered_model = torch.compile(
model, backend=backend_unlowered, fullgraph=True
)
output = compiled_model(x)
output_unlowered = compiled_unlowered_model(x)
selected = lowering_pass.selected_impls["rms_norm"]
assert len(selected) == 3
assert selected["rms_norm"] == rms_provider
assert selected["rms_norm_1"] == rms_provider
assert selected["rms_norm_2"] == "native"
# Compiled function guards on global value, avoid recompilation
with ir.enable_torch_wrap(True):
output2 = compiled_model(x)
torch.testing.assert_close(output_unlowered, output)
torch.testing.assert_close(output_unlowered, output2)
...@@ -6,6 +6,7 @@ import pytest ...@@ -6,6 +6,7 @@ import pytest
import torch import torch
import vllm.config import vllm.config
import vllm.ir.ops
import vllm.plugins import vllm.plugins
from tests.compile.backend import TestBackend from tests.compile.backend import TestBackend
from tests.utils import TestBlockFP8Layer, TestFP8Layer from tests.utils import TestBlockFP8Layer, TestFP8Layer
...@@ -51,7 +52,6 @@ from vllm.utils.deep_gemm import ( ...@@ -51,7 +52,6 @@ from vllm.utils.deep_gemm import (
FP8_DTYPE = current_platform.fp8_dtype() FP8_DTYPE = current_platform.fp8_dtype()
RMS_OP = torch.ops._C.rms_norm.default
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
# Kernel and group_shape combinations: (kernel, group_shape) # Kernel and group_shape combinations: (kernel, group_shape)
...@@ -246,10 +246,8 @@ class TestModel(torch.nn.Module): ...@@ -246,10 +246,8 @@ class TestModel(torch.nn.Module):
] ]
def ops_in_model_before_partial(self): def ops_in_model_before_partial(self):
return ( return [torch.ops.vllm_ir.rms_norm] + (
[RMS_OP, RMS_ADD_OP] [RMS_ADD_OP] if self.enable_rms_norm_custom_op else [torch.ops.aten.rsqrt]
if self.enable_rms_norm_custom_op
else [torch.ops.aten.rsqrt]
) )
...@@ -340,7 +338,10 @@ def test_fusion_rmsnorm_quant( ...@@ -340,7 +338,10 @@ def test_fusion_rmsnorm_quant(
), ),
) )
with vllm.config.set_current_vllm_config(vllm_config): with (
vllm.config.set_current_vllm_config(vllm_config),
vllm_config.kernel_config.ir_op_priority.set_priority(),
):
# Setup device before model creation # Setup device before model creation
torch.set_default_device("cuda") torch.set_default_device("cuda")
torch.set_default_dtype(dtype) torch.set_default_dtype(dtype)
...@@ -370,8 +371,9 @@ def test_fusion_rmsnorm_quant( ...@@ -370,8 +371,9 @@ def test_fusion_rmsnorm_quant(
# Hence, we check only 2 add nodes are left (final fused rmsnorm add). # Hence, we check only 2 add nodes are left (final fused rmsnorm add).
if not enable_rms_norm_custom_op: if not enable_rms_norm_custom_op:
n_add_nodes = lambda g: sum(1 for _ in find_op_nodes(torch.ops.aten.add, g)) n_add_nodes = lambda g: sum(1 for _ in find_op_nodes(torch.ops.aten.add, g))
# 7 = 1 (RMS) + 3x2 (3xRMS_ADD, 2 each) # rms_norm is IR, not included
assert n_add_nodes(backend.graph_pre_pass) == 7 # 6 = 3x2 (3xRMS_ADD, 2 each)
assert n_add_nodes(backend.graph_pre_pass) == 6
assert n_add_nodes(backend.graph_post_pass) == 2 assert n_add_nodes(backend.graph_post_pass) == 2
......
...@@ -3,11 +3,11 @@ ...@@ -3,11 +3,11 @@
import pytest import pytest
import torch import torch
from torch._ops import OpOverload, OpOverloadPacket
from tests.compile.backend import TestBackend from tests.compile.backend import TestBackend
from vllm.compilation.passes.fusion.matcher_utils import ( from vllm.compilation.passes.fusion.matcher_utils import (
FLASHINFER_ROTARY_OP, FLASHINFER_ROTARY_OP,
RMS_OP,
ROTARY_OP, ROTARY_OP,
) )
from vllm.compilation.passes.fusion.qk_norm_rope_fusion import ( from vllm.compilation.passes.fusion.qk_norm_rope_fusion import (
...@@ -100,13 +100,8 @@ class QKNormRoPETestModel(torch.nn.Module): ...@@ -100,13 +100,8 @@ class QKNormRoPETestModel(torch.nn.Module):
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
return q, k, v return q, k, v
def ops_in_model_before(self) -> list[torch._ops.OpOverload]: def ops_in_model_before(self) -> list[OpOverload | OpOverloadPacket]:
ops = [] ops: list[OpOverload | OpOverloadPacket] = [torch.ops.vllm_ir.rms_norm]
if self.enable_rms_norm_custom_op:
ops.append(RMS_OP)
else:
ops.append(RSQRT_OP)
if self.enable_rope_custom_op: if self.enable_rope_custom_op:
if self.rotary_emb.use_flashinfer: if self.rotary_emb.use_flashinfer:
ops.append(FLASHINFER_ROTARY_OP) ops.append(FLASHINFER_ROTARY_OP)
...@@ -116,7 +111,7 @@ class QKNormRoPETestModel(torch.nn.Module): ...@@ -116,7 +111,7 @@ class QKNormRoPETestModel(torch.nn.Module):
ops.append(INDEX_SELECT_OP) ops.append(INDEX_SELECT_OP)
return ops return ops
def ops_in_model_after(self) -> list[torch._ops.OpOverload]: def ops_in_model_after(self) -> list[OpOverload | OpOverloadPacket]:
return [FUSED_QK_ROPE_OP] return [FUSED_QK_ROPE_OP]
...@@ -166,7 +161,10 @@ def test_qk_norm_rope_fusion( ...@@ -166,7 +161,10 @@ def test_qk_norm_rope_fusion(
num_heads, num_kv_heads, head_dim = 16, 4, 128 num_heads, num_kv_heads, head_dim = 16, 4, 128
T = 5 T = 5
with set_current_vllm_config(vllm_config): with (
set_current_vllm_config(vllm_config),
vllm_config.kernel_config.ir_op_priority.set_priority(),
):
model = QKNormRoPETestModel( model = QKNormRoPETestModel(
num_heads=num_heads, num_heads=num_heads,
num_kv_heads=num_kv_heads, num_kv_heads=num_kv_heads,
......
...@@ -1622,3 +1622,26 @@ def fresh_vllm_cache(monkeypatch, use_fresh_inductor_cache): ...@@ -1622,3 +1622,26 @@ def fresh_vllm_cache(monkeypatch, use_fresh_inductor_cache):
def enable_pickle(monkeypatch): def enable_pickle(monkeypatch):
"""`LLM.apply_model` requires pickling a function.""" """`LLM.apply_model` requires pickling a function."""
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
@pytest.fixture(scope="function")
def disable_log_dedup(monkeypatch):
"""
Disable log deduplication such that warning_once and info_once always print.
"""
# Patch logger._print_warning_once to remove the lru_cache decorator
from vllm import logger
original_print_warning_once = logger._print_warning_once
original_print_info_once = logger._print_info_once
original_print_debug_once = logger._print_debug_once
logger._print_warning_once = original_print_warning_once.__wrapped__
logger._print_info_once = original_print_info_once.__wrapped__
logger._print_debug_once = original_print_debug_once.__wrapped__
yield
logger._print_warning_once = original_print_warning_once
logger._print_info_once = original_print_info_once
logger._print_debug_once = original_print_debug_once
...@@ -523,3 +523,20 @@ def test_human_readable_model_len(): ...@@ -523,3 +523,20 @@ def test_human_readable_model_len():
for invalid in ["1a", "pwd", "10.24", "1.23M", "1.22T"]: for invalid in ["1a", "pwd", "10.24", "1.23M", "1.22T"]:
with pytest.raises(ArgumentError): with pytest.raises(ArgumentError):
parser.parse_args(["--max-model-len", invalid]) parser.parse_args(["--max-model-len", invalid])
def test_ir_op_priority():
from vllm.config.kernel import IrOpPriorityConfig, KernelConfig
ir_op_priority = IrOpPriorityConfig(rms_norm=["vllm_c"])
cfg1 = EngineArgs(ir_op_priority=ir_op_priority).create_engine_config()
cfg2 = EngineArgs(
kernel_config=KernelConfig(ir_op_priority=ir_op_priority)
).create_engine_config()
assert cfg1.kernel_config.ir_op_priority == cfg2.kernel_config.ir_op_priority
with pytest.raises(ValueError, match="rms_norm"):
_ = EngineArgs(
ir_op_priority=ir_op_priority,
kernel_config=KernelConfig(ir_op_priority=ir_op_priority),
).create_engine_config()
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import importlib.util
import logging
from pathlib import Path
from typing import Any
import pytest
import torch
from torch import fx
from torch.fx.experimental.proxy_tensor import make_fx
import vllm.ir.op
from vllm.ir.op import RESERVED_PROVIDERS, IrOp, IrOpImpl
# This should not exist
assert "_custom_add" not in IrOp.registry
class CustomError(Exception):
pass
@vllm.ir.register_op
def _custom_add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x + y
def test_registration_overloads():
assert all(
n not in IrOp.registry for n in ["_custom_sub", "_custom_mul", "_custom_div"]
)
# Calling with decorator
@vllm.ir.register_op()
def _custom_sub(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x - y
assert _custom_sub.name == "_custom_sub"
assert _custom_sub is IrOp.registry["_custom_sub"]
# Custom name
@vllm.ir.register_op(name="_custom_mul")
def custom_mul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x * y
assert custom_mul.name == "_custom_mul"
assert custom_mul is IrOp.registry["_custom_mul"]
# Direct construction does not register directly
def _custom_div(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x / y
custom_div = IrOp("_custom_div", _custom_div)
assert custom_div.name == "_custom_div"
assert "_custom_div" not in IrOp.registry
# Duplicate op registration not allowed
with pytest.raises(AssertionError):
@vllm.ir.register_op
def _custom_mul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x * y - 100
def test_no_kw_only_args():
# kw-only args not supported
with pytest.raises(ValueError, match="keyword-only arguments"):
@vllm.ir.register_op
def _custom_kwarg_op(
x: torch.Tensor, y: torch.Tensor, *, kwarg: int = 0
) -> torch.Tensor:
return x + y + kwarg
assert "_custom_kwarg_op" not in IrOp.registry
class TestIrOpCustomAdd:
# Registration invariants
def test_decorated_object(self):
"""Make sure that referring directly to an op is correct"""
assert isinstance(_custom_add, IrOp)
assert "_custom_add" in IrOp.registry
assert _custom_add is IrOp.registry["_custom_add"]
def test_torch_op_is_registered(self):
assert hasattr(torch.ops.vllm_ir, "_custom_add")
assert callable(torch.ops.vllm_ir._custom_add.default)
# Semantic correctness
def test_semantics_match_native(self):
x = torch.randn(4, 5)
y = torch.randn(4, 5)
# Calls native by default
out = _custom_add(x, y)
ref = x + y
torch.testing.assert_close(out, ref)
# -------------------------
# Implementation registration
# -------------------------
def test_register_impl_is_non_intrusive(self):
@_custom_add.register_impl("dummy_provider")
def dummy_impl(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x + y + 123
assert "dummy_provider" in _custom_add.impls
assert isinstance(_custom_add.impls["dummy_provider"], IrOpImpl)
x = torch.ones(2, 2)
y = torch.ones(2, 2)
# Native semantics must still hold
torch.testing.assert_close(_custom_add(x, y), x + y)
def test_schema_contains_tensor_signature(self):
schema = _custom_add._schema_str
assert "Tensor" in schema
assert "-> Tensor" in schema
# -------------------------
# FX visibility
# -------------------------
@pytest.mark.parametrize("enable_torch_wrap", [True, False])
@pytest.mark.parametrize("symbolic_trace", [True, False])
def test_trace_sees_single_custom_op(
self, symbolic_trace: bool, enable_torch_wrap: bool
):
def fn(x, y):
return _custom_add(x, y)
def find_fn(target: Any, gm: fx.GraphModule):
return gm.graph.find_nodes(op="call_function", target=target)
with pytest.raises(CustomError), vllm.ir.enable_torch_wrap(enable_torch_wrap):
if symbolic_trace:
gm = torch.fx.symbolic_trace(fn)
else:
gm = make_fx(fn)(torch.randn(2, 2), torch.randn(2, 2))
x1, y1 = torch.rand(5, 4), torch.rand(5, 4)
out_fx = gm(x1, y1)
out_eager = fn(x1, y1)
# raise error to check enable_torch_wrap context restored correctly
raise CustomError
# check behavior matches eager in all cases
torch.testing.assert_close(out_fx, out_eager)
# check that IR nodes only appear if enable_torch_wrap=True
ir_nodes = find_fn(torch.ops.vllm_ir._custom_add.default, gm)
if enable_torch_wrap:
assert len(ir_nodes) == 1, gm.code
else:
assert len(ir_nodes) == 0, gm.code
# with torch wrapping enabled (default), IR nodes appear
if symbolic_trace:
gm = torch.fx.symbolic_trace(fn)
else:
gm = make_fx(fn)(torch.randn(2, 2), torch.randn(2, 2))
ir_nodes = find_fn(torch.ops.vllm_ir._custom_add.default, gm)
assert len(ir_nodes) == 1, gm.code
@_custom_add.register_impl("impl_a")
def impl_a(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x + y + 10
@_custom_add.register_impl("impl_b")
def impl_b(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x + y + 20
@_custom_add.register_impl("impl_even", supports_args=lambda x, y: x.size(1) % 2 == 0)
def impl_even(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x + y + 50
class TestIrOpImplDispatch:
def test_register_impl(self):
assert "impl_a" in _custom_add.impls
impl = _custom_add.impls["impl_a"]
assert impl is impl_a
assert impl.op is _custom_add
assert impl.provider == "impl_a"
assert callable(impl.impl_fn)
# Test duplicate registration rejected
with pytest.raises(AssertionError):
@_custom_add.register_impl("impl_a")
def impl_a_dup(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x + y + 30
# Check the original impl is still intact
assert _custom_add.impls["impl_a"] is impl_a
# Check support all args
assert impl_a.supports_all_args
assert impl_b.supports_all_args
assert not impl_even.supports_all_args
def test_reserved_provider_rejected(self):
for provider in RESERVED_PROVIDERS:
with pytest.raises(AssertionError):
@_custom_add.register_impl(provider)
def bad_impl(x, y):
return x + y
def test_set_priority_scoped(self):
assert _custom_add.get_priority() == []
with _custom_add.set_priority(["impl_even", "impl_b"]):
assert _custom_add.get_priority() == ["impl_even", "impl_b"]
# Check nesting
with _custom_add.set_priority(["impl_b"]):
assert _custom_add.get_priority() == ["impl_b"]
# Restored
assert _custom_add.get_priority() == ["impl_even", "impl_b"]
# Check that exception restores priority
with pytest.raises(CustomError), _custom_add.set_priority(["impl_a"]):
assert _custom_add.get_priority() == ["impl_a"]
raise CustomError
# Restored again
assert _custom_add.get_priority() == ["impl_even", "impl_b"]
# Restored to empty
assert _custom_add.get_priority() == []
def test_dispatch_priority_order(self):
x = torch.tensor(1, dtype=torch.int32)
y = torch.tensor(2, dtype=torch.int32)
with _custom_add.set_priority(["impl_b", "impl_a"]):
assert _custom_add.dispatch(x, y) is impl_b
out1 = _custom_add(x, y)
out2 = torch.ops.vllm_ir._custom_add(x, y)
with _custom_add.set_priority(["impl_a"]):
assert _custom_add.dispatch(x, y) is impl_a
out3 = _custom_add(x, y)
out4 = torch.ops.vllm_ir._custom_add(x, y)
# impl_b
assert out1.item() == 1 + 2 + 20
assert out2.item() == 1 + 2 + 20
# impl_a
assert out3.item() == 1 + 2 + 10
assert out4.item() == 1 + 2 + 10
def test_unsupported_impl_filtered(self):
@_custom_add.register_impl("unsupported", supported=False)
def impl_bad(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x + y + 999
x = torch.tensor(1, dtype=torch.int32)
y = torch.tensor(2, dtype=torch.int32)
with _custom_add.set_priority(["unsupported", "impl_a"]):
assert _custom_add.get_priority() == ["impl_a"]
out = _custom_add(x, y)
# impl_bad skipped → impl_a
assert out.item() == 1 + 2 + 10
def test_supports_args_runtime_dispatch_and_warning(
self, caplog_vllm: pytest.LogCaptureFixture
):
x1 = torch.ones((2, 2), dtype=torch.int32)
y1 = torch.full((2, 2), 2, dtype=torch.int32)
x2 = torch.ones((2, 3), dtype=torch.int32)
y2 = torch.full((2, 3), 2, dtype=torch.int32)
with (
caplog_vllm.at_level(logging.WARNING),
_custom_add.set_priority(["impl_even"]),
):
# Test the warning about native fallback is logged (before even dispatching)
assert len(caplog_vllm.records) == 1
message = caplog_vllm.records[0].message
assert "_custom_add" in message
assert "fallback to native" in message
assert "priority" in message
# Check dispatching
assert _custom_add.get_priority() == ["impl_even", "native"]
assert _custom_add.dispatch(x1, y1) is impl_even
assert _custom_add.dispatch(x2, y2) is _custom_add.impls["native"]
out1 = _custom_add(x1, y1) # size(1) == 2 → impl_even
out2 = _custom_add(x2, y2) # size(1) == 3 → native fallback
# no other warnings
assert len(caplog_vllm.records) == 1
assert torch.all(out1 == 1 + 2 + 50)
assert torch.all(out2 == 1 + 2)
def test_default_priority(
self, caplog_vllm: pytest.LogCaptureFixture, disable_log_dedup
):
# Make sure logs are not deduplicated to properly test the warning
x = torch.tensor([3], dtype=torch.int32)
y = torch.tensor([4], dtype=torch.int32)
# No priority set → falls back to native
assert _custom_add.get_priority() == []
with caplog_vllm.at_level(logging.WARNING):
# Native by default
assert _custom_add.dispatch(x, y) is _custom_add.impls["native"]
out = _custom_add(x, y)
# Check dispatching to native by default
assert out.item() == 3 + 4
# Check warning
assert len(caplog_vllm.records) == 2
message = caplog_vllm.records[0].message.lower()
assert "_custom_add" in message
assert "priority not set" in message
@vllm.ir.register_op
def _custom_mm(
x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor | None = None
) -> torch.Tensor:
tmp = x @ y
return tmp if bias is None else tmp + bias
def test_default_args():
# Test that default args are properly applied when dispatching and calling
@_custom_mm.register_impl("impl_mm", supports_args=lambda x, y, bias=None: True)
def impl_mm(
x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor | None = None
) -> torch.Tensor:
tmp = x @ y
return tmp + 50 if bias is None else tmp + bias + 100
x1 = torch.tensor([1, 2], dtype=torch.int32)
x2 = torch.tensor([3, 4], dtype=torch.int32)
# Test that supports_args receives the defaulted args
assert impl_mm.supports_args(x1, x2)
with _custom_mm.set_priority(["impl_mm", "native"]):
assert _custom_mm.dispatch(x1, x2) is impl_mm
def test_bad_impl_registrations():
# Check bad schema
with pytest.raises(ValueError, match="does not match native schema"):
@_custom_mm.register_impl("impl_mm_bad_schema")
def impl_mm_bad_schema(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x @ y - 1
with pytest.raises(ValueError, match="does not match native schema"):
@_custom_mm.register_impl("impl_mm_bad_schema_2")
def impl_mm_bad_schema_2(
x: torch.Tensor, y: torch.Tensor, b: torch.Tensor | None = None
) -> torch.Tensor:
return x @ y + b - 2
with pytest.raises(ValueError, match="does not match native schema"):
@_custom_mm.register_impl("impl_mm_bad_schema_3")
def impl_mm_bad_schema_3(
x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor
) -> torch.Tensor:
return x @ y + bias - 5
# check supports_args with incorrect params
with pytest.raises(ValueError, match="supports_args must be a callable"):
@_custom_mm.register_impl("impl_mm_bad_supports_args", supports_args=True)
def impl_mm_bad_supports_args(
x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor | None = None
) -> torch.Tensor:
return x @ y + 10
with pytest.raises(ValueError, match="number of parameters"):
@_custom_mm.register_impl(
"impl_mm_bad_supports_args_2", supports_args=lambda x, y: True
)
def impl_mm_bad_supports_args(
x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor | None = None
) -> torch.Tensor:
return x @ y + 10
with pytest.raises(ValueError, match="keyword-only parameters"):
@_custom_mm.register_impl(
"impl_mm_bad_supports_args_3", supports_args=lambda x, y, *, b: True
)
def impl_mm_bad_supports_args_2(
x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor | None = None
) -> torch.Tensor:
return x @ y + 20
with pytest.raises(ValueError, match="does not match native parameter"):
@_custom_mm.register_impl(
"impl_mm_bad_supports_args_4", supports_args=lambda x, y, b: True
)
def impl_mm_bad_supports_args_4(
x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor | None = None
) -> torch.Tensor:
return x @ y + 30
with pytest.raises(ValueError, match="does not match native default"):
@_custom_mm.register_impl(
"impl_mm_bad_supports_args_5", supports_args=lambda x, y, bias=1: True
)
def impl_mm_bad_supports_args_5(
x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor | None = None
) -> torch.Tensor:
return x @ y + 40
assert set(_custom_mm.impls.keys()) == {"impl_mm", "native"}
IMPL_OOT_SRC = """
import torch
@_custom_mm.register_impl("impl_mm_oot")
def impl_mm_oot(
x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor | None = None
) -> torch.Tensor:
return x @ y - 99
"""
def load_custom_mm_module(file_path: Path):
spec = importlib.util.spec_from_file_location("_custom_mm_oot", file_path)
assert spec is not None
module = importlib.util.module_from_spec(spec)
# Inject the variable into the module's global namespace
# This allows the @_custom_mm.register_impl decorator to work
module._custom_mm = _custom_mm # type: ignore[attr-defined]
# Execute the file; this triggers the decorator
assert spec.loader is not None
spec.loader.exec_module(module)
return module
def test_uuid_and_oot(tmp_path: Path):
file_path = tmp_path / "_custom_mm_oot.py"
file_path.write_text(IMPL_OOT_SRC)
assert "impl_mm_oot" not in _custom_mm.impls
_ = load_custom_mm_module(file_path)
assert "impl_mm_oot" in _custom_mm.impls
uuid = _custom_mm.impls["impl_mm_oot"].uuid()
del _custom_mm.impls["impl_mm_oot"]
# Replace file source
file_path.write_text(IMPL_OOT_SRC + " # added file source")
assert "impl_mm_oot" not in _custom_mm.impls
_ = load_custom_mm_module(file_path)
assert "impl_mm_oot" in _custom_mm.impls
uuid1 = _custom_mm.impls["impl_mm_oot"].uuid()
assert uuid1 != uuid
del _custom_mm.impls["impl_mm_oot"]
# Back to original
file_path.write_text(IMPL_OOT_SRC)
assert "impl_mm_oot" not in _custom_mm.impls
_ = load_custom_mm_module(file_path)
assert "impl_mm_oot" in _custom_mm.impls
uuid2 = _custom_mm.impls["impl_mm_oot"].uuid()
assert uuid2 == uuid
assert uuid2 != uuid1
del _custom_mm.impls["impl_mm_oot"]
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
# This registers op implementations
import vllm.kernels # noqa: F401
from tests.kernels.allclose_default import get_default_rtol
from vllm import ir
from vllm.platforms import current_platform
def rms_norm_inputs(n_tokens: int, hidden_size: int, dtype: torch.dtype):
x = torch.randn(n_tokens, hidden_size, dtype=dtype)
weight = torch.rand(hidden_size, dtype=dtype)
return x, weight
rms_norm_native = ir.ops.rms_norm.impls["native"].impl_fn
@pytest.mark.skipif(
not current_platform.is_cuda_alike() and not current_platform.is_xpu(),
reason="Currently only kernels on CUDA, ROCm and XPU",
)
def test_rms_norm_registration():
expected = {
"native": True,
"vllm_c": current_platform.is_cuda_alike(),
"aiter": current_platform.is_rocm(),
"oink": False,
"xpu_kernels": current_platform.is_xpu(),
}
actual = {
provider: impl.supported for provider, impl in ir.ops.rms_norm.impls.items()
}
assert actual == expected
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
@pytest.mark.parametrize("n_tokens", [1, 8, 17])
@pytest.mark.parametrize("hidden_size", [16, 4096, 8192])
@pytest.mark.parametrize("epsilon", [1e-6, 1e-5])
@pytest.mark.skipif(
not current_platform.is_cuda_alike() and not current_platform.is_xpu(),
reason="Currently only kernels on CUDA, ROCm and XPU",
)
class TestRMSNorm:
@classmethod
def setup_class(cls, **kwargs):
torch.set_default_device(current_platform.device_type)
def test_native_semantics(self, dtype, n_tokens, hidden_size, epsilon):
x, weight = rms_norm_inputs(4, 8, dtype)
out = rms_norm_native(x, weight, epsilon=epsilon)
# Check shape, dtype, device
assert out.shape == x.shape
assert out.dtype == x.dtype
assert out.device == x.device
# Check the scaling property of rms norm
out2 = rms_norm_native(x * 2.0, weight, epsilon=epsilon)
torch.testing.assert_close(out2, out, rtol=get_default_rtol(out), atol=1e-3)
# Check behavior with and without weight
weight1 = torch.ones_like(weight)
out3 = rms_norm_native(x, weight1, epsilon=epsilon)
out4 = rms_norm_native(x, None, epsilon=epsilon)
torch.testing.assert_close(out3, out4)
@pytest.mark.parametrize("provider", ["vllm_c", "aiter", "xpu_kernels"])
def test_impls(self, dtype, n_tokens, hidden_size, epsilon, provider):
impl = ir.ops.rms_norm.impls[provider]
if not impl.supported:
pytest.skip(f"{provider} impl not supported on this platform")
x, weight = rms_norm_inputs(n_tokens, hidden_size, dtype)
args = (x, weight, epsilon, None)
assert impl.supported
if provider == "aiter" and dtype not in [torch.float16, torch.bfloat16]:
assert not impl.supports_args(*args)
return
assert impl.supports_args(*args)
out_impl = impl.impl_fn(*args)
out_native = rms_norm_native(*args)
torch.testing.assert_close(
out_impl, out_native, rtol=get_default_rtol(out_impl), atol=1e-3
)
# check that dispatched call matches direct call
with ir.ops.rms_norm.set_priority([provider, "native"]):
out_impl2 = ir.ops.rms_norm(*args)
# exact match
torch.testing.assert_close(out_impl2, out_impl, rtol=0.0, atol=0.0)
# none of these support variance_size override
assert not impl.supports_args(x, weight, epsilon, 4)
assert not impl.supports_args(x, weight, epsilon, variance_size=4)
# test weight=None behavior
out_impl_no_weight = impl.impl_fn(x, None, epsilon)
out_impl_unit_weight = impl.impl_fn(x, torch.ones_like(weight), epsilon)
torch.testing.assert_close(
out_impl_no_weight,
out_impl_unit_weight,
rtol=get_default_rtol(out_impl_no_weight),
atol=2e-4,
)
@pytest.mark.parametrize("provider", ["vllm_c", "aiter", "xpu_kernels", "native"])
def test_torch_opcheck(self, dtype, n_tokens, hidden_size, epsilon, provider):
if not ir.ops.rms_norm.impls[provider].supported:
pytest.skip(f"{provider} impl not supported on this platform")
x, weight = rms_norm_inputs(n_tokens, hidden_size, dtype)
args = (x, weight, epsilon, None)
# When checking the torch op, we have to set priority and use dispatch
with ir.ops.rms_norm.set_priority([provider, "native"]):
torch.library.opcheck(torch.ops.vllm_ir.rms_norm, args)
...@@ -27,7 +27,6 @@ from vllm.model_executor.layers.layernorm import ( ...@@ -27,7 +27,6 @@ from vllm.model_executor.layers.layernorm import (
RMSNorm, RMSNorm,
dispatch_rocm_rmsnorm_func, dispatch_rocm_rmsnorm_func,
fused_add_rms_norm, fused_add_rms_norm,
rms_norm,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -156,7 +155,7 @@ def test_topk_sigmoid_dispatch(use_rocm_aiter: bool): ...@@ -156,7 +155,7 @@ def test_topk_sigmoid_dispatch(use_rocm_aiter: bool):
assert topk_func == vllm_topk_sigmoid assert topk_func == vllm_topk_sigmoid
@pytest.mark.parametrize("add_residual", [True, False]) @pytest.mark.parametrize("add_residual", [False])
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("use_rocm_aiter", [True, False]) @pytest.mark.parametrize("use_rocm_aiter", [True, False])
@pytest.mark.skipif( @pytest.mark.skipif(
...@@ -165,7 +164,7 @@ def test_topk_sigmoid_dispatch(use_rocm_aiter: bool): ...@@ -165,7 +164,7 @@ def test_topk_sigmoid_dispatch(use_rocm_aiter: bool):
def test_rms_norm_dispatch( def test_rms_norm_dispatch(
add_residual: bool, dtype: torch.dtype, use_rocm_aiter: bool add_residual: bool, dtype: torch.dtype, use_rocm_aiter: bool
): ):
rms_norm_func = dispatch_rocm_rmsnorm_func(add_residual, dtype, use_rocm_aiter) rms_norm_func = dispatch_rocm_rmsnorm_func(dtype, use_rocm_aiter)
should_use_rocm_aiter = ( should_use_rocm_aiter = (
current_platform.is_rocm() current_platform.is_rocm()
...@@ -173,11 +172,7 @@ def test_rms_norm_dispatch( ...@@ -173,11 +172,7 @@ def test_rms_norm_dispatch(
and dtype in RMS_NORM_SUPPORTED_DTYPES and dtype in RMS_NORM_SUPPORTED_DTYPES
) )
if add_residual and should_use_rocm_aiter: if should_use_rocm_aiter:
assert rms_norm_func == rocm_aiter_ops.rms_norm2d_with_add assert rms_norm_func == rocm_aiter_ops.rms_norm2d_with_add
elif should_use_rocm_aiter:
assert rms_norm_func == rocm_aiter_ops.rms_norm
elif add_residual:
assert rms_norm_func == fused_add_rms_norm
else: else:
assert rms_norm_func == rms_norm assert rms_norm_func == fused_add_rms_norm
...@@ -6,12 +6,14 @@ import os ...@@ -6,12 +6,14 @@ import os
from dataclasses import MISSING, Field, asdict, dataclass, field from dataclasses import MISSING, Field, asdict, dataclass, field
from unittest.mock import patch from unittest.mock import patch
import pydantic
import pytest import pytest
from pydantic import ValidationError from pydantic import ValidationError
from vllm.compilation.backends import VllmBackend from vllm.compilation.backends import VllmBackend
from vllm.config import ( from vllm.config import (
CompilationConfig, CompilationConfig,
KernelConfig,
ModelConfig, ModelConfig,
ParallelConfig, ParallelConfig,
PoolerConfig, PoolerConfig,
...@@ -21,6 +23,7 @@ from vllm.config import ( ...@@ -21,6 +23,7 @@ from vllm.config import (
update_config, update_config,
) )
from vllm.config.compilation import CompilationMode, CUDAGraphMode from vllm.config.compilation import CompilationMode, CUDAGraphMode
from vllm.config.kernel import IrOpPriorityConfig
from vllm.config.load import LoadConfig from vllm.config.load import LoadConfig
from vllm.config.utils import get_field from vllm.config.utils import get_field
from vllm.config.vllm import ( from vllm.config.vllm import (
...@@ -1077,6 +1080,39 @@ def test_vllm_config_explicit_overrides(): ...@@ -1077,6 +1080,39 @@ def test_vllm_config_explicit_overrides():
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE
def test_fusion_pass_op_priority():
"""This test checks that custom op enablement & IR op priority
correctly control default fusions"""
# Default config, O2, rms_norm+quant fusion disabled
cfg1 = VllmConfig()
assert not cfg1.compilation_config.pass_config.fuse_norm_quant
# rms_norm manually enabled, O1, rms_norm+quant fusion enabled
cfg2 = VllmConfig(
optimization_level=OptimizationLevel.O1,
compilation_config=CompilationConfig(
custom_ops=["+rms_norm"],
),
)
assert cfg2.compilation_config.pass_config.fuse_norm_quant
# using custom kernel for RMSNorm via IR:
# Note that vLLM IR only supports the non-residual rms_norm for now;
# soon this will be resolved.
cfg3 = VllmConfig(
kernel_config=KernelConfig(
ir_op_priority=IrOpPriorityConfig(rms_norm=["vllm_c"])
)
)
assert cfg3.compilation_config.pass_config.fuse_norm_quant
# block-fp8 model should enable quant_fp8 automatically
cfg4 = VllmConfig(model_config=ModelConfig("Qwen/Qwen3-4B-FP8"))
assert "+quant_fp8" in cfg4.compilation_config.custom_ops
assert cfg4.compilation_config.pass_config.fuse_norm_quant
def test_scheduler_config_init(): def test_scheduler_config_init():
with pytest.raises(ValidationError): with pytest.raises(ValidationError):
# Positional InitVars missing # Positional InitVars missing
...@@ -1171,3 +1207,35 @@ def test_eagle_draft_model_config(): ...@@ -1171,3 +1207,35 @@ def test_eagle_draft_model_config():
assert draft_model_config.hf_text_config.model_type == "eagle" assert draft_model_config.hf_text_config.model_type == "eagle"
assert draft_model_config.architectures == ["EagleLlamaForCausalLM"] assert draft_model_config.architectures == ["EagleLlamaForCausalLM"]
assert draft_model_config.architecture == "EagleLlamaForCausalLM" assert draft_model_config.architecture == "EagleLlamaForCausalLM"
def test_ir_op_priority_default():
"""Test that IR op priority defaults are set correctly."""
from vllm.config.kernel import IrOpPriorityConfig
# Assert default is applied to ops
priority_config = IrOpPriorityConfig.with_default(["vllm_c", "native"])
assert priority_config.rms_norm == ["vllm_c", "native"]
# Assert single ops override the default
assert IrOpPriorityConfig.with_default(
["vllm_c", "native"], rms_norm=["oink", "native"]
) == IrOpPriorityConfig(rms_norm=["oink", "native"])
def test_ir_op_priority_str():
"""Test that passing a comma-delimited string works"""
from vllm.config.kernel import IrOpPriorityConfig
priority_config = IrOpPriorityConfig(rms_norm="vllm_c")
assert priority_config.rms_norm == ["vllm_c"]
priority_config = IrOpPriorityConfig(rms_norm="vllm_c,native")
assert priority_config.rms_norm == ["vllm_c", "native"]
priority_config = IrOpPriorityConfig(rms_norm=" native, vllm_c ")
assert priority_config.rms_norm == ["native", "vllm_c"]
with pytest.raises(pydantic.ValidationError):
# must be list of only strings
priority_config = IrOpPriorityConfig(rms_norm=["vllm_c", 4, "native"])
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import contextlib import contextlib
from importlib.util import find_spec from importlib.util import find_spec
from types import ModuleType from types import ModuleType
from typing import Any
import torch import torch
import torch._inductor.pattern_matcher as pm import torch._inductor.pattern_matcher as pm
...@@ -10,6 +11,7 @@ import torch.fx as fx ...@@ -10,6 +11,7 @@ import torch.fx as fx
from torch._higher_order_ops.auto_functionalize import auto_functionalized from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import PatternMatcherPass from torch._inductor.pattern_matcher import PatternMatcherPass
import vllm.ir.ops
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.utils import Range from vllm.config.utils import Range
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
...@@ -28,7 +30,7 @@ from vllm.utils.torch_utils import ( ...@@ -28,7 +30,7 @@ from vllm.utils.torch_utils import (
from ..inductor_pass import enable_fake_mode from ..inductor_pass import enable_fake_mode
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8
FP8_DTYPE = current_platform.fp8_dtype() FP8_DTYPE = current_platform.fp8_dtype()
...@@ -258,6 +260,12 @@ class BasePattern: ...@@ -258,6 +260,12 @@ class BasePattern:
self.tp = get_tp_group() self.tp = get_tp_group()
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
def empty(self, *args: Any, **kwargs: Any) -> torch.Tensor:
return torch.empty(*args, dtype=self.dtype, device=self.device, **kwargs)
def empty_f32(self, *args: Any, **kwargs: Any) -> torch.Tensor:
return torch.empty(*args, dtype=torch.float32, device=self.device, **kwargs)
class AllReduceRMSNormPattern(BasePattern): class AllReduceRMSNormPattern(BasePattern):
""" """
...@@ -276,20 +284,17 @@ class AllReduceRMSNormPattern(BasePattern): ...@@ -276,20 +284,17 @@ class AllReduceRMSNormPattern(BasePattern):
super().__init__(dtype, device) super().__init__(dtype, device)
self.epsilon = epsilon self.epsilon = epsilon
self.allreduce_params = allreduce_params self.allreduce_params = allreduce_params
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
def get_inputs(self) -> list[torch.Tensor]: def get_inputs(self) -> list[torch.Tensor]:
input, weight = self.rmsnorm_matcher.inputs() # input, weight
return [self.empty(5, 16), self.empty(16)]
# input goes through allreduce first, always 16-bit
return [input.to(self.dtype), weight]
def register(self, pm_pass: PatternMatcherPass) -> None: def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern( def pattern(
input: torch.Tensor, weight: torch.Tensor input: torch.Tensor, weight: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
allreduce_output = tensor_model_parallel_all_reduce(input) allreduce_output = tensor_model_parallel_all_reduce(input)
rms = self.rmsnorm_matcher(allreduce_output, weight) rms = vllm.ir.ops.rms_norm(allreduce_output, weight, self.epsilon)
return rms, allreduce_output return rms, allreduce_output
...@@ -407,15 +412,13 @@ class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern): ...@@ -407,15 +412,13 @@ class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern):
self.epsilon = epsilon self.epsilon = epsilon
self.allreduce_params = allreduce_params self.allreduce_params = allreduce_params
self.quant_dtype = torch.float8_e4m3fn self.quant_dtype = torch.float8_e4m3fn
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym) self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
def get_inputs(self) -> list[torch.Tensor]: def get_inputs(self) -> list[torch.Tensor]:
input, weight = self.rmsnorm_matcher.inputs()
_, scale = self.quant_matcher.inputs() _, scale = self.quant_matcher.inputs()
# input goes through allreduce first, always 16-bit # input, weight
return [input.to(self.dtype), weight, scale] return [self.empty(5, 16), self.empty(16), scale]
def register(self, pm_pass: PatternMatcherPass) -> None: def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern( def pattern(
...@@ -424,7 +427,7 @@ class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern): ...@@ -424,7 +427,7 @@ class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern):
scale: torch.Tensor, scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
all_reduce = tensor_model_parallel_all_reduce(input) all_reduce = tensor_model_parallel_all_reduce(input)
rms = self.rmsnorm_matcher(all_reduce, weight) rms = vllm.ir.ops.rms_norm(all_reduce, weight, self.epsilon)
quant, _ = self.quant_matcher(rms, scale) quant, _ = self.quant_matcher(rms, scale)
return quant, all_reduce return quant, all_reduce
...@@ -553,7 +556,6 @@ class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern): ...@@ -553,7 +556,6 @@ class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern):
super().__init__(dtype, device) super().__init__(dtype, device)
self.epsilon = epsilon self.epsilon = epsilon
self.allreduce_params = allreduce_params self.allreduce_params = allreduce_params
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
def get_inputs(self) -> list[torch.Tensor]: def get_inputs(self) -> list[torch.Tensor]:
input = torch.empty([1, 16, 16], device=self.device, dtype=self.dtype) input = torch.empty([1, 16, 16], device=self.device, dtype=self.dtype)
...@@ -575,7 +577,7 @@ class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern): ...@@ -575,7 +577,7 @@ class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern):
output_scale: torch.Tensor, output_scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
all_reduce = tensor_model_parallel_all_reduce(input) all_reduce = tensor_model_parallel_all_reduce(input)
rms = self.rmsnorm_matcher(all_reduce, weight) rms = vllm.ir.ops.rms_norm(all_reduce, weight, self.epsilon)
quant_out_tuple = auto_functionalized( quant_out_tuple = auto_functionalized(
STATIC_FP4_QUANT_OP, STATIC_FP4_QUANT_OP,
input=rms, input=rms,
......
...@@ -26,7 +26,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( ...@@ -26,7 +26,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.platforms import current_platform from vllm.platforms import current_platform
RMS_OP = torch.ops._C.rms_norm.default
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
ROTARY_OP = torch.ops._C.rotary_embedding.default ROTARY_OP = torch.ops._C.rotary_embedding.default
FLASHINFER_ROTARY_OP = torch.ops.vllm.flashinfer_rotary_embedding.default FLASHINFER_ROTARY_OP = torch.ops.vllm.flashinfer_rotary_embedding.default
...@@ -160,69 +159,6 @@ class MatcherRotaryEmbedding(MatcherCustomOp): ...@@ -160,69 +159,6 @@ class MatcherRotaryEmbedding(MatcherCustomOp):
return result return result
class MatcherRMSNorm(MatcherCustomOp):
def __init__(
self,
epsilon: float,
enabled: bool | None = None,
match_rocm_aiter: bool = False,
) -> None:
if enabled is None:
enabled = RMSNorm.enabled()
super().__init__(enabled)
self.epsilon = epsilon
self._rmsnorm_op = RMS_OP
self.match_rocm_aiter = match_rocm_aiter
if match_rocm_aiter:
self._rmsnorm_op = rocm_aiter_ops.get_rmsnorm_op()
def inputs(self) -> list[torch.Tensor]:
input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16)
weight = self.empty(16)
return [input, weight]
def forward_rocm_aiter(
self,
input: torch.Tensor,
weight: torch.Tensor,
) -> torch.Tensor:
return self._rmsnorm_op(
x=input,
weight=weight,
variance_epsilon=self.epsilon,
)
def forward_custom(
self,
input: torch.Tensor,
weight: torch.Tensor,
) -> torch.Tensor:
if self.match_rocm_aiter:
return self.forward_rocm_aiter(input, weight)
result = torch.empty_like(input)
_, result = auto_functionalized(
self._rmsnorm_op,
result=result,
input=input,
weight=weight,
epsilon=self.epsilon,
)
return result
def forward_native(
self,
input: torch.Tensor,
weight: torch.Tensor,
) -> torch.Tensor:
return RMSNorm.forward_static(
input, self.epsilon, input.size(-1), self.model_dtype, weight
)
class MatcherFusedAddRMSNorm(MatcherCustomOp): class MatcherFusedAddRMSNorm(MatcherCustomOp):
def __init__( def __init__(
self, self,
......
...@@ -10,6 +10,7 @@ from torch import fx ...@@ -10,6 +10,7 @@ from torch import fx
from torch._higher_order_ops.auto_functionalize import auto_functionalized from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import PatternMatcherPass from torch._inductor.pattern_matcher import PatternMatcherPass
import vllm.ir.ops
from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.attention import Attention
...@@ -17,7 +18,7 @@ from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding ...@@ -17,7 +18,7 @@ from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from ..inductor_pass import enable_fake_mode from ..inductor_pass import enable_fake_mode
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
from .matcher_utils import MatcherRMSNorm, MatcherRotaryEmbedding from .matcher_utils import MatcherRotaryEmbedding
from .rms_quant_fusion import empty_bf16, empty_fp32, empty_i64 from .rms_quant_fusion import empty_bf16, empty_fp32, empty_i64
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -64,7 +65,6 @@ class QkNormRopePattern: ...@@ -64,7 +65,6 @@ class QkNormRopePattern:
self.q_size = self.num_heads * self.head_dim self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim
self.eps = eps self.eps = eps
self.rmsnorm_matcher = MatcherRMSNorm(eps)
self.is_neox = is_neox self.is_neox = is_neox
self.rope_flashinfer = rope_flashinfer self.rope_flashinfer = rope_flashinfer
self.rope_matcher = MatcherRotaryEmbedding( self.rope_matcher = MatcherRotaryEmbedding(
...@@ -129,14 +129,14 @@ class QkNormRopePattern: ...@@ -129,14 +129,14 @@ class QkNormRopePattern:
q_by_head = q.view( q_by_head = q.view(
*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim *q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim
) )
q_normed_by_head = self.rmsnorm_matcher(q_by_head, q_weight) q_normed_by_head = vllm.ir.ops.rms_norm(q_by_head, q_weight, self.eps)
q_flat = q_normed_by_head.view(q.shape) q_flat = q_normed_by_head.view(q.shape)
# K path: view -> RMS -> view back to k.shape # K path: view -> RMS -> view back to k.shape
k_by_head = k.view( k_by_head = k.view(
*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim *k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim
) )
k_normed_by_head = self.rmsnorm_matcher(k_by_head, k_weight) k_normed_by_head = vllm.ir.ops.rms_norm(k_by_head, k_weight, self.eps)
k_flat = k_normed_by_head.view(k.shape) k_flat = k_normed_by_head.view(k.shape)
# RoPE: apply to flattened q/k # RoPE: apply to flattened q/k
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment