mllama.py 68.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch Mllama model."""
import math
19
from collections.abc import Iterable, Mapping, Sequence
20
from typing import Annotated, Literal, Optional, Union
21

22
import numpy as np
23
24
25
import torch
import torch.nn.functional as F
import transformers.models.mllama.configuration_mllama as config_mllama
26
from PIL.Image import Image
27
from torch import nn
28
from transformers import BatchFeature, MllamaConfig
29
30
31
32
from transformers.modeling_outputs import (BaseModelOutput,
                                           CausalLMOutputWithPast)
from transformers.models.mllama.image_processing_mllama import (
    get_optimal_tiled_canvas)
33
from transformers.models.mllama.processing_mllama import (
34
    MllamaProcessor, get_cross_attention_token_mask)
35
36
37

import vllm.distributed.parallel_state as ps
from vllm.attention import Attention, AttentionMetadata, AttentionType
38
from vllm.attention.ops.paged_attn import PagedAttention
39
from vllm.attention.selector import _Backend
40
from vllm.config import VllmConfig
41
42
from vllm.distributed import get_pp_group, get_tp_group
from vllm.forward_context import get_forward_context
43
44
45
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
46
                                               QKVCrossParallelLinear,
47
48
49
50
51
52
                                               QKVParallelLinear,
                                               RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
    DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
53
54
from vllm.model_executor.model_loader.weight_utils import (
    default_weight_loader, maybe_remap_kv_scale_name)
55
from vllm.model_executor.models.module_mapping import MultiModelKeys
56
57
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
58
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs,
59
60
                                    MultiModalFieldConfig,
                                    MultiModalKwargsItems)
61
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
62
                                   MultiModalDataItems)
63
64
from vllm.multimodal.processing import (BaseProcessingInfo,
                                        EncDecMultiModalProcessor,
65
                                        PromptReplacement, PromptUpdate)
66
from vllm.multimodal.profiling import BaseDummyInputsBuilder
67
from vllm.utils.tensor_schema import TensorSchema, TensorShape
68
69

from .clip import CLIPMLP
70
from .interfaces import SupportsMultiModal, SupportsV0Only
71
from .llama import LlamaDecoderLayer, LlamaMLP
72
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
73
74
75
76

logger = init_logger(__name__)


77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
class MllamaImagePixelInputs(TensorSchema):
    """
    Dimensions:
        - batch_size: Batch size
        - max_num_image: Max number of images
        - max_num_chunk: Max number of chunks
        - max_num_tiles: Max number of tiles per image
        - num_channel: Number of channels
        - height: Height
        - width: Width
    """

    type: Literal["pixel_values"] = "pixel_values"

    data: Annotated[torch.Tensor,
                    TensorShape("batch_size", "max_num_image", "max_num_chunk",
                                "num_channel", "height", "width")]

    aspect_ratio_ids: Annotated[torch.Tensor,
                                TensorShape("batch_size", "max_num_image")]

    aspect_ratio_mask: Annotated[
        torch.Tensor,
        TensorShape("batch_size", "max_num_image", "max_num_tiles")]
101
102
103
104
105


# TODO: support LlamaImageEmbeddingInputs


106
107
108
109
110
111
112
113
114
115
116
def calc_token_per_chunk(image_size: int) -> int:
    assert image_size % 14 == 0, "chunk size should be multiple of 14"
    token_per_chunk = (image_size // 14)**2 + 1
    return token_per_chunk


class MllamaProcessingInfo(BaseProcessingInfo):

    def get_hf_config(self) -> MllamaConfig:
        return self.ctx.get_hf_config(MllamaConfig)

117
118
    def get_hf_processor(self, **kwargs: object) -> MllamaProcessor:
        return self.ctx.get_hf_processor(MllamaProcessor, **kwargs)
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151

    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": None}

    def get_token_per_chunk_from_config(self) -> int:
        image_size = self.get_hf_config().vision_config.image_size
        return calc_token_per_chunk(image_size)

    def get_num_tiles_per_image(self, image_height: int,
                                image_width: int) -> int:
        vision_config = self.get_hf_config().vision_config
        max_num_tiles = vision_config.max_num_tiles
        image_size = vision_config.image_size
        tiled_height, tiled_width = get_optimal_tiled_canvas(
            image_height,
            image_width,
            max_num_tiles,
            tile_size=image_size,
        )
        num_tiles_height = tiled_height // image_size
        num_tiles_width = tiled_width // image_size
        return num_tiles_height * num_tiles_width

    def get_image_size_with_most_features(self) -> ImageSize:
        vision_config = self.get_hf_config().vision_config
        image_size = vision_config.image_size
        max_num_tiles = vision_config.max_num_tiles
        # Result in the max possible feature size (h:w = 16:1)
        return ImageSize(height=max_num_tiles * image_size, width=image_size)


class MllamaDummyInputsBuilder(BaseDummyInputsBuilder[MllamaProcessingInfo]):

152
153
154
155
156
157
158
159
160
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

        processor = self.info.get_hf_processor()
        image_token = processor.image_token

        return image_token * num_images

    def get_dummy_mm_data(
161
162
163
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
164
    ) -> MultiModalDataDict:
165
166
167
168
169
        num_images = mm_counts.get("image", 0)

        target_width, target_height = \
            self.info.get_image_size_with_most_features()

170
        return {
171
172
173
174
175
176
177
178
179
180
            "image":
            self._get_dummy_images(width=target_width,
                                   height=target_height,
                                   num_images=num_images)
        }


class MllamaMultiModalProcessor(EncDecMultiModalProcessor[MllamaProcessingInfo]
                                ):

181
182
183
184
185
    def apply(
        self,
        prompt: Union[str, list[int]],
        mm_data: MultiModalDataDict,
        hf_processor_mm_kwargs: Mapping[str, object],
186
        tokenization_kwargs: Optional[Mapping[str, object]] = None,
187
    ) -> MultiModalEncDecInputs:
188
        mm_inputs = super().apply(prompt, mm_data, hf_processor_mm_kwargs,
189
                                  tokenization_kwargs)
190

191
        image_token_id = self.info.get_hf_config().image_token_index
192
193
        # Check that the number of image tokens in the decoder prompt matches
        # the number of images provided in mm_data
194
        num_image_tokens = mm_inputs['prompt_token_ids'].count(image_token_id)
195
196
197
198
199
200
201
        image_data = mm_data.get("image", [])
        num_images = 1 if isinstance(image_data, Image) else len(image_data)
        if num_image_tokens != num_images:
            raise ValueError(
                f"The number of image tokens ({num_image_tokens}) must be"
                f" the same as the number of images ({num_images})")

202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
        # Given prompt: <IMG0> P0 P1 <IMG1> <IMG2> P3 P4 D5 D6...., (P-prefill, D-decode)  # noqa: E501
        # P0 & P1 do cross attention with placeholder of <IMG0>
        # P3 P4 D5 D6 do cross attention with placeholder of <IMG1> and <IMG2>
        # Example input to encoder and decoder:
        # {
        #     'encoder': {
        #         'type': 'token',
        #         'prompt_token_ids': [128256, 128256, ..., 128256],
        #         'prompt': '<|image|><|image|>...<|image|>',
        #         'multi_modal_data': {'image': <PIL.Image.Image image mode=RGB size=1770x1180 at 0x7FDE2C624880>},  # noqa: E501
        #     },
        #     'decoder': {
        #         'type': 'token',
        #         'prompt_token_ids': [128000, 128256, 128000, 3923, 374, 279, 2262, 315, 420, 2217, 30],  # noqa: E501
        #         'prompt': '<|image|><|begin_of_text|>What is the content of this image?',  # noqa: E501
        #         'multi_modal_data': {'image': <PIL.Image.Image image mode=RGB size=1770x1180 at 0x7FDE2C624880>},  # noqa: E501
        #     },
        # }

        if mm_data:
222
223
224
            hf_processor = self.info.get_hf_processor()
            image_token: str = hf_processor.image_token

225
226
227
228
229
230
231
232
233
234
235
            # Since only the last group of consecutive images
            # are attended by the decoded tokens, we only need to
            # get the number of tokens for those images.
            token_per_chunk = self.info.get_token_per_chunk_from_config()
            num_decode_images = self._get_num_image_in_last_group(
                mm_inputs["prompt_token_ids"])
            num_encode_images = num_images - num_decode_images

            # Set encoder prompt length based on the number of tiles.
            # This tells the block manager to allocate correct number
            # of slots for encoder tokens.
236
            num_tiles = mm_inputs["mm_kwargs"].get_data()["num_tiles"]
237
238
239
240
            decode_tiles = num_tiles[num_encode_images:num_images].sum().item()
            num_tokens = decode_tiles * token_per_chunk
            mm_inputs["encoder_prompt_token_ids"] = [image_token_id
                                                     ] * num_tokens
241
            mm_inputs["encoder_prompt"] = image_token * num_tokens
242

243
244
        return mm_inputs

245
    def _get_num_image_in_last_group(self, prompt_token_ids: list[int]) -> int:
246
247
248
249
250
251
252
253
        num_images = 0
        for token_id in prompt_token_ids[::-1]:
            if token_id == self.info.get_hf_config().image_token_index:
                num_images += 1
            elif num_images > 0:
                break
        return num_images

254
255
256
257
258
    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
259
        tok_kwargs: Mapping[str, object],
260
261
262
263
264
265
266
267
    ) -> BatchFeature:
        tokenizer = self.info.get_tokenizer()
        if mm_data:
            num_tiles = [
                self.info.get_num_tiles_per_image(img.height, img.width)
                for img in mm_data["images"]
            ]
            processed_outputs = super()._call_hf_processor(
268
                prompt, mm_data, mm_kwargs, tok_kwargs)
269
270
271
            processed_outputs["num_tiles"] = torch.tensor(num_tiles)
            for k in ('pixel_values', 'aspect_ratio_ids', "aspect_ratio_mask"):
                processed_outputs[k] = processed_outputs[k].squeeze(0)
272

273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
            processed_token_ids = processed_outputs.pop("input_ids")
            start_idx, end_idx = 0, processed_token_ids.size(1)
            processed_prompt_text = tokenizer.decode(processed_token_ids[0])

            hf_processor = self.info.get_hf_processor()
            bos_token = hf_processor.bos_token
            # Remove the bos_token from the start of prompt,
            # because we all know there would be image_token.
            if processed_prompt_text.startswith(bos_token):
                start_idx += 1
            # Remove the bos_token from the end of prompt,
            # because text is empty in this case.
            if processed_prompt_text.endswith(bos_token):
                end_idx -= 1
            processed_outputs[
                "input_ids"] = processed_token_ids[:, start_idx:end_idx]
        else:
            processed_outputs = tokenizer(prompt,
                                          add_special_tokens=False,
                                          return_tensors="pt")
        return processed_outputs

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            aspect_ratio_ids=MultiModalFieldConfig.batched("image"),
            aspect_ratio_mask=MultiModalFieldConfig.batched("image"),
            num_tiles=MultiModalFieldConfig.batched("image"),
305
        )
306
307
308
309
310
311
312
313
314
315
316

    def create_encoder_prompt(
        self,
        prompt: Union[str, list[int]],
        mm_data: MultiModalDataDict,
    ) -> Union[str, list[int]]:
        data = mm_data.get("image", [])
        num_images = 1 if isinstance(data, Image) else len(data)
        image_token_id = self.info.get_hf_config().image_token_index
        return [image_token_id] * num_images

317
    def _get_prompt_updates(
318
319
320
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
321
        out_mm_kwargs: MultiModalKwargsItems,
322
    ) -> Sequence[PromptUpdate]:
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
        token_per_chunk = self.info.get_token_per_chunk_from_config()
        image_token_id = self.info.get_hf_config().image_token_index

        def get_replacement_mllama(item_idx):
            images = mm_items.get_items("image", ImageProcessorItems)
            image_size = images.get_image_size(item_idx)
            num_tile = self.info.get_num_tiles_per_image(
                image_height=image_size.height,
                image_width=image_size.width,
            )
            num_tokens = num_tile * token_per_chunk
            return [image_token_id] * num_tokens

        return [
            PromptReplacement(
                modality="image",
                target=[image_token_id],
                replacement=get_replacement_mllama,
            )
        ]
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391


def _prepare_aspect_ratio_attention_mask(
    aspect_ratio_mask: torch.Tensor,
    num_patches: int,
    target_length: int,
    dtype: torch.dtype,
) -> torch.Tensor:
    # Expand aspect ratio mask to target_length
    batch_size, max_num_tiles = aspect_ratio_mask.shape
    attention_mask = aspect_ratio_mask.view(batch_size, max_num_tiles, 1,
                                            1).to(dtype)
    attention_mask = attention_mask.repeat(1, 1, target_length, 1)

    # Mask padding patches
    pad_patches = target_length - num_patches
    attention_mask[:, :, -pad_patches:] = 0

    # Invert the mask (0 -> 1, 1 -> 0)
    attention_mask = 1 - attention_mask

    # Reshape to 2D and create 4D attention mask
    # (batch_size, 1, max_num_tiles*target_length, max_num_tiles*target_length)
    attention_mask = attention_mask.reshape(batch_size,
                                            max_num_tiles * target_length, 1)
    attention_mask = attention_mask @ attention_mask.transpose(
        -1, -2) * torch.finfo(dtype).min
    attention_mask = attention_mask.unsqueeze(1)

    return attention_mask


class ColumnParallelConv2dPatch(torch.nn.Module):
    """Conv2D Patching layer with model parallelism.
    Column parallel over unfolded input.
    Arguments:
        in_channels: Input channels.
        out_channels: Output channels.
        kernel_size: Size of convolution kernel.
        stride (default 1): Stride for convolution.
        bias (default False): Use bias in Conv2d.
    Input: (bsz, in_channels, width, height)
    Output: (bsz, num_tokens, out_channels)
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
392
393
        kernel_size: Union[int, tuple[int, int]],
        stride: Union[int, tuple[int, int]],
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
        bias: bool = False,
    ) -> None:
        super().__init__()
        if isinstance(kernel_size, int):
            kernel_size = (kernel_size, kernel_size)
        self._unfold = torch.nn.Unfold(kernel_size=kernel_size, stride=stride)
        self._linear = ColumnParallelLinear(
            in_channels * kernel_size[0] * kernel_size[1],
            out_channels,
            bias=bias,
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self._unfold(x)
        x = x.permute(0, 2, 1)
        x, _ = self._linear(x)
        return x


class MllamaPrecomputedAspectRatioEmbedding(nn.Module):

    def __init__(self,
                 config: config_mllama.MllamaVisionConfig,
                 is_gated: bool = True):
        super().__init__()
        self.max_num_tiles = config.max_num_tiles
        self.hidden_size = config.hidden_size
        self.max_aspect_ratio_id = config.max_aspect_ratio_id
        self.is_gated = is_gated

        self.embedding = nn.Embedding(self.max_aspect_ratio_id + 1,
                                      self.max_num_tiles * self.hidden_size)
        if is_gated:
            self.gate = nn.Parameter(torch.zeros(1))

    def forward(self, hidden_state: torch.Tensor,
                aspect_ratio_ids: torch.Tensor) -> torch.Tensor:
        embeddings = self.embedding(aspect_ratio_ids)
        embeddings = embeddings.reshape(-1, self.max_num_tiles, 1,
                                        self.hidden_size)

        if self.is_gated:
            embeddings = embeddings * self.gate.tanh()

        hidden_state = hidden_state + embeddings
        return hidden_state


class MllamaPrecomputedPositionEmbedding(nn.Module):

    def __init__(self, config: config_mllama.MllamaVisionConfig):
        super().__init__()
        self.max_num_tiles = config.max_num_tiles
        self.max_aspect_ratio_id = config.max_aspect_ratio_id
        self.num_patches = (config.image_size // config.patch_size)**2 + 1
        self.hidden_size = config.hidden_size
        self.scale = config.hidden_size**-0.5

        self.gate = nn.Parameter(torch.zeros(1))

        # position embedding
        position_embedding = torch.randn(self.num_patches, self.hidden_size)
        self.embedding = nn.Parameter(self.scale * position_embedding)

        # tile position embedding
        self.tile_embedding = nn.Embedding(
            self.max_aspect_ratio_id + 1,
            self.max_num_tiles * self.num_patches * self.hidden_size)

    def forward(self, hidden_state: torch.Tensor,
                aspect_ratio_ids: torch.Tensor) -> torch.Tensor:
        # position embeddings
        gated_position_embedding = (1 - self.gate.tanh()) * self.embedding
        hidden_state = hidden_state + gated_position_embedding.view(
            1, 1, self.num_patches, self.hidden_size)

        # precomputed tile position embeddings
        tile_position_embedding = self.tile_embedding(aspect_ratio_ids)
        batch_size = hidden_state.shape[0]
        tile_position_embedding = tile_position_embedding.reshape(
            batch_size, self.max_num_tiles, self.num_patches, self.hidden_size)
        gated_tile_position_embedding = self.gate.tanh(
        ) * tile_position_embedding
        hidden_state = hidden_state + gated_tile_position_embedding

        return hidden_state


# TODO: support other attention backends for attention in vision model
class MllamaVisionSdpaAttention(nn.Module):

485
486
487
488
    def __init__(self,
                 config: config_mllama.MllamaVisionConfig,
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
489
490
        super().__init__()

491
        tensor_parallel_size = get_tp_group().world_size
492
493
494
        self.embed_dim = config.hidden_size
        self.num_heads = config.attention_heads
        self.head_dim = config.hidden_size // config.attention_heads
495
        self.num_local_heads = self.num_heads // tensor_parallel_size
496
497
498
499
500
501
502
503
        self.q_size = self.num_local_heads * self.head_dim
        self.kv_size = self.num_local_heads * self.head_dim

        self.qkv_proj = QKVParallelLinear(
            self.embed_dim,
            self.head_dim,
            self.num_heads,
            bias=False,
504
505
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
506
507
508
509
510
511
        )
        self.o_proj = RowParallelLinear(
            self.num_heads * self.head_dim,
            self.embed_dim,
            bias=False,
            input_is_parallel=True,
512
513
            quant_config=quant_config,
            prefix=f"{prefix}.o_proj",
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
        )

    def forward(
        self,
        hidden_state: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_state)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q = q.view(q.shape[0], q.shape[1], self.num_local_heads,
                   self.head_dim).transpose(1, 2)
        k = k.view(k.shape[0], k.shape[1], self.num_local_heads,
                   self.head_dim).transpose(1, 2)
        v = v.view(v.shape[0], v.shape[1], self.num_local_heads,
                   self.head_dim).transpose(1, 2)

        # TODO: remove padding in image encoder
        attn_output = F.scaled_dot_product_attention(q,
                                                     k,
                                                     v,
                                                     attn_mask=attention_mask,
                                                     dropout_p=0.0)

        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.reshape(attn_output.shape[0],
                                          attn_output.shape[1], -1)
        output, _ = self.o_proj(attn_output)
        return output


class MllamaVisionEncoderLayer(nn.Module):

546
547
548
549
550
551
552
    def __init__(
        self,
        config: config_mllama.MllamaVisionConfig,
        quant_config: Optional[QuantizationConfig],
        prefix: str = "",
        is_gated: bool = False,
    ) -> None:
553
554
555
556
557
558
559
        super().__init__()

        self.hidden_size = config.hidden_size
        self.num_attention_heads = config.attention_heads
        self.is_gated = is_gated
        self.intermediate_size = config.intermediate_size

560
561
        self.self_attn = MllamaVisionSdpaAttention(
            config, quant_config=quant_config, prefix=f"{prefix}.self_attn")
562
563
564
        self.mlp = CLIPMLP(config,
                           quant_config=quant_config,
                           prefix=f"{prefix}.mlp")
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600

        self.input_layernorm = nn.LayerNorm(self.hidden_size,
                                            eps=config.norm_eps)
        self.post_attention_layernorm = nn.LayerNorm(self.hidden_size,
                                                     eps=config.norm_eps)

        # there used to be an if else here, no code path
        if is_gated:
            self.gate_attn = nn.Parameter(torch.ones(1) * math.pi / 4)
            self.gate_ffn = nn.Parameter(torch.ones(1) * math.pi / 4)

    def forward(
        self,
        hidden_state: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
    ):
        # Self Attention
        residual = hidden_state
        hidden_state = self.input_layernorm(hidden_state)
        hidden_state = self.self_attn(hidden_state,
                                      attention_mask=attention_mask)
        gate_attn = 1 if not self.is_gated else self.gate_attn.tanh()
        hidden_state = residual + gate_attn * hidden_state

        # Feed forward
        residual = hidden_state
        hidden_state = self.post_attention_layernorm(hidden_state)
        hidden_state = self.mlp(hidden_state)
        gate_ffn = 1 if not self.is_gated else self.gate_ffn.tanh()
        hidden_state = residual + gate_ffn * hidden_state

        return hidden_state


class MllamaVisionEncoder(nn.Module):

601
602
603
604
605
606
607
608
609
    def __init__(
        self,
        config: config_mllama.MllamaVisionConfig,
        quant_config: Optional[QuantizationConfig],
        num_layers: int = 32,
        is_gated: bool = False,
        output_hidden_states=None,
        prefix: str = "",
    ) -> None:
610
611
612
        super().__init__()
        self.config = config
        self.layers = nn.ModuleList([
613
614
615
616
617
            MllamaVisionEncoderLayer(config,
                                     quant_config=quant_config,
                                     is_gated=is_gated,
                                     prefix=f"{prefix}.layers.{layer_idx}")
            for layer_idx in range(num_layers)
618
619
620
621
622
623
624
        ])
        self.output_hidden_states = output_hidden_states or []

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
625
    ) -> Union[BaseModelOutput]:
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
        encoder_states = ()

        for i, encoder_layer in enumerate(self.layers):
            if i in self.output_hidden_states:
                encoder_states = encoder_states + (hidden_states, )
            hidden_states = encoder_layer(
                hidden_states,
                attention_mask,
            )

        if len(self.layers) - 1 in self.output_hidden_states:
            encoder_states = encoder_states + (hidden_states, )

        return hidden_states, encoder_states


class MllamaVisionModel(nn.Module):

644
645
646
647
648
649
    def __init__(
        self,
        config: config_mllama.MllamaVisionConfig,
        quant_config: Optional[QuantizationConfig],
        prefix: str = "",
    ) -> None:
650
        super().__init__()
651

652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
        self.image_size = config.image_size
        self.patch_size = config.patch_size
        self.max_num_tiles = config.max_num_tiles
        self.hidden_size = config.hidden_size
        self.in_channels = config.num_channels
        self.intermediate_layers_indices = config.intermediate_layers_indices

        self.num_patches = (self.image_size // self.patch_size)**2 + 1
        self.scale = config.hidden_size**-0.5

        self.patch_embedding = ColumnParallelConv2dPatch(
            in_channels=config.num_channels,
            out_channels=self.hidden_size,
            kernel_size=self.patch_size,
            stride=self.patch_size,
            bias=False,
        )

        self.class_embedding = nn.Parameter(self.scale *
                                            torch.randn(self.hidden_size))
        self.gated_positional_embedding = MllamaPrecomputedPositionEmbedding(
            config)

        self.pre_tile_positional_embedding = \
            MllamaPrecomputedAspectRatioEmbedding(config, is_gated=True)
        self.post_tile_positional_embedding = \
            MllamaPrecomputedAspectRatioEmbedding(config, is_gated=True)

        # layer norms
        self.layernorm_pre = nn.LayerNorm(self.hidden_size)
        self.layernorm_post = nn.LayerNorm(self.hidden_size)

        # encoders
        self.transformer = MllamaVisionEncoder(
            config,
687
            quant_config,
688
689
            config.num_hidden_layers,
            is_gated=False,
690
691
692
693
694
695
696
697
698
699
            output_hidden_states=config.intermediate_layers_indices,
            prefix=f"{prefix}.transformer",
        )
        self.global_transformer = MllamaVisionEncoder(
            config,
            quant_config,
            config.num_global_layers,
            is_gated=True,
            prefix=f"{prefix}.global_transformer",
        )
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810

    def apply_class_embedding(self,
                              hidden_state: torch.Tensor) -> torch.Tensor:
        batch_size, _, hidden_size = hidden_state.shape
        class_embedding = self.class_embedding.expand(batch_size, 1,
                                                      hidden_size)
        hidden_state = torch.cat([class_embedding, hidden_state], dim=1)
        return hidden_state

    def forward(self, pixel_values: torch.Tensor,
                aspect_ratio_ids: torch.Tensor,
                aspect_ratio_mask: torch.Tensor) -> torch.Tensor:
        batch_size, num_concurrent_media, num_tiles, num_channels, \
            height, width = pixel_values.shape

        pixel_values = pixel_values.reshape(
            batch_size * num_concurrent_media * num_tiles, num_channels,
            height, width)
        aspect_ratio_ids = aspect_ratio_ids.reshape(
            batch_size * num_concurrent_media, -1)

        # patch embedding
        patch_embeds = self.patch_embedding(
            pixel_values.to(self.layernorm_pre.weight.dtype))
        hidden_state = patch_embeds
        hidden_state = ps.get_tp_group().all_gather(hidden_state)

        # tile embeddings
        _, num_patches, dim = hidden_state.shape
        hidden_state = hidden_state.reshape(batch_size * num_concurrent_media,
                                            num_tiles, -1, dim)
        hidden_state = self.pre_tile_positional_embedding(
            hidden_state, aspect_ratio_ids)

        # apply cls token
        hidden_state = hidden_state.reshape(
            batch_size * num_concurrent_media * num_tiles, num_patches, dim)
        hidden_state = self.apply_class_embedding(hidden_state)
        num_patches += 1

        # apply position embeddings
        hidden_state = hidden_state.reshape(batch_size * num_concurrent_media,
                                            num_tiles, num_patches, dim)
        hidden_state = self.gated_positional_embedding(hidden_state,
                                                       aspect_ratio_ids)

        # apply encoder
        hidden_state = self.layernorm_pre(hidden_state)

        # Compute the number of tokens to pad
        num_padding_patches = (8 - (hidden_state.shape[-2] % 8)) % 8
        # Compute padding tuple for pad function
        padding = (
            0, 0, 0, num_padding_patches
        )  # (pad_left, pad_right, pad_left for dim -2, pad_right for dim -2)
        # Pad the tensor
        hidden_state = F.pad(hidden_state, padding, mode="constant", value=0)
        slice_index = -num_padding_patches if num_padding_patches > 0 else None

        attention_mask = aspect_ratio_mask.reshape(
            batch_size * num_concurrent_media, -1)
        attention_mask = _prepare_aspect_ratio_attention_mask(
            aspect_ratio_mask=attention_mask,
            num_patches=self.num_patches,
            target_length=hidden_state.shape[2],
            dtype=self.layernorm_pre.weight.dtype,
        )

        hidden_state = hidden_state.view(batch_size * num_concurrent_media, -1,
                                         dim)
        output = self.transformer(
            hidden_state,
            attention_mask=attention_mask,
        )
        hidden_state, intermediate_hidden_states = output[0], output[1]
        intermediate_hidden_states = torch.stack(intermediate_hidden_states,
                                                 dim=-1)

        # apply global encoder
        hidden_state = self.layernorm_post(hidden_state)
        hidden_state = hidden_state.reshape(batch_size * num_concurrent_media,
                                            num_tiles,
                                            num_patches + num_padding_patches,
                                            dim)
        hidden_state = self.post_tile_positional_embedding(
            hidden_state, aspect_ratio_ids)
        hidden_state = hidden_state.reshape(
            batch_size * num_concurrent_media,
            num_tiles * (num_patches + num_padding_patches), dim)
        hidden_state = self.global_transformer(
            hidden_state, attention_mask=attention_mask)[0]
        hidden_state = hidden_state.reshape(batch_size * num_concurrent_media,
                                            num_tiles,
                                            num_patches + num_padding_patches,
                                            dim)
        hidden_state = hidden_state[:, :, :slice_index]

        # adding intermediate layer outputs
        hidden_state = hidden_state.reshape(batch_size, num_concurrent_media,
                                            num_tiles, num_patches, dim)
        intermediate_hidden_states = intermediate_hidden_states.reshape(
            batch_size * num_concurrent_media, num_tiles,
            num_patches + num_padding_patches, -1)
        intermediate_hidden_states = intermediate_hidden_states[:, :, :
                                                                slice_index]
        intermediate_hidden_states = intermediate_hidden_states.reshape(
            batch_size, num_concurrent_media, num_tiles, num_patches, -1)
        hidden_state = torch.cat([hidden_state, intermediate_hidden_states],
                                 dim=-1)
        return hidden_state

811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
        ]
        params_dict = dict(self.named_parameters())
        updated_params: set[str] = set()
        for name, loaded_weight in weights:
            if 'patch_embedding._linear.weight' in name:
                loaded_weight = loaded_weight.view(loaded_weight.shape[0], -1)
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
                updated_params.add(name)
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict.pop(name)
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
                updated_params.add(name)
        return updated_params

841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870

class MllamaTextRMSNorm(nn.Module):

    def __init__(self, hidden_size, eps=1e-6):
        """
        MllamaTextRMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance +
                                                    self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)

    def extra_repr(self):
        return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"


class MllamaTextCrossAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(
        self,
        config: Optional[config_mllama.MllamaTextConfig] = None,
        layer_idx: Optional[int] = None,
871
        quant_config: Optional[QuantizationConfig] = None,
872
        prefix: str = "",
873
874
875
    ):
        super().__init__()
        self.config = config
876
877
        self.pipeline_parallel_rank = get_pp_group().rank_in_group
        self.tensor_parallel_size = get_tp_group().world_size
878
879
880
        self.num_heads = config.num_attention_heads
        self.num_key_value_heads = config.num_key_value_heads

881
        self.num_local_heads = self.num_heads // self.tensor_parallel_size
882
        self.num_local_key_value_heads = \
883
            self.num_key_value_heads // self.tensor_parallel_size
884
885
        self.hidden_size = config.hidden_size
        self.head_dim = config.hidden_size // self.num_heads
886
887
        self.num_key_value_heads = config.num_key_value_heads

888
889
890
891
892
        self.layer_idx = layer_idx
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
        self.q_local_size = self.num_local_heads * self.head_dim
        self.kv_local_size = self.num_local_key_value_heads * self.head_dim

893
        self.qkv_proj = QKVCrossParallelLinear(
894
895
            self.hidden_size,
            self.head_dim,
896
897
            self.num_heads,
            self.num_key_value_heads,
898
            bias=False,
899
            quant_config=quant_config,
900
            prefix=f"{prefix}.qkv_proj",
901
        )
902

903
904
905
906
907
        self.o_proj = RowParallelLinear(
            self.num_heads * self.head_dim,
            self.hidden_size,
            bias=False,
            input_is_parallel=True,
908
            quant_config=quant_config,
909
            prefix=f"{prefix}.o_proj",
910
911
912
913
914
915
916
917
918
919
920
921
        )
        # vllm.model_executor.layers.layernorm.RMSNorm has precision issue,
        # use huggingface's instead
        self.q_norm = MllamaTextRMSNorm(self.head_dim, eps=config.rms_norm_eps)
        self.k_norm = MllamaTextRMSNorm(self.head_dim, eps=config.rms_norm_eps)
        self.scaling = self.head_dim**-0.5

        self.attn = Attention(
            self.num_local_heads,
            self.head_dim,
            self.scaling,
            self.num_local_key_value_heads,
922
            prefix=f"{prefix}.attn",
923
            attn_type=AttentionType.ENCODER_DECODER,
924
925
926
927
928
929
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor],
930
        kv_range_for_decode: Optional[list[tuple[int, int]]],
931
932
        cross_attention_states: Optional[torch.Tensor],
    ) -> torch.Tensor:
933
        q, k, v = self.qkv_proj(hidden_states, cross_attention_states)
934
        if cross_attention_states is not None:
935
936
937
            k = k.view(-1, self.num_local_key_value_heads, self.head_dim)
            v = v.view(-1, self.num_local_key_value_heads, self.head_dim)
            k = self.k_norm(k)
938

939
940
941
        q = q.view(-1, self.num_local_heads, self.head_dim)
        q = self.q_norm(q)

942
        if attention_mask is not None:
943
944
            output = self._attention_with_mask(q, k, v, attention_mask,
                                               kv_range_for_decode)
945
        else:
946
            output = self.attn(
947
                q.view(-1, self.num_local_heads * self.head_dim), k, v)
948
949
950
        out, _ = self.o_proj(output)
        return out

951
    def _attention_with_mask(
952
953
954
955
956
        self,
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
        attention_mask: torch.Tensor,
957
        kv_range_for_decode: list[tuple[int, int]],
958
    ) -> torch.Tensor:
959
960
        kv_cache = self.attn.kv_cache[self.pipeline_parallel_rank]
        attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
961
        # Skip writing kv-cache for the initial profiling run.
962
        # TODO (NickLucche) replace with custom attn bias and use standard attn
963
        if len(kv_cache.shape) > 1:
964
            i = torch.ones(1, dtype=torch.float32)
965
966
            if self.attn.backend in (_Backend.FLASH_ATTN,
                                     _Backend.FLASH_ATTN_VLLM_V1):
967
968
969
970
971
972
973
974
975
976
                cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode])
                cached_v = torch.cat([v[s:e] for s, e in kv_range_for_decode])
                torch.ops._C_cache_ops.reshape_and_cache_flash(
                    cached_k,
                    cached_v,
                    kv_cache[0],
                    kv_cache[1],
                    attn_metadata.
                    cross_slot_mapping,  # type: ignore[union-attr]
                    "auto",
977
978
                    i,
                    i,
979
                )
980
981
            elif self.attn.backend in (_Backend.XFORMERS, _Backend.ROCM_FLASH,
                                       _Backend.TORCH_SDPA):
982
983
984
985
986
987
                key_cache, value_cache = PagedAttention.split_kv_cache(
                    kv_cache, self.num_local_key_value_heads, self.head_dim)
                cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode])
                cached_v = torch.cat([v[s:e] for s, e in kv_range_for_decode])
                PagedAttention.write_to_paged_cache(
                    cached_k, cached_v, key_cache, value_cache,
988
                    attn_metadata.cross_slot_mapping, "auto", i, i)
989
990
            else:
                raise ValueError(
991
992
                    f"Unsupported Attention backend {self.attn.backend} "
                    "enum found. Expected the Attention backend to be "
993
994
                    "FLASH_ATTN, FLASH_ATTN_VLLM_V1, "
                    "XFORMERS or TORCH_SDPA.")
995

996
997
998
999
1000
1001
1002
1003
1004
1005
        # We have to call torch.sdpa for prefill when using a
        # custom cross-attention mask. Because the mask is not a
        # standard causal mask, neither a block diagonal mask which
        # can be optimized by xformers.BlockDiagonalMask.
        # The mask is specially calculated for supporting multi
        # images and interleaved images.
        q_len = q.shape[0]
        kv_len = k.shape[0]
        q = q.transpose(0, 1).view(self.num_local_key_value_heads,
                                   self.num_key_value_groups, q_len,
1006
                                   self.head_dim).contiguous()
1007
1008
1009
1010
        k = k.transpose(0,
                        1)[:,
                           None, :, :].expand(self.num_local_key_value_heads,
                                              self.num_key_value_groups,
1011
1012
                                              kv_len,
                                              self.head_dim).contiguous()
1013
1014
1015
1016
        v = v.transpose(0,
                        1)[:,
                           None, :, :].expand(self.num_local_key_value_heads,
                                              self.num_key_value_groups,
1017
1018
                                              kv_len,
                                              self.head_dim).contiguous()
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
        attention_mask = attention_mask.view(1, 1, q_len, kv_len)
        output = F.scaled_dot_product_attention(q,
                                                k,
                                                v,
                                                attn_mask=attention_mask,
                                                is_causal=False)
        output = output.permute(2, 0, 1, 3).reshape(
            q_len, self.num_local_heads * self.head_dim)
        return output

1029
1030
1031
1032
1033

class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
    """Cross-attention transformer block with tanh-gated attention
    and feedforward."""

1034
1035
1036
1037
1038
1039
1040
    def __init__(
        self,
        config: config_mllama.MllamaTextConfig,
        layer_idx: int,
        quant_config: Optional[QuantizationConfig],
        prefix: str = "",
    ) -> None:
1041
        super().__init__()
1042

1043
1044
1045
1046
        self.layer_idx = layer_idx
        self.cross_attn = MllamaTextCrossAttention(
            config=config,
            layer_idx=layer_idx,
1047
            quant_config=quant_config,
1048
            prefix=f"{prefix}.cross_attn",
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
        )

        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.cross_attn_attn_gate = torch.nn.Parameter(torch.zeros(1))

        self.mlp = LlamaMLP(
            hidden_size=config.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
1059
            quant_config=quant_config,
1060
            prefix=f"{prefix}.mlp",
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
        )
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)
        self.cross_attn_mlp_gate = torch.nn.Parameter(torch.zeros(1))

    def forward(
        self,
        hidden_states: torch.Tensor,
        cross_attention_states: torch.Tensor,
        cross_attention_mask: torch.Tensor,
1071
        kv_range_for_decode: Optional[list[tuple[int, int]]],
1072
1073
1074
1075
1076
1077
1078
1079
        full_text_row_masked_out_mask: torch.Tensor,
    ) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)

        hidden_states = self.cross_attn(
            hidden_states=hidden_states,
            attention_mask=cross_attention_mask,
1080
            kv_range_for_decode=kv_range_for_decode,
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
            cross_attention_states=cross_attention_states,
        )
        hidden_states = full_text_row_masked_out_mask * hidden_states
        hidden_states = residual + self.cross_attn_attn_gate.tanh(
        ) * hidden_states

        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = full_text_row_masked_out_mask * hidden_states
        hidden_states = residual + self.cross_attn_mlp_gate.tanh(
        ) * hidden_states
        return hidden_states


class MllamaTextModel(nn.Module):
    config_class = config_mllama.MllamaTextConfig
    base_model_prefix = "model"

1100
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
1101
        super().__init__()
1102

1103
1104
1105
1106
        config = vllm_config.model_config.hf_config.text_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

1107
1108
1109
1110
1111
1112
1113
1114
1115
        self.vocab_size = config.vocab_size
        self.embed_tokens = VocabParallelEmbedding(config.vocab_size + 8,
                                                   config.hidden_size)
        self.cross_attention_layers = config.cross_attention_layers

        layers = []
        for layer_idx in range(config.num_hidden_layers):
            if layer_idx in self.cross_attention_layers:
                layers.append(
1116
                    MllamaCrossAttentionDecoderLayer(
1117
1118
1119
1120
1121
                        config,
                        layer_idx,
                        quant_config=quant_config,
                        prefix=f"{prefix}.layers.{layer_idx}",
                    ))
1122
1123
1124
            else:
                # TODO: force LlamaDecoderLayer to config.attention_bias=False
                layers.append(
1125
1126
1127
1128
1129
1130
                    LlamaDecoderLayer(
                        config,
                        cache_config=cache_config,
                        quant_config=quant_config,
                        prefix=f"{prefix}.layers.{layer_idx}",
                    ))
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140

        self.layers = nn.ModuleList(layers)
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(
        self,
        input_ids: torch.LongTensor,
        positions: Optional[torch.LongTensor],
        cross_attention_states: Optional[torch.LongTensor],
        cross_attention_mask: Optional[torch.LongTensor],
1141
1142
        kv_range_for_decode: Optional[list[tuple[int, int]]],
        full_text_row_masked_out_mask: Optional[tuple[torch.Tensor,
1143
1144
1145
1146
1147
1148
                                                      torch.Tensor]],
        skip_cross_attention: bool,
    ) -> torch.Tensor:
        inputs_embeds = self.embed_tokens(input_ids)
        hidden_states = inputs_embeds

1149
1150
        for idx, decoder_layer in enumerate(self.layers):
            if idx in self.cross_attention_layers:
1151
1152
1153
1154
1155
                if not skip_cross_attention:
                    hidden_states = decoder_layer(
                        hidden_states=hidden_states,
                        cross_attention_states=cross_attention_states,
                        cross_attention_mask=cross_attention_mask,
1156
                        kv_range_for_decode=kv_range_for_decode,
1157
1158
1159
                        full_text_row_masked_out_mask=
                        full_text_row_masked_out_mask,
                    )
1160
            else:
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
                hidden_states, residual = decoder_layer(
                    positions=positions,
                    hidden_states=hidden_states,
                    residual=None,
                )
                hidden_states = hidden_states + residual
        hidden_states = self.norm(hidden_states)
        return hidden_states


class MllamaForCausalLM(nn.Module):
    config_class = config_mllama.MllamaTextConfig
    base_model_prefix = "language_model"
    _no_split_modules = [
        "MllamaCrossAttentionDecoderLayer", "MllamaSelfAttentionDecoderLayer"
    ]

1178
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
1179
        super().__init__()
1180
1181
1182

        config = vllm_config.model_config.hf_config.text_config
        quant_config = vllm_config.quant_config
1183
        self.quant_config = quant_config
1184

1185
        self.vocab_size = config.vocab_size
1186
        self.model = MllamaTextModel(vllm_config=vllm_config,
1187
                                     prefix=f"{prefix}.model")
1188
1189
1190
1191
1192
1193
        self.lm_head = ParallelLMHead(
            config.vocab_size,
            config.hidden_size,
            org_num_embeddings=config.vocab_size,
            padding_size=DEFAULT_VOCAB_PADDING_SIZE,
            quant_config=quant_config,
1194
            prefix=f"{prefix}.lm_head",
1195
1196
1197
1198
1199
1200
1201
1202
        )

    def forward(
        self,
        input_ids: torch.LongTensor,
        positions: Optional[torch.LongTensor],
        cross_attention_states: Optional[torch.LongTensor],
        cross_attention_mask: Optional[torch.LongTensor],
1203
1204
        kv_range_for_decode: Optional[list[tuple[int, int]]],
        full_text_row_masked_out_mask: Optional[tuple[torch.Tensor,
1205
1206
1207
1208
1209
1210
1211
1212
                                                      torch.Tensor]],
        skip_cross_attention: bool,
    ) -> torch.Tensor:
        hidden_states = self.model(
            input_ids=input_ids,
            positions=positions,
            cross_attention_states=cross_attention_states,
            cross_attention_mask=cross_attention_mask,
1213
            kv_range_for_decode=kv_range_for_decode,
1214
1215
1216
1217
1218
            full_text_row_masked_out_mask=full_text_row_masked_out_mask,
            skip_cross_attention=skip_cross_attention,
        )
        return hidden_states

1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
        updated_params: set[str] = set()
        for name, loaded_weight in weights:
            if 'patch_embedding.weight' in name:
                name = name.replace('patch_embedding.weight',
                                    'patch_embedding._linear.weight')
                loaded_weight = loaded_weight.view(loaded_weight.shape[0], -1)
            if (self.quant_config is not None and
                (scale_name := self.quant_config.get_cache_scale(name))):
                # Loading kv cache quantization scales
                param = params_dict[scale_name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
                                 loaded_weight[0])
                weight_loader(param, loaded_weight)
                updated_params.add(scale_name)
                continue
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
                updated_params.add(name)
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                orig_name = name
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    logger.debug("Missing name %s, orig name %s", name,
                                 orig_name)
                    continue

                param = params_dict.pop(name)
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
                updated_params.add(name)
        return updated_params

1271

1272
1273
1274
@MULTIMODAL_REGISTRY.register_processor(MllamaMultiModalProcessor,
                                        info=MllamaProcessingInfo,
                                        dummy_inputs=MllamaDummyInputsBuilder)
1275
1276
class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
                                     SupportsV0Only):
1277
    packed_modules_mapping = {
1278
1279
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
        "gate_up_proj": ["gate_proj", "up_proj"]
1280
    }
1281

1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            # mapping for new names in checkpoint saved after transformers v4.52
            "model.vision_model.": "vision_model.",
            "model.multi_modal_projector.": "multi_modal_projector.",
            "model.language_model.": "language_model.model.",
            "lm_head.": "language_model.lm_head.",
        },
        orig_to_new_suffix={
            "patch_embedding.weight": "patch_embedding._linear.weight",
        },
    )

1295
1296
1297
1298
1299
1300
1301
    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
        if modality.startswith("image"):
            return "<|image|>"

        raise ValueError("Only image modality is supported")

1302
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
1303
        super().__init__()
1304
        config: MllamaConfig = vllm_config.model_config.hf_config
1305
        quant_config = vllm_config.quant_config
1306
        self.config = config
1307
        self.quant_config = quant_config
1308
1309
1310
1311
1312
1313
1314
        self.vocab_size = config.text_config.vocab_size
        self.hidden_size = config.text_config.hidden_size
        self.max_num_tiles = config.vision_config.max_num_tiles
        self.vision_output_dim = config.vision_config.vision_output_dim
        self.pad_token_id = \
            config.pad_token_id if config.pad_token_id is not None else -1
        self.image_size = config.vision_config.image_size
1315
        self.image_token_id = config.image_token_index
1316

1317
        self.vision_model = MllamaVisionModel(config.vision_config,
1318
                                              quant_config,
1319
1320
                                              prefix=maybe_prefix(
                                                  prefix, "vision_model"))
1321
        self.language_model = MllamaForCausalLM(
1322
1323
            vllm_config=vllm_config,
            prefix=maybe_prefix(prefix, "language_model"),
1324
        )
1325
        self.multi_modal_projector = ColumnParallelLinear(
1326
1327
1328
            config.vision_config.vision_output_dim,
            config.text_config.hidden_size,
            bias=True,
1329
1330
            quant_config=quant_config,
            gather_output=True,
1331
            prefix=maybe_prefix(prefix, "multi_modal_projector"),
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
        )
        self.logits_processor = LogitsProcessor(config.output_hidden_states,
                                                config.text_config.vocab_size)

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
        logits = self.logits_processor(self.language_model.lm_head,
                                       hidden_states, sampling_metadata)
        return logits

1345
    def unpack_data(self,
1346
                    image_data: Union[list[torch.Tensor], torch.Tensor],
1347
1348
1349
1350
1351
1352
1353
1354
                    padding_value=0) -> torch.Tensor:
        if isinstance(image_data, torch.Tensor):
            # torch.Tensor
            return image_data
        else:
            assert isinstance(
                image_data[0],
                torch.Tensor), "Image data is not properly batched."
1355
            # list[torch.Tensor]
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
            bsz = len(image_data)
            max_length = max(t.size(0) for t in image_data)
            trailing_dims = image_data[0].shape[1:]
            for data in image_data:
                cur_trailing_dims = data.shape[1:]
                assert cur_trailing_dims == trailing_dims
            output_tensor = torch.full((bsz, max_length, *trailing_dims),
                                       padding_value,
                                       dtype=image_data[0].dtype,
                                       device=image_data[0].device)
            for i, t in enumerate(image_data):
                output_tensor[i, :t.size(0)] = t
            return output_tensor

1370
1371
    def _parse_and_validate_image_input(self, **kwargs: object):
        # tensor with the same shape will be batched together by
1372
        # MultiModalKwargs.batch, so pixel_values here can be:
1373
        #   - list[torch.Tensor]:
1374
1375
1376
        #       with shape (num_image, num_tiles, 3, image_res, image_res)
        #   - torch.Tensor:
        #       with shape (bs, num_image, num_tiles, 3, image_res, image_res)
1377
1378
        pixel_values: Optional[Union[list[list[torch.Tensor]],
                                     list[torch.Tensor],
1379
1380
                                     torch.Tensor]] = kwargs.pop(
                                         "pixel_values", None)
1381
1382
        image_embeds: Optional[Union[list[list[torch.Tensor]],
                                     list[torch.Tensor],
1383
1384
                                     torch.Tensor]] = kwargs.pop(
                                         "image_embeds", None)
1385
1386
        aspect_ratio_ids: Optional[Union[list[list[torch.Tensor]],
                                         list[torch.Tensor],
1387
1388
                                         torch.Tensor]] = kwargs.pop(
                                             "aspect_ratio_ids", None)
1389
1390
        aspect_ratio_mask: Optional[Union[list[list[torch.Tensor]],
                                          list[torch.Tensor],
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
                                          torch.Tensor]] = kwargs.pop(
                                              "aspect_ratio_mask", None)

        if pixel_values is None and image_embeds is None:
            return None

        if pixel_values is not None and image_embeds is not None:
            raise ValueError(
                "Both pixel values and image embeds are provided.")

        if pixel_values is not None:
            assert aspect_ratio_ids is not None
            assert aspect_ratio_mask is not None

            return MllamaImagePixelInputs(
                type="pixel_values",
1407
1408
1409
                data=self.unpack_data(pixel_values),
                aspect_ratio_ids=self.unpack_data(aspect_ratio_ids),
                aspect_ratio_mask=self.unpack_data(aspect_ratio_mask))
1410
1411
1412
1413
1414
1415

        if image_embeds is not None:
            raise NotImplementedError

        raise AssertionError("This line should be unreachable.")

1416
1417
    def _get_and_validate_encoder_lens(
        self,
1418
1419
        encoder_seq_lens: list[int],
        num_tiles: list[list[int]],
1420
        num_tokens_per_tile: int,
1421
    ) -> list[int]:
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
        # Get the actual number of encoder tokens for each sample.
        # Because attn_metadata.encoder_seq_lens only counts the last
        # group of images for each sample, which is used to cheat the
        # block manager to allocate blocks for those images only.
        # See MllamaMultiModalProcessor for more details.
        actual_encoder_seq_lens = [
            sum(num_tile) * num_tokens_per_tile for num_tile in num_tiles
        ]

        # remove 0 encoder len entries for text-only requests for these
        # assertions
        attn_metadata_lens = [x for x in encoder_seq_lens if x > 0]
        assert len(actual_encoder_seq_lens) == len(attn_metadata_lens)
        for actual_len, last_group_len in zip(actual_encoder_seq_lens,
                                              attn_metadata_lens):
            assert actual_len >= last_group_len

        return actual_encoder_seq_lens

1441
    def flat_encoder_result(self, cross_attention_states: torch.Tensor,
1442
                            attn_metadata: AttentionMetadata,
1443
                            actual_encoder_seq_lens: list[int]):
1444
1445

        cross_attention_states_flat = torch.zeros(
1446
            sum(actual_encoder_seq_lens),
1447
1448
1449
1450
            cross_attention_states.shape[-1],
            device=cross_attention_states.device,
            dtype=cross_attention_states.dtype)
        start_pos = 0
1451
1452
        for seq_len, vision_token_in_batch in zip(actual_encoder_seq_lens,
                                                  cross_attention_states):
1453
1454
1455
1456
1457
            end_pos = start_pos + seq_len
            cross_attention_states_flat[
                start_pos:end_pos] = vision_token_in_batch[:seq_len]
            start_pos = end_pos
        cross_attention_states = cross_attention_states_flat
1458
1459
        return cross_attention_states

1460
1461
1462
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

1463
1464
1465
1466
    def get_cross_attention_states(
        self,
        image_inputs: MllamaImagePixelInputs,
        attn_metadata: AttentionMetadata,
1467
1468
        actual_encoder_seq_lens: list[int],
    ) -> tuple[torch.Tensor]:
1469
1470
1471
1472
1473
1474
1475
        # NOTE: llama's reference implementation runs vision model on CPU
        pixel_values = image_inputs['data']
        aspect_ratio_ids = image_inputs['aspect_ratio_ids']
        aspect_ratio_mask = image_inputs['aspect_ratio_mask']
        cross_attention_states = self.vision_model(pixel_values,
                                                   aspect_ratio_ids,
                                                   aspect_ratio_mask)
1476
        cross_attention_states, _ = self.multi_modal_projector(
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
            cross_attention_states)

        bsz, _, _, _, image_token_dim = tuple(cross_attention_states.shape)
        cross_attention_states = cross_attention_states.view(
            bsz, -1, image_token_dim)

        cross_attention_states = self.flat_encoder_result(
            cross_attention_states, attn_metadata, actual_encoder_seq_lens)

        return cross_attention_states

    def get_cross_attention_mask(
        self,
        input_ids: torch.Tensor,
        attn_metadata: AttentionMetadata,
1492
        num_tiles: list[list[int]],
1493
1494
        num_tokens_per_tile: int,
        dtype: torch.dtype,
1495
    ) -> tuple[torch.Tensor, torch.Tensor]:
1496
1497
1498
1499
1500
1501
1502
        token_ids = input_ids.tolist()
        start = 0
        batch_token_ids = []
        for seq_len in attn_metadata.seq_lens:
            batch_token_ids.append(token_ids[start:start + seq_len])
            start += seq_len
        sparse_mask = [
1503
            get_cross_attention_token_mask(t, self.image_token_id)
1504
1505
            for t in batch_token_ids
        ]
1506

1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
        # Skip generating cross-attention mask if all samples
        # are text-only or have only 1 leading image.
        if skip_attention_mask(sparse_mask):
            return None, None

        dense_mask, tile_range_for_decode = \
            convert_sparse_cross_attention_mask_to_dense(
                sparse_mask, num_tiles, attn_metadata.seq_lens)
        cross_attention_mask = \
            convert_dense_cross_attention_mask_to_tensor(
                dense_mask, num_tokens_per_tile, input_ids.device, dtype)
        kv_range_for_decode = [[
            t[0] * num_tokens_per_tile, t[1] * num_tokens_per_tile
        ] for t in tile_range_for_decode]

        return cross_attention_mask, kv_range_for_decode

    def get_full_text_row_masked_out_mask(
        self,
        attn_metadata: AttentionMetadata,
        device: torch.device,
    ) -> torch.Tensor:
1529
1530
1531
        full_text_row_masked_out_mask = torch.ones(
            (attn_metadata.num_prefill_tokens, 1), dtype=torch.bool)
        start_pos = 0
1532
1533
        for seq_len, encoder_seq_len in zip(attn_metadata.seq_lens,
                                            attn_metadata.encoder_seq_lens):
1534
1535
1536
1537
1538
            if encoder_seq_len == 0:
                full_text_row_masked_out_mask[start_pos:start_pos +
                                              seq_len] = False
            start_pos += seq_len
        full_text_row_masked_out_mask = full_text_row_masked_out_mask.to(
1539
1540
            device)
        return full_text_row_masked_out_mask
1541
1542
1543
1544
1545
1546

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        **kwargs: object,
1547
    ) -> Union[CausalLMOutputWithPast]:
1548
        attn_metadata = get_forward_context().attn_metadata
1549
1550
1551
1552
        if attn_metadata.num_prefill_tokens > 0 and \
            attn_metadata.num_decode_tokens > 0:
            raise ValueError("Chunk prefill not supported")
        image_inputs = self._parse_and_validate_image_input(**kwargs)
1553
1554
1555
1556
1557
        cross_attention_states = None
        cross_attention_mask = None
        kv_range_for_decode = None

        # For 1) text-only prefill and decode, 2) image-present decode.
1558
1559
        if image_inputs is None:
            full_text_row_masked_out_mask = (
1560
1561
                attn_metadata.encoder_seq_lens_tensor
                != 0).reshape(-1, 1).to(input_ids.device)
1562
            skip_cross_attention = attn_metadata.max_encoder_seq_len == 0
1563
1564

        # For image-present prefill.
1565
1566
        else:
            skip_cross_attention = False
1567

1568
            num_tiles = [t.tolist() for t in kwargs.pop("num_tiles")]
1569
            num_tokens_per_tile = calc_token_per_chunk(self.image_size)
1570
1571
1572
1573
1574
1575

            actual_encoder_seq_lens = self._get_and_validate_encoder_lens(
                attn_metadata.encoder_seq_lens,
                num_tiles,
                num_tokens_per_tile,
            )
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587

            cross_attention_states = self.get_cross_attention_states(
                image_inputs, attn_metadata, actual_encoder_seq_lens)

            full_text_row_masked_out_mask = \
                self.get_full_text_row_masked_out_mask(
                    attn_metadata, input_ids.device)

            cross_attention_mask, kv_range_for_decode = \
                self.get_cross_attention_mask(
                    input_ids, attn_metadata, num_tiles,
                    num_tokens_per_tile, cross_attention_states.dtype)
1588
1589
1590
1591
1592
1593

        outputs = self.language_model(
            input_ids=input_ids,
            positions=positions,
            cross_attention_states=cross_attention_states,
            cross_attention_mask=cross_attention_mask,
1594
            kv_range_for_decode=kv_range_for_decode,
1595
1596
1597
1598
1599
1600
            full_text_row_masked_out_mask=full_text_row_masked_out_mask,
            skip_cross_attention=skip_cross_attention,
        )

        return outputs

1601
1602
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
1603
1604
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
1605

1606
1607
1608
1609
1610
1611
1612
1613
1614
    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="language_model",
            connector="multi_modal_projector",
            tower_model="vision_model")

1615

1616
def skip_attention_mask(sparse_mask: list[list[int]]) -> bool:
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
    for mask in sparse_mask:
        # Skip text-only samples.
        if len(mask) == 0:
            continue
        # If the sample contains more than 1 images,
        # we can't skip mask.
        if len(mask) != 1:
            return False
        # If the sample contains only 1 image,
        # but the image is not the leading one,
        # we can't skip mask.
        if mask[0][0] != 0 or mask[0][1] != -1:
            return False
    return True


def convert_sparse_cross_attention_mask_to_dense(
1634
1635
1636
1637
    sparse_mask: list[list[list[int]]],
    num_tiles: list[list[int]],
    lengths: list[int],
) -> tuple[np.ndarray, list[tuple[int, int]]]:
1638
1639
1640
    total_length = sum(lengths)
    total_tiles = sum([sum(tiles) for tiles in num_tiles])
    dense_mask = np.zeros(shape=(total_length, total_tiles), dtype=np.int64)
1641
1642
    # A list of ranges, range[i] = [start, end] means that the i-th image will
    # use tiles[start, end] for cross-attention decoding.
1643
1644
1645
1646
    tile_range_for_decode = []

    seq_start = 0
    tile_start = 0
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657

    # sparse_mask has an [] entry for each sequence that does not have images,
    # but num_tiles does not have these entries...
    num_tiles_idx = 0
    for masks, length in zip(sparse_mask, lengths):
        if len(masks) == 0:
            # Text only
            continue

        tiles = num_tiles[num_tiles_idx]
        num_tiles_idx += 1
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
        ts, td = -1, 0
        for mask, tile in zip(masks, tiles):
            if len(mask) != 2:
                continue
            start, end = mask
            end = min(end, length)
            if end == -1:
                end = length
            if end == length:
                if ts == -1:
                    ts = tile_start
                td += tile
            dense_mask[seq_start + start:seq_start + end,
                       tile_start:tile_start + tile] = 1
            tile_start += tile
1673
1674
        assert ts != -1
        assert td != 0
1675
1676
        tile_range_for_decode.append((ts, ts + td))
        seq_start += length
1677
    assert num_tiles_idx == len(num_tiles)
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697

    return dense_mask, tile_range_for_decode


def convert_dense_cross_attention_mask_to_tensor(
    cross_attention_token_mask: np.ndarray,
    num_tokens_per_tile: int,
    device: torch.device,
    dtype: torch.dtype,
) -> torch.Tensor:
    mask = torch.tensor(cross_attention_token_mask, dtype=dtype, device=device)
    mask = mask.repeat_interleave(num_tokens_per_tile, dim=1)

    mask = 1.0 - mask
    mask = mask.masked_fill(mask.to(torch.bool), torch.finfo(dtype).min)

    ninf = torch.finfo(dtype).min
    full_text_mask = ((mask != ninf).any(dim=-1).type_as(mask)[..., None])
    mask *= full_text_mask
    # (num_prompt_tokens, num_encoder_tokens)
1698
    return mask