Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8e5314a4
Unverified
Commit
8e5314a4
authored
Apr 08, 2025
by
Michael Goin
Committed by
GitHub
Apr 07, 2025
Browse files
[V1] Add `disable_chunked_mm_input` arg to disable partial mm input prefill (#15837)
Signed-off-by:
mgoin
<
mgoin64@gmail.com
>
parent
87918e40
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
80 additions
and
0 deletions
+80
-0
tests/v1/core/test_scheduler.py
tests/v1/core/test_scheduler.py
+45
-0
vllm/config.py
vllm/config.py
+8
-0
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+16
-0
vllm/v1/core/sched/scheduler.py
vllm/v1/core/sched/scheduler.py
+11
-0
No files found.
tests/v1/core/test_scheduler.py
View file @
8e5314a4
...
...
@@ -24,6 +24,7 @@ def create_scheduler(
max_num_batched_tokens
:
int
=
8192
,
enable_prefix_caching
:
Optional
[
bool
]
=
None
,
long_prefill_token_threshold
:
int
=
0
,
disable_chunked_mm_input
:
bool
=
False
,
)
->
Scheduler
:
'''Create scheduler under test.
...
...
@@ -43,6 +44,7 @@ def create_scheduler(
max_num_batched_tokens
=
max_num_batched_tokens
,
max_model_len
=
max_num_batched_tokens
,
long_prefill_token_threshold
=
long_prefill_token_threshold
,
disable_chunked_mm_input
=
disable_chunked_mm_input
,
)
model_config
=
ModelConfig
(
model
=
model
,
...
...
@@ -278,6 +280,49 @@ def test_schedule_partial_requests():
assert
requests
[
2
].
request_id
not
in
output
.
num_scheduled_tokens
def
test_no_mm_input_chunking
():
# Disable multimodal input chunking.
scheduler
=
create_scheduler
(
model
=
"llava-hf/llava-1.5-7b-hf"
,
max_num_batched_tokens
=
1024
,
disable_chunked_mm_input
=
True
,
)
mm_positions
=
[[
PlaceholderRange
(
offset
=
400
,
length
=
800
)]]
requests
=
create_requests
(
num_requests
=
1
,
num_tokens
=
1200
,
mm_positions
=
mm_positions
)
for
request
in
requests
:
scheduler
.
add_request
(
request
)
output
=
scheduler
.
schedule
()
assert
len
(
output
.
scheduled_new_reqs
)
==
1
assert
len
(
output
.
scheduled_cached_reqs
)
==
0
assert
len
(
output
.
finished_req_ids
)
==
0
# We want to only see the 400 text tokens at the start scheduled
assert
output
.
num_scheduled_tokens
[
requests
[
0
].
request_id
]
==
400
req_to_index
=
{
request
.
request_id
:
i
for
i
,
request
in
enumerate
(
requests
)
}
model_runner_output
=
ModelRunnerOutput
(
req_ids
=
[
request
.
request_id
for
request
in
requests
],
req_id_to_index
=
req_to_index
,
sampled_token_ids
=
[[]
for
_
in
range
(
len
(
requests
))],
spec_token_ids
=
None
,
logprobs
=
None
,
prompt_logprobs_dict
=
{},
)
scheduler
.
update_from_output
(
output
,
model_runner_output
)
output
=
scheduler
.
schedule
()
assert
len
(
scheduler
.
running
)
==
1
assert
len
(
output
.
scheduled_new_reqs
)
==
0
assert
len
(
output
.
scheduled_cached_reqs
)
==
1
assert
len
(
output
.
finished_req_ids
)
==
0
assert
output
.
num_scheduled_tokens
[
requests
[
0
].
request_id
]
==
800
@
pytest
.
mark
.
parametrize
(
"enable_prefix_caching"
,
[
True
,
False
])
def
test_schedule_concurrent_partial_requests
(
enable_prefix_caching
:
bool
):
"""Test scheduling behavior with concurrent partial requests.
...
...
vllm/config.py
View file @
8e5314a4
...
...
@@ -1721,6 +1721,14 @@ class SchedulerConfig:
chunked_prefill_enabled
:
bool
=
field
(
init
=
False
)
# If set to true and chunked prefill is enabled, we do not want to
# partially schedule a multimodal item. Only used in V1
# This ensures that if a request has a mixed prompt
# (like text tokens TTTT followed by image tokens IIIIIIIIII) where only
# some image tokens can be scheduled (like TTTTIIIII, leaving IIIII),
# it will be scheduled as TTTT in one step and IIIIIIIIII in the next.
disable_chunked_mm_input
:
bool
=
False
# scheduler class or path. "vllm.core.scheduler.Scheduler" (default)
# or "mod.custom_class".
scheduler_cls
:
Union
[
str
,
type
[
object
]]
=
"vllm.core.scheduler.Scheduler"
...
...
vllm/engine/arg_utils.py
View file @
8e5314a4
...
...
@@ -179,6 +179,7 @@ class EngineArgs:
scheduler_delay_factor
:
float
=
0.0
enable_chunked_prefill
:
Optional
[
bool
]
=
None
disable_chunked_mm_input
:
bool
=
False
guided_decoding_backend
:
str
=
'xgrammar'
logits_processor_pattern
:
Optional
[
str
]
=
None
...
...
@@ -1017,6 +1018,20 @@ class EngineArgs:
"Note that even if this is set to False, cascade attention will be "
"only used when the heuristic tells that it's beneficial."
)
parser
.
add_argument
(
"--disable-chunked-mm-input"
,
action
=
StoreBoolean
,
default
=
EngineArgs
.
disable_chunked_mm_input
,
nargs
=
"?"
,
const
=
"False"
,
help
=
"Disable multimodal input chunking attention for V1. "
"If set to true and chunked prefill is enabled, we do not want to"
" partially schedule a multimodal item. This ensures that if a "
"request has a mixed prompt (like text tokens TTTT followed by "
"image tokens IIIIIIIIII) where only some image tokens can be "
"scheduled (like TTTTIIIII, leaving IIIII), it will be scheduled "
"as TTTT in one step and IIIIIIIIII in the next."
)
return
parser
@
classmethod
...
...
@@ -1261,6 +1276,7 @@ class EngineArgs:
num_lookahead_slots
=
num_lookahead_slots
,
delay_factor
=
self
.
scheduler_delay_factor
,
enable_chunked_prefill
=
self
.
enable_chunked_prefill
,
disable_chunked_mm_input
=
self
.
disable_chunked_mm_input
,
is_multimodal_model
=
model_config
.
is_multimodal_model
,
preemption_mode
=
self
.
preemption_mode
,
num_scheduler_steps
=
self
.
num_scheduler_steps
,
...
...
vllm/v1/core/sched/scheduler.py
View file @
8e5314a4
...
...
@@ -522,6 +522,17 @@ class Scheduler(SchedulerInterface):
if
self
.
encoder_cache_manager
.
has_cache
(
request
,
i
):
# The encoder input is already computed and cached.
continue
# If no encoder input chunking is allowed, we do not want to
# partially schedule a multimodal item. If the scheduled range would
# only cover part of the mm input, roll back to before the mm item.
if
(
self
.
scheduler_config
.
disable_chunked_mm_input
and
num_computed_tokens
<
start_pos
and
(
num_computed_tokens
+
num_new_tokens
)
<
(
start_pos
+
num_encoder_tokens
)):
num_new_tokens
=
start_pos
-
num_computed_tokens
break
if
(
not
self
.
encoder_cache_manager
.
can_allocate
(
request
,
i
)
or
num_encoder_tokens
>
encoder_budget
):
# The encoder cache is full or the encoder budget is exhausted.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment