protocol.py 1.57 KB
Newer Older
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time
4
from typing import Any, TypeAlias
5
6
7
8
9

from pydantic import (
    Field,
)

10
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo
11
from vllm.entrypoints.pooling.base.protocol import (
12
    ChatRequestMixin,
13
    CompletionRequestMixin,
14
    EmbedRequestMixin,
15
16
    PoolingBasicRequestMixin,
)
17
18
19
from vllm.utils import random_uuid


20
21
22
class EmbeddingCompletionRequest(
    PoolingBasicRequestMixin, CompletionRequestMixin, EmbedRequestMixin
):
23
24
    # Ordered by official OpenAI API documentation
    # https://platform.openai.com/docs/api-reference/embeddings
25
    pass
26

27

28
29
30
class EmbeddingChatRequest(
    PoolingBasicRequestMixin, ChatRequestMixin, EmbedRequestMixin
):
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
    mm_processor_kwargs: dict[str, Any] | None = Field(
        default=None,
        description=("Additional kwargs to pass to the HF processor."),
    )


EmbeddingRequest: TypeAlias = EmbeddingCompletionRequest | EmbeddingChatRequest


class EmbeddingResponseData(OpenAIBaseModel):
    index: int
    object: str = "embedding"
    embedding: list[float] | str


class EmbeddingResponse(OpenAIBaseModel):
    id: str = Field(default_factory=lambda: f"embd-{random_uuid()}")
    object: str = "list"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
    data: list[EmbeddingResponseData]
    usage: UsageInfo


class EmbeddingBytesResponse(OpenAIBaseModel):
56
57
    content: list[bytes]
    headers: dict[str, str] | None = None
58
    media_type: str = "application/octet-stream"