gpt_oss.py 47.9 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
import typing
from collections.abc import Callable, Iterable
5
6
7
8
9
10
11
12

import torch
import torch.distributed as dist
from torch import nn
from transformers import GptOssConfig

from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
13
from vllm.distributed import (
14
    get_dp_group,
15
    get_ep_group,
16
    get_pcp_group,
17
18
19
20
21
    get_pp_group,
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
    tensor_model_parallel_all_gather,
)
22
from vllm.model_executor.layers.attention import Attention
23
from vllm.model_executor.layers.fused_moe import FusedMoE, GateLinear
24
from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig
25
from vllm.model_executor.layers.layernorm import RMSNorm
26
27
28
29
from vllm.model_executor.layers.linear import (
    QKVParallelLinear,
    RowParallelLinear,
)
30
31
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
32
from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import OCP_MX_BLOCK_SIZE
33
from vllm.model_executor.layers.rotary_embedding import get_rope
34
from vllm.model_executor.layers.utils import rocm_unquantized_gemm
35
from vllm.model_executor.layers.vocab_parallel_embedding import (
36
37
38
    ParallelLMHead,
    VocabParallelEmbedding,
)
39
40
41
42
from vllm.model_executor.model_loader.weight_utils import (
    default_weight_loader,
    maybe_remap_kv_scale_name,
)
43
from vllm.model_executor.models.utils import sequence_parallel_chunk
44
from vllm.platforms import current_platform
45
from vllm.sequence import IntermediateTensors
46
from vllm.utils.math_utils import cdiv
47
from vllm.v1.attention.backend import AttentionType
48

49
50
51
52
53
54
55
from .interfaces import (
    EagleModelMixin,
    SupportsEagle,
    SupportsEagle3,
    SupportsLoRA,
    SupportsPP,
)
56
57
58
59
60
61
62
63
64
from .utils import (
    AutoWeightsLoader,
    WeightsMapper,
    extract_layer_index,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
65
66
67
68
69
70


class OAIAttention(nn.Module):
    def __init__(
        self,
        config: GptOssConfig,
71
72
        quant_config: QuantizationConfig | None = None,
        cache_config: CacheConfig | None = None,
73
74
75
76
77
78
79
80
81
82
83
84
85
        prefix: str = "",
    ):
        super().__init__()
        self.layer_idx = extract_layer_index(prefix)
        self.head_dim = config.head_dim
        self.num_attention_heads = config.num_attention_heads
        self.num_key_value_heads = config.num_key_value_heads
        self.hidden_size = config.hidden_size

        self.rotary_emb = get_rope(
            self.head_dim,
            max_position=config.max_position_embeddings,
            dtype=torch.float32,
86
87
            rope_parameters={
                "rope_theta": config.rope_parameters["rope_theta"],
88
                "rope_type": "yarn",
89
90
                "factor": config.rope_parameters["factor"],
                "original_max_position_embeddings": config.rope_parameters[
91
92
                    "original_max_position_embeddings"
                ],
93
94
                "beta_fast": config.rope_parameters["beta_fast"],
                "beta_slow": config.rope_parameters["beta_slow"],
95
                "truncate": config.rope_parameters.get("truncate", True),
96
97
98
99
100
101
102
            },
            is_neox_style=True,
        )

        tp_size = get_tensor_model_parallel_world_size()

        self.sinks = torch.nn.Parameter(
103
104
            torch.empty(config.num_attention_heads // tp_size, requires_grad=False)
        )
105
106
107
108
109

        self.q_size = self.num_attention_heads * self.head_dim // tp_size
        self.kv_size = self.num_key_value_heads * self.head_dim // tp_size
        self.scaling = self.head_dim**-0.5

110
        self.qkv_proj = QKVParallelLinear(
111
112
113
114
            hidden_size=self.hidden_size,
            head_size=self.head_dim,
            total_num_heads=self.num_attention_heads,
            total_num_kv_heads=self.num_key_value_heads,
115
            bias=True,
116
117
118
119
120
121
122
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
        )

        self.o_proj = RowParallelLinear(
            input_size=self.num_attention_heads * self.head_dim,
            output_size=self.hidden_size,
123
            bias=True,
124
125
126
127
128
129
130
131
            quant_config=quant_config,
            prefix=f"{prefix}.o_proj",
        )

        self.num_local_attention_heads = config.num_attention_heads // tp_size
        self.num_local_key_value_heads = config.num_key_value_heads // tp_size

        # Only apply sliding window to every other layer
132
        sliding_window = config.sliding_window if self.layer_idx % 2 == 0 else None
133
134
135
136
137
138
139
140
141
142
143
144
145
        self.attn = Attention(
            self.num_local_attention_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_local_key_value_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            per_layer_sliding_window=sliding_window,
            attn_type=AttentionType.DECODER,
            prefix=f"{prefix}.attn",
            sinks=self.sinks,
        )

146
147
148
    def forward(
        self, hidden_states: torch.Tensor, positions: torch.Tensor
    ) -> torch.Tensor:
149
        qkv, _ = self.qkv_proj(hidden_states)
150
151
152
153
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
        attn_output = self.attn(q, k, v)
        output, _ = self.o_proj(attn_output)
154
        return output
155
156
157
158
159


class MLPBlock(torch.nn.Module):
    def __init__(
        self,
160
        vllm_config: VllmConfig,
161
162
163
164
        layer_idx: int,
        prefix: str = "",
    ):
        super().__init__()
165
166
167
168
169
170
171

        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        parallel_config = vllm_config.parallel_config

        self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe

172
173
        self.layer_idx = layer_idx
        self.num_experts = config.num_local_experts
174
        self.hidden_size = config.hidden_size
175
176
        self.experts_per_token = config.num_experts_per_tok
        self.world_size = dist.get_world_size() if dist.is_initialized() else 1
177
        self.router = GateLinear(
178
179
180
181
182
            config.hidden_size,
            config.num_local_experts,
            bias=True,
            prefix=f"{prefix}.router",
        )
183
        assert config.intermediate_size % self.world_size == 0
184
185
186
187
188
189
190
191
192
193
194
195
196
197
        self.experts = FusedMoE(
            num_experts=config.num_local_experts,
            top_k=config.num_experts_per_tok,
            hidden_size=config.hidden_size,
            intermediate_size=config.intermediate_size,
            reduce_results=True,
            renormalize=True,
            quant_config=quant_config,
            prefix=f"{prefix}.experts",
            apply_router_weight_on_input=False,
            has_bias=True,
            activation="swigluoai",
            is_sequence_parallel=self.is_sequence_parallel,
        )
198
199

    def forward(self, x: torch.Tensor) -> torch.Tensor:
200
201
202
203
        num_tokens = x.shape[0]
        if self.is_sequence_parallel:
            x = sequence_parallel_chunk(x)

204
205
206
207
208
        if current_platform.is_rocm():
            g = rocm_unquantized_gemm(
                self, x[:, : self.hidden_size], self.router.weight, self.router.bias
            )
        else:
209
            g, _ = self.router(x)
210
        x = self.experts(hidden_states=x, router_logits=g)[:, : self.hidden_size]
211
212
213
214

        if self.is_sequence_parallel:
            x = tensor_model_parallel_all_gather(x.contiguous(), 0)
            x = x[:num_tokens]
215
        return x
216
217
218
219
220


class TransformerBlock(torch.nn.Module):
    def __init__(
        self,
221
        vllm_config: VllmConfig,
222
        quant_config: QuantizationConfig,
223
224
225
        prefix: str = "",
    ):
        super().__init__()
226
227
228
229

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config

230
        self.layer_idx = extract_layer_index(prefix)
231
        self.attn = OAIAttention(
232
233
234
235
            config,
            prefix=f"{prefix}.attn",
            quant_config=quant_config,
            cache_config=cache_config,
236
237
        )
        self.mlp = MLPBlock(vllm_config, self.layer_idx, prefix=f"{prefix}.mlp")
238
239
        self.input_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
        self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
240

241
242
243
244
    def forward(
        self,
        hidden_states: torch.Tensor,
        positions: torch.Tensor,
245
        residual: torch.Tensor | None,
246
247
248
249
250
251
    ) -> torch.Tensor:
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
252
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
253
        hidden_states = self.attn(hidden_states, positions)
254

255
        # Fully Connected
256
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
257
258
        output = self.mlp(hidden_states)
        return output, residual
259
260
261


@support_torch_compile
262
class GptOssModel(nn.Module, EagleModelMixin):
263
264
265
266
267
268
269
270
    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
    ):
        super().__init__()
        self.config = vllm_config.model_config.hf_config
271
        self.quant_config = vllm_config.quant_config
272
        self.parallel_config = vllm_config.parallel_config
273
274
275
276
        self.embedding = VocabParallelEmbedding(
            self.config.vocab_size,
            self.config.hidden_size,
        )
277
278
279
        self.start_layer, self.end_layer, self.layers = make_layers(
            self.config.num_hidden_layers,
            lambda prefix: TransformerBlock(
280
                vllm_config,
281
                prefix=prefix,
282
                quant_config=self.quant_config,
283
284
285
            ),
            prefix=f"{prefix}.layers",
        )
286
        self.norm = RMSNorm(self.config.hidden_size, eps=1e-5)
287
288
289
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], self.config.hidden_size
        )
290

291
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
292
293
294
295
        return self.embedding(input_ids)

    def forward(
        self,
296
        input_ids: torch.Tensor | None,
297
        positions: torch.Tensor,
298
299
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
300
301
302
303
304
    ) -> torch.Tensor:
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                x = inputs_embeds
            else:
305
                x = self.embed_input_ids(input_ids)
306
307
308
309
310
311
312

            residual = None
        else:
            assert intermediate_tensors is not None
            x = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

313
314
315
        aux_hidden_states = self._maybe_add_hidden_state(
            [], self.start_layer, x, residual
        )
316
317
318
        for i in range(self.start_layer, self.end_layer):
            layer = self.layers[i]
            x, residual = layer(x, positions, residual)
319
            self._maybe_add_hidden_state(aux_hidden_states, i + 1, x, residual)
320
        if not get_pp_group().is_last_rank:
321
            return IntermediateTensors({"hidden_states": x, "residual": residual})
322
        x, _ = self.norm(x, residual)
323
324
325

        if len(aux_hidden_states) > 0:
            return x, aux_hidden_states
326
327
        return x

328
329
330
331
332
333
334
335
336
337
338
339
340
    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
        # Params for weights, weight scales, activation scales
        # (param_name, weight_name, expert_id, shard_id)
        # NOTE: this is only used for quark.
        return FusedMoE.make_expert_params_mapping(
            self,
            ckpt_gate_proj_name="w1",
            ckpt_down_proj_name="w2",
            ckpt_up_proj_name="w3",
            num_experts=self.config.num_local_experts,
            num_redundant_experts=0,
        )

341
    def _load_weights_mxfp4(
342
343
344
345
346
347
348
349
        self,
        ep_rank_end: int,
        ep_rank_start: int,
        heads_per_rank: int,
        head_start: int,
        weights: Iterable[tuple[str, torch.Tensor]],
        stacked_params_mapping: list[tuple[str, ...]],
    ) -> set[str]:
350
351
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()
352
353
354

        use_ep = self.parallel_config.enable_expert_parallel
        num_experts = self.config.num_local_experts
355

356
357
        # In MoE, we need to flatten the tensor parallel size across the data
        # parallel size when EP is disabled.
358
        tp_size, tp_rank = FusedMoEParallelConfig.flatten_tp_across_dp_and_pcp(
359
360
361
            tp_size=get_tensor_model_parallel_world_size(),
            dp_size=get_dp_group().world_size,
            dp_rank=get_dp_group().rank_in_group,
362
363
            pcp_size=get_pcp_group().world_size,
            pcp_rank=get_pcp_group().rank_in_group,
364
        )
365
366

        intermediate_size = self.config.intermediate_size
367
        intermediate_size_block = intermediate_size // OCP_MX_BLOCK_SIZE
368
        per_rank_intermediate_size_block = cdiv(intermediate_size_block, tp_size)
369
370
371
        per_rank_intermediate_size = (
            per_rank_intermediate_size_block * OCP_MX_BLOCK_SIZE
        )
372
373
374

        # Calculate common slicing bounds for current rank
        tp_rank_start = tp_rank * per_rank_intermediate_size
375
        tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size, intermediate_size)
376
377

        for name, weight in weights:
378
379
380
381
            # Skip layers on other devices.
            if is_pp_missing_parameter(name, self):
                continue

382
383
            if ".w13_weight_scale" in name:
                # Handle MLP gate and up projection weights scale
384
385
386
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
387
                    narrow_weight = weight[:, 2 * tp_rank_start : 2 * tp_rank_end, ...]
388

389
                param = params_dict[name]
390
391
392
393
394
395
396
397
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                weight_loader(
                    param,
                    narrow_weight,
                    weight_name=name,
                    shard_id=None,
                    expert_id=None,
                )
398
399
400
                loaded_params.add(name)
                continue
            elif ".w2_weight_scale" in name:
401
402
403
404
                # Handle MLP down projection weights
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
405
                    narrow_weight = weight[
406
407
408
                        ...,
                        tp_rank_start // OCP_MX_BLOCK_SIZE : tp_rank_end
                        // OCP_MX_BLOCK_SIZE,
409
                    ]
410

411
                param = params_dict[name]
412
413
414
415
416
417
418
419
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                weight_loader(
                    param,
                    narrow_weight,
                    weight_name=name,
                    shard_id=None,
                    expert_id=None,
                )
420
421
422
423
424
425
                loaded_params.add(name)
                continue
            elif ".w13_weight" in name:
                # Handle MLP gate and up projection weights
                # flat weight from (E, 2 * N, block_size, entry_per_block)
                # to (E, 2 * N, -1), shouldn't trigger copy for contiguous
426
427
428
                weight = weight.view(
                    num_experts, 2 * intermediate_size, -1
                ).contiguous()
429

430
431
                # Extract gate and up projection parts
                # since the weight is shuffled, we can slice directly
432
433
434
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
435
                    narrow_weight = weight[:, 2 * tp_rank_start : 2 * tp_rank_end, ...]
436

437
                param = params_dict[name]
438
439
440
441
442
443
444
445
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                weight_loader(
                    param,
                    narrow_weight,
                    weight_name=name,
                    shard_id=None,
                    expert_id=None,
                )
446
447
448
                loaded_params.add(name)
                continue
            elif ".w2_weight" in name:
449
                # Handle MLP down projection weights
450
451
                # same flatten here, but since 2 mx4 value are packed in 1
                # uint8, divide by 2
452
453
454
                weight = weight.view(
                    num_experts, -1, intermediate_size // 2
                ).contiguous()
455
456
457
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
458
                    narrow_weight = weight[..., tp_rank_start // 2 : tp_rank_end // 2]
459

460
                param = params_dict[name]
461
462
463
464
465
466
467
468
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                weight_loader(
                    param,
                    narrow_weight,
                    weight_name=name,
                    shard_id=None,
                    expert_id=None,
                )
469
470
471
                loaded_params.add(name)
                continue
            elif ".w13_bias" in name:
472
473
474
475
476
                # Handle MLP gate and up projection biases
                # Extract gate and up projection bias parts
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
477
                    narrow_weight = weight[:, 2 * tp_rank_start : 2 * tp_rank_end]
478

479
                param = params_dict[name]
480
481
482
483
484
485
486
487
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                weight_loader(
                    param,
                    narrow_weight,
                    weight_name=name,
                    shard_id=None,
                    expert_id=None,
                )
488
489
490
                loaded_params.add(name)
                continue
            elif ".w2_bias" in name:
491
                # Handle MLP down projection bias
492
                param = params_dict[name]
493
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
494
495
496
497
498
499
                if use_ep:
                    weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
                    # (only load on rank 0 to avoid duplication)
                    if tp_rank != 0:
                        weight.zero_()
500
501
502
                weight_loader(
                    param, weight, weight_name=name, shard_id=None, expert_id=None
                )
503
504
                loaded_params.add(name)
                continue
505
506
507
508
509
510
            elif "sinks" in name:
                # Handle attention sinks (distributed across ranks)
                param = params_dict[name]
                narrow_weight = weight.narrow(0, head_start, heads_per_rank)
                param.data.copy_(narrow_weight)
                loaded_params.add(name)
511
512
513
514
515
516
                continue
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
517
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
518
519
520
521
522
                if weight_loader == default_weight_loader:
                    weight_loader(param, weight)
                else:
                    weight_loader(param, weight, shard_id)
                break
523
524
            else:
                # Handle all other weights with potential renaming
525
                if name not in params_dict:
526
                    continue
527
                param = params_dict[name]
528
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
529
                weight_loader(param, weight)
530
            loaded_params.add(name)
531
        return loaded_params
532

533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
    def _load_weights_quark(
        self,
        ep_rank_end: int,
        ep_rank_start: int,
        heads_per_rank: int,
        head_start: int,
        weights: Iterable[tuple[str, torch.Tensor]],
        stacked_params_mapping: list[tuple[str, ...]],
    ) -> set[str]:
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()

        use_ep = self.parallel_config.enable_expert_parallel
        num_experts = self.config.num_local_experts

        if use_ep:
            tp_rank = get_tensor_model_parallel_rank()
            tp_size = get_tensor_model_parallel_world_size()
        else:
            tp_size, tp_rank = FusedMoEParallelConfig.flatten_tp_across_dp_and_pcp(
                tp_size=get_tensor_model_parallel_world_size(),
                dp_size=get_dp_group().world_size,
                dp_rank=get_dp_group().rank_in_group,
                pcp_size=get_pcp_group().world_size,
                pcp_rank=get_pcp_group().rank_in_group,
            )

        def _get_moe_weight_dtype(layer_id: int = 0) -> str | None:
            """Helper function to get MoE quantization weight dtype.

            Args:
                layer_id: Layer index to check (default 0, as all layers should
                        have the same quantization method)

            Returns:
                Weight dtype string (e.g., "mxfp4", "fp8") or None if not available
            """
            if hasattr(self.layers[layer_id].mlp.experts.quant_method, "weight_dtype"):
                return self.layers[layer_id].mlp.experts.quant_method.weight_dtype
            return None

        intermediate_size = self.config.intermediate_size

        moe_weight_dtype = _get_moe_weight_dtype(layer_id=0)

        if moe_weight_dtype == "mxfp4":
            # MXFP4 requires OCP_MX_BLOCK_SIZE alignment
            intermediate_size_block = intermediate_size // OCP_MX_BLOCK_SIZE
            per_rank_intermediate_size_block = cdiv(intermediate_size_block, tp_size)
            per_rank_intermediate_size = (
                per_rank_intermediate_size_block * OCP_MX_BLOCK_SIZE
            )
        else:
            # FP8 and other formats don't need alignment
            per_rank_intermediate_size = cdiv(intermediate_size, tp_size)

        tp_rank_start = tp_rank * per_rank_intermediate_size
        tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size, intermediate_size)
        expert_params_mapping = self.get_expert_mapping()
        for name, loaded_weight in weights:
            if is_pp_missing_parameter(name, self):
                continue

            layer_id, expert_id, fused_name = None, None, None
            moe_quant_method = None
            if "experts" in name:
                parts = name.split(".")
                ids = [s for s in parts if s.isdigit()]

602
                # for amd-quark format that each expert is separated
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
                # need to extract the parameter name with experts fused.
                # example model: amd/gpt-oss-20b-MoE-Quant-W-MXFP4-A-FP8-KV-FP8
                if len(ids) == 2:
                    layer_id, expert_id = int(ids[0]), int(ids[-1])
                    parts.pop(len(parts) - 1 - parts[::-1].index(str(expert_id)))
                    fused_name = ".".join(parts)

                # for openai mxfp4 format that all experts are combined
                # no need to extract the parameter name with experts fused.
                # models: openai/gpt-oss-20b, openai/gpt-oss-120b
                elif len(ids) == 1:
                    layer_id, expert_id = int(ids[0]), None
                    fused_name = name

                else:
                    raise NameError(
                        f"Layer {name} contains more than 2 numeric indices. This is "
                        "an unexpected condition. Please open an issue if encountered."
                    )

                moe_quant_method = _get_moe_weight_dtype(layer_id=layer_id)

            def kv_cache_scale_loader(
                quant_config: QuantizationConfig,
                name: str,
                params_dict: dict[str, typing.Any],
                weight: torch.Tensor,
                default_weight_loader: Callable[..., None],
                loaded_params: set[str],
            ) -> tuple[bool, set[str]]:
                """
                Load KV cache output scales.
                Returns:
                    Tuple of (bool, set):
                    - bool: True if KV-cache scale was loaded into loaded_params
                    - set: Updated set of loaded_params if True else the original set
                """
                # load explicit cached KV output scale from quant_config
                if quant_config is not None and (
                    scale_name := quant_config.get_cache_scale(name)
                ):
                    param = params_dict[scale_name]
                    weight_loader = getattr(
                        param, "weight_loader", default_weight_loader
                    )
                    if weight.numel() != 1:
                        raise ValueError(
                            f"KV cache scale '{scale_name}' is expected to be a "
                            f"scalar, but got a tensor of shape {weight.shape}."
                        )
                    # Ensure weight is a scalar before passing to loader.
                    weight_loader(param, weight.flatten()[0])
                    loaded_params.add(scale_name)
                    return True, loaded_params

                return False, loaded_params

            load_kv_cache_scale_completed, loaded_params = kv_cache_scale_loader(
                self.quant_config,
                name,
                params_dict,
                loaded_weight,
                default_weight_loader,
                loaded_params,
            )
            if load_kv_cache_scale_completed:
                continue

            if (
                all(key in name for key in ["input_scale", "mlp.experts"])
                and expert_id is not None
            ):
                assert loaded_weight.numel() == 1
                expert_data = params_dict[fused_name].data[expert_id]
                expert_data.copy_(loaded_weight)
                loaded_params.add(fused_name)
                continue

            # Unified handler for mxfp4 weights and scales
            elif moe_quant_method == "mxfp4" and any(
                name.endswith(suffix)
                for suffix in [
                    ".w13_weight_scale",
                    ".w2_weight_scale",
                    ".w13_weight",
                    ".w2_weight",
                ]
            ):
                is_w13 = ".w13_" in name
                is_scale = "_scale" in name

                # Reshape weight for mxfp4 if needed (not for scales)
                if not is_scale and expert_id is None:
                    if is_w13:
                        if loaded_weight.dim() < 3:
                            raise ValueError(
                                f"Expected w13_weight to have at least 3 "
                                f"dimensions, got shape "
                                f"{loaded_weight.shape}"
                            )
                        if loaded_weight.shape[0] != num_experts:
                            raise ValueError(
                                f"Expected w13_weight first dimension to be "
                                f"{num_experts}, got "
                                f"{loaded_weight.shape[0]}"
                            )
                        loaded_weight = loaded_weight.view(
                            num_experts, 2 * intermediate_size, -1
                        ).contiguous()
                    else:
                        if loaded_weight.dim() < 3:
                            raise ValueError(
                                f"Expected w2_weight to have at least 3 "
                                f"dimensions, got shape "
                                f"{loaded_weight.shape}"
                            )
                        if loaded_weight.shape[0] != num_experts:
                            raise ValueError(
                                f"Expected w2_weight first dimension to be "
                                f"{num_experts}, got "
                                f"{loaded_weight.shape[0]}"
                            )
                        loaded_weight = loaded_weight.view(
                            num_experts, -1, intermediate_size // 2
                        ).contiguous()

                if use_ep:
                    sliced_weight = loaded_weight[ep_rank_start:ep_rank_end, ...]
                else:
                    if is_w13:
                        if expert_id is None:
                            sliced_weight = loaded_weight[
                                :, 2 * tp_rank_start : 2 * tp_rank_end, ...
                            ]
                        else:
                            sliced_weight = loaded_weight[
                                2 * tp_rank_start : 2 * tp_rank_end, ...
                            ]
                    else:
                        if is_scale:
                            sliced_weight = loaded_weight[
                                ...,
                                tp_rank_start // OCP_MX_BLOCK_SIZE : tp_rank_end
                                // OCP_MX_BLOCK_SIZE,
                            ]
                        else:
                            sliced_weight = loaded_weight[
                                ..., tp_rank_start // 2 : tp_rank_end // 2
                            ]

                # NOTE(rob): because gpt-oss ckpt has "unique" structure with
                # fused gate_up_proj fused on disk, we cannot use the existing
                # weight loaders without added complexity, so just do the
                # direct load here.
                param = params_dict[fused_name]
                expert_data = param.data[expert_id]
                dim1 = sliced_weight.shape[0]
                dim2 = sliced_weight.shape[1]
                expert_data.data[:dim1, :dim2].copy_(sliced_weight)
                loaded_params.add(fused_name)
                continue

            elif name.endswith(".w13_weight") and moe_quant_method == "fp8":
                if use_ep:
                    narrow_weight = loaded_weight[ep_rank_start:ep_rank_end, ...]
                else:
                    if expert_id is None:
                        narrow_weight = loaded_weight[
                            :, 2 * tp_rank_start : 2 * tp_rank_end, :
                        ]
                    else:
                        narrow_weight = loaded_weight[
                            2 * tp_rank_start : 2 * tp_rank_end, :
                        ]

                assert fused_name is not None
                param = params_dict[fused_name]

                if expert_id is None:
                    param.data.copy_(narrow_weight)
                else:
                    param.data[expert_id].copy_(narrow_weight)

                loaded_params.add(fused_name)
                continue

            elif name.endswith(".w13_weight_scale") and moe_quant_method == "fp8":
                assert fused_name is not None
                param = params_dict[fused_name]

                # Check if this is per-channel or per-tensor scale
                if loaded_weight.numel() > 1 and loaded_weight.dim() == 1:
                    if use_ep:
                        narrow_weight = loaded_weight[ep_rank_start:ep_rank_end, ...]
                    else:
                        narrow_weight = loaded_weight[
                            2 * tp_rank_start : 2 * tp_rank_end
                        ]
                else:
                    narrow_weight = loaded_weight

                if expert_id is None:
                    param.data.copy_(narrow_weight)
                else:
                    param.data[expert_id].copy_(narrow_weight)

                loaded_params.add(fused_name)
                continue

            elif name.endswith(".w13_input_scale") and moe_quant_method == "fp8":
                assert fused_name is not None
                param = params_dict[fused_name]

                if expert_id is None:
                    param.data.copy_(loaded_weight)
                else:
                    param.data[expert_id].copy_(loaded_weight)

                loaded_params.add(fused_name)
                continue

            elif name.endswith(".w2_weight") and moe_quant_method == "fp8":
                if use_ep:
                    narrow_weight = loaded_weight[ep_rank_start:ep_rank_end, ...]
                else:
                    if expert_id is None:
                        narrow_weight = loaded_weight[..., tp_rank_start:tp_rank_end]
                    else:
                        narrow_weight = loaded_weight[..., tp_rank_start:tp_rank_end]

                assert fused_name is not None
                param = params_dict[fused_name]

                if expert_id is None:
                    param.data.copy_(narrow_weight)
                else:
                    param.data[expert_id].copy_(narrow_weight)

                loaded_params.add(fused_name)
                continue

            elif name.endswith(".w2_weight_scale") and moe_quant_method == "fp8":
                assert fused_name is not None
                param = params_dict[fused_name]

                if use_ep:
                    narrow_weight = loaded_weight[ep_rank_start:ep_rank_end, ...]
                else:
                    narrow_weight = loaded_weight

                if expert_id is None:
                    param.data.copy_(narrow_weight)
                else:
                    param.data[expert_id].copy_(narrow_weight)

                loaded_params.add(fused_name)
                continue

            # Unified handler for bias loading (w13_bias and w2_bias)
            elif name.endswith(".w13_bias") or name.endswith(".w2_bias"):
                is_w13_bias = name.endswith(".w13_bias")

                if use_ep:
                    sliced_weight = loaded_weight[ep_rank_start:ep_rank_end, ...]
                else:
                    if is_w13_bias:
                        if expert_id is None:
                            sliced_weight = loaded_weight[
                                :, 2 * tp_rank_start : 2 * tp_rank_end
                            ]
                        else:
                            sliced_weight = loaded_weight[
                                2 * tp_rank_start : 2 * tp_rank_end
                            ]
                    else:
                        sliced_weight = loaded_weight
                        if tp_rank != 0:
                            sliced_weight = sliced_weight.zero_()

                # NOTE(rob): because gpt-oss ckpt has "unique" structure with
                # fused gate_up_proj fused on disk, we cannot use the existing
                # weight loaders without added complexity, so just do the
                # direct load here.
                assert fused_name is not None
                param = params_dict[fused_name]
                expert_data = param.data[expert_id]
                dim1 = sliced_weight.shape[0]
                expert_data.data[:dim1].copy_(sliced_weight)
                loaded_params.add(fused_name)
                continue

            elif "sinks" in name:
                # Handle attention sinks (distributed across ranks)
                param = params_dict[name]
                narrow_weight = loaded_weight.narrow(0, head_start, heads_per_rank)
                param.data.copy_(narrow_weight)
                loaded_params.add(name)
                continue

            for param_name, weight_name, shard_id in stacked_params_mapping:
                # Skip non-stacked layers and experts (experts handled below).
                if weight_name not in name:
                    continue
                # We have mlp.experts[0].gate_proj in the checkpoint.
                # Since we handle the experts below in expert_params_mapping,
                # we need to skip here BEFORE we update the name, otherwise
                # name will be updated to mlp.experts[0].gate_up_proj, which
                # will then be updated below in expert_params_mapping
                # for mlp.experts[0].gate_gate_up_proj, which breaks load.
                if ("mlp.experts." in name) and name not in params_dict:
                    continue
                name = name.replace(weight_name, param_name)

                if name.endswith("scale"):
                    # Remapping the name of FP8 kv-scale.
                    name = maybe_remap_kv_scale_name(name, params_dict)
                    if name is None:
                        continue

                param = params_dict[name]
                weight_loader = param.weight_loader

                weight_loader(param, loaded_weight, shard_id)
                loaded_params.add(name)
                break
            else:
                for mapping in expert_params_mapping:
                    # Anyway, this is an expert weight and should not be
                    # attempted to load as other weights later
                    param_name, weight_name, mapping_expert_id, shard_id = mapping
                    weight_name = (
                        weight_name[:-1] if weight_name.endswith(".") else weight_name
                    )

                    if weight_name not in name:
                        continue

                    param = params_dict[fused_name]
                    # We should ask the weight loader to return success or not
                    # here since otherwise we may skip experts with other
                    # available replicas.
                    weight_loader = typing.cast(
                        Callable[..., bool], param.weight_loader
                    )
                    # Use checkpoint's expert_id for quark format (when expert_id
                    # is extracted from weight name), otherwise use mapping's expert_id
                    actual_expert_id = (
                        expert_id if expert_id is not None else mapping_expert_id
                    )
                    success = weight_loader(
                        param,
                        loaded_weight,
                        fused_name,
                        shard_id=shard_id,
                        expert_id=actual_expert_id,
                        return_success=True,
                    )
                    if success:
                        name = fused_name
                        loaded_params.add(name)
                        break
                else:
                    if name not in params_dict:
                        continue
                    param = params_dict[name]
                    weight_loader = getattr(
                        param, "weight_loader", default_weight_loader
                    )
                    weight_loader(param, loaded_weight)

                loaded_params.add(name)
        return loaded_params

976
    def _load_weights_other(
977
978
        self,
        ep_rank_end: int,
979
        ep_rank_start: int,
980
981
982
983
984
        heads_per_rank: int,
        head_start: int,
        weights: Iterable[tuple[str, torch.Tensor]],
        stacked_params_mapping: list[tuple[str, ...]],
    ) -> set[str]:
985
986
987
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()

988
989
        use_ep = self.parallel_config.enable_expert_parallel

990
991
        # In MoE, we need to flatten the tensor parallel size across the data
        # parallel size when EP is disabled.
992
        tp_size, tp_rank = FusedMoEParallelConfig.flatten_tp_across_dp_and_pcp(
993
994
995
            tp_size=get_tensor_model_parallel_world_size(),
            dp_size=get_dp_group().world_size,
            dp_rank=get_dp_group().rank_in_group,
996
997
            pcp_size=get_pcp_group().world_size,
            pcp_rank=get_pcp_group().rank_in_group,
998
        )
999

1000
        intermediate_size = self.config.intermediate_size
1001
1002
1003
        per_rank_intermediate_size = cdiv(intermediate_size, tp_size)
        # Calculate common slicing bounds for current rank
        tp_rank_start = tp_rank * per_rank_intermediate_size
1004
        tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size, intermediate_size)
1005
1006

        for name, weight in weights:
1007
1008
1009
1010
            # Skip layers on other devices.
            if is_pp_missing_parameter(name, self):
                continue

1011
            if ".w13_weight" in name:
1012
1013
1014
1015
1016
                # Handle MLP gate and up projection weights
                # Extract gate and up projection parts
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
1017
                    narrow_weight = weight[:, :, 2 * tp_rank_start : 2 * tp_rank_end]
1018
1019

                narrow_weight = narrow_weight.permute(0, 2, 1).contiguous()
1020
                param = params_dict[name]
1021
1022

                param.copy_(narrow_weight)
1023
1024
1025
                loaded_params.add(name)
                continue
            elif ".w2_weight" in name:
1026
1027
1028
1029
1030
1031
                # Handle MLP down projection weights
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
                    narrow_weight = weight[:, tp_rank_start:tp_rank_end, :]
                narrow_weight = narrow_weight.permute(0, 2, 1).contiguous()
1032
                param = params_dict[name]
1033
1034

                param.copy_(narrow_weight)
1035
1036
1037
                loaded_params.add(name)
                continue
            elif ".w13_bias" in name:
1038
1039
1040
1041
1042
                # Handle MLP gate and up projection biases
                # Extract gate and up projection bias parts
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
1043
                    narrow_weight = weight[:, 2 * tp_rank_start : 2 * tp_rank_end]
1044

1045
                param = params_dict[name]
1046
                param.copy_(narrow_weight)
1047
1048
1049
                loaded_params.add(name)
                continue
            elif ".w2_bias" in name:
1050
1051
1052
1053
1054
1055
1056
                # Handle MLP down projection bias
                if use_ep:
                    weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
                    # (only load on rank 0 to avoid duplication)
                    if tp_rank != 0:
                        weight.zero_()
1057
                param = params_dict[name]
1058
                param.copy_(weight)
1059
1060
                loaded_params.add(name)
                continue
1061
1062
1063
1064
1065
1066
            elif "sinks" in name:
                # Handle attention sinks (distributed across ranks)
                param = params_dict[name]
                narrow_weight = weight.narrow(0, head_start, heads_per_rank)
                param.data.copy_(narrow_weight)
                loaded_params.add(name)
1067
1068
1069
1070
1071
1072
                continue
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
1073
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
1074
1075
1076
1077
1078
                if weight_loader == default_weight_loader:
                    weight_loader(param, weight)
                else:
                    weight_loader(param, weight, shard_id)
                break
1079
1080
            else:
                # Handle all other weights with potential renaming
1081
                if name not in params_dict:
1082
                    continue
1083
                param = params_dict[name]
1084
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
1085
                weight_loader(param, weight)
1086
            loaded_params.add(name)
1087
1088
        return loaded_params

1089
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
1090
1091
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
1092
1093
1094
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
        ]

        tp_rank = get_tensor_model_parallel_rank()
        tp_size = get_tensor_model_parallel_world_size()

        # Attention heads per rank
        heads_per_rank = self.config.num_attention_heads // tp_size
        head_start = tp_rank * heads_per_rank

        ep_size = get_ep_group().world_size
        ep_rank = get_ep_group().rank
        num_experts = self.config.num_local_experts
        experts_per_rank = num_experts // ep_size
        ep_rank_start = ep_rank * experts_per_rank
        ep_rank_end = (ep_rank + 1) * experts_per_rank

1111
1112
1113
1114
1115
        quant_method = (
            self.config.quantization_config["quant_method"]
            if hasattr(self.config, "quantization_config")
            else None
        )
1116

1117
        if quant_method == "mxfp4":
1118
1119
1120
1121
1122
1123
1124
1125
            return self._load_weights_mxfp4(
                ep_rank_end,
                ep_rank_start,
                heads_per_rank,
                head_start,
                weights,
                stacked_params_mapping,
            )
1126
1127
1128
1129
1130
1131
1132
1133
1134
        elif quant_method == "quark":
            return self._load_weights_quark(
                ep_rank_end,
                ep_rank_start,
                heads_per_rank,
                head_start,
                weights,
                stacked_params_mapping,
            )
1135
        else:
1136
            return self._load_weights_other(
1137
                ep_rank_end,
1138
                ep_rank_start,
1139
1140
1141
1142
1143
                heads_per_rank,
                head_start,
                weights,
                stacked_params_mapping,
            )
1144
1145


1146
1147
1148
class GptOssForCausalLM(
    nn.Module, SupportsPP, SupportsEagle, SupportsEagle3, SupportsLoRA
):
1149
    is_3d_moe_weight: bool = True
1150
    packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168

    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_substr={
            ".self_attn.": ".attn.",
        },
        orig_to_new_suffix={
            ".embed_tokens.weight": ".embedding.weight",
            # MoE MXFP4 weights
            ".gate_up_proj_blocks": ".w13_weight",
            ".down_proj_blocks": ".w2_weight",
            ".gate_up_proj_scales": ".w13_weight_scale",
            ".down_proj_scales": ".w2_weight_scale",
            # MoE other weights
            ".gate_up_proj": ".w13_weight",
            ".down_proj": ".w2_weight",
            # MoE Bias
            ".gate_up_proj_bias": ".w13_bias",
            ".down_proj_bias": ".w2_bias",
1169
1170
1171
1172
1173
1174
1175
1176
1177
            # For quark format
            ".gate_up_proj.weight": ".w13_weight",
            ".gate_up_proj.weight_scale": ".w13_weight_scale",
            ".gate_up_proj.bias": ".w13_bias",
            ".gate_up_proj.input_scale": ".w13_input_scale",
            ".down_proj.weight": ".w2_weight",
            ".down_proj.weight_scale": ".w2_weight_scale",
            ".down_proj.bias": ".w2_bias",
            ".down_proj.input_scale": ".w2_input_scale",
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
        },
    )

    def __init__(
        self,
        vllm_config: VllmConfig,
        prefix: str = "",
    ):
        super().__init__()
        self.vllm_config = vllm_config
        self.config = vllm_config.model_config.hf_config

        self.model = GptOssModel(
            vllm_config=vllm_config,
            prefix=maybe_prefix(prefix, "model"),
        )
        self.lm_head = ParallelLMHead(
            self.config.vocab_size,
            self.config.hidden_size,
1197
            prefix=maybe_prefix(prefix, "lm_head"),
1198
1199
        )
        self.logits_processor = LogitsProcessor(self.config.vocab_size)
1200
        self.make_empty_intermediate_tensors = (
1201
1202
            self.model.make_empty_intermediate_tensors
        )
1203

1204
1205
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
1206

1207
1208
    def forward(
        self,
1209
        input_ids: torch.Tensor | None,
1210
        positions: torch.Tensor,
1211
1212
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
1213
1214
    ) -> torch.Tensor:
        return self.model(input_ids, positions, intermediate_tensors, inputs_embeds)
1215

1216
1217
    def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head, hidden_states)
1218
1219
        return logits

1220
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
1221
1222
        loader = AutoWeightsLoader(
            self,
1223
            skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
1224
1225
        )
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)