gpt_oss.py 26.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable

import torch
import torch.distributed as dist
from torch import nn
from transformers import GptOssConfig

from vllm.attention import Attention, AttentionType
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
13
14
15
16
17
18
19
from vllm.distributed import (
    get_ep_group,
    get_pp_group,
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
    tensor_model_parallel_all_gather,
)
20
21
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
22
from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear
23
24
25
26
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
27
28
29
    ParallelLMHead,
    VocabParallelEmbedding,
)
30
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
31
from vllm.model_executor.models.utils import sequence_parallel_chunk
32
33
34
from vllm.sequence import IntermediateTensors
from vllm.utils import cdiv

35
from .interfaces import SupportsEagle3, SupportsPP
36
37
38
39
40
41
42
43
44
from .utils import (
    AutoWeightsLoader,
    WeightsMapper,
    extract_layer_index,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
45
46
47
48
49
50


class OAIAttention(nn.Module):
    def __init__(
        self,
        config: GptOssConfig,
51
52
        quant_config: QuantizationConfig | None = None,
        cache_config: CacheConfig | None = None,
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
        prefix: str = "",
    ):
        super().__init__()
        self.layer_idx = extract_layer_index(prefix)
        self.head_dim = config.head_dim
        self.num_attention_heads = config.num_attention_heads
        self.num_key_value_heads = config.num_key_value_heads
        self.hidden_size = config.hidden_size

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=config.max_position_embeddings,
            base=config.rope_theta,
            dtype=torch.float32,
            rope_scaling={
69
70
71
72
73
74
75
                "rope_type": "yarn",
                "factor": config.rope_scaling["factor"],
                "original_max_position_embeddings": config.rope_scaling[
                    "original_max_position_embeddings"
                ],
                "beta_fast": config.rope_scaling["beta_fast"],
                "beta_slow": config.rope_scaling["beta_slow"],
76
77
78
79
80
81
82
            },
            is_neox_style=True,
        )

        tp_size = get_tensor_model_parallel_world_size()

        self.sinks = torch.nn.Parameter(
83
84
            torch.empty(config.num_attention_heads // tp_size, requires_grad=False)
        )
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110

        self.q_size = self.num_attention_heads * self.head_dim // tp_size
        self.kv_size = self.num_key_value_heads * self.head_dim // tp_size
        self.scaling = self.head_dim**-0.5
        self.rope_theta = config.rope_theta

        self.qkv = QKVParallelLinear(
            hidden_size=self.hidden_size,
            head_size=self.head_dim,
            total_num_heads=self.num_attention_heads,
            total_num_kv_heads=self.num_key_value_heads,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
        )

        self.o_proj = RowParallelLinear(
            input_size=self.num_attention_heads * self.head_dim,
            output_size=self.hidden_size,
            quant_config=quant_config,
            prefix=f"{prefix}.o_proj",
        )

        self.num_local_attention_heads = config.num_attention_heads // tp_size
        self.num_local_key_value_heads = config.num_key_value_heads // tp_size

        # Only apply sliding window to every other layer
111
        sliding_window = config.sliding_window if self.layer_idx % 2 == 0 else None
112
113
114
115
116
117
118
119
120
121
122
123
124
        self.attn = Attention(
            self.num_local_attention_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_local_key_value_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            per_layer_sliding_window=sliding_window,
            attn_type=AttentionType.DECODER,
            prefix=f"{prefix}.attn",
            sinks=self.sinks,
        )

125
126
127
    def forward(
        self, hidden_states: torch.Tensor, positions: torch.Tensor
    ) -> torch.Tensor:
128
        qkv, _ = self.qkv(hidden_states)
129
130
131
132
133
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
        v = v.contiguous()
        attn_output = self.attn(q, k, v)
        output, _ = self.o_proj(attn_output)
134
        return output
135
136
137
138
139


class MLPBlock(torch.nn.Module):
    def __init__(
        self,
140
        vllm_config: VllmConfig,
141
142
143
144
        layer_idx: int,
        prefix: str = "",
    ):
        super().__init__()
145
146
147
148
149
150
151

        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        parallel_config = vllm_config.parallel_config

        self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe

152
153
154
155
        self.layer_idx = layer_idx
        self.num_experts = config.num_local_experts
        self.experts_per_token = config.num_experts_per_tok
        self.world_size = dist.get_world_size() if dist.is_initialized() else 1
156
        self.router = torch.nn.Linear(config.hidden_size, config.num_local_experts)
157
        assert config.intermediate_size % self.world_size == 0
158
159
160
161
162
163
164
165
166
167
168
169
170
171
        self.experts = FusedMoE(
            num_experts=config.num_local_experts,
            top_k=config.num_experts_per_tok,
            hidden_size=config.hidden_size,
            intermediate_size=config.intermediate_size,
            reduce_results=True,
            renormalize=True,
            quant_config=quant_config,
            prefix=f"{prefix}.experts",
            apply_router_weight_on_input=False,
            has_bias=True,
            activation="swigluoai",
            is_sequence_parallel=self.is_sequence_parallel,
        )
172
173

    def forward(self, x: torch.Tensor) -> torch.Tensor:
174
175
176
177
        num_tokens = x.shape[0]
        if self.is_sequence_parallel:
            x = sequence_parallel_chunk(x)

178
179
        g = self.router(x)
        x = self.experts(hidden_states=x, router_logits=g)
180
181
182
183

        if self.is_sequence_parallel:
            x = tensor_model_parallel_all_gather(x.contiguous(), 0)
            x = x[:num_tokens]
184
        return x
185
186
187
188
189


class TransformerBlock(torch.nn.Module):
    def __init__(
        self,
190
        vllm_config: VllmConfig,
191
192
193
        prefix: str = "",
    ):
        super().__init__()
194
195
196
197

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config

198
        self.layer_idx = extract_layer_index(prefix)
199
200
201
202
        self.attn = OAIAttention(
            config, prefix=f"{prefix}.attn", cache_config=cache_config
        )
        self.mlp = MLPBlock(vllm_config, self.layer_idx, prefix=f"{prefix}.mlp")
203
204
        self.input_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
        self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
205

206
207
208
209
    def forward(
        self,
        hidden_states: torch.Tensor,
        positions: torch.Tensor,
210
        residual: torch.Tensor | None,
211
212
213
214
215
216
    ) -> torch.Tensor:
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
217
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
218
        hidden_states = self.attn(hidden_states, positions)
219

220
        # Fully Connected
221
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
222
223
        output = self.mlp(hidden_states)
        return output, residual
224
225
226
227
228
229
230
231
232
233
234
235


@support_torch_compile
class GptOssModel(nn.Module):
    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
    ):
        super().__init__()
        self.config = vllm_config.model_config.hf_config
236
        self.parallel_config = vllm_config.parallel_config
237
238
239
240
241
        self.config.hidden_size = self.config.hidden_size
        self.embedding = VocabParallelEmbedding(
            self.config.vocab_size,
            self.config.hidden_size,
        )
242
243
244
        self.start_layer, self.end_layer, self.layers = make_layers(
            self.config.num_hidden_layers,
            lambda prefix: TransformerBlock(
245
                vllm_config,
246
247
248
249
                prefix=prefix,
            ),
            prefix=f"{prefix}.layers",
        )
250
        self.norm = RMSNorm(self.config.hidden_size, eps=1e-5)
251
252
253
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], self.config.hidden_size
        )
254
        self.aux_hidden_state_layers = tuple[int, ...]()
255

256
257
258
259
260
261
262
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embedding(input_ids)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
263
264
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
265
266
267
268
269
270
271
272
273
274
275
276
277
    ) -> torch.Tensor:
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                x = inputs_embeds
            else:
                x = self.get_input_embeddings(input_ids)

            residual = None
        else:
            assert intermediate_tensors is not None
            x = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

278
        aux_hidden_states = []
279
280
        for i in range(self.start_layer, self.end_layer):
            layer = self.layers[i]
281
            if i in self.aux_hidden_state_layers:
282
                aux_hidden_states.append(x if residual is None else x + residual)
283
284
            x, residual = layer(x, positions, residual)
        if not get_pp_group().is_last_rank:
285
            return IntermediateTensors({"hidden_states": x, "residual": residual})
286
        x, _ = self.norm(x, residual)
287
288
289

        if len(aux_hidden_states) > 0:
            return x, aux_hidden_states
290
291
        return x

292
    def _load_weights_mxfp4(
293
294
295
296
297
298
299
300
        self,
        ep_rank_end: int,
        ep_rank_start: int,
        heads_per_rank: int,
        head_start: int,
        weights: Iterable[tuple[str, torch.Tensor]],
        stacked_params_mapping: list[tuple[str, ...]],
    ) -> set[str]:
301
302
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()
303

304
        mxfp4_block = 32
305
306
        use_ep = self.parallel_config.enable_expert_parallel
        num_experts = self.config.num_local_experts
307
308
309

        tp_rank = get_tensor_model_parallel_rank()
        tp_size = get_tensor_model_parallel_world_size()
310
311

        intermediate_size = self.config.intermediate_size
312
        intermediate_size_block = intermediate_size // mxfp4_block
313
314
        per_rank_intermediate_size_block = cdiv(intermediate_size_block, tp_size)
        per_rank_intermediate_size = per_rank_intermediate_size_block * mxfp4_block
315
316
317

        # Calculate common slicing bounds for current rank
        tp_rank_start = tp_rank * per_rank_intermediate_size
318
        tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size, intermediate_size)
319
320

        for name, weight in weights:
321
322
323
324
            # Skip layers on other devices.
            if is_pp_missing_parameter(name, self):
                continue

325
326
327
            # FIXME(woosuk): Remove this after testing.
            weight = weight.cuda()

328
329
            if ".w13_weight_scale" in name:
                # Handle MLP gate and up projection weights scale
330
331
332
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
333
                    narrow_weight = weight[:, 2 * tp_rank_start : 2 * tp_rank_end, ...]
334

335
                param = params_dict[name]
336
337
338
339
340
341
342
343
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                weight_loader(
                    param,
                    narrow_weight,
                    weight_name=name,
                    shard_id=None,
                    expert_id=None,
                )
344
345
346
                loaded_params.add(name)
                continue
            elif ".w2_weight_scale" in name:
347
348
349
350
                # Handle MLP down projection weights
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
351
352
353
                    narrow_weight = weight[
                        ..., tp_rank_start // mxfp4_block : tp_rank_end // mxfp4_block
                    ]
354

355
                param = params_dict[name]
356
357
358
359
360
361
362
363
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                weight_loader(
                    param,
                    narrow_weight,
                    weight_name=name,
                    shard_id=None,
                    expert_id=None,
                )
364
365
366
367
368
369
                loaded_params.add(name)
                continue
            elif ".w13_weight" in name:
                # Handle MLP gate and up projection weights
                # flat weight from (E, 2 * N, block_size, entry_per_block)
                # to (E, 2 * N, -1), shouldn't trigger copy for contiguous
370
371
372
                weight = weight.view(
                    num_experts, 2 * intermediate_size, -1
                ).contiguous()
373

374
375
                # Extract gate and up projection parts
                # since the weight is shuffled, we can slice directly
376
377
378
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
379
                    narrow_weight = weight[:, 2 * tp_rank_start : 2 * tp_rank_end, ...]
380

381
                param = params_dict[name]
382
383
384
385
386
387
388
389
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                weight_loader(
                    param,
                    narrow_weight,
                    weight_name=name,
                    shard_id=None,
                    expert_id=None,
                )
390
391
392
                loaded_params.add(name)
                continue
            elif ".w2_weight" in name:
393
                # Handle MLP down projection weights
394
395
                # same flatten here, but since 2 mx4 value are packed in 1
                # uint8, divide by 2
396
397
398
                weight = weight.view(
                    num_experts, -1, intermediate_size // 2
                ).contiguous()
399
400
401
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
402
                    narrow_weight = weight[..., tp_rank_start // 2 : tp_rank_end // 2]
403

404
                param = params_dict[name]
405
406
407
408
409
410
411
412
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                weight_loader(
                    param,
                    narrow_weight,
                    weight_name=name,
                    shard_id=None,
                    expert_id=None,
                )
413
414
415
                loaded_params.add(name)
                continue
            elif ".w13_bias" in name:
416
417
418
419
420
                # Handle MLP gate and up projection biases
                # Extract gate and up projection bias parts
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
421
                    narrow_weight = weight[:, 2 * tp_rank_start : 2 * tp_rank_end]
422

423
                param = params_dict[name]
424
425
426
427
428
429
430
431
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                weight_loader(
                    param,
                    narrow_weight,
                    weight_name=name,
                    shard_id=None,
                    expert_id=None,
                )
432
433
434
                loaded_params.add(name)
                continue
            elif ".w2_bias" in name:
435
                # Handle MLP down projection bias
436
                param = params_dict[name]
437
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
438
439
440
441
442
443
                if use_ep:
                    weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
                    # (only load on rank 0 to avoid duplication)
                    if tp_rank != 0:
                        weight.zero_()
444
445
446
                weight_loader(
                    param, weight, weight_name=name, shard_id=None, expert_id=None
                )
447
448
                loaded_params.add(name)
                continue
449
450
451
452
453
454
            elif "sinks" in name:
                # Handle attention sinks (distributed across ranks)
                param = params_dict[name]
                narrow_weight = weight.narrow(0, head_start, heads_per_rank)
                param.data.copy_(narrow_weight)
                loaded_params.add(name)
455
456
457
458
459
460
                continue
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
461
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
462
463
464
465
466
                if weight_loader == default_weight_loader:
                    weight_loader(param, weight)
                else:
                    weight_loader(param, weight, shard_id)
                break
467
468
            else:
                # Handle all other weights with potential renaming
469
                if name not in params_dict:
470
                    continue
471
                param = params_dict[name]
472
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
473
                weight_loader(param, weight)
474
            loaded_params.add(name)
475
        return loaded_params
476
477

    def _load_weights_other(
478
479
480
481
482
483
484
485
        self,
        ep_rank_start: int,
        ep_rank_end: int,
        heads_per_rank: int,
        head_start: int,
        weights: Iterable[tuple[str, torch.Tensor]],
        stacked_params_mapping: list[tuple[str, ...]],
    ) -> set[str]:
486
487
488
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()

489
490
        use_ep = self.parallel_config.enable_expert_parallel

491
492
493
        tp_rank = get_tensor_model_parallel_rank()
        tp_size = get_tensor_model_parallel_world_size()

494
        intermediate_size = self.config.intermediate_size
495
496
497
        per_rank_intermediate_size = cdiv(intermediate_size, tp_size)
        # Calculate common slicing bounds for current rank
        tp_rank_start = tp_rank * per_rank_intermediate_size
498
        tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size, intermediate_size)
499
500

        for name, weight in weights:
501
502
503
504
            # Skip layers on other devices.
            if is_pp_missing_parameter(name, self):
                continue

505
            if ".w13_weight" in name:
506
507
508
509
510
                # Handle MLP gate and up projection weights
                # Extract gate and up projection parts
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
511
                    narrow_weight = weight[:, :, 2 * tp_rank_start : 2 * tp_rank_end]
512
513

                narrow_weight = narrow_weight.permute(0, 2, 1).contiguous()
514
                param = params_dict[name]
515
516

                param.copy_(narrow_weight)
517
518
519
                loaded_params.add(name)
                continue
            elif ".w2_weight" in name:
520
521
522
523
524
525
                # Handle MLP down projection weights
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
                    narrow_weight = weight[:, tp_rank_start:tp_rank_end, :]
                narrow_weight = narrow_weight.permute(0, 2, 1).contiguous()
526
                param = params_dict[name]
527
528

                param.copy_(narrow_weight)
529
530
531
                loaded_params.add(name)
                continue
            elif ".w13_bias" in name:
532
533
534
535
536
                # Handle MLP gate and up projection biases
                # Extract gate and up projection bias parts
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
537
                    narrow_weight = weight[:, 2 * tp_rank_start : 2 * tp_rank_end]
538

539
                param = params_dict[name]
540
                param.copy_(narrow_weight)
541
542
543
                loaded_params.add(name)
                continue
            elif ".w2_bias" in name:
544
545
546
547
548
549
550
                # Handle MLP down projection bias
                if use_ep:
                    weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
                    # (only load on rank 0 to avoid duplication)
                    if tp_rank != 0:
                        weight.zero_()
551
                param = params_dict[name]
552
                param.copy_(weight)
553
554
                loaded_params.add(name)
                continue
555
556
557
558
559
560
            elif "sinks" in name:
                # Handle attention sinks (distributed across ranks)
                param = params_dict[name]
                narrow_weight = weight.narrow(0, head_start, heads_per_rank)
                param.data.copy_(narrow_weight)
                loaded_params.add(name)
561
562
563
564
565
566
                continue
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
567
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
568
569
570
571
572
                if weight_loader == default_weight_loader:
                    weight_loader(param, weight)
                else:
                    weight_loader(param, weight, shard_id)
                break
573
574
            else:
                # Handle all other weights with potential renaming
575
                if name not in params_dict:
576
                    continue
577
                param = params_dict[name]
578
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
579
                weight_loader(param, weight)
580
            loaded_params.add(name)
581
582
        return loaded_params

583
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".qkv", ".q_proj", "q"),
            (".qkv", ".k_proj", "k"),
            (".qkv", ".v_proj", "v"),
        ]

        tp_rank = get_tensor_model_parallel_rank()
        tp_size = get_tensor_model_parallel_world_size()

        # Attention heads per rank
        heads_per_rank = self.config.num_attention_heads // tp_size
        head_start = tp_rank * heads_per_rank

        ep_size = get_ep_group().world_size
        ep_rank = get_ep_group().rank
        num_experts = self.config.num_local_experts
        experts_per_rank = num_experts // ep_size
        ep_rank_start = ep_rank * experts_per_rank
        ep_rank_end = (ep_rank + 1) * experts_per_rank

605
606
607
608
609
        quant_method = (
            self.config.quantization_config["quant_method"]
            if hasattr(self.config, "quantization_config")
            else None
        )
610
        if quant_method == "mxfp4":
611
612
613
614
615
616
617
618
            return self._load_weights_mxfp4(
                ep_rank_end,
                ep_rank_start,
                heads_per_rank,
                head_start,
                weights,
                stacked_params_mapping,
            )
619
        else:
620
621
622
623
624
625
626
627
            return self._load_weights_other(
                ep_rank_end,
                ep_rank_start,
                heads_per_rank,
                head_start,
                weights,
                stacked_params_mapping,
            )
628
629


630
class GptOssForCausalLM(nn.Module, SupportsPP, SupportsEagle3):
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
    packed_modules_mapping = {"qkv": ["q_proj", "k_proj", "v_proj"]}

    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_substr={
            ".self_attn.": ".attn.",
        },
        orig_to_new_suffix={
            ".embed_tokens.weight": ".embedding.weight",
            # MoE MXFP4 weights
            ".gate_up_proj_blocks": ".w13_weight",
            ".down_proj_blocks": ".w2_weight",
            ".gate_up_proj_scales": ".w13_weight_scale",
            ".down_proj_scales": ".w2_weight_scale",
            # MoE other weights
            ".gate_up_proj": ".w13_weight",
            ".down_proj": ".w2_weight",
            # MoE Bias
            ".gate_up_proj_bias": ".w13_bias",
            ".down_proj_bias": ".w2_bias",
        },
    )

    def __init__(
        self,
        vllm_config: VllmConfig,
        prefix: str = "",
    ):
        super().__init__()
        self.vllm_config = vllm_config
        self.config = vllm_config.model_config.hf_config

        self.model = GptOssModel(
            vllm_config=vllm_config,
            prefix=maybe_prefix(prefix, "model"),
        )
        self.lm_head = ParallelLMHead(
            self.config.vocab_size,
            self.config.hidden_size,
669
            prefix=maybe_prefix(prefix, "lm_head"),
670
671
        )
        self.logits_processor = LogitsProcessor(self.config.vocab_size)
672
        self.make_empty_intermediate_tensors = (
673
674
            self.model.make_empty_intermediate_tensors
        )
675

676
677
678
679
680
681
682
    def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
        self.model.aux_hidden_state_layers = layers

    def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
        num_layers = len(self.model.layers)
        return (2, num_layers // 2, num_layers - 3)

683
684
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)
685

686
687
688
689
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
690
691
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
692
693
    ) -> torch.Tensor:
        return self.model(input_ids, positions, intermediate_tensors, inputs_embeds)
694

695
696
    def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head, hidden_states)
697
698
        return logits

699
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
700
701
        loader = AutoWeightsLoader(
            self,
702
            skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
703
704
        )
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)