bert.py 27.5 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
from typing import Iterable, Optional, Set, Tuple
4
5
6
7
8

import torch
from torch import nn
from transformers import BertConfig

9
from vllm.attention import Attention, AttentionType
10
from vllm.compilation.decorators import support_torch_compile
11
from vllm.config import CacheConfig, PoolerConfig, VllmConfig
12
from vllm.distributed import get_tensor_model_parallel_world_size
13
from vllm.forward_context import get_forward_context
14
15
from vllm.model_executor.layers.activation import (get_act_and_mul_fn,
                                                   get_act_fn)
16
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
17
                                               MergedColumnParallelLinear,
18
19
                                               QKVParallelLinear,
                                               RowParallelLinear)
20
21
from vllm.model_executor.layers.pooler import (CrossEncodingPooler, Pooler,
                                               PoolingType)
22
from vllm.model_executor.layers.quantization import QuantizationConfig
23
from vllm.model_executor.layers.rotary_embedding import get_rope
24
25
26
27
28
from vllm.model_executor.layers.vocab_parallel_embedding import (
    VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput
29
30
from vllm.transformers_utils.config import (
    get_cross_encoder_activation_function)
31

32
from .interfaces import SupportsCrossEncoding, SupportsQuant, SupportsV0Only
33
from .utils import WeightsMapper, maybe_prefix
34

35
36
37
38
39
40
41
42
43

class BertEmbedding(nn.Module):

    def __init__(self, config: BertConfig):

        super().__init__()
        self.size = config.hidden_size
        self.word_embeddings = VocabParallelEmbedding(config.vocab_size,
                                                      config.hidden_size)
44

45
46
47
48
49
50
        self.token_type_embeddings = VocabParallelEmbedding(
            config.type_vocab_size, config.hidden_size)
        self.LayerNorm = nn.LayerNorm(config.hidden_size,
                                      eps=config.layer_norm_eps)

        self.position_embedding_type = config.position_embedding_type
51
52
53
54
55
56
57
58
59
60
61
        if self.position_embedding_type == "absolute":
            self.position_embeddings = VocabParallelEmbedding(
                config.max_position_embeddings, config.hidden_size)
            self.position_ids = nn.Parameter(
                torch.empty((1, config.max_position_embeddings)), )
        elif self.position_embedding_type == "rotary":
            self.position_embeddings = None
            self.position_ids = None
        else:
            raise ValueError("Only 'absolute' and 'rotary' " +
                             "position_embedding_type is supported")
62
63
64
65

    def forward(
        self,
        input_ids: torch.Tensor,
66
67
68
        seq_lens: torch.Tensor,
        position_ids: torch.Tensor,
        token_type_ids: Optional[torch.Tensor] = None,
69
70
71
72
73
74
    ) -> torch.Tensor:
        input_shape = input_ids.size()

        # Input embeddings.
        inputs_embeds = self.word_embeddings(input_ids)

75
76
77
78
79
80
        if token_type_ids is None:
            token_type_ids = torch.zeros(input_shape,
                                         dtype=torch.long,
                                         device=inputs_embeds.device)

        token_type_embeddings = self.token_type_embeddings(token_type_ids)
81

82
83
84
85
86
87
        embeddings = inputs_embeds + token_type_embeddings

        if self.position_embedding_type == "absolute":
            position_embeddings = self.position_embeddings(position_ids)
            embeddings += position_embeddings

88
89
90
91
        embeddings = self.LayerNorm(embeddings)
        return embeddings


92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
class BertPooler(nn.Module):

    def __init__(self, config: BertConfig):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.activation = nn.Tanh()

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # We "pool" the model by simply taking the hidden state corresponding
        # to the first token.
        first_token_tensor = hidden_states[0, :]
        pooled_output = self.dense(first_token_tensor)
        pooled_output = self.activation(pooled_output)
        return pooled_output


108
@support_torch_compile
109
110
class BertEncoder(nn.Module):

111
112
    def __init__(self,
                 vllm_config: VllmConfig,
113
                 bias: bool = True,
114
115
                 rotary_kwargs: Optional[dict] = None,
                 prefix: str = ""):
116
        super().__init__()
117
118
119
        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
120
121
122
123
        self.layer = nn.ModuleList([
            BertLayer(config=config,
                      cache_config=cache_config,
                      quant_config=quant_config,
124
                      bias=bias,
125
                      rotary_kwargs=rotary_kwargs,
126
127
128
129
130
131
                      prefix=f"{prefix}.layer.{layer_idx}")
            for layer_idx in range(config.num_hidden_layers)
        ])

    def forward(
        self,
132
        positions: torch.Tensor,
133
134
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
135
        for layer in self.layer:
136
            hidden_states = layer(positions, hidden_states)
137
138
139
140
141
142
143
144
145
        return hidden_states


class BertLayer(nn.Module):

    def __init__(self,
                 config: BertConfig,
                 cache_config: Optional[CacheConfig] = None,
                 quant_config: Optional[QuantizationConfig] = None,
146
                 bias: bool = True,
147
                 rotary_kwargs: Optional[dict] = None,
148
149
150
151
152
153
154
155
156
                 prefix: str = ""):
        super().__init__()

        self.attention = BertAttention(
            hidden_size=config.hidden_size,
            num_attention_heads=config.num_attention_heads,
            layer_norm_eps=config.layer_norm_eps,
            cache_config=cache_config,
            quant_config=quant_config,
157
            bias=bias,
158
            rotary_kwargs=rotary_kwargs,
159
160
            prefix=f"{prefix}.attention")

161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
        if config.hidden_act in ["silu", "gelu_and_mul"]:
            self.intermediate = BertGatedIntermediate(
                hidden_size=config.hidden_size,
                intermediate_size=config.intermediate_size,
                hidden_act=config.hidden_act,
                bias=bias,
                quant_config=quant_config,
                prefix=f"{prefix}.intermediate")
        else:
            self.intermediate = BertIntermediate(
                hidden_size=config.hidden_size,
                intermediate_size=config.intermediate_size,
                hidden_act=config.hidden_act,
                bias=bias,
                quant_config=quant_config,
                prefix=f"{prefix}.intermediate")
177
178
179
180

        self.output = BertOutput(hidden_size=config.hidden_size,
                                 intermediate_size=config.intermediate_size,
                                 layer_norm_eps=config.layer_norm_eps,
181
                                 bias=bias,
182
183
184
                                 quant_config=quant_config,
                                 prefix=f"{prefix}.output")

185
186
    def forward(self, positions: torch.Tensor, hidden_states: torch.Tensor):
        attn_output = self.attention(positions, hidden_states)
187
188
189
190
191
192
193
194
195
196
197
198
199
200
        intermediate_output = self.intermediate(attn_output)
        output = self.output(intermediate_output, attn_output)
        return output


class BertAttention(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        num_attention_heads: int,
        layer_norm_eps: float,
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
201
        bias: bool = True,
202
        rotary_kwargs: Optional[dict] = None,
203
204
205
206
207
208
209
210
        prefix: str = "",
    ):
        super().__init__()

        self.self = BertSelfAttention(hidden_size=hidden_size,
                                      num_attention_heads=num_attention_heads,
                                      cache_config=cache_config,
                                      quant_config=quant_config,
211
                                      bias=bias,
212
                                      rotary_kwargs=rotary_kwargs,
213
214
215
216
                                      prefix=f"{prefix}.output")

        self.output = BertSelfOutput(hidden_size=hidden_size,
                                     layer_norm_eps=layer_norm_eps,
217
                                     bias=bias,
218
219
220
221
222
                                     quant_config=quant_config,
                                     prefix=f"{prefix}.output")

    def forward(
        self,
223
        positions: torch.Tensor,
224
225
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
226
        self_output = self.self(positions, hidden_states)
227
228
229
230
231
232
233
234
235
236
237
        return self.output(self_output, hidden_states)


class BertSelfAttention(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        num_attention_heads: int,
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
238
        bias: bool = True,
239
        rotary_kwargs: Optional[dict] = None,
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
        prefix: str = "",
    ):
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()

        self.total_num_heads = num_attention_heads
        assert self.total_num_heads % tp_size == 0

        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = self.total_num_heads
        self.head_dim = self.hidden_size // self.total_num_heads
        assert self.head_dim * self.total_num_heads == self.hidden_size

        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)

        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.qkv_proj = QKVParallelLinear(
            hidden_size=self.hidden_size,
            head_size=self.head_dim,
            total_num_heads=self.total_num_heads,
            total_num_kv_heads=self.total_num_kv_heads,
264
            bias=bias,
265
266
267
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj")

268
269
270
271
272
        if rotary_kwargs:
            self.rotary_emb = get_rope(**rotary_kwargs)
        else:
            self.rotary_emb = None

273
274
275
276
277
278
        self.attn = Attention(num_heads=self.num_heads,
                              head_size=self.head_dim,
                              scale=self.scaling,
                              num_kv_heads=self.num_kv_heads,
                              cache_config=cache_config,
                              quant_config=quant_config,
279
280
                              prefix=f"{prefix}.attn",
                              attn_type=AttentionType.ENCODER_ONLY)
281
282
283

    def forward(
        self,
284
        positions: torch.Tensor,
285
286
287
288
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
289
290
291
292

        if self.rotary_emb:
            q, k = self.rotary_emb(positions, q, k)

293
        output = self.attn(q, k, v)
294
295
296
297
298
299
300
301
        return output


class BertSelfOutput(nn.Module):

    def __init__(self,
                 hidden_size: int,
                 layer_norm_eps: float,
302
                 bias: bool = True,
303
304
305
306
307
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
        super().__init__()
        self.dense = RowParallelLinear(input_size=hidden_size,
                                       output_size=hidden_size,
308
                                       bias=bias,
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
                                       quant_config=quant_config,
                                       prefix=f"{prefix}.dense")
        self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)

    def forward(self, hidden_states: torch.Tensor,
                input_tensor: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.dense(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states


class BertIntermediate(nn.Module):

    def __init__(self,
                 hidden_size: int,
                 intermediate_size: int,
                 hidden_act: str,
326
                 bias: bool = True,
327
328
329
330
331
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
        super().__init__()
        self.dense = ColumnParallelLinear(input_size=hidden_size,
                                          output_size=intermediate_size,
332
                                          bias=bias,
333
334
335
336
337
338
339
340
341
342
                                          quant_config=quant_config,
                                          prefix=f"{prefix}.dense")
        self.intermediate_act_fn = get_act_fn(hidden_act)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.dense(hidden_states)
        hidden_states = self.intermediate_act_fn(hidden_states)
        return hidden_states


343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
class BertGatedIntermediate(nn.Module):
    # for NomciBert and GteModel

    def __init__(self,
                 hidden_size: int,
                 intermediate_size: int,
                 hidden_act: str,
                 bias: bool = True,
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
        super().__init__()
        self.act_fn = get_act_and_mul_fn(hidden_act)
        self.gate_up_proj = MergedColumnParallelLinear(
            hidden_size,
            [intermediate_size] * 2,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.gate_up_proj",
        )

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        gate_up, _ = self.gate_up_proj(hidden_states)
        hidden_states = self.act_fn(gate_up)
        return hidden_states


369
370
371
372
373
374
class BertOutput(nn.Module):

    def __init__(self,
                 hidden_size: int,
                 intermediate_size: int,
                 layer_norm_eps: float,
375
                 bias: bool = True,
376
377
378
379
380
381
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
        super().__init__()

        self.dense = RowParallelLinear(input_size=intermediate_size,
                                       output_size=hidden_size,
382
                                       bias=bias,
383
384
385
386
387
388
389
390
391
392
393
394
                                       quant_config=quant_config,
                                       prefix=f"{prefix}.dense")

        self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)

    def forward(self, hidden_states: torch.Tensor,
                input_tensor: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.dense(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states


395
class BertModel(nn.Module, SupportsQuant):
396
397
398
399
400
401
402
    packed_modules_mapping = {
        "qkv_proj": ["query", "key", "value"],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }
403

404
405
406
407
    def __init__(self,
                 *,
                 vllm_config: VllmConfig,
                 prefix: str = "",
408
                 embedding_class: type = BertEmbedding,
409
                 bias: bool = True,
410
                 rotary_kwargs: Optional[dict] = None,
411
                 add_pooling_layer: bool = False):
412
        super().__init__()
413
414
415
416
417
        """
        For BertModel, all linear layers have bias.
        For NomicBertModel, all linear layers do not have bias.
        """

418
        config = vllm_config.model_config.hf_config
419
        self.embeddings = embedding_class(config)
420
        self.encoder = BertEncoder(vllm_config=vllm_config,
421
                                   bias=bias,
422
                                   rotary_kwargs=rotary_kwargs,
423
                                   prefix=f"{prefix}.encoder")
424
        self.pooler = BertPooler(config) if add_pooling_layer else None
425
426
427
428
429
430
431

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
432
        token_type_ids: Optional[torch.Tensor] = None,
433
434
435
436
    ) -> torch.Tensor:
        if inputs_embeds is not None:
            hidden_states = inputs_embeds
        else:
437
            attn_metadata = get_forward_context().attn_metadata
438
439
440
441
442
443
            assert hasattr(attn_metadata, "seq_lens_tensor")
            hidden_states = self.embeddings(
                input_ids=input_ids,
                seq_lens=attn_metadata.seq_lens_tensor,
                position_ids=position_ids,
                token_type_ids=token_type_ids)
444
        return self.encoder(position_ids, hidden_states)
445

446
447
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
448
449
450
451
452
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "query", "q"),
            ("qkv_proj", "key", "k"),
            ("qkv_proj", "value", "v"),
453
454
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
455
456
457
        ]

        params_dict = dict(self.named_parameters())
458
        loaded_params: Set[str] = set()
459
        for name, loaded_weight in weights:
460
            if self.pooler is None and "pooler" in name:
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
                continue
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
481
482
            loaded_params.add(name)
        return loaded_params
483
484


485
class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant):
486
487
488
489
490
491
492
493
494
    """A model that uses Bert to provide embedding functionalities.

   This class encapsulates the BertModel and provides an interface for
   embedding operations and customized pooling functions.

   Attributes:
       model: An instance of BertModel used for forward operations.
       _pooler: An instance of Pooler used for pooling operations.
   """
495
    hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
496

497
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
498
        super().__init__()
499
        pooler_config = vllm_config.model_config.pooler_config
500
        self.config = vllm_config.model_config.hf_config
501
502
503
        self.model = self._build_model(vllm_config=vllm_config,
                                       prefix=maybe_prefix(prefix, "model"))
        self._pooler = self._build_pooler(pooler_config)
504
505
506
507
508
509
510
511
512
513
514

    def forward(
        self,
        input_ids: Optional[torch.Tensor],
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        return self.model(input_ids=input_ids,
                          position_ids=positions,
                          inputs_embeds=inputs_embeds,
515
                          intermediate_tensors=intermediate_tensors)
516
517
518
519
520
521
522
523
524

    def pooler(
        self,
        hidden_states: torch.Tensor,
        pooling_metadata: PoolingMetadata,
    ) -> Optional[PoolerOutput]:
        return self._pooler(hidden_states, pooling_metadata)

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
525
        weights = self.hf_to_vllm_mapper.apply(weights)
526
527
        weights = ((name, data) for name, data in weights
                   if not name.startswith("lm_head."))
528
        self.model.load_weights(weights)
529
530
531
532
533
534
535
536
537
538
539
540
541

    def _build_model(self,
                     vllm_config: VllmConfig,
                     prefix: str = "") -> BertModel:
        return BertModel(vllm_config=vllm_config,
                         prefix=prefix,
                         embedding_class=BertEmbedding)

    def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler:
        return Pooler.from_config_with_defaults(pooler_config,
                                                pooling_type=PoolingType.CLS,
                                                normalize=True,
                                                softmax=False)
542
543


544
545
class BertForSequenceClassification(nn.Module, SupportsCrossEncoding,
                                    SupportsQuant):
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
    """A model that uses Bert to provide embedding functionalities.

   This class encapsulates the BertModel and provides an interface for
   embedding operations and customized pooling functions.

   Attributes:
       model: An instance of BertModel used for forward operations.
       _pooler: An instance of Pooler used for pooling operations.
   """

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config

        self.default_activation_function = \
            get_cross_encoder_activation_function(config)

        self.num_labels = config.num_labels
        self.bert = BertModel(vllm_config=vllm_config,
                              prefix=maybe_prefix(prefix, "bert"),
                              embedding_class=BertEmbedding,
                              add_pooling_layer=True)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)
        self._pooler = CrossEncodingPooler(config, self.classifier,
                                           self.bert.pooler)

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

        self_weights = []

        def weight_filter():
            for name, weight in weights:
                if name.startswith("bert."):
                    yield (name[len("bert."):], weight)
                else:
                    self_weights.append((name, weight))

        self.bert.load_weights(weight_filter())

        params_dict = dict(self.named_parameters())

        for name, loaded_weight in self_weights:
            if name.startswith("classifier"):
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)

    def pooler(
        self,
        hidden_states: torch.Tensor,
        pooling_metadata: PoolingMetadata,
    ) -> Optional[PoolerOutput]:
        return self._pooler(hidden_states, pooling_metadata)

    def forward(
        self,
        input_ids: Optional[torch.Tensor],
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        return self.bert(input_ids=input_ids,
                         position_ids=positions,
                         inputs_embeds=inputs_embeds,
                         intermediate_tensors=intermediate_tensors,
                         token_type_ids=token_type_ids)
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725


class NomicBertEmbeddingModel(BertEmbeddingModel):

    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_substr={
            "emb_ln": "embeddings.LayerNorm",
            "layers": "layer",
            "attn.Wqkv": "attention.self.qkv_proj",
            "attn.out_proj": "attention.output.dense",
            'norm1': "attention.output.LayerNorm",
            'mlp.fc11': "intermediate.up_proj",
            'mlp.fc12': "intermediate.gate_proj",
            'mlp.fc2': "output.dense",
            'norm2': "output.LayerNorm",
        })

    def _build_model(self,
                     vllm_config: VllmConfig,
                     prefix: str = "") -> BertModel:
        config = vllm_config.model_config.hf_config

        assert config.__class__.__name__ == "NomicBertConfig"
        assert config.activation_function == "swiglu"

        # Assume NomicBertModel all linear layers do not have bias
        assert not config.mlp_fc1_bias
        assert not config.mlp_fc2_bias
        assert not config.qkv_proj_bias

        config.layer_norm_eps = config.layer_norm_epsilon
        config.position_embedding_type = "rotary"
        config.intermediate_size = config.n_inner
        config.hidden_act = "silu"
        config.hidden_size = config.n_embd
        config.num_hidden_layers = config.n_layer

        head_dim = config.hidden_size // config.num_attention_heads
        rotary_kwargs = {
            "head_size": head_dim,
            "rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
            "max_position": config.max_trained_positions,
            "base": config.rotary_emb_base,
            "rope_scaling": {
                "rope_type": "dynamic",
                "factor": config.rotary_scaling_factor
            }
        }

        return BertModel(vllm_config=vllm_config,
                         prefix=prefix,
                         bias=False,
                         rotary_kwargs=rotary_kwargs,
                         embedding_class=BertEmbedding)


class GteEmbeddingModel(BertEmbeddingModel):
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_substr={
            "attention.qkv_proj": "attention.self.qkv_proj",
            "attention.o_proj": "attention.output.dense",
            'attn_ln': "attention.output.LayerNorm",
            'mlp.down_proj': "output.dense",
            'mlp_ln': "output.LayerNorm",
        })

    def _build_model(self,
                     vllm_config: VllmConfig,
                     prefix: str = "") -> BertModel:
        config = vllm_config.model_config.hf_config

        assert config.__class__.__name__ == "GteConfig"
        assert config.position_embedding_type == "rope"
        assert config.hidden_act == "gelu"

        config.position_embedding_type = "rotary"
        config.hidden_act = "gelu_and_mul"

        head_dim = config.hidden_size // config.num_attention_heads
        rotary_kwargs = {
            "head_size": head_dim,
            "rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
            "max_position": config.max_position_embeddings,
            "base": config.rope_theta,
        }

        model = BertModel(vllm_config=vllm_config,
                          prefix=prefix,
                          rotary_kwargs=rotary_kwargs,
                          embedding_class=BertEmbedding)

        # GteModel only gate_up_proj does not have bias.
        # Hack method learned from vllm/model_executor/models/glm.py
        for layer in model.encoder.layer:
            layer.intermediate.gate_up_proj.bias = None
            layer.intermediate.skip_bias_add = True
        return model

    def split_up_gate_proj(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        n = "mlp.up_gate_proj"
        for name, weight in weights:
            if n in name:
                up, gate = weight.chunk(2, dim=0)
                yield name.replace(n, "intermediate.up_proj"), up
                yield name.replace(n, "intermediate.gate_proj"), gate
            else:
                yield name, weight

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        weights = self.hf_to_vllm_mapper.apply(weights)
        weights = self.split_up_gate_proj(weights)
        self.model.load_weights(weights)