Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1dcafb3d
Unverified
Commit
1dcafb3d
authored
Nov 28, 2025
by
Woosuk Kwon
Committed by
GitHub
Nov 28, 2025
Browse files
[Model Runner V2] Support penalties using bin counts (#29703)
Signed-off-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
ea3370b4
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
280 additions
and
14 deletions
+280
-14
vllm/v1/worker/gpu/input_batch.py
vllm/v1/worker/gpu/input_batch.py
+15
-0
vllm/v1/worker/gpu/model_runner.py
vllm/v1/worker/gpu/model_runner.py
+5
-4
vllm/v1/worker/gpu/penalties.py
vllm/v1/worker/gpu/penalties.py
+85
-0
vllm/v1/worker/gpu/sampler.py
vllm/v1/worker/gpu/sampler.py
+3
-0
vllm/v1/worker/gpu/states.py
vllm/v1/worker/gpu/states.py
+172
-10
No files found.
vllm/v1/worker/gpu/input_batch.py
View file @
1dcafb3d
...
...
@@ -341,6 +341,8 @@ def _post_update_kernel(
idx_mapping_ptr
,
num_computed_tokens_ptr
,
last_sampled_tokens_ptr
,
output_bin_counts_ptr
,
output_bin_counts_stride
,
sampled_tokens_ptr
,
sampled_tokens_stride
,
num_sampled_ptr
,
...
...
@@ -357,6 +359,15 @@ def _post_update_kernel(
)
tl
.
store
(
last_sampled_tokens_ptr
+
req_state_idx
,
token_id
)
for
i
in
range
(
num_sampled
):
token_id
=
tl
.
load
(
sampled_tokens_ptr
+
req_id
*
sampled_tokens_stride
+
i
)
token_ptr
=
(
output_bin_counts_ptr
+
req_state_idx
*
output_bin_counts_stride
+
token_id
)
count
=
tl
.
load
(
token_ptr
)
count
+=
1
tl
.
store
(
token_ptr
,
count
)
query_start
=
tl
.
load
(
query_start_loc_ptr
+
req_id
)
query_end
=
tl
.
load
(
query_start_loc_ptr
+
req_id
+
1
)
query_len
=
query_end
-
query_start
...
...
@@ -374,6 +385,8 @@ def post_update(
num_computed_tokens
:
torch
.
Tensor
,
# [max_num_reqs]
last_sampled_tokens
:
torch
.
Tensor
,
# [max_num_reqs, vocab_size]
output_bin_counts
:
torch
.
Tensor
,
# [num_reqs, num_speculative_steps + 1]
sampled_tokens
:
torch
.
Tensor
,
# [num_reqs]
...
...
@@ -388,6 +401,8 @@ def post_update(
idx_mapping
,
num_computed_tokens
,
last_sampled_tokens
,
output_bin_counts
,
output_bin_counts
.
stride
(
0
),
sampled_tokens
,
sampled_tokens
.
stride
(
0
),
num_sampled
,
...
...
vllm/v1/worker/gpu/model_runner.py
View file @
1dcafb3d
...
...
@@ -512,7 +512,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
idx_mapping_np
,
num_scheduled_tokens
,
query_start_loc_np
,
self
.
req_states
.
prefill_token_ids
,
self
.
req_states
.
prefill_token_ids
.
np
,
self
.
req_states
.
num_computed_prefill_tokens
,
self
.
input_buffers
.
input_ids
.
np
,
)
...
...
@@ -681,7 +681,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Handle chunked prompts.
pos_after_step
=
computed_prefill
+
input_batch
.
num_scheduled_tokens
is_prompt_chunked
=
pos_after_step
<
prompt_lens
prefill_token_ids
=
self
.
req_states
.
prefill_token_ids
prefill_token_ids
=
self
.
req_states
.
prefill_token_ids
.
np
query_start_loc
=
self
.
input_buffers
.
query_start_loc
.
np
for
i
,
req_id
in
enumerate
(
input_batch
.
req_ids
):
if
not
needs_prompt_logprobs
[
i
]:
...
...
@@ -756,6 +756,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
input_batch
.
idx_mapping
,
self
.
req_states
.
num_computed_tokens
,
self
.
req_states
.
last_sampled_tokens
,
self
.
req_states
.
output_bin_counts
,
sampled_tokens
,
num_sampled
,
num_rejected
,
...
...
@@ -785,7 +786,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
idx_mapping_np
=
input_batch
.
idx_mapping_np
with
async_barrier
(
self
.
spec_decode_event
):
self
.
input_buffers
.
next_prefill_tokens
.
np
[:
num_reqs
]
=
(
self
.
req_states
.
prefill_token_ids
[
self
.
req_states
.
prefill_token_ids
.
np
[
idx_mapping_np
,
self
.
req_states
.
num_computed_prefill_tokens
[
idx_mapping_np
],
]
...
...
@@ -896,7 +897,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# barrier to avoid race conditions.
pos
=
input_batch
.
positions
[
input_batch
.
logits_indices
]
sampling_metadata
=
self
.
req_states
.
make_sampling_metadata
(
input_batch
.
idx_mapping_np
,
pos
input_batch
.
idx_mapping
,
input_batch
.
idx_mapping_np
,
pos
)
if
input_batch
.
num_draft_tokens
>
0
:
sampling_metadata
=
self
.
req_states
.
expand_sampling_metadata
(
...
...
vllm/v1/worker/gpu/penalties.py
0 → 100644
View file @
1dcafb3d
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
torch
from
vllm.triton_utils
import
tl
,
triton
from
vllm.v1.worker.gpu.states
import
SamplingMetadata
@
triton
.
jit
def
_penalties_kernel
(
logits_ptr
,
logits_stride
,
repetition_penalty_ptr
,
frequency_penalty_ptr
,
presence_penalty_ptr
,
idx_mapping_ptr
,
prompt_bin_counts_ptr
,
prompt_bin_counts_stride
,
output_bin_counts_ptr
,
output_bin_counts_stride
,
vocab_size
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
batch_idx
=
tl
.
program_id
(
0
)
rep_penalty
=
tl
.
load
(
repetition_penalty_ptr
+
batch_idx
)
freq_penalty
=
tl
.
load
(
frequency_penalty_ptr
+
batch_idx
)
pres_penalty
=
tl
.
load
(
presence_penalty_ptr
+
batch_idx
)
use_rep_penalty
=
rep_penalty
!=
1.0
use_freq_penalty
=
freq_penalty
!=
0.0
use_pres_penalty
=
pres_penalty
!=
0.0
if
not
(
use_rep_penalty
or
use_freq_penalty
or
use_pres_penalty
):
# No penalties to apply. Early return.
return
block_idx
=
tl
.
program_id
(
1
)
block
=
block_idx
*
BLOCK_SIZE
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
block
<
vocab_size
logits
=
tl
.
load
(
logits_ptr
+
batch_idx
*
logits_stride
+
block
,
mask
=
mask
)
logits
=
logits
.
to
(
tl
.
float32
)
req_state_idx
=
tl
.
load
(
idx_mapping_ptr
+
batch_idx
)
output_bin_counts
=
tl
.
load
(
output_bin_counts_ptr
+
req_state_idx
*
output_bin_counts_stride
+
block
,
mask
=
mask
,
)
# Apply repetition penalties.
if
use_rep_penalty
:
prompt_bin_counts
=
tl
.
load
(
prompt_bin_counts_ptr
+
req_state_idx
*
prompt_bin_counts_stride
+
block
,
mask
=
mask
,
)
# If token appears in prompt or output, apply, otherwise use 1.0 for no-op.
scale
=
tl
.
where
((
prompt_bin_counts
+
output_bin_counts
)
>
0
,
rep_penalty
,
1.0
)
# If logits are positive, divide by penalty, otherwise multiply by penalty.
scale
=
tl
.
where
(
logits
>
0
,
1.0
/
scale
,
scale
)
logits
*=
scale
# Apply frequency penalties.
logits
-=
freq_penalty
*
output_bin_counts
# Apply presence penalties.
logits
-=
pres_penalty
*
(
output_bin_counts
>
0
)
# Store back to logits.
tl
.
store
(
logits_ptr
+
batch_idx
*
logits_stride
+
block
,
logits
,
mask
=
mask
)
def
apply_penalties
(
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
None
:
num_reqs
,
vocab_size
=
logits
.
shape
BLOCK_SIZE
=
8192
num_blocks
=
triton
.
cdiv
(
vocab_size
,
BLOCK_SIZE
)
_penalties_kernel
[(
num_reqs
,
num_blocks
)](
logits
,
logits
.
stride
(
0
),
sampling_metadata
.
repetition_penalty
,
sampling_metadata
.
frequency_penalty
,
sampling_metadata
.
presence_penalty
,
sampling_metadata
.
idx_mapping
,
sampling_metadata
.
prompt_bin_counts
,
sampling_metadata
.
prompt_bin_counts
.
stride
(
0
),
sampling_metadata
.
output_bin_counts
,
sampling_metadata
.
output_bin_counts
.
stride
(
0
),
vocab_size
,
BLOCK_SIZE
=
BLOCK_SIZE
,
)
vllm/v1/worker/gpu/sampler.py
View file @
1dcafb3d
...
...
@@ -8,6 +8,7 @@ from vllm.config.model import LogprobsMode
from
vllm.triton_utils
import
tl
,
triton
from
vllm.v1.outputs
import
LogprobsTensors
,
SamplerOutput
from
vllm.v1.sample.ops.topk_topp_sampler
import
apply_top_k_top_p
from
vllm.v1.worker.gpu.penalties
import
apply_penalties
from
vllm.v1.worker.gpu.states
import
SamplingMetadata
...
...
@@ -65,6 +66,8 @@ class Sampler:
logits
=
apply_top_k_top_p
(
logits
,
sampling_metadata
.
top_k
,
sampling_metadata
.
top_p
)
# Apply penalties in place.
apply_penalties
(
logits
,
sampling_metadata
)
sampled
=
gumbel_sample
(
logits
,
...
...
vllm/v1/worker/gpu/states.py
View file @
1dcafb3d
...
...
@@ -8,6 +8,8 @@ import torch
from
vllm.lora.request
import
LoRARequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.triton_utils
import
tl
,
triton
from
vllm.utils.platform_utils
import
is_uva_available
from
vllm.utils.torch_utils
import
get_cuda_view_from_cpu_tensor
from
vllm.v1.outputs
import
LogprobsTensors
from
vllm.v1.utils
import
CpuGpuBuffer
...
...
@@ -23,12 +25,21 @@ class SamplingMetadata:
top_p
:
torch
.
Tensor
|
None
top_k
:
torch
.
Tensor
|
None
repetition_penalty
:
torch
.
Tensor
frequency_penalty
:
torch
.
Tensor
presence_penalty
:
torch
.
Tensor
seeds
:
torch
.
Tensor
pos
:
torch
.
Tensor
# None means no logprobs, 0 means sampled token logprobs only
max_num_logprobs
:
int
|
None
# For penalties
idx_mapping
:
torch
.
Tensor
prompt_bin_counts
:
torch
.
Tensor
output_bin_counts
:
torch
.
Tensor
@
classmethod
def
make_dummy
(
cls
,
...
...
@@ -44,17 +55,35 @@ class SamplingMetadata:
# top_k = torch.full((num_reqs,), 20, dtype=torch.int32, device=device)
top_p
=
None
top_k
=
None
# NOTE(woosuk): We must set penalties to their default values to make sure
# the penalties kernel does not touch the placeholder bin_counts tensors.
repetition_penalty
=
torch
.
ones
(
num_reqs
,
dtype
=
torch
.
float32
,
device
=
device
)
frequency_penalty
=
torch
.
zeros
(
num_reqs
,
dtype
=
torch
.
float32
,
device
=
device
)
presence_penalty
=
torch
.
zeros
(
num_reqs
,
dtype
=
torch
.
float32
,
device
=
device
)
seeds
=
torch
.
zeros
(
num_reqs
,
dtype
=
torch
.
int64
,
device
=
device
)
pos
=
torch
.
zeros
(
num_reqs
,
dtype
=
torch
.
int64
,
device
=
device
)
max_num_logprobs
=
20
idx_mapping
=
torch
.
arange
(
num_reqs
,
dtype
=
torch
.
int32
,
device
=
device
)
# NOTE(woosuk): These are placeholder tensors to avoid None checks in the
# penalties kernel. We use 2 instead of 1 as vocab_size to avoid Triton
# specialization and re-compilation at runtime.
prompt_bin_counts
=
torch
.
zeros
(
num_reqs
,
2
,
dtype
=
torch
.
int32
,
device
=
device
)
output_bin_counts
=
torch
.
zeros
(
num_reqs
,
2
,
dtype
=
torch
.
int32
,
device
=
device
)
return
cls
(
temperature
=
temperature
,
top_p
=
top_p
,
top_k
=
top_k
,
repetition_penalty
=
repetition_penalty
,
frequency_penalty
=
frequency_penalty
,
presence_penalty
=
presence_penalty
,
seeds
=
seeds
,
pos
=
pos
,
max_num_logprobs
=
max_num_logprobs
,
idx_mapping
=
idx_mapping
,
prompt_bin_counts
=
prompt_bin_counts
,
output_bin_counts
=
output_bin_counts
,
)
...
...
@@ -83,9 +112,10 @@ class RequestState:
self
.
extra_data
:
dict
[
str
,
ExtraData
]
=
{}
self
.
prompt_len
=
np
.
zeros
(
self
.
max_num_reqs
,
dtype
=
np
.
int32
)
self
.
prefill_token_ids
=
np
.
zeros
(
(
self
.
max_num_reqs
,
self
.
max_model_len
),
dtype
=
np
.
int32
,
# NOTE(woosuk): This tensor can be extremely large (e.g., several GBs)
# depending on the configured max_num_reqs and max_model_len.
self
.
prefill_token_ids
=
UvaBuffer
(
self
.
max_num_reqs
,
self
.
max_model_len
,
dtype
=
torch
.
int32
)
self
.
prefill_len
=
self
.
_make_buffer
(
self
.
max_num_reqs
,
dtype
=
torch
.
int32
)
...
...
@@ -119,6 +149,9 @@ class RequestState:
self
.
temperature
=
self
.
_make_param
(
self
.
max_num_reqs
,
torch
.
float32
)
self
.
top_p
=
self
.
_make_param
(
self
.
max_num_reqs
,
torch
.
float32
)
self
.
top_k
=
self
.
_make_param
(
self
.
max_num_reqs
,
torch
.
int32
)
self
.
repetition_penalty
=
self
.
_make_param
(
self
.
max_num_reqs
,
torch
.
float32
)
self
.
frequency_penalty
=
self
.
_make_param
(
self
.
max_num_reqs
,
torch
.
float32
)
self
.
presence_penalty
=
self
.
_make_param
(
self
.
max_num_reqs
,
torch
.
float32
)
self
.
seeds
=
self
.
_make_param
(
self
.
max_num_reqs
,
torch
.
int64
)
self
.
num_logprobs
=
np
.
empty
(
self
.
max_num_reqs
,
dtype
=
np
.
int32
)
...
...
@@ -126,6 +159,16 @@ class RequestState:
self
.
num_logprobs
.
fill
(
-
1
)
self
.
needs_prompt_logprobs
=
np
.
zeros
(
self
.
max_num_reqs
,
dtype
=
bool
)
# Statistics for penalties.
# TODO(woosuk): These tensors are rarely used but can be extremely large.
# Optimize the memory usage.
self
.
prompt_bin_counts
=
torch
.
zeros
(
self
.
max_num_reqs
,
self
.
vocab_size
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
self
.
output_bin_counts
=
torch
.
zeros
(
self
.
max_num_reqs
,
self
.
vocab_size
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
def
_make_param
(
self
,
size
:
int
,
dtype
:
torch
.
dtype
)
->
"Param"
:
return
Param
(
size
,
dtype
=
dtype
,
device
=
self
.
device
,
pin_memory
=
self
.
pin_memory
)
...
...
@@ -159,7 +202,7 @@ class RequestState:
f
"prefill_len
{
prefill_len
}
< prompt_len
{
prompt_len
}
"
)
self
.
prefill_len
.
np
[
req_idx
]
=
prefill_len
self
.
prefill_token_ids
[
req_idx
,
:
prefill_len
]
=
prefill_token_ids
self
.
prefill_token_ids
.
np
[
req_idx
,
:
prefill_len
]
=
prefill_token_ids
self
.
num_computed_prefill_tokens
[
req_idx
]
=
num_computed_tokens
# FIXME(woosuk): This triggers a GPU operation whenever adding a new request.
...
...
@@ -178,6 +221,18 @@ class RequestState:
else
:
top_k
=
self
.
vocab_size
self
.
top_k
.
np
[
req_idx
]
=
top_k
self
.
repetition_penalty
.
np
[
req_idx
]
=
sampling_params
.
repetition_penalty
self
.
frequency_penalty
.
np
[
req_idx
]
=
sampling_params
.
frequency_penalty
self
.
presence_penalty
.
np
[
req_idx
]
=
sampling_params
.
presence_penalty
if
use_penalty
(
sampling_params
):
bincount
(
self
.
prefill_token_ids
.
gpu
[
req_idx
],
prefill_len
,
prompt_len
,
self
.
prompt_bin_counts
[
req_idx
],
self
.
output_bin_counts
[
req_idx
],
)
if
sampling_params
.
seed
is
not
None
:
seed
=
sampling_params
.
seed
...
...
@@ -206,24 +261,32 @@ class RequestState:
def
make_sampling_metadata
(
self
,
idx_mapping
:
np
.
ndarray
,
idx_mapping
:
torch
.
Tensor
,
idx_mapping_np
:
np
.
ndarray
,
pos
:
torch
.
Tensor
,
)
->
SamplingMetadata
:
temperature
=
self
.
temperature
.
np
[
idx_mapping
]
temperature
=
self
.
temperature
.
np
[
idx_mapping
_np
]
temperature
=
self
.
temperature
.
copy_np_to_gpu
(
temperature
)
top_p
=
self
.
top_p
.
np
[
idx_mapping
]
top_p
=
self
.
top_p
.
np
[
idx_mapping
_np
]
no_top_p
=
np
.
all
(
top_p
==
1.0
)
top_p
=
self
.
top_p
.
copy_np_to_gpu
(
top_p
)
if
not
no_top_p
else
None
top_k
=
self
.
top_k
.
np
[
idx_mapping
]
top_k
=
self
.
top_k
.
np
[
idx_mapping
_np
]
no_top_k
=
np
.
all
(
top_k
==
self
.
vocab_size
)
top_k
=
self
.
top_k
.
copy_np_to_gpu
(
top_k
)
if
not
no_top_k
else
None
seeds
=
self
.
seeds
.
np
[
idx_mapping
]
rep_penalty
=
self
.
repetition_penalty
.
np
[
idx_mapping_np
]
rep_penalty
=
self
.
repetition_penalty
.
copy_np_to_gpu
(
rep_penalty
)
freq_penalty
=
self
.
frequency_penalty
.
np
[
idx_mapping_np
]
freq_penalty
=
self
.
frequency_penalty
.
copy_np_to_gpu
(
freq_penalty
)
pres_penalty
=
self
.
presence_penalty
.
np
[
idx_mapping_np
]
pres_penalty
=
self
.
presence_penalty
.
copy_np_to_gpu
(
pres_penalty
)
seeds
=
self
.
seeds
.
np
[
idx_mapping_np
]
seeds
=
self
.
seeds
.
copy_np_to_gpu
(
seeds
)
num_logprobs
=
self
.
num_logprobs
[
idx_mapping
]
num_logprobs
=
self
.
num_logprobs
[
idx_mapping
_np
]
max_num_logprobs
:
int
|
None
=
int
(
np
.
max
(
num_logprobs
))
if
max_num_logprobs
==
-
1
:
max_num_logprobs
=
None
...
...
@@ -232,9 +295,15 @@ class RequestState:
temperature
=
temperature
,
top_p
=
top_p
,
top_k
=
top_k
,
repetition_penalty
=
rep_penalty
,
frequency_penalty
=
freq_penalty
,
presence_penalty
=
pres_penalty
,
seeds
=
seeds
,
pos
=
pos
,
max_num_logprobs
=
max_num_logprobs
,
idx_mapping
=
idx_mapping
,
prompt_bin_counts
=
self
.
prompt_bin_counts
,
output_bin_counts
=
self
.
output_bin_counts
,
)
def
expand_sampling_metadata
(
...
...
@@ -294,6 +363,14 @@ class ExtraData:
in_progress_prompt_logprobs
:
list
[
LogprobsTensors
]
=
field
(
default_factory
=
list
)
class
UvaBuffer
:
def
__init__
(
self
,
*
size
:
int
|
torch
.
SymInt
,
dtype
:
torch
.
dtype
):
assert
is_uva_available
()
self
.
cpu
=
torch
.
zeros
(
*
size
,
dtype
=
dtype
,
device
=
"cpu"
,
pin_memory
=
True
)
self
.
np
=
self
.
cpu
.
numpy
()
self
.
gpu
=
get_cuda_view_from_cpu_tensor
(
self
.
cpu
)
# NOTE(woosuk): Re-compilation can happen at runtime since top_p and top_k can be None.
@
triton
.
jit
def
_expand_sampling_metadata_kernel
(
...
...
@@ -304,6 +381,12 @@ def _expand_sampling_metadata_kernel(
top_k_ptr
,
expanded_top_k_ptr
,
seeds_ptr
,
rep_penalty_ptr
,
expanded_rep_penalty_ptr
,
freq_penalty_ptr
,
expanded_freq_penalty_ptr
,
pres_penalty_ptr
,
expanded_pres_penalty_ptr
,
expanded_seeds_ptr
,
cu_num_logits_ptr
,
BLOCK_SIZE
:
tl
.
constexpr
,
...
...
@@ -327,6 +410,15 @@ def _expand_sampling_metadata_kernel(
top_k
=
tl
.
load
(
top_k_ptr
+
req_idx
)
tl
.
store
(
expanded_top_k_ptr
+
start_idx
+
block
,
top_k
,
mask
=
mask
)
rep_penalty
=
tl
.
load
(
rep_penalty_ptr
+
req_idx
)
tl
.
store
(
expanded_rep_penalty_ptr
+
start_idx
+
block
,
rep_penalty
,
mask
=
mask
)
freq_penalty
=
tl
.
load
(
freq_penalty_ptr
+
req_idx
)
tl
.
store
(
expanded_freq_penalty_ptr
+
start_idx
+
block
,
freq_penalty
,
mask
=
mask
)
pres_penalty
=
tl
.
load
(
pres_penalty_ptr
+
req_idx
)
tl
.
store
(
expanded_pres_penalty_ptr
+
start_idx
+
block
,
pres_penalty
,
mask
=
mask
)
seed
=
tl
.
load
(
seeds_ptr
+
req_idx
)
tl
.
store
(
expanded_seeds_ptr
+
start_idx
+
block
,
seed
,
mask
=
mask
)
...
...
@@ -341,6 +433,9 @@ def expand_sampling_metadata(
expanded_temp
=
create_empty
(
sampling_metadata
.
temperature
)
expanded_top_p
=
create_empty
(
sampling_metadata
.
top_p
)
expanded_top_k
=
create_empty
(
sampling_metadata
.
top_k
)
expanded_repetition_penalty
=
create_empty
(
sampling_metadata
.
repetition_penalty
)
expanded_frequency_penalty
=
create_empty
(
sampling_metadata
.
frequency_penalty
)
expanded_presence_penalty
=
create_empty
(
sampling_metadata
.
presence_penalty
)
expanded_seeds
=
create_empty
(
sampling_metadata
.
seeds
)
num_reqs
=
cu_num_logits
.
shape
[
0
]
-
1
...
...
@@ -351,6 +446,12 @@ def expand_sampling_metadata(
expanded_top_p
,
sampling_metadata
.
top_k
,
expanded_top_k
,
sampling_metadata
.
repetition_penalty
,
expanded_repetition_penalty
,
sampling_metadata
.
frequency_penalty
,
expanded_frequency_penalty
,
sampling_metadata
.
presence_penalty
,
expanded_presence_penalty
,
sampling_metadata
.
seeds
,
expanded_seeds
,
cu_num_logits
,
...
...
@@ -361,6 +462,67 @@ def expand_sampling_metadata(
top_p
=
expanded_top_p
,
top_k
=
expanded_top_k
,
seeds
=
expanded_seeds
,
repetition_penalty
=
expanded_repetition_penalty
,
frequency_penalty
=
expanded_frequency_penalty
,
presence_penalty
=
expanded_presence_penalty
,
pos
=
sampling_metadata
.
pos
,
max_num_logprobs
=
sampling_metadata
.
max_num_logprobs
,
# TODO(woosuk): Support penalties with spec decoding.
idx_mapping
=
sampling_metadata
.
idx_mapping
,
prompt_bin_counts
=
sampling_metadata
.
prompt_bin_counts
,
output_bin_counts
=
sampling_metadata
.
output_bin_counts
,
)
def
use_penalty
(
sampling_params
:
SamplingParams
)
->
bool
:
return
(
sampling_params
.
repetition_penalty
!=
1.0
or
sampling_params
.
frequency_penalty
!=
0.0
or
sampling_params
.
presence_penalty
!=
0.0
)
@
triton
.
jit
(
do_not_specialize
=
[
"prefill_len"
,
"prompt_len"
])
def
_bincount_kernel
(
prefill_token_ids_ptr
,
prefill_len
,
prompt_len
,
prompt_bin_counts_ptr
,
output_bin_counts_ptr
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
block_idx
=
tl
.
program_id
(
0
)
if
block_idx
*
BLOCK_SIZE
>=
prefill_len
:
return
block
=
block_idx
*
BLOCK_SIZE
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
if
block_idx
*
BLOCK_SIZE
<
prompt_len
:
mask
=
block
<
prompt_len
prefill_tokens
=
tl
.
load
(
prefill_token_ids_ptr
+
block
,
mask
=
mask
)
tl
.
atomic_add
(
prompt_bin_counts_ptr
+
prefill_tokens
,
1
,
mask
=
mask
)
if
(
block_idx
+
1
)
*
BLOCK_SIZE
>=
prompt_len
:
mask
=
block
<
prefill_len
mask
&=
block
>=
prompt_len
prefill_tokens
=
tl
.
load
(
prefill_token_ids_ptr
+
block
,
mask
=
mask
)
tl
.
atomic_add
(
output_bin_counts_ptr
+
prefill_tokens
,
1
,
mask
=
mask
)
def
bincount
(
prefill_token_ids
:
torch
.
Tensor
,
prefill_len
:
int
,
prompt_len
:
int
,
prompt_bin_counts
:
torch
.
Tensor
,
output_bin_counts
:
torch
.
Tensor
,
)
->
None
:
prompt_bin_counts
.
zero_
()
output_bin_counts
.
zero_
()
BLOCK_SIZE
=
1024
num_blocks
=
triton
.
cdiv
(
prefill_len
,
BLOCK_SIZE
)
_bincount_kernel
[(
num_blocks
,)](
prefill_token_ids
,
prefill_len
,
prompt_len
,
prompt_bin_counts
,
output_bin_counts
,
BLOCK_SIZE
=
BLOCK_SIZE
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment