gpt_oss.py 49.2 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
import typing
from collections.abc import Callable, Iterable
5
6
7
8
9
10
11
12

import torch
import torch.distributed as dist
from torch import nn
from transformers import GptOssConfig

from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
13
from vllm.distributed import (
14
    get_dp_group,
15
    get_ep_group,
16
    get_pcp_group,
17
18
19
20
21
    get_pp_group,
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
    tensor_model_parallel_all_gather,
)
22
from vllm.model_executor.layers.attention import Attention
23
from vllm.model_executor.layers.fused_moe import FusedMoE
24
from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig
25
from vllm.model_executor.layers.layernorm import RMSNorm
26
27
from vllm.model_executor.layers.linear import (
    QKVParallelLinear,
28
    ReplicatedLinear,
29
30
    RowParallelLinear,
)
31
32
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
33
from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import OCP_MX_BLOCK_SIZE
34
from vllm.model_executor.layers.rotary_embedding import get_rope
35
from vllm.model_executor.layers.utils import rocm_unquantized_gemm
36
from vllm.model_executor.layers.vocab_parallel_embedding import (
37
38
39
    ParallelLMHead,
    VocabParallelEmbedding,
)
40
41
42
43
from vllm.model_executor.model_loader.weight_utils import (
    default_weight_loader,
    maybe_remap_kv_scale_name,
)
44
from vllm.model_executor.models.utils import sequence_parallel_chunk
45
from vllm.platforms import current_platform
46
from vllm.sequence import IntermediateTensors
47
from vllm.utils.math_utils import cdiv
48
from vllm.v1.attention.backend import AttentionType
49

50
51
52
53
54
55
56
from .interfaces import (
    EagleModelMixin,
    SupportsEagle,
    SupportsEagle3,
    SupportsLoRA,
    SupportsPP,
)
57
58
59
60
61
62
63
64
65
from .utils import (
    AutoWeightsLoader,
    WeightsMapper,
    extract_layer_index,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
66
67
68
69
70
71


class OAIAttention(nn.Module):
    def __init__(
        self,
        config: GptOssConfig,
72
73
        quant_config: QuantizationConfig | None = None,
        cache_config: CacheConfig | None = None,
74
75
76
77
78
79
80
81
82
83
84
85
86
        prefix: str = "",
    ):
        super().__init__()
        self.layer_idx = extract_layer_index(prefix)
        self.head_dim = config.head_dim
        self.num_attention_heads = config.num_attention_heads
        self.num_key_value_heads = config.num_key_value_heads
        self.hidden_size = config.hidden_size

        self.rotary_emb = get_rope(
            self.head_dim,
            max_position=config.max_position_embeddings,
            dtype=torch.float32,
87
88
            rope_parameters={
                "rope_theta": config.rope_parameters["rope_theta"],
89
                "rope_type": "yarn",
90
91
                "factor": config.rope_parameters["factor"],
                "original_max_position_embeddings": config.rope_parameters[
92
93
                    "original_max_position_embeddings"
                ],
94
95
                "beta_fast": config.rope_parameters["beta_fast"],
                "beta_slow": config.rope_parameters["beta_slow"],
96
                "truncate": config.rope_parameters.get("truncate", True),
97
98
99
100
101
102
103
            },
            is_neox_style=True,
        )

        tp_size = get_tensor_model_parallel_world_size()

        self.sinks = torch.nn.Parameter(
104
105
            torch.empty(config.num_attention_heads // tp_size, requires_grad=False)
        )
106
107
108
109
110

        self.q_size = self.num_attention_heads * self.head_dim // tp_size
        self.kv_size = self.num_key_value_heads * self.head_dim // tp_size
        self.scaling = self.head_dim**-0.5

111
        self.qkv_proj = QKVParallelLinear(
112
113
114
115
            hidden_size=self.hidden_size,
            head_size=self.head_dim,
            total_num_heads=self.num_attention_heads,
            total_num_kv_heads=self.num_key_value_heads,
116
            bias=True,
117
118
119
120
121
122
123
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
        )

        self.o_proj = RowParallelLinear(
            input_size=self.num_attention_heads * self.head_dim,
            output_size=self.hidden_size,
124
            bias=True,
125
126
127
128
129
130
131
132
            quant_config=quant_config,
            prefix=f"{prefix}.o_proj",
        )

        self.num_local_attention_heads = config.num_attention_heads // tp_size
        self.num_local_key_value_heads = config.num_key_value_heads // tp_size

        # Only apply sliding window to every other layer
133
        sliding_window = config.sliding_window if self.layer_idx % 2 == 0 else None
134
135
136
137
138
139
140
141
142
143
144
145
146
        self.attn = Attention(
            self.num_local_attention_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_local_key_value_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            per_layer_sliding_window=sliding_window,
            attn_type=AttentionType.DECODER,
            prefix=f"{prefix}.attn",
            sinks=self.sinks,
        )

147
148
149
    def forward(
        self, hidden_states: torch.Tensor, positions: torch.Tensor
    ) -> torch.Tensor:
150
        qkv, _ = self.qkv_proj(hidden_states)
151
152
153
154
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
        attn_output = self.attn(q, k, v)
        output, _ = self.o_proj(attn_output)
155
        return output
156
157
158
159
160


class MLPBlock(torch.nn.Module):
    def __init__(
        self,
161
        vllm_config: VllmConfig,
162
163
164
165
        layer_idx: int,
        prefix: str = "",
    ):
        super().__init__()
166
167
168
169
170
171
172

        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        parallel_config = vllm_config.parallel_config

        self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe

173
174
        self.layer_idx = layer_idx
        self.num_experts = config.num_local_experts
175
        self.hidden_size = config.hidden_size
176
177
        self.experts_per_token = config.num_experts_per_tok
        self.world_size = dist.get_world_size() if dist.is_initialized() else 1
178
        self.router = ReplicatedLinear(
179
180
181
            config.hidden_size,
            config.num_local_experts,
            bias=True,
182
            quant_config=None,
183
            prefix=f"{prefix}.router",
184
            return_bias=False,
185
        )
186
        assert config.intermediate_size % self.world_size == 0
187
188
189
190
191
192
193
194
195
196
197
198
199
        self.experts = FusedMoE(
            num_experts=config.num_local_experts,
            top_k=config.num_experts_per_tok,
            hidden_size=config.hidden_size,
            intermediate_size=config.intermediate_size,
            renormalize=True,
            quant_config=quant_config,
            prefix=f"{prefix}.experts",
            apply_router_weight_on_input=False,
            has_bias=True,
            activation="swigluoai",
            is_sequence_parallel=self.is_sequence_parallel,
        )
200
201

    def forward(self, x: torch.Tensor) -> torch.Tensor:
202
203
204
205
        num_tokens = x.shape[0]
        if self.is_sequence_parallel:
            x = sequence_parallel_chunk(x)

206
207
208
209
210
        if current_platform.is_rocm():
            g = rocm_unquantized_gemm(
                self, x[:, : self.hidden_size], self.router.weight, self.router.bias
            )
        else:
211
            g = self.router(x)
212
        x = self.experts(hidden_states=x, router_logits=g)[:, : self.hidden_size]
213
214
215
216

        if self.is_sequence_parallel:
            x = tensor_model_parallel_all_gather(x.contiguous(), 0)
            x = x[:num_tokens]
217
        return x
218
219
220
221
222


class TransformerBlock(torch.nn.Module):
    def __init__(
        self,
223
        vllm_config: VllmConfig,
224
        quant_config: QuantizationConfig,
225
226
227
        prefix: str = "",
    ):
        super().__init__()
228
229
230
231

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config

232
        self.layer_idx = extract_layer_index(prefix)
233
        self.attn = OAIAttention(
234
235
236
237
            config,
            prefix=f"{prefix}.attn",
            quant_config=quant_config,
            cache_config=cache_config,
238
239
        )
        self.mlp = MLPBlock(vllm_config, self.layer_idx, prefix=f"{prefix}.mlp")
240
241
        self.input_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
        self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
242

243
244
245
246
    def forward(
        self,
        hidden_states: torch.Tensor,
        positions: torch.Tensor,
247
        residual: torch.Tensor | None,
248
249
250
251
252
253
    ) -> torch.Tensor:
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
254
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
255
        hidden_states = self.attn(hidden_states, positions)
256

257
        # Fully Connected
258
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
259
260
        output = self.mlp(hidden_states)
        return output, residual
261
262
263


@support_torch_compile
264
class GptOssModel(nn.Module, EagleModelMixin):
265
266
267
268
269
270
271
272
    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
    ):
        super().__init__()
        self.config = vllm_config.model_config.hf_config
273
        self.quant_config = vllm_config.quant_config
274
        self.parallel_config = vllm_config.parallel_config
275
276
277
278
        self.embedding = VocabParallelEmbedding(
            self.config.vocab_size,
            self.config.hidden_size,
        )
279
280
281
        self.start_layer, self.end_layer, self.layers = make_layers(
            self.config.num_hidden_layers,
            lambda prefix: TransformerBlock(
282
                vllm_config,
283
                prefix=prefix,
284
                quant_config=self.quant_config,
285
286
287
            ),
            prefix=f"{prefix}.layers",
        )
288
        self.norm = RMSNorm(self.config.hidden_size, eps=1e-5)
289
290
291
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], self.config.hidden_size
        )
292

293
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
294
295
296
297
        return self.embedding(input_ids)

    def forward(
        self,
298
        input_ids: torch.Tensor | None,
299
        positions: torch.Tensor,
300
301
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
302
303
304
305
306
    ) -> torch.Tensor:
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                x = inputs_embeds
            else:
307
                x = self.embed_input_ids(input_ids)
308
309
310
311
312
313
314

            residual = None
        else:
            assert intermediate_tensors is not None
            x = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

315
316
317
        aux_hidden_states = self._maybe_add_hidden_state(
            [], self.start_layer, x, residual
        )
318
319
320
        for i in range(self.start_layer, self.end_layer):
            layer = self.layers[i]
            x, residual = layer(x, positions, residual)
321
            self._maybe_add_hidden_state(aux_hidden_states, i + 1, x, residual)
322
        if not get_pp_group().is_last_rank:
323
            return IntermediateTensors({"hidden_states": x, "residual": residual})
324
        x, _ = self.norm(x, residual)
325
326
327

        if len(aux_hidden_states) > 0:
            return x, aux_hidden_states
328
329
        return x

330
331
332
333
334
335
336
337
338
339
340
341
342
    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
        # Params for weights, weight scales, activation scales
        # (param_name, weight_name, expert_id, shard_id)
        # NOTE: this is only used for quark.
        return FusedMoE.make_expert_params_mapping(
            self,
            ckpt_gate_proj_name="w1",
            ckpt_down_proj_name="w2",
            ckpt_up_proj_name="w3",
            num_experts=self.config.num_local_experts,
            num_redundant_experts=0,
        )

343
    def _load_weights_mxfp4(
344
345
346
347
348
349
350
351
        self,
        ep_rank_end: int,
        ep_rank_start: int,
        heads_per_rank: int,
        head_start: int,
        weights: Iterable[tuple[str, torch.Tensor]],
        stacked_params_mapping: list[tuple[str, ...]],
    ) -> set[str]:
352
353
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()
354
355
356

        use_ep = self.parallel_config.enable_expert_parallel
        num_experts = self.config.num_local_experts
357

358
359
        # In MoE, we need to flatten the tensor parallel size across the data
        # parallel size when EP is disabled.
360
        tp_size, tp_rank = FusedMoEParallelConfig.flatten_tp_across_dp_and_pcp(
361
362
363
            tp_size=get_tensor_model_parallel_world_size(),
            dp_size=get_dp_group().world_size,
            dp_rank=get_dp_group().rank_in_group,
364
365
            pcp_size=get_pcp_group().world_size,
            pcp_rank=get_pcp_group().rank_in_group,
366
        )
367
368

        intermediate_size = self.config.intermediate_size
369
        intermediate_size_block = intermediate_size // OCP_MX_BLOCK_SIZE
370
        per_rank_intermediate_size_block = cdiv(intermediate_size_block, tp_size)
371
372
373
        per_rank_intermediate_size = (
            per_rank_intermediate_size_block * OCP_MX_BLOCK_SIZE
        )
374
375
376

        # Calculate common slicing bounds for current rank
        tp_rank_start = tp_rank * per_rank_intermediate_size
377
        tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size, intermediate_size)
378
379

        for name, weight in weights:
380
381
382
383
            # Skip layers on other devices.
            if is_pp_missing_parameter(name, self):
                continue

384
385
            if ".w13_weight_scale" in name:
                # Handle MLP gate and up projection weights scale
386
387
388
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
389
                    narrow_weight = weight[:, 2 * tp_rank_start : 2 * tp_rank_end, ...]
390

391
                param = params_dict[name]
392
393
394
395
396
397
398
399
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                weight_loader(
                    param,
                    narrow_weight,
                    weight_name=name,
                    shard_id=None,
                    expert_id=None,
                )
400
401
402
                loaded_params.add(name)
                continue
            elif ".w2_weight_scale" in name:
403
404
405
406
                # Handle MLP down projection weights
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
407
                    narrow_weight = weight[
408
409
410
                        ...,
                        tp_rank_start // OCP_MX_BLOCK_SIZE : tp_rank_end
                        // OCP_MX_BLOCK_SIZE,
411
                    ]
412

413
                param = params_dict[name]
414
415
416
417
418
419
420
421
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                weight_loader(
                    param,
                    narrow_weight,
                    weight_name=name,
                    shard_id=None,
                    expert_id=None,
                )
422
423
424
425
426
427
                loaded_params.add(name)
                continue
            elif ".w13_weight" in name:
                # Handle MLP gate and up projection weights
                # flat weight from (E, 2 * N, block_size, entry_per_block)
                # to (E, 2 * N, -1), shouldn't trigger copy for contiguous
428
429
430
                weight = weight.view(
                    num_experts, 2 * intermediate_size, -1
                ).contiguous()
431

432
433
                # Extract gate and up projection parts
                # since the weight is shuffled, we can slice directly
434
435
436
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
437
                    narrow_weight = weight[:, 2 * tp_rank_start : 2 * tp_rank_end, ...]
438

439
                param = params_dict[name]
440
441
442
443
444
445
446
447
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                weight_loader(
                    param,
                    narrow_weight,
                    weight_name=name,
                    shard_id=None,
                    expert_id=None,
                )
448
449
450
                loaded_params.add(name)
                continue
            elif ".w2_weight" in name:
451
                # Handle MLP down projection weights
452
453
                # same flatten here, but since 2 mx4 value are packed in 1
                # uint8, divide by 2
454
455
456
                weight = weight.view(
                    num_experts, -1, intermediate_size // 2
                ).contiguous()
457
458
459
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
460
                    narrow_weight = weight[..., tp_rank_start // 2 : tp_rank_end // 2]
461

462
                param = params_dict[name]
463
464
465
466
467
468
469
470
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                weight_loader(
                    param,
                    narrow_weight,
                    weight_name=name,
                    shard_id=None,
                    expert_id=None,
                )
471
472
473
                loaded_params.add(name)
                continue
            elif ".w13_bias" in name:
474
475
476
477
478
                # Handle MLP gate and up projection biases
                # Extract gate and up projection bias parts
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
479
                    narrow_weight = weight[:, 2 * tp_rank_start : 2 * tp_rank_end]
480

481
                param = params_dict[name]
482
483
484
485
486
487
488
489
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                weight_loader(
                    param,
                    narrow_weight,
                    weight_name=name,
                    shard_id=None,
                    expert_id=None,
                )
490
491
492
                loaded_params.add(name)
                continue
            elif ".w2_bias" in name:
493
                # Handle MLP down projection bias
494
                param = params_dict[name]
495
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
496
497
498
499
500
501
                if use_ep:
                    weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
                    # (only load on rank 0 to avoid duplication)
                    if tp_rank != 0:
                        weight.zero_()
502
503
504
                weight_loader(
                    param, weight, weight_name=name, shard_id=None, expert_id=None
                )
505
506
                loaded_params.add(name)
                continue
507
508
509
510
511
512
            elif "sinks" in name:
                # Handle attention sinks (distributed across ranks)
                param = params_dict[name]
                narrow_weight = weight.narrow(0, head_start, heads_per_rank)
                param.data.copy_(narrow_weight)
                loaded_params.add(name)
513
514
515
516
517
518
                continue
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
519
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
520
521
522
523
524
                if weight_loader == default_weight_loader:
                    weight_loader(param, weight)
                else:
                    weight_loader(param, weight, shard_id)
                break
525
526
            else:
                # Handle all other weights with potential renaming
527
                if name not in params_dict:
528
                    continue
529
                param = params_dict[name]
530
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
531
                weight_loader(param, weight)
532
            loaded_params.add(name)
533
        return loaded_params
534

535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
    def _load_weights_quark(
        self,
        ep_rank_end: int,
        ep_rank_start: int,
        heads_per_rank: int,
        head_start: int,
        weights: Iterable[tuple[str, torch.Tensor]],
        stacked_params_mapping: list[tuple[str, ...]],
    ) -> set[str]:
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()

        use_ep = self.parallel_config.enable_expert_parallel
        num_experts = self.config.num_local_experts

        if use_ep:
            tp_rank = get_tensor_model_parallel_rank()
            tp_size = get_tensor_model_parallel_world_size()
        else:
            tp_size, tp_rank = FusedMoEParallelConfig.flatten_tp_across_dp_and_pcp(
                tp_size=get_tensor_model_parallel_world_size(),
                dp_size=get_dp_group().world_size,
                dp_rank=get_dp_group().rank_in_group,
                pcp_size=get_pcp_group().world_size,
                pcp_rank=get_pcp_group().rank_in_group,
            )

562
563
564
565
566
567
568
569
        def _is_mxfp4(weight_dtype: str | None) -> bool:
            """Return True for any MXFP4 weight-dtype variant.

            Covers "gpt_oss_mxfp4" (GptOssMxfp4MoEMethod) and "mxfp4"
            (QuarkMoEMethod with fp4 weights) and any future variants.
            """
            return weight_dtype is not None and "mxfp4" in weight_dtype

570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
        def _get_moe_weight_dtype(layer_id: int = 0) -> str | None:
            """Helper function to get MoE quantization weight dtype.

            Args:
                layer_id: Layer index to check (default 0, as all layers should
                        have the same quantization method)

            Returns:
                Weight dtype string (e.g., "mxfp4", "fp8") or None if not available
            """
            if hasattr(self.layers[layer_id].mlp.experts.quant_method, "weight_dtype"):
                return self.layers[layer_id].mlp.experts.quant_method.weight_dtype
            return None

        intermediate_size = self.config.intermediate_size

        moe_weight_dtype = _get_moe_weight_dtype(layer_id=0)

588
        if _is_mxfp4(moe_weight_dtype):
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
            # MXFP4 requires OCP_MX_BLOCK_SIZE alignment
            intermediate_size_block = intermediate_size // OCP_MX_BLOCK_SIZE
            per_rank_intermediate_size_block = cdiv(intermediate_size_block, tp_size)
            per_rank_intermediate_size = (
                per_rank_intermediate_size_block * OCP_MX_BLOCK_SIZE
            )
        else:
            # FP8 and other formats don't need alignment
            per_rank_intermediate_size = cdiv(intermediate_size, tp_size)

        tp_rank_start = tp_rank * per_rank_intermediate_size
        tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size, intermediate_size)
        expert_params_mapping = self.get_expert_mapping()
        for name, loaded_weight in weights:
            if is_pp_missing_parameter(name, self):
                continue

            layer_id, expert_id, fused_name = None, None, None
            moe_quant_method = None
            if "experts" in name:
                parts = name.split(".")
                ids = [s for s in parts if s.isdigit()]

612
                # for amd-quark format that each expert is separated
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
                # need to extract the parameter name with experts fused.
                # example model: amd/gpt-oss-20b-MoE-Quant-W-MXFP4-A-FP8-KV-FP8
                if len(ids) == 2:
                    layer_id, expert_id = int(ids[0]), int(ids[-1])
                    parts.pop(len(parts) - 1 - parts[::-1].index(str(expert_id)))
                    fused_name = ".".join(parts)

                # for openai mxfp4 format that all experts are combined
                # no need to extract the parameter name with experts fused.
                # models: openai/gpt-oss-20b, openai/gpt-oss-120b
                elif len(ids) == 1:
                    layer_id, expert_id = int(ids[0]), None
                    fused_name = name

                else:
                    raise NameError(
                        f"Layer {name} contains more than 2 numeric indices. This is "
                        "an unexpected condition. Please open an issue if encountered."
                    )

                moe_quant_method = _get_moe_weight_dtype(layer_id=layer_id)

            def kv_cache_scale_loader(
                quant_config: QuantizationConfig,
                name: str,
                params_dict: dict[str, typing.Any],
                weight: torch.Tensor,
                default_weight_loader: Callable[..., None],
                loaded_params: set[str],
            ) -> tuple[bool, set[str]]:
                """
                Load KV cache output scales.
                Returns:
                    Tuple of (bool, set):
                    - bool: True if KV-cache scale was loaded into loaded_params
                    - set: Updated set of loaded_params if True else the original set
                """
                # load explicit cached KV output scale from quant_config
                if quant_config is not None and (
                    scale_name := quant_config.get_cache_scale(name)
                ):
                    param = params_dict[scale_name]
                    weight_loader = getattr(
                        param, "weight_loader", default_weight_loader
                    )
                    if weight.numel() != 1:
                        raise ValueError(
                            f"KV cache scale '{scale_name}' is expected to be a "
                            f"scalar, but got a tensor of shape {weight.shape}."
                        )
                    # Ensure weight is a scalar before passing to loader.
                    weight_loader(param, weight.flatten()[0])
                    loaded_params.add(scale_name)
                    return True, loaded_params

                return False, loaded_params

            load_kv_cache_scale_completed, loaded_params = kv_cache_scale_loader(
                self.quant_config,
                name,
                params_dict,
                loaded_weight,
                default_weight_loader,
                loaded_params,
            )
            if load_kv_cache_scale_completed:
                continue

            if (
                all(key in name for key in ["input_scale", "mlp.experts"])
                and expert_id is not None
            ):
                assert loaded_weight.numel() == 1
                expert_data = params_dict[fused_name].data[expert_id]
                expert_data.copy_(loaded_weight)
                loaded_params.add(fused_name)
                continue

            # Unified handler for mxfp4 weights and scales
692
            elif _is_mxfp4(moe_quant_method) and any(
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
                name.endswith(suffix)
                for suffix in [
                    ".w13_weight_scale",
                    ".w2_weight_scale",
                    ".w13_weight",
                    ".w2_weight",
                ]
            ):
                is_w13 = ".w13_" in name
                is_scale = "_scale" in name

                # Reshape weight for mxfp4 if needed (not for scales)
                if not is_scale and expert_id is None:
                    if is_w13:
                        if loaded_weight.dim() < 3:
                            raise ValueError(
                                f"Expected w13_weight to have at least 3 "
                                f"dimensions, got shape "
                                f"{loaded_weight.shape}"
                            )
                        if loaded_weight.shape[0] != num_experts:
                            raise ValueError(
                                f"Expected w13_weight first dimension to be "
                                f"{num_experts}, got "
                                f"{loaded_weight.shape[0]}"
                            )
                        loaded_weight = loaded_weight.view(
                            num_experts, 2 * intermediate_size, -1
                        ).contiguous()
                    else:
                        if loaded_weight.dim() < 3:
                            raise ValueError(
                                f"Expected w2_weight to have at least 3 "
                                f"dimensions, got shape "
                                f"{loaded_weight.shape}"
                            )
                        if loaded_weight.shape[0] != num_experts:
                            raise ValueError(
                                f"Expected w2_weight first dimension to be "
                                f"{num_experts}, got "
                                f"{loaded_weight.shape[0]}"
                            )
                        loaded_weight = loaded_weight.view(
                            num_experts, -1, intermediate_size // 2
                        ).contiguous()

                if use_ep:
                    sliced_weight = loaded_weight[ep_rank_start:ep_rank_end, ...]
                else:
                    if is_w13:
                        if expert_id is None:
                            sliced_weight = loaded_weight[
                                :, 2 * tp_rank_start : 2 * tp_rank_end, ...
                            ]
                        else:
                            sliced_weight = loaded_weight[
                                2 * tp_rank_start : 2 * tp_rank_end, ...
                            ]
                    else:
                        if is_scale:
                            sliced_weight = loaded_weight[
                                ...,
                                tp_rank_start // OCP_MX_BLOCK_SIZE : tp_rank_end
                                // OCP_MX_BLOCK_SIZE,
                            ]
                        else:
                            sliced_weight = loaded_weight[
                                ..., tp_rank_start // 2 : tp_rank_end // 2
                            ]

                # NOTE(rob): because gpt-oss ckpt has "unique" structure with
                # fused gate_up_proj fused on disk, we cannot use the existing
                # weight loaders without added complexity, so just do the
                # direct load here.
                param = params_dict[fused_name]
                expert_data = param.data[expert_id]
                dim1 = sliced_weight.shape[0]
                dim2 = sliced_weight.shape[1]
                expert_data.data[:dim1, :dim2].copy_(sliced_weight)
                loaded_params.add(fused_name)
                continue

            elif name.endswith(".w13_weight") and moe_quant_method == "fp8":
                if use_ep:
                    narrow_weight = loaded_weight[ep_rank_start:ep_rank_end, ...]
                else:
                    if expert_id is None:
                        narrow_weight = loaded_weight[
                            :, 2 * tp_rank_start : 2 * tp_rank_end, :
                        ]
                    else:
                        narrow_weight = loaded_weight[
                            2 * tp_rank_start : 2 * tp_rank_end, :
                        ]

                assert fused_name is not None
                param = params_dict[fused_name]

                if expert_id is None:
                    param.data.copy_(narrow_weight)
                else:
                    param.data[expert_id].copy_(narrow_weight)

                loaded_params.add(fused_name)
                continue

            elif name.endswith(".w13_weight_scale") and moe_quant_method == "fp8":
                assert fused_name is not None
                param = params_dict[fused_name]

                # Check if this is per-channel or per-tensor scale
                if loaded_weight.numel() > 1 and loaded_weight.dim() == 1:
                    if use_ep:
                        narrow_weight = loaded_weight[ep_rank_start:ep_rank_end, ...]
                    else:
                        narrow_weight = loaded_weight[
                            2 * tp_rank_start : 2 * tp_rank_end
                        ]
                else:
                    narrow_weight = loaded_weight

                if expert_id is None:
                    param.data.copy_(narrow_weight)
                else:
                    param.data[expert_id].copy_(narrow_weight)

                loaded_params.add(fused_name)
                continue

            elif name.endswith(".w13_input_scale") and moe_quant_method == "fp8":
                assert fused_name is not None
                param = params_dict[fused_name]

                if expert_id is None:
                    param.data.copy_(loaded_weight)
                else:
                    param.data[expert_id].copy_(loaded_weight)

                loaded_params.add(fused_name)
                continue

            elif name.endswith(".w2_weight") and moe_quant_method == "fp8":
                if use_ep:
                    narrow_weight = loaded_weight[ep_rank_start:ep_rank_end, ...]
                else:
                    if expert_id is None:
                        narrow_weight = loaded_weight[..., tp_rank_start:tp_rank_end]
                    else:
                        narrow_weight = loaded_weight[..., tp_rank_start:tp_rank_end]

                assert fused_name is not None
                param = params_dict[fused_name]

                if expert_id is None:
                    param.data.copy_(narrow_weight)
                else:
                    param.data[expert_id].copy_(narrow_weight)

                loaded_params.add(fused_name)
                continue

            elif name.endswith(".w2_weight_scale") and moe_quant_method == "fp8":
                assert fused_name is not None
                param = params_dict[fused_name]

                if use_ep:
                    narrow_weight = loaded_weight[ep_rank_start:ep_rank_end, ...]
                else:
                    narrow_weight = loaded_weight

                if expert_id is None:
                    param.data.copy_(narrow_weight)
                else:
                    param.data[expert_id].copy_(narrow_weight)

                loaded_params.add(fused_name)
                continue

            # Unified handler for bias loading (w13_bias and w2_bias)
            elif name.endswith(".w13_bias") or name.endswith(".w2_bias"):
                is_w13_bias = name.endswith(".w13_bias")

                if use_ep:
                    sliced_weight = loaded_weight[ep_rank_start:ep_rank_end, ...]
                else:
                    if is_w13_bias:
                        if expert_id is None:
                            sliced_weight = loaded_weight[
                                :, 2 * tp_rank_start : 2 * tp_rank_end
                            ]
                        else:
                            sliced_weight = loaded_weight[
                                2 * tp_rank_start : 2 * tp_rank_end
                            ]
                    else:
                        sliced_weight = loaded_weight
                        if tp_rank != 0:
                            sliced_weight = sliced_weight.zero_()

                # NOTE(rob): because gpt-oss ckpt has "unique" structure with
                # fused gate_up_proj fused on disk, we cannot use the existing
                # weight loaders without added complexity, so just do the
                # direct load here.
                assert fused_name is not None
                param = params_dict[fused_name]
                expert_data = param.data[expert_id]
                dim1 = sliced_weight.shape[0]
                expert_data.data[:dim1].copy_(sliced_weight)
                loaded_params.add(fused_name)
                continue

            elif "sinks" in name:
                # Handle attention sinks (distributed across ranks)
                param = params_dict[name]
                narrow_weight = loaded_weight.narrow(0, head_start, heads_per_rank)
                param.data.copy_(narrow_weight)
                loaded_params.add(name)
                continue

            for param_name, weight_name, shard_id in stacked_params_mapping:
                # Skip non-stacked layers and experts (experts handled below).
                if weight_name not in name:
                    continue
                # We have mlp.experts[0].gate_proj in the checkpoint.
                # Since we handle the experts below in expert_params_mapping,
                # we need to skip here BEFORE we update the name, otherwise
                # name will be updated to mlp.experts[0].gate_up_proj, which
                # will then be updated below in expert_params_mapping
                # for mlp.experts[0].gate_gate_up_proj, which breaks load.
                if ("mlp.experts." in name) and name not in params_dict:
                    continue
                name = name.replace(weight_name, param_name)

                if name.endswith("scale"):
                    # Remapping the name of FP8 kv-scale.
                    name = maybe_remap_kv_scale_name(name, params_dict)
                    if name is None:
                        continue

                param = params_dict[name]
                weight_loader = param.weight_loader

                weight_loader(param, loaded_weight, shard_id)
                loaded_params.add(name)
                break
            else:
                for mapping in expert_params_mapping:
                    # Anyway, this is an expert weight and should not be
                    # attempted to load as other weights later
                    param_name, weight_name, mapping_expert_id, shard_id = mapping
                    weight_name = (
                        weight_name[:-1] if weight_name.endswith(".") else weight_name
                    )

                    if weight_name not in name:
                        continue

                    param = params_dict[fused_name]
                    # We should ask the weight loader to return success or not
                    # here since otherwise we may skip experts with other
                    # available replicas.
                    weight_loader = typing.cast(
                        Callable[..., bool], param.weight_loader
                    )
                    # Use checkpoint's expert_id for quark format (when expert_id
                    # is extracted from weight name), otherwise use mapping's expert_id
                    actual_expert_id = (
                        expert_id if expert_id is not None else mapping_expert_id
                    )
                    success = weight_loader(
                        param,
                        loaded_weight,
                        fused_name,
                        shard_id=shard_id,
                        expert_id=actual_expert_id,
                        return_success=True,
                    )
                    if success:
                        name = fused_name
                        loaded_params.add(name)
                        break
                else:
                    if name not in params_dict:
                        continue
                    param = params_dict[name]
                    weight_loader = getattr(
                        param, "weight_loader", default_weight_loader
                    )
                    weight_loader(param, loaded_weight)

                loaded_params.add(name)
        return loaded_params

986
    def _load_weights_other(
987
988
        self,
        ep_rank_end: int,
989
        ep_rank_start: int,
990
991
992
993
994
        heads_per_rank: int,
        head_start: int,
        weights: Iterable[tuple[str, torch.Tensor]],
        stacked_params_mapping: list[tuple[str, ...]],
    ) -> set[str]:
995
996
997
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()

998
999
        use_ep = self.parallel_config.enable_expert_parallel

1000
1001
        # In MoE, we need to flatten the tensor parallel size across the data
        # parallel size when EP is disabled.
1002
        tp_size, tp_rank = FusedMoEParallelConfig.flatten_tp_across_dp_and_pcp(
1003
1004
1005
            tp_size=get_tensor_model_parallel_world_size(),
            dp_size=get_dp_group().world_size,
            dp_rank=get_dp_group().rank_in_group,
1006
1007
            pcp_size=get_pcp_group().world_size,
            pcp_rank=get_pcp_group().rank_in_group,
1008
        )
1009

1010
        intermediate_size = self.config.intermediate_size
1011
1012
1013
        per_rank_intermediate_size = cdiv(intermediate_size, tp_size)
        # Calculate common slicing bounds for current rank
        tp_rank_start = tp_rank * per_rank_intermediate_size
1014
        tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size, intermediate_size)
1015
1016

        for name, weight in weights:
1017
1018
1019
1020
            # Skip layers on other devices.
            if is_pp_missing_parameter(name, self):
                continue

1021
            if ".w13_weight" in name:
1022
1023
1024
1025
1026
                # Handle MLP gate and up projection weights
                # Extract gate and up projection parts
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
1027
                    narrow_weight = weight[:, :, 2 * tp_rank_start : 2 * tp_rank_end]
1028
1029

                narrow_weight = narrow_weight.permute(0, 2, 1).contiguous()
1030
                param = params_dict[name]
1031
1032

                param.copy_(narrow_weight)
1033
1034
1035
                loaded_params.add(name)
                continue
            elif ".w2_weight" in name:
1036
1037
1038
1039
1040
1041
                # Handle MLP down projection weights
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
                    narrow_weight = weight[:, tp_rank_start:tp_rank_end, :]
                narrow_weight = narrow_weight.permute(0, 2, 1).contiguous()
1042
                param = params_dict[name]
1043
1044

                param.copy_(narrow_weight)
1045
1046
1047
                loaded_params.add(name)
                continue
            elif ".w13_bias" in name:
1048
1049
1050
1051
1052
                # Handle MLP gate and up projection biases
                # Extract gate and up projection bias parts
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
1053
                    narrow_weight = weight[:, 2 * tp_rank_start : 2 * tp_rank_end]
1054

1055
                param = params_dict[name]
1056
                param.copy_(narrow_weight)
1057
1058
1059
                loaded_params.add(name)
                continue
            elif ".w2_bias" in name:
1060
1061
1062
1063
1064
1065
1066
                # Handle MLP down projection bias
                if use_ep:
                    weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
                    # (only load on rank 0 to avoid duplication)
                    if tp_rank != 0:
                        weight.zero_()
1067
                param = params_dict[name]
1068
                param.copy_(weight)
1069
1070
                loaded_params.add(name)
                continue
1071
1072
1073
1074
1075
1076
            elif "sinks" in name:
                # Handle attention sinks (distributed across ranks)
                param = params_dict[name]
                narrow_weight = weight.narrow(0, head_start, heads_per_rank)
                param.data.copy_(narrow_weight)
                loaded_params.add(name)
1077
1078
1079
1080
1081
1082
                continue
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
1083
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
1084
1085
1086
1087
1088
                if weight_loader == default_weight_loader:
                    weight_loader(param, weight)
                else:
                    weight_loader(param, weight, shard_id)
                break
1089
1090
            else:
                # Handle all other weights with potential renaming
1091
                if name not in params_dict:
1092
                    continue
1093
                param = params_dict[name]
1094
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
1095
                weight_loader(param, weight)
1096
            loaded_params.add(name)
1097
1098
        return loaded_params

1099
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
1100
1101
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
1102
1103
1104
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
        ]

        tp_rank = get_tensor_model_parallel_rank()
        tp_size = get_tensor_model_parallel_world_size()

        # Attention heads per rank
        heads_per_rank = self.config.num_attention_heads // tp_size
        head_start = tp_rank * heads_per_rank

        ep_size = get_ep_group().world_size
        ep_rank = get_ep_group().rank
        num_experts = self.config.num_local_experts
        experts_per_rank = num_experts // ep_size
        ep_rank_start = ep_rank * experts_per_rank
        ep_rank_end = (ep_rank + 1) * experts_per_rank

1121
1122
1123
1124
1125
        quant_method = (
            self.config.quantization_config["quant_method"]
            if hasattr(self.config, "quantization_config")
            else None
        )
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
        # Normalize the checkpoint's quant_method to the internal name.
        # Note: there are three places where "mxfp4" -> "gpt_oss_mxfp4"
        # normalization occurs, each serving a different data path:
        #   1. GptOssMxfp4Config.override_quantization_method() — sets
        #      ModelConfig.quantization (used to select the QuantizationConfig
        #      class at model init time), reading from model_arch_config which
        #      is a snapshot taken before verify_and_update_model_config runs.
        #   2. GptOssForCausalLMConfig.verify_and_update_model_config() —
        #      patches hf_config.quantization_config in-place (a separate copy
        #      of the dict from model_arch_config) for later hf_config lookups.
        #   3. Here — reads directly from self.config (the raw HF config) which
        #      may still carry the original "mxfp4" string from the checkpoint.
1138
        if quant_method == "mxfp4":
1139
1140
1141
            quant_method = "gpt_oss_mxfp4"

        if quant_method == "gpt_oss_mxfp4":
1142
1143
1144
1145
1146
1147
1148
1149
            return self._load_weights_mxfp4(
                ep_rank_end,
                ep_rank_start,
                heads_per_rank,
                head_start,
                weights,
                stacked_params_mapping,
            )
1150
1151
1152
1153
1154
1155
1156
1157
1158
        elif quant_method == "quark":
            return self._load_weights_quark(
                ep_rank_end,
                ep_rank_start,
                heads_per_rank,
                head_start,
                weights,
                stacked_params_mapping,
            )
1159
        else:
1160
            return self._load_weights_other(
1161
                ep_rank_end,
1162
                ep_rank_start,
1163
1164
1165
1166
1167
                heads_per_rank,
                head_start,
                weights,
                stacked_params_mapping,
            )
1168
1169


1170
1171
1172
class GptOssForCausalLM(
    nn.Module, SupportsPP, SupportsEagle, SupportsEagle3, SupportsLoRA
):
1173
    is_3d_moe_weight: bool = True
1174
    packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192

    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_substr={
            ".self_attn.": ".attn.",
        },
        orig_to_new_suffix={
            ".embed_tokens.weight": ".embedding.weight",
            # MoE MXFP4 weights
            ".gate_up_proj_blocks": ".w13_weight",
            ".down_proj_blocks": ".w2_weight",
            ".gate_up_proj_scales": ".w13_weight_scale",
            ".down_proj_scales": ".w2_weight_scale",
            # MoE other weights
            ".gate_up_proj": ".w13_weight",
            ".down_proj": ".w2_weight",
            # MoE Bias
            ".gate_up_proj_bias": ".w13_bias",
            ".down_proj_bias": ".w2_bias",
1193
1194
1195
1196
1197
1198
1199
1200
1201
            # For quark format
            ".gate_up_proj.weight": ".w13_weight",
            ".gate_up_proj.weight_scale": ".w13_weight_scale",
            ".gate_up_proj.bias": ".w13_bias",
            ".gate_up_proj.input_scale": ".w13_input_scale",
            ".down_proj.weight": ".w2_weight",
            ".down_proj.weight_scale": ".w2_weight_scale",
            ".down_proj.bias": ".w2_bias",
            ".down_proj.input_scale": ".w2_input_scale",
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
        },
    )

    def __init__(
        self,
        vllm_config: VllmConfig,
        prefix: str = "",
    ):
        super().__init__()
        self.vllm_config = vllm_config
        self.config = vllm_config.model_config.hf_config

        self.model = GptOssModel(
            vllm_config=vllm_config,
            prefix=maybe_prefix(prefix, "model"),
        )
        self.lm_head = ParallelLMHead(
            self.config.vocab_size,
            self.config.hidden_size,
1221
            prefix=maybe_prefix(prefix, "lm_head"),
1222
1223
        )
        self.logits_processor = LogitsProcessor(self.config.vocab_size)
1224
        self.make_empty_intermediate_tensors = (
1225
1226
            self.model.make_empty_intermediate_tensors
        )
1227

1228
1229
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
1230

1231
1232
    def forward(
        self,
1233
        input_ids: torch.Tensor | None,
1234
        positions: torch.Tensor,
1235
1236
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
1237
1238
    ) -> torch.Tensor:
        return self.model(input_ids, positions, intermediate_tensors, inputs_embeds)
1239

1240
1241
    def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head, hidden_states)
1242
1243
        return logits

1244
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
1245
1246
        loader = AutoWeightsLoader(
            self,
1247
            skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
1248
1249
        )
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)