qwen2.py 21.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

Junyang Lin's avatar
Junyang Lin committed
4
5
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2/modeling_qwen2.py
6
7
# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
Junyang Lin's avatar
Junyang Lin committed
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
26
"""Inference-only Qwen2 model compatible with HuggingFace weights."""
27

28
from collections.abc import Iterable
29
from itertools import islice
30
from typing import Any
Junyang Lin's avatar
Junyang Lin committed
31
32
33
34
35

import torch
from torch import nn
from transformers import Qwen2Config

36
from vllm.compilation.decorators import support_torch_compile
37
from vllm.config import CacheConfig, VllmConfig
38
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
Junyang Lin's avatar
Junyang Lin committed
39
from vllm.model_executor.layers.activation import SiluAndMul
40
41
from vllm.model_executor.layers.attention import (
    Attention,
42
43
    EncoderOnlyAttention,
)
Junyang Lin's avatar
Junyang Lin committed
44
from vllm.model_executor.layers.layernorm import RMSNorm
45
46
47
48
49
from vllm.model_executor.layers.linear import (
    MergedColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
50
from vllm.model_executor.layers.logits_processor import LogitsProcessor
51
from vllm.model_executor.layers.quantization import QuantizationConfig
52
from vllm.model_executor.layers.rotary_embedding import get_rope
Junyang Lin's avatar
Junyang Lin committed
53
from vllm.model_executor.layers.vocab_parallel_embedding import (
54
55
56
    ParallelLMHead,
    VocabParallelEmbedding,
)
57
from vllm.model_executor.model_loader.weight_utils import (
58
59
60
    default_weight_loader,
    maybe_remap_kv_scale_name,
)
61
from vllm.sequence import IntermediateTensors
62
from vllm.transformers_utils.config import is_interleaved, set_default_rope_theta
63
from vllm.v1.attention.backend import AttentionType
Junyang Lin's avatar
Junyang Lin committed
64

65
66
67
68
69
70
71
from .interfaces import (
    EagleModelMixin,
    SupportsEagle,
    SupportsEagle3,
    SupportsLoRA,
    SupportsPP,
)
72
73
74
75
76
77
78
79
80
from .utils import (
    AutoWeightsLoader,
    PPMissingLayer,
    extract_layer_index,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
81

Junyang Lin's avatar
Junyang Lin committed
82
83
84
85
86
87
88

class Qwen2MLP(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
89
        quant_config: QuantizationConfig | None = None,
90
        prefix: str = "",
Junyang Lin's avatar
Junyang Lin committed
91
92
93
    ) -> None:
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
94
95
96
97
98
99
100
101
102
            hidden_size,
            [intermediate_size] * 2,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.gate_up_proj",
        )
        self.down_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
Junyang Lin's avatar
Junyang Lin committed
103
            bias=False,
104
105
106
            quant_config=quant_config,
            prefix=f"{prefix}.down_proj",
        )
Junyang Lin's avatar
Junyang Lin committed
107
        if hidden_act != "silu":
108
109
110
            raise ValueError(
                f"Unsupported activation: {hidden_act}. Only silu is supported for now."
            )
Junyang Lin's avatar
Junyang Lin committed
111
112
113
114
115
116
117
118
119
120
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class Qwen2Attention(nn.Module):
121
    def __init__(
122
123
124
125
        self,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
126
        rope_parameters: dict[str, Any],
127
        max_position: int = 4096 * 32,
128
129
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
130
131
        prefix: str = "",
        attn_type: str = AttentionType.DECODER,
132
        dual_chunk_attention_config: dict[str, Any] | None = None,
133
134
        qk_norm: bool = False,
        rms_norm_eps: float = 1e-6,
135
    ) -> None:
Junyang Lin's avatar
Junyang Lin committed
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
156
        self.dual_chunk_attention_config = dual_chunk_attention_config
157
        self.qk_norm = qk_norm
Junyang Lin's avatar
Junyang Lin committed
158
159
160
161
162
163
164

        self.qkv_proj = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=True,
165
            quant_config=quant_config,
166
            prefix=f"{prefix}.qkv_proj",
Junyang Lin's avatar
Junyang Lin committed
167
168
169
170
171
        )
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
172
            quant_config=quant_config,
173
            prefix=f"{prefix}.o_proj",
Junyang Lin's avatar
Junyang Lin committed
174
175
        )

176
177
178
179
180
        # QK Normalization support (used in BAGEL and some other models)
        if self.qk_norm:
            self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
            self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)

Junyang Lin's avatar
Junyang Lin committed
181
182
183
        self.rotary_emb = get_rope(
            self.head_dim,
            max_position=max_position,
184
            rope_parameters=rope_parameters,
185
            dual_chunk_attention_config=dual_chunk_attention_config,
Junyang Lin's avatar
Junyang Lin committed
186
        )
187
188
189
190
191
        attn_cls = (
            EncoderOnlyAttention
            if attn_type == AttentionType.ENCODER_ONLY
            else Attention
        )
192
        self.attn = attn_cls(
193
194
195
196
197
198
199
200
201
202
203
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            attn_type=attn_type,
            prefix=f"{prefix}.attn",
            **{
                "layer_idx": extract_layer_index(prefix),
                "dual_chunk_attention_config": dual_chunk_attention_config,
204
205
206
207
            }
            if dual_chunk_attention_config
            else {},
        )
Junyang Lin's avatar
Junyang Lin committed
208
209
210
211
212
213
214
215

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232

        # Apply QK normalization if enabled (before RoPE)
        if self.qk_norm:
            # Reshape to apply per-head normalization
            # q shape: (total_tokens, q_size) -> (total_tokens, num_heads, head_dim)
            total_tokens = q.shape[0]
            q = q.view(total_tokens, self.num_heads, self.head_dim)
            k = k.view(total_tokens, self.num_kv_heads, self.head_dim)

            # Apply normalization
            q = self.q_norm(q)
            k = self.k_norm(k)

            # Reshape back
            q = q.view(total_tokens, self.q_size)
            k = k.view(total_tokens, self.kv_size)

Junyang Lin's avatar
Junyang Lin committed
233
        q, k = self.rotary_emb(positions, q, k)
234
        attn_output = self.attn(q, k, v)
Junyang Lin's avatar
Junyang Lin committed
235
236
237
238
239
240
241
242
        output, _ = self.o_proj(attn_output)
        return output


class Qwen2DecoderLayer(nn.Module):
    def __init__(
        self,
        config: Qwen2Config,
243
244
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
245
        prefix: str = "",
Junyang Lin's avatar
Junyang Lin committed
246
247
248
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
249
        set_default_rope_theta(config, default_theta=1000000)
250
251
252
        dual_chunk_attention_config = getattr(
            config, "dual_chunk_attention_config", None
        )
253
254
255
256
257
258
259
260
261
262

        # By default, Qwen2 uses causal attention as it is a decoder-only model.
        # You can override the HF config with `is_causal=False` to enable
        # bidirectional attention, which is used in some embedding models
        # (e.g. Alibaba-NLP/gte-Qwen2-7B-instruct)
        if getattr(config, "is_causal", True):
            attn_type = AttentionType.DECODER
        else:
            attn_type = AttentionType.ENCODER_ONLY

263
264
265
        # Check if QK normalization is enabled (used in BAGEL and some other models)
        qk_norm = getattr(config, "qk_norm", False)

Junyang Lin's avatar
Junyang Lin committed
266
267
268
269
270
        self.self_attn = Qwen2Attention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            max_position=config.max_position_embeddings,
            num_kv_heads=config.num_key_value_heads,
271
            cache_config=cache_config,
272
            quant_config=quant_config,
273
            rope_parameters=config.rope_parameters,
274
            prefix=f"{prefix}.self_attn",
275
            attn_type=attn_type,
276
            dual_chunk_attention_config=dual_chunk_attention_config,
277
278
            qk_norm=qk_norm,
            rms_norm_eps=config.rms_norm_eps,
279
        )
Junyang Lin's avatar
Junyang Lin committed
280
281
282
283
        self.mlp = Qwen2MLP(
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
284
            quant_config=quant_config,
285
            prefix=f"{prefix}.mlp",
Junyang Lin's avatar
Junyang Lin committed
286
        )
287
288
289
290
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
Junyang Lin's avatar
Junyang Lin committed
291
292
293
294
295

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
296
        residual: torch.Tensor | None,
297
    ) -> tuple[torch.Tensor, torch.Tensor]:
Junyang Lin's avatar
Junyang Lin committed
298
299
300
301
302
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
303
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
Junyang Lin's avatar
Junyang Lin committed
304
305
306
307
308
309
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )

        # Fully Connected
310
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
Junyang Lin's avatar
Junyang Lin committed
311
312
313
314
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
def qwen_2_model_invariants(
    input_ids: torch.Tensor,
    positions: torch.Tensor,
    intermediate_tensors: IntermediateTensors | None = None,
    inputs_embeds: torch.Tensor | None = None,
):
    """Shape invariants for Qwen2Model Model, those are translated to
    runtime assertions for unbacked dynamic shapes and are compiled away for
    backed"""
    # All these should be equal.
    # input_ids.size()[0]
    # positions.size()[-1]
    # intermediate_tensors["hidden_states"].size()[0]
    # inputs_embeds.size()[0]
    torch._check(input_ids.size()[0] == positions.size()[-1])
    if intermediate_tensors is not None:
        torch._check(
            input_ids.size()[0] == intermediate_tensors["hidden_states"].size()[0]
        )

    if inputs_embeds is not None:
        torch._check(input_ids.size()[0] == inputs_embeds.size()[0])

    # Hidden dimensions should match (hidden_size)
    # intermediate_tensors["hidden_states"].size()[1]
    # inputs_embeds.size()[1]
    if inputs_embeds is not None and intermediate_tensors is not None:
        torch._check(
            inputs_embeds.size()[1] == intermediate_tensors["hidden_states"].size()[1]
        )


347
348
349
350
351
352
353
354
@support_torch_compile(
    dynamic_arg_dims={
        "input_ids": 0,
        # positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl,
        # otherwise (seq_len, ).
        "positions": -1,
        "intermediate_tensors": 0,
        "inputs_embeds": 0,
355
356
    },
    shape_invariants=qwen_2_model_invariants,
357
)
358
class Qwen2Model(nn.Module, EagleModelMixin):
359
360
361
362
363
364
365
    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
        decoder_layer_type: type[nn.Module] = Qwen2DecoderLayer,
    ):
Junyang Lin's avatar
Junyang Lin committed
366
        super().__init__()
367

368
        config = vllm_config.model_config.hf_config.get_text_config()
369
370
371
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

372
        # TODO (@robertgshaw2): see if this can be moved out
373
        if is_interleaved(vllm_config.model_config.hf_text_config):
374
375
376
377
378
379
380
            assert config.max_window_layers == config.num_hidden_layers, (
                "Sliding window for some but all layers is not supported. "
                "This model uses sliding window but `max_window_layers` = {} "
                "is less than `num_hidden_layers` = {}. Please open an issue "
                "to discuss this feature.".format(
                    config.max_window_layers,
                    config.num_hidden_layers,
381
382
                )
            )
383

Junyang Lin's avatar
Junyang Lin committed
384
        self.config = config
385
        self.quant_config = quant_config
Junyang Lin's avatar
Junyang Lin committed
386
387
        self.vocab_size = config.vocab_size

388
389
390
        if get_pp_group().is_first_rank or (
            config.tie_word_embeddings and get_pp_group().is_last_rank
        ):
391
392
393
394
            self.embed_tokens = VocabParallelEmbedding(
                config.vocab_size,
                config.hidden_size,
                quant_config=quant_config,
395
                prefix=f"{prefix}.embed_tokens",
396
397
398
399
            )
        else:
            self.embed_tokens = PPMissingLayer()

400
401
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
402
403
404
405
406
407
            lambda prefix: decoder_layer_type(
                config=config,
                cache_config=cache_config,
                quant_config=quant_config,
                prefix=prefix,
            ),
408
409
410
            prefix=f"{prefix}.layers",
        )

411
412
413
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size
        )
414
415
416
417
        if get_pp_group().is_last_rank:
            self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        else:
            self.norm = PPMissingLayer()
Junyang Lin's avatar
Junyang Lin committed
418

419
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
420
421
        return self.embed_tokens(input_ids)

Junyang Lin's avatar
Junyang Lin committed
422
423
    def forward(
        self,
424
        input_ids: torch.Tensor | None,
Junyang Lin's avatar
Junyang Lin committed
425
        positions: torch.Tensor,
426
427
428
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
429
430
431
432
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
433
                hidden_states = self.embed_input_ids(input_ids)
434
            residual = None
435
        else:
436
437
438
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
439

440
        aux_hidden_states = self._maybe_add_hidden_state([], 0, hidden_states, residual)
441
        for idx, layer in enumerate(
442
443
            islice(self.layers, self.start_layer, self.end_layer)
        ):
444
            hidden_states, residual = layer(positions, hidden_states, residual)
445
446
447
            self._maybe_add_hidden_state(
                aux_hidden_states, idx + 1, hidden_states, residual
            )
448

449
        if not get_pp_group().is_last_rank:
450
451
452
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
453

Junyang Lin's avatar
Junyang Lin committed
454
        hidden_states, _ = self.norm(hidden_states, residual)
455
456
457
458

        if len(aux_hidden_states) > 0:
            return hidden_states, aux_hidden_states

Junyang Lin's avatar
Junyang Lin committed
459
460
        return hidden_states

461
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
462
463
464
465
466
467
468
469
470
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
        params_dict = dict(self.named_parameters(remove_duplicate=False))
471
        loaded_params: set[str] = set()
472
473
474
        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue
475
476
477
            if self.quant_config is not None and (
                scale_name := self.quant_config.get_cache_scale(name)
            ):
478
                # Loading kv cache quantization scales
479
                param = params_dict[scale_name]
480
481
482
483
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                loaded_weight = (
                    loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
                )
484
485
486
                weight_loader(param, loaded_weight)
                loaded_params.add(scale_name)
                continue
487
            for param_name, weight_name, shard_id in stacked_params_mapping:
488
489
490
491
492
493
494
495
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
496
497
498
499
500
                if name.endswith("scale"):
                    # Remapping the name of FP8 kv-scale.
                    name = maybe_remap_kv_scale_name(name, params_dict)
                    if name is None:
                        continue
501
                param = params_dict[name]
502
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
503
504
505
506
                if weight_loader == default_weight_loader:
                    weight_loader(param, loaded_weight)
                else:
                    weight_loader(param, loaded_weight, shard_id)
507
508
509
510
511
512
513
514
515
516
517
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                # Remapping the name of FP8 kv-scale.
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
518
519
                if name not in params_dict:
                    continue
520
                param = params_dict[name]
521
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
522
                weight_loader(param, loaded_weight)
523
524
            loaded_params.add(name)
        return loaded_params
525

Junyang Lin's avatar
Junyang Lin committed
526

527
528
529
class Qwen2ForCausalLM(
    nn.Module, SupportsLoRA, SupportsPP, SupportsEagle, SupportsEagle3
):
530
531
532
533
534
535
536
537
538
539
540
541
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }

542
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
543
        super().__init__()
544
        config = vllm_config.model_config.hf_config.get_text_config()
545
        quant_config = vllm_config.quant_config
546

Junyang Lin's avatar
Junyang Lin committed
547
        self.config = config
548

549
        self.quant_config = quant_config
550
551
552
        self.model = Qwen2Model(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
553

554
555
556
557
        if get_pp_group().is_last_rank:
            if config.tie_word_embeddings:
                self.lm_head = self.model.embed_tokens
            else:
558
559
560
561
562
563
                self.lm_head = ParallelLMHead(
                    config.vocab_size,
                    config.hidden_size,
                    quant_config=quant_config,
                    prefix=maybe_prefix(prefix, "lm_head"),
                )
564
        else:
565
            self.lm_head = PPMissingLayer()
566

567
        self.logits_processor = LogitsProcessor(config.vocab_size)
568

569
        self.make_empty_intermediate_tensors = (
570
571
            self.model.make_empty_intermediate_tensors
        )
Junyang Lin's avatar
Junyang Lin committed
572

573
574
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
575

Junyang Lin's avatar
Junyang Lin committed
576
577
    def forward(
        self,
578
        input_ids: torch.Tensor | None,
Junyang Lin's avatar
Junyang Lin committed
579
        positions: torch.Tensor,
580
581
582
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
583
584
585
        hidden_states = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
Junyang Lin's avatar
Junyang Lin committed
586
587
        return hidden_states

588
589
590
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
591
    ) -> torch.Tensor | None:
592
        logits = self.logits_processor(self.lm_head, hidden_states)
593
594
        return logits

595
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
596
597
        loader = AutoWeightsLoader(
            self,
598
            skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
599
        )
600
        return loader.load_weights(weights)