terratorch.py 10.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17

# Copyright 2025 The vLLM team.
# Copyright 2025 IBM.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
18
"""Wrapper around `Terratorch` models"""
19

20
from collections import OrderedDict
21
from collections.abc import Iterable, Mapping, Sequence
22
from typing import Any, Callable, Optional, Union
23
24
25

import torch
import torch.nn as nn
26
27
from terratorch.vllm import (DummyDataGenerator, InferenceRunner,
                             InputDefinition, InputTypeEnum)
28
29
30
from transformers import BatchFeature

from vllm.config import VllmConfig
31
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
32
33
34
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.utils import AutoWeightsLoader
from vllm.multimodal import MULTIMODAL_REGISTRY
35
from vllm.multimodal.cache import MultiModalProcessorOnlyCache
36
37
38
from vllm.multimodal.inputs import (ImageItem, ModalityData,
                                    MultiModalDataDict, MultiModalFieldConfig,
                                    MultiModalInputs, MultiModalKwargsItems,
39
                                    MultiModalUUIDDict, PlaceholderRange)
40
41
from vllm.multimodal.parse import (DictEmbeddingItems, ModalityDataItems,
                                   MultiModalDataItems, MultiModalDataParser)
42
from vllm.multimodal.processing import (BaseMultiModalProcessor,
43
                                        BaseProcessingInfo, PromptUpdate)
44
from vllm.multimodal.profiling import BaseDummyInputsBuilder
45
from vllm.sequence import IntermediateTensors
46

47
from .interfaces import (IsAttentionFree, MultiModalEmbeddings,
48
                         SupportsMultiModal)
49
50
from .interfaces_base import default_pooling_type

51

52
53
54
def _terratorch_field_names(pretrained_cfg: dict):
    input_definition = InputDefinition(**pretrained_cfg["input"])
    return set(input_definition.data.keys())
55
56


57
58
59
60
61
62
def _terratorch_field_factory(
    pretrained_cfg: dict
) -> Callable[
    [Mapping[str, torch.Tensor]],
        Mapping[str, MultiModalFieldConfig],
]:
63

64
65
66
67
68
69
    def _terratorch_field_config(hf_inputs: Mapping[str, torch.Tensor]):
        input_definition = InputDefinition(**pretrained_cfg["input"])
        fields = {}
        for input_name, input in input_definition.data.items():
            if input.type == InputTypeEnum.tensor:
                fields[input_name] = "image"
70

71
72
73
74
75
        mm_fields_config = {}
        for field_name, field_modality in fields.items():
            mm_fields_config[field_name] = MultiModalFieldConfig.shared(
                batch_size=1, modality=field_modality)
        return mm_fields_config
76

77
    return _terratorch_field_config
78
79


80
class TerratorchProcessingInfo(BaseProcessingInfo):
81
82
83
84
85

    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": None}


86
87
88
89
90
91
class TerratorchInputBuilder(BaseDummyInputsBuilder[TerratorchProcessingInfo]):

    def __init__(self, info: TerratorchProcessingInfo):
        super().__init__(info)
        self.dummy_data_generator = DummyDataGenerator(
            self.info.get_hf_config().to_dict()["pretrained_cfg"])
92

93
94
95
96
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        return ""

    def get_dummy_mm_data(
97
98
99
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
100
    ) -> MultiModalDataDict:
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
        # Dummy data is generated based on the 'input' section
        # defined in the HF configuration file
        return self.dummy_data_generator.get_dummy_mm_data()


class TerratorchMultiModalDataParser(MultiModalDataParser):

    def __init__(self, pretrained_cfg: dict, *args, **kwargs):
        self._pretrained_cfg = pretrained_cfg
        super().__init__(*args, **kwargs)

    def _parse_image_data(
        self,
        data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]],
    ) -> Optional[ModalityDataItems[Any, Any]]:
        if isinstance(data, dict):

            terratorch_fields = _terratorch_field_names(self._pretrained_cfg)
119

120
121
122
123
124
125
126
127
            return DictEmbeddingItems(
                data,
                modality="image",
                required_fields=terratorch_fields,
                fields_factory=_terratorch_field_factory(self._pretrained_cfg),
            )

        return super()._parse_image_data(data)
128

129

130
131
132
133
134
135
136
137
138
139
140
class TerratorchMultiModalProcessor(BaseMultiModalProcessor):

    def __init__(
            self,
            info: TerratorchProcessingInfo,
            dummy_inputs: "BaseDummyInputsBuilder[TerratorchProcessingInfo]",
            *,
            cache: Optional[MultiModalProcessorOnlyCache] = None) -> None:

        self.pretrained_cfg = info.get_hf_config().to_dict()["pretrained_cfg"]
        super().__init__(info=info, dummy_inputs=dummy_inputs, cache=cache)
141

142
    def _get_data_parser(self) -> MultiModalDataParser:
143
144
        return TerratorchMultiModalDataParser(
            pretrained_cfg=self.pretrained_cfg)
145

146
147
148
149
150
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
151
        return _terratorch_field_factory(self.pretrained_cfg)(hf_inputs)
152

153
    def _get_prompt_updates(
154
155
156
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
157
        out_mm_kwargs: MultiModalKwargsItems,
158
159
    ) -> Sequence[PromptUpdate]:
        return []
160
161
162
163
164
165

    def apply(
        self,
        prompt: Union[str, list[int]],
        mm_data: MultiModalDataDict,
        hf_processor_mm_kwargs: Mapping[str, object],
166
        tokenization_kwargs: Optional[Mapping[str, object]] = None,
167
        mm_uuids: Optional[MultiModalUUIDDict] = None,
168
    ) -> MultiModalInputs:
169
170
171
172
173
174
175
        if "image" in mm_data:
            image_data = mm_data["image"]
        else:
            image_data = mm_data
            mm_data = {"image": mm_data}

        mm_items = self._to_mm_items(mm_data)
176
        tokenization_kwargs = tokenization_kwargs or {}
177
178
179
        mm_hashes = self._hash_mm_items(mm_items,
                                        hf_processor_mm_kwargs,
                                        tokenization_kwargs,
180
                                        mm_uuids=mm_uuids)
181
182
        mm_placeholders = {"image": [PlaceholderRange(offset=0, length=0)]}

183
184
185
186
187
188
189
        mm_processed_data = BatchFeature(image_data)

        mm_kwargs = MultiModalKwargsItems.from_hf_inputs(
            mm_processed_data,
            self._get_mm_fields_config(mm_processed_data,
                                       hf_processor_mm_kwargs),
        )
190
191
192
193
194

        return MultiModalInputs(
            type="multimodal",
            prompt=prompt,
            prompt_token_ids=[1],
195
196
            mm_kwargs=mm_kwargs,
            mm_hashes=mm_hashes,
197
            mm_placeholders=mm_placeholders,
198
199
200
        )


201
@default_pooling_type("All")
202
@MULTIMODAL_REGISTRY.register_processor(
203
204
205
    TerratorchMultiModalProcessor,
    info=TerratorchProcessingInfo,
    dummy_inputs=TerratorchInputBuilder,
206
)
207
class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal):
208
    supports_multimodal_raw_input_only = True
209
    is_pooling_model = True
210

211
212
213
214
215
216
217
    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
        if modality.startswith("image"):
            return None

        raise ValueError("Only image modality is supported")

218
219
220
    def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

221
222
223
224
        config = vllm_config.model_config.hf_config.to_dict()["pretrained_cfg"]

        self.inference_runner = InferenceRunner(config)
        self.model = self.inference_runner.model
225

226
227
228
229
230
        pooler_config = vllm_config.model_config.pooler_config
        assert pooler_config is not None

        self.pooler = DispatchPooler(
            {"encode": Pooler.for_encode(pooler_config)}, )
231

232
233
234
235
236
237
238
239
240
241
242
    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
        multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
    ) -> torch.Tensor:
        # We do not really use any input tokens and therefore no embeddings
        # to be calculated. However, due to the mandatory token ids in
        # the input prompt we pass one token and the size of the dummy
        # embedding tensors must reflect that.
        return torch.empty((input_ids.shape[0], 0))

243
244
245
246
247
248
249
250
    def forward(
        self,
        input_ids: Optional[torch.Tensor],
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        **kwargs: object,
    ):
251
        model_output = self.inference_runner.forward(**kwargs)
252
253
254

        return model_output.output

255
256
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
257
258
259
260
        params_list = []
        model_buffers = dict(self.named_buffers())
        loaded_buffers = []
        for key, value in weights:
261
262
263
264
265
266
267
268
269
            if isinstance(value, (dict, OrderedDict)):
                if key == "state_dict":
                    weights_to_parse = value
                    for name, weight in weights_to_parse.items():
                        name = f"inference_runner.{name}"

                        if "pos_embed" in name:
                            continue

270
271
                        if "_timm_module." in name:
                            name = name.replace("_timm_module.", "")
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288

                        # this model requires a couple of buffers to be loaded
                        # that are not loadable with the AutoWeightsLoader
                        if name in model_buffers:
                            if "_timm_module." in name:
                                name = name.replace("_timm_module.", "")
                            buffer = model_buffers[name]
                            weight_loader = getattr(buffer, "weight_loader",
                                                    default_weight_loader)
                            weight_loader(buffer, weight)
                            loaded_buffers.append(name)
                        else:
                            params_list.append((name, weight))
                    break

            elif isinstance(value, torch.Tensor):
                params_list.append((f"inference_runner.model.{key}", value))
289
290
291
292
293
294

        # Load the remaining model parameters
        loader = AutoWeightsLoader(self)
        autoloaded_weights = loader.load_weights(params_list)

        return autoloaded_weights.union(set(loaded_buffers))