Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6c01ffb8
Unverified
Commit
6c01ffb8
authored
Jan 19, 2026
by
Woosuk Kwon
Committed by
GitHub
Jan 19, 2026
Browse files
[Model Runner V2] Decouple temperature from penalties (#32629)
Signed-off-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
7b7cdce9
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
82 additions
and
56 deletions
+82
-56
vllm/v1/worker/gpu/sample/gumbel.py
vllm/v1/worker/gpu/sample/gumbel.py
+45
-1
vllm/v1/worker/gpu/sample/penalties.py
vllm/v1/worker/gpu/sample/penalties.py
+31
-50
vllm/v1/worker/gpu/sample/sampler.py
vllm/v1/worker/gpu/sample/sampler.py
+6
-5
No files found.
vllm/v1/worker/gpu/sample/gumbel.py
View file @
6c01ffb8
...
...
@@ -5,6 +5,50 @@ import torch
from
vllm.triton_utils
import
tl
,
triton
@
triton
.
jit
def
_temperature_kernel
(
logits_ptr
,
logits_stride
,
idx_mapping_ptr
,
temperature_ptr
,
vocab_size
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
batch_idx
=
tl
.
program_id
(
0
)
req_state_idx
=
tl
.
load
(
idx_mapping_ptr
+
batch_idx
)
temperature
=
tl
.
load
(
temperature_ptr
+
req_state_idx
).
to
(
tl
.
float32
)
if
temperature
==
0.0
or
temperature
==
1.0
:
# Early return to avoid loading logits.
return
block_idx
=
tl
.
program_id
(
1
)
block
=
block_idx
*
BLOCK_SIZE
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
block
<
vocab_size
logits
=
tl
.
load
(
logits_ptr
+
batch_idx
*
logits_stride
+
block
,
mask
=
mask
)
logits
=
logits
.
to
(
tl
.
float32
)
logits
=
logits
/
temperature
tl
.
store
(
logits_ptr
+
batch_idx
*
logits_stride
+
block
,
logits
,
mask
=
mask
)
def
apply_temperature
(
logits
:
torch
.
Tensor
,
idx_mapping
:
torch
.
Tensor
,
temperature
:
torch
.
Tensor
,
)
->
None
:
num_reqs
,
vocab_size
=
logits
.
shape
BLOCK_SIZE
=
8192
num_blocks
=
triton
.
cdiv
(
vocab_size
,
BLOCK_SIZE
)
_temperature_kernel
[(
num_reqs
,
num_blocks
)](
logits
,
logits
.
stride
(
0
),
idx_mapping
,
temperature
,
vocab_size
,
BLOCK_SIZE
=
BLOCK_SIZE
,
)
@
triton
.
jit
def
_gumbel_sample_kernel
(
local_argmax_ptr
,
...
...
@@ -48,7 +92,7 @@ def _gumbel_sample_kernel(
# Apply temperature.
if
APPLY_TEMPERATURE
:
# NOTE(woosuk): Match the behavior of
_penalties_and
_temperature_kernel.
# NOTE(woosuk): Match the behavior of _temperature_kernel.
# E.g., if the kernel uses tl.div_rn, we should use tl.div_rn here too.
logits
=
logits
/
temp
...
...
vllm/v1/worker/gpu/sample/penalties.py
View file @
6c01ffb8
...
...
@@ -66,16 +66,10 @@ class PenaltiesState:
self
.
frequency_penalty
.
copy_to_uva
()
self
.
presence_penalty
.
copy_to_uva
()
def
apply_penalties_and_temperature
(
self
,
logits
:
torch
.
Tensor
,
idx_mapping
:
torch
.
Tensor
,
temperature
:
torch
.
Tensor
,
)
->
None
:
apply_penalties_and_temperature
(
def
apply_penalties
(
self
,
logits
:
torch
.
Tensor
,
idx_mapping
:
torch
.
Tensor
)
->
None
:
apply_penalties
(
logits
,
idx_mapping
,
temperature
,
self
.
repetition_penalty
.
gpu
,
self
.
frequency_penalty
.
gpu
,
self
.
presence_penalty
.
gpu
,
...
...
@@ -85,14 +79,13 @@ class PenaltiesState:
@
triton
.
jit
def
_penalties_
and_temperature_
kernel
(
def
_penalties_kernel
(
logits_ptr
,
logits_stride
,
idx_mapping_ptr
,
repetition_penalty_ptr
,
frequency_penalty_ptr
,
presence_penalty_ptr
,
temperature_ptr
,
prompt_bin_mask_ptr
,
prompt_bin_mask_stride
,
output_bin_counts_ptr
,
...
...
@@ -105,15 +98,12 @@ def _penalties_and_temperature_kernel(
rep_penalty
=
tl
.
load
(
repetition_penalty_ptr
+
req_state_idx
)
freq_penalty
=
tl
.
load
(
frequency_penalty_ptr
+
req_state_idx
)
pres_penalty
=
tl
.
load
(
presence_penalty_ptr
+
req_state_idx
)
temperature
=
tl
.
load
(
temperature_ptr
+
req_state_idx
)
temperature
=
tl
.
where
(
temperature
==
0.0
,
1.0
,
temperature
)
use_rep_penalty
=
rep_penalty
!=
1.0
use_freq_penalty
=
freq_penalty
!=
0.0
use_pres_penalty
=
pres_penalty
!=
0.0
use_penalty
=
use_rep_penalty
or
use_freq_penalty
or
use_pres_penalty
use_temperature
=
temperature
!=
1.0
if
not
(
use_penalty
or
use_temperature
):
if
not
use_penalty
:
# Early return to avoid loading logits.
return
...
...
@@ -123,47 +113,39 @@ def _penalties_and_temperature_kernel(
logits
=
tl
.
load
(
logits_ptr
+
batch_idx
*
logits_stride
+
block
,
mask
=
mask
)
logits
=
logits
.
to
(
tl
.
float32
)
if
use_penalty
:
output_bin_counts
=
tl
.
load
(
output_bin_counts_ptr
+
req_state_idx
*
output_bin_counts_stride
+
block
,
mask
=
mask
,
output_bin_counts
=
tl
.
load
(
output_bin_counts_ptr
+
req_state_idx
*
output_bin_counts_stride
+
block
,
mask
=
mask
,
)
output_bin_mask
=
output_bin_counts
>
0
# Apply repetition penalties.
if
use_rep_penalty
:
packed_block
=
block_idx
*
BLOCK_SIZE
//
32
+
tl
.
arange
(
0
,
BLOCK_SIZE
//
32
)
packed_mask
=
tl
.
load
(
prompt_bin_mask_ptr
+
req_state_idx
*
prompt_bin_mask_stride
+
packed_block
,
mask
=
packed_block
<
tl
.
cdiv
(
vocab_size
,
32
),
)
output_bin_mask
=
output_bin_counts
>
0
# Apply repetition penalties.
if
use_rep_penalty
:
packed_block
=
block_idx
*
BLOCK_SIZE
//
32
+
tl
.
arange
(
0
,
BLOCK_SIZE
//
32
)
packed_mask
=
tl
.
load
(
prompt_bin_mask_ptr
+
req_state_idx
*
prompt_bin_mask_stride
+
packed_block
,
mask
=
packed_block
<
tl
.
cdiv
(
vocab_size
,
32
),
)
prompt_bin_mask
=
(
packed_mask
[:,
None
]
>>
(
tl
.
arange
(
0
,
32
)[
None
,
:]))
&
1
prompt_bin_mask
=
prompt_bin_mask
.
to
(
tl
.
int1
)
prompt_bin_mask
=
prompt_bin_mask
.
reshape
(
BLOCK_SIZE
)
# If token appears in prompt or output, apply, otherwise use 1.0 for no-op.
scale
=
tl
.
where
(
prompt_bin_mask
|
output_bin_mask
,
rep_penalty
,
1.0
)
# If logits are positive, divide by penalty, otherwise multiply by penalty.
logits
*=
tl
.
where
(
logits
>
0
,
1.0
/
scale
,
scale
)
# Apply frequency penalties.
logits
-=
freq_penalty
*
output_bin_counts
# Apply presence penalties.
logits
-=
pres_penalty
*
output_bin_mask
# Apply temperature.
logits
=
logits
/
temperature
prompt_bin_mask
=
(
packed_mask
[:,
None
]
>>
(
tl
.
arange
(
0
,
32
)[
None
,
:]))
&
1
prompt_bin_mask
=
prompt_bin_mask
.
to
(
tl
.
int1
)
prompt_bin_mask
=
prompt_bin_mask
.
reshape
(
BLOCK_SIZE
)
# If token appears in prompt or output, apply, otherwise use 1.0 for no-op.
scale
=
tl
.
where
(
prompt_bin_mask
|
output_bin_mask
,
rep_penalty
,
1.0
)
# If logits are positive, divide by penalty, otherwise multiply by penalty.
logits
*=
tl
.
where
(
logits
>
0
,
1.0
/
scale
,
scale
)
# Apply frequency penalties.
logits
-=
freq_penalty
*
output_bin_counts
# Apply presence penalties.
logits
-=
pres_penalty
*
output_bin_mask
# Store back to logits.
tl
.
store
(
logits_ptr
+
batch_idx
*
logits_stride
+
block
,
logits
,
mask
=
mask
)
def
apply_penalties
_and_temperature
(
def
apply_penalties
(
logits
:
torch
.
Tensor
,
idx_mapping
:
torch
.
Tensor
,
temperature
:
torch
.
Tensor
,
repetition_penalty
:
torch
.
Tensor
,
frequency_penalty
:
torch
.
Tensor
,
presence_penalty
:
torch
.
Tensor
,
...
...
@@ -173,14 +155,13 @@ def apply_penalties_and_temperature(
num_reqs
,
vocab_size
=
logits
.
shape
BLOCK_SIZE
=
8192
num_blocks
=
triton
.
cdiv
(
vocab_size
,
BLOCK_SIZE
)
_penalties_
and_temperature_
kernel
[(
num_reqs
,
num_blocks
)](
_penalties_kernel
[(
num_reqs
,
num_blocks
)](
logits
,
logits
.
stride
(
0
),
idx_mapping
,
repetition_penalty
,
frequency_penalty
,
presence_penalty
,
temperature
,
prompt_bin_mask
,
prompt_bin_mask
.
stride
(
0
),
output_bin_counts
,
...
...
vllm/v1/worker/gpu/sample/sampler.py
View file @
6c01ffb8
...
...
@@ -9,7 +9,7 @@ from vllm.config.model import LogprobsMode
from
vllm.sampling_params
import
SamplingParams
from
vllm.v1.sample.ops.topk_topp_sampler
import
apply_top_k_top_p
from
vllm.v1.worker.gpu.metrics.logits
import
get_num_nans
from
vllm.v1.worker.gpu.sample.gumbel
import
gumbel_sample
from
vllm.v1.worker.gpu.sample.gumbel
import
apply_temperature
,
gumbel_sample
from
vllm.v1.worker.gpu.sample.logit_bias
import
LogitBiasState
from
vllm.v1.worker.gpu.sample.logprob
import
compute_topk_logprobs
from
vllm.v1.worker.gpu.sample.min_p
import
apply_min_p
...
...
@@ -106,10 +106,11 @@ class Sampler:
# Apply logit bias (e.g., allowed_token_ids, min_tokens) in place.
self
.
logit_bias_state
.
apply_logit_bias
(
logits
,
idx_mapping
,
pos
)
# Apply penalties and temperature in place.
self
.
penalties_state
.
apply_penalties_and_temperature
(
logits
,
idx_mapping
,
self
.
sampling_states
.
temperature
.
gpu
)
# Apply penalties in place.
self
.
penalties_state
.
apply_penalties
(
logits
,
idx_mapping
)
# Apply temperature in place.
apply_temperature
(
logits
,
idx_mapping
,
self
.
sampling_states
.
temperature
.
gpu
)
# Apply min_p in place if any request has a non-zero min_p.
do_min_p
=
self
.
sampling_states
.
do_min_p
(
idx_mapping_np
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment