Unverified Commit 6f59beaf authored by antrec's avatar antrec Committed by GitHub
Browse files

[Model] Add support for ModernBertForTokenClassification (#26340)


Signed-off-by: default avatarAntoine Recanati Le Goat <antoine.recanati@sancare.fr>
Signed-off-by: default avatarantrec <antoine.recanati@gmail.com>
Co-authored-by: default avatarAntoine Recanati Le Goat <antoine.recanati@sancare.fr>
Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
parent 41f1cf38
...@@ -576,6 +576,7 @@ These models primarily support the [`LLM.encode`](./pooling_models.md#llmencode) ...@@ -576,6 +576,7 @@ These models primarily support the [`LLM.encode`](./pooling_models.md#llmencode)
| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) | | Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) |
|--------------|--------|-------------------|-----------------------------|-----------------------------------------|---------------------| |--------------|--------|-------------------|-----------------------------|-----------------------------------------|---------------------|
| `BertForTokenClassification` | bert-based | `boltuix/NeuroBERT-NER` (see note), etc. | | | ✅︎ | | `BertForTokenClassification` | bert-based | `boltuix/NeuroBERT-NER` (see note), etc. | | | ✅︎ |
| `ModernBertForTokenClassification` | ModernBERT-based | `disham993/electrical-ner-ModernBERT-base` | | | ✅︎ |
!!! note !!! note
Named Entity Recognition (NER) usage, please refer to <gh-file:examples/offline_inference/pooling/ner.py>, <gh-file:examples/online_serving/pooling/ner.py>. Named Entity Recognition (NER) usage, please refer to <gh-file:examples/offline_inference/pooling/ner.py>, <gh-file:examples/online_serving/pooling/ner.py>.
......
...@@ -11,7 +11,38 @@ from tests.models.utils import softmax ...@@ -11,7 +11,38 @@ from tests.models.utils import softmax
# The float32 is required for this tiny model to pass the test. # The float32 is required for this tiny model to pass the test.
@pytest.mark.parametrize("dtype", ["float"]) @pytest.mark.parametrize("dtype", ["float"])
@torch.inference_mode @torch.inference_mode
def test_models( def test_bert_models(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
) -> None:
with vllm_runner(model, max_model_len=None, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.encode(example_prompts)
with hf_runner(
model, dtype=dtype, auto_cls=AutoModelForTokenClassification
) as hf_model:
tokenizer = hf_model.tokenizer
hf_outputs = []
for prompt in example_prompts:
inputs = tokenizer([prompt], return_tensors="pt")
inputs = hf_model.wrap_device(inputs)
output = hf_model.model(**inputs)
hf_outputs.append(softmax(output.logits[0]))
# check logits difference
for hf_output, vllm_output in zip(hf_outputs, vllm_outputs):
hf_output = torch.tensor(hf_output).cpu().float()
vllm_output = torch.tensor(vllm_output).cpu().float()
assert torch.allclose(hf_output, vllm_output, 1e-2)
@pytest.mark.parametrize("model", ["disham993/electrical-ner-ModernBERT-base"])
@pytest.mark.parametrize("dtype", ["float"])
@torch.inference_mode
def test_modernbert_models(
hf_runner, hf_runner,
vllm_runner, vllm_runner,
example_prompts, example_prompts,
......
...@@ -527,6 +527,9 @@ _SEQUENCE_CLASSIFICATION_EXAMPLE_MODELS = { ...@@ -527,6 +527,9 @@ _SEQUENCE_CLASSIFICATION_EXAMPLE_MODELS = {
"ModernBertForSequenceClassification": _HfExamplesInfo( "ModernBertForSequenceClassification": _HfExamplesInfo(
"Alibaba-NLP/gte-reranker-modernbert-base" "Alibaba-NLP/gte-reranker-modernbert-base"
), ),
"ModernBertForTokenClassification": _HfExamplesInfo(
"disham993/electrical-ner-ModernBERT-base"
),
"RobertaForSequenceClassification": _HfExamplesInfo( "RobertaForSequenceClassification": _HfExamplesInfo(
"cross-encoder/quora-roberta-base" "cross-encoder/quora-roberta-base"
), ),
......
...@@ -6,6 +6,7 @@ from typing import Optional, Union ...@@ -6,6 +6,7 @@ from typing import Optional, Union
import torch import torch
from torch import nn from torch import nn
from transformers import ModernBertConfig from transformers import ModernBertConfig
from transformers.activations import ACT2FN
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
...@@ -29,7 +30,7 @@ from vllm.v1.pool.metadata import PoolingMetadata ...@@ -29,7 +30,7 @@ from vllm.v1.pool.metadata import PoolingMetadata
from .interfaces import SupportsCrossEncoding from .interfaces import SupportsCrossEncoding
from .interfaces_base import default_pooling_type from .interfaces_base import default_pooling_type
from .utils import WeightsMapper, maybe_prefix from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
class ModernBertEmbeddings(nn.Module): class ModernBertEmbeddings(nn.Module):
...@@ -379,3 +380,73 @@ class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding): ...@@ -379,3 +380,73 @@ class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding):
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
positions=positions, positions=positions,
) )
class ModernBertPredictionHead(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.dense = nn.Linear(
config.hidden_size, config.hidden_size, bias=config.classifier_bias
)
self.act = ACT2FN[config.classifier_activation]
self.norm = nn.LayerNorm(
config.hidden_size,
eps=getattr(config, "norm_eps", 1e-5),
bias=getattr(config, "norm_bias", True),
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return self.norm(self.act(self.dense(hidden_states)))
@default_pooling_type("ALL")
class ModernBertForTokenClassification(nn.Module):
is_pooling_model = True
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
self.head_dtype = vllm_config.model_config.head_dtype
self.num_labels = config.num_labels
self.model = ModernBertModel(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "modernbert")
)
self.head = ModernBertPredictionHead(config)
self.classifier = nn.Linear(
config.hidden_size, config.num_labels, dtype=self.head_dtype
)
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = DispatchPooler(
{
"encode": Pooler.for_encode(pooler_config),
}
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self, skip_prefixes=["drop"])
loaded_params = loader.load_weights(weights)
return loaded_params
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
hidden_states = self.model(
input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors,
)
hidden_states = self.head(hidden_states)
hidden_states = hidden_states.to(self.head_dtype)
return self.classifier(hidden_states)
...@@ -225,6 +225,10 @@ _CROSS_ENCODER_MODELS = { ...@@ -225,6 +225,10 @@ _CROSS_ENCODER_MODELS = {
"modernbert", "modernbert",
"ModernBertForSequenceClassification", "ModernBertForSequenceClassification",
), ),
"ModernBertForTokenClassification": (
"modernbert",
"ModernBertForTokenClassification",
),
"RobertaForSequenceClassification": ("roberta", "RobertaForSequenceClassification"), "RobertaForSequenceClassification": ("roberta", "RobertaForSequenceClassification"),
"XLMRobertaForSequenceClassification": ( "XLMRobertaForSequenceClassification": (
"roberta", "roberta",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment