protocol.py 1.39 KB
Newer Older
1
2
3
4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import time
5
from typing import Any, TypeAlias
6
7
8
9
10

from pydantic import (
    Field,
)

11
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo
12
from vllm.entrypoints.pooling.base.protocol import (
13
    ChatRequestMixin,
14
    ClassifyRequestMixin,
15
16
17
    CompletionRequestMixin,
    PoolingBasicRequestMixin,
)
18
19
20
from vllm.utils import random_uuid


21
22
23
24
class ClassificationCompletionRequest(
    PoolingBasicRequestMixin, CompletionRequestMixin, ClassifyRequestMixin
):
    pass
25
26


27
28
29
class ClassificationChatRequest(
    PoolingBasicRequestMixin, ChatRequestMixin, ClassifyRequestMixin
):
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
    # --8<-- [start:chat-classification-extra-params]
    mm_processor_kwargs: dict[str, Any] | None = Field(
        default=None,
        description=("Additional kwargs to pass to the HF processor."),
    )


ClassificationRequest: TypeAlias = (
    ClassificationCompletionRequest | ClassificationChatRequest
)


class ClassificationData(OpenAIBaseModel):
    index: int
    label: str | None
    probs: list[float]
    num_classes: int


class ClassificationResponse(OpenAIBaseModel):
    id: str = Field(default_factory=lambda: f"classify-{random_uuid()}")
    object: str = "list"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
    data: list[ClassificationData]
    usage: UsageInfo