Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
83d93371
Unverified
Commit
83d93371
authored
Apr 22, 2025
by
Chenyaaang
Committed by
GitHub
Apr 22, 2025
Browse files
[Core][V1][TPU] Enable structured decoding on TPU V1 (#16499)
Signed-off-by:
Chenyaaang
<
chenyangli@google.com
>
parent
5175b884
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
158 additions
and
31 deletions
+158
-31
.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh
.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh
+3
-1
benchmarks/benchmark_serving_structured_output.py
benchmarks/benchmark_serving_structured_output.py
+1
-1
tests/v1/tpu/test_sampler.py
tests/v1/tpu/test_sampler.py
+5
-2
vllm/platforms/tpu.py
vllm/platforms/tpu.py
+2
-2
vllm/v1/worker/tpu_model_runner.py
vllm/v1/worker/tpu_model_runner.py
+147
-25
No files found.
.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh
View file @
83d93371
...
...
@@ -44,7 +44,9 @@ docker run --privileged --net host --shm-size=16G -it \
&& echo TEST_9
\
&& pytest -s -v /workspace/vllm/tests/v1/tpu/test_multimodal.py
\
&& echo TEST_10
\
&& pytest -s -v /workspace/vllm/tests/v1/tpu/test_pallas.py"
\
&& pytest -s -v /workspace/vllm/tests/v1/tpu/test_pallas.py
\
&& echo TEST_11
\
&& pytest -s -v /workspace/vllm/tests/v1/entrypoints/llm/test_struct_output_generate.py"
\
# TODO: This test fails because it uses RANDOM_SEED sampling
...
...
benchmarks/benchmark_serving_structured_output.py
View file @
83d93371
...
...
@@ -51,7 +51,7 @@ try:
except
ImportError
:
from
argparse
import
ArgumentParser
as
FlexibleArgumentParser
from
vllm.v1.structured_output.
utils
import
(
from
vllm.v1.structured_output.
backend_xgrammar
import
(
has_xgrammar_unsupported_json_features
)
MILLISECONDS_TO_SECONDS_CONVERSION
=
1000
...
...
tests/v1/tpu/test_sampler.py
View file @
83d93371
...
...
@@ -23,7 +23,7 @@ def test_sampler_different(model_name: str):
different results.
"""
llm
=
LLM
(
model_name
,
enforce_eager
=
Fals
e
,
enforce_eager
=
Tru
e
,
max_num_seqs
=
1
,
max_model_len
=
512
,
max_num_batched_tokens
=
512
)
...
...
@@ -57,4 +57,7 @@ def test_sampler_different(model_name: str):
# Make sure first two reqs have the same K/P
sampling_params
[
0
]
=
sampling_params
[
1
]
output
=
llm
.
generate
(
p
,
sampling_params
)
assert
output
[
0
].
outputs
[
0
].
text
==
output
[
1
].
outputs
[
0
].
text
# There are natural numerical instabilities that make it difficult
# to have deterministic results over many tokens, tests the first ~20
# tokens match.
assert
output
[
0
].
outputs
[
0
].
text
[:
20
]
==
output
[
1
].
outputs
[
0
].
text
[:
20
]
vllm/platforms/tpu.py
View file @
83d93371
...
...
@@ -168,9 +168,9 @@ class TpuPlatform(Platform):
)
->
None
:
"""Raises if this request is unsupported on this platform"""
if
isinstance
(
params
,
SamplingParams
):
if
params
.
guided_decoding
is
not
None
:
if
params
.
guided_decoding
is
not
None
and
not
envs
.
VLLM_USE_V1
:
raise
ValueError
(
"Structured output is not supported on "
f
"
{
cls
.
device_name
}
."
)
f
"
{
cls
.
device_name
}
V0
."
)
if
params
.
sampling_type
==
SamplingType
.
RANDOM_SEED
:
raise
ValueError
(
"Torch XLA does not support per-request seed."
)
vllm/v1/worker/tpu_model_runner.py
View file @
83d93371
...
...
@@ -30,8 +30,9 @@ from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available
from
vllm.v1.attention.backends.pallas
import
(
PallasAttentionBackend
,
PallasMetadata
)
from
vllm.v1.core.encoder_cache_manager
import
compute_encoder_budget
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
KVCacheConfig
,
KVCacheSpec
,
SlidingWindowSpec
)
from
vllm.v1.kv_cache_interface
import
(
AttentionSpec
,
FullAttentionSpec
,
KVCacheConfig
,
KVCacheSpec
,
SlidingWindowSpec
)
from
vllm.v1.outputs
import
(
EMPTY_MODEL_RUNNER_OUTPUT
,
LogprobsTensors
,
ModelRunnerOutput
)
from
vllm.v1.sample.tpu.metadata
import
TPUSupportedSamplingMetadata
...
...
@@ -148,6 +149,7 @@ class TPUModelRunner:
self
.
num_kv_heads
=
model_config
.
get_num_kv_heads
(
parallel_config
)
self
.
head_size
=
model_config
.
get_head_size
()
self
.
hidden_size
=
model_config
.
get_hidden_size
()
self
.
vocab_size
=
model_config
.
get_vocab_size
()
# Multi-modal data support
self
.
mm_registry
=
MULTIMODAL_REGISTRY
...
...
@@ -178,7 +180,7 @@ class TPUModelRunner:
max_num_blocks_per_req
=
self
.
max_num_blocks_per_req
,
device
=
self
.
device
,
pin_memory
=
self
.
pin_memory
,
vocab_size
=
model_config
.
get_
vocab_size
()
,
vocab_size
=
self
.
vocab_size
,
)
# Cached torch/numpy tensor
...
...
@@ -221,6 +223,20 @@ class TPUModelRunner:
self
.
num_reqs_paddings
=
_get_req_paddings
(
min_req_size
=
MIN_NUM_SEQS
,
max_req_size
=
self
.
max_num_reqs
)
# tensors for structured decoding
self
.
grammar_bitmask_cpu
=
torch
.
zeros
(
(
self
.
max_num_reqs
,
cdiv
(
self
.
vocab_size
,
32
)),
dtype
=
torch
.
int32
,
device
=
"cpu"
,
pin_memory
=
self
.
pin_memory
)
self
.
require_structured_out_cpu
=
torch
.
zeros
(
(
self
.
max_num_reqs
,
1
),
dtype
=
torch
.
bool
,
device
=
"cpu"
,
pin_memory
=
self
.
pin_memory
)
self
.
structured_decode_arange
=
torch
.
arange
(
0
,
32
,
device
=
"cpu"
,
pin_memory
=
self
.
pin_memory
)
# Get maximum number of mm items per modality (batch size).
self
.
max_num_mm_items_by_modality
=
dict
()
if
(
self
.
is_multimodal_model
and
self
.
max_num_encoder_input_tokens
>
0
...
...
@@ -762,9 +778,16 @@ class TPUModelRunner:
)
hidden_states
=
self
.
select_hidden_states
(
hidden_states
,
logits_indices
)
logits
=
self
.
compute_logits
(
hidden_states
)
tpu_sampling_metadata
=
TPUSupportedSamplingMetadata
.
\
from_input_batch
(
self
.
input_batch
,
padded_num_reqs
,
self
.
device
)
selected_token_ids
=
self
.
sample_from_hidden
(
hidden_states
,
if
scheduler_output
.
grammar_bitmask
is
not
None
:
require_struct_decoding
,
grammar_bitmask_padded
,
arange
=
\
self
.
prepare_structured_decoding_input
(
logits
,
scheduler_output
)
logits
=
self
.
structured_decode
(
require_struct_decoding
,
grammar_bitmask_padded
,
logits
,
arange
)
selected_token_ids
=
self
.
sample_from_logits
(
logits
,
tpu_sampling_metadata
)
# Remove padding on cpu and keep dynamic op outside of xla graph.
selected_token_ids
=
selected_token_ids
.
cpu
()[:
num_reqs
]
...
...
@@ -997,7 +1020,7 @@ class TPUModelRunner:
self
.
_dummy_run
(
num_tokens
)
xm
.
wait_device_ops
()
end
=
time
.
perf_counter
()
logger
.
info
(
"Compilation finished in
in
%.2f [secs]."
,
end
-
start
)
logger
.
info
(
"Compilation finished in %.2f [secs]."
,
end
-
start
)
self
.
_update_num_xla_graphs
(
"model backbone"
)
def
_precompile_select_hidden_states
(
self
)
->
None
:
...
...
@@ -1026,19 +1049,59 @@ class TPUModelRunner:
break
xm
.
wait_device_ops
()
end
=
time
.
perf_counter
()
logger
.
info
(
"Compilation finished in
in
%.2f [secs]."
,
end
-
start
)
logger
.
info
(
"Compilation finished in %.2f [secs]."
,
end
-
start
)
self
.
_update_num_xla_graphs
(
"select_hidden_states"
)
def
_precompile_
sample_from_hidden
(
self
)
->
None
:
logger
.
info
(
"Compiling
sampling
with different
num_req
s."
)
def
_precompile_
compute_logits
(
self
)
->
None
:
logger
.
info
(
"Compiling
compute_logits
with different
input shape
s."
)
start
=
time
.
perf_counter
()
hsize
=
self
.
model_config
.
get_hidden_size
()
for
num_reqs
in
self
.
num_reqs_paddings
:
dummy_hidden
=
torch
.
zeros
((
num_reqs
,
hsize
),
device
=
self
.
device
,
dtype
=
self
.
_hidden_states_dtype
)
# The first dimension of dummy_hidden cannot be mark_dynamic because
# some operations in the sampler require it to be static.
torch
.
_dynamo
.
mark_dynamic
(
dummy_hidden
,
0
)
self
.
compute_logits
(
dummy_hidden
)
logger
.
info
(
" -- num_seqs: %d"
,
num_reqs
)
xm
.
wait_device_ops
()
end
=
time
.
perf_counter
()
logger
.
info
(
"Compilation finished in %.2f [secs]."
,
end
-
start
)
self
.
_update_num_xla_graphs
(
"compute_logits"
)
def
_precompile_structured_decoding
(
self
)
->
None
:
logger
.
info
(
"Compiling structured_decoding with different input shapes."
)
start
=
time
.
perf_counter
()
for
num_reqs
in
self
.
num_reqs_paddings
:
dummy_logits
=
torch
.
zeros
((
num_reqs
,
self
.
vocab_size
),
device
=
self
.
device
,
dtype
=
self
.
_hidden_states_dtype
)
dummy_require_struct_decoding
=
\
self
.
require_structured_out_cpu
[:
num_reqs
].
to
(
self
.
device
)
dummy_grammar_bitmask
=
\
self
.
grammar_bitmask_cpu
[:
num_reqs
].
to
(
self
.
device
)
# The first dimension of the above 3 dummy tensors cannot be
# mark_dynamic because some operations in structured_decode require
# them to be static.
arange
=
self
.
structured_decode_arange
.
to
(
self
.
device
)
self
.
structured_decode
(
dummy_require_struct_decoding
,
dummy_grammar_bitmask
,
dummy_logits
,
arange
)
logger
.
info
(
" -- num_seqs: %d"
,
num_reqs
)
xm
.
wait_device_ops
()
end
=
time
.
perf_counter
()
logger
.
info
(
"Compilation finished in %.2f [secs]."
,
end
-
start
)
self
.
_update_num_xla_graphs
(
"structured_decoding"
)
def
_precompile_sample_from_logits
(
self
)
->
None
:
logger
.
info
(
"Compiling sample_from_logits with different input shapes."
)
start
=
time
.
perf_counter
()
for
num_reqs
in
self
.
num_reqs_paddings
:
dummy_logits
=
torch
.
zeros
((
num_reqs
,
self
.
vocab_size
),
device
=
self
.
device
,
dtype
=
self
.
_hidden_states_dtype
)
# The first dimension of dummy_logits cannot be mark_dynamic
# because some operations in the sampler require it to be static.
for
all_greedy
in
[
False
,
True
]:
generate_params_if_all_greedy
=
not
all_greedy
sampling_metadata
=
(
...
...
@@ -1049,12 +1112,12 @@ class TPUModelRunner:
generate_params_if_all_greedy
,
))
sampling_metadata
.
all_greedy
=
all_greedy
self
.
sample_from_
hidden
(
dummy_
hidden
,
sampling_metadata
)
self
.
sample_from_
logits
(
dummy_
logits
,
sampling_metadata
)
logger
.
info
(
" -- num_seqs: %d"
,
num_reqs
)
xm
.
wait_device_ops
()
end
=
time
.
perf_counter
()
logger
.
info
(
"Compilation finished in
in
%.2f [secs]."
,
end
-
start
)
self
.
_update_num_xla_graphs
(
"sampl
ing
"
)
logger
.
info
(
"Compilation finished in %.2f [secs]."
,
end
-
start
)
self
.
_update_num_xla_graphs
(
"sampl
e_from_logits
"
)
def
capture_model
(
self
)
->
None
:
"""
...
...
@@ -1063,7 +1126,9 @@ class TPUModelRunner:
self
.
_precompile_mm_encoder
()
self
.
_precompile_backbone
()
self
.
_precompile_select_hidden_states
()
self
.
_precompile_sample_from_hidden
()
self
.
_precompile_compute_logits
()
self
.
_precompile_structured_decoding
()
self
.
_precompile_sample_from_logits
()
def
profile_run
(
self
,
...
...
@@ -1144,7 +1209,7 @@ class TPUModelRunner:
tensor_config
=
kv_cache_config
.
tensors
[
layer_name
]
assert
tensor_config
.
size
%
kv_cache_spec
.
page_size_bytes
==
0
num_blocks
=
tensor_config
.
size
//
kv_cache_spec
.
page_size_bytes
if
isinstance
(
kv_cache_spec
,
Full
AttentionSpec
):
if
isinstance
(
kv_cache_spec
,
AttentionSpec
):
kv_cache_shape
=
PallasAttentionBackend
.
get_kv_cache_shape
(
num_blocks
,
kv_cache_spec
.
block_size
,
kv_cache_spec
.
num_kv_heads
,
kv_cache_spec
.
head_size
)
...
...
@@ -1179,16 +1244,14 @@ class TPUModelRunner:
return
hidden_states
[
indices_do_sample
]
@
torch
.
compile
(
backend
=
"openxla"
,
fullgraph
=
True
,
dynamic
=
False
)
def
sample_from_hidden
(
self
,
sample_hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
TPUSupportedSamplingMetadata
,
)
->
torch
.
Tensor
:
"""
Sample with xla-friendly function. This function is to be traced
separately from `forward` for lighter compilation overhead.
"""
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
,
None
)
def
compute_logits
(
self
,
sample_hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
compute_logits
(
sample_hidden_states
,
None
)
@
torch
.
compile
(
backend
=
"openxla"
,
fullgraph
=
True
,
dynamic
=
False
)
def
sample_from_logits
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
TPUSupportedSamplingMetadata
)
->
torch
.
Tensor
:
if
sampling_metadata
.
all_greedy
:
out_tokens
=
torch
.
argmax
(
logits
,
dim
=-
1
,
keepdim
=
True
)
else
:
...
...
@@ -1196,12 +1259,71 @@ class TPUModelRunner:
sampling_metadata
).
sampled_token_ids
return
out_tokens
@
torch
.
compile
(
backend
=
"openxla"
,
fullgraph
=
True
,
dynamic
=
False
)
def
structured_decode
(
self
,
require_struct_decoding
:
torch
.
Tensor
,
grammar_bitmask
:
torch
.
Tensor
,
logits
:
torch
.
Tensor
,
arange
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
torch
.
where
(
require_struct_decoding
,
self
.
apply_grammar_bitmask
(
logits
,
grammar_bitmask
,
arange
),
logits
)
def
apply_grammar_bitmask
(
self
,
logits
:
torch
.
Tensor
,
grammar_bitmask
:
torch
.
Tensor
,
arange
:
torch
.
Tensor
):
assert
(
logits
.
shape
[
0
]
==
grammar_bitmask
.
shape
[
0
])
logits_cloned
=
logits
.
clone
()
for
i
in
range
(
logits
.
shape
[
0
]):
unpacked_bitmask
=
(
torch
.
bitwise_right_shift
(
grammar_bitmask
[
i
][:,
None
],
arange
[
None
,
:])
&
1
)
==
0
unpacked_bitmask
=
unpacked_bitmask
.
reshape
(
-
1
)[:
self
.
vocab_size
]
logits_cloned
[
i
]
=
logits_cloned
[
i
].
masked_fill
(
unpacked_bitmask
,
-
float
(
"inf"
))
return
logits_cloned
def
get_multimodal_embeddings
(
self
,
*
args
,
**
kwargs
):
return
self
.
model
.
get_multimodal_embeddings
(
*
args
,
**
kwargs
)
def
get_input_embeddings
(
self
,
*
args
,
**
kwargs
):
return
self
.
model
.
get_input_embeddings
(
*
args
,
**
kwargs
)
def
prepare_structured_decoding_input
(
self
,
logits
:
torch
.
Tensor
,
scheduler_output
:
"SchedulerOutput"
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
grammar_bitmask
=
scheduler_output
.
grammar_bitmask
assert
grammar_bitmask
is
not
None
num_reqs
,
_
=
logits
.
shape
# Reset pre-allocated tensors
self
.
grammar_bitmask_cpu
.
zero_
()
self
.
require_structured_out_cpu
.
zero_
()
# We receive the structured output bitmask from the scheduler, but the
# indices of the requests in the batch may not match the indices of
# the bitmask since the scheduler doesn't know how the tpu runner is
# ordering the requests in the batch. We need to match the order of
# bitmask with the order of requests
struct_out_indices
:
list
[
int
]
=
[]
mask_indices
:
list
[
int
]
=
[]
for
req_id
in
self
.
input_batch
.
req_ids
:
mask_index
=
scheduler_output
.
structured_output_request_ids
.
get
(
req_id
)
if
mask_index
is
None
:
continue
batch_index
=
self
.
input_batch
.
req_id_to_index
[
req_id
]
struct_out_indices
.
append
(
batch_index
)
mask_indices
.
append
(
mask_index
)
self
.
grammar_bitmask_cpu
[
struct_out_indices
]
=
torch
.
from_numpy
(
grammar_bitmask
[
mask_indices
])
# It's not guaranteed that all requests in this batch require
# structured output, so create a bool tensor to represent
# the requests that need structured output.
struct_out_indices
=
torch
.
tensor
(
struct_out_indices
,
dtype
=
torch
.
long
)
self
.
require_structured_out_cpu
[
struct_out_indices
]
=
True
return
self
.
require_structured_out_cpu
[:
num_reqs
].
to
(
logits
.
device
),
\
self
.
grammar_bitmask_cpu
[:
num_reqs
].
to
(
logits
.
device
),
\
self
.
structured_decode_arange
.
to
(
logits
.
device
)
def
_get_mm_dummy_batch
(
self
,
modality
:
str
,
batch_size
:
int
)
->
BatchedTensorInputs
:
# Dummy data for pre-compiling multimodal models.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment