Unverified Commit d8874c61 authored by Ronald's avatar Ronald Committed by GitHub
Browse files

[Core] Async Scheduling X Spec Decoding Compatibility (#24799)


Signed-off-by: default avatarRonald1995 <ronaldautomobile@163.com>
Signed-off-by: default avatarNick Hill <nhill@redhat.com>
Signed-off-by: default avatarBenjamin Chislett <chislett.ben@gmail.com>
Co-authored-by: default avatarNick Hill <nhill@redhat.com>
Co-authored-by: default avatarBenjamin Chislett <chislett.ben@gmail.com>
parent f8b19c0f
...@@ -15,7 +15,7 @@ from ...conftest import VllmRunner ...@@ -15,7 +15,7 @@ from ...conftest import VllmRunner
from ...models.utils import check_outputs_equal from ...models.utils import check_outputs_equal
MODEL = "Qwen/Qwen3-0.6B" MODEL = "Qwen/Qwen3-0.6B"
MTP_MODEL = "XiaomiMiMo/MiMo-7B-Base" MTP_MODEL = "meta-llama/Llama-3.2-1B-Instruct"
first_prompt = ( first_prompt = (
...@@ -29,7 +29,8 @@ example_prompts = [first_prompt, "In one word, the capital of France is "] + [ ...@@ -29,7 +29,8 @@ example_prompts = [first_prompt, "In one word, the capital of France is "] + [
default_params = dict( default_params = dict(
temperature=0.0, # greedy temperature=0.0, # greedy
max_tokens=20, max_tokens=23,
min_tokens=18,
) )
...@@ -69,15 +70,9 @@ def test_without_spec_decoding( ...@@ -69,15 +70,9 @@ def test_without_spec_decoding(
(True, "uni", True, None, True), (True, "uni", True, None, True),
] ]
run_tests( run_tests(monkeypatch, MODEL, test_configs, test_sampling_params)
monkeypatch,
MODEL,
test_configs,
test_sampling_params,
)
@pytest.mark.skip("MTP model too big to run in fp32 in CI")
def test_with_spec_decoding(monkeypatch: pytest.MonkeyPatch): def test_with_spec_decoding(monkeypatch: pytest.MonkeyPatch):
"""Test consistency and acceptance rates with some different combos of """Test consistency and acceptance rates with some different combos of
preemption, executor, async scheduling, prefill chunking, preemption, executor, async scheduling, prefill chunking,
...@@ -85,8 +80,9 @@ def test_with_spec_decoding(monkeypatch: pytest.MonkeyPatch): ...@@ -85,8 +80,9 @@ def test_with_spec_decoding(monkeypatch: pytest.MonkeyPatch):
""" """
spec_config = { spec_config = {
"method": "mtp", "method": "eagle3",
"num_speculative_tokens": 2, "num_speculative_tokens": 2,
"model": "nm-testing/Llama3_2_1B_speculator.eagle3",
} }
spec_config_short = spec_config | {"max_model_len": 50} spec_config_short = spec_config | {"max_model_len": 50}
...@@ -106,12 +102,7 @@ def test_with_spec_decoding(monkeypatch: pytest.MonkeyPatch): ...@@ -106,12 +102,7 @@ def test_with_spec_decoding(monkeypatch: pytest.MonkeyPatch):
(True, "uni", True, spec_config_short, True), (True, "uni", True, spec_config_short, True),
] ]
run_tests( run_tests(monkeypatch, MTP_MODEL, test_configs, [{}])
monkeypatch,
MTP_MODEL,
test_configs,
[{}],
)
@dynamo_config.patch(cache_size_limit=16) @dynamo_config.patch(cache_size_limit=16)
...@@ -182,15 +173,13 @@ def run_tests( ...@@ -182,15 +173,13 @@ def run_tests(
and test_acceptance_rate is not None and test_acceptance_rate is not None
): ):
if "spec_mml=None" in test_config: if "spec_mml=None" in test_config:
# because the acceptance rate can vary, we use a looser
# tolerance here.
assert ( assert (
pytest.approx(test_acceptance_rate, rel=5e-2) pytest.approx(test_acceptance_rate, rel=5e-2)
== base_acceptance_rate == base_acceptance_rate
) )
else: else:
# Currently the reported acceptance rate is expected to be # Currently the reported acceptance rate is expected to be
# lower when we skip drafting altogether. # lower when we sometimes skip drafting altogether.
assert test_acceptance_rate > 0.05 assert test_acceptance_rate > 0.05
print( print(
f"PASSED: config=[{test_config}], params={params}" f"PASSED: config=[{test_config}], params={params}"
...@@ -220,6 +209,7 @@ def run_test( ...@@ -220,6 +209,7 @@ def run_test(
): ):
spec_decoding = spec_config is not None spec_decoding = spec_config is not None
cache_arg: dict[str, Any] = ( cache_arg: dict[str, Any] = (
# Force preemptions
dict(num_gpu_blocks_override=32) dict(num_gpu_blocks_override=32)
if test_preemption if test_preemption
else dict(gpu_memory_utilization=0.9) else dict(gpu_memory_utilization=0.9)
...@@ -238,6 +228,7 @@ def run_test( ...@@ -238,6 +228,7 @@ def run_test(
model, model,
max_model_len=512, max_model_len=512,
enable_chunked_prefill=test_prefill_chunking, enable_chunked_prefill=test_prefill_chunking,
# Force prefill chunking
max_num_batched_tokens=48 if test_prefill_chunking else None, max_num_batched_tokens=48 if test_prefill_chunking else None,
# enforce_eager=True, # enforce_eager=True,
async_scheduling=async_scheduling, async_scheduling=async_scheduling,
...@@ -255,10 +246,7 @@ def run_test( ...@@ -255,10 +246,7 @@ def run_test(
results.append( results.append(
vllm_model.generate( vllm_model.generate(
example_prompts, example_prompts,
sampling_params=SamplingParams( sampling_params=SamplingParams(**default_params, **override_params),
**default_params,
**override_params,
),
return_logprobs=True, return_logprobs=True,
) )
) )
...@@ -270,9 +258,7 @@ def run_test( ...@@ -270,9 +258,7 @@ def run_test(
if test_preemption: if test_preemption:
preemptions = _get_count( preemptions = _get_count(
metrics_before, metrics_before, metrics_after, "vllm:num_preemptions"
metrics_after,
"vllm:num_preemptions",
) )
assert preemptions > 0, "preemption test had no preemptions" assert preemptions > 0, "preemption test had no preemptions"
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import ast import ast
import hashlib import hashlib
from typing import TYPE_CHECKING, Any, Literal from typing import TYPE_CHECKING, Any, Literal, get_args
from pydantic import Field, SkipValidation, model_validator from pydantic import Field, SkipValidation, model_validator
from pydantic.dataclasses import dataclass from pydantic.dataclasses import dataclass
...@@ -29,31 +29,25 @@ else: ...@@ -29,31 +29,25 @@ else:
logger = init_logger(__name__) logger = init_logger(__name__)
SpeculativeMethod = Literal[ MTPModelTypes = Literal[
"ngram",
"eagle",
"eagle3",
"medusa",
"mlp_speculator",
"draft_model",
"deepseek_mtp",
"ernie_mtp",
"qwen3_next_mtp",
"mimo_mtp",
"longcat_flash_mtp",
"pangu_ultra_moe_mtp",
"mtp",
"suffix",
]
MTP_MODEL_TYPES = (
"deepseek_mtp", "deepseek_mtp",
"mimo_mtp", "mimo_mtp",
"glm4_moe_mtp", "glm4_moe_mtp",
"ernie_mtp", "ernie_mtp",
"qwen3_next_mtp", "qwen3_next_mtp",
"longcat_flash_mtp", "longcat_flash_mtp",
"mtp",
"pangu_ultra_moe_mtp", "pangu_ultra_moe_mtp",
) ]
EagleModelTypes = Literal["eagle", "eagle3", MTPModelTypes]
SpeculativeMethod = Literal[
"ngram",
"medusa",
"mlp_speculator",
"draft_model",
"suffix",
EagleModelTypes,
]
@config @config
...@@ -244,7 +238,7 @@ class SpeculativeConfig: ...@@ -244,7 +238,7 @@ class SpeculativeConfig:
# can not be detected, it will be considered as the "draft_model" by # can not be detected, it will be considered as the "draft_model" by
# default. # default.
if self.method in MTP_MODEL_TYPES: if self.method in get_args(MTPModelTypes) and self.method != "mtp":
logger.warning( logger.warning(
"method `%s` is deprecated and replaced with mtp.", self.method "method `%s` is deprecated and replaced with mtp.", self.method
) )
...@@ -361,7 +355,9 @@ class SpeculativeConfig: ...@@ -361,7 +355,9 @@ class SpeculativeConfig:
self.method = "medusa" self.method = "medusa"
elif self.draft_model_config.hf_config.model_type == "mlp_speculator": elif self.draft_model_config.hf_config.model_type == "mlp_speculator":
self.method = "mlp_speculator" self.method = "mlp_speculator"
elif self.draft_model_config.hf_config.model_type in MTP_MODEL_TYPES: elif self.draft_model_config.hf_config.model_type in get_args(
MTPModelTypes
):
self.method = "mtp" self.method = "mtp"
if self.num_speculative_tokens > 1: if self.num_speculative_tokens > 1:
logger.warning( logger.warning(
......
...@@ -14,13 +14,14 @@ from dataclasses import replace ...@@ -14,13 +14,14 @@ from dataclasses import replace
from datetime import datetime from datetime import datetime
from functools import lru_cache from functools import lru_cache
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any, TypeVar from typing import TYPE_CHECKING, Any, TypeVar, get_args
import torch import torch
from pydantic import ConfigDict, Field, model_validator from pydantic import ConfigDict, Field, model_validator
from pydantic.dataclasses import dataclass from pydantic.dataclasses import dataclass
import vllm.envs as envs import vllm.envs as envs
from vllm.config.speculative import EagleModelTypes
from vllm.logger import enable_trace_function_call, init_logger from vllm.logger import enable_trace_function_call, init_logger
from vllm.transformers_utils.runai_utils import is_runai_obj_uri from vllm.transformers_utils.runai_utils import is_runai_obj_uri
from vllm.utils import random_uuid from vllm.utils import random_uuid
...@@ -374,9 +375,21 @@ class VllmConfig: ...@@ -374,9 +375,21 @@ class VllmConfig:
"Async scheduling is not yet compatible with " "Async scheduling is not yet compatible with "
"pipeline_parallel_size > 1." "pipeline_parallel_size > 1."
) )
# Currently, async scheduling only support eagle speculative
# decoding.
if self.speculative_config is not None: if self.speculative_config is not None:
if self.speculative_config.method not in get_args(EagleModelTypes):
raise ValueError( raise ValueError(
"Async scheduling is not yet compatible with speculative decoding." "Currently, async scheduling is only supported "
"with EAGLE/MTP kind of speculative decoding"
)
if self.speculative_config.disable_padded_drafter_batch:
raise ValueError(
"async scheduling for EAGLE/MTP kind of speculative "
"decoding is enabled, but disable_padded_drafter_batch=True "
"disable_padded_drafter_batch=True is not supported for "
"this situation now. please set "
"disable_padded_drafter_batch=Fasle"
) )
if not executor_supports_async_sched: if not executor_supports_async_sched:
raise ValueError( raise ValueError(
......
...@@ -16,18 +16,25 @@ class AsyncScheduler(Scheduler): ...@@ -16,18 +16,25 @@ class AsyncScheduler(Scheduler):
) -> None: ) -> None:
super()._update_after_schedule(scheduler_output) super()._update_after_schedule(scheduler_output)
pending_structured_output_tokens = False pending_structured_output_tokens = False
spec_decode_tokens = scheduler_output.scheduled_spec_decode_tokens
for req_id in scheduler_output.num_scheduled_tokens: for req_id in scheduler_output.num_scheduled_tokens:
request = self.requests[req_id] request = self.requests[req_id]
pending_structured_output_tokens |= ( pending_structured_output_tokens |= (
request.use_structured_output and request.num_output_placeholders > 0 request.use_structured_output and request.num_output_placeholders > 0
) )
cur_num_spec_tokens = len(spec_decode_tokens.get(req_id, ()))
if ( if (
request.num_computed_tokens request.num_computed_tokens
== request.num_tokens + request.num_output_placeholders == request.num_tokens
+ request.num_output_placeholders
+ cur_num_spec_tokens
): ):
# The request will generate a new token in this scheduling step. # The request will generate a new token plus num_spec_tokens
# TODO(woosuk): Support speculative decoding. # in this scheduling step.
request.num_output_placeholders += 1 request.num_output_placeholders += 1 + cur_num_spec_tokens
# Add placeholders for the new tokens in spec_token_ids.
# Wwe will update the actual spec token ids in the worker process.
request.spec_token_ids = [-1] * self.num_spec_tokens
scheduler_output.pending_structured_output_tokens = ( scheduler_output.pending_structured_output_tokens = (
pending_structured_output_tokens pending_structured_output_tokens
......
...@@ -348,7 +348,10 @@ class Scheduler(SchedulerInterface): ...@@ -348,7 +348,10 @@ class Scheduler(SchedulerInterface):
# Speculative decode related. # Speculative decode related.
if request.spec_token_ids: if request.spec_token_ids:
num_scheduled_spec_tokens = ( num_scheduled_spec_tokens = (
num_new_tokens + request.num_computed_tokens - request.num_tokens num_new_tokens
+ request.num_computed_tokens
- request.num_tokens
- request.num_output_placeholders
) )
if num_scheduled_spec_tokens > 0: if num_scheduled_spec_tokens > 0:
# Trim spec_token_ids list to num_scheduled_spec_tokens. # Trim spec_token_ids list to num_scheduled_spec_tokens.
...@@ -1024,7 +1027,12 @@ class Scheduler(SchedulerInterface): ...@@ -1024,7 +1027,12 @@ class Scheduler(SchedulerInterface):
# tokens and rejections. If some tokens are rejected, # tokens and rejections. If some tokens are rejected,
# num_computed_tokens is decreased by the number of rejected # num_computed_tokens is decreased by the number of rejected
# tokens. # tokens.
if request.num_computed_tokens > 0:
request.num_computed_tokens -= num_rejected request.num_computed_tokens -= num_rejected
# If async scheduling, num_output_placeholders also includes
# the scheduled spec tokens count and so is similarly adjusted.
if request.num_output_placeholders > 0:
request.num_output_placeholders -= num_rejected
spec_decoding_stats = self.make_spec_decoding_stats( spec_decoding_stats = self.make_spec_decoding_stats(
spec_decoding_stats, spec_decoding_stats,
num_draft_tokens=num_draft_tokens, num_draft_tokens=num_draft_tokens,
......
...@@ -198,6 +198,7 @@ class EngineCore: ...@@ -198,6 +198,7 @@ class EngineCore:
self.step_fn = ( self.step_fn = (
self.step if self.batch_queue is None else self.step_with_batch_queue self.step if self.batch_queue is None else self.step_with_batch_queue
) )
self.async_scheduling = vllm_config.scheduler_config.async_scheduling
# Mark the startup heap as static so that it's ignored by GC. # Mark the startup heap as static so that it's ignored by GC.
# Reduces pause times of oldest generation collections. # Reduces pause times of oldest generation collections.
...@@ -341,7 +342,10 @@ class EngineCore: ...@@ -341,7 +342,10 @@ class EngineCore:
return engine_core_outputs, scheduler_output.total_num_scheduled_tokens > 0 return engine_core_outputs, scheduler_output.total_num_scheduled_tokens > 0
def post_step(self, model_executed: bool) -> None: def post_step(self, model_executed: bool) -> None:
if self.use_spec_decode and model_executed: # When using async scheduling we can't get draft token ids in advance,
# so we update draft token ids in the worker process and don't
# need to update draft token ids here.
if not self.async_scheduling and self.use_spec_decode and model_executed:
# Take the draft token ids. # Take the draft token ids.
draft_token_ids = self.model_executor.take_draft_token_ids() draft_token_ids = self.model_executor.take_draft_token_ids()
if draft_token_ids is not None: if draft_token_ids is not None:
......
...@@ -150,6 +150,23 @@ class Processor: ...@@ -150,6 +150,23 @@ class Processor:
raise ValueError( raise ValueError(
"vLLM V1 does not support per request user provided logits processors." "vLLM V1 does not support per request user provided logits processors."
) )
# Async scheduling + spec decode currently incompatible with some
# sampling parameters.
if (
self.vllm_config.speculative_config is not None
and self.vllm_config.scheduler_config.async_scheduling
and (
params.frequency_penalty != 0.0
or params.presence_penalty != 0.0
or params.repetition_penalty != 1.0
or params.bad_words_token_ids
or params.structured_outputs
)
):
raise ValueError(
"async scheduling with spec decoding doesn't yet support "
"penalties, bad words or structured outputs in sampling parameters."
)
def _validate_params( def _validate_params(
self, self,
......
...@@ -41,7 +41,7 @@ STR_POOLING_REJECTS_LOGITSPROCS = ( ...@@ -41,7 +41,7 @@ STR_POOLING_REJECTS_LOGITSPROCS = (
# Error message when the user tries to initialize vLLM with a speculative # Error message when the user tries to initialize vLLM with a speculative
# decoding enabled and custom logitsproces # decoding enabled and custom logitsproces
STR_SPEC_DEC_REJECTS_LOGITSPROCS = ( STR_SPEC_DEC_REJECTS_LOGITSPROCS = (
"Custom logits processors are not supportedwhen speculative decoding is enabled." "Custom logits processors are not supported when speculative decoding is enabled."
) )
LOGITSPROCS_GROUP = "vllm.logits_processors" LOGITSPROCS_GROUP = "vllm.logits_processors"
......
...@@ -397,10 +397,13 @@ class EagleProposer: ...@@ -397,10 +397,13 @@ class EagleProposer:
positions += 1 positions += 1
exceeds_max_model_len = positions >= self.max_model_len exceeds_max_model_len = positions >= self.max_model_len
clamped_positions = torch.where(exceeds_max_model_len, 0, positions) clamped_positions = torch.where(exceeds_max_model_len, 0, positions)
# For data integrity when async scheduling, we shouldn't use in place
# operations in case they are modified in next step's `prepare_input`
# of main model.
# Increment the sequence lengths. # Increment the sequence lengths.
common_attn_metadata.seq_lens += 1 common_attn_metadata.seq_lens += 1
common_attn_metadata.seq_lens_cpu += 1 # This is an out-of-place operation to avoid modifying the original tensor.
common_attn_metadata.seq_lens_cpu = common_attn_metadata.seq_lens_cpu + 1
# For the requests that exceed the max model length, we set the # For the requests that exceed the max model length, we set the
# sequence length to 1 to minimize their overheads in attention. # sequence length to 1 to minimize their overheads in attention.
......
...@@ -46,6 +46,9 @@ class CachedRequestState: ...@@ -46,6 +46,9 @@ class CachedRequestState:
lora_request: LoRARequest | None = None lora_request: LoRARequest | None = None
prompt_embeds: torch.Tensor | None = None prompt_embeds: torch.Tensor | None = None
# Used when both async_scheduling and spec_decode are enabled.
prev_num_draft_len: int = 0
def __post_init__(self): def __post_init__(self):
self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds( self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
self.prompt_token_ids, self.prompt_embeds self.prompt_token_ids, self.prompt_embeds
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment