Unverified Commit 4a5299c9 authored by Tomas Ruiz's avatar Tomas Ruiz Committed by GitHub
Browse files

feat: spec decode with draft models (#24322)


Signed-off-by: default avatarTomas Ruiz <tomas.ruiz.te@gmail.com>
parent 73f2a81c
...@@ -54,7 +54,7 @@ def parse_args(): ...@@ -54,7 +54,7 @@ def parse_args():
"--method", "--method",
type=str, type=str,
default="eagle", default="eagle",
choices=["ngram", "eagle", "eagle3", "mtp"], choices=["ngram", "eagle", "eagle3", "mtp", "draft_model"],
) )
parser.add_argument("--num-spec-tokens", type=int, default=2) parser.add_argument("--num-spec-tokens", type=int, default=2)
parser.add_argument("--prompt-lookup-max", type=int, default=5) parser.add_argument("--prompt-lookup-max", type=int, default=5)
...@@ -70,7 +70,11 @@ def parse_args(): ...@@ -70,7 +70,11 @@ def parse_args():
parser.add_argument("--output-len", type=int, default=256) parser.add_argument("--output-len", type=int, default=256)
parser.add_argument("--model-dir", type=str, default=None) parser.add_argument("--model-dir", type=str, default=None)
parser.add_argument("--eagle-dir", type=str, default=None) parser.add_argument("--eagle-dir", type=str, default=None)
parser.add_argument("--draft-model", type=str, default=None)
parser.add_argument("--custom-mm-prompts", action="store_true") parser.add_argument("--custom-mm-prompts", action="store_true")
parser.add_argument("--gpu-memory-utilization", type=float, default=0.9)
parser.add_argument("--disable-padded-drafter-batch", action="store_true")
parser.add_argument("--max-num-seqs", type=int, default=None)
return parser.parse_args() return parser.parse_args()
...@@ -111,6 +115,7 @@ def main(args): ...@@ -111,6 +115,7 @@ def main(args):
"method": args.method, "method": args.method,
"model": eagle_dir, "model": eagle_dir,
"num_speculative_tokens": args.num_spec_tokens, "num_speculative_tokens": args.num_spec_tokens,
"disable_padded_drafter_batch": args.disable_padded_drafter_batch,
} }
elif args.method == "ngram": elif args.method == "ngram":
speculative_config = { speculative_config = {
...@@ -119,6 +124,15 @@ def main(args): ...@@ -119,6 +124,15 @@ def main(args):
"prompt_lookup_max": args.prompt_lookup_max, "prompt_lookup_max": args.prompt_lookup_max,
"prompt_lookup_min": args.prompt_lookup_min, "prompt_lookup_min": args.prompt_lookup_min,
} }
elif args.method == "draft_model":
assert args.draft_model is not None and args.draft_model != ""
speculative_config = {
"method": args.method,
"model": args.draft_model,
"num_speculative_tokens": args.num_spec_tokens,
"enforce_eager": args.enforce_eager,
"max_model_len": args.max_model_len,
}
elif args.method == "mtp": elif args.method == "mtp":
speculative_config = { speculative_config = {
"method": "mtp", "method": "mtp",
...@@ -133,12 +147,13 @@ def main(args): ...@@ -133,12 +147,13 @@ def main(args):
tensor_parallel_size=args.tp, tensor_parallel_size=args.tp,
enable_chunked_prefill=args.enable_chunked_prefill, enable_chunked_prefill=args.enable_chunked_prefill,
enforce_eager=args.enforce_eager, enforce_eager=args.enforce_eager,
gpu_memory_utilization=0.9, gpu_memory_utilization=args.gpu_memory_utilization,
speculative_config=speculative_config, speculative_config=speculative_config,
disable_log_stats=False, disable_log_stats=False,
max_model_len=args.max_model_len, max_model_len=args.max_model_len,
limit_mm_per_prompt={"image": 5}, limit_mm_per_prompt={"image": 5},
disable_chunked_mm_input=True, disable_chunked_mm_input=True,
max_num_seqs=args.max_num_seqs,
) )
sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len) sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len)
......
...@@ -4,13 +4,13 @@ import asyncio ...@@ -4,13 +4,13 @@ import asyncio
import copy import copy
import logging import logging
import os import os
import re
import socket import socket
import threading import threading
import uuid import uuid
import aiohttp import aiohttp
import msgpack import msgpack
import regex as re
import zmq import zmq
from quart import Quart, make_response, request from quart import Quart, make_response, request
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import random import random
from collections.abc import Iterable
from dataclasses import dataclass
from typing import Any from typing import Any
import pytest import pytest
...@@ -10,15 +12,22 @@ from tests.utils import get_attn_backend_list_based_on_platform, large_gpu_mark ...@@ -10,15 +12,22 @@ from tests.utils import get_attn_backend_list_based_on_platform, large_gpu_mark
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.assets.base import VLLM_S3_BUCKET_URL from vllm.assets.base import VLLM_S3_BUCKET_URL
from vllm.assets.image import VLM_IMAGES_DIR from vllm.assets.image import VLM_IMAGES_DIR
from vllm.benchmarks.datasets import InstructCoderDataset
from vllm.config.vllm import VllmConfig
from vllm.distributed import cleanup_dist_env_and_memory from vllm.distributed import cleanup_dist_env_and_memory
from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.v1.metrics.reader import Metric
from vllm.v1.spec_decode.draft_model import (
create_vllm_config_for_draft_model,
merge_toks_kernel,
)
MTP_SIMILARITY_RATE = 0.8 MTP_SIMILARITY_RATE = 0.8
def _skip_if_insufficient_gpus_for_tp(tp_size: int): def _skip_if_insufficient_gpus_for_tp(tp_size: int):
"""Skip test if available GPUs < tp_size on ROCm.""" """Skip test if available GPUs < tp_size on ROCm."""
if current_platform.is_rocm():
available_gpus = torch.cuda.device_count() available_gpus = torch.cuda.device_count()
if available_gpus < tp_size: if available_gpus < tp_size:
pytest.skip( pytest.skip(
...@@ -26,15 +35,21 @@ def _skip_if_insufficient_gpus_for_tp(tp_size: int): ...@@ -26,15 +35,21 @@ def _skip_if_insufficient_gpus_for_tp(tp_size: int):
) )
def get_test_prompts(mm_enabled: bool): Messages = list[dict[str, Any]]
def get_test_prompts(
mm_enabled: bool, quiet: bool = False, num_prompts: int = 100
) -> list[Messages]:
prompt_types = ["repeat", "sentence"] prompt_types = ["repeat", "sentence"]
if mm_enabled: if mm_enabled:
prompt_types.append("mm") prompt_types.append("mm")
num_prompts = 100
prompts = [] prompts = []
random.seed(0) random.seed(0)
random_prompt_type_choices = random.choices(prompt_types, k=num_prompts) random_prompt_type_choices = random.choices(prompt_types, k=num_prompts)
if not quiet:
print(f"Prompt types: {random_prompt_type_choices}") print(f"Prompt types: {random_prompt_type_choices}")
# Generate a mixed batch of prompts, some of which can be easily # Generate a mixed batch of prompts, some of which can be easily
...@@ -75,11 +90,27 @@ def get_test_prompts(mm_enabled: bool): ...@@ -75,11 +90,27 @@ def get_test_prompts(mm_enabled: bool):
return prompts return prompts
def get_instruct_coder_messages(n: int) -> list[Messages]:
dataset = InstructCoderDataset(
dataset_path="likaixin/InstructCoder", dataset_split="train"
)
prompts: Iterable[str] = dataset.sample_prompts(n=n)
return [[{"role": "user", "content": prompt}] for prompt in prompts]
@pytest.fixture @pytest.fixture
def sampling_config(): def sampling_config():
return greedy_sampling()
def greedy_sampling() -> SamplingParams:
return SamplingParams(temperature=0, max_tokens=10, ignore_eos=False) return SamplingParams(temperature=0, max_tokens=10, ignore_eos=False)
def stochastic_sampling() -> SamplingParams:
return SamplingParams(temperature=1.0, max_tokens=10, ignore_eos=False)
@pytest.fixture @pytest.fixture
def model_name(): def model_name():
return "meta-llama/Llama-3.1-8B-Instruct" return "meta-llama/Llama-3.1-8B-Instruct"
...@@ -583,3 +614,269 @@ def test_mtp_correctness( ...@@ -583,3 +614,269 @@ def test_mtp_correctness(
del spec_llm del spec_llm
torch.cuda.empty_cache() torch.cuda.empty_cache()
cleanup_dist_env_and_memory() cleanup_dist_env_and_memory()
@dataclass
class ArgsTest:
target_model: str
draft_model: str
sampling_config: SamplingParams
num_speculative_tokens: int
expected_acceptance_rate: float
expected_acceptance_len: float
# Defaults
target_tensor_parallel_size: int = 1
draft_tensor_parallel_size: int = 1
max_model_len: int = 1024
gpu_memory_utilization: float = 0.5
dataset: str = "test_prompts"
num_prompts: int = 100
cases = [
# Same model for draft and target, greedy sampling.
ArgsTest(
target_model="Qwen/Qwen3-0.6B",
draft_model="Qwen/Qwen3-0.6B",
sampling_config=greedy_sampling(),
num_speculative_tokens=3, # K
expected_acceptance_len=3 + 1, # K + 1
expected_acceptance_rate=1.0,
),
# Smaller draft model, stochastic sampling.
ArgsTest(
target_model="Qwen/Qwen3-1.7B",
draft_model="Qwen/Qwen3-0.6B",
sampling_config=stochastic_sampling(),
num_speculative_tokens=3,
expected_acceptance_len=2.8 + 1,
expected_acceptance_rate=0.9,
),
]
@pytest.mark.parametrize("args", cases)
@pytest.mark.parametrize("enforce_eager", [True, False])
def test_draft_model_correctness(args: ArgsTest, enforce_eager: bool):
assert_draft_model_correctness(args, enforce_eager)
def test_draft_model_realistic_example():
args = ArgsTest(
target_model="Qwen/Qwen3-1.7B",
draft_model="Qwen/Qwen3-0.6B",
dataset="likaixin/InstructCoder",
num_speculative_tokens=3,
sampling_config=greedy_sampling(),
# values below are not derived, but just prevent a regression
expected_acceptance_len=2.8,
expected_acceptance_rate=0.55,
)
assert_draft_model_correctness(args, enforce_eager=False)
@pytest.mark.parametrize(
"models",
[
# target_model, draft_model
("Qwen/Qwen3-1.7B-FP8", "Qwen/Qwen3-0.6B"), # target quantized
("Qwen/Qwen3-1.7B", "Qwen/Qwen3-0.6B-FP8"), # draft quantized
],
ids=["target_quantized", "draft_quantized"],
)
@pytest.mark.parametrize("enforce_eager", [True, False])
def test_draft_model_quantization(models: tuple[str, str], enforce_eager: bool):
tgt_model, draft_model = models
sd_case = ArgsTest(
target_model=tgt_model,
draft_model=draft_model,
**some_high_acceptance_metrics(),
)
assert_draft_model_correctness(sd_case, enforce_eager)
def test_draft_model_tensor_parallelism():
"""Ensure spec decode works when running with TP > 1."""
_skip_if_insufficient_gpus_for_tp(2)
sd_case = ArgsTest(
target_model="Qwen/Qwen3-1.7B",
target_tensor_parallel_size=2,
draft_model="Qwen/Qwen3-0.6B",
draft_tensor_parallel_size=2,
**some_high_acceptance_metrics(),
)
assert_draft_model_correctness(sd_case, enforce_eager=False)
def test_draft_model_engine_args_tensor_parallelism():
"""Ensure the vllm_config for the draft model is created correctly,
and independently of the target model (quantization, TP, etc.)"""
_skip_if_insufficient_gpus_for_tp(2)
engine_args = EngineArgs(
model="Qwen/Qwen3-1.7B-FP8", # <<< tgt quantized
tensor_parallel_size=2,
speculative_config={
"model": "Qwen/Qwen3-0.6B", # <<< draft not quantized
"method": "draft_model",
"num_speculative_tokens": 3,
"draft_tensor_parallel_size": 1, # <<< valid arg name
},
)
tgt_vllm_config: VllmConfig = engine_args.create_engine_config()
assert tgt_vllm_config.parallel_config.tensor_parallel_size == 2
assert tgt_vllm_config.quant_config.get_name() == "fp8"
draft_vllm_config: VllmConfig = create_vllm_config_for_draft_model(tgt_vllm_config)
assert draft_vllm_config.parallel_config.tensor_parallel_size == 1
assert draft_vllm_config.quant_config is None
def test_draft_model_engine_args_rejects_invalid_tp_argname():
"""The user should pass "draft_tensor_parallel_size" rather than
"tensor_parallel_size". We enforce this with validation."""
engine_args = EngineArgs(
model="Qwen/Qwen3-1.7B",
tensor_parallel_size=1,
speculative_config={
"model": "Qwen/Qwen3-0.6B",
"method": "draft_model",
"num_speculative_tokens": 3,
"tensor_parallel_size": 1, # <<< invalid arg name
},
)
with pytest.raises(ValueError):
engine_args.create_engine_config()
def assert_draft_model_correctness(args: ArgsTest, enforce_eager: bool):
"""Compare the outputs using and not using speculative decoding.
In the greedy decoding case, the outputs must match EXACTLY."""
test_prompts: list[Messages] = get_messages(
dataset=args.dataset, n=args.num_prompts
)
spec_llm = LLM(
model=args.target_model,
speculative_config={
"model": args.draft_model,
"method": "draft_model",
"num_speculative_tokens": args.num_speculative_tokens,
"max_model_len": args.max_model_len,
"enforce_eager": enforce_eager,
"draft_tensor_parallel_size": args.draft_tensor_parallel_size,
"max_num_seqs": 100, # limit cudagraph capture runtime
},
max_model_len=args.max_model_len,
gpu_memory_utilization=args.gpu_memory_utilization,
tensor_parallel_size=args.target_tensor_parallel_size,
enforce_eager=enforce_eager,
disable_log_stats=False, # enables get_metrics()
)
# we don't check the outputs, only check the metrics
spec_llm.chat(test_prompts, args.sampling_config)
metrics = spec_llm.get_metrics()
acceptance_rate: float = compute_acceptance_rate(metrics)
acceptance_len: float = compute_acceptance_len(metrics)
del spec_llm # CLEANUP
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()
print(
f"spec-decode: target={args.target_model}, draft={args.draft_model}, "
f"temperature={args.sampling_config.temperature:.2f}, "
f"acceptance_rate={acceptance_rate:.2f}, "
f"acceptance_len={acceptance_len:.2f}, "
)
assert acceptance_rate >= args.expected_acceptance_rate
assert acceptance_len >= args.expected_acceptance_len
def get_messages(dataset: str, n: int) -> list[Messages]:
if dataset == "test_prompts":
return get_test_prompts(mm_enabled=False, quiet=True, num_prompts=n)
elif dataset == "likaixin/InstructCoder":
return get_instruct_coder_messages(n=n)
else:
raise NotImplementedError(f"Dataset '{dataset}' not implemented")
def some_high_acceptance_metrics() -> dict:
return {
"sampling_config": greedy_sampling(),
"num_speculative_tokens": 3,
"expected_acceptance_len": 2.90 + 1,
"expected_acceptance_rate": 0.90,
}
def test_merge_toks_kernel():
device = "cuda"
merged_len = 5 + 2 # len(target_toks) = 5, batch_size = 2
merged = torch.full((merged_len,), -100, device=device) # -100 is arbitrary
is_rejected_tok = torch.full((merged_len,), True, device=device)
grid = (2,)
merge_toks_kernel[grid](
target_toks_ptr=torch.tensor([0, 1, 2, 0, 1], device=device),
next_toks_ptr=torch.tensor([3, 2], device=device),
query_start_locs_ptr=torch.tensor([0, 3], device=device),
query_end_locs_ptr=torch.tensor([2, 4], device=device),
out_ptr_merged_toks=merged,
out_ptr_is_rejected_tok=is_rejected_tok,
target_toks_size=5,
rejected_tok_fill=-1,
)
expected_merged = torch.tensor([0, 1, 2, 3, 0, 1, 2], device=device)
assert torch.allclose(merged, expected_merged)
expected_rejected_toks = torch.tensor([False] * merged_len, device=device)
assert torch.allclose(is_rejected_tok, expected_rejected_toks)
def test_merge_toks_kernel_with_rejected_tokens():
device = "cuda"
merged_size = 9 + 2 # len(target_toks) = 9, batch_size = 2
merged = torch.full((merged_size,), -100, device=device)
is_rejected_tok = torch.full((merged_size,), True, device=device)
grid = (2,)
merge_toks_kernel[grid](
# rejected tokens
# ↓ ↓ ↓ ↓
target_toks_ptr=torch.tensor([0, 1, 2, 13, 14, 15, 0, 1, 22], device=device),
next_toks_ptr=torch.tensor([3, 2], device=device),
query_start_locs_ptr=torch.tensor([0, 6], device=device),
query_end_locs_ptr=torch.tensor([2, 7], device=device),
out_ptr_merged_toks=merged,
out_ptr_is_rejected_tok=is_rejected_tok,
target_toks_size=9,
rejected_tok_fill=-1,
)
expected_merged = torch.tensor([0, 1, 2, 3, -1, -1, -1, 0, 1, 2, -1], device=device)
assert torch.allclose(merged, expected_merged)
expected_rejected_toks = torch.tensor(
[False, False, False, False, True, True, True, False, False, False, True],
device=device,
)
assert torch.allclose(is_rejected_tok, expected_rejected_toks)
def compute_acceptance_rate(metrics: list[Metric]) -> float:
name2metric = {metric.name: metric for metric in metrics}
n_draft_toks = name2metric["vllm:spec_decode_num_draft_tokens"].value # type: ignore
if n_draft_toks == 0:
return float("nan")
n_accepted_toks = name2metric["vllm:spec_decode_num_accepted_tokens"].value # type: ignore
return n_accepted_toks / n_draft_toks
def compute_acceptance_len(metrics: list[Metric]) -> float:
name2metric = {metric.name: metric for metric in metrics}
n_drafts = name2metric["vllm:spec_decode_num_drafts"].value # type: ignore
n_accepted_toks = name2metric["vllm:spec_decode_num_accepted_tokens"].value # type: ignore
if n_drafts == 0:
return 1
return 1 + (n_accepted_toks / n_drafts)
...@@ -55,3 +55,38 @@ def test_bind_kv_cache_non_attention(default_vllm_config): ...@@ -55,3 +55,38 @@ def test_bind_kv_cache_non_attention(default_vllm_config):
assert runner_kv_caches[0] is kv_cache["model.layers.20.attn"] assert runner_kv_caches[0] is kv_cache["model.layers.20.attn"]
assert runner_kv_caches[1] is kv_cache["model.layers.28.attn"] assert runner_kv_caches[1] is kv_cache["model.layers.28.attn"]
def test_bind_kv_cache_draft_model(default_vllm_config):
from vllm.attention.layer import Attention
layer_names = [
"model.layers.0.attn",
"model.layers.1.attn",
"draft_model.layers.0.attn",
"draft_model.layers.1.attn",
]
ctx = {
layer_name: Attention(32, 128, 0.1, prefix=layer_name)
for layer_name in layer_names
}
kv_cache = {layer_name: torch.zeros((1,)) for layer_name in layer_names}
runner_kv_caches: list[torch.Tensor] = []
bind_kv_cache(kv_cache, ctx, runner_kv_caches)
assert ctx["model.layers.0.attn"].kv_cache[0] is kv_cache["model.layers.0.attn"]
assert ctx["model.layers.1.attn"].kv_cache[0] is kv_cache["model.layers.1.attn"]
assert (
ctx["draft_model.layers.0.attn"].kv_cache[0]
is kv_cache["draft_model.layers.0.attn"]
)
assert (
ctx["draft_model.layers.1.attn"].kv_cache[0]
is kv_cache["draft_model.layers.1.attn"]
)
# caches are ordered by layer_index, interleaving target and draft model
assert runner_kv_caches[0] is kv_cache["model.layers.0.attn"]
assert runner_kv_caches[1] is kv_cache["draft_model.layers.0.attn"]
assert runner_kv_caches[2] is kv_cache["model.layers.1.attn"]
assert runner_kv_caches[3] is kv_cache["draft_model.layers.1.attn"]
...@@ -2593,17 +2593,10 @@ class InstructCoderDataset(HuggingFaceDataset): ...@@ -2593,17 +2593,10 @@ class InstructCoderDataset(HuggingFaceDataset):
request_id_prefix: str = "", request_id_prefix: str = "",
no_oversample: bool = False, no_oversample: bool = False,
**kwargs, **kwargs,
) -> list: ) -> list[SampleRequest]:
output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN
sampled_requests = [] sampled_requests = []
for i, item in enumerate(self.data): for i, prompt in enumerate(self.sample_prompts(n=num_requests)):
if len(sampled_requests) >= num_requests:
break
prompt = (
f"{item['input']}\n\n{item['instruction']} Just output "
"the code, do not include any explanation."
)
# apply template # apply template
if not skip_chat_template: if not skip_chat_template:
prompt = tokenizer.apply_chat_template( prompt = tokenizer.apply_chat_template(
...@@ -2626,6 +2619,14 @@ class InstructCoderDataset(HuggingFaceDataset): ...@@ -2626,6 +2619,14 @@ class InstructCoderDataset(HuggingFaceDataset):
) )
return sampled_requests return sampled_requests
def sample_prompts(self, n: int) -> Iterator[str]:
for item in self.data.take(n):
prompt = (
f"{item['input']}\n\n{item['instruction']} Just output "
"the code, do not include any explanation."
)
yield prompt
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# MT-Bench Dataset Implementation # MT-Bench Dataset Implementation
......
...@@ -8,8 +8,12 @@ import time ...@@ -8,8 +8,12 @@ import time
import aiohttp import aiohttp
from tqdm.asyncio import tqdm from tqdm.asyncio import tqdm
from vllm.logger import init_logger
from .endpoint_request_func import RequestFunc, RequestFuncInput, RequestFuncOutput from .endpoint_request_func import RequestFunc, RequestFuncInput, RequestFuncOutput
logger = init_logger(__name__)
async def wait_for_endpoint( async def wait_for_endpoint(
request_func: RequestFunc, request_func: RequestFunc,
...@@ -61,6 +65,8 @@ async def wait_for_endpoint( ...@@ -61,6 +65,8 @@ async def wait_for_endpoint(
if output.success: if output.success:
pbar.close() pbar.close()
return output return output
else:
logger.warning("Endpoint is not ready. Error='%s'", output.error)
except aiohttp.ClientConnectorError: except aiohttp.ClientConnectorError:
pass pass
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import os import os
from collections.abc import Callable from collections.abc import Callable
from dataclasses import replace
from typing import TYPE_CHECKING, Any, Literal from typing import TYPE_CHECKING, Any, Literal
import torch import torch
...@@ -709,3 +710,6 @@ class ParallelConfig: ...@@ -709,3 +710,6 @@ class ParallelConfig:
) )
return self return self
def replace(self, **kwargs) -> Self:
return replace(self, **kwargs)
...@@ -77,6 +77,9 @@ class SpeculativeConfig: ...@@ -77,6 +77,9 @@ class SpeculativeConfig:
draft_tensor_parallel_size: int | None = Field(default=None, ge=1) draft_tensor_parallel_size: int | None = Field(default=None, ge=1)
"""The degree of the tensor parallelism for the draft model. Can only be 1 """The degree of the tensor parallelism for the draft model. Can only be 1
or the same as the target model's tensor parallel size.""" or the same as the target model's tensor parallel size."""
tensor_parallel_size: int | None = None
"""Users should pass "draft_tensor_parallel_size". This parameter's purpose is to
warn users when they mistakenly provide the wrong argument."""
# Draft model configuration # Draft model configuration
quantization: me_quant.QuantizationMethods | None = None quantization: me_quant.QuantizationMethods | None = None
...@@ -397,13 +400,11 @@ class SpeculativeConfig: ...@@ -397,13 +400,11 @@ class SpeculativeConfig:
"one layer. Might need some code changes " "one layer. Might need some code changes "
"to support multiple layers." "to support multiple layers."
) )
elif self.method == "draft_model":
pass
else: else:
self.method = "draft_model"
raise NotImplementedError( raise NotImplementedError(
"Speculative decoding with draft model is not " f"Unsupported speculative method: '{self.method}'"
"supported yet. Please consider using other "
"speculative decoding methods such as ngram, medusa, "
"eagle, or mtp."
) )
# Replace hf_config for EAGLE draft_model # Replace hf_config for EAGLE draft_model
...@@ -631,6 +632,12 @@ class SpeculativeConfig: ...@@ -631,6 +632,12 @@ class SpeculativeConfig:
@model_validator(mode="after") @model_validator(mode="after")
def _verify_args(self) -> Self: def _verify_args(self) -> Self:
if self.tensor_parallel_size is not None:
raise ValueError(
"'tensor_parallel_size' is not a valid argument in the "
"speculative_config. Please pass 'draft_tensor_parallel_size' instead."
)
if self.num_speculative_tokens is None: if self.num_speculative_tokens is None:
raise ValueError( raise ValueError(
"num_speculative_tokens must be provided with " "num_speculative_tokens must be provided with "
...@@ -669,12 +676,32 @@ class SpeculativeConfig: ...@@ -669,12 +676,32 @@ class SpeculativeConfig:
f"Eagle3 is only supported for {eagle3_target_supported} models. " # noqa: E501 f"Eagle3 is only supported for {eagle3_target_supported} models. " # noqa: E501
f"Got {self.target_model_config.hf_text_config.model_type=}" f"Got {self.target_model_config.hf_text_config.model_type=}"
) )
self.verify_equal_vocab_size_if_draft_model()
return self return self
def verify_equal_vocab_size_if_draft_model(self):
if (
self.method == "draft_model"
and self.target_model_config is not None
and self.draft_model_config is not None
):
target_vocab_size = self.target_model_config.get_vocab_size()
draft_vocab_size = self.draft_model_config.get_vocab_size()
if target_vocab_size != draft_vocab_size:
raise ValueError(
f"Target and draft model should have the same vocabulary size. "
f"Target model vocab_size={target_vocab_size}. "
f"Draft model vocab_size={draft_vocab_size}. "
f"Using models with different tokenizers can cause out-of-bounds "
f"errors during speculative decoding."
)
def use_eagle(self) -> bool: def use_eagle(self) -> bool:
return self.method in ("eagle", "eagle3", "mtp") return self.method in ("eagle", "eagle3", "mtp")
def uses_draft_model(self) -> bool:
return self.method == "draft_model"
def __repr__(self) -> str: def __repr__(self) -> str:
method = self.method method = self.method
model = None if method in ("ngram", "suffix") else self.draft_model_config.model model = None if method in ("ngram", "suffix") else self.draft_model_config.model
......
...@@ -1214,10 +1214,19 @@ class VllmConfig: ...@@ -1214,10 +1214,19 @@ class VllmConfig:
compilation_config = self.compilation_config compilation_config = self.compilation_config
computed_compile_ranges_split_points = [] computed_compile_ranges_split_points = []
# The upper bound of the compile ranges is the max_num_batched_tokens # The upper bound of the compile ranges is the max_num_batched_tokens.
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens # For speculative decoding with draft model, the compile range must be extended
if max_num_batched_tokens is not None: # by 1 for each sequence.
computed_compile_ranges_split_points.append(max_num_batched_tokens) compile_range_end = self.scheduler_config.max_num_batched_tokens
if compile_range_end is not None:
do_extend: bool = (
self.speculative_config is not None
and self.speculative_config.uses_draft_model()
)
if do_extend:
compile_range_end += self.scheduler_config.max_num_seqs
computed_compile_ranges_split_points.append(compile_range_end)
# Add the compile ranges for flashinfer # Add the compile ranges for flashinfer
if compilation_config.pass_config.fuse_allreduce_rms: if compilation_config.pass_config.fuse_allreduce_rms:
...@@ -1228,10 +1237,7 @@ class VllmConfig: ...@@ -1228,10 +1237,7 @@ class VllmConfig:
self.model_config.get_hidden_size() self.model_config.get_hidden_size()
* self.model_config.dtype.itemsize * self.model_config.dtype.itemsize
) )
if ( if compile_range_end is not None and max_token_num < compile_range_end:
max_num_batched_tokens is not None
and max_token_num < max_num_batched_tokens
):
computed_compile_ranges_split_points.append(max_token_num) computed_compile_ranges_split_points.append(max_token_num)
else: else:
logger.debug( logger.debug(
...@@ -1243,11 +1249,7 @@ class VllmConfig: ...@@ -1243,11 +1249,7 @@ class VllmConfig:
for x in compilation_config.compile_ranges_split_points: for x in compilation_config.compile_ranges_split_points:
assert isinstance(x, int) assert isinstance(x, int)
assert x > 0, f"Invalid compile range split point: {x}" assert x > 0, f"Invalid compile range split point: {x}"
if ( if compile_range_end is not None and x < compile_range_end and x > 1:
max_num_batched_tokens is not None
and x < max_num_batched_tokens
and x > 1
):
computed_compile_ranges_split_points.append(x) computed_compile_ranges_split_points.append(x)
compilation_config.compile_ranges_split_points = sorted( compilation_config.compile_ranges_split_points = sorted(
computed_compile_ranges_split_points computed_compile_ranges_split_points
...@@ -1316,6 +1318,14 @@ class VllmConfig: ...@@ -1316,6 +1318,14 @@ class VllmConfig:
path = self.compilation_config.debug_dump_path / append_path path = self.compilation_config.debug_dump_path / append_path
return path return path
def replace(self, **kwargs):
"""
Replace attributes of the config, and 'recompute' the config.
dataclass.replace() calls __init__() and __post_init__(), source:
https://docs.python.org/3/library/dataclasses.html#dataclasses.replace
"""
return replace(self, **kwargs)
def __str__(self): def __str__(self):
return ( return (
f"model={self.model_config.model!r}, " f"model={self.model_config.model!r}, "
......
...@@ -1776,21 +1776,6 @@ class EngineArgs: ...@@ -1776,21 +1776,6 @@ class EngineArgs:
): ):
_raise_unsupported_error(feature_name="Concurrent Partial Prefill") _raise_unsupported_error(feature_name="Concurrent Partial Prefill")
# N-gram, Medusa, and Eagle are supported for speculative decoding.
if self.speculative_config is not None:
# speculative_config could still be a dict at this point
if isinstance(self.speculative_config, dict):
method = self.speculative_config.get("method", None)
else:
method = self.speculative_config.method
if method == "draft_model":
raise NotImplementedError(
"Draft model speculative decoding is not supported yet. "
"Please consider using other speculative decoding methods "
"such as ngram, medusa, eagle, or mtp."
)
if self.pipeline_parallel_size > 1: if self.pipeline_parallel_size > 1:
supports_pp = getattr( supports_pp = getattr(
self.distributed_executor_backend, "supports_pp", False self.distributed_executor_backend, "supports_pp", False
......
...@@ -124,12 +124,17 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: ...@@ -124,12 +124,17 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
def get_model( def get_model(
*, vllm_config: VllmConfig, model_config: ModelConfig | None = None *,
vllm_config: VllmConfig,
model_config: ModelConfig | None = None,
prefix: str = "",
) -> nn.Module: ) -> nn.Module:
loader = get_model_loader(vllm_config.load_config) loader = get_model_loader(vllm_config.load_config)
if model_config is None: if model_config is None:
model_config = vllm_config.model_config model_config = vllm_config.model_config
return loader.load_model(vllm_config=vllm_config, model_config=model_config) return loader.load_model(
vllm_config=vllm_config, model_config=model_config, prefix=prefix
)
__all__ = [ __all__ = [
......
...@@ -36,7 +36,7 @@ class BaseModelLoader(ABC): ...@@ -36,7 +36,7 @@ class BaseModelLoader(ABC):
raise NotImplementedError raise NotImplementedError
def load_model( def load_model(
self, vllm_config: VllmConfig, model_config: ModelConfig self, vllm_config: VllmConfig, model_config: ModelConfig, prefix: str = ""
) -> nn.Module: ) -> nn.Module:
"""Load a model with the given configurations.""" """Load a model with the given configurations."""
device_config = vllm_config.device_config device_config = vllm_config.device_config
...@@ -48,7 +48,7 @@ class BaseModelLoader(ABC): ...@@ -48,7 +48,7 @@ class BaseModelLoader(ABC):
with set_default_torch_dtype(model_config.dtype): with set_default_torch_dtype(model_config.dtype):
with target_device: with target_device:
model = initialize_model( model = initialize_model(
vllm_config=vllm_config, model_config=model_config vllm_config=vllm_config, model_config=model_config, prefix=prefix
) )
log_model_inspection(model) log_model_inspection(model)
......
...@@ -335,7 +335,7 @@ class GGUFModelLoader(BaseModelLoader): ...@@ -335,7 +335,7 @@ class GGUFModelLoader(BaseModelLoader):
) )
def load_model( def load_model(
self, vllm_config: VllmConfig, model_config: ModelConfig self, vllm_config: VllmConfig, model_config: ModelConfig, prefix: str = ""
) -> nn.Module: ) -> nn.Module:
device_config = vllm_config.device_config device_config = vllm_config.device_config
local_model_path = self._prepare_weights(model_config) local_model_path = self._prepare_weights(model_config)
...@@ -364,7 +364,7 @@ class GGUFModelLoader(BaseModelLoader): ...@@ -364,7 +364,7 @@ class GGUFModelLoader(BaseModelLoader):
target_device = torch.device(device_config.device) target_device = torch.device(device_config.device)
with set_default_torch_dtype(model_config.dtype): with set_default_torch_dtype(model_config.dtype):
with target_device: with target_device:
model = initialize_model(vllm_config=vllm_config) model = initialize_model(vllm_config=vllm_config, prefix=prefix)
self.load_weights(model, model_config) self.load_weights(model, model_config)
process_weights_after_loading(model, model_config, target_device) process_weights_after_loading(model, model_config, target_device)
......
...@@ -68,6 +68,7 @@ class TensorizerLoader(BaseModelLoader): ...@@ -68,6 +68,7 @@ class TensorizerLoader(BaseModelLoader):
def _load_model_serialized_cpu( def _load_model_serialized_cpu(
self, self,
vllm_config: VllmConfig, vllm_config: VllmConfig,
prefix: str = "",
) -> nn.Module: ) -> nn.Module:
"""Load a serialized model with tensorizer to the CPU. """Load a serialized model with tensorizer to the CPU.
...@@ -80,7 +81,7 @@ class TensorizerLoader(BaseModelLoader): ...@@ -80,7 +81,7 @@ class TensorizerLoader(BaseModelLoader):
model_config = vllm_config.model_config model_config = vllm_config.model_config
with set_default_torch_dtype(model_config.dtype): with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device): with torch.device(device_config.device):
model = initialize_model(vllm_config=vllm_config) model = initialize_model(vllm_config=vllm_config, prefix=prefix)
model.load_weights(self._get_weights_iterator()) model.load_weights(self._get_weights_iterator())
return model.eval() return model.eval()
...@@ -112,7 +113,7 @@ class TensorizerLoader(BaseModelLoader): ...@@ -112,7 +113,7 @@ class TensorizerLoader(BaseModelLoader):
model.load_weights(self._get_weights_iterator()) model.load_weights(self._get_weights_iterator())
def load_model( def load_model(
self, vllm_config: VllmConfig, model_config: ModelConfig self, vllm_config: VllmConfig, model_config: ModelConfig, prefix: str = ""
) -> nn.Module: ) -> nn.Module:
parallel_config = vllm_config.parallel_config parallel_config = vllm_config.parallel_config
self._verify_config(model_config, parallel_config) self._verify_config(model_config, parallel_config)
...@@ -134,7 +135,7 @@ class TensorizerLoader(BaseModelLoader): ...@@ -134,7 +135,7 @@ class TensorizerLoader(BaseModelLoader):
) )
self.load_weights(model, model_config) self.load_weights(model, model_config)
return model return model
return self._load_model_serialized_cpu(vllm_config=vllm_config) return self._load_model_serialized_cpu(vllm_config=vllm_config, prefix=prefix)
@staticmethod @staticmethod
def save_model( def save_model(
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass, replace
from enum import Enum from enum import Enum
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Protocol, TypeVar, get_args from typing import TYPE_CHECKING, Any, ClassVar, Generic, Protocol, TypeVar, get_args
...@@ -329,6 +329,16 @@ class CommonAttentionMetadata: ...@@ -329,6 +329,16 @@ class CommonAttentionMetadata:
_num_computed_tokens_cache: torch.Tensor | None = None _num_computed_tokens_cache: torch.Tensor | None = None
def batch_size(self) -> int:
return self.seq_lens.shape[0]
def naive_query_lens(self) -> torch.Tensor:
"""Naive because it assumes that query ends where the next query starts."""
return self.query_start_loc[1:] - self.query_start_loc[:-1]
def replace(self, **kwargs) -> "CommonAttentionMetadata":
return replace(self, **kwargs)
@property @property
@deprecated( @deprecated(
""" """
......
...@@ -818,3 +818,35 @@ def get_dcp_local_seq_lens( ...@@ -818,3 +818,35 @@ def get_dcp_local_seq_lens(
) )
dcp_local_seq_lens = base + remainder dcp_local_seq_lens = base + remainder
return dcp_local_seq_lens.squeeze(1) return dcp_local_seq_lens.squeeze(1)
def extend_all_queries_by_1(
common_attn_metadata: CommonAttentionMetadata,
arange: torch.Tensor,
new_slot_mapping: torch.Tensor,
) -> CommonAttentionMetadata:
"""
Creates a new CommonAttentionMetadata with all query lengths increased by 1.
Also all seq lens are increased by 1.
This is useful e.g. in speculative decoding with draft models, where we
extend each sequence by 1 token.
The slot mapping is computed externally, as it requires more information.
"""
cad = common_attn_metadata
# query start loc must be increased by [+0, +1, +2, ..., +batch_size]
new_query_start_loc = cad.query_start_loc + arange[: len(cad.query_start_loc)]
new_query_start_loc_cpu = cad.query_start_loc_cpu + torch.arange(
len(cad.query_start_loc_cpu), dtype=torch.int32
)
new_cad = cad.replace(
query_start_loc=new_query_start_loc,
query_start_loc_cpu=new_query_start_loc_cpu,
seq_lens=cad.seq_lens + 1,
# each request is extended by 1 token -> batch_size tokens are added
num_actual_tokens=cad.num_actual_tokens + cad.batch_size(),
# All query lens increase by 1, so max query len increases by 1
max_query_len=cad.max_query_len + 1,
max_seq_len=cad.max_seq_len + 1,
slot_mapping=new_slot_mapping,
)
return new_cad
...@@ -208,6 +208,8 @@ class Scheduler(SchedulerInterface): ...@@ -208,6 +208,8 @@ class Scheduler(SchedulerInterface):
if speculative_config.use_eagle(): if speculative_config.use_eagle():
self.use_eagle = True self.use_eagle = True
self.num_lookahead_tokens = self.num_spec_tokens self.num_lookahead_tokens = self.num_spec_tokens
if speculative_config.uses_draft_model():
self.num_lookahead_tokens = self.num_spec_tokens
# Create the KV cache manager. # Create the KV cache manager.
self.kv_cache_manager = KVCacheManager( self.kv_cache_manager = KVCacheManager(
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
import torch
from vllm.attention.layer import Attention
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.config.speculative import SpeculativeConfig
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
from vllm.triton_utils import tl, triton
from vllm.v1.attention.backends.utils import (
CommonAttentionMetadata,
extend_all_queries_by_1,
)
from vllm.v1.spec_decode.eagle import PADDING_SLOT_ID, SpecDecodeBaseProposer
logger = init_logger(__name__)
class DraftModelProposer(SpecDecodeBaseProposer):
def __init__(
self,
vllm_config: VllmConfig,
device: torch.device,
runner=None,
):
super().__init__(
vllm_config=vllm_config,
device=device,
pass_hidden_states_to_model=False,
runner=runner,
)
self._raise_if_multimodal()
self._raise_if_mrope()
self._raise_if_padded_drafter_batch_disabled()
self._raise_if_vocab_size_mismatch()
self._raise_if_draft_tp_mismatch()
def _block_size(self) -> int:
builder = self._get_attention_metadata_builder()
return builder.kv_cache_spec.block_size
def _raise_if_multimodal(self):
if self.supports_mm_inputs:
raise NotImplementedError(
"Speculative Decoding with draft models "
"does not support multimodal models yet"
)
def _raise_if_mrope(self):
if self.draft_model_config.uses_mrope:
raise NotImplementedError(
"Speculative Decoding with draft models does not support M-RoPE yet"
)
def _raise_if_padded_drafter_batch_disabled(self):
if self.vllm_config.speculative_config.disable_padded_drafter_batch:
raise NotImplementedError(
"Speculative Decoding with draft models only supports "
"padded drafter batch. Please don't pass --disable-padded-drafter-batch"
" in the speculative_config."
)
def _raise_if_vocab_size_mismatch(self):
self.vllm_config.speculative_config.verify_equal_vocab_size_if_draft_model()
def _raise_if_draft_tp_mismatch(self):
# Note(Tomas Ruiz) If we run the target model with TP > 1 and
# the draft model with TP = 1, then the different TP ranks collide.
# Specifically when all ranks compile the draft model on rank 0
# (because TP=1), then the torch compile cache is overwritten and corrupted.
# We need a mechanism like this: https://github.com/vllm-project/vllm/pull/5414
# To prevent this error, we assert that both TP sizes must be the same.
spec_cfg: SpeculativeConfig = self.vllm_config.speculative_config
tgt_tp = spec_cfg.target_parallel_config.tensor_parallel_size
draft_tp = spec_cfg.draft_parallel_config.tensor_parallel_size
if draft_tp != tgt_tp:
raise ValueError(
f"Currently, 'draft_tensor_parallel_size' and 'tensor_parallel_size' "
f"must be the same. Got {draft_tp} and {tgt_tp}. "
"Please pass 'draft_tensor_parallel_size' in the speculative_config."
)
def set_inputs_first_pass(
self,
target_token_ids: torch.Tensor,
next_token_ids: torch.Tensor,
target_positions: torch.Tensor,
last_token_indices: torch.Tensor | None,
cad: CommonAttentionMetadata,
num_rejected_tokens_gpu: torch.Tensor | None,
) -> tuple[int, torch.Tensor, CommonAttentionMetadata]:
batch_size = cad.batch_size()
grid = (batch_size,)
start_locs = cad.query_start_loc[:-1]
end_locs = cad.query_start_loc[1:] - 1
if num_rejected_tokens_gpu is not None:
end_locs -= num_rejected_tokens_gpu
num_tokens = target_token_ids.shape[0] + batch_size
is_rejected_tok = torch.empty(
(num_tokens,), device=self.input_ids.device, dtype=torch.bool
)
merge_toks_kernel[grid](
target_toks_ptr=target_token_ids,
next_toks_ptr=next_token_ids,
query_start_locs_ptr=start_locs,
query_end_locs_ptr=end_locs,
out_ptr_merged_toks=self.input_ids,
out_ptr_is_rejected_tok=is_rejected_tok,
target_toks_size=target_token_ids.shape[0],
# passing a negative rejected_tok_fill value will raise an error
# when the value is used to index into embeddings.
# Therefore, we pass a valid integer, e.g. 0.
rejected_tok_fill=0,
)
merge_toks_kernel[grid](
target_toks_ptr=target_positions,
next_toks_ptr=target_positions[end_locs] + 1,
query_start_locs_ptr=start_locs,
query_end_locs_ptr=end_locs,
out_ptr_merged_toks=self.positions,
out_ptr_is_rejected_tok=is_rejected_tok,
target_toks_size=target_positions.shape[0],
rejected_tok_fill=0,
)
# recompute slot mapping
new_slot_mapping = compute_new_slot_mapping(
cad=cad,
new_positions=self.positions[:num_tokens],
is_rejected_token_mask=is_rejected_tok,
block_size=self._block_size(),
max_model_len=self.max_model_len,
)
# update common_attn_metadata
new_cad: CommonAttentionMetadata = extend_all_queries_by_1(
cad,
arange=self.arange,
new_slot_mapping=new_slot_mapping,
)
new_last_token_indices = new_cad.query_start_loc[1:] - 1
if num_rejected_tokens_gpu is not None:
new_last_token_indices -= num_rejected_tokens_gpu
return num_tokens, new_last_token_indices, new_cad
def load_model(self, target_model: Any) -> None:
"""Takes target_model to satisfy the type checker."""
# This must be computed before loading the draft model
# because that mutates the forward_context of the vllm_config
target_attn_layer_names = set(
get_layers_from_vllm_config(self.vllm_config, Attention).keys()
)
from vllm.compilation.backends import set_model_tag
draft_vllm_config: VllmConfig = create_vllm_config_for_draft_model(
target_model_vllm_config=self.vllm_config
)
logger.info(
"Starting to load draft model %s. TP=%d, rank=%d",
draft_vllm_config.model_config.model,
draft_vllm_config.parallel_config.tensor_parallel_size,
draft_vllm_config.parallel_config.rank,
)
with set_model_tag("draft_model"):
self.model = get_model(vllm_config=draft_vllm_config, prefix="draft_model")
# This must be computed after loading the draft model
# because that mutates the forward_context of the vllm_config
draft_attn_layer_names = (
get_layers_from_vllm_config(self.vllm_config, Attention).keys()
- target_attn_layer_names
)
self.attn_layer_names = list(draft_attn_layer_names)
def create_vllm_config_for_draft_model(
target_model_vllm_config: VllmConfig,
) -> VllmConfig:
"""The vllm_config is configured for the target model, e.g.
its quant_config and parallel_config. But the draft model is potentially
quantized differently, and has potentially different tensor_parallel_size.
This function creates a new vllm_config configured for the draft model.
The vllm_config is useful when loading the draft model with get_model().
"""
old = target_model_vllm_config
new_parallel_config = old.speculative_config.draft_parallel_config.replace(
rank=old.parallel_config.rank
)
new: VllmConfig = old.replace(
quant_config=None, # quant_config is recomputed in __init__()
model_config=old.speculative_config.draft_model_config,
parallel_config=new_parallel_config,
)
return new
def compute_new_slot_mapping(
cad: CommonAttentionMetadata,
new_positions: torch.Tensor,
is_rejected_token_mask: torch.Tensor,
block_size: int,
max_model_len: int,
):
batch_size, n_blocks_per_req = cad.block_table_tensor.shape
req_indices = torch.arange(batch_size, device=cad.query_start_loc.device)
req_indices = torch.repeat_interleave(
req_indices, cad.naive_query_lens() + 1, output_size=len(new_positions)
)
# Clamp the positions to prevent an out-of-bounds error when indexing
# into block_table_tensor.
clamped_positions = torch.clamp(new_positions, max=max_model_len - 1)
block_table_indices = (
req_indices * n_blocks_per_req + clamped_positions // block_size
)
block_nums = cad.block_table_tensor.view(-1)[block_table_indices]
block_offsets = clamped_positions % block_size
new_slot_mapping = block_nums * block_size + block_offsets
# Mask out the position ids that exceed the max model length.
exceeds_max_model_len = new_positions >= max_model_len
new_slot_mapping.masked_fill_(exceeds_max_model_len, PADDING_SLOT_ID)
# Mask out rejected tokens to prevent saves to the KV cache.
new_slot_mapping.masked_fill_(is_rejected_token_mask, PADDING_SLOT_ID)
return new_slot_mapping
@triton.jit
def merge_toks_kernel(
target_toks_ptr,
next_toks_ptr,
query_start_locs_ptr,
query_end_locs_ptr,
out_ptr_merged_toks,
out_ptr_is_rejected_tok,
target_toks_size,
rejected_tok_fill,
):
"""
Merges the `target_toks_ptr` and the `next_toks_ptr` into a new tensor
called `out_ptr_merged_toks`. Rejected tokens are those after the
`query_end_locs_ptr` and before the next `query_start_locs_ptr`. Fills the
rejected tokens positions with the value `rejected_tok_fill`. Also fills a mask
of the rejected tokens in `out_ptr_is_rejected_tok`.
"""
pid = tl.program_id(0)
start_loc = tl.load(query_start_locs_ptr + pid)
is_last_program = pid == tl.num_programs(0) - 1
if is_last_program:
next_start_loc = target_toks_size.to(tl.int32)
else:
next_start_loc = tl.load(query_start_locs_ptr + pid + 1).to(tl.int32)
end_loc = tl.load(query_end_locs_ptr + pid)
new_val = tl.load(next_toks_ptr + pid)
for i in range(start_loc, next_start_loc + 1):
if i <= end_loc: # copy existing tokens
old_val = tl.load(target_toks_ptr + i)
tl.store(out_ptr_merged_toks + pid + i, old_val)
tl.store(out_ptr_is_rejected_tok + pid + i, False)
elif i == end_loc + 1: # copy bonus token
tl.store(out_ptr_merged_toks + pid + i, new_val)
tl.store(out_ptr_is_rejected_tok + pid + i, False)
else: # fill rejected tokens
tl.store(out_ptr_merged_toks + pid + i, rejected_tok_fill)
tl.store(out_ptr_is_rejected_tok + pid + i, True)
...@@ -53,11 +53,12 @@ logger = init_logger(__name__) ...@@ -53,11 +53,12 @@ logger = init_logger(__name__)
PADDING_SLOT_ID = -1 PADDING_SLOT_ID = -1
class EagleProposer: class SpecDecodeBaseProposer:
def __init__( def __init__(
self, self,
vllm_config: VllmConfig, vllm_config: VllmConfig,
device: torch.device, device: torch.device,
pass_hidden_states_to_model: bool,
runner=None, runner=None,
): ):
self.vllm_config = vllm_config self.vllm_config = vllm_config
...@@ -65,6 +66,7 @@ class EagleProposer: ...@@ -65,6 +66,7 @@ class EagleProposer:
assert self.speculative_config is not None assert self.speculative_config is not None
self.draft_model_config = self.speculative_config.draft_model_config self.draft_model_config = self.speculative_config.draft_model_config
self.method = self.speculative_config.method self.method = self.speculative_config.method
self.pass_hidden_states_to_model = pass_hidden_states_to_model
self.runner = runner self.runner = runner
self.device = device self.device = device
...@@ -72,7 +74,11 @@ class EagleProposer: ...@@ -72,7 +74,11 @@ class EagleProposer:
self.max_model_len = vllm_config.model_config.max_model_len self.max_model_len = vllm_config.model_config.max_model_len
self.dp_rank = vllm_config.parallel_config.data_parallel_rank self.dp_rank = vllm_config.parallel_config.data_parallel_rank
self.num_speculative_tokens = self.speculative_config.num_speculative_tokens self.num_speculative_tokens = self.speculative_config.num_speculative_tokens
self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens # The drafter can get longer sequences than the target model.
max_batch_size = vllm_config.scheduler_config.max_num_seqs
self.max_num_tokens = (
vllm_config.scheduler_config.max_num_batched_tokens + max_batch_size
)
self.token_arange_np = np.arange(self.max_num_tokens) self.token_arange_np = np.arange(self.max_num_tokens)
# We need to get the hidden size from the draft model config because # We need to get the hidden size from the draft model config because
# the draft model's hidden size can be different from the target model's # the draft model's hidden size can be different from the target model's
...@@ -143,7 +149,6 @@ class EagleProposer: ...@@ -143,7 +149,6 @@ class EagleProposer:
# We need +1 here because the arange is used to set query_start_loc, # We need +1 here because the arange is used to set query_start_loc,
# which has one more element than batch_size. # which has one more element than batch_size.
max_batch_size = vllm_config.scheduler_config.max_num_seqs
max_num_slots_for_arange = max(max_batch_size + 1, self.max_num_tokens) max_num_slots_for_arange = max(max_batch_size + 1, self.max_num_tokens)
self.arange = torch.arange( self.arange = torch.arange(
max_num_slots_for_arange, device=device, dtype=torch.int32 max_num_slots_for_arange, device=device, dtype=torch.int32
...@@ -245,11 +250,7 @@ class EagleProposer: ...@@ -245,11 +250,7 @@ class EagleProposer:
mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None, mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None,
num_rejected_tokens_gpu: torch.Tensor | None = None, num_rejected_tokens_gpu: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
num_tokens = target_token_ids.shape[0] batch_size = common_attn_metadata.batch_size()
batch_size = next_token_ids.shape[0]
if last_token_indices is None:
last_token_indices = common_attn_metadata.query_start_loc[1:] - 1
if self.method == "eagle3": if self.method == "eagle3":
assert isinstance(self.model, Eagle3LlamaForCausalLM) assert isinstance(self.model, Eagle3LlamaForCausalLM)
...@@ -257,12 +258,17 @@ class EagleProposer: ...@@ -257,12 +258,17 @@ class EagleProposer:
target_hidden_states target_hidden_states
) )
assert target_hidden_states.shape[-1] == self.hidden_size assert target_hidden_states.shape[-1] == self.hidden_size
# Shift the input ids by one token.
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3] num_tokens, last_token_indices, common_attn_metadata = (
self.input_ids[: num_tokens - 1] = target_token_ids[1:] self.set_inputs_first_pass(
# Replace the last token with the next token. target_token_ids=target_token_ids,
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] next_token_ids=next_token_ids,
self.input_ids[last_token_indices] = next_token_ids target_positions=target_positions,
last_token_indices=last_token_indices,
cad=common_attn_metadata,
num_rejected_tokens_gpu=num_rejected_tokens_gpu,
)
)
assert self.runner is not None assert self.runner is not None
...@@ -311,8 +317,9 @@ class EagleProposer: ...@@ -311,8 +317,9 @@ class EagleProposer:
if num_tokens_across_dp is not None: if num_tokens_across_dp is not None:
num_tokens_across_dp[self.dp_rank] = num_input_tokens num_tokens_across_dp[self.dp_rank] = num_input_tokens
# copy inputs to buffer for cudagraph if self.pass_hidden_states_to_model:
self._set_positions(num_tokens, target_positions) # target_hidden_states and self.hidden_states can have different
# hidden dims. E.g. large target model and small draft model.
self.hidden_states[:num_tokens] = target_hidden_states self.hidden_states[:num_tokens] = target_hidden_states
if self.supports_mm_inputs: if self.supports_mm_inputs:
...@@ -330,6 +337,14 @@ class EagleProposer: ...@@ -330,6 +337,14 @@ class EagleProposer:
input_ids = self.input_ids[:num_input_tokens] input_ids = self.input_ids[:num_input_tokens]
inputs_embeds = None inputs_embeds = None
model_kwargs = {
"input_ids": input_ids,
"positions": self._get_positions(num_input_tokens),
"inputs_embeds": inputs_embeds,
}
if self.pass_hidden_states_to_model:
model_kwargs["hidden_states"] = self.hidden_states[:num_input_tokens]
with set_forward_context( with set_forward_context(
per_layer_attn_metadata, per_layer_attn_metadata,
self.vllm_config, self.vllm_config,
...@@ -337,17 +352,13 @@ class EagleProposer: ...@@ -337,17 +352,13 @@ class EagleProposer:
num_tokens_across_dp=num_tokens_across_dp, num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=cudagraph_runtime_mode, cudagraph_runtime_mode=cudagraph_runtime_mode,
): ):
ret_hidden_states = self.model( ret_hidden_states = self.model(**model_kwargs)
input_ids=input_ids, if not self.model_returns_tuple():
positions=self._get_positions(num_input_tokens),
hidden_states=self.hidden_states[:num_input_tokens],
inputs_embeds=inputs_embeds,
)
if self.method == "mtp":
last_hidden_states = ret_hidden_states last_hidden_states = ret_hidden_states
hidden_states = last_hidden_states hidden_states = last_hidden_states
else: else:
last_hidden_states, hidden_states = ret_hidden_states last_hidden_states, hidden_states = ret_hidden_states
sample_hidden_states = last_hidden_states[last_token_indices] sample_hidden_states = last_hidden_states[last_token_indices]
logits = self.model.compute_logits(sample_hidden_states) logits = self.model.compute_logits(sample_hidden_states)
...@@ -357,9 +368,9 @@ class EagleProposer: ...@@ -357,9 +368,9 @@ class EagleProposer:
return draft_token_ids.view(-1, 1) return draft_token_ids.view(-1, 1)
if self.uses_mrope: if self.uses_mrope:
positions = target_positions[:, last_token_indices] positions = self.positions[:, last_token_indices]
else: else:
positions = target_positions[last_token_indices] positions = self.positions[last_token_indices]
if self.method in ( if self.method in (
"deepseek_mtp", "deepseek_mtp",
"ernie_mtp", "ernie_mtp",
...@@ -527,6 +538,14 @@ class EagleProposer: ...@@ -527,6 +538,14 @@ class EagleProposer:
inputs_embeds = None inputs_embeds = None
# Run the model. # Run the model.
model_kwargs = {
"input_ids": input_ids,
"positions": self._get_positions(input_batch_size),
"inputs_embeds": inputs_embeds,
}
if self.pass_hidden_states_to_model:
model_kwargs["hidden_states"] = self.hidden_states[:input_batch_size]
with set_forward_context( with set_forward_context(
per_layer_attn_metadata, per_layer_attn_metadata,
self.vllm_config, self.vllm_config,
...@@ -534,17 +553,13 @@ class EagleProposer: ...@@ -534,17 +553,13 @@ class EagleProposer:
num_tokens_across_dp=batch_size_across_dp, num_tokens_across_dp=batch_size_across_dp,
cudagraph_runtime_mode=cudagraph_runtime_mode, cudagraph_runtime_mode=cudagraph_runtime_mode,
): ):
ret_hidden_states = self.model( ret_hidden_states = self.model(**model_kwargs)
input_ids=input_ids, if not self.model_returns_tuple():
positions=self._get_positions(input_batch_size),
hidden_states=self.hidden_states[:input_batch_size],
inputs_embeds=inputs_embeds,
)
if self.method == "mtp":
last_hidden_states = ret_hidden_states last_hidden_states = ret_hidden_states
hidden_states = ret_hidden_states hidden_states = ret_hidden_states
else: else:
last_hidden_states, hidden_states = ret_hidden_states last_hidden_states, hidden_states = ret_hidden_states
hidden_states = hidden_states[:batch_size] hidden_states = hidden_states[:batch_size]
logits = self.model.compute_logits(last_hidden_states[:batch_size]) logits = self.model.compute_logits(last_hidden_states[:batch_size])
draft_token_ids = logits.argmax(dim=-1) draft_token_ids = logits.argmax(dim=-1)
...@@ -554,6 +569,34 @@ class EagleProposer: ...@@ -554,6 +569,34 @@ class EagleProposer:
draft_token_ids = torch.stack(draft_token_ids_list, dim=1) draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
return draft_token_ids return draft_token_ids
def set_inputs_first_pass(
self,
target_token_ids: torch.Tensor,
next_token_ids: torch.Tensor,
target_positions: torch.Tensor,
last_token_indices: torch.Tensor | None,
cad: CommonAttentionMetadata,
num_rejected_tokens_gpu: torch.Tensor | None,
) -> tuple[int, torch.Tensor, CommonAttentionMetadata]:
if last_token_indices is None:
last_token_indices = cad.query_start_loc[1:] - 1
num_tokens = target_token_ids.shape[0]
# Shift the input ids by one token.
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
self.input_ids[: num_tokens - 1] = target_token_ids[1:]
# Replace the last token with the next token.
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
self.input_ids[last_token_indices] = next_token_ids
# copy inputs to buffer for cudagraph
self._set_positions(num_tokens, target_positions)
return num_tokens, last_token_indices, cad
def model_returns_tuple(self) -> bool:
return self.method not in ("mtp", "draft_model")
def prepare_next_token_ids_cpu( def prepare_next_token_ids_cpu(
self, self,
sampled_token_ids: list[list[int]], sampled_token_ids: list[list[int]],
...@@ -1214,12 +1257,14 @@ class EagleProposer: ...@@ -1214,12 +1257,14 @@ class EagleProposer:
input_ids = self.input_ids[:num_input_tokens] input_ids = self.input_ids[:num_input_tokens]
inputs_embeds = None inputs_embeds = None
self.model( kwargs = dict(
input_ids=input_ids, input_ids=input_ids,
positions=self._get_positions(num_input_tokens), positions=self._get_positions(num_input_tokens),
hidden_states=self.hidden_states[:num_input_tokens],
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
) )
if self.pass_hidden_states_to_model:
kwargs["hidden_states"] = self.hidden_states[:num_input_tokens]
self.model(**kwargs)
def _get_attention_metadata_builder(self) -> AttentionMetadataBuilder: def _get_attention_metadata_builder(self) -> AttentionMetadataBuilder:
"""Find and return the attention metadata builders for EAGLE layers. """Find and return the attention metadata builders for EAGLE layers.
...@@ -1264,8 +1309,8 @@ class EagleProposer: ...@@ -1264,8 +1309,8 @@ class EagleProposer:
def validate_same_kv_cache_group(self, kv_cache_config: KVCacheConfig) -> None: def validate_same_kv_cache_group(self, kv_cache_config: KVCacheConfig) -> None:
""" """
Validate that all eagle layers belong to the same KVCacheGroup. Validate that all drafting layers belong to the same KVCacheGroup.
Need this assumption to ensure all eagle layers can use the Need this assumption to ensure all drafting layers can use the
same AttentionMetadata. same AttentionMetadata.
May extend to multiple AttentionMetadata in the future. May extend to multiple AttentionMetadata in the future.
""" """
...@@ -1283,7 +1328,7 @@ class EagleProposer: ...@@ -1283,7 +1328,7 @@ class EagleProposer:
) )
) )
== 1 == 1
), "All eagle layers should belong to the same kv cache group" ), "All drafting layers should belong to the same kv cache group"
def _pad_batch_across_dp( def _pad_batch_across_dp(
self, self,
...@@ -1308,6 +1353,21 @@ class EagleProposer: ...@@ -1308,6 +1353,21 @@ class EagleProposer:
return num_tokens_dp_padded, num_toks_across_dp return num_tokens_dp_padded, num_toks_across_dp
class EagleProposer(SpecDecodeBaseProposer):
def __init__(
self,
vllm_config: VllmConfig,
device: torch.device,
runner=None,
):
super().__init__(
vllm_config,
device,
pass_hidden_states_to_model=True,
runner=runner,
)
# NOTE(woosuk): Currently, the below code is not used and we always use argmax # NOTE(woosuk): Currently, the below code is not used and we always use argmax
# to sample the draft tokens. We will use this after we find a way to manage # to sample the draft tokens. We will use this after we find a way to manage
# the draft prob tensor. # the draft prob tensor.
......
...@@ -145,6 +145,7 @@ from vllm.v1.sample.logits_processor.interface import LogitsProcessor ...@@ -145,6 +145,7 @@ from vllm.v1.sample.logits_processor.interface import LogitsProcessor
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.rejection_sampler import RejectionSampler from vllm.v1.sample.rejection_sampler import RejectionSampler
from vllm.v1.sample.sampler import Sampler from vllm.v1.sample.sampler import Sampler
from vllm.v1.spec_decode.draft_model import DraftModelProposer
from vllm.v1.spec_decode.eagle import EagleProposer from vllm.v1.spec_decode.eagle import EagleProposer
from vllm.v1.spec_decode.medusa import MedusaProposer from vllm.v1.spec_decode.medusa import MedusaProposer
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
...@@ -432,10 +433,20 @@ class GPUModelRunner( ...@@ -432,10 +433,20 @@ class GPUModelRunner(
# layers in the draft model. # layers in the draft model.
if self.speculative_config and get_pp_group().is_last_rank: if self.speculative_config and get_pp_group().is_last_rank:
self.drafter: ( self.drafter: (
NgramProposer | SuffixDecodingProposer | EagleProposer | MedusaProposer NgramProposer
| SuffixDecodingProposer
| EagleProposer
| DraftModelProposer
| MedusaProposer
) )
if self.speculative_config.method == "ngram": if self.speculative_config.method == "ngram":
self.drafter = NgramProposer(self.vllm_config) self.drafter = NgramProposer(self.vllm_config)
elif self.speculative_config.uses_draft_model():
self.drafter = DraftModelProposer(
vllm_config=self.vllm_config,
device=self.device,
runner=self,
)
elif self.speculative_config.method == "suffix": elif self.speculative_config.method == "suffix":
self.drafter = SuffixDecodingProposer(self.vllm_config) self.drafter = SuffixDecodingProposer(self.vllm_config)
elif self.speculative_config.use_eagle(): elif self.speculative_config.use_eagle():
...@@ -3443,10 +3454,13 @@ class GPUModelRunner( ...@@ -3443,10 +3454,13 @@ class GPUModelRunner(
spec_decode_common_attn_metadata.max_seq_len + self.num_spec_tokens spec_decode_common_attn_metadata.max_seq_len + self.num_spec_tokens
<= self.effective_drafter_max_model_len <= self.effective_drafter_max_model_len
) )
if spec_config.use_eagle() and not spec_config.disable_padded_drafter_batch: use_gpu_toks = (
# EAGLE speculative decoding can use the GPU sampled tokens spec_config.use_eagle() or spec_config.uses_draft_model()
) and not spec_config.disable_padded_drafter_batch
if use_gpu_toks:
# EAGLE/DraftModel speculative decoding can use the GPU sampled tokens
# as inputs, and does not need to wait for bookkeeping to finish. # as inputs, and does not need to wait for bookkeeping to finish.
assert isinstance(self.drafter, EagleProposer) assert isinstance(self.drafter, EagleProposer | DraftModelProposer)
sampled_token_ids = sampler_output.sampled_token_ids sampled_token_ids = sampler_output.sampled_token_ids
if input_fits_in_drafter: if input_fits_in_drafter:
propose_draft_token_ids(sampled_token_ids) propose_draft_token_ids(sampled_token_ids)
...@@ -3679,8 +3693,8 @@ class GPUModelRunner( ...@@ -3679,8 +3693,8 @@ class GPUModelRunner(
target_hidden_states=hidden_states, target_hidden_states=hidden_states,
sampling_metadata=sampling_metadata, sampling_metadata=sampling_metadata,
) )
elif spec_config.use_eagle(): elif spec_config.use_eagle() or spec_config.uses_draft_model():
assert isinstance(self.drafter, EagleProposer) assert isinstance(self.drafter, EagleProposer | DraftModelProposer)
if spec_config.disable_padded_drafter_batch: if spec_config.disable_padded_drafter_batch:
# When padded-batch is disabled, the sampled_token_ids should be # When padded-batch is disabled, the sampled_token_ids should be
...@@ -4475,8 +4489,12 @@ class GPUModelRunner( ...@@ -4475,8 +4489,12 @@ class GPUModelRunner(
else: else:
hidden_states = outputs hidden_states = outputs
if self.speculative_config and self.speculative_config.use_eagle(): if self.speculative_config and (
assert isinstance(self.drafter, EagleProposer) self.speculative_config.use_eagle()
or self.speculative_config.uses_draft_model()
):
assert isinstance(self.drafter, EagleProposer | DraftModelProposer)
assert self.speculative_config is not None
# Eagle currently only supports PIECEWISE cudagraphs. # Eagle currently only supports PIECEWISE cudagraphs.
# Therefore only use cudagraphs if the main model uses PIECEWISE # Therefore only use cudagraphs if the main model uses PIECEWISE
# NOTE(lucas): this is a hack, need to clean up. # NOTE(lucas): this is a hack, need to clean up.
...@@ -5652,8 +5670,11 @@ class GPUModelRunner( ...@@ -5652,8 +5670,11 @@ class GPUModelRunner(
kv_cache_config, kernel_block_sizes kv_cache_config, kernel_block_sizes
) )
if self.speculative_config and self.speculative_config.use_eagle(): if self.speculative_config and (
assert isinstance(self.drafter, EagleProposer) self.speculative_config.use_eagle()
or self.speculative_config.uses_draft_model()
):
assert isinstance(self.drafter, EagleProposer | DraftModelProposer)
# validate all draft model layers belong to the same kv cache # validate all draft model layers belong to the same kv cache
# group # group
self.drafter.validate_same_kv_cache_group(kv_cache_config) self.drafter.validate_same_kv_cache_group(kv_cache_config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment