"vllm/vscode:/vscode.git/clone" did not exist on "26c82c2762819425abeaea7fdcc87fbdf4768d0e"
pooling_params.py 1.97 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
from typing import TYPE_CHECKING, Any, Optional
4

5
import msgspec
6

7
8
9
if TYPE_CHECKING:
    from vllm.config import ModelConfig

10
11
12
13
14

class PoolingParams(
        msgspec.Struct,
        omit_defaults=True,  # type: ignore[call-arg]
        array_like=True):  # type: ignore[call-arg]
15
    """API parameters for pooling models. This is currently a placeholder.
16
17

    Attributes:
18
19
        dimensions: Reduce the dimensions of embeddings
                    if model support matryoshka representation.
20
21
        additional_data: Any additional data needed for pooling.
    """
22
23

    dimensions: Optional[int] = None
24
    additional_data: Optional[Any] = None
25
26
27

    def clone(self) -> "PoolingParams":
        """Returns a deep copy of the PoolingParams instance."""
28
29
30
31
32
33
34
35
36
37
        return PoolingParams(dimensions=self.dimensions,
                             additional_data=self.additional_data)

    def verify(self, model_config: "ModelConfig") -> None:
        if self.dimensions is not None:
            if not model_config.is_matryoshka:
                raise ValueError(
                    f'Model "{model_config.served_model_name}" does not '
                    f'support matryoshka representation, '
                    f'changing output dimensions will lead to poor results.')
38
39
40
41
42
43
44
45
46
47

            mds = model_config.matryoshka_dimensions
            if mds is not None:
                if self.dimensions not in mds:
                    raise ValueError(
                        f'Model "{model_config.served_model_name}" '
                        f'only supports {str(mds)} matryoshka dimensions, '
                        f'use other output dimensions will '
                        f'lead to poor results.')
            elif self.dimensions < 1:
48
                raise ValueError("Dimensions must be greater than 0")
49
50
51

    def __repr__(self) -> str:
        return (f"PoolingParams("
52
                f"dimensions={self.dimensions}, "
53
                f"additional_metadata={self.additional_data})")