Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7d6b0338
Unverified
Commit
7d6b0338
authored
Oct 04, 2025
by
Huamin Li
Committed by
GitHub
Oct 04, 2025
Browse files
[CI Failure] fix_test_auto_prefix_cache_support (#26053)
Signed-off-by:
Huamin Li
<
3ericli@gmail.com
>
parent
7c2e91c4
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
14 additions
and
7 deletions
+14
-7
tests/v1/core/test_scheduler.py
tests/v1/core/test_scheduler.py
+2
-2
vllm/config/vllm.py
vllm/config/vllm.py
+12
-5
No files found.
tests/v1/core/test_scheduler.py
View file @
7d6b0338
...
@@ -1917,7 +1917,7 @@ def test_priority_scheduling_preemption_when_out_of_kv():
...
@@ -1917,7 +1917,7 @@ def test_priority_scheduling_preemption_when_out_of_kv():
def
test_chunked_prefill_disabled_for_encoder_decoder
(
def
test_chunked_prefill_disabled_for_encoder_decoder
(
enable_chunked_prefill
:
bool
,
is_encoder_decoder
:
bool
,
enable_chunked_prefill
:
bool
,
is_encoder_decoder
:
bool
,
expect_enabled
:
bool
)
->
None
:
expect_enabled
:
bool
)
->
None
:
"""Validate that chunked prefill is appropriately disabled for
"""Validate that chunked prefill is appropriately disabled for
encoder-decoder models."""
encoder-decoder models."""
scheduler_config
=
SchedulerConfig
(
scheduler_config
=
SchedulerConfig
(
enable_chunked_prefill
=
enable_chunked_prefill
,
enable_chunked_prefill
=
enable_chunked_prefill
,
...
@@ -1942,7 +1942,7 @@ def test_chunked_prefill_disabled_for_encoder_decoder(
...
@@ -1942,7 +1942,7 @@ def test_chunked_prefill_disabled_for_encoder_decoder(
def
_validate_chunked_prefill_settings_for_encoder_decoder
(
def
_validate_chunked_prefill_settings_for_encoder_decoder
(
scheduler_config
:
SchedulerConfig
,
is_encoder_decoder
:
bool
,
scheduler_config
:
SchedulerConfig
,
is_encoder_decoder
:
bool
,
expect_enabled
:
bool
)
->
None
:
expect_enabled
:
bool
)
->
None
:
"""Validate chunked prefill settings in the scheduler config for
"""Validate chunked prefill settings in the scheduler config for
encoder-decoder models."""
encoder-decoder models."""
assert
scheduler_config
.
chunked_prefill_enabled
is
expect_enabled
assert
scheduler_config
.
chunked_prefill_enabled
is
expect_enabled
assert
scheduler_config
.
enable_chunked_prefill
is
expect_enabled
assert
scheduler_config
.
enable_chunked_prefill
is
expect_enabled
...
...
vllm/config/vllm.py
View file @
7d6b0338
...
@@ -396,10 +396,17 @@ class VllmConfig:
...
@@ -396,10 +396,17 @@ class VllmConfig:
"try setting 'VLLM_WORKER_MULTIPROC_METHOD' "
"try setting 'VLLM_WORKER_MULTIPROC_METHOD' "
"to 'spawn'."
)
"to 'spawn'."
)
# Disable prefix caching only if chunked prefill is explicitly disabled
# Final off-switch for CP/APC:
# (and not merely unset)
# Disable for (a) collected blockers, (b) encoder–decoder, or
if
(
self
.
scheduler_config
.
chunked_prefill_enabled
is
False
# (c) explicit CP=False when APC wasn't requested.
or
disable_chunked_prefill_reasons
):
# Do NOT disable merely because the resolved CP flag is False.
apc_requested
=
(
self
.
cache_config
is
not
None
and
self
.
cache_config
.
enable_prefix_caching
)
if
(
disable_chunked_prefill_reasons
or
(
self
.
model_config
is
not
None
and
self
.
model_config
.
is_encoder_decoder
)
or
(
self
.
scheduler_config
.
enable_chunked_prefill
is
False
and
not
apc_requested
)):
for
reason
in
disable_chunked_prefill_reasons
:
for
reason
in
disable_chunked_prefill_reasons
:
logger
.
info
(
reason
)
logger
.
info
(
reason
)
self
.
scheduler_config
.
chunked_prefill_enabled
=
False
self
.
scheduler_config
.
chunked_prefill_enabled
=
False
...
@@ -668,7 +675,7 @@ class VllmConfig:
...
@@ -668,7 +675,7 @@ class VllmConfig:
f
"Model:
{
self
.
model_config
.
model
}
"
)
f
"Model:
{
self
.
model_config
.
model
}
"
)
def
compile_debug_dump_path
(
self
)
->
Optional
[
Path
]:
def
compile_debug_dump_path
(
self
)
->
Optional
[
Path
]:
"""Returns a rank-aware path for dumping
"""Returns a rank-aware path for dumping
torch.compile debug information.
torch.compile debug information.
"""
"""
if
self
.
compilation_config
.
debug_dump_path
is
None
:
if
self
.
compilation_config
.
debug_dump_path
is
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment