__init__.py 4.57 KB
Newer Older
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

4
5
from typing import TYPE_CHECKING

6
7
from fastapi import FastAPI

8
9
10
from vllm.config import ModelConfig
from vllm.logger import init_logger

11
12
13
14
15
16
if TYPE_CHECKING:
    from argparse import Namespace

    from starlette.datastructures import State

    from vllm.engine.protocol import EngineClient
17
18
19
20
21
    from vllm.entrypoints.logger import RequestLogger
    from vllm.tasks import SupportedTask
else:
    RequestLogger = object
    SupportedTask = object
22

23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
logger = init_logger(__name__)


def enable_scoring_api(
    supported_tasks: tuple["SupportedTask", ...],
    model_config: ModelConfig | None = None,
) -> bool:
    if any(t in supported_tasks for t in ("embed", "token_embed")):
        return True

    if model_config is not None and "classify" in supported_tasks:
        num_labels = getattr(model_config.hf_config, "num_labels", 0)
        if num_labels != 1:
            logger.debug_once("Score API is only enabled for num_labels == 1.")
            return False
        return True

    return False

42

43
def register_pooling_api_routers(
44
45
46
    app: FastAPI,
    supported_tasks: tuple["SupportedTask", ...],
    model_config: ModelConfig | None = None,
47
):
48
49
    if model_config is None:
        return
50

51
52
53
54
55
56
    pooling_task = model_config.get_pooling_task(supported_tasks)

    if pooling_task is not None:
        from vllm.entrypoints.pooling.pooling.api_router import router as pooling_router

        app.include_router(pooling_router)
57

58
59
60
61
62
63
64
65
66
67
68
    if "classify" in supported_tasks:
        from vllm.entrypoints.pooling.classify.api_router import (
            router as classify_router,
        )

        app.include_router(classify_router)

    if "embed" in supported_tasks:
        from vllm.entrypoints.pooling.embed.api_router import router as embed_router

        app.include_router(embed_router)
69

70
    if enable_scoring_api(supported_tasks, model_config):
71
72
73
74
75
76
77
78
79
80
81
        from vllm.entrypoints.pooling.score.api_router import router as score_router

        app.include_router(score_router)


def init_pooling_state(
    engine_client: "EngineClient",
    state: "State",
    args: "Namespace",
    request_logger: RequestLogger | None,
    supported_tasks: tuple["SupportedTask", ...],
82
):
83
    from vllm.entrypoints.chat_utils import load_chat_template
84
    from vllm.entrypoints.pooling.classify.serving import ServingClassification
85
    from vllm.entrypoints.pooling.embed.serving import ServingEmbedding
86
87
88
89
    from vllm.entrypoints.pooling.pooling.serving import OpenAIServingPooling
    from vllm.entrypoints.pooling.score.serving import ServingScores
    from vllm.tasks import POOLING_TASKS

90
91
    model_config = engine_client.model_config

92
    resolved_chat_template = load_chat_template(args.chat_template)
93

94
    state.serving_pooling = (
95
96
97
98
        (
            OpenAIServingPooling(
                engine_client,
                state.openai_serving_models,
99
                state.openai_serving_render,
100
                supported_tasks=supported_tasks,
101
102
103
104
105
106
                request_logger=request_logger,
                chat_template=resolved_chat_template,
                chat_template_content_format=args.chat_template_content_format,
                trust_request_chat_template=args.trust_request_chat_template,
            )
        )
107
        if any(t in supported_tasks for t in POOLING_TASKS)
108
109
        else None
    )
110
111
    state.serving_embedding = (
        ServingEmbedding(
112
113
114
115
116
117
118
119
120
121
            engine_client,
            state.openai_serving_models,
            request_logger=request_logger,
            chat_template=resolved_chat_template,
            chat_template_content_format=args.chat_template_content_format,
            trust_request_chat_template=args.trust_request_chat_template,
        )
        if "embed" in supported_tasks
        else None
    )
122
    state.serving_classification = (
123
124
125
126
127
128
129
130
131
132
133
        ServingClassification(
            engine_client,
            state.openai_serving_models,
            request_logger=request_logger,
            chat_template=resolved_chat_template,
            chat_template_content_format=args.chat_template_content_format,
            trust_request_chat_template=args.trust_request_chat_template,
        )
        if "classify" in supported_tasks
        else None
    )
134
    state.serving_scores = (
135
136
137
138
139
        ServingScores(
            engine_client,
            state.openai_serving_models,
            request_logger=request_logger,
            score_template=resolved_chat_template,
140
            log_error_stack=args.log_error_stack,
141
        )
142
        if enable_scoring_api(supported_tasks, model_config)
143
144
        else None
    )