terratorch.py 11 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17

# Copyright 2025 The vLLM team.
# Copyright 2025 IBM.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
18
"""Wrapper around `Terratorch` models"""
19

20
from collections import OrderedDict
21
from collections.abc import Iterable, Mapping, Sequence
22
from functools import cached_property
23
from typing import Any
24
25
26

import torch
import torch.nn as nn
27
28
29
30
31
32
from terratorch.vllm import (
    DummyDataGenerator,
    InferenceRunner,
    InputDefinition,
    InputTypeEnum,
)
33
34
35
from transformers import BatchFeature

from vllm.config import VllmConfig
36
37
from vllm.config.multimodal import BaseDummyOptions
from vllm.logger import init_logger
38
from vllm.model_executor.layers.pooler import IdentityPooler
39
40
41
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.utils import AutoWeightsLoader
from vllm.multimodal import MULTIMODAL_REGISTRY
42
43
44
45
46
47
48
49
from vllm.multimodal.inputs import (
    ImageItem,
    ModalityData,
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalInputs,
    MultiModalKwargsItems,
    PlaceholderRange,
50
    mm_inputs,
51
52
53
54
55
56
)
from vllm.multimodal.parse import (
    DictEmbeddingItems,
    ModalityDataItems,
    MultiModalDataItems,
    MultiModalDataParser,
57
    MultiModalUUIDItems,
58
59
)
from vllm.multimodal.processing import (
60
    BaseDummyInputsBuilder,
61
62
63
64
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptUpdate,
)
65
from vllm.sequence import IntermediateTensors
66

67
from .interfaces import IsAttentionFree, MultiModalEmbeddings, SupportsMultiModal
68
from .interfaces_base import attn_type
69

70
71
logger = init_logger(__name__)

72

73
def _terratorch_field_names(input_definition: InputDefinition):
74
    return set(input_definition.data.keys())
75
76


77
78
79
80
81
def _terratorch_field_factory(
    input_definition: InputDefinition,
    *,
    is_shared: bool = True,  # True for unprocessed data, False for processed data
):
82
83
84
85
86
87
    def _terratorch_field_config(
        hf_inputs: Mapping[str, torch.Tensor],
    ) -> Mapping[str, MultiModalFieldConfig]:
        fields = dict[str, MultiModalFieldConfig]()
        for name, input in input_definition.data.items():
            modality = "image"
88
            if input.type == InputTypeEnum.tensor:
89
90
91
92
93
                fields[name] = (
                    MultiModalFieldConfig.shared(modality, batch_size=1)
                    if is_shared
                    else MultiModalFieldConfig.batched(modality)
                )
94

95
        return fields
96

97
    return _terratorch_field_config
98
99


100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
class TerratorchMultiModalDataParser(MultiModalDataParser):
    def __init__(self, input_definition: InputDefinition, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.input_definition = input_definition

    def _parse_image_data(
        self,
        data: dict[str, torch.Tensor] | ModalityData[ImageItem],
    ) -> ModalityDataItems[Any, Any] | None:
        if isinstance(data, dict):
            return DictEmbeddingItems(
                data,
                modality="image",
                required_fields=_terratorch_field_names(self.input_definition),
                fields_factory=_terratorch_field_factory(self.input_definition),
            )

        return super()._parse_image_data(data)

    def parse_mm_data(self, mm_data: MultiModalDataDict) -> MultiModalDataItems:
        if "image" not in mm_data:
            mm_data = {"image": mm_data}

        return super().parse_mm_data(mm_data)


127
class TerratorchProcessingInfo(BaseProcessingInfo):
128
129
130
131
132
133
134
135
136
137
138
    @cached_property
    def input_definition(self) -> InputDefinition:
        pretrained_cfg = self.get_hf_config().to_dict()["pretrained_cfg"]
        return InputDefinition(**pretrained_cfg["input"])

    def get_data_parser(self):
        return TerratorchMultiModalDataParser(
            self.input_definition,
            expected_hidden_size=self._get_expected_hidden_size(),
        )

139
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
140
141
142
        return {"image": None}


143
144
145
146
class TerratorchInputBuilder(BaseDummyInputsBuilder[TerratorchProcessingInfo]):
    def __init__(self, info: TerratorchProcessingInfo):
        super().__init__(info)
        self.dummy_data_generator = DummyDataGenerator(
147
148
            self.info.get_hf_config().to_dict()["pretrained_cfg"]
        )
149

150
151
152
153
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        return ""

    def get_dummy_mm_data(
154
155
156
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
157
        mm_options: Mapping[str, BaseDummyOptions],
158
    ) -> MultiModalDataDict:
159
160
        # Dummy data is generated based on the 'input' section
        # defined in the HF configuration file
161
162

        if mm_options:
163
164
165
166
167
            logger.warning(
                "Configurable multimodal profiling "
                "options are not supported for Terratorch. "
                "They are ignored for now."
            )
168

169
170
171
        return self.dummy_data_generator.get_dummy_mm_data()


172
class TerratorchMultiModalProcessor(BaseMultiModalProcessor[TerratorchProcessingInfo]):
173
174
175
176
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
177
178
        *,
        is_shared: bool = True,
179
    ) -> Mapping[str, MultiModalFieldConfig]:
180
181
182
183
184
        factory = _terratorch_field_factory(
            self.info.input_definition,
            is_shared=is_shared,
        )
        return factory(hf_inputs)
185

186
    def _get_prompt_updates(
187
188
189
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
190
        out_mm_kwargs: MultiModalKwargsItems,
191
192
    ) -> Sequence[PromptUpdate]:
        return []
193
194
195

    def apply(
        self,
196
        prompt: str | list[int],
197
        mm_items: MultiModalDataItems,
198
199
        mm_uuid_items: MultiModalUUIDItems | None = None,
        hf_processor_mm_kwargs: Mapping[str, object] | None = None,
200
        tokenization_kwargs: Mapping[str, object] | None = None,
201
    ) -> MultiModalInputs:
202
203
        if hf_processor_mm_kwargs is None:
            hf_processor_mm_kwargs = {}
204
205
206
        if tokenization_kwargs is None:
            tokenization_kwargs = {}

207
        mm_hashes = self._hash_mm_items(
208
209
210
            mm_items,
            mm_uuid_items,
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
211
        )
212

213
        _, passthrough_data = self._get_hf_mm_data(mm_items)
214
        mm_processed_data = BatchFeature(
215
            {k: torch.as_tensor(v).unsqueeze(0) for k, v in passthrough_data.items()},
216
217
            tensor_type="pt",
        )
218
        mm_placeholders = {"image": [PlaceholderRange(offset=0, length=0)]}
219
220
221

        mm_kwargs = MultiModalKwargsItems.from_hf_inputs(
            mm_processed_data,
222
223
224
225
226
            self._get_mm_fields_config(
                mm_processed_data,
                hf_processor_mm_kwargs,
                is_shared=False,
            ),
227
        )
228

229
        return mm_inputs(
230
            prompt_token_ids=[1],
231
232
            mm_kwargs=mm_kwargs,
            mm_hashes=mm_hashes,
233
            mm_placeholders=mm_placeholders,
234
235
236
        )


237
@attn_type("attention_free")
238
@MULTIMODAL_REGISTRY.register_processor(
239
240
241
    TerratorchMultiModalProcessor,
    info=TerratorchProcessingInfo,
    dummy_inputs=TerratorchInputBuilder,
242
)
243
class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal):
244
    supports_multimodal_raw_input_only = True
245
    is_pooling_model = True
246

247
    @classmethod
248
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
249
250
251
252
253
        if modality.startswith("image"):
            return None

        raise ValueError("Only image modality is supported")

254
255
256
    def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

257
258
259
260
        config = vllm_config.model_config.hf_config.to_dict()["pretrained_cfg"]

        self.inference_runner = InferenceRunner(config)
        self.model = self.inference_runner.model
261

262
        self.pooler = IdentityPooler()
263

264
    def embed_input_ids(
265
266
        self,
        input_ids: torch.Tensor,
267
        multimodal_embeddings: MultiModalEmbeddings | None = None,
268
        *,
269
        is_multimodal: torch.Tensor | None = None,
270
        handle_oov_mm_token: bool = False,
271
272
273
274
275
276
277
    ) -> torch.Tensor:
        # We do not really use any input tokens and therefore no embeddings
        # to be calculated. However, due to the mandatory token ids in
        # the input prompt we pass one token and the size of the dummy
        # embedding tensors must reflect that.
        return torch.empty((input_ids.shape[0], 0))

278
279
    def forward(
        self,
280
        input_ids: torch.Tensor | None,
281
        positions: torch.Tensor,
282
283
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
284
285
        **kwargs: object,
    ):
286
287
        model_output = self.inference_runner.forward(**kwargs)
        return model_output.output
288

289
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
290
291
292
293
        params_list = []
        model_buffers = dict(self.named_buffers())
        loaded_buffers = []
        for key, value in weights:
294
295
296
297
298
299
300
301
302
            if isinstance(value, (dict, OrderedDict)):
                if key == "state_dict":
                    weights_to_parse = value
                    for name, weight in weights_to_parse.items():
                        name = f"inference_runner.{name}"

                        if "pos_embed" in name:
                            continue

303
304
                        if "_timm_module." in name:
                            name = name.replace("_timm_module.", "")
305
306
307
308
309
310
311

                        # this model requires a couple of buffers to be loaded
                        # that are not loadable with the AutoWeightsLoader
                        if name in model_buffers:
                            if "_timm_module." in name:
                                name = name.replace("_timm_module.", "")
                            buffer = model_buffers[name]
312
313
314
                            weight_loader = getattr(
                                buffer, "weight_loader", default_weight_loader
                            )
315
316
317
318
319
320
321
322
                            weight_loader(buffer, weight)
                            loaded_buffers.append(name)
                        else:
                            params_list.append((name, weight))
                    break

            elif isinstance(value, torch.Tensor):
                params_list.append((f"inference_runner.model.{key}", value))
323
324
325
326
327
328

        # Load the remaining model parameters
        loader = AutoWeightsLoader(self)
        autoloaded_weights = loader.load_weights(params_list)

        return autoloaded_weights.union(set(loaded_buffers))