gpt_oss.py 28.3 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import os
4
5
6
7
8
9
10
11
12
13
14
from collections.abc import Iterable
from typing import Optional

import torch
import torch.distributed as dist
from torch import nn
from transformers import GptOssConfig

from vllm.attention import Attention, AttentionType
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
15
16
from vllm.distributed import (get_ep_group, get_pp_group,
                              get_tensor_model_parallel_rank,
17
18
                              get_tensor_model_parallel_world_size,
                              tensor_model_parallel_all_gather)
19
20
21
22
23
24
25
26
27
28
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (QKVParallelLinear,
                                               RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
    ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
29
from vllm.model_executor.models.utils import sequence_parallel_chunk
30
31
32
from vllm.sequence import IntermediateTensors
from vllm.utils import cdiv

33
from .interfaces import SupportsEagle3, SupportsPP
34
from .utils import (AutoWeightsLoader, WeightsMapper, extract_layer_index,
35
36
                    is_pp_missing_parameter,
                    make_empty_intermediate_tensors_factory, make_layers,
37
                    maybe_prefix)
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69


class OAIAttention(nn.Module):

    def __init__(
        self,
        config: GptOssConfig,
        quant_config: Optional[QuantizationConfig] = None,
        cache_config: Optional[CacheConfig] = None,
        prefix: str = "",
    ):
        super().__init__()
        self.layer_idx = extract_layer_index(prefix)
        self.head_dim = config.head_dim
        self.num_attention_heads = config.num_attention_heads
        self.num_key_value_heads = config.num_key_value_heads
        self.hidden_size = config.hidden_size

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=config.max_position_embeddings,
            base=config.rope_theta,
            dtype=torch.float32,
            rope_scaling={
                "rope_type":
                "yarn",
                "factor":
                config.rope_scaling["factor"],
                "original_max_position_embeddings":
                config.rope_scaling["original_max_position_embeddings"],
                "beta_fast":
70
                config.rope_scaling["beta_fast"],
71
                "beta_slow":
72
                config.rope_scaling["beta_slow"],
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
            },
            is_neox_style=True,
        )

        tp_size = get_tensor_model_parallel_world_size()

        self.sinks = torch.nn.Parameter(
            torch.empty(config.num_attention_heads // tp_size,
                        requires_grad=False))

        self.q_size = self.num_attention_heads * self.head_dim // tp_size
        self.kv_size = self.num_key_value_heads * self.head_dim // tp_size
        self.scaling = self.head_dim**-0.5
        self.rope_theta = config.rope_theta

        self.qkv = QKVParallelLinear(
            hidden_size=self.hidden_size,
            head_size=self.head_dim,
            total_num_heads=self.num_attention_heads,
            total_num_kv_heads=self.num_key_value_heads,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
        )

        self.o_proj = RowParallelLinear(
            input_size=self.num_attention_heads * self.head_dim,
            output_size=self.hidden_size,
            quant_config=quant_config,
            prefix=f"{prefix}.o_proj",
        )

        self.num_local_attention_heads = config.num_attention_heads // tp_size
        self.num_local_key_value_heads = config.num_key_value_heads // tp_size

        # Only apply sliding window to every other layer
        sliding_window = (config.sliding_window if self.layer_idx %
                          2 == 0 else None)
        self.attn = Attention(
            self.num_local_attention_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_local_key_value_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            per_layer_sliding_window=sliding_window,
            attn_type=AttentionType.DECODER,
            prefix=f"{prefix}.attn",
            sinks=self.sinks,
        )

    def forward(self, hidden_states: torch.Tensor,
                positions: torch.Tensor) -> torch.Tensor:
125
        qkv, _ = self.qkv(hidden_states)
126
127
128
129
130
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
        v = v.contiguous()
        attn_output = self.attn(q, k, v)
        output, _ = self.o_proj(attn_output)
131
        return output
132
133
134
135
136
137


class MLPBlock(torch.nn.Module):

    def __init__(
        self,
138
        vllm_config: VllmConfig,
139
140
141
142
        layer_idx: int,
        prefix: str = "",
    ):
        super().__init__()
143
144
145
146
147
148
149

        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        parallel_config = vllm_config.parallel_config

        self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe

150
151
152
153
154
        self.layer_idx = layer_idx
        self.num_experts = config.num_local_experts
        self.experts_per_token = config.num_experts_per_tok
        self.world_size = dist.get_world_size() if dist.is_initialized() else 1
        self.router = torch.nn.Linear(config.hidden_size,
155
                                      config.num_local_experts)
156
157
        assert config.intermediate_size % self.world_size == 0
        self.experts = FusedMoE(num_experts=config.num_local_experts,
158
                                top_k=config.num_experts_per_tok,
159
160
161
162
163
164
                                hidden_size=config.hidden_size,
                                intermediate_size=config.intermediate_size,
                                reduce_results=True,
                                renormalize=True,
                                quant_config=quant_config,
                                prefix=f"{prefix}.experts",
165
166
                                apply_router_weight_on_input=False,
                                has_bias=True,
167
168
                                activation="swigluoai",
                                is_sequence_parallel=self.is_sequence_parallel)
169
170

    def forward(self, x: torch.Tensor) -> torch.Tensor:
171
172
173
174
        num_tokens = x.shape[0]
        if self.is_sequence_parallel:
            x = sequence_parallel_chunk(x)

175
176
        g = self.router(x)
        x = self.experts(hidden_states=x, router_logits=g)
177
178
179
180

        if self.is_sequence_parallel:
            x = tensor_model_parallel_all_gather(x.contiguous(), 0)
            x = x[:num_tokens]
181
        return x
182
183
184
185
186
187


class TransformerBlock(torch.nn.Module):

    def __init__(
        self,
188
        vllm_config: VllmConfig,
189
190
191
        prefix: str = "",
    ):
        super().__init__()
192
193
194
195

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config

196
        self.layer_idx = extract_layer_index(prefix)
197
198
199
        self.attn = OAIAttention(config,
                                 prefix=f"{prefix}.attn",
                                 cache_config=cache_config)
200
        self.mlp = MLPBlock(vllm_config,
201
202
                            self.layer_idx,
                            prefix=f"{prefix}.mlp")
203
204
        self.input_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
        self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
205

206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
    def forward(
        self,
        hidden_states: torch.Tensor,
        positions: torch.Tensor,
        residual: Optional[torch.Tensor],
    ) -> torch.Tensor:
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
        hidden_states = self.attn(hidden_states, positions)
        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
        output = self.mlp(hidden_states)
        return output, residual
225
226
227
228
229
230
231
232
233
234
235
236
237


@support_torch_compile
class GptOssModel(nn.Module):

    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
    ):
        super().__init__()
        self.config = vllm_config.model_config.hf_config
238
        self.parallel_config = vllm_config.parallel_config
239
240
241
242
243
        self.config.hidden_size = self.config.hidden_size
        self.embedding = VocabParallelEmbedding(
            self.config.vocab_size,
            self.config.hidden_size,
        )
244
245
246
        self.start_layer, self.end_layer, self.layers = make_layers(
            self.config.num_hidden_layers,
            lambda prefix: TransformerBlock(
247
                vllm_config,
248
249
250
251
                prefix=prefix,
            ),
            prefix=f"{prefix}.layers",
        )
252
        self.norm = RMSNorm(self.config.hidden_size, eps=1e-5)
253
254
255
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(
                ["hidden_states", "residual"], self.config.hidden_size))
256
        self.aux_hidden_state_layers = tuple[int, ...]()
257
        self.use_nn_moe = int(os.environ.get('MOE_NN', 1)) == 1
258

259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embedding(input_ids)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                x = inputs_embeds
            else:
                x = self.get_input_embeddings(input_ids)

            residual = None
        else:
            assert intermediate_tensors is not None
            x = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

281
        aux_hidden_states = []
282
283
        for i in range(self.start_layer, self.end_layer):
            layer = self.layers[i]
284
285
286
            if i in self.aux_hidden_state_layers:
                aux_hidden_states.append(x if residual is None else x +
                                         residual)
287
288
289
290
291
292
293
            x, residual = layer(x, positions, residual)
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": x,
                "residual": residual
            })
        x, _ = self.norm(x, residual)
294
295
296

        if len(aux_hidden_states) > 0:
            return x, aux_hidden_states
297
298
        return x

299
    def _load_weights_mxfp4(
300
301
302
303
304
305
306
307
        self,
        ep_rank_end: int,
        ep_rank_start: int,
        heads_per_rank: int,
        head_start: int,
        weights: Iterable[tuple[str, torch.Tensor]],
        stacked_params_mapping: list[tuple[str, ...]],
    ) -> set[str]:
308
309
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()
310

311
        mxfp4_block = 32
312
313
        use_ep = self.parallel_config.enable_expert_parallel
        num_experts = self.config.num_local_experts
314
315
316

        tp_rank = get_tensor_model_parallel_rank()
        tp_size = get_tensor_model_parallel_world_size()
317
318

        intermediate_size = self.config.intermediate_size
319
320
321
322
323
324
325
326
327
328
329
330
        intermediate_size_block = intermediate_size // mxfp4_block
        per_rank_intermediate_size_block = cdiv(intermediate_size_block,
                                                tp_size)
        per_rank_intermediate_size = (per_rank_intermediate_size_block *
                                      mxfp4_block)

        # Calculate common slicing bounds for current rank
        tp_rank_start = tp_rank * per_rank_intermediate_size
        tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size,
                          intermediate_size)

        for name, weight in weights:
331
332
333
334
            # Skip layers on other devices.
            if is_pp_missing_parameter(name, self):
                continue

335
336
337
            # FIXME(woosuk): Remove this after testing.
            weight = weight.cuda()

338
339
            if ".w13_weight_scale" in name:
                # Handle MLP gate and up projection weights scale
340
341
342
343
344
345
346
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
                    narrow_weight = weight[:,
                                           2 * tp_rank_start:2 * tp_rank_end,
                                           ...]

347
                param = params_dict[name]
348
349
350
351
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param,
                              narrow_weight,
352
                              weight_name=name,
353
354
                              shard_id=None,
                              expert_id=None)
355
356
357
                loaded_params.add(name)
                continue
            elif ".w2_weight_scale" in name:
358
359
360
361
                # Handle MLP down projection weights
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
362
363
364
                    narrow_weight = weight[..., tp_rank_start //
                                           mxfp4_block:tp_rank_end //
                                           mxfp4_block]
365

366
                param = params_dict[name]
367
368
369
370
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param,
                              narrow_weight,
371
                              weight_name=name,
372
373
                              shard_id=None,
                              expert_id=None)
374
375
376
377
378
379
380
381
                loaded_params.add(name)
                continue
            elif ".w13_weight" in name:
                # Handle MLP gate and up projection weights
                # flat weight from (E, 2 * N, block_size, entry_per_block)
                # to (E, 2 * N, -1), shouldn't trigger copy for contiguous
                weight = weight.view(num_experts, 2 * intermediate_size,
                                     -1).contiguous()
382

383
384
                # Extract gate and up projection parts
                # since the weight is shuffled, we can slice directly
385
386
387
388
389
390
391
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
                    narrow_weight = weight[:,
                                           2 * tp_rank_start:2 * tp_rank_end,
                                           ...]

392
                param = params_dict[name]
393
394
395
396
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param,
                              narrow_weight,
397
                              weight_name=name,
398
399
                              shard_id=None,
                              expert_id=None)
400
401
402
                loaded_params.add(name)
                continue
            elif ".w2_weight" in name:
403
                # Handle MLP down projection weights
404
405
406
407
                # same flatten here, but since 2 mx4 value are packed in 1
                # uint8, divide by 2
                weight = weight.view(num_experts, -1,
                                     intermediate_size // 2).contiguous()
408
409
410
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
411
412
                    narrow_weight = weight[...,
                                           tp_rank_start // 2:tp_rank_end // 2]
413

414
                param = params_dict[name]
415
416
417
418
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param,
                              narrow_weight,
419
                              weight_name=name,
420
421
                              shard_id=None,
                              expert_id=None)
422
423
424
                loaded_params.add(name)
                continue
            elif ".w13_bias" in name:
425
426
427
428
429
430
431
432
                # Handle MLP gate and up projection biases
                # Extract gate and up projection bias parts
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
                    narrow_weight = weight[:,
                                           2 * tp_rank_start:2 * tp_rank_end]

433
                param = params_dict[name]
434
435
436
437
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param,
                              narrow_weight,
438
                              weight_name=name,
439
440
                              shard_id=None,
                              expert_id=None)
441
442
443
                loaded_params.add(name)
                continue
            elif ".w2_bias" in name:
444
                # Handle MLP down projection bias
445
                param = params_dict[name]
446
447
448
449
450
451
452
453
454
455
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                if use_ep:
                    weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
                    # (only load on rank 0 to avoid duplication)
                    if tp_rank != 0:
                        weight.zero_()
                weight_loader(param,
                              weight,
456
                              weight_name=name,
457
458
                              shard_id=None,
                              expert_id=None)
459
460
                loaded_params.add(name)
                continue
461
462
463
464
465
466
            elif "sinks" in name:
                # Handle attention sinks (distributed across ranks)
                param = params_dict[name]
                narrow_weight = weight.narrow(0, head_start, heads_per_rank)
                param.data.copy_(narrow_weight)
                loaded_params.add(name)
467
468
469
470
471
472
473
474
475
476
477
478
479
                continue
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                if weight_loader == default_weight_loader:
                    weight_loader(param, weight)
                else:
                    weight_loader(param, weight, shard_id)
                break
480
481
            else:
                # Handle all other weights with potential renaming
482
                if name not in params_dict:
483
                    continue
484
                param = params_dict[name]
485
486
487
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, weight)
488
            loaded_params.add(name)
489
        return loaded_params
490
491

    def _load_weights_other(
492
493
494
495
496
497
498
499
        self,
        ep_rank_start: int,
        ep_rank_end: int,
        heads_per_rank: int,
        head_start: int,
        weights: Iterable[tuple[str, torch.Tensor]],
        stacked_params_mapping: list[tuple[str, ...]],
    ) -> set[str]:
500
501
502
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()

503
504
        use_ep = self.parallel_config.enable_expert_parallel

505
506
507
        tp_rank = get_tensor_model_parallel_rank()
        tp_size = get_tensor_model_parallel_world_size()

508
        intermediate_size = self.config.intermediate_size
509
510
511
512
513
514
515
        per_rank_intermediate_size = cdiv(intermediate_size, tp_size)
        # Calculate common slicing bounds for current rank
        tp_rank_start = tp_rank * per_rank_intermediate_size
        tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size,
                          intermediate_size)

        for name, weight in weights:
516
517
518
519
            # Skip layers on other devices.
            if is_pp_missing_parameter(name, self):
                continue

520
            if ".w13_weight" in name:
521
522
523
524
525
526
                # Handle MLP gate and up projection weights
                # Extract gate and up projection parts
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
                    narrow_weight = weight[:, :,
527
528
529
530
                                            2 * tp_rank_start:2 * tp_rank_end]
                
                if not self.use_nn_moe:
                    narrow_weight = narrow_weight.permute(0, 2, 1).contiguous()
531

532
                param = params_dict[name]
533
534

                param.copy_(narrow_weight)
535
536
537
                loaded_params.add(name)
                continue
            elif ".w2_weight" in name:
538
539
540
541
542
543
                # Handle MLP down projection weights
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
                    narrow_weight = weight[:, tp_rank_start:tp_rank_end, :]

544
545
546
547
548
                if not self.use_nn_moe:
                    narrow_weight = narrow_weight.permute(0, 2, 1).contiguous()

                param = params_dict[name]
                
549
                param.copy_(narrow_weight)
550
551
552
                loaded_params.add(name)
                continue
            elif ".w13_bias" in name:
553
554
555
556
557
558
559
560
                # Handle MLP gate and up projection biases
                # Extract gate and up projection bias parts
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
                    narrow_weight = weight[:,
                                           2 * tp_rank_start:2 * tp_rank_end]

561
                param = params_dict[name]
562
                param.copy_(narrow_weight)
563
564
565
                loaded_params.add(name)
                continue
            elif ".w2_bias" in name:
566
567
568
569
570
571
572
                # Handle MLP down projection bias
                if use_ep:
                    weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
                    # (only load on rank 0 to avoid duplication)
                    if tp_rank != 0:
                        weight.zero_()
573
                param = params_dict[name]
574
                param.copy_(weight)
575
576
                loaded_params.add(name)
                continue
577
578
579
580
581
582
            elif "sinks" in name:
                # Handle attention sinks (distributed across ranks)
                param = params_dict[name]
                narrow_weight = weight.narrow(0, head_start, heads_per_rank)
                param.data.copy_(narrow_weight)
                loaded_params.add(name)
583
584
585
586
587
588
589
590
591
592
593
594
595
                continue
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                if weight_loader == default_weight_loader:
                    weight_loader(param, weight)
                else:
                    weight_loader(param, weight, shard_id)
                break
596
597
            else:
                # Handle all other weights with potential renaming
598
                if name not in params_dict:
599
                    continue
600
                param = params_dict[name]
601
602
603
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, weight)
604
            loaded_params.add(name)
605
606
607
608
        return loaded_params

    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".qkv", ".q_proj", "q"),
            (".qkv", ".k_proj", "k"),
            (".qkv", ".v_proj", "v"),
        ]

        tp_rank = get_tensor_model_parallel_rank()
        tp_size = get_tensor_model_parallel_world_size()

        # Attention heads per rank
        heads_per_rank = self.config.num_attention_heads // tp_size
        head_start = tp_rank * heads_per_rank

        ep_size = get_ep_group().world_size
        ep_rank = get_ep_group().rank
        num_experts = self.config.num_local_experts
        experts_per_rank = num_experts // ep_size
        ep_rank_start = ep_rank * experts_per_rank
        ep_rank_end = (ep_rank + 1) * experts_per_rank

        quant_method = (self.config.quantization_config['quant_method'] if
                        hasattr(self.config, "quantization_config") else None)
632
        if quant_method == "mxfp4":
633
634
635
            return self._load_weights_mxfp4(ep_rank_end, ep_rank_start,
                                            heads_per_rank, head_start,
                                            weights, stacked_params_mapping)
636
        else:
637
638
639
640
641
            return self._load_weights_other(ep_rank_end, ep_rank_start,
                                            heads_per_rank, head_start,
                                            weights, stacked_params_mapping)


642
class GptOssForCausalLM(nn.Module, SupportsPP, SupportsEagle3):
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
    packed_modules_mapping = {"qkv": ["q_proj", "k_proj", "v_proj"]}

    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_substr={
            ".self_attn.": ".attn.",
        },
        orig_to_new_suffix={
            ".embed_tokens.weight": ".embedding.weight",

            # MoE MXFP4 weights
            ".gate_up_proj_blocks": ".w13_weight",
            ".down_proj_blocks": ".w2_weight",
            ".gate_up_proj_scales": ".w13_weight_scale",
            ".down_proj_scales": ".w2_weight_scale",

            # MoE other weights
            ".gate_up_proj": ".w13_weight",
            ".down_proj": ".w2_weight",

            # MoE Bias
            ".gate_up_proj_bias": ".w13_bias",
            ".down_proj_bias": ".w2_bias",
        },
    )

    def __init__(
        self,
        vllm_config: VllmConfig,
        prefix: str = "",
    ):
        super().__init__()
        self.vllm_config = vllm_config
        self.config = vllm_config.model_config.hf_config

        self.model = GptOssModel(
            vllm_config=vllm_config,
            prefix=maybe_prefix(prefix, "model"),
        )
        self.lm_head = ParallelLMHead(
            self.config.vocab_size,
            self.config.hidden_size,
684
            prefix=maybe_prefix(prefix, "lm_head"),
685
686
        )
        self.logits_processor = LogitsProcessor(self.config.vocab_size)
687
688
689
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)

690
691
692
693
694
695
696
    def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
        self.model.aux_hidden_state_layers = layers

    def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
        num_layers = len(self.model.layers)
        return (2, num_layers // 2, num_layers - 3)

697
698
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)
699
700
701
702
703
704

    def forward(self,
                input_ids: torch.Tensor,
                positions: torch.Tensor,
                intermediate_tensors: Optional[IntermediateTensors] = None,
                inputs_embeds: Optional[torch.Tensor] = None) -> torch.Tensor:
705
706
        return self.model(input_ids, positions, intermediate_tensors,
                          inputs_embeds)
707

708
709
    def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head, hidden_states)
710
711
712
713
714
715
716
717
718
719
        return logits

    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
        loader = AutoWeightsLoader(
            self,
            skip_prefixes=(["lm_head."]
                           if self.config.tie_word_embeddings else None),
        )
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)