roberta.py 10 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
from collections.abc import Iterable
from typing import Optional, Union
6
7
8
9
10
11

import torch
from torch import nn
from transformers import RobertaConfig

from vllm.config import VllmConfig
12
13
from vllm.model_executor.layers.pooler import (ClassifierPooler, CLSPool,
                                               DispatchPooler, Pooler)
14
15
from vllm.model_executor.layers.vocab_parallel_embedding import (
    VocabParallelEmbedding)
16
17
18
19
from vllm.model_executor.models.bert import (TOKEN_TYPE_SHIFT,
                                             BertEmbeddingModel, BertModel,
                                             _decode_token_type_ids,
                                             _encode_token_type_ids)
20
21
from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper,
                                              maybe_prefix)
22
from vllm.sequence import IntermediateTensors
23

24
from .bert_with_rope import BertWithRope, JinaRobertaModel
25
from .interfaces import SupportsCrossEncoding, default_pooling_type
26

27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43

class RobertaEmbedding(nn.Module):

    def __init__(self, config: RobertaConfig):
        super().__init__()
        self.size = config.hidden_size
        self.word_embeddings = VocabParallelEmbedding(config.vocab_size,
                                                      config.hidden_size)
        self.padding_idx = config.pad_token_id
        self.position_embeddings = nn.Embedding(config.max_position_embeddings,
                                                config.hidden_size,
                                                padding_idx=self.padding_idx)

        self.token_type_embeddings = nn.Embedding(config.type_vocab_size,
                                                  config.hidden_size)
        self.LayerNorm = nn.LayerNorm(config.hidden_size,
                                      eps=config.layer_norm_eps)
44
45
46
47
        self.register_buffer(
            "position_ids",
            torch.arange(config.max_position_embeddings).unsqueeze(0),
        )
48
49
50
51
52
53
54
55
56

        self.position_embedding_type = config.position_embedding_type
        if self.position_embedding_type != "absolute":
            raise ValueError("Only 'absolute' position_embedding_type" +
                             " is supported")

    def forward(
        self,
        input_ids: torch.Tensor,
57
        position_ids: torch.Tensor,
58
59
    ) -> torch.Tensor:

60
61
62
        token_type_ids = _decode_token_type_ids(input_ids)

        inputs_embeds = self.word_embeddings(input_ids)
63
64
        position_embeddings = self.position_embeddings(position_ids)

65
        token_type_embeddings = self.token_type_embeddings(token_type_ids)
66
67
68
69
70
        embeddings = inputs_embeds + token_type_embeddings + position_embeddings
        embeddings = self.LayerNorm(embeddings)
        return embeddings


71
72
73
74
75
76
77
78
79
# Adapted from transformers
class RobertaClassificationHead(nn.Module):
    """Head for sentence-level classification tasks."""

    def __init__(self, config: RobertaConfig):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.out_proj = nn.Linear(config.hidden_size, config.num_labels)

80
81
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # CLSPool has already been applied in `pooling`
82
83
84
85
86
87
        x = self.dense(x)
        x = torch.tanh(x)
        x = self.out_proj(x)
        return x


88
@default_pooling_type("CLS")
89
90
91
92
93
94
95
96
97
98
99
class RobertaEmbeddingModel(BertEmbeddingModel):
    """A model that uses Roberta to provide embedding functionalities.

   This class encapsulates the BertModel and provides an interface for
   embedding operations and customized pooling functions.

   Attributes:
       model: An instance of BertModel used for forward operations.
       _pooler: An instance of Pooler used for pooling operations.
   """

100
101
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__(vllm_config=vllm_config, prefix=prefix)
102
        self.padding_idx: int = vllm_config.model_config.hf_config.pad_token_id
103
104
105

    def forward(
        self,
106
        input_ids: torch.Tensor,
107
108
109
110
111
112
113
114
115
116
117
118
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:

        # Fix Roberta positions here outside of the CUDA graph.
        # Because we need the to extract the sequences from
        # input_ids the control flow is data dependent.
        replace_roberta_positions(input_ids=input_ids,
                                  position_ids=positions,
                                  padding_idx=self.padding_idx)

119
120
        return self.model(input_ids=input_ids,
                          positions=positions,
121
122
123
                          inputs_embeds=inputs_embeds,
                          intermediate_tensors=intermediate_tensors)

124
125
    def _build_model(self,
                     vllm_config: VllmConfig,
126
                     prefix: str = "") -> Union[BertModel, BertWithRope]:
127
128
        if (vllm_config.model_config.hf_config.position_embedding_type ==
                "rotary"):
129
            return JinaRobertaModel(vllm_config=vllm_config, prefix=prefix)
130
131
132
133
        else:
            return BertModel(vllm_config=vllm_config,
                             prefix=prefix,
                             embedding_class=RobertaEmbedding)
134

135
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
136
137
138
139
140
141
142
143
144
145
146
147
148
149
        weights_list = list(weights)
        has_roberta_prefix = any(
            name.startswith("roberta.") for name, _ in weights_list)
        if has_roberta_prefix:
            # For models with the `roberta.` prefix e.g.
            # `FacebookAI/roberta-base`
            mapper = WeightsMapper(orig_to_new_prefix={"roberta.": "model."})
        else:
            # For models without the `roberta.` prefix e.g.
            # `sentence-transformers/stsb-roberta-base-v2`
            mapper = WeightsMapper(orig_to_new_prefix={"": "model."})

        loader = AutoWeightsLoader(self, skip_prefixes=["lm_head."])
        return loader.load_weights(weights_list, mapper=mapper)
150

151

152
@default_pooling_type("CLS")
153
class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
154
155
156
157
158
159
160
161
162
163
    """A model that uses Roberta to provide embedding functionalities.

   This class encapsulates the BertModel and provides an interface for
   embedding operations and customized pooling functions.

   Attributes:
       roberta: An instance of BertModel used for forward operations.
       _pooler: An instance of Pooler used for pooling operations.
   """

164
    is_pooling_model = True
165
166
167
168
169
170
171
172
173
174
175
176
    jina_to_vllm_mapper = WeightsMapper(
        orig_to_new_substr={
            'emb_ln': "embeddings.LayerNorm",
            'layers': "layer",
            'mixer.Wqkv': "attention.self.qkv_proj",
            'mixer.out_proj': "attention.output.dense",
            'norm1': "attention.output.LayerNorm",
            'mlp.fc1': "intermediate.dense",
            'mlp.fc2': "output.dense",
            'norm2': "output.LayerNorm",
        })

177
178
179
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
180
        self.padding_idx: int = vllm_config.model_config.hf_config.pad_token_id
181
182
183
184

        self.num_labels = config.num_labels
        self.roberta = BertModel(vllm_config=vllm_config,
                                 prefix=maybe_prefix(prefix, "bert"),
185
                                 embedding_class=RobertaEmbedding)
186
        self.classifier = RobertaClassificationHead(config)
187

188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
        pooler_config = vllm_config.model_config.pooler_config
        assert pooler_config is not None

        self.pooler = DispatchPooler({
            "encode":
            Pooler.for_encode(pooler_config),
            "classify":
            ClassifierPooler(
                pooling=CLSPool(),
                classifier=self.classifier,
                act_fn=ClassifierPooler.act_fn_for_seq_cls(
                    vllm_config.model_config),
            ),
            "score":
            ClassifierPooler(
                pooling=CLSPool(),
                classifier=self.classifier,
                act_fn=ClassifierPooler.act_fn_for_cross_encoder(
                    vllm_config.model_config),
            ),
        })
209

210
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
211
212
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights, mapper=self.jina_to_vllm_mapper)
213

214
215
216
217
218
219
    def forward(
        self,
        input_ids: Optional[torch.Tensor],
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
220
        token_type_ids: Optional[torch.Tensor] = None,
221
    ) -> torch.Tensor:
222
223
224
        replace_roberta_positions(input_ids=input_ids,
                                  position_ids=positions,
                                  padding_idx=self.padding_idx)
225
226
227
228
        if token_type_ids is not None:
            assert self.roberta.config.vocab_size < (1 << TOKEN_TYPE_SHIFT)
            assert input_ids is not None
            _encode_token_type_ids(input_ids, token_type_ids)
229
        return self.roberta(input_ids=input_ids,
230
                            positions=positions,
231
                            inputs_embeds=inputs_embeds,
232
                            intermediate_tensors=intermediate_tensors)
233
234


235
236
237
def replace_roberta_positions(input_ids: torch.Tensor,
                              position_ids: torch.Tensor,
                              padding_idx: int) -> None:
238
239
240
241
242
243
244
245
    # Replace position ids because in RoBERTa models
    # they have to start at padding_idx + 1 and ignore
    # existing padding tokens
    # References:
    # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133
    # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669
    # vllm does not use padding tokens, let's make things simpler
    position_ids += padding_idx + 1