Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
494636b2
Unverified
Commit
494636b2
authored
Mar 30, 2026
by
Benjamin Chislett
Committed by
GitHub
Mar 30, 2026
Browse files
[Feat][Spec Decode] DFlash (#36847)
Signed-off-by:
Benjamin Chislett
<
bchislett@nvidia.com
>
parent
ab1a6a43
Changes
17
Show whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
1577 additions
and
107 deletions
+1577
-107
tests/models/registry.py
tests/models/registry.py
+8
-0
tests/v1/e2e/spec_decode/test_spec_decode.py
tests/v1/e2e/spec_decode/test_spec_decode.py
+166
-6
tests/v1/spec_decode/test_eagle.py
tests/v1/spec_decode/test_eagle.py
+152
-3
vllm/config/speculative.py
vllm/config/speculative.py
+21
-6
vllm/config/vllm.py
vllm/config/vllm.py
+20
-0
vllm/model_executor/models/qwen3.py
vllm/model_executor/models/qwen3.py
+1
-0
vllm/model_executor/models/qwen3_dflash.py
vllm/model_executor/models/qwen3_dflash.py
+619
-0
vllm/model_executor/models/qwen3_next.py
vllm/model_executor/models/qwen3_next.py
+6
-8
vllm/model_executor/models/registry.py
vllm/model_executor/models/registry.py
+1
-0
vllm/transformers_utils/configs/eagle.py
vllm/transformers_utils/configs/eagle.py
+12
-1
vllm/v1/attention/backend.py
vllm/v1/attention/backend.py
+14
-0
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+4
-0
vllm/v1/attention/selector.py
vllm/v1/attention/selector.py
+9
-1
vllm/v1/spec_decode/dflash.py
vllm/v1/spec_decode/dflash.py
+282
-0
vllm/v1/spec_decode/eagle.py
vllm/v1/spec_decode/eagle.py
+119
-72
vllm/v1/spec_decode/utils.py
vllm/v1/spec_decode/utils.py
+108
-0
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+35
-10
No files found.
tests/models/registry.py
View file @
494636b2
...
...
@@ -1163,6 +1163,14 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
# "JackFram/llama-160m",
# speculative_model="ibm-ai-platform/llama-160m-accelerator"
# ),
# [DFlash]
"DFlashDraftModel"
:
_HfExamplesInfo
(
"Qwen/Qwen3.5-4B"
,
speculative_model
=
"z-lab/Qwen3.5-4B-DFlash"
,
use_original_num_layers
=
True
,
# Need all layers since DFlash has >1 layer,
max_model_len
=
8192
,
# Reduce max len to ensure test runs in low-VRAM CI env
max_num_seqs
=
32
,
),
# [Eagle]
"EagleDeepSeekMTPModel"
:
_HfExamplesInfo
(
"eagle618/deepseek-v3-random"
,
...
...
tests/v1/e2e/spec_decode/test_spec_decode.py
View file @
494636b2
...
...
@@ -7,6 +7,7 @@ from typing import Any
import
pytest
import
torch
from
tqdm
import
tqdm
from
tests.evals.gsm8k.gsm8k_eval
import
_build_gsm8k_prompts
,
evaluate_gsm8k_offline
from
tests.utils
import
(
...
...
@@ -1105,19 +1106,178 @@ def some_high_acceptance_metrics() -> dict:
}
def
compute_acceptance_rate
(
metrics
:
list
[
Metric
])
->
float
:
def
compute_acceptance_rate
(
metrics
:
list
[
Metric
],
prev_metrics
:
list
[
Metric
]
|
None
=
None
)
->
float
:
name2metric
=
{
metric
.
name
:
metric
for
metric
in
metrics
}
n_draft_toks
=
name2metric
[
"vllm:spec_decode_num_draft_tokens"
].
value
# type: ignore
n_draft_toks
=
name2metric
[
"vllm:spec_decode_num_draft_tokens"
].
value
if
n_draft_toks
==
0
:
return
float
(
"nan"
)
n_accepted_toks
=
name2metric
[
"vllm:spec_decode_num_accepted_tokens"
].
value
# type: ignore
n_accepted_toks
=
name2metric
[
"vllm:spec_decode_num_accepted_tokens"
].
value
if
prev_metrics
is
not
None
:
prev_name2metric
=
{
metric
.
name
:
metric
for
metric
in
prev_metrics
}
n_draft_toks
-=
prev_name2metric
[
"vllm:spec_decode_num_draft_tokens"
].
value
n_accepted_toks
-=
prev_name2metric
[
"vllm:spec_decode_num_accepted_tokens"
].
value
if
n_draft_toks
<=
0
:
return
float
(
"nan"
)
return
n_accepted_toks
/
n_draft_toks
def
compute_acceptance_len
(
metrics
:
list
[
Metric
])
->
float
:
def
compute_acceptance_len
(
metrics
:
list
[
Metric
],
prev_metrics
:
list
[
Metric
]
|
None
=
None
)
->
float
:
name2metric
=
{
metric
.
name
:
metric
for
metric
in
metrics
}
n_drafts
=
name2metric
[
"vllm:spec_decode_num_drafts"
].
value
# type: ignore
n_accepted_toks
=
name2metric
[
"vllm:spec_decode_num_accepted_tokens"
].
value
# type: ignore
n_drafts
=
name2metric
[
"vllm:spec_decode_num_drafts"
].
value
n_accepted_toks
=
name2metric
[
"vllm:spec_decode_num_accepted_tokens"
].
value
if
n_drafts
==
0
:
return
1
if
prev_metrics
is
not
None
:
prev_name2metric
=
{
metric
.
name
:
metric
for
metric
in
prev_metrics
}
n_drafts
-=
prev_name2metric
[
"vllm:spec_decode_num_drafts"
].
value
n_accepted_toks
-=
prev_name2metric
[
"vllm:spec_decode_num_accepted_tokens"
].
value
if
n_drafts
<=
0
:
return
1
return
1
+
(
n_accepted_toks
/
n_drafts
)
# Datasets in the format used in DFlash validations
def
load_and_process_dataset
(
data_name
:
str
):
from
datasets
import
load_dataset
if
data_name
==
"gsm8k"
:
dataset
=
load_dataset
(
"openai/gsm8k"
,
"main"
,
split
=
"test"
)
prompt_fmt
=
(
"{question}
\n
Please reason step by step,"
" and put your final answer within
\\
boxed{{}}."
)
dataset
=
dataset
.
map
(
lambda
x
:
{
"turns"
:
[
prompt_fmt
.
format
(
**
x
)]})
elif
data_name
==
"mt-bench"
:
dataset
=
load_dataset
(
"HuggingFaceH4/mt_bench_prompts"
,
split
=
"train"
)
dataset
=
dataset
.
map
(
lambda
x
:
{
"turns"
:
x
[
"prompt"
]})
elif
data_name
==
"humaneval"
:
dataset
=
load_dataset
(
"openai/openai_humaneval"
,
split
=
"test"
)
prompt_fmt
=
(
"Write a solution to the following problem and make sure"
" that it passes the tests:
\n
```python
\n
{prompt}
\n
```"
)
dataset
=
dataset
.
map
(
lambda
x
:
{
"turns"
:
[
prompt_fmt
.
format
(
**
x
)]})
return
dataset
@
pytest
.
fixture
def
dflash_config
():
target_model
=
"Qwen/Qwen3-8B"
draft_model
=
"z-lab/Qwen3-8B-DFlash-b16"
return
dict
(
model
=
target_model
,
trust_remote_code
=
True
,
speculative_config
=
{
"method"
:
"dflash"
,
"model"
:
draft_model
,
"num_speculative_tokens"
:
16
,
"max_model_len"
:
32768
,
},
max_model_len
=
32768
,
max_num_seqs
=
128
,
gpu_memory_utilization
=
0.85
,
enforce_eager
=
False
,
disable_log_stats
=
False
,
)
def
test_dflash_acceptance_rates
(
dflash_config
):
"""
E2E test for DFlash (block diffusion) speculative decoding.
Runs acceptance rate validation on GSM8k, MT-Bench, and HumanEval
comparing against baseline results from the paper (Table 1).
See https://github.com/z-lab/dflash/blob/main/benchmark_sglang.py for methodology.
"""
spec_llm
=
LLM
(
**
dflash_config
)
max_prompts_per_dataset
=
200
# mt-bench has 80, humaneval has 164, truncates gsm8k
# All scores from Table 1 in https://arxiv.org/pdf/2602.06036
expected_acceptance_lengths
=
{
"mt-bench"
:
4.24
,
"humaneval"
:
6.50
,
"gsm8k"
:
6.54
*
0.95
,
# runs with a subset of prompts so extra wide tol here
}
tokenizer
=
spec_llm
.
get_tokenizer
()
for
dataset_name
,
expected_len
in
expected_acceptance_lengths
.
items
():
dataset
=
load_and_process_dataset
(
dataset_name
)
prev_metrics
=
None
acceptance_lengths
=
[]
for
i
in
tqdm
(
range
(
min
(
max_prompts_per_dataset
,
len
(
dataset
))),
desc
=
f
"Processing
{
dataset_name
}
"
,
):
user_content
=
dataset
[
i
][
"turns"
][
0
]
prompt_text
=
tokenizer
.
apply_chat_template
(
[{
"role"
:
"user"
,
"content"
:
user_content
}],
tokenize
=
False
,
add_generation_prompt
=
True
,
enable_thinking
=
False
,
)
# Temp=0, MaxTokens=2048 from the paper
spec_llm
.
generate
(
[
prompt_text
],
SamplingParams
(
temperature
=
0
,
max_tokens
=
2048
),
use_tqdm
=
False
,
)
current_metrics
=
spec_llm
.
get_metrics
()
acceptance_len
=
compute_acceptance_len
(
current_metrics
,
prev_metrics
)
prev_metrics
=
current_metrics
acceptance_lengths
.
append
(
acceptance_len
)
mean_acceptance_length
=
sum
(
acceptance_lengths
)
/
len
(
acceptance_lengths
)
expected_len
=
expected_len
*
0.9
print
(
f
"DFlash acceptance_len for
{
dataset_name
}
:
{
mean_acceptance_length
:.
2
f
}
"
f
" (expected at least
{
expected_len
:.
2
f
}
)"
)
assert
mean_acceptance_length
>=
expected_len
,
(
f
"DFlash acceptance_len for
{
dataset_name
}
is below expected threshold:"
f
"
{
mean_acceptance_length
:.
2
f
}
<
{
expected_len
:.
2
f
}
"
)
del
spec_llm
torch
.
accelerator
.
empty_cache
()
cleanup_dist_env_and_memory
()
def
test_dflash_correctness
(
dflash_config
):
"""
E2E test for DFlash (block diffusion) speculative decoding.
Ensures output correctness on GSM8k, with cudagraphs and batching on.
"""
spec_llm
=
LLM
(
**
dflash_config
)
# Evaluate GSM8k accuracy (Qwen3-8B ref: ~87-92% on GSM8k)
evaluate_llm_for_gsm8k
(
spec_llm
,
expected_accuracy_threshold
=
0.8
)
current_metrics
=
spec_llm
.
get_metrics
()
acceptance_len
=
compute_acceptance_len
(
current_metrics
)
# AR is thoroughly validated in test_dflash_acceptance_rates, in a manner consistent
# with the DFlash paper. However, that test measures AL per-request and thus runs
# with a batch size of 1. To ensure that AL does not collapse with large batch sizes
# we enforce a baseline on the AL over the full lm-eval-style GSM8k test.
expected_len
=
3.5
# Measured is 3.9 to 4.0
print
(
f
"DFlash GSM8k correctness test got AL
{
acceptance_len
}
"
)
assert
acceptance_len
>=
expected_len
,
(
"DFlash correctness check failed with"
f
"
{
acceptance_len
=
}
, expected at least
{
expected_len
}
"
)
del
spec_llm
torch
.
accelerator
.
empty_cache
()
cleanup_dist_env_and_memory
()
tests/v1/spec_decode/test_eagle.py
View file @
494636b2
...
...
@@ -27,6 +27,7 @@ from vllm.config.load import LoadConfig
from
vllm.model_executor.models.llama
import
LlamaForCausalLM
from
vllm.platforms
import
current_platform
from
vllm.v1.attention.backends.registry
import
AttentionBackendEnum
from
vllm.v1.spec_decode.dflash
import
DFlashProposer
from
vllm.v1.spec_decode.draft_model
import
DraftModelProposer
from
vllm.v1.spec_decode.eagle
import
EagleProposer
from
vllm.v1.spec_decode.metadata
import
SpecDecodeMetadata
...
...
@@ -36,6 +37,8 @@ model_dir = "meta-llama/Llama-3.1-8B-Instruct"
eagle_dir
=
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
eagle3_dir
=
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
ar_draft_model_dir
=
"amd/PARD-Llama-3.2-1B"
# Compatible with parallel and AR drafting
dflash_target_dir
=
"Qwen/Qwen3-8B"
dflash_dir
=
"z-lab/Qwen3-8B-DFlash-b16"
BLOCK_SIZE
=
16
...
...
@@ -47,18 +50,29 @@ def _create_proposer(
speculative_token_tree
:
list
[
tuple
[
int
,
...]]
|
None
=
None
,
parallel_drafting
:
bool
=
False
,
)
->
EagleProposer
:
model_config
=
ModelConfig
(
model
=
model_dir
,
runner
=
"generate"
,
max_model_len
=
100
)
# Method-dependent setup
if
method
==
"eagle"
:
target_model_dir
=
model_dir
draft_model_dir
=
eagle_dir
elif
method
==
"eagle3"
:
target_model_dir
=
model_dir
draft_model_dir
=
eagle3_dir
elif
method
==
"draft_model"
:
target_model_dir
=
model_dir
draft_model_dir
=
ar_draft_model_dir
elif
method
==
"dflash"
:
target_model_dir
=
dflash_target_dir
draft_model_dir
=
dflash_dir
else
:
raise
ValueError
(
f
"Unknown method:
{
method
}
"
)
model_config
=
ModelConfig
(
model
=
target_model_dir
,
runner
=
"generate"
,
max_model_len
=
100
,
trust_remote_code
=
(
method
==
"dflash"
),
)
spec_token_tree_str
=
None
if
speculative_token_tree
is
not
None
:
assert
num_speculative_tokens
==
len
(
speculative_token_tree
)
...
...
@@ -92,7 +106,9 @@ def _create_proposer(
attention_config
=
AttentionConfig
(
backend
=
attention_backend
),
)
if
"eagle"
in
method
:
if
method
==
"dflash"
:
proposer
=
DFlashProposer
(
vllm_config
=
vllm_config
,
device
=
device
)
elif
"eagle"
in
method
:
proposer
=
EagleProposer
(
vllm_config
=
vllm_config
,
device
=
device
)
else
:
proposer
=
DraftModelProposer
(
vllm_config
=
vllm_config
,
device
=
device
)
...
...
@@ -1152,3 +1168,136 @@ def test_propose_tree(spec_token_tree):
# Verify that the draft tokens match our expectations.
assert
torch
.
equal
(
result
,
expected_tokens
)
def
test_set_inputs_first_pass_dflash
():
"""
Test for DFlash set_inputs_first_pass.
DFlash uses cross-attention: context tokens become K/V and only
query tokens (bonus + mask) are Q. This tests the DFlash-specific
input preparation where:
- Context hidden states are stored by reference (no copy)
- Query input_ids are [next_token, mask, mask, ...] per request
- Context and query positions are written to separate buffers
- token_indices_to_sample points to mask token positions only
- A new CommonAttentionMetadata is returned with causal=False
Setup:
- 3 requests with query_lens [3, 2, 4]
- num_speculative_tokens = 3
- num_query_per_req = 4 (1 bonus + 3 mask tokens)
- next_token_ids: [100, 200, 300]
Expected output layout (query tokens only, 12 total):
Request 0 (indices 0-3): [100, mask, mask, mask]
Request 1 (indices 4-7): [200, mask, mask, mask]
Request 2 (indices 8-11): [300, mask, mask, mask]
Expected positions layout (separate buffers):
Context (_context_positions_buffer, 9 tokens): copied from target_positions
Query (positions, 12 tokens):
Request 0: last_pos=9, query=[10, 11, 12, 13]
Request 1: last_pos=7, query=[8, 9, 10, 11]
Request 2: last_pos=11, query=[12, 13, 14, 15]
"""
device
=
torch
.
device
(
current_platform
.
device_type
)
num_speculative_tokens
=
3
proposer
=
_create_proposer
(
"dflash"
,
num_speculative_tokens
)
mask_token_id
=
proposer
.
parallel_drafting_token_id
# Setup batch with 3 requests
batch_spec
=
BatchSpec
(
seq_lens
=
[
10
,
8
,
12
],
query_lens
=
[
3
,
2
,
4
],
)
common_attn_metadata
=
create_common_attn_metadata
(
batch_spec
,
block_size
=
BLOCK_SIZE
,
device
=
device
,
arange_block_indices
=
True
,
)
# Input tensors
# Request 0: tokens [10, 11, 12] at positions [7, 8, 9]
# Request 1: tokens [20, 21] at positions [6, 7]
# Request 2: tokens [30, 31, 32, 33] at positions [8, 9, 10, 11]
target_token_ids
=
torch
.
tensor
(
[
10
,
11
,
12
,
20
,
21
,
30
,
31
,
32
,
33
],
dtype
=
torch
.
int32
,
device
=
device
)
target_positions
=
torch
.
tensor
(
[
7
,
8
,
9
,
6
,
7
,
8
,
9
,
10
,
11
],
dtype
=
torch
.
int64
,
device
=
device
)
target_hidden_states
=
torch
.
randn
(
9
,
proposer
.
hidden_size
,
dtype
=
proposer
.
dtype
,
device
=
device
)
next_token_ids
=
torch
.
tensor
([
100
,
200
,
300
],
dtype
=
torch
.
int32
,
device
=
device
)
num_tokens
,
token_indices_to_sample
,
output_cad
=
proposer
.
set_inputs_first_pass
(
target_token_ids
=
target_token_ids
,
next_token_ids
=
next_token_ids
,
target_positions
=
target_positions
,
target_hidden_states
=
target_hidden_states
,
token_indices_to_sample
=
None
,
cad
=
common_attn_metadata
,
num_rejected_tokens_gpu
=
None
,
)
num_query_per_req
=
1
+
num_speculative_tokens
# 4
num_context
=
9
# num_tokens is the query-only count
assert
num_tokens
==
3
*
num_query_per_req
# 12
# Verify input_ids (query tokens only)
# Each request: [next_token, mask, mask, mask]
M
=
mask_token_id
expected_input_ids
=
torch
.
tensor
(
[
100
,
M
,
M
,
M
,
200
,
M
,
M
,
M
,
300
,
M
,
M
,
M
],
dtype
=
torch
.
int32
,
device
=
device
,
)
assert
torch
.
equal
(
proposer
.
input_ids
[:
num_tokens
],
expected_input_ids
)
# Verify context positions (separate buffer): copied from target_positions
assert
torch
.
equal
(
proposer
.
_context_positions_buffer
[:
num_context
],
target_positions
)
# Verify query positions (separate buffer, starts at index 0):
# req0: last_pos=9, query=[10, 11, 12, 13]
# req1: last_pos=7, query=[8, 9, 10, 11]
# req2: last_pos=11, query=[12, 13, 14, 15]
expected_query_positions
=
torch
.
tensor
(
[
10
,
11
,
12
,
13
,
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
],
dtype
=
torch
.
int64
,
device
=
device
,
)
assert
torch
.
equal
(
proposer
.
positions
[:
num_tokens
],
expected_query_positions
,
)
# Verify token_indices_to_sample (mask tokens only, skip bonus at offset 0)
# req0: query indices 0-3, mask at 1,2,3
# req1: query indices 4-7, mask at 5,6,7
# req2: query indices 8-11, mask at 9,10,11
expected_token_indices_to_sample
=
torch
.
tensor
(
[
1
,
2
,
3
,
5
,
6
,
7
,
9
,
10
,
11
],
dtype
=
torch
.
int32
,
device
=
device
)
assert
torch
.
equal
(
token_indices_to_sample
,
expected_token_indices_to_sample
)
# Verify the new CAD has DFlash-specific properties
assert
output_cad
.
causal
is
False
# DFlash requires non-causal attention
assert
output_cad
.
num_actual_tokens
==
num_tokens
# query-only count
assert
output_cad
.
max_query_len
==
num_query_per_req
expected_query_start_loc
=
torch
.
tensor
(
[
0
,
4
,
8
,
12
],
dtype
=
torch
.
int32
,
device
=
device
)
assert
torch
.
equal
(
output_cad
.
query_start_loc
,
expected_query_start_loc
)
# Verify hidden states (stored by reference, not copied)
assert
proposer
.
_dflash_hidden_states
is
target_hidden_states
vllm/config/speculative.py
View file @
494636b2
...
...
@@ -47,8 +47,11 @@ MTPModelTypes = Literal[
"pangu_ultra_moe_mtp"
,
"step3p5_mtp"
,
]
EagleModelTypes
=
Literal
[
"eagle"
,
"eagle3"
,
"extract_hidden_states"
,
MTPModelTypes
]
NgramGPUTypes
=
Literal
[
"ngram_gpu"
]
DFlashModelTypes
=
Literal
[
"dflash"
]
EagleModelTypes
=
Literal
[
"eagle"
,
"eagle3"
,
"extract_hidden_states"
,
MTPModelTypes
,
DFlashModelTypes
]
SpeculativeMethod
=
Literal
[
"ngram"
,
"medusa"
,
...
...
@@ -206,7 +209,11 @@ class SpeculativeConfig:
factors
:
list
[
Any
]
=
[]
# Eagle3 and extract_hidden_states affect the computation graph because
# they return intermediate hidden states in addition to the final hidden state.
uses_aux_hidden_states
=
self
.
method
in
(
"eagle3"
,
"extract_hidden_states"
)
uses_aux_hidden_states
=
self
.
method
in
(
"eagle3"
,
"extract_hidden_states"
,
"dflash"
,
)
factors
.
append
(
uses_aux_hidden_states
)
# The specific layers used also affect the computation graph
...
...
@@ -490,7 +497,7 @@ class SpeculativeConfig:
)
# Automatically detect the method
if
self
.
method
in
(
"eagle"
,
"eagle3"
):
if
self
.
method
in
(
"eagle"
,
"eagle3"
,
"dflash"
):
pass
# examples:
# yuhuili/EAGLE-LLaMA3-Instruct-8B
...
...
@@ -500,6 +507,8 @@ class SpeculativeConfig:
self
.
method
=
"eagle"
elif
"eagle3"
in
self
.
draft_model_config
.
model
.
lower
():
self
.
method
=
"eagle3"
elif
"dflash"
in
self
.
draft_model_config
.
model
.
lower
():
self
.
method
=
"dflash"
elif
self
.
draft_model_config
.
hf_config
.
model_type
==
"medusa"
:
self
.
method
=
"medusa"
elif
self
.
draft_model_config
.
hf_config
.
model_type
==
"mlp_speculator"
:
...
...
@@ -532,7 +541,7 @@ class SpeculativeConfig:
)
# Replace hf_config for EAGLE draft_model
if
self
.
method
in
(
"eagle"
,
"eagle3"
):
if
self
.
method
in
(
"eagle"
,
"eagle3"
,
"dflash"
):
from
vllm.transformers_utils.configs.eagle
import
EAGLEConfig
from
vllm.transformers_utils.configs.speculators
import
(
SpeculatorsConfig
,
...
...
@@ -552,6 +561,9 @@ class SpeculativeConfig:
self
.
draft_model_config
.
hf_config
=
eagle_config
self
.
update_arch_
()
if
self
.
method
==
"dflash"
:
self
.
parallel_drafting
=
True
if
self
.
num_speculative_tokens
is
not
None
and
hasattr
(
self
.
draft_model_config
.
hf_config
,
"num_lookahead_tokens"
):
...
...
@@ -807,7 +819,7 @@ class SpeculativeConfig:
"kimi_k25"
,
]
if
(
self
.
method
in
(
"eagle3"
,
"extract_hidden_states"
)
self
.
method
in
(
"eagle3"
,
"extract_hidden_states"
,
"dflash"
)
and
self
.
target_model_config
and
not
any
(
supported_model
in
self
.
target_model_config
.
hf_text_config
.
model_type
...
...
@@ -855,7 +867,10 @@ class SpeculativeConfig:
return
slots_per_req
def
use_eagle
(
self
)
->
bool
:
return
self
.
method
in
(
"eagle"
,
"eagle3"
,
"mtp"
)
return
self
.
method
in
(
"eagle"
,
"eagle3"
,
"mtp"
,
"dflash"
)
def
use_dflash
(
self
)
->
bool
:
return
self
.
method
==
"dflash"
def
uses_draft_model
(
self
)
->
bool
:
return
self
.
method
==
"draft_model"
...
...
vllm/config/vllm.py
View file @
494636b2
...
...
@@ -1327,6 +1327,26 @@ class VllmConfig:
max_num_batched_tokens
-
scheduled_token_delta
)
if
self
.
scheduler_config
.
max_num_scheduled_tokens
<=
0
:
raise
ValueError
(
"max_num_scheduled_tokens is set to"
f
"
{
self
.
scheduler_config
.
max_num_scheduled_tokens
}
based on"
" the speculative decoding settings, which does not allow"
" any tokens to be scheduled. Increase max_num_batched_tokens"
" to accommodate the additional draft token slots, or decrease"
" num_speculative_tokens or max_num_seqs."
)
if
self
.
scheduler_config
.
max_num_scheduled_tokens
<
8192
:
logger
.
warning_once
(
"max_num_scheduled_tokens is set to"
f
"
{
self
.
scheduler_config
.
max_num_scheduled_tokens
}
based on"
" the speculative decoding settings. This may lead to suboptimal"
" performance. Consider increasing max_num_batched_tokens to"
" accommodate the additional draft token slots, or decrease"
" num_speculative_tokens or max_num_seqs."
,
scope
=
"local"
,
)
max_num_scheduled_tokens
=
self
.
scheduler_config
.
max_num_scheduled_tokens
if
max_num_batched_tokens
<
max_num_scheduled_tokens
+
(
self
.
speculative_config
.
max_num_new_slots_for_drafting
...
...
vllm/model_executor/models/qwen3.py
View file @
494636b2
...
...
@@ -285,6 +285,7 @@ class Qwen3ForCausalLM(
self
.
config
=
config
self
.
vllm_config
=
vllm_config
self
.
quant_config
=
quant_config
self
.
model
=
Qwen3Model
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
)
...
...
vllm/model_executor/models/qwen3_dflash.py
0 → 100644
View file @
494636b2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Iterable
import
torch
import
torch.nn.functional
as
F
from
torch
import
nn
from
transformers
import
Qwen3Config
from
vllm
import
_custom_ops
as
ops
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
VllmConfig
,
get_current_vllm_config
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.attention
import
Attention
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
QKVParallelLinear
,
ReplicatedLinear
,
RowParallelLinear
,
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
,
)
from
vllm.model_executor.model_loader.weight_utils
import
(
default_weight_loader
,
maybe_remap_kv_scale_name
,
)
from
vllm.multimodal.inputs
import
NestedTensors
from
vllm.transformers_utils.config
import
set_default_rope_theta
from
vllm.v1.attention.backend
import
AttentionType
from
.qwen2
import
Qwen2MLP
as
Qwen3MLP
from
.qwen3
import
Qwen3ForCausalLM
from
.utils
import
(
AutoWeightsLoader
,
get_draft_quant_config
,
maybe_prefix
,
process_eagle_weight
,
)
logger
=
init_logger
(
__name__
)
class
DFlashQwen3Attention
(
nn
.
Module
):
"""Attention for DFlash speculative decoding.
Context KVs are pre-inserted into the KV cache before the forward pass.
This layer handles only query tokens via standard attention.
Adapted from Qwen3Attention."""
def
__init__
(
self
,
hidden_size
:
int
,
num_heads
:
int
,
num_kv_heads
:
int
,
rope_parameters
:
dict
,
max_position
:
int
=
4096
*
32
,
head_dim
:
int
|
None
=
None
,
rms_norm_eps
:
float
=
1e-06
,
attention_bias
:
bool
=
False
,
cache_config
:
CacheConfig
|
None
=
None
,
quant_config
:
QuantizationConfig
|
None
=
None
,
prefix
:
str
=
""
,
attn_type
:
str
=
AttentionType
.
DECODER
,
)
->
None
:
super
().
__init__
()
self
.
layer_name
=
prefix
self
.
hidden_size
=
hidden_size
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
total_num_heads
=
num_heads
assert
self
.
total_num_heads
%
tp_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tp_size
self
.
total_num_kv_heads
=
num_kv_heads
if
self
.
total_num_kv_heads
>=
tp_size
:
assert
self
.
total_num_kv_heads
%
tp_size
==
0
else
:
assert
tp_size
%
self
.
total_num_kv_heads
==
0
self
.
num_kv_heads
=
max
(
1
,
self
.
total_num_kv_heads
//
tp_size
)
self
.
head_dim
=
head_dim
or
hidden_size
//
self
.
total_num_heads
self
.
q_size
=
self
.
num_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
qkv_proj
=
QKVParallelLinear
(
hidden_size
,
self
.
head_dim
,
self
.
total_num_heads
,
self
.
total_num_kv_heads
,
bias
=
attention_bias
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.qkv_proj"
,
)
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
hidden_size
,
bias
=
attention_bias
,
# DFlash has o_proj bias when using attention bias
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.o_proj"
,
)
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
max_position
=
max_position
,
rope_parameters
=
rope_parameters
,
)
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.attn"
,
attn_type
=
attn_type
,
)
self
.
q_norm
=
RMSNorm
(
self
.
head_dim
,
eps
=
rms_norm_eps
)
self
.
k_norm
=
RMSNorm
(
self
.
head_dim
,
eps
=
rms_norm_eps
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
"""DFlash attention assumes that the KV cache is already populated
with the context K/V from the target model's hidden states. This forward op
computes attention for the query tokens only.
See also: precompute_and_store_context_kv"""
qkv
=
F
.
linear
(
hidden_states
,
self
.
qkv_proj
.
weight
,
self
.
qkv_proj
.
bias
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
# Per-head RMSNorm
q_shape
,
k_shape
=
q
.
shape
,
k
.
shape
q
=
self
.
q_norm
(
q
.
view
(
*
q_shape
[:
-
1
],
q_shape
[
-
1
]
//
self
.
head_dim
,
self
.
head_dim
)
).
view
(
q_shape
)
k
=
self
.
k_norm
(
k
.
view
(
*
k_shape
[:
-
1
],
k_shape
[
-
1
]
//
self
.
head_dim
,
self
.
head_dim
)
).
view
(
k_shape
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
class
DFlashQwen3DecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
*
,
config
:
Qwen3Config
,
cache_config
:
CacheConfig
|
None
=
None
,
quant_config
:
QuantizationConfig
|
None
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
set_default_rope_theta
(
config
,
default_theta
=
1000000
)
attn_type
=
AttentionType
.
DECODER
self
.
self_attn
=
DFlashQwen3Attention
(
hidden_size
=
self
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
max_position
=
config
.
max_position_embeddings
,
num_kv_heads
=
config
.
num_key_value_heads
,
rms_norm_eps
=
config
.
rms_norm_eps
,
attention_bias
=
getattr
(
config
,
"attention_bias"
,
False
),
head_dim
=
getattr
(
config
,
"head_dim"
,
None
),
cache_config
=
cache_config
,
quant_config
=
quant_config
,
rope_parameters
=
config
.
rope_parameters
,
prefix
=
f
"
{
prefix
}
.self_attn"
,
attn_type
=
attn_type
,
)
self
.
mlp
=
Qwen3MLP
(
hidden_size
=
self
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.mlp"
,
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
|
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
residual
is
not
None
:
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
else
:
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
)
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
mlp
(
hidden_states
)
return
hidden_states
,
residual
@
support_torch_compile
class
DFlashQwen3Model
(
nn
.
Module
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
start_layer_id
:
int
=
0
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
config
=
vllm_config
.
speculative_config
.
draft_model_config
.
hf_config
self
.
vocab_size
=
self
.
config
.
vocab_size
self
.
quant_config
=
get_draft_quant_config
(
vllm_config
)
drafter_config
=
getattr
(
self
.
config
,
"eagle_config"
,
{})
drafter_config
.
update
(
getattr
(
self
.
config
,
"dflash_config"
,
{}))
if
drafter_config
is
not
None
and
"use_aux_hidden_state"
in
drafter_config
:
self
.
use_aux_hidden_state
=
drafter_config
[
"use_aux_hidden_state"
]
else
:
self
.
use_aux_hidden_state
=
True
current_vllm_config
=
get_current_vllm_config
()
self
.
embed_tokens
=
VocabParallelEmbedding
(
self
.
config
.
vocab_size
,
self
.
config
.
hidden_size
,
prefix
=
maybe_prefix
(
prefix
,
"embed_tokens"
),
)
self
.
layers
=
nn
.
ModuleList
(
[
DFlashQwen3DecoderLayer
(
current_vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
f
"layers.
{
layer_idx
+
start_layer_id
}
"
),
config
=
self
.
config
,
)
for
layer_idx
in
range
(
self
.
config
.
num_hidden_layers
)
]
)
if
self
.
use_aux_hidden_state
:
num_features_to_use
=
self
.
config
.
num_hidden_layers
if
"target_layer_ids"
in
drafter_config
:
num_features_to_use
=
len
(
drafter_config
[
"target_layer_ids"
])
elif
"layer_ids"
in
drafter_config
:
num_features_to_use
=
len
(
drafter_config
[
"layer_ids"
])
if
hasattr
(
self
.
config
,
"target_hidden_size"
):
fc_input_size
=
self
.
config
.
target_hidden_size
*
num_features_to_use
else
:
fc_input_size
=
self
.
config
.
hidden_size
*
num_features_to_use
self
.
fc
=
ReplicatedLinear
(
input_size
=
fc_input_size
,
output_size
=
self
.
config
.
hidden_size
,
bias
=
False
,
params_dtype
=
vllm_config
.
model_config
.
dtype
,
quant_config
=
self
.
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"fc"
),
return_bias
=
False
,
)
self
.
hidden_norm
=
RMSNorm
(
self
.
config
.
hidden_size
,
eps
=
self
.
config
.
rms_norm_eps
,
)
self
.
norm
=
RMSNorm
(
self
.
config
.
hidden_size
,
eps
=
self
.
config
.
rms_norm_eps
,
)
def
embed_input_ids
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embed_tokens
(
input_ids
)
def
_build_fused_kv_buffers
(
self
)
->
None
:
"""Build fused weight buffers for precompute_and_store_context_kv.
Must be called after weights are loaded. Stacks the KV-projection
weights, K-norm weights, and RoPE parameters from every attention
layer so that precompute_and_store_context_kv can run one fused
GEMM for all layers at once. Also aliases the weight of the hidden_norm.
"""
layers_attn
=
[
layer
.
self_attn
for
layer
in
self
.
layers
]
attn0
=
layers_attn
[
0
]
has_bias
=
attn0
.
qkv_proj
.
bias
is
not
None
self
.
_hidden_norm_weight
=
self
.
hidden_norm
.
weight
.
data
# KV projection weights: [num_layers * 2 * kv_size, hidden_size]
kv_weights
=
[
a
.
qkv_proj
.
weight
[
a
.
q_size
:]
for
a
in
layers_attn
]
self
.
_fused_kv_weight
=
torch
.
cat
(
kv_weights
,
dim
=
0
)
if
has_bias
:
kv_biases
=
[
a
.
qkv_proj
.
bias
[
a
.
q_size
:]
for
a
in
layers_attn
]
self
.
_fused_kv_bias
:
torch
.
Tensor
|
None
=
torch
.
cat
(
kv_biases
,
dim
=
0
)
else
:
self
.
_fused_kv_bias
=
None
# K-norm weights: list of [head_dim] tensors, one per layer.
self
.
_k_norm_weights
=
[
a
.
k_norm
.
weight
.
data
for
a
in
layers_attn
]
# RoPE parameters
self
.
_rope_head_size
=
attn0
.
rotary_emb
.
head_size
self
.
_rope_cos_sin_cache
=
attn0
.
rotary_emb
.
cos_sin_cache
self
.
_rope_is_neox
=
attn0
.
rotary_emb
.
is_neox_style
# Validation that RoPE params are the same across all layers
for
attn
in
layers_attn
[
1
:]:
assert
(
attn
.
rotary_emb
.
head_size
==
self
.
_rope_head_size
and
attn
.
rotary_emb
.
is_neox_style
==
self
.
_rope_is_neox
),
"All layers must have the same RoPE parameters for DFlash precomputation"
# Layer metadata
self
.
_num_attn_layers
=
len
(
layers_attn
)
self
.
_kv_size
=
attn0
.
kv_size
self
.
_head_dim
=
attn0
.
head_dim
self
.
_num_kv_heads
=
attn0
.
num_kv_heads
self
.
_rms_norm_eps
=
attn0
.
q_norm
.
variance_epsilon
# Validation that all layers have the same attention config
for
attn
in
layers_attn
[
1
:]:
assert
(
attn
.
kv_size
==
self
.
_kv_size
and
attn
.
head_dim
==
self
.
_head_dim
and
attn
.
num_kv_heads
==
self
.
_num_kv_heads
and
attn
.
q_norm
.
variance_epsilon
==
self
.
_rms_norm_eps
),
"All layers must have the same attn config for DFlash precomputation"
# References to inner Attention layers for direct cache writes
self
.
_attn_layers
=
[
layer
.
self_attn
.
attn
for
layer
in
self
.
layers
]
def
precompute_and_store_context_kv
(
self
,
context_states
:
torch
.
Tensor
,
context_positions
:
torch
.
Tensor
,
context_slot_mapping
:
torch
.
Tensor
|
None
=
None
,
)
->
None
:
"""Precompute K/V for context states write them into each layer's KV cache.
Input context states are projected to K/V, normed, and have RoPE applied.
Since the context shape is different than the query shape, we can't rely on the
regular forward pass to apply torch.compile and CUDA graphs to this section.
As such, this function is optimized to minimize the number of torch ops present:
we use fused vLLM kernels for RMSNorm and RoPE, fuse the GEMM into one
large projection, and avoid cloning buffers (with .contiguous()) where possible.
When context_slot_mapping is None (e.g. during dummy_run) only
the computation runs, and no K/V is written to cache.
"""
if
not
hasattr
(
self
,
"_num_attn_layers"
):
logger
.
warning_once
(
"DFlash buffer initialization was skipped. If dummy weights are not "
"in use, this may indicate an error in weight loading."
)
self
.
_build_fused_kv_buffers
()
num_ctx
=
context_states
.
shape
[
0
]
L
=
self
.
_num_attn_layers
kv
=
self
.
_kv_size
hd
=
self
.
_head_dim
nkv
=
self
.
_num_kv_heads
# --- Fused KV projection (one GEMM for all layers) ---
normed_context_states
=
torch
.
empty_like
(
context_states
)
ops
.
rms_norm
(
normed_context_states
,
context_states
,
self
.
_hidden_norm_weight
,
self
.
_rms_norm_eps
,
)
all_kv_flat
=
F
.
linear
(
normed_context_states
,
self
.
_fused_kv_weight
,
self
.
_fused_kv_bias
)
# Single contiguous copy that separates K/V and transposes to
# layer-major layout. Result: [2, L, num_ctx, nkv, hd] contiguous.
# Indexing dim-0 gives contiguous [L, num_ctx, nkv, hd] for K and V.
all_kv
=
(
all_kv_flat
.
view
(
num_ctx
,
L
,
2
,
nkv
,
hd
).
permute
(
2
,
1
,
0
,
3
,
4
).
contiguous
()
)
all_k
=
all_kv
[
0
]
# [L, num_ctx, nkv, hd], contiguous
all_v
=
all_kv
[
1
]
# [L, num_ctx, nkv, hd], contiguous
# --- Per-layer RMSNorm K (3D: [num_ctx, nkv, hd] per layer) ---
all_k_normed
=
torch
.
empty_like
(
all_k
)
for
i
in
range
(
L
):
ops
.
rms_norm
(
all_k_normed
[
i
],
all_k
[
i
],
self
.
_k_norm_weights
[
i
],
self
.
_rms_norm_eps
,
)
# --- Fused RoPE across all layers ---
# View as [L * num_ctx, kv] so RoPE sees one big batch (no copy).
# In-place RoPE: pass K as the "query" arg with key=None.
all_k_flat
=
all_k_normed
.
view
(
L
*
num_ctx
,
kv
)
positions_repeated
=
context_positions
.
repeat
(
L
)
cos_sin_cache
=
self
.
_rope_cos_sin_cache
if
cos_sin_cache
.
dtype
!=
all_k_flat
.
dtype
:
cos_sin_cache
=
cos_sin_cache
.
to
(
dtype
=
all_k_flat
.
dtype
)
ops
.
rotary_embedding
(
positions_repeated
,
all_k_flat
,
None
,
self
.
_rope_head_size
,
cos_sin_cache
,
self
.
_rope_is_neox
,
)
if
context_slot_mapping
is
None
:
return
# --- Per-layer cache insert ---
all_k_final
=
all_k_flat
.
view
(
L
,
num_ctx
,
nkv
,
hd
)
for
i
in
range
(
L
):
attn
=
self
.
_attn_layers
[
i
]
kv_cache
=
attn
.
kv_cache
attn
.
impl
.
do_kv_cache_update
(
attn
,
all_k_final
[
i
],
all_v
[
i
],
kv_cache
,
context_slot_mapping
,
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_embeds
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
if
input_embeds
is
None
:
input_embeds
=
self
.
embed_input_ids
(
input_ids
)
hidden_states
=
input_embeds
residual
=
None
for
layer
in
self
.
layers
:
hidden_states
,
residual
=
layer
(
positions
=
positions
,
hidden_states
=
hidden_states
,
residual
=
residual
,
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
stacked_params_mapping
=
[
(
".qkv_proj"
,
".q_proj"
,
"q"
),
(
".qkv_proj"
,
".k_proj"
,
"k"
),
(
".qkv_proj"
,
".v_proj"
,
"v"
),
(
".gate_up_proj"
,
".gate_proj"
,
0
),
(
".gate_up_proj"
,
".up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
if
"midlayer."
in
name
:
name
=
name
.
replace
(
"midlayer."
,
"layers.0."
)
if
self
.
quant_config
is
not
None
and
(
scale_name
:
=
self
.
quant_config
.
get_cache_scale
(
name
)
):
param
=
params_dict
[
scale_name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
loaded_weight
=
(
loaded_weight
if
loaded_weight
.
dim
()
==
0
else
loaded_weight
[
0
]
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
scale_name
)
continue
if
"scale"
in
name
:
name
=
maybe_remap_kv_scale_name
(
name
,
params_dict
)
if
name
is
None
:
continue
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
class
DFlashQwen3ForCausalLM
(
Qwen3ForCausalLM
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
nn
.
Module
.
__init__
(
self
)
self
.
config
=
vllm_config
.
speculative_config
.
draft_model_config
.
hf_config
if
getattr
(
self
.
config
,
"draft_vocab_size"
,
None
)
is
None
:
self
.
config
.
draft_vocab_size
=
getattr
(
self
.
config
,
"vocab_size"
,
None
)
target_layer_num
=
vllm_config
.
model_config
.
get_num_layers
(
vllm_config
.
parallel_config
)
self
.
config
.
target_layer_count
=
target_layer_num
self
.
model
=
DFlashQwen3Model
(
vllm_config
=
vllm_config
,
prefix
=
"model"
,
start_layer_id
=
target_layer_num
,
)
logit_scale
=
getattr
(
self
.
config
,
"logit_scale"
,
1.0
)
self
.
lm_head
=
ParallelLMHead
(
self
.
config
.
draft_vocab_size
,
self
.
config
.
hidden_size
,
prefix
=
maybe_prefix
(
prefix
,
"lm_head"
),
)
self
.
logits_processor
=
LogitsProcessor
(
self
.
config
.
draft_vocab_size
,
scale
=
logit_scale
)
self
.
draft_id_to_target_id
=
None
def
embed_input_ids
(
self
,
input_ids
:
torch
.
Tensor
,
multimodal_embeddings
:
NestedTensors
|
None
=
None
,
is_multimodal
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
return
self
.
model
.
embed_input_ids
(
input_ids
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
inputs_embeds
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
return
self
.
model
(
input_ids
,
positions
,
inputs_embeds
)
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
)
->
torch
.
Tensor
|
None
:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
)
if
self
.
draft_id_to_target_id
is
None
:
return
logits
base
=
torch
.
arange
(
self
.
config
.
draft_vocab_size
,
device
=
logits
.
device
)
targets
=
base
+
self
.
draft_id_to_target_id
logits_new
=
logits
.
new_full
(
(
logits
.
shape
[
0
],
self
.
config
.
vocab_size
),
float
(
"-inf"
),
)
logits_new
[:,
targets
]
=
logits
return
logits_new
def
precompute_and_store_context_kv
(
self
,
context_states
:
torch
.
Tensor
,
context_positions
:
torch
.
Tensor
,
context_slot_mapping
:
torch
.
Tensor
|
None
=
None
,
)
->
None
:
"""Precompute projected + RoPE'd K/V and write to cache."""
self
.
model
.
precompute_and_store_context_kv
(
context_states
,
context_positions
,
context_slot_mapping
)
def
combine_hidden_states
(
self
,
hidden_states
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
if
not
self
.
model
.
use_aux_hidden_state
:
return
hidden_states
needs_squeeze
=
hidden_states
.
dim
()
==
1
if
needs_squeeze
:
hidden_states
=
hidden_states
.
unsqueeze
(
0
)
result
=
self
.
model
.
fc
(
hidden_states
)
if
needs_squeeze
:
result
=
result
.
squeeze
(
0
)
return
result
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]]):
model_weights
=
{}
includes_draft_id_mapping
=
False
includes_embed_tokens
=
False
for
name
,
loaded_weight
in
weights
:
assert
"mask_hidden"
not
in
name
,
(
"DFlash should use mask_token_id to embed the padding hidden state"
)
if
"t2d"
in
name
:
continue
if
"d2t"
in
name
:
name
=
name
.
replace
(
"d2t"
,
"draft_id_to_target_id"
)
includes_draft_id_mapping
=
True
elif
"lm_head"
not
in
name
:
name
=
"model."
+
name
if
"embed_tokens"
in
name
:
includes_embed_tokens
=
True
model_weights
[
name
]
=
loaded_weight
process_eagle_weight
(
self
,
name
)
skip_substrs
=
[]
if
not
includes_draft_id_mapping
:
skip_substrs
.
append
(
"draft_id_to_target_id"
)
if
not
includes_embed_tokens
:
skip_substrs
.
append
(
"embed_tokens"
)
if
not
self
.
model
.
use_aux_hidden_state
:
skip_substrs
.
append
(
"fc."
)
loader
=
AutoWeightsLoader
(
self
,
skip_prefixes
=
None
,
skip_substrs
=
skip_substrs
,
)
loader
.
load_weights
(
model_weights
.
items
())
self
.
model
.
_build_fused_kv_buffers
()
vllm/model_executor/models/qwen3_next.py
View file @
494636b2
...
...
@@ -56,6 +56,7 @@ from vllm.sequence import IntermediateTensors
from
vllm.transformers_utils.configs.qwen3_next
import
Qwen3NextConfig
from
.interfaces
import
(
EagleModelMixin
,
HasInnerState
,
IsHybrid
,
MixtureOfExperts
,
...
...
@@ -454,7 +455,7 @@ class Qwen3NextDecoderLayer(nn.Module):
@
support_torch_compile
class
Qwen3NextModel
(
nn
.
Module
):
class
Qwen3NextModel
(
nn
.
Module
,
EagleModelMixin
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
...
...
@@ -492,8 +493,6 @@ class Qwen3NextModel(nn.Module):
else
:
self
.
norm
=
PPMissingLayer
()
self
.
aux_hidden_state_layers
:
tuple
[
int
,
...]
=
()
def
embed_input_ids
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embed_tokens
(
input_ids
)
...
...
@@ -515,20 +514,19 @@ class Qwen3NextModel(nn.Module):
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
aux_hidden_states
=
[]
aux_hidden_states
=
self
.
_maybe_add_hidden_state
([],
0
,
hidden_states
,
residual
)
for
layer_idx
,
layer
in
enumerate
(
islice
(
self
.
layers
,
self
.
start_layer
,
self
.
end_layer
),
start
=
self
.
start_layer
,
):
if
layer_idx
in
self
.
aux_hidden_state_layers
:
aux_hidden_states
.
append
(
hidden_states
+
residual
if
residual
is
not
None
else
hidden_states
)
hidden_states
,
residual
=
layer
(
positions
=
positions
,
hidden_states
=
hidden_states
,
residual
=
residual
,
)
self
.
_maybe_add_hidden_state
(
aux_hidden_states
,
layer_idx
+
1
,
hidden_states
,
residual
)
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
(
...
...
vllm/model_executor/models/registry.py
View file @
494636b2
...
...
@@ -546,6 +546,7 @@ _SPECULATIVE_DECODING_MODELS = {
"EagleLlamaForCausalLM"
:
(
"llama_eagle"
,
"EagleLlamaForCausalLM"
),
"EagleLlama4ForCausalLM"
:
(
"llama4_eagle"
,
"EagleLlama4ForCausalLM"
),
"EagleMiniCPMForCausalLM"
:
(
"minicpm_eagle"
,
"EagleMiniCPMForCausalLM"
),
"DFlashDraftModel"
:
(
"qwen3_dflash"
,
"DFlashQwen3ForCausalLM"
),
"Eagle3LlamaForCausalLM"
:
(
"llama_eagle3"
,
"Eagle3LlamaForCausalLM"
),
"LlamaForCausalLMEagle3"
:
(
"llama_eagle3"
,
"Eagle3LlamaForCausalLM"
),
"Eagle3Qwen2_5vlForCausalLM"
:
(
"llama_eagle3"
,
"Eagle3LlamaForCausalLM"
),
...
...
vllm/transformers_utils/configs/eagle.py
View file @
494636b2
...
...
@@ -62,9 +62,20 @@ class EAGLEConfig(PretrainedConfig):
else
f
"Eagle3
{
arch
}
"
for
arch
in
self
.
model
.
architectures
]
elif
method
==
"dflash"
:
assert
self
.
model
is
not
None
,
(
"model should not be None when method is dflash"
)
kwargs
[
"architectures"
]
=
[
arch
if
arch
.
startswith
(
"DFlash"
)
or
arch
.
endswith
(
"DFlash"
)
else
f
"DFlash
{
arch
}
"
for
arch
in
self
.
model
.
architectures
]
else
:
raise
ValueError
(
f
"Invalid method
{
method
}
. Supported methods are eagle and eagle3."
f
"Invalid method
{
method
}
. Supported methods are "
"eagle, eagle3, and dflash."
)
super
().
__init__
(
**
kwargs
)
...
...
vllm/v1/attention/backend.py
View file @
494636b2
...
...
@@ -220,6 +220,17 @@ class AttentionBackend(ABC):
def
supports_per_head_quant_scales
(
cls
)
->
bool
:
return
False
@
classmethod
def
supports_non_causal
(
cls
)
->
bool
:
"""Check if backend supports non-causal (bidirectional) attention
for decoder models.
Unlike ENCODER_ONLY attention type which implies a different
execution model, this refers to non-causal attention within the
standard paged-KV-cache decoder path.
"""
return
False
@
classmethod
def
supports_attn_type
(
cls
,
attn_type
:
str
)
->
bool
:
"""Check if backend supports a given attention type.
...
...
@@ -261,6 +272,7 @@ class AttentionBackend(ABC):
use_per_head_quant_scales
:
bool
,
device_capability
:
"DeviceCapability"
,
attn_type
:
str
,
use_non_causal
:
bool
=
False
,
)
->
list
[
str
]:
invalid_reasons
=
[]
if
not
cls
.
supports_head_size
(
head_size
):
...
...
@@ -293,6 +305,8 @@ class AttentionBackend(ABC):
invalid_reasons
.
append
(
"compute capability not supported"
)
if
not
cls
.
supports_attn_type
(
attn_type
):
invalid_reasons
.
append
(
f
"attention type
{
attn_type
}
not supported"
)
if
use_non_causal
and
not
cls
.
supports_non_causal
():
invalid_reasons
.
append
(
"non-causal attention not supported"
)
combination_reason
=
cls
.
supports_combination
(
head_size
,
dtype
,
...
...
vllm/v1/attention/backends/flash_attn.py
View file @
494636b2
...
...
@@ -101,6 +101,10 @@ class FlashAttentionBackend(AttentionBackend):
def
get_name
()
->
str
:
return
"FLASH_ATTN"
@
classmethod
def
supports_non_causal
(
cls
)
->
bool
:
return
True
@
classmethod
def
supports_attn_type
(
cls
,
attn_type
:
str
)
->
bool
:
"""FlashAttention supports all attention types."""
...
...
vllm/v1/attention/selector.py
View file @
494636b2
...
...
@@ -29,6 +29,7 @@ class AttentionSelectorConfig(NamedTuple):
use_mm_prefix
:
bool
=
False
use_per_head_quant_scales
:
bool
=
False
attn_type
:
str
=
AttentionType
.
DECODER
use_non_causal
:
bool
=
False
def
__repr__
(
self
):
return
(
...
...
@@ -41,7 +42,8 @@ class AttentionSelectorConfig(NamedTuple):
f
"use_sparse=
{
self
.
use_sparse
}
, "
f
"use_mm_prefix=
{
self
.
use_mm_prefix
}
, "
f
"use_per_head_quant_scales=
{
self
.
use_per_head_quant_scales
}
, "
f
"attn_type=
{
self
.
attn_type
}
)"
f
"attn_type=
{
self
.
attn_type
}
, "
f
"use_non_causal=
{
self
.
use_non_causal
}
)"
)
...
...
@@ -76,6 +78,11 @@ def get_attn_backend(
else
:
block_size
=
None
speculative_config
=
vllm_config
.
speculative_config
use_non_causal
=
(
speculative_config
is
not
None
and
speculative_config
.
method
==
"dflash"
)
attn_selector_config
=
AttentionSelectorConfig
(
head_size
=
head_size
,
dtype
=
dtype
,
...
...
@@ -87,6 +94,7 @@ def get_attn_backend(
use_mm_prefix
=
use_mm_prefix
,
use_per_head_quant_scales
=
use_per_head_quant_scales
,
attn_type
=
attn_type
or
AttentionType
.
DECODER
,
use_non_causal
=
use_non_causal
,
)
return
_cached_get_attn_backend
(
...
...
vllm/v1/spec_decode/dflash.py
0 → 100644
View file @
494636b2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Any
import
torch
from
typing_extensions
import
override
from
vllm.config
import
VllmConfig
from
vllm.forward_context
import
set_forward_context
from
vllm.logger
import
init_logger
from
vllm.triton_utils
import
triton
from
vllm.v1.attention.backend
import
CommonAttentionMetadata
from
vllm.v1.spec_decode.eagle
import
SpecDecodeBaseProposer
from
vllm.v1.spec_decode.utils
import
copy_and_expand_dflash_inputs_kernel
logger
=
init_logger
(
__name__
)
class
DFlashProposer
(
SpecDecodeBaseProposer
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
device
:
torch
.
device
,
runner
=
None
,
):
assert
vllm_config
.
speculative_config
is
not
None
assert
vllm_config
.
speculative_config
.
method
==
"dflash"
super
().
__init__
(
vllm_config
=
vllm_config
,
device
=
device
,
pass_hidden_states_to_model
=
True
,
runner
=
runner
,
)
# Only next_token_ids and mask tokens are query tokens, all other context is K/V
self
.
max_query_tokens
=
self
.
max_batch_size
*
(
1
+
self
.
num_speculative_tokens
)
# Positions covers both context states + query states
self
.
max_positions
=
self
.
max_num_tokens
+
self
.
max_query_tokens
# Separate context buffers to keep query buffer addresses stable for CUDA graphs
self
.
_context_slot_mapping_buffer
=
torch
.
zeros
(
self
.
max_num_tokens
,
dtype
=
torch
.
int64
,
device
=
device
,
)
self
.
_slot_mapping_buffer
=
torch
.
zeros
(
self
.
max_query_tokens
,
dtype
=
torch
.
int64
,
device
=
device
,
)
self
.
_context_positions_buffer
=
torch
.
zeros
(
self
.
max_num_tokens
,
dtype
=
torch
.
int64
,
device
=
device
,
)
self
.
positions
=
torch
.
zeros
(
self
.
max_query_tokens
,
dtype
=
torch
.
int64
,
device
=
device
,
)
self
.
arange
=
torch
.
arange
(
self
.
max_positions
+
1
,
device
=
device
,
dtype
=
torch
.
int32
)
# For DFlash we use the input embeddings to embed the mask token
self
.
parallel_drafting_hidden_state_tensor
=
None
@
override
def
_raise_if_multimodal
(
self
):
# Override to allow multimodal inputs since DFlash supports Qwen3.5 models
# Support for multimodal inputs has not been tested.
pass
@
override
def
set_inputs_first_pass
(
self
,
target_token_ids
:
torch
.
Tensor
,
next_token_ids
:
torch
.
Tensor
,
target_positions
:
torch
.
Tensor
,
target_hidden_states
:
torch
.
Tensor
,
token_indices_to_sample
:
torch
.
Tensor
|
None
,
cad
:
CommonAttentionMetadata
,
num_rejected_tokens_gpu
:
torch
.
Tensor
|
None
,
)
->
tuple
[
int
,
torch
.
Tensor
,
CommonAttentionMetadata
]:
# DFlash cross-attention: context K/V from target hidden states,
# Q from query embeddings (bonus + mask tokens).
batch_size
=
cad
.
batch_size
()
num_context
=
target_token_ids
.
shape
[
0
]
num_query_per_req
=
1
+
self
.
num_speculative_tokens
num_query_total
=
batch_size
*
num_query_per_req
# Store for build_model_inputs_first_pass to use
self
.
_dflash_num_context
=
num_context
# We don't need to copy into a buffer here since the context preprocessing
# does not run in a CUDA graph
self
.
_dflash_hidden_states
=
target_hidden_states
token_indices_to_sample
=
torch
.
empty
(
batch_size
*
self
.
num_speculative_tokens
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
)
# Launch fused triton kernel for input_ids, positions, slot_mapping,
# and token_indices_to_sample
max_ctx_per_req
=
cad
.
max_query_len
max_tokens_per_req
=
max_ctx_per_req
+
num_query_per_req
BLOCK_SIZE
=
min
(
256
,
triton
.
next_power_of_2
(
max_tokens_per_req
))
num_blocks
=
triton
.
cdiv
(
max_tokens_per_req
,
BLOCK_SIZE
)
grid
=
(
batch_size
,
num_blocks
)
has_num_rejected
=
num_rejected_tokens_gpu
is
not
None
copy_and_expand_dflash_inputs_kernel
[
grid
](
# Inputs
next_token_ids_ptr
=
next_token_ids
,
target_positions_ptr
=
target_positions
,
# Outputs
out_input_ids_ptr
=
self
.
input_ids
,
out_context_positions_ptr
=
self
.
_context_positions_buffer
,
out_query_positions_ptr
=
self
.
positions
,
out_context_slot_mapping_ptr
=
self
.
_context_slot_mapping_buffer
,
out_query_slot_mapping_ptr
=
self
.
_slot_mapping_buffer
,
out_token_indices_ptr
=
token_indices_to_sample
,
# Block table
block_table_ptr
=
cad
.
block_table_tensor
,
block_table_stride
=
cad
.
block_table_tensor
.
stride
(
0
),
# Metadata
query_start_loc_ptr
=
cad
.
query_start_loc
,
num_rejected_tokens_ptr
=
(
num_rejected_tokens_gpu
if
has_num_rejected
else
0
),
# Scalars
parallel_drafting_token_id
=
self
.
parallel_drafting_token_id
,
block_size
=
self
.
block_size
,
num_query_per_req
=
num_query_per_req
,
num_speculative_tokens
=
self
.
num_speculative_tokens
,
total_input_tokens
=
num_context
,
BLOCK_SIZE
=
BLOCK_SIZE
,
HAS_NUM_REJECTED
=
has_num_rejected
,
)
query_slot_mapping
=
self
.
_slot_mapping_buffer
[:
num_query_total
]
new_query_start_loc
=
self
.
arange
[:
batch_size
+
1
]
*
num_query_per_req
# In padded mode, cad.seq_lens includes rejected tokens. Subtract
# them so attention only sees the valid prefix of context states.
effective_seq_lens
=
cad
.
seq_lens
if
has_num_rejected
:
effective_seq_lens
=
effective_seq_lens
-
num_rejected_tokens_gpu
new_cad
=
CommonAttentionMetadata
(
query_start_loc
=
new_query_start_loc
,
seq_lens
=
effective_seq_lens
+
num_query_per_req
,
query_start_loc_cpu
=
(
torch
.
from_numpy
(
self
.
token_arange_np
[:
batch_size
+
1
]).
clone
()
*
num_query_per_req
),
_seq_lens_cpu
=
None
,
_num_computed_tokens_cpu
=
None
,
num_reqs
=
cad
.
num_reqs
,
num_actual_tokens
=
num_query_total
,
max_query_len
=
num_query_per_req
,
max_seq_len
=
cad
.
max_seq_len
+
num_query_per_req
,
block_table_tensor
=
cad
.
block_table_tensor
,
slot_mapping
=
query_slot_mapping
,
causal
=
False
,
# Non-causal attention is required for DFlash
)
return
num_query_total
,
token_indices_to_sample
,
new_cad
@
override
@
torch
.
inference_mode
()
def
dummy_run
(
self
,
num_tokens
:
int
,
use_cudagraphs
:
bool
=
True
,
is_graph_capturing
:
bool
=
False
,
slot_mappings
:
dict
[
str
,
torch
.
Tensor
]
|
None
=
None
,
)
->
None
:
"""
Key differences to default dummy_run:
- Only one forward pass due to parallel drafting
- DFlash uses context states as unpadded metadata, so hidden_states will
use the unpadded num_tokens instead of num_input_tokens
- max_query_tokens is quite small, DFlash only sees spec tokens as queries
- Multimodal inputs are not currently supported
"""
num_query_tokens
=
min
(
num_tokens
,
self
.
max_query_tokens
)
cudagraph_runtime_mode
,
num_input_tokens
,
num_tokens_across_dp
=
(
self
.
_determine_batch_execution_and_padding
(
num_query_tokens
,
use_cudagraphs
=
use_cudagraphs
)
)
# Slot mapping sized to num_input_tokens (query only), matching
# the K/V tensor size from the model forward. Context KVs are
# pre-inserted separately and don't flow through the model.
if
(
self
.
_draft_attn_layer_names
and
slot_mappings
is
not
None
and
next
(
iter
(
self
.
_draft_attn_layer_names
))
in
slot_mappings
):
slot_mapping_dict
=
self
.
_get_slot_mapping
(
num_input_tokens
)
else
:
slot_mapping_dict
=
slot_mappings
or
{}
# Context and query positions use separate buffers; no copy needed.
context_positions
=
self
.
_context_positions_buffer
[:
num_tokens
]
# Context states will be passed directly to the precomputation without
# going through the buffer, since no CUDA graph is used for the precomputation.
# For the dummy run, we use the dummy buffer.
context_states
=
self
.
hidden_states
[:
num_tokens
]
# Run the KV projection (GEMM + norms + RoPE) for memory profiling,
self
.
model
.
precompute_and_store_context_kv
(
context_states
,
context_positions
)
with
set_forward_context
(
None
,
self
.
vllm_config
,
num_tokens
=
num_input_tokens
,
num_tokens_across_dp
=
num_tokens_across_dp
,
cudagraph_runtime_mode
=
cudagraph_runtime_mode
,
slot_mapping
=
slot_mapping_dict
,
):
self
.
model
(
input_ids
=
self
.
input_ids
[:
num_input_tokens
],
positions
=
self
.
_get_positions
(
num_input_tokens
),
inputs_embeds
=
None
,
)
@
override
def
build_model_inputs_first_pass
(
self
,
num_tokens
:
int
,
num_input_tokens
:
int
,
mm_embed_inputs
:
tuple
[
list
[
torch
.
Tensor
],
torch
.
Tensor
]
|
None
,
)
->
tuple
[
dict
[
str
,
Any
],
int
]:
# Context and query positions/slots were written to separate
# buffers by the kernel — no copy needed.
num_context
=
self
.
_dflash_num_context
# Pre-insert context KVs directly into cache
self
.
model
.
precompute_and_store_context_kv
(
self
.
_dflash_hidden_states
,
# Shape is already [num_context, hidden_size]
self
.
_context_positions_buffer
[:
num_context
],
self
.
_context_slot_mapping_buffer
[:
num_context
],
)
return
(
dict
(
input_ids
=
self
.
input_ids
[:
num_input_tokens
],
positions
=
self
.
_get_positions
(
num_input_tokens
),
inputs_embeds
=
None
,
),
num_input_tokens
,
)
@
override
def
build_per_layer_attn_metadata
(
self
,
cad
:
CommonAttentionMetadata
,
draft_index
:
int
=
0
)
->
dict
[
str
,
object
]:
per_layer_attention_metadata
=
super
().
build_per_layer_attn_metadata
(
cad
,
draft_index
)
for
layer_name
,
attn_metadata
in
per_layer_attention_metadata
.
items
():
assert
getattr
(
attn_metadata
,
"causal"
,
None
)
is
False
,
(
f
"Attention metadata for layer
{
layer_name
}
does not have"
" non-causal support, which is required for DFlash."
" Consider using a different attention backend, such as FlashAttention."
)
return
per_layer_attention_metadata
@
override
def
_get_eagle3_use_aux_hidden_state_from_config
(
self
):
use_aux_hidden_state
=
True
dflash_config
=
getattr
(
self
.
draft_model_config
.
hf_config
,
"dflash_config"
,
None
)
if
dflash_config
is
not
None
:
use_aux_hidden_state
=
dflash_config
.
get
(
"use_aux_hidden_state"
,
True
)
return
use_aux_hidden_state
vllm/v1/spec_decode/eagle.py
View file @
494636b2
...
...
@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
ast
from
importlib.util
import
find_spec
from
typing
import
cast
from
typing
import
Any
,
cast
import
numpy
as
np
import
torch
...
...
@@ -23,6 +23,7 @@ from vllm.model_executor.models import supports_multimodal
from
vllm.model_executor.models.deepseek_eagle3
import
Eagle3DeepseekV2ForCausalLM
from
vllm.model_executor.models.interfaces
import
SupportsMultiModal
from
vllm.model_executor.models.llama_eagle3
import
Eagle3LlamaForCausalLM
from
vllm.model_executor.models.qwen3_dflash
import
DFlashQwen3ForCausalLM
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.platforms
import
current_platform
from
vllm.triton_utils
import
triton
...
...
@@ -83,13 +84,15 @@ class SpecDecodeBaseProposer:
self
.
hidden_size
=
self
.
draft_model_config
.
get_hidden_size
()
self
.
inputs_embeds_size
=
self
.
draft_model_config
.
get_inputs_embeds_size
()
# Unifying eagle, draft model, and parallel drafting support
# Unifying eagle, draft model, and parallel drafting support.
# DFlash always uses parallel drafting (all tokens in one pass),
# but has an additional slot for the next_token_id (does not shift like EAGLE)
self
.
parallel_drafting
:
bool
=
self
.
speculative_config
.
parallel_drafting
self
.
extra_slots_per_request
=
(
1
if
not
self
.
parallel_drafting
else
self
.
num_speculative_tokens
)
self
.
net_num_new_slots_per_request
=
self
.
extra_slots_per_request
-
(
1
if
self
.
pass_hidden_states_to_model
else
0
1
if
(
self
.
pass_hidden_states_to_model
and
self
.
method
!=
"dflash"
)
else
0
)
self
.
needs_extra_input_slots
=
self
.
net_num_new_slots_per_request
>
0
...
...
@@ -101,10 +104,14 @@ class SpecDecodeBaseProposer:
self
.
speculative_config
.
use_local_argmax_reduction
)
max_batch_size
=
vllm_config
.
scheduler_config
.
max_num_seqs
self
.
max_batch_size
=
vllm_config
.
scheduler_config
.
max_num_seqs
self
.
max_num_tokens
=
vllm_config
.
scheduler_config
.
max_num_batched_tokens
self
.
token_arange_np
=
np
.
arange
(
self
.
max_num_tokens
)
# Can be specialized by methods like DFlash to reduce the limit
self
.
max_query_tokens
=
self
.
max_num_tokens
self
.
max_positions
=
self
.
max_num_tokens
# Multi-modal data support
self
.
mm_registry
=
MULTIMODAL_REGISTRY
self
.
supports_mm_inputs
=
self
.
mm_registry
.
supports_multimodal_inputs
(
...
...
@@ -146,18 +153,20 @@ class SpecDecodeBaseProposer:
# 1D-RoPE.
# See page 5 of https://arxiv.org/abs/2409.12191
self
.
mrope_positions
=
torch
.
zeros
(
(
3
,
self
.
max_
num_toke
ns
+
1
),
dtype
=
torch
.
int64
,
device
=
device
(
3
,
self
.
max_
positio
ns
+
1
),
dtype
=
torch
.
int64
,
device
=
device
)
elif
self
.
uses_xdrope_dim
>
0
and
self
.
draft_uses_xdrope_dim
>
0
:
self
.
xdrope_positions
=
torch
.
zeros
(
(
self
.
uses_xdrope_dim
,
self
.
max_
num_toke
ns
+
1
),
(
self
.
uses_xdrope_dim
,
self
.
max_
positio
ns
+
1
),
dtype
=
torch
.
int64
,
device
=
device
,
)
else
:
# RoPE need (max_num_tokens,)
self
.
positions
=
torch
.
zeros
(
self
.
max_num_tokens
,
dtype
=
torch
.
int64
,
device
=
device
self
.
max_positions
,
dtype
=
torch
.
int64
,
device
=
device
,
)
self
.
hidden_states
=
torch
.
zeros
(
(
self
.
max_num_tokens
,
self
.
hidden_size
),
dtype
=
self
.
dtype
,
device
=
device
...
...
@@ -168,7 +177,7 @@ class SpecDecodeBaseProposer:
# We need +1 here because the arange is used to set query_start_loc,
# which has one more element than batch_size.
max_num_slots_for_arange
=
max
(
max_batch_size
+
1
,
self
.
max_num_tokens
)
max_num_slots_for_arange
=
max
(
self
.
max_batch_size
+
1
,
self
.
max_num_tokens
)
self
.
arange
=
torch
.
arange
(
max_num_slots_for_arange
,
device
=
device
,
dtype
=
torch
.
int32
)
...
...
@@ -200,7 +209,7 @@ class SpecDecodeBaseProposer:
)
self
.
backup_next_token_ids
=
CpuGpuBuffer
(
max_batch_size
,
self
.
max_batch_size
,
dtype
=
torch
.
int32
,
pin_memory
=
is_pin_memory_available
(),
device
=
device
,
...
...
@@ -208,7 +217,9 @@ class SpecDecodeBaseProposer:
)
self
.
_slot_mapping_buffer
=
torch
.
zeros
(
self
.
max_num_tokens
,
dtype
=
torch
.
int64
,
device
=
device
self
.
max_positions
,
dtype
=
torch
.
int64
,
device
=
device
,
)
# Determine allowed attention backends once during initialization.
...
...
@@ -275,7 +286,7 @@ class SpecDecodeBaseProposer:
# Precompute draft position offsets in flattened tree.
self
.
tree_draft_pos_offsets
=
torch
.
arange
(
1
,
len
(
self
.
tree_choices
)
+
1
,
device
=
device
,
dtype
=
torch
.
int32
).
repeat
(
max_batch_size
,
1
)
).
repeat
(
self
.
max_batch_size
,
1
)
def
_raise_if_padded_drafter_batch_disabled
(
self
):
if
self
.
speculative_config
.
disable_padded_drafter_batch
:
...
...
@@ -305,14 +316,19 @@ class SpecDecodeBaseProposer:
# for those masked slots.
model_hf_config
=
self
.
draft_model_config
.
hf_config
if
hasattr
(
model_hf_config
,
"pard_token"
):
# DFlash stores mask_token_id in dflash_config
dflash_config
=
getattr
(
model_hf_config
,
"dflash_config"
,
None
)
if
dflash_config
and
"mask_token_id"
in
dflash_config
:
self
.
parallel_drafting_token_id
=
dflash_config
[
"mask_token_id"
]
elif
hasattr
(
model_hf_config
,
"pard_token"
):
self
.
parallel_drafting_token_id
=
model_hf_config
.
pard_token
elif
hasattr
(
model_hf_config
,
"ptd_token_id"
):
self
.
parallel_drafting_token_id
=
model_hf_config
.
ptd_token_id
else
:
raise
ValueError
(
"For parallel drafting, the draft model config must have "
"`pard_token` or `ptd_token_id` specified in its config.json."
"`pard_token`, `ptd_token_id`, or "
"`dflash_config.mask_token_id` specified in its config.json."
)
if
self
.
pass_hidden_states_to_model
:
...
...
@@ -402,9 +418,14 @@ class SpecDecodeBaseProposer:
)
->
torch
.
Tensor
:
batch_size
=
common_attn_metadata
.
batch_size
()
if
self
.
method
==
"eagle3"
:
if
self
.
method
in
(
"eagle3"
,
"dflash"
)
:
assert
isinstance
(
self
.
model
,
(
Eagle3LlamaForCausalLM
,
Eagle3DeepseekV2ForCausalLM
)
self
.
model
,
(
Eagle3LlamaForCausalLM
,
Eagle3DeepseekV2ForCausalLM
,
DFlashQwen3ForCausalLM
,
),
)
target_hidden_states
=
self
.
model
.
combine_hidden_states
(
target_hidden_states
...
...
@@ -423,41 +444,18 @@ class SpecDecodeBaseProposer:
)
)
per_layer_attn_metadata
:
dict
[
str
,
object
]
=
{}
for
attn_group
in
self
.
draft_attn_groups
:
attn_metadata
=
attn_group
.
get_metadata_builder
().
build_for_drafting
(
common_attn_metadata
=
common_attn_metadata
,
draft_index
=
0
per_layer_attn_metadata
=
self
.
build_per_layer_attn_metadata
(
common_attn_metadata
)
for
layer_name
in
attn_group
.
layer_names
:
per_layer_attn_metadata
[
layer_name
]
=
attn_metadata
cudagraph_runtime_mode
,
num_input_tokens
,
num_tokens_across_dp
=
(
self
.
_determine_batch_execution_and_padding
(
num_tokens
)
)
if
self
.
supports_mm_inputs
:
mm_embeds
,
is_mm_embed
=
mm_embed_inputs
or
(
None
,
None
)
self
.
inputs_embeds
[:
num_tokens
]
=
self
.
model
.
embed_input_ids
(
self
.
input_ids
[:
num_tokens
],
multimodal_embeddings
=
mm_embeds
,
is_multimodal
=
is_mm_embed
,
model_kwargs
,
slot_mapping_size
=
self
.
build_model_inputs_first_pass
(
num_tokens
,
num_input_tokens
,
mm_embed_inputs
)
input_ids
=
None
inputs_embeds
=
self
.
inputs_embeds
[:
num_input_tokens
]
else
:
input_ids
=
self
.
input_ids
[:
num_input_tokens
]
inputs_embeds
=
None
model_kwargs
=
{
"input_ids"
:
input_ids
,
"positions"
:
self
.
_get_positions
(
num_input_tokens
),
"inputs_embeds"
:
inputs_embeds
,
}
if
self
.
pass_hidden_states_to_model
:
model_kwargs
[
"hidden_states"
]
=
self
.
hidden_states
[:
num_input_tokens
]
with
set_forward_context
(
per_layer_attn_metadata
,
self
.
vllm_config
,
...
...
@@ -465,7 +463,7 @@ class SpecDecodeBaseProposer:
num_tokens_across_dp
=
num_tokens_across_dp
,
cudagraph_runtime_mode
=
cudagraph_runtime_mode
,
slot_mapping
=
self
.
_get_slot_mapping
(
num_input_tokens
,
common_attn_metadata
.
slot_mapping
slot_mapping_size
,
common_attn_metadata
.
slot_mapping
),
):
ret_hidden_states
=
self
.
model
(
**
model_kwargs
)
...
...
@@ -488,7 +486,10 @@ class SpecDecodeBaseProposer:
positions
=
self
.
positions
[
token_indices_to_sample
]
hidden_states
=
hidden_states
[
token_indices_to_sample
]
if
isinstance
(
attn_metadata
,
TreeAttentionMetadata
):
if
any
(
isinstance
(
attn_metadata
,
TreeAttentionMetadata
)
for
attn_metadata
in
per_layer_attn_metadata
.
values
()
):
# Draft using tree attention - requires full logits for top-k
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
)
draft_token_ids_list
=
self
.
propose_tree
(
...
...
@@ -504,6 +505,7 @@ class SpecDecodeBaseProposer:
draft_token_ids
=
self
.
_greedy_sample
(
sample_hidden_states
)
for
attn_metadata
in
per_layer_attn_metadata
.
values
():
if
self
.
allowed_attn_types
is
not
None
and
not
isinstance
(
attn_metadata
,
self
.
allowed_attn_types
):
...
...
@@ -593,13 +595,9 @@ class SpecDecodeBaseProposer:
common_attn_metadata
.
_num_computed_tokens_cpu
+=
1
# Rebuild attention metadata
for
attn_group
in
self
.
draft_attn_groups
:
attn_metadata
=
attn_group
.
get_metadata_builder
().
build_for_drafting
(
common_attn_metadata
=
common_attn_metadata
,
draft_index
=
token_index
+
1
,
per_layer_attn_metadata
=
self
.
build_per_layer_attn_metadata
(
common_attn_metadata
,
draft_index
=
token_index
+
1
)
for
layer_name
in
attn_group
.
layer_names
:
per_layer_attn_metadata
[
layer_name
]
=
attn_metadata
# copy inputs to buffer for cudagraph
self
.
input_ids
[:
batch_size
]
=
input_ids
...
...
@@ -780,8 +778,51 @@ class SpecDecodeBaseProposer:
return
total_num_output_tokens
,
token_indices_to_sample
,
new_cad
def
build_model_inputs_first_pass
(
self
,
num_tokens
:
int
,
num_input_tokens
:
int
,
mm_embed_inputs
:
tuple
[
list
[
torch
.
Tensor
],
torch
.
Tensor
]
|
None
,
)
->
tuple
[
dict
[
str
,
Any
],
int
]:
if
self
.
supports_mm_inputs
:
mm_embeds
,
is_mm_embed
=
mm_embed_inputs
or
(
None
,
None
)
self
.
inputs_embeds
[:
num_tokens
]
=
self
.
model
.
embed_input_ids
(
self
.
input_ids
[:
num_tokens
],
multimodal_embeddings
=
mm_embeds
,
is_multimodal
=
is_mm_embed
,
)
input_ids
=
None
inputs_embeds
=
self
.
inputs_embeds
[:
num_input_tokens
]
else
:
input_ids
=
self
.
input_ids
[:
num_input_tokens
]
inputs_embeds
=
None
model_kwargs
=
{
"input_ids"
:
input_ids
,
"positions"
:
self
.
_get_positions
(
num_input_tokens
),
"inputs_embeds"
:
inputs_embeds
,
}
if
self
.
pass_hidden_states_to_model
:
model_kwargs
[
"hidden_states"
]
=
self
.
hidden_states
[:
num_input_tokens
]
return
model_kwargs
,
num_input_tokens
def
build_per_layer_attn_metadata
(
self
,
common_attn_metadata
:
CommonAttentionMetadata
,
draft_index
:
int
=
0
)
->
dict
[
str
,
object
]:
per_layer_attn_metadata
:
dict
[
str
,
object
]
=
{}
for
attn_group
in
self
.
draft_attn_groups
:
attn_metadata
=
attn_group
.
get_metadata_builder
().
build_for_drafting
(
common_attn_metadata
=
common_attn_metadata
,
draft_index
=
draft_index
)
for
layer_name
in
attn_group
.
layer_names
:
per_layer_attn_metadata
[
layer_name
]
=
attn_metadata
return
per_layer_attn_metadata
def
model_returns_tuple
(
self
)
->
bool
:
return
self
.
method
not
in
(
"mtp"
,
"draft_model"
)
return
self
.
method
not
in
(
"mtp"
,
"draft_model"
,
"dflash"
)
def
prepare_next_token_ids_cpu
(
self
,
...
...
@@ -1310,15 +1351,20 @@ class SpecDecodeBaseProposer:
self
.
_maybe_share_embeddings
(
target_language_model
)
self
.
_maybe_share_lm_head
(
target_language_model
)
if
self
.
parallel_drafting
and
self
.
pass_hidden_states_to_model
:
assert
self
.
parallel_drafting_hidden_state_tensor
is
not
None
if
(
self
.
parallel_drafting
and
self
.
pass_hidden_states_to_model
and
self
.
parallel_drafting_hidden_state_tensor
is
not
None
):
flat_mask
=
self
.
model
.
mask_hidden
.
view
(
-
1
)
if
self
.
eagle3_use_aux_hidden_state
:
# EAGLE3: mask_hidden stores all aux hidden states,
# project through combine_hidden_states
self
.
parallel_drafting_hidden_state_tensor
.
copy_
(
self
.
model
.
combine_hidden_states
(
self
.
model
.
mask_hidden
.
view
(
3
*
self
.
hidden_size
)
)
if
self
.
eagle3_use_aux_hidden_state
else
self
.
model
.
mask_hidden
.
view
(
self
.
hidden_size
)
self
.
model
.
combine_hidden_states
(
flat_mask
)
)
else
:
self
.
parallel_drafting_hidden_state_tensor
.
copy_
(
flat_mask
)
def
_maybe_share_embeddings
(
self
,
target_language_model
:
nn
.
Module
)
->
None
:
"""
...
...
@@ -1493,8 +1539,9 @@ class SpecDecodeBaseProposer:
)
->
None
:
# FIXME: when using tree-based specdec, adjust number of forward-passes
# according to the depth of the tree.
only_one_forward_pass
=
is_graph_capturing
or
self
.
parallel_drafting
for
fwd_idx
in
range
(
self
.
num_speculative_tokens
if
not
is_graph_capturing
else
1
1
if
only_one_forward_pass
else
self
.
num_speculative_tokens
):
if
fwd_idx
<=
1
:
cudagraph_runtime_mode
,
num_input_tokens
,
num_tokens_across_dp
=
(
...
...
vllm/v1/spec_decode/utils.py
View file @
494636b2
...
...
@@ -441,6 +441,114 @@ def copy_and_expand_eagle_inputs_kernel(
)
@
triton
.
jit
def
copy_and_expand_dflash_inputs_kernel
(
# Inputs
next_token_ids_ptr
,
# [num_reqs]
target_positions_ptr
,
# [num_context]
# Outputs
out_input_ids_ptr
,
# [num_query_total] (output)
out_context_positions_ptr
,
# [num_context] (output)
out_query_positions_ptr
,
# [num_query_total] (output)
out_context_slot_mapping_ptr
,
# [num_context] (output)
out_query_slot_mapping_ptr
,
# [num_query_total] (output)
out_token_indices_ptr
,
# [num_reqs * num_speculative_tokens] (output)
# Block table
block_table_ptr
,
# [max_reqs, max_blocks]
block_table_stride
,
# stride of block_table dim 0 (in elements)
# Metadata
query_start_loc_ptr
,
# [num_reqs + 1]
num_rejected_tokens_ptr
,
# [num_reqs] or null (0) when not padded
# Scalars
parallel_drafting_token_id
,
# tl.int32
block_size
,
# tl.int32
num_query_per_req
,
# tl.int32
num_speculative_tokens
,
# tl.int32
total_input_tokens
,
# tl.int32
BLOCK_SIZE
:
tl
.
constexpr
,
HAS_NUM_REJECTED
:
tl
.
constexpr
=
False
,
):
"""
Fused kernel for DFlash first-pass input setup.
Per request, this kernel:
1. Copies context positions from target_positions to
out_context_positions.
2. Computes query positions (last_target_pos + 1 + offset) and writes
them to out_query_positions.
3. Writes input_ids for query tokens: [next_token, mask, mask, ...].
4. Computes slot_mapping for context and query positions into separate
buffers via block_table lookup.
5. Writes token_indices_to_sample for the mask (speculative) tokens.
"""
req_idx
=
tl
.
program_id
(
axis
=
0
)
block_idx
=
tl
.
program_id
(
axis
=
1
)
# Load context token range for this request
ctx_start
=
tl
.
load
(
query_start_loc_ptr
+
req_idx
)
ctx_end
=
tl
.
load
(
query_start_loc_ptr
+
req_idx
+
1
)
num_ctx
=
ctx_end
-
ctx_start
total_tokens
=
num_ctx
+
num_query_per_req
j
=
block_idx
*
BLOCK_SIZE
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
in_bounds
=
j
<
total_tokens
is_ctx
=
j
<
num_ctx
is_query
=
(
~
is_ctx
)
&
in_bounds
query_off
=
j
-
num_ctx
# offset within query portion (0-indexed)
# --- Positions ---
# Context: load from target_positions
ctx_pos_idx
=
tl
.
minimum
(
ctx_start
+
j
,
total_input_tokens
-
1
)
ctx_pos
=
tl
.
load
(
target_positions_ptr
+
ctx_pos_idx
,
mask
=
is_ctx
,
other
=
0
)
# Query: last_valid_pos + 1 + query_off
# In padded mode, ctx_end includes rejected tokens; use valid_ctx_end
# to find the last accepted context position.
if
HAS_NUM_REJECTED
:
num_rejected
=
tl
.
load
(
num_rejected_tokens_ptr
+
req_idx
)
valid_ctx_end
=
ctx_end
-
num_rejected
else
:
valid_ctx_end
=
ctx_end
last_pos
=
tl
.
load
(
target_positions_ptr
+
valid_ctx_end
-
1
)
query_pos
=
last_pos
+
1
+
query_off
positions
=
tl
.
where
(
is_ctx
,
ctx_pos
,
query_pos
)
# Context and query positions go to separate buffers.
ctx_pos_out
=
ctx_start
+
j
tl
.
store
(
out_context_positions_ptr
+
ctx_pos_out
,
ctx_pos
,
mask
=
is_ctx
)
query_out
=
req_idx
*
num_query_per_req
+
query_off
tl
.
store
(
out_query_positions_ptr
+
query_out
,
query_pos
,
mask
=
is_query
)
# --- Slot mapping (block_table lookup for all positions) ---
block_num
=
positions
//
block_size
# # Clamp block_number to avoid OOB when position is at max
block_num
=
tl
.
minimum
(
block_num
,
block_table_stride
-
1
)
block_id
=
tl
.
load
(
block_table_ptr
+
req_idx
*
block_table_stride
+
block_num
,
mask
=
in_bounds
,
other
=
0
,
).
to
(
tl
.
int64
)
slot
=
block_id
*
block_size
+
(
positions
%
block_size
)
tl
.
store
(
out_context_slot_mapping_ptr
+
ctx_pos_out
,
slot
,
mask
=
is_ctx
)
tl
.
store
(
out_query_slot_mapping_ptr
+
query_out
,
slot
,
mask
=
is_query
)
# --- Input IDs (query tokens only) ---
bonus_token
=
tl
.
load
(
next_token_ids_ptr
+
req_idx
)
is_bonus
=
is_query
&
(
query_off
==
0
)
input_id
=
tl
.
where
(
is_bonus
,
bonus_token
,
parallel_drafting_token_id
)
tl
.
store
(
out_input_ids_ptr
+
query_out
,
input_id
,
mask
=
is_query
)
# --- Token indices to sample (mask tokens, skip the bonus token) ---
is_sample
=
is_query
&
(
query_off
>
0
)
sample_out_idx
=
req_idx
*
num_speculative_tokens
+
(
query_off
-
1
)
tl
.
store
(
out_token_indices_ptr
+
sample_out_idx
,
query_out
,
mask
=
is_sample
,
)
@
torch
.
compile
(
dynamic
=
True
,
backend
=
current_platform
.
simple_compile_backend
)
def
update_num_computed_tokens_for_batch_change
(
num_computed_tokens
:
torch
.
Tensor
,
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
494636b2
...
...
@@ -160,6 +160,7 @@ from vllm.v1.sample.logits_processor.interface import LogitsProcessor
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.sample.rejection_sampler
import
RejectionSampler
from
vllm.v1.sample.sampler
import
Sampler
from
vllm.v1.spec_decode.dflash
import
DFlashProposer
from
vllm.v1.spec_decode.draft_model
import
DraftModelProposer
from
vllm.v1.spec_decode.eagle
import
EagleProposer
from
vllm.v1.spec_decode.extract_hidden_states
import
ExtractHiddenStatesProposer
...
...
@@ -515,6 +516,7 @@ class GPUModelRunner(
|
NgramProposerGPU
|
SuffixDecodingProposer
|
EagleProposer
|
DFlashProposer
|
DraftModelProposer
|
MedusaProposer
|
ExtractHiddenStatesProposer
...
...
@@ -546,6 +548,9 @@ class GPUModelRunner(
self
.
_ngram_pinned_val_buf
=
torch
.
zeros
(
self
.
max_num_reqs
,
dtype
=
torch
.
int32
,
pin_memory
=
True
)
elif
self
.
speculative_config
.
use_dflash
():
self
.
drafter
=
DFlashProposer
(
self
.
vllm_config
,
self
.
device
,
self
)
self
.
use_aux_hidden_state_outputs
=
True
elif
self
.
speculative_config
.
method
==
"suffix"
:
self
.
drafter
=
SuffixDecodingProposer
(
self
.
vllm_config
)
elif
self
.
speculative_config
.
use_eagle
():
...
...
@@ -2289,7 +2294,7 @@ class GPUModelRunner(
cm
.
slot_mapping
=
slot_mappings
[
kv_cache_gid
]
if
self
.
speculative_config
and
spec_decode_common_attn_metadata
is
None
:
if
isinstance
(
self
.
drafter
,
EagleProposer
):
if
isinstance
(
self
.
drafter
,
(
EagleProposer
,
DFlashProposer
)
):
if
self
.
drafter
.
kv_cache_gid
==
kv_cache_gid
:
spec_decode_common_attn_metadata
=
cm
else
:
...
...
@@ -4202,7 +4207,10 @@ class GPUModelRunner(
# as inputs, and does not need to wait for bookkeeping to finish.
assert
isinstance
(
self
.
drafter
,
EagleProposer
|
DraftModelProposer
|
ExtractHiddenStatesProposer
,
EagleProposer
|
DFlashProposer
|
DraftModelProposer
|
ExtractHiddenStatesProposer
,
)
sampled_token_ids
=
sampler_output
.
sampled_token_ids
if
input_fits_in_drafter
:
...
...
@@ -4589,8 +4597,14 @@ class GPUModelRunner(
next_token_ids
,
valid_sampled_tokens_count
)
elif
spec_config
.
use_eagle
()
or
spec_config
.
uses_draft_model
():
assert
isinstance
(
self
.
drafter
,
EagleProposer
|
DraftModelProposer
)
elif
(
spec_config
.
use_eagle
()
or
spec_config
.
use_dflash
()
or
spec_config
.
uses_draft_model
()
):
assert
isinstance
(
self
.
drafter
,
EagleProposer
|
DFlashProposer
|
DraftModelProposer
)
if
spec_config
.
disable_padded_drafter_batch
:
# When padded-batch is disabled, the sampled_token_ids should be
...
...
@@ -4889,10 +4903,13 @@ class GPUModelRunner(
return
None
hf_config
=
self
.
speculative_config
.
draft_model_config
.
hf_config
if
not
hasattr
(
hf_config
,
"eagle_aux_hidden_state_layer_ids"
):
return
None
layer_ids
=
hf_config
.
eagle_aux_hidden_state_layer_ids
layer_ids
=
getattr
(
hf_config
,
"eagle_aux_hidden_state_layer_ids"
,
None
)
if
not
layer_ids
:
dflash_config
=
getattr
(
hf_config
,
"dflash_config"
,
None
)
if
dflash_config
and
isinstance
(
dflash_config
,
dict
):
layer_ids
=
dflash_config
.
get
(
"target_layer_ids"
)
if
layer_ids
and
isinstance
(
layer_ids
,
(
list
,
tuple
)):
return
tuple
(
layer_ids
)
...
...
@@ -5479,7 +5496,10 @@ class GPUModelRunner(
):
assert
isinstance
(
self
.
drafter
,
EagleProposer
|
DraftModelProposer
|
ExtractHiddenStatesProposer
,
EagleProposer
|
DFlashProposer
|
DraftModelProposer
|
ExtractHiddenStatesProposer
,
)
assert
self
.
speculative_config
is
not
None
# Eagle currently only supports PIECEWISE cudagraphs.
...
...
@@ -6236,7 +6256,9 @@ class GPUModelRunner(
self
.
speculative_config
.
use_eagle
()
or
self
.
speculative_config
.
uses_draft_model
()
):
assert
isinstance
(
self
.
drafter
,
EagleProposer
|
DraftModelProposer
)
assert
isinstance
(
self
.
drafter
,
EagleProposer
|
DFlashProposer
|
DraftModelProposer
)
self
.
drafter
.
initialize_attn_backend
(
kv_cache_config
,
kernel_block_sizes
)
def
_check_and_update_cudagraph_mode
(
...
...
@@ -6420,7 +6442,10 @@ class GPUModelRunner(
self
.
speculative_config
.
use_eagle
()
or
self
.
speculative_config
.
uses_extract_hidden_states
()
):
assert
isinstance
(
self
.
drafter
,
EagleProposer
|
ExtractHiddenStatesProposer
)
assert
isinstance
(
self
.
drafter
,
EagleProposer
|
DFlashProposer
|
ExtractHiddenStatesProposer
,
)
self
.
drafter
.
initialize_cudagraph_keys
(
cudagraph_mode
)
def
calculate_reorder_batch_threshold
(
self
)
->
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment