bert.py 25 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from collections.abc import Iterable, Set
5
from typing import Optional, Union
6
7
8
9
10

import torch
from torch import nn
from transformers import BertConfig

11
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
12
from vllm.compilation.decorators import support_torch_compile
13
from vllm.config import CacheConfig, PoolerConfig, VllmConfig
14
from vllm.distributed import get_tensor_model_parallel_world_size
15
from vllm.model_executor.layers.activation import get_act_fn
16
17
18
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
19
20
from vllm.model_executor.layers.pooler import (ClassifierPooler,
                                               DispatchPooler, Pooler,
21
22
                                               PoolingMethod,
                                               PoolingParamsUpdate,
23
                                               PoolingType)
24
from vllm.model_executor.layers.quantization import QuantizationConfig
25
26
from vllm.model_executor.layers.vocab_parallel_embedding import (
    VocabParallelEmbedding)
27
from vllm.sequence import IntermediateTensors
28
from vllm.tasks import PoolingTask
29
from vllm.v1.pool.metadata import PoolingMetadata
30

31
32
from .interfaces import SupportsCrossEncoding, SupportsQuant
from .interfaces_base import default_pooling_type
33
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
34

35
36
37
38
39
40
41
42
43

class BertEmbedding(nn.Module):

    def __init__(self, config: BertConfig):

        super().__init__()
        self.size = config.hidden_size
        self.word_embeddings = VocabParallelEmbedding(config.vocab_size,
                                                      config.hidden_size)
44
45
        self.position_embeddings = VocabParallelEmbedding(
            config.max_position_embeddings, config.hidden_size)
46
47
48
49
50
        self.token_type_embeddings = VocabParallelEmbedding(
            config.type_vocab_size, config.hidden_size)
        self.LayerNorm = nn.LayerNorm(config.hidden_size,
                                      eps=config.layer_norm_eps)

51
52
53
54
        self.register_buffer(
            "position_ids",
            torch.arange(config.max_position_embeddings).unsqueeze(0),
        )
55
        self.position_embedding_type = config.position_embedding_type
56
57
58
        if self.position_embedding_type != "absolute":
            raise ValueError("Only 'absolute' position_embedding_type" +
                             " is supported")
59
60
61
62

    def forward(
        self,
        input_ids: torch.Tensor,
63
        position_ids: torch.Tensor,
64
65
    ) -> torch.Tensor:

66
        token_type_ids = _decode_token_type_ids(input_ids)
67

68
        inputs_embeds = self.word_embeddings(input_ids)
69
70
        position_embeddings = self.position_embeddings(position_ids)

71
        token_type_embeddings = self.token_type_embeddings(token_type_ids)
72

73
        embeddings = inputs_embeds + token_type_embeddings + position_embeddings
74
75
76
77
        embeddings = self.LayerNorm(embeddings)
        return embeddings


78
class BertPooler(Pooler):
79
80
81

    def __init__(self, config: BertConfig):
        super().__init__()
82
83

        self.pooling = PoolingMethod.from_pooling_type(PoolingType.CLS)
84
85
86
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.activation = nn.Tanh()

87
88
89
90
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return self.pooling.get_supported_tasks()

    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
91
        return self.pooling.get_pooling_updates(task)
92

93
94
95
96
97
    def _head(self, pooled_output: torch.Tensor):
        pooled_output = self.dense(pooled_output)
        pooled_output = self.activation(pooled_output)
        return pooled_output

98
99
100
101
102
103
    def forward(
        self,
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
        pooling_metadata: PoolingMetadata,
    ) -> Union[torch.Tensor, list[torch.Tensor]]:
        pooled_output = self.pooling(hidden_states, pooling_metadata)
104
105
106
107
108
109

        if isinstance(pooled_output, list):
            pooled_output = [self._head(output) for output in pooled_output]
        else:
            pooled_output = self._head(pooled_output)

110
111
112
        return pooled_output


113
114
class BertEncoder(nn.Module):

115
    def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
116
        super().__init__()
117
118
119
        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
120
121
122
123
124
125
126
127
128
129
130
131
        self.layer = nn.ModuleList([
            BertLayer(config=config,
                      cache_config=cache_config,
                      quant_config=quant_config,
                      prefix=f"{prefix}.layer.{layer_idx}")
            for layer_idx in range(config.num_hidden_layers)
        ])

    def forward(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
132
        for layer in self.layer:
133
            hidden_states = layer(hidden_states)
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
        return hidden_states


class BertLayer(nn.Module):

    def __init__(self,
                 config: BertConfig,
                 cache_config: Optional[CacheConfig] = None,
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
        super().__init__()

        self.attention = BertAttention(
            hidden_size=config.hidden_size,
            num_attention_heads=config.num_attention_heads,
            layer_norm_eps=config.layer_norm_eps,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attention")

154
155
156
157
158
159
        self.intermediate = BertIntermediate(
            hidden_size=config.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
            quant_config=quant_config,
            prefix=f"{prefix}.intermediate")
160
161
162
163
164
165
166

        self.output = BertOutput(hidden_size=config.hidden_size,
                                 intermediate_size=config.intermediate_size,
                                 layer_norm_eps=config.layer_norm_eps,
                                 quant_config=quant_config,
                                 prefix=f"{prefix}.output")

167
168
    def forward(self, hidden_states: torch.Tensor):
        attn_output = self.attention(hidden_states)
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
        intermediate_output = self.intermediate(attn_output)
        output = self.output(intermediate_output, attn_output)
        return output


class BertAttention(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        num_attention_heads: int,
        layer_norm_eps: float,
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ):
        super().__init__()

        self.self = BertSelfAttention(hidden_size=hidden_size,
                                      num_attention_heads=num_attention_heads,
                                      cache_config=cache_config,
                                      quant_config=quant_config,
                                      prefix=f"{prefix}.output")

        self.output = BertSelfOutput(hidden_size=hidden_size,
                                     layer_norm_eps=layer_norm_eps,
                                     quant_config=quant_config,
                                     prefix=f"{prefix}.output")

    def forward(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
202
        self_output = self.self(hidden_states)
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
        return self.output(self_output, hidden_states)


class BertSelfAttention(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        num_attention_heads: int,
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ):
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()

        self.total_num_heads = num_attention_heads
        assert self.total_num_heads % tp_size == 0

        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = self.total_num_heads
        self.head_dim = self.hidden_size // self.total_num_heads
        assert self.head_dim * self.total_num_heads == self.hidden_size

        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)

        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.qkv_proj = QKVParallelLinear(
            hidden_size=self.hidden_size,
            head_size=self.head_dim,
            total_num_heads=self.total_num_heads,
            total_num_kv_heads=self.total_num_kv_heads,
238
            bias=True,
239
240
241
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj")

242
243
244
245
246
247
248
        self.attn = EncoderOnlyAttention(num_heads=self.num_heads,
                                         head_size=self.head_dim,
                                         scale=self.scaling,
                                         num_kv_heads=self.num_kv_heads,
                                         cache_config=cache_config,
                                         quant_config=quant_config,
                                         prefix=f"{prefix}.attn")
249
250
251
252
253
254
255

    def forward(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
256
        output = self.attn(q, k, v)
257
258
259
260
261
262
263
264
265
266
267
268
269
        return output


class BertSelfOutput(nn.Module):

    def __init__(self,
                 hidden_size: int,
                 layer_norm_eps: float,
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
        super().__init__()
        self.dense = RowParallelLinear(input_size=hidden_size,
                                       output_size=hidden_size,
270
                                       bias=True,
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
                                       quant_config=quant_config,
                                       prefix=f"{prefix}.dense")
        self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)

    def forward(self, hidden_states: torch.Tensor,
                input_tensor: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.dense(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states


class BertIntermediate(nn.Module):

    def __init__(self,
                 hidden_size: int,
                 intermediate_size: int,
                 hidden_act: str,
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
        super().__init__()
        self.dense = ColumnParallelLinear(input_size=hidden_size,
                                          output_size=intermediate_size,
293
                                          bias=True,
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
                                          quant_config=quant_config,
                                          prefix=f"{prefix}.dense")
        self.intermediate_act_fn = get_act_fn(hidden_act)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.dense(hidden_states)
        hidden_states = self.intermediate_act_fn(hidden_states)
        return hidden_states


class BertOutput(nn.Module):

    def __init__(self,
                 hidden_size: int,
                 intermediate_size: int,
                 layer_norm_eps: float,
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
        super().__init__()

        self.dense = RowParallelLinear(input_size=intermediate_size,
                                       output_size=hidden_size,
316
                                       bias=True,
317
318
319
320
321
322
323
324
325
326
327
328
                                       quant_config=quant_config,
                                       prefix=f"{prefix}.dense")

        self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)

    def forward(self, hidden_states: torch.Tensor,
                input_tensor: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.dense(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states


329
@support_torch_compile
330
@default_pooling_type("CLS")
331
class BertModel(nn.Module, SupportsQuant):
332
333
334

    is_pooling_model = True

335
    packed_modules_mapping = {"qkv_proj": ["query", "key", "value"]}
336

337
338
339
340
341
342
343
    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
        embedding_class: type[nn.Module] = BertEmbedding,
    ) -> None:
344
        super().__init__()
345

346
347
        self.config = vllm_config.model_config.hf_config
        self.embeddings = embedding_class(self.config)
348
        self.encoder = BertEncoder(vllm_config=vllm_config,
349
350
                                   prefix=f"{prefix}.encoder")

351
352
353
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embeddings(input_ids)

354
355
356
    def forward(
        self,
        input_ids: torch.Tensor,
357
        positions: torch.Tensor,
358
359
360
361
362
363
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        if inputs_embeds is not None:
            hidden_states = inputs_embeds
        else:
364
            hidden_states = self.embeddings(input_ids=input_ids,
365
                                            position_ids=positions)
366
        return self.encoder(hidden_states)
367

368
    def _load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
369
370
371
372
373
374
375
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "query", "q"),
            ("qkv_proj", "key", "k"),
            ("qkv_proj", "value", "v"),
        ]

376
377
        loaded_stacked_params = []
        other_weights = []
378
379
380
381
382
        params_dict = dict(self.named_parameters())
        for name, loaded_weight in weights:
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
383

384
                name = name.replace(weight_name, param_name)
385
                if name not in params_dict:
386
387
388
389
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
390
                loaded_stacked_params.append(name)
391
392
                break
            else:
393
394
395
                if name in params_dict:
                    other_weights.append((name, loaded_weight))

396
397
398
399
400
401
402
403
404
405
406
407
        return other_weights, loaded_stacked_params

    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
        other_weights, loaded_stacked_params = self._load_weights(weights)

        loader = AutoWeightsLoader(self, skip_prefixes=["pooler."])
        loaded_params = loader.load_weights(other_weights)
        loaded_params.update(loaded_stacked_params)
        return loaded_params


408
@default_pooling_type("ALL")
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
class BertPoolingModel(BertModel):

    is_pooling_model = True

    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
        embedding_class: type[nn.Module] = BertEmbedding,
    ) -> None:
        super().__init__(
            vllm_config=vllm_config,
            prefix=prefix,
            embedding_class=embedding_class,
424
        )
425
426
427
428
429
430
431
432
433

        config = vllm_config.model_config.hf_config
        self.pooler = BertPooler(config)

    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
        other_weights, loaded_stacked_params = self._load_weights(weights)

        loader = AutoWeightsLoader(self)
434
435
        loaded_params = loader.load_weights(other_weights)
        loaded_params.update(loaded_stacked_params)
436
        return loaded_params
437
438


439
@default_pooling_type("CLS")
440
class BertEmbeddingModel(nn.Module, SupportsQuant):
441
442
    """A model that uses Bert to provide embedding functionalities.

443
444
    This class encapsulates the BertModel and provides an interface for
    embedding operations and customized pooling functions.
445

446
447
448
449
    Attributes:
        model: An instance of BertModel used for forward operations.
        _pooler: An instance of Pooler used for pooling operations.
    """
450

451
452
    is_pooling_model = True

453
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
454
        super().__init__()
455

456
        pooler_config = vllm_config.model_config.pooler_config
457
458
        assert pooler_config is not None

459
460
        self.model = self._build_model(vllm_config=vllm_config,
                                       prefix=maybe_prefix(prefix, "model"))
461
        self.pooler = self._build_pooler(pooler_config)
462

463
464
465
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

466
467
    def forward(
        self,
468
        input_ids: torch.Tensor,
469
470
471
472
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
473
        return self.model(input_ids=input_ids,
474
                          positions=positions,
475
476
                          inputs_embeds=inputs_embeds,
                          intermediate_tensors=intermediate_tensors)
477

478
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
479
480
481
482
483
484
485
486
487
        weights_list = list(weights)

        has_model_prefix = any(
            name.startswith("model.") for name, _ in weights_list)
        if not has_model_prefix:
            mapper = WeightsMapper(orig_to_new_prefix={"": "model."})

        loader = AutoWeightsLoader(self, skip_prefixes=["lm_head."])
        return loader.load_weights(weights_list, mapper=mapper)
488
489
490
491
492
493
494
495
496

    def _build_model(self,
                     vllm_config: VllmConfig,
                     prefix: str = "") -> BertModel:
        return BertModel(vllm_config=vllm_config,
                         prefix=prefix,
                         embedding_class=BertEmbedding)

    def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler:
497
        return DispatchPooler({
498
499
            "encode": Pooler.for_encode(pooler_config),
            "embed": Pooler.for_embed(pooler_config),
500
        })
501
502


503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
# Here we encode the token type ids together with the input ids.
# Since we use int 32 for the input IDs and the vocabulary size
# is way lower than 2**31, there is room to encode additional
# bits. At the same time, for cross-encoder use cases, the
# token type ids are only 0 or 1, requiring only 1 bit.
# This means that we can store the token type ids in the 31st
# bit. We void the 32nd bit because that would produce a negative
# number, which could be used to signal other things.
#
# The reason for all of this is that all the tensors that are
# passed as input to the forward function of a module marked
# with @support_torch_compile have to be persistent. So to
# avoid adding more persistent tensors in the model runner, we
# encode more information in the same persistent tensor.
#
# Since the *ForClassification module is outside of the BertModel
# which is compiled, we can do the encoding here and then separate
# the information again in the Embedding  layer. Since with bit masks
# we can do this entirely with torch operations and without branching,
# it works with torch compile.

TOKEN_TYPE_SHIFT = 30


def _encode_token_type_ids(input_ids: torch.Tensor,
                           token_type_ids: torch.Tensor) -> None:
    # input_ids can be padded to the right
    input_ids[:token_type_ids.shape[0]].bitwise_or_(
        token_type_ids << TOKEN_TYPE_SHIFT)


def _decode_token_type_ids(input_ids: torch.Tensor) -> torch.Tensor:

536
537
538
    ids_mask = torch.ones_like(input_ids,
                               dtype=torch.int32,
                               device=input_ids.device) << TOKEN_TYPE_SHIFT
539
540
541
542
543
544
545
546
547
    tokens_mask = ids_mask.bitwise_not()

    token_type_ids = input_ids.bitwise_and(ids_mask) >> TOKEN_TYPE_SHIFT

    input_ids.bitwise_and_(tokens_mask)

    return token_type_ids


548
@default_pooling_type("CLS")
549
550
class BertForSequenceClassification(nn.Module, SupportsCrossEncoding,
                                    SupportsQuant):
551
552
553
554
555
556
557
558
559
560
    """A model that uses Bert to provide embedding functionalities.

   This class encapsulates the BertModel and provides an interface for
   embedding operations and customized pooling functions.

   Attributes:
       model: An instance of BertModel used for forward operations.
       _pooler: An instance of Pooler used for pooling operations.
   """

561
562
    is_pooling_model = True

563
564
565
566
567
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config

        self.num_labels = config.num_labels
568
569
570
        self.bert = BertPoolingModel(vllm_config=vllm_config,
                                     prefix=maybe_prefix(prefix, "bert"),
                                     embedding_class=BertEmbedding)
571
572
573
        self.classifier = nn.Linear(config.hidden_size,
                                    config.num_labels,
                                    dtype=vllm_config.model_config.head_dtype)
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595

        pooler_config = vllm_config.model_config.pooler_config
        assert pooler_config is not None

        self.pooler = DispatchPooler({
            "encode":
            Pooler.for_encode(pooler_config),
            "classify":
            ClassifierPooler(
                pooling=self.bert.pooler,
                classifier=self.classifier,
                act_fn=ClassifierPooler.act_fn_for_seq_cls(
                    vllm_config.model_config),
            ),
            "score":
            ClassifierPooler(
                pooling=self.bert.pooler,
                classifier=self.classifier,
                act_fn=ClassifierPooler.act_fn_for_cross_encoder(
                    vllm_config.model_config),
            ),
        })
596

597
598
599
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.bert.get_input_embeddings(input_ids)

600
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
601
602
603
        loader = AutoWeightsLoader(self)
        loaded_params = loader.load_weights(weights)
        return loaded_params
604
605
606
607
608
609
610
611
612

    def forward(
        self,
        input_ids: Optional[torch.Tensor],
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
613
614
615
616
617
618

        if token_type_ids is not None:
            assert self.bert.config.vocab_size < (1 << TOKEN_TYPE_SHIFT)
            assert input_ids is not None
            _encode_token_type_ids(input_ids, token_type_ids)

619
        return self.bert(input_ids=input_ids,
620
                         positions=positions,
621
                         inputs_embeds=inputs_embeds,
622
                         intermediate_tensors=intermediate_tensors)
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648


@default_pooling_type("ALL")
class BertForTokenClassification(nn.Module):
    is_pooling_model = True

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        self.head_dtype = vllm_config.model_config.head_dtype
        self.num_labels = config.num_labels
        self.bert = BertModel(vllm_config=vllm_config,
                              prefix=maybe_prefix(prefix, "bert"),
                              embedding_class=BertEmbedding)
        self.classifier = nn.Linear(config.hidden_size,
                                    config.num_labels,
                                    dtype=self.head_dtype)

        pooler_config = vllm_config.model_config.pooler_config
        assert pooler_config is not None

        self.pooler = DispatchPooler({
            "encode":
            Pooler.for_encode(pooler_config),
        })

649
650
651
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.bert.get_input_embeddings(input_ids)

652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
        loader = AutoWeightsLoader(self)
        loaded_params = loader.load_weights(weights)
        return loaded_params

    def forward(
        self,
        input_ids: Optional[torch.Tensor],
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:

        if token_type_ids is not None:
            assert self.bert.config.vocab_size < (1 << TOKEN_TYPE_SHIFT)
            assert input_ids is not None
            _encode_token_type_ids(input_ids, token_type_ids)

        hidden_states = self.bert(input_ids=input_ids,
                                  positions=positions,
                                  inputs_embeds=inputs_embeds,
                                  intermediate_tensors=intermediate_tensors)

        hidden_states = hidden_states.to(self.head_dtype)
        return self.classifier(hidden_states)