Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a238cbd8
Unverified
Commit
a238cbd8
authored
Dec 05, 2025
by
Woosuk Kwon
Committed by
GitHub
Dec 05, 2025
Browse files
[Model Runner V2] Support min-p sampling (#30171)
Signed-off-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
4026ae31
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
77 additions
and
0 deletions
+77
-0
vllm/v1/worker/gpu/sample/metadata.py
vllm/v1/worker/gpu/sample/metadata.py
+13
-0
vllm/v1/worker/gpu/sample/min_p.py
vllm/v1/worker/gpu/sample/min_p.py
+53
-0
vllm/v1/worker/gpu/sample/sampler.py
vllm/v1/worker/gpu/sample/sampler.py
+4
-0
vllm/v1/worker/gpu/states.py
vllm/v1/worker/gpu/states.py
+7
-0
No files found.
vllm/v1/worker/gpu/sample/metadata.py
View file @
a238cbd8
...
@@ -13,6 +13,7 @@ class SamplingMetadata:
...
@@ -13,6 +13,7 @@ class SamplingMetadata:
top_p
:
torch
.
Tensor
|
None
top_p
:
torch
.
Tensor
|
None
top_k
:
torch
.
Tensor
|
None
top_k
:
torch
.
Tensor
|
None
min_p
:
torch
.
Tensor
|
None
repetition_penalty
:
torch
.
Tensor
repetition_penalty
:
torch
.
Tensor
frequency_penalty
:
torch
.
Tensor
frequency_penalty
:
torch
.
Tensor
...
@@ -44,6 +45,7 @@ class SamplingMetadata:
...
@@ -44,6 +45,7 @@ class SamplingMetadata:
# top_k = torch.full((num_reqs,), 20, dtype=torch.int32, device=device)
# top_k = torch.full((num_reqs,), 20, dtype=torch.int32, device=device)
top_p
=
None
top_p
=
None
top_k
=
None
top_k
=
None
min_p
=
torch
.
zeros
(
num_reqs
,
dtype
=
torch
.
float32
,
device
=
device
)
# NOTE(woosuk): We must set penalties to their default values to make sure
# NOTE(woosuk): We must set penalties to their default values to make sure
# the penalties kernel does not touch the placeholder bin_counts tensors.
# the penalties kernel does not touch the placeholder bin_counts tensors.
repetition_penalty
=
torch
.
ones
(
num_reqs
,
dtype
=
torch
.
float32
,
device
=
device
)
repetition_penalty
=
torch
.
ones
(
num_reqs
,
dtype
=
torch
.
float32
,
device
=
device
)
...
@@ -64,6 +66,7 @@ class SamplingMetadata:
...
@@ -64,6 +66,7 @@ class SamplingMetadata:
temperature
=
temperature
,
temperature
=
temperature
,
top_p
=
top_p
,
top_p
=
top_p
,
top_k
=
top_k
,
top_k
=
top_k
,
min_p
=
min_p
,
repetition_penalty
=
repetition_penalty
,
repetition_penalty
=
repetition_penalty
,
frequency_penalty
=
frequency_penalty
,
frequency_penalty
=
frequency_penalty
,
presence_penalty
=
presence_penalty
,
presence_penalty
=
presence_penalty
,
...
@@ -85,6 +88,8 @@ def _expand_sampling_metadata_kernel(
...
@@ -85,6 +88,8 @@ def _expand_sampling_metadata_kernel(
expanded_top_p_ptr
,
expanded_top_p_ptr
,
top_k_ptr
,
top_k_ptr
,
expanded_top_k_ptr
,
expanded_top_k_ptr
,
min_p_ptr
,
expanded_min_p_ptr
,
rep_penalty_ptr
,
rep_penalty_ptr
,
expanded_rep_penalty_ptr
,
expanded_rep_penalty_ptr
,
freq_penalty_ptr
,
freq_penalty_ptr
,
...
@@ -115,6 +120,10 @@ def _expand_sampling_metadata_kernel(
...
@@ -115,6 +120,10 @@ def _expand_sampling_metadata_kernel(
top_k
=
tl
.
load
(
top_k_ptr
+
req_idx
)
top_k
=
tl
.
load
(
top_k_ptr
+
req_idx
)
tl
.
store
(
expanded_top_k_ptr
+
start_idx
+
block
,
top_k
,
mask
=
mask
)
tl
.
store
(
expanded_top_k_ptr
+
start_idx
+
block
,
top_k
,
mask
=
mask
)
if
min_p_ptr
is
not
None
:
min_p
=
tl
.
load
(
min_p_ptr
+
req_idx
)
tl
.
store
(
expanded_min_p_ptr
+
start_idx
+
block
,
min_p
,
mask
=
mask
)
rep_penalty
=
tl
.
load
(
rep_penalty_ptr
+
req_idx
)
rep_penalty
=
tl
.
load
(
rep_penalty_ptr
+
req_idx
)
tl
.
store
(
expanded_rep_penalty_ptr
+
start_idx
+
block
,
rep_penalty
,
mask
=
mask
)
tl
.
store
(
expanded_rep_penalty_ptr
+
start_idx
+
block
,
rep_penalty
,
mask
=
mask
)
...
@@ -138,6 +147,7 @@ def expand_sampling_metadata(
...
@@ -138,6 +147,7 @@ def expand_sampling_metadata(
expanded_temp
=
create_empty
(
sampling_metadata
.
temperature
)
expanded_temp
=
create_empty
(
sampling_metadata
.
temperature
)
expanded_top_p
=
create_empty
(
sampling_metadata
.
top_p
)
expanded_top_p
=
create_empty
(
sampling_metadata
.
top_p
)
expanded_top_k
=
create_empty
(
sampling_metadata
.
top_k
)
expanded_top_k
=
create_empty
(
sampling_metadata
.
top_k
)
expanded_min_p
=
create_empty
(
sampling_metadata
.
min_p
)
expanded_repetition_penalty
=
create_empty
(
sampling_metadata
.
repetition_penalty
)
expanded_repetition_penalty
=
create_empty
(
sampling_metadata
.
repetition_penalty
)
expanded_frequency_penalty
=
create_empty
(
sampling_metadata
.
frequency_penalty
)
expanded_frequency_penalty
=
create_empty
(
sampling_metadata
.
frequency_penalty
)
expanded_presence_penalty
=
create_empty
(
sampling_metadata
.
presence_penalty
)
expanded_presence_penalty
=
create_empty
(
sampling_metadata
.
presence_penalty
)
...
@@ -151,6 +161,8 @@ def expand_sampling_metadata(
...
@@ -151,6 +161,8 @@ def expand_sampling_metadata(
expanded_top_p
,
expanded_top_p
,
sampling_metadata
.
top_k
,
sampling_metadata
.
top_k
,
expanded_top_k
,
expanded_top_k
,
sampling_metadata
.
min_p
,
expanded_min_p
,
sampling_metadata
.
repetition_penalty
,
sampling_metadata
.
repetition_penalty
,
expanded_repetition_penalty
,
expanded_repetition_penalty
,
sampling_metadata
.
frequency_penalty
,
sampling_metadata
.
frequency_penalty
,
...
@@ -166,6 +178,7 @@ def expand_sampling_metadata(
...
@@ -166,6 +178,7 @@ def expand_sampling_metadata(
temperature
=
expanded_temp
,
temperature
=
expanded_temp
,
top_p
=
expanded_top_p
,
top_p
=
expanded_top_p
,
top_k
=
expanded_top_k
,
top_k
=
expanded_top_k
,
min_p
=
expanded_min_p
,
seeds
=
expanded_seeds
,
seeds
=
expanded_seeds
,
repetition_penalty
=
expanded_repetition_penalty
,
repetition_penalty
=
expanded_repetition_penalty
,
frequency_penalty
=
expanded_frequency_penalty
,
frequency_penalty
=
expanded_frequency_penalty
,
...
...
vllm/v1/worker/gpu/sample/min_p.py
0 → 100644
View file @
a238cbd8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
torch
from
vllm.triton_utils
import
tl
,
triton
@
triton
.
jit
def
_min_p_kernel
(
logits_ptr
,
logits_stride
,
min_p_ptr
,
vocab_size
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
req_idx
=
tl
.
program_id
(
0
)
min_p
=
tl
.
load
(
min_p_ptr
+
req_idx
).
to
(
tl
.
float32
)
if
min_p
==
0.0
:
return
max_val
=
float
(
"-inf"
)
for
i
in
range
(
0
,
vocab_size
,
BLOCK_SIZE
):
block
=
i
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
block
<
vocab_size
logits
=
tl
.
load
(
logits_ptr
+
req_idx
*
logits_stride
+
block
,
mask
=
mask
,
other
=
float
(
"-inf"
)
)
max_val
=
tl
.
max
(
tl
.
maximum
(
logits
,
max_val
))
max_val
=
max_val
.
to
(
tl
.
float32
)
# type: ignore
threshold
=
max_val
+
tl
.
log
(
min_p
)
for
i
in
range
(
0
,
vocab_size
,
BLOCK_SIZE
):
block
=
i
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
block
<
vocab_size
logits
=
tl
.
load
(
logits_ptr
+
req_idx
*
logits_stride
+
block
,
mask
=
mask
,
other
=
float
(
"-inf"
)
)
logits
=
tl
.
where
(
logits
<
threshold
,
float
(
"-inf"
),
logits
)
tl
.
store
(
logits_ptr
+
req_idx
*
logits_stride
+
block
,
logits
,
mask
=
mask
)
def
apply_min_p
(
logits
:
torch
.
Tensor
,
min_p
:
torch
.
Tensor
|
None
)
->
None
:
if
min_p
is
None
:
return
num_reqs
,
vocab_size
=
logits
.
shape
BLOCK_SIZE
=
1024
_min_p_kernel
[(
num_reqs
,)](
logits
,
logits
.
stride
(
0
),
min_p
,
vocab_size
,
BLOCK_SIZE
=
BLOCK_SIZE
,
)
vllm/v1/worker/gpu/sample/sampler.py
View file @
a238cbd8
...
@@ -9,6 +9,7 @@ from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
...
@@ -9,6 +9,7 @@ from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
from
vllm.v1.worker.gpu.sample.gumbel
import
gumbel_sample
from
vllm.v1.worker.gpu.sample.gumbel
import
gumbel_sample
from
vllm.v1.worker.gpu.sample.logprob
import
compute_topk_logprobs
from
vllm.v1.worker.gpu.sample.logprob
import
compute_topk_logprobs
from
vllm.v1.worker.gpu.sample.metadata
import
SamplingMetadata
from
vllm.v1.worker.gpu.sample.metadata
import
SamplingMetadata
from
vllm.v1.worker.gpu.sample.min_p
import
apply_min_p
from
vllm.v1.worker.gpu.sample.penalties
import
apply_penalties_and_temperature
from
vllm.v1.worker.gpu.sample.penalties
import
apply_penalties_and_temperature
...
@@ -61,6 +62,9 @@ class Sampler:
...
@@ -61,6 +62,9 @@ class Sampler:
# Apply penalties and temperature in place.
# Apply penalties and temperature in place.
apply_penalties_and_temperature
(
logits
,
sampling_metadata
)
apply_penalties_and_temperature
(
logits
,
sampling_metadata
)
# Apply min_p in place.
apply_min_p
(
logits
,
sampling_metadata
.
min_p
)
# Apply top_k and/or top_p. This might return a new tensor.
logits
=
apply_top_k_top_p
(
logits
=
apply_top_k_top_p
(
logits
,
sampling_metadata
.
top_k
,
sampling_metadata
.
top_p
logits
,
sampling_metadata
.
top_k
,
sampling_metadata
.
top_p
)
)
...
...
vllm/v1/worker/gpu/states.py
View file @
a238cbd8
...
@@ -87,6 +87,7 @@ class RequestState:
...
@@ -87,6 +87,7 @@ class RequestState:
self
.
temperature
=
self
.
_make_param
(
self
.
max_num_reqs
,
torch
.
float32
)
self
.
temperature
=
self
.
_make_param
(
self
.
max_num_reqs
,
torch
.
float32
)
self
.
top_p
=
self
.
_make_param
(
self
.
max_num_reqs
,
torch
.
float32
)
self
.
top_p
=
self
.
_make_param
(
self
.
max_num_reqs
,
torch
.
float32
)
self
.
top_k
=
self
.
_make_param
(
self
.
max_num_reqs
,
torch
.
int32
)
self
.
top_k
=
self
.
_make_param
(
self
.
max_num_reqs
,
torch
.
int32
)
self
.
min_p
=
self
.
_make_param
(
self
.
max_num_reqs
,
torch
.
float32
)
self
.
repetition_penalty
=
self
.
_make_param
(
self
.
max_num_reqs
,
torch
.
float32
)
self
.
repetition_penalty
=
self
.
_make_param
(
self
.
max_num_reqs
,
torch
.
float32
)
self
.
frequency_penalty
=
self
.
_make_param
(
self
.
max_num_reqs
,
torch
.
float32
)
self
.
frequency_penalty
=
self
.
_make_param
(
self
.
max_num_reqs
,
torch
.
float32
)
self
.
presence_penalty
=
self
.
_make_param
(
self
.
max_num_reqs
,
torch
.
float32
)
self
.
presence_penalty
=
self
.
_make_param
(
self
.
max_num_reqs
,
torch
.
float32
)
...
@@ -162,6 +163,7 @@ class RequestState:
...
@@ -162,6 +163,7 @@ class RequestState:
else
:
else
:
top_k
=
self
.
vocab_size
top_k
=
self
.
vocab_size
self
.
top_k
.
np
[
req_idx
]
=
top_k
self
.
top_k
.
np
[
req_idx
]
=
top_k
self
.
min_p
.
np
[
req_idx
]
=
sampling_params
.
min_p
self
.
repetition_penalty
.
np
[
req_idx
]
=
sampling_params
.
repetition_penalty
self
.
repetition_penalty
.
np
[
req_idx
]
=
sampling_params
.
repetition_penalty
self
.
frequency_penalty
.
np
[
req_idx
]
=
sampling_params
.
frequency_penalty
self
.
frequency_penalty
.
np
[
req_idx
]
=
sampling_params
.
frequency_penalty
self
.
presence_penalty
.
np
[
req_idx
]
=
sampling_params
.
presence_penalty
self
.
presence_penalty
.
np
[
req_idx
]
=
sampling_params
.
presence_penalty
...
@@ -217,6 +219,10 @@ class RequestState:
...
@@ -217,6 +219,10 @@ class RequestState:
no_top_k
=
np
.
all
(
top_k
==
self
.
vocab_size
)
no_top_k
=
np
.
all
(
top_k
==
self
.
vocab_size
)
top_k
=
self
.
top_k
.
copy_np_to_gpu
(
top_k
)
if
not
no_top_k
else
None
top_k
=
self
.
top_k
.
copy_np_to_gpu
(
top_k
)
if
not
no_top_k
else
None
min_p
=
self
.
min_p
.
np
[
idx_mapping_np
]
no_min_p
=
np
.
all
(
min_p
==
0.0
)
min_p
=
self
.
min_p
.
copy_np_to_gpu
(
min_p
)
if
not
no_min_p
else
None
rep_penalty
=
self
.
repetition_penalty
.
np
[
idx_mapping_np
]
rep_penalty
=
self
.
repetition_penalty
.
np
[
idx_mapping_np
]
rep_penalty
=
self
.
repetition_penalty
.
copy_np_to_gpu
(
rep_penalty
)
rep_penalty
=
self
.
repetition_penalty
.
copy_np_to_gpu
(
rep_penalty
)
freq_penalty
=
self
.
frequency_penalty
.
np
[
idx_mapping_np
]
freq_penalty
=
self
.
frequency_penalty
.
np
[
idx_mapping_np
]
...
@@ -236,6 +242,7 @@ class RequestState:
...
@@ -236,6 +242,7 @@ class RequestState:
temperature
=
temperature
,
temperature
=
temperature
,
top_p
=
top_p
,
top_p
=
top_p
,
top_k
=
top_k
,
top_k
=
top_k
,
min_p
=
min_p
,
repetition_penalty
=
rep_penalty
,
repetition_penalty
=
rep_penalty
,
frequency_penalty
=
freq_penalty
,
frequency_penalty
=
freq_penalty
,
presence_penalty
=
pres_penalty
,
presence_penalty
=
pres_penalty
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment