Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d00df624
Unverified
Commit
d00df624
authored
Feb 16, 2026
by
Woosuk Kwon
Committed by
GitHub
Feb 16, 2026
Browse files
[Model Runner V2] Minor refactoring for penalties (#34662)
Signed-off-by:
Woosuk Kwon
<
woosuk@inferact.ai
>
parent
9752da9d
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
93 additions
and
74 deletions
+93
-74
vllm/v1/worker/gpu/model_runner.py
vllm/v1/worker/gpu/model_runner.py
+2
-8
vllm/v1/worker/gpu/sample/bad_words.py
vllm/v1/worker/gpu/sample/bad_words.py
+8
-15
vllm/v1/worker/gpu/sample/penalties.py
vllm/v1/worker/gpu/sample/penalties.py
+77
-37
vllm/v1/worker/gpu/sample/sampler.py
vllm/v1/worker/gpu/sample/sampler.py
+6
-14
No files found.
vllm/v1/worker/gpu/model_runner.py
View file @
d00df624
...
...
@@ -155,9 +155,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
max_num_reqs
=
self
.
max_num_reqs
,
vocab_size
=
self
.
vocab_size
,
device
=
self
.
device
,
all_token_ids
=
self
.
req_states
.
all_token_ids
.
gpu
,
prompt_len
=
self
.
req_states
.
prompt_len
.
gpu
,
total_len
=
self
.
req_states
.
total_len
.
gpu
,
req_states
=
self
.
req_states
,
logprobs_mode
=
self
.
model_config
.
logprobs_mode
,
num_speculative_tokens
=
self
.
num_speculative_steps
+
1
,
)
...
...
@@ -528,11 +526,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if
scheduler_output
.
scheduled_new_reqs
:
self
.
req_states
.
apply_staged_writes
()
self
.
sampler
.
apply_staged_writes
(
self
.
req_states
.
all_token_ids
.
gpu
,
self
.
req_states
.
prefill_len
.
np
,
self
.
req_states
.
prompt_len
.
np
,
)
self
.
sampler
.
apply_staged_writes
()
if
self
.
uses_mrope
:
self
.
mrope_states
.
apply_staged_writes
()
...
...
vllm/v1/worker/gpu/sample/bad_words.py
View file @
d00df624
...
...
@@ -6,24 +6,17 @@ import torch
from
vllm.sampling_params
import
SamplingParams
from
vllm.triton_utils
import
tl
,
triton
from
vllm.v1.worker.gpu.buffer_utils
import
StagedWriteTensor
,
UvaBackedTensor
from
vllm.v1.worker.gpu.states
import
RequestState
MAX_BAD_WORDS_TOTAL_TOKENS
=
1024
# Max total tokens for all bad words per request
MAX_NUM_BAD_WORDS
=
128
# Max number of bad words per request
class
BadWordsState
:
def
__init__
(
self
,
all_token_ids
:
torch
.
Tensor
,
prompt_len
:
torch
.
Tensor
,
total_len
:
torch
.
Tensor
,
):
self
.
all_token_ids
=
all_token_ids
self
.
prompt_len
=
prompt_len
self
.
total_len
=
total_len
self
.
max_num_reqs
=
prompt_len
.
shape
[
0
]
self
.
device
=
prompt_len
.
device
def
__init__
(
self
,
req_states
:
RequestState
):
self
.
req_states
=
req_states
self
.
max_num_reqs
=
req_states
.
max_num_reqs
self
.
device
=
req_states
.
device
# flattened bad word tokens: [max_num_reqs, MAX_BAD_WORDS_TOTAL_TOKENS]
self
.
bad_word_token_ids
=
StagedWriteTensor
(
...
...
@@ -95,9 +88,9 @@ class BadWordsState:
self
.
bad_word_token_ids
.
gpu
,
self
.
bad_word_offsets
.
gpu
,
self
.
num_bad_words
.
gpu
,
self
.
all_token_ids
,
self
.
prompt_len
,
self
.
total_len
,
self
.
req_states
.
all_token_ids
.
gpu
,
self
.
req_states
.
prompt_len
.
gpu
,
self
.
req_states
.
total_len
.
gpu
,
input_ids
,
expanded_local_pos
,
max_num_bad_words
,
...
...
vllm/v1/worker/gpu/sample/penalties.py
View file @
d00df624
...
...
@@ -6,14 +6,18 @@ import torch
from
vllm.sampling_params
import
SamplingParams
from
vllm.triton_utils
import
tl
,
triton
from
vllm.utils.math_utils
import
cdiv
from
vllm.utils.torch_utils
import
async_tensor_h2d
from
vllm.v1.worker.gpu.buffer_utils
import
UvaBackedTensor
from
vllm.v1.worker.gpu.states
import
RequestState
class
PenaltiesState
:
def
__init__
(
self
,
max_num_reqs
:
int
,
vocab_size
:
int
,
device
:
torch
.
device
):
self
.
max_num_reqs
=
max_num_reqs
self
.
vocab_size
=
vocab_size
self
.
device
=
device
def
__init__
(
self
,
req_states
:
RequestState
):
self
.
req_states
=
req_states
max_num_reqs
=
req_states
.
max_num_reqs
self
.
vocab_size
=
req_states
.
vocab_size
self
.
device
=
req_states
.
device
self
.
repetition_penalty
=
UvaBackedTensor
(
max_num_reqs
,
dtype
=
torch
.
float32
)
self
.
frequency_penalty
=
UvaBackedTensor
(
max_num_reqs
,
dtype
=
torch
.
float32
)
...
...
@@ -26,7 +30,7 @@ class PenaltiesState:
# Statistics for penalties.
self
.
prompt_bin_mask
=
torch
.
zeros
(
self
.
max_num_reqs
,
max_num_reqs
,
cdiv
(
self
.
vocab_size
,
32
),
dtype
=
torch
.
int32
,
device
=
self
.
device
,
...
...
@@ -34,10 +38,10 @@ class PenaltiesState:
# TODO(woosuk): This tensor is rarely used but can be very large, taking up
# GBs of GPU memory. Optimize the memory usage.
self
.
output_bin_counts
=
torch
.
zeros
(
self
.
max_num_reqs
,
self
.
vocab_size
,
dtype
=
torch
.
int32
,
device
=
self
.
device
max_num_reqs
,
self
.
vocab_size
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
self
.
_penalties_reqs
:
list
[
int
]
=
[]
self
.
_
new_
penalties_reqs
:
list
[
int
]
=
[]
def
add_request
(
self
,
req_idx
:
int
,
sampling_params
:
SamplingParams
)
->
None
:
self
.
repetition_penalty
.
np
[
req_idx
]
=
sampling_params
.
repetition_penalty
...
...
@@ -47,24 +51,29 @@ class PenaltiesState:
do_penalty
=
use_penalty
(
sampling_params
)
self
.
use_penalty
[
req_idx
]
=
do_penalty
if
do_penalty
:
self
.
_penalties_reqs
.
append
(
req_idx
)
self
.
_new_penalties_reqs
.
append
(
req_idx
)
def
apply_staged_writes
(
self
)
->
None
:
if
self
.
_new_penalties_reqs
:
idx_mapping
=
async_tensor_h2d
(
self
.
_new_penalties_reqs
,
dtype
=
torch
.
int32
,
target_device
=
self
.
device
,
pin_memory
=
True
,
)
def
apply_staged_writes
(
self
,
all_token_ids
:
torch
.
Tensor
,
prefill_lens
:
np
.
ndarray
,
prompt_lens
:
np
.
ndarray
,
)
->
None
:
# TODO(woosuk): Optimize this.
for
req_idx
in
self
.
_penalties_reqs
:
prefill_lens
=
self
.
req_states
.
prefill_len
.
np
[
self
.
_new_penalties_reqs
]
max_prefill_len
=
int
(
prefill_lens
.
max
())
bincount
(
all_token_ids
[
req_idx
],
int
(
prefill_lens
[
req_idx
]),
int
(
prompt_lens
[
req_idx
]),
self
.
prompt_bin_mask
[
req_idx
],
self
.
output_bin_counts
[
req_idx
],
idx_mapping
,
self
.
req_states
.
all_token_ids
.
gpu
,
self
.
req_states
.
prompt_len
.
gpu
,
self
.
req_states
.
prefill_len
.
gpu
,
self
.
prompt_bin_mask
,
self
.
output_bin_counts
,
max_prefill_len
,
)
self
.
_penalties_reqs
.
clear
()
self
.
_
new_
penalties_reqs
.
clear
()
self
.
repetition_penalty
.
copy_to_uva
()
self
.
frequency_penalty
.
copy_to_uva
()
...
...
@@ -214,51 +223,82 @@ def apply_penalties(
)
@
triton
.
jit
(
do_not_specialize
=
[
"prefill_len"
,
"prompt_len"
])
@
triton
.
jit
def
_bincount_kernel
(
idx_mapping_ptr
,
all_token_ids_ptr
,
prefill_len
,
prompt_len
,
all_token_ids_stride
,
prompt_len_ptr
,
prefill_len_ptr
,
prompt_bin_mask_ptr
,
prompt_bin_mask_stride
,
output_bin_counts_ptr
,
output_bin_counts_stride
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
block_idx
=
tl
.
program_id
(
0
)
batch_idx
=
tl
.
program_id
(
0
)
block_idx
=
tl
.
program_id
(
1
)
req_state_idx
=
tl
.
load
(
idx_mapping_ptr
+
batch_idx
)
prefill_len
=
tl
.
load
(
prefill_len_ptr
+
req_state_idx
)
if
block_idx
*
BLOCK_SIZE
>=
prefill_len
:
return
prompt_len
=
tl
.
load
(
prompt_len_ptr
+
req_state_idx
)
block
=
block_idx
*
BLOCK_SIZE
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
if
block_idx
*
BLOCK_SIZE
<
prompt_len
:
mask
=
block
<
prompt_len
prompt_tokens
=
tl
.
load
(
all_token_ids_ptr
+
block
,
mask
=
mask
)
prompt_tokens
=
tl
.
load
(
all_token_ids_ptr
+
req_state_idx
*
all_token_ids_stride
+
block
,
mask
=
mask
)
idx
=
prompt_tokens
//
32
bit_idx
=
prompt_tokens
%
32
bit
=
tl
.
full
((
BLOCK_SIZE
,),
1
,
tl
.
int32
)
<<
bit_idx
tl
.
atomic_or
(
prompt_bin_mask_ptr
+
idx
,
bit
,
mask
=
mask
)
tl
.
atomic_or
(
prompt_bin_mask_ptr
+
req_state_idx
*
prompt_bin_mask_stride
+
idx
,
bit
,
mask
=
mask
,
)
if
(
block_idx
+
1
)
*
BLOCK_SIZE
>=
prompt_len
:
mask
=
block
<
prefill_len
mask
&=
block
>=
prompt_len
output_tokens
=
tl
.
load
(
all_token_ids_ptr
+
block
,
mask
=
mask
)
tl
.
atomic_add
(
output_bin_counts_ptr
+
output_tokens
,
1
,
mask
=
mask
)
output_tokens
=
tl
.
load
(
all_token_ids_ptr
+
req_state_idx
*
all_token_ids_stride
+
block
,
mask
=
mask
)
tl
.
atomic_add
(
output_bin_counts_ptr
+
req_state_idx
*
output_bin_counts_stride
+
output_tokens
,
1
,
mask
=
mask
,
)
def
bincount
(
idx_mapping
:
torch
.
Tensor
,
all_token_ids
:
torch
.
Tensor
,
pr
efill
_len
:
int
,
pr
ompt
_len
:
int
,
pr
ompt
_len
:
torch
.
Tensor
,
pr
efill
_len
:
torch
.
Tensor
,
prompt_bin_mask
:
torch
.
Tensor
,
output_bin_counts
:
torch
.
Tensor
,
max_prefill_len
:
int
,
)
->
None
:
prompt_bin_mask
.
zero_
()
output_bin_counts
.
zero_
()
prompt_bin_mask
[
idx_mapping
]
=
0
output_bin_counts
[
idx_mapping
]
=
0
num_reqs
=
idx_mapping
.
shape
[
0
]
BLOCK_SIZE
=
1024
num_blocks
=
triton
.
cdiv
(
prefill_len
,
BLOCK_SIZE
)
_bincount_kernel
[(
num_blocks
,)](
num_blocks
=
triton
.
cdiv
(
max_prefill_len
,
BLOCK_SIZE
)
_bincount_kernel
[(
num_reqs
,
num_blocks
)](
idx_mapping
,
all_token_ids
,
prefill_len
,
all_token_ids
.
stride
(
0
)
,
prompt_len
,
prefill_len
,
prompt_bin_mask
,
prompt_bin_mask
.
stride
(
0
),
output_bin_counts
,
output_bin_counts
.
stride
(
0
),
BLOCK_SIZE
=
BLOCK_SIZE
,
)
...
...
vllm/v1/worker/gpu/sample/sampler.py
View file @
d00df624
...
...
@@ -15,6 +15,7 @@ from vllm.v1.worker.gpu.sample.logprob import compute_topk_logprobs
from
vllm.v1.worker.gpu.sample.output
import
SamplerOutput
from
vllm.v1.worker.gpu.sample.penalties
import
PenaltiesState
from
vllm.v1.worker.gpu.sample.states
import
NO_LOGPROBS
,
SamplingStates
from
vllm.v1.worker.gpu.states
import
RequestState
class
Sampler
:
...
...
@@ -23,9 +24,7 @@ class Sampler:
max_num_reqs
:
int
,
vocab_size
:
int
,
device
:
torch
.
device
,
all_token_ids
:
torch
.
Tensor
,
prompt_len
:
torch
.
Tensor
,
total_len
:
torch
.
Tensor
,
req_states
:
RequestState
,
logprobs_mode
:
LogprobsMode
=
"raw_logprobs"
,
num_speculative_tokens
:
int
=
1
,
):
...
...
@@ -35,9 +34,9 @@ class Sampler:
self
.
compute_nans
=
envs
.
VLLM_COMPUTE_NANS_IN_LOGITS
# False by default.
self
.
sampling_states
=
SamplingStates
(
max_num_reqs
,
vocab_size
)
self
.
penalties_state
=
PenaltiesState
(
max_num_reqs
,
vocab_size
,
device
)
self
.
penalties_state
=
PenaltiesState
(
req_states
)
self
.
logit_bias_state
=
LogitBiasState
(
max_num_reqs
,
device
)
self
.
bad_words_state
=
BadWordsState
(
all_token_ids
,
prompt_len
,
total_len
)
self
.
bad_words_state
=
BadWordsState
(
req_states
)
self
.
num_speculative_tokens
=
num_speculative_tokens
def
add_request
(
...
...
@@ -48,16 +47,9 @@ class Sampler:
self
.
logit_bias_state
.
add_request
(
req_idx
,
prompt_len
,
sampling_params
)
self
.
bad_words_state
.
add_request
(
req_idx
,
sampling_params
)
def
apply_staged_writes
(
self
,
all_token_ids
:
torch
.
Tensor
,
prefill_lens
:
np
.
ndarray
,
prompt_lens
:
np
.
ndarray
,
)
->
None
:
def
apply_staged_writes
(
self
)
->
None
:
self
.
sampling_states
.
apply_staged_writes
()
self
.
penalties_state
.
apply_staged_writes
(
all_token_ids
,
prefill_lens
,
prompt_lens
)
self
.
penalties_state
.
apply_staged_writes
()
self
.
logit_bias_state
.
apply_staged_writes
()
self
.
bad_words_state
.
apply_staged_writes
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment