bert.py 19.9 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
from collections.abc import Iterable
from typing import Optional
5
6
7
8
9

import torch
from torch import nn
from transformers import BertConfig

10
from vllm.attention import Attention, AttentionType
11
from vllm.compilation.decorators import support_torch_compile
12
from vllm.config import CacheConfig, PoolerConfig, VllmConfig
13
from vllm.distributed import get_tensor_model_parallel_world_size
14
from vllm.forward_context import get_forward_context
15
from vllm.model_executor.layers.activation import get_act_fn
16
17
18
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
19
from vllm.model_executor.layers.pooler import (ClassifierPooler, Pooler,
20
                                               PoolingType)
21
from vllm.model_executor.layers.quantization import QuantizationConfig
22
23
24
25
26
from vllm.model_executor.layers.vocab_parallel_embedding import (
    VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput
27
28
from vllm.transformers_utils.config import (
    get_cross_encoder_activation_function)
29

30
from .interfaces import SupportsCrossEncoding, SupportsQuant, SupportsV0Only
31
from .utils import WeightsMapper, maybe_prefix
32

33
34
35
36
37
38
39
40
41

class BertEmbedding(nn.Module):

    def __init__(self, config: BertConfig):

        super().__init__()
        self.size = config.hidden_size
        self.word_embeddings = VocabParallelEmbedding(config.vocab_size,
                                                      config.hidden_size)
42
43
        self.position_embeddings = VocabParallelEmbedding(
            config.max_position_embeddings, config.hidden_size)
44
45
46
47
        self.token_type_embeddings = VocabParallelEmbedding(
            config.type_vocab_size, config.hidden_size)
        self.LayerNorm = nn.LayerNorm(config.hidden_size,
                                      eps=config.layer_norm_eps)
48
49
        self.position_ids = nn.Parameter(
            torch.empty((1, config.max_position_embeddings)), )
50
51

        self.position_embedding_type = config.position_embedding_type
52
53
54
        if self.position_embedding_type != "absolute":
            raise ValueError("Only 'absolute' position_embedding_type" +
                             " is supported")
55
56
57
58

    def forward(
        self,
        input_ids: torch.Tensor,
59
60
61
        seq_lens: torch.Tensor,
        position_ids: torch.Tensor,
        token_type_ids: Optional[torch.Tensor] = None,
62
63
64
65
66
67
    ) -> torch.Tensor:
        input_shape = input_ids.size()

        # Input embeddings.
        inputs_embeds = self.word_embeddings(input_ids)

68
69
70
        # Position embeddings.
        position_embeddings = self.position_embeddings(position_ids)

71
72
73
74
75
76
        if token_type_ids is None:
            token_type_ids = torch.zeros(input_shape,
                                         dtype=torch.long,
                                         device=inputs_embeds.device)

        token_type_embeddings = self.token_type_embeddings(token_type_ids)
77

78
        embeddings = inputs_embeds + token_type_embeddings + position_embeddings
79
80
81
82
        embeddings = self.LayerNorm(embeddings)
        return embeddings


83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
class BertPooler(nn.Module):

    def __init__(self, config: BertConfig):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.activation = nn.Tanh()

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # We "pool" the model by simply taking the hidden state corresponding
        # to the first token.
        first_token_tensor = hidden_states[0, :]
        pooled_output = self.dense(first_token_tensor)
        pooled_output = self.activation(pooled_output)
        return pooled_output


99
@support_torch_compile
100
101
class BertEncoder(nn.Module):

102
    def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
103
        super().__init__()
104
105
106
        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
107
108
109
110
111
112
113
114
115
116
117
118
        self.layer = nn.ModuleList([
            BertLayer(config=config,
                      cache_config=cache_config,
                      quant_config=quant_config,
                      prefix=f"{prefix}.layer.{layer_idx}")
            for layer_idx in range(config.num_hidden_layers)
        ])

    def forward(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
119
        for layer in self.layer:
120
            hidden_states = layer(hidden_states)
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
        return hidden_states


class BertLayer(nn.Module):

    def __init__(self,
                 config: BertConfig,
                 cache_config: Optional[CacheConfig] = None,
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
        super().__init__()

        self.attention = BertAttention(
            hidden_size=config.hidden_size,
            num_attention_heads=config.num_attention_heads,
            layer_norm_eps=config.layer_norm_eps,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attention")

141
142
143
144
145
146
        self.intermediate = BertIntermediate(
            hidden_size=config.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
            quant_config=quant_config,
            prefix=f"{prefix}.intermediate")
147
148
149
150
151
152
153

        self.output = BertOutput(hidden_size=config.hidden_size,
                                 intermediate_size=config.intermediate_size,
                                 layer_norm_eps=config.layer_norm_eps,
                                 quant_config=quant_config,
                                 prefix=f"{prefix}.output")

154
155
    def forward(self, hidden_states: torch.Tensor):
        attn_output = self.attention(hidden_states)
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
        intermediate_output = self.intermediate(attn_output)
        output = self.output(intermediate_output, attn_output)
        return output


class BertAttention(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        num_attention_heads: int,
        layer_norm_eps: float,
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ):
        super().__init__()

        self.self = BertSelfAttention(hidden_size=hidden_size,
                                      num_attention_heads=num_attention_heads,
                                      cache_config=cache_config,
                                      quant_config=quant_config,
                                      prefix=f"{prefix}.output")

        self.output = BertSelfOutput(hidden_size=hidden_size,
                                     layer_norm_eps=layer_norm_eps,
                                     quant_config=quant_config,
                                     prefix=f"{prefix}.output")

    def forward(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
189
        self_output = self.self(hidden_states)
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
        return self.output(self_output, hidden_states)


class BertSelfAttention(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        num_attention_heads: int,
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ):
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()

        self.total_num_heads = num_attention_heads
        assert self.total_num_heads % tp_size == 0

        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = self.total_num_heads
        self.head_dim = self.hidden_size // self.total_num_heads
        assert self.head_dim * self.total_num_heads == self.hidden_size

        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)

        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.qkv_proj = QKVParallelLinear(
            hidden_size=self.hidden_size,
            head_size=self.head_dim,
            total_num_heads=self.total_num_heads,
            total_num_kv_heads=self.total_num_kv_heads,
225
            bias=True,
226
227
228
229
230
231
232
233
234
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj")

        self.attn = Attention(num_heads=self.num_heads,
                              head_size=self.head_dim,
                              scale=self.scaling,
                              num_kv_heads=self.num_kv_heads,
                              cache_config=cache_config,
                              quant_config=quant_config,
235
236
                              prefix=f"{prefix}.attn",
                              attn_type=AttentionType.ENCODER_ONLY)
237
238
239
240
241
242
243

    def forward(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
244
        output = self.attn(q, k, v)
245
246
247
248
249
250
251
252
253
254
255
256
257
        return output


class BertSelfOutput(nn.Module):

    def __init__(self,
                 hidden_size: int,
                 layer_norm_eps: float,
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
        super().__init__()
        self.dense = RowParallelLinear(input_size=hidden_size,
                                       output_size=hidden_size,
258
                                       bias=True,
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
                                       quant_config=quant_config,
                                       prefix=f"{prefix}.dense")
        self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)

    def forward(self, hidden_states: torch.Tensor,
                input_tensor: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.dense(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states


class BertIntermediate(nn.Module):

    def __init__(self,
                 hidden_size: int,
                 intermediate_size: int,
                 hidden_act: str,
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
        super().__init__()
        self.dense = ColumnParallelLinear(input_size=hidden_size,
                                          output_size=intermediate_size,
281
                                          bias=True,
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
                                          quant_config=quant_config,
                                          prefix=f"{prefix}.dense")
        self.intermediate_act_fn = get_act_fn(hidden_act)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.dense(hidden_states)
        hidden_states = self.intermediate_act_fn(hidden_states)
        return hidden_states


class BertOutput(nn.Module):

    def __init__(self,
                 hidden_size: int,
                 intermediate_size: int,
                 layer_norm_eps: float,
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
        super().__init__()

        self.dense = RowParallelLinear(input_size=intermediate_size,
                                       output_size=hidden_size,
304
                                       bias=True,
305
306
307
308
309
310
311
312
313
314
315
316
                                       quant_config=quant_config,
                                       prefix=f"{prefix}.dense")

        self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)

    def forward(self, hidden_states: torch.Tensor,
                input_tensor: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.dense(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states


317
class BertModel(nn.Module, SupportsQuant):
318
    packed_modules_mapping = {"qkv_proj": ["query", "key", "value"]}
319

320
321
322
323
    def __init__(self,
                 *,
                 vllm_config: VllmConfig,
                 prefix: str = "",
324
325
                 embedding_class: type = BertEmbedding,
                 add_pooling_layer: bool = False):
326
        super().__init__()
327
        config = vllm_config.model_config.hf_config
328
        self.embeddings = embedding_class(config)
329
        self.encoder = BertEncoder(vllm_config=vllm_config,
330
                                   prefix=f"{prefix}.encoder")
331
        self.pooler = BertPooler(config) if add_pooling_layer else None
332
333
334
335
336
337
338

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
339
        token_type_ids: Optional[torch.Tensor] = None,
340
341
342
343
    ) -> torch.Tensor:
        if inputs_embeds is not None:
            hidden_states = inputs_embeds
        else:
344
            attn_metadata = get_forward_context().attn_metadata
345
346
347
348
349
350
            assert hasattr(attn_metadata, "seq_lens_tensor")
            hidden_states = self.embeddings(
                input_ids=input_ids,
                seq_lens=attn_metadata.seq_lens_tensor,
                position_ids=position_ids,
                token_type_ids=token_type_ids)
351
        return self.encoder(hidden_states)
352

353
354
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
355
356
357
358
359
360
361
362
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "query", "q"),
            ("qkv_proj", "key", "k"),
            ("qkv_proj", "value", "v"),
        ]

        params_dict = dict(self.named_parameters())
363
        loaded_params: set[str] = set()
364
        for name, loaded_weight in weights:
365
            if self.pooler is None and "pooler" in name:
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
                continue
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
386
387
            loaded_params.add(name)
        return loaded_params
388
389


390
class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant):
391
392
393
394
395
396
397
398
399
    """A model that uses Bert to provide embedding functionalities.

   This class encapsulates the BertModel and provides an interface for
   embedding operations and customized pooling functions.

   Attributes:
       model: An instance of BertModel used for forward operations.
       _pooler: An instance of Pooler used for pooling operations.
   """
400
    hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
401

402
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
403
        super().__init__()
404
        pooler_config = vllm_config.model_config.pooler_config
405
406
407
        self.model = self._build_model(vllm_config=vllm_config,
                                       prefix=maybe_prefix(prefix, "model"))
        self._pooler = self._build_pooler(pooler_config)
408
409
410
411
412
413
414
415
416
417
418

    def forward(
        self,
        input_ids: Optional[torch.Tensor],
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        return self.model(input_ids=input_ids,
                          position_ids=positions,
                          inputs_embeds=inputs_embeds,
419
                          intermediate_tensors=intermediate_tensors)
420
421
422
423
424
425
426
427

    def pooler(
        self,
        hidden_states: torch.Tensor,
        pooling_metadata: PoolingMetadata,
    ) -> Optional[PoolerOutput]:
        return self._pooler(hidden_states, pooling_metadata)

428
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
429
        weights = self.hf_to_vllm_mapper.apply(weights)
430
431
        weights = ((name, data) for name, data in weights
                   if not name.startswith("lm_head."))
432
        self.model.load_weights(weights)
433
434
435
436
437
438
439
440
441
442
443
444
445

    def _build_model(self,
                     vllm_config: VllmConfig,
                     prefix: str = "") -> BertModel:
        return BertModel(vllm_config=vllm_config,
                         prefix=prefix,
                         embedding_class=BertEmbedding)

    def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler:
        return Pooler.from_config_with_defaults(pooler_config,
                                                pooling_type=PoolingType.CLS,
                                                normalize=True,
                                                softmax=False)
446
447


448
449
class BertForSequenceClassification(nn.Module, SupportsCrossEncoding,
                                    SupportsQuant):
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
    """A model that uses Bert to provide embedding functionalities.

   This class encapsulates the BertModel and provides an interface for
   embedding operations and customized pooling functions.

   Attributes:
       model: An instance of BertModel used for forward operations.
       _pooler: An instance of Pooler used for pooling operations.
   """

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config

        self.default_activation_function = \
            get_cross_encoder_activation_function(config)

        self.num_labels = config.num_labels
        self.bert = BertModel(vllm_config=vllm_config,
                              prefix=maybe_prefix(prefix, "bert"),
                              embedding_class=BertEmbedding,
                              add_pooling_layer=True)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)
473
474
        self._pooler = ClassifierPooler(vllm_config.model_config,
                                        self.classifier, self.bert.pooler)
475

476
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517

        self_weights = []

        def weight_filter():
            for name, weight in weights:
                if name.startswith("bert."):
                    yield (name[len("bert."):], weight)
                else:
                    self_weights.append((name, weight))

        self.bert.load_weights(weight_filter())

        params_dict = dict(self.named_parameters())

        for name, loaded_weight in self_weights:
            if name.startswith("classifier"):
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)

    def pooler(
        self,
        hidden_states: torch.Tensor,
        pooling_metadata: PoolingMetadata,
    ) -> Optional[PoolerOutput]:
        return self._pooler(hidden_states, pooling_metadata)

    def forward(
        self,
        input_ids: Optional[torch.Tensor],
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        return self.bert(input_ids=input_ids,
                         position_ids=positions,
                         inputs_embeds=inputs_embeds,
                         intermediate_tensors=intermediate_tensors,
                         token_type_ids=token_type_ids)