__init__.py 4.21 KB
Newer Older
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

4
5
from typing import TYPE_CHECKING

6
7
from fastapi import FastAPI

8
from vllm.config import ModelConfig
9
from vllm.entrypoints.pooling.utils import enable_scoring_api
10
11
from vllm.logger import init_logger

12
13
14
15
16
17
if TYPE_CHECKING:
    from argparse import Namespace

    from starlette.datastructures import State

    from vllm.engine.protocol import EngineClient
18
19
20
21
22
    from vllm.entrypoints.logger import RequestLogger
    from vllm.tasks import SupportedTask
else:
    RequestLogger = object
    SupportedTask = object
23

24
25
26
logger = init_logger(__name__)


27
def register_pooling_api_routers(
28
29
30
    app: FastAPI,
    supported_tasks: tuple["SupportedTask", ...],
    model_config: ModelConfig | None = None,
31
):
32
33
    if model_config is None:
        return
34

35
36
37
38
39
40
    pooling_task = model_config.get_pooling_task(supported_tasks)

    if pooling_task is not None:
        from vllm.entrypoints.pooling.pooling.api_router import router as pooling_router

        app.include_router(pooling_router)
41

42
43
44
45
46
47
48
49
50
51
52
    if "classify" in supported_tasks:
        from vllm.entrypoints.pooling.classify.api_router import (
            router as classify_router,
        )

        app.include_router(classify_router)

    if "embed" in supported_tasks:
        from vllm.entrypoints.pooling.embed.api_router import router as embed_router

        app.include_router(embed_router)
53

54
    if enable_scoring_api(supported_tasks, model_config):
55
        from vllm.entrypoints.pooling.scoring.api_router import router as score_router
56
57
58
59
60
61
62
63
64
65

        app.include_router(score_router)


def init_pooling_state(
    engine_client: "EngineClient",
    state: "State",
    args: "Namespace",
    request_logger: RequestLogger | None,
    supported_tasks: tuple["SupportedTask", ...],
66
):
67
    from vllm.entrypoints.chat_utils import load_chat_template
68
    from vllm.entrypoints.pooling.classify.serving import ServingClassification
69
    from vllm.entrypoints.pooling.embed.serving import ServingEmbedding
70
    from vllm.entrypoints.pooling.pooling.serving import OpenAIServingPooling
71
    from vllm.entrypoints.pooling.scoring.serving import ServingScores
72
73
    from vllm.tasks import POOLING_TASKS

74
75
    model_config = engine_client.model_config

76
    resolved_chat_template = load_chat_template(args.chat_template)
77

78
    state.serving_pooling = (
79
80
81
82
        (
            OpenAIServingPooling(
                engine_client,
                state.openai_serving_models,
83
                state.openai_serving_render,
84
                supported_tasks=supported_tasks,
85
86
87
88
89
90
                request_logger=request_logger,
                chat_template=resolved_chat_template,
                chat_template_content_format=args.chat_template_content_format,
                trust_request_chat_template=args.trust_request_chat_template,
            )
        )
91
        if any(t in supported_tasks for t in POOLING_TASKS)
92
93
        else None
    )
94
95
    state.serving_embedding = (
        ServingEmbedding(
96
97
98
99
100
101
102
103
104
105
            engine_client,
            state.openai_serving_models,
            request_logger=request_logger,
            chat_template=resolved_chat_template,
            chat_template_content_format=args.chat_template_content_format,
            trust_request_chat_template=args.trust_request_chat_template,
        )
        if "embed" in supported_tasks
        else None
    )
106
    state.serving_classification = (
107
108
109
110
111
112
113
114
115
116
117
        ServingClassification(
            engine_client,
            state.openai_serving_models,
            request_logger=request_logger,
            chat_template=resolved_chat_template,
            chat_template_content_format=args.chat_template_content_format,
            trust_request_chat_template=args.trust_request_chat_template,
        )
        if "classify" in supported_tasks
        else None
    )
118
    state.serving_scores = (
119
120
121
122
        ServingScores(
            engine_client,
            state.openai_serving_models,
            request_logger=request_logger,
123
124
125
            chat_template=resolved_chat_template,
            chat_template_content_format=args.chat_template_content_format,
            trust_request_chat_template=args.trust_request_chat_template,
126
        )
127
        if enable_scoring_api(supported_tasks, model_config)
128
129
        else None
    )