Unverified Commit 4c1c501a authored by Chauncey's avatar Chauncey Committed by GitHub
Browse files

[Refactor] [10/N] to simplify the vLLM openai completion serving architecture (#32369)


Signed-off-by: default avatarchaunceyjiang <chaunceyjiang@gmail.com>
parent ae1eba6a
...@@ -12,7 +12,8 @@ from vllm.config.multimodal import MultiModalConfig ...@@ -12,7 +12,8 @@ from vllm.config.multimodal import MultiModalConfig
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat
from vllm.entrypoints.openai.engine.protocol import ErrorResponse from vllm.entrypoints.openai.engine.protocol import ErrorResponse
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels from vllm.entrypoints.openai.models.protocol import BaseModelPath
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.outputs import CompletionOutput, RequestOutput from vllm.outputs import CompletionOutput, RequestOutput
from vllm.tokenizers import get_tokenizer from vllm.tokenizers import get_tokenizer
from vllm.v1.engine.async_llm import AsyncLLM from vllm.v1.engine.async_llm import AsyncLLM
......
...@@ -6,7 +6,7 @@ import json ...@@ -6,7 +6,7 @@ import json
import pytest import pytest
from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_serve_args from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_serve_args
from vllm.entrypoints.openai.serving_models import LoRAModulePath from vllm.entrypoints.openai.models.protocol import LoRAModulePath
from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
from ...utils import VLLM_PATH from ...utils import VLLM_PATH
......
...@@ -9,9 +9,11 @@ from unittest.mock import AsyncMock, MagicMock ...@@ -9,9 +9,11 @@ from unittest.mock import AsyncMock, MagicMock
import pytest import pytest
from vllm.config.multimodal import MultiModalConfig from vllm.config.multimodal import MultiModalConfig
from vllm.entrypoints.openai.engine.protocol import CompletionRequest, ErrorResponse from vllm.entrypoints.openai.completion.protocol import CompletionRequest
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.entrypoints.openai.completion.serving import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels from vllm.entrypoints.openai.engine.protocol import ErrorResponse
from vllm.entrypoints.openai.models.protocol import BaseModelPath
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.outputs import CompletionOutput, RequestOutput from vllm.outputs import CompletionOutput, RequestOutput
from vllm.tokenizers import get_tokenizer from vllm.tokenizers import get_tokenizer
from vllm.v1.engine.async_llm import AsyncLLM from vllm.v1.engine.async_llm import AsyncLLM
......
...@@ -8,13 +8,11 @@ from unittest.mock import Mock ...@@ -8,13 +8,11 @@ from unittest.mock import Mock
import pytest import pytest
from vllm.entrypoints.openai.engine.protocol import (
StructuredOutputsParams,
)
from vllm.entrypoints.tool_server import ToolServer from vllm.entrypoints.tool_server import ToolServer
from vllm.reasoning.gptoss_reasoning_parser import ( from vllm.reasoning.gptoss_reasoning_parser import (
GptOssReasoningParser, GptOssReasoningParser,
) )
from vllm.sampling_params import StructuredOutputsParams
class TestGptOssStructuralTagsIntegration: class TestGptOssStructuralTagsIntegration:
......
...@@ -9,9 +9,11 @@ from unittest.mock import AsyncMock, MagicMock ...@@ -9,9 +9,11 @@ from unittest.mock import AsyncMock, MagicMock
import pytest import pytest
from vllm.config.multimodal import MultiModalConfig from vllm.config.multimodal import MultiModalConfig
from vllm.entrypoints.openai.engine.protocol import CompletionRequest, ErrorResponse from vllm.entrypoints.openai.completion.protocol import CompletionRequest
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.entrypoints.openai.completion.serving import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels from vllm.entrypoints.openai.engine.protocol import ErrorResponse
from vllm.entrypoints.openai.models.protocol import BaseModelPath
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.lora.resolver import LoRAResolver, LoRAResolverRegistry from vllm.lora.resolver import LoRAResolver, LoRAResolverRegistry
from vllm.tokenizers import get_tokenizer from vllm.tokenizers import get_tokenizer
......
...@@ -20,8 +20,8 @@ from vllm.entrypoints.openai.engine.protocol import ( ...@@ -20,8 +20,8 @@ from vllm.entrypoints.openai.engine.protocol import (
ErrorResponse, ErrorResponse,
RequestResponseMetadata, RequestResponseMetadata,
) )
from vllm.entrypoints.openai.models.serving import BaseModelPath, OpenAIServingModels
from vllm.entrypoints.openai.parser.harmony_utils import get_encoding from vllm.entrypoints.openai.parser.harmony_utils import get_encoding
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
from vllm.outputs import CompletionOutput, RequestOutput from vllm.outputs import CompletionOutput, RequestOutput
from vllm.tokenizers import get_tokenizer from vllm.tokenizers import get_tokenizer
from vllm.tool_parsers import ToolParserManager from vllm.tool_parsers import ToolParserManager
......
...@@ -9,7 +9,7 @@ import pytest ...@@ -9,7 +9,7 @@ import pytest
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.entrypoints.openai.engine.serving import OpenAIServing from vllm.entrypoints.openai.engine.serving import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.tokenizers.mistral import MistralTokenizer from vllm.tokenizers.mistral import MistralTokenizer
......
...@@ -11,7 +11,8 @@ from vllm.engine.protocol import EngineClient ...@@ -11,7 +11,8 @@ from vllm.engine.protocol import EngineClient
from vllm.entrypoints.openai.engine.protocol import ( from vllm.entrypoints.openai.engine.protocol import (
ErrorResponse, ErrorResponse,
) )
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels from vllm.entrypoints.openai.models.protocol import BaseModelPath
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.serve.lora.protocol import ( from vllm.entrypoints.serve.lora.protocol import (
LoadLoRAAdapterRequest, LoadLoRAAdapterRequest,
UnloadLoRAAdapterRequest, UnloadLoRAAdapterRequest,
......
...@@ -19,7 +19,8 @@ from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat ...@@ -19,7 +19,8 @@ from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat
from vllm.entrypoints.openai.engine.protocol import ( from vllm.entrypoints.openai.engine.protocol import (
ErrorResponse, ErrorResponse,
) )
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels from vllm.entrypoints.openai.models.protocol import BaseModelPath
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.inputs import PromptType from vllm.inputs import PromptType
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.platforms import current_platform from vllm.platforms import current_platform
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from http import HTTPStatus
from fastapi import APIRouter, Depends, FastAPI, Request
from fastapi.responses import JSONResponse, StreamingResponse
from vllm.entrypoints.anthropic.protocol import (
AnthropicError,
AnthropicErrorResponse,
AnthropicMessagesRequest,
AnthropicMessagesResponse,
)
from vllm.entrypoints.anthropic.serving import AnthropicServingMessages
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
from vllm.entrypoints.openai.utils import validate_json_request
from vllm.entrypoints.utils import (
load_aware_call,
with_cancellation,
)
from vllm.logger import init_logger
logger = init_logger(__name__)
router = APIRouter()
def messages(request: Request) -> AnthropicServingMessages:
return request.app.state.anthropic_serving_messages
@router.post(
"/v1/messages",
dependencies=[Depends(validate_json_request)],
responses={
HTTPStatus.OK.value: {"content": {"text/event-stream": {}}},
HTTPStatus.BAD_REQUEST.value: {"model": AnthropicErrorResponse},
HTTPStatus.NOT_FOUND.value: {"model": AnthropicErrorResponse},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": AnthropicErrorResponse},
},
)
@with_cancellation
@load_aware_call
async def create_messages(request: AnthropicMessagesRequest, raw_request: Request):
def translate_error_response(response: ErrorResponse) -> JSONResponse:
anthropic_error = AnthropicErrorResponse(
error=AnthropicError(
type=response.error.type,
message=response.error.message,
)
)
return JSONResponse(
status_code=response.error.code, content=anthropic_error.model_dump()
)
handler = messages(raw_request)
if handler is None:
base_server = raw_request.app.state.openai_serving_tokenization
error = base_server.create_error_response(
message="The model does not support Messages API"
)
return translate_error_response(error)
try:
generator = await handler.create_messages(request, raw_request)
except Exception as e:
logger.exception("Error in create_messages: %s", e)
return JSONResponse(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
content=AnthropicErrorResponse(
error=AnthropicError(
type="internal_error",
message=str(e),
)
).model_dump(),
)
if isinstance(generator, ErrorResponse):
return translate_error_response(generator)
elif isinstance(generator, AnthropicMessagesResponse):
resp = generator.model_dump(exclude_none=True)
logger.debug("Anthropic Messages Response: %s", resp)
return JSONResponse(content=resp)
return StreamingResponse(content=generator, media_type="text/event-stream")
def attach_router(app: FastAPI):
app.include_router(router)
...@@ -37,7 +37,7 @@ from vllm.entrypoints.openai.engine.protocol import ( ...@@ -37,7 +37,7 @@ from vllm.entrypoints.openai.engine.protocol import (
ErrorResponse, ErrorResponse,
StreamOptions, StreamOptions,
) )
from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.models.serving import OpenAIServingModels
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
...@@ -22,10 +22,10 @@ from typing import Any ...@@ -22,10 +22,10 @@ from typing import Any
import model_hosting_container_standards.sagemaker as sagemaker_standards import model_hosting_container_standards.sagemaker as sagemaker_standards
import pydantic import pydantic
import uvloop import uvloop
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request from fastapi import APIRouter, FastAPI, HTTPException, Request
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse from fastapi.responses import JSONResponse
from starlette.concurrency import iterate_in_threadpool from starlette.concurrency import iterate_in_threadpool
from starlette.datastructures import URL, Headers, MutableHeaders, State from starlette.datastructures import URL, Headers, MutableHeaders, State
from starlette.types import ASGIApp, Message, Receive, Scope, Send from starlette.types import ASGIApp, Message, Receive, Scope, Send
...@@ -33,36 +33,26 @@ from starlette.types import ASGIApp, Message, Receive, Scope, Send ...@@ -33,36 +33,26 @@ from starlette.types import ASGIApp, Message, Receive, Scope, Send
import vllm.envs as envs import vllm.envs as envs
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
from vllm.entrypoints.anthropic.protocol import ( from vllm.entrypoints.anthropic.serving import AnthropicServingMessages
AnthropicError,
AnthropicErrorResponse,
AnthropicMessagesRequest,
AnthropicMessagesResponse,
)
from vllm.entrypoints.anthropic.serving_messages import AnthropicServingMessages
from vllm.entrypoints.launcher import serve_http from vllm.entrypoints.launcher import serve_http
from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat
from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_serve_args from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_serve_args
from vllm.entrypoints.openai.completion.serving import OpenAIServingCompletion
from vllm.entrypoints.openai.engine.protocol import ( from vllm.entrypoints.openai.engine.protocol import (
CompletionRequest,
CompletionResponse,
ErrorInfo, ErrorInfo,
ErrorResponse, ErrorResponse,
) )
from vllm.entrypoints.openai.engine.serving import OpenAIServing from vllm.entrypoints.openai.engine.serving import OpenAIServing
from vllm.entrypoints.openai.orca_metrics import metrics_header from vllm.entrypoints.openai.models.protocol import BaseModelPath
from vllm.entrypoints.openai.responses.serving import OpenAIServingResponses from vllm.entrypoints.openai.models.serving import (
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_models import (
BaseModelPath,
OpenAIServingModels, OpenAIServingModels,
) )
from vllm.entrypoints.openai.responses.serving import OpenAIServingResponses
from vllm.entrypoints.openai.translations.serving import ( from vllm.entrypoints.openai.translations.serving import (
OpenAIServingTranscription, OpenAIServingTranscription,
OpenAIServingTranslation, OpenAIServingTranslation,
) )
from vllm.entrypoints.openai.utils import validate_json_request
from vllm.entrypoints.pooling.classify.serving import ServingClassification from vllm.entrypoints.pooling.classify.serving import ServingClassification
from vllm.entrypoints.pooling.embed.serving import OpenAIServingEmbedding from vllm.entrypoints.pooling.embed.serving import OpenAIServingEmbedding
from vllm.entrypoints.pooling.pooling.serving import OpenAIServingPooling from vllm.entrypoints.pooling.pooling.serving import OpenAIServingPooling
...@@ -75,12 +65,10 @@ from vllm.entrypoints.serve.tokenize.serving import OpenAIServingTokenization ...@@ -75,12 +65,10 @@ from vllm.entrypoints.serve.tokenize.serving import OpenAIServingTokenization
from vllm.entrypoints.tool_server import DemoToolServer, MCPToolServer, ToolServer from vllm.entrypoints.tool_server import DemoToolServer, MCPToolServer, ToolServer
from vllm.entrypoints.utils import ( from vllm.entrypoints.utils import (
cli_env_setup, cli_env_setup,
load_aware_call,
log_non_default_args, log_non_default_args,
process_chat_template, process_chat_template,
process_lora_modules, process_lora_modules,
sanitize_message, sanitize_message,
with_cancellation,
) )
from vllm.exceptions import VLLMValidationError from vllm.exceptions import VLLMValidationError
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -99,7 +87,6 @@ prometheus_multiproc_dir: tempfile.TemporaryDirectory ...@@ -99,7 +87,6 @@ prometheus_multiproc_dir: tempfile.TemporaryDirectory
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765) # Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
logger = init_logger("vllm.entrypoints.openai.api_server") logger = init_logger("vllm.entrypoints.openai.api_server")
ENDPOINT_LOAD_METRICS_FORMAT_HEADER_LABEL = "endpoint-load-metrics-format"
_running_tasks: set[asyncio.Task] = set() _running_tasks: set[asyncio.Task] = set()
...@@ -231,22 +218,6 @@ def base(request: Request) -> OpenAIServing: ...@@ -231,22 +218,6 @@ def base(request: Request) -> OpenAIServing:
return tokenization(request) return tokenization(request)
def models(request: Request) -> OpenAIServingModels:
return request.app.state.openai_serving_models
def messages(request: Request) -> AnthropicServingMessages:
return request.app.state.anthropic_serving_messages
def chat(request: Request) -> OpenAIServingChat | None:
return request.app.state.openai_serving_chat
def completion(request: Request) -> OpenAIServingCompletion | None:
return request.app.state.openai_serving_completion
def tokenization(request: Request) -> OpenAIServingTokenization: def tokenization(request: Request) -> OpenAIServingTokenization:
return request.app.state.openai_serving_tokenization return request.app.state.openai_serving_tokenization
...@@ -278,116 +249,12 @@ async def get_server_load_metrics(request: Request): ...@@ -278,116 +249,12 @@ async def get_server_load_metrics(request: Request):
return JSONResponse(content={"server_load": request.app.state.server_load_metrics}) return JSONResponse(content={"server_load": request.app.state.server_load_metrics})
@router.get("/v1/models")
async def show_available_models(raw_request: Request):
handler = models(raw_request)
models_ = await handler.show_available_models()
return JSONResponse(content=models_.model_dump())
@router.get("/version") @router.get("/version")
async def show_version(): async def show_version():
ver = {"version": VLLM_VERSION} ver = {"version": VLLM_VERSION}
return JSONResponse(content=ver) return JSONResponse(content=ver)
@router.post(
"/v1/messages",
dependencies=[Depends(validate_json_request)],
responses={
HTTPStatus.OK.value: {"content": {"text/event-stream": {}}},
HTTPStatus.BAD_REQUEST.value: {"model": AnthropicErrorResponse},
HTTPStatus.NOT_FOUND.value: {"model": AnthropicErrorResponse},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": AnthropicErrorResponse},
},
)
@with_cancellation
@load_aware_call
async def create_messages(request: AnthropicMessagesRequest, raw_request: Request):
def translate_error_response(response: ErrorResponse) -> JSONResponse:
anthropic_error = AnthropicErrorResponse(
error=AnthropicError(
type=response.error.type,
message=response.error.message,
)
)
return JSONResponse(
status_code=response.error.code, content=anthropic_error.model_dump()
)
handler = messages(raw_request)
if handler is None:
error = base(raw_request).create_error_response(
message="The model does not support Messages API"
)
return translate_error_response(error)
try:
generator = await handler.create_messages(request, raw_request)
except Exception as e:
logger.exception("Error in create_messages: %s", e)
return JSONResponse(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
content=AnthropicErrorResponse(
error=AnthropicError(
type="internal_error",
message=str(e),
)
).model_dump(),
)
if isinstance(generator, ErrorResponse):
return translate_error_response(generator)
elif isinstance(generator, AnthropicMessagesResponse):
resp = generator.model_dump(exclude_none=True)
logger.debug("Anthropic Messages Response: %s", resp)
return JSONResponse(content=resp)
return StreamingResponse(content=generator, media_type="text/event-stream")
@router.post(
"/v1/completions",
dependencies=[Depends(validate_json_request)],
responses={
HTTPStatus.OK.value: {"content": {"text/event-stream": {}}},
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
},
)
@with_cancellation
@load_aware_call
async def create_completion(request: CompletionRequest, raw_request: Request):
metrics_header_format = raw_request.headers.get(
ENDPOINT_LOAD_METRICS_FORMAT_HEADER_LABEL, ""
)
handler = completion(raw_request)
if handler is None:
return base(raw_request).create_error_response(
message="The model does not support Completions API"
)
try:
generator = await handler.create_completion(request, raw_request)
except Exception as e:
return handler.create_error_response(e)
if isinstance(generator, ErrorResponse):
return JSONResponse(
content=generator.model_dump(), status_code=generator.error.code
)
elif isinstance(generator, CompletionResponse):
return JSONResponse(
content=generator.model_dump(),
headers=metrics_header(metrics_header_format),
)
return StreamingResponse(content=generator, media_type="text/event-stream")
def load_log_config(log_config_file: str | None) -> dict | None: def load_log_config(log_config_file: str | None) -> dict | None:
if not log_config_file: if not log_config_file:
return None return None
...@@ -486,7 +353,7 @@ def _extract_content_from_chunk(chunk_data: dict) -> str: ...@@ -486,7 +353,7 @@ def _extract_content_from_chunk(chunk_data: dict) -> str:
from vllm.entrypoints.openai.chat_completion.protocol import ( from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionStreamResponse, ChatCompletionStreamResponse,
) )
from vllm.entrypoints.openai.engine.protocol import ( from vllm.entrypoints.openai.completion.protocol import (
CompletionStreamResponse, CompletionStreamResponse,
) )
...@@ -646,6 +513,22 @@ def build_app(args: Namespace) -> FastAPI: ...@@ -646,6 +513,22 @@ def build_app(args: Namespace) -> FastAPI:
) )
register_translations_api_router(app) register_translations_api_router(app)
from vllm.entrypoints.openai.completion.api_router import (
attach_router as register_completion_api_router,
)
register_completion_api_router(app)
from vllm.entrypoints.anthropic.api_router import (
attach_router as register_anthropic_api_router,
)
register_anthropic_api_router(app)
from vllm.entrypoints.openai.models.api_router import (
attach_router as register_models_api_router,
)
register_models_api_router(app)
from vllm.entrypoints.sagemaker.routes import register_sagemaker_routes from vllm.entrypoints.sagemaker.routes import register_sagemaker_routes
register_sagemaker_routes(router) register_sagemaker_routes(router)
......
...@@ -54,6 +54,7 @@ from vllm.entrypoints.openai.engine.serving import ( ...@@ -54,6 +54,7 @@ from vllm.entrypoints.openai.engine.serving import (
OpenAIServing, OpenAIServing,
clamp_prompt_logprobs, clamp_prompt_logprobs,
) )
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.openai.parser.harmony_utils import ( from vllm.entrypoints.openai.parser.harmony_utils import (
get_developer_message, get_developer_message,
get_stop_tokens_for_assistant_actions, get_stop_tokens_for_assistant_actions,
...@@ -63,7 +64,6 @@ from vllm.entrypoints.openai.parser.harmony_utils import ( ...@@ -63,7 +64,6 @@ from vllm.entrypoints.openai.parser.harmony_utils import (
parse_chat_output, parse_chat_output,
render_for_completion, render_for_completion,
) )
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.openai.utils import maybe_filter_parallel_tool_calls from vllm.entrypoints.openai.utils import maybe_filter_parallel_tool_calls
from vllm.entrypoints.utils import get_max_tokens, should_include_usage from vllm.entrypoints.utils import get_max_tokens, should_include_usage
from vllm.inputs.data import TokensPrompt from vllm.inputs.data import TokensPrompt
......
...@@ -26,7 +26,7 @@ from vllm.entrypoints.constants import ( ...@@ -26,7 +26,7 @@ from vllm.entrypoints.constants import (
H11_MAX_HEADER_COUNT_DEFAULT, H11_MAX_HEADER_COUNT_DEFAULT,
H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT, H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT,
) )
from vllm.entrypoints.openai.serving_models import LoRAModulePath from vllm.entrypoints.openai.models.protocol import LoRAModulePath
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.tool_parsers import ToolParserManager from vllm.tool_parsers import ToolParserManager
from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from http import HTTPStatus
from fastapi import APIRouter, Depends, FastAPI, Request
from fastapi.responses import JSONResponse, StreamingResponse
from vllm.entrypoints.openai.completion.protocol import (
CompletionRequest,
CompletionResponse,
)
from vllm.entrypoints.openai.completion.serving import OpenAIServingCompletion
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
from vllm.entrypoints.openai.orca_metrics import metrics_header
from vllm.entrypoints.openai.utils import validate_json_request
from vllm.entrypoints.utils import (
load_aware_call,
with_cancellation,
)
from vllm.logger import init_logger
logger = init_logger(__name__)
router = APIRouter()
ENDPOINT_LOAD_METRICS_FORMAT_HEADER_LABEL = "endpoint-load-metrics-format"
def completion(request: Request) -> OpenAIServingCompletion | None:
return request.app.state.openai_serving_completion
@router.post(
"/v1/completions",
dependencies=[Depends(validate_json_request)],
responses={
HTTPStatus.OK.value: {"content": {"text/event-stream": {}}},
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
},
)
@with_cancellation
@load_aware_call
async def create_completion(request: CompletionRequest, raw_request: Request):
metrics_header_format = raw_request.headers.get(
ENDPOINT_LOAD_METRICS_FORMAT_HEADER_LABEL, ""
)
handler = completion(raw_request)
if handler is None:
base_server = raw_request.app.state.openai_serving_tokenization
return base_server.create_error_response(
message="The model does not support Completions API"
)
try:
generator = await handler.create_completion(request, raw_request)
except Exception as e:
return handler.create_error_response(e)
if isinstance(generator, ErrorResponse):
return JSONResponse(
content=generator.model_dump(), status_code=generator.error.code
)
elif isinstance(generator, CompletionResponse):
return JSONResponse(
content=generator.model_dump(),
headers=metrics_header(metrics_header_format),
)
return StreamingResponse(content=generator, media_type="text/event-stream")
def attach_router(app: FastAPI):
app.include_router(router)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
import json
import time
from typing import Annotated, Any, Literal
import torch
from pydantic import (
Field,
model_validator,
)
from vllm.entrypoints.openai.engine.protocol import (
AnyResponseFormat,
LegacyStructuralTagResponseFormat,
LogitsProcessors,
OpenAIBaseModel,
StreamOptions,
StructuralTagResponseFormat,
UsageInfo,
get_logits_processors,
)
from vllm.exceptions import VLLMValidationError
from vllm.logger import init_logger
from vllm.logprobs import Logprob
from vllm.sampling_params import (
BeamSearchParams,
RequestOutputKind,
SamplingParams,
StructuredOutputsParams,
)
from vllm.utils import random_uuid
logger = init_logger(__name__)
_LONG_INFO = torch.iinfo(torch.long)
class CompletionRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/completions/create
model: str | None = None
prompt: list[int] | list[list[int]] | str | list[str] | None = None
echo: bool | None = False
frequency_penalty: float | None = 0.0
logit_bias: dict[str, float] | None = None
logprobs: int | None = None
max_tokens: int | None = 16
n: int = 1
presence_penalty: float | None = 0.0
seed: int | None = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
stop: str | list[str] | None = []
stream: bool | None = False
stream_options: StreamOptions | None = None
suffix: str | None = None
temperature: float | None = None
top_p: float | None = None
user: str | None = None
# --8<-- [start:completion-sampling-params]
use_beam_search: bool = False
top_k: int | None = None
min_p: float | None = None
repetition_penalty: float | None = None
length_penalty: float = 1.0
stop_token_ids: list[int] | None = []
include_stop_str_in_output: bool = False
ignore_eos: bool = False
min_tokens: int = 0
skip_special_tokens: bool = True
spaces_between_special_tokens: bool = True
truncate_prompt_tokens: Annotated[int, Field(ge=-1, le=_LONG_INFO.max)] | None = (
None
)
allowed_token_ids: list[int] | None = None
prompt_logprobs: int | None = None
# --8<-- [end:completion-sampling-params]
# --8<-- [start:completion-extra-params]
prompt_embeds: bytes | list[bytes] | None = None
add_special_tokens: bool = Field(
default=True,
description=(
"If true (the default), special tokens (e.g. BOS) will be added to "
"the prompt."
),
)
response_format: AnyResponseFormat | None = Field(
default=None,
description=(
"Similar to chat completion, this parameter specifies the format "
"of output. Only {'type': 'json_object'}, {'type': 'json_schema'}"
", {'type': 'structural_tag'}, or {'type': 'text' } is supported."
),
)
structured_outputs: StructuredOutputsParams | None = Field(
default=None,
description="Additional kwargs for structured outputs",
)
priority: int = Field(
default=0,
description=(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."
),
)
request_id: str = Field(
default_factory=random_uuid,
description=(
"The request_id related to this request. If the caller does "
"not set it, a random_uuid will be generated. This id is used "
"through out the inference process and return in response."
),
)
logits_processors: LogitsProcessors | None = Field(
default=None,
description=(
"A list of either qualified names of logits processors, or "
"constructor objects, to apply when sampling. A constructor is "
"a JSON object with a required 'qualname' field specifying the "
"qualified name of the processor class/factory, and optional "
"'args' and 'kwargs' fields containing positional and keyword "
"arguments. For example: {'qualname': "
"'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': "
"{'param': 'value'}}."
),
)
return_tokens_as_token_ids: bool | None = Field(
default=None,
description=(
"If specified with 'logprobs', tokens are represented "
" as strings of the form 'token_id:{token_id}' so that tokens "
"that are not JSON-encodable can be identified."
),
)
return_token_ids: bool | None = Field(
default=None,
description=(
"If specified, the result will include token IDs alongside the "
"generated text. In streaming mode, prompt_token_ids is included "
"only in the first chunk, and token_ids contains the delta tokens "
"for each chunk. This is useful for debugging or when you "
"need to map generated text back to input tokens."
),
)
cache_salt: str | None = Field(
default=None,
description=(
"If specified, the prefix cache will be salted with the provided "
"string to prevent an attacker to guess prompts in multi-user "
"environments. The salt should be random, protected from "
"access by 3rd parties, and long enough to be "
"unpredictable (e.g., 43 characters base64-encoded, corresponding "
"to 256 bit)."
),
)
kv_transfer_params: dict[str, Any] | None = Field(
default=None,
description="KVTransfer parameters used for disaggregated serving.",
)
vllm_xargs: dict[str, str | int | float] | None = Field(
default=None,
description=(
"Additional request parameters with string or "
"numeric values, used by custom extensions."
),
)
# --8<-- [end:completion-extra-params]
# Default sampling parameters for completion requests
_DEFAULT_SAMPLING_PARAMS: dict = {
"repetition_penalty": 1.0,
"temperature": 1.0,
"top_p": 1.0,
"top_k": 0,
"min_p": 0.0,
}
def to_beam_search_params(
self,
max_tokens: int,
default_sampling_params: dict | None = None,
) -> BeamSearchParams:
if default_sampling_params is None:
default_sampling_params = {}
n = self.n if self.n is not None else 1
if (temperature := self.temperature) is None:
temperature = default_sampling_params.get("temperature", 1.0)
return BeamSearchParams(
beam_width=n,
max_tokens=max_tokens,
ignore_eos=self.ignore_eos,
temperature=temperature,
length_penalty=self.length_penalty,
include_stop_str_in_output=self.include_stop_str_in_output,
)
def to_sampling_params(
self,
max_tokens: int,
logits_processor_pattern: str | None,
default_sampling_params: dict | None = None,
) -> SamplingParams:
if default_sampling_params is None:
default_sampling_params = {}
# Default parameters
if (repetition_penalty := self.repetition_penalty) is None:
repetition_penalty = default_sampling_params.get(
"repetition_penalty",
self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"],
)
if (temperature := self.temperature) is None:
temperature = default_sampling_params.get(
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]
)
if (top_p := self.top_p) is None:
top_p = default_sampling_params.get(
"top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]
)
if (top_k := self.top_k) is None:
top_k = default_sampling_params.get(
"top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"]
)
if (min_p := self.min_p) is None:
min_p = default_sampling_params.get(
"min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"]
)
prompt_logprobs = self.prompt_logprobs
if prompt_logprobs is None and self.echo:
prompt_logprobs = self.logprobs
echo_without_generation = self.echo and self.max_tokens == 0
response_format = self.response_format
if response_format is not None:
# If structured outputs wasn't already enabled,
# we must enable it for these features to work
if self.structured_outputs is None:
self.structured_outputs = StructuredOutputsParams()
# Set structured output params for response format
if response_format.type == "json_object":
self.structured_outputs.json_object = True
elif response_format.type == "json_schema":
json_schema = response_format.json_schema
assert json_schema is not None
self.structured_outputs.json = json_schema.json_schema
elif response_format.type == "structural_tag":
structural_tag = response_format
assert structural_tag is not None and isinstance(
structural_tag,
(
LegacyStructuralTagResponseFormat,
StructuralTagResponseFormat,
),
)
s_tag_obj = structural_tag.model_dump(by_alias=True)
self.structured_outputs.structural_tag = json.dumps(s_tag_obj)
extra_args: dict[str, Any] = self.vllm_xargs if self.vllm_xargs else {}
if self.kv_transfer_params:
# Pass in kv_transfer_params via extra_args
extra_args["kv_transfer_params"] = self.kv_transfer_params
return SamplingParams.from_optional(
n=self.n,
presence_penalty=self.presence_penalty,
frequency_penalty=self.frequency_penalty,
repetition_penalty=repetition_penalty,
temperature=temperature,
top_p=top_p,
top_k=top_k,
min_p=min_p,
seed=self.seed,
stop=self.stop,
stop_token_ids=self.stop_token_ids,
logprobs=self.logprobs,
ignore_eos=self.ignore_eos,
max_tokens=max_tokens if not echo_without_generation else 1,
min_tokens=self.min_tokens,
prompt_logprobs=prompt_logprobs,
skip_special_tokens=self.skip_special_tokens,
spaces_between_special_tokens=self.spaces_between_special_tokens,
include_stop_str_in_output=self.include_stop_str_in_output,
logits_processors=get_logits_processors(
self.logits_processors, logits_processor_pattern
),
truncate_prompt_tokens=self.truncate_prompt_tokens,
output_kind=RequestOutputKind.DELTA
if self.stream
else RequestOutputKind.FINAL_ONLY,
structured_outputs=self.structured_outputs,
logit_bias=self.logit_bias,
allowed_token_ids=self.allowed_token_ids,
extra_args=extra_args or None,
skip_clone=True, # Created fresh per request, safe to skip clone
)
@model_validator(mode="before")
@classmethod
def check_structured_outputs_count(cls, data):
if data.get("structured_outputs", None) is None:
return data
structured_outputs_kwargs = data["structured_outputs"]
count = sum(
structured_outputs_kwargs.get(k) is not None
for k in ("json", "regex", "choice")
)
if count > 1:
raise VLLMValidationError(
"You can only use one kind of constraints for structured "
"outputs ('json', 'regex' or 'choice').",
parameter="structured_outputs",
)
return data
@model_validator(mode="before")
@classmethod
def check_logprobs(cls, data):
if (prompt_logprobs := data.get("prompt_logprobs")) is not None:
if data.get("stream") and (prompt_logprobs > 0 or prompt_logprobs == -1):
raise VLLMValidationError(
"`prompt_logprobs` are not available when `stream=True`.",
parameter="prompt_logprobs",
)
if prompt_logprobs < 0 and prompt_logprobs != -1:
raise VLLMValidationError(
"`prompt_logprobs` must be a positive value or -1.",
parameter="prompt_logprobs",
value=prompt_logprobs,
)
if (logprobs := data.get("logprobs")) is not None and logprobs < 0:
raise VLLMValidationError(
"`logprobs` must be a positive value.",
parameter="logprobs",
value=logprobs,
)
return data
@model_validator(mode="before")
@classmethod
def validate_stream_options(cls, data):
if data.get("stream_options") and not data.get("stream"):
raise VLLMValidationError(
"Stream options can only be defined when `stream=True`.",
parameter="stream_options",
)
return data
@model_validator(mode="before")
@classmethod
def validate_prompt_and_prompt_embeds(cls, data):
prompt = data.get("prompt")
prompt_embeds = data.get("prompt_embeds")
prompt_is_empty = prompt is None or (isinstance(prompt, str) and prompt == "")
embeds_is_empty = prompt_embeds is None or (
isinstance(prompt_embeds, list) and len(prompt_embeds) == 0
)
if prompt_is_empty and embeds_is_empty:
raise ValueError(
"Either prompt or prompt_embeds must be provided and non-empty."
)
return data
@model_validator(mode="before")
@classmethod
def check_cache_salt_support(cls, data):
if data.get("cache_salt") is not None and (
not isinstance(data["cache_salt"], str) or not data["cache_salt"]
):
raise ValueError(
"Parameter 'cache_salt' must be a non-empty string if provided."
)
return data
class CompletionLogProbs(OpenAIBaseModel):
text_offset: list[int] = Field(default_factory=list)
token_logprobs: list[float | None] = Field(default_factory=list)
tokens: list[str] = Field(default_factory=list)
top_logprobs: list[dict[str, float] | None] = Field(default_factory=list)
class CompletionResponseChoice(OpenAIBaseModel):
index: int
text: str
logprobs: CompletionLogProbs | None = None
finish_reason: str | None = None
stop_reason: int | str | None = Field(
default=None,
description=(
"The stop string or token id that caused the completion "
"to stop, None if the completion finished for some other reason "
"including encountering the EOS token"
),
)
token_ids: list[int] | None = None # For response
prompt_logprobs: list[dict[int, Logprob] | None] | None = None
prompt_token_ids: list[int] | None = None # For prompt
class CompletionResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
object: Literal["text_completion"] = "text_completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: list[CompletionResponseChoice]
service_tier: Literal["auto", "default", "flex", "scale", "priority"] | None = None
system_fingerprint: str | None = None
usage: UsageInfo
# vLLM-specific fields that are not in OpenAI spec
kv_transfer_params: dict[str, Any] | None = Field(
default=None, description="KVTransfer parameters."
)
class CompletionResponseStreamChoice(OpenAIBaseModel):
index: int
text: str
logprobs: CompletionLogProbs | None = None
finish_reason: str | None = None
stop_reason: int | str | None = Field(
default=None,
description=(
"The stop string or token id that caused the completion "
"to stop, None if the completion finished for some other reason "
"including encountering the EOS token"
),
)
# not part of the OpenAI spec but for tracing the tokens
# prompt tokens is put into choice to align with CompletionResponseChoice
prompt_token_ids: list[int] | None = None
token_ids: list[int] | None = None
class CompletionStreamResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
object: str = "text_completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: list[CompletionResponseStreamChoice]
usage: UsageInfo | None = Field(default=None)
...@@ -12,27 +12,29 @@ from fastapi import Request ...@@ -12,27 +12,29 @@ from fastapi import Request
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.engine.protocol import ( from vllm.entrypoints.openai.completion.protocol import (
CompletionLogProbs, CompletionLogProbs,
CompletionRequest, CompletionRequest,
CompletionResponse, CompletionResponse,
CompletionResponseChoice, CompletionResponseChoice,
CompletionResponseStreamChoice, CompletionResponseStreamChoice,
CompletionStreamResponse, CompletionStreamResponse,
)
from vllm.entrypoints.openai.engine.protocol import (
ErrorResponse, ErrorResponse,
PromptTokenUsageInfo, PromptTokenUsageInfo,
RequestResponseMetadata, RequestResponseMetadata,
UsageInfo, UsageInfo,
VLLMValidationError,
) )
from vllm.entrypoints.openai.engine.serving import ( from vllm.entrypoints.openai.engine.serving import (
GenerationError, GenerationError,
OpenAIServing, OpenAIServing,
clamp_prompt_logprobs, clamp_prompt_logprobs,
) )
from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.renderer import RenderConfig from vllm.entrypoints.renderer import RenderConfig
from vllm.entrypoints.utils import get_max_tokens, should_include_usage from vllm.entrypoints.utils import get_max_tokens, should_include_usage
from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import EmbedsPrompt, TokensPrompt, is_embeds_prompt from vllm.inputs.data import EmbedsPrompt, TokensPrompt, is_embeds_prompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logprobs import Logprob from vllm.logprobs import Logprob
......
...@@ -3,9 +3,8 @@ ...@@ -3,9 +3,8 @@
# Adapted from # Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py # https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
import json
import time import time
from typing import Annotated, Any, ClassVar, Literal, TypeAlias from typing import Any, ClassVar, Literal, TypeAlias
import regex as re import regex as re
import torch import torch
...@@ -17,14 +16,9 @@ from pydantic import ( ...@@ -17,14 +16,9 @@ from pydantic import (
) )
from vllm.entrypoints.chat_utils import make_tool_call_id from vllm.entrypoints.chat_utils import make_tool_call_id
from vllm.exceptions import VLLMValidationError
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logprobs import Logprob
from vllm.sampling_params import ( from vllm.sampling_params import (
BeamSearchParams,
RequestOutputKind,
SamplingParams, SamplingParams,
StructuredOutputsParams,
) )
from vllm.utils import random_uuid from vllm.utils import random_uuid
from vllm.utils.import_utils import resolve_obj_by_qualname from vllm.utils.import_utils import resolve_obj_by_qualname
...@@ -226,429 +220,6 @@ def get_logits_processors( ...@@ -226,429 +220,6 @@ def get_logits_processors(
return None return None
class CompletionRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/completions/create
model: str | None = None
prompt: list[int] | list[list[int]] | str | list[str] | None = None
echo: bool | None = False
frequency_penalty: float | None = 0.0
logit_bias: dict[str, float] | None = None
logprobs: int | None = None
max_tokens: int | None = 16
n: int = 1
presence_penalty: float | None = 0.0
seed: int | None = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
stop: str | list[str] | None = []
stream: bool | None = False
stream_options: StreamOptions | None = None
suffix: str | None = None
temperature: float | None = None
top_p: float | None = None
user: str | None = None
# --8<-- [start:completion-sampling-params]
use_beam_search: bool = False
top_k: int | None = None
min_p: float | None = None
repetition_penalty: float | None = None
length_penalty: float = 1.0
stop_token_ids: list[int] | None = []
include_stop_str_in_output: bool = False
ignore_eos: bool = False
min_tokens: int = 0
skip_special_tokens: bool = True
spaces_between_special_tokens: bool = True
truncate_prompt_tokens: Annotated[int, Field(ge=-1, le=_LONG_INFO.max)] | None = (
None
)
allowed_token_ids: list[int] | None = None
prompt_logprobs: int | None = None
# --8<-- [end:completion-sampling-params]
# --8<-- [start:completion-extra-params]
prompt_embeds: bytes | list[bytes] | None = None
add_special_tokens: bool = Field(
default=True,
description=(
"If true (the default), special tokens (e.g. BOS) will be added to "
"the prompt."
),
)
response_format: AnyResponseFormat | None = Field(
default=None,
description=(
"Similar to chat completion, this parameter specifies the format "
"of output. Only {'type': 'json_object'}, {'type': 'json_schema'}"
", {'type': 'structural_tag'}, or {'type': 'text' } is supported."
),
)
structured_outputs: StructuredOutputsParams | None = Field(
default=None,
description="Additional kwargs for structured outputs",
)
priority: int = Field(
default=0,
description=(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."
),
)
request_id: str = Field(
default_factory=random_uuid,
description=(
"The request_id related to this request. If the caller does "
"not set it, a random_uuid will be generated. This id is used "
"through out the inference process and return in response."
),
)
logits_processors: LogitsProcessors | None = Field(
default=None,
description=(
"A list of either qualified names of logits processors, or "
"constructor objects, to apply when sampling. A constructor is "
"a JSON object with a required 'qualname' field specifying the "
"qualified name of the processor class/factory, and optional "
"'args' and 'kwargs' fields containing positional and keyword "
"arguments. For example: {'qualname': "
"'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': "
"{'param': 'value'}}."
),
)
return_tokens_as_token_ids: bool | None = Field(
default=None,
description=(
"If specified with 'logprobs', tokens are represented "
" as strings of the form 'token_id:{token_id}' so that tokens "
"that are not JSON-encodable can be identified."
),
)
return_token_ids: bool | None = Field(
default=None,
description=(
"If specified, the result will include token IDs alongside the "
"generated text. In streaming mode, prompt_token_ids is included "
"only in the first chunk, and token_ids contains the delta tokens "
"for each chunk. This is useful for debugging or when you "
"need to map generated text back to input tokens."
),
)
cache_salt: str | None = Field(
default=None,
description=(
"If specified, the prefix cache will be salted with the provided "
"string to prevent an attacker to guess prompts in multi-user "
"environments. The salt should be random, protected from "
"access by 3rd parties, and long enough to be "
"unpredictable (e.g., 43 characters base64-encoded, corresponding "
"to 256 bit)."
),
)
kv_transfer_params: dict[str, Any] | None = Field(
default=None,
description="KVTransfer parameters used for disaggregated serving.",
)
vllm_xargs: dict[str, str | int | float] | None = Field(
default=None,
description=(
"Additional request parameters with string or "
"numeric values, used by custom extensions."
),
)
# --8<-- [end:completion-extra-params]
# Default sampling parameters for completion requests
_DEFAULT_SAMPLING_PARAMS: dict = {
"repetition_penalty": 1.0,
"temperature": 1.0,
"top_p": 1.0,
"top_k": 0,
"min_p": 0.0,
}
def to_beam_search_params(
self,
max_tokens: int,
default_sampling_params: dict | None = None,
) -> BeamSearchParams:
if default_sampling_params is None:
default_sampling_params = {}
n = self.n if self.n is not None else 1
if (temperature := self.temperature) is None:
temperature = default_sampling_params.get("temperature", 1.0)
return BeamSearchParams(
beam_width=n,
max_tokens=max_tokens,
ignore_eos=self.ignore_eos,
temperature=temperature,
length_penalty=self.length_penalty,
include_stop_str_in_output=self.include_stop_str_in_output,
)
def to_sampling_params(
self,
max_tokens: int,
logits_processor_pattern: str | None,
default_sampling_params: dict | None = None,
) -> SamplingParams:
if default_sampling_params is None:
default_sampling_params = {}
# Default parameters
if (repetition_penalty := self.repetition_penalty) is None:
repetition_penalty = default_sampling_params.get(
"repetition_penalty",
self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"],
)
if (temperature := self.temperature) is None:
temperature = default_sampling_params.get(
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]
)
if (top_p := self.top_p) is None:
top_p = default_sampling_params.get(
"top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]
)
if (top_k := self.top_k) is None:
top_k = default_sampling_params.get(
"top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"]
)
if (min_p := self.min_p) is None:
min_p = default_sampling_params.get(
"min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"]
)
prompt_logprobs = self.prompt_logprobs
if prompt_logprobs is None and self.echo:
prompt_logprobs = self.logprobs
echo_without_generation = self.echo and self.max_tokens == 0
response_format = self.response_format
if response_format is not None:
# If structured outputs wasn't already enabled,
# we must enable it for these features to work
if self.structured_outputs is None:
self.structured_outputs = StructuredOutputsParams()
# Set structured output params for response format
if response_format.type == "json_object":
self.structured_outputs.json_object = True
elif response_format.type == "json_schema":
json_schema = response_format.json_schema
assert json_schema is not None
self.structured_outputs.json = json_schema.json_schema
elif response_format.type == "structural_tag":
structural_tag = response_format
assert structural_tag is not None and isinstance(
structural_tag,
(
LegacyStructuralTagResponseFormat,
StructuralTagResponseFormat,
),
)
s_tag_obj = structural_tag.model_dump(by_alias=True)
self.structured_outputs.structural_tag = json.dumps(s_tag_obj)
extra_args: dict[str, Any] = self.vllm_xargs if self.vllm_xargs else {}
if self.kv_transfer_params:
# Pass in kv_transfer_params via extra_args
extra_args["kv_transfer_params"] = self.kv_transfer_params
return SamplingParams.from_optional(
n=self.n,
presence_penalty=self.presence_penalty,
frequency_penalty=self.frequency_penalty,
repetition_penalty=repetition_penalty,
temperature=temperature,
top_p=top_p,
top_k=top_k,
min_p=min_p,
seed=self.seed,
stop=self.stop,
stop_token_ids=self.stop_token_ids,
logprobs=self.logprobs,
ignore_eos=self.ignore_eos,
max_tokens=max_tokens if not echo_without_generation else 1,
min_tokens=self.min_tokens,
prompt_logprobs=prompt_logprobs,
skip_special_tokens=self.skip_special_tokens,
spaces_between_special_tokens=self.spaces_between_special_tokens,
include_stop_str_in_output=self.include_stop_str_in_output,
logits_processors=get_logits_processors(
self.logits_processors, logits_processor_pattern
),
truncate_prompt_tokens=self.truncate_prompt_tokens,
output_kind=RequestOutputKind.DELTA
if self.stream
else RequestOutputKind.FINAL_ONLY,
structured_outputs=self.structured_outputs,
logit_bias=self.logit_bias,
allowed_token_ids=self.allowed_token_ids,
extra_args=extra_args or None,
skip_clone=True, # Created fresh per request, safe to skip clone
)
@model_validator(mode="before")
@classmethod
def check_structured_outputs_count(cls, data):
if data.get("structured_outputs", None) is None:
return data
structured_outputs_kwargs = data["structured_outputs"]
count = sum(
structured_outputs_kwargs.get(k) is not None
for k in ("json", "regex", "choice")
)
if count > 1:
raise VLLMValidationError(
"You can only use one kind of constraints for structured "
"outputs ('json', 'regex' or 'choice').",
parameter="structured_outputs",
)
return data
@model_validator(mode="before")
@classmethod
def check_logprobs(cls, data):
if (prompt_logprobs := data.get("prompt_logprobs")) is not None:
if data.get("stream") and (prompt_logprobs > 0 or prompt_logprobs == -1):
raise VLLMValidationError(
"`prompt_logprobs` are not available when `stream=True`.",
parameter="prompt_logprobs",
)
if prompt_logprobs < 0 and prompt_logprobs != -1:
raise VLLMValidationError(
"`prompt_logprobs` must be a positive value or -1.",
parameter="prompt_logprobs",
value=prompt_logprobs,
)
if (logprobs := data.get("logprobs")) is not None and logprobs < 0:
raise VLLMValidationError(
"`logprobs` must be a positive value.",
parameter="logprobs",
value=logprobs,
)
return data
@model_validator(mode="before")
@classmethod
def validate_stream_options(cls, data):
if data.get("stream_options") and not data.get("stream"):
raise VLLMValidationError(
"Stream options can only be defined when `stream=True`.",
parameter="stream_options",
)
return data
@model_validator(mode="before")
@classmethod
def validate_prompt_and_prompt_embeds(cls, data):
prompt = data.get("prompt")
prompt_embeds = data.get("prompt_embeds")
prompt_is_empty = prompt is None or (isinstance(prompt, str) and prompt == "")
embeds_is_empty = prompt_embeds is None or (
isinstance(prompt_embeds, list) and len(prompt_embeds) == 0
)
if prompt_is_empty and embeds_is_empty:
raise ValueError(
"Either prompt or prompt_embeds must be provided and non-empty."
)
return data
@model_validator(mode="before")
@classmethod
def check_cache_salt_support(cls, data):
if data.get("cache_salt") is not None and (
not isinstance(data["cache_salt"], str) or not data["cache_salt"]
):
raise ValueError(
"Parameter 'cache_salt' must be a non-empty string if provided."
)
return data
class CompletionLogProbs(OpenAIBaseModel):
text_offset: list[int] = Field(default_factory=list)
token_logprobs: list[float | None] = Field(default_factory=list)
tokens: list[str] = Field(default_factory=list)
top_logprobs: list[dict[str, float] | None] = Field(default_factory=list)
class CompletionResponseChoice(OpenAIBaseModel):
index: int
text: str
logprobs: CompletionLogProbs | None = None
finish_reason: str | None = None
stop_reason: int | str | None = Field(
default=None,
description=(
"The stop string or token id that caused the completion "
"to stop, None if the completion finished for some other reason "
"including encountering the EOS token"
),
)
token_ids: list[int] | None = None # For response
prompt_logprobs: list[dict[int, Logprob] | None] | None = None
prompt_token_ids: list[int] | None = None # For prompt
class CompletionResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
object: Literal["text_completion"] = "text_completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: list[CompletionResponseChoice]
service_tier: Literal["auto", "default", "flex", "scale", "priority"] | None = None
system_fingerprint: str | None = None
usage: UsageInfo
# vLLM-specific fields that are not in OpenAI spec
kv_transfer_params: dict[str, Any] | None = Field(
default=None, description="KVTransfer parameters."
)
class CompletionResponseStreamChoice(OpenAIBaseModel):
index: int
text: str
logprobs: CompletionLogProbs | None = None
finish_reason: str | None = None
stop_reason: int | str | None = Field(
default=None,
description=(
"The stop string or token id that caused the completion "
"to stop, None if the completion finished for some other reason "
"including encountering the EOS token"
),
)
# not part of the OpenAI spec but for tracing the tokens
# prompt tokens is put into choice to align with CompletionResponseChoice
prompt_token_ids: list[int] | None = None
token_ids: list[int] | None = None
class CompletionStreamResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
object: str = "text_completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: list[CompletionResponseStreamChoice]
usage: UsageInfo | None = Field(default=None)
class FunctionCall(OpenAIBaseModel): class FunctionCall(OpenAIBaseModel):
name: str name: str
arguments: str arguments: str
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment