Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
23322431
Unverified
Commit
23322431
authored
Aug 02, 2025
by
fhl2000
Committed by
GitHub
Aug 01, 2025
Browse files
[V1][CUDA] Full cudagraph support for FlashInfer (#21367)
parent
3654847d
Changes
8
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
377 additions
and
48 deletions
+377
-48
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+5
-2
vllm/v1/attention/backends/flashinfer.py
vllm/v1/attention/backends/flashinfer.py
+323
-34
vllm/v1/attention/backends/mla/flashmla.py
vllm/v1/attention/backends/mla/flashmla.py
+3
-1
vllm/v1/attention/backends/mla/rocm_aiter_mla.py
vllm/v1/attention/backends/mla/rocm_aiter_mla.py
+3
-1
vllm/v1/attention/backends/triton_attn.py
vllm/v1/attention/backends/triton_attn.py
+4
-2
vllm/v1/attention/backends/utils.py
vllm/v1/attention/backends/utils.py
+17
-1
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+17
-7
vllm/v1/worker/gpu_worker.py
vllm/v1/worker/gpu_worker.py
+5
-0
No files found.
vllm/v1/attention/backends/flash_attn.py
View file @
23322431
...
@@ -25,7 +25,8 @@ if is_flash_attn_varlen_func_available():
...
@@ -25,7 +25,8 @@ if is_flash_attn_varlen_func_available():
from
vllm.config
import
VllmConfig
,
get_layers_from_vllm_config
from
vllm.config
import
VllmConfig
,
get_layers_from_vllm_config
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.utils
import
cdiv
from
vllm.utils
import
cdiv
from
vllm.v1.attention.backends.utils
import
(
AttentionMetadataBuilder
,
from
vllm.v1.attention.backends.utils
import
(
AttentionCGSupport
,
AttentionMetadataBuilder
,
CommonAttentionMetadata
,
CommonAttentionMetadata
,
get_kv_cache_layout
)
get_kv_cache_layout
)
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.kv_cache_interface
import
AttentionSpec
...
@@ -153,7 +154,9 @@ def _get_sliding_window_configs(
...
@@ -153,7 +154,9 @@ def _get_sliding_window_configs(
class
FlashAttentionMetadataBuilder
(
class
FlashAttentionMetadataBuilder
(
AttentionMetadataBuilder
[
FlashAttentionMetadata
]):
AttentionMetadataBuilder
[
FlashAttentionMetadata
]):
full_cudagraph_supported
:
ClassVar
[
bool
]
=
get_flash_attn_version
()
==
3
attn_cudagraph_support
:
ClassVar
[
AttentionCGSupport
]
=
\
AttentionCGSupport
.
NEVER
if
get_flash_attn_version
()
==
2
\
else
AttentionCGSupport
.
ALWAYS
def
__init__
(
self
,
kv_cache_spec
:
AttentionSpec
,
layer_names
:
list
[
str
],
def
__init__
(
self
,
kv_cache_spec
:
AttentionSpec
,
layer_names
:
list
[
str
],
vllm_config
:
VllmConfig
,
device
:
torch
.
device
):
vllm_config
:
VllmConfig
,
device
:
torch
.
device
):
...
...
vllm/v1/attention/backends/flashinfer.py
View file @
23322431
This diff is collapsed.
Click to expand it.
vllm/v1/attention/backends/mla/flashmla.py
View file @
23322431
...
@@ -18,6 +18,7 @@ from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
...
@@ -18,6 +18,7 @@ from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
MLACommonImpl
,
MLACommonImpl
,
MLACommonMetadata
,
MLACommonMetadata
,
MLACommonMetadataBuilder
)
MLACommonMetadataBuilder
)
from
vllm.v1.attention.backends.utils
import
AttentionCGSupport
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.kv_cache_interface
import
AttentionSpec
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -54,7 +55,8 @@ class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]):
...
@@ -54,7 +55,8 @@ class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]):
class
FlashMLAMetadataBuilder
(
MLACommonMetadataBuilder
[
FlashMLAMetadata
]):
class
FlashMLAMetadataBuilder
(
MLACommonMetadataBuilder
[
FlashMLAMetadata
]):
full_cudagraph_supported
:
ClassVar
[
bool
]
=
True
# Decode-only
attn_cudagraph_support
:
ClassVar
[
AttentionCGSupport
]
=
\
AttentionCGSupport
.
PURE_DECODE_ONLY
def
__init__
(
self
,
kv_cache_spec
:
AttentionSpec
,
layer_names
:
list
[
str
],
def
__init__
(
self
,
kv_cache_spec
:
AttentionSpec
,
layer_names
:
list
[
str
],
vllm_config
:
VllmConfig
,
device
:
torch
.
device
):
vllm_config
:
VllmConfig
,
device
:
torch
.
device
):
...
...
vllm/v1/attention/backends/mla/rocm_aiter_mla.py
View file @
23322431
...
@@ -17,6 +17,7 @@ from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
...
@@ -17,6 +17,7 @@ from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
MLACommonImpl
,
MLACommonImpl
,
MLACommonMetadata
,
MLACommonMetadata
,
MLACommonMetadataBuilder
)
MLACommonMetadataBuilder
)
from
vllm.v1.attention.backends.utils
import
AttentionCGSupport
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.kv_cache_interface
import
AttentionSpec
# yapf: enable
# yapf: enable
...
@@ -64,7 +65,8 @@ class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
...
@@ -64,7 +65,8 @@ class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
class
AiterMLAMetadataBuilder
(
MLACommonMetadataBuilder
[
AiterMLAMetadata
]):
class
AiterMLAMetadataBuilder
(
MLACommonMetadataBuilder
[
AiterMLAMetadata
]):
full_cudagraph_supported
:
ClassVar
[
bool
]
=
True
# decode only
attn_cudagraph_support
:
ClassVar
[
AttentionCGSupport
]
=
\
AttentionCGSupport
.
PURE_DECODE_ONLY
def
__init__
(
self
,
kv_cache_spec
:
AttentionSpec
,
layer_names
:
list
[
str
],
def
__init__
(
self
,
kv_cache_spec
:
AttentionSpec
,
layer_names
:
list
[
str
],
vllm_config
:
VllmConfig
,
device
:
torch
.
device
):
vllm_config
:
VllmConfig
,
device
:
torch
.
device
):
...
...
vllm/v1/attention/backends/triton_attn.py
View file @
23322431
...
@@ -18,7 +18,8 @@ from vllm.config import VllmConfig
...
@@ -18,7 +18,8 @@ from vllm.config import VllmConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.v1.attention.backends.flash_attn
import
FlashAttentionMetadata
from
vllm.v1.attention.backends.flash_attn
import
FlashAttentionMetadata
from
vllm.v1.attention.backends.utils
import
(
AttentionMetadataBuilder
,
from
vllm.v1.attention.backends.utils
import
(
AttentionCGSupport
,
AttentionMetadataBuilder
,
CommonAttentionMetadata
)
CommonAttentionMetadata
)
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.kv_cache_interface
import
AttentionSpec
...
@@ -57,7 +58,8 @@ class TritonAttentionMetadata:
...
@@ -57,7 +58,8 @@ class TritonAttentionMetadata:
class
TritonAttentionMetadataBuilder
(
class
TritonAttentionMetadataBuilder
(
AttentionMetadataBuilder
[
TritonAttentionMetadata
]):
AttentionMetadataBuilder
[
TritonAttentionMetadata
]):
full_cudagraph_supported
:
ClassVar
[
bool
]
=
True
attn_cudagraph_support
:
ClassVar
[
AttentionCGSupport
]
=
\
AttentionCGSupport
.
ALWAYS
def
__init__
(
self
,
kv_cache_spec
:
AttentionSpec
,
layer_names
:
list
[
str
],
def
__init__
(
self
,
kv_cache_spec
:
AttentionSpec
,
layer_names
:
list
[
str
],
vllm_config
:
VllmConfig
,
device
:
torch
.
device
):
vllm_config
:
VllmConfig
,
device
:
torch
.
device
):
...
...
vllm/v1/attention/backends/utils.py
View file @
23322431
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
abc
import
abc
import
enum
import
functools
import
functools
from
abc
import
abstractmethod
from
abc
import
abstractmethod
from
dataclasses
import
dataclass
,
make_dataclass
from
dataclasses
import
dataclass
,
make_dataclass
...
@@ -65,9 +66,24 @@ class CommonAttentionMetadata:
...
@@ -65,9 +66,24 @@ class CommonAttentionMetadata:
M
=
TypeVar
(
"M"
)
M
=
TypeVar
(
"M"
)
class
AttentionCGSupport
(
enum
.
Enum
):
""" Constants for the cudagraph support of the attention backend
Here we do not consider the cascade attention, as currently
it is never cudagraph supported."""
NEVER
=
0
"""NO cudagraph support"""
PURE_DECODE_ONLY
=
1
"""Cudagraph supported for pure decode, need to run without
cudagraph for mixed prefill-decode batches"""
ALWAYS
=
2
"""Cudagraph always supported"""
class
AttentionMetadataBuilder
(
abc
.
ABC
,
Generic
[
M
]):
class
AttentionMetadataBuilder
(
abc
.
ABC
,
Generic
[
M
]):
# Does this backend/builder support CUDA Graphs for attention.
# Does this backend/builder support CUDA Graphs for attention.
full_cudagraph_supported
:
ClassVar
[
bool
]
=
False
attn_cudagraph_support
:
ClassVar
[
AttentionCGSupport
]
=
\
AttentionCGSupport
.
NEVER
@
abstractmethod
@
abstractmethod
def
__init__
(
self
,
kv_cache_spec
:
AttentionSpec
,
layer_names
:
list
[
str
],
def
__init__
(
self
,
kv_cache_spec
:
AttentionSpec
,
layer_names
:
list
[
str
],
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
23322431
...
@@ -47,7 +47,7 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
...
@@ -47,7 +47,7 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
is_pin_memory_available
,
round_up
,
supports_dynamo
)
is_pin_memory_available
,
round_up
,
supports_dynamo
)
from
vllm.v1.attention.backends.mamba_selectors
import
get_mamba_attn_backend
from
vllm.v1.attention.backends.mamba_selectors
import
get_mamba_attn_backend
from
vllm.v1.attention.backends.utils
import
(
from
vllm.v1.attention.backends.utils
import
(
AttentionMetadataBuilder
,
CommonAttentionMetadata
,
AttentionCGSupport
,
AttentionMetadataBuilder
,
CommonAttentionMetadata
,
make_kv_sharing_fast_prefill_attention_metadata
,
make_kv_sharing_fast_prefill_attention_metadata
,
make_local_attention_virtual_batches
)
make_local_attention_virtual_batches
)
from
vllm.v1.core.encoder_cache_manager
import
compute_encoder_budget
from
vllm.v1.core.encoder_cache_manager
import
compute_encoder_budget
...
@@ -2619,12 +2619,22 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -2619,12 +2619,22 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self
.
device
,
self
.
device
,
)
)
if
(
self
.
full_cuda_graph
if
self
.
full_cuda_graph
:
and
not
attn_metadata_builder_i
.
full_cudagraph_supported
):
if
attn_metadata_builder_i
.
attn_cudagraph_support
==
\
raise
ValueError
(
AttentionCGSupport
.
NEVER
:
f
"Full CUDAGraph not supported for "
raise
ValueError
(
f
"Full CUDAGraph not supported for "
f
"
{
attn_backend_i
.
__name__
}
. Turn off CompilationConfig."
f
"
{
attn_backend_i
.
__name__
}
. Turn off "
f
"full_cuda_graph or use a different attention backend."
)
f
"CompilationConfig.full_cuda_graph or use a "
f
" different attention backend."
)
if
attn_metadata_builder_i
.
attn_cudagraph_support
==
\
AttentionCGSupport
.
PURE_DECODE_ONLY
:
# Limit the max cudagraph size to the max number of
# sequences for pure decode only cudagraph backend,
# whose max_query_len is 1.
self
.
cudagraph_batch_sizes
=
[
size
for
size
in
self
.
cudagraph_batch_sizes
if
size
<=
self
.
scheduler_config
.
max_num_seqs
]
return
attn_backend_i
,
attn_metadata_builder_i
return
attn_backend_i
,
attn_metadata_builder_i
def
initialize_attn_backend
(
self
,
kv_cache_config
:
KVCacheConfig
)
->
None
:
def
initialize_attn_backend
(
self
,
kv_cache_config
:
KVCacheConfig
)
->
None
:
...
...
vllm/v1/worker/gpu_worker.py
View file @
23322431
...
@@ -321,11 +321,16 @@ class Worker(WorkerBase):
...
@@ -321,11 +321,16 @@ class Worker(WorkerBase):
if
get_pp_group
().
is_last_rank
:
if
get_pp_group
().
is_last_rank
:
max_num_reqs
=
min
(
self
.
scheduler_config
.
max_num_seqs
,
max_num_reqs
=
min
(
self
.
scheduler_config
.
max_num_seqs
,
self
.
scheduler_config
.
max_num_batched_tokens
)
self
.
scheduler_config
.
max_num_batched_tokens
)
# activate building attn_metadata for this dummy run to avoid
# potential illegal memory access for full cudagraph relay.
attn_cudagraph
=
self
.
compilation_config
.
full_cuda_graph
and
\
not
self
.
model_config
.
enforce_eager
# We skip EPLB here since we don't want to record dummy metrics
# We skip EPLB here since we don't want to record dummy metrics
hidden_states
,
last_hidden_states
=
\
hidden_states
,
last_hidden_states
=
\
self
.
model_runner
.
_dummy_run
(
self
.
model_runner
.
_dummy_run
(
num_tokens
=
max_num_reqs
,
num_tokens
=
max_num_reqs
,
capture_attn_cudagraph
=
attn_cudagraph
,
skip_eplb
=
True
,
skip_eplb
=
True
,
)
)
if
self
.
model_runner
.
is_pooling_model
:
if
self
.
model_runner
.
is_pooling_model
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment