Unverified Commit 0cdbe7b7 authored by Nick Hill's avatar Nick Hill Committed by GitHub
Browse files

[Core] Async scheduling + structured outputs compatibility (#26866)


Signed-off-by: default avatarNick Hill <nhill@redhat.com>
parent df334868
......@@ -109,6 +109,7 @@ from vllm.v1.outputs import (
EMPTY_MODEL_RUNNER_OUTPUT,
AsyncModelRunnerOutput,
DraftTokenIds,
KVConnectorOutput,
LogprobsLists,
LogprobsTensors,
ModelRunnerOutput,
......@@ -150,7 +151,7 @@ from .utils import (
if TYPE_CHECKING:
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
logger = init_logger(__name__)
......@@ -218,6 +219,20 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput):
return output
class ExecuteModelState(NamedTuple):
"""Ephemeral cached state transferred between execute_model() and
sample_tokens(), after execute_model() returns None."""
scheduler_output: "SchedulerOutput"
logits: torch.Tensor
spec_decode_metadata: SpecDecodeMetadata | None
spec_decode_common_attn_metadata: CommonAttentionMetadata | None
hidden_states: torch.Tensor
sample_hidden_states: torch.Tensor
aux_hidden_states: list[torch.Tensor] | None
kv_connector_output: KVConnectorOutput | None
class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def __init__(
self,
......@@ -509,6 +524,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
pin_memory=self.pin_memory,
)
# Ephemeral state transferred between execute_model() and sample_tokens().
self.execute_model_state: ExecuteModelState | None = None
def reset_mm_cache(self) -> None:
if self.mm_budget:
self.mm_budget.reset_cache()
......@@ -2113,7 +2131,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_input_tokens: int, # Padded
intermediate_tensors: IntermediateTensors | None = None,
) -> tuple[
int,
torch.Tensor | None,
torch.Tensor | None,
torch.Tensor,
......@@ -2207,7 +2224,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
model_kwargs.update(encoder_inputs)
return (
num_scheduled_tokens,
input_ids,
inputs_embeds,
positions,
......@@ -2425,13 +2441,19 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self,
scheduler_output: "SchedulerOutput",
intermediate_tensors: IntermediateTensors | None = None,
) -> ModelRunnerOutput | AsyncModelRunnerOutput | IntermediateTensors:
) -> ModelRunnerOutput | IntermediateTensors | None:
if self.execute_model_state is not None:
raise RuntimeError(
"State error: sample_tokens() must be called "
"after execute_model() returns None."
)
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
with record_function_or_nullcontext("Preprocess"):
with self.synchronize_input_prep():
# Update persistent batch states.
self._update_states(scheduler_output)
if not scheduler_output.total_num_scheduled_tokens:
if not num_scheduled_tokens:
if not has_kv_transfer_group():
# Return empty ModelRunnerOutput if no work to do.
return EMPTY_MODEL_RUNNER_OUTPUT
......@@ -2471,7 +2493,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
)
(
num_scheduled_tokens,
input_ids,
inputs_embeds,
positions,
......@@ -2559,6 +2580,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Rare case.
assert not self.is_pooling_model
sample_hidden_states = hidden_states[logits_indices]
if not get_pp_group().is_last_rank:
all_gather_tensors = {
"residual": not is_residual_scattered_for_sp(
......@@ -2572,7 +2594,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
)
logits = None
else:
sample_hidden_states = hidden_states[logits_indices]
logits = self.model.compute_logits(sample_hidden_states)
model_output_broadcast_data = {}
......@@ -2585,9 +2606,45 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
assert model_output_broadcast_data is not None
logits = model_output_broadcast_data["logits"]
# Apply structured output bitmasks if present
if scheduler_output.structured_output_request_ids:
apply_grammar_bitmask(scheduler_output, self.input_batch, logits)
self.execute_model_state = ExecuteModelState(
scheduler_output,
logits,
spec_decode_metadata,
spec_decode_common_attn_metadata,
hidden_states,
sample_hidden_states,
aux_hidden_states,
kv_connector_output,
)
return None
@torch.inference_mode
def sample_tokens(
self, grammar_output: "GrammarOutput | None"
) -> ModelRunnerOutput | AsyncModelRunnerOutput | IntermediateTensors:
if self.execute_model_state is None:
# Nothing to do (PP non-final rank case), output isn't used.
return None # noqa
# Unpack ephemeral state.
(
scheduler_output,
logits,
spec_decode_metadata,
spec_decode_common_attn_metadata,
hidden_states,
sample_hidden_states,
aux_hidden_states,
kv_connector_output,
) = self.execute_model_state
# Clear ephemeral state.
self.execute_model_state = None
# Apply structured output bitmasks if present.
if grammar_output is not None:
apply_grammar_bitmask(
scheduler_output, grammar_output, self.input_batch, logits
)
with record_function_or_nullcontext("Sample"):
sampler_output = self._sample(logits, spec_decode_metadata)
......@@ -2646,7 +2703,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
sampler_output,
logits,
hidden_states,
num_scheduled_tokens,
scheduler_output.total_num_scheduled_tokens,
spec_decode_metadata,
)
......
......@@ -6,6 +6,7 @@ import copy
import gc
import os
from contextlib import AbstractContextManager, nullcontext
from types import NoneType
from typing import TYPE_CHECKING, Any
import torch
......@@ -37,6 +38,7 @@ from vllm.sequence import IntermediateTensors
from vllm.tasks import SupportedTask
from vllm.utils.mem_constants import GiB_bytes
from vllm.utils.mem_utils import MemorySnapshot, memory_profiling
from vllm.v1.core.sched.output import GrammarOutput
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import (
......@@ -508,11 +510,16 @@ class Worker(WorkerBase):
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
return self.model_runner.get_supported_tasks()
@torch.inference_mode()
def sample_tokens(
self, grammar_output: "GrammarOutput"
) -> ModelRunnerOutput | AsyncModelRunnerOutput:
return self.model_runner.sample_tokens(grammar_output)
@torch.inference_mode()
def execute_model(
self,
scheduler_output: "SchedulerOutput",
) -> ModelRunnerOutput | AsyncModelRunnerOutput | None:
self, scheduler_output: "SchedulerOutput"
) -> ModelRunnerOutput | None:
intermediate_tensors = None
forward_pass = scheduler_output.total_num_scheduled_tokens > 0
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
......@@ -531,13 +538,13 @@ class Worker(WorkerBase):
)
output = self.model_runner.execute_model(scheduler_output, intermediate_tensors)
if isinstance(output, (ModelRunnerOutput, AsyncModelRunnerOutput)):
if isinstance(output, (ModelRunnerOutput, NoneType)):
return output
assert isinstance(output, IntermediateTensors)
parallel_config = self.vllm_config.parallel_config
assert (
parallel_config.distributed_executor_backend != ("external_launcher")
parallel_config.distributed_executor_backend != "external_launcher"
and not get_pp_group().is_last_rank
)
......
......@@ -92,7 +92,7 @@ from .utils import (
)
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
logger = init_logger(__name__)
......@@ -372,6 +372,11 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
else:
self.sample_from_logits_func = self.sample_from_logits
# For passing scheduler_output between successive
# execute_model() and sample_tokens() calls.
self.scheduler_output: SchedulerOutput | None = None
self.mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None
def reset_mm_cache(self) -> None:
if self.mm_budget:
self.mm_budget.reset_cache()
......@@ -1078,7 +1083,12 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self,
scheduler_output: "SchedulerOutput",
intermediate_tensors: IntermediateTensors | None = None,
) -> ModelRunnerOutput:
) -> ModelRunnerOutput | None:
if self.scheduler_output is not None:
raise RuntimeError(
"State error: sample_tokens() must be called "
"after execute_model() returns None."
)
# Update cached state
self._update_states(scheduler_output)
if not scheduler_output.total_num_scheduled_tokens:
......@@ -1088,14 +1098,30 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
return self.kv_connector_no_forward(scheduler_output, self.vllm_config)
mm_embed_inputs = None
if self.supports_mm_inputs:
# Run the multimodal encoder if any.
self._execute_mm_encoder(scheduler_output)
mm_embed_inputs = self._gather_mm_embeddings(scheduler_output)
else:
mm_embed_inputs = None
torch_xla.sync(wait=False)
self.scheduler_output = scheduler_output
self.mm_embed_inputs = mm_embed_inputs
return None
@torch.no_grad()
def sample_tokens(
self, grammar_output: "GrammarOutput | None"
) -> ModelRunnerOutput:
if self.scheduler_output is None:
# Nothing to do (PP non-final rank case), output isn't used.
return None # noqa
scheduler_output = self.scheduler_output
mm_embed_inputs = self.mm_embed_inputs
self.scheduler_output = None
self.mm_embed_inputs = None
# Prepare inputs, the requests might be split into multiple
# executions, combine the result of each execution.
start_index = 0
......@@ -1131,9 +1157,9 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
tpu_sampling_metadata = TPUSupportedSamplingMetadata.from_input_batch(
self.input_batch, padded_num_reqs, self.device
)
if scheduler_output.grammar_bitmask is not None:
if grammar_output is not None:
require_struct_decoding, grammar_bitmask_padded, arange = (
self.prepare_structured_decoding_input(logits, scheduler_output)
self.prepare_structured_decoding_input(logits, grammar_output)
)
logits = self.structured_decode(
require_struct_decoding, grammar_bitmask_padded, logits, arange
......@@ -1954,10 +1980,9 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
return self.model.get_input_embeddings(*args, **kwargs)
def prepare_structured_decoding_input(
self, logits: torch.Tensor, scheduler_output: "SchedulerOutput"
self, logits: torch.Tensor, grammar_output: "GrammarOutput"
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
grammar_bitmask = scheduler_output.grammar_bitmask
assert grammar_bitmask is not None
grammar_bitmask = grammar_output.grammar_bitmask
num_reqs, _ = logits.shape
# Reset pre-allocated tensors
......@@ -1965,7 +1990,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.require_structured_out_cpu.zero_()
cumulative_mask_idx = 0
for req_id in scheduler_output.structured_output_request_ids:
for req_id in grammar_output.structured_output_request_ids:
if req_id not in self.input_batch.req_id_to_index:
continue
batch_index = self.input_batch.req_id_to_index[req_id]
......
......@@ -17,7 +17,6 @@ from vllm.distributed import (
)
from vllm.distributed.kv_transfer import (
ensure_kv_transfer_initialized,
has_kv_transfer_group,
)
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
......@@ -27,7 +26,7 @@ from vllm.platforms.tpu import USE_TPU_INFERENCE
from vllm.tasks import SupportedTask
from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.utils import report_usage_stats
......@@ -255,13 +254,13 @@ class TPUWorker:
tpu_kv_cache_bytes = tpu_kv_cache_bytes * head_size // padded_head_size
return int(tpu_kv_cache_bytes)
def sample_tokens(self, grammar_output: "GrammarOutput") -> ModelRunnerOutput:
return self.model_runner.sample_tokens(grammar_output)
def execute_model(
self,
scheduler_output: "SchedulerOutput",
self, scheduler_output: "SchedulerOutput"
) -> ModelRunnerOutput | None:
output = self.model_runner.execute_model(scheduler_output)
# every worker's output is needed when kv_transfer_group is set up
return output if self.is_driver_worker or has_kv_transfer_group() else None
return self.model_runner.execute_model(scheduler_output)
def profile(self, is_start: bool = True):
if self.rank < 1:
......
......@@ -20,10 +20,12 @@ from vllm.v1.kv_cache_interface import KVCacheSpec
from vllm.v1.serial_utils import run_method
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.outputs import AsyncModelRunnerOutput, ModelRunnerOutput
else:
SchedulerOutput = object
GrammarOutput = object
AsyncModelRunnerOutput = object
ModelRunnerOutput = object
logger = init_logger(__name__)
......@@ -122,7 +124,21 @@ class WorkerBase:
"""Load model onto target device."""
raise NotImplementedError
def execute_model(self, scheduler_output: SchedulerOutput) -> ModelRunnerOutput:
def execute_model(
self, scheduler_output: SchedulerOutput
) -> ModelRunnerOutput | None:
"""If this method returns None, sample_tokens should be called immediately after
to obtain the ModelRunnerOutput.
Note that this design may be changed in future if/when structured outputs
parallelism is re-architected.
"""
raise NotImplementedError
def sample_tokens(
self, grammar_output: GrammarOutput
) -> ModelRunnerOutput | AsyncModelRunnerOutput:
"""Should be called immediately after execute_model iff it returned None."""
raise NotImplementedError
def get_cache_block_size_bytes(self) -> int:
......@@ -344,7 +360,7 @@ class WorkerWrapperBase:
scheduler_output: SchedulerOutput,
*args,
**kwargs,
) -> ModelRunnerOutput:
) -> ModelRunnerOutput | None:
self._apply_mm_cache(scheduler_output)
assert self.worker is not None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment