Unverified Commit b7036c87 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Refactor] Clean up pooler modules (#31897)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent cc6dafae
...@@ -5,7 +5,7 @@ from collections.abc import Callable, Mapping, Set ...@@ -5,7 +5,7 @@ from collections.abc import Callable, Mapping, Set
from dataclasses import dataclass from dataclasses import dataclass
from enum import IntEnum from enum import IntEnum
from itertools import groupby from itertools import groupby
from typing import TypeVar from typing import TypeAlias, TypeVar
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -18,8 +18,8 @@ from vllm.model_executor.models.adapters import _load_st_projector ...@@ -18,8 +18,8 @@ from vllm.model_executor.models.adapters import _load_st_projector
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.tasks import PoolingTask from vllm.tasks import PoolingTask
from vllm.utils.import_utils import resolve_obj_by_qualname from vllm.utils.import_utils import resolve_obj_by_qualname
from vllm.v1.outputs import PoolerOutput from vllm.v1.outputs import PoolerOutput, TokenPoolerOutput, TokensPoolerOutput
from vllm.v1.pool.metadata import PoolingCursor, PoolingMetadata from vllm.v1.pool.metadata import PoolingMetadata
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -30,6 +30,15 @@ PoolingFn = Callable[ ...@@ -30,6 +30,15 @@ PoolingFn = Callable[
ClassifierFn = Callable[[torch.Tensor], torch.Tensor] ClassifierFn = Callable[[torch.Tensor], torch.Tensor]
TokenPoolingMethodOutput: TypeAlias = torch.Tensor | list[torch.Tensor]
TokensPoolingMethodOutput: TypeAlias = list[torch.Tensor] | list[torch.Tensor | None]
TokensPoolingMethodOutputItem: TypeAlias = torch.Tensor | None
PoolingMethodOutput: TypeAlias = TokenPoolingMethodOutput | TokensPoolingMethodOutput
TokenPoolerHeadOutput: TypeAlias = torch.Tensor | list[torch.Tensor]
TokensPoolerHeadOutput: TypeAlias = torch.Tensor | None
class PoolingType(IntEnum): class PoolingType(IntEnum):
"""Enumeration for different types of pooling methods.""" """Enumeration for different types of pooling methods."""
...@@ -123,31 +132,24 @@ class PoolingMethod(nn.Module, ABC): ...@@ -123,31 +132,24 @@ class PoolingMethod(nn.Module, ABC):
return PoolingParamsUpdate() return PoolingParamsUpdate()
@abstractmethod @abstractmethod
def forward_all(
self,
hidden_states: torch.Tensor,
pooling_cursor: PoolingCursor,
) -> PoolerOutput:
raise NotImplementedError
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata, pooling_metadata: PoolingMetadata,
) -> PoolerOutput: ) -> PoolingMethodOutput:
pooling_cursor = pooling_metadata.pooling_cursor raise NotImplementedError
return self.forward_all(hidden_states, pooling_cursor)
class CLSPool(PoolingMethod): class CLSPool(PoolingMethod):
def get_supported_tasks(self) -> Set[PoolingTask]: def get_supported_tasks(self) -> Set[PoolingTask]:
return {"token_embed", "token_classify", "embed", "classify", "score"} return {"token_embed", "token_classify", "embed", "classify", "score"}
def forward_all( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
pooling_cursor: PoolingCursor, pooling_metadata: PoolingMetadata,
) -> PoolerOutput: ) -> TokenPoolingMethodOutput:
pooling_cursor = pooling_metadata.get_pooling_cursor()
assert not pooling_cursor.is_partial_prefill(), ( assert not pooling_cursor.is_partial_prefill(), (
"partial prefill not supported with CLS pooling" "partial prefill not supported with CLS pooling"
) )
...@@ -159,11 +161,12 @@ class LastPool(PoolingMethod): ...@@ -159,11 +161,12 @@ class LastPool(PoolingMethod):
def get_supported_tasks(self) -> Set[PoolingTask]: def get_supported_tasks(self) -> Set[PoolingTask]:
return {"token_embed", "token_classify", "embed", "classify", "score"} return {"token_embed", "token_classify", "embed", "classify", "score"}
def forward_all( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
pooling_cursor: PoolingCursor, pooling_metadata: PoolingMetadata,
) -> PoolerOutput: ) -> TokenPoolingMethodOutput:
pooling_cursor = pooling_metadata.get_pooling_cursor()
return hidden_states[pooling_cursor.last_token_indices_gpu] return hidden_states[pooling_cursor.last_token_indices_gpu]
...@@ -179,19 +182,12 @@ class AllPool(PoolingMethod): ...@@ -179,19 +182,12 @@ class AllPool(PoolingMethod):
def get_supported_tasks(self) -> Set[PoolingTask]: def get_supported_tasks(self) -> Set[PoolingTask]:
return {"token_embed", "token_classify"} return {"token_embed", "token_classify"}
def forward_all(
self, hidden_states: torch.Tensor, pooling_cursor: PoolingCursor
) -> PoolerOutput:
raise NotImplementedError(
"forward_all is not implemented for AllPool. Use forward instead."
)
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata, pooling_metadata: PoolingMetadata,
) -> PoolerOutput: ) -> TokensPoolingMethodOutput:
pooling_cursor = pooling_metadata.pooling_cursor pooling_cursor = pooling_metadata.get_pooling_cursor()
is_finished = pooling_cursor.is_finished() is_finished = pooling_cursor.is_finished()
hidden_states_lst = list( hidden_states_lst = list(
hidden_states.split(pooling_cursor.num_scheduled_tokens_cpu.tolist()) hidden_states.split(pooling_cursor.num_scheduled_tokens_cpu.tolist())
...@@ -209,7 +205,7 @@ class AllPool(PoolingMethod): ...@@ -209,7 +205,7 @@ class AllPool(PoolingMethod):
p.hidden_states_cache.append(hs_chunk) p.hidden_states_cache.append(hs_chunk)
# 2. Once prefill is finished, send hidden_states_cache to PoolerHead # 2. Once prefill is finished, send hidden_states_cache to PoolerHead
output_list: PoolerOutput = [] output_list = list[torch.Tensor | None]()
for p, finished in zip(pooling_states, is_finished): for p, finished in zip(pooling_states, is_finished):
if finished: if finished:
hidden_states_cache = p.hidden_states_cache hidden_states_cache = p.hidden_states_cache
...@@ -228,11 +224,12 @@ class MeanPool(PoolingMethod): ...@@ -228,11 +224,12 @@ class MeanPool(PoolingMethod):
def get_supported_tasks(self) -> Set[PoolingTask]: def get_supported_tasks(self) -> Set[PoolingTask]:
return {"token_embed", "token_classify", "embed", "classify", "score"} return {"token_embed", "token_classify", "embed", "classify", "score"}
def forward_all( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
pooling_cursor: PoolingCursor, pooling_metadata: PoolingMetadata,
) -> PoolerOutput: ) -> TokenPoolingMethodOutput:
pooling_cursor = pooling_metadata.get_pooling_cursor()
assert not pooling_cursor.is_partial_prefill(), ( assert not pooling_cursor.is_partial_prefill(), (
"partial prefill not supported with MEAN pooling" "partial prefill not supported with MEAN pooling"
) )
...@@ -410,7 +407,7 @@ class Pooler(nn.Module, ABC): ...@@ -410,7 +407,7 @@ class Pooler(nn.Module, ABC):
@abstractmethod @abstractmethod
def forward( def forward(
self, self,
hidden_states: list[torch.Tensor] | torch.Tensor, hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata, pooling_metadata: PoolingMetadata,
) -> PoolerOutput: ) -> PoolerOutput:
raise NotImplementedError raise NotImplementedError
...@@ -422,41 +419,42 @@ class DummyPooler(Pooler): ...@@ -422,41 +419,42 @@ class DummyPooler(Pooler):
def forward( def forward(
self, self,
hidden_states: list[torch.Tensor] | torch.Tensor, hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata, pooling_metadata: PoolingMetadata,
) -> PoolerOutput: ) -> PoolerOutput:
return hidden_states return hidden_states
class PoolerHead(nn.Module): class TokenPoolerHead(nn.Module, ABC):
def __init__(self, activation: PoolerActivation) -> None: """Applicable to pooling strategies that output one token."""
super().__init__()
self.activation = activation
@abstractmethod
def forward( def forward(
self, self,
pooled_data: list[torch.Tensor] | torch.Tensor, pooled_data: TokenPoolingMethodOutput,
pooling_metadata: PoolingMetadata, pooling_metadata: PoolingMetadata,
) -> PoolerOutput: ) -> TokenPoolerHeadOutput:
return self.activation(pooled_data) raise NotImplementedError
class EmbeddingPoolerHead(PoolerHead): class EmbeddingPoolerHead(TokenPoolerHead):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__(activation=PoolerNormalize()) super().__init__()
# Load ST projector if available # Load ST projector if available
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
self.projector: nn.Module | None = ( self.projector = (
_load_st_projector(vllm_config.model_config) if vllm_config else None _load_st_projector(vllm_config.model_config) if vllm_config else None
) )
self.head_dtype = vllm_config.model_config.head_dtype self.head_dtype = vllm_config.model_config.head_dtype
self.activation = PoolerNormalize()
def forward( def forward(
self, self,
pooled_data: list[torch.Tensor] | torch.Tensor, pooled_data: TokenPoolingMethodOutput,
pooling_metadata: PoolingMetadata, pooling_metadata: PoolingMetadata,
) -> PoolerOutput: ) -> TokenPoolerHeadOutput:
if isinstance(pooled_data, list): if isinstance(pooled_data, list):
pooled_data = torch.stack(pooled_data) pooled_data = torch.stack(pooled_data)
# pooled_data shape: [batchsize, hidden_dimension] # pooled_data shape: [batchsize, hidden_dimension]
...@@ -509,7 +507,7 @@ class SimplePooler(Pooler): ...@@ -509,7 +507,7 @@ class SimplePooler(Pooler):
3. Returns structured results as `PoolerOutput`. 3. Returns structured results as `PoolerOutput`.
""" """
def __init__(self, pooling: PoolingMethod, head: PoolerHead) -> None: def __init__(self, pooling: PoolingMethod, head: TokenPoolerHead) -> None:
super().__init__() super().__init__()
self.pooling = pooling self.pooling = pooling
...@@ -523,9 +521,9 @@ class SimplePooler(Pooler): ...@@ -523,9 +521,9 @@ class SimplePooler(Pooler):
def forward( def forward(
self, self,
hidden_states: torch.Tensor | list[torch.Tensor], hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata, pooling_metadata: PoolingMetadata,
) -> PoolerOutput: ) -> TokenPoolerHeadOutput:
pooled_data = self.pooling(hidden_states, pooling_metadata) pooled_data = self.pooling(hidden_states, pooling_metadata)
pooled_data = self.head(pooled_data, pooling_metadata) pooled_data = self.head(pooled_data, pooling_metadata)
return pooled_data return pooled_data
...@@ -591,9 +589,9 @@ class ClassifierPooler(Pooler): ...@@ -591,9 +589,9 @@ class ClassifierPooler(Pooler):
def forward( def forward(
self, self,
hidden_states: torch.Tensor | list[torch.Tensor], hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata, pooling_metadata: PoolingMetadata,
) -> PoolerOutput: ) -> TokenPoolerOutput:
pooled_data = self.pooling(hidden_states, pooling_metadata) pooled_data = self.pooling(hidden_states, pooling_metadata)
if isinstance(pooled_data, list): if isinstance(pooled_data, list):
pooled_data = torch.stack(pooled_data) pooled_data = torch.stack(pooled_data)
...@@ -622,10 +620,36 @@ class ClassifierPooler(Pooler): ...@@ -622,10 +620,36 @@ class ClassifierPooler(Pooler):
return scores return scores
class TokenEmbeddingPoolerHead(EmbeddingPoolerHead): class TokensPoolerHead(nn.Module, ABC):
"""Applicable to pooling strategies that output multiple tokens."""
@abstractmethod
def forward( def forward(
self, pooled_data: torch.Tensor | None, pooling_param: PoolingParams self,
) -> PoolerOutput: pooled_data: TokensPoolingMethodOutputItem,
pooling_param: PoolingParams,
) -> TokensPoolerHeadOutput:
raise NotImplementedError
class TokenEmbeddingPoolerHead(TokensPoolerHead):
def __init__(self) -> None:
super().__init__()
# Load ST projector if available
vllm_config = get_current_vllm_config()
self.projector = (
_load_st_projector(vllm_config.model_config) if vllm_config else None
)
self.head_dtype = vllm_config.model_config.head_dtype
self.activation = PoolerNormalize()
def forward(
self,
pooled_data: TokensPoolingMethodOutputItem,
pooling_param: PoolingParams,
) -> TokensPoolerHeadOutput:
# for unfinished chunked prefill # for unfinished chunked prefill
if pooled_data is None: if pooled_data is None:
return None return None
...@@ -649,57 +673,56 @@ class TokenEmbeddingPoolerHead(EmbeddingPoolerHead): ...@@ -649,57 +673,56 @@ class TokenEmbeddingPoolerHead(EmbeddingPoolerHead):
return pooled_data return pooled_data
class TokenClassifierPoolerHead(nn.Module): class TokenClassifierPoolerHead(TokensPoolerHead):
def __init__( def __init__(
self, self,
classifier: ClassifierFn | None, classifier: ClassifierFn | None,
act_fn: PoolerActivation | str | None = None, act_fn: PoolerActivation | str | None = None,
) -> None: ) -> None:
super().__init__() super().__init__()
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
self.classifier = classifier self.classifier = classifier
self.act_fn = ClassifierPooler.resolve_act_fn(
vllm_config.model_config, static_num_labels=False, act_fn=act_fn
)
self.logit_bias: float | None = ( self.logit_bias: float | None = (
vllm_config.model_config.pooler_config.logit_bias vllm_config.model_config.pooler_config.logit_bias
) )
self.head_dtype = vllm_config.model_config.head_dtype self.head_dtype = vllm_config.model_config.head_dtype
def get_supported_tasks(self) -> Set[PoolingTask]: self.activation = ClassifierPooler.resolve_act_fn(
return {"token_classify"} vllm_config.model_config, static_num_labels=False, act_fn=act_fn
)
def forward( def forward(
self, self,
hidden_states: torch.Tensor | None, pooled_data: TokensPoolingMethodOutputItem,
pooling_param: PoolingParams, pooling_param: PoolingParams,
) -> PoolerOutput: ) -> TokensPoolerHeadOutput:
# for unfinished chunked prefill # for unfinished chunked prefill
if hidden_states is None: if pooled_data is None:
return None return None
hidden_states = hidden_states.to(self.head_dtype) pooled_data = pooled_data.to(self.head_dtype)
# hidden_states shape: [n_token, hidden_size] # hidden_states shape: [n_token, hidden_size]
if self.classifier is not None: if self.classifier is not None:
scores = self.classifier(hidden_states) scores = self.classifier(pooled_data)
else: else:
scores = hidden_states scores = pooled_data
# scores shape: [n_token, num_labels] # scores shape: [n_token, num_labels]
if self.logit_bias is not None: if self.logit_bias is not None:
scores -= self.logit_bias scores -= self.logit_bias
if pooling_param.use_activation: if pooling_param.use_activation:
scores = self.act_fn(scores) scores = self.activation(scores)
# scores shape: [n_token, num_labels] # scores shape: [n_token, num_labels]
return scores return scores
class AllPooler(Pooler): class AllPooler(Pooler):
def __init__(self, head: nn.Module | PoolerHead) -> None: def __init__(self, head: TokensPoolerHead) -> None:
super().__init__() super().__init__()
self.pooling = AllPool() self.pooling = AllPool()
...@@ -712,17 +735,16 @@ class AllPooler(Pooler): ...@@ -712,17 +735,16 @@ class AllPooler(Pooler):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata, pooling_metadata: PoolingMetadata,
) -> PoolerOutput: ) -> TokensPoolerOutput:
pooled_data = self.pooling(hidden_states, pooling_metadata) pooled_data = self.pooling(hidden_states, pooling_metadata)
pooling_params = pooling_metadata.pooling_params pooling_params = pooling_metadata.pooling_params
assert len(pooled_data) == len(pooling_params) assert len(pooled_data) == len(pooling_params)
pooled_data = [self.head(d, p) for d, p in zip(pooled_data, pooling_params)] return [self.head(d, p) for d, p in zip(pooled_data, pooling_params)]
return pooled_data
class StepPooler(Pooler): class StepPooler(Pooler):
def __init__(self, head: nn.Module | PoolerHead) -> None: def __init__(self, head: TokensPoolerHead) -> None:
super().__init__() super().__init__()
self.pooling = AllPool() self.pooling = AllPool()
...@@ -730,14 +752,14 @@ class StepPooler(Pooler): ...@@ -730,14 +752,14 @@ class StepPooler(Pooler):
def extract_states( def extract_states(
self, self,
hidden_states: torch.Tensor | list[torch.Tensor], hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata, pooling_metadata: PoolingMetadata,
) -> PoolerOutput: ) -> list[torch.Tensor | None]:
pooled_data_lst = self.pooling(hidden_states, pooling_metadata) pooled_data_lst = self.pooling(hidden_states, pooling_metadata)
prompt_token_ids = pooling_metadata.get_prompt_token_ids() prompt_token_ids = pooling_metadata.get_prompt_token_ids()
pooling_params = pooling_metadata.pooling_params pooling_params = pooling_metadata.pooling_params
pooled_data: PoolerOutput = [] pooled_data = list[torch.Tensor | None]()
for data, token_id, pooling_param in zip( for data, token_id, pooling_param in zip(
pooled_data_lst, prompt_token_ids, pooling_params pooled_data_lst, prompt_token_ids, pooling_params
): ):
...@@ -766,15 +788,14 @@ class StepPooler(Pooler): ...@@ -766,15 +788,14 @@ class StepPooler(Pooler):
def forward( def forward(
self, self,
hidden_states: torch.Tensor | list[torch.Tensor], hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata, pooling_metadata: PoolingMetadata,
) -> PoolerOutput: ) -> TokensPoolerOutput:
pooled_data = self.extract_states(hidden_states, pooling_metadata) pooled_data = self.extract_states(hidden_states, pooling_metadata)
pooling_params = pooling_metadata.pooling_params pooling_params = pooling_metadata.pooling_params
assert len(pooled_data) == len(pooling_params) assert len(pooled_data) == len(pooling_params)
pooled_data = [self.head(d, p) for d, p in zip(pooled_data, pooling_params)] return [self.head(d, p) for d, p in zip(pooled_data, pooling_params)]
return pooled_data
class DispatchPooler(Pooler): class DispatchPooler(Pooler):
...@@ -800,12 +821,12 @@ class DispatchPooler(Pooler): ...@@ -800,12 +821,12 @@ class DispatchPooler(Pooler):
def forward( def forward(
self, self,
hidden_states: torch.Tensor | list[torch.Tensor], hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata, pooling_metadata: PoolingMetadata,
) -> PoolerOutput: ) -> PoolerOutput:
poolers_by_task = self.poolers_by_task poolers_by_task = self.poolers_by_task
outputs = list[torch.Tensor]() outputs = list[torch.Tensor | None]()
offset = 0 offset = 0
for task, group in groupby(pooling_metadata.tasks): for task, group in groupby(pooling_metadata.tasks):
if not (pooler := poolers_by_task.get(task)): if not (pooler := poolers_by_task.get(task)):
......
...@@ -24,11 +24,14 @@ from vllm.model_executor.layers.pooler import ( ...@@ -24,11 +24,14 @@ from vllm.model_executor.layers.pooler import (
PoolingMethod, PoolingMethod,
PoolingParamsUpdate, PoolingParamsUpdate,
PoolingType, PoolingType,
TokenPoolerHeadOutput,
TokenPoolingMethodOutput,
) )
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.tasks import PoolingTask from vllm.tasks import PoolingTask
from vllm.v1.outputs import TokenPoolerOutput
from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.pool.metadata import PoolingMetadata
from .interfaces import SupportsCrossEncoding, SupportsQuant from .interfaces import SupportsCrossEncoding, SupportsQuant
...@@ -97,24 +100,26 @@ class BertPooler(Pooler): ...@@ -97,24 +100,26 @@ class BertPooler(Pooler):
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate: def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
return self.pooling.get_pooling_updates(task) return self.pooling.get_pooling_updates(task)
def _head(self, pooled_output: torch.Tensor): def head(
pooled_output = self.dense(pooled_output)
pooled_output = self.activation(pooled_output)
return pooled_output
def forward(
self, self,
hidden_states: torch.Tensor | list[torch.Tensor], pooled_data: TokenPoolingMethodOutput,
pooling_metadata: PoolingMetadata, pooling_metadata: PoolingMetadata,
) -> torch.Tensor | list[torch.Tensor]: ) -> TokenPoolerHeadOutput:
pooled_output = self.pooling(hidden_states, pooling_metadata) if isinstance(pooled_data, list):
pooled_data = torch.stack(pooled_data)
if isinstance(pooled_output, list): pooled_data = self.dense(pooled_data)
pooled_output = [self._head(output) for output in pooled_output] pooled_data = self.activation(pooled_data)
else: return pooled_data
pooled_output = self._head(pooled_output)
return pooled_output def forward(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> TokenPoolerOutput:
pooled_data = self.pooling(hidden_states, pooling_metadata)
pooled_data = self.head(pooled_data, pooling_metadata)
return pooled_data
class BertEncoder(nn.Module): class BertEncoder(nn.Module):
......
...@@ -4,21 +4,22 @@ from collections.abc import Set ...@@ -4,21 +4,22 @@ from collections.abc import Set
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn
from vllm.config import ModelConfig, VllmConfig from vllm.config import ModelConfig, VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.pooler import ( from vllm.model_executor.layers.pooler import (
DispatchPooler, DispatchPooler,
Pooler, Pooler,
PoolerHead,
PoolerNormalize, PoolerNormalize,
PoolingMethod,
PoolingParamsUpdate, PoolingParamsUpdate,
TokenPoolerHeadOutput,
TokenPoolingMethodOutput,
) )
from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.tasks import PoolingTask from vllm.tasks import PoolingTask
from vllm.tokenizers import cached_tokenizer_from_config from vllm.tokenizers import cached_tokenizer_from_config
from vllm.v1.outputs import PoolerOutput from vllm.v1.outputs import TokenPoolerOutput
from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.pool.metadata import PoolingMetadata
from .interfaces_base import default_pooling_type from .interfaces_base import default_pooling_type
...@@ -26,7 +27,7 @@ from .interfaces_base import default_pooling_type ...@@ -26,7 +27,7 @@ from .interfaces_base import default_pooling_type
logger = init_logger(__name__) logger = init_logger(__name__)
class GritLMMeanPool(nn.Module): class GritLMMeanPool(PoolingMethod):
"""As `MeanPool`, but only includes non-instruction tokens.""" """As `MeanPool`, but only includes non-instruction tokens."""
def __init__(self, model_config: ModelConfig): def __init__(self, model_config: ModelConfig):
...@@ -141,16 +142,16 @@ class GritLMMeanPool(nn.Module): ...@@ -141,16 +142,16 @@ class GritLMMeanPool(nn.Module):
return instruction_len return instruction_len
def get_supported_tasks(self) -> Set[PoolingTask]: def get_supported_tasks(self) -> Set[PoolingTask]:
return {"encode", "embed"} return {"embed"}
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate: def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
return PoolingParamsUpdate(requires_token_ids=True) return PoolingParamsUpdate(requires_token_ids=True)
def forward( def forward(
self, self,
hidden_states: torch.Tensor | list[torch.Tensor], hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata, pooling_metadata: PoolingMetadata,
) -> list[torch.Tensor] | torch.Tensor: ) -> TokenPoolingMethodOutput:
prompt_lens = pooling_metadata.prompt_lens prompt_lens = pooling_metadata.prompt_lens
instr_lens = torch.tensor( instr_lens = torch.tensor(
[ [
...@@ -178,7 +179,7 @@ class GritLMPooler(Pooler): ...@@ -178,7 +179,7 @@ class GritLMPooler(Pooler):
super().__init__() super().__init__()
self.pooling = GritLMMeanPool(model_config) self.pooling = GritLMMeanPool(model_config)
self.head = PoolerHead(PoolerNormalize()) self.activation = PoolerNormalize()
def get_supported_tasks(self) -> Set[PoolingTask]: def get_supported_tasks(self) -> Set[PoolingTask]:
return self.pooling.get_supported_tasks() return self.pooling.get_supported_tasks()
...@@ -186,11 +187,18 @@ class GritLMPooler(Pooler): ...@@ -186,11 +187,18 @@ class GritLMPooler(Pooler):
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate: def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
return self.pooling.get_pooling_updates(task) return self.pooling.get_pooling_updates(task)
def head(
self,
pooled_data: TokenPoolingMethodOutput,
pooling_metadata: PoolingMetadata,
) -> TokenPoolerHeadOutput:
return self.activation(pooled_data)
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata, pooling_metadata: PoolingMetadata,
) -> PoolerOutput: ) -> TokenPoolerOutput:
pooled_data = self.pooling(hidden_states, pooling_metadata) pooled_data = self.pooling(hidden_states, pooling_metadata)
pooled_data = self.head(pooled_data, pooling_metadata) pooled_data = self.head(pooled_data, pooling_metadata)
return pooled_data return pooled_data
......
...@@ -19,12 +19,15 @@ from vllm.model_executor.layers.pooler import ( ...@@ -19,12 +19,15 @@ from vllm.model_executor.layers.pooler import (
PoolingMethod, PoolingMethod,
PoolingParamsUpdate, PoolingParamsUpdate,
PoolingType, PoolingType,
TokenPoolerHeadOutput,
TokenPoolingMethodOutput,
) )
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.tasks import PoolingTask from vllm.tasks import PoolingTask
from vllm.v1.outputs import TokenPoolerOutput
from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.pool.metadata import PoolingMetadata
from .interfaces import SupportsCrossEncoding from .interfaces import SupportsCrossEncoding
...@@ -300,23 +303,25 @@ class ModernBertPooler(Pooler): ...@@ -300,23 +303,25 @@ class ModernBertPooler(Pooler):
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate: def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
return self.pooling.get_pooling_updates(task) return self.pooling.get_pooling_updates(task)
def _head(self, pooled_output: torch.Tensor): def head(
pooled_output = pooled_output.to(self.dense.weight.dtype)
return self.norm(self.act(self.dense(pooled_output)))
def forward(
self, self,
hidden_states: torch.Tensor | list[torch.Tensor], pooled_data: TokenPoolingMethodOutput,
pooling_metadata: PoolingMetadata, pooling_metadata: PoolingMetadata,
) -> torch.Tensor | list[torch.Tensor]: ) -> TokenPoolerHeadOutput:
pooled_output = self.pooling(hidden_states, pooling_metadata) if isinstance(pooled_data, list):
pooled_data = torch.stack(pooled_data)
if isinstance(pooled_output, list): pooled_data = pooled_data.to(self.dense.weight.dtype)
pooled_output = [self._head(output) for output in pooled_output] return self.norm(self.act(self.dense(pooled_data)))
else:
pooled_output = self._head(pooled_output)
return pooled_output def forward(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> TokenPoolerOutput:
pooled_data = self.pooling(hidden_states, pooling_metadata)
pooled_data = self.head(pooled_data, pooling_metadata)
return pooled_data
@default_pooling_type("CLS") @default_pooling_type("CLS")
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import TYPE_CHECKING, NamedTuple from typing import TYPE_CHECKING, NamedTuple, TypeAlias
import numpy as np import numpy as np
import torch import torch
...@@ -91,7 +91,9 @@ class LogprobsTensors(NamedTuple): ...@@ -91,7 +91,9 @@ class LogprobsTensors(NamedTuple):
# [num_reqs, <dynamic>] # [num_reqs, <dynamic>]
# The shape of each element depends on the pooler used # The shape of each element depends on the pooler used
PoolerOutput = list[torch.Tensor | None] | torch.Tensor | None TokenPoolerOutput: TypeAlias = torch.Tensor | list[torch.Tensor]
TokensPoolerOutput: TypeAlias = list[torch.Tensor] | list[torch.Tensor | None]
PoolerOutput: TypeAlias = TokenPoolerOutput | TokensPoolerOutput
@dataclass @dataclass
......
...@@ -90,6 +90,12 @@ class PoolingMetadata: ...@@ -90,6 +90,12 @@ class PoolingMetadata:
return [prompt_token_ids[i, :num] for i, num in enumerate(self.prompt_lens)] return [prompt_token_ids[i, :num] for i, num in enumerate(self.prompt_lens)]
def get_pooling_cursor(self) -> PoolingCursor:
pooling_cursor = self.pooling_cursor
assert pooling_cursor is not None, "Should call `build_pooling_cursor` first"
return pooling_cursor
def build_pooling_cursor( def build_pooling_cursor(
self, self,
num_scheduled_tokens_np: np.ndarray, num_scheduled_tokens_np: np.ndarray,
......
...@@ -4680,7 +4680,7 @@ class GPUModelRunner( ...@@ -4680,7 +4680,7 @@ class GPUModelRunner(
for task in supported_pooling_tasks: for task in supported_pooling_tasks:
# Run a full batch with each task to ensure none of them OOMs # Run a full batch with each task to ensure none of them OOMs
output = self._dummy_pooler_run_task(hidden_states, task) output = self._dummy_pooler_run_task(hidden_states, task)
output_size[task] = sum(o.nbytes for o in output) output_size[task] = sum(o.nbytes for o in output if o is not None)
del output # Allow GC del output # Allow GC
max_task = max(output_size.items(), key=lambda x: x[1])[0] max_task = max(output_size.items(), key=lambda x: x[1])[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment