"tests/models/language/pooling/test_embedding.py" did not exist on "4a18fd14ba4a349291c798a16bf62fa8a9af0b6b"
granitemoehybrid.py 23.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Inference-only GraniteMoeHybrid model."""
4

5
# Added by the IBM Team, 2025
6
7
from collections.abc import Iterable
from typing import Optional
8
9
10
11
12
13

import torch
from torch import nn
from transformers import GraniteMoeHybridConfig

from vllm.attention.layer import Attention
14
from vllm.compilation.decorators import support_torch_compile
15
from vllm.config import CacheConfig, ModelConfig, VllmConfig
16
from vllm.distributed import get_tensor_model_parallel_world_size
17
18
from vllm.distributed.parallel_state import get_pp_group
from vllm.model_executor.layers.layernorm import RMSNorm
19
from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear
20
from vllm.model_executor.layers.logits_processor import LogitsProcessor
21
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
22
from vllm.model_executor.layers.mamba.mamba_utils import (
23
24
25
    MambaStateDtypeCalculator,
    MambaStateShapeCalculator,
)
26
27
28
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
29
30
31
32
    DEFAULT_VOCAB_PADDING_SIZE,
    ParallelLMHead,
    VocabParallelEmbedding,
)
33
34
35
36
37
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import IntermediateTensors

from .granitemoe import GraniteMoeMoE
from .granitemoeshared import GraniteMoeSharedMLP
38
39
40
41
42
43
44
45
from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP, SupportsQuant
from .utils import (
    AutoWeightsLoader,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
46
47
48


class GraniteMoeHybridMambaDecoderLayer(nn.Module):
49
50
51
52
53
54
55
56
57
    def __init__(
        self,
        config: GraniteMoeHybridConfig,
        layer_idx: int,
        model_config: Optional[ModelConfig] = None,
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ) -> None:
58
59
60
61
62
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.residual_multiplier = config.residual_multiplier

63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
        self.mamba = MambaMixer2(
            hidden_size=config.hidden_size,
            ssm_state_size=config.mamba_d_state,
            conv_kernel_size=config.mamba_d_conv,
            intermediate_size=config.mamba_expand * config.hidden_size,
            use_conv_bias=config.mamba_conv_bias,
            use_bias=config.mamba_proj_bias,
            n_groups=config.mamba_n_groups,
            num_heads=config.mamba_n_heads,
            head_dim=config.mamba_d_head,
            rms_norm_eps=config.rms_norm_eps,
            activation=config.hidden_act,
            model_config=model_config,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.mixer",
        )
80

81
82
83
84
85
86
87
88
        self.block_sparse_moe = None
        if getattr(config, "num_local_experts", 0) > 0:
            self.block_sparse_moe = GraniteMoeMoE(
                num_experts=config.num_local_experts,
                top_k=config.num_experts_per_tok,
                hidden_size=config.hidden_size,
                intermediate_size=config.intermediate_size,
                quant_config=quant_config,
89
90
                prefix=f"{prefix}.block_sparse_moe",
            )
91

92
93
94
        self.shared_mlp = (
            None
            if getattr(config, "shared_intermediate_size", 0) == 0
95
            else GraniteMoeSharedMLP(
96
                config, quant_config=quant_config, prefix=f"{prefix}.shared_mlp"
97
            )
98
        )
99

100
101
102
103
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
104
105
106
107
108
109
110
111
112

    def forward(
        self,
        hidden_states: torch.Tensor,
        residual: Optional[torch.Tensor],
        **kwargs,
    ):
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)
113
        output = torch.empty_like(hidden_states)
114
        self.mamba(hidden_states, output)
115
        hidden_states = residual + output * self.residual_multiplier
116
117
118
119

        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        if self.shared_mlp is None:
120
121
122
            if self.block_sparse_moe is not None:
                hidden_states = self.block_sparse_moe(hidden_states)
            # else: skip
123
124
        else:
            # create a copy since block_sparse_moe modifies in-place
125
126
127
            if self.block_sparse_moe is not None:
                moe_hidden_states = hidden_states.clone()
                moe_hidden_states = self.block_sparse_moe(moe_hidden_states)
128
                hidden_states = moe_hidden_states + self.shared_mlp(hidden_states)
129
130
131
                del moe_hidden_states
            else:
                hidden_states = self.shared_mlp(hidden_states)
132
133
134
135
136
137
138
139
140
141
        hidden_states = residual + hidden_states * self.residual_multiplier

        return hidden_states, residual


class GraniteMoeHybridAttentionDecoderLayer(nn.Module):
    def __init__(
        self,
        config: GraniteMoeHybridConfig,
        layer_idx: int,
142
        model_config: Optional[ModelConfig] = None,
143
144
145
146
147
148
149
150
151
152
153
154
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
        self.residual_multiplier = config.residual_multiplier

        self.self_attn = GraniteMoeHybridAttention(
            config,
            cache_config=cache_config,
            quant_config=quant_config,
155
156
            prefix=f"{prefix}.self_attn",
        )
157

158
159
160
161
162
163
164
165
        self.block_sparse_moe = None
        if getattr(config, "num_local_experts", 0) > 0:
            self.block_sparse_moe = GraniteMoeMoE(
                num_experts=config.num_local_experts,
                top_k=config.num_experts_per_tok,
                hidden_size=config.hidden_size,
                intermediate_size=config.intermediate_size,
                quant_config=quant_config,
166
167
                prefix=f"{prefix}.block_sparse_moe",
            )
168

169
170
171
        self.shared_mlp = (
            None
            if getattr(config, "shared_intermediate_size", 0) == 0
172
            else GraniteMoeSharedMLP(
173
                config, quant_config=quant_config, prefix=f"{prefix}.shared_mlp"
174
            )
175
        )
176

177
178
179
180
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        residual: Optional[torch.Tensor],
    ) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)

        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )
        hidden_states = residual + hidden_states * self.residual_multiplier

        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        if self.shared_mlp is None:
200
201
202
            if self.block_sparse_moe is not None:
                hidden_states = self.block_sparse_moe(hidden_states)
            # else: skip
203
204
        else:
            # create a copy since block_sparse_moe modifies in-place
205
206
207
            if self.block_sparse_moe is not None:
                moe_hidden_states = hidden_states.clone()
                moe_hidden_states = self.block_sparse_moe(moe_hidden_states)
208
                hidden_states = moe_hidden_states + self.shared_mlp(hidden_states)
209
210
211
                del moe_hidden_states
            else:
                hidden_states = self.shared_mlp(hidden_states)
212
213
214
215
216
217
218
219
220
        hidden_states = residual + hidden_states * self.residual_multiplier

        return hidden_states, residual


class GraniteMoeHybridAttention(nn.Module):
    def __init__(
        self,
        config: GraniteMoeHybridConfig,
221
        model_config: Optional[ModelConfig] = None,
222
223
224
225
226
227
228
229
230
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.causal = True
        self.hidden_size = config.hidden_size
        self.attention_bias = config.attention_bias
        self.attention_multiplier = config.attention_multiplier
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
        self.total_num_heads = config.num_attention_heads
        self.head_dim = self.hidden_size // self.total_num_heads
        self.total_num_kv_heads = config.num_key_value_heads

        # TensorParallel logic
        tp_size = get_tensor_model_parallel_world_size()
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_key_value_heads = max(1, self.total_num_kv_heads // tp_size)

249
250
251
252
253
254
255
256
257
        self.qkv_proj = QKVParallelLinear(
            self.hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=self.attention_bias,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
        )
258

259
260
261
262
263
264
265
        self.o_proj = RowParallelLinear(
            self.hidden_size,
            self.hidden_size,
            bias=self.attention_bias,
            quant_config=quant_config,
            prefix=f"{prefix}.o_proj",
        )
266
267
268
269
270
271
272

        if config.position_embedding_type == "rope":
            self.rotary_emb = get_rope(
                self.head_dim,
                rotary_dim=self.head_dim,
                max_position=config.max_position_embeddings,
                base=int(config.rope_theta),
273
274
275
                rope_scaling=config.rope_scaling
                if hasattr(config, "rope_scaling") and config.rope_scaling is not None
                else None,
276
277
278
279
280
                is_neox_style=True,
            )
        else:
            self.rotary_emb = None

281
282
283
284
285
286
287
288
289
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.attention_multiplier,
            num_kv_heads=self.num_key_value_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )
290
291
292
293
294
295

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
296
        qkv, _ = self.qkv_proj(hidden_states)
297
298
299
300
301
302
303
304
        query, key, value = qkv.split(
            [
                self.num_heads * self.head_dim,
                self.num_key_value_heads * self.head_dim,
                self.num_key_value_heads * self.head_dim,
            ],
            dim=-1,
        )
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321

        if self.rotary_emb is not None:
            query, key = self.rotary_emb(positions, query, key)

        hidden_states = self.attn(query, key, value)
        del query, key, value

        hidden_states = self.o_proj(hidden_states)[0]
        return hidden_states


ALL_DECODER_LAYER_TYPES = {
    "attention": GraniteMoeHybridAttentionDecoderLayer,
    "mamba": GraniteMoeHybridMambaDecoderLayer,
}


322
@support_torch_compile
323
324
325
326
327
class GraniteMoeHybridModel(nn.Module):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

        config = vllm_config.model_config.hf_config
328
        model_config = vllm_config.model_config
329
330
331
332
333
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config

        self.config = config
334
335
336
337
338
        lora_vocab = (
            (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1))
            if lora_config
            else 0
        )
339
340
341
342
343
344
345
346
347
348
349
350
        self.vocab_size = config.vocab_size + lora_vocab
        self.org_vocab_size = config.vocab_size

        self.embed_tokens = VocabParallelEmbedding(
            self.vocab_size,
            config.hidden_size,
            org_num_embeddings=config.vocab_size,
        )
        self.embedding_multiplier = config.embedding_multiplier

        def get_layer(prefix: str):
            layer_idx = int(prefix.rsplit(".", 1)[1])
351
            layer_class = ALL_DECODER_LAYER_TYPES[config.layer_types[layer_idx]]
352
353
354
            return layer_class(
                config,
                layer_idx,
355
                model_config,
356
357
358
359
360
361
                cache_config,
                quant_config=quant_config,
                prefix=prefix,
            )

        self.start_layer, self.end_layer, self.layers = make_layers(
362
363
364
365
366
            config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers"
        )
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size
        )
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388

        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
                hidden_states = hidden_states * self.embedding_multiplier
            residual = None
        else:
            if intermediate_tensors is None:
389
                raise RuntimeError("Intermediate tensors may not be None!")
390
391
392
393
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

        num_attn = 0
394
        for i, layer in enumerate(self.layers):
395
396
            if isinstance(layer, GraniteMoeHybridAttentionDecoderLayer):
                num_attn += 1
397
398
399
            hidden_states, residual = layer(
                positions=positions, hidden_states=hidden_states, residual=residual
            )
400
401

        if not get_pp_group().is_last_rank:
402
403
404
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
405
406
407
408

        hidden_states = self.norm(hidden_states)
        return hidden_states

409
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
410
411
412
413
414
415
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
        ]
416
        params_dict = dict(self.named_parameters())
417
        loaded_params: set[str] = set()
418
419
420

        def _load(n, p):
            param = params_dict[n]
421
            weight_loader = getattr(param, "weight_loader", default_weight_loader)
422
423
424
            weight_loader(param, p)
            loaded_params.add(n)

425
426
427
428
        def _load_shard(n, p, shard_id):
            # Skip layers on other devices.
            if not is_pp_missing_parameter(n, self):
                param = params_dict[n]
429
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
430
431
432
                weight_loader(param, p, shard_id)
                loaded_params.add(n)

433
434
        def _load_expert(n, p, name, shard_id, expert_id):
            param = params_dict[n]
435
436
            weight_loader = getattr(param, "weight_loader", default_weight_loader)
            weight_loader(param, p, name, shard_id=shard_id, expert_id=expert_id)
437
438
439
440
441
442
443
444
445
446
            loaded_params.add(n)

        for n, p in weights:
            if "A_log" in n:
                n = n.replace("A_log", "A")

            # Logic analogous to: https://github.com/vllm-project/vllm/blob/f49e5aff11c986ed4d45202b1716c5d74786efa9/vllm/model_executor/models/granitemoeshared.py#L215
            # Mapping different experts' layout:
            #  from HF (input_linear, output_linear, router)
            #  to vLLM (experts_w13({e}.w1, {e}.w2), experts_w3({e}.w3), gate)
447
448
            # The renaming and parameter loading logic is the same for weight
            # and weight_scale tensors so we can reuse them without issues.
449
450
451
            if n.endswith(".block_sparse_moe.input_linear.weight") or n.endswith(
                ".block_sparse_moe.input_linear.weight_scale"
            ):
452
453
                for e in range(p.size(0)):
                    w1_name = n.replace(
454
455
456
                        ".block_sparse_moe.input_linear.weight",
                        f".block_sparse_moe.experts.{e}.w1.weight",
                    )
457
                    w3_name = n.replace(
458
459
460
                        ".block_sparse_moe.input_linear.weight",
                        f".block_sparse_moe.experts.{e}.w3.weight",
                    )
461
                    w1_param, w3_param = p[e].chunk(2, dim=0)
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
                    _load_expert(
                        n.replace(".input_linear.", ".experts.w13_"),
                        w1_param,
                        w1_name,
                        shard_id="w1",
                        expert_id=e,
                    )
                    _load_expert(
                        n.replace(".input_linear.", ".experts.w13_"),
                        w3_param,
                        w3_name,
                        shard_id="w3",
                        expert_id=e,
                    )
            elif n.endswith(".block_sparse_moe.output_linear.weight") or n.endswith(
                ".block_sparse_moe.output_linear.weight_scale"
            ):
479
480
                for e in range(p.size(0)):
                    w2_name = n.replace(
481
482
483
                        ".block_sparse_moe.output_linear.weight",
                        f".block_sparse_moe.experts.{e}.w2.weight",
                    )
484
                    w2_param = p[e]
485
486
487
488
489
490
491
492
493
494
495
496
                    _load_expert(
                        n.replace(".output_linear.", ".experts.w2_"),
                        w2_param,
                        w2_name,
                        shard_id="w2",
                        expert_id=e,
                    )
            elif n.endswith(".block_sparse_moe.router.layer.weight"):
                gate_name = n.replace(
                    ".block_sparse_moe.router.layer.weight",
                    ".block_sparse_moe.gate.weight",
                )
497
498
                _load(gate_name, p)
            else:
499
500
501
                loaded = False
                for param_name, weight_name, shard_id in stacked_params_mapping:
                    if weight_name in n:
502
503
504
                        _load_shard(
                            n.replace(weight_name, param_name), p, shard_id=shard_id
                        )
505
506
507
                        loaded = True
                if not loaded:
                    _load(n, p)
508
509
510
511

        return loaded_params


512
513
514
class GraniteMoeHybridForCausalLM(
    nn.Module, HasInnerState, SupportsLoRA, SupportsPP, IsHybrid, SupportsQuant
):
515
516
517
518
519
520
521
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
    }
522
523
524
525
526
527
    embedding_modules = {
        "embed_tokens": "input_embeddings",
        "lm_head": "output_embeddings",
    }
    embedding_padding_modules = ["lm_head"]

528
529
530
531
532
533
534
535
536
537
538
    @classmethod
    def get_mamba_state_dtype_from_config(
        cls,
        vllm_config: "VllmConfig",
    ) -> tuple[torch.dtype, torch.dtype]:
        return MambaStateDtypeCalculator.mamba2_state_dtype(
            vllm_config.model_config.dtype,
            vllm_config.cache_config.mamba_cache_dtype,
            vllm_config.cache_config.mamba_ssm_cache_dtype,
        )

539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
    @classmethod
    def get_mamba_state_shape_from_config(
        cls,
        vllm_config: "VllmConfig",
    ) -> tuple[tuple[int, int], tuple[int, int, int]]:
        """Calculate shapes for Mamba's convolutional and state caches.

        Args:
            vllm_config: vLLM config

        Returns:
            Tuple containing:
            - conv_state_shape: Shape for convolutional state cache
            - temporal_state_shape: Shape for state space model cache
        """
        parallel_config = vllm_config.parallel_config
        hf_config = vllm_config.model_config.hf_config
        intermediate_size = hf_config.mamba_expand * hf_config.hidden_size

558
        return MambaStateShapeCalculator.mamba2_state_shape(
559
560
561
562
563
564
565
566
567
            intermediate_size=intermediate_size,
            tp_world_size=parallel_config.tensor_parallel_size,
            n_groups=hf_config.mamba_n_groups,
            num_heads=hf_config.mamba_n_heads,
            head_dim=hf_config.mamba_d_head,
            state_size=hf_config.mamba_d_state,
            conv_kernel=hf_config.mamba_d_conv,
        )

568
569
570
571
572
573
574
575
576
577
578
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

        config = vllm_config.model_config.hf_config
        self.vllm_config = vllm_config
        self.model_config = vllm_config.model_config
        lora_config = vllm_config.lora_config
        scheduler_config = vllm_config.scheduler_config
        self.quant_config = vllm_config.quant_config
        self.config = config
        self.scheduler_config = scheduler_config
579
580
581
        self.model = GraniteMoeHybridModel(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
582
583
584
585
586
587
588
589
590
591
592
        self.unpadded_vocab_size = config.vocab_size
        if lora_config:
            self.unpadded_vocab_size += lora_config.lora_extra_vocab_size

        self.lm_head = ParallelLMHead(
            self.unpadded_vocab_size,
            config.hidden_size,
            org_num_embeddings=config.vocab_size,
            padding_size=DEFAULT_VOCAB_PADDING_SIZE
            # We need bigger padding if using lora for kernel
            # compatibility
593
594
            if not lora_config
            else lora_config.lora_vocab_padding_size,
595
            quant_config=self.quant_config,
596
597
            prefix=maybe_prefix(prefix, "lm_head"),
        )
598
599
        if config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight
600
601
602
603
604
        self.logits_processor = LogitsProcessor(
            self.unpadded_vocab_size,
            config.vocab_size,
            scale=1 / self.config.logits_scaling,
        )
605
606

        self.make_empty_intermediate_tensors = (
607
608
            self.model.make_empty_intermediate_tensors
        )
609
610
611
612

    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

613
614
615
616
617
618
619
620
621
622
623
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        **kwargs,
    ):
        hidden_states = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
624
625
626
627
628
629
630

        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> Optional[torch.Tensor]:
631
        logits = self.logits_processor(self.lm_head, hidden_states)
632
633
        return logits

634
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
635
636
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights)