gpt_oss.py 27 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
from typing import Optional

import torch
import torch.distributed as dist
from torch import nn
from transformers import GptOssConfig

from vllm.attention import Attention, AttentionType
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
14
15
from vllm.distributed import (get_ep_group, get_pp_group,
                              get_tensor_model_parallel_rank,
16
17
18
19
20
21
22
23
24
25
26
27
28
29
                              get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (QKVParallelLinear,
                                               RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
    ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import IntermediateTensors
from vllm.utils import cdiv

30
from .interfaces import SupportsPP
31
from .utils import (AutoWeightsLoader, WeightsMapper, extract_layer_index,
32
33
                    is_pp_missing_parameter,
                    make_empty_intermediate_tensors_factory, make_layers,
34
                    maybe_prefix)
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66


class OAIAttention(nn.Module):

    def __init__(
        self,
        config: GptOssConfig,
        quant_config: Optional[QuantizationConfig] = None,
        cache_config: Optional[CacheConfig] = None,
        prefix: str = "",
    ):
        super().__init__()
        self.layer_idx = extract_layer_index(prefix)
        self.head_dim = config.head_dim
        self.num_attention_heads = config.num_attention_heads
        self.num_key_value_heads = config.num_key_value_heads
        self.hidden_size = config.hidden_size

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=config.max_position_embeddings,
            base=config.rope_theta,
            dtype=torch.float32,
            rope_scaling={
                "rope_type":
                "yarn",
                "factor":
                config.rope_scaling["factor"],
                "original_max_position_embeddings":
                config.rope_scaling["original_max_position_embeddings"],
                "beta_fast":
67
                config.rope_scaling["beta_fast"],
68
                "beta_slow":
69
                config.rope_scaling["beta_slow"],
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
            },
            is_neox_style=True,
        )

        tp_size = get_tensor_model_parallel_world_size()

        self.sinks = torch.nn.Parameter(
            torch.empty(config.num_attention_heads // tp_size,
                        requires_grad=False))

        self.q_size = self.num_attention_heads * self.head_dim // tp_size
        self.kv_size = self.num_key_value_heads * self.head_dim // tp_size
        self.scaling = self.head_dim**-0.5
        self.rope_theta = config.rope_theta

        self.qkv = QKVParallelLinear(
            hidden_size=self.hidden_size,
            head_size=self.head_dim,
            total_num_heads=self.num_attention_heads,
            total_num_kv_heads=self.num_key_value_heads,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
        )

        self.o_proj = RowParallelLinear(
            input_size=self.num_attention_heads * self.head_dim,
            output_size=self.hidden_size,
            quant_config=quant_config,
            prefix=f"{prefix}.o_proj",
        )

        self.num_local_attention_heads = config.num_attention_heads // tp_size
        self.num_local_key_value_heads = config.num_key_value_heads // tp_size

        # Only apply sliding window to every other layer
        sliding_window = (config.sliding_window if self.layer_idx %
                          2 == 0 else None)
        self.attn = Attention(
            self.num_local_attention_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_local_key_value_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            per_layer_sliding_window=sliding_window,
            attn_type=AttentionType.DECODER,
            prefix=f"{prefix}.attn",
            sinks=self.sinks,
        )

    def forward(self, hidden_states: torch.Tensor,
                positions: torch.Tensor) -> torch.Tensor:
122
        qkv, _ = self.qkv(hidden_states)
123
124
125
126
127
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
        v = v.contiguous()
        attn_output = self.attn(q, k, v)
        output, _ = self.o_proj(attn_output)
128
        return output
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145


class MLPBlock(torch.nn.Module):

    def __init__(
        self,
        config: GptOssConfig,
        layer_idx: int,
        quant_config: QuantizationConfig,
        prefix: str = "",
    ):
        super().__init__()
        self.layer_idx = layer_idx
        self.num_experts = config.num_local_experts
        self.experts_per_token = config.num_experts_per_tok
        self.world_size = dist.get_world_size() if dist.is_initialized() else 1
        self.router = torch.nn.Linear(config.hidden_size,
146
                                      config.num_local_experts)
147
148
        assert config.intermediate_size % self.world_size == 0
        self.experts = FusedMoE(num_experts=config.num_local_experts,
149
                                top_k=config.num_experts_per_tok,
150
151
152
153
154
155
                                hidden_size=config.hidden_size,
                                intermediate_size=config.intermediate_size,
                                reduce_results=True,
                                renormalize=True,
                                quant_config=quant_config,
                                prefix=f"{prefix}.experts",
156
157
                                apply_router_weight_on_input=False,
                                has_bias=True,
158
                                activation="swigluoai")
159
160

    def forward(self, x: torch.Tensor) -> torch.Tensor:
161
162
163
        g = self.router(x)
        x = self.experts(hidden_states=x, router_logits=g)
        return x
164
165
166
167
168
169
170


class TransformerBlock(torch.nn.Module):

    def __init__(
        self,
        config: GptOssConfig,
171
        cache_config: CacheConfig,
172
173
174
175
176
        quant_config: QuantizationConfig,
        prefix: str = "",
    ):
        super().__init__()
        self.layer_idx = extract_layer_index(prefix)
177
178
179
        self.attn = OAIAttention(config,
                                 prefix=f"{prefix}.attn",
                                 cache_config=cache_config)
180
181
182
183
        self.mlp = MLPBlock(config,
                            self.layer_idx,
                            quant_config=quant_config,
                            prefix=f"{prefix}.mlp")
184
185
        self.input_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
        self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
186

187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
    def forward(
        self,
        hidden_states: torch.Tensor,
        positions: torch.Tensor,
        residual: Optional[torch.Tensor],
    ) -> torch.Tensor:
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
        hidden_states = self.attn(hidden_states, positions)
        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
        output = self.mlp(hidden_states)
        return output, residual
206
207
208
209
210
211
212
213
214
215
216
217
218


@support_torch_compile
class GptOssModel(nn.Module):

    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
    ):
        super().__init__()
        self.config = vllm_config.model_config.hf_config
219
        self.cache_config = vllm_config.cache_config
220
        self.quant_config = vllm_config.quant_config
221
        self.parallel_config = vllm_config.parallel_config
222
223
224
225
226
        self.config.hidden_size = self.config.hidden_size
        self.embedding = VocabParallelEmbedding(
            self.config.vocab_size,
            self.config.hidden_size,
        )
227
228
229
        self.start_layer, self.end_layer, self.layers = make_layers(
            self.config.num_hidden_layers,
            lambda prefix: TransformerBlock(
230
                self.config,
231
                cache_config=self.cache_config,
232
                quant_config=self.quant_config,
233
234
235
236
                prefix=prefix,
            ),
            prefix=f"{prefix}.layers",
        )
237
        self.norm = RMSNorm(self.config.hidden_size, eps=1e-5)
238
239
240
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(
                ["hidden_states", "residual"], self.config.hidden_size))
241

242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embedding(input_ids)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                x = inputs_embeds
            else:
                x = self.get_input_embeddings(input_ids)

            residual = None
        else:
            assert intermediate_tensors is not None
            x = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

        for i in range(self.start_layer, self.end_layer):
            layer = self.layers[i]
            x, residual = layer(x, positions, residual)
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": x,
                "residual": residual
            })
        x, _ = self.norm(x, residual)
273
274
        return x

275
    def _load_weights_mxfp4(
276
277
278
279
280
281
282
283
        self,
        ep_rank_end: int,
        ep_rank_start: int,
        heads_per_rank: int,
        head_start: int,
        weights: Iterable[tuple[str, torch.Tensor]],
        stacked_params_mapping: list[tuple[str, ...]],
    ) -> set[str]:
284
285
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()
286

287
        mxfp4_block = 32
288
289
        use_ep = self.parallel_config.enable_expert_parallel
        num_experts = self.config.num_local_experts
290
291
292

        tp_rank = get_tensor_model_parallel_rank()
        tp_size = get_tensor_model_parallel_world_size()
293
294

        intermediate_size = self.config.intermediate_size
295
296
297
298
299
300
301
302
303
304
305
306
        intermediate_size_block = intermediate_size // mxfp4_block
        per_rank_intermediate_size_block = cdiv(intermediate_size_block,
                                                tp_size)
        per_rank_intermediate_size = (per_rank_intermediate_size_block *
                                      mxfp4_block)

        # Calculate common slicing bounds for current rank
        tp_rank_start = tp_rank * per_rank_intermediate_size
        tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size,
                          intermediate_size)

        for name, weight in weights:
307
308
309
310
            # Skip layers on other devices.
            if is_pp_missing_parameter(name, self):
                continue

311
312
313
            # FIXME(woosuk): Remove this after testing.
            weight = weight.cuda()

314
315
            if ".w13_weight_scale" in name:
                # Handle MLP gate and up projection weights scale
316
317
318
319
320
321
322
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
                    narrow_weight = weight[:,
                                           2 * tp_rank_start:2 * tp_rank_end,
                                           ...]

323
                param = params_dict[name]
324
325
326
327
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param,
                              narrow_weight,
328
                              weight_name=name,
329
330
                              shard_id=None,
                              expert_id=None)
331
332
333
                loaded_params.add(name)
                continue
            elif ".w2_weight_scale" in name:
334
335
336
337
                # Handle MLP down projection weights
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
338
339
340
                    narrow_weight = weight[..., tp_rank_start //
                                           mxfp4_block:tp_rank_end //
                                           mxfp4_block]
341

342
                param = params_dict[name]
343
344
345
346
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param,
                              narrow_weight,
347
                              weight_name=name,
348
349
                              shard_id=None,
                              expert_id=None)
350
351
352
353
354
355
356
357
                loaded_params.add(name)
                continue
            elif ".w13_weight" in name:
                # Handle MLP gate and up projection weights
                # flat weight from (E, 2 * N, block_size, entry_per_block)
                # to (E, 2 * N, -1), shouldn't trigger copy for contiguous
                weight = weight.view(num_experts, 2 * intermediate_size,
                                     -1).contiguous()
358

359
360
                # Extract gate and up projection parts
                # since the weight is shuffled, we can slice directly
361
362
363
364
365
366
367
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
                    narrow_weight = weight[:,
                                           2 * tp_rank_start:2 * tp_rank_end,
                                           ...]

368
                param = params_dict[name]
369
370
371
372
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param,
                              narrow_weight,
373
                              weight_name=name,
374
375
                              shard_id=None,
                              expert_id=None)
376
377
378
                loaded_params.add(name)
                continue
            elif ".w2_weight" in name:
379
                # Handle MLP down projection weights
380
381
382
383
                # same flatten here, but since 2 mx4 value are packed in 1
                # uint8, divide by 2
                weight = weight.view(num_experts, -1,
                                     intermediate_size // 2).contiguous()
384
385
386
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
387
388
                    narrow_weight = weight[...,
                                           tp_rank_start // 2:tp_rank_end // 2]
389

390
                param = params_dict[name]
391
392
393
394
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param,
                              narrow_weight,
395
                              weight_name=name,
396
397
                              shard_id=None,
                              expert_id=None)
398
399
400
                loaded_params.add(name)
                continue
            elif ".w13_bias" in name:
401
402
403
404
405
406
407
408
                # Handle MLP gate and up projection biases
                # Extract gate and up projection bias parts
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
                    narrow_weight = weight[:,
                                           2 * tp_rank_start:2 * tp_rank_end]

409
                param = params_dict[name]
410
411
412
413
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param,
                              narrow_weight,
414
                              weight_name=name,
415
416
                              shard_id=None,
                              expert_id=None)
417
418
419
                loaded_params.add(name)
                continue
            elif ".w2_bias" in name:
420
                # Handle MLP down projection bias
421
                param = params_dict[name]
422
423
424
425
426
427
428
429
430
431
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                if use_ep:
                    weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
                    # (only load on rank 0 to avoid duplication)
                    if tp_rank != 0:
                        weight.zero_()
                weight_loader(param,
                              weight,
432
                              weight_name=name,
433
434
                              shard_id=None,
                              expert_id=None)
435
436
                loaded_params.add(name)
                continue
437
438
439
440
441
442
            elif "sinks" in name:
                # Handle attention sinks (distributed across ranks)
                param = params_dict[name]
                narrow_weight = weight.narrow(0, head_start, heads_per_rank)
                param.data.copy_(narrow_weight)
                loaded_params.add(name)
443
444
445
446
447
448
449
450
451
452
453
454
455
                continue
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                if weight_loader == default_weight_loader:
                    weight_loader(param, weight)
                else:
                    weight_loader(param, weight, shard_id)
                break
456
457
            else:
                # Handle all other weights with potential renaming
458
                if name not in params_dict:
459
                    continue
460
                param = params_dict[name]
461
462
463
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, weight)
464
            loaded_params.add(name)
465
        return loaded_params
466
467

    def _load_weights_other(
468
469
470
471
472
473
474
475
        self,
        ep_rank_start: int,
        ep_rank_end: int,
        heads_per_rank: int,
        head_start: int,
        weights: Iterable[tuple[str, torch.Tensor]],
        stacked_params_mapping: list[tuple[str, ...]],
    ) -> set[str]:
476
477
478
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()

479
480
        use_ep = self.parallel_config.enable_expert_parallel

481
482
483
        tp_rank = get_tensor_model_parallel_rank()
        tp_size = get_tensor_model_parallel_world_size()

484
        intermediate_size = self.config.intermediate_size
485
486
487
488
489
490
491
        per_rank_intermediate_size = cdiv(intermediate_size, tp_size)
        # Calculate common slicing bounds for current rank
        tp_rank_start = tp_rank * per_rank_intermediate_size
        tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size,
                          intermediate_size)

        for name, weight in weights:
492
493
494
495
            # Skip layers on other devices.
            if is_pp_missing_parameter(name, self):
                continue

496
            if ".w13_weight" in name:
497
498
499
500
501
502
503
504
505
                # Handle MLP gate and up projection weights
                # Extract gate and up projection parts
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
                    narrow_weight = weight[:, :,
                                           2 * tp_rank_start:2 * tp_rank_end]

                narrow_weight = narrow_weight.permute(0, 2, 1).contiguous()
506
                param = params_dict[name]
507
508

                param.copy_(narrow_weight)
509
510
511
                loaded_params.add(name)
                continue
            elif ".w2_weight" in name:
512
513
514
515
516
517
                # Handle MLP down projection weights
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
                    narrow_weight = weight[:, tp_rank_start:tp_rank_end, :]
                narrow_weight = narrow_weight.permute(0, 2, 1).contiguous()
518
                param = params_dict[name]
519
520

                param.copy_(narrow_weight)
521
522
523
                loaded_params.add(name)
                continue
            elif ".w13_bias" in name:
524
525
526
527
528
529
530
531
                # Handle MLP gate and up projection biases
                # Extract gate and up projection bias parts
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
                    narrow_weight = weight[:,
                                           2 * tp_rank_start:2 * tp_rank_end]

532
                param = params_dict[name]
533
                param.copy_(narrow_weight)
534
535
536
                loaded_params.add(name)
                continue
            elif ".w2_bias" in name:
537
538
539
540
541
542
543
                # Handle MLP down projection bias
                if use_ep:
                    weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
                    # (only load on rank 0 to avoid duplication)
                    if tp_rank != 0:
                        weight.zero_()
544
                param = params_dict[name]
545
                param.copy_(weight)
546
547
                loaded_params.add(name)
                continue
548
549
550
551
552
553
            elif "sinks" in name:
                # Handle attention sinks (distributed across ranks)
                param = params_dict[name]
                narrow_weight = weight.narrow(0, head_start, heads_per_rank)
                param.data.copy_(narrow_weight)
                loaded_params.add(name)
554
555
556
557
558
559
560
561
562
563
564
565
566
                continue
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                if weight_loader == default_weight_loader:
                    weight_loader(param, weight)
                else:
                    weight_loader(param, weight, shard_id)
                break
567
568
            else:
                # Handle all other weights with potential renaming
569
                if name not in params_dict:
570
                    continue
571
                param = params_dict[name]
572
573
574
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, weight)
575
            loaded_params.add(name)
576
577
578
579
        return loaded_params

    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".qkv", ".q_proj", "q"),
            (".qkv", ".k_proj", "k"),
            (".qkv", ".v_proj", "v"),
        ]

        tp_rank = get_tensor_model_parallel_rank()
        tp_size = get_tensor_model_parallel_world_size()

        # Attention heads per rank
        heads_per_rank = self.config.num_attention_heads // tp_size
        head_start = tp_rank * heads_per_rank

        ep_size = get_ep_group().world_size
        ep_rank = get_ep_group().rank
        num_experts = self.config.num_local_experts
        experts_per_rank = num_experts // ep_size
        ep_rank_start = ep_rank * experts_per_rank
        ep_rank_end = (ep_rank + 1) * experts_per_rank

        quant_method = (self.config.quantization_config['quant_method'] if
                        hasattr(self.config, "quantization_config") else None)
603
        if quant_method == "mxfp4":
604
605
606
            return self._load_weights_mxfp4(ep_rank_end, ep_rank_start,
                                            heads_per_rank, head_start,
                                            weights, stacked_params_mapping)
607
        else:
608
609
610
611
612
            return self._load_weights_other(ep_rank_end, ep_rank_start,
                                            heads_per_rank, head_start,
                                            weights, stacked_params_mapping)


613
class GptOssForCausalLM(nn.Module, SupportsPP):
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
    packed_modules_mapping = {"qkv": ["q_proj", "k_proj", "v_proj"]}

    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_substr={
            ".self_attn.": ".attn.",
        },
        orig_to_new_suffix={
            ".embed_tokens.weight": ".embedding.weight",

            # MoE MXFP4 weights
            ".gate_up_proj_blocks": ".w13_weight",
            ".down_proj_blocks": ".w2_weight",
            ".gate_up_proj_scales": ".w13_weight_scale",
            ".down_proj_scales": ".w2_weight_scale",

            # MoE other weights
            ".gate_up_proj": ".w13_weight",
            ".down_proj": ".w2_weight",

            # MoE Bias
            ".gate_up_proj_bias": ".w13_bias",
            ".down_proj_bias": ".w2_bias",
        },
    )

    def __init__(
        self,
        vllm_config: VllmConfig,
        prefix: str = "",
    ):
        super().__init__()
        self.vllm_config = vllm_config
        self.config = vllm_config.model_config.hf_config

        self.model = GptOssModel(
            vllm_config=vllm_config,
            prefix=maybe_prefix(prefix, "model"),
        )
        self.lm_head = ParallelLMHead(
            self.config.vocab_size,
            self.config.hidden_size,
655
            prefix=maybe_prefix(prefix, "lm_head"),
656
657
        )
        self.logits_processor = LogitsProcessor(self.config.vocab_size)
658
659
660
661
662
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)

    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)
663
664
665
666
667
668

    def forward(self,
                input_ids: torch.Tensor,
                positions: torch.Tensor,
                intermediate_tensors: Optional[IntermediateTensors] = None,
                inputs_embeds: Optional[torch.Tensor] = None) -> torch.Tensor:
669
670
        return self.model(input_ids, positions, intermediate_tensors,
                          inputs_embeds)
671

672
673
    def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head, hidden_states)
674
675
676
677
678
679
680
681
682
683
        return logits

    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
        loader = AutoWeightsLoader(
            self,
            skip_prefixes=(["lm_head."]
                           if self.config.tie_word_embeddings else None),
        )
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)