Unverified Commit 494636b2 authored by Benjamin Chislett's avatar Benjamin Chislett Committed by GitHub
Browse files

[Feat][Spec Decode] DFlash (#36847)


Signed-off-by: default avatarBenjamin Chislett <bchislett@nvidia.com>
parent ab1a6a43
...@@ -1163,6 +1163,14 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = { ...@@ -1163,6 +1163,14 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
# "JackFram/llama-160m", # "JackFram/llama-160m",
# speculative_model="ibm-ai-platform/llama-160m-accelerator" # speculative_model="ibm-ai-platform/llama-160m-accelerator"
# ), # ),
# [DFlash]
"DFlashDraftModel": _HfExamplesInfo(
"Qwen/Qwen3.5-4B",
speculative_model="z-lab/Qwen3.5-4B-DFlash",
use_original_num_layers=True, # Need all layers since DFlash has >1 layer,
max_model_len=8192, # Reduce max len to ensure test runs in low-VRAM CI env
max_num_seqs=32,
),
# [Eagle] # [Eagle]
"EagleDeepSeekMTPModel": _HfExamplesInfo( "EagleDeepSeekMTPModel": _HfExamplesInfo(
"eagle618/deepseek-v3-random", "eagle618/deepseek-v3-random",
......
...@@ -7,6 +7,7 @@ from typing import Any ...@@ -7,6 +7,7 @@ from typing import Any
import pytest import pytest
import torch import torch
from tqdm import tqdm
from tests.evals.gsm8k.gsm8k_eval import _build_gsm8k_prompts, evaluate_gsm8k_offline from tests.evals.gsm8k.gsm8k_eval import _build_gsm8k_prompts, evaluate_gsm8k_offline
from tests.utils import ( from tests.utils import (
...@@ -1105,19 +1106,178 @@ def some_high_acceptance_metrics() -> dict: ...@@ -1105,19 +1106,178 @@ def some_high_acceptance_metrics() -> dict:
} }
def compute_acceptance_rate(metrics: list[Metric]) -> float: def compute_acceptance_rate(
metrics: list[Metric], prev_metrics: list[Metric] | None = None
) -> float:
name2metric = {metric.name: metric for metric in metrics} name2metric = {metric.name: metric for metric in metrics}
n_draft_toks = name2metric["vllm:spec_decode_num_draft_tokens"].value # type: ignore n_draft_toks = name2metric["vllm:spec_decode_num_draft_tokens"].value
if n_draft_toks == 0: if n_draft_toks == 0:
return float("nan") return float("nan")
n_accepted_toks = name2metric["vllm:spec_decode_num_accepted_tokens"].value # type: ignore n_accepted_toks = name2metric["vllm:spec_decode_num_accepted_tokens"].value
if prev_metrics is not None:
prev_name2metric = {metric.name: metric for metric in prev_metrics}
n_draft_toks -= prev_name2metric["vllm:spec_decode_num_draft_tokens"].value
n_accepted_toks -= prev_name2metric[
"vllm:spec_decode_num_accepted_tokens"
].value
if n_draft_toks <= 0:
return float("nan")
return n_accepted_toks / n_draft_toks return n_accepted_toks / n_draft_toks
def compute_acceptance_len(metrics: list[Metric]) -> float: def compute_acceptance_len(
metrics: list[Metric], prev_metrics: list[Metric] | None = None
) -> float:
name2metric = {metric.name: metric for metric in metrics} name2metric = {metric.name: metric for metric in metrics}
n_drafts = name2metric["vllm:spec_decode_num_drafts"].value # type: ignore n_drafts = name2metric["vllm:spec_decode_num_drafts"].value
n_accepted_toks = name2metric["vllm:spec_decode_num_accepted_tokens"].value # type: ignore n_accepted_toks = name2metric["vllm:spec_decode_num_accepted_tokens"].value
if n_drafts == 0: if n_drafts == 0:
return 1 return 1
if prev_metrics is not None:
prev_name2metric = {metric.name: metric for metric in prev_metrics}
n_drafts -= prev_name2metric["vllm:spec_decode_num_drafts"].value
n_accepted_toks -= prev_name2metric[
"vllm:spec_decode_num_accepted_tokens"
].value
if n_drafts <= 0:
return 1
return 1 + (n_accepted_toks / n_drafts) return 1 + (n_accepted_toks / n_drafts)
# Datasets in the format used in DFlash validations
def load_and_process_dataset(data_name: str):
from datasets import load_dataset
if data_name == "gsm8k":
dataset = load_dataset("openai/gsm8k", "main", split="test")
prompt_fmt = (
"{question}\nPlease reason step by step,"
" and put your final answer within \\boxed{{}}."
)
dataset = dataset.map(lambda x: {"turns": [prompt_fmt.format(**x)]})
elif data_name == "mt-bench":
dataset = load_dataset("HuggingFaceH4/mt_bench_prompts", split="train")
dataset = dataset.map(lambda x: {"turns": x["prompt"]})
elif data_name == "humaneval":
dataset = load_dataset("openai/openai_humaneval", split="test")
prompt_fmt = (
"Write a solution to the following problem and make sure"
" that it passes the tests:\n```python\n{prompt}\n```"
)
dataset = dataset.map(lambda x: {"turns": [prompt_fmt.format(**x)]})
return dataset
@pytest.fixture
def dflash_config():
target_model = "Qwen/Qwen3-8B"
draft_model = "z-lab/Qwen3-8B-DFlash-b16"
return dict(
model=target_model,
trust_remote_code=True,
speculative_config={
"method": "dflash",
"model": draft_model,
"num_speculative_tokens": 16,
"max_model_len": 32768,
},
max_model_len=32768,
max_num_seqs=128,
gpu_memory_utilization=0.85,
enforce_eager=False,
disable_log_stats=False,
)
def test_dflash_acceptance_rates(dflash_config):
"""
E2E test for DFlash (block diffusion) speculative decoding.
Runs acceptance rate validation on GSM8k, MT-Bench, and HumanEval
comparing against baseline results from the paper (Table 1).
See https://github.com/z-lab/dflash/blob/main/benchmark_sglang.py for methodology.
"""
spec_llm = LLM(**dflash_config)
max_prompts_per_dataset = 200 # mt-bench has 80, humaneval has 164, truncates gsm8k
# All scores from Table 1 in https://arxiv.org/pdf/2602.06036
expected_acceptance_lengths = {
"mt-bench": 4.24,
"humaneval": 6.50,
"gsm8k": 6.54 * 0.95, # runs with a subset of prompts so extra wide tol here
}
tokenizer = spec_llm.get_tokenizer()
for dataset_name, expected_len in expected_acceptance_lengths.items():
dataset = load_and_process_dataset(dataset_name)
prev_metrics = None
acceptance_lengths = []
for i in tqdm(
range(min(max_prompts_per_dataset, len(dataset))),
desc=f"Processing {dataset_name}",
):
user_content = dataset[i]["turns"][0]
prompt_text = tokenizer.apply_chat_template(
[{"role": "user", "content": user_content}],
tokenize=False,
add_generation_prompt=True,
enable_thinking=False,
)
# Temp=0, MaxTokens=2048 from the paper
spec_llm.generate(
[prompt_text],
SamplingParams(temperature=0, max_tokens=2048),
use_tqdm=False,
)
current_metrics = spec_llm.get_metrics()
acceptance_len = compute_acceptance_len(current_metrics, prev_metrics)
prev_metrics = current_metrics
acceptance_lengths.append(acceptance_len)
mean_acceptance_length = sum(acceptance_lengths) / len(acceptance_lengths)
expected_len = expected_len * 0.9
print(
f"DFlash acceptance_len for {dataset_name}: {mean_acceptance_length:.2f}"
f" (expected at least {expected_len:.2f})"
)
assert mean_acceptance_length >= expected_len, (
f"DFlash acceptance_len for {dataset_name} is below expected threshold:"
f"{mean_acceptance_length:.2f} < {expected_len:.2f}"
)
del spec_llm
torch.accelerator.empty_cache()
cleanup_dist_env_and_memory()
def test_dflash_correctness(dflash_config):
"""
E2E test for DFlash (block diffusion) speculative decoding.
Ensures output correctness on GSM8k, with cudagraphs and batching on.
"""
spec_llm = LLM(**dflash_config)
# Evaluate GSM8k accuracy (Qwen3-8B ref: ~87-92% on GSM8k)
evaluate_llm_for_gsm8k(spec_llm, expected_accuracy_threshold=0.8)
current_metrics = spec_llm.get_metrics()
acceptance_len = compute_acceptance_len(current_metrics)
# AR is thoroughly validated in test_dflash_acceptance_rates, in a manner consistent
# with the DFlash paper. However, that test measures AL per-request and thus runs
# with a batch size of 1. To ensure that AL does not collapse with large batch sizes
# we enforce a baseline on the AL over the full lm-eval-style GSM8k test.
expected_len = 3.5 # Measured is 3.9 to 4.0
print(f"DFlash GSM8k correctness test got AL {acceptance_len}")
assert acceptance_len >= expected_len, (
"DFlash correctness check failed with"
f" {acceptance_len=}, expected at least {expected_len}"
)
del spec_llm
torch.accelerator.empty_cache()
cleanup_dist_env_and_memory()
...@@ -27,6 +27,7 @@ from vllm.config.load import LoadConfig ...@@ -27,6 +27,7 @@ from vllm.config.load import LoadConfig
from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.v1.attention.backends.registry import AttentionBackendEnum from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.spec_decode.dflash import DFlashProposer
from vllm.v1.spec_decode.draft_model import DraftModelProposer from vllm.v1.spec_decode.draft_model import DraftModelProposer
from vllm.v1.spec_decode.eagle import EagleProposer from vllm.v1.spec_decode.eagle import EagleProposer
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
...@@ -36,6 +37,8 @@ model_dir = "meta-llama/Llama-3.1-8B-Instruct" ...@@ -36,6 +37,8 @@ model_dir = "meta-llama/Llama-3.1-8B-Instruct"
eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B" eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
eagle3_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B" eagle3_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
ar_draft_model_dir = "amd/PARD-Llama-3.2-1B" # Compatible with parallel and AR drafting ar_draft_model_dir = "amd/PARD-Llama-3.2-1B" # Compatible with parallel and AR drafting
dflash_target_dir = "Qwen/Qwen3-8B"
dflash_dir = "z-lab/Qwen3-8B-DFlash-b16"
BLOCK_SIZE = 16 BLOCK_SIZE = 16
...@@ -47,18 +50,29 @@ def _create_proposer( ...@@ -47,18 +50,29 @@ def _create_proposer(
speculative_token_tree: list[tuple[int, ...]] | None = None, speculative_token_tree: list[tuple[int, ...]] | None = None,
parallel_drafting: bool = False, parallel_drafting: bool = False,
) -> EagleProposer: ) -> EagleProposer:
model_config = ModelConfig(model=model_dir, runner="generate", max_model_len=100)
# Method-dependent setup # Method-dependent setup
if method == "eagle": if method == "eagle":
target_model_dir = model_dir
draft_model_dir = eagle_dir draft_model_dir = eagle_dir
elif method == "eagle3": elif method == "eagle3":
target_model_dir = model_dir
draft_model_dir = eagle3_dir draft_model_dir = eagle3_dir
elif method == "draft_model": elif method == "draft_model":
target_model_dir = model_dir
draft_model_dir = ar_draft_model_dir draft_model_dir = ar_draft_model_dir
elif method == "dflash":
target_model_dir = dflash_target_dir
draft_model_dir = dflash_dir
else: else:
raise ValueError(f"Unknown method: {method}") raise ValueError(f"Unknown method: {method}")
model_config = ModelConfig(
model=target_model_dir,
runner="generate",
max_model_len=100,
trust_remote_code=(method == "dflash"),
)
spec_token_tree_str = None spec_token_tree_str = None
if speculative_token_tree is not None: if speculative_token_tree is not None:
assert num_speculative_tokens == len(speculative_token_tree) assert num_speculative_tokens == len(speculative_token_tree)
...@@ -92,7 +106,9 @@ def _create_proposer( ...@@ -92,7 +106,9 @@ def _create_proposer(
attention_config=AttentionConfig(backend=attention_backend), attention_config=AttentionConfig(backend=attention_backend),
) )
if "eagle" in method: if method == "dflash":
proposer = DFlashProposer(vllm_config=vllm_config, device=device)
elif "eagle" in method:
proposer = EagleProposer(vllm_config=vllm_config, device=device) proposer = EagleProposer(vllm_config=vllm_config, device=device)
else: else:
proposer = DraftModelProposer(vllm_config=vllm_config, device=device) proposer = DraftModelProposer(vllm_config=vllm_config, device=device)
...@@ -1152,3 +1168,136 @@ def test_propose_tree(spec_token_tree): ...@@ -1152,3 +1168,136 @@ def test_propose_tree(spec_token_tree):
# Verify that the draft tokens match our expectations. # Verify that the draft tokens match our expectations.
assert torch.equal(result, expected_tokens) assert torch.equal(result, expected_tokens)
def test_set_inputs_first_pass_dflash():
"""
Test for DFlash set_inputs_first_pass.
DFlash uses cross-attention: context tokens become K/V and only
query tokens (bonus + mask) are Q. This tests the DFlash-specific
input preparation where:
- Context hidden states are stored by reference (no copy)
- Query input_ids are [next_token, mask, mask, ...] per request
- Context and query positions are written to separate buffers
- token_indices_to_sample points to mask token positions only
- A new CommonAttentionMetadata is returned with causal=False
Setup:
- 3 requests with query_lens [3, 2, 4]
- num_speculative_tokens = 3
- num_query_per_req = 4 (1 bonus + 3 mask tokens)
- next_token_ids: [100, 200, 300]
Expected output layout (query tokens only, 12 total):
Request 0 (indices 0-3): [100, mask, mask, mask]
Request 1 (indices 4-7): [200, mask, mask, mask]
Request 2 (indices 8-11): [300, mask, mask, mask]
Expected positions layout (separate buffers):
Context (_context_positions_buffer, 9 tokens): copied from target_positions
Query (positions, 12 tokens):
Request 0: last_pos=9, query=[10, 11, 12, 13]
Request 1: last_pos=7, query=[8, 9, 10, 11]
Request 2: last_pos=11, query=[12, 13, 14, 15]
"""
device = torch.device(current_platform.device_type)
num_speculative_tokens = 3
proposer = _create_proposer("dflash", num_speculative_tokens)
mask_token_id = proposer.parallel_drafting_token_id
# Setup batch with 3 requests
batch_spec = BatchSpec(
seq_lens=[10, 8, 12],
query_lens=[3, 2, 4],
)
common_attn_metadata = create_common_attn_metadata(
batch_spec,
block_size=BLOCK_SIZE,
device=device,
arange_block_indices=True,
)
# Input tensors
# Request 0: tokens [10, 11, 12] at positions [7, 8, 9]
# Request 1: tokens [20, 21] at positions [6, 7]
# Request 2: tokens [30, 31, 32, 33] at positions [8, 9, 10, 11]
target_token_ids = torch.tensor(
[10, 11, 12, 20, 21, 30, 31, 32, 33], dtype=torch.int32, device=device
)
target_positions = torch.tensor(
[7, 8, 9, 6, 7, 8, 9, 10, 11], dtype=torch.int64, device=device
)
target_hidden_states = torch.randn(
9, proposer.hidden_size, dtype=proposer.dtype, device=device
)
next_token_ids = torch.tensor([100, 200, 300], dtype=torch.int32, device=device)
num_tokens, token_indices_to_sample, output_cad = proposer.set_inputs_first_pass(
target_token_ids=target_token_ids,
next_token_ids=next_token_ids,
target_positions=target_positions,
target_hidden_states=target_hidden_states,
token_indices_to_sample=None,
cad=common_attn_metadata,
num_rejected_tokens_gpu=None,
)
num_query_per_req = 1 + num_speculative_tokens # 4
num_context = 9
# num_tokens is the query-only count
assert num_tokens == 3 * num_query_per_req # 12
# Verify input_ids (query tokens only)
# Each request: [next_token, mask, mask, mask]
M = mask_token_id
expected_input_ids = torch.tensor(
[100, M, M, M, 200, M, M, M, 300, M, M, M],
dtype=torch.int32,
device=device,
)
assert torch.equal(proposer.input_ids[:num_tokens], expected_input_ids)
# Verify context positions (separate buffer): copied from target_positions
assert torch.equal(
proposer._context_positions_buffer[:num_context], target_positions
)
# Verify query positions (separate buffer, starts at index 0):
# req0: last_pos=9, query=[10, 11, 12, 13]
# req1: last_pos=7, query=[8, 9, 10, 11]
# req2: last_pos=11, query=[12, 13, 14, 15]
expected_query_positions = torch.tensor(
[10, 11, 12, 13, 8, 9, 10, 11, 12, 13, 14, 15],
dtype=torch.int64,
device=device,
)
assert torch.equal(
proposer.positions[:num_tokens],
expected_query_positions,
)
# Verify token_indices_to_sample (mask tokens only, skip bonus at offset 0)
# req0: query indices 0-3, mask at 1,2,3
# req1: query indices 4-7, mask at 5,6,7
# req2: query indices 8-11, mask at 9,10,11
expected_token_indices_to_sample = torch.tensor(
[1, 2, 3, 5, 6, 7, 9, 10, 11], dtype=torch.int32, device=device
)
assert torch.equal(token_indices_to_sample, expected_token_indices_to_sample)
# Verify the new CAD has DFlash-specific properties
assert output_cad.causal is False # DFlash requires non-causal attention
assert output_cad.num_actual_tokens == num_tokens # query-only count
assert output_cad.max_query_len == num_query_per_req
expected_query_start_loc = torch.tensor(
[0, 4, 8, 12], dtype=torch.int32, device=device
)
assert torch.equal(output_cad.query_start_loc, expected_query_start_loc)
# Verify hidden states (stored by reference, not copied)
assert proposer._dflash_hidden_states is target_hidden_states
...@@ -47,8 +47,11 @@ MTPModelTypes = Literal[ ...@@ -47,8 +47,11 @@ MTPModelTypes = Literal[
"pangu_ultra_moe_mtp", "pangu_ultra_moe_mtp",
"step3p5_mtp", "step3p5_mtp",
] ]
EagleModelTypes = Literal["eagle", "eagle3", "extract_hidden_states", MTPModelTypes]
NgramGPUTypes = Literal["ngram_gpu"] NgramGPUTypes = Literal["ngram_gpu"]
DFlashModelTypes = Literal["dflash"]
EagleModelTypes = Literal[
"eagle", "eagle3", "extract_hidden_states", MTPModelTypes, DFlashModelTypes
]
SpeculativeMethod = Literal[ SpeculativeMethod = Literal[
"ngram", "ngram",
"medusa", "medusa",
...@@ -206,7 +209,11 @@ class SpeculativeConfig: ...@@ -206,7 +209,11 @@ class SpeculativeConfig:
factors: list[Any] = [] factors: list[Any] = []
# Eagle3 and extract_hidden_states affect the computation graph because # Eagle3 and extract_hidden_states affect the computation graph because
# they return intermediate hidden states in addition to the final hidden state. # they return intermediate hidden states in addition to the final hidden state.
uses_aux_hidden_states = self.method in ("eagle3", "extract_hidden_states") uses_aux_hidden_states = self.method in (
"eagle3",
"extract_hidden_states",
"dflash",
)
factors.append(uses_aux_hidden_states) factors.append(uses_aux_hidden_states)
# The specific layers used also affect the computation graph # The specific layers used also affect the computation graph
...@@ -490,7 +497,7 @@ class SpeculativeConfig: ...@@ -490,7 +497,7 @@ class SpeculativeConfig:
) )
# Automatically detect the method # Automatically detect the method
if self.method in ("eagle", "eagle3"): if self.method in ("eagle", "eagle3", "dflash"):
pass pass
# examples: # examples:
# yuhuili/EAGLE-LLaMA3-Instruct-8B # yuhuili/EAGLE-LLaMA3-Instruct-8B
...@@ -500,6 +507,8 @@ class SpeculativeConfig: ...@@ -500,6 +507,8 @@ class SpeculativeConfig:
self.method = "eagle" self.method = "eagle"
elif "eagle3" in self.draft_model_config.model.lower(): elif "eagle3" in self.draft_model_config.model.lower():
self.method = "eagle3" self.method = "eagle3"
elif "dflash" in self.draft_model_config.model.lower():
self.method = "dflash"
elif self.draft_model_config.hf_config.model_type == "medusa": elif self.draft_model_config.hf_config.model_type == "medusa":
self.method = "medusa" self.method = "medusa"
elif self.draft_model_config.hf_config.model_type == "mlp_speculator": elif self.draft_model_config.hf_config.model_type == "mlp_speculator":
...@@ -532,7 +541,7 @@ class SpeculativeConfig: ...@@ -532,7 +541,7 @@ class SpeculativeConfig:
) )
# Replace hf_config for EAGLE draft_model # Replace hf_config for EAGLE draft_model
if self.method in ("eagle", "eagle3"): if self.method in ("eagle", "eagle3", "dflash"):
from vllm.transformers_utils.configs.eagle import EAGLEConfig from vllm.transformers_utils.configs.eagle import EAGLEConfig
from vllm.transformers_utils.configs.speculators import ( from vllm.transformers_utils.configs.speculators import (
SpeculatorsConfig, SpeculatorsConfig,
...@@ -552,6 +561,9 @@ class SpeculativeConfig: ...@@ -552,6 +561,9 @@ class SpeculativeConfig:
self.draft_model_config.hf_config = eagle_config self.draft_model_config.hf_config = eagle_config
self.update_arch_() self.update_arch_()
if self.method == "dflash":
self.parallel_drafting = True
if self.num_speculative_tokens is not None and hasattr( if self.num_speculative_tokens is not None and hasattr(
self.draft_model_config.hf_config, "num_lookahead_tokens" self.draft_model_config.hf_config, "num_lookahead_tokens"
): ):
...@@ -807,7 +819,7 @@ class SpeculativeConfig: ...@@ -807,7 +819,7 @@ class SpeculativeConfig:
"kimi_k25", "kimi_k25",
] ]
if ( if (
self.method in ("eagle3", "extract_hidden_states") self.method in ("eagle3", "extract_hidden_states", "dflash")
and self.target_model_config and self.target_model_config
and not any( and not any(
supported_model in self.target_model_config.hf_text_config.model_type supported_model in self.target_model_config.hf_text_config.model_type
...@@ -855,7 +867,10 @@ class SpeculativeConfig: ...@@ -855,7 +867,10 @@ class SpeculativeConfig:
return slots_per_req return slots_per_req
def use_eagle(self) -> bool: def use_eagle(self) -> bool:
return self.method in ("eagle", "eagle3", "mtp") return self.method in ("eagle", "eagle3", "mtp", "dflash")
def use_dflash(self) -> bool:
return self.method == "dflash"
def uses_draft_model(self) -> bool: def uses_draft_model(self) -> bool:
return self.method == "draft_model" return self.method == "draft_model"
......
...@@ -1327,6 +1327,26 @@ class VllmConfig: ...@@ -1327,6 +1327,26 @@ class VllmConfig:
max_num_batched_tokens - scheduled_token_delta max_num_batched_tokens - scheduled_token_delta
) )
if self.scheduler_config.max_num_scheduled_tokens <= 0:
raise ValueError(
"max_num_scheduled_tokens is set to"
f" {self.scheduler_config.max_num_scheduled_tokens} based on"
" the speculative decoding settings, which does not allow"
" any tokens to be scheduled. Increase max_num_batched_tokens"
" to accommodate the additional draft token slots, or decrease"
" num_speculative_tokens or max_num_seqs."
)
if self.scheduler_config.max_num_scheduled_tokens < 8192:
logger.warning_once(
"max_num_scheduled_tokens is set to"
f" {self.scheduler_config.max_num_scheduled_tokens} based on"
" the speculative decoding settings. This may lead to suboptimal"
" performance. Consider increasing max_num_batched_tokens to"
" accommodate the additional draft token slots, or decrease"
" num_speculative_tokens or max_num_seqs.",
scope="local",
)
max_num_scheduled_tokens = self.scheduler_config.max_num_scheduled_tokens max_num_scheduled_tokens = self.scheduler_config.max_num_scheduled_tokens
if max_num_batched_tokens < max_num_scheduled_tokens + ( if max_num_batched_tokens < max_num_scheduled_tokens + (
self.speculative_config.max_num_new_slots_for_drafting self.speculative_config.max_num_new_slots_for_drafting
......
...@@ -285,6 +285,7 @@ class Qwen3ForCausalLM( ...@@ -285,6 +285,7 @@ class Qwen3ForCausalLM(
self.config = config self.config = config
self.vllm_config = vllm_config
self.quant_config = quant_config self.quant_config = quant_config
self.model = Qwen3Model( self.model = Qwen3Model(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
......
This diff is collapsed.
...@@ -56,6 +56,7 @@ from vllm.sequence import IntermediateTensors ...@@ -56,6 +56,7 @@ from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.qwen3_next import Qwen3NextConfig from vllm.transformers_utils.configs.qwen3_next import Qwen3NextConfig
from .interfaces import ( from .interfaces import (
EagleModelMixin,
HasInnerState, HasInnerState,
IsHybrid, IsHybrid,
MixtureOfExperts, MixtureOfExperts,
...@@ -454,7 +455,7 @@ class Qwen3NextDecoderLayer(nn.Module): ...@@ -454,7 +455,7 @@ class Qwen3NextDecoderLayer(nn.Module):
@support_torch_compile @support_torch_compile
class Qwen3NextModel(nn.Module): class Qwen3NextModel(nn.Module, EagleModelMixin):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
...@@ -492,8 +493,6 @@ class Qwen3NextModel(nn.Module): ...@@ -492,8 +493,6 @@ class Qwen3NextModel(nn.Module):
else: else:
self.norm = PPMissingLayer() self.norm = PPMissingLayer()
self.aux_hidden_state_layers: tuple[int, ...] = ()
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
...@@ -515,20 +514,19 @@ class Qwen3NextModel(nn.Module): ...@@ -515,20 +514,19 @@ class Qwen3NextModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
aux_hidden_states = [] aux_hidden_states = self._maybe_add_hidden_state([], 0, hidden_states, residual)
for layer_idx, layer in enumerate( for layer_idx, layer in enumerate(
islice(self.layers, self.start_layer, self.end_layer), islice(self.layers, self.start_layer, self.end_layer),
start=self.start_layer, start=self.start_layer,
): ):
if layer_idx in self.aux_hidden_state_layers:
aux_hidden_states.append(
hidden_states + residual if residual is not None else hidden_states
)
hidden_states, residual = layer( hidden_states, residual = layer(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
residual=residual, residual=residual,
) )
self._maybe_add_hidden_state(
aux_hidden_states, layer_idx + 1, hidden_states, residual
)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors( return IntermediateTensors(
......
...@@ -546,6 +546,7 @@ _SPECULATIVE_DECODING_MODELS = { ...@@ -546,6 +546,7 @@ _SPECULATIVE_DECODING_MODELS = {
"EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"), "EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
"EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"), "EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"),
"EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"), "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
"DFlashDraftModel": ("qwen3_dflash", "DFlashQwen3ForCausalLM"),
"Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"), "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
"LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"), "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
"Eagle3Qwen2_5vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"), "Eagle3Qwen2_5vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
......
...@@ -62,9 +62,20 @@ class EAGLEConfig(PretrainedConfig): ...@@ -62,9 +62,20 @@ class EAGLEConfig(PretrainedConfig):
else f"Eagle3{arch}" else f"Eagle3{arch}"
for arch in self.model.architectures for arch in self.model.architectures
] ]
elif method == "dflash":
assert self.model is not None, (
"model should not be None when method is dflash"
)
kwargs["architectures"] = [
arch
if arch.startswith("DFlash") or arch.endswith("DFlash")
else f"DFlash{arch}"
for arch in self.model.architectures
]
else: else:
raise ValueError( raise ValueError(
f"Invalid method {method}. Supported methods are eagle and eagle3." f"Invalid method {method}. Supported methods are "
"eagle, eagle3, and dflash."
) )
super().__init__(**kwargs) super().__init__(**kwargs)
......
...@@ -220,6 +220,17 @@ class AttentionBackend(ABC): ...@@ -220,6 +220,17 @@ class AttentionBackend(ABC):
def supports_per_head_quant_scales(cls) -> bool: def supports_per_head_quant_scales(cls) -> bool:
return False return False
@classmethod
def supports_non_causal(cls) -> bool:
"""Check if backend supports non-causal (bidirectional) attention
for decoder models.
Unlike ENCODER_ONLY attention type which implies a different
execution model, this refers to non-causal attention within the
standard paged-KV-cache decoder path.
"""
return False
@classmethod @classmethod
def supports_attn_type(cls, attn_type: str) -> bool: def supports_attn_type(cls, attn_type: str) -> bool:
"""Check if backend supports a given attention type. """Check if backend supports a given attention type.
...@@ -261,6 +272,7 @@ class AttentionBackend(ABC): ...@@ -261,6 +272,7 @@ class AttentionBackend(ABC):
use_per_head_quant_scales: bool, use_per_head_quant_scales: bool,
device_capability: "DeviceCapability", device_capability: "DeviceCapability",
attn_type: str, attn_type: str,
use_non_causal: bool = False,
) -> list[str]: ) -> list[str]:
invalid_reasons = [] invalid_reasons = []
if not cls.supports_head_size(head_size): if not cls.supports_head_size(head_size):
...@@ -293,6 +305,8 @@ class AttentionBackend(ABC): ...@@ -293,6 +305,8 @@ class AttentionBackend(ABC):
invalid_reasons.append("compute capability not supported") invalid_reasons.append("compute capability not supported")
if not cls.supports_attn_type(attn_type): if not cls.supports_attn_type(attn_type):
invalid_reasons.append(f"attention type {attn_type} not supported") invalid_reasons.append(f"attention type {attn_type} not supported")
if use_non_causal and not cls.supports_non_causal():
invalid_reasons.append("non-causal attention not supported")
combination_reason = cls.supports_combination( combination_reason = cls.supports_combination(
head_size, head_size,
dtype, dtype,
......
...@@ -101,6 +101,10 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -101,6 +101,10 @@ class FlashAttentionBackend(AttentionBackend):
def get_name() -> str: def get_name() -> str:
return "FLASH_ATTN" return "FLASH_ATTN"
@classmethod
def supports_non_causal(cls) -> bool:
return True
@classmethod @classmethod
def supports_attn_type(cls, attn_type: str) -> bool: def supports_attn_type(cls, attn_type: str) -> bool:
"""FlashAttention supports all attention types.""" """FlashAttention supports all attention types."""
......
...@@ -29,6 +29,7 @@ class AttentionSelectorConfig(NamedTuple): ...@@ -29,6 +29,7 @@ class AttentionSelectorConfig(NamedTuple):
use_mm_prefix: bool = False use_mm_prefix: bool = False
use_per_head_quant_scales: bool = False use_per_head_quant_scales: bool = False
attn_type: str = AttentionType.DECODER attn_type: str = AttentionType.DECODER
use_non_causal: bool = False
def __repr__(self): def __repr__(self):
return ( return (
...@@ -41,7 +42,8 @@ class AttentionSelectorConfig(NamedTuple): ...@@ -41,7 +42,8 @@ class AttentionSelectorConfig(NamedTuple):
f"use_sparse={self.use_sparse}, " f"use_sparse={self.use_sparse}, "
f"use_mm_prefix={self.use_mm_prefix}, " f"use_mm_prefix={self.use_mm_prefix}, "
f"use_per_head_quant_scales={self.use_per_head_quant_scales}, " f"use_per_head_quant_scales={self.use_per_head_quant_scales}, "
f"attn_type={self.attn_type})" f"attn_type={self.attn_type}, "
f"use_non_causal={self.use_non_causal})"
) )
...@@ -76,6 +78,11 @@ def get_attn_backend( ...@@ -76,6 +78,11 @@ def get_attn_backend(
else: else:
block_size = None block_size = None
speculative_config = vllm_config.speculative_config
use_non_causal = (
speculative_config is not None and speculative_config.method == "dflash"
)
attn_selector_config = AttentionSelectorConfig( attn_selector_config = AttentionSelectorConfig(
head_size=head_size, head_size=head_size,
dtype=dtype, dtype=dtype,
...@@ -87,6 +94,7 @@ def get_attn_backend( ...@@ -87,6 +94,7 @@ def get_attn_backend(
use_mm_prefix=use_mm_prefix, use_mm_prefix=use_mm_prefix,
use_per_head_quant_scales=use_per_head_quant_scales, use_per_head_quant_scales=use_per_head_quant_scales,
attn_type=attn_type or AttentionType.DECODER, attn_type=attn_type or AttentionType.DECODER,
use_non_causal=use_non_causal,
) )
return _cached_get_attn_backend( return _cached_get_attn_backend(
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
import torch
from typing_extensions import override
from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.triton_utils import triton
from vllm.v1.attention.backend import CommonAttentionMetadata
from vllm.v1.spec_decode.eagle import SpecDecodeBaseProposer
from vllm.v1.spec_decode.utils import copy_and_expand_dflash_inputs_kernel
logger = init_logger(__name__)
class DFlashProposer(SpecDecodeBaseProposer):
def __init__(
self,
vllm_config: VllmConfig,
device: torch.device,
runner=None,
):
assert vllm_config.speculative_config is not None
assert vllm_config.speculative_config.method == "dflash"
super().__init__(
vllm_config=vllm_config,
device=device,
pass_hidden_states_to_model=True,
runner=runner,
)
# Only next_token_ids and mask tokens are query tokens, all other context is K/V
self.max_query_tokens = self.max_batch_size * (1 + self.num_speculative_tokens)
# Positions covers both context states + query states
self.max_positions = self.max_num_tokens + self.max_query_tokens
# Separate context buffers to keep query buffer addresses stable for CUDA graphs
self._context_slot_mapping_buffer = torch.zeros(
self.max_num_tokens,
dtype=torch.int64,
device=device,
)
self._slot_mapping_buffer = torch.zeros(
self.max_query_tokens,
dtype=torch.int64,
device=device,
)
self._context_positions_buffer = torch.zeros(
self.max_num_tokens,
dtype=torch.int64,
device=device,
)
self.positions = torch.zeros(
self.max_query_tokens,
dtype=torch.int64,
device=device,
)
self.arange = torch.arange(
self.max_positions + 1, device=device, dtype=torch.int32
)
# For DFlash we use the input embeddings to embed the mask token
self.parallel_drafting_hidden_state_tensor = None
@override
def _raise_if_multimodal(self):
# Override to allow multimodal inputs since DFlash supports Qwen3.5 models
# Support for multimodal inputs has not been tested.
pass
@override
def set_inputs_first_pass(
self,
target_token_ids: torch.Tensor,
next_token_ids: torch.Tensor,
target_positions: torch.Tensor,
target_hidden_states: torch.Tensor,
token_indices_to_sample: torch.Tensor | None,
cad: CommonAttentionMetadata,
num_rejected_tokens_gpu: torch.Tensor | None,
) -> tuple[int, torch.Tensor, CommonAttentionMetadata]:
# DFlash cross-attention: context K/V from target hidden states,
# Q from query embeddings (bonus + mask tokens).
batch_size = cad.batch_size()
num_context = target_token_ids.shape[0]
num_query_per_req = 1 + self.num_speculative_tokens
num_query_total = batch_size * num_query_per_req
# Store for build_model_inputs_first_pass to use
self._dflash_num_context = num_context
# We don't need to copy into a buffer here since the context preprocessing
# does not run in a CUDA graph
self._dflash_hidden_states = target_hidden_states
token_indices_to_sample = torch.empty(
batch_size * self.num_speculative_tokens,
dtype=torch.int32,
device=self.device,
)
# Launch fused triton kernel for input_ids, positions, slot_mapping,
# and token_indices_to_sample
max_ctx_per_req = cad.max_query_len
max_tokens_per_req = max_ctx_per_req + num_query_per_req
BLOCK_SIZE = min(256, triton.next_power_of_2(max_tokens_per_req))
num_blocks = triton.cdiv(max_tokens_per_req, BLOCK_SIZE)
grid = (batch_size, num_blocks)
has_num_rejected = num_rejected_tokens_gpu is not None
copy_and_expand_dflash_inputs_kernel[grid](
# Inputs
next_token_ids_ptr=next_token_ids,
target_positions_ptr=target_positions,
# Outputs
out_input_ids_ptr=self.input_ids,
out_context_positions_ptr=self._context_positions_buffer,
out_query_positions_ptr=self.positions,
out_context_slot_mapping_ptr=self._context_slot_mapping_buffer,
out_query_slot_mapping_ptr=self._slot_mapping_buffer,
out_token_indices_ptr=token_indices_to_sample,
# Block table
block_table_ptr=cad.block_table_tensor,
block_table_stride=cad.block_table_tensor.stride(0),
# Metadata
query_start_loc_ptr=cad.query_start_loc,
num_rejected_tokens_ptr=(
num_rejected_tokens_gpu if has_num_rejected else 0
),
# Scalars
parallel_drafting_token_id=self.parallel_drafting_token_id,
block_size=self.block_size,
num_query_per_req=num_query_per_req,
num_speculative_tokens=self.num_speculative_tokens,
total_input_tokens=num_context,
BLOCK_SIZE=BLOCK_SIZE,
HAS_NUM_REJECTED=has_num_rejected,
)
query_slot_mapping = self._slot_mapping_buffer[:num_query_total]
new_query_start_loc = self.arange[: batch_size + 1] * num_query_per_req
# In padded mode, cad.seq_lens includes rejected tokens. Subtract
# them so attention only sees the valid prefix of context states.
effective_seq_lens = cad.seq_lens
if has_num_rejected:
effective_seq_lens = effective_seq_lens - num_rejected_tokens_gpu
new_cad = CommonAttentionMetadata(
query_start_loc=new_query_start_loc,
seq_lens=effective_seq_lens + num_query_per_req,
query_start_loc_cpu=(
torch.from_numpy(self.token_arange_np[: batch_size + 1]).clone()
* num_query_per_req
),
_seq_lens_cpu=None,
_num_computed_tokens_cpu=None,
num_reqs=cad.num_reqs,
num_actual_tokens=num_query_total,
max_query_len=num_query_per_req,
max_seq_len=cad.max_seq_len + num_query_per_req,
block_table_tensor=cad.block_table_tensor,
slot_mapping=query_slot_mapping,
causal=False, # Non-causal attention is required for DFlash
)
return num_query_total, token_indices_to_sample, new_cad
@override
@torch.inference_mode()
def dummy_run(
self,
num_tokens: int,
use_cudagraphs: bool = True,
is_graph_capturing: bool = False,
slot_mappings: dict[str, torch.Tensor] | None = None,
) -> None:
"""
Key differences to default dummy_run:
- Only one forward pass due to parallel drafting
- DFlash uses context states as unpadded metadata, so hidden_states will
use the unpadded num_tokens instead of num_input_tokens
- max_query_tokens is quite small, DFlash only sees spec tokens as queries
- Multimodal inputs are not currently supported
"""
num_query_tokens = min(num_tokens, self.max_query_tokens)
cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = (
self._determine_batch_execution_and_padding(
num_query_tokens, use_cudagraphs=use_cudagraphs
)
)
# Slot mapping sized to num_input_tokens (query only), matching
# the K/V tensor size from the model forward. Context KVs are
# pre-inserted separately and don't flow through the model.
if (
self._draft_attn_layer_names
and slot_mappings is not None
and next(iter(self._draft_attn_layer_names)) in slot_mappings
):
slot_mapping_dict = self._get_slot_mapping(num_input_tokens)
else:
slot_mapping_dict = slot_mappings or {}
# Context and query positions use separate buffers; no copy needed.
context_positions = self._context_positions_buffer[:num_tokens]
# Context states will be passed directly to the precomputation without
# going through the buffer, since no CUDA graph is used for the precomputation.
# For the dummy run, we use the dummy buffer.
context_states = self.hidden_states[:num_tokens]
# Run the KV projection (GEMM + norms + RoPE) for memory profiling,
self.model.precompute_and_store_context_kv(context_states, context_positions)
with set_forward_context(
None,
self.vllm_config,
num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=cudagraph_runtime_mode,
slot_mapping=slot_mapping_dict,
):
self.model(
input_ids=self.input_ids[:num_input_tokens],
positions=self._get_positions(num_input_tokens),
inputs_embeds=None,
)
@override
def build_model_inputs_first_pass(
self,
num_tokens: int,
num_input_tokens: int,
mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None,
) -> tuple[dict[str, Any], int]:
# Context and query positions/slots were written to separate
# buffers by the kernel — no copy needed.
num_context = self._dflash_num_context
# Pre-insert context KVs directly into cache
self.model.precompute_and_store_context_kv(
self._dflash_hidden_states, # Shape is already [num_context, hidden_size]
self._context_positions_buffer[:num_context],
self._context_slot_mapping_buffer[:num_context],
)
return (
dict(
input_ids=self.input_ids[:num_input_tokens],
positions=self._get_positions(num_input_tokens),
inputs_embeds=None,
),
num_input_tokens,
)
@override
def build_per_layer_attn_metadata(
self, cad: CommonAttentionMetadata, draft_index: int = 0
) -> dict[str, object]:
per_layer_attention_metadata = super().build_per_layer_attn_metadata(
cad, draft_index
)
for layer_name, attn_metadata in per_layer_attention_metadata.items():
assert getattr(attn_metadata, "causal", None) is False, (
f"Attention metadata for layer {layer_name} does not have"
" non-causal support, which is required for DFlash."
" Consider using a different attention backend, such as FlashAttention."
)
return per_layer_attention_metadata
@override
def _get_eagle3_use_aux_hidden_state_from_config(self):
use_aux_hidden_state = True
dflash_config = getattr(
self.draft_model_config.hf_config, "dflash_config", None
)
if dflash_config is not None:
use_aux_hidden_state = dflash_config.get("use_aux_hidden_state", True)
return use_aux_hidden_state
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import ast import ast
from importlib.util import find_spec from importlib.util import find_spec
from typing import cast from typing import Any, cast
import numpy as np import numpy as np
import torch import torch
...@@ -23,6 +23,7 @@ from vllm.model_executor.models import supports_multimodal ...@@ -23,6 +23,7 @@ from vllm.model_executor.models import supports_multimodal
from vllm.model_executor.models.deepseek_eagle3 import Eagle3DeepseekV2ForCausalLM from vllm.model_executor.models.deepseek_eagle3 import Eagle3DeepseekV2ForCausalLM
from vllm.model_executor.models.interfaces import SupportsMultiModal from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
from vllm.model_executor.models.qwen3_dflash import DFlashQwen3ForCausalLM
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import triton from vllm.triton_utils import triton
...@@ -83,13 +84,15 @@ class SpecDecodeBaseProposer: ...@@ -83,13 +84,15 @@ class SpecDecodeBaseProposer:
self.hidden_size = self.draft_model_config.get_hidden_size() self.hidden_size = self.draft_model_config.get_hidden_size()
self.inputs_embeds_size = self.draft_model_config.get_inputs_embeds_size() self.inputs_embeds_size = self.draft_model_config.get_inputs_embeds_size()
# Unifying eagle, draft model, and parallel drafting support # Unifying eagle, draft model, and parallel drafting support.
# DFlash always uses parallel drafting (all tokens in one pass),
# but has an additional slot for the next_token_id (does not shift like EAGLE)
self.parallel_drafting: bool = self.speculative_config.parallel_drafting self.parallel_drafting: bool = self.speculative_config.parallel_drafting
self.extra_slots_per_request = ( self.extra_slots_per_request = (
1 if not self.parallel_drafting else self.num_speculative_tokens 1 if not self.parallel_drafting else self.num_speculative_tokens
) )
self.net_num_new_slots_per_request = self.extra_slots_per_request - ( self.net_num_new_slots_per_request = self.extra_slots_per_request - (
1 if self.pass_hidden_states_to_model else 0 1 if (self.pass_hidden_states_to_model and self.method != "dflash") else 0
) )
self.needs_extra_input_slots = self.net_num_new_slots_per_request > 0 self.needs_extra_input_slots = self.net_num_new_slots_per_request > 0
...@@ -101,10 +104,14 @@ class SpecDecodeBaseProposer: ...@@ -101,10 +104,14 @@ class SpecDecodeBaseProposer:
self.speculative_config.use_local_argmax_reduction self.speculative_config.use_local_argmax_reduction
) )
max_batch_size = vllm_config.scheduler_config.max_num_seqs self.max_batch_size = vllm_config.scheduler_config.max_num_seqs
self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
self.token_arange_np = np.arange(self.max_num_tokens) self.token_arange_np = np.arange(self.max_num_tokens)
# Can be specialized by methods like DFlash to reduce the limit
self.max_query_tokens = self.max_num_tokens
self.max_positions = self.max_num_tokens
# Multi-modal data support # Multi-modal data support
self.mm_registry = MULTIMODAL_REGISTRY self.mm_registry = MULTIMODAL_REGISTRY
self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs( self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
...@@ -146,18 +153,20 @@ class SpecDecodeBaseProposer: ...@@ -146,18 +153,20 @@ class SpecDecodeBaseProposer:
# 1D-RoPE. # 1D-RoPE.
# See page 5 of https://arxiv.org/abs/2409.12191 # See page 5 of https://arxiv.org/abs/2409.12191
self.mrope_positions = torch.zeros( self.mrope_positions = torch.zeros(
(3, self.max_num_tokens + 1), dtype=torch.int64, device=device (3, self.max_positions + 1), dtype=torch.int64, device=device
) )
elif self.uses_xdrope_dim > 0 and self.draft_uses_xdrope_dim > 0: elif self.uses_xdrope_dim > 0 and self.draft_uses_xdrope_dim > 0:
self.xdrope_positions = torch.zeros( self.xdrope_positions = torch.zeros(
(self.uses_xdrope_dim, self.max_num_tokens + 1), (self.uses_xdrope_dim, self.max_positions + 1),
dtype=torch.int64, dtype=torch.int64,
device=device, device=device,
) )
else: else:
# RoPE need (max_num_tokens,) # RoPE need (max_num_tokens,)
self.positions = torch.zeros( self.positions = torch.zeros(
self.max_num_tokens, dtype=torch.int64, device=device self.max_positions,
dtype=torch.int64,
device=device,
) )
self.hidden_states = torch.zeros( self.hidden_states = torch.zeros(
(self.max_num_tokens, self.hidden_size), dtype=self.dtype, device=device (self.max_num_tokens, self.hidden_size), dtype=self.dtype, device=device
...@@ -168,7 +177,7 @@ class SpecDecodeBaseProposer: ...@@ -168,7 +177,7 @@ class SpecDecodeBaseProposer:
# We need +1 here because the arange is used to set query_start_loc, # We need +1 here because the arange is used to set query_start_loc,
# which has one more element than batch_size. # which has one more element than batch_size.
max_num_slots_for_arange = max(max_batch_size + 1, self.max_num_tokens) max_num_slots_for_arange = max(self.max_batch_size + 1, self.max_num_tokens)
self.arange = torch.arange( self.arange = torch.arange(
max_num_slots_for_arange, device=device, dtype=torch.int32 max_num_slots_for_arange, device=device, dtype=torch.int32
) )
...@@ -200,7 +209,7 @@ class SpecDecodeBaseProposer: ...@@ -200,7 +209,7 @@ class SpecDecodeBaseProposer:
) )
self.backup_next_token_ids = CpuGpuBuffer( self.backup_next_token_ids = CpuGpuBuffer(
max_batch_size, self.max_batch_size,
dtype=torch.int32, dtype=torch.int32,
pin_memory=is_pin_memory_available(), pin_memory=is_pin_memory_available(),
device=device, device=device,
...@@ -208,7 +217,9 @@ class SpecDecodeBaseProposer: ...@@ -208,7 +217,9 @@ class SpecDecodeBaseProposer:
) )
self._slot_mapping_buffer = torch.zeros( self._slot_mapping_buffer = torch.zeros(
self.max_num_tokens, dtype=torch.int64, device=device self.max_positions,
dtype=torch.int64,
device=device,
) )
# Determine allowed attention backends once during initialization. # Determine allowed attention backends once during initialization.
...@@ -275,7 +286,7 @@ class SpecDecodeBaseProposer: ...@@ -275,7 +286,7 @@ class SpecDecodeBaseProposer:
# Precompute draft position offsets in flattened tree. # Precompute draft position offsets in flattened tree.
self.tree_draft_pos_offsets = torch.arange( self.tree_draft_pos_offsets = torch.arange(
1, len(self.tree_choices) + 1, device=device, dtype=torch.int32 1, len(self.tree_choices) + 1, device=device, dtype=torch.int32
).repeat(max_batch_size, 1) ).repeat(self.max_batch_size, 1)
def _raise_if_padded_drafter_batch_disabled(self): def _raise_if_padded_drafter_batch_disabled(self):
if self.speculative_config.disable_padded_drafter_batch: if self.speculative_config.disable_padded_drafter_batch:
...@@ -305,14 +316,19 @@ class SpecDecodeBaseProposer: ...@@ -305,14 +316,19 @@ class SpecDecodeBaseProposer:
# for those masked slots. # for those masked slots.
model_hf_config = self.draft_model_config.hf_config model_hf_config = self.draft_model_config.hf_config
if hasattr(model_hf_config, "pard_token"): # DFlash stores mask_token_id in dflash_config
dflash_config = getattr(model_hf_config, "dflash_config", None)
if dflash_config and "mask_token_id" in dflash_config:
self.parallel_drafting_token_id = dflash_config["mask_token_id"]
elif hasattr(model_hf_config, "pard_token"):
self.parallel_drafting_token_id = model_hf_config.pard_token self.parallel_drafting_token_id = model_hf_config.pard_token
elif hasattr(model_hf_config, "ptd_token_id"): elif hasattr(model_hf_config, "ptd_token_id"):
self.parallel_drafting_token_id = model_hf_config.ptd_token_id self.parallel_drafting_token_id = model_hf_config.ptd_token_id
else: else:
raise ValueError( raise ValueError(
"For parallel drafting, the draft model config must have " "For parallel drafting, the draft model config must have "
"`pard_token` or `ptd_token_id` specified in its config.json." "`pard_token`, `ptd_token_id`, or "
"`dflash_config.mask_token_id` specified in its config.json."
) )
if self.pass_hidden_states_to_model: if self.pass_hidden_states_to_model:
...@@ -402,9 +418,14 @@ class SpecDecodeBaseProposer: ...@@ -402,9 +418,14 @@ class SpecDecodeBaseProposer:
) -> torch.Tensor: ) -> torch.Tensor:
batch_size = common_attn_metadata.batch_size() batch_size = common_attn_metadata.batch_size()
if self.method == "eagle3": if self.method in ("eagle3", "dflash"):
assert isinstance( assert isinstance(
self.model, (Eagle3LlamaForCausalLM, Eagle3DeepseekV2ForCausalLM) self.model,
(
Eagle3LlamaForCausalLM,
Eagle3DeepseekV2ForCausalLM,
DFlashQwen3ForCausalLM,
),
) )
target_hidden_states = self.model.combine_hidden_states( target_hidden_states = self.model.combine_hidden_states(
target_hidden_states target_hidden_states
...@@ -423,40 +444,17 @@ class SpecDecodeBaseProposer: ...@@ -423,40 +444,17 @@ class SpecDecodeBaseProposer:
) )
) )
per_layer_attn_metadata: dict[str, object] = {} per_layer_attn_metadata = self.build_per_layer_attn_metadata(
for attn_group in self.draft_attn_groups: common_attn_metadata
attn_metadata = attn_group.get_metadata_builder().build_for_drafting( )
common_attn_metadata=common_attn_metadata, draft_index=0
)
for layer_name in attn_group.layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata
cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = ( cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = (
self._determine_batch_execution_and_padding(num_tokens) self._determine_batch_execution_and_padding(num_tokens)
) )
if self.supports_mm_inputs: model_kwargs, slot_mapping_size = self.build_model_inputs_first_pass(
mm_embeds, is_mm_embed = mm_embed_inputs or (None, None) num_tokens, num_input_tokens, mm_embed_inputs
)
self.inputs_embeds[:num_tokens] = self.model.embed_input_ids(
self.input_ids[:num_tokens],
multimodal_embeddings=mm_embeds,
is_multimodal=is_mm_embed,
)
input_ids = None
inputs_embeds = self.inputs_embeds[:num_input_tokens]
else:
input_ids = self.input_ids[:num_input_tokens]
inputs_embeds = None
model_kwargs = {
"input_ids": input_ids,
"positions": self._get_positions(num_input_tokens),
"inputs_embeds": inputs_embeds,
}
if self.pass_hidden_states_to_model:
model_kwargs["hidden_states"] = self.hidden_states[:num_input_tokens]
with set_forward_context( with set_forward_context(
per_layer_attn_metadata, per_layer_attn_metadata,
...@@ -465,7 +463,7 @@ class SpecDecodeBaseProposer: ...@@ -465,7 +463,7 @@ class SpecDecodeBaseProposer:
num_tokens_across_dp=num_tokens_across_dp, num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=cudagraph_runtime_mode, cudagraph_runtime_mode=cudagraph_runtime_mode,
slot_mapping=self._get_slot_mapping( slot_mapping=self._get_slot_mapping(
num_input_tokens, common_attn_metadata.slot_mapping slot_mapping_size, common_attn_metadata.slot_mapping
), ),
): ):
ret_hidden_states = self.model(**model_kwargs) ret_hidden_states = self.model(**model_kwargs)
...@@ -488,7 +486,10 @@ class SpecDecodeBaseProposer: ...@@ -488,7 +486,10 @@ class SpecDecodeBaseProposer:
positions = self.positions[token_indices_to_sample] positions = self.positions[token_indices_to_sample]
hidden_states = hidden_states[token_indices_to_sample] hidden_states = hidden_states[token_indices_to_sample]
if isinstance(attn_metadata, TreeAttentionMetadata): if any(
isinstance(attn_metadata, TreeAttentionMetadata)
for attn_metadata in per_layer_attn_metadata.values()
):
# Draft using tree attention - requires full logits for top-k # Draft using tree attention - requires full logits for top-k
logits = self.model.compute_logits(sample_hidden_states) logits = self.model.compute_logits(sample_hidden_states)
draft_token_ids_list = self.propose_tree( draft_token_ids_list = self.propose_tree(
...@@ -504,15 +505,16 @@ class SpecDecodeBaseProposer: ...@@ -504,15 +505,16 @@ class SpecDecodeBaseProposer:
draft_token_ids = self._greedy_sample(sample_hidden_states) draft_token_ids = self._greedy_sample(sample_hidden_states)
if self.allowed_attn_types is not None and not isinstance( for attn_metadata in per_layer_attn_metadata.values():
attn_metadata, self.allowed_attn_types if self.allowed_attn_types is not None and not isinstance(
): attn_metadata, self.allowed_attn_types
raise ValueError( ):
f"Unsupported attention metadata type for speculative " raise ValueError(
"decoding with num_speculative_tokens > 1: " f"Unsupported attention metadata type for speculative "
f"{type(attn_metadata)}. Supported types are: " "decoding with num_speculative_tokens > 1: "
f"{self.allowed_attn_types}" f"{type(attn_metadata)}. Supported types are: "
) f"{self.allowed_attn_types}"
)
# Generate the remaining draft tokens. # Generate the remaining draft tokens.
draft_token_ids_list = [draft_token_ids] draft_token_ids_list = [draft_token_ids]
...@@ -593,13 +595,9 @@ class SpecDecodeBaseProposer: ...@@ -593,13 +595,9 @@ class SpecDecodeBaseProposer:
common_attn_metadata._num_computed_tokens_cpu += 1 common_attn_metadata._num_computed_tokens_cpu += 1
# Rebuild attention metadata # Rebuild attention metadata
for attn_group in self.draft_attn_groups: per_layer_attn_metadata = self.build_per_layer_attn_metadata(
attn_metadata = attn_group.get_metadata_builder().build_for_drafting( common_attn_metadata, draft_index=token_index + 1
common_attn_metadata=common_attn_metadata, )
draft_index=token_index + 1,
)
for layer_name in attn_group.layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata
# copy inputs to buffer for cudagraph # copy inputs to buffer for cudagraph
self.input_ids[:batch_size] = input_ids self.input_ids[:batch_size] = input_ids
...@@ -780,8 +778,51 @@ class SpecDecodeBaseProposer: ...@@ -780,8 +778,51 @@ class SpecDecodeBaseProposer:
return total_num_output_tokens, token_indices_to_sample, new_cad return total_num_output_tokens, token_indices_to_sample, new_cad
def build_model_inputs_first_pass(
self,
num_tokens: int,
num_input_tokens: int,
mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None,
) -> tuple[dict[str, Any], int]:
if self.supports_mm_inputs:
mm_embeds, is_mm_embed = mm_embed_inputs or (None, None)
self.inputs_embeds[:num_tokens] = self.model.embed_input_ids(
self.input_ids[:num_tokens],
multimodal_embeddings=mm_embeds,
is_multimodal=is_mm_embed,
)
input_ids = None
inputs_embeds = self.inputs_embeds[:num_input_tokens]
else:
input_ids = self.input_ids[:num_input_tokens]
inputs_embeds = None
model_kwargs = {
"input_ids": input_ids,
"positions": self._get_positions(num_input_tokens),
"inputs_embeds": inputs_embeds,
}
if self.pass_hidden_states_to_model:
model_kwargs["hidden_states"] = self.hidden_states[:num_input_tokens]
return model_kwargs, num_input_tokens
def build_per_layer_attn_metadata(
self, common_attn_metadata: CommonAttentionMetadata, draft_index: int = 0
) -> dict[str, object]:
per_layer_attn_metadata: dict[str, object] = {}
for attn_group in self.draft_attn_groups:
attn_metadata = attn_group.get_metadata_builder().build_for_drafting(
common_attn_metadata=common_attn_metadata, draft_index=draft_index
)
for layer_name in attn_group.layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata
return per_layer_attn_metadata
def model_returns_tuple(self) -> bool: def model_returns_tuple(self) -> bool:
return self.method not in ("mtp", "draft_model") return self.method not in ("mtp", "draft_model", "dflash")
def prepare_next_token_ids_cpu( def prepare_next_token_ids_cpu(
self, self,
...@@ -1310,15 +1351,20 @@ class SpecDecodeBaseProposer: ...@@ -1310,15 +1351,20 @@ class SpecDecodeBaseProposer:
self._maybe_share_embeddings(target_language_model) self._maybe_share_embeddings(target_language_model)
self._maybe_share_lm_head(target_language_model) self._maybe_share_lm_head(target_language_model)
if self.parallel_drafting and self.pass_hidden_states_to_model: if (
assert self.parallel_drafting_hidden_state_tensor is not None self.parallel_drafting
self.parallel_drafting_hidden_state_tensor.copy_( and self.pass_hidden_states_to_model
self.model.combine_hidden_states( and self.parallel_drafting_hidden_state_tensor is not None
self.model.mask_hidden.view(3 * self.hidden_size) ):
flat_mask = self.model.mask_hidden.view(-1)
if self.eagle3_use_aux_hidden_state:
# EAGLE3: mask_hidden stores all aux hidden states,
# project through combine_hidden_states
self.parallel_drafting_hidden_state_tensor.copy_(
self.model.combine_hidden_states(flat_mask)
) )
if self.eagle3_use_aux_hidden_state else:
else self.model.mask_hidden.view(self.hidden_size) self.parallel_drafting_hidden_state_tensor.copy_(flat_mask)
)
def _maybe_share_embeddings(self, target_language_model: nn.Module) -> None: def _maybe_share_embeddings(self, target_language_model: nn.Module) -> None:
""" """
...@@ -1493,8 +1539,9 @@ class SpecDecodeBaseProposer: ...@@ -1493,8 +1539,9 @@ class SpecDecodeBaseProposer:
) -> None: ) -> None:
# FIXME: when using tree-based specdec, adjust number of forward-passes # FIXME: when using tree-based specdec, adjust number of forward-passes
# according to the depth of the tree. # according to the depth of the tree.
only_one_forward_pass = is_graph_capturing or self.parallel_drafting
for fwd_idx in range( for fwd_idx in range(
self.num_speculative_tokens if not is_graph_capturing else 1 1 if only_one_forward_pass else self.num_speculative_tokens
): ):
if fwd_idx <= 1: if fwd_idx <= 1:
cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = ( cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = (
......
...@@ -441,6 +441,114 @@ def copy_and_expand_eagle_inputs_kernel( ...@@ -441,6 +441,114 @@ def copy_and_expand_eagle_inputs_kernel(
) )
@triton.jit
def copy_and_expand_dflash_inputs_kernel(
# Inputs
next_token_ids_ptr, # [num_reqs]
target_positions_ptr, # [num_context]
# Outputs
out_input_ids_ptr, # [num_query_total] (output)
out_context_positions_ptr, # [num_context] (output)
out_query_positions_ptr, # [num_query_total] (output)
out_context_slot_mapping_ptr, # [num_context] (output)
out_query_slot_mapping_ptr, # [num_query_total] (output)
out_token_indices_ptr, # [num_reqs * num_speculative_tokens] (output)
# Block table
block_table_ptr, # [max_reqs, max_blocks]
block_table_stride, # stride of block_table dim 0 (in elements)
# Metadata
query_start_loc_ptr, # [num_reqs + 1]
num_rejected_tokens_ptr, # [num_reqs] or null (0) when not padded
# Scalars
parallel_drafting_token_id, # tl.int32
block_size, # tl.int32
num_query_per_req, # tl.int32
num_speculative_tokens, # tl.int32
total_input_tokens, # tl.int32
BLOCK_SIZE: tl.constexpr,
HAS_NUM_REJECTED: tl.constexpr = False,
):
"""
Fused kernel for DFlash first-pass input setup.
Per request, this kernel:
1. Copies context positions from target_positions to
out_context_positions.
2. Computes query positions (last_target_pos + 1 + offset) and writes
them to out_query_positions.
3. Writes input_ids for query tokens: [next_token, mask, mask, ...].
4. Computes slot_mapping for context and query positions into separate
buffers via block_table lookup.
5. Writes token_indices_to_sample for the mask (speculative) tokens.
"""
req_idx = tl.program_id(axis=0)
block_idx = tl.program_id(axis=1)
# Load context token range for this request
ctx_start = tl.load(query_start_loc_ptr + req_idx)
ctx_end = tl.load(query_start_loc_ptr + req_idx + 1)
num_ctx = ctx_end - ctx_start
total_tokens = num_ctx + num_query_per_req
j = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
in_bounds = j < total_tokens
is_ctx = j < num_ctx
is_query = (~is_ctx) & in_bounds
query_off = j - num_ctx # offset within query portion (0-indexed)
# --- Positions ---
# Context: load from target_positions
ctx_pos_idx = tl.minimum(ctx_start + j, total_input_tokens - 1)
ctx_pos = tl.load(target_positions_ptr + ctx_pos_idx, mask=is_ctx, other=0)
# Query: last_valid_pos + 1 + query_off
# In padded mode, ctx_end includes rejected tokens; use valid_ctx_end
# to find the last accepted context position.
if HAS_NUM_REJECTED:
num_rejected = tl.load(num_rejected_tokens_ptr + req_idx)
valid_ctx_end = ctx_end - num_rejected
else:
valid_ctx_end = ctx_end
last_pos = tl.load(target_positions_ptr + valid_ctx_end - 1)
query_pos = last_pos + 1 + query_off
positions = tl.where(is_ctx, ctx_pos, query_pos)
# Context and query positions go to separate buffers.
ctx_pos_out = ctx_start + j
tl.store(out_context_positions_ptr + ctx_pos_out, ctx_pos, mask=is_ctx)
query_out = req_idx * num_query_per_req + query_off
tl.store(out_query_positions_ptr + query_out, query_pos, mask=is_query)
# --- Slot mapping (block_table lookup for all positions) ---
block_num = positions // block_size
# # Clamp block_number to avoid OOB when position is at max
block_num = tl.minimum(block_num, block_table_stride - 1)
block_id = tl.load(
block_table_ptr + req_idx * block_table_stride + block_num,
mask=in_bounds,
other=0,
).to(tl.int64)
slot = block_id * block_size + (positions % block_size)
tl.store(out_context_slot_mapping_ptr + ctx_pos_out, slot, mask=is_ctx)
tl.store(out_query_slot_mapping_ptr + query_out, slot, mask=is_query)
# --- Input IDs (query tokens only) ---
bonus_token = tl.load(next_token_ids_ptr + req_idx)
is_bonus = is_query & (query_off == 0)
input_id = tl.where(is_bonus, bonus_token, parallel_drafting_token_id)
tl.store(out_input_ids_ptr + query_out, input_id, mask=is_query)
# --- Token indices to sample (mask tokens, skip the bonus token) ---
is_sample = is_query & (query_off > 0)
sample_out_idx = req_idx * num_speculative_tokens + (query_off - 1)
tl.store(
out_token_indices_ptr + sample_out_idx,
query_out,
mask=is_sample,
)
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) @torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
def update_num_computed_tokens_for_batch_change( def update_num_computed_tokens_for_batch_change(
num_computed_tokens: torch.Tensor, num_computed_tokens: torch.Tensor,
......
...@@ -160,6 +160,7 @@ from vllm.v1.sample.logits_processor.interface import LogitsProcessor ...@@ -160,6 +160,7 @@ from vllm.v1.sample.logits_processor.interface import LogitsProcessor
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.rejection_sampler import RejectionSampler from vllm.v1.sample.rejection_sampler import RejectionSampler
from vllm.v1.sample.sampler import Sampler from vllm.v1.sample.sampler import Sampler
from vllm.v1.spec_decode.dflash import DFlashProposer
from vllm.v1.spec_decode.draft_model import DraftModelProposer from vllm.v1.spec_decode.draft_model import DraftModelProposer
from vllm.v1.spec_decode.eagle import EagleProposer from vllm.v1.spec_decode.eagle import EagleProposer
from vllm.v1.spec_decode.extract_hidden_states import ExtractHiddenStatesProposer from vllm.v1.spec_decode.extract_hidden_states import ExtractHiddenStatesProposer
...@@ -515,6 +516,7 @@ class GPUModelRunner( ...@@ -515,6 +516,7 @@ class GPUModelRunner(
| NgramProposerGPU | NgramProposerGPU
| SuffixDecodingProposer | SuffixDecodingProposer
| EagleProposer | EagleProposer
| DFlashProposer
| DraftModelProposer | DraftModelProposer
| MedusaProposer | MedusaProposer
| ExtractHiddenStatesProposer | ExtractHiddenStatesProposer
...@@ -546,6 +548,9 @@ class GPUModelRunner( ...@@ -546,6 +548,9 @@ class GPUModelRunner(
self._ngram_pinned_val_buf = torch.zeros( self._ngram_pinned_val_buf = torch.zeros(
self.max_num_reqs, dtype=torch.int32, pin_memory=True self.max_num_reqs, dtype=torch.int32, pin_memory=True
) )
elif self.speculative_config.use_dflash():
self.drafter = DFlashProposer(self.vllm_config, self.device, self)
self.use_aux_hidden_state_outputs = True
elif self.speculative_config.method == "suffix": elif self.speculative_config.method == "suffix":
self.drafter = SuffixDecodingProposer(self.vllm_config) self.drafter = SuffixDecodingProposer(self.vllm_config)
elif self.speculative_config.use_eagle(): elif self.speculative_config.use_eagle():
...@@ -2289,7 +2294,7 @@ class GPUModelRunner( ...@@ -2289,7 +2294,7 @@ class GPUModelRunner(
cm.slot_mapping = slot_mappings[kv_cache_gid] cm.slot_mapping = slot_mappings[kv_cache_gid]
if self.speculative_config and spec_decode_common_attn_metadata is None: if self.speculative_config and spec_decode_common_attn_metadata is None:
if isinstance(self.drafter, EagleProposer): if isinstance(self.drafter, (EagleProposer, DFlashProposer)):
if self.drafter.kv_cache_gid == kv_cache_gid: if self.drafter.kv_cache_gid == kv_cache_gid:
spec_decode_common_attn_metadata = cm spec_decode_common_attn_metadata = cm
else: else:
...@@ -4202,7 +4207,10 @@ class GPUModelRunner( ...@@ -4202,7 +4207,10 @@ class GPUModelRunner(
# as inputs, and does not need to wait for bookkeeping to finish. # as inputs, and does not need to wait for bookkeeping to finish.
assert isinstance( assert isinstance(
self.drafter, self.drafter,
EagleProposer | DraftModelProposer | ExtractHiddenStatesProposer, EagleProposer
| DFlashProposer
| DraftModelProposer
| ExtractHiddenStatesProposer,
) )
sampled_token_ids = sampler_output.sampled_token_ids sampled_token_ids = sampler_output.sampled_token_ids
if input_fits_in_drafter: if input_fits_in_drafter:
...@@ -4589,8 +4597,14 @@ class GPUModelRunner( ...@@ -4589,8 +4597,14 @@ class GPUModelRunner(
next_token_ids, valid_sampled_tokens_count next_token_ids, valid_sampled_tokens_count
) )
elif spec_config.use_eagle() or spec_config.uses_draft_model(): elif (
assert isinstance(self.drafter, EagleProposer | DraftModelProposer) spec_config.use_eagle()
or spec_config.use_dflash()
or spec_config.uses_draft_model()
):
assert isinstance(
self.drafter, EagleProposer | DFlashProposer | DraftModelProposer
)
if spec_config.disable_padded_drafter_batch: if spec_config.disable_padded_drafter_batch:
# When padded-batch is disabled, the sampled_token_ids should be # When padded-batch is disabled, the sampled_token_ids should be
...@@ -4889,10 +4903,13 @@ class GPUModelRunner( ...@@ -4889,10 +4903,13 @@ class GPUModelRunner(
return None return None
hf_config = self.speculative_config.draft_model_config.hf_config hf_config = self.speculative_config.draft_model_config.hf_config
if not hasattr(hf_config, "eagle_aux_hidden_state_layer_ids"):
return None
layer_ids = hf_config.eagle_aux_hidden_state_layer_ids layer_ids = getattr(hf_config, "eagle_aux_hidden_state_layer_ids", None)
if not layer_ids:
dflash_config = getattr(hf_config, "dflash_config", None)
if dflash_config and isinstance(dflash_config, dict):
layer_ids = dflash_config.get("target_layer_ids")
if layer_ids and isinstance(layer_ids, (list, tuple)): if layer_ids and isinstance(layer_ids, (list, tuple)):
return tuple(layer_ids) return tuple(layer_ids)
...@@ -5479,7 +5496,10 @@ class GPUModelRunner( ...@@ -5479,7 +5496,10 @@ class GPUModelRunner(
): ):
assert isinstance( assert isinstance(
self.drafter, self.drafter,
EagleProposer | DraftModelProposer | ExtractHiddenStatesProposer, EagleProposer
| DFlashProposer
| DraftModelProposer
| ExtractHiddenStatesProposer,
) )
assert self.speculative_config is not None assert self.speculative_config is not None
# Eagle currently only supports PIECEWISE cudagraphs. # Eagle currently only supports PIECEWISE cudagraphs.
...@@ -6236,7 +6256,9 @@ class GPUModelRunner( ...@@ -6236,7 +6256,9 @@ class GPUModelRunner(
self.speculative_config.use_eagle() self.speculative_config.use_eagle()
or self.speculative_config.uses_draft_model() or self.speculative_config.uses_draft_model()
): ):
assert isinstance(self.drafter, EagleProposer | DraftModelProposer) assert isinstance(
self.drafter, EagleProposer | DFlashProposer | DraftModelProposer
)
self.drafter.initialize_attn_backend(kv_cache_config, kernel_block_sizes) self.drafter.initialize_attn_backend(kv_cache_config, kernel_block_sizes)
def _check_and_update_cudagraph_mode( def _check_and_update_cudagraph_mode(
...@@ -6420,7 +6442,10 @@ class GPUModelRunner( ...@@ -6420,7 +6442,10 @@ class GPUModelRunner(
self.speculative_config.use_eagle() self.speculative_config.use_eagle()
or self.speculative_config.uses_extract_hidden_states() or self.speculative_config.uses_extract_hidden_states()
): ):
assert isinstance(self.drafter, EagleProposer | ExtractHiddenStatesProposer) assert isinstance(
self.drafter,
EagleProposer | DFlashProposer | ExtractHiddenStatesProposer,
)
self.drafter.initialize_cudagraph_keys(cudagraph_mode) self.drafter.initialize_cudagraph_keys(cudagraph_mode)
def calculate_reorder_batch_threshold(self) -> None: def calculate_reorder_batch_threshold(self) -> None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment