typing.py 3.09 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import time
4
from collections.abc import AsyncGenerator, Sequence
5
6
from dataclasses import dataclass, field
from typing import Any, Generic, TypeAlias, TypeVar
7

8
9
from fastapi import Request
from pydantic import ConfigDict
10

11
from vllm import PoolingParams, PoolingRequestOutput, PromptType
12
13
14
15
16
17
from vllm.entrypoints.pooling.classify.protocol import (
    ClassificationChatRequest,
    ClassificationCompletionRequest,
    ClassificationResponse,
)
from vllm.entrypoints.pooling.embed.protocol import (
18
    CohereEmbedRequest,
19
20
21
22
23
24
25
    EmbeddingBytesResponse,
    EmbeddingChatRequest,
    EmbeddingCompletionRequest,
    EmbeddingResponse,
)
from vllm.entrypoints.pooling.pooling.protocol import (
    IOProcessorRequest,
26
    PoolingBytesResponse,
27
28
29
30
    PoolingChatRequest,
    PoolingCompletionRequest,
    PoolingResponse,
)
31
32
from vllm.entrypoints.pooling.scoring.protocol import ScoringRequest, ScoringResponse
from vllm.entrypoints.pooling.scoring.typing import ScoringData
33
from vllm.inputs import EngineInput
34
from vllm.lora.request import LoRARequest
35
36
37
38
39
40
41
42
43
44
45
46

PoolingCompletionLikeRequest: TypeAlias = (
    EmbeddingCompletionRequest
    | ClassificationCompletionRequest
    | PoolingCompletionRequest
)

PoolingChatLikeRequest: TypeAlias = (
    EmbeddingChatRequest | ClassificationChatRequest | PoolingChatRequest
)

AnyPoolingRequest: TypeAlias = (
47
48
49
    PoolingCompletionLikeRequest
    | PoolingChatLikeRequest
    | IOProcessorRequest
50
    | ScoringRequest
51
    | CohereEmbedRequest
52
53
54
55
56
57
58
)

AnyPoolingResponse: TypeAlias = (
    ClassificationResponse
    | EmbeddingResponse
    | EmbeddingBytesResponse
    | PoolingResponse
59
60
    | PoolingBytesResponse
    | ScoringResponse
61
)
62
63
64
65
66
67
68
69
70
71
72
73

PoolingRequestT = TypeVar("PoolingRequestT", bound=AnyPoolingRequest)


@dataclass(kw_only=True)
class PoolingServeContext(Generic[PoolingRequestT]):
    request: PoolingRequestT
    raw_request: Request | None = None
    model_name: str
    request_id: str
    created_time: int = field(default_factory=lambda: int(time.time()))
    lora_request: LoRARequest | None = None
74
75
    pooling_params: PoolingParams | list[PoolingParams] | None = None
    engine_inputs: Sequence[EngineInput] | None = None
76
77
78
79
80
81
82
83
84
    prompt_request_ids: list[str] | None = None
    intermediates: Any | None = None

    result_generator: AsyncGenerator[tuple[int, PoolingRequestOutput], None] | None = (
        None
    )
    final_res_batch: list[PoolingRequestOutput] = field(default_factory=list)

    model_config = ConfigDict(arbitrary_types_allowed=True)
85

86
87
88
    ## for bi-encoder & late-interaction
    n_queries: int | None = None

89
90
91
92
93
94
95
96
97

@dataclass
class OfflineInputsContext:
    prompts: PromptType | Sequence[PromptType] | ScoringData
    pooling_params: PoolingParams | list[PoolingParams] | None = None
    tokenization_kwargs: dict[str, Any] | None = None
    chat_template: str | None = None

    ## for bi-encoder & late-interaction
98
    n_queries: int | None = None
99
100
101
102
103
104
105


@dataclass
class OfflineOutputsContext:
    outputs: list[PoolingRequestOutput]

    ## for bi-encoder & late-interaction
106
    n_queries: int | None = None