api_router.py 2.48 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
import importlib.util
from functools import lru_cache
5
6
from http import HTTPStatus

7
from fastapi import APIRouter, Depends, Request
8
9
10
from fastapi.responses import JSONResponse, StreamingResponse
from typing_extensions import assert_never

11
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
12
13
14
15
16
17
18
19
from vllm.entrypoints.openai.utils import validate_json_request
from vllm.entrypoints.pooling.embed.protocol import (
    EmbeddingBytesResponse,
    EmbeddingRequest,
    EmbeddingResponse,
)
from vllm.entrypoints.pooling.embed.serving import OpenAIServingEmbedding
from vllm.entrypoints.utils import load_aware_call, with_cancellation
20
from vllm.logger import init_logger
21
22
23

router = APIRouter()

24
25
26
27
28
29
30
31
32
33
34
35
36
37
logger = init_logger(__name__)


@lru_cache(maxsize=1)
def _get_json_response_cls():
    if importlib.util.find_spec("orjson") is not None:
        from fastapi.responses import ORJSONResponse

        return ORJSONResponse
    logger.warning_once(
        "To make v1/embeddings API fast, please install orjson by `pip install orjson`"
    )
    return JSONResponse

38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63

def embedding(request: Request) -> OpenAIServingEmbedding | None:
    return request.app.state.openai_serving_embedding


@router.post(
    "/v1/embeddings",
    dependencies=[Depends(validate_json_request)],
    responses={
        HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
        HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
    },
)
@with_cancellation
@load_aware_call
async def create_embedding(
    request: EmbeddingRequest,
    raw_request: Request,
):
    handler = embedding(raw_request)
    if handler is None:
        base_server = raw_request.app.state.openai_serving_tokenization
        return base_server.create_error_response(
            message="The model does not support Embeddings API"
        )

64
    generator = await handler.create_embedding(request, raw_request)
65
66
67
68
69
70

    if isinstance(generator, ErrorResponse):
        return JSONResponse(
            content=generator.model_dump(), status_code=generator.error.code
        )
    elif isinstance(generator, EmbeddingResponse):
71
        return _get_json_response_cls()(content=generator.model_dump())
72
73
    elif isinstance(generator, EmbeddingBytesResponse):
        return StreamingResponse(
74
75
            content=generator.content,
            headers=generator.headers,
76
77
78
79
            media_type=generator.media_type,
        )

    assert_never(generator)