terratorch.py 11 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17

# Copyright 2025 The vLLM team.
# Copyright 2025 IBM.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
18
"""Wrapper around `Terratorch` models"""
19

20
from collections import OrderedDict
21
from collections.abc import Iterable, Mapping, Sequence
22
from typing import Any
23
24
25

import torch
import torch.nn as nn
26
27
28
29
30
31
from terratorch.vllm import (
    DummyDataGenerator,
    InferenceRunner,
    InputDefinition,
    InputTypeEnum,
)
32
33
34
from transformers import BatchFeature

from vllm.config import VllmConfig
35
36
from vllm.config.multimodal import BaseDummyOptions
from vllm.logger import init_logger
37
from vllm.model_executor.layers.pooler import IdentityPooler
38
39
40
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.utils import AutoWeightsLoader
from vllm.multimodal import MULTIMODAL_REGISTRY
41
from vllm.multimodal.cache import MultiModalProcessorOnlyCache
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from vllm.multimodal.inputs import (
    ImageItem,
    ModalityData,
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalInputs,
    MultiModalKwargsItems,
    MultiModalUUIDDict,
    PlaceholderRange,
)
from vllm.multimodal.parse import (
    DictEmbeddingItems,
    ModalityDataItems,
    MultiModalDataItems,
    MultiModalDataParser,
)
from vllm.multimodal.processing import (
59
    BaseDummyInputsBuilder,
60
61
62
63
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptUpdate,
)
64
from vllm.sequence import IntermediateTensors
65
from vllm.utils import length_from_prompt_token_ids_or_embeds
66

67
from .interfaces import IsAttentionFree, MultiModalEmbeddings, SupportsMultiModal
68
from .interfaces_base import attn_type
69

70
71
logger = init_logger(__name__)

72

73
def _terratorch_field_names(input_definition: InputDefinition):
74
    return set(input_definition.data.keys())
75
76


77
78
79
80
81
82
83
def _terratorch_field_factory(input_definition: InputDefinition):
    def _terratorch_field_config(
        hf_inputs: Mapping[str, torch.Tensor],
    ) -> Mapping[str, MultiModalFieldConfig]:
        fields = dict[str, MultiModalFieldConfig]()
        for name, input in input_definition.data.items():
            modality = "image"
84
            if input.type == InputTypeEnum.tensor:
85
                fields[name] = MultiModalFieldConfig.shared(modality, batch_size=1)
86

87
        return fields
88

89
    return _terratorch_field_config
90
91


92
class TerratorchProcessingInfo(BaseProcessingInfo):
93
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
94
95
96
        return {"image": None}


97
98
99
100
class TerratorchInputBuilder(BaseDummyInputsBuilder[TerratorchProcessingInfo]):
    def __init__(self, info: TerratorchProcessingInfo):
        super().__init__(info)
        self.dummy_data_generator = DummyDataGenerator(
101
102
            self.info.get_hf_config().to_dict()["pretrained_cfg"]
        )
103

104
105
106
107
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        return ""

    def get_dummy_mm_data(
108
109
110
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
111
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
112
    ) -> MultiModalDataDict:
113
114
        # Dummy data is generated based on the 'input' section
        # defined in the HF configuration file
115
116

        if mm_options:
117
118
119
120
121
            logger.warning(
                "Configurable multimodal profiling "
                "options are not supported for Terratorch. "
                "They are ignored for now."
            )
122

123
124
125
126
        return self.dummy_data_generator.get_dummy_mm_data()


class TerratorchMultiModalDataParser(MultiModalDataParser):
127
    def __init__(self, input_definition: InputDefinition, *args, **kwargs):
128
129
        super().__init__(*args, **kwargs)

130
131
        self.input_definition = input_definition

132
133
    def _parse_image_data(
        self,
134
135
        data: dict[str, torch.Tensor] | ModalityData[ImageItem],
    ) -> ModalityDataItems[Any, Any] | None:
136
137
138
139
        if isinstance(data, dict):
            return DictEmbeddingItems(
                data,
                modality="image",
140
141
                required_fields=_terratorch_field_names(self.input_definition),
                fields_factory=_terratorch_field_factory(self.input_definition),
142
143
144
            )

        return super()._parse_image_data(data)
145

146
147
148
149
150
151
    def parse_mm_data(self, mm_data: MultiModalDataDict) -> MultiModalDataItems:
        if "image" not in mm_data:
            mm_data = {"image": mm_data}

        return super().parse_mm_data(mm_data)

152

153
154
class TerratorchMultiModalProcessor(BaseMultiModalProcessor):
    def __init__(
155
156
157
158
        self,
        info: TerratorchProcessingInfo,
        dummy_inputs: "BaseDummyInputsBuilder[TerratorchProcessingInfo]",
        *,
159
        cache: MultiModalProcessorOnlyCache | None = None,
160
    ) -> None:
161
162
163
        pretrained_cfg = info.get_hf_config().to_dict()["pretrained_cfg"]
        self._input_definition = InputDefinition(**pretrained_cfg["input"])

164
        super().__init__(info=info, dummy_inputs=dummy_inputs, cache=cache)
165

166
    def _get_data_parser(self) -> MultiModalDataParser:
167
        return TerratorchMultiModalDataParser(self._input_definition)
168

169
170
171
172
173
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
174
        return _terratorch_field_factory(self._input_definition)(hf_inputs)
175

176
    def _get_prompt_updates(
177
178
179
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
180
        out_mm_kwargs: MultiModalKwargsItems,
181
182
    ) -> Sequence[PromptUpdate]:
        return []
183
184
185

    def apply(
        self,
186
        prompt: str | list[int],
187
188
        mm_data: MultiModalDataDict,
        hf_processor_mm_kwargs: Mapping[str, object],
189
190
        tokenization_kwargs: Mapping[str, object] | None = None,
        mm_uuids: MultiModalUUIDDict | None = None,
191
    ) -> MultiModalInputs:
192
        mm_items = self._to_mm_items(mm_data)
193
        tokenization_kwargs = tokenization_kwargs or {}
194
195
196
        mm_hashes = self._hash_mm_items(
            mm_items, hf_processor_mm_kwargs, tokenization_kwargs, mm_uuids=mm_uuids
        )
197

198
199
200
201
        mm_processed_data = BatchFeature(
            mm_data.get("image", mm_data), tensor_type="pt"
        )
        mm_placeholders = {"image": [PlaceholderRange(offset=0, length=0)]}
202
203
204

        mm_kwargs = MultiModalKwargsItems.from_hf_inputs(
            mm_processed_data,
205
            self._get_mm_fields_config(mm_processed_data, hf_processor_mm_kwargs),
206
        )
207
208
209
210

        return MultiModalInputs(
            type="multimodal",
            prompt_token_ids=[1],
211
212
            mm_kwargs=mm_kwargs,
            mm_hashes=mm_hashes,
213
            mm_placeholders=mm_placeholders,
214
215
216
        )


217
@attn_type("attention_free")
218
@MULTIMODAL_REGISTRY.register_processor(
219
220
221
    TerratorchMultiModalProcessor,
    info=TerratorchProcessingInfo,
    dummy_inputs=TerratorchInputBuilder,
222
)
223
class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal):
224
    supports_multimodal_raw_input_only = True
225
    is_pooling_model = True
226

227
    @classmethod
228
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
229
230
231
232
233
        if modality.startswith("image"):
            return None

        raise ValueError("Only image modality is supported")

234
235
236
    def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

237
238
239
240
        config = vllm_config.model_config.hf_config.to_dict()["pretrained_cfg"]

        self.inference_runner = InferenceRunner(config)
        self.model = self.inference_runner.model
241

242
243
244
        pooler_config = vllm_config.model_config.pooler_config
        assert pooler_config is not None

245
        self.pooler = IdentityPooler()
246

247
    def embed_input_ids(
248
249
        self,
        input_ids: torch.Tensor,
250
        multimodal_embeddings: MultiModalEmbeddings | None = None,
251
        *,
252
        is_multimodal: torch.Tensor | None = None,
253
        handle_oov_mm_token: bool = False,
254
255
256
257
258
259
260
    ) -> torch.Tensor:
        # We do not really use any input tokens and therefore no embeddings
        # to be calculated. However, due to the mandatory token ids in
        # the input prompt we pass one token and the size of the dummy
        # embedding tensors must reflect that.
        return torch.empty((input_ids.shape[0], 0))

261
262
    def forward(
        self,
263
        input_ids: torch.Tensor | None,
264
        positions: torch.Tensor,
265
266
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
267
268
        **kwargs: object,
    ):
269
270
271
272
        input_len = length_from_prompt_token_ids_or_embeds(input_ids, inputs_embeds)

        batched_kwargs = {k: v.unsqueeze(0) for k, v in kwargs.items()}
        model_output = self.inference_runner.forward(**batched_kwargs).output
273

274
275
276
277
        # The leading dimension of hidden states needs to equal input length
        return model_output.expand(
            input_len, *(-1 for _ in range(model_output.ndim - 1))
        )
278

279
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
280
281
282
283
        params_list = []
        model_buffers = dict(self.named_buffers())
        loaded_buffers = []
        for key, value in weights:
284
285
286
287
288
289
290
291
292
            if isinstance(value, (dict, OrderedDict)):
                if key == "state_dict":
                    weights_to_parse = value
                    for name, weight in weights_to_parse.items():
                        name = f"inference_runner.{name}"

                        if "pos_embed" in name:
                            continue

293
294
                        if "_timm_module." in name:
                            name = name.replace("_timm_module.", "")
295
296
297
298
299
300
301

                        # this model requires a couple of buffers to be loaded
                        # that are not loadable with the AutoWeightsLoader
                        if name in model_buffers:
                            if "_timm_module." in name:
                                name = name.replace("_timm_module.", "")
                            buffer = model_buffers[name]
302
303
304
                            weight_loader = getattr(
                                buffer, "weight_loader", default_weight_loader
                            )
305
306
307
308
309
310
311
312
                            weight_loader(buffer, weight)
                            loaded_buffers.append(name)
                        else:
                            params_list.append((name, weight))
                    break

            elif isinstance(value, torch.Tensor):
                params_list.append((f"inference_runner.model.{key}", value))
313
314
315
316
317
318

        # Load the remaining model parameters
        loader = AutoWeightsLoader(self)
        autoloaded_weights = loader.load_weights(params_list)

        return autoloaded_weights.union(set(loaded_buffers))