gemma3_mm.py 25.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import math
4
from collections.abc import Iterable, Mapping, Sequence
5
from typing import Any, Literal, Optional, TypedDict
6
7
8

import torch
from torch import nn
9
10
from transformers import BatchFeature, Gemma3Config, Gemma3Processor
from transformers.models.gemma3.processing_gemma3 import Gemma3ProcessorKwargs
11

12
import vllm.envs as envs
13
14
15
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import GemmaRMSNorm
16
from vllm.model_executor.models.module_mapping import MultiModelKeys
17
from vllm.model_executor.sampling_metadata import SamplingMetadata
18
19
20
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
                                    MultiModalKwargs)
21
22
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
                                   MultiModalDataItems)
23
# yapf: disable
24
from vllm.multimodal.processing import (BaseMultiModalProcessor,
25
26
27
28
                                        BaseProcessingInfo, BoundPromptUpdate,
                                        PlaceholderFeaturesInfo,
                                        PromptReplacement, PromptTargetMatch,
                                        PromptUpdate, PromptUpdateDetails,
29
                                        find_mm_placeholders,
30
31
                                        replace_token_matches)
# yapf: enable
32
from vllm.multimodal.profiling import BaseDummyInputsBuilder
33
34
from vllm.sequence import IntermediateTensors

35
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
36
                         SupportsMultiModal, SupportsPP)
37
38
39
40
41
42
43
44
45
from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
                    maybe_prefix, merge_multimodal_embeddings)

logger = init_logger(__name__)


class Gemma3ImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
46
47
    pixel_values: torch.Tensor
    """
48
    Shape: `(num_patches_total, num_channels, height, width)`
49

50
    `num_patches_total` is the total number of patches
51
52
    over each image over each prompt in the batch.
    """
53
54
55
56

    num_patches: torch.Tensor
    """Shape: `(batch_size * num_images)`"""

57
58
59
60
61
62

Gemma3ImageInputs = Gemma3ImagePixelInputs


class Gemma3ProcessingInfo(BaseProcessingInfo):

63
64
65
    def get_hf_config(self):
        return self.ctx.get_hf_config(Gemma3Config)

66
67
68
    def get_hf_processor(self, **kwargs: object):
        return self.ctx.get_hf_processor(Gemma3Processor, **kwargs)

69
70
71
    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": None}

72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
    def _resolve_image_kwargs(
        self,
        processor: Gemma3Processor,
        keys: set[str],
    ) -> dict[str, Any]:
        image_processor = processor.image_processor
        kwargs = processor._merge_kwargs(
            Gemma3ProcessorKwargs,
            tokenizer_init_kwargs=processor.tokenizer.init_kwargs,
        )

        images_kwargs = kwargs["images_kwargs"]

        def _resolve_kw(key: str):
            val = getattr(image_processor, key)
            if val is None:
                val = images_kwargs[key]

            return val

        return {k: _resolve_kw(k) for k in keys}

    def get_num_crops(
        self,
        *,
        image_width: int,
        image_height: int,
        processor: Optional[Gemma3Processor],
    ) -> int:
        if processor is None:
            processor = self.get_hf_processor()

        images_kwargs = self._resolve_image_kwargs(
            processor, {
                "do_pan_and_scan", "pan_and_scan_min_crop_size",
                "pan_and_scan_max_num_crops",
                "pan_and_scan_min_ratio_to_activate"
            })

        do_pan_and_scan = images_kwargs["do_pan_and_scan"]
        pan_and_scan_min_crop_size = images_kwargs[
            "pan_and_scan_min_crop_size"]
        pan_and_scan_max_num_crops = images_kwargs[
            "pan_and_scan_max_num_crops"]
        pan_and_scan_min_ratio_to_activate = images_kwargs[
            "pan_and_scan_min_ratio_to_activate"]

        if not do_pan_and_scan:
            return 0

122
123
124
125
126
        if envs.VLLM_USE_V1:
            logger.warning_once(
                "`do_pan_and_scan=True` has suboptimal results on V1 "
                "because of the simplified attention pattern being used.")

127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
        # Based on Gemma3ImageProcessor.pan_and_scan
        if image_width >= image_height:
            if image_width / image_height < pan_and_scan_min_ratio_to_activate:
                return 0

            num_crops_w = min(
                int(math.floor(image_width / pan_and_scan_min_crop_size)),
                int(math.floor(image_width / image_height + 0.5)),
            )

            num_crops_w = max(2, num_crops_w)
            num_crops_w = min(pan_and_scan_max_num_crops, num_crops_w)
            num_crops_h = 1
        else:
            if image_height / image_width < pan_and_scan_min_ratio_to_activate:
                return 0

            num_crops_h = min(
                int(math.floor(image_height / pan_and_scan_min_crop_size)),
                int(math.floor(image_height / image_width + 0.5)),
            )

            num_crops_h = max(2, num_crops_h)
            num_crops_h = min(pan_and_scan_max_num_crops, num_crops_h)
            num_crops_w = 1

        crop_size_w = int(math.ceil(image_width / num_crops_w))
        crop_size_h = int(math.ceil(image_height / num_crops_h))

        if min(crop_size_w, crop_size_h) < pan_and_scan_min_crop_size:
            return 0

        return num_crops_w * num_crops_h

    def get_image_repl(
        self,
        *,
        image_width: int,
        image_height: int,
        processor: Optional[Gemma3Processor],
167
    ) -> PromptUpdateDetails[str]:
168
169
170
        if processor is None:
            processor = self.get_hf_processor()

171
        boi_token = processor.boi_token
172
173
174
175
176
177
178
179

        num_crops = self.get_num_crops(
            image_width=image_width,
            image_height=image_height,
            processor=processor,
        )

        if num_crops == 0:
180
            image_text = boi_token
181
        else:
182
            crops_image_tokens = " ".join(boi_token for _ in range(num_crops))
183
            image_text = (
184
                f"Here is the original image {boi_token} and here are some "
185
186
                f"crops to help you see better {crops_image_tokens}")

187
        repl_full = image_text.replace(boi_token,
188
189
                                       processor.full_image_sequence)

190
191
192
193
194
        tokenizer = processor.tokenizer
        vocab = tokenizer.get_vocab()
        image_token_id = vocab[tokenizer.image_token]

        return PromptUpdateDetails.select_token_id(repl_full, image_token_id)
195
196
197
198
199
200

    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
201
        processor: Optional[Gemma3Processor],
202
    ) -> int:
203
204
205
206
        if processor is None:
            processor = self.get_hf_processor()

        num_crops = self.get_num_crops(
207
208
209
210
            image_width=image_width,
            image_height=image_height,
            processor=processor,
        )
211
        image_seq_len = processor.image_seq_length
212

213
        return (num_crops + 1) * image_seq_len
214
215

    def get_image_size_with_most_features(self) -> ImageSize:
216
217
218
219
220
221
222
223
224
        processor = self.get_hf_processor()

        images_kwargs = self._resolve_image_kwargs(
            processor, {"pan_and_scan_max_num_crops"})
        max_num_crops = images_kwargs["pan_and_scan_max_num_crops"]

        # Result in the max possible feature size (h:w = max_num_crops:1)
        return ImageSize(height=50 * max_num_crops, width=50)

225
226
227

class Gemma3DummyInputsBuilder(BaseDummyInputsBuilder[Gemma3ProcessingInfo]):

228
229
230
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

231
232
        processor = self.info.get_hf_processor()
        image_token = processor.boi_token
233

234
235
236
237
238
239
240
        return image_token * num_images

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> MultiModalDataDict:
241
        num_images = mm_counts.get("image", 0)
242

243
244
245
        target_width, target_height = \
            self.info.get_image_size_with_most_features()

246
        return {
247
248
249
250
251
            "image":
            self._get_dummy_images(width=target_width,
                                   height=target_height,
                                   num_images=num_images)
        }
252

253
254
255
256
257
258
259
260
261

class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):

    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
    ) -> BatchFeature:
262
263
264
265
        processed_outputs = super()._call_hf_processor(
            prompt,
            mm_data,
            mm_kwargs,
266
267
        )

268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
        # HF processor pops the `num_crops` kwarg, which is needed by vLLM
        if (images := mm_data.get("images")) is not None:
            parsed_images = (self._get_data_parser().parse_mm_data({
                "image":
                images
            }).get_items("image", ImageProcessorItems))
            image_sizes = [
                parsed_images.get_image_size(i)
                for i in range(len(parsed_images))
            ]
            hf_processor = self.info.get_hf_processor(**mm_kwargs)

            num_crops = [
                self.info.get_num_crops(image_width=size.width,
                                        image_height=size.height,
                                        processor=hf_processor)
                for size in image_sizes
            ]
            processed_outputs["num_crops"] = torch.tensor(num_crops)

        return processed_outputs

290
291
292
293
294
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
295
296
297
298
299
300
301
        num_crops = hf_inputs.get("num_crops", torch.empty(0))

        return dict(
            pixel_values=MultiModalFieldConfig.flat_from_sizes(
                "image", num_crops + 1),
            num_crops=MultiModalFieldConfig.batched("image"),
        )
302
303
304
305
306
307
308
309

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, Any],
        out_mm_kwargs: MultiModalKwargs,
    ) -> Sequence[PromptUpdate]:
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
310
        image_token = hf_processor.boi_token
311
312

        def get_replacement_gemma3(item_idx: int):
313
314
315
316
317
318
319
            images = mm_items.get_items("image", ImageProcessorItems)

            image_size = images.get_image_size(item_idx)
            return self.info.get_image_repl(
                image_width=image_size.width,
                image_height=image_size.height,
                processor=hf_processor,
320
321
322
323
324
            )

        return [
            PromptReplacement(
                modality="image",
325
                target=image_token,
326
327
328
329
                replacement=get_replacement_gemma3,
            )
        ]

330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
    def _apply_token_matches(
        self,
        prompt: list[int],
        mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
        mm_item_counts: Mapping[str, int],
    ) -> list[int]:
        token_ids = super()._apply_token_matches(
            prompt,
            mm_matches,
            mm_item_counts,
        )

        # "\n\n\n" and "\n\n\n\n" are single tokens
        # Since our replacement can insert "\n\n" next to "\n"
        # tokens, we have to combine them to be consistent with
        # the output of the tokenizer
        tokenizer = self.info.get_tokenizer()
        vocab = tokenizer.get_vocab()
        newline_1 = vocab["\n"]
        newline_2 = vocab["\n\n"]
        newline_3 = vocab["\n\n\n"]
        newline_4 = vocab["\n\n\n\n"]

        token_ids = replace_token_matches(
            token_ids,
            [newline_1, newline_2],
            [newline_3],
        )
        token_ids = replace_token_matches(
            token_ids,
            [newline_2, newline_1],
            [newline_3],
        )
        token_ids = replace_token_matches(
            token_ids,
            [newline_2, newline_2],
            [newline_4],
        )

        return token_ids

    def _find_mm_placeholders(
        self,
        mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]],
        new_token_ids: list[int],
        mm_item_counts: Mapping[str, int],
    ) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
        # We need to detect "\n\n" inside "\n\n\n" and "\n\n\n\n"
        tokenizer = self.info.get_tokenizer()
        vocab = tokenizer.get_vocab()
        newline_1 = vocab["\n"]
        newline_2 = vocab["\n\n"]
        newline_3 = vocab["\n\n\n"]
        newline_4 = vocab["\n\n\n\n"]

        def get_repl_toks(tok: int) -> list[int]:
            if tok == newline_3:
                return [newline_1, newline_2]
            if tok == newline_4:
                return [newline_2, newline_2]

            return [tok]

        repl_token_ids = list[int]()
        repl_orig_idxs = list[int]()
        for orig_idx, orig_tok in enumerate(new_token_ids):
            repl_toks = get_repl_toks(orig_tok)
            repl_token_ids.extend(repl_toks)
            repl_orig_idxs.extend(orig_idx for _ in range(len(repl_toks)))

        repls = find_mm_placeholders(mm_prompt_updates, repl_token_ids,
                                     mm_item_counts)

        return {
            modality: [
                PlaceholderFeaturesInfo(
                    modality=p.modality,
                    item_idx=p.item_idx,
                    start_idx=repl_orig_idxs[p.start_idx],
                    tokens=p.tokens,
410
                    is_embed=p.is_embed,
411
412
413
414
415
                ) for p in placeholders
            ]
            for modality, placeholders in repls.items()
        }

416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459

class Gemma3MultiModalProjector(nn.Module):

    def __init__(self, config: Gemma3Config):
        super().__init__()

        self.mm_input_projection_weight = nn.Parameter(
            torch.zeros(config.vision_config.hidden_size,
                        config.text_config.hidden_size))

        self.mm_soft_emb_norm = GemmaRMSNorm(
            config.vision_config.hidden_size,
            eps=config.vision_config.layer_norm_eps)

        self.patches_per_image = int(config.vision_config.image_size //
                                     config.vision_config.patch_size)
        self.tokens_per_side = int(config.mm_tokens_per_image**0.5)
        self.kernel_size = self.patches_per_image // self.tokens_per_side
        self.avg_pool = nn.AvgPool2d(kernel_size=self.kernel_size,
                                     stride=self.kernel_size)

    def forward(self, vision_outputs: torch.Tensor):
        batch_size, _, seq_length = vision_outputs.shape

        reshaped_vision_outputs = vision_outputs.transpose(1, 2)
        reshaped_vision_outputs = reshaped_vision_outputs.reshape(
            batch_size, seq_length, self.patches_per_image,
            self.patches_per_image)
        reshaped_vision_outputs = reshaped_vision_outputs.contiguous()

        pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs)
        pooled_vision_outputs = pooled_vision_outputs.flatten(2)
        pooled_vision_outputs = pooled_vision_outputs.transpose(1, 2)

        normed_vision_outputs = self.mm_soft_emb_norm(pooled_vision_outputs)

        projected_vision_outputs = torch.matmul(
            normed_vision_outputs, self.mm_input_projection_weight)
        return projected_vision_outputs.type_as(vision_outputs)


@MULTIMODAL_REGISTRY.register_processor(Gemma3MultiModalProcessor,
                                        info=Gemma3ProcessingInfo,
                                        dummy_inputs=Gemma3DummyInputsBuilder)
460
class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
461
                                     SupportsLoRA):
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
        self.config = config
        self.quant_config = quant_config
        self.multimodal_config = multimodal_config
482
483
        self.sliding_window = getattr(config.text_config,
                                      "interleaved_sliding_window", None)
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502

        self.vision_tower = SiglipVisionModel(config.vision_config,
                                              quant_config,
                                              prefix=maybe_prefix(
                                                  prefix, "vision_tower"))
        self.multi_modal_projector = Gemma3MultiModalProjector(config)

        self.language_model = init_vllm_registered_model(
            vllm_config=vllm_config,
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
            architectures=["Gemma3ForCausalLM"],
        )
        logit_scale = getattr(config, "logit_scale", 1.0)
        self.language_model.logits_processor.scale *= logit_scale

        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)

503
504
505
506
    @property
    def dtype(self):
        return next(self.parameters()).dtype

507
    def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
508
509
510
511
512
513
        image_size = self.config.vision_config.image_size
        expected_dims = (3, image_size, image_size)
        if data.shape[1:] != expected_dims:
            raise ValueError(
                "The expected shape of pixel values per image per batch is "
                f"{expected_dims}. You supplied {tuple(data.shape)}.")
514
515
516
517
518
        return data

    def _parse_and_validate_image_input(
            self, **kwargs: object) -> Optional[Gemma3ImageInputs]:
        pixel_values = kwargs.pop("pixel_values", None)
519
        num_crops = kwargs.pop("num_crops", None)
520
521
522
523
524
        image_embeds = kwargs.pop("image_embeds", None)
        assert image_embeds is None, "Gemma3 does not support image_embeds."
        if pixel_values is None:
            return None

525
        if not isinstance(pixel_values, (torch.Tensor, list)):
526
527
528
            raise ValueError("Incorrect type of pixel values. "
                             f"Got type: {type(pixel_values)}")

529
        if not isinstance(num_crops, (torch.Tensor, list)):
530
            raise ValueError("Incorrect type of num_crops. "
531
532
                             f"Got type: {type(num_crops)}")

533
        pixel_values = flatten_bn(pixel_values, concat=True)
534
535
        num_crops = flatten_bn(num_crops, concat=True)

536
537
        return Gemma3ImagePixelInputs(
            type="pixel_values",
538
            pixel_values=self._validate_pixel_values(pixel_values),
539
            num_patches=num_crops + 1,
540
541
542
543
544
545
546
        )

    def _image_pixels_to_features(
        self,
        vision_tower: SiglipVisionModel,
        pixel_values: torch.Tensor,
    ) -> torch.Tensor:
547
        return vision_tower(pixel_values)
548
549
550
551

    def _process_image_input(
        self,
        image_input: Gemma3ImageInputs,
552
    ) -> list[torch.Tensor]:
553
        assert self.vision_tower is not None
554
555

        pixel_values = image_input["pixel_values"]
556
557
558
        num_patches = image_input["num_patches"]

        image_features = self._image_pixels_to_features(
559
560
561
            self.vision_tower,
            pixel_values,
        )
562
563
        image_embeds = self.multi_modal_projector(image_features)

564
565
566
        return [
            e.flatten(0, 1) for e in image_embeds.split(num_patches.tolist())
        ]
567

568
569
570
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

571
572
    def get_multimodal_embeddings(
            self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
573
574
575
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
            return None
576

577
        return self._process_image_input(image_input)
578
579
580
581

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
582
        multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
583
    ) -> torch.Tensor:
584
585
        inputs_embeds = self.language_model.get_input_embeddings(input_ids)
        if multimodal_embeddings is not None:
586
            inputs_embeds = merge_multimodal_embeddings(
587
588
                input_ids,
                inputs_embeds,
589
                multimodal_embeddings,
590
591
                self.config.image_token_index,
            )
592
593
594
595
596
597
598
        return inputs_embeds

    def forward(self,
                input_ids: torch.Tensor,
                positions: torch.Tensor,
                intermediate_tensors: Optional[IntermediateTensors] = None,
                inputs_embeds: Optional[torch.Tensor] = None,
599
                **kwargs: object) -> IntermediateTensors:
600
601
602
603
604
605
606
607
608
609
610
611
612
613
        if intermediate_tensors is not None:
            inputs_embeds = None

        # NOTE: In v1, inputs_embeds is always generated at model runner, this
        # condition is for v0 compatibility.
        elif inputs_embeds is None:
            vision_embeddings = self.get_multimodal_embeddings(**kwargs)

            inputs_embeds = self.get_input_embeddings(input_ids,
                                                      vision_embeddings)
            if vision_embeddings is not None:
                kwargs = self.prepare_attn_masks(
                    input_ids,
                    positions,
614
615
616
                    mask_dtype=self.dtype,
                    **kwargs,
                )
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
            input_ids = None

        hidden_states = self.language_model.model(input_ids,
                                                  positions,
                                                  intermediate_tensors,
                                                  inputs_embeds=inputs_embeds,
                                                  **kwargs)

        return hidden_states

    def prepare_attn_masks(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        mask_dtype: torch.dtype,
        **kwargs,
    ):
        kwargs["has_images"] = True
        # NOTE(woosuk): Here, we distinguish the sequences by the position id 0.
        # This is a HACK. Fix this.
        start_idices = (positions == 0).cpu().nonzero()
        num_seqs = len(start_idices)
        seq_lens = []
        for i in range(num_seqs):
            start_idx = start_idices[i].item()
            if i < num_seqs - 1:
                end_idx = start_idices[i + 1].item()
            else:
                end_idx = len(input_ids)
            seq_lens.append(end_idx - start_idx)
        kwargs["seq_lens"] = seq_lens

        global_attn_masks = []
        local_attn_masks = []
        start_idx = 0
        for seq_len in seq_lens:
            end_idx = start_idx + seq_len
            input_token_ids = input_ids[start_idx:end_idx]
            start_idx = end_idx
            # Create a global causal mask.
            global_attn_mask = torch.empty(
                1,
                1,
                seq_len,
                seq_len,
                dtype=mask_dtype,
                device=input_ids.device,
            )
            global_attn_mask.fill_(float("-inf"))
            # Fill the lower triangle with 0.
            global_attn_mask = global_attn_mask.triu(diagonal=1)

            # Consider the bidirectional attention between image tokens.
            img_mask = torch.zeros_like(global_attn_mask)
            img_pos = (input_token_ids == self.config.image_token_index)
            img_mask[:, :, :, img_pos] += 1
            img_mask[:, :, img_pos, :] += 1
            global_attn_mask = torch.where(img_mask == 2, 0, global_attn_mask)
            global_attn_masks.append(global_attn_mask)

677
678
679
680
681
682
683
684
            if self.sliding_window is not None:
                # Create a local causal mask with sliding window (1024).
                local_attn_mask = torch.ones_like(global_attn_mask)
                local_attn_mask = torch.tril(local_attn_mask,
                                             diagonal=-self.sliding_window)
                local_attn_mask = torch.where(local_attn_mask == 0,
                                              global_attn_mask, float("-inf"))
                local_attn_masks.append(local_attn_mask)
685
686
687
688
689
690
691
692
693
694
695
696
        kwargs["global_attn_masks"] = global_attn_masks
        kwargs["local_attn_masks"] = local_attn_masks
        return kwargs

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)

697
698
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
699
700
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights)
701
702
703
704
705
706
707
708
709

    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="language_model",
            connector="multi_modal_projector",
            tower_model="vision_tower")