Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0377802c
Unverified
Commit
0377802c
authored
Sep 12, 2025
by
Flora Feng
Committed by
GitHub
Sep 12, 2025
Browse files
[Multimodal] Remove legacy multimodal fields in favor of MultiModalFeatureSpec (#24548)
Signed-off-by:
sfeng33
<
4florafeng@gmail.com
>
parent
72fc8aa4
Changes
13
Show whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
102 additions
and
116 deletions
+102
-116
tests/v1/core/test_encoder_cache_manager.py
tests/v1/core/test_encoder_cache_manager.py
+11
-1
tests/v1/tpu/worker/test_tpu_model_runner.py
tests/v1/tpu/worker/test_tpu_model_runner.py
+1
-3
tests/v1/worker/test_gpu_input_batch.py
tests/v1/worker/test_gpu_input_batch.py
+1
-3
tests/v1/worker/test_gpu_model_runner.py
tests/v1/worker/test_gpu_model_runner.py
+1
-3
vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py
...d/kv_transfer/kv_connector/v1/shared_storage_connector.py
+22
-19
vllm/v1/core/encoder_cache_manager.py
vllm/v1/core/encoder_cache_manager.py
+5
-5
vllm/v1/core/kv_cache_utils.py
vllm/v1/core/kv_cache_utils.py
+16
-20
vllm/v1/core/sched/output.py
vllm/v1/core/sched/output.py
+5
-13
vllm/v1/core/sched/scheduler.py
vllm/v1/core/sched/scheduler.py
+11
-11
vllm/v1/request.py
vllm/v1/request.py
+2
-7
vllm/v1/worker/gpu_input_batch.py
vllm/v1/worker/gpu_input_batch.py
+4
-6
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+15
-13
vllm/v1/worker/tpu_model_runner.py
vllm/v1/worker/tpu_model_runner.py
+8
-12
No files found.
tests/v1/core/test_encoder_cache_manager.py
View file @
0377802c
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
vllm.multimodal.inputs
import
MultiModalFeatureSpec
,
PlaceholderRange
from
vllm.v1.core.encoder_cache_manager
import
EncoderCacheManager
...
...
@@ -9,8 +10,17 @@ class MockRequest:
def
__init__
(
self
,
request_id
,
mm_hashes
,
token_counts
):
self
.
request_id
=
request_id
self
.
mm_hashes
=
mm_hashes
self
.
_token_counts
=
token_counts
self
.
mm_features
=
[]
for
i
,
mm_hash
in
enumerate
(
mm_hashes
):
feature
=
MultiModalFeatureSpec
(
data
=
None
,
modality
=
"image"
,
identifier
=
mm_hash
,
mm_position
=
PlaceholderRange
(
offset
=
0
,
length
=
self
.
_token_counts
[
i
]),
)
self
.
mm_features
.
append
(
feature
)
def
get_num_encoder_tokens
(
self
,
input_id
:
int
)
->
int
:
return
self
.
_token_counts
[
input_id
]
...
...
tests/v1/tpu/worker/test_tpu_model_runner.py
View file @
0377802c
...
...
@@ -64,9 +64,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
NewRequestData
(
req_id
=
req_id
,
prompt_token_ids
=
[
1
,
2
,
3
],
mm_kwargs
=
[],
mm_hashes
=
[],
mm_positions
=
[],
mm_features
=
[],
sampling_params
=
SamplingParams
(),
pooling_params
=
PoolingParams
(),
block_ids
=
([
0
],
),
# block_ids should be tuple[list[int]]
...
...
tests/v1/worker/test_gpu_input_batch.py
View file @
0377802c
...
...
@@ -203,9 +203,7 @@ def _construct_cached_request_state(req_id_suffix: int):
prompt_token_ids
=
prompt_token_ids
,
sampling_params
=
_create_sampling_params
(),
pooling_params
=
None
,
mm_kwargs
=
[],
mm_positions
=
[],
mm_hashes
=
[],
mm_features
=
[],
block_ids
=
([],
),
generator
=
None
,
num_computed_tokens
=
len
(
output_token_ids
),
...
...
tests/v1/worker/test_gpu_model_runner.py
View file @
0377802c
...
...
@@ -118,9 +118,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
NewRequestData
(
req_id
=
req_id
,
prompt_token_ids
=
[
1
,
2
,
3
],
mm_kwargs
=
[],
mm_hashes
=
[],
mm_positions
=
[],
mm_features
=
[],
sampling_params
=
SamplingParams
(),
pooling_params
=
None
,
block_ids
=
([
0
],
),
...
...
vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py
View file @
0377802c
...
...
@@ -300,11 +300,12 @@ class SharedStorageConnector(KVConnectorBase_V1):
total_need_load
=
0
for
new_req
in
scheduler_output
.
scheduled_new_reqs
:
if
new_req
.
req_id
in
self
.
_requests_need_load
:
meta
.
add_request
(
token_ids
=
new_req
.
prompt_token_ids
,
meta
.
add_request
(
token_ids
=
new_req
.
prompt_token_ids
,
block_ids
=
new_req
.
block_ids
[
0
],
block_size
=
self
.
_block_size
,
is_store
=
False
,
mm_hashes
=
new_req
.
mm_
hash
es
)
mm_hashes
=
[
f
.
identifier
for
f
in
new_req
.
mm_
featur
es
]
)
total_need_load
+=
1
else
:
# NOTE: here, we set the store and load being exclusive,
...
...
@@ -312,11 +313,12 @@ class SharedStorageConnector(KVConnectorBase_V1):
# NOTE(rob): for this debug implementation, we only cache
# the original prompt tokens.
if
not
self
.
_found_match_for_request
(
new_req
):
meta
.
add_request
(
token_ids
=
new_req
.
prompt_token_ids
,
meta
.
add_request
(
token_ids
=
new_req
.
prompt_token_ids
,
block_ids
=
new_req
.
block_ids
[
0
],
block_size
=
self
.
_block_size
,
is_store
=
True
,
mm_hashes
=
new_req
.
mm_
hash
es
)
mm_hashes
=
[
f
.
identifier
for
f
in
new_req
.
mm_
featur
es
]
)
cached_reqs
=
scheduler_output
.
scheduled_cached_reqs
for
i
,
req_id
in
enumerate
(
cached_reqs
.
req_ids
):
...
...
@@ -341,11 +343,12 @@ class SharedStorageConnector(KVConnectorBase_V1):
# of the block_ids for the request.
block_ids
=
new_block_ids
[
0
]
meta
.
add_request
(
token_ids
=
token_ids
,
meta
.
add_request
(
token_ids
=
token_ids
,
block_ids
=
block_ids
,
block_size
=
self
.
_block_size
,
is_store
=
False
,
mm_hashes
=
request
.
mm_
hash
es
)
mm_hashes
=
[
f
.
identifier
for
f
in
request
.
mm_
featur
es
]
)
total_need_load
+=
1
assert
total_need_load
==
len
(
self
.
_requests_need_load
)
...
...
@@ -364,9 +367,9 @@ class SharedStorageConnector(KVConnectorBase_V1):
"""
num_tokens_to_check
=
align_to_block_size
(
len
(
request
.
prompt_token_ids
)
-
1
,
self
.
_block_size
)
foldername
=
self
.
_generate_foldername_debug
(
torch
.
tensor
(
request
.
prompt_token_ids
)[:
num_tokens_to_check
],
request
.
mm_
hash
es
,
foldername
=
self
.
_generate_foldername_debug
(
torch
.
tensor
(
request
.
prompt_token_ids
)[:
num_tokens_to_check
],
[
f
.
identifier
for
f
in
request
.
mm_
featur
es
]
,
create_folder
=
False
)
return
os
.
path
.
exists
(
foldername
)
...
...
vllm/v1/core/encoder_cache_manager.py
View file @
0377802c
...
...
@@ -86,7 +86,7 @@ class EncoderCacheManager:
Returns:
True if the encoder output for this input is already cached
"""
mm_hash
=
request
.
mm_
hash
es
[
input_id
]
mm_hash
=
request
.
mm_
featur
es
[
input_id
]
.
identifier
# Not cached at all
if
mm_hash
not
in
self
.
cached
:
return
False
...
...
@@ -167,7 +167,7 @@ class EncoderCacheManager:
This method assumes can_allocate() returned True for the same input.
"""
mm_hash
=
request
.
mm_
hash
es
[
input_id
]
mm_hash
=
request
.
mm_
featur
es
[
input_id
]
.
identifier
request_id
=
request
.
request_id
if
mm_hash
not
in
self
.
cached
:
self
.
cached
[
mm_hash
]
=
set
()
...
...
@@ -193,8 +193,8 @@ class EncoderCacheManager:
"""
return
{
input_id
for
input_id
in
range
(
len
(
request
.
mm_
hash
es
))
if
request
.
mm_
hash
es
[
input_id
]
in
self
.
cached
for
input_id
in
range
(
len
(
request
.
mm_
featur
es
))
if
request
.
mm_
featur
es
[
input_id
]
.
identifier
in
self
.
cached
}
def
free_encoder_input
(
self
,
request
:
Request
,
input_id
:
int
)
->
None
:
...
...
@@ -208,7 +208,7 @@ class EncoderCacheManager:
`can_allocate`).
"""
req_id
=
request
.
request_id
mm_hash
=
request
.
mm_
hash
es
[
input_id
]
mm_hash
=
request
.
mm_
featur
es
[
input_id
]
.
identifier
# The mm_hash not in cache or the req_id set is empty
if
not
self
.
cached
.
get
(
mm_hash
,
None
):
return
...
...
vllm/v1/core/kv_cache_utils.py
View file @
0377802c
...
...
@@ -418,7 +418,7 @@ def need_extra_keys(request: Request) -> bool:
# Multimodal requests need to include the MM hash.
# LoRA requests need to include the LoRA ID.
# Request with provided cache salt need to include the salt.
return
bool
(
request
.
mm_
hash
es
)
or
(
request
.
lora_request
return
bool
(
request
.
mm_
featur
es
)
or
(
request
.
lora_request
is
not
None
)
or
(
request
.
cache_salt
is
not
None
)
...
...
@@ -442,32 +442,28 @@ def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int,
"""
extra_keys
:
list
[
Any
]
=
[]
mm_
positions
,
mm_hashes
=
request
.
mm_positions
,
request
.
mm_hash
es
if
not
mm_
position
s
:
mm_
features
=
request
.
mm_featur
es
if
not
mm_
feature
s
:
return
extra_keys
,
start_mm_idx
if
mm_positions
and
len
(
mm_positions
)
!=
len
(
mm_hashes
):
raise
ValueError
(
"The number of multi-modal positions and hashes must match. This "
"is likely because you did not enable MM hashing. "
"Please set `mm_processor_cache_gb > 0`."
)
# Note that we assume mm_positions is sorted by offset.
# Note that we assume mm_features are sorted by mm_position.offset.
# We do not need to check all mm inputs if the start token index is out of
# range. This usually happens in the late prefill phase and decoding phase.
if
mm_positions
[
-
1
].
offset
+
mm_positions
[
-
1
].
length
<
start_token_idx
:
last_pos
=
mm_features
[
-
1
].
mm_position
if
last_pos
.
offset
+
last_pos
.
length
<
start_token_idx
:
return
extra_keys
,
start_mm_idx
# Support start_mm_idx == -1 to indicate the last mm input.
if
start_mm_idx
<
0
:
assert
-
start_mm_idx
<=
len
(
mm_
position
s
)
start_mm_idx
=
len
(
mm_
position
s
)
+
start_mm_idx
assert
-
start_mm_idx
<=
len
(
mm_
feature
s
)
start_mm_idx
=
len
(
mm_
feature
s
)
+
start_mm_idx
curr_mm_idx
=
start_mm_idx
while
mm_positions
and
curr_mm_idx
<
len
(
mm_positions
):
assert
mm_hashes
[
curr_mm_idx
]
is
not
None
offset
=
mm_positions
[
curr_mm_idx
].
offset
length
=
mm_positions
[
curr_mm_idx
].
length
while
mm_features
and
curr_mm_idx
<
len
(
mm_features
):
mm_feature
=
mm_features
[
curr_mm_idx
]
assert
mm_feature
.
identifier
is
not
None
offset
=
mm_feature
.
mm_position
.
offset
length
=
mm_feature
.
mm_position
.
length
if
end_token_idx
>
offset
:
if
start_token_idx
>
offset
+
length
:
# This block has passed the current mm input.
...
...
@@ -475,7 +471,7 @@ def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int,
continue
# The block contains the current mm input.
extra_keys
.
append
(
mm_
hashes
[
curr_mm_idx
]
)
extra_keys
.
append
(
mm_
feature
.
identifier
)
if
end_token_idx
>=
offset
+
length
:
# If this block contains the end of the current mm input,
...
...
vllm/v1/core/sched/output.py
View file @
0377802c
...
...
@@ -15,7 +15,7 @@ if TYPE_CHECKING:
from
vllm.distributed.kv_transfer.kv_connector.v1.base
import
(
KVConnectorMetadata
)
from
vllm.lora.request
import
LoRARequest
from
vllm.multimodal.inputs
import
MultiModal
KwargsItem
,
PlaceholderRange
from
vllm.multimodal.inputs
import
MultiModal
FeatureSpec
from
vllm.pooling_params
import
PoolingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.v1.request
import
Request
...
...
@@ -27,9 +27,7 @@ class NewRequestData:
req_id
:
str
prompt_token_ids
:
list
[
int
]
mm_kwargs
:
list
[
MultiModalKwargsItem
]
mm_hashes
:
list
[
str
]
mm_positions
:
list
[
PlaceholderRange
]
mm_features
:
list
[
MultiModalFeatureSpec
]
sampling_params
:
Optional
[
SamplingParams
]
pooling_params
:
Optional
[
PoolingParams
]
block_ids
:
tuple
[
list
[
int
],
...]
...
...
@@ -45,9 +43,7 @@ class NewRequestData:
return
cls
(
req_id
=
request
.
request_id
,
prompt_token_ids
=
request
.
prompt_token_ids
,
mm_kwargs
=
request
.
mm_kwargs
,
mm_hashes
=
request
.
mm_hashes
,
mm_positions
=
request
.
mm_positions
,
mm_features
=
request
.
mm_features
,
sampling_params
=
request
.
sampling_params
,
pooling_params
=
request
.
pooling_params
,
block_ids
=
block_ids
,
...
...
@@ -59,9 +55,7 @@ class NewRequestData:
return
(
f
"NewRequestData("
f
"req_id=
{
self
.
req_id
}
,"
f
"prompt_token_ids=
{
self
.
prompt_token_ids
}
,"
f
"mm_kwargs=
{
self
.
mm_kwargs
}
,"
f
"mm_hashes=
{
self
.
mm_hashes
}
,"
f
"mm_positions=
{
self
.
mm_positions
}
,"
f
"mm_features=
{
self
.
mm_features
}
,"
f
"sampling_params=
{
self
.
sampling_params
}
,"
f
"block_ids=
{
self
.
block_ids
}
,"
f
"num_computed_tokens=
{
self
.
num_computed_tokens
}
,"
...
...
@@ -73,9 +67,7 @@ class NewRequestData:
return
(
f
"NewRequestData("
f
"req_id=
{
self
.
req_id
}
,"
f
"prompt_token_ids_len=
{
len
(
self
.
prompt_token_ids
)
}
,"
f
"mm_kwargs=
{
self
.
mm_kwargs
}
,"
f
"mm_hashes=
{
self
.
mm_hashes
}
,"
f
"mm_positions=
{
self
.
mm_positions
}
,"
f
"mm_features=
{
self
.
mm_features
}
,"
f
"sampling_params=
{
self
.
sampling_params
}
,"
f
"block_ids=
{
self
.
block_ids
}
,"
f
"num_computed_tokens=
{
self
.
num_computed_tokens
}
,"
...
...
vllm/v1/core/sched/scheduler.py
View file @
0377802c
...
...
@@ -736,18 +736,18 @@ class Scheduler(SchedulerInterface):
if
num_new_tokens
==
0
or
not
request
.
has_encoder_inputs
:
return
[],
num_new_tokens
,
encoder_compute_budget
encoder_inputs_to_schedule
:
list
[
int
]
=
[]
mm_
position
s
=
request
.
mm_
position
s
assert
mm_
position
s
is
not
None
assert
len
(
mm_
position
s
)
>
0
mm_
feature
s
=
request
.
mm_
feature
s
assert
mm_
feature
s
is
not
None
assert
len
(
mm_
feature
s
)
>
0
# NOTE: since scheduler operates on the request level (possibly with
# multiple encoder inputs per request), we need to create temporary
# trackers for accounting at the encoder input level.
mm_hashes_to_schedule
=
set
()
num_tokens_to_schedule
=
0
for
i
,
pos_info
in
enumerate
(
mm_
position
s
):
start_pos
=
pos_info
.
offset
num_encoder_tokens
=
pos_info
.
length
for
i
,
mm_feature
in
enumerate
(
mm_
feature
s
):
start_pos
=
mm_feature
.
mm_position
.
offset
num_encoder_tokens
=
mm_feature
.
mm_position
.
length
# The encoder output is needed if the two ranges overlap:
# [num_computed_tokens, num_computed_tokens + num_new_tokens) and
...
...
@@ -778,7 +778,7 @@ class Scheduler(SchedulerInterface):
if
not
self
.
is_encoder_decoder
:
# We are not using the encoder cache for encoder-decoder models,
# yet.
if
request
.
mm_
hashes
[
i
]
in
mm_hashes_to_schedule
:
if
request
.
mm_
features
[
i
].
identifier
in
mm_hashes_to_schedule
:
# The same encoder input has already been scheduled in the
# current step.
continue
...
...
@@ -820,7 +820,7 @@ class Scheduler(SchedulerInterface):
num_tokens_to_schedule
+=
num_encoder_tokens
encoder_compute_budget
-=
num_encoder_tokens
mm_hashes_to_schedule
.
add
(
request
.
mm_
hashes
[
i
]
)
mm_hashes_to_schedule
.
add
(
request
.
mm_
features
[
i
].
identifier
)
encoder_inputs_to_schedule
.
append
(
i
)
return
(
...
...
@@ -1048,9 +1048,9 @@ class Scheduler(SchedulerInterface):
# Here, we use list(set) to avoid modifying the set while iterating
# over it.
for
input_id
in
list
(
cached_encoder_input_ids
):
mm_
positions
=
request
.
mm_
position
s
[
input_id
]
start_pos
=
mm_position
s
.
offset
num_tokens
=
mm_position
s
.
length
mm_
feature
=
request
.
mm_
feature
s
[
input_id
]
start_pos
=
mm_feature
.
mm_position
.
offset
num_tokens
=
mm_feature
.
mm_position
.
length
if
self
.
is_encoder_decoder
and
request
.
num_computed_tokens
>
0
:
# With Whisper, as soon as we've generated a single token,
# we know we're done with the encoder input. Cross Attention
...
...
vllm/v1/request.py
View file @
0377802c
...
...
@@ -91,11 +91,6 @@ class Request:
self
.
mm_features
=
mm_features
or
[]
self
.
num_encoder_inputs
=
len
(
self
.
mm_features
)
self
.
has_encoder_inputs
=
self
.
num_encoder_inputs
>
0
# TODO(sfeng33): Remove these legacy fields after clearing out all
# references in scheduler and model runner
self
.
mm_positions
=
[
f
.
mm_position
for
f
in
self
.
mm_features
]
self
.
mm_kwargs
=
[
f
.
data
for
f
in
self
.
mm_features
]
self
.
mm_hashes
=
[
f
.
identifier
for
f
in
self
.
mm_features
]
# Read-only views
# Prevent directly appending to these lists since
...
...
@@ -180,8 +175,8 @@ class Request:
return
RequestStatus
.
get_finished_reason
(
self
.
status
)
def
get_num_encoder_tokens
(
self
,
input_id
:
int
)
->
int
:
assert
input_id
<
len
(
self
.
mm_
position
s
)
num_tokens
=
self
.
mm_
position
s
[
input_id
].
length
assert
input_id
<
len
(
self
.
mm_
feature
s
)
num_tokens
=
self
.
mm_
feature
s
[
input_id
].
mm_position
.
length
return
num_tokens
def
record_event
(
...
...
vllm/v1/worker/gpu_input_batch.py
View file @
0377802c
...
...
@@ -10,8 +10,7 @@ import torch
from
typing_extensions
import
deprecated
from
vllm.lora.request
import
LoRARequest
from
vllm.multimodal.inputs
import
(
MultiModalKwargsItem
,
MultiModalKwargsItems
,
PlaceholderRange
)
from
vllm.multimodal.inputs
import
MultiModalFeatureSpec
,
MultiModalKwargsItems
from
vllm.pooling_params
import
PoolingParams
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.utils
import
swap_dict_values
...
...
@@ -31,9 +30,7 @@ class CachedRequestState:
req_id
:
str
prompt_token_ids
:
list
[
int
]
mm_kwargs
:
list
[
MultiModalKwargsItem
]
mm_positions
:
list
[
PlaceholderRange
]
mm_hashes
:
list
[
str
]
mm_features
:
list
[
MultiModalFeatureSpec
]
sampling_params
:
Optional
[
SamplingParams
]
pooling_params
:
Optional
[
PoolingParams
]
generator
:
Optional
[
torch
.
Generator
]
...
...
@@ -60,7 +57,8 @@ class CachedRequestState:
"removed in v0.13. Please use `mm_kwargs` instead."
)
def
mm_inputs
(
self
)
->
list
[
MultiModalKwargsItems
]:
return
[
MultiModalKwargsItems
.
from_seq
([
item
])
for
item
in
self
.
mm_kwargs
MultiModalKwargsItems
.
from_seq
([
f
.
data
])
for
f
in
self
.
mm_features
if
f
.
data
is
not
None
]
def
get_token_id
(
self
,
idx
:
int
)
->
int
:
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
0377802c
...
...
@@ -555,9 +555,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
req_state
=
CachedRequestState
(
req_id
=
req_id
,
prompt_token_ids
=
new_req_data
.
prompt_token_ids
,
mm_kwargs
=
new_req_data
.
mm_kwargs
,
mm_positions
=
new_req_data
.
mm_positions
,
mm_hashes
=
new_req_data
.
mm_hashes
,
mm_features
=
new_req_data
.
mm_features
,
sampling_params
=
sampling_params
,
pooling_params
=
pooling_params
,
generator
=
generator
,
...
...
@@ -698,7 +696,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
second_per_grid_ts
=
[]
audio_feature_lengths
=
[]
use_audio_in_video
=
False
for
mm_item
in
req_state
.
mm_kwargs
:
for
mm_feature
in
req_state
.
mm_features
:
mm_item
=
mm_feature
.
data
if
mm_item
is
None
:
continue
mm_input
=
mm_item
.
get_data
()
if
(
t
:
=
mm_input
.
get
(
"image_grid_thw"
))
is
not
None
:
image_grid_thw
.
append
(
t
.
tolist
())
...
...
@@ -731,7 +732,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
mm_kwargs
=
list
[
MultiModalKwargsItem
]()
for
req
in
scheduler_output
.
scheduled_new_reqs
:
mm_kwargs
.
extend
(
req
.
mm_kwargs
)
for
feature
in
req
.
mm_features
:
if
feature
.
data
is
not
None
:
mm_kwargs
.
append
(
feature
.
data
)
# Input all modalities at once
mm_kwargs_combined
:
BatchedTensorInputs
=
{}
...
...
@@ -1361,10 +1364,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
req_state
=
self
.
requests
[
req_id
]
for
mm_input_id
in
encoder_input_ids
:
mm_
hash
=
req_state
.
mm_
hash
es
[
mm_input_id
]
mm_
kwargs
.
append
(
req_state
.
mm_kwargs
[
mm_input_id
])
mm_
hashes_pos
.
append
(
(
mm_hash
,
req_stat
e
.
mm_position
s
[
mm_input_id
]
))
mm_
feature
=
req_state
.
mm_
featur
es
[
mm_input_id
]
mm_
hash
=
mm_feature
.
identifier
mm_
kwargs
.
append
(
mm_feature
.
data
)
mm_hashes_pos
.
append
(
(
mm_hash
,
mm_featur
e
.
mm_position
))
return
mm_kwargs
,
mm_hashes_pos
...
...
@@ -1426,9 +1429,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
req_state
=
self
.
requests
[
req_id
]
num_computed_tokens
=
\
req_state
.
num_computed_tokens
+
shift_computed_tokens
mm_positions
=
req_state
.
mm_positions
mm_hashes
=
req_state
.
mm_hashes
for
i
,
pos_info
in
enumerate
(
mm_positions
):
for
mm_feature
in
req_state
.
mm_features
:
pos_info
=
mm_feature
.
mm_position
start_pos
=
pos_info
.
offset
num_encoder_tokens
=
pos_info
.
length
...
...
@@ -1451,7 +1453,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
)
assert
start_idx
<
end_idx
mm_hash
=
mm_
hashes
[
i
]
mm_hash
=
mm_
feature
.
identifier
encoder_output
=
self
.
encoder_cache
.
get
(
mm_hash
,
None
)
assert
encoder_output
is
not
None
,
\
f
"Encoder cache miss for
{
mm_hash
}
."
...
...
vllm/v1/worker/tpu_model_runner.py
View file @
0377802c
...
...
@@ -387,9 +387,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self
.
requests
[
req_id
]
=
CachedRequestState
(
req_id
=
req_id
,
prompt_token_ids
=
new_req_data
.
prompt_token_ids
,
mm_kwargs
=
new_req_data
.
mm_kwargs
,
mm_positions
=
new_req_data
.
mm_positions
,
mm_hashes
=
new_req_data
.
mm_hashes
,
mm_features
=
new_req_data
.
mm_features
,
sampling_params
=
sampling_params
,
pooling_params
=
None
,
generator
=
None
,
...
...
@@ -822,10 +820,10 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
req_state
=
self
.
requests
[
req_id
]
for
mm_input_id
in
encoder_input_ids
:
mm_
hash
=
req_state
.
mm_
hash
es
[
mm_input_id
]
mm_
kwargs
.
append
(
req_state
.
mm_kwargs
[
mm_input_id
])
mm_
hashes_pos
.
append
(
(
mm_hash
,
req_stat
e
.
mm_position
s
[
mm_input_id
]
))
mm_
feature
=
req_state
.
mm_
featur
es
[
mm_input_id
]
mm_
hash
=
mm_feature
.
identifier
mm_
kwargs
.
append
(
mm_feature
.
data
)
mm_hashes_pos
.
append
(
(
mm_hash
,
mm_featur
e
.
mm_position
))
# Batch mm inputs as much as we can: if a request in the batch has
# multiple modalities or a different modality than the previous one,
...
...
@@ -883,13 +881,12 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
req_id
]
req_state
=
self
.
requests
[
req_id
]
num_computed_tokens
=
req_state
.
num_computed_tokens
mm_positions
=
req_state
.
mm_positions
mm_hashes
=
req_state
.
mm_hashes
# TODO unroll loop and assume/enforce --disable_chunked_mm_input
# NOTE (NickLucche) here we diverge from logic in other runners, as
# we assume to only have whole mm items to process. Hence we avoid
# the intrinsic dynamism that `gather_mm_placeholders` introduces.
for
i
,
pos_info
in
enumerate
(
mm_positions
):
for
mm_feature
in
req_state
.
mm_features
:
pos_info
=
mm_feature
.
mm_position
start_pos
=
pos_info
.
offset
num_encoder_tokens
=
pos_info
.
length
...
...
@@ -904,8 +901,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# The encoder output is already processed and stored
# in the decoder's KV cache.
continue
mm_hash
=
mm_hashes
[
i
]
mm_hash
=
mm_feature
.
identifier
encoder_output
=
self
.
encoder_cache
.
get
(
mm_hash
,
None
)
assert
encoder_output
is
not
None
,
\
f
"Encoder cache miss for
{
mm_hash
}
."
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment