Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c3f687ac
Unverified
Commit
c3f687ac
authored
Mar 28, 2025
by
Alexander Matveev
Committed by
GitHub
Mar 28, 2025
Browse files
[V1] TPU - Fix the chunked prompt bug (#15713)
Signed-off-by:
Alexander Matveev
<
amatveev@redhat.com
>
parent
04437e31
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
17 additions
and
1 deletion
+17
-1
tests/v1/tpu/test_basic.py
tests/v1/tpu/test_basic.py
+4
-1
vllm/v1/worker/tpu_model_runner.py
vllm/v1/worker/tpu_model_runner.py
+13
-0
No files found.
tests/v1/tpu/test_basic.py
View file @
c3f687ac
...
...
@@ -48,7 +48,10 @@ def test_models(
with
vllm_runner
(
model
,
max_model_len
=
8192
,
# Note: max_num_batched_tokens == 1024 is needed here to
# actually test chunked prompt
max_num_batched_tokens
=
1024
,
max_model_len
=
8196
,
gpu_memory_utilization
=
0.7
,
max_num_seqs
=
16
,
tensor_parallel_size
=
tensor_parallel_size
)
as
vllm_model
:
...
...
vllm/v1/worker/tpu_model_runner.py
View file @
c3f687ac
...
...
@@ -618,6 +618,7 @@ class TPUModelRunner:
# Update the cache state concurrently. Code above will not block until
# we use `selected_token_ids`. Add mark_step if post-processing changes
request_seq_lens
:
list
[
tuple
[
int
,
CachedRequestState
,
int
]]
=
[]
discard_sampled_tokens_req_indices
=
[]
for
i
,
req_id
in
zip
(
range
(
num_reqs
),
self
.
input_batch
.
req_ids
):
assert
req_id
is
not
None
req_state
=
self
.
requests
[
req_id
]
...
...
@@ -633,6 +634,10 @@ class TPUModelRunner:
# This relies on cuda-specific torch-internal impl details
generator
.
set_offset
(
generator
.
get_offset
()
-
4
)
# Record the index of the request that should not be sampled,
# so that we could clear the sampled tokens before returning.
discard_sampled_tokens_req_indices
.
append
(
i
)
assert
all
(
req_id
is
not
None
for
req_id
in
self
.
input_batch
.
req_ids
[:
num_reqs
]),
"req_ids contains None"
...
...
@@ -646,11 +651,19 @@ class TPUModelRunner:
if
max_gen_len
==
1
:
valid_sampled_token_ids
=
selected_token_ids
.
tolist
()
# Mask out the sampled tokens that should not be sampled.
# TODO: Keep in sync with gpu_model_runner.py, in particular
# the "else" case here
for
i
in
discard_sampled_tokens_req_indices
:
valid_sampled_token_ids
[
i
].
clear
()
# Append sampled tokens
for
i
,
req_state
,
seq_len
in
request_seq_lens
:
token_id
=
valid_sampled_token_ids
[
i
][
0
]
self
.
input_batch
.
token_ids_cpu
[
i
,
seq_len
]
=
token_id
req_state
.
output_token_ids
.
append
(
token_id
)
self
.
input_batch
.
num_tokens
[
i
]
+=
1
else
:
valid_mask
=
selected_token_ids
!=
INVALID_TOKEN_ID
gen_lens
=
valid_mask
.
sum
(
dim
=
1
).
tolist
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment