Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5941e0b7
Unverified
Commit
5941e0b7
authored
May 05, 2025
by
Nicolò Lucchesi
Committed by
GitHub
May 05, 2025
Browse files
[TPU][V1] Add support for top-logprobs (#17072)
Signed-off-by:
NickLucche
<
nlucches@redhat.com
>
parent
97659408
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
105 additions
and
17 deletions
+105
-17
tests/v1/tpu/test_sampler.py
tests/v1/tpu/test_sampler.py
+48
-0
vllm/v1/sample/tpu/metadata.py
vllm/v1/sample/tpu/metadata.py
+9
-4
vllm/v1/sample/tpu/sampler.py
vllm/v1/sample/tpu/sampler.py
+2
-12
vllm/v1/worker/tpu_model_runner.py
vllm/v1/worker/tpu_model_runner.py
+46
-1
No files found.
tests/v1/tpu/test_sampler.py
View file @
5941e0b7
...
...
@@ -61,3 +61,51 @@ def test_sampler_different(model_name: str):
# to have deterministic results over many tokens, tests the first ~20
# tokens match.
assert
output
[
0
].
outputs
[
0
].
text
[:
20
]
==
output
[
1
].
outputs
[
0
].
text
[:
20
]
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"Qwen/Qwen2.5-1.5B-Instruct"
])
# TODO TPU will appear busy if we fan-out test params here
@
pytest
.
mark
.
parametrize
(
"n_prompts"
,
[
1
])
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_tpu
(),
reason
=
"This test needs a TPU"
)
def
test_logprobs
(
model_name
:
str
,
n_prompts
:
int
):
"""
Request top logprobs with different sampling settings and check
that results contains the requested number, ordered ascendingly.
"""
def
check_num_logprobs
(
logprobs
,
expected_num
:
int
):
for
step
in
logprobs
:
prev_logp
=
1.0
# order by rank
sorted_step
=
dict
(
sorted
(
step
.
items
(),
key
=
lambda
item
:
item
[
1
].
rank
))
# Can contain the sampled token
assert
len
(
step
)
==
expected_num
or
len
(
step
)
==
expected_num
+
1
# Check results are ordered by prob value
for
rankno
,
(
tid
,
logp
)
in
enumerate
(
sorted_step
.
items
()):
assert
logp
.
logprob
<=
prev_logp
prev_logp
=
logp
.
logprob
assert
logp
.
rank
==
rankno
+
1
llm
=
LLM
(
model_name
,
enforce_eager
=
False
,
max_num_seqs
=
1
,
max_model_len
=
128
,
max_num_batched_tokens
=
128
)
prompts
=
[
"Write a short story about a robot that dreams for the first time."
]
*
n_prompts
greedy_sampling_params
=
SamplingParams
(
temperature
=
0.0
,
max_tokens
=
64
,
\
logprobs
=
4
)
regular_sampling_params
=
SamplingParams
(
temperature
=
0.4
,
max_tokens
=
64
,
\
logprobs
=
4
)
topkp_sampling_params
=
SamplingParams
(
temperature
=
0.4
,
max_tokens
=
64
,
\
logprobs
=
4
,
top_k
=
12
,
top_p
=
0.5
)
for
sp
in
[
greedy_sampling_params
,
regular_sampling_params
,
\
topkp_sampling_params
]:
output
=
llm
.
generate
(
prompts
,
sp
)
for
o
in
output
:
check_num_logprobs
(
o
.
outputs
[
0
].
logprobs
,
4
)
vllm/v1/sample/tpu/metadata.py
View file @
5941e0b7
...
...
@@ -31,8 +31,10 @@ class TPUSupportedSamplingMetadata:
all_greedy
:
bool
=
True
# unsupported, you need to return an extra tensor of static size BxV
max_num_logprobs
=
None
# Whether logprobs are to be gathered in this batch of request. To balance
# out compile time and runtime, a fixed `max_number_logprobs` value is used
# when gathering logprobs, regardless of the values specified in the batch.
logprobs
:
bool
=
False
# TODO No penalties for now
no_penalties
:
bool
=
True
...
...
@@ -84,10 +86,12 @@ class TPUSupportedSamplingMetadata:
we want to pre-compile a graph with sampling parameters, even if
they are not strictly needed for greedy decoding.
"""
needs_logprobs
=
input_batch
.
max_num_logprobs
>
0
if
\
input_batch
.
max_num_logprobs
else
False
# Early return to avoid unnecessary cpu to tpu copy
if
(
input_batch
.
all_greedy
is
True
and
generate_params_if_all_greedy
is
False
):
return
cls
(
all_greedy
=
True
)
return
cls
(
all_greedy
=
True
,
logprobs
=
needs_logprobs
)
num_reqs
=
input_batch
.
num_reqs
...
...
@@ -115,4 +119,5 @@ class TPUSupportedSamplingMetadata:
top_k
=
input_batch
.
top_k_cpu_tensor
[:
padded_num_reqs
].
to
(
xla_device
),
min_p
=
input_batch
.
min_p_cpu_tensor
[:
padded_num_reqs
].
to
(
xla_device
))
xla_device
),
logprobs
=
needs_logprobs
)
vllm/v1/sample/tpu/sampler.py
View file @
5941e0b7
...
...
@@ -22,27 +22,18 @@ class Sampler(nn.Module):
logits
:
torch
.
Tensor
,
sampling_metadata
:
TPUSupportedSamplingMetadata
,
)
->
SamplerOutput
:
# NOTE(woosuk): Use the original logits (before any penalties or
# temperature scaling) for the top-k logprobs.
# This is different from the V0 sampler, which uses the logits that
# is used for sampling (after penalties and temperature scaling).
# Use float32 for the logits.
logits
=
logits
.
to
(
torch
.
float32
)
# Sample the next token.
sampled
=
self
.
sample
(
logits
,
sampling_metadata
)
# Use int32 to reduce the tensor size.
sampled
=
sampled
.
to
(
torch
.
int32
)
# These are GPU tensors.
# These are TPU tensors.
sampler_output
=
SamplerOutput
(
# The sampled tokens are expanded to 2D tensor with shape
# [num_requests, 1], where each row represents one generated
# token per request.
sampled_token_ids
=
sampled
.
unsqueeze
(
-
1
),
logprobs_tensors
=
None
,
)
logprobs_tensors
=
None
)
return
sampler_output
def
apply_temperature
(
...
...
@@ -50,7 +41,6 @@ class Sampler(nn.Module):
logits
:
torch
.
Tensor
,
temp
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
# Use in-place division to avoid creating a new tensor.
return
logits
.
div_
(
temp
.
unsqueeze
(
dim
=
1
))
def
greedy_sample
(
self
,
logits
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
...
vllm/v1/worker/tpu_model_runner.py
View file @
5941e0b7
...
...
@@ -791,8 +791,18 @@ class TPUModelRunner:
arange
)
selected_token_ids
=
self
.
sample_from_logits
(
logits
,
tpu_sampling_metadata
)
# NOTE (NickLucche) Use the original logits (before any penalties or
# temperature scaling) for the top-k logprobs. We can't enforce it due
# to recompilations outside torch.compiled code, so just make sure
# `sample_from_logits` does not modify the logits in-place.
logprobs
=
self
.
gather_logprobs
(
logits
,
selected_token_ids
)
\
if
tpu_sampling_metadata
.
logprobs
else
None
# Remove padding on cpu and keep dynamic op outside of xla graph.
selected_token_ids
=
selected_token_ids
.
cpu
()[:
num_reqs
]
logprobs_lists
=
logprobs
.
tolists
()
\
if
tpu_sampling_metadata
.
logprobs
else
None
# Update the cache state concurrently. Code above will not block until
# we use `selected_token_ids`. Add mark_step if post-processing changes
...
...
@@ -862,7 +872,7 @@ class TPUModelRunner:
req_id_to_index
=
self
.
input_batch
.
req_id_to_index
,
sampled_token_ids
=
valid_sampled_token_ids
,
spec_token_ids
=
None
,
logprobs
=
None
,
logprobs
=
logprobs_lists
,
prompt_logprobs_dict
=
prompt_logprobs_dict
,
)
...
...
@@ -1121,6 +1131,22 @@ class TPUModelRunner:
logger
.
info
(
"Compilation finished in %.2f [secs]."
,
end
-
start
)
self
.
_update_num_xla_graphs
(
"sample_from_logits"
)
def
_precompile_gather_logprobs
(
self
)
->
None
:
logger
.
info
(
"Compiling gather_logprobs with different input shapes."
)
start
=
time
.
perf_counter
()
for
num_reqs
in
self
.
num_reqs_paddings
:
dummy_logits
=
torch
.
zeros
((
num_reqs
,
self
.
vocab_size
),
device
=
self
.
device
,
dtype
=
self
.
_hidden_states_dtype
)
dummy_tokens
=
torch
.
zeros
((
num_reqs
,
1
),
dtype
=
torch
.
int64
).
to
(
self
.
device
)
self
.
gather_logprobs
(
dummy_logits
,
dummy_tokens
)
logger
.
info
(
" -- num_seqs: %d"
,
num_reqs
)
xm
.
wait_device_ops
()
end
=
time
.
perf_counter
()
logger
.
info
(
"Compilation finished in %.2f [secs]."
,
end
-
start
)
self
.
_update_num_xla_graphs
(
"gather_logprobs"
)
def
capture_model
(
self
)
->
None
:
"""
Precompile all the subgraphs with possible input shapes.
...
...
@@ -1131,6 +1157,7 @@ class TPUModelRunner:
self
.
_precompile_compute_logits
()
self
.
_precompile_structured_decoding
()
self
.
_precompile_sample_from_logits
()
self
.
_precompile_gather_logprobs
()
def
profile_run
(
self
,
...
...
@@ -1254,6 +1281,10 @@ class TPUModelRunner:
def
sample_from_logits
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
TPUSupportedSamplingMetadata
)
->
torch
.
Tensor
:
"""
Sample with xla-friendly function. This function is to be traced
separately from `forward` for lighter compilation overhead.
"""
if
sampling_metadata
.
all_greedy
:
out_tokens
=
torch
.
argmax
(
logits
,
dim
=-
1
,
keepdim
=
True
)
else
:
...
...
@@ -1261,6 +1292,20 @@ class TPUModelRunner:
sampling_metadata
).
sampled_token_ids
return
out_tokens
@
torch
.
compile
(
backend
=
"openxla"
,
fullgraph
=
True
,
dynamic
=
False
)
def
gather_logprobs
(
self
,
logits
:
torch
.
Tensor
,
sampled_tokens
:
torch
.
Tensor
)
->
LogprobsTensors
:
"""
Gather the top_logprobs with corresponding tokens. Use a fixed number
of logprobs as an alternative to having multiple pre-compiled graphs.
Select the number of logprobs actually demanded by each request on CPU.
"""
logprobs
=
self
.
sampler
.
compute_logprobs
(
logits
)
return
self
.
sampler
.
gather_logprobs
(
logprobs
,
self
.
model_config
.
max_logprobs
,
token_ids
=
sampled_tokens
.
squeeze
(
-
1
))
@
torch
.
compile
(
backend
=
"openxla"
,
fullgraph
=
True
,
dynamic
=
False
)
def
structured_decode
(
self
,
require_struct_decoding
:
torch
.
Tensor
,
grammar_bitmask
:
torch
.
Tensor
,
logits
:
torch
.
Tensor
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment