Unverified Commit d8874c61 authored by Ronald's avatar Ronald Committed by GitHub
Browse files

[Core] Async Scheduling X Spec Decoding Compatibility (#24799)


Signed-off-by: default avatarRonald1995 <ronaldautomobile@163.com>
Signed-off-by: default avatarNick Hill <nhill@redhat.com>
Signed-off-by: default avatarBenjamin Chislett <chislett.ben@gmail.com>
Co-authored-by: default avatarNick Hill <nhill@redhat.com>
Co-authored-by: default avatarBenjamin Chislett <chislett.ben@gmail.com>
parent f8b19c0f
......@@ -15,7 +15,7 @@ from ...conftest import VllmRunner
from ...models.utils import check_outputs_equal
MODEL = "Qwen/Qwen3-0.6B"
MTP_MODEL = "XiaomiMiMo/MiMo-7B-Base"
MTP_MODEL = "meta-llama/Llama-3.2-1B-Instruct"
first_prompt = (
......@@ -29,7 +29,8 @@ example_prompts = [first_prompt, "In one word, the capital of France is "] + [
default_params = dict(
temperature=0.0, # greedy
max_tokens=20,
max_tokens=23,
min_tokens=18,
)
......@@ -69,15 +70,9 @@ def test_without_spec_decoding(
(True, "uni", True, None, True),
]
run_tests(
monkeypatch,
MODEL,
test_configs,
test_sampling_params,
)
run_tests(monkeypatch, MODEL, test_configs, test_sampling_params)
@pytest.mark.skip("MTP model too big to run in fp32 in CI")
def test_with_spec_decoding(monkeypatch: pytest.MonkeyPatch):
"""Test consistency and acceptance rates with some different combos of
preemption, executor, async scheduling, prefill chunking,
......@@ -85,8 +80,9 @@ def test_with_spec_decoding(monkeypatch: pytest.MonkeyPatch):
"""
spec_config = {
"method": "mtp",
"method": "eagle3",
"num_speculative_tokens": 2,
"model": "nm-testing/Llama3_2_1B_speculator.eagle3",
}
spec_config_short = spec_config | {"max_model_len": 50}
......@@ -106,12 +102,7 @@ def test_with_spec_decoding(monkeypatch: pytest.MonkeyPatch):
(True, "uni", True, spec_config_short, True),
]
run_tests(
monkeypatch,
MTP_MODEL,
test_configs,
[{}],
)
run_tests(monkeypatch, MTP_MODEL, test_configs, [{}])
@dynamo_config.patch(cache_size_limit=16)
......@@ -182,15 +173,13 @@ def run_tests(
and test_acceptance_rate is not None
):
if "spec_mml=None" in test_config:
# because the acceptance rate can vary, we use a looser
# tolerance here.
assert (
pytest.approx(test_acceptance_rate, rel=5e-2)
== base_acceptance_rate
)
else:
# Currently the reported acceptance rate is expected to be
# lower when we skip drafting altogether.
# lower when we sometimes skip drafting altogether.
assert test_acceptance_rate > 0.05
print(
f"PASSED: config=[{test_config}], params={params}"
......@@ -220,6 +209,7 @@ def run_test(
):
spec_decoding = spec_config is not None
cache_arg: dict[str, Any] = (
# Force preemptions
dict(num_gpu_blocks_override=32)
if test_preemption
else dict(gpu_memory_utilization=0.9)
......@@ -238,6 +228,7 @@ def run_test(
model,
max_model_len=512,
enable_chunked_prefill=test_prefill_chunking,
# Force prefill chunking
max_num_batched_tokens=48 if test_prefill_chunking else None,
# enforce_eager=True,
async_scheduling=async_scheduling,
......@@ -255,10 +246,7 @@ def run_test(
results.append(
vllm_model.generate(
example_prompts,
sampling_params=SamplingParams(
**default_params,
**override_params,
),
sampling_params=SamplingParams(**default_params, **override_params),
return_logprobs=True,
)
)
......@@ -270,9 +258,7 @@ def run_test(
if test_preemption:
preemptions = _get_count(
metrics_before,
metrics_after,
"vllm:num_preemptions",
metrics_before, metrics_after, "vllm:num_preemptions"
)
assert preemptions > 0, "preemption test had no preemptions"
......
......@@ -3,7 +3,7 @@
import ast
import hashlib
from typing import TYPE_CHECKING, Any, Literal
from typing import TYPE_CHECKING, Any, Literal, get_args
from pydantic import Field, SkipValidation, model_validator
from pydantic.dataclasses import dataclass
......@@ -29,31 +29,25 @@ else:
logger = init_logger(__name__)
SpeculativeMethod = Literal[
"ngram",
"eagle",
"eagle3",
"medusa",
"mlp_speculator",
"draft_model",
"deepseek_mtp",
"ernie_mtp",
"qwen3_next_mtp",
"mimo_mtp",
"longcat_flash_mtp",
"pangu_ultra_moe_mtp",
"mtp",
"suffix",
]
MTP_MODEL_TYPES = (
MTPModelTypes = Literal[
"deepseek_mtp",
"mimo_mtp",
"glm4_moe_mtp",
"ernie_mtp",
"qwen3_next_mtp",
"longcat_flash_mtp",
"mtp",
"pangu_ultra_moe_mtp",
)
]
EagleModelTypes = Literal["eagle", "eagle3", MTPModelTypes]
SpeculativeMethod = Literal[
"ngram",
"medusa",
"mlp_speculator",
"draft_model",
"suffix",
EagleModelTypes,
]
@config
......@@ -244,7 +238,7 @@ class SpeculativeConfig:
# can not be detected, it will be considered as the "draft_model" by
# default.
if self.method in MTP_MODEL_TYPES:
if self.method in get_args(MTPModelTypes) and self.method != "mtp":
logger.warning(
"method `%s` is deprecated and replaced with mtp.", self.method
)
......@@ -361,7 +355,9 @@ class SpeculativeConfig:
self.method = "medusa"
elif self.draft_model_config.hf_config.model_type == "mlp_speculator":
self.method = "mlp_speculator"
elif self.draft_model_config.hf_config.model_type in MTP_MODEL_TYPES:
elif self.draft_model_config.hf_config.model_type in get_args(
MTPModelTypes
):
self.method = "mtp"
if self.num_speculative_tokens > 1:
logger.warning(
......
......@@ -14,13 +14,14 @@ from dataclasses import replace
from datetime import datetime
from functools import lru_cache
from pathlib import Path
from typing import TYPE_CHECKING, Any, TypeVar
from typing import TYPE_CHECKING, Any, TypeVar, get_args
import torch
from pydantic import ConfigDict, Field, model_validator
from pydantic.dataclasses import dataclass
import vllm.envs as envs
from vllm.config.speculative import EagleModelTypes
from vllm.logger import enable_trace_function_call, init_logger
from vllm.transformers_utils.runai_utils import is_runai_obj_uri
from vllm.utils import random_uuid
......@@ -374,10 +375,22 @@ class VllmConfig:
"Async scheduling is not yet compatible with "
"pipeline_parallel_size > 1."
)
# Currently, async scheduling only support eagle speculative
# decoding.
if self.speculative_config is not None:
raise ValueError(
"Async scheduling is not yet compatible with speculative decoding."
)
if self.speculative_config.method not in get_args(EagleModelTypes):
raise ValueError(
"Currently, async scheduling is only supported "
"with EAGLE/MTP kind of speculative decoding"
)
if self.speculative_config.disable_padded_drafter_batch:
raise ValueError(
"async scheduling for EAGLE/MTP kind of speculative "
"decoding is enabled, but disable_padded_drafter_batch=True "
"disable_padded_drafter_batch=True is not supported for "
"this situation now. please set "
"disable_padded_drafter_batch=Fasle"
)
if not executor_supports_async_sched:
raise ValueError(
"Currently, async scheduling only supports `mp`, `uni`, or "
......
......@@ -16,18 +16,25 @@ class AsyncScheduler(Scheduler):
) -> None:
super()._update_after_schedule(scheduler_output)
pending_structured_output_tokens = False
spec_decode_tokens = scheduler_output.scheduled_spec_decode_tokens
for req_id in scheduler_output.num_scheduled_tokens:
request = self.requests[req_id]
pending_structured_output_tokens |= (
request.use_structured_output and request.num_output_placeholders > 0
)
cur_num_spec_tokens = len(spec_decode_tokens.get(req_id, ()))
if (
request.num_computed_tokens
== request.num_tokens + request.num_output_placeholders
== request.num_tokens
+ request.num_output_placeholders
+ cur_num_spec_tokens
):
# The request will generate a new token in this scheduling step.
# TODO(woosuk): Support speculative decoding.
request.num_output_placeholders += 1
# The request will generate a new token plus num_spec_tokens
# in this scheduling step.
request.num_output_placeholders += 1 + cur_num_spec_tokens
# Add placeholders for the new tokens in spec_token_ids.
# Wwe will update the actual spec token ids in the worker process.
request.spec_token_ids = [-1] * self.num_spec_tokens
scheduler_output.pending_structured_output_tokens = (
pending_structured_output_tokens
......
......@@ -348,7 +348,10 @@ class Scheduler(SchedulerInterface):
# Speculative decode related.
if request.spec_token_ids:
num_scheduled_spec_tokens = (
num_new_tokens + request.num_computed_tokens - request.num_tokens
num_new_tokens
+ request.num_computed_tokens
- request.num_tokens
- request.num_output_placeholders
)
if num_scheduled_spec_tokens > 0:
# Trim spec_token_ids list to num_scheduled_spec_tokens.
......@@ -1024,7 +1027,12 @@ class Scheduler(SchedulerInterface):
# tokens and rejections. If some tokens are rejected,
# num_computed_tokens is decreased by the number of rejected
# tokens.
request.num_computed_tokens -= num_rejected
if request.num_computed_tokens > 0:
request.num_computed_tokens -= num_rejected
# If async scheduling, num_output_placeholders also includes
# the scheduled spec tokens count and so is similarly adjusted.
if request.num_output_placeholders > 0:
request.num_output_placeholders -= num_rejected
spec_decoding_stats = self.make_spec_decoding_stats(
spec_decoding_stats,
num_draft_tokens=num_draft_tokens,
......
......@@ -198,6 +198,7 @@ class EngineCore:
self.step_fn = (
self.step if self.batch_queue is None else self.step_with_batch_queue
)
self.async_scheduling = vllm_config.scheduler_config.async_scheduling
# Mark the startup heap as static so that it's ignored by GC.
# Reduces pause times of oldest generation collections.
......@@ -341,7 +342,10 @@ class EngineCore:
return engine_core_outputs, scheduler_output.total_num_scheduled_tokens > 0
def post_step(self, model_executed: bool) -> None:
if self.use_spec_decode and model_executed:
# When using async scheduling we can't get draft token ids in advance,
# so we update draft token ids in the worker process and don't
# need to update draft token ids here.
if not self.async_scheduling and self.use_spec_decode and model_executed:
# Take the draft token ids.
draft_token_ids = self.model_executor.take_draft_token_ids()
if draft_token_ids is not None:
......
......@@ -150,6 +150,23 @@ class Processor:
raise ValueError(
"vLLM V1 does not support per request user provided logits processors."
)
# Async scheduling + spec decode currently incompatible with some
# sampling parameters.
if (
self.vllm_config.speculative_config is not None
and self.vllm_config.scheduler_config.async_scheduling
and (
params.frequency_penalty != 0.0
or params.presence_penalty != 0.0
or params.repetition_penalty != 1.0
or params.bad_words_token_ids
or params.structured_outputs
)
):
raise ValueError(
"async scheduling with spec decoding doesn't yet support "
"penalties, bad words or structured outputs in sampling parameters."
)
def _validate_params(
self,
......
......@@ -41,7 +41,7 @@ STR_POOLING_REJECTS_LOGITSPROCS = (
# Error message when the user tries to initialize vLLM with a speculative
# decoding enabled and custom logitsproces
STR_SPEC_DEC_REJECTS_LOGITSPROCS = (
"Custom logits processors are not supportedwhen speculative decoding is enabled."
"Custom logits processors are not supported when speculative decoding is enabled."
)
LOGITSPROCS_GROUP = "vllm.logits_processors"
......
......@@ -397,10 +397,13 @@ class EagleProposer:
positions += 1
exceeds_max_model_len = positions >= self.max_model_len
clamped_positions = torch.where(exceeds_max_model_len, 0, positions)
# For data integrity when async scheduling, we shouldn't use in place
# operations in case they are modified in next step's `prepare_input`
# of main model.
# Increment the sequence lengths.
common_attn_metadata.seq_lens += 1
common_attn_metadata.seq_lens_cpu += 1
# This is an out-of-place operation to avoid modifying the original tensor.
common_attn_metadata.seq_lens_cpu = common_attn_metadata.seq_lens_cpu + 1
# For the requests that exceed the max model length, we set the
# sequence length to 1 to minimize their overheads in attention.
......
......@@ -46,6 +46,9 @@ class CachedRequestState:
lora_request: LoRARequest | None = None
prompt_embeds: torch.Tensor | None = None
# Used when both async_scheduling and spec_decode are enabled.
prev_num_draft_len: int = 0
def __post_init__(self):
self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
self.prompt_token_ids, self.prompt_embeds
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment