Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
16786da7
Unverified
Commit
16786da7
authored
Feb 07, 2026
by
zhrrr
Committed by
GitHub
Feb 06, 2026
Browse files
[Model Runner V2] support apply penalty for spec decode (#33251)
Signed-off-by:
zhuhaoran
<
zhuhaoran.zhr@alibaba-inc.com
>
parent
aaa2efbe
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
91 additions
and
14 deletions
+91
-14
vllm/v1/worker/gpu/input_batch.py
vllm/v1/worker/gpu/input_batch.py
+12
-2
vllm/v1/worker/gpu/model_runner.py
vllm/v1/worker/gpu/model_runner.py
+22
-3
vllm/v1/worker/gpu/sample/penalties.py
vllm/v1/worker/gpu/sample/penalties.py
+37
-7
vllm/v1/worker/gpu/sample/sampler.py
vllm/v1/worker/gpu/sample/sampler.py
+20
-2
No files found.
vllm/v1/worker/gpu/input_batch.py
View file @
16786da7
...
@@ -40,6 +40,8 @@ class InputBatch:
...
@@ -40,6 +40,8 @@ class InputBatch:
idx_mapping_np
:
np
.
ndarray
idx_mapping_np
:
np
.
ndarray
# Identical to idx_mapping except for spec decoding.
# Identical to idx_mapping except for spec decoding.
expanded_idx_mapping
:
torch
.
Tensor
expanded_idx_mapping
:
torch
.
Tensor
# [total_num_logits] position within request for each logit
expanded_local_pos
:
torch
.
Tensor
# [num_reqs]
# [num_reqs]
# batch_idx -> num_scheduled_tokens
# batch_idx -> num_scheduled_tokens
...
@@ -91,6 +93,7 @@ class InputBatch:
...
@@ -91,6 +93,7 @@ class InputBatch:
idx_mapping_np
=
np
.
arange
(
num_reqs
,
dtype
=
np
.
int32
)
idx_mapping_np
=
np
.
arange
(
num_reqs
,
dtype
=
np
.
int32
)
idx_mapping
=
torch
.
arange
(
num_reqs
,
dtype
=
torch
.
int32
,
device
=
device
)
idx_mapping
=
torch
.
arange
(
num_reqs
,
dtype
=
torch
.
int32
,
device
=
device
)
expanded_idx_mapping
=
idx_mapping
expanded_idx_mapping
=
idx_mapping
expanded_local_pos
=
torch
.
zeros
(
num_reqs
,
dtype
=
torch
.
int32
,
device
=
device
)
num_scheduled_tokens
=
np
.
full
(
num_reqs
,
num_tokens
//
num_reqs
,
dtype
=
np
.
int32
)
num_scheduled_tokens
=
np
.
full
(
num_reqs
,
num_tokens
//
num_reqs
,
dtype
=
np
.
int32
)
num_scheduled_tokens
[
-
1
]
+=
num_tokens
%
num_reqs
num_scheduled_tokens
[
-
1
]
+=
num_tokens
%
num_reqs
assert
int
(
num_scheduled_tokens
.
sum
())
==
num_tokens
assert
int
(
num_scheduled_tokens
.
sum
())
==
num_tokens
...
@@ -126,6 +129,7 @@ class InputBatch:
...
@@ -126,6 +129,7 @@ class InputBatch:
idx_mapping
=
idx_mapping
,
idx_mapping
=
idx_mapping
,
idx_mapping_np
=
idx_mapping_np
,
idx_mapping_np
=
idx_mapping_np
,
expanded_idx_mapping
=
expanded_idx_mapping
,
expanded_idx_mapping
=
expanded_idx_mapping
,
expanded_local_pos
=
expanded_local_pos
,
num_scheduled_tokens
=
num_scheduled_tokens
,
num_scheduled_tokens
=
num_scheduled_tokens
,
num_tokens
=
num_tokens
,
num_tokens
=
num_tokens
,
num_tokens_after_padding
=
num_tokens
,
num_tokens_after_padding
=
num_tokens
,
...
@@ -487,6 +491,7 @@ def post_update(
...
@@ -487,6 +491,7 @@ def post_update(
def
_expand_idx_mapping_kernel
(
def
_expand_idx_mapping_kernel
(
idx_mapping_ptr
,
idx_mapping_ptr
,
expanded_idx_mapping_ptr
,
expanded_idx_mapping_ptr
,
expanded_local_pos_ptr
,
cu_num_logits_ptr
,
cu_num_logits_ptr
,
BLOCK_SIZE
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
):
...
@@ -499,6 +504,7 @@ def _expand_idx_mapping_kernel(
...
@@ -499,6 +504,7 @@ def _expand_idx_mapping_kernel(
mask
=
block
<
num_tokens
mask
=
block
<
num_tokens
req_state_idx
=
tl
.
load
(
idx_mapping_ptr
+
req_idx
)
req_state_idx
=
tl
.
load
(
idx_mapping_ptr
+
req_idx
)
tl
.
store
(
expanded_idx_mapping_ptr
+
start_idx
+
block
,
req_state_idx
,
mask
=
mask
)
tl
.
store
(
expanded_idx_mapping_ptr
+
start_idx
+
block
,
req_state_idx
,
mask
=
mask
)
tl
.
store
(
expanded_local_pos_ptr
+
start_idx
+
block
,
block
,
mask
=
mask
)
def
expand_idx_mapping
(
def
expand_idx_mapping
(
...
@@ -506,13 +512,17 @@ def expand_idx_mapping(
...
@@ -506,13 +512,17 @@ def expand_idx_mapping(
total_num_logits
:
int
,
total_num_logits
:
int
,
cu_num_logits
:
torch
.
Tensor
,
cu_num_logits
:
torch
.
Tensor
,
max_expand_len
:
int
,
max_expand_len
:
int
,
)
->
torch
.
Tensor
:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
num_reqs
=
idx_mapping
.
shape
[
0
]
num_reqs
=
idx_mapping
.
shape
[
0
]
expanded_idx_mapping
=
idx_mapping
.
new_empty
(
total_num_logits
)
expanded_idx_mapping
=
idx_mapping
.
new_empty
(
total_num_logits
)
expanded_local_pos
=
torch
.
empty
(
total_num_logits
,
dtype
=
torch
.
int32
,
device
=
idx_mapping
.
device
)
_expand_idx_mapping_kernel
[(
num_reqs
,)](
_expand_idx_mapping_kernel
[(
num_reqs
,)](
idx_mapping
,
idx_mapping
,
expanded_idx_mapping
,
expanded_idx_mapping
,
expanded_local_pos
,
cu_num_logits
,
cu_num_logits
,
BLOCK_SIZE
=
triton
.
next_power_of_2
(
max_expand_len
),
BLOCK_SIZE
=
triton
.
next_power_of_2
(
max_expand_len
),
)
)
return
expanded_idx_mapping
return
expanded_idx_mapping
,
expanded_local_pos
vllm/v1/worker/gpu/model_runner.py
View file @
16786da7
...
@@ -152,6 +152,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -152,6 +152,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
vocab_size
=
self
.
vocab_size
,
vocab_size
=
self
.
vocab_size
,
device
=
self
.
device
,
device
=
self
.
device
,
logprobs_mode
=
self
.
model_config
.
logprobs_mode
,
logprobs_mode
=
self
.
model_config
.
logprobs_mode
,
num_speculative_tokens
=
self
.
num_speculative_steps
+
1
,
)
)
self
.
prompt_logprobs_worker
=
PromptLogprobsWorker
(
self
.
max_num_reqs
)
self
.
prompt_logprobs_worker
=
PromptLogprobsWorker
(
self
.
max_num_reqs
)
...
@@ -318,10 +319,22 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -318,10 +319,22 @@ class GPUModelRunner(LoRAModelRunnerMixin):
idx_mapping
=
torch
.
arange
(
num_reqs
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
idx_mapping
=
torch
.
arange
(
num_reqs
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
idx_mapping_np
=
np
.
arange
(
num_reqs
,
dtype
=
np
.
int32
)
idx_mapping_np
=
np
.
arange
(
num_reqs
,
dtype
=
np
.
int32
)
pos
=
torch
.
zeros
(
num_reqs
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
pos
=
torch
.
zeros
(
num_reqs
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
dummy_input_ids
=
torch
.
zeros
(
num_reqs
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
expanded_local_pos
=
torch
.
zeros
(
num_reqs
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
# NOTE(woosuk): During the initial memory profiling, the sampler may skip
# NOTE(woosuk): During the initial memory profiling, the sampler may skip
# top_k, top_p, and logprobs, using less GPU memory than what is possible
# top_k, top_p, and logprobs, using less GPU memory than what is possible
# during actual execution.
# during actual execution.
self
.
sampler
(
logits
,
idx_mapping
,
idx_mapping_np
,
idx_mapping_np
,
pos
)
self
.
sampler
(
logits
,
idx_mapping
,
idx_mapping_np
,
idx_mapping_np
,
pos
,
dummy_input_ids
,
expanded_local_pos
,
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
profile_run
(
self
)
->
None
:
def
profile_run
(
self
)
->
None
:
...
@@ -511,6 +524,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -511,6 +524,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_reqs
+
1
,
device
=
self
.
device
,
dtype
=
torch
.
int32
num_reqs
+
1
,
device
=
self
.
device
,
dtype
=
torch
.
int32
)
)
expanded_idx_mapping
=
idx_mapping
expanded_idx_mapping
=
idx_mapping
expanded_local_pos
=
torch
.
zeros
(
num_reqs
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
else
:
else
:
num_draft_tokens
=
np
.
array
(
num_draft_tokens
=
np
.
array
(
[
len
(
draft_tokens
.
get
(
req_id
,
()))
for
req_id
in
req_ids
],
[
len
(
draft_tokens
.
get
(
req_id
,
()))
for
req_id
in
req_ids
],
...
@@ -526,7 +542,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -526,7 +542,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
cu_num_logits
=
async_copy_to_gpu
(
cu_num_logits_np
,
device
=
self
.
device
)
cu_num_logits
=
async_copy_to_gpu
(
cu_num_logits_np
,
device
=
self
.
device
)
max_expand_len
=
self
.
num_speculative_steps
+
1
max_expand_len
=
self
.
num_speculative_steps
+
1
expanded_idx_mapping
=
expand_idx_mapping
(
expanded_idx_mapping
,
expanded_local_pos
=
expand_idx_mapping
(
idx_mapping
,
total_num_logits
,
cu_num_logits
,
max_expand_len
idx_mapping
,
total_num_logits
,
cu_num_logits
,
max_expand_len
)
)
...
@@ -627,6 +643,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -627,6 +643,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
idx_mapping
=
idx_mapping
,
idx_mapping
=
idx_mapping
,
idx_mapping_np
=
idx_mapping_np
,
idx_mapping_np
=
idx_mapping_np
,
expanded_idx_mapping
=
expanded_idx_mapping
,
expanded_idx_mapping
=
expanded_idx_mapping
,
expanded_local_pos
=
expanded_local_pos
,
num_scheduled_tokens
=
num_scheduled_tokens
,
num_scheduled_tokens
=
num_scheduled_tokens
,
num_tokens
=
num_tokens
,
num_tokens
=
num_tokens
,
num_tokens_after_padding
=
num_tokens_after_padding
,
num_tokens_after_padding
=
num_tokens_after_padding
,
...
@@ -674,6 +691,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -674,6 +691,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
->
tuple
[
SamplerOutput
,
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
tuple
[
SamplerOutput
,
torch
.
Tensor
,
torch
.
Tensor
]:
sample_hidden_states
=
hidden_states
[
input_batch
.
logits_indices
]
sample_hidden_states
=
hidden_states
[
input_batch
.
logits_indices
]
sample_pos
=
input_batch
.
positions
[
input_batch
.
logits_indices
]
sample_pos
=
input_batch
.
positions
[
input_batch
.
logits_indices
]
input_ids
=
input_batch
.
input_ids
[
input_batch
.
logits_indices
]
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
)
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
)
if
grammar_output
is
not
None
:
if
grammar_output
is
not
None
:
# Apply grammar bitmask to the logits in-place.
# Apply grammar bitmask to the logits in-place.
...
@@ -691,6 +709,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -691,6 +709,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
input_batch
.
idx_mapping_np
,
input_batch
.
idx_mapping_np
,
input_batch
.
cu_num_logits_np
,
input_batch
.
cu_num_logits_np
,
sample_pos
,
sample_pos
,
input_ids
,
input_batch
.
expanded_local_pos
,
)
)
if
input_batch
.
num_draft_tokens
==
0
:
if
input_batch
.
num_draft_tokens
==
0
:
...
@@ -700,7 +720,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -700,7 +720,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
)
else
:
else
:
# Rejection sampling for spec decoding.
# Rejection sampling for spec decoding.
input_ids
=
input_batch
.
input_ids
[
input_batch
.
logits_indices
]
sampled_tokens
,
num_sampled
=
rejection_sample
(
sampled_tokens
,
num_sampled
=
rejection_sample
(
sampler_output
.
sampled_token_ids
,
sampler_output
.
sampled_token_ids
,
input_ids
,
input_ids
,
...
...
vllm/v1/worker/gpu/sample/penalties.py
View file @
16786da7
...
@@ -75,6 +75,9 @@ class PenaltiesState:
...
@@ -75,6 +75,9 @@ class PenaltiesState:
logits
:
torch
.
Tensor
,
logits
:
torch
.
Tensor
,
idx_mapping
:
torch
.
Tensor
,
idx_mapping
:
torch
.
Tensor
,
idx_mapping_np
:
np
.
ndarray
,
idx_mapping_np
:
np
.
ndarray
,
input_ids
:
torch
.
Tensor
,
expanded_local_pos
:
torch
.
Tensor
,
num_speculative_tokens
:
int
,
)
->
None
:
)
->
None
:
if
not
np
.
any
(
self
.
use_penalty
[
idx_mapping_np
]):
if
not
np
.
any
(
self
.
use_penalty
[
idx_mapping_np
]):
# No request uses penalties. Skip the kernel launch.
# No request uses penalties. Skip the kernel launch.
...
@@ -83,11 +86,14 @@ class PenaltiesState:
...
@@ -83,11 +86,14 @@ class PenaltiesState:
apply_penalties
(
apply_penalties
(
logits
,
logits
,
idx_mapping
,
idx_mapping
,
input_ids
,
expanded_local_pos
,
self
.
repetition_penalty
.
gpu
,
self
.
repetition_penalty
.
gpu
,
self
.
frequency_penalty
.
gpu
,
self
.
frequency_penalty
.
gpu
,
self
.
presence_penalty
.
gpu
,
self
.
presence_penalty
.
gpu
,
self
.
prompt_bin_mask
,
self
.
prompt_bin_mask
,
self
.
output_bin_counts
,
self
.
output_bin_counts
,
num_speculative_tokens
,
)
)
...
@@ -96,6 +102,8 @@ def _penalties_kernel(
...
@@ -96,6 +102,8 @@ def _penalties_kernel(
logits_ptr
,
logits_ptr
,
logits_stride
,
logits_stride
,
idx_mapping_ptr
,
idx_mapping_ptr
,
token_ids_ptr
,
expanded_local_pos_ptr
,
repetition_penalty_ptr
,
repetition_penalty_ptr
,
frequency_penalty_ptr
,
frequency_penalty_ptr
,
presence_penalty_ptr
,
presence_penalty_ptr
,
...
@@ -105,9 +113,10 @@ def _penalties_kernel(
...
@@ -105,9 +113,10 @@ def _penalties_kernel(
output_bin_counts_stride
,
output_bin_counts_stride
,
vocab_size
,
vocab_size
,
BLOCK_SIZE
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
MAX_SPEC_LEN
:
tl
.
constexpr
,
):
):
batch
_idx
=
tl
.
program_id
(
0
)
token
_idx
=
tl
.
program_id
(
0
)
req_state_idx
=
tl
.
load
(
idx_mapping_ptr
+
batch
_idx
)
req_state_idx
=
tl
.
load
(
idx_mapping_ptr
+
token
_idx
)
rep_penalty
=
tl
.
load
(
repetition_penalty_ptr
+
req_state_idx
)
rep_penalty
=
tl
.
load
(
repetition_penalty_ptr
+
req_state_idx
)
freq_penalty
=
tl
.
load
(
frequency_penalty_ptr
+
req_state_idx
)
freq_penalty
=
tl
.
load
(
frequency_penalty_ptr
+
req_state_idx
)
pres_penalty
=
tl
.
load
(
presence_penalty_ptr
+
req_state_idx
)
pres_penalty
=
tl
.
load
(
presence_penalty_ptr
+
req_state_idx
)
...
@@ -123,13 +132,27 @@ def _penalties_kernel(
...
@@ -123,13 +132,27 @@ def _penalties_kernel(
block_idx
=
tl
.
program_id
(
1
)
block_idx
=
tl
.
program_id
(
1
)
block
=
block_idx
*
BLOCK_SIZE
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
block
=
block_idx
*
BLOCK_SIZE
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
block
<
vocab_size
mask
=
block
<
vocab_size
logits
=
tl
.
load
(
logits_ptr
+
batch
_idx
*
logits_stride
+
block
,
mask
=
mask
)
logits
=
tl
.
load
(
logits_ptr
+
token
_idx
*
logits_stride
+
block
,
mask
=
mask
)
logits
=
logits
.
to
(
tl
.
float32
)
logits
=
logits
.
to
(
tl
.
float32
)
output_
bin_
counts
=
tl
.
load
(
base_
output_counts
=
tl
.
load
(
output_bin_counts_ptr
+
req_state_idx
*
output_bin_counts_stride
+
block
,
output_bin_counts_ptr
+
req_state_idx
*
output_bin_counts_stride
+
block
,
mask
=
mask
,
mask
=
mask
,
other
=
0
,
)
)
# Compute cumulative draft_counts from previous positions in this request
pos
=
tl
.
load
(
expanded_local_pos_ptr
+
token_idx
)
start_idx
=
token_idx
-
pos
draft_counts
=
tl
.
zeros
((
BLOCK_SIZE
,),
dtype
=
tl
.
int32
)
for
prev_pos
in
tl
.
static_range
(
MAX_SPEC_LEN
):
if
prev_pos
<
pos
:
prev_token
=
tl
.
load
(
token_ids_ptr
+
start_idx
+
prev_pos
+
1
)
token_match
=
block
==
prev_token
draft_counts
=
draft_counts
+
token_match
.
to
(
tl
.
int32
)
# Total counts = base output counts + cumulative draft counts
output_bin_counts
=
base_output_counts
+
draft_counts
output_bin_mask
=
output_bin_counts
>
0
output_bin_mask
=
output_bin_counts
>
0
# Apply repetition penalties.
# Apply repetition penalties.
...
@@ -138,6 +161,7 @@ def _penalties_kernel(
...
@@ -138,6 +161,7 @@ def _penalties_kernel(
packed_mask
=
tl
.
load
(
packed_mask
=
tl
.
load
(
prompt_bin_mask_ptr
+
req_state_idx
*
prompt_bin_mask_stride
+
packed_block
,
prompt_bin_mask_ptr
+
req_state_idx
*
prompt_bin_mask_stride
+
packed_block
,
mask
=
packed_block
<
tl
.
cdiv
(
vocab_size
,
32
),
mask
=
packed_block
<
tl
.
cdiv
(
vocab_size
,
32
),
other
=
0
,
)
)
prompt_bin_mask
=
(
packed_mask
[:,
None
]
>>
(
tl
.
arange
(
0
,
32
)[
None
,
:]))
&
1
prompt_bin_mask
=
(
packed_mask
[:,
None
]
>>
(
tl
.
arange
(
0
,
32
)[
None
,
:]))
&
1
prompt_bin_mask
=
prompt_bin_mask
.
to
(
tl
.
int1
)
prompt_bin_mask
=
prompt_bin_mask
.
to
(
tl
.
int1
)
...
@@ -153,25 +177,30 @@ def _penalties_kernel(
...
@@ -153,25 +177,30 @@ def _penalties_kernel(
# Apply presence penalties.
# Apply presence penalties.
logits
-=
pres_penalty
*
output_bin_mask
logits
-=
pres_penalty
*
output_bin_mask
# Store back to logits.
# Store back to logits.
tl
.
store
(
logits_ptr
+
batch
_idx
*
logits_stride
+
block
,
logits
,
mask
=
mask
)
tl
.
store
(
logits_ptr
+
token
_idx
*
logits_stride
+
block
,
logits
,
mask
=
mask
)
def
apply_penalties
(
def
apply_penalties
(
logits
:
torch
.
Tensor
,
logits
:
torch
.
Tensor
,
idx_mapping
:
torch
.
Tensor
,
idx_mapping
:
torch
.
Tensor
,
token_ids
:
torch
.
Tensor
,
expanded_local_pos
:
torch
.
Tensor
,
repetition_penalty
:
torch
.
Tensor
,
repetition_penalty
:
torch
.
Tensor
,
frequency_penalty
:
torch
.
Tensor
,
frequency_penalty
:
torch
.
Tensor
,
presence_penalty
:
torch
.
Tensor
,
presence_penalty
:
torch
.
Tensor
,
prompt_bin_mask
:
torch
.
Tensor
,
prompt_bin_mask
:
torch
.
Tensor
,
output_bin_counts
:
torch
.
Tensor
,
output_bin_counts
:
torch
.
Tensor
,
num_speculative_tokens
:
int
,
)
->
None
:
)
->
None
:
num_
req
s
,
vocab_size
=
logits
.
shape
num_
token
s
,
vocab_size
=
logits
.
shape
BLOCK_SIZE
=
8192
BLOCK_SIZE
=
8192
num_blocks
=
triton
.
cdiv
(
vocab_size
,
BLOCK_SIZE
)
num_blocks
=
triton
.
cdiv
(
vocab_size
,
BLOCK_SIZE
)
_penalties_kernel
[(
num_
req
s
,
num_blocks
)](
_penalties_kernel
[(
num_
token
s
,
num_blocks
)](
logits
,
logits
,
logits
.
stride
(
0
),
logits
.
stride
(
0
),
idx_mapping
,
idx_mapping
,
token_ids
,
expanded_local_pos
,
repetition_penalty
,
repetition_penalty
,
frequency_penalty
,
frequency_penalty
,
presence_penalty
,
presence_penalty
,
...
@@ -181,6 +210,7 @@ def apply_penalties(
...
@@ -181,6 +210,7 @@ def apply_penalties(
output_bin_counts
.
stride
(
0
),
output_bin_counts
.
stride
(
0
),
vocab_size
,
vocab_size
,
BLOCK_SIZE
=
BLOCK_SIZE
,
BLOCK_SIZE
=
BLOCK_SIZE
,
MAX_SPEC_LEN
=
num_speculative_tokens
,
)
)
...
...
vllm/v1/worker/gpu/sample/sampler.py
View file @
16786da7
...
@@ -25,6 +25,7 @@ class Sampler:
...
@@ -25,6 +25,7 @@ class Sampler:
vocab_size
:
int
,
vocab_size
:
int
,
device
:
torch
.
device
,
device
:
torch
.
device
,
logprobs_mode
:
LogprobsMode
=
"raw_logprobs"
,
logprobs_mode
:
LogprobsMode
=
"raw_logprobs"
,
num_speculative_tokens
:
int
=
1
,
):
):
if
logprobs_mode
not
in
(
"processed_logprobs"
,
"raw_logprobs"
):
if
logprobs_mode
not
in
(
"processed_logprobs"
,
"raw_logprobs"
):
raise
NotImplementedError
(
f
"Unsupported logprobs_mode:
{
logprobs_mode
}
"
)
raise
NotImplementedError
(
f
"Unsupported logprobs_mode:
{
logprobs_mode
}
"
)
...
@@ -34,6 +35,7 @@ class Sampler:
...
@@ -34,6 +35,7 @@ class Sampler:
self
.
sampling_states
=
SamplingStates
(
max_num_reqs
,
vocab_size
)
self
.
sampling_states
=
SamplingStates
(
max_num_reqs
,
vocab_size
)
self
.
penalties_state
=
PenaltiesState
(
max_num_reqs
,
vocab_size
,
device
)
self
.
penalties_state
=
PenaltiesState
(
max_num_reqs
,
vocab_size
,
device
)
self
.
logit_bias_state
=
LogitBiasState
(
max_num_reqs
,
device
)
self
.
logit_bias_state
=
LogitBiasState
(
max_num_reqs
,
device
)
self
.
num_speculative_tokens
=
num_speculative_tokens
def
add_request
(
def
add_request
(
self
,
req_idx
:
int
,
prompt_len
:
int
,
sampling_params
:
SamplingParams
self
,
req_idx
:
int
,
prompt_len
:
int
,
sampling_params
:
SamplingParams
...
@@ -61,12 +63,19 @@ class Sampler:
...
@@ -61,12 +63,19 @@ class Sampler:
idx_mapping_np
:
np
.
ndarray
,
idx_mapping_np
:
np
.
ndarray
,
cu_num_logits_np
:
np
.
ndarray
,
cu_num_logits_np
:
np
.
ndarray
,
pos
:
torch
.
Tensor
,
pos
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
expanded_local_pos
:
torch
.
Tensor
,
)
->
SamplerOutput
:
)
->
SamplerOutput
:
# NOTE(woosuk): We intentionally compute num_nans before sampling to make clear
# NOTE(woosuk): We intentionally compute num_nans before sampling to make clear
# that num_nans is computed before applying penalties and temperature.
# that num_nans is computed before applying penalties and temperature.
num_nans
=
get_num_nans
(
logits
)
if
self
.
compute_nans
else
None
num_nans
=
get_num_nans
(
logits
)
if
self
.
compute_nans
else
None
sampled
,
processed_logits
=
self
.
sample
(
sampled
,
processed_logits
=
self
.
sample
(
logits
,
idx_mapping
,
idx_mapping_np
,
pos
logits
,
idx_mapping
,
idx_mapping_np
,
pos
,
input_ids
,
expanded_local_pos
,
)
)
max_num_logprobs
=
self
.
sampling_states
.
max_num_logprobs
(
idx_mapping_np
)
max_num_logprobs
=
self
.
sampling_states
.
max_num_logprobs
(
idx_mapping_np
)
...
@@ -98,6 +107,8 @@ class Sampler:
...
@@ -98,6 +107,8 @@ class Sampler:
idx_mapping
:
torch
.
Tensor
,
idx_mapping
:
torch
.
Tensor
,
idx_mapping_np
:
np
.
ndarray
,
idx_mapping_np
:
np
.
ndarray
,
pos
:
torch
.
Tensor
,
pos
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
expanded_local_pos
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Copy logits to a new FP32 tensor.
# Copy logits to a new FP32 tensor.
logits
=
torch
.
empty_like
(
logits
,
dtype
=
torch
.
float32
).
copy_
(
logits
)
logits
=
torch
.
empty_like
(
logits
,
dtype
=
torch
.
float32
).
copy_
(
logits
)
...
@@ -106,7 +117,14 @@ class Sampler:
...
@@ -106,7 +117,14 @@ class Sampler:
self
.
logit_bias_state
.
apply_logit_bias
(
logits
,
idx_mapping
,
idx_mapping_np
,
pos
)
self
.
logit_bias_state
.
apply_logit_bias
(
logits
,
idx_mapping
,
idx_mapping_np
,
pos
)
# Apply penalties in place.
# Apply penalties in place.
self
.
penalties_state
.
apply_penalties
(
logits
,
idx_mapping
,
idx_mapping_np
)
self
.
penalties_state
.
apply_penalties
(
logits
,
idx_mapping
,
idx_mapping_np
,
input_ids
,
expanded_local_pos
,
self
.
num_speculative_tokens
,
)
# Apply temperature in place.
# Apply temperature in place.
apply_temperature
(
logits
,
idx_mapping
,
self
.
sampling_states
.
temperature
.
gpu
)
apply_temperature
(
logits
,
idx_mapping
,
self
.
sampling_states
.
temperature
.
gpu
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment