Unverified Commit 0cdbe7b7 authored by Nick Hill's avatar Nick Hill Committed by GitHub
Browse files

[Core] Async scheduling + structured outputs compatibility (#26866)


Signed-off-by: default avatarNick Hill <nhill@redhat.com>
parent df334868
...@@ -6,6 +6,9 @@ from copy import deepcopy ...@@ -6,6 +6,9 @@ from copy import deepcopy
from tblib import pickling_support from tblib import pickling_support
# Import fixture
from tests.v1.entrypoints.conftest import sample_json_schema # noqa
# ruff: noqa # ruff: noqa
# Install support for pickling exceptions so that we can nicely propagate # Install support for pickling exceptions so that we can nicely propagate
......
...@@ -337,8 +337,6 @@ def test_stop_via_update_from_output(): ...@@ -337,8 +337,6 @@ def test_stop_via_update_from_output():
num_common_prefix_blocks=[], num_common_prefix_blocks=[],
finished_req_ids=set(), finished_req_ids=set(),
free_encoder_mm_hashes=[], free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
) )
model_output = ModelRunnerOutput( model_output = ModelRunnerOutput(
...@@ -385,8 +383,6 @@ def test_stop_via_update_from_output(): ...@@ -385,8 +383,6 @@ def test_stop_via_update_from_output():
num_common_prefix_blocks=[], num_common_prefix_blocks=[],
finished_req_ids=set(), finished_req_ids=set(),
free_encoder_mm_hashes=[], free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
) )
model_output = ModelRunnerOutput( model_output = ModelRunnerOutput(
...@@ -431,8 +427,6 @@ def test_stop_via_update_from_output(): ...@@ -431,8 +427,6 @@ def test_stop_via_update_from_output():
num_common_prefix_blocks=[], num_common_prefix_blocks=[],
finished_req_ids=set(), finished_req_ids=set(),
free_encoder_mm_hashes=[], free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
) )
model_output = ModelRunnerOutput( model_output = ModelRunnerOutput(
...@@ -472,8 +466,6 @@ def test_stop_via_update_from_output(): ...@@ -472,8 +466,6 @@ def test_stop_via_update_from_output():
num_common_prefix_blocks=[], num_common_prefix_blocks=[],
finished_req_ids=set(), finished_req_ids=set(),
free_encoder_mm_hashes=[], free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
) )
model_output = ModelRunnerOutput( model_output = ModelRunnerOutput(
...@@ -1988,7 +1980,6 @@ def test_schedule_skip_tokenizer_init(): ...@@ -1988,7 +1980,6 @@ def test_schedule_skip_tokenizer_init():
scheduler.add_request(request) scheduler.add_request(request)
output = scheduler.schedule() output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == len(requests) assert len(output.scheduled_new_reqs) == len(requests)
assert output.grammar_bitmask is None
def test_schedule_skip_tokenizer_init_structured_output_request(): def test_schedule_skip_tokenizer_init_structured_output_request():
......
...@@ -7,6 +7,7 @@ import torch._dynamo.config as dynamo_config ...@@ -7,6 +7,7 @@ import torch._dynamo.config as dynamo_config
from vllm import SamplingParams from vllm import SamplingParams
from vllm.logprobs import Logprob from vllm.logprobs import Logprob
from vllm.sampling_params import StructuredOutputsParams
from ...conftest import VllmRunner from ...conftest import VllmRunner
from ...models.utils import check_outputs_equal from ...models.utils import check_outputs_equal
...@@ -15,9 +16,12 @@ MODEL = "Qwen/Qwen3-0.6B" ...@@ -15,9 +16,12 @@ MODEL = "Qwen/Qwen3-0.6B"
@dynamo_config.patch(cache_size_limit=16) @dynamo_config.patch(cache_size_limit=16)
def test_preempt_and_async_scheduling_e2e(monkeypatch: pytest.MonkeyPatch): def test_preempt_and_async_scheduling_e2e(
sample_json_schema, monkeypatch: pytest.MonkeyPatch
):
"""Test consistency of combos of async scheduling, preemption, """Test consistency of combos of async scheduling, preemption,
uni/multiproc executor, and various sampling parameters.""" uni/multiproc executor, and various sampling parameters
including structured outputs."""
first_prompt = ( first_prompt = (
"The following numbers of the sequence " "The following numbers of the sequence "
...@@ -35,6 +39,12 @@ def test_preempt_and_async_scheduling_e2e(monkeypatch: pytest.MonkeyPatch): ...@@ -35,6 +39,12 @@ def test_preempt_and_async_scheduling_e2e(monkeypatch: pytest.MonkeyPatch):
dict(bad_words=["the", " the"]), dict(bad_words=["the", " the"]),
dict(logprobs=2), dict(logprobs=2),
dict(logprobs=2, presence_penalty=-1.0), dict(logprobs=2, presence_penalty=-1.0),
dict(structured_outputs=StructuredOutputsParams(json=sample_json_schema)),
dict(
structured_outputs=StructuredOutputsParams(json=sample_json_schema),
logprobs=2,
presence_penalty=-1.0,
),
] ]
default_params = dict( default_params = dict(
......
...@@ -248,7 +248,7 @@ def test_engine_core_concurrent_batches(): ...@@ -248,7 +248,7 @@ def test_engine_core_concurrent_batches():
self, self,
scheduler_output, scheduler_output,
non_block=False, non_block=False,
) -> Future[ModelRunnerOutput]: ) -> Future[ModelRunnerOutput | None]:
"""Make execute_model non-blocking.""" """Make execute_model non-blocking."""
# DummyExecutor used only for testing async case. # DummyExecutor used only for testing async case.
...@@ -263,6 +263,23 @@ def test_engine_core_concurrent_batches(): ...@@ -263,6 +263,23 @@ def test_engine_core_concurrent_batches():
# Use the thread pool instead of creating a new thread # Use the thread pool instead of creating a new thread
return self.thread_pool.submit(_execute) return self.thread_pool.submit(_execute)
def sample_tokens(
self, grammar_output, non_block=False
) -> Future[ModelRunnerOutput]:
"""Make sample_tokens non-blocking."""
# DummyExecutor used only for testing async case.
assert non_block
def _execute():
output = self.collective_rpc("sample_tokens", args=(grammar_output,))
# Make a copy because output[0] may be reused
# by the next batch.
return copy.deepcopy(output[0])
# Use the thread pool instead of creating a new thread
return self.thread_pool.submit(_execute)
@property @property
def max_concurrent_batches(self) -> int: def max_concurrent_batches(self) -> int:
return 2 return 2
......
...@@ -31,7 +31,9 @@ class CustomMultiprocExecutor(MultiprocExecutor): ...@@ -31,7 +31,9 @@ class CustomMultiprocExecutor(MultiprocExecutor):
# Drop marker to show that this was run # Drop marker to show that this was run
with open(".marker", "w"): with open(".marker", "w"):
... ...
return super().collective_rpc(method, timeout, args, kwargs) return super().collective_rpc(
method, timeout, args, kwargs, non_block, unique_reply_rank
)
CustomMultiprocExecutorAsync = CustomMultiprocExecutor CustomMultiprocExecutorAsync = CustomMultiprocExecutor
......
...@@ -26,8 +26,6 @@ def _make_empty_scheduler_output(): ...@@ -26,8 +26,6 @@ def _make_empty_scheduler_output():
num_common_prefix_blocks=[], num_common_prefix_blocks=[],
finished_req_ids=set(), finished_req_ids=set(),
free_encoder_mm_hashes=[], free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
kv_connector_metadata=SharedStorageConnectorMetadata(), kv_connector_metadata=SharedStorageConnectorMetadata(),
) )
......
...@@ -981,9 +981,7 @@ def test_scheduler_kv_connector_stats_aggregation(): ...@@ -981,9 +981,7 @@ def test_scheduler_kv_connector_stats_aggregation():
scheduled_encoder_inputs={}, scheduled_encoder_inputs={},
num_common_prefix_blocks=[0], num_common_prefix_blocks=[0],
finished_req_ids=set(), finished_req_ids=set(),
free_encoder_mm_hashes=set(), free_encoder_mm_hashes=[],
structured_output_request_ids={},
grammar_bitmask=None,
) )
engine_core_outputs = scheduler.update_from_output(scheduler_output, model_output) engine_core_outputs = scheduler.update_from_output(scheduler_output, model_output)
......
...@@ -92,8 +92,6 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput: ...@@ -92,8 +92,6 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
num_common_prefix_blocks=[], num_common_prefix_blocks=[],
finished_req_ids=set(), finished_req_ids=set(),
free_encoder_mm_hashes=[], free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
) )
...@@ -171,8 +169,6 @@ def test_update_states_request_finished(model_runner): ...@@ -171,8 +169,6 @@ def test_update_states_request_finished(model_runner):
num_common_prefix_blocks=[], num_common_prefix_blocks=[],
finished_req_ids={req_id}, finished_req_ids={req_id},
free_encoder_mm_hashes=[], free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
) )
model_runner._update_states(scheduler_output) model_runner._update_states(scheduler_output)
...@@ -201,8 +197,6 @@ def test_update_states_request_resumed(model_runner): ...@@ -201,8 +197,6 @@ def test_update_states_request_resumed(model_runner):
num_common_prefix_blocks=[], num_common_prefix_blocks=[],
finished_req_ids=set(), finished_req_ids=set(),
free_encoder_mm_hashes=[], free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
) )
model_runner._update_states(scheduler_output) model_runner._update_states(scheduler_output)
...@@ -230,8 +224,6 @@ def test_update_states_request_resumed(model_runner): ...@@ -230,8 +224,6 @@ def test_update_states_request_resumed(model_runner):
num_common_prefix_blocks=[], num_common_prefix_blocks=[],
finished_req_ids=set(), finished_req_ids=set(),
free_encoder_mm_hashes=[], free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
) )
model_runner._update_states(scheduler_output) model_runner._update_states(scheduler_output)
...@@ -261,8 +253,6 @@ def test_update_states_no_changes(model_runner): ...@@ -261,8 +253,6 @@ def test_update_states_no_changes(model_runner):
num_common_prefix_blocks=[], num_common_prefix_blocks=[],
finished_req_ids=set(), finished_req_ids=set(),
free_encoder_mm_hashes=[], free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
) )
model_runner._update_states(scheduler_output) model_runner._update_states(scheduler_output)
...@@ -296,8 +286,6 @@ def test_update_states_request_unscheduled(model_runner): ...@@ -296,8 +286,6 @@ def test_update_states_request_unscheduled(model_runner):
num_common_prefix_blocks=[], num_common_prefix_blocks=[],
finished_req_ids=set(), finished_req_ids=set(),
free_encoder_mm_hashes=[], free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
) )
model_runner._update_states(scheduler_output) model_runner._update_states(scheduler_output)
......
...@@ -152,8 +152,6 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput: ...@@ -152,8 +152,6 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
num_common_prefix_blocks=[], num_common_prefix_blocks=[],
finished_req_ids=set(), finished_req_ids=set(),
free_encoder_mm_hashes=[], free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
) )
...@@ -269,8 +267,6 @@ def test_update_states_request_finished(model_runner, dist_init): ...@@ -269,8 +267,6 @@ def test_update_states_request_finished(model_runner, dist_init):
num_common_prefix_blocks=[], num_common_prefix_blocks=[],
finished_req_ids={req_id}, finished_req_ids={req_id},
free_encoder_mm_hashes=[], free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
) )
metadata_before = model_runner.input_batch.sampling_metadata metadata_before = model_runner.input_batch.sampling_metadata
...@@ -301,8 +297,6 @@ def test_update_states_request_resumed(model_runner, dist_init): ...@@ -301,8 +297,6 @@ def test_update_states_request_resumed(model_runner, dist_init):
num_common_prefix_blocks=[], num_common_prefix_blocks=[],
finished_req_ids=set(), finished_req_ids=set(),
free_encoder_mm_hashes=[], free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
) )
model_runner._update_states(scheduler_output) model_runner._update_states(scheduler_output)
...@@ -330,8 +324,6 @@ def test_update_states_request_resumed(model_runner, dist_init): ...@@ -330,8 +324,6 @@ def test_update_states_request_resumed(model_runner, dist_init):
num_common_prefix_blocks=[], num_common_prefix_blocks=[],
finished_req_ids=set(), finished_req_ids=set(),
free_encoder_mm_hashes=[], free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
) )
metadata_before = model_runner.input_batch.sampling_metadata metadata_before = model_runner.input_batch.sampling_metadata
...@@ -423,8 +415,6 @@ def test_update_states_no_changes(model_runner, dist_init): ...@@ -423,8 +415,6 @@ def test_update_states_no_changes(model_runner, dist_init):
num_common_prefix_blocks=[], num_common_prefix_blocks=[],
finished_req_ids=set(), finished_req_ids=set(),
free_encoder_mm_hashes=[], free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
) )
metadata_before = model_runner.input_batch.sampling_metadata metadata_before = model_runner.input_batch.sampling_metadata
...@@ -460,8 +450,6 @@ def test_update_states_request_unscheduled(model_runner, dist_init): ...@@ -460,8 +450,6 @@ def test_update_states_request_unscheduled(model_runner, dist_init):
num_common_prefix_blocks=[], num_common_prefix_blocks=[],
finished_req_ids=set(), finished_req_ids=set(),
free_encoder_mm_hashes=[], free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
) )
metadata_before = model_runner._update_states(scheduler_output) metadata_before = model_runner._update_states(scheduler_output)
......
...@@ -6,7 +6,7 @@ KV cache helper for store. ...@@ -6,7 +6,7 @@ KV cache helper for store.
from collections.abc import Sequence from collections.abc import Sequence
from concurrent.futures import CancelledError, Future from concurrent.futures import CancelledError, Future
from typing import TYPE_CHECKING, Literal, cast from typing import TYPE_CHECKING, Literal
import torch import torch
...@@ -138,8 +138,11 @@ class KVOutputAggregator: ...@@ -138,8 +138,11 @@ class KVOutputAggregator:
return cls(connector.get_finished_count() or world_size) return cls(connector.get_finished_count() or world_size)
def aggregate( def aggregate(
self, outputs: list[ModelRunnerOutput], output_rank: int = 0 self, outputs: list[ModelRunnerOutput | None], output_rank: int = 0
) -> ModelRunnerOutput: ) -> ModelRunnerOutput | None:
if not outputs[output_rank]:
return None
# Aggregate kv_connector_output from all workers # Aggregate kv_connector_output from all workers
def update_finished_set( def update_finished_set(
...@@ -161,6 +164,7 @@ class KVOutputAggregator: ...@@ -161,6 +164,7 @@ class KVOutputAggregator:
aggregated_kv_connector_stats = None aggregated_kv_connector_stats = None
invalid_block_ids = set[int]() invalid_block_ids = set[int]()
for model_runner_output in outputs: for model_runner_output in outputs:
assert model_runner_output is not None
kv_output = model_runner_output.kv_connector_output kv_output = model_runner_output.kv_connector_output
if not kv_output: if not kv_output:
continue continue
...@@ -204,6 +208,7 @@ class KVOutputAggregator: ...@@ -204,6 +208,7 @@ class KVOutputAggregator:
# select output of the worker specified by output_rank # select output of the worker specified by output_rank
output = outputs[output_rank] output = outputs[output_rank]
assert output is not None
output.kv_connector_output = KVConnectorOutput( output.kv_connector_output = KVConnectorOutput(
finished_sending=finished_sending or None, finished_sending=finished_sending or None,
finished_recving=finished_recving or None, finished_recving=finished_recving or None,
...@@ -215,13 +220,16 @@ class KVOutputAggregator: ...@@ -215,13 +220,16 @@ class KVOutputAggregator:
return output return output
def async_aggregate( def async_aggregate(
self, output_futures: Sequence[Future[ModelRunnerOutput]], output_rank: int = 0 self,
) -> Future[ModelRunnerOutput]: output_futures: Sequence[Future[ModelRunnerOutput | None]],
output_rank: int = 0,
) -> Future[ModelRunnerOutput | None]:
"""Takes a list of futures and returns a single future which resolves """Takes a list of futures and returns a single future which resolves
to the respective list of outputs.""" to the respective list of outputs."""
result_future: Future[ModelRunnerOutput] = Future() result_future: Future[ModelRunnerOutput | None] = Future()
outputs: list[ModelRunnerOutput | None] = [None] * len(output_futures) outputs: list[ModelRunnerOutput | None] = [None] * len(output_futures)
remaining = len(output_futures)
def make_callback(idx): def make_callback(idx):
def callback(fut): def callback(fut):
...@@ -236,12 +244,10 @@ class KVOutputAggregator: ...@@ -236,12 +244,10 @@ class KVOutputAggregator:
result_future.set_exception(e) result_future.set_exception(e)
# this check assumes io_thread_pool uses a single thread # this check assumes io_thread_pool uses a single thread
if all(outputs): nonlocal remaining
result_future.set_result( remaining -= 1
self.aggregate( if not remaining:
cast(list[ModelRunnerOutput], outputs), output_rank result_future.set_result(self.aggregate(outputs, output_rank))
)
)
return callback return callback
......
...@@ -15,8 +15,12 @@ class AsyncScheduler(Scheduler): ...@@ -15,8 +15,12 @@ class AsyncScheduler(Scheduler):
scheduler_output: SchedulerOutput, scheduler_output: SchedulerOutput,
) -> None: ) -> None:
super()._update_after_schedule(scheduler_output) super()._update_after_schedule(scheduler_output)
pending_structured_output_tokens = False
for req_id in scheduler_output.num_scheduled_tokens: for req_id in scheduler_output.num_scheduled_tokens:
request = self.requests[req_id] request = self.requests[req_id]
pending_structured_output_tokens |= (
request.use_structured_output and request.num_output_placeholders > 0
)
if ( if (
request.num_computed_tokens request.num_computed_tokens
== request.num_tokens + request.num_output_placeholders == request.num_tokens + request.num_output_placeholders
...@@ -25,6 +29,10 @@ class AsyncScheduler(Scheduler): ...@@ -25,6 +29,10 @@ class AsyncScheduler(Scheduler):
# TODO(woosuk): Support speculative decoding. # TODO(woosuk): Support speculative decoding.
request.num_output_placeholders += 1 request.num_output_placeholders += 1
scheduler_output.pending_structured_output_tokens = (
pending_structured_output_tokens
)
def _update_request_with_output( def _update_request_with_output(
self, self,
request: Request, request: Request,
......
...@@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Optional ...@@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.engine import EngineCoreOutputs from vllm.v1.engine import EngineCoreOutputs
from vllm.v1.metrics.stats import SchedulerStats from vllm.v1.metrics.stats import SchedulerStats
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
...@@ -40,6 +40,12 @@ class SchedulerInterface(ABC): ...@@ -40,6 +40,12 @@ class SchedulerInterface(ABC):
""" """
raise NotImplementedError raise NotImplementedError
@abstractmethod
def get_grammar_bitmask(
self, scheduler_output: "SchedulerOutput"
) -> "GrammarOutput | None":
raise NotImplementedError
@abstractmethod @abstractmethod
def update_from_output( def update_from_output(
self, self,
......
...@@ -181,12 +181,17 @@ class SchedulerOutput: ...@@ -181,12 +181,17 @@ class SchedulerOutput:
# freed from the encoder cache. # freed from the encoder cache.
free_encoder_mm_hashes: list[str] free_encoder_mm_hashes: list[str]
# ids of structured outputs requests included in the bitmask, in the # Whether the scheduled requests have all the output tokens they
# same order as the corresponding stacked rows of the bitmask. # need to perform grammar bitmask computation.
# There may be more than one row per request in the case of speculative decoding. pending_structured_output_tokens: bool = False
structured_output_request_ids: list[str]
# the bitmask for the whole batch
grammar_bitmask: "npt.NDArray[np.int32] | None"
# KV Cache Connector metadata. # KV Cache Connector metadata.
kv_connector_metadata: KVConnectorMetadata | None = None kv_connector_metadata: KVConnectorMetadata | None = None
@dataclass
class GrammarOutput:
# ids of structured output requests.
structured_output_request_ids: list[str]
# Bitmask ordered as structured_output_request_ids.
grammar_bitmask: "npt.NDArray[np.int32]"
...@@ -5,7 +5,7 @@ import itertools ...@@ -5,7 +5,7 @@ import itertools
import time import time
from collections import defaultdict from collections import defaultdict
from collections.abc import Iterable from collections.abc import Iterable
from typing import TYPE_CHECKING, Any from typing import Any
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch
...@@ -24,7 +24,12 @@ from vllm.v1.core.encoder_cache_manager import ( ...@@ -24,7 +24,12 @@ from vllm.v1.core.encoder_cache_manager import (
) )
from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager
from vllm.v1.core.sched.interface import SchedulerInterface from vllm.v1.core.sched.interface import SchedulerInterface
from vllm.v1.core.sched.output import CachedRequestData, NewRequestData, SchedulerOutput from vllm.v1.core.sched.output import (
CachedRequestData,
GrammarOutput,
NewRequestData,
SchedulerOutput,
)
from vllm.v1.core.sched.request_queue import SchedulingPolicy, create_request_queue from vllm.v1.core.sched.request_queue import SchedulingPolicy, create_request_queue
from vllm.v1.core.sched.utils import check_stop, remove_all from vllm.v1.core.sched.utils import check_stop, remove_all
from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs
...@@ -35,10 +40,6 @@ from vllm.v1.request import Request, RequestStatus ...@@ -35,10 +40,6 @@ from vllm.v1.request import Request, RequestStatus
from vllm.v1.spec_decode.metrics import SpecDecodingStats from vllm.v1.spec_decode.metrics import SpecDecodingStats
from vllm.v1.structured_output import StructuredOutputManager from vllm.v1.structured_output import StructuredOutputManager
if TYPE_CHECKING:
import numpy as np
import numpy.typing as npt
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -619,9 +620,6 @@ class Scheduler(SchedulerInterface): ...@@ -619,9 +620,6 @@ class Scheduler(SchedulerInterface):
scheduled_spec_decode_tokens, scheduled_spec_decode_tokens,
req_to_new_blocks, req_to_new_blocks,
) )
structured_output_request_ids, grammar_bitmask = self.get_grammar_bitmask(
num_scheduled_tokens.keys(), scheduled_spec_decode_tokens
)
# Record the request ids that were scheduled in this step. # Record the request ids that were scheduled in this step.
self.prev_step_scheduled_req_ids.clear() self.prev_step_scheduled_req_ids.clear()
...@@ -641,8 +639,6 @@ class Scheduler(SchedulerInterface): ...@@ -641,8 +639,6 @@ class Scheduler(SchedulerInterface):
# the previous and the current steps. # the previous and the current steps.
finished_req_ids=self.finished_req_ids, finished_req_ids=self.finished_req_ids,
free_encoder_mm_hashes=self.encoder_cache_manager.get_freed_mm_hashes(), free_encoder_mm_hashes=self.encoder_cache_manager.get_freed_mm_hashes(),
structured_output_request_ids=structured_output_request_ids,
grammar_bitmask=grammar_bitmask,
) )
# NOTE(Kuntai): this function is designed for multiple purposes: # NOTE(Kuntai): this function is designed for multiple purposes:
...@@ -872,9 +868,8 @@ class Scheduler(SchedulerInterface): ...@@ -872,9 +868,8 @@ class Scheduler(SchedulerInterface):
def get_grammar_bitmask( def get_grammar_bitmask(
self, self,
scheduled_request_ids: Iterable[str], scheduler_output: SchedulerOutput,
scheduled_spec_decode_tokens: dict[str, list[int]], ) -> GrammarOutput | None:
) -> tuple[list[str], "npt.NDArray[np.int32] | None"]:
# Collect list of scheduled request ids that use structured output. # Collect list of scheduled request ids that use structured output.
# The corresponding rows of the bitmask will be in this order. # The corresponding rows of the bitmask will be in this order.
# PERF: in case of chunked prefill, # PERF: in case of chunked prefill,
...@@ -883,18 +878,18 @@ class Scheduler(SchedulerInterface): ...@@ -883,18 +878,18 @@ class Scheduler(SchedulerInterface):
# cycle to fill in the bitmask, which could be a big no-op. # cycle to fill in the bitmask, which could be a big no-op.
structured_output_request_ids = [ structured_output_request_ids = [
req_id req_id
for req_id in scheduled_request_ids for req_id in scheduler_output.num_scheduled_tokens
if (req := self.requests.get(req_id)) and req.use_structured_output if (req := self.requests.get(req_id)) and req.use_structured_output
] ]
if not structured_output_request_ids: if not structured_output_request_ids:
return structured_output_request_ids, None return None
bitmask = self.structured_output_manager.grammar_bitmask( bitmask = self.structured_output_manager.grammar_bitmask(
self.requests, self.requests,
structured_output_request_ids, structured_output_request_ids,
scheduled_spec_decode_tokens, scheduler_output.scheduled_spec_decode_tokens,
) )
return structured_output_request_ids, bitmask return GrammarOutput(structured_output_request_ids, bitmask)
def update_from_output( def update_from_output(
self, self,
......
...@@ -12,7 +12,7 @@ from concurrent.futures import Future ...@@ -12,7 +12,7 @@ from concurrent.futures import Future
from contextlib import ExitStack, contextmanager from contextlib import ExitStack, contextmanager
from inspect import isclass, signature from inspect import isclass, signature
from logging import DEBUG from logging import DEBUG
from typing import Any, TypeVar from typing import Any, TypeVar, cast
import msgspec import msgspec
import zmq import zmq
...@@ -334,9 +334,12 @@ class EngineCore: ...@@ -334,9 +334,12 @@ class EngineCore:
if not self.scheduler.has_requests(): if not self.scheduler.has_requests():
return {}, False return {}, False
scheduler_output = self.scheduler.schedule() scheduler_output = self.scheduler.schedule()
future = self.model_executor.execute_model(scheduler_output, non_block=True)
grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output)
with self.log_error_detail(scheduler_output): with self.log_error_detail(scheduler_output):
model_output = self.model_executor.execute_model(scheduler_output) model_output = future.result()
if model_output is None:
model_output = self.model_executor.sample_tokens(grammar_output)
engine_core_outputs = self.scheduler.update_from_output( engine_core_outputs = self.scheduler.update_from_output(
scheduler_output, model_output scheduler_output, model_output
...@@ -376,20 +379,47 @@ class EngineCore: ...@@ -376,20 +379,47 @@ class EngineCore:
assert len(batch_queue) < self.batch_queue_size assert len(batch_queue) < self.batch_queue_size
model_executed = False model_executed = False
deferred_scheduler_output = None
if self.scheduler.has_requests(): if self.scheduler.has_requests():
scheduler_output = self.scheduler.schedule() scheduler_output = self.scheduler.schedule()
future = self.model_executor.execute_model(scheduler_output, non_block=True) exec_future = self.model_executor.execute_model(
batch_queue.appendleft((future, scheduler_output)) scheduler_output, non_block=True
)
model_executed = scheduler_output.total_num_scheduled_tokens > 0 model_executed = scheduler_output.total_num_scheduled_tokens > 0
if (
model_executed if scheduler_output.pending_structured_output_tokens:
and len(batch_queue) < self.batch_queue_size # We need to defer sampling until we have processed the model output
and not batch_queue[-1][0].done() # from the prior step.
): deferred_scheduler_output = scheduler_output
# Don't block on next worker response unless the queue is full # Block-wait for execute to return (continues running async on the GPU).
# or there are no more requests to schedule. with self.log_error_detail(scheduler_output):
return None, True exec_result = exec_future.result()
assert exec_result is None
else:
# We aren't waiting for any tokens, get any grammar output immediately.
grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output)
# Block-wait for execute to return (continues running async on the GPU).
with self.log_error_detail(scheduler_output):
exec_result = exec_future.result()
if exec_result is None:
# Call sample tokens.
future = self.model_executor.sample_tokens(
grammar_output, non_block=True
)
else:
# No sampling required (e.g. all requests finished).
future = cast(Future[ModelRunnerOutput], exec_future)
# Add this step's future to the queue.
batch_queue.appendleft((future, scheduler_output))
if (
model_executed
and len(batch_queue) < self.batch_queue_size
and not batch_queue[-1][0].done()
):
# Don't block on next worker response unless the queue is full
# or there are no more requests to schedule.
return None, True
elif not batch_queue: elif not batch_queue:
# Queue is empty. We should not reach here since this method should # Queue is empty. We should not reach here since this method should
...@@ -405,6 +435,19 @@ class EngineCore: ...@@ -405,6 +435,19 @@ class EngineCore:
engine_core_outputs = self.scheduler.update_from_output( engine_core_outputs = self.scheduler.update_from_output(
scheduler_output, model_output scheduler_output, model_output
) )
# NOTE(nick): We can either handle the deferred tasks here or save
# in a field and do it immediately once step_with_batch_queue is
# re-called. The latter slightly favors TTFT over TPOT/throughput.
if deferred_scheduler_output:
# We now have the tokens needed to compute the bitmask for the
# deferred request. Get the bitmask and call sample tokens.
grammar_output = self.scheduler.get_grammar_bitmask(
deferred_scheduler_output
)
future = self.model_executor.sample_tokens(grammar_output, non_block=True)
batch_queue.appendleft((future, deferred_scheduler_output))
return engine_core_outputs, model_executed return engine_core_outputs, model_executed
def shutdown(self): def shutdown(self):
......
...@@ -16,7 +16,7 @@ from vllm.logger import init_logger ...@@ -16,7 +16,7 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.tasks import SupportedTask from vllm.tasks import SupportedTask
from vllm.utils.import_utils import resolve_obj_by_qualname from vllm.utils.import_utils import resolve_obj_by_qualname
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.engine import ReconfigureDistributedRequest from vllm.v1.engine import ReconfigureDistributedRequest
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
...@@ -187,28 +187,44 @@ class Executor(ABC): ...@@ -187,28 +187,44 @@ class Executor(ABC):
@overload @overload
def execute_model( def execute_model(
self, self, scheduler_output: SchedulerOutput, non_block: Literal[False] = False
scheduler_output: SchedulerOutput, ) -> ModelRunnerOutput | None:
non_block: Literal[False] = False,
) -> ModelRunnerOutput:
pass pass
@overload @overload
def execute_model( def execute_model(
self, self, scheduler_output: SchedulerOutput, non_block: Literal[True] = True
scheduler_output: SchedulerOutput, ) -> Future[ModelRunnerOutput | None]:
non_block: Literal[True] = True,
) -> Future[ModelRunnerOutput]:
pass pass
def execute_model( def execute_model(
self, scheduler_output: SchedulerOutput, non_block: bool = False self, scheduler_output: SchedulerOutput, non_block: bool = False
) -> ModelRunnerOutput | Future[ModelRunnerOutput]: ) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]:
output = self.collective_rpc( # type: ignore[call-overload] output = self.collective_rpc( # type: ignore[call-overload]
"execute_model", args=(scheduler_output,), non_block=non_block "execute_model", args=(scheduler_output,), non_block=non_block
) )
return output[0] return output[0]
@overload
def sample_tokens(
self, grammar_output: GrammarOutput | None, non_block: Literal[False] = False
) -> ModelRunnerOutput:
pass
@overload
def sample_tokens(
self, grammar_output: GrammarOutput | None, non_block: Literal[True] = True
) -> Future[ModelRunnerOutput]:
pass
def sample_tokens(
self, grammar_output: GrammarOutput | None, non_block: bool = False
) -> ModelRunnerOutput | Future[ModelRunnerOutput]:
output = self.collective_rpc( # type: ignore[call-overload]
"sample_tokens", args=(grammar_output,), non_block=non_block
)
return output[0]
def execute_dummy_batch(self) -> None: def execute_dummy_batch(self) -> None:
self.collective_rpc("execute_dummy_batch") self.collective_rpc("execute_dummy_batch")
......
...@@ -46,7 +46,7 @@ from vllm.utils.system_utils import ( ...@@ -46,7 +46,7 @@ from vllm.utils.system_utils import (
get_mp_context, get_mp_context,
set_process_title, set_process_title,
) )
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.executor.abstract import Executor, FailureCallback from vllm.v1.executor.abstract import Executor, FailureCallback
from vllm.v1.outputs import AsyncModelRunnerOutput, DraftTokenIds, ModelRunnerOutput from vllm.v1.outputs import AsyncModelRunnerOutput, DraftTokenIds, ModelRunnerOutput
from vllm.v1.worker.worker_base import WorkerWrapperBase from vllm.v1.worker.worker_base import WorkerWrapperBase
...@@ -132,15 +132,12 @@ class MultiprocExecutor(Executor): ...@@ -132,15 +132,12 @@ class MultiprocExecutor(Executor):
uw.death_writer.close() uw.death_writer.close()
self._ensure_worker_termination([uw.proc for uw in unready_workers]) self._ensure_worker_termination([uw.proc for uw in unready_workers])
# For pipeline parallel, we use a thread pool for asynchronous # Note: must use only 1 IO thread to keep dequeue sequence
# execute_model. # from the response queue.
if self.max_concurrent_batches > 1: # _async_aggregate_workers_output also assumes a single IO thread.
# Note: must use only 1 IO thread to keep dequeue sequence self.io_thread_pool = ThreadPoolExecutor(
# from the response queue max_workers=1, thread_name_prefix="mp_exec_io"
# _async_aggregate_workers_output also assumes a single IO thread )
self.io_thread_pool = ThreadPoolExecutor(
max_workers=1, thread_name_prefix="mp_exec_io"
)
self.output_rank = self._get_output_rank() self.output_rank = self._get_output_rank()
self.has_connector = self.vllm_config.kv_transfer_config is not None self.has_connector = self.vllm_config.kv_transfer_config is not None
...@@ -180,15 +177,27 @@ class MultiprocExecutor(Executor): ...@@ -180,15 +177,27 @@ class MultiprocExecutor(Executor):
self.failure_callback = callback self.failure_callback = callback
def execute_model( # type: ignore[override] def execute_model( # type: ignore[override]
self, self, scheduler_output: SchedulerOutput, non_block: bool = False
scheduler_output: SchedulerOutput, ) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]:
non_block: bool = False, return self._execute_with_aggregation(
"execute_model", scheduler_output, non_block=non_block
)
def sample_tokens( # type: ignore[override]
self, grammar_output: GrammarOutput | None, non_block: bool = False
) -> ModelRunnerOutput | Future[ModelRunnerOutput]: ) -> ModelRunnerOutput | Future[ModelRunnerOutput]:
return self._execute_with_aggregation( # type: ignore[return-value]
"sample_tokens", grammar_output, non_block=non_block
)
def _execute_with_aggregation(
self, method: str, *args, non_block: bool = False
) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]:
if not self.has_connector: if not self.has_connector:
# get output only from a single worker (output_rank) # get output only from a single worker (output_rank)
(output,) = self.collective_rpc( (output,) = self.collective_rpc(
"execute_model", method,
args=(scheduler_output,), args=args,
unique_reply_rank=self.output_rank, unique_reply_rank=self.output_rank,
non_block=non_block, non_block=non_block,
timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS, timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS,
...@@ -197,8 +206,8 @@ class MultiprocExecutor(Executor): ...@@ -197,8 +206,8 @@ class MultiprocExecutor(Executor):
# get output from all workers # get output from all workers
outputs = self.collective_rpc( outputs = self.collective_rpc(
"execute_model", method,
args=(scheduler_output,), args=args,
non_block=non_block, non_block=non_block,
timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS, timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS,
) )
......
...@@ -19,7 +19,7 @@ from vllm.utils.network_utils import ( ...@@ -19,7 +19,7 @@ from vllm.utils.network_utils import (
get_ip, get_ip,
get_open_port, get_open_port,
) )
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
from vllm.v1.executor.abstract import Executor from vllm.v1.executor.abstract import Executor
from vllm.v1.executor.ray_utils import ( from vllm.v1.executor.ray_utils import (
...@@ -41,6 +41,9 @@ if TYPE_CHECKING: ...@@ -41,6 +41,9 @@ if TYPE_CHECKING:
logger = init_logger(__name__) logger = init_logger(__name__)
COMPLETED_NONE_FUTURE: Future[ModelRunnerOutput | None] = Future()
COMPLETED_NONE_FUTURE.set_result(None)
@dataclass @dataclass
class RayWorkerMetaData: class RayWorkerMetaData:
...@@ -96,6 +99,8 @@ class RayDistributedExecutor(Executor): ...@@ -96,6 +99,8 @@ class RayDistributedExecutor(Executor):
# KV connector setup # KV connector setup
self.has_connector = self.vllm_config.kv_transfer_config is not None self.has_connector = self.vllm_config.kv_transfer_config is not None
self.scheduler_output: SchedulerOutput | None = None
@property @property
def max_concurrent_batches(self) -> int: def max_concurrent_batches(self) -> int:
"""Ray distributed executor supports pipeline parallelism, """Ray distributed executor supports pipeline parallelism,
...@@ -381,22 +386,46 @@ class RayDistributedExecutor(Executor): ...@@ -381,22 +386,46 @@ class RayDistributedExecutor(Executor):
self.shutdown() self.shutdown()
def execute_model( # type: ignore[override] def execute_model( # type: ignore[override]
self, scheduler_output: SchedulerOutput, non_block: bool = False self,
scheduler_output: SchedulerOutput,
non_block: bool = False,
) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]:
if self.scheduler_output is not None:
raise RuntimeError(
"State error: sample_tokens() must be called "
"after execute_model() returns None."
)
self.scheduler_output = scheduler_output
return COMPLETED_NONE_FUTURE if non_block else None
def sample_tokens( # type: ignore[override]
self,
grammar_output: "GrammarOutput | None",
non_block: bool = False,
) -> ModelRunnerOutput | Future[ModelRunnerOutput]: ) -> ModelRunnerOutput | Future[ModelRunnerOutput]:
"""Execute the model on the Ray workers. """Execute the model on the Ray workers.
The scheduler output to use should have been provided in
a prior call to execute_model().
Args: Args:
scheduler_output: The scheduler output to execute. grammar_output: The structured outputs grammar bitmask, if applicable.
non_block: If True, the method will return a Future. non_block: If True, the method will return a Future.
Returns: Returns:
The model runner output. The model runner output.
""" """
scheduler_output = self.scheduler_output
if scheduler_output is None:
return None # noqa
self.scheduler_output = None
# Build the compiled DAG for the first time. # Build the compiled DAG for the first time.
if self.forward_dag is None: # type: ignore if self.forward_dag is None: # type: ignore
self.forward_dag = self._compiled_ray_dag(enable_asyncio=False) self.forward_dag = self._compiled_ray_dag(enable_asyncio=False)
refs = self.forward_dag.execute(scheduler_output) # type: ignore refs = self.forward_dag.execute((scheduler_output, grammar_output)) # type: ignore
if not self.has_connector: if not self.has_connector:
# Get output only from a single worker (output_rank) # Get output only from a single worker (output_rank)
......
...@@ -19,7 +19,7 @@ from vllm.v1.outputs import AsyncModelRunnerOutput ...@@ -19,7 +19,7 @@ from vllm.v1.outputs import AsyncModelRunnerOutput
from vllm.v1.worker.worker_base import WorkerWrapperBase from vllm.v1.worker.worker_base import WorkerWrapperBase
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.outputs import ModelRunnerOutput
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -82,36 +82,41 @@ try: ...@@ -82,36 +82,41 @@ try:
def execute_model_ray( def execute_model_ray(
self, self,
scheduler_output: Union[ execute_model_input: tuple["SchedulerOutput", "GrammarOutput"]
"SchedulerOutput", tuple["SchedulerOutput", "IntermediateTensors"] | tuple["SchedulerOutput", "GrammarOutput", "IntermediateTensors"],
],
) -> Union[ ) -> Union[
"ModelRunnerOutput", tuple["SchedulerOutput", "IntermediateTensors"] "ModelRunnerOutput",
tuple["SchedulerOutput", "GrammarOutput", "IntermediateTensors"],
]: ]:
# This method is used by Ray Compiled Graph to execute the model, # This method is used by Ray Compiled Graph to execute the model,
# and it needs a special logic of self.setup_device_if_necessary() # and it needs a special logic of self.setup_device_if_necessary()
self.setup_device_if_necessary() self.setup_device_if_necessary()
assert self.worker is not None, "Worker is not initialized" assert self.worker is not None, "Worker is not initialized"
if isinstance(scheduler_output, tuple): if len(execute_model_input) == 3:
scheduler_output, intermediate_tensors = scheduler_output scheduler_output, grammar_output, intermediate_tensors = (
execute_model_input
)
else: else:
scheduler_output, intermediate_tensors = scheduler_output, None scheduler_output, grammar_output = execute_model_input
intermediate_tensors = None
assert self.worker.model_runner is not None assert self.worker.model_runner is not None
output = self.worker.model_runner.execute_model( output = self.worker.model_runner.execute_model(
scheduler_output, intermediate_tensors scheduler_output, intermediate_tensors
) )
if isinstance(output, IntermediateTensors): if isinstance(output, IntermediateTensors):
output = scheduler_output, output output = scheduler_output, grammar_output, output
elif not get_pp_group().is_last_rank: elif not get_pp_group().is_last_rank:
# Case where there are no scheduled requests # Case where there are no scheduled requests
# but may still be finished requests. # but may still be finished requests.
assert not output or not output.req_ids assert not output or not output.req_ids
output = scheduler_output, None output = scheduler_output, grammar_output, None
# Ensure outputs crossing Ray compiled DAG are serializable. elif output is None:
# AsyncModelRunnerOutput holds CUDA events and cannot be output = self.worker.model_runner.sample_tokens(grammar_output)
# pickled. # Ensure outputs crossing Ray compiled DAG are serializable.
if isinstance(output, AsyncModelRunnerOutput): # AsyncModelRunnerOutput holds CUDA events and cannot be
output = output.get_output() # pickled.
if isinstance(output, AsyncModelRunnerOutput):
output = output.get_output()
return output return output
def override_env_vars(self, vars: dict[str, str]): def override_env_vars(self, vars: dict[str, str]):
......
...@@ -16,6 +16,7 @@ from diskcache import Cache ...@@ -16,6 +16,7 @@ from diskcache import Cache
import vllm.envs as envs import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils.import_utils import LazyLoader from vllm.utils.import_utils import LazyLoader
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
if TYPE_CHECKING: if TYPE_CHECKING:
import outlines_core as oc import outlines_core as oc
...@@ -24,7 +25,6 @@ if TYPE_CHECKING: ...@@ -24,7 +25,6 @@ if TYPE_CHECKING:
import xgrammar as xgr import xgrammar as xgr
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch from vllm.v1.worker.gpu_input_batch import InputBatch
else: else:
xgr = LazyLoader("xgr", globals(), "xgrammar") xgr = LazyLoader("xgr", globals(), "xgrammar")
...@@ -47,6 +47,7 @@ CACHE = None ...@@ -47,6 +47,7 @@ CACHE = None
def apply_grammar_bitmask( def apply_grammar_bitmask(
scheduler_output: SchedulerOutput, scheduler_output: SchedulerOutput,
grammar_output: GrammarOutput,
input_batch: InputBatch, input_batch: InputBatch,
logits: torch.Tensor, logits: torch.Tensor,
) -> None: ) -> None:
...@@ -58,9 +59,9 @@ def apply_grammar_bitmask( ...@@ -58,9 +59,9 @@ def apply_grammar_bitmask(
input_batch (InputBatch): The input of model runner. input_batch (InputBatch): The input of model runner.
logits (torch.Tensor): The output logits of model forward. logits (torch.Tensor): The output logits of model forward.
""" """
grammar_bitmask = scheduler_output.grammar_bitmask # Serialization of np.ndarray is much more efficient than a tensor,
if grammar_bitmask is None: # so we receive it in that format.
return grammar_bitmask = grammar_output.grammar_bitmask
# We receive the structured output bitmask from the scheduler, # We receive the structured output bitmask from the scheduler,
# compacted to contain bitmasks only for structured output requests. # compacted to contain bitmasks only for structured output requests.
...@@ -79,7 +80,7 @@ def apply_grammar_bitmask( ...@@ -79,7 +80,7 @@ def apply_grammar_bitmask(
cumulative_offset += len( cumulative_offset += len(
scheduler_output.scheduled_spec_decode_tokens.get(req_id, []) scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])
) )
if req_id in scheduler_output.structured_output_request_ids: if req_id in grammar_output.structured_output_request_ids:
struct_out_req_batch_indices[req_id] = logit_index struct_out_req_batch_indices[req_id] = logit_index
out_indices = [] out_indices = []
...@@ -91,7 +92,7 @@ def apply_grammar_bitmask( ...@@ -91,7 +92,7 @@ def apply_grammar_bitmask(
dtype=grammar_bitmask.dtype, dtype=grammar_bitmask.dtype,
) )
cumulative_index = 0 cumulative_index = 0
for req_id in scheduler_output.structured_output_request_ids: for req_id in grammar_output.structured_output_request_ids:
num_spec_tokens = len( num_spec_tokens = len(
scheduler_output.scheduled_spec_decode_tokens.get(req_id, []) scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])
) )
...@@ -101,22 +102,28 @@ def apply_grammar_bitmask( ...@@ -101,22 +102,28 @@ def apply_grammar_bitmask(
sorted_bitmask[logit_index + i] = grammar_bitmask[cumulative_index + i] sorted_bitmask[logit_index + i] = grammar_bitmask[cumulative_index + i]
out_indices.append(logit_index + i) out_indices.append(logit_index + i)
cumulative_index += 1 + num_spec_tokens cumulative_index += 1 + num_spec_tokens
grammar_bitmask = sorted_bitmask
# Copy async to device as tensor.
grammar_bitmask = torch.from_numpy(sorted_bitmask).to(
logits.device, non_blocking=True
)
# If the length of out indices and the logits have the same shape # If the length of out indices and the logits have the same shape
# we don't need to pass indices to the kernel, # we don't need to pass indices to the kernel,
# since the bitmask is already aligned with the logits. # since the bitmask is already aligned with the logits.
skip_out_indices = len(out_indices) == logits.shape[0] skip_out_indices = len(out_indices) == logits.shape[0]
# Serialization of np.ndarray is much more efficient than a tensor, index_tensor = None
# so we receive it in that format. if not skip_out_indices:
grammar_bitmask = torch.from_numpy(grammar_bitmask).contiguous() # xgrammar expects a python list of indices but it will actually work with
# a tensor. If we copy the tensor ourselves here we can do it in a non_blocking
# manner and there should be no cpu sync within xgrammar.
index_tensor = torch.tensor(
out_indices, dtype=torch.int32, device="cpu", pin_memory=True
)
index_tensor = index_tensor.to(logits.device, non_blocking=True)
xgr.apply_token_bitmask_inplace( xgr.apply_token_bitmask_inplace(logits, grammar_bitmask, indices=index_tensor)
logits,
grammar_bitmask.to(logits.device, non_blocking=True),
indices=out_indices if not skip_out_indices else None,
)
class OutlinesVocabulary: class OutlinesVocabulary:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment