Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
64e39d66
Unverified
Commit
64e39d66
authored
Nov 17, 2025
by
Lucas Wilkinson
Committed by
GitHub
Nov 17, 2025
Browse files
[BugFix] Temporary fix for IMA with MTP = 2 and full-cg (#28315)
Signed-off-by:
Lucas Wilkinson
<
lwilkins@redhat.com
>
parent
1b82fb0a
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
80 additions
and
13 deletions
+80
-13
vllm/config/compilation.py
vllm/config/compilation.py
+64
-13
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+16
-0
No files found.
vllm/config/compilation.py
View file @
64e39d66
...
...
@@ -18,6 +18,7 @@ from vllm.config.utils import config
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.utils.import_utils
import
resolve_obj_by_qualname
from
vllm.utils.math_utils
import
round_up
from
vllm.utils.torch_utils
import
is_torch_equal_or_newer
if
TYPE_CHECKING
:
...
...
@@ -773,19 +774,8 @@ class CompilationConfig:
if
self
.
cudagraph_capture_sizes
:
assert
self
.
cudagraph_capture_sizes
[
-
1
]
==
self
.
max_cudagraph_capture_size
# pre-compute the mapping from batch size to padded graph size
self
.
bs_to_padded_graph_size
=
[
0
for
i
in
range
(
self
.
max_cudagraph_capture_size
+
1
)
]
for
end
,
start
in
zip
(
self
.
cudagraph_capture_sizes
+
[
self
.
max_cudagraph_capture_size
+
1
],
[
0
]
+
self
.
cudagraph_capture_sizes
,
):
for
bs
in
range
(
start
,
end
):
if
bs
==
start
:
self
.
bs_to_padded_graph_size
[
bs
]
=
start
else
:
self
.
bs_to_padded_graph_size
[
bs
]
=
end
# May get recomputed in the model runner if adjustment is needed for spec-decode
self
.
compute_bs_to_padded_graph_size
()
def
set_splitting_ops_for_v1
(
self
):
# NOTE: this function needs to be called only when mode is
...
...
@@ -922,3 +912,64 @@ class CompilationConfig:
enable_str
,
op
,
)
def
adjust_cudagraph_sizes_for_spec_decode
(
self
,
uniform_decode_query_len
:
int
,
tensor_parallel_size
:
int
):
multiple_of
=
uniform_decode_query_len
if
tensor_parallel_size
>
1
:
multiple_of
=
max
(
uniform_decode_query_len
,
tensor_parallel_size
)
if
(
multiple_of
%
uniform_decode_query_len
!=
0
or
multiple_of
%
tensor_parallel_size
!=
0
):
raise
ValueError
(
f
"Can't determine cudagraph shapes that are both a "
f
"multiple of
{
uniform_decode_query_len
}
"
f
"(num_speculative_tokens + 1) required by spec-decode "
f
"and
{
tensor_parallel_size
}
(tensor_parallel_size) "
f
"required by sequence parallelism please adjust "
f
"num_speculative_tokens or disable sequence parallelism"
)
if
not
self
.
cudagraph_capture_sizes
or
multiple_of
<=
1
:
return
assert
self
.
max_cudagraph_capture_size
is
not
None
rounded_sizes
=
sorted
(
set
(
round_up
(
size
,
multiple_of
)
for
size
in
self
.
cudagraph_capture_sizes
if
round_up
(
size
,
multiple_of
)
<=
self
.
max_cudagraph_capture_size
)
)
if
len
(
rounded_sizes
)
==
0
:
logger
.
warning
(
"No valid cudagraph sizes after rounding to multiple of "
" num_speculative_tokens + 1 (%d); please adjust num_speculative_tokens"
" or max_cudagraph_capture_size (or cudagraph_capture_sizes)"
,
multiple_of
,
)
return
self
.
max_cudagraph_capture_size
=
rounded_sizes
[
-
1
]
self
.
cudagraph_capture_sizes
=
rounded_sizes
# Recompute after adjusting the cudagraph sizes
self
.
compute_bs_to_padded_graph_size
()
def
compute_bs_to_padded_graph_size
(
self
):
# pre-compute the mapping from batch size to padded graph size
self
.
bs_to_padded_graph_size
=
[
0
for
i
in
range
(
self
.
max_cudagraph_capture_size
+
1
)
]
for
end
,
start
in
zip
(
self
.
cudagraph_capture_sizes
+
[
self
.
max_cudagraph_capture_size
+
1
],
[
0
]
+
self
.
cudagraph_capture_sizes
,
):
for
bs
in
range
(
start
,
end
):
if
bs
==
start
:
self
.
bs_to_padded_graph_size
[
bs
]
=
start
else
:
self
.
bs_to_padded_graph_size
[
bs
]
=
end
vllm/v1/worker/gpu_model_runner.py
View file @
64e39d66
...
...
@@ -4332,6 +4332,22 @@ class GPUModelRunner(
"and make sure compilation mode is VLLM_COMPILE"
)
# if we have dedicated decode cudagraphs, and spec-decode is enabled,
# we need to adjust the cudagraph sizes to be a multiple of the uniform
# decode query length to avoid: https://github.com/vllm-project/vllm/issues/28207
# temp-fix: https://github.com/vllm-project/vllm/issues/28207#issuecomment-3504004536
# Will be removed in the near future when we have seperate cudagraph capture
# sizes for decode and mixed prefill-decode.
if
(
cudagraph_mode
.
decode_mode
()
==
CUDAGraphMode
.
FULL
and
cudagraph_mode
.
separate_routine
()
and
self
.
uniform_decode_query_len
>
1
):
self
.
compilation_config
.
adjust_cudagraph_sizes_for_spec_decode
(
self
.
uniform_decode_query_len
,
self
.
parallel_config
.
tensor_parallel_size
)
self
.
cudagraph_batch_sizes
=
self
.
compilation_config
.
cudagraph_capture_sizes
# Trigger cudagraph dispatching keys initialization after
# resolved cudagraph mode.
self
.
cudagraph_dispatcher
.
initialize_cudagraph_keys
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment