Unverified Commit 494636b2 authored by Benjamin Chislett's avatar Benjamin Chislett Committed by GitHub
Browse files

[Feat][Spec Decode] DFlash (#36847)


Signed-off-by: default avatarBenjamin Chislett <bchislett@nvidia.com>
parent ab1a6a43
......@@ -1163,6 +1163,14 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
# "JackFram/llama-160m",
# speculative_model="ibm-ai-platform/llama-160m-accelerator"
# ),
# [DFlash]
"DFlashDraftModel": _HfExamplesInfo(
"Qwen/Qwen3.5-4B",
speculative_model="z-lab/Qwen3.5-4B-DFlash",
use_original_num_layers=True, # Need all layers since DFlash has >1 layer,
max_model_len=8192, # Reduce max len to ensure test runs in low-VRAM CI env
max_num_seqs=32,
),
# [Eagle]
"EagleDeepSeekMTPModel": _HfExamplesInfo(
"eagle618/deepseek-v3-random",
......
......@@ -7,6 +7,7 @@ from typing import Any
import pytest
import torch
from tqdm import tqdm
from tests.evals.gsm8k.gsm8k_eval import _build_gsm8k_prompts, evaluate_gsm8k_offline
from tests.utils import (
......@@ -1105,19 +1106,178 @@ def some_high_acceptance_metrics() -> dict:
}
def compute_acceptance_rate(metrics: list[Metric]) -> float:
def compute_acceptance_rate(
metrics: list[Metric], prev_metrics: list[Metric] | None = None
) -> float:
name2metric = {metric.name: metric for metric in metrics}
n_draft_toks = name2metric["vllm:spec_decode_num_draft_tokens"].value # type: ignore
n_draft_toks = name2metric["vllm:spec_decode_num_draft_tokens"].value
if n_draft_toks == 0:
return float("nan")
n_accepted_toks = name2metric["vllm:spec_decode_num_accepted_tokens"].value # type: ignore
n_accepted_toks = name2metric["vllm:spec_decode_num_accepted_tokens"].value
if prev_metrics is not None:
prev_name2metric = {metric.name: metric for metric in prev_metrics}
n_draft_toks -= prev_name2metric["vllm:spec_decode_num_draft_tokens"].value
n_accepted_toks -= prev_name2metric[
"vllm:spec_decode_num_accepted_tokens"
].value
if n_draft_toks <= 0:
return float("nan")
return n_accepted_toks / n_draft_toks
def compute_acceptance_len(metrics: list[Metric]) -> float:
def compute_acceptance_len(
metrics: list[Metric], prev_metrics: list[Metric] | None = None
) -> float:
name2metric = {metric.name: metric for metric in metrics}
n_drafts = name2metric["vllm:spec_decode_num_drafts"].value # type: ignore
n_accepted_toks = name2metric["vllm:spec_decode_num_accepted_tokens"].value # type: ignore
n_drafts = name2metric["vllm:spec_decode_num_drafts"].value
n_accepted_toks = name2metric["vllm:spec_decode_num_accepted_tokens"].value
if n_drafts == 0:
return 1
if prev_metrics is not None:
prev_name2metric = {metric.name: metric for metric in prev_metrics}
n_drafts -= prev_name2metric["vllm:spec_decode_num_drafts"].value
n_accepted_toks -= prev_name2metric[
"vllm:spec_decode_num_accepted_tokens"
].value
if n_drafts <= 0:
return 1
return 1 + (n_accepted_toks / n_drafts)
# Datasets in the format used in DFlash validations
def load_and_process_dataset(data_name: str):
from datasets import load_dataset
if data_name == "gsm8k":
dataset = load_dataset("openai/gsm8k", "main", split="test")
prompt_fmt = (
"{question}\nPlease reason step by step,"
" and put your final answer within \\boxed{{}}."
)
dataset = dataset.map(lambda x: {"turns": [prompt_fmt.format(**x)]})
elif data_name == "mt-bench":
dataset = load_dataset("HuggingFaceH4/mt_bench_prompts", split="train")
dataset = dataset.map(lambda x: {"turns": x["prompt"]})
elif data_name == "humaneval":
dataset = load_dataset("openai/openai_humaneval", split="test")
prompt_fmt = (
"Write a solution to the following problem and make sure"
" that it passes the tests:\n```python\n{prompt}\n```"
)
dataset = dataset.map(lambda x: {"turns": [prompt_fmt.format(**x)]})
return dataset
@pytest.fixture
def dflash_config():
target_model = "Qwen/Qwen3-8B"
draft_model = "z-lab/Qwen3-8B-DFlash-b16"
return dict(
model=target_model,
trust_remote_code=True,
speculative_config={
"method": "dflash",
"model": draft_model,
"num_speculative_tokens": 16,
"max_model_len": 32768,
},
max_model_len=32768,
max_num_seqs=128,
gpu_memory_utilization=0.85,
enforce_eager=False,
disable_log_stats=False,
)
def test_dflash_acceptance_rates(dflash_config):
"""
E2E test for DFlash (block diffusion) speculative decoding.
Runs acceptance rate validation on GSM8k, MT-Bench, and HumanEval
comparing against baseline results from the paper (Table 1).
See https://github.com/z-lab/dflash/blob/main/benchmark_sglang.py for methodology.
"""
spec_llm = LLM(**dflash_config)
max_prompts_per_dataset = 200 # mt-bench has 80, humaneval has 164, truncates gsm8k
# All scores from Table 1 in https://arxiv.org/pdf/2602.06036
expected_acceptance_lengths = {
"mt-bench": 4.24,
"humaneval": 6.50,
"gsm8k": 6.54 * 0.95, # runs with a subset of prompts so extra wide tol here
}
tokenizer = spec_llm.get_tokenizer()
for dataset_name, expected_len in expected_acceptance_lengths.items():
dataset = load_and_process_dataset(dataset_name)
prev_metrics = None
acceptance_lengths = []
for i in tqdm(
range(min(max_prompts_per_dataset, len(dataset))),
desc=f"Processing {dataset_name}",
):
user_content = dataset[i]["turns"][0]
prompt_text = tokenizer.apply_chat_template(
[{"role": "user", "content": user_content}],
tokenize=False,
add_generation_prompt=True,
enable_thinking=False,
)
# Temp=0, MaxTokens=2048 from the paper
spec_llm.generate(
[prompt_text],
SamplingParams(temperature=0, max_tokens=2048),
use_tqdm=False,
)
current_metrics = spec_llm.get_metrics()
acceptance_len = compute_acceptance_len(current_metrics, prev_metrics)
prev_metrics = current_metrics
acceptance_lengths.append(acceptance_len)
mean_acceptance_length = sum(acceptance_lengths) / len(acceptance_lengths)
expected_len = expected_len * 0.9
print(
f"DFlash acceptance_len for {dataset_name}: {mean_acceptance_length:.2f}"
f" (expected at least {expected_len:.2f})"
)
assert mean_acceptance_length >= expected_len, (
f"DFlash acceptance_len for {dataset_name} is below expected threshold:"
f"{mean_acceptance_length:.2f} < {expected_len:.2f}"
)
del spec_llm
torch.accelerator.empty_cache()
cleanup_dist_env_and_memory()
def test_dflash_correctness(dflash_config):
"""
E2E test for DFlash (block diffusion) speculative decoding.
Ensures output correctness on GSM8k, with cudagraphs and batching on.
"""
spec_llm = LLM(**dflash_config)
# Evaluate GSM8k accuracy (Qwen3-8B ref: ~87-92% on GSM8k)
evaluate_llm_for_gsm8k(spec_llm, expected_accuracy_threshold=0.8)
current_metrics = spec_llm.get_metrics()
acceptance_len = compute_acceptance_len(current_metrics)
# AR is thoroughly validated in test_dflash_acceptance_rates, in a manner consistent
# with the DFlash paper. However, that test measures AL per-request and thus runs
# with a batch size of 1. To ensure that AL does not collapse with large batch sizes
# we enforce a baseline on the AL over the full lm-eval-style GSM8k test.
expected_len = 3.5 # Measured is 3.9 to 4.0
print(f"DFlash GSM8k correctness test got AL {acceptance_len}")
assert acceptance_len >= expected_len, (
"DFlash correctness check failed with"
f" {acceptance_len=}, expected at least {expected_len}"
)
del spec_llm
torch.accelerator.empty_cache()
cleanup_dist_env_and_memory()
......@@ -27,6 +27,7 @@ from vllm.config.load import LoadConfig
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.platforms import current_platform
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.spec_decode.dflash import DFlashProposer
from vllm.v1.spec_decode.draft_model import DraftModelProposer
from vllm.v1.spec_decode.eagle import EagleProposer
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
......@@ -36,6 +37,8 @@ model_dir = "meta-llama/Llama-3.1-8B-Instruct"
eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
eagle3_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
ar_draft_model_dir = "amd/PARD-Llama-3.2-1B" # Compatible with parallel and AR drafting
dflash_target_dir = "Qwen/Qwen3-8B"
dflash_dir = "z-lab/Qwen3-8B-DFlash-b16"
BLOCK_SIZE = 16
......@@ -47,18 +50,29 @@ def _create_proposer(
speculative_token_tree: list[tuple[int, ...]] | None = None,
parallel_drafting: bool = False,
) -> EagleProposer:
model_config = ModelConfig(model=model_dir, runner="generate", max_model_len=100)
# Method-dependent setup
if method == "eagle":
target_model_dir = model_dir
draft_model_dir = eagle_dir
elif method == "eagle3":
target_model_dir = model_dir
draft_model_dir = eagle3_dir
elif method == "draft_model":
target_model_dir = model_dir
draft_model_dir = ar_draft_model_dir
elif method == "dflash":
target_model_dir = dflash_target_dir
draft_model_dir = dflash_dir
else:
raise ValueError(f"Unknown method: {method}")
model_config = ModelConfig(
model=target_model_dir,
runner="generate",
max_model_len=100,
trust_remote_code=(method == "dflash"),
)
spec_token_tree_str = None
if speculative_token_tree is not None:
assert num_speculative_tokens == len(speculative_token_tree)
......@@ -92,7 +106,9 @@ def _create_proposer(
attention_config=AttentionConfig(backend=attention_backend),
)
if "eagle" in method:
if method == "dflash":
proposer = DFlashProposer(vllm_config=vllm_config, device=device)
elif "eagle" in method:
proposer = EagleProposer(vllm_config=vllm_config, device=device)
else:
proposer = DraftModelProposer(vllm_config=vllm_config, device=device)
......@@ -1152,3 +1168,136 @@ def test_propose_tree(spec_token_tree):
# Verify that the draft tokens match our expectations.
assert torch.equal(result, expected_tokens)
def test_set_inputs_first_pass_dflash():
"""
Test for DFlash set_inputs_first_pass.
DFlash uses cross-attention: context tokens become K/V and only
query tokens (bonus + mask) are Q. This tests the DFlash-specific
input preparation where:
- Context hidden states are stored by reference (no copy)
- Query input_ids are [next_token, mask, mask, ...] per request
- Context and query positions are written to separate buffers
- token_indices_to_sample points to mask token positions only
- A new CommonAttentionMetadata is returned with causal=False
Setup:
- 3 requests with query_lens [3, 2, 4]
- num_speculative_tokens = 3
- num_query_per_req = 4 (1 bonus + 3 mask tokens)
- next_token_ids: [100, 200, 300]
Expected output layout (query tokens only, 12 total):
Request 0 (indices 0-3): [100, mask, mask, mask]
Request 1 (indices 4-7): [200, mask, mask, mask]
Request 2 (indices 8-11): [300, mask, mask, mask]
Expected positions layout (separate buffers):
Context (_context_positions_buffer, 9 tokens): copied from target_positions
Query (positions, 12 tokens):
Request 0: last_pos=9, query=[10, 11, 12, 13]
Request 1: last_pos=7, query=[8, 9, 10, 11]
Request 2: last_pos=11, query=[12, 13, 14, 15]
"""
device = torch.device(current_platform.device_type)
num_speculative_tokens = 3
proposer = _create_proposer("dflash", num_speculative_tokens)
mask_token_id = proposer.parallel_drafting_token_id
# Setup batch with 3 requests
batch_spec = BatchSpec(
seq_lens=[10, 8, 12],
query_lens=[3, 2, 4],
)
common_attn_metadata = create_common_attn_metadata(
batch_spec,
block_size=BLOCK_SIZE,
device=device,
arange_block_indices=True,
)
# Input tensors
# Request 0: tokens [10, 11, 12] at positions [7, 8, 9]
# Request 1: tokens [20, 21] at positions [6, 7]
# Request 2: tokens [30, 31, 32, 33] at positions [8, 9, 10, 11]
target_token_ids = torch.tensor(
[10, 11, 12, 20, 21, 30, 31, 32, 33], dtype=torch.int32, device=device
)
target_positions = torch.tensor(
[7, 8, 9, 6, 7, 8, 9, 10, 11], dtype=torch.int64, device=device
)
target_hidden_states = torch.randn(
9, proposer.hidden_size, dtype=proposer.dtype, device=device
)
next_token_ids = torch.tensor([100, 200, 300], dtype=torch.int32, device=device)
num_tokens, token_indices_to_sample, output_cad = proposer.set_inputs_first_pass(
target_token_ids=target_token_ids,
next_token_ids=next_token_ids,
target_positions=target_positions,
target_hidden_states=target_hidden_states,
token_indices_to_sample=None,
cad=common_attn_metadata,
num_rejected_tokens_gpu=None,
)
num_query_per_req = 1 + num_speculative_tokens # 4
num_context = 9
# num_tokens is the query-only count
assert num_tokens == 3 * num_query_per_req # 12
# Verify input_ids (query tokens only)
# Each request: [next_token, mask, mask, mask]
M = mask_token_id
expected_input_ids = torch.tensor(
[100, M, M, M, 200, M, M, M, 300, M, M, M],
dtype=torch.int32,
device=device,
)
assert torch.equal(proposer.input_ids[:num_tokens], expected_input_ids)
# Verify context positions (separate buffer): copied from target_positions
assert torch.equal(
proposer._context_positions_buffer[:num_context], target_positions
)
# Verify query positions (separate buffer, starts at index 0):
# req0: last_pos=9, query=[10, 11, 12, 13]
# req1: last_pos=7, query=[8, 9, 10, 11]
# req2: last_pos=11, query=[12, 13, 14, 15]
expected_query_positions = torch.tensor(
[10, 11, 12, 13, 8, 9, 10, 11, 12, 13, 14, 15],
dtype=torch.int64,
device=device,
)
assert torch.equal(
proposer.positions[:num_tokens],
expected_query_positions,
)
# Verify token_indices_to_sample (mask tokens only, skip bonus at offset 0)
# req0: query indices 0-3, mask at 1,2,3
# req1: query indices 4-7, mask at 5,6,7
# req2: query indices 8-11, mask at 9,10,11
expected_token_indices_to_sample = torch.tensor(
[1, 2, 3, 5, 6, 7, 9, 10, 11], dtype=torch.int32, device=device
)
assert torch.equal(token_indices_to_sample, expected_token_indices_to_sample)
# Verify the new CAD has DFlash-specific properties
assert output_cad.causal is False # DFlash requires non-causal attention
assert output_cad.num_actual_tokens == num_tokens # query-only count
assert output_cad.max_query_len == num_query_per_req
expected_query_start_loc = torch.tensor(
[0, 4, 8, 12], dtype=torch.int32, device=device
)
assert torch.equal(output_cad.query_start_loc, expected_query_start_loc)
# Verify hidden states (stored by reference, not copied)
assert proposer._dflash_hidden_states is target_hidden_states
......@@ -47,8 +47,11 @@ MTPModelTypes = Literal[
"pangu_ultra_moe_mtp",
"step3p5_mtp",
]
EagleModelTypes = Literal["eagle", "eagle3", "extract_hidden_states", MTPModelTypes]
NgramGPUTypes = Literal["ngram_gpu"]
DFlashModelTypes = Literal["dflash"]
EagleModelTypes = Literal[
"eagle", "eagle3", "extract_hidden_states", MTPModelTypes, DFlashModelTypes
]
SpeculativeMethod = Literal[
"ngram",
"medusa",
......@@ -206,7 +209,11 @@ class SpeculativeConfig:
factors: list[Any] = []
# Eagle3 and extract_hidden_states affect the computation graph because
# they return intermediate hidden states in addition to the final hidden state.
uses_aux_hidden_states = self.method in ("eagle3", "extract_hidden_states")
uses_aux_hidden_states = self.method in (
"eagle3",
"extract_hidden_states",
"dflash",
)
factors.append(uses_aux_hidden_states)
# The specific layers used also affect the computation graph
......@@ -490,7 +497,7 @@ class SpeculativeConfig:
)
# Automatically detect the method
if self.method in ("eagle", "eagle3"):
if self.method in ("eagle", "eagle3", "dflash"):
pass
# examples:
# yuhuili/EAGLE-LLaMA3-Instruct-8B
......@@ -500,6 +507,8 @@ class SpeculativeConfig:
self.method = "eagle"
elif "eagle3" in self.draft_model_config.model.lower():
self.method = "eagle3"
elif "dflash" in self.draft_model_config.model.lower():
self.method = "dflash"
elif self.draft_model_config.hf_config.model_type == "medusa":
self.method = "medusa"
elif self.draft_model_config.hf_config.model_type == "mlp_speculator":
......@@ -532,7 +541,7 @@ class SpeculativeConfig:
)
# Replace hf_config for EAGLE draft_model
if self.method in ("eagle", "eagle3"):
if self.method in ("eagle", "eagle3", "dflash"):
from vllm.transformers_utils.configs.eagle import EAGLEConfig
from vllm.transformers_utils.configs.speculators import (
SpeculatorsConfig,
......@@ -552,6 +561,9 @@ class SpeculativeConfig:
self.draft_model_config.hf_config = eagle_config
self.update_arch_()
if self.method == "dflash":
self.parallel_drafting = True
if self.num_speculative_tokens is not None and hasattr(
self.draft_model_config.hf_config, "num_lookahead_tokens"
):
......@@ -807,7 +819,7 @@ class SpeculativeConfig:
"kimi_k25",
]
if (
self.method in ("eagle3", "extract_hidden_states")
self.method in ("eagle3", "extract_hidden_states", "dflash")
and self.target_model_config
and not any(
supported_model in self.target_model_config.hf_text_config.model_type
......@@ -855,7 +867,10 @@ class SpeculativeConfig:
return slots_per_req
def use_eagle(self) -> bool:
return self.method in ("eagle", "eagle3", "mtp")
return self.method in ("eagle", "eagle3", "mtp", "dflash")
def use_dflash(self) -> bool:
return self.method == "dflash"
def uses_draft_model(self) -> bool:
return self.method == "draft_model"
......
......@@ -1327,6 +1327,26 @@ class VllmConfig:
max_num_batched_tokens - scheduled_token_delta
)
if self.scheduler_config.max_num_scheduled_tokens <= 0:
raise ValueError(
"max_num_scheduled_tokens is set to"
f" {self.scheduler_config.max_num_scheduled_tokens} based on"
" the speculative decoding settings, which does not allow"
" any tokens to be scheduled. Increase max_num_batched_tokens"
" to accommodate the additional draft token slots, or decrease"
" num_speculative_tokens or max_num_seqs."
)
if self.scheduler_config.max_num_scheduled_tokens < 8192:
logger.warning_once(
"max_num_scheduled_tokens is set to"
f" {self.scheduler_config.max_num_scheduled_tokens} based on"
" the speculative decoding settings. This may lead to suboptimal"
" performance. Consider increasing max_num_batched_tokens to"
" accommodate the additional draft token slots, or decrease"
" num_speculative_tokens or max_num_seqs.",
scope="local",
)
max_num_scheduled_tokens = self.scheduler_config.max_num_scheduled_tokens
if max_num_batched_tokens < max_num_scheduled_tokens + (
self.speculative_config.max_num_new_slots_for_drafting
......
......@@ -285,6 +285,7 @@ class Qwen3ForCausalLM(
self.config = config
self.vllm_config = vllm_config
self.quant_config = quant_config
self.model = Qwen3Model(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
......
This diff is collapsed.
......@@ -56,6 +56,7 @@ from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.qwen3_next import Qwen3NextConfig
from .interfaces import (
EagleModelMixin,
HasInnerState,
IsHybrid,
MixtureOfExperts,
......@@ -454,7 +455,7 @@ class Qwen3NextDecoderLayer(nn.Module):
@support_torch_compile
class Qwen3NextModel(nn.Module):
class Qwen3NextModel(nn.Module, EagleModelMixin):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
......@@ -492,8 +493,6 @@ class Qwen3NextModel(nn.Module):
else:
self.norm = PPMissingLayer()
self.aux_hidden_state_layers: tuple[int, ...] = ()
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
......@@ -515,20 +514,19 @@ class Qwen3NextModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
aux_hidden_states = []
aux_hidden_states = self._maybe_add_hidden_state([], 0, hidden_states, residual)
for layer_idx, layer in enumerate(
islice(self.layers, self.start_layer, self.end_layer),
start=self.start_layer,
):
if layer_idx in self.aux_hidden_state_layers:
aux_hidden_states.append(
hidden_states + residual if residual is not None else hidden_states
)
hidden_states, residual = layer(
positions=positions,
hidden_states=hidden_states,
residual=residual,
)
self._maybe_add_hidden_state(
aux_hidden_states, layer_idx + 1, hidden_states, residual
)
if not get_pp_group().is_last_rank:
return IntermediateTensors(
......
......@@ -546,6 +546,7 @@ _SPECULATIVE_DECODING_MODELS = {
"EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
"EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"),
"EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
"DFlashDraftModel": ("qwen3_dflash", "DFlashQwen3ForCausalLM"),
"Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
"LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
"Eagle3Qwen2_5vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
......
......@@ -62,9 +62,20 @@ class EAGLEConfig(PretrainedConfig):
else f"Eagle3{arch}"
for arch in self.model.architectures
]
elif method == "dflash":
assert self.model is not None, (
"model should not be None when method is dflash"
)
kwargs["architectures"] = [
arch
if arch.startswith("DFlash") or arch.endswith("DFlash")
else f"DFlash{arch}"
for arch in self.model.architectures
]
else:
raise ValueError(
f"Invalid method {method}. Supported methods are eagle and eagle3."
f"Invalid method {method}. Supported methods are "
"eagle, eagle3, and dflash."
)
super().__init__(**kwargs)
......
......@@ -220,6 +220,17 @@ class AttentionBackend(ABC):
def supports_per_head_quant_scales(cls) -> bool:
return False
@classmethod
def supports_non_causal(cls) -> bool:
"""Check if backend supports non-causal (bidirectional) attention
for decoder models.
Unlike ENCODER_ONLY attention type which implies a different
execution model, this refers to non-causal attention within the
standard paged-KV-cache decoder path.
"""
return False
@classmethod
def supports_attn_type(cls, attn_type: str) -> bool:
"""Check if backend supports a given attention type.
......@@ -261,6 +272,7 @@ class AttentionBackend(ABC):
use_per_head_quant_scales: bool,
device_capability: "DeviceCapability",
attn_type: str,
use_non_causal: bool = False,
) -> list[str]:
invalid_reasons = []
if not cls.supports_head_size(head_size):
......@@ -293,6 +305,8 @@ class AttentionBackend(ABC):
invalid_reasons.append("compute capability not supported")
if not cls.supports_attn_type(attn_type):
invalid_reasons.append(f"attention type {attn_type} not supported")
if use_non_causal and not cls.supports_non_causal():
invalid_reasons.append("non-causal attention not supported")
combination_reason = cls.supports_combination(
head_size,
dtype,
......
......@@ -101,6 +101,10 @@ class FlashAttentionBackend(AttentionBackend):
def get_name() -> str:
return "FLASH_ATTN"
@classmethod
def supports_non_causal(cls) -> bool:
return True
@classmethod
def supports_attn_type(cls, attn_type: str) -> bool:
"""FlashAttention supports all attention types."""
......
......@@ -29,6 +29,7 @@ class AttentionSelectorConfig(NamedTuple):
use_mm_prefix: bool = False
use_per_head_quant_scales: bool = False
attn_type: str = AttentionType.DECODER
use_non_causal: bool = False
def __repr__(self):
return (
......@@ -41,7 +42,8 @@ class AttentionSelectorConfig(NamedTuple):
f"use_sparse={self.use_sparse}, "
f"use_mm_prefix={self.use_mm_prefix}, "
f"use_per_head_quant_scales={self.use_per_head_quant_scales}, "
f"attn_type={self.attn_type})"
f"attn_type={self.attn_type}, "
f"use_non_causal={self.use_non_causal})"
)
......@@ -76,6 +78,11 @@ def get_attn_backend(
else:
block_size = None
speculative_config = vllm_config.speculative_config
use_non_causal = (
speculative_config is not None and speculative_config.method == "dflash"
)
attn_selector_config = AttentionSelectorConfig(
head_size=head_size,
dtype=dtype,
......@@ -87,6 +94,7 @@ def get_attn_backend(
use_mm_prefix=use_mm_prefix,
use_per_head_quant_scales=use_per_head_quant_scales,
attn_type=attn_type or AttentionType.DECODER,
use_non_causal=use_non_causal,
)
return _cached_get_attn_backend(
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
import torch
from typing_extensions import override
from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.triton_utils import triton
from vllm.v1.attention.backend import CommonAttentionMetadata
from vllm.v1.spec_decode.eagle import SpecDecodeBaseProposer
from vllm.v1.spec_decode.utils import copy_and_expand_dflash_inputs_kernel
logger = init_logger(__name__)
class DFlashProposer(SpecDecodeBaseProposer):
def __init__(
self,
vllm_config: VllmConfig,
device: torch.device,
runner=None,
):
assert vllm_config.speculative_config is not None
assert vllm_config.speculative_config.method == "dflash"
super().__init__(
vllm_config=vllm_config,
device=device,
pass_hidden_states_to_model=True,
runner=runner,
)
# Only next_token_ids and mask tokens are query tokens, all other context is K/V
self.max_query_tokens = self.max_batch_size * (1 + self.num_speculative_tokens)
# Positions covers both context states + query states
self.max_positions = self.max_num_tokens + self.max_query_tokens
# Separate context buffers to keep query buffer addresses stable for CUDA graphs
self._context_slot_mapping_buffer = torch.zeros(
self.max_num_tokens,
dtype=torch.int64,
device=device,
)
self._slot_mapping_buffer = torch.zeros(
self.max_query_tokens,
dtype=torch.int64,
device=device,
)
self._context_positions_buffer = torch.zeros(
self.max_num_tokens,
dtype=torch.int64,
device=device,
)
self.positions = torch.zeros(
self.max_query_tokens,
dtype=torch.int64,
device=device,
)
self.arange = torch.arange(
self.max_positions + 1, device=device, dtype=torch.int32
)
# For DFlash we use the input embeddings to embed the mask token
self.parallel_drafting_hidden_state_tensor = None
@override
def _raise_if_multimodal(self):
# Override to allow multimodal inputs since DFlash supports Qwen3.5 models
# Support for multimodal inputs has not been tested.
pass
@override
def set_inputs_first_pass(
self,
target_token_ids: torch.Tensor,
next_token_ids: torch.Tensor,
target_positions: torch.Tensor,
target_hidden_states: torch.Tensor,
token_indices_to_sample: torch.Tensor | None,
cad: CommonAttentionMetadata,
num_rejected_tokens_gpu: torch.Tensor | None,
) -> tuple[int, torch.Tensor, CommonAttentionMetadata]:
# DFlash cross-attention: context K/V from target hidden states,
# Q from query embeddings (bonus + mask tokens).
batch_size = cad.batch_size()
num_context = target_token_ids.shape[0]
num_query_per_req = 1 + self.num_speculative_tokens
num_query_total = batch_size * num_query_per_req
# Store for build_model_inputs_first_pass to use
self._dflash_num_context = num_context
# We don't need to copy into a buffer here since the context preprocessing
# does not run in a CUDA graph
self._dflash_hidden_states = target_hidden_states
token_indices_to_sample = torch.empty(
batch_size * self.num_speculative_tokens,
dtype=torch.int32,
device=self.device,
)
# Launch fused triton kernel for input_ids, positions, slot_mapping,
# and token_indices_to_sample
max_ctx_per_req = cad.max_query_len
max_tokens_per_req = max_ctx_per_req + num_query_per_req
BLOCK_SIZE = min(256, triton.next_power_of_2(max_tokens_per_req))
num_blocks = triton.cdiv(max_tokens_per_req, BLOCK_SIZE)
grid = (batch_size, num_blocks)
has_num_rejected = num_rejected_tokens_gpu is not None
copy_and_expand_dflash_inputs_kernel[grid](
# Inputs
next_token_ids_ptr=next_token_ids,
target_positions_ptr=target_positions,
# Outputs
out_input_ids_ptr=self.input_ids,
out_context_positions_ptr=self._context_positions_buffer,
out_query_positions_ptr=self.positions,
out_context_slot_mapping_ptr=self._context_slot_mapping_buffer,
out_query_slot_mapping_ptr=self._slot_mapping_buffer,
out_token_indices_ptr=token_indices_to_sample,
# Block table
block_table_ptr=cad.block_table_tensor,
block_table_stride=cad.block_table_tensor.stride(0),
# Metadata
query_start_loc_ptr=cad.query_start_loc,
num_rejected_tokens_ptr=(
num_rejected_tokens_gpu if has_num_rejected else 0
),
# Scalars
parallel_drafting_token_id=self.parallel_drafting_token_id,
block_size=self.block_size,
num_query_per_req=num_query_per_req,
num_speculative_tokens=self.num_speculative_tokens,
total_input_tokens=num_context,
BLOCK_SIZE=BLOCK_SIZE,
HAS_NUM_REJECTED=has_num_rejected,
)
query_slot_mapping = self._slot_mapping_buffer[:num_query_total]
new_query_start_loc = self.arange[: batch_size + 1] * num_query_per_req
# In padded mode, cad.seq_lens includes rejected tokens. Subtract
# them so attention only sees the valid prefix of context states.
effective_seq_lens = cad.seq_lens
if has_num_rejected:
effective_seq_lens = effective_seq_lens - num_rejected_tokens_gpu
new_cad = CommonAttentionMetadata(
query_start_loc=new_query_start_loc,
seq_lens=effective_seq_lens + num_query_per_req,
query_start_loc_cpu=(
torch.from_numpy(self.token_arange_np[: batch_size + 1]).clone()
* num_query_per_req
),
_seq_lens_cpu=None,
_num_computed_tokens_cpu=None,
num_reqs=cad.num_reqs,
num_actual_tokens=num_query_total,
max_query_len=num_query_per_req,
max_seq_len=cad.max_seq_len + num_query_per_req,
block_table_tensor=cad.block_table_tensor,
slot_mapping=query_slot_mapping,
causal=False, # Non-causal attention is required for DFlash
)
return num_query_total, token_indices_to_sample, new_cad
@override
@torch.inference_mode()
def dummy_run(
self,
num_tokens: int,
use_cudagraphs: bool = True,
is_graph_capturing: bool = False,
slot_mappings: dict[str, torch.Tensor] | None = None,
) -> None:
"""
Key differences to default dummy_run:
- Only one forward pass due to parallel drafting
- DFlash uses context states as unpadded metadata, so hidden_states will
use the unpadded num_tokens instead of num_input_tokens
- max_query_tokens is quite small, DFlash only sees spec tokens as queries
- Multimodal inputs are not currently supported
"""
num_query_tokens = min(num_tokens, self.max_query_tokens)
cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = (
self._determine_batch_execution_and_padding(
num_query_tokens, use_cudagraphs=use_cudagraphs
)
)
# Slot mapping sized to num_input_tokens (query only), matching
# the K/V tensor size from the model forward. Context KVs are
# pre-inserted separately and don't flow through the model.
if (
self._draft_attn_layer_names
and slot_mappings is not None
and next(iter(self._draft_attn_layer_names)) in slot_mappings
):
slot_mapping_dict = self._get_slot_mapping(num_input_tokens)
else:
slot_mapping_dict = slot_mappings or {}
# Context and query positions use separate buffers; no copy needed.
context_positions = self._context_positions_buffer[:num_tokens]
# Context states will be passed directly to the precomputation without
# going through the buffer, since no CUDA graph is used for the precomputation.
# For the dummy run, we use the dummy buffer.
context_states = self.hidden_states[:num_tokens]
# Run the KV projection (GEMM + norms + RoPE) for memory profiling,
self.model.precompute_and_store_context_kv(context_states, context_positions)
with set_forward_context(
None,
self.vllm_config,
num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=cudagraph_runtime_mode,
slot_mapping=slot_mapping_dict,
):
self.model(
input_ids=self.input_ids[:num_input_tokens],
positions=self._get_positions(num_input_tokens),
inputs_embeds=None,
)
@override
def build_model_inputs_first_pass(
self,
num_tokens: int,
num_input_tokens: int,
mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None,
) -> tuple[dict[str, Any], int]:
# Context and query positions/slots were written to separate
# buffers by the kernel — no copy needed.
num_context = self._dflash_num_context
# Pre-insert context KVs directly into cache
self.model.precompute_and_store_context_kv(
self._dflash_hidden_states, # Shape is already [num_context, hidden_size]
self._context_positions_buffer[:num_context],
self._context_slot_mapping_buffer[:num_context],
)
return (
dict(
input_ids=self.input_ids[:num_input_tokens],
positions=self._get_positions(num_input_tokens),
inputs_embeds=None,
),
num_input_tokens,
)
@override
def build_per_layer_attn_metadata(
self, cad: CommonAttentionMetadata, draft_index: int = 0
) -> dict[str, object]:
per_layer_attention_metadata = super().build_per_layer_attn_metadata(
cad, draft_index
)
for layer_name, attn_metadata in per_layer_attention_metadata.items():
assert getattr(attn_metadata, "causal", None) is False, (
f"Attention metadata for layer {layer_name} does not have"
" non-causal support, which is required for DFlash."
" Consider using a different attention backend, such as FlashAttention."
)
return per_layer_attention_metadata
@override
def _get_eagle3_use_aux_hidden_state_from_config(self):
use_aux_hidden_state = True
dflash_config = getattr(
self.draft_model_config.hf_config, "dflash_config", None
)
if dflash_config is not None:
use_aux_hidden_state = dflash_config.get("use_aux_hidden_state", True)
return use_aux_hidden_state
......@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import ast
from importlib.util import find_spec
from typing import cast
from typing import Any, cast
import numpy as np
import torch
......@@ -23,6 +23,7 @@ from vllm.model_executor.models import supports_multimodal
from vllm.model_executor.models.deepseek_eagle3 import Eagle3DeepseekV2ForCausalLM
from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
from vllm.model_executor.models.qwen3_dflash import DFlashQwen3ForCausalLM
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.platforms import current_platform
from vllm.triton_utils import triton
......@@ -83,13 +84,15 @@ class SpecDecodeBaseProposer:
self.hidden_size = self.draft_model_config.get_hidden_size()
self.inputs_embeds_size = self.draft_model_config.get_inputs_embeds_size()
# Unifying eagle, draft model, and parallel drafting support
# Unifying eagle, draft model, and parallel drafting support.
# DFlash always uses parallel drafting (all tokens in one pass),
# but has an additional slot for the next_token_id (does not shift like EAGLE)
self.parallel_drafting: bool = self.speculative_config.parallel_drafting
self.extra_slots_per_request = (
1 if not self.parallel_drafting else self.num_speculative_tokens
)
self.net_num_new_slots_per_request = self.extra_slots_per_request - (
1 if self.pass_hidden_states_to_model else 0
1 if (self.pass_hidden_states_to_model and self.method != "dflash") else 0
)
self.needs_extra_input_slots = self.net_num_new_slots_per_request > 0
......@@ -101,10 +104,14 @@ class SpecDecodeBaseProposer:
self.speculative_config.use_local_argmax_reduction
)
max_batch_size = vllm_config.scheduler_config.max_num_seqs
self.max_batch_size = vllm_config.scheduler_config.max_num_seqs
self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
self.token_arange_np = np.arange(self.max_num_tokens)
# Can be specialized by methods like DFlash to reduce the limit
self.max_query_tokens = self.max_num_tokens
self.max_positions = self.max_num_tokens
# Multi-modal data support
self.mm_registry = MULTIMODAL_REGISTRY
self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
......@@ -146,18 +153,20 @@ class SpecDecodeBaseProposer:
# 1D-RoPE.
# See page 5 of https://arxiv.org/abs/2409.12191
self.mrope_positions = torch.zeros(
(3, self.max_num_tokens + 1), dtype=torch.int64, device=device
(3, self.max_positions + 1), dtype=torch.int64, device=device
)
elif self.uses_xdrope_dim > 0 and self.draft_uses_xdrope_dim > 0:
self.xdrope_positions = torch.zeros(
(self.uses_xdrope_dim, self.max_num_tokens + 1),
(self.uses_xdrope_dim, self.max_positions + 1),
dtype=torch.int64,
device=device,
)
else:
# RoPE need (max_num_tokens,)
self.positions = torch.zeros(
self.max_num_tokens, dtype=torch.int64, device=device
self.max_positions,
dtype=torch.int64,
device=device,
)
self.hidden_states = torch.zeros(
(self.max_num_tokens, self.hidden_size), dtype=self.dtype, device=device
......@@ -168,7 +177,7 @@ class SpecDecodeBaseProposer:
# We need +1 here because the arange is used to set query_start_loc,
# which has one more element than batch_size.
max_num_slots_for_arange = max(max_batch_size + 1, self.max_num_tokens)
max_num_slots_for_arange = max(self.max_batch_size + 1, self.max_num_tokens)
self.arange = torch.arange(
max_num_slots_for_arange, device=device, dtype=torch.int32
)
......@@ -200,7 +209,7 @@ class SpecDecodeBaseProposer:
)
self.backup_next_token_ids = CpuGpuBuffer(
max_batch_size,
self.max_batch_size,
dtype=torch.int32,
pin_memory=is_pin_memory_available(),
device=device,
......@@ -208,7 +217,9 @@ class SpecDecodeBaseProposer:
)
self._slot_mapping_buffer = torch.zeros(
self.max_num_tokens, dtype=torch.int64, device=device
self.max_positions,
dtype=torch.int64,
device=device,
)
# Determine allowed attention backends once during initialization.
......@@ -275,7 +286,7 @@ class SpecDecodeBaseProposer:
# Precompute draft position offsets in flattened tree.
self.tree_draft_pos_offsets = torch.arange(
1, len(self.tree_choices) + 1, device=device, dtype=torch.int32
).repeat(max_batch_size, 1)
).repeat(self.max_batch_size, 1)
def _raise_if_padded_drafter_batch_disabled(self):
if self.speculative_config.disable_padded_drafter_batch:
......@@ -305,14 +316,19 @@ class SpecDecodeBaseProposer:
# for those masked slots.
model_hf_config = self.draft_model_config.hf_config
if hasattr(model_hf_config, "pard_token"):
# DFlash stores mask_token_id in dflash_config
dflash_config = getattr(model_hf_config, "dflash_config", None)
if dflash_config and "mask_token_id" in dflash_config:
self.parallel_drafting_token_id = dflash_config["mask_token_id"]
elif hasattr(model_hf_config, "pard_token"):
self.parallel_drafting_token_id = model_hf_config.pard_token
elif hasattr(model_hf_config, "ptd_token_id"):
self.parallel_drafting_token_id = model_hf_config.ptd_token_id
else:
raise ValueError(
"For parallel drafting, the draft model config must have "
"`pard_token` or `ptd_token_id` specified in its config.json."
"`pard_token`, `ptd_token_id`, or "
"`dflash_config.mask_token_id` specified in its config.json."
)
if self.pass_hidden_states_to_model:
......@@ -402,9 +418,14 @@ class SpecDecodeBaseProposer:
) -> torch.Tensor:
batch_size = common_attn_metadata.batch_size()
if self.method == "eagle3":
if self.method in ("eagle3", "dflash"):
assert isinstance(
self.model, (Eagle3LlamaForCausalLM, Eagle3DeepseekV2ForCausalLM)
self.model,
(
Eagle3LlamaForCausalLM,
Eagle3DeepseekV2ForCausalLM,
DFlashQwen3ForCausalLM,
),
)
target_hidden_states = self.model.combine_hidden_states(
target_hidden_states
......@@ -423,40 +444,17 @@ class SpecDecodeBaseProposer:
)
)
per_layer_attn_metadata: dict[str, object] = {}
for attn_group in self.draft_attn_groups:
attn_metadata = attn_group.get_metadata_builder().build_for_drafting(
common_attn_metadata=common_attn_metadata, draft_index=0
)
for layer_name in attn_group.layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata
per_layer_attn_metadata = self.build_per_layer_attn_metadata(
common_attn_metadata
)
cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = (
self._determine_batch_execution_and_padding(num_tokens)
)
if self.supports_mm_inputs:
mm_embeds, is_mm_embed = mm_embed_inputs or (None, None)
self.inputs_embeds[:num_tokens] = self.model.embed_input_ids(
self.input_ids[:num_tokens],
multimodal_embeddings=mm_embeds,
is_multimodal=is_mm_embed,
)
input_ids = None
inputs_embeds = self.inputs_embeds[:num_input_tokens]
else:
input_ids = self.input_ids[:num_input_tokens]
inputs_embeds = None
model_kwargs = {
"input_ids": input_ids,
"positions": self._get_positions(num_input_tokens),
"inputs_embeds": inputs_embeds,
}
if self.pass_hidden_states_to_model:
model_kwargs["hidden_states"] = self.hidden_states[:num_input_tokens]
model_kwargs, slot_mapping_size = self.build_model_inputs_first_pass(
num_tokens, num_input_tokens, mm_embed_inputs
)
with set_forward_context(
per_layer_attn_metadata,
......@@ -465,7 +463,7 @@ class SpecDecodeBaseProposer:
num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=cudagraph_runtime_mode,
slot_mapping=self._get_slot_mapping(
num_input_tokens, common_attn_metadata.slot_mapping
slot_mapping_size, common_attn_metadata.slot_mapping
),
):
ret_hidden_states = self.model(**model_kwargs)
......@@ -488,7 +486,10 @@ class SpecDecodeBaseProposer:
positions = self.positions[token_indices_to_sample]
hidden_states = hidden_states[token_indices_to_sample]
if isinstance(attn_metadata, TreeAttentionMetadata):
if any(
isinstance(attn_metadata, TreeAttentionMetadata)
for attn_metadata in per_layer_attn_metadata.values()
):
# Draft using tree attention - requires full logits for top-k
logits = self.model.compute_logits(sample_hidden_states)
draft_token_ids_list = self.propose_tree(
......@@ -504,15 +505,16 @@ class SpecDecodeBaseProposer:
draft_token_ids = self._greedy_sample(sample_hidden_states)
if self.allowed_attn_types is not None and not isinstance(
attn_metadata, self.allowed_attn_types
):
raise ValueError(
f"Unsupported attention metadata type for speculative "
"decoding with num_speculative_tokens > 1: "
f"{type(attn_metadata)}. Supported types are: "
f"{self.allowed_attn_types}"
)
for attn_metadata in per_layer_attn_metadata.values():
if self.allowed_attn_types is not None and not isinstance(
attn_metadata, self.allowed_attn_types
):
raise ValueError(
f"Unsupported attention metadata type for speculative "
"decoding with num_speculative_tokens > 1: "
f"{type(attn_metadata)}. Supported types are: "
f"{self.allowed_attn_types}"
)
# Generate the remaining draft tokens.
draft_token_ids_list = [draft_token_ids]
......@@ -593,13 +595,9 @@ class SpecDecodeBaseProposer:
common_attn_metadata._num_computed_tokens_cpu += 1
# Rebuild attention metadata
for attn_group in self.draft_attn_groups:
attn_metadata = attn_group.get_metadata_builder().build_for_drafting(
common_attn_metadata=common_attn_metadata,
draft_index=token_index + 1,
)
for layer_name in attn_group.layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata
per_layer_attn_metadata = self.build_per_layer_attn_metadata(
common_attn_metadata, draft_index=token_index + 1
)
# copy inputs to buffer for cudagraph
self.input_ids[:batch_size] = input_ids
......@@ -780,8 +778,51 @@ class SpecDecodeBaseProposer:
return total_num_output_tokens, token_indices_to_sample, new_cad
def build_model_inputs_first_pass(
self,
num_tokens: int,
num_input_tokens: int,
mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None,
) -> tuple[dict[str, Any], int]:
if self.supports_mm_inputs:
mm_embeds, is_mm_embed = mm_embed_inputs or (None, None)
self.inputs_embeds[:num_tokens] = self.model.embed_input_ids(
self.input_ids[:num_tokens],
multimodal_embeddings=mm_embeds,
is_multimodal=is_mm_embed,
)
input_ids = None
inputs_embeds = self.inputs_embeds[:num_input_tokens]
else:
input_ids = self.input_ids[:num_input_tokens]
inputs_embeds = None
model_kwargs = {
"input_ids": input_ids,
"positions": self._get_positions(num_input_tokens),
"inputs_embeds": inputs_embeds,
}
if self.pass_hidden_states_to_model:
model_kwargs["hidden_states"] = self.hidden_states[:num_input_tokens]
return model_kwargs, num_input_tokens
def build_per_layer_attn_metadata(
self, common_attn_metadata: CommonAttentionMetadata, draft_index: int = 0
) -> dict[str, object]:
per_layer_attn_metadata: dict[str, object] = {}
for attn_group in self.draft_attn_groups:
attn_metadata = attn_group.get_metadata_builder().build_for_drafting(
common_attn_metadata=common_attn_metadata, draft_index=draft_index
)
for layer_name in attn_group.layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata
return per_layer_attn_metadata
def model_returns_tuple(self) -> bool:
return self.method not in ("mtp", "draft_model")
return self.method not in ("mtp", "draft_model", "dflash")
def prepare_next_token_ids_cpu(
self,
......@@ -1310,15 +1351,20 @@ class SpecDecodeBaseProposer:
self._maybe_share_embeddings(target_language_model)
self._maybe_share_lm_head(target_language_model)
if self.parallel_drafting and self.pass_hidden_states_to_model:
assert self.parallel_drafting_hidden_state_tensor is not None
self.parallel_drafting_hidden_state_tensor.copy_(
self.model.combine_hidden_states(
self.model.mask_hidden.view(3 * self.hidden_size)
if (
self.parallel_drafting
and self.pass_hidden_states_to_model
and self.parallel_drafting_hidden_state_tensor is not None
):
flat_mask = self.model.mask_hidden.view(-1)
if self.eagle3_use_aux_hidden_state:
# EAGLE3: mask_hidden stores all aux hidden states,
# project through combine_hidden_states
self.parallel_drafting_hidden_state_tensor.copy_(
self.model.combine_hidden_states(flat_mask)
)
if self.eagle3_use_aux_hidden_state
else self.model.mask_hidden.view(self.hidden_size)
)
else:
self.parallel_drafting_hidden_state_tensor.copy_(flat_mask)
def _maybe_share_embeddings(self, target_language_model: nn.Module) -> None:
"""
......@@ -1493,8 +1539,9 @@ class SpecDecodeBaseProposer:
) -> None:
# FIXME: when using tree-based specdec, adjust number of forward-passes
# according to the depth of the tree.
only_one_forward_pass = is_graph_capturing or self.parallel_drafting
for fwd_idx in range(
self.num_speculative_tokens if not is_graph_capturing else 1
1 if only_one_forward_pass else self.num_speculative_tokens
):
if fwd_idx <= 1:
cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = (
......
......@@ -441,6 +441,114 @@ def copy_and_expand_eagle_inputs_kernel(
)
@triton.jit
def copy_and_expand_dflash_inputs_kernel(
# Inputs
next_token_ids_ptr, # [num_reqs]
target_positions_ptr, # [num_context]
# Outputs
out_input_ids_ptr, # [num_query_total] (output)
out_context_positions_ptr, # [num_context] (output)
out_query_positions_ptr, # [num_query_total] (output)
out_context_slot_mapping_ptr, # [num_context] (output)
out_query_slot_mapping_ptr, # [num_query_total] (output)
out_token_indices_ptr, # [num_reqs * num_speculative_tokens] (output)
# Block table
block_table_ptr, # [max_reqs, max_blocks]
block_table_stride, # stride of block_table dim 0 (in elements)
# Metadata
query_start_loc_ptr, # [num_reqs + 1]
num_rejected_tokens_ptr, # [num_reqs] or null (0) when not padded
# Scalars
parallel_drafting_token_id, # tl.int32
block_size, # tl.int32
num_query_per_req, # tl.int32
num_speculative_tokens, # tl.int32
total_input_tokens, # tl.int32
BLOCK_SIZE: tl.constexpr,
HAS_NUM_REJECTED: tl.constexpr = False,
):
"""
Fused kernel for DFlash first-pass input setup.
Per request, this kernel:
1. Copies context positions from target_positions to
out_context_positions.
2. Computes query positions (last_target_pos + 1 + offset) and writes
them to out_query_positions.
3. Writes input_ids for query tokens: [next_token, mask, mask, ...].
4. Computes slot_mapping for context and query positions into separate
buffers via block_table lookup.
5. Writes token_indices_to_sample for the mask (speculative) tokens.
"""
req_idx = tl.program_id(axis=0)
block_idx = tl.program_id(axis=1)
# Load context token range for this request
ctx_start = tl.load(query_start_loc_ptr + req_idx)
ctx_end = tl.load(query_start_loc_ptr + req_idx + 1)
num_ctx = ctx_end - ctx_start
total_tokens = num_ctx + num_query_per_req
j = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
in_bounds = j < total_tokens
is_ctx = j < num_ctx
is_query = (~is_ctx) & in_bounds
query_off = j - num_ctx # offset within query portion (0-indexed)
# --- Positions ---
# Context: load from target_positions
ctx_pos_idx = tl.minimum(ctx_start + j, total_input_tokens - 1)
ctx_pos = tl.load(target_positions_ptr + ctx_pos_idx, mask=is_ctx, other=0)
# Query: last_valid_pos + 1 + query_off
# In padded mode, ctx_end includes rejected tokens; use valid_ctx_end
# to find the last accepted context position.
if HAS_NUM_REJECTED:
num_rejected = tl.load(num_rejected_tokens_ptr + req_idx)
valid_ctx_end = ctx_end - num_rejected
else:
valid_ctx_end = ctx_end
last_pos = tl.load(target_positions_ptr + valid_ctx_end - 1)
query_pos = last_pos + 1 + query_off
positions = tl.where(is_ctx, ctx_pos, query_pos)
# Context and query positions go to separate buffers.
ctx_pos_out = ctx_start + j
tl.store(out_context_positions_ptr + ctx_pos_out, ctx_pos, mask=is_ctx)
query_out = req_idx * num_query_per_req + query_off
tl.store(out_query_positions_ptr + query_out, query_pos, mask=is_query)
# --- Slot mapping (block_table lookup for all positions) ---
block_num = positions // block_size
# # Clamp block_number to avoid OOB when position is at max
block_num = tl.minimum(block_num, block_table_stride - 1)
block_id = tl.load(
block_table_ptr + req_idx * block_table_stride + block_num,
mask=in_bounds,
other=0,
).to(tl.int64)
slot = block_id * block_size + (positions % block_size)
tl.store(out_context_slot_mapping_ptr + ctx_pos_out, slot, mask=is_ctx)
tl.store(out_query_slot_mapping_ptr + query_out, slot, mask=is_query)
# --- Input IDs (query tokens only) ---
bonus_token = tl.load(next_token_ids_ptr + req_idx)
is_bonus = is_query & (query_off == 0)
input_id = tl.where(is_bonus, bonus_token, parallel_drafting_token_id)
tl.store(out_input_ids_ptr + query_out, input_id, mask=is_query)
# --- Token indices to sample (mask tokens, skip the bonus token) ---
is_sample = is_query & (query_off > 0)
sample_out_idx = req_idx * num_speculative_tokens + (query_off - 1)
tl.store(
out_token_indices_ptr + sample_out_idx,
query_out,
mask=is_sample,
)
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
def update_num_computed_tokens_for_batch_change(
num_computed_tokens: torch.Tensor,
......
......@@ -160,6 +160,7 @@ from vllm.v1.sample.logits_processor.interface import LogitsProcessor
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.rejection_sampler import RejectionSampler
from vllm.v1.sample.sampler import Sampler
from vllm.v1.spec_decode.dflash import DFlashProposer
from vllm.v1.spec_decode.draft_model import DraftModelProposer
from vllm.v1.spec_decode.eagle import EagleProposer
from vllm.v1.spec_decode.extract_hidden_states import ExtractHiddenStatesProposer
......@@ -515,6 +516,7 @@ class GPUModelRunner(
| NgramProposerGPU
| SuffixDecodingProposer
| EagleProposer
| DFlashProposer
| DraftModelProposer
| MedusaProposer
| ExtractHiddenStatesProposer
......@@ -546,6 +548,9 @@ class GPUModelRunner(
self._ngram_pinned_val_buf = torch.zeros(
self.max_num_reqs, dtype=torch.int32, pin_memory=True
)
elif self.speculative_config.use_dflash():
self.drafter = DFlashProposer(self.vllm_config, self.device, self)
self.use_aux_hidden_state_outputs = True
elif self.speculative_config.method == "suffix":
self.drafter = SuffixDecodingProposer(self.vllm_config)
elif self.speculative_config.use_eagle():
......@@ -2289,7 +2294,7 @@ class GPUModelRunner(
cm.slot_mapping = slot_mappings[kv_cache_gid]
if self.speculative_config and spec_decode_common_attn_metadata is None:
if isinstance(self.drafter, EagleProposer):
if isinstance(self.drafter, (EagleProposer, DFlashProposer)):
if self.drafter.kv_cache_gid == kv_cache_gid:
spec_decode_common_attn_metadata = cm
else:
......@@ -4202,7 +4207,10 @@ class GPUModelRunner(
# as inputs, and does not need to wait for bookkeeping to finish.
assert isinstance(
self.drafter,
EagleProposer | DraftModelProposer | ExtractHiddenStatesProposer,
EagleProposer
| DFlashProposer
| DraftModelProposer
| ExtractHiddenStatesProposer,
)
sampled_token_ids = sampler_output.sampled_token_ids
if input_fits_in_drafter:
......@@ -4589,8 +4597,14 @@ class GPUModelRunner(
next_token_ids, valid_sampled_tokens_count
)
elif spec_config.use_eagle() or spec_config.uses_draft_model():
assert isinstance(self.drafter, EagleProposer | DraftModelProposer)
elif (
spec_config.use_eagle()
or spec_config.use_dflash()
or spec_config.uses_draft_model()
):
assert isinstance(
self.drafter, EagleProposer | DFlashProposer | DraftModelProposer
)
if spec_config.disable_padded_drafter_batch:
# When padded-batch is disabled, the sampled_token_ids should be
......@@ -4889,10 +4903,13 @@ class GPUModelRunner(
return None
hf_config = self.speculative_config.draft_model_config.hf_config
if not hasattr(hf_config, "eagle_aux_hidden_state_layer_ids"):
return None
layer_ids = hf_config.eagle_aux_hidden_state_layer_ids
layer_ids = getattr(hf_config, "eagle_aux_hidden_state_layer_ids", None)
if not layer_ids:
dflash_config = getattr(hf_config, "dflash_config", None)
if dflash_config and isinstance(dflash_config, dict):
layer_ids = dflash_config.get("target_layer_ids")
if layer_ids and isinstance(layer_ids, (list, tuple)):
return tuple(layer_ids)
......@@ -5479,7 +5496,10 @@ class GPUModelRunner(
):
assert isinstance(
self.drafter,
EagleProposer | DraftModelProposer | ExtractHiddenStatesProposer,
EagleProposer
| DFlashProposer
| DraftModelProposer
| ExtractHiddenStatesProposer,
)
assert self.speculative_config is not None
# Eagle currently only supports PIECEWISE cudagraphs.
......@@ -6236,7 +6256,9 @@ class GPUModelRunner(
self.speculative_config.use_eagle()
or self.speculative_config.uses_draft_model()
):
assert isinstance(self.drafter, EagleProposer | DraftModelProposer)
assert isinstance(
self.drafter, EagleProposer | DFlashProposer | DraftModelProposer
)
self.drafter.initialize_attn_backend(kv_cache_config, kernel_block_sizes)
def _check_and_update_cudagraph_mode(
......@@ -6420,7 +6442,10 @@ class GPUModelRunner(
self.speculative_config.use_eagle()
or self.speculative_config.uses_extract_hidden_states()
):
assert isinstance(self.drafter, EagleProposer | ExtractHiddenStatesProposer)
assert isinstance(
self.drafter,
EagleProposer | DFlashProposer | ExtractHiddenStatesProposer,
)
self.drafter.initialize_cudagraph_keys(cudagraph_mode)
def calculate_reorder_batch_threshold(self) -> None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment