"vscode:/vscode.git/clone" did not exist on "e89003ddd07a112a152521feb2c80d8bfa3c01fb"
typing.py 2.36 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
import time
from collections.abc import AsyncGenerator
from dataclasses import dataclass, field
from typing import Any, Generic, TypeAlias, TypeVar
7

8
9
from fastapi import Request
from pydantic import ConfigDict
10

11
from vllm import PoolingRequestOutput
12
13
14
15
16
17
from vllm.entrypoints.pooling.classify.protocol import (
    ClassificationChatRequest,
    ClassificationCompletionRequest,
    ClassificationResponse,
)
from vllm.entrypoints.pooling.embed.protocol import (
18
    CohereEmbedRequest,
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
    EmbeddingBytesResponse,
    EmbeddingChatRequest,
    EmbeddingCompletionRequest,
    EmbeddingResponse,
)
from vllm.entrypoints.pooling.pooling.protocol import (
    IOProcessorRequest,
    PoolingChatRequest,
    PoolingCompletionRequest,
    PoolingResponse,
)
from vllm.entrypoints.pooling.score.protocol import (
    RerankRequest,
    ScoreRequest,
    ScoreResponse,
)
35
from vllm.inputs import EngineInput
36
from vllm.lora.request import LoRARequest
37
38
39
40
41
42
43
44
45
46
47
48

PoolingCompletionLikeRequest: TypeAlias = (
    EmbeddingCompletionRequest
    | ClassificationCompletionRequest
    | PoolingCompletionRequest
)

PoolingChatLikeRequest: TypeAlias = (
    EmbeddingChatRequest | ClassificationChatRequest | PoolingChatRequest
)

AnyPoolingRequest: TypeAlias = (
49
50
51
52
53
    PoolingCompletionLikeRequest
    | PoolingChatLikeRequest
    | IOProcessorRequest
    | RerankRequest
    | ScoreRequest
54
    | CohereEmbedRequest
55
56
57
58
59
60
61
62
63
)

AnyPoolingResponse: TypeAlias = (
    ClassificationResponse
    | EmbeddingResponse
    | EmbeddingBytesResponse
    | PoolingResponse
    | ScoreResponse
)
64
65
66
67
68
69
70
71
72
73
74
75
76

PoolingRequestT = TypeVar("PoolingRequestT", bound=AnyPoolingRequest)


@dataclass(kw_only=True)
class PoolingServeContext(Generic[PoolingRequestT]):
    request: PoolingRequestT
    raw_request: Request | None = None
    model_name: str
    request_id: str
    created_time: int = field(default_factory=lambda: int(time.time()))
    lora_request: LoRARequest | None = None

77
    engine_inputs: list[EngineInput] | None = None
78
79
80
81
82
83
84
85
86
    prompt_request_ids: list[str] | None = None
    intermediates: Any | None = None

    result_generator: AsyncGenerator[tuple[int, PoolingRequestOutput], None] | None = (
        None
    )
    final_res_batch: list[PoolingRequestOutput] = field(default_factory=list)

    model_config = ConfigDict(arbitrary_types_allowed=True)