Unverified Commit a6be75db authored by PatchyTIS's avatar PatchyTIS Committed by GitHub
Browse files

[Core] NGram GPU Implementation compatible with Async Scheduler (#29184)

parent ee54f9cd
......@@ -98,7 +98,7 @@ def test_without_spec_decoding(
@single_gpu_only
@large_gpu_mark(min_gb=16)
def test_with_spec_decoding(sample_json_schema, monkeypatch: pytest.MonkeyPatch):
def test_with_eagle3_spec_decoding(sample_json_schema, monkeypatch: pytest.MonkeyPatch):
"""Test consistency and acceptance rates with some different combos of
preemption, executor, async scheduling, prefill chunking,
spec decoding model length.
......@@ -154,6 +154,42 @@ def test_with_spec_decoding(sample_json_schema, monkeypatch: pytest.MonkeyPatch)
)
def test_with_ngram_gpu_spec_decoding(monkeypatch: pytest.MonkeyPatch):
"""Test ngram_gpu speculative decoding with different configurations.
This test specifically validates ngram_gpu behavior with various:
- Number of speculative tokens (2-6)
- Prompt lookup window sizes (min/max)
- Async scheduling enabled (as in production)
- Different executors and chunking settings
"""
# Variant with larger speculation window
ngram_gpu_config = {
"method": "ngram_gpu",
"num_speculative_tokens": 3,
"prompt_lookup_max": 3,
"prompt_lookup_min": 2,
}
# Test configurations covering various scenarios
# test_preemption, executor, async_scheduling,
# spec_config, test_prefill_chunking
test_configs = [
(False, "mp", False, None, False),
(False, "mp", False, ngram_gpu_config, False),
(True, "mp", False, ngram_gpu_config, True),
(False, "mp", True, ngram_gpu_config, False),
(True, "mp", True, ngram_gpu_config, False),
(True, "uni", True, ngram_gpu_config, False),
(True, "mp", True, ngram_gpu_config, True),
]
# Use MODEL (Qwen) for ngram_gpu tests as it's lighter weight
# and ngram_gpu doesn't require a specific draft model
run_tests(monkeypatch, MODEL, test_configs, [{}])
@dynamo_config.patch(cache_size_limit=16)
def run_tests(
monkeypatch: pytest.MonkeyPatch,
......@@ -282,11 +318,12 @@ def run_test(
else dict(gpu_memory_utilization=0.9)
)
spec_mml = (spec_config or {}).get("max_model_len")
spec_method = (spec_config or {}).get("method", "none")
test_config = (
f"executor={executor}, preemption={test_preemption}, "
f"async_sched={async_scheduling}, "
f"chunk_prefill={test_prefill_chunking}, "
f"spec_decoding={spec_decoding}, spec_mml={spec_mml}"
f"spec_decoding={spec_decoding}, spec_method={spec_method}, spec_mml={spec_mml}"
)
print("-" * 80)
print(f"---- TESTING {test_str}: {test_config}")
......@@ -294,7 +331,7 @@ def run_test(
with VllmRunner(
model,
max_model_len=512,
max_model_len=4096,
enable_chunked_prefill=test_prefill_chunking,
# Force prefill chunking
max_num_batched_tokens=48 if test_prefill_chunking else None,
......
......@@ -183,6 +183,34 @@ def test_ngram_and_suffix_correctness(
cleanup_dist_env_and_memory()
@pytest.mark.parametrize("async_scheduling", [True], ids=["async"])
@single_gpu_only
@large_gpu_mark(min_gb=20)
def test_ngram_gpu_default_with_async_scheduling(
async_scheduling: bool,
):
"""
Test ngram_gpu speculative decoding (k=3) correctness with and without
async scheduling, validated via GSM8K accuracy.
Uses Qwen/Qwen3-8B (ref GSM8K accuracy: 87%-92%).
"""
qwen3_model = "Qwen/Qwen3-8B"
spec_llm = LLM(
model=qwen3_model,
speculative_config={
"method": "ngram_gpu",
"prompt_lookup_max": 3,
"prompt_lookup_min": 2,
"num_speculative_tokens": 2,
},
max_model_len=4096,
async_scheduling=async_scheduling,
)
evaluate_llm_for_gsm8k(spec_llm, expected_accuracy_threshold=0.8)
del spec_llm
cleanup_dist_env_and_memory()
@single_gpu_only
@large_gpu_mark(min_gb=20)
def test_suffix_decoding_acceptance(
......
......@@ -907,6 +907,13 @@ class VllmBackend:
# Honors opt-outs such as CompilationMode.NONE or VLLM_DISABLE_COMPILE_CACHE.
disable_cache = not is_compile_cache_enabled(self.inductor_config)
# TODO(patchy): ngram gpu kernel will cause vllm torch compile cache errors.
is_ngram_gpu_enabled = (
vllm_config.speculative_config is not None
and vllm_config.speculative_config.use_ngram_gpu()
)
disable_cache = disable_cache or is_ngram_gpu_enabled
if disable_cache:
logger.info_once("vLLM's torch.compile cache is disabled.", scope="local")
else:
......
......@@ -47,6 +47,7 @@ MTPModelTypes = Literal[
"step3p5_mtp",
]
EagleModelTypes = Literal["eagle", "eagle3", "extract_hidden_states", MTPModelTypes]
NgramGPUTypes = Literal["ngram_gpu"]
SpeculativeMethod = Literal[
"ngram",
"medusa",
......@@ -54,6 +55,7 @@ SpeculativeMethod = Literal[
"draft_model",
"suffix",
EagleModelTypes,
NgramGPUTypes,
]
......@@ -364,6 +366,8 @@ class SpeculativeConfig:
self.quantization = self.target_model_config.quantization
elif self.method in ("ngram", "[ngram]"):
self.model = "ngram"
elif self.method == "ngram_gpu":
self.model = "ngram_gpu"
elif self.method == "suffix":
self.model = "suffix"
elif self.method == "extract_hidden_states":
......@@ -374,8 +378,9 @@ class SpeculativeConfig:
)
if self.method in ("ngram", "[ngram]"):
# Unified to "ngram" internally
self.method = "ngram"
if self.method in ("ngram", "ngram_gpu"):
# Set default values if not provided
if self.prompt_lookup_min is None and self.prompt_lookup_max is None:
# TODO(woosuk): Tune these values. They are arbitrarily chosen.
......@@ -832,6 +837,9 @@ class SpeculativeConfig:
def uses_extract_hidden_states(self) -> bool:
return self.method == "extract_hidden_states"
def use_ngram_gpu(self) -> bool:
return self.method == "ngram_gpu"
def __repr__(self) -> str:
method = self.method
model = (
......
......@@ -41,7 +41,7 @@ from .offload import OffloadConfig
from .parallel import ParallelConfig
from .profiler import ProfilerConfig
from .scheduler import SchedulerConfig
from .speculative import EagleModelTypes, SpeculativeConfig
from .speculative import EagleModelTypes, NgramGPUTypes, SpeculativeConfig
from .structured_outputs import StructuredOutputsConfig
from .utils import SupportsHash, config, replace
from .weight_transfer import WeightTransferConfig
......@@ -696,11 +696,13 @@ class VllmConfig:
if self.speculative_config is not None:
if (
self.speculative_config.method not in get_args(EagleModelTypes)
and self.speculative_config.method not in get_args(NgramGPUTypes)
and self.speculative_config.method != "draft_model"
):
raise ValueError(
"Currently, async scheduling is only supported "
"with EAGLE/MTP/Draft Model kind of speculative decoding."
"with EAGLE/MTP/Draft Model/NGram GPU kind of "
"speculative decoding"
)
if self.speculative_config.disable_padded_drafter_batch:
raise ValueError(
......@@ -718,6 +720,7 @@ class VllmConfig:
if (
self.speculative_config is not None
and self.speculative_config.method not in get_args(EagleModelTypes)
and self.speculative_config.method not in get_args(NgramGPUTypes)
):
logger.warning_once(
"Async scheduling not supported with %s-based "
......
......@@ -385,6 +385,7 @@ class Hermes2ProToolParser(ToolParser):
prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get(
"arguments"
)
assert current_tool_call is not None
cur_arguments = current_tool_call.get("arguments")
logger.debug("diffing old arguments: %s", prev_arguments)
......@@ -489,6 +490,7 @@ class Hermes2ProToolParser(ToolParser):
# handle saving the state for the current tool into
# the "prev" list for use in diffing for the next iteration
assert isinstance(current_tool_call, dict)
if self.current_tool_id == len(self.prev_tool_call_arr) - 1:
self.prev_tool_call_arr[self.current_tool_id] = current_tool_call
else:
......
This diff is collapsed.
......@@ -127,7 +127,13 @@ class InputBatch:
# allocation if max_model_len is big.
# Maps req_index -> tensor of shape (num_prompt_tokens, hidden_size)
self.req_prompt_embeds: dict[int, torch.Tensor] = {}
self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32)
self.num_tokens_no_spec_cpu_tensor = torch.zeros(
(max_num_reqs,),
device="cpu",
dtype=torch.int32,
pin_memory=pin_memory,
)
self.num_tokens_no_spec = self.num_tokens_no_spec_cpu_tensor.numpy()
self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
self.num_computed_tokens_cpu_tensor = torch.zeros(
(max_num_reqs,),
......
......@@ -10,7 +10,7 @@ from collections import defaultdict
from collections.abc import Iterable, Iterator, Sequence
from contextlib import contextmanager
from copy import copy, deepcopy
from dataclasses import dataclass
from dataclasses import dataclass, replace
from functools import reduce
from typing import TYPE_CHECKING, Any, NamedTuple, TypeAlias, cast
......@@ -164,6 +164,12 @@ from vllm.v1.spec_decode.eagle import EagleProposer
from vllm.v1.spec_decode.extract_hidden_states import ExtractHiddenStatesProposer
from vllm.v1.spec_decode.medusa import MedusaProposer
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.spec_decode.ngram_proposer_gpu import (
NgramProposerGPU,
copy_num_valid_draft_tokens,
update_ngram_gpu_tensors_incremental,
update_scheduler_for_invalid_drafts,
)
from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer
from vllm.v1.structured_output.utils import apply_grammar_bitmask
from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext
......@@ -424,7 +430,7 @@ class GPUModelRunner(
# Broadcast PP output for external_launcher (torchrun)
# to make sure we are synced across pp ranks
# TODO: Support overlapping mirco-batches
# TODO: Support overlapping micro-batches
# https://github.com/vllm-project/vllm/issues/18019
self.broadcast_pp_output = (
self.parallel_config.distributed_executor_backend == "external_launcher"
......@@ -493,6 +499,7 @@ class GPUModelRunner(
if self.speculative_config and get_pp_group().is_last_rank:
self.drafter: (
NgramProposer # noqa: F823
| NgramProposerGPU
| SuffixDecodingProposer
| EagleProposer
| DraftModelProposer
......@@ -509,6 +516,23 @@ class GPUModelRunner(
device=self.device,
runner=self,
)
elif self.speculative_config.use_ngram_gpu():
self.drafter = NgramProposerGPU(self.vllm_config, self.device, self)
self.num_tokens_no_spec_gpu = torch.zeros(
self.max_num_reqs, dtype=torch.int32, device=device
)
self.token_ids_gpu_tensor = torch.zeros(
self.max_num_reqs,
self.max_model_len,
dtype=torch.int32,
device=device,
)
self._ngram_pinned_idx_buf = torch.zeros(
self.max_num_reqs, dtype=torch.long, pin_memory=True
)
self._ngram_pinned_val_buf = torch.zeros(
self.max_num_reqs, dtype=torch.int32, pin_memory=True
)
elif self.speculative_config.method == "suffix":
self.drafter = SuffixDecodingProposer(self.vllm_config)
elif self.speculative_config.use_eagle():
......@@ -564,7 +588,7 @@ class GPUModelRunner(
)
self.input_batch = InputBatch(
max_num_reqs=self.max_num_reqs,
# We need to use the encoder length for encoder-decoer
# We need to use the encoder length for encoder-decoder
# because of KV cache for cross-attention.
max_model_len=max(self.max_model_len, self.max_encoder_len),
max_num_batched_tokens=self.max_num_tokens,
......@@ -721,6 +745,21 @@ class GPUModelRunner(
# Cached outputs.
self._draft_token_ids: list[list[int]] | torch.Tensor | None = None
# N-gram GPU path: async D2H buffer/event for per-request valid draft counts.
self._num_valid_draft_tokens: torch.Tensor | None = None
self._num_valid_draft_tokens_cpu: torch.Tensor | None = None
self._num_valid_draft_tokens_event: torch.cuda.Event | None = None
self._num_valid_draft_tokens_copy_stream: torch.cuda.Stream | None = None
if (
self.speculative_config is not None
and self.speculative_config.use_ngram_gpu()
):
self._num_valid_draft_tokens_cpu = torch.empty(
self.max_num_reqs, dtype=torch.int32, pin_memory=self.pin_memory
)
self._num_valid_draft_tokens_event = torch.cuda.Event()
self._num_valid_draft_tokens_copy_stream = torch.cuda.Stream()
self._draft_token_req_ids: list[str] | None = None
self.transfer_event = torch.Event()
self.sampled_token_ids_pinned_cpu = torch.empty(
......@@ -992,6 +1031,13 @@ class GPUModelRunner(
for req_id in unscheduled_req_ids:
self.input_batch.remove_request(req_id)
is_ngram_gpu = (
self.speculative_config is not None
and self.speculative_config.use_ngram_gpu()
)
if is_ngram_gpu:
ngram_gpu_new_reqs: list[CachedRequestState] = []
reqs_to_add: list[CachedRequestState] = []
# Add new requests to the cached states.
for new_req_data in scheduler_output.scheduled_new_reqs:
......@@ -1054,12 +1100,31 @@ class GPUModelRunner(
self._init_xdrope_positions(req_state)
reqs_to_add.append(req_state)
# Track new requests for ngram_gpu full tensor copy
if is_ngram_gpu:
ngram_gpu_new_reqs.append(req_state)
# Update the states of the running/resumed requests.
is_last_rank = get_pp_group().is_last_rank
req_data = scheduler_output.scheduled_cached_reqs
scheduled_spec_tokens = scheduler_output.scheduled_spec_decode_tokens
# Save scheduler-allocated spec lengths before trimming so
# prev_num_draft_len keeps the optimistic count for rejection correction.
original_num_spec_per_req: dict[str, int] = {}
if (
self.speculative_config is not None
and self.speculative_config.use_ngram_gpu()
):
for req_id, toks in scheduled_spec_tokens.items():
original_num_spec_per_req[req_id] = len(toks)
update_scheduler_for_invalid_drafts(
self._num_valid_draft_tokens_event,
self._num_valid_draft_tokens_cpu,
scheduler_output,
self.input_batch.req_id_to_index,
)
# Wait until valid_sampled_tokens_count is copied to cpu,
# then use it to update actual num_computed_tokens of each request.
valid_sampled_token_count = self._get_valid_sampled_token_count()
......@@ -1076,13 +1141,13 @@ class GPUModelRunner(
# prev_num_draft_len is used in async scheduling mode with
# spec decode. it indicates if need to update num_computed_tokens
# of the request. for example:
# fist step: num_computed_tokens = 0, spec_tokens = [],
# first step: num_computed_tokens = 0, spec_tokens = [],
# prev_num_draft_len = 0.
# second step: num_computed_tokens = 100(prompt length),
# spec_tokens = [a,b], prev_num_draft_len = 0.
# third step: num_computed_tokens = 100 + 2, spec_tokens = [c,d],
# prev_num_draft_len = 2.
# num_computed_tokens in first step and second step does't contain
# num_computed_tokens in first step and second step doesn't contain
# the spec tokens length, but in third step it contains the
# spec tokens length. we only need to update num_computed_tokens
# when prev_num_draft_len > 0.
......@@ -1096,6 +1161,9 @@ class GPUModelRunner(
num_computed_tokens -= num_rejected
req_state.output_token_ids.extend([-1] * num_accepted)
if is_ngram_gpu and num_accepted > 0 and req_index is not None:
self.input_batch.num_tokens_no_spec[req_index] += num_accepted
# Update the cached states.
req_state.num_computed_tokens = num_computed_tokens
......@@ -1156,6 +1224,9 @@ class GPUModelRunner(
req_state.output_token_ids = resumed_token_ids[-num_output_tokens:]
reqs_to_add.append(req_state)
# Track resumed requests for ngram_gpu full tensor copy
if is_ngram_gpu:
ngram_gpu_new_reqs.append(req_state)
continue
# Update the persistent batch.
......@@ -1176,6 +1247,11 @@ class GPUModelRunner(
# Add spec_token_ids to token_ids_cpu.
self.input_batch.update_req_spec_token_ids(req_state, scheduled_spec_tokens)
# Restore scheduler-side draft count after ngram trimming.
if original_num_spec_per_req:
orig = original_num_spec_per_req.get(req_id, 0)
if orig != req_state.prev_num_draft_len:
req_state.prev_num_draft_len = orig
# Add the new or resumed requests to the persistent batch.
# The smaller empty indices are filled first.
......@@ -1190,6 +1266,18 @@ class GPUModelRunner(
# Refresh batch metadata with any pending updates.
self.input_batch.refresh_metadata()
# Incrementally update ngram_gpu tensors after batch is stable
if is_ngram_gpu:
update_ngram_gpu_tensors_incremental(
self.input_batch,
self.token_ids_gpu_tensor,
self.num_tokens_no_spec_gpu,
ngram_gpu_new_reqs,
self.device,
_pinned_idx_buf=self._ngram_pinned_idx_buf,
_pinned_val_buf=self._ngram_pinned_val_buf,
)
def _update_states_after_model_execute(
self, output_token_ids: torch.Tensor, scheduler_output: "SchedulerOutput"
) -> None:
......@@ -3412,6 +3500,23 @@ class GPUModelRunner(
else:
logger.error("RoutedExpertsCapturer not initialized.")
# If ngram_gpu is used, we need to copy the scheduler_output to avoid
# the modification has influence on the scheduler_output in engine core process.
# The replace is much faster than deepcopy.
if (
self.speculative_config is not None
and self.speculative_config.use_ngram_gpu()
):
num_scheduled_tokens_copy = scheduler_output.num_scheduled_tokens.copy()
spec_decode_tokens_copy = (
scheduler_output.scheduled_spec_decode_tokens.copy()
)
scheduler_output = replace(
scheduler_output,
num_scheduled_tokens=num_scheduled_tokens_copy,
scheduled_spec_decode_tokens=spec_decode_tokens_copy,
)
if scheduler_output.preempted_req_ids and has_kv_transfer_group():
get_kv_transfer_group().handle_preemptions(
scheduler_output.preempted_req_ids
......@@ -3825,6 +3930,32 @@ class GPUModelRunner(
self._copy_valid_sampled_token_count(
next_token_ids, valid_sampled_tokens_count
)
self._draft_token_ids = torch.zeros(
1, device=self.device, dtype=torch.int32
).expand(len(self.input_batch.req_ids), self.num_spec_tokens)
self._copy_draft_token_ids_to_cpu(scheduler_output, zeros_only=True)
elif (
spec_config.use_ngram_gpu()
and not spec_config.disable_padded_drafter_batch
):
assert isinstance(self.drafter, NgramProposerGPU)
sampled_token_ids = sampler_output.sampled_token_ids
if input_fits_in_drafter:
propose_draft_token_ids(sampled_token_ids)
elif self.valid_sampled_token_count_event is not None:
assert spec_decode_common_attn_metadata is not None
next_token_ids, valid_sampled_tokens_count, _ = (
self.drafter.update_token_ids_ngram(
sampled_token_ids,
self.input_batch,
self.token_ids_gpu_tensor,
self.num_tokens_no_spec_gpu,
self.discard_request_mask.gpu,
)
)
self._copy_valid_sampled_token_count(
next_token_ids, valid_sampled_tokens_count
)
# Since we couldn't run the drafter,
# just use zeros for the draft tokens.
self._draft_token_ids = torch.zeros(
......@@ -4064,6 +4195,52 @@ class GPUModelRunner(
self.input_batch.token_ids_cpu,
slot_mappings=slot_mappings,
)
if isinstance(self.drafter, NgramProposer):
assert isinstance(sampled_token_ids, list), (
"sampled_token_ids should be a python list when ngram is used."
)
draft_token_ids = self.drafter.propose(
sampled_token_ids,
self.input_batch.num_tokens_no_spec,
self.input_batch.token_ids_cpu,
)
elif spec_config.use_ngram_gpu():
assert isinstance(self.drafter, NgramProposerGPU)
(
next_token_ids,
valid_sampled_tokens_count,
valid_sampled_token_ids_gpu,
) = self.drafter.update_token_ids_ngram(
sampled_token_ids,
self.input_batch,
self.token_ids_gpu_tensor,
self.num_tokens_no_spec_gpu,
self.discard_request_mask.gpu,
)
self._copy_valid_sampled_token_count(
next_token_ids, valid_sampled_tokens_count
)
batch_size = next_token_ids.shape[0]
draft_token_ids, num_valid_draft_tokens = self.drafter.propose(
self.num_tokens_no_spec_gpu[:batch_size],
self.token_ids_gpu_tensor[:batch_size],
valid_sampled_token_ids_gpu,
valid_sampled_tokens_count,
)
# Cache valid draft counts for scheduler-side trimming.
self._num_valid_draft_tokens = num_valid_draft_tokens
# Async D2H copy on a dedicated stream.
copy_num_valid_draft_tokens(
self._num_valid_draft_tokens_cpu,
self._num_valid_draft_tokens_copy_stream,
self._num_valid_draft_tokens_event,
self._num_valid_draft_tokens,
self.input_batch.num_reqs,
)
elif spec_config.method == "suffix":
assert isinstance(sampled_token_ids, list)
assert isinstance(self.drafter, SuffixDecodingProposer)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment