mllama4.py 38 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
#
# Copyright 2025 the LLAMA4, Meta Inc., vLLM, and HuggingFace Inc. team.
# All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from collections.abc import Iterable, Mapping
from itertools import tee
22
from typing import Annotated, Literal, Optional, Union
23
24
25
26
27
28
29

import torch
from torch import nn
from transformers import BatchFeature, Llama4Config, Llama4VisionConfig
from transformers.image_utils import SizeDict
from transformers.models.llama4 import Llama4Processor
from transformers.models.llama4.image_processing_llama4_fast import (
30
31
32
    find_supported_resolutions,
    get_best_fit,
)
33
34
35

from vllm.attention.layer import MultiHeadAttention
from vllm.config import VllmConfig
36
from vllm.config.multimodal import BaseDummyOptions
37
from vllm.distributed import get_tensor_model_parallel_world_size
38
39
40
41
42
43
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    QKVParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
44
45
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
46
from vllm.model_executor.model_loader.utils import initialize_model
47
48
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal import MULTIMODAL_REGISTRY
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
)
from vllm.multimodal.parse import ImageProcessorItems, ImageSize, MultiModalDataItems
from vllm.multimodal.processing import (
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    InputProcessingContext,
    PromptReplacement,
    PromptUpdate,
    PromptUpdateDetails,
)
63
from vllm.multimodal.profiling import BaseDummyInputsBuilder
64
from vllm.sequence import IntermediateTensors
65
from vllm.utils.tensor_schema import TensorSchema, TensorShape
66
67
68

from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .llama4 import Llama4ForCausalLM
69
from .utils import AutoWeightsLoader, flatten_bn, maybe_prefix
70
from .vision import run_dp_sharded_vision_model
71
72


73
class Llama4ImagePatchInputs(TensorSchema):
74
    """
75
76
77
78
79
    Dimensions:
        - batch_size: Batch size
        - total_num_chunks: Batch size * number of chunks
        - num_channels: Number of channels
        - image_size: Size of each image
80
    """
81
82
83

    type: Literal["pixel_values"] = "pixel_values"

84
85
86
87
    flat_data: Annotated[
        torch.Tensor,
        TensorShape("total_num_chunks", "num_channels", "image_size", "image_size"),
    ]
88
89

    patches_per_image: Annotated[torch.Tensor, TensorShape("batch_size")]
90
91
    """
    The number of total patches for each image in the batch.
92
    
93
94
95
    This is used to split the embeddings which has the first two dimensions
    flattened just like `flat_data`.
    """
96

97
    aspect_ratios: Annotated[torch.Tensor, TensorShape("batch_size", 2)]
98
99
100
    """
    A list of aspect ratios corresponding to the number of tiles
    in each dimension that each image in the batch corresponds to.
101
    Each aspect ratio is a pair (ratio_h, ratio_w).
102
103
104
105
    """


class Llama4VisionMLP(nn.Module):
106
107
108
109
110
111
112
113
114
115
116
    def __init__(
        self,
        input_size: int,
        intermediate_size: int,
        output_size: int,
        bias: bool,
        output_activation: bool,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
        use_data_parallel: bool = False,
    ):
117
        super().__init__()
118
        self.fc1 = ColumnParallelLinear(
119
120
121
122
123
            input_size=input_size,
            output_size=intermediate_size,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.fc1",
124
            disable_tp=use_data_parallel,
125
        )
126
        self.fc2 = RowParallelLinear(
127
128
129
130
131
            input_size=intermediate_size,
            output_size=output_size,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.fc2",
132
            disable_tp=use_data_parallel,
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
        )
        self.activation_fn = nn.GELU()
        self.output_activation = output_activation

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.fc1(hidden_states)
        hidden_states = self.activation_fn(hidden_states)
        hidden_states, _ = self.fc2(hidden_states)
        if self.output_activation:
            return self.activation_fn(hidden_states)
        return hidden_states


class Llama4MultiModalProjector(nn.Module):
    def __init__(
        self,
        config,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ):
        super().__init__()
        self.linear_1 = ColumnParallelLinear(
            input_size=config.vision_config.vision_output_dim,
            output_size=config.text_config.hidden_size,
            bias=False,
            quant_config=quant_config,
            gather_output=True,
            prefix=f"{prefix}.linear_1",
        )

    def forward(self, image_features):
        hidden_states, _ = self.linear_1(image_features)
        return hidden_states


def pixel_shuffle(input_tensor, shuffle_ratio):
    # input_tensor: [batch_size, num_patches, channels]
    batch_size, num_patches, channels = input_tensor.shape
    patch_size = int(math.sqrt(num_patches))

    input_tensor = input_tensor.view(batch_size, patch_size, patch_size, -1)
    batch_size, height, width, channels = input_tensor.size()

176
177
178
    reshaped_tensor = input_tensor.view(
        batch_size, height, int(width * shuffle_ratio), int(channels / shuffle_ratio)
    )
179
180
    reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous()

181
182
183
184
185
186
    reshaped_tensor = reshaped_tensor.view(
        batch_size,
        int(height * shuffle_ratio),
        int(width * shuffle_ratio),
        int(channels / (shuffle_ratio**2)),
    )
187
188
    reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous()

189
    output_tensor = reshaped_tensor.view(batch_size, -1, reshaped_tensor.shape[-1])
190
191
192
193
194
195
196
197
198
    return output_tensor


class Llama4VisionPixelShuffleMLP(nn.Module):
    def __init__(
        self,
        config,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
199
        use_data_parallel: bool = False,
200
201
202
    ):
        super().__init__()
        self.pixel_shuffle_ratio = config.pixel_shuffle_ratio
203
204
205
        self.inner_dim = int(
            config.projector_input_dim // (self.pixel_shuffle_ratio**2)
        )
206
207
208
209
210
211
212
213
        self.output_dim = config.projector_output_dim
        self.mlp = Llama4VisionMLP(
            input_size=config.intermediate_size,
            intermediate_size=config.projector_input_dim,
            output_size=config.projector_output_dim,
            bias=config.multi_modal_projector_bias,
            output_activation=True,
            quant_config=quant_config,
214
215
216
            prefix=f"{prefix}.mlp",
            use_data_parallel=use_data_parallel,
        )
217
218

    def forward(self, encoded_patches: torch.Tensor) -> torch.Tensor:
219
        encoded_patches = pixel_shuffle(encoded_patches, self.pixel_shuffle_ratio)
220
221
222
223
224
225
226
227
228
        return self.mlp(encoded_patches)


class Llama4VisionAttention(nn.Module):
    def __init__(
        self,
        config: Llama4VisionConfig,
        quant_config: Optional[QuantizationConfig],
        prefix: str = "",
229
        use_data_parallel: bool = False,
230
231
232
    ):
        super().__init__()
        self.config = config
233
234
235
        self.tp_size = (
            1 if use_data_parallel else get_tensor_model_parallel_world_size()
        )
236
237
238
239
240
241
242
243
244
245
        self.embed_dim = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = config.hidden_size // self.num_heads
        assert self.num_heads % self.tp_size == 0
        self.num_local_heads = self.num_heads // self.tp_size
        self.q_size = self.num_local_heads * self.head_dim
        self.kv_size = self.num_local_heads * self.head_dim
        self.attention_dropout = config.attention_dropout
        self.scaling = self.head_dim**-0.5

246
247
248
        self.attn = MultiHeadAttention(
            self.num_local_heads, self.head_dim, self.scaling
        )
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281

        if use_data_parallel:
            self.qkv_proj = ReplicatedLinear(
                self.embed_dim,
                self.q_size + 2 * self.kv_size,
                bias=True,
                quant_config=quant_config,
                prefix=f"{prefix}.qkv_proj",
            )
            self.o_proj = ReplicatedLinear(
                self.num_heads * self.head_dim,
                self.embed_dim,
                bias=True,
                quant_config=quant_config,
                prefix=f"{prefix}.o_proj",
            )
        else:
            self.qkv_proj = QKVParallelLinear(
                self.embed_dim,
                self.head_dim,
                self.num_heads,
                bias=True,
                quant_config=quant_config,
                prefix=f"{prefix}.qkv_proj",
            )
            self.o_proj = RowParallelLinear(
                self.num_heads * self.head_dim,
                self.embed_dim,
                bias=True,
                input_is_parallel=True,
                quant_config=quant_config,
                prefix=f"{prefix}.o_proj",
            )
282
283
284
285
286

        self.rotary_emb = get_rope(
            head_size=self.head_dim,
            rotary_dim=config.hidden_size // config.num_attention_heads // 2,
            # number of image patches
287
            max_position=(config.image_size // config.patch_size) ** 2,
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
            base=config.rope_theta,
            rope_scaling={"rope_type": "mllama4"},
            is_neox_style=False,
            dtype=torch.complex64,  # important
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        input_shape = hidden_states.shape[:-1]

        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)

        q = q.view(q.shape[0], q.shape[1], self.num_local_heads, self.head_dim)
        k = k.view(k.shape[0], k.shape[1], self.num_local_heads, self.head_dim)
        q, k = self.rotary_emb(q, k)

        q = q.view(q.shape[0], q.shape[1], -1)
        k = k.view(k.shape[0], k.shape[1], -1)

        attn_output = self.attn(q, k, v)
        attn_output = attn_output.reshape(*input_shape, -1).contiguous()
        attn_output, _ = self.o_proj(attn_output)

        return attn_output


class Llama4VisionEncoderLayer(nn.Module):
    def __init__(
        self,
        config: Llama4VisionConfig,
        quant_config: Optional[QuantizationConfig],
        prefix: str = "",
323
        use_data_parallel: bool = False,
324
325
326
327
328
329
    ):
        super().__init__()
        self.hidden_size = config.hidden_size
        self.num_attention_heads = config.num_attention_heads
        self.intermediate_size = config.intermediate_size

330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
        self.self_attn = Llama4VisionAttention(
            config,
            quant_config=quant_config,
            prefix=f"{prefix}.self_attn",
            use_data_parallel=use_data_parallel,
        )
        self.mlp = Llama4VisionMLP(
            input_size=config.hidden_size,
            intermediate_size=config.intermediate_size,
            output_size=config.hidden_size,
            bias=True,
            output_activation=False,
            quant_config=quant_config,
            prefix=f"{prefix}.mlp",
            use_data_parallel=use_data_parallel,
        )
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365

        self.input_layernorm = nn.LayerNorm(config.hidden_size)
        self.post_attention_layernorm = nn.LayerNorm(config.hidden_size)

    def forward(
        self,
        hidden_state: torch.Tensor,
    ):
        # Self Attention
        residual = hidden_state
        hidden_state = self.input_layernorm(hidden_state)
        hidden_state = self.self_attn(hidden_state)
        hidden_state = residual + hidden_state

        # Feed forward
        residual = hidden_state
        hidden_state = self.post_attention_layernorm(hidden_state)
        hidden_state = self.mlp(hidden_state)
        hidden_state = residual + hidden_state

366
        outputs = (hidden_state,)
367
368
369
370
371
372
373
374
375
        return outputs


class Llama4VisionEncoder(nn.Module):
    def __init__(
        self,
        config: Llama4VisionConfig,
        quant_config: Optional[QuantizationConfig],
        prefix: str = "",
376
        use_data_parallel: bool = False,
377
378
379
    ):
        super().__init__()
        self.config = config
380
381
382
383
384
385
386
387
388
389
390
        self.layers = nn.ModuleList(
            [
                Llama4VisionEncoderLayer(
                    config,
                    quant_config=quant_config,
                    prefix=f"{prefix}.layers.{layer_idx}",
                    use_data_parallel=use_data_parallel,
                )
                for layer_idx in range(config.num_hidden_layers)
            ]
        )
391
392
393
394
395
396
397

    def forward(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        r"""
        Args:
398
            hidden_states: Input tensor of shape
399
                (batch_size, sequence_length, hidden_size).
400
                Hidden states from the model embeddings, representing
401
                the input tokens.
402
403
404
405
406
407
408
409
410
411
412
413
                associated vectors than the model's internal embedding
                lookup matrix.
        """

        for encoder_layer in self.layers:
            layer_outputs = encoder_layer(hidden_states)
            hidden_states = layer_outputs[0]

        return hidden_states


class Llama4UnfoldConvolution(nn.Module):
414
415
416
417
418
419
420
    def __init__(
        self,
        config: Llama4VisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
        use_data_parallel: bool = False,
    ):
421
422
423
424
        super().__init__()
        kernel_size = config.patch_size
        if isinstance(kernel_size, int):
            kernel_size = (kernel_size, kernel_size)
425
        self.unfold = torch.nn.Unfold(kernel_size=kernel_size, stride=config.patch_size)
426
427
428
429
430
431
432
433
434
        self.linear = ColumnParallelLinear(
            input_size=config.num_channels * kernel_size[0] * kernel_size[1],
            output_size=config.hidden_size,
            bias=False,
            gather_output=True,
            quant_config=quant_config,
            prefix=f"{prefix}.linear",
            disable_tp=use_data_parallel,
        )
435
436
437
438
439
440
441
442
443
444
445
446
447
448

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states = self.unfold(hidden_states)
        hidden_states = hidden_states.permute(0, 2, 1)
        hidden_states, _ = self.linear(hidden_states)
        return hidden_states


class Llama4VisionModel(nn.Module):
    def __init__(
        self,
        config: Llama4VisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
449
        use_data_parallel: bool = False,
450
451
452
453
454
455
456
457
    ):
        super().__init__()
        self.config = config
        self.image_size = config.image_size
        self.patch_size = config.patch_size
        self.hidden_size = config.hidden_size
        self.num_channels = config.num_channels

458
        self.num_patches = (self.image_size // self.patch_size) ** 2 + 1
459
460
461
462
463
        self.scale = config.hidden_size**-0.5

        self.patch_embedding = Llama4UnfoldConvolution(
            config,
            quant_config=quant_config,
464
465
466
            prefix=f"{prefix}.patch_embedding",
            use_data_parallel=use_data_parallel,
        )
467

468
        self.class_embedding = nn.Parameter(self.scale * torch.randn(self.hidden_size))
469
        self.positional_embedding_vlm = nn.Parameter(
470
471
            self.scale * torch.randn(self.num_patches, self.hidden_size)
        )
472
473
474
475
476
477

        # layer norms
        self.layernorm_pre = nn.LayerNorm(self.hidden_size, eps=1e-5)
        self.layernorm_post = nn.LayerNorm(self.hidden_size, eps=1e-5)

        # encoders
478
479
480
481
482
483
        self.model = Llama4VisionEncoder(
            config,
            quant_config=quant_config,
            prefix=f"{prefix}.model",
            use_data_parallel=use_data_parallel,
        )
484
        self.vision_adapter = Llama4VisionPixelShuffleMLP(
485
486
487
488
489
            config,
            quant_config,
            prefix=f"{prefix}.vision_adapter",
            use_data_parallel=use_data_parallel,
        )
490
491
492
493
494
495
496
497
498
499

    def forward(
        self,
        images_flattened: torch.Tensor,
    ) -> torch.Tensor:
        # Patch embedding
        hidden_state = self.patch_embedding(images_flattened)
        num_tiles, num_patches, hidden_dim = hidden_state.shape

        # Add cls token
500
501
502
        class_embedding = self.class_embedding.expand(
            hidden_state.shape[0], 1, hidden_state.shape[-1]
        )
503
504
505
506
507
508
509
510
511
512
513
        hidden_state = torch.cat([hidden_state, class_embedding], dim=1)
        num_patches += 1

        # Position embeddings
        hidden_state = hidden_state.reshape(
            num_tiles,
            1,
            num_patches,
            hidden_dim,
        )
        positional_embedding = self.positional_embedding_vlm.to(
514
515
            dtype=hidden_state.dtype, device=hidden_state.device
        )
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
        hidden_state = hidden_state + positional_embedding
        hidden_state = self.layernorm_pre(hidden_state)
        hidden_state = hidden_state.view(num_tiles, -1, hidden_dim)

        # Apply encoder
        hidden_state = self.model(hidden_state)
        hidden_state = self.layernorm_post(hidden_state)

        # Remove CLS token output
        hidden_state = hidden_state[:, :-1, :]

        # now, we use Llama4VisionPixelShuffle + mlp to project embeddings
        hidden_state = self.vision_adapter(hidden_state)

        return hidden_state


class Mllama4ProcessingInfo(BaseProcessingInfo):
    def __init__(self, ctx: InputProcessingContext) -> None:
        super().__init__(ctx)

    def get_hf_config(self) -> Llama4Config:
        return self.ctx.get_hf_config(Llama4Config)

    def get_hf_processor(self, **kwargs: object) -> Llama4Processor:
541
542
543
        return self.ctx.get_hf_processor(
            Llama4Processor, use_fast=kwargs.pop("use_fast", True), **kwargs
        )
544
545

    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
546
547
548
        # Although vLLM can support more images from an infra capability
        # perspective, we do not recommend using >10 images in practice.
        return {"image": None}
549
550
551
552
553
554

    @staticmethod
    def get_patch_per_chunk(vision_config: Llama4VisionConfig) -> int:
        image_size = vision_config.image_size
        patch_size = vision_config.patch_size

555
556
557
        assert image_size % patch_size == 0, (
            f"chunk size {image_size} should be multiple of "
        )
558
559
560
        f"patch_size {patch_size}"

        ds_ratio = int(round(1.0 / (vision_config.pixel_shuffle_ratio**2)))
561
        return (image_size // patch_size) ** 2 // ds_ratio
562
563
564
565
566
567
568
569
570

    def get_max_num_tiles(self) -> int:
        image_processor = self.get_hf_processor().image_processor
        return image_processor.max_patches

    def get_image_size_with_most_features(self) -> ImageSize:
        vision_config = self.get_hf_config().vision_config
        image_size = vision_config.image_size
        # Result in the max possible feature size (h:w = 16:1)
571
        return ImageSize(height=self.get_max_num_tiles() * image_size, width=image_size)
572
573


574
class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo]):
575
576
577
578
579
    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
580
        tok_kwargs: Mapping[str, object],
581
582
583
584
585
586
587
588
589
    ) -> BatchFeature:
        tokenizer = self.info.get_tokenizer()

        if mm_data is None:
            return tokenizer(prompt, add_special_tokens=False)  # exclude bos
        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
590
            tok_kwargs=tok_kwargs,
591
592
593
594
595
596
597
        )

        processor = self.info.get_hf_processor(**mm_kwargs)
        image_processor = processor.image_processor
        vision_config = self.info.get_hf_config().vision_config

        if processed_outputs.get("pixel_values") is not None:
598
599
600
            assert "images" in mm_data, (
                "images expected to be in mm_data when pixel_values is present"
            )
601
602

            images = mm_data["images"]
603
604
605
606
607
            parsed_images = (
                self._get_data_parser()
                .parse_mm_data({"image": images})
                .get_items("image", ImageProcessorItems)
            )
608
609
610
611
612
613
614
615
616
617

            tile_size = vision_config.image_size
            possible_resolutions = find_supported_resolutions(
                max_num_chunks=self.info.get_max_num_tiles(),
                patch_size=SizeDict(height=tile_size, width=tile_size),
            )
            best_fit_sizes = [
                get_best_fit(
                    (image.size[1], image.size[0]),
                    torch.tensor(possible_resolutions),
618
                    resize_to_max_canvas=image_processor.resize_to_max_canvas,
619
620
                )
                for image in parsed_images
621
622
            ]
            # TODO tile height/width do not necessarily need to match
623
624
625
626
            aspect_ratios = [
                (image_size[0] // tile_size, image_size[1] // tile_size)
                for image_size in best_fit_sizes
            ]
627
            patches_per_image = [
628
                1 if r_h * r_w == 1 else 1 + r_h * r_w for (r_h, r_w) in aspect_ratios
629
630
            ]

631
            processed_outputs["aspect_ratios"] = torch.tensor(aspect_ratios)
632
            processed_outputs["patches_per_image"] = torch.tensor(patches_per_image)
633
634
635
636
637
638
639
640
641
642
643

        return processed_outputs

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        patches_per_image = hf_inputs.get("patches_per_image", torch.empty(0))
        return dict(
            pixel_values=MultiModalFieldConfig.flat_from_sizes(
644
645
                "image", patches_per_image
            ),
646
647
648
649
650
651
652
653
            patches_per_image=MultiModalFieldConfig.batched("image"),
            aspect_ratios=MultiModalFieldConfig.batched("image"),
        )

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
654
        out_mm_kwargs: MultiModalKwargsItems,
655
    ) -> list[PromptUpdate]:
656
657
658
659
660
661
        config = self.info.get_hf_config()
        vision_config = config.vision_config

        num_patches_per_chunk = self.info.get_patch_per_chunk(vision_config)
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
        image_token = hf_processor.image_token
662
        img_patch_token = hf_processor.img_patch_token
663
664

        def get_replacement(item_idx: int):
665
666
            out_item = out_mm_kwargs["image"][item_idx]
            aspect_ratio = out_item["aspect_ratios"].data
667
668

            repl = hf_processor._prompt_split_image(
669
                aspect_ratio=aspect_ratio,
670
671
672
673
                num_patches_per_chunk=num_patches_per_chunk,
            )

            return PromptUpdateDetails.select_text(repl, img_patch_token)
674
675
676
677
678
679
680
681
682
683
684

        return [
            PromptReplacement(
                modality="image",
                target=image_token,
                replacement=get_replacement,
            )
        ]


class Mllama4DummyInputsBuilder(BaseDummyInputsBuilder[Mllama4ProcessingInfo]):
685
686
687
688
689
690
691
692
693
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

        processor = self.info.get_hf_processor()
        image_token = processor.fake_image_token

        return image_token * num_images

    def get_dummy_mm_data(
694
695
696
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
697
        mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
698
    ) -> MultiModalDataDict:
699
700
        num_images = mm_counts.get("image", 0)

701
        (target_width, target_height) = self.info.get_image_size_with_most_features()
702

703
704
        image_overrides = mm_options.get("image") if mm_options else None

705
        return {
706
707
708
709
710
711
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            )
712
713
714
715
716
717
718
719
        }


@MULTIMODAL_REGISTRY.register_processor(
    Mllama4MultiModalProcessor,
    info=Mllama4ProcessingInfo,
    dummy_inputs=Mllama4DummyInputsBuilder,
)
720
class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
721
722
    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
723
        "gate_up_proj": ["gate_proj", "up_proj"],
724
725
    }

726
727
    supports_encoder_tp_data = True

728
729
730
731
732
733
734
    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
        if modality.startswith("image"):
            return "<|image|>"

        raise ValueError("Only image modality is supported")

735
736
737
738
739
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
740
741
        self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"

742
743
744
        self.config = config
        self.quant_config = quant_config
        self.multimodal_config = multimodal_config
745
746
747
748
749
750
751
752
        if multimodal_config.get_limit_per_prompt("image"):
            self.vision_model = Llama4VisionModel(
                config.vision_config,
                None,
                prefix=maybe_prefix(prefix, "vision_model"),
                use_data_parallel=self.use_data_parallel,
            )
            self.multi_modal_projector = Llama4MultiModalProjector(
753
754
                self.config, None, prefix=maybe_prefix(prefix, "multi_modal_projector")
            )
755
756
757
        else:
            self.vision_model = None
            self.multi_modal_projector = None
758
        self.language_model = initialize_model(
759
760
761
            vllm_config=vllm_config.with_hf_config(
                config.text_config, ["LlamaForCausalLM"]
            ),
762
763
764
765
766
            prefix=maybe_prefix(prefix, "language_model"),
            model_class=Llama4ForCausalLM,
        )

        self.make_empty_intermediate_tensors = (
767
768
            self.language_model.make_empty_intermediate_tensors
        )
769
770

    def _parse_and_validate_image_input(
771
772
        self, **kwargs: object
    ) -> Optional[Llama4ImagePatchInputs]:
773
774
775
776
777
778
779
780
781
        # num_images, 1, num_chunks, channel, image_size, image_size
        pixel_values = kwargs.pop("pixel_values", None)
        if pixel_values is None:
            return None

        # num_images x num_chunks, channel, image_size, image_size
        # TODO: confirm handling for variable lengths
        flat_pixel_values = flatten_bn(pixel_values, concat=True)
        patches_per_image = flatten_bn(kwargs.pop("patches_per_image"))
782
783
784
        aspect_ratios = kwargs.pop("aspect_ratios")
        if aspect_ratios.ndim == 3:
            aspect_ratios = aspect_ratios.squeeze(1)
785
786
787
788
789
790
791
792
793

        return Llama4ImagePatchInputs(
            type="pixel_values",
            flat_data=flat_pixel_values,
            patches_per_image=patches_per_image,
            aspect_ratios=aspect_ratios,
        )

    def _process_image_input(
794
795
        self, image_input: Llama4ImagePatchInputs
    ) -> MultiModalEmbeddings:
796
        assert self.vision_model and self.multi_modal_projector
797
798
        flat_data = image_input["flat_data"]
        patches_per_image = image_input["patches_per_image"].tolist()
799

800
801
802
        # shard image input
        if self.use_data_parallel:
            vision_embeddings_flat = run_dp_sharded_vision_model(
803
804
                flat_data, self.vision_model
            )
805
806
807
        else:
            vision_embeddings_flat = self.vision_model(flat_data)

808
        vision_embeddings_flat = self.multi_modal_projector(vision_embeddings_flat)
809
810
811
812
813

        return [
            img.flatten(0, 1)
            for img in vision_embeddings_flat.split(patches_per_image, dim=0)
        ]
814

815
816
817
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

818
    def get_multimodal_embeddings(self, **kwargs) -> MultiModalEmbeddings:
819
820
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
821
            return []
822

823
        return self._process_image_input(image_input)
824
825
826
827
828
829
830
831
832
833
834
835

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        **kwargs: object,
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if intermediate_tensors is not None:
            inputs_embeds = None

836
837
838
        return self.language_model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
839
840
841
842
843

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> Optional[torch.Tensor]:
844
        return self.language_model.compute_logits(hidden_states)
845
846
847

    def separate_weights(
        self,
848
        weights: Iterable[tuple[str, torch.Tensor]],
849
        prefix: str,
850
    ) -> tuple[Iterable[tuple[str, torch.Tensor]], Iterable[tuple[str, torch.Tensor]]]:
851
852
        weights1, weights2 = tee(weights, 2)

853
        def get_prefix_weights() -> Iterable[tuple[str, torch.Tensor]]:
854
855
856
857
            for name, data in weights1:
                if name.startswith(prefix):
                    yield (name, data)

858
        def get_other_weights() -> Iterable[tuple[str, torch.Tensor]]:
859
860
861
862
863
864
            for name, data in weights2:
                if not name.startswith(prefix):
                    yield (name, data)

        return get_prefix_weights(), get_other_weights()

865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
    def _consolidate_qkv_weights(
        self, weights: Iterable[tuple[str, torch.Tensor]]
    ) -> Iterable[tuple[str, torch.Tensor]]:
        qkv_idx_mappings = {
            ".self_attn.q_proj": 0,
            ".self_attn.k_proj": 1,
            ".self_attn.v_proj": 2,
        }
        qkv_weights = {}
        for name, loaded_weight in weights:
            for weight_name, idx in qkv_idx_mappings.items():
                if weight_name not in name:
                    continue
                new_name = name.replace(weight_name, ".self_attn.qkv_proj")
                if new_name not in qkv_weights:
                    qkv_weights[new_name] = [None] * 3
                qkv_weights[new_name][idx] = loaded_weight
                break
            else:
                yield name, loaded_weight
        for key, weight in qkv_weights.items():
            qkv_weight = torch.cat(weight, dim=0)
            yield key, qkv_weight

889
890
891
    def _rename_weight_for_modelopt_checkpoint(self, name: str) -> str:
        """Rename weights from ModelOpt llama4 fp8 checkpoints to vLLM
        format."""
892
893
894
895
896
897
        if name.startswith("model.") or name.startswith("language_model.model."):
            renamed = (
                name.replace("model.", "language_model.model.", 1)
                if name.startswith("model.")
                else name
            )
898
            # Handle expert scale parameters with flat naming
899
900
901
            if "feed_forward.experts." in name and (
                "_input_scale" in name or "_weight_scale" in name
            ):
902
903
                # Map checkpoint naming to vLLM's expected naming
                if "down_proj_input_scale" in renamed:
904
                    return renamed.replace("down_proj_input_scale", "w2_input_scale")
905
                elif "down_proj_weight_scale" in renamed:
906
                    return renamed.replace("down_proj_weight_scale", "w2_weight_scale")
907
                elif "gate_up_proj_input_scale" in renamed:
908
909
910
                    return renamed.replace(
                        "gate_up_proj_input_scale", "w13_input_scale"
                    )
911
                elif "gate_up_proj_weight_scale" in renamed:
912
913
914
                    return renamed.replace(
                        "gate_up_proj_weight_scale", "w13_weight_scale"
                    )
915
916
917
                return renamed

            # Handle attention scale parameters
918
            elif "self_attn." in name and (".k_scale" in name or ".v_scale" in name):
919
920
921
922
923
924
925
                if ".k_proj.k_scale" in renamed:
                    return renamed.replace(".k_proj.k_scale", ".attn.k_scale")
                elif ".v_proj.v_scale" in renamed:
                    return renamed.replace(".v_proj.v_scale", ".attn.v_scale")
                return renamed

            # Standard model.* to language_model.model.* renaming
926
            return renamed
927
928

        elif name.startswith("lm_head.weight"):
929
            return name.replace("lm_head.weight", "language_model.lm_head.weight")
930
931
932
933
934
935
936
937
938
939

        return name

    def _separate_and_rename_weights(
        self, weights: Iterable[tuple[str, torch.Tensor]]
    ) -> tuple[list[tuple[str, torch.Tensor]], list[tuple[str, torch.Tensor]]]:
        """Rename weights and separate them into language_model and other
        weights."""
        language_model_weights = []
        other_weights = []
940

941
942
        for name, weight in weights:
            renamed = self._rename_weight_for_modelopt_checkpoint(name)
943

944
945
946
947
948
949
950
951
            if renamed.startswith("language_model."):
                language_model_weights.append((renamed, weight))
            else:
                other_weights.append((renamed, weight))

        return language_model_weights, other_weights

    def _handle_expert_scale_broadcasting(
952
        self, weights: list[tuple[str, torch.Tensor]], params_dict: dict
953
954
955
956
957
958
959
960
961
962
963
964
    ) -> tuple[list[tuple[str, torch.Tensor]], set[str]]:
        """Handle expert scale parameters that need broadcasting.

        ModelOpt checkpoints use a single value tensor scalar for BMM style
        experts, vLLM expects the scale to be broadcasted across all experts.
        """
        regular_weights = []
        expert_scale_weights = []
        updated_params = set()

        for name, weight in weights:
            # Check if this is an expert scale parameter that needs broadcasting
965
966
967
968
969
            if (
                "feed_forward.experts." in name
                and "scale" in name
                and ".shared_expert" not in name
            ):
970
971
                if name in params_dict:
                    param = params_dict[name]
972
973
974
975
976
                    if (
                        hasattr(param, "data")
                        and param.data.numel() > 1
                        and weight.numel() == 1
                    ):
977
978
979
980
981
982
983
984
985
986
987
                        # Broadcast single value to all experts
                        param.data.fill_(weight.item())
                        updated_params.add(name)
                        continue

                expert_scale_weights.append((name, weight))
            else:
                regular_weights.append((name, weight))

        return regular_weights, expert_scale_weights, updated_params

988
989
990
991
992
993
    def _load_other_weights(
        self,
        other_weights: Iterable[tuple[str, torch.Tensor]],
        params_dict: dict,
        stacked_params_mapping: list,
    ) -> set[str]:
994
995
        """Load non-language-model weights with stacking support."""
        updated_params = set()
996

997
998
999
        if self.use_data_parallel:
            other_weights = self._consolidate_qkv_weights(other_weights)

1000
        for name, loaded_weight in other_weights:
1001
            # Try stacked parameter mapping first
1002
            for param_name, weight_name, shard_id in stacked_params_mapping:
1003
                if weight_name not in name or self.use_data_parallel:
1004
1005
1006
1007
1008
1009
1010
1011
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
                updated_params.add(name)
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
1012
                # Use regular weight loading
1013
                param = params_dict[name]
1014
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
1015
1016
                weight_loader(param, loaded_weight)
                updated_params.add(name)
1017
1018
1019

        return updated_params

1020
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".self_attn.qkv_proj", ".self_attn.q_proj", "q"),
            (".self_attn.qkv_proj", ".self_attn.k_proj", "k"),
            (".self_attn.qkv_proj", ".self_attn.v_proj", "v"),
            # Shared expert gate_up_proj stacking
            (".shared_expert.gate_up_proj", ".shared_expert.gate_proj", 0),
            (".shared_expert.gate_up_proj", ".shared_expert.up_proj", 1),
            # Feed forward gate_up_proj stacking (for non-MoE layers if any)
            (".feed_forward.gate_up_proj", ".feed_forward.gate_proj", 0),
            (".feed_forward.gate_up_proj", ".feed_forward.up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
        updated_params: set[str] = set()

        # Separate and rename weights
1037
1038
1039
        language_model_weights, other_weights = self._separate_and_rename_weights(
            weights
        )
1040

1041
1042
1043
1044
        # Skip loading vision model and projector if they're not initialized.
        if self.vision_model is None and self.multi_modal_projector is None:
            other_weights = []

1045
1046
        # Handle expert scale parameters
        regular_weights, expert_scale_weights, updated_params_from_experts = (
1047
1048
            self._handle_expert_scale_broadcasting(language_model_weights, params_dict)
        )
1049
1050
1051
1052
1053
1054
1055
1056
        updated_params.update(updated_params_from_experts)

        loader = AutoWeightsLoader(self)
        loaded_language_model_params = loader.load_weights(regular_weights)
        assert loaded_language_model_params is not None
        updated_params.update(loaded_language_model_params)

        if expert_scale_weights:
1057
            loaded_expert_scale_params = loader.load_weights(expert_scale_weights)
1058
1059
1060
1061
            if loaded_expert_scale_params:
                updated_params.update(loaded_expert_scale_params)

        updated_params.update(
1062
1063
            self._load_other_weights(other_weights, params_dict, stacked_params_mapping)
        )
1064

1065
        return updated_params