Unverified Commit 583a90e0 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Refactor] Separate sequence and token pooling types (#32026)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 52d42829
...@@ -46,7 +46,8 @@ def test_model_loading_with_params(vllm_runner, monkeypatch): ...@@ -46,7 +46,8 @@ def test_model_loading_with_params(vllm_runner, monkeypatch):
assert model_config.encoder_config["do_lower_case"] assert model_config.encoder_config["do_lower_case"]
# asserts on the pooling config files # asserts on the pooling config files
assert model_config.pooler_config.pooling_type == "CLS" assert model_config.pooler_config.seq_pooling_type == "CLS"
assert model_config.pooler_config.tok_pooling_type == "ALL"
assert model_config.pooler_config.normalize assert model_config.pooler_config.normalize
# asserts on the tokenizer loaded # asserts on the tokenizer loaded
...@@ -90,7 +91,8 @@ def test_roberta_model_loading_with_params(vllm_runner, monkeypatch): ...@@ -90,7 +91,8 @@ def test_roberta_model_loading_with_params(vllm_runner, monkeypatch):
assert not model_config.encoder_config["do_lower_case"] assert not model_config.encoder_config["do_lower_case"]
# asserts on the pooling config files # asserts on the pooling config files
assert model_config.pooler_config.pooling_type == "MEAN" assert model_config.pooler_config.seq_pooling_type == "MEAN"
assert model_config.pooler_config.tok_pooling_type == "ALL"
assert model_config.pooler_config.normalize assert model_config.pooler_config.normalize
# asserts on the tokenizer loaded # asserts on the tokenizer loaded
......
...@@ -54,7 +54,7 @@ def test_models( ...@@ -54,7 +54,7 @@ def test_models(
vllm_extra_kwargs = {} vllm_extra_kwargs = {}
if model == "ssmits/Qwen2-7B-Instruct-embed-base": if model == "ssmits/Qwen2-7B-Instruct-embed-base":
vllm_extra_kwargs["pooler_config"] = PoolerConfig( vllm_extra_kwargs["pooler_config"] = PoolerConfig(
pooling_type="MEAN", normalize=False seq_pooling_type="MEAN", normalize=False
) )
max_model_len: int | None = 512 max_model_len: int | None = 512
......
...@@ -88,7 +88,7 @@ def test_gemma_multimodal( ...@@ -88,7 +88,7 @@ def test_gemma_multimodal(
convert="classify", convert="classify",
load_format="auto", load_format="auto",
hf_overrides=update_config, hf_overrides=update_config,
pooler_config=PoolerConfig(pooling_type="LAST"), pooler_config=PoolerConfig(seq_pooling_type="LAST"),
max_model_len=512, max_model_len=512,
enforce_eager=True, enforce_eager=True,
tensor_parallel_size=1, tensor_parallel_size=1,
......
...@@ -162,8 +162,11 @@ def mteb_test_embed_models( ...@@ -162,8 +162,11 @@ def mteb_test_embed_models(
assert model_info.architecture in model_config.architectures assert model_info.architecture in model_config.architectures
# Confirm whether the important configs in model_config are correct. # Confirm whether the important configs in model_config are correct.
if model_info.pooling_type is not None: pooler_config = model_config.pooler_config
assert model_config.pooler_config.pooling_type == model_info.pooling_type if model_info.seq_pooling_type is not None:
assert pooler_config.seq_pooling_type == model_info.seq_pooling_type
if model_info.tok_pooling_type is not None:
assert pooler_config.tok_pooling_type == model_info.tok_pooling_type
if model_info.attn_type is not None: if model_info.attn_type is not None:
assert model_config.attn_type == model_info.attn_type assert model_config.attn_type == model_info.attn_type
if model_info.is_prefix_caching_supported is not None: if model_info.is_prefix_caching_supported is not None:
......
...@@ -254,8 +254,11 @@ def mteb_test_rerank_models( ...@@ -254,8 +254,11 @@ def mteb_test_rerank_models(
assert model_config.hf_config.num_labels == 1 assert model_config.hf_config.num_labels == 1
# Confirm whether the important configs in model_config are correct. # Confirm whether the important configs in model_config are correct.
if model_info.pooling_type is not None: pooler_config = model_config.pooler_config
assert model_config.pooler_config.pooling_type == model_info.pooling_type if model_info.seq_pooling_type is not None:
assert pooler_config.seq_pooling_type == model_info.seq_pooling_type
if model_info.tok_pooling_type is not None:
assert pooler_config.tok_pooling_type == model_info.tok_pooling_type
if model_info.attn_type is not None: if model_info.attn_type is not None:
assert model_config.attn_type == model_info.attn_type assert model_config.attn_type == model_info.attn_type
if model_info.is_prefix_caching_supported is not None: if model_info.is_prefix_caching_supported is not None:
......
...@@ -17,7 +17,7 @@ MODELS = [ ...@@ -17,7 +17,7 @@ MODELS = [
"BAAI/bge-base-en", "BAAI/bge-base-en",
architecture="BertModel", architecture="BertModel",
mteb_score=0.779336792, mteb_score=0.779336792,
pooling_type="CLS", seq_pooling_type="CLS",
attn_type="encoder_only", attn_type="encoder_only",
is_prefix_caching_supported=False, is_prefix_caching_supported=False,
is_chunked_prefill_supported=False, is_chunked_prefill_supported=False,
...@@ -54,7 +54,7 @@ MODELS = [ ...@@ -54,7 +54,7 @@ MODELS = [
"BAAI/bge-m3", "BAAI/bge-m3",
architecture="XLMRobertaModel", architecture="XLMRobertaModel",
mteb_score=0.787343078, mteb_score=0.787343078,
pooling_type="CLS", seq_pooling_type="CLS",
attn_type="encoder_only", attn_type="encoder_only",
is_prefix_caching_supported=False, is_prefix_caching_supported=False,
is_chunked_prefill_supported=False, is_chunked_prefill_supported=False,
...@@ -65,7 +65,7 @@ MODELS = [ ...@@ -65,7 +65,7 @@ MODELS = [
"BAAI/bge-code-v1", "BAAI/bge-code-v1",
architecture="Qwen2Model", architecture="Qwen2Model",
mteb_score=0.75724465, mteb_score=0.75724465,
pooling_type="LAST", seq_pooling_type="LAST",
attn_type="decoder", attn_type="decoder",
is_prefix_caching_supported=True, is_prefix_caching_supported=True,
is_chunked_prefill_supported=True, is_chunked_prefill_supported=True,
...@@ -79,7 +79,7 @@ RERANK_MODELS = [ ...@@ -79,7 +79,7 @@ RERANK_MODELS = [
"BAAI/bge-reranker-base", "BAAI/bge-reranker-base",
architecture="XLMRobertaForSequenceClassification", architecture="XLMRobertaForSequenceClassification",
mteb_score=0.32398, mteb_score=0.32398,
pooling_type="CLS", seq_pooling_type="CLS",
attn_type="encoder_only", attn_type="encoder_only",
is_prefix_caching_supported=False, is_prefix_caching_supported=False,
is_chunked_prefill_supported=False, is_chunked_prefill_supported=False,
......
...@@ -26,7 +26,7 @@ RERANK_MODELS = [ ...@@ -26,7 +26,7 @@ RERANK_MODELS = [
"method": "no_post_processing", "method": "no_post_processing",
}, },
mteb_score=0.33757, mteb_score=0.33757,
pooling_type="LAST", seq_pooling_type="LAST",
attn_type="decoder", attn_type="decoder",
is_prefix_caching_supported=True, is_prefix_caching_supported=True,
is_chunked_prefill_supported=True, is_chunked_prefill_supported=True,
......
...@@ -12,7 +12,7 @@ RERANK_MODELS = [ ...@@ -12,7 +12,7 @@ RERANK_MODELS = [
RerankModelInfo( RerankModelInfo(
"cross-encoder/ms-marco-TinyBERT-L-2-v2", "cross-encoder/ms-marco-TinyBERT-L-2-v2",
architecture="BertForSequenceClassification", architecture="BertForSequenceClassification",
pooling_type="CLS", seq_pooling_type="CLS",
attn_type="encoder_only", attn_type="encoder_only",
is_prefix_caching_supported=False, is_prefix_caching_supported=False,
is_chunked_prefill_supported=False, is_chunked_prefill_supported=False,
...@@ -21,7 +21,7 @@ RERANK_MODELS = [ ...@@ -21,7 +21,7 @@ RERANK_MODELS = [
RerankModelInfo( RerankModelInfo(
"tomaarsen/Qwen3-Reranker-0.6B-seq-cls", "tomaarsen/Qwen3-Reranker-0.6B-seq-cls",
architecture="Qwen3ForSequenceClassification", architecture="Qwen3ForSequenceClassification",
pooling_type="LAST", seq_pooling_type="LAST",
attn_type="decoder", attn_type="decoder",
is_prefix_caching_supported=True, is_prefix_caching_supported=True,
is_chunked_prefill_supported=True, is_chunked_prefill_supported=True,
......
...@@ -18,7 +18,7 @@ MODELS = [ ...@@ -18,7 +18,7 @@ MODELS = [
"thenlper/gte-large", "thenlper/gte-large",
mteb_score=0.76807651, mteb_score=0.76807651,
architecture="BertModel", architecture="BertModel",
pooling_type="MEAN", seq_pooling_type="MEAN",
attn_type="encoder_only", attn_type="encoder_only",
is_prefix_caching_supported=False, is_prefix_caching_supported=False,
is_chunked_prefill_supported=False, is_chunked_prefill_supported=False,
...@@ -44,7 +44,7 @@ MODELS = [ ...@@ -44,7 +44,7 @@ MODELS = [
architecture="GteNewModel", architecture="GteNewModel",
mteb_score=0.775074696, mteb_score=0.775074696,
hf_overrides={"architectures": ["GteNewModel"]}, hf_overrides={"architectures": ["GteNewModel"]},
pooling_type="CLS", seq_pooling_type="CLS",
attn_type="encoder_only", attn_type="encoder_only",
is_prefix_caching_supported=False, is_prefix_caching_supported=False,
is_chunked_prefill_supported=False, is_chunked_prefill_supported=False,
...@@ -67,7 +67,7 @@ MODELS = [ ...@@ -67,7 +67,7 @@ MODELS = [
"Alibaba-NLP/gte-Qwen2-1.5B-instruct", "Alibaba-NLP/gte-Qwen2-1.5B-instruct",
mteb_score=0.758473459018872, mteb_score=0.758473459018872,
architecture="Qwen2ForCausalLM", architecture="Qwen2ForCausalLM",
pooling_type="LAST", seq_pooling_type="LAST",
attn_type="encoder_only", attn_type="encoder_only",
is_prefix_caching_supported=False, is_prefix_caching_supported=False,
is_chunked_prefill_supported=False, is_chunked_prefill_supported=False,
...@@ -78,7 +78,7 @@ MODELS = [ ...@@ -78,7 +78,7 @@ MODELS = [
"Alibaba-NLP/gte-modernbert-base", "Alibaba-NLP/gte-modernbert-base",
mteb_score=0.748193353, mteb_score=0.748193353,
architecture="ModernBertModel", architecture="ModernBertModel",
pooling_type="CLS", seq_pooling_type="CLS",
attn_type="encoder_only", attn_type="encoder_only",
is_prefix_caching_supported=False, is_prefix_caching_supported=False,
is_chunked_prefill_supported=False, is_chunked_prefill_supported=False,
...@@ -89,7 +89,7 @@ MODELS = [ ...@@ -89,7 +89,7 @@ MODELS = [
"Qwen/Qwen3-Embedding-0.6B", "Qwen/Qwen3-Embedding-0.6B",
mteb_score=0.771163695, mteb_score=0.771163695,
architecture="Qwen3ForCausalLM", architecture="Qwen3ForCausalLM",
pooling_type="LAST", seq_pooling_type="LAST",
attn_type="decoder", attn_type="decoder",
is_prefix_caching_supported=True, is_prefix_caching_supported=True,
is_chunked_prefill_supported=True, is_chunked_prefill_supported=True,
...@@ -108,7 +108,7 @@ RERANK_MODELS = [ ...@@ -108,7 +108,7 @@ RERANK_MODELS = [
"Alibaba-NLP/gte-reranker-modernbert-base", "Alibaba-NLP/gte-reranker-modernbert-base",
mteb_score=0.33386, mteb_score=0.33386,
architecture="ModernBertForSequenceClassification", architecture="ModernBertForSequenceClassification",
pooling_type="CLS", seq_pooling_type="CLS",
attn_type="encoder_only", attn_type="encoder_only",
is_prefix_caching_supported=False, is_prefix_caching_supported=False,
is_chunked_prefill_supported=False, is_chunked_prefill_supported=False,
...@@ -119,7 +119,7 @@ RERANK_MODELS = [ ...@@ -119,7 +119,7 @@ RERANK_MODELS = [
mteb_score=0.33062, mteb_score=0.33062,
architecture="GteNewForSequenceClassification", architecture="GteNewForSequenceClassification",
hf_overrides={"architectures": ["GteNewForSequenceClassification"]}, hf_overrides={"architectures": ["GteNewForSequenceClassification"]},
pooling_type="CLS", seq_pooling_type="CLS",
attn_type="encoder_only", attn_type="encoder_only",
is_prefix_caching_supported=False, is_prefix_caching_supported=False,
is_chunked_prefill_supported=False, is_chunked_prefill_supported=False,
......
...@@ -13,7 +13,7 @@ MODELS = [ ...@@ -13,7 +13,7 @@ MODELS = [
"intfloat/e5-small", "intfloat/e5-small",
architecture="BertModel", architecture="BertModel",
mteb_score=0.742285423, mteb_score=0.742285423,
pooling_type="MEAN", seq_pooling_type="MEAN",
attn_type="encoder_only", attn_type="encoder_only",
is_prefix_caching_supported=False, is_prefix_caching_supported=False,
is_chunked_prefill_supported=False, is_chunked_prefill_supported=False,
...@@ -29,7 +29,7 @@ MODELS = [ ...@@ -29,7 +29,7 @@ MODELS = [
"intfloat/multilingual-e5-base", "intfloat/multilingual-e5-base",
architecture="XLMRobertaModel", architecture="XLMRobertaModel",
mteb_score=0.779325955, mteb_score=0.779325955,
pooling_type="MEAN", seq_pooling_type="MEAN",
attn_type="encoder_only", attn_type="encoder_only",
is_prefix_caching_supported=False, is_prefix_caching_supported=False,
is_chunked_prefill_supported=False, is_chunked_prefill_supported=False,
......
...@@ -24,7 +24,7 @@ EMBEDDING_MODELS = [ ...@@ -24,7 +24,7 @@ EMBEDDING_MODELS = [
mteb_score=0.824413164, mteb_score=0.824413164,
architecture="XLMRobertaModel", architecture="XLMRobertaModel",
is_matryoshka=True, is_matryoshka=True,
pooling_type="MEAN", seq_pooling_type="MEAN",
attn_type="encoder_only", attn_type="encoder_only",
is_prefix_caching_supported=False, is_prefix_caching_supported=False,
is_chunked_prefill_supported=False, is_chunked_prefill_supported=False,
...@@ -36,7 +36,7 @@ RERANK_MODELS = [ ...@@ -36,7 +36,7 @@ RERANK_MODELS = [
"jinaai/jina-reranker-v2-base-multilingual", "jinaai/jina-reranker-v2-base-multilingual",
mteb_score=0.33643, mteb_score=0.33643,
architecture="XLMRobertaForSequenceClassification", architecture="XLMRobertaForSequenceClassification",
pooling_type="CLS", seq_pooling_type="CLS",
attn_type="encoder_only", attn_type="encoder_only",
is_prefix_caching_supported=False, is_prefix_caching_supported=False,
is_chunked_prefill_supported=False, is_chunked_prefill_supported=False,
......
...@@ -24,7 +24,7 @@ RERANK_MODELS = [ ...@@ -24,7 +24,7 @@ RERANK_MODELS = [
"mixedbread-ai/mxbai-rerank-base-v2", "mixedbread-ai/mxbai-rerank-base-v2",
architecture="Qwen2ForSequenceClassification", architecture="Qwen2ForSequenceClassification",
hf_overrides=mxbai_rerank_hf_overrides, hf_overrides=mxbai_rerank_hf_overrides,
pooling_type="LAST", seq_pooling_type="LAST",
attn_type="decoder", attn_type="decoder",
is_prefix_caching_supported=True, is_prefix_caching_supported=True,
is_chunked_prefill_supported=True, is_chunked_prefill_supported=True,
......
...@@ -19,7 +19,7 @@ EMBEDDING_MODELS = [ ...@@ -19,7 +19,7 @@ EMBEDDING_MODELS = [
"nvidia/llama-nemotron-embed-1b-v2", "nvidia/llama-nemotron-embed-1b-v2",
architecture="LlamaBidirectionalModel", architecture="LlamaBidirectionalModel",
mteb_score=0.689164662128673, mteb_score=0.689164662128673,
pooling_type="MEAN", seq_pooling_type="MEAN",
attn_type="encoder_only", attn_type="encoder_only",
is_prefix_caching_supported=False, is_prefix_caching_supported=False,
is_chunked_prefill_supported=False, is_chunked_prefill_supported=False,
...@@ -32,7 +32,7 @@ RERANK_MODELS = [ ...@@ -32,7 +32,7 @@ RERANK_MODELS = [
architecture="LlamaBidirectionalForSequenceClassification", architecture="LlamaBidirectionalForSequenceClassification",
chat_template_name="nemotron-rerank.jinja", chat_template_name="nemotron-rerank.jinja",
mteb_score=0.33994, mteb_score=0.33994,
pooling_type="MEAN", seq_pooling_type="MEAN",
attn_type="encoder_only", attn_type="encoder_only",
is_prefix_caching_supported=False, is_prefix_caching_supported=False,
is_chunked_prefill_supported=False, is_chunked_prefill_supported=False,
......
...@@ -14,7 +14,7 @@ MODELS = [ ...@@ -14,7 +14,7 @@ MODELS = [
architecture="NomicBertModel", architecture="NomicBertModel",
mteb_score=0.737568559, mteb_score=0.737568559,
enable_test=True, enable_test=True,
pooling_type="MEAN", seq_pooling_type="MEAN",
attn_type="encoder_only", attn_type="encoder_only",
is_prefix_caching_supported=False, is_prefix_caching_supported=False,
is_chunked_prefill_supported=False, is_chunked_prefill_supported=False,
...@@ -32,7 +32,7 @@ MODELS = [ ...@@ -32,7 +32,7 @@ MODELS = [
architecture="NomicBertModel", architecture="NomicBertModel",
mteb_score=0.715488912, mteb_score=0.715488912,
enable_test=True, enable_test=True,
pooling_type="MEAN", seq_pooling_type="MEAN",
attn_type="encoder_only", attn_type="encoder_only",
is_prefix_caching_supported=False, is_prefix_caching_supported=False,
is_chunked_prefill_supported=False, is_chunked_prefill_supported=False,
......
...@@ -27,7 +27,7 @@ RERANK_MODELS = [ ...@@ -27,7 +27,7 @@ RERANK_MODELS = [
architecture="Qwen3ForSequenceClassification", architecture="Qwen3ForSequenceClassification",
hf_overrides=qwen3_reranker_hf_overrides, hf_overrides=qwen3_reranker_hf_overrides,
chat_template_name="qwen3_reranker.jinja", chat_template_name="qwen3_reranker.jinja",
pooling_type="LAST", seq_pooling_type="LAST",
attn_type="decoder", attn_type="decoder",
is_prefix_caching_supported=True, is_prefix_caching_supported=True,
is_chunked_prefill_supported=True, is_chunked_prefill_supported=True,
......
...@@ -14,7 +14,7 @@ MODELS = [ ...@@ -14,7 +14,7 @@ MODELS = [
is_matryoshka=False, is_matryoshka=False,
architecture="BertModel", architecture="BertModel",
mteb_score=0.714927797, mteb_score=0.714927797,
pooling_type="CLS", seq_pooling_type="CLS",
attn_type="encoder_only", attn_type="encoder_only",
is_prefix_caching_supported=False, is_prefix_caching_supported=False,
is_chunked_prefill_supported=False, is_chunked_prefill_supported=False,
...@@ -37,7 +37,7 @@ MODELS = [ ...@@ -37,7 +37,7 @@ MODELS = [
is_matryoshka=False, is_matryoshka=False,
architecture="NomicBertModel", architecture="NomicBertModel",
mteb_score=0.681146831, mteb_score=0.681146831,
pooling_type="CLS", seq_pooling_type="CLS",
attn_type="encoder_only", attn_type="encoder_only",
is_prefix_caching_supported=False, is_prefix_caching_supported=False,
is_chunked_prefill_supported=False, is_chunked_prefill_supported=False,
...@@ -54,7 +54,7 @@ MODELS = [ ...@@ -54,7 +54,7 @@ MODELS = [
is_matryoshka=True, is_matryoshka=True,
architecture="BertModel", architecture="BertModel",
mteb_score=0.649088363, mteb_score=0.649088363,
pooling_type="CLS", seq_pooling_type="CLS",
attn_type="encoder_only", attn_type="encoder_only",
is_prefix_caching_supported=False, is_prefix_caching_supported=False,
is_chunked_prefill_supported=False, is_chunked_prefill_supported=False,
...@@ -65,7 +65,7 @@ MODELS = [ ...@@ -65,7 +65,7 @@ MODELS = [
is_matryoshka=True, is_matryoshka=True,
architecture="XLMRobertaModel", architecture="XLMRobertaModel",
mteb_score=0.712258299, mteb_score=0.712258299,
pooling_type="CLS", seq_pooling_type="CLS",
attn_type="encoder_only", attn_type="encoder_only",
is_prefix_caching_supported=False, is_prefix_caching_supported=False,
is_chunked_prefill_supported=False, is_chunked_prefill_supported=False,
...@@ -76,7 +76,7 @@ MODELS = [ ...@@ -76,7 +76,7 @@ MODELS = [
is_matryoshka=True, is_matryoshka=True,
architecture="GteModel", architecture="GteModel",
mteb_score=0.706622444, mteb_score=0.706622444,
pooling_type="CLS", seq_pooling_type="CLS",
attn_type="encoder_only", attn_type="encoder_only",
is_prefix_caching_supported=False, is_prefix_caching_supported=False,
is_chunked_prefill_supported=False, is_chunked_prefill_supported=False,
......
...@@ -14,7 +14,7 @@ ST_PROJECTOR_MODELS = [ ...@@ -14,7 +14,7 @@ ST_PROJECTOR_MODELS = [
"TencentBAC/Conan-embedding-v1", "TencentBAC/Conan-embedding-v1",
architecture="BertModel", architecture="BertModel",
mteb_score=0.688611955, mteb_score=0.688611955,
pooling_type="MEAN", seq_pooling_type="MEAN",
attn_type="encoder_only", attn_type="encoder_only",
is_prefix_caching_supported=False, is_prefix_caching_supported=False,
is_chunked_prefill_supported=False, is_chunked_prefill_supported=False,
...@@ -24,7 +24,7 @@ ST_PROJECTOR_MODELS = [ ...@@ -24,7 +24,7 @@ ST_PROJECTOR_MODELS = [
"google/embeddinggemma-300m", "google/embeddinggemma-300m",
architecture="Gemma3TextModel", architecture="Gemma3TextModel",
mteb_score=0.7473819294684156, mteb_score=0.7473819294684156,
pooling_type="MEAN", seq_pooling_type="MEAN",
attn_type="encoder_only", attn_type="encoder_only",
is_prefix_caching_supported=False, is_prefix_caching_supported=False,
is_chunked_prefill_supported=False, is_chunked_prefill_supported=False,
......
...@@ -11,6 +11,7 @@ import torch.nn.functional as F ...@@ -11,6 +11,7 @@ import torch.nn.functional as F
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.config.model import AttnTypeStr, ModelConfig, ModelDType, RunnerOption from vllm.config.model import AttnTypeStr, ModelConfig, ModelDType, RunnerOption
from vllm.config.pooler import SequencePoolingType, TokenPoolingType
from vllm.logprobs import Logprob, PromptLogprobs, SampleLogprobs from vllm.logprobs import Logprob, PromptLogprobs, SampleLogprobs
from vllm.multimodal.processing import InputProcessingContext from vllm.multimodal.processing import InputProcessingContext
from vllm.tokenizers import cached_tokenizer_from_config from vllm.tokenizers import cached_tokenizer_from_config
...@@ -379,7 +380,8 @@ class ModelInfo: ...@@ -379,7 +380,8 @@ class ModelInfo:
max_model_len: int | None = None max_model_len: int | None = None
hf_dtype: str = "float32" hf_dtype: str = "float32"
hf_overrides: dict[str, Any] | None = None hf_overrides: dict[str, Any] | None = None
pooling_type: str | None = None seq_pooling_type: SequencePoolingType | None = None
tok_pooling_type: TokenPoolingType | None = None
attn_type: AttnTypeStr | None = None attn_type: AttnTypeStr | None = None
is_prefix_caching_supported: bool | None = None is_prefix_caching_supported: bool | None = None
is_chunked_prefill_supported: bool | None = None is_chunked_prefill_supported: bool | None = None
......
...@@ -161,7 +161,8 @@ def test_get_pooling_config(): ...@@ -161,7 +161,8 @@ def test_get_pooling_config():
assert model_config.pooler_config is not None assert model_config.pooler_config is not None
assert model_config.pooler_config.normalize assert model_config.pooler_config.normalize
assert model_config.pooler_config.pooling_type == "MEAN" assert model_config.pooler_config.seq_pooling_type == "MEAN"
assert model_config.pooler_config.tok_pooling_type == "ALL"
@pytest.mark.skipif( @pytest.mark.skipif(
...@@ -169,7 +170,7 @@ def test_get_pooling_config(): ...@@ -169,7 +170,7 @@ def test_get_pooling_config():
) )
def test_get_pooling_config_from_args(): def test_get_pooling_config_from_args():
model_id = "sentence-transformers/all-MiniLM-L12-v2" model_id = "sentence-transformers/all-MiniLM-L12-v2"
pooler_config = PoolerConfig(pooling_type="CLS", normalize=True) pooler_config = PoolerConfig(seq_pooling_type="CLS", normalize=True)
model_config = ModelConfig(model_id, pooler_config=pooler_config) model_config = ModelConfig(model_id, pooler_config=pooler_config)
assert asdict(model_config.pooler_config) == asdict(pooler_config) assert asdict(model_config.pooler_config) == asdict(pooler_config)
...@@ -180,14 +181,25 @@ def test_get_pooling_config_from_args(): ...@@ -180,14 +181,25 @@ def test_get_pooling_config_from_args():
[ [
("tomaarsen/Qwen3-Reranker-0.6B-seq-cls", "LAST", "LAST"), # LLM ("tomaarsen/Qwen3-Reranker-0.6B-seq-cls", "LAST", "LAST"), # LLM
("intfloat/e5-small", "CLS", "MEAN"), # BertModel ("intfloat/e5-small", "CLS", "MEAN"), # BertModel
],
)
def test_default_seq_pooling_type(model_id, default_pooling_type, pooling_type):
model_config = ModelConfig(model_id)
assert model_config._model_info.default_seq_pooling_type == default_pooling_type
assert model_config.pooler_config.seq_pooling_type == pooling_type
@pytest.mark.parametrize(
("model_id", "default_pooling_type", "pooling_type"),
[
("Qwen/Qwen2.5-Math-RM-72B", "ALL", "ALL"), # reward ("Qwen/Qwen2.5-Math-RM-72B", "ALL", "ALL"), # reward
("Qwen/Qwen2.5-Math-PRM-7B", "STEP", "STEP"), # step reward ("Qwen/Qwen2.5-Math-PRM-7B", "STEP", "STEP"), # step reward
], ],
) )
def test_default_pooling_type(model_id, default_pooling_type, pooling_type): def test_default_tok_pooling_type(model_id, default_pooling_type, pooling_type):
model_config = ModelConfig(model_id) model_config = ModelConfig(model_id)
assert model_config._model_info.default_pooling_type == default_pooling_type assert model_config._model_info.default_tok_pooling_type == default_pooling_type
assert model_config.pooler_config.pooling_type == pooling_type assert model_config.pooler_config.tok_pooling_type == pooling_type
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -554,100 +566,100 @@ def test_s3_url_different_models_create_different_directories(mock_pull_files): ...@@ -554,100 +566,100 @@ def test_s3_url_different_models_create_different_directories(mock_pull_files):
"jason9693/Qwen2.5-1.5B-apeach", "jason9693/Qwen2.5-1.5B-apeach",
"decoder", "decoder",
True, True,
"Pooling models with causal attn and last pooling support chunked prefill.", "Pooling models with causal attn and LAST/ALL pooling support chunked prefill.", # noqa: E501
), ),
( (
"Qwen/Qwen3-Embedding-0.6B", "Qwen/Qwen3-Embedding-0.6B",
"decoder", "decoder",
True, True,
"Pooling models with causal attn and last pooling support chunked prefill.", "Pooling models with causal attn and LAST/ALL pooling support chunked prefill.", # noqa: E501
), ),
( (
"Qwen/Qwen2.5-Math-PRM-7B", "Qwen/Qwen2.5-Math-PRM-7B",
"decoder", "decoder",
False, False,
"Pooling models with step pooling does not support chunked prefill.", "Pooling models with causal attn and LAST/STEP pooling do not support chunked prefill.", # noqa: E501
), ),
( (
"internlm/internlm2-1_8b-reward", "internlm/internlm2-1_8b-reward",
"decoder", "decoder",
True, True,
"Pooling models with causal attn and all pooling support chunked prefill.", "Pooling models with causal attn and LAST/ALL pooling support chunked prefill.", # noqa: E501
), ),
( (
"BAAI/bge-base-en", "BAAI/bge-base-en",
"encoder_only", "encoder_only",
False, False,
"Pooling models with bidirectional attn does not support chunked prefill.", "Pooling models with bidirectional attn do not support chunked prefill.", # noqa: E501
), ),
( (
"boltuix/NeuroBERT-NER", "boltuix/NeuroBERT-NER",
"encoder_only", "encoder_only",
False, False,
"Pooling models with bidirectional attn does not support chunked prefill.", "Pooling models with bidirectional attn do not support chunked prefill.", # noqa: E501
), ),
( (
"papluca/xlm-roberta-base-language-detection", "papluca/xlm-roberta-base-language-detection",
"encoder_only", "encoder_only",
False, False,
"Pooling models with bidirectional attn does not support chunked prefill.", "Pooling models with bidirectional attn do not support chunked prefill.", # noqa: E501
), ),
( (
"Alibaba-NLP/gte-Qwen2-1.5B-instruct", "Alibaba-NLP/gte-Qwen2-1.5B-instruct",
"encoder_only", "encoder_only",
False, False,
"Pooling models with bidirectional attn does not support chunked prefill.", "Pooling models with bidirectional attn do not support chunked prefill.", # noqa: E501
), ),
( (
"intfloat/e5-small", "intfloat/e5-small",
"encoder_only", "encoder_only",
False, False,
"Pooling models with bidirectional attn does not support chunked prefill.", "Pooling models with bidirectional attn do not support chunked prefill.", # noqa: E501
), ),
# multimodal models # multimodal models
( (
"openai/clip-vit-base-patch32", "openai/clip-vit-base-patch32",
"decoder", "decoder",
True, True,
"Pooling models with causal attn and last pooling support chunked prefill.", "Pooling models with causal attn and LAST/ALL pooling support chunked prefill.", # noqa: E501
), ),
( (
"google/siglip-base-patch16-224", "google/siglip-base-patch16-224",
"encoder_only", "encoder_only",
False, False,
"Pooling models with bidirectional attn does not support chunked prefill.", "Pooling models with bidirectional attn do not support chunked prefill.", # noqa: E501
), ),
# generate models # generate models
( (
"Qwen/Qwen3-0.6B", "Qwen/Qwen3-0.6B",
"decoder", "decoder",
True, True,
"Generative models support chunked prefill.", "Generative models support chunked prefill.", # noqa: E501
), ),
( (
"Qwen/Qwen3-Next-80B-A3B-Instruct", "Qwen/Qwen3-Next-80B-A3B-Instruct",
"hybrid", "hybrid",
True, True,
"Generative models support chunked prefill.", "Generative models support chunked prefill.", # noqa: E501
), ),
( (
"ibm-granite/granite-4.0-h-small", "ibm-granite/granite-4.0-h-small",
"hybrid", "hybrid",
True, True,
"Generative models support chunked prefill.", "Generative models support chunked prefill.", # noqa: E501
), ),
( (
"state-spaces/mamba-130m-hf", "state-spaces/mamba-130m-hf",
"attention_free", "attention_free",
True, True,
"Generative models support chunked prefill.", "Generative models support chunked prefill.", # noqa: E501
), ),
# encoder_decoder models # encoder_decoder models
( (
"openai/whisper-small", "openai/whisper-small",
"encoder_decoder", "encoder_decoder",
False, False,
"Encoder decoder models does not support chunked prefill.", "Encoder decoder models do not support chunked prefill.", # noqa: E501
), ),
], ],
) )
...@@ -673,100 +685,100 @@ def test_is_chunked_prefill_supported( ...@@ -673,100 +685,100 @@ def test_is_chunked_prefill_supported(
"jason9693/Qwen2.5-1.5B-apeach", "jason9693/Qwen2.5-1.5B-apeach",
"decoder", "decoder",
True, True,
"Pooling models with causal attn and last pooling support prefix caching.", "Pooling models with causal attn and LAST/ALL pooling support prefix caching.", # noqa: E501
), ),
( (
"Qwen/Qwen3-Embedding-0.6B", "Qwen/Qwen3-Embedding-0.6B",
"decoder", "decoder",
True, True,
"Pooling models with causal attn and last pooling support prefix caching.", "Pooling models with causal attn and LAST/ALL pooling support prefix caching.", # noqa: E501
), ),
( (
"Qwen/Qwen2.5-Math-PRM-7B", "Qwen/Qwen2.5-Math-PRM-7B",
"decoder", "decoder",
False, False,
"Pooling models with step pooling does not support prefix caching.", "Pooling models with causal attn and LAST/STEP pooling do not support prefix caching.", # noqa: E501
), ),
( (
"internlm/internlm2-1_8b-reward", "internlm/internlm2-1_8b-reward",
"decoder", "decoder",
True, True,
"Pooling models with causal attn and all pooling support prefix caching.", "Pooling models with causal attn and LAST/ALL pooling support prefix caching.", # noqa: E501
), ),
( (
"BAAI/bge-base-en", "BAAI/bge-base-en",
"encoder_only", "encoder_only",
False, False,
"Pooling models with bidirectional attn does not support prefix caching.", "Pooling models with bidirectional attn do not support prefix caching.", # noqa: E501
), ),
( (
"boltuix/NeuroBERT-NER", "boltuix/NeuroBERT-NER",
"encoder_only", "encoder_only",
False, False,
"Pooling models with bidirectional attn does not support prefix caching.", "Pooling models with bidirectional attn do not support prefix caching.", # noqa: E501
), ),
( (
"papluca/xlm-roberta-base-language-detection", "papluca/xlm-roberta-base-language-detection",
"encoder_only", "encoder_only",
False, False,
"Pooling models with bidirectional attn does not support prefix caching.", "Pooling models with bidirectional attn do not support prefix caching.", # noqa: E501
), ),
( (
"Alibaba-NLP/gte-Qwen2-1.5B-instruct", "Alibaba-NLP/gte-Qwen2-1.5B-instruct",
"encoder_only", "encoder_only",
False, False,
"Pooling models with bidirectional attn does not support prefix caching.", "Pooling models with bidirectional attn do not support prefix caching.", # noqa: E501
), ),
( (
"intfloat/e5-small", "intfloat/e5-small",
"encoder_only", "encoder_only",
False, False,
"Pooling models with bidirectional attn does not support prefix caching.", "Pooling models with bidirectional attn do not support prefix caching.", # noqa: E501
), ),
# multimodal models # multimodal models
( (
"openai/clip-vit-base-patch32", "openai/clip-vit-base-patch32",
"decoder", "decoder",
True, True,
"Pooling models with causal attn and last pooling support prefix caching.", "Pooling models with causal attn and LAST/ALL pooling support prefix caching.", # noqa: E501
), ),
( (
"google/siglip-base-patch16-224", "google/siglip-base-patch16-224",
"encoder_only", "encoder_only",
False, False,
"Pooling models with bidirectional attn does not support prefix caching.", "Pooling models with bidirectional attn do not support prefix caching.", # noqa: E501
), ),
# generate models # generate models
( (
"Qwen/Qwen3-0.6B", "Qwen/Qwen3-0.6B",
"decoder", "decoder",
True, True,
"Generative models support prefix caching.", "Generative models support prefix caching.", # noqa: E501
), ),
( (
"Qwen/Qwen3-Next-80B-A3B-Instruct", "Qwen/Qwen3-Next-80B-A3B-Instruct",
"hybrid", "hybrid",
False, False,
"Hybrid models does not support prefix caching since the feature is still experimental.", # noqa: E501 "Hybrid models do not support prefix caching since the feature is still experimental.", # noqa: E501
), ),
( (
"ibm-granite/granite-4.0-h-small", "ibm-granite/granite-4.0-h-small",
"hybrid", "hybrid",
False, False,
"Hybrid models does not support prefix caching since the feature is still experimental.", # noqa: E501 "Hybrid models do not support prefix caching since the feature is still experimental.", # noqa: E501
), ),
( (
"state-spaces/mamba-130m-hf", "state-spaces/mamba-130m-hf",
"attention_free", "attention_free",
False, False,
"Attention free models does not support prefix caching since the feature is still experimental.", # noqa: E501 "Attention free models do not support prefix caching since the feature is still experimental.", # noqa: E501
), ),
# encoder_decoder models # encoder_decoder models
( (
"openai/whisper-small", "openai/whisper-small",
"encoder_decoder", "encoder_decoder",
False, False,
"Encoder decoder models does not support prefix caching.", "Encoder decoder models do not support prefix caching.", # noqa: E501
), ),
], ],
) )
......
...@@ -40,7 +40,7 @@ def test_task(): ...@@ -40,7 +40,7 @@ def test_task():
def test_embed(): def test_embed():
task = "embed" task = "embed"
model_config = MockModelConfig(pooler_config=PoolerConfig(pooling_type="CLS")) model_config = MockModelConfig(pooler_config=PoolerConfig(seq_pooling_type="CLS"))
pooling_params = PoolingParams(normalize=None) pooling_params = PoolingParams(normalize=None)
pooling_params.verify(task=task, model_config=model_config) pooling_params.verify(task=task, model_config=model_config)
...@@ -86,7 +86,7 @@ def test_embed_dimensions(model_info: EmbedModelInfo): ...@@ -86,7 +86,7 @@ def test_embed_dimensions(model_info: EmbedModelInfo):
@pytest.mark.parametrize("task", ["score", "classify"]) @pytest.mark.parametrize("task", ["score", "classify"])
def test_classify(task): def test_classify(task):
model_config = MockModelConfig(pooler_config=PoolerConfig(pooling_type="CLS")) model_config = MockModelConfig(pooler_config=PoolerConfig(seq_pooling_type="CLS"))
pooling_params = PoolingParams(use_activation=None) pooling_params = PoolingParams(use_activation=None)
pooling_params.verify(task=task, model_config=model_config) pooling_params.verify(task=task, model_config=model_config)
...@@ -108,7 +108,7 @@ def test_classify(task): ...@@ -108,7 +108,7 @@ def test_classify(task):
def test_token_embed(pooling_type: str): def test_token_embed(pooling_type: str):
task = "token_embed" task = "token_embed"
model_config = MockModelConfig( model_config = MockModelConfig(
pooler_config=PoolerConfig(pooling_type=pooling_type) pooler_config=PoolerConfig(tok_pooling_type=pooling_type)
) )
pooling_params = PoolingParams(normalize=None) pooling_params = PoolingParams(normalize=None)
...@@ -134,7 +134,7 @@ def test_token_embed(pooling_type: str): ...@@ -134,7 +134,7 @@ def test_token_embed(pooling_type: str):
def test_token_classify(pooling_type: str): def test_token_classify(pooling_type: str):
task = "token_classify" task = "token_classify"
model_config = MockModelConfig( model_config = MockModelConfig(
pooler_config=PoolerConfig(pooling_type=pooling_type) pooler_config=PoolerConfig(tok_pooling_type=pooling_type)
) )
pooling_params = PoolingParams(use_activation=None) pooling_params = PoolingParams(use_activation=None)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment