roberta.py 9.88 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from collections.abc import Iterable
5
6
7
8
9

import torch
from torch import nn
from transformers import RobertaConfig

10
from vllm.config import ModelConfig, VllmConfig
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from vllm.model_executor.layers.pooler import (
    ClassifierPooler,
    CLSPool,
    DispatchPooler,
    Pooler,
)
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.models.bert import (
    TOKEN_TYPE_SHIFT,
    BertEmbeddingModel,
    BertModel,
    _decode_token_type_ids,
    _encode_token_type_ids,
)
from vllm.model_executor.models.utils import (
    AutoWeightsLoader,
    WeightsMapper,
    maybe_prefix,
)
30
from vllm.sequence import IntermediateTensors
31

32
from .bert_with_rope import BertWithRope, JinaRobertaModel
33
34
from .interfaces import SupportsCrossEncoding
from .interfaces_base import default_pooling_type
35

36
37
38
39
40

class RobertaEmbedding(nn.Module):
    def __init__(self, config: RobertaConfig):
        super().__init__()
        self.size = config.hidden_size
41
42
43
        self.word_embeddings = VocabParallelEmbedding(
            config.vocab_size, config.hidden_size
        )
44
        self.padding_idx = config.pad_token_id
45
46
47
48
49
50
51
52
53
54
        self.position_embeddings = nn.Embedding(
            config.max_position_embeddings,
            config.hidden_size,
            padding_idx=self.padding_idx,
        )

        self.token_type_embeddings = nn.Embedding(
            config.type_vocab_size, config.hidden_size
        )
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
55
56
57
58
        self.register_buffer(
            "position_ids",
            torch.arange(config.max_position_embeddings).unsqueeze(0),
        )
59
60
61

        self.position_embedding_type = config.position_embedding_type
        if self.position_embedding_type != "absolute":
62
63
64
            raise ValueError(
                "Only 'absolute' position_embedding_type" + " is supported"
            )
65
66
67
68

    def forward(
        self,
        input_ids: torch.Tensor,
69
        position_ids: torch.Tensor,
70
        inputs_embeds: torch.Tensor | None = None,
71
    ) -> torch.Tensor:
72
73
        token_type_ids = _decode_token_type_ids(input_ids)

74
75
76
        if inputs_embeds is None:
            inputs_embeds = self.word_embeddings(input_ids)

77
78
        position_embeddings = self.position_embeddings(position_ids)

79
        token_type_embeddings = self.token_type_embeddings(token_type_ids)
80
81
82
83
84
        embeddings = inputs_embeds + token_type_embeddings + position_embeddings
        embeddings = self.LayerNorm(embeddings)
        return embeddings


85
86
87
88
# Adapted from transformers
class RobertaClassificationHead(nn.Module):
    """Head for sentence-level classification tasks."""

89
    def __init__(self, model_config: "ModelConfig"):
90
        super().__init__()
91
92
        config = model_config.hf_config
        head_dtype = model_config.head_dtype
93
94
95
96
        self.dense = nn.Linear(config.hidden_size, config.hidden_size, dtype=head_dtype)
        self.out_proj = nn.Linear(
            config.hidden_size, config.num_labels, dtype=head_dtype
        )
97

98
99
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # CLSPool has already been applied in `pooling`
100
101
102
103
104
105
        x = self.dense(x)
        x = torch.tanh(x)
        x = self.out_proj(x)
        return x


106
@default_pooling_type("CLS")
107
108
109
class RobertaEmbeddingModel(BertEmbeddingModel):
    """A model that uses Roberta to provide embedding functionalities.

110
111
    This class encapsulates the BertModel and provides an interface for
    embedding operations and customized pooling functions.
112

113
114
115
116
    Attributes:
        model: An instance of BertModel used for forward operations.
        _pooler: An instance of Pooler used for pooling operations.
    """
117

118
119
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__(vllm_config=vllm_config, prefix=prefix)
120
        self.padding_idx: int = vllm_config.model_config.hf_config.pad_token_id
121
122
123

    def forward(
        self,
124
        input_ids: torch.Tensor,
125
        positions: torch.Tensor,
126
127
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
128
129
130
131
    ) -> torch.Tensor:
        # Fix Roberta positions here outside of the CUDA graph.
        # Because we need the to extract the sequences from
        # input_ids the control flow is data dependent.
132
133
134
135
136
137
138
139
140
141
142
143
144
        replace_roberta_positions(
            input_ids=input_ids, position_ids=positions, padding_idx=self.padding_idx
        )

        return self.model(
            input_ids=input_ids,
            positions=positions,
            inputs_embeds=inputs_embeds,
            intermediate_tensors=intermediate_tensors,
        )

    def _build_model(
        self, vllm_config: VllmConfig, prefix: str = ""
145
    ) -> BertModel | BertWithRope:
146
        if vllm_config.model_config.hf_config.position_embedding_type == "rotary":
147
            return JinaRobertaModel(vllm_config=vllm_config, prefix=prefix)
148
        else:
149
150
151
            return BertModel(
                vllm_config=vllm_config, prefix=prefix, embedding_class=RobertaEmbedding
            )
152

153
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
154
155
        weights_list = list(weights)
        has_roberta_prefix = any(
156
157
            name.startswith("roberta.") for name, _ in weights_list
        )
158
159
160
161
162
163
164
165
166
167
168
        if has_roberta_prefix:
            # For models with the `roberta.` prefix e.g.
            # `FacebookAI/roberta-base`
            mapper = WeightsMapper(orig_to_new_prefix={"roberta.": "model."})
        else:
            # For models without the `roberta.` prefix e.g.
            # `sentence-transformers/stsb-roberta-base-v2`
            mapper = WeightsMapper(orig_to_new_prefix={"": "model."})

        loader = AutoWeightsLoader(self, skip_prefixes=["lm_head."])
        return loader.load_weights(weights_list, mapper=mapper)
169

170

171
@default_pooling_type("CLS")
172
class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
173
174
    """A model that uses Roberta to provide embedding functionalities.

175
176
    This class encapsulates the BertModel and provides an interface for
    embedding operations and customized pooling functions.
177

178
179
180
181
    Attributes:
        roberta: An instance of BertModel used for forward operations.
        _pooler: An instance of Pooler used for pooling operations.
    """
182

183
    is_pooling_model = True
184
185
    jina_to_vllm_mapper = WeightsMapper(
        orig_to_new_substr={
186
187
188
189
190
191
192
193
194
195
            "emb_ln": "embeddings.LayerNorm",
            "layers": "layer",
            "mixer.Wqkv": "attention.self.qkv_proj",
            "mixer.out_proj": "attention.output.dense",
            "norm1": "attention.output.LayerNorm",
            "mlp.fc1": "intermediate.dense",
            "mlp.fc2": "output.dense",
            "norm2": "output.LayerNorm",
        }
    )
196

197
198
199
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
200
        self.padding_idx: int = vllm_config.model_config.hf_config.pad_token_id
201
202

        self.num_labels = config.num_labels
203
204
205
206
207
        self.roberta = BertModel(
            vllm_config=vllm_config,
            prefix=maybe_prefix(prefix, "bert"),
            embedding_class=RobertaEmbedding,
        )
208
        self.classifier = RobertaClassificationHead(vllm_config.model_config)
209

210
211
212
        pooler_config = vllm_config.model_config.pooler_config
        assert pooler_config is not None

213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
        self.pooler = DispatchPooler(
            {
                "encode": Pooler.for_encode(pooler_config),
                "classify": ClassifierPooler(
                    pooling=CLSPool(),
                    classifier=self.classifier,
                    act_fn=ClassifierPooler.act_fn_for_seq_cls(
                        vllm_config.model_config
                    ),
                ),
                "score": ClassifierPooler(
                    pooling=CLSPool(),
                    classifier=self.classifier,
                    act_fn=ClassifierPooler.act_fn_for_cross_encoder(
                        vllm_config.model_config
                    ),
                ),
            }
        )
232

233
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
234
235
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights, mapper=self.jina_to_vllm_mapper)
236

237
238
239
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.roberta.get_input_embeddings(input_ids)

240
241
    def forward(
        self,
242
        input_ids: torch.Tensor | None,
243
        positions: torch.Tensor,
244
245
246
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
        token_type_ids: torch.Tensor | None = None,
247
    ) -> torch.Tensor:
248
249
250
        replace_roberta_positions(
            input_ids=input_ids, position_ids=positions, padding_idx=self.padding_idx
        )
251
252
253
254
        if token_type_ids is not None:
            assert self.roberta.config.vocab_size < (1 << TOKEN_TYPE_SHIFT)
            assert input_ids is not None
            _encode_token_type_ids(input_ids, token_type_ids)
255
256
257
258
259
260
        return self.roberta(
            input_ids=input_ids,
            positions=positions,
            inputs_embeds=inputs_embeds,
            intermediate_tensors=intermediate_tensors,
        )
261
262


263
264
265
def replace_roberta_positions(
    input_ids: torch.Tensor, position_ids: torch.Tensor, padding_idx: int
) -> None:
266
267
268
269
270
271
272
273
    # Replace position ids because in RoBERTa models
    # they have to start at padding_idx + 1 and ignore
    # existing padding tokens
    # References:
    # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133
    # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669
    # vllm does not use padding tokens, let's make things simpler
    position_ids += padding_idx + 1