Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
105d3d62
Unverified
Commit
105d3d62
authored
Sep 07, 2025
by
Woosuk Kwon
Committed by
GitHub
Sep 07, 2025
Browse files
[TPU] Remove TopKTopPSampler dependency for TPU sampler (#24391)
Signed-off-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
62f66be1
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
79 additions
and
57 deletions
+79
-57
tests/v1/tpu/test_topk_topp_sampler.py
tests/v1/tpu/test_topk_topp_sampler.py
+6
-2
vllm/v1/sample/ops/topk_topp_sampler.py
vllm/v1/sample/ops/topk_topp_sampler.py
+2
-51
vllm/v1/sample/tpu/sampler.py
vllm/v1/sample/tpu/sampler.py
+71
-4
No files found.
tests/v1/tpu/test_topk_topp_sampler.py
View file @
105d3d62
...
...
@@ -6,8 +6,12 @@ import pytest
import
torch
from
vllm.platforms
import
current_platform
from
vllm.v1.sample.ops.topk_topp_sampler
import
(
apply_top_k_top_p
,
from
vllm.v1.sample.ops.topk_topp_sampler
import
apply_top_k_top_p
# isort: off
from
vllm.v1.sample.tpu.sampler
import
(
apply_top_k_top_p
as
apply_top_k_top_p_tpu
)
# isort: on
if
not
current_platform
.
is_tpu
():
pytest
.
skip
(
"This test needs a TPU."
,
allow_module_level
=
True
)
...
...
vllm/v1/sample/ops/topk_topp_sampler.py
View file @
105d3d62
...
...
@@ -73,9 +73,7 @@ class TopKTopPSampler(nn.Module):
self
.
forward
=
self
.
forward_native
else
:
self
.
forward
=
self
.
forward_native
if
current_platform
.
is_tpu
():
self
.
apply_top_k_top_p
=
apply_top_k_top_p_tpu
else
:
self
.
apply_top_k_top_p
=
apply_top_k_top_p
def
forward_native
(
...
...
@@ -125,53 +123,6 @@ class TopKTopPSampler(nn.Module):
return
flashinfer_sample
(
logits
.
contiguous
(),
k
,
p
,
generators
),
None
def
apply_top_k_top_p_tpu
(
logits
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
p
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
"""
Apply top-k and top-p optimized for TPU.
This algorithm avoids using torch.scatter which is extremely slow on TPU.
This is achieved by finding a "cut-off" element in the original logit, and
after thresholding the logit using this cut-off, the remaining elements
shall constitute the top-p set.
Note: in the case of tie (i.e. multipple cut-off elements present in the
logit), all tie elements are included in the top-p set. In other words,
this function does not break ties. Instead, these tie tokens have equal
chance of being chosen during final sampling, so we can consider the tie
being broken then.
"""
probs
=
logits
.
softmax
(
dim
=-
1
)
probs_sort
,
_
=
probs
.
sort
(
dim
=-
1
,
descending
=
False
)
if
k
is
not
None
:
top_k_count
=
probs_sort
.
size
(
1
)
-
k
.
to
(
torch
.
long
)
# shape: (batch, )
top_k_count
=
top_k_count
.
unsqueeze
(
dim
=
1
)
top_k_cutoff
=
probs_sort
.
gather
(
-
1
,
top_k_count
)
# Make sure the no top-k rows are no-op.
no_top_k_mask
=
(
k
==
logits
.
shape
[
1
]).
unsqueeze
(
dim
=
1
)
top_k_cutoff
.
masked_fill_
(
no_top_k_mask
,
-
float
(
"inf"
))
elements_to_discard
=
probs
<
top_k_cutoff
logits
.
masked_fill_
(
elements_to_discard
,
-
float
(
"inf"
))
if
p
is
not
None
:
cumprob
=
torch
.
cumsum
(
probs_sort
,
dim
=-
1
)
top_p_mask
=
cumprob
<=
1
-
p
.
unsqueeze
(
dim
=
1
)
top_p_mask
[:,
-
1
]
=
False
# at least one
top_p_count
=
top_p_mask
.
sum
(
dim
=-
1
).
unsqueeze
(
1
)
top_p_cutoff
=
probs_sort
.
gather
(
-
1
,
top_p_count
)
elements_to_discard
=
probs
<
top_p_cutoff
logits
.
masked_fill_
(
elements_to_discard
,
-
float
(
"inf"
))
return
logits
def
apply_top_k_top_p
(
logits
:
torch
.
Tensor
,
k
:
Optional
[
torch
.
Tensor
],
...
...
vllm/v1/sample/tpu/sampler.py
View file @
105d3d62
...
...
@@ -2,11 +2,12 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Sampler layer implementing TPU supported operations."""
from
typing
import
Optional
import
torch
import
torch.nn
as
nn
from
vllm.v1.outputs
import
LogprobsTensors
,
SamplerOutput
from
vllm.v1.sample.ops.topk_topp_sampler
import
TopKTopPSampler
from
vllm.v1.sample.tpu.metadata
import
TPUSupportedSamplingMetadata
_SAMPLING_EPS
=
1e-5
...
...
@@ -17,7 +18,6 @@ class Sampler(nn.Module):
def
__init__
(
self
):
# TODO(houseroad): Add support for logprobs_mode.
super
().
__init__
()
self
.
topk_topp_sampler
=
TopKTopPSampler
()
def
forward
(
self
,
...
...
@@ -65,13 +65,17 @@ class Sampler(nn.Module):
logits
=
self
.
apply_min_p
(
logits
,
sampling_metadata
.
min_p
)
# Apply top_k and/or top_p.
random_sampled
,
_
=
self
.
topk_top
p
_
sampler
(
logits
=
apply_
top
_
k_top_
p
(
logits
,
sampling_metadata
.
generators
,
sampling_metadata
.
top_k
,
sampling_metadata
.
top_p
,
)
# Random sample.
probs
=
logits
.
softmax
(
dim
=-
1
,
dtype
=
torch
.
float32
)
random_sampled
=
self
.
random_sample
(
probs
,
sampling_metadata
.
generators
)
sampled
=
torch
.
where
(
sampling_metadata
.
temperature
<
_SAMPLING_EPS
,
greedy_sampled
,
random_sampled
)
return
sampled
...
...
@@ -144,3 +148,66 @@ class Sampler(nn.Module):
# Apply mask using boolean indexing (xla friendly)
logits
.
masked_fill_
(
~
valid_token_mask
,
-
float
(
"inf"
))
return
logits
def
random_sample
(
self
,
probs
:
torch
.
Tensor
,
generators
:
dict
[
int
,
torch
.
Generator
],
)
->
torch
.
Tensor
:
q
=
torch
.
empty_like
(
probs
)
# NOTE(woosuk): To batch-process the requests without their own seeds,
# which is the common case, we first assume that every request does
# not have its own seed. Then, we overwrite the values for the requests
# that have their own seeds.
q
.
exponential_
()
if
generators
:
for
i
,
generator
in
generators
.
items
():
q
[
i
].
exponential_
(
generator
=
generator
)
return
probs
.
div_
(
q
).
argmax
(
dim
=-
1
).
view
(
-
1
)
def
apply_top_k_top_p
(
logits
:
torch
.
Tensor
,
k
:
Optional
[
torch
.
Tensor
],
p
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
"""
Apply top-k and top-p optimized for TPU.
This algorithm avoids using torch.scatter which is extremely slow on TPU.
This is achieved by finding a "cut-off" element in the original logit, and
after thresholding the logit using this cut-off, the remaining elements
shall constitute the top-p set.
Note: in the case of tie (i.e. multipple cut-off elements present in the
logit), all tie elements are included in the top-p set. In other words,
this function does not break ties. Instead, these tie tokens have equal
chance of being chosen during final sampling, so we can consider the tie
being broken then.
"""
probs
=
logits
.
softmax
(
dim
=-
1
)
probs_sort
,
_
=
probs
.
sort
(
dim
=-
1
,
descending
=
False
)
if
k
is
not
None
:
top_k_count
=
probs_sort
.
size
(
1
)
-
k
.
to
(
torch
.
long
)
# shape: (batch, )
top_k_count
=
top_k_count
.
unsqueeze
(
dim
=
1
)
top_k_cutoff
=
probs_sort
.
gather
(
-
1
,
top_k_count
)
# Make sure the no top-k rows are no-op.
no_top_k_mask
=
(
k
==
logits
.
shape
[
1
]).
unsqueeze
(
dim
=
1
)
top_k_cutoff
.
masked_fill_
(
no_top_k_mask
,
-
float
(
"inf"
))
elements_to_discard
=
probs
<
top_k_cutoff
logits
.
masked_fill_
(
elements_to_discard
,
-
float
(
"inf"
))
if
p
is
not
None
:
cumprob
=
torch
.
cumsum
(
probs_sort
,
dim
=-
1
)
top_p_mask
=
cumprob
<=
1
-
p
.
unsqueeze
(
dim
=
1
)
top_p_mask
[:,
-
1
]
=
False
# at least one
top_p_count
=
top_p_mask
.
sum
(
dim
=-
1
).
unsqueeze
(
1
)
top_p_cutoff
=
probs_sort
.
gather
(
-
1
,
top_p_count
)
elements_to_discard
=
probs
<
top_p_cutoff
logits
.
masked_fill_
(
elements_to_discard
,
-
float
(
"inf"
))
return
logits
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment