idefics3.py 24 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Idefics3 model compatible with HuggingFace weights."""

import math
20
from collections.abc import Iterable, Mapping, Sequence
21
from typing import Annotated, Literal, TypeAlias
22
23
24

import torch
from torch import nn
25
26
27
28
29
30
from transformers import (
    BatchFeature,
    Idefics3Config,
    Idefics3ImageProcessor,
    Idefics3Processor,
)
31

32
from vllm.config import VllmConfig
33
from vllm.config.multimodal import BaseDummyOptions
34
35
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
36
from vllm.model_executor.layers.quantization import QuantizationConfig
37
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
38
from vllm.model_executor.models.module_mapping import MultiModelKeys
39
from vllm.multimodal import MULTIMODAL_REGISTRY
40
41
42
43
44
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
)
45
from vllm.multimodal.parse import ImageProcessorItems, ImageSize
46
47
48
49
50
51
52
53
from vllm.multimodal.processing import (
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    MultiModalDataItems,
    PromptReplacement,
    PromptUpdate,
    PromptUpdateDetails,
)
54
from vllm.multimodal.profiling import BaseDummyInputsBuilder
55
from vllm.sequence import IntermediateTensors
56
from vllm.utils.tensor_schema import TensorSchema, TensorShape
57
58

from .idefics2_vision_model import (
59
60
    Idefics2VisionTransformer as Idefics3VisionTransformer,
)
61
from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal
62
from .llama import LlamaModel
63
from .utils import AutoWeightsLoader, maybe_prefix
64
65


66
class Idefics3ImagePixelInputs(TensorSchema):
67
    """
68
69
70
71
72
73
    Dimensions:
        - bn: Batch size * number of images
        - bnp: Batch size * number of images * number of patches
        - c: Number of channels (3)
        - h: Height
        - w: Width
74
    """
75

76
77
    type: Literal["pixel_values"]
    pixel_values: Annotated[torch.Tensor, TensorShape("bnp", 3, "h", "w")]
78
    pixel_attention_mask: Annotated[torch.Tensor, TensorShape("bnp", "h", "w")]
79
    num_patches: Annotated[torch.Tensor, TensorShape("bn")]
80

81

82
class Idefics3ImageEmbeddingInputs(TensorSchema):
83
    """
84
85
86
87
    Dimensions:
        - bn: Batch size * number of images
        - f: Image feature size
        - h: Hidden size (must match the hidden size of language model backbone)
88
    """
89

90
91
    type: Literal["image_embeds"]
    data: Annotated[torch.Tensor, TensorShape("bn", "f", "h")]
92
93


94
ImageInputs: TypeAlias = Idefics3ImagePixelInputs | Idefics3ImageEmbeddingInputs
95
96


97
class Idefics3ProcessingInfo(BaseProcessingInfo):
98
    def get_hf_processor(self, **kwargs: object) -> Idefics3Processor:
99
        return self.ctx.get_hf_processor(Idefics3Processor, **kwargs)
100

101
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
102
        return {"image": None}
103

104
105
106
107
108
    def _resize_output_size(
        self,
        *,
        height: int,
        width: int,
109
        max_len: int | None = None,
110
        min_len: int = 1,
111
        max_size: int | None = None,
112
    ) -> tuple[int, int]:
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
        # Set default value for max_len if not provided
        max_len = max(height, width) if max_len is None else max_len
        aspect_ratio = width / height

        # Handle the maximum size constraint
        if max_size is not None:
            max_len = min(max_len, max_size)

        # Adjust dimensions according to the aspect ratio
        if width >= height:
            width = max_len
            height = int(width / aspect_ratio)
        else:
            height = max_len
            width = int(height * aspect_ratio)
128

129
130
131
        # Ensure both width and height are even (if needed)
        height += height % 2
        width += width % 2
132

133
134
135
        # Ensure dimensions are not smaller than the minimum length
        height = max(height, min_len)
        width = max(width, min_len)
136

137
        return height, width
138

139
140
141
142
143
144
145
146
147
    def _get_resize_output_image_size(
        self,
        *,
        image_width: int,
        image_height: int,
        resolution_max_side: int,
    ) -> tuple[int, int]:
        hf_processor = self.get_hf_processor()
        image_processor: Idefics3ImageProcessor = hf_processor.image_processor
148
        max_image_size = image_processor.size["longest_edge"]
149
150
        if resolution_max_side > max_image_size:
            raise ValueError(
151
152
                "`resolution_max_side` cannot be larger than `max_image_size`"
            )
153
154
155
156
157

        height, width = image_height, image_width

        # Find the output size, when rescaling the longest edge to max_len and
        # preserving the aspect ratio
158
159
160
        height, width = self._resize_output_size(
            height=height, width=width, max_len=resolution_max_side
        )
161
162
163
164
165
166
167
        return height, width

    def _get_image_feature_grid_size(
        self,
        *,
        image_width: int,
        image_height: int,
168
        processor: Idefics3Processor | None,
169
    ) -> tuple[int, int]:
170
171
172
173
174
        if processor is None:
            processor = self.get_hf_processor()

        image_processor: Idefics3ImageProcessor = processor.image_processor

175
176
        max_image_size = image_processor.max_image_size["longest_edge"]
        size = image_processor.size["longest_edge"]
177
178
179
        assert size % max_image_size == 0, (
            "`longest_edge` in image_processor's `size` must be divisible by "
            "`longest_edge` in `max_image_size`, this may be caused by "
180
181
            "incorrect mm_kwargs override."
        )
182
183
184
185
186
187
188
189
190
191
192
193

        resized_height, resized_width = self._get_resize_output_image_size(
            image_width=image_width,
            image_height=image_height,
            resolution_max_side=size,
        )
        if resized_height > max_image_size or resized_width > max_image_size:
            grid_h = math.ceil(resized_height / max_image_size)
            grid_w = math.ceil(resized_width / max_image_size)
        else:
            grid_h = grid_w = 0
        return grid_w, grid_h
194

195
196
197
198
199
    def get_num_patches(
        self,
        *,
        image_width: int,
        image_height: int,
200
        processor: Idefics3Processor | None,
201
202
203
204
205
206
207
208
209
    ) -> int:
        grid_w, grid_h = self._get_image_feature_grid_size(
            image_width=image_width,
            image_height=image_height,
            processor=processor,
        )

        return grid_w * grid_h + 1

210
    def _get_image_token(
211
        self, processor: Idefics3Processor | None
212
    ) -> tuple[str, str, str]:
213
214
        if processor is None:
            processor = self.get_hf_processor()
215

216
217
        image_token = processor.image_token
        fake_image_token = processor.fake_image_token
218
219
220
        global_image_token = processor.global_image_tag
        return image_token, fake_image_token, global_image_token

221
222
223
224
225
    def get_image_repl(
        self,
        *,
        image_width: int,
        image_height: int,
226
        processor: Idefics3Processor | None,
227
228
229
230
    ) -> str:
        if processor is None:
            processor = self.get_hf_processor()

231
        image_token, fake_image_token, global_img_token = self._get_image_token(
232
233
            processor
        )
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
        image_seq_len = processor.image_seq_len
        grid_placeholder = "<row_{n_h}_col_{n_w}>"

        p_img = image_token * image_seq_len
        global_img_placeholder = fake_image_token + global_img_token + p_img
        tile_img_placeholder = fake_image_token + grid_placeholder + p_img

        grid_w, grid_h = self._get_image_feature_grid_size(
            image_width=image_width,
            image_height=image_height,
            processor=processor,
        )
        if grid_w == 0 and grid_h == 0:
            return global_img_placeholder + fake_image_token

        tiles_placeholder = list[str]()
        for i in range(grid_h):
            for j in range(grid_w):
252
                placeholder_per_tile = tile_img_placeholder.format(n_h=i + 1, n_w=j + 1)
253
254
255
256
257
                tiles_placeholder.append(placeholder_per_tile)
                # Add line break if it is the last tile in the row
                if j == grid_w - 1:
                    tiles_placeholder.append("\n")

258
259
260
261
262
263
264
265
        return "".join(
            [
                *tiles_placeholder,
                "\n",
                global_img_placeholder,
                fake_image_token,
            ]
        )
266
267
268
269
270
271

    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
272
        processor: Idefics3Processor | None,
273
    ) -> int:
274
275
276
277
        if processor is None:
            processor = self.get_hf_processor()

        num_patches = self.get_num_patches(
278
279
280
281
282
            image_width=image_width,
            image_height=image_height,
            processor=processor,
        )

283
        return num_patches * processor.image_seq_len
284
285
286
287
288
289
290
291
292
293

    def get_image_size_with_most_features(self) -> ImageSize:
        processor = self.get_hf_processor()
        image_processor: Idefics3ImageProcessor = processor.image_processor

        return ImageSize(
            width=image_processor.size["longest_edge"],
            height=image_processor.size["longest_edge"],
        )

294

295
class Idefics3DummyInputsBuilder(BaseDummyInputsBuilder[Idefics3ProcessingInfo]):
296
297
298
299
300
301
302
303
304
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

        processor = self.info.get_hf_processor()
        image_token, _, _ = self.info._get_image_token(processor)

        return image_token * num_images

    def get_dummy_mm_data(
305
        self,
306
307
        seq_len: int,
        mm_counts: Mapping[str, int],
308
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
309
    ) -> MultiModalDataDict:
310
311
312
        num_images = mm_counts.get("image", 0)
        hf_processor = self.info.get_hf_processor()
        image_processor: Idefics3ImageProcessor = hf_processor.image_processor
313
        longest_edge = image_processor.max_image_size["longest_edge"]
314

315
316
        image_overrides = mm_options.get("image") if mm_options else None

317
        return {
318
319
320
321
322
323
            "image": self._get_dummy_images(
                width=longest_edge,
                height=longest_edge,
                num_images=num_images,
                overrides=image_overrides,
            )
324
325
        }

326

327
class Idefics3MultiModalProcessor(BaseMultiModalProcessor[Idefics3ProcessingInfo]):
328
329
330
331
332
    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
333
        tok_kwargs: Mapping[str, object],
334
    ) -> BatchFeature:
335
336
337
338
339
340
341
342
343
344
        # Text-only input not supported in composite processor
        if not (images := mm_data.get("images", [])):
            prompt_ids = self.info.get_tokenizer().encode(prompt)
            prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
            return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")

        processed_outputs = super()._call_hf_processor(
            prompt,
            mm_data,
            mm_kwargs,
345
            tok_kwargs,
346
347
        )

348
349
350
351
352
        parsed_images = (
            self._get_data_parser()
            .parse_mm_data({"image": images})
            .get_items("image", ImageProcessorItems)
        )
353
354
355
356
357
358
359
360
361
362
        image_sizes = [
            parsed_images.get_image_size(i) for i in range(len(parsed_images))
        ]
        hf_processor = self.info.get_hf_processor(**mm_kwargs)

        num_patches = [
            self.info.get_num_patches(
                image_width=size.width,
                image_height=size.height,
                processor=hf_processor,
363
364
            )
            for size in image_sizes
365
366
367
368
369
370
371
        ]
        processed_outputs["num_patches"] = torch.tensor(num_patches)

        # Remove the extra batch dimension
        processed_outputs["pixel_values"].squeeze_(0)
        processed_outputs["pixel_attention_mask"].squeeze_(0)

372
        return processed_outputs
373

374
375
376
377
378
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
379
380
        num_patches = hf_inputs.get("num_patches", torch.empty(0))

381
        return dict(
382
            pixel_values=MultiModalFieldConfig.flat_from_sizes("image", num_patches),
383
            pixel_attention_mask=MultiModalFieldConfig.flat_from_sizes(
384
385
                "image", num_patches
            ),
386
            image_embeds=MultiModalFieldConfig.batched("image"),
387
            num_patches=MultiModalFieldConfig.batched("image"),
388
        )
389

390
    def _get_prompt_updates(
391
392
393
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
394
        out_mm_kwargs: MultiModalKwargsItems,
395
    ) -> Sequence[PromptUpdate]:
396
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
397
        image_token, _, _ = self.info._get_image_token(hf_processor)
398

399
        def get_replacement_idefics3(item_idx: int) -> PromptUpdateDetails:
400
401
402
            images = mm_items.get_items("image", ImageProcessorItems)

            image_size = images.get_image_size(item_idx)
403

404
            image_repl = self.info.get_image_repl(
405
406
                image_width=image_size.width,
                image_height=image_size.height,
407
                processor=hf_processor,
408
409
            )

410
411
412
413
414
            return PromptUpdateDetails.select_text(
                image_repl,
                embed_text=image_token,
            )

415
416
417
418
419
420
421
        return [
            PromptReplacement(
                modality="image",
                target=image_token,
                replacement=get_replacement_idefics3,
            )
        ]
422
423
424


class Idefics3SimpleMLP(nn.Module):
425
426
427
    def __init__(
        self,
        config: Idefics3Config,
428
        quant_config: QuantizationConfig | None = None,
429
430
        prefix: str = "",
    ):
431
        super().__init__()
432
        input_size = config.vision_config.hidden_size * (config.scale_factor**2)
433
        output_size = config.text_config.hidden_size
434
435
436
437
438
439
440
        self.proj = ReplicatedLinear(
            input_size,
            output_size,
            bias=False,
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "proj"),
        )
441
442
443
444
445
446
447

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        out, _ = self.proj(x)
        return out


class Idefics3Connector(nn.Module):
448
449
450
    def __init__(
        self,
        config: Idefics3Config,
451
        quant_config: QuantizationConfig | None = None,
452
453
        prefix: str = "",
    ):
454
455
        super().__init__()
        self.scale_factor = config.scale_factor
456
457
458
459
460
        self.modality_projection = Idefics3SimpleMLP(
            config,
            quant_config,
            prefix=maybe_prefix(prefix, "modality_projection"),
        )
461

462
    def pixel_shuffle(self, x: torch.Tensor, scale_factor: int = 2) -> torch.Tensor:
463
464
465
        bsz, seq, embed_dim = x.size()
        height = width = int(seq**0.5)
        x = x.view(bsz, height, width, embed_dim)
466
        x = x.view(bsz, height, int(width / scale_factor), embed_dim * scale_factor)
467
468
469
470
471
472
473
474
        x = x.permute(0, 2, 1, 3)
        x = x.reshape(
            bsz,
            int(width / scale_factor),
            int(height / scale_factor),
            embed_dim * (scale_factor**2),
        )
        x = x.permute(0, 2, 1, 3)
475
        x = x.reshape(bsz, int(seq / (scale_factor**2)), embed_dim * (scale_factor**2))
476
477
478
        return x

    def forward(self, image_hidden_states: torch.Tensor) -> torch.Tensor:
479
        image_hidden_states = self.pixel_shuffle(image_hidden_states, self.scale_factor)
480
481
482
483
484
        image_hidden_states = self.modality_projection(image_hidden_states)
        return image_hidden_states


class Idefics3Model(nn.Module):
485
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
486
        super().__init__()
487

488
        config: Idefics3Config = vllm_config.model_config.hf_config
489
490
        quant_config = vllm_config.quant_config

491
492
        self.config = config
        self.vocab_size = self.config.text_config.vocab_size
493
494
495
        self.vision_model = Idefics3VisionTransformer(
            config.vision_config,
            quant_config=quant_config,
496
497
            prefix=maybe_prefix(prefix, "vision_model"),
        )
498
499
500
501
502
        self.connector = Idefics3Connector(
            config,
            quant_config,
            prefix=maybe_prefix(prefix, "connector"),
        )
503
504
505
506
        self.text_model = LlamaModel(
            vllm_config=vllm_config.with_hf_config(config.text_config),
            prefix=maybe_prefix(prefix, "text_model"),
        )
507
508

        self.image_seq_len = int(
509
510
511
            ((config.vision_config.image_size // config.vision_config.patch_size) ** 2)
            / (config.scale_factor**2)
        )
512
513
        self.image_token_id = self.config.image_token_id

514
    def image_pixels_to_features(
515
516
        self,
        pixel_values: torch.Tensor,
517
518
        pixel_attention_mask: torch.Tensor,
    ) -> torch.Tensor:
519
520
521
522
523
524
525
526
527
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the vision tower
        pixel_values = pixel_values.to(
            dtype=self.vision_model.embeddings.patch_embedding.weight.dtype
        )  # fp16 compatibility

        # Remove padding images - padding images are full 0.
        nb_values_per_image = pixel_values.shape[1:].numel()
        real_images_inds = (pixel_values == 0.0).sum(
528
529
            dim=(-1, -2, -3)
        ) != nb_values_per_image
530
531
532
        pixel_values = pixel_values[real_images_inds].contiguous()

        # Handle the vision attention mask
533
        # Remove padding images from the mask
534
        pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous()
535
536

        patch_size = self.config.vision_config.patch_size
537
538
539
540
541
542
        patches_subgrid = pixel_attention_mask.unfold(
            dimension=1, size=patch_size, step=patch_size
        )
        patches_subgrid = patches_subgrid.unfold(
            dimension=2, size=patch_size, step=patch_size
        )
543
544
545
546
547
548
549
550
        patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()

        # Get sequence from the vision encoder
        image_hidden_states = self.vision_model(
            pixel_values=pixel_values,
            patch_attention_mask=patch_attention_mask,
        )

551
        return image_hidden_states
552

553
554
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.text_model.embed_input_ids(input_ids)
555

556
557
558
559
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
560
561
562
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
563
564
565
566
567
568
569
570
571
        hidden_states = self.text_model(
            input_ids,
            positions,
            intermediate_tensors,
            inputs_embeds=inputs_embeds,
        )
        return hidden_states


572
@MULTIMODAL_REGISTRY.register_processor(
573
    Idefics3MultiModalProcessor,
574
    info=Idefics3ProcessingInfo,
575
576
577
    dummy_inputs=Idefics3DummyInputsBuilder,
)
class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA):
578
579
    merge_by_field_config = True

580
581
582
583
584
585
586
587
588
589
590
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }
591

592
    @classmethod
593
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
594
595
596
597
598
        if modality.startswith("image"):
            return "<image>"

        raise ValueError("Only image modality is supported")

599
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
600
601
        super().__init__()

602
603
604
605
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config

606
607
608
        self.config = config
        self.multimodal_config = multimodal_config

609
610
611
        self.model = Idefics3Model(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
612
613
614
615
616
617
        self.image_token_id = self.config.image_token_id

        self.lm_head = ParallelLMHead(
            config.text_config.vocab_size,
            config.text_config.hidden_size,
            quant_config=quant_config,
618
            prefix=maybe_prefix(prefix, "lm_head"),
619
620
        )
        if self.config.text_config.tie_word_embeddings:
621
            self.lm_head.weight = self.model.text_model.embed_tokens.weight
622
623
        self.logits_processor = LogitsProcessor(config.text_config.vocab_size)

624
    def _parse_and_validate_image_input(self, **kwargs: object) -> ImageInputs | None:
625
626
627
628
629
630
631
632
633
        pixel_values = kwargs.pop("pixel_values", None)
        image_embeds = kwargs.pop("image_embeds", None)

        if pixel_values is None and image_embeds is None:
            return None

        if image_embeds is not None:
            return Idefics3ImageEmbeddingInputs(
                type="image_embeds",
634
                data=image_embeds,
635
636
637
638
639
            )

        if pixel_values is not None:
            pixel_attention_mask = kwargs.pop("pixel_attention_mask")
            num_patches = kwargs.pop("num_patches")
640
            expected_h = expected_w = self.config.vision_config.image_size
641

642
643
            return Idefics3ImagePixelInputs(
                type="pixel_values",
644
645
646
                pixel_values=pixel_values,
                pixel_attention_mask=pixel_attention_mask,
                num_patches=num_patches,
647
                resolve_bindings={"h": expected_h, "w": expected_w},
648
649
650
651
            )

        raise AssertionError("This line should be unreachable.")

652
    def _process_image_pixels(self, inputs: Idefics3ImagePixelInputs) -> torch.Tensor:
653
654
655
656
657
658
659
660
        pixel_values = inputs["pixel_values"]
        pixel_attention_mask = inputs["pixel_attention_mask"]

        return self.model.image_pixels_to_features(
            pixel_values,
            pixel_attention_mask=pixel_attention_mask,
        )

661
662
663
    def _process_image_input(
        self,
        image_input: ImageInputs,
664
    ) -> torch.Tensor | list[torch.Tensor]:
665
666
667
668
669
670
671
        if image_input["type"] == "image_embeds":
            return image_input["data"]

        image_features = self._process_image_pixels(image_input)
        image_features = self.model.connector(image_features)

        num_patches = image_input["num_patches"]
672
        return [e.flatten(0, 1) for e in image_features.split(num_patches.tolist())]
673

674
675
676
    def get_language_model(self) -> torch.nn.Module:
        return self.model

677
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
678
        image_input = self._parse_and_validate_image_input(**kwargs)
679
        if image_input is None:
680
            return []
681

682
        return self._process_image_input(image_input)
683

684
685
686
687
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
688
689
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
690
        **kwargs: object,
691
    ) -> torch.Tensor | IntermediateTensors:
692
693
694
        if intermediate_tensors is not None:
            inputs_embeds = None

695
696
697
        hidden_states = self.model.text_model(
            input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
        )
698

699
700
        return hidden_states

701
702
    def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head, hidden_states)
703
704
        return logits

705
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
706
        loader = AutoWeightsLoader(self)
707
        return loader.load_weights(weights)
708
709
710
711
712
713
714
715

    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="model.text_model",
            connector="model.connector",
716
717
            tower_model="model.vision_model",
        )