Unverified Commit 9101dc75 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Model] Avoid hardcoding pooling type (#32119)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 025a32f9
...@@ -25,11 +25,11 @@ from vllm.model_executor.layers.pooler import ( ...@@ -25,11 +25,11 @@ from vllm.model_executor.layers.pooler import (
PoolingParamsUpdate, PoolingParamsUpdate,
) )
from vllm.model_executor.layers.pooler.seqwise import ( from vllm.model_executor.layers.pooler.seqwise import (
CLSPool,
SequencePooler, SequencePooler,
SequencePoolerHeadOutput, SequencePoolerHeadOutput,
SequencePoolerOutput, SequencePoolerOutput,
SequencePoolingMethodOutput, SequencePoolingMethodOutput,
get_seq_pooling_method,
) )
from vllm.model_executor.layers.pooler.tokwise import ( from vllm.model_executor.layers.pooler.tokwise import (
pooler_for_token_classify, pooler_for_token_classify,
...@@ -94,9 +94,9 @@ class BertEmbedding(nn.Module): ...@@ -94,9 +94,9 @@ class BertEmbedding(nn.Module):
class BertPooler(SequencePooler): class BertPooler(SequencePooler):
def __init__(self, config: BertConfig): def __init__(self, config: BertConfig, pooler_config: PoolerConfig):
super().__init__( super().__init__(
pooling=CLSPool(), pooling=get_seq_pooling_method(pooler_config.seq_pooling_type),
head=self.head, head=self.head,
) )
...@@ -450,7 +450,11 @@ class BertPoolingModel(BertModel): ...@@ -450,7 +450,11 @@ class BertPoolingModel(BertModel):
) )
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
self.pooler = BertPooler(config)
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = BertPooler(config, pooler_config)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
other_weights, loaded_stacked_params = self._load_weights(weights) other_weights, loaded_stacked_params = self._load_weights(weights)
...@@ -711,6 +715,8 @@ class BertSpladeSparseEmbeddingModel(BertEmbeddingModel): ...@@ -711,6 +715,8 @@ class BertSpladeSparseEmbeddingModel(BertEmbeddingModel):
layer_norm_eps=getattr(cfg, "layer_norm_eps", 1e-12), layer_norm_eps=getattr(cfg, "layer_norm_eps", 1e-12),
) )
# None of vLLM's built-in sequence pooling types are
# applicable so it is overwritten by SPLADESparsePooler
pooling_mode = getattr(self, "_splade_pooling", "max") pooling_mode = getattr(self, "_splade_pooling", "max")
cls_id = getattr(cfg, "cls_token_id", None) cls_id = getattr(cfg, "cls_token_id", None)
......
...@@ -453,6 +453,7 @@ class BertWithRope(nn.Module, SupportsQuant): ...@@ -453,6 +453,7 @@ class BertWithRope(nn.Module, SupportsQuant):
add_pooling_layer: bool = False, add_pooling_layer: bool = False,
): ):
super().__init__() super().__init__()
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.add_pooling_layer = add_pooling_layer self.add_pooling_layer = add_pooling_layer
self.config = vllm_config.model_config.hf_config self.config = vllm_config.model_config.hf_config
...@@ -463,7 +464,14 @@ class BertWithRope(nn.Module, SupportsQuant): ...@@ -463,7 +464,14 @@ class BertWithRope(nn.Module, SupportsQuant):
rotary_kwargs=self.config.rotary_kwargs, rotary_kwargs=self.config.rotary_kwargs,
prefix=f"{prefix}.encoder", prefix=f"{prefix}.encoder",
) )
self.pooler = BertPooler(self.config) if add_pooling_layer else None
if add_pooling_layer:
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = BertPooler(self.config, pooler_config)
else:
self.pooler = None
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embeddings(input_ids) return self.embeddings(input_ids)
......
...@@ -5,7 +5,7 @@ from collections.abc import Set ...@@ -5,7 +5,7 @@ from collections.abc import Set
import numpy as np import numpy as np
import torch import torch
from vllm.config import ModelConfig, VllmConfig from vllm.config import ModelConfig, PoolerConfig, VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.pooler import ( from vllm.model_executor.layers.pooler import (
DispatchPooler, DispatchPooler,
...@@ -17,6 +17,7 @@ from vllm.model_executor.layers.pooler.seqwise import ( ...@@ -17,6 +17,7 @@ from vllm.model_executor.layers.pooler.seqwise import (
SequencePoolerHeadOutput, SequencePoolerHeadOutput,
SequencePoolingMethod, SequencePoolingMethod,
SequencePoolingMethodOutput, SequencePoolingMethodOutput,
get_seq_pooling_method,
) )
from vllm.model_executor.layers.pooler.tokwise import pooler_for_token_embed from vllm.model_executor.layers.pooler.tokwise import pooler_for_token_embed
from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.model_executor.models.llama import LlamaForCausalLM
...@@ -177,9 +178,13 @@ class GritLMMeanPool(SequencePoolingMethod): ...@@ -177,9 +178,13 @@ class GritLMMeanPool(SequencePoolingMethod):
class GritLMPooler(SequencePooler): class GritLMPooler(SequencePooler):
def __init__(self, model_config: ModelConfig): def __init__(self, model_config: ModelConfig, pooler_config: PoolerConfig):
super().__init__( super().__init__(
pooling=GritLMMeanPool(model_config), pooling=(
GritLMMeanPool(model_config)
if pooler_config.seq_pooling_type == "MEAN"
else get_seq_pooling_method(pooler_config.seq_pooling_type)
),
head=self.head, head=self.head,
) )
...@@ -235,6 +240,6 @@ class GritLM(LlamaForCausalLM): ...@@ -235,6 +240,6 @@ class GritLM(LlamaForCausalLM):
self.pooler = DispatchPooler( self.pooler = DispatchPooler(
{ {
"token_embed": pooler_for_token_embed(pooler_config), "token_embed": pooler_for_token_embed(pooler_config),
"embed": GritLMPooler(vllm_config.model_config), "embed": GritLMPooler(vllm_config.model_config, pooler_config),
} }
) )
...@@ -8,7 +8,7 @@ from transformers import ModernBertConfig ...@@ -8,7 +8,7 @@ from transformers import ModernBertConfig
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig from vllm.config import PoolerConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.attention.encoder_only_attention import ( from vllm.model_executor.layers.attention.encoder_only_attention import (
EncoderOnlyAttention, EncoderOnlyAttention,
...@@ -282,9 +282,14 @@ class ModernBertModel(nn.Module): ...@@ -282,9 +282,14 @@ class ModernBertModel(nn.Module):
class ModernBertPooler(SequencePooler): class ModernBertPooler(SequencePooler):
def __init__(self, config: ModernBertConfig): def __init__(self, config: ModernBertConfig, pooler_config: PoolerConfig):
hf_pooling_type = config.classifier_pooling.upper()
# vllm_pooling_type = pooler_config.seq_pooling_type
# Currently we don't have a way to see if the user set the pooling type
# explicitly or not, so we always use the HF pooling type for now.
super().__init__( super().__init__(
pooling=get_seq_pooling_method(config.classifier_pooling.upper()), pooling=get_seq_pooling_method(hf_pooling_type),
head=self.head, head=self.head,
) )
...@@ -314,7 +319,9 @@ class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding): ...@@ -314,7 +319,9 @@ class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
self.config = config self.config = config
self.model = ModernBertModel( self.model = ModernBertModel(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "modernbert") vllm_config=vllm_config, prefix=maybe_prefix(prefix, "modernbert")
...@@ -324,11 +331,12 @@ class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding): ...@@ -324,11 +331,12 @@ class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding):
config.num_labels, config.num_labels,
dtype=vllm_config.model_config.head_dtype, dtype=vllm_config.model_config.head_dtype,
) )
self.pooling = ModernBertPooler(config)
pooler_config = vllm_config.model_config.pooler_config pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None assert pooler_config is not None
self.pooling = ModernBertPooler(config, pooler_config)
self.pooler = DispatchPooler.for_seq_cls( self.pooler = DispatchPooler.for_seq_cls(
pooler_config, pooler_config,
pooling=self.pooling, pooling=self.pooling,
......
...@@ -9,7 +9,6 @@ from transformers import RobertaConfig ...@@ -9,7 +9,6 @@ from transformers import RobertaConfig
from vllm.config import ModelConfig, VllmConfig from vllm.config import ModelConfig, VllmConfig
from vllm.model_executor.layers.pooler import DispatchPooler from vllm.model_executor.layers.pooler import DispatchPooler
from vllm.model_executor.layers.pooler.seqwise import CLSPool
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.models.bert import ( from vllm.model_executor.models.bert import (
TOKEN_TYPE_SHIFT, TOKEN_TYPE_SHIFT,
...@@ -86,7 +85,7 @@ class RobertaClassificationHead(nn.Module): ...@@ -86,7 +85,7 @@ class RobertaClassificationHead(nn.Module):
) )
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
# CLSPool has already been applied in `pooling` # Token extraction has already been applied in `pooler.pooling`
x = self.dense(x) x = self.dense(x)
x = torch.tanh(x) x = torch.tanh(x)
x = self.out_proj(x) x = self.out_proj(x)
...@@ -194,7 +193,6 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding): ...@@ -194,7 +193,6 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
self.pooler = DispatchPooler.for_seq_cls( self.pooler = DispatchPooler.for_seq_cls(
pooler_config, pooler_config,
pooling=CLSPool(),
classifier=self.classifier, classifier=self.classifier,
) )
......
...@@ -23,7 +23,6 @@ from transformers import AutoModelForSequenceClassification ...@@ -23,7 +23,6 @@ from transformers import AutoModelForSequenceClassification
from vllm.config.utils import getattr_iter from vllm.config.utils import getattr_iter
from vllm.model_executor.layers.pooler import DispatchPooler from vllm.model_executor.layers.pooler import DispatchPooler
from vllm.model_executor.layers.pooler.seqwise import CLSPool
from vllm.model_executor.models.interfaces import SupportsCrossEncoding from vllm.model_executor.models.interfaces import SupportsCrossEncoding
from vllm.model_executor.models.interfaces_base import VllmModelForPooling from vllm.model_executor.models.interfaces_base import VllmModelForPooling
...@@ -32,7 +31,7 @@ if TYPE_CHECKING: ...@@ -32,7 +31,7 @@ if TYPE_CHECKING:
class EmbeddingMixin(VllmModelForPooling): class EmbeddingMixin(VllmModelForPooling):
default_pooling_type = "CLS" default_seq_pooling_type = "CLS"
def __init__(self, *, vllm_config: "VllmConfig", prefix: str = ""): def __init__(self, *, vllm_config: "VllmConfig", prefix: str = ""):
# Skip VllmModelForPooling.__init__ and call the next class in MRO # Skip VllmModelForPooling.__init__ and call the next class in MRO
...@@ -47,7 +46,7 @@ class EmbeddingMixin(VllmModelForPooling): ...@@ -47,7 +46,7 @@ class EmbeddingMixin(VllmModelForPooling):
class SequenceClassificationMixin(SupportsCrossEncoding, VllmModelForPooling): class SequenceClassificationMixin(SupportsCrossEncoding, VllmModelForPooling):
default_pooling_type = "CLS" default_seq_pooling_type = "CLS"
def __init__(self, *, vllm_config: "VllmConfig", prefix: str = ""): def __init__(self, *, vllm_config: "VllmConfig", prefix: str = ""):
# Skip VllmModelForPooling.__init__ and call the next class in MRO # Skip VllmModelForPooling.__init__ and call the next class in MRO
...@@ -85,8 +84,10 @@ class SequenceClassificationMixin(SupportsCrossEncoding, VllmModelForPooling): ...@@ -85,8 +84,10 @@ class SequenceClassificationMixin(SupportsCrossEncoding, VllmModelForPooling):
self.init_parameters(self.classifier, dtype=self.model_config.head_dtype) self.init_parameters(self.classifier, dtype=self.model_config.head_dtype)
class ClassifierWithReshape(self.classifier.__class__): class ClassifierWithReshape(self.classifier.__class__):
"""CLSPool has already been applied in `pooling`. """
Add dim to match expected input shape of `classifier.forward`.""" Token extraction has already been applied in `pooler.pooling`.
Add dim to match expected input shape of `classifier.forward`.
"""
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
if len(args) > 0: if len(args) > 0:
...@@ -97,6 +98,5 @@ class SequenceClassificationMixin(SupportsCrossEncoding, VllmModelForPooling): ...@@ -97,6 +98,5 @@ class SequenceClassificationMixin(SupportsCrossEncoding, VllmModelForPooling):
self.pooler = DispatchPooler.for_seq_cls( self.pooler = DispatchPooler.for_seq_cls(
pooler_config, pooler_config,
pooling=CLSPool(),
classifier=self.classifier, classifier=self.classifier,
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment