Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8b141ed8
Unverified
Commit
8b141ed8
authored
Apr 02, 2026
by
shunting314
Committed by
GitHub
Apr 02, 2026
Browse files
full cudagraph for flex-attn (#36298)
Signed-off-by:
shunting314
<
shunting@meta.com
>
parent
2ad7c033
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
145 additions
and
11 deletions
+145
-11
tests/compile/fullgraph/test_full_cudagraph.py
tests/compile/fullgraph/test_full_cudagraph.py
+0
-11
tests/kernels/test_flex_attention.py
tests/kernels/test_flex_attention.py
+53
-0
vllm/v1/attention/backends/flex_attention.py
vllm/v1/attention/backends/flex_attention.py
+91
-0
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+1
-0
No files found.
tests/compile/fullgraph/test_full_cudagraph.py
View file @
8b141ed8
...
...
@@ -170,14 +170,3 @@ class TestFullCUDAGraph:
piecewise_res
.
outputs
[
0
].
text
.
lower
()
==
full_res
.
outputs
[
0
].
text
.
lower
()
)
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_cuda
(),
reason
=
"Skip if not cuda"
)
def
test_full_cudagraph_with_invalid_backend
():
# Flex_Attention is not supported with full cuda graph
with
pytest
.
raises
(
RuntimeError
):
LLM
(
model
=
"Qwen/Qwen2-1.5B-Instruct"
,
compilation_config
=
CompilationConfig
(
cudagraph_mode
=
"FULL"
),
attention_config
=
{
"backend"
:
"FLEX_ATTENTION"
},
)
tests/kernels/test_flex_attention.py
View file @
8b141ed8
...
...
@@ -26,6 +26,59 @@ MINIMUM_TORCH_VERSION = version.parse("2.7.0")
DIRECT_BUILD_VERSION
=
version
.
parse
(
"2.9.dev0"
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
()
or
TORCH_VERSION
<
MINIMUM_TORCH_VERSION
,
reason
=
"CUDA not available or PyTorch version < 2.7"
,
)
def
test_flex_attention_full_cudagraphs
(
vllm_runner
):
"""Test the numerics for flex attention full cudagraphs support."""
model_name
=
"Qwen/Qwen2.5-1.5B-Instruct"
seed
=
42
max_tokens
=
24
num_logprobs
=
5
prompts
=
[
"Hello, my name is"
,
"The president of the United States is"
,
"The capital of France is"
,
]
# Run with flex attention eager
set_random_seed
(
seed
)
with
vllm_runner
(
model_name
,
runner
=
"generate"
,
tensor_parallel_size
=
1
,
num_gpu_blocks_override
=
128
,
enforce_eager
=
True
,
attention_config
=
{
"backend"
:
"FLEX_ATTENTION"
},
)
as
llm_flex
:
output_eager
=
llm_flex
.
generate_greedy_logprobs
(
prompts
,
max_tokens
,
num_logprobs
)
# Run with flex attention compiled
set_random_seed
(
seed
)
with
vllm_runner
(
model_name
,
runner
=
"generate"
,
tensor_parallel_size
=
1
,
num_gpu_blocks_override
=
128
,
enforce_eager
=
False
,
gpu_memory_utilization
=
0.85
,
attention_config
=
{
"backend"
:
"FLEX_ATTENTION"
},
)
as
llm_default
:
output_compile
=
llm_default
.
generate_greedy_logprobs
(
prompts
,
max_tokens
,
num_logprobs
)
check_logprobs_close
(
outputs_0_lst
=
output_eager
,
outputs_1_lst
=
output_compile
,
name_0
=
"eager"
,
name_1
=
"compile"
,
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
()
or
TORCH_VERSION
<
MINIMUM_TORCH_VERSION
,
reason
=
"CUDA not available or PyTorch version < 2.7"
,
...
...
vllm/v1/attention/backends/flex_attention.py
View file @
8b141ed8
...
...
@@ -30,6 +30,7 @@ from vllm.utils.math_utils import cdiv
from
vllm.utils.torch_utils
import
is_quantized_kv_cache
,
is_torch_equal_or_newer
from
vllm.v1.attention.backend
import
(
AttentionBackend
,
AttentionCGSupport
,
AttentionImpl
,
AttentionMetadataBuilder
,
AttentionType
,
...
...
@@ -315,6 +316,18 @@ class BlockSparsityHint(NamedTuple):
hint_fn
:
_block_sparsity_hint_signature
def
copy_to_persistent
(
dst
,
src
):
try
:
dst
=
dst
.
as_strided
(
src
.
shape
,
src
.
stride
())
except
RuntimeError
as
e
:
raise
RuntimeError
(
f
"Fail to re-stride a persistent tensor of shape
{
dst
.
shape
}
"
f
"for a tensor of shape
{
src
.
shape
}
"
)
from
e
dst
.
copy_
(
src
)
return
dst
@
dataclass
class
FlexAttentionMetadata
:
causal
:
bool
...
...
@@ -340,6 +353,9 @@ class FlexAttentionMetadata:
physical_to_logical
:
torch
.
Tensor
decode_offset
:
torch
.
Tensor
num_blocks_per_seq
:
torch
.
Tensor
persistent_kv_indices
:
torch
.
Tensor
persistent_kv_num_blocks
:
torch
.
Tensor
persistent_doc_ids
:
torch
.
Tensor
# For logging.
num_input_tokens
:
int
=
0
# Number of tokens including padding.
...
...
@@ -656,8 +672,11 @@ class FlexAttentionMetadata:
kv_indices
=
unique_static_unsorted
(
(
used_pages_padded
.
long
()),
M
=
self
.
num_blocks
).
to
(
torch
.
int32
)
kv_indices
=
copy_to_persistent
(
self
.
persistent_kv_indices
,
kv_indices
)
kv_num_blocks
=
(
kv_indices
>=
0
).
sum
(
dim
=-
1
).
to
(
torch
.
int32
)
kv_num_blocks
=
copy_to_persistent
(
self
.
persistent_kv_num_blocks
,
kv_num_blocks
)
block_mask_kwargs
=
{
"seq_lengths"
:
(
self
.
num_actual_tokens
,
self
.
total_cache_tokens
),
"kv_num_blocks"
:
kv_num_blocks
[
None
,
None
],
...
...
@@ -694,6 +713,7 @@ class FlexAttentionMetadata:
assert
self
.
suffix_kv_lens
is
None
,
"Not implemented yet."
# Create a lookup mapping from query indices -> request number
self
.
doc_ids
=
_offsets_to_doc_ids_tensor
(
self
.
query_start_loc
)
self
.
doc_ids
=
copy_to_persistent
(
self
.
persistent_doc_ids
,
self
.
doc_ids
)
self
.
num_blocks
=
self
.
total_cache_tokens
//
self
.
block_size
self
.
mask_mod
=
self
.
get_mask_mod
()
...
...
@@ -701,6 +721,8 @@ class FlexAttentionMetadata:
class
FlexAttentionMetadataBuilder
(
AttentionMetadataBuilder
[
FlexAttentionMetadata
]):
_cudagraph_support
:
ClassVar
[
AttentionCGSupport
]
=
AttentionCGSupport
.
ALWAYS
def
__init__
(
self
,
kv_cache_spec
:
AttentionSpec
,
...
...
@@ -726,6 +748,38 @@ class FlexAttentionMetadataBuilder(AttentionMetadataBuilder[FlexAttentionMetadat
self
.
q_block_size
:
int
=
16
if
supports_small_blocks
else
128
self
.
kv_block_size
:
int
=
self
.
block_size
if
supports_small_blocks
else
128
self
.
max_model_len
=
self
.
model_config
.
max_model_len
max_num_seqs
=
vllm_config
.
scheduler_config
.
max_num_seqs
max_num_batched_tokens
=
vllm_config
.
scheduler_config
.
max_num_batched_tokens
self
.
max_num_q_block
=
(
self
.
max_model_len
+
self
.
q_block_size
-
1
)
//
self
.
q_block_size
self
.
persistent_kv_num_blocks
=
torch
.
empty
(
self
.
max_num_q_block
,
dtype
=
torch
.
int32
,
device
=
device
)
self
.
persistent_offset_tensor
=
torch
.
empty
(
max_num_seqs
,
dtype
=
torch
.
int32
,
device
=
device
)
self
.
persistent_doc_ids
=
torch
.
empty
(
max_num_batched_tokens
,
dtype
=
torch
.
int32
,
device
=
device
)
# initialize later when we can access block_table
self
.
persistent_physical_to_logical
=
None
self
.
persistent_kv_indices
=
None
def
build_for_cudagraph_capture
(
self
,
common_attn_metadata
:
CommonAttentionMetadata
)
->
FlexAttentionMetadata
:
# Use actual max_seq_len instead of max_model_len to avoid
# torch.compile recompilation during CUDA graph capture.
common_attn_metadata
.
max_seq_len
=
(
common_attn_metadata
.
seq_lens_cpu
.
max
().
item
()
)
return
self
.
build
(
common_prefix_len
=
0
,
common_attn_metadata
=
common_attn_metadata
)
def
build
(
self
,
common_prefix_len
:
int
,
...
...
@@ -765,8 +819,32 @@ class FlexAttentionMetadataBuilder(AttentionMetadataBuilder[FlexAttentionMetadat
inverse_block_table
=
physical_to_logical_mapping
(
block_table_tensor
,
seq_lens
,
block_size
,
num_gpu_blocks
)
if
self
.
persistent_physical_to_logical
is
None
:
max_num_seqs
=
self
.
vllm_config
.
scheduler_config
.
max_num_seqs
self
.
persistent_physical_to_logical
=
torch
.
empty
(
max_num_seqs
,
num_gpu_blocks
,
dtype
=
torch
.
long
,
device
=
self
.
device
,
)
if
self
.
persistent_kv_indices
is
None
:
max_num_kv_block
=
(
self
.
max_model_len
+
self
.
kv_block_size
-
1
)
//
self
.
kv_block_size
self
.
persistent_kv_indices
=
torch
.
empty
(
self
.
max_model_len
,
max_num_kv_block
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
)
inverse_block_table
=
copy_to_persistent
(
self
.
persistent_physical_to_logical
,
inverse_block_table
)
offset_tensor
=
common_attn_metadata
.
compute_num_computed_tokens
()
offset_tensor
=
copy_to_persistent
(
self
.
persistent_offset_tensor
,
offset_tensor
)
out
=
FlexAttentionMetadata
(
causal
=
common_attn_metadata
.
causal
,
...
...
@@ -795,7 +873,20 @@ class FlexAttentionMetadataBuilder(AttentionMetadataBuilder[FlexAttentionMetadat
direct_build
=
(
self
.
direct_build
and
common_attn_metadata
.
causal
),
q_block_size
=
self
.
q_block_size
,
kv_block_size
=
self
.
kv_block_size
,
persistent_kv_indices
=
self
.
persistent_kv_indices
,
persistent_kv_num_blocks
=
self
.
persistent_kv_num_blocks
,
persistent_doc_ids
=
self
.
persistent_doc_ids
,
)
# Pre-build block_mask so it is ready before CUDA graph capture.
# Without this, the lazy build in forward() would run non-graph-safe
# ops (e.g. torch.nonzero) inside capture.
if
out
.
block_mask
is
None
:
if
out
.
direct_build
:
out
.
block_mask
=
out
.
_build_block_mask_direct
()
else
:
out
.
block_mask
=
out
.
build_block_mask
()
return
out
def
use_cascade_attention
(
self
,
*
args
,
**
kwargs
)
->
bool
:
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
8b141ed8
...
...
@@ -6077,6 +6077,7 @@ class GPUModelRunner(
skip_eplb
=
True
,
remove_lora
=
False
,
num_active_loras
=
desc
.
num_active_loras
,
profile_seq_lens
=
profile_seq_lens
,
)
self
.
_dummy_run
(
desc
.
num_tokens
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment