# SPDX-License-Identifier: Apache-2.0 import asyncio import time from collections.abc import AsyncGenerator, AsyncIterator from collections.abc import Sequence as GenericSequence from typing import Optional, Union, cast import jinja2 from fastapi import Request from typing_extensions import assert_never from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient from vllm.entrypoints.logger import RequestLogger # yapf conflicts with isort for this block # yapf: disable from vllm.entrypoints.openai.protocol import (CompletionLogProbs, CompletionRequest, CompletionResponse, CompletionResponseChoice, CompletionResponseStreamChoice, CompletionStreamResponse, ErrorResponse, RequestResponseMetadata, UsageInfo) # yapf: enable from vllm.entrypoints.openai.serving_engine import (OpenAIServing, clamp_prompt_logprobs, is_text_tokens_prompt) from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.inputs.data import (EmbedsPrompt, TokensPrompt, is_embeds_prompt, is_tokens_prompt) from vllm.logger import init_logger from vllm.outputs import RequestOutput from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.sequence import Logprob from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import merge_async_iterators logger = init_logger(__name__) class OpenAIServingCompletion(OpenAIServing): def __init__( self, engine_client: EngineClient, model_config: ModelConfig, models: OpenAIServingModels, *, request_logger: Optional[RequestLogger], return_tokens_as_token_ids: bool = False, ): super().__init__(engine_client=engine_client, model_config=model_config, models=models, request_logger=request_logger, return_tokens_as_token_ids=return_tokens_as_token_ids) self.default_sampling_params = ( self.model_config.get_diff_sampling_param()) if self.default_sampling_params: source = self.model_config.generation_config source = "model" if source == "auto" else source logger.info("Using default completion sampling params from %s: %s", source, self.default_sampling_params) async def create_completion( self, request: CompletionRequest, raw_request: Optional[Request] = None, ) -> Union[AsyncGenerator[str, None], CompletionResponse, ErrorResponse]: """Completion API similar to OpenAI's API. See https://platform.openai.com/docs/api-reference/completions/create for the API specification. This API mimics the OpenAI Completion API. NOTE: Currently we do not support the following feature: - suffix (the language models we currently support do not support suffix) """ error_check_ret = await self._check_model(request) if error_check_ret is not None: return error_check_ret # If the engine is dead, raise the engine's DEAD_ERROR. # This is required for the streaming case, where we return a # success status before we actually start generating text :). if self.engine_client.errored: raise self.engine_client.dead_error # Return error for unsupported features. if request.suffix is not None: return self.create_error_response( "suffix is not currently supported") if request.echo and request.prompt_embeds is not None: return self.create_error_response( "Echo is unsupported with prompt embeds.") request_id = f"cmpl-{self._base_request_id(raw_request)}" created_time = int(time.time()) request_metadata = RequestResponseMetadata(request_id=request_id) if raw_request: raw_request.state.request_metadata = request_metadata try: ( lora_request, prompt_adapter_request, ) = self._maybe_get_adapters(request) tokenizer = await self.engine_client.get_tokenizer(lora_request) request_prompts, engine_prompts = await self._preprocess_completion( request, tokenizer, request.prompt, truncate_prompt_tokens=request.truncate_prompt_tokens, add_special_tokens=request.add_special_tokens, ) except ValueError as e: logger.exception("Error in preprocessing prompt inputs") return self.create_error_response(str(e)) except TypeError as e: logger.exception("Error in preprocessing prompt inputs") return self.create_error_response(str(e)) except RuntimeError as e: logger.exception("Error in preprocessing prompt inputs") return self.create_error_response(str(e)) except jinja2.TemplateError as e: logger.exception("Error in preprocessing prompt inputs") return self.create_error_response(str(e)) # Schedule the request and get the result generator. generators: list[AsyncGenerator[RequestOutput, None]] = [] try: for i, engine_prompt in enumerate(engine_prompts): sampling_params: Union[SamplingParams, BeamSearchParams] # Mypy does not infer that engine_prompt will have only one of # "prompt_token_ids" or "prompt_embeds" defined, and both of # these as Union[object, the expected type], where it infers # object if engine_prompt is a subclass of one of the # typeddicts that defines both keys. Worse, because of # https://github.com/python/mypy/issues/8586, mypy does not # infer the type of engine_prompt correctly because of the # enumerate. So we need an unnecessary cast here. engine_prompt = cast(Union[EmbedsPrompt, TokensPrompt], engine_prompt) if is_embeds_prompt(engine_prompt): input_length = len(engine_prompt["prompt_embeds"]) elif is_tokens_prompt(engine_prompt): input_length = len(engine_prompt["prompt_token_ids"]) else: assert_never(engine_prompt) default_max_tokens = self.max_model_len - input_length if request.use_beam_search: sampling_params = request.to_beam_search_params( default_max_tokens, self.default_sampling_params) else: sampling_params = request.to_sampling_params( default_max_tokens, self.model_config.logits_processor_pattern, self.default_sampling_params) request_id_item = f"{request_id}-{i}" self._log_inputs(request_id_item, request_prompts[i], params=sampling_params, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request) trace_headers = (None if raw_request is None else await self._get_trace_headers(raw_request.headers)) # Mypy inconsistently requires this second cast in different # environments. It shouldn't be necessary (redundant from above) # but pre-commit in CI fails without it. engine_prompt = cast(Union[EmbedsPrompt, TokensPrompt], engine_prompt) if isinstance(sampling_params, BeamSearchParams): generator = self.engine_client.beam_search( prompt=engine_prompt, request_id=request_id, params=sampling_params, lora_request=lora_request, ) else: generator = self.engine_client.generate( engine_prompt, sampling_params, request_id_item, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, trace_headers=trace_headers, priority=request.priority, ) generators.append(generator) except ValueError as e: # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) result_generator = merge_async_iterators(*generators) model_name = self._get_model_name(request.model, lora_request) num_prompts = len(engine_prompts) # Similar to the OpenAI API, when n != best_of, we do not stream the # results. Noting that best_of is only supported in V0. In addition, # we do not stream the results when use beam search. stream = (request.stream and (request.best_of is None or request.n == request.best_of) and not request.use_beam_search) # Streaming response if stream: return self.completion_stream_generator( request, result_generator, request_id, created_time, model_name, num_prompts=num_prompts, tokenizer=tokenizer, request_metadata=request_metadata) # Non-streaming response final_res_batch: list[Optional[RequestOutput]] = [None] * num_prompts try: async for i, res in result_generator: final_res_batch[i] = res for i, final_res in enumerate(final_res_batch): assert final_res is not None # The output should contain the input text # We did not pass it into vLLM engine to avoid being redundant # with the inputs token IDs if final_res.prompt is None: request_prompt = request_prompts[i] if is_text_tokens_prompt(request_prompt): final_res.prompt = request_prompt["prompt"] else: final_res.prompt = None final_res_batch_checked = cast(list[RequestOutput], final_res_batch) response = self.request_output_to_completion_response( final_res_batch_checked, request, request_id, created_time, model_name, tokenizer, request_metadata, ) except asyncio.CancelledError: return self.create_error_response("Client disconnected") except ValueError as e: # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) # When user requests streaming but we don't stream, we still need to # return a streaming response with a single event. if request.stream: response_json = response.model_dump_json() async def fake_stream_generator() -> AsyncGenerator[str, None]: yield f"data: {response_json}\n\n" yield "data: [DONE]\n\n" return fake_stream_generator() return response async def completion_stream_generator( self, request: CompletionRequest, result_generator: AsyncIterator[tuple[int, RequestOutput]], request_id: str, created_time: int, model_name: str, num_prompts: int, tokenizer: AnyTokenizer, request_metadata: RequestResponseMetadata, ) -> AsyncGenerator[str, None]: num_choices = 1 if request.n is None else request.n previous_text_lens = [0] * num_choices * num_prompts previous_num_tokens = [0] * num_choices * num_prompts has_echoed = [False] * num_choices * num_prompts num_prompt_tokens = [0] * num_prompts stream_options = request.stream_options if stream_options: include_usage = stream_options.include_usage include_continuous_usage = include_usage and \ stream_options.continuous_usage_stats else: include_usage, include_continuous_usage = False, False try: async for prompt_idx, res in result_generator: prompt_token_ids = res.prompt_token_ids prompt_logprobs = res.prompt_logprobs prompt_text = res.prompt # Prompt details are excluded from later streamed outputs if prompt_token_ids is not None: num_prompt_tokens[prompt_idx] = len(prompt_token_ids) delta_token_ids: GenericSequence[int] out_logprobs: Optional[GenericSequence[Optional[dict[ int, Logprob]]]] for output in res.outputs: i = output.index + prompt_idx * num_choices assert request.max_tokens is not None if request.echo and not has_echoed[i]: assert prompt_token_ids is not None assert prompt_text is not None if request.max_tokens == 0: # only return the prompt delta_text = prompt_text delta_token_ids = prompt_token_ids out_logprobs = prompt_logprobs else: assert prompt_logprobs is not None # echo the prompt and first token delta_text = prompt_text + output.text delta_token_ids = [ *prompt_token_ids, *output.token_ids ] out_logprobs = [ *prompt_logprobs, *(output.logprobs or []), ] has_echoed[i] = True else: # return just the delta delta_text = output.text delta_token_ids = output.token_ids out_logprobs = output.logprobs if not delta_text and not delta_token_ids \ and not previous_num_tokens[i]: # Chunked prefill case, don't return empty chunks continue if request.logprobs is not None: assert out_logprobs is not None, ( "Did not output logprobs") logprobs = self._create_completion_logprobs( token_ids=delta_token_ids, top_logprobs=out_logprobs, num_output_top_logprobs=request.logprobs, tokenizer=tokenizer, initial_text_offset=previous_text_lens[i], return_as_token_id=request. return_tokens_as_token_ids, ) else: logprobs = None previous_text_lens[i] += len(output.text) previous_num_tokens[i] += len(output.token_ids) finish_reason = output.finish_reason stop_reason = output.stop_reason chunk = CompletionStreamResponse( id=request_id, created=created_time, model=model_name, choices=[ CompletionResponseStreamChoice( index=i, text=delta_text, logprobs=logprobs, finish_reason=finish_reason, stop_reason=stop_reason, ) ]) if include_continuous_usage: prompt_tokens = num_prompt_tokens[prompt_idx] completion_tokens = previous_num_tokens[i] chunk.usage = UsageInfo( prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=prompt_tokens + completion_tokens, ) response_json = chunk.model_dump_json(exclude_unset=False) yield f"data: {response_json}\n\n" total_prompt_tokens = sum(num_prompt_tokens) total_completion_tokens = sum(previous_num_tokens) final_usage_info = UsageInfo( prompt_tokens=total_prompt_tokens, completion_tokens=total_completion_tokens, total_tokens=total_prompt_tokens + total_completion_tokens) if include_usage: final_usage_chunk = CompletionStreamResponse( id=request_id, created=created_time, model=model_name, choices=[], usage=final_usage_info, ) final_usage_data = (final_usage_chunk.model_dump_json( exclude_unset=False, exclude_none=True)) yield f"data: {final_usage_data}\n\n" # report to FastAPI middleware aggregate usage across all choices request_metadata.final_usage_info = final_usage_info except Exception as e: # TODO: Use a vllm-specific Validation Error data = self.create_streaming_error_response(str(e)) yield f"data: {data}\n\n" yield "data: [DONE]\n\n" def request_output_to_completion_response( self, final_res_batch: list[RequestOutput], request: CompletionRequest, request_id: str, created_time: int, model_name: str, tokenizer: AnyTokenizer, request_metadata: RequestResponseMetadata, ) -> CompletionResponse: choices: list[CompletionResponseChoice] = [] num_prompt_tokens = 0 num_generated_tokens = 0 for final_res in final_res_batch: prompt_token_ids = final_res.prompt_token_ids assert prompt_token_ids is not None prompt_logprobs = clamp_prompt_logprobs(final_res.prompt_logprobs) prompt_text = final_res.prompt token_ids: GenericSequence[int] out_logprobs: Optional[GenericSequence[Optional[dict[int, Logprob]]]] for output in final_res.outputs: assert request.max_tokens is not None if request.echo: assert prompt_text is not None if request.max_tokens == 0: token_ids = prompt_token_ids out_logprobs = prompt_logprobs output_text = prompt_text else: token_ids = [*prompt_token_ids, *output.token_ids] if request.logprobs is None: out_logprobs = None else: assert prompt_logprobs is not None assert output.logprobs is not None out_logprobs = [ *prompt_logprobs, *output.logprobs, ] output_text = prompt_text + output.text else: token_ids = output.token_ids out_logprobs = output.logprobs output_text = output.text if request.logprobs is not None: assert out_logprobs is not None, "Did not output logprobs" logprobs = self._create_completion_logprobs( token_ids=token_ids, top_logprobs=out_logprobs, tokenizer=tokenizer, num_output_top_logprobs=request.logprobs, return_as_token_id=request.return_tokens_as_token_ids, ) else: logprobs = None choice_data = CompletionResponseChoice( index=len(choices), text=output_text, logprobs=logprobs, finish_reason=output.finish_reason, stop_reason=output.stop_reason, prompt_logprobs=final_res.prompt_logprobs, ) choices.append(choice_data) num_generated_tokens += len(output.token_ids) num_prompt_tokens += len(prompt_token_ids) usage = UsageInfo( prompt_tokens=num_prompt_tokens, completion_tokens=num_generated_tokens, total_tokens=num_prompt_tokens + num_generated_tokens, ) request_metadata.final_usage_info = usage return CompletionResponse( id=request_id, created=created_time, model=model_name, choices=choices, usage=usage, kv_transfer_params=final_res_batch[0].kv_transfer_params) def _create_completion_logprobs( self, token_ids: GenericSequence[int], top_logprobs: GenericSequence[Optional[dict[int, Logprob]]], num_output_top_logprobs: int, tokenizer: AnyTokenizer, initial_text_offset: int = 0, return_as_token_id: Optional[bool] = None, ) -> CompletionLogProbs: """Create logprobs for OpenAI Completion API.""" out_text_offset: list[int] = [] out_token_logprobs: list[Optional[float]] = [] out_tokens: list[str] = [] out_top_logprobs: list[Optional[dict[str, float]]] = [] last_token_len = 0 should_return_as_token_id = return_as_token_id if \ return_as_token_id is not None else self.return_tokens_as_token_ids for i, token_id in enumerate(token_ids): step_top_logprobs = top_logprobs[i] if step_top_logprobs is None: token = tokenizer.decode(token_id) if should_return_as_token_id: token = f"token_id:{token_id}" out_tokens.append(token) out_token_logprobs.append(None) out_top_logprobs.append(None) else: step_token = step_top_logprobs[token_id] token = self._get_decoded_token( step_token, token_id, tokenizer, return_as_token_id=should_return_as_token_id, ) token_logprob = max(step_token.logprob, -9999.0) out_tokens.append(token) out_token_logprobs.append(token_logprob) # makes sure to add the top num_output_top_logprobs + 1 # logprobs, as defined in the openai API # (cf. https://github.com/openai/openai-openapi/blob/ # 893ba52242dbd5387a97b96444ee1c742cfce9bd/openapi.yaml#L7153) out_top_logprobs.append({ # Convert float("-inf") to the # JSON-serializable float that OpenAI uses self._get_decoded_token(top_lp[1], top_lp[0], tokenizer, return_as_token_id=should_return_as_token_id): max(top_lp[1].logprob, -9999.0) for i, top_lp in enumerate(step_top_logprobs.items()) if num_output_top_logprobs >= i }) if len(out_text_offset) == 0: out_text_offset.append(initial_text_offset) else: out_text_offset.append(out_text_offset[-1] + last_token_len) last_token_len = len(token) return CompletionLogProbs( text_offset=out_text_offset, token_logprobs=out_token_logprobs, tokens=out_tokens, top_logprobs=out_top_logprobs, )