"vllm/vscode:/vscode.git/clone" did not exist on "4f6593b058dc7ba66d887442ba5763c6c1b3886e"
terratorch.py 10.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17

# Copyright 2025 The vLLM team.
# Copyright 2025 IBM.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
18
"""Wrapper around `Terratorch` models"""
19

20
from collections import OrderedDict
21
from collections.abc import Iterable, Mapping, Sequence
22
from functools import cached_property
23
from typing import Any
24
25
26

import torch
import torch.nn as nn
27
28
29
30
31
32
from terratorch.vllm import (
    DummyDataGenerator,
    InferenceRunner,
    InputDefinition,
    InputTypeEnum,
)
33
34
35
from transformers import BatchFeature

from vllm.config import VllmConfig
36
37
from vllm.config.multimodal import BaseDummyOptions
from vllm.logger import init_logger
38
from vllm.model_executor.layers.pooler import IdentityPooler
39
40
41
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.utils import AutoWeightsLoader
from vllm.multimodal import MULTIMODAL_REGISTRY
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from vllm.multimodal.inputs import (
    ImageItem,
    ModalityData,
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalInputs,
    MultiModalKwargsItems,
    MultiModalUUIDDict,
    PlaceholderRange,
)
from vllm.multimodal.parse import (
    DictEmbeddingItems,
    ModalityDataItems,
    MultiModalDataItems,
    MultiModalDataParser,
)
from vllm.multimodal.processing import (
59
    BaseDummyInputsBuilder,
60
61
62
63
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptUpdate,
)
64
from vllm.sequence import IntermediateTensors
65
from vllm.utils import length_from_prompt_token_ids_or_embeds
66

67
from .interfaces import IsAttentionFree, MultiModalEmbeddings, SupportsMultiModal
68
from .interfaces_base import attn_type
69

70
71
logger = init_logger(__name__)

72

73
def _terratorch_field_names(input_definition: InputDefinition):
74
    return set(input_definition.data.keys())
75
76


77
78
79
80
81
82
83
def _terratorch_field_factory(input_definition: InputDefinition):
    def _terratorch_field_config(
        hf_inputs: Mapping[str, torch.Tensor],
    ) -> Mapping[str, MultiModalFieldConfig]:
        fields = dict[str, MultiModalFieldConfig]()
        for name, input in input_definition.data.items():
            modality = "image"
84
            if input.type == InputTypeEnum.tensor:
85
                fields[name] = MultiModalFieldConfig.shared(modality, batch_size=1)
86

87
        return fields
88

89
    return _terratorch_field_config
90
91


92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
class TerratorchMultiModalDataParser(MultiModalDataParser):
    def __init__(self, input_definition: InputDefinition, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.input_definition = input_definition

    def _parse_image_data(
        self,
        data: dict[str, torch.Tensor] | ModalityData[ImageItem],
    ) -> ModalityDataItems[Any, Any] | None:
        if isinstance(data, dict):
            return DictEmbeddingItems(
                data,
                modality="image",
                required_fields=_terratorch_field_names(self.input_definition),
                fields_factory=_terratorch_field_factory(self.input_definition),
            )

        return super()._parse_image_data(data)

    def parse_mm_data(self, mm_data: MultiModalDataDict) -> MultiModalDataItems:
        if "image" not in mm_data:
            mm_data = {"image": mm_data}

        return super().parse_mm_data(mm_data)


119
class TerratorchProcessingInfo(BaseProcessingInfo):
120
121
122
123
124
125
126
127
128
129
130
    @cached_property
    def input_definition(self) -> InputDefinition:
        pretrained_cfg = self.get_hf_config().to_dict()["pretrained_cfg"]
        return InputDefinition(**pretrained_cfg["input"])

    def get_data_parser(self):
        return TerratorchMultiModalDataParser(
            self.input_definition,
            expected_hidden_size=self._get_expected_hidden_size(),
        )

131
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
132
133
134
        return {"image": None}


135
136
137
138
class TerratorchInputBuilder(BaseDummyInputsBuilder[TerratorchProcessingInfo]):
    def __init__(self, info: TerratorchProcessingInfo):
        super().__init__(info)
        self.dummy_data_generator = DummyDataGenerator(
139
140
            self.info.get_hf_config().to_dict()["pretrained_cfg"]
        )
141

142
143
144
145
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        return ""

    def get_dummy_mm_data(
146
147
148
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
149
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
150
    ) -> MultiModalDataDict:
151
152
        # Dummy data is generated based on the 'input' section
        # defined in the HF configuration file
153
154

        if mm_options:
155
156
157
158
159
            logger.warning(
                "Configurable multimodal profiling "
                "options are not supported for Terratorch. "
                "They are ignored for now."
            )
160

161
162
163
        return self.dummy_data_generator.get_dummy_mm_data()


164
class TerratorchMultiModalProcessor(BaseMultiModalProcessor[TerratorchProcessingInfo]):
165
166
167
168
169
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
170
        return _terratorch_field_factory(self.info.input_definition)(hf_inputs)
171

172
    def _get_prompt_updates(
173
174
175
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
176
        out_mm_kwargs: MultiModalKwargsItems,
177
178
    ) -> Sequence[PromptUpdate]:
        return []
179
180
181

    def apply(
        self,
182
        prompt: str | list[int],
183
184
        mm_data: MultiModalDataDict,
        hf_processor_mm_kwargs: Mapping[str, object],
185
186
        tokenization_kwargs: Mapping[str, object] | None = None,
        mm_uuids: MultiModalUUIDDict | None = None,
187
    ) -> MultiModalInputs:
188
        mm_items = self._to_mm_items(mm_data)
189
        tokenization_kwargs = tokenization_kwargs or {}
190
191
192
        mm_hashes = self._hash_mm_items(
            mm_items, hf_processor_mm_kwargs, tokenization_kwargs, mm_uuids=mm_uuids
        )
193

194
195
196
197
        mm_processed_data = BatchFeature(
            mm_data.get("image", mm_data), tensor_type="pt"
        )
        mm_placeholders = {"image": [PlaceholderRange(offset=0, length=0)]}
198
199
200

        mm_kwargs = MultiModalKwargsItems.from_hf_inputs(
            mm_processed_data,
201
            self._get_mm_fields_config(mm_processed_data, hf_processor_mm_kwargs),
202
        )
203
204
205
206

        return MultiModalInputs(
            type="multimodal",
            prompt_token_ids=[1],
207
208
            mm_kwargs=mm_kwargs,
            mm_hashes=mm_hashes,
209
            mm_placeholders=mm_placeholders,
210
211
212
        )


213
@attn_type("attention_free")
214
@MULTIMODAL_REGISTRY.register_processor(
215
216
217
    TerratorchMultiModalProcessor,
    info=TerratorchProcessingInfo,
    dummy_inputs=TerratorchInputBuilder,
218
)
219
class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal):
220
    supports_multimodal_raw_input_only = True
221
    is_pooling_model = True
222

223
    @classmethod
224
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
225
226
227
228
229
        if modality.startswith("image"):
            return None

        raise ValueError("Only image modality is supported")

230
231
232
    def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

233
234
235
236
        config = vllm_config.model_config.hf_config.to_dict()["pretrained_cfg"]

        self.inference_runner = InferenceRunner(config)
        self.model = self.inference_runner.model
237

238
239
240
        pooler_config = vllm_config.model_config.pooler_config
        assert pooler_config is not None

241
        self.pooler = IdentityPooler()
242

243
    def embed_input_ids(
244
245
        self,
        input_ids: torch.Tensor,
246
        multimodal_embeddings: MultiModalEmbeddings | None = None,
247
        *,
248
        is_multimodal: torch.Tensor | None = None,
249
        handle_oov_mm_token: bool = False,
250
251
252
253
254
255
256
    ) -> torch.Tensor:
        # We do not really use any input tokens and therefore no embeddings
        # to be calculated. However, due to the mandatory token ids in
        # the input prompt we pass one token and the size of the dummy
        # embedding tensors must reflect that.
        return torch.empty((input_ids.shape[0], 0))

257
258
    def forward(
        self,
259
        input_ids: torch.Tensor | None,
260
        positions: torch.Tensor,
261
262
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
263
264
        **kwargs: object,
    ):
265
266
267
268
        input_len = length_from_prompt_token_ids_or_embeds(input_ids, inputs_embeds)

        batched_kwargs = {k: v.unsqueeze(0) for k, v in kwargs.items()}
        model_output = self.inference_runner.forward(**batched_kwargs).output
269

270
271
272
273
        # The leading dimension of hidden states needs to equal input length
        return model_output.expand(
            input_len, *(-1 for _ in range(model_output.ndim - 1))
        )
274

275
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
276
277
278
279
        params_list = []
        model_buffers = dict(self.named_buffers())
        loaded_buffers = []
        for key, value in weights:
280
281
282
283
284
285
286
287
288
            if isinstance(value, (dict, OrderedDict)):
                if key == "state_dict":
                    weights_to_parse = value
                    for name, weight in weights_to_parse.items():
                        name = f"inference_runner.{name}"

                        if "pos_embed" in name:
                            continue

289
290
                        if "_timm_module." in name:
                            name = name.replace("_timm_module.", "")
291
292
293
294
295
296
297

                        # this model requires a couple of buffers to be loaded
                        # that are not loadable with the AutoWeightsLoader
                        if name in model_buffers:
                            if "_timm_module." in name:
                                name = name.replace("_timm_module.", "")
                            buffer = model_buffers[name]
298
299
300
                            weight_loader = getattr(
                                buffer, "weight_loader", default_weight_loader
                            )
301
302
303
304
305
306
307
308
                            weight_loader(buffer, weight)
                            loaded_buffers.append(name)
                        else:
                            params_list.append((name, weight))
                    break

            elif isinstance(value, torch.Tensor):
                params_list.append((f"inference_runner.model.{key}", value))
309
310
311
312
313
314

        # Load the remaining model parameters
        loader = AutoWeightsLoader(self)
        autoloaded_weights = loader.load_weights(params_list)

        return autoloaded_weights.union(set(loaded_buffers))