roberta.py 8.87 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from collections.abc import Iterable
5
6
7
8
9

import torch
from torch import nn
from transformers import RobertaConfig

10
from vllm.config import ModelConfig, VllmConfig
11
12
from vllm.model_executor.layers.pooler import DispatchPooler
from vllm.model_executor.layers.pooler.seqwise import CLSPool
13
14
15
16
17
18
19
20
21
22
23
24
25
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.models.bert import (
    TOKEN_TYPE_SHIFT,
    BertEmbeddingModel,
    BertModel,
    _decode_token_type_ids,
    _encode_token_type_ids,
)
from vllm.model_executor.models.utils import (
    AutoWeightsLoader,
    WeightsMapper,
    maybe_prefix,
)
26
from vllm.sequence import IntermediateTensors
27

28
from .bert_with_rope import BertWithRope, JinaRobertaModel
29
30
from .interfaces import SupportsCrossEncoding
from .interfaces_base import default_pooling_type
31

32
33
34
35
36

class RobertaEmbedding(nn.Module):
    def __init__(self, config: RobertaConfig):
        super().__init__()
        self.size = config.hidden_size
37
38
39
        self.word_embeddings = VocabParallelEmbedding(
            config.vocab_size, config.hidden_size
        )
40
        self.padding_idx = config.pad_token_id
41
42
43
44
45
46
47
48
49
50
        self.position_embeddings = nn.Embedding(
            config.max_position_embeddings,
            config.hidden_size,
            padding_idx=self.padding_idx,
        )

        self.token_type_embeddings = nn.Embedding(
            config.type_vocab_size, config.hidden_size
        )
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
51
52
53
54
        self.register_buffer(
            "position_ids",
            torch.arange(config.max_position_embeddings).unsqueeze(0),
        )
55
56
57
58

    def forward(
        self,
        input_ids: torch.Tensor,
59
        position_ids: torch.Tensor,
60
        inputs_embeds: torch.Tensor | None = None,
61
    ) -> torch.Tensor:
62
63
        token_type_ids = _decode_token_type_ids(input_ids)

64
65
66
        if inputs_embeds is None:
            inputs_embeds = self.word_embeddings(input_ids)

67
68
        position_embeddings = self.position_embeddings(position_ids)

69
        token_type_embeddings = self.token_type_embeddings(token_type_ids)
70
71
72
73
74
        embeddings = inputs_embeds + token_type_embeddings + position_embeddings
        embeddings = self.LayerNorm(embeddings)
        return embeddings


75
76
77
78
# Adapted from transformers
class RobertaClassificationHead(nn.Module):
    """Head for sentence-level classification tasks."""

79
    def __init__(self, model_config: "ModelConfig"):
80
        super().__init__()
81
82
        config = model_config.hf_config
        head_dtype = model_config.head_dtype
83
84
85
86
        self.dense = nn.Linear(config.hidden_size, config.hidden_size, dtype=head_dtype)
        self.out_proj = nn.Linear(
            config.hidden_size, config.num_labels, dtype=head_dtype
        )
87

88
89
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # CLSPool has already been applied in `pooling`
90
91
92
93
94
95
        x = self.dense(x)
        x = torch.tanh(x)
        x = self.out_proj(x)
        return x


96
@default_pooling_type(seq_pooling_type="CLS")
97
class RobertaEmbeddingModel(BertEmbeddingModel):
98
    """A model that uses Roberta to provide embedding functionalities."""
99

100
101
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__(vllm_config=vllm_config, prefix=prefix)
102
        self.padding_idx: int = vllm_config.model_config.hf_config.pad_token_id
103
104
105

    def forward(
        self,
106
        input_ids: torch.Tensor,
107
        positions: torch.Tensor,
108
109
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
110
111
112
113
    ) -> torch.Tensor:
        # Fix Roberta positions here outside of the CUDA graph.
        # Because we need the to extract the sequences from
        # input_ids the control flow is data dependent.
114
115
116
117
118
119
120
121
122
123
124
125
126
        replace_roberta_positions(
            input_ids=input_ids, position_ids=positions, padding_idx=self.padding_idx
        )

        return self.model(
            input_ids=input_ids,
            positions=positions,
            inputs_embeds=inputs_embeds,
            intermediate_tensors=intermediate_tensors,
        )

    def _build_model(
        self, vllm_config: VllmConfig, prefix: str = ""
127
    ) -> BertModel | BertWithRope:
128
129
130
131
        hf_config = vllm_config.model_config.hf_config
        kwargs = dict(vllm_config=vllm_config, prefix=prefix)
        if getattr(hf_config, "position_embedding_type", "absolute") == "absolute":
            return BertModel(**kwargs, embedding_class=RobertaEmbedding)
132
        else:
133
            return JinaRobertaModel(**kwargs)
134

135
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
136
137
        weights_list = list(weights)
        has_roberta_prefix = any(
138
139
            name.startswith("roberta.") for name, _ in weights_list
        )
140
141
142
143
144
145
146
147
148
149
150
        if has_roberta_prefix:
            # For models with the `roberta.` prefix e.g.
            # `FacebookAI/roberta-base`
            mapper = WeightsMapper(orig_to_new_prefix={"roberta.": "model."})
        else:
            # For models without the `roberta.` prefix e.g.
            # `sentence-transformers/stsb-roberta-base-v2`
            mapper = WeightsMapper(orig_to_new_prefix={"": "model."})

        loader = AutoWeightsLoader(self, skip_prefixes=["lm_head."])
        return loader.load_weights(weights_list, mapper=mapper)
151

152

153
@default_pooling_type(seq_pooling_type="CLS")
154
class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
155
156
    """A model that uses Roberta to provide embedding functionalities.

157
158
    This class encapsulates the BertModel and provides an interface for
    embedding operations and customized pooling functions.
159

160
161
162
163
    Attributes:
        roberta: An instance of BertModel used for forward operations.
        _pooler: An instance of Pooler used for pooling operations.
    """
164

165
    is_pooling_model = True
166
167
    jina_to_vllm_mapper = WeightsMapper(
        orig_to_new_substr={
168
169
170
171
172
173
174
175
176
177
            "emb_ln": "embeddings.LayerNorm",
            "layers": "layer",
            "mixer.Wqkv": "attention.self.qkv_proj",
            "mixer.out_proj": "attention.output.dense",
            "norm1": "attention.output.LayerNorm",
            "mlp.fc1": "intermediate.dense",
            "mlp.fc2": "output.dense",
            "norm2": "output.LayerNorm",
        }
    )
178

179
180
181
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
182
        self.padding_idx: int = vllm_config.model_config.hf_config.pad_token_id
183
184

        self.num_labels = config.num_labels
185
186
187
188
189
        self.roberta = BertModel(
            vllm_config=vllm_config,
            prefix=maybe_prefix(prefix, "bert"),
            embedding_class=RobertaEmbedding,
        )
190
        self.classifier = RobertaClassificationHead(vllm_config.model_config)
191

192
193
194
        pooler_config = vllm_config.model_config.pooler_config
        assert pooler_config is not None

195
196
197
198
        self.pooler = DispatchPooler.for_seq_cls(
            pooler_config,
            pooling=CLSPool(),
            classifier=self.classifier,
199
        )
200

201
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
202
203
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights, mapper=self.jina_to_vllm_mapper)
204

205
206
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.roberta.embed_input_ids(input_ids)
207

208
209
    def forward(
        self,
210
        input_ids: torch.Tensor | None,
211
        positions: torch.Tensor,
212
213
214
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
        token_type_ids: torch.Tensor | None = None,
215
    ) -> torch.Tensor:
216
217
218
        replace_roberta_positions(
            input_ids=input_ids, position_ids=positions, padding_idx=self.padding_idx
        )
219
220
221
222
        if token_type_ids is not None:
            assert self.roberta.config.vocab_size < (1 << TOKEN_TYPE_SHIFT)
            assert input_ids is not None
            _encode_token_type_ids(input_ids, token_type_ids)
223
224
225
226
227
228
        return self.roberta(
            input_ids=input_ids,
            positions=positions,
            inputs_embeds=inputs_embeds,
            intermediate_tensors=intermediate_tensors,
        )
229
230


231
232
233
def replace_roberta_positions(
    input_ids: torch.Tensor, position_ids: torch.Tensor, padding_idx: int
) -> None:
234
235
236
237
238
239
240
241
    # Replace position ids because in RoBERTa models
    # they have to start at padding_idx + 1 and ignore
    # existing padding tokens
    # References:
    # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133
    # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669
    # vllm does not use padding tokens, let's make things simpler
    position_ids += padding_idx + 1