idefics3.py 24.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Idefics3 model compatible with HuggingFace weights."""

import math
20
from collections.abc import Iterable, Mapping, Sequence
21
from typing import Annotated, Literal, TypeAlias
22
23
24

import torch
from torch import nn
25
26
27
28
29
30
from transformers import (
    BatchFeature,
    Idefics3Config,
    Idefics3ImageProcessor,
    Idefics3Processor,
)
31

32
from vllm.config import VllmConfig
33
from vllm.config.multimodal import BaseDummyOptions
34
35
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
36
from vllm.model_executor.layers.quantization import QuantizationConfig
37
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
38
from vllm.model_executor.models.module_mapping import MultiModelKeys
39
from vllm.multimodal import MULTIMODAL_REGISTRY
40
41
42
43
44
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
)
45
from vllm.multimodal.parse import ImageProcessorItems, ImageSize, MultiModalDataItems
46
from vllm.multimodal.processing import (
47
    BaseDummyInputsBuilder,
48
49
50
51
52
53
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptReplacement,
    PromptUpdate,
    PromptUpdateDetails,
)
54
from vllm.sequence import IntermediateTensors
55
from vllm.utils.tensor_schema import TensorSchema, TensorShape
56
57

from .idefics2_vision_model import (
58
59
    Idefics2VisionTransformer as Idefics3VisionTransformer,
)
60
61
62
63
64
from .interfaces import (
    MultiModalEmbeddings,
    SupportsLoRA,
    SupportsMultiModal,
)
65
from .llama import LlamaModel
66
from .utils import AutoWeightsLoader, maybe_prefix
67
68


69
class Idefics3ImagePixelInputs(TensorSchema):
70
    """
71
72
73
74
75
76
    Dimensions:
        - bn: Batch size * number of images
        - bnp: Batch size * number of images * number of patches
        - c: Number of channels (3)
        - h: Height
        - w: Width
77
    """
78

79
80
    type: Literal["pixel_values"]
    pixel_values: Annotated[torch.Tensor, TensorShape("bnp", 3, "h", "w")]
81
    pixel_attention_mask: Annotated[torch.Tensor, TensorShape("bnp", "h", "w")]
82
    num_patches: Annotated[torch.Tensor, TensorShape("bn")]
83

84

85
class Idefics3ImageEmbeddingInputs(TensorSchema):
86
    """
87
88
89
90
    Dimensions:
        - bn: Batch size * number of images
        - f: Image feature size
        - h: Hidden size (must match the hidden size of language model backbone)
91
    """
92

93
94
    type: Literal["image_embeds"]
    data: Annotated[torch.Tensor, TensorShape("bn", "f", "h")]
95
96


97
ImageInputs: TypeAlias = Idefics3ImagePixelInputs | Idefics3ImageEmbeddingInputs
98
99


100
class Idefics3ProcessingInfo(BaseProcessingInfo):
101
    def get_hf_processor(self, **kwargs: object) -> Idefics3Processor:
102
        return self.ctx.get_hf_processor(Idefics3Processor, **kwargs)
103

104
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
105
        return {"image": None}
106

107
108
109
110
111
    def _resize_output_size(
        self,
        *,
        height: int,
        width: int,
112
        max_len: int | None = None,
113
        min_len: int = 1,
114
        max_size: int | None = None,
115
    ) -> tuple[int, int]:
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
        # Set default value for max_len if not provided
        max_len = max(height, width) if max_len is None else max_len
        aspect_ratio = width / height

        # Handle the maximum size constraint
        if max_size is not None:
            max_len = min(max_len, max_size)

        # Adjust dimensions according to the aspect ratio
        if width >= height:
            width = max_len
            height = int(width / aspect_ratio)
        else:
            height = max_len
            width = int(height * aspect_ratio)
131

132
133
134
        # Ensure both width and height are even (if needed)
        height += height % 2
        width += width % 2
135

136
137
138
        # Ensure dimensions are not smaller than the minimum length
        height = max(height, min_len)
        width = max(width, min_len)
139

140
        return height, width
141

142
143
144
145
146
147
148
149
150
    def _get_resize_output_image_size(
        self,
        *,
        image_width: int,
        image_height: int,
        resolution_max_side: int,
    ) -> tuple[int, int]:
        hf_processor = self.get_hf_processor()
        image_processor: Idefics3ImageProcessor = hf_processor.image_processor
151
        max_image_size = image_processor.size["longest_edge"]
152
153
        if resolution_max_side > max_image_size:
            raise ValueError(
154
155
                "`resolution_max_side` cannot be larger than `max_image_size`"
            )
156
157
158
159
160

        height, width = image_height, image_width

        # Find the output size, when rescaling the longest edge to max_len and
        # preserving the aspect ratio
161
162
163
        height, width = self._resize_output_size(
            height=height, width=width, max_len=resolution_max_side
        )
164
165
166
167
168
169
170
        return height, width

    def _get_image_feature_grid_size(
        self,
        *,
        image_width: int,
        image_height: int,
171
        processor: Idefics3Processor | None,
172
    ) -> tuple[int, int]:
173
174
175
176
177
        if processor is None:
            processor = self.get_hf_processor()

        image_processor: Idefics3ImageProcessor = processor.image_processor

178
179
        max_image_size = image_processor.max_image_size["longest_edge"]
        size = image_processor.size["longest_edge"]
180
181
182
        assert size % max_image_size == 0, (
            "`longest_edge` in image_processor's `size` must be divisible by "
            "`longest_edge` in `max_image_size`, this may be caused by "
183
184
            "incorrect mm_kwargs override."
        )
185
186
187
188
189
190
191
192
193
194
195
196

        resized_height, resized_width = self._get_resize_output_image_size(
            image_width=image_width,
            image_height=image_height,
            resolution_max_side=size,
        )
        if resized_height > max_image_size or resized_width > max_image_size:
            grid_h = math.ceil(resized_height / max_image_size)
            grid_w = math.ceil(resized_width / max_image_size)
        else:
            grid_h = grid_w = 0
        return grid_w, grid_h
197

198
199
200
201
202
    def get_num_patches(
        self,
        *,
        image_width: int,
        image_height: int,
203
        processor: Idefics3Processor | None,
204
205
206
207
208
209
210
211
212
    ) -> int:
        grid_w, grid_h = self._get_image_feature_grid_size(
            image_width=image_width,
            image_height=image_height,
            processor=processor,
        )

        return grid_w * grid_h + 1

213
    def _get_image_token(
214
        self, processor: Idefics3Processor | None
215
    ) -> tuple[str, str, str]:
216
217
        if processor is None:
            processor = self.get_hf_processor()
218

219
220
        image_token = processor.image_token
        fake_image_token = processor.fake_image_token
221
222
223
        global_image_token = processor.global_image_tag
        return image_token, fake_image_token, global_image_token

224
225
226
227
228
    def get_image_repl(
        self,
        *,
        image_width: int,
        image_height: int,
229
        processor: Idefics3Processor | None,
230
231
232
233
    ) -> str:
        if processor is None:
            processor = self.get_hf_processor()

234
        image_token, fake_image_token, global_img_token = self._get_image_token(
235
236
            processor
        )
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
        image_seq_len = processor.image_seq_len
        grid_placeholder = "<row_{n_h}_col_{n_w}>"

        p_img = image_token * image_seq_len
        global_img_placeholder = fake_image_token + global_img_token + p_img
        tile_img_placeholder = fake_image_token + grid_placeholder + p_img

        grid_w, grid_h = self._get_image_feature_grid_size(
            image_width=image_width,
            image_height=image_height,
            processor=processor,
        )
        if grid_w == 0 and grid_h == 0:
            return global_img_placeholder + fake_image_token

        tiles_placeholder = list[str]()
        for i in range(grid_h):
            for j in range(grid_w):
255
                placeholder_per_tile = tile_img_placeholder.format(n_h=i + 1, n_w=j + 1)
256
257
258
259
260
                tiles_placeholder.append(placeholder_per_tile)
                # Add line break if it is the last tile in the row
                if j == grid_w - 1:
                    tiles_placeholder.append("\n")

261
262
263
264
265
266
267
268
        return "".join(
            [
                *tiles_placeholder,
                "\n",
                global_img_placeholder,
                fake_image_token,
            ]
        )
269
270
271
272
273
274

    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
275
        processor: Idefics3Processor | None,
276
    ) -> int:
277
278
279
280
        if processor is None:
            processor = self.get_hf_processor()

        num_patches = self.get_num_patches(
281
282
283
284
285
            image_width=image_width,
            image_height=image_height,
            processor=processor,
        )

286
        return num_patches * processor.image_seq_len
287
288
289
290
291
292
293
294
295
296

    def get_image_size_with_most_features(self) -> ImageSize:
        processor = self.get_hf_processor()
        image_processor: Idefics3ImageProcessor = processor.image_processor

        return ImageSize(
            width=image_processor.size["longest_edge"],
            height=image_processor.size["longest_edge"],
        )

297

298
class Idefics3DummyInputsBuilder(BaseDummyInputsBuilder[Idefics3ProcessingInfo]):
299
300
301
302
303
304
305
306
307
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

        processor = self.info.get_hf_processor()
        image_token, _, _ = self.info._get_image_token(processor)

        return image_token * num_images

    def get_dummy_mm_data(
308
        self,
309
310
        seq_len: int,
        mm_counts: Mapping[str, int],
311
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
312
    ) -> MultiModalDataDict:
313
314
315
        num_images = mm_counts.get("image", 0)
        hf_processor = self.info.get_hf_processor()
        image_processor: Idefics3ImageProcessor = hf_processor.image_processor
316
        longest_edge = image_processor.max_image_size["longest_edge"]
317

318
319
        image_overrides = mm_options.get("image") if mm_options else None

320
        return {
321
322
323
324
325
326
            "image": self._get_dummy_images(
                width=longest_edge,
                height=longest_edge,
                num_images=num_images,
                overrides=image_overrides,
            )
327
328
        }

329

330
class Idefics3MultiModalProcessor(BaseMultiModalProcessor[Idefics3ProcessingInfo]):
331
332
333
334
335
    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
336
        tok_kwargs: Mapping[str, object],
337
    ) -> BatchFeature:
338
339
340
341
342
343
        # Text-only input not supported in composite processor
        if not (images := mm_data.get("images", [])):
            prompt_ids = self.info.get_tokenizer().encode(prompt)
            prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
            return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")

344
        mm_kwargs = {"input_data_format": "channels_last", **mm_kwargs}
345
346
347
348
        processed_outputs = super()._call_hf_processor(
            prompt,
            mm_data,
            mm_kwargs,
349
            tok_kwargs,
350
351
        )

352
353
        mm_items = self.info.parse_mm_data({"image": images}, validate=False)
        parsed_images = mm_items.get_items("image", ImageProcessorItems)
354
355
356
357
358
359
360
361
362
363
        image_sizes = [
            parsed_images.get_image_size(i) for i in range(len(parsed_images))
        ]
        hf_processor = self.info.get_hf_processor(**mm_kwargs)

        num_patches = [
            self.info.get_num_patches(
                image_width=size.width,
                image_height=size.height,
                processor=hf_processor,
364
365
            )
            for size in image_sizes
366
367
368
369
370
371
372
        ]
        processed_outputs["num_patches"] = torch.tensor(num_patches)

        # Remove the extra batch dimension
        processed_outputs["pixel_values"].squeeze_(0)
        processed_outputs["pixel_attention_mask"].squeeze_(0)

373
        return processed_outputs
374

375
376
377
378
379
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
380
381
        num_patches = hf_inputs.get("num_patches", torch.empty(0))

382
        return dict(
383
            pixel_values=MultiModalFieldConfig.flat_from_sizes("image", num_patches),
384
            pixel_attention_mask=MultiModalFieldConfig.flat_from_sizes(
385
386
                "image", num_patches
            ),
387
            image_embeds=MultiModalFieldConfig.batched("image"),
388
            num_patches=MultiModalFieldConfig.batched("image"),
389
        )
390

391
    def _get_prompt_updates(
392
393
394
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
395
        out_mm_kwargs: MultiModalKwargsItems,
396
    ) -> Sequence[PromptUpdate]:
397
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
398
        image_token, _, _ = self.info._get_image_token(hf_processor)
399

400
        def get_replacement_idefics3(item_idx: int) -> PromptUpdateDetails:
401
402
403
            images = mm_items.get_items("image", ImageProcessorItems)

            image_size = images.get_image_size(item_idx)
404

405
            image_repl = self.info.get_image_repl(
406
407
                image_width=image_size.width,
                image_height=image_size.height,
408
                processor=hf_processor,
409
410
            )

411
412
413
414
415
            return PromptUpdateDetails.select_text(
                image_repl,
                embed_text=image_token,
            )

416
417
418
419
420
421
422
        return [
            PromptReplacement(
                modality="image",
                target=image_token,
                replacement=get_replacement_idefics3,
            )
        ]
423
424
425


class Idefics3SimpleMLP(nn.Module):
426
427
428
    def __init__(
        self,
        config: Idefics3Config,
429
        quant_config: QuantizationConfig | None = None,
430
431
        prefix: str = "",
    ):
432
        super().__init__()
433
        input_size = config.vision_config.hidden_size * (config.scale_factor**2)
434
        output_size = config.text_config.hidden_size
435
436
437
438
439
440
441
        self.proj = ReplicatedLinear(
            input_size,
            output_size,
            bias=False,
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "proj"),
        )
442
443
444
445
446
447
448

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        out, _ = self.proj(x)
        return out


class Idefics3Connector(nn.Module):
449
450
451
    def __init__(
        self,
        config: Idefics3Config,
452
        quant_config: QuantizationConfig | None = None,
453
454
        prefix: str = "",
    ):
455
456
        super().__init__()
        self.scale_factor = config.scale_factor
457
458
459
460
461
        self.modality_projection = Idefics3SimpleMLP(
            config,
            quant_config,
            prefix=maybe_prefix(prefix, "modality_projection"),
        )
462

463
    def pixel_shuffle(self, x: torch.Tensor, scale_factor: int = 2) -> torch.Tensor:
464
465
466
        bsz, seq, embed_dim = x.size()
        height = width = int(seq**0.5)
        x = x.view(bsz, height, width, embed_dim)
467
        x = x.view(bsz, height, int(width / scale_factor), embed_dim * scale_factor)
468
469
470
471
472
473
474
475
        x = x.permute(0, 2, 1, 3)
        x = x.reshape(
            bsz,
            int(width / scale_factor),
            int(height / scale_factor),
            embed_dim * (scale_factor**2),
        )
        x = x.permute(0, 2, 1, 3)
476
        x = x.reshape(bsz, int(seq / (scale_factor**2)), embed_dim * (scale_factor**2))
477
478
479
        return x

    def forward(self, image_hidden_states: torch.Tensor) -> torch.Tensor:
480
        image_hidden_states = self.pixel_shuffle(image_hidden_states, self.scale_factor)
481
482
483
484
485
        image_hidden_states = self.modality_projection(image_hidden_states)
        return image_hidden_states


class Idefics3Model(nn.Module):
486
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
487
        super().__init__()
488

489
        config: Idefics3Config = vllm_config.model_config.hf_config
490
491
        quant_config = vllm_config.quant_config

492
493
        self.config = config
        self.vocab_size = self.config.text_config.vocab_size
494
495
496
        self.vision_model = Idefics3VisionTransformer(
            config.vision_config,
            quant_config=quant_config,
497
498
            prefix=maybe_prefix(prefix, "vision_model"),
        )
499
500
501
502
503
        self.connector = Idefics3Connector(
            config,
            quant_config,
            prefix=maybe_prefix(prefix, "connector"),
        )
504
505
506
507
        self.text_model = LlamaModel(
            vllm_config=vllm_config.with_hf_config(config.text_config),
            prefix=maybe_prefix(prefix, "text_model"),
        )
508
509

        self.image_seq_len = int(
510
511
512
            ((config.vision_config.image_size // config.vision_config.patch_size) ** 2)
            / (config.scale_factor**2)
        )
513
514
        self.image_token_id = self.config.image_token_id

515
    def image_pixels_to_features(
516
517
        self,
        pixel_values: torch.Tensor,
518
519
        pixel_attention_mask: torch.Tensor,
    ) -> torch.Tensor:
520
521
522
523
524
525
526
527
528
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the vision tower
        pixel_values = pixel_values.to(
            dtype=self.vision_model.embeddings.patch_embedding.weight.dtype
        )  # fp16 compatibility

        # Remove padding images - padding images are full 0.
        nb_values_per_image = pixel_values.shape[1:].numel()
        real_images_inds = (pixel_values == 0.0).sum(
529
530
            dim=(-1, -2, -3)
        ) != nb_values_per_image
531
532
533
        pixel_values = pixel_values[real_images_inds].contiguous()

        # Handle the vision attention mask
534
        # Remove padding images from the mask
535
        pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous()
536
537

        patch_size = self.config.vision_config.patch_size
538
539
540
541
542
543
        patches_subgrid = pixel_attention_mask.unfold(
            dimension=1, size=patch_size, step=patch_size
        )
        patches_subgrid = patches_subgrid.unfold(
            dimension=2, size=patch_size, step=patch_size
        )
544
545
546
547
548
549
550
551
        patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()

        # Get sequence from the vision encoder
        image_hidden_states = self.vision_model(
            pixel_values=pixel_values,
            patch_attention_mask=patch_attention_mask,
        )

552
        return image_hidden_states
553

554
555
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.text_model.embed_input_ids(input_ids)
556

557
558
    def forward(
        self,
559
        input_ids: torch.Tensor | None,
560
        positions: torch.Tensor,
561
562
563
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
564
565
566
567
568
569
570
571
572
        hidden_states = self.text_model(
            input_ids,
            positions,
            intermediate_tensors,
            inputs_embeds=inputs_embeds,
        )
        return hidden_states


573
@MULTIMODAL_REGISTRY.register_processor(
574
    Idefics3MultiModalProcessor,
575
    info=Idefics3ProcessingInfo,
576
577
578
    dummy_inputs=Idefics3DummyInputsBuilder,
)
class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA):
579
580
581
582
583
584
585
586
587
588
589
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }
590

591
    @classmethod
592
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
593
594
595
596
597
        if modality.startswith("image"):
            return "<image>"

        raise ValueError("Only image modality is supported")

598
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
599
600
        super().__init__()

601
602
603
604
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config

605
606
607
        self.config = config
        self.multimodal_config = multimodal_config

608
609
610
611
612
613
614
615
616
617
        with self._mark_composite_model(
            vllm_config,
            language_targets=LlamaModel,
            tower_targets={"image": (Idefics3VisionTransformer, Idefics3Connector)},
        ):
            self.model = Idefics3Model(
                vllm_config=vllm_config,
                prefix=maybe_prefix(prefix, "model"),
            )

618
619
620
621
622
623
        self.image_token_id = self.config.image_token_id

        self.lm_head = ParallelLMHead(
            config.text_config.vocab_size,
            config.text_config.hidden_size,
            quant_config=quant_config,
624
            prefix=maybe_prefix(prefix, "lm_head"),
625
626
        )
        if self.config.text_config.tie_word_embeddings:
627
            self.lm_head.weight = self.model.text_model.embed_tokens.weight
628
629
        self.logits_processor = LogitsProcessor(config.text_config.vocab_size)

630
    def _parse_and_validate_image_input(self, **kwargs: object) -> ImageInputs | None:
631
632
633
634
635
636
637
638
639
        pixel_values = kwargs.pop("pixel_values", None)
        image_embeds = kwargs.pop("image_embeds", None)

        if pixel_values is None and image_embeds is None:
            return None

        if image_embeds is not None:
            return Idefics3ImageEmbeddingInputs(
                type="image_embeds",
640
                data=image_embeds,
641
642
643
644
645
            )

        if pixel_values is not None:
            pixel_attention_mask = kwargs.pop("pixel_attention_mask")
            num_patches = kwargs.pop("num_patches")
646
            expected_h = expected_w = self.config.vision_config.image_size
647

648
649
            return Idefics3ImagePixelInputs(
                type="pixel_values",
650
651
652
                pixel_values=pixel_values,
                pixel_attention_mask=pixel_attention_mask,
                num_patches=num_patches,
653
                resolve_bindings={"h": expected_h, "w": expected_w},
654
655
656
657
            )

        raise AssertionError("This line should be unreachable.")

658
    def _process_image_pixels(self, inputs: Idefics3ImagePixelInputs) -> torch.Tensor:
659
660
661
662
663
664
665
666
        pixel_values = inputs["pixel_values"]
        pixel_attention_mask = inputs["pixel_attention_mask"]

        return self.model.image_pixels_to_features(
            pixel_values,
            pixel_attention_mask=pixel_attention_mask,
        )

667
668
669
    def _process_image_input(
        self,
        image_input: ImageInputs,
670
    ) -> torch.Tensor | list[torch.Tensor]:
671
672
673
674
675
676
677
        if image_input["type"] == "image_embeds":
            return image_input["data"]

        image_features = self._process_image_pixels(image_input)
        image_features = self.model.connector(image_features)

        num_patches = image_input["num_patches"]
678
        return [e.flatten(0, 1) for e in image_features.split(num_patches.tolist())]
679

680
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
681
        image_input = self._parse_and_validate_image_input(**kwargs)
682
        if image_input is None:
683
            return []
684

685
        return self._process_image_input(image_input)
686

687
688
    def forward(
        self,
689
        input_ids: torch.Tensor | None,
690
        positions: torch.Tensor,
691
692
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
693
        **kwargs: object,
694
    ) -> torch.Tensor | IntermediateTensors:
695
696
697
        if intermediate_tensors is not None:
            inputs_embeds = None

698
699
700
        hidden_states = self.model.text_model(
            input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
        )
701

702
703
        return hidden_states

704
705
    def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head, hidden_states)
706
707
        return logits

708
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
709
        loader = AutoWeightsLoader(self)
710
        return loader.load_weights(weights)
711
712
713
714
715
716
717
718

    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="model.text_model",
            connector="model.connector",
719
720
            tower_model="model.vision_model",
        )
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738

    def get_num_mm_encoder_tokens(
        self,
        num_image_tokens: int,
    ) -> int:
        hf_config = self.config
        scale_factor = hf_config.scale_factor

        return num_image_tokens * scale_factor**2

    def get_num_mm_connector_tokens(
        self,
        num_vision_tokens: int,
    ) -> int:
        hf_config = self.config
        scale_factor = hf_config.scale_factor

        return num_vision_tokens // scale_factor**2