__init__.py 3.64 KB
Newer Older
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

4
5
from typing import TYPE_CHECKING

6
7
from fastapi import FastAPI

8
9
10
11
12
13
14
if TYPE_CHECKING:
    from argparse import Namespace

    from starlette.datastructures import State

    from vllm.engine.protocol import EngineClient

15
16
17
18
19
20
21
22
23
24
25

def register_pooling_api_routers(app: FastAPI):
    from vllm.entrypoints.pooling.classify.api_router import router as classify_router
    from vllm.entrypoints.pooling.embed.api_router import router as embed_router
    from vllm.entrypoints.pooling.pooling.api_router import router as pooling_router
    from vllm.entrypoints.pooling.score.api_router import router as score_router

    app.include_router(classify_router)
    app.include_router(embed_router)
    app.include_router(score_router)
    app.include_router(pooling_router)
26
27
28
29
30


async def init_pooling_state(
    engine_client: "EngineClient", state: "State", args: "Namespace"
):
31
    from vllm.entrypoints.chat_utils import load_chat_template
32
33
34
35
36
37
38
39
40
    from vllm.entrypoints.logger import RequestLogger
    from vllm.entrypoints.pooling.classify.serving import ServingClassification
    from vllm.entrypoints.pooling.embed.serving import OpenAIServingEmbedding
    from vllm.entrypoints.pooling.pooling.serving import OpenAIServingPooling
    from vllm.entrypoints.pooling.score.serving import ServingScores
    from vllm.tasks import POOLING_TASKS

    supported_tasks = await engine_client.get_supported_tasks()

41
    resolved_chat_template = load_chat_template(args.chat_template)
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100

    if args.enable_log_requests:
        request_logger = RequestLogger(max_log_len=args.max_log_len)
    else:
        request_logger = None

    state.openai_serving_pooling = (
        (
            OpenAIServingPooling(
                engine_client,
                state.openai_serving_models,
                supported_tasks=supported_tasks,
                request_logger=request_logger,
                chat_template=resolved_chat_template,
                chat_template_content_format=args.chat_template_content_format,
                trust_request_chat_template=args.trust_request_chat_template,
                log_error_stack=args.log_error_stack,
            )
        )
        if any(task in POOLING_TASKS for task in supported_tasks)
        else None
    )
    state.openai_serving_embedding = (
        OpenAIServingEmbedding(
            engine_client,
            state.openai_serving_models,
            request_logger=request_logger,
            chat_template=resolved_chat_template,
            chat_template_content_format=args.chat_template_content_format,
            trust_request_chat_template=args.trust_request_chat_template,
            log_error_stack=args.log_error_stack,
        )
        if "embed" in supported_tasks
        else None
    )
    state.openai_serving_classification = (
        ServingClassification(
            engine_client,
            state.openai_serving_models,
            request_logger=request_logger,
            chat_template=resolved_chat_template,
            chat_template_content_format=args.chat_template_content_format,
            trust_request_chat_template=args.trust_request_chat_template,
            log_error_stack=args.log_error_stack,
        )
        if "classify" in supported_tasks
        else None
    )
    state.openai_serving_scores = (
        ServingScores(
            engine_client,
            state.openai_serving_models,
            request_logger=request_logger,
            score_template=resolved_chat_template,
            log_error_stack=args.log_error_stack,
        )
        if ("embed" in supported_tasks or "score" in supported_tasks)
        else None
    )