mllama4.py 38.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
#
# Copyright 2025 the LLAMA4, Meta Inc., vLLM, and HuggingFace Inc. team.
# All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from collections.abc import Iterable, Mapping
from itertools import tee
22
from typing import Annotated, Literal
23
24
25
26
27
28
29

import torch
from torch import nn
from transformers import BatchFeature, Llama4Config, Llama4VisionConfig
from transformers.image_utils import SizeDict
from transformers.models.llama4 import Llama4Processor
from transformers.models.llama4.image_processing_llama4_fast import (
30
31
32
    find_supported_resolutions,
    get_best_fit,
)
33
34
35

from vllm.attention.layer import MultiHeadAttention
from vllm.config import VllmConfig
36
from vllm.config.multimodal import BaseDummyOptions
37
from vllm.distributed import get_tensor_model_parallel_world_size
38
39
40
41
42
43
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    QKVParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
44
45
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
46
from vllm.model_executor.model_loader.utils import initialize_model
47
48
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal import MULTIMODAL_REGISTRY
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
)
from vllm.multimodal.parse import ImageProcessorItems, ImageSize, MultiModalDataItems
from vllm.multimodal.processing import (
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    InputProcessingContext,
    PromptReplacement,
    PromptUpdate,
    PromptUpdateDetails,
)
63
from vllm.multimodal.profiling import BaseDummyInputsBuilder
64
from vllm.sequence import IntermediateTensors
65
from vllm.utils.tensor_schema import TensorSchema, TensorShape
66

67
68
69
70
71
72
from .interfaces import (
    MultiModalEmbeddings,
    SupportsEagle3,
    SupportsMultiModal,
    SupportsPP,
)
73
from .llama4 import Llama4ForCausalLM
74
from .utils import AutoWeightsLoader, maybe_prefix
75
from .vision import run_dp_sharded_vision_model
76
77


78
class Llama4ImagePatchInputs(TensorSchema):
79
    """
80
81
82
83
84
    Dimensions:
        - batch_size: Batch size
        - total_num_chunks: Batch size * number of chunks
        - num_channels: Number of channels
        - image_size: Size of each image
85
    """
86
87
88

    type: Literal["pixel_values"] = "pixel_values"

89
    pixel_values: Annotated[
90
91
92
        torch.Tensor,
        TensorShape("total_num_chunks", "num_channels", "image_size", "image_size"),
    ]
93
94

    patches_per_image: Annotated[torch.Tensor, TensorShape("batch_size")]
95
96
    """
    The number of total patches for each image in the batch.
97
    
98
    This is used to split the embeddings which has the first two dimensions
99
    flattened just like `pixel_values`.
100
    """
101

102
    aspect_ratios: Annotated[torch.Tensor, TensorShape("batch_size", 2)]
103
104
105
    """
    A list of aspect ratios corresponding to the number of tiles
    in each dimension that each image in the batch corresponds to.
106
    Each aspect ratio is a pair (ratio_h, ratio_w).
107
108
109
110
    """


class Llama4VisionMLP(nn.Module):
111
112
113
114
115
116
117
    def __init__(
        self,
        input_size: int,
        intermediate_size: int,
        output_size: int,
        bias: bool,
        output_activation: bool,
118
        quant_config: QuantizationConfig | None = None,
119
120
121
        prefix: str = "",
        use_data_parallel: bool = False,
    ):
122
        super().__init__()
123
        self.fc1 = ColumnParallelLinear(
124
125
126
127
128
            input_size=input_size,
            output_size=intermediate_size,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.fc1",
129
            disable_tp=use_data_parallel,
130
        )
131
        self.fc2 = RowParallelLinear(
132
133
134
135
136
            input_size=intermediate_size,
            output_size=output_size,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.fc2",
137
            disable_tp=use_data_parallel,
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
        )
        self.activation_fn = nn.GELU()
        self.output_activation = output_activation

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.fc1(hidden_states)
        hidden_states = self.activation_fn(hidden_states)
        hidden_states, _ = self.fc2(hidden_states)
        if self.output_activation:
            return self.activation_fn(hidden_states)
        return hidden_states


class Llama4MultiModalProjector(nn.Module):
    def __init__(
        self,
        config,
155
        quant_config: QuantizationConfig | None = None,
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
        prefix: str = "",
    ):
        super().__init__()
        self.linear_1 = ColumnParallelLinear(
            input_size=config.vision_config.vision_output_dim,
            output_size=config.text_config.hidden_size,
            bias=False,
            quant_config=quant_config,
            gather_output=True,
            prefix=f"{prefix}.linear_1",
        )

    def forward(self, image_features):
        hidden_states, _ = self.linear_1(image_features)
        return hidden_states


def pixel_shuffle(input_tensor, shuffle_ratio):
    # input_tensor: [batch_size, num_patches, channels]
    batch_size, num_patches, channels = input_tensor.shape
    patch_size = int(math.sqrt(num_patches))

    input_tensor = input_tensor.view(batch_size, patch_size, patch_size, -1)
    batch_size, height, width, channels = input_tensor.size()

181
182
183
    reshaped_tensor = input_tensor.view(
        batch_size, height, int(width * shuffle_ratio), int(channels / shuffle_ratio)
    )
184
185
    reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous()

186
187
188
189
190
191
    reshaped_tensor = reshaped_tensor.view(
        batch_size,
        int(height * shuffle_ratio),
        int(width * shuffle_ratio),
        int(channels / (shuffle_ratio**2)),
    )
192
193
    reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous()

194
    output_tensor = reshaped_tensor.view(batch_size, -1, reshaped_tensor.shape[-1])
195
196
197
198
199
200
201
    return output_tensor


class Llama4VisionPixelShuffleMLP(nn.Module):
    def __init__(
        self,
        config,
202
        quant_config: QuantizationConfig | None = None,
203
        prefix: str = "",
204
        use_data_parallel: bool = False,
205
206
207
    ):
        super().__init__()
        self.pixel_shuffle_ratio = config.pixel_shuffle_ratio
208
209
210
        self.inner_dim = int(
            config.projector_input_dim // (self.pixel_shuffle_ratio**2)
        )
211
212
213
214
215
216
217
218
        self.output_dim = config.projector_output_dim
        self.mlp = Llama4VisionMLP(
            input_size=config.intermediate_size,
            intermediate_size=config.projector_input_dim,
            output_size=config.projector_output_dim,
            bias=config.multi_modal_projector_bias,
            output_activation=True,
            quant_config=quant_config,
219
220
221
            prefix=f"{prefix}.mlp",
            use_data_parallel=use_data_parallel,
        )
222
223

    def forward(self, encoded_patches: torch.Tensor) -> torch.Tensor:
224
        encoded_patches = pixel_shuffle(encoded_patches, self.pixel_shuffle_ratio)
225
226
227
228
229
230
231
        return self.mlp(encoded_patches)


class Llama4VisionAttention(nn.Module):
    def __init__(
        self,
        config: Llama4VisionConfig,
232
        quant_config: QuantizationConfig | None,
233
        prefix: str = "",
234
        use_data_parallel: bool = False,
235
236
237
    ):
        super().__init__()
        self.config = config
238
239
240
        self.tp_size = (
            1 if use_data_parallel else get_tensor_model_parallel_world_size()
        )
241
242
243
244
245
246
247
248
249
250
        self.embed_dim = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = config.hidden_size // self.num_heads
        assert self.num_heads % self.tp_size == 0
        self.num_local_heads = self.num_heads // self.tp_size
        self.q_size = self.num_local_heads * self.head_dim
        self.kv_size = self.num_local_heads * self.head_dim
        self.attention_dropout = config.attention_dropout
        self.scaling = self.head_dim**-0.5

251
252
253
        self.attn = MultiHeadAttention(
            self.num_local_heads, self.head_dim, self.scaling
        )
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286

        if use_data_parallel:
            self.qkv_proj = ReplicatedLinear(
                self.embed_dim,
                self.q_size + 2 * self.kv_size,
                bias=True,
                quant_config=quant_config,
                prefix=f"{prefix}.qkv_proj",
            )
            self.o_proj = ReplicatedLinear(
                self.num_heads * self.head_dim,
                self.embed_dim,
                bias=True,
                quant_config=quant_config,
                prefix=f"{prefix}.o_proj",
            )
        else:
            self.qkv_proj = QKVParallelLinear(
                self.embed_dim,
                self.head_dim,
                self.num_heads,
                bias=True,
                quant_config=quant_config,
                prefix=f"{prefix}.qkv_proj",
            )
            self.o_proj = RowParallelLinear(
                self.num_heads * self.head_dim,
                self.embed_dim,
                bias=True,
                input_is_parallel=True,
                quant_config=quant_config,
                prefix=f"{prefix}.o_proj",
            )
287
288
289
290
291

        self.rotary_emb = get_rope(
            head_size=self.head_dim,
            rotary_dim=config.hidden_size // config.num_attention_heads // 2,
            # number of image patches
292
            max_position=(config.image_size // config.patch_size) ** 2,
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
            base=config.rope_theta,
            rope_scaling={"rope_type": "mllama4"},
            is_neox_style=False,
            dtype=torch.complex64,  # important
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        input_shape = hidden_states.shape[:-1]

        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)

        q = q.view(q.shape[0], q.shape[1], self.num_local_heads, self.head_dim)
        k = k.view(k.shape[0], k.shape[1], self.num_local_heads, self.head_dim)
        q, k = self.rotary_emb(q, k)

        q = q.view(q.shape[0], q.shape[1], -1)
        k = k.view(k.shape[0], k.shape[1], -1)

        attn_output = self.attn(q, k, v)
        attn_output = attn_output.reshape(*input_shape, -1).contiguous()
        attn_output, _ = self.o_proj(attn_output)

        return attn_output


class Llama4VisionEncoderLayer(nn.Module):
    def __init__(
        self,
        config: Llama4VisionConfig,
326
        quant_config: QuantizationConfig | None,
327
        prefix: str = "",
328
        use_data_parallel: bool = False,
329
330
331
332
333
334
    ):
        super().__init__()
        self.hidden_size = config.hidden_size
        self.num_attention_heads = config.num_attention_heads
        self.intermediate_size = config.intermediate_size

335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
        self.self_attn = Llama4VisionAttention(
            config,
            quant_config=quant_config,
            prefix=f"{prefix}.self_attn",
            use_data_parallel=use_data_parallel,
        )
        self.mlp = Llama4VisionMLP(
            input_size=config.hidden_size,
            intermediate_size=config.intermediate_size,
            output_size=config.hidden_size,
            bias=True,
            output_activation=False,
            quant_config=quant_config,
            prefix=f"{prefix}.mlp",
            use_data_parallel=use_data_parallel,
        )
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370

        self.input_layernorm = nn.LayerNorm(config.hidden_size)
        self.post_attention_layernorm = nn.LayerNorm(config.hidden_size)

    def forward(
        self,
        hidden_state: torch.Tensor,
    ):
        # Self Attention
        residual = hidden_state
        hidden_state = self.input_layernorm(hidden_state)
        hidden_state = self.self_attn(hidden_state)
        hidden_state = residual + hidden_state

        # Feed forward
        residual = hidden_state
        hidden_state = self.post_attention_layernorm(hidden_state)
        hidden_state = self.mlp(hidden_state)
        hidden_state = residual + hidden_state

371
        outputs = (hidden_state,)
372
373
374
375
376
377
378
        return outputs


class Llama4VisionEncoder(nn.Module):
    def __init__(
        self,
        config: Llama4VisionConfig,
379
        quant_config: QuantizationConfig | None,
380
        prefix: str = "",
381
        use_data_parallel: bool = False,
382
383
384
    ):
        super().__init__()
        self.config = config
385
386
387
388
389
390
391
392
393
394
395
        self.layers = nn.ModuleList(
            [
                Llama4VisionEncoderLayer(
                    config,
                    quant_config=quant_config,
                    prefix=f"{prefix}.layers.{layer_idx}",
                    use_data_parallel=use_data_parallel,
                )
                for layer_idx in range(config.num_hidden_layers)
            ]
        )
396
397
398
399
400
401
402

    def forward(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        r"""
        Args:
403
            hidden_states: Input tensor of shape
404
                (batch_size, sequence_length, hidden_size).
405
                Hidden states from the model embeddings, representing
406
                the input tokens.
407
408
409
410
411
412
413
414
415
416
417
418
                associated vectors than the model's internal embedding
                lookup matrix.
        """

        for encoder_layer in self.layers:
            layer_outputs = encoder_layer(hidden_states)
            hidden_states = layer_outputs[0]

        return hidden_states


class Llama4UnfoldConvolution(nn.Module):
419
420
421
    def __init__(
        self,
        config: Llama4VisionConfig,
422
        quant_config: QuantizationConfig | None = None,
423
424
425
        prefix: str = "",
        use_data_parallel: bool = False,
    ):
426
427
428
429
        super().__init__()
        kernel_size = config.patch_size
        if isinstance(kernel_size, int):
            kernel_size = (kernel_size, kernel_size)
430
        self.unfold = torch.nn.Unfold(kernel_size=kernel_size, stride=config.patch_size)
431
432
433
434
435
436
437
438
439
        self.linear = ColumnParallelLinear(
            input_size=config.num_channels * kernel_size[0] * kernel_size[1],
            output_size=config.hidden_size,
            bias=False,
            gather_output=True,
            quant_config=quant_config,
            prefix=f"{prefix}.linear",
            disable_tp=use_data_parallel,
        )
440
441
442
443
444
445
446
447
448
449
450
451

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states = self.unfold(hidden_states)
        hidden_states = hidden_states.permute(0, 2, 1)
        hidden_states, _ = self.linear(hidden_states)
        return hidden_states


class Llama4VisionModel(nn.Module):
    def __init__(
        self,
        config: Llama4VisionConfig,
452
        quant_config: QuantizationConfig | None = None,
453
        prefix: str = "",
454
        use_data_parallel: bool = False,
455
456
457
458
459
460
461
462
    ):
        super().__init__()
        self.config = config
        self.image_size = config.image_size
        self.patch_size = config.patch_size
        self.hidden_size = config.hidden_size
        self.num_channels = config.num_channels

463
        self.num_patches = (self.image_size // self.patch_size) ** 2 + 1
464
465
466
467
468
        self.scale = config.hidden_size**-0.5

        self.patch_embedding = Llama4UnfoldConvolution(
            config,
            quant_config=quant_config,
469
470
471
            prefix=f"{prefix}.patch_embedding",
            use_data_parallel=use_data_parallel,
        )
472

473
        self.class_embedding = nn.Parameter(self.scale * torch.randn(self.hidden_size))
474
        self.positional_embedding_vlm = nn.Parameter(
475
476
            self.scale * torch.randn(self.num_patches, self.hidden_size)
        )
477
478
479
480
481
482

        # layer norms
        self.layernorm_pre = nn.LayerNorm(self.hidden_size, eps=1e-5)
        self.layernorm_post = nn.LayerNorm(self.hidden_size, eps=1e-5)

        # encoders
483
484
485
486
487
488
        self.model = Llama4VisionEncoder(
            config,
            quant_config=quant_config,
            prefix=f"{prefix}.model",
            use_data_parallel=use_data_parallel,
        )
489
        self.vision_adapter = Llama4VisionPixelShuffleMLP(
490
491
492
493
494
            config,
            quant_config,
            prefix=f"{prefix}.vision_adapter",
            use_data_parallel=use_data_parallel,
        )
495
496
497
498
499
500
501
502
503
504

    def forward(
        self,
        images_flattened: torch.Tensor,
    ) -> torch.Tensor:
        # Patch embedding
        hidden_state = self.patch_embedding(images_flattened)
        num_tiles, num_patches, hidden_dim = hidden_state.shape

        # Add cls token
505
506
507
        class_embedding = self.class_embedding.expand(
            hidden_state.shape[0], 1, hidden_state.shape[-1]
        )
508
509
510
511
512
513
514
515
516
517
518
        hidden_state = torch.cat([hidden_state, class_embedding], dim=1)
        num_patches += 1

        # Position embeddings
        hidden_state = hidden_state.reshape(
            num_tiles,
            1,
            num_patches,
            hidden_dim,
        )
        positional_embedding = self.positional_embedding_vlm.to(
519
520
            dtype=hidden_state.dtype, device=hidden_state.device
        )
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
        hidden_state = hidden_state + positional_embedding
        hidden_state = self.layernorm_pre(hidden_state)
        hidden_state = hidden_state.view(num_tiles, -1, hidden_dim)

        # Apply encoder
        hidden_state = self.model(hidden_state)
        hidden_state = self.layernorm_post(hidden_state)

        # Remove CLS token output
        hidden_state = hidden_state[:, :-1, :]

        # now, we use Llama4VisionPixelShuffle + mlp to project embeddings
        hidden_state = self.vision_adapter(hidden_state)

        return hidden_state


class Mllama4ProcessingInfo(BaseProcessingInfo):
    def __init__(self, ctx: InputProcessingContext) -> None:
        super().__init__(ctx)

    def get_hf_config(self) -> Llama4Config:
        return self.ctx.get_hf_config(Llama4Config)

    def get_hf_processor(self, **kwargs: object) -> Llama4Processor:
546
547
548
        return self.ctx.get_hf_processor(
            Llama4Processor, use_fast=kwargs.pop("use_fast", True), **kwargs
        )
549

550
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
551
552
553
        # Although vLLM can support more images from an infra capability
        # perspective, we do not recommend using >10 images in practice.
        return {"image": None}
554
555
556
557
558
559

    @staticmethod
    def get_patch_per_chunk(vision_config: Llama4VisionConfig) -> int:
        image_size = vision_config.image_size
        patch_size = vision_config.patch_size

560
561
562
        assert image_size % patch_size == 0, (
            f"chunk size {image_size} should be multiple of "
        )
563
564
565
        f"patch_size {patch_size}"

        ds_ratio = int(round(1.0 / (vision_config.pixel_shuffle_ratio**2)))
566
        return (image_size // patch_size) ** 2 // ds_ratio
567
568
569
570
571
572
573
574
575

    def get_max_num_tiles(self) -> int:
        image_processor = self.get_hf_processor().image_processor
        return image_processor.max_patches

    def get_image_size_with_most_features(self) -> ImageSize:
        vision_config = self.get_hf_config().vision_config
        image_size = vision_config.image_size
        # Result in the max possible feature size (h:w = 16:1)
576
        return ImageSize(height=self.get_max_num_tiles() * image_size, width=image_size)
577
578


579
class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo]):
580
581
582
583
584
    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
585
        tok_kwargs: Mapping[str, object],
586
587
588
589
590
591
592
593
594
    ) -> BatchFeature:
        tokenizer = self.info.get_tokenizer()

        if mm_data is None:
            return tokenizer(prompt, add_special_tokens=False)  # exclude bos
        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
595
            tok_kwargs=tok_kwargs,
596
597
598
599
600
601
602
        )

        processor = self.info.get_hf_processor(**mm_kwargs)
        image_processor = processor.image_processor
        vision_config = self.info.get_hf_config().vision_config

        if processed_outputs.get("pixel_values") is not None:
603
604
605
            assert "images" in mm_data, (
                "images expected to be in mm_data when pixel_values is present"
            )
606
607

            images = mm_data["images"]
608
609
610
611
612
            parsed_images = (
                self._get_data_parser()
                .parse_mm_data({"image": images})
                .get_items("image", ImageProcessorItems)
            )
613
614
615
616
617
618
619
620
621
622

            tile_size = vision_config.image_size
            possible_resolutions = find_supported_resolutions(
                max_num_chunks=self.info.get_max_num_tiles(),
                patch_size=SizeDict(height=tile_size, width=tile_size),
            )
            best_fit_sizes = [
                get_best_fit(
                    (image.size[1], image.size[0]),
                    torch.tensor(possible_resolutions),
623
                    resize_to_max_canvas=image_processor.resize_to_max_canvas,
624
625
                )
                for image in parsed_images
626
627
            ]
            # TODO tile height/width do not necessarily need to match
628
629
630
631
            aspect_ratios = [
                (image_size[0] // tile_size, image_size[1] // tile_size)
                for image_size in best_fit_sizes
            ]
632
            patches_per_image = [
633
                1 if r_h * r_w == 1 else 1 + r_h * r_w for (r_h, r_w) in aspect_ratios
634
635
            ]

636
            processed_outputs["aspect_ratios"] = torch.tensor(aspect_ratios)
637
            processed_outputs["patches_per_image"] = torch.tensor(patches_per_image)
638
639
640
641
642
643
644
645
646
647
648

        return processed_outputs

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        patches_per_image = hf_inputs.get("patches_per_image", torch.empty(0))
        return dict(
            pixel_values=MultiModalFieldConfig.flat_from_sizes(
649
650
                "image", patches_per_image
            ),
651
652
653
654
655
656
657
658
            patches_per_image=MultiModalFieldConfig.batched("image"),
            aspect_ratios=MultiModalFieldConfig.batched("image"),
        )

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
659
        out_mm_kwargs: MultiModalKwargsItems,
660
    ) -> list[PromptUpdate]:
661
662
663
664
665
666
        config = self.info.get_hf_config()
        vision_config = config.vision_config

        num_patches_per_chunk = self.info.get_patch_per_chunk(vision_config)
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
        image_token = hf_processor.image_token
667
        img_patch_token = hf_processor.img_patch_token
668
669

        def get_replacement(item_idx: int):
670
671
            out_item = out_mm_kwargs["image"][item_idx]
            aspect_ratio = out_item["aspect_ratios"].data
672
673

            repl = hf_processor._prompt_split_image(
674
                aspect_ratio=aspect_ratio,
675
676
677
678
                num_patches_per_chunk=num_patches_per_chunk,
            )

            return PromptUpdateDetails.select_text(repl, img_patch_token)
679
680
681
682
683
684
685
686
687
688
689

        return [
            PromptReplacement(
                modality="image",
                target=image_token,
                replacement=get_replacement,
            )
        ]


class Mllama4DummyInputsBuilder(BaseDummyInputsBuilder[Mllama4ProcessingInfo]):
690
691
692
693
694
695
696
697
698
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

        processor = self.info.get_hf_processor()
        image_token = processor.fake_image_token

        return image_token * num_images

    def get_dummy_mm_data(
699
700
701
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
702
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
703
    ) -> MultiModalDataDict:
704
705
        num_images = mm_counts.get("image", 0)

706
        (target_width, target_height) = self.info.get_image_size_with_most_features()
707

708
709
        image_overrides = mm_options.get("image") if mm_options else None

710
        return {
711
712
713
714
715
716
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            )
717
718
719
720
721
722
723
724
        }


@MULTIMODAL_REGISTRY.register_processor(
    Mllama4MultiModalProcessor,
    info=Mllama4ProcessingInfo,
    dummy_inputs=Mllama4DummyInputsBuilder,
)
725
726
727
class Llama4ForConditionalGeneration(
    nn.Module, SupportsMultiModal, SupportsPP, SupportsEagle3
):
728
729
    merge_by_field_config = True

730
731
    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
732
        "gate_up_proj": ["gate_proj", "up_proj"],
733
734
    }

735
736
    supports_encoder_tp_data = True

737
    @classmethod
738
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
739
740
741
742
743
        if modality.startswith("image"):
            return "<|image|>"

        raise ValueError("Only image modality is supported")

744
745
746
747
748
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
749
750
        self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"

751
752
753
        self.config = config
        self.quant_config = quant_config
        self.multimodal_config = multimodal_config
754
755
756
757
758
759
760
761
        if multimodal_config.get_limit_per_prompt("image"):
            self.vision_model = Llama4VisionModel(
                config.vision_config,
                None,
                prefix=maybe_prefix(prefix, "vision_model"),
                use_data_parallel=self.use_data_parallel,
            )
            self.multi_modal_projector = Llama4MultiModalProjector(
762
763
                self.config, None, prefix=maybe_prefix(prefix, "multi_modal_projector")
            )
764
765
766
        else:
            self.vision_model = None
            self.multi_modal_projector = None
767
        self.language_model = initialize_model(
768
769
770
            vllm_config=vllm_config.with_hf_config(
                config.text_config, ["LlamaForCausalLM"]
            ),
771
772
773
774
775
            prefix=maybe_prefix(prefix, "language_model"),
            model_class=Llama4ForCausalLM,
        )

        self.make_empty_intermediate_tensors = (
776
777
            self.language_model.make_empty_intermediate_tensors
        )
778

779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
    def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
        """Set which layers should output auxiliary hidden states for EAGLE3."""
        # Delegate to underlying language model (Llama4ForCausalLM)
        assert hasattr(self.language_model, "set_aux_hidden_state_layers")
        self.language_model.set_aux_hidden_state_layers(layers)

    def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
        """Get the layer indices for auxiliary hidden state outputs.

        Note: The GPU model runner will override this with layers from
        the speculative config if available, providing dynamic configuration.
        """
        # Delegate to underlying language model (Llama4ForCausalLM)
        assert hasattr(self.language_model, "get_eagle3_aux_hidden_state_layers")
        return self.language_model.get_eagle3_aux_hidden_state_layers()

795
    def _parse_and_validate_image_input(
796
        self, **kwargs: object
797
    ) -> Llama4ImagePatchInputs | None:
798
799
800
801
802
        # num_images, 1, num_chunks, channel, image_size, image_size
        pixel_values = kwargs.pop("pixel_values", None)
        if pixel_values is None:
            return None

803
        patches_per_image = kwargs.pop("patches_per_image")
804
        aspect_ratios = kwargs.pop("aspect_ratios")
805
806
807

        return Llama4ImagePatchInputs(
            type="pixel_values",
808
            pixel_values=pixel_values,
809
810
811
812
813
            patches_per_image=patches_per_image,
            aspect_ratios=aspect_ratios,
        )

    def _process_image_input(
814
815
        self, image_input: Llama4ImagePatchInputs
    ) -> MultiModalEmbeddings:
816
        assert self.vision_model and self.multi_modal_projector
817
        pixel_values = image_input["pixel_values"]
818
        patches_per_image = image_input["patches_per_image"].tolist()
819

820
821
822
        # shard image input
        if self.use_data_parallel:
            vision_embeddings_flat = run_dp_sharded_vision_model(
823
                pixel_values, self.vision_model
824
            )
825
        else:
826
            vision_embeddings_flat = self.vision_model(pixel_values)
827

828
        vision_embeddings_flat = self.multi_modal_projector(vision_embeddings_flat)
829
830
831
832
833

        return [
            img.flatten(0, 1)
            for img in vision_embeddings_flat.split(patches_per_image, dim=0)
        ]
834

835
836
837
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

838
    def get_multimodal_embeddings(self, **kwargs) -> MultiModalEmbeddings:
839
840
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
841
            return []
842

843
        return self._process_image_input(image_input)
844
845
846
847
848

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
849
850
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
851
        **kwargs: object,
852
    ) -> torch.Tensor | IntermediateTensors:
853
854
855
        if intermediate_tensors is not None:
            inputs_embeds = None

856
857
858
        return self.language_model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
859
860
861
862

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
863
    ) -> torch.Tensor | None:
864
        return self.language_model.compute_logits(hidden_states)
865
866
867

    def separate_weights(
        self,
868
        weights: Iterable[tuple[str, torch.Tensor]],
869
        prefix: str,
870
    ) -> tuple[Iterable[tuple[str, torch.Tensor]], Iterable[tuple[str, torch.Tensor]]]:
871
872
        weights1, weights2 = tee(weights, 2)

873
        def get_prefix_weights() -> Iterable[tuple[str, torch.Tensor]]:
874
875
876
877
            for name, data in weights1:
                if name.startswith(prefix):
                    yield (name, data)

878
        def get_other_weights() -> Iterable[tuple[str, torch.Tensor]]:
879
880
881
882
883
884
            for name, data in weights2:
                if not name.startswith(prefix):
                    yield (name, data)

        return get_prefix_weights(), get_other_weights()

885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
    def _consolidate_qkv_weights(
        self, weights: Iterable[tuple[str, torch.Tensor]]
    ) -> Iterable[tuple[str, torch.Tensor]]:
        qkv_idx_mappings = {
            ".self_attn.q_proj": 0,
            ".self_attn.k_proj": 1,
            ".self_attn.v_proj": 2,
        }
        qkv_weights = {}
        for name, loaded_weight in weights:
            for weight_name, idx in qkv_idx_mappings.items():
                if weight_name not in name:
                    continue
                new_name = name.replace(weight_name, ".self_attn.qkv_proj")
                if new_name not in qkv_weights:
                    qkv_weights[new_name] = [None] * 3
                qkv_weights[new_name][idx] = loaded_weight
                break
            else:
                yield name, loaded_weight
        for key, weight in qkv_weights.items():
            qkv_weight = torch.cat(weight, dim=0)
            yield key, qkv_weight

909
910
911
    def _rename_weight_for_modelopt_checkpoint(self, name: str) -> str:
        """Rename weights from ModelOpt llama4 fp8 checkpoints to vLLM
        format."""
912
913
914
915
916
917
        if name.startswith("model.") or name.startswith("language_model.model."):
            renamed = (
                name.replace("model.", "language_model.model.", 1)
                if name.startswith("model.")
                else name
            )
918
            # Handle expert scale parameters with flat naming
919
920
921
            if "feed_forward.experts." in name and (
                "_input_scale" in name or "_weight_scale" in name
            ):
922
923
                # Map checkpoint naming to vLLM's expected naming
                if "down_proj_input_scale" in renamed:
924
                    return renamed.replace("down_proj_input_scale", "w2_input_scale")
925
                elif "down_proj_weight_scale" in renamed:
926
                    return renamed.replace("down_proj_weight_scale", "w2_weight_scale")
927
                elif "gate_up_proj_input_scale" in renamed:
928
929
930
                    return renamed.replace(
                        "gate_up_proj_input_scale", "w13_input_scale"
                    )
931
                elif "gate_up_proj_weight_scale" in renamed:
932
933
934
                    return renamed.replace(
                        "gate_up_proj_weight_scale", "w13_weight_scale"
                    )
935
936
937
                return renamed

            # Handle attention scale parameters
938
            elif "self_attn." in name and (".k_scale" in name or ".v_scale" in name):
939
940
941
942
943
944
945
                if ".k_proj.k_scale" in renamed:
                    return renamed.replace(".k_proj.k_scale", ".attn.k_scale")
                elif ".v_proj.v_scale" in renamed:
                    return renamed.replace(".v_proj.v_scale", ".attn.v_scale")
                return renamed

            # Standard model.* to language_model.model.* renaming
946
            return renamed
947
948

        elif name.startswith("lm_head.weight"):
949
            return name.replace("lm_head.weight", "language_model.lm_head.weight")
950
951
952
953
954
955
956
957
958
959

        return name

    def _separate_and_rename_weights(
        self, weights: Iterable[tuple[str, torch.Tensor]]
    ) -> tuple[list[tuple[str, torch.Tensor]], list[tuple[str, torch.Tensor]]]:
        """Rename weights and separate them into language_model and other
        weights."""
        language_model_weights = []
        other_weights = []
960

961
962
        for name, weight in weights:
            renamed = self._rename_weight_for_modelopt_checkpoint(name)
963

964
965
966
967
968
969
970
971
            if renamed.startswith("language_model."):
                language_model_weights.append((renamed, weight))
            else:
                other_weights.append((renamed, weight))

        return language_model_weights, other_weights

    def _handle_expert_scale_broadcasting(
972
        self, weights: list[tuple[str, torch.Tensor]], params_dict: dict
973
974
975
976
977
978
979
980
981
982
983
984
    ) -> tuple[list[tuple[str, torch.Tensor]], set[str]]:
        """Handle expert scale parameters that need broadcasting.

        ModelOpt checkpoints use a single value tensor scalar for BMM style
        experts, vLLM expects the scale to be broadcasted across all experts.
        """
        regular_weights = []
        expert_scale_weights = []
        updated_params = set()

        for name, weight in weights:
            # Check if this is an expert scale parameter that needs broadcasting
985
986
987
988
989
            if (
                "feed_forward.experts." in name
                and "scale" in name
                and ".shared_expert" not in name
            ):
990
991
                if name in params_dict:
                    param = params_dict[name]
992
993
994
995
996
                    if (
                        hasattr(param, "data")
                        and param.data.numel() > 1
                        and weight.numel() == 1
                    ):
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
                        # Broadcast single value to all experts
                        param.data.fill_(weight.item())
                        updated_params.add(name)
                        continue

                expert_scale_weights.append((name, weight))
            else:
                regular_weights.append((name, weight))

        return regular_weights, expert_scale_weights, updated_params

1008
1009
1010
1011
1012
1013
    def _load_other_weights(
        self,
        other_weights: Iterable[tuple[str, torch.Tensor]],
        params_dict: dict,
        stacked_params_mapping: list,
    ) -> set[str]:
1014
1015
        """Load non-language-model weights with stacking support."""
        updated_params = set()
1016

1017
1018
1019
        if self.use_data_parallel:
            other_weights = self._consolidate_qkv_weights(other_weights)

1020
        for name, loaded_weight in other_weights:
1021
            # Try stacked parameter mapping first
1022
            for param_name, weight_name, shard_id in stacked_params_mapping:
1023
                if weight_name not in name or self.use_data_parallel:
1024
1025
1026
1027
1028
1029
1030
1031
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
                updated_params.add(name)
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
1032
                # Use regular weight loading
1033
                param = params_dict[name]
1034
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
1035
1036
                weight_loader(param, loaded_weight)
                updated_params.add(name)
1037
1038
1039

        return updated_params

1040
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".self_attn.qkv_proj", ".self_attn.q_proj", "q"),
            (".self_attn.qkv_proj", ".self_attn.k_proj", "k"),
            (".self_attn.qkv_proj", ".self_attn.v_proj", "v"),
            # Shared expert gate_up_proj stacking
            (".shared_expert.gate_up_proj", ".shared_expert.gate_proj", 0),
            (".shared_expert.gate_up_proj", ".shared_expert.up_proj", 1),
            # Feed forward gate_up_proj stacking (for non-MoE layers if any)
            (".feed_forward.gate_up_proj", ".feed_forward.gate_proj", 0),
            (".feed_forward.gate_up_proj", ".feed_forward.up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
        updated_params: set[str] = set()

        # Separate and rename weights
1057
1058
1059
        language_model_weights, other_weights = self._separate_and_rename_weights(
            weights
        )
1060

1061
1062
1063
1064
        # Skip loading vision model and projector if they're not initialized.
        if self.vision_model is None and self.multi_modal_projector is None:
            other_weights = []

1065
1066
        # Handle expert scale parameters
        regular_weights, expert_scale_weights, updated_params_from_experts = (
1067
1068
            self._handle_expert_scale_broadcasting(language_model_weights, params_dict)
        )
1069
1070
1071
1072
1073
1074
1075
1076
        updated_params.update(updated_params_from_experts)

        loader = AutoWeightsLoader(self)
        loaded_language_model_params = loader.load_weights(regular_weights)
        assert loaded_language_model_params is not None
        updated_params.update(loaded_language_model_params)

        if expert_scale_weights:
1077
            loaded_expert_scale_params = loader.load_weights(expert_scale_weights)
1078
1079
1080
1081
            if loaded_expert_scale_params:
                updated_params.update(loaded_expert_scale_params)

        updated_params.update(
1082
1083
            self._load_other_weights(other_weights, params_dict, stacked_params_mapping)
        )
1084

1085
        return updated_params