api_router.py 1.24 KB
Newer Older
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

4
from fastapi import APIRouter, Depends, Request
5
6
7
8
9
10
11
from starlette.responses import JSONResponse

from vllm.entrypoints.openai.utils import validate_json_request
from vllm.entrypoints.pooling.classify.protocol import (
    ClassificationRequest,
)
from vllm.entrypoints.pooling.classify.serving import ServingClassification
12
13
14
15
16
from vllm.entrypoints.utils import (
    create_error_response,
    load_aware_call,
    with_cancellation,
)
17
18
19
20
21
22
23
24
25
26
27

router = APIRouter()


def classify(request: Request) -> ServingClassification | None:
    return request.app.state.openai_serving_classification


@router.post("/classify", dependencies=[Depends(validate_json_request)])
@with_cancellation
@load_aware_call
28
29
30
async def create_classify(
    request: ClassificationRequest, raw_request: Request
) -> JSONResponse:
31
32
    handler = classify(raw_request)
    if handler is None:
33
        error_response = create_error_response(
34
35
36
            message="The model does not support Classification API"
        )
        return JSONResponse(
37
38
            content=error_response.model_dump(),
            status_code=error_response.error.code,
39
40
        )

41
    return await handler(request, raw_request)