"vscode:/vscode.git/clone" did not exist on "997c8811d6aadf92dc299e0c2a8d274117308880"
gpt_oss.py 28.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable

import torch
import torch.distributed as dist
from torch import nn
from transformers import GptOssConfig

from vllm.attention import Attention, AttentionType
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
13
from vllm.distributed import (
14
    get_dp_group,
15
    get_ep_group,
16
    get_pcp_group,
17
18
19
20
21
    get_pp_group,
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
    tensor_model_parallel_all_gather,
)
22
from vllm.model_executor.layers.fused_moe import FusedMoE
23
from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig
24
from vllm.model_executor.layers.layernorm import RMSNorm
25
from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear
26
27
28
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
29
from vllm.model_executor.layers.utils import rocm_unquantized_gemm
30
from vllm.model_executor.layers.vocab_parallel_embedding import (
31
32
33
    ParallelLMHead,
    VocabParallelEmbedding,
)
34
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
35
from vllm.model_executor.models.utils import sequence_parallel_chunk
36
from vllm.platforms import current_platform
37
from vllm.sequence import IntermediateTensors
38
from vllm.utils.math_utils import cdiv
39

40
from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP
41
42
43
44
45
46
47
48
49
from .utils import (
    AutoWeightsLoader,
    WeightsMapper,
    extract_layer_index,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
50
51
52
53
54
55


class OAIAttention(nn.Module):
    def __init__(
        self,
        config: GptOssConfig,
56
57
        quant_config: QuantizationConfig | None = None,
        cache_config: CacheConfig | None = None,
58
59
60
61
62
63
64
65
66
67
68
69
70
71
        prefix: str = "",
    ):
        super().__init__()
        self.layer_idx = extract_layer_index(prefix)
        self.head_dim = config.head_dim
        self.num_attention_heads = config.num_attention_heads
        self.num_key_value_heads = config.num_key_value_heads
        self.hidden_size = config.hidden_size

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=config.max_position_embeddings,
            dtype=torch.float32,
72
73
            rope_parameters={
                "rope_theta": config.rope_parameters["rope_theta"],
74
                "rope_type": "yarn",
75
76
                "factor": config.rope_parameters["factor"],
                "original_max_position_embeddings": config.rope_parameters[
77
78
                    "original_max_position_embeddings"
                ],
79
80
                "beta_fast": config.rope_parameters["beta_fast"],
                "beta_slow": config.rope_parameters["beta_slow"],
81
                "truncate": config.rope_parameters.get("truncate", True),
82
83
84
85
86
87
88
            },
            is_neox_style=True,
        )

        tp_size = get_tensor_model_parallel_world_size()

        self.sinks = torch.nn.Parameter(
89
90
            torch.empty(config.num_attention_heads // tp_size, requires_grad=False)
        )
91
92
93
94
95

        self.q_size = self.num_attention_heads * self.head_dim // tp_size
        self.kv_size = self.num_key_value_heads * self.head_dim // tp_size
        self.scaling = self.head_dim**-0.5

96
        self.qkv_proj = QKVParallelLinear(
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
            hidden_size=self.hidden_size,
            head_size=self.head_dim,
            total_num_heads=self.num_attention_heads,
            total_num_kv_heads=self.num_key_value_heads,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
        )

        self.o_proj = RowParallelLinear(
            input_size=self.num_attention_heads * self.head_dim,
            output_size=self.hidden_size,
            quant_config=quant_config,
            prefix=f"{prefix}.o_proj",
        )

        self.num_local_attention_heads = config.num_attention_heads // tp_size
        self.num_local_key_value_heads = config.num_key_value_heads // tp_size

        # Only apply sliding window to every other layer
116
        sliding_window = config.sliding_window if self.layer_idx % 2 == 0 else None
117
118
119
120
121
122
123
124
125
126
127
128
129
        self.attn = Attention(
            self.num_local_attention_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_local_key_value_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            per_layer_sliding_window=sliding_window,
            attn_type=AttentionType.DECODER,
            prefix=f"{prefix}.attn",
            sinks=self.sinks,
        )

130
131
132
    def forward(
        self, hidden_states: torch.Tensor, positions: torch.Tensor
    ) -> torch.Tensor:
133
        qkv, _ = self.qkv_proj(hidden_states)
134
135
136
137
138
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
        v = v.contiguous()
        attn_output = self.attn(q, k, v)
        output, _ = self.o_proj(attn_output)
139
        return output
140
141
142
143
144


class MLPBlock(torch.nn.Module):
    def __init__(
        self,
145
        vllm_config: VllmConfig,
146
147
148
149
        layer_idx: int,
        prefix: str = "",
    ):
        super().__init__()
150
151
152
153
154
155
156

        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        parallel_config = vllm_config.parallel_config

        self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe

157
158
        self.layer_idx = layer_idx
        self.num_experts = config.num_local_experts
159
        self.hidden_size = config.hidden_size
160
161
        self.experts_per_token = config.num_experts_per_tok
        self.world_size = dist.get_world_size() if dist.is_initialized() else 1
162
        self.router = torch.nn.Linear(config.hidden_size, config.num_local_experts)
163
        assert config.intermediate_size % self.world_size == 0
164
165
166
167
168
169
170
171
172
173
174
175
176
177
        self.experts = FusedMoE(
            num_experts=config.num_local_experts,
            top_k=config.num_experts_per_tok,
            hidden_size=config.hidden_size,
            intermediate_size=config.intermediate_size,
            reduce_results=True,
            renormalize=True,
            quant_config=quant_config,
            prefix=f"{prefix}.experts",
            apply_router_weight_on_input=False,
            has_bias=True,
            activation="swigluoai",
            is_sequence_parallel=self.is_sequence_parallel,
        )
178
179

    def forward(self, x: torch.Tensor) -> torch.Tensor:
180
181
182
183
        num_tokens = x.shape[0]
        if self.is_sequence_parallel:
            x = sequence_parallel_chunk(x)

184
185
186
187
188
189
        if current_platform.is_rocm():
            g = rocm_unquantized_gemm(
                self, x[:, : self.hidden_size], self.router.weight, self.router.bias
            )
        else:
            g = self.router(x)
190
        x = self.experts(hidden_states=x, router_logits=g)
191
192
193
194

        if self.is_sequence_parallel:
            x = tensor_model_parallel_all_gather(x.contiguous(), 0)
            x = x[:num_tokens]
195
        return x
196
197
198
199
200


class TransformerBlock(torch.nn.Module):
    def __init__(
        self,
201
        vllm_config: VllmConfig,
202
        quant_config: QuantizationConfig,
203
204
205
        prefix: str = "",
    ):
        super().__init__()
206
207
208
209

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config

210
        self.layer_idx = extract_layer_index(prefix)
211
        self.attn = OAIAttention(
212
213
214
215
            config,
            prefix=f"{prefix}.attn",
            quant_config=quant_config,
            cache_config=cache_config,
216
217
        )
        self.mlp = MLPBlock(vllm_config, self.layer_idx, prefix=f"{prefix}.mlp")
218
219
        self.input_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
        self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
220

221
222
223
224
    def forward(
        self,
        hidden_states: torch.Tensor,
        positions: torch.Tensor,
225
        residual: torch.Tensor | None,
226
227
228
229
230
231
    ) -> torch.Tensor:
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
232
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
233
        hidden_states = self.attn(hidden_states, positions)
234

235
        # Fully Connected
236
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
237
238
        output = self.mlp(hidden_states)
        return output, residual
239
240
241
242
243
244
245
246
247
248
249
250


@support_torch_compile
class GptOssModel(nn.Module):
    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
    ):
        super().__init__()
        self.config = vllm_config.model_config.hf_config
251
        self.quant_config = vllm_config.quant_config
252
        self.parallel_config = vllm_config.parallel_config
253
254
255
256
257
        self.config.hidden_size = self.config.hidden_size
        self.embedding = VocabParallelEmbedding(
            self.config.vocab_size,
            self.config.hidden_size,
        )
258
259
260
        self.start_layer, self.end_layer, self.layers = make_layers(
            self.config.num_hidden_layers,
            lambda prefix: TransformerBlock(
261
                vllm_config,
262
                prefix=prefix,
263
                quant_config=self.quant_config,
264
265
266
            ),
            prefix=f"{prefix}.layers",
        )
267
        self.norm = RMSNorm(self.config.hidden_size, eps=1e-5)
268
269
270
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], self.config.hidden_size
        )
271
        self.aux_hidden_state_layers = tuple[int, ...]()
272

273
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
274
275
276
277
278
279
        return self.embedding(input_ids)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
280
281
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
282
283
284
285
286
    ) -> torch.Tensor:
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                x = inputs_embeds
            else:
287
                x = self.embed_input_ids(input_ids)
288
289
290
291
292
293
294

            residual = None
        else:
            assert intermediate_tensors is not None
            x = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

295
        aux_hidden_states = []
296
297
        for i in range(self.start_layer, self.end_layer):
            layer = self.layers[i]
298
            if i in self.aux_hidden_state_layers:
299
                aux_hidden_states.append(x if residual is None else x + residual)
300
301
            x, residual = layer(x, positions, residual)
        if not get_pp_group().is_last_rank:
302
            return IntermediateTensors({"hidden_states": x, "residual": residual})
303
        x, _ = self.norm(x, residual)
304
305
306

        if len(aux_hidden_states) > 0:
            return x, aux_hidden_states
307
308
        return x

309
    def _load_weights_mxfp4(
310
311
312
313
314
315
316
317
        self,
        ep_rank_end: int,
        ep_rank_start: int,
        heads_per_rank: int,
        head_start: int,
        weights: Iterable[tuple[str, torch.Tensor]],
        stacked_params_mapping: list[tuple[str, ...]],
    ) -> set[str]:
318
319
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()
320

321
        mxfp4_block = 32
322
323
        use_ep = self.parallel_config.enable_expert_parallel
        num_experts = self.config.num_local_experts
324

325
326
        # In MoE, we need to flatten the tensor parallel size across the data
        # parallel size when EP is disabled.
327
        tp_size, tp_rank = FusedMoEParallelConfig.flatten_tp_across_dp_and_pcp(
328
329
330
            tp_size=get_tensor_model_parallel_world_size(),
            dp_size=get_dp_group().world_size,
            dp_rank=get_dp_group().rank_in_group,
331
332
            pcp_size=get_pcp_group().world_size,
            pcp_rank=get_pcp_group().rank_in_group,
333
        )
334
335

        intermediate_size = self.config.intermediate_size
336
        intermediate_size_block = intermediate_size // mxfp4_block
337
338
        per_rank_intermediate_size_block = cdiv(intermediate_size_block, tp_size)
        per_rank_intermediate_size = per_rank_intermediate_size_block * mxfp4_block
339
340
341

        # Calculate common slicing bounds for current rank
        tp_rank_start = tp_rank * per_rank_intermediate_size
342
        tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size, intermediate_size)
343
344

        for name, weight in weights:
345
346
347
348
            # Skip layers on other devices.
            if is_pp_missing_parameter(name, self):
                continue

349
350
            if ".w13_weight_scale" in name:
                # Handle MLP gate and up projection weights scale
351
352
353
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
354
                    narrow_weight = weight[:, 2 * tp_rank_start : 2 * tp_rank_end, ...]
355

356
                param = params_dict[name]
357
358
359
360
361
362
363
364
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                weight_loader(
                    param,
                    narrow_weight,
                    weight_name=name,
                    shard_id=None,
                    expert_id=None,
                )
365
366
367
                loaded_params.add(name)
                continue
            elif ".w2_weight_scale" in name:
368
369
370
371
                # Handle MLP down projection weights
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
372
373
374
                    narrow_weight = weight[
                        ..., tp_rank_start // mxfp4_block : tp_rank_end // mxfp4_block
                    ]
375

376
                param = params_dict[name]
377
378
379
380
381
382
383
384
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                weight_loader(
                    param,
                    narrow_weight,
                    weight_name=name,
                    shard_id=None,
                    expert_id=None,
                )
385
386
387
388
389
390
                loaded_params.add(name)
                continue
            elif ".w13_weight" in name:
                # Handle MLP gate and up projection weights
                # flat weight from (E, 2 * N, block_size, entry_per_block)
                # to (E, 2 * N, -1), shouldn't trigger copy for contiguous
391
392
393
                weight = weight.view(
                    num_experts, 2 * intermediate_size, -1
                ).contiguous()
394

395
396
                # Extract gate and up projection parts
                # since the weight is shuffled, we can slice directly
397
398
399
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
400
                    narrow_weight = weight[:, 2 * tp_rank_start : 2 * tp_rank_end, ...]
401

402
                param = params_dict[name]
403
404
405
406
407
408
409
410
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                weight_loader(
                    param,
                    narrow_weight,
                    weight_name=name,
                    shard_id=None,
                    expert_id=None,
                )
411
412
413
                loaded_params.add(name)
                continue
            elif ".w2_weight" in name:
414
                # Handle MLP down projection weights
415
416
                # same flatten here, but since 2 mx4 value are packed in 1
                # uint8, divide by 2
417
418
419
                weight = weight.view(
                    num_experts, -1, intermediate_size // 2
                ).contiguous()
420
421
422
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
423
                    narrow_weight = weight[..., tp_rank_start // 2 : tp_rank_end // 2]
424

425
                param = params_dict[name]
426
427
428
429
430
431
432
433
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                weight_loader(
                    param,
                    narrow_weight,
                    weight_name=name,
                    shard_id=None,
                    expert_id=None,
                )
434
435
436
                loaded_params.add(name)
                continue
            elif ".w13_bias" in name:
437
438
439
440
441
                # Handle MLP gate and up projection biases
                # Extract gate and up projection bias parts
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
442
                    narrow_weight = weight[:, 2 * tp_rank_start : 2 * tp_rank_end]
443

444
                param = params_dict[name]
445
446
447
448
449
450
451
452
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                weight_loader(
                    param,
                    narrow_weight,
                    weight_name=name,
                    shard_id=None,
                    expert_id=None,
                )
453
454
455
                loaded_params.add(name)
                continue
            elif ".w2_bias" in name:
456
                # Handle MLP down projection bias
457
                param = params_dict[name]
458
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
459
460
461
462
463
464
                if use_ep:
                    weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
                    # (only load on rank 0 to avoid duplication)
                    if tp_rank != 0:
                        weight.zero_()
465
466
467
                weight_loader(
                    param, weight, weight_name=name, shard_id=None, expert_id=None
                )
468
469
                loaded_params.add(name)
                continue
470
471
472
473
474
475
            elif "sinks" in name:
                # Handle attention sinks (distributed across ranks)
                param = params_dict[name]
                narrow_weight = weight.narrow(0, head_start, heads_per_rank)
                param.data.copy_(narrow_weight)
                loaded_params.add(name)
476
477
478
479
480
481
                continue
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
482
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
483
484
485
486
487
                if weight_loader == default_weight_loader:
                    weight_loader(param, weight)
                else:
                    weight_loader(param, weight, shard_id)
                break
488
489
            else:
                # Handle all other weights with potential renaming
490
                if name not in params_dict:
491
                    continue
492
                param = params_dict[name]
493
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
494
                weight_loader(param, weight)
495
            loaded_params.add(name)
496
        return loaded_params
497
498

    def _load_weights_other(
499
500
        self,
        ep_rank_end: int,
501
        ep_rank_start: int,
502
503
504
505
506
        heads_per_rank: int,
        head_start: int,
        weights: Iterable[tuple[str, torch.Tensor]],
        stacked_params_mapping: list[tuple[str, ...]],
    ) -> set[str]:
507
508
509
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()

510
511
        use_ep = self.parallel_config.enable_expert_parallel

512
513
        # In MoE, we need to flatten the tensor parallel size across the data
        # parallel size when EP is disabled.
514
        tp_size, tp_rank = FusedMoEParallelConfig.flatten_tp_across_dp_and_pcp(
515
516
517
            tp_size=get_tensor_model_parallel_world_size(),
            dp_size=get_dp_group().world_size,
            dp_rank=get_dp_group().rank_in_group,
518
519
            pcp_size=get_pcp_group().world_size,
            pcp_rank=get_pcp_group().rank_in_group,
520
        )
521

522
        intermediate_size = self.config.intermediate_size
523
524
525
        per_rank_intermediate_size = cdiv(intermediate_size, tp_size)
        # Calculate common slicing bounds for current rank
        tp_rank_start = tp_rank * per_rank_intermediate_size
526
        tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size, intermediate_size)
527
528

        for name, weight in weights:
529
530
531
532
            # Skip layers on other devices.
            if is_pp_missing_parameter(name, self):
                continue

533
            if ".w13_weight" in name:
534
535
536
537
538
                # Handle MLP gate and up projection weights
                # Extract gate and up projection parts
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
539
                    narrow_weight = weight[:, :, 2 * tp_rank_start : 2 * tp_rank_end]
540
541

                narrow_weight = narrow_weight.permute(0, 2, 1).contiguous()
542
                param = params_dict[name]
543
544

                param.copy_(narrow_weight)
545
546
547
                loaded_params.add(name)
                continue
            elif ".w2_weight" in name:
548
549
550
551
552
553
                # Handle MLP down projection weights
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
                    narrow_weight = weight[:, tp_rank_start:tp_rank_end, :]
                narrow_weight = narrow_weight.permute(0, 2, 1).contiguous()
554
                param = params_dict[name]
555
556

                param.copy_(narrow_weight)
557
558
559
                loaded_params.add(name)
                continue
            elif ".w13_bias" in name:
560
561
562
563
564
                # Handle MLP gate and up projection biases
                # Extract gate and up projection bias parts
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
565
                    narrow_weight = weight[:, 2 * tp_rank_start : 2 * tp_rank_end]
566

567
                param = params_dict[name]
568
                param.copy_(narrow_weight)
569
570
571
                loaded_params.add(name)
                continue
            elif ".w2_bias" in name:
572
573
574
575
576
577
578
                # Handle MLP down projection bias
                if use_ep:
                    weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
                    # (only load on rank 0 to avoid duplication)
                    if tp_rank != 0:
                        weight.zero_()
579
                param = params_dict[name]
580
                param.copy_(weight)
581
582
                loaded_params.add(name)
                continue
583
584
585
586
587
588
            elif "sinks" in name:
                # Handle attention sinks (distributed across ranks)
                param = params_dict[name]
                narrow_weight = weight.narrow(0, head_start, heads_per_rank)
                param.data.copy_(narrow_weight)
                loaded_params.add(name)
589
590
591
592
593
594
                continue
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
595
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
596
597
598
599
600
                if weight_loader == default_weight_loader:
                    weight_loader(param, weight)
                else:
                    weight_loader(param, weight, shard_id)
                break
601
602
            else:
                # Handle all other weights with potential renaming
603
                if name not in params_dict:
604
                    continue
605
                param = params_dict[name]
606
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
607
                weight_loader(param, weight)
608
            loaded_params.add(name)
609
610
        return loaded_params

611
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
612
613
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
614
615
616
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
        ]

        tp_rank = get_tensor_model_parallel_rank()
        tp_size = get_tensor_model_parallel_world_size()

        # Attention heads per rank
        heads_per_rank = self.config.num_attention_heads // tp_size
        head_start = tp_rank * heads_per_rank

        ep_size = get_ep_group().world_size
        ep_rank = get_ep_group().rank
        num_experts = self.config.num_local_experts
        experts_per_rank = num_experts // ep_size
        ep_rank_start = ep_rank * experts_per_rank
        ep_rank_end = (ep_rank + 1) * experts_per_rank

633
634
635
636
637
        quant_method = (
            self.config.quantization_config["quant_method"]
            if hasattr(self.config, "quantization_config")
            else None
        )
638
        if quant_method == "mxfp4":
639
640
641
642
643
644
645
646
            return self._load_weights_mxfp4(
                ep_rank_end,
                ep_rank_start,
                heads_per_rank,
                head_start,
                weights,
                stacked_params_mapping,
            )
647
        else:
648
649
            return self._load_weights_other(
                ep_rank_start,
650
                ep_rank_end,
651
652
653
654
655
                heads_per_rank,
                head_start,
                weights,
                stacked_params_mapping,
            )
656
657


658
class GptOssForCausalLM(nn.Module, SupportsPP, SupportsEagle3, SupportsLoRA):
659
    packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696

    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_substr={
            ".self_attn.": ".attn.",
        },
        orig_to_new_suffix={
            ".embed_tokens.weight": ".embedding.weight",
            # MoE MXFP4 weights
            ".gate_up_proj_blocks": ".w13_weight",
            ".down_proj_blocks": ".w2_weight",
            ".gate_up_proj_scales": ".w13_weight_scale",
            ".down_proj_scales": ".w2_weight_scale",
            # MoE other weights
            ".gate_up_proj": ".w13_weight",
            ".down_proj": ".w2_weight",
            # MoE Bias
            ".gate_up_proj_bias": ".w13_bias",
            ".down_proj_bias": ".w2_bias",
        },
    )

    def __init__(
        self,
        vllm_config: VllmConfig,
        prefix: str = "",
    ):
        super().__init__()
        self.vllm_config = vllm_config
        self.config = vllm_config.model_config.hf_config

        self.model = GptOssModel(
            vllm_config=vllm_config,
            prefix=maybe_prefix(prefix, "model"),
        )
        self.lm_head = ParallelLMHead(
            self.config.vocab_size,
            self.config.hidden_size,
697
            prefix=maybe_prefix(prefix, "lm_head"),
698
699
        )
        self.logits_processor = LogitsProcessor(self.config.vocab_size)
700
        self.make_empty_intermediate_tensors = (
701
702
            self.model.make_empty_intermediate_tensors
        )
703

704
705
706
707
708
709
710
    def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
        self.model.aux_hidden_state_layers = layers

    def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
        num_layers = len(self.model.layers)
        return (2, num_layers // 2, num_layers - 3)

711
712
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
713

714
715
716
717
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
718
719
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
720
721
    ) -> torch.Tensor:
        return self.model(input_ids, positions, intermediate_tensors, inputs_embeds)
722

723
724
    def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head, hidden_states)
725
726
        return logits

727
728
729
730
731
732
733
734
735
736
737
    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
        # Params for weights, weight scales, activation scales
        # (param_name, weight_name, expert_id, shard_id)
        return FusedMoE.make_expert_params_mapping(
            ckpt_gate_proj_name="gate_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="up_proj",
            num_experts=self.config.num_local_experts,
            num_redundant_experts=0,
        )

738
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
739
740
        loader = AutoWeightsLoader(
            self,
741
            skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
742
743
        )
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)