Unverified Commit 583a90e0 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Refactor] Separate sequence and token pooling types (#32026)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 52d42829
......@@ -539,9 +539,12 @@ class ModelConfig:
if getattr(self.pooler_config, k) is None:
setattr(self.pooler_config, k, v)
default_pooling_type = self._model_info.default_pooling_type
if self.pooler_config.pooling_type is None:
self.pooler_config.pooling_type = default_pooling_type
default_seq_pooling_type = self._model_info.default_seq_pooling_type
if self.pooler_config.seq_pooling_type is None:
self.pooler_config.seq_pooling_type = default_seq_pooling_type
default_tok_pooling_type = self._model_info.default_tok_pooling_type
if self.pooler_config.tok_pooling_type is None:
self.pooler_config.tok_pooling_type = default_tok_pooling_type
self.dtype: torch.dtype = _get_and_verify_dtype(
self.model,
......@@ -1543,8 +1546,8 @@ class ModelConfig:
@property
def attn_type(self) -> AttnTypeStr:
if self.pooler_config is not None:
pooling_type = self._model_info.default_pooling_type.lower()
if pooling_type == "cls":
seq_pooling_type = self._model_info.default_seq_pooling_type
if seq_pooling_type == "CLS":
return "encoder_only"
else:
is_causal = getattr(self.hf_config, "is_causal", True)
......@@ -1561,89 +1564,102 @@ class ModelConfig:
@property
def is_chunked_prefill_supported(self) -> bool:
attn_type = self.attn_type
if self.pooler_config is not None:
if pooler_config := self.pooler_config:
# for pooling models
if attn_type == "encoder_only":
logger.debug(
"Pooling models with bidirectional attn does not support "
"chunked prefill."
"Pooling models with bidirectional attn "
"do not support chunked prefill."
)
return False
elif attn_type == "decoder":
pooling_type = self.pooler_config.pooling_type.lower()
if pooling_type in ["mean", "step", "cls"]:
if attn_type == "decoder":
if (
pooler_config.seq_pooling_type in ("MEAN", "CLS")
or pooler_config.tok_pooling_type == "STEP"
):
logger.debug(
"Pooling models with %s pooling does not "
"support chunked prefill.",
pooling_type,
"Pooling models with causal attn and %s/%s pooling "
"do not support chunked prefill.",
pooler_config.seq_pooling_type,
pooler_config.tok_pooling_type,
)
return False
elif pooling_type in ["all", "last"]:
else:
logger.debug(
"Pooling models with causal attn and %s pooling support "
"chunked prefill.",
pooling_type,
"Pooling models with causal attn and %s/%s pooling "
"support chunked prefill.",
pooler_config.seq_pooling_type,
pooler_config.tok_pooling_type,
)
return True
else:
raise ValueError(f"{pooling_type=} not supported.")
# vllm currently does not have pooling models using hybrid,
# attention_free or encoder_decoder attn types.
return attn_type != "encoder_decoder"
else:
# for generative models
if attn_type == "encoder_decoder":
logger.debug("Encoder decoder models does not support chunked prefill.")
logger.debug("Encoder decoder models do not support chunked prefill.")
return False
logger.debug("Generative models support chunked prefill.")
return True
@property
def is_prefix_caching_supported(self) -> bool:
attn_type = self.attn_type
if self.pooler_config is not None:
if pooler_config := self.pooler_config:
# for pooling models
if attn_type == "encoder_only":
logger.debug(
"Pooling models with bidirectional attn does not "
"support prefix caching."
"Pooling models with bidirectional attn "
"do not support prefix caching."
)
return False
elif attn_type == "decoder":
pooling_type = self.pooler_config.pooling_type.lower()
if pooling_type in ["mean", "step", "cls"]:
if attn_type == "decoder":
if (
pooler_config.seq_pooling_type in ("MEAN", "CLS")
or pooler_config.tok_pooling_type == "STEP"
):
logger.debug(
"Pooling models with %s pooling does not "
"support prefix caching.",
pooling_type,
"Pooling models with causal attn and %s/%s pooling "
"do not support prefix caching.",
pooler_config.seq_pooling_type,
pooler_config.tok_pooling_type,
)
return False
elif pooling_type in ["all", "last"]:
else:
logger.debug(
"Pooling models with causal attn and %s pooling support "
"prefix caching.",
pooling_type,
"Pooling models with causal attn and %s/%s pooling "
"support prefix caching.",
pooler_config.seq_pooling_type,
pooler_config.tok_pooling_type,
)
return True
else:
raise ValueError(f"{pooling_type=} not supported.")
# vllm currently does not have pooling models using hybrid,
# attention_free or encoder_decoder attn types.
return False
else:
# for generative models
if attn_type == "hybrid":
logger.debug(
"Hybrid models does not support prefix caching since the feature "
"Hybrid models do not support prefix caching since the feature "
"is still experimental."
)
return False
elif attn_type == "attention_free":
logger.debug(
"Attention free models does not support prefix caching since the "
"Attention free models do not support prefix caching since the "
"feature is still experimental."
)
return False
elif attn_type == "encoder_decoder":
logger.debug("Encoder decoder models does not support prefix caching.")
logger.debug("Encoder decoder models do not support prefix caching.")
return False
else: # attn_type == "decoder"
logger.debug("Generative models support prefix caching.")
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Literal
from typing import Any, Literal, get_args
from pydantic.dataclasses import dataclass
......@@ -11,7 +11,11 @@ from vllm.utils.hashing import safe_hash
logger = init_logger(__name__)
PoolingTypeStr = Literal["LAST", "ALL", "CLS", "STEP", "MEAN"]
SequencePoolingType = Literal["CLS", "LAST", "MEAN"]
SEQ_POOLING_TYPES: tuple[SequencePoolingType, ...] = get_args(SequencePoolingType)
TokenPoolingType = Literal["ALL", "STEP"]
TOK_POOLING_TYPES: tuple[TokenPoolingType, ...] = get_args(TokenPoolingType)
@config
......@@ -19,9 +23,26 @@ PoolingTypeStr = Literal["LAST", "ALL", "CLS", "STEP", "MEAN"]
class PoolerConfig:
"""Controls the behavior of output pooling in pooling models."""
pooling_type: PoolingTypeStr | None = None
pooling_type: SequencePoolingType | TokenPoolingType | None = None
"""
The pooling method used for pooling.
If set, `seq_pooling_type` or `tok_pooling_type` are automatically populated
with this field. Alternatively, users can set `seq_pooling_type` and
`tok_pooling_type` explicitly.
This field is mainly for user convenience. Internal code should always use
`seq_pooling_type` or `tok_pooling_type` instead of `pooling_type`.
"""
seq_pooling_type: SequencePoolingType | None = None
"""
The pooling method used for sequence pooling.
"""
tok_pooling_type: TokenPoolingType | None = None
"""
The pooling method of the pooling model.
The pooling method used for tokenwise pooling.
"""
## for embeddings models
......@@ -88,9 +109,40 @@ class PoolerConfig:
# raise deprecated warning for softmax and activation
self.use_activation = get_use_activation(self)
def get_pooling_type(self) -> PoolingTypeStr:
assert self.pooling_type is not None, "Should be resolved by ModelConfig"
return self.pooling_type
if pooling_type := self.pooling_type:
if self.seq_pooling_type is not None:
raise ValueError(
"Cannot set both `pooling_type` and `seq_pooling_type`"
)
if self.tok_pooling_type is not None:
raise ValueError(
"Cannot set both `pooling_type` and `tok_pooling_type`"
)
if pooling_type in SEQ_POOLING_TYPES:
logger.debug(
"Resolved `pooling_type=%r` to `seq_pooling_type=%r`.",
pooling_type,
pooling_type,
)
self.seq_pooling_type = pooling_type
elif pooling_type in TOK_POOLING_TYPES:
logger.debug(
"Resolved `pooling_type=%r` to `tok_pooling_type=%r`.",
pooling_type,
pooling_type,
)
self.tok_pooling_type = pooling_type
else:
raise NotImplementedError(pooling_type)
def get_seq_pooling_type(self) -> SequencePoolingType:
assert self.seq_pooling_type is not None, "Should be resolved by ModelConfig"
return self.seq_pooling_type
def get_tok_pooling_type(self) -> TokenPoolingType:
assert self.tok_pooling_type is not None, "Should be resolved by ModelConfig"
return self.tok_pooling_type
def compute_hash(self) -> str:
"""
......
......@@ -172,7 +172,7 @@ class LLM:
The available overrides depend on the model that is being run.
For example, for Phi-3-Vision: `{"num_crops": 4}`.
pooler_config: Initialize non-default pooling config for the pooling
model. e.g. `PoolerConfig(pooling_type="mean", normalize=False)`.
model. e.g. `PoolerConfig(seq_pooling_type="MEAN", normalize=False)`.
compilation_config: Either an integer or a dictionary. If it is an
integer, it is used as the mode of compilation optimization. If it
is a dictionary, it can specify the full compilation configuration.
......
......@@ -7,7 +7,7 @@ from typing import TypeAlias
import torch
import torch.nn as nn
from vllm.config.pooler import PoolingTypeStr
from vllm.config.pooler import SequencePoolingType
from vllm.model_executor.layers.pooler import PoolingParamsUpdate
from vllm.tasks import PoolingTask
from vllm.v1.pool.metadata import PoolingMetadata
......@@ -82,11 +82,11 @@ class MeanPool(SequencePoolingMethod):
) / prompt_lens.unsqueeze(1)
def get_seq_pooling_method(pooling_type: PoolingTypeStr | str):
if pooling_type == "LAST":
return LastPool()
def get_seq_pooling_method(pooling_type: SequencePoolingType | str):
if pooling_type == "CLS":
return CLSPool()
if pooling_type == "LAST":
return LastPool()
if pooling_type == "MEAN":
return MeanPool()
......
......@@ -85,7 +85,7 @@ class SequencePooler(Pooler):
def pooler_for_embed(pooler_config: PoolerConfig):
pooling = get_seq_pooling_method(pooler_config.get_pooling_type())
pooling = get_seq_pooling_method(pooler_config.get_seq_pooling_type())
head = EmbeddingPoolerHead()
return SequencePooler(pooling=pooling, head=head)
......@@ -99,7 +99,7 @@ def pooler_for_classify(
act_fn: PoolerActivation | str | None = None,
):
if pooling is None:
pooling = get_seq_pooling_method(pooler_config.get_pooling_type())
pooling = get_seq_pooling_method(pooler_config.get_seq_pooling_type())
head = ClassifierPoolerHead(classifier=classifier, act_fn=act_fn)
......
......@@ -8,7 +8,7 @@ import torch
import torch.nn as nn
from vllm.config import get_current_vllm_config
from vllm.config.pooler import PoolingTypeStr
from vllm.config.pooler import TokenPoolingType
from vllm.model_executor.layers.pooler import PoolingParamsUpdate
from vllm.tasks import PoolingTask
from vllm.v1.pool.metadata import PoolingMetadata
......@@ -113,12 +113,10 @@ class StepPool(AllPool):
return pooled_data
def get_tok_pooling_method(pooling_type: PoolingTypeStr | str):
def get_tok_pooling_method(pooling_type: TokenPoolingType | str):
if pooling_type == "ALL":
return AllPool()
if pooling_type == "STEP":
return StepPool()
# TODO: Separate seq and tok pooling types so we don't need this fallback
return AllPool()
raise NotImplementedError(f"Unknown tokenwise pooling type: {pooling_type!r}")
......@@ -85,7 +85,7 @@ class TokenPooler(Pooler):
def pooler_for_token_embed(pooler_config: PoolerConfig):
pooling = get_tok_pooling_method(pooler_config.get_pooling_type())
pooling = get_tok_pooling_method(pooler_config.get_tok_pooling_type())
head = TokenEmbeddingPoolerHead()
return TokenPooler(pooling=pooling, head=head)
......@@ -99,7 +99,7 @@ def pooler_for_token_classify(
act_fn: PoolerActivation | str | None = None,
):
if pooling is None:
pooling = get_tok_pooling_method(pooler_config.get_pooling_type())
pooling = get_tok_pooling_method(pooler_config.get_tok_pooling_type())
head = TokenClassifierPoolerHead(classifier=classifier, act_fn=act_fn)
......
......@@ -357,7 +357,7 @@ class BertOutput(nn.Module):
@support_torch_compile
@default_pooling_type("CLS")
@default_pooling_type(seq_pooling_type="CLS")
class BertModel(nn.Module, SupportsQuant):
is_pooling_model = True
......@@ -461,7 +461,7 @@ class BertPoolingModel(BertModel):
return loaded_params
@default_pooling_type("CLS")
@default_pooling_type(seq_pooling_type="CLS")
class BertEmbeddingModel(nn.Module, SupportsQuant):
"""A model that uses Bert to provide embedding functionalities.
......@@ -675,7 +675,7 @@ class SPLADESparsePooler(Pooler):
return torch.stack(pooled_list, dim=0).contiguous()
@default_pooling_type("CLS")
@default_pooling_type(seq_pooling_type="CLS")
class BertSpladeSparseEmbeddingModel(BertEmbeddingModel):
"""
BertEmbeddingModel + SPLADE sparse embedding.
......@@ -780,7 +780,7 @@ class BertSpladeSparseEmbeddingModel(BertEmbeddingModel):
return loaded
@default_pooling_type("CLS")
@default_pooling_type(seq_pooling_type="CLS")
class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, SupportsQuant):
"""A model that uses Bert to provide embedding functionalities.
......@@ -849,7 +849,7 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, SupportsQu
@attn_type("encoder_only")
@default_pooling_type("ALL")
@default_pooling_type(tok_pooling_type="ALL")
class BertForTokenClassification(nn.Module):
is_pooling_model = True
......
......@@ -441,7 +441,7 @@ class BertWithRopeEncoder(nn.Module):
@support_torch_compile
@default_pooling_type("CLS")
@default_pooling_type(seq_pooling_type="CLS")
class BertWithRope(nn.Module, SupportsQuant):
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
......@@ -670,7 +670,7 @@ class JinaRobertaModel(BertWithRope):
return super().load_weights(weights)
@default_pooling_type("CLS")
@default_pooling_type(seq_pooling_type="CLS")
class GteNewForSequenceClassification(nn.Module, SupportsCrossEncoding):
is_pooling_model = True
......
......@@ -145,7 +145,7 @@ class CLIPProcessingInfo(BaseProcessingInfo):
image_width=image_width,
image_height=image_height,
),
_get_vision_feature_select_strategy(pooler_config.pooling_type),
_get_vision_feature_select_strategy(pooler_config.seq_pooling_type),
)
def get_image_size_with_most_features(self) -> ImageSize:
......@@ -819,7 +819,7 @@ class CLIPVisionModel(nn.Module):
# Assume EOS token corresponds to LAST token in text model
@default_pooling_type("LAST")
@default_pooling_type(seq_pooling_type="LAST")
@MULTIMODAL_REGISTRY.register_processor(
CLIPMultiModalProcessor,
info=CLIPProcessingInfo,
......@@ -908,7 +908,7 @@ class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
) -> torch.Tensor:
if feature_select_strategy is None:
feature_select_strategy = _get_vision_feature_select_strategy(
self.pooler_config.pooling_type
self.pooler_config.seq_pooling_type
)
pooled_output = self.vision_model(
......
......@@ -94,12 +94,12 @@ class JinaRobertaModelConfig(VerifyAndUpdateConfig):
class LlamaBidirectionalConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
from vllm.config.pooler import PoolingTypeStr
from vllm.config.pooler import SequencePoolingType
hf_config = model_config.hf_config
hf_config.is_causal = False
pooling_type_map: dict[str, PoolingTypeStr] = {
pooling_type_map: dict[str, SequencePoolingType] = {
"avg": "MEAN",
"cls": "CLS",
"last": "LAST",
......@@ -107,8 +107,9 @@ class LlamaBidirectionalConfig(VerifyAndUpdateConfig):
pooling_type = pooling_type_map.get(hf_config.pooling, None)
if pooling_type is None:
raise ValueError(f"pool_type {hf_config.pooling} not supported")
model_config.pooler_config.pooling_type = pooling_type
raise ValueError(f"pool_type {hf_config.pooling!r} not supported")
model_config.pooler_config.seq_pooling_type = pooling_type
class NomicBertModelConfig(VerifyAndUpdateConfig):
......
......@@ -193,7 +193,7 @@ class GritLMPooler(SequencePooler):
return self.activation(pooled_data)
@default_pooling_type("MEAN")
@default_pooling_type(seq_pooling_type="MEAN")
class GritLM(LlamaForCausalLM):
"""This class implements the embedding model for parasail-ai/GritLM-7B-vllm.
......
......@@ -20,12 +20,13 @@ from vllm.utils.func_utils import supports_kw
if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.config.model import AttnTypeStr
from vllm.config.pooler import PoolingTypeStr
from vllm.config.pooler import SequencePoolingType, TokenPoolingType
from vllm.model_executor.layers.pooler import Pooler
else:
VllmConfig = Any
Pooler = Any
PoolingTypeStr = Any
SequencePoolingType = Any
TokenPoolingType = Any
AttnTypeStr = Any
logger = init_logger(__name__)
......@@ -155,9 +156,19 @@ class VllmModelForPooling(VllmModel[T_co], Protocol[T_co]):
MRO of your model class.
"""
default_pooling_type: ClassVar[PoolingTypeStr] = "LAST"
default_seq_pooling_type: ClassVar[SequencePoolingType] = "LAST"
"""
Indicates the [vllm.config.pooler.PoolerConfig.pooling_type][]
Indicates the [vllm.config.pooler.PoolerConfig.seq_pooling_type][]
to use by default.
You can use the
[vllm.model_executor.models.interfaces_base.default_pooling_type][]
decorator to conveniently set this field.
"""
default_tok_pooling_type: ClassVar[TokenPoolingType] = "ALL"
"""
Indicates the [vllm.config.pooler.PoolerConfig.tok_pooling_type][]
to use by default.
You can use the
......@@ -200,18 +211,31 @@ def is_pooling_model(
_T = TypeVar("_T", bound=type[nn.Module])
def default_pooling_type(pooling_type: PoolingTypeStr):
"""Decorator to set `VllmModelForPooling.default_pooling_type`."""
def default_pooling_type(
*,
seq_pooling_type: SequencePoolingType = "LAST",
tok_pooling_type: TokenPoolingType = "ALL",
):
"""Decorator to set `VllmModelForPooling.default_*_pooling_type`."""
def func(model: _T) -> _T:
model.default_pooling_type = pooling_type # type: ignore
model.default_seq_pooling_type = seq_pooling_type # type: ignore
model.default_tok_pooling_type = tok_pooling_type # type: ignore
return model
return func
def get_default_pooling_type(model: type[object] | object) -> PoolingTypeStr:
return getattr(model, "default_pooling_type", "LAST")
def get_default_seq_pooling_type(
model: type[object] | object,
) -> SequencePoolingType:
return getattr(model, "default_seq_pooling_type", "LAST")
def get_default_tok_pooling_type(
model: type[object] | object,
) -> TokenPoolingType:
return getattr(model, "default_tok_pooling_type", "ALL")
def attn_type(attn_type: AttnTypeStr):
......
......@@ -402,7 +402,7 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
return loaded_params
@default_pooling_type("ALL")
@default_pooling_type(tok_pooling_type="ALL")
class InternLM2ForRewardModel(InternLM2ForCausalLM):
is_pooling_model = True
......
......@@ -221,7 +221,7 @@ class ModernBertEncoderLayer(nn.Module):
@support_torch_compile
@default_pooling_type("CLS")
@default_pooling_type(seq_pooling_type="CLS")
class ModernBertModel(nn.Module):
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={"layers.": "encoder_layer.layers."}
......@@ -308,7 +308,7 @@ class ModernBertPooler(SequencePooler):
return self.norm(self.act(self.dense(pooled_data)))
@default_pooling_type("CLS")
@default_pooling_type(seq_pooling_type="CLS")
class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding):
is_pooling_model = True
......@@ -395,7 +395,7 @@ class ModernBertPredictionHead(nn.Module):
@attn_type("encoder_only")
@default_pooling_type("ALL")
@default_pooling_type(tok_pooling_type="ALL")
class ModernBertForTokenClassification(nn.Module):
is_pooling_model = True
......
......@@ -96,7 +96,7 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP):
return loader.load_weights(weights)
@default_pooling_type("ALL")
@default_pooling_type(tok_pooling_type="ALL")
class Qwen2ForRewardModel(Qwen2RewardBaseModel):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
vllm_config.model_config.hf_config.num_labels = 1
......@@ -108,7 +108,7 @@ class Qwen2ForRewardModel(Qwen2RewardBaseModel):
self.pooler = pooler_for_token_classify(pooler_config)
@default_pooling_type("STEP")
@default_pooling_type(tok_pooling_type="STEP")
class Qwen2ForProcessRewardModel(Qwen2RewardBaseModel):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
vllm_config.model_config.hf_config.num_labels = 2
......
......@@ -35,10 +35,11 @@ from vllm.utils.hashing import safe_hash
if TYPE_CHECKING:
from vllm.config.model import AttnTypeStr
from vllm.config.pooler import PoolingTypeStr
from vllm.config.pooler import SequencePoolingType, TokenPoolingType
else:
AttnTypeStr = Any
PoolingTypeStr = Any
SequencePoolingType = Any
TokenPoolingType = Any
from .interfaces import (
......@@ -57,7 +58,8 @@ from .interfaces import (
)
from .interfaces_base import (
get_attn_type,
get_default_pooling_type,
get_default_seq_pooling_type,
get_default_tok_pooling_type,
is_pooling_model,
is_text_generation_model,
)
......@@ -548,7 +550,8 @@ class _ModelInfo:
is_text_generation_model: bool
is_pooling_model: bool
attn_type: AttnTypeStr
default_pooling_type: PoolingTypeStr
default_seq_pooling_type: SequencePoolingType
default_tok_pooling_type: TokenPoolingType
supports_cross_encoding: bool
supports_multimodal: bool
supports_multimodal_raw_input_only: bool
......@@ -569,7 +572,8 @@ class _ModelInfo:
architecture=model.__name__,
is_text_generation_model=is_text_generation_model(model),
is_pooling_model=is_pooling_model(model),
default_pooling_type=get_default_pooling_type(model),
default_seq_pooling_type=get_default_seq_pooling_type(model),
default_tok_pooling_type=get_default_tok_pooling_type(model),
attn_type=get_attn_type(model),
supports_cross_encoding=supports_cross_encoding(model),
supports_multimodal=supports_multimodal(model),
......
......@@ -93,7 +93,7 @@ class RobertaClassificationHead(nn.Module):
return x
@default_pooling_type("CLS")
@default_pooling_type(seq_pooling_type="CLS")
class RobertaEmbeddingModel(BertEmbeddingModel):
"""A model that uses Roberta to provide embedding functionalities."""
......@@ -150,7 +150,7 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
return loader.load_weights(weights_list, mapper=mapper)
@default_pooling_type("CLS")
@default_pooling_type(seq_pooling_type="CLS")
class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
"""A model that uses Roberta to provide embedding functionalities.
......
......@@ -129,7 +129,7 @@ class SiglipProcessingInfo(BaseProcessingInfo):
image_width=image_width,
image_height=image_height,
),
_get_vision_feature_select_strategy(pooler_config.pooling_type),
_get_vision_feature_select_strategy(pooler_config.seq_pooling_type),
)
def get_image_size_with_most_features(self) -> ImageSize:
......@@ -998,7 +998,7 @@ class SiglipTextEmbeddings(nn.Module):
# Assume EOS token corresponds to CLS token in text model
@default_pooling_type("CLS")
@default_pooling_type(seq_pooling_type="CLS")
@MULTIMODAL_REGISTRY.register_processor(
SiglipMultiModalProcessor,
info=SiglipProcessingInfo,
......@@ -1125,7 +1125,7 @@ class SiglipEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
) -> torch.Tensor:
if feature_select_strategy is None:
feature_select_strategy = _get_vision_feature_select_strategy(
self.pooler_config.pooling_type
self.pooler_config.seq_pooling_type
)
pooled_output = self.vision_model(
......
......@@ -140,7 +140,7 @@ class PoolingParams(
self, pooler_config: "PoolerConfig", valid_parameters: list[str]
):
step_pooling_parameters = ["step_tag_id", "returned_token_ids"]
if pooler_config.pooling_type != "STEP":
if pooler_config.tok_pooling_type != "STEP":
invalid_parameters = []
for k in step_pooling_parameters:
if getattr(self, k, None) is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment