Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
23322431
Unverified
Commit
23322431
authored
Aug 02, 2025
by
fhl2000
Committed by
GitHub
Aug 01, 2025
Browse files
[V1][CUDA] Full cudagraph support for FlashInfer (#21367)
parent
3654847d
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
377 additions
and
48 deletions
+377
-48
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+5
-2
vllm/v1/attention/backends/flashinfer.py
vllm/v1/attention/backends/flashinfer.py
+323
-34
vllm/v1/attention/backends/mla/flashmla.py
vllm/v1/attention/backends/mla/flashmla.py
+3
-1
vllm/v1/attention/backends/mla/rocm_aiter_mla.py
vllm/v1/attention/backends/mla/rocm_aiter_mla.py
+3
-1
vllm/v1/attention/backends/triton_attn.py
vllm/v1/attention/backends/triton_attn.py
+4
-2
vllm/v1/attention/backends/utils.py
vllm/v1/attention/backends/utils.py
+17
-1
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+17
-7
vllm/v1/worker/gpu_worker.py
vllm/v1/worker/gpu_worker.py
+5
-0
No files found.
vllm/v1/attention/backends/flash_attn.py
View file @
23322431
...
...
@@ -25,7 +25,8 @@ if is_flash_attn_varlen_func_available():
from
vllm.config
import
VllmConfig
,
get_layers_from_vllm_config
from
vllm.logger
import
init_logger
from
vllm.utils
import
cdiv
from
vllm.v1.attention.backends.utils
import
(
AttentionMetadataBuilder
,
from
vllm.v1.attention.backends.utils
import
(
AttentionCGSupport
,
AttentionMetadataBuilder
,
CommonAttentionMetadata
,
get_kv_cache_layout
)
from
vllm.v1.kv_cache_interface
import
AttentionSpec
...
...
@@ -153,7 +154,9 @@ def _get_sliding_window_configs(
class
FlashAttentionMetadataBuilder
(
AttentionMetadataBuilder
[
FlashAttentionMetadata
]):
full_cudagraph_supported
:
ClassVar
[
bool
]
=
get_flash_attn_version
()
==
3
attn_cudagraph_support
:
ClassVar
[
AttentionCGSupport
]
=
\
AttentionCGSupport
.
NEVER
if
get_flash_attn_version
()
==
2
\
else
AttentionCGSupport
.
ALWAYS
def
__init__
(
self
,
kv_cache_spec
:
AttentionSpec
,
layer_names
:
list
[
str
],
vllm_config
:
VllmConfig
,
device
:
torch
.
device
):
...
...
vllm/v1/attention/backends/flashinfer.py
View file @
23322431
...
...
@@ -4,26 +4,28 @@
from
__future__
import
annotations
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Optional
from
typing
import
TYPE_CHECKING
,
ClassVar
,
Optional
,
Union
import
torch
from
flashinfer
import
(
BatchDecodeWithPagedKVCacheWrapper
,
BatchPrefillWithPagedKVCacheWrapper
,
MultiLevelCascadeAttentionWrapper
)
from
flashinfer.decode
import
trtllm_batch_decode_with_kv_cache
from
flashinfer.decode
import
(
_get_range_buf
,
get_seq_lens
,
trtllm_batch_decode_with_kv_cache
)
import
vllm.envs
as
envs
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionType
)
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.utils
import
cdiv
from
vllm.utils
import
cdiv
,
is_pin_memory_available
from
vllm.utils.flashinfer
import
use_trtllm_decode_attention
from
vllm.v1.attention.backends.flash_attn
import
use_cascade_attention
from
vllm.v1.attention.backends.utils
import
(
AttentionMetadataBuilder
,
CommonAttentionMetadata
,
get_kv_cache_layout
,
get_per_layer_parameters
,
infer_global_hyperparameters
,
reorder_batch_to_split_decodes_and_prefills
,
split_decodes_and_prefills
)
AttentionCGSupport
,
AttentionMetadataBuilder
,
CommonAttentionMetadata
,
get_kv_cache_layout
,
get_per_layer_parameters
,
infer_global_hyperparameters
,
reorder_batch_to_split_decodes_and_prefills
,
split_decodes_and_prefills
)
from
vllm.v1.kv_cache_interface
import
AttentionSpec
if
TYPE_CHECKING
:
...
...
@@ -174,26 +176,66 @@ class FlashInferMetadata:
class
FlashInferMetadataBuilder
(
AttentionMetadataBuilder
[
FlashInferMetadata
]):
attn_cudagraph_support
:
ClassVar
[
AttentionCGSupport
]
=
\
AttentionCGSupport
.
PURE_DECODE_ONLY
def
__init__
(
self
,
kv_cache_spec
:
AttentionSpec
,
layer_names
:
list
[
str
],
vllm_config
:
VllmConfig
,
device
:
torch
.
device
):
self
.
device
=
device
self
.
vllm_config
=
vllm_config
self
.
cache_config
=
vllm_config
.
cache_config
self
.
kv_cache_spec
=
kv_cache_spec
self
.
_workspace_buffer
=
None
self
.
_prefill_wrapper
=
None
# Wrapper for prefill/append
self
.
_decode_wrapper
=
None
# Wrapper for decode
self
.
_decode_wrapper
=
None
# Wrapper for decode (general shape)
self
.
compilation_config
=
vllm_config
.
compilation_config
max_num_pages_per_req
=
cdiv
(
vllm_config
.
model_config
.
max_model_len
,
self
.
kv_cache_spec
.
block_size
)
max_num_reqs
=
vllm_config
.
scheduler_config
.
max_num_seqs
max_num_pages
=
max_num_reqs
*
max_num_pages_per_req
self
.
enable_cuda_graph
=
self
.
compilation_config
.
full_cuda_graph
if
self
.
enable_cuda_graph
:
# For full cudagraph capture, one `decode_wrapper` for each batch
# size is needed for FlashInfer.
self
.
_decode_wrappers_cudagraph
:
dict
[
int
,
BatchDecodeWithPagedKVCacheWrapper
]
=
{}
self
.
_decode_cudagraph_max_bs
=
min
(
max_num_reqs
,
self
.
compilation_config
.
max_capture_size
)
self
.
_cascade_wrapper
=
None
# Wrapper for cascade attention
# Global hyperparameters shared by all attention layers
self
.
global_hyperparameters
=
infer_global_hyperparameters
(
get_per_layer_parameters
(
vllm_config
,
layer_names
,
FlashInferImpl
))
self
.
vllm_config
=
vllm_config
self
.
cache_config
=
vllm_config
.
cache_config
self
.
kv_cache_spec
=
kv_cache_spec
max_num_blocks_per_request
=
cdiv
(
vllm_config
.
model_config
.
max_model_len
,
self
.
kv_cache_spec
.
block_size
)
self
.
block_table_arange
=
torch
.
arange
(
max_num_blocks_per_request
,
# Preparing persistent buffers (device-side)
self
.
paged_kv_indptr
=
torch
.
zeros
(
max_num_reqs
+
1
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
self
.
paged_kv_indices
=
torch
.
zeros
(
max_num_pages
,
# max num pages possible
dtype
=
torch
.
int32
,
device
=
self
.
device
)
self
.
paged_kv_last_page_len
=
torch
.
zeros
(
max_num_reqs
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
# host-side buffer
pin_memory
=
is_pin_memory_available
()
self
.
paged_kv_indptr_cpu
=
torch
.
zeros
(
max_num_reqs
+
1
,
dtype
=
torch
.
int32
,
device
=
"cpu"
,
pin_memory
=
pin_memory
)
self
.
paged_kv_indices_cpu
=
torch
.
zeros
(
max_num_pages
,
dtype
=
torch
.
int32
,
device
=
"cpu"
,
pin_memory
=
pin_memory
)
self
.
paged_kv_last_page_len_cpu
=
torch
.
zeros
(
max_num_reqs
,
dtype
=
torch
.
int32
,
device
=
"cpu"
,
pin_memory
=
pin_memory
)
self
.
block_table_arange
=
torch
.
arange
(
max_num_pages_per_req
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
...
...
@@ -217,8 +259,16 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self
.
_get_workspace_buffer
(),
get_kv_cache_layout
())
return
self
.
_prefill_wrapper
def
_get_decode_wrapper
(
self
):
if
self
.
_decode_wrapper
is
None
:
def
_get_decode_wrapper
(
self
,
batch_size
:
int
,
use_cudagraph
:
bool
=
False
):
if
use_cudagraph
:
decode_wrapper
=
self
.
_decode_wrappers_cudagraph
.
get
(
batch_size
,
None
)
else
:
decode_wrapper
=
self
.
_decode_wrapper
if
decode_wrapper
is
None
:
num_qo_heads
=
(
self
.
vllm_config
.
model_config
.
get_num_attention_heads
(
self
.
vllm_config
.
parallel_config
))
...
...
@@ -226,11 +276,32 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self
.
vllm_config
.
parallel_config
)
use_tensor_cores
=
envs
.
VLLM_FLASHINFER_FORCE_TENSOR_CORES
or
(
num_qo_heads
//
num_kv_heads
>
4
)
self
.
_decode_wrapper
=
BatchDecodeWithPagedKVCacheWrapper
(
if
use_cudagraph
:
paged_kv_indptr
=
self
.
paged_kv_indptr
[:
batch_size
+
1
]
paged_kv_indices
=
self
.
paged_kv_indices
paged_kv_last_page_len
=
self
.
paged_kv_last_page_len
[:
batch_size
]
else
:
paged_kv_indptr
=
None
paged_kv_indices
=
None
paged_kv_last_page_len
=
None
decode_wrapper
=
BatchDecodeWithPagedKVCacheWrapper
(
self
.
_get_workspace_buffer
(),
get_kv_cache_layout
(),
use_cuda_graph
=
use_cudagraph
,
paged_kv_indptr_buffer
=
paged_kv_indptr
,
paged_kv_indices_buffer
=
paged_kv_indices
,
paged_kv_last_page_len_buffer
=
paged_kv_last_page_len
,
use_tensor_cores
=
use_tensor_cores
)
return
self
.
_decode_wrapper
# save the decode wrapper
if
use_cudagraph
:
self
.
_decode_wrappers_cudagraph
[
batch_size
]
=
decode_wrapper
else
:
self
.
_decode_wrapper
=
decode_wrapper
return
decode_wrapper
def
_get_cascade_wrapper
(
self
):
if
self
.
_cascade_wrapper
is
None
:
...
...
@@ -308,16 +379,44 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
)
if
num_decodes
>
0
:
attn_metadata
.
decode_wrapper
=
self
.
_get_decode_wrapper
()
pure_decode
=
num_prefills
==
0
# possible required padding for cudagraph replay
use_cudagraph
=
(
self
.
enable_cuda_graph
and
pure_decode
and
num_decodes
<=
self
.
_decode_cudagraph_max_bs
)
if
use_cudagraph
:
num_input_tokens
=
(
self
.
vllm_config
.
pad_for_cudagraph
(
num_decodes
))
# Carefully fulfill the padding region with reasonable value
# on cpu.
# Make sure paged_kv_indptr_cpu is not decreasing
self
.
paged_kv_indptr_cpu
[
1
+
num_decodes
:
1
+
num_input_tokens
].
fill_
(
attn_metadata
.
paged_kv_indptr_cpu
[
-
1
])
# Fill the remaining paged_kv_last_page_len_cpu with 1.
# This is because flashinfer treats 0 as a full page
# instead of empty.
self
.
paged_kv_last_page_len_cpu
[
num_decodes
:
num_input_tokens
].
fill_
(
1
)
else
:
num_input_tokens
=
num_decodes
attn_metadata
.
decode_wrapper
=
self
.
_get_decode_wrapper
(
num_input_tokens
,
use_cudagraph
)
if
not
use_trtllm_decode_attention
(
num_decodes
,
attn_metadata
.
max_seq_len
,
self
.
cache_config
.
cache_dtype
,
attn_metadata
.
num_qo_heads
,
attn_metadata
.
num_kv_heads
,
attn_metadata
.
head_dim
):
attn_metadata
.
decode_wrapper
.
plan
(
attn_metadata
.
paged_kv_indptr_cpu
[:
num_decodes
+
1
],
# Use the persistent buffer with padding length,
# instead of the same address but chunked version
# in atten_metadata when using cudagraph.
fast_plan_decode
(
attn_metadata
.
decode_wrapper
,
self
.
paged_kv_indptr_cpu
[:
num_input_tokens
+
1
],
attn_metadata
.
paged_kv_indices
,
attn_metadata
.
paged_kv_last_page_len_cpu
[:
num_
decode
s
],
self
.
paged_kv_last_page_len_cpu
[:
num_
input_token
s
],
attn_metadata
.
num_qo_heads
,
attn_metadata
.
num_kv_heads
,
attn_metadata
.
head_dim
,
...
...
@@ -336,6 +435,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
common_prefix_len
:
int
,
common_attn_metadata
:
CommonAttentionMetadata
,
fast_build
:
bool
=
False
)
->
FlashInferMetadata
:
num_reqs
=
common_attn_metadata
.
num_reqs
num_actual_tokens
=
common_attn_metadata
.
num_actual_tokens
num_decodes
,
num_prefills
,
num_decode_tokens
,
num_prefill_tokens
=
\
split_decodes_and_prefills
(
common_attn_metadata
)
...
...
@@ -381,18 +481,26 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
non_blocking
=
True
)
mask
=
(
self
.
block_table_arange
[:
max_num_blocks
].
unsqueeze
(
0
)
<
block_table_bounds
.
unsqueeze
(
1
))
paged_kv_indices
=
block_table_tensor
[:,
:
max_num_blocks
][
mask
]
paged_kv_indptr_cpu
=
torch
.
zeros
(
len
(
block_table_bounds_cpu
)
+
1
,
dtype
=
torch
.
int32
,
device
=
'cpu'
)
paged_kv_indptr_cpu
[
1
:]
=
block_table_bounds_cpu
.
cumsum
(
dim
=
0
,
dtype
=
torch
.
int32
)
# write self.paged_kv_indices inplace
num_actual_pages
=
torch
.
sum
(
mask
)
paged_kv_indices
=
self
.
paged_kv_indices
[:
num_actual_pages
]
torch
.
masked_select
(
block_table_tensor
[:,
:
max_num_blocks
],
mask
,
out
=
paged_kv_indices
)
# write self.paged_kv_indptr_cpu inplace (0-index is always 0)
torch
.
cumsum
(
block_table_bounds_cpu
,
dim
=
0
,
dtype
=
torch
.
int32
,
out
=
self
.
paged_kv_indptr_cpu
[
1
:
1
+
num_reqs
])
paged_kv_last_page_len_cpu
=
seq_lens_cpu
%
page_size
paged_kv_last_page_len_cpu
=
torch
.
where
(
paged_kv_last_page_len_cpu
==
0
,
page_size
,
paged_kv_last_page_len_cpu
)
# write self.paged_kv_last_page_len_cpu inplace
torch
.
where
(
paged_kv_last_page_len_cpu
==
0
,
torch
.
tensor
(
page_size
),
paged_kv_last_page_len_cpu
,
out
=
self
.
paged_kv_last_page_len_cpu
[:
num_reqs
])
cache_dtype
=
self
.
cache_config
.
cache_dtype
if
cache_dtype
.
startswith
(
"fp8"
):
kv_cache_dtype
=
FlashInferBackend
.
get_fp8_dtype_for_flashinfer
(
...
...
@@ -402,9 +510,10 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
attn_metadata
=
FlashInferMetadata
(
num_actual_tokens
=
num_actual_tokens
,
qo_indptr_cpu
=
common_attn_metadata
.
query_start_loc_cpu
,
paged_kv_indptr_cpu
=
paged_kv_indptr_cpu
,
paged_kv_indptr_cpu
=
self
.
paged_kv_indptr_cpu
[:
1
+
num_reqs
]
,
paged_kv_indices
=
paged_kv_indices
,
paged_kv_last_page_len_cpu
=
paged_kv_last_page_len_cpu
,
paged_kv_last_page_len_cpu
=
self
.
paged_kv_last_page_len_cpu
[:
num_reqs
],
num_qo_heads
=
self
.
vllm_config
.
model_config
.
get_num_attention_heads
(
self
.
vllm_config
.
parallel_config
),
num_kv_heads
=
self
.
kv_cache_spec
.
num_kv_heads
,
...
...
@@ -431,6 +540,26 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
return
attn_metadata
def
build_for_cudagraph_capture
(
self
,
common_attn_metadata
:
CommonAttentionMetadata
):
"""
This method builds the metadata for full cudagraph capture.
Currently, only decode is supported for full cudagraphs with FlashInfer.
"""
m
=
common_attn_metadata
assert
m
.
num_reqs
==
m
.
num_actual_tokens
,
\
"FlashInfer only supports decode-only full CUDAGraph capture. "
\
"Make sure all cudagraph capture sizes <= max_num_seq."
m
.
max_query_len
=
1
# decode-only
return
self
.
build
(
0
,
m
)
def
can_run_in_cudagraph
(
self
,
common_attn_metadata
:
CommonAttentionMetadata
)
->
bool
:
return
common_attn_metadata
.
max_query_len
==
1
def
use_cascade_attention
(
self
,
*
args
,
**
kwargs
)
->
bool
:
if
self
.
kv_cache_spec
.
dtype
!=
self
.
vllm_config
.
model_config
.
dtype
:
# TODO: The cascade wrapper currently does not support setting
...
...
@@ -638,3 +767,163 @@ class FlashInferImpl(AttentionImpl):
out
=
output
[:
num_decode_tokens
],
)
return
output_padded
def
fast_plan_decode
(
self
,
# decode wrapper
indptr_cpu
:
torch
.
Tensor
,
indices
:
torch
.
Tensor
,
last_page_len_cpu
:
torch
.
Tensor
,
num_qo_heads
:
int
,
num_kv_heads
:
int
,
head_dim
:
int
,
page_size
:
int
,
pos_encoding_mode
:
str
=
"NONE"
,
window_left
:
int
=
-
1
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
q_data_type
:
Optional
[
Union
[
str
,
torch
.
dtype
]]
=
"float16"
,
kv_data_type
:
Optional
[
Union
[
str
,
torch
.
dtype
]]
=
None
,
data_type
:
Optional
[
Union
[
str
,
torch
.
dtype
]]
=
None
,
sm_scale
:
Optional
[
float
]
=
None
,
rope_scale
:
Optional
[
float
]
=
None
,
rope_theta
:
Optional
[
float
]
=
None
,
non_blocking
:
bool
=
True
,
)
->
None
:
"""
A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for
cudagraph capture/replay, while the no cudagraph version turns back
to the original plan.
using original plan after passing host-side buffers:
- only host-to-device copy of indptr and last_page_len buffers
Modifications for cudagraph:
- only host-to-device copy of indptr and last_page_len buffers.
- avoid device-to-device copy of indices buffer.
Part of the code get inspiration from the original plan from FlashInfer repo
and the implementation of fast_decode_plan for FlashInfer in SGlang repo.
"""
# Warm up with the original plan if it is first call, and always run the
# original plan if we run for dynamic shape. For fixed shape (cudagraph),
# this warm up is to generate the _cached_module for the decode wrapper.
if
not
self
.
is_cuda_graph_enabled
or
\
getattr
(
self
,
"vllm_first_call"
,
True
):
self
.
plan
(
indptr_cpu
,
indices
,
last_page_len_cpu
,
num_qo_heads
,
num_kv_heads
,
head_dim
,
page_size
,
pos_encoding_mode
,
window_left
,
logits_soft_cap
,
q_data_type
,
kv_data_type
,
data_type
,
sm_scale
,
rope_scale
,
rope_theta
,
non_blocking
,
)
self
.
vllm_first_call
=
False
return
assert
self
.
is_cuda_graph_enabled
,
"Should be cudagraph only here"
batch_size
=
len
(
last_page_len_cpu
)
if
logits_soft_cap
is
None
:
logits_soft_cap
=
0.0
# Handle data types consistently
if
data_type
is
not
None
:
if
q_data_type
is
None
:
q_data_type
=
data_type
if
kv_data_type
is
None
:
kv_data_type
=
data_type
elif
q_data_type
is
None
:
q_data_type
=
"float16"
if
kv_data_type
is
None
:
kv_data_type
=
q_data_type
q_data_type
=
getattr
(
torch
,
q_data_type
)
if
isinstance
(
q_data_type
,
str
)
else
q_data_type
kv_data_type
=
getattr
(
torch
,
kv_data_type
)
if
isinstance
(
kv_data_type
,
str
)
else
kv_data_type
if
self
.
use_tensor_cores
:
qo_indptr_host
=
_get_range_buf
(
batch_size
+
1
,
"cpu"
)
if
batch_size
!=
self
.
_fixed_batch_size
:
raise
ValueError
(
"The batch size should be fixed in cudagraph mode, the runtime "
"batch size {} mismatches the batch size set during "
"initialization {}"
.
format
(
batch_size
,
self
.
_fixed_batch_size
))
if
len
(
indices
)
>
len
(
self
.
_paged_kv_indices_buf
):
raise
ValueError
(
"The size of indices should be less than or equal to the "
"allocated buffer"
)
# host-to-device copy for the indptr buffer
self
.
_paged_kv_indptr_buf
.
copy_
(
indptr_cpu
,
non_blocking
=
True
)
# host-to-device copy for the last_page_len buffer
self
.
_paged_kv_last_page_len_buf
.
copy_
(
last_page_len_cpu
,
non_blocking
=
True
)
indptr_host
=
indptr_cpu
last_page_len_host
=
last_page_len_cpu
if
self
.
use_tensor_cores
:
kv_lens_arr_host
=
get_seq_lens
(
indptr_host
,
last_page_len_host
,
page_size
)
try
:
# Make sure we pass exactly 15 arguments for tensor core version
self
.
_plan_info
=
self
.
_cached_module
.
plan
(
self
.
_float_workspace_buffer
,
self
.
_int_workspace_buffer
,
self
.
_pin_memory_int_workspace_buffer
,
qo_indptr_host
,
indptr_host
,
kv_lens_arr_host
,
batch_size
,
# total_num_rows
batch_size
,
num_qo_heads
,
num_kv_heads
,
page_size
,
self
.
is_cuda_graph_enabled
,
head_dim
,
head_dim
,
False
,
# causal
)
except
Exception
as
e
:
raise
RuntimeError
(
f
"Error in tensor core plan:
{
e
}
"
)
from
e
else
:
try
:
# Make sure we pass exactly 15 arguments for standard version
self
.
_plan_info
=
self
.
_cached_module
.
plan
(
self
.
_float_workspace_buffer
,
self
.
_int_workspace_buffer
,
self
.
_pin_memory_int_workspace_buffer
,
indptr_host
,
batch_size
,
num_qo_heads
,
num_kv_heads
,
page_size
,
self
.
is_cuda_graph_enabled
,
window_left
,
logits_soft_cap
,
head_dim
,
head_dim
,
torch
.
empty
(
0
,
dtype
=
q_data_type
),
torch
.
empty
(
0
,
dtype
=
kv_data_type
),
)
except
Exception
as
e
:
raise
RuntimeError
(
f
"Error in standard plan:
{
e
}
"
)
from
e
self
.
_pos_encoding_mode
=
pos_encoding_mode
self
.
_window_left
=
window_left
self
.
_logits_soft_cap
=
logits_soft_cap
self
.
_sm_scale
=
sm_scale
self
.
_rope_scale
=
rope_scale
self
.
_rope_theta
=
rope_theta
vllm/v1/attention/backends/mla/flashmla.py
View file @
23322431
...
...
@@ -18,6 +18,7 @@ from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
MLACommonImpl
,
MLACommonMetadata
,
MLACommonMetadataBuilder
)
from
vllm.v1.attention.backends.utils
import
AttentionCGSupport
from
vllm.v1.kv_cache_interface
import
AttentionSpec
logger
=
init_logger
(
__name__
)
...
...
@@ -54,7 +55,8 @@ class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]):
class
FlashMLAMetadataBuilder
(
MLACommonMetadataBuilder
[
FlashMLAMetadata
]):
full_cudagraph_supported
:
ClassVar
[
bool
]
=
True
# Decode-only
attn_cudagraph_support
:
ClassVar
[
AttentionCGSupport
]
=
\
AttentionCGSupport
.
PURE_DECODE_ONLY
def
__init__
(
self
,
kv_cache_spec
:
AttentionSpec
,
layer_names
:
list
[
str
],
vllm_config
:
VllmConfig
,
device
:
torch
.
device
):
...
...
vllm/v1/attention/backends/mla/rocm_aiter_mla.py
View file @
23322431
...
...
@@ -17,6 +17,7 @@ from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
MLACommonImpl
,
MLACommonMetadata
,
MLACommonMetadataBuilder
)
from
vllm.v1.attention.backends.utils
import
AttentionCGSupport
from
vllm.v1.kv_cache_interface
import
AttentionSpec
# yapf: enable
...
...
@@ -64,7 +65,8 @@ class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
class
AiterMLAMetadataBuilder
(
MLACommonMetadataBuilder
[
AiterMLAMetadata
]):
full_cudagraph_supported
:
ClassVar
[
bool
]
=
True
# decode only
attn_cudagraph_support
:
ClassVar
[
AttentionCGSupport
]
=
\
AttentionCGSupport
.
PURE_DECODE_ONLY
def
__init__
(
self
,
kv_cache_spec
:
AttentionSpec
,
layer_names
:
list
[
str
],
vllm_config
:
VllmConfig
,
device
:
torch
.
device
):
...
...
vllm/v1/attention/backends/triton_attn.py
View file @
23322431
...
...
@@ -18,7 +18,8 @@ from vllm.config import VllmConfig
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.v1.attention.backends.flash_attn
import
FlashAttentionMetadata
from
vllm.v1.attention.backends.utils
import
(
AttentionMetadataBuilder
,
from
vllm.v1.attention.backends.utils
import
(
AttentionCGSupport
,
AttentionMetadataBuilder
,
CommonAttentionMetadata
)
from
vllm.v1.kv_cache_interface
import
AttentionSpec
...
...
@@ -57,7 +58,8 @@ class TritonAttentionMetadata:
class
TritonAttentionMetadataBuilder
(
AttentionMetadataBuilder
[
TritonAttentionMetadata
]):
full_cudagraph_supported
:
ClassVar
[
bool
]
=
True
attn_cudagraph_support
:
ClassVar
[
AttentionCGSupport
]
=
\
AttentionCGSupport
.
ALWAYS
def
__init__
(
self
,
kv_cache_spec
:
AttentionSpec
,
layer_names
:
list
[
str
],
vllm_config
:
VllmConfig
,
device
:
torch
.
device
):
...
...
vllm/v1/attention/backends/utils.py
View file @
23322431
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
abc
import
enum
import
functools
from
abc
import
abstractmethod
from
dataclasses
import
dataclass
,
make_dataclass
...
...
@@ -65,9 +66,24 @@ class CommonAttentionMetadata:
M
=
TypeVar
(
"M"
)
class
AttentionCGSupport
(
enum
.
Enum
):
""" Constants for the cudagraph support of the attention backend
Here we do not consider the cascade attention, as currently
it is never cudagraph supported."""
NEVER
=
0
"""NO cudagraph support"""
PURE_DECODE_ONLY
=
1
"""Cudagraph supported for pure decode, need to run without
cudagraph for mixed prefill-decode batches"""
ALWAYS
=
2
"""Cudagraph always supported"""
class
AttentionMetadataBuilder
(
abc
.
ABC
,
Generic
[
M
]):
# Does this backend/builder support CUDA Graphs for attention.
full_cudagraph_supported
:
ClassVar
[
bool
]
=
False
attn_cudagraph_support
:
ClassVar
[
AttentionCGSupport
]
=
\
AttentionCGSupport
.
NEVER
@
abstractmethod
def
__init__
(
self
,
kv_cache_spec
:
AttentionSpec
,
layer_names
:
list
[
str
],
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
23322431
...
...
@@ -47,7 +47,7 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
is_pin_memory_available
,
round_up
,
supports_dynamo
)
from
vllm.v1.attention.backends.mamba_selectors
import
get_mamba_attn_backend
from
vllm.v1.attention.backends.utils
import
(
AttentionMetadataBuilder
,
CommonAttentionMetadata
,
AttentionCGSupport
,
AttentionMetadataBuilder
,
CommonAttentionMetadata
,
make_kv_sharing_fast_prefill_attention_metadata
,
make_local_attention_virtual_batches
)
from
vllm.v1.core.encoder_cache_manager
import
compute_encoder_budget
...
...
@@ -2619,12 +2619,22 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self
.
device
,
)
if
(
self
.
full_cuda_graph
and
not
attn_metadata_builder_i
.
full_cudagraph_supported
):
raise
ValueError
(
f
"Full CUDAGraph not supported for "
f
"
{
attn_backend_i
.
__name__
}
. Turn off CompilationConfig."
f
"full_cuda_graph or use a different attention backend."
)
if
self
.
full_cuda_graph
:
if
attn_metadata_builder_i
.
attn_cudagraph_support
==
\
AttentionCGSupport
.
NEVER
:
raise
ValueError
(
f
"Full CUDAGraph not supported for "
f
"
{
attn_backend_i
.
__name__
}
. Turn off "
f
"CompilationConfig.full_cuda_graph or use a "
f
" different attention backend."
)
if
attn_metadata_builder_i
.
attn_cudagraph_support
==
\
AttentionCGSupport
.
PURE_DECODE_ONLY
:
# Limit the max cudagraph size to the max number of
# sequences for pure decode only cudagraph backend,
# whose max_query_len is 1.
self
.
cudagraph_batch_sizes
=
[
size
for
size
in
self
.
cudagraph_batch_sizes
if
size
<=
self
.
scheduler_config
.
max_num_seqs
]
return
attn_backend_i
,
attn_metadata_builder_i
def
initialize_attn_backend
(
self
,
kv_cache_config
:
KVCacheConfig
)
->
None
:
...
...
vllm/v1/worker/gpu_worker.py
View file @
23322431
...
...
@@ -321,11 +321,16 @@ class Worker(WorkerBase):
if
get_pp_group
().
is_last_rank
:
max_num_reqs
=
min
(
self
.
scheduler_config
.
max_num_seqs
,
self
.
scheduler_config
.
max_num_batched_tokens
)
# activate building attn_metadata for this dummy run to avoid
# potential illegal memory access for full cudagraph relay.
attn_cudagraph
=
self
.
compilation_config
.
full_cuda_graph
and
\
not
self
.
model_config
.
enforce_eager
# We skip EPLB here since we don't want to record dummy metrics
hidden_states
,
last_hidden_states
=
\
self
.
model_runner
.
_dummy_run
(
num_tokens
=
max_num_reqs
,
capture_attn_cudagraph
=
attn_cudagraph
,
skip_eplb
=
True
,
)
if
self
.
model_runner
.
is_pooling_model
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment