Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c3f9773b
Unverified
Commit
c3f9773b
authored
Sep 09, 2025
by
Chenyaaang
Committed by
GitHub
Sep 09, 2025
Browse files
[TPU] Fix tpu structured decoding in mixed batches (#24458)
Signed-off-by:
Chenyaaang
<
chenyangli@google.com
>
parent
3707cb25
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
20 deletions
+14
-20
vllm/v1/worker/tpu_model_runner.py
vllm/v1/worker/tpu_model_runner.py
+14
-20
No files found.
vllm/v1/worker/tpu_model_runner.py
View file @
c3f9773b
...
...
@@ -1769,28 +1769,22 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self
.
grammar_bitmask_cpu
.
zero_
()
self
.
require_structured_out_cpu
.
zero_
()
# We receive the structured output bitmask from the scheduler, but the
# indices of the requests in the batch may not match the indices of
# the bitmask since the scheduler doesn't know how the tpu runner is
# ordering the requests in the batch. We need to match the order of
# bitmask with the order of requests
struct_out_indices
:
list
[
int
]
=
[]
mask_indices
:
list
[
int
]
=
[]
for
req_id
in
self
.
input_batch
.
req_ids
:
mask_index
=
scheduler_output
.
structured_output_request_ids
.
get
(
req_id
)
if
mask_index
is
None
:
sorted_struct_requests
=
sorted
(
scheduler_output
.
structured_output_request_ids
.
items
(),
key
=
lambda
item
:
item
[
1
])
cumulative_mask_idx
=
0
for
req_id
,
_
in
sorted_struct_requests
:
if
req_id
not
in
self
.
input_batch
.
req_id_to_index
:
continue
batch_index
=
self
.
input_batch
.
req_id_to_index
[
req_id
]
struct_out_indices
.
append
(
batch_index
)
mask_indices
.
append
(
mask_index
)
self
.
grammar_bitmask_cpu
[
struct_out_indices
]
=
torch
.
from_numpy
(
grammar_bitmask
[
mask_indices
])
# It's not guaranteed that all requests in this batch require
# structured output, so create a bool tensor to represent
# the requests that need structured output.
struct_out_indices
=
torch
.
tensor
(
struct_out_indices
,
dtype
=
torch
.
long
)
self
.
require_structured_out_cpu
[
struct_out_indices
]
=
True
self
.
grammar_bitmask_cpu
[
batch_index
]
=
torch
.
from_numpy
(
grammar_bitmask
[
cumulative_mask_idx
])
# It's not guaranteed that all requests in this batch require
# structured output, so create a bool tensor to represent
# the requests that need structured output.
self
.
require_structured_out_cpu
[
batch_index
]
=
True
cumulative_mask_idx
+=
1
return
self
.
require_structured_out_cpu
[:
num_reqs
].
to
(
logits
.
device
),
\
self
.
grammar_bitmask_cpu
[:
num_reqs
].
to
(
logits
.
device
),
\
self
.
structured_decode_arange
.
to
(
logits
.
device
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment