bert.py 23.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from collections.abc import Iterable, Set
5
from typing import Optional, Union
6
7
8
9
10

import torch
from torch import nn
from transformers import BertConfig

11
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
12
from vllm.compilation.decorators import support_torch_compile
13
from vllm.config import CacheConfig, PoolerConfig, VllmConfig
14
from vllm.distributed import get_tensor_model_parallel_world_size
15
from vllm.model_executor.layers.activation import get_act_fn
16
17
18
19
20
21
22
23
24
25
26
27
28
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
from vllm.model_executor.layers.pooler import (
    ClassifierPooler,
    DispatchPooler,
    Pooler,
    PoolingMethod,
    PoolingParamsUpdate,
    PoolingType,
)
29
from vllm.model_executor.layers.quantization import QuantizationConfig
30
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
31
from vllm.sequence import IntermediateTensors
32
from vllm.tasks import PoolingTask
33
from vllm.v1.pool.metadata import PoolingMetadata
34

35
36
from .interfaces import SupportsCrossEncoding, SupportsQuant
from .interfaces_base import default_pooling_type
37
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
38

39
40
41
42
43

class BertEmbedding(nn.Module):
    def __init__(self, config: BertConfig):
        super().__init__()
        self.size = config.hidden_size
44
45
46
        self.word_embeddings = VocabParallelEmbedding(
            config.vocab_size, config.hidden_size
        )
47
        self.position_embeddings = VocabParallelEmbedding(
48
49
            config.max_position_embeddings, config.hidden_size
        )
50
        self.token_type_embeddings = VocabParallelEmbedding(
51
52
53
            config.type_vocab_size, config.hidden_size
        )
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
54

55
56
57
58
        self.register_buffer(
            "position_ids",
            torch.arange(config.max_position_embeddings).unsqueeze(0),
        )
59
        self.position_embedding_type = config.position_embedding_type
60
        if self.position_embedding_type != "absolute":
61
62
63
            raise ValueError(
                "Only 'absolute' position_embedding_type" + " is supported"
            )
64
65
66
67

    def forward(
        self,
        input_ids: torch.Tensor,
68
        position_ids: torch.Tensor,
69
        inputs_embeds: Optional[torch.Tensor] = None,
70
    ) -> torch.Tensor:
71
        token_type_ids = _decode_token_type_ids(input_ids)
72

73
74
75
        if inputs_embeds is None:
            inputs_embeds = self.word_embeddings(input_ids)

76
77
        position_embeddings = self.position_embeddings(position_ids)

78
        token_type_embeddings = self.token_type_embeddings(token_type_ids)
79

80
        embeddings = inputs_embeds + token_type_embeddings + position_embeddings
81
82
83
84
        embeddings = self.LayerNorm(embeddings)
        return embeddings


85
class BertPooler(Pooler):
86
87
    def __init__(self, config: BertConfig):
        super().__init__()
88
89

        self.pooling = PoolingMethod.from_pooling_type(PoolingType.CLS)
90
91
92
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.activation = nn.Tanh()

93
94
95
96
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return self.pooling.get_supported_tasks()

    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
97
        return self.pooling.get_pooling_updates(task)
98

99
100
101
102
103
    def _head(self, pooled_output: torch.Tensor):
        pooled_output = self.dense(pooled_output)
        pooled_output = self.activation(pooled_output)
        return pooled_output

104
105
106
107
108
109
    def forward(
        self,
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
        pooling_metadata: PoolingMetadata,
    ) -> Union[torch.Tensor, list[torch.Tensor]]:
        pooled_output = self.pooling(hidden_states, pooling_metadata)
110
111
112
113
114
115

        if isinstance(pooled_output, list):
            pooled_output = [self._head(output) for output in pooled_output]
        else:
            pooled_output = self._head(pooled_output)

116
117
118
        return pooled_output


119
class BertEncoder(nn.Module):
120
    def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
121
        super().__init__()
122
123
124
        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
125
126
127
128
129
130
131
132
133
134
135
        self.layer = nn.ModuleList(
            [
                BertLayer(
                    config=config,
                    cache_config=cache_config,
                    quant_config=quant_config,
                    prefix=f"{prefix}.layer.{layer_idx}",
                )
                for layer_idx in range(config.num_hidden_layers)
            ]
        )
136
137
138
139
140

    def forward(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
141
        for layer in self.layer:
142
            hidden_states = layer(hidden_states)
143
144
145
146
        return hidden_states


class BertLayer(nn.Module):
147
148
149
150
151
152
153
    def __init__(
        self,
        config: BertConfig,
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ):
154
155
156
157
158
159
160
161
        super().__init__()

        self.attention = BertAttention(
            hidden_size=config.hidden_size,
            num_attention_heads=config.num_attention_heads,
            layer_norm_eps=config.layer_norm_eps,
            cache_config=cache_config,
            quant_config=quant_config,
162
163
            prefix=f"{prefix}.attention",
        )
164

165
166
167
168
169
        self.intermediate = BertIntermediate(
            hidden_size=config.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
            quant_config=quant_config,
170
171
            prefix=f"{prefix}.intermediate",
        )
172

173
174
175
176
177
178
179
        self.output = BertOutput(
            hidden_size=config.hidden_size,
            intermediate_size=config.intermediate_size,
            layer_norm_eps=config.layer_norm_eps,
            quant_config=quant_config,
            prefix=f"{prefix}.output",
        )
180

181
182
    def forward(self, hidden_states: torch.Tensor):
        attn_output = self.attention(hidden_states)
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
        intermediate_output = self.intermediate(attn_output)
        output = self.output(intermediate_output, attn_output)
        return output


class BertAttention(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        num_attention_heads: int,
        layer_norm_eps: float,
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ):
        super().__init__()

200
201
202
203
204
205
206
        self.self = BertSelfAttention(
            hidden_size=hidden_size,
            num_attention_heads=num_attention_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.output",
        )
207

208
209
210
211
212
213
        self.output = BertSelfOutput(
            hidden_size=hidden_size,
            layer_norm_eps=layer_norm_eps,
            quant_config=quant_config,
            prefix=f"{prefix}.output",
        )
214
215
216
217
218

    def forward(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
219
        self_output = self.self(hidden_states)
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
        return self.output(self_output, hidden_states)


class BertSelfAttention(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        num_attention_heads: int,
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ):
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()

        self.total_num_heads = num_attention_heads
        assert self.total_num_heads % tp_size == 0

        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = self.total_num_heads
        self.head_dim = self.hidden_size // self.total_num_heads
        assert self.head_dim * self.total_num_heads == self.hidden_size

        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)

        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.qkv_proj = QKVParallelLinear(
            hidden_size=self.hidden_size,
            head_size=self.head_dim,
            total_num_heads=self.total_num_heads,
            total_num_kv_heads=self.total_num_kv_heads,
254
            bias=True,
255
            quant_config=quant_config,
256
257
            prefix=f"{prefix}.qkv_proj",
        )
258

259
260
261
262
263
264
265
266
267
        self.attn = EncoderOnlyAttention(
            num_heads=self.num_heads,
            head_size=self.head_dim,
            scale=self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )
268
269
270
271
272
273
274

    def forward(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
275
        output = self.attn(q, k, v)
276
277
278
279
        return output


class BertSelfOutput(nn.Module):
280
281
282
283
284
285
286
    def __init__(
        self,
        hidden_size: int,
        layer_norm_eps: float,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ):
287
        super().__init__()
288
289
290
291
292
293
294
        self.dense = RowParallelLinear(
            input_size=hidden_size,
            output_size=hidden_size,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.dense",
        )
295
296
        self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)

297
298
299
    def forward(
        self, hidden_states: torch.Tensor, input_tensor: torch.Tensor
    ) -> torch.Tensor:
300
301
302
303
304
305
        hidden_states, _ = self.dense(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states


class BertIntermediate(nn.Module):
306
307
308
309
310
311
312
313
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ):
314
        super().__init__()
315
316
317
318
319
320
321
        self.dense = ColumnParallelLinear(
            input_size=hidden_size,
            output_size=intermediate_size,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.dense",
        )
322
323
324
325
326
327
328
329
330
        self.intermediate_act_fn = get_act_fn(hidden_act)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.dense(hidden_states)
        hidden_states = self.intermediate_act_fn(hidden_states)
        return hidden_states


class BertOutput(nn.Module):
331
332
333
334
335
336
337
338
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        layer_norm_eps: float,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ):
339
340
        super().__init__()

341
342
343
344
345
346
347
        self.dense = RowParallelLinear(
            input_size=intermediate_size,
            output_size=hidden_size,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.dense",
        )
348
349
350

        self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)

351
352
353
    def forward(
        self, hidden_states: torch.Tensor, input_tensor: torch.Tensor
    ) -> torch.Tensor:
354
355
356
357
358
        hidden_states, _ = self.dense(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states


359
@support_torch_compile
360
@default_pooling_type("CLS")
361
class BertModel(nn.Module, SupportsQuant):
362
363
    is_pooling_model = True

364
    packed_modules_mapping = {"qkv_proj": ["query", "key", "value"]}
365

366
367
368
369
370
371
372
    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
        embedding_class: type[nn.Module] = BertEmbedding,
    ) -> None:
373
        super().__init__()
374

375
376
        self.config = vllm_config.model_config.hf_config
        self.embeddings = embedding_class(self.config)
377
        self.encoder = BertEncoder(vllm_config=vllm_config, prefix=f"{prefix}.encoder")
378

379
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
380
        return self.embeddings.word_embeddings(input_ids)
381

382
383
384
    def forward(
        self,
        input_ids: torch.Tensor,
385
        positions: torch.Tensor,
386
387
388
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
389
390
391
392
393
394
        hidden_states = self.embeddings(
            input_ids=input_ids,
            position_ids=positions,
            inputs_embeds=inputs_embeds,
        )

395
        return self.encoder(hidden_states)
396

397
    def _load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
398
399
400
401
402
403
404
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "query", "q"),
            ("qkv_proj", "key", "k"),
            ("qkv_proj", "value", "v"),
        ]

405
406
        loaded_stacked_params = []
        other_weights = []
407
408
        params_dict = dict(self.named_parameters())
        for name, loaded_weight in weights:
409
            for param_name, weight_name, shard_id in stacked_params_mapping:
410
411
                if weight_name not in name:
                    continue
412

413
                name = name.replace(weight_name, param_name)
414
                if name not in params_dict:
415
416
417
418
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
419
                loaded_stacked_params.append(name)
420
421
                break
            else:
422
423
424
                if name in params_dict:
                    other_weights.append((name, loaded_weight))

425
426
        return other_weights, loaded_stacked_params

427
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
428
429
430
431
432
433
434
435
        other_weights, loaded_stacked_params = self._load_weights(weights)

        loader = AutoWeightsLoader(self, skip_prefixes=["pooler."])
        loaded_params = loader.load_weights(other_weights)
        loaded_params.update(loaded_stacked_params)
        return loaded_params


436
@default_pooling_type("ALL")
437
438
439
440
441
442
443
444
445
446
447
448
449
450
class BertPoolingModel(BertModel):
    is_pooling_model = True

    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
        embedding_class: type[nn.Module] = BertEmbedding,
    ) -> None:
        super().__init__(
            vllm_config=vllm_config,
            prefix=prefix,
            embedding_class=embedding_class,
451
        )
452
453
454
455

        config = vllm_config.model_config.hf_config
        self.pooler = BertPooler(config)

456
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
457
458
459
        other_weights, loaded_stacked_params = self._load_weights(weights)

        loader = AutoWeightsLoader(self)
460
461
        loaded_params = loader.load_weights(other_weights)
        loaded_params.update(loaded_stacked_params)
462
        return loaded_params
463
464


465
@default_pooling_type("CLS")
466
class BertEmbeddingModel(nn.Module, SupportsQuant):
467
468
    """A model that uses Bert to provide embedding functionalities.

469
470
    This class encapsulates the BertModel and provides an interface for
    embedding operations and customized pooling functions.
471

472
473
474
475
    Attributes:
        model: An instance of BertModel used for forward operations.
        _pooler: An instance of Pooler used for pooling operations.
    """
476

477
478
    is_pooling_model = True

479
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
480
        super().__init__()
481

482
        pooler_config = vllm_config.model_config.pooler_config
483
484
        assert pooler_config is not None

485
486
487
        self.model = self._build_model(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
488
        self.pooler = self._build_pooler(pooler_config)
489

490
491
492
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

493
494
    def forward(
        self,
495
        input_ids: torch.Tensor,
496
497
498
499
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
500
501
502
503
504
505
        return self.model(
            input_ids=input_ids,
            positions=positions,
            inputs_embeds=inputs_embeds,
            intermediate_tensors=intermediate_tensors,
        )
506

507
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
508
509
        weights_list = list(weights)

510
        has_model_prefix = any(name.startswith("model.") for name, _ in weights_list)
511
512
513
514
515
        if not has_model_prefix:
            mapper = WeightsMapper(orig_to_new_prefix={"": "model."})

        loader = AutoWeightsLoader(self, skip_prefixes=["lm_head."])
        return loader.load_weights(weights_list, mapper=mapper)
516

517
518
519
520
    def _build_model(self, vllm_config: VllmConfig, prefix: str = "") -> BertModel:
        return BertModel(
            vllm_config=vllm_config, prefix=prefix, embedding_class=BertEmbedding
        )
521
522

    def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler:
523
524
525
526
527
528
        return DispatchPooler(
            {
                "encode": Pooler.for_encode(pooler_config),
                "embed": Pooler.for_embed(pooler_config),
            }
        )
529
530


531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
# Here we encode the token type ids together with the input ids.
# Since we use int 32 for the input IDs and the vocabulary size
# is way lower than 2**31, there is room to encode additional
# bits. At the same time, for cross-encoder use cases, the
# token type ids are only 0 or 1, requiring only 1 bit.
# This means that we can store the token type ids in the 31st
# bit. We void the 32nd bit because that would produce a negative
# number, which could be used to signal other things.
#
# The reason for all of this is that all the tensors that are
# passed as input to the forward function of a module marked
# with @support_torch_compile have to be persistent. So to
# avoid adding more persistent tensors in the model runner, we
# encode more information in the same persistent tensor.
#
# Since the *ForClassification module is outside of the BertModel
# which is compiled, we can do the encoding here and then separate
# the information again in the Embedding  layer. Since with bit masks
# we can do this entirely with torch operations and without branching,
# it works with torch compile.

TOKEN_TYPE_SHIFT = 30


555
556
557
def _encode_token_type_ids(
    input_ids: torch.Tensor, token_type_ids: torch.Tensor
) -> None:
558
    # input_ids can be padded to the right
559
    input_ids[: token_type_ids.shape[0]].bitwise_or_(token_type_ids << TOKEN_TYPE_SHIFT)
560
561
562


def _decode_token_type_ids(input_ids: torch.Tensor) -> torch.Tensor:
563
564
565
566
    ids_mask = (
        torch.ones_like(input_ids, dtype=torch.int32, device=input_ids.device)
        << TOKEN_TYPE_SHIFT
    )
567
568
569
570
571
572
573
574
575
    tokens_mask = ids_mask.bitwise_not()

    token_type_ids = input_ids.bitwise_and(ids_mask) >> TOKEN_TYPE_SHIFT

    input_ids.bitwise_and_(tokens_mask)

    return token_type_ids


576
@default_pooling_type("CLS")
577
class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, SupportsQuant):
578
579
    """A model that uses Bert to provide embedding functionalities.

580
581
    This class encapsulates the BertModel and provides an interface for
    embedding operations and customized pooling functions.
582

583
584
585
586
    Attributes:
        model: An instance of BertModel used for forward operations.
        _pooler: An instance of Pooler used for pooling operations.
    """
587

588
589
    is_pooling_model = True

590
591
592
593
594
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config

        self.num_labels = config.num_labels
595
596
597
598
599
600
601
602
603
604
        self.bert = BertPoolingModel(
            vllm_config=vllm_config,
            prefix=maybe_prefix(prefix, "bert"),
            embedding_class=BertEmbedding,
        )
        self.classifier = nn.Linear(
            config.hidden_size,
            config.num_labels,
            dtype=vllm_config.model_config.head_dtype,
        )
605
606
607
608

        pooler_config = vllm_config.model_config.pooler_config
        assert pooler_config is not None

609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
        self.pooler = DispatchPooler(
            {
                "encode": Pooler.for_encode(pooler_config),
                "classify": ClassifierPooler(
                    pooling=self.bert.pooler,
                    classifier=self.classifier,
                    act_fn=ClassifierPooler.act_fn_for_seq_cls(
                        vllm_config.model_config
                    ),
                ),
                "score": ClassifierPooler(
                    pooling=self.bert.pooler,
                    classifier=self.classifier,
                    act_fn=ClassifierPooler.act_fn_for_cross_encoder(
                        vllm_config.model_config
                    ),
                ),
            }
        )
628

629
630
631
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.bert.get_input_embeddings(input_ids)

632
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
633
634
635
        loader = AutoWeightsLoader(self)
        loaded_params = loader.load_weights(weights)
        return loaded_params
636
637
638
639
640
641
642
643
644

    def forward(
        self,
        input_ids: Optional[torch.Tensor],
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
645
646
647
648
649
        if token_type_ids is not None:
            assert self.bert.config.vocab_size < (1 << TOKEN_TYPE_SHIFT)
            assert input_ids is not None
            _encode_token_type_ids(input_ids, token_type_ids)

650
651
652
653
654
655
        return self.bert(
            input_ids=input_ids,
            positions=positions,
            inputs_embeds=inputs_embeds,
            intermediate_tensors=intermediate_tensors,
        )
656
657
658
659
660
661
662
663
664
665
666


@default_pooling_type("ALL")
class BertForTokenClassification(nn.Module):
    is_pooling_model = True

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        self.head_dtype = vllm_config.model_config.head_dtype
        self.num_labels = config.num_labels
667
668
669
670
671
672
673
674
        self.bert = BertModel(
            vllm_config=vllm_config,
            prefix=maybe_prefix(prefix, "bert"),
            embedding_class=BertEmbedding,
        )
        self.classifier = nn.Linear(
            config.hidden_size, config.num_labels, dtype=self.head_dtype
        )
675
676
677
678

        pooler_config = vllm_config.model_config.pooler_config
        assert pooler_config is not None

679
680
681
682
683
        self.pooler = DispatchPooler(
            {
                "encode": Pooler.for_encode(pooler_config),
            }
        )
684

685
686
687
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.bert.get_input_embeddings(input_ids)

688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
        loader = AutoWeightsLoader(self)
        loaded_params = loader.load_weights(weights)
        return loaded_params

    def forward(
        self,
        input_ids: Optional[torch.Tensor],
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        if token_type_ids is not None:
            assert self.bert.config.vocab_size < (1 << TOKEN_TYPE_SHIFT)
            assert input_ids is not None
            _encode_token_type_ids(input_ids, token_type_ids)

706
707
708
709
710
711
        hidden_states = self.bert(
            input_ids=input_ids,
            positions=positions,
            inputs_embeds=inputs_embeds,
            intermediate_tensors=intermediate_tensors,
        )
712
713
714

        hidden_states = hidden_states.to(self.head_dtype)
        return self.classifier(hidden_states)