Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
41fb013d
Unverified
Commit
41fb013d
authored
Apr 23, 2025
by
Woosuk Kwon
Committed by
GitHub
Apr 23, 2025
Browse files
[V1][Spec Decode] Always use argmax for sampling draft tokens (#16899)
Signed-off-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
32d4b669
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
18 additions
and
23 deletions
+18
-23
vllm/v1/sample/rejection_sampler.py
vllm/v1/sample/rejection_sampler.py
+7
-7
vllm/v1/spec_decode/eagle.py
vllm/v1/spec_decode/eagle.py
+10
-12
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+1
-4
No files found.
vllm/v1/sample/rejection_sampler.py
View file @
41fb013d
...
...
@@ -226,7 +226,7 @@ def rejection_sample(
is_greedy
,
max_spec_len
,
vocab_size
,
IS_NGRAM
=
draft_probs
is
None
,
NO_DRAFT_PROBS
=
draft_probs
is
None
,
num_warps
=
1
,
)
return
output_token_ids
...
...
@@ -423,7 +423,7 @@ def sample_recovered_tokens(
q
,
vocab_size
,
triton
.
next_power_of_2
(
vocab_size
),
IS_NGRAM
=
draft_probs
is
None
,
NO_DRAFT_PROBS
=
draft_probs
is
None
,
)
return
recovered_token_ids
...
...
@@ -490,7 +490,7 @@ def rejection_random_sample_kernel(
is_greedy_ptr
,
# [batch_size]
max_spec_len
,
vocab_size
,
IS_NGRAM
:
tl
.
constexpr
,
NO_DRAFT_PROBS
:
tl
.
constexpr
,
):
req_idx
=
tl
.
program_id
(
0
)
is_greedy
=
tl
.
load
(
is_greedy_ptr
+
req_idx
)
...
...
@@ -509,7 +509,7 @@ def rejection_random_sample_kernel(
for
pos
in
range
(
num_draft_tokens
):
if
not
rejected
:
draft_token_id
=
tl
.
load
(
draft_token_ids_ptr
+
start_idx
+
pos
)
if
IS_NGRAM
:
if
NO_DRAFT_PROBS
:
draft_prob
=
1
else
:
draft_prob
=
tl
.
load
(
draft_probs_ptr
+
...
...
@@ -575,7 +575,7 @@ def sample_recovered_tokens_kernel(
q_ptr
,
# [batch_size, vocab_size]
vocab_size
,
PADDED_VOCAB_SIZE
:
tl
.
constexpr
,
IS_NGRAM
:
tl
.
constexpr
,
NO_DRAFT_PROBS
:
tl
.
constexpr
,
):
req_idx
=
tl
.
program_id
(
0
)
if
req_idx
==
0
:
...
...
@@ -591,7 +591,7 @@ def sample_recovered_tokens_kernel(
return
vocab_offset
=
tl
.
arange
(
0
,
PADDED_VOCAB_SIZE
)
if
IS_NGRAM
:
if
NO_DRAFT_PROBS
:
draft_token_id
=
tl
.
load
(
draft_token_ids_ptr
+
start_idx
+
pos
)
orig_prob
=
tl
.
load
(
target_probs_ptr
+
(
start_idx
+
pos
)
*
vocab_size
+
draft_token_id
)
...
...
@@ -624,7 +624,7 @@ def sample_recovered_tokens_kernel(
recovered_id
=
tl
.
argmax
(
prob
/
q
,
axis
=-
1
)
tl
.
store
(
output_token_ids_ptr
+
start_idx
+
pos
,
recovered_id
)
if
IS_NGRAM
:
if
NO_DRAFT_PROBS
:
# Restore the original probability.
tl
.
store
(
target_probs_ptr
+
(
start_idx
+
pos
)
*
vocab_size
+
draft_token_id
,
...
...
vllm/v1/spec_decode/eagle.py
View file @
41fb013d
...
...
@@ -51,7 +51,7 @@ class EagleProposer:
# [batch_size, max_num_blocks_per_req]
block_table
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
)
->
torch
.
Tensor
:
num_tokens
=
target_token_ids
.
shape
[
0
]
batch_size
=
next_token_ids
.
shape
[
0
]
last_token_indices
=
cu_num_tokens
[
1
:]
-
1
...
...
@@ -94,17 +94,15 @@ class EagleProposer:
)
sample_hidden_states
=
hidden_states
[
last_token_indices
]
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
,
None
)
draft_token_ids
,
draft_probs
=
compute_probs_and_sample_next_token
(
logits
,
sampling_metadata
)
draft_token_ids
=
logits
.
argmax
(
dim
=-
1
)
# Early exit if there is only one draft token to be generated.
if
self
.
num_speculative_tokens
==
1
:
# [batch_size, 1]
and [batch_size, 1, vocab_size]
return
draft_token_ids
.
view
(
-
1
,
1
)
,
draft_probs
.
unsqueeze
(
dim
=
1
)
# [batch_size, 1]
return
draft_token_ids
.
view
(
-
1
,
1
)
# Generate the remaining draft tokens.
draft_token_ids_list
=
[
draft_token_ids
]
draft_probs_list
=
[
draft_probs
]
positions
=
target_positions
[
last_token_indices
]
hidden_states
=
sample_hidden_states
...
...
@@ -159,16 +157,12 @@ class EagleProposer:
positions
=
clamped_positions
,
)
logits
=
self
.
model
.
compute_logits
(
hidden_states
,
None
)
draft_token_ids
,
probs
=
compute_probs_and_sample_next_token
(
logits
,
sampling_metadata
)
draft_token_ids
=
logits
.
argmax
(
dim
=-
1
)
draft_token_ids_list
.
append
(
draft_token_ids
)
draft_probs_list
.
append
(
probs
)
# [batch_size, num_speculative_tokens]
draft_token_ids
=
torch
.
stack
(
draft_token_ids_list
,
dim
=
1
)
# [batch_size, num_speculative_tokens, vocab_size]
draft_probs
=
torch
.
stack
(
draft_probs_list
,
dim
=
1
)
return
draft_token_ids
,
draft_probs
return
draft_token_ids
@
staticmethod
def
prepare_inputs
(
...
...
@@ -238,6 +232,10 @@ class EagleProposer:
self
.
model
.
lm_head
=
target_model
.
lm_head
# NOTE(woosuk): Currently, the below code is not used and we always use argmax
# to sample the draft tokens. We will use this after we find a way to manage
# the draft prob tensor.
# Refer to https://github.com/vllm-project/vllm/pull/16899 for the details.
# FIXME(woosuk): The logic here is duplicated with the main sampling code.
# We should refactor this to reuse the same sampling implementation.
def
compute_probs_and_sample_next_token
(
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
41fb013d
...
...
@@ -1230,7 +1230,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
target_hidden_states
=
hidden_states
[
token_indices
]
target_slot_mapping
=
attn_metadata
.
slot_mapping
[
token_indices
]
draft_token_ids
,
draft_probs
=
self
.
drafter
.
propose
(
draft_token_ids
=
self
.
drafter
.
propose
(
target_token_ids
=
target_token_ids
,
target_positions
=
target_positions
,
target_hidden_states
=
target_hidden_states
,
...
...
@@ -1241,9 +1241,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
sampling_metadata
=
sampling_metadata
,
)
spec_token_ids
=
draft_token_ids
.
tolist
()
# TODO(woosuk): Cache draft_probs and use it for rejection sampling
# in the next step.
del
draft_probs
# Clear KVConnector state after all KVs are generated.
if
has_kv_transfer_group
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment