Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b7433ca1
Unverified
Commit
b7433ca1
authored
Sep 18, 2025
by
Benjamin Chislett
Committed by
GitHub
Sep 18, 2025
Browse files
[Spec Decode] Efficient padded speculation (#24539)
Signed-off-by:
Benjamin Chislett
<
bchislett@nvidia.com
>
parent
5c65a72b
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
507 additions
and
104 deletions
+507
-104
tests/v1/spec_decode/test_eagle.py
tests/v1/spec_decode/test_eagle.py
+174
-5
vllm/config/speculative.py
vllm/config/speculative.py
+5
-0
vllm/v1/spec_decode/eagle.py
vllm/v1/spec_decode/eagle.py
+223
-35
vllm/v1/worker/gpu_input_batch.py
vllm/v1/worker/gpu_input_batch.py
+4
-1
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+101
-63
No files found.
tests/v1/spec_decode/test_eagle.py
View file @
b7433ca1
...
@@ -19,6 +19,8 @@ from vllm.config.load import LoadConfig
...
@@ -19,6 +19,8 @@ from vllm.config.load import LoadConfig
from
vllm.model_executor.models.llama
import
LlamaForCausalLM
from
vllm.model_executor.models.llama
import
LlamaForCausalLM
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.v1.spec_decode.eagle
import
EagleProposer
from
vllm.v1.spec_decode.eagle
import
EagleProposer
from
vllm.v1.spec_decode.metadata
import
SpecDecodeMetadata
from
vllm.v1.worker.gpu_input_batch
import
CachedRequestState
,
InputBatch
model_dir
=
"meta-llama/Llama-3.1-8B-Instruct"
model_dir
=
"meta-llama/Llama-3.1-8B-Instruct"
eagle_dir
=
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
eagle_dir
=
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
...
@@ -64,6 +66,86 @@ def _create_proposer(
...
@@ -64,6 +66,86 @@ def _create_proposer(
device
=
current_platform
.
device_type
)
device
=
current_platform
.
device_type
)
def
test_prepare_next_token_ids
():
"""
Test for prepare_next_token_ids_cpu and prepare_next_token_ids_padded.
Each will produce a device tensor of next_token_ids, taking as input
either the GPU tensor of sampled_token_ids with -1 for rejected tokens,
or the CPU python list[list[int]] with the rejected tokens removed.
"""
device
=
torch
.
device
(
current_platform
.
device_type
)
num_requests
=
4
num_speculative_tokens
=
4
batch_spec
=
BatchSpec
(
seq_lens
=
[
num_speculative_tokens
+
1
]
*
num_requests
,
query_lens
=
[
num_speculative_tokens
+
1
]
*
num_requests
,
)
req_ids
=
[
f
"req_
{
i
+
1
}
"
for
i
in
range
(
num_requests
)]
mock_input_batch
=
mock
.
MagicMock
(
spec
=
InputBatch
)
mock_input_batch
.
req_ids
=
req_ids
mock_input_batch
.
num_reqs
=
num_requests
mock_input_batch
.
vocab_size
=
100
mock_num_scheduled_tokens
=
{
req_id
:
0
for
req_id
in
req_ids
}
mock_requests
=
{}
for
req_id
in
req_ids
:
mock_request
=
mock
.
MagicMock
(
spec
=
CachedRequestState
)
# Each request will have a backup next token id of 10, 20, 30, 40
mock_request
.
get_token_id
.
return_value
=
int
(
req_id
.
split
(
"_"
)[
1
])
*
10
mock_request
.
num_computed_tokens
=
0
mock_requests
[
req_id
]
=
mock_request
sampled_token_ids
=
[
[
0
,
1
,
-
1
,
-
1
,
-
1
],
# 1 accepted, 3 rejected, "1" sampled
[
0
,
1
,
2
,
3
,
4
],
# all accepted, "4" sampled
[
-
1
,
-
1
,
-
1
,
-
1
,
-
1
],
# sampling skipped, use backup token "30"
[
-
1
,
-
1
,
-
1
,
-
1
,
-
1
]
# this request will be discarded
]
sampled_token_ids_tensor
=
torch
.
tensor
(
sampled_token_ids
,
dtype
=
torch
.
int32
,
device
=
device
)
sampled_token_ids_cpu
=
[[
i
for
i
in
seq
if
i
!=
-
1
]
for
seq
in
sampled_token_ids
]
expected_next_token_ids_cpu
=
[
1
,
4
,
30
,
40
]
expected_next_token_ids_tensor
=
torch
.
tensor
(
expected_next_token_ids_cpu
,
dtype
=
torch
.
int32
,
device
=
device
)
proposer
=
_create_proposer
(
"eagle"
,
num_speculative_tokens
)
next_token_ids_from_cpu
=
proposer
.
prepare_next_token_ids_cpu
(
sampled_token_ids_cpu
,
mock_requests
,
mock_input_batch
,
mock_num_scheduled_tokens
)
assert
torch
.
equal
(
next_token_ids_from_cpu
,
expected_next_token_ids_tensor
)
common_attn_metadata
=
create_common_attn_metadata
(
batch_spec
,
block_size
=
16
,
device
=
device
,
)
discarded_req_indices
=
torch
.
tensor
([
3
],
dtype
=
torch
.
int64
,
device
=
device
)
num_discarded_reqs
=
1
expected_valid_sampled_tokens_count
=
torch
.
tensor
([
2
,
5
,
0
,
0
],
dtype
=
torch
.
int32
,
device
=
device
)
next_token_ids_from_padded
,
valid_sampled_tokens_count
=
\
proposer
.
prepare_next_token_ids_padded
(
common_attn_metadata
,
sampled_token_ids_tensor
,
mock_requests
,
mock_input_batch
,
discarded_req_indices
,
num_discarded_reqs
)
assert
torch
.
equal
(
next_token_ids_from_padded
,
expected_next_token_ids_tensor
)
assert
torch
.
equal
(
valid_sampled_tokens_count
,
expected_valid_sampled_tokens_count
)
def
test_prepare_inputs
():
def
test_prepare_inputs
():
"""
"""
cu_target_query_lens: [0, a, a + b, a + b + c]
cu_target_query_lens: [0, a, a + b, a + b + c]
...
@@ -90,10 +172,24 @@ def test_prepare_inputs():
...
@@ -90,10 +172,24 @@ def test_prepare_inputs():
device
=
device
,
device
=
device
,
)
)
# Rejected tokens per request: [1, 3, 2]
# If there are `k` sampled tokens, then `k-1` tokens are draft tokens
num_rejected_tokens
=
torch
.
tensor
([
1
,
3
,
2
],
# from the previous iteration, and the last token is the bonus token sampled
dtype
=
torch
.
int32
,
# from the base model.
device
=
device
)
num_draft_tokens
=
[
3
,
6
,
4
]
# one less than query_lens
# num rejected tokens is [1, 3, 2]
ACCEPT_TOKEN
=
0
BONUS_TOKEN
=
1
REJECT_TOKEN
=
-
1
sampled_token_ids
=
[
[
ACCEPT_TOKEN
,
ACCEPT_TOKEN
,
REJECT_TOKEN
,
BONUS_TOKEN
],
[
ACCEPT_TOKEN
,
ACCEPT_TOKEN
,
ACCEPT_TOKEN
,
REJECT_TOKEN
,
REJECT_TOKEN
,
REJECT_TOKEN
,
BONUS_TOKEN
],
[
ACCEPT_TOKEN
,
ACCEPT_TOKEN
,
REJECT_TOKEN
,
REJECT_TOKEN
,
BONUS_TOKEN
]
]
sampled_token_ids
=
[[
i
for
i
in
seq
if
i
!=
REJECT_TOKEN
]
for
seq
in
sampled_token_ids
]
# Expected calculations:
# Expected calculations:
# query_len_per_req = [4, 7, 5]
# query_len_per_req = [4, 7, 5]
...
@@ -125,7 +221,7 @@ def test_prepare_inputs():
...
@@ -125,7 +221,7 @@ def test_prepare_inputs():
proposer
=
_create_proposer
(
"eagle"
,
1
)
proposer
=
_create_proposer
(
"eagle"
,
1
)
updated_metadata
,
token_indices
=
proposer
.
prepare_inputs
(
updated_metadata
,
token_indices
=
proposer
.
prepare_inputs
(
common_attn_metadata
,
num_rejected
_tokens
.
cpu
()
)
common_attn_metadata
,
sampled_token_ids
,
num_draft
_tokens
)
assert
torch
.
equal
(
updated_metadata
.
query_start_loc
,
assert
torch
.
equal
(
updated_metadata
.
query_start_loc
,
expected_cu_num_tokens
)
expected_cu_num_tokens
)
...
@@ -133,6 +229,77 @@ def test_prepare_inputs():
...
@@ -133,6 +229,77 @@ def test_prepare_inputs():
assert
torch
.
equal
(
token_indices
,
expected_token_indices
)
assert
torch
.
equal
(
token_indices
,
expected_token_indices
)
def
test_prepare_inputs_padded
():
"""
Input scenario is 3 requests with num_speculative_tokens == 2 and:
- Request 1: query_len = 3, rejected = 1
- Request 2: query_len = 3, rejected = 0
- Request 3: query_len = 3, rejected = 2
Expected outputs:
token_indices: [0, 1, 2,
3, 4, 5,
6, 7, 8]
Reason: Deferred computation should not disturb the original indices.
token_indices_to_sample: [1, 5, 6]
Reason: After accounting for rejections, these are the valid token positions
from the original indices to sample from.
"""
device
=
torch
.
device
(
current_platform
.
device_type
)
expected_token_indices
=
torch
.
tensor
([
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
],
dtype
=
torch
.
int32
,
device
=
device
)
expected_token_indices_to_sample
=
torch
.
tensor
([
1
,
5
,
6
],
dtype
=
torch
.
int32
,
device
=
device
)
num_speculative_tokens
=
2
batch_spec
=
BatchSpec
(
seq_lens
=
[
3
,
3
,
3
],
query_lens
=
[
3
,
3
,
3
],
)
common_attn_metadata
=
create_common_attn_metadata
(
batch_spec
,
block_size
=
16
,
device
=
device
,
)
# Needed for cu_num_draft_tokens, which is expected to be [3, 6, 9]
expected_query_start_loc
=
torch
.
tensor
([
0
,
3
,
6
,
9
],
dtype
=
torch
.
int32
,
device
=
device
)
spec_decode_metadata
=
SpecDecodeMetadata
.
make_dummy
(
draft_token_ids
=
[[
0
]
*
num_speculative_tokens
]
*
3
,
device
=
device
,
)
# num_rejected_tokens = [1, 0, 2]
# num_draft_tokens = [2, 2, 2]
# valid_sampled_tokens_count = num_draft_tokens + 1 - num_rejected_tokens
valid_sampled_tokens_count
=
torch
.
tensor
([
2
,
3
,
1
],
dtype
=
torch
.
int32
,
device
=
device
)
proposer
=
_create_proposer
(
"eagle"
,
num_speculative_tokens
)
output_metadata
,
token_indices
,
token_indices_to_sample
=
\
proposer
.
prepare_inputs_padded
(
common_attn_metadata
,
spec_decode_metadata
,
valid_sampled_tokens_count
)
assert
output_metadata
.
max_query_len
==
3
assert
torch
.
equal
(
output_metadata
.
query_start_loc
,
expected_query_start_loc
)
assert
torch
.
equal
(
token_indices
,
expected_token_indices
)
assert
torch
.
equal
(
token_indices_to_sample
,
expected_token_indices_to_sample
)
@
pytest
.
mark
.
parametrize
(
"method"
,
[
"eagle"
,
"eagle3"
])
@
pytest
.
mark
.
parametrize
(
"method"
,
[
"eagle"
,
"eagle3"
])
@
pytest
.
mark
.
parametrize
(
"attn_backend"
,
@
pytest
.
mark
.
parametrize
(
"attn_backend"
,
get_attn_backend_list_based_on_platform
())
get_attn_backend_list_based_on_platform
())
...
@@ -373,6 +540,7 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
...
@@ -373,6 +540,7 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
target_positions
=
target_positions
,
target_positions
=
target_positions
,
target_hidden_states
=
target_hidden_states
,
target_hidden_states
=
target_hidden_states
,
next_token_ids
=
next_token_ids
,
next_token_ids
=
next_token_ids
,
last_token_indices
=
None
,
common_attn_metadata
=
common_attn_metadata
,
common_attn_metadata
=
common_attn_metadata
,
sampling_metadata
=
sampling_metadata
)
sampling_metadata
=
sampling_metadata
)
...
@@ -526,6 +694,7 @@ def test_propose_tree(spec_token_tree):
...
@@ -526,6 +694,7 @@ def test_propose_tree(spec_token_tree):
target_positions
=
target_positions
,
target_positions
=
target_positions
,
target_hidden_states
=
target_hidden_states
,
target_hidden_states
=
target_hidden_states
,
next_token_ids
=
next_token_ids
,
next_token_ids
=
next_token_ids
,
last_token_indices
=
None
,
common_attn_metadata
=
common_attn_metadata
,
common_attn_metadata
=
common_attn_metadata
,
sampling_metadata
=
sampling_metadata
)
sampling_metadata
=
sampling_metadata
)
assert
result
.
shape
==
(
batch_size
,
num_speculative_tokens
)
assert
result
.
shape
==
(
batch_size
,
num_speculative_tokens
)
...
...
vllm/config/speculative.py
View file @
b7433ca1
...
@@ -83,6 +83,11 @@ class SpeculativeConfig:
...
@@ -83,6 +83,11 @@ class SpeculativeConfig:
disable_by_batch_size
:
Optional
[
int
]
=
None
disable_by_batch_size
:
Optional
[
int
]
=
None
"""Disable speculative decoding for new incoming requests when the number
"""Disable speculative decoding for new incoming requests when the number
of enqueued requests is larger than this value, if provided."""
of enqueued requests is larger than this value, if provided."""
disable_padded_drafter_batch
:
bool
=
False
"""Disable input padding for speculative decoding. If set to True,
speculative input batches can contain sequences of different lengths,
which may only be supported by certain attention backends. This currently
only affects the EAGLE method of speculation."""
# Ngram proposer configuration
# Ngram proposer configuration
prompt_lookup_max
:
Optional
[
int
]
=
None
prompt_lookup_max
:
Optional
[
int
]
=
None
...
...
vllm/v1/spec_decode/eagle.py
View file @
b7433ca1
...
@@ -27,6 +27,9 @@ from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata
...
@@ -27,6 +27,9 @@ from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata
from
vllm.v1.attention.backends.utils
import
CommonAttentionMetadata
from
vllm.v1.attention.backends.utils
import
CommonAttentionMetadata
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.spec_decode.metadata
import
SpecDecodeMetadata
from
vllm.v1.utils
import
CpuGpuBuffer
from
vllm.v1.worker.gpu_input_batch
import
CachedRequestState
,
InputBatch
from
vllm.v1.worker.ubatching
import
dbo_current_ubatch_id
from
vllm.v1.worker.ubatching
import
dbo_current_ubatch_id
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -94,20 +97,26 @@ class EagleProposer:
...
@@ -94,20 +97,26 @@ class EagleProposer:
dtype
=
self
.
dtype
,
dtype
=
self
.
dtype
,
device
=
device
)
device
=
device
)
max_batch_size
=
vllm_config
.
scheduler_config
.
max_num_seqs
self
.
arange
=
torch
.
arange
(
# We need +1 here because the arange is used to set query_start_loc,
# We need +1 here because the arange is used to set query_start_loc,
# which has one more element than batch_size.
# which has one more element than batch_size.
max_batch_size
+
1
,
max_batch_size
=
vllm_config
.
scheduler_config
.
max_num_seqs
max_num_slots_for_arange
=
max
(
max_batch_size
+
1
,
self
.
max_num_tokens
)
self
.
arange
=
torch
.
arange
(
max_num_slots_for_arange
,
device
=
device
,
device
=
device
,
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
)
)
self
.
inputs_embeds
=
torch
.
zeros
(
self
.
inputs_embeds
=
torch
.
zeros
(
(
self
.
max_num_tokens
,
self
.
hidden_size
),
(
self
.
max_num_tokens
,
self
.
hidden_size
),
dtype
=
self
.
dtype
,
dtype
=
self
.
dtype
,
device
=
device
)
device
=
device
)
self
.
backup_next_token_ids
=
CpuGpuBuffer
(
max_batch_size
,
dtype
=
torch
.
int32
,
pin_memory
=
is_pin_memory_available
(),
device
=
device
,
with_numpy
=
True
)
# Determine allowed attention backends once during initialization.
# Determine allowed attention backends once during initialization.
self
.
allowed_attn_types
:
tuple
[
type
[
EagleAttentionMetadata
],
...]
self
.
allowed_attn_types
:
tuple
[
type
[
EagleAttentionMetadata
],
...]
if
current_platform
.
is_rocm
():
if
current_platform
.
is_rocm
():
...
@@ -156,12 +165,15 @@ class EagleProposer:
...
@@ -156,12 +165,15 @@ class EagleProposer:
target_hidden_states
:
torch
.
Tensor
,
target_hidden_states
:
torch
.
Tensor
,
# [batch_size]
# [batch_size]
next_token_ids
:
torch
.
Tensor
,
next_token_ids
:
torch
.
Tensor
,
last_token_indices
:
Optional
[
torch
.
Tensor
],
common_attn_metadata
:
CommonAttentionMetadata
,
common_attn_metadata
:
CommonAttentionMetadata
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
mm_embeds
:
Optional
[
list
[
torch
.
Tensor
]]
=
None
,
mm_embeds
:
Optional
[
list
[
torch
.
Tensor
]]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
num_tokens
=
target_token_ids
.
shape
[
0
]
num_tokens
=
target_token_ids
.
shape
[
0
]
batch_size
=
next_token_ids
.
shape
[
0
]
batch_size
=
next_token_ids
.
shape
[
0
]
if
last_token_indices
is
None
:
last_token_indices
=
common_attn_metadata
.
query_start_loc
[
1
:]
-
1
last_token_indices
=
common_attn_metadata
.
query_start_loc
[
1
:]
-
1
if
self
.
method
==
"eagle3"
:
if
self
.
method
==
"eagle3"
:
...
@@ -228,6 +240,12 @@ class EagleProposer:
...
@@ -228,6 +240,12 @@ class EagleProposer:
last_hidden_states
,
hidden_states
=
ret_hidden_states
last_hidden_states
,
hidden_states
=
ret_hidden_states
sample_hidden_states
=
last_hidden_states
[
last_token_indices
]
sample_hidden_states
=
last_hidden_states
[
last_token_indices
]
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
,
None
)
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
,
None
)
# Early exit if there is only one draft token to be generated.
if
self
.
num_speculative_tokens
==
1
:
draft_token_ids
=
logits
.
argmax
(
dim
=-
1
)
return
draft_token_ids
.
view
(
-
1
,
1
)
positions
=
target_positions
[
last_token_indices
]
positions
=
target_positions
[
last_token_indices
]
hidden_states
=
hidden_states
[
last_token_indices
]
hidden_states
=
hidden_states
[
last_token_indices
]
...
@@ -245,15 +263,12 @@ class EagleProposer:
...
@@ -245,15 +263,12 @@ class EagleProposer:
draft_token_ids
=
logits
.
argmax
(
dim
=-
1
)
draft_token_ids
=
logits
.
argmax
(
dim
=-
1
)
# Early exit if there is only one draft token to be generated.
if
not
isinstance
(
attn_metadata
,
self
.
allowed_attn_types
):
if
self
.
num_speculative_tokens
==
1
:
raise
ValueError
(
# [batch_size, 1]
f
"Unsupported attention metadata type for speculative "
return
draft_token_ids
.
view
(
-
1
,
1
)
"decoding with num_speculative_tokens > 1: "
f
"
{
type
(
attn_metadata
)
}
. Supported types are: "
# TODO: Currently, MTP module released by deepseek only has
f
"
{
self
.
allowed_attn_types
}
"
)
# one layer. Adapt this code to support multiple layers once
# there's a multi-layer MTP module.
assert
isinstance
(
attn_metadata
,
self
.
allowed_attn_types
)
# Generate the remaining draft tokens.
# Generate the remaining draft tokens.
draft_token_ids_list
=
[
draft_token_ids
]
draft_token_ids_list
=
[
draft_token_ids
]
...
@@ -263,10 +278,13 @@ class EagleProposer:
...
@@ -263,10 +278,13 @@ class EagleProposer:
input_batch_size
=
self
.
vllm_config
.
pad_for_cudagraph
(
batch_size
)
input_batch_size
=
self
.
vllm_config
.
pad_for_cudagraph
(
batch_size
)
else
:
else
:
input_batch_size
=
batch_size
input_batch_size
=
batch_size
attn_metadata
.
num_actual_tokens
=
batch_size
attn_metadata
.
max_query_len
=
1
common_attn_metadata
.
num_actual_tokens
=
batch_size
attn_metadata
.
query_start_loc
=
self
.
arange
[:
batch_size
+
1
]
common_attn_metadata
.
max_query_len
=
1
for
_
in
range
(
self
.
num_speculative_tokens
-
1
):
common_attn_metadata
.
query_start_loc
=
self
.
arange
[:
batch_size
+
1
]
common_attn_metadata
.
query_start_loc_cpu
=
torch
.
from_numpy
(
self
.
token_arange_np
[:
batch_size
+
1
]).
clone
()
for
token_index
in
range
(
self
.
num_speculative_tokens
-
1
):
# Update the inputs.
# Update the inputs.
# cast to int32 is crucial when eagle model is compiled.
# cast to int32 is crucial when eagle model is compiled.
# tensor.argmax() returns int64 by default.
# tensor.argmax() returns int64 by default.
...
@@ -286,27 +304,38 @@ class EagleProposer:
...
@@ -286,27 +304,38 @@ class EagleProposer:
positions
)
positions
)
# Increment the sequence lengths.
# Increment the sequence lengths.
attn_metadata
.
max_seq_len
+=
1
common_attn_metadata
.
seq_lens
+=
1
attn_metadata
.
seq_lens
+=
1
common_attn_metadata
.
seq_lens_cpu
+=
1
# Consider max model length.
attn_metadata
.
max_seq_len
=
min
(
attn_metadata
.
max_seq_len
,
self
.
max_model_len
)
# For the requests that exceed the max model length, we set the
# For the requests that exceed the max model length, we set the
# sequence length to 1 to minimize their overheads in attention.
# sequence length to 1 to minimize their overheads in attention.
attn_metadata
.
seq_lens
.
masked_fill_
(
exceeds_max_model_len
,
1
)
common_attn_metadata
.
seq_lens
.
masked_fill_
(
exceeds_max_model_len
,
1
)
common_attn_metadata
.
num_computed_tokens_cpu
=
\
common_attn_metadata
.
seq_lens_cpu
-
1
# Compute the slot mapping.
# Compute the slot mapping.
block_numbers
=
clamped_positions
//
self
.
block_size
block_numbers
=
clamped_positions
//
self
.
block_size
block_ids
=
attn_metadata
.
block_table
.
gather
(
block_ids
=
common_
attn_metadata
.
block_table
_tensor
.
gather
(
dim
=
1
,
index
=
block_numbers
.
view
(
-
1
,
1
))
dim
=
1
,
index
=
block_numbers
.
view
(
-
1
,
1
))
block_ids
=
block_ids
.
view
(
-
1
)
block_ids
=
block_ids
.
view
(
-
1
)
attn_metadata
.
slot_mapping
=
(
block_ids
*
self
.
block_size
+
common_attn_metadata
.
slot_mapping
=
(
block_ids
*
self
.
block_size
+
clamped_positions
%
self
.
block_size
)
clamped_positions
%
self
.
block_size
)
# Mask out the slot mappings that exceed the max model length.
# Mask out the slot mappings that exceed the max model length.
# Otherwise, the KV cache will be inadvertently updated with the
# Otherwise, the KV cache will be inadvertently updated with the
# padding tokens.
# padding tokens.
attn_metadata
.
slot_mapping
.
masked_fill_
(
exceeds_max_model_len
,
common_attn_metadata
.
slot_mapping
.
masked_fill_
(
PADDING_SLOT_ID
)
exceeds_max_model_len
,
PADDING_SLOT_ID
)
# Rebuild attention metadata
attn_metadata_builder
=
\
self
.
runner
.
attn_groups
[
0
][
0
].
metadata_builders
[
ubatch_id
]
attn_metadata
=
attn_metadata_builder
\
.
build_for_drafting
(
common_attn_metadata
=
common_attn_metadata
,
draft_index
=
token_index
+
1
)
for
layer_name
in
self
.
attn_layer_names
:
per_layer_attn_metadata
[
layer_name
]
=
attn_metadata
# copy inputs to buffer for cudagraph
# copy inputs to buffer for cudagraph
self
.
input_ids
[:
batch_size
]
=
input_ids
self
.
input_ids
[:
batch_size
]
=
input_ids
...
@@ -347,6 +376,158 @@ class EagleProposer:
...
@@ -347,6 +376,158 @@ class EagleProposer:
draft_token_ids
=
torch
.
stack
(
draft_token_ids_list
,
dim
=
1
)
draft_token_ids
=
torch
.
stack
(
draft_token_ids_list
,
dim
=
1
)
return
draft_token_ids
return
draft_token_ids
def
prepare_next_token_ids_cpu
(
self
,
sampled_token_ids
:
list
[
list
[
int
]],
requests
:
dict
[
str
,
CachedRequestState
],
gpu_input_batch
:
InputBatch
,
num_scheduled_tokens
:
dict
[
str
,
int
])
->
torch
.
Tensor
:
"""
This function is used to prepare the inputs for speculative decoding.
It calculates the next token ids for each request based on the sampled
token ids from the CPU. If a request has no sampled token ids (e.g.,
during the initial decoding steps), it falls back to using the request
state to get the next token id.
"""
req_ids
=
gpu_input_batch
.
req_ids
next_token_ids
:
list
[
int
]
=
[]
for
i
,
token_ids
in
enumerate
(
sampled_token_ids
):
if
token_ids
:
# Common case.
next_token_id
=
token_ids
[
-
1
]
else
:
# Partial prefill (rare case).
# Get the next token id from the request state.
req_id
=
req_ids
[
i
]
req_state
=
requests
[
req_id
]
seq_len
=
(
req_state
.
num_computed_tokens
+
num_scheduled_tokens
[
req_id
])
next_token_id
=
req_state
.
get_token_id
(
seq_len
)
next_token_ids
.
append
(
next_token_id
)
next_token_ids
=
torch
.
tensor
(
next_token_ids
,
dtype
=
torch
.
int32
,
device
=
self
.
input_ids
.
device
)
return
next_token_ids
def
prepare_next_token_ids_padded
(
self
,
common_attn_metadata
:
CommonAttentionMetadata
,
sampled_token_ids
:
torch
.
Tensor
,
requests
:
dict
[
str
,
CachedRequestState
],
gpu_input_batch
:
InputBatch
,
discard_request_indices
:
torch
.
Tensor
,
num_discarded_requests
:
int
)
->
\
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
This function is used to prepare the inputs for speculative decoding.
It calculates the next token ids and the number of valid sampled tokens
for each request, considering the "discarded" requests whose next token
is not sampled and comes from `request.get_token_id()` instead.
It also accounts for the rejected tokens in `sampled_token_ids`.
This function must use device functions to operate on the inputs, and
should not introduce any blocking CPU-GPU synchronization.
"""
# TODO(Ben): Combine this into a custom fused kernel
# Precompute get_token_id for when there is no valid next token
num_reqs
=
gpu_input_batch
.
num_reqs
self
.
backup_next_token_ids
.
np
[:
num_reqs
]
=
np
.
array
([
requests
[
gpu_input_batch
.
req_ids
[
i
]].
get_token_id
(
common_attn_metadata
.
seq_lens_cpu
[
i
].
item
())
for
i
in
range
(
num_reqs
)
])
self
.
backup_next_token_ids
.
copy_to_gpu
(
num_reqs
)
# Mask out the sampled tokens indices that should not be sampled.
discard_sampled_tokens_req_indices
=
\
discard_request_indices
[:
num_discarded_requests
]
valid_sampled_token_ids_gpu
=
sampled_token_ids
.
clone
()
valid_sampled_token_ids_gpu
.
index_fill_
(
0
,
discard_sampled_tokens_req_indices
,
-
1
)
# Generate a mask for all valid tokens within those requests
max_gen_len
=
sampled_token_ids
.
shape
[
-
1
]
if
max_gen_len
==
1
:
valid_mask
=
torch
.
ones_like
(
valid_sampled_token_ids_gpu
,
dtype
=
torch
.
bool
)
else
:
valid_mask
=
(
(
valid_sampled_token_ids_gpu
!=
-
1
)
&
(
valid_sampled_token_ids_gpu
<
gpu_input_batch
.
vocab_size
))
# Count the number of valid tokens in each request
valid_sampled_tokens_count
=
valid_mask
.
sum
(
dim
=
1
)
# Get the rightmost valid index per row
last_valid_indices
=
valid_sampled_tokens_count
-
1
last_valid_indices_safe
=
torch
.
clamp
(
last_valid_indices
,
min
=
0
)
# Get last valid token from each row
# (assume undefined state where there is no valid token)
selected_tokens
=
torch
.
gather
(
valid_sampled_token_ids_gpu
,
1
,
last_valid_indices_safe
.
unsqueeze
(
1
)).
squeeze
(
1
)
# Use last token if valid, pre-computed backup if not
batch_size
=
valid_sampled_token_ids_gpu
.
shape
[
0
]
next_token_ids
=
torch
.
where
(
last_valid_indices
!=
-
1
,
selected_tokens
,
self
.
backup_next_token_ids
.
gpu
[:
batch_size
])
return
next_token_ids
,
valid_sampled_tokens_count
def
prepare_inputs_padded
(
self
,
common_attn_metadata
:
CommonAttentionMetadata
,
spec_decode_metadata
:
SpecDecodeMetadata
,
valid_sampled_tokens_count
:
torch
.
Tensor
)
->
\
tuple
[
CommonAttentionMetadata
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
This function is used to prepare the inputs for speculative decoding
It updates the common_attn_metadata for speculative decoding,
but does not consider the rejected tokens. Instead, all tokens
are included as inputs to the speculator, with the rejected tokens
used as padding and filtered out later by `token_indices_to_sample`.
No blocking CPU operations should be introduced in this function.
"""
num_draft_tokens_gpu
=
torch
.
cat
([
spec_decode_metadata
.
cu_num_draft_tokens
[
0
:
1
],
spec_decode_metadata
.
cu_num_draft_tokens
[
1
:]
-
spec_decode_metadata
.
cu_num_draft_tokens
[:
-
1
]
])
num_rejected_tokens_gpu
=
torch
.
where
(
num_draft_tokens_gpu
>
0
,
num_draft_tokens_gpu
+
1
-
valid_sampled_tokens_count
,
torch
.
zeros_like
(
num_draft_tokens_gpu
))
query_start_loc_cpu
=
common_attn_metadata
.
query_start_loc_cpu
new_query_len_per_req
=
(
query_start_loc_cpu
[
1
:]
-
query_start_loc_cpu
[:
-
1
])
total_num_tokens
=
query_start_loc_cpu
[
-
1
].
item
()
token_indices
=
self
.
arange
[:
total_num_tokens
]
spec_common_attn_metadata
=
CommonAttentionMetadata
(
query_start_loc
=
common_attn_metadata
.
query_start_loc
,
seq_lens
=
common_attn_metadata
.
seq_lens
,
query_start_loc_cpu
=
query_start_loc_cpu
,
seq_lens_cpu
=
common_attn_metadata
.
seq_lens_cpu
,
num_computed_tokens_cpu
=
common_attn_metadata
.
num_computed_tokens_cpu
,
num_reqs
=
common_attn_metadata
.
num_reqs
,
num_actual_tokens
=
total_num_tokens
,
max_query_len
=
new_query_len_per_req
.
max
().
item
(),
max_seq_len
=
common_attn_metadata
.
seq_lens_cpu
.
max
().
item
(),
block_table_tensor
=
common_attn_metadata
.
block_table_tensor
,
slot_mapping
=
common_attn_metadata
.
slot_mapping
[
token_indices
],
causal
=
True
,
)
token_indices_to_sample
=
common_attn_metadata
.
query_start_loc
[
1
:]
-
1
\
-
num_rejected_tokens_gpu
return
spec_common_attn_metadata
,
token_indices
,
token_indices_to_sample
def
propose_tree
(
def
propose_tree
(
self
,
self
,
batch_size
:
int
,
batch_size
:
int
,
...
@@ -520,11 +701,11 @@ class EagleProposer:
...
@@ -520,11 +701,11 @@ class EagleProposer:
def
prepare_inputs
(
def
prepare_inputs
(
self
,
self
,
common_attn_metadata
:
CommonAttentionMetadata
,
common_attn_metadata
:
CommonAttentionMetadata
,
# [batch_size]
sampled_token_ids
:
list
[
list
[
int
]],
num_
rejected
_tokens
:
torch
.
Tensor
num_
draft
_tokens
:
list
[
int
],
)
->
tuple
[
CommonAttentionMetadata
,
torch
.
Tensor
]:
)
->
tuple
[
CommonAttentionMetadata
,
torch
.
Tensor
]:
"""
"""
This function is used to prepare the inputs for
the
spec decod
e
.
This function is used to prepare the inputs for spec
ulative
decod
ing
.
It updates to the common_attn_metadata to account for the rejected
It updates to the common_attn_metadata to account for the rejected
tokens (and newly sampled tokens). It also returns the token indices
tokens (and newly sampled tokens). It also returns the token indices
of the tokens that should be fed to the speculator.
of the tokens that should be fed to the speculator.
...
@@ -545,6 +726,13 @@ class EagleProposer:
...
@@ -545,6 +726,13 @@ class EagleProposer:
# q1, q1 + 1, ..., q1 + q2 - n2 - 1,
# q1, q1 + 1, ..., q1 + q2 - n2 - 1,
# q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1]
# q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1]
num_rejected_tokens
=
[
n
+
1
-
len
(
sampled_token_ids
[
i
])
if
n
>
0
else
0
for
i
,
n
in
enumerate
(
num_draft_tokens
)
]
num_rejected_tokens
=
torch
.
tensor
(
num_rejected_tokens
,
dtype
=
torch
.
int32
)
device
=
common_attn_metadata
.
query_start_loc
.
device
device
=
common_attn_metadata
.
query_start_loc
.
device
query_start_loc_cpu
=
common_attn_metadata
.
query_start_loc_cpu
query_start_loc_cpu
=
common_attn_metadata
.
query_start_loc_cpu
new_seq_lens_cpu
=
common_attn_metadata
.
seq_lens_cpu
\
new_seq_lens_cpu
=
common_attn_metadata
.
seq_lens_cpu
\
...
...
vllm/v1/worker/gpu_input_batch.py
View file @
b7433ca1
...
@@ -64,7 +64,10 @@ class CachedRequestState:
...
@@ -64,7 +64,10 @@ class CachedRequestState:
def
get_token_id
(
self
,
idx
:
int
)
->
int
:
def
get_token_id
(
self
,
idx
:
int
)
->
int
:
if
idx
<
self
.
num_prompt_tokens
:
if
idx
<
self
.
num_prompt_tokens
:
return
self
.
prompt_token_ids
[
idx
]
return
self
.
prompt_token_ids
[
idx
]
elif
idx
-
self
.
num_prompt_tokens
<
len
(
self
.
output_token_ids
):
return
self
.
output_token_ids
[
idx
-
self
.
num_prompt_tokens
]
return
self
.
output_token_ids
[
idx
-
self
.
num_prompt_tokens
]
else
:
return
-
1
class
InputBatch
:
class
InputBatch
:
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
b7433ca1
...
@@ -344,6 +344,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -344,6 +344,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self
.
hidden_size
,
self
.
hidden_size
,
dtype
=
self
.
dtype
,
dtype
=
self
.
dtype
,
numpy
=
False
)
numpy
=
False
)
self
.
discard_request_indices
=
self
.
_make_buffer
(
self
.
max_num_reqs
,
dtype
=
torch
.
int64
)
self
.
num_discarded_requests
=
0
self
.
num_draft_tokens
=
self
.
_make_buffer
(
self
.
max_num_reqs
,
self
.
num_draft_tokens
=
self
.
_make_buffer
(
self
.
max_num_reqs
,
dtype
=
torch
.
int32
)
dtype
=
torch
.
int32
)
self
.
num_accepted_tokens
=
self
.
_make_buffer
(
self
.
max_num_reqs
,
self
.
num_accepted_tokens
=
self
.
_make_buffer
(
self
.
max_num_reqs
,
...
@@ -974,6 +978,21 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -974,6 +978,21 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
seq_lens
=
self
.
seq_lens
.
gpu
[:
num_reqs
]
seq_lens
=
self
.
seq_lens
.
gpu
[:
num_reqs
]
max_seq_len
=
self
.
seq_lens
.
np
[:
num_reqs
].
max
().
item
()
max_seq_len
=
self
.
seq_lens
.
np
[:
num_reqs
].
max
().
item
()
num_tokens
=
[
self
.
requests
[
r
].
num_tokens
for
r
in
self
.
input_batch
.
req_ids
]
num_tokens_np
=
np
.
array
(
num_tokens
,
dtype
=
np
.
int32
)
# Record the index of requests that should not be sampled,
# so that we could clear the sampled tokens before returning
discard_requests_mask
=
self
.
seq_lens
.
np
[:
num_reqs
]
<
num_tokens_np
discard_request_indices
=
np
.
nonzero
(
discard_requests_mask
)[
0
]
self
.
num_discarded_requests
=
len
(
discard_request_indices
)
self
.
discard_request_indices
.
np
[:
self
.
num_discarded_requests
]
=
(
discard_request_indices
)
self
.
discard_request_indices
.
copy_to_gpu
(
self
.
num_discarded_requests
)
# Copy the tensors to the GPU.
# Copy the tensors to the GPU.
self
.
_prepare_input_ids
(
total_num_scheduled_tokens
,
cu_num_tokens
)
self
.
_prepare_input_ids
(
total_num_scheduled_tokens
,
cu_num_tokens
)
...
@@ -1973,23 +1992,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -1973,23 +1992,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if
envs
.
VLLM_COMPUTE_NANS_IN_LOGITS
:
if
envs
.
VLLM_COMPUTE_NANS_IN_LOGITS
:
num_nans_in_logits
=
self
.
_get_nans_in_logits
(
logits
)
num_nans_in_logits
=
self
.
_get_nans_in_logits
(
logits
)
# TODO(woosuk): The following loop can be slow since it iterates over
discard_sampled_tokens_req_indices
=
\
# the requests one by one. Optimize.
self
.
discard_request_indices
.
np
[:
self
.
num_discarded_requests
]
discard_sampled_tokens_req_indices
=
[]
for
i
in
discard_sampled_tokens_req_indices
:
for
i
,
req_id
in
enumerate
(
self
.
input_batch
.
req_ids
):
gen
=
self
.
input_batch
.
generators
.
get
(
int
(
i
))
req_state
=
self
.
requests
[
req_id
]
if
gen
is
not
None
:
seq_len
=
(
req_state
.
num_computed_tokens
+
gen
.
set_offset
(
gen
.
get_offset
()
-
4
)
scheduler_output
.
num_scheduled_tokens
[
req_id
])
if
seq_len
<
req_state
.
num_tokens
:
# Ignore the sampled token for partial prefills.
# Rewind the generator state as if the token was not sampled.
# This relies on cuda-specific torch-internal impl details
generator
=
self
.
input_batch
.
generators
.
get
(
i
)
if
generator
is
not
None
:
generator
.
set_offset
(
generator
.
get_offset
()
-
4
)
# Record the index of the request that should not be sampled,
# so that we could clear the sampled tokens before returning.
discard_sampled_tokens_req_indices
.
append
(
i
)
# Copy some objects so they don't get modified after returning.
# Copy some objects so they don't get modified after returning.
# This is important when using async scheduling.
# This is important when using async scheduling.
...
@@ -2026,10 +2034,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -2026,10 +2034,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
)
)
# Mask out the sampled tokens that should not be sampled.
# Mask out the sampled tokens that should not be sampled.
for
i
in
discard_sampled_tokens_req_indices
:
for
i
in
discard_sampled_tokens_req_indices
:
valid_sampled_token_ids
[
i
].
clear
()
valid_sampled_token_ids
[
i
nt
(
i
)
].
clear
()
else
:
else
:
valid_sampled_token_ids
=
[]
valid_sampled_token_ids
=
[]
invalid_req_indices
=
list
(
discard_sampled_tokens_req_indices
)
invalid_req_indices
=
discard_sampled_tokens_req_indices
.
tolist
(
)
invalid_req_indices_set
=
set
(
invalid_req_indices
)
invalid_req_indices_set
=
set
(
invalid_req_indices
)
assert
sampled_token_ids
.
shape
[
-
1
]
==
1
assert
sampled_token_ids
.
shape
[
-
1
]
==
1
...
@@ -2229,6 +2237,28 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -2229,6 +2237,28 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
with
record_function_or_nullcontext
(
"Sample"
):
with
record_function_or_nullcontext
(
"Sample"
):
sampler_output
=
self
.
_sample
(
logits
,
spec_decode_metadata
)
sampler_output
=
self
.
_sample
(
logits
,
spec_decode_metadata
)
def
propose_draft_token_ids
(
sampled_token_ids
):
assert
spec_decode_common_attn_metadata
is
not
None
with
record_function_or_nullcontext
(
"Draft"
):
self
.
_draft_token_ids
=
self
.
propose_draft_token_ids
(
scheduler_output
,
sampled_token_ids
,
self
.
input_batch
.
sampling_metadata
,
hidden_states
,
sample_hidden_states
,
aux_hidden_states
,
spec_decode_metadata
,
spec_decode_common_attn_metadata
,
)
use_padded_batch_for_eagle
=
self
.
speculative_config
and
\
self
.
speculative_config
.
use_eagle
()
and
\
not
self
.
speculative_config
.
disable_padded_drafter_batch
if
use_padded_batch_for_eagle
:
# EAGLE speculative decoding can use the GPU sampled tokens
# as inputs, and does not need to wait for bookkeeping to finish.
propose_draft_token_ids
(
sampler_output
.
sampled_token_ids
)
with
record_function_or_nullcontext
(
"Bookkeep"
):
with
record_function_or_nullcontext
(
"Bookkeep"
):
(
(
num_nans_in_logits
,
num_nans_in_logits
,
...
@@ -2242,19 +2272,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -2242,19 +2272,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
logits
,
hidden_states
,
logits
,
hidden_states
,
num_scheduled_tokens
)
num_scheduled_tokens
)
if
self
.
speculative_config
:
if
self
.
speculative_config
and
not
use_padded_batch_for_eagle
:
assert
spec_decode_common_attn_metadata
is
not
None
# ngram and other speculative decoding methods use the sampled
with
record_function_or_nullcontext
(
"Draft"
):
# tokens on the CPU, so they are run after bookkeeping.
self
.
_draft_token_ids
=
self
.
propose_draft_token_ids
(
propose_draft_token_ids
(
valid_sampled_token_ids
)
scheduler_output
,
valid_sampled_token_ids
,
self
.
input_batch
.
sampling_metadata
,
hidden_states
,
sample_hidden_states
,
aux_hidden_states
,
spec_decode_metadata
,
spec_decode_common_attn_metadata
,
)
with
record_function_or_nullcontext
(
"EPLB"
):
with
record_function_or_nullcontext
(
"EPLB"
):
self
.
eplb_step
()
self
.
eplb_step
()
...
@@ -2294,7 +2315,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -2294,7 +2315,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def
propose_draft_token_ids
(
def
propose_draft_token_ids
(
self
,
self
,
scheduler_output
:
"SchedulerOutput"
,
scheduler_output
:
"SchedulerOutput"
,
sampled_token_ids
:
list
[
list
[
int
]],
sampled_token_ids
:
Union
[
torch
.
Tensor
,
list
[
list
[
int
]]
]
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
sample_hidden_states
:
torch
.
Tensor
,
sample_hidden_states
:
torch
.
Tensor
,
...
@@ -2304,11 +2325,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -2304,11 +2325,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
)
->
Union
[
list
[
list
[
int
]],
torch
.
Tensor
]:
)
->
Union
[
list
[
list
[
int
]],
torch
.
Tensor
]:
num_scheduled_tokens
=
scheduler_output
.
total_num_scheduled_tokens
num_scheduled_tokens
=
scheduler_output
.
total_num_scheduled_tokens
if
self
.
speculative_config
.
method
==
"ngram"
:
if
self
.
speculative_config
.
method
==
"ngram"
:
assert
isinstance
(
sampled_token_ids
,
list
)
assert
isinstance
(
self
.
drafter
,
NgramProposer
)
assert
isinstance
(
self
.
drafter
,
NgramProposer
)
draft_token_ids
=
self
.
propose_ngram_draft_token_ids
(
draft_token_ids
=
self
.
propose_ngram_draft_token_ids
(
sampled_token_ids
)
sampled_token_ids
)
elif
self
.
speculative_config
.
method
==
"medusa"
:
elif
self
.
speculative_config
.
method
==
"medusa"
:
assert
isinstance
(
sampled_token_ids
,
list
)
assert
isinstance
(
self
.
drafter
,
MedusaProposer
)
assert
isinstance
(
self
.
drafter
,
MedusaProposer
)
if
sample_hidden_states
.
shape
[
0
]
==
len
(
sampled_token_ids
):
if
sample_hidden_states
.
shape
[
0
]
==
len
(
sampled_token_ids
):
# The input to the target model does not include draft tokens.
# The input to the target model does not include draft tokens.
hidden_states
=
sample_hidden_states
hidden_states
=
sample_hidden_states
...
@@ -2329,27 +2353,37 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -2329,27 +2353,37 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
)
)
elif
self
.
speculative_config
.
use_eagle
():
elif
self
.
speculative_config
.
use_eagle
():
assert
isinstance
(
self
.
drafter
,
EagleProposer
)
assert
isinstance
(
self
.
drafter
,
EagleProposer
)
# TODO(woosuk): Refactor the loop.
req_ids
=
self
.
input_batch
.
req_ids
if
self
.
speculative_config
.
disable_padded_drafter_batch
:
next_token_ids
:
list
[
int
]
=
[]
# When padded-batch is disabled, the sampled_token_ids should be
for
i
,
token_ids
in
enumerate
(
sampled_token_ids
):
# the cpu-side list[list[int]] of valid sampled tokens for each
if
token_ids
:
# request, with invalid requests having empty lists.
# Common case.
assert
isinstance
(
sampled_token_ids
,
list
),
\
next_token_id
=
token_ids
[
-
1
]
"sampled_token_ids should be a python list when"
\
"padded-batch is disabled."
next_token_ids
=
self
.
drafter
.
prepare_next_token_ids_cpu
(
sampled_token_ids
,
self
.
requests
,
self
.
input_batch
,
scheduler_output
.
num_scheduled_tokens
)
else
:
else
:
# Partial prefill (rare case).
# When using padded-batch, the sampled_token_ids should be
# Get the next token id from the request state.
# the gpu tensor of sampled tokens for each request, of shape
req_id
=
req_ids
[
i
]
# (num_reqs, num_spec_tokens + 1) with rejected tokens having
req_state
=
self
.
requests
[
req_id
]
# value -1.
seq_len
=
(
req_state
.
num_computed_tokens
+
assert
isinstance
(
sampled_token_ids
,
torch
.
Tensor
),
\
scheduler_output
.
num_scheduled_tokens
[
req_id
])
"sampled_token_ids should be a torch.Tensor when"
\
next_token_id
=
req_state
.
get_token_id
(
seq_len
)
"padded-batch is enabled."
next_token_ids
.
append
(
next_token_id
)
next_token_ids
,
valid_sampled_tokens_count
=
\
next_token_ids
=
torch
.
tensor
(
next_token_ids
,
self
.
drafter
.
prepare_next_token_ids_padded
(
dtype
=
torch
.
int32
,
common_attn_metadata
,
device
=
self
.
device
)
sampled_token_ids
,
self
.
requests
,
self
.
input_batch
,
self
.
discard_request_indices
.
gpu
,
self
.
num_discarded_requests
)
if
spec_decode_metadata
is
None
:
if
spec_decode_metadata
is
None
:
token_indices_to_sample
=
None
# input_ids can be None for multimodal models.
# input_ids can be None for multimodal models.
target_token_ids
=
self
.
input_ids
.
gpu
[:
num_scheduled_tokens
]
target_token_ids
=
self
.
input_ids
.
gpu
[:
num_scheduled_tokens
]
# TODO(woosuk): Support M-RoPE.
# TODO(woosuk): Support M-RoPE.
...
@@ -2361,17 +2395,20 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -2361,17 +2395,20 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
else
:
else
:
target_hidden_states
=
hidden_states
[:
num_scheduled_tokens
]
target_hidden_states
=
hidden_states
[:
num_scheduled_tokens
]
else
:
else
:
# TODO(woosuk): Refactor this.
if
self
.
speculative_config
.
disable_padded_drafter_batch
:
num_draft_tokens
=
spec_decode_metadata
.
num_draft_tokens
token_indices_to_sample
=
None
num_rejected_tokens
=
[
n
+
1
-
len
(
sampled_token_ids
[
i
])
if
n
>
0
else
0
for
i
,
n
in
enumerate
(
num_draft_tokens
)
]
num_rejected_tokens_cpu
=
torch
.
tensor
(
num_rejected_tokens
,
dtype
=
torch
.
int32
)
common_attn_metadata
,
token_indices
=
\
common_attn_metadata
,
token_indices
=
\
self
.
drafter
.
prepare_inputs
(
self
.
drafter
.
prepare_inputs
(
common_attn_metadata
,
num_rejected_tokens_cpu
)
common_attn_metadata
,
sampled_token_ids
,
spec_decode_metadata
.
num_draft_tokens
)
else
:
common_attn_metadata
,
token_indices
,
\
token_indices_to_sample
=
\
self
.
drafter
.
prepare_inputs_padded
(
common_attn_metadata
,
spec_decode_metadata
,
valid_sampled_tokens_count
)
target_token_ids
=
self
.
input_ids
.
gpu
[
token_indices
]
target_token_ids
=
self
.
input_ids
.
gpu
[
token_indices
]
# TODO(woosuk): Support M-RoPE.
# TODO(woosuk): Support M-RoPE.
...
@@ -2391,6 +2428,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -2391,6 +2428,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
target_positions
=
target_positions
,
target_positions
=
target_positions
,
target_hidden_states
=
target_hidden_states
,
target_hidden_states
=
target_hidden_states
,
next_token_ids
=
next_token_ids
,
next_token_ids
=
next_token_ids
,
last_token_indices
=
token_indices_to_sample
,
sampling_metadata
=
sampling_metadata
,
sampling_metadata
=
sampling_metadata
,
common_attn_metadata
=
common_attn_metadata
,
common_attn_metadata
=
common_attn_metadata
,
mm_embeds
=
mm_embeds
,
mm_embeds
=
mm_embeds
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment