Unverified Commit 214efc2c authored by Maximilien de Bayser's avatar Maximilien de Bayser Committed by GitHub
Browse files

Support Cross encoder models (#10400)


Signed-off-by: default avatarMax de Bayser <maxdebayser@gmail.com>
Signed-off-by: default avatarMax de Bayser <mbayser@br.ibm.com>
Signed-off-by: default avatarFlavia Beo <flavia.beo@ibm.com>
Co-authored-by: default avatarFlavia Beo <flavia.beo@ibm.com>
parent 49628fe1
...@@ -6,7 +6,7 @@ import numpy as np ...@@ -6,7 +6,7 @@ import numpy as np
import torch import torch
import torch.types import torch.types
from PIL.Image import Image from PIL.Image import Image
from typing_extensions import TypeAlias from typing_extensions import NotRequired, TypeAlias
from vllm.utils import JSONTree, is_list_of, json_map_leaves from vllm.utils import JSONTree, is_list_of, json_map_leaves
...@@ -208,6 +208,9 @@ class MultiModalInputsV2(TypedDict): ...@@ -208,6 +208,9 @@ class MultiModalInputsV2(TypedDict):
prompt_token_ids: List[int] prompt_token_ids: List[int]
"""The processed token IDs which includes placeholder tokens.""" """The processed token IDs which includes placeholder tokens."""
token_type_ids: NotRequired[List[int]]
"""The token type IDs of the prompt."""
mm_kwargs: MultiModalKwargs mm_kwargs: MultiModalKwargs
"""Keyword arguments to be directly passed to the model after batching.""" """Keyword arguments to be directly passed to the model after batching."""
......
...@@ -60,7 +60,6 @@ class EmbeddingOutput: ...@@ -60,7 +60,6 @@ class EmbeddingOutput:
embedding: The embedding vector, which is a list of floats. The embedding: The embedding vector, which is a list of floats. The
length of vector depends on the model as listed in the embedding guide. length of vector depends on the model as listed in the embedding guide.
""" """
embedding: List[float] embedding: List[float]
def __repr__(self) -> str: def __repr__(self) -> str:
...@@ -363,6 +362,50 @@ class EmbeddingRequestOutput: ...@@ -363,6 +362,50 @@ class EmbeddingRequestOutput:
f"finished={self.finished})") f"finished={self.finished})")
@dataclass
class ScoreOutput:
"""The output data of one completion output of a request.
Args:
score: The score, which is a list of floats.
index: The correspondent text index of the score.
"""
index: int
score: List[float]
def __repr__(self) -> str:
return (f"ScoreOutput("
f"score={self.score}), "
f"index={self.index})")
class ScoreRequestOutput:
"""
The output data of an score request to the LLM.
Args:
request_id (str): A unique identifier for the score request.
outputs (score): The embedding results for the given input.
"""
def __init__(self, request_id: str, outputs: "ScoreOutput"):
self.request_id = request_id
self.outputs = outputs
def __repr__(self):
"""
Returns a string representation of an ScoreRequestOutput instance.
The representation includes the request_id and the number of outputs,
providing a quick overview of the embedding request's results.
Returns:
str: A string representation of the ScoreRequestOutput instance.
"""
return (f"ScoreRequestOutput(request_id='{self.request_id}', "
f"outputs={repr(self.outputs)}")
class RequestOutputFactory: class RequestOutputFactory:
@staticmethod @staticmethod
......
...@@ -449,6 +449,10 @@ class Sequence: ...@@ -449,6 +449,10 @@ class Sequence:
def prompt_embeds(self) -> Optional[torch.Tensor]: def prompt_embeds(self) -> Optional[torch.Tensor]:
return self.inputs.prompt_embeds return self.inputs.prompt_embeds
@property
def token_type_ids(self) -> List[int]:
return self.inputs.token_type_ids
@property @property
def multi_modal_data(self) -> "MultiModalDataDict": def multi_modal_data(self) -> "MultiModalDataDict":
return self.inputs.multi_modal_data return self.inputs.multi_modal_data
...@@ -687,6 +691,10 @@ class SequenceGroup: ...@@ -687,6 +691,10 @@ class SequenceGroup:
return (self.encoder_seq.prompt_token_ids return (self.encoder_seq.prompt_token_ids
if self.encoder_seq is not None else None) if self.encoder_seq is not None else None)
@property
def token_type_ids(self) -> Optional[List[int]]:
return self.first_seq.token_type_ids
@property @property
def multi_modal_data(self) -> MultiModalDataDict: def multi_modal_data(self) -> MultiModalDataDict:
return self.first_seq.multi_modal_data return self.first_seq.multi_modal_data
...@@ -909,6 +917,7 @@ class SequenceGroupMetadata( ...@@ -909,6 +917,7 @@ class SequenceGroupMetadata(
default_factory=lambda: SequenceGroupState()) default_factory=lambda: SequenceGroupState())
# "MultiModalDataDict" types. We have to use Any due to msgspec # "MultiModalDataDict" types. We have to use Any due to msgspec
# doesn't allow to have union of 2 different dicts. # doesn't allow to have union of 2 different dicts.
token_type_ids: Optional[List[int]] = None
multi_modal_data: Optional[Any] = None multi_modal_data: Optional[Any] = None
multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None
mm_processor_kwargs: Optional[Dict[str, Any]] = None mm_processor_kwargs: Optional[Dict[str, Any]] = None
......
...@@ -9,6 +9,7 @@ from huggingface_hub import (file_exists, hf_hub_download, ...@@ -9,6 +9,7 @@ from huggingface_hub import (file_exists, hf_hub_download,
from huggingface_hub.utils import (EntryNotFoundError, LocalEntryNotFoundError, from huggingface_hub.utils import (EntryNotFoundError, LocalEntryNotFoundError,
RepositoryNotFoundError, RepositoryNotFoundError,
RevisionNotFoundError) RevisionNotFoundError)
from torch import nn
from transformers import GenerationConfig, PretrainedConfig from transformers import GenerationConfig, PretrainedConfig
from transformers.models.auto.image_processing_auto import ( from transformers.models.auto.image_processing_auto import (
get_image_processor_config) get_image_processor_config)
...@@ -31,6 +32,7 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig, ...@@ -31,6 +32,7 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig,
UltravoxConfig) UltravoxConfig)
# yapf: enable # yapf: enable
from vllm.transformers_utils.utils import check_gguf_file from vllm.transformers_utils.utils import check_gguf_file
from vllm.utils import resolve_obj_by_qualname
if VLLM_USE_MODELSCOPE: if VLLM_USE_MODELSCOPE:
from modelscope import AutoConfig from modelscope import AutoConfig
...@@ -577,3 +579,16 @@ def try_get_generation_config( ...@@ -577,3 +579,16 @@ def try_get_generation_config(
return GenerationConfig.from_model_config(config) return GenerationConfig.from_model_config(config)
except OSError: # Not found except OSError: # Not found
return None return None
def get_cross_encoder_activation_function(config: PretrainedConfig):
if (hasattr(config, "sbert_ce_default_activation_function")
and config.sbert_ce_default_activation_function is not None):
function_name = config.sbert_ce_default_activation_function
assert function_name.startswith("torch.nn.modules."), \
"Loading of activation functions is restricted to " \
"torch.nn.modules for security reasons"
return resolve_obj_by_qualname(function_name)()
else:
return nn.Sigmoid() if config.num_labels == 1 else nn.Identity()
...@@ -50,6 +50,9 @@ class CPUEmbeddingModelRunner( ...@@ -50,6 +50,9 @@ class CPUEmbeddingModelRunner(
] ]
model_executable = self.model model_executable = self.model
cross_enc_kwargs = {}
if model_input.token_type_ids is not None:
cross_enc_kwargs["token_type_ids"] = model_input.token_type_ids
execute_model_kwargs = { execute_model_kwargs = {
"input_ids": "input_ids":
model_input.input_tokens, model_input.input_tokens,
...@@ -61,6 +64,7 @@ class CPUEmbeddingModelRunner( ...@@ -61,6 +64,7 @@ class CPUEmbeddingModelRunner(
model_input.attn_metadata, model_input.attn_metadata,
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {}, **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {},
device=self.device), device=self.device),
**cross_enc_kwargs,
"intermediate_tensors": "intermediate_tensors":
intermediate_tensors, intermediate_tensors,
} }
......
...@@ -43,6 +43,7 @@ class ModelInputForCPU(ModelRunnerInputBase): ...@@ -43,6 +43,7 @@ class ModelInputForCPU(ModelRunnerInputBase):
""" """
input_tokens: Optional[torch.Tensor] = None input_tokens: Optional[torch.Tensor] = None
input_positions: Optional[torch.Tensor] = None input_positions: Optional[torch.Tensor] = None
token_type_ids: Optional[torch.Tensor] = None
attn_metadata: Optional["AttentionMetadata"] = None attn_metadata: Optional["AttentionMetadata"] = None
multi_modal_kwargs: Optional[BatchedTensorInputs] = None multi_modal_kwargs: Optional[BatchedTensorInputs] = None
virtual_engine: Optional[int] = None virtual_engine: Optional[int] = None
...@@ -54,6 +55,7 @@ class ModelInputForCPU(ModelRunnerInputBase): ...@@ -54,6 +55,7 @@ class ModelInputForCPU(ModelRunnerInputBase):
tensor_dict = { tensor_dict = {
"input_tokens": self.input_tokens, "input_tokens": self.input_tokens,
"input_positions": self.input_positions, "input_positions": self.input_positions,
"token_type_ids": self.token_type_ids,
"multi_modal_kwargs": self.multi_modal_kwargs, "multi_modal_kwargs": self.multi_modal_kwargs,
} }
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
...@@ -83,6 +85,7 @@ class ModelInputForCPUWithSamplingMetadata(ModelInputForCPU): ...@@ -83,6 +85,7 @@ class ModelInputForCPUWithSamplingMetadata(ModelInputForCPU):
tensor_dict = { tensor_dict = {
"input_tokens": self.input_tokens, "input_tokens": self.input_tokens,
"input_positions": self.input_positions, "input_positions": self.input_positions,
"token_type_ids": self.token_type_ids,
"multi_modal_kwargs": self.multi_modal_kwargs, "multi_modal_kwargs": self.multi_modal_kwargs,
} }
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
...@@ -112,6 +115,7 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]): ...@@ -112,6 +115,7 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
self.input_tokens: List[int] = [] self.input_tokens: List[int] = []
self.input_positions: Optional[ self.input_positions: Optional[
List[int]] = [] if not self.use_mrope else None List[int]] = [] if not self.use_mrope else None
self.token_type_ids: Optional[List[int]] = []
self.seq_lens: List[int] = [] self.seq_lens: List[int] = []
self.query_lens: List[int] = [] self.query_lens: List[int] = []
self.prefill_block_tables: List[List[int]] = [] self.prefill_block_tables: List[List[int]] = []
...@@ -165,6 +169,10 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]): ...@@ -165,6 +169,10 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
if not input_data.use_mrope else input_data.input_mrope_positions, if not input_data.use_mrope else input_data.input_mrope_positions,
dtype=torch.long, dtype=torch.long,
device="cpu") device="cpu")
token_type_ids = torch.tensor(input_data.token_type_ids,
dtype=torch.long,
device="cpu") \
if input_data.token_type_ids else None
# For multi-modal models # For multi-modal models
multi_modal_kwargs = None multi_modal_kwargs = None
...@@ -178,6 +186,7 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]): ...@@ -178,6 +186,7 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
return self.model_input_cls( return self.model_input_cls(
input_tokens=input_tokens, input_tokens=input_tokens,
input_positions=input_positions, input_positions=input_positions,
token_type_ids=token_type_ids,
seq_lens=input_data.seq_lens, seq_lens=input_data.seq_lens,
query_lens=input_data.query_lens, query_lens=input_data.query_lens,
attn_metadata=attn_metadata, attn_metadata=attn_metadata,
...@@ -285,6 +294,7 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]): ...@@ -285,6 +294,7 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
tokens = seq_data.get_token_ids() tokens = seq_data.get_token_ids()
tokens = tokens[context_len:seq_len] tokens = tokens[context_len:seq_len]
token_positions = range(context_len, seq_len) token_positions = range(context_len, seq_len)
token_types = seq_group_metadata.token_type_ids
# For encoder-only models, the block_table is None, # For encoder-only models, the block_table is None,
# and there is no need to initialize the slot_mapping. # and there is no need to initialize the slot_mapping.
...@@ -301,6 +311,9 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]): ...@@ -301,6 +311,9 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
if data.input_positions is not None: if data.input_positions is not None:
data.input_positions.extend(token_positions) data.input_positions.extend(token_positions)
if data.token_type_ids is not None:
data.token_type_ids.extend(token_types if token_types else [])
# Update fields # Update fields
data.input_tokens.extend(tokens) data.input_tokens.extend(tokens)
data.num_prefills += 1 data.num_prefills += 1
......
...@@ -97,6 +97,10 @@ class EmbeddingModelRunner( ...@@ -97,6 +97,10 @@ class EmbeddingModelRunner(
model_forward_end = torch.cuda.Event(enable_timing=True) model_forward_end = torch.cuda.Event(enable_timing=True)
model_forward_start.record() model_forward_start.record()
cross_enc_kwargs = {}
if model_input.token_types is not None:
cross_enc_kwargs["token_type_ids"] = model_input.token_types
with set_forward_context(model_input.attn_metadata, self.vllm_config): with set_forward_context(model_input.attn_metadata, self.vllm_config):
hidden_or_intermediate_states = model_executable( hidden_or_intermediate_states = model_executable(
input_ids=model_input.input_tokens, input_ids=model_input.input_tokens,
...@@ -105,7 +109,8 @@ class EmbeddingModelRunner( ...@@ -105,7 +109,8 @@ class EmbeddingModelRunner(
attn_metadata=model_input.attn_metadata, attn_metadata=model_input.attn_metadata,
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(multi_modal_kwargs, **MultiModalKwargs.as_kwargs(multi_modal_kwargs,
device=self.device)) device=self.device),
**cross_enc_kwargs)
if (self.observability_config is not None if (self.observability_config is not None
and self.observability_config.collect_model_forward_time): and self.observability_config.collect_model_forward_time):
......
...@@ -92,6 +92,7 @@ class ModelInputForGPU(ModelRunnerInputBase): ...@@ -92,6 +92,7 @@ class ModelInputForGPU(ModelRunnerInputBase):
""" """
input_tokens: Optional[torch.Tensor] = None input_tokens: Optional[torch.Tensor] = None
input_positions: Optional[torch.Tensor] = None input_positions: Optional[torch.Tensor] = None
token_types: Optional[torch.Tensor] = None
seq_lens: Optional[List[int]] = None seq_lens: Optional[List[int]] = None
query_lens: Optional[List[int]] = None query_lens: Optional[List[int]] = None
lora_mapping: Optional["LoRAMapping"] = None lora_mapping: Optional["LoRAMapping"] = None
...@@ -200,6 +201,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -200,6 +201,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
def simple_reinit(self): def simple_reinit(self):
self.input_tokens[0].clear() # type: ignore self.input_tokens[0].clear() # type: ignore
self.input_positions[0].clear() # type: ignore self.input_positions[0].clear() # type: ignore
self.token_types[0].clear() # type: ignore
self.mrope_input_positions = None # type: ignore self.mrope_input_positions = None # type: ignore
self.seq_lens[0] = 0 # type: ignore self.seq_lens[0] = 0 # type: ignore
self.orig_seq_lens[0] = 0 # type: ignore self.orig_seq_lens[0] = 0 # type: ignore
...@@ -226,6 +228,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -226,6 +228,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
# Input tokens and positions. # Input tokens and positions.
input_tokens: Optional[List[List[int]]] = None, input_tokens: Optional[List[List[int]]] = None,
input_positions: Optional[List[List[int]]] = None, input_positions: Optional[List[List[int]]] = None,
token_types: Optional[List[List[int]]] = None,
mrope_input_positions: Optional[List[List[List[int]]]] = None, mrope_input_positions: Optional[List[List[List[int]]]] = None,
# The sequence length (may be capped to the sliding window). # The sequence length (may be capped to the sliding window).
...@@ -291,6 +294,12 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -291,6 +294,12 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
for seq_id in range(len(self.seq_ids)): for seq_id in range(len(self.seq_ids)):
self.input_positions[seq_id].clear() self.input_positions[seq_id].clear()
if token_types:
self.token_types = token_types
else:
for seq_id in range(len(self.seq_ids)):
self.token_types[seq_id].clear()
self.mrope_input_positions = None self.mrope_input_positions = None
if seq_lens: if seq_lens:
...@@ -354,6 +363,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -354,6 +363,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
else: else:
self.input_tokens = input_tokens or [] self.input_tokens = input_tokens or []
self.input_positions = input_positions or [] self.input_positions = input_positions or []
self.token_types = token_types or []
self.mrope_input_positions = mrope_input_positions or None self.mrope_input_positions = mrope_input_positions or None
self.seq_lens = seq_lens or [] self.seq_lens = seq_lens or []
self.orig_seq_lens = orig_seq_lens or [] self.orig_seq_lens = orig_seq_lens or []
...@@ -386,6 +396,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -386,6 +396,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self.input_tokens = [[] for _ in range(self.n_seqs)] self.input_tokens = [[] for _ in range(self.n_seqs)]
self.input_positions = [[] for _ in range(self.n_seqs)] self.input_positions = [[] for _ in range(self.n_seqs)]
self.token_types = [[] for _ in range(self.n_seqs)]
self.mrope_input_positions = None self.mrope_input_positions = None
self.seq_lens = [0] * self.n_seqs self.seq_lens = [0] * self.n_seqs
self.orig_seq_lens = [0] * self.n_seqs self.orig_seq_lens = [0] * self.n_seqs
...@@ -498,12 +509,15 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -498,12 +509,15 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
# Compute tokens. # Compute tokens.
tokens = seq_data.get_token_ids()[context_len:seq_len] tokens = seq_data.get_token_ids()[context_len:seq_len]
token_types = seq_group_metadata.token_type_ids
inter_data.seq_lens[seq_idx] = seq_len inter_data.seq_lens[seq_idx] = seq_len
inter_data.orig_seq_lens[seq_idx] = seq_len inter_data.orig_seq_lens[seq_idx] = seq_len
inter_data.context_lens[seq_idx] = context_len inter_data.context_lens[seq_idx] = context_len
inter_data.input_tokens[seq_idx].extend(tokens) inter_data.input_tokens[seq_idx].extend(tokens)
inter_data.input_positions[seq_idx].extend(range(context_len, seq_len)) inter_data.input_positions[seq_idx].extend(range(context_len, seq_len))
inter_data.token_types[seq_idx].extend(
token_types if token_types else [])
inter_data.query_lens[seq_idx] = seq_len - context_len inter_data.query_lens[seq_idx] = seq_len - context_len
if seq_data.mrope_position_delta is not None: if seq_data.mrope_position_delta is not None:
...@@ -561,6 +575,8 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -561,6 +575,8 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
seq_idx][uncomputed_start:] seq_idx][uncomputed_start:]
inter_data.input_positions[seq_idx] = inter_data.input_positions[ inter_data.input_positions[seq_idx] = inter_data.input_positions[
seq_idx][uncomputed_start:] seq_idx][uncomputed_start:]
inter_data.token_types[seq_idx] = inter_data.token_types[seq_idx][
uncomputed_start:]
context_len = prefix_cache_len context_len = prefix_cache_len
inter_data.context_lens[seq_idx] = context_len inter_data.context_lens[seq_idx] = context_len
...@@ -575,6 +591,8 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -575,6 +591,8 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
seq_idx][-1:] seq_idx][-1:]
inter_data.input_positions[seq_idx] = inter_data.input_positions[ inter_data.input_positions[seq_idx] = inter_data.input_positions[
seq_idx][-1:] seq_idx][-1:]
inter_data.token_types[seq_idx] = inter_data.token_types[seq_idx][
-1:]
inter_data.query_lens[seq_idx] = 1 inter_data.query_lens[seq_idx] = 1
inter_data.context_lens[seq_idx] = inter_data.seq_lens[seq_idx] - 1 inter_data.context_lens[seq_idx] = inter_data.seq_lens[seq_idx] - 1
...@@ -803,9 +821,12 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -803,9 +821,12 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
""" """
# Combine and flatten intermediate data. # Combine and flatten intermediate data.
input_tokens = [] input_tokens = []
token_types = []
for inter_data in self.inter_data_list: for inter_data in self.inter_data_list:
for cur_input_tokens in inter_data.input_tokens: for cur_input_tokens in inter_data.input_tokens:
input_tokens.extend(cur_input_tokens) input_tokens.extend(cur_input_tokens)
for cur_token_types in inter_data.token_types:
token_types.extend(cur_token_types)
if not input_tokens: if not input_tokens:
# This may happen when all prefill requests hit # This may happen when all prefill requests hit
...@@ -874,6 +895,12 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -874,6 +895,12 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
input_tokens_tensor = async_tensor_h2d(input_tokens, torch.long, input_tokens_tensor = async_tensor_h2d(input_tokens, torch.long,
self.runner.device, self.runner.device,
self.runner.pin_memory) self.runner.pin_memory)
token_types_tensor = async_tensor_h2d(token_types, torch.long,
self.runner.device,
self.runner.pin_memory) \
if token_types else None
if mrope_input_positions is not None: if mrope_input_positions is not None:
for idx in range(3): for idx in range(3):
mrope_input_positions[idx].extend( mrope_input_positions[idx].extend(
...@@ -952,6 +979,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -952,6 +979,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
return self.model_input_cls( return self.model_input_cls(
input_tokens=input_tokens_tensor, input_tokens=input_tokens_tensor,
input_positions=input_positions_tensor, input_positions=input_positions_tensor,
token_types=token_types_tensor,
attn_metadata=attn_metadata, attn_metadata=attn_metadata,
seq_lens=seq_lens, seq_lens=seq_lens,
query_lens=query_lens, query_lens=query_lens,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment