Unverified Commit a3189a08 authored by wang.yuqi's avatar wang.yuqi Committed by GitHub
Browse files

[Model] Consolidate score logic by introduce score_type (#36479)


Signed-off-by: default avatarwang.yuqi <yuqi.wang@daocloud.io>
parent 409c4e63
...@@ -546,15 +546,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { ...@@ -546,15 +546,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
_EMBEDDING_EXAMPLE_MODELS = { _EMBEDDING_EXAMPLE_MODELS = {
# [Text-only] # [Text-only]
"BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5"), "BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5"),
"HF_ColBERT": _HfExamplesInfo("answerdotai/answerai-colbert-small-v1"), "BertSpladeSparseEmbeddingModel": _HfExamplesInfo(
"ColBERTModernBertModel": _HfExamplesInfo( "naver/splade-v3",
"lightonai/GTE-ModernColBERT-v1", hf_overrides={"architectures": ["BertSpladeSparseEmbeddingModel"]},
hf_overrides={"architectures": ["ColBERTModernBertModel"]},
),
"ColBERTJinaRobertaModel": _HfExamplesInfo(
"jinaai/jina-colbert-v2",
trust_remote_code=True,
hf_overrides={"architectures": ["ColBERTJinaRobertaModel"]},
), ),
"BgeM3EmbeddingModel": _HfExamplesInfo("BAAI/bge-m3"), "BgeM3EmbeddingModel": _HfExamplesInfo("BAAI/bge-m3"),
"Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2"), "Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2"),
...@@ -568,10 +562,6 @@ _EMBEDDING_EXAMPLE_MODELS = { ...@@ -568,10 +562,6 @@ _EMBEDDING_EXAMPLE_MODELS = {
trust_remote_code=True, trust_remote_code=True,
hf_overrides={"architectures": ["GteNewModel"]}, hf_overrides={"architectures": ["GteNewModel"]},
), ),
"InternLM2ForRewardModel": _HfExamplesInfo(
"internlm/internlm2-1_8b-reward", trust_remote_code=True
),
"JambaForSequenceClassification": _HfExamplesInfo("ai21labs/Jamba-tiny-reward-dev"),
"LlamaModel": _HfExamplesInfo("llama", is_available_online=False), "LlamaModel": _HfExamplesInfo("llama", is_available_online=False),
"LlamaBidirectionalModel": _HfExamplesInfo( "LlamaBidirectionalModel": _HfExamplesInfo(
"nvidia/llama-nemotron-embed-1b-v2", trust_remote_code=True "nvidia/llama-nemotron-embed-1b-v2", trust_remote_code=True
...@@ -584,35 +574,14 @@ _EMBEDDING_EXAMPLE_MODELS = { ...@@ -584,35 +574,14 @@ _EMBEDDING_EXAMPLE_MODELS = {
"nomic-ai/nomic-embed-text-v2-moe", trust_remote_code=True "nomic-ai/nomic-embed-text-v2-moe", trust_remote_code=True
), ),
"Qwen2Model": _HfExamplesInfo("ssmits/Qwen2-7B-Instruct-embed-base"), "Qwen2Model": _HfExamplesInfo("ssmits/Qwen2-7B-Instruct-embed-base"),
"Qwen2ForRewardModel": _HfExamplesInfo(
"Qwen/Qwen2.5-Math-RM-72B",
max_transformers_version="4.53",
transformers_version_reason={
"hf": "HF model uses remote code that is not compatible with latest Transformers" # noqa: E501
},
),
"Qwen2ForProcessRewardModel": _HfExamplesInfo(
"Qwen/Qwen2.5-Math-PRM-7B",
max_transformers_version="4.53",
transformers_version_reason={
"hf": "HF model uses remote code that is not compatible with latest Transformers" # noqa: E501
},
),
"RobertaModel": _HfExamplesInfo("sentence-transformers/stsb-roberta-base-v2"), "RobertaModel": _HfExamplesInfo("sentence-transformers/stsb-roberta-base-v2"),
"RobertaForMaskedLM": _HfExamplesInfo("sentence-transformers/all-roberta-large-v1"), "RobertaForMaskedLM": _HfExamplesInfo("sentence-transformers/all-roberta-large-v1"),
"VoyageQwen3BidirectionalEmbedModel": _HfExamplesInfo( "VoyageQwen3BidirectionalEmbedModel": _HfExamplesInfo(
"voyageai/voyage-4-nano", trust_remote_code=True "voyageai/voyage-4-nano", trust_remote_code=True
), ),
"XLMRobertaModel": _HfExamplesInfo("intfloat/multilingual-e5-small"), "XLMRobertaModel": _HfExamplesInfo("intfloat/multilingual-e5-small"),
"BertSpladeSparseEmbeddingModel": _HfExamplesInfo(
"naver/splade-v3",
hf_overrides={"architectures": ["BertSpladeSparseEmbeddingModel"]},
),
# [Multimodal] # [Multimodal]
"CLIPModel": _HfExamplesInfo("openai/clip-vit-base-patch32"), "CLIPModel": _HfExamplesInfo("openai/clip-vit-base-patch32"),
"ColModernVBertForRetrieval": _HfExamplesInfo(
"ModernVBERT/colmodernvbert-merged",
),
"LlamaNemotronVLModel": _HfExamplesInfo( "LlamaNemotronVLModel": _HfExamplesInfo(
"nvidia/llama-nemotron-embed-vl-1b-v2", trust_remote_code=True "nvidia/llama-nemotron-embed-vl-1b-v2", trust_remote_code=True
), ),
...@@ -621,15 +590,6 @@ _EMBEDDING_EXAMPLE_MODELS = { ...@@ -621,15 +590,6 @@ _EMBEDDING_EXAMPLE_MODELS = {
"TIGER-Lab/VLM2Vec-Full", trust_remote_code=True "TIGER-Lab/VLM2Vec-Full", trust_remote_code=True
), ),
"Qwen2VLForConditionalGeneration": _HfExamplesInfo("MrLight/dse-qwen2-2b-mrl-v1"), "Qwen2VLForConditionalGeneration": _HfExamplesInfo("MrLight/dse-qwen2-2b-mrl-v1"),
"ColQwen3": _HfExamplesInfo(
"TomoroAI/tomoro-colqwen3-embed-4b", trust_remote_code=True
),
"OpsColQwen3Model": _HfExamplesInfo(
"OpenSearch-AI/Ops-Colqwen3-4B", trust_remote_code=True
),
"Qwen3VLNemotronEmbedModel": _HfExamplesInfo(
"nvidia/nemotron-colembed-vl-4b-v2",
),
"SiglipModel": _HfExamplesInfo("google/siglip-base-patch16-224"), "SiglipModel": _HfExamplesInfo("google/siglip-base-patch16-224"),
"PrithviGeoSpatialMAE": _HfExamplesInfo( "PrithviGeoSpatialMAE": _HfExamplesInfo(
"ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11", "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11",
...@@ -649,21 +609,74 @@ _EMBEDDING_EXAMPLE_MODELS = { ...@@ -649,21 +609,74 @@ _EMBEDDING_EXAMPLE_MODELS = {
), ),
} }
_SEQUENCE_CLASSIFICATION_EXAMPLE_MODELS = { _LATE_INTERACTION_EXAMPLE_MODELS = {
# [Decoder-only] # [Text-only]
"GPT2ForSequenceClassification": _HfExamplesInfo( "HF_ColBERT": _HfExamplesInfo("answerdotai/answerai-colbert-small-v1"),
"nie3e/sentiment-polish-gpt2-small" "ColBERTModernBertModel": _HfExamplesInfo(
"lightonai/GTE-ModernColBERT-v1",
hf_overrides={"architectures": ["ColBERTModernBertModel"]},
), ),
# [Cross-encoder] "ColBERTJinaRobertaModel": _HfExamplesInfo(
"jinaai/jina-colbert-v2",
trust_remote_code=True,
hf_overrides={"architectures": ["ColBERTJinaRobertaModel"]},
),
# [Multimodal]
"ColModernVBertForRetrieval": _HfExamplesInfo(
"ModernVBERT/colmodernvbert-merged",
),
"ColQwen3": _HfExamplesInfo(
"TomoroAI/tomoro-colqwen3-embed-4b", trust_remote_code=True
),
"OpsColQwen3Model": _HfExamplesInfo(
"OpenSearch-AI/Ops-Colqwen3-4B", trust_remote_code=True
),
"Qwen3VLNemotronEmbedModel": _HfExamplesInfo(
"nvidia/nemotron-colembed-vl-4b-v2",
),
}
_REWARD_EXAMPLE_MODELS = {
"InternLM2ForRewardModel": _HfExamplesInfo(
"internlm/internlm2-1_8b-reward", trust_remote_code=True
),
"Qwen2ForRewardModel": _HfExamplesInfo(
"Qwen/Qwen2.5-Math-RM-72B",
max_transformers_version="4.53",
transformers_version_reason={
"hf": "HF model uses remote code that is not compatible with latest Transformers" # noqa: E501
},
),
"Qwen2ForProcessRewardModel": _HfExamplesInfo(
"Qwen/Qwen2.5-Math-PRM-7B",
max_transformers_version="4.53",
transformers_version_reason={
"hf": "HF model uses remote code that is not compatible with latest Transformers" # noqa: E501
},
),
}
_TOKEN_CLASSIFICATION_EXAMPLE_MODELS = {
"BertForTokenClassification": _HfExamplesInfo("boltuix/NeuroBERT-NER"),
"ModernBertForTokenClassification": _HfExamplesInfo(
"disham993/electrical-ner-ModernBERT-base"
),
}
_SEQUENCE_CLASSIFICATION_EXAMPLE_MODELS = {
"BertForSequenceClassification": _HfExamplesInfo( "BertForSequenceClassification": _HfExamplesInfo(
"cross-encoder/ms-marco-MiniLM-L-6-v2" "cross-encoder/ms-marco-MiniLM-L-6-v2"
), ),
"BertForTokenClassification": _HfExamplesInfo("boltuix/NeuroBERT-NER"), "GPT2ForSequenceClassification": _HfExamplesInfo(
"nie3e/sentiment-polish-gpt2-small"
),
"GteNewForSequenceClassification": _HfExamplesInfo( "GteNewForSequenceClassification": _HfExamplesInfo(
"Alibaba-NLP/gte-multilingual-reranker-base", "Alibaba-NLP/gte-multilingual-reranker-base",
trust_remote_code=True, trust_remote_code=True,
hf_overrides={"architectures": ["GteNewForSequenceClassification"]}, hf_overrides={"architectures": ["GteNewForSequenceClassification"]},
), ),
"JambaForSequenceClassification": _HfExamplesInfo("ai21labs/Jamba-tiny-reward-dev"),
"LlamaBidirectionalForSequenceClassification": _HfExamplesInfo( "LlamaBidirectionalForSequenceClassification": _HfExamplesInfo(
"nvidia/llama-nemotron-rerank-1b-v2", trust_remote_code=True "nvidia/llama-nemotron-rerank-1b-v2", trust_remote_code=True
), ),
...@@ -673,9 +686,6 @@ _SEQUENCE_CLASSIFICATION_EXAMPLE_MODELS = { ...@@ -673,9 +686,6 @@ _SEQUENCE_CLASSIFICATION_EXAMPLE_MODELS = {
"ModernBertForSequenceClassification": _HfExamplesInfo( "ModernBertForSequenceClassification": _HfExamplesInfo(
"Alibaba-NLP/gte-reranker-modernbert-base" "Alibaba-NLP/gte-reranker-modernbert-base"
), ),
"ModernBertForTokenClassification": _HfExamplesInfo(
"disham993/electrical-ner-ModernBERT-base"
),
"RobertaForSequenceClassification": _HfExamplesInfo( "RobertaForSequenceClassification": _HfExamplesInfo(
"cross-encoder/quora-roberta-base" "cross-encoder/quora-roberta-base"
), ),
...@@ -1273,6 +1283,9 @@ _TRANSFORMERS_BACKEND_MODELS = { ...@@ -1273,6 +1283,9 @@ _TRANSFORMERS_BACKEND_MODELS = {
_EXAMPLE_MODELS = { _EXAMPLE_MODELS = {
**_TEXT_GENERATION_EXAMPLE_MODELS, **_TEXT_GENERATION_EXAMPLE_MODELS,
**_EMBEDDING_EXAMPLE_MODELS, **_EMBEDDING_EXAMPLE_MODELS,
**_LATE_INTERACTION_EXAMPLE_MODELS,
**_REWARD_EXAMPLE_MODELS,
**_TOKEN_CLASSIFICATION_EXAMPLE_MODELS,
**_SEQUENCE_CLASSIFICATION_EXAMPLE_MODELS, **_SEQUENCE_CLASSIFICATION_EXAMPLE_MODELS,
**_MULTIMODAL_EXAMPLE_MODELS, **_MULTIMODAL_EXAMPLE_MODELS,
**_SPECULATIVE_DECODING_EXAMPLE_MODELS, **_SPECULATIVE_DECODING_EXAMPLE_MODELS,
......
...@@ -56,21 +56,24 @@ def test_registry_imports(model_arch): ...@@ -56,21 +56,24 @@ def test_registry_imports(model_arch):
@create_new_process_for_each_test() @create_new_process_for_each_test()
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model_arch,is_mm,init_cuda,is_ce", "model_arch,is_mm,init_cuda,score_type",
[ [
("LlamaForCausalLM", False, False, False), ("LlamaForCausalLM", False, False, "bi-encoder"),
("LlavaForConditionalGeneration", True, True, False), ("LlavaForConditionalGeneration", True, True, "bi-encoder"),
("BertForSequenceClassification", False, False, True), ("BertForSequenceClassification", False, False, "cross-encoder"),
("RobertaForSequenceClassification", False, False, True), ("RobertaForSequenceClassification", False, False, "cross-encoder"),
("XLMRobertaForSequenceClassification", False, False, True), ("XLMRobertaForSequenceClassification", False, False, "cross-encoder"),
("GteNewModel", False, False, "bi-encoder"),
("GteNewForSequenceClassification", False, False, "cross-encoder"),
("HF_ColBERT", False, False, "late-interaction"),
], ],
) )
def test_registry_model_property(model_arch, is_mm, init_cuda, is_ce): def test_registry_model_property(model_arch, is_mm, init_cuda, score_type):
model_info = ModelRegistry._try_inspect_model_cls(model_arch) model_info = ModelRegistry._try_inspect_model_cls(model_arch)
assert model_info is not None assert model_info is not None
assert model_info.supports_multimodal is is_mm assert model_info.supports_multimodal is is_mm
assert model_info.supports_cross_encoding is is_ce assert model_info.score_type == score_type
if init_cuda and current_platform.is_cuda_alike(): if init_cuda and current_platform.is_cuda_alike():
assert not torch.cuda.is_initialized() assert not torch.cuda.is_initialized()
......
...@@ -20,6 +20,7 @@ from vllm.config.scheduler import RunnerType ...@@ -20,6 +20,7 @@ from vllm.config.scheduler import RunnerType
from vllm.config.utils import config, getattr_iter from vllm.config.utils import config, getattr_iter
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.tasks import ScoreType
from vllm.transformers_utils.config import ( from vllm.transformers_utils.config import (
ConfigFormat, ConfigFormat,
get_config, get_config,
...@@ -1412,16 +1413,23 @@ class ModelConfig: ...@@ -1412,16 +1413,23 @@ class ModelConfig:
return self._model_info.requires_raw_input_tokens return self._model_info.requires_raw_input_tokens
@property @property
def is_cross_encoder(self) -> bool: def score_type(self) -> ScoreType:
"""
Score API handles score/rerank for:
- "score" task (score_type: cross-encoder models)
- "embed" task (score_type: bi-encoder models)
- "token_embed" task (score_type: late interaction models)
"""
# fixme: self._model_info.score_type is the score type before
# as_seq_cls_model, which is "bi-encoder", rather than the
# score type after as_seq_cls_model, which is "cross-encoder".
# Therefore, the following logic is required.
return ( return (
self._model_info.supports_cross_encoding or self.convert_type == "classify" "cross-encoder"
if self.convert_type == "classify"
else self._model_info.score_type
) )
@property
def is_late_interaction(self) -> bool:
"""Check if model uses late interaction (ColBERT-style) scoring."""
return self._model_info.supports_late_interaction
@property @property
def is_pp_supported(self) -> bool: def is_pp_supported(self) -> bool:
return self._model_info.supports_pp return self._model_info.supports_pp
......
...@@ -1584,8 +1584,11 @@ class LLM: ...@@ -1584,8 +1584,11 @@ class LLM:
) )
supported_tasks = self.supported_tasks supported_tasks = self.supported_tasks
score_type = self.model_config.score_type
is_late_interaction = score_type == "late-interaction"
is_cross_encoder = score_type == "cross-encoder"
# Late interaction models (e.g., ColBERT) use token_embed for scoring # Late interaction models (e.g., ColBERT) use token_embed for scoring
is_late_interaction = model_config.is_late_interaction
if not is_late_interaction and all( if not is_late_interaction and all(
t not in supported_tasks for t in ("embed", "classify") t not in supported_tasks for t in ("embed", "classify")
): ):
...@@ -1595,13 +1598,10 @@ class LLM: ...@@ -1595,13 +1598,10 @@ class LLM:
"`--convert embed` or `--convert classify`." "`--convert embed` or `--convert classify`."
) )
if ( if is_cross_encoder and getattr(model_config.hf_config, "num_labels", 0) != 1:
model_config.is_cross_encoder
and getattr(model_config.hf_config, "num_labels", 0) != 1
):
raise ValueError("Score API is only enabled for num_labels == 1.") raise ValueError("Score API is only enabled for num_labels == 1.")
if not model_config.is_cross_encoder and chat_template is not None: if not is_cross_encoder and chat_template is not None:
raise ValueError( raise ValueError(
"chat_template is only supported for cross-encoder models." "chat_template is only supported for cross-encoder models."
) )
...@@ -1622,7 +1622,7 @@ class LLM: ...@@ -1622,7 +1622,7 @@ class LLM:
) )
encode_kwargs = tok_params.get_encode_kwargs() encode_kwargs = tok_params.get_encode_kwargs()
if model_config.is_cross_encoder: if is_cross_encoder:
return self._cross_encoding_score( return self._cross_encoding_score(
score_data_1, score_data_1,
score_data_2, score_data_2,
......
...@@ -37,10 +37,10 @@ def register_pooling_api_routers( ...@@ -37,10 +37,10 @@ def register_pooling_api_routers(
app.include_router(embed_router) app.include_router(embed_router)
# Score/rerank endpoints are available for: # Score API handles score/rerank for:
# - "score" task (cross-encoder models) # - "score" task (score_type: cross-encoder models)
# - "embed" task (bi-encoder models) # - "embed" task (score_type: bi-encoder models)
# - "token_embed" task (late interaction models like ColBERT) # - "token_embed" task (score_type: late interaction models)
if any(t in supported_tasks for t in ("score", "embed", "token_embed")): if any(t in supported_tasks for t in ("score", "embed", "token_embed")):
from vllm.entrypoints.pooling.score.api_router import router as score_router from vllm.entrypoints.pooling.score.api_router import router as score_router
...@@ -101,10 +101,10 @@ def init_pooling_state( ...@@ -101,10 +101,10 @@ def init_pooling_state(
if "classify" in supported_tasks if "classify" in supported_tasks
else None else None
) )
# ServingScores handles score/rerank for: # Score API handles score/rerank for:
# - "score" task (cross-encoder models) # - "score" task (score_type: cross-encoder models)
# - "embed" task (bi-encoder models) # - "embed" task (score_type: bi-encoder models)
# - "token_embed" task (late interaction models like ColBERT) # - "token_embed" task (score_type: late interaction models)
state.serving_scores = ( state.serving_scores = (
ServingScores( ServingScores(
engine_client, engine_client,
......
...@@ -69,16 +69,15 @@ class ServingScores(OpenAIServing): ...@@ -69,16 +69,15 @@ class ServingScores(OpenAIServing):
self._tokenizer_executor = ThreadPoolExecutor(max_workers=1) self._tokenizer_executor = ThreadPoolExecutor(max_workers=1)
self.is_cross_encoder = self.model_config.is_cross_encoder self.score_type = self.model_config.score_type
self.is_multimodal_model = self.model_config.is_multimodal_model
self.architecture = self.model_config.architecture self.architecture = self.model_config.architecture
self.is_late_interaction = self.model_config.is_late_interaction self.is_multimodal_model = self.model_config.is_multimodal_model
if self.is_cross_encoder: if self.score_type == "cross-encoder":
self._score_func = self._cross_encoding_score self._score_func = self._cross_encoding_score
elif self.is_late_interaction: elif self.score_type == "late-interaction":
self._score_func = self._late_interaction_score self._score_func = self._late_interaction_score
else: else: # "bi-encoder"
self._score_func = self._embedding_score self._score_func = self._embedding_score
async def _embedding_score( async def _embedding_score(
......
...@@ -30,8 +30,11 @@ from vllm.lora.utils import ( ...@@ -30,8 +30,11 @@ from vllm.lora.utils import (
replace_submodule, replace_submodule,
) )
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.models import SupportsLoRA, supports_multimodal from vllm.model_executor.models import (
from vllm.model_executor.models.interfaces import is_pooling_model SupportsLoRA,
is_pooling_model,
supports_multimodal,
)
from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.utils import PPMissingLayer from vllm.model_executor.models.utils import PPMissingLayer
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
......
...@@ -18,7 +18,6 @@ Reference: https://arxiv.org/abs/2004.12832 ...@@ -18,7 +18,6 @@ Reference: https://arxiv.org/abs/2004.12832
""" """
from collections.abc import Iterable from collections.abc import Iterable
from typing import ClassVar, Literal
import torch import torch
from torch import nn from torch import nn
...@@ -28,16 +27,16 @@ from vllm.model_executor.layers.pooler import Pooler ...@@ -28,16 +27,16 @@ from vllm.model_executor.layers.pooler import Pooler
from vllm.model_executor.layers.pooler.tokwise import pooler_for_token_embed from vllm.model_executor.layers.pooler.tokwise import pooler_for_token_embed
from .bert import BertEmbeddingModel, BertModel from .bert import BertEmbeddingModel, BertModel
from .interfaces import SupportsLateInteraction
from .interfaces_base import default_pooling_type from .interfaces_base import default_pooling_type
class ColBERTMixin: class ColBERTMixin(nn.Module, SupportsLateInteraction):
"""Mixin that adds ColBERT late interaction support to any embedding model. """Mixin that adds ColBERT late interaction support to any embedding model.
ColBERT (Contextualized Late Interaction over BERT) uses per-token ColBERT (Contextualized Late Interaction over BERT) uses per-token
embeddings with a linear projection layer. This mixin provides: embeddings with a linear projection layer. This mixin provides:
- ``supports_late_interaction`` class-var
- ColBERT linear projection initialisation / lazy creation - ColBERT linear projection initialisation / lazy creation
- Weight loading helpers for the projection layer - Weight loading helpers for the projection layer
- A builder for the token-embedding pooler - A builder for the token-embedding pooler
...@@ -52,8 +51,6 @@ class ColBERTMixin: ...@@ -52,8 +51,6 @@ class ColBERTMixin:
the ColBERT projection weight, then delegate the rest to the backbone. the ColBERT projection weight, then delegate the rest to the backbone.
""" """
supports_late_interaction: ClassVar[Literal[True]] = True
# Set during _init_colbert_components # Set during _init_colbert_components
colbert_dim: int | None colbert_dim: int | None
colbert_linear: nn.Linear | None colbert_linear: nn.Linear | None
......
...@@ -9,7 +9,6 @@ Reference: https://huggingface.co/ModernVBERT/colmodernvbert-merged ...@@ -9,7 +9,6 @@ Reference: https://huggingface.co/ModernVBERT/colmodernvbert-merged
""" """
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from typing import ClassVar, Literal
import torch import torch
from torch import nn from torch import nn
...@@ -37,7 +36,11 @@ from vllm.multimodal.processing import ( ...@@ -37,7 +36,11 @@ from vllm.multimodal.processing import (
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.colmodernvbert import ColModernVBertConfig from vllm.transformers_utils.configs.colmodernvbert import ColModernVBertConfig
from .interfaces import MultiModalEmbeddings, SupportsMultiModal from .interfaces import (
MultiModalEmbeddings,
SupportsLateInteraction,
SupportsMultiModal,
)
from .interfaces_base import default_pooling_type from .interfaces_base import default_pooling_type
from .modernbert import ModernBertEmbeddings, ModernBertLayer from .modernbert import ModernBertEmbeddings, ModernBertLayer
from .siglip import SiglipVisionModel from .siglip import SiglipVisionModel
...@@ -234,7 +237,9 @@ class ColModernVBertMultiModalProcessor( ...@@ -234,7 +237,9 @@ class ColModernVBertMultiModalProcessor(
dummy_inputs=ColModernVBertDummyInputsBuilder, dummy_inputs=ColModernVBertDummyInputsBuilder,
) )
@default_pooling_type(seq_pooling_type="CLS", tok_pooling_type="ALL") @default_pooling_type(seq_pooling_type="CLS", tok_pooling_type="ALL")
class ColModernVBertForRetrieval(nn.Module, SupportsMultiModal): class ColModernVBertForRetrieval(
nn.Module, SupportsMultiModal, SupportsLateInteraction
):
"""ColModernVBERT multimodal late-interaction retrieval model. """ColModernVBERT multimodal late-interaction retrieval model.
Architecture: Architecture:
...@@ -248,7 +253,6 @@ class ColModernVBertForRetrieval(nn.Module, SupportsMultiModal): ...@@ -248,7 +253,6 @@ class ColModernVBertForRetrieval(nn.Module, SupportsMultiModal):
""" """
is_pooling_model = True is_pooling_model = True
supports_late_interaction: ClassVar[Literal[True]] = True
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
......
...@@ -20,7 +20,6 @@ Target models: ...@@ -20,7 +20,6 @@ Target models:
""" """
from collections.abc import Iterable, Mapping from collections.abc import Iterable, Mapping
from typing import ClassVar, Literal
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -31,6 +30,7 @@ from vllm.model_executor.layers.pooler.tokwise import pooler_for_token_embed ...@@ -31,6 +30,7 @@ from vllm.model_executor.layers.pooler.tokwise import pooler_for_token_embed
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from .interfaces import SupportsLateInteraction
from .interfaces_base import default_pooling_type from .interfaces_base import default_pooling_type
from .qwen2_vl import Qwen2VLMultiModalDataParser from .qwen2_vl import Qwen2VLMultiModalDataParser
from .qwen3_vl import ( from .qwen3_vl import (
...@@ -113,9 +113,7 @@ class ColQwen3ProcessingInfo(Qwen3VLProcessingInfo): ...@@ -113,9 +113,7 @@ class ColQwen3ProcessingInfo(Qwen3VLProcessingInfo):
info=ColQwen3ProcessingInfo, info=ColQwen3ProcessingInfo,
dummy_inputs=Qwen3VLDummyInputsBuilder, dummy_inputs=Qwen3VLDummyInputsBuilder,
) )
class ColQwen3Model( class ColQwen3Model(Qwen3VLForConditionalGeneration, SupportsLateInteraction):
Qwen3VLForConditionalGeneration,
):
"""ColQwen3 late interaction model for multi-modal retrieval/reranking. """ColQwen3 late interaction model for multi-modal retrieval/reranking.
This model extends Qwen3VLForConditionalGeneration with a ColBERT-style This model extends Qwen3VLForConditionalGeneration with a ColBERT-style
...@@ -132,16 +130,11 @@ class ColQwen3Model( ...@@ -132,16 +130,11 @@ class ColQwen3Model(
Attributes: Attributes:
custom_text_proj: Linear projection from hidden_size to embed_dim custom_text_proj: Linear projection from hidden_size to embed_dim
supports_late_interaction: Flag indicating this model uses late
interaction scoring
""" """
# Mark this as a pooling model so vLLM routes to pooler path # Mark this as a pooling model so vLLM routes to pooler path
is_pooling_model = True is_pooling_model = True
# Mark this model as supporting late interaction scoring
supports_late_interaction: ClassVar[Literal[True]] = True
# Override hf_to_vllm_mapper to handle ColQwen3 weight naming. # Override hf_to_vllm_mapper to handle ColQwen3 weight naming.
# NOTE: WeightsMapper applies ALL matching prefix rules sequentially # NOTE: WeightsMapper applies ALL matching prefix rules sequentially
# (no early exit), so more-specific prefixes must come first. # (no early exit), so more-specific prefixes must come first.
......
...@@ -34,10 +34,11 @@ from vllm.inputs.data import PromptType ...@@ -34,10 +34,11 @@ from vllm.inputs.data import PromptType
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.mamba.mamba_utils import MambaStateCopyFunc from vllm.model_executor.layers.mamba.mamba_utils import MambaStateCopyFunc
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.tasks import ScoreType
from vllm.utils.collection_utils import common_prefix from vllm.utils.collection_utils import common_prefix
from vllm.utils.func_utils import supports_kw from vllm.utils.func_utils import supports_kw
from .interfaces_base import VllmModel, is_pooling_model from .interfaces_base import VllmModel
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import VllmConfig from vllm.config import VllmConfig
...@@ -969,29 +970,7 @@ def supports_mamba_prefix_caching( ...@@ -969,29 +970,7 @@ def supports_mamba_prefix_caching(
class SupportsCrossEncoding(Protocol): class SupportsCrossEncoding(Protocol):
"""The interface required for all models that support cross encoding.""" """The interface required for all models that support cross encoding."""
supports_cross_encoding: ClassVar[Literal[True]] = True score_type: ClassVar[ScoreType] = "cross-encoder"
@overload
def supports_cross_encoding(
model: type[object],
) -> TypeIs[type[SupportsCrossEncoding]]: ...
@overload
def supports_cross_encoding(model: object) -> TypeIs[SupportsCrossEncoding]: ...
def _supports_cross_encoding(
model: type[object] | object,
) -> TypeIs[type[SupportsCrossEncoding]] | TypeIs[SupportsCrossEncoding]:
return getattr(model, "supports_cross_encoding", False)
def supports_cross_encoding(
model: type[object] | object,
) -> TypeIs[type[SupportsCrossEncoding]] | TypeIs[SupportsCrossEncoding]:
return is_pooling_model(model) and _supports_cross_encoding(model)
@runtime_checkable @runtime_checkable
...@@ -1003,29 +982,7 @@ class SupportsLateInteraction(Protocol): ...@@ -1003,29 +982,7 @@ class SupportsLateInteraction(Protocol):
MaxSim (max over document tokens, sum over query tokens). MaxSim (max over document tokens, sum over query tokens).
""" """
supports_late_interaction: ClassVar[Literal[True]] = True score_type: ClassVar[ScoreType] = "late-interaction"
@overload
def supports_late_interaction(
model: type[object],
) -> TypeIs[type[SupportsLateInteraction]]: ...
@overload
def supports_late_interaction(model: object) -> TypeIs[SupportsLateInteraction]: ...
def _supports_late_interaction(
model: type[object] | object,
) -> TypeIs[type[SupportsLateInteraction]] | TypeIs[SupportsLateInteraction]:
return getattr(model, "supports_late_interaction", False)
def supports_late_interaction(
model: type[object] | object,
) -> TypeIs[type[SupportsLateInteraction]] | TypeIs[SupportsLateInteraction]:
return is_pooling_model(model) and _supports_late_interaction(model)
class SupportsQuant: class SupportsQuant:
......
...@@ -15,6 +15,7 @@ import torch.nn as nn ...@@ -15,6 +15,7 @@ import torch.nn as nn
from typing_extensions import TypeIs, TypeVar from typing_extensions import TypeIs, TypeVar
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.tasks import ScoreType
from vllm.utils.func_utils import supports_kw from vllm.utils.func_utils import supports_kw
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -187,6 +188,26 @@ class VllmModelForPooling(VllmModel[T_co], Protocol[T_co]): ...@@ -187,6 +188,26 @@ class VllmModelForPooling(VllmModel[T_co], Protocol[T_co]):
decorator to conveniently set this field. decorator to conveniently set this field.
""" """
score_type: ClassVar[ScoreType] = "bi-encoder"
"""
Indicates the
[vllm.config.model.ModelConfig.score_type][]
to use by default.
Score API handles score/rerank for:
- "score" task (score_type: cross-encoder models)
- "embed" task (score_type: bi-encoder models)
- "token_embed" task (score_type: late interaction models)
score_type defaults to bi-encoder, then the Score API uses the "embed" task.
If you set score_type to cross-encoder via
[vllm.model_executor.models.interfaces.SupportsCrossEncoding][],
then the Score API uses the "score" task.
If you set score_type to late-interaction via
[vllm.model_executor.models.interfaces.SupportsLateInteraction][],
then the Score API uses the "token_embed" task.
"""
pooler: Pooler pooler: Pooler
"""The pooler is only called on TP rank 0.""" """The pooler is only called on TP rank 0."""
...@@ -250,3 +271,13 @@ def attn_type(attn_type: AttnTypeStr): ...@@ -250,3 +271,13 @@ def attn_type(attn_type: AttnTypeStr):
def get_attn_type(model: type[object] | object) -> AttnTypeStr: def get_attn_type(model: type[object] | object) -> AttnTypeStr:
return getattr(model, "attn_type", "decoder") return getattr(model, "attn_type", "decoder")
def get_score_type(model: type[object] | object) -> ScoreType:
score_types = set()
for m in model.__mro__:
score_type = getattr(m, "score_type", "bi-encoder")
if score_type != "bi-encoder":
score_types.add(score_type)
assert len(score_types) < 2
return "bi-encoder" if not score_types else list(score_types)[0]
...@@ -30,6 +30,7 @@ from vllm.config import ( ...@@ -30,6 +30,7 @@ from vllm.config import (
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logging_utils import logtime from vllm.logging_utils import logtime
from vllm.tasks import ScoreType
from vllm.transformers_utils.dynamic_module import try_get_class_from_dynamic_module from vllm.transformers_utils.dynamic_module import try_get_class_from_dynamic_module
from vllm.utils.hashing import safe_hash from vllm.utils.hashing import safe_hash
...@@ -48,8 +49,6 @@ from .interfaces import ( ...@@ -48,8 +49,6 @@ from .interfaces import (
is_attention_free, is_attention_free,
is_hybrid, is_hybrid,
requires_raw_input_tokens, requires_raw_input_tokens,
supports_cross_encoding,
supports_late_interaction,
supports_mamba_prefix_caching, supports_mamba_prefix_caching,
supports_multimodal, supports_multimodal,
supports_multimodal_encoder_tp_data, supports_multimodal_encoder_tp_data,
...@@ -61,6 +60,7 @@ from .interfaces_base import ( ...@@ -61,6 +60,7 @@ from .interfaces_base import (
get_attn_type, get_attn_type,
get_default_seq_pooling_type, get_default_seq_pooling_type,
get_default_tok_pooling_type, get_default_tok_pooling_type,
get_score_type,
is_pooling_model, is_pooling_model,
is_text_generation_model, is_text_generation_model,
) )
...@@ -214,19 +214,14 @@ _EMBEDDING_MODELS = { ...@@ -214,19 +214,14 @@ _EMBEDDING_MODELS = {
# [Text-only] # [Text-only]
"BertModel": ("bert", "BertEmbeddingModel"), "BertModel": ("bert", "BertEmbeddingModel"),
"BertSpladeSparseEmbeddingModel": ("bert", "BertSpladeSparseEmbeddingModel"), "BertSpladeSparseEmbeddingModel": ("bert", "BertSpladeSparseEmbeddingModel"),
"HF_ColBERT": ("colbert", "ColBERTModel"), "BgeM3EmbeddingModel": ("roberta", "BgeM3EmbeddingModel"),
"ColBERTModernBertModel": ("colbert", "ColBERTModernBertModel"),
"ColBERTJinaRobertaModel": ("colbert", "ColBERTJinaRobertaModel"),
"DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"), "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
"Gemma2Model": ("gemma2", "Gemma2ForCausalLM"), "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
"Gemma3TextModel": ("gemma3", "Gemma3Model"), "Gemma3TextModel": ("gemma3", "Gemma3Model"),
"GlmForCausalLM": ("glm", "GlmForCausalLM"), "GlmForCausalLM": ("glm", "GlmForCausalLM"),
"GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"),
"GritLM": ("gritlm", "GritLM"), "GritLM": ("gritlm", "GritLM"),
"GteModel": ("bert_with_rope", "SnowflakeGteNewModel"), "GteModel": ("bert_with_rope", "SnowflakeGteNewModel"),
"GteNewModel": ("bert_with_rope", "GteNewModel"), "GteNewModel": ("bert_with_rope", "GteNewModel"),
"InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
"JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"), # noqa: E501
"LlamaBidirectionalModel": ("llama", "LlamaBidirectionalModel"), "LlamaBidirectionalModel": ("llama", "LlamaBidirectionalModel"),
"LlamaModel": ("llama", "LlamaForCausalLM"), "LlamaModel": ("llama", "LlamaForCausalLM"),
**{ **{
...@@ -241,8 +236,6 @@ _EMBEDDING_MODELS = { ...@@ -241,8 +236,6 @@ _EMBEDDING_MODELS = {
"Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"), "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
"Qwen2Model": ("qwen2", "Qwen2ForCausalLM"), "Qwen2Model": ("qwen2", "Qwen2ForCausalLM"),
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
"Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
"Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"),
"RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"), "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
"RobertaModel": ("roberta", "RobertaEmbeddingModel"), "RobertaModel": ("roberta", "RobertaEmbeddingModel"),
"TeleChatForCausalLM": ("telechat2", "TeleChat2ForCausalLM"), "TeleChatForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
...@@ -252,19 +245,14 @@ _EMBEDDING_MODELS = { ...@@ -252,19 +245,14 @@ _EMBEDDING_MODELS = {
"VoyageQwen3BidirectionalEmbedModel", "VoyageQwen3BidirectionalEmbedModel",
), ),
"XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"), "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
"BgeM3EmbeddingModel": ("roberta", "BgeM3EmbeddingModel"),
# [Multimodal] # [Multimodal]
"CLIPModel": ("clip", "CLIPEmbeddingModel"), "CLIPModel": ("clip", "CLIPEmbeddingModel"),
"ColModernVBertForRetrieval": ("colmodernvbert", "ColModernVBertForRetrieval"),
"LlavaNextForConditionalGeneration": ( "LlavaNextForConditionalGeneration": (
"llava_next", "llava_next",
"LlavaNextForConditionalGeneration", "LlavaNextForConditionalGeneration",
), ),
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
"Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501 "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501
"ColQwen3": ("colqwen3", "ColQwen3Model"),
"OpsColQwen3Model": ("colqwen3", "ColQwen3Model"),
"Qwen3VLNemotronEmbedModel": ("colqwen3", "ColQwen3Model"),
"SiglipModel": ("siglip", "SiglipEmbeddingModel"), "SiglipModel": ("siglip", "SiglipEmbeddingModel"),
"LlamaNemotronVLModel": ( "LlamaNemotronVLModel": (
"nemotron_vl", "nemotron_vl",
...@@ -277,35 +265,59 @@ _EMBEDDING_MODELS = { ...@@ -277,35 +265,59 @@ _EMBEDDING_MODELS = {
"Terratorch": ("terratorch", "Terratorch"), "Terratorch": ("terratorch", "Terratorch"),
} }
_CROSS_ENCODER_MODELS = { _LATE_INTERACTION_MODELS = {
"BertForSequenceClassification": ("bert", "BertForSequenceClassification"), # [Text-only]
"HF_ColBERT": ("colbert", "ColBERTModel"),
"ColBERTModernBertModel": ("colbert", "ColBERTModernBertModel"),
"ColBERTJinaRobertaModel": ("colbert", "ColBERTJinaRobertaModel"),
# [Multimodal]
"ColModernVBertForRetrieval": ("colmodernvbert", "ColModernVBertForRetrieval"),
"ColQwen3": ("colqwen3", "ColQwen3Model"),
"OpsColQwen3Model": ("colqwen3", "ColQwen3Model"),
"Qwen3VLNemotronEmbedModel": ("colqwen3", "ColQwen3Model"),
}
_REWARD_MODELS = {
"InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
"Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
"Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"),
}
_TOKEN_CLASSIFICATION_MODELS = {
"BertForTokenClassification": ("bert", "BertForTokenClassification"), "BertForTokenClassification": ("bert", "BertForTokenClassification"),
"ModernBertForTokenClassification": (
"modernbert",
"ModernBertForTokenClassification",
),
}
_SEQUENCE_CLASSIFICATION_MODELS = {
"BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
"GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"),
"GteNewForSequenceClassification": ( "GteNewForSequenceClassification": (
"bert_with_rope", "bert_with_rope",
"GteNewForSequenceClassification", "GteNewForSequenceClassification",
), ),
"JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"), "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"), # noqa: E501
"LlamaBidirectionalForSequenceClassification": ( "LlamaBidirectionalForSequenceClassification": (
"llama", "llama",
"LlamaBidirectionalForSequenceClassification", "LlamaBidirectionalForSequenceClassification",
), ),
"LlamaNemotronVLForSequenceClassification": (
"nemotron_vl",
"LlamaNemotronVLForSequenceClassification",
),
"ModernBertForSequenceClassification": ( "ModernBertForSequenceClassification": (
"modernbert", "modernbert",
"ModernBertForSequenceClassification", "ModernBertForSequenceClassification",
), ),
"ModernBertForTokenClassification": (
"modernbert",
"ModernBertForTokenClassification",
),
"RobertaForSequenceClassification": ("roberta", "RobertaForSequenceClassification"), "RobertaForSequenceClassification": ("roberta", "RobertaForSequenceClassification"),
"XLMRobertaForSequenceClassification": ( "XLMRobertaForSequenceClassification": (
"roberta", "roberta",
"RobertaForSequenceClassification", "RobertaForSequenceClassification",
), ),
# [Multimodal]
"JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"),
"LlamaNemotronVLForSequenceClassification": (
"nemotron_vl",
"LlamaNemotronVLForSequenceClassification",
),
} }
_MULTIMODAL_MODELS = { _MULTIMODAL_MODELS = {
...@@ -606,7 +618,10 @@ _TRANSFORMERS_BACKEND_MODELS = { ...@@ -606,7 +618,10 @@ _TRANSFORMERS_BACKEND_MODELS = {
_VLLM_MODELS = { _VLLM_MODELS = {
**_TEXT_GENERATION_MODELS, **_TEXT_GENERATION_MODELS,
**_EMBEDDING_MODELS, **_EMBEDDING_MODELS,
**_CROSS_ENCODER_MODELS, **_LATE_INTERACTION_MODELS,
**_REWARD_MODELS,
**_TOKEN_CLASSIFICATION_MODELS,
**_SEQUENCE_CLASSIFICATION_MODELS,
**_MULTIMODAL_MODELS, **_MULTIMODAL_MODELS,
**_SPECULATIVE_DECODING_MODELS, **_SPECULATIVE_DECODING_MODELS,
**_TRANSFORMERS_SUPPORTED_MODELS, **_TRANSFORMERS_SUPPORTED_MODELS,
...@@ -643,8 +658,7 @@ class _ModelInfo: ...@@ -643,8 +658,7 @@ class _ModelInfo:
attn_type: AttnTypeStr attn_type: AttnTypeStr
default_seq_pooling_type: SequencePoolingType default_seq_pooling_type: SequencePoolingType
default_tok_pooling_type: TokenPoolingType default_tok_pooling_type: TokenPoolingType
supports_cross_encoding: bool score_type: ScoreType
supports_late_interaction: bool
supports_multimodal: bool supports_multimodal: bool
supports_multimodal_raw_input_only: bool supports_multimodal_raw_input_only: bool
requires_raw_input_tokens: bool requires_raw_input_tokens: bool
...@@ -667,8 +681,7 @@ class _ModelInfo: ...@@ -667,8 +681,7 @@ class _ModelInfo:
default_seq_pooling_type=get_default_seq_pooling_type(model), default_seq_pooling_type=get_default_seq_pooling_type(model),
default_tok_pooling_type=get_default_tok_pooling_type(model), default_tok_pooling_type=get_default_tok_pooling_type(model),
attn_type=get_attn_type(model), attn_type=get_attn_type(model),
supports_cross_encoding=supports_cross_encoding(model), score_type=get_score_type(model),
supports_late_interaction=supports_late_interaction(model),
supports_multimodal=supports_multimodal(model), supports_multimodal=supports_multimodal(model),
supports_multimodal_raw_input_only=supports_multimodal_raw_input_only( supports_multimodal_raw_input_only=supports_multimodal_raw_input_only(
model model
...@@ -1166,14 +1179,6 @@ class _ModelRegistry: ...@@ -1166,14 +1179,6 @@ class _ModelRegistry:
model_cls, _ = self.inspect_model_cls(architectures, model_config) model_cls, _ = self.inspect_model_cls(architectures, model_config)
return model_cls.is_pooling_model return model_cls.is_pooling_model
def is_cross_encoder_model(
self,
architectures: str | list[str],
model_config: ModelConfig,
) -> bool:
model_cls, _ = self.inspect_model_cls(architectures, model_config)
return model_cls.supports_cross_encoding
def is_multimodal_model( def is_multimodal_model(
self, self,
architectures: str | list[str], architectures: str | list[str],
......
...@@ -10,6 +10,12 @@ PoolingTask = Literal[ ...@@ -10,6 +10,12 @@ PoolingTask = Literal[
] ]
POOLING_TASKS: tuple[PoolingTask, ...] = get_args(PoolingTask) POOLING_TASKS: tuple[PoolingTask, ...] = get_args(PoolingTask)
# Score API handles score/rerank for:
# - "score" task (score_type: cross-encoder models)
# - "embed" task (score_type: bi-encoder models)
# - "token_embed" task (score_type: late interaction models)
ScoreType = Literal["bi-encoder", "cross-encoder", "late-interaction"]
FrontendTask = Literal["render"] FrontendTask = Literal["render"]
FRONTEND_TASKS: tuple[FrontendTask, ...] = get_args(FrontendTask) FRONTEND_TASKS: tuple[FrontendTask, ...] = get_args(FrontendTask)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment