modernbert.py 15.6 KB
Newer Older
xsank's avatar
xsank committed
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from collections.abc import Iterable, Set
xsank's avatar
xsank committed
4
5
6
7

import torch
from torch import nn
from transformers import ModernBertConfig
8
from transformers.activations import ACT2FN
xsank's avatar
xsank committed
9

10
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
11
from vllm.compilation.decorators import support_torch_compile
xsank's avatar
xsank committed
12
13
from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
14
15
16
17
18
19
20
21
22
from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear
from vllm.model_executor.layers.pooler import (
    ClassifierPooler,
    DispatchPooler,
    Pooler,
    PoolingMethod,
    PoolingParamsUpdate,
    PoolingType,
)
xsank's avatar
xsank committed
23
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
24
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
xsank's avatar
xsank committed
25
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
26
from vllm.sequence import IntermediateTensors
27
from vllm.tasks import PoolingTask
28
from vllm.v1.pool.metadata import PoolingMetadata
xsank's avatar
xsank committed
29

30
31
from .interfaces import SupportsCrossEncoding
from .interfaces_base import default_pooling_type
32
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
xsank's avatar
xsank committed
33
34
35
36
37
38


class ModernBertEmbeddings(nn.Module):
    def __init__(self, config: ModernBertConfig):
        super().__init__()
        self.config = config
39
40
41
        self.tok_embeddings = VocabParallelEmbedding(
            config.vocab_size, config.hidden_size
        )
42
43
44
45
        eps = (
            getattr(config, "norm_eps", None)
            or getattr(config, "layer_norm_eps", None)
            or 1e-5
46
        )
47
        self.norm = nn.LayerNorm(config.hidden_size, eps=eps, bias=config.norm_bias)
xsank's avatar
xsank committed
48

49
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
50
51
        return self.tok_embeddings(input_ids)

xsank's avatar
xsank committed
52
53
54
    def forward(
        self,
        input_ids: torch.Tensor,
55
        inputs_embeds: torch.Tensor | None = None,
xsank's avatar
xsank committed
56
    ) -> torch.Tensor:
57
        if inputs_embeds is not None:
xsank's avatar
xsank committed
58
59
60
61
62
63
64
65
            return self.norm(inputs_embeds)
        else:
            inputs_embeds = self.tok_embeddings(input_ids)
            embeddings = self.norm(inputs_embeds)
            return embeddings


class ModernBertRotaryEmbedding(RotaryEmbedding):
66
    def __init__(self, config: ModernBertConfig, head_size: int, dim: int, base: float):
xsank's avatar
xsank committed
67
68
69
70
71
72
        super().__init__(
            head_size=head_size,
            rotary_dim=dim,
            max_position_embeddings=config.max_position_embeddings,
            base=base,
            is_neox_style=True,
73
74
            dtype=torch.float16,
        )
xsank's avatar
xsank committed
75
76
77
78
        self.config = config


class ModernBertAttention(nn.Module):
79
    def __init__(self, config: ModernBertConfig, layer_id: int | None = None):
xsank's avatar
xsank committed
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.layer_id = layer_id
        self.deterministic_flash_attn = config.deterministic_flash_attn
        self.num_heads = config.num_attention_heads
        assert self.num_heads % tp_size == 0
        self.head_dim = config.hidden_size // config.num_attention_heads
        self.all_head_size = self.head_dim * self.num_heads
        self.scaling = self.head_dim**-0.5
        self.Wqkv = QKVParallelLinear(
            config.hidden_size,
            self.head_dim,
            self.num_heads,
            bias=config.attention_bias,
        )

98
        sliding_window = None
xsank's avatar
xsank committed
99
        if layer_id % config.global_attn_every_n_layers != 0:
100
            sliding_window = config.local_attention // 2
101
102
103
104
105
            rope_theta = (
                config.local_rope_theta
                if config.local_rope_theta is not None
                else config.global_rope_theta
            )
xsank's avatar
xsank committed
106
        else:
107
            rope_theta = config.global_rope_theta
xsank's avatar
xsank committed
108

109
110
111
        self.rotary_emb = ModernBertRotaryEmbedding(
            config=config, head_size=self.head_dim, dim=self.head_dim, base=rope_theta
        )
112
113
114
115
116
        self.attn = EncoderOnlyAttention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            prefix=f"{layer_id}.attn",
117
118
119
120
121
            per_layer_sliding_window=sliding_window,
        )
        self.Wo = RowParallelLinear(
            config.hidden_size, config.hidden_size, bias=config.attention_bias
        )
xsank's avatar
xsank committed
122
123
124
125

    def forward(
        self,
        hidden_states: torch.Tensor,
126
        position_ids: torch.Tensor,
xsank's avatar
xsank committed
127
128
129
130
131
132
133
134
135
136
137
138
139
140
    ) -> torch.Tensor:
        qkv, _ = self.Wqkv(hidden_states)
        q, k, v = qkv.split([self.all_head_size] * 3, dim=-1)
        q, k = self.rotary_emb(position_ids, q, k)
        attn_outputs = self.attn(q, k, v)
        hidden_states = attn_outputs
        hidden_states, _ = self.Wo(hidden_states)
        return hidden_states


class ModernBertMLP(nn.Module):
    def __init__(self, config: ModernBertConfig):
        super().__init__()
        self.config = config
141
142
143
        self.Wi = nn.Linear(
            config.hidden_size, int(config.intermediate_size) * 2, bias=config.mlp_bias
        )
xsank's avatar
xsank committed
144
        self.act = nn.GELU()
145
146
147
        self.Wo = RowParallelLinear(
            config.intermediate_size, config.hidden_size, bias=config.mlp_bias
        )
xsank's avatar
xsank committed
148
149
150
151
152
153
154

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        input, gate = self.Wi(hidden_states).chunk(2, dim=-1)
        return self.Wo(self.act(input) * gate)[0]


class ModernBertLayer(nn.Module):
155
    def __init__(
156
        self, config: ModernBertConfig, prefix: str = "", layer_id: int | None = None
157
    ):
xsank's avatar
xsank committed
158
159
160
161
162
        super().__init__()
        self.config = config
        if layer_id == 0:
            self.attn_norm = nn.Identity()
        else:
163
164
165
            self.attn_norm = nn.LayerNorm(
                config.hidden_size, eps=config.norm_eps, bias=config.norm_bias
            )
xsank's avatar
xsank committed
166
        self.attn = ModernBertAttention(config=config, layer_id=layer_id)
167
168
169
        self.mlp_norm = nn.LayerNorm(
            config.hidden_size, eps=config.norm_eps, bias=config.norm_bias
        )
xsank's avatar
xsank committed
170
171
172
173
174
        self.mlp = ModernBertMLP(config)

    def forward(
        self,
        hidden_states: torch.Tensor,
175
176
        position_ids: torch.Tensor,
    ) -> torch.Tensor:
177
178
179
        attn_outputs = self.attn(
            hidden_states=self.attn_norm(hidden_states), position_ids=position_ids
        )
xsank's avatar
xsank committed
180
181
182
183
184
185
186
187
188
189
        hidden_states = hidden_states + attn_outputs
        mlp_output = self.mlp(self.mlp_norm(hidden_states))
        hidden_states = hidden_states + mlp_output
        return hidden_states


class ModernBertEncoderLayer(nn.Module):
    def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
190
191
192
193
194
195
        self.layers = nn.ModuleList(
            [
                ModernBertLayer(config=config, layer_id=layer_id)
                for layer_id in range(config.num_hidden_layers)
            ]
        )
xsank's avatar
xsank committed
196
197
198
199

    def forward(
        self,
        hidden_states: torch.Tensor,
200
        position_ids: torch.Tensor,
xsank's avatar
xsank committed
201
202
203
204
205
206
    ) -> torch.Tensor:
        for i, layer in enumerate(self.layers):
            hidden_states = layer(hidden_states, position_ids)
        return hidden_states


207
@support_torch_compile
208
@default_pooling_type("CLS")
xsank's avatar
xsank committed
209
210
class ModernBertModel(nn.Module):
    hf_to_vllm_mapper = WeightsMapper(
211
212
        orig_to_new_prefix={"layers.": "encoder_layer.layers."}
    )
xsank's avatar
xsank committed
213
214
215
216
217
218
219
220
221
222
223

    def __init__(
        self,
        vllm_config: VllmConfig,
        prefix: str = "",
    ):
        super().__init__()
        config = vllm_config.model_config.hf_config
        self.config = config
        self.embeddings = ModernBertEmbeddings(config)
        self.encoder_layer = ModernBertEncoderLayer(vllm_config)
224
225
226
        self.final_norm = nn.LayerNorm(
            config.hidden_size, eps=config.norm_eps, bias=config.norm_bias
        )
xsank's avatar
xsank committed
227

228
229
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embeddings.embed_input_ids(input_ids)
230

231
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
xsank's avatar
xsank committed
232
233
        weights = self.hf_to_vllm_mapper.apply(weights)
        params_dict = dict(self.named_parameters())
234
        loaded_params: set[str] = set()
xsank's avatar
xsank committed
235
236
237
238
        for name, loaded_weight in weights:
            if name.endswith(".bias") and name not in params_dict:
                continue
            param = params_dict[name]
239
            weight_loader = getattr(param, "weight_loader", default_weight_loader)
xsank's avatar
xsank committed
240
241
242
243
244
245
            weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

    def forward(
        self,
246
247
        input_ids: torch.Tensor,
        positions: torch.Tensor,
248
249
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
xsank's avatar
xsank committed
250
251
252
253
    ) -> torch.Tensor:
        if inputs_embeds is not None:
            hidden_states = inputs_embeds
        else:
254
255
256
            hidden_states = self.embeddings(
                input_ids=input_ids, inputs_embeds=inputs_embeds
            )
xsank's avatar
xsank committed
257
258
259

        outputs = self.encoder_layer(
            hidden_states=hidden_states,
260
            position_ids=positions,
xsank's avatar
xsank committed
261
262
263
264
265
        )
        norm_outputs = self.final_norm(outputs)
        return norm_outputs


266
class ModernBertPooler(Pooler):
xsank's avatar
xsank committed
267
268
    def __init__(self, config: ModernBertConfig):
        super().__init__()
269
270
271

        pooling_type = PoolingType[config.classifier_pooling.upper()]
        self.pooling = PoolingMethod.from_pooling_type(pooling_type)
272
273
274
        self.dense = nn.Linear(
            config.hidden_size, config.hidden_size, config.classifier_bias
        )
xsank's avatar
xsank committed
275
        self.act = nn.GELU()
276
277
278
        self.norm = nn.LayerNorm(
            config.hidden_size, eps=config.norm_eps, bias=config.norm_bias
        )
xsank's avatar
xsank committed
279

280
281
282
283
    def get_supported_tasks(self) -> Set[PoolingTask]:
        return self.pooling.get_supported_tasks()

    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
284
        return self.pooling.get_pooling_updates(task)
285

286
    def _head(self, pooled_output: torch.Tensor):
287
        pooled_output = pooled_output.to(self.dense.weight.dtype)
288
289
        return self.norm(self.act(self.dense(pooled_output)))

290
291
    def forward(
        self,
292
        hidden_states: torch.Tensor | list[torch.Tensor],
293
        pooling_metadata: PoolingMetadata,
294
    ) -> torch.Tensor | list[torch.Tensor]:
295
        pooled_output = self.pooling(hidden_states, pooling_metadata)
296
297
298
299
300
301

        if isinstance(pooled_output, list):
            pooled_output = [self._head(output) for output in pooled_output]
        else:
            pooled_output = self._head(pooled_output)

xsank's avatar
xsank committed
302
303
304
        return pooled_output


305
@default_pooling_type("CLS")
306
class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding):
307
308
    is_pooling_model = True

xsank's avatar
xsank committed
309
310
311
312
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        self.config = config
313
314
315
316
317
318
319
320
        self.model = ModernBertModel(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "modernbert")
        )
        self.classifier = nn.Linear(
            config.hidden_size,
            config.num_labels,
            dtype=vllm_config.model_config.head_dtype,
        )
321
        self.pooling = ModernBertPooler(config)
322
323
324
325

        pooler_config = vllm_config.model_config.pooler_config
        assert pooler_config is not None

326
327
        self.pooler = DispatchPooler(
            {
328
329
330
                "token_classify": Pooler.for_token_classify(
                    pooler_config, classifier=self.classifier
                ),
331
                "classify": ClassifierPooler(
332
                    pooling=self.pooling, classifier=self.classifier, act_fn="classify"
333
334
                ),
                "score": ClassifierPooler(
335
                    pooling=self.pooling, classifier=self.classifier, act_fn="score"
336
337
338
                ),
            }
        )
xsank's avatar
xsank committed
339

340
341
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
342

343
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
xsank's avatar
xsank committed
344
345
346
347
348
        self_weights = []

        def weight_filter():
            for name, weight in weights:
                if name.startswith("model."):
349
                    yield name[len("model.") :], weight
xsank's avatar
xsank committed
350
351
352
353
354
355
356
357
358
359
                else:
                    self_weights.append((name, weight))

        self.model.load_weights(weight_filter())

        params_dict = dict(self.named_parameters())

        for name, loaded_weight in self_weights:
            if name.startswith("classifier"):
                param = params_dict[name]
360
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
xsank's avatar
xsank committed
361
362
                weight_loader(param, loaded_weight)
            if name.startswith("head"):
363
364
                param = params_dict["pooling." + name[len("head") + 1 :]]
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
xsank's avatar
xsank committed
365
366
367
368
                weight_loader(param, loaded_weight)

    def forward(
        self,
369
        input_ids: torch.LongTensor | None,
xsank's avatar
xsank committed
370
        positions: torch.Tensor,
371
372
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
xsank's avatar
xsank committed
373
374
375
376
    ) -> torch.Tensor:
        return self.model(
            input_ids=input_ids,
            inputs_embeds=inputs_embeds,
377
            positions=positions,
xsank's avatar
xsank committed
378
        )
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420


class ModernBertPredictionHead(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.dense = nn.Linear(
            config.hidden_size, config.hidden_size, bias=config.classifier_bias
        )
        self.act = ACT2FN[config.classifier_activation]
        self.norm = nn.LayerNorm(
            config.hidden_size,
            eps=getattr(config, "norm_eps", 1e-5),
            bias=getattr(config, "norm_bias", True),
        )

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        return self.norm(self.act(self.dense(hidden_states)))


@default_pooling_type("ALL")
class ModernBertForTokenClassification(nn.Module):
    is_pooling_model = True

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        self.head_dtype = vllm_config.model_config.head_dtype
        self.num_labels = config.num_labels
        self.model = ModernBertModel(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "modernbert")
        )
        self.head = ModernBertPredictionHead(config)
        self.classifier = nn.Linear(
            config.hidden_size, config.num_labels, dtype=self.head_dtype
        )

        pooler_config = vllm_config.model_config.pooler_config
        assert pooler_config is not None

        self.pooler = DispatchPooler(
            {
421
422
423
                "token_classify": Pooler.for_token_classify(
                    pooler_config=pooler_config
                ),
424
425
426
            }
        )

427
428
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
429
430
431
432
433
434
435
436

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
        loader = AutoWeightsLoader(self, skip_prefixes=["drop"])
        loaded_params = loader.load_weights(weights)
        return loaded_params

    def forward(
        self,
437
        input_ids: torch.Tensor | None,
438
        positions: torch.Tensor,
439
440
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
441
442
443
444
445
446
447
448
449
450
    ) -> torch.Tensor:
        hidden_states = self.model(
            input_ids=input_ids,
            positions=positions,
            inputs_embeds=inputs_embeds,
            intermediate_tensors=intermediate_tensors,
        )
        hidden_states = self.head(hidden_states)
        hidden_states = hidden_states.to(self.head_dtype)
        return self.classifier(hidden_states)