"vscode:/vscode.git/clone" did not exist on "6e599eebe8655dab75462a8a165f6d811d0d845f"
modernbert.py 13.9 KB
Newer Older
xsank's avatar
xsank committed
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from collections.abc import Iterable, Set
4
from typing import Optional, Union
xsank's avatar
xsank committed
5
6
7
8
9
10
11
12
13
14
15

import torch
from torch import nn
from transformers import ModernBertConfig

from vllm.attention import Attention, AttentionType
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.linear import (QKVParallelLinear,
                                               RowParallelLinear)
16
17
from vllm.model_executor.layers.pooler import (ClassifierPooler,
                                               DispatchPooler, Pooler,
18
19
                                               PoolingMethod,
                                               PoolingParamsUpdate,
20
                                               PoolingType)
xsank's avatar
xsank committed
21
22
23
24
25
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.model_executor.layers.vocab_parallel_embedding import (
    VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.pooling_metadata import PoolingMetadata
26
from vllm.pooling_params import PoolingTask
27
from vllm.sequence import IntermediateTensors
xsank's avatar
xsank committed
28

29
from .interfaces import SupportsCrossEncoding, SupportsV0Only
xsank's avatar
xsank committed
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
from .utils import WeightsMapper, maybe_prefix


class ModernBertEmbeddings(nn.Module):

    def __init__(self, config: ModernBertConfig):

        super().__init__()
        self.config = config
        self.tok_embeddings = VocabParallelEmbedding(config.vocab_size,
                                                     config.hidden_size)
        self.norm = nn.LayerNorm(config.hidden_size,
                                 eps=config.layer_norm_eps,
                                 bias=config.norm_bias)

    def forward(
        self,
        input_ids: torch.Tensor,
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        if inputs_embeds:
            return self.norm(inputs_embeds)
        else:
            inputs_embeds = self.tok_embeddings(input_ids)
            embeddings = self.norm(inputs_embeds)
            return embeddings


class ModernBertRotaryEmbedding(RotaryEmbedding):

    def __init__(self, config: ModernBertConfig, head_size: int, dim: int,
                 base: float):
        super().__init__(
            head_size=head_size,
            rotary_dim=dim,
            max_position_embeddings=config.max_position_embeddings,
            base=base,
            is_neox_style=True,
            dtype=torch.float16)
        self.config = config


class ModernBertAttention(nn.Module):

    def __init__(self,
                 config: ModernBertConfig,
                 layer_id: Optional[int] = None):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.layer_id = layer_id
        self.deterministic_flash_attn = config.deterministic_flash_attn
        self.num_heads = config.num_attention_heads
        assert self.num_heads % tp_size == 0
        self.head_dim = config.hidden_size // config.num_attention_heads
        self.all_head_size = self.head_dim * self.num_heads
        self.scaling = self.head_dim**-0.5
        self.Wqkv = QKVParallelLinear(
            config.hidden_size,
            self.head_dim,
            self.num_heads,
            bias=config.attention_bias,
        )

        if layer_id % config.global_attn_every_n_layers != 0:
            self.local_attention = (config.local_attention // 2,
                                    config.local_attention // 2)
        else:
            self.local_attention = (-1, -1)

        rope_theta = config.global_rope_theta
        if self.local_attention != (
                -1, -1) and config.local_rope_theta is not None:
            rope_theta = config.local_rope_theta
        self.rotary_emb = ModernBertRotaryEmbedding(config=config,
                                                    head_size=self.head_dim,
                                                    dim=self.head_dim,
                                                    base=rope_theta)
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
                              prefix=f"{layer_id}.attn",
                              attn_type=AttentionType.ENCODER_ONLY)
        self.Wo = RowParallelLinear(config.hidden_size,
                                    config.hidden_size,
                                    bias=config.attention_bias)

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_ids: Optional[torch.LongTensor] = None,
    ) -> torch.Tensor:
        qkv, _ = self.Wqkv(hidden_states)
        q, k, v = qkv.split([self.all_head_size] * 3, dim=-1)
        q, k = self.rotary_emb(position_ids, q, k)
        attn_outputs = self.attn(q, k, v)
        hidden_states = attn_outputs
        hidden_states, _ = self.Wo(hidden_states)
        return hidden_states


class ModernBertMLP(nn.Module):

    def __init__(self, config: ModernBertConfig):
        super().__init__()
        self.config = config
        self.Wi = nn.Linear(config.hidden_size,
                            int(config.intermediate_size) * 2,
                            bias=config.mlp_bias)
        self.act = nn.GELU()
        self.Wo = RowParallelLinear(config.intermediate_size,
                                    config.hidden_size,
                                    bias=config.mlp_bias)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        input, gate = self.Wi(hidden_states).chunk(2, dim=-1)
        return self.Wo(self.act(input) * gate)[0]


class ModernBertLayer(nn.Module):

    def __init__(self,
                 config: ModernBertConfig,
                 prefix: str = "",
                 layer_id: Optional[int] = None):
        super().__init__()
        self.config = config
        if layer_id == 0:
            self.attn_norm = nn.Identity()
        else:
            self.attn_norm = nn.LayerNorm(config.hidden_size,
                                          eps=config.norm_eps,
                                          bias=config.norm_bias)
        self.attn = ModernBertAttention(config=config, layer_id=layer_id)
        self.mlp_norm = nn.LayerNorm(config.hidden_size,
                                     eps=config.norm_eps,
                                     bias=config.norm_bias)
        self.mlp = ModernBertMLP(config)

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_ids: Optional[torch.LongTensor] = None,
    ):
        attn_outputs = self.attn(self.attn_norm(hidden_states),
                                 position_ids=position_ids)
        hidden_states = hidden_states + attn_outputs
        mlp_output = self.mlp(self.mlp_norm(hidden_states))
        hidden_states = hidden_states + mlp_output
        return hidden_states


class ModernBertEncoderLayer(nn.Module):

    def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        self.layers = nn.ModuleList([
            ModernBertLayer(config=config, layer_id=layer_id)
            for layer_id in range(config.num_hidden_layers)
        ])

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_ids: Optional[torch.LongTensor] = None,
    ) -> torch.Tensor:
        for i, layer in enumerate(self.layers):
            hidden_states = layer(hidden_states, position_ids)
        return hidden_states


@support_torch_compile
class ModernBertModel(nn.Module):
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={"layers.": "encoder_layer.layers."})

    def __init__(
        self,
        vllm_config: VllmConfig,
        prefix: str = "",
    ):
        super().__init__()
        config = vllm_config.model_config.hf_config
        self.config = config
        self.embeddings = ModernBertEmbeddings(config)
        self.encoder_layer = ModernBertEncoderLayer(vllm_config)
        self.final_norm = nn.LayerNorm(config.hidden_size,
                                       eps=config.norm_eps,
                                       bias=config.norm_bias)

222
223
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
xsank's avatar
xsank committed
224
225
        weights = self.hf_to_vllm_mapper.apply(weights)
        params_dict = dict(self.named_parameters())
226
        loaded_params: set[str] = set()
xsank's avatar
xsank committed
227
228
229
230
231
232
233
234
235
236
237
238
239
        for name, loaded_weight in weights:
            if name.endswith(".bias") and name not in params_dict:
                continue
            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
240
241
        positions: Optional[torch.Tensor] = None,
        intermediate_tensors: Optional[IntermediateTensors] = None,
xsank's avatar
xsank committed
242
243
244
        inputs_embeds: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
    ) -> torch.Tensor:
245
        position_ids = positions if positions is not None else position_ids
xsank's avatar
xsank committed
246
247
248
249
250
251
252
253
254
255
256
257
258
259
        if inputs_embeds is not None:
            hidden_states = inputs_embeds
        else:
            hidden_states = self.embeddings(input_ids=input_ids,
                                            inputs_embeds=inputs_embeds)

        outputs = self.encoder_layer(
            hidden_states=hidden_states,
            position_ids=position_ids,
        )
        norm_outputs = self.final_norm(outputs)
        return norm_outputs


260
class ModernBertPooler(Pooler):
xsank's avatar
xsank committed
261
262
263

    def __init__(self, config: ModernBertConfig):
        super().__init__()
264
265
266

        pooling_type = PoolingType[config.classifier_pooling.upper()]
        self.pooling = PoolingMethod.from_pooling_type(pooling_type)
xsank's avatar
xsank committed
267
268
        self.dense = nn.Linear(config.hidden_size, config.hidden_size,
                               config.classifier_bias)
269
        self.pooling_type = config.classifier_pooling
xsank's avatar
xsank committed
270
271
272
273
274
        self.act = nn.GELU()
        self.norm = nn.LayerNorm(config.hidden_size,
                                 eps=config.norm_eps,
                                 bias=config.norm_bias)

275
276
277
278
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return self.pooling.get_supported_tasks()

    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
279
        return self.pooling.get_pooling_updates(task)
280

281
282
283
    def _head(self, pooled_output: torch.Tensor):
        return self.norm(self.act(self.dense(pooled_output)))

284
285
286
287
288
289
    def forward(
        self,
        hidden_states: Union[torch.Tensor, list[torch.Tensor]],
        pooling_metadata: PoolingMetadata,
    ) -> Union[torch.Tensor, list[torch.Tensor]]:
        pooled_output = self.pooling(hidden_states, pooling_metadata)
290
291
292
293
294
295

        if isinstance(pooled_output, list):
            pooled_output = [self._head(output) for output in pooled_output]
        else:
            pooled_output = self._head(pooled_output)

xsank's avatar
xsank committed
296
297
298
        return pooled_output


299
300
class ModernBertForSequenceClassification(nn.Module, SupportsV0Only,
                                          SupportsCrossEncoding):
xsank's avatar
xsank committed
301

302
303
    is_pooling_model = True

xsank's avatar
xsank committed
304
305
306
307
308
309
310
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        self.config = config
        self.model = ModernBertModel(vllm_config=vllm_config,
                                     prefix=maybe_prefix(prefix, "modernbert"))
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332

        pooler_config = vllm_config.model_config.pooler_config
        assert pooler_config is not None

        self.pooler = DispatchPooler({
            "encode":
            Pooler.for_encode(pooler_config),
            "classify":
            ClassifierPooler(
                pooling=ModernBertPooler(config),
                classifier=self.classifier,
                act_fn=ClassifierPooler.act_fn_for_seq_cls(
                    vllm_config.model_config),
            ),
            "score":
            ClassifierPooler(
                pooling=ModernBertPooler(config),
                classifier=self.classifier,
                act_fn=ClassifierPooler.act_fn_for_cross_encoder(
                    vllm_config.model_config),
            ),
        })
xsank's avatar
xsank committed
333

334
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
xsank's avatar
xsank committed
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372

        self_weights = []

        def weight_filter():
            for name, weight in weights:
                if name.startswith("model."):
                    yield name[len("model."):], weight
                else:
                    self_weights.append((name, weight))

        self.model.load_weights(weight_filter())

        params_dict = dict(self.named_parameters())

        for name, loaded_weight in self_weights:
            if name.startswith("classifier"):
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
            if name.startswith("head"):
                param = params_dict["_pooler.pooler." + name[len("head") + 1:]]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)

    def forward(
        self,
        input_ids: Optional[torch.LongTensor],
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        return self.model(
            input_ids=input_ids,
            inputs_embeds=inputs_embeds,
            position_ids=positions,
        )