Unverified Commit 1c3198b6 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Model] Consolidate pooler implementations (#20927)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 260127ea
This diff is collapsed.
......@@ -58,16 +58,21 @@ def _create_pooling_model_cls(
) -> None:
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
self.vllm_config = vllm_config
# These are not used in pooling models
for attr in ("lm_head", "logits_processor"):
if hasattr(self, attr):
delattr(self, attr)
# If the model already defines a pooler instance, don't overwrite it
if not getattr(self, "_pooler", None):
self._init_pooler(vllm_config, prefix=prefix)
def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
# If the model already defines a pooler instance, don't overwrite it
if not getattr(self, "_pooler", None):
self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=default_pooling_type,
......@@ -165,7 +170,9 @@ def as_seq_cls_model(cls: _T) -> _T:
# Lazy import
from vllm.model_executor.layers.linear import RowParallelLinear
from vllm.model_executor.layers.pooler import PoolerOutput, PoolingType
from vllm.model_executor.layers.pooler import (ClassifierPooler,
PoolerOutput, PoolingType,
SimplePooler)
from vllm.model_executor.models.interfaces import SupportsCrossEncoding
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors
......@@ -182,30 +189,40 @@ def as_seq_cls_model(cls: _T) -> _T:
class ModelForSequenceClassification(ModelForPooling,
SupportsCrossEncoding):
def __init__(
self,
*,
vllm_config: "VllmConfig",
prefix: str = "",
**kwargs: Any,
) -> None:
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.vllm_config = vllm_config
self.task = vllm_config.model_config.task
self.pooling_type = (
vllm_config.model_config.pooler_config.pooling_type)
self.score = RowParallelLinear(config.hidden_size,
self.score = RowParallelLinear(
config.hidden_size,
config.num_labels,
quant_config=quant_config,
input_is_parallel=False,
bias=False,
prefix=maybe_prefix(
prefix, "score"))
params_dtype=torch.float32,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "score"),
)
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
pooler = SimplePooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.LAST,
normalize=False,
softmax=True,
)
self._pooler = ClassifierPooler(
vllm_config.model_config,
pooling=pooler.pooling,
classifier=self._classifier,
act_fn=pooler.head.activation,
)
def _classifier(self, x: torch.Tensor):
x, _ = self.score(x.float())
return x
def forward(
self,
......@@ -222,27 +239,7 @@ def as_seq_cls_model(cls: _T) -> _T:
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
def get_logits(hidden_states):
if isinstance(hidden_states, list):
logits = [self.score(state)[0] for state in hidden_states]
else:
logits, _ = self.score(hidden_states)
return logits
if self.pooling_type == PoolingType.ALL:
logits = get_logits(hidden_states)
return self._pooler(logits, pooling_metadata)
else:
hidden_states = self._pooler.extract_states(
hidden_states, pooling_metadata)
logits = get_logits(hidden_states)
pooled_data = self._pooler.head(logits, pooling_metadata)
pooled_outputs = [
self._pooler.build_output(data) for data in pooled_data
]
return PoolerOutput(outputs=pooled_outputs)
return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
tokens = getattr(self.config, "classifier_from_token", None)
......
......@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
from typing import Optional
from typing import Optional, Union
import torch
from torch import nn
......@@ -18,7 +18,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.pooler import (ClassifierPooler, Pooler,
PoolingType)
PoolingMethod, PoolingType)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
......@@ -84,14 +84,18 @@ class BertPooler(nn.Module):
def __init__(self, config: BertConfig):
super().__init__()
self.pooling = PoolingMethod.from_pooling_type(PoolingType.CLS)
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[0, :]
pooled_output = self.dense(first_token_tensor)
def forward(
self,
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata,
) -> Union[torch.Tensor, list[torch.Tensor]]:
pooled_output = self.pooling(hidden_states, pooling_metadata)
pooled_output = self.dense(pooled_output)
pooled_output = self.activation(pooled_output)
return pooled_output
......@@ -472,8 +476,11 @@ class BertForSequenceClassification(nn.Module, SupportsV0Only,
embedding_class=BertEmbedding,
add_pooling_layer=True)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self._pooler = ClassifierPooler(vllm_config.model_config,
self.classifier, self.bert.pooler)
self._pooler = ClassifierPooler(
vllm_config.model_config,
pooling=self.bert.pooler,
classifier=self.classifier,
)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self)
......
......@@ -9,7 +9,7 @@ import torch.nn as nn
from vllm.config import ModelConfig, VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.pooler import PoolerHead
from vllm.model_executor.layers.pooler import PoolerHead, PoolerNormalize
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.model_executor.pooling_metadata import (PoolingMetadata,
PoolingTensors)
......@@ -49,7 +49,7 @@ class GritLMPooler(nn.Module):
self.embed_pattern_ids = tokens_to_ids(
["▁<", "|", "embed", "|", ">", "<0x0A>"])
self.head = PoolerHead(normalize=True, softmax=False)
self.head = PoolerHead(PoolerNormalize())
def _find_array(self, arr: array, target: array, start_idx: int) -> int:
"""
......
......@@ -659,7 +659,7 @@ def supports_cross_encoding(
def has_step_pooler(model: Union[type[object], object]) -> bool:
"""Check if the model uses step pooler."""
return is_pooling_model(model) and any(
type(module).__name__ == "StepPool" for module in model.modules())
type(module).__name__ == "StepPooler" for module in model.modules())
class SupportsQuant:
......
......@@ -19,7 +19,8 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.pooler import (ClassifierPooler, PoolingType,
SimplePooler)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
......@@ -564,29 +565,41 @@ class JambaForSequenceClassification(JambaForCausalLM):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
config = vllm_config.model_config.hf_config
num_labels: int = config.num_labels
score_bias: bool = getattr(config, 'score_bias', False)
self.score = nn.Linear(config.hidden_size, num_labels, bias=score_bias)
# TODO: The original reward weights have float32 accuracy data, we
# would like to load them in fp32 to get that extra precision.
# Currently weight_loader passes the weight which is already in bf16
self.score = nn.Linear(
config.hidden_size,
num_labels,
bias=score_bias,
dtype=torch.float32,
)
pooler_config = vllm_config.model_config.pooler_config
self._pooler = Pooler.from_config_with_defaults(
assert pooler_config is not None
pooler = SimplePooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.LAST,
normalize=False,
softmax=False)
softmax=False,
)
self._pooler = ClassifierPooler(
vllm_config.model_config,
pooling=pooler.pooling,
classifier=self.score,
act_fn=pooler.head.activation,
)
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
hidden_states = hidden_states.float()
logits = self.score(hidden_states)
return self._pooler(logits, pooling_metadata)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
# TODO: The reward weights themselves have float32 accuracy data, we
# would like to load them in fp32 to get that extra precision.
super().load_weights(weights)
self.score = self.score.float()
return self._pooler(hidden_states, pooling_metadata)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
from typing import Optional
from typing import Optional, Union
import torch
from torch import nn
......@@ -13,7 +13,8 @@ from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.linear import (QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.pooler import ClassifierPooler
from vllm.model_executor.layers.pooler import (BasePooler, ClassifierPooler,
PoolingMethod, PoolingType)
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
......@@ -252,10 +253,13 @@ class ModernBertModel(nn.Module):
return norm_outputs
class ModernBertPooler(nn.Module):
class ModernBertPooler(BasePooler):
def __init__(self, config: ModernBertConfig):
super().__init__()
pooling_type = PoolingType[config.classifier_pooling.upper()]
self.pooling = PoolingMethod.from_pooling_type(pooling_type)
self.dense = nn.Linear(config.hidden_size, config.hidden_size,
config.classifier_bias)
self.pooling_type = config.classifier_pooling
......@@ -264,15 +268,12 @@ class ModernBertPooler(nn.Module):
eps=config.norm_eps,
bias=config.norm_bias)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
pooled_output = hidden_states
if self.pooling_type == "mean":
pooled_output = pooled_output.mean(dim=0, keepdim=False)
elif self.pooling_type == "cls":
pooled_output = pooled_output[0, :]
else:
raise ValueError("Pooling type should be either `cls` or `mean`, "
f"but got {self.pooling_type}")
def forward(
self,
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata,
) -> Union[torch.Tensor, list[torch.Tensor]]:
pooled_output = self.pooling(hidden_states, pooling_metadata)
pooled_output = self.norm(self.act(self.dense(pooled_output)))
return pooled_output
......@@ -287,9 +288,11 @@ class ModernBertForSequenceClassification(nn.Module, SupportsV0Only,
self.model = ModernBertModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "modernbert"))
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self._pooler = ClassifierPooler(vllm_config.model_config,
self.classifier,
ModernBertPooler(config))
self._pooler = ClassifierPooler(
vllm_config.model_config,
pooling=ModernBertPooler(config),
classifier=self.classifier,
)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
......
......@@ -9,7 +9,7 @@ from torch import nn
from transformers import RobertaConfig
from vllm.config import VllmConfig
from vllm.model_executor.layers.pooler import ClassifierPooler
from vllm.model_executor.layers.pooler import ClassifierPooler, CLSPool
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.models.bert import BertEmbeddingModel, BertModel
......@@ -106,8 +106,8 @@ class RobertaClassificationHead(nn.Module):
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
def forward(self, features, **kwargs):
x = features[0, :] # take <s> token (equiv. to [CLS])
def forward(self, x: torch.Tensor) -> torch.Tensor:
# CLSPool has already been applied in `pooling`
x = self.dense(x)
x = torch.tanh(x)
x = self.out_proj(x)
......@@ -188,8 +188,11 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
add_pooling_layer=False)
self.classifier = RobertaClassificationHead(config)
self._pooler = ClassifierPooler(vllm_config.model_config,
self.classifier)
self._pooler = ClassifierPooler(
vllm_config.model_config,
pooling=CLSPool(),
classifier=self.classifier,
)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self)
......
......@@ -17,7 +17,6 @@ from huggingface_hub.utils import (EntryNotFoundError, HfHubHTTPError,
HFValidationError, LocalEntryNotFoundError,
RepositoryNotFoundError,
RevisionNotFoundError)
from torch import nn
from transformers import GenerationConfig, PretrainedConfig
from transformers.models.auto.image_processing_auto import (
get_image_processor_config)
......@@ -44,7 +43,6 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, Cohere2Config,
# yapf: enable
from vllm.transformers_utils.configs.mistral import adapt_config_dict
from vllm.transformers_utils.utils import check_gguf_file
from vllm.utils import resolve_obj_by_qualname
if envs.VLLM_USE_MODELSCOPE:
from modelscope import AutoConfig
......@@ -775,28 +773,6 @@ def try_get_generation_config(
return None
def get_classification_activation_function(config: PretrainedConfig):
return nn.Sigmoid() if config.num_labels == 1 else nn.Softmax()
def get_cross_encoder_activation_function(config: PretrainedConfig):
function_name: Optional[str] = None
if (hasattr(config, "sentence_transformers")
and "activation_fn" in config.sentence_transformers):
function_name = config.sentence_transformers["activation_fn"]
elif (hasattr(config, "sbert_ce_default_activation_function")
and config.sbert_ce_default_activation_function is not None):
function_name = config.sbert_ce_default_activation_function
if function_name is not None:
assert function_name.startswith("torch.nn.modules."), (
"Loading of activation functions is restricted to "
"torch.nn.modules for security reasons")
return resolve_obj_by_qualname(function_name)()
return nn.Sigmoid() if config.num_labels == 1 else nn.Identity()
def try_get_safetensors_metadata(
model: str,
*,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment