Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ca6a95ba
Unverified
Commit
ca6a95ba
authored
Dec 24, 2025
by
Cyrus Leung
Committed by
GitHub
Dec 23, 2025
Browse files
[Chore] Simplify logic of `_execute_mm_encoder` (#31222)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
bc0a5a0c
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
21 additions
and
25 deletions
+21
-25
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+21
-25
No files found.
vllm/v1/worker/gpu_model_runner.py
View file @
ca6a95ba
...
...
@@ -61,6 +61,7 @@ from vllm.model_executor.layers.rotary_embedding import (
)
from
vllm.model_executor.model_loader
import
TensorizerLoader
,
get_model_loader
from
vllm.model_executor.models.interfaces
import
(
MultiModalEmbeddings
,
SupportsMRoPE
,
SupportsMultiModal
,
SupportsXDRoPE
,
...
...
@@ -78,11 +79,7 @@ from vllm.model_executor.models.interfaces_base import (
is_text_generation_model
,
)
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
(
BatchedTensorInputs
,
MultiModalKwargsItem
,
PlaceholderRange
,
)
from
vllm.multimodal.inputs
import
BatchedTensorInputs
,
MultiModalKwargsItem
from
vllm.multimodal.utils
import
group_mm_kwargs_by_modality
from
vllm.pooling_params
import
PoolingParams
from
vllm.sampling_params
import
SamplingType
...
...
@@ -2097,28 +2094,27 @@ class GPUModelRunner(
]
return
logits_indices_padded
def
_batch_mm_
kwarg
s_from_scheduler
(
def
_batch_mm_
input
s_from_scheduler
(
self
,
scheduler_output
:
"SchedulerOutput"
,
)
->
tuple
[
list
[
MultiModalKwargsItem
]
,
list
[
tuple
[
str
,
PlaceholderRange
]]
]:
"""Batch multimodal
kwarg
s from scheduled encoder inputs.
)
->
tuple
[
list
[
str
],
list
[
MultiModalKwargsItem
]]:
"""Batch multimodal
input
s from scheduled encoder inputs.
Args:
scheduler_output: The scheduler output containing scheduled encoder
inputs.
Returns:
A tuple of (mm_
kwargs, req_ids_po
s) where:
- mm_
kwarg
s: List of multimodal
kwargs items to be batched
- mm_
hashes_po
s: List of
(mm_hash, position_info) tuples
A tuple of (mm_
hashes, mm_kwarg
s) where:
- mm_
hashe
s: List of multimodal
hashes for each item
- mm_
kwarg
s: List of
multimodal kwargs for each item
"""
scheduled_encoder_inputs
=
scheduler_output
.
scheduled_encoder_inputs
if
not
scheduled_encoder_inputs
:
return
[],
[]
# Batch the multi-modal inputs.
mm_hashes
=
list
[
str
]()
mm_kwargs
=
list
[
MultiModalKwargsItem
]()
# list of tuple (mm_hash, position_info)
mm_hashes_pos
=
list
[
tuple
[
str
,
PlaceholderRange
]]()
for
req_id
,
encoder_input_ids
in
scheduled_encoder_inputs
.
items
():
req_state
=
self
.
requests
[
req_id
]
...
...
@@ -2126,19 +2122,16 @@ class GPUModelRunner(
mm_feature
=
req_state
.
mm_features
[
mm_input_id
]
if
mm_feature
.
data
is
None
:
continue
mm_hash
=
mm_feature
.
identifier
mm_hashes
.
append
(
mm_feature
.
identifier
)
mm_kwargs
.
append
(
mm_feature
.
data
)
mm_hashes_pos
.
append
((
mm_hash
,
mm_feature
.
mm_position
))
return
mm_
kwarg
s
,
mm_
hashes_po
s
return
mm_
hashe
s
,
mm_
kwarg
s
def
_execute_mm_encoder
(
self
,
scheduler_output
:
"SchedulerOutput"
)
->
list
[
torch
.
Tensor
]:
# Batch the multi-modal inputs using the helper method.
mm_kwargs
,
mm_hashes_pos
=
self
.
_batch_mm_kwargs_from_scheduler
(
scheduler_output
)
mm_hashes
,
mm_kwargs
=
self
.
_batch_mm_inputs_from_scheduler
(
scheduler_output
)
if
not
mm_kwargs
:
return
[]
...
...
@@ -2157,7 +2150,7 @@ class GPUModelRunner(
device
=
self
.
device
,
pin_memory
=
self
.
pin_memory
,
):
curr_group_outputs
:
list
[
torch
.
Tensor
]
=
[]
curr_group_outputs
:
MultiModalEmbeddings
# EVS-related change.
# (ekhvedchenia): Temporary hack to limit peak memory usage when
...
...
@@ -2173,6 +2166,7 @@ class GPUModelRunner(
and
modality
==
"video"
and
num_items
>
1
):
curr_group_outputs_lst
=
list
[
torch
.
Tensor
]()
for
video_mm_kwargs_item
in
filter
(
lambda
item
:
item
.
modality
==
"video"
,
mm_kwargs
):
...
...
@@ -2188,7 +2182,9 @@ class GPUModelRunner(
**
micro_batch_mm_inputs
)
curr_group_outputs
.
extend
(
micro_batch_outputs
)
curr_group_outputs_lst
.
extend
(
micro_batch_outputs
)
curr_group_outputs
=
curr_group_outputs_lst
else
:
# Run the encoder.
# `curr_group_outputs` is either of the following:
...
...
@@ -2197,7 +2193,7 @@ class GPUModelRunner(
# 2. A list or tuple (length: num_items) of tensors,
# each of shape (feature_size, hidden_size) in case the feature
# size is dynamic depending on the input multimodal items.
curr_group_outputs
=
model
.
embed_multimodal
(
**
mm_kwargs_group
)
# type: ignore[assignment]
curr_group_outputs
=
model
.
embed_multimodal
(
**
mm_kwargs_group
)
sanity_check_mm_encoder_outputs
(
curr_group_outputs
,
...
...
@@ -2206,7 +2202,7 @@ class GPUModelRunner(
encoder_outputs
.
extend
(
curr_group_outputs
)
# Cache the encoder outputs by mm_hash
for
(
mm_hash
,
pos_info
),
output
in
zip
(
mm_hashes
_pos
,
encoder_outputs
):
for
mm_hash
,
output
in
zip
(
mm_hashes
,
encoder_outputs
):
self
.
encoder_cache
[
mm_hash
]
=
output
logger
.
debug
(
"Finish execute for mm hash %s"
,
mm_hash
)
self
.
maybe_save_ec_to_connector
(
self
.
encoder_cache
,
mm_hash
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment