"vllm/vscode:/vscode.git/clone" did not exist on "3120128024f11364321b0804d379d53c653c1d7b"
llama4.py 34.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
#
# Copyright 2025 the LLAMA4, Meta Inc., vLLM, and HuggingFace Inc. team.
# All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only LLaMA model compatible with HuggingFace weights."""
20

21
from collections.abc import Iterable
22
23
24
25
26
27

import torch
from torch import nn
from transformers import Llama4TextConfig

from vllm.attention import Attention
28
from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention
29
30
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
31
from vllm.distributed import (
32
    get_ep_group,
33
34
35
    get_tensor_model_parallel_world_size,
    tensor_model_parallel_all_gather,
)
36
from vllm.logger import init_logger
37
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
38
from vllm.model_executor.layers.layernorm import RMSNorm
39
40
41
42
43
from vllm.model_executor.layers.linear import (
    QKVParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
44
45
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
46
from vllm.model_executor.model_loader.weight_utils import (
47
48
49
    default_weight_loader,
    maybe_remap_kv_scale_name,
)
50
from vllm.model_executor.models.interfaces import MixtureOfExperts
51
from vllm.model_executor.models.utils import sequence_parallel_chunk
52
53

from .llama import LlamaForCausalLM, LlamaMLP, LlamaModel
54
55
from .utils import (
    AutoWeightsLoader,
56
    PPMissingLayer,
57
58
59
60
    extract_layer_index,
    fast_topk,
    is_pp_missing_parameter,
)
61

62
63
logger = init_logger(__name__)

64
65
66
67
68
69
70
71

class Llama4MoE(nn.Module):
    @staticmethod
    def custom_routing_function(
        hidden_states: torch.Tensor,
        gating_output: torch.Tensor,
        topk: int,
        renormalize: bool,
72
    ) -> tuple[torch.Tensor, torch.Tensor]:
73
        router_scores, router_indices = fast_topk(gating_output, topk, dim=-1)
74
        # pseudo-standard is that the router scores are floats
75
        router_scores = torch.sigmoid(router_scores.float())
76
77
        return (router_scores, router_indices.to(torch.int32))

78
    def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
79
        super().__init__()
80
81
82
83
84

        config = vllm_config.model_config.hf_config
        parallel_config = vllm_config.parallel_config
        quant_config = vllm_config.quant_config

85
86
        self.tp_size = get_tensor_model_parallel_world_size()
        self.top_k = config.num_experts_per_tok
87
        self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
88
89
90
        self.ep_group = get_ep_group().device_group
        self.ep_rank = get_ep_group().rank_in_group
        self.ep_size = self.ep_group.size()
91
92

        intermediate_size_moe = config.intermediate_size
93
94
95
96
97
98
99
        self.router = ReplicatedLinear(
            config.hidden_size,
            config.num_local_experts,
            bias=False,
            quant_config=None,
            prefix=f"{prefix}.router",
        )
100

101
102
103
104
105
106
107
108
        self.shared_expert = LlamaMLP(
            hidden_size=config.hidden_size,
            intermediate_size=intermediate_size_moe,
            hidden_act="silu",
            quant_config=quant_config,
            bias=False,
            prefix=f"{prefix}.shared_expert",
            reduce_results=False,
109
            disable_tp=self.is_sequence_parallel,
110
111
        )

112
113
114
115
116
117
118
119
120
121
122
123
124
125
        # Load balancing settings.
        eplb_config = parallel_config.eplb_config if parallel_config else None
        self.enable_eplb = parallel_config.enable_eplb if parallel_config else False
        self.n_redundant_experts = (
            eplb_config.num_redundant_experts if eplb_config else 0
        )

        self.n_routed_experts: int = config.num_local_experts
        self.n_logical_experts = self.n_routed_experts
        self.n_shared_experts: int = 1
        self.n_local_experts: int = config.num_local_experts
        self.n_physical_experts = self.n_local_experts + self.n_redundant_experts
        self.n_local_physical_experts = self.n_physical_experts // self.ep_size

126
127
        self.experts = SharedFusedMoE(
            shared_experts=self.shared_expert,
128
129
130
131
132
133
134
135
136
            num_experts=config.num_local_experts,
            top_k=config.num_experts_per_tok,
            hidden_size=config.hidden_size,
            custom_routing_function=Llama4MoE.custom_routing_function,
            intermediate_size=intermediate_size_moe,
            apply_router_weight_on_input=True,
            reduce_results=False,
            renormalize=False,
            quant_config=quant_config,
137
            prefix=f"{prefix}.experts",
138
            is_sequence_parallel=self.is_sequence_parallel,
139
140
            enable_eplb=self.enable_eplb,
            num_redundant_experts=self.n_redundant_experts,
141
142
143
        )

    def forward(self, hidden_states):
144
145
146
147
        num_tokens = hidden_states.shape[0]
        if self.is_sequence_parallel:
            hidden_states = sequence_parallel_chunk(hidden_states)

148
        router_logits, _ = self.router(hidden_states)
149
150

        shared_out, routed_out = self.experts(
151
152
153
154
155
            hidden_states=hidden_states,
            router_logits=router_logits,
        )
        experts_out = routed_out + shared_out

156
157
158
159
        if self.is_sequence_parallel:
            experts_out = tensor_model_parallel_all_gather(experts_out, 0)
            experts_out = experts_out[:num_tokens]
        elif self.tp_size > 1:
160
            experts_out = self.experts.maybe_all_reduce_tensor_model_parallel(
161
162
                experts_out
            )
163
164
165
166
167

        return experts_out


class Llama4Attention(nn.Module):
168
169
170
171
172
173
174
    def __init__(
        self,
        config: Llama4TextConfig,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        max_position_embeddings: int = 8192,
175
        quant_config: QuantizationConfig | None = None,
176
177
        bias: bool = False,
        bias_o_proj: bool = False,
178
        cache_config: CacheConfig | None = None,
179
180
        prefix: str = "",
    ) -> None:
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
        super().__init__()
        self.layer_idx = extract_layer_index(prefix)
        self.hidden_size = hidden_size
        self.no_rope_layers = config.no_rope_layers
        self.nope = self.no_rope_layers[self.layer_idx] == 0
        self.use_qk_norm = config.use_qk_norm and not self.nope
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = config.head_dim
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
205
        self.attn_temperature_tuning = self.nope and config.attn_temperature_tuning
206
207
208
209
210

        self.floor_scale = getattr(config, "floor_scale", 8192.0)
        self.attn_scale = getattr(config, "attn_scale", 0.1)
        self.max_position_embeddings = max_position_embeddings
        self.n_rep = self.num_heads // self.num_kv_heads
211
212
213
214
215
216
217
218
219
220
        self.qk_norm = (
            RMSNorm(
                hidden_size=self.head_dim,
                eps=config.rms_norm_eps,
                has_weight=False,
                dtype=torch.float32,
            )
            if self.use_qk_norm
            else None
        )
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
        self.qkv_proj = QKVParallelLinear(
            hidden_size=hidden_size,
            head_size=self.head_dim,
            total_num_heads=self.total_num_heads,
            total_num_kv_heads=self.total_num_kv_heads,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
        )

        self.o_proj = RowParallelLinear(
            input_size=self.total_num_heads * self.head_dim,
            output_size=hidden_size,
            bias=bias_o_proj,
            quant_config=quant_config,
            prefix=f"{prefix}.o_proj",
        )
        is_neox_style = True
        is_gguf = quant_config and quant_config.get_name() == "gguf"
        if is_gguf and config.model_type == "llama":
            is_neox_style = False

243
244
245
246
247
        self.rotary_emb = (
            get_rope(
                self.head_dim,
                rotary_dim=self.head_dim,
                max_position=max_position_embeddings,
248
                rope_parameters=config.rope_parameters,
249
250
251
252
253
                is_neox_style=is_neox_style,
            )
            if not self.nope
            else None
        )
254

255
        use_chunked_local_attn = not self.nope and config.attention_chunk_size
256
        attn_cls = ChunkedLocalAttention if use_chunked_local_attn else Attention
257
        self.attn = attn_cls(
258
259
260
261
262
263
264
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
265
266
267
268
269
270
            **(
                {"attention_chunk_size": config.attention_chunk_size}
                if use_chunked_local_attn
                else {}
            ),
        )
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287

    def _get_attn_scale(self, positions: torch.Tensor) -> torch.Tensor:
        floor = torch.floor((positions + 1.0) / self.floor_scale)
        attn_scale = torch.log(floor + 1.0) * self.attn_scale + 1.0

        return attn_scale.unsqueeze(-1)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)

        if self.rotary_emb is not None:
            q, k = self.rotary_emb(positions, q, k)
288

289
        if self.qk_norm is not None:
290
291
292
293
            # Normalization is applied on the head_dim dimension. The rest of
            # the dimensions are collapsed into a single dimension to support
            # custom rms_norm cuda kernel.
            q = q.reshape(-1, self.head_dim)
294
            q = self.qk_norm(q.float()).reshape(-1, self.q_size).to(q.dtype)
295
            k = k.reshape(-1, self.head_dim)
296
            k = self.qk_norm(k.float()).reshape(-1, self.kv_size).to(k.dtype)
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314

        # We are applying temperature tuning (https://arxiv.org/abs/2501.19399)
        # to NoPE layers, where the inference-time temperature tuning function
        # is customized to not affect short context
        # while working at very long context
        # https://arxiv.org/abs/2501.19399
        #
        # We should apply temperature tuning between (after) rotary / QK norm
        # and (before) attention.
        if self.attn_temperature_tuning and self.nope:
            attn_scale = self._get_attn_scale(positions)
            q = (q * attn_scale).to(q.dtype)
        attn_output = self.attn(q, k, v)
        output, _ = self.o_proj(attn_output)
        return output


class Llama4DecoderLayer(nn.Module):
315
316
317
318
    def __init__(
        self,
        vllm_config: VllmConfig,
        prefix: str = "",
319
        config: Llama4TextConfig | None = None,
320
    ) -> None:
321
322
        super().__init__()

323
324
325
326
        config = config or vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

327
        self.layer_idx = extract_layer_index(prefix)
328
        self.global_layer = config.no_rope_layers[self.layer_idx] == 0
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
        self.hidden_size = config.hidden_size
        max_position_embeddings = config.max_position_embeddings

        self.self_attn = Llama4Attention(
            config=config,
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            num_kv_heads=config.num_key_value_heads,
            max_position_embeddings=max_position_embeddings,
            quant_config=quant_config,
            bias=False,
            bias_o_proj=False,
            cache_config=cache_config,
            prefix=f"{prefix}.self_attn",
        )
344
345
346
347
        is_moe_layer = (
            config.interleave_moe_layer_step > 0
            and (self.layer_idx + 1) % config.interleave_moe_layer_step == 0
        )
348
349
        if is_moe_layer:
            self.feed_forward = Llama4MoE(
350
                vllm_config=vllm_config,
351
352
353
354
355
356
357
358
359
360
361
                prefix=f"{prefix}.feed_forward",
            )
        else:
            self.feed_forward = LlamaMLP(
                hidden_size=self.hidden_size,
                intermediate_size=config.intermediate_size_mlp,
                hidden_act="silu",
                quant_config=quant_config,
                bias=False,
                prefix=f"{prefix}.feed_forward",
            )
362
363
364
365
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
366
367
368
369
370

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
371
        residual: torch.Tensor | None,
372
    ) -> tuple[torch.Tensor, torch.Tensor]:
373
374
375
376
377
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
378
379
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
        hidden_states = self.self_attn(positions=positions, hidden_states=hidden_states)
380
381

        # Fully Connected
382
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
383
384
385
386
387
388
        hidden_states = self.feed_forward(hidden_states)
        return hidden_states, residual


@support_torch_compile
class Llama4Model(LlamaModel):
389
390
391
392
393
394
395
    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
        layer_type: type[Llama4DecoderLayer] = Llama4DecoderLayer,
    ):
396
        self.num_experts = vllm_config.model_config.hf_config.num_local_experts
397
398
399
        self.n_redundant_experts = (
            vllm_config.parallel_config.eplb_config.num_redundant_experts
        )
400
        super().__init__(vllm_config=vllm_config, prefix=prefix, layer_type=layer_type)
401
402
403
404
405

    def load_moe_expert_weights(
        self,
        name: str,
        loaded_weight: torch.Tensor,
406
407
408
        params_dict: dict[str, nn.Parameter],
        loaded_params: set[str],
        expert_params_mapping: list[tuple[str, str, int, str]],
409
410
        fused: bool = True,
    ) -> bool:
411
412
413
414
415
416
417
418
419
        """
        Load MoE expert weights.

        Args:
            name: The name of the weight to load.
            loaded_weight: The weight to load.
            params_dict: The dictionary of module parameters.
            loaded_params: The set of already loaded parameters.
            expert_params_mapping: The mapping of expert parameters. Must be
420
                generated by SharedFusedMoE.make_expert_params_mapping().
421
422
423
424
425
426
427
428
429
430
431
432
433
434
            fused: Whether the expert weights are fused into a single weight
                tensor or are separate weight tensors for each expert.
                When fused is True, loaded_weight should have shape of:
                [num_experts, hidden_in, hidden_out] for gate/up/down proj and
                [hidden_out, hidden_in] for the others like router.
                When fused is False, loaded_weight should have shape of:
                [hidden_out, hidden_in].

        Returns:
            True if loaded_weight is one of MoE weights and the MoE expert
            weights are loaded successfully, False otherwise.
        """

        # Whether the MoE expert weights are loaded successfully.
435
        expert_param_loaded = False
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450

        # If fused is True, the loaded weight is in the layout of:
        # [num_experts, hidden_in, hidden_out], so we must transpose the last
        # two dimensions to match the expected layout of the parameters.
        if fused and loaded_weight.ndim == 3:
            loaded_weight = loaded_weight.transpose(-1, -2)

            # If the gate_proj and up_proj weights are fused into a single
            # weight tensor, we need to split the weight tensor into a tuple
            # of two weight tensors along the hidden_out dimension.
            if "experts.gate_up_proj" in name:
                loaded_weight = loaded_weight.chunk(2, dim=-2)

        # Iterate over all the expert parameters and load the weights if we find
        # a match in weight name.
451
        for param_name, weight_name, expert_id, shard_id in expert_params_mapping:
452
453
            # Get a view of the loaded_weight to avoid modifying the original
            # one across iterations.
454
            new_loaded_weight = loaded_weight
455
456
457

            # If expert weights are fused into a single weight tensor, remove
            # the expert index from the expected weight name.
458
            if fused:
459
                # The string between e_str and proj_str is the expert index.
460
                e_str, _, proj_str, _ = weight_name.split(".")
461
462
                weight_name = f"{e_str}.{proj_str}"
                param_name = f"{param_name}weight"
463
464

            # Skip if the current weight is not one of the MoE weights.
465
466
            if weight_name not in name:
                continue
467
468

            # Replace the weight name with the parameter name.
469
            full_param_name = name.replace(weight_name, param_name)
470
471
472

            # Skip if the current weight corresponds to a parameter that
            # does not exist on the current PP (pipeline parallel) rank.
473
474
            if is_pp_missing_parameter(name, self):
                continue
475
476

            # Skip if the current weight is for the bias.
477
478
479
            if (
                name.endswith(".bias") or name.endswith("_bias")
            ) and name not in params_dict:
480
                continue
481

482
483
            param = params_dict[full_param_name]
            weight_loader = param.weight_loader
484

485
            if fused:
486
487
488
                # If the parameter is for w13 together, the corresponding weight
                # will be a tuple, so we must select the correct weight
                # depending on the shard id, which is either "w1" or "w3".
489
                if "w13" in full_param_name:
490
                    assert shard_id in ["w1", "w3"]
491
492
                    shard_idx = 0 if shard_id == "w1" else 1
                    new_loaded_weight = new_loaded_weight[shard_idx]
493
494
495
496

                # If EP (expert parallel) is enabled, update expert_id to the
                # starting expert index for the current EP rank and extract the
                # corresponding expert weights.
497
                layer_idx = extract_layer_index(name)
498
                expert_map = self.layers[layer_idx].feed_forward.experts.expert_map
499
                if expert_map is not None:
500
501
502
503
504
505
                    local_expert_indices = (
                        (expert_map != -1)
                        .nonzero()
                        .flatten()
                        .to(new_loaded_weight.device)
                    )
506
507
508
509
510
                    new_loaded_weight = new_loaded_weight[local_expert_indices]
                    expert_id = local_expert_indices[0].item()
            else:
                # TODO: add EP support for non fused weights
                pass
511
512
513

            # Load the weight into the module parameter with corresponding
            # shard id and expert id.
514
515
516
517
518
519
520
            weight_loader(
                param,
                new_loaded_weight,
                full_param_name,
                shard_id=shard_id,
                expert_id=expert_id,
            )
521
522
            loaded_params.add(full_param_name)
            expert_param_loaded = True
523

524
525
        return expert_param_loaded

526
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
527
528
        # Name mapping from the parameter name to the shard name and
        # corresponding shard id.
529
530
531
532
533
534
535
536
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
        ]
537
538
        # Indicate whether the expert weights are fused into a single weight
        # tensor.
539
        fused_experts_params = False
540
541
        # Expert parameter mapping for the case where the expert weights are
        # not fused into a single weight tensor.
542
        expert_params_mapping = SharedFusedMoE.make_expert_params_mapping(
543
544
545
            ckpt_gate_proj_name="gate_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="up_proj",
546
            num_experts=self.num_experts,
547
            num_redundant_experts=self.n_redundant_experts,
548
        )
549
550
        # Expert parameter mapping for the case where the expert weights are
        # fused into a single weight tensor.
551
        expert_params_mapping_fused = SharedFusedMoE.make_expert_params_mapping(
552
553
554
            ckpt_gate_proj_name="gate_up_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="gate_up_proj",
555
556
            num_experts=1,
        )
557
        # All the module parameters.
558
        params_dict = dict(self.named_parameters())
559
        # The module parameters that have been loaded.
560
        loaded_params: set[str] = set()
561
562

        # Iterate over all the weights and load them into module parameters.
563
        for name, loaded_weight in weights:
564
565
566
            # If the name contains "experts.gate_up_proj" or "experts.down_proj"
            # without the expert indices, it means the expert weights are fused
            # into a single weight tensor across all experts.
567
568
569
            if "experts.gate_up_proj" in name or "experts.down_proj" in name:
                fused_experts_params = True
                expert_params_mapping = expert_params_mapping_fused
570
571
572
573

            # If kv cache quantization scales exist and the weight name
            # corresponds to one of the kv cache quantization scales, load
            # them.
574
575
576
            if self.quant_config is not None and (
                scale_name := self.quant_config.get_cache_scale(name)
            ):
577
                param = params_dict[scale_name]
578
579
580
581
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                loaded_weight = (
                    loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
                )
582
583
584
                weight_loader(param, loaded_weight)
                loaded_params.add(scale_name)
                continue
585
586
587
588
589

            # Iterate over stacked_params_mapping to check if the current weight
            # is one of the stacked parameters. If so, load the weight with the
            # corresponding shard id. Note that MoE weights are handled
            # separately in the else block.
590
            for param_name, weight_name, shard_id in stacked_params_mapping:
591
592
                # Skip if the current weight is not one of the stacked
                # parameters or if the current weight is a MoE weight.
593
594
                if weight_name not in name or "experts" in name:
                    continue
595
596
597

                # For ModelOpt checkpoints, we need to rename the self_attn
                # weight/weight_scale names except for kv cache scales.
598
599
600
                if not (
                    name.endswith((".k_scale", ".v_scale")) and "self_attn" in name
                ):
601
                    name = name.replace(weight_name, param_name)
602
603
604

                # Skip if the current weight corresponds to a parameter that
                # does not exist on the current PP (pipeline parallel) rank.
605
606
                if is_pp_missing_parameter(name, self):
                    continue
607
608
609
610
611

                # Remap kv cache scale names for ModelOpt checkpoints.
                # TODO: ModelOpt should implement get_cache_scale() such that
                #       kv cache scale name remapping can be done there.
                if name.endswith("scale"):
612
613
614
                    name = maybe_remap_kv_scale_name(name, params_dict)
                    if name is None:
                        continue
615
616
617

                # Load the weight into the module parameter with corresponding
                # shard id and exit the for loop and the else block.
618
                param = params_dict[name]
619
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
620

621
622
623
624
                if weight_loader == default_weight_loader:
                    weight_loader(param, loaded_weight)
                else:
                    weight_loader(param, loaded_weight, shard_id)
625

626
627
                loaded_params.add(name)
                break
628
629

            # Handle normal (non-stacked) weights and MoE weights.
630
            else:
631
632
                # First, try to load MoE weights using load_moe_expert_weights.
                # If successful, move on to next loaded weight.
633
634
635
636
637
638
639
640
                if self.load_moe_expert_weights(
                    name,
                    loaded_weight,
                    params_dict,
                    loaded_params,
                    expert_params_mapping,
                    fused=fused_experts_params,
                ):
641
                    continue
642

643
644
645
646
647
648
649
650
651
                # Skip if the current weight corresponds to a parameter that
                # does not exist on the current PP (pipeline parallel) rank.
                if is_pp_missing_parameter(name, self):
                    continue

                # Handle flat expert scale parameters that don't match
                # per-expert patterns, i.e. one weight scale tensor for all
                # experts.
                scale_names = [
652
653
654
655
                    "w13_input_scale",
                    "w13_weight_scale",
                    "w2_input_scale",
                    "w2_weight_scale",
656
                ]
657
658
659
                if "experts." in name and any(
                    scale_name in name for scale_name in scale_names
                ):
660
                    param = params_dict[name]
661
662
663
                    weight_loader = getattr(
                        param, "weight_loader", default_weight_loader
                    )
664
665
666

                    # If weight loader supports special moe loading, use it to
                    # avoid expensive runtime reflection
667
                    if getattr(weight_loader, "supports_moe_loading", False):
668
669
670
671
672
673
                        # Map the weight name to the corresponding shard id.
                        shard_id = "w2" if "w2_" in name else "w1"

                        # Transpose if weight scales are FP8 block scales with
                        # three dimensions:
                        # [num_experts, hidden_in, hidden_out].
674
675
676
677
678
                        if (
                            name.endswith("weight_scale")
                            and loaded_weight.dtype == torch.float8_e4m3fn
                            and loaded_weight.ndim == 3
                        ):
679
680
681
682
                            loaded_weight = loaded_weight.transpose(-1, -2)

                        # Load the weight into the module parameter with
                        # corresponding shard id and expert id.
683
684
685
                        weight_loader(
                            param, loaded_weight, name, shard_id=shard_id, expert_id=0
                        )
686
687
688
689
690
691

                    else:
                        # Regular weight loader (handles both
                        # param.weight_loader and default_weight_loader)
                        weight_loader(param, loaded_weight)

692
                    loaded_params.add(name)
693
694
695
696
                    continue

                # Handle normal (non-stacked, non-MoE) weights.
                param = params_dict[name]
697
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
698
699
700
701
                weight_loader(param, loaded_weight)
                loaded_params.add(name)

        # Finally, return the set of loaded parameters.
702
703
704
        return loaded_params


705
class Llama4ForCausalLM(LlamaForCausalLM, MixtureOfExperts):
706
707
708
709
710
711
    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
        "gate_up_proj": ["gate_proj", "up_proj"],
    }

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
712
        # update temperature tuning config from generation config
713
714
        gen_config = vllm_config.model_config.try_get_generation_config()
        gen_config.update(vllm_config.model_config.override_generation_config)
715
        # enable temperature tuning by default when max_model_len > 32K
716
717
718
719
720
721
722
723
        default_attn_temperature_tuning = vllm_config.model_config.max_model_len > 32768
        vllm_config.model_config.hf_config.attn_temperature_tuning = gen_config.get(
            "attn_temperature_tuning", default_attn_temperature_tuning
        )

        super().__init__(
            vllm_config=vllm_config, prefix=prefix, layer_type=Llama4DecoderLayer
        )
724
725
726
727
728
729
730
731
732
        # Set MoE hyperparameters
        self.set_moe_parameters()

    def set_moe_parameters(self):
        self.expert_weights = []

        self.moe_layers = []
        example_moe = None
        for layer in self.model.layers:
733
734
735
            if isinstance(layer, PPMissingLayer):
                continue

736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
            assert isinstance(layer, Llama4DecoderLayer)
            if isinstance(layer.feed_forward, Llama4MoE):
                # Pick last one layer since the first ones may be dense layers.
                example_moe = layer.feed_forward
                self.moe_layers.append(layer.feed_forward.experts)

        if example_moe is None:
            self.num_moe_layers = 0
            self.num_expert_groups = 0
            self.num_logical_experts = 0
            self.num_physical_experts = 0
            self.num_local_physical_experts = 0
            self.num_routed_experts = 0
            self.num_shared_experts = 0
            self.num_redundant_experts = 0
            logger.warning("No Llama4MoE layer found in model.layers.")
        else:
            self.num_moe_layers = len(self.moe_layers)
            self.num_expert_groups = 1
            self.num_logical_experts = example_moe.n_logical_experts
            self.num_physical_experts = example_moe.n_physical_experts
            self.num_local_physical_experts = example_moe.n_local_physical_experts
            self.num_routed_experts = example_moe.n_routed_experts
            self.num_shared_experts = example_moe.n_shared_experts
            self.num_redundant_experts = example_moe.n_redundant_experts

    def update_physical_experts_metadata(
        self,
        num_physical_experts: int,
        num_local_physical_experts: int,
    ) -> None:
        assert self.num_local_physical_experts == num_local_physical_experts
        self.num_physical_experts = num_physical_experts
        self.num_local_physical_experts = num_local_physical_experts
        self.num_redundant_experts = num_physical_experts - self.num_logical_experts
        for layer in self.model.layers:
772
773
774
            if isinstance(layer, PPMissingLayer):
                continue

775
776
777
778
779
780
            if isinstance(layer.feed_forward, Llama4MoE):
                moe = layer.feed_forward
                moe.n_local_physical_experts = num_local_physical_experts
                moe.n_physical_experts = num_physical_experts
                moe.n_redundant_experts = self.num_redundant_experts
                moe.experts.update_expert_map()
781
782
783
784
785
786
787
788
789
790
791
792

    def _init_model(
        self,
        vllm_config: VllmConfig,
        prefix: str = "",
        layer_type: type[Llama4DecoderLayer] = Llama4DecoderLayer,
    ):
        return Llama4Model(
            vllm_config=vllm_config, prefix=prefix, layer_type=layer_type
        )

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
793
794
        loader = AutoWeightsLoader(
            self,
795
            skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
796
797
798
799
800
801
802
803
804
805
806
        )
        weights = [
            self.permute_qk_weight_for_rotary(name, loaded_weight)
            for name, loaded_weight in weights
        ]
        return loader.load_weights(weights)

    def permute_qk_weight_for_rotary(
        self,
        name: str,
        loaded_weight: torch.Tensor,
807
    ) -> tuple[str, torch.Tensor]:
808
809
810
811
        # Helper function to permute the weight's channels
        def permute(w: torch.Tensor, n_heads: int, is_weight_scale: bool):
            # Calculate the expected shape of the weight.
            # Do not rely on w's shape, as it may be in another layout.
812
813
814
            attn_in = self.config.head_dim * n_heads
            attn_out = self.config.hidden_size

815
816
817
818
819
820
821
            # If the weight is FP4 packed as uint8, we need to divide attn_out
            # by 2.
            if w.dtype == torch.uint8 and w.shape[1] * 2 == attn_out:
                attn_out = attn_out // 2

            # If the weight is a weight scale, we need to divide attn_out by
            # block size, which is currently 16.
822
823
824
825
826
            elif (
                w.dtype == torch.float8_e4m3fn
                and is_weight_scale
                and w.shape[1] * 16 == attn_out
            ):
827
828
                attn_out = attn_out // 16

829
830
831
832
833
            return (
                w.view(n_heads, attn_in // n_heads // 2, 2, attn_out)
                .transpose(1, 2)
                .reshape(attn_in, attn_out)
            )
834
835
836

        modules = name.split(".")

837
838
        # Permute Q/K weights and weight block scales for rotary embedding
        is_weight = modules[-1] == "weight"
839
840
841
        is_nvfp4_weight_scale = (
            modules[-1] == "weight_scale" and loaded_weight.dtype == torch.float8_e4m3fn
        )
842
843

        if is_weight or is_nvfp4_weight_scale:
844
845
846
847
848
849
850
851
852
853
854
855
            if "wk" in modules or "k_proj" in modules:
                loaded_weight = permute(
                    loaded_weight,
                    self.config.num_key_value_heads,
                    is_nvfp4_weight_scale,
                )
            elif "wq" in modules or "q_proj" in modules:
                loaded_weight = permute(
                    loaded_weight,
                    self.config.num_attention_heads,
                    is_nvfp4_weight_scale,
                )
856
857

        return name, loaded_weight