llama4.py 30.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
#
# Copyright 2025 the LLAMA4, Meta Inc., vLLM, and HuggingFace Inc. team.
# All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only LLaMA model compatible with HuggingFace weights."""
20
21
from collections.abc import Iterable
from typing import Any, Optional
22
23
24
25
26
27
28
29

import torch
from torch import nn
from transformers import Llama4TextConfig

from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
30
from vllm.distributed import get_tensor_model_parallel_world_size
31
32
33
34
35
36
37
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (QKVParallelLinear,
                                               ReplicatedLinear,
                                               RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
38
39
from vllm.model_executor.model_loader.weight_utils import (
    default_weight_loader, maybe_remap_kv_scale_name)
40
41

from .llama import LlamaForCausalLM, LlamaMLP, LlamaModel
42
from .utils import (AutoWeightsLoader, extract_layer_index, fast_topk,
43
44
45
46
47
48
49
50
51
52
53
                    is_pp_missing_parameter)


class Llama4MoE(nn.Module):

    @staticmethod
    def custom_routing_function(
        hidden_states: torch.Tensor,
        gating_output: torch.Tensor,
        topk: int,
        renormalize: bool,
54
    ) -> tuple[torch.Tensor, torch.Tensor]:
55
        router_scores, router_indices = fast_topk(gating_output, topk, dim=-1)
56
        # pseudo-standard is that the router scores are floats
57
        router_scores = torch.sigmoid(router_scores.float())
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
        return (router_scores, router_indices.to(torch.int32))

    def __init__(self,
                 config: Llama4TextConfig,
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
        super().__init__()
        self.tp_size = get_tensor_model_parallel_world_size()
        self.top_k = config.num_experts_per_tok

        intermediate_size_moe = config.intermediate_size
        self.router = ReplicatedLinear(config.hidden_size,
                                       config.num_local_experts,
                                       bias=False,
                                       quant_config=None,
                                       prefix=f"{prefix}.router")

        self.experts = FusedMoE(
            num_experts=config.num_local_experts,
            top_k=config.num_experts_per_tok,
            hidden_size=config.hidden_size,
            custom_routing_function=Llama4MoE.custom_routing_function,
            intermediate_size=intermediate_size_moe,
            apply_router_weight_on_input=True,
            reduce_results=False,
            renormalize=False,
            quant_config=quant_config,
            prefix=f"{prefix}.experts")

        self.shared_expert = LlamaMLP(
            hidden_size=config.hidden_size,
            intermediate_size=intermediate_size_moe,
            hidden_act="silu",
            quant_config=quant_config,
            bias=False,
            prefix=f"{prefix}.shared_expert",
94
            reduce_results=self.experts.must_reduce_shared_expert_outputs(),
95
96
97
98
99
100
101
102
103
104
105
106
        )

    def forward(self, hidden_states):
        router_logits, _ = self.router(hidden_states)
        shared_out = self.shared_expert(hidden_states)
        routed_out = self.experts(
            hidden_states=hidden_states,
            router_logits=router_logits,
        )
        experts_out = routed_out + shared_out

        if self.tp_size > 1:
107
108
            experts_out = self.experts.maybe_all_reduce_tensor_model_parallel(
                experts_out)
109
110
111
112
113
114
115
116
117
118
119
120

        return experts_out


class Llama4Attention(nn.Module):

    def __init__(self,
                 config: Llama4TextConfig,
                 hidden_size: int,
                 num_heads: int,
                 num_kv_heads: int,
                 rope_theta: float = 10000,
121
                 rope_scaling: Optional[dict[str, Any]] = None,
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
                 max_position_embeddings: int = 8192,
                 quant_config: Optional[QuantizationConfig] = None,
                 bias: bool = False,
                 bias_o_proj: bool = False,
                 cache_config: Optional[CacheConfig] = None,
                 prefix: str = "") -> None:
        super().__init__()
        self.layer_idx = extract_layer_index(prefix)
        self.hidden_size = hidden_size
        self.no_rope_layers = config.no_rope_layers
        self.nope = self.no_rope_layers[self.layer_idx] == 0
        self.use_qk_norm = config.use_qk_norm and not self.nope
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = config.head_dim
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.attn_temperature_tuning = self.nope and \
153
            config.attn_temperature_tuning
154
155
156
157
158
159

        self.floor_scale = getattr(config, "floor_scale", 8192.0)
        self.attn_scale = getattr(config, "attn_scale", 0.1)
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings
        self.n_rep = self.num_heads // self.num_kv_heads
160
161
        self.qk_norm = RMSNorm(
            hidden_size=self.head_dim,
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
            eps=config.rms_norm_eps,
            has_weight=False,
            dtype=torch.float32,
        ) if self.use_qk_norm else None
        self.qkv_proj = QKVParallelLinear(
            hidden_size=hidden_size,
            head_size=self.head_dim,
            total_num_heads=self.total_num_heads,
            total_num_kv_heads=self.total_num_kv_heads,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
        )

        self.o_proj = RowParallelLinear(
            input_size=self.total_num_heads * self.head_dim,
            output_size=hidden_size,
            bias=bias_o_proj,
            quant_config=quant_config,
            prefix=f"{prefix}.o_proj",
        )
        is_neox_style = True
        is_gguf = quant_config and quant_config.get_name() == "gguf"
        if is_gguf and config.model_type == "llama":
            is_neox_style = False

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=int(rope_theta),
            rope_scaling=rope_scaling if rope_scaling != "default" else None,
            is_neox_style=is_neox_style,
        ) if not self.nope else None

        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            per_layer_sliding_window=None,
            use_irope=not self.nope,
            prefix=f"{prefix}.attn",
        )

    def _get_attn_scale(self, positions: torch.Tensor) -> torch.Tensor:
        floor = torch.floor((positions + 1.0) / self.floor_scale)
        attn_scale = torch.log(floor + 1.0) * self.attn_scale + 1.0

        return attn_scale.unsqueeze(-1)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)

        if self.rotary_emb is not None:
            q, k = self.rotary_emb(positions, q, k)
225
226
227
228
229
        if self.qk_norm is not None:
            q = q.reshape(-1, self.num_heads, self.head_dim)
            q = self.qk_norm(q.float()).reshape(-1, self.q_size).to(q.dtype)
            k = k.reshape(-1, self.num_kv_heads, self.head_dim)
            k = self.qk_norm(k.float()).reshape(-1, self.kv_size).to(k.dtype)
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258

        # We are applying temperature tuning (https://arxiv.org/abs/2501.19399)
        # to NoPE layers, where the inference-time temperature tuning function
        # is customized to not affect short context
        # while working at very long context
        # https://arxiv.org/abs/2501.19399
        #
        # We should apply temperature tuning between (after) rotary / QK norm
        # and (before) attention.
        if self.attn_temperature_tuning and self.nope:
            attn_scale = self._get_attn_scale(positions)
            q = (q * attn_scale).to(q.dtype)
        attn_output = self.attn(q, k, v)
        output, _ = self.o_proj(attn_output)
        return output


class Llama4DecoderLayer(nn.Module):

    def __init__(
        self,
        config: Llama4TextConfig,
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ) -> None:
        super().__init__()

        self.layer_idx = extract_layer_index(prefix)
259
        self.global_layer = config.no_rope_layers[self.layer_idx] == 0
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
        self.hidden_size = config.hidden_size
        rope_theta = config.rope_theta
        rope_scaling = config.rope_scaling
        max_position_embeddings = config.max_position_embeddings

        self.self_attn = Llama4Attention(
            config=config,
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            num_kv_heads=config.num_key_value_heads,
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
            quant_config=quant_config,
            bias=False,
            bias_o_proj=False,
            cache_config=cache_config,
            prefix=f"{prefix}.self_attn",
        )
279
280
        is_moe_layer = config.interleave_moe_layer_step > 0 and (
            self.layer_idx + 1) % config.interleave_moe_layer_step == 0
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
        if is_moe_layer:
            self.feed_forward = Llama4MoE(
                config=config,
                quant_config=quant_config,
                prefix=f"{prefix}.feed_forward",
            )
        else:
            self.feed_forward = LlamaMLP(
                hidden_size=self.hidden_size,
                intermediate_size=config.intermediate_size_mlp,
                hidden_act="silu",
                quant_config=quant_config,
                bias=False,
                prefix=f"{prefix}.feed_forward",
            )
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        residual: Optional[torch.Tensor],
306
    ) -> tuple[torch.Tensor, torch.Tensor]:
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
        hidden_states = self.self_attn(positions=positions,
                                       hidden_states=hidden_states)

        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
        hidden_states = self.feed_forward(hidden_states)
        return hidden_states, residual


@support_torch_compile
class Llama4Model(LlamaModel):

    def __init__(self,
                 *,
                 vllm_config: VllmConfig,
                 prefix: str = "",
                 layer_type: type[Llama4DecoderLayer] = Llama4DecoderLayer):
        self.num_experts = vllm_config.model_config.hf_config.num_local_experts
        super().__init__(vllm_config=vllm_config,
                         prefix=prefix,
                         layer_type=layer_type)

    def load_moe_expert_weights(
        self,
        name: str,
        loaded_weight: torch.Tensor,
341
342
343
        params_dict: dict[str, nn.Parameter],
        loaded_params: set[str],
        expert_params_mapping: list[tuple[str, str, int, str]],
344
345
        fused: bool = True,
    ) -> bool:
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
        """
        Load MoE expert weights.

        Args:
            name: The name of the weight to load.
            loaded_weight: The weight to load.
            params_dict: The dictionary of module parameters.
            loaded_params: The set of already loaded parameters.
            expert_params_mapping: The mapping of expert parameters. Must be
                generated by FusedMoE.make_expert_params_mapping().
            fused: Whether the expert weights are fused into a single weight
                tensor or are separate weight tensors for each expert.
                When fused is True, loaded_weight should have shape of:
                [num_experts, hidden_in, hidden_out] for gate/up/down proj and
                [hidden_out, hidden_in] for the others like router.
                When fused is False, loaded_weight should have shape of:
                [hidden_out, hidden_in].

        Returns:
            True if loaded_weight is one of MoE weights and the MoE expert
            weights are loaded successfully, False otherwise.
        """

        # Whether the MoE expert weights are loaded successfully.
370
        expert_param_loaded = False
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385

        # If fused is True, the loaded weight is in the layout of:
        # [num_experts, hidden_in, hidden_out], so we must transpose the last
        # two dimensions to match the expected layout of the parameters.
        if fused and loaded_weight.ndim == 3:
            loaded_weight = loaded_weight.transpose(-1, -2)

            # If the gate_proj and up_proj weights are fused into a single
            # weight tensor, we need to split the weight tensor into a tuple
            # of two weight tensors along the hidden_out dimension.
            if "experts.gate_up_proj" in name:
                loaded_weight = loaded_weight.chunk(2, dim=-2)

        # Iterate over all the expert parameters and load the weights if we find
        # a match in weight name.
386
387
        for (param_name, weight_name, expert_id,
             shard_id) in expert_params_mapping:
388
389
390

            # Get a view of the loaded_weight to avoid modifying the original
            # one across iterations.
391
            new_loaded_weight = loaded_weight
392
393
394

            # If expert weights are fused into a single weight tensor, remove
            # the expert index from the expected weight name.
395
            if fused:
396
                # The string between e_str and proj_str is the expert index.
397
398
399
                e_str, _, proj_str, _ = weight_name.split('.')
                weight_name = f"{e_str}.{proj_str}"
                param_name = f"{param_name}weight"
400
401

            # Skip if the current weight is not one of the MoE weights.
402
403
            if weight_name not in name:
                continue
404
405

            # Replace the weight name with the parameter name.
406
            full_param_name = name.replace(weight_name, param_name)
407
408
409

            # Skip if the current weight corresponds to a parameter that
            # does not exist on the current PP (pipeline parallel) rank.
410
411
            if is_pp_missing_parameter(name, self):
                continue
412
413

            # Skip if the current weight is for the bias.
414
415
416
            if ((name.endswith(".bias") or name.endswith("_bias"))
                    and name not in params_dict):
                continue
417

418
419
            param = params_dict[full_param_name]
            weight_loader = param.weight_loader
420

421
            if fused:
422
423
424
                # If the parameter is for w13 together, the corresponding weight
                # will be a tuple, so we must select the correct weight
                # depending on the shard id, which is either "w1" or "w3".
425
                if "w13" in full_param_name:
426
                    assert shard_id in ["w1", "w3"]
427
428
                    shard_idx = 0 if shard_id == "w1" else 1
                    new_loaded_weight = new_loaded_weight[shard_idx]
429
430
431
432

                # If EP (expert parallel) is enabled, update expert_id to the
                # starting expert index for the current EP rank and extract the
                # corresponding expert weights.
433
434
435
436
437
438
439
440
441
442
443
444
445
                layer_idx = extract_layer_index(name)
                expert_map = self.layers[
                    layer_idx].feed_forward.experts.expert_map
                if expert_map is not None:
                    local_expert_indices = (expert_map != -1) \
                                            .nonzero() \
                                            .flatten() \
                                            .to(new_loaded_weight.device)
                    new_loaded_weight = new_loaded_weight[local_expert_indices]
                    expert_id = local_expert_indices[0].item()
            else:
                # TODO: add EP support for non fused weights
                pass
446
447
448

            # Load the weight into the module parameter with corresponding
            # shard id and expert id.
449
450
451
452
453
454
455
456
            weight_loader(param,
                          new_loaded_weight,
                          full_param_name,
                          shard_id=shard_id,
                          expert_id=expert_id)

            loaded_params.add(full_param_name)
            expert_param_loaded = True
457

458
459
        return expert_param_loaded

460
461
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
462
463
        # Name mapping from the parameter name to the shard name and
        # corresponding shard id.
464
465
466
467
468
469
470
471
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
        ]
472
473
        # Indicate whether the expert weights are fused into a single weight
        # tensor.
474
        fused_experts_params = False
475
476
        # Expert parameter mapping for the case where the expert weights are
        # not fused into a single weight tensor.
477
478
479
480
481
        expert_params_mapping = FusedMoE.make_expert_params_mapping(
            ckpt_gate_proj_name="gate_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="up_proj",
            num_experts=self.num_experts)
482
483
        # Expert parameter mapping for the case where the expert weights are
        # fused into a single weight tensor.
484
485
486
487
488
        expert_params_mapping_fused = FusedMoE.make_expert_params_mapping(
            ckpt_gate_proj_name="gate_up_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="gate_up_proj",
            num_experts=1)
489
        # All the module parameters.
490
        params_dict = dict(self.named_parameters())
491
        # The module parameters that have been loaded.
492
        loaded_params: set[str] = set()
493
494

        # Iterate over all the weights and load them into module parameters.
495
        for name, loaded_weight in weights:
496
497
498
499

            # If the name contains "experts.gate_up_proj" or "experts.down_proj"
            # without the expert indices, it means the expert weights are fused
            # into a single weight tensor across all experts.
500
501
502
            if "experts.gate_up_proj" in name or "experts.down_proj" in name:
                fused_experts_params = True
                expert_params_mapping = expert_params_mapping_fused
503
504
505
506

            # If kv cache quantization scales exist and the weight name
            # corresponds to one of the kv cache quantization scales, load
            # them.
507
508
509
510
511
512
513
514
515
516
            if (self.quant_config is not None and
                (scale_name := self.quant_config.get_cache_scale(name))):
                param = params_dict[scale_name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
                                 loaded_weight[0])
                weight_loader(param, loaded_weight)
                loaded_params.add(scale_name)
                continue
517
518
519
520
521

            # Iterate over stacked_params_mapping to check if the current weight
            # is one of the stacked parameters. If so, load the weight with the
            # corresponding shard id. Note that MoE weights are handled
            # separately in the else block.
522
            for param_name, weight_name, shard_id in stacked_params_mapping:
523
524
                # Skip if the current weight is not one of the stacked
                # parameters or if the current weight is a MoE weight.
525
526
                if weight_name not in name or "experts" in name:
                    continue
527
528
529

                # For ModelOpt checkpoints, we need to rename the self_attn
                # weight/weight_scale names except for kv cache scales.
530
531
532
                if not (name.endswith(
                    (".k_scale", ".v_scale")) and "self_attn" in name):
                    name = name.replace(weight_name, param_name)
533
534
535

                # Skip if the current weight corresponds to a parameter that
                # does not exist on the current PP (pipeline parallel) rank.
536
537
                if is_pp_missing_parameter(name, self):
                    continue
538
539
540
541
542

                # Remap kv cache scale names for ModelOpt checkpoints.
                # TODO: ModelOpt should implement get_cache_scale() such that
                #       kv cache scale name remapping can be done there.
                if name.endswith("scale"):
543
544
545
                    name = maybe_remap_kv_scale_name(name, params_dict)
                    if name is None:
                        continue
546
547
548

                # Load the weight into the module parameter with corresponding
                # shard id and exit the for loop and the else block.
549
                param = params_dict[name]
550
551
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
552

553
554
555
556
                if weight_loader == default_weight_loader:
                    weight_loader(param, loaded_weight)
                else:
                    weight_loader(param, loaded_weight, shard_id)
557

558
559
                loaded_params.add(name)
                break
560
561

            # Handle normal (non-stacked) weights and MoE weights.
562
            else:
563
564
565
566
567
568
569
570
571
                # First, try to load MoE weights using load_moe_expert_weights.
                # If successful, move on to next loaded weight.
                if self.load_moe_expert_weights(name,
                                                loaded_weight,
                                                params_dict,
                                                loaded_params,
                                                expert_params_mapping,
                                                fused=fused_experts_params):
                    continue
572

573
574
575
576
577
578
579
580
581
582
583
584
585
586
                # Skip if the current weight corresponds to a parameter that
                # does not exist on the current PP (pipeline parallel) rank.
                if is_pp_missing_parameter(name, self):
                    continue

                # Handle flat expert scale parameters that don't match
                # per-expert patterns, i.e. one weight scale tensor for all
                # experts.
                scale_names = [
                    "w13_input_scale", "w13_weight_scale", "w2_input_scale",
                    "w2_weight_scale"
                ]
                if ("experts." in name and any(scale_name in name
                                               for scale_name in scale_names)):
587

588
589
590
                    param = params_dict[name]
                    weight_loader = getattr(param, "weight_loader",
                                            default_weight_loader)
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618

                    # If weight loader supports special moe loading, use it to
                    # avoid expensive runtime reflection
                    if getattr(weight_loader, 'supports_moe_loading', False):
                        # Map the weight name to the corresponding shard id.
                        shard_id = "w2" if "w2_" in name else "w1"

                        # Transpose if weight scales are FP8 block scales with
                        # three dimensions:
                        # [num_experts, hidden_in, hidden_out].
                        if name.endswith("weight_scale") \
                            and loaded_weight.dtype == torch.float8_e4m3fn \
                            and loaded_weight.ndim == 3:
                            loaded_weight = loaded_weight.transpose(-1, -2)

                        # Load the weight into the module parameter with
                        # corresponding shard id and expert id.
                        weight_loader(param,
                                      loaded_weight,
                                      name,
                                      shard_id=shard_id,
                                      expert_id=0)

                    else:
                        # Regular weight loader (handles both
                        # param.weight_loader and default_weight_loader)
                        weight_loader(param, loaded_weight)

619
                    loaded_params.add(name)
620
621
622
623
624
625
626
627
628
629
                    continue

                # Handle normal (non-stacked, non-MoE) weights.
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
                loaded_params.add(name)

        # Finally, return the set of loaded parameters.
630
631
632
633
634
635
636
637
638
639
640
        return loaded_params


class Llama4ForCausalLM(LlamaForCausalLM):

    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
        "gate_up_proj": ["gate_proj", "up_proj"],
    }

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
641
        # update temperature tuning config from generation config
642
643
        gen_config = vllm_config.model_config.try_get_generation_config()
        gen_config.update(vllm_config.model_config.override_generation_config)
644
645
646
        # enable temperature tuning by default when max_model_len > 32K
        default_attn_temperature_tuning = \
            vllm_config.model_config.max_model_len > 32768
647
        vllm_config.model_config.hf_config.attn_temperature_tuning \
648
649
            = gen_config.get(
                "attn_temperature_tuning", default_attn_temperature_tuning)
650
651
652
653
654
655
656
657
658
659
660
661
662

        super().__init__(vllm_config=vllm_config,
                         prefix=prefix,
                         layer_type=Llama4DecoderLayer)

    def _init_model(self,
                    vllm_config: VllmConfig,
                    prefix: str = "",
                    layer_type: type[Llama4DecoderLayer] = Llama4DecoderLayer):
        return Llama4Model(vllm_config=vllm_config,
                           prefix=prefix,
                           layer_type=layer_type)

663
664
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
        loader = AutoWeightsLoader(
            self,
            skip_prefixes=(["lm_head."]
                           if self.config.tie_word_embeddings else None),
        )
        weights = [
            self.permute_qk_weight_for_rotary(name, loaded_weight)
            for name, loaded_weight in weights
        ]
        return loader.load_weights(weights)

    def permute_qk_weight_for_rotary(
        self,
        name: str,
        loaded_weight: torch.Tensor,
680
    ) -> tuple[str, torch.Tensor]:
681

682
683
684
685
686
        # Helper function to permute the weight's channels
        def permute(w: torch.Tensor, n_heads: int, is_weight_scale: bool):

            # Calculate the expected shape of the weight.
            # Do not rely on w's shape, as it may be in another layout.
687
688
689
            attn_in = self.config.head_dim * n_heads
            attn_out = self.config.hidden_size

690
691
692
693
694
695
696
697
698
699
700
            # If the weight is FP4 packed as uint8, we need to divide attn_out
            # by 2.
            if w.dtype == torch.uint8 and w.shape[1] * 2 == attn_out:
                attn_out = attn_out // 2

            # If the weight is a weight scale, we need to divide attn_out by
            # block size, which is currently 16.
            elif w.dtype == torch.float8_e4m3fn and is_weight_scale \
                and w.shape[1] * 16 == attn_out:
                attn_out = attn_out // 16

701
702
703
704
705
            return w.view(n_heads, attn_in // n_heads // 2, 2,
                          attn_out).transpose(1, 2).reshape(attn_in, attn_out)

        modules = name.split(".")

706
707
708
709
710
711
712
713
714
715
716
717
718
719
        # Permute Q/K weights and weight block scales for rotary embedding
        is_weight = modules[-1] == "weight"
        is_nvfp4_weight_scale = (modules[-1] == "weight_scale" and
                                 loaded_weight.dtype == torch.float8_e4m3fn)

        if is_weight or is_nvfp4_weight_scale:
            if ("wk" in modules or "k_proj" in modules):
                loaded_weight = permute(loaded_weight,
                                        self.config.num_key_value_heads,
                                        is_nvfp4_weight_scale)
            elif ("wq" in modules or "q_proj" in modules):
                loaded_weight = permute(loaded_weight,
                                        self.config.num_attention_heads,
                                        is_nvfp4_weight_scale)
720
721

        return name, loaded_weight