gpt_oss.py 28.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable

import torch
import torch.distributed as dist
from torch import nn
from transformers import GptOssConfig

10
from vllm.attention.layer import Attention
11
12
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
13
from vllm.distributed import (
14
    get_dp_group,
15
    get_ep_group,
16
    get_pcp_group,
17
18
19
20
21
    get_pp_group,
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
    tensor_model_parallel_all_gather,
)
22
from vllm.model_executor.layers.fused_moe import FusedMoE
23
from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig
24
from vllm.model_executor.layers.layernorm import RMSNorm
25
from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear
26
27
28
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
29
from vllm.model_executor.layers.utils import rocm_unquantized_gemm
30
from vllm.model_executor.layers.vocab_parallel_embedding import (
31
32
33
    ParallelLMHead,
    VocabParallelEmbedding,
)
34
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
35
from vllm.model_executor.models.utils import sequence_parallel_chunk
36

37
from vllm.platforms import current_platform
38
from vllm.sequence import IntermediateTensors
39
from vllm.utils.math_utils import cdiv
40
from vllm.v1.attention.backend import AttentionType
41

42
from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP
43
44
45
46
47
48
49
50
51
from .utils import (
    AutoWeightsLoader,
    WeightsMapper,
    extract_layer_index,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
52
53
54
55
56
57


class OAIAttention(nn.Module):
    def __init__(
        self,
        config: GptOssConfig,
58
59
        quant_config: QuantizationConfig | None = None,
        cache_config: CacheConfig | None = None,
60
61
62
63
64
65
66
67
68
69
70
71
72
        prefix: str = "",
    ):
        super().__init__()
        self.layer_idx = extract_layer_index(prefix)
        self.head_dim = config.head_dim
        self.num_attention_heads = config.num_attention_heads
        self.num_key_value_heads = config.num_key_value_heads
        self.hidden_size = config.hidden_size

        self.rotary_emb = get_rope(
            self.head_dim,
            max_position=config.max_position_embeddings,
            dtype=torch.float32,
73
74
            rope_parameters={
                "rope_theta": config.rope_parameters["rope_theta"],
75
                "rope_type": "yarn",
76
77
                "factor": config.rope_parameters["factor"],
                "original_max_position_embeddings": config.rope_parameters[
78
79
                    "original_max_position_embeddings"
                ],
80
81
                "beta_fast": config.rope_parameters["beta_fast"],
                "beta_slow": config.rope_parameters["beta_slow"],
82
                "truncate": config.rope_parameters.get("truncate", True),
83
84
85
86
87
88
89
            },
            is_neox_style=True,
        )

        tp_size = get_tensor_model_parallel_world_size()

        self.sinks = torch.nn.Parameter(
90
91
            torch.empty(config.num_attention_heads // tp_size, requires_grad=False)
        )
92
93
94
95
96

        self.q_size = self.num_attention_heads * self.head_dim // tp_size
        self.kv_size = self.num_key_value_heads * self.head_dim // tp_size
        self.scaling = self.head_dim**-0.5

97
        self.qkv_proj = QKVParallelLinear(
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
            hidden_size=self.hidden_size,
            head_size=self.head_dim,
            total_num_heads=self.num_attention_heads,
            total_num_kv_heads=self.num_key_value_heads,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
        )

        self.o_proj = RowParallelLinear(
            input_size=self.num_attention_heads * self.head_dim,
            output_size=self.hidden_size,
            quant_config=quant_config,
            prefix=f"{prefix}.o_proj",
        )

        self.num_local_attention_heads = config.num_attention_heads // tp_size
        self.num_local_key_value_heads = config.num_key_value_heads // tp_size

        # Only apply sliding window to every other layer
117
        sliding_window = config.sliding_window if self.layer_idx % 2 == 0 else None
118
119
120
121
122
123
124
125
126
127
128
129
130
        self.attn = Attention(
            self.num_local_attention_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_local_key_value_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            per_layer_sliding_window=sliding_window,
            attn_type=AttentionType.DECODER,
            prefix=f"{prefix}.attn",
            sinks=self.sinks,
        )

131
132
133
    def forward(
        self, hidden_states: torch.Tensor, positions: torch.Tensor
    ) -> torch.Tensor:
134
        qkv, _ = self.qkv_proj(hidden_states)
135
136
137
138
139
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
        v = v.contiguous()
        attn_output = self.attn(q, k, v)
        output, _ = self.o_proj(attn_output)
140
        return output
141
142
143
144
145


class MLPBlock(torch.nn.Module):
    def __init__(
        self,
146
        vllm_config: VllmConfig,
147
148
149
150
        layer_idx: int,
        prefix: str = "",
    ):
        super().__init__()
151
152
153
154
155
156
157

        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        parallel_config = vllm_config.parallel_config

        self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe

158
159
        self.layer_idx = layer_idx
        self.num_experts = config.num_local_experts
160
        self.hidden_size = config.hidden_size
161
162
        self.experts_per_token = config.num_experts_per_tok
        self.world_size = dist.get_world_size() if dist.is_initialized() else 1
163
        self.router = torch.nn.Linear(config.hidden_size, config.num_local_experts)
164
        assert config.intermediate_size % self.world_size == 0
165
166
167
168
169
170
171
172
173
174
175
176
177
178
        self.experts = FusedMoE(
            num_experts=config.num_local_experts,
            top_k=config.num_experts_per_tok,
            hidden_size=config.hidden_size,
            intermediate_size=config.intermediate_size,
            reduce_results=True,
            renormalize=True,
            quant_config=quant_config,
            prefix=f"{prefix}.experts",
            apply_router_weight_on_input=False,
            has_bias=True,
            activation="swigluoai",
            is_sequence_parallel=self.is_sequence_parallel,
        )
179
180

    def forward(self, x: torch.Tensor) -> torch.Tensor:
181
182
183
184
        num_tokens = x.shape[0]
        if self.is_sequence_parallel:
            x = sequence_parallel_chunk(x)

185
186
187
188
189
190
        if current_platform.is_rocm():
            g = rocm_unquantized_gemm(
                self, x[:, : self.hidden_size], self.router.weight, self.router.bias
            )
        else:
            g = self.router(x)
191
        x = self.experts(hidden_states=x, router_logits=g)
192
193
194
195

        if self.is_sequence_parallel:
            x = tensor_model_parallel_all_gather(x.contiguous(), 0)
            x = x[:num_tokens]
196
        return x
197
198
199
200
201


class TransformerBlock(torch.nn.Module):
    def __init__(
        self,
202
        vllm_config: VllmConfig,
203
        quant_config: QuantizationConfig,
204
205
206
        prefix: str = "",
    ):
        super().__init__()
207
208
209
210

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config

211
        self.layer_idx = extract_layer_index(prefix)
212
        self.attn = OAIAttention(
213
214
215
216
            config,
            prefix=f"{prefix}.attn",
            quant_config=quant_config,
            cache_config=cache_config,
217
218
        )
        self.mlp = MLPBlock(vllm_config, self.layer_idx, prefix=f"{prefix}.mlp")
219
220
        self.input_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
        self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
221

222
223
224
225
    def forward(
        self,
        hidden_states: torch.Tensor,
        positions: torch.Tensor,
226
        residual: torch.Tensor | None,
227
228
229
230
231
232
    ) -> torch.Tensor:
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
233
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
234
        hidden_states = self.attn(hidden_states, positions)
235

236
        # Fully Connected
237
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
238
239
        output = self.mlp(hidden_states)
        return output, residual
240
241
242
243
244
245
246
247
248
249
250
251


@support_torch_compile
class GptOssModel(nn.Module):
    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
    ):
        super().__init__()
        self.config = vllm_config.model_config.hf_config
252
        self.quant_config = vllm_config.quant_config
253
        self.parallel_config = vllm_config.parallel_config
254
255
256
257
258
        self.config.hidden_size = self.config.hidden_size
        self.embedding = VocabParallelEmbedding(
            self.config.vocab_size,
            self.config.hidden_size,
        )
259
260
261
        self.start_layer, self.end_layer, self.layers = make_layers(
            self.config.num_hidden_layers,
            lambda prefix: TransformerBlock(
262
                vllm_config,
263
                prefix=prefix,
264
                quant_config=self.quant_config,
265
266
267
            ),
            prefix=f"{prefix}.layers",
        )
268
        self.norm = RMSNorm(self.config.hidden_size, eps=1e-5)
269
270
271
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], self.config.hidden_size
        )
272
        self.aux_hidden_state_layers = tuple[int, ...]()
273

274
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
275
276
277
278
        return self.embedding(input_ids)

    def forward(
        self,
279
        input_ids: torch.Tensor | None,
280
        positions: torch.Tensor,
281
282
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
283
284
285
286
287
    ) -> torch.Tensor:
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                x = inputs_embeds
            else:
288
                x = self.embed_input_ids(input_ids)
289
290
291
292
293
294
295

            residual = None
        else:
            assert intermediate_tensors is not None
            x = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

296
        aux_hidden_states = []
297
298
        for i in range(self.start_layer, self.end_layer):
            layer = self.layers[i]
299
            if i in self.aux_hidden_state_layers:
300
                aux_hidden_states.append(x if residual is None else x + residual)
301
302
            x, residual = layer(x, positions, residual)
        if not get_pp_group().is_last_rank:
303
            return IntermediateTensors({"hidden_states": x, "residual": residual})
304
        x, _ = self.norm(x, residual)
305
306
307

        if len(aux_hidden_states) > 0:
            return x, aux_hidden_states
308
309
        return x

310
    def _load_weights_mxfp4(
311
312
313
314
315
316
317
318
        self,
        ep_rank_end: int,
        ep_rank_start: int,
        heads_per_rank: int,
        head_start: int,
        weights: Iterable[tuple[str, torch.Tensor]],
        stacked_params_mapping: list[tuple[str, ...]],
    ) -> set[str]:
319
320
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()
321

322
        mxfp4_block = 32
323
324
        use_ep = self.parallel_config.enable_expert_parallel
        num_experts = self.config.num_local_experts
325

326
327
        # In MoE, we need to flatten the tensor parallel size across the data
        # parallel size when EP is disabled.
328
        tp_size, tp_rank = FusedMoEParallelConfig.flatten_tp_across_dp_and_pcp(
329
330
331
            tp_size=get_tensor_model_parallel_world_size(),
            dp_size=get_dp_group().world_size,
            dp_rank=get_dp_group().rank_in_group,
332
333
            pcp_size=get_pcp_group().world_size,
            pcp_rank=get_pcp_group().rank_in_group,
334
        )
335
336

        intermediate_size = self.config.intermediate_size
337
        intermediate_size_block = intermediate_size // mxfp4_block
338
339
        per_rank_intermediate_size_block = cdiv(intermediate_size_block, tp_size)
        per_rank_intermediate_size = per_rank_intermediate_size_block * mxfp4_block
340
341
342

        # Calculate common slicing bounds for current rank
        tp_rank_start = tp_rank * per_rank_intermediate_size
343
        tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size, intermediate_size)
344
345

        for name, weight in weights:
346
347
348
349
            # Skip layers on other devices.
            if is_pp_missing_parameter(name, self):
                continue

350
351
            if ".w13_weight_scale" in name:
                # Handle MLP gate and up projection weights scale
352
353
354
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
355
                    narrow_weight = weight[:, 2 * tp_rank_start : 2 * tp_rank_end, ...]
356

357
                param = params_dict[name]
358
359
360
361
362
363
364
365
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                weight_loader(
                    param,
                    narrow_weight,
                    weight_name=name,
                    shard_id=None,
                    expert_id=None,
                )
366
367
368
                loaded_params.add(name)
                continue
            elif ".w2_weight_scale" in name:
369
370
371
372
                # Handle MLP down projection weights
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
373
374
375
                    narrow_weight = weight[
                        ..., tp_rank_start // mxfp4_block : tp_rank_end // mxfp4_block
                    ]
376

377
                param = params_dict[name]
378
379
380
381
382
383
384
385
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                weight_loader(
                    param,
                    narrow_weight,
                    weight_name=name,
                    shard_id=None,
                    expert_id=None,
                )
386
387
388
389
390
391
                loaded_params.add(name)
                continue
            elif ".w13_weight" in name:
                # Handle MLP gate and up projection weights
                # flat weight from (E, 2 * N, block_size, entry_per_block)
                # to (E, 2 * N, -1), shouldn't trigger copy for contiguous
392
393
394
                weight = weight.view(
                    num_experts, 2 * intermediate_size, -1
                ).contiguous()
395

396
397
                # Extract gate and up projection parts
                # since the weight is shuffled, we can slice directly
398
399
400
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
401
                    narrow_weight = weight[:, 2 * tp_rank_start : 2 * tp_rank_end, ...]
402

403
                param = params_dict[name]
404
405
406
407
408
409
410
411
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                weight_loader(
                    param,
                    narrow_weight,
                    weight_name=name,
                    shard_id=None,
                    expert_id=None,
                )
412
413
414
                loaded_params.add(name)
                continue
            elif ".w2_weight" in name:
415
                # Handle MLP down projection weights
416
417
                # same flatten here, but since 2 mx4 value are packed in 1
                # uint8, divide by 2
418
419
420
                weight = weight.view(
                    num_experts, -1, intermediate_size // 2
                ).contiguous()
421
422
423
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
424
                    narrow_weight = weight[..., tp_rank_start // 2 : tp_rank_end // 2]
425

426
                param = params_dict[name]
427
428
429
430
431
432
433
434
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                weight_loader(
                    param,
                    narrow_weight,
                    weight_name=name,
                    shard_id=None,
                    expert_id=None,
                )
435
436
437
                loaded_params.add(name)
                continue
            elif ".w13_bias" in name:
438
439
440
441
442
                # Handle MLP gate and up projection biases
                # Extract gate and up projection bias parts
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
443
                    narrow_weight = weight[:, 2 * tp_rank_start : 2 * tp_rank_end]
444

445
                param = params_dict[name]
446
447
448
449
450
451
452
453
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                weight_loader(
                    param,
                    narrow_weight,
                    weight_name=name,
                    shard_id=None,
                    expert_id=None,
                )
454
455
456
                loaded_params.add(name)
                continue
            elif ".w2_bias" in name:
457
                # Handle MLP down projection bias
458
                param = params_dict[name]
459
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
460
461
462
463
464
465
                if use_ep:
                    weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
                    # (only load on rank 0 to avoid duplication)
                    if tp_rank != 0:
                        weight.zero_()
466
467
468
                weight_loader(
                    param, weight, weight_name=name, shard_id=None, expert_id=None
                )
469
470
                loaded_params.add(name)
                continue
471
472
473
474
475
476
            elif "sinks" in name:
                # Handle attention sinks (distributed across ranks)
                param = params_dict[name]
                narrow_weight = weight.narrow(0, head_start, heads_per_rank)
                param.data.copy_(narrow_weight)
                loaded_params.add(name)
477
478
479
480
481
482
                continue
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
483
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
484
485
486
487
488
                if weight_loader == default_weight_loader:
                    weight_loader(param, weight)
                else:
                    weight_loader(param, weight, shard_id)
                break
489
490
            else:
                # Handle all other weights with potential renaming
491
                if name not in params_dict:
492
                    continue
493
                param = params_dict[name]
494
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
495
                weight_loader(param, weight)
496
            loaded_params.add(name)
497
        return loaded_params
498
499

    def _load_weights_other(
500
501
        self,
        ep_rank_end: int,
502
        ep_rank_start: int,
503
504
505
506
507
        heads_per_rank: int,
        head_start: int,
        weights: Iterable[tuple[str, torch.Tensor]],
        stacked_params_mapping: list[tuple[str, ...]],
    ) -> set[str]:
508
509
510
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()

511
512
        use_ep = self.parallel_config.enable_expert_parallel

513
514
        # In MoE, we need to flatten the tensor parallel size across the data
        # parallel size when EP is disabled.
515
        tp_size, tp_rank = FusedMoEParallelConfig.flatten_tp_across_dp_and_pcp(
516
517
518
            tp_size=get_tensor_model_parallel_world_size(),
            dp_size=get_dp_group().world_size,
            dp_rank=get_dp_group().rank_in_group,
519
520
            pcp_size=get_pcp_group().world_size,
            pcp_rank=get_pcp_group().rank_in_group,
521
        )
522

523
        intermediate_size = self.config.intermediate_size
524
525
526
        per_rank_intermediate_size = cdiv(intermediate_size, tp_size)
        # Calculate common slicing bounds for current rank
        tp_rank_start = tp_rank * per_rank_intermediate_size
527
        tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size, intermediate_size)
528
529

        for name, weight in weights:
530
531
532
533
            # Skip layers on other devices.
            if is_pp_missing_parameter(name, self):
                continue

534
            if ".w13_weight" in name:
535
536
537
538
539
                # Handle MLP gate and up projection weights
                # Extract gate and up projection parts
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
540
                    narrow_weight = weight[:, :, 2 * tp_rank_start : 2 * tp_rank_end]
541
542

                narrow_weight = narrow_weight.permute(0, 2, 1).contiguous()
543
                param = params_dict[name]
544
545

                param.copy_(narrow_weight)
546
547
548
                loaded_params.add(name)
                continue
            elif ".w2_weight" in name:
549
550
551
552
553
554
                # Handle MLP down projection weights
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
                    narrow_weight = weight[:, tp_rank_start:tp_rank_end, :]
                narrow_weight = narrow_weight.permute(0, 2, 1).contiguous()
555
                param = params_dict[name]
556
557

                param.copy_(narrow_weight)
558
559
560
                loaded_params.add(name)
                continue
            elif ".w13_bias" in name:
561
562
563
564
565
                # Handle MLP gate and up projection biases
                # Extract gate and up projection bias parts
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
566
                    narrow_weight = weight[:, 2 * tp_rank_start : 2 * tp_rank_end]
567

568
                param = params_dict[name]
569
                param.copy_(narrow_weight)
570
571
572
                loaded_params.add(name)
                continue
            elif ".w2_bias" in name:
573
574
575
576
577
578
579
                # Handle MLP down projection bias
                if use_ep:
                    weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
                    # (only load on rank 0 to avoid duplication)
                    if tp_rank != 0:
                        weight.zero_()
580
                param = params_dict[name]
581
                param.copy_(weight)
582
583
                loaded_params.add(name)
                continue
584
585
586
587
588
589
            elif "sinks" in name:
                # Handle attention sinks (distributed across ranks)
                param = params_dict[name]
                narrow_weight = weight.narrow(0, head_start, heads_per_rank)
                param.data.copy_(narrow_weight)
                loaded_params.add(name)
590
591
592
593
594
595
                continue
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
596
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
597
598
599
600
601
                if weight_loader == default_weight_loader:
                    weight_loader(param, weight)
                else:
                    weight_loader(param, weight, shard_id)
                break
602
603
            else:
                # Handle all other weights with potential renaming
604
                if name not in params_dict:
605
                    continue
606
                param = params_dict[name]
607
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
608
                weight_loader(param, weight)
609
            loaded_params.add(name)
610
611
        return loaded_params

612
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
613
614
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
615
616
617
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
        ]

        tp_rank = get_tensor_model_parallel_rank()
        tp_size = get_tensor_model_parallel_world_size()

        # Attention heads per rank
        heads_per_rank = self.config.num_attention_heads // tp_size
        head_start = tp_rank * heads_per_rank

        ep_size = get_ep_group().world_size
        ep_rank = get_ep_group().rank
        num_experts = self.config.num_local_experts
        experts_per_rank = num_experts // ep_size
        ep_rank_start = ep_rank * experts_per_rank
        ep_rank_end = (ep_rank + 1) * experts_per_rank

634
635
636
637
638
        quant_method = (
            self.config.quantization_config["quant_method"]
            if hasattr(self.config, "quantization_config")
            else None
        )
639
        if quant_method == "mxfp4":
640
641
642
643
644
645
646
647
            return self._load_weights_mxfp4(
                ep_rank_end,
                ep_rank_start,
                heads_per_rank,
                head_start,
                weights,
                stacked_params_mapping,
            )
648
        else:
649
            return self._load_weights_other(
650
                ep_rank_end,
651
                ep_rank_start,
652
653
654
655
656
                heads_per_rank,
                head_start,
                weights,
                stacked_params_mapping,
            )
657
658


659
class GptOssForCausalLM(nn.Module, SupportsPP, SupportsEagle3, SupportsLoRA):
660
    is_3d_moe_weight: bool = True
661
    packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698

    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_substr={
            ".self_attn.": ".attn.",
        },
        orig_to_new_suffix={
            ".embed_tokens.weight": ".embedding.weight",
            # MoE MXFP4 weights
            ".gate_up_proj_blocks": ".w13_weight",
            ".down_proj_blocks": ".w2_weight",
            ".gate_up_proj_scales": ".w13_weight_scale",
            ".down_proj_scales": ".w2_weight_scale",
            # MoE other weights
            ".gate_up_proj": ".w13_weight",
            ".down_proj": ".w2_weight",
            # MoE Bias
            ".gate_up_proj_bias": ".w13_bias",
            ".down_proj_bias": ".w2_bias",
        },
    )

    def __init__(
        self,
        vllm_config: VllmConfig,
        prefix: str = "",
    ):
        super().__init__()
        self.vllm_config = vllm_config
        self.config = vllm_config.model_config.hf_config

        self.model = GptOssModel(
            vllm_config=vllm_config,
            prefix=maybe_prefix(prefix, "model"),
        )
        self.lm_head = ParallelLMHead(
            self.config.vocab_size,
            self.config.hidden_size,
699
            prefix=maybe_prefix(prefix, "lm_head"),
700
701
        )
        self.logits_processor = LogitsProcessor(self.config.vocab_size)
702
        self.make_empty_intermediate_tensors = (
703
704
            self.model.make_empty_intermediate_tensors
        )
705

706
707
708
709
710
711
712
    def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
        self.model.aux_hidden_state_layers = layers

    def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
        num_layers = len(self.model.layers)
        return (2, num_layers // 2, num_layers - 3)

713
714
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
715

716
717
    def forward(
        self,
718
        input_ids: torch.Tensor | None,
719
        positions: torch.Tensor,
720
721
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
722
723
    ) -> torch.Tensor:
        return self.model(input_ids, positions, intermediate_tensors, inputs_embeds)
724

725
726
    def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head, hidden_states)
727
728
        return logits

729
730
731
732
    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
        # Params for weights, weight scales, activation scales
        # (param_name, weight_name, expert_id, shard_id)
        return FusedMoE.make_expert_params_mapping(
733
            self,
734
735
736
737
738
739
740
            ckpt_gate_proj_name="gate_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="up_proj",
            num_experts=self.config.num_local_experts,
            num_redundant_experts=0,
        )

741
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
742
743
        loader = AutoWeightsLoader(
            self,
744
            skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
745
746
        )
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)