# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import enum import time from collections import deque from collections.abc import Callable, Mapping from dataclasses import dataclass from functools import partial from typing import TYPE_CHECKING, Any, Optional import torch from vllm.multimodal.inputs import MultiModalFeatureSpec from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.utils import length_from_prompt_token_ids_or_embeds from vllm.v1.engine import ( EngineCoreEvent, EngineCoreEventType, EngineCoreRequest, FinishReason, ) from vllm.v1.structured_output.request import StructuredOutputRequest from vllm.v1.utils import ConstantList if TYPE_CHECKING: from vllm.lora.request import LoRARequest from vllm.v1.core.kv_cache_utils import BlockHash @dataclass class StreamingUpdate: """Lightweight data for streaming session continuation. Contains only the fields needed to update an existing streaming session with new input data. """ mm_features: list[MultiModalFeatureSpec] | None prompt_token_ids: list[int] | None max_tokens: int arrival_time: float sampling_params: SamplingParams | None @classmethod def from_request(cls, request: "Request") -> "StreamingUpdate | None": if not request.resumable: return None return cls( mm_features=request.mm_features, prompt_token_ids=request.prompt_token_ids, max_tokens=request.max_tokens, arrival_time=request.arrival_time, sampling_params=request.sampling_params, ) class Request: def __init__( self, request_id: str, prompt_token_ids: list[int] | None, sampling_params: SamplingParams | None, pooling_params: PoolingParams | None, eos_token_id: int | None, client_index: int = 0, arrival_time: float | None = None, prompt_embeds: torch.Tensor | None = None, mm_features: list[MultiModalFeatureSpec] | None = None, lora_request: Optional["LoRARequest"] = None, cache_salt: str | None = None, priority: int = 0, trace_headers: Mapping[str, str] | None = None, block_hasher: Callable[["Request"], list["BlockHash"]] | None = None, resumable: bool = False, ) -> None: self.request_id = request_id self.client_index = client_index self.priority = priority self.sampling_params = sampling_params self.pooling_params = pooling_params # Because of LoRA, the eos token id can be different for each request. self.eos_token_id = eos_token_id self.lora_request = lora_request self.structured_output_request = StructuredOutputRequest.from_sampling_params( sampling_params ) self.arrival_time = arrival_time if arrival_time is not None else time.time() self.status = RequestStatus.WAITING self.events: list[EngineCoreEvent] = [] self.stop_reason: int | str | None = None # P/D: Connector-specific KV transfer parameters. self.kv_transfer_params: dict[str, Any] | None = None if pooling_params is not None: # Pooling models. self.max_tokens = 1 elif sampling_params is not None: # Generative models. assert sampling_params.max_tokens is not None self.max_tokens = sampling_params.max_tokens if self.structured_output_request is not None: self.status = RequestStatus.WAITING_FOR_FSM if sampling_params.extra_args is not None: self.kv_transfer_params = sampling_params.extra_args.get( "kv_transfer_params" ) else: raise ValueError("sampling_params and pooling_params can't both be unset") self.prompt_token_ids = prompt_token_ids self.prompt_embeds = prompt_embeds self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds( prompt_token_ids, prompt_embeds ) self._output_token_ids: list[int] = [] self._all_token_ids: list[int] = ( self.prompt_token_ids.copy() if self.prompt_token_ids is not None else [0] * self.num_prompt_tokens ) # Used in async scheduling. self.num_output_placeholders = 0 # Used in forced preemption (reset_prefix_cache) with async scheduling. self.discard_latest_async_tokens = False self.spec_token_ids: list[int] = [] self.num_computed_tokens = 0 self.num_generated_token_ids = 0 self.cache_salt: str | None = cache_salt # Multi-modal related self.mm_features = mm_features or [] # Read-only views # Prevent directly appending to these lists since # they should also be updated simultaneously. self.output_token_ids = ConstantList(self._output_token_ids) self.all_token_ids = ConstantList(self._all_token_ids) # trace_headers self.trace_headers = trace_headers # State # The number of tokens with prefix cache hits. self.num_cached_tokens = -1 # The number of NaNs in logits. A value greater than 0 # indicates that the output is corrupted self.num_nans_in_logits = 0 # The number of times this request has been preempted by the scheduler. self.num_preemptions = 0 # The number of tokens that have been computed remotely. self.num_external_computed_tokens = 0 self.block_hashes: list[BlockHash] = [] self.get_hash_new_full_blocks: Callable[[], list[BlockHash]] | None = None if block_hasher is not None: self.get_hash_new_full_blocks = partial(block_hasher, self) self.block_hashes = self.get_hash_new_full_blocks() self.skip_reading_prefix_cache = self.get_skip_reading_prefix_cache() # Used for streaming self.resumable = resumable # None entry in the queue means finished. self.streaming_queue: deque[StreamingUpdate | None] | None = None @classmethod def from_engine_core_request( cls, request: EngineCoreRequest, block_hasher: Callable[["Request"], list["BlockHash"]] | None, ) -> "Request": return cls( request_id=request.request_id, client_index=request.client_index, prompt_token_ids=request.prompt_token_ids, prompt_embeds=request.prompt_embeds, mm_features=request.mm_features, sampling_params=request.sampling_params, pooling_params=request.pooling_params, eos_token_id=request.eos_token_id, arrival_time=request.arrival_time, lora_request=request.lora_request, cache_salt=request.cache_salt, priority=request.priority, trace_headers=request.trace_headers, block_hasher=block_hasher, resumable=request.resumable, ) def append_output_token_ids( self, token_ids: int | list[int], ) -> None: if isinstance(token_ids, int): self._output_token_ids.append(token_ids) self._all_token_ids.append(token_ids) else: self._output_token_ids.extend(token_ids) self._all_token_ids.extend(token_ids) if self.get_hash_new_full_blocks is not None: self.block_hashes.extend(self.get_hash_new_full_blocks()) @property def use_structured_output(self) -> bool: return self.structured_output_request is not None @property def num_tokens(self) -> int: return len(self._all_token_ids) @property def num_tokens_with_spec(self) -> int: return len(self._all_token_ids) + len(self.spec_token_ids) @property def num_output_tokens(self) -> int: return len(self._output_token_ids) @property def num_encoder_inputs(self) -> int: return len(self.mm_features) @property def has_encoder_inputs(self) -> bool: return self.num_encoder_inputs > 0 def get_skip_reading_prefix_cache(self) -> bool: if ( self.sampling_params is not None and self.sampling_params.skip_reading_prefix_cache is not None ): return self.sampling_params.skip_reading_prefix_cache elif ( self.pooling_params is not None and self.pooling_params.skip_reading_prefix_cache is not None ): return self.pooling_params.skip_reading_prefix_cache return False def is_finished(self) -> bool: return RequestStatus.is_finished(self.status) def get_finished_reason(self) -> FinishReason | None: return RequestStatus.get_finished_reason(self.status) def get_num_encoder_embeds(self, input_id: int) -> int: assert input_id < len(self.mm_features) return self.mm_features[input_id].mm_position.get_num_embeds def record_event( self, event_type: EngineCoreEventType, timestamp: float | None = None, ) -> None: self.events.append(EngineCoreEvent.new_event(event_type, timestamp)) def take_events(self) -> list[EngineCoreEvent] | None: if not self.events: return None events, self.events = self.events, [] return events def __lt__(self, other: "Request") -> bool: """ Compare two requests based on priority, arrival time, and request ID. Used in priority scheduling. """ if self.priority != other.priority: return self.priority < other.priority if self.arrival_time != other.arrival_time: return self.arrival_time < other.arrival_time if self.request_id != other.request_id: return self.request_id < other.request_id return id(self) < id(other) class RequestStatus(enum.IntEnum): """Status of a request.""" WAITING = enum.auto() WAITING_FOR_FSM = enum.auto() WAITING_FOR_REMOTE_KVS = enum.auto() WAITING_FOR_STREAMING_REQ = enum.auto() RUNNING = enum.auto() PREEMPTED = enum.auto() # Note: anything after PREEMPTED will be considered # as a finished status. FINISHED_STOPPED = enum.auto() FINISHED_LENGTH_CAPPED = enum.auto() FINISHED_ABORTED = enum.auto() FINISHED_IGNORED = enum.auto() FINISHED_ERROR = enum.auto() def __str__(self) -> str: return self.name @staticmethod def is_finished(status: "RequestStatus") -> bool: return status > RequestStatus.PREEMPTED @staticmethod def get_finished_reason(status: "RequestStatus") -> FinishReason | None: return _FINISHED_REASON_MAP.get(status) # Mapping of finished statuses to their finish reasons. # NOTE: The ignored requests are the requests whose prompt lengths # are longer than the model's length cap. Therefore, the stop # reason should also be "length" as in OpenAI API. _FINISHED_REASON_MAP = { RequestStatus.FINISHED_STOPPED: FinishReason.STOP, RequestStatus.FINISHED_LENGTH_CAPPED: FinishReason.LENGTH, RequestStatus.FINISHED_ABORTED: FinishReason.ABORT, RequestStatus.FINISHED_IGNORED: FinishReason.LENGTH, RequestStatus.FINISHED_ERROR: FinishReason.ERROR, RequestStatus.WAITING_FOR_STREAMING_REQ: FinishReason.STOP, }