roberta.py 9.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from collections.abc import Iterable
5
6
7
8
9

import torch
from torch import nn
from transformers import RobertaConfig

10
from vllm.config import ModelConfig, VllmConfig
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from vllm.model_executor.layers.pooler import (
    ClassifierPooler,
    CLSPool,
    DispatchPooler,
    Pooler,
)
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.models.bert import (
    TOKEN_TYPE_SHIFT,
    BertEmbeddingModel,
    BertModel,
    _decode_token_type_ids,
    _encode_token_type_ids,
)
from vllm.model_executor.models.utils import (
    AutoWeightsLoader,
    WeightsMapper,
    maybe_prefix,
)
30
from vllm.sequence import IntermediateTensors
31

32
from .bert_with_rope import BertWithRope, JinaRobertaModel
33
34
from .interfaces import SupportsCrossEncoding
from .interfaces_base import default_pooling_type
35

36
37
38
39
40

class RobertaEmbedding(nn.Module):
    def __init__(self, config: RobertaConfig):
        super().__init__()
        self.size = config.hidden_size
41
42
43
        self.word_embeddings = VocabParallelEmbedding(
            config.vocab_size, config.hidden_size
        )
44
        self.padding_idx = config.pad_token_id
45
46
47
48
49
50
51
52
53
54
        self.position_embeddings = nn.Embedding(
            config.max_position_embeddings,
            config.hidden_size,
            padding_idx=self.padding_idx,
        )

        self.token_type_embeddings = nn.Embedding(
            config.type_vocab_size, config.hidden_size
        )
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
55
56
57
58
        self.register_buffer(
            "position_ids",
            torch.arange(config.max_position_embeddings).unsqueeze(0),
        )
59
60
61

        self.position_embedding_type = config.position_embedding_type
        if self.position_embedding_type != "absolute":
62
63
64
            raise ValueError(
                "Only 'absolute' position_embedding_type" + " is supported"
            )
65
66
67
68

    def forward(
        self,
        input_ids: torch.Tensor,
69
        position_ids: torch.Tensor,
70
        inputs_embeds: torch.Tensor | None = None,
71
    ) -> torch.Tensor:
72
73
        token_type_ids = _decode_token_type_ids(input_ids)

74
75
76
        if inputs_embeds is None:
            inputs_embeds = self.word_embeddings(input_ids)

77
78
        position_embeddings = self.position_embeddings(position_ids)

79
        token_type_embeddings = self.token_type_embeddings(token_type_ids)
80
81
82
83
84
        embeddings = inputs_embeds + token_type_embeddings + position_embeddings
        embeddings = self.LayerNorm(embeddings)
        return embeddings


85
86
87
88
# Adapted from transformers
class RobertaClassificationHead(nn.Module):
    """Head for sentence-level classification tasks."""

89
    def __init__(self, model_config: "ModelConfig"):
90
        super().__init__()
91
92
        config = model_config.hf_config
        head_dtype = model_config.head_dtype
93
94
95
96
        self.dense = nn.Linear(config.hidden_size, config.hidden_size, dtype=head_dtype)
        self.out_proj = nn.Linear(
            config.hidden_size, config.num_labels, dtype=head_dtype
        )
97

98
99
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # CLSPool has already been applied in `pooling`
100
101
102
103
104
105
        x = self.dense(x)
        x = torch.tanh(x)
        x = self.out_proj(x)
        return x


106
@default_pooling_type("CLS")
107
class RobertaEmbeddingModel(BertEmbeddingModel):
108
    """A model that uses Roberta to provide embedding functionalities."""
109

110
111
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__(vllm_config=vllm_config, prefix=prefix)
112
        self.padding_idx: int = vllm_config.model_config.hf_config.pad_token_id
113
114
115

    def forward(
        self,
116
        input_ids: torch.Tensor,
117
        positions: torch.Tensor,
118
119
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
120
121
122
123
    ) -> torch.Tensor:
        # Fix Roberta positions here outside of the CUDA graph.
        # Because we need the to extract the sequences from
        # input_ids the control flow is data dependent.
124
125
126
127
128
129
130
131
132
133
134
135
136
        replace_roberta_positions(
            input_ids=input_ids, position_ids=positions, padding_idx=self.padding_idx
        )

        return self.model(
            input_ids=input_ids,
            positions=positions,
            inputs_embeds=inputs_embeds,
            intermediate_tensors=intermediate_tensors,
        )

    def _build_model(
        self, vllm_config: VllmConfig, prefix: str = ""
137
    ) -> BertModel | BertWithRope:
138
        if vllm_config.model_config.hf_config.position_embedding_type == "rotary":
139
            return JinaRobertaModel(vllm_config=vllm_config, prefix=prefix)
140
        else:
141
142
143
            return BertModel(
                vllm_config=vllm_config, prefix=prefix, embedding_class=RobertaEmbedding
            )
144

145
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
146
147
        weights_list = list(weights)
        has_roberta_prefix = any(
148
149
            name.startswith("roberta.") for name, _ in weights_list
        )
150
151
152
153
154
155
156
157
158
159
160
        if has_roberta_prefix:
            # For models with the `roberta.` prefix e.g.
            # `FacebookAI/roberta-base`
            mapper = WeightsMapper(orig_to_new_prefix={"roberta.": "model."})
        else:
            # For models without the `roberta.` prefix e.g.
            # `sentence-transformers/stsb-roberta-base-v2`
            mapper = WeightsMapper(orig_to_new_prefix={"": "model."})

        loader = AutoWeightsLoader(self, skip_prefixes=["lm_head."])
        return loader.load_weights(weights_list, mapper=mapper)
161

162

163
@default_pooling_type("CLS")
164
class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
165
166
    """A model that uses Roberta to provide embedding functionalities.

167
168
    This class encapsulates the BertModel and provides an interface for
    embedding operations and customized pooling functions.
169

170
171
172
173
    Attributes:
        roberta: An instance of BertModel used for forward operations.
        _pooler: An instance of Pooler used for pooling operations.
    """
174

175
    is_pooling_model = True
176
177
    jina_to_vllm_mapper = WeightsMapper(
        orig_to_new_substr={
178
179
180
181
182
183
184
185
186
187
            "emb_ln": "embeddings.LayerNorm",
            "layers": "layer",
            "mixer.Wqkv": "attention.self.qkv_proj",
            "mixer.out_proj": "attention.output.dense",
            "norm1": "attention.output.LayerNorm",
            "mlp.fc1": "intermediate.dense",
            "mlp.fc2": "output.dense",
            "norm2": "output.LayerNorm",
        }
    )
188

189
190
191
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
192
        self.padding_idx: int = vllm_config.model_config.hf_config.pad_token_id
193
194

        self.num_labels = config.num_labels
195
196
197
198
199
        self.roberta = BertModel(
            vllm_config=vllm_config,
            prefix=maybe_prefix(prefix, "bert"),
            embedding_class=RobertaEmbedding,
        )
200
        self.classifier = RobertaClassificationHead(vllm_config.model_config)
201

202
203
204
        pooler_config = vllm_config.model_config.pooler_config
        assert pooler_config is not None

205
206
        self.pooler = DispatchPooler(
            {
207
208
209
                "token_classify": Pooler.for_token_classify(
                    pooler_config=pooler_config, classifier=self.classifier
                ),
210
                "classify": ClassifierPooler(
211
                    pooling=CLSPool(), classifier=self.classifier, act_fn="classify"
212
213
                ),
                "score": ClassifierPooler(
214
                    pooling=CLSPool(), classifier=self.classifier, act_fn="score"
215
216
217
                ),
            }
        )
218

219
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
220
221
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights, mapper=self.jina_to_vllm_mapper)
222

223
224
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.roberta.embed_input_ids(input_ids)
225

226
227
    def forward(
        self,
228
        input_ids: torch.Tensor | None,
229
        positions: torch.Tensor,
230
231
232
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
        token_type_ids: torch.Tensor | None = None,
233
    ) -> torch.Tensor:
234
235
236
        replace_roberta_positions(
            input_ids=input_ids, position_ids=positions, padding_idx=self.padding_idx
        )
237
238
239
240
        if token_type_ids is not None:
            assert self.roberta.config.vocab_size < (1 << TOKEN_TYPE_SHIFT)
            assert input_ids is not None
            _encode_token_type_ids(input_ids, token_type_ids)
241
242
243
244
245
246
        return self.roberta(
            input_ids=input_ids,
            positions=positions,
            inputs_embeds=inputs_embeds,
            intermediate_tensors=intermediate_tensors,
        )
247
248


249
250
251
def replace_roberta_positions(
    input_ids: torch.Tensor, position_ids: torch.Tensor, padding_idx: int
) -> None:
252
253
254
255
256
257
258
259
    # Replace position ids because in RoBERTa models
    # they have to start at padding_idx + 1 and ignore
    # existing padding tokens
    # References:
    # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133
    # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669
    # vllm does not use padding tokens, let's make things simpler
    position_ids += padding_idx + 1