"vscode:/vscode.git/clone" did not exist on "076169f603a44b3a3377e59bad62d1cfc62cf98a"
bert.py 30.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from collections.abc import Iterable, Set
5
6
7
8
9

import torch
from torch import nn
from transformers import BertConfig

10
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
11
from vllm.compilation.decorators import support_torch_compile
12
from vllm.config import CacheConfig, PoolerConfig, VllmConfig
13
from vllm.distributed import get_tensor_model_parallel_world_size
14
from vllm.model_executor.layers.activation import get_act_fn
15
16
17
18
19
20
21
22
23
24
25
26
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
from vllm.model_executor.layers.pooler import (
    ClassifierPooler,
    DispatchPooler,
    Pooler,
    PoolingMethod,
    PoolingParamsUpdate,
    PoolingType,
27
28
    TokenPoolerHeadOutput,
    TokenPoolingMethodOutput,
29
)
30
from vllm.model_executor.layers.quantization import QuantizationConfig
31
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
32
from vllm.sequence import IntermediateTensors
33
from vllm.tasks import PoolingTask
34
from vllm.v1.outputs import TokenPoolerOutput
35
from vllm.v1.pool.metadata import PoolingMetadata
36

37
from .interfaces import SupportsCrossEncoding, SupportsQuant
38
from .interfaces_base import attn_type, default_pooling_type
39
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
40

41
42
43
44
45

class BertEmbedding(nn.Module):
    def __init__(self, config: BertConfig):
        super().__init__()
        self.size = config.hidden_size
46
47
48
        self.word_embeddings = VocabParallelEmbedding(
            config.vocab_size, config.hidden_size
        )
49
        self.position_embeddings = VocabParallelEmbedding(
50
51
            config.max_position_embeddings, config.hidden_size
        )
52
        self.token_type_embeddings = VocabParallelEmbedding(
53
54
55
            config.type_vocab_size, config.hidden_size
        )
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
56

57
58
59
60
        self.register_buffer(
            "position_ids",
            torch.arange(config.max_position_embeddings).unsqueeze(0),
        )
61
62
63
        self.position_embedding_type = getattr(
            config, "position_embedding_type", "absolute"
        )
64
        if self.position_embedding_type != "absolute":
65
66
67
            raise ValueError(
                "Only 'absolute' position_embedding_type" + " is supported"
            )
68
69
70
71

    def forward(
        self,
        input_ids: torch.Tensor,
72
        position_ids: torch.Tensor,
73
        inputs_embeds: torch.Tensor | None = None,
74
    ) -> torch.Tensor:
75
        token_type_ids = _decode_token_type_ids(input_ids)
76

77
78
79
        if inputs_embeds is None:
            inputs_embeds = self.word_embeddings(input_ids)

80
81
        position_embeddings = self.position_embeddings(position_ids)

82
        token_type_embeddings = self.token_type_embeddings(token_type_ids)
83

84
        embeddings = inputs_embeds + token_type_embeddings + position_embeddings
85
86
87
88
        embeddings = self.LayerNorm(embeddings)
        return embeddings


89
class BertPooler(Pooler):
90
91
    def __init__(self, config: BertConfig):
        super().__init__()
92
93

        self.pooling = PoolingMethod.from_pooling_type(PoolingType.CLS)
94
95
96
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.activation = nn.Tanh()

97
98
99
100
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return self.pooling.get_supported_tasks()

    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
101
        return self.pooling.get_pooling_updates(task)
102

103
    def head(
104
        self,
105
        pooled_data: TokenPoolingMethodOutput,
106
        pooling_metadata: PoolingMetadata,
107
108
109
    ) -> TokenPoolerHeadOutput:
        if isinstance(pooled_data, list):
            pooled_data = torch.stack(pooled_data)
110

111
112
113
        pooled_data = self.dense(pooled_data)
        pooled_data = self.activation(pooled_data)
        return pooled_data
114

115
116
117
118
119
120
121
122
    def forward(
        self,
        hidden_states: torch.Tensor,
        pooling_metadata: PoolingMetadata,
    ) -> TokenPoolerOutput:
        pooled_data = self.pooling(hidden_states, pooling_metadata)
        pooled_data = self.head(pooled_data, pooling_metadata)
        return pooled_data
123
124


125
class BertEncoder(nn.Module):
126
    def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
127
        super().__init__()
128
129
130
        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
131
132
133
134
135
136
137
138
139
140
141
        self.layer = nn.ModuleList(
            [
                BertLayer(
                    config=config,
                    cache_config=cache_config,
                    quant_config=quant_config,
                    prefix=f"{prefix}.layer.{layer_idx}",
                )
                for layer_idx in range(config.num_hidden_layers)
            ]
        )
142
143
144
145
146

    def forward(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
147
        for layer in self.layer:
148
            hidden_states = layer(hidden_states)
149
150
151
152
        return hidden_states


class BertLayer(nn.Module):
153
154
155
    def __init__(
        self,
        config: BertConfig,
156
157
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
158
159
        prefix: str = "",
    ):
160
161
162
163
164
165
166
167
        super().__init__()

        self.attention = BertAttention(
            hidden_size=config.hidden_size,
            num_attention_heads=config.num_attention_heads,
            layer_norm_eps=config.layer_norm_eps,
            cache_config=cache_config,
            quant_config=quant_config,
168
169
            prefix=f"{prefix}.attention",
        )
170

171
172
173
174
175
        self.intermediate = BertIntermediate(
            hidden_size=config.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
            quant_config=quant_config,
176
177
            prefix=f"{prefix}.intermediate",
        )
178

179
180
181
182
183
184
185
        self.output = BertOutput(
            hidden_size=config.hidden_size,
            intermediate_size=config.intermediate_size,
            layer_norm_eps=config.layer_norm_eps,
            quant_config=quant_config,
            prefix=f"{prefix}.output",
        )
186

187
188
    def forward(self, hidden_states: torch.Tensor):
        attn_output = self.attention(hidden_states)
189
190
191
192
193
194
195
196
197
198
199
        intermediate_output = self.intermediate(attn_output)
        output = self.output(intermediate_output, attn_output)
        return output


class BertAttention(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        num_attention_heads: int,
        layer_norm_eps: float,
200
201
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
202
203
204
205
        prefix: str = "",
    ):
        super().__init__()

206
207
208
209
210
211
212
        self.self = BertSelfAttention(
            hidden_size=hidden_size,
            num_attention_heads=num_attention_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.output",
        )
213

214
215
216
217
218
219
        self.output = BertSelfOutput(
            hidden_size=hidden_size,
            layer_norm_eps=layer_norm_eps,
            quant_config=quant_config,
            prefix=f"{prefix}.output",
        )
220
221
222
223
224

    def forward(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
225
        self_output = self.self(hidden_states)
226
227
228
229
230
231
232
233
        return self.output(self_output, hidden_states)


class BertSelfAttention(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        num_attention_heads: int,
234
235
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
        prefix: str = "",
    ):
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()

        self.total_num_heads = num_attention_heads
        assert self.total_num_heads % tp_size == 0

        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = self.total_num_heads
        self.head_dim = self.hidden_size // self.total_num_heads
        assert self.head_dim * self.total_num_heads == self.hidden_size

        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)

        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.qkv_proj = QKVParallelLinear(
            hidden_size=self.hidden_size,
            head_size=self.head_dim,
            total_num_heads=self.total_num_heads,
            total_num_kv_heads=self.total_num_kv_heads,
260
            bias=True,
261
            quant_config=quant_config,
262
263
            prefix=f"{prefix}.qkv_proj",
        )
264

265
266
267
268
269
270
271
272
273
        self.attn = EncoderOnlyAttention(
            num_heads=self.num_heads,
            head_size=self.head_dim,
            scale=self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )
274
275
276
277
278
279
280

    def forward(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
281
        output = self.attn(q, k, v)
282
283
284
285
        return output


class BertSelfOutput(nn.Module):
286
287
288
289
    def __init__(
        self,
        hidden_size: int,
        layer_norm_eps: float,
290
        quant_config: QuantizationConfig | None = None,
291
292
        prefix: str = "",
    ):
293
        super().__init__()
294
295
296
297
298
299
300
        self.dense = RowParallelLinear(
            input_size=hidden_size,
            output_size=hidden_size,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.dense",
        )
301
302
        self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)

303
304
305
    def forward(
        self, hidden_states: torch.Tensor, input_tensor: torch.Tensor
    ) -> torch.Tensor:
306
307
308
309
310
311
        hidden_states, _ = self.dense(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states


class BertIntermediate(nn.Module):
312
313
314
315
316
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
317
        quant_config: QuantizationConfig | None = None,
318
319
        prefix: str = "",
    ):
320
        super().__init__()
321
322
323
324
325
326
327
        self.dense = ColumnParallelLinear(
            input_size=hidden_size,
            output_size=intermediate_size,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.dense",
        )
328
329
330
331
332
333
334
335
336
        self.intermediate_act_fn = get_act_fn(hidden_act)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.dense(hidden_states)
        hidden_states = self.intermediate_act_fn(hidden_states)
        return hidden_states


class BertOutput(nn.Module):
337
338
339
340
341
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        layer_norm_eps: float,
342
        quant_config: QuantizationConfig | None = None,
343
344
        prefix: str = "",
    ):
345
346
        super().__init__()

347
348
349
350
351
352
353
        self.dense = RowParallelLinear(
            input_size=intermediate_size,
            output_size=hidden_size,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.dense",
        )
354
355
356

        self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)

357
358
359
    def forward(
        self, hidden_states: torch.Tensor, input_tensor: torch.Tensor
    ) -> torch.Tensor:
360
361
362
363
364
        hidden_states, _ = self.dense(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states


365
@support_torch_compile
366
@default_pooling_type("CLS")
367
class BertModel(nn.Module, SupportsQuant):
368
369
    is_pooling_model = True

370
    packed_modules_mapping = {"qkv_proj": ["query", "key", "value"]}
371

372
373
374
375
376
377
378
    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
        embedding_class: type[nn.Module] = BertEmbedding,
    ) -> None:
379
        super().__init__()
380

381
382
        self.config = vllm_config.model_config.hf_config
        self.embeddings = embedding_class(self.config)
383
        self.encoder = BertEncoder(vllm_config=vllm_config, prefix=f"{prefix}.encoder")
384

385
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
386
        return self.embeddings.word_embeddings(input_ids)
387

388
389
390
    def forward(
        self,
        input_ids: torch.Tensor,
391
        positions: torch.Tensor,
392
393
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
394
    ) -> torch.Tensor:
395
396
397
398
399
400
        hidden_states = self.embeddings(
            input_ids=input_ids,
            position_ids=positions,
            inputs_embeds=inputs_embeds,
        )

401
        return self.encoder(hidden_states)
402

403
    def _load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
404
405
406
407
408
409
410
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "query", "q"),
            ("qkv_proj", "key", "k"),
            ("qkv_proj", "value", "v"),
        ]

411
412
        loaded_stacked_params = []
        other_weights = []
413
414
        params_dict = dict(self.named_parameters())
        for name, loaded_weight in weights:
415
            for param_name, weight_name, shard_id in stacked_params_mapping:
416
417
                if weight_name not in name:
                    continue
418

419
                name = name.replace(weight_name, param_name)
420
                if name not in params_dict:
421
422
423
424
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
425
                loaded_stacked_params.append(name)
426
427
                break
            else:
428
429
430
                if name in params_dict:
                    other_weights.append((name, loaded_weight))

431
432
        return other_weights, loaded_stacked_params

433
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
        other_weights, loaded_stacked_params = self._load_weights(weights)

        loader = AutoWeightsLoader(self, skip_prefixes=["pooler."])
        loaded_params = loader.load_weights(other_weights)
        loaded_params.update(loaded_stacked_params)
        return loaded_params


class BertPoolingModel(BertModel):
    is_pooling_model = True

    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
        embedding_class: type[nn.Module] = BertEmbedding,
    ) -> None:
        super().__init__(
            vllm_config=vllm_config,
            prefix=prefix,
            embedding_class=embedding_class,
456
        )
457
458
459
460

        config = vllm_config.model_config.hf_config
        self.pooler = BertPooler(config)

461
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
462
463
464
        other_weights, loaded_stacked_params = self._load_weights(weights)

        loader = AutoWeightsLoader(self)
465
466
        loaded_params = loader.load_weights(other_weights)
        loaded_params.update(loaded_stacked_params)
467
        return loaded_params
468
469


470
@default_pooling_type("CLS")
471
class BertEmbeddingModel(nn.Module, SupportsQuant):
472
473
    """A model that uses Bert to provide embedding functionalities.

474
475
    This class encapsulates the BertModel and provides an interface for
    embedding operations and customized pooling functions.
476

477
478
479
480
    Attributes:
        model: An instance of BertModel used for forward operations.
        _pooler: An instance of Pooler used for pooling operations.
    """
481

482
483
    is_pooling_model = True

484
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
485
        super().__init__()
486

487
        pooler_config = vllm_config.model_config.pooler_config
488
489
        assert pooler_config is not None

490
491
492
        self.model = self._build_model(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
493
        self.pooler = self._build_pooler(pooler_config)
494

495
496
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
497

498
499
    def forward(
        self,
500
        input_ids: torch.Tensor,
501
        positions: torch.Tensor,
502
503
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
504
    ) -> torch.Tensor:
505
506
507
508
509
510
        return self.model(
            input_ids=input_ids,
            positions=positions,
            inputs_embeds=inputs_embeds,
            intermediate_tensors=intermediate_tensors,
        )
511

512
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
513
514
        weights_list = list(weights)

515
        has_model_prefix = any(name.startswith("model.") for name, _ in weights_list)
516
517
518
519
520
        if not has_model_prefix:
            mapper = WeightsMapper(orig_to_new_prefix={"": "model."})

        loader = AutoWeightsLoader(self, skip_prefixes=["lm_head."])
        return loader.load_weights(weights_list, mapper=mapper)
521

522
523
524
525
    def _build_model(self, vllm_config: VllmConfig, prefix: str = "") -> BertModel:
        return BertModel(
            vllm_config=vllm_config, prefix=prefix, embedding_class=BertEmbedding
        )
526
527

    def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler:
528
529
        return DispatchPooler(
            {
530
                "token_embed": Pooler.for_token_embed(pooler_config),
531
532
533
                "embed": Pooler.for_embed(pooler_config),
            }
        )
534
535


536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
# Here we encode the token type ids together with the input ids.
# Since we use int 32 for the input IDs and the vocabulary size
# is way lower than 2**31, there is room to encode additional
# bits. At the same time, for cross-encoder use cases, the
# token type ids are only 0 or 1, requiring only 1 bit.
# This means that we can store the token type ids in the 31st
# bit. We void the 32nd bit because that would produce a negative
# number, which could be used to signal other things.
#
# The reason for all of this is that all the tensors that are
# passed as input to the forward function of a module marked
# with @support_torch_compile have to be persistent. So to
# avoid adding more persistent tensors in the model runner, we
# encode more information in the same persistent tensor.
#
# Since the *ForClassification module is outside of the BertModel
# which is compiled, we can do the encoding here and then separate
# the information again in the Embedding  layer. Since with bit masks
# we can do this entirely with torch operations and without branching,
# it works with torch compile.

TOKEN_TYPE_SHIFT = 30


560
561
562
def _encode_token_type_ids(
    input_ids: torch.Tensor, token_type_ids: torch.Tensor
) -> None:
563
    # input_ids can be padded to the right
564
    input_ids[: token_type_ids.shape[0]].bitwise_or_(token_type_ids << TOKEN_TYPE_SHIFT)
565
566
567


def _decode_token_type_ids(input_ids: torch.Tensor) -> torch.Tensor:
568
569
570
571
    ids_mask = (
        torch.ones_like(input_ids, dtype=torch.int32, device=input_ids.device)
        << TOKEN_TYPE_SHIFT
    )
572
573
574
575
576
577
578
579
580
    tokens_mask = ids_mask.bitwise_not()

    token_type_ids = input_ids.bitwise_and(ids_mask) >> TOKEN_TYPE_SHIFT

    input_ids.bitwise_and_(tokens_mask)

    return token_type_ids


581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
class BertMLMHead(nn.Module):
    def __init__(
        self, hidden_size: int, vocab_size: int, layer_norm_eps: float = 1e-12
    ):
        super().__init__()
        self.dense = nn.Linear(hidden_size, hidden_size)
        self.activation = nn.GELU()
        self.layer_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
        self.decoder = nn.Linear(hidden_size, vocab_size, bias=True)

    def tie_weights_with_embeddings(self, embeddings_weight: torch.Tensor):
        self.decoder.weight = embeddings_weight

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        x = self.dense(hidden_states)
        x = self.activation(x)
        x = self.layer_norm(x)
        logits = self.decoder(x)
        return logits


class SPLADESparsePooler(Pooler):
    """
    SPLADE sparse pooling:
    logits = mlm_head(hidden_states)
            -> log1p(relu(logits))
            -> (max|sum over L)
            -> [V]

    Padding is masked with an attention mask,
    [CLS]/[SEP] is removed (selected),
    and then pooled.
    """

    def __init__(
        self,
        mlm_head: nn.Module,
618
619
        cls_token_id: int | None = 101,
        sep_token_id: int | None = 102,
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
        pooling: str = "max",
        remove_cls_sep: bool = True,
    ):
        super().__init__()
        assert pooling in ("max", "sum")
        self.mlm_head = mlm_head
        self.cls_token_id = cls_token_id
        self.sep_token_id = sep_token_id
        self.pooling = pooling
        self.remove_cls_sep = remove_cls_sep

    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"embed"}

    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
        return PoolingParamsUpdate(requires_token_ids=True)

    def forward(
        self,
        hidden_states: torch.Tensor,
        pooling_metadata: PoolingMetadata,
    ) -> torch.Tensor:
        assert isinstance(hidden_states, torch.Tensor) and hidden_states.dim() == 2

        lens_tensor: torch.Tensor = pooling_metadata.prompt_lens
        lens: list[int] = lens_tensor.tolist()
        B: int = len(lens)

        token_ids = pooling_metadata.prompt_token_ids
        offset = 0
        pooled_list: list[torch.Tensor] = []

        for i in range(B):
            L = int(lens[i])
            hs = hidden_states[offset : offset + L]

            start_idx = 0
            end_idx = L
            if self.remove_cls_sep and token_ids is not None:
                if (
                    self.cls_token_id is not None
                    and token_ids[i, 0].item() == self.cls_token_id
                ):
                    start_idx = 1
                if (
                    self.sep_token_id is not None
                    and token_ids[i, L - 1].item() == self.sep_token_id
                ):
                    end_idx = max(start_idx, L - 1)

            if end_idx <= start_idx:
                V = int(self.mlm_head.decoder.out_features)
                pooled_list.append(hs.new_zeros((V,)))
                offset += L
                continue

            logits_i = self.mlm_head(hs[start_idx:end_idx])
            scores_i = torch.log1p(torch.relu(logits_i))

            if self.pooling == "sum":
                pooled_i = scores_i.sum(dim=0)
            else:  # "max"
                pooled_i = scores_i.max(dim=0).values

            pooled_list.append(pooled_i.contiguous())
            offset += L

        return torch.stack(pooled_list, dim=0).contiguous()


@default_pooling_type("CLS")
class BertSpladeSparseEmbeddingModel(BertEmbeddingModel):
    """
    BertEmbeddingModel + SPLADE sparse embedding.
    - Make logits by self.mlm_head
    - pooler: SPLADESparsePooler(mlm_head...)
    """

    def __init__(
        self, *, vllm_config: VllmConfig, prefix: str = "", splade_pooling: str = "max"
    ):
        super().__init__(vllm_config=vllm_config, prefix=prefix)
        cfg = vllm_config.model_config.hf_config

        # MLM head
        self.mlm_head = BertMLMHead(
            hidden_size=cfg.hidden_size,
            vocab_size=cfg.vocab_size,
            layer_norm_eps=getattr(cfg, "layer_norm_eps", 1e-12),
        )

        self._splade_pooling = splade_pooling
        pooler_config = vllm_config.model_config.pooler_config
        assert pooler_config is not None
        self.pooler = self._build_pooler(pooler_config)

    def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler:
        cfg = self.model.config

        if not hasattr(self, "mlm_head"):
            self.mlm_head = BertMLMHead(
                hidden_size=cfg.hidden_size,
                vocab_size=cfg.vocab_size,
                layer_norm_eps=getattr(cfg, "layer_norm_eps", 1e-12),
            )

        pooling_mode = getattr(self, "_splade_pooling", "max")

        cls_id = getattr(cfg, "cls_token_id", None)
        sep_id = getattr(cfg, "sep_token_id", None)

        return DispatchPooler(
            {
733
                "token_embed": Pooler.for_token_embed(pooler_config),
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
                "embed": SPLADESparsePooler(
                    mlm_head=self.mlm_head,
                    cls_token_id=cls_id,
                    sep_token_id=sep_id,
                    pooling=pooling_mode,  # "max" or "sum"
                    remove_cls_sep=True,
                ),
            }
        )

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
        if not hasattr(self, "mlm_head"):
            cfg = self.model.config
            self.mlm_head = BertMLMHead(
                hidden_size=cfg.hidden_size,
                vocab_size=cfg.vocab_size,
                layer_norm_eps=getattr(cfg, "layer_norm_eps", 1e-12),
            )

        def _strip(name: str) -> str:
            for p in ("model.", "bert."):
                if name.startswith(p):
                    name = name[len(p) :]
            return name

        weights_list = list(weights)
        model_side: list[tuple[str, torch.Tensor]] = []
        mlm_side: list[tuple[str, torch.Tensor]] = []

        for k, w in weights_list:
            name = _strip(k)
            if name.startswith("cls.predictions."):
                mlm_side.append((name, w))
            else:
                model_side.append((name, w))

        loaded: set[str] = set()
        loaded_model = self.model.load_weights(model_side)
        loaded.update({"model." + n for n in loaded_model})

        if mlm_side:
            name_map = {
                "cls.predictions.transform.dense.weight": "mlm_head.dense.weight",
                "cls.predictions.transform.dense.bias": "mlm_head.dense.bias",
                ("cls.predictions.transform.LayerNorm.weight"): (
                    "mlm_head.layer_norm.weight"
                ),
                ("cls.predictions.transform.LayerNorm.bias"): (
                    "mlm_head.layer_norm.bias"
                ),
                "cls.predictions.decoder.weight": "mlm_head.decoder.weight",
                "cls.predictions.decoder.bias": "mlm_head.decoder.bias",
            }
            remapped = [(name_map[n], w) for n, w in mlm_side if n in name_map]
            if remapped:
                loaded_mlm = AutoWeightsLoader(self).load_weights(remapped)
                loaded.update(loaded_mlm)

        return loaded


795
@default_pooling_type("CLS")
796
class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, SupportsQuant):
797
798
    """A model that uses Bert to provide embedding functionalities.

799
800
    This class encapsulates the BertModel and provides an interface for
    embedding operations and customized pooling functions.
801

802
803
804
805
    Attributes:
        model: An instance of BertModel used for forward operations.
        _pooler: An instance of Pooler used for pooling operations.
    """
806

807
808
    is_pooling_model = True

809
810
811
812
813
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config

        self.num_labels = config.num_labels
814
815
816
817
818
819
820
821
822
823
        self.bert = BertPoolingModel(
            vllm_config=vllm_config,
            prefix=maybe_prefix(prefix, "bert"),
            embedding_class=BertEmbedding,
        )
        self.classifier = nn.Linear(
            config.hidden_size,
            config.num_labels,
            dtype=vllm_config.model_config.head_dtype,
        )
824
825
826
827

        pooler_config = vllm_config.model_config.pooler_config
        assert pooler_config is not None

828
829
        self.pooler = DispatchPooler(
            {
830
831
832
                "token_classify": Pooler.for_token_classify(
                    pooler_config, classifier=self.classifier
                ),
833
834
835
                "classify": ClassifierPooler(
                    pooling=self.bert.pooler,
                    classifier=self.classifier,
836
                    act_fn="classify",
837
838
                ),
                "score": ClassifierPooler(
839
                    pooling=self.bert.pooler, classifier=self.classifier, act_fn="score"
840
841
842
                ),
            }
        )
843

844
845
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.bert.embed_input_ids(input_ids)
846

847
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
848
849
850
        loader = AutoWeightsLoader(self)
        loaded_params = loader.load_weights(weights)
        return loaded_params
851
852
853

    def forward(
        self,
854
        input_ids: torch.Tensor | None,
855
        positions: torch.Tensor,
856
857
858
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
        token_type_ids: torch.Tensor | None = None,
859
    ) -> torch.Tensor:
860
861
862
863
864
        if token_type_ids is not None:
            assert self.bert.config.vocab_size < (1 << TOKEN_TYPE_SHIFT)
            assert input_ids is not None
            _encode_token_type_ids(input_ids, token_type_ids)

865
866
867
868
869
870
        return self.bert(
            input_ids=input_ids,
            positions=positions,
            inputs_embeds=inputs_embeds,
            intermediate_tensors=intermediate_tensors,
        )
871
872


873
@attn_type("encoder_only")
874
875
876
877
878
879
880
881
882
@default_pooling_type("ALL")
class BertForTokenClassification(nn.Module):
    is_pooling_model = True

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        self.head_dtype = vllm_config.model_config.head_dtype
        self.num_labels = config.num_labels
883
884
885
886
887
888
889
890
        self.bert = BertModel(
            vllm_config=vllm_config,
            prefix=maybe_prefix(prefix, "bert"),
            embedding_class=BertEmbedding,
        )
        self.classifier = nn.Linear(
            config.hidden_size, config.num_labels, dtype=self.head_dtype
        )
891
892
893
894

        pooler_config = vllm_config.model_config.pooler_config
        assert pooler_config is not None

895
896
        self.pooler = DispatchPooler(
            {
897
898
899
                "token_classify": Pooler.for_token_classify(
                    pooler_config=pooler_config
                ),
900
901
            }
        )
902

903
904
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.bert.embed_input_ids(input_ids)
905

906
907
908
909
910
911
912
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
        loader = AutoWeightsLoader(self)
        loaded_params = loader.load_weights(weights)
        return loaded_params

    def forward(
        self,
913
        input_ids: torch.Tensor | None,
914
        positions: torch.Tensor,
915
916
917
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
        token_type_ids: torch.Tensor | None = None,
918
919
920
921
922
923
    ) -> torch.Tensor:
        if token_type_ids is not None:
            assert self.bert.config.vocab_size < (1 << TOKEN_TYPE_SHIFT)
            assert input_ids is not None
            _encode_token_type_ids(input_ids, token_type_ids)

924
925
926
927
928
929
        hidden_states = self.bert(
            input_ids=input_ids,
            positions=positions,
            inputs_embeds=inputs_embeds,
            intermediate_tensors=intermediate_tensors,
        )
930
931
932

        hidden_states = hidden_states.to(self.head_dtype)
        return self.classifier(hidden_states)