Commit 006693ed authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.11.2' into v0.11.2-ori

parents 4b51e6f1 275de341
......@@ -4,6 +4,7 @@
Run `pytest tests/basic_correctness/test_basic_correctness.py`.
"""
import os
import weakref
from unittest.mock import Mock
......@@ -12,14 +13,14 @@ import pytest
import torch
from vllm import LLM
from vllm.v1.engine.llm_engine import LLMEngine as LLMEngineV1
from vllm.v1.engine.llm_engine import LLMEngine
from ..conftest import HfRunner, VllmRunner
from ..models.utils import check_outputs_equal
from ..utils import multi_gpu_test
MODELS = [
"google/gemma-2-2b-it",
"hmellor/tiny-random-Gemma2ForCausalLM",
"meta-llama/Llama-3.2-1B-Instruct",
]
......@@ -28,7 +29,7 @@ TARGET_TEST_SUITE = os.environ.get("TARGET_TEST_SUITE", "L4")
def test_vllm_gc_ed():
"""Verify vllm instance is GC'ed when it is deleted"""
llm = LLM("distilbert/distilgpt2")
llm = LLM("hmellor/tiny-random-LlamaForCausalLM")
weak_llm = weakref.ref(llm)
del llm
# If there's any circular reference to vllm, this fails
......@@ -37,16 +38,21 @@ def test_vllm_gc_ed():
def _fix_prompt_embed_outputs(
vllm_outputs: list[tuple[list[int], str]], hf_model: HfRunner,
example_prompts: list[str]) -> list[tuple[list[int], str]]:
vllm_outputs: list[tuple[list[int], str]],
hf_model: HfRunner,
example_prompts: list[str],
) -> list[tuple[list[int], str]]:
fixed_vllm_outputs = []
for vllm_output, hf_input, prompt in zip(
vllm_outputs, hf_model.get_inputs(example_prompts),
example_prompts):
vllm_outputs, hf_model.get_inputs(example_prompts), example_prompts
):
hf_input_ids = hf_input["input_ids"].tolist()[0]
fixed_vllm_outputs.append(
(hf_input_ids + vllm_output[0][len(hf_input_ids):],
prompt + vllm_output[1]))
(
hf_input_ids + vllm_output[0][len(hf_input_ids) :],
prompt + vllm_output[1],
)
)
return fixed_vllm_outputs
......@@ -69,8 +75,7 @@ def test_models(
enable_prompt_embeds: bool,
) -> None:
if backend == "XFORMERS" and model == "google/gemma-2-2b-it":
pytest.skip(
f"{backend} does not support gemma2 with full context length.")
pytest.skip(f"{backend} does not support gemma2 with full context length.")
with monkeypatch.context() as m:
m.setenv("VLLM_ATTENTION_BACKEND", backend)
......@@ -78,34 +83,35 @@ def test_models(
# 5042 tokens for gemma2
# gemma2 has alternating sliding window size of 4096
# we need a prompt with more than 4096 tokens to test the sliding window
prompt = "The following numbers of the sequence " + ", ".join(
str(i) for i in range(1024)) + " are:"
prompt = (
"The following numbers of the sequence "
+ ", ".join(str(i) for i in range(1024))
+ " are:"
)
example_prompts = [prompt]
with hf_runner(model) as hf_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
if enable_prompt_embeds:
with torch.no_grad():
prompt_embeds = hf_model.get_prompt_embeddings(
example_prompts)
prompt_embeds = hf_model.get_prompt_embeddings(example_prompts)
with VllmRunner(
model,
max_model_len=8192,
enforce_eager=enforce_eager,
enable_prompt_embeds=enable_prompt_embeds,
gpu_memory_utilization=0.7,
async_scheduling=async_scheduling,
distributed_executor_backend=model_executor,
model,
max_model_len=8192,
enforce_eager=enforce_eager,
enable_prompt_embeds=enable_prompt_embeds,
gpu_memory_utilization=0.7,
async_scheduling=async_scheduling,
distributed_executor_backend=model_executor,
) as vllm_model:
if enable_prompt_embeds:
vllm_outputs = vllm_model.generate_greedy(
prompt_embeds, max_tokens)
vllm_outputs = vllm_model.generate_greedy(prompt_embeds, max_tokens)
vllm_outputs = _fix_prompt_embed_outputs(
vllm_outputs, hf_model, example_prompts)
vllm_outputs, hf_model, example_prompts
)
else:
vllm_outputs = vllm_model.generate_greedy(
example_prompts, max_tokens)
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
check_outputs_equal(
outputs_0_lst=hf_outputs,
......@@ -117,21 +123,18 @@ def test_models(
@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize(
"model, distributed_executor_backend, attention_backend, "
"test_suite, extra_env", [
("distilbert/distilgpt2", "ray", "", "L4", {}),
("distilbert/distilgpt2", "mp", "", "L4", {}),
("distilbert/distilgpt2", "ray", "", "L4", {
"VLLM_SLEEP_WHEN_IDLE": "1"
}),
("distilbert/distilgpt2", "mp", "", "L4", {
"VLLM_SLEEP_WHEN_IDLE": "1"
}),
"model, distributed_executor_backend, attention_backend, test_suite, extra_env",
[
("facebook/opt-125m", "ray", "", "L4", {}),
("facebook/opt-125m", "mp", "", "L4", {}),
("facebook/opt-125m", "ray", "", "L4", {"VLLM_SLEEP_WHEN_IDLE": "1"}),
("facebook/opt-125m", "mp", "", "L4", {"VLLM_SLEEP_WHEN_IDLE": "1"}),
("meta-llama/Llama-3.2-1B-Instruct", "ray", "", "L4", {}),
("meta-llama/Llama-3.2-1B-Instruct", "mp", "", "L4", {}),
("distilbert/distilgpt2", "ray", "", "A100", {}),
("distilbert/distilgpt2", "mp", "", "A100", {}),
])
("facebook/opt-125m", "ray", "", "A100", {}),
("facebook/opt-125m", "mp", "", "A100", {}),
],
)
@pytest.mark.parametrize("enable_prompt_embeds", [True, False])
def test_models_distributed(
monkeypatch: pytest.MonkeyPatch,
......@@ -149,13 +152,14 @@ def test_models_distributed(
pytest.skip(f"Skip test for {test_suite}")
with monkeypatch.context() as monkeypatch_context:
if model == "meta-llama/Llama-3.2-1B-Instruct" and distributed_executor_backend == "ray" and attention_backend == "" and test_suite == "L4": # noqa
if enable_prompt_embeds:
pytest.skip(
"enable_prompt_embeds does not work with ray compiled dag."
)
monkeypatch_context.setenv("VLLM_USE_RAY_SPMD_WORKER", "1")
monkeypatch_context.setenv("VLLM_USE_RAY_COMPILED_DAG", "1")
if (
model == "meta-llama/Llama-3.2-1B-Instruct"
and distributed_executor_backend == "ray"
and attention_backend == ""
and test_suite == "L4"
and enable_prompt_embeds
): # noqa
pytest.skip("enable_prompt_embeds does not work with ray compiled dag.")
if attention_backend:
monkeypatch_context.setenv(
......@@ -175,30 +179,26 @@ def test_models_distributed(
# will hurt multiprocessing backend with fork method
# (the default method).
with vllm_runner(
model,
dtype=dtype,
tensor_parallel_size=2,
distributed_executor_backend=distributed_executor_backend,
enable_prompt_embeds=enable_prompt_embeds,
gpu_memory_utilization=0.7,
model,
dtype=dtype,
tensor_parallel_size=2,
distributed_executor_backend=distributed_executor_backend,
enable_prompt_embeds=enable_prompt_embeds,
gpu_memory_utilization=0.7,
) as vllm_model:
if enable_prompt_embeds:
with hf_runner(model, dtype=dtype) as hf_model:
with torch.no_grad():
prompt_embeds = hf_model.get_prompt_embeddings(
example_prompts)
vllm_outputs = vllm_model.generate_greedy(
prompt_embeds, max_tokens)
prompt_embeds = hf_model.get_prompt_embeddings(example_prompts)
vllm_outputs = vllm_model.generate_greedy(prompt_embeds, max_tokens)
vllm_outputs = _fix_prompt_embed_outputs(
vllm_outputs, hf_model, example_prompts)
hf_outputs = hf_model.generate_greedy(
example_prompts, max_tokens)
vllm_outputs, hf_model, example_prompts
)
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
else:
vllm_outputs = vllm_model.generate_greedy(
example_prompts, max_tokens)
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy(
example_prompts, max_tokens)
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
check_outputs_equal(
outputs_0_lst=hf_outputs,
......@@ -209,27 +209,18 @@ def test_models_distributed(
def test_failed_model_execution(vllm_runner, monkeypatch) -> None:
from vllm.envs import VLLM_USE_V1
if not VLLM_USE_V1:
pytest.skip("Skipping V0 test, dump input not supported")
# Needed to mock an error in the same process
monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0')
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
with vllm_runner('facebook/opt-125m', enforce_eager=True) as vllm_model:
if isinstance(vllm_model.llm.llm_engine, LLMEngineV1):
with vllm_runner("facebook/opt-125m", enforce_eager=True) as vllm_model:
if isinstance(vllm_model.llm.llm_engine, LLMEngine):
v1_test_failed_model_execution(vllm_model)
def v1_test_failed_model_execution(vllm_model):
engine = vllm_model.llm.llm_engine
mocked_execute_model = Mock(
side_effect=RuntimeError("Mocked Critical Error"))
engine.engine_core.engine_core.model_executor.execute_model =\
mocked_execute_model
mocked_execute_model = Mock(side_effect=RuntimeError("Mocked Critical Error"))
engine.engine_core.engine_core.model_executor.execute_model = mocked_execute_model
with pytest.raises(RuntimeError) as exc_info:
prompts = [
......
......@@ -5,5 +5,6 @@ from ..utils import compare_two_settings
def test_cpu_offload():
compare_two_settings("meta-llama/Llama-3.2-1B-Instruct", [],
["--cpu-offload-gb", "1"])
compare_two_settings(
"hmellor/tiny-random-LlamaForCausalLM", [], ["--cpu-offload-gb", "1"]
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import pytest
import torch
from vllm import LLM, SamplingParams
from vllm import LLM, AsyncEngineArgs, AsyncLLMEngine, SamplingParams
from vllm.device_allocator.cumem import CuMemAllocator
from vllm.utils import GiB_bytes
from vllm.platforms import current_platform
from vllm.utils.mem_constants import GiB_bytes
from ..utils import create_new_process_for_each_test
@create_new_process_for_each_test()
@create_new_process_for_each_test("fork" if not current_platform.is_rocm() else "spawn")
def test_python_error():
"""
Test if Python error occurs when there's low-level
......@@ -23,13 +26,13 @@ def test_python_error():
tensors = []
with allocator.use_memory_pool():
# allocate 70% of the total memory
x = torch.empty(alloc_bytes, dtype=torch.uint8, device='cuda')
x = torch.empty(alloc_bytes, dtype=torch.uint8, device="cuda")
tensors.append(x)
# release the memory
allocator.sleep()
# allocate more memory than the total memory
y = torch.empty(alloc_bytes, dtype=torch.uint8, device='cuda')
y = torch.empty(alloc_bytes, dtype=torch.uint8, device="cuda")
tensors.append(y)
with pytest.raises(RuntimeError):
# when the allocator is woken up, it should raise an error
......@@ -37,21 +40,21 @@ def test_python_error():
allocator.wake_up()
@create_new_process_for_each_test()
@create_new_process_for_each_test("fork" if not current_platform.is_rocm() else "spawn")
def test_basic_cumem():
# some tensors from default memory pool
shape = (1024, 1024)
x = torch.empty(shape, device='cuda')
x = torch.empty(shape, device="cuda")
x.zero_()
# some tensors from custom memory pool
allocator = CuMemAllocator.get_instance()
with allocator.use_memory_pool():
# custom memory pool
y = torch.empty(shape, device='cuda')
y = torch.empty(shape, device="cuda")
y.zero_()
y += 1
z = torch.empty(shape, device='cuda')
z = torch.empty(shape, device="cuda")
z.zero_()
z += 2
......@@ -70,20 +73,20 @@ def test_basic_cumem():
assert torch.allclose(output, torch.ones_like(output) * 3)
@create_new_process_for_each_test()
@create_new_process_for_each_test("fork" if not current_platform.is_rocm() else "spawn")
def test_cumem_with_cudagraph():
allocator = CuMemAllocator.get_instance()
with allocator.use_memory_pool():
weight = torch.eye(1024, device='cuda')
weight = torch.eye(1024, device="cuda")
with allocator.use_memory_pool(tag="discard"):
cache = torch.empty(1024, 1024, device='cuda')
cache = torch.empty(1024, 1024, device="cuda")
def model(x):
out = x @ weight
cache[:out.size(0)].copy_(out)
cache[: out.size(0)].copy_(out)
return out + 1
x = torch.empty(128, 1024, device='cuda')
x = torch.empty(128, 1024, device="cuda")
# warmup
model(x)
......@@ -109,80 +112,72 @@ def test_cumem_with_cudagraph():
model_graph.replay()
# cache content is as expected
assert torch.allclose(x, cache[:x.size(0)])
assert torch.allclose(x, cache[: x.size(0)])
# output content is as expected
assert torch.allclose(y, x + 1)
@create_new_process_for_each_test()
@create_new_process_for_each_test("fork" if not current_platform.is_rocm() else "spawn")
@pytest.mark.parametrize(
"model, use_v1",
"model",
[
# sleep mode with safetensors
("meta-llama/Llama-3.2-1B", True),
"hmellor/tiny-random-LlamaForCausalLM",
# sleep mode with pytorch checkpoint
("facebook/opt-125m", True),
])
def test_end_to_end(monkeypatch: pytest.MonkeyPatch, model: str, use_v1: bool):
with monkeypatch.context() as m:
assert use_v1
m.setenv("VLLM_USE_V1", "1")
free, total = torch.cuda.mem_get_info()
used_bytes_baseline = total - free # in case other process is running
llm = LLM(model, enable_sleep_mode=True)
prompt = "How are you?"
sampling_params = SamplingParams(temperature=0, max_tokens=10)
output = llm.generate(prompt, sampling_params)
# the benefit of `llm.sleep(level=2)` is mainly CPU memory usage,
# which is difficult to measure in the test. therefore, we only
# test sleep level 1 here.
llm.sleep(level=1)
free_gpu_bytes_after_sleep, total = torch.cuda.mem_get_info()
used_bytes = total - free_gpu_bytes_after_sleep - used_bytes_baseline
# now the memory usage is mostly cudagraph memory pool,
# and it should be less than the model weights (1B model, 2GiB weights)
# NOTE: In V1, the memory buffer for logits (max_num_reqs x vocab_size)
# is captured but cannot be releasesd from PyTorch due to a known bug,
# therefore high memory usage after `llm.sleep` is called is expected.
# FIXME(youkaichao & ywang96): Fix memory buffer issue with sleep mode
# in V1.
if use_v1:
assert used_bytes < 7 * GiB_bytes
else:
assert used_bytes < 2 * GiB_bytes
llm.wake_up()
output2 = llm.generate(prompt, sampling_params)
# cmp output
assert output[0].outputs[0].text == output2[0].outputs[0].text
"facebook/opt-125m",
],
)
def test_end_to_end(model: str):
free, total = torch.cuda.mem_get_info()
used_bytes_baseline = total - free # in case other process is running
llm = LLM(model, enable_sleep_mode=True)
prompt = "How are you?"
sampling_params = SamplingParams(temperature=0, max_tokens=10)
output = llm.generate(prompt, sampling_params)
llm.sleep(level=1)
llm.wake_up(tags=["weights"])
# the benefit of `llm.sleep(level=2)` is mainly CPU memory usage,
# which is difficult to measure in the test. therefore, we only
# test sleep level 1 here.
llm.sleep(level=1)
free_gpu_bytes_wake_up_w, total = torch.cuda.mem_get_info()
used_bytes = total - free_gpu_bytes_wake_up_w - used_bytes_baseline
free_gpu_bytes_after_sleep, total = torch.cuda.mem_get_info()
used_bytes = total - free_gpu_bytes_after_sleep - used_bytes_baseline
# now the memory usage is mostly cudagraph memory pool,
# and it should be less than the model weights (1B model, 2GiB weights)
# should just reallocate memory for weights (1B model, ~2GiB weights)
if use_v1:
assert used_bytes < 10 * GiB_bytes
else:
assert used_bytes < 6 * GiB_bytes
# NOTE: In V1, the memory buffer for logits (max_num_reqs x vocab_size)
# is captured but cannot be releasesd from PyTorch due to a known bug,
# therefore high memory usage after `llm.sleep` is called is expected.
# FIXME(youkaichao & ywang96): Fix memory buffer issue with sleep mode
# in V1.
assert used_bytes < 7 * GiB_bytes
# now allocate kv cache memory
llm.wake_up(tags=["kv_cache"])
output3 = llm.generate(prompt, sampling_params)
llm.wake_up()
output2 = llm.generate(prompt, sampling_params)
# cmp output
assert output[0].outputs[0].text == output2[0].outputs[0].text
# cmp output
assert output[0].outputs[0].text == output3[0].outputs[0].text
llm.sleep(level=1)
llm.wake_up(tags=["weights"])
free_gpu_bytes_wake_up_w, total = torch.cuda.mem_get_info()
used_bytes = total - free_gpu_bytes_wake_up_w - used_bytes_baseline
# should just reallocate memory for weights (1B model, ~2GiB weights)
assert used_bytes < 10 * GiB_bytes
# now allocate kv cache memory
llm.wake_up(tags=["kv_cache"])
output3 = llm.generate(prompt, sampling_params)
# cmp output
assert output[0].outputs[0].text == output3[0].outputs[0].text
@create_new_process_for_each_test()
def test_deep_sleep():
model = "Qwen/Qwen3-0.6B"
model = "hmellor/tiny-random-LlamaForCausalLM"
free, total = torch.cuda.mem_get_info()
used_bytes_baseline = total - free # in case other process is running
llm = LLM(model, enable_sleep_mode=True)
......@@ -209,3 +204,42 @@ def test_deep_sleep():
# cmp output
assert output[0].outputs[0].text == output2[0].outputs[0].text
@create_new_process_for_each_test()
def test_deep_sleep_async():
async def test():
model = "hmellor/tiny-random-LlamaForCausalLM"
free, total = torch.cuda.mem_get_info()
used_bytes_baseline = total - free # in case other process is running
engine_args = AsyncEngineArgs(
model=model,
enable_sleep_mode=True,
)
llm = AsyncLLMEngine.from_engine_args(engine_args)
prompt = "How are you?"
sampling_params = SamplingParams(temperature=0, max_tokens=10)
outputs = llm.generate(prompt, sampling_params, request_id="test_request_id1")
async for output in outputs:
pass
# Put the engine to deep sleep
await llm.sleep(level=2)
await llm.wake_up(tags=["weights"])
await llm.collective_rpc("reload_weights")
free_gpu_bytes_wake_up_w, total = torch.cuda.mem_get_info()
used_bytes = total - free_gpu_bytes_wake_up_w - used_bytes_baseline
assert used_bytes < 4 * GiB_bytes
# now allocate kv cache and cuda graph memory
await llm.wake_up(tags=["kv_cache"])
outputs2 = llm.generate(prompt, sampling_params, request_id="test_request_id2")
async for output2 in outputs2:
pass
# cmp output
assert output.outputs[0].text == output2.outputs[0].text
asyncio.run(test())
......@@ -10,8 +10,18 @@ MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
@pytest.mark.benchmark
def test_bench_latency():
command = [
"vllm", "bench", "latency", "--model", MODEL_NAME, "--input-len", "32",
"--output-len", "1", "--enforce-eager", "--load-format", "dummy"
"vllm",
"bench",
"latency",
"--model",
MODEL_NAME,
"--input-len",
"32",
"--output-len",
"1",
"--enforce-eager",
"--load-format",
"dummy",
]
result = subprocess.run(command, capture_output=True, text=True)
print(result.stdout)
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import random
from typing import Any, NamedTuple, Optional, cast
from typing import Any, NamedTuple, cast
import numpy as np
import pytest
from transformers import AutoTokenizer, PreTrainedTokenizerBase
from vllm.benchmarks.datasets import (RandomDataset, RandomMultiModalDataset,
SampleRequest)
from vllm.benchmarks.datasets import (
RandomDataset,
RandomMultiModalDataset,
SampleRequest,
)
@pytest.fixture(scope="session")
......@@ -27,11 +30,9 @@ class Params(NamedTuple):
@pytest.fixture(scope="session")
def random_dataset_params() -> Params:
return Params(num_requests=16,
prefix_len=7,
range_ratio=0.3,
input_len=50,
output_len=20)
return Params(
num_requests=16, prefix_len=7, range_ratio=0.3, input_len=50, output_len=20
)
def _fingerprint_sample(req: SampleRequest) -> tuple[str, int, int]:
......@@ -39,13 +40,15 @@ def _fingerprint_sample(req: SampleRequest) -> tuple[str, int, int]:
return (req.prompt, req.prompt_len, req.expected_output_len)
def _collect_samples(dataset: RandomDataset,
tokenizer: PreTrainedTokenizerBase,
num_requests: int = 16,
prefix_len: int = 7,
range_ratio: float = 0.3,
input_len: int = 50,
output_len: int = 20) -> list[tuple[str, int, int]]:
def _collect_samples(
dataset: RandomDataset,
tokenizer: PreTrainedTokenizerBase,
num_requests: int = 16,
prefix_len: int = 7,
range_ratio: float = 0.3,
input_len: int = 50,
output_len: int = 20,
) -> list[tuple[str, int, int]]:
samples = dataset.sample(
tokenizer=tokenizer,
num_requests=num_requests,
......@@ -59,8 +62,8 @@ def _collect_samples(dataset: RandomDataset,
@pytest.mark.benchmark
def test_random_dataset_same_seed(
hf_tokenizer: PreTrainedTokenizerBase,
random_dataset_params: Params) -> None:
hf_tokenizer: PreTrainedTokenizerBase, random_dataset_params: Params
) -> None:
"""Same seed should yield identical outputs, even if global RNGs change.
This guards against accidental reliance on Python's random or np.random
......@@ -70,13 +73,15 @@ def test_random_dataset_same_seed(
common_seed = 123
dataset_a = RandomDataset(random_seed=common_seed)
dataset_b = RandomDataset(random_seed=common_seed)
a = _collect_samples(dataset_a,
hf_tokenizer,
num_requests=p.num_requests,
prefix_len=p.prefix_len,
range_ratio=p.range_ratio,
input_len=p.input_len,
output_len=p.output_len)
a = _collect_samples(
dataset_a,
hf_tokenizer,
num_requests=p.num_requests,
prefix_len=p.prefix_len,
range_ratio=p.range_ratio,
input_len=p.input_len,
output_len=p.output_len,
)
# Perturb global RNG state to ensure isolation
random.seed(999)
......@@ -84,43 +89,50 @@ def test_random_dataset_same_seed(
np.random.seed(888)
_ = [np.random.random() for _ in range(100)]
b = _collect_samples(dataset_b,
hf_tokenizer,
num_requests=p.num_requests,
prefix_len=p.prefix_len,
range_ratio=p.range_ratio,
input_len=p.input_len,
output_len=p.output_len)
b = _collect_samples(
dataset_b,
hf_tokenizer,
num_requests=p.num_requests,
prefix_len=p.prefix_len,
range_ratio=p.range_ratio,
input_len=p.input_len,
output_len=p.output_len,
)
assert a == b
@pytest.mark.benchmark
def test_random_dataset_different_seeds(
hf_tokenizer: PreTrainedTokenizerBase,
random_dataset_params: Params) -> None:
hf_tokenizer: PreTrainedTokenizerBase, random_dataset_params: Params
) -> None:
"""Different seeds should change outputs with overwhelming likelihood."""
p = random_dataset_params
seed_a = 0
dataset_a = RandomDataset(random_seed=seed_a)
a = _collect_samples(dataset_a,
hf_tokenizer,
num_requests=p.num_requests,
prefix_len=p.prefix_len,
range_ratio=p.range_ratio,
input_len=p.input_len,
output_len=p.output_len)
a = _collect_samples(
dataset_a,
hf_tokenizer,
num_requests=p.num_requests,
prefix_len=p.prefix_len,
range_ratio=p.range_ratio,
input_len=p.input_len,
output_len=p.output_len,
)
seed_b = 999
dataset_b = RandomDataset(random_seed=seed_b)
# Perturb global RNG with same seed as dataset_a to ensure isolation
random.seed(seed_a)
np.random.seed(seed_a)
b = _collect_samples(dataset_b,
hf_tokenizer,
num_requests=p.num_requests,
prefix_len=p.prefix_len,
range_ratio=p.range_ratio,
input_len=p.input_len,
output_len=p.output_len)
b = _collect_samples(
dataset_b,
hf_tokenizer,
num_requests=p.num_requests,
prefix_len=p.prefix_len,
range_ratio=p.range_ratio,
input_len=p.input_len,
output_len=p.output_len,
)
assert a != b
......@@ -128,6 +140,7 @@ def test_random_dataset_different_seeds(
# RandomMultiModalDataset tests
# -----------------------------
def _mm_fingerprint_sample(
req: SampleRequest,
) -> tuple[str, int, int, int, list[str]]:
......@@ -152,8 +165,13 @@ def _mm_fingerprint_sample(
item_prefixes.append(f"video:{url[:22]}")
else:
item_prefixes.append("unknown:")
return (req.prompt, req.prompt_len, req.expected_output_len, len(items),
item_prefixes)
return (
req.prompt,
req.prompt_len,
req.expected_output_len,
len(items),
item_prefixes,
)
def _collect_mm_samples(
......@@ -167,8 +185,8 @@ def _collect_mm_samples(
output_len: int = 5,
base_items_per_request: int = 2,
num_mm_items_range_ratio: float = 0.0,
limit_mm_per_prompt: Optional[dict[str, int]] = None,
bucket_config: Optional[dict[tuple[int, int, int], float]] = None,
limit_mm_per_prompt: dict[str, int] | None = None,
bucket_config: dict[tuple[int, int, int], float] | None = None,
enable_multimodal_chat: bool = False,
) -> list[SampleRequest]:
if limit_mm_per_prompt is None:
......@@ -214,6 +232,7 @@ def test_random_mm_different_seeds(
fb = [_mm_fingerprint_sample(s) for s in b]
assert fa != fb
@pytest.mark.benchmark
def test_random_mm_respects_limits(
hf_tokenizer: PreTrainedTokenizerBase,
......@@ -271,9 +290,9 @@ def test_random_mm_zero_items(hf_tokenizer: PreTrainedTokenizerBase) -> None:
for s in samples:
assert s.multi_modal_data == []
@pytest.mark.benchmark
def test_random_mm_num_items_per_prompt(
hf_tokenizer: PreTrainedTokenizerBase) -> None:
def test_random_mm_num_items_per_prompt(hf_tokenizer: PreTrainedTokenizerBase) -> None:
ds = RandomMultiModalDataset(random_seed=0)
# Fixed number of images per prompt
# set num_mm_items_range_ratio to 0.0
......@@ -300,7 +319,6 @@ def test_random_mm_num_items_per_prompt(
def test_random_mm_bucket_config_not_mutated(
hf_tokenizer: PreTrainedTokenizerBase,
) -> None:
ds = RandomMultiModalDataset(random_seed=0)
# This bucket config is not normalized to sum to 1
# and has more buckets than requested images
......@@ -321,7 +339,6 @@ def test_random_mm_bucket_config_not_mutated(
# Ensure the original dict content is unchanged
assert original == snapshot
# Vary number of mm items per prompt
# set num_mm_items_range_ratio to 0.5
samples_varying_items = _collect_mm_samples(
......@@ -342,3 +359,126 @@ def test_random_mm_bucket_config_not_mutated(
assert len(mm_data) >= 1
for it in mm_data:
assert it.get("type") == "image_url"
@pytest.mark.benchmark
def test_random_mm_video_sampling(hf_tokenizer: PreTrainedTokenizerBase) -> None:
"""Test video sampling functionality in RandomMultiModalDataset."""
ds = RandomMultiModalDataset(random_seed=42)
# Test with video bucket configuration
bucket_config = {
(64, 64, 1): 0.3, # Images
(64, 64, 8): 0.7, # Videos
}
limit_mm_per_prompt = {"image": 2, "video": 2}
samples = _collect_mm_samples(
ds,
hf_tokenizer,
num_requests=5,
base_items_per_request=1,
num_mm_items_range_ratio=0.0,
limit_mm_per_prompt=limit_mm_per_prompt,
bucket_config=bucket_config,
)
assert len(samples) == 5
# Check that we have both images and videos
video_count = 0
image_count = 0
for s in samples:
mm_data = cast(list[dict[str, Any]], s.multi_modal_data)
assert len(mm_data) == 1
item = mm_data[0]
if item.get("type") == "video_url":
video_count += 1
# Verify video URL format
url = item.get("video_url", {}).get("url", "")
assert url.startswith("data:video/mp4;base64,")
elif item.get("type") == "image_url":
image_count += 1
# Verify image URL format
url = item.get("image_url", {}).get("url", "")
assert url.startswith("data:image/jpeg;base64,")
# Should have some videos due to 0.7 probability
assert video_count > 0
assert image_count > 0
@pytest.mark.benchmark
def test_random_mm_video_only_sampling(hf_tokenizer: PreTrainedTokenizerBase) -> None:
"""Test sampling with only video buckets."""
ds = RandomMultiModalDataset(random_seed=42)
bucket_config = {
(64, 64, 8): 1.0, # Only videos
}
limit_mm_per_prompt = {"image": 0, "video": 1}
samples = _collect_mm_samples(
ds,
hf_tokenizer,
num_requests=3,
base_items_per_request=1,
num_mm_items_range_ratio=0.0,
limit_mm_per_prompt=limit_mm_per_prompt,
bucket_config=bucket_config,
)
assert len(samples) == 3
for s in samples:
mm_data = cast(list[dict[str, Any]], s.multi_modal_data)
assert len(mm_data) == 1
item = mm_data[0]
assert item.get("type") == "video_url"
url = item.get("video_url", {}).get("url", "")
assert url.startswith("data:video/mp4;base64,")
@pytest.mark.benchmark
def test_random_mm_video_deterministic_sampling(
hf_tokenizer: PreTrainedTokenizerBase,
) -> None:
"""Test that video sampling is deterministic with same seed."""
seed = 123
ds_a = RandomMultiModalDataset(random_seed=seed)
ds_b = RandomMultiModalDataset(random_seed=seed)
bucket_config = {
(64, 64, 8): 1.0, # Only videos
}
limit_mm_per_prompt = {"image": 0, "video": 1}
a = _collect_mm_samples(
ds_a,
hf_tokenizer,
num_requests=3,
base_items_per_request=1,
num_mm_items_range_ratio=0.0,
limit_mm_per_prompt=limit_mm_per_prompt,
bucket_config=bucket_config,
)
b = _collect_mm_samples(
ds_b,
hf_tokenizer,
num_requests=3,
base_items_per_request=1,
num_mm_items_range_ratio=0.0,
limit_mm_per_prompt=limit_mm_per_prompt,
bucket_config=bucket_config,
)
fa = [_mm_fingerprint_sample(s) for s in a]
fb = [_mm_fingerprint_sample(s) for s in b]
assert fa == fb
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import base64
import os
from tempfile import NamedTemporaryFile
from typing import Any, cast
import cv2
import pytest
from transformers import AutoTokenizer, PreTrainedTokenizerBase
from vllm.benchmarks.datasets import RandomMultiModalDataset, SampleRequest
@pytest.fixture(scope="session")
def hf_tokenizer() -> PreTrainedTokenizerBase:
"""Use a small, commonly available tokenizer."""
return AutoTokenizer.from_pretrained("gpt2")
@pytest.fixture
def video_dataset() -> RandomMultiModalDataset:
"""Create a RandomMultiModalDataset instance for testing."""
return RandomMultiModalDataset(random_seed=42)
@pytest.mark.benchmark
def test_generate_synthetic_video_different_seeds():
"""Test that different seeds produce different videos."""
dataset1 = RandomMultiModalDataset(random_seed=123)
dataset2 = RandomMultiModalDataset(random_seed=456)
width, height, num_frames = 64, 48, 8
video1 = dataset1.generate_synthetic_video(width, height, num_frames)
video2 = dataset2.generate_synthetic_video(width, height, num_frames)
# Videos should be different due to different seeds
assert video1["bytes"] != video2["bytes"]
@pytest.mark.benchmark
def test_map_config_to_modality(video_dataset: RandomMultiModalDataset):
"""Test modality mapping for different configurations."""
# Test image configuration (num_frames = 1)
assert video_dataset.map_config_to_modality((256, 256, 1)) == "image"
assert video_dataset.map_config_to_modality((720, 1280, 1)) == "image"
# Test video configurations (num_frames > 1)
assert video_dataset.map_config_to_modality((256, 256, 8)) == "video"
assert video_dataset.map_config_to_modality((720, 1280, 16)) == "video"
assert video_dataset.map_config_to_modality((64, 64, 32)) == "video"
# Test invalid configurations
with pytest.raises(ValueError, match="Invalid multimodal item configuration"):
video_dataset.map_config_to_modality((256, 256, 0))
with pytest.raises(ValueError, match="Invalid multimodal item configuration"):
video_dataset.map_config_to_modality((256, 256, -1))
@pytest.mark.benchmark
def test_generate_mm_item_video(video_dataset: RandomMultiModalDataset):
"""Test generating multimodal items for video configurations."""
# Test video item generation
video_config = (64, 48, 8) # height, width, num_frames
result = video_dataset.generate_mm_item(video_config)
# Check the result structure matches OpenAI API format
assert isinstance(result, dict)
assert result["type"] == "video_url"
assert "video_url" in result
assert "url" in result["video_url"]
# Check that the URL is a data URL with base64 encoded video
url = result["video_url"]["url"]
assert url.startswith("data:video/mp4;base64,")
# Decode and verify the video content
base64_data = url.split(",")[1]
video_bytes = base64.b64decode(base64_data)
assert len(video_bytes) > 0
# Verify the video can be decoded
with NamedTemporaryFile(suffix=".mp4", delete=False) as temp_file:
temp_path = temp_file.name
temp_file.write(video_bytes)
try:
cap = cv2.VideoCapture(temp_path)
assert cap.isOpened()
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
assert frame_count == 8
assert frame_width == 48
assert frame_height == 64
cap.release()
finally:
if os.path.exists(temp_path):
os.unlink(temp_path)
@pytest.mark.benchmark
def test_generate_mm_item_image(video_dataset: RandomMultiModalDataset):
"""Test generating multimodal items for image configurations."""
# Test image item generation
image_config = (64, 48, 1) # height, width, num_frames=1
result = video_dataset.generate_mm_item(image_config)
# Check the result structure matches OpenAI API format
assert isinstance(result, dict)
assert result["type"] == "image_url"
assert "image_url" in result
assert "url" in result["image_url"]
# Check that the URL is a data URL with base64 encoded image
url = result["image_url"]["url"]
assert url.startswith("data:image/jpeg;base64,")
@pytest.mark.benchmark
def test_generate_mm_item_invalid_config(video_dataset: RandomMultiModalDataset):
"""Test error handling for invalid configurations."""
with pytest.raises(ValueError, match="Invalid multimodal item configuration"):
video_dataset.generate_mm_item((256, 256, 0))
@pytest.mark.benchmark
def test_sample_with_video_buckets(
video_dataset: RandomMultiModalDataset, hf_tokenizer: PreTrainedTokenizerBase
):
"""Test sampling with video bucket configurations."""
# Configure bucket with video probability > 0
bucket_config = {
(64, 64, 1): 0.3, # Images
(64, 64, 8): 0.7, # Videos
}
limit_mm_per_prompt = {"image": 5, "video": 3}
samples = video_dataset.sample(
tokenizer=hf_tokenizer,
num_requests=5,
base_items_per_request=2,
num_mm_items_range_ratio=0.0,
limit_mm_per_prompt=limit_mm_per_prompt,
bucket_config=bucket_config,
input_len=20,
output_len=5,
)
assert len(samples) == 5
# Check that samples contain both images and videos
video_count = 0
image_count = 0
for sample in samples:
assert isinstance(sample, SampleRequest)
assert sample.multi_modal_data is not None
assert isinstance(sample.multi_modal_data, list)
mm_data = cast(list[dict[str, Any]], sample.multi_modal_data)
assert len(mm_data) == 2 # base_items_per_request
for item in mm_data:
if item["type"] == "video_url":
video_count += 1
# Verify video URL format
url = item["video_url"]["url"]
assert url.startswith("data:video/mp4;base64,")
elif item["type"] == "image_url":
image_count += 1
# Verify image URL format
url = item["image_url"]["url"]
assert url.startswith("data:image/jpeg;base64,")
# Should have some videos due to 0.7 probability
assert video_count > 0
assert image_count > 0
@pytest.mark.benchmark
def test_sample_video_only_buckets(
video_dataset: RandomMultiModalDataset, hf_tokenizer: PreTrainedTokenizerBase
):
"""Test sampling with only video buckets."""
bucket_config = {
(64, 64, 8): 1.0, # Only videos
}
limit_mm_per_prompt = {"image": 0, "video": 2}
samples = video_dataset.sample(
tokenizer=hf_tokenizer,
num_requests=3,
base_items_per_request=1,
num_mm_items_range_ratio=0.0,
limit_mm_per_prompt=limit_mm_per_prompt,
bucket_config=bucket_config,
input_len=20,
output_len=5,
)
assert len(samples) == 3
for sample in samples:
assert isinstance(sample, SampleRequest)
assert sample.multi_modal_data is not None
assert isinstance(sample.multi_modal_data, list)
mm_data = cast(list[dict[str, Any]], sample.multi_modal_data)
assert len(mm_data) == 1
item = mm_data[0]
assert item["type"] == "video_url"
url = item["video_url"]["url"]
assert url.startswith("data:video/mp4;base64,")
@pytest.mark.benchmark
def test_sample_respects_video_limits(
video_dataset: RandomMultiModalDataset, hf_tokenizer: PreTrainedTokenizerBase
):
"""Test that sampling respects video limits per prompt."""
bucket_config = {
(64, 64, 8): 1.0, # Only videos
}
# Set very low video limit
limit_mm_per_prompt = {"image": 0, "video": 1}
samples = video_dataset.sample(
tokenizer=hf_tokenizer,
num_requests=3,
base_items_per_request=1,
num_mm_items_range_ratio=0.0,
limit_mm_per_prompt=limit_mm_per_prompt,
bucket_config=bucket_config,
input_len=20,
output_len=5,
)
assert len(samples) == 3
for sample in samples:
mm_data = cast(list[dict[str, Any]], sample.multi_modal_data)
assert len(mm_data) <= 1 # Should respect video limit
@pytest.mark.benchmark
def test_sample_mixed_buckets_with_zero_probability(
video_dataset: RandomMultiModalDataset, hf_tokenizer: PreTrainedTokenizerBase
):
"""Test sampling with mixed buckets including zero probability entries."""
bucket_config = {
(64, 64, 1): 0.5, # Images
(64, 64, 8): 0.5, # Videos
(128, 128, 16): 0.0, # Zero probability videos (should be ignored)
}
limit_mm_per_prompt = {"image": 2, "video": 2}
samples = video_dataset.sample(
tokenizer=hf_tokenizer,
num_requests=4,
base_items_per_request=2,
num_mm_items_range_ratio=0.0,
limit_mm_per_prompt=limit_mm_per_prompt,
bucket_config=bucket_config,
input_len=20,
output_len=5,
)
assert len(samples) == 4
# Should only see 64x64 videos, not 128x128 videos
for sample in samples:
mm_data = cast(list[dict[str, Any]], sample.multi_modal_data)
for item in mm_data:
if item["type"] == "video_url":
# Decode video to verify dimensions
url = item["video_url"]["url"]
base64_data = url.split(",")[1]
video_bytes = base64.b64decode(base64_data)
with NamedTemporaryFile(suffix=".mp4", delete=False) as temp_file: # noqa
temp_path = temp_file.name
temp_file.write(video_bytes)
try:
cap = cv2.VideoCapture(temp_path)
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
cap.release()
# Should be 64x64, not 128x128
assert frame_width == 64
assert frame_height == 64
finally:
if os.path.exists(temp_path):
os.unlink(temp_path)
@pytest.mark.benchmark
def test_sample_deterministic_with_videos(hf_tokenizer: PreTrainedTokenizerBase):
"""Test that sampling with videos is deterministic with same seed."""
dataset1 = RandomMultiModalDataset(random_seed=123)
dataset2 = RandomMultiModalDataset(random_seed=123)
bucket_config = {
(64, 64, 1): 0.3, # Images
(64, 64, 8): 0.7, # Videos
}
limit_mm_per_prompt = {"image": 2, "video": 2}
samples1 = dataset1.sample(
tokenizer=hf_tokenizer,
num_requests=3,
base_items_per_request=1,
num_mm_items_range_ratio=0.0,
limit_mm_per_prompt=limit_mm_per_prompt,
bucket_config=bucket_config,
input_len=20,
output_len=5,
)
samples2 = dataset2.sample(
tokenizer=hf_tokenizer,
num_requests=3,
base_items_per_request=1,
num_mm_items_range_ratio=0.0,
limit_mm_per_prompt=limit_mm_per_prompt,
bucket_config=bucket_config,
input_len=20,
output_len=5,
)
assert len(samples1) == len(samples2)
# Compare multimodal data
for s1, s2 in zip(samples1, samples2):
assert s1.multi_modal_data == s2.multi_modal_data
@pytest.mark.benchmark
def test_sample_different_seeds_produce_different_videos(
hf_tokenizer: PreTrainedTokenizerBase,
):
"""Test that different seeds produce different video content."""
dataset1 = RandomMultiModalDataset(random_seed=123)
dataset2 = RandomMultiModalDataset(random_seed=456)
bucket_config = {
(64, 64, 8): 1.0, # Only videos
}
limit_mm_per_prompt = {"image": 0, "video": 1}
samples1 = dataset1.sample(
tokenizer=hf_tokenizer,
num_requests=2,
base_items_per_request=1,
num_mm_items_range_ratio=0.0,
limit_mm_per_prompt=limit_mm_per_prompt,
bucket_config=bucket_config,
input_len=20,
output_len=5,
)
samples2 = dataset2.sample(
tokenizer=hf_tokenizer,
num_requests=2,
base_items_per_request=1,
num_mm_items_range_ratio=0.0,
limit_mm_per_prompt=limit_mm_per_prompt,
bucket_config=bucket_config,
input_len=20,
output_len=5,
)
# Video content should be different
for s1, s2 in zip(samples1, samples2):
mm_data1 = cast(list[dict[str, Any]], s1.multi_modal_data)
mm_data2 = cast(list[dict[str, Any]], s2.multi_modal_data)
assert len(mm_data1) == len(mm_data2) == 1
url1 = mm_data1[0]["video_url"]["url"]
url2 = mm_data2[0]["video_url"]["url"]
assert url1 != url2 # Different video content
......@@ -11,9 +11,7 @@ MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
@pytest.fixture(scope="module")
def server():
args = [
"--max-model-len", "1024", "--enforce-eager", "--load-format", "dummy"
]
args = ["--max-model-len", "1024", "--enforce-eager", "--load-format", "dummy"]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
......@@ -46,6 +44,7 @@ def test_bench_serve(server):
assert result.returncode == 0, f"Benchmark failed: {result.stderr}"
@pytest.mark.benchmark
def test_bench_serve_chat(server):
command = [
......
......@@ -10,8 +10,18 @@ MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
@pytest.mark.benchmark
def test_bench_throughput():
command = [
"vllm", "bench", "throughput", "--model", MODEL_NAME, "--input-len",
"32", "--output-len", "1", "--enforce-eager", "--load-format", "dummy"
"vllm",
"bench",
"throughput",
"--model",
MODEL_NAME,
"--input-len",
"32",
"--output-len",
"1",
"--enforce-eager",
"--load-format",
"dummy",
]
result = subprocess.run(command, capture_output=True, text=True)
print(result.stdout)
......
......@@ -5,13 +5,16 @@ These envs only work for a small part of the tests, fix what you need!
"""
import os
from typing import TYPE_CHECKING, Any, Callable, Optional
from collections.abc import Callable
from typing import TYPE_CHECKING, Any
from vllm.envs import maybe_convert_bool
if TYPE_CHECKING:
VLLM_CI_NO_SKIP: bool = False
VLLM_CI_DTYPE: Optional[str] = None
VLLM_CI_HEAD_DTYPE: Optional[str] = None
VLLM_CI_HF_DTYPE: Optional[str] = None
VLLM_CI_DTYPE: str | None = None
VLLM_CI_HEAD_DTYPE: str | None = None
VLLM_CI_HF_DTYPE: str | None = None
environment_variables: dict[str, Callable[[], Any]] = {
# A model family has many models with the same architecture.
......@@ -24,6 +27,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_CI_HEAD_DTYPE": lambda: os.getenv("VLLM_CI_HEAD_DTYPE", None),
# Allow changing the head dtype used by transformers in tests
"VLLM_CI_HF_DTYPE": lambda: os.getenv("VLLM_CI_HF_DTYPE", None),
# Allow control over whether tests use enforce_eager
"VLLM_CI_ENFORCE_EAGER": lambda: maybe_convert_bool(
os.getenv("VLLM_CI_ENFORCE_EAGER", None)
),
}
......
......@@ -2,18 +2,23 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import weakref
from collections.abc import Sequence
from collections.abc import Callable, Sequence
from contextlib import nullcontext
from copy import deepcopy
from typing import Callable, Union
import depyf
from torch import fx
from torch._ops import OpOverload
from torch.fx._utils import lazy_format_graph_code
from vllm.compilation.fx_utils import find_op_nodes
from vllm.compilation.inductor_pass import InductorPass
from vllm.compilation.pass_manager import with_pattern_match_debug
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.logger import init_logger
logger = init_logger("vllm.tests.compile.backend")
class LazyInitPass(InductorPass):
......@@ -23,8 +28,7 @@ class LazyInitPass(InductorPass):
and then immediately invoke it.
"""
def __init__(self, pass_cls: type[VllmInductorPass],
vllm_config: VllmConfig):
def __init__(self, pass_cls: type[VllmInductorPass], vllm_config: VllmConfig):
self.pass_cls = pass_cls
self.vllm_config = weakref.proxy(vllm_config) # avoid cycle
......@@ -45,24 +49,34 @@ class TestBackend:
Inductor config is default-initialized from VllmConfig.CompilationConfig.
"""
def __init__(self, *passes: Union[InductorPass, Callable[[fx.Graph],
None]]):
def __init__(self, *passes: InductorPass | Callable[[fx.Graph], None]):
self.custom_passes = list(passes)
compile_config = get_current_vllm_config().compilation_config
self.inductor_config = compile_config.inductor_compile_config
self.inductor_config['force_disable_caches'] = True
self.inductor_config['post_grad_custom_post_pass'] = self.post_pass
vllm_config = get_current_vllm_config()
compile_config = vllm_config.compilation_config
# Deepcopy to allow multiple TestBackend instances to use the same VllmConfig
self.inductor_config = deepcopy(compile_config.inductor_compile_config)
self.inductor_config["force_disable_caches"] = True
self.inductor_config["post_grad_custom_post_pass"] = self.post_pass
if debug_dump_path := vllm_config.compile_debug_dump_path():
logger.debug("Dumping depyf output to %s", debug_dump_path)
self.debug_ctx = depyf.prepare_debug(debug_dump_path.as_posix())
else:
self.debug_ctx = nullcontext()
def __call__(self, graph: fx.GraphModule, example_inputs):
self.graph_pre_compile = deepcopy(graph)
from torch._inductor.compile_fx import compile_fx
return compile_fx(graph,
example_inputs,
config_patches=self.inductor_config)
with self.debug_ctx:
return compile_fx(
graph, example_inputs, config_patches=self.inductor_config
)
@with_pattern_match_debug
def post_pass(self, graph: fx.Graph):
self.graph_pre_pass = deepcopy(graph)
lazy_format_graph_code("graph_pre_pass", graph.owning_module)
VllmInductorPass.dump_prefix = 0
for pass_ in self.custom_passes:
......@@ -72,6 +86,7 @@ class TestBackend:
VllmInductorPass.dump_prefix = None
self.graph_post_pass = deepcopy(graph)
lazy_format_graph_code("graph_post_pass", graph.owning_module)
# assign by reference, will reflect the final state of the graph
self.final_graph = graph
......@@ -82,8 +97,7 @@ class TestBackend:
assert num_pre > 0, f"Op {op.name()} not found in pre-pass graph"
assert num_pre > num_post, f"All nodes remain for op {op.name()}"
if fully_replaced:
assert num_post == 0, \
f"Unexpected op {op.name()} in post-pass graph"
assert num_post == 0, f"Unexpected op {op.name()} in post-pass graph"
def check_after_ops(self, ops: Sequence[OpOverload]):
for op in ops:
......
......@@ -3,15 +3,15 @@
import contextlib
import os
import weakref
from dataclasses import dataclass
from typing import Optional
import pytest
from tests.utils import wait_for_gpu_memory_to_clear
from tests.v1.attention.utils import full_cg_backend_configs as backend_configs
from vllm import LLM, SamplingParams
from vllm.config import CompilationConfig
from vllm.platforms import current_platform
from vllm.utils.torch_utils import is_torch_equal_or_newer
@contextlib.contextmanager
......@@ -33,121 +33,44 @@ def temporary_environ(env_vars):
os.environ[k] = v
@dataclass
class BackendConfig:
name: str
env_vars: dict
comp_config: dict
specific_gpu_arch: Optional[tuple] = None
# Define all backend configurations of full cudagraph to be tested
backend_configs = {
# FA3 on Hopper
"FA3":
BackendConfig(name="FA3",
env_vars={
"VLLM_FLASH_ATTN_VERSION": "3",
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
},
comp_config={
"cudagraph_mode": "FULL",
},
specific_gpu_arch=(9, 0)),
# FlashMLA on Hopper
"FlashMLA":
BackendConfig(name="FlashMLA",
env_vars={
"VLLM_ATTENTION_BACKEND": "FLASHMLA",
},
comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE",
},
specific_gpu_arch=(9, 0)),
# FlashAttention MLA on Hopper
"FlashAttentionMLA":
BackendConfig(name="FlashAttentionMLA",
env_vars={
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN_MLA",
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
},
comp_config={
"cudagraph_mode": "FULL_DECODE_ONLY",
},
specific_gpu_arch=(9, 0)),
# Cutlass MLA on Blackwell
"CutlassMLA":
BackendConfig(
name="CutlassMLA",
env_vars={
"VLLM_USE_V1": "1",
"VLLM_ATTENTION_BACKEND": "CUTLASS_MLA",
"FORCE_NUM_KV_SPLITS":
"1", # TODO: remove this when hang issue is fixed
},
comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE",
"cudagraph_capture_sizes": [16, 32, 64, 128, 256, 512],
},
specific_gpu_arch=(10, 0)),
# FA2
"FA2":
BackendConfig(name="FA2",
env_vars={
"VLLM_FLASH_ATTN_VERSION": "2",
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
},
comp_config={
"cudagraph_mode": "FULL",
}),
# Triton Attention
"TritonAttn":
BackendConfig(name="TritonAttn",
env_vars={"VLLM_ATTENTION_BACKEND": "TRITON_ATTN"},
comp_config={
"cudagraph_mode": "FULL",
}),
# FlashInfer
"FlashInfer":
BackendConfig(name="FlashInfer",
env_vars={"VLLM_ATTENTION_BACKEND": "FLASHINFER"},
comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE",
}),
}
test_params_full_cudagraph = []
model_backends_full_cudagraph = []
# deepseek-ai/DeepSeek-V2-Lite with MLA
MLA_backends = ["FlashMLA", "FlashAttentionMLA", "CutlassMLA"]
for mla_backend in MLA_backends:
test_params_full_cudagraph.append(
pytest.param(
("deepseek-ai/DeepSeek-V2-Lite", backend_configs[mla_backend])))
model_backends_full_cudagraph.append(
("deepseek-ai/DeepSeek-V2-Lite", backend_configs[mla_backend])
)
# Qwen/Qwen2-1.5B-Instruct with other backends
other_backend_configs = [
backend_configs[c] for c in backend_configs if c not in MLA_backends
]
for backend_config in other_backend_configs:
test_params_full_cudagraph.append(
pytest.param(("Qwen/Qwen2-1.5B-Instruct", backend_config)))
model_backends_full_cudagraph.append(("Qwen/Qwen2-1.5B-Instruct", backend_config))
@pytest.fixture(scope="class")
def llm_pair(request):
model, backend_config = request.param
model, backend_config, use_inductor_graph_partition = request.param
backend_config.comp_config["use_inductor_graph_partition"] = (
use_inductor_graph_partition
)
if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
pytest.skip("Inductor graph partition only supported in torch>=2.9")
# Dynamically skip test if GPU capability is not met
if backend_config.specific_gpu_arch and backend_config.specific_gpu_arch\
!= current_platform.get_device_capability():
if (
backend_config.specific_gpu_arch
and backend_config.specific_gpu_arch != current_platform.get_device_capability()
):
if backend_config.specific_gpu_arch == (9, 0):
pytest.skip("Only Hopper GPUs support FA3 and FlashMLA")
elif backend_config.specific_gpu_arch == (10, 0):
pytest.skip("Only Blackwell GPUs support Cutlass MLA")
env_vars = {
"VLLM_USE_V1": "1",
# Force native sampler to avoid potential nondeterminism in FlashInfer
# when per-request generators are not used in V1.
"VLLM_USE_FLASHINFER_SAMPLER": "0",
......@@ -160,8 +83,7 @@ def llm_pair(request):
trust_remote_code=True,
max_model_len=1024,
max_num_seqs=128,
compilation_config=\
CompilationConfig(**backend_config.comp_config),
compilation_config=CompilationConfig(**backend_config.comp_config),
generation_config="vllm",
seed=42,
)
......@@ -187,7 +109,15 @@ def llm_pair(request):
)
@pytest.mark.parametrize("llm_pair", test_params_full_cudagraph, indirect=True)
@pytest.mark.parametrize(
"llm_pair",
[
pytest.param((model, backend_config, use_inductor_graph_partition))
for model, backend_config in model_backends_full_cudagraph
for use_inductor_graph_partition in [True, False]
],
indirect=True,
)
class TestFullCUDAGraph:
"""
Use a class such that an llm pair is constructed once for all
......@@ -197,20 +127,22 @@ class TestFullCUDAGraph:
meaning there would be multiple LLM instances hogging memory simultaneously.
"""
@pytest.mark.parametrize(("batch_size", "max_tokens"), [
(1, 10),
(7, 10),
(16, 10),
(25, 10),
(32, 10),
(45, 10),
(64, 10),
(123, 10),
(8, 5),
(8, 30),
])
def test_full_cudagraph(self, batch_size, max_tokens,
llm_pair: tuple[LLM, LLM]):
@pytest.mark.parametrize(
("batch_size", "max_tokens"),
[
(1, 10),
(7, 10),
(16, 10),
(25, 10),
(32, 10),
(45, 10),
(64, 10),
(123, 10),
(8, 5),
(8, 30),
],
)
def test_full_cudagraph(self, batch_size, max_tokens, llm_pair: tuple[LLM, LLM]):
"""
Test various batch sizes and max_tokens to ensure that the
full cudagraph compilation works for padded cases too.
......@@ -221,26 +153,33 @@ class TestFullCUDAGraph:
prompts = ["the quick brown fox"] * batch_size
# Use purely greedy decoding to avoid top-p truncation sensitivity
# that can amplify tiny numeric differences across runtimes.
sampling_params = SamplingParams(temperature=0.0,
max_tokens=max_tokens,
top_p=1.0)
sampling_params = SamplingParams(
temperature=0.0, max_tokens=max_tokens, top_p=1.0
)
piecewise_responses = piecewise_llm.generate(prompts, sampling_params)
full_responses = full_cudagraph_llm.generate(prompts, sampling_params)
# Check that all responses are the same
for piecewise_res, full_res in zip(piecewise_responses,
full_responses):
assert piecewise_res.outputs[0].text.lower() == \
full_res.outputs[0].text.lower()
for piecewise_res, full_res in zip(piecewise_responses, full_responses):
assert (
piecewise_res.outputs[0].text.lower()
== full_res.outputs[0].text.lower()
)
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
def test_full_cudagraph_with_invalid_backend():
with temporary_environ({
"VLLM_USE_V1": "1",
"VLLM_ATTENTION_BACKEND": "FLEX_ATTENTION"
# Flex_Attention is not supported with full cuda graph
}), pytest.raises(RuntimeError):
LLM(model="Qwen/Qwen2-1.5B-Instruct",
compilation_config=CompilationConfig(cudagraph_mode="FULL"))
with (
temporary_environ(
{
"VLLM_ATTENTION_BACKEND": "FLEX_ATTENTION",
# Flex_Attention is not supported with full cuda graph
}
),
pytest.raises(RuntimeError),
):
LLM(
model="Qwen/Qwen2-1.5B-Instruct",
compilation_config=CompilationConfig(cudagraph_mode="FULL"),
)
......@@ -5,16 +5,24 @@ Test (piecewise) compilation with a simple model where multiple submodules
are compiled and graph captured separately.
"""
import pytest
import torch
from torch import nn
from vllm.compilation.backends import set_model_tag
from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import (ignore_torch_compile,
support_torch_compile)
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
VllmConfig, set_current_vllm_config)
from vllm.compilation.decorators import ignore_torch_compile, support_torch_compile
from vllm.config import (
CompilationConfig,
CompilationMode,
CUDAGraphMode,
VllmConfig,
set_current_vllm_config,
)
from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.utils.torch_utils import is_torch_equal_or_newer
from ...utils import create_new_process_for_each_test
# This import automatically registers `torch.ops.silly.attention`
from .. import silly_attention # noqa: F401
......@@ -27,12 +35,7 @@ RANDOM_SEED = 0
@support_torch_compile
class ParentModel(nn.Module):
def __init__(self,
*,
vllm_config: VllmConfig,
prefix: str = '',
**kwargs) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs) -> None:
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
......@@ -40,7 +43,6 @@ class ParentModel(nn.Module):
class Attention(nn.Module):
def __init__(self, mlp_size: int, hidden_size: int) -> None:
super().__init__()
self.pre_attn = nn.Linear(mlp_size, hidden_size, bias=False)
......@@ -51,17 +53,21 @@ class Attention(nn.Module):
nn.init.xavier_normal_(
self.pre_attn.weight.data,
generator=torch.Generator().manual_seed(RANDOM_SEED),
gain=0.001)
gain=0.001,
)
nn.init.xavier_normal_(
self.post_attn.weight.data,
generator=torch.Generator().manual_seed(RANDOM_SEED),
gain=0.001)
gain=0.001,
)
def rms_norm_ref(self, x: torch.Tensor) -> torch.Tensor:
x_f32 = x.float()
return (x_f32 * torch.rsqrt(
torch.mean(x_f32.square(), dim=-1, keepdim=True) + 1e-6) *
self.rms_norm_weight).to(x.dtype)
return (
x_f32
* torch.rsqrt(torch.mean(x_f32.square(), dim=-1, keepdim=True) + 1e-6)
* self.rms_norm_weight
).to(x.dtype)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.pre_attn(x)
......@@ -76,14 +82,15 @@ class Attention(nn.Module):
@support_torch_compile
class CompiledAttention(nn.Module):
def __init__(self,
*,
mlp_size: int,
hidden_size: int,
vllm_config: VllmConfig,
prefix: str = '',
**kwargs) -> None:
def __init__(
self,
*,
mlp_size: int,
hidden_size: int,
vllm_config: VllmConfig,
prefix: str = "",
**kwargs,
) -> None:
super().__init__()
self.attn = Attention(mlp_size, hidden_size)
......@@ -93,21 +100,21 @@ class CompiledAttention(nn.Module):
@support_torch_compile
class CompiledAttentionTwo(CompiledAttention):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.attn(x) + x
@ignore_torch_compile
class SimpleModelWithTwoGraphs(ParentModel):
def __init__(self,
*,
mlp_size: int,
hidden_size: int,
vllm_config: VllmConfig,
prefix: str = '',
**kwargs) -> None:
def __init__(
self,
*,
mlp_size: int,
hidden_size: int,
vllm_config: VllmConfig,
prefix: str = "",
**kwargs,
) -> None:
super().__init__(vllm_config=vllm_config, prefix=prefix)
# Test will fail without set_model_tag here with error:
# "ValueError: too many values to unpack (expected 3)"
......@@ -142,118 +149,174 @@ class SimpleModelWithTwoGraphs(ParentModel):
@torch.inference_mode
def run_model(vllm_config: VllmConfig, model: nn.Module, inputs: torch.Tensor,
cudagraph_runtime_mode: CUDAGraphMode):
def run_model(
vllm_config: VllmConfig,
model: nn.Module,
inputs: torch.Tensor,
cudagraph_runtime_mode: CUDAGraphMode,
):
with set_forward_context({}, vllm_config=vllm_config):
# warmup for the model with cudagraph_mode NONE
model(inputs)
# simulate cudagraphs capturing
with set_forward_context({},
vllm_config=vllm_config,
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=BatchDescriptor(
num_tokens=2, )):
with set_forward_context(
{},
vllm_config=vllm_config,
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=BatchDescriptor(
num_tokens=2,
),
):
model(inputs[:2])
with set_forward_context({},
vllm_config=vllm_config,
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=BatchDescriptor(
num_tokens=1, )):
with set_forward_context(
{},
vllm_config=vllm_config,
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=BatchDescriptor(
num_tokens=1,
),
):
model(inputs[:1])
# simulate cudagraphs replay
with set_forward_context({},
vllm_config=vllm_config,
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=BatchDescriptor(
num_tokens=2, )):
with set_forward_context(
{},
vllm_config=vllm_config,
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=BatchDescriptor(
num_tokens=2,
),
):
output = model(inputs[:2])
output = output.cpu()
return output.cpu()
def test_multi_graph_piecewise_compile_outputs_equal():
@pytest.mark.parametrize("use_inductor_graph_partition", [False, True])
@pytest.mark.parametrize("use_bytecode_hook", [True, False])
@create_new_process_for_each_test("spawn")
def test_multi_graph_piecewise_compile(
use_inductor_graph_partition: bool, use_bytecode_hook: bool, monkeypatch
):
# Set the environment variable for this test
monkeypatch.setenv("VLLM_USE_BYTECODE_HOOK", "1" if use_bytecode_hook else "0")
if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
outputs = []
# piecewise compile
vllm_config = VllmConfig(compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
use_cudagraph=True,
splitting_ops=["silly.attention"],
cudagraph_capture_sizes=[1, 2],
))
# vllmcompile compile
vllm_config = VllmConfig(
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
cudagraph_mode=CUDAGraphMode.PIECEWISE,
splitting_ops=["silly::attention"],
cudagraph_capture_sizes=[1, 2],
use_inductor_graph_partition=use_inductor_graph_partition,
)
)
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
with set_current_vllm_config(vllm_config):
model = SimpleModelWithTwoGraphs(mlp_size=MLP_SIZE,
hidden_size=HIDDEN_SIZE,
vllm_config=vllm_config,
prefix='').eval().cuda()
model = (
SimpleModelWithTwoGraphs(
mlp_size=MLP_SIZE,
hidden_size=HIDDEN_SIZE,
vllm_config=vllm_config,
prefix="",
)
.eval()
.cuda()
)
# Pre-allocate memory for CUDAGraph which expects
# static tensor addresses
inputs = torch.randn(BATCH_SIZE, MLP_SIZE).cuda()
if use_inductor_graph_partition:
# Splitting happens at Inductor lowering level,
# total piecewise fx graphs is equal to total graphs
num_piecewise_fx = 2
num_piecewise_capturable_fx = 2
else:
# attn_one, attn_two each has 3 piecewise graphs
# (pre attn, post attn, silly_attention) each
num_piecewise_fx = 6
# attn_one, attn_two has pre attn and post attn each, total=4
num_piecewise_capturable_fx = 4
with compilation_counter.expect(
num_graphs_seen=2, # two graphs for the model
num_piecewise_graphs_seen=6,
# attn_one, attn_two each has 3 piecewise graphs
# (pre attn, post attn, silly_attention) each
num_piecewise_capturable_graphs_seen=4,
# attn_one, attn_two has pre attn and post attn each, total=4
num_backend_compilations=4, # num_piecewise_capturable_graphs_seen
num_cudagraph_captured=8,
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
num_graphs_seen=2, # two graphs for the model
num_piecewise_graphs_seen=num_piecewise_fx,
num_piecewise_capturable_graphs_seen=num_piecewise_capturable_fx,
num_backend_compilations=num_piecewise_capturable_fx,
num_cudagraph_captured=8, # num_cudagraph_sizes * num_partitions
):
outputs.append(
run_model(vllm_config, model, inputs, cudagraph_runtime_mode))
outputs.append(run_model(vllm_config, model, inputs, cudagraph_runtime_mode))
# no compile or cudagraph
vllm_config = VllmConfig(compilation_config=CompilationConfig(
level=CompilationLevel.NO_COMPILATION, ))
vllm_config = VllmConfig(
compilation_config=CompilationConfig(
mode=CompilationMode.NONE,
)
)
cudagraph_runtime_mode = CUDAGraphMode.NONE
with set_current_vllm_config(vllm_config):
model = SimpleModelWithTwoGraphs(mlp_size=MLP_SIZE,
hidden_size=HIDDEN_SIZE,
vllm_config=vllm_config,
prefix='').eval().cuda()
model = (
SimpleModelWithTwoGraphs(
mlp_size=MLP_SIZE,
hidden_size=HIDDEN_SIZE,
vllm_config=vllm_config,
prefix="",
)
.eval()
.cuda()
)
with compilation_counter.expect(
num_graphs_seen=0,
num_piecewise_graphs_seen=0,
num_piecewise_capturable_graphs_seen=0,
num_backend_compilations=0,
num_cudagraph_captured=0,
num_graphs_seen=0,
num_piecewise_graphs_seen=0,
num_piecewise_capturable_graphs_seen=0,
num_backend_compilations=0,
num_cudagraph_captured=0,
):
outputs.append(
run_model(vllm_config, model, inputs, cudagraph_runtime_mode))
outputs.append(run_model(vllm_config, model, inputs, cudagraph_runtime_mode))
# piecewise compile without CUDA graph
vllm_config = VllmConfig(compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
use_cudagraph=False,
splitting_ops=["silly.attention"],
))
vllm_config = VllmConfig(
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
cudagraph_mode=CUDAGraphMode.NONE,
splitting_ops=["silly::attention"],
use_inductor_graph_partition=use_inductor_graph_partition,
)
)
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
with set_current_vllm_config(vllm_config):
model = SimpleModelWithTwoGraphs(mlp_size=MLP_SIZE,
hidden_size=HIDDEN_SIZE,
vllm_config=vllm_config,
prefix='').eval().cuda()
model = (
SimpleModelWithTwoGraphs(
mlp_size=MLP_SIZE,
hidden_size=HIDDEN_SIZE,
vllm_config=vllm_config,
prefix="",
)
.eval()
.cuda()
)
with compilation_counter.expect(
num_graphs_seen=2,
num_piecewise_graphs_seen=6,
num_piecewise_capturable_graphs_seen=4,
num_backend_compilations=4,
num_cudagraph_captured=0, # no cudagraph captured
num_graphs_seen=2,
num_piecewise_graphs_seen=num_piecewise_fx,
num_piecewise_capturable_graphs_seen=num_piecewise_capturable_fx,
num_backend_compilations=num_piecewise_capturable_fx,
num_cudagraph_captured=0, # no cudagraph captured
):
outputs.append(
run_model(vllm_config, model, inputs, cudagraph_runtime_mode))
outputs.append(run_model(vllm_config, model, inputs, cudagraph_runtime_mode))
# Generally don't expect outputs with and without inductor
# to be bitwise equivalent
......
......@@ -11,11 +11,17 @@ from torch import nn
from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
VllmConfig, set_current_vllm_config)
from vllm.envs import VLLM_USE_V1
from vllm.config import (
CompilationConfig,
CompilationMode,
CUDAGraphMode,
VllmConfig,
set_current_vllm_config,
)
from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.utils import is_torch_equal_or_newer
from vllm.utils.torch_utils import is_torch_equal_or_newer
from ...utils import create_new_process_for_each_test
# This import automatically registers `torch.ops.silly.attention`
from ..silly_attention import get_global_counter, reset_global_counter
......@@ -23,12 +29,7 @@ from ..silly_attention import get_global_counter, reset_global_counter
@support_torch_compile
class SillyModel(nn.Module):
def __init__(self,
*,
vllm_config: VllmConfig,
prefix: str = '',
**kwargs) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs) -> None:
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
......@@ -60,53 +61,64 @@ def _run_simple_model(
expected_num_backend_compilations,
expected_num_cudagraph_captured,
):
vllm_config = VllmConfig(compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
use_cudagraph=True,
use_inductor=use_inductor,
splitting_ops=splitting_ops,
use_inductor_graph_partition=use_inductor_graph_partition,
cudagraph_copy_inputs=True,
cudagraph_capture_sizes=[1, 2],
))
vllm_config = VllmConfig(
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
use_inductor=use_inductor,
splitting_ops=splitting_ops,
use_inductor_graph_partition=use_inductor_graph_partition,
cudagraph_copy_inputs=True,
cudagraph_capture_sizes=[1, 2],
)
)
with set_current_vllm_config(vllm_config):
model = SillyModel(vllm_config=vllm_config, prefix='')
model = SillyModel(vllm_config=vllm_config, prefix="")
inputs = torch.randn(100).cuda()
with compilation_counter.expect(
with (
compilation_counter.expect(
num_graphs_seen=1, # one graph for the model
num_piecewise_graphs_seen=expected_num_piecewise_graphs_seen,
num_piecewise_capturable_graphs_seen=
expected_num_piecewise_capturable_graphs_seen,
num_piecewise_capturable_graphs_seen=expected_num_piecewise_capturable_graphs_seen,
num_backend_compilations=expected_num_backend_compilations,
num_cudagraph_captured=expected_num_cudagraph_captured,
), set_forward_context(None,
vllm_config=vllm_config): # background context
),
set_forward_context(None, vllm_config=vllm_config),
): # background context
# warm up with background context
model(inputs)
# capturing/replaying should under context of cudagraph dispatching
with set_forward_context(
None,
vllm_config=vllm_config,
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
batch_descriptor=BatchDescriptor(num_tokens=2, )):
None,
vllm_config=vllm_config,
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
batch_descriptor=BatchDescriptor(
num_tokens=2,
),
):
model(torch.randn(2).cuda())
with set_forward_context(
None,
vllm_config=vllm_config,
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
batch_descriptor=BatchDescriptor(num_tokens=1, )):
None,
vllm_config=vllm_config,
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
batch_descriptor=BatchDescriptor(
num_tokens=1,
),
):
model(torch.randn(1).cuda())
input = torch.zeros(2).cuda()
reset_global_counter()
with set_forward_context(
None,
vllm_config=vllm_config,
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
batch_descriptor=BatchDescriptor(num_tokens=2, )):
None,
vllm_config=vllm_config,
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
batch_descriptor=BatchDescriptor(
num_tokens=2,
),
):
output = model(input)
assert get_global_counter() == 2
assert torch.allclose(output.cpu(), torch.tensor([19.0, 19.0]))
......@@ -114,42 +126,42 @@ def _run_simple_model(
@pytest.mark.parametrize("use_inductor", [True, False])
@torch.inference_mode()
@create_new_process_for_each_test("spawn")
def test_simple_piecewise_compile(use_inductor):
assert VLLM_USE_V1
_run_simple_model(
splitting_ops=["silly.attention"],
splitting_ops=["silly::attention"],
use_inductor_graph_partition=False,
use_inductor=use_inductor,
expected_num_piecewise_graphs_seen=5, # 2 * num_layers + 1
expected_num_piecewise_capturable_graphs_seen=3, # 1 + num_layers
expected_num_backend_compilations=
3, # num_piecewise_capturable_graphs_seen
expected_num_cudagraph_captured=
6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
# 2 * num_layers + 1
expected_num_piecewise_graphs_seen=5,
# 1 + num_layers
expected_num_piecewise_capturable_graphs_seen=3,
# num_piecewise_capturable_graphs_seen
expected_num_backend_compilations=3,
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
expected_num_cudagraph_captured=6,
)
@torch.inference_mode()
@pytest.mark.parametrize("splitting_ops", [["silly.attention"], []])
def test_simple_inductor_graph_partition(splitting_ops):
assert VLLM_USE_V1
def test_simple_inductor_graph_partition(monkeypatch):
if not is_torch_equal_or_newer("2.9.0.dev"):
pytest.skip("inductor graph partition is only available "
"in PyTorch 2.9+")
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
# disable compile cache so that we run separately for different splitting_ops
# and get the expected number of cudagraphs captured.
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
_run_simple_model(
# inductor graph partition automatically resets splitting_ops
# to be an empty list
splitting_ops=splitting_ops,
splitting_ops=["silly::attention"],
use_inductor_graph_partition=True,
use_inductor=True,
expected_num_piecewise_graphs_seen=
1, # since not splitting at fx graph level
expected_num_piecewise_capturable_graphs_seen=
1, # since not splitting at fx graph level
expected_num_backend_compilations=
1, # since not splitting at fx graph level
expected_num_cudagraph_captured=
6, # inductor graph partition still captures 6
# graph, same as fx graph partition.
# Since not splitting at fx graph level
expected_num_piecewise_graphs_seen=1,
# Since not splitting at fx graph level
expected_num_piecewise_capturable_graphs_seen=1,
# Since not splitting at fx graph level
expected_num_backend_compilations=1,
# Inductor graph partition still captures 6 graph, same as fx graph partition
expected_num_cudagraph_captured=6,
)
......@@ -8,8 +8,10 @@ This is a tractable model, the weights and computation are specially designed
if the config `tractable_init` is set to True. Otherwise, the weights are
initialized randomly with a fixed seed.
"""
from copy import deepcopy
from dataclasses import dataclass
from typing import Any, Optional
from typing import Any
import pytest
import torch
......@@ -17,9 +19,17 @@ from torch import nn
from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
VllmConfig, set_current_vllm_config)
from vllm.config import (
CompilationConfig,
CompilationMode,
CUDAGraphMode,
VllmConfig,
set_current_vllm_config,
)
from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.utils.torch_utils import is_torch_equal_or_newer
from ...utils import create_new_process_for_each_test
# This import automatically registers `torch.ops.silly.attention`
from .. import silly_attention # noqa: F401
......@@ -43,15 +53,14 @@ class LlamaConfig:
factors.append((k, v))
factors.sort()
import hashlib
return hashlib.md5(str(factors).encode(),
usedforsecurity=False).hexdigest()
return hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
def __post_init__(self):
assert self.mlp_size >= self.hidden_size
class LlamaMLP(nn.Module):
def __init__(self, config: LlamaConfig) -> None:
super().__init__()
self.gate_up_projection = nn.Linear(
......@@ -66,31 +75,31 @@ class LlamaMLP(nn.Module):
)
if config.tractable_init:
nn.init.eye_(self.gate_up_projection.weight.data[:config.mlp_size])
nn.init.eye_(self.gate_up_projection.weight.data[config.mlp_size:])
nn.init.eye_(self.gate_up_projection.weight.data[: config.mlp_size])
nn.init.eye_(self.gate_up_projection.weight.data[config.mlp_size :])
nn.init.eye_(self.down_projection.weight.data)
else:
nn.init.xavier_normal_(self.gate_up_projection.weight.data,
generator=torch.Generator().manual_seed(
config.random_seed),
gain=0.001)
nn.init.xavier_normal_(self.down_projection.weight.data,
generator=torch.Generator().manual_seed(
config.random_seed),
gain=0.001)
nn.init.xavier_normal_(
self.gate_up_projection.weight.data,
generator=torch.Generator().manual_seed(config.random_seed),
gain=0.001,
)
nn.init.xavier_normal_(
self.down_projection.weight.data,
generator=torch.Generator().manual_seed(config.random_seed),
gain=0.001,
)
def forward(self, x):
# for tractable_init and positive input, this is
# essentially an elementwise-square
x = self.gate_up_projection(x)
x = x[:, :x.size(1) // 2] * torch.nn.functional.relu(
x[:, x.size(1) // 2:])
x = x[:, : x.size(1) // 2] * torch.nn.functional.relu(x[:, x.size(1) // 2 :])
x = self.down_projection(x)
return x
class LlamaAttention(nn.Module):
def __init__(self, config: LlamaConfig) -> None:
super().__init__()
self.qkv_projection = nn.Linear(
......@@ -106,21 +115,25 @@ class LlamaAttention(nn.Module):
)
if config.tractable_init:
nn.init.eye_(self.qkv_projection.weight.data[:config.hidden_size])
nn.init.eye_(self.qkv_projection.weight.data[config.hidden_size:2 *
config.hidden_size])
nn.init.eye_(self.qkv_projection.weight.data[2 *
config.hidden_size:])
nn.init.eye_(self.qkv_projection.weight.data[: config.hidden_size])
nn.init.eye_(
self.qkv_projection.weight.data[
config.hidden_size : 2 * config.hidden_size
]
)
nn.init.eye_(self.qkv_projection.weight.data[2 * config.hidden_size :])
nn.init.eye_(self.output_projection.weight.data)
else:
nn.init.xavier_normal_(self.qkv_projection.weight.data,
generator=torch.Generator().manual_seed(
config.random_seed),
gain=0.001)
nn.init.xavier_normal_(self.output_projection.weight.data,
generator=torch.Generator().manual_seed(
config.random_seed),
gain=0.001)
nn.init.xavier_normal_(
self.qkv_projection.weight.data,
generator=torch.Generator().manual_seed(config.random_seed),
gain=0.001,
)
nn.init.xavier_normal_(
self.output_projection.weight.data,
generator=torch.Generator().manual_seed(config.random_seed),
gain=0.001,
)
def forward(
self,
......@@ -144,7 +157,6 @@ class LlamaAttention(nn.Module):
class LlamaDecoderLayer(nn.Module):
def __init__(self, config: LlamaConfig) -> None:
super().__init__()
self.self_attention = LlamaAttention(config)
......@@ -154,7 +166,7 @@ class LlamaDecoderLayer(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
residual: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
For tractable computation:
......@@ -164,7 +176,7 @@ class LlamaDecoderLayer(nn.Module):
- if residual is not None, the outputs are:
- residual = (hidden_states + residual + 1) * 3 + positions * 2 + hidden_states + residual = (hidden_states + residual) * 4 + positions * 2 + 3
- hidden_states = (residual + 1) ** 2
""" # noqa
""" # noqa
if residual is None:
residual = hidden_states
hidden_states = hidden_states + 1
......@@ -173,8 +185,9 @@ class LlamaDecoderLayer(nn.Module):
residual = hidden_states
hidden_states = hidden_states + 1
hidden_states = self.self_attention(positions=positions,
hidden_states=hidden_states)
hidden_states = self.self_attention(
positions=positions, hidden_states=hidden_states
)
hidden_states = hidden_states + residual
residual = hidden_states
......@@ -186,27 +199,29 @@ class LlamaDecoderLayer(nn.Module):
@support_torch_compile
class LlamaModel(nn.Module):
def __init__(self,
*,
vllm_config: VllmConfig,
config: LlamaConfig,
prefix: str = '',
**kwargs) -> None:
def __init__(
self,
*,
vllm_config: VllmConfig,
config: LlamaConfig,
prefix: str = "",
**kwargs,
) -> None:
super().__init__()
self.embedding_tokens = nn.Embedding(
num_embeddings=config.vocab_size,
embedding_dim=config.hidden_size,
)
self.layers = nn.ModuleList(
[LlamaDecoderLayer(config) for _ in range(config.num_layers)])
[LlamaDecoderLayer(config) for _ in range(config.num_layers)]
)
# this is the initial value of the hidden states
self.embedding_tokens.weight.data.fill_(config.init_value)
def forward(
self,
input_ids: Optional[torch.Tensor],
input_ids: torch.Tensor | None,
positions: torch.Tensor,
) -> torch.Tensor:
hidden_states = self.embedding_tokens(input_ids)
......@@ -216,168 +231,195 @@ class LlamaModel(nn.Module):
return hidden_states
def tractable_computation(input_ids: torch.Tensor,
positions: torch.Tensor,
config: LlamaConfig,
init_value: float = 1.0) -> torch.Tensor:
hidden_states = torch.ones(input_ids.size(0),
config.hidden_size,
device=input_ids.device,
dtype=input_ids.dtype) * init_value
def tractable_computation(
input_ids: torch.Tensor,
positions: torch.Tensor,
config: LlamaConfig,
init_value: float = 1.0,
) -> torch.Tensor:
hidden_states = (
torch.ones(
input_ids.size(0),
config.hidden_size,
device=input_ids.device,
dtype=input_ids.dtype,
)
* init_value
)
# first layer
residual = hidden_states * 4 + positions.unsqueeze(1) * 2 + 3
hidden_states = (residual + 1)**2
hidden_states = (residual + 1) ** 2
# following layers
for _ in range(config.num_layers - 1):
hidden_states = hidden_states + residual
residual = hidden_states * 4 + positions.unsqueeze(1) * 2 + 3
hidden_states = (residual + 1)**2
hidden_states = (residual + 1) ** 2
return hidden_states
@torch.inference_mode
def run_model(llama_config,
use_compile: bool,
use_inductor: bool,
split_attn: bool = False) -> torch.Tensor:
if use_compile:
compilation_config = CompilationConfig(
level=CompilationLevel.PIECEWISE,
use_cudagraph=True,
use_inductor=use_inductor,
cudagraph_capture_sizes=[1, 2],
)
if split_attn:
compilation_config.splitting_ops = ["silly.attention"]
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
else:
compilation_config = CompilationConfig(
level=CompilationLevel.NO_COMPILATION, )
cudagraph_runtime_mode = CUDAGraphMode.NONE
vllm_config = VllmConfig(compilation_config=compilation_config,
additional_config=llama_config)
def run_model(llama_config, compile_config: CompilationConfig) -> torch.Tensor:
# Start with a fresh copy to make sure there's no cache dir sharing
compile_config = deepcopy(compile_config)
cudagraph_runtime_mode = compile_config.cudagraph_mode
vllm_config = VllmConfig(
compilation_config=compile_config, additional_config=llama_config
)
with set_current_vllm_config(vllm_config):
model = LlamaModel(config=llama_config,
vllm_config=vllm_config,
prefix="").eval().cuda()
model = (
LlamaModel(config=llama_config, vllm_config=vllm_config, prefix="")
.eval()
.cuda()
)
with set_forward_context({},
vllm_config=vllm_config): # background context
with set_forward_context({}, vllm_config=vllm_config): # background context
B = 16 # max batch size
input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda()
input_ids = torch.randint(0, llama_config.vocab_size, (B,)).cuda()
positions = torch.arange(B).cuda()
# warmup for the model with cudagraph_mode NONE
model(input_ids, positions)
# simulate cudagraphs capturing
with set_forward_context({},
vllm_config=vllm_config,
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=BatchDescriptor(
num_tokens=2, )):
with set_forward_context(
{},
vllm_config=vllm_config,
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=BatchDescriptor(
num_tokens=2,
),
):
model(input_ids[:2], positions[:2])
with set_forward_context({},
vllm_config=vllm_config,
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=BatchDescriptor(
num_tokens=1, )):
with set_forward_context(
{},
vllm_config=vllm_config,
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=BatchDescriptor(
num_tokens=1,
),
):
model(input_ids[:1], positions[:1])
input_ids[:2].zero_()
# simulate cudagraphs replay
with set_forward_context({},
vllm_config=vllm_config,
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=BatchDescriptor(
num_tokens=2, )):
with set_forward_context(
{},
vllm_config=vllm_config,
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=BatchDescriptor(
num_tokens=2,
),
):
output = model(input_ids[:2], positions[:2])
output = output.cpu()
if llama_config.tractable_init:
expected_output = tractable_computation(input_ids[:2],
positions[:2],
llama_config).cpu()
expected_output = tractable_computation(
input_ids[:2], positions[:2], llama_config
).cpu()
assert torch.allclose(output, expected_output)
else:
return output.cpu()
@pytest.mark.parametrize("use_inductor", [True, False])
def test_toy_llama(use_inductor: bool):
@pytest.mark.parametrize(
"backend, use_inductor_graph_partition",
[
("eager", False), # No inductor
("inductor", False), # Inductor, Dynamo partition
("inductor", True), # Inductor, Inductor partition
],
)
@create_new_process_for_each_test("spawn")
def test_toy_llama(
backend: str, use_inductor_graph_partition: bool, monkeypatch, tmp_path
):
# We disable the vLLM compile cache into a new tmp dir for 1 reason:
# 1. To make sure we can properly track the number of Inductor compilations.
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
pytest.skip("Inductor graph partition only supported in torch>=2.9")
# compare output with and without piecewise compilation
llama_config = LlamaConfig(hidden_size=128,
mlp_size=256,
vocab_size=128,
num_layers=12)
llama_config = LlamaConfig(
hidden_size=128, mlp_size=256, vocab_size=128, num_layers=12
)
tractable_config = LlamaConfig(
hidden_size=128, mlp_size=256, vocab_size=128, num_layers=2, tractable_init=True
)
compile_config_no_compile = CompilationConfig(
mode=CompilationMode.NONE,
cudagraph_mode=CUDAGraphMode.NONE,
backend="eager",
)
compile_config_no_split = CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
use_inductor_graph_partition=use_inductor_graph_partition,
cudagraph_mode=CUDAGraphMode.PIECEWISE,
backend=backend,
cudagraph_capture_sizes=[1, 2],
)
tractable_config = LlamaConfig(hidden_size=128,
mlp_size=256,
vocab_size=128,
num_layers=2,
tractable_init=True)
compile_config_split = deepcopy(compile_config_no_split)
compile_config_split.splitting_ops = ["silly::attention"]
outputs = []
with compilation_counter.expect(
num_graphs_seen=0,
num_piecewise_graphs_seen=0,
num_piecewise_capturable_graphs_seen=0,
num_backend_compilations=0,
num_cudagraph_captured=0,
num_graphs_seen=0,
num_piecewise_graphs_seen=0,
num_piecewise_capturable_graphs_seen=0,
num_backend_compilations=0,
num_cudagraph_captured=0,
):
outputs.append(
run_model(llama_config, use_inductor=False, use_compile=False))
run_model(tractable_config, use_inductor=False, use_compile=False)
outputs.append(run_model(llama_config, compile_config_no_compile))
if use_inductor:
run_model(tractable_config, compile_config_no_compile)
if backend == "inductor":
kwargs = {"num_inductor_compiles": 1, "num_eager_compiles": 0}
else:
kwargs = {"num_eager_compiles": 1, "num_inductor_compiles": 0}
with compilation_counter.expect(
num_graphs_seen=1, # one graph for the model
num_piecewise_graphs_seen=1,
num_piecewise_capturable_graphs_seen=1,
num_backend_compilations=1, # num_piecewise_capturable_graphs_seen
num_cudagraph_captured=
2, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
**kwargs,
num_graphs_seen=1, # one graph for the model
num_piecewise_graphs_seen=1,
num_piecewise_capturable_graphs_seen=1,
num_backend_compilations=1, # num_piecewise_capturable_graphs_seen
num_cudagraph_captured=2,
**kwargs,
):
outputs.append(
run_model(llama_config,
use_inductor=use_inductor,
use_compile=True))
run_model(tractable_config, use_inductor=use_inductor, use_compile=True)
outputs.append(run_model(llama_config, compile_config_no_split))
run_model(tractable_config, compile_config_no_split)
if use_inductor_graph_partition:
num_piecewise_fx = 1
num_piecewise_capturable_fx = 1
else:
num_piecewise_fx = 2 * llama_config.num_layers + 1
num_piecewise_capturable_fx = 1 + llama_config.num_layers
with compilation_counter.expect(
num_graphs_seen=1, # one graph for the model
num_piecewise_graphs_seen=2 * llama_config.num_layers +
1, # 2 * num_layers + 1
num_piecewise_capturable_graphs_seen=1 +
llama_config.num_layers, # 1 + num_layers
num_backend_compilations=1 +
llama_config.num_layers, # num_piecewise_capturable_graphs_seen
num_cudagraph_captured=2 *
(1 + llama_config.num_layers
), # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
num_graphs_seen=1, # one graph for the model
num_piecewise_graphs_seen=num_piecewise_fx,
num_piecewise_capturable_graphs_seen=num_piecewise_capturable_fx,
num_backend_compilations=num_piecewise_capturable_fx,
# num_cudagraph_sizes * num_partitions
num_cudagraph_captured=2 * (1 + llama_config.num_layers),
):
outputs.append(
run_model(llama_config,
use_inductor=use_inductor,
use_compile=True,
split_attn=True))
run_model(tractable_config,
use_inductor=use_inductor,
use_compile=True,
split_attn=True)
outputs.append(run_model(llama_config, compile_config_split))
run_model(tractable_config, compile_config_split)
for i in range(1, len(outputs)):
assert torch.allclose(outputs[0], outputs[i])
......@@ -388,17 +430,15 @@ def benchmark():
from triton.testing import do_bench
# similar to llama 3.1-8B
llama_config = LlamaConfig(hidden_size=4096,
mlp_size=14336,
vocab_size=128 * 1024,
num_layers=32)
llama_config = LlamaConfig(
hidden_size=4096, mlp_size=14336, vocab_size=128 * 1024, num_layers=32
)
# a tiny model to measure the overhead
# of piecewise cudagraph
llama_config = LlamaConfig(hidden_size=40,
mlp_size=80,
vocab_size=128,
num_layers=2)
llama_config = LlamaConfig(
hidden_size=40, mlp_size=80, vocab_size=128, num_layers=2
)
cudagraph_sizes = [1, 2, 4] + [i * 8 for i in range(1, 33)]
......@@ -411,25 +451,27 @@ def benchmark():
for piecewise in [False, True]:
if piecewise:
compilation_config = CompilationConfig(
level=CompilationLevel.PIECEWISE,
use_cudagraph=True,
splitting_ops=["silly.attention"],
mode=CompilationMode.VLLM_COMPILE,
splitting_ops=["silly::attention"],
cudagraph_capture_sizes=cudagraph_sizes,
)
else:
compilation_config = CompilationConfig(
level=CompilationLevel.PIECEWISE,
mode=CompilationMode.VLLM_COMPILE,
cudagraph_capture_sizes=cudagraph_sizes,
)
vllm_config = VllmConfig(compilation_config=compilation_config)
with set_current_vllm_config(vllm_config):
model = LlamaModel(config=llama_config,
vllm_config=vllm_config,
prefix="").eval().cuda().to(torch.bfloat16)
model = (
LlamaModel(config=llama_config, vllm_config=vllm_config, prefix="")
.eval()
.cuda()
.to(torch.bfloat16)
)
B = 256 # max batch size
input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda()
input_ids = torch.randint(0, llama_config.vocab_size, (B,)).cuda()
positions = torch.arange(B).cuda().to(torch.bfloat16)
graphs = {}
......@@ -451,22 +493,31 @@ def benchmark():
# and use it later, because it will look up the name `b` in the
# enclosing scope, and the value of `b` will always be 256.
# it is fine here, because we only use the lambda function once.
runtime = do_bench(lambda: graphs[b][0] # noqa
(input_ids[:b], positions[:b])) # noqa
runtime = do_bench(
lambda: graphs[b][0]( # noqa
input_ids[:b], # noqa
positions[:b], # noqa
)
)
piecewise_cudagraph_time[b] = runtime
else:
runtime = do_bench(lambda: graphs[b][0].replay()) # noqa
eager_runtime = do_bench(
lambda: model(input_ids[:b], positions[:b])) # noqa
eager_runtime = do_bench(lambda: model(input_ids[:b], positions[:b])) # noqa
full_cudagraph_time[b] = runtime
eager_time[b] = eager_runtime
# print in tabular format
print("batch size\teager mode\tfull cudagraph\tpiecewise cudagraph")
for b in cudagraph_sizes:
print(f"{b}\t{eager_time[b]:.3f}\t{full_cudagraph_time[b]:.3f}"
f"\t{piecewise_cudagraph_time[b]:.3f}")
print(
f"{b}\t{eager_time[b]:.3f}\t{full_cudagraph_time[b]:.3f}"
f"\t{piecewise_cudagraph_time[b]:.3f}"
)
if __name__ == "__main__":
benchmark()
# Protect against subprocess reimport when using spawn_new_process_for_each_test
import os
if os.environ.get("RUNNING_IN_SUBPROCESS") != "1":
benchmark()
......@@ -8,7 +8,7 @@ Centralizes custom operation definitions to avoid duplicate registrations.
import torch
from torch.library import Library
from vllm.utils import direct_register_custom_op
from vllm.utils.torch_utils import direct_register_custom_op
# Shared library for all compilation test operations
# Using "silly" namespace to match existing test expectations
......@@ -31,8 +31,9 @@ def reset_global_counter():
_global_counter = 0
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out: torch.Tensor) -> None:
def silly_attention(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor
) -> None:
"""
Unified attention implementation that depends on
all inputs and affects the output.
......@@ -47,8 +48,9 @@ def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out.copy_(q + k + v)
def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out: torch.Tensor) -> None:
def silly_attention_fake(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor
) -> None:
"""Fake implementation for testing"""
return
......@@ -60,5 +62,4 @@ direct_register_custom_op(
mutates_args=["out"],
fake_impl=silly_attention_fake,
target_lib=silly_lib,
tags=(torch._C.Tag.cudagraph_unsafe, ),
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import tempfile
from contextlib import contextmanager
import pytest
import torch
from vllm.compilation.decorators import support_torch_compile
from vllm.config import (
CompilationConfig,
CompilationMode,
VllmConfig,
set_current_vllm_config,
)
from vllm.forward_context import set_forward_context
from vllm.utils.torch_utils import is_torch_equal_or_newer
def reference_fn(x: torch.Tensor):
assert x.shape[0] <= 42
assert x.shape[0] % 2 == 0
for _ in range(3000):
x = x + x.shape[0]
return x
@support_torch_compile
class CompiledMod(torch.nn.Module):
def __init__(self, **kwargs):
super().__init__()
def forward(self, x: torch.Tensor):
return reference_fn(x)
def make_vllm_config() -> VllmConfig:
return VllmConfig(
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
)
)
@contextmanager
def use_vllm_config(vllm_config: VllmConfig):
with set_forward_context({}, vllm_config), set_current_vllm_config(vllm_config):
yield
@pytest.mark.skipif(
not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
)
def test_no_dynamo_cache_entry(monkeypatch: pytest.MonkeyPatch):
with monkeypatch.context() as m:
vllm_config = make_vllm_config()
args = (torch.randn(10, 10),)
expected = reference_fn(*args)
with use_vllm_config(vllm_config):
m.setenv("VLLM_USE_AOT_COMPILE", "0")
with (
pytest.raises(RuntimeError, match="Detected recompile"),
torch.compiler.set_stance("fail_on_recompile"),
):
CompiledMod(vllm_config=vllm_config)(*args)
m.setenv("VLLM_USE_AOT_COMPILE", "1")
torch._dynamo.reset()
with torch.compiler.set_stance("fail_on_recompile"):
actual = CompiledMod(vllm_config=vllm_config)(*args)
assert torch.allclose(actual, expected)
@pytest.mark.skipif(
not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
)
def test_force_aot_load(monkeypatch: pytest.MonkeyPatch):
with tempfile.TemporaryDirectory() as tmpdirname, monkeypatch.context() as m:
args = (torch.randn(10, 10),)
m.setenv("VLLM_USE_AOT_COMPILE", "1")
m.setenv("VLLM_FORCE_AOT_LOAD", "1")
m.setenv("VLLM_CACHE_ROOT", tmpdirname)
vllm_config = make_vllm_config()
with use_vllm_config(vllm_config), pytest.raises(FileNotFoundError):
CompiledMod(vllm_config=vllm_config)(*args)
@pytest.mark.skipif(
not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
)
def test_save_and_load(monkeypatch: pytest.MonkeyPatch):
with monkeypatch.context() as m:
args = (torch.randn(10, 10),)
with tempfile.TemporaryDirectory() as tmpdirname:
m.setenv("VLLM_CACHE_ROOT", tmpdirname)
m.setenv("VLLM_USE_AOT_COMPILE", "1")
vllm_config = make_vllm_config()
with use_vllm_config(vllm_config):
expected = CompiledMod(vllm_config=vllm_config)(*args)
m.setenv("VLLM_FORCE_AOT_LOAD", "1")
vllm_config = make_vllm_config()
with use_vllm_config(vllm_config):
ret = CompiledMod(vllm_config=vllm_config)(*args)
assert torch.allclose(ret, expected)
@pytest.mark.skipif(
not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
)
def test_shape_env(monkeypatch: pytest.MonkeyPatch):
"""
Test that the shape environment is correctly serialized and preserved
when loading from cache.
"""
with monkeypatch.context() as m:
args = (torch.randn(10, 10),)
with tempfile.TemporaryDirectory() as tmpdirname:
m.setenv("VLLM_CACHE_ROOT", tmpdirname)
m.setenv("VLLM_USE_AOT_COMPILE", "1")
vllm_config = make_vllm_config()
with use_vllm_config(vllm_config):
compiled_mod = CompiledMod(vllm_config=vllm_config)
compiled_mod(*args)
artifacts = compiled_mod.aot_compiled_fn._artifacts
guards_string = artifacts.compiled_fn.shape_env.format_guards()
assert guards_string == " - s77 <= 42\n - Eq(Mod(s77, 2), 0)"
m.setenv("VLLM_FORCE_AOT_LOAD", "1")
vllm_config = make_vllm_config()
with use_vllm_config(vllm_config):
compiled_mod = CompiledMod(vllm_config=vllm_config)
compiled_mod(*args)
artifacts = compiled_mod.aot_compiled_fn._artifacts
guards_string = artifacts.compiled_fn.shape_env.format_guards()
assert guards_string == " - s77 <= 42\n - Eq(Mod(s77, 2), 0)"
......@@ -8,18 +8,31 @@ import torch
import vllm.envs as envs
from vllm.compilation.collective_fusion import AsyncTPPass
from vllm.config import (CompilationConfig, DeviceConfig, ModelConfig,
PassConfig, VllmConfig)
from vllm.distributed import (tensor_model_parallel_all_gather,
tensor_model_parallel_reduce_scatter)
from vllm.distributed.parallel_state import (init_distributed_environment,
initialize_model_parallel)
from vllm.config import (
CompilationConfig,
CompilationMode,
DeviceConfig,
ModelConfig,
PassConfig,
VllmConfig,
)
from vllm.distributed import (
tensor_model_parallel_all_gather,
tensor_model_parallel_reduce_scatter,
)
from vllm.distributed.parallel_state import (
init_distributed_environment,
initialize_model_parallel,
)
from vllm.platforms import current_platform
from vllm.utils import update_environment_variables
from vllm.utils.system_utils import update_environment_variables
from ..models.registry import HF_EXAMPLE_MODELS
from ..utils import (compare_two_settings, create_new_process_for_each_test,
multi_gpu_test)
from ..utils import (
compare_two_settings,
create_new_process_for_each_test,
multi_gpu_test,
)
from .backend import TestBackend
FP8_DTYPE = current_platform.fp8_dtype()
......@@ -33,21 +46,20 @@ prompts = [
class TestMMRSModel(torch.nn.Module):
def __init__(self, hidden_size=16, dtype=torch.float16):
super().__init__()
self.hidden_size = hidden_size
self.dtype = dtype
self.gate_proj = torch.nn.Parameter(torch.empty(
(self.hidden_size * 2, hidden_size)),
requires_grad=False)
self.gate_proj = torch.nn.Parameter(
torch.empty((self.hidden_size * 2, hidden_size)), requires_grad=False
)
# Initialize weights
torch.nn.init.normal_(self.gate_proj, std=0.02)
def forward(self, hidden_states):
"""
Forward pass implementing the mm + reduce scatter in the FX graph
"""
# Reshape input
view = hidden_states.reshape(-1, self.hidden_size)
......@@ -66,14 +78,13 @@ class TestMMRSModel(torch.nn.Module):
class TestAGMMModel(torch.nn.Module):
def __init__(self, hidden_size=16, dtype=torch.float16):
super().__init__()
self.hidden_size = hidden_size
self.dtype = dtype
self.weight = torch.nn.Parameter(torch.empty(
(hidden_size, hidden_size)),
requires_grad=False)
self.weight = torch.nn.Parameter(
torch.empty((hidden_size, hidden_size)), requires_grad=False
)
# Initialize weights
torch.nn.init.normal_(self.weight, std=0.02)
......@@ -96,32 +107,35 @@ class TestAGMMModel(torch.nn.Module):
class _BaseScaledMMModel(torch.nn.Module):
def __init__(self, hidden_size=16, dtype=torch.float16):
super().__init__()
self.hidden_size = hidden_size
self.dtype = dtype
self.weight = torch.empty([hidden_size, hidden_size], dtype=FP8_DTYPE)\
.contiguous().transpose(0, 1)
self.weight = (
torch.empty([hidden_size, hidden_size], dtype=FP8_DTYPE)
.contiguous()
.transpose(0, 1)
)
# Initialize scale_b for _scaled_mm.
self.scale_b = torch.ones(1, self.hidden_size, dtype=torch.float32)
class TestScaledMMRSModel(_BaseScaledMMModel):
def forward(self, input: torch.Tensor):
"""
Forward pass implementing the scaled_mm + reduce scatter in the FX graph
"""
fp8_input = input.to(FP8_DTYPE)
scale_a = torch.ones(input.shape[0], 1, dtype=torch.float32)
scaled_mm = torch._scaled_mm(fp8_input,
self.weight,
scale_a=scale_a,
scale_b=self.scale_b,
out_dtype=self.dtype)
scaled_mm = torch._scaled_mm(
fp8_input,
self.weight,
scale_a=scale_a,
scale_b=self.scale_b,
out_dtype=self.dtype,
)
reduce_scatter = tensor_model_parallel_reduce_scatter(scaled_mm, dim=0)
return reduce_scatter
......@@ -129,11 +143,10 @@ class TestScaledMMRSModel(_BaseScaledMMModel):
return [torch.ops.vllm.reduce_scatter.default]
def ops_in_model_after(self):
return [torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter.default]
return [torch.ops.vllm.patched_fused_scaled_matmul_reduce_scatter.default]
class TestAGScaledMMModel(_BaseScaledMMModel):
def forward(self, input: torch.Tensor):
"""
Forward pass implementing the all gather + scaled_mm in the FX graph
......@@ -143,11 +156,13 @@ class TestAGScaledMMModel(_BaseScaledMMModel):
all_gather = tensor_model_parallel_all_gather(fp8_input, dim=0)
scale_a = torch.ones(all_gather.shape[0], 1, dtype=torch.float32)
scaled_mm = torch._scaled_mm(all_gather,
self.weight,
scale_a=scale_a,
scale_b=self.scale_b,
out_dtype=self.dtype)
scaled_mm = torch._scaled_mm(
all_gather,
self.weight,
scale_a=scale_a,
scale_b=self.scale_b,
out_dtype=self.dtype,
)
return scaled_mm
def ops_in_model_before(self):
......@@ -158,20 +173,22 @@ class TestAGScaledMMModel(_BaseScaledMMModel):
class TestCutlassScaledMMRSModel(_BaseScaledMMModel):
def forward(self, input: torch.Tensor):
"""
Forward pass implementing the cutlass_scaled_mm + reduce scatter
in the FX graph
"""
fp8_input = input.to(FP8_DTYPE)
scale_a = torch.ones(input.shape[0], 1, dtype=torch.float32)
mm_out = torch.empty((fp8_input.shape[0], self.weight.shape[1]),
dtype=self.dtype,
device=input.device)
torch.ops._C.cutlass_scaled_mm(mm_out, fp8_input, self.weight, scale_a,
self.scale_b, None)
mm_out = torch.empty(
(fp8_input.shape[0], self.weight.shape[1]),
dtype=self.dtype,
device=input.device,
)
torch.ops._C.cutlass_scaled_mm(
mm_out, fp8_input, self.weight, scale_a, self.scale_b, None
)
reduce_scatter = tensor_model_parallel_reduce_scatter(mm_out, dim=0)
return reduce_scatter
......@@ -179,14 +196,13 @@ class TestCutlassScaledMMRSModel(_BaseScaledMMModel):
return [torch.ops.vllm.reduce_scatter.default]
def ops_in_model_after(self):
return [torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter.default]
return [torch.ops.vllm.patched_fused_scaled_matmul_reduce_scatter.default]
class TestAGCutlassScaledMMModel(_BaseScaledMMModel):
def forward(self, input: torch.Tensor):
"""
Forward pass implementing the all gather + cutlass_scaled_mm
Forward pass implementing the all gather + cutlass_scaled_mm
in the FX graph
"""
# Reshape input
......@@ -195,11 +211,14 @@ class TestAGCutlassScaledMMModel(_BaseScaledMMModel):
scale_a = torch.ones(all_gather.shape[0], 1, dtype=torch.float32)
mm_out = torch.empty((all_gather.shape[0], self.weight.shape[1]),
dtype=self.dtype,
device=all_gather.device)
torch.ops._C.cutlass_scaled_mm(mm_out, all_gather, self.weight,
scale_a, self.scale_b, None)
mm_out = torch.empty(
(all_gather.shape[0], self.weight.shape[1]),
dtype=self.dtype,
device=all_gather.device,
)
torch.ops._C.cutlass_scaled_mm(
mm_out, all_gather, self.weight, scale_a, self.scale_b, None
)
return mm_out
def ops_in_model_before(self):
......@@ -210,23 +229,43 @@ class TestAGCutlassScaledMMModel(_BaseScaledMMModel):
@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize("test_model", [
TestMMRSModel, TestAGMMModel, TestScaledMMRSModel, TestAGScaledMMModel,
TestCutlassScaledMMRSModel, TestAGCutlassScaledMMModel
])
@pytest.mark.parametrize(
"test_model",
[
TestMMRSModel,
TestAGMMModel,
TestScaledMMRSModel,
TestAGScaledMMModel,
TestCutlassScaledMMRSModel,
TestAGCutlassScaledMMModel,
],
)
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("seq_len", [16])
@pytest.mark.parametrize("hidden_size", [16])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"],
reason="Only test on CUDA")
def test_async_tp_pass_replace(test_model: str, batch_size: int, seq_len: int,
hidden_size: int, dtype: torch.dtype):
if test_model in (TestScaledMMRSModel, TestAGScaledMMModel,
TestCutlassScaledMMRSModel,
TestAGCutlassScaledMMModel) and dtype == torch.float16:
@pytest.mark.parametrize("dynamic", [True, False])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
def test_async_tp_pass_replace(
test_model: str,
batch_size: int,
seq_len: int,
hidden_size: int,
dtype: torch.dtype,
dynamic: bool,
):
if (
test_model
in (
TestScaledMMRSModel,
TestAGScaledMMModel,
TestCutlassScaledMMRSModel,
TestAGCutlassScaledMMModel,
)
and dtype == torch.float16
):
pytest.skip(
"Only bf16 high precision output types are supported for " \
"Only bf16 high precision output types are supported for "
"per-token (row-wise) scaling"
)
......@@ -235,19 +274,33 @@ def test_async_tp_pass_replace(test_model: str, batch_size: int, seq_len: int,
def run_torch_spawn(fn, nprocs):
# need to use torch.mp.spawn otherwise will have problems with
# torch.distributed and cuda
torch.multiprocessing.spawn(fn,
args=(num_processes, test_model,
batch_size, seq_len, hidden_size,
dtype),
nprocs=nprocs)
torch.multiprocessing.spawn(
fn,
args=(
num_processes,
test_model,
batch_size,
seq_len,
hidden_size,
dtype,
dynamic,
),
nprocs=nprocs,
)
run_torch_spawn(async_tp_pass_on_test_model, num_processes)
def async_tp_pass_on_test_model(local_rank: int, world_size: int,
test_model_cls: torch.nn.Module,
batch_size: int, seq_len: int,
hidden_size: int, dtype: torch.dtype):
def async_tp_pass_on_test_model(
local_rank: int,
world_size: int,
test_model_cls: torch.nn.Module,
batch_size: int,
seq_len: int,
hidden_size: int,
dtype: torch.dtype,
dynamic: bool,
):
current_platform.seed_everything(0)
device = torch.device(f"cuda:{local_rank}")
......@@ -255,13 +308,15 @@ def async_tp_pass_on_test_model(local_rank: int, world_size: int,
torch.set_default_device(device)
torch.set_default_dtype(dtype)
update_environment_variables({
'RANK': str(local_rank),
'LOCAL_RANK': str(local_rank),
'WORLD_SIZE': str(world_size),
'MASTER_ADDR': 'localhost',
'MASTER_PORT': '12345',
})
update_environment_variables(
{
"RANK": str(local_rank),
"LOCAL_RANK": str(local_rank),
"WORLD_SIZE": str(world_size),
"MASTER_ADDR": "localhost",
"MASTER_PORT": "12345",
}
)
# initialize distributed
init_distributed_environment()
......@@ -269,27 +324,40 @@ def async_tp_pass_on_test_model(local_rank: int, world_size: int,
# configure vllm config for SequenceParallelismPass
vllm_config = VllmConfig()
vllm_config.compilation_config = CompilationConfig(pass_config=PassConfig(
enable_async_tp=True, ), )
vllm_config.compilation_config = CompilationConfig(
pass_config=PassConfig(
enable_async_tp=True,
),
)
vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
# this is a fake model name to construct the model config
# in the vllm_config, it's not really used.
model_name = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e"
vllm_config.model_config = ModelConfig(model=model_name,
trust_remote_code=True,
dtype=dtype,
seed=42)
model_name = "RedHatAI/Llama-3.2-1B-Instruct-FP8"
vllm_config.model_config = ModelConfig(
model=model_name, trust_remote_code=True, dtype=dtype, seed=42
)
async_tp_pass = AsyncTPPass(vllm_config)
backend = TestBackend(async_tp_pass)
model = test_model_cls(hidden_size,
dtype) # Pass dtype to model constructor
assert (
async_tp_pass.compilation_config.splitting_ops
== vllm_config.compilation_config.splitting_ops
)
assert (
async_tp_pass.compilation_config.use_inductor_graph_partition
== vllm_config.compilation_config.use_inductor_graph_partition
)
model = test_model_cls(hidden_size, dtype) # Pass dtype to model constructor
hidden_states = torch.randn((batch_size * seq_len, hidden_size),
dtype=dtype,
requires_grad=False)
hidden_states = torch.randn(
(batch_size * seq_len, hidden_size), dtype=dtype, requires_grad=False
)
if dynamic:
torch._dynamo.mark_dynamic(hidden_states, 0)
compiled_model = torch.compile(model, backend=backend)
compiled_model(hidden_states)
......@@ -306,10 +374,10 @@ def async_tp_pass_on_test_model(local_rank: int, world_size: int,
@create_new_process_for_each_test()
@pytest.mark.parametrize("model_id", [
"meta-llama/Llama-3.2-1B-Instruct",
"RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8"
])
@pytest.mark.parametrize(
"model_id",
["meta-llama/Llama-3.2-1B-Instruct", "RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8"],
)
@pytest.mark.parametrize("tp_size", [2])
@pytest.mark.parametrize("async_tp_enabled", [True])
@pytest.mark.parametrize("distributed_backend", ["mp"])
......@@ -342,16 +410,10 @@ def test_async_tp_pass_correctness(
common_args.append("--enforce-eager")
compilation_config = {
'level': 3,
'compile_sizes': [2, 4, 8],
'splitting_ops': [],
'pass_config': {
'enable_async_tp': async_tp_enabled
},
}
async_tp_env = tp_env = {
"VLLM_USE_V1": "1",
"mode": CompilationMode.VLLM_COMPILE,
"compile_sizes": [2, 4, 8],
"splitting_ops": [],
"pass_config": {"enable_async_tp": async_tp_enabled},
}
async_tp_args = [
......@@ -372,9 +434,4 @@ def test_async_tp_pass_correctness(
"mp",
]
compare_two_settings(model_id,
async_tp_args,
tp_args,
async_tp_env,
tp_env,
method="generate")
compare_two_settings(model_id, async_tp_args, tp_args, method="generate")
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import dataclasses
import pytest
from vllm.config import CompilationLevel
from vllm.utils import cuda_device_count_stateless
from vllm.config import CompilationMode
from vllm.utils.torch_utils import cuda_device_count_stateless
from ..utils import compare_all_settings
......@@ -23,7 +21,7 @@ class TestSetting:
# we cannot afford testing the full Cartesian product
# of all models and all levels
# of all models and all modes
@pytest.mark.parametrize(
"test_setting",
[
......@@ -79,14 +77,15 @@ class TestSetting:
method="encode",
),
# vision language model
TestSetting(
model="microsoft/Phi-3.5-vision-instruct",
model_args=["--trust-remote-code", "--max-model-len", "2048"],
pp_size=2,
tp_size=1,
attn_backend="FLASH_ATTN",
method="generate_with_image",
),
# See https://github.com/vllm-project/vllm/issues/26716.
# TestSetting(
# model="microsoft/Phi-3.5-vision-instruct",
# model_args=["--trust-remote-code", "--max-model-len", "2048"],
# pp_size=2,
# tp_size=1,
# attn_backend="FLASH_ATTN",
# method="generate_with_image",
# ),
],
)
def test_compile_correctness(
......@@ -103,43 +102,54 @@ def test_compile_correctness(
attn_backend = test_setting.attn_backend
method = test_setting.method
if cuda_device_count_stateless() < pp_size * tp_size:
pytest.skip(f"Need at least {pp_size}*{tp_size} CUDA gpus but got "
f"{cuda_device_count_stateless()}")
pytest.skip(
f"Need at least {pp_size}*{tp_size} CUDA gpus but got "
f"{cuda_device_count_stateless()}"
)
with monkeypatch.context() as m:
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
final_args = [
"--enforce-eager", *model_args, "-pp",
str(pp_size), "-tp",
str(tp_size)
*model_args,
"-pp",
str(pp_size),
"-tp",
str(tp_size),
"-O.cudagraph_mode=none",
]
all_args: list[list[str]] = []
all_envs: list[dict[str, str] | None] = []
for level in [
CompilationLevel.NO_COMPILATION,
CompilationLevel.PIECEWISE,
for comp_mode in [
CompilationMode.STOCK_TORCH_COMPILE,
CompilationMode.DYNAMO_TRACE_ONCE,
CompilationMode.VLLM_COMPILE,
]:
all_args.append(final_args + [f"-O{level}"])
all_envs.append({})
for mode in [CompilationMode.NONE, comp_mode]:
all_args.append(
final_args + [f"-O.mode={mode.name}", "-O.backend=inductor"]
)
# inductor will change the output, so we only compare if the output
# is close, not exactly the same.
compare_all_settings(
model,
all_args,
all_envs,
method=method if method != "generate" else "generate_close")
all_envs.clear()
all_args.clear()
# inductor will change the output, so we only compare if the output
# is close, not exactly the same.
compare_all_settings(
model,
all_args,
all_envs,
method=method if method != "generate" else "generate_close",
)
all_envs.clear()
all_args.clear()
for level in [
CompilationLevel.NO_COMPILATION,
CompilationLevel.DYNAMO_AS_IS,
CompilationLevel.DYNAMO_ONCE,
for mode in [
CompilationMode.NONE,
CompilationMode.STOCK_TORCH_COMPILE,
CompilationMode.DYNAMO_TRACE_ONCE,
CompilationMode.VLLM_COMPILE,
]:
all_args.append(final_args + [f"-O{level}"])
all_args.append(final_args + [f"-O.mode={mode.name}", "-O.backend=eager"])
all_envs.append({})
all_envs.append({})
compare_all_settings(model, all_args * 3, all_envs, method=method)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
from contextlib import nullcontext
from unittest.mock import patch
import pytest
from pydantic import ValidationError
import vllm
from vllm.compilation.counter import compilation_counter
from vllm.config import CompilationConfig, VllmConfig
from vllm.utils import _is_torch_equal_or_newer
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
from vllm.config.compilation import CompilationMode
from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform
from vllm.utils.torch_utils import _is_torch_equal_or_newer
# This import automatically registers `torch.ops.silly.attention`
from . import silly_attention # noqa: F401
def test_version():
assert _is_torch_equal_or_newer('2.8.0.dev20250624+cu128', '2.8.0.dev')
assert _is_torch_equal_or_newer('2.8.0a0+gitc82a174', '2.8.0.dev')
assert _is_torch_equal_or_newer('2.8.0', '2.8.0.dev')
assert _is_torch_equal_or_newer('2.8.1', '2.8.0.dev')
assert not _is_torch_equal_or_newer('2.7.1', '2.8.0.dev')
def test_version():
# Test the version comparison logic using the private function
assert _is_torch_equal_or_newer("2.8.0.dev20250624+cu128", "2.8.0.dev")
assert _is_torch_equal_or_newer("2.8.0a0+gitc82a174", "2.8.0.dev")
assert _is_torch_equal_or_newer("2.8.0", "2.8.0.dev")
assert _is_torch_equal_or_newer("2.8.1", "2.8.0.dev")
assert not _is_torch_equal_or_newer("2.7.1", "2.8.0.dev")
def test_use_cudagraphs_dynamic(monkeypatch):
assert vllm.envs.VLLM_USE_V1
vllm_config = VllmConfig()
assert vllm_config.compilation_config.use_cudagraph
monkeypatch.setenv('VLLM_USE_V1', '0')
def test_copy_pass():
vllm_config = VllmConfig()
assert not vllm_config.compilation_config.use_cudagraph
inductor_pass = FixFunctionalizationPass(vllm_config)
copied_inductor_pass = copy.deepcopy(inductor_pass)
assert (
copied_inductor_pass.compilation_config.use_inductor_graph_partition
== vllm_config.compilation_config.use_inductor_graph_partition
)
assert (
copied_inductor_pass.compilation_config.splitting_ops
== vllm_config.compilation_config.splitting_ops
)
def test_custom_op():
......@@ -41,63 +57,80 @@ def test_custom_op():
# may be influenced by other tests.
@pytest.mark.parametrize("val", ["1"])
def test_VLLM_DISABLE_COMPILE_CACHE(vllm_runner, monkeypatch, val):
assert vllm.envs.VLLM_USE_V1
# Disable multiprocessing so that the counter is in the same process
monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0')
monkeypatch.setenv('VLLM_DISABLE_COMPILE_CACHE', val)
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", val)
compilation_config = {
"use_cudagraph": False, # speed things up a bit
"cudagraph_mode": CUDAGraphMode.NONE, # speed things up a bit
}
with (
compilation_counter.expect(num_cache_entries_updated=0,
num_compiled_artifacts_saved=0),
# loading the model causes compilation (if enabled) to happen
vllm_runner('facebook/opt-125m',
compilation_config=compilation_config,
gpu_memory_utilization=0.4) as _):
compilation_counter.expect(
num_cache_entries_updated=0, num_compiled_artifacts_saved=0
),
# loading the model causes compilation (if enabled) to happen
vllm_runner(
"facebook/opt-125m",
compilation_config=compilation_config,
gpu_memory_utilization=0.4,
) as _,
):
pass
# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
@pytest.mark.forked
@pytest.mark.parametrize("enabled", [True, False])
def test_use_cudagraphs(vllm_runner, monkeypatch, enabled):
assert vllm.envs.VLLM_USE_V1
@pytest.mark.parametrize(
"cudagraph_mode,num_cudagraph_captured",
[
(CUDAGraphMode.NONE, 0),
(CUDAGraphMode.FULL_DECODE_ONLY, 1),
(CUDAGraphMode.PIECEWISE, 13),
(CUDAGraphMode.FULL_AND_PIECEWISE, 14),
],
)
def test_use_cudagraphs(
vllm_runner, monkeypatch, cudagraph_mode, num_cudagraph_captured
):
# Disable multiprocessing so that the counter is in the same process
monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0')
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
compilation_config = {
"cudagraph_capture_sizes": [100],
"use_cudagraph": enabled,
"cudagraph_mode": cudagraph_mode,
}
num_gpu_runner_capture_triggers = 1 if cudagraph_mode != CUDAGraphMode.NONE else 0
with (
compilation_counter.expect(
num_graphs_seen=1,
num_gpu_runner_capture_triggers=1 if enabled else 0,
num_cudagraph_captured=13 if enabled else 0,
),
# loading the model causes compilation (if enabled) to happen
vllm_runner('facebook/opt-125m',
compilation_config=compilation_config,
gpu_memory_utilization=0.4) as _):
compilation_counter.expect(
num_graphs_seen=1,
num_gpu_runner_capture_triggers=num_gpu_runner_capture_triggers,
num_cudagraph_captured=num_cudagraph_captured,
),
# loading the model causes compilation (if enabled) to happen
vllm_runner(
"facebook/opt-125m",
compilation_config=compilation_config,
gpu_memory_utilization=0.4,
) as _,
):
pass
# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
@pytest.mark.forked
def test_dynamo_as_is(vllm_runner, monkeypatch):
def test_stock_torch_compile(vllm_runner, monkeypatch):
# Disable multiprocessing so that the counter is in the same process
monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0')
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
with (
compilation_counter.expect(dynamo_as_is_count=1),
# loading the model causes compilation (if enabled) to happen
vllm_runner('facebook/opt-125m',
compilation_config={"level": 1},
gpu_memory_utilization=0.4) as _):
compilation_counter.expect(stock_torch_compile_count=1),
# loading the model causes compilation (if enabled) to happen
vllm_runner(
"facebook/opt-125m",
compilation_config={"mode": CompilationMode.STOCK_TORCH_COMPILE},
gpu_memory_utilization=0.4,
) as _,
):
pass
......@@ -105,15 +138,16 @@ def test_dynamo_as_is(vllm_runner, monkeypatch):
@pytest.mark.forked
def test_no_compilation(vllm_runner, monkeypatch):
# Disable multiprocessing so that the counter is in the same process
monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0')
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
with (
compilation_counter.expect(num_graphs_seen=0,
dynamo_as_is_count=0),
# loading the model causes compilation (if enabled) to happen
vllm_runner('facebook/opt-125m',
compilation_config={"level": 0},
gpu_memory_utilization=0.4) as _):
compilation_counter.expect(num_graphs_seen=0, stock_torch_compile_count=0),
# loading the model causes compilation (if enabled) to happen
vllm_runner(
"facebook/opt-125m",
compilation_config={"mode": CompilationMode.NONE},
gpu_memory_utilization=0.4,
) as _,
):
pass
......@@ -121,13 +155,223 @@ def test_no_compilation(vllm_runner, monkeypatch):
@pytest.mark.forked
def test_enforce_eager(vllm_runner, monkeypatch):
# Disable multiprocessing so that the counter is in the same process
monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0')
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
with (
compilation_counter.expect(num_graphs_seen=0,
dynamo_as_is_count=0),
# loading the model causes compilation (if enabled) to happen
vllm_runner('facebook/opt-125m',
enforce_eager=True,
gpu_memory_utilization=0.4) as _):
compilation_counter.expect(num_graphs_seen=0, stock_torch_compile_count=0),
# loading the model causes compilation (if enabled) to happen
vllm_runner(
"facebook/opt-125m", enforce_eager=True, gpu_memory_utilization=0.4
) as _,
):
pass
def test_splitting_ops_dynamic():
# Default config
config = VllmConfig()
# Default V1 config leaves cudagraph mode unset; splitting ops are only
# populated when the engine decides to use piecewise compilation.
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.NONE
assert not config.compilation_config.splitting_ops_contain_attention()
# When use_inductor_graph_partition=True
config = VllmConfig(
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
use_inductor_graph_partition=True,
splitting_ops=["vllm::unified_attention"],
)
)
# with inductor partition we use splitting_ops directly for
# partition rules
assert config.compilation_config.splitting_ops == ["vllm::unified_attention"]
# When attn_fusion pass enabled.
config = VllmConfig(
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
pass_config={"enable_attn_fusion": True, "enable_noop": True},
custom_ops=["+quant_fp8"],
cudagraph_mode=CUDAGraphMode.PIECEWISE,
)
)
assert config.compilation_config.splitting_ops == []
# cudagraph mode also fall back to FULL
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL
# splitting_ops can not contain attention ops when attn_fusion
# pass enabled.
with pytest.raises(ValidationError):
config = VllmConfig(
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
pass_config={"enable_attn_fusion": True, "enable_noop": True},
custom_ops=["+quant_fp8"],
cudagraph_mode=CUDAGraphMode.PIECEWISE,
# work around for accessing all attntion ops
splitting_ops=CompilationConfig()._attention_ops,
)
)
# When both use_inductor_graph_partition and attn_fusion pass enabled.
config = VllmConfig(
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
use_inductor_graph_partition=True,
pass_config={"enable_attn_fusion": True, "enable_noop": True},
custom_ops=["+quant_fp8"],
cudagraph_mode=CUDAGraphMode.PIECEWISE,
)
)
# With inductor graph partition, attn_fusion and splitting_ops
# work together. Default splitting_ops include attention ops.
assert config.compilation_config.splitting_ops_contain_attention()
# enable_attn_fusion is directly supported under
# use_inductor_graph_partition=True, and cudagraph_mode
# is unchanged.
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE
def test_should_split():
import torch
from vllm.compilation.partition_rules import should_split
graph = torch.fx.Graph()
node = torch.fx.Node(
graph=graph,
name="dummy_node",
op="call_function",
target=torch.ops.aten.add.default,
args=(),
kwargs={},
)
# supports OpOverloadPacket
splitting_ops = ["aten::add"]
assert should_split(node, splitting_ops)
# supports OpOverload
splitting_ops = ["aten::add.default"]
assert should_split(node, splitting_ops)
# supports OpOverload
splitting_ops = ["aten::add.Tensor"]
assert not should_split(node, splitting_ops)
q, k, v, out = [torch.randn(1)] * 4
# supports custom ops as OpOverloadPacket
node = torch.fx.Node(
graph=graph,
name="dummy_node",
op="call_function",
target=torch.ops.silly.attention,
args=(q, k, v, out),
kwargs={},
)
splitting_ops = ["silly::attention"]
assert should_split(node, splitting_ops)
# supports custom ops as OpOverload
node = torch.fx.Node(
graph=graph,
name="dummy_node",
op="call_function",
target=torch.ops.silly.attention.default,
args=(q, k, v, out),
kwargs={},
)
splitting_ops = ["silly::attention"]
assert should_split(node, splitting_ops)
splitting_ops = ["silly::attention.default"]
assert should_split(node, splitting_ops)
@pytest.mark.skipif(
not current_platform.support_static_graph_mode(),
reason="Skip if not cudagraph mode supported",
)
@pytest.mark.parametrize(
(
"cudagraph_capture_sizes",
"max_cudagraph_capture_size",
"tp_size",
"enable_sequence_parallelism",
"max_num_batched_tokens",
"cudagraph_mode",
"expected_max_size",
),
[
(None, None, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 256),
([1, 2, 4], 4, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 4),
(
[1, 2, 4],
8,
1,
False,
2048,
CUDAGraphMode.FULL_AND_PIECEWISE,
ValidationError,
),
([1, 256], None, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 256),
([], None, 1, False, 2048, CUDAGraphMode.NONE, 0),
(None, 0, 1, False, 2048, CUDAGraphMode.NONE, 0),
# truncated to nearest multiple of 8 or 16
(None, 257, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 256),
# max from list
([1, 2, 4, 15], None, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 15),
# filtered out 15 due to SP
([1, 2, 4, 15], None, 2, True, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 4),
# limited by the max_tokens
([1, 2, 4, 15], None, 1, False, 8, CUDAGraphMode.FULL_AND_PIECEWISE, 4),
# the list should contain at least 1 element when use cudagraph
([], None, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, ValidationError),
# the max capturing size should be >= 1 when use cudagraph
(None, 0, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, ValidationError),
],
)
def test_cudagraph_sizes_post_init(
cudagraph_capture_sizes,
max_cudagraph_capture_size,
tp_size,
enable_sequence_parallelism,
max_num_batched_tokens,
cudagraph_mode,
expected_max_size,
):
ctx = nullcontext()
if expected_max_size == ValidationError:
ctx = pytest.raises(expected_max_size)
with (
ctx,
patch("vllm.config.parallel.cuda_device_count_stateless", return_value=tp_size),
):
compilation_config = CompilationConfig(
cudagraph_capture_sizes=cudagraph_capture_sizes,
max_cudagraph_capture_size=max_cudagraph_capture_size,
pass_config={
"enable_sequence_parallelism": enable_sequence_parallelism,
"enable_fusion": True,
"enable_noop": True,
},
cudagraph_mode=cudagraph_mode,
)
engine_args = EngineArgs(
model="facebook/opt-125m",
tensor_parallel_size=tp_size,
max_num_seqs=min(max_num_batched_tokens, 128),
max_num_batched_tokens=max_num_batched_tokens,
compilation_config=compilation_config,
)
vllm_config = engine_args.create_engine_config()
assert (
vllm_config.compilation_config.max_cudagraph_capture_size
== expected_max_size
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from torch import nn
from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import (ignore_torch_compile,
support_torch_compile)
from vllm.config import (CacheConfig, CompilationConfig, CompilationLevel,
CUDAGraphMode, VllmConfig, set_current_vllm_config)
from vllm.compilation.decorators import ignore_torch_compile, support_torch_compile
from vllm.config import (
CacheConfig,
CompilationConfig,
CompilationMode,
CUDAGraphMode,
VllmConfig,
set_current_vllm_config,
)
from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.utils.torch_utils import is_torch_equal_or_newer
# This import automatically registers `torch.ops.silly.attention`
from . import silly_attention # noqa: F401
......@@ -18,56 +25,86 @@ MLP_SIZE = 128
@torch.inference_mode
def run_model(vllm_config: VllmConfig, model: nn.Module,
cudagraph_runtime_mode: CUDAGraphMode):
def run_model(
vllm_config: VllmConfig, model: nn.Module, cudagraph_runtime_mode: CUDAGraphMode
):
with set_forward_context({}, vllm_config=vllm_config):
# warmup for the model with cudagraph_mode NONE
model(torch.randn(BATCH_SIZE, MLP_SIZE).cuda())
# simulate cudagraphs capturing
with set_forward_context({},
vllm_config=vllm_config,
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=BatchDescriptor(
num_tokens=2, )):
with set_forward_context(
{},
vllm_config=vllm_config,
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=BatchDescriptor(
num_tokens=2,
),
):
model(torch.randn(2, MLP_SIZE).cuda())
with set_forward_context({},
vllm_config=vllm_config,
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=BatchDescriptor(
num_tokens=1, )):
with set_forward_context(
{},
vllm_config=vllm_config,
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=BatchDescriptor(
num_tokens=1,
),
):
model(torch.randn(1, MLP_SIZE).cuda())
# simulate cudagraphs replay
with set_forward_context({},
vllm_config=vllm_config,
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=BatchDescriptor(
num_tokens=2, )):
with set_forward_context(
{},
vllm_config=vllm_config,
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=BatchDescriptor(
num_tokens=2,
),
):
output = model(torch.randn(2, MLP_SIZE).cuda())
output = output.cpu()
return output.cpu()
def test_ignore_torch_compile_decorator():
@pytest.mark.parametrize("use_inductor_graph_partition", [True, False])
def test_ignore_torch_compile_decorator(use_inductor_graph_partition, monkeypatch):
# disable compile cache so that we can count the number of compilations
# appropriately
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
# piecewise
vllm_config = VllmConfig(compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
use_cudagraph=True,
splitting_ops=["silly.attention"],
cudagraph_capture_sizes=[1, 2],
))
vllm_config = VllmConfig(
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
splitting_ops=["silly::attention"],
cudagraph_capture_sizes=[1, 2],
use_inductor_graph_partition=use_inductor_graph_partition,
)
)
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
expected_num_graphs_seen = 1
expected_num_cudagraph_captured = (
4 # num_cudagraph_sizes * num cudagraphs to capture
)
if use_inductor_graph_partition:
expected_num_piecewise_graphs_seen = 1
expected_num_piecewise_capturable_graphs_seen = 1
expected_num_backend_compilations = 1
else:
expected_num_piecewise_graphs_seen = 3
expected_num_piecewise_capturable_graphs_seen = 2
expected_num_backend_compilations = 2
@support_torch_compile
class A(nn.Module):
def __init__(self,
*,
vllm_config: VllmConfig,
prefix: str = '',
**kwargs) -> None:
def __init__(
self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs
) -> None:
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
......@@ -79,66 +116,58 @@ def test_ignore_torch_compile_decorator():
return x
@ignore_torch_compile
class B(A):
...
class B(A): ...
@support_torch_compile
class C(B):
...
class C(B): ...
with set_current_vllm_config(vllm_config):
mod_A = A(vllm_config=vllm_config, prefix='').eval().cuda()
mod_A = A(vllm_config=vllm_config, prefix="").eval().cuda()
# A has support_torch_compile
with compilation_counter.expect(
num_graphs_seen=1,
num_piecewise_graphs_seen=3,
num_piecewise_capturable_graphs_seen=2,
num_backend_compilations=2,
num_cudagraph_captured=4,
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
num_graphs_seen=expected_num_graphs_seen,
num_piecewise_graphs_seen=expected_num_piecewise_graphs_seen,
num_piecewise_capturable_graphs_seen=expected_num_piecewise_capturable_graphs_seen,
num_backend_compilations=expected_num_backend_compilations,
num_cudagraph_captured=expected_num_cudagraph_captured,
):
run_model(vllm_config, mod_A, cudagraph_runtime_mode)
with set_current_vllm_config(vllm_config):
mod_B = B(vllm_config=vllm_config, prefix='').eval().cuda()
mod_B = B(vllm_config=vllm_config, prefix="").eval().cuda()
# B's ignore_torch_compile should override A's support_torch_compile
with compilation_counter.expect(
num_graphs_seen=0,
num_piecewise_graphs_seen=0,
num_piecewise_capturable_graphs_seen=0,
num_backend_compilations=0,
num_cudagraph_captured=0,
num_graphs_seen=0,
num_piecewise_graphs_seen=0,
num_piecewise_capturable_graphs_seen=0,
num_backend_compilations=0,
num_cudagraph_captured=0,
):
run_model(vllm_config, mod_B, cudagraph_runtime_mode)
with set_current_vllm_config(vllm_config):
mod_C = C(vllm_config=vllm_config, prefix='').eval().cuda()
mod_C = C(vllm_config=vllm_config, prefix="").eval().cuda()
# C's support_torch_compile should override B's ignore_torch_compile
with compilation_counter.expect(
num_graphs_seen=1,
num_piecewise_graphs_seen=3,
num_piecewise_capturable_graphs_seen=2,
num_backend_compilations=2,
num_cudagraph_captured=4,
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
num_graphs_seen=expected_num_graphs_seen,
num_piecewise_graphs_seen=expected_num_piecewise_graphs_seen,
num_piecewise_capturable_graphs_seen=expected_num_piecewise_capturable_graphs_seen,
num_backend_compilations=expected_num_backend_compilations,
num_cudagraph_captured=expected_num_cudagraph_captured,
):
run_model(vllm_config, mod_C, cudagraph_runtime_mode)
# Only enable torch.compile if
# vllm_config.cache_config.kv_sharing_fast_prefill=True
@support_torch_compile(enable_if=lambda vllm_config: vllm_config.cache_config.
kv_sharing_fast_prefill)
@support_torch_compile(
enable_if=lambda vllm_config: vllm_config.cache_config.kv_sharing_fast_prefill
)
class B(nn.Module):
def __init__(self,
*,
vllm_config: VllmConfig,
prefix: str = '',
**kwargs) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs) -> None:
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
......@@ -152,15 +181,11 @@ class B(nn.Module):
# Only enable torch.compile if
# vllm_config.cache_config.kv_sharing_fast_prefill=False
@support_torch_compile(enable_if=lambda vllm_config: not vllm_config.
cache_config.kv_sharing_fast_prefill)
@support_torch_compile(
enable_if=lambda vllm_config: not vllm_config.cache_config.kv_sharing_fast_prefill
)
class A(nn.Module):
def __init__(self,
*,
vllm_config: VllmConfig,
prefix: str = '',
**kwargs) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs) -> None:
super().__init__()
self.mod1 = B(vllm_config=vllm_config, prefix=prefix, **kwargs)
self.mod2 = B(vllm_config=vllm_config, prefix=prefix, **kwargs)
......@@ -174,55 +199,88 @@ class A(nn.Module):
return x
def test_conditional_compile_enable_if():
vllm_config = VllmConfig(cache_config=CacheConfig(
kv_sharing_fast_prefill=True, ),
compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
use_cudagraph=True,
splitting_ops=["silly.attention"],
cudagraph_capture_sizes=[1, 2],
))
@pytest.mark.parametrize("use_inductor_graph_partition", [True, False])
def test_conditional_compile_enable_if(use_inductor_graph_partition, monkeypatch):
# disable compile cache so that we can count the number of compilations
# appropriately
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
vllm_config = VllmConfig(
cache_config=CacheConfig(
kv_sharing_fast_prefill=True,
),
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
splitting_ops=["silly::attention"],
cudagraph_capture_sizes=[1, 2],
use_inductor_graph_partition=use_inductor_graph_partition,
),
)
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
with set_current_vllm_config(vllm_config):
mod_A = A(vllm_config=vllm_config, prefix='').eval().cuda()
mod_A = A(vllm_config=vllm_config, prefix="").eval().cuda()
if use_inductor_graph_partition:
expected_num_piecewise_graphs_seen = 2
expected_num_piecewise_capturable_graphs_seen = 2
expected_num_backend_compilations = 2
else:
expected_num_piecewise_graphs_seen = 6
expected_num_piecewise_capturable_graphs_seen = 4
expected_num_backend_compilations = 4
# A has support_torch_compile but enable_if fn returns False
# enalbe_if will be True for B, so we expect mod1 and mod2
# to be compiled
with compilation_counter.expect(
num_graphs_seen=2,
num_piecewise_graphs_seen=6,
# 3 piecewise graphs per instance of B()
num_piecewise_capturable_graphs_seen=4,
num_backend_compilations=4,
num_cudagraph_captured=8,
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
num_graphs_seen=2,
num_piecewise_graphs_seen=expected_num_piecewise_graphs_seen,
# 3 piecewise graphs per instance of B()
num_piecewise_capturable_graphs_seen=expected_num_piecewise_capturable_graphs_seen,
num_backend_compilations=expected_num_backend_compilations,
num_cudagraph_captured=8,
# num_cudagraph_sizes * num cudagraphable graphs to capture
):
run_model(vllm_config, mod_A, cudagraph_runtime_mode)
# Set kv_sharing_fast_prefill=False
# which will cause A to be compiled and B to not be compiled
vllm_config = VllmConfig(cache_config=CacheConfig(
kv_sharing_fast_prefill=False, ),
compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
use_cudagraph=True,
splitting_ops=["silly.attention"],
cudagraph_capture_sizes=[1, 2],
))
vllm_config = VllmConfig(
cache_config=CacheConfig(
kv_sharing_fast_prefill=False,
),
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
splitting_ops=["silly::attention"],
cudagraph_capture_sizes=[1, 2],
use_inductor_graph_partition=use_inductor_graph_partition,
),
)
with set_current_vllm_config(vllm_config):
mod_A = A(vllm_config=vllm_config, prefix='').eval().cuda()
mod_A = A(vllm_config=vllm_config, prefix="").eval().cuda()
if use_inductor_graph_partition:
expected_num_piecewise_graphs_seen = 1
expected_num_piecewise_capturable_graphs_seen = 1
expected_num_backend_compilations = 1
else:
# 3 attn ops and 4 non-attn ops
expected_num_piecewise_graphs_seen = 7
expected_num_piecewise_capturable_graphs_seen = 4
expected_num_backend_compilations = 4
with compilation_counter.expect(
num_graphs_seen=1,
num_piecewise_graphs_seen=7,
# 3 attn ops and 4 non-attn ops
num_piecewise_capturable_graphs_seen=4,
num_backend_compilations=4,
num_cudagraph_captured=8,
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
num_graphs_seen=1,
num_piecewise_graphs_seen=expected_num_piecewise_graphs_seen,
# 3 attn ops and 4 non-attn ops
num_piecewise_capturable_graphs_seen=expected_num_piecewise_capturable_graphs_seen,
num_backend_compilations=expected_num_backend_compilations,
num_cudagraph_captured=8,
# num_cudagraph_sizes * num cudagraphable graphs to capture
):
run_model(vllm_config, mod_A, cudagraph_runtime_mode)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment