roberta.py 9.19 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from collections.abc import Iterable
5
6
7
8
9

import torch
from torch import nn
from transformers import RobertaConfig

10
from vllm.config import ModelConfig, VllmConfig
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from vllm.model_executor.layers.pooler import (
    ClassifierPooler,
    CLSPool,
    DispatchPooler,
    Pooler,
)
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.models.bert import (
    TOKEN_TYPE_SHIFT,
    BertEmbeddingModel,
    BertModel,
    _decode_token_type_ids,
    _encode_token_type_ids,
)
from vllm.model_executor.models.utils import (
    AutoWeightsLoader,
    WeightsMapper,
    maybe_prefix,
)
30
from vllm.sequence import IntermediateTensors
31

32
from .bert_with_rope import BertWithRope, JinaRobertaModel
33
34
from .interfaces import SupportsCrossEncoding
from .interfaces_base import default_pooling_type
35

36
37
38
39
40

class RobertaEmbedding(nn.Module):
    def __init__(self, config: RobertaConfig):
        super().__init__()
        self.size = config.hidden_size
41
42
43
        self.word_embeddings = VocabParallelEmbedding(
            config.vocab_size, config.hidden_size
        )
44
        self.padding_idx = config.pad_token_id
45
46
47
48
49
50
51
52
53
54
        self.position_embeddings = nn.Embedding(
            config.max_position_embeddings,
            config.hidden_size,
            padding_idx=self.padding_idx,
        )

        self.token_type_embeddings = nn.Embedding(
            config.type_vocab_size, config.hidden_size
        )
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
55
56
57
58
        self.register_buffer(
            "position_ids",
            torch.arange(config.max_position_embeddings).unsqueeze(0),
        )
59
60
61
62

    def forward(
        self,
        input_ids: torch.Tensor,
63
        position_ids: torch.Tensor,
64
        inputs_embeds: torch.Tensor | None = None,
65
    ) -> torch.Tensor:
66
67
        token_type_ids = _decode_token_type_ids(input_ids)

68
69
70
        if inputs_embeds is None:
            inputs_embeds = self.word_embeddings(input_ids)

71
72
        position_embeddings = self.position_embeddings(position_ids)

73
        token_type_embeddings = self.token_type_embeddings(token_type_ids)
74
75
76
77
78
        embeddings = inputs_embeds + token_type_embeddings + position_embeddings
        embeddings = self.LayerNorm(embeddings)
        return embeddings


79
80
81
82
# Adapted from transformers
class RobertaClassificationHead(nn.Module):
    """Head for sentence-level classification tasks."""

83
    def __init__(self, model_config: "ModelConfig"):
84
        super().__init__()
85
86
        config = model_config.hf_config
        head_dtype = model_config.head_dtype
87
88
89
90
        self.dense = nn.Linear(config.hidden_size, config.hidden_size, dtype=head_dtype)
        self.out_proj = nn.Linear(
            config.hidden_size, config.num_labels, dtype=head_dtype
        )
91

92
93
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # CLSPool has already been applied in `pooling`
94
95
96
97
98
99
        x = self.dense(x)
        x = torch.tanh(x)
        x = self.out_proj(x)
        return x


100
@default_pooling_type("CLS")
101
class RobertaEmbeddingModel(BertEmbeddingModel):
102
    """A model that uses Roberta to provide embedding functionalities."""
103

104
105
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__(vllm_config=vllm_config, prefix=prefix)
106
        self.padding_idx: int = vllm_config.model_config.hf_config.pad_token_id
107
108
109

    def forward(
        self,
110
        input_ids: torch.Tensor,
111
        positions: torch.Tensor,
112
113
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
114
115
116
117
    ) -> torch.Tensor:
        # Fix Roberta positions here outside of the CUDA graph.
        # Because we need the to extract the sequences from
        # input_ids the control flow is data dependent.
118
119
120
121
122
123
124
125
126
127
128
129
130
        replace_roberta_positions(
            input_ids=input_ids, position_ids=positions, padding_idx=self.padding_idx
        )

        return self.model(
            input_ids=input_ids,
            positions=positions,
            inputs_embeds=inputs_embeds,
            intermediate_tensors=intermediate_tensors,
        )

    def _build_model(
        self, vllm_config: VllmConfig, prefix: str = ""
131
    ) -> BertModel | BertWithRope:
132
133
134
135
        hf_config = vllm_config.model_config.hf_config
        kwargs = dict(vllm_config=vllm_config, prefix=prefix)
        if getattr(hf_config, "position_embedding_type", "absolute") == "absolute":
            return BertModel(**kwargs, embedding_class=RobertaEmbedding)
136
        else:
137
            return JinaRobertaModel(**kwargs)
138

139
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
140
141
        weights_list = list(weights)
        has_roberta_prefix = any(
142
143
            name.startswith("roberta.") for name, _ in weights_list
        )
144
145
146
147
148
149
150
151
152
153
154
        if has_roberta_prefix:
            # For models with the `roberta.` prefix e.g.
            # `FacebookAI/roberta-base`
            mapper = WeightsMapper(orig_to_new_prefix={"roberta.": "model."})
        else:
            # For models without the `roberta.` prefix e.g.
            # `sentence-transformers/stsb-roberta-base-v2`
            mapper = WeightsMapper(orig_to_new_prefix={"": "model."})

        loader = AutoWeightsLoader(self, skip_prefixes=["lm_head."])
        return loader.load_weights(weights_list, mapper=mapper)
155

156

157
@default_pooling_type("CLS")
158
class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
159
160
    """A model that uses Roberta to provide embedding functionalities.

161
162
    This class encapsulates the BertModel and provides an interface for
    embedding operations and customized pooling functions.
163

164
165
166
167
    Attributes:
        roberta: An instance of BertModel used for forward operations.
        _pooler: An instance of Pooler used for pooling operations.
    """
168

169
    is_pooling_model = True
170
171
    jina_to_vllm_mapper = WeightsMapper(
        orig_to_new_substr={
172
173
174
175
176
177
178
179
180
181
            "emb_ln": "embeddings.LayerNorm",
            "layers": "layer",
            "mixer.Wqkv": "attention.self.qkv_proj",
            "mixer.out_proj": "attention.output.dense",
            "norm1": "attention.output.LayerNorm",
            "mlp.fc1": "intermediate.dense",
            "mlp.fc2": "output.dense",
            "norm2": "output.LayerNorm",
        }
    )
182

183
184
185
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
186
        self.padding_idx: int = vllm_config.model_config.hf_config.pad_token_id
187
188

        self.num_labels = config.num_labels
189
190
191
192
193
        self.roberta = BertModel(
            vllm_config=vllm_config,
            prefix=maybe_prefix(prefix, "bert"),
            embedding_class=RobertaEmbedding,
        )
194
        self.classifier = RobertaClassificationHead(vllm_config.model_config)
195

196
197
198
        pooler_config = vllm_config.model_config.pooler_config
        assert pooler_config is not None

199
200
        self.pooler = DispatchPooler(
            {
201
202
203
                "token_classify": Pooler.for_token_classify(
                    pooler_config=pooler_config, classifier=self.classifier
                ),
204
                "classify": ClassifierPooler(
205
                    pooling=CLSPool(), classifier=self.classifier, act_fn="classify"
206
207
                ),
                "score": ClassifierPooler(
208
                    pooling=CLSPool(), classifier=self.classifier, act_fn="score"
209
210
211
                ),
            }
        )
212

213
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
214
215
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights, mapper=self.jina_to_vllm_mapper)
216

217
218
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.roberta.embed_input_ids(input_ids)
219

220
221
    def forward(
        self,
222
        input_ids: torch.Tensor | None,
223
        positions: torch.Tensor,
224
225
226
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
        token_type_ids: torch.Tensor | None = None,
227
    ) -> torch.Tensor:
228
229
230
        replace_roberta_positions(
            input_ids=input_ids, position_ids=positions, padding_idx=self.padding_idx
        )
231
232
233
234
        if token_type_ids is not None:
            assert self.roberta.config.vocab_size < (1 << TOKEN_TYPE_SHIFT)
            assert input_ids is not None
            _encode_token_type_ids(input_ids, token_type_ids)
235
236
237
238
239
240
        return self.roberta(
            input_ids=input_ids,
            positions=positions,
            inputs_embeds=inputs_embeds,
            intermediate_tensors=intermediate_tensors,
        )
241
242


243
244
245
def replace_roberta_positions(
    input_ids: torch.Tensor, position_ids: torch.Tensor, padding_idx: int
) -> None:
246
247
248
249
250
251
252
253
    # Replace position ids because in RoBERTa models
    # they have to start at padding_idx + 1 and ignore
    # existing padding tokens
    # References:
    # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133
    # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669
    # vllm does not use padding tokens, let's make things simpler
    position_ids += padding_idx + 1