"vllm/vscode:/vscode.git/clone" did not exist on "b8afa8b95a4eee008a9b72440620113e5bfbe962"
idefics3.py 23.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Idefics3 model compatible with HuggingFace weights."""

19
from collections.abc import Iterable, Mapping, Sequence
20
from typing import Annotated, Literal, TypeAlias
21
22
23

import torch
from torch import nn
24
25
26
27
28
29
from transformers import (
    BatchFeature,
    Idefics3Config,
    Idefics3ImageProcessor,
    Idefics3Processor,
)
30

31
from vllm.config import VllmConfig
32
from vllm.config.multimodal import BaseDummyOptions
33
34
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
35
from vllm.model_executor.layers.quantization import QuantizationConfig
36
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
37
from vllm.model_executor.models.module_mapping import MultiModelKeys
38
from vllm.multimodal import MULTIMODAL_REGISTRY
39
40
41
42
43
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
)
44
from vllm.multimodal.parse import ImageProcessorItems, MultiModalDataItems
45
from vllm.multimodal.processing import (
46
    BaseDummyInputsBuilder,
47
48
49
50
51
52
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptReplacement,
    PromptUpdate,
    PromptUpdateDetails,
)
53
from vllm.sequence import IntermediateTensors
54
from vllm.utils.tensor_schema import TensorSchema, TensorShape
55
56

from .idefics2_vision_model import (
57
58
    Idefics2VisionTransformer as Idefics3VisionTransformer,
)
59
60
61
62
63
from .interfaces import (
    MultiModalEmbeddings,
    SupportsLoRA,
    SupportsMultiModal,
)
64
from .llama import LlamaModel
65
from .utils import AutoWeightsLoader, maybe_prefix
66
67


68
class Idefics3ImagePixelInputs(TensorSchema):
69
    """
70
71
72
73
74
75
    Dimensions:
        - bn: Batch size * number of images
        - bnp: Batch size * number of images * number of patches
        - c: Number of channels (3)
        - h: Height
        - w: Width
76
    """
77

78
79
    type: Literal["pixel_values"]
    pixel_values: Annotated[torch.Tensor, TensorShape("bnp", 3, "h", "w")]
80
    pixel_attention_mask: Annotated[torch.Tensor, TensorShape("bnp", "h", "w")]
81
    num_patches: Annotated[torch.Tensor, TensorShape("bn")]
82

83

84
class Idefics3ImageEmbeddingInputs(TensorSchema):
85
    """
86
87
88
89
    Dimensions:
        - bn: Batch size * number of images
        - f: Image feature size
        - h: Hidden size (must match the hidden size of language model backbone)
90
    """
91

92
93
    type: Literal["image_embeds"]
    data: Annotated[torch.Tensor, TensorShape("bn", "f", "h")]
94
95


96
ImageInputs: TypeAlias = Idefics3ImagePixelInputs | Idefics3ImageEmbeddingInputs
97
98


99
class Idefics3ProcessingInfo(BaseProcessingInfo):
100
    def get_hf_processor(self, **kwargs: object) -> Idefics3Processor:
101
        return self.ctx.get_hf_processor(Idefics3Processor, **kwargs)
102

103
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
104
        return {"image": None}
105

106
107
108
109
110
    def _resize_output_size(
        self,
        *,
        height: int,
        width: int,
111
        max_len: int | None = None,
112
        min_len: int = 1,
113
        max_size: int | None = None,
114
    ) -> tuple[int, int]:
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
        # Set default value for max_len if not provided
        max_len = max(height, width) if max_len is None else max_len
        aspect_ratio = width / height

        # Handle the maximum size constraint
        if max_size is not None:
            max_len = min(max_len, max_size)

        # Adjust dimensions according to the aspect ratio
        if width >= height:
            width = max_len
            height = int(width / aspect_ratio)
        else:
            height = max_len
            width = int(height * aspect_ratio)
130

131
132
133
        # Ensure both width and height are even (if needed)
        height += height % 2
        width += width % 2
134

135
136
137
        # Ensure dimensions are not smaller than the minimum length
        height = max(height, min_len)
        width = max(width, min_len)
138

139
        return height, width
140

141
142
143
144
145
146
147
148
149
    def _get_resize_output_image_size(
        self,
        *,
        image_width: int,
        image_height: int,
        resolution_max_side: int,
    ) -> tuple[int, int]:
        hf_processor = self.get_hf_processor()
        image_processor: Idefics3ImageProcessor = hf_processor.image_processor
150
        max_image_size = image_processor.size["longest_edge"]
151
152
        if resolution_max_side > max_image_size:
            raise ValueError(
153
154
                "`resolution_max_side` cannot be larger than `max_image_size`"
            )
155
156
157
158
159

        height, width = image_height, image_width

        # Find the output size, when rescaling the longest edge to max_len and
        # preserving the aspect ratio
160
161
162
        height, width = self._resize_output_size(
            height=height, width=width, max_len=resolution_max_side
        )
163
164
165
166
167
168
169
        return height, width

    def _get_image_feature_grid_size(
        self,
        *,
        image_width: int,
        image_height: int,
170
171
172
        processor: Idefics3Processor,
        mm_kwargs: Mapping[str, object],
    ) -> tuple[int, int, int]:
173
174
        image_processor: Idefics3ImageProcessor = processor.image_processor

175
176
177
178
        return image_processor.get_number_of_image_patches(
            image_height,
            image_width,
            self.ctx.get_merged_mm_kwargs(mm_kwargs),
179
        )
180

181
182
183
184
185
    def get_num_patches(
        self,
        *,
        image_width: int,
        image_height: int,
186
187
        processor: Idefics3Processor,
        mm_kwargs: Mapping[str, object],
188
    ) -> int:
189
        num_patches, _, _ = self._get_image_feature_grid_size(
190
191
192
            image_width=image_width,
            image_height=image_height,
            processor=processor,
193
            mm_kwargs=mm_kwargs,
194
195
        )

196
        return num_patches
197

198
    def _get_image_token(self, processor: Idefics3Processor) -> tuple[str, str, str]:
199
200
        image_token = processor.image_token
        fake_image_token = processor.fake_image_token
201
202
203
        global_image_token = processor.global_image_tag
        return image_token, fake_image_token, global_image_token

204
205
206
207
208
    def get_image_repl(
        self,
        *,
        image_width: int,
        image_height: int,
209
210
        processor: Idefics3Processor,
        mm_kwargs: Mapping[str, object],
211
    ) -> str:
212
        image_token, fake_image_token, global_img_token = self._get_image_token(
213
214
            processor
        )
215
216
217
218
219
220
221
        image_seq_len = processor.image_seq_len
        grid_placeholder = "<row_{n_h}_col_{n_w}>"

        p_img = image_token * image_seq_len
        global_img_placeholder = fake_image_token + global_img_token + p_img
        tile_img_placeholder = fake_image_token + grid_placeholder + p_img

222
        _, grid_h, grid_w = self._get_image_feature_grid_size(
223
224
225
            image_width=image_width,
            image_height=image_height,
            processor=processor,
226
            mm_kwargs=mm_kwargs,
227
228
229
230
231
232
233
        )
        if grid_w == 0 and grid_h == 0:
            return global_img_placeholder + fake_image_token

        tiles_placeholder = list[str]()
        for i in range(grid_h):
            for j in range(grid_w):
234
                placeholder_per_tile = tile_img_placeholder.format(n_h=i + 1, n_w=j + 1)
235
236
237
238
239
                tiles_placeholder.append(placeholder_per_tile)
                # Add line break if it is the last tile in the row
                if j == grid_w - 1:
                    tiles_placeholder.append("\n")

240
241
242
243
244
245
246
247
        return "".join(
            [
                *tiles_placeholder,
                "\n",
                global_img_placeholder,
                fake_image_token,
            ]
        )
248
249
250
251
252
253

    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
254
255
        processor: Idefics3Processor,
        mm_kwargs: Mapping[str, object],
256
    ) -> int:
257
        num_patches = self.get_num_patches(
258
259
260
            image_width=image_width,
            image_height=image_height,
            processor=processor,
261
            mm_kwargs=mm_kwargs,
262
263
        )

264
        return num_patches * processor.image_seq_len
265

266

267
class Idefics3DummyInputsBuilder(BaseDummyInputsBuilder[Idefics3ProcessingInfo]):
268
269
270
271
272
273
274
275
276
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

        processor = self.info.get_hf_processor()
        image_token, _, _ = self.info._get_image_token(processor)

        return image_token * num_images

    def get_dummy_mm_data(
277
        self,
278
279
        seq_len: int,
        mm_counts: Mapping[str, int],
280
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
281
        mm_processor_kwargs: Mapping[str, object] | None = None,
282
    ) -> MultiModalDataDict:
283
        num_images = mm_counts.get("image", 0)
284
        hf_processor = self.info.get_hf_processor(**(mm_processor_kwargs or {}))
285
        image_processor: Idefics3ImageProcessor = hf_processor.image_processor
286
        longest_edge = image_processor.max_image_size["longest_edge"]
287

288
289
        image_overrides = mm_options.get("image") if mm_options else None

290
        return {
291
292
293
294
295
296
            "image": self._get_dummy_images(
                width=longest_edge,
                height=longest_edge,
                num_images=num_images,
                overrides=image_overrides,
            )
297
298
        }

299

300
class Idefics3MultiModalProcessor(BaseMultiModalProcessor[Idefics3ProcessingInfo]):
301
302
303
304
305
    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
306
        tok_kwargs: Mapping[str, object],
307
    ) -> BatchFeature:
308
309
310
311
312
313
        # Text-only input not supported in composite processor
        if not (images := mm_data.get("images", [])):
            prompt_ids = self.info.get_tokenizer().encode(prompt)
            prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
            return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")

314
        mm_kwargs = {"input_data_format": "channels_last", **mm_kwargs}
315
316
317
318
        processed_outputs = super()._call_hf_processor(
            prompt,
            mm_data,
            mm_kwargs,
319
            tok_kwargs,
320
321
        )

322
323
        mm_items = self.info.parse_mm_data({"image": images}, validate=False)
        parsed_images = mm_items.get_items("image", ImageProcessorItems)
324
325
326
327
328
329
330
331
332
333
        image_sizes = [
            parsed_images.get_image_size(i) for i in range(len(parsed_images))
        ]
        hf_processor = self.info.get_hf_processor(**mm_kwargs)

        num_patches = [
            self.info.get_num_patches(
                image_width=size.width,
                image_height=size.height,
                processor=hf_processor,
334
                mm_kwargs=mm_kwargs,
335
336
            )
            for size in image_sizes
337
338
339
340
341
342
343
        ]
        processed_outputs["num_patches"] = torch.tensor(num_patches)

        # Remove the extra batch dimension
        processed_outputs["pixel_values"].squeeze_(0)
        processed_outputs["pixel_attention_mask"].squeeze_(0)

344
        return processed_outputs
345

346
347
348
349
350
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
351
352
        num_patches = hf_inputs.get("num_patches", torch.empty(0))

353
        return dict(
354
            pixel_values=MultiModalFieldConfig.flat_from_sizes("image", num_patches),
355
            pixel_attention_mask=MultiModalFieldConfig.flat_from_sizes(
356
357
                "image", num_patches
            ),
358
            image_embeds=MultiModalFieldConfig.batched("image"),
359
            num_patches=MultiModalFieldConfig.batched("image"),
360
        )
361

362
    def _get_prompt_updates(
363
364
365
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
366
        out_mm_kwargs: MultiModalKwargsItems,
367
    ) -> Sequence[PromptUpdate]:
368
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
369
        image_token, _, _ = self.info._get_image_token(hf_processor)
370

371
        def get_replacement_idefics3(item_idx: int) -> PromptUpdateDetails:
372
373
374
            images = mm_items.get_items("image", ImageProcessorItems)

            image_size = images.get_image_size(item_idx)
375

376
            image_repl = self.info.get_image_repl(
377
378
                image_width=image_size.width,
                image_height=image_size.height,
379
                processor=hf_processor,
380
                mm_kwargs=hf_processor_mm_kwargs,
381
382
            )

383
384
385
386
387
            return PromptUpdateDetails.select_text(
                image_repl,
                embed_text=image_token,
            )

388
389
390
391
392
393
394
        return [
            PromptReplacement(
                modality="image",
                target=image_token,
                replacement=get_replacement_idefics3,
            )
        ]
395
396
397


class Idefics3SimpleMLP(nn.Module):
398
399
400
    def __init__(
        self,
        config: Idefics3Config,
401
        quant_config: QuantizationConfig | None = None,
402
403
        prefix: str = "",
    ):
404
        super().__init__()
405
        input_size = config.vision_config.hidden_size * (config.scale_factor**2)
406
        output_size = config.text_config.hidden_size
407
408
409
410
411
412
413
        self.proj = ReplicatedLinear(
            input_size,
            output_size,
            bias=False,
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "proj"),
        )
414
415
416
417
418
419
420

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        out, _ = self.proj(x)
        return out


class Idefics3Connector(nn.Module):
421
422
423
    def __init__(
        self,
        config: Idefics3Config,
424
        quant_config: QuantizationConfig | None = None,
425
426
        prefix: str = "",
    ):
427
428
        super().__init__()
        self.scale_factor = config.scale_factor
429
430
431
432
433
        self.modality_projection = Idefics3SimpleMLP(
            config,
            quant_config,
            prefix=maybe_prefix(prefix, "modality_projection"),
        )
434

435
    def pixel_shuffle(self, x: torch.Tensor, scale_factor: int = 2) -> torch.Tensor:
436
437
438
        bsz, seq, embed_dim = x.size()
        height = width = int(seq**0.5)
        x = x.view(bsz, height, width, embed_dim)
439
        x = x.view(bsz, height, int(width / scale_factor), embed_dim * scale_factor)
440
441
442
443
444
445
446
447
        x = x.permute(0, 2, 1, 3)
        x = x.reshape(
            bsz,
            int(width / scale_factor),
            int(height / scale_factor),
            embed_dim * (scale_factor**2),
        )
        x = x.permute(0, 2, 1, 3)
448
        x = x.reshape(bsz, int(seq / (scale_factor**2)), embed_dim * (scale_factor**2))
449
450
451
        return x

    def forward(self, image_hidden_states: torch.Tensor) -> torch.Tensor:
452
        image_hidden_states = self.pixel_shuffle(image_hidden_states, self.scale_factor)
453
454
455
456
457
        image_hidden_states = self.modality_projection(image_hidden_states)
        return image_hidden_states


class Idefics3Model(nn.Module):
458
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
459
        super().__init__()
460

461
        config: Idefics3Config = vllm_config.model_config.hf_config
462
463
        quant_config = vllm_config.quant_config

464
465
        self.config = config
        self.vocab_size = self.config.text_config.vocab_size
466
467
468
        self.vision_model = Idefics3VisionTransformer(
            config.vision_config,
            quant_config=quant_config,
469
470
            prefix=maybe_prefix(prefix, "vision_model"),
        )
471
472
473
474
475
        self.connector = Idefics3Connector(
            config,
            quant_config,
            prefix=maybe_prefix(prefix, "connector"),
        )
476
477
478
479
        self.text_model = LlamaModel(
            vllm_config=vllm_config.with_hf_config(config.text_config),
            prefix=maybe_prefix(prefix, "text_model"),
        )
480
481

        self.image_seq_len = int(
482
483
484
            ((config.vision_config.image_size // config.vision_config.patch_size) ** 2)
            / (config.scale_factor**2)
        )
485
486
        self.image_token_id = self.config.image_token_id

487
    def image_pixels_to_features(
488
489
        self,
        pixel_values: torch.Tensor,
490
491
        pixel_attention_mask: torch.Tensor,
    ) -> torch.Tensor:
492
493
494
495
496
497
498
499
500
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the vision tower
        pixel_values = pixel_values.to(
            dtype=self.vision_model.embeddings.patch_embedding.weight.dtype
        )  # fp16 compatibility

        # Remove padding images - padding images are full 0.
        nb_values_per_image = pixel_values.shape[1:].numel()
        real_images_inds = (pixel_values == 0.0).sum(
501
502
            dim=(-1, -2, -3)
        ) != nb_values_per_image
503
504
505
        pixel_values = pixel_values[real_images_inds].contiguous()

        # Handle the vision attention mask
506
        # Remove padding images from the mask
507
        pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous()
508
509

        patch_size = self.config.vision_config.patch_size
510
511
512
513
514
515
        patches_subgrid = pixel_attention_mask.unfold(
            dimension=1, size=patch_size, step=patch_size
        )
        patches_subgrid = patches_subgrid.unfold(
            dimension=2, size=patch_size, step=patch_size
        )
516
517
518
519
520
521
522
523
        patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()

        # Get sequence from the vision encoder
        image_hidden_states = self.vision_model(
            pixel_values=pixel_values,
            patch_attention_mask=patch_attention_mask,
        )

524
        return image_hidden_states
525

526
527
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.text_model.embed_input_ids(input_ids)
528

529
530
    def forward(
        self,
531
        input_ids: torch.Tensor | None,
532
        positions: torch.Tensor,
533
534
535
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
536
537
538
539
540
541
542
543
544
        hidden_states = self.text_model(
            input_ids,
            positions,
            intermediate_tensors,
            inputs_embeds=inputs_embeds,
        )
        return hidden_states


545
@MULTIMODAL_REGISTRY.register_processor(
546
    Idefics3MultiModalProcessor,
547
    info=Idefics3ProcessingInfo,
548
549
550
    dummy_inputs=Idefics3DummyInputsBuilder,
)
class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA):
551
552
553
554
555
556
557
558
559
560
561
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }
562

563
    @classmethod
564
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
565
566
567
568
569
        if modality.startswith("image"):
            return "<image>"

        raise ValueError("Only image modality is supported")

570
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
571
572
        super().__init__()

573
574
575
576
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config

577
578
579
        self.config = config
        self.multimodal_config = multimodal_config

580
581
582
583
584
585
586
587
588
589
        with self._mark_composite_model(
            vllm_config,
            language_targets=LlamaModel,
            tower_targets={"image": (Idefics3VisionTransformer, Idefics3Connector)},
        ):
            self.model = Idefics3Model(
                vllm_config=vllm_config,
                prefix=maybe_prefix(prefix, "model"),
            )

590
591
592
593
594
595
        self.image_token_id = self.config.image_token_id

        self.lm_head = ParallelLMHead(
            config.text_config.vocab_size,
            config.text_config.hidden_size,
            quant_config=quant_config,
596
            prefix=maybe_prefix(prefix, "lm_head"),
597
598
        )
        if self.config.text_config.tie_word_embeddings:
599
            self.lm_head.weight = self.model.text_model.embed_tokens.weight
600
601
        self.logits_processor = LogitsProcessor(config.text_config.vocab_size)

602
    def _parse_and_validate_image_input(self, **kwargs: object) -> ImageInputs | None:
603
604
605
606
607
608
609
610
611
        pixel_values = kwargs.pop("pixel_values", None)
        image_embeds = kwargs.pop("image_embeds", None)

        if pixel_values is None and image_embeds is None:
            return None

        if image_embeds is not None:
            return Idefics3ImageEmbeddingInputs(
                type="image_embeds",
612
                data=image_embeds,
613
614
615
616
617
            )

        if pixel_values is not None:
            pixel_attention_mask = kwargs.pop("pixel_attention_mask")
            num_patches = kwargs.pop("num_patches")
618
            expected_h = expected_w = self.config.vision_config.image_size
619

620
621
            return Idefics3ImagePixelInputs(
                type="pixel_values",
622
623
624
                pixel_values=pixel_values,
                pixel_attention_mask=pixel_attention_mask,
                num_patches=num_patches,
625
                resolve_bindings={"h": expected_h, "w": expected_w},
626
627
628
629
            )

        raise AssertionError("This line should be unreachable.")

630
    def _process_image_pixels(self, inputs: Idefics3ImagePixelInputs) -> torch.Tensor:
631
632
633
634
635
636
637
638
        pixel_values = inputs["pixel_values"]
        pixel_attention_mask = inputs["pixel_attention_mask"]

        return self.model.image_pixels_to_features(
            pixel_values,
            pixel_attention_mask=pixel_attention_mask,
        )

639
640
641
    def _process_image_input(
        self,
        image_input: ImageInputs,
642
    ) -> torch.Tensor | list[torch.Tensor]:
643
644
645
646
647
648
649
        if image_input["type"] == "image_embeds":
            return image_input["data"]

        image_features = self._process_image_pixels(image_input)
        image_features = self.model.connector(image_features)

        num_patches = image_input["num_patches"]
650
        return [e.flatten(0, 1) for e in image_features.split(num_patches.tolist())]
651

652
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
653
        image_input = self._parse_and_validate_image_input(**kwargs)
654
        if image_input is None:
655
            return []
656

657
        return self._process_image_input(image_input)
658

659
660
    def forward(
        self,
661
        input_ids: torch.Tensor | None,
662
        positions: torch.Tensor,
663
664
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
665
        **kwargs: object,
666
    ) -> torch.Tensor | IntermediateTensors:
667
668
669
        if intermediate_tensors is not None:
            inputs_embeds = None

670
671
672
        hidden_states = self.model.text_model(
            input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
        )
673

674
675
        return hidden_states

676
677
    def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head, hidden_states)
678
679
        return logits

680
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
681
        loader = AutoWeightsLoader(self)
682
        return loader.load_weights(weights)
683
684
685
686
687
688
689
690

    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="model.text_model",
            connector="model.connector",
691
692
            tower_model="model.vision_model",
        )
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710

    def get_num_mm_encoder_tokens(
        self,
        num_image_tokens: int,
    ) -> int:
        hf_config = self.config
        scale_factor = hf_config.scale_factor

        return num_image_tokens * scale_factor**2

    def get_num_mm_connector_tokens(
        self,
        num_vision_tokens: int,
    ) -> int:
        hf_config = self.config
        scale_factor = hf_config.scale_factor

        return num_vision_tokens // scale_factor**2