registry.py 5.71 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from collections.abc import Mapping
4
from dataclasses import dataclass
5
from typing import TYPE_CHECKING, Any, Union
6

7
import torch
8
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin
9
from typing_extensions import TypeVar
10

11
from vllm.logger import init_logger
12
from vllm.transformers_utils.processor import cached_processor_from_config
13
from vllm.utils import get_allowed_kwarg_only_overrides
14
from vllm.utils.jsontree import JSONTree, json_map_leaves
15
16

if TYPE_CHECKING:
17
    from vllm.config import ModelConfig
18
19
20
21
    from vllm.transformers_utils.tokenizer import AnyTokenizer
else:
    ModelConfig = Any
    AnyTokenizer = Any
22

23
24
25
_T = TypeVar("_T")
_C = TypeVar("_C", bound=PretrainedConfig, default=PretrainedConfig)
_P = TypeVar("_P", bound=ProcessorMixin, default=ProcessorMixin)
26

27
28
logger = init_logger(__name__)

29

30
31
32
33
34
35
36
@dataclass(frozen=True)
class InputContext:
    """
    Contains information about the model which may be used to
    modify the inputs.
    """

37
    model_config: ModelConfig
38
39
    """The configuration of the model."""

40
41
    def get_hf_config(
        self,
42
        typ: Union[type[_C], tuple[type[_C], ...]] = PretrainedConfig,
43
        /,
44
    ) -> _C:
45
46
        """
        Get the HuggingFace configuration
47
        (`transformers.PretrainedConfig`) of the model,
48
49
50
        additionally checking its type.

        Raises:
51
            TypeError: If the configuration is not of the specified type.
52
53
        """
        hf_config = self.model_config.hf_config
54
        if not isinstance(hf_config, typ):
55
            raise TypeError("Invalid type of HuggingFace config. "
56
                            f"Expected type: {typ}, but "
57
58
59
60
                            f"found type: {type(hf_config)}")

        return hf_config

61
    def get_hf_image_processor_config(self) -> dict[str, Any]:
62
63
64
65
66
        """
        Get the HuggingFace image processor configuration of the model.
        """
        return self.model_config.hf_image_processor_config

67
68
69
70
71
72
73
74
75
76
77
78
79
    def get_mm_config(self):
        """
        Get the multimodal config of the model.

        Raises:
            RuntimeError: If the model is not a multimodal model.
        """
        mm_config = self.model_config.multimodal_config
        if mm_config is None:
            raise RuntimeError("Not a multimodal model")

        return mm_config

80
81
    def get_hf_processor(
        self,
82
        typ: Union[type[_P], tuple[type[_P], ...]] = ProcessorMixin,
83
84
        /,
        **kwargs: object,
85
    ) -> _P:
86
87
        """
        Get the HuggingFace processor
88
        (`transformers.ProcessorMixin`) of the model,
89
90
91
92
93
        additionally checking its type.

        Raises:
            TypeError: If the processor is not of the specified type.
        """
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
        return cached_processor_from_config(
            self.model_config,
            processor_cls=typ,
            **kwargs,
        )

    def init_processor(
        self,
        typ: type[_T],
        /,
        **kwargs: object,
    ) -> _T:
        """
        Initialize a HuggingFace-like processor class, merging the
        keyword arguments with those in the model's configuration.
        """
110
111
        mm_config = self.model_config.get_multimodal_config()
        base_kwargs = mm_config.mm_processor_kwargs
112
113
114
115
116
        if base_kwargs is None:
            base_kwargs = {}

        merged_kwargs = {**base_kwargs, **kwargs}

117
        return typ(**merged_kwargs)
118

119

120
121
122
123
124
@dataclass(frozen=True)
class InputProcessingContext(InputContext):
    tokenizer: AnyTokenizer
    """The tokenizer used to tokenize the inputs."""

125
126
    def get_hf_processor(
        self,
127
        typ: Union[type[_P], tuple[type[_P], ...]] = ProcessorMixin,
128
129
        /,
        **kwargs: object,
130
    ) -> _P:
131
132
        return super().get_hf_processor(
            typ,
133
            tokenizer=self.tokenizer,
134
            **kwargs,
135
136
        )

137
    def call_hf_processor(
138
139
        self,
        hf_processor: ProcessorMixin,
140
141
        data: Mapping[str, object],
        kwargs: Mapping[str, object] = {},
142
    ) -> Union[BatchFeature, JSONTree]:
143
        """
144
145
        Call `hf_processor` on the prompt `data`
        (text, image, audio...) with configurable options `kwargs`.
146
        """
147
148
        assert callable(hf_processor)

149
        mm_config = self.model_config.get_multimodal_config()
150
        merged_kwargs = mm_config.merge_mm_processor_kwargs(kwargs)
151

152
        allowed_kwargs = get_allowed_kwarg_only_overrides(
153
            hf_processor,
154
            merged_kwargs,
155
156
            requires_kw_only=False,
            allow_var_kwargs=True,
157
        )
158

159
160
161
162
163
164
        def maybe_cast_dtype(x):
            # This mimics the behavior of transformers.BatchFeature
            if isinstance(x, torch.Tensor) and x.is_floating_point():
                return x.to(dtype=self.model_config.dtype)
            return x

165
        try:
166
167
168
            output = hf_processor(**data,
                                  **allowed_kwargs,
                                  return_tensors="pt")
169
170
            # this emulates output.to(dtype=self.model_config.dtype)
            if isinstance(output, BatchFeature):
171
                cast_output = json_map_leaves(maybe_cast_dtype, output.data)
172
173
                return BatchFeature(cast_output)

174
175
            cast_output = json_map_leaves(maybe_cast_dtype, output)

176
177
178
179
180
181
            logger.warning_once(
                f"{type(hf_processor).__name__} did not return `BatchFeature`. "
                "Make sure to match the behaviour of `ProcessorMixin` when "
                "implementing custom processors.")
            return cast_output

182
183
        except Exception as exc:
            msg = (f"Failed to apply {type(hf_processor).__name__} "
184
                   f"on data={data} with kwargs={allowed_kwargs}")
185

186
            raise ValueError(msg) from exc