gemma3_mm.py 23.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import math
4
from collections.abc import Iterable, Mapping, Sequence
5
from typing import Annotated, Any, Literal, Optional
6
7
8

import torch
from torch import nn
9
10
from transformers import BatchFeature, Gemma3Config, Gemma3Processor
from transformers.models.gemma3.processing_gemma3 import Gemma3ProcessorKwargs
11

12
import vllm.envs as envs
13
from vllm.config import VllmConfig
14
from vllm.config.multimodal import BaseDummyOptions
15
16
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import GemmaRMSNorm
17
from vllm.model_executor.models.module_mapping import MultiModelKeys
18
from vllm.multimodal import MULTIMODAL_REGISTRY
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
)
from vllm.multimodal.parse import ImageProcessorItems, ImageSize, MultiModalDataItems
from vllm.multimodal.processing import (
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    MultiModalPromptUpdates,
    MultiModalPromptUpdatesApplyResult,
    PlaceholderFeaturesInfo,
    PromptReplacement,
    PromptUpdate,
    PromptUpdateDetails,
    replace_token_matches,
)
36
from vllm.multimodal.profiling import BaseDummyInputsBuilder
37
from vllm.sequence import IntermediateTensors
38
from vllm.utils.tensor_schema import TensorSchema, TensorShape
39

40
41
42
43
44
45
from .interfaces import (
    MultiModalEmbeddings,
    SupportsLoRA,
    SupportsMultiModal,
    SupportsPP,
)
46
from .siglip import SiglipVisionModel
47
48
49
50
51
52
from .utils import (
    AutoWeightsLoader,
    WeightsMapper,
    init_vllm_registered_model,
    maybe_prefix,
)
53
54
55
56

logger = init_logger(__name__)


57
class Gemma3ImagePixelInputs(TensorSchema):
58
    """
59
60
61
62
63
64
65
    Dimensions:
        - p: Number of patches total (over each image over each prompt in the
          batch)
        - c: Number of channels (3)
        - h: Height of each patch
        - w: Width of each patch
        - bn: Batch size * number of images
66
    """
67

68
69
70
    type: Literal["pixel_values"] = "pixel_values"

    pixel_values: Annotated[torch.Tensor, TensorShape("p", 3, "h", "w")]
71

72
    num_patches: Annotated[torch.Tensor, TensorShape("bn")]
73

74
75
76
77
78

Gemma3ImageInputs = Gemma3ImagePixelInputs


class Gemma3ProcessingInfo(BaseProcessingInfo):
79
80
81
    def get_hf_config(self):
        return self.ctx.get_hf_config(Gemma3Config)

82
83
84
    def get_hf_processor(self, **kwargs: object):
        return self.ctx.get_hf_processor(Gemma3Processor, **kwargs)

85
86
87
    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": None}

88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
    def _resolve_image_kwargs(
        self,
        processor: Gemma3Processor,
        keys: set[str],
    ) -> dict[str, Any]:
        image_processor = processor.image_processor
        kwargs = processor._merge_kwargs(
            Gemma3ProcessorKwargs,
            tokenizer_init_kwargs=processor.tokenizer.init_kwargs,
        )

        images_kwargs = kwargs["images_kwargs"]

        def _resolve_kw(key: str):
            val = getattr(image_processor, key)
            if val is None:
                val = images_kwargs[key]

            return val

        return {k: _resolve_kw(k) for k in keys}

    def get_num_crops(
        self,
        *,
        image_width: int,
        image_height: int,
        processor: Optional[Gemma3Processor],
    ) -> int:
        if processor is None:
            processor = self.get_hf_processor()

        images_kwargs = self._resolve_image_kwargs(
121
122
123
124
            processor,
            {
                "do_pan_and_scan",
                "pan_and_scan_min_crop_size",
125
                "pan_and_scan_max_num_crops",
126
127
128
                "pan_and_scan_min_ratio_to_activate",
            },
        )
129
130

        do_pan_and_scan = images_kwargs["do_pan_and_scan"]
131
132
        pan_and_scan_min_crop_size = images_kwargs["pan_and_scan_min_crop_size"]
        pan_and_scan_max_num_crops = images_kwargs["pan_and_scan_max_num_crops"]
133
        pan_and_scan_min_ratio_to_activate = images_kwargs[
134
135
            "pan_and_scan_min_ratio_to_activate"
        ]
136
137
138
139

        if not do_pan_and_scan:
            return 0

140
141
142
        if envs.VLLM_USE_V1:
            logger.warning_once(
                "`do_pan_and_scan=True` has suboptimal results on V1 "
143
144
                "because of the simplified attention pattern being used."
            )
145

146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
        # Based on Gemma3ImageProcessor.pan_and_scan
        if image_width >= image_height:
            if image_width / image_height < pan_and_scan_min_ratio_to_activate:
                return 0

            num_crops_w = min(
                int(math.floor(image_width / pan_and_scan_min_crop_size)),
                int(math.floor(image_width / image_height + 0.5)),
            )

            num_crops_w = max(2, num_crops_w)
            num_crops_w = min(pan_and_scan_max_num_crops, num_crops_w)
            num_crops_h = 1
        else:
            if image_height / image_width < pan_and_scan_min_ratio_to_activate:
                return 0

            num_crops_h = min(
                int(math.floor(image_height / pan_and_scan_min_crop_size)),
                int(math.floor(image_height / image_width + 0.5)),
            )

            num_crops_h = max(2, num_crops_h)
            num_crops_h = min(pan_and_scan_max_num_crops, num_crops_h)
            num_crops_w = 1

        crop_size_w = int(math.ceil(image_width / num_crops_w))
        crop_size_h = int(math.ceil(image_height / num_crops_h))

        if min(crop_size_w, crop_size_h) < pan_and_scan_min_crop_size:
            return 0

        return num_crops_w * num_crops_h

    def get_image_repl(
        self,
        *,
        image_width: int,
        image_height: int,
        processor: Optional[Gemma3Processor],
186
    ) -> PromptUpdateDetails[str]:
187
188
189
        if processor is None:
            processor = self.get_hf_processor()

190
        boi_token = processor.boi_token
191
192
193
194
195
196
197
198

        num_crops = self.get_num_crops(
            image_width=image_width,
            image_height=image_height,
            processor=processor,
        )

        if num_crops == 0:
199
            image_text = boi_token
200
        else:
201
            crops_image_tokens = " ".join(boi_token for _ in range(num_crops))
202
            image_text = (
203
                f"Here is the original image {boi_token} and here are some "
204
205
                f"crops to help you see better {crops_image_tokens}"
            )
206

207
        repl_full = image_text.replace(boi_token, processor.full_image_sequence)
208

209
210
211
212
213
        tokenizer = processor.tokenizer
        vocab = tokenizer.get_vocab()
        image_token_id = vocab[tokenizer.image_token]

        return PromptUpdateDetails.select_token_id(repl_full, image_token_id)
214
215
216
217
218
219

    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
220
        processor: Optional[Gemma3Processor],
221
    ) -> int:
222
223
224
225
        if processor is None:
            processor = self.get_hf_processor()

        num_crops = self.get_num_crops(
226
227
228
229
            image_width=image_width,
            image_height=image_height,
            processor=processor,
        )
230
        image_seq_len = processor.image_seq_length
231

232
        return (num_crops + 1) * image_seq_len
233
234

    def get_image_size_with_most_features(self) -> ImageSize:
235
236
237
        processor = self.get_hf_processor()

        images_kwargs = self._resolve_image_kwargs(
238
239
            processor, {"pan_and_scan_max_num_crops"}
        )
240
241
242
243
244
        max_num_crops = images_kwargs["pan_and_scan_max_num_crops"]

        # Result in the max possible feature size (h:w = max_num_crops:1)
        return ImageSize(height=50 * max_num_crops, width=50)

245
246

class Gemma3DummyInputsBuilder(BaseDummyInputsBuilder[Gemma3ProcessingInfo]):
247
248
249
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

250
251
        processor = self.info.get_hf_processor()
        image_token = processor.boi_token
252

253
254
255
256
257
258
        return image_token * num_images

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
259
        mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
260
    ) -> MultiModalDataDict:
261
        num_images = mm_counts.get("image", 0)
262

263
        target_width, target_height = self.info.get_image_size_with_most_features()
264

265
266
        image_overrides = mm_options.get("image") if mm_options else None

267
        return {
268
269
270
271
272
273
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            )
274
        }
275

276
277
278
279
280
281
282

class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
283
        tok_kwargs: Mapping[str, object],
284
    ) -> BatchFeature:
285
286
287
288
        processed_outputs = super()._call_hf_processor(
            prompt,
            mm_data,
            mm_kwargs,
289
            tok_kwargs,
290
291
        )

292
293
        # HF processor pops the `num_crops` kwarg, which is needed by vLLM
        if (images := mm_data.get("images")) is not None:
294
295
296
297
298
            parsed_images = (
                self._get_data_parser()
                .parse_mm_data({"image": images})
                .get_items("image", ImageProcessorItems)
            )
299
            image_sizes = [
300
                parsed_images.get_image_size(i) for i in range(len(parsed_images))
301
302
303
304
            ]
            hf_processor = self.info.get_hf_processor(**mm_kwargs)

            num_crops = [
305
306
307
308
309
                self.info.get_num_crops(
                    image_width=size.width,
                    image_height=size.height,
                    processor=hf_processor,
                )
310
311
                for size in image_sizes
            ]
312
            processed_outputs["num_patches"] = torch.tensor(num_crops) + 1
313
314
315

        return processed_outputs

316
317
318
319
320
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
321
        num_patches = hf_inputs.get("num_patches", torch.empty(0))
322
323

        return dict(
324
            pixel_values=MultiModalFieldConfig.flat_from_sizes("image", num_patches),
325
            num_patches=MultiModalFieldConfig.batched("image"),
326
        )
327
328
329
330
331

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, Any],
332
        out_mm_kwargs: MultiModalKwargsItems,
333
334
    ) -> Sequence[PromptUpdate]:
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
335
        image_token = hf_processor.boi_token
336
337

        def get_replacement_gemma3(item_idx: int):
338
339
340
341
342
343
344
            images = mm_items.get_items("image", ImageProcessorItems)

            image_size = images.get_image_size(item_idx)
            return self.info.get_image_repl(
                image_width=image_size.width,
                image_height=image_size.height,
                processor=hf_processor,
345
346
347
348
349
            )

        return [
            PromptReplacement(
                modality="image",
350
                target=image_token,
351
352
353
354
                replacement=get_replacement_gemma3,
            )
        ]

355
356
357
    def _apply_token_matches(
        self,
        prompt: list[int],
358
359
        mm_prompt_updates: MultiModalPromptUpdates,
    ) -> tuple[list[int], MultiModalPromptUpdatesApplyResult]:
360
        token_ids, res = super()._apply_token_matches(prompt, mm_prompt_updates)
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388

        # "\n\n\n" and "\n\n\n\n" are single tokens
        # Since our replacement can insert "\n\n" next to "\n"
        # tokens, we have to combine them to be consistent with
        # the output of the tokenizer
        tokenizer = self.info.get_tokenizer()
        vocab = tokenizer.get_vocab()
        newline_1 = vocab["\n"]
        newline_2 = vocab["\n\n"]
        newline_3 = vocab["\n\n\n"]
        newline_4 = vocab["\n\n\n\n"]

        token_ids = replace_token_matches(
            token_ids,
            [newline_1, newline_2],
            [newline_3],
        )
        token_ids = replace_token_matches(
            token_ids,
            [newline_2, newline_1],
            [newline_3],
        )
        token_ids = replace_token_matches(
            token_ids,
            [newline_2, newline_2],
            [newline_4],
        )

389
        return token_ids, res
390
391
392
393

    def _find_mm_placeholders(
        self,
        new_token_ids: list[int],
394
        mm_prompt_updates: MultiModalPromptUpdates,
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
    ) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
        # We need to detect "\n\n" inside "\n\n\n" and "\n\n\n\n"
        tokenizer = self.info.get_tokenizer()
        vocab = tokenizer.get_vocab()
        newline_1 = vocab["\n"]
        newline_2 = vocab["\n\n"]
        newline_3 = vocab["\n\n\n"]
        newline_4 = vocab["\n\n\n\n"]

        def get_repl_toks(tok: int) -> list[int]:
            if tok == newline_3:
                return [newline_1, newline_2]
            if tok == newline_4:
                return [newline_2, newline_2]

            return [tok]

        repl_token_ids = list[int]()
        repl_orig_idxs = list[int]()
        for orig_idx, orig_tok in enumerate(new_token_ids):
            repl_toks = get_repl_toks(orig_tok)
            repl_token_ids.extend(repl_toks)
            repl_orig_idxs.extend(orig_idx for _ in range(len(repl_toks)))

419
        repls = super()._find_mm_placeholders(repl_token_ids, mm_prompt_updates)
420
421
422
423
424
425
426
427

        return {
            modality: [
                PlaceholderFeaturesInfo(
                    modality=p.modality,
                    item_idx=p.item_idx,
                    start_idx=repl_orig_idxs[p.start_idx],
                    tokens=p.tokens,
428
                    is_embed=p.is_embed,
429
430
                )
                for p in placeholders
431
432
433
434
            ]
            for modality, placeholders in repls.items()
        }

435
436
437
438
439
440

class Gemma3MultiModalProjector(nn.Module):
    def __init__(self, config: Gemma3Config):
        super().__init__()

        self.mm_input_projection_weight = nn.Parameter(
441
442
443
444
            torch.zeros(
                config.vision_config.hidden_size, config.text_config.hidden_size
            )
        )
445
446

        self.mm_soft_emb_norm = GemmaRMSNorm(
447
448
            config.vision_config.hidden_size, eps=config.vision_config.layer_norm_eps
        )
449

450
451
452
        self.patches_per_image = int(
            config.vision_config.image_size // config.vision_config.patch_size
        )
453
454
        self.tokens_per_side = int(config.mm_tokens_per_image**0.5)
        self.kernel_size = self.patches_per_image // self.tokens_per_side
455
456
457
        self.avg_pool = nn.AvgPool2d(
            kernel_size=self.kernel_size, stride=self.kernel_size
        )
458
459
460
461
462
463

    def forward(self, vision_outputs: torch.Tensor):
        batch_size, _, seq_length = vision_outputs.shape

        reshaped_vision_outputs = vision_outputs.transpose(1, 2)
        reshaped_vision_outputs = reshaped_vision_outputs.reshape(
464
465
            batch_size, seq_length, self.patches_per_image, self.patches_per_image
        )
466
467
468
469
470
471
472
473
474
        reshaped_vision_outputs = reshaped_vision_outputs.contiguous()

        pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs)
        pooled_vision_outputs = pooled_vision_outputs.flatten(2)
        pooled_vision_outputs = pooled_vision_outputs.transpose(1, 2)

        normed_vision_outputs = self.mm_soft_emb_norm(pooled_vision_outputs)

        projected_vision_outputs = torch.matmul(
475
476
            normed_vision_outputs, self.mm_input_projection_weight
        )
477
478
479
        return projected_vision_outputs.type_as(vision_outputs)


480
481
482
483
484
485
486
487
@MULTIMODAL_REGISTRY.register_processor(
    Gemma3MultiModalProcessor,
    info=Gemma3ProcessingInfo,
    dummy_inputs=Gemma3DummyInputsBuilder,
)
class Gemma3ForConditionalGeneration(
    nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA
):
488
489
    merge_by_field_config = True

490
491
492
493
494
495
496
497
498
499
500
501
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }

502
503
504
505
506
507
508
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            # mapping for new names in checkpoint saved after transformers v4.52
            "model.language_model.": "language_model.model.",
            "model.vision_tower.": "vision_tower.",
            "model.multi_modal_projector.": "multi_modal_projector.",
            "lm_head.": "language_model.lm_head.",
509
510
        }
    )
511

512
513
514
515
516
517
518
    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
        if modality.startswith("image"):
            return "<start_of_image>"

        raise ValueError("Only image modality is supported")

519
520
521
522
523
524
525
526
527
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
        self.config = config
        self.quant_config = quant_config
        self.multimodal_config = multimodal_config

528
529
530
531
532
        self.vision_tower = SiglipVisionModel(
            config.vision_config,
            quant_config,
            prefix=maybe_prefix(prefix, "vision_tower"),
        )
533
534
535
536
537
538
539
540
541
        self.multi_modal_projector = Gemma3MultiModalProjector(config)

        self.language_model = init_vllm_registered_model(
            vllm_config=vllm_config,
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
            architectures=["Gemma3ForCausalLM"],
        )
        logit_scale = getattr(config, "logit_scale", 1.0)
542
543
544
545
546

        if hasattr(self.language_model, "logits_processor"):
            # The logits processor can be unset if we're using
            # automatic conversion to pooling model.
            self.language_model.logits_processor.scale *= logit_scale
547
548

        self.make_empty_intermediate_tensors = (
549
550
            self.language_model.make_empty_intermediate_tensors
        )
551

552
553
554
555
    @property
    def dtype(self):
        return next(self.parameters()).dtype

556
    def _parse_and_validate_image_input(
557
558
        self, **kwargs: object
    ) -> Optional[Gemma3ImageInputs]:
559
        pixel_values = kwargs.pop("pixel_values", None)
560
        num_patches = kwargs.pop("num_patches", None)
561
562
563
564
565
        image_embeds = kwargs.pop("image_embeds", None)
        assert image_embeds is None, "Gemma3 does not support image_embeds."
        if pixel_values is None:
            return None

566
        image_size = self.config.vision_config.image_size
567

568
569
570
571
572
        return Gemma3ImagePixelInputs(
            pixel_values=pixel_values,
            num_patches=num_patches,
            resolve_bindings={"h": image_size, "w": image_size},
        )
573
574
575
576
577
578

    def _image_pixels_to_features(
        self,
        vision_tower: SiglipVisionModel,
        pixel_values: torch.Tensor,
    ) -> torch.Tensor:
579
        return vision_tower(pixel_values)
580
581
582
583

    def _process_image_input(
        self,
        image_input: Gemma3ImageInputs,
584
    ) -> list[torch.Tensor]:
585
        assert self.vision_tower is not None
586
587

        pixel_values = image_input["pixel_values"]
588
589
590
        num_patches = image_input["num_patches"]

        image_features = self._image_pixels_to_features(
591
592
593
            self.vision_tower,
            pixel_values,
        )
594
595
        image_embeds = self.multi_modal_projector(image_features)

596
        return [e.flatten(0, 1) for e in image_embeds.split(num_patches.tolist())]
597

598
599
600
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

601
    def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
602
603
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
604
            return []
605

606
        return self._process_image_input(image_input)
607

608
609
610
611
612
613
614
615
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        **kwargs: object,
    ) -> IntermediateTensors:
616
617
618
        if intermediate_tensors is not None:
            inputs_embeds = None

619
620
621
622
623
624
625
        hidden_states = self.language_model.model(
            input_ids,
            positions,
            intermediate_tensors,
            inputs_embeds=inputs_embeds,
            **kwargs,
        )
626
627
628
629
630
631
632
633
634
635
636
637
638

        return hidden_states

    def prepare_attn_masks(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        mask_dtype: torch.dtype,
        **kwargs,
    ):
        kwargs["has_images"] = True
        # NOTE(woosuk): Here, we distinguish the sequences by the position id 0.
        # This is a HACK. Fix this.
639
640
        start_indices = (positions == 0).cpu().nonzero()
        num_seqs = len(start_indices)
641
642
        seq_lens = []
        for i in range(num_seqs):
643
            start_idx = start_indices[i].item()
644
            if i < num_seqs - 1:
645
                end_idx = start_indices[i + 1].item()
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
            else:
                end_idx = len(input_ids)
            seq_lens.append(end_idx - start_idx)
        kwargs["seq_lens"] = seq_lens

        global_attn_masks = []
        local_attn_masks = []
        start_idx = 0
        for seq_len in seq_lens:
            end_idx = start_idx + seq_len
            input_token_ids = input_ids[start_idx:end_idx]
            start_idx = end_idx
            # Create a global causal mask.
            global_attn_mask = torch.empty(
                1,
                1,
                seq_len,
                seq_len,
                dtype=mask_dtype,
                device=input_ids.device,
            )
            global_attn_mask.fill_(float("-inf"))
            # Fill the lower triangle with 0.
            global_attn_mask = global_attn_mask.triu(diagonal=1)

            # Consider the bidirectional attention between image tokens.
            img_mask = torch.zeros_like(global_attn_mask)
673
            img_pos = input_token_ids == self.config.image_token_index
674
675
676
677
678
            img_mask[:, :, :, img_pos] += 1
            img_mask[:, :, img_pos, :] += 1
            global_attn_mask = torch.where(img_mask == 2, 0, global_attn_mask)
            global_attn_masks.append(global_attn_mask)

679
680
            sliding_window = self.config.text_config.sliding_window
            if sliding_window is not None:
681
682
                # Create a local causal mask with sliding window (1024).
                local_attn_mask = torch.ones_like(global_attn_mask)
683
684
685
686
                local_attn_mask = torch.tril(local_attn_mask, diagonal=-sliding_window)
                local_attn_mask = torch.where(
                    local_attn_mask == 0, global_attn_mask, float("-inf")
                )
687
                local_attn_masks.append(local_attn_mask)
688
689
690
691
692
693
694
695
        kwargs["global_attn_masks"] = global_attn_masks
        kwargs["local_attn_masks"] = local_attn_masks
        return kwargs

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> Optional[torch.Tensor]:
696
        return self.language_model.compute_logits(hidden_states)
697

698
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
699
        loader = AutoWeightsLoader(self)
700
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
701
702
703
704
705
706
707
708

    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="language_model",
            connector="multi_modal_projector",
709
710
            tower_model="vision_tower",
        )