# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import asyncio import time from collections.abc import AsyncGenerator, AsyncIterator from http import HTTPStatus from typing import Callable, Final, Optional, Union import jinja2 from fastapi import Request from openai.types.responses import ResponseOutputMessage, ResponseOutputText from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam, ChatTemplateContentFormatOption) from vllm.entrypoints.logger import RequestLogger # yapf conflicts with isort for this block # yapf: disable from vllm.entrypoints.openai.protocol import (ErrorResponse, PromptTokenUsageInfo, RequestResponseMetadata, ResponseReasoningItem, ResponsesRequest, ResponsesResponse, UsageInfo) # yapf: enable from vllm.entrypoints.openai.serving_engine import OpenAIServing from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.logger import init_logger from vllm.outputs import RequestOutput from vllm.reasoning import ReasoningParser, ReasoningParserManager from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import random_uuid logger = init_logger(__name__) class OpenAIServingResponses(OpenAIServing): def __init__( self, engine_client: EngineClient, model_config: ModelConfig, models: OpenAIServingModels, *, request_logger: Optional[RequestLogger], chat_template: Optional[str], chat_template_content_format: ChatTemplateContentFormatOption, return_tokens_as_token_ids: bool = False, reasoning_parser: str = "", enable_auto_tools: bool = False, tool_parser: Optional[str] = None, enable_prompt_tokens_details: bool = False, enable_force_include_usage: bool = False, ) -> None: super().__init__( engine_client=engine_client, model_config=model_config, models=models, request_logger=request_logger, return_tokens_as_token_ids=return_tokens_as_token_ids, enable_force_include_usage=enable_force_include_usage, ) self.chat_template = chat_template self.chat_template_content_format: Final = chat_template_content_format self.reasoning_parser: Optional[Callable[[AnyTokenizer], ReasoningParser]] = None if reasoning_parser: try: self.reasoning_parser = ( ReasoningParserManager.get_reasoning_parser( reasoning_parser)) assert self.reasoning_parser is not None except Exception as e: raise TypeError( f"{reasoning_parser=} has not been registered") from e self.enable_prompt_tokens_details = enable_prompt_tokens_details self.enable_force_include_usage = enable_force_include_usage self.default_sampling_params = ( self.model_config.get_diff_sampling_param()) if self.default_sampling_params: source = self.model_config.generation_config source = "model" if source == "auto" else source logger.info("Using default chat sampling params from %s: %s", source, self.default_sampling_params) # HACK(woosuk): This is a hack. We should use a better store. # FIXME: This causes a memory leak since we never remove responses # from the store. self.response_store: dict[str, ResponsesResponse] = {} self.response_store_lock = asyncio.Lock() # HACK(woosuk): This is a hack. We should use a better store. # FIXME: This causes a memory leak since we never remove messages # from the store. self.msg_store: dict[str, list[ChatCompletionMessageParam]] = {} self.background_tasks: dict[str, asyncio.Task] = {} async def create_responses( self, request: ResponsesRequest, raw_request: Optional[Request] = None, ) -> Union[AsyncGenerator[str, None], ResponsesResponse, ErrorResponse]: error_check_ret = await self._check_model(request) if error_check_ret is not None: logger.error("Error with model %s", error_check_ret) return error_check_ret # If the engine is dead, raise the engine's DEAD_ERROR. # This is required for the streaming case, where we return a # success status before we actually start generating text :). if self.engine_client.errored: raise self.engine_client.dead_error # Handle the previous response ID. prev_response_id = request.previous_response_id if prev_response_id is not None: if not prev_response_id.startswith("resp_"): return self._make_invalid_id_error(prev_response_id) async with self.response_store_lock: prev_response = self.response_store.get(prev_response_id) if prev_response is None: return self._make_not_found_error(prev_response_id) else: prev_response = None # Construct the input messages. messages = self._construct_input_messages(request, prev_response) try: ( lora_request, prompt_adapter_request, ) = self._maybe_get_adapters(request) model_name = self._get_model_name(request.model, lora_request) tokenizer = await self.engine_client.get_tokenizer(lora_request) _, request_prompts, engine_prompts = await self._preprocess_chat( request, tokenizer, messages, chat_template=self.chat_template, chat_template_content_format=self.chat_template_content_format, ) except (ValueError, TypeError, RuntimeError, jinja2.TemplateError) as e: logger.exception("Error in preprocessing prompt inputs") return self.create_error_response(f"{e} {e.__cause__}") request_metadata = RequestResponseMetadata( request_id=request.request_id) if raw_request: raw_request.state.request_metadata = request_metadata # Schedule the request and get the result generator. generators: list[AsyncGenerator[RequestOutput, None]] = [] try: for i, engine_prompt in enumerate(engine_prompts): default_max_tokens = self.max_model_len - len( engine_prompt["prompt_token_ids"]) sampling_params = request.to_sampling_params( default_max_tokens, self.default_sampling_params) self._log_inputs(request.request_id, request_prompts[i], params=sampling_params, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request) trace_headers = (None if raw_request is None else await self._get_trace_headers(raw_request.headers)) generator = self.engine_client.generate( engine_prompt, sampling_params, request.request_id, lora_request=lora_request, trace_headers=trace_headers, prompt_adapter_request=prompt_adapter_request, priority=request.priority, ) generators.append(generator) except ValueError as e: # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) assert len(generators) == 1 result_generator, = generators # Store the input messages. if request.store: self.msg_store[request.request_id] = messages if request.background: created_time = int(time.time()) response = ResponsesResponse.from_request( request, sampling_params, model_name=model_name, created_time=created_time, output=[], status="queued", usage=None, ) async with self.response_store_lock: self.response_store[response.id] = response # Run the request in the background. task = asyncio.create_task( self._run_background_request( request, sampling_params, result_generator, model_name, tokenizer, request_metadata, created_time, ), name=f"create_{response.id}", ) # For cleanup. response_id = response.id self.background_tasks[response_id] = task task.add_done_callback( lambda _: self.background_tasks.pop(response_id, None)) return response if request.stream: raise NotImplementedError("Streaming responses are not supported") try: return await self.responses_full_generator( request, sampling_params, result_generator, model_name, tokenizer, request_metadata, ) except Exception as e: return self.create_error_response(str(e)) async def responses_full_generator( self, request: ResponsesRequest, sampling_params: SamplingParams, result_generator: AsyncIterator[RequestOutput], model_name: str, tokenizer: AnyTokenizer, request_metadata: RequestResponseMetadata, created_time: Optional[int] = None, ) -> Union[ErrorResponse, ResponsesResponse]: if created_time is None: created_time = int(time.time()) final_res: Optional[RequestOutput] = None try: async for res in result_generator: final_res = res except asyncio.CancelledError: return self.create_error_response("Client disconnected") except ValueError as e: # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) assert final_res is not None assert len(final_res.outputs) == 1 final_output = final_res.outputs[0] if self.reasoning_parser: try: reasoning_parser = self.reasoning_parser(tokenizer) except RuntimeError as e: logger.exception("Error in reasoning parser creation.") return self.create_error_response(str(e)) reasoning_content, content = ( reasoning_parser.extract_reasoning_content(final_output.text, request=request)) else: reasoning_content = None content = final_output.text output = [] if reasoning_content: reasoning_item = ResponseReasoningItem( text=reasoning_content, status=None, # NOTE: Only the last output item has status. ) output.append(reasoning_item) if content: output_text = ResponseOutputText( text=content, annotations=[], # TODO type="output_text", logprobs=None, # TODO ) message = ResponseOutputMessage( id=f"msg_{random_uuid()}", content=[output_text], role="assistant", status="completed", type="message", ) output.append(message) # Calculate usage. assert final_res.prompt_token_ids is not None num_prompt_tokens = len(final_res.prompt_token_ids) num_generated_tokens = len(final_output.token_ids) usage = UsageInfo( prompt_tokens=num_prompt_tokens, completion_tokens=num_generated_tokens, total_tokens=num_prompt_tokens + num_generated_tokens, ) if self.enable_prompt_tokens_details and final_res.num_cached_tokens: usage.prompt_tokens_details = PromptTokenUsageInfo( cached_tokens=final_res.num_cached_tokens) request_metadata.final_usage_info = usage response = ResponsesResponse.from_request( request, sampling_params, model_name=model_name, created_time=created_time, output=output, status="completed", usage=usage, ) if request.store: async with self.response_store_lock: stored_response = self.response_store.get(response.id) # If the response is already cancelled, don't update it. if (stored_response is None or stored_response.status != "cancelled"): self.response_store[response.id] = response return response def _construct_input_messages( self, request: ResponsesRequest, prev_response: Optional[ResponsesResponse] = None, ) -> list[ChatCompletionMessageParam]: messages: list[ChatCompletionMessageParam] = [] if request.instructions: messages.append({ "role": "system", "content": request.instructions, }) # Prepend the conversation history. if prev_response is not None: # Add the previous messages. prev_msg = self.msg_store[prev_response.id] messages.extend(prev_msg) # Add the previous output. for output_item in prev_response.output: # NOTE: We skip the reasoning output. if isinstance(output_item, ResponseOutputMessage): for content in output_item.content: messages.append({ "role": "assistant", "content": content.text, }) # Append the new input. # Responses API supports simple text inputs without chat format. if isinstance(request.input, str): messages.append({"role": "user", "content": request.input}) else: messages.extend(request.input) # type: ignore return messages async def _run_background_request( self, request: ResponsesRequest, *args, **kwargs, ): try: response = await self.responses_full_generator( request, *args, **kwargs) except Exception as e: logger.exception("Background request failed for %s", request.request_id) response = self.create_error_response(str(e)) if isinstance(response, ErrorResponse): # If the request has failed, update the status to "failed". response_id = request.request_id async with self.response_store_lock: stored_response = self.response_store.get(response_id) assert stored_response is not None if stored_response.status not in ("completed", "cancelled"): stored_response.status = "failed" async def retrieve_responses( self, response_id: str, ) -> Union[ErrorResponse, ResponsesResponse]: if not response_id.startswith("resp_"): return self._make_invalid_id_error(response_id) async with self.response_store_lock: response = self.response_store.get(response_id) if response is None: return self._make_not_found_error(response_id) return response async def cancel_responses( self, response_id: str, ) -> Union[ErrorResponse, ResponsesResponse]: if not response_id.startswith("resp_"): return self._make_invalid_id_error(response_id) async with self.response_store_lock: response = self.response_store.get(response_id) if response is None: return self._make_not_found_error(response_id) prev_status = response.status if prev_status not in ("queued", "in_progress"): return self.create_error_response( err_type="invalid_request_error", message="Cannot cancel a synchronous response.", ) # Update the status to "cancelled". response.status = "cancelled" # Abort the request. if (task := self.background_tasks.get(response_id)): task.cancel() try: await task except asyncio.CancelledError: logger.exception("Background task for %s was cancelled", response_id) return response def _make_invalid_id_error(self, response_id: str) -> ErrorResponse: return self.create_error_response( err_type="invalid_request_error", message=(f"Invalid 'response_id': '{response_id}'. " "Expected an ID that begins with 'resp'."), ) def _make_not_found_error(self, response_id: str) -> ErrorResponse: return self.create_error_response( err_type="invalid_request_error", message=f"Response with id '{response_id}' not found.", status_code=HTTPStatus.NOT_FOUND, )