Unverified Commit c8ed39b9 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Model] Reorganize pooling layers (#31973)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 02073280
......@@ -24,6 +24,7 @@ from vllm.model_executor.layers.linear import (
ReplicatedLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.pooler import DispatchPooler
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
......@@ -37,7 +38,6 @@ from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from ..layers.pooler import ClassifierPooler, DispatchPooler, Pooler
from .bert import BertPooler
from .interfaces import SupportsCrossEncoding, SupportsQuant
from .interfaces_base import default_pooling_type
......@@ -693,20 +693,10 @@ class GteNewForSequenceClassification(nn.Module, SupportsCrossEncoding):
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = DispatchPooler(
{
"token_classify": Pooler.for_token_classify(
pooler_config, classifier=self.classifier
),
"classify": ClassifierPooler(
pooling=self.new.pooler,
classifier=self.classifier,
act_fn="classify",
),
"score": ClassifierPooler(
pooling=self.new.pooler, classifier=self.classifier, act_fn="score"
),
}
self.pooler = DispatchPooler.for_seq_cls(
pooler_config,
pooling=self.new.pooler,
classifier=self.classifier,
)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
......
......@@ -26,7 +26,7 @@ from vllm.model_executor.layers.linear import (
QKVParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
from vllm.model_executor.layers.pooler import DispatchPooler
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
......@@ -880,12 +880,7 @@ class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
assert pooler_config is not None
self.pooler_config = pooler_config
self.pooler = DispatchPooler(
{
"token_embed": Pooler.for_token_embed(pooler_config),
"embed": Pooler.for_embed(pooler_config),
}
)
self.pooler = DispatchPooler.for_embedding(pooler_config)
# Assumes that self.forward is called after self.embed_input_ids
self._is_text_input = True
......
......@@ -41,6 +41,7 @@ from vllm.model_executor.layers.linear import (
RowParallelLinear,
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.pooler import DispatchPooler
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
......@@ -49,7 +50,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import IntermediateTensors
from ..layers.pooler import DispatchPooler, Pooler
from .interfaces import SupportsCrossEncoding, SupportsPP
from .utils import (
AutoWeightsLoader,
......@@ -351,19 +351,7 @@ class GPT2ForSequenceClassification(nn.Module, SupportsCrossEncoding):
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = DispatchPooler(
{
"token_classify": Pooler.for_token_classify(
pooler_config, classifier=self.score
),
"classify": Pooler.for_classify(
pooler_config, classifier=self.score, act_fn="classify"
),
"score": Pooler.for_classify(
pooler_config, classifier=self.score, act_fn="score"
),
}
)
self.pooler = DispatchPooler.for_seq_cls(pooler_config, classifier=self.score)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.transformer.embed_input_ids(input_ids)
......
......@@ -9,17 +9,19 @@ from vllm.config import ModelConfig, VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.pooler import (
DispatchPooler,
Pooler,
PoolerNormalize,
PoolingMethod,
PoolingParamsUpdate,
TokenPoolerHeadOutput,
TokenPoolingMethodOutput,
)
from vllm.model_executor.layers.pooler.activations import PoolerNormalize
from vllm.model_executor.layers.pooler.seqwise import (
SequencePooler,
SequencePoolerHeadOutput,
SequencePoolingMethod,
SequencePoolingMethodOutput,
)
from vllm.model_executor.layers.pooler.tokwise import pooler_for_token_embed
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.tasks import PoolingTask
from vllm.tokenizers import cached_tokenizer_from_config
from vllm.v1.outputs import TokenPoolerOutput
from vllm.v1.pool.metadata import PoolingMetadata
from .interfaces_base import default_pooling_type
......@@ -27,7 +29,7 @@ from .interfaces_base import default_pooling_type
logger = init_logger(__name__)
class GritLMMeanPool(PoolingMethod):
class GritLMMeanPool(SequencePoolingMethod):
"""As `MeanPool`, but only includes non-instruction tokens."""
def __init__(self, model_config: ModelConfig):
......@@ -151,7 +153,7 @@ class GritLMMeanPool(PoolingMethod):
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> TokenPoolingMethodOutput:
) -> SequencePoolingMethodOutput:
prompt_lens = pooling_metadata.prompt_lens
instr_lens = torch.tensor(
[
......@@ -174,35 +176,22 @@ class GritLMMeanPool(PoolingMethod):
return pooled_data
class GritLMPooler(Pooler):
class GritLMPooler(SequencePooler):
def __init__(self, model_config: ModelConfig):
super().__init__()
super().__init__(
pooling=GritLMMeanPool(model_config),
head=self.head,
)
self.pooling = GritLMMeanPool(model_config)
self.activation = PoolerNormalize()
def get_supported_tasks(self) -> Set[PoolingTask]:
return self.pooling.get_supported_tasks()
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
return self.pooling.get_pooling_updates(task)
def head(
self,
pooled_data: TokenPoolingMethodOutput,
pooled_data: SequencePoolingMethodOutput,
pooling_metadata: PoolingMetadata,
) -> TokenPoolerHeadOutput:
) -> SequencePoolerHeadOutput:
return self.activation(pooled_data)
def forward(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> TokenPoolerOutput:
pooled_data = self.pooling(hidden_states, pooling_metadata)
pooled_data = self.head(pooled_data, pooling_metadata)
return pooled_data
@default_pooling_type("MEAN")
class GritLM(LlamaForCausalLM):
......@@ -245,7 +234,7 @@ class GritLM(LlamaForCausalLM):
if pooler_config is not None:
self.pooler = DispatchPooler(
{
"token_embed": Pooler.for_token_embed(pooler_config),
"token_embed": pooler_for_token_embed(pooler_config),
"embed": GritLMPooler(vllm_config.model_config),
}
)
......@@ -28,7 +28,7 @@ from vllm.model_executor.layers.linear import (
RowParallelLinear,
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
from vllm.model_executor.layers.pooler.tokwise import pooler_for_token_classify
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
......@@ -434,9 +434,7 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM):
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = DispatchPooler(
{"token_classify": Pooler.for_token_classify(pooler_config)}
)
self.pooler = pooler_for_token_classify(pooler_config)
def forward(
self,
......
......@@ -27,7 +27,7 @@ from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateDtypeCalculator,
MambaStateShapeCalculator,
)
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
from vllm.model_executor.layers.pooler import DispatchPooler
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
......@@ -596,16 +596,4 @@ class JambaForSequenceClassification(JambaForCausalLM):
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = DispatchPooler(
{
"token_classify": Pooler.for_token_classify(
pooler_config, classifier=self.score
),
"classify": Pooler.for_classify(
pooler_config, classifier=self.score, act_fn="classify"
),
"score": Pooler.for_classify(
pooler_config, classifier=self.score, act_fn="score"
),
}
)
self.pooler = DispatchPooler.for_seq_cls(pooler_config, classifier=self.score)
......@@ -10,7 +10,7 @@ from vllm.config import ModelConfig, VllmConfig
from vllm.inputs import TokensPrompt
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
from vllm.model_executor.layers.pooler import DispatchPooler
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import IntermediateTensors
......@@ -105,19 +105,7 @@ class JinaVLForSequenceClassification(
self.score = JinaVLScorer(
vllm_config.model_config, prefix=maybe_prefix(prefix, "score")
)
self.pooler = DispatchPooler(
{
"token_classify": Pooler.for_token_classify(
pooler_config, classifier=self.score
),
"classify": Pooler.for_classify(
pooler_config, classifier=self.score, act_fn="classify"
),
"score": Pooler.for_classify(
pooler_config, classifier=self.score, act_fn="score"
),
}
)
self.pooler = DispatchPooler.for_seq_cls(pooler_config, classifier=self.score)
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable, Set
from collections.abc import Iterable
import torch
from torch import nn
......@@ -12,21 +12,18 @@ from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear
from vllm.model_executor.layers.pooler import (
ClassifierPooler,
DispatchPooler,
Pooler,
PoolingMethod,
PoolingParamsUpdate,
TokenPoolerHeadOutput,
TokenPoolingMethodOutput,
from vllm.model_executor.layers.pooler import DispatchPooler
from vllm.model_executor.layers.pooler.seqwise import (
SequencePooler,
SequencePoolerHeadOutput,
SequencePoolingMethodOutput,
get_seq_pooling_method,
)
from vllm.model_executor.layers.pooler.tokwise import pooler_for_token_classify
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import IntermediateTensors
from vllm.tasks import PoolingTask
from vllm.v1.outputs import TokenPoolerOutput
from vllm.v1.pool.metadata import PoolingMetadata
from .interfaces import SupportsCrossEncoding
......@@ -282,12 +279,13 @@ class ModernBertModel(nn.Module):
return norm_outputs
class ModernBertPooler(Pooler):
class ModernBertPooler(SequencePooler):
def __init__(self, config: ModernBertConfig):
super().__init__()
super().__init__(
pooling=get_seq_pooling_method(config.classifier_pooling.upper()),
head=self.head,
)
pooling_type = config.classifier_pooling.upper()
self.pooling = PoolingMethod.from_pooling_type(pooling_type)
self.dense = nn.Linear(
config.hidden_size, config.hidden_size, config.classifier_bias
)
......@@ -296,32 +294,17 @@ class ModernBertPooler(Pooler):
config.hidden_size, eps=config.norm_eps, bias=config.norm_bias
)
def get_supported_tasks(self) -> Set[PoolingTask]:
return self.pooling.get_supported_tasks()
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
return self.pooling.get_pooling_updates(task)
def head(
self,
pooled_data: TokenPoolingMethodOutput,
pooled_data: SequencePoolingMethodOutput,
pooling_metadata: PoolingMetadata,
) -> TokenPoolerHeadOutput:
) -> SequencePoolerHeadOutput:
if isinstance(pooled_data, list):
pooled_data = torch.stack(pooled_data)
pooled_data = pooled_data.to(self.dense.weight.dtype)
return self.norm(self.act(self.dense(pooled_data)))
def forward(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> TokenPoolerOutput:
pooled_data = self.pooling(hidden_states, pooling_metadata)
pooled_data = self.head(pooled_data, pooling_metadata)
return pooled_data
@default_pooling_type("CLS")
class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding):
......@@ -344,18 +327,10 @@ class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding):
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = DispatchPooler(
{
"token_classify": Pooler.for_token_classify(
pooler_config, classifier=self.classifier
),
"classify": ClassifierPooler(
pooling=self.pooling, classifier=self.classifier, act_fn="classify"
),
"score": ClassifierPooler(
pooling=self.pooling, classifier=self.classifier, act_fn="score"
),
}
self.pooler = DispatchPooler.for_seq_cls(
pooler_config,
pooling=self.pooling,
classifier=self.classifier,
)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
......@@ -438,13 +413,7 @@ class ModernBertForTokenClassification(nn.Module):
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = DispatchPooler(
{
"token_classify": Pooler.for_token_classify(
pooler_config=pooler_config
),
}
)
self.pooler = pooler_for_token_classify(pooler_config)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
......
......@@ -14,7 +14,8 @@ from torch import nn
from vllm.config import VllmConfig
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
from vllm.model_executor.layers.pooler import Pooler
from vllm.model_executor.layers.pooler.tokwise import pooler_for_token_classify
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP
......@@ -104,9 +105,7 @@ class Qwen2ForRewardModel(Qwen2RewardBaseModel):
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = DispatchPooler(
{"token_classify": Pooler.for_token_classify(pooler_config)}
)
self.pooler = pooler_for_token_classify(pooler_config)
@default_pooling_type("STEP")
......@@ -118,6 +117,4 @@ class Qwen2ForProcessRewardModel(Qwen2RewardBaseModel):
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = DispatchPooler(
{"token_classify": Pooler.for_token_classify(pooler_config)}
)
self.pooler = pooler_for_token_classify(pooler_config)
......@@ -8,12 +8,8 @@ from torch import nn
from transformers import RobertaConfig
from vllm.config import ModelConfig, VllmConfig
from vllm.model_executor.layers.pooler import (
ClassifierPooler,
CLSPool,
DispatchPooler,
Pooler,
)
from vllm.model_executor.layers.pooler import DispatchPooler
from vllm.model_executor.layers.pooler.seqwise import CLSPool
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.models.bert import (
TOKEN_TYPE_SHIFT,
......@@ -196,18 +192,10 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = DispatchPooler(
{
"token_classify": Pooler.for_token_classify(
pooler_config=pooler_config, classifier=self.classifier
),
"classify": ClassifierPooler(
pooling=CLSPool(), classifier=self.classifier, act_fn="classify"
),
"score": ClassifierPooler(
pooling=CLSPool(), classifier=self.classifier, act_fn="score"
),
}
self.pooler = DispatchPooler.for_seq_cls(
pooler_config,
pooling=CLSPool(),
classifier=self.classifier,
)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
......
......@@ -27,7 +27,7 @@ from vllm.model_executor.layers.linear import (
QKVParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
from vllm.model_executor.layers.pooler import DispatchPooler
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.model_loader.weight_utils import (
......@@ -1050,12 +1050,7 @@ class SiglipEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
assert pooler_config is not None
self.pooler_config = pooler_config
self.pooler = DispatchPooler(
{
"token_embed": Pooler.for_token_embed(pooler_config),
"embed": Pooler.for_embed(pooler_config),
}
)
self.pooler = DispatchPooler.for_embedding(pooler_config)
self._is_text_input = True
......
......@@ -34,7 +34,7 @@ from transformers import BatchFeature
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.logger import init_logger
from vllm.model_executor.layers.pooler import DispatchPooler, DummyPooler
from vllm.model_executor.layers.pooler import IdentityPooler
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.utils import AutoWeightsLoader
from vllm.multimodal import MULTIMODAL_REGISTRY
......@@ -248,7 +248,7 @@ class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal):
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = DispatchPooler({"plugin": DummyPooler()})
self.pooler = IdentityPooler()
def embed_input_ids(
self,
......
......@@ -22,12 +22,8 @@ import torch
from transformers import AutoModelForSequenceClassification
from vllm.config.utils import getattr_iter
from vllm.model_executor.layers.pooler import (
ClassifierPooler,
CLSPool,
DispatchPooler,
Pooler,
)
from vllm.model_executor.layers.pooler import DispatchPooler
from vllm.model_executor.layers.pooler.seqwise import CLSPool
from vllm.model_executor.models.interfaces import SupportsCrossEncoding
from vllm.model_executor.models.interfaces_base import VllmModelForPooling
......@@ -47,12 +43,7 @@ class EmbeddingMixin(VllmModelForPooling):
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = DispatchPooler(
{
"token_embed": Pooler.for_token_embed(pooler_config),
"embed": Pooler.for_embed(pooler_config),
}
)
self.pooler = DispatchPooler.for_embedding(pooler_config)
class SequenceClassificationMixin(SupportsCrossEncoding, VllmModelForPooling):
......@@ -104,16 +95,8 @@ class SequenceClassificationMixin(SupportsCrossEncoding, VllmModelForPooling):
self.classifier.__class__ = ClassifierWithReshape
self.pooler = DispatchPooler(
{
"token_classify": Pooler.for_token_classify(
pooler_config, classifier=self.classifier
),
"classify": ClassifierPooler(
pooling=CLSPool(), classifier=self.classifier, act_fn="classify"
),
"score": ClassifierPooler(
pooling=CLSPool(), classifier=self.classifier, act_fn="score"
),
}
self.pooler = DispatchPooler.for_seq_cls(
pooler_config,
pooling=CLSPool(),
classifier=self.classifier,
)
......@@ -91,9 +91,7 @@ class LogprobsTensors(NamedTuple):
# [num_reqs, <dynamic>]
# The shape of each element depends on the pooler used
TokenPoolerOutput: TypeAlias = torch.Tensor | list[torch.Tensor]
TokenwisePoolerOutput: TypeAlias = list[torch.Tensor] | list[torch.Tensor | None]
PoolerOutput: TypeAlias = TokenPoolerOutput | TokenwisePoolerOutput
PoolerOutput: TypeAlias = torch.Tensor | list[torch.Tensor] | list[torch.Tensor | None]
@dataclass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment