typing.py 3.22 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import time
4
from collections.abc import AsyncGenerator, Sequence
5
6
from dataclasses import dataclass, field
from typing import Any, Generic, TypeAlias, TypeVar
7

8
9
from fastapi import Request
from pydantic import ConfigDict
10

11
from vllm import PoolingParams, PoolingRequestOutput, PromptType
12
13
14
15
from vllm.inputs import DataPrompt, EngineInput
from vllm.lora.request import LoRARequest

from .classify.protocol import (
16
17
18
19
    ClassificationChatRequest,
    ClassificationCompletionRequest,
    ClassificationResponse,
)
20
from .embed.protocol import (
21
    CohereEmbedRequest,
22
23
24
25
26
    EmbeddingBytesResponse,
    EmbeddingChatRequest,
    EmbeddingCompletionRequest,
    EmbeddingResponse,
)
27
from .pooling.protocol import (
28
    IOProcessorRequest,
29
    PoolingBytesResponse,
30
31
32
33
    PoolingChatRequest,
    PoolingCompletionRequest,
    PoolingResponse,
)
34
35
from .scoring.protocol import ScoringRequest, ScoringResponse
from .scoring.typing import ScoringData
36
37
38
39
40
41
42
43
44
45
46
47

PoolingCompletionLikeRequest: TypeAlias = (
    EmbeddingCompletionRequest
    | ClassificationCompletionRequest
    | PoolingCompletionRequest
)

PoolingChatLikeRequest: TypeAlias = (
    EmbeddingChatRequest | ClassificationChatRequest | PoolingChatRequest
)

AnyPoolingRequest: TypeAlias = (
48
49
50
    PoolingCompletionLikeRequest
    | PoolingChatLikeRequest
    | IOProcessorRequest
51
    | ScoringRequest
52
    | CohereEmbedRequest
53
54
55
56
57
58
59
)

AnyPoolingResponse: TypeAlias = (
    ClassificationResponse
    | EmbeddingResponse
    | EmbeddingBytesResponse
    | PoolingResponse
60
61
    | PoolingBytesResponse
    | ScoringResponse
62
)
63
64
65
66
67
68

PoolingRequestT = TypeVar("PoolingRequestT", bound=AnyPoolingRequest)


@dataclass(kw_only=True)
class PoolingServeContext(Generic[PoolingRequestT]):
69
70
    model_config = ConfigDict(arbitrary_types_allowed=True)

71
72
73
74
    request: PoolingRequestT
    raw_request: Request | None = None
    model_name: str
    request_id: str
75
    pooling_params: PoolingParams | list[PoolingParams]
76
77
    created_time: int = field(default_factory=lambda: int(time.time()))
    lora_request: LoRARequest | None = None
78
    engine_inputs: Sequence[EngineInput] | None = None
79
80
81
82
83
84
85
    prompt_request_ids: list[str] | None = None

    result_generator: AsyncGenerator[tuple[int, PoolingRequestOutput], None] | None = (
        None
    )
    final_res_batch: list[PoolingRequestOutput] = field(default_factory=list)

86
87
    ## for Long Text Embedding with Chunked Processing
    original_engine_inputs: Sequence[EngineInput] | None = None
88

89
90
91
    ## for bi-encoder & late-interaction
    n_queries: int | None = None

92
93
94
    ## for IOProcessorResponse
    response: Any | None = None

95
96
97
    ## for flash-late-interaction
    query_final_res_batch: list[PoolingRequestOutput] | None = None

98
99
100

@dataclass
class OfflineInputsContext:
101
102
    prompts: PromptType | Sequence[PromptType] | DataPrompt | ScoringData
    pooling_params: PoolingParams | Sequence[PoolingParams]
103
104
105
106
    tokenization_kwargs: dict[str, Any] | None = None
    chat_template: str | None = None

    ## for bi-encoder & late-interaction
107
    n_queries: int | None = None
108
109
110
111
112
113
114


@dataclass
class OfflineOutputsContext:
    outputs: list[PoolingRequestOutput]

    ## for bi-encoder & late-interaction
115
    n_queries: int | None = None