Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
17f59521
Commit
17f59521
authored
Jan 08, 2026
by
laibao
Browse files
V1 采样器:新增 reduced top-k/top-p 采样路径
parent
2b0c9835
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
111 additions
and
1 deletion
+111
-1
vllm/envs.py
vllm/envs.py
+6
-0
vllm/v1/sample/metadata.py
vllm/v1/sample/metadata.py
+7
-0
vllm/v1/sample/ops/topk_topp_sampler.py
vllm/v1/sample/ops/topk_topp_sampler.py
+85
-1
vllm/v1/sample/sampler.py
vllm/v1/sample/sampler.py
+2
-0
vllm/v1/worker/gpu_input_batch.py
vllm/v1/worker/gpu_input_batch.py
+11
-0
No files found.
vllm/envs.py
View file @
17f59521
...
@@ -246,6 +246,7 @@ if TYPE_CHECKING:
...
@@ -246,6 +246,7 @@ if TYPE_CHECKING:
VLLM_USE_FUSED_RMS_ROPE
:
bool
=
False
VLLM_USE_FUSED_RMS_ROPE
:
bool
=
False
VLLM_USE_MARLIN_W16A16_MOE
:
bool
=
False
VLLM_USE_MARLIN_W16A16_MOE
:
bool
=
False
VLLM_V1_FAST_TOKEN_ID_COPY
:
bool
=
False
VLLM_V1_FAST_TOKEN_ID_COPY
:
bool
=
False
VLLM_V1_USE_REDUCED_TOPK_TOPP_SAMPLER
:
bool
=
False
def
get_default_cache_root
():
def
get_default_cache_root
():
return
os
.
getenv
(
return
os
.
getenv
(
...
@@ -1694,6 +1695,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
...
@@ -1694,6 +1695,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_V1_FAST_TOKEN_ID_COPY"
:
"VLLM_V1_FAST_TOKEN_ID_COPY"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_V1_FAST_TOKEN_ID_COPY"
,
"False"
).
lower
()
in
lambda
:
(
os
.
environ
.
get
(
"VLLM_V1_FAST_TOKEN_ID_COPY"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
(
"true"
,
"1"
)),
# If set to 1/True, enable the reduced top-k/top-p sampling path in the
# V1 PyTorch-native sampler.
"VLLM_V1_USE_REDUCED_TOPK_TOPP_SAMPLER"
:
lambda
:
(
os
.
getenv
(
"VLLM_V1_USE_REDUCED_TOPK_TOPP_SAMPLER"
,
"0"
).
lower
()
in
(
"true"
,
"1"
)),
}
}
# --8<-- [end:env-vars-definition]
# --8<-- [end:env-vars-definition]
...
...
vllm/v1/sample/metadata.py
View file @
17f59521
...
@@ -41,3 +41,10 @@ class SamplingMetadata:
...
@@ -41,3 +41,10 @@ class SamplingMetadata:
# Loaded logits processors
# Loaded logits processors
logitsprocs
:
LogitsProcessors
logitsprocs
:
LogitsProcessors
# Optional host-side summaries to avoid device sync in fast paths.
# When `top_k` is provided, `max_top_k` is the maximum top-k value across
# the batch on the host (Python int).
max_top_k
:
Optional
[
int
]
=
None
# True if any request in the batch has top_k == vocab_size (i.e. no top-k).
has_any_no_top_k
:
bool
=
False
vllm/v1/sample/ops/topk_topp_sampler.py
View file @
17f59521
...
@@ -81,12 +81,35 @@ class TopKTopPSampler(nn.Module):
...
@@ -81,12 +81,35 @@ class TopKTopPSampler(nn.Module):
generators
:
dict
[
int
,
torch
.
Generator
],
generators
:
dict
[
int
,
torch
.
Generator
],
k
:
Optional
[
torch
.
Tensor
],
k
:
Optional
[
torch
.
Tensor
],
p
:
Optional
[
torch
.
Tensor
],
p
:
Optional
[
torch
.
Tensor
],
*
,
max_top_k
:
Optional
[
int
]
=
None
,
has_any_no_top_k
:
bool
=
False
,
)
->
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
)
->
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
"""
"""
PyTorch-native implementation of top-k and top-p sampling.
PyTorch-native implementation of top-k and top-p sampling.
The logits tensor may be updated in-place.
The logits tensor may be updated in-place.
"""
"""
# Fast path: when top-k is enabled, avoid full-vocab sort/softmax by
# sampling only from the top-k candidates (and applying top-p within
# that set). This is especially important on ROCm where the PyTorch
# native sort path can be very expensive.
#
# NOTE: Do not branch on device tensors here; doing so triggers
# `aten::is_nonzero` and synchronizes the CPU with GPU.
if
(
self
.
logprobs_mode
not
in
(
"processed_logits"
,
"processed_logprobs"
)
and
envs
.
VLLM_V1_USE_REDUCED_TOPK_TOPP_SAMPLER
and
k
is
not
None
and
max_top_k
is
not
None
and
not
has_any_no_top_k
and
max_top_k
<=
4096
):
try
:
return
(
sample_top_k_top_p_reduced
(
logits
,
generators
,
k
,
p
,
max_top_k
=
max_top_k
),
None
)
except
Exception
:
# Fall back to the reference implementation for safety.
pass
logits
=
self
.
apply_top_k_top_p
(
logits
,
k
,
p
)
logits
=
self
.
apply_top_k_top_p
(
logits
,
k
,
p
)
logits_to_return
=
None
logits_to_return
=
None
if
self
.
logprobs_mode
==
"processed_logits"
:
if
self
.
logprobs_mode
==
"processed_logits"
:
...
@@ -102,6 +125,9 @@ class TopKTopPSampler(nn.Module):
...
@@ -102,6 +125,9 @@ class TopKTopPSampler(nn.Module):
generators
:
dict
[
int
,
torch
.
Generator
],
generators
:
dict
[
int
,
torch
.
Generator
],
k
:
Optional
[
torch
.
Tensor
],
k
:
Optional
[
torch
.
Tensor
],
p
:
Optional
[
torch
.
Tensor
],
p
:
Optional
[
torch
.
Tensor
],
*
,
max_top_k
:
Optional
[
int
]
=
None
,
has_any_no_top_k
:
bool
=
False
,
)
->
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
)
->
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
"""More optimized implementation for top-k and top-p sampling."""
"""More optimized implementation for top-k and top-p sampling."""
# We prefer `random_sample` over `flashinfer_sample` when sorting is
# We prefer `random_sample` over `flashinfer_sample` when sorting is
...
@@ -112,7 +138,12 @@ class TopKTopPSampler(nn.Module):
...
@@ -112,7 +138,12 @@ class TopKTopPSampler(nn.Module):
logger
.
debug_once
(
"FlashInfer 0.2.3+ does not support "
logger
.
debug_once
(
"FlashInfer 0.2.3+ does not support "
"per-request generators. Falling back to "
"per-request generators. Falling back to "
"PyTorch-native implementation."
)
"PyTorch-native implementation."
)
return
self
.
forward_native
(
logits
,
generators
,
k
,
p
)
return
self
.
forward_native
(
logits
,
generators
,
k
,
p
,
max_top_k
=
max_top_k
,
has_any_no_top_k
=
has_any_no_top_k
)
assert
self
.
logprobs_mode
not
in
(
assert
self
.
logprobs_mode
not
in
(
"processed_logits"
,
"processed_logprobs"
"processed_logits"
,
"processed_logprobs"
),
"FlashInfer does not support returning logits/logprobs"
),
"FlashInfer does not support returning logits/logprobs"
...
@@ -127,6 +158,9 @@ class TopKTopPSampler(nn.Module):
...
@@ -127,6 +158,9 @@ class TopKTopPSampler(nn.Module):
generators
:
dict
[
int
,
torch
.
Generator
],
generators
:
dict
[
int
,
torch
.
Generator
],
k
:
Optional
[
torch
.
Tensor
],
k
:
Optional
[
torch
.
Tensor
],
p
:
Optional
[
torch
.
Tensor
],
p
:
Optional
[
torch
.
Tensor
],
*
,
max_top_k
:
Optional
[
int
]
=
None
,
has_any_no_top_k
:
bool
=
False
,
)
->
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
)
->
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
"""
"""
PyTorch-native implementation of top-k and top-p sampling for CPU.
PyTorch-native implementation of top-k and top-p sampling for CPU.
...
@@ -253,6 +287,56 @@ def random_sample(
...
@@ -253,6 +287,56 @@ def random_sample(
return
probs
.
div_
(
q
).
argmax
(
dim
=-
1
).
view
(
-
1
)
return
probs
.
div_
(
q
).
argmax
(
dim
=-
1
).
view
(
-
1
)
def
sample_top_k_top_p_reduced
(
logits
:
torch
.
Tensor
,
generators
:
dict
[
int
,
torch
.
Generator
],
k
:
torch
.
Tensor
,
p
:
Optional
[
torch
.
Tensor
],
*
,
max_top_k
:
int
,
)
->
torch
.
Tensor
:
"""Sample from logits using only the top-k candidates.
This avoids full-vocab sorting and full-vocab softmax/exponential kernels.
Semantics match applying top-k then top-p (if provided) and sampling from
the resulting distribution.
"""
vocab_size
=
logits
.
shape
[
-
1
]
# Cap for safety; very large top-k values may be expensive or defeat the
# purpose of the reduced path.
if
max_top_k
<=
0
or
max_top_k
>=
vocab_size
:
masked_logits
=
apply_top_k_top_p
(
logits
,
k
,
p
)
probs
=
masked_logits
.
softmax
(
dim
=-
1
,
dtype
=
torch
.
float32
)
return
random_sample
(
probs
,
generators
)
topk
=
logits
.
topk
(
max_top_k
,
dim
=-
1
)
topk_logits
=
topk
.
values
topk_indices
=
topk
.
indices
# Apply per-row top-k (some rows may have smaller k).
# topk_logits is sorted descending by default.
k
=
k
.
to
(
torch
.
long
)
arange_k
=
torch
.
arange
(
max_top_k
,
device
=
logits
.
device
).
unsqueeze
(
0
)
keep_k
=
arange_k
<
k
.
unsqueeze
(
1
)
topk_logits
=
topk_logits
.
masked_fill
(
~
keep_k
,
-
float
(
"inf"
))
# Convert to probabilities over the reduced candidate set.
probs
=
topk_logits
.
softmax
(
dim
=-
1
,
dtype
=
torch
.
float32
)
if
p
is
not
None
:
# Apply top-p within the reduced set. Since candidates are already
# sorted by descending logit, we can do cumulative top-p on this order.
# Keep tokens until cumprob exceeds p, inclusive of the boundary token.
cumprob
=
torch
.
cumsum
(
probs
,
dim
=-
1
)
cumprob_prev
=
cumprob
-
probs
keep_p
=
cumprob_prev
<=
p
.
unsqueeze
(
1
)
probs
=
probs
*
keep_p
# Sample a position within the reduced set and map it back to vocab ids.
pos
=
random_sample
(
probs
,
generators
)
return
topk_indices
.
gather
(
1
,
pos
.
unsqueeze
(
1
)).
squeeze
(
1
)
def
flashinfer_sample
(
def
flashinfer_sample
(
logits
:
torch
.
Tensor
,
logits
:
torch
.
Tensor
,
k
:
Optional
[
torch
.
Tensor
],
k
:
Optional
[
torch
.
Tensor
],
...
...
vllm/v1/sample/sampler.py
View file @
17f59521
...
@@ -182,6 +182,8 @@ class Sampler(nn.Module):
...
@@ -182,6 +182,8 @@ class Sampler(nn.Module):
sampling_metadata
.
generators
,
sampling_metadata
.
generators
,
sampling_metadata
.
top_k
,
sampling_metadata
.
top_k
,
sampling_metadata
.
top_p
,
sampling_metadata
.
top_p
,
max_top_k
=
sampling_metadata
.
max_top_k
,
has_any_no_top_k
=
sampling_metadata
.
has_any_no_top_k
,
)
)
if
greedy_sampled
is
None
:
if
greedy_sampled
is
None
:
...
...
vllm/v1/worker/gpu_input_batch.py
View file @
17f59521
...
@@ -802,6 +802,15 @@ class InputBatch:
...
@@ -802,6 +802,15 @@ class InputBatch:
self
.
allowed_token_ids_mask
,
num_reqs
)
self
.
allowed_token_ids_mask
,
num_reqs
)
allowed_token_ids_mask
=
self
.
allowed_token_ids_mask
[:
num_reqs
]
allowed_token_ids_mask
=
self
.
allowed_token_ids_mask
[:
num_reqs
]
# Host-side summaries to avoid device synchronization in sampling
# fast paths (e.g. reduced top-k/top-p sampling).
max_top_k
=
None
has_any_no_top_k
=
False
if
not
self
.
no_top_k
and
num_reqs
>
0
:
top_k_cpu
=
self
.
top_k_cpu
[:
num_reqs
]
max_top_k
=
int
(
top_k_cpu
.
max
())
has_any_no_top_k
=
bool
((
top_k_cpu
==
self
.
vocab_size
).
any
())
return
SamplingMetadata
(
return
SamplingMetadata
(
temperature
=
temperature
,
temperature
=
temperature
,
all_greedy
=
self
.
all_greedy
,
all_greedy
=
self
.
all_greedy
,
...
@@ -819,6 +828,8 @@ class InputBatch:
...
@@ -819,6 +828,8 @@ class InputBatch:
allowed_token_ids_mask
=
allowed_token_ids_mask
,
allowed_token_ids_mask
=
allowed_token_ids_mask
,
bad_words_token_ids
=
self
.
bad_words_token_ids
,
bad_words_token_ids
=
self
.
bad_words_token_ids
,
logitsprocs
=
self
.
logitsprocs
,
logitsprocs
=
self
.
logitsprocs
,
max_top_k
=
max_top_k
,
has_any_no_top_k
=
has_any_no_top_k
,
)
)
def
get_pooling_params
(
self
)
->
list
[
PoolingParams
]:
def
get_pooling_params
(
self
)
->
list
[
PoolingParams
]:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment