terratorch.py 11.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17

# Copyright 2025 The vLLM team.
# Copyright 2025 IBM.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
18
"""Wrapper around `Terratorch` models"""
19

20
from collections import OrderedDict
21
from collections.abc import Iterable, Mapping, Sequence
22
from typing import Any, Callable, Optional, Union
23
24
25

import torch
import torch.nn as nn
26
27
from terratorch.vllm import (DummyDataGenerator, InferenceRunner,
                             InputDefinition, InputTypeEnum)
28
29
30
from transformers import BatchFeature

from vllm.config import VllmConfig
31
32
from vllm.config.multimodal import BaseDummyOptions
from vllm.logger import init_logger
33
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
34
35
36
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.utils import AutoWeightsLoader
from vllm.multimodal import MULTIMODAL_REGISTRY
37
from vllm.multimodal.cache import MultiModalProcessorOnlyCache
38
39
40
from vllm.multimodal.inputs import (ImageItem, ModalityData,
                                    MultiModalDataDict, MultiModalFieldConfig,
                                    MultiModalInputs, MultiModalKwargsItems,
41
                                    MultiModalUUIDDict, PlaceholderRange)
42
43
from vllm.multimodal.parse import (DictEmbeddingItems, ModalityDataItems,
                                   MultiModalDataItems, MultiModalDataParser)
44
from vllm.multimodal.processing import (BaseMultiModalProcessor,
45
                                        BaseProcessingInfo, PromptUpdate)
46
from vllm.multimodal.profiling import BaseDummyInputsBuilder
47
from vllm.sequence import IntermediateTensors
48

49
from .interfaces import (IsAttentionFree, MultiModalEmbeddings,
50
                         SupportsMultiModal)
51
52
from .interfaces_base import default_pooling_type

53
54
logger = init_logger(__name__)

55

56
57
58
def _terratorch_field_names(pretrained_cfg: dict):
    input_definition = InputDefinition(**pretrained_cfg["input"])
    return set(input_definition.data.keys())
59
60


61
62
63
64
65
66
def _terratorch_field_factory(
    pretrained_cfg: dict
) -> Callable[
    [Mapping[str, torch.Tensor]],
        Mapping[str, MultiModalFieldConfig],
]:
67

68
69
70
71
72
73
    def _terratorch_field_config(hf_inputs: Mapping[str, torch.Tensor]):
        input_definition = InputDefinition(**pretrained_cfg["input"])
        fields = {}
        for input_name, input in input_definition.data.items():
            if input.type == InputTypeEnum.tensor:
                fields[input_name] = "image"
74

75
76
77
78
79
        mm_fields_config = {}
        for field_name, field_modality in fields.items():
            mm_fields_config[field_name] = MultiModalFieldConfig.shared(
                batch_size=1, modality=field_modality)
        return mm_fields_config
80

81
    return _terratorch_field_config
82
83


84
class TerratorchProcessingInfo(BaseProcessingInfo):
85
86
87
88
89

    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": None}


90
91
92
93
94
95
class TerratorchInputBuilder(BaseDummyInputsBuilder[TerratorchProcessingInfo]):

    def __init__(self, info: TerratorchProcessingInfo):
        super().__init__(info)
        self.dummy_data_generator = DummyDataGenerator(
            self.info.get_hf_config().to_dict()["pretrained_cfg"])
96

97
98
99
100
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        return ""

    def get_dummy_mm_data(
101
102
103
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
104
        mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
105
    ) -> MultiModalDataDict:
106
107
        # Dummy data is generated based on the 'input' section
        # defined in the HF configuration file
108
109
110
111
112
113

        if mm_options:
            logger.warning("Configurable multimodal profiling "
                           "options are not supported for Terratorch. "
                           "They are ignored for now.")

114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
        return self.dummy_data_generator.get_dummy_mm_data()


class TerratorchMultiModalDataParser(MultiModalDataParser):

    def __init__(self, pretrained_cfg: dict, *args, **kwargs):
        self._pretrained_cfg = pretrained_cfg
        super().__init__(*args, **kwargs)

    def _parse_image_data(
        self,
        data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]],
    ) -> Optional[ModalityDataItems[Any, Any]]:
        if isinstance(data, dict):

            terratorch_fields = _terratorch_field_names(self._pretrained_cfg)
130

131
132
133
134
135
136
137
138
            return DictEmbeddingItems(
                data,
                modality="image",
                required_fields=terratorch_fields,
                fields_factory=_terratorch_field_factory(self._pretrained_cfg),
            )

        return super()._parse_image_data(data)
139

140

141
142
143
144
145
146
147
148
149
150
151
class TerratorchMultiModalProcessor(BaseMultiModalProcessor):

    def __init__(
            self,
            info: TerratorchProcessingInfo,
            dummy_inputs: "BaseDummyInputsBuilder[TerratorchProcessingInfo]",
            *,
            cache: Optional[MultiModalProcessorOnlyCache] = None) -> None:

        self.pretrained_cfg = info.get_hf_config().to_dict()["pretrained_cfg"]
        super().__init__(info=info, dummy_inputs=dummy_inputs, cache=cache)
152

153
    def _get_data_parser(self) -> MultiModalDataParser:
154
155
        return TerratorchMultiModalDataParser(
            pretrained_cfg=self.pretrained_cfg)
156

157
158
159
160
161
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
162
        return _terratorch_field_factory(self.pretrained_cfg)(hf_inputs)
163

164
    def _get_prompt_updates(
165
166
167
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
168
        out_mm_kwargs: MultiModalKwargsItems,
169
170
    ) -> Sequence[PromptUpdate]:
        return []
171
172
173
174
175
176

    def apply(
        self,
        prompt: Union[str, list[int]],
        mm_data: MultiModalDataDict,
        hf_processor_mm_kwargs: Mapping[str, object],
177
        tokenization_kwargs: Optional[Mapping[str, object]] = None,
178
        mm_uuids: Optional[MultiModalUUIDDict] = None,
179
    ) -> MultiModalInputs:
180
181
182
183
184
185
186
        if "image" in mm_data:
            image_data = mm_data["image"]
        else:
            image_data = mm_data
            mm_data = {"image": mm_data}

        mm_items = self._to_mm_items(mm_data)
187
        tokenization_kwargs = tokenization_kwargs or {}
188
189
190
        mm_hashes = self._hash_mm_items(mm_items,
                                        hf_processor_mm_kwargs,
                                        tokenization_kwargs,
191
                                        mm_uuids=mm_uuids)
192
193
        mm_placeholders = {"image": [PlaceholderRange(offset=0, length=0)]}

194
195
196
197
198
199
200
        mm_processed_data = BatchFeature(image_data)

        mm_kwargs = MultiModalKwargsItems.from_hf_inputs(
            mm_processed_data,
            self._get_mm_fields_config(mm_processed_data,
                                       hf_processor_mm_kwargs),
        )
201
202
203
204

        return MultiModalInputs(
            type="multimodal",
            prompt_token_ids=[1],
205
206
            mm_kwargs=mm_kwargs,
            mm_hashes=mm_hashes,
207
            mm_placeholders=mm_placeholders,
208
209
210
        )


211
@default_pooling_type("All")
212
@MULTIMODAL_REGISTRY.register_processor(
213
214
215
    TerratorchMultiModalProcessor,
    info=TerratorchProcessingInfo,
    dummy_inputs=TerratorchInputBuilder,
216
)
217
class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal):
218
    supports_multimodal_raw_input_only = True
219
    is_pooling_model = True
220

221
222
223
224
225
226
227
    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
        if modality.startswith("image"):
            return None

        raise ValueError("Only image modality is supported")

228
229
230
    def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

231
232
233
234
        config = vllm_config.model_config.hf_config.to_dict()["pretrained_cfg"]

        self.inference_runner = InferenceRunner(config)
        self.model = self.inference_runner.model
235

236
237
238
239
240
        pooler_config = vllm_config.model_config.pooler_config
        assert pooler_config is not None

        self.pooler = DispatchPooler(
            {"encode": Pooler.for_encode(pooler_config)}, )
241

242
243
244
245
    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
        multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
246
247
248
        *,
        is_multimodal: Optional[torch.Tensor] = None,
        handle_oov_mm_token: bool = False,
249
250
251
252
253
254
255
    ) -> torch.Tensor:
        # We do not really use any input tokens and therefore no embeddings
        # to be calculated. However, due to the mandatory token ids in
        # the input prompt we pass one token and the size of the dummy
        # embedding tensors must reflect that.
        return torch.empty((input_ids.shape[0], 0))

256
257
258
259
260
261
262
263
    def forward(
        self,
        input_ids: Optional[torch.Tensor],
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        **kwargs: object,
    ):
264
        model_output = self.inference_runner.forward(**kwargs)
265
266
267

        return model_output.output

268
269
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
270
271
272
273
        params_list = []
        model_buffers = dict(self.named_buffers())
        loaded_buffers = []
        for key, value in weights:
274
275
276
277
278
279
280
281
282
            if isinstance(value, (dict, OrderedDict)):
                if key == "state_dict":
                    weights_to_parse = value
                    for name, weight in weights_to_parse.items():
                        name = f"inference_runner.{name}"

                        if "pos_embed" in name:
                            continue

283
284
                        if "_timm_module." in name:
                            name = name.replace("_timm_module.", "")
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301

                        # this model requires a couple of buffers to be loaded
                        # that are not loadable with the AutoWeightsLoader
                        if name in model_buffers:
                            if "_timm_module." in name:
                                name = name.replace("_timm_module.", "")
                            buffer = model_buffers[name]
                            weight_loader = getattr(buffer, "weight_loader",
                                                    default_weight_loader)
                            weight_loader(buffer, weight)
                            loaded_buffers.append(name)
                        else:
                            params_list.append((name, weight))
                    break

            elif isinstance(value, torch.Tensor):
                params_list.append((f"inference_runner.model.{key}", value))
302
303
304
305
306
307

        # Load the remaining model parameters
        loader = AutoWeightsLoader(self)
        autoloaded_weights = loader.load_weights(params_list)

        return autoloaded_weights.union(set(loaded_buffers))