Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
fbfc3ee3
Unverified
Commit
fbfc3ee3
authored
Mar 04, 2025
by
Michael Goin
Committed by
GitHub
Mar 04, 2025
Browse files
[V1][TPU] TPU multimodal model support for ragged attention (#14158)
Signed-off-by:
Michael Goin
<
mgoin64@gmail.com
>
parent
3e1d2236
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
194 additions
and
30 deletions
+194
-30
vllm/v1/worker/tpu_model_runner.py
vllm/v1/worker/tpu_model_runner.py
+193
-29
vllm/v1/worker/tpu_worker.py
vllm/v1/worker/tpu_worker.py
+1
-1
No files found.
vllm/v1/worker/tpu_model_runner.py
View file @
fbfc3ee3
...
...
@@ -15,14 +15,18 @@ from vllm.attention.backends.abstract import AttentionType
from
vllm.attention.layer
import
Attention
from
vllm.config
import
VllmConfig
from
vllm.forward_context
import
get_forward_context
,
set_forward_context
from
vllm.inputs
import
INPUT_REGISTRY
from
vllm.logger
import
init_logger
from
vllm.model_executor.model_loader
import
get_model
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalKwargs
from
vllm.multimodal.utils
import
group_mm_inputs_by_modality
from
vllm.sampling_params
import
SamplingType
from
vllm.utils
import
LayerBlockType
,
cdiv
,
is_pin_memory_available
from
vllm.v1.attention.backends.pallas
import
(
NUM_KV_PAGES_PER_BLOCK
,
NUM_QUERIES_PER_BLOCK
,
PallasAttentionBackend
,
PallasMetadata
)
from
vllm.v1.core.encoder_cache_manager
import
compute_encoder_budget
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
KVCacheConfig
,
KVCacheSpec
)
from
vllm.v1.outputs
import
LogprobsTensors
,
ModelRunnerOutput
...
...
@@ -72,8 +76,10 @@ class TPUModelRunner:
self
.
block_size
=
cache_config
.
block_size
self
.
max_model_len
=
model_config
.
max_model_len
self
.
max_num_blocks_per_req
=
cdiv
(
self
.
max_model_len
,
self
.
block_size
)
self
.
max_num_tokens
=
scheduler_config
.
max_num_batched_tokens
self
.
max_num_reqs
=
scheduler_config
.
max_num_seqs
self
.
max_num_tokens
=
_get_padded_number
(
scheduler_config
.
max_num_batched_tokens
,
NUM_QUERIES_PER_BLOCK
)
self
.
max_num_reqs
=
_get_padded_number
(
scheduler_config
.
max_num_seqs
,
NUM_QUERIES_PER_BLOCK
)
# Model-related.
self
.
num_attn_layers
=
model_config
.
get_num_layers_by_block_type
(
...
...
@@ -84,6 +90,28 @@ class TPUModelRunner:
self
.
head_size
=
model_config
.
get_head_size
()
self
.
hidden_size
=
model_config
.
get_hidden_size
()
# Multi-modal data support
self
.
input_registry
=
INPUT_REGISTRY
self
.
mm_registry
=
MULTIMODAL_REGISTRY
self
.
uses_mrope
=
model_config
.
uses_mrope
# TODO: Support M-RoPE (e.g, Qwen2-VL)
assert
not
self
.
uses_mrope
,
"TPU does not support M-RoPE yet."
encoder_compute_budget
,
encoder_cache_size
=
compute_encoder_budget
(
model_config
=
model_config
,
scheduler_config
=
scheduler_config
,
)
self
.
max_num_encoder_input_tokens
=
encoder_compute_budget
self
.
encoder_cache_size
=
encoder_cache_size
# Lazy initialization
# self.model: nn.Module # Set after load_model
self
.
kv_caches
:
list
[
torch
.
Tensor
]
=
[]
# req_id -> (input_id -> encoder_output)
self
.
encoder_cache
:
dict
[
str
,
dict
[
int
,
torch
.
Tensor
]]
=
{}
# Request states.
self
.
requests
:
dict
[
str
,
CachedRequestState
]
=
{}
# Persistent batch.
self
.
input_batch
=
InputBatch
(
max_num_reqs
=
self
.
max_num_reqs
,
...
...
@@ -91,18 +119,9 @@ class TPUModelRunner:
max_num_blocks_per_req
=
self
.
max_num_blocks_per_req
,
device
=
self
.
device
,
pin_memory
=
self
.
pin_memory
,
vocab_size
=
self
.
model_config
.
get_vocab_size
(),
vocab_size
=
model_config
.
get_vocab_size
(),
)
# Request states.
self
.
requests
:
dict
[
str
,
CachedRequestState
]
=
{}
# req_id -> (input_id -> encoder_output)
self
.
encoder_cache
:
dict
[
str
,
dict
[
int
,
torch
.
Tensor
]]
=
{}
# KV caches for forward pass
self
.
kv_caches
:
list
[
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]
=
[]
# Cached torch/numpy tensor
# The pytorch tensor and numpy array share the same buffer.
# Sometimes the numpy op is faster so we create both.
...
...
@@ -164,6 +183,7 @@ class TPUModelRunner:
# Remove finished requests from the cached states.
for
req_id
in
scheduler_output
.
finished_req_ids
:
self
.
requests
.
pop
(
req_id
,
None
)
self
.
encoder_cache
.
pop
(
req_id
,
None
)
# Remove the finished requests from the persistent batch.
# NOTE(woosuk): There could be an edge case where finished_req_ids and
...
...
@@ -177,6 +197,14 @@ class TPUModelRunner:
if
req_index
is
not
None
:
removed_req_indices
.
append
(
req_index
)
# Free the cached encoder outputs.
for
req_id
,
input_id
in
scheduler_output
.
free_encoder_input_ids
:
encoder_outputs
=
self
.
encoder_cache
.
get
(
req_id
)
if
encoder_outputs
is
not
None
:
encoder_outputs
.
pop
(
input_id
,
None
)
if
not
encoder_outputs
:
self
.
encoder_cache
.
pop
(
req_id
,
None
)
# Remove the unscheduled requests from the persistent batch.
# NOTE(woosuk): The unscheduled requests are either preempted requests
# or running requests that are not scheduled in this step. We remove
...
...
@@ -426,6 +454,92 @@ class TPUModelRunner:
logits_indices
=
query_start_loc
[
1
:]
-
1
return
attn_metadata
,
logits_indices
def
_execute_encoder
(
self
,
scheduler_output
:
"SchedulerOutput"
):
scheduled_encoder_inputs
=
scheduler_output
.
scheduled_encoder_inputs
if
not
scheduled_encoder_inputs
:
return
# Batch the multi-modal inputs.
mm_inputs
:
list
[
MultiModalKwargs
]
=
[]
req_input_ids
:
list
[
tuple
[
str
,
int
]]
=
[]
for
req_id
,
encoder_input_ids
in
scheduled_encoder_inputs
.
items
():
req_state
=
self
.
requests
[
req_id
]
for
input_id
in
encoder_input_ids
:
mm_inputs
.
append
(
req_state
.
mm_inputs
[
input_id
])
req_input_ids
.
append
((
req_id
,
input_id
))
# Batch mm inputs as much as we can: if a request in the batch has
# multiple modalities or a different modality than the previous one,
# we process it separately to preserve item order.
# FIXME(ywang96): This is a hacky way to deal with multiple modalities
# in the same batch while still being able to benefit from batching
# multimodal inputs. The proper solution should be reordering the
# encoder outputs.
grouped_mm_inputs_list
=
group_mm_inputs_by_modality
(
mm_inputs
)
encoder_outputs
=
[]
for
grouped_mm_inputs
in
grouped_mm_inputs_list
:
batched_mm_inputs
=
MultiModalKwargs
.
batch
(
grouped_mm_inputs
)
batched_mm_inputs
=
MultiModalKwargs
.
as_kwargs
(
batched_mm_inputs
,
device
=
self
.
device
)
# Run the encoder.
# `curr_group_outputs` is either of the following:
# 1. A tensor of shape (num_items, feature_size, hidden_size)
# in case feature_size is fixed across all multimodal items.
# 2. A list or tuple (length: num_items) of tensors, each of shape
# (feature_size, hidden_size) in case the feature size is dynamic
# depending on the input multimodal items.
curr_group_outputs
=
self
.
model
.
get_multimodal_embeddings
(
**
batched_mm_inputs
)
for
output
in
curr_group_outputs
:
encoder_outputs
.
append
(
output
)
# Cache the encoder outputs.
for
(
req_id
,
input_id
),
output
in
zip
(
req_input_ids
,
encoder_outputs
):
if
req_id
not
in
self
.
encoder_cache
:
self
.
encoder_cache
[
req_id
]
=
{}
self
.
encoder_cache
[
req_id
][
input_id
]
=
output
def
_gather_encoder_outputs
(
self
,
scheduler_output
:
"SchedulerOutput"
,
)
->
list
[
torch
.
Tensor
]:
encoder_outputs
:
list
[
torch
.
Tensor
]
=
[]
for
req_id
in
self
.
input_batch
.
req_ids
:
num_scheduled_tokens
=
scheduler_output
.
num_scheduled_tokens
[
req_id
]
req_state
=
self
.
requests
[
req_id
]
num_computed_tokens
=
req_state
.
num_computed_tokens
mm_positions
=
req_state
.
mm_positions
for
i
,
pos_info
in
enumerate
(
mm_positions
):
start_pos
=
pos_info
[
"offset"
]
num_encoder_tokens
=
pos_info
[
"length"
]
# The encoder output is needed if the two ranges overlap:
# [num_computed_tokens,
# num_computed_tokens + num_scheduled_tokens) and
# [start_pos, start_pos + num_encoder_tokens)
if
start_pos
>=
num_computed_tokens
+
num_scheduled_tokens
:
# The encoder output is not needed in this step.
break
if
start_pos
+
num_encoder_tokens
<=
num_computed_tokens
:
# The encoder output is already processed and stored
# in the decoder's KV cache.
continue
start_idx
=
max
(
num_computed_tokens
-
start_pos
,
0
)
end_idx
=
min
(
num_computed_tokens
-
start_pos
+
num_scheduled_tokens
,
num_encoder_tokens
)
assert
start_idx
<
end_idx
assert
req_id
in
self
.
encoder_cache
assert
i
in
self
.
encoder_cache
[
req_id
]
encoder_output
=
self
.
encoder_cache
[
req_id
][
i
]
encoder_outputs
.
append
(
encoder_output
[
start_idx
:
end_idx
])
return
encoder_outputs
@
torch
.
no_grad
()
def
execute_model
(
self
,
...
...
@@ -434,16 +548,42 @@ class TPUModelRunner:
# Update cached state
self
.
_update_states
(
scheduler_output
)
if
self
.
is_multimodal_model
:
# Run the multimodal encoder if any.
self
.
_execute_encoder
(
scheduler_output
)
encoder_outputs
=
self
.
_gather_encoder_outputs
(
scheduler_output
)
else
:
encoder_outputs
=
[]
# Prepare inputs
attn_metadata
,
logits_indices
=
self
.
_prepare_inputs
(
scheduler_output
)
total_num_scheduled_tokens
=
scheduler_output
.
total_num_scheduled_tokens
if
self
.
is_multimodal_model
:
# NOTE(woosuk): To unify token ids and soft tokens (vision
# embeddings), we always use embeddings (rather than token ids)
# as input to the multimodal model, even when the input is text.
if
encoder_outputs
:
inputs_embeds
=
self
.
model
.
get_input_embeddings
(
self
.
input_ids
,
encoder_outputs
)
else
:
inputs_embeds
=
self
.
model
.
get_input_embeddings
(
self
.
input_ids
)
input_ids
=
None
else
:
# For text-only models, we use token ids as input.
# While it is possible to use embeddings as input just like the
# multimodal models, it is not desirable for performance since
# then the embedding layer is not included in the CUDA graph.
input_ids
=
self
.
input_ids
inputs_embeds
=
None
# Run the decoder
with
set_forward_context
(
attn_metadata
,
self
.
vllm_config
):
hidden_states
=
self
.
model
(
token_ids
=
self
.
input_ids
,
position
_id
s
=
self
.
position_ids
,
input_ids
=
input_ids
,
positions
=
self
.
position_ids
,
kv_caches
=
self
.
kv_caches
,
inputs_embeds
=
inputs_embeds
,
)
hidden_states
=
hidden_states
[:
total_num_scheduled_tokens
]
num_reqs
=
self
.
input_batch
.
num_reqs
...
...
@@ -538,14 +678,21 @@ class TPUModelRunner:
fullgraph
=
True
,
dynamic
=
False
)
def
dummy_run
(
def
_
dummy_run
(
self
,
kv_caches
,
num_tokens
:
int
,
)
->
None
:
input_ids
=
torch
.
zeros
(
num_tokens
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
if
self
.
is_multimodal_model
:
input_ids
=
None
inputs_embeds
=
torch
.
zeros
((
num_tokens
,
self
.
hidden_size
),
dtype
=
self
.
dtype
,
device
=
self
.
device
)
else
:
input_ids
=
torch
.
zeros
((
num_tokens
),
dtype
=
torch
.
int32
,
device
=
self
.
device
)
inputs_embeds
=
None
position_ids
=
torch
.
zeros
(
num_tokens
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
...
...
@@ -571,7 +718,10 @@ class TPUModelRunner:
num_seqs
=
num_tokens
,
)
torch
.
_dynamo
.
mark_dynamic
(
input_ids
,
0
)
if
self
.
is_multimodal_model
:
torch
.
_dynamo
.
mark_dynamic
(
inputs_embeds
,
0
)
else
:
torch
.
_dynamo
.
mark_dynamic
(
input_ids
,
0
)
torch
.
_dynamo
.
mark_dynamic
(
position_ids
,
0
)
torch
.
_dynamo
.
mark_dynamic
(
attn_metadata
.
slot_mapping
,
0
)
torch
.
_dynamo
.
mark_dynamic
(
attn_metadata
.
block_tables
,
0
)
...
...
@@ -580,7 +730,12 @@ class TPUModelRunner:
with
set_forward_context
(
attn_metadata
,
self
.
vllm_config
,
0
):
assert
self
.
model
is
not
None
self
.
model
(
input_ids
,
position_ids
,
kv_caches
)
self
.
model
(
input_ids
=
input_ids
,
positions
=
position_ids
,
kv_caches
=
kv_caches
,
inputs_embeds
=
inputs_embeds
,
)
def
capture_model
(
self
)
->
None
:
"""Compile the model."""
...
...
@@ -590,11 +745,11 @@ class TPUModelRunner:
start
=
time
.
perf_counter
()
num_tokens
=
16
while
True
:
self
.
dummy_run
(
self
.
kv_caches
,
num_tokens
)
self
.
_
dummy_run
(
self
.
kv_caches
,
num_tokens
)
logger
.
info
(
" -- num_tokens: %d"
,
num_tokens
)
xm
.
mark_step
()
xm
.
wait_device_ops
()
if
num_tokens
>=
self
.
scheduler_config
.
max_num_batched
_tokens
:
if
num_tokens
>=
self
.
max_num
_tokens
:
break
num_tokens
*=
2
end
=
time
.
perf_counter
()
...
...
@@ -647,17 +802,20 @@ class ModelWrapperV1(nn.Module):
def
forward
(
self
,
token
_ids
:
torch
.
Tensor
,
position
_id
s
:
torch
.
Tensor
,
input
_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
list
[
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""Executes the forward pass of the model and samples the next token.
Args:
token
_ids: The input token IDs of shape [num_tokens].
position
_id
s: The input position IDs of shape [num_tokens].
input
_ids: The input token IDs of shape [num_tokens].
positions: The input position IDs of shape [num_tokens].
kv_caches: The key and value caches. They can be None during the
memory profiling at initialization.
inputs_embeds: The input embeddings of shape [num_tokens,
hidden_size]. It is used for multimodal models.
"""
# Skip this in memory profiling at initialization.
if
kv_caches
[
0
][
0
].
numel
()
>
0
:
...
...
@@ -684,9 +842,9 @@ class ModelWrapperV1(nn.Module):
assert
self
.
model
is
not
None
hidden_states
=
self
.
model
(
token
_ids
,
position
_id
s
,
kv_cache
s
,
input_ids
=
input
_ids
,
position
s
=
position
s
,
inputs_embeds
=
inputs_embed
s
,
)
return
hidden_states
...
...
@@ -699,6 +857,12 @@ class ModelWrapperV1(nn.Module):
logits
=
self
.
model
.
compute_logits
(
hidden_states
,
sampling_metadata
)
return
logits
def
get_multimodal_embeddings
(
self
,
*
args
,
**
kwargs
):
return
self
.
model
.
get_multimodal_embeddings
(
*
args
,
**
kwargs
)
def
get_input_embeddings
(
self
,
*
args
,
**
kwargs
):
return
self
.
model
.
get_input_embeddings
(
*
args
,
**
kwargs
)
def
_get_padded_number
(
n
:
int
,
multiple
:
int
)
->
int
:
return
((
n
+
multiple
-
1
)
//
multiple
)
*
multiple
vllm/v1/worker/tpu_worker.py
View file @
fbfc3ee3
...
...
@@ -134,7 +134,7 @@ class TPUWorker:
self
.
vllm_config
.
compilation_config
.
static_forward_context
,
runner_kv_caches
)
self
.
model_runner
.
dummy_run
(
self
.
model_runner
.
_
dummy_run
(
runner_kv_caches
,
num_tokens
=
self
.
scheduler_config
.
max_num_batched_tokens
,
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment