Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
af3162d3
Unverified
Commit
af3162d3
authored
Feb 05, 2026
by
Benjamin Chislett
Committed by
GitHub
Feb 05, 2026
Browse files
[Spec Decode] Unified Parallel Drafting (#32887)
Signed-off-by:
Benjamin Chislett
<
bchislett@nvidia.com
>
parent
5b2a9422
Changes
14
Show whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
1085 additions
and
392 deletions
+1085
-392
examples/offline_inference/spec_decode.py
examples/offline_inference/spec_decode.py
+3
-0
tests/v1/e2e/test_spec_decode.py
tests/v1/e2e/test_spec_decode.py
+32
-63
tests/v1/spec_decode/test_eagle.py
tests/v1/spec_decode/test_eagle.py
+408
-6
tests/v1/spec_decode/test_mtp.py
tests/v1/spec_decode/test_mtp.py
+1
-1
vllm/config/speculative.py
vllm/config/speculative.py
+7
-0
vllm/config/vllm.py
vllm/config/vllm.py
+18
-10
vllm/model_executor/models/llama_eagle3.py
vllm/model_executor/models/llama_eagle3.py
+39
-3
vllm/v1/attention/backend.py
vllm/v1/attention/backend.py
+6
-1
vllm/v1/attention/backends/flashinfer.py
vllm/v1/attention/backends/flashinfer.py
+30
-6
vllm/v1/attention/backends/utils.py
vllm/v1/attention/backends/utils.py
+0
-32
vllm/v1/spec_decode/draft_model.py
vllm/v1/spec_decode/draft_model.py
+23
-222
vllm/v1/spec_decode/eagle.py
vllm/v1/spec_decode/eagle.py
+269
-47
vllm/v1/spec_decode/utils.py
vllm/v1/spec_decode/utils.py
+248
-0
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+1
-1
No files found.
examples/offline_inference/spec_decode.py
View file @
af3162d3
...
@@ -75,6 +75,7 @@ def parse_args():
...
@@ -75,6 +75,7 @@ def parse_args():
parser
.
add_argument
(
"--gpu-memory-utilization"
,
type
=
float
,
default
=
0.9
)
parser
.
add_argument
(
"--gpu-memory-utilization"
,
type
=
float
,
default
=
0.9
)
parser
.
add_argument
(
"--disable-padded-drafter-batch"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--disable-padded-drafter-batch"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--max-num-seqs"
,
type
=
int
,
default
=
None
)
parser
.
add_argument
(
"--max-num-seqs"
,
type
=
int
,
default
=
None
)
parser
.
add_argument
(
"--parallel-drafting"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--allowed-local-media-path"
,
type
=
str
,
default
=
""
)
parser
.
add_argument
(
"--allowed-local-media-path"
,
type
=
str
,
default
=
""
)
return
parser
.
parse_args
()
return
parser
.
parse_args
()
...
@@ -121,6 +122,7 @@ def main(args):
...
@@ -121,6 +122,7 @@ def main(args):
"model"
:
eagle_dir
,
"model"
:
eagle_dir
,
"num_speculative_tokens"
:
args
.
num_spec_tokens
,
"num_speculative_tokens"
:
args
.
num_spec_tokens
,
"disable_padded_drafter_batch"
:
args
.
disable_padded_drafter_batch
,
"disable_padded_drafter_batch"
:
args
.
disable_padded_drafter_batch
,
"parallel_drafting"
:
args
.
parallel_drafting
,
}
}
elif
args
.
method
==
"ngram"
:
elif
args
.
method
==
"ngram"
:
speculative_config
=
{
speculative_config
=
{
...
@@ -137,6 +139,7 @@ def main(args):
...
@@ -137,6 +139,7 @@ def main(args):
"num_speculative_tokens"
:
args
.
num_spec_tokens
,
"num_speculative_tokens"
:
args
.
num_spec_tokens
,
"enforce_eager"
:
args
.
enforce_eager
,
"enforce_eager"
:
args
.
enforce_eager
,
"max_model_len"
:
args
.
max_model_len
,
"max_model_len"
:
args
.
max_model_len
,
"parallel_drafting"
:
args
.
parallel_drafting
,
}
}
elif
args
.
method
==
"mtp"
:
elif
args
.
method
==
"mtp"
:
speculative_config
=
{
speculative_config
=
{
...
...
tests/v1/e2e/test_spec_decode.py
View file @
af3162d3
...
@@ -13,15 +13,12 @@ from vllm import LLM, SamplingParams
...
@@ -13,15 +13,12 @@ from vllm import LLM, SamplingParams
from
vllm.assets.base
import
VLLM_S3_BUCKET_URL
from
vllm.assets.base
import
VLLM_S3_BUCKET_URL
from
vllm.assets.image
import
VLM_IMAGES_DIR
from
vllm.assets.image
import
VLM_IMAGES_DIR
from
vllm.benchmarks.datasets
import
InstructCoderDataset
from
vllm.benchmarks.datasets
import
InstructCoderDataset
from
vllm.config
.vllm
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.distributed
import
cleanup_dist_env_and_memory
from
vllm.distributed
import
cleanup_dist_env_and_memory
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.v1.metrics.reader
import
Metric
from
vllm.v1.metrics.reader
import
Metric
from
vllm.v1.spec_decode.draft_model
import
(
from
vllm.v1.spec_decode.utils
import
create_vllm_config_for_draft_model
create_vllm_config_for_draft_model
,
merge_toks_kernel
,
)
MTP_SIMILARITY_RATE
=
0.8
MTP_SIMILARITY_RATE
=
0.8
...
@@ -625,6 +622,8 @@ class ArgsTest:
...
@@ -625,6 +622,8 @@ class ArgsTest:
expected_acceptance_rate
:
float
expected_acceptance_rate
:
float
expected_acceptance_len
:
float
expected_acceptance_len
:
float
# Defaults
# Defaults
enforce_eager
:
bool
=
True
parallel_drafting
:
bool
=
False
target_tensor_parallel_size
:
int
=
1
target_tensor_parallel_size
:
int
=
1
draft_tensor_parallel_size
:
int
=
1
draft_tensor_parallel_size
:
int
=
1
max_model_len
:
int
=
1024
max_model_len
:
int
=
1024
...
@@ -658,7 +657,8 @@ cases = [
...
@@ -658,7 +657,8 @@ cases = [
@
pytest
.
mark
.
parametrize
(
"args"
,
cases
)
@
pytest
.
mark
.
parametrize
(
"args"
,
cases
)
@
pytest
.
mark
.
parametrize
(
"enforce_eager"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"enforce_eager"
,
[
True
,
False
])
def
test_draft_model_correctness
(
args
:
ArgsTest
,
enforce_eager
:
bool
):
def
test_draft_model_correctness
(
args
:
ArgsTest
,
enforce_eager
:
bool
):
assert_draft_model_correctness
(
args
,
enforce_eager
)
args
.
enforce_eager
=
enforce_eager
assert_draft_model_correctness
(
args
)
def
test_draft_model_realistic_example
():
def
test_draft_model_realistic_example
():
...
@@ -668,11 +668,28 @@ def test_draft_model_realistic_example():
...
@@ -668,11 +668,28 @@ def test_draft_model_realistic_example():
dataset
=
"likaixin/InstructCoder"
,
dataset
=
"likaixin/InstructCoder"
,
num_speculative_tokens
=
3
,
num_speculative_tokens
=
3
,
sampling_config
=
greedy_sampling
(),
sampling_config
=
greedy_sampling
(),
enforce_eager
=
False
,
# values below are not derived, but just prevent a regression
# values below are not derived, but just prevent a regression
expected_acceptance_len
=
2.8
,
expected_acceptance_len
=
2.8
,
expected_acceptance_rate
=
0.55
,
expected_acceptance_rate
=
0.55
,
)
)
assert_draft_model_correctness
(
args
,
enforce_eager
=
False
)
assert_draft_model_correctness
(
args
)
def
test_draft_model_parallel_drafting
():
args
=
ArgsTest
(
target_model
=
"Qwen/Qwen3-1.7B"
,
draft_model
=
"amd/PARD-Qwen3-0.6B"
,
dataset
=
"likaixin/InstructCoder"
,
num_speculative_tokens
=
3
,
sampling_config
=
greedy_sampling
(),
parallel_drafting
=
True
,
enforce_eager
=
False
,
# values below are collected from a stable run, with ~5% tolerance
expected_acceptance_len
=
2.375
,
expected_acceptance_rate
=
0.45
,
)
assert_draft_model_correctness
(
args
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
...
@@ -691,8 +708,9 @@ def test_draft_model_quantization(models: tuple[str, str], enforce_eager: bool):
...
@@ -691,8 +708,9 @@ def test_draft_model_quantization(models: tuple[str, str], enforce_eager: bool):
target_model
=
tgt_model
,
target_model
=
tgt_model
,
draft_model
=
draft_model
,
draft_model
=
draft_model
,
**
some_high_acceptance_metrics
(),
**
some_high_acceptance_metrics
(),
enforce_eager
=
enforce_eager
,
)
)
assert_draft_model_correctness
(
sd_case
,
enforce_eager
)
assert_draft_model_correctness
(
sd_case
)
def
test_draft_model_tensor_parallelism
():
def
test_draft_model_tensor_parallelism
():
...
@@ -704,8 +722,9 @@ def test_draft_model_tensor_parallelism():
...
@@ -704,8 +722,9 @@ def test_draft_model_tensor_parallelism():
draft_model
=
"Qwen/Qwen3-0.6B"
,
draft_model
=
"Qwen/Qwen3-0.6B"
,
draft_tensor_parallel_size
=
2
,
draft_tensor_parallel_size
=
2
,
**
some_high_acceptance_metrics
(),
**
some_high_acceptance_metrics
(),
enforce_eager
=
False
,
)
)
assert_draft_model_correctness
(
sd_case
,
enforce_eager
=
False
)
assert_draft_model_correctness
(
sd_case
)
def
test_draft_model_engine_args_tensor_parallelism
():
def
test_draft_model_engine_args_tensor_parallelism
():
...
@@ -750,7 +769,7 @@ def test_draft_model_engine_args_rejects_invalid_tp_argname():
...
@@ -750,7 +769,7 @@ def test_draft_model_engine_args_rejects_invalid_tp_argname():
engine_args
.
create_engine_config
()
engine_args
.
create_engine_config
()
def
assert_draft_model_correctness
(
args
:
ArgsTest
,
enforce_eager
:
bool
):
def
assert_draft_model_correctness
(
args
:
ArgsTest
):
"""Compare the outputs using and not using speculative decoding.
"""Compare the outputs using and not using speculative decoding.
In the greedy decoding case, the outputs must match EXACTLY."""
In the greedy decoding case, the outputs must match EXACTLY."""
test_prompts
:
list
[
Messages
]
=
get_messages
(
test_prompts
:
list
[
Messages
]
=
get_messages
(
...
@@ -764,14 +783,15 @@ def assert_draft_model_correctness(args: ArgsTest, enforce_eager: bool):
...
@@ -764,14 +783,15 @@ def assert_draft_model_correctness(args: ArgsTest, enforce_eager: bool):
"method"
:
"draft_model"
,
"method"
:
"draft_model"
,
"num_speculative_tokens"
:
args
.
num_speculative_tokens
,
"num_speculative_tokens"
:
args
.
num_speculative_tokens
,
"max_model_len"
:
args
.
max_model_len
,
"max_model_len"
:
args
.
max_model_len
,
"enforce_eager"
:
enforce_eager
,
"enforce_eager"
:
args
.
enforce_eager
,
"draft_tensor_parallel_size"
:
args
.
draft_tensor_parallel_size
,
"draft_tensor_parallel_size"
:
args
.
draft_tensor_parallel_size
,
"parallel_drafting"
:
args
.
parallel_drafting
,
},
},
max_num_seqs
=
100
,
# limit cudagraph capture runtime
max_num_seqs
=
100
,
# limit cudagraph capture runtime
max_model_len
=
args
.
max_model_len
,
max_model_len
=
args
.
max_model_len
,
gpu_memory_utilization
=
args
.
gpu_memory_utilization
,
gpu_memory_utilization
=
args
.
gpu_memory_utilization
,
tensor_parallel_size
=
args
.
target_tensor_parallel_size
,
tensor_parallel_size
=
args
.
target_tensor_parallel_size
,
enforce_eager
=
enforce_eager
,
enforce_eager
=
args
.
enforce_eager
,
disable_log_stats
=
False
,
# enables get_metrics()
disable_log_stats
=
False
,
# enables get_metrics()
)
)
# we don't check the outputs, only check the metrics
# we don't check the outputs, only check the metrics
...
@@ -813,57 +833,6 @@ def some_high_acceptance_metrics() -> dict:
...
@@ -813,57 +833,6 @@ def some_high_acceptance_metrics() -> dict:
}
}
def
test_merge_toks_kernel
():
device
=
"cuda"
merged_len
=
5
+
2
# len(target_toks) = 5, batch_size = 2
merged
=
torch
.
full
((
merged_len
,),
-
100
,
device
=
device
)
# -100 is arbitrary
is_rejected_tok
=
torch
.
full
((
merged_len
,),
True
,
device
=
device
)
grid
=
(
2
,)
merge_toks_kernel
[
grid
](
target_toks_ptr
=
torch
.
tensor
([
0
,
1
,
2
,
0
,
1
],
device
=
device
),
next_toks_ptr
=
torch
.
tensor
([
3
,
2
],
device
=
device
),
query_start_locs_ptr
=
torch
.
tensor
([
0
,
3
],
device
=
device
),
query_end_locs_ptr
=
torch
.
tensor
([
2
,
4
],
device
=
device
),
out_ptr_merged_toks
=
merged
,
out_ptr_is_rejected_tok
=
is_rejected_tok
,
target_toks_size
=
5
,
rejected_tok_fill
=-
1
,
)
expected_merged
=
torch
.
tensor
([
0
,
1
,
2
,
3
,
0
,
1
,
2
],
device
=
device
)
assert
torch
.
allclose
(
merged
,
expected_merged
)
expected_rejected_toks
=
torch
.
tensor
([
False
]
*
merged_len
,
device
=
device
)
assert
torch
.
allclose
(
is_rejected_tok
,
expected_rejected_toks
)
def
test_merge_toks_kernel_with_rejected_tokens
():
device
=
"cuda"
merged_size
=
9
+
2
# len(target_toks) = 9, batch_size = 2
merged
=
torch
.
full
((
merged_size
,),
-
100
,
device
=
device
)
is_rejected_tok
=
torch
.
full
((
merged_size
,),
True
,
device
=
device
)
grid
=
(
2
,)
merge_toks_kernel
[
grid
](
# rejected tokens
# ↓ ↓ ↓ ↓
target_toks_ptr
=
torch
.
tensor
([
0
,
1
,
2
,
13
,
14
,
15
,
0
,
1
,
22
],
device
=
device
),
next_toks_ptr
=
torch
.
tensor
([
3
,
2
],
device
=
device
),
query_start_locs_ptr
=
torch
.
tensor
([
0
,
6
],
device
=
device
),
query_end_locs_ptr
=
torch
.
tensor
([
2
,
7
],
device
=
device
),
out_ptr_merged_toks
=
merged
,
out_ptr_is_rejected_tok
=
is_rejected_tok
,
target_toks_size
=
9
,
rejected_tok_fill
=-
1
,
)
expected_merged
=
torch
.
tensor
([
0
,
1
,
2
,
3
,
-
1
,
-
1
,
-
1
,
0
,
1
,
2
,
-
1
],
device
=
device
)
assert
torch
.
allclose
(
merged
,
expected_merged
)
expected_rejected_toks
=
torch
.
tensor
(
[
False
,
False
,
False
,
False
,
True
,
True
,
True
,
False
,
False
,
False
,
True
],
device
=
device
,
)
assert
torch
.
allclose
(
is_rejected_tok
,
expected_rejected_toks
)
def
compute_acceptance_rate
(
metrics
:
list
[
Metric
])
->
float
:
def
compute_acceptance_rate
(
metrics
:
list
[
Metric
])
->
float
:
name2metric
=
{
metric
.
name
:
metric
for
metric
in
metrics
}
name2metric
=
{
metric
.
name
:
metric
for
metric
in
metrics
}
n_draft_toks
=
name2metric
[
"vllm:spec_decode_num_draft_tokens"
].
value
# type: ignore
n_draft_toks
=
name2metric
[
"vllm:spec_decode_num_draft_tokens"
].
value
# type: ignore
...
...
tests/v1/spec_decode/test_eagle.py
View file @
af3162d3
...
@@ -27,6 +27,7 @@ from vllm.config.load import LoadConfig
...
@@ -27,6 +27,7 @@ from vllm.config.load import LoadConfig
from
vllm.model_executor.models.llama
import
LlamaForCausalLM
from
vllm.model_executor.models.llama
import
LlamaForCausalLM
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.v1.attention.backends.registry
import
AttentionBackendEnum
from
vllm.v1.attention.backends.registry
import
AttentionBackendEnum
from
vllm.v1.spec_decode.draft_model
import
DraftModelProposer
from
vllm.v1.spec_decode.eagle
import
EagleProposer
from
vllm.v1.spec_decode.eagle
import
EagleProposer
from
vllm.v1.spec_decode.metadata
import
SpecDecodeMetadata
from
vllm.v1.spec_decode.metadata
import
SpecDecodeMetadata
from
vllm.v1.worker.gpu_input_batch
import
CachedRequestState
,
InputBatch
from
vllm.v1.worker.gpu_input_batch
import
CachedRequestState
,
InputBatch
...
@@ -34,6 +35,7 @@ from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
...
@@ -34,6 +35,7 @@ from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
model_dir
=
"meta-llama/Llama-3.1-8B-Instruct"
model_dir
=
"meta-llama/Llama-3.1-8B-Instruct"
eagle_dir
=
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
eagle_dir
=
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
eagle3_dir
=
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
eagle3_dir
=
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
ar_draft_model_dir
=
"amd/PARD-Llama-3.2-1B"
# Compatible with parallel and AR drafting
def
_create_proposer
(
def
_create_proposer
(
...
@@ -41,11 +43,19 @@ def _create_proposer(
...
@@ -41,11 +43,19 @@ def _create_proposer(
num_speculative_tokens
:
int
,
num_speculative_tokens
:
int
,
attention_backend
:
str
|
None
=
None
,
attention_backend
:
str
|
None
=
None
,
speculative_token_tree
:
list
[
tuple
[
int
,
...]]
|
None
=
None
,
speculative_token_tree
:
list
[
tuple
[
int
,
...]]
|
None
=
None
,
parallel_drafting
:
bool
=
False
,
)
->
EagleProposer
:
)
->
EagleProposer
:
model_config
=
ModelConfig
(
model
=
model_dir
,
runner
=
"generate"
,
max_model_len
=
100
)
model_config
=
ModelConfig
(
model
=
model_dir
,
runner
=
"generate"
,
max_model_len
=
100
)
# Choose model directory based on method
# Method-dependent setup
draft_model_dir
=
eagle_dir
if
method
==
"eagle"
else
eagle3_dir
if
method
==
"eagle"
:
draft_model_dir
=
eagle_dir
elif
method
==
"eagle3"
:
draft_model_dir
=
eagle3_dir
elif
method
==
"draft_model"
:
draft_model_dir
=
ar_draft_model_dir
else
:
raise
ValueError
(
f
"Unknown method:
{
method
}
"
)
spec_token_tree_str
=
None
spec_token_tree_str
=
None
if
speculative_token_tree
is
not
None
:
if
speculative_token_tree
is
not
None
:
...
@@ -59,13 +69,18 @@ def _create_proposer(
...
@@ -59,13 +69,18 @@ def _create_proposer(
method
=
method
,
method
=
method
,
num_speculative_tokens
=
num_speculative_tokens
,
num_speculative_tokens
=
num_speculative_tokens
,
speculative_token_tree
=
spec_token_tree_str
,
speculative_token_tree
=
spec_token_tree_str
,
parallel_drafting
=
parallel_drafting
,
)
)
if
parallel_drafting
:
# Overwrite pard_token to avoid crash during init
speculative_config
.
draft_model_config
.
hf_config
.
pard_token
=
0
device
=
current_platform
.
device_type
vllm_config
=
VllmConfig
(
vllm_config
=
VllmConfig
(
model_config
=
model_config
,
model_config
=
model_config
,
cache_config
=
CacheConfig
(),
cache_config
=
CacheConfig
(),
speculative_config
=
speculative_config
,
speculative_config
=
speculative_config
,
device_config
=
DeviceConfig
(
device
=
current_platform
.
device_typ
e
),
device_config
=
DeviceConfig
(
device
=
devic
e
),
parallel_config
=
ParallelConfig
(),
parallel_config
=
ParallelConfig
(),
load_config
=
LoadConfig
(),
load_config
=
LoadConfig
(),
scheduler_config
=
SchedulerConfig
(
scheduler_config
=
SchedulerConfig
(
...
@@ -75,7 +90,10 @@ def _create_proposer(
...
@@ -75,7 +90,10 @@ def _create_proposer(
attention_config
=
AttentionConfig
(
backend
=
attention_backend
),
attention_config
=
AttentionConfig
(
backend
=
attention_backend
),
)
)
return
EagleProposer
(
vllm_config
=
vllm_config
,
device
=
current_platform
.
device_type
)
if
"eagle"
in
method
:
return
EagleProposer
(
vllm_config
=
vllm_config
,
device
=
device
)
else
:
return
DraftModelProposer
(
vllm_config
=
vllm_config
,
device
=
device
)
def
test_prepare_next_token_ids
():
def
test_prepare_next_token_ids
():
...
@@ -321,6 +339,390 @@ def test_prepare_inputs_padded():
...
@@ -321,6 +339,390 @@ def test_prepare_inputs_padded():
assert
torch
.
equal
(
token_indices_to_sample
,
expected_token_indices_to_sample
)
assert
torch
.
equal
(
token_indices_to_sample
,
expected_token_indices_to_sample
)
def
test_set_inputs_first_pass_default_eagle
():
"""
Test for set_inputs_first_pass without extra input slots (default EAGLE).
This tests the path where needs_extra_input_slots=False, which is the
default EAGLE pathway. In this case:
- Input IDs are rotated (shifted by one)
- The next_token_ids are inserted at the last position of each request
- Positions are copied as-is
- Hidden states are copied as-is
- The CommonAttentionMetadata is returned unchanged
Setup:
- 3 requests with query_lens [3, 2, 4]
- Tokens: [a1, a2, a3, b1, b2, c1, c2, c3, c4]
- After rotation: [a2, a3, -, b2, -, c2, c3, c4, -]
- After inserting next_tokens [100, 200, 300]:
[a2, a3, 100, b2, 200, c2, c3, c4, 300]
"""
device
=
torch
.
device
(
current_platform
.
device_type
)
num_speculative_tokens
=
3
proposer
=
_create_proposer
(
"eagle"
,
num_speculative_tokens
)
# Setup batch with 3 requests
batch_spec
=
BatchSpec
(
seq_lens
=
[
10
,
8
,
12
],
# Arbitrary context lengths
query_lens
=
[
3
,
2
,
4
],
)
common_attn_metadata
=
create_common_attn_metadata
(
batch_spec
,
block_size
=
16
,
device
=
device
,
)
# Input tensors
# Request 0: tokens [10, 11, 12] at positions [7, 8, 9]
# Request 1: tokens [20, 21] at positions [6, 7]
# Request 2: tokens [30, 31, 32, 33] at positions [8, 9, 10, 11]
target_token_ids
=
torch
.
tensor
(
[
10
,
11
,
12
,
20
,
21
,
30
,
31
,
32
,
33
],
dtype
=
torch
.
int32
,
device
=
device
)
target_positions
=
torch
.
tensor
(
[
7
,
8
,
9
,
6
,
7
,
8
,
9
,
10
,
11
],
dtype
=
torch
.
int64
,
device
=
device
)
target_hidden_states
=
torch
.
randn
(
9
,
proposer
.
hidden_size
,
dtype
=
proposer
.
dtype
,
device
=
device
)
next_token_ids
=
torch
.
tensor
([
100
,
200
,
300
],
dtype
=
torch
.
int32
,
device
=
device
)
num_tokens
,
token_indices_to_sample
,
output_cad
=
proposer
.
set_inputs_first_pass
(
target_token_ids
=
target_token_ids
,
next_token_ids
=
next_token_ids
,
target_positions
=
target_positions
,
target_hidden_states
=
target_hidden_states
,
token_indices_to_sample
=
None
,
cad
=
common_attn_metadata
,
num_rejected_tokens_gpu
=
None
,
)
assert
num_tokens
==
9
# Total tokens unchanged
expected_token_indices_to_sample
=
torch
.
tensor
(
[
2
,
4
,
8
],
dtype
=
torch
.
int32
,
device
=
device
)
assert
torch
.
equal
(
token_indices_to_sample
,
expected_token_indices_to_sample
)
assert
output_cad
is
common_attn_metadata
# Verify input_ids are rotated and next_tokens inserted
# Original: [10, 11, 12, 20, 21, 30, 31, 32, 33]
# After shift by 1: [11, 12, 12, 21, 21, 31, 32, 33, 33]
# After inserting at last indices [2, 4, 8]: [11, 12, 100, 21, 200, 31, 32, 33, 300]
expected_input_ids
=
torch
.
tensor
(
[
11
,
12
,
100
,
21
,
200
,
31
,
32
,
33
,
300
],
dtype
=
torch
.
int32
,
device
=
device
)
assert
torch
.
equal
(
proposer
.
input_ids
[:
num_tokens
],
expected_input_ids
)
# Verify positions are copied as-is
assert
torch
.
equal
(
proposer
.
positions
[:
num_tokens
],
target_positions
)
# Verify hidden states are copied as-is
assert
torch
.
equal
(
proposer
.
hidden_states
[:
num_tokens
],
target_hidden_states
)
def
test_set_inputs_first_pass_draft_model
():
"""
Test for set_inputs_first_pass with a draft model (extra input slots,
no shift).
This tests the path where needs_extra_input_slots=True and
shift_input_ids=False (draft model case). In this case:
- Input IDs are NOT shifted
- Each request gets extra_slots_per_request (1) new slots
- The kernel handles copying tokens and inserting bonus/padding tokens
- A new CommonAttentionMetadata is returned with updated query_start_loc
Setup:
- 2 requests
- Request 0: tokens [10, 11, 12] at positions [0, 1, 2]
- Only tokens [10, 11] are "valid" (query_end_loc=1),
token 12 is a rejected token from previous speculation
- Request 1: tokens [20, 21] at positions [0, 1], both valid.
- Note: this is less than num_speculative_tokens (2) to ensure
we handle variable lengths correctly.
- next_token_ids: [100, 200] (bonus tokens)
With extra_slots_per_request=1 and shift=False:
Expected output layout:
Request 0 (indices 0-3):
- idx 0: token 10, pos 0
- idx 1: token 11, pos 1
- idx 2: token 100, pos 2 (bonus token)
- idx 3: padding_token_id, is_rejected=True
Request 1 (indices 4-6):
- idx 4: token 20, pos 0
- idx 5: token 21, pos 1
- idx 6: token 200, pos 2 (bonus token)
"""
device
=
torch
.
device
(
current_platform
.
device_type
)
num_speculative_tokens
=
2
block_size
=
16
# Create a proposer configured as a draft model (pass_hidden_states=False)
# We need to mock this since _create_proposer defaults to EAGLE
proposer
=
_create_proposer
(
"draft_model"
,
num_speculative_tokens
)
proposer
.
parallel_drafting_token_id
=
0
proposer
.
is_rejected_token_mask
=
torch
.
zeros
(
proposer
.
max_num_tokens
,
dtype
=
torch
.
bool
,
device
=
device
)
proposer
.
is_masked_token_mask
=
torch
.
zeros
(
proposer
.
max_num_tokens
,
dtype
=
torch
.
bool
,
device
=
device
)
# Mock the attn_metadata_builder to avoid needing the full model setup
mock_kv_cache_spec
=
mock
.
MagicMock
()
mock_kv_cache_spec
.
block_size
=
block_size
mock_builder
=
mock
.
MagicMock
()
mock_builder
.
kv_cache_spec
=
mock_kv_cache_spec
proposer
.
attn_metadata_builder
=
mock_builder
# Request 0: query_len=3 (but 1 rejected), Request 1: query_len=2
batch_spec
=
BatchSpec
(
seq_lens
=
[
3
,
2
],
query_lens
=
[
3
,
2
],
)
common_attn_metadata
=
create_common_attn_metadata
(
batch_spec
,
block_size
=
block_size
,
device
=
device
,
arange_block_indices
=
True
,
# Use predictable block indices
)
# Input tensors
target_token_ids
=
torch
.
tensor
(
[
10
,
11
,
12
,
20
,
21
],
dtype
=
torch
.
int32
,
device
=
device
)
target_positions
=
torch
.
tensor
([
0
,
1
,
2
,
0
,
1
],
dtype
=
torch
.
int64
,
device
=
device
)
target_hidden_states
=
torch
.
randn
(
5
,
proposer
.
hidden_size
,
dtype
=
proposer
.
dtype
,
device
=
device
)
next_token_ids
=
torch
.
tensor
([
100
,
200
],
dtype
=
torch
.
int32
,
device
=
device
)
num_rejected_tokens_gpu
=
torch
.
tensor
([
1
,
0
],
dtype
=
torch
.
int32
,
device
=
device
)
num_tokens
,
token_indices_to_sample
,
output_cad
=
proposer
.
set_inputs_first_pass
(
target_token_ids
=
target_token_ids
,
next_token_ids
=
next_token_ids
,
target_positions
=
target_positions
,
target_hidden_states
=
target_hidden_states
,
token_indices_to_sample
=
None
,
cad
=
common_attn_metadata
,
num_rejected_tokens_gpu
=
num_rejected_tokens_gpu
,
)
assert
proposer
.
net_num_new_slots_per_request
==
1
assert
proposer
.
needs_extra_input_slots
# total_output_tokens = total_input_tokens + net_num_new_slots * batch_size
assert
num_tokens
==
7
# Request 0: [10, 11, 100, padding_token (0)]
# Request 1: [20, 21, 200]
# Combined: [10, 11, 100, 0, 20, 21, 200]
expected_input_ids
=
torch
.
tensor
(
[
10
,
11
,
100
,
0
,
20
,
21
,
200
],
dtype
=
torch
.
int32
,
device
=
device
)
assert
torch
.
equal
(
proposer
.
input_ids
[:
num_tokens
],
expected_input_ids
)
# Verify positions
# Request 0: [0, 1, 2, 0 (don't care)]
# Request 1: [0, 1, 2]
# Combined: [0, 1, 2, 0, 0, 1, 2]
expected_positions
=
torch
.
tensor
(
[
0
,
1
,
2
,
0
,
0
,
1
,
2
],
dtype
=
torch
.
int64
,
device
=
device
)
assert
torch
.
equal
(
proposer
.
positions
[:
num_tokens
],
expected_positions
,
)
# Verify rejection mask
expected_is_rejected
=
torch
.
zeros
(
7
,
dtype
=
torch
.
bool
,
device
=
device
)
expected_is_rejected
[
3
]
=
True
# padding token at index 3
assert
torch
.
equal
(
proposer
.
is_rejected_token_mask
[:
num_tokens
],
expected_is_rejected
)
# Verify masked token mask (should all be False for non-parallel drafting)
expected_is_masked
=
torch
.
zeros
(
7
,
dtype
=
torch
.
bool
,
device
=
device
)
assert
torch
.
equal
(
proposer
.
is_masked_token_mask
[:
num_tokens
],
expected_is_masked
)
# Verify token_indices_to_sample (bonus tokens at indices 2 and 6)
expected_token_indices_to_sample
=
torch
.
tensor
(
[
2
,
6
],
dtype
=
torch
.
int32
,
device
=
device
)
assert
torch
.
equal
(
token_indices_to_sample
,
expected_token_indices_to_sample
)
# Verify the new CAD has updated query_start_loc
# Original: [0, 3, 5] -> New: [0, 4, 7] (each request gains 1 slot)
expected_query_start_loc
=
torch
.
tensor
([
0
,
4
,
7
],
dtype
=
torch
.
int32
,
device
=
device
)
assert
torch
.
equal
(
output_cad
.
query_start_loc
,
expected_query_start_loc
)
def
test_set_inputs_first_pass_parallel_drafting
():
"""
Test for set_inputs_first_pass with parallel drafting (extra input slots,
with shift).
This tests the path where needs_extra_input_slots=True and
shift_input_ids=True (parallel drafting case). In this case:
- Input IDs ARE shifted (like default EAGLE)
- Each request gets extra_slots_per_request (3) new slots
- Parallel drafting tokens are inserted and marked as masked
- Hidden states are mapped correctly
Setup:
- 2 requests with query_lens [4, 4] (1 bonus + 3 spec tokens each)
- Request 0: tokens [10, 11, 12, 13] at positions [5, 6, 7, 8]
- Only tokens [10, 11, 12] are "valid", token 13 is rejected
- Request 1: tokens [20, 21, 22, 23] at positions [10, 11, 12, 13], all valid.
- next_token_ids: [100, 200] (bonus tokens)
With shift_input_ids=True, extra_slots_per_request=3:
Expected output layout:
Request 0 (6 output slots = 4 - 1 + 3):
- idx 0-2: shifted tokens [11, 12, 100]
- idx 3-4: parallel_drafting_tokens, is_masked=True
- idx 5: padding_token, is_rejected=True
Request 1 (6 output slots = 4 - 1 + 3):
- idx 6-8: shifted tokens [21, 22, 23]
- idx 9: bonus token 200
- idx 10-11: parallel_drafting_tokens, is_masked=True
"""
device
=
torch
.
device
(
current_platform
.
device_type
)
num_speculative_tokens
=
3
block_size
=
16
proposer
=
_create_proposer
(
"eagle"
,
num_speculative_tokens
,
parallel_drafting
=
True
)
# Override to simulate parallel drafting behavior
proposer
.
parallel_drafting_token_id
=
-
2
proposer
.
parallel_drafting_hidden_state_tensor
=
torch
.
zeros
(
proposer
.
hidden_size
,
dtype
=
proposer
.
dtype
,
device
=
device
)
proposer
.
is_rejected_token_mask
=
torch
.
zeros
(
proposer
.
max_num_tokens
,
dtype
=
torch
.
bool
,
device
=
device
)
proposer
.
is_masked_token_mask
=
torch
.
zeros
(
proposer
.
max_num_tokens
,
dtype
=
torch
.
bool
,
device
=
device
)
# Mock the attn_metadata_builder
mock_kv_cache_spec
=
mock
.
MagicMock
()
mock_kv_cache_spec
.
block_size
=
block_size
mock_builder
=
mock
.
MagicMock
()
mock_builder
.
kv_cache_spec
=
mock_kv_cache_spec
proposer
.
attn_metadata_builder
=
mock_builder
# Request 0: query_len=4 (1 rejected), Request 1: query_len=4 (all valid)
batch_spec
=
BatchSpec
(
seq_lens
=
[
9
,
14
],
query_lens
=
[
4
,
4
],
)
common_attn_metadata
=
create_common_attn_metadata
(
batch_spec
,
block_size
=
block_size
,
device
=
device
,
arange_block_indices
=
True
,
)
# Input tensors
target_token_ids
=
torch
.
tensor
(
[
10
,
11
,
12
,
13
,
20
,
21
,
22
,
23
],
dtype
=
torch
.
int32
,
device
=
device
)
target_positions
=
torch
.
tensor
(
[
5
,
6
,
7
,
8
,
10
,
11
,
12
,
13
],
dtype
=
torch
.
int64
,
device
=
device
)
target_hidden_states
=
torch
.
arange
(
8
*
proposer
.
hidden_size
,
dtype
=
proposer
.
dtype
,
device
=
device
).
view
(
8
,
proposer
.
hidden_size
)
next_token_ids
=
torch
.
tensor
([
100
,
200
],
dtype
=
torch
.
int32
,
device
=
device
)
num_rejected_tokens_gpu
=
torch
.
tensor
([
1
,
0
],
dtype
=
torch
.
int32
,
device
=
device
)
num_tokens
,
token_indices_to_sample
,
output_cad
=
proposer
.
set_inputs_first_pass
(
target_token_ids
=
target_token_ids
,
next_token_ids
=
next_token_ids
,
target_positions
=
target_positions
,
target_hidden_states
=
target_hidden_states
,
token_indices_to_sample
=
None
,
cad
=
common_attn_metadata
,
num_rejected_tokens_gpu
=
num_rejected_tokens_gpu
,
)
# total_output_tokens = total_input_tokens + net_num_new_slots * batch_size
# = 8 + 2 * 2 = 12
assert
num_tokens
==
12
# Request 0: [11, 12, 100, -2, -2, 0(padding)]
# Request 1: [21, 22, 23, 200, -2, -2]
expected_input_ids
=
torch
.
tensor
(
[
11
,
12
,
100
,
-
2
,
-
2
,
0
,
21
,
22
,
23
,
200
,
-
2
,
-
2
],
dtype
=
torch
.
int32
,
device
=
device
,
)
assert
torch
.
equal
(
proposer
.
input_ids
[:
num_tokens
],
expected_input_ids
)
# Verify positions
# Request 0: [5, 6, 7, 8, 9, 0 (don't care)]
# Request 1: [10, 11, 12, 13, 14, 15]
expected_positions
=
torch
.
tensor
(
[
5
,
6
,
7
,
8
,
9
,
0
,
10
,
11
,
12
,
13
,
14
,
15
],
dtype
=
torch
.
int64
,
device
=
device
)
assert
torch
.
equal
(
proposer
.
positions
[:
num_tokens
],
expected_positions
,
)
# Verify rejection mask
expected_is_rejected
=
torch
.
zeros
(
12
,
dtype
=
torch
.
bool
,
device
=
device
)
expected_is_rejected
[
5
]
=
True
assert
torch
.
equal
(
proposer
.
is_rejected_token_mask
[:
num_tokens
],
expected_is_rejected
)
# Verify masked token mask (parallel drafting slots should be masked)
expected_is_masked
=
torch
.
zeros
(
12
,
dtype
=
torch
.
bool
,
device
=
device
)
expected_is_masked
[
3
]
=
True
expected_is_masked
[
4
]
=
True
expected_is_masked
[
10
]
=
True
expected_is_masked
[
11
]
=
True
assert
torch
.
equal
(
proposer
.
is_masked_token_mask
[:
num_tokens
],
expected_is_masked
)
# Verify token_indices_to_sample (bonus + parallel drafting tokens)
# Request 0: bonus at 2, parallel at 3, 4
# Request 1: bonus at 9, parallel at 10, 11
expected_token_indices_to_sample
=
torch
.
tensor
(
[
2
,
3
,
4
,
9
,
10
,
11
],
dtype
=
torch
.
int32
,
device
=
device
)
assert
torch
.
equal
(
token_indices_to_sample
,
expected_token_indices_to_sample
)
# Verify the new CAD has updated query_start_loc
# Original query_lens: [4, 4] -> Output: [6, 6]
expected_query_start_loc
=
torch
.
tensor
(
[
0
,
6
,
12
],
dtype
=
torch
.
int32
,
device
=
device
)
assert
torch
.
equal
(
output_cad
.
query_start_loc
,
expected_query_start_loc
)
# Verify masked positions have the parallel drafting hidden state (zeros)
parallel_drafting_hs
=
proposer
.
parallel_drafting_hidden_state_tensor
for
i
in
range
(
num_tokens
):
if
expected_is_masked
[
i
]:
assert
torch
.
equal
(
proposer
.
hidden_states
[
i
],
parallel_drafting_hs
),
(
f
"Masked position
{
i
}
should have parallel drafting hidden state"
)
@
pytest
.
mark
.
parametrize
(
"method"
,
[
"eagle"
,
"eagle3"
])
@
pytest
.
mark
.
parametrize
(
"method"
,
[
"eagle"
,
"eagle3"
])
@
pytest
.
mark
.
parametrize
(
"attn_backend"
,
get_attn_backend_list_based_on_platform
())
@
pytest
.
mark
.
parametrize
(
"attn_backend"
,
get_attn_backend_list_based_on_platform
())
@
pytest
.
mark
.
parametrize
(
"pp_size"
,
[
1
,
2
])
@
pytest
.
mark
.
parametrize
(
"pp_size"
,
[
1
,
2
])
...
@@ -579,7 +981,7 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
...
@@ -579,7 +981,7 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
target_positions
=
target_positions
,
target_positions
=
target_positions
,
target_hidden_states
=
target_hidden_states
,
target_hidden_states
=
target_hidden_states
,
next_token_ids
=
next_token_ids
,
next_token_ids
=
next_token_ids
,
last_
token_indices
=
None
,
token_indices
_to_sample
=
None
,
common_attn_metadata
=
common_attn_metadata
,
common_attn_metadata
=
common_attn_metadata
,
sampling_metadata
=
sampling_metadata
,
sampling_metadata
=
sampling_metadata
,
)
)
...
@@ -737,7 +1139,7 @@ def test_propose_tree(spec_token_tree):
...
@@ -737,7 +1139,7 @@ def test_propose_tree(spec_token_tree):
target_positions
=
target_positions
,
target_positions
=
target_positions
,
target_hidden_states
=
target_hidden_states
,
target_hidden_states
=
target_hidden_states
,
next_token_ids
=
next_token_ids
,
next_token_ids
=
next_token_ids
,
last_
token_indices
=
None
,
token_indices
_to_sample
=
None
,
common_attn_metadata
=
common_attn_metadata
,
common_attn_metadata
=
common_attn_metadata
,
sampling_metadata
=
sampling_metadata
,
sampling_metadata
=
sampling_metadata
,
)
)
...
...
tests/v1/spec_decode/test_mtp.py
View file @
af3162d3
...
@@ -204,7 +204,7 @@ def test_mtp_propose(num_speculative_tokens, monkeypatch):
...
@@ -204,7 +204,7 @@ def test_mtp_propose(num_speculative_tokens, monkeypatch):
target_positions
=
target_positions
,
target_positions
=
target_positions
,
target_hidden_states
=
target_hidden_states
,
target_hidden_states
=
target_hidden_states
,
next_token_ids
=
next_token_ids
,
next_token_ids
=
next_token_ids
,
last_
token_indices
=
None
,
token_indices
_to_sample
=
None
,
common_attn_metadata
=
common_attn_metadata
,
common_attn_metadata
=
common_attn_metadata
,
sampling_metadata
=
sampling_metadata
,
sampling_metadata
=
sampling_metadata
,
)
)
...
...
vllm/config/speculative.py
View file @
af3162d3
...
@@ -116,9 +116,16 @@ class SpeculativeConfig:
...
@@ -116,9 +116,16 @@ class SpeculativeConfig:
"""Minimum size of ngram token window when using Ngram proposer, if
"""Minimum size of ngram token window when using Ngram proposer, if
provided. Defaults to 1."""
provided. Defaults to 1."""
# Alternative drafting strategies
speculative_token_tree
:
str
|
None
=
None
speculative_token_tree
:
str
|
None
=
None
"""Specifies the tree structure for speculative token generation.
"""Specifies the tree structure for speculative token generation.
"""
"""
parallel_drafting
:
bool
=
False
"""Enable parallel drafting, where all speculative tokens are generated
in parallel rather than sequentially. This can improve performance but
requires the speculative model be trained to support parallel drafting.
Only compatible with EAGLE and draft model methods."""
# required configuration params passed from engine
# required configuration params passed from engine
target_model_config
:
SkipValidation
[
ModelConfig
]
=
None
# type: ignore
target_model_config
:
SkipValidation
[
ModelConfig
]
=
None
# type: ignore
"""The configuration of the target model."""
"""The configuration of the target model."""
...
...
vllm/config/vllm.py
View file @
af3162d3
...
@@ -604,10 +604,13 @@ class VllmConfig:
...
@@ -604,10 +604,13 @@ class VllmConfig:
# Currently, async scheduling only support eagle speculative
# Currently, async scheduling only support eagle speculative
# decoding.
# decoding.
if
self
.
speculative_config
is
not
None
:
if
self
.
speculative_config
is
not
None
:
if
self
.
speculative_config
.
method
not
in
get_args
(
EagleModelTypes
):
if
(
self
.
speculative_config
.
method
not
in
get_args
(
EagleModelTypes
)
and
self
.
speculative_config
.
method
!=
"draft_model"
):
raise
ValueError
(
raise
ValueError
(
"Currently, async scheduling is only supported "
"Currently, async scheduling is only supported "
"with EAGLE/MTP kind of speculative decoding."
"with EAGLE/MTP
/Draft Model
kind of speculative decoding."
)
)
if
self
.
speculative_config
.
disable_padded_drafter_batch
:
if
self
.
speculative_config
.
disable_padded_drafter_batch
:
raise
ValueError
(
raise
ValueError
(
...
@@ -1298,16 +1301,21 @@ class VllmConfig:
...
@@ -1298,16 +1301,21 @@ class VllmConfig:
computed_compile_ranges_split_points
=
[]
computed_compile_ranges_split_points
=
[]
# The upper bound of the compile ranges is the max_num_batched_tokens.
# The upper bound of the compile ranges is the max_num_batched_tokens.
# For speculative decoding with draft model, the compile range must be extended
# For speculative decoding, the compile range must be extended
# by 1 for each sequence.
# - Sequential: + 1 * max_num_seqs (one draft token per iteration)
# - Parallel draft: + num_speculative_tokens * max_num_seqs
compile_range_end
=
self
.
scheduler_config
.
max_num_batched_tokens
compile_range_end
=
self
.
scheduler_config
.
max_num_batched_tokens
if
compile_range_end
is
not
None
:
if
compile_range_end
is
not
None
:
do_extend
:
bool
=
(
if
self
.
speculative_config
is
not
None
and
(
self
.
speculative_config
is
not
None
self
.
speculative_config
.
uses_draft_model
()
and
self
.
speculative_config
.
uses_draft_model
()
or
self
.
speculative_config
.
use_eagle
()
):
multiplier
=
(
self
.
speculative_config
.
num_speculative_tokens
if
self
.
speculative_config
.
parallel_drafting
else
1
)
)
if
do_extend
:
compile_range_end
+=
multiplier
*
self
.
scheduler_config
.
max_num_seqs
compile_range_end
+=
self
.
scheduler_config
.
max_num_seqs
computed_compile_ranges_split_points
.
append
(
compile_range_end
)
computed_compile_ranges_split_points
.
append
(
compile_range_end
)
...
...
vllm/model_executor/models/llama_eagle3.py
View file @
af3162d3
...
@@ -52,13 +52,16 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
...
@@ -52,13 +52,16 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
# Subsequent layers use hidden_size (only hidden_states, no embeds)
# Subsequent layers use hidden_size (only hidden_states, no embeds)
qkv_input_size
=
2
*
self
.
hidden_size
if
layer_idx
==
0
else
self
.
hidden_size
qkv_input_size
=
2
*
self
.
hidden_size
if
layer_idx
==
0
else
self
.
hidden_size
# override qkv
# Parallel drafting checkpoints may have attention bias enabled
qkv_bias
=
getattr
(
config
,
"attention_bias"
,
False
)
# Override qkv_proj with correct input size and bias setting
self
.
self_attn
.
qkv_proj
=
QKVParallelLinear
(
self
.
self_attn
.
qkv_proj
=
QKVParallelLinear
(
qkv_input_size
,
qkv_input_size
,
self
.
self_attn
.
head_dim
,
self
.
self_attn
.
head_dim
,
self
.
self_attn
.
total_num_heads
,
self
.
self_attn
.
total_num_heads
,
self
.
self_attn
.
total_num_kv_heads
,
self
.
self_attn
.
total_num_kv_heads
,
bias
=
False
,
bias
=
qkv_bias
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"qkv_proj"
),
prefix
=
maybe_prefix
(
prefix
,
"qkv_proj"
),
)
)
...
@@ -293,6 +296,19 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
...
@@ -293,6 +296,19 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
requires_grad
=
False
,
requires_grad
=
False
,
)
)
self
.
use_parallel_drafting
=
vllm_config
.
speculative_config
.
parallel_drafting
if
self
.
use_parallel_drafting
:
self
.
register_buffer
(
"mask_hidden"
,
torch
.
zeros
(
1
,
(
3
if
self
.
model
.
use_aux_hidden_state
else
1
)
*
self
.
config
.
hidden_size
,
),
persistent
=
False
,
)
def
embed_input_ids
(
def
embed_input_ids
(
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
...
@@ -347,12 +363,25 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
...
@@ -347,12 +363,25 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
model_weights
=
{}
model_weights
=
{}
includes_draft_id_mapping
=
False
includes_draft_id_mapping
=
False
includes_embed_tokens
=
False
includes_embed_tokens
=
False
includes_mask_hidden
=
False
for
name
,
loaded_weight
in
weights
:
for
name
,
loaded_weight
in
weights
:
if
"t2d"
in
name
:
if
"t2d"
in
name
:
continue
continue
if
"d2t"
in
name
:
if
"d2t"
in
name
:
name
=
name
.
replace
(
"d2t"
,
"draft_id_to_target_id"
)
name
=
name
.
replace
(
"d2t"
,
"draft_id_to_target_id"
)
includes_draft_id_mapping
=
True
includes_draft_id_mapping
=
True
elif
"mask_hidden"
in
name
:
# Load mask_hidden directly into buffer
if
not
self
.
use_parallel_drafting
:
logger
.
warning
(
"mask_hidden found in weights but "
"model is not configured for parallel drafting. "
"Skipping loading mask_hidden."
)
continue
self
.
mask_hidden
.
copy_
(
loaded_weight
.
view
(
1
,
-
1
))
includes_mask_hidden
=
True
continue
elif
"lm_head"
not
in
name
:
elif
"lm_head"
not
in
name
:
name
=
"model."
+
name
name
=
"model."
+
name
if
"embed_tokens"
in
name
:
if
"embed_tokens"
in
name
:
...
@@ -360,7 +389,14 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
...
@@ -360,7 +389,14 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
model_weights
[
name
]
=
loaded_weight
model_weights
[
name
]
=
loaded_weight
process_eagle_weight
(
self
,
name
)
process_eagle_weight
(
self
,
name
)
skip_substrs
=
[]
if
not
includes_mask_hidden
and
self
.
use_parallel_drafting
:
raise
ValueError
(
"mask_hidden not found in weights but "
"model is configured for parallel drafting. "
"Please provide mask_hidden in the weights."
)
skip_substrs
=
[
"mask_hidden"
]
if
not
includes_draft_id_mapping
:
if
not
includes_draft_id_mapping
:
skip_substrs
.
append
(
"draft_id_to_target_id"
)
skip_substrs
.
append
(
"draft_id_to_target_id"
)
if
not
includes_embed_tokens
:
if
not
includes_embed_tokens
:
...
...
vllm/v1/attention/backend.py
View file @
af3162d3
...
@@ -480,9 +480,14 @@ class AttentionMetadataBuilder(ABC, Generic[M]):
...
@@ -480,9 +480,14 @@ class AttentionMetadataBuilder(ABC, Generic[M]):
speculative_config
is
not
None
speculative_config
is
not
None
and
speculative_config
.
num_speculative_tokens
is
not
None
and
speculative_config
.
num_speculative_tokens
is
not
None
):
):
max_num_queries_for_spec
=
(
1
+
(
2
if
speculative_config
.
parallel_drafting
else
1
)
*
speculative_config
.
num_speculative_tokens
)
self
.
reorder_batch_threshold
=
max
(
self
.
reorder_batch_threshold
=
max
(
self
.
reorder_batch_threshold
,
self
.
reorder_batch_threshold
,
1
+
speculative_config
.
num_speculative_tokens
,
max_num_queries_for_spec
,
)
)
if
(
if
(
...
...
vllm/v1/attention/backends/flashinfer.py
View file @
af3162d3
...
@@ -60,7 +60,7 @@ from vllm.v1.attention.backends.utils import (
...
@@ -60,7 +60,7 @@ from vllm.v1.attention.backends.utils import (
)
)
from
vllm.v1.attention.ops.common
import
cp_lse_ag_out_rs
from
vllm.v1.attention.ops.common
import
cp_lse_ag_out_rs
from
vllm.v1.attention.ops.merge_attn_states
import
merge_attn_states
from
vllm.v1.attention.ops.merge_attn_states
import
merge_attn_states
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.kv_cache_interface
import
AttentionSpec
,
UniformTypeKVCacheSpecs
from
vllm.v1.utils
import
CpuGpuBuffer
from
vllm.v1.utils
import
CpuGpuBuffer
FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT
=
2048
*
1024
*
1024
FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT
=
2048
*
1024
*
1024
...
@@ -658,12 +658,36 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
...
@@ -658,12 +658,36 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
vllm_config
:
VllmConfig
,
vllm_config
:
VllmConfig
,
kv_cache_spec
:
AttentionSpec
,
kv_cache_spec
:
AttentionSpec
,
)
->
AttentionCGSupport
:
)
->
AttentionCGSupport
:
has_trtllm_support
=
can_use_trtllm_attention
(
"""Get the cudagraph support level for FlashInfer attention.
num_qo_heads
=
vllm_config
.
model_config
.
get_num_attention_heads
(
This depends on whether we can use TRTLLM attention for decodes, since we can
only do UNIFORM_SINGLE_TOKEN_DECODE if it is unavailable.
To check this, we must call can_use_trtllm_attention with the number of KV
heads from the kv_cache_spec. We check all available KV cache specs and
only return UNIFORM_BATCH if all of them support TRTLLM attention.
"""
# For UniformTypeKVCacheSpecs, check all contained specs
kv_specs
=
(
kv_cache_spec
.
kv_cache_specs
.
values
()
if
isinstance
(
kv_cache_spec
,
UniformTypeKVCacheSpecs
)
else
[
kv_cache_spec
]
)
num_qo_heads
=
vllm_config
.
model_config
.
get_num_attention_heads
(
vllm_config
.
parallel_config
vllm_config
.
parallel_config
),
num_kv_heads
=
kv_cache_spec
.
num_kv_heads
,
)
)
has_trtllm_support
:
bool
=
len
(
kv_specs
)
>
0
for
spec
in
kv_specs
:
if
not
isinstance
(
spec
,
AttentionSpec
):
# FlashInfer only applies to attention, so we don't consider other types
# of KV spec (e.g. Mamba) here. This is mostly for type checking.
continue
if
not
can_use_trtllm_attention
(
num_qo_heads
=
num_qo_heads
,
num_kv_heads
=
spec
.
num_kv_heads
,
):
has_trtllm_support
=
False
break
if
has_trtllm_support
:
if
has_trtllm_support
:
return
AttentionCGSupport
.
UNIFORM_BATCH
return
AttentionCGSupport
.
UNIFORM_BATCH
else
:
else
:
...
...
vllm/v1/attention/backends/utils.py
View file @
af3162d3
...
@@ -825,38 +825,6 @@ def get_dcp_local_seq_lens(
...
@@ -825,38 +825,6 @@ def get_dcp_local_seq_lens(
return
dcp_local_seq_lens
.
squeeze
(
1
)
return
dcp_local_seq_lens
.
squeeze
(
1
)
def
extend_all_queries_by_1
(
common_attn_metadata
:
CommonAttentionMetadata
,
arange
:
torch
.
Tensor
,
new_slot_mapping
:
torch
.
Tensor
,
)
->
CommonAttentionMetadata
:
"""
Creates a new CommonAttentionMetadata with all query lengths increased by 1.
Also all seq lens are increased by 1.
This is useful e.g. in speculative decoding with draft models, where we
extend each sequence by 1 token.
The slot mapping is computed externally, as it requires more information.
"""
cad
=
common_attn_metadata
# query start loc must be increased by [+0, +1, +2, ..., +batch_size]
new_query_start_loc
=
cad
.
query_start_loc
+
arange
[:
len
(
cad
.
query_start_loc
)]
new_query_start_loc_cpu
=
cad
.
query_start_loc_cpu
+
torch
.
arange
(
len
(
cad
.
query_start_loc_cpu
),
dtype
=
torch
.
int32
)
new_cad
=
cad
.
replace
(
query_start_loc
=
new_query_start_loc
,
query_start_loc_cpu
=
new_query_start_loc_cpu
,
seq_lens
=
cad
.
seq_lens
+
1
,
# each request is extended by 1 token -> batch_size tokens are added
num_actual_tokens
=
cad
.
num_actual_tokens
+
cad
.
batch_size
(),
# All query lens increase by 1, so max query len increases by 1
max_query_len
=
cad
.
max_query_len
+
1
,
max_seq_len
=
cad
.
max_seq_len
+
1
,
slot_mapping
=
new_slot_mapping
,
)
return
new_cad
def
mamba_get_block_table_tensor
(
def
mamba_get_block_table_tensor
(
block_table
:
torch
.
Tensor
,
block_table
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
...
...
vllm/v1/spec_decode/draft_model.py
View file @
af3162d3
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Any
import
torch
import
torch
import
torch.nn
as
nn
from
typing_extensions
import
override
from
vllm.config
import
VllmConfig
,
get_layers_from_vllm_config
,
replace
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.attention
import
Attention
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.model_loader
import
get_model
from
vllm.triton_utils
import
tl
,
triton
from
vllm.v1.spec_decode.eagle
import
SpecDecodeBaseProposer
from
vllm.v1.attention.backends.utils
import
(
from
vllm.v1.spec_decode.utils
import
create_vllm_config_for_draft_model
CommonAttentionMetadata
,
extend_all_queries_by_1
,
)
from
vllm.v1.spec_decode.eagle
import
PADDING_SLOT_ID
,
SpecDecodeBaseProposer
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -31,37 +27,9 @@ class DraftModelProposer(SpecDecodeBaseProposer):
...
@@ -31,37 +27,9 @@ class DraftModelProposer(SpecDecodeBaseProposer):
pass_hidden_states_to_model
=
False
,
pass_hidden_states_to_model
=
False
,
runner
=
runner
,
runner
=
runner
,
)
)
self
.
_raise_if_multimodal
()
self
.
_raise_if_mrope
()
self
.
_raise_if_padded_drafter_batch_disabled
()
self
.
_raise_if_vocab_size_mismatch
()
self
.
_raise_if_vocab_size_mismatch
()
self
.
_raise_if_draft_tp_mismatch
()
self
.
_raise_if_draft_tp_mismatch
()
def
_block_size
(
self
)
->
int
:
builder
=
self
.
_get_attention_metadata_builder
()
return
builder
.
kv_cache_spec
.
block_size
def
_raise_if_multimodal
(
self
):
if
self
.
supports_mm_inputs
:
raise
NotImplementedError
(
"Speculative Decoding with draft models "
"does not support multimodal models yet"
)
def
_raise_if_mrope
(
self
):
if
self
.
draft_model_config
.
uses_mrope
:
raise
NotImplementedError
(
"Speculative Decoding with draft models does not support M-RoPE yet"
)
def
_raise_if_padded_drafter_batch_disabled
(
self
):
if
self
.
speculative_config
.
disable_padded_drafter_batch
:
raise
NotImplementedError
(
"Speculative Decoding with draft models only supports "
"padded drafter batch. Please don't pass --disable-padded-drafter-batch"
" in the speculative_config."
)
def
_raise_if_vocab_size_mismatch
(
self
):
def
_raise_if_vocab_size_mismatch
(
self
):
self
.
speculative_config
.
verify_equal_vocab_size_if_draft_model
()
self
.
speculative_config
.
verify_equal_vocab_size_if_draft_model
()
...
@@ -82,193 +50,26 @@ class DraftModelProposer(SpecDecodeBaseProposer):
...
@@ -82,193 +50,26 @@ class DraftModelProposer(SpecDecodeBaseProposer):
"Please pass 'draft_tensor_parallel_size' in the speculative_config."
"Please pass 'draft_tensor_parallel_size' in the speculative_config."
)
)
def
set_inputs_first_pass
(
@
override
self
,
def
_get_model
(
self
)
->
nn
.
Module
:
target_token_ids
:
torch
.
Tensor
,
# Draft models may be quantized or on different parallelism,
next_token_ids
:
torch
.
Tensor
,
# so we load them with a modified vllm config
target_positions
:
torch
.
Tensor
,
last_token_indices
:
torch
.
Tensor
|
None
,
cad
:
CommonAttentionMetadata
,
num_rejected_tokens_gpu
:
torch
.
Tensor
|
None
,
)
->
tuple
[
int
,
torch
.
Tensor
,
CommonAttentionMetadata
]:
batch_size
=
cad
.
batch_size
()
grid
=
(
batch_size
,)
start_locs
=
cad
.
query_start_loc
[:
-
1
]
end_locs
=
cad
.
query_start_loc
[
1
:]
-
1
if
num_rejected_tokens_gpu
is
not
None
:
end_locs
-=
num_rejected_tokens_gpu
num_tokens
=
target_token_ids
.
shape
[
0
]
+
batch_size
is_rejected_tok
=
torch
.
empty
(
(
num_tokens
,),
device
=
self
.
input_ids
.
device
,
dtype
=
torch
.
bool
)
merge_toks_kernel
[
grid
](
target_toks_ptr
=
target_token_ids
,
next_toks_ptr
=
next_token_ids
,
query_start_locs_ptr
=
start_locs
,
query_end_locs_ptr
=
end_locs
,
out_ptr_merged_toks
=
self
.
input_ids
,
out_ptr_is_rejected_tok
=
is_rejected_tok
,
target_toks_size
=
target_token_ids
.
shape
[
0
],
# passing a negative rejected_tok_fill value will raise an error
# when the value is used to index into embeddings.
# Therefore, we pass a valid integer, e.g. 0.
rejected_tok_fill
=
0
,
)
merge_toks_kernel
[
grid
](
target_toks_ptr
=
target_positions
,
next_toks_ptr
=
target_positions
[
end_locs
]
+
1
,
query_start_locs_ptr
=
start_locs
,
query_end_locs_ptr
=
end_locs
,
out_ptr_merged_toks
=
self
.
positions
,
out_ptr_is_rejected_tok
=
is_rejected_tok
,
target_toks_size
=
target_positions
.
shape
[
0
],
rejected_tok_fill
=
0
,
)
# recompute slot mapping
new_slot_mapping
=
compute_new_slot_mapping
(
cad
=
cad
,
new_positions
=
self
.
positions
[:
num_tokens
],
is_rejected_token_mask
=
is_rejected_tok
,
block_size
=
self
.
_block_size
(),
max_model_len
=
self
.
max_model_len
,
)
# update common_attn_metadata
new_cad
:
CommonAttentionMetadata
=
extend_all_queries_by_1
(
cad
,
arange
=
self
.
arange
,
new_slot_mapping
=
new_slot_mapping
,
)
new_last_token_indices
=
new_cad
.
query_start_loc
[
1
:]
-
1
if
num_rejected_tokens_gpu
is
not
None
:
new_last_token_indices
-=
num_rejected_tokens_gpu
return
num_tokens
,
new_last_token_indices
,
new_cad
def
load_model
(
self
,
target_model
:
Any
)
->
None
:
"""Takes target_model to satisfy the type checker."""
# This must be computed before loading the draft model
# because that mutates the forward_context of the vllm_config
target_attn_layer_names
=
set
(
get_layers_from_vllm_config
(
self
.
vllm_config
,
Attention
).
keys
()
)
from
vllm.compilation.backends
import
set_model_tag
from
vllm.compilation.backends
import
set_model_tag
draft_vllm_config
:
VllmConfig
=
create_vllm_config_for_draft_model
(
temp_vllm_config
=
create_vllm_config_for_draft_model
(
self
.
vllm_config
)
target_model_vllm_config
=
self
.
vllm_config
)
logger
.
info
(
"Starting to load draft model %s. TP=%d, rank=%d"
,
draft_vllm_config
.
model_config
.
model
,
draft_vllm_config
.
parallel_config
.
tensor_parallel_size
,
draft_vllm_config
.
parallel_config
.
rank
,
)
with
set_model_tag
(
"draft_model"
):
with
set_model_tag
(
"draft_model"
):
self
.
model
=
get_model
(
vllm_config
=
draft_vllm_config
,
prefix
=
"draft_model"
)
model
=
get_model
(
vllm_config
=
temp_vllm_config
,
# This must be computed after loading the draft model
prefix
=
"draft_model"
,
# because that mutates the forward_context of the vllm_config
draft_attn_layer_names
=
(
get_layers_from_vllm_config
(
self
.
vllm_config
,
Attention
).
keys
()
-
target_attn_layer_names
)
self
.
attn_layer_names
=
list
(
draft_attn_layer_names
)
def
create_vllm_config_for_draft_model
(
target_model_vllm_config
:
VllmConfig
,
)
->
VllmConfig
:
"""The vllm_config is configured for the target model, e.g.
its quant_config and parallel_config. But the draft model is potentially
quantized differently, and has potentially different tensor_parallel_size.
This function creates a new vllm_config configured for the draft model.
The vllm_config is useful when loading the draft model with get_model().
"""
old
=
target_model_vllm_config
assert
old
.
speculative_config
is
not
None
,
"speculative_config is not set"
old_spec_config
=
old
.
speculative_config
new_parallel_config
=
replace
(
old_spec_config
.
draft_parallel_config
,
rank
=
old
.
parallel_config
.
rank
,
)
new
:
VllmConfig
=
replace
(
old
,
quant_config
=
None
,
# quant_config is recomputed in __init__()
model_config
=
old_spec_config
.
draft_model_config
,
parallel_config
=
new_parallel_config
,
)
return
new
def
compute_new_slot_mapping
(
cad
:
CommonAttentionMetadata
,
new_positions
:
torch
.
Tensor
,
is_rejected_token_mask
:
torch
.
Tensor
,
block_size
:
int
,
max_model_len
:
int
,
):
batch_size
,
n_blocks_per_req
=
cad
.
block_table_tensor
.
shape
req_indices
=
torch
.
arange
(
batch_size
,
device
=
cad
.
query_start_loc
.
device
)
req_indices
=
torch
.
repeat_interleave
(
req_indices
,
cad
.
naive_query_lens
()
+
1
,
output_size
=
len
(
new_positions
)
)
)
# Clamp the positions to prevent an out-of-bounds error when indexing
return
model
# into block_table_tensor.
clamped_positions
=
torch
.
clamp
(
new_positions
,
max
=
max_model_len
-
1
)
block_table_indices
=
(
req_indices
*
n_blocks_per_req
+
clamped_positions
//
block_size
)
block_nums
=
cad
.
block_table_tensor
.
view
(
-
1
)[
block_table_indices
]
block_offsets
=
clamped_positions
%
block_size
new_slot_mapping
=
block_nums
*
block_size
+
block_offsets
# Mask out the position ids that exceed the max model length.
exceeds_max_model_len
=
new_positions
>=
max_model_len
new_slot_mapping
.
masked_fill_
(
exceeds_max_model_len
,
PADDING_SLOT_ID
)
# Mask out rejected tokens to prevent saves to the KV cache.
new_slot_mapping
.
masked_fill_
(
is_rejected_token_mask
,
PADDING_SLOT_ID
)
return
new_slot_mapping
@
triton
.
jit
@
override
def
merge_toks_kernel
(
def
_maybe_share_embeddings
(
self
,
target_language_model
:
nn
.
Module
)
->
None
:
target_toks_ptr
,
# Draft models don't share embeddings with the target model
next_toks_ptr
,
pass
query_start_locs_ptr
,
query_end_locs_ptr
,
out_ptr_merged_toks
,
out_ptr_is_rejected_tok
,
target_toks_size
,
rejected_tok_fill
,
):
"""
Merges the `target_toks_ptr` and the `next_toks_ptr` into a new tensor
called `out_ptr_merged_toks`. Rejected tokens are those after the
`query_end_locs_ptr` and before the next `query_start_locs_ptr`. Fills the
rejected tokens positions with the value `rejected_tok_fill`. Also fills a mask
of the rejected tokens in `out_ptr_is_rejected_tok`.
"""
pid
=
tl
.
program_id
(
0
)
start_loc
=
tl
.
load
(
query_start_locs_ptr
+
pid
)
is_last_program
=
pid
==
tl
.
num_programs
(
0
)
-
1
if
is_last_program
:
next_start_loc
=
target_toks_size
.
to
(
tl
.
int32
)
else
:
next_start_loc
=
tl
.
load
(
query_start_locs_ptr
+
pid
+
1
).
to
(
tl
.
int32
)
end_loc
=
tl
.
load
(
query_end_locs_ptr
+
pid
)
@
override
new_val
=
tl
.
load
(
next_toks_ptr
+
pid
)
def
_maybe_share_lm_head
(
self
,
target_language_model
:
nn
.
Module
)
->
None
:
for
i
in
range
(
start_loc
,
next_start_loc
+
1
):
# Draft models don't share lm_head with the target model
if
i
<=
end_loc
:
# copy existing tokens
pass
old_val
=
tl
.
load
(
target_toks_ptr
+
i
)
tl
.
store
(
out_ptr_merged_toks
+
pid
+
i
,
old_val
)
tl
.
store
(
out_ptr_is_rejected_tok
+
pid
+
i
,
False
)
elif
i
==
end_loc
+
1
:
# copy bonus token
tl
.
store
(
out_ptr_merged_toks
+
pid
+
i
,
new_val
)
tl
.
store
(
out_ptr_is_rejected_tok
+
pid
+
i
,
False
)
else
:
# fill rejected tokens
tl
.
store
(
out_ptr_merged_toks
+
pid
+
i
,
rejected_tok_fill
)
tl
.
store
(
out_ptr_is_rejected_tok
+
pid
+
i
,
True
)
vllm/v1/spec_decode/eagle.py
View file @
af3162d3
...
@@ -43,8 +43,12 @@ from vllm.v1.sample.metadata import SamplingMetadata
...
@@ -43,8 +43,12 @@ from vllm.v1.sample.metadata import SamplingMetadata
from
vllm.v1.sample.sampler
import
_SAMPLING_EPS
from
vllm.v1.sample.sampler
import
_SAMPLING_EPS
from
vllm.v1.spec_decode.metadata
import
SpecDecodeMetadata
from
vllm.v1.spec_decode.metadata
import
SpecDecodeMetadata
from
vllm.v1.spec_decode.utils
import
(
from
vllm.v1.spec_decode.utils
import
(
PADDING_SLOT_ID
,
compute_new_slot_mapping
,
copy_and_expand_eagle_inputs_kernel
,
eagle_prepare_inputs_padded_kernel
,
eagle_prepare_inputs_padded_kernel
,
eagle_prepare_next_token_padded_kernel
,
eagle_prepare_next_token_padded_kernel
,
extend_all_queries_by_N
,
)
)
from
vllm.v1.utils
import
CpuGpuBuffer
from
vllm.v1.utils
import
CpuGpuBuffer
from
vllm.v1.worker.dp_utils
import
coordinate_batch_across_dp
from
vllm.v1.worker.dp_utils
import
coordinate_batch_across_dp
...
@@ -52,8 +56,6 @@ from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
...
@@ -52,8 +56,6 @@ from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
PADDING_SLOT_ID
=
-
1
class
SpecDecodeBaseProposer
:
class
SpecDecodeBaseProposer
:
def
__init__
(
def
__init__
(
...
@@ -76,18 +78,35 @@ class SpecDecodeBaseProposer:
...
@@ -76,18 +78,35 @@ class SpecDecodeBaseProposer:
self
.
max_model_len
=
vllm_config
.
model_config
.
max_model_len
self
.
max_model_len
=
vllm_config
.
model_config
.
max_model_len
self
.
dp_rank
=
vllm_config
.
parallel_config
.
data_parallel_rank
self
.
dp_rank
=
vllm_config
.
parallel_config
.
data_parallel_rank
self
.
num_speculative_tokens
=
self
.
speculative_config
.
num_speculative_tokens
self
.
num_speculative_tokens
=
self
.
speculative_config
.
num_speculative_tokens
# The drafter can get longer sequences than the target model.
max_batch_size
=
vllm_config
.
scheduler_config
.
max_num_seqs
self
.
max_num_tokens
=
(
vllm_config
.
scheduler_config
.
max_num_batched_tokens
+
max_batch_size
)
self
.
token_arange_np
=
np
.
arange
(
self
.
max_num_tokens
)
# We need to get the hidden size from the draft model config because
# We need to get the hidden size from the draft model config because
# the draft model's hidden size can be different from the target model's
# the draft model's hidden size can be different from the target model's
# hidden size (e.g., Llama 3.3 70B).
# hidden size (e.g., Llama 3.3 70B).
self
.
hidden_size
=
self
.
draft_model_config
.
get_hidden_size
()
self
.
hidden_size
=
self
.
draft_model_config
.
get_hidden_size
()
self
.
inputs_embeds_size
=
self
.
draft_model_config
.
get_inputs_embeds_size
()
self
.
inputs_embeds_size
=
self
.
draft_model_config
.
get_inputs_embeds_size
()
# Unifying eagle, draft model, and parallel drafting support
self
.
parallel_drafting
:
bool
=
self
.
speculative_config
.
parallel_drafting
self
.
extra_slots_per_request
=
(
1
if
not
self
.
parallel_drafting
else
self
.
num_speculative_tokens
)
self
.
net_num_new_slots_per_request
=
self
.
extra_slots_per_request
-
(
1
if
self
.
pass_hidden_states_to_model
else
0
)
self
.
needs_extra_input_slots
=
self
.
net_num_new_slots_per_request
>
0
self
.
parallel_drafting_token_id
:
int
=
0
self
.
parallel_drafting_hidden_state_tensor
:
torch
.
Tensor
|
None
=
None
if
self
.
parallel_drafting
:
self
.
_init_parallel_drafting_params
()
# The drafter can get longer sequences than the target model.
max_batch_size
=
vllm_config
.
scheduler_config
.
max_num_seqs
self
.
max_num_tokens
=
vllm_config
.
scheduler_config
.
max_num_batched_tokens
+
(
self
.
net_num_new_slots_per_request
*
max_batch_size
)
self
.
token_arange_np
=
np
.
arange
(
self
.
max_num_tokens
)
# Multi-modal data support
# Multi-modal data support
self
.
mm_registry
=
MULTIMODAL_REGISTRY
self
.
mm_registry
=
MULTIMODAL_REGISTRY
self
.
supports_mm_inputs
=
self
.
mm_registry
.
supports_multimodal_inputs
(
self
.
supports_mm_inputs
=
self
.
mm_registry
.
supports_multimodal_inputs
(
...
@@ -155,6 +174,26 @@ class SpecDecodeBaseProposer:
...
@@ -155,6 +174,26 @@ class SpecDecodeBaseProposer:
max_num_slots_for_arange
,
device
=
device
,
dtype
=
torch
.
int32
max_num_slots_for_arange
,
device
=
device
,
dtype
=
torch
.
int32
)
)
if
self
.
needs_extra_input_slots
:
self
.
_raise_if_padded_drafter_batch_disabled
()
self
.
_raise_if_multimodal
()
self
.
_raise_if_mrope
()
self
.
is_rejected_token_mask
:
torch
.
Tensor
|
None
=
None
self
.
is_masked_token_mask
:
torch
.
Tensor
|
None
=
None
if
self
.
needs_extra_input_slots
:
# For draft models and parallel drafting, we need to keep track of
# which tokens are rejected to update the slot mapping with padding slots.
self
.
is_rejected_token_mask
=
torch
.
zeros
(
(
self
.
max_num_tokens
,),
dtype
=
torch
.
bool
,
device
=
device
)
# For parallel drafting, we also need to keep track of which tokens
# are parallel-padding tokens used to sample at later positions.
# We populate this tensor even when using draft models for simplicity.
self
.
is_masked_token_mask
=
torch
.
zeros
(
(
self
.
max_num_tokens
,),
dtype
=
torch
.
bool
,
device
=
device
)
self
.
inputs_embeds
=
torch
.
zeros
(
self
.
inputs_embeds
=
torch
.
zeros
(
(
self
.
max_num_tokens
,
self
.
inputs_embeds_size
),
(
self
.
max_num_tokens
,
self
.
inputs_embeds_size
),
dtype
=
self
.
dtype
,
dtype
=
self
.
dtype
,
...
@@ -231,6 +270,49 @@ class SpecDecodeBaseProposer:
...
@@ -231,6 +270,49 @@ class SpecDecodeBaseProposer:
1
,
len
(
self
.
tree_choices
)
+
1
,
device
=
device
,
dtype
=
torch
.
int32
1
,
len
(
self
.
tree_choices
)
+
1
,
device
=
device
,
dtype
=
torch
.
int32
).
repeat
(
max_batch_size
,
1
)
).
repeat
(
max_batch_size
,
1
)
def
_raise_if_padded_drafter_batch_disabled
(
self
):
if
self
.
speculative_config
.
disable_padded_drafter_batch
:
raise
NotImplementedError
(
"Speculative Decoding with draft models or parallel drafting only "
"supports padded drafter batch. Please unset "
"disable_padded_drafter_batch in the speculative_config."
)
def
_raise_if_multimodal
(
self
):
if
self
.
supports_mm_inputs
:
raise
NotImplementedError
(
"Speculative Decoding with draft models or parallel drafting "
"does not support multimodal models yet"
)
def
_raise_if_mrope
(
self
):
if
self
.
draft_model_config
.
uses_mrope
:
raise
NotImplementedError
(
"Speculative Decoding with draft models or parallel drafting "
"does not support M-RoPE yet"
)
def
_init_parallel_drafting_params
(
self
):
# For parallel drafting, we need the token ID to use for masked slots
# And for EAGLE + parallel drafting, we need the hidden state tensor to use
# for those masked slots.
model_hf_config
=
self
.
draft_model_config
.
hf_config
if
hasattr
(
model_hf_config
,
"pard_token"
):
self
.
parallel_drafting_token_id
=
model_hf_config
.
pard_token
elif
hasattr
(
model_hf_config
,
"ptd_token_id"
):
self
.
parallel_drafting_token_id
=
model_hf_config
.
ptd_token_id
else
:
raise
ValueError
(
"For parallel drafting, the draft model config must have "
"`pard_token` or `ptd_token_id` specified in its config.json."
)
if
self
.
pass_hidden_states_to_model
:
self
.
parallel_drafting_hidden_state_tensor
=
torch
.
empty
(
self
.
hidden_size
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
def
_get_positions
(
self
,
num_tokens
:
int
):
def
_get_positions
(
self
,
num_tokens
:
int
):
if
self
.
uses_mrope
:
if
self
.
uses_mrope
:
return
self
.
mrope_positions
[:,
:
num_tokens
]
return
self
.
mrope_positions
[:,
:
num_tokens
]
...
@@ -296,7 +378,7 @@ class SpecDecodeBaseProposer:
...
@@ -296,7 +378,7 @@ class SpecDecodeBaseProposer:
target_hidden_states
:
torch
.
Tensor
,
target_hidden_states
:
torch
.
Tensor
,
# [batch_size]
# [batch_size]
next_token_ids
:
torch
.
Tensor
,
next_token_ids
:
torch
.
Tensor
,
last_
token_indices
:
torch
.
Tensor
|
None
,
token_indices
_to_sample
:
torch
.
Tensor
|
None
,
common_attn_metadata
:
CommonAttentionMetadata
,
common_attn_metadata
:
CommonAttentionMetadata
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
mm_embed_inputs
:
tuple
[
list
[
torch
.
Tensor
],
torch
.
Tensor
]
|
None
=
None
,
mm_embed_inputs
:
tuple
[
list
[
torch
.
Tensor
],
torch
.
Tensor
]
|
None
=
None
,
...
@@ -314,12 +396,13 @@ class SpecDecodeBaseProposer:
...
@@ -314,12 +396,13 @@ class SpecDecodeBaseProposer:
)
)
assert
target_hidden_states
.
shape
[
-
1
]
==
self
.
hidden_size
assert
target_hidden_states
.
shape
[
-
1
]
==
self
.
hidden_size
num_tokens
,
last_
token_indices
,
common_attn_metadata
=
(
num_tokens
,
token_indices
_to_sample
,
common_attn_metadata
=
(
self
.
set_inputs_first_pass
(
self
.
set_inputs_first_pass
(
target_token_ids
=
target_token_ids
,
target_token_ids
=
target_token_ids
,
next_token_ids
=
next_token_ids
,
next_token_ids
=
next_token_ids
,
target_positions
=
target_positions
,
target_positions
=
target_positions
,
last_token_indices
=
last_token_indices
,
target_hidden_states
=
target_hidden_states
,
token_indices_to_sample
=
token_indices_to_sample
,
cad
=
common_attn_metadata
,
cad
=
common_attn_metadata
,
num_rejected_tokens_gpu
=
num_rejected_tokens_gpu
,
num_rejected_tokens_gpu
=
num_rejected_tokens_gpu
,
)
)
...
@@ -366,11 +449,6 @@ class SpecDecodeBaseProposer:
...
@@ -366,11 +449,6 @@ class SpecDecodeBaseProposer:
if
num_tokens_across_dp
is
not
None
:
if
num_tokens_across_dp
is
not
None
:
num_tokens_across_dp
[
self
.
dp_rank
]
=
num_input_tokens
num_tokens_across_dp
[
self
.
dp_rank
]
=
num_input_tokens
if
self
.
pass_hidden_states_to_model
:
# target_hidden_states and self.hidden_states can have different
# hidden dims. E.g. large target model and small draft model.
self
.
hidden_states
[:
num_tokens
]
=
target_hidden_states
if
self
.
supports_mm_inputs
:
if
self
.
supports_mm_inputs
:
mm_embeds
,
is_mm_embed
=
mm_embed_inputs
or
(
None
,
None
)
mm_embeds
,
is_mm_embed
=
mm_embed_inputs
or
(
None
,
None
)
...
@@ -411,27 +489,27 @@ class SpecDecodeBaseProposer:
...
@@ -411,27 +489,27 @@ class SpecDecodeBaseProposer:
else
:
else
:
last_hidden_states
,
hidden_states
=
ret_hidden_states
last_hidden_states
,
hidden_states
=
ret_hidden_states
sample_hidden_states
=
last_hidden_states
[
last_
token_indices
]
sample_hidden_states
=
last_hidden_states
[
token_indices
_to_sample
]
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
)
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
)
# Early exit if there is only one draft token to be generated.
# Early exit if there is only one draft token to be generated.
if
self
.
num_speculative_tokens
==
1
:
if
self
.
num_speculative_tokens
==
1
or
self
.
parallel_drafting
:
draft_token_ids
=
logits
.
argmax
(
dim
=-
1
)
draft_token_ids
=
logits
.
argmax
(
dim
=-
1
)
return
draft_token_ids
.
view
(
-
1
,
1
)
return
draft_token_ids
.
view
(
-
1
,
self
.
num_speculative_tokens
)
if
self
.
uses_mrope
:
if
self
.
uses_mrope
:
positions
=
self
.
mrope_positions
[:,
last_
token_indices
]
positions
=
self
.
mrope_positions
[:,
token_indices
_to_sample
]
else
:
else
:
positions
=
self
.
positions
[
last_
token_indices
]
positions
=
self
.
positions
[
token_indices
_to_sample
]
if
self
.
method
in
(
if
self
.
method
in
(
"deepseek_mtp"
,
"deepseek_mtp"
,
"ernie_mtp"
,
"ernie_mtp"
,
"longcat_flash_mtp"
,
"longcat_flash_mtp"
,
"pangu_ultra_moe_mtp"
,
"pangu_ultra_moe_mtp"
,
):
):
hidden_states
=
self
.
hidden_states
[
last_
token_indices
]
hidden_states
=
self
.
hidden_states
[
token_indices
_to_sample
]
else
:
else
:
hidden_states
=
hidden_states
[
last_
token_indices
]
hidden_states
=
hidden_states
[
token_indices
_to_sample
]
if
isinstance
(
attn_metadata
,
TreeAttentionMetadata
):
if
isinstance
(
attn_metadata
,
TreeAttentionMetadata
):
# Draft using tree attention.
# Draft using tree attention.
...
@@ -624,12 +702,17 @@ class SpecDecodeBaseProposer:
...
@@ -624,12 +702,17 @@ class SpecDecodeBaseProposer:
target_token_ids
:
torch
.
Tensor
,
target_token_ids
:
torch
.
Tensor
,
next_token_ids
:
torch
.
Tensor
,
next_token_ids
:
torch
.
Tensor
,
target_positions
:
torch
.
Tensor
,
target_positions
:
torch
.
Tensor
,
last_token_indices
:
torch
.
Tensor
|
None
,
target_hidden_states
:
torch
.
Tensor
,
token_indices_to_sample
:
torch
.
Tensor
|
None
,
cad
:
CommonAttentionMetadata
,
cad
:
CommonAttentionMetadata
,
num_rejected_tokens_gpu
:
torch
.
Tensor
|
None
,
num_rejected_tokens_gpu
:
torch
.
Tensor
|
None
,
)
->
tuple
[
int
,
torch
.
Tensor
,
CommonAttentionMetadata
]:
)
->
tuple
[
int
,
torch
.
Tensor
,
CommonAttentionMetadata
]:
if
last_token_indices
is
None
:
if
not
self
.
needs_extra_input_slots
:
last_token_indices
=
cad
.
query_start_loc
[
1
:]
-
1
# Default EAGLE pathway: no reshaping of input tensors needed.
# Simply rotate the input ids and leave the positions unchanged,
# Inserting the next token ids at the last slot in each request.
if
token_indices_to_sample
is
None
:
token_indices_to_sample
=
cad
.
query_start_loc
[
1
:]
-
1
num_tokens
=
target_token_ids
.
shape
[
0
]
num_tokens
=
target_token_ids
.
shape
[
0
]
# Shift the input ids by one token.
# Shift the input ids by one token.
...
@@ -637,14 +720,121 @@ class SpecDecodeBaseProposer:
...
@@ -637,14 +720,121 @@ class SpecDecodeBaseProposer:
self
.
input_ids
[:
num_tokens
-
1
]
=
target_token_ids
[
1
:]
self
.
input_ids
[:
num_tokens
-
1
]
=
target_token_ids
[
1
:]
# Replace the last token with the next token.
# Replace the last token with the next token.
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
self
.
input_ids
[
last_
token_indices
]
=
next_token_ids
self
.
input_ids
[
token_indices
_to_sample
]
=
next_token_ids
# copy inputs to buffer for cudagraph
# copy inputs to buffer for cudagraph
if
self
.
uses_xdrope_dim
>
0
and
self
.
draft_uses_xdrope_dim
==
0
:
if
self
.
uses_xdrope_dim
>
0
and
self
.
draft_uses_xdrope_dim
==
0
:
target_positions
=
target_positions
[
0
]
target_positions
=
target_positions
[
0
]
self
.
_set_positions
(
num_tokens
,
target_positions
)
self
.
_set_positions
(
num_tokens
,
target_positions
)
return
num_tokens
,
last_token_indices
,
cad
self
.
hidden_states
[:
num_tokens
]
=
target_hidden_states
return
num_tokens
,
token_indices_to_sample
,
cad
else
:
assert
self
.
is_rejected_token_mask
is
not
None
assert
self
.
is_masked_token_mask
is
not
None
# 1.
# Call a custom triton kernel to copy input_ids and positions
# into the correct slots in the preallocated buffers self.input_ids,
# self.positions.
batch_size
=
cad
.
batch_size
()
# Since we might have to copy a lot of data for prefills, we select the
# block size based on the max query length and limit to max 256 slots/block.
max_num_tokens_per_request
=
(
cad
.
max_query_len
+
self
.
net_num_new_slots_per_request
)
BLOCK_SIZE_TOKENS
=
min
(
256
,
triton
.
next_power_of_2
(
max_num_tokens_per_request
)
)
num_blocks
=
(
max_num_tokens_per_request
+
BLOCK_SIZE_TOKENS
-
1
)
//
BLOCK_SIZE_TOKENS
total_num_input_tokens
=
target_token_ids
.
shape
[
0
]
total_num_output_tokens
=
total_num_input_tokens
+
(
self
.
net_num_new_slots_per_request
*
batch_size
)
token_indices_to_sample
=
torch
.
empty
(
batch_size
*
self
.
extra_slots_per_request
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
)
# Destination indices to write target_hidden_states into drafting buffer.
out_hidden_state_mapping
=
torch
.
empty
(
total_num_input_tokens
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
# Kernel grid: one program per request (row)
grid
=
(
batch_size
,
num_blocks
)
query_start_loc
=
cad
.
query_start_loc
query_end_loc
=
cad
.
query_start_loc
[
1
:]
-
1
if
num_rejected_tokens_gpu
is
not
None
:
query_end_loc
=
query_end_loc
-
num_rejected_tokens_gpu
copy_and_expand_eagle_inputs_kernel
[
grid
](
# (Padded) Inputs from the target model
target_token_ids_ptr
=
target_token_ids
,
target_positions_ptr
=
target_positions
,
next_token_ids_ptr
=
next_token_ids
,
# sampled tokens, one per request
# Outputs to the drafting buffers
out_input_ids_ptr
=
self
.
input_ids
,
out_positions_ptr
=
self
.
positions
,
# Doesn't support mrope for now
out_is_rejected_token_mask_ptr
=
self
.
is_rejected_token_mask
,
out_is_masked_token_mask_ptr
=
self
.
is_masked_token_mask
,
out_new_token_indices_ptr
=
token_indices_to_sample
,
out_hidden_state_mapping_ptr
=
out_hidden_state_mapping
,
# Input metadata
query_start_loc_ptr
=
query_start_loc
,
query_end_loc_ptr
=
query_end_loc
,
padding_token_id
=
0
,
parallel_drafting_token_id
=
self
.
parallel_drafting_token_id
,
# Sizing info
# Note that we can deduce batch_size for free from the grid size
total_input_tokens
=
total_num_input_tokens
,
num_padding_slots_per_request
=
self
.
extra_slots_per_request
,
shift_input_ids
=
self
.
pass_hidden_states_to_model
,
BLOCK_SIZE_TOKENS
=
BLOCK_SIZE_TOKENS
,
)
if
self
.
pass_hidden_states_to_model
:
assert
self
.
parallel_drafting_hidden_state_tensor
is
not
None
self
.
hidden_states
[
out_hidden_state_mapping
]
=
target_hidden_states
# Use torch.where to avoid DtoH sync from boolean indexing
mask
=
self
.
is_masked_token_mask
[:
total_num_output_tokens
]
torch
.
where
(
mask
.
unsqueeze
(
1
),
self
.
parallel_drafting_hidden_state_tensor
,
self
.
hidden_states
[:
total_num_output_tokens
],
out
=
self
.
hidden_states
[:
total_num_output_tokens
],
)
# 2.
# Recompute the slot mapping based on the new positions and
# rejection mask.
builder
=
(
self
.
_get_attention_metadata_builder
()
if
self
.
attn_metadata_builder
is
None
else
self
.
attn_metadata_builder
)
new_slot_mapping
=
compute_new_slot_mapping
(
cad
=
cad
,
new_positions
=
self
.
positions
[:
total_num_output_tokens
],
is_rejected_token_mask
=
self
.
is_rejected_token_mask
[
:
total_num_output_tokens
],
block_size
=
builder
.
kv_cache_spec
.
block_size
,
num_new_tokens
=
self
.
net_num_new_slots_per_request
,
max_model_len
=
self
.
max_model_len
,
)
# 3. Update the common attention metadata with the new (meta)data
new_cad
=
extend_all_queries_by_N
(
cad
,
N
=
self
.
net_num_new_slots_per_request
,
arange
=
self
.
arange
,
new_slot_mapping
=
new_slot_mapping
,
)
return
total_num_output_tokens
,
token_indices_to_sample
,
new_cad
def
model_returns_tuple
(
self
)
->
bool
:
def
model_returns_tuple
(
self
)
->
bool
:
return
self
.
method
not
in
(
"mtp"
,
"draft_model"
)
return
self
.
method
not
in
(
"mtp"
,
"draft_model"
)
...
@@ -1081,8 +1271,21 @@ class SpecDecodeBaseProposer:
...
@@ -1081,8 +1271,21 @@ class SpecDecodeBaseProposer:
model
=
model
.
module
model
=
model
.
module
return
model
.
__class__
.
__name__
return
model
.
__class__
.
__name__
def
_get_model
(
self
)
->
nn
.
Module
:
"""
Default method to call get_model(). Can be overridden by subclasses which
need to customize model loading.
"""
from
vllm.compilation.backends
import
set_model_tag
with
set_model_tag
(
"eagle_head"
):
model
=
get_model
(
vllm_config
=
self
.
vllm_config
,
model_config
=
self
.
speculative_config
.
draft_model_config
,
)
return
model
def
load_model
(
self
,
target_model
:
nn
.
Module
)
->
None
:
def
load_model
(
self
,
target_model
:
nn
.
Module
)
->
None
:
draft_model_config
=
self
.
speculative_config
.
draft_model_config
target_attn_layer_names
=
set
(
target_attn_layer_names
=
set
(
get_layers_from_vllm_config
(
get_layers_from_vllm_config
(
self
.
vllm_config
,
self
.
vllm_config
,
...
@@ -1096,12 +1299,7 @@ class SpecDecodeBaseProposer:
...
@@ -1096,12 +1299,7 @@ class SpecDecodeBaseProposer:
).
keys
()
).
keys
()
)
)
from
vllm.compilation.backends
import
set_model_tag
self
.
model
=
self
.
_get_model
()
with
set_model_tag
(
"eagle_head"
):
self
.
model
=
get_model
(
vllm_config
=
self
.
vllm_config
,
model_config
=
draft_model_config
)
draft_attn_layer_names
=
(
draft_attn_layer_names
=
(
get_layers_from_vllm_config
(
get_layers_from_vllm_config
(
...
@@ -1170,7 +1368,26 @@ class SpecDecodeBaseProposer:
...
@@ -1170,7 +1368,26 @@ class SpecDecodeBaseProposer:
else
:
else
:
target_language_model
=
target_model
target_language_model
=
target_model
# share embed_tokens with the target model if needed
self
.
_maybe_share_embeddings
(
target_language_model
)
self
.
_maybe_share_lm_head
(
target_language_model
)
if
self
.
parallel_drafting
and
self
.
pass_hidden_states_to_model
:
assert
self
.
parallel_drafting_hidden_state_tensor
is
not
None
self
.
parallel_drafting_hidden_state_tensor
.
copy_
(
self
.
model
.
combine_hidden_states
(
self
.
model
.
mask_hidden
.
view
(
3
*
self
.
hidden_size
)
)
if
self
.
eagle3_use_aux_hidden_state
else
self
.
model
.
mask_hidden
.
view
(
self
.
hidden_size
)
)
def
_maybe_share_embeddings
(
self
,
target_language_model
:
nn
.
Module
)
->
None
:
"""
Some draft models may not have their own embedding layers, and some may
have a duplicate copy of the target model's embedding layers. In these cases,
we share the target model's embedding layers with the draft model to save
memory.
"""
if
get_pp_group
().
world_size
==
1
:
if
get_pp_group
().
world_size
==
1
:
inner_model
=
getattr
(
target_language_model
,
"model"
,
None
)
inner_model
=
getattr
(
target_language_model
,
"model"
,
None
)
if
inner_model
is
None
:
if
inner_model
is
None
:
...
@@ -1233,7 +1450,12 @@ class SpecDecodeBaseProposer:
...
@@ -1233,7 +1450,12 @@ class SpecDecodeBaseProposer:
" from the target model."
" from the target model."
)
)
# share lm_head with the target model if needed
def
_maybe_share_lm_head
(
self
,
target_language_model
:
nn
.
Module
)
->
None
:
"""
Some draft models may not have their own LM head, and some may have a
duplicate copy of the target model's LM head. In these cases, we share
the target model's LM head with the draft model to save memory.
"""
share_lm_head
=
False
share_lm_head
=
False
if
hasattr
(
self
.
model
,
"has_own_lm_head"
):
if
hasattr
(
self
.
model
,
"has_own_lm_head"
):
# EAGLE model
# EAGLE model
...
...
vllm/v1/spec_decode/utils.py
View file @
af3162d3
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
torch
from
vllm.config
import
VllmConfig
,
replace
from
vllm.triton_utils
import
tl
,
triton
from
vllm.triton_utils
import
tl
,
triton
from
vllm.v1.attention.backends.utils
import
(
CommonAttentionMetadata
,
)
PADDING_SLOT_ID
=
-
1
@
triton
.
jit
@
triton
.
jit
...
@@ -107,3 +115,243 @@ def eagle_prepare_next_token_padded_kernel(
...
@@ -107,3 +115,243 @@ def eagle_prepare_next_token_padded_kernel(
tl
.
store
(
next_token_ids_ptr
+
req_idx
,
backup_token
)
tl
.
store
(
next_token_ids_ptr
+
req_idx
,
backup_token
)
tl
.
store
(
valid_sampled_tokens_count_ptr
+
req_idx
,
valid_count
)
tl
.
store
(
valid_sampled_tokens_count_ptr
+
req_idx
,
valid_count
)
def
compute_new_slot_mapping
(
cad
:
CommonAttentionMetadata
,
new_positions
:
torch
.
Tensor
,
is_rejected_token_mask
:
torch
.
Tensor
,
block_size
:
int
,
num_new_tokens
:
int
,
max_model_len
:
int
,
):
batch_size
,
n_blocks_per_req
=
cad
.
block_table_tensor
.
shape
req_indices
=
torch
.
arange
(
batch_size
,
device
=
cad
.
query_start_loc
.
device
)
req_indices
=
torch
.
repeat_interleave
(
req_indices
,
cad
.
naive_query_lens
()
+
num_new_tokens
,
output_size
=
len
(
new_positions
),
)
# Clamp the positions to prevent an out-of-bounds error when indexing
# into block_table_tensor.
clamped_positions
=
torch
.
clamp
(
new_positions
,
max
=
max_model_len
-
1
)
block_table_indices
=
(
req_indices
*
n_blocks_per_req
+
clamped_positions
//
block_size
)
block_nums
=
cad
.
block_table_tensor
.
view
(
-
1
)[
block_table_indices
]
block_offsets
=
clamped_positions
%
block_size
new_slot_mapping
=
block_nums
*
block_size
+
block_offsets
# Mask out the position ids that exceed the max model length.
exceeds_max_model_len
=
new_positions
>=
max_model_len
new_slot_mapping
.
masked_fill_
(
exceeds_max_model_len
,
PADDING_SLOT_ID
)
# Mask out rejected tokens to prevent saves to the KV cache.
new_slot_mapping
.
masked_fill_
(
is_rejected_token_mask
,
PADDING_SLOT_ID
)
return
new_slot_mapping
def
create_vllm_config_for_draft_model
(
target_model_vllm_config
:
VllmConfig
,
)
->
VllmConfig
:
"""The vllm_config is configured for the target model, e.g.
its quant_config and parallel_config. But the draft model is potentially
quantized differently, and has potentially different tensor_parallel_size.
This function creates a new vllm_config configured for the drafter.
The vllm_config is useful when loading the draft model with get_model().
"""
old
=
target_model_vllm_config
assert
old
.
speculative_config
is
not
None
,
"speculative_config is not set"
old_spec_config
=
old
.
speculative_config
new_parallel_config
=
replace
(
old_spec_config
.
draft_parallel_config
,
rank
=
old
.
parallel_config
.
rank
)
new
:
VllmConfig
=
replace
(
old
,
quant_config
=
None
,
parallel_config
=
new_parallel_config
,
model_config
=
old_spec_config
.
draft_model_config
,
)
return
new
def
extend_all_queries_by_N
(
common_attn_metadata
:
CommonAttentionMetadata
,
N
:
int
,
arange
:
torch
.
Tensor
,
new_slot_mapping
:
torch
.
Tensor
,
)
->
CommonAttentionMetadata
:
"""
Creates a new CommonAttentionMetadata with all query lengths increased by N.
Also all seq lens are increased by N.
This is useful e.g. in speculative decoding with parallel drafting, where we
extend each sequence by N tokens and predict all tokens in one pass.
The slot mapping is computed externally, as it requires more information.
"""
cad
=
common_attn_metadata
# query start loc must be increased by [+0, +N, +2N, ..., +batch_size * N]
new_query_start_loc
=
cad
.
query_start_loc
+
N
*
arange
[:
len
(
cad
.
query_start_loc
)]
new_query_start_loc_cpu
=
cad
.
query_start_loc_cpu
+
N
*
torch
.
arange
(
len
(
cad
.
query_start_loc_cpu
),
dtype
=
torch
.
int32
)
new_cad
=
cad
.
replace
(
query_start_loc
=
new_query_start_loc
,
query_start_loc_cpu
=
new_query_start_loc_cpu
,
seq_lens
=
cad
.
seq_lens
+
N
,
# each request is extended by N tokens -> batch_size * N tokens are added
num_actual_tokens
=
cad
.
num_actual_tokens
+
cad
.
batch_size
()
*
N
,
# All query lens increase by N, so max query len increases by N
max_query_len
=
cad
.
max_query_len
+
N
,
max_seq_len
=
cad
.
max_seq_len
+
N
,
slot_mapping
=
new_slot_mapping
,
)
return
new_cad
# Unified copy/expand kernel
@
triton
.
jit
def
copy_and_expand_eagle_inputs_kernel
(
# (Padded) Inputs from the target model
target_token_ids_ptr
,
# [total_tokens_in_batch]
target_positions_ptr
,
# [total_tokens_in_batch]
next_token_ids_ptr
,
# [num_reqs]
# Outputs to the drafting buffers
out_input_ids_ptr
,
# [total_draft_tokens_in_batch] (output)
out_positions_ptr
,
# [total_draft_tokens_in_batch] (output)
out_is_rejected_token_mask_ptr
,
# [total_draft_tokens_in_batch] (output)
out_is_masked_token_mask_ptr
,
# [total_draft_tokens_in_batch] (output)
out_new_token_indices_ptr
,
# [num_padding_slots_per_request * num_reqs] (output)
out_hidden_state_mapping_ptr
,
# [total_tokens_in_batch]
# Input metadata
query_start_loc_ptr
,
# [num_reqs + 1], last value is the total num input tokens
query_end_loc_ptr
,
# [num_reqs]
padding_token_id
,
# tl.int32
parallel_drafting_token_id
,
# tl.int32
# Sizing info
total_input_tokens
,
# tl.int32
num_padding_slots_per_request
,
# tl.int32
shift_input_ids
,
# tl.bool
BLOCK_SIZE_TOKENS
:
tl
.
constexpr
,
# Blocks along token dim to handle prefills
):
"""
Copy and expand inputs from the target model to the drafting buffers for Eagle
speculative decoding. This kernel handles padding slots and parallel drafting
tokens, if enabled.
"""
request_idx
=
tl
.
program_id
(
axis
=
0
)
token_batch_idx
=
tl
.
program_id
(
axis
=
1
)
# Load query locations
query_start_loc
=
tl
.
load
(
query_start_loc_ptr
+
request_idx
)
next_query_start_loc
=
tl
.
load
(
query_start_loc_ptr
+
request_idx
+
1
)
query_end_loc
=
tl
.
load
(
query_end_loc_ptr
+
request_idx
)
# Calculate number of valid tokens to copy and input offset
# With shift_input_ids=True, we skip the first token
# Output layout: each request gets (input_len + num_padding_slots_per_request) slots
# But with shift, we lose one token per request
if
shift_input_ids
:
num_valid_tokens
=
query_end_loc
-
query_start_loc
input_offset
=
1
output_start
=
query_start_loc
+
request_idx
*
(
num_padding_slots_per_request
-
1
)
else
:
num_valid_tokens
=
query_end_loc
-
query_start_loc
+
1
input_offset
=
0
output_start
=
query_start_loc
+
request_idx
*
num_padding_slots_per_request
# Number of rejected tokens from previous speculation
num_rejected
=
next_query_start_loc
-
query_end_loc
-
1
# Total output tokens for this request
total_output_tokens
=
(
num_valid_tokens
+
num_padding_slots_per_request
+
num_rejected
)
# Process tokens in this block
j
=
token_batch_idx
*
BLOCK_SIZE_TOKENS
+
tl
.
arange
(
0
,
BLOCK_SIZE_TOKENS
)
# Compute masks for different output regions:
# [0, num_valid_tokens): valid tokens copied from input
# [num_valid_tokens]: bonus token from next_token_ids
# (num_valid_tokens, num_valid_tokens + num_padding_slots_per_request):
# parallel drafting slots
# [num_valid_tokens + num_padding_slots_per_request, total_output_tokens):
# rejected slots
in_bounds
=
j
<
total_output_tokens
is_valid_region
=
j
<
num_valid_tokens
is_bonus_region
=
j
==
num_valid_tokens
is_parallel_draft_region
=
(
j
>
num_valid_tokens
)
&
(
j
<
num_valid_tokens
+
num_padding_slots_per_request
)
is_rejected_region
=
j
>=
num_valid_tokens
+
num_padding_slots_per_request
# Compute output indices
out_idx
=
output_start
+
j
# For valid tokens, compute input index
in_idx
=
query_start_loc
+
input_offset
+
j
# Clamp to avoid out-of-bounds access (masked loads still need valid addresses)
in_idx_clamped
=
tl
.
minimum
(
in_idx
,
total_input_tokens
-
1
)
# Load input tokens (masked to valid region)
token_ids
=
tl
.
load
(
target_token_ids_ptr
+
in_idx_clamped
,
mask
=
is_valid_region
&
in_bounds
,
other
=
0
)
# Load the starting position for this request (first position in the sequence)
start_pos
=
tl
.
load
(
target_positions_ptr
+
query_start_loc
)
# Load bonus token for this request
bonus_token
=
tl
.
load
(
next_token_ids_ptr
+
request_idx
)
# Build final token_ids based on region
token_ids
=
tl
.
where
(
is_bonus_region
,
bonus_token
,
token_ids
)
token_ids
=
tl
.
where
(
is_parallel_draft_region
,
parallel_drafting_token_id
,
token_ids
)
token_ids
=
tl
.
where
(
is_rejected_region
,
padding_token_id
,
token_ids
)
# Build final positions:
# Positions are NOT shifted - they start from the first input position and increment
# Output position j gets start_pos + j
# (e.g., input positions [5,6,7] -> output [5,6,7,8,9,...])
positions
=
start_pos
+
j
# Rejected positions are don't-care, set to 0
positions
=
tl
.
where
(
is_rejected_region
,
0
,
positions
)
# Compute output masks
is_rejected_out
=
is_rejected_region
&
in_bounds
is_masked_out
=
is_parallel_draft_region
&
in_bounds
# Compute indices of new tokens (bonus + parallel drafting) for sampling
# New tokens are at positions
# [num_valid_tokens, num_valid_tokens + num_padding_slots_per_request)
is_new_token_region
=
(
j
>=
num_valid_tokens
)
&
(
j
<
num_valid_tokens
+
num_padding_slots_per_request
)
new_token_local_idx
=
(
j
-
num_valid_tokens
)
# 0 for bonus, 1, 2, ... for parallel drafting
new_token_out_idx
=
(
request_idx
*
num_padding_slots_per_request
+
new_token_local_idx
)
# Compute hidden state mapping (source index -> destination index)
# This maps each input position to its corresponding output position
# Hidden states don't get shifted, so we map all input tokens (including rejected)
if
shift_input_ids
:
num_input_tokens_this_request
=
next_query_start_loc
-
query_start_loc
is_input_region
=
j
<
num_input_tokens_this_request
src_idx
=
query_start_loc
+
j
tl
.
store
(
out_hidden_state_mapping_ptr
+
src_idx
,
out_idx
,
mask
=
is_input_region
)
# Store outputs
tl
.
store
(
out_input_ids_ptr
+
out_idx
,
token_ids
,
mask
=
in_bounds
)
tl
.
store
(
out_positions_ptr
+
out_idx
,
positions
,
mask
=
in_bounds
)
tl
.
store
(
out_is_rejected_token_mask_ptr
+
out_idx
,
is_rejected_out
,
mask
=
in_bounds
)
tl
.
store
(
out_is_masked_token_mask_ptr
+
out_idx
,
is_masked_out
,
mask
=
in_bounds
)
tl
.
store
(
out_new_token_indices_ptr
+
new_token_out_idx
,
out_idx
,
mask
=
is_new_token_region
&
in_bounds
,
)
vllm/v1/worker/gpu_model_runner.py
View file @
af3162d3
...
@@ -4090,7 +4090,7 @@ class GPUModelRunner(
...
@@ -4090,7 +4090,7 @@ class GPUModelRunner(
target_positions
=
target_positions
,
target_positions
=
target_positions
,
target_hidden_states
=
target_hidden_states
,
target_hidden_states
=
target_hidden_states
,
next_token_ids
=
next_token_ids
,
next_token_ids
=
next_token_ids
,
last_
token_indices
=
token_indices_to_sample
,
token_indices
_to_sample
=
token_indices_to_sample
,
sampling_metadata
=
sampling_metadata
,
sampling_metadata
=
sampling_metadata
,
common_attn_metadata
=
common_attn_metadata
,
common_attn_metadata
=
common_attn_metadata
,
mm_embed_inputs
=
mm_embed_inputs
,
mm_embed_inputs
=
mm_embed_inputs
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment