Unverified Commit f7008ce1 authored by Benjamin Chislett's avatar Benjamin Chislett Committed by GitHub
Browse files

[Perf] Async Scheduling + Speculative Decoding + Structured Outputs (#29821)


Signed-off-by: default avatarBenjamin Chislett <bchislett@nvidia.com>
Signed-off-by: default avatarNick Hill <nickhill123@gmail.com>
Co-authored-by: default avatarNick Hill <nickhill123@gmail.com>
parent 4e67a8f6
......@@ -30,8 +30,9 @@ example_prompts = [first_prompt, "In one word, the capital of France is "] + [
default_params = dict(
temperature=0.0, # greedy
max_tokens=23,
min_tokens=18,
max_tokens=30,
# spec decoding currently doesn't support min_tokens
# min_tokens=28,
)
......@@ -86,7 +87,7 @@ def test_without_spec_decoding(
run_tests(monkeypatch, MODEL, test_configs, test_sampling_params)
def test_with_spec_decoding(monkeypatch: pytest.MonkeyPatch):
def test_with_spec_decoding(sample_json_schema, monkeypatch: pytest.MonkeyPatch):
"""Test consistency and acceptance rates with some different combos of
preemption, executor, async scheduling, prefill chunking,
spec decoding model length.
......@@ -100,9 +101,16 @@ def test_with_spec_decoding(monkeypatch: pytest.MonkeyPatch):
# Set small draft model len to force doesn't-fit-in-drafter case.
spec_config_short = spec_config | {"max_model_len": 50}
struct_outputs = StructuredOutputsParams(json=sample_json_schema)
test_sampling_params = [
dict(),
dict(logprobs=2),
dict(structured_outputs=struct_outputs),
dict(
structured_outputs=struct_outputs,
logprobs=2,
),
]
# test_preemption, executor, async_scheduling,
......
......@@ -12,10 +12,12 @@ logger = init_logger(__name__)
class AsyncScheduler(Scheduler):
def _update_after_schedule(self, scheduler_output: SchedulerOutput) -> None:
super()._update_after_schedule(scheduler_output)
has_structured_output_requests = False
pending_structured_output_tokens = False
spec_decode_tokens = scheduler_output.scheduled_spec_decode_tokens
for req_id in scheduler_output.num_scheduled_tokens:
request = self.requests[req_id]
has_structured_output_requests |= request.use_structured_output
pending_structured_output_tokens |= (
request.use_structured_output and request.num_output_placeholders > 0
)
......@@ -33,6 +35,7 @@ class AsyncScheduler(Scheduler):
# We will update the actual spec token ids in the worker process.
request.spec_token_ids = [-1] * self.num_spec_tokens
scheduler_output.has_structured_output_requests = has_structured_output_requests
scheduler_output.pending_structured_output_tokens = (
pending_structured_output_tokens
)
......
......@@ -86,7 +86,26 @@ class SchedulerInterface(ABC):
@abstractmethod
def update_draft_token_ids(self, draft_token_ids: "DraftTokenIds") -> None:
"""Update the draft token ids for the scheduled requests."""
"""Update requests with newly generated draft token ids, applying
structured output grammar validation if needed.
Args:
draft_token_ids: The input draft token ids for each request.
"""
raise NotImplementedError
@abstractmethod
def update_draft_token_ids_in_output(
self, draft_token_ids: "DraftTokenIds", scheduler_output: "SchedulerOutput"
) -> None:
"""Update scheduler output with newly generated draft token ids, applying
structured output grammar validation if needed.
Args:
draft_token_ids: The input draft token ids for each request.
scheduler_output: Update the given scheduler_output
with the corresponding draft token ids.
"""
raise NotImplementedError
@abstractmethod
......
......@@ -181,10 +181,17 @@ class SchedulerOutput:
# Only used for v2 model runner.
preempted_req_ids: set[str] | None = None
# Whether any of the scheduled requests use structured output.
# Set only in async scheduling case.
has_structured_output_requests: bool = False
# Whether the scheduled requests have all the output tokens they
# need to perform grammar bitmask computation.
pending_structured_output_tokens: bool = False
# Used for adjusting acceptance rate calculation.
num_invalid_spec_tokens: dict[str, int] | None = None
# KV Cache Connector metadata.
kv_connector_metadata: KVConnectorMetadata | None = None
......
......@@ -1130,6 +1130,8 @@ class Scheduler(SchedulerInterface):
spec_decoding_stats,
num_draft_tokens=num_draft_tokens,
num_accepted_tokens=num_accepted,
num_invalid_spec_tokens=scheduler_output.num_invalid_spec_tokens,
request_id=req_id,
)
stopped = False
......@@ -1168,7 +1170,13 @@ class Scheduler(SchedulerInterface):
struct_output_request = request.structured_output_request
assert struct_output_request is not None
assert struct_output_request.grammar is not None
struct_output_request.grammar.accept_tokens(req_id, new_token_ids)
ok = struct_output_request.grammar.accept_tokens(req_id, new_token_ids)
if not ok:
logger.warning(
"Unexpected: grammar rejected tokens %s for request %s.",
new_token_ids,
req_id,
)
if num_nans_in_logits is not None and req_id in num_nans_in_logits:
request.num_nans_in_logits = num_nans_in_logits[req_id]
......@@ -1330,12 +1338,47 @@ class Scheduler(SchedulerInterface):
# Add newly generated spec token ids to the request.
if self.structured_output_manager.should_advance(request):
metadata = request.structured_output_request
request.spec_token_ids = metadata.grammar.validate_tokens( # type: ignore[union-attr]
spec_token_ids
)
else:
spec_token_ids = metadata.grammar.validate_tokens(spec_token_ids) # type: ignore[union-attr]
request.spec_token_ids = spec_token_ids
def update_draft_token_ids_in_output(
self, draft_token_ids: DraftTokenIds, scheduler_output: SchedulerOutput
) -> None:
num_invalid_spec_tokens: dict[str, int] = {}
sched_spec_tokens = scheduler_output.scheduled_spec_decode_tokens
for req_id, spec_token_ids in zip(
draft_token_ids.req_ids,
draft_token_ids.draft_token_ids,
):
request = self.requests.get(req_id)
if request is None or request.is_finished():
# The request may have been finished. Skip.
continue
placeholder_spec_tokens = sched_spec_tokens.get(req_id)
if not placeholder_spec_tokens:
continue
orig_num_spec_tokens = len(placeholder_spec_tokens)
# Trim drafts to scheduled number of spec tokens
# (needed for chunked prefill case for example).
del spec_token_ids[orig_num_spec_tokens:]
# Filter out spec tokens which do not adhere to the grammar.
if self.structured_output_manager.should_advance(request):
metadata = request.structured_output_request
assert metadata is not None and metadata.grammar is not None
spec_token_ids = metadata.grammar.validate_tokens(spec_token_ids)
# Pad to original number of spec tokens.
num_invalid_tokens = orig_num_spec_tokens - len(spec_token_ids)
if num_invalid_tokens:
spec_token_ids.extend([-1] * num_invalid_tokens)
num_invalid_spec_tokens[req_id] = num_invalid_tokens
sched_spec_tokens[req_id] = spec_token_ids
scheduler_output.num_invalid_spec_tokens = num_invalid_spec_tokens
def get_request_counts(self) -> tuple[int, int]:
"""Returns (num_running_reqs, num_waiting_reqs)."""
return len(self.running), len(self.waiting)
......@@ -1513,11 +1556,15 @@ class Scheduler(SchedulerInterface):
spec_decoding_stats: SpecDecodingStats | None,
num_draft_tokens: int,
num_accepted_tokens: int,
num_invalid_spec_tokens: dict[str, int] | None,
request_id: str,
) -> SpecDecodingStats | None:
if not self.log_stats:
if not self.log_stats or not num_draft_tokens:
return None
if spec_decoding_stats is None:
spec_decoding_stats = SpecDecodingStats.new(self.num_spec_tokens)
if num_invalid_spec_tokens:
num_draft_tokens -= num_invalid_spec_tokens.get(request_id, 0)
spec_decoding_stats.observe_draft(
num_draft_tokens=num_draft_tokens, num_accepted_tokens=num_accepted_tokens
)
......
......@@ -466,6 +466,18 @@ class EngineCore:
# in a field and do it immediately once step_with_batch_queue is
# re-called. The latter slightly favors TTFT over TPOT/throughput.
if deferred_scheduler_output:
# If we are doing speculative decoding with structured output,
# we need to get the draft token ids from the prior step before
# we can compute the grammar bitmask for the deferred request.
if self.use_spec_decode:
draft_token_ids = self.model_executor.take_draft_token_ids()
assert draft_token_ids is not None
# Update the draft token ids in the scheduler output to
# filter out the invalid spec tokens, which will be padded
# with -1 and skipped by the grammar bitmask computation.
self.scheduler.update_draft_token_ids_in_output(
draft_token_ids, deferred_scheduler_output
)
# We now have the tokens needed to compute the bitmask for the
# deferred request. Get the bitmask and call sample tokens.
grammar_output = self.scheduler.get_grammar_bitmask(
......
......@@ -158,12 +158,11 @@ class InputProcessor:
or params.presence_penalty != 0.0
or params.repetition_penalty != 1.0
or params.bad_words_token_ids
or params.structured_outputs
)
):
raise ValueError(
"async scheduling with spec decoding doesn't yet support "
"penalties, bad words or structured outputs in sampling parameters."
"penalties or bad words in sampling parameters."
)
def _validate_params(
......
......@@ -626,6 +626,7 @@ class GPUModelRunner(
# Cached outputs.
self._draft_token_ids: list[list[int]] | torch.Tensor | None = None
self._draft_token_req_ids: list[str] | None = None
self.transfer_event = torch.Event()
self.sampled_token_ids_pinned_cpu = torch.empty(
(self.max_num_reqs, 1),
......@@ -638,7 +639,22 @@ class GPUModelRunner(
# with dedicated stream for overlapping and event for coordination.
self.valid_sampled_token_count_event: torch.Event | None = None
self.valid_sampled_token_count_copy_stream: torch.cuda.Stream | None = None
if self.use_async_scheduling and self.num_spec_tokens:
# We also copy the drafted tokens to the CPU asynchronously,
# in case we need them for structured outputs.
self.draft_token_ids_event: torch.Event | None = None
self.draft_token_ids_copy_stream: torch.cuda.Stream | None = None
self.valid_sampled_token_count_cpu: torch.Tensor | None = None
self.draft_token_ids_cpu: torch.Tensor | None = None
if self.num_spec_tokens:
self.draft_token_ids_event = torch.Event()
self.draft_token_ids_copy_stream = torch.cuda.Stream()
self.draft_token_ids_cpu = torch.empty(
(self.max_num_reqs, self.num_spec_tokens),
dtype=torch.int64,
device="cpu",
pin_memory=self.pin_memory,
)
if self.use_async_scheduling:
self.valid_sampled_token_count_event = torch.Event()
self.valid_sampled_token_count_copy_stream = torch.cuda.Stream()
self.valid_sampled_token_count_cpu = torch.empty(
......@@ -1036,15 +1052,8 @@ class GPUModelRunner(
self.input_batch.spec_token_ids[req_index].clear()
self.input_batch.spec_token_ids[req_index].extend(spec_token_ids)
# there are no draft tokens with async scheduling,
# we clear the spec_decoding info in scheduler_output and
# use normal sampling but rejection_sampling.
if self.use_async_scheduling:
req_state.prev_num_draft_len = num_spec_tokens
if num_spec_tokens and self._draft_token_ids is None:
scheduler_output.total_num_scheduled_tokens -= num_spec_tokens
scheduler_output.num_scheduled_tokens[req_id] -= num_spec_tokens
scheduler_output.scheduled_spec_decode_tokens.pop(req_id, None)
# Add the new or resumed requests to the persistent batch.
# The smaller empty indices are filled first.
for request in reqs_to_add:
......@@ -1291,7 +1300,6 @@ class GPUModelRunner(
# because input_ids dtype is torch.int32,
# so convert draft_token_ids to torch.int32 here.
draft_token_ids = self._draft_token_ids.to(dtype=torch.int32)
self._draft_token_ids = None
self.input_ids.gpu.scatter_(
dim=0,
......@@ -3100,20 +3108,6 @@ class GPUModelRunner(
"after execute_model() returns None."
)
# self._draft_token_ids is None when `input_fits_in_drafter=False`
# and there is no draft tokens scheduled. so it need to update the
# spec_decoding info in scheduler_output with async_scheduling.
# use deepcopy to avoid the modification has influence on the
# scheduler_output in engine core process.
# TODO(Ronald1995): deepcopy is expensive when there is a large
# number of requests, optimize it later.
if (
self.use_async_scheduling
and self.num_spec_tokens
and self._draft_token_ids is None
):
scheduler_output = deepcopy(scheduler_output)
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
with (
record_function_or_nullcontext("gpu_model_runner: preprocess"),
......@@ -3360,6 +3354,8 @@ class GPUModelRunner(
) -> ModelRunnerOutput | AsyncModelRunnerOutput | IntermediateTensors:
kv_connector_output = self.kv_connector_output
self.kv_connector_output = None
self._draft_token_ids = None
self._draft_token_req_ids = None
if self.execute_model_state is None:
# Nothing to do (PP non-final rank case), output isn't used.
......@@ -3414,6 +3410,7 @@ class GPUModelRunner(
spec_decode_metadata,
spec_decode_common_attn_metadata,
)
self._copy_draft_token_ids_to_cpu(scheduler_output)
spec_config = self.speculative_config
use_padded_batch_for_eagle = (
......@@ -3458,6 +3455,12 @@ class GPUModelRunner(
self._copy_valid_sampled_token_count(
next_token_ids, valid_sampled_tokens_count
)
# Since we couldn't run the drafter,
# just use zeros for the draft tokens.
self._draft_token_ids = torch.zeros(
1, device=self.device, dtype=torch.int32
).expand(len(self.input_batch.req_ids), self.num_spec_tokens)
self._copy_draft_token_ids_to_cpu(scheduler_output, zeros_only=True)
with record_function_or_nullcontext("gpu_model_runner: bookkeep"):
(
......@@ -3529,19 +3532,50 @@ class GPUModelRunner(
return async_output
def take_draft_token_ids(self) -> DraftTokenIds | None:
if not self.num_spec_tokens:
if not self.num_spec_tokens or not self._draft_token_req_ids:
return None
req_ids = self._draft_token_req_ids
draft_token_ids = self._get_draft_token_ids_cpu(len(req_ids))
return DraftTokenIds(req_ids, draft_token_ids)
req_ids = self.input_batch.req_ids
if self._draft_token_ids is None:
return DraftTokenIds(req_ids, [[] for _ in req_ids])
def _copy_draft_token_ids_to_cpu(
self, scheduler_output: "SchedulerOutput", zeros_only: bool = False
) -> None:
struct_output = scheduler_output.has_structured_output_requests
if self.use_async_scheduling and not struct_output:
# Draft tokens don't need to be copied to the CPU if async
# scheduling is in use and there are no structured output reqs.
return
# We must also set the corresponding request ids.
self._draft_token_req_ids = self.input_batch.req_ids.copy()
if isinstance(self._draft_token_ids, torch.Tensor):
draft_token_ids = self._draft_token_ids.tolist()
draft_token_ids: torch.Tensor = self._draft_token_ids
if not torch.is_tensor(draft_token_ids):
return
assert self.draft_token_ids_event is not None
assert self.draft_token_ids_copy_stream is not None
assert self.draft_token_ids_cpu is not None
default_stream = torch.cuda.current_stream()
num_reqs = draft_token_ids.shape[0]
with torch.cuda.stream(self.draft_token_ids_copy_stream):
if not zeros_only:
# Trigger async copy of draft token ids to cpu.
self.draft_token_ids_copy_stream.wait_stream(default_stream)
self.draft_token_ids_cpu[:num_reqs].copy_(
draft_token_ids, non_blocking=True
)
else:
draft_token_ids = self._draft_token_ids
self._draft_token_ids = None
return DraftTokenIds(req_ids, draft_token_ids)
# No copy needed, just zero-out cpu tensor.
self.draft_token_ids_cpu[:num_reqs] = 0
self.draft_token_ids_event.record()
def _get_draft_token_ids_cpu(self, num_reqs: int) -> list[list[int]]:
if isinstance(self._draft_token_ids, list):
return self._draft_token_ids
assert self.draft_token_ids_event is not None
assert self.draft_token_ids_cpu is not None
self.draft_token_ids_event.synchronize()
return self.draft_token_ids_cpu[:num_reqs].tolist()
def _copy_valid_sampled_token_count(
self, next_token_ids: torch.Tensor, valid_sampled_tokens_count: torch.Tensor
......@@ -3556,6 +3590,7 @@ class GPUModelRunner(
self.valid_sampled_token_count_copy_stream.wait_stream(default_stream) # type: ignore
counts = valid_sampled_tokens_count
counts_cpu = self.valid_sampled_token_count_cpu
assert counts_cpu is not None
counts_cpu[: counts.shape[0]].copy_(counts, non_blocking=True)
self.valid_sampled_token_count_event.record()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment