Unverified Commit f54f8512 authored by wang.yuqi's avatar wang.yuqi Committed by GitHub
Browse files

[Model][2/N] Improve all pooling task | Support multi-vector retrieval (#25370)


Signed-off-by: default avatarwang.yuqi <noooop@126.com>
parent d4d1a602
......@@ -1748,16 +1748,19 @@ async def init_app_state(
else None
)
state.openai_serving_pooling = (
OpenAIServingPooling(
engine_client,
state.openai_serving_models,
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
trust_request_chat_template=args.trust_request_chat_template,
log_error_stack=args.log_error_stack,
(
OpenAIServingPooling(
engine_client,
state.openai_serving_models,
supported_tasks=supported_tasks,
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
trust_request_chat_template=args.trust_request_chat_template,
log_error_stack=args.log_error_stack,
)
)
if "encode" in supported_tasks
if ("token_embed" in supported_tasks or "token_classify" in supported_tasks)
else None
)
state.openai_serving_embedding = (
......
......@@ -1682,7 +1682,7 @@ class IOProcessorRequest(OpenAIBaseModel, Generic[T]):
When using plugins IOProcessor plugins, the actual input is processed
by the plugin itself. Hence, we use a generic type for the request data
"""
softmax: bool = True
activation: bool = False
embed_dtype: str = Field(
default="float32",
......@@ -1693,7 +1693,7 @@ class IOProcessorRequest(OpenAIBaseModel, Generic[T]):
)
def to_pooling_params(self):
return PoolingParams(task="encode", softmax=self.softmax)
return PoolingParams(task="token_classify", activation=self.activation)
class IOProcessorResponse(OpenAIBaseModel, Generic[T]):
......
......@@ -35,6 +35,7 @@ from vllm.entrypoints.renderer import RenderConfig
from vllm.entrypoints.utils import _validate_truncation_size
from vllm.logger import init_logger
from vllm.outputs import PoolingOutput, PoolingRequestOutput
from vllm.tasks import SupportedTask
from vllm.utils import merge_async_iterators
logger = init_logger(__name__)
......@@ -62,6 +63,7 @@ class OpenAIServingPooling(OpenAIServing):
engine_client: EngineClient,
models: OpenAIServingModels,
*,
supported_tasks: tuple[SupportedTask, ...],
request_logger: RequestLogger | None,
chat_template: str | None,
chat_template_content_format: ChatTemplateContentFormatOption,
......@@ -75,6 +77,7 @@ class OpenAIServingPooling(OpenAIServing):
log_error_stack=log_error_stack,
)
self.supported_tasks = supported_tasks
self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format
self.trust_request_chat_template = trust_request_chat_template
......@@ -178,8 +181,17 @@ class OpenAIServingPooling(OpenAIServing):
try:
pooling_params = request.to_pooling_params()
if "token_embed" in self.supported_tasks:
pooling_task = "token_embed"
elif "token_classify" in self.supported_tasks:
pooling_task = "token_classify"
else:
return self.create_error_response(
f"pooling_task must be one of {self.supported_tasks}."
)
try:
pooling_params.verify("encode", self.model_config)
pooling_params.verify(pooling_task, self.model_config)
except ValueError as e:
return self.create_error_response(str(e))
......
......@@ -64,66 +64,6 @@ class PoolingParamsUpdate:
params.requires_token_ids = self.requires_token_ids
class Pooler(nn.Module, ABC):
"""The interface required for all poolers used in pooling models in vLLM."""
@staticmethod
def for_encode(pooler_config: PoolerConfig):
if pooler_config.pooling_type == "STEP":
return StepPooler()
resolved_config = ResolvedPoolingConfig(
task="encode", pooling_type=PoolingType.ALL
)
return SimplePooler.from_config(resolved_config)
@staticmethod
def for_embed(pooler_config: PoolerConfig):
resolved_config = ResolvedPoolingConfig.from_config(
task="embed",
pooler_config=pooler_config,
)
return SimplePooler.from_config(resolved_config)
@staticmethod
def for_classify(
pooler_config: PoolerConfig,
classifier: ClassifierFn | None,
):
resolved_config = ResolvedPoolingConfig.from_config(
task="classify",
pooler_config=pooler_config,
)
pooling = PoolingMethod.from_pooling_type(resolved_config.pooling_type)
return ClassifierPooler(
pooling=pooling,
classifier=classifier,
)
@abstractmethod
def get_supported_tasks(self) -> Set[PoolingTask]:
"""Determine which pooling tasks are supported."""
raise NotImplementedError
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
"""
Construct the updated pooling parameters to use for a supported task.
"""
return PoolingParamsUpdate()
@abstractmethod
def forward(
self,
hidden_states: list[torch.Tensor] | torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
raise NotImplementedError
def get_prompt_lens(
hidden_states: torch.Tensor | list[torch.Tensor],
pooling_metadata: PoolingMetadata,
......@@ -237,7 +177,7 @@ class PoolingMethod(nn.Module, ABC):
class CLSPool(PoolingMethod):
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"encode", "embed", "classify", "score"}
return {"token_embed", "token_classify", "embed", "classify", "score"}
def forward_all(
self,
......@@ -253,7 +193,7 @@ class CLSPool(PoolingMethod):
class LastPool(PoolingMethod):
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"encode", "embed", "classify", "score"}
return {"token_embed", "token_classify", "embed", "classify", "score"}
def forward_all(
self,
......@@ -265,7 +205,7 @@ class LastPool(PoolingMethod):
class AllPool(PoolingMethod):
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"encode"}
return {"token_embed", "token_classify"}
def forward_all(
self,
......@@ -284,7 +224,7 @@ class AllPool(PoolingMethod):
class MeanPool(PoolingMethod):
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"encode", "embed", "classify", "score"}
return {"token_embed", "token_classify", "embed", "classify", "score"}
def forward_all(
self,
......@@ -398,6 +338,82 @@ class LambdaPoolerActivation(PoolerActivation):
return self.fn(pooled_data)
class Pooler(nn.Module, ABC):
"""The interface required for all poolers used in pooling models in vLLM."""
@staticmethod
def for_token_embed(pooler_config: PoolerConfig):
head = TokenEmbeddingPoolerHead()
if pooler_config.pooling_type == "STEP":
return StepPooler(head=head)
return AllPooler(head=head)
@staticmethod
def for_token_classify(
pooler_config: PoolerConfig,
classifier: ClassifierFn | None = None,
act_fn: PoolerActivation | str | None = None,
):
head = TokenClassifierPoolerHead(classifier=classifier, act_fn=act_fn)
if pooler_config.pooling_type == "STEP":
return StepPooler(head=head)
return AllPooler(head=head)
@staticmethod
def for_embed(pooler_config: PoolerConfig):
resolved_config = ResolvedPoolingConfig.from_config(
task="embed",
pooler_config=pooler_config,
)
pooling = PoolingMethod.from_pooling_type(resolved_config.pooling_type)
head = EmbeddingPoolerHead()
return SimplePooler(pooling=pooling, head=head)
@staticmethod
def for_classify(
pooler_config: PoolerConfig,
classifier: ClassifierFn | None,
act_fn: PoolerActivation | str | None = None,
):
resolved_config = ResolvedPoolingConfig.from_config(
task="classify",
pooler_config=pooler_config,
)
pooling = PoolingMethod.from_pooling_type(resolved_config.pooling_type)
return ClassifierPooler(
pooling=pooling,
classifier=classifier,
act_fn=act_fn,
)
@abstractmethod
def get_supported_tasks(self) -> Set[PoolingTask]:
"""Determine which pooling tasks are supported."""
raise NotImplementedError
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
"""
Construct the updated pooling parameters to use for a supported task.
"""
return PoolingParamsUpdate()
@abstractmethod
def forward(
self,
hidden_states: list[torch.Tensor] | torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
raise NotImplementedError
class PoolerHead(nn.Module):
def __init__(self, activation: PoolerActivation) -> None:
super().__init__()
......@@ -416,7 +432,6 @@ class EmbeddingPoolerHead(PoolerHead):
super().__init__(activation=PoolerNormalize())
# Load ST projector if available
vllm_config = get_current_vllm_config()
self.projector: nn.Module | None = (
_load_st_projector(vllm_config.model_config) if vllm_config else None
......@@ -471,39 +486,6 @@ class EmbeddingPoolerHead(PoolerHead):
return pooled_data
class RewardPoolerHead(PoolerHead):
def __init__(self) -> None:
super().__init__(activation=PoolerClassify(static_num_labels=False))
vllm_config = get_current_vllm_config()
self.head_dtype = vllm_config.model_config.head_dtype
def forward(
self,
pooled_data: list[torch.Tensor] | torch.Tensor,
pooling_metadata: PoolingMetadata,
):
if isinstance(pooled_data, list):
pooled_data = [p.to(self.head_dtype) for p in pooled_data]
else:
pooled_data = pooled_data.to(self.head_dtype)
pooling_params = get_pooling_params(pooling_metadata)
# for softmax
flags = [p.softmax for p in pooling_params]
if len(set(flags)) == 1:
if flags[0]:
pooled_data = self.activation(pooled_data)
else:
pooled_data = [
self.activation(vecs) if f else vecs
for vecs, f in zip(pooled_data, flags)
]
return pooled_data
class SimplePooler(Pooler):
"""A layer that pools specific information from hidden states.
......@@ -513,20 +495,6 @@ class SimplePooler(Pooler):
3. Returns structured results as `PoolerOutput`.
"""
@classmethod
def from_config(
cls,
pooler_config: ResolvedPoolingConfig,
) -> "SimplePooler":
pooling = PoolingMethod.from_pooling_type(pooler_config.pooling_type)
if pooler_config.task == "embed":
head = EmbeddingPoolerHead()
elif pooler_config.task == "encode":
head = RewardPoolerHead()
else:
raise NotImplementedError(f"Unknown task: {pooler_config.task}")
return cls(pooling, head)
def __init__(self, pooling: PoolingMethod, head: PoolerHead) -> None:
super().__init__()
......@@ -549,58 +517,6 @@ class SimplePooler(Pooler):
return pooled_data
class StepPooler(Pooler):
def __init__(
self,
) -> None:
super().__init__()
self.pooling = AllPool()
self.head = RewardPoolerHead()
def extract_states(
self,
hidden_states: torch.Tensor | list[torch.Tensor],
pooling_metadata: PoolingMetadata,
) -> list[torch.Tensor] | torch.Tensor:
pooled_data_lst = self.pooling(hidden_states, pooling_metadata)
prompt_token_ids = get_prompt_token_ids(pooling_metadata)
pooled_data = list[torch.Tensor]()
pooling_params = get_pooling_params(pooling_metadata)
for data, token_id, pooling_param in zip(
pooled_data_lst, prompt_token_ids, pooling_params
):
step_tag_id = pooling_param.step_tag_id
returned_token_ids = pooling_param.returned_token_ids
if returned_token_ids is not None and len(returned_token_ids) > 0:
data = data[:, returned_token_ids]
if step_tag_id is not None:
data = data[token_id == step_tag_id]
pooled_data.append(data)
return pooled_data
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"encode"}
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
return PoolingParamsUpdate(requires_token_ids=True)
def forward(
self,
hidden_states: torch.Tensor | list[torch.Tensor],
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
pooled_data = self.extract_states(hidden_states, pooling_metadata)
pooled_data = self.head(pooled_data, pooling_metadata)
return pooled_data
class ClassifierPooler(Pooler):
"""A pooling layer for classification tasks.
......@@ -611,26 +527,46 @@ class ClassifierPooler(Pooler):
"""
@staticmethod
def act_fn_for_seq_cls(config: ModelConfig):
return get_classification_activation_function(config.hf_config)
def act_fn_for_seq_cls(model_config: ModelConfig):
return get_classification_activation_function(model_config.hf_config)
@staticmethod
def act_fn_for_cross_encoder(model_config: ModelConfig):
return get_cross_encoder_activation_function(model_config.hf_config)
@staticmethod
def act_fn_for_cross_encoder(config: ModelConfig):
return get_cross_encoder_activation_function(config.hf_config)
def resolve_act_fn(
model_config: ModelConfig,
static_num_labels: bool = True,
act_fn: PoolerActivation | str | None = None,
):
if isinstance(act_fn, str):
if act_fn == "classify":
return ClassifierPooler.act_fn_for_seq_cls(model_config)
elif act_fn == "score":
return ClassifierPooler.act_fn_for_cross_encoder(model_config)
else:
raise ValueError(f"act_fn [{act_fn=}] not supported.")
elif act_fn is None:
return PoolerClassify(static_num_labels=static_num_labels)
else:
assert callable(act_fn)
return act_fn
def __init__(
self,
pooling: PoolingFn,
classifier: ClassifierFn | None,
act_fn: PoolerActivation | None = None,
act_fn: PoolerActivation | str | None = None,
) -> None:
super().__init__()
vllm_config = get_current_vllm_config()
self.pooling = pooling
self.classifier = classifier
self.act_fn = act_fn or PoolerClassify()
self.act_fn = self.resolve_act_fn(
vllm_config.model_config, static_num_labels=True, act_fn=act_fn
)
self.logit_bias: float | None = (
vllm_config.model_config.pooler_config.logit_bias
)
......@@ -672,6 +608,150 @@ class ClassifierPooler(Pooler):
return scores
class TokenEmbeddingPoolerHead(EmbeddingPoolerHead):
def forward(
self, pooled_data: torch.Tensor, pooling_param: PoolingParams
) -> torch.Tensor:
pooled_data = pooled_data.to(self.head_dtype)
# pooled_data shape: [n_tokens, hidden_dimension]
# Apply ST projector
if self.projector is not None:
pooled_data = self.projector(pooled_data)
# pooled_data shape: [n_tokens, embedding_dimension]
# for matryoshka representation
pooled_data = pooled_data[..., : pooling_param.dimensions]
# for normalize
if pooling_param.normalize:
pooled_data = self.activation(pooled_data)
# pooled_data shape: [n_tokens, embedding_dimension]
return pooled_data
class TokenClassifierPoolerHead(nn.Module):
def __init__(
self,
classifier: ClassifierFn | None,
act_fn: PoolerActivation | str | None = None,
) -> None:
super().__init__()
vllm_config = get_current_vllm_config()
self.classifier = classifier
self.act_fn = ClassifierPooler.resolve_act_fn(
vllm_config.model_config, static_num_labels=False, act_fn=act_fn
)
self.logit_bias: float | None = (
vllm_config.model_config.pooler_config.logit_bias
)
self.head_dtype = vllm_config.model_config.head_dtype
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"token_classify"}
def forward(
self,
hidden_states: torch.Tensor,
pooling_param: PoolingParams,
) -> torch.Tensor:
hidden_states = hidden_states.to(self.head_dtype)
# hidden_states shape: [n_token, hidden_size]
if self.classifier is not None:
scores = self.classifier(hidden_states)
else:
scores = hidden_states
# scores shape: [n_token, num_labels]
if self.logit_bias is not None:
scores -= self.logit_bias
if pooling_param.activation:
scores = self.act_fn(scores)
# scores shape: [n_token, num_labels]
return scores
class AllPooler(Pooler):
def __init__(self, head: nn.Module | PoolerHead) -> None:
super().__init__()
self.pooling = AllPool()
self.head = head
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"token_embed", "token_classify"}
def forward(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
pooled_data = self.pooling(hidden_states, pooling_metadata)
pooling_params = get_pooling_params(pooling_metadata)
assert len(pooled_data) == len(pooling_params)
pooled_data = [self.head(d, p) for d, p in zip(pooled_data, pooling_params)]
return pooled_data
class StepPooler(Pooler):
def __init__(self, head: nn.Module | PoolerHead) -> None:
super().__init__()
self.pooling = AllPool()
self.head = head
def extract_states(
self,
hidden_states: torch.Tensor | list[torch.Tensor],
pooling_metadata: PoolingMetadata,
) -> torch.Tensor | list[torch.Tensor]:
pooled_data_lst = self.pooling(hidden_states, pooling_metadata)
prompt_token_ids = get_prompt_token_ids(pooling_metadata)
pooled_data = list[torch.Tensor]()
pooling_params = get_pooling_params(pooling_metadata)
for data, token_id, pooling_param in zip(
pooled_data_lst, prompt_token_ids, pooling_params
):
step_tag_id = pooling_param.step_tag_id
returned_token_ids = pooling_param.returned_token_ids
if returned_token_ids is not None and len(returned_token_ids) > 0:
data = data[:, returned_token_ids]
if step_tag_id is not None:
data = data[token_id == step_tag_id]
pooled_data.append(data)
return pooled_data
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"token_embed", "token_classify"}
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
return PoolingParamsUpdate(requires_token_ids=True)
def forward(
self,
hidden_states: torch.Tensor | list[torch.Tensor],
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
pooled_data = self.extract_states(hidden_states, pooling_metadata)
pooling_params = get_pooling_params(pooling_metadata)
assert len(pooled_data) == len(pooling_params)
pooled_data = [self.head(d, p) for d, p in zip(pooled_data, pooling_params)]
return pooled_data
class DispatchPooler(Pooler):
"""Dispatches calls to a sub-pooler based on the pooling task."""
......
......@@ -250,7 +250,7 @@ def as_embedding_model(cls: _T) -> _T:
self.pooler = DispatchPooler(
{
"encode": Pooler.for_encode(pooler_config),
"token_embed": Pooler.for_token_embed(pooler_config),
"embed": Pooler.for_embed(pooler_config),
},
)
......@@ -279,11 +279,8 @@ def as_seq_cls_model(cls: _T) -> _T:
# Lazy import
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.pooler import (
ClassifierPooler,
DispatchPooler,
Pooler,
PoolingMethod,
PoolingType,
)
from vllm.model_executor.models.interfaces import SupportsCrossEncoding
from vllm.sequence import IntermediateTensors
......@@ -302,42 +299,29 @@ def as_seq_cls_model(cls: _T) -> _T:
model_config.hidden_size,
config.num_labels,
bias=False,
params_dtype=torch.float32,
params_dtype=vllm_config.model_config.head_dtype,
quant_config=quant_config,
return_bias=False,
prefix=maybe_prefix(prefix, "score"),
)
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
pooling_type_str = pooler_config.pooling_type
assert pooling_type_str is not None
pooling_type = PoolingType[pooling_type_str]
self.pooler = DispatchPooler(
{
"encode": Pooler.for_encode(pooler_config),
"classify": ClassifierPooler(
pooling=PoolingMethod.from_pooling_type(pooling_type),
classifier=self._classifier,
act_fn=ClassifierPooler.act_fn_for_seq_cls(
vllm_config.model_config
),
"token_classify": Pooler.for_token_classify(
pooler_config, classifier=self.score
),
"classify": Pooler.for_classify(
pooler_config, classifier=self.score, act_fn="classify"
),
"score": ClassifierPooler(
pooling=PoolingMethod.from_pooling_type(pooling_type),
classifier=self._classifier,
act_fn=ClassifierPooler.act_fn_for_cross_encoder(
vllm_config.model_config
),
"score": Pooler.for_classify(
pooler_config, classifier=self.score, act_fn="score"
),
}
)
def _classifier(self, x: torch.Tensor):
x, _ = self.score(x.float())
return x
def forward(
self,
input_ids: torch.Tensor,
......@@ -393,7 +377,11 @@ def as_reward_model(cls: _T) -> _T:
assert pooler_config is not None
self.pooler = DispatchPooler(
{"encode": Pooler.for_encode(pooler_config)},
{
"token_classify": Pooler.for_token_classify(
pooler_config=pooler_config
)
}
)
ModelForReward.__name__ = _get_pooling_model_name(cls.__name__, "ForReward")
......
......@@ -521,7 +521,7 @@ class BertEmbeddingModel(nn.Module, SupportsQuant):
def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler:
return DispatchPooler(
{
"encode": Pooler.for_encode(pooler_config),
"token_embed": Pooler.for_token_embed(pooler_config),
"embed": Pooler.for_embed(pooler_config),
}
)
......@@ -724,7 +724,7 @@ class BertSpladeSparseEmbeddingModel(BertEmbeddingModel):
return DispatchPooler(
{
"encode": Pooler.for_encode(pooler_config),
"token_embed": Pooler.for_token_embed(pooler_config),
"embed": SPLADESparsePooler(
mlm_head=self.mlm_head,
cls_token_id=cls_id,
......@@ -821,20 +821,16 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, SupportsQu
self.pooler = DispatchPooler(
{
"encode": Pooler.for_encode(pooler_config),
"token_classify": Pooler.for_token_classify(
pooler_config, classifier=self.classifier
),
"classify": ClassifierPooler(
pooling=self.bert.pooler,
classifier=self.classifier,
act_fn=ClassifierPooler.act_fn_for_seq_cls(
vllm_config.model_config
),
act_fn="classify",
),
"score": ClassifierPooler(
pooling=self.bert.pooler,
classifier=self.classifier,
act_fn=ClassifierPooler.act_fn_for_cross_encoder(
vllm_config.model_config
),
pooling=self.bert.pooler, classifier=self.classifier, act_fn="score"
),
}
)
......@@ -891,7 +887,9 @@ class BertForTokenClassification(nn.Module):
self.pooler = DispatchPooler(
{
"encode": Pooler.for_encode(pooler_config),
"token_classify": Pooler.for_token_classify(
pooler_config=pooler_config
),
}
)
......
......@@ -695,20 +695,16 @@ class GteNewForSequenceClassification(nn.Module, SupportsCrossEncoding):
self.pooler = DispatchPooler(
{
"encode": Pooler.for_encode(pooler_config),
"token_classify": Pooler.for_token_classify(
pooler_config, classifier=self.classifier
),
"classify": ClassifierPooler(
pooling=self.new.pooler,
classifier=self.classifier,
act_fn=ClassifierPooler.act_fn_for_seq_cls(
vllm_config.model_config
),
act_fn="classify",
),
"score": ClassifierPooler(
pooling=self.new.pooler,
classifier=self.classifier,
act_fn=ClassifierPooler.act_fn_for_cross_encoder(
vllm_config.model_config
),
pooling=self.new.pooler, classifier=self.classifier, act_fn="score"
),
}
)
......
......@@ -837,7 +837,7 @@ class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
self.pooler = DispatchPooler(
{
"encode": Pooler.for_encode(pooler_config),
"token_embed": Pooler.for_token_embed(pooler_config),
"embed": Pooler.for_embed(pooler_config),
}
)
......
......@@ -353,8 +353,15 @@ class GPT2ForSequenceClassification(nn.Module, SupportsCrossEncoding):
self.pooler = DispatchPooler(
{
"encode": Pooler.for_encode(pooler_config),
"classify": Pooler.for_classify(pooler_config, classifier=self.score),
"token_classify": Pooler.for_token_classify(
pooler_config, classifier=self.score
),
"classify": Pooler.for_classify(
pooler_config, classifier=self.score, act_fn="classify"
),
"score": Pooler.for_classify(
pooler_config, classifier=self.score, act_fn="score"
),
}
)
......
......@@ -239,7 +239,7 @@ class GritLM(LlamaForCausalLM):
if pooler_config is not None:
self.pooler = DispatchPooler(
{
"encode": Pooler.for_encode(pooler_config),
"token_embed": Pooler.for_token_embed(pooler_config),
"embed": GritLMPooler(vllm_config.model_config),
}
)
......@@ -444,7 +444,7 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM):
assert pooler_config is not None
self.pooler = DispatchPooler(
{"encode": Pooler.for_encode(pooler_config)},
{"token_classify": Pooler.for_token_classify(pooler_config)}
)
def forward(
......
......@@ -604,10 +604,14 @@ class JambaForSequenceClassification(JambaForCausalLM):
self.pooler = DispatchPooler(
{
"encode": Pooler.for_encode(pooler_config),
"token_classify": Pooler.for_token_classify(
pooler_config, classifier=self.score
),
"classify": Pooler.for_classify(
pooler_config,
classifier=self.score,
pooler_config, classifier=self.score, act_fn="classify"
),
"score": Pooler.for_classify(
pooler_config, classifier=self.score, act_fn="score"
),
}
)
......@@ -97,9 +97,15 @@ class JinaVLForSequenceClassification(
self.score = JinaVLScorer(vllm_config.model_config)
self.pooler = DispatchPooler(
{
"encode": Pooler.for_encode(pooler_config),
"classify": Pooler.for_classify(pooler_config, classifier=self.score),
"score": Pooler.for_classify(pooler_config, classifier=self.score),
"token_classify": Pooler.for_token_classify(
pooler_config, classifier=self.score
),
"classify": Pooler.for_classify(
pooler_config, classifier=self.score, act_fn="classify"
),
"score": Pooler.for_classify(
pooler_config, classifier=self.score, act_fn="score"
),
}
)
......
......@@ -322,20 +322,14 @@ class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding):
self.pooler = DispatchPooler(
{
"encode": Pooler.for_encode(pooler_config),
"token_classify": Pooler.for_token_classify(
pooler_config, classifier=self.classifier
),
"classify": ClassifierPooler(
pooling=self.pooling,
classifier=self.classifier,
act_fn=ClassifierPooler.act_fn_for_seq_cls(
vllm_config.model_config
),
pooling=self.pooling, classifier=self.classifier, act_fn="classify"
),
"score": ClassifierPooler(
pooling=self.pooling,
classifier=self.classifier,
act_fn=ClassifierPooler.act_fn_for_cross_encoder(
vllm_config.model_config
),
pooling=self.pooling, classifier=self.classifier, act_fn="score"
),
}
)
......@@ -421,7 +415,9 @@ class ModernBertForTokenClassification(nn.Module):
self.pooler = DispatchPooler(
{
"encode": Pooler.for_encode(pooler_config),
"token_classify": Pooler.for_token_classify(
pooler_config=pooler_config
),
}
)
......
......@@ -107,7 +107,7 @@ class Qwen2ForRewardModel(Qwen2RewardBaseModel):
assert pooler_config is not None
self.pooler = DispatchPooler(
{"encode": Pooler.for_encode(pooler_config)},
{"token_classify": Pooler.for_token_classify(pooler_config)}
)
......@@ -120,4 +120,6 @@ class Qwen2ForProcessRewardModel(Qwen2RewardBaseModel):
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = DispatchPooler({"encode": Pooler.for_encode(pooler_config)})
self.pooler = DispatchPooler(
{"token_classify": Pooler.for_token_classify(pooler_config)}
)
......@@ -105,15 +105,7 @@ class RobertaClassificationHead(nn.Module):
@default_pooling_type("CLS")
class RobertaEmbeddingModel(BertEmbeddingModel):
"""A model that uses Roberta to provide embedding functionalities.
This class encapsulates the BertModel and provides an interface for
embedding operations and customized pooling functions.
Attributes:
model: An instance of BertModel used for forward operations.
_pooler: An instance of Pooler used for pooling operations.
"""
"""A model that uses Roberta to provide embedding functionalities."""
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
......@@ -212,20 +204,14 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
self.pooler = DispatchPooler(
{
"encode": Pooler.for_encode(pooler_config),
"token_classify": Pooler.for_token_classify(
pooler_config=pooler_config, classifier=self.classifier
),
"classify": ClassifierPooler(
pooling=CLSPool(),
classifier=self.classifier,
act_fn=ClassifierPooler.act_fn_for_seq_cls(
vllm_config.model_config
),
pooling=CLSPool(), classifier=self.classifier, act_fn="classify"
),
"score": ClassifierPooler(
pooling=CLSPool(),
classifier=self.classifier,
act_fn=ClassifierPooler.act_fn_for_cross_encoder(
vllm_config.model_config
),
pooling=CLSPool(), classifier=self.classifier, act_fn="score"
),
}
)
......
......@@ -250,7 +250,7 @@ class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal):
assert pooler_config is not None
self.pooler = DispatchPooler(
{"encode": Pooler.for_encode(pooler_config)},
{"token_classify": Pooler.for_token_classify(pooler_config)}
)
def get_input_embeddings(
......
......@@ -135,7 +135,7 @@ class TransformersEmbeddingModel(TransformersPoolingBase):
self.pooler = DispatchPooler(
{
"encode": Pooler.for_encode(pooler_config),
"token_embed": Pooler.for_token_embed(pooler_config),
"embed": Pooler.for_embed(pooler_config),
}
)
......@@ -190,20 +190,14 @@ class TransformersForSequenceClassification(TransformersPoolingBase):
self.pooler = DispatchPooler(
{
"encode": Pooler.for_encode(pooler_config),
"token_classify": Pooler.for_token_classify(
pooler_config, classifier=self.classifier
),
"classify": ClassifierPooler(
pooling=CLSPool(),
classifier=self.classifier,
act_fn=ClassifierPooler.act_fn_for_seq_cls(
vllm_config.model_config
),
pooling=CLSPool(), classifier=self.classifier, act_fn="classify"
),
"score": ClassifierPooler(
pooling=CLSPool(),
classifier=self.classifier,
act_fn=ClassifierPooler.act_fn_for_cross_encoder(
vllm_config.model_config
),
pooling=CLSPool(), classifier=self.classifier, act_fn="score"
),
}
)
......
......@@ -10,7 +10,7 @@ from vllm.sampling_params import RequestOutputKind
from vllm.tasks import PoolingTask
if TYPE_CHECKING:
from vllm.config import ModelConfig
from vllm.config import ModelConfig, PoolerConfig
class PoolingParams(
......@@ -30,7 +30,6 @@ class PoolingParams(
if model support matryoshka representation.
activation: Whether to apply activation function to
the classification outputs.
softmax: Whether to apply softmax to the reward outputs.
"""
# --8<-- [start:common-pooling-params]
......@@ -48,32 +47,19 @@ class PoolingParams(
activation: bool | None = None
# --8<-- [end:classification-pooling-params]
## for reward models
softmax: bool | None = None
## for step pooling models
step_tag_id: int | None = None
returned_token_ids: list[int] | None = None
## Internal use only
task: PoolingTask | None = None
"""Internal use only."""
requires_token_ids: bool = False
"""Internal use only."""
extra_kwargs: dict[str, Any] | None = None
"""Internal use only."""
output_kind: RequestOutputKind = RequestOutputKind.FINAL_ONLY
@property
def all_parameters(self) -> list[str]:
return [
"dimensions",
"normalize",
"activation",
"softmax",
"step_tag_id",
"returned_token_ids",
]
return ["dimensions", "normalize", "activation"]
@property
def valid_parameters(self):
......@@ -81,7 +67,8 @@ class PoolingParams(
"embed": ["dimensions", "normalize"],
"classify": ["activation"],
"score": ["activation"],
"encode": ["softmax", "step_tag_id", "returned_token_ids"],
"token_embed": ["dimensions", "normalize"],
"token_classify": ["activation"],
}
def clone(self) -> "PoolingParams":
......@@ -100,7 +87,6 @@ class PoolingParams(
# NOTE: Task validation needs to done against the model instance,
# which is not available in model config. So, it's not included
# in this method
self._merge_default_parameters(model_config)
self._set_default_parameters(model_config)
self._verify_valid_parameters()
......@@ -125,8 +111,34 @@ class PoolingParams(
if getattr(self, k, None) is None:
setattr(self, k, getattr(pooler_config, k))
self._verify_step_pooling(pooler_config, valid_parameters)
def _verify_step_pooling(
self, pooler_config: "PoolerConfig", valid_parameters: list[str]
):
step_pooling_parameters = ["step_tag_id", "returned_token_ids"]
if pooler_config.pooling_type != "STEP":
invalid_parameters = []
for k in step_pooling_parameters:
if getattr(self, k, None) is not None:
invalid_parameters.append(k)
if invalid_parameters:
raise ValueError(
f"Task {self.task} only supports {valid_parameters} "
f"parameters, does not support "
f"{invalid_parameters} parameters"
)
else:
for k in step_pooling_parameters:
if getattr(pooler_config, k, None) is None:
continue
if getattr(self, k, None) is None:
setattr(self, k, getattr(pooler_config, k))
def _set_default_parameters(self, model_config: Optional["ModelConfig"]):
if self.task == "embed":
if self.task in ["embed", "token_embed"]:
if self.normalize is None:
self.normalize = True
......@@ -150,13 +162,9 @@ class PoolingParams(
elif self.dimensions < 1:
raise ValueError("Dimensions must be greater than 0")
elif self.task in ["classify", "score"]:
elif self.task in ["classify", "score", "token_classify"]:
if self.activation is None:
self.activation = True
elif self.task == "encode":
if self.softmax is None:
self.softmax = True
else:
raise ValueError(f"Unknown pooling task: {self.task}")
......@@ -185,7 +193,6 @@ class PoolingParams(
f"normalize={self.normalize}, "
f"dimensions={self.dimensions}, "
f"activation={self.activation}, "
f"softmax={self.softmax}, "
f"step_tag_id={self.step_tag_id}, "
f"returned_token_ids={self.returned_token_ids}, "
f"requires_token_ids={self.requires_token_ids}, "
......
......@@ -5,7 +5,7 @@ from typing import Literal, get_args
GenerationTask = Literal["generate", "transcription"]
GENERATION_TASKS = get_args(GenerationTask)
PoolingTask = Literal["encode", "embed", "classify", "score"]
PoolingTask = Literal["embed", "classify", "score", "token_embed", "token_classify"]
POOLING_TASKS = get_args(PoolingTask)
SupportedTask = Literal[GenerationTask, PoolingTask]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment