Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
59f93530
Unverified
Commit
59f93530
authored
Jul 19, 2025
by
Lucas Wilkinson
Committed by
GitHub
Jul 19, 2025
Browse files
[BugFix] Fix potential cuda-graph IMA (#21196)
Signed-off-by:
Lucas Wilkinson
<
lwilkins@redhat.com
>
parent
18e519ec
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
6 additions
and
6 deletions
+6
-6
vllm/v1/attention/backends/utils.py
vllm/v1/attention/backends/utils.py
+0
-5
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+6
-1
No files found.
vllm/v1/attention/backends/utils.py
View file @
59f93530
...
...
@@ -59,11 +59,6 @@ class CommonAttentionMetadata:
block_table_tensor
:
torch
.
Tensor
slot_mapping
:
torch
.
Tensor
def
__post_init__
(
self
):
# Fill unused with -1. Needed for reshape_and_cache in full cuda graph
# mode.
self
.
slot_mapping
[
self
.
num_actual_tokens
:].
fill_
(
-
1
)
M
=
TypeVar
(
"M"
)
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
59f93530
...
...
@@ -684,7 +684,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
seq_lens
[:
num_reqs
].
copy_
(
self
.
seq_lens_cpu
[:
num_reqs
],
non_blocking
=
True
)
# Fill unused with
-1. Needed for reshape_and_cache
# Fill unused with
0 for full cuda graph mode.
self
.
seq_lens
[
num_reqs
:].
fill_
(
0
)
# Note: pad query_start_loc to be non-decreasing, as kernels
# like FlashAttention requires that
...
...
@@ -704,6 +704,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
blk_table
=
self
.
input_batch
.
block_table
[
kv_cache_group_id
]
blk_table_tensor
=
blk_table
.
get_device_tensor
()[:
num_reqs
]
slot_mapping
=
blk_table
.
slot_mapping
[:
total_num_scheduled_tokens
]
# Fill unused with -1. Needed for reshape_and_cache in full cuda
# graph mode.
blk_table
.
slot_mapping
[
total_num_scheduled_tokens
:].
fill_
(
-
1
)
common_attn_metadata
=
CommonAttentionMetadata
(
query_start_loc
=
self
.
query_start_loc
[:
num_reqs
+
1
],
query_start_loc_cpu
=
self
.
query_start_loc_cpu
[:
num_reqs
+
1
],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment