Unverified Commit a6be75db authored by PatchyTIS's avatar PatchyTIS Committed by GitHub
Browse files

[Core] NGram GPU Implementation compatible with Async Scheduler (#29184)

parent ee54f9cd
...@@ -98,7 +98,7 @@ def test_without_spec_decoding( ...@@ -98,7 +98,7 @@ def test_without_spec_decoding(
@single_gpu_only @single_gpu_only
@large_gpu_mark(min_gb=16) @large_gpu_mark(min_gb=16)
def test_with_spec_decoding(sample_json_schema, monkeypatch: pytest.MonkeyPatch): def test_with_eagle3_spec_decoding(sample_json_schema, monkeypatch: pytest.MonkeyPatch):
"""Test consistency and acceptance rates with some different combos of """Test consistency and acceptance rates with some different combos of
preemption, executor, async scheduling, prefill chunking, preemption, executor, async scheduling, prefill chunking,
spec decoding model length. spec decoding model length.
...@@ -154,6 +154,42 @@ def test_with_spec_decoding(sample_json_schema, monkeypatch: pytest.MonkeyPatch) ...@@ -154,6 +154,42 @@ def test_with_spec_decoding(sample_json_schema, monkeypatch: pytest.MonkeyPatch)
) )
def test_with_ngram_gpu_spec_decoding(monkeypatch: pytest.MonkeyPatch):
"""Test ngram_gpu speculative decoding with different configurations.
This test specifically validates ngram_gpu behavior with various:
- Number of speculative tokens (2-6)
- Prompt lookup window sizes (min/max)
- Async scheduling enabled (as in production)
- Different executors and chunking settings
"""
# Variant with larger speculation window
ngram_gpu_config = {
"method": "ngram_gpu",
"num_speculative_tokens": 3,
"prompt_lookup_max": 3,
"prompt_lookup_min": 2,
}
# Test configurations covering various scenarios
# test_preemption, executor, async_scheduling,
# spec_config, test_prefill_chunking
test_configs = [
(False, "mp", False, None, False),
(False, "mp", False, ngram_gpu_config, False),
(True, "mp", False, ngram_gpu_config, True),
(False, "mp", True, ngram_gpu_config, False),
(True, "mp", True, ngram_gpu_config, False),
(True, "uni", True, ngram_gpu_config, False),
(True, "mp", True, ngram_gpu_config, True),
]
# Use MODEL (Qwen) for ngram_gpu tests as it's lighter weight
# and ngram_gpu doesn't require a specific draft model
run_tests(monkeypatch, MODEL, test_configs, [{}])
@dynamo_config.patch(cache_size_limit=16) @dynamo_config.patch(cache_size_limit=16)
def run_tests( def run_tests(
monkeypatch: pytest.MonkeyPatch, monkeypatch: pytest.MonkeyPatch,
...@@ -282,11 +318,12 @@ def run_test( ...@@ -282,11 +318,12 @@ def run_test(
else dict(gpu_memory_utilization=0.9) else dict(gpu_memory_utilization=0.9)
) )
spec_mml = (spec_config or {}).get("max_model_len") spec_mml = (spec_config or {}).get("max_model_len")
spec_method = (spec_config or {}).get("method", "none")
test_config = ( test_config = (
f"executor={executor}, preemption={test_preemption}, " f"executor={executor}, preemption={test_preemption}, "
f"async_sched={async_scheduling}, " f"async_sched={async_scheduling}, "
f"chunk_prefill={test_prefill_chunking}, " f"chunk_prefill={test_prefill_chunking}, "
f"spec_decoding={spec_decoding}, spec_mml={spec_mml}" f"spec_decoding={spec_decoding}, spec_method={spec_method}, spec_mml={spec_mml}"
) )
print("-" * 80) print("-" * 80)
print(f"---- TESTING {test_str}: {test_config}") print(f"---- TESTING {test_str}: {test_config}")
...@@ -294,7 +331,7 @@ def run_test( ...@@ -294,7 +331,7 @@ def run_test(
with VllmRunner( with VllmRunner(
model, model,
max_model_len=512, max_model_len=4096,
enable_chunked_prefill=test_prefill_chunking, enable_chunked_prefill=test_prefill_chunking,
# Force prefill chunking # Force prefill chunking
max_num_batched_tokens=48 if test_prefill_chunking else None, max_num_batched_tokens=48 if test_prefill_chunking else None,
......
...@@ -183,6 +183,34 @@ def test_ngram_and_suffix_correctness( ...@@ -183,6 +183,34 @@ def test_ngram_and_suffix_correctness(
cleanup_dist_env_and_memory() cleanup_dist_env_and_memory()
@pytest.mark.parametrize("async_scheduling", [True], ids=["async"])
@single_gpu_only
@large_gpu_mark(min_gb=20)
def test_ngram_gpu_default_with_async_scheduling(
async_scheduling: bool,
):
"""
Test ngram_gpu speculative decoding (k=3) correctness with and without
async scheduling, validated via GSM8K accuracy.
Uses Qwen/Qwen3-8B (ref GSM8K accuracy: 87%-92%).
"""
qwen3_model = "Qwen/Qwen3-8B"
spec_llm = LLM(
model=qwen3_model,
speculative_config={
"method": "ngram_gpu",
"prompt_lookup_max": 3,
"prompt_lookup_min": 2,
"num_speculative_tokens": 2,
},
max_model_len=4096,
async_scheduling=async_scheduling,
)
evaluate_llm_for_gsm8k(spec_llm, expected_accuracy_threshold=0.8)
del spec_llm
cleanup_dist_env_and_memory()
@single_gpu_only @single_gpu_only
@large_gpu_mark(min_gb=20) @large_gpu_mark(min_gb=20)
def test_suffix_decoding_acceptance( def test_suffix_decoding_acceptance(
......
...@@ -907,6 +907,13 @@ class VllmBackend: ...@@ -907,6 +907,13 @@ class VllmBackend:
# Honors opt-outs such as CompilationMode.NONE or VLLM_DISABLE_COMPILE_CACHE. # Honors opt-outs such as CompilationMode.NONE or VLLM_DISABLE_COMPILE_CACHE.
disable_cache = not is_compile_cache_enabled(self.inductor_config) disable_cache = not is_compile_cache_enabled(self.inductor_config)
# TODO(patchy): ngram gpu kernel will cause vllm torch compile cache errors.
is_ngram_gpu_enabled = (
vllm_config.speculative_config is not None
and vllm_config.speculative_config.use_ngram_gpu()
)
disable_cache = disable_cache or is_ngram_gpu_enabled
if disable_cache: if disable_cache:
logger.info_once("vLLM's torch.compile cache is disabled.", scope="local") logger.info_once("vLLM's torch.compile cache is disabled.", scope="local")
else: else:
......
...@@ -47,6 +47,7 @@ MTPModelTypes = Literal[ ...@@ -47,6 +47,7 @@ MTPModelTypes = Literal[
"step3p5_mtp", "step3p5_mtp",
] ]
EagleModelTypes = Literal["eagle", "eagle3", "extract_hidden_states", MTPModelTypes] EagleModelTypes = Literal["eagle", "eagle3", "extract_hidden_states", MTPModelTypes]
NgramGPUTypes = Literal["ngram_gpu"]
SpeculativeMethod = Literal[ SpeculativeMethod = Literal[
"ngram", "ngram",
"medusa", "medusa",
...@@ -54,6 +55,7 @@ SpeculativeMethod = Literal[ ...@@ -54,6 +55,7 @@ SpeculativeMethod = Literal[
"draft_model", "draft_model",
"suffix", "suffix",
EagleModelTypes, EagleModelTypes,
NgramGPUTypes,
] ]
...@@ -364,6 +366,8 @@ class SpeculativeConfig: ...@@ -364,6 +366,8 @@ class SpeculativeConfig:
self.quantization = self.target_model_config.quantization self.quantization = self.target_model_config.quantization
elif self.method in ("ngram", "[ngram]"): elif self.method in ("ngram", "[ngram]"):
self.model = "ngram" self.model = "ngram"
elif self.method == "ngram_gpu":
self.model = "ngram_gpu"
elif self.method == "suffix": elif self.method == "suffix":
self.model = "suffix" self.model = "suffix"
elif self.method == "extract_hidden_states": elif self.method == "extract_hidden_states":
...@@ -374,8 +378,9 @@ class SpeculativeConfig: ...@@ -374,8 +378,9 @@ class SpeculativeConfig:
) )
if self.method in ("ngram", "[ngram]"): if self.method in ("ngram", "[ngram]"):
# Unified to "ngram" internally
self.method = "ngram" self.method = "ngram"
if self.method in ("ngram", "ngram_gpu"):
# Set default values if not provided # Set default values if not provided
if self.prompt_lookup_min is None and self.prompt_lookup_max is None: if self.prompt_lookup_min is None and self.prompt_lookup_max is None:
# TODO(woosuk): Tune these values. They are arbitrarily chosen. # TODO(woosuk): Tune these values. They are arbitrarily chosen.
...@@ -832,6 +837,9 @@ class SpeculativeConfig: ...@@ -832,6 +837,9 @@ class SpeculativeConfig:
def uses_extract_hidden_states(self) -> bool: def uses_extract_hidden_states(self) -> bool:
return self.method == "extract_hidden_states" return self.method == "extract_hidden_states"
def use_ngram_gpu(self) -> bool:
return self.method == "ngram_gpu"
def __repr__(self) -> str: def __repr__(self) -> str:
method = self.method method = self.method
model = ( model = (
......
...@@ -41,7 +41,7 @@ from .offload import OffloadConfig ...@@ -41,7 +41,7 @@ from .offload import OffloadConfig
from .parallel import ParallelConfig from .parallel import ParallelConfig
from .profiler import ProfilerConfig from .profiler import ProfilerConfig
from .scheduler import SchedulerConfig from .scheduler import SchedulerConfig
from .speculative import EagleModelTypes, SpeculativeConfig from .speculative import EagleModelTypes, NgramGPUTypes, SpeculativeConfig
from .structured_outputs import StructuredOutputsConfig from .structured_outputs import StructuredOutputsConfig
from .utils import SupportsHash, config, replace from .utils import SupportsHash, config, replace
from .weight_transfer import WeightTransferConfig from .weight_transfer import WeightTransferConfig
...@@ -696,11 +696,13 @@ class VllmConfig: ...@@ -696,11 +696,13 @@ class VllmConfig:
if self.speculative_config is not None: if self.speculative_config is not None:
if ( if (
self.speculative_config.method not in get_args(EagleModelTypes) self.speculative_config.method not in get_args(EagleModelTypes)
and self.speculative_config.method not in get_args(NgramGPUTypes)
and self.speculative_config.method != "draft_model" and self.speculative_config.method != "draft_model"
): ):
raise ValueError( raise ValueError(
"Currently, async scheduling is only supported " "Currently, async scheduling is only supported "
"with EAGLE/MTP/Draft Model kind of speculative decoding." "with EAGLE/MTP/Draft Model/NGram GPU kind of "
"speculative decoding"
) )
if self.speculative_config.disable_padded_drafter_batch: if self.speculative_config.disable_padded_drafter_batch:
raise ValueError( raise ValueError(
...@@ -718,6 +720,7 @@ class VllmConfig: ...@@ -718,6 +720,7 @@ class VllmConfig:
if ( if (
self.speculative_config is not None self.speculative_config is not None
and self.speculative_config.method not in get_args(EagleModelTypes) and self.speculative_config.method not in get_args(EagleModelTypes)
and self.speculative_config.method not in get_args(NgramGPUTypes)
): ):
logger.warning_once( logger.warning_once(
"Async scheduling not supported with %s-based " "Async scheduling not supported with %s-based "
......
...@@ -385,6 +385,7 @@ class Hermes2ProToolParser(ToolParser): ...@@ -385,6 +385,7 @@ class Hermes2ProToolParser(ToolParser):
prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get( prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get(
"arguments" "arguments"
) )
assert current_tool_call is not None
cur_arguments = current_tool_call.get("arguments") cur_arguments = current_tool_call.get("arguments")
logger.debug("diffing old arguments: %s", prev_arguments) logger.debug("diffing old arguments: %s", prev_arguments)
...@@ -489,6 +490,7 @@ class Hermes2ProToolParser(ToolParser): ...@@ -489,6 +490,7 @@ class Hermes2ProToolParser(ToolParser):
# handle saving the state for the current tool into # handle saving the state for the current tool into
# the "prev" list for use in diffing for the next iteration # the "prev" list for use in diffing for the next iteration
assert isinstance(current_tool_call, dict)
if self.current_tool_id == len(self.prev_tool_call_arr) - 1: if self.current_tool_id == len(self.prev_tool_call_arr) - 1:
self.prev_tool_call_arr[self.current_tool_id] = current_tool_call self.prev_tool_call_arr[self.current_tool_id] = current_tool_call
else: else:
......
This diff is collapsed.
...@@ -127,7 +127,13 @@ class InputBatch: ...@@ -127,7 +127,13 @@ class InputBatch:
# allocation if max_model_len is big. # allocation if max_model_len is big.
# Maps req_index -> tensor of shape (num_prompt_tokens, hidden_size) # Maps req_index -> tensor of shape (num_prompt_tokens, hidden_size)
self.req_prompt_embeds: dict[int, torch.Tensor] = {} self.req_prompt_embeds: dict[int, torch.Tensor] = {}
self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32) self.num_tokens_no_spec_cpu_tensor = torch.zeros(
(max_num_reqs,),
device="cpu",
dtype=torch.int32,
pin_memory=pin_memory,
)
self.num_tokens_no_spec = self.num_tokens_no_spec_cpu_tensor.numpy()
self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32) self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
self.num_computed_tokens_cpu_tensor = torch.zeros( self.num_computed_tokens_cpu_tensor = torch.zeros(
(max_num_reqs,), (max_num_reqs,),
......
...@@ -10,7 +10,7 @@ from collections import defaultdict ...@@ -10,7 +10,7 @@ from collections import defaultdict
from collections.abc import Iterable, Iterator, Sequence from collections.abc import Iterable, Iterator, Sequence
from contextlib import contextmanager from contextlib import contextmanager
from copy import copy, deepcopy from copy import copy, deepcopy
from dataclasses import dataclass from dataclasses import dataclass, replace
from functools import reduce from functools import reduce
from typing import TYPE_CHECKING, Any, NamedTuple, TypeAlias, cast from typing import TYPE_CHECKING, Any, NamedTuple, TypeAlias, cast
...@@ -164,6 +164,12 @@ from vllm.v1.spec_decode.eagle import EagleProposer ...@@ -164,6 +164,12 @@ from vllm.v1.spec_decode.eagle import EagleProposer
from vllm.v1.spec_decode.extract_hidden_states import ExtractHiddenStatesProposer from vllm.v1.spec_decode.extract_hidden_states import ExtractHiddenStatesProposer
from vllm.v1.spec_decode.medusa import MedusaProposer from vllm.v1.spec_decode.medusa import MedusaProposer
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.spec_decode.ngram_proposer_gpu import (
NgramProposerGPU,
copy_num_valid_draft_tokens,
update_ngram_gpu_tensors_incremental,
update_scheduler_for_invalid_drafts,
)
from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer
from vllm.v1.structured_output.utils import apply_grammar_bitmask from vllm.v1.structured_output.utils import apply_grammar_bitmask
from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext
...@@ -424,7 +430,7 @@ class GPUModelRunner( ...@@ -424,7 +430,7 @@ class GPUModelRunner(
# Broadcast PP output for external_launcher (torchrun) # Broadcast PP output for external_launcher (torchrun)
# to make sure we are synced across pp ranks # to make sure we are synced across pp ranks
# TODO: Support overlapping mirco-batches # TODO: Support overlapping micro-batches
# https://github.com/vllm-project/vllm/issues/18019 # https://github.com/vllm-project/vllm/issues/18019
self.broadcast_pp_output = ( self.broadcast_pp_output = (
self.parallel_config.distributed_executor_backend == "external_launcher" self.parallel_config.distributed_executor_backend == "external_launcher"
...@@ -493,6 +499,7 @@ class GPUModelRunner( ...@@ -493,6 +499,7 @@ class GPUModelRunner(
if self.speculative_config and get_pp_group().is_last_rank: if self.speculative_config and get_pp_group().is_last_rank:
self.drafter: ( self.drafter: (
NgramProposer # noqa: F823 NgramProposer # noqa: F823
| NgramProposerGPU
| SuffixDecodingProposer | SuffixDecodingProposer
| EagleProposer | EagleProposer
| DraftModelProposer | DraftModelProposer
...@@ -509,6 +516,23 @@ class GPUModelRunner( ...@@ -509,6 +516,23 @@ class GPUModelRunner(
device=self.device, device=self.device,
runner=self, runner=self,
) )
elif self.speculative_config.use_ngram_gpu():
self.drafter = NgramProposerGPU(self.vllm_config, self.device, self)
self.num_tokens_no_spec_gpu = torch.zeros(
self.max_num_reqs, dtype=torch.int32, device=device
)
self.token_ids_gpu_tensor = torch.zeros(
self.max_num_reqs,
self.max_model_len,
dtype=torch.int32,
device=device,
)
self._ngram_pinned_idx_buf = torch.zeros(
self.max_num_reqs, dtype=torch.long, pin_memory=True
)
self._ngram_pinned_val_buf = torch.zeros(
self.max_num_reqs, dtype=torch.int32, pin_memory=True
)
elif self.speculative_config.method == "suffix": elif self.speculative_config.method == "suffix":
self.drafter = SuffixDecodingProposer(self.vllm_config) self.drafter = SuffixDecodingProposer(self.vllm_config)
elif self.speculative_config.use_eagle(): elif self.speculative_config.use_eagle():
...@@ -564,7 +588,7 @@ class GPUModelRunner( ...@@ -564,7 +588,7 @@ class GPUModelRunner(
) )
self.input_batch = InputBatch( self.input_batch = InputBatch(
max_num_reqs=self.max_num_reqs, max_num_reqs=self.max_num_reqs,
# We need to use the encoder length for encoder-decoer # We need to use the encoder length for encoder-decoder
# because of KV cache for cross-attention. # because of KV cache for cross-attention.
max_model_len=max(self.max_model_len, self.max_encoder_len), max_model_len=max(self.max_model_len, self.max_encoder_len),
max_num_batched_tokens=self.max_num_tokens, max_num_batched_tokens=self.max_num_tokens,
...@@ -721,6 +745,21 @@ class GPUModelRunner( ...@@ -721,6 +745,21 @@ class GPUModelRunner(
# Cached outputs. # Cached outputs.
self._draft_token_ids: list[list[int]] | torch.Tensor | None = None self._draft_token_ids: list[list[int]] | torch.Tensor | None = None
# N-gram GPU path: async D2H buffer/event for per-request valid draft counts.
self._num_valid_draft_tokens: torch.Tensor | None = None
self._num_valid_draft_tokens_cpu: torch.Tensor | None = None
self._num_valid_draft_tokens_event: torch.cuda.Event | None = None
self._num_valid_draft_tokens_copy_stream: torch.cuda.Stream | None = None
if (
self.speculative_config is not None
and self.speculative_config.use_ngram_gpu()
):
self._num_valid_draft_tokens_cpu = torch.empty(
self.max_num_reqs, dtype=torch.int32, pin_memory=self.pin_memory
)
self._num_valid_draft_tokens_event = torch.cuda.Event()
self._num_valid_draft_tokens_copy_stream = torch.cuda.Stream()
self._draft_token_req_ids: list[str] | None = None self._draft_token_req_ids: list[str] | None = None
self.transfer_event = torch.Event() self.transfer_event = torch.Event()
self.sampled_token_ids_pinned_cpu = torch.empty( self.sampled_token_ids_pinned_cpu = torch.empty(
...@@ -992,6 +1031,13 @@ class GPUModelRunner( ...@@ -992,6 +1031,13 @@ class GPUModelRunner(
for req_id in unscheduled_req_ids: for req_id in unscheduled_req_ids:
self.input_batch.remove_request(req_id) self.input_batch.remove_request(req_id)
is_ngram_gpu = (
self.speculative_config is not None
and self.speculative_config.use_ngram_gpu()
)
if is_ngram_gpu:
ngram_gpu_new_reqs: list[CachedRequestState] = []
reqs_to_add: list[CachedRequestState] = [] reqs_to_add: list[CachedRequestState] = []
# Add new requests to the cached states. # Add new requests to the cached states.
for new_req_data in scheduler_output.scheduled_new_reqs: for new_req_data in scheduler_output.scheduled_new_reqs:
...@@ -1054,12 +1100,31 @@ class GPUModelRunner( ...@@ -1054,12 +1100,31 @@ class GPUModelRunner(
self._init_xdrope_positions(req_state) self._init_xdrope_positions(req_state)
reqs_to_add.append(req_state) reqs_to_add.append(req_state)
# Track new requests for ngram_gpu full tensor copy
if is_ngram_gpu:
ngram_gpu_new_reqs.append(req_state)
# Update the states of the running/resumed requests. # Update the states of the running/resumed requests.
is_last_rank = get_pp_group().is_last_rank is_last_rank = get_pp_group().is_last_rank
req_data = scheduler_output.scheduled_cached_reqs req_data = scheduler_output.scheduled_cached_reqs
scheduled_spec_tokens = scheduler_output.scheduled_spec_decode_tokens scheduled_spec_tokens = scheduler_output.scheduled_spec_decode_tokens
# Save scheduler-allocated spec lengths before trimming so
# prev_num_draft_len keeps the optimistic count for rejection correction.
original_num_spec_per_req: dict[str, int] = {}
if (
self.speculative_config is not None
and self.speculative_config.use_ngram_gpu()
):
for req_id, toks in scheduled_spec_tokens.items():
original_num_spec_per_req[req_id] = len(toks)
update_scheduler_for_invalid_drafts(
self._num_valid_draft_tokens_event,
self._num_valid_draft_tokens_cpu,
scheduler_output,
self.input_batch.req_id_to_index,
)
# Wait until valid_sampled_tokens_count is copied to cpu, # Wait until valid_sampled_tokens_count is copied to cpu,
# then use it to update actual num_computed_tokens of each request. # then use it to update actual num_computed_tokens of each request.
valid_sampled_token_count = self._get_valid_sampled_token_count() valid_sampled_token_count = self._get_valid_sampled_token_count()
...@@ -1076,13 +1141,13 @@ class GPUModelRunner( ...@@ -1076,13 +1141,13 @@ class GPUModelRunner(
# prev_num_draft_len is used in async scheduling mode with # prev_num_draft_len is used in async scheduling mode with
# spec decode. it indicates if need to update num_computed_tokens # spec decode. it indicates if need to update num_computed_tokens
# of the request. for example: # of the request. for example:
# fist step: num_computed_tokens = 0, spec_tokens = [], # first step: num_computed_tokens = 0, spec_tokens = [],
# prev_num_draft_len = 0. # prev_num_draft_len = 0.
# second step: num_computed_tokens = 100(prompt length), # second step: num_computed_tokens = 100(prompt length),
# spec_tokens = [a,b], prev_num_draft_len = 0. # spec_tokens = [a,b], prev_num_draft_len = 0.
# third step: num_computed_tokens = 100 + 2, spec_tokens = [c,d], # third step: num_computed_tokens = 100 + 2, spec_tokens = [c,d],
# prev_num_draft_len = 2. # prev_num_draft_len = 2.
# num_computed_tokens in first step and second step does't contain # num_computed_tokens in first step and second step doesn't contain
# the spec tokens length, but in third step it contains the # the spec tokens length, but in third step it contains the
# spec tokens length. we only need to update num_computed_tokens # spec tokens length. we only need to update num_computed_tokens
# when prev_num_draft_len > 0. # when prev_num_draft_len > 0.
...@@ -1096,6 +1161,9 @@ class GPUModelRunner( ...@@ -1096,6 +1161,9 @@ class GPUModelRunner(
num_computed_tokens -= num_rejected num_computed_tokens -= num_rejected
req_state.output_token_ids.extend([-1] * num_accepted) req_state.output_token_ids.extend([-1] * num_accepted)
if is_ngram_gpu and num_accepted > 0 and req_index is not None:
self.input_batch.num_tokens_no_spec[req_index] += num_accepted
# Update the cached states. # Update the cached states.
req_state.num_computed_tokens = num_computed_tokens req_state.num_computed_tokens = num_computed_tokens
...@@ -1156,6 +1224,9 @@ class GPUModelRunner( ...@@ -1156,6 +1224,9 @@ class GPUModelRunner(
req_state.output_token_ids = resumed_token_ids[-num_output_tokens:] req_state.output_token_ids = resumed_token_ids[-num_output_tokens:]
reqs_to_add.append(req_state) reqs_to_add.append(req_state)
# Track resumed requests for ngram_gpu full tensor copy
if is_ngram_gpu:
ngram_gpu_new_reqs.append(req_state)
continue continue
# Update the persistent batch. # Update the persistent batch.
...@@ -1176,6 +1247,11 @@ class GPUModelRunner( ...@@ -1176,6 +1247,11 @@ class GPUModelRunner(
# Add spec_token_ids to token_ids_cpu. # Add spec_token_ids to token_ids_cpu.
self.input_batch.update_req_spec_token_ids(req_state, scheduled_spec_tokens) self.input_batch.update_req_spec_token_ids(req_state, scheduled_spec_tokens)
# Restore scheduler-side draft count after ngram trimming.
if original_num_spec_per_req:
orig = original_num_spec_per_req.get(req_id, 0)
if orig != req_state.prev_num_draft_len:
req_state.prev_num_draft_len = orig
# Add the new or resumed requests to the persistent batch. # Add the new or resumed requests to the persistent batch.
# The smaller empty indices are filled first. # The smaller empty indices are filled first.
...@@ -1190,6 +1266,18 @@ class GPUModelRunner( ...@@ -1190,6 +1266,18 @@ class GPUModelRunner(
# Refresh batch metadata with any pending updates. # Refresh batch metadata with any pending updates.
self.input_batch.refresh_metadata() self.input_batch.refresh_metadata()
# Incrementally update ngram_gpu tensors after batch is stable
if is_ngram_gpu:
update_ngram_gpu_tensors_incremental(
self.input_batch,
self.token_ids_gpu_tensor,
self.num_tokens_no_spec_gpu,
ngram_gpu_new_reqs,
self.device,
_pinned_idx_buf=self._ngram_pinned_idx_buf,
_pinned_val_buf=self._ngram_pinned_val_buf,
)
def _update_states_after_model_execute( def _update_states_after_model_execute(
self, output_token_ids: torch.Tensor, scheduler_output: "SchedulerOutput" self, output_token_ids: torch.Tensor, scheduler_output: "SchedulerOutput"
) -> None: ) -> None:
...@@ -3412,6 +3500,23 @@ class GPUModelRunner( ...@@ -3412,6 +3500,23 @@ class GPUModelRunner(
else: else:
logger.error("RoutedExpertsCapturer not initialized.") logger.error("RoutedExpertsCapturer not initialized.")
# If ngram_gpu is used, we need to copy the scheduler_output to avoid
# the modification has influence on the scheduler_output in engine core process.
# The replace is much faster than deepcopy.
if (
self.speculative_config is not None
and self.speculative_config.use_ngram_gpu()
):
num_scheduled_tokens_copy = scheduler_output.num_scheduled_tokens.copy()
spec_decode_tokens_copy = (
scheduler_output.scheduled_spec_decode_tokens.copy()
)
scheduler_output = replace(
scheduler_output,
num_scheduled_tokens=num_scheduled_tokens_copy,
scheduled_spec_decode_tokens=spec_decode_tokens_copy,
)
if scheduler_output.preempted_req_ids and has_kv_transfer_group(): if scheduler_output.preempted_req_ids and has_kv_transfer_group():
get_kv_transfer_group().handle_preemptions( get_kv_transfer_group().handle_preemptions(
scheduler_output.preempted_req_ids scheduler_output.preempted_req_ids
...@@ -3825,6 +3930,32 @@ class GPUModelRunner( ...@@ -3825,6 +3930,32 @@ class GPUModelRunner(
self._copy_valid_sampled_token_count( self._copy_valid_sampled_token_count(
next_token_ids, valid_sampled_tokens_count next_token_ids, valid_sampled_tokens_count
) )
self._draft_token_ids = torch.zeros(
1, device=self.device, dtype=torch.int32
).expand(len(self.input_batch.req_ids), self.num_spec_tokens)
self._copy_draft_token_ids_to_cpu(scheduler_output, zeros_only=True)
elif (
spec_config.use_ngram_gpu()
and not spec_config.disable_padded_drafter_batch
):
assert isinstance(self.drafter, NgramProposerGPU)
sampled_token_ids = sampler_output.sampled_token_ids
if input_fits_in_drafter:
propose_draft_token_ids(sampled_token_ids)
elif self.valid_sampled_token_count_event is not None:
assert spec_decode_common_attn_metadata is not None
next_token_ids, valid_sampled_tokens_count, _ = (
self.drafter.update_token_ids_ngram(
sampled_token_ids,
self.input_batch,
self.token_ids_gpu_tensor,
self.num_tokens_no_spec_gpu,
self.discard_request_mask.gpu,
)
)
self._copy_valid_sampled_token_count(
next_token_ids, valid_sampled_tokens_count
)
# Since we couldn't run the drafter, # Since we couldn't run the drafter,
# just use zeros for the draft tokens. # just use zeros for the draft tokens.
self._draft_token_ids = torch.zeros( self._draft_token_ids = torch.zeros(
...@@ -4064,6 +4195,52 @@ class GPUModelRunner( ...@@ -4064,6 +4195,52 @@ class GPUModelRunner(
self.input_batch.token_ids_cpu, self.input_batch.token_ids_cpu,
slot_mappings=slot_mappings, slot_mappings=slot_mappings,
) )
if isinstance(self.drafter, NgramProposer):
assert isinstance(sampled_token_ids, list), (
"sampled_token_ids should be a python list when ngram is used."
)
draft_token_ids = self.drafter.propose(
sampled_token_ids,
self.input_batch.num_tokens_no_spec,
self.input_batch.token_ids_cpu,
)
elif spec_config.use_ngram_gpu():
assert isinstance(self.drafter, NgramProposerGPU)
(
next_token_ids,
valid_sampled_tokens_count,
valid_sampled_token_ids_gpu,
) = self.drafter.update_token_ids_ngram(
sampled_token_ids,
self.input_batch,
self.token_ids_gpu_tensor,
self.num_tokens_no_spec_gpu,
self.discard_request_mask.gpu,
)
self._copy_valid_sampled_token_count(
next_token_ids, valid_sampled_tokens_count
)
batch_size = next_token_ids.shape[0]
draft_token_ids, num_valid_draft_tokens = self.drafter.propose(
self.num_tokens_no_spec_gpu[:batch_size],
self.token_ids_gpu_tensor[:batch_size],
valid_sampled_token_ids_gpu,
valid_sampled_tokens_count,
)
# Cache valid draft counts for scheduler-side trimming.
self._num_valid_draft_tokens = num_valid_draft_tokens
# Async D2H copy on a dedicated stream.
copy_num_valid_draft_tokens(
self._num_valid_draft_tokens_cpu,
self._num_valid_draft_tokens_copy_stream,
self._num_valid_draft_tokens_event,
self._num_valid_draft_tokens,
self.input_batch.num_reqs,
)
elif spec_config.method == "suffix": elif spec_config.method == "suffix":
assert isinstance(sampled_token_ids, list) assert isinstance(sampled_token_ids, list)
assert isinstance(self.drafter, SuffixDecodingProposer) assert isinstance(self.drafter, SuffixDecodingProposer)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment