Unverified Commit c18b3b8e authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Bugfix] Add `use_cross_encoder` flag to use correct activation in `ClassifierPooler` (#20527)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 9528e3a0
...@@ -1204,7 +1204,7 @@ class LLM: ...@@ -1204,7 +1204,7 @@ class LLM:
input_pairs = [(t1, t2) for t1, t2 in zip(text_1, text_2)] input_pairs = [(t1, t2) for t1, t2 in zip(text_1, text_2)]
pooling_params = PoolingParams() pooling_params = PoolingParams(use_cross_encoder=True)
tokenization_kwargs: dict[str, Any] = {} tokenization_kwargs: dict[str, Any] = {}
_validate_truncation_size(self.llm_engine.model_config.max_model_len, _validate_truncation_size(self.llm_engine.model_config.max_model_len,
......
...@@ -1156,8 +1156,9 @@ class ScoreRequest(OpenAIBaseModel): ...@@ -1156,8 +1156,9 @@ class ScoreRequest(OpenAIBaseModel):
# --8<-- [end:score-extra-params] # --8<-- [end:score-extra-params]
def to_pooling_params(self): def to_pooling_params(self, *, use_cross_encoder: bool = False):
return PoolingParams(additional_data=self.additional_data) return PoolingParams(use_cross_encoder=use_cross_encoder,
additional_data=self.additional_data)
class RerankRequest(OpenAIBaseModel): class RerankRequest(OpenAIBaseModel):
...@@ -1182,8 +1183,9 @@ class RerankRequest(OpenAIBaseModel): ...@@ -1182,8 +1183,9 @@ class RerankRequest(OpenAIBaseModel):
# --8<-- [end:rerank-extra-params] # --8<-- [end:rerank-extra-params]
def to_pooling_params(self): def to_pooling_params(self, *, use_cross_encoder: bool = False):
return PoolingParams(additional_data=self.additional_data) return PoolingParams(use_cross_encoder=use_cross_encoder,
additional_data=self.additional_data)
class RerankDocument(BaseModel): class RerankDocument(BaseModel):
......
...@@ -25,9 +25,7 @@ from vllm.logger import init_logger ...@@ -25,9 +25,7 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer, from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
PreTrainedTokenizer,
PreTrainedTokenizerFast)
from vllm.utils import make_async, merge_async_iterators from vllm.utils import make_async, merge_async_iterators
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -50,7 +48,7 @@ class ServingScores(OpenAIServing): ...@@ -50,7 +48,7 @@ class ServingScores(OpenAIServing):
async def _embedding_score( async def _embedding_score(
self, self,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], tokenizer: AnyTokenizer,
texts_1: list[str], texts_1: list[str],
texts_2: list[str], texts_2: list[str],
request: Union[RerankRequest, ScoreRequest], request: Union[RerankRequest, ScoreRequest],
...@@ -141,7 +139,7 @@ class ServingScores(OpenAIServing): ...@@ -141,7 +139,7 @@ class ServingScores(OpenAIServing):
async def _cross_encoding_score( async def _cross_encoding_score(
self, self,
tokenizer: Union[AnyTokenizer], tokenizer: AnyTokenizer,
texts_1: list[str], texts_1: list[str],
texts_2: list[str], texts_2: list[str],
request: Union[RerankRequest, ScoreRequest], request: Union[RerankRequest, ScoreRequest],
...@@ -190,7 +188,7 @@ class ServingScores(OpenAIServing): ...@@ -190,7 +188,7 @@ class ServingScores(OpenAIServing):
# Schedule the request and get the result generator. # Schedule the request and get the result generator.
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = [] generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
pooling_params = request.to_pooling_params() pooling_params = request.to_pooling_params(use_cross_encoder=True)
for i, engine_prompt in enumerate(engine_prompts): for i, engine_prompt in enumerate(engine_prompts):
request_id_item = f"{request_id}-{i}" request_id_item = f"{request_id}-{i}"
......
...@@ -15,6 +15,7 @@ from vllm.model_executor.pooling_metadata import ( # noqa: E501 ...@@ -15,6 +15,7 @@ from vllm.model_executor.pooling_metadata import ( # noqa: E501
from vllm.model_executor.pooling_metadata import PoolingTensors from vllm.model_executor.pooling_metadata import PoolingTensors
from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput
from vllm.transformers_utils.config import ( from vllm.transformers_utils.config import (
get_classification_activation_function,
get_cross_encoder_activation_function) get_cross_encoder_activation_function)
from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata
...@@ -388,15 +389,14 @@ class ClassifierPooler(nn.Module): ...@@ -388,15 +389,14 @@ class ClassifierPooler(nn.Module):
self.classifier = classifier self.classifier = classifier
self.pooler = pooler self.pooler = pooler
if config.task == "score": self.classification_act_fn = get_classification_activation_function(
self.default_activation_function = \ config.hf_config)
get_cross_encoder_activation_function(config.hf_config) self.cross_encoder_act_fn = get_cross_encoder_activation_function(
elif config.task == "classify": config.hf_config)
self.default_activation_function = nn.Sigmoid() \
if config.hf_config.num_labels == 1 else nn.Softmax() def _get_act_fn(self, use_cross_encoder: bool):
else: return (self.cross_encoder_act_fn
raise NotImplementedError(f"task={config.task!r} is not supported" if use_cross_encoder else self.classification_act_fn)
" with the classification pooler")
def get_prompt_lens( def get_prompt_lens(
self, self,
...@@ -446,8 +446,28 @@ class ClassifierPooler(nn.Module): ...@@ -446,8 +446,28 @@ class ClassifierPooler(nn.Module):
# apply classifier once on the full batch if possible # apply classifier once on the full batch if possible
pooled_output = self.classifier(pooled_output) pooled_output = self.classifier(pooled_output)
# shape: (batch_size, num_labels) if isinstance(pooling_metadata, V0PoolingMetadata):
scores = self.default_activation_function(pooled_output) use_cross_encoder_list = [
pooling_param.use_cross_encoder
for _, pooling_param in pooling_metadata.seq_groups
]
else:
use_cross_encoder_list = [
pooling_param.use_cross_encoder
for pooling_param in pooling_metadata.pooling_params
]
# shape of scores: (batch_size, num_labels)
if all(use_cross_encoder == use_cross_encoder_list[0]
for use_cross_encoder in use_cross_encoder_list):
act_fn = self._get_act_fn(use_cross_encoder_list[0])
scores = act_fn(pooled_output)
else:
scores = torch.stack([
self._get_act_fn(use_cross_encoder)(vecs)
for use_cross_encoder, vecs in zip(use_cross_encoder_list,
pooled_output)
])
pooled_outputs = [PoolingSequenceGroupOutput(data) for data in scores] pooled_outputs = [PoolingSequenceGroupOutput(data) for data in scores]
return PoolerOutput(outputs=pooled_outputs) return PoolerOutput(outputs=pooled_outputs)
...@@ -25,8 +25,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -25,8 +25,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.transformers_utils.config import (
get_cross_encoder_activation_function)
from .interfaces import SupportsCrossEncoding, SupportsQuant, SupportsV0Only from .interfaces import SupportsCrossEncoding, SupportsQuant, SupportsV0Only
from .utils import WeightsMapper, maybe_prefix from .utils import WeightsMapper, maybe_prefix
...@@ -462,9 +460,6 @@ class BertForSequenceClassification(nn.Module, SupportsV0Only, ...@@ -462,9 +460,6 @@ class BertForSequenceClassification(nn.Module, SupportsV0Only,
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
self.default_activation_function = \
get_cross_encoder_activation_function(config)
self.num_labels = config.num_labels self.num_labels = config.num_labels
self.bert = BertModel(vllm_config=vllm_config, self.bert = BertModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "bert"), prefix=maybe_prefix(prefix, "bert"),
......
...@@ -18,8 +18,6 @@ from vllm.model_executor.models.bert import BertEmbeddingModel, BertModel ...@@ -18,8 +18,6 @@ from vllm.model_executor.models.bert import BertEmbeddingModel, BertModel
from vllm.model_executor.models.utils import WeightsMapper, maybe_prefix from vllm.model_executor.models.utils import WeightsMapper, maybe_prefix
from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.transformers_utils.config import (
get_cross_encoder_activation_function)
from .bert_with_rope import BertWithRope, JinaRobertaModel from .bert_with_rope import BertWithRope, JinaRobertaModel
from .interfaces import SupportsCrossEncoding, SupportsV0Only from .interfaces import SupportsCrossEncoding, SupportsV0Only
...@@ -178,9 +176,6 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding, ...@@ -178,9 +176,6 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
self.default_activation_function = \
get_cross_encoder_activation_function(config)
self.num_labels = config.num_labels self.num_labels = config.num_labels
self.roberta = BertModel(vllm_config=vllm_config, self.roberta = BertModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "bert"), prefix=maybe_prefix(prefix, "bert"),
......
...@@ -24,12 +24,14 @@ class PoolingParams( ...@@ -24,12 +24,14 @@ class PoolingParams(
""" """
dimensions: Optional[int] = None dimensions: Optional[int] = None
use_cross_encoder: bool = False
additional_data: Optional[Any] = None additional_data: Optional[Any] = None
output_kind: RequestOutputKind = RequestOutputKind.FINAL_ONLY output_kind: RequestOutputKind = RequestOutputKind.FINAL_ONLY
def clone(self) -> "PoolingParams": def clone(self) -> "PoolingParams":
"""Returns a deep copy of the PoolingParams instance.""" """Returns a deep copy of the PoolingParams instance."""
return PoolingParams(dimensions=self.dimensions, return PoolingParams(dimensions=self.dimensions,
use_cross_encoder=self.use_cross_encoder,
additional_data=self.additional_data) additional_data=self.additional_data)
def verify(self, model_config: "ModelConfig") -> None: def verify(self, model_config: "ModelConfig") -> None:
...@@ -54,6 +56,7 @@ class PoolingParams( ...@@ -54,6 +56,7 @@ class PoolingParams(
def __repr__(self) -> str: def __repr__(self) -> str:
return (f"PoolingParams(" return (f"PoolingParams("
f"dimensions={self.dimensions}, " f"dimensions={self.dimensions}, "
f"use_cross_encoder={self.use_cross_encoder}, "
f"additional_metadata={self.additional_data})") f"additional_metadata={self.additional_data})")
def __post_init__(self) -> None: def __post_init__(self) -> None:
......
...@@ -866,24 +866,26 @@ def try_get_generation_config( ...@@ -866,24 +866,26 @@ def try_get_generation_config(
return None return None
def get_cross_encoder_activation_function(config: PretrainedConfig): def get_classification_activation_function(config: PretrainedConfig):
return nn.Sigmoid() if config.num_labels == 1 else nn.Softmax()
def get_cross_encoder_activation_function(config: PretrainedConfig):
function_name: Optional[str] = None function_name: Optional[str] = None
if hasattr(config, "sentence_transformers") and "activation_fn" in \ if (hasattr(config, "sentence_transformers")
config.sentence_transformers: and "activation_fn" in config.sentence_transformers):
function_name = config.sentence_transformers["activation_fn"] function_name = config.sentence_transformers["activation_fn"]
elif (hasattr(config, "sbert_ce_default_activation_function") elif (hasattr(config, "sbert_ce_default_activation_function")
and config.sbert_ce_default_activation_function is not None): and config.sbert_ce_default_activation_function is not None):
function_name = config.sbert_ce_default_activation_function function_name = config.sbert_ce_default_activation_function
if function_name is not None: if function_name is not None:
assert function_name.startswith("torch.nn.modules."), \ assert function_name.startswith("torch.nn.modules."), (
"Loading of activation functions is restricted to " \ "Loading of activation functions is restricted to "
"torch.nn.modules for security reasons" "torch.nn.modules for security reasons")
return resolve_obj_by_qualname(function_name)() return resolve_obj_by_qualname(function_name)()
else:
return nn.Sigmoid() if config.num_labels == 1 else nn.Identity() return nn.Sigmoid() if config.num_labels == 1 else nn.Identity()
def try_get_safetensors_metadata( def try_get_safetensors_metadata(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment