Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
586f0eba
Commit
586f0eba
authored
Mar 04, 2026
by
王敏
Browse files
[perf]合入lightop topp_topk 融合算子
parent
2036eb73
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
66 additions
and
8 deletions
+66
-8
vllm/v1/sample/ops/topk_topp_sampler.py
vllm/v1/sample/ops/topk_topp_sampler.py
+63
-0
vllm/v1/sample/rejection_sampler_opt.py
vllm/v1/sample/rejection_sampler_opt.py
+0
-2
vllm/v1/worker/gpu_input_batch.py
vllm/v1/worker/gpu_input_batch.py
+1
-1
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+2
-5
No files found.
vllm/v1/sample/ops/topk_topp_sampler.py
View file @
586f0eba
...
...
@@ -12,6 +12,14 @@ from vllm.config.model import LogprobsMode
from
vllm.logger
import
init_logger
from
vllm.platforms
import
CpuArchEnum
,
current_platform
HAS_LIGHTOP_OPT_KERNEL
=
True
try
:
from
lightop.sampling
import
top_k_top_p_sampling_from_probs
as
top_k_top_p_sampling_from_probs_lightop
from
lightop.sampling
import
top_k_sampling_from_probs
as
top_k_sampling_from_probs_lightop
from
lightop.sampling
import
top_p_sampling_from_probs
as
top_p_sampling_from_probs_lightop
except
ImportError
:
HAS_LIGHTOP_OPT_KERNEL
=
False
logger
=
init_logger
(
__name__
)
...
...
@@ -86,6 +94,8 @@ class TopKTopPSampler(nn.Module):
self
.
forward
=
self
.
forward_native
else
:
self
.
forward
=
self
.
forward_native
if
HAS_LIGHTOP_OPT_KERNEL
:
self
.
forward
=
self
.
forward_lightop_opt
self
.
apply_top_k_top_p
=
apply_top_k_top_p
...
...
@@ -169,6 +179,19 @@ class TopKTopPSampler(nn.Module):
# because of slicing operation in logits_processor.
return
flashinfer_sample
(
logits
.
contiguous
(),
k
,
p
,
generators
),
None
def
forward_lightop_opt
(
self
,
logits
:
torch
.
Tensor
,
generators
:
dict
[
int
,
torch
.
Generator
],
k
:
torch
.
Tensor
|
None
,
p
:
torch
.
Tensor
|
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
|
None
]:
"""Top-k and top-p sampling optimized by lightop."""
if
(
k
is
None
and
p
is
None
)
or
generators
:
return
self
.
forward_native
(
logits
,
generators
,
k
,
p
)
return
lightop_sample
(
logits
.
contiguous
(),
k
,
p
,
generators
),
None
def
forward_cpu
(
self
,
logits
:
torch
.
Tensor
,
...
...
@@ -453,6 +476,46 @@ def flashinfer_sample(
return
next_token_ids
.
view
(
-
1
)
def
lightop_sample
(
logits
:
torch
.
Tensor
,
k
:
torch
.
Tensor
|
None
,
p
:
torch
.
Tensor
|
None
,
generators
:
dict
[
int
,
torch
.
Generator
],
)
->
torch
.
Tensor
:
"""Sample from the logits using lightop.
Statistically, this function is equivalent to the `random_sample` function.
However, this function is faster because it avoids sorting the logits tensor
via rejection sampling.
NOTE: The outputs of this function do not necessarily match the outputs of
the `random_sample` function. It only guarantees that the outputs are
statistically equivalent.
NOTE: This function includes CPU-GPU synchronization, while `random_sample`
does not. Call this function at the end of the forward pass to minimize
the synchronization overhead.
"""
assert
not
(
k
is
None
and
p
is
None
)
probs
=
logits
.
softmax
(
dim
=-
1
,
dtype
=
torch
.
float32
)
if
k
is
None
:
# Top-p only.
next_token_ids
=
top_p_sampling_from_probs_lightop
(
probs
,
p
,
deterministic
=
True
)
elif
p
is
None
:
# Top-k only.
next_token_ids
=
top_k_sampling_from_probs_lightop
(
probs
,
k
,
deterministic
=
True
)
else
:
# Both top-k and top-p.
next_token_ids
=
top_k_top_p_sampling_from_probs_lightop
(
probs
,
k
,
p
,
deterministic
=
True
)
return
next_token_ids
.
view
(
-
1
)
def
_to_tensor_scalar_tuple
(
x
):
if
isinstance
(
x
,
torch
.
Tensor
):
...
...
vllm/v1/sample/rejection_sampler_opt.py
View file @
586f0eba
...
...
@@ -98,8 +98,6 @@ class OptRejectionSampler(nn.Module):
# won't affect the original logits tensor.
assert
logits
is
not
None
sampling_metadata
.
all_greedy
=
True
sampling_metadata
.
all_random
=
False
sampler_output
=
self
.
sampler
(
logits
=
logits
,
sampling_metadata
=
replace
(
...
...
vllm/v1/worker/gpu_input_batch.py
View file @
586f0eba
...
...
@@ -807,7 +807,7 @@ class InputBatch:
batch_update
=
self
.
batch_update_builder
.
get_and_reset
(
self
.
num_reqs
)
for
logit_proc
in
self
.
logitsprocs
.
all
:
logit_proc
.
update_state
(
batch_update
)
if
batch_update
:
if
batch_update
or
repeat_counts
is
not
None
:
self
.
sampling_metadata
=
self
.
_make_sampling_metadata
(
repeat_counts
)
def
_make_sampling_metadata
(
self
,
repeat_counts
:
Optional
[
torch
.
Tensor
]
=
None
)
->
SamplingMetadata
:
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
586f0eba
...
...
@@ -4916,6 +4916,7 @@ class GPUModelRunner(
draft_probs
=
torch
.
randn
(
num_reqs
,
self
.
speculative_config
.
num_speculative_tokens
,
logits
.
shape
[
-
1
],
device
=
self
.
device
,
dtype
=
logits
.
dtype
)
dummy_metadata
.
all_greedy
=
True
logits
=
torch
.
randn
(
num_tokens
+
num_reqs
,
...
...
@@ -5537,10 +5538,6 @@ class GPUModelRunner(
ValueError: If no valid block size found
"""
#exclude indexer backend
def
_participates_in_block_size_selection
(
backend
:
type
[
AttentionBackend
])
->
bool
:
return
not
getattr
(
backend
,
"exclude_from_block_size_selection"
,
False
)
def
block_size_is_supported
(
backends
:
list
[
type
[
AttentionBackend
]],
block_size
:
int
)
->
bool
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment