Unverified Commit 214efc2c authored by Maximilien de Bayser's avatar Maximilien de Bayser Committed by GitHub
Browse files

Support Cross encoder models (#10400)


Signed-off-by: default avatarMax de Bayser <maxdebayser@gmail.com>
Signed-off-by: default avatarMax de Bayser <mbayser@br.ibm.com>
Signed-off-by: default avatarFlavia Beo <flavia.beo@ibm.com>
Co-authored-by: default avatarFlavia Beo <flavia.beo@ibm.com>
parent 49628fe1
......@@ -6,7 +6,7 @@ import numpy as np
import torch
import torch.types
from PIL.Image import Image
from typing_extensions import TypeAlias
from typing_extensions import NotRequired, TypeAlias
from vllm.utils import JSONTree, is_list_of, json_map_leaves
......@@ -208,6 +208,9 @@ class MultiModalInputsV2(TypedDict):
prompt_token_ids: List[int]
"""The processed token IDs which includes placeholder tokens."""
token_type_ids: NotRequired[List[int]]
"""The token type IDs of the prompt."""
mm_kwargs: MultiModalKwargs
"""Keyword arguments to be directly passed to the model after batching."""
......
......@@ -60,7 +60,6 @@ class EmbeddingOutput:
embedding: The embedding vector, which is a list of floats. The
length of vector depends on the model as listed in the embedding guide.
"""
embedding: List[float]
def __repr__(self) -> str:
......@@ -363,6 +362,50 @@ class EmbeddingRequestOutput:
f"finished={self.finished})")
@dataclass
class ScoreOutput:
"""The output data of one completion output of a request.
Args:
score: The score, which is a list of floats.
index: The correspondent text index of the score.
"""
index: int
score: List[float]
def __repr__(self) -> str:
return (f"ScoreOutput("
f"score={self.score}), "
f"index={self.index})")
class ScoreRequestOutput:
"""
The output data of an score request to the LLM.
Args:
request_id (str): A unique identifier for the score request.
outputs (score): The embedding results for the given input.
"""
def __init__(self, request_id: str, outputs: "ScoreOutput"):
self.request_id = request_id
self.outputs = outputs
def __repr__(self):
"""
Returns a string representation of an ScoreRequestOutput instance.
The representation includes the request_id and the number of outputs,
providing a quick overview of the embedding request's results.
Returns:
str: A string representation of the ScoreRequestOutput instance.
"""
return (f"ScoreRequestOutput(request_id='{self.request_id}', "
f"outputs={repr(self.outputs)}")
class RequestOutputFactory:
@staticmethod
......
......@@ -449,6 +449,10 @@ class Sequence:
def prompt_embeds(self) -> Optional[torch.Tensor]:
return self.inputs.prompt_embeds
@property
def token_type_ids(self) -> List[int]:
return self.inputs.token_type_ids
@property
def multi_modal_data(self) -> "MultiModalDataDict":
return self.inputs.multi_modal_data
......@@ -687,6 +691,10 @@ class SequenceGroup:
return (self.encoder_seq.prompt_token_ids
if self.encoder_seq is not None else None)
@property
def token_type_ids(self) -> Optional[List[int]]:
return self.first_seq.token_type_ids
@property
def multi_modal_data(self) -> MultiModalDataDict:
return self.first_seq.multi_modal_data
......@@ -909,6 +917,7 @@ class SequenceGroupMetadata(
default_factory=lambda: SequenceGroupState())
# "MultiModalDataDict" types. We have to use Any due to msgspec
# doesn't allow to have union of 2 different dicts.
token_type_ids: Optional[List[int]] = None
multi_modal_data: Optional[Any] = None
multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None
mm_processor_kwargs: Optional[Dict[str, Any]] = None
......
......@@ -9,6 +9,7 @@ from huggingface_hub import (file_exists, hf_hub_download,
from huggingface_hub.utils import (EntryNotFoundError, LocalEntryNotFoundError,
RepositoryNotFoundError,
RevisionNotFoundError)
from torch import nn
from transformers import GenerationConfig, PretrainedConfig
from transformers.models.auto.image_processing_auto import (
get_image_processor_config)
......@@ -31,6 +32,7 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig,
UltravoxConfig)
# yapf: enable
from vllm.transformers_utils.utils import check_gguf_file
from vllm.utils import resolve_obj_by_qualname
if VLLM_USE_MODELSCOPE:
from modelscope import AutoConfig
......@@ -577,3 +579,16 @@ def try_get_generation_config(
return GenerationConfig.from_model_config(config)
except OSError: # Not found
return None
def get_cross_encoder_activation_function(config: PretrainedConfig):
if (hasattr(config, "sbert_ce_default_activation_function")
and config.sbert_ce_default_activation_function is not None):
function_name = config.sbert_ce_default_activation_function
assert function_name.startswith("torch.nn.modules."), \
"Loading of activation functions is restricted to " \
"torch.nn.modules for security reasons"
return resolve_obj_by_qualname(function_name)()
else:
return nn.Sigmoid() if config.num_labels == 1 else nn.Identity()
......@@ -50,6 +50,9 @@ class CPUEmbeddingModelRunner(
]
model_executable = self.model
cross_enc_kwargs = {}
if model_input.token_type_ids is not None:
cross_enc_kwargs["token_type_ids"] = model_input.token_type_ids
execute_model_kwargs = {
"input_ids":
model_input.input_tokens,
......@@ -61,6 +64,7 @@ class CPUEmbeddingModelRunner(
model_input.attn_metadata,
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {},
device=self.device),
**cross_enc_kwargs,
"intermediate_tensors":
intermediate_tensors,
}
......
......@@ -43,6 +43,7 @@ class ModelInputForCPU(ModelRunnerInputBase):
"""
input_tokens: Optional[torch.Tensor] = None
input_positions: Optional[torch.Tensor] = None
token_type_ids: Optional[torch.Tensor] = None
attn_metadata: Optional["AttentionMetadata"] = None
multi_modal_kwargs: Optional[BatchedTensorInputs] = None
virtual_engine: Optional[int] = None
......@@ -54,6 +55,7 @@ class ModelInputForCPU(ModelRunnerInputBase):
tensor_dict = {
"input_tokens": self.input_tokens,
"input_positions": self.input_positions,
"token_type_ids": self.token_type_ids,
"multi_modal_kwargs": self.multi_modal_kwargs,
}
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
......@@ -83,6 +85,7 @@ class ModelInputForCPUWithSamplingMetadata(ModelInputForCPU):
tensor_dict = {
"input_tokens": self.input_tokens,
"input_positions": self.input_positions,
"token_type_ids": self.token_type_ids,
"multi_modal_kwargs": self.multi_modal_kwargs,
}
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
......@@ -112,6 +115,7 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
self.input_tokens: List[int] = []
self.input_positions: Optional[
List[int]] = [] if not self.use_mrope else None
self.token_type_ids: Optional[List[int]] = []
self.seq_lens: List[int] = []
self.query_lens: List[int] = []
self.prefill_block_tables: List[List[int]] = []
......@@ -165,6 +169,10 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
if not input_data.use_mrope else input_data.input_mrope_positions,
dtype=torch.long,
device="cpu")
token_type_ids = torch.tensor(input_data.token_type_ids,
dtype=torch.long,
device="cpu") \
if input_data.token_type_ids else None
# For multi-modal models
multi_modal_kwargs = None
......@@ -178,6 +186,7 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
return self.model_input_cls(
input_tokens=input_tokens,
input_positions=input_positions,
token_type_ids=token_type_ids,
seq_lens=input_data.seq_lens,
query_lens=input_data.query_lens,
attn_metadata=attn_metadata,
......@@ -285,6 +294,7 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
tokens = seq_data.get_token_ids()
tokens = tokens[context_len:seq_len]
token_positions = range(context_len, seq_len)
token_types = seq_group_metadata.token_type_ids
# For encoder-only models, the block_table is None,
# and there is no need to initialize the slot_mapping.
......@@ -301,6 +311,9 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
if data.input_positions is not None:
data.input_positions.extend(token_positions)
if data.token_type_ids is not None:
data.token_type_ids.extend(token_types if token_types else [])
# Update fields
data.input_tokens.extend(tokens)
data.num_prefills += 1
......
......@@ -97,6 +97,10 @@ class EmbeddingModelRunner(
model_forward_end = torch.cuda.Event(enable_timing=True)
model_forward_start.record()
cross_enc_kwargs = {}
if model_input.token_types is not None:
cross_enc_kwargs["token_type_ids"] = model_input.token_types
with set_forward_context(model_input.attn_metadata, self.vllm_config):
hidden_or_intermediate_states = model_executable(
input_ids=model_input.input_tokens,
......@@ -105,7 +109,8 @@ class EmbeddingModelRunner(
attn_metadata=model_input.attn_metadata,
intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
device=self.device))
device=self.device),
**cross_enc_kwargs)
if (self.observability_config is not None
and self.observability_config.collect_model_forward_time):
......
......@@ -92,6 +92,7 @@ class ModelInputForGPU(ModelRunnerInputBase):
"""
input_tokens: Optional[torch.Tensor] = None
input_positions: Optional[torch.Tensor] = None
token_types: Optional[torch.Tensor] = None
seq_lens: Optional[List[int]] = None
query_lens: Optional[List[int]] = None
lora_mapping: Optional["LoRAMapping"] = None
......@@ -200,6 +201,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
def simple_reinit(self):
self.input_tokens[0].clear() # type: ignore
self.input_positions[0].clear() # type: ignore
self.token_types[0].clear() # type: ignore
self.mrope_input_positions = None # type: ignore
self.seq_lens[0] = 0 # type: ignore
self.orig_seq_lens[0] = 0 # type: ignore
......@@ -226,6 +228,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
# Input tokens and positions.
input_tokens: Optional[List[List[int]]] = None,
input_positions: Optional[List[List[int]]] = None,
token_types: Optional[List[List[int]]] = None,
mrope_input_positions: Optional[List[List[List[int]]]] = None,
# The sequence length (may be capped to the sliding window).
......@@ -291,6 +294,12 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
for seq_id in range(len(self.seq_ids)):
self.input_positions[seq_id].clear()
if token_types:
self.token_types = token_types
else:
for seq_id in range(len(self.seq_ids)):
self.token_types[seq_id].clear()
self.mrope_input_positions = None
if seq_lens:
......@@ -354,6 +363,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
else:
self.input_tokens = input_tokens or []
self.input_positions = input_positions or []
self.token_types = token_types or []
self.mrope_input_positions = mrope_input_positions or None
self.seq_lens = seq_lens or []
self.orig_seq_lens = orig_seq_lens or []
......@@ -386,6 +396,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self.input_tokens = [[] for _ in range(self.n_seqs)]
self.input_positions = [[] for _ in range(self.n_seqs)]
self.token_types = [[] for _ in range(self.n_seqs)]
self.mrope_input_positions = None
self.seq_lens = [0] * self.n_seqs
self.orig_seq_lens = [0] * self.n_seqs
......@@ -498,12 +509,15 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
# Compute tokens.
tokens = seq_data.get_token_ids()[context_len:seq_len]
token_types = seq_group_metadata.token_type_ids
inter_data.seq_lens[seq_idx] = seq_len
inter_data.orig_seq_lens[seq_idx] = seq_len
inter_data.context_lens[seq_idx] = context_len
inter_data.input_tokens[seq_idx].extend(tokens)
inter_data.input_positions[seq_idx].extend(range(context_len, seq_len))
inter_data.token_types[seq_idx].extend(
token_types if token_types else [])
inter_data.query_lens[seq_idx] = seq_len - context_len
if seq_data.mrope_position_delta is not None:
......@@ -561,6 +575,8 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
seq_idx][uncomputed_start:]
inter_data.input_positions[seq_idx] = inter_data.input_positions[
seq_idx][uncomputed_start:]
inter_data.token_types[seq_idx] = inter_data.token_types[seq_idx][
uncomputed_start:]
context_len = prefix_cache_len
inter_data.context_lens[seq_idx] = context_len
......@@ -575,6 +591,8 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
seq_idx][-1:]
inter_data.input_positions[seq_idx] = inter_data.input_positions[
seq_idx][-1:]
inter_data.token_types[seq_idx] = inter_data.token_types[seq_idx][
-1:]
inter_data.query_lens[seq_idx] = 1
inter_data.context_lens[seq_idx] = inter_data.seq_lens[seq_idx] - 1
......@@ -803,9 +821,12 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
"""
# Combine and flatten intermediate data.
input_tokens = []
token_types = []
for inter_data in self.inter_data_list:
for cur_input_tokens in inter_data.input_tokens:
input_tokens.extend(cur_input_tokens)
for cur_token_types in inter_data.token_types:
token_types.extend(cur_token_types)
if not input_tokens:
# This may happen when all prefill requests hit
......@@ -874,6 +895,12 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
input_tokens_tensor = async_tensor_h2d(input_tokens, torch.long,
self.runner.device,
self.runner.pin_memory)
token_types_tensor = async_tensor_h2d(token_types, torch.long,
self.runner.device,
self.runner.pin_memory) \
if token_types else None
if mrope_input_positions is not None:
for idx in range(3):
mrope_input_positions[idx].extend(
......@@ -952,6 +979,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
return self.model_input_cls(
input_tokens=input_tokens_tensor,
input_positions=input_positions_tensor,
token_types=token_types_tensor,
attn_metadata=attn_metadata,
seq_lens=seq_lens,
query_lens=query_lens,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment