Unverified Commit 76139d08 authored by wang.yuqi's avatar wang.yuqi Committed by GitHub
Browse files

[Frontend] Frontend will only attach supported tasks corresponding entrypoints. (#33139)


Signed-off-by: default avatarwang.yuqi <yuqi.wang@daocloud.io>
Signed-off-by: default avatarwang.yuqi <noooop@126.com>
Co-authored-by: default avatarCyrus Leung <cyrus.tl.leung@gmail.com>
parent da8d0c44
......@@ -112,18 +112,14 @@ async def test_long_audio_request(mary_had_lamb, whisper_client):
@pytest.mark.asyncio
async def test_completion_endpoints(whisper_client):
# text to text model
res = await whisper_client.chat.completions.create(
with pytest.raises(openai.NotFoundError):
await whisper_client.chat.completions.create(
model=MODEL_NAME,
messages=[{"role": "system", "content": "You are a helpful assistant."}],
)
err = res.error
assert err["code"] == 400
assert err["message"] == "The model does not support Chat Completions API"
res = await whisper_client.completions.create(model=MODEL_NAME, prompt="Hello")
err = res.error
assert err["code"] == 400
assert err["message"] == "The model does not support Completions API"
with pytest.raises(openai.NotFoundError):
await whisper_client.completions.create(model=MODEL_NAME, prompt="Hello")
@pytest.mark.asyncio
......
......@@ -9,6 +9,7 @@ import json
import httpx
import librosa
import numpy as np
import openai
import pytest
import pytest_asyncio
import soundfile as sf
......@@ -52,12 +53,11 @@ async def test_non_asr_model(foscolo, rocm_aiter_fa_attention):
model_name, _get_server_args(rocm_aiter_fa_attention)
) as remote_server:
client = remote_server.get_async_client()
res = await client.audio.translations.create(
with pytest.raises(openai.NotFoundError):
await client.audio.translations.create(
model=model_name, file=foscolo, temperature=0.0
)
err = res.error
assert err["code"] == 400 and not res.text
assert err["message"] == "The model does not support Translations API"
@pytest.mark.asyncio
......
......@@ -401,7 +401,7 @@ async def test_score(server: RemoteOpenAIServer, model_name: str):
"documents": "pong",
},
)
assert response.json()["error"]["type"] == "BadRequestError"
assert response.json()["detail"] == "Not Found"
@pytest.mark.asyncio
......@@ -416,7 +416,7 @@ async def test_rerank(server: RemoteOpenAIServer, model_name: str):
"documents": ["pong"],
},
)
assert response.json()["error"]["type"] == "BadRequestError"
assert response.json()["detail"] == "Not Found"
@pytest.mark.asyncio
......
......@@ -33,14 +33,10 @@ from starlette.types import ASGIApp, Message, Receive, Scope, Send
import vllm.envs as envs
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.anthropic.serving import AnthropicServingMessages
from vllm.entrypoints.chat_utils import load_chat_template
from vllm.entrypoints.launcher import serve_http
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.mcp.tool_server import DemoToolServer, MCPToolServer, ToolServer
from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat
from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_serve_args
from vllm.entrypoints.openai.completion.serving import OpenAIServingCompletion
from vllm.entrypoints.openai.engine.protocol import (
ErrorInfo,
ErrorResponse,
......@@ -50,12 +46,6 @@ from vllm.entrypoints.openai.models.protocol import BaseModelPath
from vllm.entrypoints.openai.models.serving import (
OpenAIServingModels,
)
from vllm.entrypoints.openai.responses.serving import OpenAIServingResponses
from vllm.entrypoints.openai.translations.serving import (
OpenAIServingTranscription,
OpenAIServingTranslation,
)
from vllm.entrypoints.serve.disagg.serving import ServingTokens
from vllm.entrypoints.serve.elastic_ep.middleware import (
ScalingMiddleware,
)
......@@ -70,6 +60,7 @@ from vllm.entrypoints.utils import (
from vllm.exceptions import VLLMValidationError
from vllm.logger import init_logger
from vllm.reasoning import ReasoningParserManager
from vllm.tasks import POOLING_TASKS, SupportedTask
from vllm.tool_parsers import ToolParserManager
from vllm.usage.usage_lib import UsageContext
from vllm.utils.argparse_utils import FlexibleArgumentParser
......@@ -513,7 +504,7 @@ def _log_non_streaming_response(response_body: list) -> None:
logger.info("response_body={<binary_data>}")
def build_app(args: Namespace) -> FastAPI:
def build_app(args: Namespace, supported_tasks: tuple["SupportedTask", ...]) -> FastAPI:
if args.disable_fastapi_docs:
app = FastAPI(
openapi_url=None, docs_url=None, redoc_url=None, lifespan=lifespan
......@@ -523,52 +514,44 @@ def build_app(args: Namespace) -> FastAPI:
else:
app = FastAPI(lifespan=lifespan)
app.state.args = args
app.include_router(router)
from vllm.entrypoints.serve import register_vllm_serve_api_routers
register_vllm_serve_api_routers(app)
from vllm.entrypoints.openai.chat_completion.api_router import (
attach_router as register_chat_api_router,
)
register_chat_api_router(app)
from vllm.entrypoints.openai.responses.api_router import (
attach_router as register_responses_api_router,
)
register_responses_api_router(app)
from vllm.entrypoints.openai.translations.api_router import (
attach_router as register_translations_api_router,
from vllm.entrypoints.openai.models.api_router import (
attach_router as register_models_api_router,
)
register_translations_api_router(app)
register_models_api_router(app)
from vllm.entrypoints.openai.completion.api_router import (
attach_router as register_completion_api_router,
from vllm.entrypoints.sagemaker.api_router import (
attach_router as register_sagemaker_api_router,
)
register_completion_api_router(app)
from vllm.entrypoints.anthropic.api_router import (
attach_router as register_anthropic_api_router,
)
register_sagemaker_api_router(app, supported_tasks)
register_anthropic_api_router(app)
from vllm.entrypoints.openai.models.api_router import (
attach_router as register_models_api_router,
if "generate" in supported_tasks:
from vllm.entrypoints.openai.generate.api_router import (
register_generate_api_routers,
)
register_models_api_router(app)
from vllm.entrypoints.sagemaker.routes import register_sagemaker_routes
register_generate_api_routers(app)
register_sagemaker_routes(router)
app.include_router(router)
if "transcription" in supported_tasks:
from vllm.entrypoints.openai.translations.api_router import (
attach_router as register_translations_api_router,
)
app.root_path = args.root_path
register_translations_api_router(app)
if any(task in POOLING_TASKS for task in supported_tasks):
from vllm.entrypoints.pooling import register_pooling_api_routers
register_pooling_api_routers(app)
register_pooling_api_routers(app, supported_tasks)
app.root_path = args.root_path
app.add_middleware(
CORSMiddleware,
allow_origins=args.allowed_origins,
......@@ -673,6 +656,7 @@ async def init_app_state(
engine_client: EngineClient,
state: State,
args: Namespace,
supported_tasks: tuple["SupportedTask", ...],
) -> None:
vllm_config = engine_client.vllm_config
......@@ -694,28 +678,9 @@ async def init_app_state(
state.log_stats = not args.disable_log_stats
state.vllm_config = vllm_config
state.args = args
supported_tasks = await engine_client.get_supported_tasks()
logger.info("Supported tasks: %s", supported_tasks)
resolved_chat_template = load_chat_template(args.chat_template)
if args.tool_server == "demo":
tool_server: ToolServer | None = DemoToolServer()
assert isinstance(tool_server, DemoToolServer)
await tool_server.init_and_validate()
elif args.tool_server:
tool_server = MCPToolServer()
await tool_server.add_tool_server(args.tool_server)
else:
tool_server = None
# Merge default_mm_loras into the static lora_modules
default_mm_loras = (
vllm_config.lora_config.default_mm_loras
if vllm_config.lora_config is not None
else {}
)
default_mm_loras = (
vllm_config.lora_config.default_mm_loras
if vllm_config.lora_config is not None
......@@ -729,66 +694,6 @@ async def init_app_state(
lora_modules=lora_modules,
)
await state.openai_serving_models.init_static_loras()
state.openai_serving_responses = (
OpenAIServingResponses(
engine_client,
state.openai_serving_models,
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
enable_auto_tools=args.enable_auto_tool_choice,
tool_parser=args.tool_call_parser,
tool_server=tool_server,
reasoning_parser=args.structured_outputs_config.reasoning_parser,
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
enable_force_include_usage=args.enable_force_include_usage,
enable_log_outputs=args.enable_log_outputs,
log_error_stack=args.log_error_stack,
)
if "generate" in supported_tasks
else None
)
state.openai_serving_chat = (
OpenAIServingChat(
engine_client,
state.openai_serving_models,
args.response_role,
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
default_chat_template_kwargs=args.default_chat_template_kwargs,
trust_request_chat_template=args.trust_request_chat_template,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
enable_auto_tools=args.enable_auto_tool_choice,
exclude_tools_when_tool_choice_none=args.exclude_tools_when_tool_choice_none,
tool_parser=args.tool_call_parser,
reasoning_parser=args.structured_outputs_config.reasoning_parser,
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
enable_force_include_usage=args.enable_force_include_usage,
enable_log_outputs=args.enable_log_outputs,
enable_log_deltas=args.enable_log_deltas,
log_error_stack=args.log_error_stack,
)
if "generate" in supported_tasks
else None
)
# Warm up chat template processing to avoid first-request latency
if state.openai_serving_chat is not None:
await state.openai_serving_chat.warmup()
state.openai_serving_completion = (
OpenAIServingCompletion(
engine_client,
state.openai_serving_models,
request_logger=request_logger,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
enable_force_include_usage=args.enable_force_include_usage,
log_error_stack=args.log_error_stack,
)
if "generate" in supported_tasks
else None
)
state.openai_serving_tokenization = OpenAIServingTokenization(
engine_client,
state.openai_serving_models,
......@@ -798,64 +703,27 @@ async def init_app_state(
trust_request_chat_template=args.trust_request_chat_template,
log_error_stack=args.log_error_stack,
)
state.openai_serving_transcription = (
OpenAIServingTranscription(
engine_client,
state.openai_serving_models,
request_logger=request_logger,
log_error_stack=args.log_error_stack,
enable_force_include_usage=args.enable_force_include_usage,
)
if "transcription" in supported_tasks
else None
)
state.openai_serving_translation = (
OpenAIServingTranslation(
engine_client,
state.openai_serving_models,
request_logger=request_logger,
log_error_stack=args.log_error_stack,
enable_force_include_usage=args.enable_force_include_usage,
)
if "transcription" in supported_tasks
else None
)
state.anthropic_serving_messages = (
AnthropicServingMessages(
engine_client,
state.openai_serving_models,
args.response_role,
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
enable_auto_tools=args.enable_auto_tool_choice,
tool_parser=args.tool_call_parser,
reasoning_parser=args.structured_outputs_config.reasoning_parser,
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
enable_force_include_usage=args.enable_force_include_usage,
)
if "generate" in supported_tasks
else None
if "generate" in supported_tasks:
from vllm.entrypoints.openai.generate.api_router import init_generate_state
await init_generate_state(
engine_client, state, args, request_logger, supported_tasks
)
state.serving_tokens = (
ServingTokens(
engine_client,
state.openai_serving_models,
request_logger=request_logger,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
log_error_stack=args.log_error_stack,
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
enable_log_outputs=args.enable_log_outputs,
force_no_detokenize=args.tokens_only,
if "transcription" in supported_tasks:
from vllm.entrypoints.openai.translations.api_router import (
init_transcription_state,
)
if "generate" in supported_tasks
else None
init_transcription_state(
engine_client, state, args, request_logger, supported_tasks
)
if any(task in POOLING_TASKS for task in supported_tasks):
from vllm.entrypoints.pooling import init_pooling_state
await init_pooling_state(engine_client, state, args)
init_pooling_state(engine_client, state, args, request_logger, supported_tasks)
state.enable_server_load_tracking = args.enable_server_load_tracking
state.server_load_metrics = 0
......@@ -972,9 +840,11 @@ async def run_server_worker(
args,
client_config=client_config,
) as engine_client:
app = build_app(args)
supported_tasks = await engine_client.get_supported_tasks()
logger.info("Supported tasks: %s", supported_tasks)
await init_app_state(engine_client, app.state, args)
app = build_app(args, supported_tasks)
await init_app_state(engine_client, app.state, args, supported_tasks)
logger.info(
"Starting vLLM API server %d on %s",
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING
from fastapi import FastAPI
if TYPE_CHECKING:
from argparse import Namespace
from starlette.datastructures import State
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger
from vllm.tasks import SupportedTask
else:
RequestLogger = object
def register_generate_api_routers(app: FastAPI):
from vllm.entrypoints.openai.chat_completion.api_router import (
attach_router as register_chat_api_router,
)
register_chat_api_router(app)
from vllm.entrypoints.openai.responses.api_router import (
attach_router as register_responses_api_router,
)
register_responses_api_router(app)
from vllm.entrypoints.openai.completion.api_router import (
attach_router as register_completion_api_router,
)
register_completion_api_router(app)
from vllm.entrypoints.anthropic.api_router import (
attach_router as register_anthropic_api_router,
)
register_anthropic_api_router(app)
async def init_generate_state(
engine_client: "EngineClient",
state: "State",
args: "Namespace",
request_logger: RequestLogger | None,
supported_tasks: tuple["SupportedTask", ...],
):
from vllm.entrypoints.anthropic.serving import AnthropicServingMessages
from vllm.entrypoints.chat_utils import load_chat_template
from vllm.entrypoints.mcp.tool_server import (
DemoToolServer,
MCPToolServer,
ToolServer,
)
from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat
from vllm.entrypoints.openai.completion.serving import OpenAIServingCompletion
from vllm.entrypoints.openai.responses.serving import OpenAIServingResponses
from vllm.entrypoints.serve.disagg.serving import ServingTokens
if args.tool_server == "demo":
tool_server: ToolServer | None = DemoToolServer()
assert isinstance(tool_server, DemoToolServer)
await tool_server.init_and_validate()
elif args.tool_server:
tool_server = MCPToolServer()
await tool_server.add_tool_server(args.tool_server)
else:
tool_server = None
resolved_chat_template = load_chat_template(args.chat_template)
state.openai_serving_responses = (
OpenAIServingResponses(
engine_client,
state.openai_serving_models,
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
enable_auto_tools=args.enable_auto_tool_choice,
tool_parser=args.tool_call_parser,
tool_server=tool_server,
reasoning_parser=args.structured_outputs_config.reasoning_parser,
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
enable_force_include_usage=args.enable_force_include_usage,
enable_log_outputs=args.enable_log_outputs,
log_error_stack=args.log_error_stack,
)
if "generate" in supported_tasks
else None
)
state.openai_serving_chat = (
OpenAIServingChat(
engine_client,
state.openai_serving_models,
args.response_role,
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
default_chat_template_kwargs=args.default_chat_template_kwargs,
trust_request_chat_template=args.trust_request_chat_template,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
enable_auto_tools=args.enable_auto_tool_choice,
exclude_tools_when_tool_choice_none=args.exclude_tools_when_tool_choice_none,
tool_parser=args.tool_call_parser,
reasoning_parser=args.structured_outputs_config.reasoning_parser,
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
enable_force_include_usage=args.enable_force_include_usage,
enable_log_outputs=args.enable_log_outputs,
enable_log_deltas=args.enable_log_deltas,
log_error_stack=args.log_error_stack,
)
if "generate" in supported_tasks
else None
)
# Warm up chat template processing to avoid first-request latency
if state.openai_serving_chat is not None:
await state.openai_serving_chat.warmup()
state.openai_serving_completion = (
OpenAIServingCompletion(
engine_client,
state.openai_serving_models,
request_logger=request_logger,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
enable_force_include_usage=args.enable_force_include_usage,
log_error_stack=args.log_error_stack,
)
if "generate" in supported_tasks
else None
)
state.anthropic_serving_messages = (
AnthropicServingMessages(
engine_client,
state.openai_serving_models,
args.response_role,
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
enable_auto_tools=args.enable_auto_tool_choice,
tool_parser=args.tool_call_parser,
reasoning_parser=args.structured_outputs_config.reasoning_parser,
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
enable_force_include_usage=args.enable_force_include_usage,
)
if "generate" in supported_tasks
else None
)
state.serving_tokens = (
ServingTokens(
engine_client,
state.openai_serving_models,
request_logger=request_logger,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
log_error_stack=args.log_error_stack,
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
enable_log_outputs=args.enable_log_outputs,
force_no_detokenize=args.tokens_only,
)
if "generate" in supported_tasks
else None
)
......@@ -3,7 +3,7 @@
from http import HTTPStatus
from typing import Annotated
from typing import TYPE_CHECKING, Annotated
from fastapi import APIRouter, FastAPI, Form, Request
from fastapi.responses import JSONResponse, StreamingResponse
......@@ -25,6 +25,17 @@ from vllm.entrypoints.utils import (
)
from vllm.logger import init_logger
if TYPE_CHECKING:
from argparse import Namespace
from starlette.datastructures import State
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger
from vllm.tasks import SupportedTask
else:
RequestLogger = object
logger = init_logger(__name__)
router = APIRouter()
......@@ -115,3 +126,34 @@ async def create_translations(
def attach_router(app: FastAPI):
app.include_router(router)
def init_transcription_state(
engine_client: "EngineClient",
state: "State",
args: "Namespace",
request_logger: RequestLogger | None,
supported_tasks: tuple["SupportedTask", ...],
):
state.openai_serving_transcription = (
OpenAIServingTranscription(
engine_client,
state.openai_serving_models,
request_logger=request_logger,
log_error_stack=args.log_error_stack,
enable_force_include_usage=args.enable_force_include_usage,
)
if "transcription" in supported_tasks
else None
)
state.openai_serving_translation = (
OpenAIServingTranslation(
engine_client,
state.openai_serving_models,
request_logger=request_logger,
log_error_stack=args.log_error_stack,
enable_force_include_usage=args.enable_force_include_usage,
)
if "transcription" in supported_tasks
else None
)
......@@ -11,40 +11,54 @@ if TYPE_CHECKING:
from starlette.datastructures import State
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger
from vllm.tasks import SupportedTask
else:
RequestLogger = object
SupportedTask = object
def register_pooling_api_routers(app: FastAPI):
from vllm.entrypoints.pooling.classify.api_router import router as classify_router
from vllm.entrypoints.pooling.embed.api_router import router as embed_router
def register_pooling_api_routers(
app: FastAPI, supported_tasks: tuple["SupportedTask", ...]
):
from vllm.entrypoints.pooling.pooling.api_router import router as pooling_router
from vllm.entrypoints.pooling.score.api_router import router as score_router
app.include_router(pooling_router)
if "classify" in supported_tasks:
from vllm.entrypoints.pooling.classify.api_router import (
router as classify_router,
)
app.include_router(classify_router)
if "embed" in supported_tasks:
from vllm.entrypoints.pooling.embed.api_router import router as embed_router
app.include_router(embed_router)
if "score" in supported_tasks or "embed" in supported_tasks:
from vllm.entrypoints.pooling.score.api_router import router as score_router
app.include_router(score_router)
app.include_router(pooling_router)
async def init_pooling_state(
engine_client: "EngineClient", state: "State", args: "Namespace"
def init_pooling_state(
engine_client: "EngineClient",
state: "State",
args: "Namespace",
request_logger: RequestLogger | None,
supported_tasks: tuple["SupportedTask", ...],
):
from vllm.entrypoints.chat_utils import load_chat_template
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.pooling.classify.serving import ServingClassification
from vllm.entrypoints.pooling.embed.serving import OpenAIServingEmbedding
from vllm.entrypoints.pooling.pooling.serving import OpenAIServingPooling
from vllm.entrypoints.pooling.score.serving import ServingScores
from vllm.tasks import POOLING_TASKS
supported_tasks = await engine_client.get_supported_tasks()
resolved_chat_template = load_chat_template(args.chat_template)
if args.enable_log_requests:
request_logger = RequestLogger(max_log_len=args.max_log_len)
else:
request_logger = None
state.openai_serving_pooling = (
(
OpenAIServingPooling(
......
......@@ -7,45 +7,15 @@ from typing import Any
import model_hosting_container_standards.sagemaker as sagemaker_standards
import pydantic
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse, Response
from vllm.entrypoints.openai.api_server import (
base,
)
from vllm.entrypoints.openai.chat_completion.api_router import (
chat,
create_chat_completion,
)
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
)
from vllm.entrypoints.openai.completion.api_router import (
completion,
create_completion,
)
from vllm.entrypoints.openai.completion.protocol import (
CompletionRequest,
)
from vllm.entrypoints.openai.engine.protocol import (
ErrorResponse,
)
from vllm.entrypoints.openai.api_server import base
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
from vllm.entrypoints.openai.engine.serving import OpenAIServing
from vllm.entrypoints.openai.utils import validate_json_request
from vllm.entrypoints.pooling.classify.api_router import classify, create_classify
from vllm.entrypoints.pooling.classify.protocol import ClassificationRequest
from vllm.entrypoints.pooling.embed.api_router import create_embedding, embedding
from vllm.entrypoints.pooling.embed.protocol import EmbeddingRequest
from vllm.entrypoints.pooling.pooling.api_router import create_pooling, pooling
from vllm.entrypoints.pooling.pooling.protocol import PoolingRequest
from vllm.entrypoints.pooling.score.api_router import (
create_score,
do_rerank,
rerank,
score,
)
from vllm.entrypoints.pooling.score.protocol import RerankRequest, ScoreRequest
from vllm.entrypoints.serve.instrumentator.health import health
from vllm.tasks import POOLING_TASKS, SupportedTask
# TODO: RequestType = TypeForm[BaseModel] when recognized by type checkers
# (requires typing_extensions >= 4.13)
......@@ -53,25 +23,89 @@ RequestType = Any
GetHandlerFn = Callable[[Request], OpenAIServing | None]
EndpointFn = Callable[[RequestType, Request], Awaitable[Any]]
# NOTE: Items defined earlier take higher priority
INVOCATION_TYPES: list[tuple[RequestType, tuple[GetHandlerFn, EndpointFn]]] = [
def get_invocation_types(supported_tasks: tuple["SupportedTask", ...]):
# NOTE: Items defined earlier take higher priority
INVOCATION_TYPES: list[tuple[RequestType, tuple[GetHandlerFn, EndpointFn]]] = []
if "generate" in supported_tasks:
from vllm.entrypoints.openai.chat_completion.api_router import (
chat,
create_chat_completion,
)
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
)
from vllm.entrypoints.openai.completion.api_router import (
completion,
create_completion,
)
from vllm.entrypoints.openai.completion.protocol import CompletionRequest
INVOCATION_TYPES += [
(ChatCompletionRequest, (chat, create_chat_completion)),
(CompletionRequest, (completion, create_completion)),
]
if "embed" in supported_tasks:
from vllm.entrypoints.pooling.embed.api_router import (
create_embedding,
embedding,
)
from vllm.entrypoints.pooling.embed.protocol import EmbeddingRequest
INVOCATION_TYPES += [
(EmbeddingRequest, (embedding, create_embedding)),
]
if "classify" in supported_tasks:
from vllm.entrypoints.pooling.classify.api_router import (
classify,
create_classify,
)
from vllm.entrypoints.pooling.classify.protocol import ClassificationRequest
INVOCATION_TYPES += [
(ClassificationRequest, (classify, create_classify)),
(ScoreRequest, (score, create_score)),
]
if "score" in supported_tasks:
from vllm.entrypoints.pooling.score.api_router import do_rerank, rerank
from vllm.entrypoints.pooling.score.protocol import RerankRequest
INVOCATION_TYPES += [
(RerankRequest, (rerank, do_rerank)),
]
if "score" in supported_tasks or "embed" in supported_tasks:
from vllm.entrypoints.pooling.score.api_router import create_score, score
from vllm.entrypoints.pooling.score.protocol import ScoreRequest
INVOCATION_TYPES += [
(ScoreRequest, (score, create_score)),
]
if any(task in POOLING_TASKS for task in supported_tasks):
from vllm.entrypoints.pooling.pooling.api_router import create_pooling, pooling
from vllm.entrypoints.pooling.pooling.protocol import PoolingRequest
INVOCATION_TYPES += [
(PoolingRequest, (pooling, create_pooling)),
]
]
# NOTE: Construct the TypeAdapters only once
INVOCATION_VALIDATORS = [
return INVOCATION_TYPES
def attach_router(app: FastAPI, supported_tasks: tuple["SupportedTask", ...]):
router = APIRouter()
# NOTE: Construct the TypeAdapters only once
INVOCATION_TYPES = get_invocation_types(supported_tasks)
INVOCATION_VALIDATORS = [
(pydantic.TypeAdapter(request_type), (get_handler, endpoint))
for request_type, (get_handler, endpoint) in INVOCATION_TYPES
]
]
def register_sagemaker_routes(router: APIRouter):
@router.post("/ping", response_class=Response)
@router.get("/ping", response_class=Response)
@sagemaker_standards.register_ping_handler
......@@ -123,4 +157,4 @@ def register_sagemaker_routes(router: APIRouter):
res = base(raw_request).create_error_response(message=msg)
return JSONResponse(content=res.model_dump(), status_code=res.error.code)
return router
app.include_router(router)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment