gpt_oss.py 27.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable

import torch
import torch.distributed as dist
from torch import nn
from transformers import GptOssConfig

from vllm.attention import Attention, AttentionType
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
13
from vllm.distributed import (
14
    get_dp_group,
15
16
17
18
19
20
    get_ep_group,
    get_pp_group,
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
    tensor_model_parallel_all_gather,
)
21
from vllm.model_executor.layers.fused_moe import FusedMoE
22
from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig
23
from vllm.model_executor.layers.layernorm import RMSNorm
24
from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear
25
26
27
28
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
29
30
31
    ParallelLMHead,
    VocabParallelEmbedding,
)
32
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
33
from vllm.model_executor.models.utils import sequence_parallel_chunk
34
35
36
from vllm.sequence import IntermediateTensors
from vllm.utils import cdiv

37
from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP
38
39
40
41
42
43
44
45
46
from .utils import (
    AutoWeightsLoader,
    WeightsMapper,
    extract_layer_index,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
47
48
49
50
51
52


class OAIAttention(nn.Module):
    def __init__(
        self,
        config: GptOssConfig,
53
54
        quant_config: QuantizationConfig | None = None,
        cache_config: CacheConfig | None = None,
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
        prefix: str = "",
    ):
        super().__init__()
        self.layer_idx = extract_layer_index(prefix)
        self.head_dim = config.head_dim
        self.num_attention_heads = config.num_attention_heads
        self.num_key_value_heads = config.num_key_value_heads
        self.hidden_size = config.hidden_size

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=config.max_position_embeddings,
            base=config.rope_theta,
            dtype=torch.float32,
            rope_scaling={
71
72
73
74
75
76
77
                "rope_type": "yarn",
                "factor": config.rope_scaling["factor"],
                "original_max_position_embeddings": config.rope_scaling[
                    "original_max_position_embeddings"
                ],
                "beta_fast": config.rope_scaling["beta_fast"],
                "beta_slow": config.rope_scaling["beta_slow"],
78
79
80
81
82
83
84
            },
            is_neox_style=True,
        )

        tp_size = get_tensor_model_parallel_world_size()

        self.sinks = torch.nn.Parameter(
85
86
            torch.empty(config.num_attention_heads // tp_size, requires_grad=False)
        )
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112

        self.q_size = self.num_attention_heads * self.head_dim // tp_size
        self.kv_size = self.num_key_value_heads * self.head_dim // tp_size
        self.scaling = self.head_dim**-0.5
        self.rope_theta = config.rope_theta

        self.qkv = QKVParallelLinear(
            hidden_size=self.hidden_size,
            head_size=self.head_dim,
            total_num_heads=self.num_attention_heads,
            total_num_kv_heads=self.num_key_value_heads,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
        )

        self.o_proj = RowParallelLinear(
            input_size=self.num_attention_heads * self.head_dim,
            output_size=self.hidden_size,
            quant_config=quant_config,
            prefix=f"{prefix}.o_proj",
        )

        self.num_local_attention_heads = config.num_attention_heads // tp_size
        self.num_local_key_value_heads = config.num_key_value_heads // tp_size

        # Only apply sliding window to every other layer
113
        sliding_window = config.sliding_window if self.layer_idx % 2 == 0 else None
114
115
116
117
118
119
120
121
122
123
124
125
126
        self.attn = Attention(
            self.num_local_attention_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_local_key_value_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            per_layer_sliding_window=sliding_window,
            attn_type=AttentionType.DECODER,
            prefix=f"{prefix}.attn",
            sinks=self.sinks,
        )

127
128
129
    def forward(
        self, hidden_states: torch.Tensor, positions: torch.Tensor
    ) -> torch.Tensor:
130
        qkv, _ = self.qkv(hidden_states)
131
132
133
134
135
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
        v = v.contiguous()
        attn_output = self.attn(q, k, v)
        output, _ = self.o_proj(attn_output)
136
        return output
137
138
139
140
141


class MLPBlock(torch.nn.Module):
    def __init__(
        self,
142
        vllm_config: VllmConfig,
143
144
145
146
        layer_idx: int,
        prefix: str = "",
    ):
        super().__init__()
147
148
149
150
151
152
153

        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        parallel_config = vllm_config.parallel_config

        self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe

154
155
156
157
        self.layer_idx = layer_idx
        self.num_experts = config.num_local_experts
        self.experts_per_token = config.num_experts_per_tok
        self.world_size = dist.get_world_size() if dist.is_initialized() else 1
158
        self.router = torch.nn.Linear(config.hidden_size, config.num_local_experts)
159
        assert config.intermediate_size % self.world_size == 0
160
161
162
163
164
165
166
167
168
169
170
171
172
173
        self.experts = FusedMoE(
            num_experts=config.num_local_experts,
            top_k=config.num_experts_per_tok,
            hidden_size=config.hidden_size,
            intermediate_size=config.intermediate_size,
            reduce_results=True,
            renormalize=True,
            quant_config=quant_config,
            prefix=f"{prefix}.experts",
            apply_router_weight_on_input=False,
            has_bias=True,
            activation="swigluoai",
            is_sequence_parallel=self.is_sequence_parallel,
        )
174
175

    def forward(self, x: torch.Tensor) -> torch.Tensor:
176
177
178
179
        num_tokens = x.shape[0]
        if self.is_sequence_parallel:
            x = sequence_parallel_chunk(x)

180
181
        g = self.router(x)
        x = self.experts(hidden_states=x, router_logits=g)
182
183
184
185

        if self.is_sequence_parallel:
            x = tensor_model_parallel_all_gather(x.contiguous(), 0)
            x = x[:num_tokens]
186
        return x
187
188
189
190
191


class TransformerBlock(torch.nn.Module):
    def __init__(
        self,
192
        vllm_config: VllmConfig,
193
194
195
        prefix: str = "",
    ):
        super().__init__()
196
197
198
199

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config

200
        self.layer_idx = extract_layer_index(prefix)
201
202
203
204
        self.attn = OAIAttention(
            config, prefix=f"{prefix}.attn", cache_config=cache_config
        )
        self.mlp = MLPBlock(vllm_config, self.layer_idx, prefix=f"{prefix}.mlp")
205
206
        self.input_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
        self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
207

208
209
210
211
    def forward(
        self,
        hidden_states: torch.Tensor,
        positions: torch.Tensor,
212
        residual: torch.Tensor | None,
213
214
215
216
217
218
    ) -> torch.Tensor:
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
219
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
220
        hidden_states = self.attn(hidden_states, positions)
221

222
        # Fully Connected
223
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
224
225
        output = self.mlp(hidden_states)
        return output, residual
226
227
228
229
230
231
232
233
234
235
236
237


@support_torch_compile
class GptOssModel(nn.Module):
    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
    ):
        super().__init__()
        self.config = vllm_config.model_config.hf_config
238
        self.parallel_config = vllm_config.parallel_config
239
240
241
242
243
        self.config.hidden_size = self.config.hidden_size
        self.embedding = VocabParallelEmbedding(
            self.config.vocab_size,
            self.config.hidden_size,
        )
244
245
246
        self.start_layer, self.end_layer, self.layers = make_layers(
            self.config.num_hidden_layers,
            lambda prefix: TransformerBlock(
247
                vllm_config,
248
249
250
251
                prefix=prefix,
            ),
            prefix=f"{prefix}.layers",
        )
252
        self.norm = RMSNorm(self.config.hidden_size, eps=1e-5)
253
254
255
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], self.config.hidden_size
        )
256
        self.aux_hidden_state_layers = tuple[int, ...]()
257

258
259
260
261
262
263
264
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embedding(input_ids)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
265
266
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
267
268
269
270
271
272
273
274
275
276
277
278
279
    ) -> torch.Tensor:
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                x = inputs_embeds
            else:
                x = self.get_input_embeddings(input_ids)

            residual = None
        else:
            assert intermediate_tensors is not None
            x = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

280
        aux_hidden_states = []
281
282
        for i in range(self.start_layer, self.end_layer):
            layer = self.layers[i]
283
            if i in self.aux_hidden_state_layers:
284
                aux_hidden_states.append(x if residual is None else x + residual)
285
286
            x, residual = layer(x, positions, residual)
        if not get_pp_group().is_last_rank:
287
            return IntermediateTensors({"hidden_states": x, "residual": residual})
288
        x, _ = self.norm(x, residual)
289
290
291

        if len(aux_hidden_states) > 0:
            return x, aux_hidden_states
292
293
        return x

294
    def _load_weights_mxfp4(
295
296
297
298
299
300
301
302
        self,
        ep_rank_end: int,
        ep_rank_start: int,
        heads_per_rank: int,
        head_start: int,
        weights: Iterable[tuple[str, torch.Tensor]],
        stacked_params_mapping: list[tuple[str, ...]],
    ) -> set[str]:
303
304
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()
305

306
        mxfp4_block = 32
307
308
        use_ep = self.parallel_config.enable_expert_parallel
        num_experts = self.config.num_local_experts
309

310
311
312
313
314
315
316
        # In MoE, we need to flatten the tensor parallel size across the data
        # parallel size when EP is disabled.
        tp_size, tp_rank = FusedMoEParallelConfig.flatten_tp_across_dp(
            tp_size=get_tensor_model_parallel_world_size(),
            dp_size=get_dp_group().world_size,
            dp_rank=get_dp_group().rank_in_group,
        )
317
318

        intermediate_size = self.config.intermediate_size
319
        intermediate_size_block = intermediate_size // mxfp4_block
320
321
        per_rank_intermediate_size_block = cdiv(intermediate_size_block, tp_size)
        per_rank_intermediate_size = per_rank_intermediate_size_block * mxfp4_block
322
323
324

        # Calculate common slicing bounds for current rank
        tp_rank_start = tp_rank * per_rank_intermediate_size
325
        tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size, intermediate_size)
326
327

        for name, weight in weights:
328
329
330
331
            # Skip layers on other devices.
            if is_pp_missing_parameter(name, self):
                continue

332
333
334
            # FIXME(woosuk): Remove this after testing.
            weight = weight.cuda()

335
336
            if ".w13_weight_scale" in name:
                # Handle MLP gate and up projection weights scale
337
338
339
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
340
                    narrow_weight = weight[:, 2 * tp_rank_start : 2 * tp_rank_end, ...]
341

342
                param = params_dict[name]
343
344
345
346
347
348
349
350
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                weight_loader(
                    param,
                    narrow_weight,
                    weight_name=name,
                    shard_id=None,
                    expert_id=None,
                )
351
352
353
                loaded_params.add(name)
                continue
            elif ".w2_weight_scale" in name:
354
355
356
357
                # Handle MLP down projection weights
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
358
359
360
                    narrow_weight = weight[
                        ..., tp_rank_start // mxfp4_block : tp_rank_end // mxfp4_block
                    ]
361

362
                param = params_dict[name]
363
364
365
366
367
368
369
370
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                weight_loader(
                    param,
                    narrow_weight,
                    weight_name=name,
                    shard_id=None,
                    expert_id=None,
                )
371
372
373
374
375
376
                loaded_params.add(name)
                continue
            elif ".w13_weight" in name:
                # Handle MLP gate and up projection weights
                # flat weight from (E, 2 * N, block_size, entry_per_block)
                # to (E, 2 * N, -1), shouldn't trigger copy for contiguous
377
378
379
                weight = weight.view(
                    num_experts, 2 * intermediate_size, -1
                ).contiguous()
380

381
382
                # Extract gate and up projection parts
                # since the weight is shuffled, we can slice directly
383
384
385
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
386
                    narrow_weight = weight[:, 2 * tp_rank_start : 2 * tp_rank_end, ...]
387

388
                param = params_dict[name]
389
390
391
392
393
394
395
396
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                weight_loader(
                    param,
                    narrow_weight,
                    weight_name=name,
                    shard_id=None,
                    expert_id=None,
                )
397
398
399
                loaded_params.add(name)
                continue
            elif ".w2_weight" in name:
400
                # Handle MLP down projection weights
401
402
                # same flatten here, but since 2 mx4 value are packed in 1
                # uint8, divide by 2
403
404
405
                weight = weight.view(
                    num_experts, -1, intermediate_size // 2
                ).contiguous()
406
407
408
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
409
                    narrow_weight = weight[..., tp_rank_start // 2 : tp_rank_end // 2]
410

411
                param = params_dict[name]
412
413
414
415
416
417
418
419
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                weight_loader(
                    param,
                    narrow_weight,
                    weight_name=name,
                    shard_id=None,
                    expert_id=None,
                )
420
421
422
                loaded_params.add(name)
                continue
            elif ".w13_bias" in name:
423
424
425
426
427
                # Handle MLP gate and up projection biases
                # Extract gate and up projection bias parts
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
428
                    narrow_weight = weight[:, 2 * tp_rank_start : 2 * tp_rank_end]
429

430
                param = params_dict[name]
431
432
433
434
435
436
437
438
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                weight_loader(
                    param,
                    narrow_weight,
                    weight_name=name,
                    shard_id=None,
                    expert_id=None,
                )
439
440
441
                loaded_params.add(name)
                continue
            elif ".w2_bias" in name:
442
                # Handle MLP down projection bias
443
                param = params_dict[name]
444
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
445
446
447
448
449
450
                if use_ep:
                    weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
                    # (only load on rank 0 to avoid duplication)
                    if tp_rank != 0:
                        weight.zero_()
451
452
453
                weight_loader(
                    param, weight, weight_name=name, shard_id=None, expert_id=None
                )
454
455
                loaded_params.add(name)
                continue
456
457
458
459
460
461
            elif "sinks" in name:
                # Handle attention sinks (distributed across ranks)
                param = params_dict[name]
                narrow_weight = weight.narrow(0, head_start, heads_per_rank)
                param.data.copy_(narrow_weight)
                loaded_params.add(name)
462
463
464
465
466
467
                continue
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
468
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
469
470
471
472
473
                if weight_loader == default_weight_loader:
                    weight_loader(param, weight)
                else:
                    weight_loader(param, weight, shard_id)
                break
474
475
            else:
                # Handle all other weights with potential renaming
476
                if name not in params_dict:
477
                    continue
478
                param = params_dict[name]
479
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
480
                weight_loader(param, weight)
481
            loaded_params.add(name)
482
        return loaded_params
483
484

    def _load_weights_other(
485
486
487
488
489
490
491
492
        self,
        ep_rank_start: int,
        ep_rank_end: int,
        heads_per_rank: int,
        head_start: int,
        weights: Iterable[tuple[str, torch.Tensor]],
        stacked_params_mapping: list[tuple[str, ...]],
    ) -> set[str]:
493
494
495
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()

496
497
        use_ep = self.parallel_config.enable_expert_parallel

498
499
500
501
502
503
504
        # In MoE, we need to flatten the tensor parallel size across the data
        # parallel size when EP is disabled.
        tp_size, tp_rank = FusedMoEParallelConfig.flatten_tp_across_dp(
            tp_size=get_tensor_model_parallel_world_size(),
            dp_size=get_dp_group().world_size,
            dp_rank=get_dp_group().rank_in_group,
        )
505

506
        intermediate_size = self.config.intermediate_size
507
508
509
        per_rank_intermediate_size = cdiv(intermediate_size, tp_size)
        # Calculate common slicing bounds for current rank
        tp_rank_start = tp_rank * per_rank_intermediate_size
510
        tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size, intermediate_size)
511
512

        for name, weight in weights:
513
514
515
516
            # Skip layers on other devices.
            if is_pp_missing_parameter(name, self):
                continue

517
            if ".w13_weight" in name:
518
519
520
521
522
                # Handle MLP gate and up projection weights
                # Extract gate and up projection parts
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
523
                    narrow_weight = weight[:, :, 2 * tp_rank_start : 2 * tp_rank_end]
524
525

                narrow_weight = narrow_weight.permute(0, 2, 1).contiguous()
526
                param = params_dict[name]
527
528

                param.copy_(narrow_weight)
529
530
531
                loaded_params.add(name)
                continue
            elif ".w2_weight" in name:
532
533
534
535
536
537
                # Handle MLP down projection weights
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
                    narrow_weight = weight[:, tp_rank_start:tp_rank_end, :]
                narrow_weight = narrow_weight.permute(0, 2, 1).contiguous()
538
                param = params_dict[name]
539
540

                param.copy_(narrow_weight)
541
542
543
                loaded_params.add(name)
                continue
            elif ".w13_bias" in name:
544
545
546
547
548
                # Handle MLP gate and up projection biases
                # Extract gate and up projection bias parts
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
549
                    narrow_weight = weight[:, 2 * tp_rank_start : 2 * tp_rank_end]
550

551
                param = params_dict[name]
552
                param.copy_(narrow_weight)
553
554
555
                loaded_params.add(name)
                continue
            elif ".w2_bias" in name:
556
557
558
559
560
561
562
                # Handle MLP down projection bias
                if use_ep:
                    weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
                    # (only load on rank 0 to avoid duplication)
                    if tp_rank != 0:
                        weight.zero_()
563
                param = params_dict[name]
564
                param.copy_(weight)
565
566
                loaded_params.add(name)
                continue
567
568
569
570
571
572
            elif "sinks" in name:
                # Handle attention sinks (distributed across ranks)
                param = params_dict[name]
                narrow_weight = weight.narrow(0, head_start, heads_per_rank)
                param.data.copy_(narrow_weight)
                loaded_params.add(name)
573
574
575
576
577
578
                continue
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
579
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
580
581
582
583
584
                if weight_loader == default_weight_loader:
                    weight_loader(param, weight)
                else:
                    weight_loader(param, weight, shard_id)
                break
585
586
            else:
                # Handle all other weights with potential renaming
587
                if name not in params_dict:
588
                    continue
589
                param = params_dict[name]
590
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
591
                weight_loader(param, weight)
592
            loaded_params.add(name)
593
594
        return loaded_params

595
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".qkv", ".q_proj", "q"),
            (".qkv", ".k_proj", "k"),
            (".qkv", ".v_proj", "v"),
        ]

        tp_rank = get_tensor_model_parallel_rank()
        tp_size = get_tensor_model_parallel_world_size()

        # Attention heads per rank
        heads_per_rank = self.config.num_attention_heads // tp_size
        head_start = tp_rank * heads_per_rank

        ep_size = get_ep_group().world_size
        ep_rank = get_ep_group().rank
        num_experts = self.config.num_local_experts
        experts_per_rank = num_experts // ep_size
        ep_rank_start = ep_rank * experts_per_rank
        ep_rank_end = (ep_rank + 1) * experts_per_rank

617
618
619
620
621
        quant_method = (
            self.config.quantization_config["quant_method"]
            if hasattr(self.config, "quantization_config")
            else None
        )
622
        if quant_method == "mxfp4":
623
624
625
626
627
628
629
630
            return self._load_weights_mxfp4(
                ep_rank_end,
                ep_rank_start,
                heads_per_rank,
                head_start,
                weights,
                stacked_params_mapping,
            )
631
        else:
632
633
634
635
636
637
638
639
            return self._load_weights_other(
                ep_rank_end,
                ep_rank_start,
                heads_per_rank,
                head_start,
                weights,
                stacked_params_mapping,
            )
640
641


642
class GptOssForCausalLM(nn.Module, SupportsPP, SupportsEagle3, SupportsLoRA):
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
    packed_modules_mapping = {"qkv": ["q_proj", "k_proj", "v_proj"]}

    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_substr={
            ".self_attn.": ".attn.",
        },
        orig_to_new_suffix={
            ".embed_tokens.weight": ".embedding.weight",
            # MoE MXFP4 weights
            ".gate_up_proj_blocks": ".w13_weight",
            ".down_proj_blocks": ".w2_weight",
            ".gate_up_proj_scales": ".w13_weight_scale",
            ".down_proj_scales": ".w2_weight_scale",
            # MoE other weights
            ".gate_up_proj": ".w13_weight",
            ".down_proj": ".w2_weight",
            # MoE Bias
            ".gate_up_proj_bias": ".w13_bias",
            ".down_proj_bias": ".w2_bias",
        },
    )

    def __init__(
        self,
        vllm_config: VllmConfig,
        prefix: str = "",
    ):
        super().__init__()
        self.vllm_config = vllm_config
        self.config = vllm_config.model_config.hf_config

        self.model = GptOssModel(
            vllm_config=vllm_config,
            prefix=maybe_prefix(prefix, "model"),
        )
        self.lm_head = ParallelLMHead(
            self.config.vocab_size,
            self.config.hidden_size,
681
            prefix=maybe_prefix(prefix, "lm_head"),
682
683
        )
        self.logits_processor = LogitsProcessor(self.config.vocab_size)
684
        self.make_empty_intermediate_tensors = (
685
686
            self.model.make_empty_intermediate_tensors
        )
687

688
689
690
691
692
693
694
    def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
        self.model.aux_hidden_state_layers = layers

    def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
        num_layers = len(self.model.layers)
        return (2, num_layers // 2, num_layers - 3)

695
696
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)
697

698
699
700
701
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
702
703
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
704
705
    ) -> torch.Tensor:
        return self.model(input_ids, positions, intermediate_tensors, inputs_embeds)
706

707
708
    def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head, hidden_states)
709
710
        return logits

711
712
713
714
715
716
717
718
719
720
721
    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
        # Params for weights, weight scales, activation scales
        # (param_name, weight_name, expert_id, shard_id)
        return FusedMoE.make_expert_params_mapping(
            ckpt_gate_proj_name="gate_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="up_proj",
            num_experts=self.config.num_local_experts,
            num_redundant_experts=0,
        )

722
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
723
724
        loader = AutoWeightsLoader(
            self,
725
            skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
726
727
        )
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)