from collections.abc import Callable from typing import TYPE_CHECKING import numpy as np import torch from vllm.v1.request import Request if TYPE_CHECKING: from vllm.v1.core.kv_cache_utils import BlockHash from vllm_omni.engine import AdditionalInformationPayload, OmniEngineCoreRequest, PromptEmbedsPayload class OmniRequest(Request): """Request class for omni models, extending the base Request. This class extends the base vLLM Request with support for prompt embeddings and additional information payloads, enabling direct transfer of pre-computed embeddings between stages. Args: prompt_embeds: Optional serialized prompt embeddings payload. Used for direct transfer of embeddings between stages. additional_information: Optional additional information payload containing tensors or lists to be passed along with the request. """ def __init__( self, prompt_embeds: PromptEmbedsPayload | torch.Tensor | None = None, # Optional external request ID for tracking external_req_id: str | None = None, additional_information: AdditionalInformationPayload | None = None, *args, **kwargs, ): prompt_embeds_tensor = self._maybe_decode_prompt_embeds(prompt_embeds) super().__init__(prompt_embeds=prompt_embeds_tensor, *args, **kwargs) # Preserve serialized prompt embeddings payload (optional) self.prompt_embeds_payload: PromptEmbedsPayload | None = ( prompt_embeds if isinstance(prompt_embeds, PromptEmbedsPayload) else None ) # Optional external request ID for tracking self.external_req_id: str | None = external_req_id # Serialized additional information payload (optional) self.additional_information: AdditionalInformationPayload | None = additional_information @staticmethod def _maybe_decode_prompt_embeds( prompt_embeds: PromptEmbedsPayload | torch.Tensor | None, ) -> torch.Tensor | None: if isinstance(prompt_embeds, PromptEmbedsPayload): dtype = getattr(np, prompt_embeds.dtype) arr = np.frombuffer(prompt_embeds.data, dtype=dtype) arr = arr.reshape(prompt_embeds.shape) return torch.from_numpy(arr) return prompt_embeds @classmethod def from_engine_core_request( cls, request: OmniEngineCoreRequest, block_hasher: Callable[["Request"], list["BlockHash"]] | None, ) -> "Request": """Create an OmniRequest from an OmniEngineCoreRequest. Args: request: The OmniEngineCoreRequest to convert block_hasher: Optional function to compute block hashes for prefix caching Returns: OmniRequest instance created from the engine core request """ return cls( request_id=request.request_id, # Optional external request ID for tracking external_req_id=request.external_req_id, client_index=request.client_index, prompt_token_ids=request.prompt_token_ids, prompt_embeds=request.prompt_embeds, mm_features=request.mm_features, sampling_params=request.sampling_params, pooling_params=request.pooling_params, eos_token_id=request.eos_token_id, arrival_time=request.arrival_time, lora_request=request.lora_request, cache_salt=request.cache_salt, priority=request.priority, trace_headers=request.trace_headers, block_hasher=block_hasher, additional_information=request.additional_information, resumable=request.resumable, )