"vscode:/vscode.git/clone" did not exist on "1f5674218f968dec625d0995fe5cd5d626db9188"
roberta.py 12.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
5
from collections.abc import Iterable
6
7
8
9
10

import torch
from torch import nn
from transformers import RobertaConfig

11
12
13
14
15
16
17
18
19
20
21
22
23
24
from vllm.config import ModelConfig, PoolerConfig, VllmConfig
from vllm.model_executor.layers.pooler import (
    BOSEOSFilter,
    DispatchPooler,
    Pooler,
)
from vllm.model_executor.layers.pooler.seqwise import (
    pooler_for_embed,
)
from vllm.model_executor.layers.pooler.tokwise import (
    AllPool,
    pooler_for_token_classify,
    pooler_for_token_embed,
)
25
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
26
27
from vllm.model_executor.model_loader.default_loader import DefaultModelLoader
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
28
29
30
31
32
33
34
35
36
37
38
39
from vllm.model_executor.models.bert import (
    TOKEN_TYPE_SHIFT,
    BertEmbeddingModel,
    BertModel,
    _decode_token_type_ids,
    _encode_token_type_ids,
)
from vllm.model_executor.models.utils import (
    AutoWeightsLoader,
    WeightsMapper,
    maybe_prefix,
)
40
from vllm.sequence import IntermediateTensors
41

42
from .bert_with_rope import BertWithRope, JinaRobertaModel
43
44
from .interfaces import SupportsCrossEncoding
from .interfaces_base import default_pooling_type
45

46
47
48
49
50

class RobertaEmbedding(nn.Module):
    def __init__(self, config: RobertaConfig):
        super().__init__()
        self.size = config.hidden_size
51
52
53
        self.word_embeddings = VocabParallelEmbedding(
            config.vocab_size, config.hidden_size
        )
54
        self.padding_idx = config.pad_token_id
55
56
57
58
59
60
61
62
63
64
        self.position_embeddings = nn.Embedding(
            config.max_position_embeddings,
            config.hidden_size,
            padding_idx=self.padding_idx,
        )

        self.token_type_embeddings = nn.Embedding(
            config.type_vocab_size, config.hidden_size
        )
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
65
66
67
68
        self.register_buffer(
            "position_ids",
            torch.arange(config.max_position_embeddings).unsqueeze(0),
        )
69
70
71
72

    def forward(
        self,
        input_ids: torch.Tensor,
73
        position_ids: torch.Tensor,
74
        inputs_embeds: torch.Tensor | None = None,
75
    ) -> torch.Tensor:
76
77
        token_type_ids = _decode_token_type_ids(input_ids)

78
79
80
        if inputs_embeds is None:
            inputs_embeds = self.word_embeddings(input_ids)

81
82
        position_embeddings = self.position_embeddings(position_ids)

83
        token_type_embeddings = self.token_type_embeddings(token_type_ids)
84
85
86
87
88
        embeddings = inputs_embeds + token_type_embeddings + position_embeddings
        embeddings = self.LayerNorm(embeddings)
        return embeddings


89
90
91
92
# Adapted from transformers
class RobertaClassificationHead(nn.Module):
    """Head for sentence-level classification tasks."""

93
    def __init__(self, model_config: "ModelConfig"):
94
        super().__init__()
95
96
        config = model_config.hf_config
        head_dtype = model_config.head_dtype
97
98
99
100
        self.dense = nn.Linear(config.hidden_size, config.hidden_size, dtype=head_dtype)
        self.out_proj = nn.Linear(
            config.hidden_size, config.num_labels, dtype=head_dtype
        )
101

102
    def forward(self, x: torch.Tensor) -> torch.Tensor:
103
        # Token extraction has already been applied in `pooler.pooling`
104
105
106
107
108
109
        x = self.dense(x)
        x = torch.tanh(x)
        x = self.out_proj(x)
        return x


110
@default_pooling_type(seq_pooling_type="CLS")
111
class RobertaEmbeddingModel(BertEmbeddingModel):
112
    """A model that uses Roberta to provide embedding functionalities."""
113

114
115
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__(vllm_config=vllm_config, prefix=prefix)
116
        self.padding_idx: int = vllm_config.model_config.hf_config.pad_token_id
117
118
119

    def forward(
        self,
120
        input_ids: torch.Tensor,
121
        positions: torch.Tensor,
122
123
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
124
125
126
127
    ) -> torch.Tensor:
        # Fix Roberta positions here outside of the CUDA graph.
        # Because we need the to extract the sequences from
        # input_ids the control flow is data dependent.
128
129
130
131
132
133
134
135
136
137
138
139
140
        replace_roberta_positions(
            input_ids=input_ids, position_ids=positions, padding_idx=self.padding_idx
        )

        return self.model(
            input_ids=input_ids,
            positions=positions,
            inputs_embeds=inputs_embeds,
            intermediate_tensors=intermediate_tensors,
        )

    def _build_model(
        self, vllm_config: VllmConfig, prefix: str = ""
141
    ) -> BertModel | BertWithRope:
142
143
144
145
        hf_config = vllm_config.model_config.hf_config
        kwargs = dict(vllm_config=vllm_config, prefix=prefix)
        if getattr(hf_config, "position_embedding_type", "absolute") == "absolute":
            return BertModel(**kwargs, embedding_class=RobertaEmbedding)
146
        else:
147
            return JinaRobertaModel(**kwargs)
148

149
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
150
151
        weights_list = list(weights)
        has_roberta_prefix = any(
152
153
            name.startswith("roberta.") for name, _ in weights_list
        )
154
155
156
157
158
159
160
161
162
163
164
        if has_roberta_prefix:
            # For models with the `roberta.` prefix e.g.
            # `FacebookAI/roberta-base`
            mapper = WeightsMapper(orig_to_new_prefix={"roberta.": "model."})
        else:
            # For models without the `roberta.` prefix e.g.
            # `sentence-transformers/stsb-roberta-base-v2`
            mapper = WeightsMapper(orig_to_new_prefix={"": "model."})

        loader = AutoWeightsLoader(self, skip_prefixes=["lm_head."])
        return loader.load_weights(weights_list, mapper=mapper)
165

166

167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
def filter_secondary_weights(
    all_weights: Iterable[tuple[str, torch.Tensor]],
    secondary_weights: list[str],
) -> tuple[Iterable[tuple[str, torch.Tensor]], Iterable[tuple[str, torch.Tensor]]]:
    all_weights1, all_weights2 = itertools.tee(all_weights)

    def filtered(n):
        return any(n.startswith(f) for f in secondary_weights)

    return ((n, w) for n, w in all_weights1 if filtered(n)), (
        (n, w) for n, w in all_weights2 if not filtered(n)
    )


class BgeM3EmbeddingModel(RobertaEmbeddingModel):
    """A model that extends RobertaEmbeddingModel with sparse embeddings.

    This class supports loading an additional sparse_linear.pt file
    to create sparse embeddings as described in https://arxiv.org/abs/2402.03216
    """

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        self.hidden_size = vllm_config.model_config.hf_config.hidden_size

        model_config = vllm_config.model_config
        self.head_dtype = model_config.head_dtype
        self.bos_token_id = model_config.hf_config.bos_token_id
        self.eos_token_id = model_config.hf_config.eos_token_id

        super().__init__(vllm_config=vllm_config, prefix=prefix)
        self.secondary_weight_prefixes = ["sparse_linear.", "colbert_linear."]
        self.secondary_weight_files = [
            prefix + "pt" for prefix in self.secondary_weight_prefixes
        ]

        self.secondary_weights = [
            DefaultModelLoader.Source(
                model_or_path=vllm_config.model_config.model,
                revision=None,
                prefix=prefix,
                allow_patterns_overrides=[filename],
            )
            for filename, prefix in zip(
                self.secondary_weight_files, self.secondary_weight_prefixes
            )
        ]

    def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler:
        self.sparse_linear = nn.Linear(self.hidden_size, 1, dtype=self.head_dtype)
        self.colbert_linear = nn.Linear(
            self.hidden_size, self.hidden_size, dtype=self.head_dtype
        )

        return DispatchPooler(
            {
                "embed": pooler_for_embed(pooler_config),
                "token_embed": BOSEOSFilter(
                    pooler_for_token_embed(pooler_config, self.colbert_linear),
                    self.bos_token_id,
                    # for some reason m3 only filters the bos for colbert vectors
                ),
                "token_classify": BOSEOSFilter(
                    pooler_for_token_classify(
                        pooler_config,
                        pooling=AllPool(),
                        classifier=self.sparse_linear,
                        act_fn=torch.relu,
                    ),
                    self.bos_token_id,
                    self.eos_token_id,
                ),
            }
        )

    def load_weights(self, all_weights: Iterable[tuple[str, torch.Tensor]]):
        secondary, weights = filter_secondary_weights(
            all_weights, self.secondary_weight_prefixes
        )

        super().load_weights(weights)

        params_dict = dict(self.named_parameters())

        for name, loaded_weight in secondary:
            if any(
                name.startswith(prefix) for prefix in self.secondary_weight_prefixes
            ):
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                weight_loader(param, loaded_weight)


259
@default_pooling_type(seq_pooling_type="CLS")
260
class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
261
262
    """A model that uses Roberta to provide embedding functionalities.

263
264
    This class encapsulates the BertModel and provides an interface for
    embedding operations and customized pooling functions.
265

266
267
268
269
    Attributes:
        roberta: An instance of BertModel used for forward operations.
        _pooler: An instance of Pooler used for pooling operations.
    """
270

271
    is_pooling_model = True
272
273
    jina_to_vllm_mapper = WeightsMapper(
        orig_to_new_substr={
274
275
276
277
278
279
280
281
282
283
            "emb_ln": "embeddings.LayerNorm",
            "layers": "layer",
            "mixer.Wqkv": "attention.self.qkv_proj",
            "mixer.out_proj": "attention.output.dense",
            "norm1": "attention.output.LayerNorm",
            "mlp.fc1": "intermediate.dense",
            "mlp.fc2": "output.dense",
            "norm2": "output.LayerNorm",
        }
    )
284

285
286
287
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
288
        self.padding_idx: int = vllm_config.model_config.hf_config.pad_token_id
289
290

        self.num_labels = config.num_labels
291
292
293
294
295
        self.roberta = BertModel(
            vllm_config=vllm_config,
            prefix=maybe_prefix(prefix, "bert"),
            embedding_class=RobertaEmbedding,
        )
296
        self.classifier = RobertaClassificationHead(vllm_config.model_config)
297

298
299
300
        pooler_config = vllm_config.model_config.pooler_config
        assert pooler_config is not None

301
302
303
        self.pooler = DispatchPooler.for_seq_cls(
            pooler_config,
            classifier=self.classifier,
304
        )
305

306
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
307
308
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights, mapper=self.jina_to_vllm_mapper)
309

310
311
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.roberta.embed_input_ids(input_ids)
312

313
314
    def forward(
        self,
315
        input_ids: torch.Tensor | None,
316
        positions: torch.Tensor,
317
318
319
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
        token_type_ids: torch.Tensor | None = None,
320
    ) -> torch.Tensor:
321
322
323
        replace_roberta_positions(
            input_ids=input_ids, position_ids=positions, padding_idx=self.padding_idx
        )
324
325
326
327
        if token_type_ids is not None:
            assert self.roberta.config.vocab_size < (1 << TOKEN_TYPE_SHIFT)
            assert input_ids is not None
            _encode_token_type_ids(input_ids, token_type_ids)
328
329
330
331
332
333
        return self.roberta(
            input_ids=input_ids,
            positions=positions,
            inputs_embeds=inputs_embeds,
            intermediate_tensors=intermediate_tensors,
        )
334
335


336
337
338
def replace_roberta_positions(
    input_ids: torch.Tensor, position_ids: torch.Tensor, padding_idx: int
) -> None:
339
340
341
342
343
344
345
346
    # Replace position ids because in RoBERTa models
    # they have to start at padding_idx + 1 and ignore
    # existing padding tokens
    # References:
    # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133
    # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669
    # vllm does not use padding tokens, let's make things simpler
    position_ids += padding_idx + 1