bert.py 19.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from collections.abc import Iterable
5
from typing import Optional, Union
6
7
8
9
10

import torch
from torch import nn
from transformers import BertConfig

11
from vllm.attention import Attention, AttentionType
12
from vllm.compilation.decorators import support_torch_compile
13
from vllm.config import CacheConfig, PoolerConfig, VllmConfig
14
from vllm.distributed import get_tensor_model_parallel_world_size
15
from vllm.forward_context import get_forward_context
16
from vllm.model_executor.layers.activation import get_act_fn
17
18
19
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
20
from vllm.model_executor.layers.pooler import (ClassifierPooler, Pooler,
21
22
                                               PoolingMethod,
                                               PoolingParamsUpdate,
23
                                               PoolingType)
24
from vllm.model_executor.layers.quantization import QuantizationConfig
25
26
27
from vllm.model_executor.layers.vocab_parallel_embedding import (
    VocabParallelEmbedding)
from vllm.model_executor.pooling_metadata import PoolingMetadata
28
from vllm.pooling_params import PoolingTask
29
from vllm.sequence import IntermediateTensors
30

31
from .interfaces import SupportsCrossEncoding, SupportsQuant, SupportsV0Only
32
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
33

34
35
36
37
38
39
40
41
42

class BertEmbedding(nn.Module):

    def __init__(self, config: BertConfig):

        super().__init__()
        self.size = config.hidden_size
        self.word_embeddings = VocabParallelEmbedding(config.vocab_size,
                                                      config.hidden_size)
43
44
        self.position_embeddings = VocabParallelEmbedding(
            config.max_position_embeddings, config.hidden_size)
45
46
47
48
49
        self.token_type_embeddings = VocabParallelEmbedding(
            config.type_vocab_size, config.hidden_size)
        self.LayerNorm = nn.LayerNorm(config.hidden_size,
                                      eps=config.layer_norm_eps)

50
51
52
53
        self.register_buffer(
            "position_ids",
            torch.arange(config.max_position_embeddings).unsqueeze(0),
        )
54
        self.position_embedding_type = config.position_embedding_type
55
56
57
        if self.position_embedding_type != "absolute":
            raise ValueError("Only 'absolute' position_embedding_type" +
                             " is supported")
58
59
60
61

    def forward(
        self,
        input_ids: torch.Tensor,
62
63
64
        seq_lens: torch.Tensor,
        position_ids: torch.Tensor,
        token_type_ids: Optional[torch.Tensor] = None,
65
66
67
68
69
70
    ) -> torch.Tensor:
        input_shape = input_ids.size()

        # Input embeddings.
        inputs_embeds = self.word_embeddings(input_ids)

71
72
73
        # Position embeddings.
        position_embeddings = self.position_embeddings(position_ids)

74
75
76
77
78
79
        if token_type_ids is None:
            token_type_ids = torch.zeros(input_shape,
                                         dtype=torch.long,
                                         device=inputs_embeds.device)

        token_type_embeddings = self.token_type_embeddings(token_type_ids)
80

81
        embeddings = inputs_embeds + token_type_embeddings + position_embeddings
82
83
84
85
        embeddings = self.LayerNorm(embeddings)
        return embeddings


86
class BertPooler(Pooler):
87
88
89

    def __init__(self, config: BertConfig):
        super().__init__()
90
91

        self.pooling = PoolingMethod.from_pooling_type(PoolingType.CLS)
92
93
94
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.activation = nn.Tanh()

95
96
97
98
99
    def get_pooling_updates(
        self,
        task: PoolingTask,
    ) -> Optional[PoolingParamsUpdate]:
        return self.pooling.get_pooling_updates(task)
100

101
102
103
104
105
106
107
    def forward(
        self,
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
        pooling_metadata: PoolingMetadata,
    ) -> Union[torch.Tensor, list[torch.Tensor]]:
        pooled_output = self.pooling(hidden_states, pooling_metadata)
        pooled_output = self.dense(pooled_output)
108
109
110
111
        pooled_output = self.activation(pooled_output)
        return pooled_output


112
@support_torch_compile
113
114
class BertEncoder(nn.Module):

115
    def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
116
        super().__init__()
117
118
119
        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
120
121
122
123
124
125
126
127
128
129
130
131
        self.layer = nn.ModuleList([
            BertLayer(config=config,
                      cache_config=cache_config,
                      quant_config=quant_config,
                      prefix=f"{prefix}.layer.{layer_idx}")
            for layer_idx in range(config.num_hidden_layers)
        ])

    def forward(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
132
        for layer in self.layer:
133
            hidden_states = layer(hidden_states)
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
        return hidden_states


class BertLayer(nn.Module):

    def __init__(self,
                 config: BertConfig,
                 cache_config: Optional[CacheConfig] = None,
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
        super().__init__()

        self.attention = BertAttention(
            hidden_size=config.hidden_size,
            num_attention_heads=config.num_attention_heads,
            layer_norm_eps=config.layer_norm_eps,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attention")

154
155
156
157
158
159
        self.intermediate = BertIntermediate(
            hidden_size=config.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
            quant_config=quant_config,
            prefix=f"{prefix}.intermediate")
160
161
162
163
164
165
166

        self.output = BertOutput(hidden_size=config.hidden_size,
                                 intermediate_size=config.intermediate_size,
                                 layer_norm_eps=config.layer_norm_eps,
                                 quant_config=quant_config,
                                 prefix=f"{prefix}.output")

167
168
    def forward(self, hidden_states: torch.Tensor):
        attn_output = self.attention(hidden_states)
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
        intermediate_output = self.intermediate(attn_output)
        output = self.output(intermediate_output, attn_output)
        return output


class BertAttention(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        num_attention_heads: int,
        layer_norm_eps: float,
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ):
        super().__init__()

        self.self = BertSelfAttention(hidden_size=hidden_size,
                                      num_attention_heads=num_attention_heads,
                                      cache_config=cache_config,
                                      quant_config=quant_config,
                                      prefix=f"{prefix}.output")

        self.output = BertSelfOutput(hidden_size=hidden_size,
                                     layer_norm_eps=layer_norm_eps,
                                     quant_config=quant_config,
                                     prefix=f"{prefix}.output")

    def forward(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
202
        self_output = self.self(hidden_states)
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
        return self.output(self_output, hidden_states)


class BertSelfAttention(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        num_attention_heads: int,
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ):
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()

        self.total_num_heads = num_attention_heads
        assert self.total_num_heads % tp_size == 0

        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = self.total_num_heads
        self.head_dim = self.hidden_size // self.total_num_heads
        assert self.head_dim * self.total_num_heads == self.hidden_size

        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)

        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.qkv_proj = QKVParallelLinear(
            hidden_size=self.hidden_size,
            head_size=self.head_dim,
            total_num_heads=self.total_num_heads,
            total_num_kv_heads=self.total_num_kv_heads,
238
            bias=True,
239
240
241
242
243
244
245
246
247
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj")

        self.attn = Attention(num_heads=self.num_heads,
                              head_size=self.head_dim,
                              scale=self.scaling,
                              num_kv_heads=self.num_kv_heads,
                              cache_config=cache_config,
                              quant_config=quant_config,
248
249
                              prefix=f"{prefix}.attn",
                              attn_type=AttentionType.ENCODER_ONLY)
250
251
252
253
254
255
256

    def forward(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
257
        output = self.attn(q, k, v)
258
259
260
261
262
263
264
265
266
267
268
269
270
        return output


class BertSelfOutput(nn.Module):

    def __init__(self,
                 hidden_size: int,
                 layer_norm_eps: float,
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
        super().__init__()
        self.dense = RowParallelLinear(input_size=hidden_size,
                                       output_size=hidden_size,
271
                                       bias=True,
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
                                       quant_config=quant_config,
                                       prefix=f"{prefix}.dense")
        self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)

    def forward(self, hidden_states: torch.Tensor,
                input_tensor: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.dense(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states


class BertIntermediate(nn.Module):

    def __init__(self,
                 hidden_size: int,
                 intermediate_size: int,
                 hidden_act: str,
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
        super().__init__()
        self.dense = ColumnParallelLinear(input_size=hidden_size,
                                          output_size=intermediate_size,
294
                                          bias=True,
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
                                          quant_config=quant_config,
                                          prefix=f"{prefix}.dense")
        self.intermediate_act_fn = get_act_fn(hidden_act)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.dense(hidden_states)
        hidden_states = self.intermediate_act_fn(hidden_states)
        return hidden_states


class BertOutput(nn.Module):

    def __init__(self,
                 hidden_size: int,
                 intermediate_size: int,
                 layer_norm_eps: float,
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
        super().__init__()

        self.dense = RowParallelLinear(input_size=intermediate_size,
                                       output_size=hidden_size,
317
                                       bias=True,
318
319
320
321
322
323
324
325
326
327
328
329
                                       quant_config=quant_config,
                                       prefix=f"{prefix}.dense")

        self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)

    def forward(self, hidden_states: torch.Tensor,
                input_tensor: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.dense(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states


330
class BertModel(nn.Module, SupportsQuant):
331
332
333

    is_pooling_model = True

334
    packed_modules_mapping = {"qkv_proj": ["query", "key", "value"]}
335

336
337
338
339
    def __init__(self,
                 *,
                 vllm_config: VllmConfig,
                 prefix: str = "",
340
341
                 embedding_class: type = BertEmbedding,
                 add_pooling_layer: bool = False):
342
        super().__init__()
343
        config = vllm_config.model_config.hf_config
344
        self.embeddings = embedding_class(config)
345
        self.encoder = BertEncoder(vllm_config=vllm_config,
346
                                   prefix=f"{prefix}.encoder")
347
        self.pooler = BertPooler(config) if add_pooling_layer else None
348
349
350
351
352
353
354

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
355
        token_type_ids: Optional[torch.Tensor] = None,
356
357
358
359
    ) -> torch.Tensor:
        if inputs_embeds is not None:
            hidden_states = inputs_embeds
        else:
360
            attn_metadata = get_forward_context().attn_metadata
361
362
363
364
365
366
            assert hasattr(attn_metadata, "seq_lens_tensor")
            hidden_states = self.embeddings(
                input_ids=input_ids,
                seq_lens=attn_metadata.seq_lens_tensor,
                position_ids=position_ids,
                token_type_ids=token_type_ids)
367
        return self.encoder(hidden_states)
368

369
370
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
371
372
373
374
375
376
377
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "query", "q"),
            ("qkv_proj", "key", "k"),
            ("qkv_proj", "value", "v"),
        ]

378
379
        loaded_stacked_params = []
        other_weights = []
380
381
382
383
384
        params_dict = dict(self.named_parameters())
        for name, loaded_weight in weights:
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
385

386
                name = name.replace(weight_name, param_name)
387
                if name not in params_dict:
388
389
390
391
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
392
                loaded_stacked_params.append(name)
393
394
                break
            else:
395
396
397
398
399
400
401
402
403
                if name in params_dict:
                    other_weights.append((name, loaded_weight))

        loader = AutoWeightsLoader(
            self,
            skip_prefixes=(["pooler."] if self.pooler is None else []),
        )
        loaded_params = loader.load_weights(other_weights)
        loaded_params.update(loaded_stacked_params)
404
        return loaded_params
405
406


407
class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant):
408
409
    """A model that uses Bert to provide embedding functionalities.

410
411
    This class encapsulates the BertModel and provides an interface for
    embedding operations and customized pooling functions.
412

413
414
415
416
    Attributes:
        model: An instance of BertModel used for forward operations.
        _pooler: An instance of Pooler used for pooling operations.
    """
417

418
419
    is_pooling_model = True

420
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
421
        super().__init__()
422

423
        pooler_config = vllm_config.model_config.pooler_config
424
425
        self.model = self._build_model(vllm_config=vllm_config,
                                       prefix=maybe_prefix(prefix, "model"))
426
        self.pooler = self._build_pooler(pooler_config)
427
428
429
430
431
432
433
434

    def forward(
        self,
        input_ids: Optional[torch.Tensor],
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
435
436
437
438
        return self.model(input_ids=input_ids,
                          position_ids=positions,
                          inputs_embeds=inputs_embeds,
                          intermediate_tensors=intermediate_tensors)
439

440
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
441
442
443
444
445
446
447
448
449
        weights_list = list(weights)

        has_model_prefix = any(
            name.startswith("model.") for name, _ in weights_list)
        if not has_model_prefix:
            mapper = WeightsMapper(orig_to_new_prefix={"": "model."})

        loader = AutoWeightsLoader(self, skip_prefixes=["lm_head."])
        return loader.load_weights(weights_list, mapper=mapper)
450
451
452
453
454
455
456
457
458
459
460
461
462

    def _build_model(self,
                     vllm_config: VllmConfig,
                     prefix: str = "") -> BertModel:
        return BertModel(vllm_config=vllm_config,
                         prefix=prefix,
                         embedding_class=BertEmbedding)

    def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler:
        return Pooler.from_config_with_defaults(pooler_config,
                                                pooling_type=PoolingType.CLS,
                                                normalize=True,
                                                softmax=False)
463
464


465
466
class BertForSequenceClassification(nn.Module, SupportsV0Only,
                                    SupportsCrossEncoding, SupportsQuant):
467
468
469
470
471
472
473
474
475
476
    """A model that uses Bert to provide embedding functionalities.

   This class encapsulates the BertModel and provides an interface for
   embedding operations and customized pooling functions.

   Attributes:
       model: An instance of BertModel used for forward operations.
       _pooler: An instance of Pooler used for pooling operations.
   """

477
478
    is_pooling_model = True

479
480
481
482
483
484
485
486
487
488
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config

        self.num_labels = config.num_labels
        self.bert = BertModel(vllm_config=vllm_config,
                              prefix=maybe_prefix(prefix, "bert"),
                              embedding_class=BertEmbedding,
                              add_pooling_layer=True)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)
489
        self.pooler = ClassifierPooler(
490
491
492
493
            vllm_config.model_config,
            pooling=self.bert.pooler,
            classifier=self.classifier,
        )
494

495
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
496
497
498
        loader = AutoWeightsLoader(self)
        loaded_params = loader.load_weights(weights)
        return loaded_params
499
500
501
502
503
504
505
506
507
508
509
510
511
512

    def forward(
        self,
        input_ids: Optional[torch.Tensor],
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        return self.bert(input_ids=input_ids,
                         position_ids=positions,
                         inputs_embeds=inputs_embeds,
                         intermediate_tensors=intermediate_tensors,
                         token_type_ids=token_type_ids)