gpt_oss.py 48.1 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
import typing
from collections.abc import Callable, Iterable
5
6
7
8
9
10
11
12

import torch
import torch.distributed as dist
from torch import nn
from transformers import GptOssConfig

from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
13
from vllm.distributed import (
14
    get_dp_group,
15
    get_ep_group,
16
    get_pcp_group,
17
18
19
20
21
    get_pp_group,
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
    tensor_model_parallel_all_gather,
)
22
from vllm.model_executor.layers.attention import Attention
23
from vllm.model_executor.layers.fused_moe import FusedMoE
24
from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig
25
from vllm.model_executor.layers.layernorm import RMSNorm
26
from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear
27
28
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
29
from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import OCP_MX_BLOCK_SIZE
30
from vllm.model_executor.layers.rotary_embedding import get_rope
31
from vllm.model_executor.layers.utils import rocm_unquantized_gemm
32
from vllm.model_executor.layers.vocab_parallel_embedding import (
33
34
35
    ParallelLMHead,
    VocabParallelEmbedding,
)
36
37
38
39
from vllm.model_executor.model_loader.weight_utils import (
    default_weight_loader,
    maybe_remap_kv_scale_name,
)
40
from vllm.model_executor.models.utils import sequence_parallel_chunk
41
from vllm.platforms import current_platform
42
from vllm.sequence import IntermediateTensors
43
from vllm.utils.math_utils import cdiv
44
from vllm.v1.attention.backend import AttentionType
45

46
from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP
47
48
49
50
51
52
53
54
55
from .utils import (
    AutoWeightsLoader,
    WeightsMapper,
    extract_layer_index,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
56
57
58
59
60
61


class OAIAttention(nn.Module):
    def __init__(
        self,
        config: GptOssConfig,
62
63
        quant_config: QuantizationConfig | None = None,
        cache_config: CacheConfig | None = None,
64
65
66
67
68
69
70
71
72
73
74
75
76
        prefix: str = "",
    ):
        super().__init__()
        self.layer_idx = extract_layer_index(prefix)
        self.head_dim = config.head_dim
        self.num_attention_heads = config.num_attention_heads
        self.num_key_value_heads = config.num_key_value_heads
        self.hidden_size = config.hidden_size

        self.rotary_emb = get_rope(
            self.head_dim,
            max_position=config.max_position_embeddings,
            dtype=torch.float32,
77
78
            rope_parameters={
                "rope_theta": config.rope_parameters["rope_theta"],
79
                "rope_type": "yarn",
80
81
                "factor": config.rope_parameters["factor"],
                "original_max_position_embeddings": config.rope_parameters[
82
83
                    "original_max_position_embeddings"
                ],
84
85
                "beta_fast": config.rope_parameters["beta_fast"],
                "beta_slow": config.rope_parameters["beta_slow"],
86
                "truncate": config.rope_parameters.get("truncate", True),
87
88
89
90
91
92
93
            },
            is_neox_style=True,
        )

        tp_size = get_tensor_model_parallel_world_size()

        self.sinks = torch.nn.Parameter(
94
95
            torch.empty(config.num_attention_heads // tp_size, requires_grad=False)
        )
96
97
98
99
100

        self.q_size = self.num_attention_heads * self.head_dim // tp_size
        self.kv_size = self.num_key_value_heads * self.head_dim // tp_size
        self.scaling = self.head_dim**-0.5

101
        self.qkv_proj = QKVParallelLinear(
102
103
104
105
            hidden_size=self.hidden_size,
            head_size=self.head_dim,
            total_num_heads=self.num_attention_heads,
            total_num_kv_heads=self.num_key_value_heads,
106
            bias=True,
107
108
109
110
111
112
113
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
        )

        self.o_proj = RowParallelLinear(
            input_size=self.num_attention_heads * self.head_dim,
            output_size=self.hidden_size,
114
            bias=True,
115
116
117
118
119
120
121
122
            quant_config=quant_config,
            prefix=f"{prefix}.o_proj",
        )

        self.num_local_attention_heads = config.num_attention_heads // tp_size
        self.num_local_key_value_heads = config.num_key_value_heads // tp_size

        # Only apply sliding window to every other layer
123
        sliding_window = config.sliding_window if self.layer_idx % 2 == 0 else None
124
125
126
127
128
129
130
131
132
133
134
135
136
        self.attn = Attention(
            self.num_local_attention_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_local_key_value_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            per_layer_sliding_window=sliding_window,
            attn_type=AttentionType.DECODER,
            prefix=f"{prefix}.attn",
            sinks=self.sinks,
        )

137
138
139
    def forward(
        self, hidden_states: torch.Tensor, positions: torch.Tensor
    ) -> torch.Tensor:
140
        qkv, _ = self.qkv_proj(hidden_states)
141
142
143
144
145
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
        v = v.contiguous()
        attn_output = self.attn(q, k, v)
        output, _ = self.o_proj(attn_output)
146
        return output
147
148
149
150
151


class MLPBlock(torch.nn.Module):
    def __init__(
        self,
152
        vllm_config: VllmConfig,
153
154
155
156
        layer_idx: int,
        prefix: str = "",
    ):
        super().__init__()
157
158
159
160
161
162
163

        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        parallel_config = vllm_config.parallel_config

        self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe

164
165
        self.layer_idx = layer_idx
        self.num_experts = config.num_local_experts
166
        self.hidden_size = config.hidden_size
167
168
        self.experts_per_token = config.num_experts_per_tok
        self.world_size = dist.get_world_size() if dist.is_initialized() else 1
169
        self.router = torch.nn.Linear(config.hidden_size, config.num_local_experts)
170
        assert config.intermediate_size % self.world_size == 0
171
172
173
174
175
176
177
178
179
180
181
182
183
184
        self.experts = FusedMoE(
            num_experts=config.num_local_experts,
            top_k=config.num_experts_per_tok,
            hidden_size=config.hidden_size,
            intermediate_size=config.intermediate_size,
            reduce_results=True,
            renormalize=True,
            quant_config=quant_config,
            prefix=f"{prefix}.experts",
            apply_router_weight_on_input=False,
            has_bias=True,
            activation="swigluoai",
            is_sequence_parallel=self.is_sequence_parallel,
        )
185
186

    def forward(self, x: torch.Tensor) -> torch.Tensor:
187
188
189
190
        num_tokens = x.shape[0]
        if self.is_sequence_parallel:
            x = sequence_parallel_chunk(x)

191
192
193
194
195
196
        if current_platform.is_rocm():
            g = rocm_unquantized_gemm(
                self, x[:, : self.hidden_size], self.router.weight, self.router.bias
            )
        else:
            g = self.router(x)
197
        x = self.experts(hidden_states=x, router_logits=g)[:, : self.hidden_size]
198
199
200
201

        if self.is_sequence_parallel:
            x = tensor_model_parallel_all_gather(x.contiguous(), 0)
            x = x[:num_tokens]
202
        return x
203
204
205
206
207


class TransformerBlock(torch.nn.Module):
    def __init__(
        self,
208
        vllm_config: VllmConfig,
209
        quant_config: QuantizationConfig,
210
211
212
        prefix: str = "",
    ):
        super().__init__()
213
214
215
216

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config

217
        self.layer_idx = extract_layer_index(prefix)
218
        self.attn = OAIAttention(
219
220
221
222
            config,
            prefix=f"{prefix}.attn",
            quant_config=quant_config,
            cache_config=cache_config,
223
224
        )
        self.mlp = MLPBlock(vllm_config, self.layer_idx, prefix=f"{prefix}.mlp")
225
226
        self.input_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
        self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
227

228
229
230
231
    def forward(
        self,
        hidden_states: torch.Tensor,
        positions: torch.Tensor,
232
        residual: torch.Tensor | None,
233
234
235
236
237
238
    ) -> torch.Tensor:
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
239
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
240
        hidden_states = self.attn(hidden_states, positions)
241

242
        # Fully Connected
243
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
244
245
        output = self.mlp(hidden_states)
        return output, residual
246
247
248
249
250
251
252
253
254
255
256
257


@support_torch_compile
class GptOssModel(nn.Module):
    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
    ):
        super().__init__()
        self.config = vllm_config.model_config.hf_config
258
        self.quant_config = vllm_config.quant_config
259
        self.parallel_config = vllm_config.parallel_config
260
261
262
263
264
        self.config.hidden_size = self.config.hidden_size
        self.embedding = VocabParallelEmbedding(
            self.config.vocab_size,
            self.config.hidden_size,
        )
265
266
267
        self.start_layer, self.end_layer, self.layers = make_layers(
            self.config.num_hidden_layers,
            lambda prefix: TransformerBlock(
268
                vllm_config,
269
                prefix=prefix,
270
                quant_config=self.quant_config,
271
272
273
            ),
            prefix=f"{prefix}.layers",
        )
274
        self.norm = RMSNorm(self.config.hidden_size, eps=1e-5)
275
276
277
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], self.config.hidden_size
        )
278
        self.aux_hidden_state_layers = tuple[int, ...]()
279

280
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
281
282
283
284
        return self.embedding(input_ids)

    def forward(
        self,
285
        input_ids: torch.Tensor | None,
286
        positions: torch.Tensor,
287
288
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
289
290
291
292
293
    ) -> torch.Tensor:
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                x = inputs_embeds
            else:
294
                x = self.embed_input_ids(input_ids)
295
296
297
298
299
300
301

            residual = None
        else:
            assert intermediate_tensors is not None
            x = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

302
        aux_hidden_states = []
303
304
        for i in range(self.start_layer, self.end_layer):
            layer = self.layers[i]
305
            if i in self.aux_hidden_state_layers:
306
                aux_hidden_states.append(x if residual is None else x + residual)
307
308
            x, residual = layer(x, positions, residual)
        if not get_pp_group().is_last_rank:
309
            return IntermediateTensors({"hidden_states": x, "residual": residual})
310
        x, _ = self.norm(x, residual)
311
312
313

        if len(aux_hidden_states) > 0:
            return x, aux_hidden_states
314
315
        return x

316
317
318
319
320
321
322
323
324
325
326
327
328
    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
        # Params for weights, weight scales, activation scales
        # (param_name, weight_name, expert_id, shard_id)
        # NOTE: this is only used for quark.
        return FusedMoE.make_expert_params_mapping(
            self,
            ckpt_gate_proj_name="w1",
            ckpt_down_proj_name="w2",
            ckpt_up_proj_name="w3",
            num_experts=self.config.num_local_experts,
            num_redundant_experts=0,
        )

329
    def _load_weights_mxfp4(
330
331
332
333
334
335
336
337
        self,
        ep_rank_end: int,
        ep_rank_start: int,
        heads_per_rank: int,
        head_start: int,
        weights: Iterable[tuple[str, torch.Tensor]],
        stacked_params_mapping: list[tuple[str, ...]],
    ) -> set[str]:
338
339
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()
340
341
342

        use_ep = self.parallel_config.enable_expert_parallel
        num_experts = self.config.num_local_experts
343

344
345
        # In MoE, we need to flatten the tensor parallel size across the data
        # parallel size when EP is disabled.
346
        tp_size, tp_rank = FusedMoEParallelConfig.flatten_tp_across_dp_and_pcp(
347
348
349
            tp_size=get_tensor_model_parallel_world_size(),
            dp_size=get_dp_group().world_size,
            dp_rank=get_dp_group().rank_in_group,
350
351
            pcp_size=get_pcp_group().world_size,
            pcp_rank=get_pcp_group().rank_in_group,
352
        )
353
354

        intermediate_size = self.config.intermediate_size
355
        intermediate_size_block = intermediate_size // OCP_MX_BLOCK_SIZE
356
        per_rank_intermediate_size_block = cdiv(intermediate_size_block, tp_size)
357
358
359
        per_rank_intermediate_size = (
            per_rank_intermediate_size_block * OCP_MX_BLOCK_SIZE
        )
360
361
362

        # Calculate common slicing bounds for current rank
        tp_rank_start = tp_rank * per_rank_intermediate_size
363
        tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size, intermediate_size)
364
365

        for name, weight in weights:
366
367
368
369
            # Skip layers on other devices.
            if is_pp_missing_parameter(name, self):
                continue

370
371
            if ".w13_weight_scale" in name:
                # Handle MLP gate and up projection weights scale
372
373
374
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
375
                    narrow_weight = weight[:, 2 * tp_rank_start : 2 * tp_rank_end, ...]
376

377
                param = params_dict[name]
378
379
380
381
382
383
384
385
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                weight_loader(
                    param,
                    narrow_weight,
                    weight_name=name,
                    shard_id=None,
                    expert_id=None,
                )
386
387
388
                loaded_params.add(name)
                continue
            elif ".w2_weight_scale" in name:
389
390
391
392
                # Handle MLP down projection weights
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
393
                    narrow_weight = weight[
394
395
396
                        ...,
                        tp_rank_start // OCP_MX_BLOCK_SIZE : tp_rank_end
                        // OCP_MX_BLOCK_SIZE,
397
                    ]
398

399
                param = params_dict[name]
400
401
402
403
404
405
406
407
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                weight_loader(
                    param,
                    narrow_weight,
                    weight_name=name,
                    shard_id=None,
                    expert_id=None,
                )
408
409
410
411
412
413
                loaded_params.add(name)
                continue
            elif ".w13_weight" in name:
                # Handle MLP gate and up projection weights
                # flat weight from (E, 2 * N, block_size, entry_per_block)
                # to (E, 2 * N, -1), shouldn't trigger copy for contiguous
414
415
416
                weight = weight.view(
                    num_experts, 2 * intermediate_size, -1
                ).contiguous()
417

418
419
                # Extract gate and up projection parts
                # since the weight is shuffled, we can slice directly
420
421
422
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
423
                    narrow_weight = weight[:, 2 * tp_rank_start : 2 * tp_rank_end, ...]
424

425
                param = params_dict[name]
426
427
428
429
430
431
432
433
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                weight_loader(
                    param,
                    narrow_weight,
                    weight_name=name,
                    shard_id=None,
                    expert_id=None,
                )
434
435
436
                loaded_params.add(name)
                continue
            elif ".w2_weight" in name:
437
                # Handle MLP down projection weights
438
439
                # same flatten here, but since 2 mx4 value are packed in 1
                # uint8, divide by 2
440
441
442
                weight = weight.view(
                    num_experts, -1, intermediate_size // 2
                ).contiguous()
443
444
445
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
446
                    narrow_weight = weight[..., tp_rank_start // 2 : tp_rank_end // 2]
447

448
                param = params_dict[name]
449
450
451
452
453
454
455
456
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                weight_loader(
                    param,
                    narrow_weight,
                    weight_name=name,
                    shard_id=None,
                    expert_id=None,
                )
457
458
459
                loaded_params.add(name)
                continue
            elif ".w13_bias" in name:
460
461
462
463
464
                # Handle MLP gate and up projection biases
                # Extract gate and up projection bias parts
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
465
                    narrow_weight = weight[:, 2 * tp_rank_start : 2 * tp_rank_end]
466

467
                param = params_dict[name]
468
469
470
471
472
473
474
475
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                weight_loader(
                    param,
                    narrow_weight,
                    weight_name=name,
                    shard_id=None,
                    expert_id=None,
                )
476
477
478
                loaded_params.add(name)
                continue
            elif ".w2_bias" in name:
479
                # Handle MLP down projection bias
480
                param = params_dict[name]
481
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
482
483
484
485
486
487
                if use_ep:
                    weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
                    # (only load on rank 0 to avoid duplication)
                    if tp_rank != 0:
                        weight.zero_()
488
489
490
                weight_loader(
                    param, weight, weight_name=name, shard_id=None, expert_id=None
                )
491
492
                loaded_params.add(name)
                continue
493
494
495
496
497
498
            elif "sinks" in name:
                # Handle attention sinks (distributed across ranks)
                param = params_dict[name]
                narrow_weight = weight.narrow(0, head_start, heads_per_rank)
                param.data.copy_(narrow_weight)
                loaded_params.add(name)
499
500
501
502
503
504
                continue
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
505
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
506
507
508
509
510
                if weight_loader == default_weight_loader:
                    weight_loader(param, weight)
                else:
                    weight_loader(param, weight, shard_id)
                break
511
512
            else:
                # Handle all other weights with potential renaming
513
                if name not in params_dict:
514
                    continue
515
                param = params_dict[name]
516
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
517
                weight_loader(param, weight)
518
            loaded_params.add(name)
519
        return loaded_params
520

521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
    def _load_weights_quark(
        self,
        ep_rank_end: int,
        ep_rank_start: int,
        heads_per_rank: int,
        head_start: int,
        weights: Iterable[tuple[str, torch.Tensor]],
        stacked_params_mapping: list[tuple[str, ...]],
    ) -> set[str]:
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()

        use_ep = self.parallel_config.enable_expert_parallel
        num_experts = self.config.num_local_experts

        if use_ep:
            tp_rank = get_tensor_model_parallel_rank()
            tp_size = get_tensor_model_parallel_world_size()
        else:
            tp_size, tp_rank = FusedMoEParallelConfig.flatten_tp_across_dp_and_pcp(
                tp_size=get_tensor_model_parallel_world_size(),
                dp_size=get_dp_group().world_size,
                dp_rank=get_dp_group().rank_in_group,
                pcp_size=get_pcp_group().world_size,
                pcp_rank=get_pcp_group().rank_in_group,
            )

        def _get_moe_weight_dtype(layer_id: int = 0) -> str | None:
            """Helper function to get MoE quantization weight dtype.

            Args:
                layer_id: Layer index to check (default 0, as all layers should
                        have the same quantization method)

            Returns:
                Weight dtype string (e.g., "mxfp4", "fp8") or None if not available
            """
            if hasattr(self.layers[layer_id].mlp.experts.quant_method, "weight_dtype"):
                return self.layers[layer_id].mlp.experts.quant_method.weight_dtype
            return None

        intermediate_size = self.config.intermediate_size

        moe_weight_dtype = _get_moe_weight_dtype(layer_id=0)

        if moe_weight_dtype == "mxfp4":
            # MXFP4 requires OCP_MX_BLOCK_SIZE alignment
            intermediate_size_block = intermediate_size // OCP_MX_BLOCK_SIZE
            per_rank_intermediate_size_block = cdiv(intermediate_size_block, tp_size)
            per_rank_intermediate_size = (
                per_rank_intermediate_size_block * OCP_MX_BLOCK_SIZE
            )
        else:
            # FP8 and other formats don't need alignment
            per_rank_intermediate_size = cdiv(intermediate_size, tp_size)

        tp_rank_start = tp_rank * per_rank_intermediate_size
        tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size, intermediate_size)
        expert_params_mapping = self.get_expert_mapping()
        for name, loaded_weight in weights:
            if is_pp_missing_parameter(name, self):
                continue

            layer_id, expert_id, fused_name = None, None, None
            moe_quant_method = None
            if "experts" in name:
                parts = name.split(".")
                ids = [s for s in parts if s.isdigit()]

                # for amd-quark format that each expert is seperated
                # need to extract the parameter name with experts fused.
                # example model: amd/gpt-oss-20b-MoE-Quant-W-MXFP4-A-FP8-KV-FP8
                if len(ids) == 2:
                    layer_id, expert_id = int(ids[0]), int(ids[-1])
                    parts.pop(len(parts) - 1 - parts[::-1].index(str(expert_id)))
                    fused_name = ".".join(parts)

                # for openai mxfp4 format that all experts are combined
                # no need to extract the parameter name with experts fused.
                # models: openai/gpt-oss-20b, openai/gpt-oss-120b
                elif len(ids) == 1:
                    layer_id, expert_id = int(ids[0]), None
                    fused_name = name

                else:
                    raise NameError(
                        f"Layer {name} contains more than 2 numeric indices. This is "
                        "an unexpected condition. Please open an issue if encountered."
                    )

                moe_quant_method = _get_moe_weight_dtype(layer_id=layer_id)

            def kv_cache_scale_loader(
                quant_config: QuantizationConfig,
                name: str,
                params_dict: dict[str, typing.Any],
                weight: torch.Tensor,
                default_weight_loader: Callable[..., None],
                loaded_params: set[str],
            ) -> tuple[bool, set[str]]:
                """
                Load KV cache output scales.
                Returns:
                    Tuple of (bool, set):
                    - bool: True if KV-cache scale was loaded into loaded_params
                    - set: Updated set of loaded_params if True else the original set
                """
                # load explicit cached KV output scale from quant_config
                if quant_config is not None and (
                    scale_name := quant_config.get_cache_scale(name)
                ):
                    param = params_dict[scale_name]
                    weight_loader = getattr(
                        param, "weight_loader", default_weight_loader
                    )
                    if weight.numel() != 1:
                        raise ValueError(
                            f"KV cache scale '{scale_name}' is expected to be a "
                            f"scalar, but got a tensor of shape {weight.shape}."
                        )
                    # Ensure weight is a scalar before passing to loader.
                    weight_loader(param, weight.flatten()[0])
                    loaded_params.add(scale_name)
                    return True, loaded_params

                return False, loaded_params

            load_kv_cache_scale_completed, loaded_params = kv_cache_scale_loader(
                self.quant_config,
                name,
                params_dict,
                loaded_weight,
                default_weight_loader,
                loaded_params,
            )
            if load_kv_cache_scale_completed:
                continue

            if (
                all(key in name for key in ["input_scale", "mlp.experts"])
                and expert_id is not None
            ):
                assert loaded_weight.numel() == 1
                expert_data = params_dict[fused_name].data[expert_id]
                expert_data.copy_(loaded_weight)
                loaded_params.add(fused_name)
                continue

            # Unified handler for mxfp4 weights and scales
            elif moe_quant_method == "mxfp4" and any(
                name.endswith(suffix)
                for suffix in [
                    ".w13_weight_scale",
                    ".w2_weight_scale",
                    ".w13_weight",
                    ".w2_weight",
                ]
            ):
                is_w13 = ".w13_" in name
                is_scale = "_scale" in name

                # Reshape weight for mxfp4 if needed (not for scales)
                if not is_scale and expert_id is None:
                    if is_w13:
                        if loaded_weight.dim() < 3:
                            raise ValueError(
                                f"Expected w13_weight to have at least 3 "
                                f"dimensions, got shape "
                                f"{loaded_weight.shape}"
                            )
                        if loaded_weight.shape[0] != num_experts:
                            raise ValueError(
                                f"Expected w13_weight first dimension to be "
                                f"{num_experts}, got "
                                f"{loaded_weight.shape[0]}"
                            )
                        loaded_weight = loaded_weight.view(
                            num_experts, 2 * intermediate_size, -1
                        ).contiguous()
                    else:
                        if loaded_weight.dim() < 3:
                            raise ValueError(
                                f"Expected w2_weight to have at least 3 "
                                f"dimensions, got shape "
                                f"{loaded_weight.shape}"
                            )
                        if loaded_weight.shape[0] != num_experts:
                            raise ValueError(
                                f"Expected w2_weight first dimension to be "
                                f"{num_experts}, got "
                                f"{loaded_weight.shape[0]}"
                            )
                        loaded_weight = loaded_weight.view(
                            num_experts, -1, intermediate_size // 2
                        ).contiguous()

                if use_ep:
                    sliced_weight = loaded_weight[ep_rank_start:ep_rank_end, ...]
                else:
                    if is_w13:
                        if expert_id is None:
                            sliced_weight = loaded_weight[
                                :, 2 * tp_rank_start : 2 * tp_rank_end, ...
                            ]
                        else:
                            sliced_weight = loaded_weight[
                                2 * tp_rank_start : 2 * tp_rank_end, ...
                            ]
                    else:
                        if is_scale:
                            sliced_weight = loaded_weight[
                                ...,
                                tp_rank_start // OCP_MX_BLOCK_SIZE : tp_rank_end
                                // OCP_MX_BLOCK_SIZE,
                            ]
                        else:
                            sliced_weight = loaded_weight[
                                ..., tp_rank_start // 2 : tp_rank_end // 2
                            ]

                # NOTE(rob): because gpt-oss ckpt has "unique" structure with
                # fused gate_up_proj fused on disk, we cannot use the existing
                # weight loaders without added complexity, so just do the
                # direct load here.
                param = params_dict[fused_name]
                expert_data = param.data[expert_id]
                dim1 = sliced_weight.shape[0]
                dim2 = sliced_weight.shape[1]
                expert_data.data[:dim1, :dim2].copy_(sliced_weight)
                loaded_params.add(fused_name)
                continue

            elif name.endswith(".w13_weight") and moe_quant_method == "fp8":
                if use_ep:
                    narrow_weight = loaded_weight[ep_rank_start:ep_rank_end, ...]
                else:
                    if expert_id is None:
                        narrow_weight = loaded_weight[
                            :, 2 * tp_rank_start : 2 * tp_rank_end, :
                        ]
                    else:
                        narrow_weight = loaded_weight[
                            2 * tp_rank_start : 2 * tp_rank_end, :
                        ]

                assert fused_name is not None
                param = params_dict[fused_name]

                if expert_id is None:
                    param.data.copy_(narrow_weight)
                else:
                    param.data[expert_id].copy_(narrow_weight)

                loaded_params.add(fused_name)
                continue

            elif name.endswith(".w13_weight_scale") and moe_quant_method == "fp8":
                assert fused_name is not None
                param = params_dict[fused_name]

                # Check if this is per-channel or per-tensor scale
                if loaded_weight.numel() > 1 and loaded_weight.dim() == 1:
                    if use_ep:
                        narrow_weight = loaded_weight[ep_rank_start:ep_rank_end, ...]
                    else:
                        narrow_weight = loaded_weight[
                            2 * tp_rank_start : 2 * tp_rank_end
                        ]
                else:
                    narrow_weight = loaded_weight

                if expert_id is None:
                    param.data.copy_(narrow_weight)
                else:
                    param.data[expert_id].copy_(narrow_weight)

                loaded_params.add(fused_name)
                continue

            elif name.endswith(".w13_input_scale") and moe_quant_method == "fp8":
                assert fused_name is not None
                param = params_dict[fused_name]

                if expert_id is None:
                    param.data.copy_(loaded_weight)
                else:
                    param.data[expert_id].copy_(loaded_weight)

                loaded_params.add(fused_name)
                continue

            elif name.endswith(".w2_weight") and moe_quant_method == "fp8":
                if use_ep:
                    narrow_weight = loaded_weight[ep_rank_start:ep_rank_end, ...]
                else:
                    if expert_id is None:
                        narrow_weight = loaded_weight[..., tp_rank_start:tp_rank_end]
                    else:
                        narrow_weight = loaded_weight[..., tp_rank_start:tp_rank_end]

                assert fused_name is not None
                param = params_dict[fused_name]

                if expert_id is None:
                    param.data.copy_(narrow_weight)
                else:
                    param.data[expert_id].copy_(narrow_weight)

                loaded_params.add(fused_name)
                continue

            elif name.endswith(".w2_weight_scale") and moe_quant_method == "fp8":
                assert fused_name is not None
                param = params_dict[fused_name]

                if use_ep:
                    narrow_weight = loaded_weight[ep_rank_start:ep_rank_end, ...]
                else:
                    narrow_weight = loaded_weight

                if expert_id is None:
                    param.data.copy_(narrow_weight)
                else:
                    param.data[expert_id].copy_(narrow_weight)

                loaded_params.add(fused_name)
                continue

            # Unified handler for bias loading (w13_bias and w2_bias)
            elif name.endswith(".w13_bias") or name.endswith(".w2_bias"):
                is_w13_bias = name.endswith(".w13_bias")

                if use_ep:
                    sliced_weight = loaded_weight[ep_rank_start:ep_rank_end, ...]
                else:
                    if is_w13_bias:
                        if expert_id is None:
                            sliced_weight = loaded_weight[
                                :, 2 * tp_rank_start : 2 * tp_rank_end
                            ]
                        else:
                            sliced_weight = loaded_weight[
                                2 * tp_rank_start : 2 * tp_rank_end
                            ]
                    else:
                        sliced_weight = loaded_weight
                        if tp_rank != 0:
                            sliced_weight = sliced_weight.zero_()

                # NOTE(rob): because gpt-oss ckpt has "unique" structure with
                # fused gate_up_proj fused on disk, we cannot use the existing
                # weight loaders without added complexity, so just do the
                # direct load here.
                assert fused_name is not None
                param = params_dict[fused_name]
                expert_data = param.data[expert_id]
                dim1 = sliced_weight.shape[0]
                expert_data.data[:dim1].copy_(sliced_weight)
                loaded_params.add(fused_name)
                continue

            elif "sinks" in name:
                # Handle attention sinks (distributed across ranks)
                param = params_dict[name]
                narrow_weight = loaded_weight.narrow(0, head_start, heads_per_rank)
                param.data.copy_(narrow_weight)
                loaded_params.add(name)
                continue

            for param_name, weight_name, shard_id in stacked_params_mapping:
                # Skip non-stacked layers and experts (experts handled below).
                if weight_name not in name:
                    continue
                # We have mlp.experts[0].gate_proj in the checkpoint.
                # Since we handle the experts below in expert_params_mapping,
                # we need to skip here BEFORE we update the name, otherwise
                # name will be updated to mlp.experts[0].gate_up_proj, which
                # will then be updated below in expert_params_mapping
                # for mlp.experts[0].gate_gate_up_proj, which breaks load.
                if ("mlp.experts." in name) and name not in params_dict:
                    continue
                name = name.replace(weight_name, param_name)

                if name.endswith("scale"):
                    # Remapping the name of FP8 kv-scale.
                    name = maybe_remap_kv_scale_name(name, params_dict)
                    if name is None:
                        continue

                param = params_dict[name]
                weight_loader = param.weight_loader

                weight_loader(param, loaded_weight, shard_id)
                loaded_params.add(name)
                break
            else:
                for mapping in expert_params_mapping:
                    # Anyway, this is an expert weight and should not be
                    # attempted to load as other weights later
                    param_name, weight_name, mapping_expert_id, shard_id = mapping
                    weight_name = (
                        weight_name[:-1] if weight_name.endswith(".") else weight_name
                    )

                    if weight_name not in name:
                        continue

                    param = params_dict[fused_name]
                    # We should ask the weight loader to return success or not
                    # here since otherwise we may skip experts with other
                    # available replicas.
                    weight_loader = typing.cast(
                        Callable[..., bool], param.weight_loader
                    )
                    # Use checkpoint's expert_id for quark format (when expert_id
                    # is extracted from weight name), otherwise use mapping's expert_id
                    actual_expert_id = (
                        expert_id if expert_id is not None else mapping_expert_id
                    )
                    success = weight_loader(
                        param,
                        loaded_weight,
                        fused_name,
                        shard_id=shard_id,
                        expert_id=actual_expert_id,
                        return_success=True,
                    )
                    if success:
                        name = fused_name
                        loaded_params.add(name)
                        break
                else:
                    if name not in params_dict:
                        continue
                    param = params_dict[name]
                    weight_loader = getattr(
                        param, "weight_loader", default_weight_loader
                    )
                    weight_loader(param, loaded_weight)

                loaded_params.add(name)
        return loaded_params

964
    def _load_weights_other(
965
966
        self,
        ep_rank_end: int,
967
        ep_rank_start: int,
968
969
970
971
972
        heads_per_rank: int,
        head_start: int,
        weights: Iterable[tuple[str, torch.Tensor]],
        stacked_params_mapping: list[tuple[str, ...]],
    ) -> set[str]:
973
974
975
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()

976
977
        use_ep = self.parallel_config.enable_expert_parallel

978
979
        # In MoE, we need to flatten the tensor parallel size across the data
        # parallel size when EP is disabled.
980
        tp_size, tp_rank = FusedMoEParallelConfig.flatten_tp_across_dp_and_pcp(
981
982
983
            tp_size=get_tensor_model_parallel_world_size(),
            dp_size=get_dp_group().world_size,
            dp_rank=get_dp_group().rank_in_group,
984
985
            pcp_size=get_pcp_group().world_size,
            pcp_rank=get_pcp_group().rank_in_group,
986
        )
987

988
        intermediate_size = self.config.intermediate_size
989
990
991
        per_rank_intermediate_size = cdiv(intermediate_size, tp_size)
        # Calculate common slicing bounds for current rank
        tp_rank_start = tp_rank * per_rank_intermediate_size
992
        tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size, intermediate_size)
993
994

        for name, weight in weights:
995
996
997
998
            # Skip layers on other devices.
            if is_pp_missing_parameter(name, self):
                continue

999
            if ".w13_weight" in name:
1000
1001
1002
1003
1004
                # Handle MLP gate and up projection weights
                # Extract gate and up projection parts
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
1005
                    narrow_weight = weight[:, :, 2 * tp_rank_start : 2 * tp_rank_end]
1006
1007

                narrow_weight = narrow_weight.permute(0, 2, 1).contiguous()
1008
                param = params_dict[name]
1009
1010

                param.copy_(narrow_weight)
1011
1012
1013
                loaded_params.add(name)
                continue
            elif ".w2_weight" in name:
1014
1015
1016
1017
1018
1019
                # Handle MLP down projection weights
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
                    narrow_weight = weight[:, tp_rank_start:tp_rank_end, :]
                narrow_weight = narrow_weight.permute(0, 2, 1).contiguous()
1020
                param = params_dict[name]
1021
1022

                param.copy_(narrow_weight)
1023
1024
1025
                loaded_params.add(name)
                continue
            elif ".w13_bias" in name:
1026
1027
1028
1029
1030
                # Handle MLP gate and up projection biases
                # Extract gate and up projection bias parts
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
1031
                    narrow_weight = weight[:, 2 * tp_rank_start : 2 * tp_rank_end]
1032

1033
                param = params_dict[name]
1034
                param.copy_(narrow_weight)
1035
1036
1037
                loaded_params.add(name)
                continue
            elif ".w2_bias" in name:
1038
1039
1040
1041
1042
1043
1044
                # Handle MLP down projection bias
                if use_ep:
                    weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
                    # (only load on rank 0 to avoid duplication)
                    if tp_rank != 0:
                        weight.zero_()
1045
                param = params_dict[name]
1046
                param.copy_(weight)
1047
1048
                loaded_params.add(name)
                continue
1049
1050
1051
1052
1053
1054
            elif "sinks" in name:
                # Handle attention sinks (distributed across ranks)
                param = params_dict[name]
                narrow_weight = weight.narrow(0, head_start, heads_per_rank)
                param.data.copy_(narrow_weight)
                loaded_params.add(name)
1055
1056
1057
1058
1059
1060
                continue
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
1061
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
1062
1063
1064
1065
1066
                if weight_loader == default_weight_loader:
                    weight_loader(param, weight)
                else:
                    weight_loader(param, weight, shard_id)
                break
1067
1068
            else:
                # Handle all other weights with potential renaming
1069
                if name not in params_dict:
1070
                    continue
1071
                param = params_dict[name]
1072
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
1073
                weight_loader(param, weight)
1074
            loaded_params.add(name)
1075
1076
        return loaded_params

1077
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
1078
1079
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
1080
1081
1082
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
        ]

        tp_rank = get_tensor_model_parallel_rank()
        tp_size = get_tensor_model_parallel_world_size()

        # Attention heads per rank
        heads_per_rank = self.config.num_attention_heads // tp_size
        head_start = tp_rank * heads_per_rank

        ep_size = get_ep_group().world_size
        ep_rank = get_ep_group().rank
        num_experts = self.config.num_local_experts
        experts_per_rank = num_experts // ep_size
        ep_rank_start = ep_rank * experts_per_rank
        ep_rank_end = (ep_rank + 1) * experts_per_rank

1099
1100
1101
1102
1103
        quant_method = (
            self.config.quantization_config["quant_method"]
            if hasattr(self.config, "quantization_config")
            else None
        )
1104

1105
        if quant_method == "mxfp4":
1106
1107
1108
1109
1110
1111
1112
1113
            return self._load_weights_mxfp4(
                ep_rank_end,
                ep_rank_start,
                heads_per_rank,
                head_start,
                weights,
                stacked_params_mapping,
            )
1114
1115
1116
1117
1118
1119
1120
1121
1122
        elif quant_method == "quark":
            return self._load_weights_quark(
                ep_rank_end,
                ep_rank_start,
                heads_per_rank,
                head_start,
                weights,
                stacked_params_mapping,
            )
1123
        else:
1124
            return self._load_weights_other(
1125
                ep_rank_end,
1126
                ep_rank_start,
1127
1128
1129
1130
1131
                heads_per_rank,
                head_start,
                weights,
                stacked_params_mapping,
            )
1132
1133


1134
class GptOssForCausalLM(nn.Module, SupportsPP, SupportsEagle3, SupportsLoRA):
1135
    is_3d_moe_weight: bool = True
1136
    packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154

    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_substr={
            ".self_attn.": ".attn.",
        },
        orig_to_new_suffix={
            ".embed_tokens.weight": ".embedding.weight",
            # MoE MXFP4 weights
            ".gate_up_proj_blocks": ".w13_weight",
            ".down_proj_blocks": ".w2_weight",
            ".gate_up_proj_scales": ".w13_weight_scale",
            ".down_proj_scales": ".w2_weight_scale",
            # MoE other weights
            ".gate_up_proj": ".w13_weight",
            ".down_proj": ".w2_weight",
            # MoE Bias
            ".gate_up_proj_bias": ".w13_bias",
            ".down_proj_bias": ".w2_bias",
1155
1156
1157
1158
1159
1160
1161
1162
1163
            # For quark format
            ".gate_up_proj.weight": ".w13_weight",
            ".gate_up_proj.weight_scale": ".w13_weight_scale",
            ".gate_up_proj.bias": ".w13_bias",
            ".gate_up_proj.input_scale": ".w13_input_scale",
            ".down_proj.weight": ".w2_weight",
            ".down_proj.weight_scale": ".w2_weight_scale",
            ".down_proj.bias": ".w2_bias",
            ".down_proj.input_scale": ".w2_input_scale",
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
        },
    )

    def __init__(
        self,
        vllm_config: VllmConfig,
        prefix: str = "",
    ):
        super().__init__()
        self.vllm_config = vllm_config
        self.config = vllm_config.model_config.hf_config

        self.model = GptOssModel(
            vllm_config=vllm_config,
            prefix=maybe_prefix(prefix, "model"),
        )
        self.lm_head = ParallelLMHead(
            self.config.vocab_size,
            self.config.hidden_size,
1183
            prefix=maybe_prefix(prefix, "lm_head"),
1184
1185
        )
        self.logits_processor = LogitsProcessor(self.config.vocab_size)
1186
        self.make_empty_intermediate_tensors = (
1187
1188
            self.model.make_empty_intermediate_tensors
        )
1189

1190
1191
1192
1193
1194
1195
1196
    def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
        self.model.aux_hidden_state_layers = layers

    def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
        num_layers = len(self.model.layers)
        return (2, num_layers // 2, num_layers - 3)

1197
1198
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
1199

1200
1201
    def forward(
        self,
1202
        input_ids: torch.Tensor | None,
1203
        positions: torch.Tensor,
1204
1205
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
1206
1207
    ) -> torch.Tensor:
        return self.model(input_ids, positions, intermediate_tensors, inputs_embeds)
1208

1209
1210
    def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head, hidden_states)
1211
1212
        return logits

1213
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
1214
1215
        loader = AutoWeightsLoader(
            self,
1216
            skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
1217
1218
        )
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)