gpt_oss.py 28.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable

import torch
import torch.distributed as dist
from torch import nn
from transformers import GptOssConfig

from vllm.attention import Attention, AttentionType
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
13
from vllm.distributed import (
14
    get_dp_group,
15
    get_ep_group,
16
    get_pcp_group,
17
18
19
20
21
    get_pp_group,
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
    tensor_model_parallel_all_gather,
)
22
from vllm.model_executor.layers.fused_moe import FusedMoE
23
from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig
24
from vllm.model_executor.layers.layernorm import RMSNorm
25
from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear
26
27
28
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
29
from vllm.model_executor.layers.utils import rocm_unquantized_gemm
30
from vllm.model_executor.layers.vocab_parallel_embedding import (
31
32
33
    ParallelLMHead,
    VocabParallelEmbedding,
)
34
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
35
from vllm.model_executor.models.utils import sequence_parallel_chunk
36
from vllm.platforms import current_platform
37
from vllm.sequence import IntermediateTensors
38
from vllm.utils.math_utils import cdiv
39

40
from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP
41
42
43
44
45
46
47
48
49
from .utils import (
    AutoWeightsLoader,
    WeightsMapper,
    extract_layer_index,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
50
51
52
53
54
55


class OAIAttention(nn.Module):
    def __init__(
        self,
        config: GptOssConfig,
56
57
        quant_config: QuantizationConfig | None = None,
        cache_config: CacheConfig | None = None,
58
59
60
61
62
63
64
65
66
67
68
69
70
71
        prefix: str = "",
    ):
        super().__init__()
        self.layer_idx = extract_layer_index(prefix)
        self.head_dim = config.head_dim
        self.num_attention_heads = config.num_attention_heads
        self.num_key_value_heads = config.num_key_value_heads
        self.hidden_size = config.hidden_size

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=config.max_position_embeddings,
            dtype=torch.float32,
72
73
            rope_parameters={
                "rope_theta": config.rope_parameters["rope_theta"],
74
                "rope_type": "yarn",
75
76
                "factor": config.rope_parameters["factor"],
                "original_max_position_embeddings": config.rope_parameters[
77
78
                    "original_max_position_embeddings"
                ],
79
80
                "beta_fast": config.rope_parameters["beta_fast"],
                "beta_slow": config.rope_parameters["beta_slow"],
81
82
83
84
85
86
87
            },
            is_neox_style=True,
        )

        tp_size = get_tensor_model_parallel_world_size()

        self.sinks = torch.nn.Parameter(
88
89
            torch.empty(config.num_attention_heads // tp_size, requires_grad=False)
        )
90
91
92
93
94

        self.q_size = self.num_attention_heads * self.head_dim // tp_size
        self.kv_size = self.num_key_value_heads * self.head_dim // tp_size
        self.scaling = self.head_dim**-0.5

95
        self.qkv_proj = QKVParallelLinear(
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
            hidden_size=self.hidden_size,
            head_size=self.head_dim,
            total_num_heads=self.num_attention_heads,
            total_num_kv_heads=self.num_key_value_heads,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
        )

        self.o_proj = RowParallelLinear(
            input_size=self.num_attention_heads * self.head_dim,
            output_size=self.hidden_size,
            quant_config=quant_config,
            prefix=f"{prefix}.o_proj",
        )

        self.num_local_attention_heads = config.num_attention_heads // tp_size
        self.num_local_key_value_heads = config.num_key_value_heads // tp_size

        # Only apply sliding window to every other layer
115
        sliding_window = config.sliding_window if self.layer_idx % 2 == 0 else None
116
117
118
119
120
121
122
123
124
125
126
127
128
        self.attn = Attention(
            self.num_local_attention_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_local_key_value_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            per_layer_sliding_window=sliding_window,
            attn_type=AttentionType.DECODER,
            prefix=f"{prefix}.attn",
            sinks=self.sinks,
        )

129
130
131
    def forward(
        self, hidden_states: torch.Tensor, positions: torch.Tensor
    ) -> torch.Tensor:
132
        qkv, _ = self.qkv_proj(hidden_states)
133
134
135
136
137
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
        v = v.contiguous()
        attn_output = self.attn(q, k, v)
        output, _ = self.o_proj(attn_output)
138
        return output
139
140
141
142
143


class MLPBlock(torch.nn.Module):
    def __init__(
        self,
144
        vllm_config: VllmConfig,
145
146
147
148
        layer_idx: int,
        prefix: str = "",
    ):
        super().__init__()
149
150
151
152
153
154
155

        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        parallel_config = vllm_config.parallel_config

        self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe

156
157
        self.layer_idx = layer_idx
        self.num_experts = config.num_local_experts
158
        self.hidden_size = config.hidden_size
159
160
        self.experts_per_token = config.num_experts_per_tok
        self.world_size = dist.get_world_size() if dist.is_initialized() else 1
161
        self.router = torch.nn.Linear(config.hidden_size, config.num_local_experts)
162
        assert config.intermediate_size % self.world_size == 0
163
164
165
166
167
168
169
170
171
172
173
174
175
176
        self.experts = FusedMoE(
            num_experts=config.num_local_experts,
            top_k=config.num_experts_per_tok,
            hidden_size=config.hidden_size,
            intermediate_size=config.intermediate_size,
            reduce_results=True,
            renormalize=True,
            quant_config=quant_config,
            prefix=f"{prefix}.experts",
            apply_router_weight_on_input=False,
            has_bias=True,
            activation="swigluoai",
            is_sequence_parallel=self.is_sequence_parallel,
        )
177
178

    def forward(self, x: torch.Tensor) -> torch.Tensor:
179
180
181
182
        num_tokens = x.shape[0]
        if self.is_sequence_parallel:
            x = sequence_parallel_chunk(x)

183
184
185
186
187
188
        if current_platform.is_rocm():
            g = rocm_unquantized_gemm(
                self, x[:, : self.hidden_size], self.router.weight, self.router.bias
            )
        else:
            g = self.router(x)
189
        x = self.experts(hidden_states=x, router_logits=g)
190
191
192
193

        if self.is_sequence_parallel:
            x = tensor_model_parallel_all_gather(x.contiguous(), 0)
            x = x[:num_tokens]
194
        return x
195
196
197
198
199


class TransformerBlock(torch.nn.Module):
    def __init__(
        self,
200
        vllm_config: VllmConfig,
201
        quant_config: QuantizationConfig,
202
203
204
        prefix: str = "",
    ):
        super().__init__()
205
206
207
208

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config

209
        self.layer_idx = extract_layer_index(prefix)
210
        self.attn = OAIAttention(
211
212
213
214
            config,
            prefix=f"{prefix}.attn",
            quant_config=quant_config,
            cache_config=cache_config,
215
216
        )
        self.mlp = MLPBlock(vllm_config, self.layer_idx, prefix=f"{prefix}.mlp")
217
218
        self.input_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
        self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
219

220
221
222
223
    def forward(
        self,
        hidden_states: torch.Tensor,
        positions: torch.Tensor,
224
        residual: torch.Tensor | None,
225
226
227
228
229
230
    ) -> torch.Tensor:
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
231
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
232
        hidden_states = self.attn(hidden_states, positions)
233

234
        # Fully Connected
235
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
236
237
        output = self.mlp(hidden_states)
        return output, residual
238
239
240
241
242
243
244
245
246
247
248
249


@support_torch_compile
class GptOssModel(nn.Module):
    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
    ):
        super().__init__()
        self.config = vllm_config.model_config.hf_config
250
        self.quant_config = vllm_config.quant_config
251
        self.parallel_config = vllm_config.parallel_config
252
253
254
255
256
        self.config.hidden_size = self.config.hidden_size
        self.embedding = VocabParallelEmbedding(
            self.config.vocab_size,
            self.config.hidden_size,
        )
257
258
259
        self.start_layer, self.end_layer, self.layers = make_layers(
            self.config.num_hidden_layers,
            lambda prefix: TransformerBlock(
260
                vllm_config,
261
                prefix=prefix,
262
                quant_config=self.quant_config,
263
264
265
            ),
            prefix=f"{prefix}.layers",
        )
266
        self.norm = RMSNorm(self.config.hidden_size, eps=1e-5)
267
268
269
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], self.config.hidden_size
        )
270
        self.aux_hidden_state_layers = tuple[int, ...]()
271

272
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
273
274
275
276
277
278
        return self.embedding(input_ids)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
279
280
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
281
282
283
284
285
    ) -> torch.Tensor:
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                x = inputs_embeds
            else:
286
                x = self.embed_input_ids(input_ids)
287
288
289
290
291
292
293

            residual = None
        else:
            assert intermediate_tensors is not None
            x = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

294
        aux_hidden_states = []
295
296
        for i in range(self.start_layer, self.end_layer):
            layer = self.layers[i]
297
            if i in self.aux_hidden_state_layers:
298
                aux_hidden_states.append(x if residual is None else x + residual)
299
300
            x, residual = layer(x, positions, residual)
        if not get_pp_group().is_last_rank:
301
            return IntermediateTensors({"hidden_states": x, "residual": residual})
302
        x, _ = self.norm(x, residual)
303
304
305

        if len(aux_hidden_states) > 0:
            return x, aux_hidden_states
306
307
        return x

308
    def _load_weights_mxfp4(
309
310
311
312
313
314
315
316
        self,
        ep_rank_end: int,
        ep_rank_start: int,
        heads_per_rank: int,
        head_start: int,
        weights: Iterable[tuple[str, torch.Tensor]],
        stacked_params_mapping: list[tuple[str, ...]],
    ) -> set[str]:
317
318
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()
319

320
        mxfp4_block = 32
321
322
        use_ep = self.parallel_config.enable_expert_parallel
        num_experts = self.config.num_local_experts
323

324
325
        # In MoE, we need to flatten the tensor parallel size across the data
        # parallel size when EP is disabled.
326
        tp_size, tp_rank = FusedMoEParallelConfig.flatten_tp_across_dp_and_pcp(
327
328
329
            tp_size=get_tensor_model_parallel_world_size(),
            dp_size=get_dp_group().world_size,
            dp_rank=get_dp_group().rank_in_group,
330
331
            pcp_size=get_pcp_group().world_size,
            pcp_rank=get_pcp_group().rank_in_group,
332
        )
333
334

        intermediate_size = self.config.intermediate_size
335
        intermediate_size_block = intermediate_size // mxfp4_block
336
337
        per_rank_intermediate_size_block = cdiv(intermediate_size_block, tp_size)
        per_rank_intermediate_size = per_rank_intermediate_size_block * mxfp4_block
338
339
340

        # Calculate common slicing bounds for current rank
        tp_rank_start = tp_rank * per_rank_intermediate_size
341
        tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size, intermediate_size)
342
343

        for name, weight in weights:
344
345
346
347
            # Skip layers on other devices.
            if is_pp_missing_parameter(name, self):
                continue

348
349
            if ".w13_weight_scale" in name:
                # Handle MLP gate and up projection weights scale
350
351
352
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
353
                    narrow_weight = weight[:, 2 * tp_rank_start : 2 * tp_rank_end, ...]
354

355
                param = params_dict[name]
356
357
358
359
360
361
362
363
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                weight_loader(
                    param,
                    narrow_weight,
                    weight_name=name,
                    shard_id=None,
                    expert_id=None,
                )
364
365
366
                loaded_params.add(name)
                continue
            elif ".w2_weight_scale" in name:
367
368
369
370
                # Handle MLP down projection weights
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
371
372
373
                    narrow_weight = weight[
                        ..., tp_rank_start // mxfp4_block : tp_rank_end // mxfp4_block
                    ]
374

375
                param = params_dict[name]
376
377
378
379
380
381
382
383
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                weight_loader(
                    param,
                    narrow_weight,
                    weight_name=name,
                    shard_id=None,
                    expert_id=None,
                )
384
385
386
387
388
389
                loaded_params.add(name)
                continue
            elif ".w13_weight" in name:
                # Handle MLP gate and up projection weights
                # flat weight from (E, 2 * N, block_size, entry_per_block)
                # to (E, 2 * N, -1), shouldn't trigger copy for contiguous
390
391
392
                weight = weight.view(
                    num_experts, 2 * intermediate_size, -1
                ).contiguous()
393

394
395
                # Extract gate and up projection parts
                # since the weight is shuffled, we can slice directly
396
397
398
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
399
                    narrow_weight = weight[:, 2 * tp_rank_start : 2 * tp_rank_end, ...]
400

401
                param = params_dict[name]
402
403
404
405
406
407
408
409
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                weight_loader(
                    param,
                    narrow_weight,
                    weight_name=name,
                    shard_id=None,
                    expert_id=None,
                )
410
411
412
                loaded_params.add(name)
                continue
            elif ".w2_weight" in name:
413
                # Handle MLP down projection weights
414
415
                # same flatten here, but since 2 mx4 value are packed in 1
                # uint8, divide by 2
416
417
418
                weight = weight.view(
                    num_experts, -1, intermediate_size // 2
                ).contiguous()
419
420
421
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
422
                    narrow_weight = weight[..., tp_rank_start // 2 : tp_rank_end // 2]
423

424
                param = params_dict[name]
425
426
427
428
429
430
431
432
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                weight_loader(
                    param,
                    narrow_weight,
                    weight_name=name,
                    shard_id=None,
                    expert_id=None,
                )
433
434
435
                loaded_params.add(name)
                continue
            elif ".w13_bias" in name:
436
437
438
439
440
                # Handle MLP gate and up projection biases
                # Extract gate and up projection bias parts
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
441
                    narrow_weight = weight[:, 2 * tp_rank_start : 2 * tp_rank_end]
442

443
                param = params_dict[name]
444
445
446
447
448
449
450
451
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                weight_loader(
                    param,
                    narrow_weight,
                    weight_name=name,
                    shard_id=None,
                    expert_id=None,
                )
452
453
454
                loaded_params.add(name)
                continue
            elif ".w2_bias" in name:
455
                # Handle MLP down projection bias
456
                param = params_dict[name]
457
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
458
459
460
461
462
463
                if use_ep:
                    weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
                    # (only load on rank 0 to avoid duplication)
                    if tp_rank != 0:
                        weight.zero_()
464
465
466
                weight_loader(
                    param, weight, weight_name=name, shard_id=None, expert_id=None
                )
467
468
                loaded_params.add(name)
                continue
469
470
471
472
473
474
            elif "sinks" in name:
                # Handle attention sinks (distributed across ranks)
                param = params_dict[name]
                narrow_weight = weight.narrow(0, head_start, heads_per_rank)
                param.data.copy_(narrow_weight)
                loaded_params.add(name)
475
476
477
478
479
480
                continue
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
481
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
482
483
484
485
486
                if weight_loader == default_weight_loader:
                    weight_loader(param, weight)
                else:
                    weight_loader(param, weight, shard_id)
                break
487
488
            else:
                # Handle all other weights with potential renaming
489
                if name not in params_dict:
490
                    continue
491
                param = params_dict[name]
492
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
493
                weight_loader(param, weight)
494
            loaded_params.add(name)
495
        return loaded_params
496
497

    def _load_weights_other(
498
499
        self,
        ep_rank_end: int,
500
        ep_rank_start: int,
501
502
503
504
505
        heads_per_rank: int,
        head_start: int,
        weights: Iterable[tuple[str, torch.Tensor]],
        stacked_params_mapping: list[tuple[str, ...]],
    ) -> set[str]:
506
507
508
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()

509
510
        use_ep = self.parallel_config.enable_expert_parallel

511
512
        # In MoE, we need to flatten the tensor parallel size across the data
        # parallel size when EP is disabled.
513
        tp_size, tp_rank = FusedMoEParallelConfig.flatten_tp_across_dp_and_pcp(
514
515
516
            tp_size=get_tensor_model_parallel_world_size(),
            dp_size=get_dp_group().world_size,
            dp_rank=get_dp_group().rank_in_group,
517
518
            pcp_size=get_pcp_group().world_size,
            pcp_rank=get_pcp_group().rank_in_group,
519
        )
520

521
        intermediate_size = self.config.intermediate_size
522
523
524
        per_rank_intermediate_size = cdiv(intermediate_size, tp_size)
        # Calculate common slicing bounds for current rank
        tp_rank_start = tp_rank * per_rank_intermediate_size
525
        tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size, intermediate_size)
526
527

        for name, weight in weights:
528
529
530
531
            # Skip layers on other devices.
            if is_pp_missing_parameter(name, self):
                continue

532
            if ".w13_weight" in name:
533
534
535
536
537
                # Handle MLP gate and up projection weights
                # Extract gate and up projection parts
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
538
                    narrow_weight = weight[:, :, 2 * tp_rank_start : 2 * tp_rank_end]
539
540

                narrow_weight = narrow_weight.permute(0, 2, 1).contiguous()
541
                param = params_dict[name]
542
543

                param.copy_(narrow_weight)
544
545
546
                loaded_params.add(name)
                continue
            elif ".w2_weight" in name:
547
548
549
550
551
552
                # Handle MLP down projection weights
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
                    narrow_weight = weight[:, tp_rank_start:tp_rank_end, :]
                narrow_weight = narrow_weight.permute(0, 2, 1).contiguous()
553
                param = params_dict[name]
554
555

                param.copy_(narrow_weight)
556
557
558
                loaded_params.add(name)
                continue
            elif ".w13_bias" in name:
559
560
561
562
563
                # Handle MLP gate and up projection biases
                # Extract gate and up projection bias parts
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
564
                    narrow_weight = weight[:, 2 * tp_rank_start : 2 * tp_rank_end]
565

566
                param = params_dict[name]
567
                param.copy_(narrow_weight)
568
569
570
                loaded_params.add(name)
                continue
            elif ".w2_bias" in name:
571
572
573
574
575
576
577
                # Handle MLP down projection bias
                if use_ep:
                    weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
                    # (only load on rank 0 to avoid duplication)
                    if tp_rank != 0:
                        weight.zero_()
578
                param = params_dict[name]
579
                param.copy_(weight)
580
581
                loaded_params.add(name)
                continue
582
583
584
585
586
587
            elif "sinks" in name:
                # Handle attention sinks (distributed across ranks)
                param = params_dict[name]
                narrow_weight = weight.narrow(0, head_start, heads_per_rank)
                param.data.copy_(narrow_weight)
                loaded_params.add(name)
588
589
590
591
592
593
                continue
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
594
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
595
596
597
598
599
                if weight_loader == default_weight_loader:
                    weight_loader(param, weight)
                else:
                    weight_loader(param, weight, shard_id)
                break
600
601
            else:
                # Handle all other weights with potential renaming
602
                if name not in params_dict:
603
                    continue
604
                param = params_dict[name]
605
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
606
                weight_loader(param, weight)
607
            loaded_params.add(name)
608
609
        return loaded_params

610
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
611
612
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
613
614
615
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
        ]

        tp_rank = get_tensor_model_parallel_rank()
        tp_size = get_tensor_model_parallel_world_size()

        # Attention heads per rank
        heads_per_rank = self.config.num_attention_heads // tp_size
        head_start = tp_rank * heads_per_rank

        ep_size = get_ep_group().world_size
        ep_rank = get_ep_group().rank
        num_experts = self.config.num_local_experts
        experts_per_rank = num_experts // ep_size
        ep_rank_start = ep_rank * experts_per_rank
        ep_rank_end = (ep_rank + 1) * experts_per_rank

632
633
634
635
636
        quant_method = (
            self.config.quantization_config["quant_method"]
            if hasattr(self.config, "quantization_config")
            else None
        )
637
        if quant_method == "mxfp4":
638
639
640
641
642
643
644
645
            return self._load_weights_mxfp4(
                ep_rank_end,
                ep_rank_start,
                heads_per_rank,
                head_start,
                weights,
                stacked_params_mapping,
            )
646
        else:
647
648
            return self._load_weights_other(
                ep_rank_start,
649
                ep_rank_end,
650
651
652
653
654
                heads_per_rank,
                head_start,
                weights,
                stacked_params_mapping,
            )
655
656


657
class GptOssForCausalLM(nn.Module, SupportsPP, SupportsEagle3, SupportsLoRA):
658
    packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695

    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_substr={
            ".self_attn.": ".attn.",
        },
        orig_to_new_suffix={
            ".embed_tokens.weight": ".embedding.weight",
            # MoE MXFP4 weights
            ".gate_up_proj_blocks": ".w13_weight",
            ".down_proj_blocks": ".w2_weight",
            ".gate_up_proj_scales": ".w13_weight_scale",
            ".down_proj_scales": ".w2_weight_scale",
            # MoE other weights
            ".gate_up_proj": ".w13_weight",
            ".down_proj": ".w2_weight",
            # MoE Bias
            ".gate_up_proj_bias": ".w13_bias",
            ".down_proj_bias": ".w2_bias",
        },
    )

    def __init__(
        self,
        vllm_config: VllmConfig,
        prefix: str = "",
    ):
        super().__init__()
        self.vllm_config = vllm_config
        self.config = vllm_config.model_config.hf_config

        self.model = GptOssModel(
            vllm_config=vllm_config,
            prefix=maybe_prefix(prefix, "model"),
        )
        self.lm_head = ParallelLMHead(
            self.config.vocab_size,
            self.config.hidden_size,
696
            prefix=maybe_prefix(prefix, "lm_head"),
697
698
        )
        self.logits_processor = LogitsProcessor(self.config.vocab_size)
699
        self.make_empty_intermediate_tensors = (
700
701
            self.model.make_empty_intermediate_tensors
        )
702

703
704
705
706
707
708
709
    def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
        self.model.aux_hidden_state_layers = layers

    def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
        num_layers = len(self.model.layers)
        return (2, num_layers // 2, num_layers - 3)

710
711
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
712

713
714
715
716
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
717
718
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
719
720
    ) -> torch.Tensor:
        return self.model(input_ids, positions, intermediate_tensors, inputs_embeds)
721

722
723
    def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head, hidden_states)
724
725
        return logits

726
727
728
729
730
731
732
733
734
735
736
    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
        # Params for weights, weight scales, activation scales
        # (param_name, weight_name, expert_id, shard_id)
        return FusedMoE.make_expert_params_mapping(
            ckpt_gate_proj_name="gate_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="up_proj",
            num_experts=self.config.num_local_experts,
            num_redundant_experts=0,
        )

737
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
738
739
        loader = AutoWeightsLoader(
            self,
740
            skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
741
742
        )
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)