roberta.py 9.94 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
from collections.abc import Iterable
from typing import Optional, Union
6
7
8
9
10

import torch
from torch import nn
from transformers import RobertaConfig

11
from vllm.config import ModelConfig, VllmConfig
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from vllm.model_executor.layers.pooler import (
    ClassifierPooler,
    CLSPool,
    DispatchPooler,
    Pooler,
)
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.models.bert import (
    TOKEN_TYPE_SHIFT,
    BertEmbeddingModel,
    BertModel,
    _decode_token_type_ids,
    _encode_token_type_ids,
)
from vllm.model_executor.models.utils import (
    AutoWeightsLoader,
    WeightsMapper,
    maybe_prefix,
)
31
from vllm.sequence import IntermediateTensors
32

33
from .bert_with_rope import BertWithRope, JinaRobertaModel
34
35
from .interfaces import SupportsCrossEncoding
from .interfaces_base import default_pooling_type
36

37
38
39
40
41

class RobertaEmbedding(nn.Module):
    def __init__(self, config: RobertaConfig):
        super().__init__()
        self.size = config.hidden_size
42
43
44
        self.word_embeddings = VocabParallelEmbedding(
            config.vocab_size, config.hidden_size
        )
45
        self.padding_idx = config.pad_token_id
46
47
48
49
50
51
52
53
54
55
        self.position_embeddings = nn.Embedding(
            config.max_position_embeddings,
            config.hidden_size,
            padding_idx=self.padding_idx,
        )

        self.token_type_embeddings = nn.Embedding(
            config.type_vocab_size, config.hidden_size
        )
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
56
57
58
59
        self.register_buffer(
            "position_ids",
            torch.arange(config.max_position_embeddings).unsqueeze(0),
        )
60
61
62

        self.position_embedding_type = config.position_embedding_type
        if self.position_embedding_type != "absolute":
63
64
65
            raise ValueError(
                "Only 'absolute' position_embedding_type" + " is supported"
            )
66
67
68
69

    def forward(
        self,
        input_ids: torch.Tensor,
70
        position_ids: torch.Tensor,
71
        inputs_embeds: Optional[torch.Tensor] = None,
72
    ) -> torch.Tensor:
73
74
        token_type_ids = _decode_token_type_ids(input_ids)

75
76
77
        if inputs_embeds is None:
            inputs_embeds = self.word_embeddings(input_ids)

78
79
        position_embeddings = self.position_embeddings(position_ids)

80
        token_type_embeddings = self.token_type_embeddings(token_type_ids)
81
82
83
84
85
        embeddings = inputs_embeds + token_type_embeddings + position_embeddings
        embeddings = self.LayerNorm(embeddings)
        return embeddings


86
87
88
89
# Adapted from transformers
class RobertaClassificationHead(nn.Module):
    """Head for sentence-level classification tasks."""

90
    def __init__(self, model_config: "ModelConfig"):
91
        super().__init__()
92
93
        config = model_config.hf_config
        head_dtype = model_config.head_dtype
94
95
96
97
        self.dense = nn.Linear(config.hidden_size, config.hidden_size, dtype=head_dtype)
        self.out_proj = nn.Linear(
            config.hidden_size, config.num_labels, dtype=head_dtype
        )
98

99
100
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # CLSPool has already been applied in `pooling`
101
102
103
104
105
106
        x = self.dense(x)
        x = torch.tanh(x)
        x = self.out_proj(x)
        return x


107
@default_pooling_type("CLS")
108
109
110
class RobertaEmbeddingModel(BertEmbeddingModel):
    """A model that uses Roberta to provide embedding functionalities.

111
112
    This class encapsulates the BertModel and provides an interface for
    embedding operations and customized pooling functions.
113

114
115
116
117
    Attributes:
        model: An instance of BertModel used for forward operations.
        _pooler: An instance of Pooler used for pooling operations.
    """
118

119
120
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__(vllm_config=vllm_config, prefix=prefix)
121
        self.padding_idx: int = vllm_config.model_config.hf_config.pad_token_id
122
123
124

    def forward(
        self,
125
        input_ids: torch.Tensor,
126
127
128
129
130
131
132
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        # Fix Roberta positions here outside of the CUDA graph.
        # Because we need the to extract the sequences from
        # input_ids the control flow is data dependent.
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
        replace_roberta_positions(
            input_ids=input_ids, position_ids=positions, padding_idx=self.padding_idx
        )

        return self.model(
            input_ids=input_ids,
            positions=positions,
            inputs_embeds=inputs_embeds,
            intermediate_tensors=intermediate_tensors,
        )

    def _build_model(
        self, vllm_config: VllmConfig, prefix: str = ""
    ) -> Union[BertModel, BertWithRope]:
        if vllm_config.model_config.hf_config.position_embedding_type == "rotary":
148
            return JinaRobertaModel(vllm_config=vllm_config, prefix=prefix)
149
        else:
150
151
152
            return BertModel(
                vllm_config=vllm_config, prefix=prefix, embedding_class=RobertaEmbedding
            )
153

154
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
155
156
        weights_list = list(weights)
        has_roberta_prefix = any(
157
158
            name.startswith("roberta.") for name, _ in weights_list
        )
159
160
161
162
163
164
165
166
167
168
169
        if has_roberta_prefix:
            # For models with the `roberta.` prefix e.g.
            # `FacebookAI/roberta-base`
            mapper = WeightsMapper(orig_to_new_prefix={"roberta.": "model."})
        else:
            # For models without the `roberta.` prefix e.g.
            # `sentence-transformers/stsb-roberta-base-v2`
            mapper = WeightsMapper(orig_to_new_prefix={"": "model."})

        loader = AutoWeightsLoader(self, skip_prefixes=["lm_head."])
        return loader.load_weights(weights_list, mapper=mapper)
170

171

172
@default_pooling_type("CLS")
173
class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
174
175
    """A model that uses Roberta to provide embedding functionalities.

176
177
    This class encapsulates the BertModel and provides an interface for
    embedding operations and customized pooling functions.
178

179
180
181
182
    Attributes:
        roberta: An instance of BertModel used for forward operations.
        _pooler: An instance of Pooler used for pooling operations.
    """
183

184
    is_pooling_model = True
185
186
    jina_to_vllm_mapper = WeightsMapper(
        orig_to_new_substr={
187
188
189
190
191
192
193
194
195
196
            "emb_ln": "embeddings.LayerNorm",
            "layers": "layer",
            "mixer.Wqkv": "attention.self.qkv_proj",
            "mixer.out_proj": "attention.output.dense",
            "norm1": "attention.output.LayerNorm",
            "mlp.fc1": "intermediate.dense",
            "mlp.fc2": "output.dense",
            "norm2": "output.LayerNorm",
        }
    )
197

198
199
200
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
201
        self.padding_idx: int = vllm_config.model_config.hf_config.pad_token_id
202
203

        self.num_labels = config.num_labels
204
205
206
207
208
        self.roberta = BertModel(
            vllm_config=vllm_config,
            prefix=maybe_prefix(prefix, "bert"),
            embedding_class=RobertaEmbedding,
        )
209
        self.classifier = RobertaClassificationHead(vllm_config.model_config)
210

211
212
213
        pooler_config = vllm_config.model_config.pooler_config
        assert pooler_config is not None

214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
        self.pooler = DispatchPooler(
            {
                "encode": Pooler.for_encode(pooler_config),
                "classify": ClassifierPooler(
                    pooling=CLSPool(),
                    classifier=self.classifier,
                    act_fn=ClassifierPooler.act_fn_for_seq_cls(
                        vllm_config.model_config
                    ),
                ),
                "score": ClassifierPooler(
                    pooling=CLSPool(),
                    classifier=self.classifier,
                    act_fn=ClassifierPooler.act_fn_for_cross_encoder(
                        vllm_config.model_config
                    ),
                ),
            }
        )
233

234
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
235
236
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights, mapper=self.jina_to_vllm_mapper)
237

238
239
240
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.roberta.get_input_embeddings(input_ids)

241
242
243
244
245
246
    def forward(
        self,
        input_ids: Optional[torch.Tensor],
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
247
        token_type_ids: Optional[torch.Tensor] = None,
248
    ) -> torch.Tensor:
249
250
251
        replace_roberta_positions(
            input_ids=input_ids, position_ids=positions, padding_idx=self.padding_idx
        )
252
253
254
255
        if token_type_ids is not None:
            assert self.roberta.config.vocab_size < (1 << TOKEN_TYPE_SHIFT)
            assert input_ids is not None
            _encode_token_type_ids(input_ids, token_type_ids)
256
257
258
259
260
261
        return self.roberta(
            input_ids=input_ids,
            positions=positions,
            inputs_embeds=inputs_embeds,
            intermediate_tensors=intermediate_tensors,
        )
262
263


264
265
266
def replace_roberta_positions(
    input_ids: torch.Tensor, position_ids: torch.Tensor, padding_idx: int
) -> None:
267
268
269
270
271
272
273
274
    # Replace position ids because in RoBERTa models
    # they have to start at padding_idx + 1 and ignore
    # existing padding tokens
    # References:
    # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133
    # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669
    # vllm does not use padding tokens, let's make things simpler
    position_ids += padding_idx + 1