terratorch.py 11.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17

# Copyright 2025 The vLLM team.
# Copyright 2025 IBM.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
18
"""Wrapper around `Terratorch` models"""
19

20
from collections import OrderedDict
21
from collections.abc import Iterable, Mapping, Sequence
22
from typing import Any, Callable, Optional, Union
23
24
25

import torch
import torch.nn as nn
26
27
28
29
30
31
from terratorch.vllm import (
    DummyDataGenerator,
    InferenceRunner,
    InputDefinition,
    InputTypeEnum,
)
32
33
34
from transformers import BatchFeature

from vllm.config import VllmConfig
35
36
from vllm.config.multimodal import BaseDummyOptions
from vllm.logger import init_logger
37
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
38
39
40
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.utils import AutoWeightsLoader
from vllm.multimodal import MULTIMODAL_REGISTRY
41
from vllm.multimodal.cache import MultiModalProcessorOnlyCache
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from vllm.multimodal.inputs import (
    ImageItem,
    ModalityData,
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalInputs,
    MultiModalKwargsItems,
    MultiModalUUIDDict,
    PlaceholderRange,
)
from vllm.multimodal.parse import (
    DictEmbeddingItems,
    ModalityDataItems,
    MultiModalDataItems,
    MultiModalDataParser,
)
from vllm.multimodal.processing import (
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptUpdate,
)
63
from vllm.multimodal.profiling import BaseDummyInputsBuilder
64
from vllm.sequence import IntermediateTensors
65

66
from .interfaces import IsAttentionFree, MultiModalEmbeddings, SupportsMultiModal
67
68
from .interfaces_base import default_pooling_type

69
70
logger = init_logger(__name__)

71

72
73
74
def _terratorch_field_names(pretrained_cfg: dict):
    input_definition = InputDefinition(**pretrained_cfg["input"])
    return set(input_definition.data.keys())
75
76


77
def _terratorch_field_factory(
78
    pretrained_cfg: dict,
79
80
) -> Callable[
    [Mapping[str, torch.Tensor]],
81
    Mapping[str, MultiModalFieldConfig],
82
83
84
85
86
87
88
]:
    def _terratorch_field_config(hf_inputs: Mapping[str, torch.Tensor]):
        input_definition = InputDefinition(**pretrained_cfg["input"])
        fields = {}
        for input_name, input in input_definition.data.items():
            if input.type == InputTypeEnum.tensor:
                fields[input_name] = "image"
89

90
91
92
93
        return {
            field_name: MultiModalFieldConfig.batched(modality=field_modality)
            for field_name, field_modality in fields.items()
        }
94

95
    return _terratorch_field_config
96
97


98
class TerratorchProcessingInfo(BaseProcessingInfo):
99
100
101
102
    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": None}


103
104
105
106
class TerratorchInputBuilder(BaseDummyInputsBuilder[TerratorchProcessingInfo]):
    def __init__(self, info: TerratorchProcessingInfo):
        super().__init__(info)
        self.dummy_data_generator = DummyDataGenerator(
107
108
            self.info.get_hf_config().to_dict()["pretrained_cfg"]
        )
109

110
111
112
113
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        return ""

    def get_dummy_mm_data(
114
115
116
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
117
        mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
118
    ) -> MultiModalDataDict:
119
120
        # Dummy data is generated based on the 'input' section
        # defined in the HF configuration file
121
122

        if mm_options:
123
124
125
126
127
            logger.warning(
                "Configurable multimodal profiling "
                "options are not supported for Terratorch. "
                "They are ignored for now."
            )
128

129
130
131
132
133
134
135
136
137
138
139
140
141
142
        return self.dummy_data_generator.get_dummy_mm_data()


class TerratorchMultiModalDataParser(MultiModalDataParser):
    def __init__(self, pretrained_cfg: dict, *args, **kwargs):
        self._pretrained_cfg = pretrained_cfg
        super().__init__(*args, **kwargs)

    def _parse_image_data(
        self,
        data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]],
    ) -> Optional[ModalityDataItems[Any, Any]]:
        if isinstance(data, dict):
            terratorch_fields = _terratorch_field_names(self._pretrained_cfg)
143

144
145
146
147
148
149
150
151
            return DictEmbeddingItems(
                data,
                modality="image",
                required_fields=terratorch_fields,
                fields_factory=_terratorch_field_factory(self._pretrained_cfg),
            )

        return super()._parse_image_data(data)
152

153

154
155
class TerratorchMultiModalProcessor(BaseMultiModalProcessor):
    def __init__(
156
157
158
159
160
161
        self,
        info: TerratorchProcessingInfo,
        dummy_inputs: "BaseDummyInputsBuilder[TerratorchProcessingInfo]",
        *,
        cache: Optional[MultiModalProcessorOnlyCache] = None,
    ) -> None:
162
163
        self.pretrained_cfg = info.get_hf_config().to_dict()["pretrained_cfg"]
        super().__init__(info=info, dummy_inputs=dummy_inputs, cache=cache)
164

165
    def _get_data_parser(self) -> MultiModalDataParser:
166
        return TerratorchMultiModalDataParser(pretrained_cfg=self.pretrained_cfg)
167

168
169
170
171
172
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
173
        return _terratorch_field_factory(self.pretrained_cfg)(hf_inputs)
174

175
    def _get_prompt_updates(
176
177
178
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
179
        out_mm_kwargs: MultiModalKwargsItems,
180
181
    ) -> Sequence[PromptUpdate]:
        return []
182
183
184
185
186
187

    def apply(
        self,
        prompt: Union[str, list[int]],
        mm_data: MultiModalDataDict,
        hf_processor_mm_kwargs: Mapping[str, object],
188
        tokenization_kwargs: Optional[Mapping[str, object]] = None,
189
        mm_uuids: Optional[MultiModalUUIDDict] = None,
190
    ) -> MultiModalInputs:
191
192
        if "image" in mm_data:
            image_data = mm_data["image"]
193
            image_data = {k: v.unsqueeze(0) for k, v in image_data.items()}
194
195
        else:
            image_data = mm_data
196
197
198
            image_data = {k: v.unsqueeze(0) for k, v in image_data.items()}

        mm_data = {"image": image_data}
199
200

        mm_items = self._to_mm_items(mm_data)
201
        tokenization_kwargs = tokenization_kwargs or {}
202
203
204
        mm_hashes = self._hash_mm_items(
            mm_items, hf_processor_mm_kwargs, tokenization_kwargs, mm_uuids=mm_uuids
        )
205
206
        mm_placeholders = {"image": [PlaceholderRange(offset=0, length=0)]}

207
208
209
210
        mm_processed_data = BatchFeature(image_data)

        mm_kwargs = MultiModalKwargsItems.from_hf_inputs(
            mm_processed_data,
211
            self._get_mm_fields_config(mm_processed_data, hf_processor_mm_kwargs),
212
        )
213
214
215
216

        return MultiModalInputs(
            type="multimodal",
            prompt_token_ids=[1],
217
218
            mm_kwargs=mm_kwargs,
            mm_hashes=mm_hashes,
219
            mm_placeholders=mm_placeholders,
220
221
222
        )


223
@default_pooling_type("All")
224
@MULTIMODAL_REGISTRY.register_processor(
225
226
227
    TerratorchMultiModalProcessor,
    info=TerratorchProcessingInfo,
    dummy_inputs=TerratorchInputBuilder,
228
)
229
class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal):
230
    merge_by_field_config = True
231
    supports_multimodal_raw_input_only = True
232
    is_pooling_model = True
233

234
235
236
237
238
239
240
    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
        if modality.startswith("image"):
            return None

        raise ValueError("Only image modality is supported")

241
242
243
    def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

244
245
246
247
        config = vllm_config.model_config.hf_config.to_dict()["pretrained_cfg"]

        self.inference_runner = InferenceRunner(config)
        self.model = self.inference_runner.model
248

249
250
251
252
        pooler_config = vllm_config.model_config.pooler_config
        assert pooler_config is not None

        self.pooler = DispatchPooler(
253
254
            {"encode": Pooler.for_encode(pooler_config)},
        )
255

256
257
258
259
    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
        multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
260
261
262
        *,
        is_multimodal: Optional[torch.Tensor] = None,
        handle_oov_mm_token: bool = False,
263
264
265
266
267
268
269
    ) -> torch.Tensor:
        # We do not really use any input tokens and therefore no embeddings
        # to be calculated. However, due to the mandatory token ids in
        # the input prompt we pass one token and the size of the dummy
        # embedding tensors must reflect that.
        return torch.empty((input_ids.shape[0], 0))

270
271
272
273
274
275
276
277
    def forward(
        self,
        input_ids: Optional[torch.Tensor],
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        **kwargs: object,
    ):
278
        model_output = self.inference_runner.forward(**kwargs)
279
280
281

        return model_output.output

282
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
283
284
285
286
        params_list = []
        model_buffers = dict(self.named_buffers())
        loaded_buffers = []
        for key, value in weights:
287
288
289
290
291
292
293
294
295
            if isinstance(value, (dict, OrderedDict)):
                if key == "state_dict":
                    weights_to_parse = value
                    for name, weight in weights_to_parse.items():
                        name = f"inference_runner.{name}"

                        if "pos_embed" in name:
                            continue

296
297
                        if "_timm_module." in name:
                            name = name.replace("_timm_module.", "")
298
299
300
301
302
303
304

                        # this model requires a couple of buffers to be loaded
                        # that are not loadable with the AutoWeightsLoader
                        if name in model_buffers:
                            if "_timm_module." in name:
                                name = name.replace("_timm_module.", "")
                            buffer = model_buffers[name]
305
306
307
                            weight_loader = getattr(
                                buffer, "weight_loader", default_weight_loader
                            )
308
309
310
311
312
313
314
315
                            weight_loader(buffer, weight)
                            loaded_buffers.append(name)
                        else:
                            params_list.append((name, weight))
                    break

            elif isinstance(value, torch.Tensor):
                params_list.append((f"inference_runner.model.{key}", value))
316
317
318
319
320
321

        # Load the remaining model parameters
        loader = AutoWeightsLoader(self)
        autoloaded_weights = loader.load_weights(params_list)

        return autoloaded_weights.union(set(loaded_buffers))