api_router.py 4.01 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
from collections.abc import Awaitable, Callable
from http import HTTPStatus
from typing import Any

import model_hosting_container_standards.sagemaker as sagemaker_standards
import pydantic
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse, Response

13
from vllm.config import ModelConfig
14
15
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
from vllm.entrypoints.openai.engine.serving import OpenAIServing
16
from vllm.entrypoints.openai.generate.factories import get_generate_invocation_types
17
from vllm.entrypoints.openai.utils import validate_json_request
18
from vllm.entrypoints.pooling.base.serving import PoolingServingBase
19
from vllm.entrypoints.pooling.factories import get_pooling_invocation_types
20
from vllm.entrypoints.serve.instrumentator.basic import base
21
from vllm.entrypoints.serve.instrumentator.health import health
22
from vllm.tasks import SupportedTask
23
24
25
26

# TODO: RequestType = TypeForm[BaseModel] when recognized by type checkers
# (requires typing_extensions >= 4.13)
RequestType = Any
27
GetHandlerFn = Callable[[Request], OpenAIServing | PoolingServingBase | None]
28
29
30
EndpointFn = Callable[[RequestType, Request], Awaitable[Any]]


31
32
33
34
35
def attach_router(
    app: FastAPI,
    supported_tasks: tuple["SupportedTask", ...],
    model_config: ModelConfig | None = None,
):
36
37
38
    router = APIRouter()

    # NOTE: Construct the TypeAdapters only once
39
40
41
42
    INVOCATION_TYPES = get_generate_invocation_types(
        supported_tasks, model_config
    ) + get_pooling_invocation_types(supported_tasks, model_config)

43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
    INVOCATION_VALIDATORS = [
        (pydantic.TypeAdapter(request_type), (get_handler, endpoint))
        for request_type, (get_handler, endpoint) in INVOCATION_TYPES
    ]

    @router.post("/ping", response_class=Response)
    @router.get("/ping", response_class=Response)
    @sagemaker_standards.register_ping_handler
    async def ping(raw_request: Request) -> Response:
        """Ping check. Endpoint required for SageMaker"""
        return await health(raw_request)

    @router.post(
        "/invocations",
        dependencies=[Depends(validate_json_request)],
        responses={
            HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
            HTTPStatus.UNSUPPORTED_MEDIA_TYPE.value: {"model": ErrorResponse},
            HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
        },
    )
    @sagemaker_standards.register_invocation_handler
    @sagemaker_standards.stateful_session_manager()
    @sagemaker_standards.inject_adapter_id(adapter_path="model")
    async def invocations(raw_request: Request):
        """For SageMaker, routes requests based on the request type."""
        try:
            body = await raw_request.json()
        except json.JSONDecodeError as e:
            raise HTTPException(
                status_code=HTTPStatus.BAD_REQUEST.value,
                detail=f"JSON decode error: {e}",
            ) from e

        valid_endpoints = [
            (validator, endpoint)
            for validator, (get_handler, endpoint) in INVOCATION_VALIDATORS
            if get_handler(raw_request) is not None
        ]

        for request_validator, endpoint in valid_endpoints:
            try:
                request = request_validator.validate_python(body)
            except pydantic.ValidationError:
                continue

            return await endpoint(request, raw_request)

        type_names = [
            t.__name__ if isinstance(t := validator._type, type) else str(t)
            for validator, _ in valid_endpoints
        ]
        msg = f"Cannot find suitable handler for request. Expected one of: {type_names}"
        res = base(raw_request).create_error_response(message=msg)
        return JSONResponse(content=res.model_dump(), status_code=res.error.code)

    app.include_router(router)
100
101
102
103


def sagemaker_standards_bootstrap(app: FastAPI) -> FastAPI:
    return sagemaker_standards.bootstrap(app)