"docs/source/quantization/fp8_e5m2_kv_cache.rst" did not exist on "7d648418b8b1aadb90489ef18cff1763ffc82ed5"
protocol.py 2.73 KB
Newer Older
1
2
3
4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import time
5
from typing import TypeAlias
6

7
from pydantic import Field
8

9
from vllm import PoolingParams
10
from vllm.config import ModelConfig
11
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo
12
from vllm.entrypoints.pooling.base.protocol import (
13
    ChatRequestMixin,
14
    ClassifyRequestMixin,
15
16
17
    CompletionRequestMixin,
    PoolingBasicRequestMixin,
)
18
from vllm.logger import init_logger
19
from vllm.renderers import TokenizeParams
20
21
from vllm.utils import random_uuid

22
23
logger = init_logger(__name__)

24

25
26
27
class ClassificationCompletionRequest(
    PoolingBasicRequestMixin, CompletionRequestMixin, ClassifyRequestMixin
):
28
29
30
31
32
33
34
    def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
        encoder_config = model_config.encoder_config or {}

        return TokenizeParams(
            max_total_tokens=model_config.max_model_len,
            max_output_tokens=0,
            truncate_prompt_tokens=self.truncate_prompt_tokens,
35
            truncation_side=self.truncation_side,
36
37
38
39
            do_lower_case=encoder_config.get("do_lower_case", False),
            add_special_tokens=self.add_special_tokens,
            max_total_tokens_param="max_model_len",
        )
40

41
42
43
44
45
46
    def to_pooling_params(self):
        return PoolingParams(
            task="classify",
            use_activation=self.use_activation,
        )

47

48
49
50
class ClassificationChatRequest(
    PoolingBasicRequestMixin, ChatRequestMixin, ClassifyRequestMixin
):
51
52
53
54
55
56
57
    def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
        encoder_config = model_config.encoder_config or {}

        return TokenizeParams(
            max_total_tokens=model_config.max_model_len,
            max_output_tokens=0,
            truncate_prompt_tokens=self.truncate_prompt_tokens,
58
            truncation_side=self.truncation_side,
59
60
61
62
63
            do_lower_case=encoder_config.get("do_lower_case", False),
            add_special_tokens=self.add_special_tokens,
            max_total_tokens_param="max_model_len",
        )

64
65
66
67
68
69
    def to_pooling_params(self):
        return PoolingParams(
            task="classify",
            use_activation=self.use_activation,
        )

70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89

ClassificationRequest: TypeAlias = (
    ClassificationCompletionRequest | ClassificationChatRequest
)


class ClassificationData(OpenAIBaseModel):
    index: int
    label: str | None
    probs: list[float]
    num_classes: int


class ClassificationResponse(OpenAIBaseModel):
    id: str = Field(default_factory=lambda: f"classify-{random_uuid()}")
    object: str = "list"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
    data: list[ClassificationData]
    usage: UsageInfo