pooling_params.py 2.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from typing import TYPE_CHECKING, Optional
5

6
import msgspec
7

8
from vllm.sampling_params import RequestOutputKind
9
from vllm.tasks import PoolingTask
10

11
12
13
if TYPE_CHECKING:
    from vllm.config import ModelConfig

14
15
16
17
18

class PoolingParams(
        msgspec.Struct,
        omit_defaults=True,  # type: ignore[call-arg]
        array_like=True):  # type: ignore[call-arg]
19
    """API parameters for pooling models.
20
21

    Attributes:
22
23
        dimensions: Reduce the dimensions of embeddings
                    if model support matryoshka representation.
24
    """
25
26

    dimensions: Optional[int] = None
27

28
    output_kind: RequestOutputKind = RequestOutputKind.FINAL_ONLY
29

30
    task: Optional[PoolingTask] = None
31
32
    """Internal use only."""

33
34
    requires_token_ids: bool = False
    """Internal use only."""
35
36
37

    def clone(self) -> "PoolingParams":
        """Returns a deep copy of the PoolingParams instance."""
38
39
        return PoolingParams(
            dimensions=self.dimensions,
40
41
            task=self.task,
            requires_token_ids=self.requires_token_ids,
42
        )
43

44
45
46
47
48
49
50
51
52
53
54
    def verify(self, task: PoolingTask, model_config: "ModelConfig") -> None:
        if self.task is None:
            self.task = task
        elif self.task != task:
            msg = f"You cannot overwrite {self.task=!r} with {task=!r}!"
            raise ValueError(msg)

        # NOTE: Task validation needs to done against the model instance,
        # which is not available in model config. So, it's not included
        # in this method

55
56
57
58
59
60
        if self.dimensions is not None:
            if not model_config.is_matryoshka:
                raise ValueError(
                    f'Model "{model_config.served_model_name}" does not '
                    f'support matryoshka representation, '
                    f'changing output dimensions will lead to poor results.')
61
62
63
64
65
66
67
68
69
70

            mds = model_config.matryoshka_dimensions
            if mds is not None:
                if self.dimensions not in mds:
                    raise ValueError(
                        f'Model "{model_config.served_model_name}" '
                        f'only supports {str(mds)} matryoshka dimensions, '
                        f'use other output dimensions will '
                        f'lead to poor results.')
            elif self.dimensions < 1:
71
                raise ValueError("Dimensions must be greater than 0")
72
73

    def __repr__(self) -> str:
74
75
76
77
        return (f"PoolingParams("
                f"dimensions={self.dimensions}, "
                f"task={self.task}, "
                f"requires_token_ids={self.requires_token_ids})")
78
79
80
81

    def __post_init__(self) -> None:
        assert self.output_kind == RequestOutputKind.FINAL_ONLY,\
            "For pooling output_kind has to be FINAL_ONLY"