Commit 006693ed authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.11.2' into v0.11.2-ori

parents 4b51e6f1 275de341
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
Run `pytest tests/basic_correctness/test_basic_correctness.py`. Run `pytest tests/basic_correctness/test_basic_correctness.py`.
""" """
import os import os
import weakref import weakref
from unittest.mock import Mock from unittest.mock import Mock
...@@ -12,14 +13,14 @@ import pytest ...@@ -12,14 +13,14 @@ import pytest
import torch import torch
from vllm import LLM from vllm import LLM
from vllm.v1.engine.llm_engine import LLMEngine as LLMEngineV1 from vllm.v1.engine.llm_engine import LLMEngine
from ..conftest import HfRunner, VllmRunner from ..conftest import HfRunner, VllmRunner
from ..models.utils import check_outputs_equal from ..models.utils import check_outputs_equal
from ..utils import multi_gpu_test from ..utils import multi_gpu_test
MODELS = [ MODELS = [
"google/gemma-2-2b-it", "hmellor/tiny-random-Gemma2ForCausalLM",
"meta-llama/Llama-3.2-1B-Instruct", "meta-llama/Llama-3.2-1B-Instruct",
] ]
...@@ -28,7 +29,7 @@ TARGET_TEST_SUITE = os.environ.get("TARGET_TEST_SUITE", "L4") ...@@ -28,7 +29,7 @@ TARGET_TEST_SUITE = os.environ.get("TARGET_TEST_SUITE", "L4")
def test_vllm_gc_ed(): def test_vllm_gc_ed():
"""Verify vllm instance is GC'ed when it is deleted""" """Verify vllm instance is GC'ed when it is deleted"""
llm = LLM("distilbert/distilgpt2") llm = LLM("hmellor/tiny-random-LlamaForCausalLM")
weak_llm = weakref.ref(llm) weak_llm = weakref.ref(llm)
del llm del llm
# If there's any circular reference to vllm, this fails # If there's any circular reference to vllm, this fails
...@@ -37,16 +38,21 @@ def test_vllm_gc_ed(): ...@@ -37,16 +38,21 @@ def test_vllm_gc_ed():
def _fix_prompt_embed_outputs( def _fix_prompt_embed_outputs(
vllm_outputs: list[tuple[list[int], str]], hf_model: HfRunner, vllm_outputs: list[tuple[list[int], str]],
example_prompts: list[str]) -> list[tuple[list[int], str]]: hf_model: HfRunner,
example_prompts: list[str],
) -> list[tuple[list[int], str]]:
fixed_vllm_outputs = [] fixed_vllm_outputs = []
for vllm_output, hf_input, prompt in zip( for vllm_output, hf_input, prompt in zip(
vllm_outputs, hf_model.get_inputs(example_prompts), vllm_outputs, hf_model.get_inputs(example_prompts), example_prompts
example_prompts): ):
hf_input_ids = hf_input["input_ids"].tolist()[0] hf_input_ids = hf_input["input_ids"].tolist()[0]
fixed_vllm_outputs.append( fixed_vllm_outputs.append(
(hf_input_ids + vllm_output[0][len(hf_input_ids):], (
prompt + vllm_output[1])) hf_input_ids + vllm_output[0][len(hf_input_ids) :],
prompt + vllm_output[1],
)
)
return fixed_vllm_outputs return fixed_vllm_outputs
...@@ -69,8 +75,7 @@ def test_models( ...@@ -69,8 +75,7 @@ def test_models(
enable_prompt_embeds: bool, enable_prompt_embeds: bool,
) -> None: ) -> None:
if backend == "XFORMERS" and model == "google/gemma-2-2b-it": if backend == "XFORMERS" and model == "google/gemma-2-2b-it":
pytest.skip( pytest.skip(f"{backend} does not support gemma2 with full context length.")
f"{backend} does not support gemma2 with full context length.")
with monkeypatch.context() as m: with monkeypatch.context() as m:
m.setenv("VLLM_ATTENTION_BACKEND", backend) m.setenv("VLLM_ATTENTION_BACKEND", backend)
...@@ -78,34 +83,35 @@ def test_models( ...@@ -78,34 +83,35 @@ def test_models(
# 5042 tokens for gemma2 # 5042 tokens for gemma2
# gemma2 has alternating sliding window size of 4096 # gemma2 has alternating sliding window size of 4096
# we need a prompt with more than 4096 tokens to test the sliding window # we need a prompt with more than 4096 tokens to test the sliding window
prompt = "The following numbers of the sequence " + ", ".join( prompt = (
str(i) for i in range(1024)) + " are:" "The following numbers of the sequence "
+ ", ".join(str(i) for i in range(1024))
+ " are:"
)
example_prompts = [prompt] example_prompts = [prompt]
with hf_runner(model) as hf_model: with hf_runner(model) as hf_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
if enable_prompt_embeds: if enable_prompt_embeds:
with torch.no_grad(): with torch.no_grad():
prompt_embeds = hf_model.get_prompt_embeddings( prompt_embeds = hf_model.get_prompt_embeddings(example_prompts)
example_prompts)
with VllmRunner( with VllmRunner(
model, model,
max_model_len=8192, max_model_len=8192,
enforce_eager=enforce_eager, enforce_eager=enforce_eager,
enable_prompt_embeds=enable_prompt_embeds, enable_prompt_embeds=enable_prompt_embeds,
gpu_memory_utilization=0.7, gpu_memory_utilization=0.7,
async_scheduling=async_scheduling, async_scheduling=async_scheduling,
distributed_executor_backend=model_executor, distributed_executor_backend=model_executor,
) as vllm_model: ) as vllm_model:
if enable_prompt_embeds: if enable_prompt_embeds:
vllm_outputs = vllm_model.generate_greedy( vllm_outputs = vllm_model.generate_greedy(prompt_embeds, max_tokens)
prompt_embeds, max_tokens)
vllm_outputs = _fix_prompt_embed_outputs( vllm_outputs = _fix_prompt_embed_outputs(
vllm_outputs, hf_model, example_prompts) vllm_outputs, hf_model, example_prompts
)
else: else:
vllm_outputs = vllm_model.generate_greedy( vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
example_prompts, max_tokens)
check_outputs_equal( check_outputs_equal(
outputs_0_lst=hf_outputs, outputs_0_lst=hf_outputs,
...@@ -117,21 +123,18 @@ def test_models( ...@@ -117,21 +123,18 @@ def test_models(
@multi_gpu_test(num_gpus=2) @multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model, distributed_executor_backend, attention_backend, " "model, distributed_executor_backend, attention_backend, test_suite, extra_env",
"test_suite, extra_env", [ [
("distilbert/distilgpt2", "ray", "", "L4", {}), ("facebook/opt-125m", "ray", "", "L4", {}),
("distilbert/distilgpt2", "mp", "", "L4", {}), ("facebook/opt-125m", "mp", "", "L4", {}),
("distilbert/distilgpt2", "ray", "", "L4", { ("facebook/opt-125m", "ray", "", "L4", {"VLLM_SLEEP_WHEN_IDLE": "1"}),
"VLLM_SLEEP_WHEN_IDLE": "1" ("facebook/opt-125m", "mp", "", "L4", {"VLLM_SLEEP_WHEN_IDLE": "1"}),
}),
("distilbert/distilgpt2", "mp", "", "L4", {
"VLLM_SLEEP_WHEN_IDLE": "1"
}),
("meta-llama/Llama-3.2-1B-Instruct", "ray", "", "L4", {}), ("meta-llama/Llama-3.2-1B-Instruct", "ray", "", "L4", {}),
("meta-llama/Llama-3.2-1B-Instruct", "mp", "", "L4", {}), ("meta-llama/Llama-3.2-1B-Instruct", "mp", "", "L4", {}),
("distilbert/distilgpt2", "ray", "", "A100", {}), ("facebook/opt-125m", "ray", "", "A100", {}),
("distilbert/distilgpt2", "mp", "", "A100", {}), ("facebook/opt-125m", "mp", "", "A100", {}),
]) ],
)
@pytest.mark.parametrize("enable_prompt_embeds", [True, False]) @pytest.mark.parametrize("enable_prompt_embeds", [True, False])
def test_models_distributed( def test_models_distributed(
monkeypatch: pytest.MonkeyPatch, monkeypatch: pytest.MonkeyPatch,
...@@ -149,13 +152,14 @@ def test_models_distributed( ...@@ -149,13 +152,14 @@ def test_models_distributed(
pytest.skip(f"Skip test for {test_suite}") pytest.skip(f"Skip test for {test_suite}")
with monkeypatch.context() as monkeypatch_context: with monkeypatch.context() as monkeypatch_context:
if model == "meta-llama/Llama-3.2-1B-Instruct" and distributed_executor_backend == "ray" and attention_backend == "" and test_suite == "L4": # noqa if (
if enable_prompt_embeds: model == "meta-llama/Llama-3.2-1B-Instruct"
pytest.skip( and distributed_executor_backend == "ray"
"enable_prompt_embeds does not work with ray compiled dag." and attention_backend == ""
) and test_suite == "L4"
monkeypatch_context.setenv("VLLM_USE_RAY_SPMD_WORKER", "1") and enable_prompt_embeds
monkeypatch_context.setenv("VLLM_USE_RAY_COMPILED_DAG", "1") ): # noqa
pytest.skip("enable_prompt_embeds does not work with ray compiled dag.")
if attention_backend: if attention_backend:
monkeypatch_context.setenv( monkeypatch_context.setenv(
...@@ -175,30 +179,26 @@ def test_models_distributed( ...@@ -175,30 +179,26 @@ def test_models_distributed(
# will hurt multiprocessing backend with fork method # will hurt multiprocessing backend with fork method
# (the default method). # (the default method).
with vllm_runner( with vllm_runner(
model, model,
dtype=dtype, dtype=dtype,
tensor_parallel_size=2, tensor_parallel_size=2,
distributed_executor_backend=distributed_executor_backend, distributed_executor_backend=distributed_executor_backend,
enable_prompt_embeds=enable_prompt_embeds, enable_prompt_embeds=enable_prompt_embeds,
gpu_memory_utilization=0.7, gpu_memory_utilization=0.7,
) as vllm_model: ) as vllm_model:
if enable_prompt_embeds: if enable_prompt_embeds:
with hf_runner(model, dtype=dtype) as hf_model: with hf_runner(model, dtype=dtype) as hf_model:
with torch.no_grad(): with torch.no_grad():
prompt_embeds = hf_model.get_prompt_embeddings( prompt_embeds = hf_model.get_prompt_embeddings(example_prompts)
example_prompts) vllm_outputs = vllm_model.generate_greedy(prompt_embeds, max_tokens)
vllm_outputs = vllm_model.generate_greedy(
prompt_embeds, max_tokens)
vllm_outputs = _fix_prompt_embed_outputs( vllm_outputs = _fix_prompt_embed_outputs(
vllm_outputs, hf_model, example_prompts) vllm_outputs, hf_model, example_prompts
hf_outputs = hf_model.generate_greedy( )
example_prompts, max_tokens) hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
else: else:
vllm_outputs = vllm_model.generate_greedy( vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
example_prompts, max_tokens)
with hf_runner(model, dtype=dtype) as hf_model: with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy( hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
example_prompts, max_tokens)
check_outputs_equal( check_outputs_equal(
outputs_0_lst=hf_outputs, outputs_0_lst=hf_outputs,
...@@ -209,27 +209,18 @@ def test_models_distributed( ...@@ -209,27 +209,18 @@ def test_models_distributed(
def test_failed_model_execution(vllm_runner, monkeypatch) -> None: def test_failed_model_execution(vllm_runner, monkeypatch) -> None:
from vllm.envs import VLLM_USE_V1
if not VLLM_USE_V1:
pytest.skip("Skipping V0 test, dump input not supported")
# Needed to mock an error in the same process # Needed to mock an error in the same process
monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0') monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
with vllm_runner('facebook/opt-125m', enforce_eager=True) as vllm_model: with vllm_runner("facebook/opt-125m", enforce_eager=True) as vllm_model:
if isinstance(vllm_model.llm.llm_engine, LLMEngineV1): if isinstance(vllm_model.llm.llm_engine, LLMEngine):
v1_test_failed_model_execution(vllm_model) v1_test_failed_model_execution(vllm_model)
def v1_test_failed_model_execution(vllm_model): def v1_test_failed_model_execution(vllm_model):
engine = vllm_model.llm.llm_engine engine = vllm_model.llm.llm_engine
mocked_execute_model = Mock( mocked_execute_model = Mock(side_effect=RuntimeError("Mocked Critical Error"))
side_effect=RuntimeError("Mocked Critical Error")) engine.engine_core.engine_core.model_executor.execute_model = mocked_execute_model
engine.engine_core.engine_core.model_executor.execute_model =\
mocked_execute_model
with pytest.raises(RuntimeError) as exc_info: with pytest.raises(RuntimeError) as exc_info:
prompts = [ prompts = [
......
...@@ -5,5 +5,6 @@ from ..utils import compare_two_settings ...@@ -5,5 +5,6 @@ from ..utils import compare_two_settings
def test_cpu_offload(): def test_cpu_offload():
compare_two_settings("meta-llama/Llama-3.2-1B-Instruct", [], compare_two_settings(
["--cpu-offload-gb", "1"]) "hmellor/tiny-random-LlamaForCausalLM", [], ["--cpu-offload-gb", "1"]
)
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import pytest import pytest
import torch import torch
from vllm import LLM, SamplingParams from vllm import LLM, AsyncEngineArgs, AsyncLLMEngine, SamplingParams
from vllm.device_allocator.cumem import CuMemAllocator from vllm.device_allocator.cumem import CuMemAllocator
from vllm.utils import GiB_bytes from vllm.platforms import current_platform
from vllm.utils.mem_constants import GiB_bytes
from ..utils import create_new_process_for_each_test from ..utils import create_new_process_for_each_test
@create_new_process_for_each_test() @create_new_process_for_each_test("fork" if not current_platform.is_rocm() else "spawn")
def test_python_error(): def test_python_error():
""" """
Test if Python error occurs when there's low-level Test if Python error occurs when there's low-level
...@@ -23,13 +26,13 @@ def test_python_error(): ...@@ -23,13 +26,13 @@ def test_python_error():
tensors = [] tensors = []
with allocator.use_memory_pool(): with allocator.use_memory_pool():
# allocate 70% of the total memory # allocate 70% of the total memory
x = torch.empty(alloc_bytes, dtype=torch.uint8, device='cuda') x = torch.empty(alloc_bytes, dtype=torch.uint8, device="cuda")
tensors.append(x) tensors.append(x)
# release the memory # release the memory
allocator.sleep() allocator.sleep()
# allocate more memory than the total memory # allocate more memory than the total memory
y = torch.empty(alloc_bytes, dtype=torch.uint8, device='cuda') y = torch.empty(alloc_bytes, dtype=torch.uint8, device="cuda")
tensors.append(y) tensors.append(y)
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
# when the allocator is woken up, it should raise an error # when the allocator is woken up, it should raise an error
...@@ -37,21 +40,21 @@ def test_python_error(): ...@@ -37,21 +40,21 @@ def test_python_error():
allocator.wake_up() allocator.wake_up()
@create_new_process_for_each_test() @create_new_process_for_each_test("fork" if not current_platform.is_rocm() else "spawn")
def test_basic_cumem(): def test_basic_cumem():
# some tensors from default memory pool # some tensors from default memory pool
shape = (1024, 1024) shape = (1024, 1024)
x = torch.empty(shape, device='cuda') x = torch.empty(shape, device="cuda")
x.zero_() x.zero_()
# some tensors from custom memory pool # some tensors from custom memory pool
allocator = CuMemAllocator.get_instance() allocator = CuMemAllocator.get_instance()
with allocator.use_memory_pool(): with allocator.use_memory_pool():
# custom memory pool # custom memory pool
y = torch.empty(shape, device='cuda') y = torch.empty(shape, device="cuda")
y.zero_() y.zero_()
y += 1 y += 1
z = torch.empty(shape, device='cuda') z = torch.empty(shape, device="cuda")
z.zero_() z.zero_()
z += 2 z += 2
...@@ -70,20 +73,20 @@ def test_basic_cumem(): ...@@ -70,20 +73,20 @@ def test_basic_cumem():
assert torch.allclose(output, torch.ones_like(output) * 3) assert torch.allclose(output, torch.ones_like(output) * 3)
@create_new_process_for_each_test() @create_new_process_for_each_test("fork" if not current_platform.is_rocm() else "spawn")
def test_cumem_with_cudagraph(): def test_cumem_with_cudagraph():
allocator = CuMemAllocator.get_instance() allocator = CuMemAllocator.get_instance()
with allocator.use_memory_pool(): with allocator.use_memory_pool():
weight = torch.eye(1024, device='cuda') weight = torch.eye(1024, device="cuda")
with allocator.use_memory_pool(tag="discard"): with allocator.use_memory_pool(tag="discard"):
cache = torch.empty(1024, 1024, device='cuda') cache = torch.empty(1024, 1024, device="cuda")
def model(x): def model(x):
out = x @ weight out = x @ weight
cache[:out.size(0)].copy_(out) cache[: out.size(0)].copy_(out)
return out + 1 return out + 1
x = torch.empty(128, 1024, device='cuda') x = torch.empty(128, 1024, device="cuda")
# warmup # warmup
model(x) model(x)
...@@ -109,80 +112,72 @@ def test_cumem_with_cudagraph(): ...@@ -109,80 +112,72 @@ def test_cumem_with_cudagraph():
model_graph.replay() model_graph.replay()
# cache content is as expected # cache content is as expected
assert torch.allclose(x, cache[:x.size(0)]) assert torch.allclose(x, cache[: x.size(0)])
# output content is as expected # output content is as expected
assert torch.allclose(y, x + 1) assert torch.allclose(y, x + 1)
@create_new_process_for_each_test() @create_new_process_for_each_test("fork" if not current_platform.is_rocm() else "spawn")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model, use_v1", "model",
[ [
# sleep mode with safetensors # sleep mode with safetensors
("meta-llama/Llama-3.2-1B", True), "hmellor/tiny-random-LlamaForCausalLM",
# sleep mode with pytorch checkpoint # sleep mode with pytorch checkpoint
("facebook/opt-125m", True), "facebook/opt-125m",
]) ],
def test_end_to_end(monkeypatch: pytest.MonkeyPatch, model: str, use_v1: bool): )
with monkeypatch.context() as m: def test_end_to_end(model: str):
assert use_v1 free, total = torch.cuda.mem_get_info()
m.setenv("VLLM_USE_V1", "1") used_bytes_baseline = total - free # in case other process is running
free, total = torch.cuda.mem_get_info() llm = LLM(model, enable_sleep_mode=True)
used_bytes_baseline = total - free # in case other process is running prompt = "How are you?"
llm = LLM(model, enable_sleep_mode=True) sampling_params = SamplingParams(temperature=0, max_tokens=10)
prompt = "How are you?" output = llm.generate(prompt, sampling_params)
sampling_params = SamplingParams(temperature=0, max_tokens=10)
output = llm.generate(prompt, sampling_params)
# the benefit of `llm.sleep(level=2)` is mainly CPU memory usage,
# which is difficult to measure in the test. therefore, we only
# test sleep level 1 here.
llm.sleep(level=1)
free_gpu_bytes_after_sleep, total = torch.cuda.mem_get_info()
used_bytes = total - free_gpu_bytes_after_sleep - used_bytes_baseline
# now the memory usage is mostly cudagraph memory pool,
# and it should be less than the model weights (1B model, 2GiB weights)
# NOTE: In V1, the memory buffer for logits (max_num_reqs x vocab_size)
# is captured but cannot be releasesd from PyTorch due to a known bug,
# therefore high memory usage after `llm.sleep` is called is expected.
# FIXME(youkaichao & ywang96): Fix memory buffer issue with sleep mode
# in V1.
if use_v1:
assert used_bytes < 7 * GiB_bytes
else:
assert used_bytes < 2 * GiB_bytes
llm.wake_up()
output2 = llm.generate(prompt, sampling_params)
# cmp output
assert output[0].outputs[0].text == output2[0].outputs[0].text
llm.sleep(level=1) # the benefit of `llm.sleep(level=2)` is mainly CPU memory usage,
llm.wake_up(tags=["weights"]) # which is difficult to measure in the test. therefore, we only
# test sleep level 1 here.
llm.sleep(level=1)
free_gpu_bytes_wake_up_w, total = torch.cuda.mem_get_info() free_gpu_bytes_after_sleep, total = torch.cuda.mem_get_info()
used_bytes = total - free_gpu_bytes_wake_up_w - used_bytes_baseline used_bytes = total - free_gpu_bytes_after_sleep - used_bytes_baseline
# now the memory usage is mostly cudagraph memory pool,
# and it should be less than the model weights (1B model, 2GiB weights)
# should just reallocate memory for weights (1B model, ~2GiB weights) # NOTE: In V1, the memory buffer for logits (max_num_reqs x vocab_size)
if use_v1: # is captured but cannot be releasesd from PyTorch due to a known bug,
assert used_bytes < 10 * GiB_bytes # therefore high memory usage after `llm.sleep` is called is expected.
else: # FIXME(youkaichao & ywang96): Fix memory buffer issue with sleep mode
assert used_bytes < 6 * GiB_bytes # in V1.
assert used_bytes < 7 * GiB_bytes
# now allocate kv cache memory llm.wake_up()
llm.wake_up(tags=["kv_cache"]) output2 = llm.generate(prompt, sampling_params)
output3 = llm.generate(prompt, sampling_params) # cmp output
assert output[0].outputs[0].text == output2[0].outputs[0].text
# cmp output llm.sleep(level=1)
assert output[0].outputs[0].text == output3[0].outputs[0].text llm.wake_up(tags=["weights"])
free_gpu_bytes_wake_up_w, total = torch.cuda.mem_get_info()
used_bytes = total - free_gpu_bytes_wake_up_w - used_bytes_baseline
# should just reallocate memory for weights (1B model, ~2GiB weights)
assert used_bytes < 10 * GiB_bytes
# now allocate kv cache memory
llm.wake_up(tags=["kv_cache"])
output3 = llm.generate(prompt, sampling_params)
# cmp output
assert output[0].outputs[0].text == output3[0].outputs[0].text
@create_new_process_for_each_test() @create_new_process_for_each_test()
def test_deep_sleep(): def test_deep_sleep():
model = "Qwen/Qwen3-0.6B" model = "hmellor/tiny-random-LlamaForCausalLM"
free, total = torch.cuda.mem_get_info() free, total = torch.cuda.mem_get_info()
used_bytes_baseline = total - free # in case other process is running used_bytes_baseline = total - free # in case other process is running
llm = LLM(model, enable_sleep_mode=True) llm = LLM(model, enable_sleep_mode=True)
...@@ -209,3 +204,42 @@ def test_deep_sleep(): ...@@ -209,3 +204,42 @@ def test_deep_sleep():
# cmp output # cmp output
assert output[0].outputs[0].text == output2[0].outputs[0].text assert output[0].outputs[0].text == output2[0].outputs[0].text
@create_new_process_for_each_test()
def test_deep_sleep_async():
async def test():
model = "hmellor/tiny-random-LlamaForCausalLM"
free, total = torch.cuda.mem_get_info()
used_bytes_baseline = total - free # in case other process is running
engine_args = AsyncEngineArgs(
model=model,
enable_sleep_mode=True,
)
llm = AsyncLLMEngine.from_engine_args(engine_args)
prompt = "How are you?"
sampling_params = SamplingParams(temperature=0, max_tokens=10)
outputs = llm.generate(prompt, sampling_params, request_id="test_request_id1")
async for output in outputs:
pass
# Put the engine to deep sleep
await llm.sleep(level=2)
await llm.wake_up(tags=["weights"])
await llm.collective_rpc("reload_weights")
free_gpu_bytes_wake_up_w, total = torch.cuda.mem_get_info()
used_bytes = total - free_gpu_bytes_wake_up_w - used_bytes_baseline
assert used_bytes < 4 * GiB_bytes
# now allocate kv cache and cuda graph memory
await llm.wake_up(tags=["kv_cache"])
outputs2 = llm.generate(prompt, sampling_params, request_id="test_request_id2")
async for output2 in outputs2:
pass
# cmp output
assert output.outputs[0].text == output2.outputs[0].text
asyncio.run(test())
...@@ -10,8 +10,18 @@ MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" ...@@ -10,8 +10,18 @@ MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
@pytest.mark.benchmark @pytest.mark.benchmark
def test_bench_latency(): def test_bench_latency():
command = [ command = [
"vllm", "bench", "latency", "--model", MODEL_NAME, "--input-len", "32", "vllm",
"--output-len", "1", "--enforce-eager", "--load-format", "dummy" "bench",
"latency",
"--model",
MODEL_NAME,
"--input-len",
"32",
"--output-len",
"1",
"--enforce-eager",
"--load-format",
"dummy",
] ]
result = subprocess.run(command, capture_output=True, text=True) result = subprocess.run(command, capture_output=True, text=True)
print(result.stdout) print(result.stdout)
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import random import random
from typing import Any, NamedTuple, Optional, cast from typing import Any, NamedTuple, cast
import numpy as np import numpy as np
import pytest import pytest
from transformers import AutoTokenizer, PreTrainedTokenizerBase from transformers import AutoTokenizer, PreTrainedTokenizerBase
from vllm.benchmarks.datasets import (RandomDataset, RandomMultiModalDataset, from vllm.benchmarks.datasets import (
SampleRequest) RandomDataset,
RandomMultiModalDataset,
SampleRequest,
)
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
...@@ -27,11 +30,9 @@ class Params(NamedTuple): ...@@ -27,11 +30,9 @@ class Params(NamedTuple):
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def random_dataset_params() -> Params: def random_dataset_params() -> Params:
return Params(num_requests=16, return Params(
prefix_len=7, num_requests=16, prefix_len=7, range_ratio=0.3, input_len=50, output_len=20
range_ratio=0.3, )
input_len=50,
output_len=20)
def _fingerprint_sample(req: SampleRequest) -> tuple[str, int, int]: def _fingerprint_sample(req: SampleRequest) -> tuple[str, int, int]:
...@@ -39,13 +40,15 @@ def _fingerprint_sample(req: SampleRequest) -> tuple[str, int, int]: ...@@ -39,13 +40,15 @@ def _fingerprint_sample(req: SampleRequest) -> tuple[str, int, int]:
return (req.prompt, req.prompt_len, req.expected_output_len) return (req.prompt, req.prompt_len, req.expected_output_len)
def _collect_samples(dataset: RandomDataset, def _collect_samples(
tokenizer: PreTrainedTokenizerBase, dataset: RandomDataset,
num_requests: int = 16, tokenizer: PreTrainedTokenizerBase,
prefix_len: int = 7, num_requests: int = 16,
range_ratio: float = 0.3, prefix_len: int = 7,
input_len: int = 50, range_ratio: float = 0.3,
output_len: int = 20) -> list[tuple[str, int, int]]: input_len: int = 50,
output_len: int = 20,
) -> list[tuple[str, int, int]]:
samples = dataset.sample( samples = dataset.sample(
tokenizer=tokenizer, tokenizer=tokenizer,
num_requests=num_requests, num_requests=num_requests,
...@@ -59,8 +62,8 @@ def _collect_samples(dataset: RandomDataset, ...@@ -59,8 +62,8 @@ def _collect_samples(dataset: RandomDataset,
@pytest.mark.benchmark @pytest.mark.benchmark
def test_random_dataset_same_seed( def test_random_dataset_same_seed(
hf_tokenizer: PreTrainedTokenizerBase, hf_tokenizer: PreTrainedTokenizerBase, random_dataset_params: Params
random_dataset_params: Params) -> None: ) -> None:
"""Same seed should yield identical outputs, even if global RNGs change. """Same seed should yield identical outputs, even if global RNGs change.
This guards against accidental reliance on Python's random or np.random This guards against accidental reliance on Python's random or np.random
...@@ -70,13 +73,15 @@ def test_random_dataset_same_seed( ...@@ -70,13 +73,15 @@ def test_random_dataset_same_seed(
common_seed = 123 common_seed = 123
dataset_a = RandomDataset(random_seed=common_seed) dataset_a = RandomDataset(random_seed=common_seed)
dataset_b = RandomDataset(random_seed=common_seed) dataset_b = RandomDataset(random_seed=common_seed)
a = _collect_samples(dataset_a, a = _collect_samples(
hf_tokenizer, dataset_a,
num_requests=p.num_requests, hf_tokenizer,
prefix_len=p.prefix_len, num_requests=p.num_requests,
range_ratio=p.range_ratio, prefix_len=p.prefix_len,
input_len=p.input_len, range_ratio=p.range_ratio,
output_len=p.output_len) input_len=p.input_len,
output_len=p.output_len,
)
# Perturb global RNG state to ensure isolation # Perturb global RNG state to ensure isolation
random.seed(999) random.seed(999)
...@@ -84,43 +89,50 @@ def test_random_dataset_same_seed( ...@@ -84,43 +89,50 @@ def test_random_dataset_same_seed(
np.random.seed(888) np.random.seed(888)
_ = [np.random.random() for _ in range(100)] _ = [np.random.random() for _ in range(100)]
b = _collect_samples(dataset_b, b = _collect_samples(
hf_tokenizer, dataset_b,
num_requests=p.num_requests, hf_tokenizer,
prefix_len=p.prefix_len, num_requests=p.num_requests,
range_ratio=p.range_ratio, prefix_len=p.prefix_len,
input_len=p.input_len, range_ratio=p.range_ratio,
output_len=p.output_len) input_len=p.input_len,
output_len=p.output_len,
)
assert a == b assert a == b
@pytest.mark.benchmark @pytest.mark.benchmark
def test_random_dataset_different_seeds( def test_random_dataset_different_seeds(
hf_tokenizer: PreTrainedTokenizerBase, hf_tokenizer: PreTrainedTokenizerBase, random_dataset_params: Params
random_dataset_params: Params) -> None: ) -> None:
"""Different seeds should change outputs with overwhelming likelihood.""" """Different seeds should change outputs with overwhelming likelihood."""
p = random_dataset_params p = random_dataset_params
seed_a = 0 seed_a = 0
dataset_a = RandomDataset(random_seed=seed_a) dataset_a = RandomDataset(random_seed=seed_a)
a = _collect_samples(dataset_a, a = _collect_samples(
hf_tokenizer, dataset_a,
num_requests=p.num_requests, hf_tokenizer,
prefix_len=p.prefix_len, num_requests=p.num_requests,
range_ratio=p.range_ratio, prefix_len=p.prefix_len,
input_len=p.input_len, range_ratio=p.range_ratio,
output_len=p.output_len) input_len=p.input_len,
output_len=p.output_len,
)
seed_b = 999 seed_b = 999
dataset_b = RandomDataset(random_seed=seed_b) dataset_b = RandomDataset(random_seed=seed_b)
# Perturb global RNG with same seed as dataset_a to ensure isolation # Perturb global RNG with same seed as dataset_a to ensure isolation
random.seed(seed_a) random.seed(seed_a)
np.random.seed(seed_a) np.random.seed(seed_a)
b = _collect_samples(dataset_b, b = _collect_samples(
hf_tokenizer, dataset_b,
num_requests=p.num_requests, hf_tokenizer,
prefix_len=p.prefix_len, num_requests=p.num_requests,
range_ratio=p.range_ratio, prefix_len=p.prefix_len,
input_len=p.input_len, range_ratio=p.range_ratio,
output_len=p.output_len) input_len=p.input_len,
output_len=p.output_len,
)
assert a != b assert a != b
...@@ -128,6 +140,7 @@ def test_random_dataset_different_seeds( ...@@ -128,6 +140,7 @@ def test_random_dataset_different_seeds(
# RandomMultiModalDataset tests # RandomMultiModalDataset tests
# ----------------------------- # -----------------------------
def _mm_fingerprint_sample( def _mm_fingerprint_sample(
req: SampleRequest, req: SampleRequest,
) -> tuple[str, int, int, int, list[str]]: ) -> tuple[str, int, int, int, list[str]]:
...@@ -152,8 +165,13 @@ def _mm_fingerprint_sample( ...@@ -152,8 +165,13 @@ def _mm_fingerprint_sample(
item_prefixes.append(f"video:{url[:22]}") item_prefixes.append(f"video:{url[:22]}")
else: else:
item_prefixes.append("unknown:") item_prefixes.append("unknown:")
return (req.prompt, req.prompt_len, req.expected_output_len, len(items), return (
item_prefixes) req.prompt,
req.prompt_len,
req.expected_output_len,
len(items),
item_prefixes,
)
def _collect_mm_samples( def _collect_mm_samples(
...@@ -167,8 +185,8 @@ def _collect_mm_samples( ...@@ -167,8 +185,8 @@ def _collect_mm_samples(
output_len: int = 5, output_len: int = 5,
base_items_per_request: int = 2, base_items_per_request: int = 2,
num_mm_items_range_ratio: float = 0.0, num_mm_items_range_ratio: float = 0.0,
limit_mm_per_prompt: Optional[dict[str, int]] = None, limit_mm_per_prompt: dict[str, int] | None = None,
bucket_config: Optional[dict[tuple[int, int, int], float]] = None, bucket_config: dict[tuple[int, int, int], float] | None = None,
enable_multimodal_chat: bool = False, enable_multimodal_chat: bool = False,
) -> list[SampleRequest]: ) -> list[SampleRequest]:
if limit_mm_per_prompt is None: if limit_mm_per_prompt is None:
...@@ -214,6 +232,7 @@ def test_random_mm_different_seeds( ...@@ -214,6 +232,7 @@ def test_random_mm_different_seeds(
fb = [_mm_fingerprint_sample(s) for s in b] fb = [_mm_fingerprint_sample(s) for s in b]
assert fa != fb assert fa != fb
@pytest.mark.benchmark @pytest.mark.benchmark
def test_random_mm_respects_limits( def test_random_mm_respects_limits(
hf_tokenizer: PreTrainedTokenizerBase, hf_tokenizer: PreTrainedTokenizerBase,
...@@ -271,9 +290,9 @@ def test_random_mm_zero_items(hf_tokenizer: PreTrainedTokenizerBase) -> None: ...@@ -271,9 +290,9 @@ def test_random_mm_zero_items(hf_tokenizer: PreTrainedTokenizerBase) -> None:
for s in samples: for s in samples:
assert s.multi_modal_data == [] assert s.multi_modal_data == []
@pytest.mark.benchmark @pytest.mark.benchmark
def test_random_mm_num_items_per_prompt( def test_random_mm_num_items_per_prompt(hf_tokenizer: PreTrainedTokenizerBase) -> None:
hf_tokenizer: PreTrainedTokenizerBase) -> None:
ds = RandomMultiModalDataset(random_seed=0) ds = RandomMultiModalDataset(random_seed=0)
# Fixed number of images per prompt # Fixed number of images per prompt
# set num_mm_items_range_ratio to 0.0 # set num_mm_items_range_ratio to 0.0
...@@ -300,7 +319,6 @@ def test_random_mm_num_items_per_prompt( ...@@ -300,7 +319,6 @@ def test_random_mm_num_items_per_prompt(
def test_random_mm_bucket_config_not_mutated( def test_random_mm_bucket_config_not_mutated(
hf_tokenizer: PreTrainedTokenizerBase, hf_tokenizer: PreTrainedTokenizerBase,
) -> None: ) -> None:
ds = RandomMultiModalDataset(random_seed=0) ds = RandomMultiModalDataset(random_seed=0)
# This bucket config is not normalized to sum to 1 # This bucket config is not normalized to sum to 1
# and has more buckets than requested images # and has more buckets than requested images
...@@ -321,7 +339,6 @@ def test_random_mm_bucket_config_not_mutated( ...@@ -321,7 +339,6 @@ def test_random_mm_bucket_config_not_mutated(
# Ensure the original dict content is unchanged # Ensure the original dict content is unchanged
assert original == snapshot assert original == snapshot
# Vary number of mm items per prompt # Vary number of mm items per prompt
# set num_mm_items_range_ratio to 0.5 # set num_mm_items_range_ratio to 0.5
samples_varying_items = _collect_mm_samples( samples_varying_items = _collect_mm_samples(
...@@ -342,3 +359,126 @@ def test_random_mm_bucket_config_not_mutated( ...@@ -342,3 +359,126 @@ def test_random_mm_bucket_config_not_mutated(
assert len(mm_data) >= 1 assert len(mm_data) >= 1
for it in mm_data: for it in mm_data:
assert it.get("type") == "image_url" assert it.get("type") == "image_url"
@pytest.mark.benchmark
def test_random_mm_video_sampling(hf_tokenizer: PreTrainedTokenizerBase) -> None:
"""Test video sampling functionality in RandomMultiModalDataset."""
ds = RandomMultiModalDataset(random_seed=42)
# Test with video bucket configuration
bucket_config = {
(64, 64, 1): 0.3, # Images
(64, 64, 8): 0.7, # Videos
}
limit_mm_per_prompt = {"image": 2, "video": 2}
samples = _collect_mm_samples(
ds,
hf_tokenizer,
num_requests=5,
base_items_per_request=1,
num_mm_items_range_ratio=0.0,
limit_mm_per_prompt=limit_mm_per_prompt,
bucket_config=bucket_config,
)
assert len(samples) == 5
# Check that we have both images and videos
video_count = 0
image_count = 0
for s in samples:
mm_data = cast(list[dict[str, Any]], s.multi_modal_data)
assert len(mm_data) == 1
item = mm_data[0]
if item.get("type") == "video_url":
video_count += 1
# Verify video URL format
url = item.get("video_url", {}).get("url", "")
assert url.startswith("data:video/mp4;base64,")
elif item.get("type") == "image_url":
image_count += 1
# Verify image URL format
url = item.get("image_url", {}).get("url", "")
assert url.startswith("data:image/jpeg;base64,")
# Should have some videos due to 0.7 probability
assert video_count > 0
assert image_count > 0
@pytest.mark.benchmark
def test_random_mm_video_only_sampling(hf_tokenizer: PreTrainedTokenizerBase) -> None:
"""Test sampling with only video buckets."""
ds = RandomMultiModalDataset(random_seed=42)
bucket_config = {
(64, 64, 8): 1.0, # Only videos
}
limit_mm_per_prompt = {"image": 0, "video": 1}
samples = _collect_mm_samples(
ds,
hf_tokenizer,
num_requests=3,
base_items_per_request=1,
num_mm_items_range_ratio=0.0,
limit_mm_per_prompt=limit_mm_per_prompt,
bucket_config=bucket_config,
)
assert len(samples) == 3
for s in samples:
mm_data = cast(list[dict[str, Any]], s.multi_modal_data)
assert len(mm_data) == 1
item = mm_data[0]
assert item.get("type") == "video_url"
url = item.get("video_url", {}).get("url", "")
assert url.startswith("data:video/mp4;base64,")
@pytest.mark.benchmark
def test_random_mm_video_deterministic_sampling(
hf_tokenizer: PreTrainedTokenizerBase,
) -> None:
"""Test that video sampling is deterministic with same seed."""
seed = 123
ds_a = RandomMultiModalDataset(random_seed=seed)
ds_b = RandomMultiModalDataset(random_seed=seed)
bucket_config = {
(64, 64, 8): 1.0, # Only videos
}
limit_mm_per_prompt = {"image": 0, "video": 1}
a = _collect_mm_samples(
ds_a,
hf_tokenizer,
num_requests=3,
base_items_per_request=1,
num_mm_items_range_ratio=0.0,
limit_mm_per_prompt=limit_mm_per_prompt,
bucket_config=bucket_config,
)
b = _collect_mm_samples(
ds_b,
hf_tokenizer,
num_requests=3,
base_items_per_request=1,
num_mm_items_range_ratio=0.0,
limit_mm_per_prompt=limit_mm_per_prompt,
bucket_config=bucket_config,
)
fa = [_mm_fingerprint_sample(s) for s in a]
fb = [_mm_fingerprint_sample(s) for s in b]
assert fa == fb
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import base64
import os
from tempfile import NamedTemporaryFile
from typing import Any, cast
import cv2
import pytest
from transformers import AutoTokenizer, PreTrainedTokenizerBase
from vllm.benchmarks.datasets import RandomMultiModalDataset, SampleRequest
@pytest.fixture(scope="session")
def hf_tokenizer() -> PreTrainedTokenizerBase:
"""Use a small, commonly available tokenizer."""
return AutoTokenizer.from_pretrained("gpt2")
@pytest.fixture
def video_dataset() -> RandomMultiModalDataset:
"""Create a RandomMultiModalDataset instance for testing."""
return RandomMultiModalDataset(random_seed=42)
@pytest.mark.benchmark
def test_generate_synthetic_video_different_seeds():
"""Test that different seeds produce different videos."""
dataset1 = RandomMultiModalDataset(random_seed=123)
dataset2 = RandomMultiModalDataset(random_seed=456)
width, height, num_frames = 64, 48, 8
video1 = dataset1.generate_synthetic_video(width, height, num_frames)
video2 = dataset2.generate_synthetic_video(width, height, num_frames)
# Videos should be different due to different seeds
assert video1["bytes"] != video2["bytes"]
@pytest.mark.benchmark
def test_map_config_to_modality(video_dataset: RandomMultiModalDataset):
"""Test modality mapping for different configurations."""
# Test image configuration (num_frames = 1)
assert video_dataset.map_config_to_modality((256, 256, 1)) == "image"
assert video_dataset.map_config_to_modality((720, 1280, 1)) == "image"
# Test video configurations (num_frames > 1)
assert video_dataset.map_config_to_modality((256, 256, 8)) == "video"
assert video_dataset.map_config_to_modality((720, 1280, 16)) == "video"
assert video_dataset.map_config_to_modality((64, 64, 32)) == "video"
# Test invalid configurations
with pytest.raises(ValueError, match="Invalid multimodal item configuration"):
video_dataset.map_config_to_modality((256, 256, 0))
with pytest.raises(ValueError, match="Invalid multimodal item configuration"):
video_dataset.map_config_to_modality((256, 256, -1))
@pytest.mark.benchmark
def test_generate_mm_item_video(video_dataset: RandomMultiModalDataset):
"""Test generating multimodal items for video configurations."""
# Test video item generation
video_config = (64, 48, 8) # height, width, num_frames
result = video_dataset.generate_mm_item(video_config)
# Check the result structure matches OpenAI API format
assert isinstance(result, dict)
assert result["type"] == "video_url"
assert "video_url" in result
assert "url" in result["video_url"]
# Check that the URL is a data URL with base64 encoded video
url = result["video_url"]["url"]
assert url.startswith("data:video/mp4;base64,")
# Decode and verify the video content
base64_data = url.split(",")[1]
video_bytes = base64.b64decode(base64_data)
assert len(video_bytes) > 0
# Verify the video can be decoded
with NamedTemporaryFile(suffix=".mp4", delete=False) as temp_file:
temp_path = temp_file.name
temp_file.write(video_bytes)
try:
cap = cv2.VideoCapture(temp_path)
assert cap.isOpened()
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
assert frame_count == 8
assert frame_width == 48
assert frame_height == 64
cap.release()
finally:
if os.path.exists(temp_path):
os.unlink(temp_path)
@pytest.mark.benchmark
def test_generate_mm_item_image(video_dataset: RandomMultiModalDataset):
"""Test generating multimodal items for image configurations."""
# Test image item generation
image_config = (64, 48, 1) # height, width, num_frames=1
result = video_dataset.generate_mm_item(image_config)
# Check the result structure matches OpenAI API format
assert isinstance(result, dict)
assert result["type"] == "image_url"
assert "image_url" in result
assert "url" in result["image_url"]
# Check that the URL is a data URL with base64 encoded image
url = result["image_url"]["url"]
assert url.startswith("data:image/jpeg;base64,")
@pytest.mark.benchmark
def test_generate_mm_item_invalid_config(video_dataset: RandomMultiModalDataset):
"""Test error handling for invalid configurations."""
with pytest.raises(ValueError, match="Invalid multimodal item configuration"):
video_dataset.generate_mm_item((256, 256, 0))
@pytest.mark.benchmark
def test_sample_with_video_buckets(
video_dataset: RandomMultiModalDataset, hf_tokenizer: PreTrainedTokenizerBase
):
"""Test sampling with video bucket configurations."""
# Configure bucket with video probability > 0
bucket_config = {
(64, 64, 1): 0.3, # Images
(64, 64, 8): 0.7, # Videos
}
limit_mm_per_prompt = {"image": 5, "video": 3}
samples = video_dataset.sample(
tokenizer=hf_tokenizer,
num_requests=5,
base_items_per_request=2,
num_mm_items_range_ratio=0.0,
limit_mm_per_prompt=limit_mm_per_prompt,
bucket_config=bucket_config,
input_len=20,
output_len=5,
)
assert len(samples) == 5
# Check that samples contain both images and videos
video_count = 0
image_count = 0
for sample in samples:
assert isinstance(sample, SampleRequest)
assert sample.multi_modal_data is not None
assert isinstance(sample.multi_modal_data, list)
mm_data = cast(list[dict[str, Any]], sample.multi_modal_data)
assert len(mm_data) == 2 # base_items_per_request
for item in mm_data:
if item["type"] == "video_url":
video_count += 1
# Verify video URL format
url = item["video_url"]["url"]
assert url.startswith("data:video/mp4;base64,")
elif item["type"] == "image_url":
image_count += 1
# Verify image URL format
url = item["image_url"]["url"]
assert url.startswith("data:image/jpeg;base64,")
# Should have some videos due to 0.7 probability
assert video_count > 0
assert image_count > 0
@pytest.mark.benchmark
def test_sample_video_only_buckets(
video_dataset: RandomMultiModalDataset, hf_tokenizer: PreTrainedTokenizerBase
):
"""Test sampling with only video buckets."""
bucket_config = {
(64, 64, 8): 1.0, # Only videos
}
limit_mm_per_prompt = {"image": 0, "video": 2}
samples = video_dataset.sample(
tokenizer=hf_tokenizer,
num_requests=3,
base_items_per_request=1,
num_mm_items_range_ratio=0.0,
limit_mm_per_prompt=limit_mm_per_prompt,
bucket_config=bucket_config,
input_len=20,
output_len=5,
)
assert len(samples) == 3
for sample in samples:
assert isinstance(sample, SampleRequest)
assert sample.multi_modal_data is not None
assert isinstance(sample.multi_modal_data, list)
mm_data = cast(list[dict[str, Any]], sample.multi_modal_data)
assert len(mm_data) == 1
item = mm_data[0]
assert item["type"] == "video_url"
url = item["video_url"]["url"]
assert url.startswith("data:video/mp4;base64,")
@pytest.mark.benchmark
def test_sample_respects_video_limits(
video_dataset: RandomMultiModalDataset, hf_tokenizer: PreTrainedTokenizerBase
):
"""Test that sampling respects video limits per prompt."""
bucket_config = {
(64, 64, 8): 1.0, # Only videos
}
# Set very low video limit
limit_mm_per_prompt = {"image": 0, "video": 1}
samples = video_dataset.sample(
tokenizer=hf_tokenizer,
num_requests=3,
base_items_per_request=1,
num_mm_items_range_ratio=0.0,
limit_mm_per_prompt=limit_mm_per_prompt,
bucket_config=bucket_config,
input_len=20,
output_len=5,
)
assert len(samples) == 3
for sample in samples:
mm_data = cast(list[dict[str, Any]], sample.multi_modal_data)
assert len(mm_data) <= 1 # Should respect video limit
@pytest.mark.benchmark
def test_sample_mixed_buckets_with_zero_probability(
video_dataset: RandomMultiModalDataset, hf_tokenizer: PreTrainedTokenizerBase
):
"""Test sampling with mixed buckets including zero probability entries."""
bucket_config = {
(64, 64, 1): 0.5, # Images
(64, 64, 8): 0.5, # Videos
(128, 128, 16): 0.0, # Zero probability videos (should be ignored)
}
limit_mm_per_prompt = {"image": 2, "video": 2}
samples = video_dataset.sample(
tokenizer=hf_tokenizer,
num_requests=4,
base_items_per_request=2,
num_mm_items_range_ratio=0.0,
limit_mm_per_prompt=limit_mm_per_prompt,
bucket_config=bucket_config,
input_len=20,
output_len=5,
)
assert len(samples) == 4
# Should only see 64x64 videos, not 128x128 videos
for sample in samples:
mm_data = cast(list[dict[str, Any]], sample.multi_modal_data)
for item in mm_data:
if item["type"] == "video_url":
# Decode video to verify dimensions
url = item["video_url"]["url"]
base64_data = url.split(",")[1]
video_bytes = base64.b64decode(base64_data)
with NamedTemporaryFile(suffix=".mp4", delete=False) as temp_file: # noqa
temp_path = temp_file.name
temp_file.write(video_bytes)
try:
cap = cv2.VideoCapture(temp_path)
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
cap.release()
# Should be 64x64, not 128x128
assert frame_width == 64
assert frame_height == 64
finally:
if os.path.exists(temp_path):
os.unlink(temp_path)
@pytest.mark.benchmark
def test_sample_deterministic_with_videos(hf_tokenizer: PreTrainedTokenizerBase):
"""Test that sampling with videos is deterministic with same seed."""
dataset1 = RandomMultiModalDataset(random_seed=123)
dataset2 = RandomMultiModalDataset(random_seed=123)
bucket_config = {
(64, 64, 1): 0.3, # Images
(64, 64, 8): 0.7, # Videos
}
limit_mm_per_prompt = {"image": 2, "video": 2}
samples1 = dataset1.sample(
tokenizer=hf_tokenizer,
num_requests=3,
base_items_per_request=1,
num_mm_items_range_ratio=0.0,
limit_mm_per_prompt=limit_mm_per_prompt,
bucket_config=bucket_config,
input_len=20,
output_len=5,
)
samples2 = dataset2.sample(
tokenizer=hf_tokenizer,
num_requests=3,
base_items_per_request=1,
num_mm_items_range_ratio=0.0,
limit_mm_per_prompt=limit_mm_per_prompt,
bucket_config=bucket_config,
input_len=20,
output_len=5,
)
assert len(samples1) == len(samples2)
# Compare multimodal data
for s1, s2 in zip(samples1, samples2):
assert s1.multi_modal_data == s2.multi_modal_data
@pytest.mark.benchmark
def test_sample_different_seeds_produce_different_videos(
hf_tokenizer: PreTrainedTokenizerBase,
):
"""Test that different seeds produce different video content."""
dataset1 = RandomMultiModalDataset(random_seed=123)
dataset2 = RandomMultiModalDataset(random_seed=456)
bucket_config = {
(64, 64, 8): 1.0, # Only videos
}
limit_mm_per_prompt = {"image": 0, "video": 1}
samples1 = dataset1.sample(
tokenizer=hf_tokenizer,
num_requests=2,
base_items_per_request=1,
num_mm_items_range_ratio=0.0,
limit_mm_per_prompt=limit_mm_per_prompt,
bucket_config=bucket_config,
input_len=20,
output_len=5,
)
samples2 = dataset2.sample(
tokenizer=hf_tokenizer,
num_requests=2,
base_items_per_request=1,
num_mm_items_range_ratio=0.0,
limit_mm_per_prompt=limit_mm_per_prompt,
bucket_config=bucket_config,
input_len=20,
output_len=5,
)
# Video content should be different
for s1, s2 in zip(samples1, samples2):
mm_data1 = cast(list[dict[str, Any]], s1.multi_modal_data)
mm_data2 = cast(list[dict[str, Any]], s2.multi_modal_data)
assert len(mm_data1) == len(mm_data2) == 1
url1 = mm_data1[0]["video_url"]["url"]
url2 = mm_data2[0]["video_url"]["url"]
assert url1 != url2 # Different video content
...@@ -11,9 +11,7 @@ MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" ...@@ -11,9 +11,7 @@ MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def server(): def server():
args = [ args = ["--max-model-len", "1024", "--enforce-eager", "--load-format", "dummy"]
"--max-model-len", "1024", "--enforce-eager", "--load-format", "dummy"
]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server yield remote_server
...@@ -46,6 +44,7 @@ def test_bench_serve(server): ...@@ -46,6 +44,7 @@ def test_bench_serve(server):
assert result.returncode == 0, f"Benchmark failed: {result.stderr}" assert result.returncode == 0, f"Benchmark failed: {result.stderr}"
@pytest.mark.benchmark @pytest.mark.benchmark
def test_bench_serve_chat(server): def test_bench_serve_chat(server):
command = [ command = [
......
...@@ -10,8 +10,18 @@ MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" ...@@ -10,8 +10,18 @@ MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
@pytest.mark.benchmark @pytest.mark.benchmark
def test_bench_throughput(): def test_bench_throughput():
command = [ command = [
"vllm", "bench", "throughput", "--model", MODEL_NAME, "--input-len", "vllm",
"32", "--output-len", "1", "--enforce-eager", "--load-format", "dummy" "bench",
"throughput",
"--model",
MODEL_NAME,
"--input-len",
"32",
"--output-len",
"1",
"--enforce-eager",
"--load-format",
"dummy",
] ]
result = subprocess.run(command, capture_output=True, text=True) result = subprocess.run(command, capture_output=True, text=True)
print(result.stdout) print(result.stdout)
......
...@@ -5,13 +5,16 @@ These envs only work for a small part of the tests, fix what you need! ...@@ -5,13 +5,16 @@ These envs only work for a small part of the tests, fix what you need!
""" """
import os import os
from typing import TYPE_CHECKING, Any, Callable, Optional from collections.abc import Callable
from typing import TYPE_CHECKING, Any
from vllm.envs import maybe_convert_bool
if TYPE_CHECKING: if TYPE_CHECKING:
VLLM_CI_NO_SKIP: bool = False VLLM_CI_NO_SKIP: bool = False
VLLM_CI_DTYPE: Optional[str] = None VLLM_CI_DTYPE: str | None = None
VLLM_CI_HEAD_DTYPE: Optional[str] = None VLLM_CI_HEAD_DTYPE: str | None = None
VLLM_CI_HF_DTYPE: Optional[str] = None VLLM_CI_HF_DTYPE: str | None = None
environment_variables: dict[str, Callable[[], Any]] = { environment_variables: dict[str, Callable[[], Any]] = {
# A model family has many models with the same architecture. # A model family has many models with the same architecture.
...@@ -24,6 +27,10 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -24,6 +27,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_CI_HEAD_DTYPE": lambda: os.getenv("VLLM_CI_HEAD_DTYPE", None), "VLLM_CI_HEAD_DTYPE": lambda: os.getenv("VLLM_CI_HEAD_DTYPE", None),
# Allow changing the head dtype used by transformers in tests # Allow changing the head dtype used by transformers in tests
"VLLM_CI_HF_DTYPE": lambda: os.getenv("VLLM_CI_HF_DTYPE", None), "VLLM_CI_HF_DTYPE": lambda: os.getenv("VLLM_CI_HF_DTYPE", None),
# Allow control over whether tests use enforce_eager
"VLLM_CI_ENFORCE_EAGER": lambda: maybe_convert_bool(
os.getenv("VLLM_CI_ENFORCE_EAGER", None)
),
} }
......
...@@ -2,18 +2,23 @@ ...@@ -2,18 +2,23 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import weakref import weakref
from collections.abc import Sequence from collections.abc import Callable, Sequence
from contextlib import nullcontext
from copy import deepcopy from copy import deepcopy
from typing import Callable, Union
import depyf
from torch import fx from torch import fx
from torch._ops import OpOverload from torch._ops import OpOverload
from torch.fx._utils import lazy_format_graph_code
from vllm.compilation.fx_utils import find_op_nodes from vllm.compilation.fx_utils import find_op_nodes
from vllm.compilation.inductor_pass import InductorPass from vllm.compilation.inductor_pass import InductorPass
from vllm.compilation.pass_manager import with_pattern_match_debug from vllm.compilation.pass_manager import with_pattern_match_debug
from vllm.compilation.vllm_inductor_pass import VllmInductorPass from vllm.compilation.vllm_inductor_pass import VllmInductorPass
from vllm.config import VllmConfig, get_current_vllm_config from vllm.config import VllmConfig, get_current_vllm_config
from vllm.logger import init_logger
logger = init_logger("vllm.tests.compile.backend")
class LazyInitPass(InductorPass): class LazyInitPass(InductorPass):
...@@ -23,8 +28,7 @@ class LazyInitPass(InductorPass): ...@@ -23,8 +28,7 @@ class LazyInitPass(InductorPass):
and then immediately invoke it. and then immediately invoke it.
""" """
def __init__(self, pass_cls: type[VllmInductorPass], def __init__(self, pass_cls: type[VllmInductorPass], vllm_config: VllmConfig):
vllm_config: VllmConfig):
self.pass_cls = pass_cls self.pass_cls = pass_cls
self.vllm_config = weakref.proxy(vllm_config) # avoid cycle self.vllm_config = weakref.proxy(vllm_config) # avoid cycle
...@@ -45,24 +49,34 @@ class TestBackend: ...@@ -45,24 +49,34 @@ class TestBackend:
Inductor config is default-initialized from VllmConfig.CompilationConfig. Inductor config is default-initialized from VllmConfig.CompilationConfig.
""" """
def __init__(self, *passes: Union[InductorPass, Callable[[fx.Graph], def __init__(self, *passes: InductorPass | Callable[[fx.Graph], None]):
None]]):
self.custom_passes = list(passes) self.custom_passes = list(passes)
compile_config = get_current_vllm_config().compilation_config vllm_config = get_current_vllm_config()
self.inductor_config = compile_config.inductor_compile_config compile_config = vllm_config.compilation_config
self.inductor_config['force_disable_caches'] = True # Deepcopy to allow multiple TestBackend instances to use the same VllmConfig
self.inductor_config['post_grad_custom_post_pass'] = self.post_pass self.inductor_config = deepcopy(compile_config.inductor_compile_config)
self.inductor_config["force_disable_caches"] = True
self.inductor_config["post_grad_custom_post_pass"] = self.post_pass
if debug_dump_path := vllm_config.compile_debug_dump_path():
logger.debug("Dumping depyf output to %s", debug_dump_path)
self.debug_ctx = depyf.prepare_debug(debug_dump_path.as_posix())
else:
self.debug_ctx = nullcontext()
def __call__(self, graph: fx.GraphModule, example_inputs): def __call__(self, graph: fx.GraphModule, example_inputs):
self.graph_pre_compile = deepcopy(graph) self.graph_pre_compile = deepcopy(graph)
from torch._inductor.compile_fx import compile_fx from torch._inductor.compile_fx import compile_fx
return compile_fx(graph,
example_inputs, with self.debug_ctx:
config_patches=self.inductor_config) return compile_fx(
graph, example_inputs, config_patches=self.inductor_config
)
@with_pattern_match_debug @with_pattern_match_debug
def post_pass(self, graph: fx.Graph): def post_pass(self, graph: fx.Graph):
self.graph_pre_pass = deepcopy(graph) self.graph_pre_pass = deepcopy(graph)
lazy_format_graph_code("graph_pre_pass", graph.owning_module)
VllmInductorPass.dump_prefix = 0 VllmInductorPass.dump_prefix = 0
for pass_ in self.custom_passes: for pass_ in self.custom_passes:
...@@ -72,6 +86,7 @@ class TestBackend: ...@@ -72,6 +86,7 @@ class TestBackend:
VllmInductorPass.dump_prefix = None VllmInductorPass.dump_prefix = None
self.graph_post_pass = deepcopy(graph) self.graph_post_pass = deepcopy(graph)
lazy_format_graph_code("graph_post_pass", graph.owning_module)
# assign by reference, will reflect the final state of the graph # assign by reference, will reflect the final state of the graph
self.final_graph = graph self.final_graph = graph
...@@ -82,8 +97,7 @@ class TestBackend: ...@@ -82,8 +97,7 @@ class TestBackend:
assert num_pre > 0, f"Op {op.name()} not found in pre-pass graph" assert num_pre > 0, f"Op {op.name()} not found in pre-pass graph"
assert num_pre > num_post, f"All nodes remain for op {op.name()}" assert num_pre > num_post, f"All nodes remain for op {op.name()}"
if fully_replaced: if fully_replaced:
assert num_post == 0, \ assert num_post == 0, f"Unexpected op {op.name()} in post-pass graph"
f"Unexpected op {op.name()} in post-pass graph"
def check_after_ops(self, ops: Sequence[OpOverload]): def check_after_ops(self, ops: Sequence[OpOverload]):
for op in ops: for op in ops:
......
...@@ -3,15 +3,15 @@ ...@@ -3,15 +3,15 @@
import contextlib import contextlib
import os import os
import weakref import weakref
from dataclasses import dataclass
from typing import Optional
import pytest import pytest
from tests.utils import wait_for_gpu_memory_to_clear from tests.utils import wait_for_gpu_memory_to_clear
from tests.v1.attention.utils import full_cg_backend_configs as backend_configs
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.config import CompilationConfig from vllm.config import CompilationConfig
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.torch_utils import is_torch_equal_or_newer
@contextlib.contextmanager @contextlib.contextmanager
...@@ -33,121 +33,44 @@ def temporary_environ(env_vars): ...@@ -33,121 +33,44 @@ def temporary_environ(env_vars):
os.environ[k] = v os.environ[k] = v
@dataclass model_backends_full_cudagraph = []
class BackendConfig:
name: str
env_vars: dict
comp_config: dict
specific_gpu_arch: Optional[tuple] = None
# Define all backend configurations of full cudagraph to be tested
backend_configs = {
# FA3 on Hopper
"FA3":
BackendConfig(name="FA3",
env_vars={
"VLLM_FLASH_ATTN_VERSION": "3",
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
},
comp_config={
"cudagraph_mode": "FULL",
},
specific_gpu_arch=(9, 0)),
# FlashMLA on Hopper
"FlashMLA":
BackendConfig(name="FlashMLA",
env_vars={
"VLLM_ATTENTION_BACKEND": "FLASHMLA",
},
comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE",
},
specific_gpu_arch=(9, 0)),
# FlashAttention MLA on Hopper
"FlashAttentionMLA":
BackendConfig(name="FlashAttentionMLA",
env_vars={
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN_MLA",
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
},
comp_config={
"cudagraph_mode": "FULL_DECODE_ONLY",
},
specific_gpu_arch=(9, 0)),
# Cutlass MLA on Blackwell
"CutlassMLA":
BackendConfig(
name="CutlassMLA",
env_vars={
"VLLM_USE_V1": "1",
"VLLM_ATTENTION_BACKEND": "CUTLASS_MLA",
"FORCE_NUM_KV_SPLITS":
"1", # TODO: remove this when hang issue is fixed
},
comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE",
"cudagraph_capture_sizes": [16, 32, 64, 128, 256, 512],
},
specific_gpu_arch=(10, 0)),
# FA2
"FA2":
BackendConfig(name="FA2",
env_vars={
"VLLM_FLASH_ATTN_VERSION": "2",
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
},
comp_config={
"cudagraph_mode": "FULL",
}),
# Triton Attention
"TritonAttn":
BackendConfig(name="TritonAttn",
env_vars={"VLLM_ATTENTION_BACKEND": "TRITON_ATTN"},
comp_config={
"cudagraph_mode": "FULL",
}),
# FlashInfer
"FlashInfer":
BackendConfig(name="FlashInfer",
env_vars={"VLLM_ATTENTION_BACKEND": "FLASHINFER"},
comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE",
}),
}
test_params_full_cudagraph = []
# deepseek-ai/DeepSeek-V2-Lite with MLA # deepseek-ai/DeepSeek-V2-Lite with MLA
MLA_backends = ["FlashMLA", "FlashAttentionMLA", "CutlassMLA"] MLA_backends = ["FlashMLA", "FlashAttentionMLA", "CutlassMLA"]
for mla_backend in MLA_backends: for mla_backend in MLA_backends:
test_params_full_cudagraph.append( model_backends_full_cudagraph.append(
pytest.param( ("deepseek-ai/DeepSeek-V2-Lite", backend_configs[mla_backend])
("deepseek-ai/DeepSeek-V2-Lite", backend_configs[mla_backend]))) )
# Qwen/Qwen2-1.5B-Instruct with other backends # Qwen/Qwen2-1.5B-Instruct with other backends
other_backend_configs = [ other_backend_configs = [
backend_configs[c] for c in backend_configs if c not in MLA_backends backend_configs[c] for c in backend_configs if c not in MLA_backends
] ]
for backend_config in other_backend_configs: for backend_config in other_backend_configs:
test_params_full_cudagraph.append( model_backends_full_cudagraph.append(("Qwen/Qwen2-1.5B-Instruct", backend_config))
pytest.param(("Qwen/Qwen2-1.5B-Instruct", backend_config)))
@pytest.fixture(scope="class") @pytest.fixture(scope="class")
def llm_pair(request): def llm_pair(request):
model, backend_config = request.param model, backend_config, use_inductor_graph_partition = request.param
backend_config.comp_config["use_inductor_graph_partition"] = (
use_inductor_graph_partition
)
if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
pytest.skip("Inductor graph partition only supported in torch>=2.9")
# Dynamically skip test if GPU capability is not met # Dynamically skip test if GPU capability is not met
if backend_config.specific_gpu_arch and backend_config.specific_gpu_arch\ if (
!= current_platform.get_device_capability(): backend_config.specific_gpu_arch
and backend_config.specific_gpu_arch != current_platform.get_device_capability()
):
if backend_config.specific_gpu_arch == (9, 0): if backend_config.specific_gpu_arch == (9, 0):
pytest.skip("Only Hopper GPUs support FA3 and FlashMLA") pytest.skip("Only Hopper GPUs support FA3 and FlashMLA")
elif backend_config.specific_gpu_arch == (10, 0): elif backend_config.specific_gpu_arch == (10, 0):
pytest.skip("Only Blackwell GPUs support Cutlass MLA") pytest.skip("Only Blackwell GPUs support Cutlass MLA")
env_vars = { env_vars = {
"VLLM_USE_V1": "1",
# Force native sampler to avoid potential nondeterminism in FlashInfer # Force native sampler to avoid potential nondeterminism in FlashInfer
# when per-request generators are not used in V1. # when per-request generators are not used in V1.
"VLLM_USE_FLASHINFER_SAMPLER": "0", "VLLM_USE_FLASHINFER_SAMPLER": "0",
...@@ -160,8 +83,7 @@ def llm_pair(request): ...@@ -160,8 +83,7 @@ def llm_pair(request):
trust_remote_code=True, trust_remote_code=True,
max_model_len=1024, max_model_len=1024,
max_num_seqs=128, max_num_seqs=128,
compilation_config=\ compilation_config=CompilationConfig(**backend_config.comp_config),
CompilationConfig(**backend_config.comp_config),
generation_config="vllm", generation_config="vllm",
seed=42, seed=42,
) )
...@@ -187,7 +109,15 @@ def llm_pair(request): ...@@ -187,7 +109,15 @@ def llm_pair(request):
) )
@pytest.mark.parametrize("llm_pair", test_params_full_cudagraph, indirect=True) @pytest.mark.parametrize(
"llm_pair",
[
pytest.param((model, backend_config, use_inductor_graph_partition))
for model, backend_config in model_backends_full_cudagraph
for use_inductor_graph_partition in [True, False]
],
indirect=True,
)
class TestFullCUDAGraph: class TestFullCUDAGraph:
""" """
Use a class such that an llm pair is constructed once for all Use a class such that an llm pair is constructed once for all
...@@ -197,20 +127,22 @@ class TestFullCUDAGraph: ...@@ -197,20 +127,22 @@ class TestFullCUDAGraph:
meaning there would be multiple LLM instances hogging memory simultaneously. meaning there would be multiple LLM instances hogging memory simultaneously.
""" """
@pytest.mark.parametrize(("batch_size", "max_tokens"), [ @pytest.mark.parametrize(
(1, 10), ("batch_size", "max_tokens"),
(7, 10), [
(16, 10), (1, 10),
(25, 10), (7, 10),
(32, 10), (16, 10),
(45, 10), (25, 10),
(64, 10), (32, 10),
(123, 10), (45, 10),
(8, 5), (64, 10),
(8, 30), (123, 10),
]) (8, 5),
def test_full_cudagraph(self, batch_size, max_tokens, (8, 30),
llm_pair: tuple[LLM, LLM]): ],
)
def test_full_cudagraph(self, batch_size, max_tokens, llm_pair: tuple[LLM, LLM]):
""" """
Test various batch sizes and max_tokens to ensure that the Test various batch sizes and max_tokens to ensure that the
full cudagraph compilation works for padded cases too. full cudagraph compilation works for padded cases too.
...@@ -221,26 +153,33 @@ class TestFullCUDAGraph: ...@@ -221,26 +153,33 @@ class TestFullCUDAGraph:
prompts = ["the quick brown fox"] * batch_size prompts = ["the quick brown fox"] * batch_size
# Use purely greedy decoding to avoid top-p truncation sensitivity # Use purely greedy decoding to avoid top-p truncation sensitivity
# that can amplify tiny numeric differences across runtimes. # that can amplify tiny numeric differences across runtimes.
sampling_params = SamplingParams(temperature=0.0, sampling_params = SamplingParams(
max_tokens=max_tokens, temperature=0.0, max_tokens=max_tokens, top_p=1.0
top_p=1.0) )
piecewise_responses = piecewise_llm.generate(prompts, sampling_params) piecewise_responses = piecewise_llm.generate(prompts, sampling_params)
full_responses = full_cudagraph_llm.generate(prompts, sampling_params) full_responses = full_cudagraph_llm.generate(prompts, sampling_params)
# Check that all responses are the same # Check that all responses are the same
for piecewise_res, full_res in zip(piecewise_responses, for piecewise_res, full_res in zip(piecewise_responses, full_responses):
full_responses): assert (
assert piecewise_res.outputs[0].text.lower() == \ piecewise_res.outputs[0].text.lower()
full_res.outputs[0].text.lower() == full_res.outputs[0].text.lower()
)
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda") @pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
def test_full_cudagraph_with_invalid_backend(): def test_full_cudagraph_with_invalid_backend():
with temporary_environ({ with (
"VLLM_USE_V1": "1", temporary_environ(
"VLLM_ATTENTION_BACKEND": "FLEX_ATTENTION" {
# Flex_Attention is not supported with full cuda graph "VLLM_ATTENTION_BACKEND": "FLEX_ATTENTION",
}), pytest.raises(RuntimeError): # Flex_Attention is not supported with full cuda graph
LLM(model="Qwen/Qwen2-1.5B-Instruct", }
compilation_config=CompilationConfig(cudagraph_mode="FULL")) ),
pytest.raises(RuntimeError),
):
LLM(
model="Qwen/Qwen2-1.5B-Instruct",
compilation_config=CompilationConfig(cudagraph_mode="FULL"),
)
...@@ -5,16 +5,24 @@ Test (piecewise) compilation with a simple model where multiple submodules ...@@ -5,16 +5,24 @@ Test (piecewise) compilation with a simple model where multiple submodules
are compiled and graph captured separately. are compiled and graph captured separately.
""" """
import pytest
import torch import torch
from torch import nn from torch import nn
from vllm.compilation.backends import set_model_tag from vllm.compilation.backends import set_model_tag
from vllm.compilation.counter import compilation_counter from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import (ignore_torch_compile, from vllm.compilation.decorators import ignore_torch_compile, support_torch_compile
support_torch_compile) from vllm.config import (
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode, CompilationConfig,
VllmConfig, set_current_vllm_config) CompilationMode,
CUDAGraphMode,
VllmConfig,
set_current_vllm_config,
)
from vllm.forward_context import BatchDescriptor, set_forward_context from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.utils.torch_utils import is_torch_equal_or_newer
from ...utils import create_new_process_for_each_test
# This import automatically registers `torch.ops.silly.attention` # This import automatically registers `torch.ops.silly.attention`
from .. import silly_attention # noqa: F401 from .. import silly_attention # noqa: F401
...@@ -27,12 +35,7 @@ RANDOM_SEED = 0 ...@@ -27,12 +35,7 @@ RANDOM_SEED = 0
@support_torch_compile @support_torch_compile
class ParentModel(nn.Module): class ParentModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs) -> None:
def __init__(self,
*,
vllm_config: VllmConfig,
prefix: str = '',
**kwargs) -> None:
super().__init__() super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
...@@ -40,7 +43,6 @@ class ParentModel(nn.Module): ...@@ -40,7 +43,6 @@ class ParentModel(nn.Module):
class Attention(nn.Module): class Attention(nn.Module):
def __init__(self, mlp_size: int, hidden_size: int) -> None: def __init__(self, mlp_size: int, hidden_size: int) -> None:
super().__init__() super().__init__()
self.pre_attn = nn.Linear(mlp_size, hidden_size, bias=False) self.pre_attn = nn.Linear(mlp_size, hidden_size, bias=False)
...@@ -51,17 +53,21 @@ class Attention(nn.Module): ...@@ -51,17 +53,21 @@ class Attention(nn.Module):
nn.init.xavier_normal_( nn.init.xavier_normal_(
self.pre_attn.weight.data, self.pre_attn.weight.data,
generator=torch.Generator().manual_seed(RANDOM_SEED), generator=torch.Generator().manual_seed(RANDOM_SEED),
gain=0.001) gain=0.001,
)
nn.init.xavier_normal_( nn.init.xavier_normal_(
self.post_attn.weight.data, self.post_attn.weight.data,
generator=torch.Generator().manual_seed(RANDOM_SEED), generator=torch.Generator().manual_seed(RANDOM_SEED),
gain=0.001) gain=0.001,
)
def rms_norm_ref(self, x: torch.Tensor) -> torch.Tensor: def rms_norm_ref(self, x: torch.Tensor) -> torch.Tensor:
x_f32 = x.float() x_f32 = x.float()
return (x_f32 * torch.rsqrt( return (
torch.mean(x_f32.square(), dim=-1, keepdim=True) + 1e-6) * x_f32
self.rms_norm_weight).to(x.dtype) * torch.rsqrt(torch.mean(x_f32.square(), dim=-1, keepdim=True) + 1e-6)
* self.rms_norm_weight
).to(x.dtype)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.pre_attn(x) x = self.pre_attn(x)
...@@ -76,14 +82,15 @@ class Attention(nn.Module): ...@@ -76,14 +82,15 @@ class Attention(nn.Module):
@support_torch_compile @support_torch_compile
class CompiledAttention(nn.Module): class CompiledAttention(nn.Module):
def __init__(
def __init__(self, self,
*, *,
mlp_size: int, mlp_size: int,
hidden_size: int, hidden_size: int,
vllm_config: VllmConfig, vllm_config: VllmConfig,
prefix: str = '', prefix: str = "",
**kwargs) -> None: **kwargs,
) -> None:
super().__init__() super().__init__()
self.attn = Attention(mlp_size, hidden_size) self.attn = Attention(mlp_size, hidden_size)
...@@ -93,21 +100,21 @@ class CompiledAttention(nn.Module): ...@@ -93,21 +100,21 @@ class CompiledAttention(nn.Module):
@support_torch_compile @support_torch_compile
class CompiledAttentionTwo(CompiledAttention): class CompiledAttentionTwo(CompiledAttention):
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.attn(x) + x return self.attn(x) + x
@ignore_torch_compile @ignore_torch_compile
class SimpleModelWithTwoGraphs(ParentModel): class SimpleModelWithTwoGraphs(ParentModel):
def __init__(
def __init__(self, self,
*, *,
mlp_size: int, mlp_size: int,
hidden_size: int, hidden_size: int,
vllm_config: VllmConfig, vllm_config: VllmConfig,
prefix: str = '', prefix: str = "",
**kwargs) -> None: **kwargs,
) -> None:
super().__init__(vllm_config=vllm_config, prefix=prefix) super().__init__(vllm_config=vllm_config, prefix=prefix)
# Test will fail without set_model_tag here with error: # Test will fail without set_model_tag here with error:
# "ValueError: too many values to unpack (expected 3)" # "ValueError: too many values to unpack (expected 3)"
...@@ -142,118 +149,174 @@ class SimpleModelWithTwoGraphs(ParentModel): ...@@ -142,118 +149,174 @@ class SimpleModelWithTwoGraphs(ParentModel):
@torch.inference_mode @torch.inference_mode
def run_model(vllm_config: VllmConfig, model: nn.Module, inputs: torch.Tensor, def run_model(
cudagraph_runtime_mode: CUDAGraphMode): vllm_config: VllmConfig,
model: nn.Module,
inputs: torch.Tensor,
cudagraph_runtime_mode: CUDAGraphMode,
):
with set_forward_context({}, vllm_config=vllm_config): with set_forward_context({}, vllm_config=vllm_config):
# warmup for the model with cudagraph_mode NONE # warmup for the model with cudagraph_mode NONE
model(inputs) model(inputs)
# simulate cudagraphs capturing # simulate cudagraphs capturing
with set_forward_context({}, with set_forward_context(
vllm_config=vllm_config, {},
cudagraph_runtime_mode=cudagraph_runtime_mode, vllm_config=vllm_config,
batch_descriptor=BatchDescriptor( cudagraph_runtime_mode=cudagraph_runtime_mode,
num_tokens=2, )): batch_descriptor=BatchDescriptor(
num_tokens=2,
),
):
model(inputs[:2]) model(inputs[:2])
with set_forward_context({}, with set_forward_context(
vllm_config=vllm_config, {},
cudagraph_runtime_mode=cudagraph_runtime_mode, vllm_config=vllm_config,
batch_descriptor=BatchDescriptor( cudagraph_runtime_mode=cudagraph_runtime_mode,
num_tokens=1, )): batch_descriptor=BatchDescriptor(
num_tokens=1,
),
):
model(inputs[:1]) model(inputs[:1])
# simulate cudagraphs replay # simulate cudagraphs replay
with set_forward_context({}, with set_forward_context(
vllm_config=vllm_config, {},
cudagraph_runtime_mode=cudagraph_runtime_mode, vllm_config=vllm_config,
batch_descriptor=BatchDescriptor( cudagraph_runtime_mode=cudagraph_runtime_mode,
num_tokens=2, )): batch_descriptor=BatchDescriptor(
num_tokens=2,
),
):
output = model(inputs[:2]) output = model(inputs[:2])
output = output.cpu() output = output.cpu()
return output.cpu() return output.cpu()
def test_multi_graph_piecewise_compile_outputs_equal(): @pytest.mark.parametrize("use_inductor_graph_partition", [False, True])
@pytest.mark.parametrize("use_bytecode_hook", [True, False])
@create_new_process_for_each_test("spawn")
def test_multi_graph_piecewise_compile(
use_inductor_graph_partition: bool, use_bytecode_hook: bool, monkeypatch
):
# Set the environment variable for this test
monkeypatch.setenv("VLLM_USE_BYTECODE_HOOK", "1" if use_bytecode_hook else "0")
if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
outputs = [] outputs = []
# piecewise compile # vllmcompile compile
vllm_config = VllmConfig(compilation_config=CompilationConfig( vllm_config = VllmConfig(
level=CompilationLevel.PIECEWISE, compilation_config=CompilationConfig(
use_cudagraph=True, mode=CompilationMode.VLLM_COMPILE,
splitting_ops=["silly.attention"], cudagraph_mode=CUDAGraphMode.PIECEWISE,
cudagraph_capture_sizes=[1, 2], splitting_ops=["silly::attention"],
)) cudagraph_capture_sizes=[1, 2],
use_inductor_graph_partition=use_inductor_graph_partition,
)
)
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config):
model = SimpleModelWithTwoGraphs(mlp_size=MLP_SIZE, model = (
hidden_size=HIDDEN_SIZE, SimpleModelWithTwoGraphs(
vllm_config=vllm_config, mlp_size=MLP_SIZE,
prefix='').eval().cuda() hidden_size=HIDDEN_SIZE,
vllm_config=vllm_config,
prefix="",
)
.eval()
.cuda()
)
# Pre-allocate memory for CUDAGraph which expects # Pre-allocate memory for CUDAGraph which expects
# static tensor addresses # static tensor addresses
inputs = torch.randn(BATCH_SIZE, MLP_SIZE).cuda() inputs = torch.randn(BATCH_SIZE, MLP_SIZE).cuda()
if use_inductor_graph_partition:
# Splitting happens at Inductor lowering level,
# total piecewise fx graphs is equal to total graphs
num_piecewise_fx = 2
num_piecewise_capturable_fx = 2
else:
# attn_one, attn_two each has 3 piecewise graphs
# (pre attn, post attn, silly_attention) each
num_piecewise_fx = 6
# attn_one, attn_two has pre attn and post attn each, total=4
num_piecewise_capturable_fx = 4
with compilation_counter.expect( with compilation_counter.expect(
num_graphs_seen=2, # two graphs for the model num_graphs_seen=2, # two graphs for the model
num_piecewise_graphs_seen=6, num_piecewise_graphs_seen=num_piecewise_fx,
# attn_one, attn_two each has 3 piecewise graphs num_piecewise_capturable_graphs_seen=num_piecewise_capturable_fx,
# (pre attn, post attn, silly_attention) each num_backend_compilations=num_piecewise_capturable_fx,
num_piecewise_capturable_graphs_seen=4, num_cudagraph_captured=8, # num_cudagraph_sizes * num_partitions
# attn_one, attn_two has pre attn and post attn each, total=4
num_backend_compilations=4, # num_piecewise_capturable_graphs_seen
num_cudagraph_captured=8,
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
): ):
outputs.append( outputs.append(run_model(vllm_config, model, inputs, cudagraph_runtime_mode))
run_model(vllm_config, model, inputs, cudagraph_runtime_mode))
# no compile or cudagraph # no compile or cudagraph
vllm_config = VllmConfig(compilation_config=CompilationConfig( vllm_config = VllmConfig(
level=CompilationLevel.NO_COMPILATION, )) compilation_config=CompilationConfig(
mode=CompilationMode.NONE,
)
)
cudagraph_runtime_mode = CUDAGraphMode.NONE cudagraph_runtime_mode = CUDAGraphMode.NONE
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config):
model = SimpleModelWithTwoGraphs(mlp_size=MLP_SIZE, model = (
hidden_size=HIDDEN_SIZE, SimpleModelWithTwoGraphs(
vllm_config=vllm_config, mlp_size=MLP_SIZE,
prefix='').eval().cuda() hidden_size=HIDDEN_SIZE,
vllm_config=vllm_config,
prefix="",
)
.eval()
.cuda()
)
with compilation_counter.expect( with compilation_counter.expect(
num_graphs_seen=0, num_graphs_seen=0,
num_piecewise_graphs_seen=0, num_piecewise_graphs_seen=0,
num_piecewise_capturable_graphs_seen=0, num_piecewise_capturable_graphs_seen=0,
num_backend_compilations=0, num_backend_compilations=0,
num_cudagraph_captured=0, num_cudagraph_captured=0,
): ):
outputs.append( outputs.append(run_model(vllm_config, model, inputs, cudagraph_runtime_mode))
run_model(vllm_config, model, inputs, cudagraph_runtime_mode))
# piecewise compile without CUDA graph # piecewise compile without CUDA graph
vllm_config = VllmConfig(compilation_config=CompilationConfig( vllm_config = VllmConfig(
level=CompilationLevel.PIECEWISE, compilation_config=CompilationConfig(
use_cudagraph=False, mode=CompilationMode.VLLM_COMPILE,
splitting_ops=["silly.attention"], cudagraph_mode=CUDAGraphMode.NONE,
)) splitting_ops=["silly::attention"],
use_inductor_graph_partition=use_inductor_graph_partition,
)
)
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config):
model = SimpleModelWithTwoGraphs(mlp_size=MLP_SIZE, model = (
hidden_size=HIDDEN_SIZE, SimpleModelWithTwoGraphs(
vllm_config=vllm_config, mlp_size=MLP_SIZE,
prefix='').eval().cuda() hidden_size=HIDDEN_SIZE,
vllm_config=vllm_config,
prefix="",
)
.eval()
.cuda()
)
with compilation_counter.expect( with compilation_counter.expect(
num_graphs_seen=2, num_graphs_seen=2,
num_piecewise_graphs_seen=6, num_piecewise_graphs_seen=num_piecewise_fx,
num_piecewise_capturable_graphs_seen=4, num_piecewise_capturable_graphs_seen=num_piecewise_capturable_fx,
num_backend_compilations=4, num_backend_compilations=num_piecewise_capturable_fx,
num_cudagraph_captured=0, # no cudagraph captured num_cudagraph_captured=0, # no cudagraph captured
): ):
outputs.append( outputs.append(run_model(vllm_config, model, inputs, cudagraph_runtime_mode))
run_model(vllm_config, model, inputs, cudagraph_runtime_mode))
# Generally don't expect outputs with and without inductor # Generally don't expect outputs with and without inductor
# to be bitwise equivalent # to be bitwise equivalent
......
...@@ -11,11 +11,17 @@ from torch import nn ...@@ -11,11 +11,17 @@ from torch import nn
from vllm.compilation.counter import compilation_counter from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode, from vllm.config import (
VllmConfig, set_current_vllm_config) CompilationConfig,
from vllm.envs import VLLM_USE_V1 CompilationMode,
CUDAGraphMode,
VllmConfig,
set_current_vllm_config,
)
from vllm.forward_context import BatchDescriptor, set_forward_context from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.utils import is_torch_equal_or_newer from vllm.utils.torch_utils import is_torch_equal_or_newer
from ...utils import create_new_process_for_each_test
# This import automatically registers `torch.ops.silly.attention` # This import automatically registers `torch.ops.silly.attention`
from ..silly_attention import get_global_counter, reset_global_counter from ..silly_attention import get_global_counter, reset_global_counter
...@@ -23,12 +29,7 @@ from ..silly_attention import get_global_counter, reset_global_counter ...@@ -23,12 +29,7 @@ from ..silly_attention import get_global_counter, reset_global_counter
@support_torch_compile @support_torch_compile
class SillyModel(nn.Module): class SillyModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs) -> None:
def __init__(self,
*,
vllm_config: VllmConfig,
prefix: str = '',
**kwargs) -> None:
super().__init__() super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
...@@ -60,53 +61,64 @@ def _run_simple_model( ...@@ -60,53 +61,64 @@ def _run_simple_model(
expected_num_backend_compilations, expected_num_backend_compilations,
expected_num_cudagraph_captured, expected_num_cudagraph_captured,
): ):
vllm_config = VllmConfig(compilation_config=CompilationConfig( vllm_config = VllmConfig(
level=CompilationLevel.PIECEWISE, compilation_config=CompilationConfig(
use_cudagraph=True, mode=CompilationMode.VLLM_COMPILE,
use_inductor=use_inductor, use_inductor=use_inductor,
splitting_ops=splitting_ops, splitting_ops=splitting_ops,
use_inductor_graph_partition=use_inductor_graph_partition, use_inductor_graph_partition=use_inductor_graph_partition,
cudagraph_copy_inputs=True, cudagraph_copy_inputs=True,
cudagraph_capture_sizes=[1, 2], cudagraph_capture_sizes=[1, 2],
)) )
)
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config):
model = SillyModel(vllm_config=vllm_config, prefix='') model = SillyModel(vllm_config=vllm_config, prefix="")
inputs = torch.randn(100).cuda() inputs = torch.randn(100).cuda()
with compilation_counter.expect( with (
compilation_counter.expect(
num_graphs_seen=1, # one graph for the model num_graphs_seen=1, # one graph for the model
num_piecewise_graphs_seen=expected_num_piecewise_graphs_seen, num_piecewise_graphs_seen=expected_num_piecewise_graphs_seen,
num_piecewise_capturable_graphs_seen= num_piecewise_capturable_graphs_seen=expected_num_piecewise_capturable_graphs_seen,
expected_num_piecewise_capturable_graphs_seen,
num_backend_compilations=expected_num_backend_compilations, num_backend_compilations=expected_num_backend_compilations,
num_cudagraph_captured=expected_num_cudagraph_captured, num_cudagraph_captured=expected_num_cudagraph_captured,
), set_forward_context(None, ),
vllm_config=vllm_config): # background context set_forward_context(None, vllm_config=vllm_config),
): # background context
# warm up with background context # warm up with background context
model(inputs) model(inputs)
# capturing/replaying should under context of cudagraph dispatching # capturing/replaying should under context of cudagraph dispatching
with set_forward_context( with set_forward_context(
None, None,
vllm_config=vllm_config, vllm_config=vllm_config,
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE, cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
batch_descriptor=BatchDescriptor(num_tokens=2, )): batch_descriptor=BatchDescriptor(
num_tokens=2,
),
):
model(torch.randn(2).cuda()) model(torch.randn(2).cuda())
with set_forward_context( with set_forward_context(
None, None,
vllm_config=vllm_config, vllm_config=vllm_config,
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE, cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
batch_descriptor=BatchDescriptor(num_tokens=1, )): batch_descriptor=BatchDescriptor(
num_tokens=1,
),
):
model(torch.randn(1).cuda()) model(torch.randn(1).cuda())
input = torch.zeros(2).cuda() input = torch.zeros(2).cuda()
reset_global_counter() reset_global_counter()
with set_forward_context( with set_forward_context(
None, None,
vllm_config=vllm_config, vllm_config=vllm_config,
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE, cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
batch_descriptor=BatchDescriptor(num_tokens=2, )): batch_descriptor=BatchDescriptor(
num_tokens=2,
),
):
output = model(input) output = model(input)
assert get_global_counter() == 2 assert get_global_counter() == 2
assert torch.allclose(output.cpu(), torch.tensor([19.0, 19.0])) assert torch.allclose(output.cpu(), torch.tensor([19.0, 19.0]))
...@@ -114,42 +126,42 @@ def _run_simple_model( ...@@ -114,42 +126,42 @@ def _run_simple_model(
@pytest.mark.parametrize("use_inductor", [True, False]) @pytest.mark.parametrize("use_inductor", [True, False])
@torch.inference_mode() @torch.inference_mode()
@create_new_process_for_each_test("spawn")
def test_simple_piecewise_compile(use_inductor): def test_simple_piecewise_compile(use_inductor):
assert VLLM_USE_V1
_run_simple_model( _run_simple_model(
splitting_ops=["silly.attention"], splitting_ops=["silly::attention"],
use_inductor_graph_partition=False, use_inductor_graph_partition=False,
use_inductor=use_inductor, use_inductor=use_inductor,
expected_num_piecewise_graphs_seen=5, # 2 * num_layers + 1 # 2 * num_layers + 1
expected_num_piecewise_capturable_graphs_seen=3, # 1 + num_layers expected_num_piecewise_graphs_seen=5,
expected_num_backend_compilations= # 1 + num_layers
3, # num_piecewise_capturable_graphs_seen expected_num_piecewise_capturable_graphs_seen=3,
expected_num_cudagraph_captured= # num_piecewise_capturable_graphs_seen
6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen expected_num_backend_compilations=3,
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
expected_num_cudagraph_captured=6,
) )
@torch.inference_mode() @torch.inference_mode()
@pytest.mark.parametrize("splitting_ops", [["silly.attention"], []]) def test_simple_inductor_graph_partition(monkeypatch):
def test_simple_inductor_graph_partition(splitting_ops):
assert VLLM_USE_V1
if not is_torch_equal_or_newer("2.9.0.dev"): if not is_torch_equal_or_newer("2.9.0.dev"):
pytest.skip("inductor graph partition is only available " pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
"in PyTorch 2.9+")
# disable compile cache so that we run separately for different splitting_ops
# and get the expected number of cudagraphs captured.
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
_run_simple_model( _run_simple_model(
# inductor graph partition automatically resets splitting_ops splitting_ops=["silly::attention"],
# to be an empty list
splitting_ops=splitting_ops,
use_inductor_graph_partition=True, use_inductor_graph_partition=True,
use_inductor=True, use_inductor=True,
expected_num_piecewise_graphs_seen= # Since not splitting at fx graph level
1, # since not splitting at fx graph level expected_num_piecewise_graphs_seen=1,
expected_num_piecewise_capturable_graphs_seen= # Since not splitting at fx graph level
1, # since not splitting at fx graph level expected_num_piecewise_capturable_graphs_seen=1,
expected_num_backend_compilations= # Since not splitting at fx graph level
1, # since not splitting at fx graph level expected_num_backend_compilations=1,
expected_num_cudagraph_captured= # Inductor graph partition still captures 6 graph, same as fx graph partition
6, # inductor graph partition still captures 6 expected_num_cudagraph_captured=6,
# graph, same as fx graph partition.
) )
...@@ -8,8 +8,10 @@ This is a tractable model, the weights and computation are specially designed ...@@ -8,8 +8,10 @@ This is a tractable model, the weights and computation are specially designed
if the config `tractable_init` is set to True. Otherwise, the weights are if the config `tractable_init` is set to True. Otherwise, the weights are
initialized randomly with a fixed seed. initialized randomly with a fixed seed.
""" """
from copy import deepcopy
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Optional from typing import Any
import pytest import pytest
import torch import torch
...@@ -17,9 +19,17 @@ from torch import nn ...@@ -17,9 +19,17 @@ from torch import nn
from vllm.compilation.counter import compilation_counter from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode, from vllm.config import (
VllmConfig, set_current_vllm_config) CompilationConfig,
CompilationMode,
CUDAGraphMode,
VllmConfig,
set_current_vllm_config,
)
from vllm.forward_context import BatchDescriptor, set_forward_context from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.utils.torch_utils import is_torch_equal_or_newer
from ...utils import create_new_process_for_each_test
# This import automatically registers `torch.ops.silly.attention` # This import automatically registers `torch.ops.silly.attention`
from .. import silly_attention # noqa: F401 from .. import silly_attention # noqa: F401
...@@ -43,15 +53,14 @@ class LlamaConfig: ...@@ -43,15 +53,14 @@ class LlamaConfig:
factors.append((k, v)) factors.append((k, v))
factors.sort() factors.sort()
import hashlib import hashlib
return hashlib.md5(str(factors).encode(),
usedforsecurity=False).hexdigest() return hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
def __post_init__(self): def __post_init__(self):
assert self.mlp_size >= self.hidden_size assert self.mlp_size >= self.hidden_size
class LlamaMLP(nn.Module): class LlamaMLP(nn.Module):
def __init__(self, config: LlamaConfig) -> None: def __init__(self, config: LlamaConfig) -> None:
super().__init__() super().__init__()
self.gate_up_projection = nn.Linear( self.gate_up_projection = nn.Linear(
...@@ -66,31 +75,31 @@ class LlamaMLP(nn.Module): ...@@ -66,31 +75,31 @@ class LlamaMLP(nn.Module):
) )
if config.tractable_init: if config.tractable_init:
nn.init.eye_(self.gate_up_projection.weight.data[:config.mlp_size]) nn.init.eye_(self.gate_up_projection.weight.data[: config.mlp_size])
nn.init.eye_(self.gate_up_projection.weight.data[config.mlp_size:]) nn.init.eye_(self.gate_up_projection.weight.data[config.mlp_size :])
nn.init.eye_(self.down_projection.weight.data) nn.init.eye_(self.down_projection.weight.data)
else: else:
nn.init.xavier_normal_(self.gate_up_projection.weight.data, nn.init.xavier_normal_(
generator=torch.Generator().manual_seed( self.gate_up_projection.weight.data,
config.random_seed), generator=torch.Generator().manual_seed(config.random_seed),
gain=0.001) gain=0.001,
nn.init.xavier_normal_(self.down_projection.weight.data, )
generator=torch.Generator().manual_seed( nn.init.xavier_normal_(
config.random_seed), self.down_projection.weight.data,
gain=0.001) generator=torch.Generator().manual_seed(config.random_seed),
gain=0.001,
)
def forward(self, x): def forward(self, x):
# for tractable_init and positive input, this is # for tractable_init and positive input, this is
# essentially an elementwise-square # essentially an elementwise-square
x = self.gate_up_projection(x) x = self.gate_up_projection(x)
x = x[:, :x.size(1) // 2] * torch.nn.functional.relu( x = x[:, : x.size(1) // 2] * torch.nn.functional.relu(x[:, x.size(1) // 2 :])
x[:, x.size(1) // 2:])
x = self.down_projection(x) x = self.down_projection(x)
return x return x
class LlamaAttention(nn.Module): class LlamaAttention(nn.Module):
def __init__(self, config: LlamaConfig) -> None: def __init__(self, config: LlamaConfig) -> None:
super().__init__() super().__init__()
self.qkv_projection = nn.Linear( self.qkv_projection = nn.Linear(
...@@ -106,21 +115,25 @@ class LlamaAttention(nn.Module): ...@@ -106,21 +115,25 @@ class LlamaAttention(nn.Module):
) )
if config.tractable_init: if config.tractable_init:
nn.init.eye_(self.qkv_projection.weight.data[:config.hidden_size]) nn.init.eye_(self.qkv_projection.weight.data[: config.hidden_size])
nn.init.eye_(self.qkv_projection.weight.data[config.hidden_size:2 * nn.init.eye_(
config.hidden_size]) self.qkv_projection.weight.data[
nn.init.eye_(self.qkv_projection.weight.data[2 * config.hidden_size : 2 * config.hidden_size
config.hidden_size:]) ]
)
nn.init.eye_(self.qkv_projection.weight.data[2 * config.hidden_size :])
nn.init.eye_(self.output_projection.weight.data) nn.init.eye_(self.output_projection.weight.data)
else: else:
nn.init.xavier_normal_(self.qkv_projection.weight.data, nn.init.xavier_normal_(
generator=torch.Generator().manual_seed( self.qkv_projection.weight.data,
config.random_seed), generator=torch.Generator().manual_seed(config.random_seed),
gain=0.001) gain=0.001,
nn.init.xavier_normal_(self.output_projection.weight.data, )
generator=torch.Generator().manual_seed( nn.init.xavier_normal_(
config.random_seed), self.output_projection.weight.data,
gain=0.001) generator=torch.Generator().manual_seed(config.random_seed),
gain=0.001,
)
def forward( def forward(
self, self,
...@@ -144,7 +157,6 @@ class LlamaAttention(nn.Module): ...@@ -144,7 +157,6 @@ class LlamaAttention(nn.Module):
class LlamaDecoderLayer(nn.Module): class LlamaDecoderLayer(nn.Module):
def __init__(self, config: LlamaConfig) -> None: def __init__(self, config: LlamaConfig) -> None:
super().__init__() super().__init__()
self.self_attention = LlamaAttention(config) self.self_attention = LlamaAttention(config)
...@@ -154,7 +166,7 @@ class LlamaDecoderLayer(nn.Module): ...@@ -154,7 +166,7 @@ class LlamaDecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
residual: Optional[torch.Tensor], residual: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
""" """
For tractable computation: For tractable computation:
...@@ -164,7 +176,7 @@ class LlamaDecoderLayer(nn.Module): ...@@ -164,7 +176,7 @@ class LlamaDecoderLayer(nn.Module):
- if residual is not None, the outputs are: - if residual is not None, the outputs are:
- residual = (hidden_states + residual + 1) * 3 + positions * 2 + hidden_states + residual = (hidden_states + residual) * 4 + positions * 2 + 3 - residual = (hidden_states + residual + 1) * 3 + positions * 2 + hidden_states + residual = (hidden_states + residual) * 4 + positions * 2 + 3
- hidden_states = (residual + 1) ** 2 - hidden_states = (residual + 1) ** 2
""" # noqa """ # noqa
if residual is None: if residual is None:
residual = hidden_states residual = hidden_states
hidden_states = hidden_states + 1 hidden_states = hidden_states + 1
...@@ -173,8 +185,9 @@ class LlamaDecoderLayer(nn.Module): ...@@ -173,8 +185,9 @@ class LlamaDecoderLayer(nn.Module):
residual = hidden_states residual = hidden_states
hidden_states = hidden_states + 1 hidden_states = hidden_states + 1
hidden_states = self.self_attention(positions=positions, hidden_states = self.self_attention(
hidden_states=hidden_states) positions=positions, hidden_states=hidden_states
)
hidden_states = hidden_states + residual hidden_states = hidden_states + residual
residual = hidden_states residual = hidden_states
...@@ -186,27 +199,29 @@ class LlamaDecoderLayer(nn.Module): ...@@ -186,27 +199,29 @@ class LlamaDecoderLayer(nn.Module):
@support_torch_compile @support_torch_compile
class LlamaModel(nn.Module): class LlamaModel(nn.Module):
def __init__(
def __init__(self, self,
*, *,
vllm_config: VllmConfig, vllm_config: VllmConfig,
config: LlamaConfig, config: LlamaConfig,
prefix: str = '', prefix: str = "",
**kwargs) -> None: **kwargs,
) -> None:
super().__init__() super().__init__()
self.embedding_tokens = nn.Embedding( self.embedding_tokens = nn.Embedding(
num_embeddings=config.vocab_size, num_embeddings=config.vocab_size,
embedding_dim=config.hidden_size, embedding_dim=config.hidden_size,
) )
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[LlamaDecoderLayer(config) for _ in range(config.num_layers)]) [LlamaDecoderLayer(config) for _ in range(config.num_layers)]
)
# this is the initial value of the hidden states # this is the initial value of the hidden states
self.embedding_tokens.weight.data.fill_(config.init_value) self.embedding_tokens.weight.data.fill_(config.init_value)
def forward( def forward(
self, self,
input_ids: Optional[torch.Tensor], input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.embedding_tokens(input_ids) hidden_states = self.embedding_tokens(input_ids)
...@@ -216,168 +231,195 @@ class LlamaModel(nn.Module): ...@@ -216,168 +231,195 @@ class LlamaModel(nn.Module):
return hidden_states return hidden_states
def tractable_computation(input_ids: torch.Tensor, def tractable_computation(
positions: torch.Tensor, input_ids: torch.Tensor,
config: LlamaConfig, positions: torch.Tensor,
init_value: float = 1.0) -> torch.Tensor: config: LlamaConfig,
hidden_states = torch.ones(input_ids.size(0), init_value: float = 1.0,
config.hidden_size, ) -> torch.Tensor:
device=input_ids.device, hidden_states = (
dtype=input_ids.dtype) * init_value torch.ones(
input_ids.size(0),
config.hidden_size,
device=input_ids.device,
dtype=input_ids.dtype,
)
* init_value
)
# first layer # first layer
residual = hidden_states * 4 + positions.unsqueeze(1) * 2 + 3 residual = hidden_states * 4 + positions.unsqueeze(1) * 2 + 3
hidden_states = (residual + 1)**2 hidden_states = (residual + 1) ** 2
# following layers # following layers
for _ in range(config.num_layers - 1): for _ in range(config.num_layers - 1):
hidden_states = hidden_states + residual hidden_states = hidden_states + residual
residual = hidden_states * 4 + positions.unsqueeze(1) * 2 + 3 residual = hidden_states * 4 + positions.unsqueeze(1) * 2 + 3
hidden_states = (residual + 1)**2 hidden_states = (residual + 1) ** 2
return hidden_states return hidden_states
@torch.inference_mode @torch.inference_mode
def run_model(llama_config, def run_model(llama_config, compile_config: CompilationConfig) -> torch.Tensor:
use_compile: bool, # Start with a fresh copy to make sure there's no cache dir sharing
use_inductor: bool, compile_config = deepcopy(compile_config)
split_attn: bool = False) -> torch.Tensor: cudagraph_runtime_mode = compile_config.cudagraph_mode
if use_compile: vllm_config = VllmConfig(
compilation_config = CompilationConfig( compilation_config=compile_config, additional_config=llama_config
level=CompilationLevel.PIECEWISE, )
use_cudagraph=True,
use_inductor=use_inductor,
cudagraph_capture_sizes=[1, 2],
)
if split_attn:
compilation_config.splitting_ops = ["silly.attention"]
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
else:
compilation_config = CompilationConfig(
level=CompilationLevel.NO_COMPILATION, )
cudagraph_runtime_mode = CUDAGraphMode.NONE
vllm_config = VllmConfig(compilation_config=compilation_config,
additional_config=llama_config)
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config):
model = LlamaModel(config=llama_config, model = (
vllm_config=vllm_config, LlamaModel(config=llama_config, vllm_config=vllm_config, prefix="")
prefix="").eval().cuda() .eval()
.cuda()
)
with set_forward_context({}, with set_forward_context({}, vllm_config=vllm_config): # background context
vllm_config=vllm_config): # background context
B = 16 # max batch size B = 16 # max batch size
input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda() input_ids = torch.randint(0, llama_config.vocab_size, (B,)).cuda()
positions = torch.arange(B).cuda() positions = torch.arange(B).cuda()
# warmup for the model with cudagraph_mode NONE # warmup for the model with cudagraph_mode NONE
model(input_ids, positions) model(input_ids, positions)
# simulate cudagraphs capturing # simulate cudagraphs capturing
with set_forward_context({}, with set_forward_context(
vllm_config=vllm_config, {},
cudagraph_runtime_mode=cudagraph_runtime_mode, vllm_config=vllm_config,
batch_descriptor=BatchDescriptor( cudagraph_runtime_mode=cudagraph_runtime_mode,
num_tokens=2, )): batch_descriptor=BatchDescriptor(
num_tokens=2,
),
):
model(input_ids[:2], positions[:2]) model(input_ids[:2], positions[:2])
with set_forward_context({}, with set_forward_context(
vllm_config=vllm_config, {},
cudagraph_runtime_mode=cudagraph_runtime_mode, vllm_config=vllm_config,
batch_descriptor=BatchDescriptor( cudagraph_runtime_mode=cudagraph_runtime_mode,
num_tokens=1, )): batch_descriptor=BatchDescriptor(
num_tokens=1,
),
):
model(input_ids[:1], positions[:1]) model(input_ids[:1], positions[:1])
input_ids[:2].zero_() input_ids[:2].zero_()
# simulate cudagraphs replay # simulate cudagraphs replay
with set_forward_context({}, with set_forward_context(
vllm_config=vllm_config, {},
cudagraph_runtime_mode=cudagraph_runtime_mode, vllm_config=vllm_config,
batch_descriptor=BatchDescriptor( cudagraph_runtime_mode=cudagraph_runtime_mode,
num_tokens=2, )): batch_descriptor=BatchDescriptor(
num_tokens=2,
),
):
output = model(input_ids[:2], positions[:2]) output = model(input_ids[:2], positions[:2])
output = output.cpu() output = output.cpu()
if llama_config.tractable_init: if llama_config.tractable_init:
expected_output = tractable_computation(input_ids[:2], expected_output = tractable_computation(
positions[:2], input_ids[:2], positions[:2], llama_config
llama_config).cpu() ).cpu()
assert torch.allclose(output, expected_output) assert torch.allclose(output, expected_output)
else: else:
return output.cpu() return output.cpu()
@pytest.mark.parametrize("use_inductor", [True, False]) @pytest.mark.parametrize(
def test_toy_llama(use_inductor: bool): "backend, use_inductor_graph_partition",
[
("eager", False), # No inductor
("inductor", False), # Inductor, Dynamo partition
("inductor", True), # Inductor, Inductor partition
],
)
@create_new_process_for_each_test("spawn")
def test_toy_llama(
backend: str, use_inductor_graph_partition: bool, monkeypatch, tmp_path
):
# We disable the vLLM compile cache into a new tmp dir for 1 reason:
# 1. To make sure we can properly track the number of Inductor compilations.
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
pytest.skip("Inductor graph partition only supported in torch>=2.9")
# compare output with and without piecewise compilation # compare output with and without piecewise compilation
llama_config = LlamaConfig(hidden_size=128, llama_config = LlamaConfig(
mlp_size=256, hidden_size=128, mlp_size=256, vocab_size=128, num_layers=12
vocab_size=128, )
num_layers=12)
tractable_config = LlamaConfig(
hidden_size=128, mlp_size=256, vocab_size=128, num_layers=2, tractable_init=True
)
compile_config_no_compile = CompilationConfig(
mode=CompilationMode.NONE,
cudagraph_mode=CUDAGraphMode.NONE,
backend="eager",
)
compile_config_no_split = CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
use_inductor_graph_partition=use_inductor_graph_partition,
cudagraph_mode=CUDAGraphMode.PIECEWISE,
backend=backend,
cudagraph_capture_sizes=[1, 2],
)
tractable_config = LlamaConfig(hidden_size=128, compile_config_split = deepcopy(compile_config_no_split)
mlp_size=256, compile_config_split.splitting_ops = ["silly::attention"]
vocab_size=128,
num_layers=2,
tractable_init=True)
outputs = [] outputs = []
with compilation_counter.expect( with compilation_counter.expect(
num_graphs_seen=0, num_graphs_seen=0,
num_piecewise_graphs_seen=0, num_piecewise_graphs_seen=0,
num_piecewise_capturable_graphs_seen=0, num_piecewise_capturable_graphs_seen=0,
num_backend_compilations=0, num_backend_compilations=0,
num_cudagraph_captured=0, num_cudagraph_captured=0,
): ):
outputs.append( outputs.append(run_model(llama_config, compile_config_no_compile))
run_model(llama_config, use_inductor=False, use_compile=False))
run_model(tractable_config, use_inductor=False, use_compile=False)
if use_inductor: run_model(tractable_config, compile_config_no_compile)
if backend == "inductor":
kwargs = {"num_inductor_compiles": 1, "num_eager_compiles": 0} kwargs = {"num_inductor_compiles": 1, "num_eager_compiles": 0}
else: else:
kwargs = {"num_eager_compiles": 1, "num_inductor_compiles": 0} kwargs = {"num_eager_compiles": 1, "num_inductor_compiles": 0}
with compilation_counter.expect( with compilation_counter.expect(
num_graphs_seen=1, # one graph for the model num_graphs_seen=1, # one graph for the model
num_piecewise_graphs_seen=1, num_piecewise_graphs_seen=1,
num_piecewise_capturable_graphs_seen=1, num_piecewise_capturable_graphs_seen=1,
num_backend_compilations=1, # num_piecewise_capturable_graphs_seen num_backend_compilations=1, # num_piecewise_capturable_graphs_seen
num_cudagraph_captured= num_cudagraph_captured=2,
2, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen **kwargs,
**kwargs,
): ):
outputs.append( outputs.append(run_model(llama_config, compile_config_no_split))
run_model(llama_config,
use_inductor=use_inductor, run_model(tractable_config, compile_config_no_split)
use_compile=True))
run_model(tractable_config, use_inductor=use_inductor, use_compile=True) if use_inductor_graph_partition:
num_piecewise_fx = 1
num_piecewise_capturable_fx = 1
else:
num_piecewise_fx = 2 * llama_config.num_layers + 1
num_piecewise_capturable_fx = 1 + llama_config.num_layers
with compilation_counter.expect( with compilation_counter.expect(
num_graphs_seen=1, # one graph for the model num_graphs_seen=1, # one graph for the model
num_piecewise_graphs_seen=2 * llama_config.num_layers + num_piecewise_graphs_seen=num_piecewise_fx,
1, # 2 * num_layers + 1 num_piecewise_capturable_graphs_seen=num_piecewise_capturable_fx,
num_piecewise_capturable_graphs_seen=1 + num_backend_compilations=num_piecewise_capturable_fx,
llama_config.num_layers, # 1 + num_layers # num_cudagraph_sizes * num_partitions
num_backend_compilations=1 + num_cudagraph_captured=2 * (1 + llama_config.num_layers),
llama_config.num_layers, # num_piecewise_capturable_graphs_seen
num_cudagraph_captured=2 *
(1 + llama_config.num_layers
), # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
): ):
outputs.append( outputs.append(run_model(llama_config, compile_config_split))
run_model(llama_config, run_model(tractable_config, compile_config_split)
use_inductor=use_inductor,
use_compile=True,
split_attn=True))
run_model(tractable_config,
use_inductor=use_inductor,
use_compile=True,
split_attn=True)
for i in range(1, len(outputs)): for i in range(1, len(outputs)):
assert torch.allclose(outputs[0], outputs[i]) assert torch.allclose(outputs[0], outputs[i])
...@@ -388,17 +430,15 @@ def benchmark(): ...@@ -388,17 +430,15 @@ def benchmark():
from triton.testing import do_bench from triton.testing import do_bench
# similar to llama 3.1-8B # similar to llama 3.1-8B
llama_config = LlamaConfig(hidden_size=4096, llama_config = LlamaConfig(
mlp_size=14336, hidden_size=4096, mlp_size=14336, vocab_size=128 * 1024, num_layers=32
vocab_size=128 * 1024, )
num_layers=32)
# a tiny model to measure the overhead # a tiny model to measure the overhead
# of piecewise cudagraph # of piecewise cudagraph
llama_config = LlamaConfig(hidden_size=40, llama_config = LlamaConfig(
mlp_size=80, hidden_size=40, mlp_size=80, vocab_size=128, num_layers=2
vocab_size=128, )
num_layers=2)
cudagraph_sizes = [1, 2, 4] + [i * 8 for i in range(1, 33)] cudagraph_sizes = [1, 2, 4] + [i * 8 for i in range(1, 33)]
...@@ -411,25 +451,27 @@ def benchmark(): ...@@ -411,25 +451,27 @@ def benchmark():
for piecewise in [False, True]: for piecewise in [False, True]:
if piecewise: if piecewise:
compilation_config = CompilationConfig( compilation_config = CompilationConfig(
level=CompilationLevel.PIECEWISE, mode=CompilationMode.VLLM_COMPILE,
use_cudagraph=True, splitting_ops=["silly::attention"],
splitting_ops=["silly.attention"],
cudagraph_capture_sizes=cudagraph_sizes, cudagraph_capture_sizes=cudagraph_sizes,
) )
else: else:
compilation_config = CompilationConfig( compilation_config = CompilationConfig(
level=CompilationLevel.PIECEWISE, mode=CompilationMode.VLLM_COMPILE,
cudagraph_capture_sizes=cudagraph_sizes, cudagraph_capture_sizes=cudagraph_sizes,
) )
vllm_config = VllmConfig(compilation_config=compilation_config) vllm_config = VllmConfig(compilation_config=compilation_config)
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config):
model = LlamaModel(config=llama_config, model = (
vllm_config=vllm_config, LlamaModel(config=llama_config, vllm_config=vllm_config, prefix="")
prefix="").eval().cuda().to(torch.bfloat16) .eval()
.cuda()
.to(torch.bfloat16)
)
B = 256 # max batch size B = 256 # max batch size
input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda() input_ids = torch.randint(0, llama_config.vocab_size, (B,)).cuda()
positions = torch.arange(B).cuda().to(torch.bfloat16) positions = torch.arange(B).cuda().to(torch.bfloat16)
graphs = {} graphs = {}
...@@ -451,22 +493,31 @@ def benchmark(): ...@@ -451,22 +493,31 @@ def benchmark():
# and use it later, because it will look up the name `b` in the # and use it later, because it will look up the name `b` in the
# enclosing scope, and the value of `b` will always be 256. # enclosing scope, and the value of `b` will always be 256.
# it is fine here, because we only use the lambda function once. # it is fine here, because we only use the lambda function once.
runtime = do_bench(lambda: graphs[b][0] # noqa runtime = do_bench(
(input_ids[:b], positions[:b])) # noqa lambda: graphs[b][0]( # noqa
input_ids[:b], # noqa
positions[:b], # noqa
)
)
piecewise_cudagraph_time[b] = runtime piecewise_cudagraph_time[b] = runtime
else: else:
runtime = do_bench(lambda: graphs[b][0].replay()) # noqa runtime = do_bench(lambda: graphs[b][0].replay()) # noqa
eager_runtime = do_bench( eager_runtime = do_bench(lambda: model(input_ids[:b], positions[:b])) # noqa
lambda: model(input_ids[:b], positions[:b])) # noqa
full_cudagraph_time[b] = runtime full_cudagraph_time[b] = runtime
eager_time[b] = eager_runtime eager_time[b] = eager_runtime
# print in tabular format # print in tabular format
print("batch size\teager mode\tfull cudagraph\tpiecewise cudagraph") print("batch size\teager mode\tfull cudagraph\tpiecewise cudagraph")
for b in cudagraph_sizes: for b in cudagraph_sizes:
print(f"{b}\t{eager_time[b]:.3f}\t{full_cudagraph_time[b]:.3f}" print(
f"\t{piecewise_cudagraph_time[b]:.3f}") f"{b}\t{eager_time[b]:.3f}\t{full_cudagraph_time[b]:.3f}"
f"\t{piecewise_cudagraph_time[b]:.3f}"
)
if __name__ == "__main__": if __name__ == "__main__":
benchmark() # Protect against subprocess reimport when using spawn_new_process_for_each_test
import os
if os.environ.get("RUNNING_IN_SUBPROCESS") != "1":
benchmark()
...@@ -8,7 +8,7 @@ Centralizes custom operation definitions to avoid duplicate registrations. ...@@ -8,7 +8,7 @@ Centralizes custom operation definitions to avoid duplicate registrations.
import torch import torch
from torch.library import Library from torch.library import Library
from vllm.utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
# Shared library for all compilation test operations # Shared library for all compilation test operations
# Using "silly" namespace to match existing test expectations # Using "silly" namespace to match existing test expectations
...@@ -31,8 +31,9 @@ def reset_global_counter(): ...@@ -31,8 +31,9 @@ def reset_global_counter():
_global_counter = 0 _global_counter = 0
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, def silly_attention(
out: torch.Tensor) -> None: q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor
) -> None:
""" """
Unified attention implementation that depends on Unified attention implementation that depends on
all inputs and affects the output. all inputs and affects the output.
...@@ -47,8 +48,9 @@ def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, ...@@ -47,8 +48,9 @@ def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out.copy_(q + k + v) out.copy_(q + k + v)
def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, def silly_attention_fake(
out: torch.Tensor) -> None: q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor
) -> None:
"""Fake implementation for testing""" """Fake implementation for testing"""
return return
...@@ -60,5 +62,4 @@ direct_register_custom_op( ...@@ -60,5 +62,4 @@ direct_register_custom_op(
mutates_args=["out"], mutates_args=["out"],
fake_impl=silly_attention_fake, fake_impl=silly_attention_fake,
target_lib=silly_lib, target_lib=silly_lib,
tags=(torch._C.Tag.cudagraph_unsafe, ),
) )
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import tempfile
from contextlib import contextmanager
import pytest
import torch
from vllm.compilation.decorators import support_torch_compile
from vllm.config import (
CompilationConfig,
CompilationMode,
VllmConfig,
set_current_vllm_config,
)
from vllm.forward_context import set_forward_context
from vllm.utils.torch_utils import is_torch_equal_or_newer
def reference_fn(x: torch.Tensor):
assert x.shape[0] <= 42
assert x.shape[0] % 2 == 0
for _ in range(3000):
x = x + x.shape[0]
return x
@support_torch_compile
class CompiledMod(torch.nn.Module):
def __init__(self, **kwargs):
super().__init__()
def forward(self, x: torch.Tensor):
return reference_fn(x)
def make_vllm_config() -> VllmConfig:
return VllmConfig(
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
)
)
@contextmanager
def use_vllm_config(vllm_config: VllmConfig):
with set_forward_context({}, vllm_config), set_current_vllm_config(vllm_config):
yield
@pytest.mark.skipif(
not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
)
def test_no_dynamo_cache_entry(monkeypatch: pytest.MonkeyPatch):
with monkeypatch.context() as m:
vllm_config = make_vllm_config()
args = (torch.randn(10, 10),)
expected = reference_fn(*args)
with use_vllm_config(vllm_config):
m.setenv("VLLM_USE_AOT_COMPILE", "0")
with (
pytest.raises(RuntimeError, match="Detected recompile"),
torch.compiler.set_stance("fail_on_recompile"),
):
CompiledMod(vllm_config=vllm_config)(*args)
m.setenv("VLLM_USE_AOT_COMPILE", "1")
torch._dynamo.reset()
with torch.compiler.set_stance("fail_on_recompile"):
actual = CompiledMod(vllm_config=vllm_config)(*args)
assert torch.allclose(actual, expected)
@pytest.mark.skipif(
not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
)
def test_force_aot_load(monkeypatch: pytest.MonkeyPatch):
with tempfile.TemporaryDirectory() as tmpdirname, monkeypatch.context() as m:
args = (torch.randn(10, 10),)
m.setenv("VLLM_USE_AOT_COMPILE", "1")
m.setenv("VLLM_FORCE_AOT_LOAD", "1")
m.setenv("VLLM_CACHE_ROOT", tmpdirname)
vllm_config = make_vllm_config()
with use_vllm_config(vllm_config), pytest.raises(FileNotFoundError):
CompiledMod(vllm_config=vllm_config)(*args)
@pytest.mark.skipif(
not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
)
def test_save_and_load(monkeypatch: pytest.MonkeyPatch):
with monkeypatch.context() as m:
args = (torch.randn(10, 10),)
with tempfile.TemporaryDirectory() as tmpdirname:
m.setenv("VLLM_CACHE_ROOT", tmpdirname)
m.setenv("VLLM_USE_AOT_COMPILE", "1")
vllm_config = make_vllm_config()
with use_vllm_config(vllm_config):
expected = CompiledMod(vllm_config=vllm_config)(*args)
m.setenv("VLLM_FORCE_AOT_LOAD", "1")
vllm_config = make_vllm_config()
with use_vllm_config(vllm_config):
ret = CompiledMod(vllm_config=vllm_config)(*args)
assert torch.allclose(ret, expected)
@pytest.mark.skipif(
not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
)
def test_shape_env(monkeypatch: pytest.MonkeyPatch):
"""
Test that the shape environment is correctly serialized and preserved
when loading from cache.
"""
with monkeypatch.context() as m:
args = (torch.randn(10, 10),)
with tempfile.TemporaryDirectory() as tmpdirname:
m.setenv("VLLM_CACHE_ROOT", tmpdirname)
m.setenv("VLLM_USE_AOT_COMPILE", "1")
vllm_config = make_vllm_config()
with use_vllm_config(vllm_config):
compiled_mod = CompiledMod(vllm_config=vllm_config)
compiled_mod(*args)
artifacts = compiled_mod.aot_compiled_fn._artifacts
guards_string = artifacts.compiled_fn.shape_env.format_guards()
assert guards_string == " - s77 <= 42\n - Eq(Mod(s77, 2), 0)"
m.setenv("VLLM_FORCE_AOT_LOAD", "1")
vllm_config = make_vllm_config()
with use_vllm_config(vllm_config):
compiled_mod = CompiledMod(vllm_config=vllm_config)
compiled_mod(*args)
artifacts = compiled_mod.aot_compiled_fn._artifacts
guards_string = artifacts.compiled_fn.shape_env.format_guards()
assert guards_string == " - s77 <= 42\n - Eq(Mod(s77, 2), 0)"
...@@ -8,18 +8,31 @@ import torch ...@@ -8,18 +8,31 @@ import torch
import vllm.envs as envs import vllm.envs as envs
from vllm.compilation.collective_fusion import AsyncTPPass from vllm.compilation.collective_fusion import AsyncTPPass
from vllm.config import (CompilationConfig, DeviceConfig, ModelConfig, from vllm.config import (
PassConfig, VllmConfig) CompilationConfig,
from vllm.distributed import (tensor_model_parallel_all_gather, CompilationMode,
tensor_model_parallel_reduce_scatter) DeviceConfig,
from vllm.distributed.parallel_state import (init_distributed_environment, ModelConfig,
initialize_model_parallel) PassConfig,
VllmConfig,
)
from vllm.distributed import (
tensor_model_parallel_all_gather,
tensor_model_parallel_reduce_scatter,
)
from vllm.distributed.parallel_state import (
init_distributed_environment,
initialize_model_parallel,
)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import update_environment_variables from vllm.utils.system_utils import update_environment_variables
from ..models.registry import HF_EXAMPLE_MODELS from ..models.registry import HF_EXAMPLE_MODELS
from ..utils import (compare_two_settings, create_new_process_for_each_test, from ..utils import (
multi_gpu_test) compare_two_settings,
create_new_process_for_each_test,
multi_gpu_test,
)
from .backend import TestBackend from .backend import TestBackend
FP8_DTYPE = current_platform.fp8_dtype() FP8_DTYPE = current_platform.fp8_dtype()
...@@ -33,21 +46,20 @@ prompts = [ ...@@ -33,21 +46,20 @@ prompts = [
class TestMMRSModel(torch.nn.Module): class TestMMRSModel(torch.nn.Module):
def __init__(self, hidden_size=16, dtype=torch.float16): def __init__(self, hidden_size=16, dtype=torch.float16):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.dtype = dtype self.dtype = dtype
self.gate_proj = torch.nn.Parameter(torch.empty( self.gate_proj = torch.nn.Parameter(
(self.hidden_size * 2, hidden_size)), torch.empty((self.hidden_size * 2, hidden_size)), requires_grad=False
requires_grad=False) )
# Initialize weights # Initialize weights
torch.nn.init.normal_(self.gate_proj, std=0.02) torch.nn.init.normal_(self.gate_proj, std=0.02)
def forward(self, hidden_states): def forward(self, hidden_states):
""" """
Forward pass implementing the mm + reduce scatter in the FX graph Forward pass implementing the mm + reduce scatter in the FX graph
""" """
# Reshape input # Reshape input
view = hidden_states.reshape(-1, self.hidden_size) view = hidden_states.reshape(-1, self.hidden_size)
...@@ -66,14 +78,13 @@ class TestMMRSModel(torch.nn.Module): ...@@ -66,14 +78,13 @@ class TestMMRSModel(torch.nn.Module):
class TestAGMMModel(torch.nn.Module): class TestAGMMModel(torch.nn.Module):
def __init__(self, hidden_size=16, dtype=torch.float16): def __init__(self, hidden_size=16, dtype=torch.float16):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.dtype = dtype self.dtype = dtype
self.weight = torch.nn.Parameter(torch.empty( self.weight = torch.nn.Parameter(
(hidden_size, hidden_size)), torch.empty((hidden_size, hidden_size)), requires_grad=False
requires_grad=False) )
# Initialize weights # Initialize weights
torch.nn.init.normal_(self.weight, std=0.02) torch.nn.init.normal_(self.weight, std=0.02)
...@@ -96,32 +107,35 @@ class TestAGMMModel(torch.nn.Module): ...@@ -96,32 +107,35 @@ class TestAGMMModel(torch.nn.Module):
class _BaseScaledMMModel(torch.nn.Module): class _BaseScaledMMModel(torch.nn.Module):
def __init__(self, hidden_size=16, dtype=torch.float16): def __init__(self, hidden_size=16, dtype=torch.float16):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.dtype = dtype self.dtype = dtype
self.weight = torch.empty([hidden_size, hidden_size], dtype=FP8_DTYPE)\ self.weight = (
.contiguous().transpose(0, 1) torch.empty([hidden_size, hidden_size], dtype=FP8_DTYPE)
.contiguous()
.transpose(0, 1)
)
# Initialize scale_b for _scaled_mm. # Initialize scale_b for _scaled_mm.
self.scale_b = torch.ones(1, self.hidden_size, dtype=torch.float32) self.scale_b = torch.ones(1, self.hidden_size, dtype=torch.float32)
class TestScaledMMRSModel(_BaseScaledMMModel): class TestScaledMMRSModel(_BaseScaledMMModel):
def forward(self, input: torch.Tensor): def forward(self, input: torch.Tensor):
""" """
Forward pass implementing the scaled_mm + reduce scatter in the FX graph Forward pass implementing the scaled_mm + reduce scatter in the FX graph
""" """
fp8_input = input.to(FP8_DTYPE) fp8_input = input.to(FP8_DTYPE)
scale_a = torch.ones(input.shape[0], 1, dtype=torch.float32) scale_a = torch.ones(input.shape[0], 1, dtype=torch.float32)
scaled_mm = torch._scaled_mm(fp8_input, scaled_mm = torch._scaled_mm(
self.weight, fp8_input,
scale_a=scale_a, self.weight,
scale_b=self.scale_b, scale_a=scale_a,
out_dtype=self.dtype) scale_b=self.scale_b,
out_dtype=self.dtype,
)
reduce_scatter = tensor_model_parallel_reduce_scatter(scaled_mm, dim=0) reduce_scatter = tensor_model_parallel_reduce_scatter(scaled_mm, dim=0)
return reduce_scatter return reduce_scatter
...@@ -129,11 +143,10 @@ class TestScaledMMRSModel(_BaseScaledMMModel): ...@@ -129,11 +143,10 @@ class TestScaledMMRSModel(_BaseScaledMMModel):
return [torch.ops.vllm.reduce_scatter.default] return [torch.ops.vllm.reduce_scatter.default]
def ops_in_model_after(self): def ops_in_model_after(self):
return [torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter.default] return [torch.ops.vllm.patched_fused_scaled_matmul_reduce_scatter.default]
class TestAGScaledMMModel(_BaseScaledMMModel): class TestAGScaledMMModel(_BaseScaledMMModel):
def forward(self, input: torch.Tensor): def forward(self, input: torch.Tensor):
""" """
Forward pass implementing the all gather + scaled_mm in the FX graph Forward pass implementing the all gather + scaled_mm in the FX graph
...@@ -143,11 +156,13 @@ class TestAGScaledMMModel(_BaseScaledMMModel): ...@@ -143,11 +156,13 @@ class TestAGScaledMMModel(_BaseScaledMMModel):
all_gather = tensor_model_parallel_all_gather(fp8_input, dim=0) all_gather = tensor_model_parallel_all_gather(fp8_input, dim=0)
scale_a = torch.ones(all_gather.shape[0], 1, dtype=torch.float32) scale_a = torch.ones(all_gather.shape[0], 1, dtype=torch.float32)
scaled_mm = torch._scaled_mm(all_gather, scaled_mm = torch._scaled_mm(
self.weight, all_gather,
scale_a=scale_a, self.weight,
scale_b=self.scale_b, scale_a=scale_a,
out_dtype=self.dtype) scale_b=self.scale_b,
out_dtype=self.dtype,
)
return scaled_mm return scaled_mm
def ops_in_model_before(self): def ops_in_model_before(self):
...@@ -158,20 +173,22 @@ class TestAGScaledMMModel(_BaseScaledMMModel): ...@@ -158,20 +173,22 @@ class TestAGScaledMMModel(_BaseScaledMMModel):
class TestCutlassScaledMMRSModel(_BaseScaledMMModel): class TestCutlassScaledMMRSModel(_BaseScaledMMModel):
def forward(self, input: torch.Tensor): def forward(self, input: torch.Tensor):
""" """
Forward pass implementing the cutlass_scaled_mm + reduce scatter Forward pass implementing the cutlass_scaled_mm + reduce scatter
in the FX graph in the FX graph
""" """
fp8_input = input.to(FP8_DTYPE) fp8_input = input.to(FP8_DTYPE)
scale_a = torch.ones(input.shape[0], 1, dtype=torch.float32) scale_a = torch.ones(input.shape[0], 1, dtype=torch.float32)
mm_out = torch.empty((fp8_input.shape[0], self.weight.shape[1]), mm_out = torch.empty(
dtype=self.dtype, (fp8_input.shape[0], self.weight.shape[1]),
device=input.device) dtype=self.dtype,
torch.ops._C.cutlass_scaled_mm(mm_out, fp8_input, self.weight, scale_a, device=input.device,
self.scale_b, None) )
torch.ops._C.cutlass_scaled_mm(
mm_out, fp8_input, self.weight, scale_a, self.scale_b, None
)
reduce_scatter = tensor_model_parallel_reduce_scatter(mm_out, dim=0) reduce_scatter = tensor_model_parallel_reduce_scatter(mm_out, dim=0)
return reduce_scatter return reduce_scatter
...@@ -179,14 +196,13 @@ class TestCutlassScaledMMRSModel(_BaseScaledMMModel): ...@@ -179,14 +196,13 @@ class TestCutlassScaledMMRSModel(_BaseScaledMMModel):
return [torch.ops.vllm.reduce_scatter.default] return [torch.ops.vllm.reduce_scatter.default]
def ops_in_model_after(self): def ops_in_model_after(self):
return [torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter.default] return [torch.ops.vllm.patched_fused_scaled_matmul_reduce_scatter.default]
class TestAGCutlassScaledMMModel(_BaseScaledMMModel): class TestAGCutlassScaledMMModel(_BaseScaledMMModel):
def forward(self, input: torch.Tensor): def forward(self, input: torch.Tensor):
""" """
Forward pass implementing the all gather + cutlass_scaled_mm Forward pass implementing the all gather + cutlass_scaled_mm
in the FX graph in the FX graph
""" """
# Reshape input # Reshape input
...@@ -195,11 +211,14 @@ class TestAGCutlassScaledMMModel(_BaseScaledMMModel): ...@@ -195,11 +211,14 @@ class TestAGCutlassScaledMMModel(_BaseScaledMMModel):
scale_a = torch.ones(all_gather.shape[0], 1, dtype=torch.float32) scale_a = torch.ones(all_gather.shape[0], 1, dtype=torch.float32)
mm_out = torch.empty((all_gather.shape[0], self.weight.shape[1]), mm_out = torch.empty(
dtype=self.dtype, (all_gather.shape[0], self.weight.shape[1]),
device=all_gather.device) dtype=self.dtype,
torch.ops._C.cutlass_scaled_mm(mm_out, all_gather, self.weight, device=all_gather.device,
scale_a, self.scale_b, None) )
torch.ops._C.cutlass_scaled_mm(
mm_out, all_gather, self.weight, scale_a, self.scale_b, None
)
return mm_out return mm_out
def ops_in_model_before(self): def ops_in_model_before(self):
...@@ -210,23 +229,43 @@ class TestAGCutlassScaledMMModel(_BaseScaledMMModel): ...@@ -210,23 +229,43 @@ class TestAGCutlassScaledMMModel(_BaseScaledMMModel):
@multi_gpu_test(num_gpus=2) @multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize("test_model", [ @pytest.mark.parametrize(
TestMMRSModel, TestAGMMModel, TestScaledMMRSModel, TestAGScaledMMModel, "test_model",
TestCutlassScaledMMRSModel, TestAGCutlassScaledMMModel [
]) TestMMRSModel,
TestAGMMModel,
TestScaledMMRSModel,
TestAGScaledMMModel,
TestCutlassScaledMMRSModel,
TestAGCutlassScaledMMModel,
],
)
@pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("seq_len", [16]) @pytest.mark.parametrize("seq_len", [16])
@pytest.mark.parametrize("hidden_size", [16]) @pytest.mark.parametrize("hidden_size", [16])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], @pytest.mark.parametrize("dynamic", [True, False])
reason="Only test on CUDA") @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
def test_async_tp_pass_replace(test_model: str, batch_size: int, seq_len: int, def test_async_tp_pass_replace(
hidden_size: int, dtype: torch.dtype): test_model: str,
if test_model in (TestScaledMMRSModel, TestAGScaledMMModel, batch_size: int,
TestCutlassScaledMMRSModel, seq_len: int,
TestAGCutlassScaledMMModel) and dtype == torch.float16: hidden_size: int,
dtype: torch.dtype,
dynamic: bool,
):
if (
test_model
in (
TestScaledMMRSModel,
TestAGScaledMMModel,
TestCutlassScaledMMRSModel,
TestAGCutlassScaledMMModel,
)
and dtype == torch.float16
):
pytest.skip( pytest.skip(
"Only bf16 high precision output types are supported for " \ "Only bf16 high precision output types are supported for "
"per-token (row-wise) scaling" "per-token (row-wise) scaling"
) )
...@@ -235,19 +274,33 @@ def test_async_tp_pass_replace(test_model: str, batch_size: int, seq_len: int, ...@@ -235,19 +274,33 @@ def test_async_tp_pass_replace(test_model: str, batch_size: int, seq_len: int,
def run_torch_spawn(fn, nprocs): def run_torch_spawn(fn, nprocs):
# need to use torch.mp.spawn otherwise will have problems with # need to use torch.mp.spawn otherwise will have problems with
# torch.distributed and cuda # torch.distributed and cuda
torch.multiprocessing.spawn(fn, torch.multiprocessing.spawn(
args=(num_processes, test_model, fn,
batch_size, seq_len, hidden_size, args=(
dtype), num_processes,
nprocs=nprocs) test_model,
batch_size,
seq_len,
hidden_size,
dtype,
dynamic,
),
nprocs=nprocs,
)
run_torch_spawn(async_tp_pass_on_test_model, num_processes) run_torch_spawn(async_tp_pass_on_test_model, num_processes)
def async_tp_pass_on_test_model(local_rank: int, world_size: int, def async_tp_pass_on_test_model(
test_model_cls: torch.nn.Module, local_rank: int,
batch_size: int, seq_len: int, world_size: int,
hidden_size: int, dtype: torch.dtype): test_model_cls: torch.nn.Module,
batch_size: int,
seq_len: int,
hidden_size: int,
dtype: torch.dtype,
dynamic: bool,
):
current_platform.seed_everything(0) current_platform.seed_everything(0)
device = torch.device(f"cuda:{local_rank}") device = torch.device(f"cuda:{local_rank}")
...@@ -255,13 +308,15 @@ def async_tp_pass_on_test_model(local_rank: int, world_size: int, ...@@ -255,13 +308,15 @@ def async_tp_pass_on_test_model(local_rank: int, world_size: int,
torch.set_default_device(device) torch.set_default_device(device)
torch.set_default_dtype(dtype) torch.set_default_dtype(dtype)
update_environment_variables({ update_environment_variables(
'RANK': str(local_rank), {
'LOCAL_RANK': str(local_rank), "RANK": str(local_rank),
'WORLD_SIZE': str(world_size), "LOCAL_RANK": str(local_rank),
'MASTER_ADDR': 'localhost', "WORLD_SIZE": str(world_size),
'MASTER_PORT': '12345', "MASTER_ADDR": "localhost",
}) "MASTER_PORT": "12345",
}
)
# initialize distributed # initialize distributed
init_distributed_environment() init_distributed_environment()
...@@ -269,27 +324,40 @@ def async_tp_pass_on_test_model(local_rank: int, world_size: int, ...@@ -269,27 +324,40 @@ def async_tp_pass_on_test_model(local_rank: int, world_size: int,
# configure vllm config for SequenceParallelismPass # configure vllm config for SequenceParallelismPass
vllm_config = VllmConfig() vllm_config = VllmConfig()
vllm_config.compilation_config = CompilationConfig(pass_config=PassConfig( vllm_config.compilation_config = CompilationConfig(
enable_async_tp=True, ), ) pass_config=PassConfig(
enable_async_tp=True,
),
)
vllm_config.device_config = DeviceConfig(device=torch.device("cuda")) vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
# this is a fake model name to construct the model config # this is a fake model name to construct the model config
# in the vllm_config, it's not really used. # in the vllm_config, it's not really used.
model_name = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e" model_name = "RedHatAI/Llama-3.2-1B-Instruct-FP8"
vllm_config.model_config = ModelConfig(model=model_name, vllm_config.model_config = ModelConfig(
trust_remote_code=True, model=model_name, trust_remote_code=True, dtype=dtype, seed=42
dtype=dtype, )
seed=42)
async_tp_pass = AsyncTPPass(vllm_config) async_tp_pass = AsyncTPPass(vllm_config)
backend = TestBackend(async_tp_pass) backend = TestBackend(async_tp_pass)
model = test_model_cls(hidden_size, assert (
dtype) # Pass dtype to model constructor async_tp_pass.compilation_config.splitting_ops
== vllm_config.compilation_config.splitting_ops
)
assert (
async_tp_pass.compilation_config.use_inductor_graph_partition
== vllm_config.compilation_config.use_inductor_graph_partition
)
model = test_model_cls(hidden_size, dtype) # Pass dtype to model constructor
hidden_states = torch.randn((batch_size * seq_len, hidden_size), hidden_states = torch.randn(
dtype=dtype, (batch_size * seq_len, hidden_size), dtype=dtype, requires_grad=False
requires_grad=False) )
if dynamic:
torch._dynamo.mark_dynamic(hidden_states, 0)
compiled_model = torch.compile(model, backend=backend) compiled_model = torch.compile(model, backend=backend)
compiled_model(hidden_states) compiled_model(hidden_states)
...@@ -306,10 +374,10 @@ def async_tp_pass_on_test_model(local_rank: int, world_size: int, ...@@ -306,10 +374,10 @@ def async_tp_pass_on_test_model(local_rank: int, world_size: int,
@create_new_process_for_each_test() @create_new_process_for_each_test()
@pytest.mark.parametrize("model_id", [ @pytest.mark.parametrize(
"meta-llama/Llama-3.2-1B-Instruct", "model_id",
"RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8" ["meta-llama/Llama-3.2-1B-Instruct", "RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8"],
]) )
@pytest.mark.parametrize("tp_size", [2]) @pytest.mark.parametrize("tp_size", [2])
@pytest.mark.parametrize("async_tp_enabled", [True]) @pytest.mark.parametrize("async_tp_enabled", [True])
@pytest.mark.parametrize("distributed_backend", ["mp"]) @pytest.mark.parametrize("distributed_backend", ["mp"])
...@@ -342,16 +410,10 @@ def test_async_tp_pass_correctness( ...@@ -342,16 +410,10 @@ def test_async_tp_pass_correctness(
common_args.append("--enforce-eager") common_args.append("--enforce-eager")
compilation_config = { compilation_config = {
'level': 3, "mode": CompilationMode.VLLM_COMPILE,
'compile_sizes': [2, 4, 8], "compile_sizes": [2, 4, 8],
'splitting_ops': [], "splitting_ops": [],
'pass_config': { "pass_config": {"enable_async_tp": async_tp_enabled},
'enable_async_tp': async_tp_enabled
},
}
async_tp_env = tp_env = {
"VLLM_USE_V1": "1",
} }
async_tp_args = [ async_tp_args = [
...@@ -372,9 +434,4 @@ def test_async_tp_pass_correctness( ...@@ -372,9 +434,4 @@ def test_async_tp_pass_correctness(
"mp", "mp",
] ]
compare_two_settings(model_id, compare_two_settings(model_id, async_tp_args, tp_args, method="generate")
async_tp_args,
tp_args,
async_tp_env,
tp_env,
method="generate")
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import dataclasses import dataclasses
import pytest import pytest
from vllm.config import CompilationLevel from vllm.config import CompilationMode
from vllm.utils import cuda_device_count_stateless from vllm.utils.torch_utils import cuda_device_count_stateless
from ..utils import compare_all_settings from ..utils import compare_all_settings
...@@ -23,7 +21,7 @@ class TestSetting: ...@@ -23,7 +21,7 @@ class TestSetting:
# we cannot afford testing the full Cartesian product # we cannot afford testing the full Cartesian product
# of all models and all levels # of all models and all modes
@pytest.mark.parametrize( @pytest.mark.parametrize(
"test_setting", "test_setting",
[ [
...@@ -79,14 +77,15 @@ class TestSetting: ...@@ -79,14 +77,15 @@ class TestSetting:
method="encode", method="encode",
), ),
# vision language model # vision language model
TestSetting( # See https://github.com/vllm-project/vllm/issues/26716.
model="microsoft/Phi-3.5-vision-instruct", # TestSetting(
model_args=["--trust-remote-code", "--max-model-len", "2048"], # model="microsoft/Phi-3.5-vision-instruct",
pp_size=2, # model_args=["--trust-remote-code", "--max-model-len", "2048"],
tp_size=1, # pp_size=2,
attn_backend="FLASH_ATTN", # tp_size=1,
method="generate_with_image", # attn_backend="FLASH_ATTN",
), # method="generate_with_image",
# ),
], ],
) )
def test_compile_correctness( def test_compile_correctness(
...@@ -103,43 +102,54 @@ def test_compile_correctness( ...@@ -103,43 +102,54 @@ def test_compile_correctness(
attn_backend = test_setting.attn_backend attn_backend = test_setting.attn_backend
method = test_setting.method method = test_setting.method
if cuda_device_count_stateless() < pp_size * tp_size: if cuda_device_count_stateless() < pp_size * tp_size:
pytest.skip(f"Need at least {pp_size}*{tp_size} CUDA gpus but got " pytest.skip(
f"{cuda_device_count_stateless()}") f"Need at least {pp_size}*{tp_size} CUDA gpus but got "
f"{cuda_device_count_stateless()}"
)
with monkeypatch.context() as m: with monkeypatch.context() as m:
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend) m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
final_args = [ final_args = [
"--enforce-eager", *model_args, "-pp", *model_args,
str(pp_size), "-tp", "-pp",
str(tp_size) str(pp_size),
"-tp",
str(tp_size),
"-O.cudagraph_mode=none",
] ]
all_args: list[list[str]] = [] all_args: list[list[str]] = []
all_envs: list[dict[str, str] | None] = [] all_envs: list[dict[str, str] | None] = []
for level in [ for comp_mode in [
CompilationLevel.NO_COMPILATION, CompilationMode.STOCK_TORCH_COMPILE,
CompilationLevel.PIECEWISE, CompilationMode.DYNAMO_TRACE_ONCE,
CompilationMode.VLLM_COMPILE,
]: ]:
all_args.append(final_args + [f"-O{level}"]) for mode in [CompilationMode.NONE, comp_mode]:
all_envs.append({}) all_args.append(
final_args + [f"-O.mode={mode.name}", "-O.backend=inductor"]
)
# inductor will change the output, so we only compare if the output # inductor will change the output, so we only compare if the output
# is close, not exactly the same. # is close, not exactly the same.
compare_all_settings( compare_all_settings(
model, model,
all_args, all_args,
all_envs, all_envs,
method=method if method != "generate" else "generate_close") method=method if method != "generate" else "generate_close",
all_envs.clear() )
all_args.clear() all_envs.clear()
all_args.clear()
for level in [ for mode in [
CompilationLevel.NO_COMPILATION, CompilationMode.NONE,
CompilationLevel.DYNAMO_AS_IS, CompilationMode.STOCK_TORCH_COMPILE,
CompilationLevel.DYNAMO_ONCE, CompilationMode.DYNAMO_TRACE_ONCE,
CompilationMode.VLLM_COMPILE,
]: ]:
all_args.append(final_args + [f"-O{level}"]) all_args.append(final_args + [f"-O.mode={mode.name}", "-O.backend=eager"])
all_envs.append({})
all_envs.append({}) all_envs.append({})
compare_all_settings(model, all_args * 3, all_envs, method=method) compare_all_settings(model, all_args * 3, all_envs, method=method)
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
from contextlib import nullcontext
from unittest.mock import patch
import pytest import pytest
from pydantic import ValidationError
import vllm
from vllm.compilation.counter import compilation_counter from vllm.compilation.counter import compilation_counter
from vllm.config import CompilationConfig, VllmConfig from vllm.compilation.fix_functionalization import FixFunctionalizationPass
from vllm.utils import _is_torch_equal_or_newer from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
from vllm.config.compilation import CompilationMode
from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform
from vllm.utils.torch_utils import _is_torch_equal_or_newer
# This import automatically registers `torch.ops.silly.attention`
from . import silly_attention # noqa: F401
def test_version():
assert _is_torch_equal_or_newer('2.8.0.dev20250624+cu128', '2.8.0.dev')
assert _is_torch_equal_or_newer('2.8.0a0+gitc82a174', '2.8.0.dev')
assert _is_torch_equal_or_newer('2.8.0', '2.8.0.dev')
assert _is_torch_equal_or_newer('2.8.1', '2.8.0.dev')
assert not _is_torch_equal_or_newer('2.7.1', '2.8.0.dev')
def test_version():
# Test the version comparison logic using the private function
assert _is_torch_equal_or_newer("2.8.0.dev20250624+cu128", "2.8.0.dev")
assert _is_torch_equal_or_newer("2.8.0a0+gitc82a174", "2.8.0.dev")
assert _is_torch_equal_or_newer("2.8.0", "2.8.0.dev")
assert _is_torch_equal_or_newer("2.8.1", "2.8.0.dev")
assert not _is_torch_equal_or_newer("2.7.1", "2.8.0.dev")
def test_use_cudagraphs_dynamic(monkeypatch):
assert vllm.envs.VLLM_USE_V1
vllm_config = VllmConfig()
assert vllm_config.compilation_config.use_cudagraph
monkeypatch.setenv('VLLM_USE_V1', '0') def test_copy_pass():
vllm_config = VllmConfig() vllm_config = VllmConfig()
assert not vllm_config.compilation_config.use_cudagraph inductor_pass = FixFunctionalizationPass(vllm_config)
copied_inductor_pass = copy.deepcopy(inductor_pass)
assert (
copied_inductor_pass.compilation_config.use_inductor_graph_partition
== vllm_config.compilation_config.use_inductor_graph_partition
)
assert (
copied_inductor_pass.compilation_config.splitting_ops
== vllm_config.compilation_config.splitting_ops
)
def test_custom_op(): def test_custom_op():
...@@ -41,63 +57,80 @@ def test_custom_op(): ...@@ -41,63 +57,80 @@ def test_custom_op():
# may be influenced by other tests. # may be influenced by other tests.
@pytest.mark.parametrize("val", ["1"]) @pytest.mark.parametrize("val", ["1"])
def test_VLLM_DISABLE_COMPILE_CACHE(vllm_runner, monkeypatch, val): def test_VLLM_DISABLE_COMPILE_CACHE(vllm_runner, monkeypatch, val):
assert vllm.envs.VLLM_USE_V1
# Disable multiprocessing so that the counter is in the same process # Disable multiprocessing so that the counter is in the same process
monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0') monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
monkeypatch.setenv('VLLM_DISABLE_COMPILE_CACHE', val) monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", val)
compilation_config = { compilation_config = {
"use_cudagraph": False, # speed things up a bit "cudagraph_mode": CUDAGraphMode.NONE, # speed things up a bit
} }
with ( with (
compilation_counter.expect(num_cache_entries_updated=0, compilation_counter.expect(
num_compiled_artifacts_saved=0), num_cache_entries_updated=0, num_compiled_artifacts_saved=0
# loading the model causes compilation (if enabled) to happen ),
vllm_runner('facebook/opt-125m', # loading the model causes compilation (if enabled) to happen
compilation_config=compilation_config, vllm_runner(
gpu_memory_utilization=0.4) as _): "facebook/opt-125m",
compilation_config=compilation_config,
gpu_memory_utilization=0.4,
) as _,
):
pass pass
# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073 # forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
@pytest.mark.forked @pytest.mark.forked
@pytest.mark.parametrize("enabled", [True, False]) @pytest.mark.parametrize(
def test_use_cudagraphs(vllm_runner, monkeypatch, enabled): "cudagraph_mode,num_cudagraph_captured",
assert vllm.envs.VLLM_USE_V1 [
(CUDAGraphMode.NONE, 0),
(CUDAGraphMode.FULL_DECODE_ONLY, 1),
(CUDAGraphMode.PIECEWISE, 13),
(CUDAGraphMode.FULL_AND_PIECEWISE, 14),
],
)
def test_use_cudagraphs(
vllm_runner, monkeypatch, cudagraph_mode, num_cudagraph_captured
):
# Disable multiprocessing so that the counter is in the same process # Disable multiprocessing so that the counter is in the same process
monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0') monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
compilation_config = { compilation_config = {
"cudagraph_capture_sizes": [100], "cudagraph_capture_sizes": [100],
"use_cudagraph": enabled, "cudagraph_mode": cudagraph_mode,
} }
num_gpu_runner_capture_triggers = 1 if cudagraph_mode != CUDAGraphMode.NONE else 0
with ( with (
compilation_counter.expect( compilation_counter.expect(
num_graphs_seen=1, num_graphs_seen=1,
num_gpu_runner_capture_triggers=1 if enabled else 0, num_gpu_runner_capture_triggers=num_gpu_runner_capture_triggers,
num_cudagraph_captured=13 if enabled else 0, num_cudagraph_captured=num_cudagraph_captured,
), ),
# loading the model causes compilation (if enabled) to happen # loading the model causes compilation (if enabled) to happen
vllm_runner('facebook/opt-125m', vllm_runner(
compilation_config=compilation_config, "facebook/opt-125m",
gpu_memory_utilization=0.4) as _): compilation_config=compilation_config,
gpu_memory_utilization=0.4,
) as _,
):
pass pass
# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073 # forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
@pytest.mark.forked @pytest.mark.forked
def test_dynamo_as_is(vllm_runner, monkeypatch): def test_stock_torch_compile(vllm_runner, monkeypatch):
# Disable multiprocessing so that the counter is in the same process # Disable multiprocessing so that the counter is in the same process
monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0') monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
with ( with (
compilation_counter.expect(dynamo_as_is_count=1), compilation_counter.expect(stock_torch_compile_count=1),
# loading the model causes compilation (if enabled) to happen # loading the model causes compilation (if enabled) to happen
vllm_runner('facebook/opt-125m', vllm_runner(
compilation_config={"level": 1}, "facebook/opt-125m",
gpu_memory_utilization=0.4) as _): compilation_config={"mode": CompilationMode.STOCK_TORCH_COMPILE},
gpu_memory_utilization=0.4,
) as _,
):
pass pass
...@@ -105,15 +138,16 @@ def test_dynamo_as_is(vllm_runner, monkeypatch): ...@@ -105,15 +138,16 @@ def test_dynamo_as_is(vllm_runner, monkeypatch):
@pytest.mark.forked @pytest.mark.forked
def test_no_compilation(vllm_runner, monkeypatch): def test_no_compilation(vllm_runner, monkeypatch):
# Disable multiprocessing so that the counter is in the same process # Disable multiprocessing so that the counter is in the same process
monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0') monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
with ( with (
compilation_counter.expect(num_graphs_seen=0, compilation_counter.expect(num_graphs_seen=0, stock_torch_compile_count=0),
dynamo_as_is_count=0), # loading the model causes compilation (if enabled) to happen
# loading the model causes compilation (if enabled) to happen vllm_runner(
vllm_runner('facebook/opt-125m', "facebook/opt-125m",
compilation_config={"level": 0}, compilation_config={"mode": CompilationMode.NONE},
gpu_memory_utilization=0.4) as _): gpu_memory_utilization=0.4,
) as _,
):
pass pass
...@@ -121,13 +155,223 @@ def test_no_compilation(vllm_runner, monkeypatch): ...@@ -121,13 +155,223 @@ def test_no_compilation(vllm_runner, monkeypatch):
@pytest.mark.forked @pytest.mark.forked
def test_enforce_eager(vllm_runner, monkeypatch): def test_enforce_eager(vllm_runner, monkeypatch):
# Disable multiprocessing so that the counter is in the same process # Disable multiprocessing so that the counter is in the same process
monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0') monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
with ( with (
compilation_counter.expect(num_graphs_seen=0, compilation_counter.expect(num_graphs_seen=0, stock_torch_compile_count=0),
dynamo_as_is_count=0), # loading the model causes compilation (if enabled) to happen
# loading the model causes compilation (if enabled) to happen vllm_runner(
vllm_runner('facebook/opt-125m', "facebook/opt-125m", enforce_eager=True, gpu_memory_utilization=0.4
enforce_eager=True, ) as _,
gpu_memory_utilization=0.4) as _): ):
pass pass
def test_splitting_ops_dynamic():
# Default config
config = VllmConfig()
# Default V1 config leaves cudagraph mode unset; splitting ops are only
# populated when the engine decides to use piecewise compilation.
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.NONE
assert not config.compilation_config.splitting_ops_contain_attention()
# When use_inductor_graph_partition=True
config = VllmConfig(
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
use_inductor_graph_partition=True,
splitting_ops=["vllm::unified_attention"],
)
)
# with inductor partition we use splitting_ops directly for
# partition rules
assert config.compilation_config.splitting_ops == ["vllm::unified_attention"]
# When attn_fusion pass enabled.
config = VllmConfig(
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
pass_config={"enable_attn_fusion": True, "enable_noop": True},
custom_ops=["+quant_fp8"],
cudagraph_mode=CUDAGraphMode.PIECEWISE,
)
)
assert config.compilation_config.splitting_ops == []
# cudagraph mode also fall back to FULL
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL
# splitting_ops can not contain attention ops when attn_fusion
# pass enabled.
with pytest.raises(ValidationError):
config = VllmConfig(
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
pass_config={"enable_attn_fusion": True, "enable_noop": True},
custom_ops=["+quant_fp8"],
cudagraph_mode=CUDAGraphMode.PIECEWISE,
# work around for accessing all attntion ops
splitting_ops=CompilationConfig()._attention_ops,
)
)
# When both use_inductor_graph_partition and attn_fusion pass enabled.
config = VllmConfig(
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
use_inductor_graph_partition=True,
pass_config={"enable_attn_fusion": True, "enable_noop": True},
custom_ops=["+quant_fp8"],
cudagraph_mode=CUDAGraphMode.PIECEWISE,
)
)
# With inductor graph partition, attn_fusion and splitting_ops
# work together. Default splitting_ops include attention ops.
assert config.compilation_config.splitting_ops_contain_attention()
# enable_attn_fusion is directly supported under
# use_inductor_graph_partition=True, and cudagraph_mode
# is unchanged.
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE
def test_should_split():
import torch
from vllm.compilation.partition_rules import should_split
graph = torch.fx.Graph()
node = torch.fx.Node(
graph=graph,
name="dummy_node",
op="call_function",
target=torch.ops.aten.add.default,
args=(),
kwargs={},
)
# supports OpOverloadPacket
splitting_ops = ["aten::add"]
assert should_split(node, splitting_ops)
# supports OpOverload
splitting_ops = ["aten::add.default"]
assert should_split(node, splitting_ops)
# supports OpOverload
splitting_ops = ["aten::add.Tensor"]
assert not should_split(node, splitting_ops)
q, k, v, out = [torch.randn(1)] * 4
# supports custom ops as OpOverloadPacket
node = torch.fx.Node(
graph=graph,
name="dummy_node",
op="call_function",
target=torch.ops.silly.attention,
args=(q, k, v, out),
kwargs={},
)
splitting_ops = ["silly::attention"]
assert should_split(node, splitting_ops)
# supports custom ops as OpOverload
node = torch.fx.Node(
graph=graph,
name="dummy_node",
op="call_function",
target=torch.ops.silly.attention.default,
args=(q, k, v, out),
kwargs={},
)
splitting_ops = ["silly::attention"]
assert should_split(node, splitting_ops)
splitting_ops = ["silly::attention.default"]
assert should_split(node, splitting_ops)
@pytest.mark.skipif(
not current_platform.support_static_graph_mode(),
reason="Skip if not cudagraph mode supported",
)
@pytest.mark.parametrize(
(
"cudagraph_capture_sizes",
"max_cudagraph_capture_size",
"tp_size",
"enable_sequence_parallelism",
"max_num_batched_tokens",
"cudagraph_mode",
"expected_max_size",
),
[
(None, None, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 256),
([1, 2, 4], 4, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 4),
(
[1, 2, 4],
8,
1,
False,
2048,
CUDAGraphMode.FULL_AND_PIECEWISE,
ValidationError,
),
([1, 256], None, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 256),
([], None, 1, False, 2048, CUDAGraphMode.NONE, 0),
(None, 0, 1, False, 2048, CUDAGraphMode.NONE, 0),
# truncated to nearest multiple of 8 or 16
(None, 257, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 256),
# max from list
([1, 2, 4, 15], None, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 15),
# filtered out 15 due to SP
([1, 2, 4, 15], None, 2, True, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 4),
# limited by the max_tokens
([1, 2, 4, 15], None, 1, False, 8, CUDAGraphMode.FULL_AND_PIECEWISE, 4),
# the list should contain at least 1 element when use cudagraph
([], None, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, ValidationError),
# the max capturing size should be >= 1 when use cudagraph
(None, 0, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, ValidationError),
],
)
def test_cudagraph_sizes_post_init(
cudagraph_capture_sizes,
max_cudagraph_capture_size,
tp_size,
enable_sequence_parallelism,
max_num_batched_tokens,
cudagraph_mode,
expected_max_size,
):
ctx = nullcontext()
if expected_max_size == ValidationError:
ctx = pytest.raises(expected_max_size)
with (
ctx,
patch("vllm.config.parallel.cuda_device_count_stateless", return_value=tp_size),
):
compilation_config = CompilationConfig(
cudagraph_capture_sizes=cudagraph_capture_sizes,
max_cudagraph_capture_size=max_cudagraph_capture_size,
pass_config={
"enable_sequence_parallelism": enable_sequence_parallelism,
"enable_fusion": True,
"enable_noop": True,
},
cudagraph_mode=cudagraph_mode,
)
engine_args = EngineArgs(
model="facebook/opt-125m",
tensor_parallel_size=tp_size,
max_num_seqs=min(max_num_batched_tokens, 128),
max_num_batched_tokens=max_num_batched_tokens,
compilation_config=compilation_config,
)
vllm_config = engine_args.create_engine_config()
assert (
vllm_config.compilation_config.max_cudagraph_capture_size
== expected_max_size
)
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch import torch
from torch import nn from torch import nn
from vllm.compilation.counter import compilation_counter from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import (ignore_torch_compile, from vllm.compilation.decorators import ignore_torch_compile, support_torch_compile
support_torch_compile) from vllm.config import (
from vllm.config import (CacheConfig, CompilationConfig, CompilationLevel, CacheConfig,
CUDAGraphMode, VllmConfig, set_current_vllm_config) CompilationConfig,
CompilationMode,
CUDAGraphMode,
VllmConfig,
set_current_vllm_config,
)
from vllm.forward_context import BatchDescriptor, set_forward_context from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.utils.torch_utils import is_torch_equal_or_newer
# This import automatically registers `torch.ops.silly.attention` # This import automatically registers `torch.ops.silly.attention`
from . import silly_attention # noqa: F401 from . import silly_attention # noqa: F401
...@@ -18,56 +25,86 @@ MLP_SIZE = 128 ...@@ -18,56 +25,86 @@ MLP_SIZE = 128
@torch.inference_mode @torch.inference_mode
def run_model(vllm_config: VllmConfig, model: nn.Module, def run_model(
cudagraph_runtime_mode: CUDAGraphMode): vllm_config: VllmConfig, model: nn.Module, cudagraph_runtime_mode: CUDAGraphMode
):
with set_forward_context({}, vllm_config=vllm_config): with set_forward_context({}, vllm_config=vllm_config):
# warmup for the model with cudagraph_mode NONE # warmup for the model with cudagraph_mode NONE
model(torch.randn(BATCH_SIZE, MLP_SIZE).cuda()) model(torch.randn(BATCH_SIZE, MLP_SIZE).cuda())
# simulate cudagraphs capturing # simulate cudagraphs capturing
with set_forward_context({}, with set_forward_context(
vllm_config=vllm_config, {},
cudagraph_runtime_mode=cudagraph_runtime_mode, vllm_config=vllm_config,
batch_descriptor=BatchDescriptor( cudagraph_runtime_mode=cudagraph_runtime_mode,
num_tokens=2, )): batch_descriptor=BatchDescriptor(
num_tokens=2,
),
):
model(torch.randn(2, MLP_SIZE).cuda()) model(torch.randn(2, MLP_SIZE).cuda())
with set_forward_context({}, with set_forward_context(
vllm_config=vllm_config, {},
cudagraph_runtime_mode=cudagraph_runtime_mode, vllm_config=vllm_config,
batch_descriptor=BatchDescriptor( cudagraph_runtime_mode=cudagraph_runtime_mode,
num_tokens=1, )): batch_descriptor=BatchDescriptor(
num_tokens=1,
),
):
model(torch.randn(1, MLP_SIZE).cuda()) model(torch.randn(1, MLP_SIZE).cuda())
# simulate cudagraphs replay # simulate cudagraphs replay
with set_forward_context({}, with set_forward_context(
vllm_config=vllm_config, {},
cudagraph_runtime_mode=cudagraph_runtime_mode, vllm_config=vllm_config,
batch_descriptor=BatchDescriptor( cudagraph_runtime_mode=cudagraph_runtime_mode,
num_tokens=2, )): batch_descriptor=BatchDescriptor(
num_tokens=2,
),
):
output = model(torch.randn(2, MLP_SIZE).cuda()) output = model(torch.randn(2, MLP_SIZE).cuda())
output = output.cpu() output = output.cpu()
return output.cpu() return output.cpu()
def test_ignore_torch_compile_decorator(): @pytest.mark.parametrize("use_inductor_graph_partition", [True, False])
def test_ignore_torch_compile_decorator(use_inductor_graph_partition, monkeypatch):
# disable compile cache so that we can count the number of compilations
# appropriately
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
# piecewise # piecewise
vllm_config = VllmConfig(compilation_config=CompilationConfig( vllm_config = VllmConfig(
level=CompilationLevel.PIECEWISE, compilation_config=CompilationConfig(
use_cudagraph=True, mode=CompilationMode.VLLM_COMPILE,
splitting_ops=["silly.attention"], splitting_ops=["silly::attention"],
cudagraph_capture_sizes=[1, 2], cudagraph_capture_sizes=[1, 2],
)) use_inductor_graph_partition=use_inductor_graph_partition,
)
)
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
expected_num_graphs_seen = 1
expected_num_cudagraph_captured = (
4 # num_cudagraph_sizes * num cudagraphs to capture
)
if use_inductor_graph_partition:
expected_num_piecewise_graphs_seen = 1
expected_num_piecewise_capturable_graphs_seen = 1
expected_num_backend_compilations = 1
else:
expected_num_piecewise_graphs_seen = 3
expected_num_piecewise_capturable_graphs_seen = 2
expected_num_backend_compilations = 2
@support_torch_compile @support_torch_compile
class A(nn.Module): class A(nn.Module):
def __init__(
def __init__(self, self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs
*, ) -> None:
vllm_config: VllmConfig,
prefix: str = '',
**kwargs) -> None:
super().__init__() super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
...@@ -79,66 +116,58 @@ def test_ignore_torch_compile_decorator(): ...@@ -79,66 +116,58 @@ def test_ignore_torch_compile_decorator():
return x return x
@ignore_torch_compile @ignore_torch_compile
class B(A): class B(A): ...
...
@support_torch_compile @support_torch_compile
class C(B): class C(B): ...
...
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config):
mod_A = A(vllm_config=vllm_config, prefix='').eval().cuda() mod_A = A(vllm_config=vllm_config, prefix="").eval().cuda()
# A has support_torch_compile # A has support_torch_compile
with compilation_counter.expect( with compilation_counter.expect(
num_graphs_seen=1, num_graphs_seen=expected_num_graphs_seen,
num_piecewise_graphs_seen=3, num_piecewise_graphs_seen=expected_num_piecewise_graphs_seen,
num_piecewise_capturable_graphs_seen=2, num_piecewise_capturable_graphs_seen=expected_num_piecewise_capturable_graphs_seen,
num_backend_compilations=2, num_backend_compilations=expected_num_backend_compilations,
num_cudagraph_captured=4, num_cudagraph_captured=expected_num_cudagraph_captured,
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
): ):
run_model(vllm_config, mod_A, cudagraph_runtime_mode) run_model(vllm_config, mod_A, cudagraph_runtime_mode)
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config):
mod_B = B(vllm_config=vllm_config, prefix='').eval().cuda() mod_B = B(vllm_config=vllm_config, prefix="").eval().cuda()
# B's ignore_torch_compile should override A's support_torch_compile # B's ignore_torch_compile should override A's support_torch_compile
with compilation_counter.expect( with compilation_counter.expect(
num_graphs_seen=0, num_graphs_seen=0,
num_piecewise_graphs_seen=0, num_piecewise_graphs_seen=0,
num_piecewise_capturable_graphs_seen=0, num_piecewise_capturable_graphs_seen=0,
num_backend_compilations=0, num_backend_compilations=0,
num_cudagraph_captured=0, num_cudagraph_captured=0,
): ):
run_model(vllm_config, mod_B, cudagraph_runtime_mode) run_model(vllm_config, mod_B, cudagraph_runtime_mode)
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config):
mod_C = C(vllm_config=vllm_config, prefix='').eval().cuda() mod_C = C(vllm_config=vllm_config, prefix="").eval().cuda()
# C's support_torch_compile should override B's ignore_torch_compile # C's support_torch_compile should override B's ignore_torch_compile
with compilation_counter.expect( with compilation_counter.expect(
num_graphs_seen=1, num_graphs_seen=expected_num_graphs_seen,
num_piecewise_graphs_seen=3, num_piecewise_graphs_seen=expected_num_piecewise_graphs_seen,
num_piecewise_capturable_graphs_seen=2, num_piecewise_capturable_graphs_seen=expected_num_piecewise_capturable_graphs_seen,
num_backend_compilations=2, num_backend_compilations=expected_num_backend_compilations,
num_cudagraph_captured=4, num_cudagraph_captured=expected_num_cudagraph_captured,
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
): ):
run_model(vllm_config, mod_C, cudagraph_runtime_mode) run_model(vllm_config, mod_C, cudagraph_runtime_mode)
# Only enable torch.compile if # Only enable torch.compile if
# vllm_config.cache_config.kv_sharing_fast_prefill=True # vllm_config.cache_config.kv_sharing_fast_prefill=True
@support_torch_compile(enable_if=lambda vllm_config: vllm_config.cache_config. @support_torch_compile(
kv_sharing_fast_prefill) enable_if=lambda vllm_config: vllm_config.cache_config.kv_sharing_fast_prefill
)
class B(nn.Module): class B(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs) -> None:
def __init__(self,
*,
vllm_config: VllmConfig,
prefix: str = '',
**kwargs) -> None:
super().__init__() super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
...@@ -152,15 +181,11 @@ class B(nn.Module): ...@@ -152,15 +181,11 @@ class B(nn.Module):
# Only enable torch.compile if # Only enable torch.compile if
# vllm_config.cache_config.kv_sharing_fast_prefill=False # vllm_config.cache_config.kv_sharing_fast_prefill=False
@support_torch_compile(enable_if=lambda vllm_config: not vllm_config. @support_torch_compile(
cache_config.kv_sharing_fast_prefill) enable_if=lambda vllm_config: not vllm_config.cache_config.kv_sharing_fast_prefill
)
class A(nn.Module): class A(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs) -> None:
def __init__(self,
*,
vllm_config: VllmConfig,
prefix: str = '',
**kwargs) -> None:
super().__init__() super().__init__()
self.mod1 = B(vllm_config=vllm_config, prefix=prefix, **kwargs) self.mod1 = B(vllm_config=vllm_config, prefix=prefix, **kwargs)
self.mod2 = B(vllm_config=vllm_config, prefix=prefix, **kwargs) self.mod2 = B(vllm_config=vllm_config, prefix=prefix, **kwargs)
...@@ -174,55 +199,88 @@ class A(nn.Module): ...@@ -174,55 +199,88 @@ class A(nn.Module):
return x return x
def test_conditional_compile_enable_if(): @pytest.mark.parametrize("use_inductor_graph_partition", [True, False])
vllm_config = VllmConfig(cache_config=CacheConfig( def test_conditional_compile_enable_if(use_inductor_graph_partition, monkeypatch):
kv_sharing_fast_prefill=True, ), # disable compile cache so that we can count the number of compilations
compilation_config=CompilationConfig( # appropriately
level=CompilationLevel.PIECEWISE, monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
use_cudagraph=True,
splitting_ops=["silly.attention"], if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
cudagraph_capture_sizes=[1, 2], pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
))
vllm_config = VllmConfig(
cache_config=CacheConfig(
kv_sharing_fast_prefill=True,
),
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
splitting_ops=["silly::attention"],
cudagraph_capture_sizes=[1, 2],
use_inductor_graph_partition=use_inductor_graph_partition,
),
)
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config):
mod_A = A(vllm_config=vllm_config, prefix='').eval().cuda() mod_A = A(vllm_config=vllm_config, prefix="").eval().cuda()
if use_inductor_graph_partition:
expected_num_piecewise_graphs_seen = 2
expected_num_piecewise_capturable_graphs_seen = 2
expected_num_backend_compilations = 2
else:
expected_num_piecewise_graphs_seen = 6
expected_num_piecewise_capturable_graphs_seen = 4
expected_num_backend_compilations = 4
# A has support_torch_compile but enable_if fn returns False # A has support_torch_compile but enable_if fn returns False
# enalbe_if will be True for B, so we expect mod1 and mod2 # enalbe_if will be True for B, so we expect mod1 and mod2
# to be compiled # to be compiled
with compilation_counter.expect( with compilation_counter.expect(
num_graphs_seen=2, num_graphs_seen=2,
num_piecewise_graphs_seen=6, num_piecewise_graphs_seen=expected_num_piecewise_graphs_seen,
# 3 piecewise graphs per instance of B() # 3 piecewise graphs per instance of B()
num_piecewise_capturable_graphs_seen=4, num_piecewise_capturable_graphs_seen=expected_num_piecewise_capturable_graphs_seen,
num_backend_compilations=4, num_backend_compilations=expected_num_backend_compilations,
num_cudagraph_captured=8, num_cudagraph_captured=8,
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen # num_cudagraph_sizes * num cudagraphable graphs to capture
): ):
run_model(vllm_config, mod_A, cudagraph_runtime_mode) run_model(vllm_config, mod_A, cudagraph_runtime_mode)
# Set kv_sharing_fast_prefill=False # Set kv_sharing_fast_prefill=False
# which will cause A to be compiled and B to not be compiled # which will cause A to be compiled and B to not be compiled
vllm_config = VllmConfig(cache_config=CacheConfig( vllm_config = VllmConfig(
kv_sharing_fast_prefill=False, ), cache_config=CacheConfig(
compilation_config=CompilationConfig( kv_sharing_fast_prefill=False,
level=CompilationLevel.PIECEWISE, ),
use_cudagraph=True, compilation_config=CompilationConfig(
splitting_ops=["silly.attention"], mode=CompilationMode.VLLM_COMPILE,
cudagraph_capture_sizes=[1, 2], splitting_ops=["silly::attention"],
)) cudagraph_capture_sizes=[1, 2],
use_inductor_graph_partition=use_inductor_graph_partition,
),
)
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config):
mod_A = A(vllm_config=vllm_config, prefix='').eval().cuda() mod_A = A(vllm_config=vllm_config, prefix="").eval().cuda()
if use_inductor_graph_partition:
expected_num_piecewise_graphs_seen = 1
expected_num_piecewise_capturable_graphs_seen = 1
expected_num_backend_compilations = 1
else:
# 3 attn ops and 4 non-attn ops
expected_num_piecewise_graphs_seen = 7
expected_num_piecewise_capturable_graphs_seen = 4
expected_num_backend_compilations = 4
with compilation_counter.expect( with compilation_counter.expect(
num_graphs_seen=1, num_graphs_seen=1,
num_piecewise_graphs_seen=7, num_piecewise_graphs_seen=expected_num_piecewise_graphs_seen,
# 3 attn ops and 4 non-attn ops # 3 attn ops and 4 non-attn ops
num_piecewise_capturable_graphs_seen=4, num_piecewise_capturable_graphs_seen=expected_num_piecewise_capturable_graphs_seen,
num_backend_compilations=4, num_backend_compilations=expected_num_backend_compilations,
num_cudagraph_captured=8, num_cudagraph_captured=8,
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen # num_cudagraph_sizes * num cudagraphable graphs to capture
): ):
run_model(vllm_config, mod_A, cudagraph_runtime_mode) run_model(vllm_config, mod_A, cudagraph_runtime_mode)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment