Unverified Commit 583a90e0 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Refactor] Separate sequence and token pooling types (#32026)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 52d42829
...@@ -539,9 +539,12 @@ class ModelConfig: ...@@ -539,9 +539,12 @@ class ModelConfig:
if getattr(self.pooler_config, k) is None: if getattr(self.pooler_config, k) is None:
setattr(self.pooler_config, k, v) setattr(self.pooler_config, k, v)
default_pooling_type = self._model_info.default_pooling_type default_seq_pooling_type = self._model_info.default_seq_pooling_type
if self.pooler_config.pooling_type is None: if self.pooler_config.seq_pooling_type is None:
self.pooler_config.pooling_type = default_pooling_type self.pooler_config.seq_pooling_type = default_seq_pooling_type
default_tok_pooling_type = self._model_info.default_tok_pooling_type
if self.pooler_config.tok_pooling_type is None:
self.pooler_config.tok_pooling_type = default_tok_pooling_type
self.dtype: torch.dtype = _get_and_verify_dtype( self.dtype: torch.dtype = _get_and_verify_dtype(
self.model, self.model,
...@@ -1543,8 +1546,8 @@ class ModelConfig: ...@@ -1543,8 +1546,8 @@ class ModelConfig:
@property @property
def attn_type(self) -> AttnTypeStr: def attn_type(self) -> AttnTypeStr:
if self.pooler_config is not None: if self.pooler_config is not None:
pooling_type = self._model_info.default_pooling_type.lower() seq_pooling_type = self._model_info.default_seq_pooling_type
if pooling_type == "cls": if seq_pooling_type == "CLS":
return "encoder_only" return "encoder_only"
else: else:
is_causal = getattr(self.hf_config, "is_causal", True) is_causal = getattr(self.hf_config, "is_causal", True)
...@@ -1561,89 +1564,102 @@ class ModelConfig: ...@@ -1561,89 +1564,102 @@ class ModelConfig:
@property @property
def is_chunked_prefill_supported(self) -> bool: def is_chunked_prefill_supported(self) -> bool:
attn_type = self.attn_type attn_type = self.attn_type
if self.pooler_config is not None:
if pooler_config := self.pooler_config:
# for pooling models # for pooling models
if attn_type == "encoder_only": if attn_type == "encoder_only":
logger.debug( logger.debug(
"Pooling models with bidirectional attn does not support " "Pooling models with bidirectional attn "
"chunked prefill." "do not support chunked prefill."
) )
return False return False
elif attn_type == "decoder":
pooling_type = self.pooler_config.pooling_type.lower() if attn_type == "decoder":
if pooling_type in ["mean", "step", "cls"]: if (
pooler_config.seq_pooling_type in ("MEAN", "CLS")
or pooler_config.tok_pooling_type == "STEP"
):
logger.debug( logger.debug(
"Pooling models with %s pooling does not " "Pooling models with causal attn and %s/%s pooling "
"support chunked prefill.", "do not support chunked prefill.",
pooling_type, pooler_config.seq_pooling_type,
pooler_config.tok_pooling_type,
) )
return False return False
elif pooling_type in ["all", "last"]: else:
logger.debug( logger.debug(
"Pooling models with causal attn and %s pooling support " "Pooling models with causal attn and %s/%s pooling "
"chunked prefill.", "support chunked prefill.",
pooling_type, pooler_config.seq_pooling_type,
pooler_config.tok_pooling_type,
) )
return True return True
else:
raise ValueError(f"{pooling_type=} not supported.")
# vllm currently does not have pooling models using hybrid, # vllm currently does not have pooling models using hybrid,
# attention_free or encoder_decoder attn types. # attention_free or encoder_decoder attn types.
return attn_type != "encoder_decoder" return attn_type != "encoder_decoder"
else: else:
# for generative models
if attn_type == "encoder_decoder": if attn_type == "encoder_decoder":
logger.debug("Encoder decoder models does not support chunked prefill.") logger.debug("Encoder decoder models do not support chunked prefill.")
return False return False
logger.debug("Generative models support chunked prefill.") logger.debug("Generative models support chunked prefill.")
return True return True
@property @property
def is_prefix_caching_supported(self) -> bool: def is_prefix_caching_supported(self) -> bool:
attn_type = self.attn_type attn_type = self.attn_type
if self.pooler_config is not None:
if pooler_config := self.pooler_config:
# for pooling models # for pooling models
if attn_type == "encoder_only": if attn_type == "encoder_only":
logger.debug( logger.debug(
"Pooling models with bidirectional attn does not " "Pooling models with bidirectional attn "
"support prefix caching." "do not support prefix caching."
) )
return False return False
elif attn_type == "decoder":
pooling_type = self.pooler_config.pooling_type.lower() if attn_type == "decoder":
if pooling_type in ["mean", "step", "cls"]: if (
pooler_config.seq_pooling_type in ("MEAN", "CLS")
or pooler_config.tok_pooling_type == "STEP"
):
logger.debug( logger.debug(
"Pooling models with %s pooling does not " "Pooling models with causal attn and %s/%s pooling "
"support prefix caching.", "do not support prefix caching.",
pooling_type, pooler_config.seq_pooling_type,
pooler_config.tok_pooling_type,
) )
return False return False
elif pooling_type in ["all", "last"]: else:
logger.debug( logger.debug(
"Pooling models with causal attn and %s pooling support " "Pooling models with causal attn and %s/%s pooling "
"prefix caching.", "support prefix caching.",
pooling_type, pooler_config.seq_pooling_type,
pooler_config.tok_pooling_type,
) )
return True return True
else:
raise ValueError(f"{pooling_type=} not supported.")
# vllm currently does not have pooling models using hybrid, # vllm currently does not have pooling models using hybrid,
# attention_free or encoder_decoder attn types. # attention_free or encoder_decoder attn types.
return False return False
else: else:
# for generative models
if attn_type == "hybrid": if attn_type == "hybrid":
logger.debug( logger.debug(
"Hybrid models does not support prefix caching since the feature " "Hybrid models do not support prefix caching since the feature "
"is still experimental." "is still experimental."
) )
return False return False
elif attn_type == "attention_free": elif attn_type == "attention_free":
logger.debug( logger.debug(
"Attention free models does not support prefix caching since the " "Attention free models do not support prefix caching since the "
"feature is still experimental." "feature is still experimental."
) )
return False return False
elif attn_type == "encoder_decoder": elif attn_type == "encoder_decoder":
logger.debug("Encoder decoder models does not support prefix caching.") logger.debug("Encoder decoder models do not support prefix caching.")
return False return False
else: # attn_type == "decoder" else: # attn_type == "decoder"
logger.debug("Generative models support prefix caching.") logger.debug("Generative models support prefix caching.")
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Literal from typing import Any, Literal, get_args
from pydantic.dataclasses import dataclass from pydantic.dataclasses import dataclass
...@@ -11,7 +11,11 @@ from vllm.utils.hashing import safe_hash ...@@ -11,7 +11,11 @@ from vllm.utils.hashing import safe_hash
logger = init_logger(__name__) logger = init_logger(__name__)
PoolingTypeStr = Literal["LAST", "ALL", "CLS", "STEP", "MEAN"] SequencePoolingType = Literal["CLS", "LAST", "MEAN"]
SEQ_POOLING_TYPES: tuple[SequencePoolingType, ...] = get_args(SequencePoolingType)
TokenPoolingType = Literal["ALL", "STEP"]
TOK_POOLING_TYPES: tuple[TokenPoolingType, ...] = get_args(TokenPoolingType)
@config @config
...@@ -19,9 +23,26 @@ PoolingTypeStr = Literal["LAST", "ALL", "CLS", "STEP", "MEAN"] ...@@ -19,9 +23,26 @@ PoolingTypeStr = Literal["LAST", "ALL", "CLS", "STEP", "MEAN"]
class PoolerConfig: class PoolerConfig:
"""Controls the behavior of output pooling in pooling models.""" """Controls the behavior of output pooling in pooling models."""
pooling_type: PoolingTypeStr | None = None pooling_type: SequencePoolingType | TokenPoolingType | None = None
"""
The pooling method used for pooling.
If set, `seq_pooling_type` or `tok_pooling_type` are automatically populated
with this field. Alternatively, users can set `seq_pooling_type` and
`tok_pooling_type` explicitly.
This field is mainly for user convenience. Internal code should always use
`seq_pooling_type` or `tok_pooling_type` instead of `pooling_type`.
"""
seq_pooling_type: SequencePoolingType | None = None
"""
The pooling method used for sequence pooling.
"""
tok_pooling_type: TokenPoolingType | None = None
""" """
The pooling method of the pooling model. The pooling method used for tokenwise pooling.
""" """
## for embeddings models ## for embeddings models
...@@ -88,9 +109,40 @@ class PoolerConfig: ...@@ -88,9 +109,40 @@ class PoolerConfig:
# raise deprecated warning for softmax and activation # raise deprecated warning for softmax and activation
self.use_activation = get_use_activation(self) self.use_activation = get_use_activation(self)
def get_pooling_type(self) -> PoolingTypeStr: if pooling_type := self.pooling_type:
assert self.pooling_type is not None, "Should be resolved by ModelConfig" if self.seq_pooling_type is not None:
return self.pooling_type raise ValueError(
"Cannot set both `pooling_type` and `seq_pooling_type`"
)
if self.tok_pooling_type is not None:
raise ValueError(
"Cannot set both `pooling_type` and `tok_pooling_type`"
)
if pooling_type in SEQ_POOLING_TYPES:
logger.debug(
"Resolved `pooling_type=%r` to `seq_pooling_type=%r`.",
pooling_type,
pooling_type,
)
self.seq_pooling_type = pooling_type
elif pooling_type in TOK_POOLING_TYPES:
logger.debug(
"Resolved `pooling_type=%r` to `tok_pooling_type=%r`.",
pooling_type,
pooling_type,
)
self.tok_pooling_type = pooling_type
else:
raise NotImplementedError(pooling_type)
def get_seq_pooling_type(self) -> SequencePoolingType:
assert self.seq_pooling_type is not None, "Should be resolved by ModelConfig"
return self.seq_pooling_type
def get_tok_pooling_type(self) -> TokenPoolingType:
assert self.tok_pooling_type is not None, "Should be resolved by ModelConfig"
return self.tok_pooling_type
def compute_hash(self) -> str: def compute_hash(self) -> str:
""" """
......
...@@ -172,7 +172,7 @@ class LLM: ...@@ -172,7 +172,7 @@ class LLM:
The available overrides depend on the model that is being run. The available overrides depend on the model that is being run.
For example, for Phi-3-Vision: `{"num_crops": 4}`. For example, for Phi-3-Vision: `{"num_crops": 4}`.
pooler_config: Initialize non-default pooling config for the pooling pooler_config: Initialize non-default pooling config for the pooling
model. e.g. `PoolerConfig(pooling_type="mean", normalize=False)`. model. e.g. `PoolerConfig(seq_pooling_type="MEAN", normalize=False)`.
compilation_config: Either an integer or a dictionary. If it is an compilation_config: Either an integer or a dictionary. If it is an
integer, it is used as the mode of compilation optimization. If it integer, it is used as the mode of compilation optimization. If it
is a dictionary, it can specify the full compilation configuration. is a dictionary, it can specify the full compilation configuration.
......
...@@ -7,7 +7,7 @@ from typing import TypeAlias ...@@ -7,7 +7,7 @@ from typing import TypeAlias
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.config.pooler import PoolingTypeStr from vllm.config.pooler import SequencePoolingType
from vllm.model_executor.layers.pooler import PoolingParamsUpdate from vllm.model_executor.layers.pooler import PoolingParamsUpdate
from vllm.tasks import PoolingTask from vllm.tasks import PoolingTask
from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.pool.metadata import PoolingMetadata
...@@ -82,11 +82,11 @@ class MeanPool(SequencePoolingMethod): ...@@ -82,11 +82,11 @@ class MeanPool(SequencePoolingMethod):
) / prompt_lens.unsqueeze(1) ) / prompt_lens.unsqueeze(1)
def get_seq_pooling_method(pooling_type: PoolingTypeStr | str): def get_seq_pooling_method(pooling_type: SequencePoolingType | str):
if pooling_type == "LAST":
return LastPool()
if pooling_type == "CLS": if pooling_type == "CLS":
return CLSPool() return CLSPool()
if pooling_type == "LAST":
return LastPool()
if pooling_type == "MEAN": if pooling_type == "MEAN":
return MeanPool() return MeanPool()
......
...@@ -85,7 +85,7 @@ class SequencePooler(Pooler): ...@@ -85,7 +85,7 @@ class SequencePooler(Pooler):
def pooler_for_embed(pooler_config: PoolerConfig): def pooler_for_embed(pooler_config: PoolerConfig):
pooling = get_seq_pooling_method(pooler_config.get_pooling_type()) pooling = get_seq_pooling_method(pooler_config.get_seq_pooling_type())
head = EmbeddingPoolerHead() head = EmbeddingPoolerHead()
return SequencePooler(pooling=pooling, head=head) return SequencePooler(pooling=pooling, head=head)
...@@ -99,7 +99,7 @@ def pooler_for_classify( ...@@ -99,7 +99,7 @@ def pooler_for_classify(
act_fn: PoolerActivation | str | None = None, act_fn: PoolerActivation | str | None = None,
): ):
if pooling is None: if pooling is None:
pooling = get_seq_pooling_method(pooler_config.get_pooling_type()) pooling = get_seq_pooling_method(pooler_config.get_seq_pooling_type())
head = ClassifierPoolerHead(classifier=classifier, act_fn=act_fn) head = ClassifierPoolerHead(classifier=classifier, act_fn=act_fn)
......
...@@ -8,7 +8,7 @@ import torch ...@@ -8,7 +8,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from vllm.config import get_current_vllm_config from vllm.config import get_current_vllm_config
from vllm.config.pooler import PoolingTypeStr from vllm.config.pooler import TokenPoolingType
from vllm.model_executor.layers.pooler import PoolingParamsUpdate from vllm.model_executor.layers.pooler import PoolingParamsUpdate
from vllm.tasks import PoolingTask from vllm.tasks import PoolingTask
from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.pool.metadata import PoolingMetadata
...@@ -113,12 +113,10 @@ class StepPool(AllPool): ...@@ -113,12 +113,10 @@ class StepPool(AllPool):
return pooled_data return pooled_data
def get_tok_pooling_method(pooling_type: PoolingTypeStr | str): def get_tok_pooling_method(pooling_type: TokenPoolingType | str):
if pooling_type == "ALL": if pooling_type == "ALL":
return AllPool() return AllPool()
if pooling_type == "STEP": if pooling_type == "STEP":
return StepPool() return StepPool()
# TODO: Separate seq and tok pooling types so we don't need this fallback
return AllPool()
raise NotImplementedError(f"Unknown tokenwise pooling type: {pooling_type!r}") raise NotImplementedError(f"Unknown tokenwise pooling type: {pooling_type!r}")
...@@ -85,7 +85,7 @@ class TokenPooler(Pooler): ...@@ -85,7 +85,7 @@ class TokenPooler(Pooler):
def pooler_for_token_embed(pooler_config: PoolerConfig): def pooler_for_token_embed(pooler_config: PoolerConfig):
pooling = get_tok_pooling_method(pooler_config.get_pooling_type()) pooling = get_tok_pooling_method(pooler_config.get_tok_pooling_type())
head = TokenEmbeddingPoolerHead() head = TokenEmbeddingPoolerHead()
return TokenPooler(pooling=pooling, head=head) return TokenPooler(pooling=pooling, head=head)
...@@ -99,7 +99,7 @@ def pooler_for_token_classify( ...@@ -99,7 +99,7 @@ def pooler_for_token_classify(
act_fn: PoolerActivation | str | None = None, act_fn: PoolerActivation | str | None = None,
): ):
if pooling is None: if pooling is None:
pooling = get_tok_pooling_method(pooler_config.get_pooling_type()) pooling = get_tok_pooling_method(pooler_config.get_tok_pooling_type())
head = TokenClassifierPoolerHead(classifier=classifier, act_fn=act_fn) head = TokenClassifierPoolerHead(classifier=classifier, act_fn=act_fn)
......
...@@ -357,7 +357,7 @@ class BertOutput(nn.Module): ...@@ -357,7 +357,7 @@ class BertOutput(nn.Module):
@support_torch_compile @support_torch_compile
@default_pooling_type("CLS") @default_pooling_type(seq_pooling_type="CLS")
class BertModel(nn.Module, SupportsQuant): class BertModel(nn.Module, SupportsQuant):
is_pooling_model = True is_pooling_model = True
...@@ -461,7 +461,7 @@ class BertPoolingModel(BertModel): ...@@ -461,7 +461,7 @@ class BertPoolingModel(BertModel):
return loaded_params return loaded_params
@default_pooling_type("CLS") @default_pooling_type(seq_pooling_type="CLS")
class BertEmbeddingModel(nn.Module, SupportsQuant): class BertEmbeddingModel(nn.Module, SupportsQuant):
"""A model that uses Bert to provide embedding functionalities. """A model that uses Bert to provide embedding functionalities.
...@@ -675,7 +675,7 @@ class SPLADESparsePooler(Pooler): ...@@ -675,7 +675,7 @@ class SPLADESparsePooler(Pooler):
return torch.stack(pooled_list, dim=0).contiguous() return torch.stack(pooled_list, dim=0).contiguous()
@default_pooling_type("CLS") @default_pooling_type(seq_pooling_type="CLS")
class BertSpladeSparseEmbeddingModel(BertEmbeddingModel): class BertSpladeSparseEmbeddingModel(BertEmbeddingModel):
""" """
BertEmbeddingModel + SPLADE sparse embedding. BertEmbeddingModel + SPLADE sparse embedding.
...@@ -780,7 +780,7 @@ class BertSpladeSparseEmbeddingModel(BertEmbeddingModel): ...@@ -780,7 +780,7 @@ class BertSpladeSparseEmbeddingModel(BertEmbeddingModel):
return loaded return loaded
@default_pooling_type("CLS") @default_pooling_type(seq_pooling_type="CLS")
class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, SupportsQuant): class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, SupportsQuant):
"""A model that uses Bert to provide embedding functionalities. """A model that uses Bert to provide embedding functionalities.
...@@ -849,7 +849,7 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, SupportsQu ...@@ -849,7 +849,7 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, SupportsQu
@attn_type("encoder_only") @attn_type("encoder_only")
@default_pooling_type("ALL") @default_pooling_type(tok_pooling_type="ALL")
class BertForTokenClassification(nn.Module): class BertForTokenClassification(nn.Module):
is_pooling_model = True is_pooling_model = True
......
...@@ -441,7 +441,7 @@ class BertWithRopeEncoder(nn.Module): ...@@ -441,7 +441,7 @@ class BertWithRopeEncoder(nn.Module):
@support_torch_compile @support_torch_compile
@default_pooling_type("CLS") @default_pooling_type(seq_pooling_type="CLS")
class BertWithRope(nn.Module, SupportsQuant): class BertWithRope(nn.Module, SupportsQuant):
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
...@@ -670,7 +670,7 @@ class JinaRobertaModel(BertWithRope): ...@@ -670,7 +670,7 @@ class JinaRobertaModel(BertWithRope):
return super().load_weights(weights) return super().load_weights(weights)
@default_pooling_type("CLS") @default_pooling_type(seq_pooling_type="CLS")
class GteNewForSequenceClassification(nn.Module, SupportsCrossEncoding): class GteNewForSequenceClassification(nn.Module, SupportsCrossEncoding):
is_pooling_model = True is_pooling_model = True
......
...@@ -145,7 +145,7 @@ class CLIPProcessingInfo(BaseProcessingInfo): ...@@ -145,7 +145,7 @@ class CLIPProcessingInfo(BaseProcessingInfo):
image_width=image_width, image_width=image_width,
image_height=image_height, image_height=image_height,
), ),
_get_vision_feature_select_strategy(pooler_config.pooling_type), _get_vision_feature_select_strategy(pooler_config.seq_pooling_type),
) )
def get_image_size_with_most_features(self) -> ImageSize: def get_image_size_with_most_features(self) -> ImageSize:
...@@ -819,7 +819,7 @@ class CLIPVisionModel(nn.Module): ...@@ -819,7 +819,7 @@ class CLIPVisionModel(nn.Module):
# Assume EOS token corresponds to LAST token in text model # Assume EOS token corresponds to LAST token in text model
@default_pooling_type("LAST") @default_pooling_type(seq_pooling_type="LAST")
@MULTIMODAL_REGISTRY.register_processor( @MULTIMODAL_REGISTRY.register_processor(
CLIPMultiModalProcessor, CLIPMultiModalProcessor,
info=CLIPProcessingInfo, info=CLIPProcessingInfo,
...@@ -908,7 +908,7 @@ class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant): ...@@ -908,7 +908,7 @@ class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
) -> torch.Tensor: ) -> torch.Tensor:
if feature_select_strategy is None: if feature_select_strategy is None:
feature_select_strategy = _get_vision_feature_select_strategy( feature_select_strategy = _get_vision_feature_select_strategy(
self.pooler_config.pooling_type self.pooler_config.seq_pooling_type
) )
pooled_output = self.vision_model( pooled_output = self.vision_model(
......
...@@ -94,12 +94,12 @@ class JinaRobertaModelConfig(VerifyAndUpdateConfig): ...@@ -94,12 +94,12 @@ class JinaRobertaModelConfig(VerifyAndUpdateConfig):
class LlamaBidirectionalConfig(VerifyAndUpdateConfig): class LlamaBidirectionalConfig(VerifyAndUpdateConfig):
@staticmethod @staticmethod
def verify_and_update_model_config(model_config: "ModelConfig") -> None: def verify_and_update_model_config(model_config: "ModelConfig") -> None:
from vllm.config.pooler import PoolingTypeStr from vllm.config.pooler import SequencePoolingType
hf_config = model_config.hf_config hf_config = model_config.hf_config
hf_config.is_causal = False hf_config.is_causal = False
pooling_type_map: dict[str, PoolingTypeStr] = { pooling_type_map: dict[str, SequencePoolingType] = {
"avg": "MEAN", "avg": "MEAN",
"cls": "CLS", "cls": "CLS",
"last": "LAST", "last": "LAST",
...@@ -107,8 +107,9 @@ class LlamaBidirectionalConfig(VerifyAndUpdateConfig): ...@@ -107,8 +107,9 @@ class LlamaBidirectionalConfig(VerifyAndUpdateConfig):
pooling_type = pooling_type_map.get(hf_config.pooling, None) pooling_type = pooling_type_map.get(hf_config.pooling, None)
if pooling_type is None: if pooling_type is None:
raise ValueError(f"pool_type {hf_config.pooling} not supported") raise ValueError(f"pool_type {hf_config.pooling!r} not supported")
model_config.pooler_config.pooling_type = pooling_type
model_config.pooler_config.seq_pooling_type = pooling_type
class NomicBertModelConfig(VerifyAndUpdateConfig): class NomicBertModelConfig(VerifyAndUpdateConfig):
......
...@@ -193,7 +193,7 @@ class GritLMPooler(SequencePooler): ...@@ -193,7 +193,7 @@ class GritLMPooler(SequencePooler):
return self.activation(pooled_data) return self.activation(pooled_data)
@default_pooling_type("MEAN") @default_pooling_type(seq_pooling_type="MEAN")
class GritLM(LlamaForCausalLM): class GritLM(LlamaForCausalLM):
"""This class implements the embedding model for parasail-ai/GritLM-7B-vllm. """This class implements the embedding model for parasail-ai/GritLM-7B-vllm.
......
...@@ -20,12 +20,13 @@ from vllm.utils.func_utils import supports_kw ...@@ -20,12 +20,13 @@ from vllm.utils.func_utils import supports_kw
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.model import AttnTypeStr from vllm.config.model import AttnTypeStr
from vllm.config.pooler import PoolingTypeStr from vllm.config.pooler import SequencePoolingType, TokenPoolingType
from vllm.model_executor.layers.pooler import Pooler from vllm.model_executor.layers.pooler import Pooler
else: else:
VllmConfig = Any VllmConfig = Any
Pooler = Any Pooler = Any
PoolingTypeStr = Any SequencePoolingType = Any
TokenPoolingType = Any
AttnTypeStr = Any AttnTypeStr = Any
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -155,9 +156,19 @@ class VllmModelForPooling(VllmModel[T_co], Protocol[T_co]): ...@@ -155,9 +156,19 @@ class VllmModelForPooling(VllmModel[T_co], Protocol[T_co]):
MRO of your model class. MRO of your model class.
""" """
default_pooling_type: ClassVar[PoolingTypeStr] = "LAST" default_seq_pooling_type: ClassVar[SequencePoolingType] = "LAST"
""" """
Indicates the [vllm.config.pooler.PoolerConfig.pooling_type][] Indicates the [vllm.config.pooler.PoolerConfig.seq_pooling_type][]
to use by default.
You can use the
[vllm.model_executor.models.interfaces_base.default_pooling_type][]
decorator to conveniently set this field.
"""
default_tok_pooling_type: ClassVar[TokenPoolingType] = "ALL"
"""
Indicates the [vllm.config.pooler.PoolerConfig.tok_pooling_type][]
to use by default. to use by default.
You can use the You can use the
...@@ -200,18 +211,31 @@ def is_pooling_model( ...@@ -200,18 +211,31 @@ def is_pooling_model(
_T = TypeVar("_T", bound=type[nn.Module]) _T = TypeVar("_T", bound=type[nn.Module])
def default_pooling_type(pooling_type: PoolingTypeStr): def default_pooling_type(
"""Decorator to set `VllmModelForPooling.default_pooling_type`.""" *,
seq_pooling_type: SequencePoolingType = "LAST",
tok_pooling_type: TokenPoolingType = "ALL",
):
"""Decorator to set `VllmModelForPooling.default_*_pooling_type`."""
def func(model: _T) -> _T: def func(model: _T) -> _T:
model.default_pooling_type = pooling_type # type: ignore model.default_seq_pooling_type = seq_pooling_type # type: ignore
model.default_tok_pooling_type = tok_pooling_type # type: ignore
return model return model
return func return func
def get_default_pooling_type(model: type[object] | object) -> PoolingTypeStr: def get_default_seq_pooling_type(
return getattr(model, "default_pooling_type", "LAST") model: type[object] | object,
) -> SequencePoolingType:
return getattr(model, "default_seq_pooling_type", "LAST")
def get_default_tok_pooling_type(
model: type[object] | object,
) -> TokenPoolingType:
return getattr(model, "default_tok_pooling_type", "ALL")
def attn_type(attn_type: AttnTypeStr): def attn_type(attn_type: AttnTypeStr):
......
...@@ -402,7 +402,7 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA): ...@@ -402,7 +402,7 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
return loaded_params return loaded_params
@default_pooling_type("ALL") @default_pooling_type(tok_pooling_type="ALL")
class InternLM2ForRewardModel(InternLM2ForCausalLM): class InternLM2ForRewardModel(InternLM2ForCausalLM):
is_pooling_model = True is_pooling_model = True
......
...@@ -221,7 +221,7 @@ class ModernBertEncoderLayer(nn.Module): ...@@ -221,7 +221,7 @@ class ModernBertEncoderLayer(nn.Module):
@support_torch_compile @support_torch_compile
@default_pooling_type("CLS") @default_pooling_type(seq_pooling_type="CLS")
class ModernBertModel(nn.Module): class ModernBertModel(nn.Module):
hf_to_vllm_mapper = WeightsMapper( hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={"layers.": "encoder_layer.layers."} orig_to_new_prefix={"layers.": "encoder_layer.layers."}
...@@ -308,7 +308,7 @@ class ModernBertPooler(SequencePooler): ...@@ -308,7 +308,7 @@ class ModernBertPooler(SequencePooler):
return self.norm(self.act(self.dense(pooled_data))) return self.norm(self.act(self.dense(pooled_data)))
@default_pooling_type("CLS") @default_pooling_type(seq_pooling_type="CLS")
class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding): class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding):
is_pooling_model = True is_pooling_model = True
...@@ -395,7 +395,7 @@ class ModernBertPredictionHead(nn.Module): ...@@ -395,7 +395,7 @@ class ModernBertPredictionHead(nn.Module):
@attn_type("encoder_only") @attn_type("encoder_only")
@default_pooling_type("ALL") @default_pooling_type(tok_pooling_type="ALL")
class ModernBertForTokenClassification(nn.Module): class ModernBertForTokenClassification(nn.Module):
is_pooling_model = True is_pooling_model = True
......
...@@ -96,7 +96,7 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP): ...@@ -96,7 +96,7 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP):
return loader.load_weights(weights) return loader.load_weights(weights)
@default_pooling_type("ALL") @default_pooling_type(tok_pooling_type="ALL")
class Qwen2ForRewardModel(Qwen2RewardBaseModel): class Qwen2ForRewardModel(Qwen2RewardBaseModel):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
vllm_config.model_config.hf_config.num_labels = 1 vllm_config.model_config.hf_config.num_labels = 1
...@@ -108,7 +108,7 @@ class Qwen2ForRewardModel(Qwen2RewardBaseModel): ...@@ -108,7 +108,7 @@ class Qwen2ForRewardModel(Qwen2RewardBaseModel):
self.pooler = pooler_for_token_classify(pooler_config) self.pooler = pooler_for_token_classify(pooler_config)
@default_pooling_type("STEP") @default_pooling_type(tok_pooling_type="STEP")
class Qwen2ForProcessRewardModel(Qwen2RewardBaseModel): class Qwen2ForProcessRewardModel(Qwen2RewardBaseModel):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
vllm_config.model_config.hf_config.num_labels = 2 vllm_config.model_config.hf_config.num_labels = 2
......
...@@ -35,10 +35,11 @@ from vllm.utils.hashing import safe_hash ...@@ -35,10 +35,11 @@ from vllm.utils.hashing import safe_hash
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config.model import AttnTypeStr from vllm.config.model import AttnTypeStr
from vllm.config.pooler import PoolingTypeStr from vllm.config.pooler import SequencePoolingType, TokenPoolingType
else: else:
AttnTypeStr = Any AttnTypeStr = Any
PoolingTypeStr = Any SequencePoolingType = Any
TokenPoolingType = Any
from .interfaces import ( from .interfaces import (
...@@ -57,7 +58,8 @@ from .interfaces import ( ...@@ -57,7 +58,8 @@ from .interfaces import (
) )
from .interfaces_base import ( from .interfaces_base import (
get_attn_type, get_attn_type,
get_default_pooling_type, get_default_seq_pooling_type,
get_default_tok_pooling_type,
is_pooling_model, is_pooling_model,
is_text_generation_model, is_text_generation_model,
) )
...@@ -548,7 +550,8 @@ class _ModelInfo: ...@@ -548,7 +550,8 @@ class _ModelInfo:
is_text_generation_model: bool is_text_generation_model: bool
is_pooling_model: bool is_pooling_model: bool
attn_type: AttnTypeStr attn_type: AttnTypeStr
default_pooling_type: PoolingTypeStr default_seq_pooling_type: SequencePoolingType
default_tok_pooling_type: TokenPoolingType
supports_cross_encoding: bool supports_cross_encoding: bool
supports_multimodal: bool supports_multimodal: bool
supports_multimodal_raw_input_only: bool supports_multimodal_raw_input_only: bool
...@@ -569,7 +572,8 @@ class _ModelInfo: ...@@ -569,7 +572,8 @@ class _ModelInfo:
architecture=model.__name__, architecture=model.__name__,
is_text_generation_model=is_text_generation_model(model), is_text_generation_model=is_text_generation_model(model),
is_pooling_model=is_pooling_model(model), is_pooling_model=is_pooling_model(model),
default_pooling_type=get_default_pooling_type(model), default_seq_pooling_type=get_default_seq_pooling_type(model),
default_tok_pooling_type=get_default_tok_pooling_type(model),
attn_type=get_attn_type(model), attn_type=get_attn_type(model),
supports_cross_encoding=supports_cross_encoding(model), supports_cross_encoding=supports_cross_encoding(model),
supports_multimodal=supports_multimodal(model), supports_multimodal=supports_multimodal(model),
......
...@@ -93,7 +93,7 @@ class RobertaClassificationHead(nn.Module): ...@@ -93,7 +93,7 @@ class RobertaClassificationHead(nn.Module):
return x return x
@default_pooling_type("CLS") @default_pooling_type(seq_pooling_type="CLS")
class RobertaEmbeddingModel(BertEmbeddingModel): class RobertaEmbeddingModel(BertEmbeddingModel):
"""A model that uses Roberta to provide embedding functionalities.""" """A model that uses Roberta to provide embedding functionalities."""
...@@ -150,7 +150,7 @@ class RobertaEmbeddingModel(BertEmbeddingModel): ...@@ -150,7 +150,7 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
return loader.load_weights(weights_list, mapper=mapper) return loader.load_weights(weights_list, mapper=mapper)
@default_pooling_type("CLS") @default_pooling_type(seq_pooling_type="CLS")
class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding): class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
"""A model that uses Roberta to provide embedding functionalities. """A model that uses Roberta to provide embedding functionalities.
......
...@@ -129,7 +129,7 @@ class SiglipProcessingInfo(BaseProcessingInfo): ...@@ -129,7 +129,7 @@ class SiglipProcessingInfo(BaseProcessingInfo):
image_width=image_width, image_width=image_width,
image_height=image_height, image_height=image_height,
), ),
_get_vision_feature_select_strategy(pooler_config.pooling_type), _get_vision_feature_select_strategy(pooler_config.seq_pooling_type),
) )
def get_image_size_with_most_features(self) -> ImageSize: def get_image_size_with_most_features(self) -> ImageSize:
...@@ -998,7 +998,7 @@ class SiglipTextEmbeddings(nn.Module): ...@@ -998,7 +998,7 @@ class SiglipTextEmbeddings(nn.Module):
# Assume EOS token corresponds to CLS token in text model # Assume EOS token corresponds to CLS token in text model
@default_pooling_type("CLS") @default_pooling_type(seq_pooling_type="CLS")
@MULTIMODAL_REGISTRY.register_processor( @MULTIMODAL_REGISTRY.register_processor(
SiglipMultiModalProcessor, SiglipMultiModalProcessor,
info=SiglipProcessingInfo, info=SiglipProcessingInfo,
...@@ -1125,7 +1125,7 @@ class SiglipEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant): ...@@ -1125,7 +1125,7 @@ class SiglipEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
) -> torch.Tensor: ) -> torch.Tensor:
if feature_select_strategy is None: if feature_select_strategy is None:
feature_select_strategy = _get_vision_feature_select_strategy( feature_select_strategy = _get_vision_feature_select_strategy(
self.pooler_config.pooling_type self.pooler_config.seq_pooling_type
) )
pooled_output = self.vision_model( pooled_output = self.vision_model(
......
...@@ -140,7 +140,7 @@ class PoolingParams( ...@@ -140,7 +140,7 @@ class PoolingParams(
self, pooler_config: "PoolerConfig", valid_parameters: list[str] self, pooler_config: "PoolerConfig", valid_parameters: list[str]
): ):
step_pooling_parameters = ["step_tag_id", "returned_token_ids"] step_pooling_parameters = ["step_tag_id", "returned_token_ids"]
if pooler_config.pooling_type != "STEP": if pooler_config.tok_pooling_type != "STEP":
invalid_parameters = [] invalid_parameters = []
for k in step_pooling_parameters: for k in step_pooling_parameters:
if getattr(self, k, None) is not None: if getattr(self, k, None) is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment