Unverified Commit 583a90e0 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Refactor] Separate sequence and token pooling types (#32026)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 52d42829
......@@ -46,7 +46,8 @@ def test_model_loading_with_params(vllm_runner, monkeypatch):
assert model_config.encoder_config["do_lower_case"]
# asserts on the pooling config files
assert model_config.pooler_config.pooling_type == "CLS"
assert model_config.pooler_config.seq_pooling_type == "CLS"
assert model_config.pooler_config.tok_pooling_type == "ALL"
assert model_config.pooler_config.normalize
# asserts on the tokenizer loaded
......@@ -90,7 +91,8 @@ def test_roberta_model_loading_with_params(vllm_runner, monkeypatch):
assert not model_config.encoder_config["do_lower_case"]
# asserts on the pooling config files
assert model_config.pooler_config.pooling_type == "MEAN"
assert model_config.pooler_config.seq_pooling_type == "MEAN"
assert model_config.pooler_config.tok_pooling_type == "ALL"
assert model_config.pooler_config.normalize
# asserts on the tokenizer loaded
......
......@@ -54,7 +54,7 @@ def test_models(
vllm_extra_kwargs = {}
if model == "ssmits/Qwen2-7B-Instruct-embed-base":
vllm_extra_kwargs["pooler_config"] = PoolerConfig(
pooling_type="MEAN", normalize=False
seq_pooling_type="MEAN", normalize=False
)
max_model_len: int | None = 512
......
......@@ -88,7 +88,7 @@ def test_gemma_multimodal(
convert="classify",
load_format="auto",
hf_overrides=update_config,
pooler_config=PoolerConfig(pooling_type="LAST"),
pooler_config=PoolerConfig(seq_pooling_type="LAST"),
max_model_len=512,
enforce_eager=True,
tensor_parallel_size=1,
......
......@@ -162,8 +162,11 @@ def mteb_test_embed_models(
assert model_info.architecture in model_config.architectures
# Confirm whether the important configs in model_config are correct.
if model_info.pooling_type is not None:
assert model_config.pooler_config.pooling_type == model_info.pooling_type
pooler_config = model_config.pooler_config
if model_info.seq_pooling_type is not None:
assert pooler_config.seq_pooling_type == model_info.seq_pooling_type
if model_info.tok_pooling_type is not None:
assert pooler_config.tok_pooling_type == model_info.tok_pooling_type
if model_info.attn_type is not None:
assert model_config.attn_type == model_info.attn_type
if model_info.is_prefix_caching_supported is not None:
......
......@@ -254,8 +254,11 @@ def mteb_test_rerank_models(
assert model_config.hf_config.num_labels == 1
# Confirm whether the important configs in model_config are correct.
if model_info.pooling_type is not None:
assert model_config.pooler_config.pooling_type == model_info.pooling_type
pooler_config = model_config.pooler_config
if model_info.seq_pooling_type is not None:
assert pooler_config.seq_pooling_type == model_info.seq_pooling_type
if model_info.tok_pooling_type is not None:
assert pooler_config.tok_pooling_type == model_info.tok_pooling_type
if model_info.attn_type is not None:
assert model_config.attn_type == model_info.attn_type
if model_info.is_prefix_caching_supported is not None:
......
......@@ -17,7 +17,7 @@ MODELS = [
"BAAI/bge-base-en",
architecture="BertModel",
mteb_score=0.779336792,
pooling_type="CLS",
seq_pooling_type="CLS",
attn_type="encoder_only",
is_prefix_caching_supported=False,
is_chunked_prefill_supported=False,
......@@ -54,7 +54,7 @@ MODELS = [
"BAAI/bge-m3",
architecture="XLMRobertaModel",
mteb_score=0.787343078,
pooling_type="CLS",
seq_pooling_type="CLS",
attn_type="encoder_only",
is_prefix_caching_supported=False,
is_chunked_prefill_supported=False,
......@@ -65,7 +65,7 @@ MODELS = [
"BAAI/bge-code-v1",
architecture="Qwen2Model",
mteb_score=0.75724465,
pooling_type="LAST",
seq_pooling_type="LAST",
attn_type="decoder",
is_prefix_caching_supported=True,
is_chunked_prefill_supported=True,
......@@ -79,7 +79,7 @@ RERANK_MODELS = [
"BAAI/bge-reranker-base",
architecture="XLMRobertaForSequenceClassification",
mteb_score=0.32398,
pooling_type="CLS",
seq_pooling_type="CLS",
attn_type="encoder_only",
is_prefix_caching_supported=False,
is_chunked_prefill_supported=False,
......
......@@ -26,7 +26,7 @@ RERANK_MODELS = [
"method": "no_post_processing",
},
mteb_score=0.33757,
pooling_type="LAST",
seq_pooling_type="LAST",
attn_type="decoder",
is_prefix_caching_supported=True,
is_chunked_prefill_supported=True,
......
......@@ -12,7 +12,7 @@ RERANK_MODELS = [
RerankModelInfo(
"cross-encoder/ms-marco-TinyBERT-L-2-v2",
architecture="BertForSequenceClassification",
pooling_type="CLS",
seq_pooling_type="CLS",
attn_type="encoder_only",
is_prefix_caching_supported=False,
is_chunked_prefill_supported=False,
......@@ -21,7 +21,7 @@ RERANK_MODELS = [
RerankModelInfo(
"tomaarsen/Qwen3-Reranker-0.6B-seq-cls",
architecture="Qwen3ForSequenceClassification",
pooling_type="LAST",
seq_pooling_type="LAST",
attn_type="decoder",
is_prefix_caching_supported=True,
is_chunked_prefill_supported=True,
......
......@@ -18,7 +18,7 @@ MODELS = [
"thenlper/gte-large",
mteb_score=0.76807651,
architecture="BertModel",
pooling_type="MEAN",
seq_pooling_type="MEAN",
attn_type="encoder_only",
is_prefix_caching_supported=False,
is_chunked_prefill_supported=False,
......@@ -44,7 +44,7 @@ MODELS = [
architecture="GteNewModel",
mteb_score=0.775074696,
hf_overrides={"architectures": ["GteNewModel"]},
pooling_type="CLS",
seq_pooling_type="CLS",
attn_type="encoder_only",
is_prefix_caching_supported=False,
is_chunked_prefill_supported=False,
......@@ -67,7 +67,7 @@ MODELS = [
"Alibaba-NLP/gte-Qwen2-1.5B-instruct",
mteb_score=0.758473459018872,
architecture="Qwen2ForCausalLM",
pooling_type="LAST",
seq_pooling_type="LAST",
attn_type="encoder_only",
is_prefix_caching_supported=False,
is_chunked_prefill_supported=False,
......@@ -78,7 +78,7 @@ MODELS = [
"Alibaba-NLP/gte-modernbert-base",
mteb_score=0.748193353,
architecture="ModernBertModel",
pooling_type="CLS",
seq_pooling_type="CLS",
attn_type="encoder_only",
is_prefix_caching_supported=False,
is_chunked_prefill_supported=False,
......@@ -89,7 +89,7 @@ MODELS = [
"Qwen/Qwen3-Embedding-0.6B",
mteb_score=0.771163695,
architecture="Qwen3ForCausalLM",
pooling_type="LAST",
seq_pooling_type="LAST",
attn_type="decoder",
is_prefix_caching_supported=True,
is_chunked_prefill_supported=True,
......@@ -108,7 +108,7 @@ RERANK_MODELS = [
"Alibaba-NLP/gte-reranker-modernbert-base",
mteb_score=0.33386,
architecture="ModernBertForSequenceClassification",
pooling_type="CLS",
seq_pooling_type="CLS",
attn_type="encoder_only",
is_prefix_caching_supported=False,
is_chunked_prefill_supported=False,
......@@ -119,7 +119,7 @@ RERANK_MODELS = [
mteb_score=0.33062,
architecture="GteNewForSequenceClassification",
hf_overrides={"architectures": ["GteNewForSequenceClassification"]},
pooling_type="CLS",
seq_pooling_type="CLS",
attn_type="encoder_only",
is_prefix_caching_supported=False,
is_chunked_prefill_supported=False,
......
......@@ -13,7 +13,7 @@ MODELS = [
"intfloat/e5-small",
architecture="BertModel",
mteb_score=0.742285423,
pooling_type="MEAN",
seq_pooling_type="MEAN",
attn_type="encoder_only",
is_prefix_caching_supported=False,
is_chunked_prefill_supported=False,
......@@ -29,7 +29,7 @@ MODELS = [
"intfloat/multilingual-e5-base",
architecture="XLMRobertaModel",
mteb_score=0.779325955,
pooling_type="MEAN",
seq_pooling_type="MEAN",
attn_type="encoder_only",
is_prefix_caching_supported=False,
is_chunked_prefill_supported=False,
......
......@@ -24,7 +24,7 @@ EMBEDDING_MODELS = [
mteb_score=0.824413164,
architecture="XLMRobertaModel",
is_matryoshka=True,
pooling_type="MEAN",
seq_pooling_type="MEAN",
attn_type="encoder_only",
is_prefix_caching_supported=False,
is_chunked_prefill_supported=False,
......@@ -36,7 +36,7 @@ RERANK_MODELS = [
"jinaai/jina-reranker-v2-base-multilingual",
mteb_score=0.33643,
architecture="XLMRobertaForSequenceClassification",
pooling_type="CLS",
seq_pooling_type="CLS",
attn_type="encoder_only",
is_prefix_caching_supported=False,
is_chunked_prefill_supported=False,
......
......@@ -24,7 +24,7 @@ RERANK_MODELS = [
"mixedbread-ai/mxbai-rerank-base-v2",
architecture="Qwen2ForSequenceClassification",
hf_overrides=mxbai_rerank_hf_overrides,
pooling_type="LAST",
seq_pooling_type="LAST",
attn_type="decoder",
is_prefix_caching_supported=True,
is_chunked_prefill_supported=True,
......
......@@ -19,7 +19,7 @@ EMBEDDING_MODELS = [
"nvidia/llama-nemotron-embed-1b-v2",
architecture="LlamaBidirectionalModel",
mteb_score=0.689164662128673,
pooling_type="MEAN",
seq_pooling_type="MEAN",
attn_type="encoder_only",
is_prefix_caching_supported=False,
is_chunked_prefill_supported=False,
......@@ -32,7 +32,7 @@ RERANK_MODELS = [
architecture="LlamaBidirectionalForSequenceClassification",
chat_template_name="nemotron-rerank.jinja",
mteb_score=0.33994,
pooling_type="MEAN",
seq_pooling_type="MEAN",
attn_type="encoder_only",
is_prefix_caching_supported=False,
is_chunked_prefill_supported=False,
......
......@@ -14,7 +14,7 @@ MODELS = [
architecture="NomicBertModel",
mteb_score=0.737568559,
enable_test=True,
pooling_type="MEAN",
seq_pooling_type="MEAN",
attn_type="encoder_only",
is_prefix_caching_supported=False,
is_chunked_prefill_supported=False,
......@@ -32,7 +32,7 @@ MODELS = [
architecture="NomicBertModel",
mteb_score=0.715488912,
enable_test=True,
pooling_type="MEAN",
seq_pooling_type="MEAN",
attn_type="encoder_only",
is_prefix_caching_supported=False,
is_chunked_prefill_supported=False,
......
......@@ -27,7 +27,7 @@ RERANK_MODELS = [
architecture="Qwen3ForSequenceClassification",
hf_overrides=qwen3_reranker_hf_overrides,
chat_template_name="qwen3_reranker.jinja",
pooling_type="LAST",
seq_pooling_type="LAST",
attn_type="decoder",
is_prefix_caching_supported=True,
is_chunked_prefill_supported=True,
......
......@@ -14,7 +14,7 @@ MODELS = [
is_matryoshka=False,
architecture="BertModel",
mteb_score=0.714927797,
pooling_type="CLS",
seq_pooling_type="CLS",
attn_type="encoder_only",
is_prefix_caching_supported=False,
is_chunked_prefill_supported=False,
......@@ -37,7 +37,7 @@ MODELS = [
is_matryoshka=False,
architecture="NomicBertModel",
mteb_score=0.681146831,
pooling_type="CLS",
seq_pooling_type="CLS",
attn_type="encoder_only",
is_prefix_caching_supported=False,
is_chunked_prefill_supported=False,
......@@ -54,7 +54,7 @@ MODELS = [
is_matryoshka=True,
architecture="BertModel",
mteb_score=0.649088363,
pooling_type="CLS",
seq_pooling_type="CLS",
attn_type="encoder_only",
is_prefix_caching_supported=False,
is_chunked_prefill_supported=False,
......@@ -65,7 +65,7 @@ MODELS = [
is_matryoshka=True,
architecture="XLMRobertaModel",
mteb_score=0.712258299,
pooling_type="CLS",
seq_pooling_type="CLS",
attn_type="encoder_only",
is_prefix_caching_supported=False,
is_chunked_prefill_supported=False,
......@@ -76,7 +76,7 @@ MODELS = [
is_matryoshka=True,
architecture="GteModel",
mteb_score=0.706622444,
pooling_type="CLS",
seq_pooling_type="CLS",
attn_type="encoder_only",
is_prefix_caching_supported=False,
is_chunked_prefill_supported=False,
......
......@@ -14,7 +14,7 @@ ST_PROJECTOR_MODELS = [
"TencentBAC/Conan-embedding-v1",
architecture="BertModel",
mteb_score=0.688611955,
pooling_type="MEAN",
seq_pooling_type="MEAN",
attn_type="encoder_only",
is_prefix_caching_supported=False,
is_chunked_prefill_supported=False,
......@@ -24,7 +24,7 @@ ST_PROJECTOR_MODELS = [
"google/embeddinggemma-300m",
architecture="Gemma3TextModel",
mteb_score=0.7473819294684156,
pooling_type="MEAN",
seq_pooling_type="MEAN",
attn_type="encoder_only",
is_prefix_caching_supported=False,
is_chunked_prefill_supported=False,
......
......@@ -11,6 +11,7 @@ import torch.nn.functional as F
from transformers import PretrainedConfig
from vllm.config.model import AttnTypeStr, ModelConfig, ModelDType, RunnerOption
from vllm.config.pooler import SequencePoolingType, TokenPoolingType
from vllm.logprobs import Logprob, PromptLogprobs, SampleLogprobs
from vllm.multimodal.processing import InputProcessingContext
from vllm.tokenizers import cached_tokenizer_from_config
......@@ -379,7 +380,8 @@ class ModelInfo:
max_model_len: int | None = None
hf_dtype: str = "float32"
hf_overrides: dict[str, Any] | None = None
pooling_type: str | None = None
seq_pooling_type: SequencePoolingType | None = None
tok_pooling_type: TokenPoolingType | None = None
attn_type: AttnTypeStr | None = None
is_prefix_caching_supported: bool | None = None
is_chunked_prefill_supported: bool | None = None
......
......@@ -161,7 +161,8 @@ def test_get_pooling_config():
assert model_config.pooler_config is not None
assert model_config.pooler_config.normalize
assert model_config.pooler_config.pooling_type == "MEAN"
assert model_config.pooler_config.seq_pooling_type == "MEAN"
assert model_config.pooler_config.tok_pooling_type == "ALL"
@pytest.mark.skipif(
......@@ -169,7 +170,7 @@ def test_get_pooling_config():
)
def test_get_pooling_config_from_args():
model_id = "sentence-transformers/all-MiniLM-L12-v2"
pooler_config = PoolerConfig(pooling_type="CLS", normalize=True)
pooler_config = PoolerConfig(seq_pooling_type="CLS", normalize=True)
model_config = ModelConfig(model_id, pooler_config=pooler_config)
assert asdict(model_config.pooler_config) == asdict(pooler_config)
......@@ -180,14 +181,25 @@ def test_get_pooling_config_from_args():
[
("tomaarsen/Qwen3-Reranker-0.6B-seq-cls", "LAST", "LAST"), # LLM
("intfloat/e5-small", "CLS", "MEAN"), # BertModel
],
)
def test_default_seq_pooling_type(model_id, default_pooling_type, pooling_type):
model_config = ModelConfig(model_id)
assert model_config._model_info.default_seq_pooling_type == default_pooling_type
assert model_config.pooler_config.seq_pooling_type == pooling_type
@pytest.mark.parametrize(
("model_id", "default_pooling_type", "pooling_type"),
[
("Qwen/Qwen2.5-Math-RM-72B", "ALL", "ALL"), # reward
("Qwen/Qwen2.5-Math-PRM-7B", "STEP", "STEP"), # step reward
],
)
def test_default_pooling_type(model_id, default_pooling_type, pooling_type):
def test_default_tok_pooling_type(model_id, default_pooling_type, pooling_type):
model_config = ModelConfig(model_id)
assert model_config._model_info.default_pooling_type == default_pooling_type
assert model_config.pooler_config.pooling_type == pooling_type
assert model_config._model_info.default_tok_pooling_type == default_pooling_type
assert model_config.pooler_config.tok_pooling_type == pooling_type
@pytest.mark.parametrize(
......@@ -554,100 +566,100 @@ def test_s3_url_different_models_create_different_directories(mock_pull_files):
"jason9693/Qwen2.5-1.5B-apeach",
"decoder",
True,
"Pooling models with causal attn and last pooling support chunked prefill.",
"Pooling models with causal attn and LAST/ALL pooling support chunked prefill.", # noqa: E501
),
(
"Qwen/Qwen3-Embedding-0.6B",
"decoder",
True,
"Pooling models with causal attn and last pooling support chunked prefill.",
"Pooling models with causal attn and LAST/ALL pooling support chunked prefill.", # noqa: E501
),
(
"Qwen/Qwen2.5-Math-PRM-7B",
"decoder",
False,
"Pooling models with step pooling does not support chunked prefill.",
"Pooling models with causal attn and LAST/STEP pooling do not support chunked prefill.", # noqa: E501
),
(
"internlm/internlm2-1_8b-reward",
"decoder",
True,
"Pooling models with causal attn and all pooling support chunked prefill.",
"Pooling models with causal attn and LAST/ALL pooling support chunked prefill.", # noqa: E501
),
(
"BAAI/bge-base-en",
"encoder_only",
False,
"Pooling models with bidirectional attn does not support chunked prefill.",
"Pooling models with bidirectional attn do not support chunked prefill.", # noqa: E501
),
(
"boltuix/NeuroBERT-NER",
"encoder_only",
False,
"Pooling models with bidirectional attn does not support chunked prefill.",
"Pooling models with bidirectional attn do not support chunked prefill.", # noqa: E501
),
(
"papluca/xlm-roberta-base-language-detection",
"encoder_only",
False,
"Pooling models with bidirectional attn does not support chunked prefill.",
"Pooling models with bidirectional attn do not support chunked prefill.", # noqa: E501
),
(
"Alibaba-NLP/gte-Qwen2-1.5B-instruct",
"encoder_only",
False,
"Pooling models with bidirectional attn does not support chunked prefill.",
"Pooling models with bidirectional attn do not support chunked prefill.", # noqa: E501
),
(
"intfloat/e5-small",
"encoder_only",
False,
"Pooling models with bidirectional attn does not support chunked prefill.",
"Pooling models with bidirectional attn do not support chunked prefill.", # noqa: E501
),
# multimodal models
(
"openai/clip-vit-base-patch32",
"decoder",
True,
"Pooling models with causal attn and last pooling support chunked prefill.",
"Pooling models with causal attn and LAST/ALL pooling support chunked prefill.", # noqa: E501
),
(
"google/siglip-base-patch16-224",
"encoder_only",
False,
"Pooling models with bidirectional attn does not support chunked prefill.",
"Pooling models with bidirectional attn do not support chunked prefill.", # noqa: E501
),
# generate models
(
"Qwen/Qwen3-0.6B",
"decoder",
True,
"Generative models support chunked prefill.",
"Generative models support chunked prefill.", # noqa: E501
),
(
"Qwen/Qwen3-Next-80B-A3B-Instruct",
"hybrid",
True,
"Generative models support chunked prefill.",
"Generative models support chunked prefill.", # noqa: E501
),
(
"ibm-granite/granite-4.0-h-small",
"hybrid",
True,
"Generative models support chunked prefill.",
"Generative models support chunked prefill.", # noqa: E501
),
(
"state-spaces/mamba-130m-hf",
"attention_free",
True,
"Generative models support chunked prefill.",
"Generative models support chunked prefill.", # noqa: E501
),
# encoder_decoder models
(
"openai/whisper-small",
"encoder_decoder",
False,
"Encoder decoder models does not support chunked prefill.",
"Encoder decoder models do not support chunked prefill.", # noqa: E501
),
],
)
......@@ -673,100 +685,100 @@ def test_is_chunked_prefill_supported(
"jason9693/Qwen2.5-1.5B-apeach",
"decoder",
True,
"Pooling models with causal attn and last pooling support prefix caching.",
"Pooling models with causal attn and LAST/ALL pooling support prefix caching.", # noqa: E501
),
(
"Qwen/Qwen3-Embedding-0.6B",
"decoder",
True,
"Pooling models with causal attn and last pooling support prefix caching.",
"Pooling models with causal attn and LAST/ALL pooling support prefix caching.", # noqa: E501
),
(
"Qwen/Qwen2.5-Math-PRM-7B",
"decoder",
False,
"Pooling models with step pooling does not support prefix caching.",
"Pooling models with causal attn and LAST/STEP pooling do not support prefix caching.", # noqa: E501
),
(
"internlm/internlm2-1_8b-reward",
"decoder",
True,
"Pooling models with causal attn and all pooling support prefix caching.",
"Pooling models with causal attn and LAST/ALL pooling support prefix caching.", # noqa: E501
),
(
"BAAI/bge-base-en",
"encoder_only",
False,
"Pooling models with bidirectional attn does not support prefix caching.",
"Pooling models with bidirectional attn do not support prefix caching.", # noqa: E501
),
(
"boltuix/NeuroBERT-NER",
"encoder_only",
False,
"Pooling models with bidirectional attn does not support prefix caching.",
"Pooling models with bidirectional attn do not support prefix caching.", # noqa: E501
),
(
"papluca/xlm-roberta-base-language-detection",
"encoder_only",
False,
"Pooling models with bidirectional attn does not support prefix caching.",
"Pooling models with bidirectional attn do not support prefix caching.", # noqa: E501
),
(
"Alibaba-NLP/gte-Qwen2-1.5B-instruct",
"encoder_only",
False,
"Pooling models with bidirectional attn does not support prefix caching.",
"Pooling models with bidirectional attn do not support prefix caching.", # noqa: E501
),
(
"intfloat/e5-small",
"encoder_only",
False,
"Pooling models with bidirectional attn does not support prefix caching.",
"Pooling models with bidirectional attn do not support prefix caching.", # noqa: E501
),
# multimodal models
(
"openai/clip-vit-base-patch32",
"decoder",
True,
"Pooling models with causal attn and last pooling support prefix caching.",
"Pooling models with causal attn and LAST/ALL pooling support prefix caching.", # noqa: E501
),
(
"google/siglip-base-patch16-224",
"encoder_only",
False,
"Pooling models with bidirectional attn does not support prefix caching.",
"Pooling models with bidirectional attn do not support prefix caching.", # noqa: E501
),
# generate models
(
"Qwen/Qwen3-0.6B",
"decoder",
True,
"Generative models support prefix caching.",
"Generative models support prefix caching.", # noqa: E501
),
(
"Qwen/Qwen3-Next-80B-A3B-Instruct",
"hybrid",
False,
"Hybrid models does not support prefix caching since the feature is still experimental.", # noqa: E501
"Hybrid models do not support prefix caching since the feature is still experimental.", # noqa: E501
),
(
"ibm-granite/granite-4.0-h-small",
"hybrid",
False,
"Hybrid models does not support prefix caching since the feature is still experimental.", # noqa: E501
"Hybrid models do not support prefix caching since the feature is still experimental.", # noqa: E501
),
(
"state-spaces/mamba-130m-hf",
"attention_free",
False,
"Attention free models does not support prefix caching since the feature is still experimental.", # noqa: E501
"Attention free models do not support prefix caching since the feature is still experimental.", # noqa: E501
),
# encoder_decoder models
(
"openai/whisper-small",
"encoder_decoder",
False,
"Encoder decoder models does not support prefix caching.",
"Encoder decoder models do not support prefix caching.", # noqa: E501
),
],
)
......
......@@ -40,7 +40,7 @@ def test_task():
def test_embed():
task = "embed"
model_config = MockModelConfig(pooler_config=PoolerConfig(pooling_type="CLS"))
model_config = MockModelConfig(pooler_config=PoolerConfig(seq_pooling_type="CLS"))
pooling_params = PoolingParams(normalize=None)
pooling_params.verify(task=task, model_config=model_config)
......@@ -86,7 +86,7 @@ def test_embed_dimensions(model_info: EmbedModelInfo):
@pytest.mark.parametrize("task", ["score", "classify"])
def test_classify(task):
model_config = MockModelConfig(pooler_config=PoolerConfig(pooling_type="CLS"))
model_config = MockModelConfig(pooler_config=PoolerConfig(seq_pooling_type="CLS"))
pooling_params = PoolingParams(use_activation=None)
pooling_params.verify(task=task, model_config=model_config)
......@@ -108,7 +108,7 @@ def test_classify(task):
def test_token_embed(pooling_type: str):
task = "token_embed"
model_config = MockModelConfig(
pooler_config=PoolerConfig(pooling_type=pooling_type)
pooler_config=PoolerConfig(tok_pooling_type=pooling_type)
)
pooling_params = PoolingParams(normalize=None)
......@@ -134,7 +134,7 @@ def test_token_embed(pooling_type: str):
def test_token_classify(pooling_type: str):
task = "token_classify"
model_config = MockModelConfig(
pooler_config=PoolerConfig(pooling_type=pooling_type)
pooler_config=PoolerConfig(tok_pooling_type=pooling_type)
)
pooling_params = PoolingParams(use_activation=None)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment