terratorch.py 10.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17

# Copyright 2025 The vLLM team.
# Copyright 2025 IBM.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
18
"""Wrapper around `Terratorch` models"""
19

20
from collections import OrderedDict
21
from collections.abc import Iterable, Mapping, Sequence
22
from functools import cached_property
23
from typing import Any
24
25
26

import torch
import torch.nn as nn
27
28
29
30
31
32
from terratorch.vllm import (
    DummyDataGenerator,
    InferenceRunner,
    InputDefinition,
    InputTypeEnum,
)
33
34
35
from transformers import BatchFeature

from vllm.config import VllmConfig
36
from vllm.config.multimodal import BaseDummyOptions
37
from vllm.inputs import ModalityData, MultiModalDataDict, MultiModalInput, mm_input
38
from vllm.logger import init_logger
39
from vllm.model_executor.layers.pooler import IdentityPooler
40
41
42
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.utils import AutoWeightsLoader
from vllm.multimodal import MULTIMODAL_REGISTRY
43
44
45
46
47
48
49
50
51
52
53
54
55
from vllm.multimodal.inputs import (
    ImageItem,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
    PlaceholderRange,
)
from vllm.multimodal.parse import (
    DictEmbeddingItems,
    ModalityDataItems,
    MultiModalDataItems,
    MultiModalDataParser,
)
from vllm.multimodal.processing import (
56
    BaseDummyInputsBuilder,
57
58
    BaseMultiModalProcessor,
    BaseProcessingInfo,
59
    ProcessorInputs,
60
    PromptUpdate,
61
    TimingContext,
62
)
63
from vllm.sequence import IntermediateTensors
64

65
from .interfaces import IsAttentionFree, MultiModalEmbeddings, SupportsMultiModal
66
from .interfaces_base import attn_type
67

68
69
logger = init_logger(__name__)

70

71
def _terratorch_field_names(input_definition: InputDefinition):
72
    return set(input_definition.data.keys())
73
74


75
76
77
78
79
def _terratorch_field_factory(
    input_definition: InputDefinition,
    *,
    is_shared: bool = True,  # True for unprocessed data, False for processed data
):
80
81
82
83
84
85
    def _terratorch_field_config(
        hf_inputs: Mapping[str, torch.Tensor],
    ) -> Mapping[str, MultiModalFieldConfig]:
        fields = dict[str, MultiModalFieldConfig]()
        for name, input in input_definition.data.items():
            modality = "image"
86
            if input.type == InputTypeEnum.tensor:
87
88
89
90
91
                fields[name] = (
                    MultiModalFieldConfig.shared(modality, batch_size=1)
                    if is_shared
                    else MultiModalFieldConfig.batched(modality)
                )
92

93
        return fields
94

95
    return _terratorch_field_config
96
97


98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
class TerratorchMultiModalDataParser(MultiModalDataParser):
    def __init__(self, input_definition: InputDefinition, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.input_definition = input_definition

    def _parse_image_data(
        self,
        data: dict[str, torch.Tensor] | ModalityData[ImageItem],
    ) -> ModalityDataItems[Any, Any] | None:
        if isinstance(data, dict):
            return DictEmbeddingItems(
                data,
                modality="image",
                required_fields=_terratorch_field_names(self.input_definition),
                fields_factory=_terratorch_field_factory(self.input_definition),
            )

        return super()._parse_image_data(data)

    def parse_mm_data(self, mm_data: MultiModalDataDict) -> MultiModalDataItems:
        if "image" not in mm_data:
            mm_data = {"image": mm_data}

        return super().parse_mm_data(mm_data)


125
class TerratorchProcessingInfo(BaseProcessingInfo):
126
127
128
129
130
131
132
133
134
135
136
    @cached_property
    def input_definition(self) -> InputDefinition:
        pretrained_cfg = self.get_hf_config().to_dict()["pretrained_cfg"]
        return InputDefinition(**pretrained_cfg["input"])

    def get_data_parser(self):
        return TerratorchMultiModalDataParser(
            self.input_definition,
            expected_hidden_size=self._get_expected_hidden_size(),
        )

137
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
138
139
140
        return {"image": None}


141
142
143
144
class TerratorchInputBuilder(BaseDummyInputsBuilder[TerratorchProcessingInfo]):
    def __init__(self, info: TerratorchProcessingInfo):
        super().__init__(info)
        self.dummy_data_generator = DummyDataGenerator(
145
146
            self.info.get_hf_config().to_dict()["pretrained_cfg"]
        )
147

148
149
150
151
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        return ""

    def get_dummy_mm_data(
152
153
154
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
155
        mm_options: Mapping[str, BaseDummyOptions],
156
    ) -> MultiModalDataDict:
157
158
        # Dummy data is generated based on the 'input' section
        # defined in the HF configuration file
159
160

        if mm_options:
161
162
163
164
165
            logger.warning(
                "Configurable multimodal profiling "
                "options are not supported for Terratorch. "
                "They are ignored for now."
            )
166

167
168
169
        return self.dummy_data_generator.get_dummy_mm_data()


170
class TerratorchMultiModalProcessor(BaseMultiModalProcessor[TerratorchProcessingInfo]):
171
172
173
174
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
175
176
        *,
        is_shared: bool = True,
177
    ) -> Mapping[str, MultiModalFieldConfig]:
178
179
180
181
182
        factory = _terratorch_field_factory(
            self.info.input_definition,
            is_shared=is_shared,
        )
        return factory(hf_inputs)
183

184
    def _get_prompt_updates(
185
186
187
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
188
        out_mm_kwargs: MultiModalKwargsItems,
189
190
    ) -> Sequence[PromptUpdate]:
        return []
191
192
193

    def apply(
        self,
194
195
        inputs: ProcessorInputs,
        timing_ctx: TimingContext,
196
    ) -> MultiModalInput:
197
198
199
200
201
202
203
204
205
206
207
208
        mm_items = inputs.mm_data_items
        hf_processor_mm_kwargs = inputs.hf_processor_mm_kwargs

        with timing_ctx.record("apply_hf_processor"):
            _, passthrough_data = self._get_hf_mm_data(mm_items)
            mm_processed_data = BatchFeature(
                {
                    k: torch.as_tensor(v).unsqueeze(0)
                    for k, v in passthrough_data.items()
                },
                tensor_type="pt",
            )
209
210
211

        mm_kwargs = MultiModalKwargsItems.from_hf_inputs(
            mm_processed_data,
212
213
214
215
216
            self._get_mm_fields_config(
                mm_processed_data,
                hf_processor_mm_kwargs,
                is_shared=False,
            ),
217
        )
218

219
220
221
222
223
        with timing_ctx.record("get_mm_hashes"):
            mm_hashes = inputs.get_mm_hashes(self.info.model_id)

        mm_placeholders = {"image": [PlaceholderRange(offset=0, length=0)]}

224
        return mm_input(
225
            prompt_token_ids=[1],
226
227
            mm_kwargs=mm_kwargs,
            mm_hashes=mm_hashes,
228
            mm_placeholders=mm_placeholders,
229
230
231
        )


232
@attn_type("attention_free")
233
@MULTIMODAL_REGISTRY.register_processor(
234
235
236
    TerratorchMultiModalProcessor,
    info=TerratorchProcessingInfo,
    dummy_inputs=TerratorchInputBuilder,
237
)
238
class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal):
239
    supports_multimodal_raw_input_only = True
240
    is_pooling_model = True
241

242
    @classmethod
243
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
244
245
246
247
248
        if modality.startswith("image"):
            return None

        raise ValueError("Only image modality is supported")

249
250
251
    def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

252
253
254
255
        config = vllm_config.model_config.hf_config.to_dict()["pretrained_cfg"]

        self.inference_runner = InferenceRunner(config)
        self.model = self.inference_runner.model
256

257
        self.pooler = IdentityPooler()
258

259
    def embed_input_ids(
260
261
        self,
        input_ids: torch.Tensor,
262
        multimodal_embeddings: MultiModalEmbeddings | None = None,
263
        *,
264
        is_multimodal: torch.Tensor | None = None,
265
266
267
268
269
270
271
    ) -> torch.Tensor:
        # We do not really use any input tokens and therefore no embeddings
        # to be calculated. However, due to the mandatory token ids in
        # the input prompt we pass one token and the size of the dummy
        # embedding tensors must reflect that.
        return torch.empty((input_ids.shape[0], 0))

272
273
    def forward(
        self,
274
        input_ids: torch.Tensor | None,
275
        positions: torch.Tensor,
276
277
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
278
279
        **kwargs: object,
    ):
280
281
        model_output = self.inference_runner.forward(**kwargs)
        return model_output.output
282

283
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
284
285
286
287
        params_list = []
        model_buffers = dict(self.named_buffers())
        loaded_buffers = []
        for key, value in weights:
288
289
290
291
292
293
294
295
296
            if isinstance(value, (dict, OrderedDict)):
                if key == "state_dict":
                    weights_to_parse = value
                    for name, weight in weights_to_parse.items():
                        name = f"inference_runner.{name}"

                        if "pos_embed" in name:
                            continue

297
298
                        if "_timm_module." in name:
                            name = name.replace("_timm_module.", "")
299
300
301
302
303
304
305

                        # this model requires a couple of buffers to be loaded
                        # that are not loadable with the AutoWeightsLoader
                        if name in model_buffers:
                            if "_timm_module." in name:
                                name = name.replace("_timm_module.", "")
                            buffer = model_buffers[name]
306
307
308
                            weight_loader = getattr(
                                buffer, "weight_loader", default_weight_loader
                            )
309
310
311
312
313
314
315
316
                            weight_loader(buffer, weight)
                            loaded_buffers.append(name)
                        else:
                            params_list.append((name, weight))
                    break

            elif isinstance(value, torch.Tensor):
                params_list.append((f"inference_runner.model.{key}", value))
317
318
319
320
321
322

        # Load the remaining model parameters
        loader = AutoWeightsLoader(self)
        autoloaded_weights = loader.load_weights(params_list)

        return autoloaded_weights.union(set(loaded_buffers))