Unverified Commit f54f8512 authored by wang.yuqi's avatar wang.yuqi Committed by GitHub
Browse files

[Model][2/N] Improve all pooling task | Support multi-vector retrieval (#25370)


Signed-off-by: default avatarwang.yuqi <noooop@126.com>
parent d4d1a602
...@@ -1748,16 +1748,19 @@ async def init_app_state( ...@@ -1748,16 +1748,19 @@ async def init_app_state(
else None else None
) )
state.openai_serving_pooling = ( state.openai_serving_pooling = (
OpenAIServingPooling( (
engine_client, OpenAIServingPooling(
state.openai_serving_models, engine_client,
request_logger=request_logger, state.openai_serving_models,
chat_template=resolved_chat_template, supported_tasks=supported_tasks,
chat_template_content_format=args.chat_template_content_format, request_logger=request_logger,
trust_request_chat_template=args.trust_request_chat_template, chat_template=resolved_chat_template,
log_error_stack=args.log_error_stack, chat_template_content_format=args.chat_template_content_format,
trust_request_chat_template=args.trust_request_chat_template,
log_error_stack=args.log_error_stack,
)
) )
if "encode" in supported_tasks if ("token_embed" in supported_tasks or "token_classify" in supported_tasks)
else None else None
) )
state.openai_serving_embedding = ( state.openai_serving_embedding = (
......
...@@ -1682,7 +1682,7 @@ class IOProcessorRequest(OpenAIBaseModel, Generic[T]): ...@@ -1682,7 +1682,7 @@ class IOProcessorRequest(OpenAIBaseModel, Generic[T]):
When using plugins IOProcessor plugins, the actual input is processed When using plugins IOProcessor plugins, the actual input is processed
by the plugin itself. Hence, we use a generic type for the request data by the plugin itself. Hence, we use a generic type for the request data
""" """
softmax: bool = True activation: bool = False
embed_dtype: str = Field( embed_dtype: str = Field(
default="float32", default="float32",
...@@ -1693,7 +1693,7 @@ class IOProcessorRequest(OpenAIBaseModel, Generic[T]): ...@@ -1693,7 +1693,7 @@ class IOProcessorRequest(OpenAIBaseModel, Generic[T]):
) )
def to_pooling_params(self): def to_pooling_params(self):
return PoolingParams(task="encode", softmax=self.softmax) return PoolingParams(task="token_classify", activation=self.activation)
class IOProcessorResponse(OpenAIBaseModel, Generic[T]): class IOProcessorResponse(OpenAIBaseModel, Generic[T]):
......
...@@ -35,6 +35,7 @@ from vllm.entrypoints.renderer import RenderConfig ...@@ -35,6 +35,7 @@ from vllm.entrypoints.renderer import RenderConfig
from vllm.entrypoints.utils import _validate_truncation_size from vllm.entrypoints.utils import _validate_truncation_size
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import PoolingOutput, PoolingRequestOutput from vllm.outputs import PoolingOutput, PoolingRequestOutput
from vllm.tasks import SupportedTask
from vllm.utils import merge_async_iterators from vllm.utils import merge_async_iterators
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -62,6 +63,7 @@ class OpenAIServingPooling(OpenAIServing): ...@@ -62,6 +63,7 @@ class OpenAIServingPooling(OpenAIServing):
engine_client: EngineClient, engine_client: EngineClient,
models: OpenAIServingModels, models: OpenAIServingModels,
*, *,
supported_tasks: tuple[SupportedTask, ...],
request_logger: RequestLogger | None, request_logger: RequestLogger | None,
chat_template: str | None, chat_template: str | None,
chat_template_content_format: ChatTemplateContentFormatOption, chat_template_content_format: ChatTemplateContentFormatOption,
...@@ -75,6 +77,7 @@ class OpenAIServingPooling(OpenAIServing): ...@@ -75,6 +77,7 @@ class OpenAIServingPooling(OpenAIServing):
log_error_stack=log_error_stack, log_error_stack=log_error_stack,
) )
self.supported_tasks = supported_tasks
self.chat_template = chat_template self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format self.chat_template_content_format: Final = chat_template_content_format
self.trust_request_chat_template = trust_request_chat_template self.trust_request_chat_template = trust_request_chat_template
...@@ -178,8 +181,17 @@ class OpenAIServingPooling(OpenAIServing): ...@@ -178,8 +181,17 @@ class OpenAIServingPooling(OpenAIServing):
try: try:
pooling_params = request.to_pooling_params() pooling_params = request.to_pooling_params()
if "token_embed" in self.supported_tasks:
pooling_task = "token_embed"
elif "token_classify" in self.supported_tasks:
pooling_task = "token_classify"
else:
return self.create_error_response(
f"pooling_task must be one of {self.supported_tasks}."
)
try: try:
pooling_params.verify("encode", self.model_config) pooling_params.verify(pooling_task, self.model_config)
except ValueError as e: except ValueError as e:
return self.create_error_response(str(e)) return self.create_error_response(str(e))
......
...@@ -64,66 +64,6 @@ class PoolingParamsUpdate: ...@@ -64,66 +64,6 @@ class PoolingParamsUpdate:
params.requires_token_ids = self.requires_token_ids params.requires_token_ids = self.requires_token_ids
class Pooler(nn.Module, ABC):
"""The interface required for all poolers used in pooling models in vLLM."""
@staticmethod
def for_encode(pooler_config: PoolerConfig):
if pooler_config.pooling_type == "STEP":
return StepPooler()
resolved_config = ResolvedPoolingConfig(
task="encode", pooling_type=PoolingType.ALL
)
return SimplePooler.from_config(resolved_config)
@staticmethod
def for_embed(pooler_config: PoolerConfig):
resolved_config = ResolvedPoolingConfig.from_config(
task="embed",
pooler_config=pooler_config,
)
return SimplePooler.from_config(resolved_config)
@staticmethod
def for_classify(
pooler_config: PoolerConfig,
classifier: ClassifierFn | None,
):
resolved_config = ResolvedPoolingConfig.from_config(
task="classify",
pooler_config=pooler_config,
)
pooling = PoolingMethod.from_pooling_type(resolved_config.pooling_type)
return ClassifierPooler(
pooling=pooling,
classifier=classifier,
)
@abstractmethod
def get_supported_tasks(self) -> Set[PoolingTask]:
"""Determine which pooling tasks are supported."""
raise NotImplementedError
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
"""
Construct the updated pooling parameters to use for a supported task.
"""
return PoolingParamsUpdate()
@abstractmethod
def forward(
self,
hidden_states: list[torch.Tensor] | torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
raise NotImplementedError
def get_prompt_lens( def get_prompt_lens(
hidden_states: torch.Tensor | list[torch.Tensor], hidden_states: torch.Tensor | list[torch.Tensor],
pooling_metadata: PoolingMetadata, pooling_metadata: PoolingMetadata,
...@@ -237,7 +177,7 @@ class PoolingMethod(nn.Module, ABC): ...@@ -237,7 +177,7 @@ class PoolingMethod(nn.Module, ABC):
class CLSPool(PoolingMethod): class CLSPool(PoolingMethod):
def get_supported_tasks(self) -> Set[PoolingTask]: def get_supported_tasks(self) -> Set[PoolingTask]:
return {"encode", "embed", "classify", "score"} return {"token_embed", "token_classify", "embed", "classify", "score"}
def forward_all( def forward_all(
self, self,
...@@ -253,7 +193,7 @@ class CLSPool(PoolingMethod): ...@@ -253,7 +193,7 @@ class CLSPool(PoolingMethod):
class LastPool(PoolingMethod): class LastPool(PoolingMethod):
def get_supported_tasks(self) -> Set[PoolingTask]: def get_supported_tasks(self) -> Set[PoolingTask]:
return {"encode", "embed", "classify", "score"} return {"token_embed", "token_classify", "embed", "classify", "score"}
def forward_all( def forward_all(
self, self,
...@@ -265,7 +205,7 @@ class LastPool(PoolingMethod): ...@@ -265,7 +205,7 @@ class LastPool(PoolingMethod):
class AllPool(PoolingMethod): class AllPool(PoolingMethod):
def get_supported_tasks(self) -> Set[PoolingTask]: def get_supported_tasks(self) -> Set[PoolingTask]:
return {"encode"} return {"token_embed", "token_classify"}
def forward_all( def forward_all(
self, self,
...@@ -284,7 +224,7 @@ class AllPool(PoolingMethod): ...@@ -284,7 +224,7 @@ class AllPool(PoolingMethod):
class MeanPool(PoolingMethod): class MeanPool(PoolingMethod):
def get_supported_tasks(self) -> Set[PoolingTask]: def get_supported_tasks(self) -> Set[PoolingTask]:
return {"encode", "embed", "classify", "score"} return {"token_embed", "token_classify", "embed", "classify", "score"}
def forward_all( def forward_all(
self, self,
...@@ -398,6 +338,82 @@ class LambdaPoolerActivation(PoolerActivation): ...@@ -398,6 +338,82 @@ class LambdaPoolerActivation(PoolerActivation):
return self.fn(pooled_data) return self.fn(pooled_data)
class Pooler(nn.Module, ABC):
"""The interface required for all poolers used in pooling models in vLLM."""
@staticmethod
def for_token_embed(pooler_config: PoolerConfig):
head = TokenEmbeddingPoolerHead()
if pooler_config.pooling_type == "STEP":
return StepPooler(head=head)
return AllPooler(head=head)
@staticmethod
def for_token_classify(
pooler_config: PoolerConfig,
classifier: ClassifierFn | None = None,
act_fn: PoolerActivation | str | None = None,
):
head = TokenClassifierPoolerHead(classifier=classifier, act_fn=act_fn)
if pooler_config.pooling_type == "STEP":
return StepPooler(head=head)
return AllPooler(head=head)
@staticmethod
def for_embed(pooler_config: PoolerConfig):
resolved_config = ResolvedPoolingConfig.from_config(
task="embed",
pooler_config=pooler_config,
)
pooling = PoolingMethod.from_pooling_type(resolved_config.pooling_type)
head = EmbeddingPoolerHead()
return SimplePooler(pooling=pooling, head=head)
@staticmethod
def for_classify(
pooler_config: PoolerConfig,
classifier: ClassifierFn | None,
act_fn: PoolerActivation | str | None = None,
):
resolved_config = ResolvedPoolingConfig.from_config(
task="classify",
pooler_config=pooler_config,
)
pooling = PoolingMethod.from_pooling_type(resolved_config.pooling_type)
return ClassifierPooler(
pooling=pooling,
classifier=classifier,
act_fn=act_fn,
)
@abstractmethod
def get_supported_tasks(self) -> Set[PoolingTask]:
"""Determine which pooling tasks are supported."""
raise NotImplementedError
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
"""
Construct the updated pooling parameters to use for a supported task.
"""
return PoolingParamsUpdate()
@abstractmethod
def forward(
self,
hidden_states: list[torch.Tensor] | torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
raise NotImplementedError
class PoolerHead(nn.Module): class PoolerHead(nn.Module):
def __init__(self, activation: PoolerActivation) -> None: def __init__(self, activation: PoolerActivation) -> None:
super().__init__() super().__init__()
...@@ -416,7 +432,6 @@ class EmbeddingPoolerHead(PoolerHead): ...@@ -416,7 +432,6 @@ class EmbeddingPoolerHead(PoolerHead):
super().__init__(activation=PoolerNormalize()) super().__init__(activation=PoolerNormalize())
# Load ST projector if available # Load ST projector if available
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
self.projector: nn.Module | None = ( self.projector: nn.Module | None = (
_load_st_projector(vllm_config.model_config) if vllm_config else None _load_st_projector(vllm_config.model_config) if vllm_config else None
...@@ -471,39 +486,6 @@ class EmbeddingPoolerHead(PoolerHead): ...@@ -471,39 +486,6 @@ class EmbeddingPoolerHead(PoolerHead):
return pooled_data return pooled_data
class RewardPoolerHead(PoolerHead):
def __init__(self) -> None:
super().__init__(activation=PoolerClassify(static_num_labels=False))
vllm_config = get_current_vllm_config()
self.head_dtype = vllm_config.model_config.head_dtype
def forward(
self,
pooled_data: list[torch.Tensor] | torch.Tensor,
pooling_metadata: PoolingMetadata,
):
if isinstance(pooled_data, list):
pooled_data = [p.to(self.head_dtype) for p in pooled_data]
else:
pooled_data = pooled_data.to(self.head_dtype)
pooling_params = get_pooling_params(pooling_metadata)
# for softmax
flags = [p.softmax for p in pooling_params]
if len(set(flags)) == 1:
if flags[0]:
pooled_data = self.activation(pooled_data)
else:
pooled_data = [
self.activation(vecs) if f else vecs
for vecs, f in zip(pooled_data, flags)
]
return pooled_data
class SimplePooler(Pooler): class SimplePooler(Pooler):
"""A layer that pools specific information from hidden states. """A layer that pools specific information from hidden states.
...@@ -513,20 +495,6 @@ class SimplePooler(Pooler): ...@@ -513,20 +495,6 @@ class SimplePooler(Pooler):
3. Returns structured results as `PoolerOutput`. 3. Returns structured results as `PoolerOutput`.
""" """
@classmethod
def from_config(
cls,
pooler_config: ResolvedPoolingConfig,
) -> "SimplePooler":
pooling = PoolingMethod.from_pooling_type(pooler_config.pooling_type)
if pooler_config.task == "embed":
head = EmbeddingPoolerHead()
elif pooler_config.task == "encode":
head = RewardPoolerHead()
else:
raise NotImplementedError(f"Unknown task: {pooler_config.task}")
return cls(pooling, head)
def __init__(self, pooling: PoolingMethod, head: PoolerHead) -> None: def __init__(self, pooling: PoolingMethod, head: PoolerHead) -> None:
super().__init__() super().__init__()
...@@ -549,58 +517,6 @@ class SimplePooler(Pooler): ...@@ -549,58 +517,6 @@ class SimplePooler(Pooler):
return pooled_data return pooled_data
class StepPooler(Pooler):
def __init__(
self,
) -> None:
super().__init__()
self.pooling = AllPool()
self.head = RewardPoolerHead()
def extract_states(
self,
hidden_states: torch.Tensor | list[torch.Tensor],
pooling_metadata: PoolingMetadata,
) -> list[torch.Tensor] | torch.Tensor:
pooled_data_lst = self.pooling(hidden_states, pooling_metadata)
prompt_token_ids = get_prompt_token_ids(pooling_metadata)
pooled_data = list[torch.Tensor]()
pooling_params = get_pooling_params(pooling_metadata)
for data, token_id, pooling_param in zip(
pooled_data_lst, prompt_token_ids, pooling_params
):
step_tag_id = pooling_param.step_tag_id
returned_token_ids = pooling_param.returned_token_ids
if returned_token_ids is not None and len(returned_token_ids) > 0:
data = data[:, returned_token_ids]
if step_tag_id is not None:
data = data[token_id == step_tag_id]
pooled_data.append(data)
return pooled_data
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"encode"}
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
return PoolingParamsUpdate(requires_token_ids=True)
def forward(
self,
hidden_states: torch.Tensor | list[torch.Tensor],
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
pooled_data = self.extract_states(hidden_states, pooling_metadata)
pooled_data = self.head(pooled_data, pooling_metadata)
return pooled_data
class ClassifierPooler(Pooler): class ClassifierPooler(Pooler):
"""A pooling layer for classification tasks. """A pooling layer for classification tasks.
...@@ -611,26 +527,46 @@ class ClassifierPooler(Pooler): ...@@ -611,26 +527,46 @@ class ClassifierPooler(Pooler):
""" """
@staticmethod @staticmethod
def act_fn_for_seq_cls(config: ModelConfig): def act_fn_for_seq_cls(model_config: ModelConfig):
return get_classification_activation_function(config.hf_config) return get_classification_activation_function(model_config.hf_config)
@staticmethod
def act_fn_for_cross_encoder(model_config: ModelConfig):
return get_cross_encoder_activation_function(model_config.hf_config)
@staticmethod @staticmethod
def act_fn_for_cross_encoder(config: ModelConfig): def resolve_act_fn(
return get_cross_encoder_activation_function(config.hf_config) model_config: ModelConfig,
static_num_labels: bool = True,
act_fn: PoolerActivation | str | None = None,
):
if isinstance(act_fn, str):
if act_fn == "classify":
return ClassifierPooler.act_fn_for_seq_cls(model_config)
elif act_fn == "score":
return ClassifierPooler.act_fn_for_cross_encoder(model_config)
else:
raise ValueError(f"act_fn [{act_fn=}] not supported.")
elif act_fn is None:
return PoolerClassify(static_num_labels=static_num_labels)
else:
assert callable(act_fn)
return act_fn
def __init__( def __init__(
self, self,
pooling: PoolingFn, pooling: PoolingFn,
classifier: ClassifierFn | None, classifier: ClassifierFn | None,
act_fn: PoolerActivation | None = None, act_fn: PoolerActivation | str | None = None,
) -> None: ) -> None:
super().__init__() super().__init__()
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
self.pooling = pooling self.pooling = pooling
self.classifier = classifier self.classifier = classifier
self.act_fn = act_fn or PoolerClassify() self.act_fn = self.resolve_act_fn(
vllm_config.model_config, static_num_labels=True, act_fn=act_fn
)
self.logit_bias: float | None = ( self.logit_bias: float | None = (
vllm_config.model_config.pooler_config.logit_bias vllm_config.model_config.pooler_config.logit_bias
) )
...@@ -672,6 +608,150 @@ class ClassifierPooler(Pooler): ...@@ -672,6 +608,150 @@ class ClassifierPooler(Pooler):
return scores return scores
class TokenEmbeddingPoolerHead(EmbeddingPoolerHead):
def forward(
self, pooled_data: torch.Tensor, pooling_param: PoolingParams
) -> torch.Tensor:
pooled_data = pooled_data.to(self.head_dtype)
# pooled_data shape: [n_tokens, hidden_dimension]
# Apply ST projector
if self.projector is not None:
pooled_data = self.projector(pooled_data)
# pooled_data shape: [n_tokens, embedding_dimension]
# for matryoshka representation
pooled_data = pooled_data[..., : pooling_param.dimensions]
# for normalize
if pooling_param.normalize:
pooled_data = self.activation(pooled_data)
# pooled_data shape: [n_tokens, embedding_dimension]
return pooled_data
class TokenClassifierPoolerHead(nn.Module):
def __init__(
self,
classifier: ClassifierFn | None,
act_fn: PoolerActivation | str | None = None,
) -> None:
super().__init__()
vllm_config = get_current_vllm_config()
self.classifier = classifier
self.act_fn = ClassifierPooler.resolve_act_fn(
vllm_config.model_config, static_num_labels=False, act_fn=act_fn
)
self.logit_bias: float | None = (
vllm_config.model_config.pooler_config.logit_bias
)
self.head_dtype = vllm_config.model_config.head_dtype
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"token_classify"}
def forward(
self,
hidden_states: torch.Tensor,
pooling_param: PoolingParams,
) -> torch.Tensor:
hidden_states = hidden_states.to(self.head_dtype)
# hidden_states shape: [n_token, hidden_size]
if self.classifier is not None:
scores = self.classifier(hidden_states)
else:
scores = hidden_states
# scores shape: [n_token, num_labels]
if self.logit_bias is not None:
scores -= self.logit_bias
if pooling_param.activation:
scores = self.act_fn(scores)
# scores shape: [n_token, num_labels]
return scores
class AllPooler(Pooler):
def __init__(self, head: nn.Module | PoolerHead) -> None:
super().__init__()
self.pooling = AllPool()
self.head = head
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"token_embed", "token_classify"}
def forward(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
pooled_data = self.pooling(hidden_states, pooling_metadata)
pooling_params = get_pooling_params(pooling_metadata)
assert len(pooled_data) == len(pooling_params)
pooled_data = [self.head(d, p) for d, p in zip(pooled_data, pooling_params)]
return pooled_data
class StepPooler(Pooler):
def __init__(self, head: nn.Module | PoolerHead) -> None:
super().__init__()
self.pooling = AllPool()
self.head = head
def extract_states(
self,
hidden_states: torch.Tensor | list[torch.Tensor],
pooling_metadata: PoolingMetadata,
) -> torch.Tensor | list[torch.Tensor]:
pooled_data_lst = self.pooling(hidden_states, pooling_metadata)
prompt_token_ids = get_prompt_token_ids(pooling_metadata)
pooled_data = list[torch.Tensor]()
pooling_params = get_pooling_params(pooling_metadata)
for data, token_id, pooling_param in zip(
pooled_data_lst, prompt_token_ids, pooling_params
):
step_tag_id = pooling_param.step_tag_id
returned_token_ids = pooling_param.returned_token_ids
if returned_token_ids is not None and len(returned_token_ids) > 0:
data = data[:, returned_token_ids]
if step_tag_id is not None:
data = data[token_id == step_tag_id]
pooled_data.append(data)
return pooled_data
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"token_embed", "token_classify"}
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
return PoolingParamsUpdate(requires_token_ids=True)
def forward(
self,
hidden_states: torch.Tensor | list[torch.Tensor],
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
pooled_data = self.extract_states(hidden_states, pooling_metadata)
pooling_params = get_pooling_params(pooling_metadata)
assert len(pooled_data) == len(pooling_params)
pooled_data = [self.head(d, p) for d, p in zip(pooled_data, pooling_params)]
return pooled_data
class DispatchPooler(Pooler): class DispatchPooler(Pooler):
"""Dispatches calls to a sub-pooler based on the pooling task.""" """Dispatches calls to a sub-pooler based on the pooling task."""
......
...@@ -250,7 +250,7 @@ def as_embedding_model(cls: _T) -> _T: ...@@ -250,7 +250,7 @@ def as_embedding_model(cls: _T) -> _T:
self.pooler = DispatchPooler( self.pooler = DispatchPooler(
{ {
"encode": Pooler.for_encode(pooler_config), "token_embed": Pooler.for_token_embed(pooler_config),
"embed": Pooler.for_embed(pooler_config), "embed": Pooler.for_embed(pooler_config),
}, },
) )
...@@ -279,11 +279,8 @@ def as_seq_cls_model(cls: _T) -> _T: ...@@ -279,11 +279,8 @@ def as_seq_cls_model(cls: _T) -> _T:
# Lazy import # Lazy import
from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.pooler import ( from vllm.model_executor.layers.pooler import (
ClassifierPooler,
DispatchPooler, DispatchPooler,
Pooler, Pooler,
PoolingMethod,
PoolingType,
) )
from vllm.model_executor.models.interfaces import SupportsCrossEncoding from vllm.model_executor.models.interfaces import SupportsCrossEncoding
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
...@@ -302,42 +299,29 @@ def as_seq_cls_model(cls: _T) -> _T: ...@@ -302,42 +299,29 @@ def as_seq_cls_model(cls: _T) -> _T:
model_config.hidden_size, model_config.hidden_size,
config.num_labels, config.num_labels,
bias=False, bias=False,
params_dtype=torch.float32, params_dtype=vllm_config.model_config.head_dtype,
quant_config=quant_config, quant_config=quant_config,
return_bias=False,
prefix=maybe_prefix(prefix, "score"), prefix=maybe_prefix(prefix, "score"),
) )
pooler_config = vllm_config.model_config.pooler_config pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None assert pooler_config is not None
pooling_type_str = pooler_config.pooling_type
assert pooling_type_str is not None
pooling_type = PoolingType[pooling_type_str]
self.pooler = DispatchPooler( self.pooler = DispatchPooler(
{ {
"encode": Pooler.for_encode(pooler_config), "token_classify": Pooler.for_token_classify(
"classify": ClassifierPooler( pooler_config, classifier=self.score
pooling=PoolingMethod.from_pooling_type(pooling_type), ),
classifier=self._classifier, "classify": Pooler.for_classify(
act_fn=ClassifierPooler.act_fn_for_seq_cls( pooler_config, classifier=self.score, act_fn="classify"
vllm_config.model_config
),
), ),
"score": ClassifierPooler( "score": Pooler.for_classify(
pooling=PoolingMethod.from_pooling_type(pooling_type), pooler_config, classifier=self.score, act_fn="score"
classifier=self._classifier,
act_fn=ClassifierPooler.act_fn_for_cross_encoder(
vllm_config.model_config
),
), ),
} }
) )
def _classifier(self, x: torch.Tensor):
x, _ = self.score(x.float())
return x
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -393,7 +377,11 @@ def as_reward_model(cls: _T) -> _T: ...@@ -393,7 +377,11 @@ def as_reward_model(cls: _T) -> _T:
assert pooler_config is not None assert pooler_config is not None
self.pooler = DispatchPooler( self.pooler = DispatchPooler(
{"encode": Pooler.for_encode(pooler_config)}, {
"token_classify": Pooler.for_token_classify(
pooler_config=pooler_config
)
}
) )
ModelForReward.__name__ = _get_pooling_model_name(cls.__name__, "ForReward") ModelForReward.__name__ = _get_pooling_model_name(cls.__name__, "ForReward")
......
...@@ -521,7 +521,7 @@ class BertEmbeddingModel(nn.Module, SupportsQuant): ...@@ -521,7 +521,7 @@ class BertEmbeddingModel(nn.Module, SupportsQuant):
def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler: def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler:
return DispatchPooler( return DispatchPooler(
{ {
"encode": Pooler.for_encode(pooler_config), "token_embed": Pooler.for_token_embed(pooler_config),
"embed": Pooler.for_embed(pooler_config), "embed": Pooler.for_embed(pooler_config),
} }
) )
...@@ -724,7 +724,7 @@ class BertSpladeSparseEmbeddingModel(BertEmbeddingModel): ...@@ -724,7 +724,7 @@ class BertSpladeSparseEmbeddingModel(BertEmbeddingModel):
return DispatchPooler( return DispatchPooler(
{ {
"encode": Pooler.for_encode(pooler_config), "token_embed": Pooler.for_token_embed(pooler_config),
"embed": SPLADESparsePooler( "embed": SPLADESparsePooler(
mlm_head=self.mlm_head, mlm_head=self.mlm_head,
cls_token_id=cls_id, cls_token_id=cls_id,
...@@ -821,20 +821,16 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, SupportsQu ...@@ -821,20 +821,16 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, SupportsQu
self.pooler = DispatchPooler( self.pooler = DispatchPooler(
{ {
"encode": Pooler.for_encode(pooler_config), "token_classify": Pooler.for_token_classify(
pooler_config, classifier=self.classifier
),
"classify": ClassifierPooler( "classify": ClassifierPooler(
pooling=self.bert.pooler, pooling=self.bert.pooler,
classifier=self.classifier, classifier=self.classifier,
act_fn=ClassifierPooler.act_fn_for_seq_cls( act_fn="classify",
vllm_config.model_config
),
), ),
"score": ClassifierPooler( "score": ClassifierPooler(
pooling=self.bert.pooler, pooling=self.bert.pooler, classifier=self.classifier, act_fn="score"
classifier=self.classifier,
act_fn=ClassifierPooler.act_fn_for_cross_encoder(
vllm_config.model_config
),
), ),
} }
) )
...@@ -891,7 +887,9 @@ class BertForTokenClassification(nn.Module): ...@@ -891,7 +887,9 @@ class BertForTokenClassification(nn.Module):
self.pooler = DispatchPooler( self.pooler = DispatchPooler(
{ {
"encode": Pooler.for_encode(pooler_config), "token_classify": Pooler.for_token_classify(
pooler_config=pooler_config
),
} }
) )
......
...@@ -695,20 +695,16 @@ class GteNewForSequenceClassification(nn.Module, SupportsCrossEncoding): ...@@ -695,20 +695,16 @@ class GteNewForSequenceClassification(nn.Module, SupportsCrossEncoding):
self.pooler = DispatchPooler( self.pooler = DispatchPooler(
{ {
"encode": Pooler.for_encode(pooler_config), "token_classify": Pooler.for_token_classify(
pooler_config, classifier=self.classifier
),
"classify": ClassifierPooler( "classify": ClassifierPooler(
pooling=self.new.pooler, pooling=self.new.pooler,
classifier=self.classifier, classifier=self.classifier,
act_fn=ClassifierPooler.act_fn_for_seq_cls( act_fn="classify",
vllm_config.model_config
),
), ),
"score": ClassifierPooler( "score": ClassifierPooler(
pooling=self.new.pooler, pooling=self.new.pooler, classifier=self.classifier, act_fn="score"
classifier=self.classifier,
act_fn=ClassifierPooler.act_fn_for_cross_encoder(
vllm_config.model_config
),
), ),
} }
) )
......
...@@ -837,7 +837,7 @@ class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant): ...@@ -837,7 +837,7 @@ class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
self.pooler = DispatchPooler( self.pooler = DispatchPooler(
{ {
"encode": Pooler.for_encode(pooler_config), "token_embed": Pooler.for_token_embed(pooler_config),
"embed": Pooler.for_embed(pooler_config), "embed": Pooler.for_embed(pooler_config),
} }
) )
......
...@@ -353,8 +353,15 @@ class GPT2ForSequenceClassification(nn.Module, SupportsCrossEncoding): ...@@ -353,8 +353,15 @@ class GPT2ForSequenceClassification(nn.Module, SupportsCrossEncoding):
self.pooler = DispatchPooler( self.pooler = DispatchPooler(
{ {
"encode": Pooler.for_encode(pooler_config), "token_classify": Pooler.for_token_classify(
"classify": Pooler.for_classify(pooler_config, classifier=self.score), pooler_config, classifier=self.score
),
"classify": Pooler.for_classify(
pooler_config, classifier=self.score, act_fn="classify"
),
"score": Pooler.for_classify(
pooler_config, classifier=self.score, act_fn="score"
),
} }
) )
......
...@@ -239,7 +239,7 @@ class GritLM(LlamaForCausalLM): ...@@ -239,7 +239,7 @@ class GritLM(LlamaForCausalLM):
if pooler_config is not None: if pooler_config is not None:
self.pooler = DispatchPooler( self.pooler = DispatchPooler(
{ {
"encode": Pooler.for_encode(pooler_config), "token_embed": Pooler.for_token_embed(pooler_config),
"embed": GritLMPooler(vllm_config.model_config), "embed": GritLMPooler(vllm_config.model_config),
} }
) )
...@@ -444,7 +444,7 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM): ...@@ -444,7 +444,7 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM):
assert pooler_config is not None assert pooler_config is not None
self.pooler = DispatchPooler( self.pooler = DispatchPooler(
{"encode": Pooler.for_encode(pooler_config)}, {"token_classify": Pooler.for_token_classify(pooler_config)}
) )
def forward( def forward(
......
...@@ -604,10 +604,14 @@ class JambaForSequenceClassification(JambaForCausalLM): ...@@ -604,10 +604,14 @@ class JambaForSequenceClassification(JambaForCausalLM):
self.pooler = DispatchPooler( self.pooler = DispatchPooler(
{ {
"encode": Pooler.for_encode(pooler_config), "token_classify": Pooler.for_token_classify(
pooler_config, classifier=self.score
),
"classify": Pooler.for_classify( "classify": Pooler.for_classify(
pooler_config, pooler_config, classifier=self.score, act_fn="classify"
classifier=self.score, ),
"score": Pooler.for_classify(
pooler_config, classifier=self.score, act_fn="score"
), ),
} }
) )
...@@ -97,9 +97,15 @@ class JinaVLForSequenceClassification( ...@@ -97,9 +97,15 @@ class JinaVLForSequenceClassification(
self.score = JinaVLScorer(vllm_config.model_config) self.score = JinaVLScorer(vllm_config.model_config)
self.pooler = DispatchPooler( self.pooler = DispatchPooler(
{ {
"encode": Pooler.for_encode(pooler_config), "token_classify": Pooler.for_token_classify(
"classify": Pooler.for_classify(pooler_config, classifier=self.score), pooler_config, classifier=self.score
"score": Pooler.for_classify(pooler_config, classifier=self.score), ),
"classify": Pooler.for_classify(
pooler_config, classifier=self.score, act_fn="classify"
),
"score": Pooler.for_classify(
pooler_config, classifier=self.score, act_fn="score"
),
} }
) )
......
...@@ -322,20 +322,14 @@ class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding): ...@@ -322,20 +322,14 @@ class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding):
self.pooler = DispatchPooler( self.pooler = DispatchPooler(
{ {
"encode": Pooler.for_encode(pooler_config), "token_classify": Pooler.for_token_classify(
pooler_config, classifier=self.classifier
),
"classify": ClassifierPooler( "classify": ClassifierPooler(
pooling=self.pooling, pooling=self.pooling, classifier=self.classifier, act_fn="classify"
classifier=self.classifier,
act_fn=ClassifierPooler.act_fn_for_seq_cls(
vllm_config.model_config
),
), ),
"score": ClassifierPooler( "score": ClassifierPooler(
pooling=self.pooling, pooling=self.pooling, classifier=self.classifier, act_fn="score"
classifier=self.classifier,
act_fn=ClassifierPooler.act_fn_for_cross_encoder(
vllm_config.model_config
),
), ),
} }
) )
...@@ -421,7 +415,9 @@ class ModernBertForTokenClassification(nn.Module): ...@@ -421,7 +415,9 @@ class ModernBertForTokenClassification(nn.Module):
self.pooler = DispatchPooler( self.pooler = DispatchPooler(
{ {
"encode": Pooler.for_encode(pooler_config), "token_classify": Pooler.for_token_classify(
pooler_config=pooler_config
),
} }
) )
......
...@@ -107,7 +107,7 @@ class Qwen2ForRewardModel(Qwen2RewardBaseModel): ...@@ -107,7 +107,7 @@ class Qwen2ForRewardModel(Qwen2RewardBaseModel):
assert pooler_config is not None assert pooler_config is not None
self.pooler = DispatchPooler( self.pooler = DispatchPooler(
{"encode": Pooler.for_encode(pooler_config)}, {"token_classify": Pooler.for_token_classify(pooler_config)}
) )
...@@ -120,4 +120,6 @@ class Qwen2ForProcessRewardModel(Qwen2RewardBaseModel): ...@@ -120,4 +120,6 @@ class Qwen2ForProcessRewardModel(Qwen2RewardBaseModel):
pooler_config = vllm_config.model_config.pooler_config pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None assert pooler_config is not None
self.pooler = DispatchPooler({"encode": Pooler.for_encode(pooler_config)}) self.pooler = DispatchPooler(
{"token_classify": Pooler.for_token_classify(pooler_config)}
)
...@@ -105,15 +105,7 @@ class RobertaClassificationHead(nn.Module): ...@@ -105,15 +105,7 @@ class RobertaClassificationHead(nn.Module):
@default_pooling_type("CLS") @default_pooling_type("CLS")
class RobertaEmbeddingModel(BertEmbeddingModel): class RobertaEmbeddingModel(BertEmbeddingModel):
"""A model that uses Roberta to provide embedding functionalities. """A model that uses Roberta to provide embedding functionalities."""
This class encapsulates the BertModel and provides an interface for
embedding operations and customized pooling functions.
Attributes:
model: An instance of BertModel used for forward operations.
_pooler: An instance of Pooler used for pooling operations.
"""
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix) super().__init__(vllm_config=vllm_config, prefix=prefix)
...@@ -212,20 +204,14 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding): ...@@ -212,20 +204,14 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
self.pooler = DispatchPooler( self.pooler = DispatchPooler(
{ {
"encode": Pooler.for_encode(pooler_config), "token_classify": Pooler.for_token_classify(
pooler_config=pooler_config, classifier=self.classifier
),
"classify": ClassifierPooler( "classify": ClassifierPooler(
pooling=CLSPool(), pooling=CLSPool(), classifier=self.classifier, act_fn="classify"
classifier=self.classifier,
act_fn=ClassifierPooler.act_fn_for_seq_cls(
vllm_config.model_config
),
), ),
"score": ClassifierPooler( "score": ClassifierPooler(
pooling=CLSPool(), pooling=CLSPool(), classifier=self.classifier, act_fn="score"
classifier=self.classifier,
act_fn=ClassifierPooler.act_fn_for_cross_encoder(
vllm_config.model_config
),
), ),
} }
) )
......
...@@ -250,7 +250,7 @@ class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal): ...@@ -250,7 +250,7 @@ class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal):
assert pooler_config is not None assert pooler_config is not None
self.pooler = DispatchPooler( self.pooler = DispatchPooler(
{"encode": Pooler.for_encode(pooler_config)}, {"token_classify": Pooler.for_token_classify(pooler_config)}
) )
def get_input_embeddings( def get_input_embeddings(
......
...@@ -135,7 +135,7 @@ class TransformersEmbeddingModel(TransformersPoolingBase): ...@@ -135,7 +135,7 @@ class TransformersEmbeddingModel(TransformersPoolingBase):
self.pooler = DispatchPooler( self.pooler = DispatchPooler(
{ {
"encode": Pooler.for_encode(pooler_config), "token_embed": Pooler.for_token_embed(pooler_config),
"embed": Pooler.for_embed(pooler_config), "embed": Pooler.for_embed(pooler_config),
} }
) )
...@@ -190,20 +190,14 @@ class TransformersForSequenceClassification(TransformersPoolingBase): ...@@ -190,20 +190,14 @@ class TransformersForSequenceClassification(TransformersPoolingBase):
self.pooler = DispatchPooler( self.pooler = DispatchPooler(
{ {
"encode": Pooler.for_encode(pooler_config), "token_classify": Pooler.for_token_classify(
pooler_config, classifier=self.classifier
),
"classify": ClassifierPooler( "classify": ClassifierPooler(
pooling=CLSPool(), pooling=CLSPool(), classifier=self.classifier, act_fn="classify"
classifier=self.classifier,
act_fn=ClassifierPooler.act_fn_for_seq_cls(
vllm_config.model_config
),
), ),
"score": ClassifierPooler( "score": ClassifierPooler(
pooling=CLSPool(), pooling=CLSPool(), classifier=self.classifier, act_fn="score"
classifier=self.classifier,
act_fn=ClassifierPooler.act_fn_for_cross_encoder(
vllm_config.model_config
),
), ),
} }
) )
......
...@@ -10,7 +10,7 @@ from vllm.sampling_params import RequestOutputKind ...@@ -10,7 +10,7 @@ from vllm.sampling_params import RequestOutputKind
from vllm.tasks import PoolingTask from vllm.tasks import PoolingTask
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import ModelConfig from vllm.config import ModelConfig, PoolerConfig
class PoolingParams( class PoolingParams(
...@@ -30,7 +30,6 @@ class PoolingParams( ...@@ -30,7 +30,6 @@ class PoolingParams(
if model support matryoshka representation. if model support matryoshka representation.
activation: Whether to apply activation function to activation: Whether to apply activation function to
the classification outputs. the classification outputs.
softmax: Whether to apply softmax to the reward outputs.
""" """
# --8<-- [start:common-pooling-params] # --8<-- [start:common-pooling-params]
...@@ -48,32 +47,19 @@ class PoolingParams( ...@@ -48,32 +47,19 @@ class PoolingParams(
activation: bool | None = None activation: bool | None = None
# --8<-- [end:classification-pooling-params] # --8<-- [end:classification-pooling-params]
## for reward models ## for step pooling models
softmax: bool | None = None
step_tag_id: int | None = None step_tag_id: int | None = None
returned_token_ids: list[int] | None = None returned_token_ids: list[int] | None = None
## Internal use only
task: PoolingTask | None = None task: PoolingTask | None = None
"""Internal use only."""
requires_token_ids: bool = False requires_token_ids: bool = False
"""Internal use only."""
extra_kwargs: dict[str, Any] | None = None extra_kwargs: dict[str, Any] | None = None
"""Internal use only."""
output_kind: RequestOutputKind = RequestOutputKind.FINAL_ONLY output_kind: RequestOutputKind = RequestOutputKind.FINAL_ONLY
@property @property
def all_parameters(self) -> list[str]: def all_parameters(self) -> list[str]:
return [ return ["dimensions", "normalize", "activation"]
"dimensions",
"normalize",
"activation",
"softmax",
"step_tag_id",
"returned_token_ids",
]
@property @property
def valid_parameters(self): def valid_parameters(self):
...@@ -81,7 +67,8 @@ class PoolingParams( ...@@ -81,7 +67,8 @@ class PoolingParams(
"embed": ["dimensions", "normalize"], "embed": ["dimensions", "normalize"],
"classify": ["activation"], "classify": ["activation"],
"score": ["activation"], "score": ["activation"],
"encode": ["softmax", "step_tag_id", "returned_token_ids"], "token_embed": ["dimensions", "normalize"],
"token_classify": ["activation"],
} }
def clone(self) -> "PoolingParams": def clone(self) -> "PoolingParams":
...@@ -100,7 +87,6 @@ class PoolingParams( ...@@ -100,7 +87,6 @@ class PoolingParams(
# NOTE: Task validation needs to done against the model instance, # NOTE: Task validation needs to done against the model instance,
# which is not available in model config. So, it's not included # which is not available in model config. So, it's not included
# in this method # in this method
self._merge_default_parameters(model_config) self._merge_default_parameters(model_config)
self._set_default_parameters(model_config) self._set_default_parameters(model_config)
self._verify_valid_parameters() self._verify_valid_parameters()
...@@ -125,8 +111,34 @@ class PoolingParams( ...@@ -125,8 +111,34 @@ class PoolingParams(
if getattr(self, k, None) is None: if getattr(self, k, None) is None:
setattr(self, k, getattr(pooler_config, k)) setattr(self, k, getattr(pooler_config, k))
self._verify_step_pooling(pooler_config, valid_parameters)
def _verify_step_pooling(
self, pooler_config: "PoolerConfig", valid_parameters: list[str]
):
step_pooling_parameters = ["step_tag_id", "returned_token_ids"]
if pooler_config.pooling_type != "STEP":
invalid_parameters = []
for k in step_pooling_parameters:
if getattr(self, k, None) is not None:
invalid_parameters.append(k)
if invalid_parameters:
raise ValueError(
f"Task {self.task} only supports {valid_parameters} "
f"parameters, does not support "
f"{invalid_parameters} parameters"
)
else:
for k in step_pooling_parameters:
if getattr(pooler_config, k, None) is None:
continue
if getattr(self, k, None) is None:
setattr(self, k, getattr(pooler_config, k))
def _set_default_parameters(self, model_config: Optional["ModelConfig"]): def _set_default_parameters(self, model_config: Optional["ModelConfig"]):
if self.task == "embed": if self.task in ["embed", "token_embed"]:
if self.normalize is None: if self.normalize is None:
self.normalize = True self.normalize = True
...@@ -150,13 +162,9 @@ class PoolingParams( ...@@ -150,13 +162,9 @@ class PoolingParams(
elif self.dimensions < 1: elif self.dimensions < 1:
raise ValueError("Dimensions must be greater than 0") raise ValueError("Dimensions must be greater than 0")
elif self.task in ["classify", "score"]: elif self.task in ["classify", "score", "token_classify"]:
if self.activation is None: if self.activation is None:
self.activation = True self.activation = True
elif self.task == "encode":
if self.softmax is None:
self.softmax = True
else: else:
raise ValueError(f"Unknown pooling task: {self.task}") raise ValueError(f"Unknown pooling task: {self.task}")
...@@ -185,7 +193,6 @@ class PoolingParams( ...@@ -185,7 +193,6 @@ class PoolingParams(
f"normalize={self.normalize}, " f"normalize={self.normalize}, "
f"dimensions={self.dimensions}, " f"dimensions={self.dimensions}, "
f"activation={self.activation}, " f"activation={self.activation}, "
f"softmax={self.softmax}, "
f"step_tag_id={self.step_tag_id}, " f"step_tag_id={self.step_tag_id}, "
f"returned_token_ids={self.returned_token_ids}, " f"returned_token_ids={self.returned_token_ids}, "
f"requires_token_ids={self.requires_token_ids}, " f"requires_token_ids={self.requires_token_ids}, "
......
...@@ -5,7 +5,7 @@ from typing import Literal, get_args ...@@ -5,7 +5,7 @@ from typing import Literal, get_args
GenerationTask = Literal["generate", "transcription"] GenerationTask = Literal["generate", "transcription"]
GENERATION_TASKS = get_args(GenerationTask) GENERATION_TASKS = get_args(GenerationTask)
PoolingTask = Literal["encode", "embed", "classify", "score"] PoolingTask = Literal["embed", "classify", "score", "token_embed", "token_classify"]
POOLING_TASKS = get_args(PoolingTask) POOLING_TASKS = get_args(PoolingTask)
SupportedTask = Literal[GenerationTask, PoolingTask] SupportedTask = Literal[GenerationTask, PoolingTask]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment