Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5629f26d
Unverified
Commit
5629f26d
authored
Feb 25, 2025
by
Lily Liu
Committed by
GitHub
Feb 25, 2025
Browse files
[V1][Spec Decode] Change Spec Decode Rejection Sampling API (#13729)
parent
9ba28043
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
104 additions
and
111 deletions
+104
-111
tests/v1/sample/test_rejection_sampler.py
tests/v1/sample/test_rejection_sampler.py
+8
-9
tests/v1/sample/test_sampler.py
tests/v1/sample/test_sampler.py
+0
-1
tests/v1/worker/test_gpu_input_batch.py
tests/v1/worker/test_gpu_input_batch.py
+0
-1
vllm/v1/sample/metadata.py
vllm/v1/sample/metadata.py
+0
-3
vllm/v1/sample/rejection_sampler.py
vllm/v1/sample/rejection_sampler.py
+68
-66
vllm/v1/sample/sampler.py
vllm/v1/sample/sampler.py
+8
-11
vllm/v1/worker/gpu_input_batch.py
vllm/v1/worker/gpu_input_batch.py
+0
-11
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+20
-9
No files found.
tests/v1/sample/test_rejection_sampler.py
View file @
5629f26d
...
...
@@ -29,7 +29,6 @@ def create_sampling_metadata(spec_tokens: List[List[int]]) -> SamplingMetadata:
temperature
=
torch
.
tensor
([]),
all_greedy
=
True
,
all_random
=
False
,
spec_token_ids
=
spec_tokens
,
top_p
=
None
,
top_k
=
None
,
min_p
=
torch
.
empty
(
batch_size
,
),
...
...
@@ -55,7 +54,7 @@ def test_perfect_match(sampler):
metadata
=
create_sampling_metadata
(
spec_tokens
)
logits
=
create_logits_tensor
(
output_tokens
)
output
=
sampler
(
logits
,
metadata
)
output
=
sampler
(
spec_tokens
,
logits
,
metadata
)
expected
=
torch
.
tensor
([[
1
,
2
,
3
,
4
]],
dtype
=
torch
.
int
,
device
=
logits
.
device
)
...
...
@@ -70,7 +69,7 @@ def test_early_mismatch(sampler):
metadata
=
create_sampling_metadata
(
spec_tokens
)
logits
=
create_logits_tensor
(
output_tokens
)
output
=
sampler
(
logits
,
metadata
)
output
=
sampler
(
spec_tokens
,
logits
,
metadata
)
expected
=
torch
.
tensor
([[
1
,
5
,
INVALID_TOKEN_ID
,
INVALID_TOKEN_ID
]],
dtype
=
torch
.
int
,
device
=
logits
.
device
)
...
...
@@ -85,7 +84,7 @@ def test_multiple_sequences(sampler):
metadata
=
create_sampling_metadata
(
spec_tokens
)
logits
=
create_logits_tensor
(
output_tokens
)
output
=
sampler
(
logits
,
metadata
)
output
=
sampler
(
spec_tokens
,
logits
,
metadata
)
expected
=
torch
.
tensor
([[
1
,
2
,
5
],
[
3
,
4
,
INVALID_TOKEN_ID
]],
dtype
=
torch
.
int
,
device
=
logits
.
device
)
...
...
@@ -100,7 +99,7 @@ def test_single_token_sequence(sampler):
metadata
=
create_sampling_metadata
(
spec_tokens
)
logits
=
create_logits_tensor
(
output_tokens
)
output
=
sampler
(
logits
,
metadata
)
output
=
sampler
(
spec_tokens
,
logits
,
metadata
)
expected
=
torch
.
tensor
([[
1
,
2
]],
dtype
=
torch
.
int
,
device
=
logits
.
device
)
assert
torch
.
equal
(
output
.
sampled_token_ids
,
expected
)
...
...
@@ -113,7 +112,7 @@ def test_empty_sequence(sampler):
metadata
=
create_sampling_metadata
(
spec_tokens
)
logits
=
create_logits_tensor
(
output_tokens
)
output
=
sampler
(
logits
,
metadata
)
output
=
sampler
(
spec_tokens
,
logits
,
metadata
)
expected
=
torch
.
tensor
([[
5
]],
dtype
=
torch
.
int
,
device
=
logits
.
device
)
assert
torch
.
equal
(
output
.
sampled_token_ids
,
expected
)
...
...
@@ -126,7 +125,7 @@ def test_multiple_mismatches(sampler):
metadata
=
create_sampling_metadata
(
spec_tokens
)
logits
=
create_logits_tensor
(
output_tokens
)
output
=
sampler
(
logits
,
metadata
)
output
=
sampler
(
spec_tokens
,
logits
,
metadata
)
expected
=
torch
.
tensor
([[
1
,
2
,
7
,
INVALID_TOKEN_ID
],
[
4
,
8
,
INVALID_TOKEN_ID
,
INVALID_TOKEN_ID
]],
dtype
=
torch
.
int
,
...
...
@@ -147,7 +146,7 @@ def test_parametrized_cases(sampler, spec_tokens, output_tokens, expected):
metadata
=
create_sampling_metadata
(
spec_tokens
)
logits
=
create_logits_tensor
(
output_tokens
)
output
=
sampler
(
logits
,
metadata
)
output
=
sampler
(
spec_tokens
,
logits
,
metadata
)
expected_tensor
=
torch
.
tensor
(
expected
,
dtype
=
torch
.
int
,
device
=
logits
.
device
)
...
...
@@ -163,7 +162,7 @@ def test_logits_shape_handling(sampler):
metadata
=
create_sampling_metadata
(
spec_tokens
)
logits
=
create_logits_tensor
(
output_tokens
,
vocab_size
)
output
=
sampler
(
logits
,
metadata
)
output
=
sampler
(
spec_tokens
,
logits
,
metadata
)
expected
=
torch
.
tensor
([[
1
,
2
,
3
]],
dtype
=
torch
.
int
,
device
=
logits
.
device
)
assert
torch
.
equal
(
output
.
sampled_token_ids
,
expected
)
assert
logits
.
shape
[
-
1
]
==
vocab_size
tests/v1/sample/test_sampler.py
View file @
5629f26d
...
...
@@ -105,7 +105,6 @@ def _create_default_sampling_metadata(
prompt_token_ids
=
_create_prompt_tokens_tensor
(
prompt_token_ids
,
vocab_size
,
device
),
output_token_ids
=
output_token_ids
,
spec_token_ids
=
None
,
frequency_penalties
=
_create_penalty_tensor
(
batch_size
,
0.0
,
device
),
presence_penalties
=
_create_penalty_tensor
(
batch_size
,
0.0
,
device
),
repetition_penalties
=
_create_penalty_tensor
(
batch_size
,
1.0
,
device
),
...
...
tests/v1/worker/test_gpu_input_batch.py
View file @
5629f26d
...
...
@@ -123,7 +123,6 @@ def _construct_expected_sampling_metadata(
dtype
=
torch
.
float
,
device
=
device
),
output_token_ids
=
output_token_ids
,
spec_token_ids
=
None
,
min_tokens
=
min_tokens
,
no_penalties
=
(
all
(
x
==
0
for
x
in
presence_penalties
)
and
all
(
x
==
0
for
x
in
frequency_penalties
)
...
...
vllm/v1/sample/metadata.py
View file @
5629f26d
...
...
@@ -13,9 +13,6 @@ class SamplingMetadata:
all_greedy
:
bool
all_random
:
bool
# None when there are no speculated tokens.
spec_token_ids
:
Optional
[
List
[
List
[
int
]]]
top_p
:
Optional
[
torch
.
Tensor
]
top_k
:
Optional
[
torch
.
Tensor
]
min_p
:
Optional
[
torch
.
Tensor
]
...
...
vllm/v1/sample/rejection_sampler.py
View file @
5629f26d
# SPDX-License-Identifier: Apache-2.0
from
typing
import
List
import
torch
import
torch.nn
as
nn
from
torch.nn.utils.rnn
import
pad_sequence
...
...
@@ -52,62 +54,62 @@ class RejectionSampler(nn.Module):
else
:
self
.
forward_method
=
self
.
forward_native
def
forward
(
self
,
logits
:
torch
.
Tensor
,
def
forward
(
self
,
draft_token_ids
:
List
[
List
[
int
]],
target_probs
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
SamplerOutput
:
if
not
sampling_metadata
.
all_greedy
:
raise
NotImplementedError
(
"Currently, only greedy sampling is supported by "
"rejection sampler."
)
return
self
.
forward_method
(
logits
,
sampling_metadata
)
return
self
.
forward_method
(
draft_token_ids
,
target_probs
,
sampling_metadata
)
def
flashinfer_sample
(
self
,
logits
:
torch
.
Tensor
,
draft_token_ids
:
List
[
List
[
int
]],
target_probs
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
SamplerOutput
:
# NOTE: The following input preparationg can be moved
# to the model runner with a persistent manner for better
# performance.
assert
sampling_metadata
.
spec_token_ids
is
not
None
spec_token_ids
=
sampling_metadata
.
spec_token_ids
max_spec_len
=
max
(
len
(
s
)
for
s
in
spec_token_ids
)
batch_size
=
len
(
spec_token_ids
)
draft_token_ids
=
torch
.
full
((
batch_size
,
max_spec_len
),
INVALID_TOKEN_ID
,
device
=
"cpu"
,
dtype
=
torch
.
long
)
target_token_ids
=
torch
.
full
((
batch_size
,
max_spec_len
+
1
),
fill_value
=
INVALID_TOKEN_ID
,
device
=
logits
.
device
,
dtype
=
torch
.
long
)
# TODO: Vectorize the following loop for better performance.
start_loc
=
0
for
i
in
range
(
batch_size
):
num_spec_tokens
=
len
(
spec_token_ids
[
i
])
draft_token_ids
[
i
,
:
num_spec_tokens
]
=
torch
.
tensor
(
spec_token_ids
[
i
],
device
=
"cpu"
,
dtype
=
torch
.
long
)
end_loc
=
start_loc
+
num_spec_tokens
+
1
# Assume greedy sampling.
target_token_ids
[
i
,
:
num_spec_tokens
+
1
]
=
torch
.
argmax
(
logits
[
start_loc
:
end_loc
],
dim
=-
1
)
start_loc
=
end_loc
vocab_size
=
logits
.
size
(
-
1
)
sample_lens
=
[
len
(
x
)
+
1
for
x
in
draft_token_ids
]
# Convert draft token IDs to a tensor, split by sample_lens, then pad.
draft_token_ids
=
[
torch
.
tensor
(
x
,
dtype
=
int
,
device
=
'cpu'
)
for
x
in
draft_token_ids
]
draft_token_ids_tensor
=
pad_sequence
(
draft_token_ids
,
batch_first
=
True
,
padding_value
=
INVALID_TOKEN_ID
)
if
sampling_metadata
.
all_greedy
:
target_token_ids
=
target_probs
.
argmax
(
dim
=-
1
).
view
(
-
1
)
target_token_ids
=
target_token_ids
.
split
(
sample_lens
)
target_token_ids
=
pad_sequence
(
target_token_ids
,
batch_first
=
True
,
padding_value
=
INVALID_TOKEN_ID
)
vocab_size
=
target_probs
.
size
(
-
1
)
# NOTE: CPU <-> GPU synchronization happens here.
draft_token_ids
=
draft_token_ids
.
to
(
logits
.
device
)
draft_probs
=
_create_greedy_token_probs
(
draft_token_ids
,
vocab_size
,
logits
.
device
)
target_probs
=
_create_greedy_token_probs
(
target_token_ids
,
vocab_size
,
logits
.
device
)
uniform_samples
=
torch
.
zeros
(
batch_size
,
max_spec_len
+
1
,
device
=
logits
.
device
)
draft_token_ids_tensor
=
draft_token_ids_tensor
.
to
(
target_probs
.
device
)
draft_probs
=
_create_greedy_token_probs
(
draft_token_ids_tensor
,
vocab_size
,
target_probs
.
device
)
target_probs
=
_create_greedy_token_probs
(
target_token_ids
,
vocab_size
,
target_probs
.
device
)
uniform_samples
=
torch
.
zeros
(
draft_token_ids_tensor
.
size
(
0
),
draft_token_ids_tensor
.
size
(
1
)
+
1
,
device
=
target_probs
.
device
)
else
:
raise
NotImplementedError
(
"Currently, only greedy sampling is supported by "
"rejection sampler."
)
sampled_token_ids
,
_
,
_
=
fs
.
chain_speculative_sampling
(
draft_probs
,
draft_token_ids
,
draft_token_ids
_tensor
,
uniform_samples
,
target_probs
,
)
...
...
@@ -117,35 +119,35 @@ class RejectionSampler(nn.Module):
# TODO: The following method can be optimized for better performance.
def
forward_native
(
self
,
logits
:
torch
.
Tensor
,
draft_token_ids
:
List
[
List
[
int
]],
target_probs
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
SamplerOutput
:
assert
sampling_metadata
.
spec_token_ids
is
not
None
spec_lens
=
[
len
(
x
)
for
x
in
sampling_metadata
.
spec_token_ids
]
sample_lens
=
[
len
(
x
)
+
1
for
x
in
draft_token_ids
]
# Convert draft token IDs to a tensor, split by sample_lens, then pad.
draft_token_ids
=
[
torch
.
tensor
(
x
,
dtype
=
int
,
device
=
'cpu'
)
for
x
in
draft_token_ids
]
draft_token_ids_tensor
=
pad_sequence
(
draft_token_ids
,
batch_first
=
True
,
padding_value
=
INVALID_TOKEN_ID
)
draft_token_ids_tensor
=
draft_token_ids_tensor
.
to
(
target_probs
.
device
)
# Add 1 to include the 'bonus' token.
sample_lens
=
[
x
+
1
for
x
in
spec_lens
]
output_token_ids
=
logits
.
argmax
(
dim
=-
1
).
view
(
-
1
)
if
sampling_metadata
.
all_greedy
:
output_token_ids
=
target_probs
.
argmax
(
dim
=-
1
).
view
(
-
1
)
output_token_ids
=
output_token_ids
.
split
(
sample_lens
)
output_token_ids
=
pad_sequence
(
output_token_ids
,
batch_first
=
True
,
padding_value
=
INVALID_TOKEN_ID
)
# Convert spec token IDs to a tensor, split by sample_lens, then pad.
spec_token_ids
=
[
torch
.
tensor
(
x
,
dtype
=
output_token_ids
.
dtype
,
device
=
output_token_ids
.
device
)
for
x
in
sampling_metadata
.
spec_token_ids
]
spec_token_ids
=
pad_sequence
(
spec_token_ids
,
batch_first
=
True
,
padding_value
=
INVALID_TOKEN_ID
)
# Produce a mask that remains 1 (True) until the first
# mismatch (cumprod turns 0 after a mismatch).
accept_mask
=
(
output_token_ids
[:,
:
-
1
]
==
spec_token_ids
).
cumprod
(
accept_mask
=
(
output_token_ids
[:,
:
-
1
]
==
draft_token_ids_tensor
).
cumprod
(
dim
=
1
)
else
:
raise
NotImplementedError
(
"Currently, only greedy sampling is supported by "
"rejection sampler."
)
# Identify valid positions (non-padding).
valid_mask
=
output_token_ids
!=
INVALID_TOKEN_ID
# Generate mask with bonus token.
...
...
vllm/v1/sample/sampler.py
View file @
5629f26d
...
...
@@ -9,7 +9,6 @@ from vllm.v1.sample.metadata import SamplingMetadata
from
vllm.v1.sample.ops.penalties
import
(
apply_all_penalties
,
apply_min_token_penalties
)
from
vllm.v1.sample.ops.topk_topp_sampler
import
TopKTopPSampler
from
vllm.v1.sample.rejection_sampler
import
RejectionSampler
_SAMPLING_EPS
=
1e-5
...
...
@@ -19,22 +18,12 @@ class Sampler(nn.Module):
def
__init__
(
self
):
super
().
__init__
()
self
.
topk_topp_sampler
=
TopKTopPSampler
()
self
.
rejection_sampler
=
RejectionSampler
()
def
forward
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
SamplerOutput
:
if
sampling_metadata
.
spec_token_ids
:
if
sampling_metadata
.
max_num_logprobs
:
raise
NotImplementedError
(
"Rejection sampling does not support logprobs."
)
return
self
.
rejection_sampler
(
logits
,
sampling_metadata
,
)
# NOTE(woosuk): Use the original logits (before any penalties or
# temperature scaling) for the top-k logprobs.
# This is different from the V0 sampler, which uses the logits that
...
...
@@ -127,6 +116,14 @@ class Sampler(nn.Module):
)
return
sampled
def
compute_probs
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
if
sampling_metadata
.
all_greedy
:
return
logits
# Apply temperature. This is an in-place op changing logits.
logits
=
self
.
apply_temperature
(
logits
,
sampling_metadata
.
temperature
)
return
logits
.
softmax
(
dim
=-
1
,
dtype
=
torch
.
float32
)
def
compute_logprobs
(
self
,
logits
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
logits
.
log_softmax
(
dim
=-
1
,
dtype
=
torch
.
float32
)
...
...
vllm/v1/worker/gpu_input_batch.py
View file @
5629f26d
...
...
@@ -490,23 +490,12 @@ class InputBatch:
presence_penalties
=
self
.
presence_penalties
[:
num_reqs
],
repetition_penalties
=
self
.
repetition_penalties
[:
num_reqs
],
output_token_ids
=
cast
(
List
[
List
[
int
]],
self
.
req_output_token_ids
),
spec_token_ids
=
None
,
min_tokens
=
self
.
min_tokens
,
no_penalties
=
self
.
no_penalties
,
logit_bias
=
self
.
logit_bias
[:
num_reqs
],
allowed_token_ids_mask
=
allowed_token_ids_mask
,
)
def
get_sampling_metadata
(
self
,
req_id_to_spec_token_ids
:
Dict
[
str
,
List
[
int
]],
)
->
SamplingMetadata
:
# Set the new spec token ids in the cached sampling metadata.
self
.
sampling_metadata
.
spec_token_ids
=
[
req_id_to_spec_token_ids
.
get
(
req_id
,
[])
for
req_id
in
self
.
req_ids
]
if
req_id_to_spec_token_ids
else
None
return
self
.
sampling_metadata
def
_make_prompt_token_ids_tensor
(
self
)
->
torch
.
Tensor
:
max_prompt_len
=
self
.
num_prompt_tokens
[:
self
.
num_reqs
].
max
()
prompt_token_ids_cpu_tensor
=
torch
.
empty
(
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
5629f26d
...
...
@@ -32,7 +32,7 @@ from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheSpec
)
from
vllm.v1.outputs
import
LogprobsTensors
,
ModelRunnerOutput
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.sample.rejection_sampler
import
INVALID_TOKEN_ID
from
vllm.v1.sample.rejection_sampler
import
INVALID_TOKEN_ID
,
RejectionSampler
from
vllm.v1.spec_decode.ngram_proposer
import
NgramProposer
from
vllm.v1.utils
import
bind_kv_cache
from
vllm.v1.worker.gpu_input_batch
import
CachedRequestState
,
InputBatch
...
...
@@ -122,7 +122,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
use_spec_decode
=
False
if
self
.
speculative_config
:
self
.
use_spec_decode
=
True
self
.
rejection_sampler
=
RejectionSampler
()
# TODO: find a better way to check if we are using ngram.
assert
self
.
speculative_config
.
ngram_prompt_lookup_min
,
\
"Currently, only ngram spec decode is supported in V1."
...
...
@@ -951,12 +951,24 @@ class GPUModelRunner(LoRAModelRunnerMixin):
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
,
None
)
# Sample the next token and get logprobs if needed.
sampling_metadata
=
self
.
input_batch
.
get_
sampling_metadata
(
scheduler_output
.
scheduled
_spec_decode
_tokens
)
sampling_metadata
=
self
.
input_batch
.
sampling_metadata
if
not
self
.
use
_spec_decode
:
sampler_output
=
self
.
model
.
sample
(
logits
=
logits
,
sampling_metadata
=
sampling_metadata
,
)
else
:
target_probs
=
self
.
model
.
sampler
.
compute_probs
(
logits
,
sampling_metadata
)
scheduled_request_ids
=
scheduler_output
.
num_scheduled_tokens
.
keys
(
)
draft_token_ids
=
[
scheduler_output
.
scheduled_spec_decode_tokens
.
get
(
req_id
,
[])
for
req_id
in
scheduled_request_ids
]
sampler_output
=
self
.
rejection_sampler
(
draft_token_ids
,
target_probs
,
sampling_metadata
)
# TODO(woosuk): The following loop can be slow since it iterates over
# the requests one by one. Optimize.
...
...
@@ -1293,7 +1305,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
temperature
=
dummy_tensors
(
0.5
),
all_greedy
=
False
,
all_random
=
False
,
spec_token_ids
=
None
,
top_p
=
dummy_tensors
(
0.9
),
top_k
=
dummy_tensors
(
logits
.
size
(
1
)
-
1
),
min_p
=
None
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment