llama4.py 35.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
#
# Copyright 2025 the LLAMA4, Meta Inc., vLLM, and HuggingFace Inc. team.
# All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only LLaMA model compatible with HuggingFace weights."""
20

21
from collections.abc import Iterable
22
23
24
25
26
27
28

import torch
from torch import nn
from transformers import Llama4TextConfig

from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
29
from vllm.distributed import (
30
    get_ep_group,
31
32
33
    get_tensor_model_parallel_world_size,
    tensor_model_parallel_all_gather,
)
34
from vllm.logger import init_logger
35
36
from vllm.model_executor.layers.attention import (
    Attention,
37
38
    ChunkedLocalAttention,
)
39
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
40
from vllm.model_executor.layers.layernorm import RMSNorm
41
42
43
44
45
from vllm.model_executor.layers.linear import (
    QKVParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
46
47
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
48
from vllm.model_executor.model_loader.weight_utils import (
49
50
51
    default_weight_loader,
    maybe_remap_kv_scale_name,
)
52
from vllm.model_executor.models.interfaces import MixtureOfExperts
53
from vllm.model_executor.models.utils import sequence_parallel_chunk
54
55
from vllm.platforms import current_platform
from vllm.utils.torch_utils import is_torch_equal_or_newer
56
57

from .llama import LlamaForCausalLM, LlamaMLP, LlamaModel
58
59
from .utils import (
    AutoWeightsLoader,
60
    PPMissingLayer,
61
62
63
64
    extract_layer_index,
    fast_topk,
    is_pp_missing_parameter,
)
65

66
67
logger = init_logger(__name__)

68
69
70
71
72
73
74
75

class Llama4MoE(nn.Module):
    @staticmethod
    def custom_routing_function(
        hidden_states: torch.Tensor,
        gating_output: torch.Tensor,
        topk: int,
        renormalize: bool,
76
    ) -> tuple[torch.Tensor, torch.Tensor]:
77
        router_scores, router_indices = fast_topk(gating_output, topk, dim=-1)
78
        # pseudo-standard is that the router scores are floats
79
        router_scores = torch.sigmoid(router_scores.float())
80
81
        return (router_scores, router_indices.to(torch.int32))

82
    def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
83
        super().__init__()
84
85
86
87
88

        config = vllm_config.model_config.hf_config
        parallel_config = vllm_config.parallel_config
        quant_config = vllm_config.quant_config

89
90
        self.tp_size = get_tensor_model_parallel_world_size()
        self.top_k = config.num_experts_per_tok
91
        self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
92
93
94
        self.ep_group = get_ep_group().device_group
        self.ep_rank = get_ep_group().rank_in_group
        self.ep_size = self.ep_group.size()
95
96

        intermediate_size_moe = config.intermediate_size
97
98
99
100
101
102
103
        self.router = ReplicatedLinear(
            config.hidden_size,
            config.num_local_experts,
            bias=False,
            quant_config=None,
            prefix=f"{prefix}.router",
        )
104

105
106
107
108
109
110
111
112
        self.shared_expert = LlamaMLP(
            hidden_size=config.hidden_size,
            intermediate_size=intermediate_size_moe,
            hidden_act="silu",
            quant_config=quant_config,
            bias=False,
            prefix=f"{prefix}.shared_expert",
            reduce_results=False,
113
            disable_tp=self.is_sequence_parallel,
114
115
        )

116
117
118
119
120
121
122
123
124
125
126
127
128
129
        # Load balancing settings.
        eplb_config = parallel_config.eplb_config if parallel_config else None
        self.enable_eplb = parallel_config.enable_eplb if parallel_config else False
        self.n_redundant_experts = (
            eplb_config.num_redundant_experts if eplb_config else 0
        )

        self.n_routed_experts: int = config.num_local_experts
        self.n_logical_experts = self.n_routed_experts
        self.n_shared_experts: int = 1
        self.n_local_experts: int = config.num_local_experts
        self.n_physical_experts = self.n_local_experts + self.n_redundant_experts
        self.n_local_physical_experts = self.n_physical_experts // self.ep_size

130
131
        self.experts = SharedFusedMoE(
            shared_experts=self.shared_expert,
132
133
134
135
136
137
138
139
140
            num_experts=config.num_local_experts,
            top_k=config.num_experts_per_tok,
            hidden_size=config.hidden_size,
            custom_routing_function=Llama4MoE.custom_routing_function,
            intermediate_size=intermediate_size_moe,
            apply_router_weight_on_input=True,
            reduce_results=False,
            renormalize=False,
            quant_config=quant_config,
141
            prefix=f"{prefix}.experts",
142
            is_sequence_parallel=self.is_sequence_parallel,
143
144
            enable_eplb=self.enable_eplb,
            num_redundant_experts=self.n_redundant_experts,
145
146
147
        )

    def forward(self, hidden_states):
148
149
150
151
        num_tokens = hidden_states.shape[0]
        if self.is_sequence_parallel:
            hidden_states = sequence_parallel_chunk(hidden_states)

152
        router_logits, _ = self.router(hidden_states)
153
154

        shared_out, routed_out = self.experts(
155
156
157
158
159
            hidden_states=hidden_states,
            router_logits=router_logits,
        )
        experts_out = routed_out + shared_out

160
161
162
163
        if self.is_sequence_parallel:
            experts_out = tensor_model_parallel_all_gather(experts_out, 0)
            experts_out = experts_out[:num_tokens]
        elif self.tp_size > 1:
164
            experts_out = self.experts.maybe_all_reduce_tensor_model_parallel(
165
166
                experts_out
            )
167
168
169
170
171

        return experts_out


class Llama4Attention(nn.Module):
172
173
174
175
176
177
178
    def __init__(
        self,
        config: Llama4TextConfig,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        max_position_embeddings: int = 8192,
179
        quant_config: QuantizationConfig | None = None,
180
181
        bias: bool = False,
        bias_o_proj: bool = False,
182
        cache_config: CacheConfig | None = None,
183
184
        prefix: str = "",
    ) -> None:
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
        super().__init__()
        self.layer_idx = extract_layer_index(prefix)
        self.hidden_size = hidden_size
        self.no_rope_layers = config.no_rope_layers
        self.nope = self.no_rope_layers[self.layer_idx] == 0
        self.use_qk_norm = config.use_qk_norm and not self.nope
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = config.head_dim
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
209
        self.attn_temperature_tuning = self.nope and config.attn_temperature_tuning
210
211
212
213
214

        self.floor_scale = getattr(config, "floor_scale", 8192.0)
        self.attn_scale = getattr(config, "attn_scale", 0.1)
        self.max_position_embeddings = max_position_embeddings
        self.n_rep = self.num_heads // self.num_kv_heads
215
216
217
218
219
220
221
222
223
224
        self.qk_norm = (
            RMSNorm(
                hidden_size=self.head_dim,
                eps=config.rms_norm_eps,
                has_weight=False,
                dtype=torch.float32,
            )
            if self.use_qk_norm
            else None
        )
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
        self.qkv_proj = QKVParallelLinear(
            hidden_size=hidden_size,
            head_size=self.head_dim,
            total_num_heads=self.total_num_heads,
            total_num_kv_heads=self.total_num_kv_heads,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
        )

        self.o_proj = RowParallelLinear(
            input_size=self.total_num_heads * self.head_dim,
            output_size=hidden_size,
            bias=bias_o_proj,
            quant_config=quant_config,
            prefix=f"{prefix}.o_proj",
        )
        is_neox_style = True
        is_gguf = quant_config and quant_config.get_name() == "gguf"
        if is_gguf and config.model_type == "llama":
            is_neox_style = False

247
248
249
250
        self.rotary_emb = (
            get_rope(
                self.head_dim,
                max_position=max_position_embeddings,
251
                rope_parameters=config.rope_parameters,
252
253
254
255
256
                is_neox_style=is_neox_style,
            )
            if not self.nope
            else None
        )
257

258
        use_chunked_local_attn = not self.nope and config.attention_chunk_size
259
        attn_cls = ChunkedLocalAttention if use_chunked_local_attn else Attention
260
        self.attn = attn_cls(
261
262
263
264
265
266
267
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
268
269
270
271
272
273
            **(
                {"attention_chunk_size": config.attention_chunk_size}
                if use_chunked_local_attn
                else {}
            ),
        )
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290

    def _get_attn_scale(self, positions: torch.Tensor) -> torch.Tensor:
        floor = torch.floor((positions + 1.0) / self.floor_scale)
        attn_scale = torch.log(floor + 1.0) * self.attn_scale + 1.0

        return attn_scale.unsqueeze(-1)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)

        if self.rotary_emb is not None:
            q, k = self.rotary_emb(positions, q, k)
291

292
        if self.qk_norm is not None:
293
294
295
296
            # Normalization is applied on the head_dim dimension. The rest of
            # the dimensions are collapsed into a single dimension to support
            # custom rms_norm cuda kernel.
            q = q.reshape(-1, self.head_dim)
297
            q = self.qk_norm(q.float()).reshape(-1, self.q_size).to(q.dtype)
298
            k = k.reshape(-1, self.head_dim)
299
            k = self.qk_norm(k.float()).reshape(-1, self.kv_size).to(k.dtype)
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317

        # We are applying temperature tuning (https://arxiv.org/abs/2501.19399)
        # to NoPE layers, where the inference-time temperature tuning function
        # is customized to not affect short context
        # while working at very long context
        # https://arxiv.org/abs/2501.19399
        #
        # We should apply temperature tuning between (after) rotary / QK norm
        # and (before) attention.
        if self.attn_temperature_tuning and self.nope:
            attn_scale = self._get_attn_scale(positions)
            q = (q * attn_scale).to(q.dtype)
        attn_output = self.attn(q, k, v)
        output, _ = self.o_proj(attn_output)
        return output


class Llama4DecoderLayer(nn.Module):
318
319
320
321
    def __init__(
        self,
        vllm_config: VllmConfig,
        prefix: str = "",
322
        config: Llama4TextConfig | None = None,
323
    ) -> None:
324
325
        super().__init__()

326
327
328
329
        config = config or vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

330
        self.layer_idx = extract_layer_index(prefix)
331
        self.global_layer = config.no_rope_layers[self.layer_idx] == 0
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
        self.hidden_size = config.hidden_size
        max_position_embeddings = config.max_position_embeddings

        self.self_attn = Llama4Attention(
            config=config,
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            num_kv_heads=config.num_key_value_heads,
            max_position_embeddings=max_position_embeddings,
            quant_config=quant_config,
            bias=False,
            bias_o_proj=False,
            cache_config=cache_config,
            prefix=f"{prefix}.self_attn",
        )
347
348
349
350
        is_moe_layer = (
            config.interleave_moe_layer_step > 0
            and (self.layer_idx + 1) % config.interleave_moe_layer_step == 0
        )
351
352
        if is_moe_layer:
            self.feed_forward = Llama4MoE(
353
                vllm_config=vllm_config,
354
355
356
357
358
359
360
361
362
363
364
                prefix=f"{prefix}.feed_forward",
            )
        else:
            self.feed_forward = LlamaMLP(
                hidden_size=self.hidden_size,
                intermediate_size=config.intermediate_size_mlp,
                hidden_act="silu",
                quant_config=quant_config,
                bias=False,
                prefix=f"{prefix}.feed_forward",
            )
365
366
367
368
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
369
370
371
372
373

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
374
        residual: torch.Tensor | None,
375
    ) -> tuple[torch.Tensor, torch.Tensor]:
376
377
378
379
380
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
381
382
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
        hidden_states = self.self_attn(positions=positions, hidden_states=hidden_states)
383
384

        # Fully Connected
385
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
386
387
388
389
390
391
        hidden_states = self.feed_forward(hidden_states)
        return hidden_states, residual


@support_torch_compile
class Llama4Model(LlamaModel):
392
393
394
395
396
397
398
    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
        layer_type: type[Llama4DecoderLayer] = Llama4DecoderLayer,
    ):
399
        self.num_experts = vllm_config.model_config.hf_config.num_local_experts
400
401
402
        self.n_redundant_experts = (
            vllm_config.parallel_config.eplb_config.num_redundant_experts
        )
403
        super().__init__(vllm_config=vllm_config, prefix=prefix, layer_type=layer_type)
404
405
406
407
408

    def load_moe_expert_weights(
        self,
        name: str,
        loaded_weight: torch.Tensor,
409
410
411
        params_dict: dict[str, nn.Parameter],
        loaded_params: set[str],
        expert_params_mapping: list[tuple[str, str, int, str]],
412
413
        fused: bool = True,
    ) -> bool:
414
415
416
417
418
419
420
421
422
        """
        Load MoE expert weights.

        Args:
            name: The name of the weight to load.
            loaded_weight: The weight to load.
            params_dict: The dictionary of module parameters.
            loaded_params: The set of already loaded parameters.
            expert_params_mapping: The mapping of expert parameters. Must be
423
                generated by SharedFusedMoE.make_expert_params_mapping().
424
425
426
427
428
429
430
431
432
433
434
435
436
437
            fused: Whether the expert weights are fused into a single weight
                tensor or are separate weight tensors for each expert.
                When fused is True, loaded_weight should have shape of:
                [num_experts, hidden_in, hidden_out] for gate/up/down proj and
                [hidden_out, hidden_in] for the others like router.
                When fused is False, loaded_weight should have shape of:
                [hidden_out, hidden_in].

        Returns:
            True if loaded_weight is one of MoE weights and the MoE expert
            weights are loaded successfully, False otherwise.
        """

        # Whether the MoE expert weights are loaded successfully.
438
        expert_param_loaded = False
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453

        # If fused is True, the loaded weight is in the layout of:
        # [num_experts, hidden_in, hidden_out], so we must transpose the last
        # two dimensions to match the expected layout of the parameters.
        if fused and loaded_weight.ndim == 3:
            loaded_weight = loaded_weight.transpose(-1, -2)

            # If the gate_proj and up_proj weights are fused into a single
            # weight tensor, we need to split the weight tensor into a tuple
            # of two weight tensors along the hidden_out dimension.
            if "experts.gate_up_proj" in name:
                loaded_weight = loaded_weight.chunk(2, dim=-2)

        # Iterate over all the expert parameters and load the weights if we find
        # a match in weight name.
454
        for param_name, weight_name, expert_id, shard_id in expert_params_mapping:
455
456
            # Get a view of the loaded_weight to avoid modifying the original
            # one across iterations.
457
            new_loaded_weight = loaded_weight
458
459
460

            # If expert weights are fused into a single weight tensor, remove
            # the expert index from the expected weight name.
461
            if fused:
462
                # The string between e_str and proj_str is the expert index.
463
                e_str, _, proj_str, _ = weight_name.split(".")
464
465
                weight_name = f"{e_str}.{proj_str}"
                param_name = f"{param_name}weight"
466
467

            # Skip if the current weight is not one of the MoE weights.
468
469
            if weight_name not in name:
                continue
470
471

            # Replace the weight name with the parameter name.
472
            full_param_name = name.replace(weight_name, param_name)
473
474
475

            # Skip if the current weight corresponds to a parameter that
            # does not exist on the current PP (pipeline parallel) rank.
476
477
            if is_pp_missing_parameter(name, self):
                continue
478
479

            # Skip if the current weight is for the bias.
480
481
482
            if (
                name.endswith(".bias") or name.endswith("_bias")
            ) and name not in params_dict:
483
                continue
484

485
486
            param = params_dict[full_param_name]
            weight_loader = param.weight_loader
487

488
            if fused:
489
490
491
                # If the parameter is for w13 together, the corresponding weight
                # will be a tuple, so we must select the correct weight
                # depending on the shard id, which is either "w1" or "w3".
492
                if "w13" in full_param_name:
493
                    assert shard_id in ["w1", "w3"]
494
495
                    shard_idx = 0 if shard_id == "w1" else 1
                    new_loaded_weight = new_loaded_weight[shard_idx]
496
497
498
499

                # If EP (expert parallel) is enabled, update expert_id to the
                # starting expert index for the current EP rank and extract the
                # corresponding expert weights.
500
                layer_idx = extract_layer_index(name)
501
                expert_map = self.layers[layer_idx].feed_forward.experts.expert_map
502
                if expert_map is not None:
503
504
505
506
507
508
                    local_expert_indices = (
                        (expert_map != -1)
                        .nonzero()
                        .flatten()
                        .to(new_loaded_weight.device)
                    )
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
                    # Workaround for FP8 CPU indexing on older PyTorch:
                    # https://github.com/vllm-project/vllm/issues/32862
                    is_fp8_dtype = new_loaded_weight.dtype == (
                        current_platform.fp8_dtype()
                    ) or (
                        new_loaded_weight.dtype.is_floating_point
                        and new_loaded_weight.element_size() == 1
                    )
                    if (
                        new_loaded_weight.device.type == "cpu"
                        and is_fp8_dtype
                        and not is_torch_equal_or_newer("2.11.0")
                    ):
                        # PyTorch < 2.11 doesn't support CPU float8 indexing.
                        new_loaded_weight = new_loaded_weight.to(torch.float16)[
                            local_expert_indices
                        ].to(new_loaded_weight.dtype)
                    else:
                        new_loaded_weight = new_loaded_weight[local_expert_indices]
528
529
530
531
                    expert_id = local_expert_indices[0].item()
            else:
                # TODO: add EP support for non fused weights
                pass
532
533
534

            # Load the weight into the module parameter with corresponding
            # shard id and expert id.
535
536
537
538
539
540
541
            weight_loader(
                param,
                new_loaded_weight,
                full_param_name,
                shard_id=shard_id,
                expert_id=expert_id,
            )
542
543
            loaded_params.add(full_param_name)
            expert_param_loaded = True
544

545
546
        return expert_param_loaded

547
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
548
549
        # Name mapping from the parameter name to the shard name and
        # corresponding shard id.
550
551
552
553
554
555
556
557
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
        ]
558
559
        # Indicate whether the expert weights are fused into a single weight
        # tensor.
560
        fused_experts_params = False
561
562
        # Expert parameter mapping for the case where the expert weights are
        # not fused into a single weight tensor.
563
        expert_params_mapping = SharedFusedMoE.make_expert_params_mapping(
564
            self,
565
566
567
            ckpt_gate_proj_name="gate_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="up_proj",
568
            num_experts=self.num_experts,
569
            num_redundant_experts=self.n_redundant_experts,
570
        )
571
572
        # Expert parameter mapping for the case where the expert weights are
        # fused into a single weight tensor.
573
        expert_params_mapping_fused = SharedFusedMoE.make_expert_params_mapping(
574
            self,
575
576
577
            ckpt_gate_proj_name="gate_up_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="gate_up_proj",
578
579
            num_experts=1,
        )
580
        # All the module parameters.
581
        params_dict = dict(self.named_parameters())
582
        # The module parameters that have been loaded.
583
        loaded_params: set[str] = set()
584
585

        # Iterate over all the weights and load them into module parameters.
586
        for name, loaded_weight in weights:
587
588
589
            # If the name contains "experts.gate_up_proj" or "experts.down_proj"
            # without the expert indices, it means the expert weights are fused
            # into a single weight tensor across all experts.
590
591
592
            if "experts.gate_up_proj" in name or "experts.down_proj" in name:
                fused_experts_params = True
                expert_params_mapping = expert_params_mapping_fused
593
594
595
596

            # If kv cache quantization scales exist and the weight name
            # corresponds to one of the kv cache quantization scales, load
            # them.
597
598
599
            if self.quant_config is not None and (
                scale_name := self.quant_config.get_cache_scale(name)
            ):
600
                param = params_dict[scale_name]
601
602
603
604
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                loaded_weight = (
                    loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
                )
605
606
607
                weight_loader(param, loaded_weight)
                loaded_params.add(scale_name)
                continue
608
609
610
611
612

            # Iterate over stacked_params_mapping to check if the current weight
            # is one of the stacked parameters. If so, load the weight with the
            # corresponding shard id. Note that MoE weights are handled
            # separately in the else block.
613
            for param_name, weight_name, shard_id in stacked_params_mapping:
614
615
                # Skip if the current weight is not one of the stacked
                # parameters or if the current weight is a MoE weight.
616
617
                if weight_name not in name or "experts" in name:
                    continue
618
619
620

                # For ModelOpt checkpoints, we need to rename the self_attn
                # weight/weight_scale names except for kv cache scales.
621
622
623
                if not (
                    name.endswith((".k_scale", ".v_scale")) and "self_attn" in name
                ):
624
                    name = name.replace(weight_name, param_name)
625
626
627

                # Skip if the current weight corresponds to a parameter that
                # does not exist on the current PP (pipeline parallel) rank.
628
629
                if is_pp_missing_parameter(name, self):
                    continue
630
631
632
633
634

                # Remap kv cache scale names for ModelOpt checkpoints.
                # TODO: ModelOpt should implement get_cache_scale() such that
                #       kv cache scale name remapping can be done there.
                if name.endswith("scale"):
635
636
637
                    name = maybe_remap_kv_scale_name(name, params_dict)
                    if name is None:
                        continue
638
639
640

                # Load the weight into the module parameter with corresponding
                # shard id and exit the for loop and the else block.
641
                param = params_dict[name]
642
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
643

644
645
646
647
                if weight_loader == default_weight_loader:
                    weight_loader(param, loaded_weight)
                else:
                    weight_loader(param, loaded_weight, shard_id)
648

649
650
                loaded_params.add(name)
                break
651
652

            # Handle normal (non-stacked) weights and MoE weights.
653
            else:
654
655
                # First, try to load MoE weights using load_moe_expert_weights.
                # If successful, move on to next loaded weight.
656
657
658
659
660
661
662
663
                if self.load_moe_expert_weights(
                    name,
                    loaded_weight,
                    params_dict,
                    loaded_params,
                    expert_params_mapping,
                    fused=fused_experts_params,
                ):
664
                    continue
665

666
667
668
669
670
671
672
673
674
                # Skip if the current weight corresponds to a parameter that
                # does not exist on the current PP (pipeline parallel) rank.
                if is_pp_missing_parameter(name, self):
                    continue

                # Handle flat expert scale parameters that don't match
                # per-expert patterns, i.e. one weight scale tensor for all
                # experts.
                scale_names = [
675
676
677
678
                    "w13_input_scale",
                    "w13_weight_scale",
                    "w2_input_scale",
                    "w2_weight_scale",
679
                ]
680
681
682
                if "experts." in name and any(
                    scale_name in name for scale_name in scale_names
                ):
683
                    param = params_dict[name]
684
685
686
                    weight_loader = getattr(
                        param, "weight_loader", default_weight_loader
                    )
687
688
689

                    # If weight loader supports special moe loading, use it to
                    # avoid expensive runtime reflection
690
                    if getattr(weight_loader, "supports_moe_loading", False):
691
692
693
694
695
696
                        # Map the weight name to the corresponding shard id.
                        shard_id = "w2" if "w2_" in name else "w1"

                        # Transpose if weight scales are FP8 block scales with
                        # three dimensions:
                        # [num_experts, hidden_in, hidden_out].
697
698
699
700
701
                        if (
                            name.endswith("weight_scale")
                            and loaded_weight.dtype == torch.float8_e4m3fn
                            and loaded_weight.ndim == 3
                        ):
702
703
704
705
                            loaded_weight = loaded_weight.transpose(-1, -2)

                        # Load the weight into the module parameter with
                        # corresponding shard id and expert id.
706
707
708
                        weight_loader(
                            param, loaded_weight, name, shard_id=shard_id, expert_id=0
                        )
709
710
711
712
713
714

                    else:
                        # Regular weight loader (handles both
                        # param.weight_loader and default_weight_loader)
                        weight_loader(param, loaded_weight)

715
                    loaded_params.add(name)
716
717
718
719
                    continue

                # Handle normal (non-stacked, non-MoE) weights.
                param = params_dict[name]
720
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
721
722
723
724
                weight_loader(param, loaded_weight)
                loaded_params.add(name)

        # Finally, return the set of loaded parameters.
725
726
727
        return loaded_params


728
class Llama4ForCausalLM(LlamaForCausalLM, MixtureOfExperts):
729
730
731
732
733
734
    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
        "gate_up_proj": ["gate_proj", "up_proj"],
    }

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
735
        # update temperature tuning config from generation config
736
737
        gen_config = vllm_config.model_config.try_get_generation_config()
        gen_config.update(vllm_config.model_config.override_generation_config)
738
        # enable temperature tuning by default when max_model_len > 32K
739
740
741
742
743
744
745
746
        default_attn_temperature_tuning = vllm_config.model_config.max_model_len > 32768
        vllm_config.model_config.hf_config.attn_temperature_tuning = gen_config.get(
            "attn_temperature_tuning", default_attn_temperature_tuning
        )

        super().__init__(
            vllm_config=vllm_config, prefix=prefix, layer_type=Llama4DecoderLayer
        )
747
748
749
750
751
752
753
754
755
        # Set MoE hyperparameters
        self.set_moe_parameters()

    def set_moe_parameters(self):
        self.expert_weights = []

        self.moe_layers = []
        example_moe = None
        for layer in self.model.layers:
756
757
758
            if isinstance(layer, PPMissingLayer):
                continue

759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
            assert isinstance(layer, Llama4DecoderLayer)
            if isinstance(layer.feed_forward, Llama4MoE):
                # Pick last one layer since the first ones may be dense layers.
                example_moe = layer.feed_forward
                self.moe_layers.append(layer.feed_forward.experts)

        if example_moe is None:
            self.num_moe_layers = 0
            self.num_expert_groups = 0
            self.num_logical_experts = 0
            self.num_physical_experts = 0
            self.num_local_physical_experts = 0
            self.num_routed_experts = 0
            self.num_shared_experts = 0
            self.num_redundant_experts = 0
            logger.warning("No Llama4MoE layer found in model.layers.")
        else:
            self.num_moe_layers = len(self.moe_layers)
            self.num_expert_groups = 1
            self.num_logical_experts = example_moe.n_logical_experts
            self.num_physical_experts = example_moe.n_physical_experts
            self.num_local_physical_experts = example_moe.n_local_physical_experts
            self.num_routed_experts = example_moe.n_routed_experts
            self.num_shared_experts = example_moe.n_shared_experts
            self.num_redundant_experts = example_moe.n_redundant_experts

    def update_physical_experts_metadata(
        self,
        num_physical_experts: int,
        num_local_physical_experts: int,
    ) -> None:
        assert self.num_local_physical_experts == num_local_physical_experts
        self.num_physical_experts = num_physical_experts
        self.num_local_physical_experts = num_local_physical_experts
        self.num_redundant_experts = num_physical_experts - self.num_logical_experts
        for layer in self.model.layers:
795
796
797
            if isinstance(layer, PPMissingLayer):
                continue

798
799
800
801
802
803
            if isinstance(layer.feed_forward, Llama4MoE):
                moe = layer.feed_forward
                moe.n_local_physical_experts = num_local_physical_experts
                moe.n_physical_experts = num_physical_experts
                moe.n_redundant_experts = self.num_redundant_experts
                moe.experts.update_expert_map()
804
805
806
807
808
809
810
811
812
813
814
815

    def _init_model(
        self,
        vllm_config: VllmConfig,
        prefix: str = "",
        layer_type: type[Llama4DecoderLayer] = Llama4DecoderLayer,
    ):
        return Llama4Model(
            vllm_config=vllm_config, prefix=prefix, layer_type=layer_type
        )

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
816
817
        loader = AutoWeightsLoader(
            self,
818
            skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
819
820
821
822
823
824
825
826
827
828
829
        )
        weights = [
            self.permute_qk_weight_for_rotary(name, loaded_weight)
            for name, loaded_weight in weights
        ]
        return loader.load_weights(weights)

    def permute_qk_weight_for_rotary(
        self,
        name: str,
        loaded_weight: torch.Tensor,
830
    ) -> tuple[str, torch.Tensor]:
831
832
833
834
        # Helper function to permute the weight's channels
        def permute(w: torch.Tensor, n_heads: int, is_weight_scale: bool):
            # Calculate the expected shape of the weight.
            # Do not rely on w's shape, as it may be in another layout.
835
836
837
            attn_in = self.config.head_dim * n_heads
            attn_out = self.config.hidden_size

838
839
840
841
842
843
844
            # If the weight is FP4 packed as uint8, we need to divide attn_out
            # by 2.
            if w.dtype == torch.uint8 and w.shape[1] * 2 == attn_out:
                attn_out = attn_out // 2

            # If the weight is a weight scale, we need to divide attn_out by
            # block size, which is currently 16.
845
846
847
848
849
            elif (
                w.dtype == torch.float8_e4m3fn
                and is_weight_scale
                and w.shape[1] * 16 == attn_out
            ):
850
851
                attn_out = attn_out // 16

852
853
854
855
856
            return (
                w.view(n_heads, attn_in // n_heads // 2, 2, attn_out)
                .transpose(1, 2)
                .reshape(attn_in, attn_out)
            )
857
858
859

        modules = name.split(".")

860
861
        # Permute Q/K weights and weight block scales for rotary embedding
        is_weight = modules[-1] == "weight"
862
863
864
        is_nvfp4_weight_scale = (
            modules[-1] == "weight_scale" and loaded_weight.dtype == torch.float8_e4m3fn
        )
865
866

        if is_weight or is_nvfp4_weight_scale:
867
868
869
870
871
872
873
874
875
876
877
878
            if "wk" in modules or "k_proj" in modules:
                loaded_weight = permute(
                    loaded_weight,
                    self.config.num_key_value_heads,
                    is_nvfp4_weight_scale,
                )
            elif "wq" in modules or "q_proj" in modules:
                loaded_weight = permute(
                    loaded_weight,
                    self.config.num_attention_heads,
                    is_nvfp4_weight_scale,
                )
879
880

        return name, loaded_weight