__init__.py 4.11 KB
Newer Older
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

4
5
from typing import TYPE_CHECKING

6
7
from fastapi import FastAPI

8
9
10
11
12
13
if TYPE_CHECKING:
    from argparse import Namespace

    from starlette.datastructures import State

    from vllm.engine.protocol import EngineClient
14
15
16
17
18
    from vllm.entrypoints.logger import RequestLogger
    from vllm.tasks import SupportedTask
else:
    RequestLogger = object
    SupportedTask = object
19

20

21
22
23
def register_pooling_api_routers(
    app: FastAPI, supported_tasks: tuple["SupportedTask", ...]
):
24
25
26
    from vllm.entrypoints.pooling.pooling.api_router import router as pooling_router

    app.include_router(pooling_router)
27

28
29
30
31
32
33
34
35
36
37
38
    if "classify" in supported_tasks:
        from vllm.entrypoints.pooling.classify.api_router import (
            router as classify_router,
        )

        app.include_router(classify_router)

    if "embed" in supported_tasks:
        from vllm.entrypoints.pooling.embed.api_router import router as embed_router

        app.include_router(embed_router)
39

40
41
42
43
    # Score API handles score/rerank for:
    # - "score" task (score_type: cross-encoder models)
    # - "embed" task (score_type: bi-encoder models)
    # - "token_embed" task (score_type: late interaction models)
44
    if any(t in supported_tasks for t in ("score", "embed", "token_embed")):
45
46
47
48
49
50
51
52
53
54
55
        from vllm.entrypoints.pooling.score.api_router import router as score_router

        app.include_router(score_router)


def init_pooling_state(
    engine_client: "EngineClient",
    state: "State",
    args: "Namespace",
    request_logger: RequestLogger | None,
    supported_tasks: tuple["SupportedTask", ...],
56
):
57
    from vllm.entrypoints.chat_utils import load_chat_template
58
    from vllm.entrypoints.pooling.classify.serving import ServingClassification
59
    from vllm.entrypoints.pooling.embed.serving import ServingEmbedding
60
61
62
63
    from vllm.entrypoints.pooling.pooling.serving import OpenAIServingPooling
    from vllm.entrypoints.pooling.score.serving import ServingScores
    from vllm.tasks import POOLING_TASKS

64
    resolved_chat_template = load_chat_template(args.chat_template)
65

66
    state.serving_pooling = (
67
68
69
70
        (
            OpenAIServingPooling(
                engine_client,
                state.openai_serving_models,
71
                state.openai_serving_render,
72
73
74
75
76
77
                request_logger=request_logger,
                chat_template=resolved_chat_template,
                chat_template_content_format=args.chat_template_content_format,
                trust_request_chat_template=args.trust_request_chat_template,
            )
        )
78
        if any(t in supported_tasks for t in POOLING_TASKS)
79
80
        else None
    )
81
82
    state.serving_embedding = (
        ServingEmbedding(
83
84
85
86
87
88
89
90
91
92
            engine_client,
            state.openai_serving_models,
            request_logger=request_logger,
            chat_template=resolved_chat_template,
            chat_template_content_format=args.chat_template_content_format,
            trust_request_chat_template=args.trust_request_chat_template,
        )
        if "embed" in supported_tasks
        else None
    )
93
    state.serving_classification = (
94
95
96
97
98
99
100
101
102
103
104
        ServingClassification(
            engine_client,
            state.openai_serving_models,
            request_logger=request_logger,
            chat_template=resolved_chat_template,
            chat_template_content_format=args.chat_template_content_format,
            trust_request_chat_template=args.trust_request_chat_template,
        )
        if "classify" in supported_tasks
        else None
    )
105
106
107
108
    # Score API handles score/rerank for:
    # - "score" task (score_type: cross-encoder models)
    # - "embed" task (score_type: bi-encoder models)
    # - "token_embed" task (score_type: late interaction models)
109
    state.serving_scores = (
110
111
112
113
114
        ServingScores(
            engine_client,
            state.openai_serving_models,
            request_logger=request_logger,
            score_template=resolved_chat_template,
115
            log_error_stack=args.log_error_stack,
116
        )
117
        if any(t in supported_tasks for t in ("embed", "score", "token_embed"))
118
119
        else None
    )