"vscode:/vscode.git/clone" did not exist on "2ac74d098ef7b8748db0cdaa255eeceb5cdd5366"
mllama.py 61.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch Mllama model."""
import math
16
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
17
18
                    TypedDict, Union)

19
import numpy as np
20
21
22
23
24
25
26
27
28
29
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
import transformers.models.mllama.configuration_mllama as config_mllama
from PIL import Image
from torch import nn
from transformers.modeling_outputs import (BaseModelOutput,
                                           CausalLMOutputWithPast)
from transformers.models.mllama.image_processing_mllama import (
    get_optimal_tiled_canvas)
30
31
from transformers.models.mllama.processing_mllama import (
    get_cross_attention_token_mask)
32
33
34

import vllm.distributed.parallel_state as ps
from vllm.attention import Attention, AttentionMetadata, AttentionType
35
from vllm.attention.ops.paged_attn import PagedAttention
36
from vllm.attention.selector import _Backend
37
from vllm.config import VllmConfig
38
from vllm.distributed import get_tensor_model_parallel_world_size
39
40
from vllm.inputs import (INPUT_REGISTRY, DummyData, EncoderDecoderInputs,
                         InputContext, TokenInputs, token_inputs)
41
42
43
44
45
46
47
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
Joe Runde's avatar
Joe Runde committed
48
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
49
50
51
52
53
from vllm.model_executor.layers.vocab_parallel_embedding import (
    DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
54
from vllm.sequence import SequenceData
55
from vllm.utils import is_list_of
56
57
58
59

from .clip import CLIPMLP
from .interfaces import SupportsMultiModal
from .llama import LlamaDecoderLayer, LlamaMLP
60
from .utils import maybe_prefix
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80

logger = init_logger(__name__)
MLLAMA_IMAGE_TOKEN_ID = 128256
MLLAMA_IMAGE_TOKEN = "<|image|>"


class MllamaImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
    data: torch.Tensor
    """Shape: """
    """(batch_size, max_num_image, max_num_chunk, num_channel, height, width)"""
    aspect_ratio_ids: torch.Tensor
    """Shape: `(batch_size, max_num_image)`"""
    aspect_ratio_mask: torch.Tensor
    """Shape: `(batch_size, max_num_image, max_num_tiles)`"""


# TODO: support LlamaImageEmbeddingInputs


81
82
83
84
85
86
87
88
89
90
def _get_num_image_in_last_group(prompt_token_ids: List[int]) -> int:
    num_images = 0
    for token_id in prompt_token_ids[::-1]:
        if token_id == MLLAMA_IMAGE_TOKEN_ID:
            num_images += 1
        elif num_images > 0:
            break
    return num_images


91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
def input_processor_for_mllama(
    ctx: InputContext,
    inputs: EncoderDecoderInputs,
) -> EncoderDecoderInputs:
    # Example input to processor:
    # {
    #     'encoder': {
    #         'type': 'token',
    #         'prompt_token_ids': [128000, 128256, 128000, 3923, 374, 279, 2262, 315, 420, 2217, 30],  # noqa: E501
    #         'prompt': '<|image|><|begin_of_text|>What is the content of this image?',  # noqa: E501
    #         'multi_modal_data': {'image': <PIL.Image.Image image mode=RGB size=1770x1180 at 0x7FDE2C624880>},  # noqa: E501
    #     },
    #     'decoder': {
    #         'type': 'token',
    #         'prompt_token_ids': [128000],
    #     },
    # }

    # move encoder prompt to decoder
    dec_inputs = TokenInputs(**inputs["encoder"])

    multi_modal_data = dec_inputs.get("multi_modal_data")
    if multi_modal_data is None or "image" not in multi_modal_data:
        # text-only
        return EncoderDecoderInputs(
            encoder=token_inputs([]),
            decoder=dec_inputs,
        )
119

120
121
122
    image_data = multi_modal_data["image"]
    if isinstance(image_data, Image.Image):
        image_data = [image_data]
123

124
    assert is_list_of(image_data, Image.Image)
125

126
127
128
129
130
131
132
    num_image_tokens = dec_inputs['prompt_token_ids'].count(
        MLLAMA_IMAGE_TOKEN_ID)
    if num_image_tokens != len(image_data):
        raise ValueError(
            f"The number of image tokens ({num_image_tokens}) must be"
            f" the same as the number of images ({len(image_data)})")

133
134
135
136
    # Since only the last group of consecutive images
    # are attended by the decoded tokens, we only need to
    # get the number of tiles for those images.
    num_decode_images = _get_num_image_in_last_group(
137
138
        dec_inputs["prompt_token_ids"])

139
    hf_config = ctx.model_config.hf_config
140
141
    vision_config = hf_config.vision_config

142
    num_tiles = 0
143
    for image in image_data[::-1]:
144
        width, height = image.size
145
        tile_size = vision_config.image_size
146
147
148
        canvas_height, canvas_width = get_optimal_tiled_canvas(
            image_height=height,
            image_width=width,
149
            max_image_tiles=vision_config.max_num_tiles,
150
151
152
153
154
            tile_size=tile_size,
        )
        num_tiles_height = canvas_height // tile_size
        num_tiles_width = canvas_width // tile_size
        num_tiles += num_tiles_height * num_tiles_width
155
156
157
        num_decode_images -= 1
        if num_decode_images == 0:
            break
158

159
160
161
    # Set encoder prompt length based on the number of tiles.
    # This tells the block manager to allocate correct number
    # of slots for encoder tokens.
162
    assert vision_config.image_size % 14 == 0, \
163
        "chunk size should be multiple of 14"
164
    token_per_chunk = (vision_config.image_size // 14)**2 + 1
165
166
    num_tokens = num_tiles * token_per_chunk

167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
    # Example output from processor:
    # {
    #     'encoder': {
    #         'type': 'token',
    #         'prompt_token_ids': [128256, 128256, ..., 128256],
    #         'prompt': '<|image|><|image|>...<|image|>',
    #         'multi_modal_data': {'image': <PIL.Image.Image image mode=RGB size=1770x1180 at 0x7FDE2C624880>},  # noqa: E501
    #     },
    #     'decoder': {
    #         'type': 'token',
    #         'prompt_token_ids': [128000, 128256, 128000, 3923, 374, 279, 2262, 315, 420, 2217, 30],  # noqa: E501
    #         'prompt': '<|image|><|begin_of_text|>What is the content of this image?',  # noqa: E501
    #         'multi_modal_data': {'image': <PIL.Image.Image image mode=RGB size=1770x1180 at 0x7FDE2C624880>},  # noqa: E501
    #     },
    # }
    return EncoderDecoderInputs(
        encoder=token_inputs(
            prompt_token_ids=[MLLAMA_IMAGE_TOKEN_ID] * num_tokens,
            prompt=MLLAMA_IMAGE_TOKEN * num_tokens,
            multi_modal_data=multi_modal_data,
        ),
        decoder=dec_inputs,
    )
190
191
192
193
194
195
196
197
198
199
200
201


def get_max_mllama_image_tokens(ctx: InputContext) -> int:
    hf_config = ctx.model_config.hf_config
    token_per_chunk = (hf_config.vision_config.image_size // 14)**2 + 1
    return hf_config.vision_config.max_num_tiles * token_per_chunk


def dummy_decoder_seq_data(seq_len: int, num_images: int):
    # <|image|> * num_images + 0 * (seq_len - num_images)
    assert seq_len >= num_images, \
        "seq_len should be greater than or equal to num_images"
202
203
204
205
206

    return SequenceData.from_prompt_token_counts(
        (MLLAMA_IMAGE_TOKEN_ID, num_images),
        (0, seq_len - num_images),
    )
207
208
209
210


def dummy_encoder_seq_data(ctx: InputContext, num_images: int):
    num_tokens = get_max_mllama_image_tokens(ctx) * num_images
211
212
213

    return SequenceData.from_prompt_token_counts(
        (MLLAMA_IMAGE_TOKEN_ID, num_tokens))
214
215
216
217
218
219
220
221
222
223
224


def dummy_image(num_images: int, ):
    width = height = 1024
    image = Image.new("RGB", (width, height), color=0)
    return {"image": image if num_images == 1 else [image] * num_images}


def dummy_decoder_data_for_mllama(ctx: InputContext, seq_len: int,
                                  mm_counts: Mapping[str, int]):
    num_images = mm_counts["image"]
225
    return DummyData(dummy_decoder_seq_data(seq_len, num_images))
226
227
228
229
230


def dummy_encoder_data_for_mllama(ctx: InputContext, seq_len: int,
                                  mm_counts: Mapping[str, int]):
    num_images = mm_counts["image"]
231
232
    return DummyData(dummy_encoder_seq_data(ctx, num_images),
                     dummy_image(num_images))
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374


def _prepare_aspect_ratio_attention_mask(
    aspect_ratio_mask: torch.Tensor,
    num_patches: int,
    target_length: int,
    dtype: torch.dtype,
) -> torch.Tensor:
    # Expand aspect ratio mask to target_length
    batch_size, max_num_tiles = aspect_ratio_mask.shape
    attention_mask = aspect_ratio_mask.view(batch_size, max_num_tiles, 1,
                                            1).to(dtype)
    attention_mask = attention_mask.repeat(1, 1, target_length, 1)

    # Mask padding patches
    pad_patches = target_length - num_patches
    attention_mask[:, :, -pad_patches:] = 0

    # Invert the mask (0 -> 1, 1 -> 0)
    attention_mask = 1 - attention_mask

    # Reshape to 2D and create 4D attention mask
    # (batch_size, 1, max_num_tiles*target_length, max_num_tiles*target_length)
    attention_mask = attention_mask.reshape(batch_size,
                                            max_num_tiles * target_length, 1)
    attention_mask = attention_mask @ attention_mask.transpose(
        -1, -2) * torch.finfo(dtype).min
    attention_mask = attention_mask.unsqueeze(1)

    return attention_mask


class ColumnParallelConv2dPatch(torch.nn.Module):
    """Conv2D Patching layer with model parallelism.
    Column parallel over unfolded input.
    Arguments:
        in_channels: Input channels.
        out_channels: Output channels.
        kernel_size: Size of convolution kernel.
        stride (default 1): Stride for convolution.
        bias (default False): Use bias in Conv2d.
    Input: (bsz, in_channels, width, height)
    Output: (bsz, num_tokens, out_channels)
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: Union[int, Tuple[int, int]],
        stride: Union[int, Tuple[int, int]],
        bias: bool = False,
    ) -> None:
        super().__init__()
        if isinstance(kernel_size, int):
            kernel_size = (kernel_size, kernel_size)
        self._unfold = torch.nn.Unfold(kernel_size=kernel_size, stride=stride)
        self._linear = ColumnParallelLinear(
            in_channels * kernel_size[0] * kernel_size[1],
            out_channels,
            bias=bias,
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self._unfold(x)
        x = x.permute(0, 2, 1)
        x, _ = self._linear(x)
        return x


class MllamaPrecomputedAspectRatioEmbedding(nn.Module):

    def __init__(self,
                 config: config_mllama.MllamaVisionConfig,
                 is_gated: bool = True):
        super().__init__()
        self.max_num_tiles = config.max_num_tiles
        self.hidden_size = config.hidden_size
        self.max_aspect_ratio_id = config.max_aspect_ratio_id
        self.is_gated = is_gated

        self.embedding = nn.Embedding(self.max_aspect_ratio_id + 1,
                                      self.max_num_tiles * self.hidden_size)
        if is_gated:
            self.gate = nn.Parameter(torch.zeros(1))

    def forward(self, hidden_state: torch.Tensor,
                aspect_ratio_ids: torch.Tensor) -> torch.Tensor:
        embeddings = self.embedding(aspect_ratio_ids)
        embeddings = embeddings.reshape(-1, self.max_num_tiles, 1,
                                        self.hidden_size)

        if self.is_gated:
            embeddings = embeddings * self.gate.tanh()

        hidden_state = hidden_state + embeddings
        return hidden_state


class MllamaPrecomputedPositionEmbedding(nn.Module):

    def __init__(self, config: config_mllama.MllamaVisionConfig):
        super().__init__()
        self.max_num_tiles = config.max_num_tiles
        self.max_aspect_ratio_id = config.max_aspect_ratio_id
        self.num_patches = (config.image_size // config.patch_size)**2 + 1
        self.hidden_size = config.hidden_size
        self.scale = config.hidden_size**-0.5

        self.gate = nn.Parameter(torch.zeros(1))

        # position embedding
        position_embedding = torch.randn(self.num_patches, self.hidden_size)
        self.embedding = nn.Parameter(self.scale * position_embedding)

        # tile position embedding
        self.tile_embedding = nn.Embedding(
            self.max_aspect_ratio_id + 1,
            self.max_num_tiles * self.num_patches * self.hidden_size)

    def forward(self, hidden_state: torch.Tensor,
                aspect_ratio_ids: torch.Tensor) -> torch.Tensor:
        # position embeddings
        gated_position_embedding = (1 - self.gate.tanh()) * self.embedding
        hidden_state = hidden_state + gated_position_embedding.view(
            1, 1, self.num_patches, self.hidden_size)

        # precomputed tile position embeddings
        tile_position_embedding = self.tile_embedding(aspect_ratio_ids)
        batch_size = hidden_state.shape[0]
        tile_position_embedding = tile_position_embedding.reshape(
            batch_size, self.max_num_tiles, self.num_patches, self.hidden_size)
        gated_tile_position_embedding = self.gate.tanh(
        ) * tile_position_embedding
        hidden_state = hidden_state + gated_tile_position_embedding

        return hidden_state


# TODO: support other attention backends for attention in vision model
class MllamaVisionSdpaAttention(nn.Module):

375
376
377
378
    def __init__(self,
                 config: config_mllama.MllamaVisionConfig,
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
        super().__init__()

        model_parallel_size = get_tensor_model_parallel_world_size()
        self.embed_dim = config.hidden_size
        self.num_heads = config.attention_heads
        self.head_dim = config.hidden_size // config.attention_heads
        self.num_local_heads = self.num_heads // model_parallel_size
        self.q_size = self.num_local_heads * self.head_dim
        self.kv_size = self.num_local_heads * self.head_dim

        self.qkv_proj = QKVParallelLinear(
            self.embed_dim,
            self.head_dim,
            self.num_heads,
            bias=False,
394
395
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
396
397
398
399
400
401
        )
        self.o_proj = RowParallelLinear(
            self.num_heads * self.head_dim,
            self.embed_dim,
            bias=False,
            input_is_parallel=True,
402
403
            quant_config=quant_config,
            prefix=f"{prefix}.o_proj",
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
        )

    def forward(
        self,
        hidden_state: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_state)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q = q.view(q.shape[0], q.shape[1], self.num_local_heads,
                   self.head_dim).transpose(1, 2)
        k = k.view(k.shape[0], k.shape[1], self.num_local_heads,
                   self.head_dim).transpose(1, 2)
        v = v.view(v.shape[0], v.shape[1], self.num_local_heads,
                   self.head_dim).transpose(1, 2)

        # TODO: remove padding in image encoder
        attn_output = F.scaled_dot_product_attention(q,
                                                     k,
                                                     v,
                                                     attn_mask=attention_mask,
                                                     dropout_p=0.0)

        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.reshape(attn_output.shape[0],
                                          attn_output.shape[1], -1)
        output, _ = self.o_proj(attn_output)
        return output


class MllamaVisionEncoderLayer(nn.Module):

436
437
438
439
440
441
442
    def __init__(
        self,
        config: config_mllama.MllamaVisionConfig,
        quant_config: Optional[QuantizationConfig],
        prefix: str = "",
        is_gated: bool = False,
    ) -> None:
443
444
445
446
447
448
449
        super().__init__()

        self.hidden_size = config.hidden_size
        self.num_attention_heads = config.attention_heads
        self.is_gated = is_gated
        self.intermediate_size = config.intermediate_size

450
451
        self.self_attn = MllamaVisionSdpaAttention(
            config, quant_config=quant_config, prefix=f"{prefix}.self_attn")
452
453
454
        self.mlp = CLIPMLP(config,
                           quant_config=quant_config,
                           prefix=f"{prefix}.mlp")
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490

        self.input_layernorm = nn.LayerNorm(self.hidden_size,
                                            eps=config.norm_eps)
        self.post_attention_layernorm = nn.LayerNorm(self.hidden_size,
                                                     eps=config.norm_eps)

        # there used to be an if else here, no code path
        if is_gated:
            self.gate_attn = nn.Parameter(torch.ones(1) * math.pi / 4)
            self.gate_ffn = nn.Parameter(torch.ones(1) * math.pi / 4)

    def forward(
        self,
        hidden_state: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
    ):
        # Self Attention
        residual = hidden_state
        hidden_state = self.input_layernorm(hidden_state)
        hidden_state = self.self_attn(hidden_state,
                                      attention_mask=attention_mask)
        gate_attn = 1 if not self.is_gated else self.gate_attn.tanh()
        hidden_state = residual + gate_attn * hidden_state

        # Feed forward
        residual = hidden_state
        hidden_state = self.post_attention_layernorm(hidden_state)
        hidden_state = self.mlp(hidden_state)
        gate_ffn = 1 if not self.is_gated else self.gate_ffn.tanh()
        hidden_state = residual + gate_ffn * hidden_state

        return hidden_state


class MllamaVisionEncoder(nn.Module):

491
492
493
494
495
496
497
498
499
    def __init__(
        self,
        config: config_mllama.MllamaVisionConfig,
        quant_config: Optional[QuantizationConfig],
        num_layers: int = 32,
        is_gated: bool = False,
        output_hidden_states=None,
        prefix: str = "",
    ) -> None:
500
501
502
        super().__init__()
        self.config = config
        self.layers = nn.ModuleList([
503
504
505
506
507
            MllamaVisionEncoderLayer(config,
                                     quant_config=quant_config,
                                     is_gated=is_gated,
                                     prefix=f"{prefix}.layers.{layer_idx}")
            for layer_idx in range(num_layers)
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
        ])
        self.output_hidden_states = output_hidden_states or []

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
    ) -> Union[Tuple, BaseModelOutput]:
        encoder_states = ()

        for i, encoder_layer in enumerate(self.layers):
            if i in self.output_hidden_states:
                encoder_states = encoder_states + (hidden_states, )
            hidden_states = encoder_layer(
                hidden_states,
                attention_mask,
            )

        if len(self.layers) - 1 in self.output_hidden_states:
            encoder_states = encoder_states + (hidden_states, )

        return hidden_states, encoder_states


class MllamaVisionModel(nn.Module):

534
535
536
537
538
539
    def __init__(
        self,
        config: config_mllama.MllamaVisionConfig,
        quant_config: Optional[QuantizationConfig],
        prefix: str = "",
    ) -> None:
540
        super().__init__()
541

542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
        self.image_size = config.image_size
        self.patch_size = config.patch_size
        self.max_num_tiles = config.max_num_tiles
        self.hidden_size = config.hidden_size
        self.in_channels = config.num_channels
        self.intermediate_layers_indices = config.intermediate_layers_indices

        self.num_patches = (self.image_size // self.patch_size)**2 + 1
        self.scale = config.hidden_size**-0.5

        self.patch_embedding = ColumnParallelConv2dPatch(
            in_channels=config.num_channels,
            out_channels=self.hidden_size,
            kernel_size=self.patch_size,
            stride=self.patch_size,
            bias=False,
        )

        self.class_embedding = nn.Parameter(self.scale *
                                            torch.randn(self.hidden_size))
        self.gated_positional_embedding = MllamaPrecomputedPositionEmbedding(
            config)

        self.pre_tile_positional_embedding = \
            MllamaPrecomputedAspectRatioEmbedding(config, is_gated=True)
        self.post_tile_positional_embedding = \
            MllamaPrecomputedAspectRatioEmbedding(config, is_gated=True)

        # layer norms
        self.layernorm_pre = nn.LayerNorm(self.hidden_size)
        self.layernorm_post = nn.LayerNorm(self.hidden_size)

        # encoders
        self.transformer = MllamaVisionEncoder(
            config,
577
            quant_config,
578
579
            config.num_hidden_layers,
            is_gated=False,
580
581
582
583
584
585
586
587
588
589
            output_hidden_states=config.intermediate_layers_indices,
            prefix=f"{prefix}.transformer",
        )
        self.global_transformer = MllamaVisionEncoder(
            config,
            quant_config,
            config.num_global_layers,
            is_gated=True,
            prefix=f"{prefix}.global_transformer",
        )
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730

    def apply_class_embedding(self,
                              hidden_state: torch.Tensor) -> torch.Tensor:
        batch_size, _, hidden_size = hidden_state.shape
        class_embedding = self.class_embedding.expand(batch_size, 1,
                                                      hidden_size)
        hidden_state = torch.cat([class_embedding, hidden_state], dim=1)
        return hidden_state

    def forward(self, pixel_values: torch.Tensor,
                aspect_ratio_ids: torch.Tensor,
                aspect_ratio_mask: torch.Tensor) -> torch.Tensor:
        batch_size, num_concurrent_media, num_tiles, num_channels, \
            height, width = pixel_values.shape

        pixel_values = pixel_values.reshape(
            batch_size * num_concurrent_media * num_tiles, num_channels,
            height, width)
        aspect_ratio_ids = aspect_ratio_ids.reshape(
            batch_size * num_concurrent_media, -1)

        # patch embedding
        patch_embeds = self.patch_embedding(
            pixel_values.to(self.layernorm_pre.weight.dtype))
        hidden_state = patch_embeds
        hidden_state = ps.get_tp_group().all_gather(hidden_state)

        # tile embeddings
        _, num_patches, dim = hidden_state.shape
        hidden_state = hidden_state.reshape(batch_size * num_concurrent_media,
                                            num_tiles, -1, dim)
        hidden_state = self.pre_tile_positional_embedding(
            hidden_state, aspect_ratio_ids)

        # apply cls token
        hidden_state = hidden_state.reshape(
            batch_size * num_concurrent_media * num_tiles, num_patches, dim)
        hidden_state = self.apply_class_embedding(hidden_state)
        num_patches += 1

        # apply position embeddings
        hidden_state = hidden_state.reshape(batch_size * num_concurrent_media,
                                            num_tiles, num_patches, dim)
        hidden_state = self.gated_positional_embedding(hidden_state,
                                                       aspect_ratio_ids)

        # apply encoder
        hidden_state = self.layernorm_pre(hidden_state)

        # Compute the number of tokens to pad
        num_padding_patches = (8 - (hidden_state.shape[-2] % 8)) % 8
        # Compute padding tuple for pad function
        padding = (
            0, 0, 0, num_padding_patches
        )  # (pad_left, pad_right, pad_left for dim -2, pad_right for dim -2)
        # Pad the tensor
        hidden_state = F.pad(hidden_state, padding, mode="constant", value=0)
        slice_index = -num_padding_patches if num_padding_patches > 0 else None

        attention_mask = aspect_ratio_mask.reshape(
            batch_size * num_concurrent_media, -1)
        attention_mask = _prepare_aspect_ratio_attention_mask(
            aspect_ratio_mask=attention_mask,
            num_patches=self.num_patches,
            target_length=hidden_state.shape[2],
            dtype=self.layernorm_pre.weight.dtype,
        )

        hidden_state = hidden_state.view(batch_size * num_concurrent_media, -1,
                                         dim)
        output = self.transformer(
            hidden_state,
            attention_mask=attention_mask,
        )
        hidden_state, intermediate_hidden_states = output[0], output[1]
        intermediate_hidden_states = torch.stack(intermediate_hidden_states,
                                                 dim=-1)

        # apply global encoder
        hidden_state = self.layernorm_post(hidden_state)
        hidden_state = hidden_state.reshape(batch_size * num_concurrent_media,
                                            num_tiles,
                                            num_patches + num_padding_patches,
                                            dim)
        hidden_state = self.post_tile_positional_embedding(
            hidden_state, aspect_ratio_ids)
        hidden_state = hidden_state.reshape(
            batch_size * num_concurrent_media,
            num_tiles * (num_patches + num_padding_patches), dim)
        hidden_state = self.global_transformer(
            hidden_state, attention_mask=attention_mask)[0]
        hidden_state = hidden_state.reshape(batch_size * num_concurrent_media,
                                            num_tiles,
                                            num_patches + num_padding_patches,
                                            dim)
        hidden_state = hidden_state[:, :, :slice_index]

        # adding intermediate layer outputs
        hidden_state = hidden_state.reshape(batch_size, num_concurrent_media,
                                            num_tiles, num_patches, dim)
        intermediate_hidden_states = intermediate_hidden_states.reshape(
            batch_size * num_concurrent_media, num_tiles,
            num_patches + num_padding_patches, -1)
        intermediate_hidden_states = intermediate_hidden_states[:, :, :
                                                                slice_index]
        intermediate_hidden_states = intermediate_hidden_states.reshape(
            batch_size, num_concurrent_media, num_tiles, num_patches, -1)
        hidden_state = torch.cat([hidden_state, intermediate_hidden_states],
                                 dim=-1)
        return hidden_state


class MllamaTextRMSNorm(nn.Module):

    def __init__(self, hidden_size, eps=1e-6):
        """
        MllamaTextRMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance +
                                                    self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)

    def extra_repr(self):
        return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"


class MllamaTextCrossAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(
        self,
        config: Optional[config_mllama.MllamaTextConfig] = None,
        layer_idx: Optional[int] = None,
731
        quant_config: Optional[QuantizationConfig] = None,
732
        prefix: str = "",
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
    ):
        super().__init__()
        self.config = config
        self.model_parallel_size = get_tensor_model_parallel_world_size()
        self.num_heads = self.config.num_attention_heads
        self.num_local_heads = self.num_heads // self.model_parallel_size
        self.num_key_value_heads = self.config.num_key_value_heads
        self.num_local_key_value_heads = \
            self.num_key_value_heads // self.model_parallel_size
        self.dropout = config.dropout
        self.hidden_size = config.hidden_size
        self.head_dim = config.hidden_size // self.num_heads
        self.layer_idx = layer_idx
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
        self.q_local_size = self.num_local_heads * self.head_dim
        self.kv_local_size = self.num_local_key_value_heads * self.head_dim

        # TODO: change to Q/KV separate linear after #7448 is merged
        self.qkv_proj = QKVParallelLinear(
            self.hidden_size,
            self.head_dim,
            self.num_heads,
            self.num_key_value_heads,
            bias=False,
757
            quant_config=quant_config,
758
            prefix=f"{prefix}.qkv_proj",
759
760
761
762
763
764
        )
        self.o_proj = RowParallelLinear(
            self.num_heads * self.head_dim,
            self.hidden_size,
            bias=False,
            input_is_parallel=True,
765
            quant_config=quant_config,
766
            prefix=f"{prefix}.o_proj",
767
768
769
770
771
772
773
774
775
776
777
778
        )
        # vllm.model_executor.layers.layernorm.RMSNorm has precision issue,
        # use huggingface's instead
        self.q_norm = MllamaTextRMSNorm(self.head_dim, eps=config.rms_norm_eps)
        self.k_norm = MllamaTextRMSNorm(self.head_dim, eps=config.rms_norm_eps)
        self.scaling = self.head_dim**-0.5

        self.attn = Attention(
            self.num_local_heads,
            self.head_dim,
            self.scaling,
            self.num_local_key_value_heads,
779
            prefix=f"{prefix}.attn",
780
            attn_type=AttentionType.ENCODER_DECODER,
781
782
783
784
785
786
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor],
787
        kv_range_for_decode: Optional[List[Tuple[int, int]]],
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
        cross_attention_states: Optional[torch.Tensor],
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
    ) -> torch.Tensor:
        qkv_dec, _ = self.qkv_proj(hidden_states)
        q, _, _ = qkv_dec.split(
            [self.q_local_size, self.kv_local_size, self.kv_local_size],
            dim=-1)
        if cross_attention_states is None:
            k = None
            v = None
        else:
            qkv_enc, _ = self.qkv_proj(cross_attention_states)
            _, k, v = qkv_enc.split(
                [self.q_local_size, self.kv_local_size, self.kv_local_size],
                dim=-1)
            k = k.view(-1, self.num_local_key_value_heads, self.head_dim)
            v = v.view(-1, self.num_local_key_value_heads, self.head_dim)
            k = self.k_norm(k)
        q = q.view(-1, self.num_local_heads, self.head_dim)
        q = self.q_norm(q)

810
        if attention_mask is not None:
811
812
813
814
            output = self._attention_with_mask(q, k, v, kv_cache,
                                               attention_mask,
                                               kv_range_for_decode,
                                               attn_metadata)
815
        else:
816
817
818
            output = self.attn(
                q.view(-1, self.num_local_heads * self.head_dim), k, v,
                kv_cache, attn_metadata)
819
820
821
        out, _ = self.o_proj(output)
        return out

822
    def _attention_with_mask(
823
824
825
826
827
828
829
830
831
832
        self,
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
        kv_cache: torch.Tensor,
        attention_mask: torch.Tensor,
        kv_range_for_decode: List[Tuple[int, int]],
        attn_metadata: AttentionMetadata,
    ) -> torch.Tensor:
        # Skip writing kv-cache for the initial profiling run.
833
        if len(kv_cache.shape) > 1:
834
            i = torch.ones(1, dtype=torch.float32)
835
836
            if self.attn.backend in (_Backend.FLASH_ATTN,
                                     _Backend.FLASH_ATTN_VLLM_V1):
837
838
839
840
841
842
843
844
845
846
                cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode])
                cached_v = torch.cat([v[s:e] for s, e in kv_range_for_decode])
                torch.ops._C_cache_ops.reshape_and_cache_flash(
                    cached_k,
                    cached_v,
                    kv_cache[0],
                    kv_cache[1],
                    attn_metadata.
                    cross_slot_mapping,  # type: ignore[union-attr]
                    "auto",
847
848
                    i,
                    i,
849
                )
850
            elif self.attn.backend in (_Backend.XFORMERS, _Backend.TORCH_SDPA):
851
852
853
854
855
856
                key_cache, value_cache = PagedAttention.split_kv_cache(
                    kv_cache, self.num_local_key_value_heads, self.head_dim)
                cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode])
                cached_v = torch.cat([v[s:e] for s, e in kv_range_for_decode])
                PagedAttention.write_to_paged_cache(
                    cached_k, cached_v, key_cache, value_cache,
857
                    attn_metadata.cross_slot_mapping, "auto", i, i)
858
859
            else:
                raise ValueError(
860
861
862
                    f"Unsupported Attention backend {self.attn.backend} "
                    "enum found. Expected the Attention backend to be "
                    "FLASH_ATTN, FLASH_ATTN_VLLM_V1, XFORMERS or TORCH_SDPA.")
863

864
865
866
867
868
869
870
871
872
873
        # We have to call torch.sdpa for prefill when using a
        # custom cross-attention mask. Because the mask is not a
        # standard causal mask, neither a block diagonal mask which
        # can be optimized by xformers.BlockDiagonalMask.
        # The mask is specially calculated for supporting multi
        # images and interleaved images.
        q_len = q.shape[0]
        kv_len = k.shape[0]
        q = q.transpose(0, 1).view(self.num_local_key_value_heads,
                                   self.num_key_value_groups, q_len,
874
                                   self.head_dim).contiguous()
875
876
877
878
        k = k.transpose(0,
                        1)[:,
                           None, :, :].expand(self.num_local_key_value_heads,
                                              self.num_key_value_groups,
879
880
                                              kv_len,
                                              self.head_dim).contiguous()
881
882
883
884
        v = v.transpose(0,
                        1)[:,
                           None, :, :].expand(self.num_local_key_value_heads,
                                              self.num_key_value_groups,
885
886
                                              kv_len,
                                              self.head_dim).contiguous()
887
888
889
890
891
892
893
894
895
896
        attention_mask = attention_mask.view(1, 1, q_len, kv_len)
        output = F.scaled_dot_product_attention(q,
                                                k,
                                                v,
                                                attn_mask=attention_mask,
                                                is_causal=False)
        output = output.permute(2, 0, 1, 3).reshape(
            q_len, self.num_local_heads * self.head_dim)
        return output

897
898
899
900
901

class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
    """Cross-attention transformer block with tanh-gated attention
    and feedforward."""

902
903
904
905
906
907
908
    def __init__(
        self,
        config: config_mllama.MllamaTextConfig,
        layer_idx: int,
        quant_config: Optional[QuantizationConfig],
        prefix: str = "",
    ) -> None:
909
        super().__init__()
910

911
912
913
914
        self.layer_idx = layer_idx
        self.cross_attn = MllamaTextCrossAttention(
            config=config,
            layer_idx=layer_idx,
915
            quant_config=quant_config,
916
            prefix=f"{prefix}.cross_attn",
917
918
919
920
921
922
923
924
925
926
        )

        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.cross_attn_attn_gate = torch.nn.Parameter(torch.zeros(1))

        self.mlp = LlamaMLP(
            hidden_size=config.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
927
            quant_config=quant_config,
928
            prefix=f"{prefix}.mlp",
929
930
931
932
933
934
935
936
937
938
        )
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)
        self.cross_attn_mlp_gate = torch.nn.Parameter(torch.zeros(1))

    def forward(
        self,
        hidden_states: torch.Tensor,
        cross_attention_states: torch.Tensor,
        cross_attention_mask: torch.Tensor,
939
        kv_range_for_decode: Optional[List[Tuple[int, int]]],
940
941
942
943
944
945
946
947
948
949
        full_text_row_masked_out_mask: torch.Tensor,
        kv_cache: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
    ) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)

        hidden_states = self.cross_attn(
            hidden_states=hidden_states,
            attention_mask=cross_attention_mask,
950
            kv_range_for_decode=kv_range_for_decode,
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
            cross_attention_states=cross_attention_states,
            kv_cache=kv_cache,
            attn_metadata=attn_metadata,
        )
        hidden_states = full_text_row_masked_out_mask * hidden_states
        hidden_states = residual + self.cross_attn_attn_gate.tanh(
        ) * hidden_states

        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = full_text_row_masked_out_mask * hidden_states
        hidden_states = residual + self.cross_attn_mlp_gate.tanh(
        ) * hidden_states
        return hidden_states


class MllamaTextModel(nn.Module):
    config_class = config_mllama.MllamaTextConfig
    base_model_prefix = "model"

972
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
973
        super().__init__()
974

975
976
977
978
        config = vllm_config.model_config.hf_config.text_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

979
980
981
982
983
984
985
986
987
988
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size
        self.embed_tokens = VocabParallelEmbedding(config.vocab_size + 8,
                                                   config.hidden_size)
        self.cross_attention_layers = config.cross_attention_layers

        layers = []
        for layer_idx in range(config.num_hidden_layers):
            if layer_idx in self.cross_attention_layers:
                layers.append(
989
                    MllamaCrossAttentionDecoderLayer(
990
991
992
993
994
                        config,
                        layer_idx,
                        quant_config=quant_config,
                        prefix=f"{prefix}.layers.{layer_idx}",
                    ))
995
996
997
            else:
                # TODO: force LlamaDecoderLayer to config.attention_bias=False
                layers.append(
998
999
1000
1001
1002
1003
                    LlamaDecoderLayer(
                        config,
                        cache_config=cache_config,
                        quant_config=quant_config,
                        prefix=f"{prefix}.layers.{layer_idx}",
                    ))
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013

        self.layers = nn.ModuleList(layers)
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(
        self,
        input_ids: torch.LongTensor,
        positions: Optional[torch.LongTensor],
        cross_attention_states: Optional[torch.LongTensor],
        cross_attention_mask: Optional[torch.LongTensor],
1014
        kv_range_for_decode: Optional[List[Tuple[int, int]]],
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
        full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor,
                                                      torch.Tensor]],
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
        skip_cross_attention: bool,
    ) -> torch.Tensor:
        inputs_embeds = self.embed_tokens(input_ids)
        hidden_states = inputs_embeds

        for idx, decoder_layer in enumerate(self.layers):
            if isinstance(decoder_layer, MllamaCrossAttentionDecoderLayer):
                if not skip_cross_attention:
                    hidden_states = decoder_layer(
                        hidden_states=hidden_states,
                        cross_attention_states=cross_attention_states,
                        cross_attention_mask=cross_attention_mask,
1031
                        kv_range_for_decode=kv_range_for_decode,
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
                        full_text_row_masked_out_mask=
                        full_text_row_masked_out_mask,
                        kv_cache=kv_caches[idx],
                        attn_metadata=attn_metadata,
                    )
            elif isinstance(decoder_layer, LlamaDecoderLayer):
                hidden_states, residual = decoder_layer(
                    positions=positions,
                    hidden_states=hidden_states,
                    kv_cache=kv_caches[idx],
                    attn_metadata=attn_metadata,
                    residual=None,
                )
                hidden_states = hidden_states + residual
            else:
                raise ValueError(
                    f"Unknown decoder layer type {type(decoder_layer)}")
        hidden_states = self.norm(hidden_states)
        return hidden_states


class MllamaForCausalLM(nn.Module):
    config_class = config_mllama.MllamaTextConfig
    base_model_prefix = "language_model"
    _no_split_modules = [
        "MllamaCrossAttentionDecoderLayer", "MllamaSelfAttentionDecoderLayer"
    ]

1060
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
1061
        super().__init__()
1062
1063
1064
1065

        config = vllm_config.model_config.hf_config.text_config
        quant_config = vllm_config.quant_config

1066
        self.vocab_size = config.vocab_size
1067
        self.model = MllamaTextModel(vllm_config=vllm_config,
1068
                                     prefix=f"{prefix}.model")
1069
1070
1071
1072
1073
1074
        self.lm_head = ParallelLMHead(
            config.vocab_size,
            config.hidden_size,
            org_num_embeddings=config.vocab_size,
            padding_size=DEFAULT_VOCAB_PADDING_SIZE,
            quant_config=quant_config,
1075
            prefix=f"{prefix}.lm_head",
1076
1077
1078
1079
1080
1081
1082
1083
        )

    def forward(
        self,
        input_ids: torch.LongTensor,
        positions: Optional[torch.LongTensor],
        cross_attention_states: Optional[torch.LongTensor],
        cross_attention_mask: Optional[torch.LongTensor],
1084
        kv_range_for_decode: Optional[List[Tuple[int, int]]],
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
        full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor,
                                                      torch.Tensor]],
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
        skip_cross_attention: bool,
    ) -> torch.Tensor:
        hidden_states = self.model(
            input_ids=input_ids,
            positions=positions,
            cross_attention_states=cross_attention_states,
            cross_attention_mask=cross_attention_mask,
1096
            kv_range_for_decode=kv_range_for_decode,
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
            full_text_row_masked_out_mask=full_text_row_masked_out_mask,
            kv_caches=kv_caches,
            attn_metadata=attn_metadata,
            skip_cross_attention=skip_cross_attention,
        )
        return hidden_states


@MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_mllama_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_decoder_data_for_mllama)
@INPUT_REGISTRY.register_dummy_encoder_data(dummy_encoder_data_for_mllama)
@INPUT_REGISTRY.register_input_processor(input_processor_for_mllama)
class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
1111
1112
1113
    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
        "gate_up_proj": ["gate_proj", "up_proj"]
1114
    }
1115

1116
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
1117
        super().__init__()
1118
1119
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
1120
        self.quant_config = quant_config
1121
1122
1123
1124
1125
1126
1127
1128
        self.vocab_size = config.text_config.vocab_size
        self.hidden_size = config.text_config.hidden_size
        self.max_num_tiles = config.vision_config.max_num_tiles
        self.vision_output_dim = config.vision_config.vision_output_dim
        self.pad_token_id = \
            config.pad_token_id if config.pad_token_id is not None else -1
        self.image_size = config.vision_config.image_size

1129
        self.vision_model = MllamaVisionModel(config.vision_config,
1130
                                              quant_config,
1131
1132
                                              prefix=maybe_prefix(
                                                  prefix, "vision_model"))
1133
        self.language_model = MllamaForCausalLM(
1134
1135
            vllm_config=vllm_config,
            prefix=maybe_prefix(prefix, "language_model"),
1136
        )
1137
        self.multi_modal_projector = ColumnParallelLinear(
1138
1139
1140
            config.vision_config.vision_output_dim,
            config.text_config.hidden_size,
            bias=True,
1141
1142
            quant_config=quant_config,
            gather_output=True,
1143
            prefix=maybe_prefix(prefix, "multi_modal_projector"),
1144
1145
1146
        )
        self.logits_processor = LogitsProcessor(config.output_hidden_states,
                                                config.text_config.vocab_size)
Joe Runde's avatar
Joe Runde committed
1147
        self.sampler = get_sampler()
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
        logits = self.logits_processor(self.language_model.lm_head,
                                       hidden_states, sampling_metadata)
        return logits

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
        next_tokens = self.sampler(logits, sampling_metadata)
        return next_tokens

    def _parse_and_validate_image_input(self, **kwargs: object):
        # tensor with the same shape will be batched together by
1168
        # MultiModalKwargs.batch, so pixel_values here can be:
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
        #   - List[List[torch.Tensor]]:
        #       with shape (num_tiles, 3, image_res, image_res)
        #   - List[torch.Tensor]:
        #       with shape (num_image, num_tiles, 3, image_res, image_res)
        #   - torch.Tensor:
        #       with shape (bs, num_image, num_tiles, 3, image_res, image_res)
        pixel_values: Optional[Union[List[List[torch.Tensor]],
                                     List[torch.Tensor],
                                     torch.Tensor]] = kwargs.pop(
                                         "pixel_values", None)
        image_embeds: Optional[Union[List[List[torch.Tensor]],
                                     List[torch.Tensor],
                                     torch.Tensor]] = kwargs.pop(
                                         "image_embeds", None)
        aspect_ratio_ids: Optional[Union[List[List[torch.Tensor]],
                                         List[torch.Tensor],
                                         torch.Tensor]] = kwargs.pop(
                                             "aspect_ratio_ids", None)
        aspect_ratio_mask: Optional[Union[List[List[torch.Tensor]],
                                          List[torch.Tensor],
                                          torch.Tensor]] = kwargs.pop(
                                              "aspect_ratio_mask", None)

        if pixel_values is None and image_embeds is None:
            return None

        if pixel_values is not None and image_embeds is not None:
            raise ValueError(
                "Both pixel values and image embeds are provided.")

        if pixel_values is not None:
            assert aspect_ratio_ids is not None
            assert aspect_ratio_mask is not None
            max_num_images = max([len(x[0]) for x in pixel_values])
            if max_num_images == 0:
                raise ValueError("No images provided.")
            max_num_tiles = max(
                max([len(x) for x in y[0]]) for y in pixel_values)
1207
            device = next(self.multi_modal_projector.parameters()).device
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
            bsz = len(pixel_values)
            out_num_tiles = []
            out_images = torch.zeros(
                bsz,
                max_num_images,
                max_num_tiles,
                3,
                self.image_size,
                self.image_size,
                dtype=torch.float32,
                device=device,
            )
            out_ar_ids = torch.ones(bsz,
                                    max_num_images,
                                    dtype=torch.int64,
                                    device=device)
            out_ar_mask = torch.zeros(bsz,
                                      max_num_images,
                                      max_num_tiles,
                                      dtype=torch.int64,
                                      device=device)
            for b in range(len(pixel_values)):
                _num_tiles = []
                for i in range(len(pixel_values[b][0])):
                    img = pixel_values[b][0][i]
                    out_images[b, i, :img.shape[0]] = img
                    out_ar_ids[b, i] = aspect_ratio_ids[b][0][i]
                    out_ar_mask[b, i] = aspect_ratio_mask[b][0][i]
                    _num_tiles.append(img.shape[0])
                out_num_tiles.append(_num_tiles)

            return MllamaImagePixelInputs(
                type="pixel_values",
                data=out_images,
                aspect_ratio_ids=out_ar_ids,
                aspect_ratio_mask=out_ar_mask,
            )

        if image_embeds is not None:
            raise NotImplementedError

        raise AssertionError("This line should be unreachable.")

    def flat_encoder_result(self, cross_attention_states: torch.Tensor,
1252
1253
                            attn_metadata: AttentionMetadata,
                            actual_encoder_seq_lens: List[int]):
1254
1255

        cross_attention_states_flat = torch.zeros(
1256
            sum(actual_encoder_seq_lens),
1257
1258
1259
1260
            cross_attention_states.shape[-1],
            device=cross_attention_states.device,
            dtype=cross_attention_states.dtype)
        start_pos = 0
1261
1262
        for seq_len, vision_token_in_batch in zip(actual_encoder_seq_lens,
                                                  cross_attention_states):
1263
1264
1265
1266
1267
            end_pos = start_pos + seq_len
            cross_attention_states_flat[
                start_pos:end_pos] = vision_token_in_batch[:seq_len]
            start_pos = end_pos
        cross_attention_states = cross_attention_states_flat
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
        return cross_attention_states

    def get_cross_attention_states(
        self,
        image_inputs: MllamaImagePixelInputs,
        attn_metadata: AttentionMetadata,
        actual_encoder_seq_lens: List[int],
    ) -> Tuple[torch.Tensor]:
        # NOTE: llama's reference implementation runs vision model on CPU
        pixel_values = image_inputs['data']
        aspect_ratio_ids = image_inputs['aspect_ratio_ids']
        aspect_ratio_mask = image_inputs['aspect_ratio_mask']
        cross_attention_states = self.vision_model(pixel_values,
                                                   aspect_ratio_ids,
                                                   aspect_ratio_mask)
1283
        cross_attention_states, _ = self.multi_modal_projector(
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
            cross_attention_states)

        bsz, _, _, _, image_token_dim = tuple(cross_attention_states.shape)
        cross_attention_states = cross_attention_states.view(
            bsz, -1, image_token_dim)

        cross_attention_states = self.flat_encoder_result(
            cross_attention_states, attn_metadata, actual_encoder_seq_lens)

        return cross_attention_states

    def get_cross_attention_mask(
        self,
        input_ids: torch.Tensor,
        attn_metadata: AttentionMetadata,
        num_tiles: List[List[int]],
        num_tokens_per_tile: int,
        dtype: torch.dtype,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        token_ids = input_ids.tolist()
        start = 0
        batch_token_ids = []
        for seq_len in attn_metadata.seq_lens:
            batch_token_ids.append(token_ids[start:start + seq_len])
            start += seq_len
        sparse_mask = [
            get_cross_attention_token_mask(t, MLLAMA_IMAGE_TOKEN_ID)
            for t in batch_token_ids
        ]
1313

1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
        # Skip generating cross-attention mask if all samples
        # are text-only or have only 1 leading image.
        if skip_attention_mask(sparse_mask):
            return None, None

        dense_mask, tile_range_for_decode = \
            convert_sparse_cross_attention_mask_to_dense(
                sparse_mask, num_tiles, attn_metadata.seq_lens)
        cross_attention_mask = \
            convert_dense_cross_attention_mask_to_tensor(
                dense_mask, num_tokens_per_tile, input_ids.device, dtype)
        kv_range_for_decode = [[
            t[0] * num_tokens_per_tile, t[1] * num_tokens_per_tile
        ] for t in tile_range_for_decode]

        return cross_attention_mask, kv_range_for_decode

    def get_full_text_row_masked_out_mask(
        self,
        attn_metadata: AttentionMetadata,
        device: torch.device,
    ) -> torch.Tensor:
1336
1337
1338
        full_text_row_masked_out_mask = torch.ones(
            (attn_metadata.num_prefill_tokens, 1), dtype=torch.bool)
        start_pos = 0
1339
1340
        for seq_len, encoder_seq_len in zip(attn_metadata.seq_lens,
                                            attn_metadata.encoder_seq_lens):
1341
1342
1343
1344
1345
            if encoder_seq_len == 0:
                full_text_row_masked_out_mask[start_pos:start_pos +
                                              seq_len] = False
            start_pos += seq_len
        full_text_row_masked_out_mask = full_text_row_masked_out_mask.to(
1346
1347
            device)
        return full_text_row_masked_out_mask
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
        **kwargs: object,
    ) -> Union[Tuple, CausalLMOutputWithPast]:
        if attn_metadata.num_prefill_tokens > 0 and \
            attn_metadata.num_decode_tokens > 0:
            raise ValueError("Chunk prefill not supported")
        image_inputs = self._parse_and_validate_image_input(**kwargs)
1361
1362
1363
1364
1365
        cross_attention_states = None
        cross_attention_mask = None
        kv_range_for_decode = None

        # For 1) text-only prefill and decode, 2) image-present decode.
1366
1367
1368
1369
1370
        if image_inputs is None:
            full_text_row_masked_out_mask = (
                attn_metadata.encoder_seq_lens_tensor != 0).reshape(-1, 1).to(
                    input_ids.device)
            skip_cross_attention = max(attn_metadata.encoder_seq_lens) == 0
1371
1372

        # For image-present prefill.
1373
1374
        else:
            skip_cross_attention = False
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401

            # Get the actual number of encoder tokens for each sample.
            # Because attn_metadata.encoder_seq_lens only counts the last
            # group of images for each sample, which is used to cheat the
            # block manager to allocate blocks for those images only.
            # See input_processor_for_mllama() for more details.
            num_tiles_tensor = kwargs.pop("num_tiles")
            num_tiles = [t[0].tolist() for t in num_tiles_tensor]
            num_tokens_per_tile = (self.image_size // 14)**2 + 1
            actual_encoder_seq_lens = [
                sum(num_tile) * num_tokens_per_tile for num_tile in num_tiles
            ]
            for actual_len, last_group_len in zip(
                    actual_encoder_seq_lens, attn_metadata.encoder_seq_lens):
                assert actual_len >= last_group_len

            cross_attention_states = self.get_cross_attention_states(
                image_inputs, attn_metadata, actual_encoder_seq_lens)

            full_text_row_masked_out_mask = \
                self.get_full_text_row_masked_out_mask(
                    attn_metadata, input_ids.device)

            cross_attention_mask, kv_range_for_decode = \
                self.get_cross_attention_mask(
                    input_ids, attn_metadata, num_tiles,
                    num_tokens_per_tile, cross_attention_states.dtype)
1402
1403
1404
1405
1406
1407

        outputs = self.language_model(
            input_ids=input_ids,
            positions=positions,
            cross_attention_states=cross_attention_states,
            cross_attention_mask=cross_attention_mask,
1408
            kv_range_for_decode=kv_range_for_decode,
1409
1410
1411
1412
1413
1414
1415
1416
            full_text_row_masked_out_mask=full_text_row_masked_out_mask,
            kv_caches=kv_caches,
            attn_metadata=attn_metadata,
            skip_cross_attention=skip_cross_attention,
        )

        return outputs

1417
1418
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
1419
1420
1421
1422
1423
1424
1425
1426
1427
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
1428
        updated_params: Set[str] = set()
1429
1430
1431
1432
1433
        for name, loaded_weight in weights:
            if 'patch_embedding.weight' in name:
                name = name.replace('patch_embedding.weight',
                                    'patch_embedding._linear.weight')
                loaded_weight = loaded_weight.view(loaded_weight.shape[0], -1)
1434
1435
            if (self.quant_config is not None and
                (scale_name := self.quant_config.get_cache_scale(name))):
1436
                # Loading kv cache quantization scales
1437
1438
1439
1440
1441
1442
1443
1444
                param = params_dict[scale_name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
                                 loaded_weight[0])
                weight_loader(param, loaded_weight)
                updated_params.add(scale_name)
                continue
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
                updated_params.add(name)
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict.pop(name)
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
1459
1460
                updated_params.add(name)
        return updated_params
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510


def skip_attention_mask(sparse_mask: List[List[int]]) -> bool:
    for mask in sparse_mask:
        # Skip text-only samples.
        if len(mask) == 0:
            continue
        # If the sample contains more than 1 images,
        # we can't skip mask.
        if len(mask) != 1:
            return False
        # If the sample contains only 1 image,
        # but the image is not the leading one,
        # we can't skip mask.
        if mask[0][0] != 0 or mask[0][1] != -1:
            return False
    return True


def convert_sparse_cross_attention_mask_to_dense(
    sparse_mask: List[List[List[int]]],
    num_tiles: List[List[int]],
    lengths: List[int],
) -> Tuple[np.ndarray, List[Tuple[int, int]]]:
    total_length = sum(lengths)
    total_tiles = sum([sum(tiles) for tiles in num_tiles])
    dense_mask = np.zeros(shape=(total_length, total_tiles), dtype=np.int64)
    # A list of ranges, range[i] = [start, end] means
    # if the i-th sample has N tiles in total, the tiles[start, end]
    # will be used for cross-attention decoding.
    tile_range_for_decode = []

    seq_start = 0
    tile_start = 0
    for masks, tiles, length in zip(sparse_mask, num_tiles, lengths):
        ts, td = -1, 0
        for mask, tile in zip(masks, tiles):
            if len(mask) != 2:
                continue
            start, end = mask
            end = min(end, length)
            if end == -1:
                end = length
            if end == length:
                if ts == -1:
                    ts = tile_start
                td += tile
            dense_mask[seq_start + start:seq_start + end,
                       tile_start:tile_start + tile] = 1
            tile_start += tile
1511
1512
        assert ts != -1
        assert td != 0
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
        tile_range_for_decode.append((ts, ts + td))
        seq_start += length

    return dense_mask, tile_range_for_decode


def convert_dense_cross_attention_mask_to_tensor(
    cross_attention_token_mask: np.ndarray,
    num_tokens_per_tile: int,
    device: torch.device,
    dtype: torch.dtype,
) -> torch.Tensor:
    mask = torch.tensor(cross_attention_token_mask, dtype=dtype, device=device)
    mask = mask.repeat_interleave(num_tokens_per_tile, dim=1)

    mask = 1.0 - mask
    mask = mask.masked_fill(mask.to(torch.bool), torch.finfo(dtype).min)

    ninf = torch.finfo(dtype).min
    full_text_mask = ((mask != ninf).any(dim=-1).type_as(mask)[..., None])
    mask *= full_text_mask
    # (num_prompt_tokens, num_encoder_tokens)
    return mask