"vllm/vscode:/vscode.git/clone" did not exist on "e50c45467215f96068d95736b08d8a25f624e67d"
mllama4.py 41.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
#
# Copyright 2025 the LLAMA4, Meta Inc., vLLM, and HuggingFace Inc. team.
# All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from collections.abc import Iterable, Mapping
from itertools import tee
22
from typing import Annotated, Literal
23
24
25
26
27
28
29

import torch
from torch import nn
from transformers import BatchFeature, Llama4Config, Llama4VisionConfig
from transformers.image_utils import SizeDict
from transformers.models.llama4 import Llama4Processor
from transformers.models.llama4.image_processing_llama4_fast import (
30
31
32
    find_supported_resolutions,
    get_best_fit,
)
33

34
35
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig, set_current_vllm_config
36
from vllm.config.multimodal import BaseDummyOptions
37
from vllm.distributed import get_tensor_model_parallel_world_size
38
from vllm.forward_context import set_forward_context
39
from vllm.model_executor.layers.attention.mm_encoder_attention import MMEncoderAttention
40
from vllm.model_executor.layers.fused_moe import FusedMoE
41
42
43
44
45
46
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    QKVParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
47
48
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
49
from vllm.model_executor.model_loader.utils import initialize_model
50
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
51
from vllm.model_executor.models.module_mapping import MultiModelKeys
52
from vllm.model_executor.models.vision import should_torch_compile_mm_vit
53
from vllm.multimodal import MULTIMODAL_REGISTRY
54
55
56
57
58
59
60
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
)
from vllm.multimodal.parse import ImageProcessorItems, ImageSize, MultiModalDataItems
from vllm.multimodal.processing import (
61
    BaseDummyInputsBuilder,
62
63
64
65
66
67
68
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    InputProcessingContext,
    PromptReplacement,
    PromptUpdate,
    PromptUpdateDetails,
)
69
from vllm.sequence import IntermediateTensors
70
from vllm.utils.tensor_schema import TensorSchema, TensorShape
71

72
from .interfaces import (
73
    LMMissingLayer,
74
    MixtureOfExperts,
75
76
    MultiModalEmbeddings,
    SupportsEagle3,
77
    SupportsLoRA,
78
79
    SupportsMultiModal,
    SupportsPP,
80
    TowerMissingLayer,
81
)
82
from .llama4 import Llama4ForCausalLM
83
84
85
86
from .utils import (
    AutoWeightsLoader,
    maybe_prefix,
)
87
from .vision import run_dp_sharded_vision_model
88
89


90
class Llama4ImagePatchInputs(TensorSchema):
91
    """
92
93
94
95
96
    Dimensions:
        - batch_size: Batch size
        - total_num_chunks: Batch size * number of chunks
        - num_channels: Number of channels
        - image_size: Size of each image
97
    """
98
99
100

    type: Literal["pixel_values"] = "pixel_values"

101
    pixel_values: Annotated[
102
103
104
        torch.Tensor,
        TensorShape("total_num_chunks", "num_channels", "image_size", "image_size"),
    ]
105
106

    patches_per_image: Annotated[torch.Tensor, TensorShape("batch_size")]
107
108
    """
    The number of total patches for each image in the batch.
109
    
110
    This is used to split the embeddings which has the first two dimensions
111
    flattened just like `pixel_values`.
112
    """
113

114
    aspect_ratios: Annotated[torch.Tensor, TensorShape("batch_size", 2)]
115
116
117
    """
    A list of aspect ratios corresponding to the number of tiles
    in each dimension that each image in the batch corresponds to.
118
    Each aspect ratio is a pair (ratio_h, ratio_w).
119
120
121
122
    """


class Llama4VisionMLP(nn.Module):
123
124
125
126
127
128
129
    def __init__(
        self,
        input_size: int,
        intermediate_size: int,
        output_size: int,
        bias: bool,
        output_activation: bool,
130
        quant_config: QuantizationConfig | None = None,
131
132
133
        prefix: str = "",
        use_data_parallel: bool = False,
    ):
134
        super().__init__()
135
        self.fc1 = ColumnParallelLinear(
136
137
138
139
140
            input_size=input_size,
            output_size=intermediate_size,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.fc1",
141
            disable_tp=use_data_parallel,
142
        )
143
        self.fc2 = RowParallelLinear(
144
145
146
147
148
            input_size=intermediate_size,
            output_size=output_size,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.fc2",
149
            disable_tp=use_data_parallel,
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
        )
        self.activation_fn = nn.GELU()
        self.output_activation = output_activation

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.fc1(hidden_states)
        hidden_states = self.activation_fn(hidden_states)
        hidden_states, _ = self.fc2(hidden_states)
        if self.output_activation:
            return self.activation_fn(hidden_states)
        return hidden_states


class Llama4MultiModalProjector(nn.Module):
    def __init__(
        self,
        config,
167
        quant_config: QuantizationConfig | None = None,
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
        prefix: str = "",
    ):
        super().__init__()
        self.linear_1 = ColumnParallelLinear(
            input_size=config.vision_config.vision_output_dim,
            output_size=config.text_config.hidden_size,
            bias=False,
            quant_config=quant_config,
            gather_output=True,
            prefix=f"{prefix}.linear_1",
        )

    def forward(self, image_features):
        hidden_states, _ = self.linear_1(image_features)
        return hidden_states


def pixel_shuffle(input_tensor, shuffle_ratio):
    # input_tensor: [batch_size, num_patches, channels]
    batch_size, num_patches, channels = input_tensor.shape
    patch_size = int(math.sqrt(num_patches))

    input_tensor = input_tensor.view(batch_size, patch_size, patch_size, -1)
    batch_size, height, width, channels = input_tensor.size()

193
194
195
    reshaped_tensor = input_tensor.view(
        batch_size, height, int(width * shuffle_ratio), int(channels / shuffle_ratio)
    )
196
197
    reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous()

198
199
200
201
202
203
    reshaped_tensor = reshaped_tensor.view(
        batch_size,
        int(height * shuffle_ratio),
        int(width * shuffle_ratio),
        int(channels / (shuffle_ratio**2)),
    )
204
205
    reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous()

206
    output_tensor = reshaped_tensor.view(batch_size, -1, reshaped_tensor.shape[-1])
207
208
209
210
211
212
213
    return output_tensor


class Llama4VisionPixelShuffleMLP(nn.Module):
    def __init__(
        self,
        config,
214
        quant_config: QuantizationConfig | None = None,
215
        prefix: str = "",
216
        use_data_parallel: bool = False,
217
218
219
    ):
        super().__init__()
        self.pixel_shuffle_ratio = config.pixel_shuffle_ratio
220
221
222
        self.inner_dim = int(
            config.projector_input_dim // (self.pixel_shuffle_ratio**2)
        )
223
224
225
226
227
228
229
230
        self.output_dim = config.projector_output_dim
        self.mlp = Llama4VisionMLP(
            input_size=config.intermediate_size,
            intermediate_size=config.projector_input_dim,
            output_size=config.projector_output_dim,
            bias=config.multi_modal_projector_bias,
            output_activation=True,
            quant_config=quant_config,
231
232
233
            prefix=f"{prefix}.mlp",
            use_data_parallel=use_data_parallel,
        )
234
235

    def forward(self, encoded_patches: torch.Tensor) -> torch.Tensor:
236
        encoded_patches = pixel_shuffle(encoded_patches, self.pixel_shuffle_ratio)
237
238
239
240
241
242
243
        return self.mlp(encoded_patches)


class Llama4VisionAttention(nn.Module):
    def __init__(
        self,
        config: Llama4VisionConfig,
244
        quant_config: QuantizationConfig | None,
245
        prefix: str = "",
246
        use_data_parallel: bool = False,
247
248
249
    ):
        super().__init__()
        self.config = config
250
251
252
        self.tp_size = (
            1 if use_data_parallel else get_tensor_model_parallel_world_size()
        )
253
254
255
256
257
258
259
260
261
262
        self.embed_dim = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = config.hidden_size // self.num_heads
        assert self.num_heads % self.tp_size == 0
        self.num_local_heads = self.num_heads // self.tp_size
        self.q_size = self.num_local_heads * self.head_dim
        self.kv_size = self.num_local_heads * self.head_dim
        self.attention_dropout = config.attention_dropout
        self.scaling = self.head_dim**-0.5

263
        self.attn = MMEncoderAttention(
264
265
            self.num_local_heads, self.head_dim, self.scaling
        )
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298

        if use_data_parallel:
            self.qkv_proj = ReplicatedLinear(
                self.embed_dim,
                self.q_size + 2 * self.kv_size,
                bias=True,
                quant_config=quant_config,
                prefix=f"{prefix}.qkv_proj",
            )
            self.o_proj = ReplicatedLinear(
                self.num_heads * self.head_dim,
                self.embed_dim,
                bias=True,
                quant_config=quant_config,
                prefix=f"{prefix}.o_proj",
            )
        else:
            self.qkv_proj = QKVParallelLinear(
                self.embed_dim,
                self.head_dim,
                self.num_heads,
                bias=True,
                quant_config=quant_config,
                prefix=f"{prefix}.qkv_proj",
            )
            self.o_proj = RowParallelLinear(
                self.num_heads * self.head_dim,
                self.embed_dim,
                bias=True,
                input_is_parallel=True,
                quant_config=quant_config,
                prefix=f"{prefix}.o_proj",
            )
299

300
301
302
        rope_parameters = {
            "rope_type": "mllama4",
            "rope_theta": config.rope_parameters["rope_theta"],
303
            "partial_rotary_factor": 0.5,
304
305
        }

306
307
308
        self.rotary_emb = get_rope(
            head_size=self.head_dim,
            # number of image patches
309
            max_position=(config.image_size // config.patch_size) ** 2,
310
            rope_parameters=rope_parameters,
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
            is_neox_style=False,
            dtype=torch.complex64,  # important
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        input_shape = hidden_states.shape[:-1]

        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)

        q = q.view(q.shape[0], q.shape[1], self.num_local_heads, self.head_dim)
        k = k.view(k.shape[0], k.shape[1], self.num_local_heads, self.head_dim)
        q, k = self.rotary_emb(q, k)

        q = q.view(q.shape[0], q.shape[1], -1)
        k = k.view(k.shape[0], k.shape[1], -1)

        attn_output = self.attn(q, k, v)
        attn_output = attn_output.reshape(*input_shape, -1).contiguous()
        attn_output, _ = self.o_proj(attn_output)

        return attn_output


class Llama4VisionEncoderLayer(nn.Module):
    def __init__(
        self,
        config: Llama4VisionConfig,
342
        quant_config: QuantizationConfig | None,
343
        prefix: str = "",
344
        use_data_parallel: bool = False,
345
346
347
348
349
350
    ):
        super().__init__()
        self.hidden_size = config.hidden_size
        self.num_attention_heads = config.num_attention_heads
        self.intermediate_size = config.intermediate_size

351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
        self.self_attn = Llama4VisionAttention(
            config,
            quant_config=quant_config,
            prefix=f"{prefix}.self_attn",
            use_data_parallel=use_data_parallel,
        )
        self.mlp = Llama4VisionMLP(
            input_size=config.hidden_size,
            intermediate_size=config.intermediate_size,
            output_size=config.hidden_size,
            bias=True,
            output_activation=False,
            quant_config=quant_config,
            prefix=f"{prefix}.mlp",
            use_data_parallel=use_data_parallel,
        )
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386

        self.input_layernorm = nn.LayerNorm(config.hidden_size)
        self.post_attention_layernorm = nn.LayerNorm(config.hidden_size)

    def forward(
        self,
        hidden_state: torch.Tensor,
    ):
        # Self Attention
        residual = hidden_state
        hidden_state = self.input_layernorm(hidden_state)
        hidden_state = self.self_attn(hidden_state)
        hidden_state = residual + hidden_state

        # Feed forward
        residual = hidden_state
        hidden_state = self.post_attention_layernorm(hidden_state)
        hidden_state = self.mlp(hidden_state)
        hidden_state = residual + hidden_state

387
        outputs = (hidden_state,)
388
389
390
391
392
393
394
        return outputs


class Llama4VisionEncoder(nn.Module):
    def __init__(
        self,
        config: Llama4VisionConfig,
395
        quant_config: QuantizationConfig | None,
396
        prefix: str = "",
397
        use_data_parallel: bool = False,
398
399
400
    ):
        super().__init__()
        self.config = config
401
402
403
404
405
406
407
408
409
410
411
        self.layers = nn.ModuleList(
            [
                Llama4VisionEncoderLayer(
                    config,
                    quant_config=quant_config,
                    prefix=f"{prefix}.layers.{layer_idx}",
                    use_data_parallel=use_data_parallel,
                )
                for layer_idx in range(config.num_hidden_layers)
            ]
        )
412
413
414
415
416
417
418

    def forward(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        r"""
        Args:
419
            hidden_states: Input tensor of shape
420
                (batch_size, sequence_length, hidden_size).
421
                Hidden states from the model embeddings, representing
422
                the input tokens.
423
424
425
426
427
428
429
430
431
432
433
434
                associated vectors than the model's internal embedding
                lookup matrix.
        """

        for encoder_layer in self.layers:
            layer_outputs = encoder_layer(hidden_states)
            hidden_states = layer_outputs[0]

        return hidden_states


class Llama4UnfoldConvolution(nn.Module):
435
436
437
    def __init__(
        self,
        config: Llama4VisionConfig,
438
        quant_config: QuantizationConfig | None = None,
439
440
441
        prefix: str = "",
        use_data_parallel: bool = False,
    ):
442
443
444
445
        super().__init__()
        kernel_size = config.patch_size
        if isinstance(kernel_size, int):
            kernel_size = (kernel_size, kernel_size)
446
        self.unfold = torch.nn.Unfold(kernel_size=kernel_size, stride=config.patch_size)
447
448
449
450
451
452
453
454
455
        self.linear = ColumnParallelLinear(
            input_size=config.num_channels * kernel_size[0] * kernel_size[1],
            output_size=config.hidden_size,
            bias=False,
            gather_output=True,
            quant_config=quant_config,
            prefix=f"{prefix}.linear",
            disable_tp=use_data_parallel,
        )
456
457
458
459
460
461
462
463

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states = self.unfold(hidden_states)
        hidden_states = hidden_states.permute(0, 2, 1)
        hidden_states, _ = self.linear(hidden_states)
        return hidden_states


464
465
466
@support_torch_compile(
    dynamic_arg_dims={"images_flattened": 0}, enable_if=should_torch_compile_mm_vit
)
467
468
469
470
class Llama4VisionModel(nn.Module):
    def __init__(
        self,
        config: Llama4VisionConfig,
471
        quant_config: QuantizationConfig | None = None,
472
        prefix: str = "",
473
        use_data_parallel: bool = False,
474
475
476
477
478
479
480
481
    ):
        super().__init__()
        self.config = config
        self.image_size = config.image_size
        self.patch_size = config.patch_size
        self.hidden_size = config.hidden_size
        self.num_channels = config.num_channels

482
        self.num_patches = (self.image_size // self.patch_size) ** 2 + 1
483
484
485
486
487
        self.scale = config.hidden_size**-0.5

        self.patch_embedding = Llama4UnfoldConvolution(
            config,
            quant_config=quant_config,
488
489
490
            prefix=f"{prefix}.patch_embedding",
            use_data_parallel=use_data_parallel,
        )
491

492
        self.class_embedding = nn.Parameter(self.scale * torch.randn(self.hidden_size))
493
        self.positional_embedding_vlm = nn.Parameter(
494
495
            self.scale * torch.randn(self.num_patches, self.hidden_size)
        )
496
497
498
499
500
501

        # layer norms
        self.layernorm_pre = nn.LayerNorm(self.hidden_size, eps=1e-5)
        self.layernorm_post = nn.LayerNorm(self.hidden_size, eps=1e-5)

        # encoders
502
503
504
505
506
507
        self.model = Llama4VisionEncoder(
            config,
            quant_config=quant_config,
            prefix=f"{prefix}.model",
            use_data_parallel=use_data_parallel,
        )
508

509
        self.vision_adapter = Llama4VisionPixelShuffleMLP(
510
511
512
513
514
            config,
            quant_config,
            prefix=f"{prefix}.vision_adapter",
            use_data_parallel=use_data_parallel,
        )
515
516
517
518
519
520
521
522
523
524

    def forward(
        self,
        images_flattened: torch.Tensor,
    ) -> torch.Tensor:
        # Patch embedding
        hidden_state = self.patch_embedding(images_flattened)
        num_tiles, num_patches, hidden_dim = hidden_state.shape

        # Add cls token
525
526
527
        class_embedding = self.class_embedding.expand(
            hidden_state.shape[0], 1, hidden_state.shape[-1]
        )
528
529
530
531
532
533
534
535
536
537
538
        hidden_state = torch.cat([hidden_state, class_embedding], dim=1)
        num_patches += 1

        # Position embeddings
        hidden_state = hidden_state.reshape(
            num_tiles,
            1,
            num_patches,
            hidden_dim,
        )
        positional_embedding = self.positional_embedding_vlm.to(
539
540
            dtype=hidden_state.dtype, device=hidden_state.device
        )
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
        hidden_state = hidden_state + positional_embedding
        hidden_state = self.layernorm_pre(hidden_state)
        hidden_state = hidden_state.view(num_tiles, -1, hidden_dim)

        # Apply encoder
        hidden_state = self.model(hidden_state)
        hidden_state = self.layernorm_post(hidden_state)

        # Remove CLS token output
        hidden_state = hidden_state[:, :-1, :]

        # now, we use Llama4VisionPixelShuffle + mlp to project embeddings
        hidden_state = self.vision_adapter(hidden_state)

        return hidden_state


class Mllama4ProcessingInfo(BaseProcessingInfo):
    def __init__(self, ctx: InputProcessingContext) -> None:
        super().__init__(ctx)

    def get_hf_config(self) -> Llama4Config:
        return self.ctx.get_hf_config(Llama4Config)

    def get_hf_processor(self, **kwargs: object) -> Llama4Processor:
566
567
568
        return self.ctx.get_hf_processor(
            Llama4Processor, use_fast=kwargs.pop("use_fast", True), **kwargs
        )
569

570
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
571
572
573
        # Although vLLM can support more images from an infra capability
        # perspective, we do not recommend using >10 images in practice.
        return {"image": None}
574
575
576
577
578
579

    @staticmethod
    def get_patch_per_chunk(vision_config: Llama4VisionConfig) -> int:
        image_size = vision_config.image_size
        patch_size = vision_config.patch_size

580
581
582
        assert image_size % patch_size == 0, (
            f"chunk size {image_size} should be multiple of "
        )
583
584
585
        f"patch_size {patch_size}"

        ds_ratio = int(round(1.0 / (vision_config.pixel_shuffle_ratio**2)))
586
        return (image_size // patch_size) ** 2 // ds_ratio
587
588
589
590
591
592
593
594
595

    def get_max_num_tiles(self) -> int:
        image_processor = self.get_hf_processor().image_processor
        return image_processor.max_patches

    def get_image_size_with_most_features(self) -> ImageSize:
        vision_config = self.get_hf_config().vision_config
        image_size = vision_config.image_size
        # Result in the max possible feature size (h:w = 16:1)
596
        return ImageSize(height=self.get_max_num_tiles() * image_size, width=image_size)
597
598


599
class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo]):
600
601
602
603
604
    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
605
        tok_kwargs: Mapping[str, object],
606
607
608
609
610
611
612
613
614
    ) -> BatchFeature:
        tokenizer = self.info.get_tokenizer()

        if mm_data is None:
            return tokenizer(prompt, add_special_tokens=False)  # exclude bos
        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
615
            tok_kwargs=tok_kwargs,
616
617
618
619
620
621
622
        )

        processor = self.info.get_hf_processor(**mm_kwargs)
        image_processor = processor.image_processor
        vision_config = self.info.get_hf_config().vision_config

        if processed_outputs.get("pixel_values") is not None:
623
624
625
            assert "images" in mm_data, (
                "images expected to be in mm_data when pixel_values is present"
            )
626
627

            images = mm_data["images"]
628
629
630
631
632
            parsed_images = (
                self._get_data_parser()
                .parse_mm_data({"image": images})
                .get_items("image", ImageProcessorItems)
            )
633
634
635
636
637
638
639
640
641
642

            tile_size = vision_config.image_size
            possible_resolutions = find_supported_resolutions(
                max_num_chunks=self.info.get_max_num_tiles(),
                patch_size=SizeDict(height=tile_size, width=tile_size),
            )
            best_fit_sizes = [
                get_best_fit(
                    (image.size[1], image.size[0]),
                    torch.tensor(possible_resolutions),
643
                    resize_to_max_canvas=image_processor.resize_to_max_canvas,
644
645
                )
                for image in parsed_images
646
647
            ]
            # TODO tile height/width do not necessarily need to match
648
649
650
651
            aspect_ratios = [
                (image_size[0] // tile_size, image_size[1] // tile_size)
                for image_size in best_fit_sizes
            ]
652
            patches_per_image = [
653
                1 if r_h * r_w == 1 else 1 + r_h * r_w for (r_h, r_w) in aspect_ratios
654
655
            ]

656
            processed_outputs["aspect_ratios"] = torch.tensor(aspect_ratios)
657
            processed_outputs["patches_per_image"] = torch.tensor(patches_per_image)
658
659
660
661
662
663
664
665
666
667
668

        return processed_outputs

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        patches_per_image = hf_inputs.get("patches_per_image", torch.empty(0))
        return dict(
            pixel_values=MultiModalFieldConfig.flat_from_sizes(
669
670
                "image", patches_per_image
            ),
671
672
673
674
675
676
677
678
            patches_per_image=MultiModalFieldConfig.batched("image"),
            aspect_ratios=MultiModalFieldConfig.batched("image"),
        )

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
679
        out_mm_kwargs: MultiModalKwargsItems,
680
    ) -> list[PromptUpdate]:
681
682
683
684
685
686
        config = self.info.get_hf_config()
        vision_config = config.vision_config

        num_patches_per_chunk = self.info.get_patch_per_chunk(vision_config)
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
        image_token = hf_processor.image_token
687
        img_patch_token = hf_processor.img_patch_token
688
689

        def get_replacement(item_idx: int):
690
691
            out_item = out_mm_kwargs["image"][item_idx]
            aspect_ratio = out_item["aspect_ratios"].data
692
693

            repl = hf_processor._prompt_split_image(
694
                aspect_ratio=aspect_ratio,
695
696
697
698
                num_patches_per_chunk=num_patches_per_chunk,
            )

            return PromptUpdateDetails.select_text(repl, img_patch_token)
699
700
701
702
703
704
705
706
707
708
709

        return [
            PromptReplacement(
                modality="image",
                target=image_token,
                replacement=get_replacement,
            )
        ]


class Mllama4DummyInputsBuilder(BaseDummyInputsBuilder[Mllama4ProcessingInfo]):
710
711
712
713
714
715
716
717
718
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

        processor = self.info.get_hf_processor()
        image_token = processor.fake_image_token

        return image_token * num_images

    def get_dummy_mm_data(
719
720
721
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
722
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
723
    ) -> MultiModalDataDict:
724
725
        num_images = mm_counts.get("image", 0)

726
        (target_width, target_height) = self.info.get_image_size_with_most_features()
727

728
729
        image_overrides = mm_options.get("image") if mm_options else None

730
        return {
731
732
733
734
735
736
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            )
737
738
739
740
741
742
743
744
        }


@MULTIMODAL_REGISTRY.register_processor(
    Mllama4MultiModalProcessor,
    info=Mllama4ProcessingInfo,
    dummy_inputs=Mllama4DummyInputsBuilder,
)
745
class Llama4ForConditionalGeneration(
746
747
748
749
750
751
    nn.Module,
    SupportsMultiModal,
    SupportsPP,
    MixtureOfExperts,
    SupportsEagle3,
    SupportsLoRA,
752
):
753
754
    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
755
        "gate_up_proj": ["gate_proj", "up_proj"],
756
757
    }

758
759
    supports_encoder_tp_data = True

760
    @classmethod
761
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
762
763
764
765
766
        if modality.startswith("image"):
            return "<|image|>"

        raise ValueError("Only image modality is supported")

767
768
769
770
771
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
772
773
        self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"

774
        self.vllm_config = vllm_config
775
776
777
        self.config = config
        self.quant_config = quant_config
        self.multimodal_config = multimodal_config
778
779

        with self._mark_tower_model(vllm_config, "image"):
780
781
782
783
784
785
786
787
788
789
790
791
792
            from vllm.compilation.backends import set_model_tag

            with (
                set_current_vllm_config(vllm_config),
                set_model_tag("Llama4VisionModel", is_encoder=True),
            ):
                self.vision_model = Llama4VisionModel(
                    config=config.vision_config,
                    quant_config=None,
                    prefix=maybe_prefix(prefix, "vision_model"),
                    use_data_parallel=self.use_data_parallel,
                )

793
            self.multi_modal_projector = Llama4MultiModalProjector(
794
795
796
                config=self.config,
                quant_config=None,
                prefix=maybe_prefix(prefix, "multi_modal_projector"),
797
            )
798
799
800
801
802
803
804
805
806

        with self._mark_language_model(vllm_config):
            self.language_model = initialize_model(
                vllm_config=vllm_config.with_hf_config(
                    config.text_config, ["LlamaForCausalLM"]
                ),
                prefix=maybe_prefix(prefix, "language_model"),
                model_class=Llama4ForCausalLM,
            )
807
808

        self.make_empty_intermediate_tensors = (
809
810
            self.language_model.make_empty_intermediate_tensors
        )
811

812
813
814
815
816
817
818
819
820
821
822
        # Set MoE hyperparameters
        self.num_expert_groups = 1
        self.num_logical_experts = self.language_model.num_logical_experts
        self.num_physical_experts = self.language_model.num_physical_experts
        self.num_local_physical_experts = self.language_model.num_local_physical_experts
        self.num_routed_experts = self.language_model.num_routed_experts
        self.num_shared_experts = self.language_model.num_shared_experts
        self.num_redundant_experts = self.language_model.num_redundant_experts
        self.moe_layers = self.language_model.moe_layers
        self.num_moe_layers = len(self.moe_layers)

823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
    def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
        """Set which layers should output auxiliary hidden states for EAGLE3."""
        # Delegate to underlying language model (Llama4ForCausalLM)
        assert hasattr(self.language_model, "set_aux_hidden_state_layers")
        self.language_model.set_aux_hidden_state_layers(layers)

    def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
        """Get the layer indices for auxiliary hidden state outputs.

        Note: The GPU model runner will override this with layers from
        the speculative config if available, providing dynamic configuration.
        """
        # Delegate to underlying language model (Llama4ForCausalLM)
        assert hasattr(self.language_model, "get_eagle3_aux_hidden_state_layers")
        return self.language_model.get_eagle3_aux_hidden_state_layers()

839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
    def set_eplb_state(
        self,
        expert_load_view: torch.Tensor,
        logical_to_physical_map: torch.Tensor,
        logical_replica_count: torch.Tensor,
    ):
        self.language_model.set_eplb_state(
            expert_load_view, logical_to_physical_map, logical_replica_count
        )
        self.expert_weights = self.language_model.expert_weights

    def update_physical_experts_metadata(
        self, num_physical_experts: int, num_local_physical_experts: int
    ):
        self.language_model.update_physical_experts_metadata(
            num_physical_experts, num_local_physical_experts
        )

857
    def _parse_and_validate_image_input(
858
        self, **kwargs: object
859
    ) -> Llama4ImagePatchInputs | None:
860
861
862
863
864
        # num_images, 1, num_chunks, channel, image_size, image_size
        pixel_values = kwargs.pop("pixel_values", None)
        if pixel_values is None:
            return None

865
        patches_per_image = kwargs.pop("patches_per_image")
866
        aspect_ratios = kwargs.pop("aspect_ratios")
867
868
869

        return Llama4ImagePatchInputs(
            type="pixel_values",
870
            pixel_values=pixel_values,
871
872
873
874
875
            patches_per_image=patches_per_image,
            aspect_ratios=aspect_ratios,
        )

    def _process_image_input(
876
877
        self, image_input: Llama4ImagePatchInputs
    ) -> MultiModalEmbeddings:
878
        assert self.vision_model and self.multi_modal_projector
879
        pixel_values = image_input["pixel_values"]
880
        patches_per_image = image_input["patches_per_image"].tolist()
881

882
883
884
        # shard image input
        if self.use_data_parallel:
            vision_embeddings_flat = run_dp_sharded_vision_model(
885
                pixel_values, self.vision_model
886
            )
887
        else:
888
            vision_embeddings_flat = self.vision_model(pixel_values)
889

890
        vision_embeddings_flat = self.multi_modal_projector(vision_embeddings_flat)
891
892
893
894
895

        return [
            img.flatten(0, 1)
            for img in vision_embeddings_flat.split(patches_per_image, dim=0)
        ]
896

897
    def embed_multimodal(self, **kwargs) -> MultiModalEmbeddings:
898
899
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
900
            return []
901

902
903
904
905
        with (
            set_forward_context(None, self.vllm_config),
        ):
            return self._process_image_input(image_input)
906
907
908
909
910

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
911
912
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
913
        **kwargs: object,
914
    ) -> torch.Tensor | IntermediateTensors:
915
916
917
        if intermediate_tensors is not None:
            inputs_embeds = None

918
919
920
        return self.language_model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
921
922
923
924

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
925
    ) -> torch.Tensor | None:
926
        return self.language_model.compute_logits(hidden_states)
927
928
929

    def separate_weights(
        self,
930
        weights: Iterable[tuple[str, torch.Tensor]],
931
        prefix: str,
932
    ) -> tuple[Iterable[tuple[str, torch.Tensor]], Iterable[tuple[str, torch.Tensor]]]:
933
934
        weights1, weights2 = tee(weights, 2)

935
        def get_prefix_weights() -> Iterable[tuple[str, torch.Tensor]]:
936
937
938
939
            for name, data in weights1:
                if name.startswith(prefix):
                    yield (name, data)

940
        def get_other_weights() -> Iterable[tuple[str, torch.Tensor]]:
941
942
943
944
945
946
            for name, data in weights2:
                if not name.startswith(prefix):
                    yield (name, data)

        return get_prefix_weights(), get_other_weights()

947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
    def _consolidate_qkv_weights(
        self, weights: Iterable[tuple[str, torch.Tensor]]
    ) -> Iterable[tuple[str, torch.Tensor]]:
        qkv_idx_mappings = {
            ".self_attn.q_proj": 0,
            ".self_attn.k_proj": 1,
            ".self_attn.v_proj": 2,
        }
        qkv_weights = {}
        for name, loaded_weight in weights:
            for weight_name, idx in qkv_idx_mappings.items():
                if weight_name not in name:
                    continue
                new_name = name.replace(weight_name, ".self_attn.qkv_proj")
                if new_name not in qkv_weights:
                    qkv_weights[new_name] = [None] * 3
                qkv_weights[new_name][idx] = loaded_weight
                break
            else:
                yield name, loaded_weight
        for key, weight in qkv_weights.items():
            qkv_weight = torch.cat(weight, dim=0)
            yield key, qkv_weight

971
972
973
    def _rename_weight_for_modelopt_checkpoint(self, name: str) -> str:
        """Rename weights from ModelOpt llama4 fp8 checkpoints to vLLM
        format."""
974
975
976
977
978
979
        if name.startswith("model.") or name.startswith("language_model.model."):
            renamed = (
                name.replace("model.", "language_model.model.", 1)
                if name.startswith("model.")
                else name
            )
980
            # Handle expert scale parameters with flat naming
981
982
983
            if "feed_forward.experts." in name and (
                "_input_scale" in name or "_weight_scale" in name
            ):
984
985
                # Map checkpoint naming to vLLM's expected naming
                if "down_proj_input_scale" in renamed:
986
                    return renamed.replace("down_proj_input_scale", "w2_input_scale")
987
                elif "down_proj_weight_scale" in renamed:
988
                    return renamed.replace("down_proj_weight_scale", "w2_weight_scale")
989
                elif "gate_up_proj_input_scale" in renamed:
990
991
992
                    return renamed.replace(
                        "gate_up_proj_input_scale", "w13_input_scale"
                    )
993
                elif "gate_up_proj_weight_scale" in renamed:
994
995
996
                    return renamed.replace(
                        "gate_up_proj_weight_scale", "w13_weight_scale"
                    )
997
998
999
                return renamed

            # Handle attention scale parameters
1000
            elif "self_attn." in name and (".k_scale" in name or ".v_scale" in name):
1001
1002
1003
1004
1005
1006
1007
                if ".k_proj.k_scale" in renamed:
                    return renamed.replace(".k_proj.k_scale", ".attn.k_scale")
                elif ".v_proj.v_scale" in renamed:
                    return renamed.replace(".v_proj.v_scale", ".attn.v_scale")
                return renamed

            # Standard model.* to language_model.model.* renaming
1008
            return renamed
1009
1010

        elif name.startswith("lm_head.weight"):
1011
            return name.replace("lm_head.weight", "language_model.lm_head.weight")
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021

        return name

    def _separate_and_rename_weights(
        self, weights: Iterable[tuple[str, torch.Tensor]]
    ) -> tuple[list[tuple[str, torch.Tensor]], list[tuple[str, torch.Tensor]]]:
        """Rename weights and separate them into language_model and other
        weights."""
        language_model_weights = []
        other_weights = []
1022

1023
1024
        for name, weight in weights:
            renamed = self._rename_weight_for_modelopt_checkpoint(name)
1025

1026
1027
1028
1029
            attr = renamed.split(".", 1)[0]
            if isinstance(getattr(self, attr), (LMMissingLayer, TowerMissingLayer)):
                continue

1030
1031
1032
1033
1034
1035
1036
1037
            if renamed.startswith("language_model."):
                language_model_weights.append((renamed, weight))
            else:
                other_weights.append((renamed, weight))

        return language_model_weights, other_weights

    def _handle_expert_scale_broadcasting(
1038
        self, weights: list[tuple[str, torch.Tensor]], params_dict: dict
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
    ) -> tuple[list[tuple[str, torch.Tensor]], set[str]]:
        """Handle expert scale parameters that need broadcasting.

        ModelOpt checkpoints use a single value tensor scalar for BMM style
        experts, vLLM expects the scale to be broadcasted across all experts.
        """
        regular_weights = []
        expert_scale_weights = []
        updated_params = set()

        for name, weight in weights:
            # Check if this is an expert scale parameter that needs broadcasting
1051
1052
1053
1054
1055
            if (
                "feed_forward.experts." in name
                and "scale" in name
                and ".shared_expert" not in name
            ):
1056
1057
                if name in params_dict:
                    param = params_dict[name]
1058
1059
1060
1061
1062
                    if (
                        hasattr(param, "data")
                        and param.data.numel() > 1
                        and weight.numel() == 1
                    ):
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
                        # Broadcast single value to all experts
                        param.data.fill_(weight.item())
                        updated_params.add(name)
                        continue

                expert_scale_weights.append((name, weight))
            else:
                regular_weights.append((name, weight))

        return regular_weights, expert_scale_weights, updated_params

1074
1075
1076
1077
1078
1079
    def _load_other_weights(
        self,
        other_weights: Iterable[tuple[str, torch.Tensor]],
        params_dict: dict,
        stacked_params_mapping: list,
    ) -> set[str]:
1080
1081
        """Load non-language-model weights with stacking support."""
        updated_params = set()
1082

1083
1084
1085
        if self.use_data_parallel:
            other_weights = self._consolidate_qkv_weights(other_weights)

1086
        for name, loaded_weight in other_weights:
1087
            # Try stacked parameter mapping first
1088
            for param_name, weight_name, shard_id in stacked_params_mapping:
1089
                if weight_name not in name or self.use_data_parallel:
1090
1091
1092
1093
1094
1095
1096
1097
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
                updated_params.add(name)
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
1098
                # Use regular weight loading
1099
                param = params_dict[name]
1100
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
1101
1102
                weight_loader(param, loaded_weight)
                updated_params.add(name)
1103
1104
1105

        return updated_params

1106
1107
1108
1109
    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
        # Params for weights, fp8 weight scales, fp8 activation scales
        # (param_name, weight_name, expert_id, shard_id)
        return FusedMoE.make_expert_params_mapping(
1110
            self,
1111
1112
1113
1114
1115
1116
1117
            ckpt_gate_proj_name="gate_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="up_proj",
            num_experts=self.config.text_config.num_local_experts,
            num_redundant_experts=self.num_redundant_experts,
        )

1118
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".self_attn.qkv_proj", ".self_attn.q_proj", "q"),
            (".self_attn.qkv_proj", ".self_attn.k_proj", "k"),
            (".self_attn.qkv_proj", ".self_attn.v_proj", "v"),
            # Shared expert gate_up_proj stacking
            (".shared_expert.gate_up_proj", ".shared_expert.gate_proj", 0),
            (".shared_expert.gate_up_proj", ".shared_expert.up_proj", 1),
            # Feed forward gate_up_proj stacking (for non-MoE layers if any)
            (".feed_forward.gate_up_proj", ".feed_forward.gate_proj", 0),
            (".feed_forward.gate_up_proj", ".feed_forward.up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
        updated_params: set[str] = set()

        # Separate and rename weights
1135
1136
1137
        language_model_weights, other_weights = self._separate_and_rename_weights(
            weights
        )
1138
1139
1140

        # Handle expert scale parameters
        regular_weights, expert_scale_weights, updated_params_from_experts = (
1141
1142
            self._handle_expert_scale_broadcasting(language_model_weights, params_dict)
        )
1143
1144
1145
1146
1147
1148
1149
1150
        updated_params.update(updated_params_from_experts)

        loader = AutoWeightsLoader(self)
        loaded_language_model_params = loader.load_weights(regular_weights)
        assert loaded_language_model_params is not None
        updated_params.update(loaded_language_model_params)

        if expert_scale_weights:
1151
            loaded_expert_scale_params = loader.load_weights(expert_scale_weights)
1152
1153
1154
1155
            if loaded_expert_scale_params:
                updated_params.update(loaded_expert_scale_params)

        updated_params.update(
1156
1157
            self._load_other_weights(other_weights, params_dict, stacked_params_mapping)
        )
1158

1159
        return updated_params
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169

    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="language_model",
            connector="multi_modal_projector.",
            tower_model="vision_model.",
        )