Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ec38a736
Unverified
Commit
ec38a736
authored
Nov 30, 2025
by
Woosuk Kwon
Committed by
GitHub
Nov 30, 2025
Browse files
[Model Runner V2] Use packed mask for prompt bin counts (#29756)
Signed-off-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
21c26279
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
35 additions
and
25 deletions
+35
-25
vllm/v1/worker/gpu/sample/metadata.py
vllm/v1/worker/gpu/sample/metadata.py
+4
-4
vllm/v1/worker/gpu/sample/penalties.py
vllm/v1/worker/gpu/sample/penalties.py
+21
-15
vllm/v1/worker/gpu/states.py
vllm/v1/worker/gpu/states.py
+10
-6
No files found.
vllm/v1/worker/gpu/sample/metadata.py
View file @
ec38a736
...
...
@@ -26,7 +26,7 @@ class SamplingMetadata:
# For penalties
idx_mapping
:
torch
.
Tensor
prompt_bin_
counts
:
torch
.
Tensor
prompt_bin_
mask
:
torch
.
Tensor
output_bin_counts
:
torch
.
Tensor
@
classmethod
...
...
@@ -57,7 +57,7 @@ class SamplingMetadata:
# NOTE(woosuk): These are placeholder tensors to avoid None checks in the
# penalties kernel. We use 2 instead of 1 as vocab_size to avoid Triton
# specialization and re-compilation at runtime.
prompt_bin_
counts
=
torch
.
zeros
(
num_reqs
,
2
,
dtype
=
torch
.
int32
,
device
=
device
)
prompt_bin_
mask
=
torch
.
zeros
(
num_reqs
,
2
,
dtype
=
torch
.
int32
,
device
=
device
)
output_bin_counts
=
torch
.
zeros
(
num_reqs
,
2
,
dtype
=
torch
.
int32
,
device
=
device
)
return
cls
(
...
...
@@ -71,7 +71,7 @@ class SamplingMetadata:
pos
=
pos
,
max_num_logprobs
=
max_num_logprobs
,
idx_mapping
=
idx_mapping
,
prompt_bin_
counts
=
prompt_bin_
counts
,
prompt_bin_
mask
=
prompt_bin_
mask
,
output_bin_counts
=
output_bin_counts
,
)
...
...
@@ -174,6 +174,6 @@ def expand_sampling_metadata(
max_num_logprobs
=
sampling_metadata
.
max_num_logprobs
,
# TODO(woosuk): Support penalties with spec decoding.
idx_mapping
=
sampling_metadata
.
idx_mapping
,
prompt_bin_
counts
=
sampling_metadata
.
prompt_bin_
counts
,
prompt_bin_
mask
=
sampling_metadata
.
prompt_bin_
mask
,
output_bin_counts
=
sampling_metadata
.
output_bin_counts
,
)
vllm/v1/worker/gpu/sample/penalties.py
View file @
ec38a736
...
...
@@ -15,8 +15,8 @@ def _penalties_and_temperature_kernel(
presence_penalty_ptr
,
temperature_ptr
,
idx_mapping_ptr
,
prompt_bin_
counts
_ptr
,
prompt_bin_
counts
_stride
,
prompt_bin_
mask
_ptr
,
prompt_bin_
mask
_stride
,
output_bin_counts_ptr
,
output_bin_counts_stride
,
vocab_size
,
...
...
@@ -54,13 +54,16 @@ def _penalties_and_temperature_kernel(
# Apply repetition penalties.
if
use_rep_penalty
:
prompt_bin_counts
=
tl
.
load
(
prompt_bin_counts_ptr
+
req_state_idx
*
prompt_bin_counts_stride
+
block
,
mask
=
mask
,
packed_block
=
block_idx
*
BLOCK_SIZE
//
32
+
tl
.
arange
(
0
,
BLOCK_SIZE
//
32
)
packed_mask
=
tl
.
load
(
prompt_bin_mask_ptr
+
req_state_idx
*
prompt_bin_mask_stride
+
packed_block
,
mask
=
packed_block
<
tl
.
cdiv
(
vocab_size
,
32
),
)
prompt_bin_mask
=
prompt_bin_counts
>
0
prompt_bin_mask
=
(
packed_mask
[:,
None
]
>>
(
tl
.
arange
(
0
,
32
)[
None
,
:]))
&
1
prompt_bin_mask
=
prompt_bin_mask
.
reshape
(
BLOCK_SIZE
)
# If token appears in prompt or output, apply, otherwise use 1.0 for no-op.
scale
=
tl
.
where
(
prompt_bin_mask
|
output_bin_mask
,
rep_penalty
,
1.0
)
# If logits are positive, divide by penalty, otherwise multiply by penalty.
...
...
@@ -93,8 +96,8 @@ def apply_penalties_and_temperature(
sampling_metadata
.
presence_penalty
,
sampling_metadata
.
temperature
,
sampling_metadata
.
idx_mapping
,
sampling_metadata
.
prompt_bin_
counts
,
sampling_metadata
.
prompt_bin_
counts
.
stride
(
0
),
sampling_metadata
.
prompt_bin_
mask
,
sampling_metadata
.
prompt_bin_
mask
.
stride
(
0
),
sampling_metadata
.
output_bin_counts
,
sampling_metadata
.
output_bin_counts
.
stride
(
0
),
vocab_size
,
...
...
@@ -107,7 +110,7 @@ def _bincount_kernel(
prefill_token_ids_ptr
,
prefill_len
,
prompt_len
,
prompt_bin_
counts
_ptr
,
prompt_bin_
mask
_ptr
,
output_bin_counts_ptr
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
...
...
@@ -119,7 +122,10 @@ def _bincount_kernel(
if
block_idx
*
BLOCK_SIZE
<
prompt_len
:
mask
=
block
<
prompt_len
prefill_tokens
=
tl
.
load
(
prefill_token_ids_ptr
+
block
,
mask
=
mask
)
tl
.
atomic_add
(
prompt_bin_counts_ptr
+
prefill_tokens
,
1
,
mask
=
mask
)
idx
=
prefill_tokens
//
32
bit_idx
=
prefill_tokens
%
32
bit
=
tl
.
full
((
BLOCK_SIZE
,),
1
,
tl
.
int32
)
<<
bit_idx
tl
.
atomic_or
(
prompt_bin_mask_ptr
+
idx
,
bit
,
mask
=
mask
)
if
(
block_idx
+
1
)
*
BLOCK_SIZE
>=
prompt_len
:
mask
=
block
<
prefill_len
mask
&=
block
>=
prompt_len
...
...
@@ -131,10 +137,10 @@ def bincount(
prefill_token_ids
:
torch
.
Tensor
,
prefill_len
:
int
,
prompt_len
:
int
,
prompt_bin_
counts
:
torch
.
Tensor
,
prompt_bin_
mask
:
torch
.
Tensor
,
output_bin_counts
:
torch
.
Tensor
,
)
->
None
:
prompt_bin_
counts
.
zero_
()
prompt_bin_
mask
.
zero_
()
output_bin_counts
.
zero_
()
BLOCK_SIZE
=
1024
num_blocks
=
triton
.
cdiv
(
prefill_len
,
BLOCK_SIZE
)
...
...
@@ -142,7 +148,7 @@ def bincount(
prefill_token_ids
,
prefill_len
,
prompt_len
,
prompt_bin_
counts
,
prompt_bin_
mask
,
output_bin_counts
,
BLOCK_SIZE
=
BLOCK_SIZE
,
)
vllm/v1/worker/gpu/states.py
View file @
ec38a736
...
...
@@ -7,6 +7,7 @@ import torch
from
vllm.lora.request
import
LoRARequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.utils.math_utils
import
cdiv
from
vllm.utils.platform_utils
import
is_uva_available
from
vllm.utils.torch_utils
import
get_cuda_view_from_cpu_tensor
from
vllm.v1.outputs
import
LogprobsTensors
...
...
@@ -97,11 +98,14 @@ class RequestState:
self
.
needs_prompt_logprobs
=
np
.
zeros
(
self
.
max_num_reqs
,
dtype
=
bool
)
# Statistics for penalties.
# TODO(woosuk): These tensors are rarely used but can be extremely large.
# Optimize the memory usage.
self
.
prompt_bin_counts
=
torch
.
zeros
(
self
.
max_num_reqs
,
self
.
vocab_size
,
dtype
=
torch
.
int32
,
device
=
self
.
device
self
.
prompt_bin_mask
=
torch
.
zeros
(
self
.
max_num_reqs
,
cdiv
(
self
.
vocab_size
,
32
),
dtype
=
torch
.
int32
,
device
=
self
.
device
,
)
# TODO(woosuk): This tensor is rarely used but can be extremely large.
# Optimize the memory usage.
self
.
output_bin_counts
=
torch
.
zeros
(
self
.
max_num_reqs
,
self
.
vocab_size
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
...
...
@@ -167,7 +171,7 @@ class RequestState:
self
.
prefill_token_ids
.
gpu
[
req_idx
],
prefill_len
,
prompt_len
,
self
.
prompt_bin_
counts
[
req_idx
],
self
.
prompt_bin_
mask
[
req_idx
],
self
.
output_bin_counts
[
req_idx
],
)
...
...
@@ -239,7 +243,7 @@ class RequestState:
pos
=
pos
,
max_num_logprobs
=
max_num_logprobs
,
idx_mapping
=
idx_mapping
,
prompt_bin_
counts
=
self
.
prompt_bin_
counts
,
prompt_bin_
mask
=
self
.
prompt_bin_
mask
,
output_bin_counts
=
self
.
output_bin_counts
,
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment