import time from typing import AsyncIterator, List, Tuple from fastapi import Request from vllm.config import ModelConfig from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.entrypoints.openai.protocol import (EmbeddingRequest, EmbeddingResponse, EmbeddingResponseData, UsageInfo) from vllm.entrypoints.openai.serving_completion import parse_prompt_format from vllm.entrypoints.openai.serving_engine import OpenAIServing from vllm.logger import init_logger from vllm.outputs import EmbeddingRequestOutput from vllm.utils import merge_async_iterators, random_uuid logger = init_logger(__name__) TypeTokenIDs = List[int] def request_output_to_embedding_response( final_res_batch: List[EmbeddingRequestOutput], request_id: str, created_time: int, model_name: str, ) -> EmbeddingResponse: data = [] num_prompt_tokens = 0 for idx, final_res in enumerate(final_res_batch): assert final_res is not None prompt_token_ids = final_res.prompt_token_ids embedding_data = EmbeddingResponseData( index=idx, embedding=final_res.outputs.embedding) data.append(embedding_data) num_prompt_tokens += len(prompt_token_ids) usage = UsageInfo( prompt_tokens=num_prompt_tokens, total_tokens=num_prompt_tokens, ) return EmbeddingResponse( id=request_id, created=created_time, model=model_name, data=data, usage=usage, ) class OpenAIServingEmbedding(OpenAIServing): def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig, served_model_names: List[str]): super().__init__(engine=engine, model_config=model_config, served_model_names=served_model_names, lora_modules=None) self._check_embedding_mode(model_config.embedding_mode) async def create_embedding(self, request: EmbeddingRequest, raw_request: Request): """Completion API similar to OpenAI's API. See https://platform.openai.com/docs/api-reference/embeddings/create for the API specification. This API mimics the OpenAI Embedding API. """ error_check_ret = await self._check_model(request) if error_check_ret is not None: return error_check_ret # Return error for unsupported features. if request.encoding_format == "base64": return self.create_error_response( "base64 encoding is not currently supported") if request.dimensions is not None: return self.create_error_response( "dimensions is currently not supported") model_name = request.model request_id = f"cmpl-{random_uuid()}" created_time = int(time.monotonic()) # Schedule the request and get the result generator. generators = [] try: prompt_is_tokens, prompts = parse_prompt_format(request.input) pooling_params = request.to_pooling_params() for i, prompt in enumerate(prompts): if prompt_is_tokens: prompt_formats = self._validate_prompt_and_tokenize( request, prompt_ids=prompt) else: prompt_formats = self._validate_prompt_and_tokenize( request, prompt=prompt) prompt_ids, prompt_text = prompt_formats generators.append( self.engine.generate(prompt_text, pooling_params, f"{request_id}-{i}", prompt_token_ids=prompt_ids)) except ValueError as e: # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) result_generator: AsyncIterator[Tuple[ int, EmbeddingRequestOutput]] = merge_async_iterators(*generators) # Non-streaming response final_res_batch: EmbeddingRequestOutput = [None] * len(prompts) async for i, res in result_generator: if await raw_request.is_disconnected(): # Abort the request if the client disconnects. await self.engine.abort(f"{request_id}-{i}") # TODO: Use a vllm-specific Validation Error return self.create_error_response("Client disconnected") final_res_batch[i] = res response = request_output_to_embedding_response( final_res_batch, request_id, created_time, model_name) return response def _check_embedding_mode(self, embedding_mode: bool): if not embedding_mode: logger.warning( "embedding_mode is False. Embedding API will not work.") else: logger.info("Activating the server engine with embedding enabled.")