Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8da1c576
Commit
8da1c576
authored
Jul 24, 2025
by
zhuwenwen
Browse files
support full_cuda_graph
parent
f9408aff
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
5 additions
and
4 deletions
+5
-4
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+3
-3
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+2
-1
No files found.
vllm/v1/attention/backends/flash_attn.py
View file @
8da1c576
...
@@ -183,7 +183,7 @@ class FlashAttentionMetadataBuilder(
...
@@ -183,7 +183,7 @@ class FlashAttentionMetadataBuilder(
self
.
max_num_splits
=
0
# No upper bound on the number of splits.
self
.
max_num_splits
=
0
# No upper bound on the number of splits.
self
.
aot_schedule
=
(
get_flash_attn_version
()
==
3
)
self
.
aot_schedule
=
(
get_flash_attn_version
()
==
3
)
self
.
use_full_cuda_graph
=
compilation_config
.
full_cuda_graph
self
.
use_full_cuda_graph
=
compilation_config
.
full_cuda_graph
if
self
.
use_full_cuda_graph
:
if
not
current_platform
.
is_rocm
()
and
self
.
use_full_cuda_graph
:
if
not
self
.
aot_schedule
:
if
not
self
.
aot_schedule
:
raise
ValueError
(
raise
ValueError
(
"AoT scheduling is required for full cuda graph."
)
"AoT scheduling is required for full cuda graph."
)
...
@@ -361,7 +361,7 @@ class FlashAttentionMetadataBuilder(
...
@@ -361,7 +361,7 @@ class FlashAttentionMetadataBuilder(
max_seq_len
=
max_seq_len
,
max_seq_len
=
max_seq_len
,
causal
=
True
)
causal
=
True
)
if
self
.
use_full_cuda_graph
:
if
not
current_platform
.
is_rocm
()
and
self
.
use_full_cuda_graph
:
assert
scheduler_metadata
is
not
None
assert
scheduler_metadata
is
not
None
n
=
scheduler_metadata
.
shape
[
0
]
n
=
scheduler_metadata
.
shape
[
0
]
self
.
scheduler_metadata
[:
n
]
=
scheduler_metadata
self
.
scheduler_metadata
[:
n
]
=
scheduler_metadata
...
@@ -373,7 +373,7 @@ class FlashAttentionMetadataBuilder(
...
@@ -373,7 +373,7 @@ class FlashAttentionMetadataBuilder(
scheduler_metadata
=
self
.
scheduler_metadata
[:
n
]
scheduler_metadata
=
self
.
scheduler_metadata
[:
n
]
max_num_splits
=
0
max_num_splits
=
0
if
(
self
.
use_full_cuda_graph
if
(
not
current_platform
.
is_rocm
()
and
self
.
use_full_cuda_graph
and
num_actual_tokens
<=
self
.
max_cudagraph_size
):
and
num_actual_tokens
<=
self
.
max_cudagraph_size
):
# NOTE(woosuk): Setting num_splits > 1 may increase the memory
# NOTE(woosuk): Setting num_splits > 1 may increase the memory
# usage, because the intermediate buffers of size [num_splits,
# usage, because the intermediate buffers of size [num_splits,
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
8da1c576
...
@@ -67,6 +67,7 @@ from vllm.v1.utils import bind_kv_cache
...
@@ -67,6 +67,7 @@ from vllm.v1.utils import bind_kv_cache
from
vllm.v1.worker.block_table
import
BlockTable
from
vllm.v1.worker.block_table
import
BlockTable
from
vllm.v1.worker.gpu_input_batch
import
CachedRequestState
,
InputBatch
from
vllm.v1.worker.gpu_input_batch
import
CachedRequestState
,
InputBatch
from
vllm.v1.worker.lora_model_runner_mixin
import
LoRAModelRunnerMixin
from
vllm.v1.worker.lora_model_runner_mixin
import
LoRAModelRunnerMixin
from
vllm.platforms
import
current_platform
from
..sample.logits_processor
import
LogitsProcessorManager
from
..sample.logits_processor
import
LogitsProcessorManager
from
.utils
import
(
gather_mm_placeholders
,
initialize_kv_cache_for_kv_sharing
,
from
.utils
import
(
gather_mm_placeholders
,
initialize_kv_cache_for_kv_sharing
,
...
@@ -2371,7 +2372,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -2371,7 +2372,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
block_table_i
,
block_table_i
,
)
)
if
(
self
.
full_cuda_graph
if
(
not
current_platform
.
is_rocm
()
and
self
.
full_cuda_graph
and
not
attn_metadata_builder_i
.
full_cudagraph_supported
):
and
not
attn_metadata_builder_i
.
full_cudagraph_supported
):
raise
ValueError
(
raise
ValueError
(
f
"Full CUDAGraph not supported for "
f
"Full CUDAGraph not supported for "
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment