Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
af51d80f
Unverified
Commit
af51d80f
authored
Apr 04, 2025
by
Roger Wang
Committed by
GitHub
Apr 04, 2025
Browse files
Revert "[V1] Scatter and gather placeholders in the model runner" (#16075)
parent
f5722a50
Changes
42
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
20 additions
and
110 deletions
+20
-110
vllm/v1/worker/tpu_model_runner.py
vllm/v1/worker/tpu_model_runner.py
+20
-65
vllm/v1/worker/utils.py
vllm/v1/worker/utils.py
+0
-45
No files found.
vllm/v1/worker/tpu_model_runner.py
View file @
af51d80f
...
...
@@ -19,8 +19,7 @@ from vllm.config import VllmConfig
from
vllm.forward_context
import
set_forward_context
from
vllm.logger
import
init_logger
from
vllm.model_executor.model_loader
import
get_model
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
MultiModalKwargs
,
PlaceholderRange
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalKwargs
from
vllm.multimodal.utils
import
group_mm_inputs_by_modality
from
vllm.sampling_params
import
SamplingType
from
vllm.sequence
import
IntermediateTensors
...
...
@@ -37,8 +36,7 @@ from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler
from
vllm.v1.utils
import
bind_kv_cache
from
vllm.v1.worker.gpu_input_batch
import
CachedRequestState
,
InputBatch
from
.utils
import
(
gather_mm_placeholders
,
sanity_check_mm_encoder_outputs
,
scatter_mm_placeholders
)
from
.utils
import
sanity_check_mm_encoder_outputs
if
TYPE_CHECKING
:
from
vllm.v1.core.sched.output
import
SchedulerOutput
...
...
@@ -509,47 +507,19 @@ class TPUModelRunner:
logits_indices
=
logits_indices
.
to
(
self
.
device
)
return
attn_metadata
,
logits_indices
def
_scatter_placeholders
(
self
,
embeds
:
torch
.
Tensor
,
is_embed
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
if
is_embed
is
None
:
return
embeds
placeholders
=
embeds
.
new_full
(
(
is_embed
.
shape
[
0
],
embeds
.
shape
[
-
1
]),
fill_value
=
torch
.
nan
,
)
placeholders
[
is_embed
]
=
embeds
return
placeholders
def
_gather_placeholders
(
self
,
placeholders
:
torch
.
Tensor
,
is_embed
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
if
is_embed
is
None
:
return
placeholders
return
placeholders
[
is_embed
]
def
_execute_mm_encoder
(
self
,
scheduler_output
:
"SchedulerOutput"
):
def
_execute_encoder
(
self
,
scheduler_output
:
"SchedulerOutput"
):
scheduled_encoder_inputs
=
scheduler_output
.
scheduled_encoder_inputs
if
not
scheduled_encoder_inputs
:
return
# Batch the multi-modal inputs.
mm_inputs
=
list
[
MultiModalKwargs
]
()
req_i
ds_pos
=
list
[
tuple
[
str
,
int
,
PlaceholderRange
]]()
mm_inputs
:
list
[
MultiModalKwargs
]
=
[]
req_i
nput_ids
:
list
[
tuple
[
str
,
int
]]
=
[]
for
req_id
,
encoder_input_ids
in
scheduled_encoder_inputs
.
items
():
req_state
=
self
.
requests
[
req_id
]
for
input_id
,
pos_info
in
zip
(
encoder_input_ids
,
req_state
.
mm_positions
,
):
for
input_id
in
encoder_input_ids
:
mm_inputs
.
append
(
req_state
.
mm_inputs
[
input_id
])
req_i
ds_po
s
.
append
((
req_id
,
input_id
,
pos_info
))
req_i
nput_id
s
.
append
((
req_id
,
input_id
))
# Batch mm inputs as much as we can: if a request in the batch has
# multiple modalities or a different modality than the previous one,
...
...
@@ -585,23 +555,16 @@ class TPUModelRunner:
encoder_outputs
.
append
(
output
)
# Cache the encoder outputs.
for
(
req_id
,
input_id
,
pos_info
),
output
in
zip
(
req_ids_pos
,
encoder_outputs
,
):
for
(
req_id
,
input_id
),
output
in
zip
(
req_input_ids
,
encoder_outputs
):
if
req_id
not
in
self
.
encoder_cache
:
self
.
encoder_cache
[
req_id
]
=
{}
self
.
encoder_cache
[
req_id
][
input_id
]
=
output
self
.
encoder_cache
[
req_id
][
input_id
]
=
scatter_mm_placeholders
(
output
,
is_embed
=
pos_info
.
is_embed
,
)
def
_gather_mm_embeddings
(
def
_gather_encoder_outputs
(
self
,
scheduler_output
:
"SchedulerOutput"
,
)
->
list
[
torch
.
Tensor
]:
mm_embed
s
:
list
[
torch
.
Tensor
]
=
[]
encoder_output
s
:
list
[
torch
.
Tensor
]
=
[]
for
req_id
in
self
.
input_batch
.
req_ids
:
num_scheduled_tokens
=
scheduler_output
.
num_scheduled_tokens
[
req_id
]
...
...
@@ -609,8 +572,8 @@ class TPUModelRunner:
num_computed_tokens
=
req_state
.
num_computed_tokens
mm_positions
=
req_state
.
mm_positions
for
i
,
pos_info
in
enumerate
(
mm_positions
):
start_pos
=
pos_info
.
offset
num_encoder_tokens
=
pos_info
.
length
start_pos
=
pos_info
[
"
offset
"
]
num_encoder_tokens
=
pos_info
[
"
length
"
]
# The encoder output is needed if the two ranges overlap:
# [num_computed_tokens,
...
...
@@ -632,16 +595,8 @@ class TPUModelRunner:
assert
req_id
in
self
.
encoder_cache
assert
i
in
self
.
encoder_cache
[
req_id
]
encoder_output
=
self
.
encoder_cache
[
req_id
][
i
]
if
(
is_embed
:
=
pos_info
.
is_embed
)
is
not
None
:
is_embed
=
is_embed
[
start_idx
:
end_idx
]
mm_embeds_item
=
gather_mm_placeholders
(
encoder_output
[
start_idx
:
end_idx
],
is_embed
=
is_embed
,
)
mm_embeds
.
append
(
mm_embeds_item
)
return
mm_embeds
encoder_outputs
.
append
(
encoder_output
[
start_idx
:
end_idx
])
return
encoder_outputs
@
torch
.
no_grad
()
def
execute_model
(
...
...
@@ -657,10 +612,10 @@ class TPUModelRunner:
if
self
.
is_multimodal_model
:
# Run the multimodal encoder if any.
self
.
_execute_
mm_
encoder
(
scheduler_output
)
mm_embed
s
=
self
.
_gather_
mm_embedding
s
(
scheduler_output
)
self
.
_execute_encoder
(
scheduler_output
)
encoder_output
s
=
self
.
_gather_
encoder_output
s
(
scheduler_output
)
else
:
mm_embed
s
=
[]
encoder_output
s
=
[]
# Prepare inputs
attn_metadata
,
logits_indices
=
self
.
_prepare_inputs
(
scheduler_output
)
...
...
@@ -668,9 +623,9 @@ class TPUModelRunner:
# NOTE(woosuk): To unify token ids and soft tokens (vision
# embeddings), we always use embeddings (rather than token ids)
# as input to the multimodal model, even when the input is text.
if
mm_embed
s
:
if
encoder_output
s
:
inputs_embeds
=
self
.
model
.
get_input_embeddings
(
self
.
input_ids
,
mm_embed
s
)
self
.
input_ids
,
encoder_output
s
)
else
:
inputs_embeds
=
self
.
model
.
get_input_embeddings
(
self
.
input_ids
)
input_ids
=
None
...
...
vllm/v1/worker/utils.py
View file @
af51d80f
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Optional
import
torch
...
...
@@ -29,46 +27,3 @@ def sanity_check_mm_encoder_outputs(
f
"but got tensors with shapes
{
[
e
.
shape
for
e
in
mm_embeddings
]
}
"
"instead. This is most likely due to incorrect implementation "
"of the model's `get_multimodal_embeddings` method."
)
def
scatter_mm_placeholders
(
embeds
:
torch
.
Tensor
,
is_embed
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
"""
Scatter the multimodal embeddings into a contiguous tensor that represents
the placeholder tokens.
:class:`vllm.multimodal.processing.PromptUpdateDetails.is_embed`.
Args:
embeds: The multimodal embeddings.
Shape: `(num_embeds, embed_dim)`
is_embed: A boolean mask indicating which positions in the placeholder
tokens need to be filled with multimodal embeddings.
Shape: `(num_placeholders, num_embeds)`
"""
if
is_embed
is
None
:
return
embeds
placeholders
=
embeds
.
new_full
(
(
is_embed
.
shape
[
0
],
embeds
.
shape
[
-
1
]),
fill_value
=
torch
.
nan
,
)
placeholders
[
is_embed
]
=
embeds
return
placeholders
def
gather_mm_placeholders
(
placeholders
:
torch
.
Tensor
,
is_embed
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
"""
Reconstructs the embeddings from the placeholder tokens.
This is the operation of :func:`scatter_mm_placeholders`.
"""
if
is_embed
is
None
:
return
placeholders
return
placeholders
[
is_embed
]
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment