bert.py 30.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from collections.abc import Iterable, Set
5
6
7
8
9

import torch
from torch import nn
from transformers import BertConfig

10
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
11
from vllm.compilation.decorators import support_torch_compile
12
from vllm.config import CacheConfig, PoolerConfig, VllmConfig
13
from vllm.distributed import get_tensor_model_parallel_world_size
14
from vllm.model_executor.layers.activation import get_act_fn
15
16
17
18
19
20
21
22
23
24
25
26
27
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
from vllm.model_executor.layers.pooler import (
    ClassifierPooler,
    DispatchPooler,
    Pooler,
    PoolingMethod,
    PoolingParamsUpdate,
    PoolingType,
)
28
from vllm.model_executor.layers.quantization import QuantizationConfig
29
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
30
from vllm.sequence import IntermediateTensors
31
from vllm.tasks import PoolingTask
32
from vllm.v1.pool.metadata import PoolingMetadata
33

34
from .interfaces import SupportsCrossEncoding, SupportsQuant
35
from .interfaces_base import attn_type, default_pooling_type
36
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
37

38
39
40
41
42

class BertEmbedding(nn.Module):
    def __init__(self, config: BertConfig):
        super().__init__()
        self.size = config.hidden_size
43
44
45
        self.word_embeddings = VocabParallelEmbedding(
            config.vocab_size, config.hidden_size
        )
46
        self.position_embeddings = VocabParallelEmbedding(
47
48
            config.max_position_embeddings, config.hidden_size
        )
49
        self.token_type_embeddings = VocabParallelEmbedding(
50
51
52
            config.type_vocab_size, config.hidden_size
        )
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
53

54
55
56
57
        self.register_buffer(
            "position_ids",
            torch.arange(config.max_position_embeddings).unsqueeze(0),
        )
58
59
60
        self.position_embedding_type = getattr(
            config, "position_embedding_type", "absolute"
        )
61
        if self.position_embedding_type != "absolute":
62
63
64
            raise ValueError(
                "Only 'absolute' position_embedding_type" + " is supported"
            )
65
66
67
68

    def forward(
        self,
        input_ids: torch.Tensor,
69
        position_ids: torch.Tensor,
70
        inputs_embeds: torch.Tensor | None = None,
71
    ) -> torch.Tensor:
72
        token_type_ids = _decode_token_type_ids(input_ids)
73

74
75
76
        if inputs_embeds is None:
            inputs_embeds = self.word_embeddings(input_ids)

77
78
        position_embeddings = self.position_embeddings(position_ids)

79
        token_type_embeddings = self.token_type_embeddings(token_type_ids)
80

81
        embeddings = inputs_embeds + token_type_embeddings + position_embeddings
82
83
84
85
        embeddings = self.LayerNorm(embeddings)
        return embeddings


86
class BertPooler(Pooler):
87
88
    def __init__(self, config: BertConfig):
        super().__init__()
89
90

        self.pooling = PoolingMethod.from_pooling_type(PoolingType.CLS)
91
92
93
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.activation = nn.Tanh()

94
95
96
97
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return self.pooling.get_supported_tasks()

    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
98
        return self.pooling.get_pooling_updates(task)
99

100
101
102
103
104
    def _head(self, pooled_output: torch.Tensor):
        pooled_output = self.dense(pooled_output)
        pooled_output = self.activation(pooled_output)
        return pooled_output

105
106
    def forward(
        self,
107
        hidden_states: torch.Tensor | list[torch.Tensor],
108
        pooling_metadata: PoolingMetadata,
109
    ) -> torch.Tensor | list[torch.Tensor]:
110
        pooled_output = self.pooling(hidden_states, pooling_metadata)
111
112
113
114
115
116

        if isinstance(pooled_output, list):
            pooled_output = [self._head(output) for output in pooled_output]
        else:
            pooled_output = self._head(pooled_output)

117
118
119
        return pooled_output


120
class BertEncoder(nn.Module):
121
    def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
122
        super().__init__()
123
124
125
        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
126
127
128
129
130
131
132
133
134
135
136
        self.layer = nn.ModuleList(
            [
                BertLayer(
                    config=config,
                    cache_config=cache_config,
                    quant_config=quant_config,
                    prefix=f"{prefix}.layer.{layer_idx}",
                )
                for layer_idx in range(config.num_hidden_layers)
            ]
        )
137
138
139
140
141

    def forward(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
142
        for layer in self.layer:
143
            hidden_states = layer(hidden_states)
144
145
146
147
        return hidden_states


class BertLayer(nn.Module):
148
149
150
    def __init__(
        self,
        config: BertConfig,
151
152
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
153
154
        prefix: str = "",
    ):
155
156
157
158
159
160
161
162
        super().__init__()

        self.attention = BertAttention(
            hidden_size=config.hidden_size,
            num_attention_heads=config.num_attention_heads,
            layer_norm_eps=config.layer_norm_eps,
            cache_config=cache_config,
            quant_config=quant_config,
163
164
            prefix=f"{prefix}.attention",
        )
165

166
167
168
169
170
        self.intermediate = BertIntermediate(
            hidden_size=config.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
            quant_config=quant_config,
171
172
            prefix=f"{prefix}.intermediate",
        )
173

174
175
176
177
178
179
180
        self.output = BertOutput(
            hidden_size=config.hidden_size,
            intermediate_size=config.intermediate_size,
            layer_norm_eps=config.layer_norm_eps,
            quant_config=quant_config,
            prefix=f"{prefix}.output",
        )
181

182
183
    def forward(self, hidden_states: torch.Tensor):
        attn_output = self.attention(hidden_states)
184
185
186
187
188
189
190
191
192
193
194
        intermediate_output = self.intermediate(attn_output)
        output = self.output(intermediate_output, attn_output)
        return output


class BertAttention(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        num_attention_heads: int,
        layer_norm_eps: float,
195
196
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
197
198
199
200
        prefix: str = "",
    ):
        super().__init__()

201
202
203
204
205
206
207
        self.self = BertSelfAttention(
            hidden_size=hidden_size,
            num_attention_heads=num_attention_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.output",
        )
208

209
210
211
212
213
214
        self.output = BertSelfOutput(
            hidden_size=hidden_size,
            layer_norm_eps=layer_norm_eps,
            quant_config=quant_config,
            prefix=f"{prefix}.output",
        )
215
216
217
218
219

    def forward(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
220
        self_output = self.self(hidden_states)
221
222
223
224
225
226
227
228
        return self.output(self_output, hidden_states)


class BertSelfAttention(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        num_attention_heads: int,
229
230
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
        prefix: str = "",
    ):
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()

        self.total_num_heads = num_attention_heads
        assert self.total_num_heads % tp_size == 0

        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = self.total_num_heads
        self.head_dim = self.hidden_size // self.total_num_heads
        assert self.head_dim * self.total_num_heads == self.hidden_size

        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)

        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.qkv_proj = QKVParallelLinear(
            hidden_size=self.hidden_size,
            head_size=self.head_dim,
            total_num_heads=self.total_num_heads,
            total_num_kv_heads=self.total_num_kv_heads,
255
            bias=True,
256
            quant_config=quant_config,
257
258
            prefix=f"{prefix}.qkv_proj",
        )
259

260
261
262
263
264
265
266
267
268
        self.attn = EncoderOnlyAttention(
            num_heads=self.num_heads,
            head_size=self.head_dim,
            scale=self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )
269
270
271
272
273
274
275

    def forward(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
276
        output = self.attn(q, k, v)
277
278
279
280
        return output


class BertSelfOutput(nn.Module):
281
282
283
284
    def __init__(
        self,
        hidden_size: int,
        layer_norm_eps: float,
285
        quant_config: QuantizationConfig | None = None,
286
287
        prefix: str = "",
    ):
288
        super().__init__()
289
290
291
292
293
294
295
        self.dense = RowParallelLinear(
            input_size=hidden_size,
            output_size=hidden_size,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.dense",
        )
296
297
        self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)

298
299
300
    def forward(
        self, hidden_states: torch.Tensor, input_tensor: torch.Tensor
    ) -> torch.Tensor:
301
302
303
304
305
306
        hidden_states, _ = self.dense(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states


class BertIntermediate(nn.Module):
307
308
309
310
311
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
312
        quant_config: QuantizationConfig | None = None,
313
314
        prefix: str = "",
    ):
315
        super().__init__()
316
317
318
319
320
321
322
        self.dense = ColumnParallelLinear(
            input_size=hidden_size,
            output_size=intermediate_size,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.dense",
        )
323
324
325
326
327
328
329
330
331
        self.intermediate_act_fn = get_act_fn(hidden_act)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.dense(hidden_states)
        hidden_states = self.intermediate_act_fn(hidden_states)
        return hidden_states


class BertOutput(nn.Module):
332
333
334
335
336
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        layer_norm_eps: float,
337
        quant_config: QuantizationConfig | None = None,
338
339
        prefix: str = "",
    ):
340
341
        super().__init__()

342
343
344
345
346
347
348
        self.dense = RowParallelLinear(
            input_size=intermediate_size,
            output_size=hidden_size,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.dense",
        )
349
350
351

        self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)

352
353
354
    def forward(
        self, hidden_states: torch.Tensor, input_tensor: torch.Tensor
    ) -> torch.Tensor:
355
356
357
358
359
        hidden_states, _ = self.dense(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states


360
@support_torch_compile
361
@default_pooling_type("CLS")
362
class BertModel(nn.Module, SupportsQuant):
363
364
    is_pooling_model = True

365
    packed_modules_mapping = {"qkv_proj": ["query", "key", "value"]}
366

367
368
369
370
371
372
373
    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
        embedding_class: type[nn.Module] = BertEmbedding,
    ) -> None:
374
        super().__init__()
375

376
377
        self.config = vllm_config.model_config.hf_config
        self.embeddings = embedding_class(self.config)
378
        self.encoder = BertEncoder(vllm_config=vllm_config, prefix=f"{prefix}.encoder")
379

380
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
381
        return self.embeddings.word_embeddings(input_ids)
382

383
384
385
    def forward(
        self,
        input_ids: torch.Tensor,
386
        positions: torch.Tensor,
387
388
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
389
    ) -> torch.Tensor:
390
391
392
393
394
395
        hidden_states = self.embeddings(
            input_ids=input_ids,
            position_ids=positions,
            inputs_embeds=inputs_embeds,
        )

396
        return self.encoder(hidden_states)
397

398
    def _load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
399
400
401
402
403
404
405
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "query", "q"),
            ("qkv_proj", "key", "k"),
            ("qkv_proj", "value", "v"),
        ]

406
407
        loaded_stacked_params = []
        other_weights = []
408
409
        params_dict = dict(self.named_parameters())
        for name, loaded_weight in weights:
410
            for param_name, weight_name, shard_id in stacked_params_mapping:
411
412
                if weight_name not in name:
                    continue
413

414
                name = name.replace(weight_name, param_name)
415
                if name not in params_dict:
416
417
418
419
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
420
                loaded_stacked_params.append(name)
421
422
                break
            else:
423
424
425
                if name in params_dict:
                    other_weights.append((name, loaded_weight))

426
427
        return other_weights, loaded_stacked_params

428
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
        other_weights, loaded_stacked_params = self._load_weights(weights)

        loader = AutoWeightsLoader(self, skip_prefixes=["pooler."])
        loaded_params = loader.load_weights(other_weights)
        loaded_params.update(loaded_stacked_params)
        return loaded_params


class BertPoolingModel(BertModel):
    is_pooling_model = True

    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
        embedding_class: type[nn.Module] = BertEmbedding,
    ) -> None:
        super().__init__(
            vllm_config=vllm_config,
            prefix=prefix,
            embedding_class=embedding_class,
451
        )
452
453
454
455

        config = vllm_config.model_config.hf_config
        self.pooler = BertPooler(config)

456
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
457
458
459
        other_weights, loaded_stacked_params = self._load_weights(weights)

        loader = AutoWeightsLoader(self)
460
461
        loaded_params = loader.load_weights(other_weights)
        loaded_params.update(loaded_stacked_params)
462
        return loaded_params
463
464


465
@default_pooling_type("CLS")
466
class BertEmbeddingModel(nn.Module, SupportsQuant):
467
468
    """A model that uses Bert to provide embedding functionalities.

469
470
    This class encapsulates the BertModel and provides an interface for
    embedding operations and customized pooling functions.
471

472
473
474
475
    Attributes:
        model: An instance of BertModel used for forward operations.
        _pooler: An instance of Pooler used for pooling operations.
    """
476

477
478
    is_pooling_model = True

479
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
480
        super().__init__()
481

482
        pooler_config = vllm_config.model_config.pooler_config
483
484
        assert pooler_config is not None

485
486
487
        self.model = self._build_model(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
488
        self.pooler = self._build_pooler(pooler_config)
489

490
491
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
492

493
494
    def forward(
        self,
495
        input_ids: torch.Tensor,
496
        positions: torch.Tensor,
497
498
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
499
    ) -> torch.Tensor:
500
501
502
503
504
505
        return self.model(
            input_ids=input_ids,
            positions=positions,
            inputs_embeds=inputs_embeds,
            intermediate_tensors=intermediate_tensors,
        )
506

507
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
508
509
        weights_list = list(weights)

510
        has_model_prefix = any(name.startswith("model.") for name, _ in weights_list)
511
512
513
514
515
        if not has_model_prefix:
            mapper = WeightsMapper(orig_to_new_prefix={"": "model."})

        loader = AutoWeightsLoader(self, skip_prefixes=["lm_head."])
        return loader.load_weights(weights_list, mapper=mapper)
516

517
518
519
520
    def _build_model(self, vllm_config: VllmConfig, prefix: str = "") -> BertModel:
        return BertModel(
            vllm_config=vllm_config, prefix=prefix, embedding_class=BertEmbedding
        )
521
522

    def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler:
523
524
        return DispatchPooler(
            {
525
                "token_embed": Pooler.for_token_embed(pooler_config),
526
527
528
                "embed": Pooler.for_embed(pooler_config),
            }
        )
529
530


531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
# Here we encode the token type ids together with the input ids.
# Since we use int 32 for the input IDs and the vocabulary size
# is way lower than 2**31, there is room to encode additional
# bits. At the same time, for cross-encoder use cases, the
# token type ids are only 0 or 1, requiring only 1 bit.
# This means that we can store the token type ids in the 31st
# bit. We void the 32nd bit because that would produce a negative
# number, which could be used to signal other things.
#
# The reason for all of this is that all the tensors that are
# passed as input to the forward function of a module marked
# with @support_torch_compile have to be persistent. So to
# avoid adding more persistent tensors in the model runner, we
# encode more information in the same persistent tensor.
#
# Since the *ForClassification module is outside of the BertModel
# which is compiled, we can do the encoding here and then separate
# the information again in the Embedding  layer. Since with bit masks
# we can do this entirely with torch operations and without branching,
# it works with torch compile.

TOKEN_TYPE_SHIFT = 30


555
556
557
def _encode_token_type_ids(
    input_ids: torch.Tensor, token_type_ids: torch.Tensor
) -> None:
558
    # input_ids can be padded to the right
559
    input_ids[: token_type_ids.shape[0]].bitwise_or_(token_type_ids << TOKEN_TYPE_SHIFT)
560
561
562


def _decode_token_type_ids(input_ids: torch.Tensor) -> torch.Tensor:
563
564
565
566
    ids_mask = (
        torch.ones_like(input_ids, dtype=torch.int32, device=input_ids.device)
        << TOKEN_TYPE_SHIFT
    )
567
568
569
570
571
572
573
574
575
    tokens_mask = ids_mask.bitwise_not()

    token_type_ids = input_ids.bitwise_and(ids_mask) >> TOKEN_TYPE_SHIFT

    input_ids.bitwise_and_(tokens_mask)

    return token_type_ids


576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
class BertMLMHead(nn.Module):
    def __init__(
        self, hidden_size: int, vocab_size: int, layer_norm_eps: float = 1e-12
    ):
        super().__init__()
        self.dense = nn.Linear(hidden_size, hidden_size)
        self.activation = nn.GELU()
        self.layer_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
        self.decoder = nn.Linear(hidden_size, vocab_size, bias=True)

    def tie_weights_with_embeddings(self, embeddings_weight: torch.Tensor):
        self.decoder.weight = embeddings_weight

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        x = self.dense(hidden_states)
        x = self.activation(x)
        x = self.layer_norm(x)
        logits = self.decoder(x)
        return logits


class SPLADESparsePooler(Pooler):
    """
    SPLADE sparse pooling:
    logits = mlm_head(hidden_states)
            -> log1p(relu(logits))
            -> (max|sum over L)
            -> [V]

    Padding is masked with an attention mask,
    [CLS]/[SEP] is removed (selected),
    and then pooled.
    """

    def __init__(
        self,
        mlm_head: nn.Module,
613
614
        cls_token_id: int | None = 101,
        sep_token_id: int | None = 102,
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
        pooling: str = "max",
        remove_cls_sep: bool = True,
    ):
        super().__init__()
        assert pooling in ("max", "sum")
        self.mlm_head = mlm_head
        self.cls_token_id = cls_token_id
        self.sep_token_id = sep_token_id
        self.pooling = pooling
        self.remove_cls_sep = remove_cls_sep

    def get_supported_tasks(self) -> Set[PoolingTask]:
        return {"embed"}

    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
        return PoolingParamsUpdate(requires_token_ids=True)

    def forward(
        self,
        hidden_states: torch.Tensor,
        pooling_metadata: PoolingMetadata,
    ) -> torch.Tensor:
        assert isinstance(hidden_states, torch.Tensor) and hidden_states.dim() == 2

        lens_tensor: torch.Tensor = pooling_metadata.prompt_lens
        lens: list[int] = lens_tensor.tolist()
        B: int = len(lens)

        token_ids = pooling_metadata.prompt_token_ids
        offset = 0
        pooled_list: list[torch.Tensor] = []

        for i in range(B):
            L = int(lens[i])
            hs = hidden_states[offset : offset + L]

            start_idx = 0
            end_idx = L
            if self.remove_cls_sep and token_ids is not None:
                if (
                    self.cls_token_id is not None
                    and token_ids[i, 0].item() == self.cls_token_id
                ):
                    start_idx = 1
                if (
                    self.sep_token_id is not None
                    and token_ids[i, L - 1].item() == self.sep_token_id
                ):
                    end_idx = max(start_idx, L - 1)

            if end_idx <= start_idx:
                V = int(self.mlm_head.decoder.out_features)
                pooled_list.append(hs.new_zeros((V,)))
                offset += L
                continue

            logits_i = self.mlm_head(hs[start_idx:end_idx])
            scores_i = torch.log1p(torch.relu(logits_i))

            if self.pooling == "sum":
                pooled_i = scores_i.sum(dim=0)
            else:  # "max"
                pooled_i = scores_i.max(dim=0).values

            pooled_list.append(pooled_i.contiguous())
            offset += L

        return torch.stack(pooled_list, dim=0).contiguous()


@default_pooling_type("CLS")
class BertSpladeSparseEmbeddingModel(BertEmbeddingModel):
    """
    BertEmbeddingModel + SPLADE sparse embedding.
    - Make logits by self.mlm_head
    - pooler: SPLADESparsePooler(mlm_head...)
    """

    def __init__(
        self, *, vllm_config: VllmConfig, prefix: str = "", splade_pooling: str = "max"
    ):
        super().__init__(vllm_config=vllm_config, prefix=prefix)
        cfg = vllm_config.model_config.hf_config

        # MLM head
        self.mlm_head = BertMLMHead(
            hidden_size=cfg.hidden_size,
            vocab_size=cfg.vocab_size,
            layer_norm_eps=getattr(cfg, "layer_norm_eps", 1e-12),
        )

        self._splade_pooling = splade_pooling
        pooler_config = vllm_config.model_config.pooler_config
        assert pooler_config is not None
        self.pooler = self._build_pooler(pooler_config)

    def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler:
        cfg = self.model.config

        if not hasattr(self, "mlm_head"):
            self.mlm_head = BertMLMHead(
                hidden_size=cfg.hidden_size,
                vocab_size=cfg.vocab_size,
                layer_norm_eps=getattr(cfg, "layer_norm_eps", 1e-12),
            )

        pooling_mode = getattr(self, "_splade_pooling", "max")

        cls_id = getattr(cfg, "cls_token_id", None)
        sep_id = getattr(cfg, "sep_token_id", None)

        return DispatchPooler(
            {
728
                "token_embed": Pooler.for_token_embed(pooler_config),
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
                "embed": SPLADESparsePooler(
                    mlm_head=self.mlm_head,
                    cls_token_id=cls_id,
                    sep_token_id=sep_id,
                    pooling=pooling_mode,  # "max" or "sum"
                    remove_cls_sep=True,
                ),
            }
        )

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
        if not hasattr(self, "mlm_head"):
            cfg = self.model.config
            self.mlm_head = BertMLMHead(
                hidden_size=cfg.hidden_size,
                vocab_size=cfg.vocab_size,
                layer_norm_eps=getattr(cfg, "layer_norm_eps", 1e-12),
            )

        def _strip(name: str) -> str:
            for p in ("model.", "bert."):
                if name.startswith(p):
                    name = name[len(p) :]
            return name

        weights_list = list(weights)
        model_side: list[tuple[str, torch.Tensor]] = []
        mlm_side: list[tuple[str, torch.Tensor]] = []

        for k, w in weights_list:
            name = _strip(k)
            if name.startswith("cls.predictions."):
                mlm_side.append((name, w))
            else:
                model_side.append((name, w))

        loaded: set[str] = set()
        loaded_model = self.model.load_weights(model_side)
        loaded.update({"model." + n for n in loaded_model})

        if mlm_side:
            name_map = {
                "cls.predictions.transform.dense.weight": "mlm_head.dense.weight",
                "cls.predictions.transform.dense.bias": "mlm_head.dense.bias",
                ("cls.predictions.transform.LayerNorm.weight"): (
                    "mlm_head.layer_norm.weight"
                ),
                ("cls.predictions.transform.LayerNorm.bias"): (
                    "mlm_head.layer_norm.bias"
                ),
                "cls.predictions.decoder.weight": "mlm_head.decoder.weight",
                "cls.predictions.decoder.bias": "mlm_head.decoder.bias",
            }
            remapped = [(name_map[n], w) for n, w in mlm_side if n in name_map]
            if remapped:
                loaded_mlm = AutoWeightsLoader(self).load_weights(remapped)
                loaded.update(loaded_mlm)

        return loaded


790
@default_pooling_type("CLS")
791
class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, SupportsQuant):
792
793
    """A model that uses Bert to provide embedding functionalities.

794
795
    This class encapsulates the BertModel and provides an interface for
    embedding operations and customized pooling functions.
796

797
798
799
800
    Attributes:
        model: An instance of BertModel used for forward operations.
        _pooler: An instance of Pooler used for pooling operations.
    """
801

802
803
    is_pooling_model = True

804
805
806
807
808
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config

        self.num_labels = config.num_labels
809
810
811
812
813
814
815
816
817
818
        self.bert = BertPoolingModel(
            vllm_config=vllm_config,
            prefix=maybe_prefix(prefix, "bert"),
            embedding_class=BertEmbedding,
        )
        self.classifier = nn.Linear(
            config.hidden_size,
            config.num_labels,
            dtype=vllm_config.model_config.head_dtype,
        )
819
820
821
822

        pooler_config = vllm_config.model_config.pooler_config
        assert pooler_config is not None

823
824
        self.pooler = DispatchPooler(
            {
825
826
827
                "token_classify": Pooler.for_token_classify(
                    pooler_config, classifier=self.classifier
                ),
828
829
830
                "classify": ClassifierPooler(
                    pooling=self.bert.pooler,
                    classifier=self.classifier,
831
                    act_fn="classify",
832
833
                ),
                "score": ClassifierPooler(
834
                    pooling=self.bert.pooler, classifier=self.classifier, act_fn="score"
835
836
837
                ),
            }
        )
838

839
840
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.bert.embed_input_ids(input_ids)
841

842
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
843
844
845
        loader = AutoWeightsLoader(self)
        loaded_params = loader.load_weights(weights)
        return loaded_params
846
847
848

    def forward(
        self,
849
        input_ids: torch.Tensor | None,
850
        positions: torch.Tensor,
851
852
853
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
        token_type_ids: torch.Tensor | None = None,
854
    ) -> torch.Tensor:
855
856
857
858
859
        if token_type_ids is not None:
            assert self.bert.config.vocab_size < (1 << TOKEN_TYPE_SHIFT)
            assert input_ids is not None
            _encode_token_type_ids(input_ids, token_type_ids)

860
861
862
863
864
865
        return self.bert(
            input_ids=input_ids,
            positions=positions,
            inputs_embeds=inputs_embeds,
            intermediate_tensors=intermediate_tensors,
        )
866
867


868
@attn_type("encoder_only")
869
870
871
872
873
874
875
876
877
@default_pooling_type("ALL")
class BertForTokenClassification(nn.Module):
    is_pooling_model = True

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        self.head_dtype = vllm_config.model_config.head_dtype
        self.num_labels = config.num_labels
878
879
880
881
882
883
884
885
        self.bert = BertModel(
            vllm_config=vllm_config,
            prefix=maybe_prefix(prefix, "bert"),
            embedding_class=BertEmbedding,
        )
        self.classifier = nn.Linear(
            config.hidden_size, config.num_labels, dtype=self.head_dtype
        )
886
887
888
889

        pooler_config = vllm_config.model_config.pooler_config
        assert pooler_config is not None

890
891
        self.pooler = DispatchPooler(
            {
892
893
894
                "token_classify": Pooler.for_token_classify(
                    pooler_config=pooler_config
                ),
895
896
            }
        )
897

898
899
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.bert.embed_input_ids(input_ids)
900

901
902
903
904
905
906
907
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
        loader = AutoWeightsLoader(self)
        loaded_params = loader.load_weights(weights)
        return loaded_params

    def forward(
        self,
908
        input_ids: torch.Tensor | None,
909
        positions: torch.Tensor,
910
911
912
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
        token_type_ids: torch.Tensor | None = None,
913
914
915
916
917
918
    ) -> torch.Tensor:
        if token_type_ids is not None:
            assert self.bert.config.vocab_size < (1 << TOKEN_TYPE_SHIFT)
            assert input_ids is not None
            _encode_token_type_ids(input_ids, token_type_ids)

919
920
921
922
923
924
        hidden_states = self.bert(
            input_ids=input_ids,
            positions=positions,
            inputs_embeds=inputs_embeds,
            intermediate_tensors=intermediate_tensors,
        )
925
926
927

        hidden_states = hidden_states.to(self.head_dtype)
        return self.classifier(hidden_states)