Unverified Commit 1c3198b6 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Model] Consolidate pooler implementations (#20927)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 260127ea
This diff is collapsed.
...@@ -58,22 +58,27 @@ def _create_pooling_model_cls( ...@@ -58,22 +58,27 @@ def _create_pooling_model_cls(
) -> None: ) -> None:
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs) super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
self.vllm_config = vllm_config
# These are not used in pooling models # These are not used in pooling models
for attr in ("lm_head", "logits_processor"): for attr in ("lm_head", "logits_processor"):
if hasattr(self, attr): if hasattr(self, attr):
delattr(self, attr) delattr(self, attr)
# If the model already defines a pooler instance, don't overwrite it
if not getattr(self, "_pooler", None):
self._init_pooler(vllm_config, prefix=prefix)
def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
pooler_config = vllm_config.model_config.pooler_config pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None assert pooler_config is not None
# If the model already defines a pooler instance, don't overwrite it self._pooler = Pooler.from_config_with_defaults(
if not getattr(self, "_pooler", None): pooler_config,
self._pooler = Pooler.from_config_with_defaults( pooling_type=default_pooling_type,
pooler_config, normalize=default_normalize,
pooling_type=default_pooling_type, softmax=default_softmax,
normalize=default_normalize, )
softmax=default_softmax,
)
def pooler( def pooler(
self, self,
...@@ -165,7 +170,9 @@ def as_seq_cls_model(cls: _T) -> _T: ...@@ -165,7 +170,9 @@ def as_seq_cls_model(cls: _T) -> _T:
# Lazy import # Lazy import
from vllm.model_executor.layers.linear import RowParallelLinear from vllm.model_executor.layers.linear import RowParallelLinear
from vllm.model_executor.layers.pooler import PoolerOutput, PoolingType from vllm.model_executor.layers.pooler import (ClassifierPooler,
PoolerOutput, PoolingType,
SimplePooler)
from vllm.model_executor.models.interfaces import SupportsCrossEncoding from vllm.model_executor.models.interfaces import SupportsCrossEncoding
from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
...@@ -182,30 +189,40 @@ def as_seq_cls_model(cls: _T) -> _T: ...@@ -182,30 +189,40 @@ def as_seq_cls_model(cls: _T) -> _T:
class ModelForSequenceClassification(ModelForPooling, class ModelForSequenceClassification(ModelForPooling,
SupportsCrossEncoding): SupportsCrossEncoding):
def __init__( def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
self,
*,
vllm_config: "VllmConfig",
prefix: str = "",
**kwargs: Any,
) -> None:
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
self.vllm_config = vllm_config self.score = RowParallelLinear(
self.task = vllm_config.model_config.task config.hidden_size,
self.pooling_type = ( config.num_labels,
vllm_config.model_config.pooler_config.pooling_type) input_is_parallel=False,
bias=False,
self.score = RowParallelLinear(config.hidden_size, params_dtype=torch.float32,
config.num_labels, quant_config=quant_config,
quant_config=quant_config, prefix=maybe_prefix(prefix, "score"),
input_is_parallel=False, )
bias=False,
prefix=maybe_prefix( pooler_config = vllm_config.model_config.pooler_config
prefix, "score")) assert pooler_config is not None
pooler = SimplePooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.LAST,
normalize=False,
softmax=True,
)
self._pooler = ClassifierPooler(
vllm_config.model_config,
pooling=pooler.pooling,
classifier=self._classifier,
act_fn=pooler.head.activation,
)
def _classifier(self, x: torch.Tensor):
x, _ = self.score(x.float())
return x
def forward( def forward(
self, self,
...@@ -222,27 +239,7 @@ def as_seq_cls_model(cls: _T) -> _T: ...@@ -222,27 +239,7 @@ def as_seq_cls_model(cls: _T) -> _T:
hidden_states: Union[torch.Tensor, list[torch.Tensor]], hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata, pooling_metadata: PoolingMetadata,
) -> PoolerOutput: ) -> PoolerOutput:
return self._pooler(hidden_states, pooling_metadata)
def get_logits(hidden_states):
if isinstance(hidden_states, list):
logits = [self.score(state)[0] for state in hidden_states]
else:
logits, _ = self.score(hidden_states)
return logits
if self.pooling_type == PoolingType.ALL:
logits = get_logits(hidden_states)
return self._pooler(logits, pooling_metadata)
else:
hidden_states = self._pooler.extract_states(
hidden_states, pooling_metadata)
logits = get_logits(hidden_states)
pooled_data = self._pooler.head(logits, pooling_metadata)
pooled_outputs = [
self._pooler.build_output(data) for data in pooled_data
]
return PoolerOutput(outputs=pooled_outputs)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
tokens = getattr(self.config, "classifier_from_token", None) tokens = getattr(self.config, "classifier_from_token", None)
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable from collections.abc import Iterable
from typing import Optional from typing import Optional, Union
import torch import torch
from torch import nn from torch import nn
...@@ -18,7 +18,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, ...@@ -18,7 +18,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.pooler import (ClassifierPooler, Pooler, from vllm.model_executor.layers.pooler import (ClassifierPooler, Pooler,
PoolingType) PoolingMethod, PoolingType)
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
...@@ -84,14 +84,18 @@ class BertPooler(nn.Module): ...@@ -84,14 +84,18 @@ class BertPooler(nn.Module):
def __init__(self, config: BertConfig): def __init__(self, config: BertConfig):
super().__init__() super().__init__()
self.pooling = PoolingMethod.from_pooling_type(PoolingType.CLS)
self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh() self.activation = nn.Tanh()
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(
# We "pool" the model by simply taking the hidden state corresponding self,
# to the first token. hidden_states: Union[torch.Tensor, list[torch.Tensor]],
first_token_tensor = hidden_states[0, :] pooling_metadata: PoolingMetadata,
pooled_output = self.dense(first_token_tensor) ) -> Union[torch.Tensor, list[torch.Tensor]]:
pooled_output = self.pooling(hidden_states, pooling_metadata)
pooled_output = self.dense(pooled_output)
pooled_output = self.activation(pooled_output) pooled_output = self.activation(pooled_output)
return pooled_output return pooled_output
...@@ -472,8 +476,11 @@ class BertForSequenceClassification(nn.Module, SupportsV0Only, ...@@ -472,8 +476,11 @@ class BertForSequenceClassification(nn.Module, SupportsV0Only,
embedding_class=BertEmbedding, embedding_class=BertEmbedding,
add_pooling_layer=True) add_pooling_layer=True)
self.classifier = nn.Linear(config.hidden_size, config.num_labels) self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self._pooler = ClassifierPooler(vllm_config.model_config, self._pooler = ClassifierPooler(
self.classifier, self.bert.pooler) vllm_config.model_config,
pooling=self.bert.pooler,
classifier=self.classifier,
)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
......
...@@ -9,7 +9,7 @@ import torch.nn as nn ...@@ -9,7 +9,7 @@ import torch.nn as nn
from vllm.config import ModelConfig, VllmConfig from vllm.config import ModelConfig, VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.pooler import PoolerHead from vllm.model_executor.layers.pooler import PoolerHead, PoolerNormalize
from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.model_executor.pooling_metadata import (PoolingMetadata, from vllm.model_executor.pooling_metadata import (PoolingMetadata,
PoolingTensors) PoolingTensors)
...@@ -49,7 +49,7 @@ class GritLMPooler(nn.Module): ...@@ -49,7 +49,7 @@ class GritLMPooler(nn.Module):
self.embed_pattern_ids = tokens_to_ids( self.embed_pattern_ids = tokens_to_ids(
["▁<", "|", "embed", "|", ">", "<0x0A>"]) ["▁<", "|", "embed", "|", ">", "<0x0A>"])
self.head = PoolerHead(normalize=True, softmax=False) self.head = PoolerHead(PoolerNormalize())
def _find_array(self, arr: array, target: array, start_idx: int) -> int: def _find_array(self, arr: array, target: array, start_idx: int) -> int:
""" """
......
...@@ -659,7 +659,7 @@ def supports_cross_encoding( ...@@ -659,7 +659,7 @@ def supports_cross_encoding(
def has_step_pooler(model: Union[type[object], object]) -> bool: def has_step_pooler(model: Union[type[object], object]) -> bool:
"""Check if the model uses step pooler.""" """Check if the model uses step pooler."""
return is_pooling_model(model) and any( return is_pooling_model(model) and any(
type(module).__name__ == "StepPool" for module in model.modules()) type(module).__name__ == "StepPooler" for module in model.modules())
class SupportsQuant: class SupportsQuant:
......
...@@ -19,7 +19,8 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear, ...@@ -19,7 +19,8 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer
from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.pooler import (ClassifierPooler, PoolingType,
SimplePooler)
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
...@@ -564,29 +565,41 @@ class JambaForSequenceClassification(JambaForCausalLM): ...@@ -564,29 +565,41 @@ class JambaForSequenceClassification(JambaForCausalLM):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix) super().__init__(vllm_config=vllm_config, prefix=prefix)
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
num_labels: int = config.num_labels num_labels: int = config.num_labels
score_bias: bool = getattr(config, 'score_bias', False) score_bias: bool = getattr(config, 'score_bias', False)
self.score = nn.Linear(config.hidden_size, num_labels, bias=score_bias)
# TODO: The original reward weights have float32 accuracy data, we
# would like to load them in fp32 to get that extra precision.
# Currently weight_loader passes the weight which is already in bf16
self.score = nn.Linear(
config.hidden_size,
num_labels,
bias=score_bias,
dtype=torch.float32,
)
pooler_config = vllm_config.model_config.pooler_config pooler_config = vllm_config.model_config.pooler_config
self._pooler = Pooler.from_config_with_defaults( assert pooler_config is not None
pooler = SimplePooler.from_config_with_defaults(
pooler_config, pooler_config,
pooling_type=PoolingType.LAST, pooling_type=PoolingType.LAST,
normalize=False, normalize=False,
softmax=False) softmax=False,
)
self._pooler = ClassifierPooler(
vllm_config.model_config,
pooling=pooler.pooling,
classifier=self.score,
act_fn=pooler.head.activation,
)
def pooler( def pooler(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata, pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]: ) -> Optional[PoolerOutput]:
hidden_states = hidden_states.float() return self._pooler(hidden_states, pooling_metadata)
logits = self.score(hidden_states)
return self._pooler(logits, pooling_metadata)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
# TODO: The reward weights themselves have float32 accuracy data, we
# would like to load them in fp32 to get that extra precision.
super().load_weights(weights)
self.score = self.score.float()
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable from collections.abc import Iterable
from typing import Optional from typing import Optional, Union
import torch import torch
from torch import nn from torch import nn
...@@ -13,7 +13,8 @@ from vllm.config import VllmConfig ...@@ -13,7 +13,8 @@ from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.linear import (QKVParallelLinear, from vllm.model_executor.layers.linear import (QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.pooler import ClassifierPooler from vllm.model_executor.layers.pooler import (BasePooler, ClassifierPooler,
PoolingMethod, PoolingType)
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
...@@ -252,10 +253,13 @@ class ModernBertModel(nn.Module): ...@@ -252,10 +253,13 @@ class ModernBertModel(nn.Module):
return norm_outputs return norm_outputs
class ModernBertPooler(nn.Module): class ModernBertPooler(BasePooler):
def __init__(self, config: ModernBertConfig): def __init__(self, config: ModernBertConfig):
super().__init__() super().__init__()
pooling_type = PoolingType[config.classifier_pooling.upper()]
self.pooling = PoolingMethod.from_pooling_type(pooling_type)
self.dense = nn.Linear(config.hidden_size, config.hidden_size, self.dense = nn.Linear(config.hidden_size, config.hidden_size,
config.classifier_bias) config.classifier_bias)
self.pooling_type = config.classifier_pooling self.pooling_type = config.classifier_pooling
...@@ -264,15 +268,12 @@ class ModernBertPooler(nn.Module): ...@@ -264,15 +268,12 @@ class ModernBertPooler(nn.Module):
eps=config.norm_eps, eps=config.norm_eps,
bias=config.norm_bias) bias=config.norm_bias)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(
pooled_output = hidden_states self,
if self.pooling_type == "mean": hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooled_output = pooled_output.mean(dim=0, keepdim=False) pooling_metadata: PoolingMetadata,
elif self.pooling_type == "cls": ) -> Union[torch.Tensor, list[torch.Tensor]]:
pooled_output = pooled_output[0, :] pooled_output = self.pooling(hidden_states, pooling_metadata)
else:
raise ValueError("Pooling type should be either `cls` or `mean`, "
f"but got {self.pooling_type}")
pooled_output = self.norm(self.act(self.dense(pooled_output))) pooled_output = self.norm(self.act(self.dense(pooled_output)))
return pooled_output return pooled_output
...@@ -287,9 +288,11 @@ class ModernBertForSequenceClassification(nn.Module, SupportsV0Only, ...@@ -287,9 +288,11 @@ class ModernBertForSequenceClassification(nn.Module, SupportsV0Only,
self.model = ModernBertModel(vllm_config=vllm_config, self.model = ModernBertModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "modernbert")) prefix=maybe_prefix(prefix, "modernbert"))
self.classifier = nn.Linear(config.hidden_size, config.num_labels) self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self._pooler = ClassifierPooler(vllm_config.model_config, self._pooler = ClassifierPooler(
self.classifier, vllm_config.model_config,
ModernBertPooler(config)) pooling=ModernBertPooler(config),
classifier=self.classifier,
)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
......
...@@ -9,7 +9,7 @@ from torch import nn ...@@ -9,7 +9,7 @@ from torch import nn
from transformers import RobertaConfig from transformers import RobertaConfig
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.layers.pooler import ClassifierPooler from vllm.model_executor.layers.pooler import ClassifierPooler, CLSPool
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
from vllm.model_executor.models.bert import BertEmbeddingModel, BertModel from vllm.model_executor.models.bert import BertEmbeddingModel, BertModel
...@@ -106,8 +106,8 @@ class RobertaClassificationHead(nn.Module): ...@@ -106,8 +106,8 @@ class RobertaClassificationHead(nn.Module):
self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.out_proj = nn.Linear(config.hidden_size, config.num_labels) self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
def forward(self, features, **kwargs): def forward(self, x: torch.Tensor) -> torch.Tensor:
x = features[0, :] # take <s> token (equiv. to [CLS]) # CLSPool has already been applied in `pooling`
x = self.dense(x) x = self.dense(x)
x = torch.tanh(x) x = torch.tanh(x)
x = self.out_proj(x) x = self.out_proj(x)
...@@ -188,8 +188,11 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding, ...@@ -188,8 +188,11 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
add_pooling_layer=False) add_pooling_layer=False)
self.classifier = RobertaClassificationHead(config) self.classifier = RobertaClassificationHead(config)
self._pooler = ClassifierPooler(vllm_config.model_config, self._pooler = ClassifierPooler(
self.classifier) vllm_config.model_config,
pooling=CLSPool(),
classifier=self.classifier,
)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
......
...@@ -17,7 +17,6 @@ from huggingface_hub.utils import (EntryNotFoundError, HfHubHTTPError, ...@@ -17,7 +17,6 @@ from huggingface_hub.utils import (EntryNotFoundError, HfHubHTTPError,
HFValidationError, LocalEntryNotFoundError, HFValidationError, LocalEntryNotFoundError,
RepositoryNotFoundError, RepositoryNotFoundError,
RevisionNotFoundError) RevisionNotFoundError)
from torch import nn
from transformers import GenerationConfig, PretrainedConfig from transformers import GenerationConfig, PretrainedConfig
from transformers.models.auto.image_processing_auto import ( from transformers.models.auto.image_processing_auto import (
get_image_processor_config) get_image_processor_config)
...@@ -44,7 +43,6 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, Cohere2Config, ...@@ -44,7 +43,6 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, Cohere2Config,
# yapf: enable # yapf: enable
from vllm.transformers_utils.configs.mistral import adapt_config_dict from vllm.transformers_utils.configs.mistral import adapt_config_dict
from vllm.transformers_utils.utils import check_gguf_file from vllm.transformers_utils.utils import check_gguf_file
from vllm.utils import resolve_obj_by_qualname
if envs.VLLM_USE_MODELSCOPE: if envs.VLLM_USE_MODELSCOPE:
from modelscope import AutoConfig from modelscope import AutoConfig
...@@ -775,28 +773,6 @@ def try_get_generation_config( ...@@ -775,28 +773,6 @@ def try_get_generation_config(
return None return None
def get_classification_activation_function(config: PretrainedConfig):
return nn.Sigmoid() if config.num_labels == 1 else nn.Softmax()
def get_cross_encoder_activation_function(config: PretrainedConfig):
function_name: Optional[str] = None
if (hasattr(config, "sentence_transformers")
and "activation_fn" in config.sentence_transformers):
function_name = config.sentence_transformers["activation_fn"]
elif (hasattr(config, "sbert_ce_default_activation_function")
and config.sbert_ce_default_activation_function is not None):
function_name = config.sbert_ce_default_activation_function
if function_name is not None:
assert function_name.startswith("torch.nn.modules."), (
"Loading of activation functions is restricted to "
"torch.nn.modules for security reasons")
return resolve_obj_by_qualname(function_name)()
return nn.Sigmoid() if config.num_labels == 1 else nn.Identity()
def try_get_safetensors_metadata( def try_get_safetensors_metadata(
model: str, model: str,
*, *,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment