terratorch.py 10.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17

# Copyright 2025 The vLLM team.
# Copyright 2025 IBM.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
18
"""Wrapper around `Terratorch` models"""
19

20
from collections import OrderedDict
21
from collections.abc import Iterable, Mapping, Sequence
22
from functools import cached_property
23
from typing import Any
24
25
26

import torch
import torch.nn as nn
27
28
29
30
31
32
from terratorch.vllm import (
    DummyDataGenerator,
    InferenceRunner,
    InputDefinition,
    InputTypeEnum,
)
33
34
35
from transformers import BatchFeature

from vllm.config import VllmConfig
36
37
from vllm.config.multimodal import BaseDummyOptions
from vllm.logger import init_logger
38
from vllm.model_executor.layers.pooler import IdentityPooler
39
40
41
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.utils import AutoWeightsLoader
from vllm.multimodal import MULTIMODAL_REGISTRY
42
43
44
45
46
47
48
49
from vllm.multimodal.inputs import (
    ImageItem,
    ModalityData,
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalInputs,
    MultiModalKwargsItems,
    PlaceholderRange,
50
    mm_inputs,
51
52
53
54
55
56
57
58
)
from vllm.multimodal.parse import (
    DictEmbeddingItems,
    ModalityDataItems,
    MultiModalDataItems,
    MultiModalDataParser,
)
from vllm.multimodal.processing import (
59
    BaseDummyInputsBuilder,
60
61
    BaseMultiModalProcessor,
    BaseProcessingInfo,
62
    ProcessorInputs,
63
    PromptUpdate,
64
    TimingContext,
65
)
66
from vllm.sequence import IntermediateTensors
67

68
from .interfaces import IsAttentionFree, MultiModalEmbeddings, SupportsMultiModal
69
from .interfaces_base import attn_type
70

71
72
logger = init_logger(__name__)

73

74
def _terratorch_field_names(input_definition: InputDefinition):
75
    return set(input_definition.data.keys())
76
77


78
79
80
81
82
def _terratorch_field_factory(
    input_definition: InputDefinition,
    *,
    is_shared: bool = True,  # True for unprocessed data, False for processed data
):
83
84
85
86
87
88
    def _terratorch_field_config(
        hf_inputs: Mapping[str, torch.Tensor],
    ) -> Mapping[str, MultiModalFieldConfig]:
        fields = dict[str, MultiModalFieldConfig]()
        for name, input in input_definition.data.items():
            modality = "image"
89
            if input.type == InputTypeEnum.tensor:
90
91
92
93
94
                fields[name] = (
                    MultiModalFieldConfig.shared(modality, batch_size=1)
                    if is_shared
                    else MultiModalFieldConfig.batched(modality)
                )
95

96
        return fields
97

98
    return _terratorch_field_config
99
100


101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
class TerratorchMultiModalDataParser(MultiModalDataParser):
    def __init__(self, input_definition: InputDefinition, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.input_definition = input_definition

    def _parse_image_data(
        self,
        data: dict[str, torch.Tensor] | ModalityData[ImageItem],
    ) -> ModalityDataItems[Any, Any] | None:
        if isinstance(data, dict):
            return DictEmbeddingItems(
                data,
                modality="image",
                required_fields=_terratorch_field_names(self.input_definition),
                fields_factory=_terratorch_field_factory(self.input_definition),
            )

        return super()._parse_image_data(data)

    def parse_mm_data(self, mm_data: MultiModalDataDict) -> MultiModalDataItems:
        if "image" not in mm_data:
            mm_data = {"image": mm_data}

        return super().parse_mm_data(mm_data)


128
class TerratorchProcessingInfo(BaseProcessingInfo):
129
130
131
132
133
134
135
136
137
138
139
    @cached_property
    def input_definition(self) -> InputDefinition:
        pretrained_cfg = self.get_hf_config().to_dict()["pretrained_cfg"]
        return InputDefinition(**pretrained_cfg["input"])

    def get_data_parser(self):
        return TerratorchMultiModalDataParser(
            self.input_definition,
            expected_hidden_size=self._get_expected_hidden_size(),
        )

140
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
141
142
143
        return {"image": None}


144
145
146
147
class TerratorchInputBuilder(BaseDummyInputsBuilder[TerratorchProcessingInfo]):
    def __init__(self, info: TerratorchProcessingInfo):
        super().__init__(info)
        self.dummy_data_generator = DummyDataGenerator(
148
149
            self.info.get_hf_config().to_dict()["pretrained_cfg"]
        )
150

151
152
153
154
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        return ""

    def get_dummy_mm_data(
155
156
157
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
158
        mm_options: Mapping[str, BaseDummyOptions],
159
    ) -> MultiModalDataDict:
160
161
        # Dummy data is generated based on the 'input' section
        # defined in the HF configuration file
162
163

        if mm_options:
164
165
166
167
168
            logger.warning(
                "Configurable multimodal profiling "
                "options are not supported for Terratorch. "
                "They are ignored for now."
            )
169

170
171
172
        return self.dummy_data_generator.get_dummy_mm_data()


173
class TerratorchMultiModalProcessor(BaseMultiModalProcessor[TerratorchProcessingInfo]):
174
175
176
177
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
178
179
        *,
        is_shared: bool = True,
180
    ) -> Mapping[str, MultiModalFieldConfig]:
181
182
183
184
185
        factory = _terratorch_field_factory(
            self.info.input_definition,
            is_shared=is_shared,
        )
        return factory(hf_inputs)
186

187
    def _get_prompt_updates(
188
189
190
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
191
        out_mm_kwargs: MultiModalKwargsItems,
192
193
    ) -> Sequence[PromptUpdate]:
        return []
194
195
196

    def apply(
        self,
197
198
        inputs: ProcessorInputs,
        timing_ctx: TimingContext,
199
    ) -> MultiModalInputs:
200
201
202
203
204
205
206
207
208
209
210
211
        mm_items = inputs.mm_data_items
        hf_processor_mm_kwargs = inputs.hf_processor_mm_kwargs

        with timing_ctx.record("apply_hf_processor"):
            _, passthrough_data = self._get_hf_mm_data(mm_items)
            mm_processed_data = BatchFeature(
                {
                    k: torch.as_tensor(v).unsqueeze(0)
                    for k, v in passthrough_data.items()
                },
                tensor_type="pt",
            )
212
213
214

        mm_kwargs = MultiModalKwargsItems.from_hf_inputs(
            mm_processed_data,
215
216
217
218
219
            self._get_mm_fields_config(
                mm_processed_data,
                hf_processor_mm_kwargs,
                is_shared=False,
            ),
220
        )
221

222
223
224
225
226
        with timing_ctx.record("get_mm_hashes"):
            mm_hashes = inputs.get_mm_hashes(self.info.model_id)

        mm_placeholders = {"image": [PlaceholderRange(offset=0, length=0)]}

227
        return mm_inputs(
228
            prompt_token_ids=[1],
229
230
            mm_kwargs=mm_kwargs,
            mm_hashes=mm_hashes,
231
            mm_placeholders=mm_placeholders,
232
233
234
        )


235
@attn_type("attention_free")
236
@MULTIMODAL_REGISTRY.register_processor(
237
238
239
    TerratorchMultiModalProcessor,
    info=TerratorchProcessingInfo,
    dummy_inputs=TerratorchInputBuilder,
240
)
241
class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal):
242
    supports_multimodal_raw_input_only = True
243
    is_pooling_model = True
244

245
    @classmethod
246
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
247
248
249
250
251
        if modality.startswith("image"):
            return None

        raise ValueError("Only image modality is supported")

252
253
254
    def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

255
256
257
258
        config = vllm_config.model_config.hf_config.to_dict()["pretrained_cfg"]

        self.inference_runner = InferenceRunner(config)
        self.model = self.inference_runner.model
259

260
        self.pooler = IdentityPooler()
261

262
    def embed_input_ids(
263
264
        self,
        input_ids: torch.Tensor,
265
        multimodal_embeddings: MultiModalEmbeddings | None = None,
266
        *,
267
        is_multimodal: torch.Tensor | None = None,
268
        handle_oov_mm_token: bool = False,
269
270
271
272
273
274
275
    ) -> torch.Tensor:
        # We do not really use any input tokens and therefore no embeddings
        # to be calculated. However, due to the mandatory token ids in
        # the input prompt we pass one token and the size of the dummy
        # embedding tensors must reflect that.
        return torch.empty((input_ids.shape[0], 0))

276
277
    def forward(
        self,
278
        input_ids: torch.Tensor | None,
279
        positions: torch.Tensor,
280
281
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
282
283
        **kwargs: object,
    ):
284
285
        model_output = self.inference_runner.forward(**kwargs)
        return model_output.output
286

287
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
288
289
290
291
        params_list = []
        model_buffers = dict(self.named_buffers())
        loaded_buffers = []
        for key, value in weights:
292
293
294
295
296
297
298
299
300
            if isinstance(value, (dict, OrderedDict)):
                if key == "state_dict":
                    weights_to_parse = value
                    for name, weight in weights_to_parse.items():
                        name = f"inference_runner.{name}"

                        if "pos_embed" in name:
                            continue

301
302
                        if "_timm_module." in name:
                            name = name.replace("_timm_module.", "")
303
304
305
306
307
308
309

                        # this model requires a couple of buffers to be loaded
                        # that are not loadable with the AutoWeightsLoader
                        if name in model_buffers:
                            if "_timm_module." in name:
                                name = name.replace("_timm_module.", "")
                            buffer = model_buffers[name]
310
311
312
                            weight_loader = getattr(
                                buffer, "weight_loader", default_weight_loader
                            )
313
314
315
316
317
318
319
320
                            weight_loader(buffer, weight)
                            loaded_buffers.append(name)
                        else:
                            params_list.append((name, weight))
                    break

            elif isinstance(value, torch.Tensor):
                params_list.append((f"inference_runner.model.{key}", value))
321
322
323
324
325
326

        # Load the remaining model parameters
        loader = AutoWeightsLoader(self)
        autoloaded_weights = loader.load_weights(params_list)

        return autoloaded_weights.union(set(loaded_buffers))