Unverified Commit f7008ce1 authored by Benjamin Chislett's avatar Benjamin Chislett Committed by GitHub
Browse files

[Perf] Async Scheduling + Speculative Decoding + Structured Outputs (#29821)


Signed-off-by: default avatarBenjamin Chislett <bchislett@nvidia.com>
Signed-off-by: default avatarNick Hill <nickhill123@gmail.com>
Co-authored-by: default avatarNick Hill <nickhill123@gmail.com>
parent 4e67a8f6
...@@ -30,8 +30,9 @@ example_prompts = [first_prompt, "In one word, the capital of France is "] + [ ...@@ -30,8 +30,9 @@ example_prompts = [first_prompt, "In one word, the capital of France is "] + [
default_params = dict( default_params = dict(
temperature=0.0, # greedy temperature=0.0, # greedy
max_tokens=23, max_tokens=30,
min_tokens=18, # spec decoding currently doesn't support min_tokens
# min_tokens=28,
) )
...@@ -86,7 +87,7 @@ def test_without_spec_decoding( ...@@ -86,7 +87,7 @@ def test_without_spec_decoding(
run_tests(monkeypatch, MODEL, test_configs, test_sampling_params) run_tests(monkeypatch, MODEL, test_configs, test_sampling_params)
def test_with_spec_decoding(monkeypatch: pytest.MonkeyPatch): def test_with_spec_decoding(sample_json_schema, monkeypatch: pytest.MonkeyPatch):
"""Test consistency and acceptance rates with some different combos of """Test consistency and acceptance rates with some different combos of
preemption, executor, async scheduling, prefill chunking, preemption, executor, async scheduling, prefill chunking,
spec decoding model length. spec decoding model length.
...@@ -100,9 +101,16 @@ def test_with_spec_decoding(monkeypatch: pytest.MonkeyPatch): ...@@ -100,9 +101,16 @@ def test_with_spec_decoding(monkeypatch: pytest.MonkeyPatch):
# Set small draft model len to force doesn't-fit-in-drafter case. # Set small draft model len to force doesn't-fit-in-drafter case.
spec_config_short = spec_config | {"max_model_len": 50} spec_config_short = spec_config | {"max_model_len": 50}
struct_outputs = StructuredOutputsParams(json=sample_json_schema)
test_sampling_params = [ test_sampling_params = [
dict(), dict(),
dict(logprobs=2), dict(logprobs=2),
dict(structured_outputs=struct_outputs),
dict(
structured_outputs=struct_outputs,
logprobs=2,
),
] ]
# test_preemption, executor, async_scheduling, # test_preemption, executor, async_scheduling,
......
...@@ -12,10 +12,12 @@ logger = init_logger(__name__) ...@@ -12,10 +12,12 @@ logger = init_logger(__name__)
class AsyncScheduler(Scheduler): class AsyncScheduler(Scheduler):
def _update_after_schedule(self, scheduler_output: SchedulerOutput) -> None: def _update_after_schedule(self, scheduler_output: SchedulerOutput) -> None:
super()._update_after_schedule(scheduler_output) super()._update_after_schedule(scheduler_output)
has_structured_output_requests = False
pending_structured_output_tokens = False pending_structured_output_tokens = False
spec_decode_tokens = scheduler_output.scheduled_spec_decode_tokens spec_decode_tokens = scheduler_output.scheduled_spec_decode_tokens
for req_id in scheduler_output.num_scheduled_tokens: for req_id in scheduler_output.num_scheduled_tokens:
request = self.requests[req_id] request = self.requests[req_id]
has_structured_output_requests |= request.use_structured_output
pending_structured_output_tokens |= ( pending_structured_output_tokens |= (
request.use_structured_output and request.num_output_placeholders > 0 request.use_structured_output and request.num_output_placeholders > 0
) )
...@@ -33,6 +35,7 @@ class AsyncScheduler(Scheduler): ...@@ -33,6 +35,7 @@ class AsyncScheduler(Scheduler):
# We will update the actual spec token ids in the worker process. # We will update the actual spec token ids in the worker process.
request.spec_token_ids = [-1] * self.num_spec_tokens request.spec_token_ids = [-1] * self.num_spec_tokens
scheduler_output.has_structured_output_requests = has_structured_output_requests
scheduler_output.pending_structured_output_tokens = ( scheduler_output.pending_structured_output_tokens = (
pending_structured_output_tokens pending_structured_output_tokens
) )
......
...@@ -86,7 +86,26 @@ class SchedulerInterface(ABC): ...@@ -86,7 +86,26 @@ class SchedulerInterface(ABC):
@abstractmethod @abstractmethod
def update_draft_token_ids(self, draft_token_ids: "DraftTokenIds") -> None: def update_draft_token_ids(self, draft_token_ids: "DraftTokenIds") -> None:
"""Update the draft token ids for the scheduled requests.""" """Update requests with newly generated draft token ids, applying
structured output grammar validation if needed.
Args:
draft_token_ids: The input draft token ids for each request.
"""
raise NotImplementedError
@abstractmethod
def update_draft_token_ids_in_output(
self, draft_token_ids: "DraftTokenIds", scheduler_output: "SchedulerOutput"
) -> None:
"""Update scheduler output with newly generated draft token ids, applying
structured output grammar validation if needed.
Args:
draft_token_ids: The input draft token ids for each request.
scheduler_output: Update the given scheduler_output
with the corresponding draft token ids.
"""
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
......
...@@ -181,10 +181,17 @@ class SchedulerOutput: ...@@ -181,10 +181,17 @@ class SchedulerOutput:
# Only used for v2 model runner. # Only used for v2 model runner.
preempted_req_ids: set[str] | None = None preempted_req_ids: set[str] | None = None
# Whether any of the scheduled requests use structured output.
# Set only in async scheduling case.
has_structured_output_requests: bool = False
# Whether the scheduled requests have all the output tokens they # Whether the scheduled requests have all the output tokens they
# need to perform grammar bitmask computation. # need to perform grammar bitmask computation.
pending_structured_output_tokens: bool = False pending_structured_output_tokens: bool = False
# Used for adjusting acceptance rate calculation.
num_invalid_spec_tokens: dict[str, int] | None = None
# KV Cache Connector metadata. # KV Cache Connector metadata.
kv_connector_metadata: KVConnectorMetadata | None = None kv_connector_metadata: KVConnectorMetadata | None = None
......
...@@ -1130,6 +1130,8 @@ class Scheduler(SchedulerInterface): ...@@ -1130,6 +1130,8 @@ class Scheduler(SchedulerInterface):
spec_decoding_stats, spec_decoding_stats,
num_draft_tokens=num_draft_tokens, num_draft_tokens=num_draft_tokens,
num_accepted_tokens=num_accepted, num_accepted_tokens=num_accepted,
num_invalid_spec_tokens=scheduler_output.num_invalid_spec_tokens,
request_id=req_id,
) )
stopped = False stopped = False
...@@ -1168,7 +1170,13 @@ class Scheduler(SchedulerInterface): ...@@ -1168,7 +1170,13 @@ class Scheduler(SchedulerInterface):
struct_output_request = request.structured_output_request struct_output_request = request.structured_output_request
assert struct_output_request is not None assert struct_output_request is not None
assert struct_output_request.grammar is not None assert struct_output_request.grammar is not None
struct_output_request.grammar.accept_tokens(req_id, new_token_ids) ok = struct_output_request.grammar.accept_tokens(req_id, new_token_ids)
if not ok:
logger.warning(
"Unexpected: grammar rejected tokens %s for request %s.",
new_token_ids,
req_id,
)
if num_nans_in_logits is not None and req_id in num_nans_in_logits: if num_nans_in_logits is not None and req_id in num_nans_in_logits:
request.num_nans_in_logits = num_nans_in_logits[req_id] request.num_nans_in_logits = num_nans_in_logits[req_id]
...@@ -1330,11 +1338,46 @@ class Scheduler(SchedulerInterface): ...@@ -1330,11 +1338,46 @@ class Scheduler(SchedulerInterface):
# Add newly generated spec token ids to the request. # Add newly generated spec token ids to the request.
if self.structured_output_manager.should_advance(request): if self.structured_output_manager.should_advance(request):
metadata = request.structured_output_request metadata = request.structured_output_request
request.spec_token_ids = metadata.grammar.validate_tokens( # type: ignore[union-attr] spec_token_ids = metadata.grammar.validate_tokens(spec_token_ids) # type: ignore[union-attr]
spec_token_ids request.spec_token_ids = spec_token_ids
)
else: def update_draft_token_ids_in_output(
request.spec_token_ids = spec_token_ids self, draft_token_ids: DraftTokenIds, scheduler_output: SchedulerOutput
) -> None:
num_invalid_spec_tokens: dict[str, int] = {}
sched_spec_tokens = scheduler_output.scheduled_spec_decode_tokens
for req_id, spec_token_ids in zip(
draft_token_ids.req_ids,
draft_token_ids.draft_token_ids,
):
request = self.requests.get(req_id)
if request is None or request.is_finished():
# The request may have been finished. Skip.
continue
placeholder_spec_tokens = sched_spec_tokens.get(req_id)
if not placeholder_spec_tokens:
continue
orig_num_spec_tokens = len(placeholder_spec_tokens)
# Trim drafts to scheduled number of spec tokens
# (needed for chunked prefill case for example).
del spec_token_ids[orig_num_spec_tokens:]
# Filter out spec tokens which do not adhere to the grammar.
if self.structured_output_manager.should_advance(request):
metadata = request.structured_output_request
assert metadata is not None and metadata.grammar is not None
spec_token_ids = metadata.grammar.validate_tokens(spec_token_ids)
# Pad to original number of spec tokens.
num_invalid_tokens = orig_num_spec_tokens - len(spec_token_ids)
if num_invalid_tokens:
spec_token_ids.extend([-1] * num_invalid_tokens)
num_invalid_spec_tokens[req_id] = num_invalid_tokens
sched_spec_tokens[req_id] = spec_token_ids
scheduler_output.num_invalid_spec_tokens = num_invalid_spec_tokens
def get_request_counts(self) -> tuple[int, int]: def get_request_counts(self) -> tuple[int, int]:
"""Returns (num_running_reqs, num_waiting_reqs).""" """Returns (num_running_reqs, num_waiting_reqs)."""
...@@ -1513,11 +1556,15 @@ class Scheduler(SchedulerInterface): ...@@ -1513,11 +1556,15 @@ class Scheduler(SchedulerInterface):
spec_decoding_stats: SpecDecodingStats | None, spec_decoding_stats: SpecDecodingStats | None,
num_draft_tokens: int, num_draft_tokens: int,
num_accepted_tokens: int, num_accepted_tokens: int,
num_invalid_spec_tokens: dict[str, int] | None,
request_id: str,
) -> SpecDecodingStats | None: ) -> SpecDecodingStats | None:
if not self.log_stats: if not self.log_stats or not num_draft_tokens:
return None return None
if spec_decoding_stats is None: if spec_decoding_stats is None:
spec_decoding_stats = SpecDecodingStats.new(self.num_spec_tokens) spec_decoding_stats = SpecDecodingStats.new(self.num_spec_tokens)
if num_invalid_spec_tokens:
num_draft_tokens -= num_invalid_spec_tokens.get(request_id, 0)
spec_decoding_stats.observe_draft( spec_decoding_stats.observe_draft(
num_draft_tokens=num_draft_tokens, num_accepted_tokens=num_accepted_tokens num_draft_tokens=num_draft_tokens, num_accepted_tokens=num_accepted_tokens
) )
......
...@@ -466,6 +466,18 @@ class EngineCore: ...@@ -466,6 +466,18 @@ class EngineCore:
# in a field and do it immediately once step_with_batch_queue is # in a field and do it immediately once step_with_batch_queue is
# re-called. The latter slightly favors TTFT over TPOT/throughput. # re-called. The latter slightly favors TTFT over TPOT/throughput.
if deferred_scheduler_output: if deferred_scheduler_output:
# If we are doing speculative decoding with structured output,
# we need to get the draft token ids from the prior step before
# we can compute the grammar bitmask for the deferred request.
if self.use_spec_decode:
draft_token_ids = self.model_executor.take_draft_token_ids()
assert draft_token_ids is not None
# Update the draft token ids in the scheduler output to
# filter out the invalid spec tokens, which will be padded
# with -1 and skipped by the grammar bitmask computation.
self.scheduler.update_draft_token_ids_in_output(
draft_token_ids, deferred_scheduler_output
)
# We now have the tokens needed to compute the bitmask for the # We now have the tokens needed to compute the bitmask for the
# deferred request. Get the bitmask and call sample tokens. # deferred request. Get the bitmask and call sample tokens.
grammar_output = self.scheduler.get_grammar_bitmask( grammar_output = self.scheduler.get_grammar_bitmask(
......
...@@ -158,12 +158,11 @@ class InputProcessor: ...@@ -158,12 +158,11 @@ class InputProcessor:
or params.presence_penalty != 0.0 or params.presence_penalty != 0.0
or params.repetition_penalty != 1.0 or params.repetition_penalty != 1.0
or params.bad_words_token_ids or params.bad_words_token_ids
or params.structured_outputs
) )
): ):
raise ValueError( raise ValueError(
"async scheduling with spec decoding doesn't yet support " "async scheduling with spec decoding doesn't yet support "
"penalties, bad words or structured outputs in sampling parameters." "penalties or bad words in sampling parameters."
) )
def _validate_params( def _validate_params(
......
...@@ -626,6 +626,7 @@ class GPUModelRunner( ...@@ -626,6 +626,7 @@ class GPUModelRunner(
# Cached outputs. # Cached outputs.
self._draft_token_ids: list[list[int]] | torch.Tensor | None = None self._draft_token_ids: list[list[int]] | torch.Tensor | None = None
self._draft_token_req_ids: list[str] | None = None
self.transfer_event = torch.Event() self.transfer_event = torch.Event()
self.sampled_token_ids_pinned_cpu = torch.empty( self.sampled_token_ids_pinned_cpu = torch.empty(
(self.max_num_reqs, 1), (self.max_num_reqs, 1),
...@@ -638,15 +639,30 @@ class GPUModelRunner( ...@@ -638,15 +639,30 @@ class GPUModelRunner(
# with dedicated stream for overlapping and event for coordination. # with dedicated stream for overlapping and event for coordination.
self.valid_sampled_token_count_event: torch.Event | None = None self.valid_sampled_token_count_event: torch.Event | None = None
self.valid_sampled_token_count_copy_stream: torch.cuda.Stream | None = None self.valid_sampled_token_count_copy_stream: torch.cuda.Stream | None = None
if self.use_async_scheduling and self.num_spec_tokens: # We also copy the drafted tokens to the CPU asynchronously,
self.valid_sampled_token_count_event = torch.Event() # in case we need them for structured outputs.
self.valid_sampled_token_count_copy_stream = torch.cuda.Stream() self.draft_token_ids_event: torch.Event | None = None
self.valid_sampled_token_count_cpu = torch.empty( self.draft_token_ids_copy_stream: torch.cuda.Stream | None = None
self.max_num_reqs, self.valid_sampled_token_count_cpu: torch.Tensor | None = None
dtype=torch.int64, self.draft_token_ids_cpu: torch.Tensor | None = None
device="cpu", if self.num_spec_tokens:
pin_memory=self.pin_memory, self.draft_token_ids_event = torch.Event()
) self.draft_token_ids_copy_stream = torch.cuda.Stream()
self.draft_token_ids_cpu = torch.empty(
(self.max_num_reqs, self.num_spec_tokens),
dtype=torch.int64,
device="cpu",
pin_memory=self.pin_memory,
)
if self.use_async_scheduling:
self.valid_sampled_token_count_event = torch.Event()
self.valid_sampled_token_count_copy_stream = torch.cuda.Stream()
self.valid_sampled_token_count_cpu = torch.empty(
self.max_num_reqs,
dtype=torch.int64,
device="cpu",
pin_memory=self.pin_memory,
)
# Ephemeral state transferred between execute_model() and sample_tokens(). # Ephemeral state transferred between execute_model() and sample_tokens().
self.execute_model_state: ExecuteModelState | None = None self.execute_model_state: ExecuteModelState | None = None
...@@ -1036,15 +1052,8 @@ class GPUModelRunner( ...@@ -1036,15 +1052,8 @@ class GPUModelRunner(
self.input_batch.spec_token_ids[req_index].clear() self.input_batch.spec_token_ids[req_index].clear()
self.input_batch.spec_token_ids[req_index].extend(spec_token_ids) self.input_batch.spec_token_ids[req_index].extend(spec_token_ids)
# there are no draft tokens with async scheduling,
# we clear the spec_decoding info in scheduler_output and
# use normal sampling but rejection_sampling.
if self.use_async_scheduling: if self.use_async_scheduling:
req_state.prev_num_draft_len = num_spec_tokens req_state.prev_num_draft_len = num_spec_tokens
if num_spec_tokens and self._draft_token_ids is None:
scheduler_output.total_num_scheduled_tokens -= num_spec_tokens
scheduler_output.num_scheduled_tokens[req_id] -= num_spec_tokens
scheduler_output.scheduled_spec_decode_tokens.pop(req_id, None)
# Add the new or resumed requests to the persistent batch. # Add the new or resumed requests to the persistent batch.
# The smaller empty indices are filled first. # The smaller empty indices are filled first.
for request in reqs_to_add: for request in reqs_to_add:
...@@ -1291,7 +1300,6 @@ class GPUModelRunner( ...@@ -1291,7 +1300,6 @@ class GPUModelRunner(
# because input_ids dtype is torch.int32, # because input_ids dtype is torch.int32,
# so convert draft_token_ids to torch.int32 here. # so convert draft_token_ids to torch.int32 here.
draft_token_ids = self._draft_token_ids.to(dtype=torch.int32) draft_token_ids = self._draft_token_ids.to(dtype=torch.int32)
self._draft_token_ids = None
self.input_ids.gpu.scatter_( self.input_ids.gpu.scatter_(
dim=0, dim=0,
...@@ -3100,20 +3108,6 @@ class GPUModelRunner( ...@@ -3100,20 +3108,6 @@ class GPUModelRunner(
"after execute_model() returns None." "after execute_model() returns None."
) )
# self._draft_token_ids is None when `input_fits_in_drafter=False`
# and there is no draft tokens scheduled. so it need to update the
# spec_decoding info in scheduler_output with async_scheduling.
# use deepcopy to avoid the modification has influence on the
# scheduler_output in engine core process.
# TODO(Ronald1995): deepcopy is expensive when there is a large
# number of requests, optimize it later.
if (
self.use_async_scheduling
and self.num_spec_tokens
and self._draft_token_ids is None
):
scheduler_output = deepcopy(scheduler_output)
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
with ( with (
record_function_or_nullcontext("gpu_model_runner: preprocess"), record_function_or_nullcontext("gpu_model_runner: preprocess"),
...@@ -3360,6 +3354,8 @@ class GPUModelRunner( ...@@ -3360,6 +3354,8 @@ class GPUModelRunner(
) -> ModelRunnerOutput | AsyncModelRunnerOutput | IntermediateTensors: ) -> ModelRunnerOutput | AsyncModelRunnerOutput | IntermediateTensors:
kv_connector_output = self.kv_connector_output kv_connector_output = self.kv_connector_output
self.kv_connector_output = None self.kv_connector_output = None
self._draft_token_ids = None
self._draft_token_req_ids = None
if self.execute_model_state is None: if self.execute_model_state is None:
# Nothing to do (PP non-final rank case), output isn't used. # Nothing to do (PP non-final rank case), output isn't used.
...@@ -3414,6 +3410,7 @@ class GPUModelRunner( ...@@ -3414,6 +3410,7 @@ class GPUModelRunner(
spec_decode_metadata, spec_decode_metadata,
spec_decode_common_attn_metadata, spec_decode_common_attn_metadata,
) )
self._copy_draft_token_ids_to_cpu(scheduler_output)
spec_config = self.speculative_config spec_config = self.speculative_config
use_padded_batch_for_eagle = ( use_padded_batch_for_eagle = (
...@@ -3458,6 +3455,12 @@ class GPUModelRunner( ...@@ -3458,6 +3455,12 @@ class GPUModelRunner(
self._copy_valid_sampled_token_count( self._copy_valid_sampled_token_count(
next_token_ids, valid_sampled_tokens_count next_token_ids, valid_sampled_tokens_count
) )
# Since we couldn't run the drafter,
# just use zeros for the draft tokens.
self._draft_token_ids = torch.zeros(
1, device=self.device, dtype=torch.int32
).expand(len(self.input_batch.req_ids), self.num_spec_tokens)
self._copy_draft_token_ids_to_cpu(scheduler_output, zeros_only=True)
with record_function_or_nullcontext("gpu_model_runner: bookkeep"): with record_function_or_nullcontext("gpu_model_runner: bookkeep"):
( (
...@@ -3529,19 +3532,50 @@ class GPUModelRunner( ...@@ -3529,19 +3532,50 @@ class GPUModelRunner(
return async_output return async_output
def take_draft_token_ids(self) -> DraftTokenIds | None: def take_draft_token_ids(self) -> DraftTokenIds | None:
if not self.num_spec_tokens: if not self.num_spec_tokens or not self._draft_token_req_ids:
return None return None
req_ids = self._draft_token_req_ids
draft_token_ids = self._get_draft_token_ids_cpu(len(req_ids))
return DraftTokenIds(req_ids, draft_token_ids)
req_ids = self.input_batch.req_ids def _copy_draft_token_ids_to_cpu(
if self._draft_token_ids is None: self, scheduler_output: "SchedulerOutput", zeros_only: bool = False
return DraftTokenIds(req_ids, [[] for _ in req_ids]) ) -> None:
struct_output = scheduler_output.has_structured_output_requests
if self.use_async_scheduling and not struct_output:
# Draft tokens don't need to be copied to the CPU if async
# scheduling is in use and there are no structured output reqs.
return
# We must also set the corresponding request ids.
self._draft_token_req_ids = self.input_batch.req_ids.copy()
if isinstance(self._draft_token_ids, torch.Tensor): draft_token_ids: torch.Tensor = self._draft_token_ids
draft_token_ids = self._draft_token_ids.tolist() if not torch.is_tensor(draft_token_ids):
else: return
draft_token_ids = self._draft_token_ids assert self.draft_token_ids_event is not None
self._draft_token_ids = None assert self.draft_token_ids_copy_stream is not None
return DraftTokenIds(req_ids, draft_token_ids) assert self.draft_token_ids_cpu is not None
default_stream = torch.cuda.current_stream()
num_reqs = draft_token_ids.shape[0]
with torch.cuda.stream(self.draft_token_ids_copy_stream):
if not zeros_only:
# Trigger async copy of draft token ids to cpu.
self.draft_token_ids_copy_stream.wait_stream(default_stream)
self.draft_token_ids_cpu[:num_reqs].copy_(
draft_token_ids, non_blocking=True
)
else:
# No copy needed, just zero-out cpu tensor.
self.draft_token_ids_cpu[:num_reqs] = 0
self.draft_token_ids_event.record()
def _get_draft_token_ids_cpu(self, num_reqs: int) -> list[list[int]]:
if isinstance(self._draft_token_ids, list):
return self._draft_token_ids
assert self.draft_token_ids_event is not None
assert self.draft_token_ids_cpu is not None
self.draft_token_ids_event.synchronize()
return self.draft_token_ids_cpu[:num_reqs].tolist()
def _copy_valid_sampled_token_count( def _copy_valid_sampled_token_count(
self, next_token_ids: torch.Tensor, valid_sampled_tokens_count: torch.Tensor self, next_token_ids: torch.Tensor, valid_sampled_tokens_count: torch.Tensor
...@@ -3556,6 +3590,7 @@ class GPUModelRunner( ...@@ -3556,6 +3590,7 @@ class GPUModelRunner(
self.valid_sampled_token_count_copy_stream.wait_stream(default_stream) # type: ignore self.valid_sampled_token_count_copy_stream.wait_stream(default_stream) # type: ignore
counts = valid_sampled_tokens_count counts = valid_sampled_tokens_count
counts_cpu = self.valid_sampled_token_count_cpu counts_cpu = self.valid_sampled_token_count_cpu
assert counts_cpu is not None
counts_cpu[: counts.shape[0]].copy_(counts, non_blocking=True) counts_cpu[: counts.shape[0]].copy_(counts, non_blocking=True)
self.valid_sampled_token_count_event.record() self.valid_sampled_token_count_event.record()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment