terratorch.py 10.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17

# Copyright 2025 The vLLM team.
# Copyright 2025 IBM.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
18
"""Wrapper around `Terratorch` models"""
19

20
from collections import OrderedDict
21
from collections.abc import Iterable, Mapping, Sequence
22
from functools import cached_property
23
from typing import Any
24
25
26

import torch
import torch.nn as nn
27
28
29
30
31
32
from terratorch.vllm import (
    DummyDataGenerator,
    InferenceRunner,
    InputDefinition,
    InputTypeEnum,
)
33
34
35
from transformers import BatchFeature

from vllm.config import VllmConfig
36
37
from vllm.config.multimodal import BaseDummyOptions
from vllm.logger import init_logger
38
from vllm.model_executor.layers.pooler import IdentityPooler
39
40
41
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.utils import AutoWeightsLoader
from vllm.multimodal import MULTIMODAL_REGISTRY
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from vllm.multimodal.inputs import (
    ImageItem,
    ModalityData,
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalInputs,
    MultiModalKwargsItems,
    MultiModalUUIDDict,
    PlaceholderRange,
)
from vllm.multimodal.parse import (
    DictEmbeddingItems,
    ModalityDataItems,
    MultiModalDataItems,
    MultiModalDataParser,
)
from vllm.multimodal.processing import (
59
    BaseDummyInputsBuilder,
60
61
62
63
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptUpdate,
)
64
from vllm.sequence import IntermediateTensors
65

66
from .interfaces import IsAttentionFree, MultiModalEmbeddings, SupportsMultiModal
67
from .interfaces_base import attn_type
68

69
70
logger = init_logger(__name__)

71

72
def _terratorch_field_names(input_definition: InputDefinition):
73
    return set(input_definition.data.keys())
74
75


76
77
78
79
80
def _terratorch_field_factory(
    input_definition: InputDefinition,
    *,
    is_shared: bool = True,  # True for unprocessed data, False for processed data
):
81
82
83
84
85
86
    def _terratorch_field_config(
        hf_inputs: Mapping[str, torch.Tensor],
    ) -> Mapping[str, MultiModalFieldConfig]:
        fields = dict[str, MultiModalFieldConfig]()
        for name, input in input_definition.data.items():
            modality = "image"
87
            if input.type == InputTypeEnum.tensor:
88
89
90
91
92
                fields[name] = (
                    MultiModalFieldConfig.shared(modality, batch_size=1)
                    if is_shared
                    else MultiModalFieldConfig.batched(modality)
                )
93

94
        return fields
95

96
    return _terratorch_field_config
97
98


99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
class TerratorchMultiModalDataParser(MultiModalDataParser):
    def __init__(self, input_definition: InputDefinition, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.input_definition = input_definition

    def _parse_image_data(
        self,
        data: dict[str, torch.Tensor] | ModalityData[ImageItem],
    ) -> ModalityDataItems[Any, Any] | None:
        if isinstance(data, dict):
            return DictEmbeddingItems(
                data,
                modality="image",
                required_fields=_terratorch_field_names(self.input_definition),
                fields_factory=_terratorch_field_factory(self.input_definition),
            )

        return super()._parse_image_data(data)

    def parse_mm_data(self, mm_data: MultiModalDataDict) -> MultiModalDataItems:
        if "image" not in mm_data:
            mm_data = {"image": mm_data}

        return super().parse_mm_data(mm_data)


126
class TerratorchProcessingInfo(BaseProcessingInfo):
127
128
129
130
131
132
133
134
135
136
137
    @cached_property
    def input_definition(self) -> InputDefinition:
        pretrained_cfg = self.get_hf_config().to_dict()["pretrained_cfg"]
        return InputDefinition(**pretrained_cfg["input"])

    def get_data_parser(self):
        return TerratorchMultiModalDataParser(
            self.input_definition,
            expected_hidden_size=self._get_expected_hidden_size(),
        )

138
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
139
140
141
        return {"image": None}


142
143
144
145
class TerratorchInputBuilder(BaseDummyInputsBuilder[TerratorchProcessingInfo]):
    def __init__(self, info: TerratorchProcessingInfo):
        super().__init__(info)
        self.dummy_data_generator = DummyDataGenerator(
146
147
            self.info.get_hf_config().to_dict()["pretrained_cfg"]
        )
148

149
150
151
152
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        return ""

    def get_dummy_mm_data(
153
154
155
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
156
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
157
    ) -> MultiModalDataDict:
158
159
        # Dummy data is generated based on the 'input' section
        # defined in the HF configuration file
160
161

        if mm_options:
162
163
164
165
166
            logger.warning(
                "Configurable multimodal profiling "
                "options are not supported for Terratorch. "
                "They are ignored for now."
            )
167

168
169
170
        return self.dummy_data_generator.get_dummy_mm_data()


171
class TerratorchMultiModalProcessor(BaseMultiModalProcessor[TerratorchProcessingInfo]):
172
173
174
175
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
176
177
        *,
        is_shared: bool = True,
178
    ) -> Mapping[str, MultiModalFieldConfig]:
179
180
181
182
183
        factory = _terratorch_field_factory(
            self.info.input_definition,
            is_shared=is_shared,
        )
        return factory(hf_inputs)
184

185
    def _get_prompt_updates(
186
187
188
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
189
        out_mm_kwargs: MultiModalKwargsItems,
190
191
    ) -> Sequence[PromptUpdate]:
        return []
192
193
194

    def apply(
        self,
195
        prompt: str | list[int],
196
        mm_items: MultiModalDataItems,
197
        hf_processor_mm_kwargs: Mapping[str, object],
198
199
        tokenization_kwargs: Mapping[str, object] | None = None,
        mm_uuids: MultiModalUUIDDict | None = None,
200
    ) -> MultiModalInputs:
201
202
203
        if tokenization_kwargs is None:
            tokenization_kwargs = {}

204
205
206
        mm_hashes = self._hash_mm_items(
            mm_items, hf_processor_mm_kwargs, tokenization_kwargs, mm_uuids=mm_uuids
        )
207

208
        _, passthrough_data = self._get_hf_mm_data(mm_items)
209
210
211
212
        mm_processed_data = BatchFeature(
            {k: torch.tensor(v).unsqueeze(0) for k, v in passthrough_data.items()},
            tensor_type="pt",
        )
213
        mm_placeholders = {"image": [PlaceholderRange(offset=0, length=0)]}
214
215
216

        mm_kwargs = MultiModalKwargsItems.from_hf_inputs(
            mm_processed_data,
217
218
219
220
221
            self._get_mm_fields_config(
                mm_processed_data,
                hf_processor_mm_kwargs,
                is_shared=False,
            ),
222
        )
223
224
225
226

        return MultiModalInputs(
            type="multimodal",
            prompt_token_ids=[1],
227
228
            mm_kwargs=mm_kwargs,
            mm_hashes=mm_hashes,
229
            mm_placeholders=mm_placeholders,
230
231
232
        )


233
@attn_type("attention_free")
234
@MULTIMODAL_REGISTRY.register_processor(
235
236
237
    TerratorchMultiModalProcessor,
    info=TerratorchProcessingInfo,
    dummy_inputs=TerratorchInputBuilder,
238
)
239
class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal):
240
    supports_multimodal_raw_input_only = True
241
    is_pooling_model = True
242

243
    @classmethod
244
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
245
246
247
248
249
        if modality.startswith("image"):
            return None

        raise ValueError("Only image modality is supported")

250
251
252
    def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

253
254
255
256
        config = vllm_config.model_config.hf_config.to_dict()["pretrained_cfg"]

        self.inference_runner = InferenceRunner(config)
        self.model = self.inference_runner.model
257

258
        self.pooler = IdentityPooler()
259

260
    def embed_input_ids(
261
262
        self,
        input_ids: torch.Tensor,
263
        multimodal_embeddings: MultiModalEmbeddings | None = None,
264
        *,
265
        is_multimodal: torch.Tensor | None = None,
266
        handle_oov_mm_token: bool = False,
267
268
269
270
271
272
273
    ) -> torch.Tensor:
        # We do not really use any input tokens and therefore no embeddings
        # to be calculated. However, due to the mandatory token ids in
        # the input prompt we pass one token and the size of the dummy
        # embedding tensors must reflect that.
        return torch.empty((input_ids.shape[0], 0))

274
275
    def forward(
        self,
276
        input_ids: torch.Tensor | None,
277
        positions: torch.Tensor,
278
279
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
280
281
        **kwargs: object,
    ):
282
283
        model_output = self.inference_runner.forward(**kwargs)
        return model_output.output
284

285
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
286
287
288
289
        params_list = []
        model_buffers = dict(self.named_buffers())
        loaded_buffers = []
        for key, value in weights:
290
291
292
293
294
295
296
297
298
            if isinstance(value, (dict, OrderedDict)):
                if key == "state_dict":
                    weights_to_parse = value
                    for name, weight in weights_to_parse.items():
                        name = f"inference_runner.{name}"

                        if "pos_embed" in name:
                            continue

299
300
                        if "_timm_module." in name:
                            name = name.replace("_timm_module.", "")
301
302
303
304
305
306
307

                        # this model requires a couple of buffers to be loaded
                        # that are not loadable with the AutoWeightsLoader
                        if name in model_buffers:
                            if "_timm_module." in name:
                                name = name.replace("_timm_module.", "")
                            buffer = model_buffers[name]
308
309
310
                            weight_loader = getattr(
                                buffer, "weight_loader", default_weight_loader
                            )
311
312
313
314
315
316
317
318
                            weight_loader(buffer, weight)
                            loaded_buffers.append(name)
                        else:
                            params_list.append((name, weight))
                    break

            elif isinstance(value, torch.Tensor):
                params_list.append((f"inference_runner.model.{key}", value))
319
320
321
322
323
324

        # Load the remaining model parameters
        loader = AutoWeightsLoader(self)
        autoloaded_weights = loader.load_weights(params_list)

        return autoloaded_weights.union(set(loaded_buffers))