Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
61f67d8a
Unverified
Commit
61f67d8a
authored
Aug 10, 2025
by
Thomas Parnell
Committed by
GitHub
Aug 09, 2025
Browse files
[V1] [Hybrid] Enable Full CUDA Graph (decode-only) for Mamba layers (#21401)
Signed-off-by:
Thomas Parnell
<
tpa@zurich.ibm.com
>
parent
42172ad1
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
103 additions
and
1 deletion
+103
-1
tests/models/language/generation/test_hybrid.py
tests/models/language/generation/test_hybrid.py
+60
-0
vllm/v1/attention/backends/mamba_attn.py
vllm/v1/attention/backends/mamba_attn.py
+43
-1
No files found.
tests/models/language/generation/test_hybrid.py
View file @
61f67d8a
...
@@ -384,3 +384,63 @@ def test_distributed_correctness(
...
@@ -384,3 +384,63 @@ def test_distributed_correctness(
name_0
=
"vllm_tp_1"
,
name_0
=
"vllm_tp_1"
,
name_1
=
"vllm_tp_2"
,
name_1
=
"vllm_tp_2"
,
)
)
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"Zyphra/Zamba2-1.2B-instruct"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
64
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
5
])
def
test_full_cuda_graph
(
hf_runner
,
vllm_runner
,
example_prompts
,
monkeypatch
,
model
:
str
,
max_tokens
:
int
,
num_logprobs
:
int
,
)
->
None
:
try
:
model_info
=
HF_EXAMPLE_MODELS
.
find_hf_info
(
model
)
model_info
.
check_available_online
(
on_fail
=
"skip"
)
model_info
.
check_transformers_version
(
on_fail
=
"skip"
)
except
ValueError
:
pass
with
hf_runner
(
model
)
as
hf_model
:
if
model
not
in
HF_UNSUPPORTED_MODELS
:
hf_outputs
=
hf_model
.
generate_greedy_logprobs_limit
(
example_prompts
,
max_tokens
,
num_logprobs
)
else
:
hf_outputs
=
None
with
vllm_runner
(
model
,
max_num_seqs
=
MAX_NUM_SEQS
)
as
vllm_model
:
vllm_v0_outputs
=
vllm_model
.
generate_greedy_logprobs
(
example_prompts
,
max_tokens
,
num_logprobs
)
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
if
model
in
HYBRID_MODELS
:
# required due to reorder_batch behaviour
m
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
"FLASHINFER"
)
with
vllm_runner
(
model
,
max_num_seqs
=
MAX_NUM_SEQS
,
compilation_config
=
{
'full_cuda_graph'
:
True
},
enable_prefix_caching
=
False
)
as
vllm_model
:
vllm_v1_outputs
=
vllm_model
.
generate_greedy_logprobs
(
example_prompts
,
max_tokens
,
num_logprobs
)
if
hf_outputs
is
not
None
:
check_logprobs_close
(
outputs_0_lst
=
hf_outputs
,
outputs_1_lst
=
vllm_v0_outputs
,
name_0
=
"hf"
,
name_1
=
"vllm-v0"
,
)
ref_outputs
=
hf_outputs
if
hf_outputs
is
not
None
else
vllm_v0_outputs
check_logprobs_close
(
outputs_0_lst
=
ref_outputs
,
outputs_1_lst
=
vllm_v1_outputs
,
name_0
=
"hf"
if
hf_outputs
is
not
None
else
"vllm-v0"
,
name_1
=
"vllm-v1"
,
)
vllm/v1/attention/backends/mamba_attn.py
View file @
61f67d8a
...
@@ -7,8 +7,10 @@ from typing import ClassVar, Optional
...
@@ -7,8 +7,10 @@ from typing import ClassVar, Optional
import
torch
import
torch
from
vllm.attention.backends.abstract
import
AttentionBackend
from
vllm.attention.backends.abstract
import
AttentionBackend
from
vllm.attention.backends.utils
import
PAD_SLOT_ID
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.v1.attention.backends.utils
import
(
AttentionMetadataBuilder
,
from
vllm.v1.attention.backends.utils
import
(
AttentionCGSupport
,
AttentionMetadataBuilder
,
CommonAttentionMetadata
,
CommonAttentionMetadata
,
split_decodes_and_prefills
)
split_decodes_and_prefills
)
from
vllm.v1.kv_cache_interface
import
AttentionSpec
,
MambaSpec
from
vllm.v1.kv_cache_interface
import
AttentionSpec
,
MambaSpec
...
@@ -82,6 +84,8 @@ class Mamba2AttentionMetadata:
...
@@ -82,6 +84,8 @@ class Mamba2AttentionMetadata:
class
Mamba2AttentionMetadataBuilder
(
class
Mamba2AttentionMetadataBuilder
(
AttentionMetadataBuilder
[
Mamba2AttentionMetadata
]):
AttentionMetadataBuilder
[
Mamba2AttentionMetadata
]):
attn_cudagraph_support
:
ClassVar
[
AttentionCGSupport
]
=
\
AttentionCGSupport
.
PURE_DECODE_ONLY
reorder_batch_threshold
:
ClassVar
[
int
]
=
1
reorder_batch_threshold
:
ClassVar
[
int
]
=
1
...
@@ -90,8 +94,18 @@ class Mamba2AttentionMetadataBuilder(
...
@@ -90,8 +94,18 @@ class Mamba2AttentionMetadataBuilder(
assert
isinstance
(
kv_cache_spec
,
MambaSpec
)
assert
isinstance
(
kv_cache_spec
,
MambaSpec
)
self
.
kv_cache_spec
=
kv_cache_spec
self
.
kv_cache_spec
=
kv_cache_spec
self
.
chunk_size
=
vllm_config
.
model_config
.
get_mamba_chunk_size
()
self
.
chunk_size
=
vllm_config
.
model_config
.
get_mamba_chunk_size
()
self
.
vllm_config
=
vllm_config
self
.
compilation_config
=
vllm_config
.
compilation_config
assert
self
.
chunk_size
is
not
None
,
(
assert
self
.
chunk_size
is
not
None
,
(
"chunk_size needs to be set in the model config for Mamba2 models"
)
"chunk_size needs to be set in the model config for Mamba2 models"
)
self
.
decode_cudagraph_max_bs
=
min
(
self
.
vllm_config
.
scheduler_config
.
max_num_seqs
,
self
.
compilation_config
.
max_capture_size
)
self
.
state_indices_tensor
=
torch
.
empty
(
(
self
.
decode_cudagraph_max_bs
,
),
dtype
=
torch
.
int32
,
device
=
device
,
)
def
build
(
self
,
def
build
(
self
,
common_prefix_len
:
int
,
common_prefix_len
:
int
,
...
@@ -144,6 +158,14 @@ class Mamba2AttentionMetadataBuilder(
...
@@ -144,6 +158,14 @@ class Mamba2AttentionMetadataBuilder(
query_start_loc_p
,
self
.
chunk_size
,
query_start_loc_p
,
self
.
chunk_size
,
num_prefill_tokens
))
num_prefill_tokens
))
elif
num_decodes
<=
self
.
decode_cudagraph_max_bs
:
# Pad state tensor for CUDA graph
num_input_tokens
=
self
.
vllm_config
.
pad_for_cudagraph
(
num_decodes
)
self
.
state_indices_tensor
[:
num_decodes
].
copy_
(
state_indices_tensor
,
non_blocking
=
True
)
state_indices_tensor
=
self
.
state_indices_tensor
[:
num_input_tokens
]
state_indices_tensor
[
num_decodes
:]
=
PAD_SLOT_ID
attn_metadata
=
Mamba2AttentionMetadata
(
attn_metadata
=
Mamba2AttentionMetadata
(
num_prefills
=
num_prefills
,
num_prefills
=
num_prefills
,
num_prefill_tokens
=
num_prefill_tokens
,
num_prefill_tokens
=
num_prefill_tokens
,
...
@@ -160,3 +182,23 @@ class Mamba2AttentionMetadataBuilder(
...
@@ -160,3 +182,23 @@ class Mamba2AttentionMetadataBuilder(
state_indices_tensor
=
state_indices_tensor
,
state_indices_tensor
=
state_indices_tensor
,
)
)
return
attn_metadata
return
attn_metadata
def
build_for_cudagraph_capture
(
self
,
common_attn_metadata
:
CommonAttentionMetadata
):
"""
This method builds the metadata for full cudagraph capture.
Currently, only decode is supported for full cudagraphs with Mamba.
"""
m
=
common_attn_metadata
assert
m
.
num_reqs
==
m
.
num_actual_tokens
,
\
"Mamba only supports decode-only full CUDAGraph capture. "
\
"Make sure all cudagraph capture sizes <= max_num_seq."
m
.
max_query_len
=
1
# decode-only
return
self
.
build
(
0
,
m
)
def
can_run_in_cudagraph
(
self
,
common_attn_metadata
:
CommonAttentionMetadata
)
->
bool
:
return
common_attn_metadata
.
max_query_len
==
1
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment