Unverified Commit 76139d08 authored by wang.yuqi's avatar wang.yuqi Committed by GitHub
Browse files

[Frontend] Frontend will only attach supported tasks corresponding entrypoints. (#33139)


Signed-off-by: default avatarwang.yuqi <yuqi.wang@daocloud.io>
Signed-off-by: default avatarwang.yuqi <noooop@126.com>
Co-authored-by: default avatarCyrus Leung <cyrus.tl.leung@gmail.com>
parent da8d0c44
...@@ -112,18 +112,14 @@ async def test_long_audio_request(mary_had_lamb, whisper_client): ...@@ -112,18 +112,14 @@ async def test_long_audio_request(mary_had_lamb, whisper_client):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_completion_endpoints(whisper_client): async def test_completion_endpoints(whisper_client):
# text to text model # text to text model
res = await whisper_client.chat.completions.create( with pytest.raises(openai.NotFoundError):
model=MODEL_NAME, await whisper_client.chat.completions.create(
messages=[{"role": "system", "content": "You are a helpful assistant."}], model=MODEL_NAME,
) messages=[{"role": "system", "content": "You are a helpful assistant."}],
err = res.error )
assert err["code"] == 400
assert err["message"] == "The model does not support Chat Completions API" with pytest.raises(openai.NotFoundError):
await whisper_client.completions.create(model=MODEL_NAME, prompt="Hello")
res = await whisper_client.completions.create(model=MODEL_NAME, prompt="Hello")
err = res.error
assert err["code"] == 400
assert err["message"] == "The model does not support Completions API"
@pytest.mark.asyncio @pytest.mark.asyncio
......
...@@ -9,6 +9,7 @@ import json ...@@ -9,6 +9,7 @@ import json
import httpx import httpx
import librosa import librosa
import numpy as np import numpy as np
import openai
import pytest import pytest
import pytest_asyncio import pytest_asyncio
import soundfile as sf import soundfile as sf
...@@ -52,12 +53,11 @@ async def test_non_asr_model(foscolo, rocm_aiter_fa_attention): ...@@ -52,12 +53,11 @@ async def test_non_asr_model(foscolo, rocm_aiter_fa_attention):
model_name, _get_server_args(rocm_aiter_fa_attention) model_name, _get_server_args(rocm_aiter_fa_attention)
) as remote_server: ) as remote_server:
client = remote_server.get_async_client() client = remote_server.get_async_client()
res = await client.audio.translations.create(
model=model_name, file=foscolo, temperature=0.0 with pytest.raises(openai.NotFoundError):
) await client.audio.translations.create(
err = res.error model=model_name, file=foscolo, temperature=0.0
assert err["code"] == 400 and not res.text )
assert err["message"] == "The model does not support Translations API"
@pytest.mark.asyncio @pytest.mark.asyncio
......
...@@ -401,7 +401,7 @@ async def test_score(server: RemoteOpenAIServer, model_name: str): ...@@ -401,7 +401,7 @@ async def test_score(server: RemoteOpenAIServer, model_name: str):
"documents": "pong", "documents": "pong",
}, },
) )
assert response.json()["error"]["type"] == "BadRequestError" assert response.json()["detail"] == "Not Found"
@pytest.mark.asyncio @pytest.mark.asyncio
...@@ -416,7 +416,7 @@ async def test_rerank(server: RemoteOpenAIServer, model_name: str): ...@@ -416,7 +416,7 @@ async def test_rerank(server: RemoteOpenAIServer, model_name: str):
"documents": ["pong"], "documents": ["pong"],
}, },
) )
assert response.json()["error"]["type"] == "BadRequestError" assert response.json()["detail"] == "Not Found"
@pytest.mark.asyncio @pytest.mark.asyncio
......
...@@ -33,14 +33,10 @@ from starlette.types import ASGIApp, Message, Receive, Scope, Send ...@@ -33,14 +33,10 @@ from starlette.types import ASGIApp, Message, Receive, Scope, Send
import vllm.envs as envs import vllm.envs as envs
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
from vllm.entrypoints.anthropic.serving import AnthropicServingMessages
from vllm.entrypoints.chat_utils import load_chat_template from vllm.entrypoints.chat_utils import load_chat_template
from vllm.entrypoints.launcher import serve_http from vllm.entrypoints.launcher import serve_http
from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.mcp.tool_server import DemoToolServer, MCPToolServer, ToolServer
from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat
from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_serve_args from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_serve_args
from vllm.entrypoints.openai.completion.serving import OpenAIServingCompletion
from vllm.entrypoints.openai.engine.protocol import ( from vllm.entrypoints.openai.engine.protocol import (
ErrorInfo, ErrorInfo,
ErrorResponse, ErrorResponse,
...@@ -50,12 +46,6 @@ from vllm.entrypoints.openai.models.protocol import BaseModelPath ...@@ -50,12 +46,6 @@ from vllm.entrypoints.openai.models.protocol import BaseModelPath
from vllm.entrypoints.openai.models.serving import ( from vllm.entrypoints.openai.models.serving import (
OpenAIServingModels, OpenAIServingModels,
) )
from vllm.entrypoints.openai.responses.serving import OpenAIServingResponses
from vllm.entrypoints.openai.translations.serving import (
OpenAIServingTranscription,
OpenAIServingTranslation,
)
from vllm.entrypoints.serve.disagg.serving import ServingTokens
from vllm.entrypoints.serve.elastic_ep.middleware import ( from vllm.entrypoints.serve.elastic_ep.middleware import (
ScalingMiddleware, ScalingMiddleware,
) )
...@@ -70,6 +60,7 @@ from vllm.entrypoints.utils import ( ...@@ -70,6 +60,7 @@ from vllm.entrypoints.utils import (
from vllm.exceptions import VLLMValidationError from vllm.exceptions import VLLMValidationError
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.reasoning import ReasoningParserManager from vllm.reasoning import ReasoningParserManager
from vllm.tasks import POOLING_TASKS, SupportedTask
from vllm.tool_parsers import ToolParserManager from vllm.tool_parsers import ToolParserManager
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
...@@ -513,7 +504,7 @@ def _log_non_streaming_response(response_body: list) -> None: ...@@ -513,7 +504,7 @@ def _log_non_streaming_response(response_body: list) -> None:
logger.info("response_body={<binary_data>}") logger.info("response_body={<binary_data>}")
def build_app(args: Namespace) -> FastAPI: def build_app(args: Namespace, supported_tasks: tuple["SupportedTask", ...]) -> FastAPI:
if args.disable_fastapi_docs: if args.disable_fastapi_docs:
app = FastAPI( app = FastAPI(
openapi_url=None, docs_url=None, redoc_url=None, lifespan=lifespan openapi_url=None, docs_url=None, redoc_url=None, lifespan=lifespan
...@@ -523,52 +514,44 @@ def build_app(args: Namespace) -> FastAPI: ...@@ -523,52 +514,44 @@ def build_app(args: Namespace) -> FastAPI:
else: else:
app = FastAPI(lifespan=lifespan) app = FastAPI(lifespan=lifespan)
app.state.args = args app.state.args = args
app.include_router(router)
from vllm.entrypoints.serve import register_vllm_serve_api_routers from vllm.entrypoints.serve import register_vllm_serve_api_routers
register_vllm_serve_api_routers(app) register_vllm_serve_api_routers(app)
from vllm.entrypoints.openai.chat_completion.api_router import (
attach_router as register_chat_api_router,
)
register_chat_api_router(app)
from vllm.entrypoints.openai.responses.api_router import ( from vllm.entrypoints.openai.models.api_router import (
attach_router as register_responses_api_router, attach_router as register_models_api_router,
)
register_responses_api_router(app)
from vllm.entrypoints.openai.translations.api_router import (
attach_router as register_translations_api_router,
) )
register_translations_api_router(app) register_models_api_router(app)
from vllm.entrypoints.openai.completion.api_router import ( from vllm.entrypoints.sagemaker.api_router import (
attach_router as register_completion_api_router, attach_router as register_sagemaker_api_router,
) )
register_completion_api_router(app) register_sagemaker_api_router(app, supported_tasks)
from vllm.entrypoints.anthropic.api_router import (
attach_router as register_anthropic_api_router,
)
register_anthropic_api_router(app) if "generate" in supported_tasks:
from vllm.entrypoints.openai.models.api_router import ( from vllm.entrypoints.openai.generate.api_router import (
attach_router as register_models_api_router, register_generate_api_routers,
) )
register_models_api_router(app) register_generate_api_routers(app)
from vllm.entrypoints.sagemaker.routes import register_sagemaker_routes
register_sagemaker_routes(router) if "transcription" in supported_tasks:
app.include_router(router) from vllm.entrypoints.openai.translations.api_router import (
attach_router as register_translations_api_router,
)
app.root_path = args.root_path register_translations_api_router(app)
from vllm.entrypoints.pooling import register_pooling_api_routers if any(task in POOLING_TASKS for task in supported_tasks):
from vllm.entrypoints.pooling import register_pooling_api_routers
register_pooling_api_routers(app) register_pooling_api_routers(app, supported_tasks)
app.root_path = args.root_path
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
allow_origins=args.allowed_origins, allow_origins=args.allowed_origins,
...@@ -673,6 +656,7 @@ async def init_app_state( ...@@ -673,6 +656,7 @@ async def init_app_state(
engine_client: EngineClient, engine_client: EngineClient,
state: State, state: State,
args: Namespace, args: Namespace,
supported_tasks: tuple["SupportedTask", ...],
) -> None: ) -> None:
vllm_config = engine_client.vllm_config vllm_config = engine_client.vllm_config
...@@ -694,28 +678,9 @@ async def init_app_state( ...@@ -694,28 +678,9 @@ async def init_app_state(
state.log_stats = not args.disable_log_stats state.log_stats = not args.disable_log_stats
state.vllm_config = vllm_config state.vllm_config = vllm_config
state.args = args state.args = args
supported_tasks = await engine_client.get_supported_tasks()
logger.info("Supported tasks: %s", supported_tasks)
resolved_chat_template = load_chat_template(args.chat_template) resolved_chat_template = load_chat_template(args.chat_template)
if args.tool_server == "demo":
tool_server: ToolServer | None = DemoToolServer()
assert isinstance(tool_server, DemoToolServer)
await tool_server.init_and_validate()
elif args.tool_server:
tool_server = MCPToolServer()
await tool_server.add_tool_server(args.tool_server)
else:
tool_server = None
# Merge default_mm_loras into the static lora_modules # Merge default_mm_loras into the static lora_modules
default_mm_loras = (
vllm_config.lora_config.default_mm_loras
if vllm_config.lora_config is not None
else {}
)
default_mm_loras = ( default_mm_loras = (
vllm_config.lora_config.default_mm_loras vllm_config.lora_config.default_mm_loras
if vllm_config.lora_config is not None if vllm_config.lora_config is not None
...@@ -729,66 +694,6 @@ async def init_app_state( ...@@ -729,66 +694,6 @@ async def init_app_state(
lora_modules=lora_modules, lora_modules=lora_modules,
) )
await state.openai_serving_models.init_static_loras() await state.openai_serving_models.init_static_loras()
state.openai_serving_responses = (
OpenAIServingResponses(
engine_client,
state.openai_serving_models,
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
enable_auto_tools=args.enable_auto_tool_choice,
tool_parser=args.tool_call_parser,
tool_server=tool_server,
reasoning_parser=args.structured_outputs_config.reasoning_parser,
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
enable_force_include_usage=args.enable_force_include_usage,
enable_log_outputs=args.enable_log_outputs,
log_error_stack=args.log_error_stack,
)
if "generate" in supported_tasks
else None
)
state.openai_serving_chat = (
OpenAIServingChat(
engine_client,
state.openai_serving_models,
args.response_role,
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
default_chat_template_kwargs=args.default_chat_template_kwargs,
trust_request_chat_template=args.trust_request_chat_template,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
enable_auto_tools=args.enable_auto_tool_choice,
exclude_tools_when_tool_choice_none=args.exclude_tools_when_tool_choice_none,
tool_parser=args.tool_call_parser,
reasoning_parser=args.structured_outputs_config.reasoning_parser,
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
enable_force_include_usage=args.enable_force_include_usage,
enable_log_outputs=args.enable_log_outputs,
enable_log_deltas=args.enable_log_deltas,
log_error_stack=args.log_error_stack,
)
if "generate" in supported_tasks
else None
)
# Warm up chat template processing to avoid first-request latency
if state.openai_serving_chat is not None:
await state.openai_serving_chat.warmup()
state.openai_serving_completion = (
OpenAIServingCompletion(
engine_client,
state.openai_serving_models,
request_logger=request_logger,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
enable_force_include_usage=args.enable_force_include_usage,
log_error_stack=args.log_error_stack,
)
if "generate" in supported_tasks
else None
)
state.openai_serving_tokenization = OpenAIServingTokenization( state.openai_serving_tokenization = OpenAIServingTokenization(
engine_client, engine_client,
state.openai_serving_models, state.openai_serving_models,
...@@ -798,64 +703,27 @@ async def init_app_state( ...@@ -798,64 +703,27 @@ async def init_app_state(
trust_request_chat_template=args.trust_request_chat_template, trust_request_chat_template=args.trust_request_chat_template,
log_error_stack=args.log_error_stack, log_error_stack=args.log_error_stack,
) )
state.openai_serving_transcription = (
OpenAIServingTranscription( if "generate" in supported_tasks:
engine_client, from vllm.entrypoints.openai.generate.api_router import init_generate_state
state.openai_serving_models,
request_logger=request_logger, await init_generate_state(
log_error_stack=args.log_error_stack, engine_client, state, args, request_logger, supported_tasks
enable_force_include_usage=args.enable_force_include_usage,
)
if "transcription" in supported_tasks
else None
)
state.openai_serving_translation = (
OpenAIServingTranslation(
engine_client,
state.openai_serving_models,
request_logger=request_logger,
log_error_stack=args.log_error_stack,
enable_force_include_usage=args.enable_force_include_usage,
) )
if "transcription" in supported_tasks
else None if "transcription" in supported_tasks:
) from vllm.entrypoints.openai.translations.api_router import (
state.anthropic_serving_messages = ( init_transcription_state,
AnthropicServingMessages(
engine_client,
state.openai_serving_models,
args.response_role,
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
enable_auto_tools=args.enable_auto_tool_choice,
tool_parser=args.tool_call_parser,
reasoning_parser=args.structured_outputs_config.reasoning_parser,
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
enable_force_include_usage=args.enable_force_include_usage,
) )
if "generate" in supported_tasks
else None init_transcription_state(
) engine_client, state, args, request_logger, supported_tasks
state.serving_tokens = (
ServingTokens(
engine_client,
state.openai_serving_models,
request_logger=request_logger,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
log_error_stack=args.log_error_stack,
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
enable_log_outputs=args.enable_log_outputs,
force_no_detokenize=args.tokens_only,
) )
if "generate" in supported_tasks
else None
)
from vllm.entrypoints.pooling import init_pooling_state if any(task in POOLING_TASKS for task in supported_tasks):
from vllm.entrypoints.pooling import init_pooling_state
await init_pooling_state(engine_client, state, args) init_pooling_state(engine_client, state, args, request_logger, supported_tasks)
state.enable_server_load_tracking = args.enable_server_load_tracking state.enable_server_load_tracking = args.enable_server_load_tracking
state.server_load_metrics = 0 state.server_load_metrics = 0
...@@ -972,9 +840,11 @@ async def run_server_worker( ...@@ -972,9 +840,11 @@ async def run_server_worker(
args, args,
client_config=client_config, client_config=client_config,
) as engine_client: ) as engine_client:
app = build_app(args) supported_tasks = await engine_client.get_supported_tasks()
logger.info("Supported tasks: %s", supported_tasks)
await init_app_state(engine_client, app.state, args) app = build_app(args, supported_tasks)
await init_app_state(engine_client, app.state, args, supported_tasks)
logger.info( logger.info(
"Starting vLLM API server %d on %s", "Starting vLLM API server %d on %s",
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING
from fastapi import FastAPI
if TYPE_CHECKING:
from argparse import Namespace
from starlette.datastructures import State
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger
from vllm.tasks import SupportedTask
else:
RequestLogger = object
def register_generate_api_routers(app: FastAPI):
from vllm.entrypoints.openai.chat_completion.api_router import (
attach_router as register_chat_api_router,
)
register_chat_api_router(app)
from vllm.entrypoints.openai.responses.api_router import (
attach_router as register_responses_api_router,
)
register_responses_api_router(app)
from vllm.entrypoints.openai.completion.api_router import (
attach_router as register_completion_api_router,
)
register_completion_api_router(app)
from vllm.entrypoints.anthropic.api_router import (
attach_router as register_anthropic_api_router,
)
register_anthropic_api_router(app)
async def init_generate_state(
engine_client: "EngineClient",
state: "State",
args: "Namespace",
request_logger: RequestLogger | None,
supported_tasks: tuple["SupportedTask", ...],
):
from vllm.entrypoints.anthropic.serving import AnthropicServingMessages
from vllm.entrypoints.chat_utils import load_chat_template
from vllm.entrypoints.mcp.tool_server import (
DemoToolServer,
MCPToolServer,
ToolServer,
)
from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat
from vllm.entrypoints.openai.completion.serving import OpenAIServingCompletion
from vllm.entrypoints.openai.responses.serving import OpenAIServingResponses
from vllm.entrypoints.serve.disagg.serving import ServingTokens
if args.tool_server == "demo":
tool_server: ToolServer | None = DemoToolServer()
assert isinstance(tool_server, DemoToolServer)
await tool_server.init_and_validate()
elif args.tool_server:
tool_server = MCPToolServer()
await tool_server.add_tool_server(args.tool_server)
else:
tool_server = None
resolved_chat_template = load_chat_template(args.chat_template)
state.openai_serving_responses = (
OpenAIServingResponses(
engine_client,
state.openai_serving_models,
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
enable_auto_tools=args.enable_auto_tool_choice,
tool_parser=args.tool_call_parser,
tool_server=tool_server,
reasoning_parser=args.structured_outputs_config.reasoning_parser,
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
enable_force_include_usage=args.enable_force_include_usage,
enable_log_outputs=args.enable_log_outputs,
log_error_stack=args.log_error_stack,
)
if "generate" in supported_tasks
else None
)
state.openai_serving_chat = (
OpenAIServingChat(
engine_client,
state.openai_serving_models,
args.response_role,
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
default_chat_template_kwargs=args.default_chat_template_kwargs,
trust_request_chat_template=args.trust_request_chat_template,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
enable_auto_tools=args.enable_auto_tool_choice,
exclude_tools_when_tool_choice_none=args.exclude_tools_when_tool_choice_none,
tool_parser=args.tool_call_parser,
reasoning_parser=args.structured_outputs_config.reasoning_parser,
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
enable_force_include_usage=args.enable_force_include_usage,
enable_log_outputs=args.enable_log_outputs,
enable_log_deltas=args.enable_log_deltas,
log_error_stack=args.log_error_stack,
)
if "generate" in supported_tasks
else None
)
# Warm up chat template processing to avoid first-request latency
if state.openai_serving_chat is not None:
await state.openai_serving_chat.warmup()
state.openai_serving_completion = (
OpenAIServingCompletion(
engine_client,
state.openai_serving_models,
request_logger=request_logger,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
enable_force_include_usage=args.enable_force_include_usage,
log_error_stack=args.log_error_stack,
)
if "generate" in supported_tasks
else None
)
state.anthropic_serving_messages = (
AnthropicServingMessages(
engine_client,
state.openai_serving_models,
args.response_role,
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
enable_auto_tools=args.enable_auto_tool_choice,
tool_parser=args.tool_call_parser,
reasoning_parser=args.structured_outputs_config.reasoning_parser,
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
enable_force_include_usage=args.enable_force_include_usage,
)
if "generate" in supported_tasks
else None
)
state.serving_tokens = (
ServingTokens(
engine_client,
state.openai_serving_models,
request_logger=request_logger,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
log_error_stack=args.log_error_stack,
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
enable_log_outputs=args.enable_log_outputs,
force_no_detokenize=args.tokens_only,
)
if "generate" in supported_tasks
else None
)
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
from http import HTTPStatus from http import HTTPStatus
from typing import Annotated from typing import TYPE_CHECKING, Annotated
from fastapi import APIRouter, FastAPI, Form, Request from fastapi import APIRouter, FastAPI, Form, Request
from fastapi.responses import JSONResponse, StreamingResponse from fastapi.responses import JSONResponse, StreamingResponse
...@@ -25,6 +25,17 @@ from vllm.entrypoints.utils import ( ...@@ -25,6 +25,17 @@ from vllm.entrypoints.utils import (
) )
from vllm.logger import init_logger from vllm.logger import init_logger
if TYPE_CHECKING:
from argparse import Namespace
from starlette.datastructures import State
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger
from vllm.tasks import SupportedTask
else:
RequestLogger = object
logger = init_logger(__name__) logger = init_logger(__name__)
router = APIRouter() router = APIRouter()
...@@ -115,3 +126,34 @@ async def create_translations( ...@@ -115,3 +126,34 @@ async def create_translations(
def attach_router(app: FastAPI): def attach_router(app: FastAPI):
app.include_router(router) app.include_router(router)
def init_transcription_state(
engine_client: "EngineClient",
state: "State",
args: "Namespace",
request_logger: RequestLogger | None,
supported_tasks: tuple["SupportedTask", ...],
):
state.openai_serving_transcription = (
OpenAIServingTranscription(
engine_client,
state.openai_serving_models,
request_logger=request_logger,
log_error_stack=args.log_error_stack,
enable_force_include_usage=args.enable_force_include_usage,
)
if "transcription" in supported_tasks
else None
)
state.openai_serving_translation = (
OpenAIServingTranslation(
engine_client,
state.openai_serving_models,
request_logger=request_logger,
log_error_stack=args.log_error_stack,
enable_force_include_usage=args.enable_force_include_usage,
)
if "transcription" in supported_tasks
else None
)
...@@ -11,40 +11,54 @@ if TYPE_CHECKING: ...@@ -11,40 +11,54 @@ if TYPE_CHECKING:
from starlette.datastructures import State from starlette.datastructures import State
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger
from vllm.tasks import SupportedTask
else:
RequestLogger = object
SupportedTask = object
def register_pooling_api_routers(app: FastAPI): def register_pooling_api_routers(
from vllm.entrypoints.pooling.classify.api_router import router as classify_router app: FastAPI, supported_tasks: tuple["SupportedTask", ...]
from vllm.entrypoints.pooling.embed.api_router import router as embed_router ):
from vllm.entrypoints.pooling.pooling.api_router import router as pooling_router from vllm.entrypoints.pooling.pooling.api_router import router as pooling_router
from vllm.entrypoints.pooling.score.api_router import router as score_router
app.include_router(classify_router)
app.include_router(embed_router)
app.include_router(score_router)
app.include_router(pooling_router) app.include_router(pooling_router)
if "classify" in supported_tasks:
from vllm.entrypoints.pooling.classify.api_router import (
router as classify_router,
)
app.include_router(classify_router)
if "embed" in supported_tasks:
from vllm.entrypoints.pooling.embed.api_router import router as embed_router
app.include_router(embed_router)
async def init_pooling_state( if "score" in supported_tasks or "embed" in supported_tasks:
engine_client: "EngineClient", state: "State", args: "Namespace" from vllm.entrypoints.pooling.score.api_router import router as score_router
app.include_router(score_router)
def init_pooling_state(
engine_client: "EngineClient",
state: "State",
args: "Namespace",
request_logger: RequestLogger | None,
supported_tasks: tuple["SupportedTask", ...],
): ):
from vllm.entrypoints.chat_utils import load_chat_template from vllm.entrypoints.chat_utils import load_chat_template
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.pooling.classify.serving import ServingClassification from vllm.entrypoints.pooling.classify.serving import ServingClassification
from vllm.entrypoints.pooling.embed.serving import OpenAIServingEmbedding from vllm.entrypoints.pooling.embed.serving import OpenAIServingEmbedding
from vllm.entrypoints.pooling.pooling.serving import OpenAIServingPooling from vllm.entrypoints.pooling.pooling.serving import OpenAIServingPooling
from vllm.entrypoints.pooling.score.serving import ServingScores from vllm.entrypoints.pooling.score.serving import ServingScores
from vllm.tasks import POOLING_TASKS from vllm.tasks import POOLING_TASKS
supported_tasks = await engine_client.get_supported_tasks()
resolved_chat_template = load_chat_template(args.chat_template) resolved_chat_template = load_chat_template(args.chat_template)
if args.enable_log_requests:
request_logger = RequestLogger(max_log_len=args.max_log_len)
else:
request_logger = None
state.openai_serving_pooling = ( state.openai_serving_pooling = (
( (
OpenAIServingPooling( OpenAIServingPooling(
......
...@@ -7,45 +7,15 @@ from typing import Any ...@@ -7,45 +7,15 @@ from typing import Any
import model_hosting_container_standards.sagemaker as sagemaker_standards import model_hosting_container_standards.sagemaker as sagemaker_standards
import pydantic import pydantic
from fastapi import APIRouter, Depends, HTTPException, Request from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse, Response from fastapi.responses import JSONResponse, Response
from vllm.entrypoints.openai.api_server import ( from vllm.entrypoints.openai.api_server import base
base, from vllm.entrypoints.openai.engine.protocol import ErrorResponse
)
from vllm.entrypoints.openai.chat_completion.api_router import (
chat,
create_chat_completion,
)
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
)
from vllm.entrypoints.openai.completion.api_router import (
completion,
create_completion,
)
from vllm.entrypoints.openai.completion.protocol import (
CompletionRequest,
)
from vllm.entrypoints.openai.engine.protocol import (
ErrorResponse,
)
from vllm.entrypoints.openai.engine.serving import OpenAIServing from vllm.entrypoints.openai.engine.serving import OpenAIServing
from vllm.entrypoints.openai.utils import validate_json_request from vllm.entrypoints.openai.utils import validate_json_request
from vllm.entrypoints.pooling.classify.api_router import classify, create_classify
from vllm.entrypoints.pooling.classify.protocol import ClassificationRequest
from vllm.entrypoints.pooling.embed.api_router import create_embedding, embedding
from vllm.entrypoints.pooling.embed.protocol import EmbeddingRequest
from vllm.entrypoints.pooling.pooling.api_router import create_pooling, pooling
from vllm.entrypoints.pooling.pooling.protocol import PoolingRequest
from vllm.entrypoints.pooling.score.api_router import (
create_score,
do_rerank,
rerank,
score,
)
from vllm.entrypoints.pooling.score.protocol import RerankRequest, ScoreRequest
from vllm.entrypoints.serve.instrumentator.health import health from vllm.entrypoints.serve.instrumentator.health import health
from vllm.tasks import POOLING_TASKS, SupportedTask
# TODO: RequestType = TypeForm[BaseModel] when recognized by type checkers # TODO: RequestType = TypeForm[BaseModel] when recognized by type checkers
# (requires typing_extensions >= 4.13) # (requires typing_extensions >= 4.13)
...@@ -53,25 +23,89 @@ RequestType = Any ...@@ -53,25 +23,89 @@ RequestType = Any
GetHandlerFn = Callable[[Request], OpenAIServing | None] GetHandlerFn = Callable[[Request], OpenAIServing | None]
EndpointFn = Callable[[RequestType, Request], Awaitable[Any]] EndpointFn = Callable[[RequestType, Request], Awaitable[Any]]
# NOTE: Items defined earlier take higher priority
INVOCATION_TYPES: list[tuple[RequestType, tuple[GetHandlerFn, EndpointFn]]] = [ def get_invocation_types(supported_tasks: tuple["SupportedTask", ...]):
(ChatCompletionRequest, (chat, create_chat_completion)), # NOTE: Items defined earlier take higher priority
(CompletionRequest, (completion, create_completion)), INVOCATION_TYPES: list[tuple[RequestType, tuple[GetHandlerFn, EndpointFn]]] = []
(EmbeddingRequest, (embedding, create_embedding)),
(ClassificationRequest, (classify, create_classify)), if "generate" in supported_tasks:
(ScoreRequest, (score, create_score)), from vllm.entrypoints.openai.chat_completion.api_router import (
(RerankRequest, (rerank, do_rerank)), chat,
(PoolingRequest, (pooling, create_pooling)), create_chat_completion,
] )
from vllm.entrypoints.openai.chat_completion.protocol import (
# NOTE: Construct the TypeAdapters only once ChatCompletionRequest,
INVOCATION_VALIDATORS = [ )
(pydantic.TypeAdapter(request_type), (get_handler, endpoint)) from vllm.entrypoints.openai.completion.api_router import (
for request_type, (get_handler, endpoint) in INVOCATION_TYPES completion,
] create_completion,
)
from vllm.entrypoints.openai.completion.protocol import CompletionRequest
def register_sagemaker_routes(router: APIRouter):
INVOCATION_TYPES += [
(ChatCompletionRequest, (chat, create_chat_completion)),
(CompletionRequest, (completion, create_completion)),
]
if "embed" in supported_tasks:
from vllm.entrypoints.pooling.embed.api_router import (
create_embedding,
embedding,
)
from vllm.entrypoints.pooling.embed.protocol import EmbeddingRequest
INVOCATION_TYPES += [
(EmbeddingRequest, (embedding, create_embedding)),
]
if "classify" in supported_tasks:
from vllm.entrypoints.pooling.classify.api_router import (
classify,
create_classify,
)
from vllm.entrypoints.pooling.classify.protocol import ClassificationRequest
INVOCATION_TYPES += [
(ClassificationRequest, (classify, create_classify)),
]
if "score" in supported_tasks:
from vllm.entrypoints.pooling.score.api_router import do_rerank, rerank
from vllm.entrypoints.pooling.score.protocol import RerankRequest
INVOCATION_TYPES += [
(RerankRequest, (rerank, do_rerank)),
]
if "score" in supported_tasks or "embed" in supported_tasks:
from vllm.entrypoints.pooling.score.api_router import create_score, score
from vllm.entrypoints.pooling.score.protocol import ScoreRequest
INVOCATION_TYPES += [
(ScoreRequest, (score, create_score)),
]
if any(task in POOLING_TASKS for task in supported_tasks):
from vllm.entrypoints.pooling.pooling.api_router import create_pooling, pooling
from vllm.entrypoints.pooling.pooling.protocol import PoolingRequest
INVOCATION_TYPES += [
(PoolingRequest, (pooling, create_pooling)),
]
return INVOCATION_TYPES
def attach_router(app: FastAPI, supported_tasks: tuple["SupportedTask", ...]):
router = APIRouter()
# NOTE: Construct the TypeAdapters only once
INVOCATION_TYPES = get_invocation_types(supported_tasks)
INVOCATION_VALIDATORS = [
(pydantic.TypeAdapter(request_type), (get_handler, endpoint))
for request_type, (get_handler, endpoint) in INVOCATION_TYPES
]
@router.post("/ping", response_class=Response) @router.post("/ping", response_class=Response)
@router.get("/ping", response_class=Response) @router.get("/ping", response_class=Response)
@sagemaker_standards.register_ping_handler @sagemaker_standards.register_ping_handler
...@@ -123,4 +157,4 @@ def register_sagemaker_routes(router: APIRouter): ...@@ -123,4 +157,4 @@ def register_sagemaker_routes(router: APIRouter):
res = base(raw_request).create_error_response(message=msg) res = base(raw_request).create_error_response(message=msg)
return JSONResponse(content=res.model_dump(), status_code=res.error.code) return JSONResponse(content=res.model_dump(), status_code=res.error.code)
return router app.include_router(router)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment