Unverified Commit c8ed39b9 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Model] Reorganize pooling layers (#31973)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 02073280
...@@ -24,6 +24,7 @@ from vllm.model_executor.layers.linear import ( ...@@ -24,6 +24,7 @@ from vllm.model_executor.layers.linear import (
ReplicatedLinear, ReplicatedLinear,
RowParallelLinear, RowParallelLinear,
) )
from vllm.model_executor.layers.pooler import DispatchPooler
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
...@@ -37,7 +38,6 @@ from vllm.model_executor.utils import set_weight_attrs ...@@ -37,7 +38,6 @@ from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from ..layers.pooler import ClassifierPooler, DispatchPooler, Pooler
from .bert import BertPooler from .bert import BertPooler
from .interfaces import SupportsCrossEncoding, SupportsQuant from .interfaces import SupportsCrossEncoding, SupportsQuant
from .interfaces_base import default_pooling_type from .interfaces_base import default_pooling_type
...@@ -693,20 +693,10 @@ class GteNewForSequenceClassification(nn.Module, SupportsCrossEncoding): ...@@ -693,20 +693,10 @@ class GteNewForSequenceClassification(nn.Module, SupportsCrossEncoding):
pooler_config = vllm_config.model_config.pooler_config pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None assert pooler_config is not None
self.pooler = DispatchPooler( self.pooler = DispatchPooler.for_seq_cls(
{ pooler_config,
"token_classify": Pooler.for_token_classify( pooling=self.new.pooler,
pooler_config, classifier=self.classifier classifier=self.classifier,
),
"classify": ClassifierPooler(
pooling=self.new.pooler,
classifier=self.classifier,
act_fn="classify",
),
"score": ClassifierPooler(
pooling=self.new.pooler, classifier=self.classifier, act_fn="score"
),
}
) )
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
......
...@@ -26,7 +26,7 @@ from vllm.model_executor.layers.linear import ( ...@@ -26,7 +26,7 @@ from vllm.model_executor.layers.linear import (
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear, RowParallelLinear,
) )
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler from vllm.model_executor.layers.pooler import DispatchPooler
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...@@ -880,12 +880,7 @@ class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant): ...@@ -880,12 +880,7 @@ class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
assert pooler_config is not None assert pooler_config is not None
self.pooler_config = pooler_config self.pooler_config = pooler_config
self.pooler = DispatchPooler( self.pooler = DispatchPooler.for_embedding(pooler_config)
{
"token_embed": Pooler.for_token_embed(pooler_config),
"embed": Pooler.for_embed(pooler_config),
}
)
# Assumes that self.forward is called after self.embed_input_ids # Assumes that self.forward is called after self.embed_input_ids
self._is_text_input = True self._is_text_input = True
......
...@@ -41,6 +41,7 @@ from vllm.model_executor.layers.linear import ( ...@@ -41,6 +41,7 @@ from vllm.model_executor.layers.linear import (
RowParallelLinear, RowParallelLinear,
) )
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.pooler import DispatchPooler
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
...@@ -49,7 +50,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -49,7 +50,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from ..layers.pooler import DispatchPooler, Pooler
from .interfaces import SupportsCrossEncoding, SupportsPP from .interfaces import SupportsCrossEncoding, SupportsPP
from .utils import ( from .utils import (
AutoWeightsLoader, AutoWeightsLoader,
...@@ -351,19 +351,7 @@ class GPT2ForSequenceClassification(nn.Module, SupportsCrossEncoding): ...@@ -351,19 +351,7 @@ class GPT2ForSequenceClassification(nn.Module, SupportsCrossEncoding):
pooler_config = vllm_config.model_config.pooler_config pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None assert pooler_config is not None
self.pooler = DispatchPooler( self.pooler = DispatchPooler.for_seq_cls(pooler_config, classifier=self.score)
{
"token_classify": Pooler.for_token_classify(
pooler_config, classifier=self.score
),
"classify": Pooler.for_classify(
pooler_config, classifier=self.score, act_fn="classify"
),
"score": Pooler.for_classify(
pooler_config, classifier=self.score, act_fn="score"
),
}
)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.transformer.embed_input_ids(input_ids) return self.transformer.embed_input_ids(input_ids)
......
...@@ -9,17 +9,19 @@ from vllm.config import ModelConfig, VllmConfig ...@@ -9,17 +9,19 @@ from vllm.config import ModelConfig, VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.pooler import ( from vllm.model_executor.layers.pooler import (
DispatchPooler, DispatchPooler,
Pooler,
PoolerNormalize,
PoolingMethod,
PoolingParamsUpdate, PoolingParamsUpdate,
TokenPoolerHeadOutput,
TokenPoolingMethodOutput,
) )
from vllm.model_executor.layers.pooler.activations import PoolerNormalize
from vllm.model_executor.layers.pooler.seqwise import (
SequencePooler,
SequencePoolerHeadOutput,
SequencePoolingMethod,
SequencePoolingMethodOutput,
)
from vllm.model_executor.layers.pooler.tokwise import pooler_for_token_embed
from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.tasks import PoolingTask from vllm.tasks import PoolingTask
from vllm.tokenizers import cached_tokenizer_from_config from vllm.tokenizers import cached_tokenizer_from_config
from vllm.v1.outputs import TokenPoolerOutput
from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.pool.metadata import PoolingMetadata
from .interfaces_base import default_pooling_type from .interfaces_base import default_pooling_type
...@@ -27,7 +29,7 @@ from .interfaces_base import default_pooling_type ...@@ -27,7 +29,7 @@ from .interfaces_base import default_pooling_type
logger = init_logger(__name__) logger = init_logger(__name__)
class GritLMMeanPool(PoolingMethod): class GritLMMeanPool(SequencePoolingMethod):
"""As `MeanPool`, but only includes non-instruction tokens.""" """As `MeanPool`, but only includes non-instruction tokens."""
def __init__(self, model_config: ModelConfig): def __init__(self, model_config: ModelConfig):
...@@ -151,7 +153,7 @@ class GritLMMeanPool(PoolingMethod): ...@@ -151,7 +153,7 @@ class GritLMMeanPool(PoolingMethod):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata, pooling_metadata: PoolingMetadata,
) -> TokenPoolingMethodOutput: ) -> SequencePoolingMethodOutput:
prompt_lens = pooling_metadata.prompt_lens prompt_lens = pooling_metadata.prompt_lens
instr_lens = torch.tensor( instr_lens = torch.tensor(
[ [
...@@ -174,35 +176,22 @@ class GritLMMeanPool(PoolingMethod): ...@@ -174,35 +176,22 @@ class GritLMMeanPool(PoolingMethod):
return pooled_data return pooled_data
class GritLMPooler(Pooler): class GritLMPooler(SequencePooler):
def __init__(self, model_config: ModelConfig): def __init__(self, model_config: ModelConfig):
super().__init__() super().__init__(
pooling=GritLMMeanPool(model_config),
head=self.head,
)
self.pooling = GritLMMeanPool(model_config)
self.activation = PoolerNormalize() self.activation = PoolerNormalize()
def get_supported_tasks(self) -> Set[PoolingTask]:
return self.pooling.get_supported_tasks()
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
return self.pooling.get_pooling_updates(task)
def head( def head(
self, self,
pooled_data: TokenPoolingMethodOutput, pooled_data: SequencePoolingMethodOutput,
pooling_metadata: PoolingMetadata, pooling_metadata: PoolingMetadata,
) -> TokenPoolerHeadOutput: ) -> SequencePoolerHeadOutput:
return self.activation(pooled_data) return self.activation(pooled_data)
def forward(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> TokenPoolerOutput:
pooled_data = self.pooling(hidden_states, pooling_metadata)
pooled_data = self.head(pooled_data, pooling_metadata)
return pooled_data
@default_pooling_type("MEAN") @default_pooling_type("MEAN")
class GritLM(LlamaForCausalLM): class GritLM(LlamaForCausalLM):
...@@ -245,7 +234,7 @@ class GritLM(LlamaForCausalLM): ...@@ -245,7 +234,7 @@ class GritLM(LlamaForCausalLM):
if pooler_config is not None: if pooler_config is not None:
self.pooler = DispatchPooler( self.pooler = DispatchPooler(
{ {
"token_embed": Pooler.for_token_embed(pooler_config), "token_embed": pooler_for_token_embed(pooler_config),
"embed": GritLMPooler(vllm_config.model_config), "embed": GritLMPooler(vllm_config.model_config),
} }
) )
...@@ -28,7 +28,7 @@ from vllm.model_executor.layers.linear import ( ...@@ -28,7 +28,7 @@ from vllm.model_executor.layers.linear import (
RowParallelLinear, RowParallelLinear,
) )
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler from vllm.model_executor.layers.pooler.tokwise import pooler_for_token_classify
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -434,9 +434,7 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM): ...@@ -434,9 +434,7 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM):
pooler_config = vllm_config.model_config.pooler_config pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None assert pooler_config is not None
self.pooler = DispatchPooler( self.pooler = pooler_for_token_classify(pooler_config)
{"token_classify": Pooler.for_token_classify(pooler_config)}
)
def forward( def forward(
self, self,
......
...@@ -27,7 +27,7 @@ from vllm.model_executor.layers.mamba.mamba_utils import ( ...@@ -27,7 +27,7 @@ from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateDtypeCalculator, MambaStateDtypeCalculator,
MambaStateShapeCalculator, MambaStateShapeCalculator,
) )
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler from vllm.model_executor.layers.pooler import DispatchPooler
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
...@@ -596,16 +596,4 @@ class JambaForSequenceClassification(JambaForCausalLM): ...@@ -596,16 +596,4 @@ class JambaForSequenceClassification(JambaForCausalLM):
pooler_config = vllm_config.model_config.pooler_config pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None assert pooler_config is not None
self.pooler = DispatchPooler( self.pooler = DispatchPooler.for_seq_cls(pooler_config, classifier=self.score)
{
"token_classify": Pooler.for_token_classify(
pooler_config, classifier=self.score
),
"classify": Pooler.for_classify(
pooler_config, classifier=self.score, act_fn="classify"
),
"score": Pooler.for_classify(
pooler_config, classifier=self.score, act_fn="score"
),
}
)
...@@ -10,7 +10,7 @@ from vllm.config import ModelConfig, VllmConfig ...@@ -10,7 +10,7 @@ from vllm.config import ModelConfig, VllmConfig
from vllm.inputs import TokensPrompt from vllm.inputs import TokensPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler from vllm.model_executor.layers.pooler import DispatchPooler
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
...@@ -105,19 +105,7 @@ class JinaVLForSequenceClassification( ...@@ -105,19 +105,7 @@ class JinaVLForSequenceClassification(
self.score = JinaVLScorer( self.score = JinaVLScorer(
vllm_config.model_config, prefix=maybe_prefix(prefix, "score") vllm_config.model_config, prefix=maybe_prefix(prefix, "score")
) )
self.pooler = DispatchPooler( self.pooler = DispatchPooler.for_seq_cls(pooler_config, classifier=self.score)
{
"token_classify": Pooler.for_token_classify(
pooler_config, classifier=self.score
),
"classify": Pooler.for_classify(
pooler_config, classifier=self.score, act_fn="classify"
),
"score": Pooler.for_classify(
pooler_config, classifier=self.score, act_fn="score"
),
}
)
@classmethod @classmethod
def get_placeholder_str(cls, modality: str, i: int) -> str | None: def get_placeholder_str(cls, modality: str, i: int) -> str | None:
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable, Set from collections.abc import Iterable
import torch import torch
from torch import nn from torch import nn
...@@ -12,21 +12,18 @@ from vllm.compilation.decorators import support_torch_compile ...@@ -12,21 +12,18 @@ from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear
from vllm.model_executor.layers.pooler import ( from vllm.model_executor.layers.pooler import DispatchPooler
ClassifierPooler, from vllm.model_executor.layers.pooler.seqwise import (
DispatchPooler, SequencePooler,
Pooler, SequencePoolerHeadOutput,
PoolingMethod, SequencePoolingMethodOutput,
PoolingParamsUpdate, get_seq_pooling_method,
TokenPoolerHeadOutput,
TokenPoolingMethodOutput,
) )
from vllm.model_executor.layers.pooler.tokwise import pooler_for_token_classify
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.tasks import PoolingTask
from vllm.v1.outputs import TokenPoolerOutput
from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.pool.metadata import PoolingMetadata
from .interfaces import SupportsCrossEncoding from .interfaces import SupportsCrossEncoding
...@@ -282,12 +279,13 @@ class ModernBertModel(nn.Module): ...@@ -282,12 +279,13 @@ class ModernBertModel(nn.Module):
return norm_outputs return norm_outputs
class ModernBertPooler(Pooler): class ModernBertPooler(SequencePooler):
def __init__(self, config: ModernBertConfig): def __init__(self, config: ModernBertConfig):
super().__init__() super().__init__(
pooling=get_seq_pooling_method(config.classifier_pooling.upper()),
head=self.head,
)
pooling_type = config.classifier_pooling.upper()
self.pooling = PoolingMethod.from_pooling_type(pooling_type)
self.dense = nn.Linear( self.dense = nn.Linear(
config.hidden_size, config.hidden_size, config.classifier_bias config.hidden_size, config.hidden_size, config.classifier_bias
) )
...@@ -296,32 +294,17 @@ class ModernBertPooler(Pooler): ...@@ -296,32 +294,17 @@ class ModernBertPooler(Pooler):
config.hidden_size, eps=config.norm_eps, bias=config.norm_bias config.hidden_size, eps=config.norm_eps, bias=config.norm_bias
) )
def get_supported_tasks(self) -> Set[PoolingTask]:
return self.pooling.get_supported_tasks()
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
return self.pooling.get_pooling_updates(task)
def head( def head(
self, self,
pooled_data: TokenPoolingMethodOutput, pooled_data: SequencePoolingMethodOutput,
pooling_metadata: PoolingMetadata, pooling_metadata: PoolingMetadata,
) -> TokenPoolerHeadOutput: ) -> SequencePoolerHeadOutput:
if isinstance(pooled_data, list): if isinstance(pooled_data, list):
pooled_data = torch.stack(pooled_data) pooled_data = torch.stack(pooled_data)
pooled_data = pooled_data.to(self.dense.weight.dtype) pooled_data = pooled_data.to(self.dense.weight.dtype)
return self.norm(self.act(self.dense(pooled_data))) return self.norm(self.act(self.dense(pooled_data)))
def forward(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> TokenPoolerOutput:
pooled_data = self.pooling(hidden_states, pooling_metadata)
pooled_data = self.head(pooled_data, pooling_metadata)
return pooled_data
@default_pooling_type("CLS") @default_pooling_type("CLS")
class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding): class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding):
...@@ -344,18 +327,10 @@ class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding): ...@@ -344,18 +327,10 @@ class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding):
pooler_config = vllm_config.model_config.pooler_config pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None assert pooler_config is not None
self.pooler = DispatchPooler( self.pooler = DispatchPooler.for_seq_cls(
{ pooler_config,
"token_classify": Pooler.for_token_classify( pooling=self.pooling,
pooler_config, classifier=self.classifier classifier=self.classifier,
),
"classify": ClassifierPooler(
pooling=self.pooling, classifier=self.classifier, act_fn="classify"
),
"score": ClassifierPooler(
pooling=self.pooling, classifier=self.classifier, act_fn="score"
),
}
) )
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
...@@ -438,13 +413,7 @@ class ModernBertForTokenClassification(nn.Module): ...@@ -438,13 +413,7 @@ class ModernBertForTokenClassification(nn.Module):
pooler_config = vllm_config.model_config.pooler_config pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None assert pooler_config is not None
self.pooler = DispatchPooler( self.pooler = pooler_for_token_classify(pooler_config)
{
"token_classify": Pooler.for_token_classify(
pooler_config=pooler_config
),
}
)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids) return self.model.embed_input_ids(input_ids)
......
...@@ -14,7 +14,8 @@ from torch import nn ...@@ -14,7 +14,8 @@ from torch import nn
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler from vllm.model_executor.layers.pooler import Pooler
from vllm.model_executor.layers.pooler.tokwise import pooler_for_token_classify
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
...@@ -104,9 +105,7 @@ class Qwen2ForRewardModel(Qwen2RewardBaseModel): ...@@ -104,9 +105,7 @@ class Qwen2ForRewardModel(Qwen2RewardBaseModel):
pooler_config = vllm_config.model_config.pooler_config pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None assert pooler_config is not None
self.pooler = DispatchPooler( self.pooler = pooler_for_token_classify(pooler_config)
{"token_classify": Pooler.for_token_classify(pooler_config)}
)
@default_pooling_type("STEP") @default_pooling_type("STEP")
...@@ -118,6 +117,4 @@ class Qwen2ForProcessRewardModel(Qwen2RewardBaseModel): ...@@ -118,6 +117,4 @@ class Qwen2ForProcessRewardModel(Qwen2RewardBaseModel):
pooler_config = vllm_config.model_config.pooler_config pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None assert pooler_config is not None
self.pooler = DispatchPooler( self.pooler = pooler_for_token_classify(pooler_config)
{"token_classify": Pooler.for_token_classify(pooler_config)}
)
...@@ -8,12 +8,8 @@ from torch import nn ...@@ -8,12 +8,8 @@ from torch import nn
from transformers import RobertaConfig from transformers import RobertaConfig
from vllm.config import ModelConfig, VllmConfig from vllm.config import ModelConfig, VllmConfig
from vllm.model_executor.layers.pooler import ( from vllm.model_executor.layers.pooler import DispatchPooler
ClassifierPooler, from vllm.model_executor.layers.pooler.seqwise import CLSPool
CLSPool,
DispatchPooler,
Pooler,
)
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.models.bert import ( from vllm.model_executor.models.bert import (
TOKEN_TYPE_SHIFT, TOKEN_TYPE_SHIFT,
...@@ -196,18 +192,10 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding): ...@@ -196,18 +192,10 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
pooler_config = vllm_config.model_config.pooler_config pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None assert pooler_config is not None
self.pooler = DispatchPooler( self.pooler = DispatchPooler.for_seq_cls(
{ pooler_config,
"token_classify": Pooler.for_token_classify( pooling=CLSPool(),
pooler_config=pooler_config, classifier=self.classifier classifier=self.classifier,
),
"classify": ClassifierPooler(
pooling=CLSPool(), classifier=self.classifier, act_fn="classify"
),
"score": ClassifierPooler(
pooling=CLSPool(), classifier=self.classifier, act_fn="score"
),
}
) )
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
......
...@@ -27,7 +27,7 @@ from vllm.model_executor.layers.linear import ( ...@@ -27,7 +27,7 @@ from vllm.model_executor.layers.linear import (
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear, RowParallelLinear,
) )
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler from vllm.model_executor.layers.pooler import DispatchPooler
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
...@@ -1050,12 +1050,7 @@ class SiglipEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant): ...@@ -1050,12 +1050,7 @@ class SiglipEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
assert pooler_config is not None assert pooler_config is not None
self.pooler_config = pooler_config self.pooler_config = pooler_config
self.pooler = DispatchPooler( self.pooler = DispatchPooler.for_embedding(pooler_config)
{
"token_embed": Pooler.for_token_embed(pooler_config),
"embed": Pooler.for_embed(pooler_config),
}
)
self._is_text_input = True self._is_text_input = True
......
...@@ -34,7 +34,7 @@ from transformers import BatchFeature ...@@ -34,7 +34,7 @@ from transformers import BatchFeature
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.pooler import DispatchPooler, DummyPooler from vllm.model_executor.layers.pooler import IdentityPooler
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.utils import AutoWeightsLoader from vllm.model_executor.models.utils import AutoWeightsLoader
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
...@@ -248,7 +248,7 @@ class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal): ...@@ -248,7 +248,7 @@ class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal):
pooler_config = vllm_config.model_config.pooler_config pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None assert pooler_config is not None
self.pooler = DispatchPooler({"plugin": DummyPooler()}) self.pooler = IdentityPooler()
def embed_input_ids( def embed_input_ids(
self, self,
......
...@@ -22,12 +22,8 @@ import torch ...@@ -22,12 +22,8 @@ import torch
from transformers import AutoModelForSequenceClassification from transformers import AutoModelForSequenceClassification
from vllm.config.utils import getattr_iter from vllm.config.utils import getattr_iter
from vllm.model_executor.layers.pooler import ( from vllm.model_executor.layers.pooler import DispatchPooler
ClassifierPooler, from vllm.model_executor.layers.pooler.seqwise import CLSPool
CLSPool,
DispatchPooler,
Pooler,
)
from vllm.model_executor.models.interfaces import SupportsCrossEncoding from vllm.model_executor.models.interfaces import SupportsCrossEncoding
from vllm.model_executor.models.interfaces_base import VllmModelForPooling from vllm.model_executor.models.interfaces_base import VllmModelForPooling
...@@ -47,12 +43,7 @@ class EmbeddingMixin(VllmModelForPooling): ...@@ -47,12 +43,7 @@ class EmbeddingMixin(VllmModelForPooling):
pooler_config = vllm_config.model_config.pooler_config pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None assert pooler_config is not None
self.pooler = DispatchPooler( self.pooler = DispatchPooler.for_embedding(pooler_config)
{
"token_embed": Pooler.for_token_embed(pooler_config),
"embed": Pooler.for_embed(pooler_config),
}
)
class SequenceClassificationMixin(SupportsCrossEncoding, VllmModelForPooling): class SequenceClassificationMixin(SupportsCrossEncoding, VllmModelForPooling):
...@@ -104,16 +95,8 @@ class SequenceClassificationMixin(SupportsCrossEncoding, VllmModelForPooling): ...@@ -104,16 +95,8 @@ class SequenceClassificationMixin(SupportsCrossEncoding, VllmModelForPooling):
self.classifier.__class__ = ClassifierWithReshape self.classifier.__class__ = ClassifierWithReshape
self.pooler = DispatchPooler( self.pooler = DispatchPooler.for_seq_cls(
{ pooler_config,
"token_classify": Pooler.for_token_classify( pooling=CLSPool(),
pooler_config, classifier=self.classifier classifier=self.classifier,
),
"classify": ClassifierPooler(
pooling=CLSPool(), classifier=self.classifier, act_fn="classify"
),
"score": ClassifierPooler(
pooling=CLSPool(), classifier=self.classifier, act_fn="score"
),
}
) )
...@@ -91,9 +91,7 @@ class LogprobsTensors(NamedTuple): ...@@ -91,9 +91,7 @@ class LogprobsTensors(NamedTuple):
# [num_reqs, <dynamic>] # [num_reqs, <dynamic>]
# The shape of each element depends on the pooler used # The shape of each element depends on the pooler used
TokenPoolerOutput: TypeAlias = torch.Tensor | list[torch.Tensor] PoolerOutput: TypeAlias = torch.Tensor | list[torch.Tensor] | list[torch.Tensor | None]
TokenwisePoolerOutput: TypeAlias = list[torch.Tensor] | list[torch.Tensor | None]
PoolerOutput: TypeAlias = TokenPoolerOutput | TokenwisePoolerOutput
@dataclass @dataclass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment