gpt_oss.py 28.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
from typing import Optional

import torch
import torch.distributed as dist
from torch import nn
from transformers import GptOssConfig

from vllm.attention import Attention, AttentionType
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
14
15
from vllm.distributed import (get_ep_group, get_pp_group,
                              get_tensor_model_parallel_rank,
16
17
                              get_tensor_model_parallel_world_size,
                              tensor_model_parallel_all_gather)
18
19
20
21
22
23
24
25
26
27
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (QKVParallelLinear,
                                               RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
    ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
28
from vllm.model_executor.models.utils import sequence_parallel_chunk
29
30
31
from vllm.sequence import IntermediateTensors
from vllm.utils import cdiv

32
from .interfaces import SupportsEagle3, SupportsPP
33
from .utils import (AutoWeightsLoader, WeightsMapper, extract_layer_index,
34
35
                    is_pp_missing_parameter,
                    make_empty_intermediate_tensors_factory, make_layers,
36
                    maybe_prefix)
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68


class OAIAttention(nn.Module):

    def __init__(
        self,
        config: GptOssConfig,
        quant_config: Optional[QuantizationConfig] = None,
        cache_config: Optional[CacheConfig] = None,
        prefix: str = "",
    ):
        super().__init__()
        self.layer_idx = extract_layer_index(prefix)
        self.head_dim = config.head_dim
        self.num_attention_heads = config.num_attention_heads
        self.num_key_value_heads = config.num_key_value_heads
        self.hidden_size = config.hidden_size

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=config.max_position_embeddings,
            base=config.rope_theta,
            dtype=torch.float32,
            rope_scaling={
                "rope_type":
                "yarn",
                "factor":
                config.rope_scaling["factor"],
                "original_max_position_embeddings":
                config.rope_scaling["original_max_position_embeddings"],
                "beta_fast":
69
                config.rope_scaling["beta_fast"],
70
                "beta_slow":
71
                config.rope_scaling["beta_slow"],
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
            },
            is_neox_style=True,
        )

        tp_size = get_tensor_model_parallel_world_size()

        self.sinks = torch.nn.Parameter(
            torch.empty(config.num_attention_heads // tp_size,
                        requires_grad=False))

        self.q_size = self.num_attention_heads * self.head_dim // tp_size
        self.kv_size = self.num_key_value_heads * self.head_dim // tp_size
        self.scaling = self.head_dim**-0.5
        self.rope_theta = config.rope_theta

        self.qkv = QKVParallelLinear(
            hidden_size=self.hidden_size,
            head_size=self.head_dim,
            total_num_heads=self.num_attention_heads,
            total_num_kv_heads=self.num_key_value_heads,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
        )

        self.o_proj = RowParallelLinear(
            input_size=self.num_attention_heads * self.head_dim,
            output_size=self.hidden_size,
            quant_config=quant_config,
            prefix=f"{prefix}.o_proj",
        )

        self.num_local_attention_heads = config.num_attention_heads // tp_size
        self.num_local_key_value_heads = config.num_key_value_heads // tp_size

        # Only apply sliding window to every other layer
        sliding_window = (config.sliding_window if self.layer_idx %
                          2 == 0 else None)
        self.attn = Attention(
            self.num_local_attention_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_local_key_value_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            per_layer_sliding_window=sliding_window,
            attn_type=AttentionType.DECODER,
            prefix=f"{prefix}.attn",
            sinks=self.sinks,
        )

    def forward(self, hidden_states: torch.Tensor,
                positions: torch.Tensor) -> torch.Tensor:
124
        qkv, _ = self.qkv(hidden_states)
125
126
127
128
129
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
        v = v.contiguous()
        attn_output = self.attn(q, k, v)
        output, _ = self.o_proj(attn_output)
130
        return output
131
132
133
134
135
136


class MLPBlock(torch.nn.Module):

    def __init__(
        self,
137
        vllm_config: VllmConfig,
138
139
140
141
        layer_idx: int,
        prefix: str = "",
    ):
        super().__init__()
142
143
144
145
146
147
148

        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        parallel_config = vllm_config.parallel_config

        self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe

149
150
151
152
153
        self.layer_idx = layer_idx
        self.num_experts = config.num_local_experts
        self.experts_per_token = config.num_experts_per_tok
        self.world_size = dist.get_world_size() if dist.is_initialized() else 1
        self.router = torch.nn.Linear(config.hidden_size,
154
                                      config.num_local_experts)
155
156
        assert config.intermediate_size % self.world_size == 0
        self.experts = FusedMoE(num_experts=config.num_local_experts,
157
                                top_k=config.num_experts_per_tok,
158
159
160
161
162
163
                                hidden_size=config.hidden_size,
                                intermediate_size=config.intermediate_size,
                                reduce_results=True,
                                renormalize=True,
                                quant_config=quant_config,
                                prefix=f"{prefix}.experts",
164
165
                                apply_router_weight_on_input=False,
                                has_bias=True,
166
167
                                activation="swigluoai",
                                is_sequence_parallel=self.is_sequence_parallel)
168
169

    def forward(self, x: torch.Tensor) -> torch.Tensor:
170
171
172
173
        num_tokens = x.shape[0]
        if self.is_sequence_parallel:
            x = sequence_parallel_chunk(x)

174
175
        g = self.router(x)
        x = self.experts(hidden_states=x, router_logits=g)
176
177
178
179

        if self.is_sequence_parallel:
            x = tensor_model_parallel_all_gather(x.contiguous(), 0)
            x = x[:num_tokens]
180
        return x
181
182
183
184
185
186


class TransformerBlock(torch.nn.Module):

    def __init__(
        self,
187
        vllm_config: VllmConfig,
188
189
190
        prefix: str = "",
    ):
        super().__init__()
191
192
193
194

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config

195
        self.layer_idx = extract_layer_index(prefix)
196
197
198
        self.attn = OAIAttention(config,
                                 prefix=f"{prefix}.attn",
                                 cache_config=cache_config)
199
        self.mlp = MLPBlock(vllm_config,
200
201
                            self.layer_idx,
                            prefix=f"{prefix}.mlp")
202
203
        self.input_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
        self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
204

205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
    def forward(
        self,
        hidden_states: torch.Tensor,
        positions: torch.Tensor,
        residual: Optional[torch.Tensor],
    ) -> torch.Tensor:
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
        hidden_states = self.attn(hidden_states, positions)
        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
        output = self.mlp(hidden_states)
        return output, residual
224
225
226
227
228
229
230
231
232
233
234
235
236


@support_torch_compile
class GptOssModel(nn.Module):

    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
    ):
        super().__init__()
        self.config = vllm_config.model_config.hf_config
237
        self.parallel_config = vllm_config.parallel_config
238
239
240
241
242
        self.config.hidden_size = self.config.hidden_size
        self.embedding = VocabParallelEmbedding(
            self.config.vocab_size,
            self.config.hidden_size,
        )
243
244
245
        self.start_layer, self.end_layer, self.layers = make_layers(
            self.config.num_hidden_layers,
            lambda prefix: TransformerBlock(
246
                vllm_config,
247
248
249
250
                prefix=prefix,
            ),
            prefix=f"{prefix}.layers",
        )
251
        self.norm = RMSNorm(self.config.hidden_size, eps=1e-5)
252
253
254
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(
                ["hidden_states", "residual"], self.config.hidden_size))
255
        self.aux_hidden_state_layers = tuple[int, ...]()
256

257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embedding(input_ids)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                x = inputs_embeds
            else:
                x = self.get_input_embeddings(input_ids)

            residual = None
        else:
            assert intermediate_tensors is not None
            x = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

279
        aux_hidden_states = []
280
281
        for i in range(self.start_layer, self.end_layer):
            layer = self.layers[i]
282
283
284
            if i in self.aux_hidden_state_layers:
                aux_hidden_states.append(x if residual is None else x +
                                         residual)
285
286
287
288
289
290
291
            x, residual = layer(x, positions, residual)
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": x,
                "residual": residual
            })
        x, _ = self.norm(x, residual)
292
293
294

        if len(aux_hidden_states) > 0:
            return x, aux_hidden_states
295
296
        return x

297
    def _load_weights_mxfp4(
298
299
300
301
302
303
304
305
        self,
        ep_rank_end: int,
        ep_rank_start: int,
        heads_per_rank: int,
        head_start: int,
        weights: Iterable[tuple[str, torch.Tensor]],
        stacked_params_mapping: list[tuple[str, ...]],
    ) -> set[str]:
306
307
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()
308

309
        mxfp4_block = 32
310
311
        use_ep = self.parallel_config.enable_expert_parallel
        num_experts = self.config.num_local_experts
312
313
314

        tp_rank = get_tensor_model_parallel_rank()
        tp_size = get_tensor_model_parallel_world_size()
315
316

        intermediate_size = self.config.intermediate_size
317
318
319
320
321
322
323
324
325
326
327
328
        intermediate_size_block = intermediate_size // mxfp4_block
        per_rank_intermediate_size_block = cdiv(intermediate_size_block,
                                                tp_size)
        per_rank_intermediate_size = (per_rank_intermediate_size_block *
                                      mxfp4_block)

        # Calculate common slicing bounds for current rank
        tp_rank_start = tp_rank * per_rank_intermediate_size
        tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size,
                          intermediate_size)

        for name, weight in weights:
329
330
331
332
            # Skip layers on other devices.
            if is_pp_missing_parameter(name, self):
                continue

333
334
335
            # FIXME(woosuk): Remove this after testing.
            weight = weight.cuda()

336
337
            if ".w13_weight_scale" in name:
                # Handle MLP gate and up projection weights scale
338
339
340
341
342
343
344
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
                    narrow_weight = weight[:,
                                           2 * tp_rank_start:2 * tp_rank_end,
                                           ...]

345
                param = params_dict[name]
346
347
348
349
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param,
                              narrow_weight,
350
                              weight_name=name,
351
352
                              shard_id=None,
                              expert_id=None)
353
354
355
                loaded_params.add(name)
                continue
            elif ".w2_weight_scale" in name:
356
357
358
359
                # Handle MLP down projection weights
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
360
361
362
                    narrow_weight = weight[..., tp_rank_start //
                                           mxfp4_block:tp_rank_end //
                                           mxfp4_block]
363

364
                param = params_dict[name]
365
366
367
368
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param,
                              narrow_weight,
369
                              weight_name=name,
370
371
                              shard_id=None,
                              expert_id=None)
372
373
374
375
376
377
378
379
                loaded_params.add(name)
                continue
            elif ".w13_weight" in name:
                # Handle MLP gate and up projection weights
                # flat weight from (E, 2 * N, block_size, entry_per_block)
                # to (E, 2 * N, -1), shouldn't trigger copy for contiguous
                weight = weight.view(num_experts, 2 * intermediate_size,
                                     -1).contiguous()
380

381
382
                # Extract gate and up projection parts
                # since the weight is shuffled, we can slice directly
383
384
385
386
387
388
389
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
                    narrow_weight = weight[:,
                                           2 * tp_rank_start:2 * tp_rank_end,
                                           ...]

390
                param = params_dict[name]
391
392
393
394
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param,
                              narrow_weight,
395
                              weight_name=name,
396
397
                              shard_id=None,
                              expert_id=None)
398
399
400
                loaded_params.add(name)
                continue
            elif ".w2_weight" in name:
401
                # Handle MLP down projection weights
402
403
404
405
                # same flatten here, but since 2 mx4 value are packed in 1
                # uint8, divide by 2
                weight = weight.view(num_experts, -1,
                                     intermediate_size // 2).contiguous()
406
407
408
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
409
410
                    narrow_weight = weight[...,
                                           tp_rank_start // 2:tp_rank_end // 2]
411

412
                param = params_dict[name]
413
414
415
416
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param,
                              narrow_weight,
417
                              weight_name=name,
418
419
                              shard_id=None,
                              expert_id=None)
420
421
422
                loaded_params.add(name)
                continue
            elif ".w13_bias" in name:
423
424
425
426
427
428
429
430
                # Handle MLP gate and up projection biases
                # Extract gate and up projection bias parts
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
                    narrow_weight = weight[:,
                                           2 * tp_rank_start:2 * tp_rank_end]

431
                param = params_dict[name]
432
433
434
435
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param,
                              narrow_weight,
436
                              weight_name=name,
437
438
                              shard_id=None,
                              expert_id=None)
439
440
441
                loaded_params.add(name)
                continue
            elif ".w2_bias" in name:
442
                # Handle MLP down projection bias
443
                param = params_dict[name]
444
445
446
447
448
449
450
451
452
453
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                if use_ep:
                    weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
                    # (only load on rank 0 to avoid duplication)
                    if tp_rank != 0:
                        weight.zero_()
                weight_loader(param,
                              weight,
454
                              weight_name=name,
455
456
                              shard_id=None,
                              expert_id=None)
457
458
                loaded_params.add(name)
                continue
459
460
461
462
463
464
            elif "sinks" in name:
                # Handle attention sinks (distributed across ranks)
                param = params_dict[name]
                narrow_weight = weight.narrow(0, head_start, heads_per_rank)
                param.data.copy_(narrow_weight)
                loaded_params.add(name)
465
466
467
468
469
470
471
472
473
474
475
476
477
                continue
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                if weight_loader == default_weight_loader:
                    weight_loader(param, weight)
                else:
                    weight_loader(param, weight, shard_id)
                break
478
479
            else:
                # Handle all other weights with potential renaming
480
                if name not in params_dict:
481
                    continue
482
                param = params_dict[name]
483
484
485
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, weight)
486
            loaded_params.add(name)
487
        return loaded_params
488
489

    def _load_weights_other(
490
491
492
493
494
495
496
497
        self,
        ep_rank_start: int,
        ep_rank_end: int,
        heads_per_rank: int,
        head_start: int,
        weights: Iterable[tuple[str, torch.Tensor]],
        stacked_params_mapping: list[tuple[str, ...]],
    ) -> set[str]:
498
499
500
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()

501
502
        use_ep = self.parallel_config.enable_expert_parallel

503
504
505
        tp_rank = get_tensor_model_parallel_rank()
        tp_size = get_tensor_model_parallel_world_size()

506
        intermediate_size = self.config.intermediate_size
507
508
509
510
511
512
513
        per_rank_intermediate_size = cdiv(intermediate_size, tp_size)
        # Calculate common slicing bounds for current rank
        tp_rank_start = tp_rank * per_rank_intermediate_size
        tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size,
                          intermediate_size)

        for name, weight in weights:
514
515
516
517
            # Skip layers on other devices.
            if is_pp_missing_parameter(name, self):
                continue

518
            if ".w13_weight" in name:
519
520
521
522
523
524
525
526
527
                # Handle MLP gate and up projection weights
                # Extract gate and up projection parts
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
                    narrow_weight = weight[:, :,
                                           2 * tp_rank_start:2 * tp_rank_end]

                narrow_weight = narrow_weight.permute(0, 2, 1).contiguous()
528
                param = params_dict[name]
529
530

                param.copy_(narrow_weight)
531
532
533
                loaded_params.add(name)
                continue
            elif ".w2_weight" in name:
534
535
536
537
538
539
                # Handle MLP down projection weights
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
                    narrow_weight = weight[:, tp_rank_start:tp_rank_end, :]
                narrow_weight = narrow_weight.permute(0, 2, 1).contiguous()
540
                param = params_dict[name]
541
542

                param.copy_(narrow_weight)
543
544
545
                loaded_params.add(name)
                continue
            elif ".w13_bias" in name:
546
547
548
549
550
551
552
553
                # Handle MLP gate and up projection biases
                # Extract gate and up projection bias parts
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
                    narrow_weight = weight[:,
                                           2 * tp_rank_start:2 * tp_rank_end]

554
                param = params_dict[name]
555
                param.copy_(narrow_weight)
556
557
558
                loaded_params.add(name)
                continue
            elif ".w2_bias" in name:
559
560
561
562
563
564
565
                # Handle MLP down projection bias
                if use_ep:
                    weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
                    # (only load on rank 0 to avoid duplication)
                    if tp_rank != 0:
                        weight.zero_()
566
                param = params_dict[name]
567
                param.copy_(weight)
568
569
                loaded_params.add(name)
                continue
570
571
572
573
574
575
            elif "sinks" in name:
                # Handle attention sinks (distributed across ranks)
                param = params_dict[name]
                narrow_weight = weight.narrow(0, head_start, heads_per_rank)
                param.data.copy_(narrow_weight)
                loaded_params.add(name)
576
577
578
579
580
581
582
583
584
585
586
587
588
                continue
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                if weight_loader == default_weight_loader:
                    weight_loader(param, weight)
                else:
                    weight_loader(param, weight, shard_id)
                break
589
590
            else:
                # Handle all other weights with potential renaming
591
                if name not in params_dict:
592
                    continue
593
                param = params_dict[name]
594
595
596
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, weight)
597
            loaded_params.add(name)
598
599
600
601
        return loaded_params

    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".qkv", ".q_proj", "q"),
            (".qkv", ".k_proj", "k"),
            (".qkv", ".v_proj", "v"),
        ]

        tp_rank = get_tensor_model_parallel_rank()
        tp_size = get_tensor_model_parallel_world_size()

        # Attention heads per rank
        heads_per_rank = self.config.num_attention_heads // tp_size
        head_start = tp_rank * heads_per_rank

        ep_size = get_ep_group().world_size
        ep_rank = get_ep_group().rank
        num_experts = self.config.num_local_experts
        experts_per_rank = num_experts // ep_size
        ep_rank_start = ep_rank * experts_per_rank
        ep_rank_end = (ep_rank + 1) * experts_per_rank

        quant_method = (self.config.quantization_config['quant_method'] if
                        hasattr(self.config, "quantization_config") else None)
625
        if quant_method == "mxfp4":
626
627
628
            return self._load_weights_mxfp4(ep_rank_end, ep_rank_start,
                                            heads_per_rank, head_start,
                                            weights, stacked_params_mapping)
629
        else:
630
631
632
633
634
            return self._load_weights_other(ep_rank_end, ep_rank_start,
                                            heads_per_rank, head_start,
                                            weights, stacked_params_mapping)


635
class GptOssForCausalLM(nn.Module, SupportsPP, SupportsEagle3):
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
    packed_modules_mapping = {"qkv": ["q_proj", "k_proj", "v_proj"]}

    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_substr={
            ".self_attn.": ".attn.",
        },
        orig_to_new_suffix={
            ".embed_tokens.weight": ".embedding.weight",

            # MoE MXFP4 weights
            ".gate_up_proj_blocks": ".w13_weight",
            ".down_proj_blocks": ".w2_weight",
            ".gate_up_proj_scales": ".w13_weight_scale",
            ".down_proj_scales": ".w2_weight_scale",

            # MoE other weights
            ".gate_up_proj": ".w13_weight",
            ".down_proj": ".w2_weight",

            # MoE Bias
            ".gate_up_proj_bias": ".w13_bias",
            ".down_proj_bias": ".w2_bias",
        },
    )

    def __init__(
        self,
        vllm_config: VllmConfig,
        prefix: str = "",
    ):
        super().__init__()
        self.vllm_config = vllm_config
        self.config = vllm_config.model_config.hf_config

        self.model = GptOssModel(
            vllm_config=vllm_config,
            prefix=maybe_prefix(prefix, "model"),
        )
        self.lm_head = ParallelLMHead(
            self.config.vocab_size,
            self.config.hidden_size,
677
            prefix=maybe_prefix(prefix, "lm_head"),
678
679
        )
        self.logits_processor = LogitsProcessor(self.config.vocab_size)
680
681
682
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)

683
684
685
686
687
688
689
    def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
        self.model.aux_hidden_state_layers = layers

    def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
        num_layers = len(self.model.layers)
        return (2, num_layers // 2, num_layers - 3)

690
691
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)
692
693
694
695
696
697

    def forward(self,
                input_ids: torch.Tensor,
                positions: torch.Tensor,
                intermediate_tensors: Optional[IntermediateTensors] = None,
                inputs_embeds: Optional[torch.Tensor] = None) -> torch.Tensor:
698
699
        return self.model(input_ids, positions, intermediate_tensors,
                          inputs_embeds)
700

701
702
    def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head, hidden_states)
703
704
705
706
707
708
709
710
711
712
        return logits

    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
        loader = AutoWeightsLoader(
            self,
            skip_prefixes=(["lm_head."]
                           if self.config.tie_word_embeddings else None),
        )
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)