roberta.py 12.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
5
from collections.abc import Iterable
6
7
8
9
10

import torch
from torch import nn
from transformers import RobertaConfig

11
12
from vllm.config import ModelConfig, PoolerConfig, VllmConfig
from vllm.model_executor.layers.pooler import (
13
    BgeM3Pooler,
14
15
16
17
18
19
20
21
22
23
24
25
    BOSEOSFilter,
    DispatchPooler,
    Pooler,
)
from vllm.model_executor.layers.pooler.seqwise import (
    pooler_for_embed,
)
from vllm.model_executor.layers.pooler.tokwise import (
    AllPool,
    pooler_for_token_classify,
    pooler_for_token_embed,
)
26
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
27
28
from vllm.model_executor.model_loader.default_loader import DefaultModelLoader
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
29
30
31
32
33
34
35
36
37
38
39
40
from vllm.model_executor.models.bert import (
    TOKEN_TYPE_SHIFT,
    BertEmbeddingModel,
    BertModel,
    _decode_token_type_ids,
    _encode_token_type_ids,
)
from vllm.model_executor.models.utils import (
    AutoWeightsLoader,
    WeightsMapper,
    maybe_prefix,
)
41
from vllm.sequence import IntermediateTensors
42

43
from .bert_with_rope import BertWithRope, JinaRobertaModel
44
45
from .interfaces import SupportsCrossEncoding
from .interfaces_base import default_pooling_type
46

47
48
49
50
51

class RobertaEmbedding(nn.Module):
    def __init__(self, config: RobertaConfig):
        super().__init__()
        self.size = config.hidden_size
52
53
54
        self.word_embeddings = VocabParallelEmbedding(
            config.vocab_size, config.hidden_size
        )
55
        self.padding_idx = config.pad_token_id
56
57
58
59
60
61
62
63
64
65
        self.position_embeddings = nn.Embedding(
            config.max_position_embeddings,
            config.hidden_size,
            padding_idx=self.padding_idx,
        )

        self.token_type_embeddings = nn.Embedding(
            config.type_vocab_size, config.hidden_size
        )
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
66
67
68
69
        self.register_buffer(
            "position_ids",
            torch.arange(config.max_position_embeddings).unsqueeze(0),
        )
70
71
72
73

    def forward(
        self,
        input_ids: torch.Tensor,
74
        position_ids: torch.Tensor,
75
        inputs_embeds: torch.Tensor | None = None,
76
    ) -> torch.Tensor:
77
78
        token_type_ids = _decode_token_type_ids(input_ids)

79
80
81
        if inputs_embeds is None:
            inputs_embeds = self.word_embeddings(input_ids)

82
83
        position_embeddings = self.position_embeddings(position_ids)

84
        token_type_embeddings = self.token_type_embeddings(token_type_ids)
85
86
87
88
89
        embeddings = inputs_embeds + token_type_embeddings + position_embeddings
        embeddings = self.LayerNorm(embeddings)
        return embeddings


90
91
92
93
# Adapted from transformers
class RobertaClassificationHead(nn.Module):
    """Head for sentence-level classification tasks."""

94
    def __init__(self, model_config: "ModelConfig"):
95
        super().__init__()
96
97
        config = model_config.hf_config
        head_dtype = model_config.head_dtype
98
99
100
101
        self.dense = nn.Linear(config.hidden_size, config.hidden_size, dtype=head_dtype)
        self.out_proj = nn.Linear(
            config.hidden_size, config.num_labels, dtype=head_dtype
        )
102

103
    def forward(self, x: torch.Tensor) -> torch.Tensor:
104
        # Token extraction has already been applied in `pooler.pooling`
105
106
107
108
109
110
        x = self.dense(x)
        x = torch.tanh(x)
        x = self.out_proj(x)
        return x


111
@default_pooling_type(seq_pooling_type="CLS")
112
class RobertaEmbeddingModel(BertEmbeddingModel):
113
    """A model that uses Roberta to provide embedding functionalities."""
114

115
116
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__(vllm_config=vllm_config, prefix=prefix)
117
        self.padding_idx: int = vllm_config.model_config.hf_config.pad_token_id
118
119
120

    def forward(
        self,
121
        input_ids: torch.Tensor,
122
        positions: torch.Tensor,
123
124
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
125
126
127
128
    ) -> torch.Tensor:
        # Fix Roberta positions here outside of the CUDA graph.
        # Because we need the to extract the sequences from
        # input_ids the control flow is data dependent.
129
130
131
132
133
134
135
136
137
138
139
140
141
        replace_roberta_positions(
            input_ids=input_ids, position_ids=positions, padding_idx=self.padding_idx
        )

        return self.model(
            input_ids=input_ids,
            positions=positions,
            inputs_embeds=inputs_embeds,
            intermediate_tensors=intermediate_tensors,
        )

    def _build_model(
        self, vllm_config: VllmConfig, prefix: str = ""
142
    ) -> BertModel | BertWithRope:
143
144
145
146
        hf_config = vllm_config.model_config.hf_config
        kwargs = dict(vllm_config=vllm_config, prefix=prefix)
        if getattr(hf_config, "position_embedding_type", "absolute") == "absolute":
            return BertModel(**kwargs, embedding_class=RobertaEmbedding)
147
        else:
148
            return JinaRobertaModel(**kwargs)
149

150
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
151
152
        weights_list = list(weights)
        has_roberta_prefix = any(
153
154
            name.startswith("roberta.") for name, _ in weights_list
        )
155
156
157
158
159
160
161
162
163
164
165
        if has_roberta_prefix:
            # For models with the `roberta.` prefix e.g.
            # `FacebookAI/roberta-base`
            mapper = WeightsMapper(orig_to_new_prefix={"roberta.": "model."})
        else:
            # For models without the `roberta.` prefix e.g.
            # `sentence-transformers/stsb-roberta-base-v2`
            mapper = WeightsMapper(orig_to_new_prefix={"": "model."})

        loader = AutoWeightsLoader(self, skip_prefixes=["lm_head."])
        return loader.load_weights(weights_list, mapper=mapper)
166

167

168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
def filter_secondary_weights(
    all_weights: Iterable[tuple[str, torch.Tensor]],
    secondary_weights: list[str],
) -> tuple[Iterable[tuple[str, torch.Tensor]], Iterable[tuple[str, torch.Tensor]]]:
    all_weights1, all_weights2 = itertools.tee(all_weights)

    def filtered(n):
        return any(n.startswith(f) for f in secondary_weights)

    return ((n, w) for n, w in all_weights1 if filtered(n)), (
        (n, w) for n, w in all_weights2 if not filtered(n)
    )


class BgeM3EmbeddingModel(RobertaEmbeddingModel):
    """A model that extends RobertaEmbeddingModel with sparse embeddings.

    This class supports loading an additional sparse_linear.pt file
    to create sparse embeddings as described in https://arxiv.org/abs/2402.03216
    """

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        self.hidden_size = vllm_config.model_config.hf_config.hidden_size

        model_config = vllm_config.model_config
        self.head_dtype = model_config.head_dtype
        self.bos_token_id = model_config.hf_config.bos_token_id
        self.eos_token_id = model_config.hf_config.eos_token_id

        super().__init__(vllm_config=vllm_config, prefix=prefix)
        self.secondary_weight_prefixes = ["sparse_linear.", "colbert_linear."]
        self.secondary_weight_files = [
            prefix + "pt" for prefix in self.secondary_weight_prefixes
        ]

        self.secondary_weights = [
            DefaultModelLoader.Source(
                model_or_path=vllm_config.model_config.model,
                revision=None,
                prefix=prefix,
                allow_patterns_overrides=[filename],
            )
            for filename, prefix in zip(
                self.secondary_weight_files, self.secondary_weight_prefixes
            )
        ]

    def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler:
        self.sparse_linear = nn.Linear(self.hidden_size, 1, dtype=self.head_dtype)
        self.colbert_linear = nn.Linear(
            self.hidden_size, self.hidden_size, dtype=self.head_dtype
        )
220
221
222
223
224
225
226
227
228
229
230
        embed_pooler = pooler_for_embed(pooler_config)
        token_classify_pooler = BOSEOSFilter(
            pooler_for_token_classify(
                pooler_config,
                pooling=AllPool(),
                classifier=self.sparse_linear,
                act_fn=torch.relu,
            ),
            self.bos_token_id,
            self.eos_token_id,
        )
231
232
233

        return DispatchPooler(
            {
234
                "embed": embed_pooler,
235
236
237
238
239
                "token_embed": BOSEOSFilter(
                    pooler_for_token_embed(pooler_config, self.colbert_linear),
                    self.bos_token_id,
                    # for some reason m3 only filters the bos for colbert vectors
                ),
240
241
242
                "token_classify": token_classify_pooler,
                "embed&token_classify": BgeM3Pooler(
                    token_classify_pooler, embed_pooler
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
                ),
            }
        )

    def load_weights(self, all_weights: Iterable[tuple[str, torch.Tensor]]):
        secondary, weights = filter_secondary_weights(
            all_weights, self.secondary_weight_prefixes
        )

        super().load_weights(weights)

        params_dict = dict(self.named_parameters())

        for name, loaded_weight in secondary:
            if any(
                name.startswith(prefix) for prefix in self.secondary_weight_prefixes
            ):
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                weight_loader(param, loaded_weight)


265
@default_pooling_type(seq_pooling_type="CLS")
266
class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
267
268
    """A model that uses Roberta to provide embedding functionalities.

269
270
    This class encapsulates the BertModel and provides an interface for
    embedding operations and customized pooling functions.
271

272
273
274
275
    Attributes:
        roberta: An instance of BertModel used for forward operations.
        _pooler: An instance of Pooler used for pooling operations.
    """
276

277
    is_pooling_model = True
278
279
    jina_to_vllm_mapper = WeightsMapper(
        orig_to_new_substr={
280
281
282
283
284
285
286
287
288
289
            "emb_ln": "embeddings.LayerNorm",
            "layers": "layer",
            "mixer.Wqkv": "attention.self.qkv_proj",
            "mixer.out_proj": "attention.output.dense",
            "norm1": "attention.output.LayerNorm",
            "mlp.fc1": "intermediate.dense",
            "mlp.fc2": "output.dense",
            "norm2": "output.LayerNorm",
        }
    )
290

291
292
293
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
294
        self.padding_idx: int = vllm_config.model_config.hf_config.pad_token_id
295
296

        self.num_labels = config.num_labels
297
298
299
300
301
        self.roberta = BertModel(
            vllm_config=vllm_config,
            prefix=maybe_prefix(prefix, "bert"),
            embedding_class=RobertaEmbedding,
        )
302
        self.classifier = RobertaClassificationHead(vllm_config.model_config)
303

304
305
306
        pooler_config = vllm_config.model_config.pooler_config
        assert pooler_config is not None

307
308
309
        self.pooler = DispatchPooler.for_seq_cls(
            pooler_config,
            classifier=self.classifier,
310
        )
311

312
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
313
314
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights, mapper=self.jina_to_vllm_mapper)
315

316
317
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.roberta.embed_input_ids(input_ids)
318

319
320
    def forward(
        self,
321
        input_ids: torch.Tensor | None,
322
        positions: torch.Tensor,
323
324
325
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
        token_type_ids: torch.Tensor | None = None,
326
    ) -> torch.Tensor:
327
328
329
        replace_roberta_positions(
            input_ids=input_ids, position_ids=positions, padding_idx=self.padding_idx
        )
330
331
332
333
        if token_type_ids is not None:
            assert self.roberta.config.vocab_size < (1 << TOKEN_TYPE_SHIFT)
            assert input_ids is not None
            _encode_token_type_ids(input_ids, token_type_ids)
334
335
336
337
338
339
        return self.roberta(
            input_ids=input_ids,
            positions=positions,
            inputs_embeds=inputs_embeds,
            intermediate_tensors=intermediate_tensors,
        )
340
341


342
343
344
def replace_roberta_positions(
    input_ids: torch.Tensor, position_ids: torch.Tensor, padding_idx: int
) -> None:
345
346
347
348
349
350
351
352
    # Replace position ids because in RoBERTa models
    # they have to start at padding_idx + 1 and ignore
    # existing padding tokens
    # References:
    # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133
    # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669
    # vllm does not use padding tokens, let's make things simpler
    position_ids += padding_idx + 1