roberta.py 10.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
from collections.abc import Iterable
from typing import Optional, Union
6
7
8
9
10

import torch
from torch import nn
from transformers import RobertaConfig

11
from vllm.config import ModelConfig, VllmConfig
12
13
from vllm.model_executor.layers.pooler import (ClassifierPooler, CLSPool,
                                               DispatchPooler, Pooler)
14
15
from vllm.model_executor.layers.vocab_parallel_embedding import (
    VocabParallelEmbedding)
16
17
18
19
from vllm.model_executor.models.bert import (TOKEN_TYPE_SHIFT,
                                             BertEmbeddingModel, BertModel,
                                             _decode_token_type_ids,
                                             _encode_token_type_ids)
20
21
from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper,
                                              maybe_prefix)
22
from vllm.sequence import IntermediateTensors
23

24
from .bert_with_rope import BertWithRope, JinaRobertaModel
25
26
from .interfaces import SupportsCrossEncoding
from .interfaces_base import default_pooling_type
27

28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44

class RobertaEmbedding(nn.Module):

    def __init__(self, config: RobertaConfig):
        super().__init__()
        self.size = config.hidden_size
        self.word_embeddings = VocabParallelEmbedding(config.vocab_size,
                                                      config.hidden_size)
        self.padding_idx = config.pad_token_id
        self.position_embeddings = nn.Embedding(config.max_position_embeddings,
                                                config.hidden_size,
                                                padding_idx=self.padding_idx)

        self.token_type_embeddings = nn.Embedding(config.type_vocab_size,
                                                  config.hidden_size)
        self.LayerNorm = nn.LayerNorm(config.hidden_size,
                                      eps=config.layer_norm_eps)
45
46
47
48
        self.register_buffer(
            "position_ids",
            torch.arange(config.max_position_embeddings).unsqueeze(0),
        )
49
50
51
52
53
54
55
56
57

        self.position_embedding_type = config.position_embedding_type
        if self.position_embedding_type != "absolute":
            raise ValueError("Only 'absolute' position_embedding_type" +
                             " is supported")

    def forward(
        self,
        input_ids: torch.Tensor,
58
        position_ids: torch.Tensor,
59
60
    ) -> torch.Tensor:

61
62
63
        token_type_ids = _decode_token_type_ids(input_ids)

        inputs_embeds = self.word_embeddings(input_ids)
64
65
        position_embeddings = self.position_embeddings(position_ids)

66
        token_type_embeddings = self.token_type_embeddings(token_type_ids)
67
68
69
70
71
        embeddings = inputs_embeds + token_type_embeddings + position_embeddings
        embeddings = self.LayerNorm(embeddings)
        return embeddings


72
73
74
75
# Adapted from transformers
class RobertaClassificationHead(nn.Module):
    """Head for sentence-level classification tasks."""

76
    def __init__(self, model_config: "ModelConfig"):
77
        super().__init__()
78
79
80
81
82
83
84
85
        config = model_config.hf_config
        head_dtype = model_config.head_dtype
        self.dense = nn.Linear(config.hidden_size,
                               config.hidden_size,
                               dtype=head_dtype)
        self.out_proj = nn.Linear(config.hidden_size,
                                  config.num_labels,
                                  dtype=head_dtype)
86

87
88
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # CLSPool has already been applied in `pooling`
89
90
91
92
93
94
        x = self.dense(x)
        x = torch.tanh(x)
        x = self.out_proj(x)
        return x


95
@default_pooling_type("CLS")
96
97
98
99
100
101
102
103
104
105
106
class RobertaEmbeddingModel(BertEmbeddingModel):
    """A model that uses Roberta to provide embedding functionalities.

   This class encapsulates the BertModel and provides an interface for
   embedding operations and customized pooling functions.

   Attributes:
       model: An instance of BertModel used for forward operations.
       _pooler: An instance of Pooler used for pooling operations.
   """

107
108
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__(vllm_config=vllm_config, prefix=prefix)
109
        self.padding_idx: int = vllm_config.model_config.hf_config.pad_token_id
110
111
112

    def forward(
        self,
113
        input_ids: torch.Tensor,
114
115
116
117
118
119
120
121
122
123
124
125
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:

        # Fix Roberta positions here outside of the CUDA graph.
        # Because we need the to extract the sequences from
        # input_ids the control flow is data dependent.
        replace_roberta_positions(input_ids=input_ids,
                                  position_ids=positions,
                                  padding_idx=self.padding_idx)

126
127
        return self.model(input_ids=input_ids,
                          positions=positions,
128
129
130
                          inputs_embeds=inputs_embeds,
                          intermediate_tensors=intermediate_tensors)

131
132
    def _build_model(self,
                     vllm_config: VllmConfig,
133
                     prefix: str = "") -> Union[BertModel, BertWithRope]:
134
135
        if (vllm_config.model_config.hf_config.position_embedding_type ==
                "rotary"):
136
            return JinaRobertaModel(vllm_config=vllm_config, prefix=prefix)
137
138
139
140
        else:
            return BertModel(vllm_config=vllm_config,
                             prefix=prefix,
                             embedding_class=RobertaEmbedding)
141

142
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
143
144
145
146
147
148
149
150
151
152
153
154
155
156
        weights_list = list(weights)
        has_roberta_prefix = any(
            name.startswith("roberta.") for name, _ in weights_list)
        if has_roberta_prefix:
            # For models with the `roberta.` prefix e.g.
            # `FacebookAI/roberta-base`
            mapper = WeightsMapper(orig_to_new_prefix={"roberta.": "model."})
        else:
            # For models without the `roberta.` prefix e.g.
            # `sentence-transformers/stsb-roberta-base-v2`
            mapper = WeightsMapper(orig_to_new_prefix={"": "model."})

        loader = AutoWeightsLoader(self, skip_prefixes=["lm_head."])
        return loader.load_weights(weights_list, mapper=mapper)
157

158

159
@default_pooling_type("CLS")
160
class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
161
162
163
164
165
166
167
168
169
170
    """A model that uses Roberta to provide embedding functionalities.

   This class encapsulates the BertModel and provides an interface for
   embedding operations and customized pooling functions.

   Attributes:
       roberta: An instance of BertModel used for forward operations.
       _pooler: An instance of Pooler used for pooling operations.
   """

171
    is_pooling_model = True
172
173
174
175
176
177
178
179
180
181
182
183
    jina_to_vllm_mapper = WeightsMapper(
        orig_to_new_substr={
            'emb_ln': "embeddings.LayerNorm",
            'layers': "layer",
            'mixer.Wqkv': "attention.self.qkv_proj",
            'mixer.out_proj': "attention.output.dense",
            'norm1': "attention.output.LayerNorm",
            'mlp.fc1': "intermediate.dense",
            'mlp.fc2': "output.dense",
            'norm2': "output.LayerNorm",
        })

184
185
186
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
187
        self.padding_idx: int = vllm_config.model_config.hf_config.pad_token_id
188
189
190
191

        self.num_labels = config.num_labels
        self.roberta = BertModel(vllm_config=vllm_config,
                                 prefix=maybe_prefix(prefix, "bert"),
192
                                 embedding_class=RobertaEmbedding)
193
        self.classifier = RobertaClassificationHead(vllm_config.model_config)
194

195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
        pooler_config = vllm_config.model_config.pooler_config
        assert pooler_config is not None

        self.pooler = DispatchPooler({
            "encode":
            Pooler.for_encode(pooler_config),
            "classify":
            ClassifierPooler(
                pooling=CLSPool(),
                classifier=self.classifier,
                act_fn=ClassifierPooler.act_fn_for_seq_cls(
                    vllm_config.model_config),
            ),
            "score":
            ClassifierPooler(
                pooling=CLSPool(),
                classifier=self.classifier,
                act_fn=ClassifierPooler.act_fn_for_cross_encoder(
                    vllm_config.model_config),
            ),
        })
216

217
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
218
219
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights, mapper=self.jina_to_vllm_mapper)
220

221
222
223
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.roberta.get_input_embeddings(input_ids)

224
225
226
227
228
229
    def forward(
        self,
        input_ids: Optional[torch.Tensor],
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
230
        token_type_ids: Optional[torch.Tensor] = None,
231
    ) -> torch.Tensor:
232
233
234
        replace_roberta_positions(input_ids=input_ids,
                                  position_ids=positions,
                                  padding_idx=self.padding_idx)
235
236
237
238
        if token_type_ids is not None:
            assert self.roberta.config.vocab_size < (1 << TOKEN_TYPE_SHIFT)
            assert input_ids is not None
            _encode_token_type_ids(input_ids, token_type_ids)
239
        return self.roberta(input_ids=input_ids,
240
                            positions=positions,
241
                            inputs_embeds=inputs_embeds,
242
                            intermediate_tensors=intermediate_tensors)
243
244


245
246
247
def replace_roberta_positions(input_ids: torch.Tensor,
                              position_ids: torch.Tensor,
                              padding_idx: int) -> None:
248
249
250
251
252
253
254
255
    # Replace position ids because in RoBERTa models
    # they have to start at padding_idx + 1 and ignore
    # existing padding tokens
    # References:
    # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133
    # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669
    # vllm does not use padding tokens, let's make things simpler
    position_ids += padding_idx + 1