Unverified Commit 74f441f4 authored by fhl2000's avatar fhl2000 Committed by GitHub
Browse files

[Core] Allow full cudagraph with separate attention routines and orthogonal to...


[Core] Allow full cudagraph with separate attention routines and orthogonal to compilation, add support for FA2 and FlashInfer (#20059)
Signed-off-by: default avatarfhl <2410591650@qq.com>
Signed-off-by: default avatarfhl2000 <63384265+fhl2000@users.noreply.github.com>
Signed-off-by: default avatarLucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: default avatarLucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Co-authored-by: default avatarLuka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: default avatarLucas Wilkinson <lwilkins@redhat.com>
Co-authored-by: default avatarLucas Wilkinson <LucasWilkinson@users.noreply.github.com>
parent a0632a3e
...@@ -3,7 +3,8 @@ ...@@ -3,7 +3,8 @@
import contextlib import contextlib
import os import os
import weakref import weakref
from contextlib import ExitStack from dataclasses import dataclass
from typing import Optional
import pytest import pytest
...@@ -32,69 +33,133 @@ def temporary_environ(env_vars): ...@@ -32,69 +33,133 @@ def temporary_environ(env_vars):
os.environ[k] = v os.environ[k] = v
@pytest.fixture(scope="class") @dataclass
def llm_pair(request): class BackendConfig:
model = request.param name: str
env_vars: dict
with temporary_environ({ comp_config: dict
"VLLM_USE_V1": "1", specific_gpu_arch: Optional[tuple] = None
"VLLM_FLASH_ATTN_VERSION": "3"
}):
full = LLM( # Define all backend configurations of full cudagraph to be tested
model=model, backend_configs = {
gpu_memory_utilization=0.45, # FA3 on Hopper
trust_remote_code=True, "FA3":
max_model_len=1024, BackendConfig(name="FA3",
compilation_config=CompilationConfig(full_cuda_graph=True), env_vars={"VLLM_FLASH_ATTN_VERSION": "3"},
) comp_config={
piecewise = LLM( "cudagraph_mode": "FULL",
model=model, },
gpu_memory_utilization=0.45, specific_gpu_arch=(9, 0)),
trust_remote_code=True, # FlashMLA on Hopper
max_model_len=1024, "FlashMLA":
compilation_config=CompilationConfig(), BackendConfig(name="FlashMLA",
) env_vars={
"VLLM_ATTENTION_BACKEND": "FLASHMLA",
# PyTest caches the fixture values so we use weakref.proxy to enable GC },
yield weakref.proxy(full), weakref.proxy(piecewise) comp_config={
del full "cudagraph_mode": "FULL_AND_PIECEWISE",
del piecewise },
specific_gpu_arch=(9, 0)),
wait_for_gpu_memory_to_clear( # Cutlass MLA on Blackwell
devices=[0], "CutlassMLA":
threshold_ratio=0.1, BackendConfig(
) name="CutlassMLA",
env_vars={
@pytest.fixture(scope="class")
def cutlass_mla_llm_pair(request):
model = request.param
# force V1 engine and Cutlass MLA backend
with temporary_environ({
"VLLM_USE_V1": "1", "VLLM_USE_V1": "1",
"VLLM_ATTENTION_BACKEND": "CUTLASS_MLA", "VLLM_ATTENTION_BACKEND": "CUTLASS_MLA",
"FORCE_NUM_KV_SPLITS": "FORCE_NUM_KV_SPLITS":
"1", # TODO: remove this when hang issue is fixed "1", # TODO: remove this when hang issue is fixed
}): },
comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE",
"cudagraph_capture_sizes": [16, 32, 64, 128, 256, 512],
},
specific_gpu_arch=(10, 0)),
# FA2
"FA2":
BackendConfig(name="FA2",
env_vars={"VLLM_FLASH_ATTN_VERSION": "2"},
comp_config={
"cudagraph_mode": "FULL",
}),
# Triton Attention
"TritonAttn":
BackendConfig(name="TritonAttn",
env_vars={"VLLM_ATTENTION_BACKEND": "TRITON_ATTN_VLLM_V1"},
comp_config={
"cudagraph_mode": "FULL",
}),
# FlashInfer
"FlashInfer":
BackendConfig(name="FlashInfer",
env_vars={"VLLM_ATTENTION_BACKEND": "FLASHINFER"},
comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE",
}),
}
test_params_full_cudagraph = []
# deepseek-ai/DeepSeek-V2-Lite with MLA
MLA_backends = ["FlashMLA", "CutlassMLA"]
for mla_backend in MLA_backends:
test_params_full_cudagraph.append(
pytest.param(
("deepseek-ai/DeepSeek-V2-Lite", backend_configs[mla_backend])))
# Qwen/Qwen2-1.5B-Instruct with other backends
other_backend_configs = [
backend_configs[c] for c in backend_configs if c not in MLA_backends
]
for backend_config in other_backend_configs:
test_params_full_cudagraph.append(
pytest.param(("Qwen/Qwen2-1.5B-Instruct", backend_config)))
@pytest.fixture(scope="class")
def llm_pair(request):
model, backend_config = request.param
# Dynamically skip test if GPU capability is not met
if backend_config.specific_gpu_arch and backend_config.specific_gpu_arch\
!= current_platform.get_device_capability():
if backend_config.specific_gpu_arch == (9, 0):
pytest.skip("Only Hopper GPUs support FA3 and FlashMLA")
elif backend_config.specific_gpu_arch == (10, 0):
pytest.skip("Only Blackwell GPUs support Cutlass MLA")
env_vars = {
"VLLM_USE_V1": "1",
# Force native sampler to avoid potential nondeterminism in FlashInfer
# when per-request generators are not used in V1.
"VLLM_USE_FLASHINFER_SAMPLER": "0",
**backend_config.env_vars,
}
with temporary_environ(env_vars):
full = LLM( full = LLM(
model=model, model=model,
gpu_memory_utilization=0.45, gpu_memory_utilization=0.43,
trust_remote_code=True, trust_remote_code=True,
max_model_len=1024, max_model_len=1024,
compilation_config=CompilationConfig( max_num_seqs=128,
full_cuda_graph=True, compilation_config=\
cudagraph_capture_sizes=[16, 32, 64, 128, 256, 512], CompilationConfig(**backend_config.comp_config),
), generation_config="vllm",
seed=42,
) )
piecewise = LLM( piecewise = LLM(
model=model, model=model,
gpu_memory_utilization=0.45, gpu_memory_utilization=0.43,
trust_remote_code=True, trust_remote_code=True,
max_model_len=1024, max_model_len=1024,
compilation_config=CompilationConfig(), max_num_seqs=128,
compilation_config=CompilationConfig(cudagraph_mode="PIECEWISE"),
generation_config="vllm",
seed=42,
) )
# PyTest caches the fixture values so we use weakref.proxy to enable GC
yield weakref.proxy(full), weakref.proxy(piecewise) yield weakref.proxy(full), weakref.proxy(piecewise)
del full del full
del piecewise del piecewise
...@@ -105,51 +170,7 @@ def cutlass_mla_llm_pair(request): ...@@ -105,51 +170,7 @@ def cutlass_mla_llm_pair(request):
) )
@pytest.mark.parametrize( @pytest.mark.parametrize("llm_pair", test_params_full_cudagraph, indirect=True)
"cutlass_mla_llm_pair",
[
# use an MLA model
"deepseek-ai/DeepSeek-V2-Lite",
],
indirect=True)
@pytest.mark.skipif(current_platform.get_device_capability() != (10, 0),
reason="Only Blackwell GPUs support Cutlass MLA")
class TestFullCUDAGraphCutlassMLA:
"""
Validate full CUDA Graph with Cutlass MLA (decode-only capture).
"""
@pytest.mark.parametrize(("batch_size", "max_tokens"), [
(8, 8),
])
def test_full_cudagraph_sm100_cutlass_mla(
self, batch_size, max_tokens, cutlass_mla_llm_pair: tuple[LLM,
LLM]):
piecewise_llm, full_cudagraph_llm = cutlass_mla_llm_pair
prompts = ["Hello, my name is"] * batch_size
sampling_params = SamplingParams(temperature=0.0,
max_tokens=max_tokens,
top_p=0.95)
piecewise_responses = piecewise_llm.generate(prompts, sampling_params)
full_responses = full_cudagraph_llm.generate(prompts, sampling_params)
for piecewise_res, full_res in zip(piecewise_responses,
full_responses):
assert piecewise_res.outputs[0].text == full_res.outputs[0].text
@pytest.mark.parametrize(
"llm_pair",
[
# Model names for the llm_pair fixture
"deepseek-ai/DeepSeek-V2-Lite",
"Qwen/Qwen2-1.5B-Instruct"
],
indirect=True)
@pytest.mark.skipif(current_platform.get_device_capability() != (9, 0),
reason="Only Hopper GPUs support FA3 and FlashMLA")
class TestFullCUDAGraph: class TestFullCUDAGraph:
""" """
Use a class such that an llm pair is constructed once for all Use a class such that an llm pair is constructed once for all
...@@ -178,12 +199,14 @@ class TestFullCUDAGraph: ...@@ -178,12 +199,14 @@ class TestFullCUDAGraph:
full cudagraph compilation works for padded cases too. full cudagraph compilation works for padded cases too.
""" """
piecewise_llm, full_cudagraph_llm = llm_pair full_cudagraph_llm, piecewise_llm = llm_pair
prompts = ["Hello, my name is"] * batch_size prompts = ["the quick brown fox"] * batch_size
# Use purely greedy decoding to avoid top-p truncation sensitivity
# that can amplify tiny numeric differences across runtimes.
sampling_params = SamplingParams(temperature=0.0, sampling_params = SamplingParams(temperature=0.0,
max_tokens=max_tokens, max_tokens=max_tokens,
top_p=0.95) top_p=1.0)
piecewise_responses = piecewise_llm.generate(prompts, sampling_params) piecewise_responses = piecewise_llm.generate(prompts, sampling_params)
full_responses = full_cudagraph_llm.generate(prompts, sampling_params) full_responses = full_cudagraph_llm.generate(prompts, sampling_params)
...@@ -191,42 +214,16 @@ class TestFullCUDAGraph: ...@@ -191,42 +214,16 @@ class TestFullCUDAGraph:
# Check that all responses are the same # Check that all responses are the same
for piecewise_res, full_res in zip(piecewise_responses, for piecewise_res, full_res in zip(piecewise_responses,
full_responses): full_responses):
assert piecewise_res.outputs[0].text == full_res.outputs[0].text assert piecewise_res.outputs[0].text.lower() == \
full_res.outputs[0].text.lower()
@pytest.mark.parametrize(
"model, supported",
[
("Qwen/Qwen2-1.5B-Instruct", True),
# MLA does not support capturing CUDA Graphs with size > max_num_seqs
("deepseek-ai/DeepSeek-V2-Lite", False),
])
@pytest.mark.skipif(current_platform.get_device_capability() != (9, 0),
reason="Only Hopper GPUs support FA3 and FlashMLA")
def test_lower_max_num_seqs(model, supported):
with temporary_environ({
"VLLM_USE_V1": "1",
"VLLM_FLASH_ATTN_VERSION": "3"
}), ExitStack() as stack:
if not supported:
stack.enter_context(pytest.raises(RuntimeError))
llm = LLM(model=model,
max_num_seqs=256,
trust_remote_code=True,
max_model_len=1024,
compilation_config=CompilationConfig(
full_cuda_graph=True,
cudagraph_capture_sizes=[64, 256, 512]))
llm.generate(["Hello, my name is"] * 10)
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda") @pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
def test_full_cudagraph_with_invalid_backend(): def test_full_cudagraph_with_invalid_backend():
with temporary_environ({ with temporary_environ({
"VLLM_USE_V1": "1", "VLLM_USE_V1": "1",
"VLLM_FLASH_ATTN_VERSION": "VLLM_ATTENTION_BACKEND": "FLEX_ATTENTION"
"2" #FA2 not supported with full_cuda_graph # Flex_Attention is not supported with full cuda graph
}), pytest.raises(RuntimeError): }), pytest.raises(RuntimeError):
LLM(model="Qwen/Qwen2-1.5B-Instruct", LLM(model="Qwen/Qwen2-1.5B-Instruct",
compilation_config=CompilationConfig(full_cuda_graph=True)) compilation_config=CompilationConfig(cudagraph_mode="FULL"))
...@@ -11,10 +11,10 @@ from torch.library import Library ...@@ -11,10 +11,10 @@ from torch.library import Library
from vllm.compilation.counter import compilation_counter from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig, from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
set_current_vllm_config) VllmConfig, set_current_vllm_config)
from vllm.envs import VLLM_USE_V1 from vllm.envs import VLLM_USE_V1
from vllm.forward_context import set_forward_context from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
global_counter = 0 global_counter = 0
...@@ -101,16 +101,33 @@ def test_simple_piecewise_compile(use_inductor): ...@@ -101,16 +101,33 @@ def test_simple_piecewise_compile(use_inductor):
num_backend_compilations=3, # num_piecewise_capturable_graphs_seen num_backend_compilations=3, # num_piecewise_capturable_graphs_seen
num_cudagraph_captured= num_cudagraph_captured=
6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen 6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
), set_forward_context({}, vllm_config=vllm_config): ), set_forward_context(None,
vllm_config=vllm_config): # background context
# warm up with background context
model(inputs) model(inputs)
model(torch.randn(2).cuda()) # capturing/replaying should under context of cudagraph dispatching
model(torch.randn(1).cuda()) with set_forward_context(
None,
vllm_config=vllm_config,
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
batch_descriptor=BatchDescriptor(num_tokens=2, )):
model(torch.randn(2).cuda())
with set_forward_context(
None,
vllm_config=vllm_config,
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
batch_descriptor=BatchDescriptor(num_tokens=1, )):
model(torch.randn(1).cuda())
input = torch.zeros(2).cuda() input = torch.zeros(2).cuda()
global global_counter global global_counter
global_counter = 0 global_counter = 0
output = model(input) with set_forward_context(
None,
vllm_config=vllm_config,
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
batch_descriptor=BatchDescriptor(num_tokens=2, )):
output = model(input)
assert global_counter == 2 assert global_counter == 2
assert torch.allclose(output.cpu(), torch.tensor([3., 1.])) assert torch.allclose(output.cpu(), torch.tensor([3., 1.]))
...@@ -18,9 +18,9 @@ from torch.library import Library ...@@ -18,9 +18,9 @@ from torch.library import Library
from vllm.compilation.counter import compilation_counter from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig, from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
set_current_vllm_config) VllmConfig, set_current_vllm_config)
from vllm.forward_context import set_forward_context from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
# create a library to hold the custom op # create a library to hold the custom op
...@@ -276,9 +276,11 @@ def run_model(llama_config, ...@@ -276,9 +276,11 @@ def run_model(llama_config,
) )
if split_attn: if split_attn:
compilation_config.splitting_ops = ["silly.attention"] compilation_config.splitting_ops = ["silly.attention"]
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
else: else:
compilation_config = CompilationConfig( compilation_config = CompilationConfig(
level=CompilationLevel.NO_COMPILATION, ) level=CompilationLevel.NO_COMPILATION, )
cudagraph_runtime_mode = CUDAGraphMode.NONE
vllm_config = VllmConfig(compilation_config=compilation_config, vllm_config = VllmConfig(compilation_config=compilation_config,
additional_config=llama_config) additional_config=llama_config)
...@@ -287,17 +289,37 @@ def run_model(llama_config, ...@@ -287,17 +289,37 @@ def run_model(llama_config,
vllm_config=vllm_config, vllm_config=vllm_config,
prefix="").eval().cuda() prefix="").eval().cuda()
with set_forward_context({}, vllm_config=vllm_config): with set_forward_context({},
vllm_config=vllm_config): # background context
B = 16 # max batch size B = 16 # max batch size
input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda() input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda()
positions = torch.arange(B).cuda() positions = torch.arange(B).cuda()
# warmup for the model with cudagraph_mode NONE
model(input_ids, positions) model(input_ids, positions)
model(input_ids[:2], positions[:2])
model(input_ids[:1], positions[:1]) # simulate cudagraphs capturing
with set_forward_context({},
vllm_config=vllm_config,
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=BatchDescriptor(
num_tokens=2, )):
model(input_ids[:2], positions[:2])
with set_forward_context({},
vllm_config=vllm_config,
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=BatchDescriptor(
num_tokens=1, )):
model(input_ids[:1], positions[:1])
input_ids[:2].zero_() input_ids[:2].zero_()
output = model(input_ids[:2], positions[:2]) # simulate cudagraphs replay
with set_forward_context({},
vllm_config=vllm_config,
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=BatchDescriptor(
num_tokens=2, )):
output = model(input_ids[:2], positions[:2])
output = output.cpu() output = output.cpu()
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from unittest.mock import MagicMock, patch
import pytest
import torch
import torch.nn as nn
from tests.utils import create_new_process_for_each_test
from vllm.compilation.cuda_graph import CUDAGraphWrapper
from vllm.compilation.monitor import set_cudagraph_capturing_enabled
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
ParallelConfig, SchedulerConfig, VllmConfig)
from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.platforms import current_platform
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
# Helper MLP for testing
class SimpleMLP(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(10, 10)
self.fc2 = nn.Linear(10, 10)
def forward(self, x):
return self.fc2(self.fc1(x))
def _create_vllm_config(compilation_config: CompilationConfig,
max_num_seqs: int = 8) -> MagicMock:
mock_config = MagicMock(spec=VllmConfig)
mock_config.compilation_config = compilation_config
mock_config.scheduler_config = SchedulerConfig(max_num_seqs=max_num_seqs)
mock_config.parallel_config = ParallelConfig()
# Mimic the behavior of VllmConfig.__post_init__()
if compilation_config.level == CompilationLevel.PIECEWISE:
compilation_config.set_splitting_ops_for_v1()
return mock_config
class TestCudagraphDispatcher:
@pytest.mark.parametrize(
"params",
[
# Test case 0: Full CG for mixed batches, no separate routine
{
"case_id": 0,
"cudagraph_mode": "FULL",
"compilation_level": CompilationLevel.NO_COMPILATION,
},
# Test case 1: Full CG for uniform batches, piecewise for mixed
{
"case_id": 1,
"cudagraph_mode": "FULL_AND_PIECEWISE",
"compilation_level": CompilationLevel.PIECEWISE,
},
# Test case 2: Full CG for uniform batches, no CG for mixed
{
"case_id": 2,
"cudagraph_mode": "FULL_DECODE_ONLY",
"compilation_level": CompilationLevel.NO_COMPILATION,
},
# Test case 3: Piecewise for all
{
"case_id": 3,
"cudagraph_mode": "PIECEWISE",
"compilation_level": CompilationLevel.PIECEWISE,
},
])
def test_dispatcher(self, params):
# Setup dispatcher
comp_config = CompilationConfig(
cudagraph_mode=params["cudagraph_mode"],
level=params["compilation_level"],
cudagraph_capture_sizes=[1, 8])
config = _create_vllm_config(comp_config, max_num_seqs=8)
dispatcher = CudagraphDispatcher(config)
dispatcher.initialize_cudagraph_keys(
cudagraph_mode=comp_config.cudagraph_mode,
uniform_decode_query_len=1)
# Verify the key is initialized correctly
if params["cudagraph_mode"] in ["FULL_AND_PIECEWISE", "PIECEWISE"]:
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == 2
else:
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == 0
if params["cudagraph_mode"] not in ["NONE", "PIECEWISE"]:
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == 2
else:
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == 0
# Test dispatch logic
# 1. non-uniform batch, size in cudagraph size list
desc_full_exact = BatchDescriptor(num_tokens=8, uniform_decode=False)
rt_mode, key = dispatcher.dispatch(desc_full_exact)
if params["cudagraph_mode"] == "FULL":
assert rt_mode == CUDAGraphMode.FULL
assert key == desc_full_exact
elif params["cudagraph_mode"] in ["FULL_AND_PIECEWISE", "PIECEWISE"]:
assert rt_mode == CUDAGraphMode.PIECEWISE
assert key == desc_full_exact
else:
assert rt_mode == CUDAGraphMode.NONE
# 2. uniform decode batch, size in cudagraph size list
desc_uniform_exact = BatchDescriptor(num_tokens=8, uniform_decode=True)
rt_mode, key = dispatcher.dispatch(desc_uniform_exact)
if params["cudagraph_mode"] == "FULL":
assert rt_mode == CUDAGraphMode.FULL
assert key == desc_uniform_exact.non_uniform
elif params["cudagraph_mode"] in [
"FULL_DECODE_ONLY", "FULL_AND_PIECEWISE"
]:
assert rt_mode == CUDAGraphMode.FULL
assert key == desc_uniform_exact
elif params["cudagraph_mode"] == "PIECEWISE":
assert rt_mode == CUDAGraphMode.PIECEWISE
assert key == desc_uniform_exact.non_uniform
else:
assert rt_mode == CUDAGraphMode.NONE
# 3. No key match
desc_no_match = BatchDescriptor(num_tokens=15, uniform_decode=False)
rt_mode, key = dispatcher.dispatch(desc_no_match)
assert rt_mode == CUDAGraphMode.NONE
assert key is None
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
class TestCUDAGraphWrapper:
def setup_method(self):
self.vllm_config = _create_vllm_config(CompilationConfig())
self.model = SimpleMLP().to("cuda")
self.persistent_input_buffer = torch.zeros(1, 10, device="cuda")
self.input_tensor = torch.randn(1, 10, device="cuda")
@create_new_process_for_each_test("spawn")
def test_capture_and_replay(self):
wrapper = CUDAGraphWrapper(self.model,
self.vllm_config,
runtime_mode=CUDAGraphMode.FULL)
batch_descriptor = BatchDescriptor(num_tokens=10)
# 0. global warmup
with set_forward_context(attn_metadata=None,
vllm_config=self.vllm_config,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
batch_descriptor=None):
wrapper(self.input_tensor)
# 1. Capture
with set_forward_context(
attn_metadata=None,
vllm_config=self.vllm_config,
cudagraph_runtime_mode=CUDAGraphMode.FULL,
batch_descriptor=batch_descriptor),\
patch("torch.cuda.graph",
wraps=torch.cuda.graph) as mock_cuda_graph:
output1 = wrapper(self.input_tensor)
# capturing phase should generate a zero output
assert torch.allclose(output1, torch.zeros_like(output1))
mock_cuda_graph.assert_called_once()
assert batch_descriptor in wrapper.concrete_cudagraph_entries
entry = wrapper.concrete_cudagraph_entries[batch_descriptor]
assert entry.cudagraph is not None
# 2. Replay
with set_forward_context(
attn_metadata=None,
vllm_config=self.vllm_config,
cudagraph_runtime_mode=CUDAGraphMode.FULL,
batch_descriptor=batch_descriptor),\
patch.object(entry.cudagraph, 'replay',
wraps=entry.cudagraph.replay) as mock_replay:
output2 = wrapper(self.input_tensor)
mock_replay.assert_called_once()
# Compare with eager output
eager_output = self.model(self.input_tensor)
torch.testing.assert_close(eager_output, output2)
@create_new_process_for_each_test("spawn")
def test_bypass_on_mode_mismatch(self):
wrapper = CUDAGraphWrapper(self.model,
self.vllm_config,
runtime_mode=CUDAGraphMode.FULL)
batch_descriptor = BatchDescriptor(num_tokens=10)
with set_forward_context(
attn_metadata=None,
vllm_config=self.vllm_config,
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
batch_descriptor=batch_descriptor), \
patch('torch.cuda.graph',
wraps=torch.cuda.graph) as mock_cuda_graph, \
patch.object(self.model, 'forward',
wraps=self.model.forward) as mock_forward:
wrapper(self.input_tensor)
mock_cuda_graph.assert_not_called()
mock_forward.assert_called_once()
assert not wrapper.concrete_cudagraph_entries
@create_new_process_for_each_test("spawn")
def test_bypass_on_mode_none(self):
wrapper = CUDAGraphWrapper(self.model,
self.vllm_config,
runtime_mode=CUDAGraphMode.FULL)
batch_descriptor = BatchDescriptor(num_tokens=10)
with set_forward_context(
attn_metadata=None,
vllm_config=self.vllm_config,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
batch_descriptor=batch_descriptor), \
patch('torch.cuda.graph',
wraps=torch.cuda.graph) as mock_cuda_graph:
wrapper(self.input_tensor)
mock_cuda_graph.assert_not_called()
assert not wrapper.concrete_cudagraph_entries
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
class TestCudagraphIntegration:
def setup_method(self):
# only FULL mode for non-uniform batches
self.comp_config = CompilationConfig(level=CompilationLevel.PIECEWISE,
cudagraph_mode="FULL",
cudagraph_capture_sizes=[10, 20])
self.vllm_config = _create_vllm_config(self.comp_config)
self.dispatcher = CudagraphDispatcher(self.vllm_config)
self.dispatcher.initialize_cudagraph_keys(
self.comp_config.cudagraph_mode, uniform_decode_query_len=1)
def _run_and_monitor_call(self, wrapper, input_tensor, runtime_mode,
batch_descriptor):
"""Helper to run a single call and monitor the action."""
with patch('torch.cuda.graph',
wraps=torch.cuda.graph) as mock_graph_context, \
patch.object(wrapper, 'runnable',
wraps=wrapper.runnable) as mock_runnable:
entry = wrapper.concrete_cudagraph_entries.get(
batch_descriptor, None)
context = set_forward_context(attn_metadata=None,
vllm_config=self.vllm_config,
cudagraph_runtime_mode=runtime_mode,
batch_descriptor=batch_descriptor)
mock_replay = MagicMock()
if entry and entry.cudagraph:
with context, \
patch.object(entry.cudagraph, 'replay',
new_callable=MagicMock) as mock_replay:
wrapper(input_tensor)
else:
with context:
wrapper(input_tensor)
if mock_graph_context.called:
# note that this is globally mocked, so it will be detected
# even whether called by the inner or outer wrapper
return "capture_global"
if mock_replay.called:
# only for outer wrapper
return "replay"
if mock_runnable.call_count > 0:
# only for outer wrapper
return "bypass"
return "unknown"
@create_new_process_for_each_test("spawn")
def test_capture_replay_bypass_logic(self):
model = SimpleMLP().to("cuda")
full_wrapper = CUDAGraphWrapper(model, self.vllm_config,
CUDAGraphMode.FULL)
max_bs = 16
persistent_input_buffer = torch.zeros(max_bs, 10, device="cuda")
input_1 = persistent_input_buffer[:1]
input_2 = persistent_input_buffer[:2]
input_3 = persistent_input_buffer[:3]
desc_1 = BatchDescriptor(num_tokens=1)
desc_2 = BatchDescriptor(num_tokens=2)
desc_3_unseen = BatchDescriptor(num_tokens=3)
# 0. global warmup
with set_forward_context(attn_metadata=None,
vllm_config=self.vllm_config,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
batch_descriptor=None):
full_wrapper(input_1)
rt_mode, key = self.dispatcher.dispatch(desc_1)
# 1. Capture first shape
action = self._run_and_monitor_call(full_wrapper, input_1, rt_mode,
key)
assert action == "capture_global"
# 2. Replay first shape
action = self._run_and_monitor_call(full_wrapper, input_1, rt_mode,
key)
assert action == "replay"
rt_mode, key = self.dispatcher.dispatch(desc_2)
# 3. Capture second shape
action = self._run_and_monitor_call(full_wrapper, input_2, rt_mode,
key)
assert action == "capture_global"
# 4. Replay second shape
action = self._run_and_monitor_call(full_wrapper, input_2,
CUDAGraphMode.FULL, desc_2)
assert action == "replay"
# 5. Bypass if no key match
rt_mode, key = self.dispatcher.dispatch(desc_3_unseen)
assert rt_mode == CUDAGraphMode.NONE
action = self._run_and_monitor_call(full_wrapper, input_3, rt_mode,
key)
assert action == "bypass"
# capture unseen shape is not allowed after disable
set_cudagraph_capturing_enabled(False)
with pytest.raises(RuntimeError):
self._run_and_monitor_call(full_wrapper, input_3,
CUDAGraphMode.FULL, desc_3_unseen)
set_cudagraph_capturing_enabled(True)
@create_new_process_for_each_test("spawn")
def test_nested_wrappers(self):
"""Tests a scenario with a PIECEWISE wrapper inside a FULL one."""
model = SimpleMLP().to("cuda")
full_wrapper = CUDAGraphWrapper(model, self.vllm_config,
CUDAGraphMode.FULL)
input_1 = torch.randn(1, 10, device="cuda")
# Setup: Inner model is wrapped with PIECEWISE, outer with FULL
inner_model = SimpleMLP().to("cuda")
piecewise_wrapper = CUDAGraphWrapper(inner_model, self.vllm_config,
CUDAGraphMode.PIECEWISE)
inner_model.forward = MagicMock(wraps=inner_model.forward)
outer_model = SimpleMLP().to("cuda")
# When outer model is called, it calls the piecewise_wrapper
outer_model.forward = MagicMock(wraps=outer_model.forward,
side_effect=piecewise_wrapper)
full_wrapper = CUDAGraphWrapper(outer_model, self.vllm_config,
CUDAGraphMode.FULL)
desc_1 = BatchDescriptor(num_tokens=1)
# 0. global warmup
with set_forward_context(attn_metadata=None,
vllm_config=self.vllm_config,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
batch_descriptor=None):
full_wrapper(input_1)
# --- Test runtime mode FULL---
# Run with FULL mode context. Expect outer wrapper to capture.
# The inner mock should be called once inside the graph capture.
outer_model.forward.reset_mock()
inner_model.forward.reset_mock()
action = self._run_and_monitor_call(full_wrapper, input_1,
CUDAGraphMode.FULL, desc_1)
assert action == "capture_global"
assert outer_model.forward.call_count == 1
assert inner_model.forward.call_count == 1
# Run again. Expect outer wrapper to replay.
# The outer model should NOT be called because the whole graph
# is replayed.
action = self._run_and_monitor_call(full_wrapper, input_1,
CUDAGraphMode.FULL, desc_1)
assert action == "replay"
assert outer_model.forward.call_count == 1 # No new call
assert inner_model.forward.call_count == 1
# --- Test runtime mode PIECEWISE ---
outer_model.forward.reset_mock()
inner_model.forward.reset_mock()
# Run with PIECEWISE mode context.
# Expect outer wrapper to bypass and call inner wrapper.
# Inner wrapper should capture.
action = self._run_and_monitor_call(full_wrapper, input_1,
CUDAGraphMode.PIECEWISE, desc_1)
assert action == "capture_global"
assert outer_model.forward.call_count == 1
assert inner_model.forward.call_count == 1
# Run again with PIECEWISE.
# Outer bypasses, inner replays.
action = self._run_and_monitor_call(full_wrapper, input_1,
CUDAGraphMode.PIECEWISE, desc_1)
assert action == "bypass"
assert outer_model.forward.call_count == 2
assert inner_model.forward.call_count == 1
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import os
import weakref
from contextlib import ExitStack
from dataclasses import dataclass
from typing import Optional
import pytest
from tests.utils import wait_for_gpu_memory_to_clear
from vllm import LLM
from vllm.config import CompilationConfig
from vllm.platforms import current_platform
@contextlib.contextmanager
def temporary_environ(env_vars):
"""
Temporarily set environment variables and restore them afterward.
We have to do this vs monkeypatch because monkeypatch doesn't work
with "module" scoped fixtures.
"""
original_env = {k: os.environ.get(k) for k in env_vars}
try:
os.environ.update(env_vars)
yield
finally:
for k, v in original_env.items():
if v is None:
os.environ.pop(k, None)
else:
os.environ[k] = v
@dataclass
class BackendConfig:
name: str
env_vars: dict
comp_config: dict
specific_gpu_arch: Optional[tuple] = None
# Define all backend configurations of full cudagraph to be tested
backend_configs = {
# FA3 on Hopper
"FA3":
BackendConfig(name="FA3",
env_vars={"VLLM_FLASH_ATTN_VERSION": "3"},
comp_config={
"cudagraph_mode": "FULL",
},
specific_gpu_arch=(9, 0)),
# FlashMLA on Hopper
"FlashMLA":
BackendConfig(name="FlashMLA",
env_vars={
"VLLM_ATTENTION_BACKEND": "FLASHMLA",
},
comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE",
},
specific_gpu_arch=(9, 0)),
# FA2
"FA2":
BackendConfig(name="FA2",
env_vars={"VLLM_FLASH_ATTN_VERSION": "2"},
comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE",
}),
# Triton Attention
"TritonAttn":
BackendConfig(name="TritonAttn",
env_vars={"VLLM_ATTENTION_BACKEND": "TRITON_ATTN_VLLM_V1"},
comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE",
}),
# FlashInfer
"FlashInfer":
BackendConfig(name="FlashInfer",
env_vars={"VLLM_ATTENTION_BACKEND": "FLASHINFER"},
comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE",
}),
}
# test attention backend and cudagraph_mode combo
# (backend_name, cudagraph_mode, supported)
combo_cases_1 = [
("FA3", "FULL", True),
("FA3", "FULL_AND_PIECEWISE", True),
("FA2", "FULL", True), # Should fallback to FULL_AND_PIECEWISE
("FA2", "FULL_AND_PIECEWISE", True),
("FlashInfer", "FULL", True), # Should fallback to FULL_AND_PIECEWISE
("FlashInfer", "FULL_AND_PIECEWISE", True),
]
@pytest.mark.parametrize("combo_case", combo_cases_1)
def test_backend_and_cudagraph_mode_combo(combo_case):
backend_name, cudagraph_mode, supported = combo_case
if backend_name == "FlashInfer":
try:
import flashinfer # noqa: F401
except ImportError:
pytest.skip("FlashInfer is not installed")
backend_config = backend_configs[backend_name]
# Dynamically skip test if GPU capability is not met
if backend_config.specific_gpu_arch and backend_config.specific_gpu_arch\
!= current_platform.get_device_capability():
pytest.skip("Only Hopper GPUs support FA3 and FlashMLA")
env_vars = {"VLLM_USE_V1": "1", **backend_configs[backend_name].env_vars}
with temporary_environ(env_vars), ExitStack() as stack:
if not supported:
stack.enter_context(pytest.raises(Exception))
llm = LLM(model="Qwen/Qwen2-1.5B-Instruct",
max_num_seqs=256,
trust_remote_code=True,
gpu_memory_utilization=0.45,
max_model_len=1024,
compilation_config=CompilationConfig(
level=3, cudagraph_mode=cudagraph_mode))
llm.generate(["Hello, my name is"] * 10)
try:
llm = weakref.proxy(llm)
del llm
except UnboundLocalError:
pass
wait_for_gpu_memory_to_clear(
devices=[0],
threshold_ratio=0.1,
)
# test cudagraph_mode with different compilation level.
# (backend_name, cudagraph_mode, compilation_level, supported)
combo_cases_2 = [
("FA2", "FULL", 0, True), # no compilation + full cudagraph
("FA2", "FULL", 3, True), # piecewise compilation + full cudagraph
("FA2", "PIECEWISE", 0, False), # no compilation + piecewise cudagraph
("FA2", "PIECEWISE", 3,
True), # piecewise compilation + piecewise cudagraph
("FA2", "FULL_AND_PIECEWISE", 0,
False), # piecewise cudagraph not supported without piecewise compilation
("FA2", "FULL_AND_PIECEWISE", 3, True),
("FA2", "FULL_DECODE_ONLY", 0, True),
("FA2", "FULL_DECODE_ONLY", 3, True),
("FA2", "NONE", 0, True), # no compilation + no cudagraph
("FA2", "NONE", 3, True), # piecewise compilation + no cudagraph
]
@pytest.mark.parametrize("combo_case", combo_cases_2)
def test_cudagraph_compilation_combo(combo_case):
backend_name, cudagraph_mode, compilation_level, supported\
= combo_case
env_vars = {"VLLM_USE_V1": "1", **backend_configs[backend_name].env_vars}
with temporary_environ(env_vars), ExitStack() as stack:
if not supported:
stack.enter_context(pytest.raises(Exception))
llm = LLM(model="Qwen/Qwen2-1.5B-Instruct",
max_num_seqs=256,
trust_remote_code=True,
gpu_memory_utilization=0.45,
max_model_len=1024,
compilation_config=CompilationConfig(
level=compilation_level, cudagraph_mode=cudagraph_mode))
llm.generate(["Hello, my name is"] * 10)
try:
llm = weakref.proxy(llm)
del llm
except UnboundLocalError:
pass
finally:
wait_for_gpu_memory_to_clear(
devices=[0],
threshold_ratio=0.1,
)
...@@ -15,7 +15,7 @@ import torch.fx as fx ...@@ -15,7 +15,7 @@ import torch.fx as fx
from torch._dispatch.python import enable_python_dispatcher from torch._dispatch.python import enable_python_dispatcher
import vllm.envs as envs import vllm.envs as envs
from vllm.config import CompilationConfig, VllmConfig from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import is_torch_equal_or_newer, resolve_obj_by_qualname from vllm.utils import is_torch_equal_or_newer, resolve_obj_by_qualname
...@@ -277,9 +277,6 @@ def split_graph(graph: fx.GraphModule, ...@@ -277,9 +277,6 @@ def split_graph(graph: fx.GraphModule,
return split_gm, outputs return split_gm, outputs
# we share the global graph pool among all the backends
global_graph_pool = None
compilation_start_time = 0.0 compilation_start_time = 0.0
...@@ -339,14 +336,37 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): ...@@ -339,14 +336,37 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
graph_index=index, graph_index=index,
num_graphs=len(self.compile_submod_names), num_graphs=len(self.compile_submod_names),
runtime_shape=None) runtime_shape=None)
# Lazy import here to avoid circular import
from .cuda_graph import CUDAGraphOptions
from .cuda_piecewise_backend import PiecewiseBackend
piecewise_backend = resolve_obj_by_qualname( piecewise_backend = PiecewiseBackend(
current_platform.get_piecewise_backend_cls()) submod, self.vllm_config, index,
self.module.__dict__[target] = piecewise_backend(
submod, self.vllm_config, self.graph_pool, index,
len(self.compile_submod_names), sym_shape_indices, len(self.compile_submod_names), sym_shape_indices,
compiled_graph_for_dynamic_shape, self.vllm_backend) compiled_graph_for_dynamic_shape, self.vllm_backend)
if self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE:
# resolve the static graph wrapper class (e.g. CUDAGraphWrapper
# class) as platform dependent.
static_graph_wrapper_class = resolve_obj_by_qualname(
current_platform.get_static_graph_wrapper_cls())
# Always assign PIECEWISE runtime mode to the
# CUDAGraphWrapper for piecewise_backend, to distinguish
# it from the FULL cudagraph runtime mode, no matter it
# is wrapped on a full or piecewise fx graph.
self.module.__dict__[target] = static_graph_wrapper_class(
runnable=piecewise_backend,
vllm_config=self.vllm_config,
runtime_mode=CUDAGraphMode.PIECEWISE,
graph_pool=self.graph_pool,
cudagraph_options=CUDAGraphOptions(
debug_log_enable=piecewise_backend.is_first_graph,
gc_disable=not piecewise_backend.is_first_graph,
weak_ref_output=piecewise_backend.is_last_graph))
else:
self.module.__dict__[target] = piecewise_backend
compilation_counter.num_piecewise_capturable_graphs_seen += 1 compilation_counter.num_piecewise_capturable_graphs_seen += 1
return output return output
...@@ -413,9 +433,7 @@ class VllmBackend: ...@@ -413,9 +433,7 @@ class VllmBackend:
# them, e.g. backbone (default), eagle_head, etc. # them, e.g. backbone (default), eagle_head, etc.
self.prefix = prefix or model_tag self.prefix = prefix or model_tag
global global_graph_pool global_graph_pool = current_platform.get_global_graph_pool()
if global_graph_pool is None:
global_graph_pool = current_platform.graph_pool_handle()
# TODO: in the future, if we want to use multiple # TODO: in the future, if we want to use multiple
# streams, it might not be safe to share a global pool. # streams, it might not be safe to share a global pool.
...@@ -585,7 +603,7 @@ class VllmBackend: ...@@ -585,7 +603,7 @@ class VllmBackend:
self._called = True self._called = True
if not self.compilation_config.use_cudagraph or \ if self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE or \
not self.compilation_config.cudagraph_copy_inputs: not self.compilation_config.cudagraph_copy_inputs:
return self.split_gm return self.split_gm
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Callable, Protocol
import torch.fx as fx
from vllm.compilation.backends import VllmBackend
from vllm.config import VllmConfig
class AbstractPiecewiseBackend(Protocol):
"""
PiecewiseBackend interface that allows platforms to extend
piecewise static graph.
"""
def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
graph_pool: Any, piecewise_compile_index: int,
total_piecewise_compiles: int, sym_shape_indices: list[int],
compiled_graph_for_general_shape: Callable,
vllm_backend: VllmBackend, **kwargs):
"""
Initializes the PiecewiseBackend class with compilation and
execution-related configurations.
This class handles piecewise compilation, graph capturing,
and dispatching for specific input shapes.
Args:
graph (fx.GraphModule): The graph represented in fx.
vllm_config (VllmConfig): Global configuration for vLLM.
graph_pool (Any):
Graph memory pool handle, e.g.,
`torch.cuda.graph_pool_handle()`.
piecewise_compile_index (int):
Index of the current piecewise subgraph.
total_piecewise_compiles (int):
Total number of piecewise-compiled graphs.
sym_shape_indices (list[int]):
Indices of symbolic shape.
compiled_graph_for_general_shape (Callable):
Callable that executes the graph compiled for general shapes.
vllm_backend (VllmBackend):
Backend compiler that manages compilation and graph runtime
for vLLM.
Keyword Args:
kwargs: Additional keyword arguments reserved for future
extensions or custom platforms.
"""
raise NotImplementedError
def __call__(self, *args) -> Any:
"""Executes the compiled graph for given input args.
If this is the first invocation, executes the general compiled graph
and initiates the compilation process tracking. For subsequent calls,
dynamically dispatches execution to either a compiled graph or a static
graph based on the input shape.
Args:
*args: Variable length input arguments to be passed into the
graph. The symbolic shape is expected to be in position
`sym_shape_indices[0]`.
Returns:
Any: Output of the executed graph. This can be from the general
compiled graph, a specialized compiled version for the given shape,
or a replayed static graph.
"""
raise NotImplementedError
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Callable, Protocol
from vllm.config import CUDAGraphMode, VllmConfig
class AbstractStaticGraphWrapper(Protocol):
"""
StaticGraphWrapper interface that allows platforms to wrap a callable
to be captured as a static graph.
"""
def __init__(self, runnable: Callable, vllm_config: VllmConfig,
runtime_mode: CUDAGraphMode, graph_pool: Any, **kwargs):
"""
Initializes the StaticGraphWrapper class with graph capturing and
execution-related configurations.
Args:
runnable (Callable): The callable to be wrapped and captured.
vllm_config (VllmConfig): Global configuration for vLLM.
runtime_mode (CUDAGraphMode): The style of the static
graph runtime. See CUDAGraphMode in vllm/config.py.
Note that only the subset enum `NONE`, `PIECEWISE` and `FULL`
are used as concrete runtime mode for cudagraph dispatching.
graph_pool (Any):
Graph memory pool handle, e.g.,
`torch.cuda.graph_pool_handle()`.
Keyword Args:
kwargs: Additional keyword arguments for platform-specific
configurations.
"""
raise NotImplementedError
def __call__(self, *args, **kwargs) -> Any:
"""
Executes the wrapped callable.
If the current runtime mode in the ForwardContext matches the runtime
mode of this instance, it replays the CUDAGraph or captures it using
the callable if it hasn't been captured yet. Otherwise, it calls the
original callable directly.
Args:
*args: Variable length input arguments to be passed into the
callable.
**kwargs: Keyword arguments to be passed into the callable.
Returns:
Any: Output of the executed callable.
"""
raise NotImplementedError
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses
from contextlib import ExitStack
from typing import Any, Callable, Optional
from unittest.mock import patch
import torch
import vllm.envs as envs
from vllm.compilation.counter import compilation_counter
from vllm.compilation.monitor import validate_cudagraph_capturing_enabled
from vllm.config import CUDAGraphMode, VllmConfig
from vllm.forward_context import BatchDescriptor, get_forward_context
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import weak_ref_tensors
logger = init_logger(__name__)
@dataclasses.dataclass
class CUDAGraphEntry:
batch_descriptor: BatchDescriptor
cudagraph: Optional[torch.cuda.CUDAGraph] = None
output: Optional[Any] = None
# for cudagraph debugging, track the input addresses
# during capture, and check if they are the same during replay
input_addresses: Optional[list[int]] = None
@dataclasses.dataclass
class CUDAGraphOptions:
debug_log_enable: bool = True
gc_disable: bool = False
weak_ref_output: bool = True
class CUDAGraphWrapper:
"""Wraps a runnable to add CUDA graph capturing and replaying ability. And
provide attribute access to the underlying `runnable` via `__getattr__`.
The workflow of this wrapper in the cudagraph dispatching is as follows:
1. At initialization, a runtime mode is assigned to the wrapper (FULL or
PIECEWISE).
2. At runtime, the wrapper receives a runtime_mode and a
batch_descriptor(key) from the forward context and blindly trust them
for cudagraph dispatching.
3. If runtime_mode is NONE or runtime_mode does not match the mode of the
wrapper, just call the runnable directly.
4. Otherwise, i.e., the runtime_mode matches the mode of the wrapper,
the wrapper will perform cudagraph capture(if key does not exist, create
a new entry and cache it) or replay (if key exists in the cache).
Note: CUDAGraphWrapper does not store persistent buffers or copy any
runtime inputs into that buffers for replay. We assume implementing them
is done outside of the wrapper. That is because we do not make any
assumption on the dynamic shape (batch size) of the runtime inputs, as a
trade-off for staying orthogonal to compilation logic. Nevertheless,
tracing and checking the input addresses to be consistent during replay is
guaranteed when VLLM_LOGGING_LEVEL == "DEBUG".
"""
def __init__(self,
runnable: Callable,
vllm_config: VllmConfig,
runtime_mode: CUDAGraphMode,
graph_pool: Any = None,
cudagraph_options: Optional[CUDAGraphOptions] = None):
self.runnable = runnable
self.vllm_config = vllm_config
self.graph_pool = graph_pool
self.runtime_mode = runtime_mode
self.compilation_config = vllm_config.compilation_config
self.first_run_finished = False
self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG"
# assert runtime_mode is not NONE(no cudagraph), otherwise, we don't
# need to initialize a CUDAGraphWrapper.
assert self.runtime_mode != CUDAGraphMode.NONE
if self.graph_pool is None:
self.graph_pool = current_platform.get_global_graph_pool()
if cudagraph_options is None:
cudagraph_options = CUDAGraphOptions()
self.cudagraph_options = cudagraph_options
# the entries for different batch descriptors that we need to capture
# cudagraphs for.
self.concrete_cudagraph_entries: dict[BatchDescriptor, CUDAGraphEntry]\
= {}
def __getattr__(self, key: str):
# allow accessing the attributes of the runnable.
if hasattr(self.runnable, key):
return getattr(self.runnable, key)
raise AttributeError(f"Attribute {key} not exists in the runnable of "
f"cudagraph wrapper: {self.runnable}")
def unwrap(self) -> Callable:
# in case we need to access the original runnable.
return self.runnable
def __call__(self, *args, **kwargs):
forward_context = get_forward_context()
batch_descriptor = forward_context.batch_descriptor
cudagraph_runtime_mode = forward_context.cudagraph_runtime_mode
if cudagraph_runtime_mode == CUDAGraphMode.NONE or \
cudagraph_runtime_mode != self.runtime_mode:
# CUDAGraphMode.NONE could mean the profile run, a warmup run, or
# running without cudagraphs.
# We do not trigger capture/replay if the runtime mode is not
# matches. This enables properly dispatching to the correct
# CUDAGraphWrapper when nesting multiple instances with different
# runtime modes.
return self.runnable(*args, **kwargs)
if batch_descriptor not in self.concrete_cudagraph_entries:
# create a new entry for this batch descriptor
self.concrete_cudagraph_entries[batch_descriptor] = \
CUDAGraphEntry(batch_descriptor=batch_descriptor)
entry = self.concrete_cudagraph_entries[batch_descriptor]
if entry.cudagraph is None:
if self.cudagraph_options.debug_log_enable:
# Since we capture cudagraph for many different shapes and
# capturing is fast, we don't need to log it for every
# shape. E.g. we only log it for the first subgraph in
# piecewise mode.
logger.debug("Capturing a cudagraph on (%s,%s)",
self.runtime_mode.name, entry.batch_descriptor)
# validate that cudagraph capturing is legal at this point.
validate_cudagraph_capturing_enabled()
input_addresses = [
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
]
entry.input_addresses = input_addresses
cudagraph = torch.cuda.CUDAGraph()
with ExitStack() as stack:
if self.cudagraph_options.gc_disable:
# during every model forward for piecewise cudagraph
# mode, we will capture many pieces of cudagraphs
# (roughly one per layer). running gc again and again
# across layers will make the cudagraph capture very slow.
# therefore, we only run gc for the first graph,
# and disable gc for the rest of the graphs.
stack.enter_context(patch("gc.collect", lambda: None))
stack.enter_context(
patch("torch.cuda.empty_cache", lambda: None))
# mind-exploding: carefully manage the reference and memory.
with torch.cuda.graph(cudagraph, pool=self.graph_pool):
# `output` is managed by pytorch's cudagraph pool
output = self.runnable(*args, **kwargs)
if self.cudagraph_options.weak_ref_output:
# by converting it to weak ref,
# the original `output` will immediately be released
# to save memory. It is only safe to do this for
# the last graph in piecewise cuadgraph mode, because
# the output of the last graph will not be used by
# any other cuda graph.
output = weak_ref_tensors(output)
# here we always use weak ref for the output
# to save memory
entry.output = weak_ref_tensors(output)
entry.cudagraph = cudagraph
compilation_counter.num_cudagraph_captured += 1
# important: we need to return the output, rather than
# the weak ref of the output, so that pytorch can correctly
# manage the memory during cuda graph capture
return output
if self.is_debugging_mode:
# check if the input addresses are the same
new_input_addresses = [
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
]
assert new_input_addresses == entry.input_addresses, (
f"Input addresses for cudagraphs are different "
f"during replay. Expected {entry.input_addresses}, "
f"got {new_input_addresses}")
entry.cudagraph.replay()
return entry.output
...@@ -2,21 +2,15 @@ ...@@ -2,21 +2,15 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses import dataclasses
from contextlib import ExitStack from typing import Any, Callable
from typing import Any, Callable, Optional
from unittest.mock import patch
import torch
import torch.fx as fx import torch.fx as fx
import vllm.envs as envs import vllm.envs as envs
from vllm.compilation.backends import VllmBackend from vllm.compilation.backends import VllmBackend
from vllm.compilation.counter import compilation_counter
from vllm.compilation.monitor import end_monitoring_torch_compile from vllm.compilation.monitor import end_monitoring_torch_compile
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import weak_ref_tensors
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -24,44 +18,29 @@ logger = init_logger(__name__) ...@@ -24,44 +18,29 @@ logger = init_logger(__name__)
@dataclasses.dataclass @dataclasses.dataclass
class ConcreteSizeEntry: class ConcreteSizeEntry:
runtime_shape: int runtime_shape: int
need_to_compile: bool # the size is in compile_sizes
use_cudagraph: bool # the size is in cudagraph_capture_sizes
compiled: bool = False compiled: bool = False
runnable: Callable = None # type: ignore runnable: Callable = None # type: ignore
num_finished_warmup: int = 0
cudagraph: Optional[torch.cuda.CUDAGraph] = None
output: Optional[Any] = None
# for cudagraph debugging, track the input addresses
# during capture, and check if they are the same during replay
input_addresses: Optional[list[int]] = None
class CUDAPiecewiseBackend: class PiecewiseBackend:
def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig, def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
graph_pool: Any, piecewise_compile_index: int, piecewise_compile_index: int, total_piecewise_compiles: int,
total_piecewise_compiles: int, sym_shape_indices: list[int], sym_shape_indices: list[int],
compiled_graph_for_general_shape: Callable, compiled_graph_for_general_shape: Callable,
vllm_backend: VllmBackend): vllm_backend: VllmBackend):
""" """
The backend for piecewise compilation. The backend for piecewise compilation.
It mainly handles the compilation and cudagraph capturing. It mainly handles the compilation of static shapes and
dispatching based on runtime shape.
We will compile `self.graph` once for the general shape, We will compile `self.graph` once for the general shape,
and then compile for different shapes specified in and then compile for different shapes specified in
`compilation_config.compile_sizes`. `compilation_config.compile_sizes`.
Independently, we will capture cudagraph for different shapes.
If a shape needs both compilation and cudagraph, we will
compile it first, and then capture cudagraph.
""" """
self.graph = graph self.graph = graph
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.compilation_config = vllm_config.compilation_config self.compilation_config = vllm_config.compilation_config
self.graph_pool = graph_pool
self.piecewise_compile_index = piecewise_compile_index self.piecewise_compile_index = piecewise_compile_index
self.total_piecewise_compiles = total_piecewise_compiles self.total_piecewise_compiles = total_piecewise_compiles
self.vllm_backend = vllm_backend self.vllm_backend = vllm_backend
...@@ -70,11 +49,10 @@ class CUDAPiecewiseBackend: ...@@ -70,11 +49,10 @@ class CUDAPiecewiseBackend:
self.is_last_graph = ( self.is_last_graph = (
piecewise_compile_index == total_piecewise_compiles - 1) piecewise_compile_index == total_piecewise_compiles - 1)
self.is_full_graph = total_piecewise_compiles == 1
self.compile_sizes: set[int] = set( self.compile_sizes: set[int] = set(
self.compilation_config.compile_sizes) self.compilation_config.compile_sizes)
self.cudagraph_capture_sizes: set[int] = set(
self.compilation_config.cudagraph_capture_sizes
) if self.compilation_config.use_cudagraph else set()
self.first_run_finished = False self.first_run_finished = False
...@@ -84,18 +62,18 @@ class CUDAPiecewiseBackend: ...@@ -84,18 +62,18 @@ class CUDAPiecewiseBackend:
self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG" self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG"
# the entries for different shapes that we need to either # the entries for different shapes that we need to compile
# compile or capture cudagraph
self.concrete_size_entries: dict[int, ConcreteSizeEntry] = {} self.concrete_size_entries: dict[int, ConcreteSizeEntry] = {}
# to_be_compiled_sizes tracks the remaining sizes to compile, # to_be_compiled_sizes tracks the remaining sizes to compile,
# and updates during the compilation process, so we need to copy it # and updates during the compilation process, so we need to copy it
self.to_be_compiled_sizes: set[int] = self.compile_sizes.copy() self.to_be_compiled_sizes: set[int] = self.compile_sizes.copy()
for shape in self.compile_sizes.union(self.cudagraph_capture_sizes):
# We only keep compilation management inside this class directly.
for shape in self.compile_sizes:
self.concrete_size_entries[shape] = ConcreteSizeEntry( self.concrete_size_entries[shape] = ConcreteSizeEntry(
runtime_shape=shape, runtime_shape=shape,
need_to_compile=shape in self.compile_sizes, runnable=self.compiled_graph_for_general_shape,
use_cudagraph=shape in self.cudagraph_capture_sizes,
) )
def check_for_ending_compilation(self): def check_for_ending_compilation(self):
...@@ -112,16 +90,14 @@ class CUDAPiecewiseBackend: ...@@ -112,16 +90,14 @@ class CUDAPiecewiseBackend:
return self.compiled_graph_for_general_shape(*args) return self.compiled_graph_for_general_shape(*args)
runtime_shape = args[self.sym_shape_indices[0]] runtime_shape = args[self.sym_shape_indices[0]]
if runtime_shape not in self.concrete_size_entries: if runtime_shape not in self.concrete_size_entries:
# we don't need to do anything for this shape # we don't need to do anything for this shape
return self.compiled_graph_for_general_shape(*args) return self.compiled_graph_for_general_shape(*args)
entry = self.concrete_size_entries[runtime_shape] entry = self.concrete_size_entries[runtime_shape]
if entry.runnable is None: if not entry.compiled:
entry.runnable = self.compiled_graph_for_general_shape
if entry.need_to_compile and not entry.compiled:
entry.compiled = True entry.compiled = True
self.to_be_compiled_sizes.remove(runtime_shape) self.to_be_compiled_sizes.remove(runtime_shape)
# args are real arguments # args are real arguments
...@@ -138,81 +114,4 @@ class CUDAPiecewiseBackend: ...@@ -138,81 +114,4 @@ class CUDAPiecewiseBackend:
if self.is_last_graph and not self.to_be_compiled_sizes: if self.is_last_graph and not self.to_be_compiled_sizes:
self.check_for_ending_compilation() self.check_for_ending_compilation()
# Skip CUDA graphs if this entry doesn't use them OR return entry.runnable(*args)
# if we're supposed to skip them globally
skip_cuda_graphs = get_forward_context().skip_cuda_graphs
if not entry.use_cudagraph or skip_cuda_graphs:
return entry.runnable(*args)
if entry.cudagraph is None:
if entry.num_finished_warmup < self.compilation_config.cudagraph_num_of_warmups: # noqa
entry.num_finished_warmup += 1
if self.is_first_graph:
logger.debug(
"Warming up %s/%s for shape %s",
entry.num_finished_warmup,
self.compilation_config.cudagraph_num_of_warmups,
runtime_shape)
return entry.runnable(*args)
if self.is_first_graph:
# Since we capture cudagraph for many different shapes and
# capturing is fast, we don't need to log it for every shape.
# We only log it in the debug mode.
logger.debug("Capturing a cudagraph for shape %s",
runtime_shape)
input_addresses = [
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
]
entry.input_addresses = input_addresses
cudagraph = torch.cuda.CUDAGraph()
with ExitStack() as stack:
if not self.is_first_graph:
# during every model forward, we will capture
# many pieces of cudagraphs (roughly one per layer).
# running gc again and again across layers will
# make the cudagraph capture very slow.
# therefore, we only run gc for the first graph,
# and disable gc for the rest of the graphs.
stack.enter_context(patch("gc.collect", lambda: None))
stack.enter_context(
patch("torch.cuda.empty_cache", lambda: None))
# mind-exploding: carefully manage the reference and memory.
with torch.cuda.graph(cudagraph, pool=self.graph_pool):
# `output` is managed by pytorch's cudagraph pool
output = entry.runnable(*args)
if self.is_last_graph:
# by converting it to weak ref,
# the original `output` will immediately be released
# to save memory. It is only safe to do this for
# the last graph, because the output of the last graph
# will not be used by any other cuda graph.
output = weak_ref_tensors(output)
# here we always use weak ref for the output
# to save memory
entry.output = weak_ref_tensors(output)
entry.cudagraph = cudagraph
compilation_counter.num_cudagraph_captured += 1
# important: we need to return the output, rather than
# the weak ref of the output, so that pytorch can correctly
# manage the memory during cuda graph capture
return output
if self.is_debugging_mode:
# check if the input addresses are the same
new_input_addresses = [
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
]
assert new_input_addresses == entry.input_addresses, (
"Input addresses for cudagraphs are different during replay."
f" Expected {entry.input_addresses}, got {new_input_addresses}"
)
entry.cudagraph.replay()
return entry.output
...@@ -37,3 +37,21 @@ def end_monitoring_torch_compile(vllm_config: VllmConfig): ...@@ -37,3 +37,21 @@ def end_monitoring_torch_compile(vllm_config: VllmConfig):
if context_manager is not None: if context_manager is not None:
context_manager.__exit__(None, None, None) context_manager.__exit__(None, None, None)
context_manager = None context_manager = None
cudagraph_capturing_enabled: bool = True
def validate_cudagraph_capturing_enabled():
# used to monitor whether an cudagraph capturing is legal at runtime.
# should be called before any cudagraph capturing.
# if an illegal cudagraph capturing happens, raise an error.
global cudagraph_capturing_enabled
if not cudagraph_capturing_enabled:
raise RuntimeError("CUDA graph capturing detected at an inappropriate "
"time. This operation is currently disabled.")
def set_cudagraph_capturing_enabled(enabled: bool):
global cudagraph_capturing_enabled
cudagraph_capturing_enabled = enabled
...@@ -11,7 +11,8 @@ from typing import Callable, Optional ...@@ -11,7 +11,8 @@ from typing import Callable, Optional
import torch import torch
import vllm.envs as envs import vllm.envs as envs
from vllm.config import CompilationLevel, get_current_vllm_config from vllm.config import (CompilationLevel, CUDAGraphMode,
get_current_vllm_config)
from vllm.logger import init_logger from vllm.logger import init_logger
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -115,8 +116,8 @@ class TorchCompileWrapperWithCustomDispatcher: ...@@ -115,8 +116,8 @@ class TorchCompileWrapperWithCustomDispatcher:
except Exception: except Exception:
pass pass
if self.vllm_config.compilation_config.use_cudagraph and \ if self.vllm_config.compilation_config.cudagraph_mode != \
"update" in new_code.co_names: CUDAGraphMode.NONE and "update" in new_code.co_names:
import depyf import depyf
src = depyf.decompile(new_code) src = depyf.decompile(new_code)
msg = "Assigning / modifying buffers of nn.Module during forward pass is not allowed when using cudagraph inside the compiler because it will cause silent errors. Please use eager mode or fix the code. The following code contains clues about which buffer is being modified (please search for the usage of the function `update`):\n" + src # noqa msg = "Assigning / modifying buffers of nn.Module during forward pass is not allowed when using cudagraph inside the compiler because it will cause silent errors. Please use eager mode or fix the code. The following code contains clues about which buffer is being modified (please search for the usage of the function `update`):\n" + src # noqa
......
...@@ -32,7 +32,7 @@ from vllm import version ...@@ -32,7 +32,7 @@ from vllm import version
from vllm.config.cache import (BlockSize, CacheConfig, CacheDType, MambaDType, from vllm.config.cache import (BlockSize, CacheConfig, CacheDType, MambaDType,
PrefixCachingHashAlgo) PrefixCachingHashAlgo)
from vllm.config.compilation import (CompilationConfig, CompilationLevel, from vllm.config.compilation import (CompilationConfig, CompilationLevel,
PassConfig) CUDAGraphMode, PassConfig)
from vllm.config.parallel import DistributedExecutorBackend, ParallelConfig from vllm.config.parallel import DistributedExecutorBackend, ParallelConfig
from vllm.config.scheduler import SchedulerConfig, SchedulerPolicy from vllm.config.scheduler import SchedulerConfig, SchedulerPolicy
from vllm.config.utils import ConfigType, config from vllm.config.utils import ConfigType, config
...@@ -3529,11 +3529,21 @@ class VllmConfig: ...@@ -3529,11 +3529,21 @@ class VllmConfig:
else: else:
self.compilation_config.level = \ self.compilation_config.level = \
CompilationLevel.NO_COMPILATION CompilationLevel.NO_COMPILATION
else: else:
# NB: Passing both --enforce-eager and a compilation level # NB: Passing both --enforce-eager and a compilation level
# in V0 means the compilation level wins out. # in V0 means the compilation level wins out.
self.compilation_config.level = CompilationLevel.NO_COMPILATION self.compilation_config.level = CompilationLevel.NO_COMPILATION
# if cudagraph_mode is not explicitly set by users, set default value
if self.compilation_config.cudagraph_mode is None:
if envs.VLLM_USE_V1 and self.compilation_config.level \
== CompilationLevel.PIECEWISE:
self.compilation_config.cudagraph_mode = \
CUDAGraphMode.PIECEWISE
else:
self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE
# async tp is built on top of sequence parallelism # async tp is built on top of sequence parallelism
# and requires it to be enabled. # and requires it to be enabled.
if self.compilation_config.pass_config.enable_async_tp: if self.compilation_config.pass_config.enable_async_tp:
...@@ -3541,12 +3551,13 @@ class VllmConfig: ...@@ -3541,12 +3551,13 @@ class VllmConfig:
True True
if self.compilation_config.pass_config.enable_sequence_parallelism: if self.compilation_config.pass_config.enable_sequence_parallelism:
self.compilation_config.custom_ops.append("+rms_norm") self.compilation_config.custom_ops.append("+rms_norm")
if envs.VLLM_USE_V1 and self.model_config is not None and \
not self.model_config.enforce_eager: # disable cudagraph when enforce eager execution
# By default, V1 uses piecewise CUDA graphs. If full_cuda_graph if self.model_config is not None and self.model_config.enforce_eager:
# is set to True, full CUDA graphs will be used. logger.info("Cudagraph is disabled under eager mode")
self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE
elif envs.VLLM_USE_V1:
self.compilation_config.cudagraph_num_of_warmups = 1 self.compilation_config.cudagraph_num_of_warmups = 1
self.compilation_config.set_splitting_ops_for_v1()
self._set_cudagraph_sizes() self._set_cudagraph_sizes()
...@@ -3566,12 +3577,6 @@ class VllmConfig: ...@@ -3566,12 +3577,6 @@ class VllmConfig:
"Disabling `torch.compile`.") "Disabling `torch.compile`.")
self.compilation_config.level = CompilationLevel.NO_COMPILATION self.compilation_config.level = CompilationLevel.NO_COMPILATION
if self.compilation_config.full_cuda_graph and \
not self.model_config.disable_cascade_attn:
logger.info("full_cuda_graph is not supported with "
"cascade attention. Disabling cascade attention.")
self.model_config.disable_cascade_attn = True
disable_chunked_prefill_reasons: list[str] = [] disable_chunked_prefill_reasons: list[str] = []
if self.model_config and self.model_config.pooler_config: if self.model_config and self.model_config.pooler_config:
...@@ -3612,9 +3617,32 @@ class VllmConfig: ...@@ -3612,9 +3617,32 @@ class VllmConfig:
"to True to enable.") "to True to enable.")
current_platform.check_and_update_config(self) current_platform.check_and_update_config(self)
# final check of cudagraph mode after platform-specific update
if envs.VLLM_USE_V1:
if self.compilation_config.cudagraph_mode == CUDAGraphMode.FULL \
and self.model_config is not None and \
not self.model_config.disable_cascade_attn:
logger.info("CUDAGraphMode.FULL is not supported with "
"cascade attention currently. Disabling cascade"
"attention.")
self.model_config.disable_cascade_attn = True
if self.compilation_config.cudagraph_mode\
.requires_piecewise_compilation():
assert self.compilation_config.level == \
CompilationLevel.PIECEWISE, \
"Compilation level should be CompilationLevel.PIECEWISE "\
"when cudagraph_mode piecewise cudagraphs is used, "\
f"cudagraph_mode={self.compilation_config.cudagraph_mode}"
if not self.instance_id: if not self.instance_id:
self.instance_id = random_uuid()[:5] self.instance_id = random_uuid()[:5]
# Do this after all the updates to compilation_config.level
if envs.VLLM_USE_V1 and \
self.compilation_config.level == CompilationLevel.PIECEWISE:
self.compilation_config.set_splitting_ops_for_v1()
if (envs.VLLM_USE_V1 if (envs.VLLM_USE_V1
and not self.scheduler_config.disable_hybrid_kv_cache_manager): and not self.scheduler_config.disable_hybrid_kv_cache_manager):
# logger should only print warning message for hybrid models. As we # logger should only print warning message for hybrid models. As we
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import enum
import hashlib import hashlib
from collections import Counter from collections import Counter
from dataclasses import asdict, field from dataclasses import asdict, field
from typing import TYPE_CHECKING, Any, Callable, Optional, Union from typing import TYPE_CHECKING, Any, Callable, ClassVar, Optional, Union
from pydantic import TypeAdapter from pydantic import TypeAdapter, field_validator
from pydantic.dataclasses import dataclass from pydantic.dataclasses import dataclass
import vllm.envs as envs import vllm.envs as envs
...@@ -31,6 +32,40 @@ class CompilationLevel: ...@@ -31,6 +32,40 @@ class CompilationLevel:
PIECEWISE = 3 PIECEWISE = 3
class CUDAGraphMode(enum.Enum):
""" Constants for the cudagraph mode in CompilationConfig.
Meanwhile, the subset enum `NONE`, `PIECEWISE` and `FULL` are also
treated as concrete runtime mode for cudagraph runtime dispatching.
"""
NONE = 0
PIECEWISE = 1
FULL = 2
FULL_DECODE_ONLY = (FULL, NONE)
FULL_AND_PIECEWISE = (FULL, PIECEWISE)
def decode_mode(self) -> 'CUDAGraphMode':
return CUDAGraphMode(self.value[0]) if \
self.separate_routine() else self
def mixed_mode(self) -> 'CUDAGraphMode':
return CUDAGraphMode(self.value[1]) if \
self.separate_routine() else self
def requires_piecewise_compilation(self) -> bool:
return (self.decode_mode() == CUDAGraphMode.PIECEWISE
or self.mixed_mode() == CUDAGraphMode.PIECEWISE)
def max_cudagraph_mode(self) -> 'CUDAGraphMode':
return CUDAGraphMode(max(
self.value)) if self.separate_routine() else self
def has_full_cudagraphs(self) -> bool:
return self.max_cudagraph_mode() == CUDAGraphMode.FULL
def separate_routine(self) -> bool:
return isinstance(self.value, tuple)
@config @config
@dataclass @dataclass
class PassConfig: class PassConfig:
...@@ -91,6 +126,7 @@ class CompilationConfig: ...@@ -91,6 +126,7 @@ class CompilationConfig:
- [`splitting_ops`][vllm.config.CompilationConfig.splitting_ops] - [`splitting_ops`][vllm.config.CompilationConfig.splitting_ops]
- CudaGraph capture: - CudaGraph capture:
- [`use_cudagraph`][vllm.config.CompilationConfig.use_cudagraph] - [`use_cudagraph`][vllm.config.CompilationConfig.use_cudagraph]
- [`cudagraph_mode`][vllm.config.CompilationConfig.cudagraph_mode]
- [`cudagraph_capture_sizes`] - [`cudagraph_capture_sizes`]
[vllm.config.CompilationConfig.cudagraph_capture_sizes] [vllm.config.CompilationConfig.cudagraph_capture_sizes]
- [`cudagraph_num_of_warmups`] - [`cudagraph_num_of_warmups`]
...@@ -157,7 +193,7 @@ class CompilationConfig: ...@@ -157,7 +193,7 @@ class CompilationConfig:
By default, all custom ops are enabled when running without Inductor and By default, all custom ops are enabled when running without Inductor and
disabled when running with Inductor: level>=PIECEWISE and use_inductor=True. disabled when running with Inductor: level>=PIECEWISE and use_inductor=True.
Inductor generates (fused) Triton kernels for disabled custom ops.""" Inductor generates (fused) Triton kernels for disabled custom ops."""
splitting_ops: list[str] = field(default_factory=list) splitting_ops: Optional[list[str]] = None
"""A list of ops to split the full graph into subgraphs, used in piecewise """A list of ops to split the full graph into subgraphs, used in piecewise
compilation.""" compilation."""
...@@ -187,7 +223,43 @@ class CompilationConfig: ...@@ -187,7 +223,43 @@ class CompilationConfig:
constructor, e.g. `CompilationConfig(inductor_passes={"a": func})`.""" constructor, e.g. `CompilationConfig(inductor_passes={"a": func})`."""
# CudaGraph compilation # CudaGraph compilation
use_cudagraph: bool = field(default_factory=lambda: envs.VLLM_USE_V1) cudagraph_mode: Optional[CUDAGraphMode] = None
"""
The mode of the cudagraph.
- NONE, no cudagraph capture.
- PIECEWISE. (v1 default)
- FULL.
- FULL_DECODE_ONLY.
- FULL_AND_PIECEWISE.
PIECEWISE mode build piecewise cudagraph only, keeping the cudagraph
incompatiable ops (i.e. some attention ops) outside the cudagraph
for general flexibility.
This is the default mode.
FULL mode: Capture full cudagraph for all batches. Can be good for small
models or workloads with small prompts; not supported by many backends.
Generally for performance FULL_AND_PIECEWISE is better.
FULL_DECODE_ONLY mode: Capture full cudagraph for decode batches only.
Mixed prefill-decode batches are run without cudagraphs. Can be good for
decode instances in a P/D setup where prefill is not as important so we
can save some memory.
FULL_AND_PIECEWISE mode: Capture full cudagraph for decode batches and
piecewise cudagraph for prefill and mixed prefill-decode batches.
This is like the most performant mode for most models.
Currently, the cudagraph mode is only used for the v1 engine.
Note that the cudagraph logic is generally orthogonal to the
compilation logic. While piecewise cudagraphs require piecewise
compilation (level=PIECEWISE and non-empty splitting_ops), full
cudagraphs are supported with and without compilation.
Warning: This flag is new and subject to change in addition
more modes may be added.
"""
use_cudagraph: bool = True
"""Whether to use cudagraph inside compilation. """Whether to use cudagraph inside compilation.
- False: cudagraph inside compilation is not used. - False: cudagraph inside compilation is not used.
- True: cudagraph inside compilation is used. It requires - True: cudagraph inside compilation is used. It requires
...@@ -197,8 +269,9 @@ class CompilationConfig: ...@@ -197,8 +269,9 @@ class CompilationConfig:
CompilationLevel.PIECEWISE (aka -O3). CompilationLevel.PIECEWISE (aka -O3).
Note that this is orthogonal to the cudagraph capture logic Note that this is orthogonal to the cudagraph capture logic
outside of compilation. outside of compilation.
TODO: move outside cudagraph logic into compilation. Warning: This flag is deprecated and will be removed in the next major or
torch.compile will handle cudagraph capture logic in the future.""" minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode instead.
"""
cudagraph_num_of_warmups: int = 0 cudagraph_num_of_warmups: int = 0
"""Number of warmup runs for cudagraph. """Number of warmup runs for cudagraph.
It means the first several runs will be treated as warmup runs. It means the first several runs will be treated as warmup runs.
...@@ -213,12 +286,17 @@ class CompilationConfig: ...@@ -213,12 +286,17 @@ class CompilationConfig:
cudagraph. If the caller can guarantee that the same input buffers cudagraph. If the caller can guarantee that the same input buffers
are always used, it can set this to False. Otherwise, it should are always used, it can set this to False. Otherwise, it should
set this to True, and the compiler will copy the input to an set this to True, and the compiler will copy the input to an
internally managed buffer. Default is False.""" internally managed buffer. Default is False.
full_cuda_graph: bool = False Note that this flag is only effective when cudagraph_mode is PIECEWISE.
"""
full_cuda_graph: Optional[bool] = False
"""whether to use a full cuda graph for the entire forward pass rather than """whether to use a full cuda graph for the entire forward pass rather than
splitting certain operations such as attention into subgraphs. Thus this splitting certain operations such as attention into subgraphs. Thus this
flag cannot be used together with splitting_ops. This may provide flag cannot be used together with splitting_ops. This may provide
performance benefits for smaller models.""" performance benefits for smaller models.
Warning: This flag is deprecated and will be removed in the next major or
minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode instead.
"""
pass_config: PassConfig = field(default_factory=PassConfig) pass_config: PassConfig = field(default_factory=PassConfig)
"""Custom inductor passes, see PassConfig for more details""" """Custom inductor passes, see PassConfig for more details"""
...@@ -253,6 +331,13 @@ class CompilationConfig: ...@@ -253,6 +331,13 @@ class CompilationConfig:
Map from layer name to layer objects that need to be accessed outside Map from layer name to layer objects that need to be accessed outside
model code, e.g., Attention, FusedMOE when dp_size>1.""" model code, e.g., Attention, FusedMOE when dp_size>1."""
# Attention ops; used for piecewise cudagraphs
_attention_ops: ClassVar[list[str]] = [
"vllm.unified_attention",
"vllm.unified_attention_with_output",
"vllm.mamba_mixer2",
]
def compute_hash(self) -> str: def compute_hash(self) -> str:
""" """
WARNING: Whenever a new field is added to this config, WARNING: Whenever a new field is added to this config,
...@@ -297,13 +382,26 @@ class CompilationConfig: ...@@ -297,13 +382,26 @@ class CompilationConfig:
if pass_config_exclude: if pass_config_exclude:
exclude["pass_config"] = pass_config_exclude exclude["pass_config"] = pass_config_exclude
return TypeAdapter(CompilationConfig).dump_json( # The cast to string is necessary because Pydantic is mocked in docs
self, # builds and sphinx-argparse doesn't know the return type of decode()
exclude=exclude, # type: ignore[arg-type] return str(
exclude_unset=True).decode() TypeAdapter(CompilationConfig).dump_json(
self,
exclude=exclude, # type: ignore[arg-type]
exclude_unset=True).decode())
__str__ = __repr__ __str__ = __repr__
@field_validator("cudagraph_mode", mode="before")
@classmethod
def validate_cudagraph_mode_before(cls, value: Any) -> Any:
"""
enable parse the `cudagraph_mode` enum type from string
"""
if isinstance(value, str):
return CUDAGraphMode[value.upper()]
return value
def __post_init__(self) -> None: def __post_init__(self) -> None:
count_none = self.custom_ops.count("none") count_none = self.custom_ops.count("none")
count_all = self.custom_ops.count("all") count_all = self.custom_ops.count("all")
...@@ -341,7 +439,26 @@ class CompilationConfig: ...@@ -341,7 +439,26 @@ class CompilationConfig:
if isinstance(self.pass_config, dict): if isinstance(self.pass_config, dict):
self.pass_config = PassConfig(**self.pass_config) self.pass_config = PassConfig(**self.pass_config)
def init_backend(self, vllm_config: VllmConfig) -> Union[str, Callable]: # migrate the deprecated flags
if not self.use_cudagraph:
logger.warning("use_cudagraph is deprecated, use "
"cudagraph_mode=NONE instead.")
if self.cudagraph_mode is not None:
raise ValueError(
"use_cudagraph and cudagraph_mode are mutually"
" exclusive, prefer cudagraph_mode since "
"use_cudagraph is deprecated.")
self.cudagraph_mode = CUDAGraphMode.NONE
if self.full_cuda_graph:
logger.warning("full_cuda_graph is deprecated, use "
"cudagraph_mode=FULL instead.")
if self.cudagraph_mode is not None:
raise ValueError("full_cuda_graph and cudagraph_mode are "
"mutually exclusive, prefer cudagraph_mode "
"since full_cuda_graph is deprecated.")
self.cudagraph_mode = CUDAGraphMode.FULL
def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]:
if self.level == CompilationLevel.NO_COMPILATION: if self.level == CompilationLevel.NO_COMPILATION:
raise ValueError("No compilation level is set.") raise ValueError("No compilation level is set.")
...@@ -414,15 +531,34 @@ class CompilationConfig: ...@@ -414,15 +531,34 @@ class CompilationConfig:
self.max_capture_size] = self.max_capture_size self.max_capture_size] = self.max_capture_size
def set_splitting_ops_for_v1(self): def set_splitting_ops_for_v1(self):
# NOTE: this function needs to be called # NOTE: this function needs to be called only when level is
if self.splitting_ops and self.full_cuda_graph: # CompilationLevel.PIECEWISE
raise ValueError("full_cuda_graph cannot be used together with " assert self.level == CompilationLevel.PIECEWISE, (
"splitting_ops, as Full CUDA graph will override " "set_splitting_ops_for_v1 should only be called when "
f"the splitting_ops: {self.splitting_ops}") "level is CompilationLevel.PIECEWISE")
if not self.splitting_ops: if self.splitting_ops is None:
self.splitting_ops = [] if self.full_cuda_graph else [ # NOTE: When using full cudagraph, instead of setting an empty
"vllm.unified_attention", # list and capture the full cudagraph inside the flattened fx
"vllm.unified_attention_with_output", # graph, we keep the piecewise fx graph structure but capture the
"vllm.mamba_mixer2", # full cudagraph outside the fx graph. This reduces some cpu
] # overhead when the runtime batch_size is not cudagraph captured.
# see https://github.com/vllm-project/vllm/pull/20059 for details.
self.splitting_ops = self._attention_ops
elif len(self.splitting_ops) == 0:
logger.warning_once("Using piecewise compilation with empty "
"splitting_ops.")
if self.cudagraph_mode == CUDAGraphMode.PIECEWISE:
logger.warning_once(
"When compilation level is piecewise with empty "
"splitting_ops, PIECEWISE cudagraph_mode will be "
"treated as FULL cudagraph_mode. Please ensure you are "
"using attention backends that support cudagraph or set "
"cudagraph_mode to NONE explicitly if encountering "
"any problems.")
self.cudagraph_mode = CUDAGraphMode.FULL
self.splitting_ops = []
def splitting_ops_contain_attention(self) -> bool:
return self.splitting_ops is not None and all(
op in self.splitting_ops for op in self._attention_ops)
...@@ -5,13 +5,13 @@ import time ...@@ -5,13 +5,13 @@ import time
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional, Union from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import vllm.envs as envs import vllm.envs as envs
from vllm.config import ParallelConfig, VllmConfig from vllm.config import CUDAGraphMode, ParallelConfig, VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -26,6 +26,27 @@ batchsize_logging_interval: float = envs.VLLM_LOG_BATCHSIZE_INTERVAL ...@@ -26,6 +26,27 @@ batchsize_logging_interval: float = envs.VLLM_LOG_BATCHSIZE_INTERVAL
batchsize_forward_time: defaultdict = defaultdict(list) batchsize_forward_time: defaultdict = defaultdict(list)
class BatchDescriptor(NamedTuple):
"""
Batch descriptor for cudagraph dispatching. We should keep the num of
items as minimal as possible to properly and uniquely describe the padded
batch for cudagraph.
"""
num_tokens: int
uniform_decode: bool = False
"""
False can also be used for an uniform decode batch to dispatch to the
cudagraph supporting non-uniform batches.
"""
@property
def non_uniform(self) -> "BatchDescriptor":
"""
Return a non-uniform version of current batch descriptor.
"""
return BatchDescriptor(self.num_tokens, uniform_decode=False)
def _compute_chunked_local_num_tokens(num_tokens_across_dp_cpu: list[int], def _compute_chunked_local_num_tokens(num_tokens_across_dp_cpu: list[int],
max_num_tokens: int, max_num_tokens: int,
chunk_idx: int) -> list[int]: chunk_idx: int) -> list[int]:
...@@ -152,7 +173,15 @@ class ForwardContext: ...@@ -152,7 +173,15 @@ class ForwardContext:
virtual_engine: int # set dynamically for each forward pass virtual_engine: int # set dynamically for each forward pass
# set dynamically for each forward pass # set dynamically for each forward pass
dp_metadata: Optional[DPMetadata] = None dp_metadata: Optional[DPMetadata] = None
skip_cuda_graphs: bool = False # determine the cudagraph style at runtime to be FULL, PIECEWISE, or NONE.
# by default NONE, no cudagraph is used.
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE
batch_descriptor: Optional[BatchDescriptor] = None
def __post_init__(self):
assert self.cudagraph_runtime_mode in [
CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], \
f"Invalid cudagraph runtime mode: {self.cudagraph_runtime_mode}"
_forward_context: Optional[ForwardContext] = None _forward_context: Optional[ForwardContext] = None
...@@ -168,13 +197,13 @@ def get_forward_context() -> ForwardContext: ...@@ -168,13 +197,13 @@ def get_forward_context() -> ForwardContext:
@contextmanager @contextmanager
def set_forward_context( def set_forward_context(
attn_metadata: Any, attn_metadata: Any,
vllm_config: VllmConfig, vllm_config: VllmConfig,
virtual_engine: int = 0, virtual_engine: int = 0,
num_tokens: Optional[int] = None, num_tokens: Optional[int] = None,
num_tokens_across_dp: Optional[torch.Tensor] = None, num_tokens_across_dp: Optional[torch.Tensor] = None,
skip_cuda_graphs: bool = False, cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
): batch_descriptor: Optional[BatchDescriptor] = None):
"""A context manager that stores the current forward context, """A context manager that stores the current forward context,
can be attention metadata, etc. can be attention metadata, etc.
Here we can inject common logic for every model forward pass. Here we can inject common logic for every model forward pass.
...@@ -198,7 +227,8 @@ def set_forward_context( ...@@ -198,7 +227,8 @@ def set_forward_context(
virtual_engine=virtual_engine, virtual_engine=virtual_engine,
attn_metadata=attn_metadata, attn_metadata=attn_metadata,
dp_metadata=dp_metadata, dp_metadata=dp_metadata,
skip_cuda_graphs=skip_cuda_graphs, cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=batch_descriptor,
) )
try: try:
......
...@@ -177,17 +177,20 @@ class CudaPlatformBase(Platform): ...@@ -177,17 +177,20 @@ class CudaPlatformBase(Platform):
logger.info("Forcing kv cache block size to 128 for " logger.info("Forcing kv cache block size to 128 for "
"CUTLASS_MLA backend.") "CUTLASS_MLA backend.")
# lazy import to avoid circular import
from vllm.config import CUDAGraphMode
compilation_config = vllm_config.compilation_config compilation_config = vllm_config.compilation_config
if (envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" if (envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput"
and parallel_config.data_parallel_size > 1 and parallel_config.data_parallel_size > 1
and compilation_config.use_cudagraph): and compilation_config.cudagraph_mode != CUDAGraphMode.NONE):
logger.info( logger.info(
"Data Parallel: Forcing enforce eager to be True since DP " "Data Parallel: disabling cudagraphs since DP "
"with DeepEP high-throughput kernels are not CUDA Graph " "with DeepEP high-throughput kernels are not CUDA Graph "
"compatible. The DeepEP low-latency kernels are CUDA Graph " "compatible. The DeepEP low-latency kernels are CUDA Graph "
"compatible. Set the all_to_all backend to deepep_low_latency " "compatible. Set the all_to_all backend to deepep_low_latency "
"to use those kernels instead.") "to use those kernels instead.")
compilation_config.use_cudagraph = False compilation_config.cudagraph_mode = CUDAGraphMode.NONE
if model_config is not None: if model_config is not None:
model_config.enforce_eager = True model_config.enforce_eager = True
...@@ -454,8 +457,8 @@ class CudaPlatformBase(Platform): ...@@ -454,8 +457,8 @@ class CudaPlatformBase(Platform):
return True return True
@classmethod @classmethod
def get_piecewise_backend_cls(cls) -> str: def get_static_graph_wrapper_cls(cls) -> str:
return "vllm.compilation.cuda_piecewise_backend.CUDAPiecewiseBackend" # noqa return "vllm.compilation.cuda_graph.CUDAGraphWrapper"
@classmethod @classmethod
def stateless_init_device_torch_dist_pg( def stateless_init_device_torch_dist_pg(
......
...@@ -7,7 +7,7 @@ import random ...@@ -7,7 +7,7 @@ import random
import sys import sys
from datetime import timedelta from datetime import timedelta
from platform import uname from platform import uname
from typing import TYPE_CHECKING, NamedTuple, Optional, Union from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union
import numpy as np import numpy as np
import torch import torch
...@@ -137,6 +137,8 @@ class Platform: ...@@ -137,6 +137,8 @@ class Platform:
additional_env_vars: list[str] = [] additional_env_vars: list[str] = []
_global_graph_pool: Optional[Any] = None
@property @property
def supported_dtypes(self) -> list[torch.dtype]: def supported_dtypes(self) -> list[torch.dtype]:
"""Returns the supported dtypes for the current platform.""" """Returns the supported dtypes for the current platform."""
...@@ -522,6 +524,15 @@ class Platform: ...@@ -522,6 +524,15 @@ class Platform:
" attribute.", self.device_type, key) " attribute.", self.device_type, key)
return None return None
def get_global_graph_pool(self) -> Any:
"""
Return the global graph pool for the this platform.
"""
cls = self.__class__
if cls._global_graph_pool is None:
cls._global_graph_pool = self.graph_pool_handle()
return cls._global_graph_pool
@classmethod @classmethod
def get_cu_count(cls, device_id: int = 0) -> int: def get_cu_count(cls, device_id: int = 0) -> int:
""" """
...@@ -530,11 +541,11 @@ class Platform: ...@@ -530,11 +541,11 @@ class Platform:
raise NotImplementedError raise NotImplementedError
@classmethod @classmethod
def get_piecewise_backend_cls(cls) -> str: def get_static_graph_wrapper_cls(cls) -> str:
""" """
Get piecewise backend class for piecewise graph. Get static graph wrapper class for static graph.
""" """
return "vllm.compilation.base_piecewise_backend.AbstractPiecewiseBackend" # noqa return "vllm.compilation.base_static_graph.AbstractStaticGraphWrapper"
@classmethod @classmethod
def stateless_init_device_torch_dist_pg( def stateless_init_device_torch_dist_pg(
......
...@@ -421,8 +421,8 @@ class RocmPlatform(Platform): ...@@ -421,8 +421,8 @@ class RocmPlatform(Platform):
return 'gfx1' in torch.cuda.get_device_properties(0).gcnArchName return 'gfx1' in torch.cuda.get_device_properties(0).gcnArchName
@classmethod @classmethod
def get_piecewise_backend_cls(cls) -> str: def get_static_graph_wrapper_cls(cls) -> str:
return "vllm.compilation.cuda_piecewise_backend.CUDAPiecewiseBackend" # noqa return "vllm.compilation.cuda_graph.CUDAGraphWrapper"
@classmethod @classmethod
def stateless_init_device_torch_dist_pg( def stateless_init_device_torch_dist_pg(
......
...@@ -99,7 +99,7 @@ class TpuPlatform(Platform): ...@@ -99,7 +99,7 @@ class TpuPlatform(Platform):
@classmethod @classmethod
def check_and_update_config(cls, vllm_config: VllmConfig) -> None: def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
from vllm.config import CompilationLevel from vllm.config import CompilationLevel, CUDAGraphMode
cache_config = vllm_config.cache_config cache_config = vllm_config.cache_config
# For v0, the default block size is 16. # For v0, the default block size is 16.
...@@ -109,9 +109,17 @@ class TpuPlatform(Platform): ...@@ -109,9 +109,17 @@ class TpuPlatform(Platform):
# TPU only supports DYNAMO_ONCE compilation level # TPU only supports DYNAMO_ONCE compilation level
if compilation_config.level != CompilationLevel.DYNAMO_ONCE: if compilation_config.level != CompilationLevel.DYNAMO_ONCE:
logger.info("[TPU] Forcing DYNAMO_ONCE compilation level") logger.info("[TPU] Forcing DYNAMO_ONCE compilation level, and "
"disabling cudagraph.")
compilation_config.level = CompilationLevel.DYNAMO_ONCE compilation_config.level = CompilationLevel.DYNAMO_ONCE
if compilation_config.cudagraph_mode is None or \
compilation_config.cudagraph_mode.max_cudagraph_mode() \
!= CUDAGraphMode.NONE:
logger.info("[TPU] CUDA graph is not supported on TPU, "
"disabling cudagraphs.")
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
if compilation_config.backend == "": if compilation_config.backend == "":
compilation_config.backend = "openxla" compilation_config.backend = "openxla"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment