Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
myrfy001
vllm_dsv4
Commits
494636b2
Unverified
Commit
494636b2
authored
Mar 30, 2026
by
Benjamin Chislett
Committed by
GitHub
Mar 30, 2026
Browse files
[Feat][Spec Decode] DFlash (#36847)
Signed-off-by:
Benjamin Chislett
<
bchislett@nvidia.com
>
parent
ab1a6a43
Changes
17
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
1577 additions
and
107 deletions
+1577
-107
tests/models/registry.py
tests/models/registry.py
+8
-0
tests/v1/e2e/spec_decode/test_spec_decode.py
tests/v1/e2e/spec_decode/test_spec_decode.py
+166
-6
tests/v1/spec_decode/test_eagle.py
tests/v1/spec_decode/test_eagle.py
+152
-3
vllm/config/speculative.py
vllm/config/speculative.py
+21
-6
vllm/config/vllm.py
vllm/config/vllm.py
+20
-0
vllm/model_executor/models/qwen3.py
vllm/model_executor/models/qwen3.py
+1
-0
vllm/model_executor/models/qwen3_dflash.py
vllm/model_executor/models/qwen3_dflash.py
+619
-0
vllm/model_executor/models/qwen3_next.py
vllm/model_executor/models/qwen3_next.py
+6
-8
vllm/model_executor/models/registry.py
vllm/model_executor/models/registry.py
+1
-0
vllm/transformers_utils/configs/eagle.py
vllm/transformers_utils/configs/eagle.py
+12
-1
vllm/v1/attention/backend.py
vllm/v1/attention/backend.py
+14
-0
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+4
-0
vllm/v1/attention/selector.py
vllm/v1/attention/selector.py
+9
-1
vllm/v1/spec_decode/dflash.py
vllm/v1/spec_decode/dflash.py
+282
-0
vllm/v1/spec_decode/eagle.py
vllm/v1/spec_decode/eagle.py
+119
-72
vllm/v1/spec_decode/utils.py
vllm/v1/spec_decode/utils.py
+108
-0
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+35
-10
No files found.
tests/models/registry.py
View file @
494636b2
...
...
@@ -1163,6 +1163,14 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
# "JackFram/llama-160m",
# speculative_model="ibm-ai-platform/llama-160m-accelerator"
# ),
# [DFlash]
"DFlashDraftModel"
:
_HfExamplesInfo
(
"Qwen/Qwen3.5-4B"
,
speculative_model
=
"z-lab/Qwen3.5-4B-DFlash"
,
use_original_num_layers
=
True
,
# Need all layers since DFlash has >1 layer,
max_model_len
=
8192
,
# Reduce max len to ensure test runs in low-VRAM CI env
max_num_seqs
=
32
,
),
# [Eagle]
"EagleDeepSeekMTPModel"
:
_HfExamplesInfo
(
"eagle618/deepseek-v3-random"
,
...
...
tests/v1/e2e/spec_decode/test_spec_decode.py
View file @
494636b2
...
...
@@ -7,6 +7,7 @@ from typing import Any
import
pytest
import
torch
from
tqdm
import
tqdm
from
tests.evals.gsm8k.gsm8k_eval
import
_build_gsm8k_prompts
,
evaluate_gsm8k_offline
from
tests.utils
import
(
...
...
@@ -1105,19 +1106,178 @@ def some_high_acceptance_metrics() -> dict:
}
def
compute_acceptance_rate
(
metrics
:
list
[
Metric
])
->
float
:
def
compute_acceptance_rate
(
metrics
:
list
[
Metric
],
prev_metrics
:
list
[
Metric
]
|
None
=
None
)
->
float
:
name2metric
=
{
metric
.
name
:
metric
for
metric
in
metrics
}
n_draft_toks
=
name2metric
[
"vllm:spec_decode_num_draft_tokens"
].
value
# type: ignore
n_draft_toks
=
name2metric
[
"vllm:spec_decode_num_draft_tokens"
].
value
if
n_draft_toks
==
0
:
return
float
(
"nan"
)
n_accepted_toks
=
name2metric
[
"vllm:spec_decode_num_accepted_tokens"
].
value
# type: ignore
n_accepted_toks
=
name2metric
[
"vllm:spec_decode_num_accepted_tokens"
].
value
if
prev_metrics
is
not
None
:
prev_name2metric
=
{
metric
.
name
:
metric
for
metric
in
prev_metrics
}
n_draft_toks
-=
prev_name2metric
[
"vllm:spec_decode_num_draft_tokens"
].
value
n_accepted_toks
-=
prev_name2metric
[
"vllm:spec_decode_num_accepted_tokens"
].
value
if
n_draft_toks
<=
0
:
return
float
(
"nan"
)
return
n_accepted_toks
/
n_draft_toks
def
compute_acceptance_len
(
metrics
:
list
[
Metric
])
->
float
:
def
compute_acceptance_len
(
metrics
:
list
[
Metric
],
prev_metrics
:
list
[
Metric
]
|
None
=
None
)
->
float
:
name2metric
=
{
metric
.
name
:
metric
for
metric
in
metrics
}
n_drafts
=
name2metric
[
"vllm:spec_decode_num_drafts"
].
value
# type: ignore
n_accepted_toks
=
name2metric
[
"vllm:spec_decode_num_accepted_tokens"
].
value
# type: ignore
n_drafts
=
name2metric
[
"vllm:spec_decode_num_drafts"
].
value
n_accepted_toks
=
name2metric
[
"vllm:spec_decode_num_accepted_tokens"
].
value
if
n_drafts
==
0
:
return
1
if
prev_metrics
is
not
None
:
prev_name2metric
=
{
metric
.
name
:
metric
for
metric
in
prev_metrics
}
n_drafts
-=
prev_name2metric
[
"vllm:spec_decode_num_drafts"
].
value
n_accepted_toks
-=
prev_name2metric
[
"vllm:spec_decode_num_accepted_tokens"
].
value
if
n_drafts
<=
0
:
return
1
return
1
+
(
n_accepted_toks
/
n_drafts
)
# Datasets in the format used in DFlash validations
def
load_and_process_dataset
(
data_name
:
str
):
from
datasets
import
load_dataset
if
data_name
==
"gsm8k"
:
dataset
=
load_dataset
(
"openai/gsm8k"
,
"main"
,
split
=
"test"
)
prompt_fmt
=
(
"{question}
\n
Please reason step by step,"
" and put your final answer within
\\
boxed{{}}."
)
dataset
=
dataset
.
map
(
lambda
x
:
{
"turns"
:
[
prompt_fmt
.
format
(
**
x
)]})
elif
data_name
==
"mt-bench"
:
dataset
=
load_dataset
(
"HuggingFaceH4/mt_bench_prompts"
,
split
=
"train"
)
dataset
=
dataset
.
map
(
lambda
x
:
{
"turns"
:
x
[
"prompt"
]})
elif
data_name
==
"humaneval"
:
dataset
=
load_dataset
(
"openai/openai_humaneval"
,
split
=
"test"
)
prompt_fmt
=
(
"Write a solution to the following problem and make sure"
" that it passes the tests:
\n
```python
\n
{prompt}
\n
```"
)
dataset
=
dataset
.
map
(
lambda
x
:
{
"turns"
:
[
prompt_fmt
.
format
(
**
x
)]})
return
dataset
@
pytest
.
fixture
def
dflash_config
():
target_model
=
"Qwen/Qwen3-8B"
draft_model
=
"z-lab/Qwen3-8B-DFlash-b16"
return
dict
(
model
=
target_model
,
trust_remote_code
=
True
,
speculative_config
=
{
"method"
:
"dflash"
,
"model"
:
draft_model
,
"num_speculative_tokens"
:
16
,
"max_model_len"
:
32768
,
},
max_model_len
=
32768
,
max_num_seqs
=
128
,
gpu_memory_utilization
=
0.85
,
enforce_eager
=
False
,
disable_log_stats
=
False
,
)
def
test_dflash_acceptance_rates
(
dflash_config
):
"""
E2E test for DFlash (block diffusion) speculative decoding.
Runs acceptance rate validation on GSM8k, MT-Bench, and HumanEval
comparing against baseline results from the paper (Table 1).
See https://github.com/z-lab/dflash/blob/main/benchmark_sglang.py for methodology.
"""
spec_llm
=
LLM
(
**
dflash_config
)
max_prompts_per_dataset
=
200
# mt-bench has 80, humaneval has 164, truncates gsm8k
# All scores from Table 1 in https://arxiv.org/pdf/2602.06036
expected_acceptance_lengths
=
{
"mt-bench"
:
4.24
,
"humaneval"
:
6.50
,
"gsm8k"
:
6.54
*
0.95
,
# runs with a subset of prompts so extra wide tol here
}
tokenizer
=
spec_llm
.
get_tokenizer
()
for
dataset_name
,
expected_len
in
expected_acceptance_lengths
.
items
():
dataset
=
load_and_process_dataset
(
dataset_name
)
prev_metrics
=
None
acceptance_lengths
=
[]
for
i
in
tqdm
(
range
(
min
(
max_prompts_per_dataset
,
len
(
dataset
))),
desc
=
f
"Processing
{
dataset_name
}
"
,
):
user_content
=
dataset
[
i
][
"turns"
][
0
]
prompt_text
=
tokenizer
.
apply_chat_template
(
[{
"role"
:
"user"
,
"content"
:
user_content
}],
tokenize
=
False
,
add_generation_prompt
=
True
,
enable_thinking
=
False
,
)
# Temp=0, MaxTokens=2048 from the paper
spec_llm
.
generate
(
[
prompt_text
],
SamplingParams
(
temperature
=
0
,
max_tokens
=
2048
),
use_tqdm
=
False
,
)
current_metrics
=
spec_llm
.
get_metrics
()
acceptance_len
=
compute_acceptance_len
(
current_metrics
,
prev_metrics
)
prev_metrics
=
current_metrics
acceptance_lengths
.
append
(
acceptance_len
)
mean_acceptance_length
=
sum
(
acceptance_lengths
)
/
len
(
acceptance_lengths
)
expected_len
=
expected_len
*
0.9
print
(
f
"DFlash acceptance_len for
{
dataset_name
}
:
{
mean_acceptance_length
:.
2
f
}
"
f
" (expected at least
{
expected_len
:.
2
f
}
)"
)
assert
mean_acceptance_length
>=
expected_len
,
(
f
"DFlash acceptance_len for
{
dataset_name
}
is below expected threshold:"
f
"
{
mean_acceptance_length
:.
2
f
}
<
{
expected_len
:.
2
f
}
"
)
del
spec_llm
torch
.
accelerator
.
empty_cache
()
cleanup_dist_env_and_memory
()
def
test_dflash_correctness
(
dflash_config
):
"""
E2E test for DFlash (block diffusion) speculative decoding.
Ensures output correctness on GSM8k, with cudagraphs and batching on.
"""
spec_llm
=
LLM
(
**
dflash_config
)
# Evaluate GSM8k accuracy (Qwen3-8B ref: ~87-92% on GSM8k)
evaluate_llm_for_gsm8k
(
spec_llm
,
expected_accuracy_threshold
=
0.8
)
current_metrics
=
spec_llm
.
get_metrics
()
acceptance_len
=
compute_acceptance_len
(
current_metrics
)
# AR is thoroughly validated in test_dflash_acceptance_rates, in a manner consistent
# with the DFlash paper. However, that test measures AL per-request and thus runs
# with a batch size of 1. To ensure that AL does not collapse with large batch sizes
# we enforce a baseline on the AL over the full lm-eval-style GSM8k test.
expected_len
=
3.5
# Measured is 3.9 to 4.0
print
(
f
"DFlash GSM8k correctness test got AL
{
acceptance_len
}
"
)
assert
acceptance_len
>=
expected_len
,
(
"DFlash correctness check failed with"
f
"
{
acceptance_len
=
}
, expected at least
{
expected_len
}
"
)
del
spec_llm
torch
.
accelerator
.
empty_cache
()
cleanup_dist_env_and_memory
()
tests/v1/spec_decode/test_eagle.py
View file @
494636b2
...
...
@@ -27,6 +27,7 @@ from vllm.config.load import LoadConfig
from
vllm.model_executor.models.llama
import
LlamaForCausalLM
from
vllm.platforms
import
current_platform
from
vllm.v1.attention.backends.registry
import
AttentionBackendEnum
from
vllm.v1.spec_decode.dflash
import
DFlashProposer
from
vllm.v1.spec_decode.draft_model
import
DraftModelProposer
from
vllm.v1.spec_decode.eagle
import
EagleProposer
from
vllm.v1.spec_decode.metadata
import
SpecDecodeMetadata
...
...
@@ -36,6 +37,8 @@ model_dir = "meta-llama/Llama-3.1-8B-Instruct"
eagle_dir
=
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
eagle3_dir
=
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
ar_draft_model_dir
=
"amd/PARD-Llama-3.2-1B"
# Compatible with parallel and AR drafting
dflash_target_dir
=
"Qwen/Qwen3-8B"
dflash_dir
=
"z-lab/Qwen3-8B-DFlash-b16"
BLOCK_SIZE
=
16
...
...
@@ -47,18 +50,29 @@ def _create_proposer(
speculative_token_tree
:
list
[
tuple
[
int
,
...]]
|
None
=
None
,
parallel_drafting
:
bool
=
False
,
)
->
EagleProposer
:
model_config
=
ModelConfig
(
model
=
model_dir
,
runner
=
"generate"
,
max_model_len
=
100
)
# Method-dependent setup
if
method
==
"eagle"
:
target_model_dir
=
model_dir
draft_model_dir
=
eagle_dir
elif
method
==
"eagle3"
:
target_model_dir
=
model_dir
draft_model_dir
=
eagle3_dir
elif
method
==
"draft_model"
:
target_model_dir
=
model_dir
draft_model_dir
=
ar_draft_model_dir
elif
method
==
"dflash"
:
target_model_dir
=
dflash_target_dir
draft_model_dir
=
dflash_dir
else
:
raise
ValueError
(
f
"Unknown method:
{
method
}
"
)
model_config
=
ModelConfig
(
model
=
target_model_dir
,
runner
=
"generate"
,
max_model_len
=
100
,
trust_remote_code
=
(
method
==
"dflash"
),
)
spec_token_tree_str
=
None
if
speculative_token_tree
is
not
None
:
assert
num_speculative_tokens
==
len
(
speculative_token_tree
)
...
...
@@ -92,7 +106,9 @@ def _create_proposer(
attention_config
=
AttentionConfig
(
backend
=
attention_backend
),
)
if
"eagle"
in
method
:
if
method
==
"dflash"
:
proposer
=
DFlashProposer
(
vllm_config
=
vllm_config
,
device
=
device
)
elif
"eagle"
in
method
:
proposer
=
EagleProposer
(
vllm_config
=
vllm_config
,
device
=
device
)
else
:
proposer
=
DraftModelProposer
(
vllm_config
=
vllm_config
,
device
=
device
)
...
...
@@ -1152,3 +1168,136 @@ def test_propose_tree(spec_token_tree):
# Verify that the draft tokens match our expectations.
assert
torch
.
equal
(
result
,
expected_tokens
)
def
test_set_inputs_first_pass_dflash
():
"""
Test for DFlash set_inputs_first_pass.
DFlash uses cross-attention: context tokens become K/V and only
query tokens (bonus + mask) are Q. This tests the DFlash-specific
input preparation where:
- Context hidden states are stored by reference (no copy)
- Query input_ids are [next_token, mask, mask, ...] per request
- Context and query positions are written to separate buffers
- token_indices_to_sample points to mask token positions only
- A new CommonAttentionMetadata is returned with causal=False
Setup:
- 3 requests with query_lens [3, 2, 4]
- num_speculative_tokens = 3
- num_query_per_req = 4 (1 bonus + 3 mask tokens)
- next_token_ids: [100, 200, 300]
Expected output layout (query tokens only, 12 total):
Request 0 (indices 0-3): [100, mask, mask, mask]
Request 1 (indices 4-7): [200, mask, mask, mask]
Request 2 (indices 8-11): [300, mask, mask, mask]
Expected positions layout (separate buffers):
Context (_context_positions_buffer, 9 tokens): copied from target_positions
Query (positions, 12 tokens):
Request 0: last_pos=9, query=[10, 11, 12, 13]
Request 1: last_pos=7, query=[8, 9, 10, 11]
Request 2: last_pos=11, query=[12, 13, 14, 15]
"""
device
=
torch
.
device
(
current_platform
.
device_type
)
num_speculative_tokens
=
3
proposer
=
_create_proposer
(
"dflash"
,
num_speculative_tokens
)
mask_token_id
=
proposer
.
parallel_drafting_token_id
# Setup batch with 3 requests
batch_spec
=
BatchSpec
(
seq_lens
=
[
10
,
8
,
12
],
query_lens
=
[
3
,
2
,
4
],
)
common_attn_metadata
=
create_common_attn_metadata
(
batch_spec
,
block_size
=
BLOCK_SIZE
,
device
=
device
,
arange_block_indices
=
True
,
)
# Input tensors
# Request 0: tokens [10, 11, 12] at positions [7, 8, 9]
# Request 1: tokens [20, 21] at positions [6, 7]
# Request 2: tokens [30, 31, 32, 33] at positions [8, 9, 10, 11]
target_token_ids
=
torch
.
tensor
(
[
10
,
11
,
12
,
20
,
21
,
30
,
31
,
32
,
33
],
dtype
=
torch
.
int32
,
device
=
device
)
target_positions
=
torch
.
tensor
(
[
7
,
8
,
9
,
6
,
7
,
8
,
9
,
10
,
11
],
dtype
=
torch
.
int64
,
device
=
device
)
target_hidden_states
=
torch
.
randn
(
9
,
proposer
.
hidden_size
,
dtype
=
proposer
.
dtype
,
device
=
device
)
next_token_ids
=
torch
.
tensor
([
100
,
200
,
300
],
dtype
=
torch
.
int32
,
device
=
device
)
num_tokens
,
token_indices_to_sample
,
output_cad
=
proposer
.
set_inputs_first_pass
(
target_token_ids
=
target_token_ids
,
next_token_ids
=
next_token_ids
,
target_positions
=
target_positions
,
target_hidden_states
=
target_hidden_states
,
token_indices_to_sample
=
None
,
cad
=
common_attn_metadata
,
num_rejected_tokens_gpu
=
None
,
)
num_query_per_req
=
1
+
num_speculative_tokens
# 4
num_context
=
9
# num_tokens is the query-only count
assert
num_tokens
==
3
*
num_query_per_req
# 12
# Verify input_ids (query tokens only)
# Each request: [next_token, mask, mask, mask]
M
=
mask_token_id
expected_input_ids
=
torch
.
tensor
(
[
100
,
M
,
M
,
M
,
200
,
M
,
M
,
M
,
300
,
M
,
M
,
M
],
dtype
=
torch
.
int32
,
device
=
device
,
)
assert
torch
.
equal
(
proposer
.
input_ids
[:
num_tokens
],
expected_input_ids
)
# Verify context positions (separate buffer): copied from target_positions
assert
torch
.
equal
(
proposer
.
_context_positions_buffer
[:
num_context
],
target_positions
)
# Verify query positions (separate buffer, starts at index 0):
# req0: last_pos=9, query=[10, 11, 12, 13]
# req1: last_pos=7, query=[8, 9, 10, 11]
# req2: last_pos=11, query=[12, 13, 14, 15]
expected_query_positions
=
torch
.
tensor
(
[
10
,
11
,
12
,
13
,
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
],
dtype
=
torch
.
int64
,
device
=
device
,
)
assert
torch
.
equal
(
proposer
.
positions
[:
num_tokens
],
expected_query_positions
,
)
# Verify token_indices_to_sample (mask tokens only, skip bonus at offset 0)
# req0: query indices 0-3, mask at 1,2,3
# req1: query indices 4-7, mask at 5,6,7
# req2: query indices 8-11, mask at 9,10,11
expected_token_indices_to_sample
=
torch
.
tensor
(
[
1
,
2
,
3
,
5
,
6
,
7
,
9
,
10
,
11
],
dtype
=
torch
.
int32
,
device
=
device
)
assert
torch
.
equal
(
token_indices_to_sample
,
expected_token_indices_to_sample
)
# Verify the new CAD has DFlash-specific properties
assert
output_cad
.
causal
is
False
# DFlash requires non-causal attention
assert
output_cad
.
num_actual_tokens
==
num_tokens
# query-only count
assert
output_cad
.
max_query_len
==
num_query_per_req
expected_query_start_loc
=
torch
.
tensor
(
[
0
,
4
,
8
,
12
],
dtype
=
torch
.
int32
,
device
=
device
)
assert
torch
.
equal
(
output_cad
.
query_start_loc
,
expected_query_start_loc
)
# Verify hidden states (stored by reference, not copied)
assert
proposer
.
_dflash_hidden_states
is
target_hidden_states
vllm/config/speculative.py
View file @
494636b2
...
...
@@ -47,8 +47,11 @@ MTPModelTypes = Literal[
"pangu_ultra_moe_mtp"
,
"step3p5_mtp"
,
]
EagleModelTypes
=
Literal
[
"eagle"
,
"eagle3"
,
"extract_hidden_states"
,
MTPModelTypes
]
NgramGPUTypes
=
Literal
[
"ngram_gpu"
]
DFlashModelTypes
=
Literal
[
"dflash"
]
EagleModelTypes
=
Literal
[
"eagle"
,
"eagle3"
,
"extract_hidden_states"
,
MTPModelTypes
,
DFlashModelTypes
]
SpeculativeMethod
=
Literal
[
"ngram"
,
"medusa"
,
...
...
@@ -206,7 +209,11 @@ class SpeculativeConfig:
factors
:
list
[
Any
]
=
[]
# Eagle3 and extract_hidden_states affect the computation graph because
# they return intermediate hidden states in addition to the final hidden state.
uses_aux_hidden_states
=
self
.
method
in
(
"eagle3"
,
"extract_hidden_states"
)
uses_aux_hidden_states
=
self
.
method
in
(
"eagle3"
,
"extract_hidden_states"
,
"dflash"
,
)
factors
.
append
(
uses_aux_hidden_states
)
# The specific layers used also affect the computation graph
...
...
@@ -490,7 +497,7 @@ class SpeculativeConfig:
)
# Automatically detect the method
if
self
.
method
in
(
"eagle"
,
"eagle3"
):
if
self
.
method
in
(
"eagle"
,
"eagle3"
,
"dflash"
):
pass
# examples:
# yuhuili/EAGLE-LLaMA3-Instruct-8B
...
...
@@ -500,6 +507,8 @@ class SpeculativeConfig:
self
.
method
=
"eagle"
elif
"eagle3"
in
self
.
draft_model_config
.
model
.
lower
():
self
.
method
=
"eagle3"
elif
"dflash"
in
self
.
draft_model_config
.
model
.
lower
():
self
.
method
=
"dflash"
elif
self
.
draft_model_config
.
hf_config
.
model_type
==
"medusa"
:
self
.
method
=
"medusa"
elif
self
.
draft_model_config
.
hf_config
.
model_type
==
"mlp_speculator"
:
...
...
@@ -532,7 +541,7 @@ class SpeculativeConfig:
)
# Replace hf_config for EAGLE draft_model
if
self
.
method
in
(
"eagle"
,
"eagle3"
):
if
self
.
method
in
(
"eagle"
,
"eagle3"
,
"dflash"
):
from
vllm.transformers_utils.configs.eagle
import
EAGLEConfig
from
vllm.transformers_utils.configs.speculators
import
(
SpeculatorsConfig
,
...
...
@@ -552,6 +561,9 @@ class SpeculativeConfig:
self
.
draft_model_config
.
hf_config
=
eagle_config
self
.
update_arch_
()
if
self
.
method
==
"dflash"
:
self
.
parallel_drafting
=
True
if
self
.
num_speculative_tokens
is
not
None
and
hasattr
(
self
.
draft_model_config
.
hf_config
,
"num_lookahead_tokens"
):
...
...
@@ -807,7 +819,7 @@ class SpeculativeConfig:
"kimi_k25"
,
]
if
(
self
.
method
in
(
"eagle3"
,
"extract_hidden_states"
)
self
.
method
in
(
"eagle3"
,
"extract_hidden_states"
,
"dflash"
)
and
self
.
target_model_config
and
not
any
(
supported_model
in
self
.
target_model_config
.
hf_text_config
.
model_type
...
...
@@ -855,7 +867,10 @@ class SpeculativeConfig:
return
slots_per_req
def
use_eagle
(
self
)
->
bool
:
return
self
.
method
in
(
"eagle"
,
"eagle3"
,
"mtp"
)
return
self
.
method
in
(
"eagle"
,
"eagle3"
,
"mtp"
,
"dflash"
)
def
use_dflash
(
self
)
->
bool
:
return
self
.
method
==
"dflash"
def
uses_draft_model
(
self
)
->
bool
:
return
self
.
method
==
"draft_model"
...
...
vllm/config/vllm.py
View file @
494636b2
...
...
@@ -1327,6 +1327,26 @@ class VllmConfig:
max_num_batched_tokens
-
scheduled_token_delta
)
if
self
.
scheduler_config
.
max_num_scheduled_tokens
<=
0
:
raise
ValueError
(
"max_num_scheduled_tokens is set to"
f
"
{
self
.
scheduler_config
.
max_num_scheduled_tokens
}
based on"
" the speculative decoding settings, which does not allow"
" any tokens to be scheduled. Increase max_num_batched_tokens"
" to accommodate the additional draft token slots, or decrease"
" num_speculative_tokens or max_num_seqs."
)
if
self
.
scheduler_config
.
max_num_scheduled_tokens
<
8192
:
logger
.
warning_once
(
"max_num_scheduled_tokens is set to"
f
"
{
self
.
scheduler_config
.
max_num_scheduled_tokens
}
based on"
" the speculative decoding settings. This may lead to suboptimal"
" performance. Consider increasing max_num_batched_tokens to"
" accommodate the additional draft token slots, or decrease"
" num_speculative_tokens or max_num_seqs."
,
scope
=
"local"
,
)
max_num_scheduled_tokens
=
self
.
scheduler_config
.
max_num_scheduled_tokens
if
max_num_batched_tokens
<
max_num_scheduled_tokens
+
(
self
.
speculative_config
.
max_num_new_slots_for_drafting
...
...
vllm/model_executor/models/qwen3.py
View file @
494636b2
...
...
@@ -285,6 +285,7 @@ class Qwen3ForCausalLM(
self
.
config
=
config
self
.
vllm_config
=
vllm_config
self
.
quant_config
=
quant_config
self
.
model
=
Qwen3Model
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
)
...
...
vllm/model_executor/models/qwen3_dflash.py
0 → 100644
View file @
494636b2
This diff is collapsed.
Click to expand it.
vllm/model_executor/models/qwen3_next.py
View file @
494636b2
...
...
@@ -56,6 +56,7 @@ from vllm.sequence import IntermediateTensors
from
vllm.transformers_utils.configs.qwen3_next
import
Qwen3NextConfig
from
.interfaces
import
(
EagleModelMixin
,
HasInnerState
,
IsHybrid
,
MixtureOfExperts
,
...
...
@@ -454,7 +455,7 @@ class Qwen3NextDecoderLayer(nn.Module):
@
support_torch_compile
class
Qwen3NextModel
(
nn
.
Module
):
class
Qwen3NextModel
(
nn
.
Module
,
EagleModelMixin
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
...
...
@@ -492,8 +493,6 @@ class Qwen3NextModel(nn.Module):
else
:
self
.
norm
=
PPMissingLayer
()
self
.
aux_hidden_state_layers
:
tuple
[
int
,
...]
=
()
def
embed_input_ids
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embed_tokens
(
input_ids
)
...
...
@@ -515,20 +514,19 @@ class Qwen3NextModel(nn.Module):
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
aux_hidden_states
=
[]
aux_hidden_states
=
self
.
_maybe_add_hidden_state
([],
0
,
hidden_states
,
residual
)
for
layer_idx
,
layer
in
enumerate
(
islice
(
self
.
layers
,
self
.
start_layer
,
self
.
end_layer
),
start
=
self
.
start_layer
,
):
if
layer_idx
in
self
.
aux_hidden_state_layers
:
aux_hidden_states
.
append
(
hidden_states
+
residual
if
residual
is
not
None
else
hidden_states
)
hidden_states
,
residual
=
layer
(
positions
=
positions
,
hidden_states
=
hidden_states
,
residual
=
residual
,
)
self
.
_maybe_add_hidden_state
(
aux_hidden_states
,
layer_idx
+
1
,
hidden_states
,
residual
)
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
(
...
...
vllm/model_executor/models/registry.py
View file @
494636b2
...
...
@@ -546,6 +546,7 @@ _SPECULATIVE_DECODING_MODELS = {
"EagleLlamaForCausalLM"
:
(
"llama_eagle"
,
"EagleLlamaForCausalLM"
),
"EagleLlama4ForCausalLM"
:
(
"llama4_eagle"
,
"EagleLlama4ForCausalLM"
),
"EagleMiniCPMForCausalLM"
:
(
"minicpm_eagle"
,
"EagleMiniCPMForCausalLM"
),
"DFlashDraftModel"
:
(
"qwen3_dflash"
,
"DFlashQwen3ForCausalLM"
),
"Eagle3LlamaForCausalLM"
:
(
"llama_eagle3"
,
"Eagle3LlamaForCausalLM"
),
"LlamaForCausalLMEagle3"
:
(
"llama_eagle3"
,
"Eagle3LlamaForCausalLM"
),
"Eagle3Qwen2_5vlForCausalLM"
:
(
"llama_eagle3"
,
"Eagle3LlamaForCausalLM"
),
...
...
vllm/transformers_utils/configs/eagle.py
View file @
494636b2
...
...
@@ -62,9 +62,20 @@ class EAGLEConfig(PretrainedConfig):
else
f
"Eagle3
{
arch
}
"
for
arch
in
self
.
model
.
architectures
]
elif
method
==
"dflash"
:
assert
self
.
model
is
not
None
,
(
"model should not be None when method is dflash"
)
kwargs
[
"architectures"
]
=
[
arch
if
arch
.
startswith
(
"DFlash"
)
or
arch
.
endswith
(
"DFlash"
)
else
f
"DFlash
{
arch
}
"
for
arch
in
self
.
model
.
architectures
]
else
:
raise
ValueError
(
f
"Invalid method
{
method
}
. Supported methods are eagle and eagle3."
f
"Invalid method
{
method
}
. Supported methods are "
"eagle, eagle3, and dflash."
)
super
().
__init__
(
**
kwargs
)
...
...
vllm/v1/attention/backend.py
View file @
494636b2
...
...
@@ -220,6 +220,17 @@ class AttentionBackend(ABC):
def
supports_per_head_quant_scales
(
cls
)
->
bool
:
return
False
@
classmethod
def
supports_non_causal
(
cls
)
->
bool
:
"""Check if backend supports non-causal (bidirectional) attention
for decoder models.
Unlike ENCODER_ONLY attention type which implies a different
execution model, this refers to non-causal attention within the
standard paged-KV-cache decoder path.
"""
return
False
@
classmethod
def
supports_attn_type
(
cls
,
attn_type
:
str
)
->
bool
:
"""Check if backend supports a given attention type.
...
...
@@ -261,6 +272,7 @@ class AttentionBackend(ABC):
use_per_head_quant_scales
:
bool
,
device_capability
:
"DeviceCapability"
,
attn_type
:
str
,
use_non_causal
:
bool
=
False
,
)
->
list
[
str
]:
invalid_reasons
=
[]
if
not
cls
.
supports_head_size
(
head_size
):
...
...
@@ -293,6 +305,8 @@ class AttentionBackend(ABC):
invalid_reasons
.
append
(
"compute capability not supported"
)
if
not
cls
.
supports_attn_type
(
attn_type
):
invalid_reasons
.
append
(
f
"attention type
{
attn_type
}
not supported"
)
if
use_non_causal
and
not
cls
.
supports_non_causal
():
invalid_reasons
.
append
(
"non-causal attention not supported"
)
combination_reason
=
cls
.
supports_combination
(
head_size
,
dtype
,
...
...
vllm/v1/attention/backends/flash_attn.py
View file @
494636b2
...
...
@@ -101,6 +101,10 @@ class FlashAttentionBackend(AttentionBackend):
def
get_name
()
->
str
:
return
"FLASH_ATTN"
@
classmethod
def
supports_non_causal
(
cls
)
->
bool
:
return
True
@
classmethod
def
supports_attn_type
(
cls
,
attn_type
:
str
)
->
bool
:
"""FlashAttention supports all attention types."""
...
...
vllm/v1/attention/selector.py
View file @
494636b2
...
...
@@ -29,6 +29,7 @@ class AttentionSelectorConfig(NamedTuple):
use_mm_prefix
:
bool
=
False
use_per_head_quant_scales
:
bool
=
False
attn_type
:
str
=
AttentionType
.
DECODER
use_non_causal
:
bool
=
False
def
__repr__
(
self
):
return
(
...
...
@@ -41,7 +42,8 @@ class AttentionSelectorConfig(NamedTuple):
f
"use_sparse=
{
self
.
use_sparse
}
, "
f
"use_mm_prefix=
{
self
.
use_mm_prefix
}
, "
f
"use_per_head_quant_scales=
{
self
.
use_per_head_quant_scales
}
, "
f
"attn_type=
{
self
.
attn_type
}
)"
f
"attn_type=
{
self
.
attn_type
}
, "
f
"use_non_causal=
{
self
.
use_non_causal
}
)"
)
...
...
@@ -76,6 +78,11 @@ def get_attn_backend(
else
:
block_size
=
None
speculative_config
=
vllm_config
.
speculative_config
use_non_causal
=
(
speculative_config
is
not
None
and
speculative_config
.
method
==
"dflash"
)
attn_selector_config
=
AttentionSelectorConfig
(
head_size
=
head_size
,
dtype
=
dtype
,
...
...
@@ -87,6 +94,7 @@ def get_attn_backend(
use_mm_prefix
=
use_mm_prefix
,
use_per_head_quant_scales
=
use_per_head_quant_scales
,
attn_type
=
attn_type
or
AttentionType
.
DECODER
,
use_non_causal
=
use_non_causal
,
)
return
_cached_get_attn_backend
(
...
...
vllm/v1/spec_decode/dflash.py
0 → 100644
View file @
494636b2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Any
import
torch
from
typing_extensions
import
override
from
vllm.config
import
VllmConfig
from
vllm.forward_context
import
set_forward_context
from
vllm.logger
import
init_logger
from
vllm.triton_utils
import
triton
from
vllm.v1.attention.backend
import
CommonAttentionMetadata
from
vllm.v1.spec_decode.eagle
import
SpecDecodeBaseProposer
from
vllm.v1.spec_decode.utils
import
copy_and_expand_dflash_inputs_kernel
logger
=
init_logger
(
__name__
)
class
DFlashProposer
(
SpecDecodeBaseProposer
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
device
:
torch
.
device
,
runner
=
None
,
):
assert
vllm_config
.
speculative_config
is
not
None
assert
vllm_config
.
speculative_config
.
method
==
"dflash"
super
().
__init__
(
vllm_config
=
vllm_config
,
device
=
device
,
pass_hidden_states_to_model
=
True
,
runner
=
runner
,
)
# Only next_token_ids and mask tokens are query tokens, all other context is K/V
self
.
max_query_tokens
=
self
.
max_batch_size
*
(
1
+
self
.
num_speculative_tokens
)
# Positions covers both context states + query states
self
.
max_positions
=
self
.
max_num_tokens
+
self
.
max_query_tokens
# Separate context buffers to keep query buffer addresses stable for CUDA graphs
self
.
_context_slot_mapping_buffer
=
torch
.
zeros
(
self
.
max_num_tokens
,
dtype
=
torch
.
int64
,
device
=
device
,
)
self
.
_slot_mapping_buffer
=
torch
.
zeros
(
self
.
max_query_tokens
,
dtype
=
torch
.
int64
,
device
=
device
,
)
self
.
_context_positions_buffer
=
torch
.
zeros
(
self
.
max_num_tokens
,
dtype
=
torch
.
int64
,
device
=
device
,
)
self
.
positions
=
torch
.
zeros
(
self
.
max_query_tokens
,
dtype
=
torch
.
int64
,
device
=
device
,
)
self
.
arange
=
torch
.
arange
(
self
.
max_positions
+
1
,
device
=
device
,
dtype
=
torch
.
int32
)
# For DFlash we use the input embeddings to embed the mask token
self
.
parallel_drafting_hidden_state_tensor
=
None
@
override
def
_raise_if_multimodal
(
self
):
# Override to allow multimodal inputs since DFlash supports Qwen3.5 models
# Support for multimodal inputs has not been tested.
pass
@
override
def
set_inputs_first_pass
(
self
,
target_token_ids
:
torch
.
Tensor
,
next_token_ids
:
torch
.
Tensor
,
target_positions
:
torch
.
Tensor
,
target_hidden_states
:
torch
.
Tensor
,
token_indices_to_sample
:
torch
.
Tensor
|
None
,
cad
:
CommonAttentionMetadata
,
num_rejected_tokens_gpu
:
torch
.
Tensor
|
None
,
)
->
tuple
[
int
,
torch
.
Tensor
,
CommonAttentionMetadata
]:
# DFlash cross-attention: context K/V from target hidden states,
# Q from query embeddings (bonus + mask tokens).
batch_size
=
cad
.
batch_size
()
num_context
=
target_token_ids
.
shape
[
0
]
num_query_per_req
=
1
+
self
.
num_speculative_tokens
num_query_total
=
batch_size
*
num_query_per_req
# Store for build_model_inputs_first_pass to use
self
.
_dflash_num_context
=
num_context
# We don't need to copy into a buffer here since the context preprocessing
# does not run in a CUDA graph
self
.
_dflash_hidden_states
=
target_hidden_states
token_indices_to_sample
=
torch
.
empty
(
batch_size
*
self
.
num_speculative_tokens
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
)
# Launch fused triton kernel for input_ids, positions, slot_mapping,
# and token_indices_to_sample
max_ctx_per_req
=
cad
.
max_query_len
max_tokens_per_req
=
max_ctx_per_req
+
num_query_per_req
BLOCK_SIZE
=
min
(
256
,
triton
.
next_power_of_2
(
max_tokens_per_req
))
num_blocks
=
triton
.
cdiv
(
max_tokens_per_req
,
BLOCK_SIZE
)
grid
=
(
batch_size
,
num_blocks
)
has_num_rejected
=
num_rejected_tokens_gpu
is
not
None
copy_and_expand_dflash_inputs_kernel
[
grid
](
# Inputs
next_token_ids_ptr
=
next_token_ids
,
target_positions_ptr
=
target_positions
,
# Outputs
out_input_ids_ptr
=
self
.
input_ids
,
out_context_positions_ptr
=
self
.
_context_positions_buffer
,
out_query_positions_ptr
=
self
.
positions
,
out_context_slot_mapping_ptr
=
self
.
_context_slot_mapping_buffer
,
out_query_slot_mapping_ptr
=
self
.
_slot_mapping_buffer
,
out_token_indices_ptr
=
token_indices_to_sample
,
# Block table
block_table_ptr
=
cad
.
block_table_tensor
,
block_table_stride
=
cad
.
block_table_tensor
.
stride
(
0
),
# Metadata
query_start_loc_ptr
=
cad
.
query_start_loc
,
num_rejected_tokens_ptr
=
(
num_rejected_tokens_gpu
if
has_num_rejected
else
0
),
# Scalars
parallel_drafting_token_id
=
self
.
parallel_drafting_token_id
,
block_size
=
self
.
block_size
,
num_query_per_req
=
num_query_per_req
,
num_speculative_tokens
=
self
.
num_speculative_tokens
,
total_input_tokens
=
num_context
,
BLOCK_SIZE
=
BLOCK_SIZE
,
HAS_NUM_REJECTED
=
has_num_rejected
,
)
query_slot_mapping
=
self
.
_slot_mapping_buffer
[:
num_query_total
]
new_query_start_loc
=
self
.
arange
[:
batch_size
+
1
]
*
num_query_per_req
# In padded mode, cad.seq_lens includes rejected tokens. Subtract
# them so attention only sees the valid prefix of context states.
effective_seq_lens
=
cad
.
seq_lens
if
has_num_rejected
:
effective_seq_lens
=
effective_seq_lens
-
num_rejected_tokens_gpu
new_cad
=
CommonAttentionMetadata
(
query_start_loc
=
new_query_start_loc
,
seq_lens
=
effective_seq_lens
+
num_query_per_req
,
query_start_loc_cpu
=
(
torch
.
from_numpy
(
self
.
token_arange_np
[:
batch_size
+
1
]).
clone
()
*
num_query_per_req
),
_seq_lens_cpu
=
None
,
_num_computed_tokens_cpu
=
None
,
num_reqs
=
cad
.
num_reqs
,
num_actual_tokens
=
num_query_total
,
max_query_len
=
num_query_per_req
,
max_seq_len
=
cad
.
max_seq_len
+
num_query_per_req
,
block_table_tensor
=
cad
.
block_table_tensor
,
slot_mapping
=
query_slot_mapping
,
causal
=
False
,
# Non-causal attention is required for DFlash
)
return
num_query_total
,
token_indices_to_sample
,
new_cad
@
override
@
torch
.
inference_mode
()
def
dummy_run
(
self
,
num_tokens
:
int
,
use_cudagraphs
:
bool
=
True
,
is_graph_capturing
:
bool
=
False
,
slot_mappings
:
dict
[
str
,
torch
.
Tensor
]
|
None
=
None
,
)
->
None
:
"""
Key differences to default dummy_run:
- Only one forward pass due to parallel drafting
- DFlash uses context states as unpadded metadata, so hidden_states will
use the unpadded num_tokens instead of num_input_tokens
- max_query_tokens is quite small, DFlash only sees spec tokens as queries
- Multimodal inputs are not currently supported
"""
num_query_tokens
=
min
(
num_tokens
,
self
.
max_query_tokens
)
cudagraph_runtime_mode
,
num_input_tokens
,
num_tokens_across_dp
=
(
self
.
_determine_batch_execution_and_padding
(
num_query_tokens
,
use_cudagraphs
=
use_cudagraphs
)
)
# Slot mapping sized to num_input_tokens (query only), matching
# the K/V tensor size from the model forward. Context KVs are
# pre-inserted separately and don't flow through the model.
if
(
self
.
_draft_attn_layer_names
and
slot_mappings
is
not
None
and
next
(
iter
(
self
.
_draft_attn_layer_names
))
in
slot_mappings
):
slot_mapping_dict
=
self
.
_get_slot_mapping
(
num_input_tokens
)
else
:
slot_mapping_dict
=
slot_mappings
or
{}
# Context and query positions use separate buffers; no copy needed.
context_positions
=
self
.
_context_positions_buffer
[:
num_tokens
]
# Context states will be passed directly to the precomputation without
# going through the buffer, since no CUDA graph is used for the precomputation.
# For the dummy run, we use the dummy buffer.
context_states
=
self
.
hidden_states
[:
num_tokens
]
# Run the KV projection (GEMM + norms + RoPE) for memory profiling,
self
.
model
.
precompute_and_store_context_kv
(
context_states
,
context_positions
)
with
set_forward_context
(
None
,
self
.
vllm_config
,
num_tokens
=
num_input_tokens
,
num_tokens_across_dp
=
num_tokens_across_dp
,
cudagraph_runtime_mode
=
cudagraph_runtime_mode
,
slot_mapping
=
slot_mapping_dict
,
):
self
.
model
(
input_ids
=
self
.
input_ids
[:
num_input_tokens
],
positions
=
self
.
_get_positions
(
num_input_tokens
),
inputs_embeds
=
None
,
)
@
override
def
build_model_inputs_first_pass
(
self
,
num_tokens
:
int
,
num_input_tokens
:
int
,
mm_embed_inputs
:
tuple
[
list
[
torch
.
Tensor
],
torch
.
Tensor
]
|
None
,
)
->
tuple
[
dict
[
str
,
Any
],
int
]:
# Context and query positions/slots were written to separate
# buffers by the kernel — no copy needed.
num_context
=
self
.
_dflash_num_context
# Pre-insert context KVs directly into cache
self
.
model
.
precompute_and_store_context_kv
(
self
.
_dflash_hidden_states
,
# Shape is already [num_context, hidden_size]
self
.
_context_positions_buffer
[:
num_context
],
self
.
_context_slot_mapping_buffer
[:
num_context
],
)
return
(
dict
(
input_ids
=
self
.
input_ids
[:
num_input_tokens
],
positions
=
self
.
_get_positions
(
num_input_tokens
),
inputs_embeds
=
None
,
),
num_input_tokens
,
)
@
override
def
build_per_layer_attn_metadata
(
self
,
cad
:
CommonAttentionMetadata
,
draft_index
:
int
=
0
)
->
dict
[
str
,
object
]:
per_layer_attention_metadata
=
super
().
build_per_layer_attn_metadata
(
cad
,
draft_index
)
for
layer_name
,
attn_metadata
in
per_layer_attention_metadata
.
items
():
assert
getattr
(
attn_metadata
,
"causal"
,
None
)
is
False
,
(
f
"Attention metadata for layer
{
layer_name
}
does not have"
" non-causal support, which is required for DFlash."
" Consider using a different attention backend, such as FlashAttention."
)
return
per_layer_attention_metadata
@
override
def
_get_eagle3_use_aux_hidden_state_from_config
(
self
):
use_aux_hidden_state
=
True
dflash_config
=
getattr
(
self
.
draft_model_config
.
hf_config
,
"dflash_config"
,
None
)
if
dflash_config
is
not
None
:
use_aux_hidden_state
=
dflash_config
.
get
(
"use_aux_hidden_state"
,
True
)
return
use_aux_hidden_state
vllm/v1/spec_decode/eagle.py
View file @
494636b2
...
...
@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
ast
from
importlib.util
import
find_spec
from
typing
import
cast
from
typing
import
Any
,
cast
import
numpy
as
np
import
torch
...
...
@@ -23,6 +23,7 @@ from vllm.model_executor.models import supports_multimodal
from
vllm.model_executor.models.deepseek_eagle3
import
Eagle3DeepseekV2ForCausalLM
from
vllm.model_executor.models.interfaces
import
SupportsMultiModal
from
vllm.model_executor.models.llama_eagle3
import
Eagle3LlamaForCausalLM
from
vllm.model_executor.models.qwen3_dflash
import
DFlashQwen3ForCausalLM
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.platforms
import
current_platform
from
vllm.triton_utils
import
triton
...
...
@@ -83,13 +84,15 @@ class SpecDecodeBaseProposer:
self
.
hidden_size
=
self
.
draft_model_config
.
get_hidden_size
()
self
.
inputs_embeds_size
=
self
.
draft_model_config
.
get_inputs_embeds_size
()
# Unifying eagle, draft model, and parallel drafting support
# Unifying eagle, draft model, and parallel drafting support.
# DFlash always uses parallel drafting (all tokens in one pass),
# but has an additional slot for the next_token_id (does not shift like EAGLE)
self
.
parallel_drafting
:
bool
=
self
.
speculative_config
.
parallel_drafting
self
.
extra_slots_per_request
=
(
1
if
not
self
.
parallel_drafting
else
self
.
num_speculative_tokens
)
self
.
net_num_new_slots_per_request
=
self
.
extra_slots_per_request
-
(
1
if
self
.
pass_hidden_states_to_model
else
0
1
if
(
self
.
pass_hidden_states_to_model
and
self
.
method
!=
"dflash"
)
else
0
)
self
.
needs_extra_input_slots
=
self
.
net_num_new_slots_per_request
>
0
...
...
@@ -101,10 +104,14 @@ class SpecDecodeBaseProposer:
self
.
speculative_config
.
use_local_argmax_reduction
)
max_batch_size
=
vllm_config
.
scheduler_config
.
max_num_seqs
self
.
max_batch_size
=
vllm_config
.
scheduler_config
.
max_num_seqs
self
.
max_num_tokens
=
vllm_config
.
scheduler_config
.
max_num_batched_tokens
self
.
token_arange_np
=
np
.
arange
(
self
.
max_num_tokens
)
# Can be specialized by methods like DFlash to reduce the limit
self
.
max_query_tokens
=
self
.
max_num_tokens
self
.
max_positions
=
self
.
max_num_tokens
# Multi-modal data support
self
.
mm_registry
=
MULTIMODAL_REGISTRY
self
.
supports_mm_inputs
=
self
.
mm_registry
.
supports_multimodal_inputs
(
...
...
@@ -146,18 +153,20 @@ class SpecDecodeBaseProposer:
# 1D-RoPE.
# See page 5 of https://arxiv.org/abs/2409.12191
self
.
mrope_positions
=
torch
.
zeros
(
(
3
,
self
.
max_
num_toke
ns
+
1
),
dtype
=
torch
.
int64
,
device
=
device
(
3
,
self
.
max_
positio
ns
+
1
),
dtype
=
torch
.
int64
,
device
=
device
)
elif
self
.
uses_xdrope_dim
>
0
and
self
.
draft_uses_xdrope_dim
>
0
:
self
.
xdrope_positions
=
torch
.
zeros
(
(
self
.
uses_xdrope_dim
,
self
.
max_
num_toke
ns
+
1
),
(
self
.
uses_xdrope_dim
,
self
.
max_
positio
ns
+
1
),
dtype
=
torch
.
int64
,
device
=
device
,
)
else
:
# RoPE need (max_num_tokens,)
self
.
positions
=
torch
.
zeros
(
self
.
max_num_tokens
,
dtype
=
torch
.
int64
,
device
=
device
self
.
max_positions
,
dtype
=
torch
.
int64
,
device
=
device
,
)
self
.
hidden_states
=
torch
.
zeros
(
(
self
.
max_num_tokens
,
self
.
hidden_size
),
dtype
=
self
.
dtype
,
device
=
device
...
...
@@ -168,7 +177,7 @@ class SpecDecodeBaseProposer:
# We need +1 here because the arange is used to set query_start_loc,
# which has one more element than batch_size.
max_num_slots_for_arange
=
max
(
max_batch_size
+
1
,
self
.
max_num_tokens
)
max_num_slots_for_arange
=
max
(
self
.
max_batch_size
+
1
,
self
.
max_num_tokens
)
self
.
arange
=
torch
.
arange
(
max_num_slots_for_arange
,
device
=
device
,
dtype
=
torch
.
int32
)
...
...
@@ -200,7 +209,7 @@ class SpecDecodeBaseProposer:
)
self
.
backup_next_token_ids
=
CpuGpuBuffer
(
max_batch_size
,
self
.
max_batch_size
,
dtype
=
torch
.
int32
,
pin_memory
=
is_pin_memory_available
(),
device
=
device
,
...
...
@@ -208,7 +217,9 @@ class SpecDecodeBaseProposer:
)
self
.
_slot_mapping_buffer
=
torch
.
zeros
(
self
.
max_num_tokens
,
dtype
=
torch
.
int64
,
device
=
device
self
.
max_positions
,
dtype
=
torch
.
int64
,
device
=
device
,
)
# Determine allowed attention backends once during initialization.
...
...
@@ -275,7 +286,7 @@ class SpecDecodeBaseProposer:
# Precompute draft position offsets in flattened tree.
self
.
tree_draft_pos_offsets
=
torch
.
arange
(
1
,
len
(
self
.
tree_choices
)
+
1
,
device
=
device
,
dtype
=
torch
.
int32
).
repeat
(
max_batch_size
,
1
)
).
repeat
(
self
.
max_batch_size
,
1
)
def
_raise_if_padded_drafter_batch_disabled
(
self
):
if
self
.
speculative_config
.
disable_padded_drafter_batch
:
...
...
@@ -305,14 +316,19 @@ class SpecDecodeBaseProposer:
# for those masked slots.
model_hf_config
=
self
.
draft_model_config
.
hf_config
if
hasattr
(
model_hf_config
,
"pard_token"
):
# DFlash stores mask_token_id in dflash_config
dflash_config
=
getattr
(
model_hf_config
,
"dflash_config"
,
None
)
if
dflash_config
and
"mask_token_id"
in
dflash_config
:
self
.
parallel_drafting_token_id
=
dflash_config
[
"mask_token_id"
]
elif
hasattr
(
model_hf_config
,
"pard_token"
):
self
.
parallel_drafting_token_id
=
model_hf_config
.
pard_token
elif
hasattr
(
model_hf_config
,
"ptd_token_id"
):
self
.
parallel_drafting_token_id
=
model_hf_config
.
ptd_token_id
else
:
raise
ValueError
(
"For parallel drafting, the draft model config must have "
"`pard_token` or `ptd_token_id` specified in its config.json."
"`pard_token`, `ptd_token_id`, or "
"`dflash_config.mask_token_id` specified in its config.json."
)
if
self
.
pass_hidden_states_to_model
:
...
...
@@ -402,9 +418,14 @@ class SpecDecodeBaseProposer:
)
->
torch
.
Tensor
:
batch_size
=
common_attn_metadata
.
batch_size
()
if
self
.
method
==
"eagle3"
:
if
self
.
method
in
(
"eagle3"
,
"dflash"
)
:
assert
isinstance
(
self
.
model
,
(
Eagle3LlamaForCausalLM
,
Eagle3DeepseekV2ForCausalLM
)
self
.
model
,
(
Eagle3LlamaForCausalLM
,
Eagle3DeepseekV2ForCausalLM
,
DFlashQwen3ForCausalLM
,
),
)
target_hidden_states
=
self
.
model
.
combine_hidden_states
(
target_hidden_states
...
...
@@ -423,40 +444,17 @@ class SpecDecodeBaseProposer:
)
)
per_layer_attn_metadata
:
dict
[
str
,
object
]
=
{}
for
attn_group
in
self
.
draft_attn_groups
:
attn_metadata
=
attn_group
.
get_metadata_builder
().
build_for_drafting
(
common_attn_metadata
=
common_attn_metadata
,
draft_index
=
0
)
for
layer_name
in
attn_group
.
layer_names
:
per_layer_attn_metadata
[
layer_name
]
=
attn_metadata
per_layer_attn_metadata
=
self
.
build_per_layer_attn_metadata
(
common_attn_metadata
)
cudagraph_runtime_mode
,
num_input_tokens
,
num_tokens_across_dp
=
(
self
.
_determine_batch_execution_and_padding
(
num_tokens
)
)
if
self
.
supports_mm_inputs
:
mm_embeds
,
is_mm_embed
=
mm_embed_inputs
or
(
None
,
None
)
self
.
inputs_embeds
[:
num_tokens
]
=
self
.
model
.
embed_input_ids
(
self
.
input_ids
[:
num_tokens
],
multimodal_embeddings
=
mm_embeds
,
is_multimodal
=
is_mm_embed
,
)
input_ids
=
None
inputs_embeds
=
self
.
inputs_embeds
[:
num_input_tokens
]
else
:
input_ids
=
self
.
input_ids
[:
num_input_tokens
]
inputs_embeds
=
None
model_kwargs
=
{
"input_ids"
:
input_ids
,
"positions"
:
self
.
_get_positions
(
num_input_tokens
),
"inputs_embeds"
:
inputs_embeds
,
}
if
self
.
pass_hidden_states_to_model
:
model_kwargs
[
"hidden_states"
]
=
self
.
hidden_states
[:
num_input_tokens
]
model_kwargs
,
slot_mapping_size
=
self
.
build_model_inputs_first_pass
(
num_tokens
,
num_input_tokens
,
mm_embed_inputs
)
with
set_forward_context
(
per_layer_attn_metadata
,
...
...
@@ -465,7 +463,7 @@ class SpecDecodeBaseProposer:
num_tokens_across_dp
=
num_tokens_across_dp
,
cudagraph_runtime_mode
=
cudagraph_runtime_mode
,
slot_mapping
=
self
.
_get_slot_mapping
(
num_input_tokens
,
common_attn_metadata
.
slot_mapping
slot_mapping_size
,
common_attn_metadata
.
slot_mapping
),
):
ret_hidden_states
=
self
.
model
(
**
model_kwargs
)
...
...
@@ -488,7 +486,10 @@ class SpecDecodeBaseProposer:
positions
=
self
.
positions
[
token_indices_to_sample
]
hidden_states
=
hidden_states
[
token_indices_to_sample
]
if
isinstance
(
attn_metadata
,
TreeAttentionMetadata
):
if
any
(
isinstance
(
attn_metadata
,
TreeAttentionMetadata
)
for
attn_metadata
in
per_layer_attn_metadata
.
values
()
):
# Draft using tree attention - requires full logits for top-k
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
)
draft_token_ids_list
=
self
.
propose_tree
(
...
...
@@ -504,15 +505,16 @@ class SpecDecodeBaseProposer:
draft_token_ids
=
self
.
_greedy_sample
(
sample_hidden_states
)
if
self
.
allowed_attn_types
is
not
None
and
not
isinstance
(
attn_metadata
,
self
.
allowed_attn_types
):
raise
ValueError
(
f
"Unsupported attention metadata type for speculative "
"decoding with num_speculative_tokens > 1: "
f
"
{
type
(
attn_metadata
)
}
. Supported types are: "
f
"
{
self
.
allowed_attn_types
}
"
)
for
attn_metadata
in
per_layer_attn_metadata
.
values
():
if
self
.
allowed_attn_types
is
not
None
and
not
isinstance
(
attn_metadata
,
self
.
allowed_attn_types
):
raise
ValueError
(
f
"Unsupported attention metadata type for speculative "
"decoding with num_speculative_tokens > 1: "
f
"
{
type
(
attn_metadata
)
}
. Supported types are: "
f
"
{
self
.
allowed_attn_types
}
"
)
# Generate the remaining draft tokens.
draft_token_ids_list
=
[
draft_token_ids
]
...
...
@@ -593,13 +595,9 @@ class SpecDecodeBaseProposer:
common_attn_metadata
.
_num_computed_tokens_cpu
+=
1
# Rebuild attention metadata
for
attn_group
in
self
.
draft_attn_groups
:
attn_metadata
=
attn_group
.
get_metadata_builder
().
build_for_drafting
(
common_attn_metadata
=
common_attn_metadata
,
draft_index
=
token_index
+
1
,
)
for
layer_name
in
attn_group
.
layer_names
:
per_layer_attn_metadata
[
layer_name
]
=
attn_metadata
per_layer_attn_metadata
=
self
.
build_per_layer_attn_metadata
(
common_attn_metadata
,
draft_index
=
token_index
+
1
)
# copy inputs to buffer for cudagraph
self
.
input_ids
[:
batch_size
]
=
input_ids
...
...
@@ -780,8 +778,51 @@ class SpecDecodeBaseProposer:
return
total_num_output_tokens
,
token_indices_to_sample
,
new_cad
def
build_model_inputs_first_pass
(
self
,
num_tokens
:
int
,
num_input_tokens
:
int
,
mm_embed_inputs
:
tuple
[
list
[
torch
.
Tensor
],
torch
.
Tensor
]
|
None
,
)
->
tuple
[
dict
[
str
,
Any
],
int
]:
if
self
.
supports_mm_inputs
:
mm_embeds
,
is_mm_embed
=
mm_embed_inputs
or
(
None
,
None
)
self
.
inputs_embeds
[:
num_tokens
]
=
self
.
model
.
embed_input_ids
(
self
.
input_ids
[:
num_tokens
],
multimodal_embeddings
=
mm_embeds
,
is_multimodal
=
is_mm_embed
,
)
input_ids
=
None
inputs_embeds
=
self
.
inputs_embeds
[:
num_input_tokens
]
else
:
input_ids
=
self
.
input_ids
[:
num_input_tokens
]
inputs_embeds
=
None
model_kwargs
=
{
"input_ids"
:
input_ids
,
"positions"
:
self
.
_get_positions
(
num_input_tokens
),
"inputs_embeds"
:
inputs_embeds
,
}
if
self
.
pass_hidden_states_to_model
:
model_kwargs
[
"hidden_states"
]
=
self
.
hidden_states
[:
num_input_tokens
]
return
model_kwargs
,
num_input_tokens
def
build_per_layer_attn_metadata
(
self
,
common_attn_metadata
:
CommonAttentionMetadata
,
draft_index
:
int
=
0
)
->
dict
[
str
,
object
]:
per_layer_attn_metadata
:
dict
[
str
,
object
]
=
{}
for
attn_group
in
self
.
draft_attn_groups
:
attn_metadata
=
attn_group
.
get_metadata_builder
().
build_for_drafting
(
common_attn_metadata
=
common_attn_metadata
,
draft_index
=
draft_index
)
for
layer_name
in
attn_group
.
layer_names
:
per_layer_attn_metadata
[
layer_name
]
=
attn_metadata
return
per_layer_attn_metadata
def
model_returns_tuple
(
self
)
->
bool
:
return
self
.
method
not
in
(
"mtp"
,
"draft_model"
)
return
self
.
method
not
in
(
"mtp"
,
"draft_model"
,
"dflash"
)
def
prepare_next_token_ids_cpu
(
self
,
...
...
@@ -1310,15 +1351,20 @@ class SpecDecodeBaseProposer:
self
.
_maybe_share_embeddings
(
target_language_model
)
self
.
_maybe_share_lm_head
(
target_language_model
)
if
self
.
parallel_drafting
and
self
.
pass_hidden_states_to_model
:
assert
self
.
parallel_drafting_hidden_state_tensor
is
not
None
self
.
parallel_drafting_hidden_state_tensor
.
copy_
(
self
.
model
.
combine_hidden_states
(
self
.
model
.
mask_hidden
.
view
(
3
*
self
.
hidden_size
)
if
(
self
.
parallel_drafting
and
self
.
pass_hidden_states_to_model
and
self
.
parallel_drafting_hidden_state_tensor
is
not
None
):
flat_mask
=
self
.
model
.
mask_hidden
.
view
(
-
1
)
if
self
.
eagle3_use_aux_hidden_state
:
# EAGLE3: mask_hidden stores all aux hidden states,
# project through combine_hidden_states
self
.
parallel_drafting_hidden_state_tensor
.
copy_
(
self
.
model
.
combine_hidden_states
(
flat_mask
)
)
if
self
.
eagle3_use_aux_hidden_state
else
self
.
model
.
mask_hidden
.
view
(
self
.
hidden_size
)
)
else
:
self
.
parallel_drafting_hidden_state_tensor
.
copy_
(
flat_mask
)
def
_maybe_share_embeddings
(
self
,
target_language_model
:
nn
.
Module
)
->
None
:
"""
...
...
@@ -1493,8 +1539,9 @@ class SpecDecodeBaseProposer:
)
->
None
:
# FIXME: when using tree-based specdec, adjust number of forward-passes
# according to the depth of the tree.
only_one_forward_pass
=
is_graph_capturing
or
self
.
parallel_drafting
for
fwd_idx
in
range
(
self
.
num_speculative_tokens
if
not
is_graph_capturing
else
1
1
if
only_one_forward_pass
else
self
.
num_speculative_tokens
):
if
fwd_idx
<=
1
:
cudagraph_runtime_mode
,
num_input_tokens
,
num_tokens_across_dp
=
(
...
...
vllm/v1/spec_decode/utils.py
View file @
494636b2
...
...
@@ -441,6 +441,114 @@ def copy_and_expand_eagle_inputs_kernel(
)
@
triton
.
jit
def
copy_and_expand_dflash_inputs_kernel
(
# Inputs
next_token_ids_ptr
,
# [num_reqs]
target_positions_ptr
,
# [num_context]
# Outputs
out_input_ids_ptr
,
# [num_query_total] (output)
out_context_positions_ptr
,
# [num_context] (output)
out_query_positions_ptr
,
# [num_query_total] (output)
out_context_slot_mapping_ptr
,
# [num_context] (output)
out_query_slot_mapping_ptr
,
# [num_query_total] (output)
out_token_indices_ptr
,
# [num_reqs * num_speculative_tokens] (output)
# Block table
block_table_ptr
,
# [max_reqs, max_blocks]
block_table_stride
,
# stride of block_table dim 0 (in elements)
# Metadata
query_start_loc_ptr
,
# [num_reqs + 1]
num_rejected_tokens_ptr
,
# [num_reqs] or null (0) when not padded
# Scalars
parallel_drafting_token_id
,
# tl.int32
block_size
,
# tl.int32
num_query_per_req
,
# tl.int32
num_speculative_tokens
,
# tl.int32
total_input_tokens
,
# tl.int32
BLOCK_SIZE
:
tl
.
constexpr
,
HAS_NUM_REJECTED
:
tl
.
constexpr
=
False
,
):
"""
Fused kernel for DFlash first-pass input setup.
Per request, this kernel:
1. Copies context positions from target_positions to
out_context_positions.
2. Computes query positions (last_target_pos + 1 + offset) and writes
them to out_query_positions.
3. Writes input_ids for query tokens: [next_token, mask, mask, ...].
4. Computes slot_mapping for context and query positions into separate
buffers via block_table lookup.
5. Writes token_indices_to_sample for the mask (speculative) tokens.
"""
req_idx
=
tl
.
program_id
(
axis
=
0
)
block_idx
=
tl
.
program_id
(
axis
=
1
)
# Load context token range for this request
ctx_start
=
tl
.
load
(
query_start_loc_ptr
+
req_idx
)
ctx_end
=
tl
.
load
(
query_start_loc_ptr
+
req_idx
+
1
)
num_ctx
=
ctx_end
-
ctx_start
total_tokens
=
num_ctx
+
num_query_per_req
j
=
block_idx
*
BLOCK_SIZE
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
in_bounds
=
j
<
total_tokens
is_ctx
=
j
<
num_ctx
is_query
=
(
~
is_ctx
)
&
in_bounds
query_off
=
j
-
num_ctx
# offset within query portion (0-indexed)
# --- Positions ---
# Context: load from target_positions
ctx_pos_idx
=
tl
.
minimum
(
ctx_start
+
j
,
total_input_tokens
-
1
)
ctx_pos
=
tl
.
load
(
target_positions_ptr
+
ctx_pos_idx
,
mask
=
is_ctx
,
other
=
0
)
# Query: last_valid_pos + 1 + query_off
# In padded mode, ctx_end includes rejected tokens; use valid_ctx_end
# to find the last accepted context position.
if
HAS_NUM_REJECTED
:
num_rejected
=
tl
.
load
(
num_rejected_tokens_ptr
+
req_idx
)
valid_ctx_end
=
ctx_end
-
num_rejected
else
:
valid_ctx_end
=
ctx_end
last_pos
=
tl
.
load
(
target_positions_ptr
+
valid_ctx_end
-
1
)
query_pos
=
last_pos
+
1
+
query_off
positions
=
tl
.
where
(
is_ctx
,
ctx_pos
,
query_pos
)
# Context and query positions go to separate buffers.
ctx_pos_out
=
ctx_start
+
j
tl
.
store
(
out_context_positions_ptr
+
ctx_pos_out
,
ctx_pos
,
mask
=
is_ctx
)
query_out
=
req_idx
*
num_query_per_req
+
query_off
tl
.
store
(
out_query_positions_ptr
+
query_out
,
query_pos
,
mask
=
is_query
)
# --- Slot mapping (block_table lookup for all positions) ---
block_num
=
positions
//
block_size
# # Clamp block_number to avoid OOB when position is at max
block_num
=
tl
.
minimum
(
block_num
,
block_table_stride
-
1
)
block_id
=
tl
.
load
(
block_table_ptr
+
req_idx
*
block_table_stride
+
block_num
,
mask
=
in_bounds
,
other
=
0
,
).
to
(
tl
.
int64
)
slot
=
block_id
*
block_size
+
(
positions
%
block_size
)
tl
.
store
(
out_context_slot_mapping_ptr
+
ctx_pos_out
,
slot
,
mask
=
is_ctx
)
tl
.
store
(
out_query_slot_mapping_ptr
+
query_out
,
slot
,
mask
=
is_query
)
# --- Input IDs (query tokens only) ---
bonus_token
=
tl
.
load
(
next_token_ids_ptr
+
req_idx
)
is_bonus
=
is_query
&
(
query_off
==
0
)
input_id
=
tl
.
where
(
is_bonus
,
bonus_token
,
parallel_drafting_token_id
)
tl
.
store
(
out_input_ids_ptr
+
query_out
,
input_id
,
mask
=
is_query
)
# --- Token indices to sample (mask tokens, skip the bonus token) ---
is_sample
=
is_query
&
(
query_off
>
0
)
sample_out_idx
=
req_idx
*
num_speculative_tokens
+
(
query_off
-
1
)
tl
.
store
(
out_token_indices_ptr
+
sample_out_idx
,
query_out
,
mask
=
is_sample
,
)
@
torch
.
compile
(
dynamic
=
True
,
backend
=
current_platform
.
simple_compile_backend
)
def
update_num_computed_tokens_for_batch_change
(
num_computed_tokens
:
torch
.
Tensor
,
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
494636b2
...
...
@@ -160,6 +160,7 @@ from vllm.v1.sample.logits_processor.interface import LogitsProcessor
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.sample.rejection_sampler
import
RejectionSampler
from
vllm.v1.sample.sampler
import
Sampler
from
vllm.v1.spec_decode.dflash
import
DFlashProposer
from
vllm.v1.spec_decode.draft_model
import
DraftModelProposer
from
vllm.v1.spec_decode.eagle
import
EagleProposer
from
vllm.v1.spec_decode.extract_hidden_states
import
ExtractHiddenStatesProposer
...
...
@@ -515,6 +516,7 @@ class GPUModelRunner(
|
NgramProposerGPU
|
SuffixDecodingProposer
|
EagleProposer
|
DFlashProposer
|
DraftModelProposer
|
MedusaProposer
|
ExtractHiddenStatesProposer
...
...
@@ -546,6 +548,9 @@ class GPUModelRunner(
self
.
_ngram_pinned_val_buf
=
torch
.
zeros
(
self
.
max_num_reqs
,
dtype
=
torch
.
int32
,
pin_memory
=
True
)
elif
self
.
speculative_config
.
use_dflash
():
self
.
drafter
=
DFlashProposer
(
self
.
vllm_config
,
self
.
device
,
self
)
self
.
use_aux_hidden_state_outputs
=
True
elif
self
.
speculative_config
.
method
==
"suffix"
:
self
.
drafter
=
SuffixDecodingProposer
(
self
.
vllm_config
)
elif
self
.
speculative_config
.
use_eagle
():
...
...
@@ -2289,7 +2294,7 @@ class GPUModelRunner(
cm
.
slot_mapping
=
slot_mappings
[
kv_cache_gid
]
if
self
.
speculative_config
and
spec_decode_common_attn_metadata
is
None
:
if
isinstance
(
self
.
drafter
,
EagleProposer
):
if
isinstance
(
self
.
drafter
,
(
EagleProposer
,
DFlashProposer
)
):
if
self
.
drafter
.
kv_cache_gid
==
kv_cache_gid
:
spec_decode_common_attn_metadata
=
cm
else
:
...
...
@@ -4202,7 +4207,10 @@ class GPUModelRunner(
# as inputs, and does not need to wait for bookkeeping to finish.
assert
isinstance
(
self
.
drafter
,
EagleProposer
|
DraftModelProposer
|
ExtractHiddenStatesProposer
,
EagleProposer
|
DFlashProposer
|
DraftModelProposer
|
ExtractHiddenStatesProposer
,
)
sampled_token_ids
=
sampler_output
.
sampled_token_ids
if
input_fits_in_drafter
:
...
...
@@ -4589,8 +4597,14 @@ class GPUModelRunner(
next_token_ids
,
valid_sampled_tokens_count
)
elif
spec_config
.
use_eagle
()
or
spec_config
.
uses_draft_model
():
assert
isinstance
(
self
.
drafter
,
EagleProposer
|
DraftModelProposer
)
elif
(
spec_config
.
use_eagle
()
or
spec_config
.
use_dflash
()
or
spec_config
.
uses_draft_model
()
):
assert
isinstance
(
self
.
drafter
,
EagleProposer
|
DFlashProposer
|
DraftModelProposer
)
if
spec_config
.
disable_padded_drafter_batch
:
# When padded-batch is disabled, the sampled_token_ids should be
...
...
@@ -4889,10 +4903,13 @@ class GPUModelRunner(
return
None
hf_config
=
self
.
speculative_config
.
draft_model_config
.
hf_config
if
not
hasattr
(
hf_config
,
"eagle_aux_hidden_state_layer_ids"
):
return
None
layer_ids
=
hf_config
.
eagle_aux_hidden_state_layer_ids
layer_ids
=
getattr
(
hf_config
,
"eagle_aux_hidden_state_layer_ids"
,
None
)
if
not
layer_ids
:
dflash_config
=
getattr
(
hf_config
,
"dflash_config"
,
None
)
if
dflash_config
and
isinstance
(
dflash_config
,
dict
):
layer_ids
=
dflash_config
.
get
(
"target_layer_ids"
)
if
layer_ids
and
isinstance
(
layer_ids
,
(
list
,
tuple
)):
return
tuple
(
layer_ids
)
...
...
@@ -5479,7 +5496,10 @@ class GPUModelRunner(
):
assert
isinstance
(
self
.
drafter
,
EagleProposer
|
DraftModelProposer
|
ExtractHiddenStatesProposer
,
EagleProposer
|
DFlashProposer
|
DraftModelProposer
|
ExtractHiddenStatesProposer
,
)
assert
self
.
speculative_config
is
not
None
# Eagle currently only supports PIECEWISE cudagraphs.
...
...
@@ -6236,7 +6256,9 @@ class GPUModelRunner(
self
.
speculative_config
.
use_eagle
()
or
self
.
speculative_config
.
uses_draft_model
()
):
assert
isinstance
(
self
.
drafter
,
EagleProposer
|
DraftModelProposer
)
assert
isinstance
(
self
.
drafter
,
EagleProposer
|
DFlashProposer
|
DraftModelProposer
)
self
.
drafter
.
initialize_attn_backend
(
kv_cache_config
,
kernel_block_sizes
)
def
_check_and_update_cudagraph_mode
(
...
...
@@ -6420,7 +6442,10 @@ class GPUModelRunner(
self
.
speculative_config
.
use_eagle
()
or
self
.
speculative_config
.
uses_extract_hidden_states
()
):
assert
isinstance
(
self
.
drafter
,
EagleProposer
|
ExtractHiddenStatesProposer
)
assert
isinstance
(
self
.
drafter
,
EagleProposer
|
DFlashProposer
|
ExtractHiddenStatesProposer
,
)
self
.
drafter
.
initialize_cudagraph_keys
(
cudagraph_mode
)
def
calculate_reorder_batch_threshold
(
self
)
->
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment