roberta.py 8.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from collections.abc import Iterable
5
6
7
8
9

import torch
from torch import nn
from transformers import RobertaConfig

10
from vllm.config import ModelConfig, VllmConfig
11
from vllm.model_executor.layers.pooler import DispatchPooler
12
13
14
15
16
17
18
19
20
21
22
23
24
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.models.bert import (
    TOKEN_TYPE_SHIFT,
    BertEmbeddingModel,
    BertModel,
    _decode_token_type_ids,
    _encode_token_type_ids,
)
from vllm.model_executor.models.utils import (
    AutoWeightsLoader,
    WeightsMapper,
    maybe_prefix,
)
25
from vllm.sequence import IntermediateTensors
26

27
from .bert_with_rope import BertWithRope, JinaRobertaModel
28
29
from .interfaces import SupportsCrossEncoding
from .interfaces_base import default_pooling_type
30

31
32
33
34
35

class RobertaEmbedding(nn.Module):
    def __init__(self, config: RobertaConfig):
        super().__init__()
        self.size = config.hidden_size
36
37
38
        self.word_embeddings = VocabParallelEmbedding(
            config.vocab_size, config.hidden_size
        )
39
        self.padding_idx = config.pad_token_id
40
41
42
43
44
45
46
47
48
49
        self.position_embeddings = nn.Embedding(
            config.max_position_embeddings,
            config.hidden_size,
            padding_idx=self.padding_idx,
        )

        self.token_type_embeddings = nn.Embedding(
            config.type_vocab_size, config.hidden_size
        )
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
50
51
52
53
        self.register_buffer(
            "position_ids",
            torch.arange(config.max_position_embeddings).unsqueeze(0),
        )
54
55
56
57

    def forward(
        self,
        input_ids: torch.Tensor,
58
        position_ids: torch.Tensor,
59
        inputs_embeds: torch.Tensor | None = None,
60
    ) -> torch.Tensor:
61
62
        token_type_ids = _decode_token_type_ids(input_ids)

63
64
65
        if inputs_embeds is None:
            inputs_embeds = self.word_embeddings(input_ids)

66
67
        position_embeddings = self.position_embeddings(position_ids)

68
        token_type_embeddings = self.token_type_embeddings(token_type_ids)
69
70
71
72
73
        embeddings = inputs_embeds + token_type_embeddings + position_embeddings
        embeddings = self.LayerNorm(embeddings)
        return embeddings


74
75
76
77
# Adapted from transformers
class RobertaClassificationHead(nn.Module):
    """Head for sentence-level classification tasks."""

78
    def __init__(self, model_config: "ModelConfig"):
79
        super().__init__()
80
81
        config = model_config.hf_config
        head_dtype = model_config.head_dtype
82
83
84
85
        self.dense = nn.Linear(config.hidden_size, config.hidden_size, dtype=head_dtype)
        self.out_proj = nn.Linear(
            config.hidden_size, config.num_labels, dtype=head_dtype
        )
86

87
    def forward(self, x: torch.Tensor) -> torch.Tensor:
88
        # Token extraction has already been applied in `pooler.pooling`
89
90
91
92
93
94
        x = self.dense(x)
        x = torch.tanh(x)
        x = self.out_proj(x)
        return x


95
@default_pooling_type(seq_pooling_type="CLS")
96
class RobertaEmbeddingModel(BertEmbeddingModel):
97
    """A model that uses Roberta to provide embedding functionalities."""
98

99
100
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__(vllm_config=vllm_config, prefix=prefix)
101
        self.padding_idx: int = vllm_config.model_config.hf_config.pad_token_id
102
103
104

    def forward(
        self,
105
        input_ids: torch.Tensor,
106
        positions: torch.Tensor,
107
108
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
109
110
111
112
    ) -> torch.Tensor:
        # Fix Roberta positions here outside of the CUDA graph.
        # Because we need the to extract the sequences from
        # input_ids the control flow is data dependent.
113
114
115
116
117
118
119
120
121
122
123
124
125
        replace_roberta_positions(
            input_ids=input_ids, position_ids=positions, padding_idx=self.padding_idx
        )

        return self.model(
            input_ids=input_ids,
            positions=positions,
            inputs_embeds=inputs_embeds,
            intermediate_tensors=intermediate_tensors,
        )

    def _build_model(
        self, vllm_config: VllmConfig, prefix: str = ""
126
    ) -> BertModel | BertWithRope:
127
128
129
130
        hf_config = vllm_config.model_config.hf_config
        kwargs = dict(vllm_config=vllm_config, prefix=prefix)
        if getattr(hf_config, "position_embedding_type", "absolute") == "absolute":
            return BertModel(**kwargs, embedding_class=RobertaEmbedding)
131
        else:
132
            return JinaRobertaModel(**kwargs)
133

134
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
135
136
        weights_list = list(weights)
        has_roberta_prefix = any(
137
138
            name.startswith("roberta.") for name, _ in weights_list
        )
139
140
141
142
143
144
145
146
147
148
149
        if has_roberta_prefix:
            # For models with the `roberta.` prefix e.g.
            # `FacebookAI/roberta-base`
            mapper = WeightsMapper(orig_to_new_prefix={"roberta.": "model."})
        else:
            # For models without the `roberta.` prefix e.g.
            # `sentence-transformers/stsb-roberta-base-v2`
            mapper = WeightsMapper(orig_to_new_prefix={"": "model."})

        loader = AutoWeightsLoader(self, skip_prefixes=["lm_head."])
        return loader.load_weights(weights_list, mapper=mapper)
150

151

152
@default_pooling_type(seq_pooling_type="CLS")
153
class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
154
155
    """A model that uses Roberta to provide embedding functionalities.

156
157
    This class encapsulates the BertModel and provides an interface for
    embedding operations and customized pooling functions.
158

159
160
161
162
    Attributes:
        roberta: An instance of BertModel used for forward operations.
        _pooler: An instance of Pooler used for pooling operations.
    """
163

164
    is_pooling_model = True
165
166
    jina_to_vllm_mapper = WeightsMapper(
        orig_to_new_substr={
167
168
169
170
171
172
173
174
175
176
            "emb_ln": "embeddings.LayerNorm",
            "layers": "layer",
            "mixer.Wqkv": "attention.self.qkv_proj",
            "mixer.out_proj": "attention.output.dense",
            "norm1": "attention.output.LayerNorm",
            "mlp.fc1": "intermediate.dense",
            "mlp.fc2": "output.dense",
            "norm2": "output.LayerNorm",
        }
    )
177

178
179
180
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
181
        self.padding_idx: int = vllm_config.model_config.hf_config.pad_token_id
182
183

        self.num_labels = config.num_labels
184
185
186
187
188
        self.roberta = BertModel(
            vllm_config=vllm_config,
            prefix=maybe_prefix(prefix, "bert"),
            embedding_class=RobertaEmbedding,
        )
189
        self.classifier = RobertaClassificationHead(vllm_config.model_config)
190

191
192
193
        pooler_config = vllm_config.model_config.pooler_config
        assert pooler_config is not None

194
195
196
        self.pooler = DispatchPooler.for_seq_cls(
            pooler_config,
            classifier=self.classifier,
197
        )
198

199
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
200
201
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights, mapper=self.jina_to_vllm_mapper)
202

203
204
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.roberta.embed_input_ids(input_ids)
205

206
207
    def forward(
        self,
208
        input_ids: torch.Tensor | None,
209
        positions: torch.Tensor,
210
211
212
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
        token_type_ids: torch.Tensor | None = None,
213
    ) -> torch.Tensor:
214
215
216
        replace_roberta_positions(
            input_ids=input_ids, position_ids=positions, padding_idx=self.padding_idx
        )
217
218
219
220
        if token_type_ids is not None:
            assert self.roberta.config.vocab_size < (1 << TOKEN_TYPE_SHIFT)
            assert input_ids is not None
            _encode_token_type_ids(input_ids, token_type_ids)
221
222
223
224
225
226
        return self.roberta(
            input_ids=input_ids,
            positions=positions,
            inputs_embeds=inputs_embeds,
            intermediate_tensors=intermediate_tensors,
        )
227
228


229
230
231
def replace_roberta_positions(
    input_ids: torch.Tensor, position_ids: torch.Tensor, padding_idx: int
) -> None:
232
233
234
235
236
237
238
239
    # Replace position ids because in RoBERTa models
    # they have to start at padding_idx + 1 and ignore
    # existing padding tokens
    # References:
    # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133
    # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669
    # vllm does not use padding tokens, let's make things simpler
    position_ids += padding_idx + 1