Unverified Commit 0cdbe7b7 authored by Nick Hill's avatar Nick Hill Committed by GitHub
Browse files

[Core] Async scheduling + structured outputs compatibility (#26866)


Signed-off-by: default avatarNick Hill <nhill@redhat.com>
parent df334868
......@@ -6,6 +6,9 @@ from copy import deepcopy
from tblib import pickling_support
# Import fixture
from tests.v1.entrypoints.conftest import sample_json_schema # noqa
# ruff: noqa
# Install support for pickling exceptions so that we can nicely propagate
......
......@@ -337,8 +337,6 @@ def test_stop_via_update_from_output():
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
)
model_output = ModelRunnerOutput(
......@@ -385,8 +383,6 @@ def test_stop_via_update_from_output():
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
)
model_output = ModelRunnerOutput(
......@@ -431,8 +427,6 @@ def test_stop_via_update_from_output():
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
)
model_output = ModelRunnerOutput(
......@@ -472,8 +466,6 @@ def test_stop_via_update_from_output():
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
)
model_output = ModelRunnerOutput(
......@@ -1988,7 +1980,6 @@ def test_schedule_skip_tokenizer_init():
scheduler.add_request(request)
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == len(requests)
assert output.grammar_bitmask is None
def test_schedule_skip_tokenizer_init_structured_output_request():
......
......@@ -7,6 +7,7 @@ import torch._dynamo.config as dynamo_config
from vllm import SamplingParams
from vllm.logprobs import Logprob
from vllm.sampling_params import StructuredOutputsParams
from ...conftest import VllmRunner
from ...models.utils import check_outputs_equal
......@@ -15,9 +16,12 @@ MODEL = "Qwen/Qwen3-0.6B"
@dynamo_config.patch(cache_size_limit=16)
def test_preempt_and_async_scheduling_e2e(monkeypatch: pytest.MonkeyPatch):
def test_preempt_and_async_scheduling_e2e(
sample_json_schema, monkeypatch: pytest.MonkeyPatch
):
"""Test consistency of combos of async scheduling, preemption,
uni/multiproc executor, and various sampling parameters."""
uni/multiproc executor, and various sampling parameters
including structured outputs."""
first_prompt = (
"The following numbers of the sequence "
......@@ -35,6 +39,12 @@ def test_preempt_and_async_scheduling_e2e(monkeypatch: pytest.MonkeyPatch):
dict(bad_words=["the", " the"]),
dict(logprobs=2),
dict(logprobs=2, presence_penalty=-1.0),
dict(structured_outputs=StructuredOutputsParams(json=sample_json_schema)),
dict(
structured_outputs=StructuredOutputsParams(json=sample_json_schema),
logprobs=2,
presence_penalty=-1.0,
),
]
default_params = dict(
......
......@@ -248,7 +248,7 @@ def test_engine_core_concurrent_batches():
self,
scheduler_output,
non_block=False,
) -> Future[ModelRunnerOutput]:
) -> Future[ModelRunnerOutput | None]:
"""Make execute_model non-blocking."""
# DummyExecutor used only for testing async case.
......@@ -263,6 +263,23 @@ def test_engine_core_concurrent_batches():
# Use the thread pool instead of creating a new thread
return self.thread_pool.submit(_execute)
def sample_tokens(
self, grammar_output, non_block=False
) -> Future[ModelRunnerOutput]:
"""Make sample_tokens non-blocking."""
# DummyExecutor used only for testing async case.
assert non_block
def _execute():
output = self.collective_rpc("sample_tokens", args=(grammar_output,))
# Make a copy because output[0] may be reused
# by the next batch.
return copy.deepcopy(output[0])
# Use the thread pool instead of creating a new thread
return self.thread_pool.submit(_execute)
@property
def max_concurrent_batches(self) -> int:
return 2
......
......@@ -31,7 +31,9 @@ class CustomMultiprocExecutor(MultiprocExecutor):
# Drop marker to show that this was run
with open(".marker", "w"):
...
return super().collective_rpc(method, timeout, args, kwargs)
return super().collective_rpc(
method, timeout, args, kwargs, non_block, unique_reply_rank
)
CustomMultiprocExecutorAsync = CustomMultiprocExecutor
......
......@@ -26,8 +26,6 @@ def _make_empty_scheduler_output():
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
kv_connector_metadata=SharedStorageConnectorMetadata(),
)
......
......@@ -981,9 +981,7 @@ def test_scheduler_kv_connector_stats_aggregation():
scheduled_encoder_inputs={},
num_common_prefix_blocks=[0],
finished_req_ids=set(),
free_encoder_mm_hashes=set(),
structured_output_request_ids={},
grammar_bitmask=None,
free_encoder_mm_hashes=[],
)
engine_core_outputs = scheduler.update_from_output(scheduler_output, model_output)
......
......@@ -92,8 +92,6 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
)
......@@ -171,8 +169,6 @@ def test_update_states_request_finished(model_runner):
num_common_prefix_blocks=[],
finished_req_ids={req_id},
free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
)
model_runner._update_states(scheduler_output)
......@@ -201,8 +197,6 @@ def test_update_states_request_resumed(model_runner):
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
)
model_runner._update_states(scheduler_output)
......@@ -230,8 +224,6 @@ def test_update_states_request_resumed(model_runner):
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
)
model_runner._update_states(scheduler_output)
......@@ -261,8 +253,6 @@ def test_update_states_no_changes(model_runner):
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
)
model_runner._update_states(scheduler_output)
......@@ -296,8 +286,6 @@ def test_update_states_request_unscheduled(model_runner):
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
)
model_runner._update_states(scheduler_output)
......
......@@ -152,8 +152,6 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
)
......@@ -269,8 +267,6 @@ def test_update_states_request_finished(model_runner, dist_init):
num_common_prefix_blocks=[],
finished_req_ids={req_id},
free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
)
metadata_before = model_runner.input_batch.sampling_metadata
......@@ -301,8 +297,6 @@ def test_update_states_request_resumed(model_runner, dist_init):
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
)
model_runner._update_states(scheduler_output)
......@@ -330,8 +324,6 @@ def test_update_states_request_resumed(model_runner, dist_init):
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
)
metadata_before = model_runner.input_batch.sampling_metadata
......@@ -423,8 +415,6 @@ def test_update_states_no_changes(model_runner, dist_init):
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
)
metadata_before = model_runner.input_batch.sampling_metadata
......@@ -460,8 +450,6 @@ def test_update_states_request_unscheduled(model_runner, dist_init):
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
)
metadata_before = model_runner._update_states(scheduler_output)
......
......@@ -6,7 +6,7 @@ KV cache helper for store.
from collections.abc import Sequence
from concurrent.futures import CancelledError, Future
from typing import TYPE_CHECKING, Literal, cast
from typing import TYPE_CHECKING, Literal
import torch
......@@ -138,8 +138,11 @@ class KVOutputAggregator:
return cls(connector.get_finished_count() or world_size)
def aggregate(
self, outputs: list[ModelRunnerOutput], output_rank: int = 0
) -> ModelRunnerOutput:
self, outputs: list[ModelRunnerOutput | None], output_rank: int = 0
) -> ModelRunnerOutput | None:
if not outputs[output_rank]:
return None
# Aggregate kv_connector_output from all workers
def update_finished_set(
......@@ -161,6 +164,7 @@ class KVOutputAggregator:
aggregated_kv_connector_stats = None
invalid_block_ids = set[int]()
for model_runner_output in outputs:
assert model_runner_output is not None
kv_output = model_runner_output.kv_connector_output
if not kv_output:
continue
......@@ -204,6 +208,7 @@ class KVOutputAggregator:
# select output of the worker specified by output_rank
output = outputs[output_rank]
assert output is not None
output.kv_connector_output = KVConnectorOutput(
finished_sending=finished_sending or None,
finished_recving=finished_recving or None,
......@@ -215,13 +220,16 @@ class KVOutputAggregator:
return output
def async_aggregate(
self, output_futures: Sequence[Future[ModelRunnerOutput]], output_rank: int = 0
) -> Future[ModelRunnerOutput]:
self,
output_futures: Sequence[Future[ModelRunnerOutput | None]],
output_rank: int = 0,
) -> Future[ModelRunnerOutput | None]:
"""Takes a list of futures and returns a single future which resolves
to the respective list of outputs."""
result_future: Future[ModelRunnerOutput] = Future()
result_future: Future[ModelRunnerOutput | None] = Future()
outputs: list[ModelRunnerOutput | None] = [None] * len(output_futures)
remaining = len(output_futures)
def make_callback(idx):
def callback(fut):
......@@ -236,12 +244,10 @@ class KVOutputAggregator:
result_future.set_exception(e)
# this check assumes io_thread_pool uses a single thread
if all(outputs):
result_future.set_result(
self.aggregate(
cast(list[ModelRunnerOutput], outputs), output_rank
)
)
nonlocal remaining
remaining -= 1
if not remaining:
result_future.set_result(self.aggregate(outputs, output_rank))
return callback
......
......@@ -15,8 +15,12 @@ class AsyncScheduler(Scheduler):
scheduler_output: SchedulerOutput,
) -> None:
super()._update_after_schedule(scheduler_output)
pending_structured_output_tokens = False
for req_id in scheduler_output.num_scheduled_tokens:
request = self.requests[req_id]
pending_structured_output_tokens |= (
request.use_structured_output and request.num_output_placeholders > 0
)
if (
request.num_computed_tokens
== request.num_tokens + request.num_output_placeholders
......@@ -25,6 +29,10 @@ class AsyncScheduler(Scheduler):
# TODO(woosuk): Support speculative decoding.
request.num_output_placeholders += 1
scheduler_output.pending_structured_output_tokens = (
pending_structured_output_tokens
)
def _update_request_with_output(
self,
request: Request,
......
......@@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING:
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.engine import EngineCoreOutputs
from vllm.v1.metrics.stats import SchedulerStats
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
......@@ -40,6 +40,12 @@ class SchedulerInterface(ABC):
"""
raise NotImplementedError
@abstractmethod
def get_grammar_bitmask(
self, scheduler_output: "SchedulerOutput"
) -> "GrammarOutput | None":
raise NotImplementedError
@abstractmethod
def update_from_output(
self,
......
......@@ -181,12 +181,17 @@ class SchedulerOutput:
# freed from the encoder cache.
free_encoder_mm_hashes: list[str]
# ids of structured outputs requests included in the bitmask, in the
# same order as the corresponding stacked rows of the bitmask.
# There may be more than one row per request in the case of speculative decoding.
structured_output_request_ids: list[str]
# the bitmask for the whole batch
grammar_bitmask: "npt.NDArray[np.int32] | None"
# Whether the scheduled requests have all the output tokens they
# need to perform grammar bitmask computation.
pending_structured_output_tokens: bool = False
# KV Cache Connector metadata.
kv_connector_metadata: KVConnectorMetadata | None = None
@dataclass
class GrammarOutput:
# ids of structured output requests.
structured_output_request_ids: list[str]
# Bitmask ordered as structured_output_request_ids.
grammar_bitmask: "npt.NDArray[np.int32]"
......@@ -5,7 +5,7 @@ import itertools
import time
from collections import defaultdict
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any
from typing import Any
from vllm.config import VllmConfig
from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch
......@@ -24,7 +24,12 @@ from vllm.v1.core.encoder_cache_manager import (
)
from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager
from vllm.v1.core.sched.interface import SchedulerInterface
from vllm.v1.core.sched.output import CachedRequestData, NewRequestData, SchedulerOutput
from vllm.v1.core.sched.output import (
CachedRequestData,
GrammarOutput,
NewRequestData,
SchedulerOutput,
)
from vllm.v1.core.sched.request_queue import SchedulingPolicy, create_request_queue
from vllm.v1.core.sched.utils import check_stop, remove_all
from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs
......@@ -35,10 +40,6 @@ from vllm.v1.request import Request, RequestStatus
from vllm.v1.spec_decode.metrics import SpecDecodingStats
from vllm.v1.structured_output import StructuredOutputManager
if TYPE_CHECKING:
import numpy as np
import numpy.typing as npt
logger = init_logger(__name__)
......@@ -619,9 +620,6 @@ class Scheduler(SchedulerInterface):
scheduled_spec_decode_tokens,
req_to_new_blocks,
)
structured_output_request_ids, grammar_bitmask = self.get_grammar_bitmask(
num_scheduled_tokens.keys(), scheduled_spec_decode_tokens
)
# Record the request ids that were scheduled in this step.
self.prev_step_scheduled_req_ids.clear()
......@@ -641,8 +639,6 @@ class Scheduler(SchedulerInterface):
# the previous and the current steps.
finished_req_ids=self.finished_req_ids,
free_encoder_mm_hashes=self.encoder_cache_manager.get_freed_mm_hashes(),
structured_output_request_ids=structured_output_request_ids,
grammar_bitmask=grammar_bitmask,
)
# NOTE(Kuntai): this function is designed for multiple purposes:
......@@ -872,9 +868,8 @@ class Scheduler(SchedulerInterface):
def get_grammar_bitmask(
self,
scheduled_request_ids: Iterable[str],
scheduled_spec_decode_tokens: dict[str, list[int]],
) -> tuple[list[str], "npt.NDArray[np.int32] | None"]:
scheduler_output: SchedulerOutput,
) -> GrammarOutput | None:
# Collect list of scheduled request ids that use structured output.
# The corresponding rows of the bitmask will be in this order.
# PERF: in case of chunked prefill,
......@@ -883,18 +878,18 @@ class Scheduler(SchedulerInterface):
# cycle to fill in the bitmask, which could be a big no-op.
structured_output_request_ids = [
req_id
for req_id in scheduled_request_ids
for req_id in scheduler_output.num_scheduled_tokens
if (req := self.requests.get(req_id)) and req.use_structured_output
]
if not structured_output_request_ids:
return structured_output_request_ids, None
return None
bitmask = self.structured_output_manager.grammar_bitmask(
self.requests,
structured_output_request_ids,
scheduled_spec_decode_tokens,
scheduler_output.scheduled_spec_decode_tokens,
)
return structured_output_request_ids, bitmask
return GrammarOutput(structured_output_request_ids, bitmask)
def update_from_output(
self,
......
......@@ -12,7 +12,7 @@ from concurrent.futures import Future
from contextlib import ExitStack, contextmanager
from inspect import isclass, signature
from logging import DEBUG
from typing import Any, TypeVar
from typing import Any, TypeVar, cast
import msgspec
import zmq
......@@ -334,9 +334,12 @@ class EngineCore:
if not self.scheduler.has_requests():
return {}, False
scheduler_output = self.scheduler.schedule()
future = self.model_executor.execute_model(scheduler_output, non_block=True)
grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output)
with self.log_error_detail(scheduler_output):
model_output = self.model_executor.execute_model(scheduler_output)
model_output = future.result()
if model_output is None:
model_output = self.model_executor.sample_tokens(grammar_output)
engine_core_outputs = self.scheduler.update_from_output(
scheduler_output, model_output
......@@ -376,12 +379,39 @@ class EngineCore:
assert len(batch_queue) < self.batch_queue_size
model_executed = False
deferred_scheduler_output = None
if self.scheduler.has_requests():
scheduler_output = self.scheduler.schedule()
future = self.model_executor.execute_model(scheduler_output, non_block=True)
batch_queue.appendleft((future, scheduler_output))
exec_future = self.model_executor.execute_model(
scheduler_output, non_block=True
)
model_executed = scheduler_output.total_num_scheduled_tokens > 0
if scheduler_output.pending_structured_output_tokens:
# We need to defer sampling until we have processed the model output
# from the prior step.
deferred_scheduler_output = scheduler_output
# Block-wait for execute to return (continues running async on the GPU).
with self.log_error_detail(scheduler_output):
exec_result = exec_future.result()
assert exec_result is None
else:
# We aren't waiting for any tokens, get any grammar output immediately.
grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output)
# Block-wait for execute to return (continues running async on the GPU).
with self.log_error_detail(scheduler_output):
exec_result = exec_future.result()
if exec_result is None:
# Call sample tokens.
future = self.model_executor.sample_tokens(
grammar_output, non_block=True
)
else:
# No sampling required (e.g. all requests finished).
future = cast(Future[ModelRunnerOutput], exec_future)
# Add this step's future to the queue.
batch_queue.appendleft((future, scheduler_output))
if (
model_executed
and len(batch_queue) < self.batch_queue_size
......@@ -405,6 +435,19 @@ class EngineCore:
engine_core_outputs = self.scheduler.update_from_output(
scheduler_output, model_output
)
# NOTE(nick): We can either handle the deferred tasks here or save
# in a field and do it immediately once step_with_batch_queue is
# re-called. The latter slightly favors TTFT over TPOT/throughput.
if deferred_scheduler_output:
# We now have the tokens needed to compute the bitmask for the
# deferred request. Get the bitmask and call sample tokens.
grammar_output = self.scheduler.get_grammar_bitmask(
deferred_scheduler_output
)
future = self.model_executor.sample_tokens(grammar_output, non_block=True)
batch_queue.appendleft((future, deferred_scheduler_output))
return engine_core_outputs, model_executed
def shutdown(self):
......
......@@ -16,7 +16,7 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.tasks import SupportedTask
from vllm.utils.import_utils import resolve_obj_by_qualname
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.engine import ReconfigureDistributedRequest
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
......@@ -187,28 +187,44 @@ class Executor(ABC):
@overload
def execute_model(
self,
scheduler_output: SchedulerOutput,
non_block: Literal[False] = False,
) -> ModelRunnerOutput:
self, scheduler_output: SchedulerOutput, non_block: Literal[False] = False
) -> ModelRunnerOutput | None:
pass
@overload
def execute_model(
self,
scheduler_output: SchedulerOutput,
non_block: Literal[True] = True,
) -> Future[ModelRunnerOutput]:
self, scheduler_output: SchedulerOutput, non_block: Literal[True] = True
) -> Future[ModelRunnerOutput | None]:
pass
def execute_model(
self, scheduler_output: SchedulerOutput, non_block: bool = False
) -> ModelRunnerOutput | Future[ModelRunnerOutput]:
) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]:
output = self.collective_rpc( # type: ignore[call-overload]
"execute_model", args=(scheduler_output,), non_block=non_block
)
return output[0]
@overload
def sample_tokens(
self, grammar_output: GrammarOutput | None, non_block: Literal[False] = False
) -> ModelRunnerOutput:
pass
@overload
def sample_tokens(
self, grammar_output: GrammarOutput | None, non_block: Literal[True] = True
) -> Future[ModelRunnerOutput]:
pass
def sample_tokens(
self, grammar_output: GrammarOutput | None, non_block: bool = False
) -> ModelRunnerOutput | Future[ModelRunnerOutput]:
output = self.collective_rpc( # type: ignore[call-overload]
"sample_tokens", args=(grammar_output,), non_block=non_block
)
return output[0]
def execute_dummy_batch(self) -> None:
self.collective_rpc("execute_dummy_batch")
......
......@@ -46,7 +46,7 @@ from vllm.utils.system_utils import (
get_mp_context,
set_process_title,
)
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.executor.abstract import Executor, FailureCallback
from vllm.v1.outputs import AsyncModelRunnerOutput, DraftTokenIds, ModelRunnerOutput
from vllm.v1.worker.worker_base import WorkerWrapperBase
......@@ -132,12 +132,9 @@ class MultiprocExecutor(Executor):
uw.death_writer.close()
self._ensure_worker_termination([uw.proc for uw in unready_workers])
# For pipeline parallel, we use a thread pool for asynchronous
# execute_model.
if self.max_concurrent_batches > 1:
# Note: must use only 1 IO thread to keep dequeue sequence
# from the response queue
# _async_aggregate_workers_output also assumes a single IO thread
# from the response queue.
# _async_aggregate_workers_output also assumes a single IO thread.
self.io_thread_pool = ThreadPoolExecutor(
max_workers=1, thread_name_prefix="mp_exec_io"
)
......@@ -180,15 +177,27 @@ class MultiprocExecutor(Executor):
self.failure_callback = callback
def execute_model( # type: ignore[override]
self,
scheduler_output: SchedulerOutput,
non_block: bool = False,
self, scheduler_output: SchedulerOutput, non_block: bool = False
) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]:
return self._execute_with_aggregation(
"execute_model", scheduler_output, non_block=non_block
)
def sample_tokens( # type: ignore[override]
self, grammar_output: GrammarOutput | None, non_block: bool = False
) -> ModelRunnerOutput | Future[ModelRunnerOutput]:
return self._execute_with_aggregation( # type: ignore[return-value]
"sample_tokens", grammar_output, non_block=non_block
)
def _execute_with_aggregation(
self, method: str, *args, non_block: bool = False
) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]:
if not self.has_connector:
# get output only from a single worker (output_rank)
(output,) = self.collective_rpc(
"execute_model",
args=(scheduler_output,),
method,
args=args,
unique_reply_rank=self.output_rank,
non_block=non_block,
timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS,
......@@ -197,8 +206,8 @@ class MultiprocExecutor(Executor):
# get output from all workers
outputs = self.collective_rpc(
"execute_model",
args=(scheduler_output,),
method,
args=args,
non_block=non_block,
timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS,
)
......
......@@ -19,7 +19,7 @@ from vllm.utils.network_utils import (
get_ip,
get_open_port,
)
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
from vllm.v1.executor.abstract import Executor
from vllm.v1.executor.ray_utils import (
......@@ -41,6 +41,9 @@ if TYPE_CHECKING:
logger = init_logger(__name__)
COMPLETED_NONE_FUTURE: Future[ModelRunnerOutput | None] = Future()
COMPLETED_NONE_FUTURE.set_result(None)
@dataclass
class RayWorkerMetaData:
......@@ -96,6 +99,8 @@ class RayDistributedExecutor(Executor):
# KV connector setup
self.has_connector = self.vllm_config.kv_transfer_config is not None
self.scheduler_output: SchedulerOutput | None = None
@property
def max_concurrent_batches(self) -> int:
"""Ray distributed executor supports pipeline parallelism,
......@@ -381,22 +386,46 @@ class RayDistributedExecutor(Executor):
self.shutdown()
def execute_model( # type: ignore[override]
self, scheduler_output: SchedulerOutput, non_block: bool = False
self,
scheduler_output: SchedulerOutput,
non_block: bool = False,
) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]:
if self.scheduler_output is not None:
raise RuntimeError(
"State error: sample_tokens() must be called "
"after execute_model() returns None."
)
self.scheduler_output = scheduler_output
return COMPLETED_NONE_FUTURE if non_block else None
def sample_tokens( # type: ignore[override]
self,
grammar_output: "GrammarOutput | None",
non_block: bool = False,
) -> ModelRunnerOutput | Future[ModelRunnerOutput]:
"""Execute the model on the Ray workers.
The scheduler output to use should have been provided in
a prior call to execute_model().
Args:
scheduler_output: The scheduler output to execute.
grammar_output: The structured outputs grammar bitmask, if applicable.
non_block: If True, the method will return a Future.
Returns:
The model runner output.
"""
scheduler_output = self.scheduler_output
if scheduler_output is None:
return None # noqa
self.scheduler_output = None
# Build the compiled DAG for the first time.
if self.forward_dag is None: # type: ignore
self.forward_dag = self._compiled_ray_dag(enable_asyncio=False)
refs = self.forward_dag.execute(scheduler_output) # type: ignore
refs = self.forward_dag.execute((scheduler_output, grammar_output)) # type: ignore
if not self.has_connector:
# Get output only from a single worker (output_rank)
......
......@@ -19,7 +19,7 @@ from vllm.v1.outputs import AsyncModelRunnerOutput
from vllm.v1.worker.worker_base import WorkerWrapperBase
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.outputs import ModelRunnerOutput
logger = init_logger(__name__)
......@@ -82,31 +82,36 @@ try:
def execute_model_ray(
self,
scheduler_output: Union[
"SchedulerOutput", tuple["SchedulerOutput", "IntermediateTensors"]
],
execute_model_input: tuple["SchedulerOutput", "GrammarOutput"]
| tuple["SchedulerOutput", "GrammarOutput", "IntermediateTensors"],
) -> Union[
"ModelRunnerOutput", tuple["SchedulerOutput", "IntermediateTensors"]
"ModelRunnerOutput",
tuple["SchedulerOutput", "GrammarOutput", "IntermediateTensors"],
]:
# This method is used by Ray Compiled Graph to execute the model,
# and it needs a special logic of self.setup_device_if_necessary()
self.setup_device_if_necessary()
assert self.worker is not None, "Worker is not initialized"
if isinstance(scheduler_output, tuple):
scheduler_output, intermediate_tensors = scheduler_output
if len(execute_model_input) == 3:
scheduler_output, grammar_output, intermediate_tensors = (
execute_model_input
)
else:
scheduler_output, intermediate_tensors = scheduler_output, None
scheduler_output, grammar_output = execute_model_input
intermediate_tensors = None
assert self.worker.model_runner is not None
output = self.worker.model_runner.execute_model(
scheduler_output, intermediate_tensors
)
if isinstance(output, IntermediateTensors):
output = scheduler_output, output
output = scheduler_output, grammar_output, output
elif not get_pp_group().is_last_rank:
# Case where there are no scheduled requests
# but may still be finished requests.
assert not output or not output.req_ids
output = scheduler_output, None
output = scheduler_output, grammar_output, None
elif output is None:
output = self.worker.model_runner.sample_tokens(grammar_output)
# Ensure outputs crossing Ray compiled DAG are serializable.
# AsyncModelRunnerOutput holds CUDA events and cannot be
# pickled.
......
......@@ -16,6 +16,7 @@ from diskcache import Cache
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.utils.import_utils import LazyLoader
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
if TYPE_CHECKING:
import outlines_core as oc
......@@ -24,7 +25,6 @@ if TYPE_CHECKING:
import xgrammar as xgr
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
else:
xgr = LazyLoader("xgr", globals(), "xgrammar")
......@@ -47,6 +47,7 @@ CACHE = None
def apply_grammar_bitmask(
scheduler_output: SchedulerOutput,
grammar_output: GrammarOutput,
input_batch: InputBatch,
logits: torch.Tensor,
) -> None:
......@@ -58,9 +59,9 @@ def apply_grammar_bitmask(
input_batch (InputBatch): The input of model runner.
logits (torch.Tensor): The output logits of model forward.
"""
grammar_bitmask = scheduler_output.grammar_bitmask
if grammar_bitmask is None:
return
# Serialization of np.ndarray is much more efficient than a tensor,
# so we receive it in that format.
grammar_bitmask = grammar_output.grammar_bitmask
# We receive the structured output bitmask from the scheduler,
# compacted to contain bitmasks only for structured output requests.
......@@ -79,7 +80,7 @@ def apply_grammar_bitmask(
cumulative_offset += len(
scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])
)
if req_id in scheduler_output.structured_output_request_ids:
if req_id in grammar_output.structured_output_request_ids:
struct_out_req_batch_indices[req_id] = logit_index
out_indices = []
......@@ -91,7 +92,7 @@ def apply_grammar_bitmask(
dtype=grammar_bitmask.dtype,
)
cumulative_index = 0
for req_id in scheduler_output.structured_output_request_ids:
for req_id in grammar_output.structured_output_request_ids:
num_spec_tokens = len(
scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])
)
......@@ -101,22 +102,28 @@ def apply_grammar_bitmask(
sorted_bitmask[logit_index + i] = grammar_bitmask[cumulative_index + i]
out_indices.append(logit_index + i)
cumulative_index += 1 + num_spec_tokens
grammar_bitmask = sorted_bitmask
# Copy async to device as tensor.
grammar_bitmask = torch.from_numpy(sorted_bitmask).to(
logits.device, non_blocking=True
)
# If the length of out indices and the logits have the same shape
# we don't need to pass indices to the kernel,
# since the bitmask is already aligned with the logits.
skip_out_indices = len(out_indices) == logits.shape[0]
# Serialization of np.ndarray is much more efficient than a tensor,
# so we receive it in that format.
grammar_bitmask = torch.from_numpy(grammar_bitmask).contiguous()
xgr.apply_token_bitmask_inplace(
logits,
grammar_bitmask.to(logits.device, non_blocking=True),
indices=out_indices if not skip_out_indices else None,
index_tensor = None
if not skip_out_indices:
# xgrammar expects a python list of indices but it will actually work with
# a tensor. If we copy the tensor ourselves here we can do it in a non_blocking
# manner and there should be no cpu sync within xgrammar.
index_tensor = torch.tensor(
out_indices, dtype=torch.int32, device="cpu", pin_memory=True
)
index_tensor = index_tensor.to(logits.device, non_blocking=True)
xgr.apply_token_bitmask_inplace(logits, grammar_bitmask, indices=index_tensor)
class OutlinesVocabulary:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment