Commit 3fb4b5fa authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.18.0' into v0.18.0-ori

parents bcf25339 89138b21
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from ..utils import compare_two_settings
def test_cpu_offload():
@pytest.mark.parametrize("disable_pin_memory", [False, True])
@pytest.mark.parametrize("disable_uva", [False, True])
def test_cpu_offload(disable_pin_memory, disable_uva):
env_vars = {
"VLLM_WEIGHT_OFFLOADING_DISABLE_PIN_MEMORY": str(int(disable_pin_memory)),
"VLLM_WEIGHT_OFFLOADING_DISABLE_UVA": str(int(disable_uva)),
}
args = ["--cpu-offload-gb", "1"]
# cuda graph only works with UVA offloading
if disable_uva:
args.append("--enforce-eager")
compare_two_settings(
"hmellor/tiny-random-LlamaForCausalLM", [], ["--cpu-offload-gb", "1"]
model="hmellor/tiny-random-LlamaForCausalLM",
arg1=[],
arg2=args,
env1=None,
env2=env_vars,
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Test prefetch offloading correctness with Llama model."""
from ..utils import compare_two_settings
def test_prefetch_offload_llama():
"""Test prefetch CPU offloading with Llama-3.2-1B-Instruct.
Compares outputs between:
1. Baseline (no offloading)
2. Prefetch offloading (group_size=8, num_in_group=2, prefetch_step=1)
This tests prefetching-based offloading on a dense model.
"""
compare_two_settings(
"meta-llama/Llama-3.2-1B-Instruct",
[
# Prefetch offloading configuration
"--offload-group-size",
"8",
"--offload-num-in-group",
"2",
"--offload-prefetch-step",
"1",
# Selective offloading: only MLP weights
"--offload-params",
"gate_up_proj",
"down_proj",
],
[], # Baseline: no offloading
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
from collections.abc import Callable
from pathlib import Path
from unittest.mock import patch
from vllm.benchmarks.sweep.param_sweep import ParameterSweepItem
from vllm.benchmarks.sweep.serve_sla import _get_sla_run_path, solve_sla
from vllm.benchmarks.sweep.server import ServerProcess
from vllm.benchmarks.sweep.sla_sweep import (
SLACriterionBase,
SLALessThan,
SLALessThanOrEqualTo,
SLASweepItem,
)
def _set_return_value(
var2metric: Callable[[ParameterSweepItem], list[dict[str, float]]],
):
"""
Create a patch for run_sla with a specific function
indicating the relationship between the benchmark combination
(which includes the SLA variable) and the SLA criterion.
"""
def mock_run_sla(
server: ServerProcess | None,
bench_cmd: list[str],
*,
serve_comb: ParameterSweepItem,
bench_comb: ParameterSweepItem,
iter_path: Path,
num_runs: int,
dry_run: bool,
):
iter_data = var2metric(bench_comb)
summary_path = _get_sla_run_path(iter_path, run_number=None)
summary_path.parent.mkdir(parents=True, exist_ok=True)
with summary_path.open("w") as f:
json.dump(iter_data, f, indent=4)
return iter_data
return patch("vllm.benchmarks.sweep.serve_sla.run_sla", side_effect=mock_run_sla)
def _var2metric_linear():
def wrapped(bench_comb):
x = float(bench_comb["request_rate"])
y = x
return [{"request_throughput": y}]
return wrapped
def _var2metric_concave(elbow_point: float):
def wrapped(bench_comb):
x = float(bench_comb["request_rate"])
if x < elbow_point:
y = 0.5 * (x - elbow_point) + elbow_point
else:
y = 1.5 * (x - elbow_point) + elbow_point
return [{"request_throughput": y}]
return wrapped
def _var2metric_convex(elbow_point: float):
def wrapped(bench_comb):
x = float(bench_comb["request_rate"])
if x < elbow_point:
y = 1.5 * (x - elbow_point) + elbow_point
else:
y = 0.5 * (x - elbow_point) + elbow_point
return [{"request_throughput": y}]
return wrapped
def _var2metric_quadratic(y_intercept: float):
def wrapped(bench_comb):
x = float(bench_comb["request_rate"])
y = y_intercept + 0.1 * x**2
return [{"request_throughput": y}]
return wrapped
def _var2metric_sqrt(y_intercept: float):
def wrapped(bench_comb):
x = float(bench_comb["request_rate"])
y = y_intercept + 10 * x**0.5
return [{"request_throughput": y}]
return wrapped
def _run_solve_sla(
var2metric: Callable[[ParameterSweepItem], list[dict[str, float]]],
criterion: SLACriterionBase,
base_path: Path,
min_value: int = 1,
max_value: int = 100,
):
with _set_return_value(var2metric):
result = solve_sla(
server=None,
bench_cmd=[],
serve_comb=ParameterSweepItem(),
bench_comb=ParameterSweepItem(),
sla_comb=SLASweepItem({"request_throughput": criterion}),
base_path=base_path,
num_runs=1,
dry_run=False,
sla_variable="request_rate",
sla_min_value=min_value,
sla_max_value=max_value,
)
assert result is not None
return result
def test_solve_linear_sla_le(tmp_path):
sla_data, history = _run_solve_sla(
_var2metric_linear(),
SLALessThanOrEqualTo(target=32),
tmp_path,
)
assert history.get_max_passing() == 32
assert {val: margin <= 0 for val, margin in history.items()} == {
100: False,
1: True,
32: True,
33: False,
}
def test_solve_linear_sla_lt(tmp_path):
sla_data, history = _run_solve_sla(
_var2metric_linear(),
SLALessThan(target=32),
tmp_path,
)
assert history.get_max_passing() == 31
assert {val: margin <= 0 for val, margin in history.items()} == {
100: False,
1: True,
31: True,
32: False,
}
def test_solve_linear_sla_oob(tmp_path):
sla_data, history = _run_solve_sla(
_var2metric_linear(),
SLALessThanOrEqualTo(target=32),
tmp_path,
min_value=64,
)
assert history.get_max_passing() == 64
assert history.get_min_failing() == 64
assert {val: margin <= 0 for val, margin in history.items()} == {
100: False,
64: False,
}
def test_solve_concave_sla_le(tmp_path):
sla_data, history = _run_solve_sla(
_var2metric_concave(elbow_point=32),
SLALessThanOrEqualTo(target=24),
tmp_path,
)
assert history.get_max_passing() == 16
assert {val: margin <= 0 for val, margin in history.items()} == {
100: False,
1: True,
7: True,
13: True,
15: True,
16: True,
17: False,
}
def test_solve_convex_sla_le(tmp_path):
sla_data, history = _run_solve_sla(
_var2metric_convex(elbow_point=32),
SLALessThanOrEqualTo(target=24),
tmp_path,
)
assert history.get_max_passing() == 26
assert {val: margin <= 0 for val, margin in history.items()} == {
100: False,
1: True,
48: False,
30: False,
24: True,
26: True,
27: False,
}
def test_solve_quadratic_sla_le(tmp_path):
sla_data, history = _run_solve_sla(
_var2metric_quadratic(y_intercept=10),
SLALessThanOrEqualTo(target=50),
tmp_path,
)
assert history.get_max_passing() == 20
assert {val: margin <= 0 for val, margin in history.items()} == {
100: False,
1: True,
4: True,
20: True,
21: False,
}
def test_solve_sqrt_sla_le(tmp_path):
sla_data, history = _run_solve_sla(
_var2metric_sqrt(y_intercept=10),
SLALessThanOrEqualTo(target=100),
tmp_path,
)
assert history.get_max_passing() == 81
assert {val: margin <= 0 for val, margin in history.items()} == {
100: False,
1: True,
89: False,
81: True,
82: False,
}
def test_solve_reuse_history(tmp_path):
sla_data, history = _run_solve_sla(
_var2metric_linear(),
SLALessThanOrEqualTo(target=10),
tmp_path,
min_value=1,
max_value=20,
)
assert history.get_max_passing() == 10
assert {val: margin <= 0 for val, margin in history.items()} == {
20: False,
1: True,
10: True,
11: False,
}
sla_data, history = _run_solve_sla(
_var2metric_linear(),
SLALessThanOrEqualTo(target=30),
tmp_path,
min_value=21,
max_value=40,
)
assert history.get_max_passing() == 30
assert {val: margin <= 0 for val, margin in history.items()} == {
# Items from the past run
# (the margins are different because the target changed)
20: True,
1: True,
10: True,
11: True,
# Items from this run
40: False,
30: True,
31: False,
}
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import subprocess
import tempfile
import time
from pathlib import Path
import pytest
import requests
import urllib3
from ..utils import RemoteOpenAIServer
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
@pytest.fixture(scope="module")
def generate_self_signed_cert(cert_dir: Path) -> tuple[Path, Path]:
"""Generate a self-signed certificate for testing."""
cert_file = cert_dir / "cert.pem"
key_file = cert_dir / "key.pem"
# Generate self-signed certificate using openssl
subprocess.run(
[
"openssl",
"req",
"-x509",
"-newkey",
"rsa:2048",
"-keyout",
str(key_file),
"-out",
str(cert_file),
"-days",
"1",
"-nodes",
"-subj",
"/CN=localhost",
],
check=True,
capture_output=True,
)
return cert_file, key_file
class RemoteOpenAIServerSSL(RemoteOpenAIServer):
"""RemoteOpenAIServer subclass that supports SSL with self-signed certs."""
@property
def url_root(self) -> str:
return f"https://{self.host}:{self.port}"
def _wait_for_server(self, *, url: str, timeout: float):
"""Override to use HTTPS with SSL verification disabled."""
# Suppress InsecureRequestWarning for self-signed certs
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
start = time.time()
while True:
try:
if requests.get(url, verify=False).status_code == 200:
break
except Exception:
result = self._poll()
if result is not None and result != 0:
raise RuntimeError("Server exited unexpectedly.") from None
time.sleep(0.5)
if time.time() - start > timeout:
raise RuntimeError("Server failed to start in time.") from None
@pytest.fixture(scope="function")
def server():
args = ["--max-model-len", "1024", "--enforce-eager", "--load-format", "dummy"]
......@@ -17,6 +78,27 @@ def server():
yield remote_server
@pytest.fixture(scope="function")
def ssl_server():
"""Start a vLLM server with SSL enabled using a self-signed certificate."""
with tempfile.TemporaryDirectory() as cert_dir:
cert_file, key_file = generate_self_signed_cert(Path(cert_dir))
args = [
"--max-model-len",
"1024",
"--enforce-eager",
"--load-format",
"dummy",
"--ssl-certfile",
str(cert_file),
"--ssl-keyfile",
str(key_file),
]
with RemoteOpenAIServerSSL(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest.mark.benchmark
def test_bench_serve(server):
# Test default model detection and input/output len
......@@ -42,6 +124,31 @@ def test_bench_serve(server):
assert result.returncode == 0, f"Benchmark failed: {result.stderr}"
@pytest.mark.benchmark
def test_bench_serve_insecure(ssl_server):
"""Test --insecure flag with an HTTPS server using a self-signed certificate."""
base_url = f"https://{ssl_server.host}:{ssl_server.port}"
command = [
"vllm",
"bench",
"serve",
"--base-url",
base_url,
"--input-len",
"32",
"--output-len",
"4",
"--num-prompts",
"5",
"--insecure",
]
result = subprocess.run(command, capture_output=True, text=True)
print(result.stdout)
print(result.stderr)
assert result.returncode == 0, f"Benchmark failed: {result.stderr}"
@pytest.mark.benchmark
def test_bench_serve_chat(server):
command = [
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from contextlib import contextmanager
from unittest.mock import MagicMock, patch
import pytest
from vllm.platforms.interface import DeviceCapability
@pytest.fixture
def mock_cuda_platform():
"""
Fixture that returns a factory for creating mocked CUDA platforms.
Usage:
def test_something(mock_cuda_platform):
with mock_cuda_platform(is_cuda=True, capability=(9, 0)):
# test code
"""
@contextmanager
def _mock_platform(is_cuda: bool = True, capability: tuple[int, int] | None = None):
mock_platform = MagicMock()
mock_platform.is_cuda.return_value = is_cuda
if capability is not None:
mock_platform.get_device_capability.return_value = DeviceCapability(
*capability
)
with patch("vllm.platforms.current_platform", mock_platform):
yield mock_platform
return _mock_platform
......@@ -31,7 +31,12 @@ def test_async_tp_pass_correctness(
distributed_backend: str,
eager_mode: bool,
num_gpus_available: int,
monkeypatch,
):
# Disable FlashInfer FP8 scaled_mm kernel as it is incompatible with
# async TP patterns. No-op on H100 (kernel requires CC >= 100).
monkeypatch.setenv("VLLM_DISABLED_KERNELS", "FlashInferFP8ScaledMMLinearKernel")
model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
model_info.check_transformers_version(on_fail="skip")
model_info.check_available_online(on_fail="skip")
......
......@@ -229,7 +229,7 @@ def _compare_sp(
if chunked_prefill:
common_args.append("--enable-chunked-prefill")
if eager_mode:
common_args.append("--enforce-eager")
common_args.append("-cc.cudagraph_mode=none")
if runner != "auto":
common_args.extend(["--runner", runner])
if trust_remote_code:
......
......@@ -27,10 +27,29 @@ from ...utils import create_new_process_for_each_test
from ..silly_attention import get_global_counter, reset_global_counter
# Custom op that returns an unbacked symint during graph capture
@torch.library.custom_op("mylib::foo", mutates_args=())
def foo(x: torch.Tensor) -> int:
return 3
@foo.register_fake
def _(x):
return torch.library.get_ctx().new_dynamic_size()
@support_torch_compile
class SillyModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs) -> None:
def __init__(
self,
*,
vllm_config: VllmConfig,
prefix: str = "",
intermediate_unbacked=False,
**kwargs,
) -> None:
super().__init__()
self.intermediate_unbacked = intermediate_unbacked
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
......@@ -44,6 +63,13 @@ class SillyModel(nn.Module):
torch.ops.silly.attention(x, x, x, out)
x = out
x = x - 2
if self.intermediate_unbacked:
# Test for unbacked symints: the following is a fancy way to multiply by 1
u0 = foo(x)
ones = x.new_ones(x.shape[0], u0).sum(-1) / 3
x = x * ones
x = x - 1
out = torch.empty_like(x)
torch.ops.silly.attention(x, x, x, out)
......@@ -52,6 +78,7 @@ class SillyModel(nn.Module):
return x
@torch._dynamo.config.patch(capture_dynamic_output_shape_ops=True)
def _run_simple_model(
splitting_ops,
use_inductor_graph_partition,
......@@ -60,6 +87,8 @@ def _run_simple_model(
expected_num_piecewise_capturable_graphs_seen,
expected_num_backend_compilations,
expected_num_cudagraph_captured,
*,
intermediate_unbacked=False,
):
vllm_config = VllmConfig(
compilation_config=CompilationConfig(
......@@ -72,7 +101,11 @@ def _run_simple_model(
)
)
with set_current_vllm_config(vllm_config):
model = SillyModel(vllm_config=vllm_config, prefix="")
model = SillyModel(
vllm_config=vllm_config,
prefix="",
intermediate_unbacked=intermediate_unbacked,
)
inputs = torch.randn(100).cuda()
......@@ -125,9 +158,10 @@ def _run_simple_model(
@pytest.mark.parametrize("backend", ["inductor", "eager"])
@pytest.mark.parametrize("intermediate_unbacked", [True, False])
@torch.inference_mode()
@create_new_process_for_each_test("spawn")
def test_simple_piecewise_compile(backend):
def test_simple_piecewise_compile(backend, intermediate_unbacked):
_run_simple_model(
splitting_ops=["silly::attention"],
use_inductor_graph_partition=False,
......@@ -140,6 +174,7 @@ def test_simple_piecewise_compile(backend):
expected_num_backend_compilations=3,
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
expected_num_cudagraph_captured=6,
intermediate_unbacked=intermediate_unbacked,
)
......
......@@ -13,6 +13,7 @@ from vllm.v1.attention.backends.registry import AttentionBackendEnum
class Matches(NamedTuple):
# simple pointwise
aiter_rms_quant_fusion: int = 0
rms_quant_fusion: int = 0
act_quant_fusion: int = 0
norm_rope_fusion: int = 0
......@@ -82,6 +83,9 @@ INDUCTOR_GRAPH_PARTITION = [
]
FUSION_LOG_PATTERNS: dict[str, re.Pattern] = {
"aiter_rms_quant_fusion": re.compile(
r"RocmAiterRMSNormQuantFusionPass Replaced (\d+) patterns"
),
"rms_quant_fusion": re.compile(r"rms_quant_fusion.py:\d+] Replaced (\d+) patterns"),
"act_quant_fusion": re.compile(r"act_quant_fusion.py:\d+] Replaced (\d+) patterns"),
"norm_rope_fusion": re.compile(
......
......@@ -46,10 +46,10 @@ def run_model(compile_config: int | CompilationConfig, model: str, **model_kwarg
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
# Get the compile ranges split points after vllm config post init
# Get the compile ranges endpoints after vllm config post init
# in order to compute compile ranges correctly
compilation_config.compile_ranges_split_points = (
llm.llm_engine.vllm_config.compilation_config.compile_ranges_split_points
compilation_config.compile_ranges_endpoints = (
llm.llm_engine.vllm_config.compilation_config.compile_ranges_endpoints
)
......@@ -63,9 +63,24 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
compilation_config: dict,
matches_check: list[str],
use_deepgemm: bool = False,
use_aiter: bool = False,
tp_size: int = 1,
):
monkeypatch.setenv("VLLM_USE_DEEP_GEMM", "1" if use_deepgemm else "0")
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1" if use_aiter else "0")
from vllm._aiter_ops import rocm_aiter_ops
rocm_aiter_ops.refresh_env_variables()
# Filter here to reduce code duplication
requires_mla = "deepseek" in model_name.lower()
is_mla = "mla" in attn_backend.backend.name.lower()
if requires_mla != is_mla:
pytest.skip(
f"Incompatible model '{model_name}' and "
f"attention backend '{attn_backend.backend.name}'"
)
# Disable, compile cache to make sure custom passes run.
# Otherwise, we can't verify fusion happened through the logs.
......@@ -94,7 +109,7 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
run_model(full_compilation_config, model_name, **model_kwargs)
num_compile_ranges = len(full_compilation_config.get_compile_ranges())
assert num_compile_ranges in [1, 2]
assert num_compile_ranges in [1, 2, 3]
print(f"Compile ranges: {full_compilation_config.get_compile_ranges()}")
print("Fusion results:")
......@@ -107,12 +122,33 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
# Now check the matches
for match_name in matches_check:
num_ranges_activated = (
1 if match_name == "ar_rms_fusion" else num_compile_ranges
)
n_expected = tp_size * num_ranges_activated
log_matches = list(int(ms) for ms in log_matches_dict[match_name])
# AR+RMS skips the largest range; SP skips the smallest.
# When both are enabled, AR+RMS activation count is
# model-dependent (hidden_size affects threshold), so derive
# from log data.
if (
match_name == "ar_rms_fusion"
and "sequence_parallel" in matches_check
and num_compile_ranges >= 2
):
assert (
len(log_matches) >= tp_size and len(log_matches) % tp_size == 0
), (
f"Expected multiple of {tp_size} ar_rms log entries, "
f"found {len(log_matches)}"
)
num_ranges_activated = len(log_matches) // tp_size
elif (
match_name in ("ar_rms_fusion", "sequence_parallel")
and num_compile_ranges >= 2
):
num_ranges_activated = num_compile_ranges - 1
else:
num_ranges_activated = num_compile_ranges
n_expected = tp_size * num_ranges_activated
assert len(log_matches) == n_expected, (
f"Could not find {n_expected} {match_name} "
f"(found {len(log_matches)}) in:\n {log_holder.text}"
......@@ -122,8 +158,8 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
if match_name == "rms_quant_fusion" and "ar_rms_fusion" in matches_check:
# AR+rms+quant takes precedence over rms+quant if activated.
# That means we get full matching where ar+rms+quant was not activated,
# and less where it was
# That means we get full matching where ar+rms+quant was not
# activated, and less where it was (only the smallest range).
assert sum(m == expected_matches for m in log_matches) == tp_size * (
num_ranges_activated - 1
), "Expecting full rms+quant fusion where ar+rms+quant not activated"
......@@ -135,6 +171,43 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
f"Expecting at least {expected_matches - matches.ar_rms_fusion} "
f"where ar+rms+quant was activated"
)
elif (
match_name == "async_tp"
and "sequence_parallel" in matches_check
and num_compile_ranges >= 2
):
# AsyncTP only finds patterns on ranges where SP ran.
n_sp_ranges = num_compile_ranges - 1
assert (
sum(m == expected_matches for m in log_matches)
== tp_size * n_sp_ranges
), (
f"Expecting {expected_matches} async_tp on "
f"{tp_size * n_sp_ranges} SP-range entries, "
f"found: {log_matches}"
)
assert sum(m == 0 for m in log_matches) == tp_size, (
f"Expecting 0 async_tp on {tp_size} small-range entries "
f"(no SP), found: {log_matches}"
)
elif (
match_name == "ar_rms_fusion"
and "sequence_parallel" in matches_check
and num_compile_ranges >= 2
):
# SP consumes allreduce patterns first, so AR+RMS finds
# full matches only on the smallest range (no SP).
assert sum(m == expected_matches for m in log_matches) == tp_size, (
f"Expecting {expected_matches} ar_rms on "
f"{tp_size} small-range entries, found: {log_matches}"
)
assert sum(m == 0 for m in log_matches) == tp_size * (
num_ranges_activated - 1
), (
f"Expecting 0 ar_rms on "
f"{tp_size * (num_ranges_activated - 1)} large-range "
f"entries (SP took precedence), found: {log_matches}"
)
else:
expected_matches_list = [expected_matches] * n_expected
assert sorted(log_matches) == expected_matches_list, (
......@@ -142,7 +215,7 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
f"found: {sorted(log_matches)}"
)
if match_name == "ar_rms_fusion":
if match_name == "ar_rms_fusion" and num_compile_ranges >= 2:
log_matches = re.findall(
r"pass_manager.py:\d+] Skipping "
r".*AllReduceFusionPass.* with compile range",
......@@ -155,4 +228,17 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
f"(found {len(log_matches)}) in:\n {log_holder.text}"
)
if match_name == "sequence_parallel" and num_compile_ranges >= 2:
log_matches = re.findall(
r"pass_manager.py:\d+] Skipping "
r".*SequenceParallelismPass.* with compile range",
log_holder.text,
)
n_expected = tp_size * (num_compile_ranges - num_ranges_activated)
assert len(log_matches) == n_expected, (
f'Could not find {n_expected} "Skipping SequenceParallelismPass" '
f"(found {len(log_matches)}) in:\n {log_holder.text}"
)
return run
......@@ -2,6 +2,8 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from vllm._aiter_ops import is_aiter_found_and_supported
from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer
from vllm.v1.attention.backends.registry import AttentionBackendEnum
......@@ -24,6 +26,38 @@ TRITON_ATTN = pytest.param(
AttentionBackendCase(backend=AttentionBackendEnum.TRITON_ATTN), id="TRITON_ATTN"
)
ROCM_ATTN = pytest.param(
AttentionBackendCase(backend=AttentionBackendEnum.ROCM_ATTN),
id="ROCM_ATTN",
marks=pytest.mark.skipif(
not current_platform.is_rocm(),
reason="ROCm attention only for AMD",
),
)
ROCM_AITER_UNIFIED_ATTN = pytest.param(
AttentionBackendCase(backend=AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN),
id="ROCM_AITER_UNIFIED_ATTN",
marks=pytest.mark.skipif(
not is_aiter_found_and_supported(),
reason="ROCM_AITER_UNIFIED_ATTN only for AMD when AITER is installed",
),
)
FLASHINFER_MLA_ATTN = pytest.param(
AttentionBackendCase(backend=AttentionBackendEnum.FLASHINFER_MLA),
id="FLASHINFER_MLA",
marks=pytest.mark.skipif(
not is_blackwell() or not has_flashinfer(),
reason="FI backend requires Blackwell and FlashInfer",
),
)
TRITON_MLA_ATTN = pytest.param(
AttentionBackendCase(backend=AttentionBackendEnum.TRITON_MLA),
id="TRITON_MLA",
)
# Models
llama3_8b = ModelFusionInfo(
model_name="meta-llama/Llama-3.1-8B-Instruct",
......@@ -49,7 +83,6 @@ llama3_8b_fp8 = ModelFusionInfo(
llama3_8b_fp4 = ModelFusionInfo(
model_name="nvidia/Llama-3.1-8B-Instruct-FP4",
matches=lambda n_layers: Matches(
rms_quant_fusion=0,
act_quant_fusion=n_layers,
attn_quant_fusion=n_layers,
ar_rms_fusion=n_layers * 2 + 1,
......@@ -79,7 +112,6 @@ llama4_scout_fp4 = ModelFusionInfo(
model_name="nvidia/Llama-4-Scout-17B-16E-Instruct-NVFP4",
hf_overrides=lambda n_layers: {"text_config": {"num_hidden_layers": n_layers}},
matches=lambda n_layers: Matches(
rms_quant_fusion=0,
attn_quant_fusion=n_layers,
ar_rms_fusion=n_layers * 2,
sequence_parallel=n_layers * 2,
......@@ -108,3 +140,25 @@ qwen3_a3b_fp8 = ModelFusionInfo(
async_tp=n_layers * 2,
),
)
deepseek_v3_fp8 = ModelFusionInfo(
model_name="deepseek-ai/DeepSeek-V3",
matches=lambda n_layers: Matches(
# 3 per dense layer (first 3):
# - input_rms + qkv_proj
# - q_a_layernorm + q_b_proj (inside MLA wrapper)
# - post_attn_layernorm + MLP
# 2 per MoE layer (remaining) due to MoE wrapping
rms_quant_fusion=n_layers * 2 + min(3, n_layers), # add for 3 dense layers
# TODO silu+block quant
# act_quant_fusion=min(3, n_layers), # dense layers only
act_quant_fusion=0,
# MLA attn + quant not supported yet:
# https://github.com/vllm-project/vllm/issues/35792
attn_quant_fusion=0,
ar_rms_fusion=n_layers * 2 + 1,
# TODO
# sequence_parallel= n_layers * 2 + 1,
# async_tp=n_layers * 2,
),
)
......@@ -5,6 +5,8 @@ from collections.abc import Callable
import pytest
from vllm.config import PassConfig
from vllm.platforms import current_platform
from vllm.utils.flashinfer import is_flashinfer_fp8_blockscale_gemm_supported
from .common import (
INDUCTOR_GRAPH_PARTITION,
......@@ -15,7 +17,12 @@ from .common import (
)
from .models import (
FLASHINFER_ATTN,
FLASHINFER_MLA_ATTN,
ROCM_AITER_UNIFIED_ATTN,
ROCM_ATTN,
TRITON_ATTN,
TRITON_MLA_ATTN,
deepseek_v3_fp8,
llama3_8b_fp4,
llama3_8b_fp8,
llama4_scout_fp4,
......@@ -28,12 +35,31 @@ from .models import (
"model_name, matches_fn, model_kwargs, hf_overrides, use_deepgemm",
[
(*llama3_8b_fp8, False),
(*llama4_scout_fp8, False),
(*qwen3_a3b_fp8, False),
(*qwen3_a3b_fp8, True),
(*deepseek_v3_fp8, False),
(*deepseek_v3_fp8, True),
pytest.param(
*llama4_scout_fp8,
False,
marks=pytest.mark.skipif(
not current_platform.is_cuda(),
reason="Llama4 Scout FP8 only supported on CUDA",
),
),
],
)
@pytest.mark.parametrize(
"attn_backend",
[
TRITON_ATTN,
FLASHINFER_ATTN,
ROCM_ATTN,
ROCM_AITER_UNIFIED_ATTN,
FLASHINFER_MLA_ATTN,
TRITON_MLA_ATTN,
],
)
@pytest.mark.parametrize("attn_backend", [TRITON_ATTN, FLASHINFER_ATTN])
@pytest.mark.parametrize("n_layers", [6])
@pytest.mark.parametrize("custom_ops", custom_ops_combos("quant_fp8", "rms_norm"))
@pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION)
......@@ -50,15 +76,22 @@ def test_tp1_fp8_fusions(
run_e2e_fusion_test,
monkeypatch,
):
if use_deepgemm:
# TODO(luka/eliza) DeepGEMM uses different quants, matching not supported
if use_deepgemm and not current_platform.is_cuda():
pytest.skip("DeepGemm only supported on CUDA")
if use_deepgemm and is_flashinfer_fp8_blockscale_gemm_supported():
# Flashinfer block FP8 GEMM has internal quantization, so it can't
# be fused with other ops.
pytest.skip("FlashInfer block FP8 GEMM not supported")
if use_deepgemm and is_blackwell():
# TODO(luka) DeepGEMM uses different quants, matching not supported
# - on Blackwell, uses a special quant fp8, currently not supported
# - on Hopper, tma-aligned scales inhibit matching (fix WIP)
pytest.skip("DeepGEMM & quant matching not currently supported")
matches = matches_fn(n_layers)
if "qwen" in model_name.lower() and "-quant_fp8" in custom_ops:
block_fp8 = "qwen" in model_name.lower() or "deepseek" in model_name.lower()
if block_fp8 and "-quant_fp8" in custom_ops:
# This is why config forces +quant_fp8 by default
pytest.skip("native QuantFP8 matching not supported for group quant")
......@@ -66,7 +99,6 @@ def test_tp1_fp8_fusions(
model_kwargs["hf_overrides"] = hf_overrides(n_layers)
model_kwargs["load_format"] = "dummy"
model_kwargs["max_model_len"] = 1024
compilation_config = dict(
use_inductor_graph_partition=inductor_graph_partition,
custom_ops=custom_ops.split(","),
......@@ -78,6 +110,8 @@ def test_tp1_fp8_fusions(
),
)
use_aiter = current_platform.is_rocm() and ("qwen" in model_name.lower())
matches_check = [
"rms_quant_fusion",
"act_quant_fusion",
......@@ -85,6 +119,15 @@ def test_tp1_fp8_fusions(
"attn_quant_fusion",
]
if use_aiter:
matches_check[0] = "aiter_rms_quant_fusion"
matches = matches._replace(aiter_rms_quant_fusion=matches.rms_quant_fusion)
# TODO: enable the `norm_rope_fusion` test,
# On ROCm norm_rope_fusion is only supported without
# enabling AITER.
matches_check.remove("norm_rope_fusion")
run_e2e_fusion_test(
model_name,
matches,
......@@ -93,6 +136,7 @@ def test_tp1_fp8_fusions(
compilation_config,
matches_check,
use_deepgemm=use_deepgemm,
use_aiter=use_aiter,
)
......
......@@ -5,6 +5,7 @@ from collections.abc import Callable
import pytest
from vllm.config import PassConfig
from vllm.platforms import current_platform
from ...utils import multi_gpu_test
from .common import (
......@@ -16,7 +17,9 @@ from .common import (
)
from .models import (
FLASHINFER_ATTN,
FLASHINFER_MLA_ATTN,
TRITON_ATTN,
deepseek_v3_fp8,
llama3_8b,
llama3_8b_fp4,
llama3_8b_fp8,
......@@ -26,14 +29,18 @@ from .models import (
qwen3_a3b_fp8,
)
pytestmark = pytest.mark.skipif(not current_platform.is_cuda(), reason="Only test CUDA")
@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize(
"model_name, matches_fn, model_kwargs, hf_overrides",
# qwen3-fp8 should still fuse AR+rms even though group quant is not yet supported
[llama3_8b_fp8, llama4_scout_fp8, qwen3_a3b_fp8],
# qwen3 & dsv3 should still fuse AR+rms even though group quant is not yet supported
[llama3_8b_fp8, llama4_scout_fp8, qwen3_a3b_fp8, deepseek_v3_fp8],
)
@pytest.mark.parametrize(
"attn_backend", [TRITON_ATTN, FLASHINFER_ATTN, FLASHINFER_MLA_ATTN]
)
@pytest.mark.parametrize("attn_backend", [TRITON_ATTN, FLASHINFER_ATTN])
@pytest.mark.parametrize("n_layers", [4])
@pytest.mark.parametrize("custom_ops", custom_ops_combos("quant_fp8", "rms_norm"))
@pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION)
......@@ -51,7 +58,8 @@ def test_tp2_ar_rms_fp8_fusions(
):
matches = matches_fn(n_layers)
if "qwen" in model_name.lower() and "-quant_fp8" in custom_ops:
block_fp8 = "qwen" in model_name.lower() or "deepseek" in model_name.lower()
if block_fp8 and "-quant_fp8" in custom_ops:
# This is why config forces +quant_fp8 by default
pytest.skip("native QuantFP8 matching not supported for group quant")
......
......@@ -5,6 +5,7 @@ from collections.abc import Callable
import pytest
from vllm.config import PassConfig
from vllm.platforms import current_platform
from ...utils import multi_gpu_test
from .common import (
......@@ -23,6 +24,8 @@ from .models import (
qwen3_a3b,
)
pytestmark = pytest.mark.skipif(not current_platform.is_cuda(), reason="Only test CUDA")
@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize(
......@@ -66,6 +69,9 @@ def test_tp2_async_tp_fp8_fusions(
enable_qk_norm_rope_fusion=True,
enable_sp=True,
fuse_gemm_comms=True,
fuse_allreduce_rms=False,
# Override threshold for testing (models have small hidden_size)
sp_min_token_num=512,
),
)
......@@ -123,11 +129,141 @@ def test_tp2_async_tp_fusions(
enable_qk_norm_rope_fusion=True,
enable_sp=True,
fuse_gemm_comms=True,
fuse_allreduce_rms=False,
# Override threshold for testing (models have small hidden_size)
sp_min_token_num=512,
),
)
matches_check = [
"norm_rope_fusion",
"sequence_parallel",
"async_tp",
]
run_e2e_fusion_test(
model_name,
matches,
model_kwargs,
attn_backend,
compilation_config,
matches_check,
tp_size=2,
)
@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize(
"model_name, matches_fn, model_kwargs, hf_overrides",
[llama3_8b_fp8, llama4_scout_fp8],
)
@pytest.mark.parametrize("attn_backend", [TRITON_ATTN, FLASHINFER_ATTN])
@pytest.mark.parametrize("n_layers", [4])
@pytest.mark.parametrize("custom_ops", custom_ops_combos("quant_fp8", "rms_norm"))
@pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION)
def test_tp2_sp_ar_rms_fp8_fusions(
model_name: str,
matches_fn: Callable[[int], Matches],
model_kwargs: dict,
hf_overrides: Callable[[int], dict],
attn_backend: AttentionBackendCase,
n_layers: int,
custom_ops: str,
inductor_graph_partition: bool,
run_e2e_fusion_test,
monkeypatch,
):
matches = matches_fn(n_layers)
if is_blackwell():
# Disable FlashInfer scaled_mm FP8 as it's not supported in async tp patterns
monkeypatch.setenv("VLLM_DISABLED_KERNELS", "FlashInferFP8ScaledMMLinearKernel")
# Reduce size of model and skip weight loading time
model_kwargs["hf_overrides"] = hf_overrides(n_layers)
model_kwargs["load_format"] = "dummy"
model_kwargs["max_model_len"] = 1024
compilation_config = dict(
use_inductor_graph_partition=inductor_graph_partition,
custom_ops=custom_ops.split(","),
pass_config=PassConfig(
fuse_norm_quant=True,
fuse_act_quant=True,
fuse_attn_quant=True,
enable_qk_norm_rope_fusion=True,
enable_sp=True,
fuse_gemm_comms=True,
fuse_allreduce_rms=True,
# Override threshold for testing (models have small hidden_size)
sp_min_token_num=512,
),
)
matches_check = [
"rms_quant_fusion",
"act_quant_fusion",
"norm_rope_fusion",
"attn_quant_fusion",
"ar_rms_fusion",
"sequence_parallel",
"async_tp",
]
run_e2e_fusion_test(
model_name,
matches,
model_kwargs,
attn_backend,
compilation_config,
matches_check,
tp_size=2,
)
@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize(
"model_name, matches_fn, model_kwargs, hf_overrides",
[llama3_8b, qwen3_a3b],
)
@pytest.mark.parametrize("attn_backend", [TRITON_ATTN])
@pytest.mark.parametrize("n_layers", [4])
@pytest.mark.parametrize("custom_ops", custom_ops_combos("rms_norm"))
@pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION)
def test_tp2_sp_ar_rms_fusions(
model_name: str,
matches_fn: Callable[[int], Matches],
model_kwargs: dict,
hf_overrides: Callable[[int], dict],
attn_backend: AttentionBackendCase,
n_layers: int,
custom_ops: str,
inductor_graph_partition: bool,
run_e2e_fusion_test,
):
matches = matches_fn(n_layers)
# Reduce size of model and skip weight loading time
model_kwargs["hf_overrides"] = hf_overrides(n_layers)
model_kwargs["load_format"] = "dummy"
model_kwargs["max_model_len"] = 1024
compilation_config = dict(
use_inductor_graph_partition=inductor_graph_partition,
custom_ops=custom_ops.split(","),
pass_config=PassConfig(
enable_qk_norm_rope_fusion=True,
enable_sp=True,
fuse_gemm_comms=True,
fuse_allreduce_rms=True,
# Override threshold for testing (models have small hidden_size)
sp_min_token_num=512,
),
)
matches_check = [
"norm_rope_fusion",
"ar_rms_fusion",
"sequence_parallel",
"async_tp",
]
......
......@@ -300,7 +300,7 @@ def async_tp_pass_on_test_model(
set_random_seed(0)
device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(device)
torch.accelerator.set_device_index(device)
torch.set_default_device(device)
torch.set_default_dtype(dtype)
......@@ -316,7 +316,6 @@ def async_tp_pass_on_test_model(
# initialize distributed
init_distributed_environment()
initialize_model_parallel(tensor_model_parallel_size=world_size)
# configure vllm config for SequenceParallelismPass
vllm_config = VllmConfig()
......@@ -334,11 +333,10 @@ def async_tp_pass_on_test_model(
model=model_name, trust_remote_code=True, dtype=dtype, seed=42
)
async_tp_pass = AsyncTPPass(vllm_config)
# Set the global vllm_config for TestBackend which calls
# get_current_vllm_config()
with set_current_vllm_config(vllm_config):
initialize_model_parallel(tensor_model_parallel_size=world_size)
async_tp_pass = AsyncTPPass(vllm_config)
backend = TestBackend(async_tp_pass)
assert (
......
......@@ -142,7 +142,6 @@ class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module):
*(scaled_fp4_quant(w, wg) for w, wg in zip(self.w, wgscale))
)
self.wq, self.wscale = list(wq_gen), list(wscale_gen)
print(f"{self.wq=}, {self.wscale=}")
def forward(self, hidden_states):
# avoid having graph input be an arg to a pattern directly
......@@ -180,7 +179,7 @@ class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module):
def ops_in_model_before(self):
return [
torch.ops.vllm.all_reduce.default,
torch.ops._C.scaled_fp4_quant.default,
torch.ops._C.scaled_fp4_quant.out,
]
......@@ -199,12 +198,14 @@ class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module):
@pytest.mark.parametrize("hidden_size", [64])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False])
@pytest.mark.parametrize("flashinfer_allreduce_backend", ["trtllm", "mnnvl"])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
@pytest.mark.skipif(
not find_spec("flashinfer")
or not has_module_attribute("flashinfer.comm", "trtllm_allreduce_fusion"),
or not has_module_attribute("flashinfer.comm", "allreduce_fusion")
or not has_module_attribute("flashinfer.comm", "create_allreduce_fusion_workspace"),
reason="flashinfer is not found or flashinfer "
"is not compiled with trtllm_allreduce_fusion",
"is not compiled with allreduce_fusion",
)
def test_all_reduce_fusion_pass_replace(
test_model: torch.nn.Module,
......@@ -214,6 +215,7 @@ def test_all_reduce_fusion_pass_replace(
dtype: torch.dtype,
enable_rms_norm_custom_op,
enable_quant_fp8_custom_op,
flashinfer_allreduce_backend,
):
num_processes = 2
if (
......@@ -237,6 +239,7 @@ def test_all_reduce_fusion_pass_replace(
dtype,
enable_rms_norm_custom_op,
enable_quant_fp8_custom_op,
flashinfer_allreduce_backend,
),
nprocs=nprocs,
)
......@@ -254,11 +257,12 @@ def all_reduce_fusion_pass_on_test_model(
dtype: torch.dtype,
enable_rms_norm_custom_op,
enable_quant_fp8_custom_op,
flashinfer_allreduce_backend,
):
set_random_seed(0)
device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(device)
torch.accelerator.set_device_index(device)
torch.set_default_device(device)
torch.set_default_dtype(dtype)
......@@ -269,11 +273,11 @@ def all_reduce_fusion_pass_on_test_model(
"WORLD_SIZE": str(world_size),
"MASTER_ADDR": "localhost",
"MASTER_PORT": "12345",
"VLLM_FLASHINFER_ALLREDUCE_BACKEND": flashinfer_allreduce_backend,
}
)
init_distributed_environment()
initialize_model_parallel(tensor_model_parallel_size=world_size)
custom_ops = []
if enable_rms_norm_custom_op:
......@@ -299,6 +303,7 @@ def all_reduce_fusion_pass_on_test_model(
model=model_name, trust_remote_code=True, dtype=dtype, seed=42
)
with set_current_vllm_config(vllm_config):
initialize_model_parallel(tensor_model_parallel_size=world_size)
all_reduce_fusion_pass = AllReduceFusionPass(vllm_config)
noop_pass = NoOpEliminationPass(vllm_config)
func_pass = FixFunctionalizationPass(vllm_config)
......@@ -316,6 +321,10 @@ def all_reduce_fusion_pass_on_test_model(
compiled_model = torch.compile(model, backend=backend)
compiled_model(hidden_states)
results_unfused = model(hidden_states)
results_fused = compiled_model(hidden_states)
torch.testing.assert_close(results_unfused, results_fused, atol=1e-2, rtol=1e-2)
assert all_reduce_fusion_pass.matched_count == 4, (
f"{all_reduce_fusion_pass.matched_count=}"
)
......
......@@ -36,6 +36,8 @@ from vllm.platforms import current_platform
from vllm.utils.system_utils import update_environment_variables
from vllm.utils.torch_utils import set_random_seed
pytestmark = pytest.mark.skipif(not current_platform.is_cuda(), reason="Only test CUDA")
FP8_DTYPE = current_platform.fp8_dtype()
prompts = [
"Hello, my name is",
......@@ -226,7 +228,7 @@ def sequence_parallelism_pass_on_test_model(
set_random_seed(0)
device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(device)
torch.accelerator.set_device_index(device)
torch.set_default_device(device)
torch.set_default_dtype(dtype)
......@@ -242,7 +244,6 @@ def sequence_parallelism_pass_on_test_model(
# initialize distributed
init_distributed_environment()
initialize_model_parallel(tensor_model_parallel_size=world_size)
# configure vllm config for SequenceParallelismPass
custom_ops_list = custom_ops.split(",") if custom_ops else []
......@@ -272,6 +273,7 @@ def sequence_parallelism_pass_on_test_model(
)
with set_current_vllm_config(vllm_config):
initialize_model_parallel(tensor_model_parallel_size=world_size)
noop_pass = NoOpEliminationPass(vllm_config)
sequence_parallelism_pass = SequenceParallelismPass(vllm_config)
cleanup_pass = PostCleanupPass(vllm_config)
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
import pytest
import torch
import vllm.envs as envs
from tests.compile.backend import TestBackend
from tests.utils import TestFP8Layer
from vllm.compilation.passes.fusion.act_quant_fusion import (
......@@ -31,6 +32,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op
TEST_FP8 = current_platform.supports_fp8()
FP8_DTYPE = current_platform.fp8_dtype()
......@@ -198,23 +200,82 @@ class TestRotaryEmbeddingSliceScatter(torch.nn.Module):
return [torch.ops.aten.slice_scatter.default]
MODELS = [
TestSiluMul,
TestFusedAddRMSNorm,
TestRotaryEmbedding,
TestRotaryEmbeddingSliceScatter,
]
class TestFunctionWithMutatedArgsAndReturn(torch.nn.Module):
OP_REGISTERED = False
def __init__(self):
super().__init__()
self.register_test_custom_op()
@classmethod
def register_test_custom_op(cls):
if not cls.OP_REGISTERED:
def function_with_mutated_args_and_return_impl(
x: torch.Tensor,
) -> torch.Tensor:
ret = x + 1
x.add_(2)
return ret
def function_with_mutated_args_and_return_fake(
x: torch.Tensor,
) -> torch.Tensor:
return torch.empty_like(x)
direct_register_custom_op(
op_name="function_with_mutated_args_and_return",
op_func=function_with_mutated_args_and_return_impl,
mutates_args=["x"],
fake_impl=function_with_mutated_args_and_return_fake,
)
cls.OP_REGISTERED = True
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
# Clone x to avoid mutating the original tensor
ret = torch.ops.vllm.function_with_mutated_args_and_return(x)
return x, ret
def example_inputs(self, num_tokens=32):
hidden_states = torch.randn(num_tokens)
return (hidden_states,)
def ops_in_model(self, do_fusion):
return [torch.ops.vllm.function_with_mutated_args_and_return.default]
def ops_not_in_model(self):
return []
MODELS_AND_DO_FUSION = {
TestSiluMul: [True, False],
TestFusedAddRMSNorm: [True, False],
TestRotaryEmbedding: [False],
TestRotaryEmbeddingSliceScatter: [False],
TestFunctionWithMutatedArgsAndReturn: [False],
}
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("model_class", MODELS)
@pytest.mark.parametrize("do_fusion", [True, False])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda", reason="Only test on CUDA")
@pytest.mark.parametrize(
"model_class, do_fusion",
[
(model_class, do_fusion)
for model_class, fusions in MODELS_AND_DO_FUSION.items()
for do_fusion in fusions
],
)
@pytest.mark.skipif(
not current_platform.is_cuda_alike(),
reason="Only test on cuda and rocm platform",
)
def test_fix_functionalization(
model_class: torch.nn.Module, do_fusion: bool, dtype: torch.dtype
):
torch.set_default_device("cuda")
torch.set_default_dtype(dtype)
torch.manual_seed(0)
vllm_config = VllmConfig(
model_config=ModelConfig(dtype=dtype),
......@@ -246,8 +307,17 @@ def test_fix_functionalization(
backend_no_func = TestBackend(*passes)
model = model_class()
torch.compile(model, backend=backend_func)(*model.example_inputs())
torch.compile(model, backend=backend_no_func)(*model.example_inputs())
inputs_func = model.example_inputs()
inputs_no_func = copy.deepcopy(inputs_func)
model_func = copy.deepcopy(model)
model_no_func = copy.deepcopy(model)
model_func = torch.compile(model_func, backend=backend_func)
model_no_func = torch.compile(model_no_func, backend=backend_no_func)
# deepcopy inputs to prevent potential in place mutation
outputs_func = model_func(*copy.deepcopy(inputs_func))
outputs_no_func = model_no_func(*copy.deepcopy(inputs_no_func))
torch.testing.assert_close(outputs_func, outputs_no_func)
# check if the functionalization pass is applied
for op in model.ops_in_model(do_fusion):
......
......@@ -26,24 +26,16 @@ from vllm.config import (
PassConfig,
VllmConfig,
)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import (
from vllm.model_executor.kernels.linear import (
ChannelWiseTorchFP8ScaledMMLinearKernel,
CutlassFP8ScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.flashinfer import (
FlashInferFP8ScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.pytorch import (
ChannelWiseTorchFP8ScaledMMLinearKernel,
FP8ScaledMMLinearKernel,
PerTensorTorchFP8ScaledMMLinearKernel,
RowWiseTorchFP8ScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.rocm import (
ROCmFP8ScaledMMLinearKernel,
RowWiseTorchFP8ScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
FP8ScaledMMLinearKernel,
)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
QuantKey,
......
......@@ -92,6 +92,8 @@ class AttentionQuantPatternModel(torch.nn.Module):
def build_attn_metadata(self, batch_size: int) -> AttentionMetadata:
"""Initialize attention metadata."""
# TODO (Rohan138) reuse utils from vllm/v1/worker/gpu/attn_utils.py
# Create common attn metadata
batch_spec = BatchSpec(seq_lens=[1] * batch_size, query_lens=[1] * batch_size)
common_attn_metadata = create_common_attn_metadata(
......@@ -100,58 +102,31 @@ class AttentionQuantPatternModel(torch.nn.Module):
max_blocks = (max(batch_spec.seq_lens) + self.block_size - 1) // self.block_size
num_blocks = batch_size * max_blocks
backend = self.attn.backend
# TODO(luka) use get_kv_cache_stride_order
# Create dummy KV cache for the selected backend
if backend == AttentionBackendEnum.ROCM_ATTN:
# k/v as 1st dimention
# HND: [num_blocks, num_kv_heads, block_size, head_size]
kv_cache = torch.zeros(
2,
num_blocks,
self.num_kv_heads,
self.block_size,
self.head_size,
dtype=self.kv_cache_dtype,
device=self.device,
)
elif backend == AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN:
# k/v as 1st dimention
# NHD: [num_blocks, block_size, num_kv_heads, head_size]
kv_cache = torch.zeros(
2,
num_blocks,
self.block_size,
self.num_kv_heads,
self.head_size,
dtype=self.kv_cache_dtype,
device=self.device,
)
elif backend == AttentionBackendEnum.TRITON_ATTN:
# k/v as 2nd dimention
# NHD: [num_blocks, block_size, num_kv_heads, head_size]
kv_cache = torch.zeros(
num_blocks,
2,
self.num_kv_heads,
self.block_size,
self.head_size,
dtype=self.kv_cache_dtype,
device=self.device,
)
elif backend == AttentionBackendEnum.FLASHINFER:
kv_cache = torch.zeros(
num_blocks,
2,
self.num_kv_heads,
self.block_size,
self.head_size,
dtype=self.kv_cache_dtype,
device=self.device,
).permute(0, 1, 3, 2, 4)
else:
raise ValueError(f"Unsupported backend: {backend}")
# Fetch the attention backend and kv cache shape and stride order
attn_backend = self.attn.attn_backend
kv_cache_shape = attn_backend.get_kv_cache_shape(
num_blocks, self.block_size, self.num_kv_heads, self.head_size
)
try:
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order()
except (AttributeError, NotImplementedError):
kv_cache_stride_order = tuple(range(len(kv_cache_shape)))
kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order)
inv_order = [
kv_cache_stride_order.index(i) for i in range(len(kv_cache_stride_order))
]
# Create dummy KV cache
raw_tensor = torch.zeros(
2 * num_blocks * self.block_size * self.num_kv_heads * self.head_size,
dtype=self.kv_cache_dtype,
device=self.device,
)
raw_tensor = raw_tensor.view(kv_cache_shape)
kv_cache = raw_tensor.permute(*inv_order)
self.attn.kv_cache = [kv_cache]
# Build attn metadata
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment