Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
47195057
Unverified
Commit
47195057
authored
Mar 20, 2025
by
Hyesoo Yang
Committed by
GitHub
Mar 20, 2025
Browse files
[V1][TPU] Speed up top-k on TPU by using torch.topk (#15242)
Signed-off-by:
Hyesoo Yang
<
hyeygit@gmail.com
>
parent
6edbfa92
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
29 additions
and
4 deletions
+29
-4
tests/v1/tpu/test_sampler.py
tests/v1/tpu/test_sampler.py
+2
-1
vllm/envs.py
vllm/envs.py
+6
-0
vllm/v1/sample/ops/topk_topp_sampler.py
vllm/v1/sample/ops/topk_topp_sampler.py
+21
-3
No files found.
tests/v1/tpu/test_sampler.py
View file @
47195057
...
@@ -39,7 +39,7 @@ def test_sampler_compilation(model_name: str, monkeypatch):
...
@@ -39,7 +39,7 @@ def test_sampler_compilation(model_name: str, monkeypatch):
sampling_params
=
SamplingParams
(
sampling_params
=
SamplingParams
(
temperature
=
0.7
,
temperature
=
0.7
,
# top_p=0.6, # TODO too slow!
# top_p=0.6, # TODO too slow!
#
top_k=10,
top_k
=
10
,
min_p
=
0.2
,
min_p
=
0.2
,
max_tokens
=
16
)
max_tokens
=
16
)
s
=
time
()
s
=
time
()
...
@@ -49,6 +49,7 @@ def test_sampler_compilation(model_name: str, monkeypatch):
...
@@ -49,6 +49,7 @@ def test_sampler_compilation(model_name: str, monkeypatch):
# Second request with different params, but for which we
# Second request with different params, but for which we
# compiled for in previous eager iteration.
# compiled for in previous eager iteration.
sampling_params
=
SamplingParams
(
temperature
=
0.1
,
sampling_params
=
SamplingParams
(
temperature
=
0.1
,
top_k
=
12
,
min_p
=
0.8
,
min_p
=
0.8
,
max_tokens
=
24
)
max_tokens
=
24
)
s
=
time
()
s
=
time
()
...
...
vllm/envs.py
View file @
47195057
...
@@ -95,6 +95,7 @@ if TYPE_CHECKING:
...
@@ -95,6 +95,7 @@ if TYPE_CHECKING:
VLLM_DP_MASTER_PORT
:
int
=
0
VLLM_DP_MASTER_PORT
:
int
=
0
VLLM_MARLIN_USE_ATOMIC_ADD
:
bool
=
False
VLLM_MARLIN_USE_ATOMIC_ADD
:
bool
=
False
VLLM_V0_USE_OUTLINES_CACHE
:
bool
=
False
VLLM_V0_USE_OUTLINES_CACHE
:
bool
=
False
VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION
:
bool
=
False
def
get_default_cache_root
():
def
get_default_cache_root
():
...
@@ -623,6 +624,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
...
@@ -623,6 +624,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
# an environment with potentially malicious users.
# an environment with potentially malicious users.
"VLLM_V0_USE_OUTLINES_CACHE"
:
"VLLM_V0_USE_OUTLINES_CACHE"
:
lambda
:
os
.
environ
.
get
(
"VLLM_V0_USE_OUTLINES_CACHE"
,
"0"
)
==
"1"
,
lambda
:
os
.
environ
.
get
(
"VLLM_V0_USE_OUTLINES_CACHE"
,
"0"
)
==
"1"
,
# If set, disables TPU-specific optimization for top-k & top-p sampling
"VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION"
:
lambda
:
bool
(
int
(
os
.
environ
[
"VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION"
]))
if
"VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION"
in
os
.
environ
else
None
,
}
}
# end-env-vars-definition
# end-env-vars-definition
...
...
vllm/v1/sample/ops/topk_topp_sampler.py
View file @
47195057
...
@@ -66,7 +66,14 @@ class TopKTopPSampler(nn.Module):
...
@@ -66,7 +66,14 @@ class TopKTopPSampler(nn.Module):
"best performance, please install FlashInfer."
)
"best performance, please install FlashInfer."
)
self
.
forward
=
self
.
forward_native
self
.
forward
=
self
.
forward_native
elif
current_platform
.
is_tpu
():
elif
current_platform
.
is_tpu
():
self
.
forward
=
self
.
forward_tpu
if
envs
.
VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION
:
logger
.
warning
(
"TPU-specific optimization for top-k & top-p sampling are "
"disabled, falling back to PyTorch-native implementation "
"which could be very slow."
)
self
.
forward
=
self
.
forward_native
else
:
self
.
forward
=
self
.
forward_tpu
else
:
else
:
self
.
forward
=
self
.
forward_native
self
.
forward
=
self
.
forward_native
...
@@ -105,8 +112,19 @@ class TopKTopPSampler(nn.Module):
...
@@ -105,8 +112,19 @@ class TopKTopPSampler(nn.Module):
k
:
Optional
[
torch
.
Tensor
],
k
:
Optional
[
torch
.
Tensor
],
p
:
Optional
[
torch
.
Tensor
],
p
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# TODO Placeholder for TPU optimized topk/p kernel
# If only top-k is specified, use pytorch's builtin topk op. This leads
# logits = apply_top_k_top_p(logits, k, p)
# to significant speed up on TPU compared to using apply_top_k_top_p.
if
k
is
not
None
and
p
is
None
:
topk_values
,
topk_indices
=
torch
.
topk
(
logits
,
k
,
dim
=-
1
)
mask
=
torch
.
ones_like
(
logits
,
dtype
=
torch
.
bool
)
mask
.
scatter_
(
-
1
,
topk_indices
,
False
)
logits
.
masked_fill_
(
mask
,
float
(
'-inf'
))
else
:
# TODO Placeholder for TPU optimized topp kernel
# logits = apply_top_k_top_p(logits, k, p)
pass
probs
=
logits
.
softmax
(
dim
=-
1
,
dtype
=
torch
.
float32
)
probs
=
logits
.
softmax
(
dim
=-
1
,
dtype
=
torch
.
float32
)
return
random_sample
(
probs
,
generators
)
return
random_sample
(
probs
,
generators
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment