Unverified Commit d1b6fe00 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Chore] Further cleanup pooler (#31951)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 04a49669
...@@ -5,12 +5,7 @@ import os ...@@ -5,12 +5,7 @@ import os
import pytest import pytest
from vllm.model_executor.layers.pooler import ( from vllm.model_executor.layers.pooler import CLSPool, DispatchPooler, MeanPool
CLSPool,
DispatchPooler,
MeanPool,
PoolingType,
)
from vllm.model_executor.models.bert import BertEmbeddingModel from vllm.model_executor.models.bert import BertEmbeddingModel
from vllm.model_executor.models.roberta import RobertaEmbeddingModel from vllm.model_executor.models.roberta import RobertaEmbeddingModel
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -50,7 +45,7 @@ def test_model_loading_with_params(vllm_runner, monkeypatch): ...@@ -50,7 +45,7 @@ def test_model_loading_with_params(vllm_runner, monkeypatch):
assert model_config.encoder_config["do_lower_case"] assert model_config.encoder_config["do_lower_case"]
# asserts on the pooling config files # asserts on the pooling config files
assert model_config.pooler_config.pooling_type == PoolingType.CLS.name assert model_config.pooler_config.pooling_type == "CLS"
assert model_config.pooler_config.normalize assert model_config.pooler_config.normalize
# asserts on the tokenizer loaded # asserts on the tokenizer loaded
...@@ -94,7 +89,7 @@ def test_roberta_model_loading_with_params(vllm_runner, monkeypatch): ...@@ -94,7 +89,7 @@ def test_roberta_model_loading_with_params(vllm_runner, monkeypatch):
assert not model_config.encoder_config["do_lower_case"] assert not model_config.encoder_config["do_lower_case"]
# asserts on the pooling config files # asserts on the pooling config files
assert model_config.pooler_config.pooling_type == PoolingType.MEAN.name assert model_config.pooler_config.pooling_type == "MEAN"
assert model_config.pooler_config.normalize assert model_config.pooler_config.normalize
# asserts on the tokenizer loaded # asserts on the tokenizer loaded
......
...@@ -25,7 +25,6 @@ from vllm.config.vllm import ( ...@@ -25,7 +25,6 @@ from vllm.config.vllm import (
OPTIMIZATION_LEVEL_TO_CONFIG, OPTIMIZATION_LEVEL_TO_CONFIG,
OptimizationLevel, OptimizationLevel,
) )
from vllm.model_executor.layers.pooler import PoolingType
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -162,7 +161,7 @@ def test_get_pooling_config(): ...@@ -162,7 +161,7 @@ def test_get_pooling_config():
assert model_config.pooler_config is not None assert model_config.pooler_config is not None
assert model_config.pooler_config.normalize assert model_config.pooler_config.normalize
assert model_config.pooler_config.pooling_type == PoolingType.MEAN.name assert model_config.pooler_config.pooling_type == "MEAN"
@pytest.mark.skipif( @pytest.mark.skipif(
......
...@@ -21,8 +21,7 @@ class PoolerConfig: ...@@ -21,8 +21,7 @@ class PoolerConfig:
pooling_type: PoolingTypeStr | None = None pooling_type: PoolingTypeStr | None = None
""" """
The pooling method of the pooling model. This should be a key in The pooling method of the pooling model.
[`vllm.model_executor.layers.pooler.PoolingType`][].
""" """
## for embeddings models ## for embeddings models
......
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Callable, Mapping, Set from collections.abc import Callable, Mapping, Set
from dataclasses import dataclass from dataclasses import dataclass
from enum import IntEnum
from itertools import groupby from itertools import groupby
from typing import TypeAlias, TypeVar from typing import TypeAlias, TypeVar
...@@ -12,13 +11,14 @@ import torch.nn as nn ...@@ -12,13 +11,14 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.config import ModelConfig, PoolerConfig, get_current_vllm_config from vllm.config import ModelConfig, get_current_vllm_config
from vllm.config.pooler import PoolerConfig, PoolingTypeStr
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.models.adapters import _load_st_projector from vllm.model_executor.models.adapters import _load_st_projector
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.tasks import PoolingTask from vllm.tasks import PoolingTask
from vllm.utils.import_utils import resolve_obj_by_qualname from vllm.utils.import_utils import resolve_obj_by_qualname
from vllm.v1.outputs import PoolerOutput, TokenPoolerOutput, TokensPoolerOutput from vllm.v1.outputs import PoolerOutput, TokenPoolerOutput, TokenwisePoolerOutput
from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.pool.metadata import PoolingMetadata
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -31,27 +31,17 @@ ClassifierFn = Callable[[torch.Tensor], torch.Tensor] ...@@ -31,27 +31,17 @@ ClassifierFn = Callable[[torch.Tensor], torch.Tensor]
TokenPoolingMethodOutput: TypeAlias = torch.Tensor | list[torch.Tensor] TokenPoolingMethodOutput: TypeAlias = torch.Tensor | list[torch.Tensor]
TokensPoolingMethodOutput: TypeAlias = list[torch.Tensor] | list[torch.Tensor | None] TokenwisePoolingMethodOutput: TypeAlias = list[torch.Tensor] | list[torch.Tensor | None]
TokensPoolingMethodOutputItem: TypeAlias = torch.Tensor | None TokenwisePoolingMethodOutputItem: TypeAlias = torch.Tensor | None
PoolingMethodOutput: TypeAlias = TokenPoolingMethodOutput | TokensPoolingMethodOutput PoolingMethodOutput: TypeAlias = TokenPoolingMethodOutput | TokenwisePoolingMethodOutput
TokenPoolerHeadOutput: TypeAlias = torch.Tensor | list[torch.Tensor] TokenPoolerHeadOutput: TypeAlias = torch.Tensor | list[torch.Tensor]
TokensPoolerHeadOutput: TypeAlias = torch.Tensor | None TokenwisePoolerHeadOutput: TypeAlias = torch.Tensor | None
class PoolingType(IntEnum):
"""Enumeration for different types of pooling methods."""
LAST = 0
ALL = 1
CLS = 2
STEP = 3
MEAN = 4
@dataclass(frozen=True) @dataclass(frozen=True)
class ResolvedPoolingConfig: class ResolvedPoolingConfig:
pooling_type: PoolingType pooling_type: PoolingTypeStr
task: PoolingTask task: PoolingTask
@classmethod @classmethod
...@@ -61,7 +51,7 @@ class ResolvedPoolingConfig: ...@@ -61,7 +51,7 @@ class ResolvedPoolingConfig:
pooler_config: PoolerConfig, pooler_config: PoolerConfig,
) -> "ResolvedPoolingConfig": ) -> "ResolvedPoolingConfig":
assert pooler_config.pooling_type is not None assert pooler_config.pooling_type is not None
return cls(task=task, pooling_type=PoolingType[pooler_config.pooling_type]) return cls(task=task, pooling_type=pooler_config.pooling_type)
@dataclass(frozen=True) @dataclass(frozen=True)
...@@ -112,17 +102,22 @@ def get_cross_encoder_activation_function(config: PretrainedConfig): ...@@ -112,17 +102,22 @@ def get_cross_encoder_activation_function(config: PretrainedConfig):
class PoolingMethod(nn.Module, ABC): class PoolingMethod(nn.Module, ABC):
@staticmethod @staticmethod
def from_pooling_type(pooling_type: PoolingType) -> "PoolingMethod": def from_pooling_type(pooling_type: PoolingTypeStr) -> "PoolingMethod":
if pooling_type == PoolingType.LAST: if pooling_type == "LAST":
return LastPool() return LastPool()
if pooling_type == PoolingType.ALL: if pooling_type == "ALL":
return AllPool() return AllPool()
if pooling_type == PoolingType.CLS: if pooling_type == "CLS":
return CLSPool() return CLSPool()
if pooling_type == PoolingType.MEAN: if pooling_type == "MEAN":
return MeanPool() return MeanPool()
if pooling_type == "STEP":
raise ValueError(
"'STEP' pooling is handled by StepPooler "
"and is not a standalone PoolingMethod."
)
raise NotImplementedError(f"Unsupported method: {pooling_type}") raise NotImplementedError(f"Unsupported method: {pooling_type!r}")
@abstractmethod @abstractmethod
def get_supported_tasks(self) -> Set[PoolingTask]: def get_supported_tasks(self) -> Set[PoolingTask]:
...@@ -186,13 +181,12 @@ class AllPool(PoolingMethod): ...@@ -186,13 +181,12 @@ class AllPool(PoolingMethod):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata, pooling_metadata: PoolingMetadata,
) -> TokensPoolingMethodOutput: ) -> TokenwisePoolingMethodOutput:
pooling_cursor = pooling_metadata.get_pooling_cursor() pooling_cursor = pooling_metadata.get_pooling_cursor()
is_finished = pooling_cursor.is_finished() hidden_states_all = hidden_states.split(
hidden_states_lst = list( pooling_cursor.num_scheduled_tokens_cpu.tolist()
hidden_states.split(pooling_cursor.num_scheduled_tokens_cpu.tolist())
) )
hidden_states_lst = [hidden_states_lst[i] for i in pooling_cursor.index] hidden_states_lst = [hidden_states_all[i] for i in pooling_cursor.index]
if not self.enable_chunked_prefill: if not self.enable_chunked_prefill:
return hidden_states_lst return hidden_states_lst
...@@ -206,7 +200,7 @@ class AllPool(PoolingMethod): ...@@ -206,7 +200,7 @@ class AllPool(PoolingMethod):
# 2. Once prefill is finished, send hidden_states_cache to PoolerHead # 2. Once prefill is finished, send hidden_states_cache to PoolerHead
output_list = list[torch.Tensor | None]() output_list = list[torch.Tensor | None]()
for p, finished in zip(pooling_states, is_finished): for p, finished in zip(pooling_states, pooling_cursor.is_finished()):
if finished: if finished:
hidden_states_cache = p.hidden_states_cache hidden_states_cache = p.hidden_states_cache
if len(hidden_states_cache) == 1: if len(hidden_states_cache) == 1:
...@@ -620,19 +614,19 @@ class ClassifierPooler(Pooler): ...@@ -620,19 +614,19 @@ class ClassifierPooler(Pooler):
return scores return scores
class TokensPoolerHead(nn.Module, ABC): class TokenwisePoolerHead(nn.Module, ABC):
"""Applicable to pooling strategies that output multiple tokens.""" """Applicable to pooling strategies that output multiple tokens."""
@abstractmethod @abstractmethod
def forward( def forward(
self, self,
pooled_data: TokensPoolingMethodOutputItem, pooled_data: TokenwisePoolingMethodOutputItem,
pooling_param: PoolingParams, pooling_param: PoolingParams,
) -> TokensPoolerHeadOutput: ) -> TokenwisePoolerHeadOutput:
raise NotImplementedError raise NotImplementedError
class TokenEmbeddingPoolerHead(TokensPoolerHead): class TokenEmbeddingPoolerHead(TokenwisePoolerHead):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
...@@ -647,9 +641,9 @@ class TokenEmbeddingPoolerHead(TokensPoolerHead): ...@@ -647,9 +641,9 @@ class TokenEmbeddingPoolerHead(TokensPoolerHead):
def forward( def forward(
self, self,
pooled_data: TokensPoolingMethodOutputItem, pooled_data: TokenwisePoolingMethodOutputItem,
pooling_param: PoolingParams, pooling_param: PoolingParams,
) -> TokensPoolerHeadOutput: ) -> TokenwisePoolerHeadOutput:
# for unfinished chunked prefill # for unfinished chunked prefill
if pooled_data is None: if pooled_data is None:
return None return None
...@@ -673,7 +667,7 @@ class TokenEmbeddingPoolerHead(TokensPoolerHead): ...@@ -673,7 +667,7 @@ class TokenEmbeddingPoolerHead(TokensPoolerHead):
return pooled_data return pooled_data
class TokenClassifierPoolerHead(TokensPoolerHead): class TokenClassifierPoolerHead(TokenwisePoolerHead):
def __init__( def __init__(
self, self,
classifier: ClassifierFn | None, classifier: ClassifierFn | None,
...@@ -695,9 +689,9 @@ class TokenClassifierPoolerHead(TokensPoolerHead): ...@@ -695,9 +689,9 @@ class TokenClassifierPoolerHead(TokensPoolerHead):
def forward( def forward(
self, self,
pooled_data: TokensPoolingMethodOutputItem, pooled_data: TokenwisePoolingMethodOutputItem,
pooling_param: PoolingParams, pooling_param: PoolingParams,
) -> TokensPoolerHeadOutput: ) -> TokenwisePoolerHeadOutput:
# for unfinished chunked prefill # for unfinished chunked prefill
if pooled_data is None: if pooled_data is None:
return None return None
...@@ -722,7 +716,7 @@ class TokenClassifierPoolerHead(TokensPoolerHead): ...@@ -722,7 +716,7 @@ class TokenClassifierPoolerHead(TokensPoolerHead):
class AllPooler(Pooler): class AllPooler(Pooler):
def __init__(self, head: TokensPoolerHead) -> None: def __init__(self, head: TokenwisePoolerHead) -> None:
super().__init__() super().__init__()
self.pooling = AllPool() self.pooling = AllPool()
...@@ -735,7 +729,7 @@ class AllPooler(Pooler): ...@@ -735,7 +729,7 @@ class AllPooler(Pooler):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata, pooling_metadata: PoolingMetadata,
) -> TokensPoolerOutput: ) -> TokenwisePoolerOutput:
pooled_data = self.pooling(hidden_states, pooling_metadata) pooled_data = self.pooling(hidden_states, pooling_metadata)
pooling_params = pooling_metadata.pooling_params pooling_params = pooling_metadata.pooling_params
assert len(pooled_data) == len(pooling_params) assert len(pooled_data) == len(pooling_params)
...@@ -744,7 +738,7 @@ class AllPooler(Pooler): ...@@ -744,7 +738,7 @@ class AllPooler(Pooler):
class StepPooler(Pooler): class StepPooler(Pooler):
def __init__(self, head: TokensPoolerHead) -> None: def __init__(self, head: TokenwisePoolerHead) -> None:
super().__init__() super().__init__()
self.pooling = AllPool() self.pooling = AllPool()
...@@ -790,7 +784,7 @@ class StepPooler(Pooler): ...@@ -790,7 +784,7 @@ class StepPooler(Pooler):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata, pooling_metadata: PoolingMetadata,
) -> TokensPoolerOutput: ) -> TokenwisePoolerOutput:
pooled_data = self.extract_states(hidden_states, pooling_metadata) pooled_data = self.extract_states(hidden_states, pooling_metadata)
pooling_params = pooling_metadata.pooling_params pooling_params = pooling_metadata.pooling_params
assert len(pooled_data) == len(pooling_params) assert len(pooled_data) == len(pooling_params)
......
...@@ -23,7 +23,6 @@ from vllm.model_executor.layers.pooler import ( ...@@ -23,7 +23,6 @@ from vllm.model_executor.layers.pooler import (
Pooler, Pooler,
PoolingMethod, PoolingMethod,
PoolingParamsUpdate, PoolingParamsUpdate,
PoolingType,
TokenPoolerHeadOutput, TokenPoolerHeadOutput,
TokenPoolingMethodOutput, TokenPoolingMethodOutput,
) )
...@@ -90,7 +89,7 @@ class BertPooler(Pooler): ...@@ -90,7 +89,7 @@ class BertPooler(Pooler):
def __init__(self, config: BertConfig): def __init__(self, config: BertConfig):
super().__init__() super().__init__()
self.pooling = PoolingMethod.from_pooling_type(PoolingType.CLS) self.pooling = PoolingMethod.from_pooling_type("CLS")
self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh() self.activation = nn.Tanh()
......
...@@ -18,7 +18,6 @@ from vllm.model_executor.layers.pooler import ( ...@@ -18,7 +18,6 @@ from vllm.model_executor.layers.pooler import (
Pooler, Pooler,
PoolingMethod, PoolingMethod,
PoolingParamsUpdate, PoolingParamsUpdate,
PoolingType,
TokenPoolerHeadOutput, TokenPoolerHeadOutput,
TokenPoolingMethodOutput, TokenPoolingMethodOutput,
) )
...@@ -287,7 +286,7 @@ class ModernBertPooler(Pooler): ...@@ -287,7 +286,7 @@ class ModernBertPooler(Pooler):
def __init__(self, config: ModernBertConfig): def __init__(self, config: ModernBertConfig):
super().__init__() super().__init__()
pooling_type = PoolingType[config.classifier_pooling.upper()] pooling_type = config.classifier_pooling.upper()
self.pooling = PoolingMethod.from_pooling_type(pooling_type) self.pooling = PoolingMethod.from_pooling_type(pooling_type)
self.dense = nn.Linear( self.dense = nn.Linear(
config.hidden_size, config.hidden_size, config.classifier_bias config.hidden_size, config.hidden_size, config.classifier_bias
......
...@@ -92,8 +92,8 @@ class LogprobsTensors(NamedTuple): ...@@ -92,8 +92,8 @@ class LogprobsTensors(NamedTuple):
# [num_reqs, <dynamic>] # [num_reqs, <dynamic>]
# The shape of each element depends on the pooler used # The shape of each element depends on the pooler used
TokenPoolerOutput: TypeAlias = torch.Tensor | list[torch.Tensor] TokenPoolerOutput: TypeAlias = torch.Tensor | list[torch.Tensor]
TokensPoolerOutput: TypeAlias = list[torch.Tensor] | list[torch.Tensor | None] TokenwisePoolerOutput: TypeAlias = list[torch.Tensor] | list[torch.Tensor | None]
PoolerOutput: TypeAlias = TokenPoolerOutput | TokensPoolerOutput PoolerOutput: TypeAlias = TokenPoolerOutput | TokenwisePoolerOutput
@dataclass @dataclass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment