mllama4.py 41.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
#
# Copyright 2025 the LLAMA4, Meta Inc., vLLM, and HuggingFace Inc. team.
# All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from collections.abc import Iterable, Mapping
from itertools import tee
22
from typing import Annotated, Literal
23
24
25
26
27
28
29

import torch
from torch import nn
from transformers import BatchFeature, Llama4Config, Llama4VisionConfig
from transformers.image_utils import SizeDict
from transformers.models.llama4 import Llama4Processor
from transformers.models.llama4.image_processing_llama4_fast import (
30
31
32
    find_supported_resolutions,
    get_best_fit,
)
33

34
35
36
37
from vllm.compilation.decorators import (
    should_torch_compile_mm_encoder,
    support_torch_compile,
)
38
from vllm.config import VllmConfig, set_current_vllm_config
39
from vllm.config.multimodal import BaseDummyOptions
40
from vllm.distributed import get_tensor_model_parallel_world_size
41
from vllm.inputs import MultiModalDataDict
42
from vllm.model_executor.layers.attention import MMEncoderAttention
43
from vllm.model_executor.layers.fused_moe import FusedMoE
44
45
46
47
48
49
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    QKVParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
50
51
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
52
from vllm.model_executor.model_loader.utils import initialize_model
53
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
54
from vllm.model_executor.models.module_mapping import MultiModelKeys
55
from vllm.multimodal import MULTIMODAL_REGISTRY
56
57
58
59
60
61
from vllm.multimodal.inputs import (
    MultiModalFieldConfig,
    MultiModalKwargsItems,
)
from vllm.multimodal.parse import ImageProcessorItems, ImageSize, MultiModalDataItems
from vllm.multimodal.processing import (
62
    BaseDummyInputsBuilder,
63
64
65
66
67
68
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptReplacement,
    PromptUpdate,
    PromptUpdateDetails,
)
69
from vllm.sequence import IntermediateTensors
70
from vllm.utils.tensor_schema import TensorSchema, TensorShape
71

72
from .interfaces import (
73
    MixtureOfExperts,
74
75
    MultiModalEmbeddings,
    SupportsEagle3,
76
    SupportsLoRA,
77
78
79
    SupportsMultiModal,
    SupportsPP,
)
80
from .llama4 import Llama4ForCausalLM
81
from .utils import AutoWeightsLoader, StageMissingLayer, maybe_prefix
82
from .vision import is_vit_use_data_parallel, run_dp_sharded_vision_model
83
84


85
class Llama4ImagePatchInputs(TensorSchema):
86
    """
87
88
89
90
91
    Dimensions:
        - batch_size: Batch size
        - total_num_chunks: Batch size * number of chunks
        - num_channels: Number of channels
        - image_size: Size of each image
92
    """
93
94
95

    type: Literal["pixel_values"] = "pixel_values"

96
    pixel_values: Annotated[
97
98
99
        torch.Tensor,
        TensorShape("total_num_chunks", "num_channels", "image_size", "image_size"),
    ]
100
101

    patches_per_image: Annotated[torch.Tensor, TensorShape("batch_size")]
102
103
    """
    The number of total patches for each image in the batch.
104
    
105
    This is used to split the embeddings which has the first two dimensions
106
    flattened just like `pixel_values`.
107
    """
108

109
    aspect_ratios: Annotated[torch.Tensor, TensorShape("batch_size", 2)]
110
111
112
    """
    A list of aspect ratios corresponding to the number of tiles
    in each dimension that each image in the batch corresponds to.
113
    Each aspect ratio is a pair (ratio_h, ratio_w).
114
115
116
117
    """


class Llama4VisionMLP(nn.Module):
118
119
120
121
122
123
124
    def __init__(
        self,
        input_size: int,
        intermediate_size: int,
        output_size: int,
        bias: bool,
        output_activation: bool,
125
        quant_config: QuantizationConfig | None = None,
126
127
        prefix: str = "",
    ):
128
        super().__init__()
129
        use_data_parallel = is_vit_use_data_parallel()
130
        self.fc1 = ColumnParallelLinear(
131
132
133
134
135
            input_size=input_size,
            output_size=intermediate_size,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.fc1",
136
            disable_tp=use_data_parallel,
137
        )
138
        self.fc2 = RowParallelLinear(
139
140
141
142
143
            input_size=intermediate_size,
            output_size=output_size,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.fc2",
144
            disable_tp=use_data_parallel,
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
        )
        self.activation_fn = nn.GELU()
        self.output_activation = output_activation

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.fc1(hidden_states)
        hidden_states = self.activation_fn(hidden_states)
        hidden_states, _ = self.fc2(hidden_states)
        if self.output_activation:
            return self.activation_fn(hidden_states)
        return hidden_states


class Llama4MultiModalProjector(nn.Module):
    def __init__(
        self,
        config,
162
        quant_config: QuantizationConfig | None = None,
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
        prefix: str = "",
    ):
        super().__init__()
        self.linear_1 = ColumnParallelLinear(
            input_size=config.vision_config.vision_output_dim,
            output_size=config.text_config.hidden_size,
            bias=False,
            quant_config=quant_config,
            gather_output=True,
            prefix=f"{prefix}.linear_1",
        )

    def forward(self, image_features):
        hidden_states, _ = self.linear_1(image_features)
        return hidden_states


def pixel_shuffle(input_tensor, shuffle_ratio):
    # input_tensor: [batch_size, num_patches, channels]
    batch_size, num_patches, channels = input_tensor.shape
    patch_size = int(math.sqrt(num_patches))

    input_tensor = input_tensor.view(batch_size, patch_size, patch_size, -1)
    batch_size, height, width, channels = input_tensor.size()

188
189
190
    reshaped_tensor = input_tensor.view(
        batch_size, height, int(width * shuffle_ratio), int(channels / shuffle_ratio)
    )
191
192
    reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous()

193
194
195
196
197
198
    reshaped_tensor = reshaped_tensor.view(
        batch_size,
        int(height * shuffle_ratio),
        int(width * shuffle_ratio),
        int(channels / (shuffle_ratio**2)),
    )
199
200
    reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous()

201
    output_tensor = reshaped_tensor.view(batch_size, -1, reshaped_tensor.shape[-1])
202
203
204
205
206
207
208
    return output_tensor


class Llama4VisionPixelShuffleMLP(nn.Module):
    def __init__(
        self,
        config,
209
        quant_config: QuantizationConfig | None = None,
210
211
212
213
        prefix: str = "",
    ):
        super().__init__()
        self.pixel_shuffle_ratio = config.pixel_shuffle_ratio
214
215
216
        self.inner_dim = int(
            config.projector_input_dim // (self.pixel_shuffle_ratio**2)
        )
217
218
219
220
221
222
223
224
        self.output_dim = config.projector_output_dim
        self.mlp = Llama4VisionMLP(
            input_size=config.intermediate_size,
            intermediate_size=config.projector_input_dim,
            output_size=config.projector_output_dim,
            bias=config.multi_modal_projector_bias,
            output_activation=True,
            quant_config=quant_config,
225
226
            prefix=f"{prefix}.mlp",
        )
227
228

    def forward(self, encoded_patches: torch.Tensor) -> torch.Tensor:
229
        encoded_patches = pixel_shuffle(encoded_patches, self.pixel_shuffle_ratio)
230
231
232
233
234
235
236
        return self.mlp(encoded_patches)


class Llama4VisionAttention(nn.Module):
    def __init__(
        self,
        config: Llama4VisionConfig,
237
        quant_config: QuantizationConfig | None,
238
239
240
241
        prefix: str = "",
    ):
        super().__init__()
        self.config = config
242
        use_data_parallel = is_vit_use_data_parallel()
243
244
245
        self.tp_size = (
            1 if use_data_parallel else get_tensor_model_parallel_world_size()
        )
246
247
248
249
250
251
252
253
254
255
        self.embed_dim = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = config.hidden_size // self.num_heads
        assert self.num_heads % self.tp_size == 0
        self.num_local_heads = self.num_heads // self.tp_size
        self.q_size = self.num_local_heads * self.head_dim
        self.kv_size = self.num_local_heads * self.head_dim
        self.attention_dropout = config.attention_dropout
        self.scaling = self.head_dim**-0.5

256
        self.attn = MMEncoderAttention(
257
258
259
            self.num_local_heads,
            self.head_dim,
            self.scaling,
260
            prefix=f"{prefix}.attn",
261
        )
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294

        if use_data_parallel:
            self.qkv_proj = ReplicatedLinear(
                self.embed_dim,
                self.q_size + 2 * self.kv_size,
                bias=True,
                quant_config=quant_config,
                prefix=f"{prefix}.qkv_proj",
            )
            self.o_proj = ReplicatedLinear(
                self.num_heads * self.head_dim,
                self.embed_dim,
                bias=True,
                quant_config=quant_config,
                prefix=f"{prefix}.o_proj",
            )
        else:
            self.qkv_proj = QKVParallelLinear(
                self.embed_dim,
                self.head_dim,
                self.num_heads,
                bias=True,
                quant_config=quant_config,
                prefix=f"{prefix}.qkv_proj",
            )
            self.o_proj = RowParallelLinear(
                self.num_heads * self.head_dim,
                self.embed_dim,
                bias=True,
                input_is_parallel=True,
                quant_config=quant_config,
                prefix=f"{prefix}.o_proj",
            )
295

296
297
298
        rope_parameters = {
            "rope_type": "mllama4",
            "rope_theta": config.rope_parameters["rope_theta"],
299
            "partial_rotary_factor": 0.5,
300
301
        }

302
303
304
        self.rotary_emb = get_rope(
            head_size=self.head_dim,
            # number of image patches
305
            max_position=(config.image_size // config.patch_size) ** 2,
306
            rope_parameters=rope_parameters,
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
            is_neox_style=False,
            dtype=torch.complex64,  # important
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        input_shape = hidden_states.shape[:-1]

        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)

        q = q.view(q.shape[0], q.shape[1], self.num_local_heads, self.head_dim)
        k = k.view(k.shape[0], k.shape[1], self.num_local_heads, self.head_dim)
        q, k = self.rotary_emb(q, k)

        q = q.view(q.shape[0], q.shape[1], -1)
        k = k.view(k.shape[0], k.shape[1], -1)

        attn_output = self.attn(q, k, v)
        attn_output = attn_output.reshape(*input_shape, -1).contiguous()
        attn_output, _ = self.o_proj(attn_output)

        return attn_output


class Llama4VisionEncoderLayer(nn.Module):
    def __init__(
        self,
        config: Llama4VisionConfig,
338
        quant_config: QuantizationConfig | None,
339
340
341
342
343
344
345
        prefix: str = "",
    ):
        super().__init__()
        self.hidden_size = config.hidden_size
        self.num_attention_heads = config.num_attention_heads
        self.intermediate_size = config.intermediate_size

346
347
348
349
350
351
352
353
354
355
356
357
358
359
        self.self_attn = Llama4VisionAttention(
            config,
            quant_config=quant_config,
            prefix=f"{prefix}.self_attn",
        )
        self.mlp = Llama4VisionMLP(
            input_size=config.hidden_size,
            intermediate_size=config.intermediate_size,
            output_size=config.hidden_size,
            bias=True,
            output_activation=False,
            quant_config=quant_config,
            prefix=f"{prefix}.mlp",
        )
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379

        self.input_layernorm = nn.LayerNorm(config.hidden_size)
        self.post_attention_layernorm = nn.LayerNorm(config.hidden_size)

    def forward(
        self,
        hidden_state: torch.Tensor,
    ):
        # Self Attention
        residual = hidden_state
        hidden_state = self.input_layernorm(hidden_state)
        hidden_state = self.self_attn(hidden_state)
        hidden_state = residual + hidden_state

        # Feed forward
        residual = hidden_state
        hidden_state = self.post_attention_layernorm(hidden_state)
        hidden_state = self.mlp(hidden_state)
        hidden_state = residual + hidden_state

380
        outputs = (hidden_state,)
381
382
383
384
385
386
387
        return outputs


class Llama4VisionEncoder(nn.Module):
    def __init__(
        self,
        config: Llama4VisionConfig,
388
        quant_config: QuantizationConfig | None,
389
390
391
392
        prefix: str = "",
    ):
        super().__init__()
        self.config = config
393
394
395
396
397
398
399
400
401
402
        self.layers = nn.ModuleList(
            [
                Llama4VisionEncoderLayer(
                    config,
                    quant_config=quant_config,
                    prefix=f"{prefix}.layers.{layer_idx}",
                )
                for layer_idx in range(config.num_hidden_layers)
            ]
        )
403
404
405
406
407
408
409

    def forward(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        r"""
        Args:
410
            hidden_states: Input tensor of shape
411
                (batch_size, sequence_length, hidden_size).
412
                Hidden states from the model embeddings, representing
413
                the input tokens.
414
415
416
417
418
419
420
421
422
423
424
425
                associated vectors than the model's internal embedding
                lookup matrix.
        """

        for encoder_layer in self.layers:
            layer_outputs = encoder_layer(hidden_states)
            hidden_states = layer_outputs[0]

        return hidden_states


class Llama4UnfoldConvolution(nn.Module):
426
427
428
    def __init__(
        self,
        config: Llama4VisionConfig,
429
        quant_config: QuantizationConfig | None = None,
430
431
        prefix: str = "",
    ):
432
433
434
435
        super().__init__()
        kernel_size = config.patch_size
        if isinstance(kernel_size, int):
            kernel_size = (kernel_size, kernel_size)
436
        self.unfold = torch.nn.Unfold(kernel_size=kernel_size, stride=config.patch_size)
437
        use_data_parallel = is_vit_use_data_parallel()
438
439
440
441
442
443
444
445
446
        self.linear = ColumnParallelLinear(
            input_size=config.num_channels * kernel_size[0] * kernel_size[1],
            output_size=config.hidden_size,
            bias=False,
            gather_output=True,
            quant_config=quant_config,
            prefix=f"{prefix}.linear",
            disable_tp=use_data_parallel,
        )
447
448
449
450
451
452
453
454

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states = self.unfold(hidden_states)
        hidden_states = hidden_states.permute(0, 2, 1)
        hidden_states, _ = self.linear(hidden_states)
        return hidden_states


455
@support_torch_compile(
456
457
458
    dynamic_arg_dims={"images_flattened": 0},
    enable_if=should_torch_compile_mm_encoder,
    is_encoder=True,
459
)
460
461
462
463
class Llama4VisionModel(nn.Module):
    def __init__(
        self,
        config: Llama4VisionConfig,
464
        quant_config: QuantizationConfig | None = None,
465
466
467
468
469
470
471
472
473
        prefix: str = "",
    ):
        super().__init__()
        self.config = config
        self.image_size = config.image_size
        self.patch_size = config.patch_size
        self.hidden_size = config.hidden_size
        self.num_channels = config.num_channels

474
        self.num_patches = (self.image_size // self.patch_size) ** 2 + 1
475
476
477
478
479
        self.scale = config.hidden_size**-0.5

        self.patch_embedding = Llama4UnfoldConvolution(
            config,
            quant_config=quant_config,
480
481
            prefix=f"{prefix}.patch_embedding",
        )
482

483
        self.class_embedding = nn.Parameter(self.scale * torch.randn(self.hidden_size))
484
        self.positional_embedding_vlm = nn.Parameter(
485
486
            self.scale * torch.randn(self.num_patches, self.hidden_size)
        )
487
488
489
490
491
492

        # layer norms
        self.layernorm_pre = nn.LayerNorm(self.hidden_size, eps=1e-5)
        self.layernorm_post = nn.LayerNorm(self.hidden_size, eps=1e-5)

        # encoders
493
494
495
496
497
        self.model = Llama4VisionEncoder(
            config,
            quant_config=quant_config,
            prefix=f"{prefix}.model",
        )
498

499
        self.vision_adapter = Llama4VisionPixelShuffleMLP(
500
501
502
503
            config,
            quant_config,
            prefix=f"{prefix}.vision_adapter",
        )
504
505
506
507
508
509
510
511
512
513

    def forward(
        self,
        images_flattened: torch.Tensor,
    ) -> torch.Tensor:
        # Patch embedding
        hidden_state = self.patch_embedding(images_flattened)
        num_tiles, num_patches, hidden_dim = hidden_state.shape

        # Add cls token
514
515
516
        class_embedding = self.class_embedding.expand(
            hidden_state.shape[0], 1, hidden_state.shape[-1]
        )
517
518
519
520
521
522
523
524
525
526
527
        hidden_state = torch.cat([hidden_state, class_embedding], dim=1)
        num_patches += 1

        # Position embeddings
        hidden_state = hidden_state.reshape(
            num_tiles,
            1,
            num_patches,
            hidden_dim,
        )
        positional_embedding = self.positional_embedding_vlm.to(
528
529
            dtype=hidden_state.dtype, device=hidden_state.device
        )
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
        hidden_state = hidden_state + positional_embedding
        hidden_state = self.layernorm_pre(hidden_state)
        hidden_state = hidden_state.view(num_tiles, -1, hidden_dim)

        # Apply encoder
        hidden_state = self.model(hidden_state)
        hidden_state = self.layernorm_post(hidden_state)

        # Remove CLS token output
        hidden_state = hidden_state[:, :-1, :]

        # now, we use Llama4VisionPixelShuffle + mlp to project embeddings
        hidden_state = self.vision_adapter(hidden_state)

        return hidden_state


class Mllama4ProcessingInfo(BaseProcessingInfo):
    def get_hf_config(self) -> Llama4Config:
        return self.ctx.get_hf_config(Llama4Config)

    def get_hf_processor(self, **kwargs: object) -> Llama4Processor:
552
553
554
        return self.ctx.get_hf_processor(
            Llama4Processor, use_fast=kwargs.pop("use_fast", True), **kwargs
        )
555

556
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
557
558
559
        # Although vLLM can support more images from an infra capability
        # perspective, we do not recommend using >10 images in practice.
        return {"image": None}
560
561
562
563
564
565

    @staticmethod
    def get_patch_per_chunk(vision_config: Llama4VisionConfig) -> int:
        image_size = vision_config.image_size
        patch_size = vision_config.patch_size

566
567
568
        assert image_size % patch_size == 0, (
            f"chunk size {image_size} should be multiple of "
        )
569
570
571
        f"patch_size {patch_size}"

        ds_ratio = int(round(1.0 / (vision_config.pixel_shuffle_ratio**2)))
572
        return (image_size // patch_size) ** 2 // ds_ratio
573
574
575
576
577
578
579
580
581

    def get_max_num_tiles(self) -> int:
        image_processor = self.get_hf_processor().image_processor
        return image_processor.max_patches

    def get_image_size_with_most_features(self) -> ImageSize:
        vision_config = self.get_hf_config().vision_config
        image_size = vision_config.image_size
        # Result in the max possible feature size (h:w = 16:1)
582
        return ImageSize(height=self.get_max_num_tiles() * image_size, width=image_size)
583
584


585
class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo]):
586
587
588
589
590
    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
591
        tok_kwargs: Mapping[str, object],
592
593
594
595
596
    ) -> BatchFeature:
        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
597
            tok_kwargs=tok_kwargs,
598
599
600
601
602
603
604
        )

        processor = self.info.get_hf_processor(**mm_kwargs)
        image_processor = processor.image_processor
        vision_config = self.info.get_hf_config().vision_config

        if processed_outputs.get("pixel_values") is not None:
605
606
607
            assert "images" in mm_data, (
                "images expected to be in mm_data when pixel_values is present"
            )
608
609

            images = mm_data["images"]
610
611
            mm_items = self.info.parse_mm_data({"image": images}, validate=False)
            parsed_images = mm_items.get_items("image", ImageProcessorItems)
612
613
614
615
616
617
618
619
620
621

            tile_size = vision_config.image_size
            possible_resolutions = find_supported_resolutions(
                max_num_chunks=self.info.get_max_num_tiles(),
                patch_size=SizeDict(height=tile_size, width=tile_size),
            )
            best_fit_sizes = [
                get_best_fit(
                    (image.size[1], image.size[0]),
                    torch.tensor(possible_resolutions),
622
                    resize_to_max_canvas=image_processor.resize_to_max_canvas,
623
624
                )
                for image in parsed_images
625
626
            ]
            # TODO tile height/width do not necessarily need to match
627
628
629
630
            aspect_ratios = [
                (image_size[0] // tile_size, image_size[1] // tile_size)
                for image_size in best_fit_sizes
            ]
631
            patches_per_image = [
632
                1 if r_h * r_w == 1 else 1 + r_h * r_w for (r_h, r_w) in aspect_ratios
633
634
            ]

635
            processed_outputs["aspect_ratios"] = torch.tensor(aspect_ratios)
636
            processed_outputs["patches_per_image"] = torch.tensor(patches_per_image)
637
638
639
640
641
642
643
644
645
646
647

        return processed_outputs

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        patches_per_image = hf_inputs.get("patches_per_image", torch.empty(0))
        return dict(
            pixel_values=MultiModalFieldConfig.flat_from_sizes(
648
649
                "image", patches_per_image
            ),
650
651
652
653
654
655
656
657
            patches_per_image=MultiModalFieldConfig.batched("image"),
            aspect_ratios=MultiModalFieldConfig.batched("image"),
        )

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
658
        out_mm_kwargs: MultiModalKwargsItems,
659
    ) -> list[PromptUpdate]:
660
661
662
663
664
665
        config = self.info.get_hf_config()
        vision_config = config.vision_config

        num_patches_per_chunk = self.info.get_patch_per_chunk(vision_config)
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
        image_token = hf_processor.image_token
666
        img_patch_token = hf_processor.img_patch_token
667
668

        def get_replacement(item_idx: int):
669
670
            out_item = out_mm_kwargs["image"][item_idx]
            aspect_ratio = out_item["aspect_ratios"].data
671
672

            repl = hf_processor._prompt_split_image(
673
                aspect_ratio=aspect_ratio,
674
675
676
677
                num_patches_per_chunk=num_patches_per_chunk,
            )

            return PromptUpdateDetails.select_text(repl, img_patch_token)
678
679
680
681
682
683
684
685
686
687
688

        return [
            PromptReplacement(
                modality="image",
                target=image_token,
                replacement=get_replacement,
            )
        ]


class Mllama4DummyInputsBuilder(BaseDummyInputsBuilder[Mllama4ProcessingInfo]):
689
690
691
692
693
694
695
696
697
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

        processor = self.info.get_hf_processor()
        image_token = processor.fake_image_token

        return image_token * num_images

    def get_dummy_mm_data(
698
699
700
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
701
        mm_options: Mapping[str, BaseDummyOptions],
702
    ) -> MultiModalDataDict:
703
704
        num_images = mm_counts.get("image", 0)

705
        (target_width, target_height) = self.info.get_image_size_with_most_features()
706

707
        image_overrides = mm_options.get("image")
708

709
        return {
710
711
712
713
714
715
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            )
716
717
718
719
720
721
722
723
        }


@MULTIMODAL_REGISTRY.register_processor(
    Mllama4MultiModalProcessor,
    info=Mllama4ProcessingInfo,
    dummy_inputs=Mllama4DummyInputsBuilder,
)
724
class Llama4ForConditionalGeneration(
725
726
727
728
729
730
    nn.Module,
    SupportsMultiModal,
    SupportsPP,
    MixtureOfExperts,
    SupportsEagle3,
    SupportsLoRA,
731
):
732
733
    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
734
        "gate_up_proj": ["gate_proj", "up_proj"],
735
736
    }

737
738
    supports_encoder_tp_data = True

739
    @classmethod
740
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
741
742
743
744
745
        if modality.startswith("image"):
            return "<|image|>"

        raise ValueError("Only image modality is supported")

746
747
748
749
750
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
751
752
        self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"

753
        self.vllm_config = vllm_config
754
755
756
        self.config = config
        self.quant_config = quant_config
        self.multimodal_config = multimodal_config
757
758

        with self._mark_tower_model(vllm_config, "image"):
759
            with set_current_vllm_config(vllm_config):
760
761
762
763
764
765
                self.vision_model = Llama4VisionModel(
                    config=config.vision_config,
                    quant_config=None,
                    prefix=maybe_prefix(prefix, "vision_model"),
                )

766
            self.multi_modal_projector = Llama4MultiModalProjector(
767
768
769
                config=self.config,
                quant_config=None,
                prefix=maybe_prefix(prefix, "multi_modal_projector"),
770
            )
771
772
773
774
775
776
777
778
779

        with self._mark_language_model(vllm_config):
            self.language_model = initialize_model(
                vllm_config=vllm_config.with_hf_config(
                    config.text_config, ["LlamaForCausalLM"]
                ),
                prefix=maybe_prefix(prefix, "language_model"),
                model_class=Llama4ForCausalLM,
            )
780
781

        self.make_empty_intermediate_tensors = (
782
783
            self.language_model.make_empty_intermediate_tensors
        )
784

785
786
787
788
789
790
791
792
793
794
795
        # Set MoE hyperparameters
        self.num_expert_groups = 1
        self.num_logical_experts = self.language_model.num_logical_experts
        self.num_physical_experts = self.language_model.num_physical_experts
        self.num_local_physical_experts = self.language_model.num_local_physical_experts
        self.num_routed_experts = self.language_model.num_routed_experts
        self.num_shared_experts = self.language_model.num_shared_experts
        self.num_redundant_experts = self.language_model.num_redundant_experts
        self.moe_layers = self.language_model.moe_layers
        self.num_moe_layers = len(self.moe_layers)

796
797
798
799
800
    def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
        # Delegate to underlying language model (Llama4ForCausalLM)
        assert hasattr(self.language_model, "set_aux_hidden_state_layers")
        self.language_model.set_aux_hidden_state_layers(layers)

801
    def get_eagle3_default_aux_hidden_state_layers(self) -> tuple[int, ...]:
802
        # Delegate to underlying language model (Llama4ForCausalLM)
803
804
805
806
        assert hasattr(
            self.language_model, "get_eagle3_default_aux_hidden_state_layers"
        )
        return self.language_model.get_eagle3_default_aux_hidden_state_layers()
807

808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
    def set_eplb_state(
        self,
        expert_load_view: torch.Tensor,
        logical_to_physical_map: torch.Tensor,
        logical_replica_count: torch.Tensor,
    ):
        self.language_model.set_eplb_state(
            expert_load_view, logical_to_physical_map, logical_replica_count
        )
        self.expert_weights = self.language_model.expert_weights

    def update_physical_experts_metadata(
        self, num_physical_experts: int, num_local_physical_experts: int
    ):
        self.language_model.update_physical_experts_metadata(
            num_physical_experts, num_local_physical_experts
        )

826
    def _parse_and_validate_image_input(
827
        self, **kwargs: object
828
    ) -> Llama4ImagePatchInputs | None:
829
830
831
832
833
        # num_images, 1, num_chunks, channel, image_size, image_size
        pixel_values = kwargs.pop("pixel_values", None)
        if pixel_values is None:
            return None

834
        patches_per_image = kwargs.pop("patches_per_image")
835
        aspect_ratios = kwargs.pop("aspect_ratios")
836
837
838

        return Llama4ImagePatchInputs(
            type="pixel_values",
839
            pixel_values=pixel_values,
840
841
842
843
844
            patches_per_image=patches_per_image,
            aspect_ratios=aspect_ratios,
        )

    def _process_image_input(
845
846
        self, image_input: Llama4ImagePatchInputs
    ) -> MultiModalEmbeddings:
847
        assert self.vision_model and self.multi_modal_projector
848
        pixel_values = image_input["pixel_values"]
849
        patches_per_image = image_input["patches_per_image"].tolist()
850

851
852
853
        # shard image input
        if self.use_data_parallel:
            vision_embeddings_flat = run_dp_sharded_vision_model(
854
                pixel_values, self.vision_model
855
            )
856
        else:
857
            vision_embeddings_flat = self.vision_model(pixel_values)
858

859
        vision_embeddings_flat = self.multi_modal_projector(vision_embeddings_flat)
860
861
862
863
864

        return [
            img.flatten(0, 1)
            for img in vision_embeddings_flat.split(patches_per_image, dim=0)
        ]
865

866
    def embed_multimodal(self, **kwargs) -> MultiModalEmbeddings:
867
868
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
869
            return []
870

871
        return self._process_image_input(image_input)
872
873
874

    def forward(
        self,
875
        input_ids: torch.Tensor | None,
876
        positions: torch.Tensor,
877
878
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
879
        **kwargs: object,
880
    ) -> torch.Tensor | IntermediateTensors:
881
882
883
        if intermediate_tensors is not None:
            inputs_embeds = None

884
885
886
        return self.language_model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
887
888
889
890

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
891
    ) -> torch.Tensor | None:
892
        return self.language_model.compute_logits(hidden_states)
893
894
895

    def separate_weights(
        self,
896
        weights: Iterable[tuple[str, torch.Tensor]],
897
        prefix: str,
898
    ) -> tuple[Iterable[tuple[str, torch.Tensor]], Iterable[tuple[str, torch.Tensor]]]:
899
900
        weights1, weights2 = tee(weights, 2)

901
        def get_prefix_weights() -> Iterable[tuple[str, torch.Tensor]]:
902
903
904
905
            for name, data in weights1:
                if name.startswith(prefix):
                    yield (name, data)

906
        def get_other_weights() -> Iterable[tuple[str, torch.Tensor]]:
907
908
909
910
911
912
            for name, data in weights2:
                if not name.startswith(prefix):
                    yield (name, data)

        return get_prefix_weights(), get_other_weights()

913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
    def _consolidate_qkv_weights(
        self, weights: Iterable[tuple[str, torch.Tensor]]
    ) -> Iterable[tuple[str, torch.Tensor]]:
        qkv_idx_mappings = {
            ".self_attn.q_proj": 0,
            ".self_attn.k_proj": 1,
            ".self_attn.v_proj": 2,
        }
        qkv_weights = {}
        for name, loaded_weight in weights:
            for weight_name, idx in qkv_idx_mappings.items():
                if weight_name not in name:
                    continue
                new_name = name.replace(weight_name, ".self_attn.qkv_proj")
                if new_name not in qkv_weights:
                    qkv_weights[new_name] = [None] * 3
                qkv_weights[new_name][idx] = loaded_weight
                break
            else:
                yield name, loaded_weight
        for key, weight in qkv_weights.items():
            qkv_weight = torch.cat(weight, dim=0)
            yield key, qkv_weight

937
938
939
    def _rename_weight_for_modelopt_checkpoint(self, name: str) -> str:
        """Rename weights from ModelOpt llama4 fp8 checkpoints to vLLM
        format."""
940
941
942
943
944
945
        if name.startswith("model.") or name.startswith("language_model.model."):
            renamed = (
                name.replace("model.", "language_model.model.", 1)
                if name.startswith("model.")
                else name
            )
946
            # Handle expert scale parameters with flat naming
947
948
949
            if "feed_forward.experts." in name and (
                "_input_scale" in name or "_weight_scale" in name
            ):
950
951
                # Map checkpoint naming to vLLM's expected naming
                if "down_proj_input_scale" in renamed:
952
                    return renamed.replace("down_proj_input_scale", "w2_input_scale")
953
                elif "down_proj_weight_scale" in renamed:
954
                    return renamed.replace("down_proj_weight_scale", "w2_weight_scale")
955
                elif "gate_up_proj_input_scale" in renamed:
956
957
958
                    return renamed.replace(
                        "gate_up_proj_input_scale", "w13_input_scale"
                    )
959
                elif "gate_up_proj_weight_scale" in renamed:
960
961
962
                    return renamed.replace(
                        "gate_up_proj_weight_scale", "w13_weight_scale"
                    )
963
964
965
                return renamed

            # Handle attention scale parameters
966
            elif "self_attn." in name and (".k_scale" in name or ".v_scale" in name):
967
968
969
970
971
972
973
                if ".k_proj.k_scale" in renamed:
                    return renamed.replace(".k_proj.k_scale", ".attn.k_scale")
                elif ".v_proj.v_scale" in renamed:
                    return renamed.replace(".v_proj.v_scale", ".attn.v_scale")
                return renamed

            # Standard model.* to language_model.model.* renaming
974
            return renamed
975
976

        elif name.startswith("lm_head.weight"):
977
            return name.replace("lm_head.weight", "language_model.lm_head.weight")
978
979
980
981
982
983
984
985
986
987

        return name

    def _separate_and_rename_weights(
        self, weights: Iterable[tuple[str, torch.Tensor]]
    ) -> tuple[list[tuple[str, torch.Tensor]], list[tuple[str, torch.Tensor]]]:
        """Rename weights and separate them into language_model and other
        weights."""
        language_model_weights = []
        other_weights = []
988

989
990
        for name, weight in weights:
            renamed = self._rename_weight_for_modelopt_checkpoint(name)
991

992
            attr = renamed.split(".", 1)[0]
993
            if isinstance(getattr(self, attr), StageMissingLayer):
994
995
                continue

996
997
998
999
1000
1001
1002
1003
            if renamed.startswith("language_model."):
                language_model_weights.append((renamed, weight))
            else:
                other_weights.append((renamed, weight))

        return language_model_weights, other_weights

    def _handle_expert_scale_broadcasting(
1004
        self, weights: list[tuple[str, torch.Tensor]], params_dict: dict
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
    ) -> tuple[list[tuple[str, torch.Tensor]], set[str]]:
        """Handle expert scale parameters that need broadcasting.

        ModelOpt checkpoints use a single value tensor scalar for BMM style
        experts, vLLM expects the scale to be broadcasted across all experts.
        """
        regular_weights = []
        expert_scale_weights = []
        updated_params = set()

        for name, weight in weights:
            # Check if this is an expert scale parameter that needs broadcasting
1017
1018
1019
1020
1021
            if (
                "feed_forward.experts." in name
                and "scale" in name
                and ".shared_expert" not in name
            ):
1022
1023
                if name in params_dict:
                    param = params_dict[name]
1024
1025
1026
1027
1028
                    if (
                        hasattr(param, "data")
                        and param.data.numel() > 1
                        and weight.numel() == 1
                    ):
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
                        # Broadcast single value to all experts
                        param.data.fill_(weight.item())
                        updated_params.add(name)
                        continue

                expert_scale_weights.append((name, weight))
            else:
                regular_weights.append((name, weight))

        return regular_weights, expert_scale_weights, updated_params

1040
1041
1042
1043
1044
1045
    def _load_other_weights(
        self,
        other_weights: Iterable[tuple[str, torch.Tensor]],
        params_dict: dict,
        stacked_params_mapping: list,
    ) -> set[str]:
1046
1047
        """Load non-language-model weights with stacking support."""
        updated_params = set()
1048

1049
1050
1051
        if self.use_data_parallel:
            other_weights = self._consolidate_qkv_weights(other_weights)

1052
        for name, loaded_weight in other_weights:
1053
            # Try stacked parameter mapping first
1054
            for param_name, weight_name, shard_id in stacked_params_mapping:
1055
                if weight_name not in name or self.use_data_parallel:
1056
1057
1058
1059
1060
1061
1062
1063
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
                updated_params.add(name)
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
1064
                # Use regular weight loading
1065
                param = params_dict[name]
1066
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
1067
1068
                weight_loader(param, loaded_weight)
                updated_params.add(name)
1069
1070
1071

        return updated_params

1072
1073
1074
1075
    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
        # Params for weights, fp8 weight scales, fp8 activation scales
        # (param_name, weight_name, expert_id, shard_id)
        return FusedMoE.make_expert_params_mapping(
1076
            self,
1077
1078
1079
1080
1081
1082
1083
            ckpt_gate_proj_name="gate_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="up_proj",
            num_experts=self.config.text_config.num_local_experts,
            num_redundant_experts=self.num_redundant_experts,
        )

1084
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".self_attn.qkv_proj", ".self_attn.q_proj", "q"),
            (".self_attn.qkv_proj", ".self_attn.k_proj", "k"),
            (".self_attn.qkv_proj", ".self_attn.v_proj", "v"),
            # Shared expert gate_up_proj stacking
            (".shared_expert.gate_up_proj", ".shared_expert.gate_proj", 0),
            (".shared_expert.gate_up_proj", ".shared_expert.up_proj", 1),
            # Feed forward gate_up_proj stacking (for non-MoE layers if any)
            (".feed_forward.gate_up_proj", ".feed_forward.gate_proj", 0),
            (".feed_forward.gate_up_proj", ".feed_forward.up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
        updated_params: set[str] = set()

        # Separate and rename weights
1101
1102
1103
        language_model_weights, other_weights = self._separate_and_rename_weights(
            weights
        )
1104
1105
1106

        # Handle expert scale parameters
        regular_weights, expert_scale_weights, updated_params_from_experts = (
1107
1108
            self._handle_expert_scale_broadcasting(language_model_weights, params_dict)
        )
1109
1110
1111
1112
1113
1114
1115
1116
        updated_params.update(updated_params_from_experts)

        loader = AutoWeightsLoader(self)
        loaded_language_model_params = loader.load_weights(regular_weights)
        assert loaded_language_model_params is not None
        updated_params.update(loaded_language_model_params)

        if expert_scale_weights:
1117
            loaded_expert_scale_params = loader.load_weights(expert_scale_weights)
1118
1119
1120
1121
            if loaded_expert_scale_params:
                updated_params.update(loaded_expert_scale_params)

        updated_params.update(
1122
1123
            self._load_other_weights(other_weights, params_dict, stacked_params_mapping)
        )
1124

1125
        return updated_params
1126
1127
1128
1129
1130
1131
1132

    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="language_model",
1133
1134
1135
1136
            connector=[
                "multi_modal_projector.",
                "vision_model.vision_adapter.",
            ],
1137
1138
            tower_model="vision_model.",
        )
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157

    def get_num_mm_encoder_tokens(self, num_image_tokens: int) -> int:
        vision_config = self.config.vision_config
        patches_per_chunk = Mllama4ProcessingInfo.get_patch_per_chunk(vision_config)
        if num_image_tokens <= 0 or patches_per_chunk <= 0:
            return 0
        raw_patches = (vision_config.image_size // vision_config.patch_size) ** 2
        num_chunks = num_image_tokens // patches_per_chunk
        # Encoder processes raw_patches + 1 (CLS) per chunk
        return num_chunks * (raw_patches + 1)

    def get_num_mm_connector_tokens(self, num_vision_tokens: int) -> int:
        vision_config = self.config.vision_config
        raw_patches = (vision_config.image_size // vision_config.patch_size) ** 2
        if num_vision_tokens <= 0:
            return 0
        num_chunks = num_vision_tokens // (raw_patches + 1)
        patches_per_chunk = Mllama4ProcessingInfo.get_patch_per_chunk(vision_config)
        return num_chunks * patches_per_chunk