Unverified Commit 0b790a25 authored by zhanqiuhu's avatar zhanqiuhu Committed by GitHub
Browse files

[Speculative Decoding] Add DFlash speculators config parsing (#38300)


Signed-off-by: default avatarZhanqiu Hu <zhu@redhat.com>
parent 41488f2a
...@@ -42,3 +42,16 @@ steps: ...@@ -42,3 +42,16 @@ steps:
- tests/v1/e2e/spec_decode/ - tests/v1/e2e/spec_decode/
commands: commands:
- pytest -v -s v1/e2e/spec_decode -k "draft_model or no_sync or batch_inference" - pytest -v -s v1/e2e/spec_decode -k "draft_model or no_sync or batch_inference"
- label: DFlash Speculators Correctness
timeout_in_minutes: 30
device: h100
optional: true
num_devices: 1
source_file_dependencies:
- vllm/v1/spec_decode/
- vllm/model_executor/models/qwen3_dflash.py
- tests/v1/spec_decode/test_speculators_dflash.py
commands:
- export VLLM_ALLOW_INSECURE_SERIALIZATION=1
- pytest -v -s v1/spec_decode/test_speculators_dflash.py -m slow_test
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from tests.evals.gsm8k.gsm8k_eval import evaluate_gsm8k_offline
from tests.utils import large_gpu_mark
from vllm import LLM
from vllm.config import SpeculativeConfig
from vllm.distributed import cleanup_dist_env_and_memory
MODEL_PATH = "nm-testing/dflash-qwen3-8b-speculators"
EXPECTED_GSM8K_ACCURACY = 0.885
ACCURACY_RTOL = 0.03
EXPECTED_ACCEPTANCE_LEN = 3.45
ACCEPTANCE_LEN_RTOL = 0.15
# Expected per-position acceptance rates (accepted_at_pos / num_drafts)
# Based on GSM8K evaluation with Qwen3-8B dflash speculators.
EXPECTED_PER_POS_ACCEPTANCE_RATES = [0.795, 0.611, 0.429, 0.282]
PER_POS_RTOL = 0.15
def compute_spec_decode_stats(
metrics,
) -> dict:
"""Extract all spec-decode metrics and compute derived stats."""
name2metric = {m.name: m for m in metrics}
n_drafts = name2metric["vllm:spec_decode_num_drafts"].value
n_draft_tokens = name2metric["vllm:spec_decode_num_draft_tokens"].value
n_accepted = name2metric["vllm:spec_decode_num_accepted_tokens"].value
per_pos_vec = name2metric["vllm:spec_decode_num_accepted_tokens_per_pos"].values
acceptance_len = 1 + (n_accepted / n_drafts) if n_drafts > 0 else 1.0
draft_tokens_per_step = (n_draft_tokens / n_drafts) if n_drafts > 0 else 0
overall_acceptance_rate = (n_accepted / n_draft_tokens) if n_draft_tokens > 0 else 0
per_pos_rates = [v / n_drafts for v in per_pos_vec] if n_drafts > 0 else []
return {
"num_drafts": n_drafts,
"num_draft_tokens": n_draft_tokens,
"num_accepted_tokens": n_accepted,
"acceptance_len": acceptance_len,
"draft_tokens_per_step": draft_tokens_per_step,
"overall_acceptance_rate": overall_acceptance_rate,
"per_pos_accepted": list(per_pos_vec),
"per_pos_acceptance_rates": per_pos_rates,
}
def print_spec_decode_stats(stats: dict) -> None:
"""Print all spec-decode metrics and derived values."""
print("\n===== Spec Decode Metrics =====")
print(f" num_drafts: {stats['num_drafts']}")
print(f" num_draft_tokens: {stats['num_draft_tokens']}")
print(f" num_accepted_tokens: {stats['num_accepted_tokens']}")
print(f" draft_tokens_per_step: {stats['draft_tokens_per_step']:.2f}")
print(f" overall_acceptance_rate: {stats['overall_acceptance_rate']:.4f}")
print(f" acceptance_len (1+acc/drafts): {stats['acceptance_len']:.4f}")
print(" per-position accepted tokens:", stats["per_pos_accepted"])
print(" per-position acceptance rates:")
for i, rate in enumerate(stats["per_pos_acceptance_rates"]):
print(f" pos {i}: {rate:.4f}")
print("===============================\n")
def test_dflash_speculators_model(vllm_runner, example_prompts, monkeypatch):
"""
Test DFlash speculators model properly initializes speculative decoding.
Verifies:
1. Speculative config is automatically initialized from speculators config
2. Method is detected as 'dflash'
3. The draft model path is correctly set
4. Speculative tokens count is valid (num_speculative_tokens=8)
5. Text generation works with speculative decoding enabled
"""
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
with vllm_runner(
MODEL_PATH,
dtype=torch.bfloat16,
enforce_eager=True,
quantization="fp8",
) as vllm_model:
vllm_config = vllm_model.llm.llm_engine.vllm_config
assert isinstance(vllm_config.speculative_config, SpeculativeConfig), (
"Speculative config should be initialized for speculators model"
)
spec_config = vllm_config.speculative_config
assert spec_config.method == "dflash", (
f"Expected method='dflash', got '{spec_config.method}'"
)
assert spec_config.num_speculative_tokens > 0, (
f"Expected positive speculative tokens, "
f"got {spec_config.num_speculative_tokens}"
)
assert spec_config.model == MODEL_PATH, (
f"Draft model should be {MODEL_PATH}, got {spec_config.model}"
)
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens=20)
assert vllm_outputs, f"No outputs generated for speculators model {MODEL_PATH}"
@pytest.mark.slow_test
@large_gpu_mark(min_gb=40)
def test_dflash_speculators_correctness(monkeypatch):
"""
E2E correctness test for DFlash via the speculators auto-detect path.
Evaluates GSM8k accuracy to ensure the speculators-format model produces
correct outputs, and checks that acceptance length does not collapse under
batched inference (lm-eval style).
Observed per-position acceptance rates on GSM8K (1319 prompts):
pos 0: 0.795, pos 1: 0.611, pos 2: 0.429, pos 3: 0.282,
pos 4: 0.169, pos 5: 0.093, pos 6: 0.048, pos 7: 0.023
Observed mean AL: 3.45 (GSM8K dataset, max_num_seqs=128)
"""
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
spec_llm = LLM(
model=MODEL_PATH,
trust_remote_code=True,
max_model_len=4096,
max_num_seqs=128,
gpu_memory_utilization=0.85,
enforce_eager=False,
disable_log_stats=False,
)
results = evaluate_gsm8k_offline(spec_llm)
accuracy = results["accuracy"]
accuracy_threshold = EXPECTED_GSM8K_ACCURACY * (1 - ACCURACY_RTOL)
assert accuracy >= accuracy_threshold, (
f"Expected GSM8K accuracy >= {accuracy_threshold:.3f}, got {accuracy:.3f}"
)
current_metrics = spec_llm.get_metrics()
stats = compute_spec_decode_stats(current_metrics)
print_spec_decode_stats(stats)
acceptance_len = stats["acceptance_len"]
al_threshold = EXPECTED_ACCEPTANCE_LEN * (1 - ACCEPTANCE_LEN_RTOL)
assert acceptance_len >= al_threshold, (
f"DFlash speculators acceptance length too low: "
f"{acceptance_len:.2f} < {al_threshold:.2f}"
)
# Check per-position acceptance rates for the first few positions.
per_pos_rates = stats["per_pos_acceptance_rates"]
for i, expected_rate in enumerate(EXPECTED_PER_POS_ACCEPTANCE_RATES):
assert i < len(per_pos_rates), (
f"Missing per-position acceptance rate for position {i}"
)
threshold = expected_rate * (1 - PER_POS_RTOL)
assert per_pos_rates[i] >= threshold, (
f"Per-position acceptance rate at pos {i} too low: "
f"{per_pos_rates[i]:.4f} < {threshold:.4f} "
f"(expected ~{expected_rate:.4f})"
)
del spec_llm
torch.accelerator.empty_cache()
cleanup_dist_env_and_memory()
...@@ -523,6 +523,13 @@ class DFlashQwen3ForCausalLM(Qwen3ForCausalLM): ...@@ -523,6 +523,13 @@ class DFlashQwen3ForCausalLM(Qwen3ForCausalLM):
self.logits_processor = LogitsProcessor( self.logits_processor = LogitsProcessor(
self.config.draft_vocab_size, scale=logit_scale self.config.draft_vocab_size, scale=logit_scale
) )
target_vocab_size = vllm_config.model_config.get_vocab_size()
if self.config.draft_vocab_size != target_vocab_size:
self.draft_id_to_target_id = nn.Parameter(
torch.zeros(self.config.draft_vocab_size, dtype=torch.long),
requires_grad=False,
)
else:
self.draft_id_to_target_id = None self.draft_id_to_target_id = None
def embed_input_ids( def embed_input_ids(
......
...@@ -41,3 +41,34 @@ def update_eagle3(config_dict: dict, pre_trained_config: dict) -> None: ...@@ -41,3 +41,34 @@ def update_eagle3(config_dict: dict, pre_trained_config: dict) -> None:
pre_trained_config["eagle_aux_hidden_state_layer_ids"] = config_dict[ pre_trained_config["eagle_aux_hidden_state_layer_ids"] = config_dict[
"eagle_aux_hidden_state_layer_ids" "eagle_aux_hidden_state_layer_ids"
] ]
@register_speculator("dflash")
def update_dflash(config_dict: dict, pre_trained_config: dict) -> None:
"""
Apply DFlash specific configuration transformations to the `dict` used to
construct the Transformers PreTrainedConfig.
DFlash specific fields:
- draft_vocab_size: Size of the draft model's vocabulary
- target_hidden_size: Hidden size of the target model
- mask_token_id (required): Token ID used for parallel drafting mask
placeholders
- aux_hidden_state_layer_ids (required): Layer indices from the target
model whose intermediate hidden states are used as context for the
DFlash drafter. Mapped to both eagle_aux_hidden_state_layer_ids
(for gpu_model_runner) and dflash_config.target_layer_ids (for the
DFlash model).
"""
pre_trained_config["architectures"] = ["DFlashDraftModel"]
pre_trained_config["draft_vocab_size"] = config_dict.get("draft_vocab_size")
if config_dict.get("target_hidden_size") is not None:
pre_trained_config["target_hidden_size"] = config_dict["target_hidden_size"]
aux_layer_ids = config_dict["aux_hidden_state_layer_ids"]
pre_trained_config["eagle_aux_hidden_state_layer_ids"] = aux_layer_ids
pre_trained_config["dflash_config"] = {
"mask_token_id": config_dict["mask_token_id"],
"target_layer_ids": aux_layer_ids,
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment