Unverified Commit 11a7fafa authored by wang.yuqi's avatar wang.yuqi Committed by GitHub
Browse files

[New Model]: Support GteNewModelForSequenceClassification (#23524)


Signed-off-by: default avatarwang.yuqi <noooop@126.com>
parent 186aced5
...@@ -497,6 +497,7 @@ These models primarily support the [`LLM.score`](./pooling_models.md#llmscore) A ...@@ -497,6 +497,7 @@ These models primarily support the [`LLM.score`](./pooling_models.md#llmscore) A
|--------------|--------|-------------------|----------------------|---------------------------|---------------------| |--------------|--------|-------------------|----------------------|---------------------------|---------------------|
| `BertForSequenceClassification` | BERT-based | `cross-encoder/ms-marco-MiniLM-L-6-v2`, etc. | | | ✅︎ | | `BertForSequenceClassification` | BERT-based | `cross-encoder/ms-marco-MiniLM-L-6-v2`, etc. | | | ✅︎ |
| `GemmaForSequenceClassification` | Gemma-based | `BAAI/bge-reranker-v2-gemma` (see note), etc. | ✅︎ | ✅︎ | ✅︎ | | `GemmaForSequenceClassification` | Gemma-based | `BAAI/bge-reranker-v2-gemma` (see note), etc. | ✅︎ | ✅︎ | ✅︎ |
| `GteNewForSequenceClassification` | mGTE-TRM (see note) | `Alibaba-NLP/gte-multilingual-reranker-base`, etc. | | | ✅︎ |
| `Qwen2ForSequenceClassification` | Qwen2-based | `mixedbread-ai/mxbai-rerank-base-v2` (see note), etc. | ✅︎ | ✅︎ | ✅︎ | | `Qwen2ForSequenceClassification` | Qwen2-based | `mixedbread-ai/mxbai-rerank-base-v2` (see note), etc. | ✅︎ | ✅︎ | ✅︎ |
| `Qwen3ForSequenceClassification` | Qwen3-based | `tomaarsen/Qwen3-Reranker-0.6B-seq-cls`, `Qwen/Qwen3-Reranker-0.6B` (see note), etc. | ✅︎ | ✅︎ | ✅︎ | | `Qwen3ForSequenceClassification` | Qwen3-based | `tomaarsen/Qwen3-Reranker-0.6B-seq-cls`, `Qwen/Qwen3-Reranker-0.6B` (see note), etc. | ✅︎ | ✅︎ | ✅︎ |
| `RobertaForSequenceClassification` | RoBERTa-based | `cross-encoder/quora-roberta-base`, etc. | | | ✅︎ | | `RobertaForSequenceClassification` | RoBERTa-based | `cross-encoder/quora-roberta-base`, etc. | | | ✅︎ |
...@@ -513,6 +514,9 @@ These models primarily support the [`LLM.score`](./pooling_models.md#llmscore) A ...@@ -513,6 +514,9 @@ These models primarily support the [`LLM.score`](./pooling_models.md#llmscore) A
vllm serve BAAI/bge-reranker-v2-gemma --hf_overrides '{"architectures": ["GemmaForSequenceClassification"],"classifier_from_token": ["Yes"],"method": "no_post_processing"}' vllm serve BAAI/bge-reranker-v2-gemma --hf_overrides '{"architectures": ["GemmaForSequenceClassification"],"classifier_from_token": ["Yes"],"method": "no_post_processing"}'
``` ```
!!! note
The second-generation GTE model (mGTE-TRM) is named `NewForSequenceClassification`. The name `NewForSequenceClassification` is too generic, you should set `--hf-overrides '{"architectures": ["GteNewForSequenceClassification"]}'` to specify the use of the `GteNewForSequenceClassification` architecture.
!!! note !!! note
Load the official original `mxbai-rerank-v2` by using the following command. Load the official original `mxbai-rerank-v2` by using the following command.
......
...@@ -456,11 +456,10 @@ class HfRunner: ...@@ -456,11 +456,10 @@ class HfRunner:
# output is final logits # output is final logits
all_inputs = self.get_inputs(prompts) all_inputs = self.get_inputs(prompts)
outputs = [] outputs = []
problem_type = getattr(self.config, "problem_type", "")
for inputs in all_inputs: for inputs in all_inputs:
output = self.model(**self.wrap_device(inputs)) output = self.model(**self.wrap_device(inputs))
problem_type = getattr(self.config, "problem_type", "")
if problem_type == "regression": if problem_type == "regression":
logits = output.logits[0].tolist() logits = output.logits[0].tolist()
elif problem_type == "multi_label_classification": elif problem_type == "multi_label_classification":
......
...@@ -51,6 +51,9 @@ def correctness_test_embed_models(hf_runner, ...@@ -51,6 +51,9 @@ def correctness_test_embed_models(hf_runner,
vllm_extra_kwargs = vllm_extra_kwargs or {} vllm_extra_kwargs = vllm_extra_kwargs or {}
vllm_extra_kwargs["dtype"] = model_info.dtype vllm_extra_kwargs["dtype"] = model_info.dtype
if model_info.hf_overrides is not None:
vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides
with vllm_runner(model_info.name, with vllm_runner(model_info.name,
runner="pooling", runner="pooling",
max_model_len=None, max_model_len=None,
......
...@@ -172,6 +172,9 @@ def mteb_test_embed_models(hf_runner, ...@@ -172,6 +172,9 @@ def mteb_test_embed_models(hf_runner,
vllm_extra_kwargs = vllm_extra_kwargs or {} vllm_extra_kwargs = vllm_extra_kwargs or {}
vllm_extra_kwargs["dtype"] = model_info.dtype vllm_extra_kwargs["dtype"] = model_info.dtype
if model_info.hf_overrides is not None:
vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides
with vllm_runner(model_info.name, with vllm_runner(model_info.name,
runner="pooling", runner="pooling",
max_model_len=None, max_model_len=None,
...@@ -284,6 +287,9 @@ def mteb_test_rerank_models(hf_runner, ...@@ -284,6 +287,9 @@ def mteb_test_rerank_models(hf_runner,
vllm_extra_kwargs = vllm_extra_kwargs or {} vllm_extra_kwargs = vllm_extra_kwargs or {}
vllm_extra_kwargs["dtype"] = model_info.dtype vllm_extra_kwargs["dtype"] = model_info.dtype
if model_info.hf_overrides is not None:
vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides
with vllm_runner(model_info.name, with vllm_runner(model_info.name,
runner="pooling", runner="pooling",
max_model_len=None, max_model_len=None,
......
...@@ -13,7 +13,14 @@ from .mteb_utils import VllmMtebEncoder, mteb_test_rerank_models ...@@ -13,7 +13,14 @@ from .mteb_utils import VllmMtebEncoder, mteb_test_rerank_models
RERANK_MODELS = [ RERANK_MODELS = [
LASTPoolingRerankModelInfo("BAAI/bge-reranker-v2-gemma", LASTPoolingRerankModelInfo("BAAI/bge-reranker-v2-gemma",
architecture="GemmaForSequenceClassification"), architecture="GemmaForSequenceClassification",
hf_overrides={
"architectures":
["GemmaForSequenceClassification"],
"classifier_from_token": ["Yes"],
"method":
"no_post_processing",
}),
] ]
PROMPT = "Given a query A and a passage B, determine whether the passage contains an answer to the query by providing a prediction of either 'Yes' or 'No'." # noqa: E501 PROMPT = "Given a query A and a passage B, determine whether the passage contains an answer to the query by providing a prediction of either 'Yes' or 'No'." # noqa: E501
...@@ -119,22 +126,9 @@ class GemmaMtebEncoder(VllmMtebEncoder): ...@@ -119,22 +126,9 @@ class GemmaMtebEncoder(VllmMtebEncoder):
@pytest.mark.parametrize("model_info", RERANK_MODELS) @pytest.mark.parametrize("model_info", RERANK_MODELS)
def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo, def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None:
monkeypatch) -> None:
monkeypatch.setenv("VLLM_USE_V1", "0")
assert model_info.architecture == "GemmaForSequenceClassification"
vllm_extra_kwargs: dict[str, Any] = {
"hf_overrides": {
"architectures": ["GemmaForSequenceClassification"],
"classifier_from_token": ["Yes"],
"method": "no_post_processing",
}
}
mteb_test_rerank_models(GemmaRerankerHfRunner, mteb_test_rerank_models(GemmaRerankerHfRunner,
vllm_runner, vllm_runner,
model_info, model_info,
vllm_extra_kwargs,
vllm_mteb_encoder=GemmaMtebEncoder) vllm_mteb_encoder=GemmaMtebEncoder)
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
import pytest import pytest
...@@ -33,12 +32,15 @@ MODELS = [ ...@@ -33,12 +32,15 @@ MODELS = [
########### NewModel ########### NewModel
CLSPoolingEmbedModelInfo("Alibaba-NLP/gte-multilingual-base", CLSPoolingEmbedModelInfo("Alibaba-NLP/gte-multilingual-base",
architecture="GteNewModel", architecture="GteNewModel",
hf_overrides={"architectures": ["GteNewModel"]},
enable_test=True), enable_test=True),
CLSPoolingEmbedModelInfo("Alibaba-NLP/gte-base-en-v1.5", CLSPoolingEmbedModelInfo("Alibaba-NLP/gte-base-en-v1.5",
architecture="GteNewModel", architecture="GteNewModel",
hf_overrides={"architectures": ["GteNewModel"]},
enable_test=True), enable_test=True),
CLSPoolingEmbedModelInfo("Alibaba-NLP/gte-large-en-v1.5", CLSPoolingEmbedModelInfo("Alibaba-NLP/gte-large-en-v1.5",
architecture="GteNewModel", architecture="GteNewModel",
hf_overrides={"architectures": ["GteNewModel"]},
enable_test=True), enable_test=True),
########### Qwen2ForCausalLM ########### Qwen2ForCausalLM
LASTPoolingEmbedModelInfo("Alibaba-NLP/gte-Qwen2-1.5B-instruct", LASTPoolingEmbedModelInfo("Alibaba-NLP/gte-Qwen2-1.5B-instruct",
...@@ -60,11 +62,16 @@ MODELS = [ ...@@ -60,11 +62,16 @@ MODELS = [
] ]
RERANK_MODELS = [ RERANK_MODELS = [
# classifier_pooling: mean
CLSPoolingRerankModelInfo( CLSPoolingRerankModelInfo(
# classifier_pooling: mean
"Alibaba-NLP/gte-reranker-modernbert-base", "Alibaba-NLP/gte-reranker-modernbert-base",
architecture="ModernBertForSequenceClassification", architecture="ModernBertForSequenceClassification",
enable_test=True), enable_test=True),
CLSPoolingRerankModelInfo(
"Alibaba-NLP/gte-multilingual-reranker-base",
architecture="GteNewForSequenceClassification",
hf_overrides={"architectures": ["GteNewForSequenceClassification"]},
enable_test=True),
] ]
...@@ -75,12 +82,7 @@ def test_embed_models_mteb(hf_runner, vllm_runner, ...@@ -75,12 +82,7 @@ def test_embed_models_mteb(hf_runner, vllm_runner,
check_transformers_version(model_info.name, check_transformers_version(model_info.name,
max_transformers_version="4.53.2") max_transformers_version="4.53.2")
vllm_extra_kwargs: dict[str, Any] = {} mteb_test_embed_models(hf_runner, vllm_runner, model_info)
if model_info.architecture == "GteNewModel":
vllm_extra_kwargs["hf_overrides"] = {"architectures": ["GteNewModel"]}
mteb_test_embed_models(hf_runner, vllm_runner, model_info,
vllm_extra_kwargs)
@pytest.mark.parametrize("model_info", MODELS) @pytest.mark.parametrize("model_info", MODELS)
...@@ -91,12 +93,8 @@ def test_embed_models_correctness(hf_runner, vllm_runner, ...@@ -91,12 +93,8 @@ def test_embed_models_correctness(hf_runner, vllm_runner,
check_transformers_version(model_info.name, check_transformers_version(model_info.name,
max_transformers_version="4.53.2") max_transformers_version="4.53.2")
vllm_extra_kwargs: dict[str, Any] = {}
if model_info.architecture == "GteNewModel":
vllm_extra_kwargs["hf_overrides"] = {"architectures": ["GteNewModel"]}
correctness_test_embed_models(hf_runner, vllm_runner, model_info, correctness_test_embed_models(hf_runner, vllm_runner, model_info,
example_prompts, vllm_extra_kwargs) example_prompts)
@pytest.mark.parametrize("model_info", RERANK_MODELS) @pytest.mark.parametrize("model_info", RERANK_MODELS)
......
...@@ -10,12 +10,20 @@ from tests.conftest import HfRunner ...@@ -10,12 +10,20 @@ from tests.conftest import HfRunner
from ...utils import LASTPoolingRerankModelInfo, RerankModelInfo from ...utils import LASTPoolingRerankModelInfo, RerankModelInfo
from .mteb_utils import mteb_test_rerank_models from .mteb_utils import mteb_test_rerank_models
mxbai_rerank_hf_overrides = {
"architectures": ["Qwen2ForSequenceClassification"],
"classifier_from_token": ["0", "1"],
"method": "from_2_way_softmax",
}
RERANK_MODELS = [ RERANK_MODELS = [
LASTPoolingRerankModelInfo("mixedbread-ai/mxbai-rerank-base-v2", LASTPoolingRerankModelInfo("mixedbread-ai/mxbai-rerank-base-v2",
architecture="Qwen2ForSequenceClassification", architecture="Qwen2ForSequenceClassification",
hf_overrides=mxbai_rerank_hf_overrides,
enable_test=True), enable_test=True),
LASTPoolingRerankModelInfo("mixedbread-ai/mxbai-rerank-large-v2", LASTPoolingRerankModelInfo("mixedbread-ai/mxbai-rerank-large-v2",
architecture="Qwen2ForSequenceClassification", architecture="Qwen2ForSequenceClassification",
hf_overrides=mxbai_rerank_hf_overrides,
enable_test=False) enable_test=False)
] ]
...@@ -71,13 +79,4 @@ class MxbaiRerankerHfRunner(HfRunner): ...@@ -71,13 +79,4 @@ class MxbaiRerankerHfRunner(HfRunner):
@pytest.mark.parametrize("model_info", RERANK_MODELS) @pytest.mark.parametrize("model_info", RERANK_MODELS)
def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None: def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None:
vllm_extra_kwargs: dict[str, Any] = {} mteb_test_rerank_models(MxbaiRerankerHfRunner, vllm_runner, model_info)
if model_info.architecture == "Qwen2ForSequenceClassification":
vllm_extra_kwargs["hf_overrides"] = {
"architectures": ["Qwen2ForSequenceClassification"],
"classifier_from_token": ["0", "1"],
"method": "from_2_way_softmax",
}
mteb_test_rerank_models(MxbaiRerankerHfRunner, vllm_runner, model_info,
vllm_extra_kwargs)
...@@ -11,12 +11,20 @@ from tests.utils import multi_gpu_test ...@@ -11,12 +11,20 @@ from tests.utils import multi_gpu_test
from ...utils import LASTPoolingRerankModelInfo, RerankModelInfo from ...utils import LASTPoolingRerankModelInfo, RerankModelInfo
from .mteb_utils import mteb_test_rerank_models from .mteb_utils import mteb_test_rerank_models
qwen3_reranker_hf_overrides = {
"architectures": ["Qwen3ForSequenceClassification"],
"classifier_from_token": ["no", "yes"],
"is_original_qwen3_reranker": True,
}
RERANK_MODELS = [ RERANK_MODELS = [
LASTPoolingRerankModelInfo("Qwen/Qwen3-Reranker-0.6B", LASTPoolingRerankModelInfo("Qwen/Qwen3-Reranker-0.6B",
architecture="Qwen3ForSequenceClassification", architecture="Qwen3ForSequenceClassification",
hf_overrides=qwen3_reranker_hf_overrides,
enable_test=True), enable_test=True),
LASTPoolingRerankModelInfo("Qwen/Qwen3-Reranker-4B", LASTPoolingRerankModelInfo("Qwen/Qwen3-Reranker-4B",
architecture="Qwen3ForSequenceClassification", architecture="Qwen3ForSequenceClassification",
hf_overrides=qwen3_reranker_hf_overrides,
enable_test=False) enable_test=False)
] ]
...@@ -74,18 +82,7 @@ class Qwen3RerankerHfRunner(HfRunner): ...@@ -74,18 +82,7 @@ class Qwen3RerankerHfRunner(HfRunner):
@pytest.mark.parametrize("model_info", RERANK_MODELS) @pytest.mark.parametrize("model_info", RERANK_MODELS)
def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None: def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None:
assert model_info.architecture == "Qwen3ForSequenceClassification" mteb_test_rerank_models(Qwen3RerankerHfRunner, vllm_runner, model_info)
vllm_extra_kwargs: dict[str, Any] = {
"hf_overrides": {
"architectures": ["Qwen3ForSequenceClassification"],
"classifier_from_token": ["no", "yes"],
"is_original_qwen3_reranker": True,
}
}
mteb_test_rerank_models(Qwen3RerankerHfRunner, vllm_runner, model_info,
vllm_extra_kwargs)
@pytest.mark.parametrize("model_info", RERANK_MODELS) @pytest.mark.parametrize("model_info", RERANK_MODELS)
...@@ -96,11 +93,6 @@ def test_rerank_models_mteb_tp(vllm_runner, ...@@ -96,11 +93,6 @@ def test_rerank_models_mteb_tp(vllm_runner,
assert model_info.architecture == "Qwen3ForSequenceClassification" assert model_info.architecture == "Qwen3ForSequenceClassification"
vllm_extra_kwargs: dict[str, Any] = { vllm_extra_kwargs: dict[str, Any] = {
"hf_overrides": {
"architectures": ["Qwen3ForSequenceClassification"],
"classifier_from_token": ["no", "yes"],
"is_original_qwen3_reranker": True,
},
"tensor_parallel_size": 2, "tensor_parallel_size": 2,
} }
......
...@@ -365,6 +365,10 @@ _SEQUENCE_CLASSIFICATION_EXAMPLE_MODELS = { ...@@ -365,6 +365,10 @@ _SEQUENCE_CLASSIFICATION_EXAMPLE_MODELS = {
# [Cross-encoder] # [Cross-encoder]
"BertForSequenceClassification": _HfExamplesInfo("cross-encoder/ms-marco-MiniLM-L-6-v2", v0_only=True), # noqa: E501 "BertForSequenceClassification": _HfExamplesInfo("cross-encoder/ms-marco-MiniLM-L-6-v2", v0_only=True), # noqa: E501
"GteNewForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-multilingual-reranker-base", # noqa: E501
trust_remote_code=True,
hf_overrides={
"architectures": ["GteNewForSequenceClassification"]}),# noqa: E501
"ModernBertForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-reranker-modernbert-base", v0_only=True), # noqa: E501 "ModernBertForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-reranker-modernbert-base", v0_only=True), # noqa: E501
"RobertaForSequenceClassification": _HfExamplesInfo("cross-encoder/quora-roberta-base", v0_only=True), # noqa: E501 "RobertaForSequenceClassification": _HfExamplesInfo("cross-encoder/quora-roberta-base", v0_only=True), # noqa: E501
"XLMRobertaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-m3", v0_only=True), # noqa: E501 "XLMRobertaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-m3", v0_only=True), # noqa: E501
......
...@@ -3,7 +3,8 @@ ...@@ -3,7 +3,8 @@
import warnings import warnings
from collections.abc import Sequence from collections.abc import Sequence
from typing import Any, NamedTuple, Optional, Union from dataclasses import dataclass
from typing import Any, Optional, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -339,36 +340,43 @@ def softmax(data): ...@@ -339,36 +340,43 @@ def softmax(data):
return F.softmax(data, dim=-1) return F.softmax(data, dim=-1)
class EmbedModelInfo(NamedTuple): @dataclass
class ModelInfo:
name: str name: str
is_matryoshka: bool = False
matryoshka_dimensions: Optional[list[int]] = None
architecture: str = "" architecture: str = ""
dtype: str = "auto" dtype: str = "auto"
hf_overrides: Optional[dict[str, Any]] = None
default_pooling_type: str = "" default_pooling_type: str = ""
enable_test: bool = True enable_test: bool = True
@dataclass
class EmbedModelInfo(ModelInfo):
is_matryoshka: bool = False
matryoshka_dimensions: Optional[list[int]] = None
@dataclass
class CLSPoolingEmbedModelInfo(EmbedModelInfo): class CLSPoolingEmbedModelInfo(EmbedModelInfo):
default_pooling_type: str = "CLS" default_pooling_type: str = "CLS"
@dataclass
class LASTPoolingEmbedModelInfo(EmbedModelInfo): class LASTPoolingEmbedModelInfo(EmbedModelInfo):
default_pooling_type: str = "LAST" default_pooling_type: str = "LAST"
class RerankModelInfo(NamedTuple): @dataclass
name: str class RerankModelInfo(ModelInfo):
architecture: str = "" pass
dtype: str = "auto"
default_pooling_type: str = ""
enable_test: bool = True
@dataclass
class CLSPoolingRerankModelInfo(RerankModelInfo): class CLSPoolingRerankModelInfo(RerankModelInfo):
default_pooling_type: str = "CLS" default_pooling_type: str = "CLS"
@dataclass
class LASTPoolingRerankModelInfo(RerankModelInfo): class LASTPoolingRerankModelInfo(RerankModelInfo):
default_pooling_type: str = "LAST" default_pooling_type: str = "LAST"
......
...@@ -27,12 +27,15 @@ from vllm.model_executor.layers.rotary_embedding import get_rope ...@@ -27,12 +27,15 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.utils import WeightsMapper from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper,
maybe_prefix)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsQuant from ..layers.pooler import ClassifierPooler, DispatchPooler, Pooler
from .bert import BertPooler
from .interfaces import SupportsCrossEncoding, SupportsQuant
from .interfaces_base import default_pooling_type from .interfaces_base import default_pooling_type
...@@ -406,9 +409,14 @@ class BertWithRopeEncoder(nn.Module): ...@@ -406,9 +409,14 @@ class BertWithRopeEncoder(nn.Module):
class BertWithRope(nn.Module, SupportsQuant): class BertWithRope(nn.Module, SupportsQuant):
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self,
*,
vllm_config: VllmConfig,
prefix: str = "",
add_pooling_layer: bool = False):
super().__init__() super().__init__()
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.add_pooling_layer = add_pooling_layer
self.config = vllm_config.model_config.hf_config self.config = vllm_config.model_config.hf_config
self.embeddings = BertWithRopeEmbedding(self.config) self.embeddings = BertWithRopeEmbedding(self.config)
self.encoder = BertWithRopeEncoder( self.encoder = BertWithRopeEncoder(
...@@ -416,6 +424,7 @@ class BertWithRope(nn.Module, SupportsQuant): ...@@ -416,6 +424,7 @@ class BertWithRope(nn.Module, SupportsQuant):
bias=getattr(self.config, "bias", True), bias=getattr(self.config, "bias", True),
rotary_kwargs=self.config.rotary_kwargs, rotary_kwargs=self.config.rotary_kwargs,
prefix=f"{prefix}.encoder") prefix=f"{prefix}.encoder")
self.pooler = BertPooler(self.config) if add_pooling_layer else None
def forward( def forward(
self, self,
...@@ -448,7 +457,7 @@ class BertWithRope(nn.Module, SupportsQuant): ...@@ -448,7 +457,7 @@ class BertWithRope(nn.Module, SupportsQuant):
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "pooler" in name: if not self.add_pooling_layer and "pooler" in name:
continue continue
for (param_name, weight_name, shard_id) in stacked_params_mapping: for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
...@@ -508,8 +517,8 @@ class GteNewModel(BertWithRope): ...@@ -508,8 +517,8 @@ class GteNewModel(BertWithRope):
"attention.o_proj": "attn.out_proj", "attention.o_proj": "attn.out_proj",
}) })
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs):
super().__init__(vllm_config=vllm_config, prefix=prefix) super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
# GteNewModel only gate_up_proj does not have bias. # GteNewModel only gate_up_proj does not have bias.
# Hack method learned from vllm/model_executor/models/glm.py # Hack method learned from vllm/model_executor/models/glm.py
...@@ -614,3 +623,65 @@ class JinaRobertaModel(BertWithRope): ...@@ -614,3 +623,65 @@ class JinaRobertaModel(BertWithRope):
torch.Tensor]]) -> set[str]: torch.Tensor]]) -> set[str]:
weights = self.jina_merge_lora_weights(weights) weights = self.jina_merge_lora_weights(weights)
return super().load_weights(weights) return super().load_weights(weights)
@default_pooling_type("CLS")
class GteNewForSequenceClassification(nn.Module, SupportsCrossEncoding):
is_pooling_model = True
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.new = GteNewModel(vllm_config=vllm_config,
prefix=prefix,
add_pooling_layer=True)
self.classifier = RowParallelLinear(config.hidden_size,
config.num_labels,
input_is_parallel=False,
bias=True,
quant_config=quant_config,
prefix=maybe_prefix(
prefix, "classifier"),
return_bias=False)
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = DispatchPooler({
"encode":
Pooler.for_encode(pooler_config),
"classify":
ClassifierPooler(
pooling=self.new.pooler,
classifier=self.classifier,
act_fn=ClassifierPooler.act_fn_for_seq_cls(
vllm_config.model_config),
),
"score":
ClassifierPooler(
pooling=self.new.pooler,
classifier=self.classifier,
act_fn=ClassifierPooler.act_fn_for_cross_encoder(
vllm_config.model_config),
),
})
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self)
loaded_params = loader.load_weights(weights)
return loaded_params
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return self.new(input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors)
...@@ -406,6 +406,7 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig): ...@@ -406,6 +406,7 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = { MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
"GteModel": SnowflakeGteNewModelConfig, "GteModel": SnowflakeGteNewModelConfig,
"GteNewModel": GteNewModelConfig, "GteNewModel": GteNewModelConfig,
"GteNewForSequenceClassification": GteNewModelConfig,
"NomicBertModel": NomicBertModelConfig, "NomicBertModel": NomicBertModelConfig,
"Qwen2ForProcessRewardModel": Qwen2ForProcessRewardModelConfig, "Qwen2ForProcessRewardModel": Qwen2ForProcessRewardModelConfig,
"Qwen2ForRewardModel": Qwen2ForRewardModelConfig, "Qwen2ForRewardModel": Qwen2ForRewardModelConfig,
......
...@@ -191,12 +191,14 @@ _EMBEDDING_MODELS = { ...@@ -191,12 +191,14 @@ _EMBEDDING_MODELS = {
_CROSS_ENCODER_MODELS = { _CROSS_ENCODER_MODELS = {
"BertForSequenceClassification": ("bert", "BertForSequenceClassification"), "BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
"GteNewForSequenceClassification": ("bert_with_rope",
"GteNewForSequenceClassification"),
"ModernBertForSequenceClassification": ("modernbert",
"ModernBertForSequenceClassification"),
"RobertaForSequenceClassification": ("roberta", "RobertaForSequenceClassification": ("roberta",
"RobertaForSequenceClassification"), "RobertaForSequenceClassification"),
"XLMRobertaForSequenceClassification": ("roberta", "XLMRobertaForSequenceClassification": ("roberta",
"RobertaForSequenceClassification"), "RobertaForSequenceClassification"),
"ModernBertForSequenceClassification": ("modernbert",
"ModernBertForSequenceClassification"),
# [Auto-converted (see adapters.py)] # [Auto-converted (see adapters.py)]
"JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"), # noqa: E501, "JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"), # noqa: E501,
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment