Commit 317a82e2 authored by chenych's avatar chenych
Browse files

Add QWQ-32B

parent 37b0ad9f
# Copyright 2024 the LlamaFactory team. # Copyright 2025 the LlamaFactory team.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -86,19 +86,25 @@ class StringFormatter(Formatter): ...@@ -86,19 +86,25 @@ class StringFormatter(Formatter):
elif isinstance(slot, (dict, set)): elif isinstance(slot, (dict, set)):
elements.append(slot) elements.append(slot)
else: else:
raise RuntimeError(f"Input must be string, set[str] or dict[str, str], got {type(slot)}") raise RuntimeError(f"Input must be string, set[str] or dict[str, str], got {type(slot)}.")
return elements return elements
@dataclass @dataclass
class FunctionFormatter(Formatter): class FunctionFormatter(StringFormatter):
def __post_init__(self): def __post_init__(self):
super().__post_init__()
self.tool_utils = get_tool_utils(self.tool_format) self.tool_utils = get_tool_utils(self.tool_format)
@override @override
def apply(self, **kwargs) -> SLOTS: def apply(self, **kwargs) -> SLOTS:
content = kwargs.pop("content") content: str = kwargs.pop("content")
regex = re.compile(r"<think>(.*)</think>", re.DOTALL)
thought = re.search(regex, content)
if thought:
content = content.replace(thought.group(0), "")
functions: List["FunctionCall"] = [] functions: List["FunctionCall"] = []
try: try:
tool_calls = json.loads(content) tool_calls = json.loads(content)
...@@ -111,16 +117,13 @@ class FunctionFormatter(Formatter): ...@@ -111,16 +117,13 @@ class FunctionFormatter(Formatter):
) )
except json.JSONDecodeError: except json.JSONDecodeError:
raise RuntimeError(f"Invalid JSON format in function message: {str([content])}") # flat string raise RuntimeError(f"Invalid JSON format in function message: {str([content])}.") # flat string
elements = [] function_str = self.tool_utils.function_formatter(functions)
for slot in self.slots: if thought:
if slot == "{{content}}": function_str = thought.group(1) + function_str
elements += self.tool_utils.function_formatter(functions)
else:
elements.append(slot)
return elements return super().apply(content=function_str)
@dataclass @dataclass
...@@ -135,7 +138,7 @@ class ToolFormatter(Formatter): ...@@ -135,7 +138,7 @@ class ToolFormatter(Formatter):
tools = json.loads(content) tools = json.loads(content)
return [self.tool_utils.tool_formatter(tools) if len(tools) != 0 else ""] return [self.tool_utils.tool_formatter(tools) if len(tools) != 0 else ""]
except json.JSONDecodeError: except json.JSONDecodeError:
raise RuntimeError(f"Invalid JSON format in tool description: {str([content])}") # flat string raise RuntimeError(f"Invalid JSON format in tool description: {str([content])}.") # flat string
@override @override
def extract(self, content: str) -> Union[str, List["FunctionCall"]]: def extract(self, content: str) -> Union[str, List["FunctionCall"]]:
......
# Copyright 2024 the LlamaFactory team. # Copyright 2025 the LlamaFactory team.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -22,10 +22,17 @@ from datasets import DatasetDict, load_dataset, load_from_disk ...@@ -22,10 +22,17 @@ from datasets import DatasetDict, load_dataset, load_from_disk
from ..extras import logging from ..extras import logging
from ..extras.constants import FILEEXT2TYPE from ..extras.constants import FILEEXT2TYPE
from ..extras.misc import check_version, has_tokenized_data from ..extras.misc import check_version, has_tokenized_data
from .aligner import align_dataset from .converter import align_dataset
from .data_utils import merge_dataset, split_dataset from .data_utils import merge_dataset, split_dataset
from .parser import get_dataset_list from .parser import get_dataset_list
from .preprocess import get_preprocess_and_print_func from .processor import (
FeedbackDatasetProcessor,
PackedSupervisedDatasetProcessor,
PairwiseDatasetProcessor,
PretrainDatasetProcessor,
SupervisedDatasetProcessor,
UnsupervisedDatasetProcessor,
)
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -35,6 +42,7 @@ if TYPE_CHECKING: ...@@ -35,6 +42,7 @@ if TYPE_CHECKING:
from ..hparams import DataArguments, ModelArguments from ..hparams import DataArguments, ModelArguments
from .data_utils import DatasetModule from .data_utils import DatasetModule
from .parser import DatasetAttr from .parser import DatasetAttr
from .processor import DatasetProcessor
from .template import Template from .template import Template
...@@ -156,21 +164,67 @@ def _get_merged_dataset( ...@@ -156,21 +164,67 @@ def _get_merged_dataset(
data_args: "DataArguments", data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo", "kto"], stage: Literal["pt", "sft", "rm", "ppo", "kto"],
) -> Optional[Union["Dataset", "IterableDataset"]]: merge: bool = True,
) -> Optional[Union["Dataset", "IterableDataset", Dict[str, "Dataset"]]]:
r""" r"""
Gets the merged datasets in the standard format. Returns the merged datasets in the standard format.
""" """
if dataset_names is None: if dataset_names is None:
return None return None
datasets = [] datasets = {}
for dataset_attr in get_dataset_list(dataset_names, data_args.dataset_dir): for dataset_name, dataset_attr in zip(dataset_names, get_dataset_list(dataset_names, data_args.dataset_dir)):
if (stage == "rm" and dataset_attr.ranking is False) or (stage != "rm" and dataset_attr.ranking is True): if (stage == "rm" and dataset_attr.ranking is False) or (stage != "rm" and dataset_attr.ranking is True):
raise ValueError("The dataset is not applicable in the current training stage.") raise ValueError("The dataset is not applicable in the current training stage.")
datasets.append(_load_single_dataset(dataset_attr, model_args, data_args, training_args)) datasets[dataset_name] = _load_single_dataset(dataset_attr, model_args, data_args, training_args)
if merge:
return merge_dataset(list(datasets.values()), data_args, seed=training_args.seed)
else:
return datasets
def _get_dataset_processor(
data_args: "DataArguments",
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
do_generate: bool = False,
) -> "DatasetProcessor":
r"""
Returns the corresponding dataset processor.
"""
if stage == "pt":
dataset_processor_class = PretrainDatasetProcessor
elif stage == "sft" and not do_generate:
if data_args.packing:
if data_args.neat_packing: # hack datasets to have int32 attention mask
from datasets.arrow_writer import OptimizedTypedSequence, TypedSequence
def __init__(self, data, **kwargs):
return TypedSequence.__init__(
self,
data,
type=kwargs.pop("type", None),
try_type=kwargs.pop("try_type", None),
optimized_int_type=kwargs.pop("optimized_int_type", None),
)
OptimizedTypedSequence.__init__ = __init__
dataset_processor_class = PackedSupervisedDatasetProcessor
else:
dataset_processor_class = SupervisedDatasetProcessor
elif stage == "rm":
dataset_processor_class = PairwiseDatasetProcessor
elif stage == "kto":
dataset_processor_class = FeedbackDatasetProcessor
else:
dataset_processor_class = UnsupervisedDatasetProcessor
return merge_dataset(datasets, data_args, seed=training_args.seed) return dataset_processor_class(template=template, tokenizer=tokenizer, processor=processor, data_args=data_args)
def _get_preprocessed_dataset( def _get_preprocessed_dataset(
...@@ -189,7 +243,7 @@ def _get_preprocessed_dataset( ...@@ -189,7 +243,7 @@ def _get_preprocessed_dataset(
if dataset is None: if dataset is None:
return None return None
preprocess_func, print_function = get_preprocess_and_print_func( dataset_processor = _get_dataset_processor(
data_args, stage, template, tokenizer, processor, do_generate=(training_args.predict_with_generate and is_eval) data_args, stage, template, tokenizer, processor, do_generate=(training_args.predict_with_generate and is_eval)
) )
column_names = list(next(iter(dataset)).keys()) column_names = list(next(iter(dataset)).keys())
...@@ -202,7 +256,7 @@ def _get_preprocessed_dataset( ...@@ -202,7 +256,7 @@ def _get_preprocessed_dataset(
) )
dataset = dataset.map( dataset = dataset.map(
preprocess_func, dataset_processor.preprocess_dataset,
batched=True, batched=True,
batch_size=data_args.preprocessing_batch_size, batch_size=data_args.preprocessing_batch_size,
remove_columns=column_names, remove_columns=column_names,
...@@ -212,7 +266,7 @@ def _get_preprocessed_dataset( ...@@ -212,7 +266,7 @@ def _get_preprocessed_dataset(
if training_args.should_log: if training_args.should_log:
try: try:
print("eval example:" if is_eval else "training example:") print("eval example:" if is_eval else "training example:")
print_function(next(iter(dataset))) dataset_processor.print_data_example(next(iter(dataset)))
except StopIteration: except StopIteration:
if stage == "pt": if stage == "pt":
raise RuntimeError("Cannot find sufficient samples, consider increasing dataset size.") raise RuntimeError("Cannot find sufficient samples, consider increasing dataset size.")
...@@ -234,7 +288,7 @@ def get_dataset( ...@@ -234,7 +288,7 @@ def get_dataset(
r""" r"""
Gets the train dataset and optionally gets the evaluation dataset. Gets the train dataset and optionally gets the evaluation dataset.
""" """
# Load tokenized dataset # Load tokenized dataset if path exists
if data_args.tokenized_path is not None: if data_args.tokenized_path is not None:
if has_tokenized_data(data_args.tokenized_path): if has_tokenized_data(data_args.tokenized_path):
logger.warning_rank0("Loading dataset from disk will ignore other data arguments.") logger.warning_rank0("Loading dataset from disk will ignore other data arguments.")
...@@ -249,7 +303,7 @@ def get_dataset( ...@@ -249,7 +303,7 @@ def get_dataset(
if "validation" in tokenized_data: if "validation" in tokenized_data:
dataset_module["eval_dataset"] = tokenized_data["validation"] dataset_module["eval_dataset"] = tokenized_data["validation"]
else: # Dataset else: # single dataset
dataset_module["train_dataset"] = tokenized_data dataset_module["train_dataset"] = tokenized_data
if data_args.streaming: if data_args.streaming:
...@@ -263,15 +317,23 @@ def get_dataset( ...@@ -263,15 +317,23 @@ def get_dataset(
# Load and preprocess dataset # Load and preprocess dataset
with training_args.main_process_first(desc="load dataset"): with training_args.main_process_first(desc="load dataset"):
dataset = _get_merged_dataset(data_args.dataset, model_args, data_args, training_args, stage) dataset = _get_merged_dataset(data_args.dataset, model_args, data_args, training_args, stage)
eval_dataset = _get_merged_dataset(data_args.eval_dataset, model_args, data_args, training_args, stage) eval_dataset = _get_merged_dataset(
data_args.eval_dataset, model_args, data_args, training_args, stage, merge=training_args.do_predict
)
with training_args.main_process_first(desc="pre-process dataset"): with training_args.main_process_first(desc="pre-process dataset"):
dataset = _get_preprocessed_dataset( dataset = _get_preprocessed_dataset(
dataset, data_args, training_args, stage, template, tokenizer, processor, is_eval=False dataset, data_args, training_args, stage, template, tokenizer, processor, is_eval=False
) )
eval_dataset = _get_preprocessed_dataset( if isinstance(eval_dataset, dict):
eval_dataset, data_args, training_args, stage, template, tokenizer, processor, is_eval=True for eval_name, eval_data in eval_dataset.items():
) eval_dataset[eval_name] = _get_preprocessed_dataset(
eval_data, data_args, training_args, stage, template, tokenizer, processor, is_eval=True
)
else:
eval_dataset = _get_preprocessed_dataset(
eval_dataset, data_args, training_args, stage, template, tokenizer, processor, is_eval=True
)
if data_args.val_size > 1e-6: if data_args.val_size > 1e-6:
dataset_dict = split_dataset(dataset, data_args, seed=training_args.seed) dataset_dict = split_dataset(dataset, data_args, seed=training_args.seed)
...@@ -284,17 +346,20 @@ def get_dataset( ...@@ -284,17 +346,20 @@ def get_dataset(
dataset_dict["train"] = dataset dataset_dict["train"] = dataset
if eval_dataset is not None: if eval_dataset is not None:
if data_args.streaming: if isinstance(eval_dataset, dict):
eval_dataset = eval_dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed) dataset_dict.update({f"validation_{name}": data for name, data in eval_dataset.items()})
else:
if data_args.streaming:
eval_dataset = eval_dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
dataset_dict["validation"] = eval_dataset dataset_dict["validation"] = eval_dataset
dataset_dict = DatasetDict(dataset_dict) dataset_dict = DatasetDict(dataset_dict)
if data_args.tokenized_path is not None: if data_args.tokenized_path is not None: # save tokenized dataset to disk and exit
if training_args.should_save: if training_args.should_save:
dataset_dict.save_to_disk(data_args.tokenized_path) dataset_dict.save_to_disk(data_args.tokenized_path)
logger.info_rank0(f"Tokenized dataset saved at {data_args.tokenized_path}.") logger.info_rank0(f"Tokenized dataset is saved at {data_args.tokenized_path}.")
logger.info_rank0(f"Please restart the training with `tokenized_path: {data_args.tokenized_path}`.") logger.info_rank0(f"Please restart the training with `tokenized_path: {data_args.tokenized_path}`.")
sys.exit(0) sys.exit(0)
...@@ -305,5 +370,13 @@ def get_dataset( ...@@ -305,5 +370,13 @@ def get_dataset(
if "validation" in dataset_dict: if "validation" in dataset_dict:
dataset_module["eval_dataset"] = dataset_dict["validation"] dataset_module["eval_dataset"] = dataset_dict["validation"]
else:
eval_dataset = {}
for key in dataset_dict.keys():
if key.startswith("validation_"):
eval_dataset[key[len("validation_") :]] = dataset_dict[key]
if len(eval_dataset):
dataset_module["eval_dataset"] = eval_dataset
return dataset_module return dataset_module
import inspect
import math import math
import re import re
from copy import deepcopy from copy import deepcopy
from dataclasses import dataclass
from io import BytesIO from io import BytesIO
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Type, TypedDict, Union
import numpy as np import numpy as np
import torch import torch
from transformers.image_utils import get_image_size, to_numpy_array from transformers.image_utils import get_image_size, to_numpy_array
from typing_extensions import override from typing_extensions import override
from ..extras.constants import IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
from ..extras.packages import is_pillow_available, is_pyav_available, is_transformers_version_greater_than from ..extras.packages import (
is_librosa_available,
is_pillow_available,
is_pyav_available,
is_transformers_version_greater_than,
)
if is_librosa_available():
import librosa
if is_pillow_available(): if is_pillow_available():
...@@ -31,7 +42,9 @@ if is_transformers_version_greater_than("4.45.0"): ...@@ -31,7 +42,9 @@ if is_transformers_version_greater_than("4.45.0"):
if TYPE_CHECKING: if TYPE_CHECKING:
from av.stream import Stream from av.stream import Stream
from numpy.typing import NDArray
from transformers import PreTrainedTokenizer, ProcessorMixin from transformers import PreTrainedTokenizer, ProcessorMixin
from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
from transformers.image_processing_utils import BaseImageProcessor from transformers.image_processing_utils import BaseImageProcessor
class EncodedImage(TypedDict): class EncodedImage(TypedDict):
...@@ -40,6 +53,7 @@ if TYPE_CHECKING: ...@@ -40,6 +53,7 @@ if TYPE_CHECKING:
ImageInput = Union[str, bytes, EncodedImage, ImageObject] ImageInput = Union[str, bytes, EncodedImage, ImageObject]
VideoInput = str VideoInput = str
AudioInput = Union[str, NDArray]
def _get_paligemma_token_type_ids( def _get_paligemma_token_type_ids(
...@@ -59,20 +73,25 @@ def _get_paligemma_token_type_ids( ...@@ -59,20 +73,25 @@ def _get_paligemma_token_type_ids(
return batch_token_type_ids return batch_token_type_ids
class BasePlugin: @dataclass
def __init__(self, image_token: Optional[str], video_token: Optional[str]) -> None: class MMPluginMixin:
self.image_token = image_token image_token: Optional[str]
self.video_token = video_token video_token: Optional[str]
self.expand_mm_tokens = True audio_token: Optional[str]
expand_mm_tokens: bool = True
def _validate_input( def _validate_input(
self, self,
processor: Optional["ProcessorMixin"],
images: Sequence["ImageInput"], images: Sequence["ImageInput"],
videos: Sequence["VideoInput"], videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
) -> None: ) -> None:
r""" r"""
Validates if this model accepts the input modalities. Validates if this model accepts the input modalities.
""" """
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor", None)
feature_extractor: "SequenceFeatureExtractor" = getattr(processor, "feature_extractor", None)
if len(images) != 0 and self.image_token is None: if len(images) != 0 and self.image_token is None:
raise ValueError( raise ValueError(
"This model does not support image input. Please check whether the correct `template` is used." "This model does not support image input. Please check whether the correct `template` is used."
...@@ -83,31 +102,54 @@ class BasePlugin: ...@@ -83,31 +102,54 @@ class BasePlugin:
"This model does not support video input. Please check whether the correct `template` is used." "This model does not support video input. Please check whether the correct `template` is used."
) )
def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject": if len(audios) != 0 and self.audio_token is None:
raise ValueError(
"This model does not support audio input. Please check whether the correct `template` is used."
)
if self.image_token is not None and processor is None:
raise ValueError("Processor was not found, please check and update your processor config.")
if self.image_token is not None and image_processor is None:
raise ValueError("Image processor was not found, please check and update your processor config.")
if self.audio_token is not None and feature_extractor is None:
raise ValueError("Audio feature extractor was not found, please check and update your processor config.")
def _preprocess_image(
self, image: "ImageObject", image_max_pixels: int, image_min_pixels: int, **kwargs
) -> "ImageObject":
r""" r"""
Pre-processes a single image. Pre-processes a single image.
""" """
image_resolution: int = kwargs.get("image_resolution") if (image.width * image.height) > image_max_pixels:
if (image.width * image.height) > image_resolution: resize_factor = math.sqrt(image_max_pixels / (image.width * image.height))
resize_factor = math.sqrt(image_resolution / (image.width * image.height))
width, height = int(image.width * resize_factor), int(image.height * resize_factor) width, height = int(image.width * resize_factor), int(image.height * resize_factor)
image = image.resize((width, height), resample=Image.NEAREST) image = image.resize((width, height))
if (image.width * image.height) < image_min_pixels:
resize_factor = math.sqrt(image_min_pixels / (image.width * image.height))
width, height = int(image.width * resize_factor), int(image.height * resize_factor)
image = image.resize((width, height))
if image.mode != "RGB": if image.mode != "RGB":
image = image.convert("RGB") image = image.convert("RGB")
return image return image
def _get_video_sample_frames(self, video_stream: "Stream", **kwargs) -> int: def _get_video_sample_indices(
self, video_stream: "Stream", video_fps: float, video_maxlen: int, **kwargs
) -> List[int]:
r""" r"""
Computes video sample frames according to fps. Computes video sample indices according to fps.
""" """
video_fps: float = kwargs.get("video_fps")
video_maxlen: int = kwargs.get("video_maxlen")
total_frames = video_stream.frames total_frames = video_stream.frames
sample_frames = float(video_stream.duration * video_stream.time_base) * video_fps if total_frames == 0: # infinite video
return np.linspace(0, video_maxlen - 1, video_maxlen).astype(np.int32)
sample_frames = math.floor(float(video_stream.duration * video_stream.time_base) * video_fps)
sample_frames = min(total_frames, video_maxlen, sample_frames) sample_frames = min(total_frames, video_maxlen, sample_frames)
return math.floor(sample_frames) return np.linspace(0, total_frames - 1, sample_frames).astype(np.int32)
def _regularize_images(self, images: Sequence["ImageInput"], **kwargs) -> List["ImageObject"]: def _regularize_images(self, images: Sequence["ImageInput"], **kwargs) -> List["ImageObject"]:
r""" r"""
...@@ -126,7 +168,7 @@ class BasePlugin: ...@@ -126,7 +168,7 @@ class BasePlugin:
image = Image.open(image["path"]) image = Image.open(image["path"])
if not isinstance(image, ImageObject): if not isinstance(image, ImageObject):
raise ValueError(f"Expect input is a list of Images, but got {type(image)}.") raise ValueError(f"Expect input is a list of images, but got {type(image)}.")
results.append(self._preprocess_image(image, **kwargs)) results.append(self._preprocess_image(image, **kwargs))
...@@ -140,9 +182,7 @@ class BasePlugin: ...@@ -140,9 +182,7 @@ class BasePlugin:
for video in videos: for video in videos:
container = av.open(video, "r") container = av.open(video, "r")
video_stream = next(stream for stream in container.streams if stream.type == "video") video_stream = next(stream for stream in container.streams if stream.type == "video")
total_frames = video_stream.frames sample_indices = self._get_video_sample_indices(video_stream, **kwargs)
sample_frames = self._get_video_sample_frames(video_stream, **kwargs)
sample_indices = np.linspace(0, total_frames - 1, sample_frames).astype(np.int32)
frames: List["ImageObject"] = [] frames: List["ImageObject"] = []
container.seek(0) container.seek(0)
for frame_idx, frame in enumerate(container.decode(video_stream)): for frame_idx, frame in enumerate(container.decode(video_stream)):
...@@ -154,10 +194,27 @@ class BasePlugin: ...@@ -154,10 +194,27 @@ class BasePlugin:
return results return results
def _regularize_audios(self, audios: Sequence["AudioInput"], sampling_rate: float, **kwargs) -> List["NDArray"]:
r"""
Regularizes audios to avoid error. Including reading and resampling.
"""
results = []
for audio in audios:
if isinstance(audio, str):
audio = librosa.load(audio, sr=sampling_rate)[0]
if not isinstance(audio, np.ndarray):
raise ValueError(f"Expect input is a list of audios, but got {type(audio)}.")
results.append(audio)
return results
def _get_mm_inputs( def _get_mm_inputs(
self, self,
images: Sequence["ImageInput"], images: Sequence["ImageInput"],
videos: Sequence["VideoInput"], videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: "ProcessorMixin", processor: "ProcessorMixin",
) -> Dict[str, "torch.Tensor"]: ) -> Dict[str, "torch.Tensor"]:
r""" r"""
...@@ -172,47 +229,65 @@ class BasePlugin: ...@@ -172,47 +229,65 @@ class BasePlugin:
It holds num_patches == torch.prod(image_grid_thw) It holds num_patches == torch.prod(image_grid_thw)
""" """
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") image_processor: "BaseImageProcessor" = getattr(processor, "image_processor", None)
video_processor: "BaseImageProcessor" = getattr(processor, "video_processor", image_processor) video_processor: "BaseImageProcessor" = getattr(processor, "video_processor", image_processor)
input_dict = {"images": None} # default key feature_extractor: "SequenceFeatureExtractor" = getattr(processor, "feature_extractor", None)
mm_inputs = {}
if len(images) != 0: if len(images) != 0:
images = self._regularize_images( images = self._regularize_images(
images, images,
image_resolution=getattr(processor, "image_resolution", 512 * 512), image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
) )
input_dict["images"] = images mm_inputs.update(image_processor(images, return_tensors="pt"))
if len(videos) != 0: if len(videos) != 0:
videos = self._regularize_videos( videos = self._regularize_videos(
videos, videos,
image_resolution=getattr(processor, "video_resolution", 128 * 128), image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256),
image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16),
video_fps=getattr(processor, "video_fps", 2.0), video_fps=getattr(processor, "video_fps", 2.0),
video_maxlen=getattr(processor, "video_maxlen", 64), video_maxlen=getattr(processor, "video_maxlen", 128),
) )
input_dict["videos"] = videos if "videos" in inspect.signature(video_processor.preprocess).parameters: # for qwen2_vl and video_llava
mm_inputs.update(video_processor(images=None, videos=videos, return_tensors="pt"))
mm_inputs = {} else: # for llava_next_video
if image_processor != video_processor: mm_inputs.update(video_processor(videos, return_tensors="pt"))
if input_dict.get("images") is not None:
mm_inputs.update(image_processor(input_dict["images"], return_tensors="pt")) if len(audios) != 0:
if input_dict.get("videos") is not None: audios = self._regularize_audios(
mm_inputs.update(video_processor(input_dict["videos"], return_tensors="pt")) audios,
elif input_dict.get("images") is not None or input_dict.get("videos") is not None: # same processor (qwen2-vl) sampling_rate=getattr(feature_extractor, "sampling_rate", 16000),
mm_inputs.update(image_processor(**input_dict, return_tensors="pt")) )
mm_inputs.update(
feature_extractor(
audios,
sampling_rate=getattr(feature_extractor, "sampling_rate", 16000),
return_attention_mask=True,
padding="max_length",
return_tensors="pt",
)
)
mm_inputs["feature_attention_mask"] = mm_inputs.pop("attention_mask") # prevent conflicts
return mm_inputs return mm_inputs
@dataclass
class BasePlugin(MMPluginMixin):
def process_messages( def process_messages(
self, self,
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
images: Sequence["ImageInput"], images: Sequence["ImageInput"],
videos: Sequence["VideoInput"], videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]: ) -> List[Dict[str, str]]:
r""" r"""
Pre-processes input messages before tokenization for VLMs. Pre-processes input messages before tokenization for VLMs.
""" """
self._validate_input(images, videos) self._validate_input(processor, images, videos, audios)
return messages return messages
def process_token_ids( def process_token_ids(
...@@ -221,21 +296,24 @@ class BasePlugin: ...@@ -221,21 +296,24 @@ class BasePlugin:
labels: Optional[List[int]], labels: Optional[List[int]],
images: Sequence["ImageInput"], images: Sequence["ImageInput"],
videos: Sequence["VideoInput"], videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> Tuple[List[int], Optional[List[int]]]: ) -> Tuple[List[int], Optional[List[int]]]:
r""" r"""
Pre-processes token ids after tokenization for VLMs. Pre-processes token ids after tokenization for VLMs.
""" """
self._validate_input(images, videos) self._validate_input(processor, images, videos, audios)
return input_ids, labels return input_ids, labels
def get_mm_inputs( def get_mm_inputs(
self, self,
images: Sequence["ImageInput"], images: Sequence["ImageInput"],
videos: Sequence["VideoInput"], videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
imglens: Sequence[int], imglens: Sequence[int],
vidlens: Sequence[int], vidlens: Sequence[int],
audlens: Sequence[int],
batch_ids: Sequence[List[int]], batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]: ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
...@@ -247,13 +325,15 @@ class BasePlugin: ...@@ -247,13 +325,15 @@ class BasePlugin:
videos: a list of video inputs, shape (num_videos,) videos: a list of video inputs, shape (num_videos,)
imglens: number of images in each sample, shape (batch_size,) imglens: number of images in each sample, shape (batch_size,)
vidlens: number of videos in each sample, shape (batch_size,) vidlens: number of videos in each sample, shape (batch_size,)
audlens: number of audios in each sample, shape (batch_size,)
batch_ids: token ids of input samples, shape (batch_size, seq_len) batch_ids: token ids of input samples, shape (batch_size, seq_len)
processor: a processor for pre-processing images and videos processor: a processor for pre-processing images and videos
""" """
self._validate_input(images, videos) self._validate_input(processor, images, videos, audios)
return {} return {}
@dataclass
class LlavaPlugin(BasePlugin): class LlavaPlugin(BasePlugin):
@override @override
def process_messages( def process_messages(
...@@ -261,9 +341,10 @@ class LlavaPlugin(BasePlugin): ...@@ -261,9 +341,10 @@ class LlavaPlugin(BasePlugin):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
images: Sequence["ImageInput"], images: Sequence["ImageInput"],
videos: Sequence["VideoInput"], videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]: ) -> List[Dict[str, str]]:
self._validate_input(images, videos) self._validate_input(processor, images, videos, audios)
num_image_tokens = 0 num_image_tokens = 0
image_seqlen = getattr(processor, "image_seqlen") if self.expand_mm_tokens else 1 image_seqlen = getattr(processor, "image_seqlen") if self.expand_mm_tokens else 1
messages = deepcopy(messages) messages = deepcopy(messages)
...@@ -285,15 +366,18 @@ class LlavaPlugin(BasePlugin): ...@@ -285,15 +366,18 @@ class LlavaPlugin(BasePlugin):
self, self,
images: Sequence["ImageInput"], images: Sequence["ImageInput"],
videos: Sequence["VideoInput"], videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
imglens: Sequence[int], imglens: Sequence[int],
vidlens: Sequence[int], vidlens: Sequence[int],
audlens: Sequence[int],
batch_ids: Sequence[List[int]], batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]: ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos) self._validate_input(processor, images, videos, audios)
return self._get_mm_inputs(images, videos, processor) return self._get_mm_inputs(images, videos, audios, processor)
@dataclass
class LlavaNextPlugin(BasePlugin): class LlavaNextPlugin(BasePlugin):
@override @override
def process_messages( def process_messages(
...@@ -301,16 +385,15 @@ class LlavaNextPlugin(BasePlugin): ...@@ -301,16 +385,15 @@ class LlavaNextPlugin(BasePlugin):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
images: Sequence["ImageInput"], images: Sequence["ImageInput"],
videos: Sequence["VideoInput"], videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]: ) -> List[Dict[str, str]]:
self._validate_input(images, videos) self._validate_input(processor, images, videos, audios)
num_image_tokens = 0 num_image_tokens = 0
messages = deepcopy(messages) messages = deepcopy(messages)
mm_inputs = self._get_mm_inputs(images, videos, processor) mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
if "image_sizes" in mm_inputs:
image_sizes = iter(mm_inputs["image_sizes"])
if "pixel_values" in mm_inputs: if "pixel_values" in mm_inputs:
image_sizes = iter(mm_inputs["image_sizes"].tolist())
height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0])) height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0]))
for message in messages: for message in messages:
...@@ -319,7 +402,7 @@ class LlavaNextPlugin(BasePlugin): ...@@ -319,7 +402,7 @@ class LlavaNextPlugin(BasePlugin):
if self.expand_mm_tokens: if self.expand_mm_tokens:
orig_height, orig_width = next(image_sizes) orig_height, orig_width = next(image_sizes)
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width) image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
if getattr(processor, "vision_feature_select_strategy") == "default": if getattr(processor, "vision_feature_select_strategy", "default") == "default":
image_seqlen -= 1 image_seqlen -= 1
else: else:
image_seqlen = 1 image_seqlen = 1
...@@ -339,15 +422,18 @@ class LlavaNextPlugin(BasePlugin): ...@@ -339,15 +422,18 @@ class LlavaNextPlugin(BasePlugin):
self, self,
images: Sequence["ImageInput"], images: Sequence["ImageInput"],
videos: Sequence["VideoInput"], videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
imglens: Sequence[int], imglens: Sequence[int],
vidlens: Sequence[int], vidlens: Sequence[int],
audlens: Sequence[int],
batch_ids: Sequence[List[int]], batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]: ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos) self._validate_input(processor, images, videos, audios)
return self._get_mm_inputs(images, videos, processor) return self._get_mm_inputs(images, videos, audios, processor)
@dataclass
class LlavaNextVideoPlugin(BasePlugin): class LlavaNextVideoPlugin(BasePlugin):
@override @override
def process_messages( def process_messages(
...@@ -355,14 +441,15 @@ class LlavaNextVideoPlugin(BasePlugin): ...@@ -355,14 +441,15 @@ class LlavaNextVideoPlugin(BasePlugin):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
images: Sequence["ImageInput"], images: Sequence["ImageInput"],
videos: Sequence["VideoInput"], videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]: ) -> List[Dict[str, str]]:
self._validate_input(images, videos) self._validate_input(processor, images, videos, audios)
num_image_tokens, num_video_tokens = 0, 0 num_image_tokens, num_video_tokens = 0, 0
messages = deepcopy(messages) messages = deepcopy(messages)
mm_inputs = self._get_mm_inputs(images, videos, processor) mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
if "pixel_values" in mm_inputs: if "pixel_values" in mm_inputs:
image_sizes = iter(mm_inputs["image_sizes"]) image_sizes = iter(mm_inputs["image_sizes"].tolist())
height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0])) height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0]))
for message in messages: for message in messages:
content = message["content"] content = message["content"]
...@@ -370,7 +457,7 @@ class LlavaNextVideoPlugin(BasePlugin): ...@@ -370,7 +457,7 @@ class LlavaNextVideoPlugin(BasePlugin):
if self.expand_mm_tokens: if self.expand_mm_tokens:
orig_height, orig_width = next(image_sizes) orig_height, orig_width = next(image_sizes)
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width) image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
if getattr(processor, "vision_feature_select_strategy") == "default": if getattr(processor, "vision_feature_select_strategy", "default") == "default":
image_seqlen -= 1 image_seqlen -= 1
else: else:
image_seqlen = 1 image_seqlen = 1
...@@ -381,12 +468,15 @@ class LlavaNextVideoPlugin(BasePlugin): ...@@ -381,12 +468,15 @@ class LlavaNextVideoPlugin(BasePlugin):
message["content"] = content.replace("{{image}}", self.image_token) message["content"] = content.replace("{{image}}", self.image_token)
if "pixel_values_videos" in mm_inputs: if "pixel_values_videos" in mm_inputs:
pixel_values_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0]) if self.expand_mm_tokens:
height, width = get_image_size(pixel_values_video[0]) pixel_values_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0])
num_frames = pixel_values_video.shape[0] # frame dim is always after batch dim height, width = get_image_size(pixel_values_video[0])
image_seqlen = (height // processor.patch_size) * (width // processor.patch_size) num_frames = pixel_values_video.shape[0] # frame dim is always after batch dim
video_seqlen = image_seqlen // 4 * num_frames # divide by 4 needed for avg pooling layer image_seqlen = (height // processor.patch_size) * (width // processor.patch_size)
video_seqlen = video_seqlen if self.expand_mm_tokens else 1 video_seqlen = image_seqlen // 4 * num_frames # divide by 4 needed for avg pooling layer
else:
video_seqlen = 1
for message in messages: for message in messages:
content = message["content"] content = message["content"]
while VIDEO_PLACEHOLDER in content: while VIDEO_PLACEHOLDER in content:
...@@ -408,15 +498,18 @@ class LlavaNextVideoPlugin(BasePlugin): ...@@ -408,15 +498,18 @@ class LlavaNextVideoPlugin(BasePlugin):
self, self,
images: Sequence["ImageInput"], images: Sequence["ImageInput"],
videos: Sequence["VideoInput"], videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
imglens: Sequence[int], imglens: Sequence[int],
vidlens: Sequence[int], vidlens: Sequence[int],
audlens: Sequence[int],
batch_ids: Sequence[List[int]], batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]: ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos) self._validate_input(processor, images, videos, audios)
return self._get_mm_inputs(images, videos, processor) return self._get_mm_inputs(images, videos, audios, processor)
@dataclass
class MiniCPMVPlugin(BasePlugin): class MiniCPMVPlugin(BasePlugin):
@override @override
def process_messages( def process_messages(
...@@ -424,26 +517,27 @@ class MiniCPMVPlugin(BasePlugin): ...@@ -424,26 +517,27 @@ class MiniCPMVPlugin(BasePlugin):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
images: Sequence["ImageInput"], images: Sequence["ImageInput"],
videos: Sequence["VideoInput"], videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]: ) -> List[Dict[str, str]]:
self._validate_input(images, videos) self._validate_input(processor, images, videos, audios)
num_image_tokens = 0 num_image_tokens, num_video_tokens, num_audio_tokens = 0, 0, 0
num_video_tokens = 0
messages = deepcopy(messages) messages = deepcopy(messages)
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
mm_inputs = {} mm_inputs = {}
audio_inputs = {}
if len(images) != 0 and len(videos) != 0: if len(images) != 0 and len(videos) != 0:
raise ValueError("MiniCPM-V model does not support input images and videos at the same time.") raise ValueError("MiniCPM-V model does not support input images and videos at the same time.")
if len(videos) != 0: if len(videos) != 0:
max_slice_nums = 2 max_slice_nums = 2
use_image_id = False use_image_id = False
mm_inputs = self._get_mm_inputs([], videos, processor) mm_inputs = self._get_mm_inputs([], videos, [], processor)
else: else:
max_slice_nums = image_processor.max_slice_nums max_slice_nums = image_processor.max_slice_nums
use_image_id = image_processor.use_image_id use_image_id = image_processor.use_image_id
for message in messages: for i, message in enumerate(messages):
content = message["content"] content = message["content"]
while IMAGE_PLACEHOLDER in content: while IMAGE_PLACEHOLDER in content:
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1) content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
...@@ -454,15 +548,24 @@ class MiniCPMVPlugin(BasePlugin): ...@@ -454,15 +548,24 @@ class MiniCPMVPlugin(BasePlugin):
content = content.replace(VIDEO_PLACEHOLDER, "{{image}}" * video_seqlen, 1) content = content.replace(VIDEO_PLACEHOLDER, "{{image}}" * video_seqlen, 1)
num_video_tokens += 1 num_video_tokens += 1
message["content"] = content.replace("{{image}}", "(<image>./</image>)") while AUDIO_PLACEHOLDER in content:
content = content.replace(AUDIO_PLACEHOLDER, "{{audio}}", 1)
num_audio_tokens += 1
message["content"] = content.replace("{{image}}", "(<image>./</image>)").replace(
"{{audio}}", "(<audio>./</audio>)"
)
if num_image_tokens > 0: if num_image_tokens > 0:
mm_inputs = self._get_mm_inputs(images, [], processor) mm_inputs = self._get_mm_inputs(images, [], [], processor)
if num_audio_tokens > 0:
audio_inputs = self._get_mm_inputs([], [], audios, processor, ret_phs=True)
if mm_inputs: if mm_inputs:
pattern = "(<image>./</image>)" pattern = "(<image>./</image>)"
image_sizes = mm_inputs["image_sizes"] image_sizes = mm_inputs["image_sizes"]
idx = 0
for index, message in enumerate(messages): for index, message in enumerate(messages):
text = message["content"] text = message["content"]
image_tags = re.findall(pattern, text) image_tags = re.findall(pattern, text)
...@@ -473,9 +576,26 @@ class MiniCPMVPlugin(BasePlugin): ...@@ -473,9 +576,26 @@ class MiniCPMVPlugin(BasePlugin):
final_text final_text
+ text_chunks[i] + text_chunks[i]
+ image_processor.get_slice_image_placeholder( + image_processor.get_slice_image_placeholder(
image_sizes[0][i], i, max_slice_nums, use_image_id image_sizes[0][idx], idx, max_slice_nums, use_image_id
) )
) )
idx += 1
final_text += text_chunks[-1]
messages[index]["content"] = final_text
if audio_inputs:
pattern = "(<audio>./</audio>)"
idx = 0
for index, message in enumerate(messages):
text = message["content"]
audio_tags = re.findall(pattern, text)
text_chunks = text.split(pattern)
final_text = ""
for i in range(len(audio_tags)):
audio_placeholder = audio_inputs["audio_phs"][0][idx]
final_text = final_text + text_chunks[i] + audio_placeholder
idx += 1
final_text += text_chunks[-1] final_text += text_chunks[-1]
messages[index]["content"] = final_text messages[index]["content"] = final_text
...@@ -486,6 +606,9 @@ class MiniCPMVPlugin(BasePlugin): ...@@ -486,6 +606,9 @@ class MiniCPMVPlugin(BasePlugin):
if len(videos) != num_video_tokens: if len(videos) != num_video_tokens:
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.") raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
if len(audios) != num_audio_tokens:
raise ValueError(f"The number of audios does not match the number of {AUDIO_PLACEHOLDER} tokens.")
return messages return messages
@override @override
...@@ -493,15 +616,18 @@ class MiniCPMVPlugin(BasePlugin): ...@@ -493,15 +616,18 @@ class MiniCPMVPlugin(BasePlugin):
self, self,
images: Sequence["ImageInput"], images: Sequence["ImageInput"],
videos: Sequence["VideoInput"], videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: "ProcessorMixin", processor: "ProcessorMixin",
**kwargs, **kwargs,
) -> Dict[str, "torch.Tensor"]: ) -> Dict[str, "torch.Tensor"]:
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
feature_extractor: "SequenceFeatureExtractor" = getattr(processor, "feature_extractor", None)
mm_inputs = {} mm_inputs = {}
if len(images) != 0: if len(images) != 0:
images = self._regularize_images( images = self._regularize_images(
images, images,
image_resolution=getattr(processor, "image_resolution", 512 * 512), image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
) )
if "valid_image_nums_ls" in kwargs: if "valid_image_nums_ls" in kwargs:
valid_image_nums_ls = kwargs["valid_image_nums_ls"] valid_image_nums_ls = kwargs["valid_image_nums_ls"]
...@@ -521,13 +647,39 @@ class MiniCPMVPlugin(BasePlugin): ...@@ -521,13 +647,39 @@ class MiniCPMVPlugin(BasePlugin):
if len(videos) != 0: if len(videos) != 0:
videos = self._regularize_videos( videos = self._regularize_videos(
videos, videos,
image_resolution=getattr(processor, "video_resolution", 128 * 128), image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256),
image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16),
video_fps=getattr(processor, "video_fps", 2.0), video_fps=getattr(processor, "video_fps", 2.0),
video_maxlen=getattr(processor, "video_maxlen", 64), video_maxlen=getattr(processor, "video_maxlen", 128),
) )
video_inputs = image_processor(videos, do_pad=True, max_slice_nums=2, return_tensors="pt") video_inputs = image_processor(videos, do_pad=True, max_slice_nums=2, return_tensors="pt")
mm_inputs.update(video_inputs) mm_inputs.update(video_inputs)
if len(audios) != 0:
audios = self._regularize_audios(
audios,
sampling_rate=getattr(feature_extractor, "sampling_rate", 16000),
)
if "valid_audio_nums_ls" in kwargs:
valid_audio_nums_ls = kwargs["valid_audio_nums_ls"]
audios_ls = []
idx = 0
for valid_audio_nums in valid_audio_nums_ls:
audios_ls.append(audios[idx : idx + valid_audio_nums])
idx += valid_audio_nums
else:
audios_ls = [audios]
audio_features, audio_feature_lens, audio_phs = processor.audio_feature_extract(
audios_ls,
chunk_input=True,
sampling_rate=16000,
)
audio_feature_lens = [torch.tensor(audio_feature_len) for audio_feature_len in audio_feature_lens]
mm_inputs.update({"audio_features": audio_features, "audio_feature_lens": audio_feature_lens})
if kwargs.get("ret_phs", False):
mm_inputs.update({"audio_phs": audio_phs})
return mm_inputs return mm_inputs
@override @override
...@@ -535,15 +687,18 @@ class MiniCPMVPlugin(BasePlugin): ...@@ -535,15 +687,18 @@ class MiniCPMVPlugin(BasePlugin):
self, self,
images: Sequence["ImageInput"], images: Sequence["ImageInput"],
videos: Sequence["VideoInput"], videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
imglens: Sequence[int], imglens: Sequence[int],
vidlens: Sequence[int], vidlens: Sequence[int],
audlens: Sequence[int],
batch_ids: Sequence[List[int]], batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]: ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos) self._validate_input(processor, images, videos, audios)
# image bound
image_bounds_list = [] image_bounds_list = []
valid_image_nums_ls = [] valid_image_nums_ls = []
for input_ids in batch_ids: for i, input_ids in enumerate(batch_ids):
input_ids_ = torch.tensor(input_ids) input_ids_ = torch.tensor(input_ids)
start_cond = (input_ids_ == processor.tokenizer.im_start_id) | ( start_cond = (input_ids_ == processor.tokenizer.im_start_id) | (
input_ids_ == processor.tokenizer.slice_start_id input_ids_ == processor.tokenizer.slice_start_id
...@@ -552,21 +707,51 @@ class MiniCPMVPlugin(BasePlugin): ...@@ -552,21 +707,51 @@ class MiniCPMVPlugin(BasePlugin):
image_start_tokens = torch.where(start_cond)[0] image_start_tokens = torch.where(start_cond)[0]
image_start_tokens += 1 image_start_tokens += 1
image_end_tokens = torch.where(end_cond)[0] image_end_tokens = torch.where(end_cond)[0]
valid_image_nums = max(len(image_start_tokens), len(image_end_tokens)) valid_image_nums_ls.append(imglens[i])
valid_image_nums_ls.append(valid_image_nums)
image_bounds = torch.hstack( image_bounds = torch.hstack(
[ [
image_start_tokens[:valid_image_nums].unsqueeze(-1), image_start_tokens.unsqueeze(-1),
image_end_tokens[:valid_image_nums].unsqueeze(-1), image_end_tokens.unsqueeze(-1),
] ]
) )
image_bounds_list.append(image_bounds) image_bounds_list.append(image_bounds)
mm_inputs = self._get_mm_inputs(images, videos, processor, valid_image_nums_ls=valid_image_nums_ls) mm_inputs = self._get_mm_inputs(images, videos, [], processor, valid_image_nums_ls=valid_image_nums_ls)
if "tgt_sizes" not in mm_inputs:
dummy_data = [torch.empty(0) for _ in range(len(batch_ids))]
mm_inputs.update({"tgt_sizes": dummy_data, "pixel_values": dummy_data, "image_sizes": dummy_data})
mm_inputs.update({"image_bound": image_bounds_list}) mm_inputs.update({"image_bound": image_bounds_list})
if len(audios) > 0:
# audio bound
audio_bounds_ls = []
spk_bounds_ls = []
valid_audio_nums_ls = []
for input_ids, audiolen in zip(batch_ids, audlens):
input_ids_ = torch.tensor(input_ids)
audio_start_idx = torch.where(input_ids_ == processor.tokenizer.audio_start_id)[0]
audio_end_idx = torch.where(input_ids_ == processor.tokenizer.audio_end_id)[0]
assert len(audio_start_idx) == len(audio_end_idx)
audio_bounds = torch.hstack([(audio_start_idx + 1).unsqueeze(-1), audio_end_idx.unsqueeze(-1)])
audio_bounds_ls.append(audio_bounds)
valid_audio_nums_ls.append(audiolen)
spk_start_idx = torch.where(input_ids_ == processor.tokenizer.spk_start_id)[0]
spk_end_idx = torch.where(input_ids_ == processor.tokenizer.spk_end_id)[0]
assert len(spk_start_idx) == len(spk_end_idx)
spk_bounds = torch.hstack([(spk_start_idx + 1).unsqueeze(-1), spk_end_idx.unsqueeze(-1)])
spk_bounds_ls.append(spk_bounds)
audio_inputs = self._get_mm_inputs([], [], audios, processor, valid_audio_nums_ls=valid_audio_nums_ls)
mm_inputs.update(audio_inputs)
mm_inputs.update({"audio_bounds": audio_bounds_ls, "spk_bounds": spk_bounds_ls})
return mm_inputs return mm_inputs
@dataclass
class MllamaPlugin(BasePlugin): class MllamaPlugin(BasePlugin):
@override @override
def process_messages( def process_messages(
...@@ -574,9 +759,10 @@ class MllamaPlugin(BasePlugin): ...@@ -574,9 +759,10 @@ class MllamaPlugin(BasePlugin):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
images: Sequence["ImageInput"], images: Sequence["ImageInput"],
videos: Sequence["VideoInput"], videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]: ) -> List[Dict[str, str]]:
self._validate_input(images, videos) self._validate_input(processor, images, videos, audios)
num_image_tokens = 0 num_image_tokens = 0
messages = deepcopy(messages) messages = deepcopy(messages)
for message in messages: for message in messages:
...@@ -594,8 +780,9 @@ class MllamaPlugin(BasePlugin): ...@@ -594,8 +780,9 @@ class MllamaPlugin(BasePlugin):
self, self,
images: Sequence["ImageInput"], images: Sequence["ImageInput"],
videos: Sequence["VideoInput"], videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: "ProcessorMixin", processor: "ProcessorMixin",
**kwargs, imglens: List[int],
) -> Dict[str, "torch.Tensor"]: ) -> Dict[str, "torch.Tensor"]:
r""" r"""
Processes visual inputs for mllama because its image processor only accepts List[List[ImageInput]]. Processes visual inputs for mllama because its image processor only accepts List[List[ImageInput]].
...@@ -609,43 +796,56 @@ class MllamaPlugin(BasePlugin): ...@@ -609,43 +796,56 @@ class MllamaPlugin(BasePlugin):
num_tiles: List[List[int]] with shape (batch_size, num_images_in_batch). For example, (2, 1). num_tiles: List[List[int]] with shape (batch_size, num_images_in_batch). For example, (2, 1).
""" """
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
imglens: List[int] = kwargs["imglens"] mm_inputs = {}
images = self._regularize_images(images, image_resolution=getattr(processor, "image_resolution", 512 * 512)) if len(images) > 0:
batch_images = [] images = self._regularize_images(
for image_length in imglens: images,
batch_images.append(images[:image_length]) image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
images = images[image_length:] image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
)
batch_images = []
for image_length in imglens:
batch_images.append(images[:image_length])
images = images[image_length:]
mm_inputs.update(image_processor(batch_images, return_tensors="pt"))
return image_processor(batch_images, return_tensors="pt") return mm_inputs
@override
def get_mm_inputs( def get_mm_inputs(
self, self,
images: Sequence["ImageInput"], images: Sequence["ImageInput"],
videos: Sequence["VideoInput"], videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
imglens: Sequence[int], imglens: Sequence[int],
vidlens: Sequence[int], vidlens: Sequence[int],
audlens: Sequence[int],
batch_ids: Sequence[List[int]], batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]: ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos) self._validate_input(processor, images, videos, audios)
mm_inputs = self._get_mm_inputs(images, videos, processor, imglens=imglens) mm_inputs = self._get_mm_inputs(images, videos, audios, processor, imglens)
num_tiles = mm_inputs.pop("num_tiles") if mm_inputs:
image_token_id = getattr(processor, "image_token_id") num_tiles = mm_inputs.pop("num_tiles")
max_image_tiles = getattr(processor.image_processor, "max_image_tiles") image_token_id = getattr(processor, "image_token_id")
cross_attention_token_mask = [ max_image_tiles = getattr(processor.image_processor, "max_image_tiles")
get_cross_attention_token_mask(input_ids, image_token_id) for input_ids in batch_ids cross_attention_token_mask = [
] get_cross_attention_token_mask(input_ids, image_token_id) for input_ids in batch_ids
mm_inputs["cross_attention_mask"] = torch.from_numpy( ]
convert_sparse_cross_attention_mask_to_dense( mm_inputs["cross_attention_mask"] = torch.from_numpy(
cross_attention_token_mask, convert_sparse_cross_attention_mask_to_dense(
num_tiles=num_tiles, cross_attention_token_mask,
max_num_tiles=max_image_tiles, num_tiles=num_tiles,
length=max(len(input_ids) for input_ids in batch_ids), max_num_tiles=max_image_tiles,
) length=max(len(input_ids) for input_ids in batch_ids),
) # shape: (batch_size, length, max_num_images, max_num_tiles) )
) # shape: (batch_size, length, max_num_images, max_num_tiles)
return mm_inputs return mm_inputs
@dataclass
class PaliGemmaPlugin(BasePlugin): class PaliGemmaPlugin(BasePlugin):
@override @override
def process_messages( def process_messages(
...@@ -653,9 +853,10 @@ class PaliGemmaPlugin(BasePlugin): ...@@ -653,9 +853,10 @@ class PaliGemmaPlugin(BasePlugin):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
images: Sequence["ImageInput"], images: Sequence["ImageInput"],
videos: Sequence["VideoInput"], videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]: ) -> List[Dict[str, str]]:
self._validate_input(images, videos) self._validate_input(processor, images, videos, audios)
num_image_tokens = 0 num_image_tokens = 0
messages = deepcopy(messages) messages = deepcopy(messages)
for message in messages: for message in messages:
...@@ -678,10 +879,11 @@ class PaliGemmaPlugin(BasePlugin): ...@@ -678,10 +879,11 @@ class PaliGemmaPlugin(BasePlugin):
labels: Optional[List[int]], labels: Optional[List[int]],
images: Sequence["ImageInput"], images: Sequence["ImageInput"],
videos: Sequence["VideoInput"], videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> Tuple[List[int], Optional[List[int]]]: ) -> Tuple[List[int], Optional[List[int]]]:
self._validate_input(images, videos) self._validate_input(processor, images, videos, audios)
num_images = len(images) num_images = len(images)
image_seqlen = num_images * getattr(processor, "image_seqlen") if self.expand_mm_tokens else 0 # skip mm token image_seqlen = num_images * getattr(processor, "image_seqlen") if self.expand_mm_tokens else 0 # skip mm token
image_token_id = tokenizer.convert_tokens_to_ids(self.image_token) image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
...@@ -696,18 +898,21 @@ class PaliGemmaPlugin(BasePlugin): ...@@ -696,18 +898,21 @@ class PaliGemmaPlugin(BasePlugin):
self, self,
images: Sequence["ImageInput"], images: Sequence["ImageInput"],
videos: Sequence["VideoInput"], videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
imglens: Sequence[int], imglens: Sequence[int],
vidlens: Sequence[int], vidlens: Sequence[int],
audlens: Sequence[int],
batch_ids: Sequence[List[int]], batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]: ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos) self._validate_input(processor, images, videos, audios)
seqlens = [len(input_ids) for input_ids in batch_ids] seqlens = [len(input_ids) for input_ids in batch_ids]
mm_inputs = self._get_mm_inputs(images, videos, processor) mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
mm_inputs["token_type_ids"] = _get_paligemma_token_type_ids(imglens, seqlens, processor) mm_inputs["token_type_ids"] = _get_paligemma_token_type_ids(imglens, seqlens, processor)
return mm_inputs return mm_inputs
@dataclass
class PixtralPlugin(BasePlugin): class PixtralPlugin(BasePlugin):
@override @override
def process_messages( def process_messages(
...@@ -715,9 +920,10 @@ class PixtralPlugin(BasePlugin): ...@@ -715,9 +920,10 @@ class PixtralPlugin(BasePlugin):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
images: Sequence["ImageInput"], images: Sequence["ImageInput"],
videos: Sequence["VideoInput"], videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]: ) -> List[Dict[str, str]]:
self._validate_input(images, videos) self._validate_input(processor, images, videos, audios)
patch_size = getattr(processor, "patch_size") patch_size = getattr(processor, "patch_size")
image_token = getattr(processor, "image_token") image_token = getattr(processor, "image_token")
image_break_token = getattr(processor, "image_break_token") image_break_token = getattr(processor, "image_break_token")
...@@ -725,17 +931,15 @@ class PixtralPlugin(BasePlugin): ...@@ -725,17 +931,15 @@ class PixtralPlugin(BasePlugin):
num_image_tokens = 0 num_image_tokens = 0
messages = deepcopy(messages) messages = deepcopy(messages)
mm_inputs = self._get_mm_inputs(images, videos, processor) mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
image_input_sizes = mm_inputs.get("image_sizes", None) if "pixel_values" in mm_inputs:
image_sizes = iter(mm_inputs["image_sizes"].tolist())
for message in messages: for message in messages:
content = message["content"] content = message["content"]
while IMAGE_PLACEHOLDER in content: while IMAGE_PLACEHOLDER in content:
if image_input_sizes is None:
raise ValueError("Cannot get image input sizes.")
if self.expand_mm_tokens: if self.expand_mm_tokens:
image_size = image_input_sizes[0][num_image_tokens] height, width = next(image_sizes)
height, width = image_size
num_height_tokens = height // patch_size num_height_tokens = height // patch_size
num_width_tokens = width // patch_size num_width_tokens = width // patch_size
replace_tokens = [[image_token] * num_width_tokens + [image_break_token]] * num_height_tokens replace_tokens = [[image_token] * num_width_tokens + [image_break_token]] * num_height_tokens
...@@ -760,47 +964,105 @@ class PixtralPlugin(BasePlugin): ...@@ -760,47 +964,105 @@ class PixtralPlugin(BasePlugin):
self, self,
images: Sequence["ImageInput"], images: Sequence["ImageInput"],
videos: Sequence["VideoInput"], videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
imglens: Sequence[int], imglens: Sequence[int],
vidlens: Sequence[int], vidlens: Sequence[int],
audlens: Sequence[int],
batch_ids: Sequence[List[int]], batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]: ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos) self._validate_input(processor, images, videos, audios)
mm_inputs = self._get_mm_inputs(images, videos, processor) mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
if mm_inputs.get("pixel_values"):
mm_inputs["pixel_values"] = mm_inputs["pixel_values"][0]
mm_inputs.pop("image_sizes", None) mm_inputs.pop("image_sizes", None)
return mm_inputs return mm_inputs
class Qwen2vlPlugin(BasePlugin): @dataclass
class Qwen2AudioPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[Dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
self._validate_input(processor, images, videos, audios)
bos_token: str = getattr(processor, "audio_bos_token")
eos_token: str = getattr(processor, "audio_eos_token")
messages = deepcopy(messages)
mm_inputs = self._get_mm_inputs([], [], audios, processor)
if "feature_attention_mask" in mm_inputs:
audio_lengths = mm_inputs["feature_attention_mask"].sum(-1).tolist()
num_audio_tokens = 0
for message in messages:
content = message["content"]
while AUDIO_PLACEHOLDER in content:
if self.expand_mm_tokens:
audio_length = audio_lengths.pop(0)
input_length = (audio_length - 1) // 2 + 1
audio_seqlen = (input_length - 2) // 2 + 1
else:
audio_seqlen = 1
content = content.replace(
AUDIO_PLACEHOLDER, f"{bos_token}{self.audio_token * audio_seqlen}{eos_token}", 1
)
num_audio_tokens += 1
message["content"] = content
if len(audios) != num_audio_tokens:
raise ValueError(f"The number of audios does not match the number of {AUDIO_PLACEHOLDER} tokens.")
return messages
@override
def get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
audlens: Sequence[int],
batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios)
return self._get_mm_inputs(images, videos, audios, processor)
@dataclass
class Qwen2VLPlugin(BasePlugin):
@override @override
def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject": def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject":
image = super()._preprocess_image(image, **kwargs) image = super()._preprocess_image(image, **kwargs)
if min(image.width, image.height) < 28: if min(image.width, image.height) < 28:
width, height = max(image.width, 28), max(image.height, 28) width, height = max(image.width, 28), max(image.height, 28)
image = image.resize((width, height), resample=Image.NEAREST) image = image.resize((width, height))
if image.width / image.height > 200: if image.width / image.height > 200:
width, height = image.height * 180, image.height width, height = image.height * 180, image.height
image = image.resize((width, height), resample=Image.NEAREST) image = image.resize((width, height))
if image.height / image.width > 200: if image.height / image.width > 200:
width, height = image.width, image.width * 180 width, height = image.width, image.width * 180
image = image.resize((width, height), resample=Image.NEAREST) image = image.resize((width, height))
return image return image
@override @override
def _regularize_videos(self, videos: Sequence["VideoInput"], **kwargs) -> List[List["ImageObject"]]: def _regularize_videos(
results = [] self, videos: Sequence["VideoInput"], **kwargs
) -> Tuple[List[List["ImageObject"]], List[float]]:
results, fps_per_video = [], []
for video in videos: for video in videos:
container = av.open(video, "r") container = av.open(video, "r")
video_stream = next(stream for stream in container.streams if stream.type == "video") video_stream = next(stream for stream in container.streams if stream.type == "video")
total_frames = video_stream.frames sample_indices = self._get_video_sample_indices(video_stream, **kwargs)
sample_frames = self._get_video_sample_frames(video_stream, **kwargs)
sample_indices = np.linspace(0, total_frames - 1, sample_frames).astype(np.int32)
frames: List["ImageObject"] = [] frames: List["ImageObject"] = []
container.seek(0) container.seek(0)
for frame_idx, frame in enumerate(container.decode(video_stream)): for frame_idx, frame in enumerate(container.decode(video_stream)):
...@@ -812,8 +1074,43 @@ class Qwen2vlPlugin(BasePlugin): ...@@ -812,8 +1074,43 @@ class Qwen2vlPlugin(BasePlugin):
frames = self._regularize_images(frames, **kwargs) frames = self._regularize_images(frames, **kwargs)
results.append(frames) results.append(frames)
if video_stream.duration is None:
fps_per_video.append(2.0)
else:
fps_per_video.append(len(sample_indices) / float(video_stream.duration * video_stream.time_base))
return results return results, fps_per_video
@override
def _get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: "ProcessorMixin",
) -> Dict[str, "torch.Tensor"]:
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor", None)
mm_inputs = {}
if len(images) != 0:
images = self._regularize_images(
images,
image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
)
mm_inputs.update(image_processor(images, return_tensors="pt"))
if len(videos) != 0:
videos, fps_per_video = self._regularize_videos(
videos,
image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256),
image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16),
video_fps=getattr(processor, "video_fps", 2.0),
video_maxlen=getattr(processor, "video_maxlen", 128),
)
mm_inputs.update(image_processor(images=None, videos=videos, return_tensors="pt"))
mm_inputs["fps_per_video"] = fps_per_video
return mm_inputs
@override @override
def process_messages( def process_messages(
...@@ -821,17 +1118,23 @@ class Qwen2vlPlugin(BasePlugin): ...@@ -821,17 +1118,23 @@ class Qwen2vlPlugin(BasePlugin):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
images: Sequence["ImageInput"], images: Sequence["ImageInput"],
videos: Sequence["VideoInput"], videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]: ) -> List[Dict[str, str]]:
self._validate_input(images, videos) self._validate_input(processor, images, videos, audios)
num_image_tokens, num_video_tokens = 0, 0
messages = deepcopy(messages)
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
merge_length: int = getattr(image_processor, "merge_size") ** 2 merge_length: int = getattr(image_processor, "merge_size") ** 2
mm_inputs = self._get_mm_inputs(images, videos, processor) if self.expand_mm_tokens:
image_grid_thw = mm_inputs.get("image_grid_thw", []) mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
video_grid_thw = mm_inputs.get("video_grid_thw", []) image_grid_thw = mm_inputs.get("image_grid_thw", [])
video_grid_thw = mm_inputs.get("video_grid_thw", [])
else:
image_grid_thw = [None] * len(images)
video_grid_thw = [None] * len(videos)
num_image_tokens, num_video_tokens = 0, 0
messages = deepcopy(messages)
for message in messages: for message in messages:
content = message["content"] content = message["content"]
while IMAGE_PLACEHOLDER in content: while IMAGE_PLACEHOLDER in content:
...@@ -869,15 +1172,24 @@ class Qwen2vlPlugin(BasePlugin): ...@@ -869,15 +1172,24 @@ class Qwen2vlPlugin(BasePlugin):
self, self,
images: Sequence["ImageInput"], images: Sequence["ImageInput"],
videos: Sequence["VideoInput"], videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
imglens: Sequence[int], imglens: Sequence[int],
vidlens: Sequence[int], vidlens: Sequence[int],
audlens: Sequence[int],
batch_ids: Sequence[List[int]], batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]: ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos) self._validate_input(processor, images, videos, audios)
return self._get_mm_inputs(images, videos, processor) mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
fps_per_video = mm_inputs.pop("fps_per_video", [])
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
if "second_per_grid_ts" in processor.model_input_names and fps_per_video:
mm_inputs["second_per_grid_ts"] = [image_processor.temporal_patch_size / fps for fps in fps_per_video]
return mm_inputs
@dataclass
class VideoLlavaPlugin(BasePlugin): class VideoLlavaPlugin(BasePlugin):
@override @override
def process_messages( def process_messages(
...@@ -885,12 +1197,13 @@ class VideoLlavaPlugin(BasePlugin): ...@@ -885,12 +1197,13 @@ class VideoLlavaPlugin(BasePlugin):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
images: Sequence["ImageInput"], images: Sequence["ImageInput"],
videos: Sequence["VideoInput"], videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]: ) -> List[Dict[str, str]]:
self._validate_input(images, videos) self._validate_input(processor, images, videos, audios)
num_image_tokens, num_video_tokens = 0, 0 num_image_tokens, num_video_tokens = 0, 0
messages = deepcopy(messages) messages = deepcopy(messages)
mm_inputs = self._get_mm_inputs(images, videos, processor) mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
num_frames = 0 num_frames = 0
has_images = "pixel_values_images" in mm_inputs has_images = "pixel_values_images" in mm_inputs
has_videos = "pixel_values_videos" in mm_inputs has_videos = "pixel_values_videos" in mm_inputs
...@@ -907,7 +1220,7 @@ class VideoLlavaPlugin(BasePlugin): ...@@ -907,7 +1220,7 @@ class VideoLlavaPlugin(BasePlugin):
image_seqlen = (height // processor.patch_size) * (width // processor.patch_size) + 1 image_seqlen = (height // processor.patch_size) * (width // processor.patch_size) + 1
video_seqlen = image_seqlen * num_frames video_seqlen = image_seqlen * num_frames
if getattr(processor, "vision_feature_select_strategy") == "default": if getattr(processor, "vision_feature_select_strategy", "default") == "default":
image_seqlen -= 1 image_seqlen -= 1
else: else:
image_seqlen, video_seqlen = 1, 1 image_seqlen, video_seqlen = 1, 1
...@@ -938,13 +1251,15 @@ class VideoLlavaPlugin(BasePlugin): ...@@ -938,13 +1251,15 @@ class VideoLlavaPlugin(BasePlugin):
self, self,
images: Sequence["ImageInput"], images: Sequence["ImageInput"],
videos: Sequence["VideoInput"], videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
imglens: Sequence[int], imglens: Sequence[int],
vidlens: Sequence[int], vidlens: Sequence[int],
audlens: Sequence[int],
batch_ids: Sequence[List[int]], batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]: ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos) self._validate_input(processor, images, videos, audios)
return self._get_mm_inputs(images, videos, processor) return self._get_mm_inputs(images, videos, audios, processor)
PLUGINS = { PLUGINS = {
...@@ -956,18 +1271,32 @@ PLUGINS = { ...@@ -956,18 +1271,32 @@ PLUGINS = {
"mllama": MllamaPlugin, "mllama": MllamaPlugin,
"paligemma": PaliGemmaPlugin, "paligemma": PaliGemmaPlugin,
"pixtral": PixtralPlugin, "pixtral": PixtralPlugin,
"qwen2_vl": Qwen2vlPlugin, "qwen2_audio": Qwen2AudioPlugin,
"qwen2_vl": Qwen2VLPlugin,
"video_llava": VideoLlavaPlugin, "video_llava": VideoLlavaPlugin,
} }
def register_mm_plugin(name: str, plugin_class: Type["BasePlugin"]) -> None:
r"""
Registers a multimodal plugin.
"""
if name in PLUGINS:
raise ValueError(f"Multimodal plugin {name} already exists.")
PLUGINS[name] = plugin_class
def get_mm_plugin( def get_mm_plugin(
name: str, name: str,
image_token: Optional[str] = None, image_token: Optional[str] = None,
video_token: Optional[str] = None, video_token: Optional[str] = None,
audio_token: Optional[str] = None,
) -> "BasePlugin": ) -> "BasePlugin":
plugin_class = PLUGINS.get(name, None) r"""
if plugin_class is None: Gets plugin for multimodal inputs.
"""
if name not in PLUGINS:
raise ValueError(f"Multimodal plugin `{name}` not found.") raise ValueError(f"Multimodal plugin `{name}` not found.")
return plugin_class(image_token, video_token) return PLUGINS[name](image_token, video_token, audio_token)
# Copyright 2024 the LlamaFactory team. # Copyright 2025 the LlamaFactory team.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -44,7 +44,8 @@ class DatasetAttr: ...@@ -44,7 +44,8 @@ class DatasetAttr:
tools: Optional[str] = None tools: Optional[str] = None
images: Optional[str] = None images: Optional[str] = None
videos: Optional[str] = None videos: Optional[str] = None
# rlhf columns audios: Optional[str] = None
# dpo columns
chosen: Optional[str] = None chosen: Optional[str] = None
rejected: Optional[str] = None rejected: Optional[str] = None
kto_tag: Optional[str] = None kto_tag: Optional[str] = None
...@@ -70,6 +71,26 @@ class DatasetAttr: ...@@ -70,6 +71,26 @@ class DatasetAttr:
def set_attr(self, key: str, obj: Dict[str, Any], default: Optional[Any] = None) -> None: def set_attr(self, key: str, obj: Dict[str, Any], default: Optional[Any] = None) -> None:
setattr(self, key, obj.get(key, default)) setattr(self, key, obj.get(key, default))
def join(self, attr: Dict[str, Any]) -> None:
self.set_attr("formatting", attr, default="alpaca")
self.set_attr("ranking", attr, default=False)
self.set_attr("subset", attr)
self.set_attr("split", attr, default="train")
self.set_attr("folder", attr)
self.set_attr("num_samples", attr)
if "columns" in attr:
column_names = ["prompt", "query", "response", "history", "messages", "system", "tools"]
column_names += ["images", "videos", "audios", "chosen", "rejected", "kto_tag"]
for column_name in column_names:
self.set_attr(column_name, attr["columns"])
if "tags" in attr:
tag_names = ["role_tag", "content_tag"]
tag_names += ["user_tag", "assistant_tag", "observation_tag", "function_tag", "system_tag"]
for tag in tag_names:
self.set_attr(tag, attr["tags"])
def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -> List["DatasetAttr"]: def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -> List["DatasetAttr"]:
r""" r"""
...@@ -127,36 +148,7 @@ def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) - ...@@ -127,36 +148,7 @@ def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -
else: else:
dataset_attr = DatasetAttr("file", dataset_name=dataset_info[name]["file_name"]) dataset_attr = DatasetAttr("file", dataset_name=dataset_info[name]["file_name"])
dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca") dataset_attr.join(dataset_info[name])
dataset_attr.set_attr("ranking", dataset_info[name], default=False)
dataset_attr.set_attr("subset", dataset_info[name])
dataset_attr.set_attr("split", dataset_info[name], default="train")
dataset_attr.set_attr("folder", dataset_info[name])
dataset_attr.set_attr("num_samples", dataset_info[name])
if "columns" in dataset_info[name]:
column_names = ["system", "tools", "images", "videos", "chosen", "rejected", "kto_tag"]
if dataset_attr.formatting == "alpaca":
column_names.extend(["prompt", "query", "response", "history"])
else:
column_names.extend(["messages"])
for column_name in column_names:
dataset_attr.set_attr(column_name, dataset_info[name]["columns"])
if dataset_attr.formatting == "sharegpt" and "tags" in dataset_info[name]:
tag_names = (
"role_tag",
"content_tag",
"user_tag",
"assistant_tag",
"observation_tag",
"function_tag",
"system_tag",
)
for tag in tag_names:
dataset_attr.set_attr(tag, dataset_info[name]["tags"])
dataset_list.append(dataset_attr) dataset_list.append(dataset_attr)
return dataset_list return dataset_list
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import TYPE_CHECKING, Callable, Literal, Optional, Tuple
from .processors.feedback import preprocess_feedback_dataset
from .processors.pairwise import preprocess_pairwise_dataset, print_pairwise_dataset_example
from .processors.pretrain import preprocess_pretrain_dataset, print_pretrain_dataset_example
from .processors.supervised import (
preprocess_packed_supervised_dataset,
preprocess_supervised_dataset,
print_supervised_dataset_example,
)
from .processors.unsupervised import preprocess_unsupervised_dataset, print_unsupervised_dataset_example
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer, ProcessorMixin
from ..hparams import DataArguments
from .template import Template
def get_preprocess_and_print_func(
data_args: "DataArguments",
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
do_generate: bool = False,
) -> Tuple[Callable, Callable]:
if stage == "pt":
preprocess_func = partial(
preprocess_pretrain_dataset,
tokenizer=tokenizer,
data_args=data_args,
)
print_function = partial(print_pretrain_dataset_example, tokenizer=tokenizer)
elif stage == "sft" and not do_generate:
if data_args.packing:
if data_args.neat_packing: # hack datasets to have int32 attention mask
from datasets.arrow_writer import OptimizedTypedSequence, TypedSequence
def __init__(self, data, **kwargs):
return TypedSequence.__init__(
self,
data,
type=kwargs.pop("type", None),
try_type=kwargs.pop("try_type", None),
optimized_int_type=kwargs.pop("optimized_int_type", None),
)
OptimizedTypedSequence.__init__ = __init__
preprocess_func = partial(
preprocess_packed_supervised_dataset,
template=template,
tokenizer=tokenizer,
processor=processor,
data_args=data_args,
)
else:
preprocess_func = partial(
preprocess_supervised_dataset,
template=template,
tokenizer=tokenizer,
processor=processor,
data_args=data_args,
)
print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer)
elif stage == "rm":
preprocess_func = partial(
preprocess_pairwise_dataset,
template=template,
tokenizer=tokenizer,
processor=processor,
data_args=data_args,
)
print_function = partial(print_pairwise_dataset_example, tokenizer=tokenizer)
elif stage == "kto":
preprocess_func = partial(
preprocess_feedback_dataset,
template=template,
tokenizer=tokenizer,
processor=processor,
data_args=data_args,
)
print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer)
else:
preprocess_func = partial(
preprocess_unsupervised_dataset,
template=template,
tokenizer=tokenizer,
processor=processor,
data_args=data_args,
)
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
return preprocess_func, print_function
from .feedback import FeedbackDatasetProcessor
from .pairwise import PairwiseDatasetProcessor
from .pretrain import PretrainDatasetProcessor
from .processor_utils import DatasetProcessor
from .supervised import PackedSupervisedDatasetProcessor, SupervisedDatasetProcessor
from .unsupervised import UnsupervisedDatasetProcessor
__all__ = [
"DatasetProcessor",
"FeedbackDatasetProcessor",
"PairwiseDatasetProcessor",
"PretrainDatasetProcessor",
"PackedSupervisedDatasetProcessor",
"SupervisedDatasetProcessor",
"UnsupervisedDatasetProcessor",
]
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras import logging
from ...extras.constants import IGNORE_INDEX
from .processor_utils import DatasetProcessor, infer_seqlen
if TYPE_CHECKING:
from ..mm_plugin import AudioInput, ImageInput, VideoInput
logger = logging.get_logger(__name__)
class FeedbackDatasetProcessor(DatasetProcessor):
def _encode_data_example(
self,
prompt: Sequence[Dict[str, str]],
response: Sequence[Dict[str, str]],
kl_response: Sequence[Dict[str, str]],
system: Optional[str],
tools: Optional[str],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
) -> Tuple[List[int], List[int], List[int], List[int], bool]:
if response[0]["content"]: # desired example
kto_tag = True
messages = prompt + [response[0]]
else: # undesired example
kto_tag = False
messages = prompt + [response[1]]
if kl_response[0]["content"]:
kl_messages = prompt + [kl_response[0]]
else:
kl_messages = prompt + [kl_response[1]]
messages = self.template.mm_plugin.process_messages(messages, images, videos, audios, self.processor)
kl_messages = self.template.mm_plugin.process_messages(kl_messages, images, videos, audios, self.processor)
prompt_ids, response_ids = self.template.encode_oneturn(self.tokenizer, messages, system, tools)
kl_prompt_ids, kl_response_ids = self.template.encode_oneturn(self.tokenizer, kl_messages, system, tools)
if self.template.efficient_eos:
response_ids += [self.tokenizer.eos_token_id]
kl_response_ids += [self.tokenizer.eos_token_id]
prompt_ids, _ = self.template.mm_plugin.process_token_ids(
prompt_ids, None, images, videos, audios, self.tokenizer, self.processor
)
kl_prompt_ids, _ = self.template.mm_plugin.process_token_ids(
kl_prompt_ids, None, images, videos, audios, self.tokenizer, self.processor
)
source_len, target_len = infer_seqlen(len(prompt_ids), len(response_ids), self.data_args.cutoff_len)
prompt_ids = prompt_ids[:source_len]
response_ids = response_ids[:target_len]
kl_source_len, kl_target_len = infer_seqlen(
len(kl_prompt_ids), len(kl_response_ids), self.data_args.cutoff_len
)
kl_prompt_ids = kl_prompt_ids[:kl_source_len]
kl_response_ids = kl_response_ids[:kl_target_len]
input_ids = prompt_ids + response_ids
labels = [IGNORE_INDEX] * source_len + response_ids
kl_input_ids = kl_prompt_ids + kl_response_ids
kl_labels = [IGNORE_INDEX] * kl_source_len + kl_response_ids
return input_ids, labels, kl_input_ids, kl_labels, kto_tag
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
# create unrelated input-output pairs for estimating the KL term by flipping the matched pairs
kl_response = examples["_response"][::-1]
model_inputs = defaultdict(list)
for i in range(len(examples["_prompt"])):
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2:
logger.warning_rank0(
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
)
continue
input_ids, labels, kl_input_ids, kl_labels, kto_tag = self._encode_data_example(
prompt=examples["_prompt"][i],
response=examples["_response"][i],
kl_response=kl_response[i],
system=examples["_system"][i],
tools=examples["_tools"][i],
images=examples["_images"][i] or [],
videos=examples["_videos"][i] or [],
audios=examples["_audios"][i] or [],
)
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels)
model_inputs["kl_input_ids"].append(kl_input_ids)
model_inputs["kl_attention_mask"].append([1] * len(kl_input_ids))
model_inputs["kl_labels"].append(kl_labels)
model_inputs["kto_tags"].append(kto_tag)
model_inputs["images"].append(examples["_images"][i])
model_inputs["videos"].append(examples["_videos"][i])
model_inputs["audios"].append(examples["_audios"][i])
desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag])
undesirable_num = len(model_inputs["kto_tags"]) - desirable_num
if desirable_num == 0 or undesirable_num == 0:
logger.warning_rank0("Your dataset only has one preference type.")
return model_inputs
def print_data_example(self, example: Dict[str, List[int]]) -> None:
valid_labels = list(filter(lambda x: x != IGNORE_INDEX, example["labels"]))
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
print("label_ids:\n{}".format(example["labels"]))
print(f"labels:\n{self.tokenizer.decode(valid_labels, skip_special_tokens=False)}")
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras import logging
from ...extras.constants import IGNORE_INDEX
from .processor_utils import DatasetProcessor, infer_seqlen
if TYPE_CHECKING:
from ..mm_plugin import AudioInput, ImageInput, VideoInput
logger = logging.get_logger(__name__)
class PairwiseDatasetProcessor(DatasetProcessor):
def _encode_data_example(
self,
prompt: Sequence[Dict[str, str]],
response: Sequence[Dict[str, str]],
system: Optional[str],
tools: Optional[str],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
) -> Tuple[List[int], List[int], List[int], List[int]]:
chosen_messages = self.template.mm_plugin.process_messages(
prompt + [response[0]], images, videos, audios, self.processor
)
rejected_messages = self.template.mm_plugin.process_messages(
prompt + [response[1]], images, videos, audios, self.processor
)
prompt_ids, chosen_ids = self.template.encode_oneturn(self.tokenizer, chosen_messages, system, tools)
_, rejected_ids = self.template.encode_oneturn(self.tokenizer, rejected_messages, system, tools)
if self.template.efficient_eos:
chosen_ids += [self.tokenizer.eos_token_id]
rejected_ids += [self.tokenizer.eos_token_id]
prompt_ids, _ = self.template.mm_plugin.process_token_ids(
prompt_ids, None, images, videos, audios, self.tokenizer, self.processor
)
# consider the response is more important
source_len, target_len = infer_seqlen(
len(prompt_ids), max(len(chosen_ids), len(rejected_ids)), self.data_args.cutoff_len
)
prompt_ids = prompt_ids[:source_len]
chosen_ids = chosen_ids[:target_len]
rejected_ids = rejected_ids[:target_len]
chosen_input_ids = prompt_ids + chosen_ids
chosen_labels = [IGNORE_INDEX] * source_len + chosen_ids
rejected_input_ids = prompt_ids + rejected_ids
rejected_labels = [IGNORE_INDEX] * source_len + rejected_ids
return chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
model_inputs = defaultdict(list)
for i in range(len(examples["_prompt"])):
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2:
logger.warning_rank0(
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
)
continue
chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels = self._encode_data_example(
prompt=examples["_prompt"][i],
response=examples["_response"][i],
system=examples["_system"][i],
tools=examples["_tools"][i],
images=examples["_images"][i] or [],
videos=examples["_videos"][i] or [],
audios=examples["_audios"][i] or [],
)
model_inputs["chosen_input_ids"].append(chosen_input_ids)
model_inputs["chosen_attention_mask"].append([1] * len(chosen_input_ids))
model_inputs["chosen_labels"].append(chosen_labels)
model_inputs["rejected_input_ids"].append(rejected_input_ids)
model_inputs["rejected_attention_mask"].append([1] * len(rejected_input_ids))
model_inputs["rejected_labels"].append(rejected_labels)
model_inputs["images"].append(examples["_images"][i])
model_inputs["videos"].append(examples["_videos"][i])
model_inputs["audios"].append(examples["_audios"][i])
return model_inputs
def print_data_example(self, example: Dict[str, List[int]]) -> None:
valid_chosen_labels = list(filter(lambda x: x != IGNORE_INDEX, example["chosen_labels"]))
valid_rejected_labels = list(filter(lambda x: x != IGNORE_INDEX, example["rejected_labels"]))
print("chosen_input_ids:\n{}".format(example["chosen_input_ids"]))
print(
"chosen_inputs:\n{}".format(self.tokenizer.decode(example["chosen_input_ids"], skip_special_tokens=False))
)
print("chosen_label_ids:\n{}".format(example["chosen_labels"]))
print(f"chosen_labels:\n{self.tokenizer.decode(valid_chosen_labels, skip_special_tokens=False)}")
print("rejected_input_ids:\n{}".format(example["rejected_input_ids"]))
print(
"rejected_inputs:\n{}".format(
self.tokenizer.decode(example["rejected_input_ids"], skip_special_tokens=False)
)
)
print("rejected_label_ids:\n{}".format(example["rejected_labels"]))
print(f"rejected_labels:\n{self.tokenizer.decode(valid_rejected_labels, skip_special_tokens=False)}")
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from itertools import chain
from typing import Any, Dict, List
from .processor_utils import DatasetProcessor
@dataclass
class PretrainDatasetProcessor(DatasetProcessor):
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
# build grouped texts with format `X1 X2 X3 ...` if packing is enabled
eos_token = "<|end_of_text|>" if self.data_args.template == "llama3" else self.tokenizer.eos_token
text_examples = [messages[0]["content"] + eos_token for messages in examples["_prompt"]]
if not self.data_args.packing:
if getattr(self.tokenizer, "add_bos_token", False):
text_examples = [self.tokenizer.bos_token + example for example in text_examples]
result = self.tokenizer(
text_examples, add_special_tokens=False, truncation=True, max_length=self.data_args.cutoff_len
)
else:
tokenized_examples = self.tokenizer(text_examples, add_special_tokens=False)
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
block_size = self.data_args.cutoff_len
total_length = (total_length // block_size) * block_size
result = {
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
for k, t in concatenated_examples.items()
}
if getattr(self.tokenizer, "add_bos_token", False):
for i in range(len(result["input_ids"])):
result["input_ids"][i][0] = self.tokenizer.bos_token_id
return result
def print_data_example(self, example: Dict[str, List[int]]) -> None:
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
# Copyright 2024 the LlamaFactory team. # Copyright 2025 the LlamaFactory team.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -13,7 +13,42 @@ ...@@ -13,7 +13,42 @@
# limitations under the License. # limitations under the License.
import bisect import bisect
from typing import List, Sequence, Tuple from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer, ProcessorMixin
from ...hparams import DataArguments
from ..template import Template
@dataclass
class DatasetProcessor(ABC):
r"""
A class for data processors.
"""
template: "Template"
tokenizer: "PreTrainedTokenizer"
processor: Optional["ProcessorMixin"]
data_args: "DataArguments"
@abstractmethod
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
r"""
Builds model inputs from the examples.
"""
...
@abstractmethod
def print_data_example(self, example: Dict[str, List[int]]) -> None:
r"""
Print a data example to stdout.
"""
...
def search_for_fit(numbers: Sequence[int], capacity: int) -> int: def search_for_fit(numbers: Sequence[int], capacity: int) -> int:
......
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras import logging
from ...extras.constants import IGNORE_INDEX
from .processor_utils import DatasetProcessor, greedy_knapsack, infer_seqlen
if TYPE_CHECKING:
from ..mm_plugin import AudioInput, ImageInput, VideoInput
logger = logging.get_logger(__name__)
@dataclass
class SupervisedDatasetProcessor(DatasetProcessor):
def _encode_data_example(
self,
prompt: Sequence[Dict[str, str]],
response: Sequence[Dict[str, str]],
system: Optional[str],
tools: Optional[str],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
) -> Tuple[List[int], List[int]]:
messages = self.template.mm_plugin.process_messages(prompt + response, images, videos, audios, self.processor)
input_ids, labels = self.template.mm_plugin.process_token_ids(
[], [], images, videos, audios, self.tokenizer, self.processor
)
encoded_pairs = self.template.encode_multiturn(self.tokenizer, messages, system, tools)
total_length = len(input_ids) + (1 if self.template.efficient_eos else 0)
if self.data_args.mask_history:
encoded_pairs = encoded_pairs[::-1] # high priority for last turns
for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs):
if total_length >= self.data_args.cutoff_len:
break
source_len, target_len = infer_seqlen(
len(source_ids), len(target_ids), self.data_args.cutoff_len - total_length
)
source_ids = source_ids[:source_len]
target_ids = target_ids[:target_len]
total_length += source_len + target_len
if self.data_args.train_on_prompt:
source_label = source_ids
elif self.template.efficient_eos:
source_label = [self.tokenizer.eos_token_id] + [IGNORE_INDEX] * (source_len - 1)
else:
source_label = [IGNORE_INDEX] * source_len
if self.data_args.mask_history and turn_idx != 0: # train on the last turn only
target_label = [IGNORE_INDEX] * target_len
else:
target_label = target_ids
if self.data_args.mask_history: # reversed sequences
input_ids = source_ids + target_ids + input_ids
labels = source_label + target_label + labels
else:
input_ids += source_ids + target_ids
labels += source_label + target_label
if self.template.efficient_eos:
input_ids += [self.tokenizer.eos_token_id]
labels += [self.tokenizer.eos_token_id]
return input_ids, labels
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
model_inputs = defaultdict(list)
for i in range(len(examples["_prompt"])):
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1:
logger.warning_rank0(
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
)
continue
input_ids, labels = self._encode_data_example(
prompt=examples["_prompt"][i],
response=examples["_response"][i],
system=examples["_system"][i],
tools=examples["_tools"][i],
images=examples["_images"][i] or [],
videos=examples["_videos"][i] or [],
audios=examples["_audios"][i] or [],
)
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels)
model_inputs["images"].append(examples["_images"][i])
model_inputs["videos"].append(examples["_videos"][i])
model_inputs["audios"].append(examples["_audios"][i])
return model_inputs
def print_data_example(self, example: Dict[str, List[int]]) -> None:
valid_labels = list(filter(lambda x: x != IGNORE_INDEX, example["labels"]))
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
print("label_ids:\n{}".format(example["labels"]))
print(f"labels:\n{self.tokenizer.decode(valid_labels, skip_special_tokens=False)}")
@dataclass
class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
# TODO: use `position_ids` to achieve packing
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
valid_num = 0
batch_input_ids, batch_labels, batch_images, batch_videos, batch_audios = [], [], [], [], []
lengths = []
length2indexes = defaultdict(list)
for i in range(len(examples["_prompt"])):
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1:
logger.warning_rank0(
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
)
continue
input_ids, labels = self._encode_data_example(
prompt=examples["_prompt"][i],
response=examples["_response"][i],
system=examples["_system"][i],
tools=examples["_tools"][i],
images=examples["_images"][i] or [],
videos=examples["_videos"][i] or [],
audios=examples["_audios"][i] or [],
)
length = len(input_ids)
if length > self.data_args.cutoff_len:
logger.warning_rank0(f"Dropped lengthy example with length {length} > {self.data_args.cutoff_len}.")
else:
lengths.append(length)
length2indexes[length].append(valid_num)
batch_input_ids.append(input_ids)
batch_labels.append(labels)
batch_images.append(examples["_images"][i] or [])
batch_videos.append(examples["_videos"][i] or [])
batch_audios.append(examples["_audios"][i] or [])
valid_num += 1
model_inputs = defaultdict(list)
knapsacks = greedy_knapsack(lengths, self.data_args.cutoff_len)
for knapsack in knapsacks:
packed_input_ids, packed_attention_masks, packed_labels = [], [], []
packed_images, packed_videos, packed_audios = [], [], []
for i, length in enumerate(knapsack):
index = length2indexes[length].pop()
packed_input_ids += batch_input_ids[index]
packed_labels += batch_labels[index]
packed_images += batch_images[index]
packed_videos += batch_videos[index]
packed_audios += batch_audios[index]
if self.data_args.neat_packing:
packed_attention_masks += [i + 1] * len(batch_input_ids[index]) # start from 1
else:
packed_attention_masks += [1] * len(batch_input_ids[index])
if len(packed_input_ids) < self.data_args.cutoff_len + 1: # avoid flash_attn drops attn mask
pad_length = self.data_args.cutoff_len - len(packed_input_ids) + 1
packed_input_ids += [self.tokenizer.pad_token_id] * pad_length
packed_labels += [IGNORE_INDEX] * pad_length
if self.data_args.neat_packing:
packed_attention_masks += [0] * pad_length
else:
packed_attention_masks += [1] * pad_length # more efficient flash_attn
if len(packed_input_ids) != self.data_args.cutoff_len + 1:
raise ValueError("The length of packed example should be identical to the cutoff length.")
model_inputs["input_ids"].append(packed_input_ids)
model_inputs["attention_mask"].append(packed_attention_masks)
model_inputs["labels"].append(packed_labels)
model_inputs["images"].append(packed_images or None)
model_inputs["videos"].append(packed_videos or None)
model_inputs["audios"].append(packed_audios or None)
return model_inputs
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras import logging
from ..data_utils import Role
from .processor_utils import DatasetProcessor, infer_seqlen
if TYPE_CHECKING:
from ..mm_plugin import AudioInput, ImageInput, VideoInput
logger = logging.get_logger(__name__)
class UnsupervisedDatasetProcessor(DatasetProcessor):
def _encode_data_example(
self,
prompt: Sequence[Dict[str, str]],
response: Sequence[Dict[str, str]],
system: Optional[str],
tools: Optional[str],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
) -> Tuple[List[int], List[int]]:
if len(response) == 1:
messages = prompt + response
else:
messages = prompt + [{"role": Role.ASSISTANT.value, "content": ""}]
messages = self.template.mm_plugin.process_messages(messages, images, videos, audios, self.processor)
input_ids, labels = self.template.encode_oneturn(self.tokenizer, messages, system, tools)
if self.template.efficient_eos:
labels += [self.tokenizer.eos_token_id]
input_ids, _ = self.template.mm_plugin.process_token_ids(
input_ids, None, images, videos, audios, self.tokenizer, self.processor
)
source_len, target_len = infer_seqlen(len(input_ids), len(labels), self.data_args.cutoff_len)
input_ids = input_ids[:source_len]
labels = labels[:target_len]
return input_ids, labels
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
# build inputs with format `<bos> X` and labels with format `Y <eos>`
model_inputs = defaultdict(list)
for i in range(len(examples["_prompt"])):
if len(examples["_prompt"][i]) % 2 != 1:
logger.warning_rank0(
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
)
continue
input_ids, labels = self._encode_data_example(
prompt=examples["_prompt"][i],
response=examples["_response"][i],
system=examples["_system"][i],
tools=examples["_tools"][i],
images=examples["_images"][i] or [],
videos=examples["_videos"][i] or [],
audios=examples["_audios"][i] or [],
)
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels)
model_inputs["images"].append(examples["_images"][i])
model_inputs["videos"].append(examples["_videos"][i])
model_inputs["audios"].append(examples["_audios"][i])
return model_inputs
def print_data_example(self, example: Dict[str, List[int]]) -> None:
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
print("label_ids:\n{}".format(example["labels"]))
print("labels:\n{}".format(self.tokenizer.decode(example["labels"], skip_special_tokens=False)))
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras import logging
from ...extras.constants import IGNORE_INDEX
from .processor_utils import infer_seqlen
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer, ProcessorMixin
from ...hparams import DataArguments
from ..mm_plugin import ImageInput, VideoInput
from ..template import Template
logger = logging.get_logger(__name__)
def _encode_feedback_example(
prompt: Sequence[Dict[str, str]],
response: Sequence[Dict[str, str]],
kl_response: Sequence[Dict[str, str]],
system: Optional[str],
tools: Optional[str],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
cutoff_len: int,
) -> Tuple[List[int], List[int], List[int], List[int], bool]:
if response[0]["content"]: # desired example
kto_tag = True
messages = prompt + [response[0]]
else: # undesired example
kto_tag = False
messages = prompt + [response[1]]
if kl_response[0]["content"]:
kl_messages = prompt + [kl_response[0]]
else:
kl_messages = prompt + [kl_response[1]]
messages = template.mm_plugin.process_messages(messages, images, videos, processor)
kl_messages = template.mm_plugin.process_messages(kl_messages, images, videos, processor)
prompt_ids, response_ids = template.encode_oneturn(tokenizer, messages, system, tools)
kl_prompt_ids, kl_response_ids = template.encode_oneturn(tokenizer, kl_messages, system, tools)
if template.efficient_eos:
response_ids += [tokenizer.eos_token_id]
kl_response_ids += [tokenizer.eos_token_id]
prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, images, videos, tokenizer, processor)
kl_prompt_ids, _ = template.mm_plugin.process_token_ids(kl_prompt_ids, None, images, videos, tokenizer, processor)
source_len, target_len = infer_seqlen(len(prompt_ids), len(response_ids), cutoff_len)
prompt_ids = prompt_ids[:source_len]
response_ids = response_ids[:target_len]
kl_source_len, kl_target_len = infer_seqlen(len(kl_prompt_ids), len(kl_response_ids), cutoff_len)
kl_prompt_ids = kl_prompt_ids[:kl_source_len]
kl_response_ids = kl_response_ids[:kl_target_len]
input_ids = prompt_ids + response_ids
labels = [IGNORE_INDEX] * source_len + response_ids
kl_input_ids = kl_prompt_ids + kl_response_ids
kl_labels = [IGNORE_INDEX] * kl_source_len + kl_response_ids
return input_ids, labels, kl_input_ids, kl_labels, kto_tag
def preprocess_feedback_dataset(
examples: Dict[str, List[Any]],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
data_args: "DataArguments",
) -> Dict[str, List[Any]]:
# create unrelated input-output pairs for estimating the KL term by flipping the matched pairs
kl_response = examples["_response"][::-1]
model_inputs = defaultdict(list)
for i in range(len(examples["_prompt"])):
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2:
logger.warning_rank0(
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
)
continue
input_ids, labels, kl_input_ids, kl_labels, kto_tag = _encode_feedback_example(
prompt=examples["_prompt"][i],
response=examples["_response"][i],
kl_response=kl_response[i],
system=examples["_system"][i],
tools=examples["_tools"][i],
images=examples["_images"][i] or [],
videos=examples["_videos"][i] or [],
template=template,
tokenizer=tokenizer,
processor=processor,
cutoff_len=data_args.cutoff_len,
)
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels)
model_inputs["kl_input_ids"].append(kl_input_ids)
model_inputs["kl_attention_mask"].append([1] * len(kl_input_ids))
model_inputs["kl_labels"].append(kl_labels)
model_inputs["kto_tags"].append(kto_tag)
model_inputs["images"].append(examples["_images"][i])
model_inputs["videos"].append(examples["_videos"][i])
desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag])
undesirable_num = len(model_inputs["kto_tags"]) - desirable_num
if desirable_num == 0 or undesirable_num == 0:
logger.warning_rank0("Your dataset only has one preference type.")
return model_inputs
from typing import TYPE_CHECKING, List, Sequence
from ...extras.packages import is_pillow_available
if is_pillow_available():
from PIL import Image
if TYPE_CHECKING:
from numpy.typing import NDArray
from PIL.Image import Image as ImageObject
from transformers import ProcessorMixin
from transformers.image_processing_utils import BaseImageProcessor
def get_pixel_values(images: Sequence["ImageObject"], processor: "ProcessorMixin") -> "NDArray":
# process visual inputs (currently only supports a single image)
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
image = images[0] if len(images) != 0 else Image.new("RGB", (100, 100), (255, 255, 255))
return image_processor(image, return_tensors="pt")["pixel_values"][0] # shape (C, H, W)
def get_paligemma_token_type_ids(input_len: int, processor: "ProcessorMixin") -> List[int]:
# get paligemma token type ids for computing loss
image_seq_length = getattr(processor, "image_seq_length")
return [0] * image_seq_length + [1] * (input_len - image_seq_length)
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras import logging
from ...extras.constants import IGNORE_INDEX
from .processor_utils import infer_seqlen
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer, ProcessorMixin
from ...hparams import DataArguments
from ..mm_plugin import ImageInput, VideoInput
from ..template import Template
logger = logging.get_logger(__name__)
def _encode_pairwise_example(
prompt: Sequence[Dict[str, str]],
response: Sequence[Dict[str, str]],
system: Optional[str],
tools: Optional[str],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
cutoff_len: int,
) -> Tuple[List[int], List[int], List[int], List[int]]:
chosen_messages = template.mm_plugin.process_messages(prompt + [response[0]], images, videos, processor)
rejected_messages = template.mm_plugin.process_messages(prompt + [response[1]], images, videos, processor)
prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, chosen_messages, system, tools)
_, rejected_ids = template.encode_oneturn(tokenizer, rejected_messages, system, tools)
if template.efficient_eos:
chosen_ids += [tokenizer.eos_token_id]
rejected_ids += [tokenizer.eos_token_id]
prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, images, videos, tokenizer, processor)
# consider the response is more important
source_len, target_len = infer_seqlen(len(prompt_ids), max(len(chosen_ids), len(rejected_ids)), cutoff_len)
prompt_ids = prompt_ids[:source_len]
chosen_ids = chosen_ids[:target_len]
rejected_ids = rejected_ids[:target_len]
chosen_input_ids = prompt_ids + chosen_ids
chosen_labels = [IGNORE_INDEX] * source_len + chosen_ids
rejected_input_ids = prompt_ids + rejected_ids
rejected_labels = [IGNORE_INDEX] * source_len + rejected_ids
return chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels
def preprocess_pairwise_dataset(
examples: Dict[str, List[Any]],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
data_args: "DataArguments",
) -> Dict[str, List[Any]]:
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
model_inputs = defaultdict(list)
for i in range(len(examples["_prompt"])):
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2:
logger.warning_rank0(
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
)
continue
chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels = _encode_pairwise_example(
prompt=examples["_prompt"][i],
response=examples["_response"][i],
system=examples["_system"][i],
tools=examples["_tools"][i],
images=examples["_images"][i] or [],
videos=examples["_videos"][i] or [],
template=template,
tokenizer=tokenizer,
processor=processor,
cutoff_len=data_args.cutoff_len,
)
model_inputs["chosen_input_ids"].append(chosen_input_ids)
model_inputs["chosen_attention_mask"].append([1] * len(chosen_input_ids))
model_inputs["chosen_labels"].append(chosen_labels)
model_inputs["rejected_input_ids"].append(rejected_input_ids)
model_inputs["rejected_attention_mask"].append([1] * len(rejected_input_ids))
model_inputs["rejected_labels"].append(rejected_labels)
model_inputs["images"].append(examples["_images"][i])
model_inputs["videos"].append(examples["_videos"][i])
return model_inputs
def print_pairwise_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
valid_chosen_labels = list(filter(lambda x: x != IGNORE_INDEX, example["chosen_labels"]))
valid_rejected_labels = list(filter(lambda x: x != IGNORE_INDEX, example["rejected_labels"]))
print("chosen_input_ids:\n{}".format(example["chosen_input_ids"]))
print("chosen_inputs:\n{}".format(tokenizer.decode(example["chosen_input_ids"], skip_special_tokens=False)))
print("chosen_label_ids:\n{}".format(example["chosen_labels"]))
print(f"chosen_labels:\n{tokenizer.decode(valid_chosen_labels, skip_special_tokens=False)}")
print("rejected_input_ids:\n{}".format(example["rejected_input_ids"]))
print("rejected_inputs:\n{}".format(tokenizer.decode(example["rejected_input_ids"], skip_special_tokens=False)))
print("rejected_label_ids:\n{}".format(example["rejected_labels"]))
print(f"rejected_labels:\n{tokenizer.decode(valid_rejected_labels, skip_special_tokens=False)}")
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from itertools import chain
from typing import TYPE_CHECKING, Any, Dict, List
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
from ...hparams import DataArguments
def preprocess_pretrain_dataset(
examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments"
) -> Dict[str, List[Any]]:
# build grouped texts with format `X1 X2 X3 ...` if packing is enabled
eos_token = "<|end_of_text|>" if data_args.template == "llama3" else tokenizer.eos_token
text_examples = [messages[0]["content"] + eos_token for messages in examples["_prompt"]]
if not data_args.packing:
if data_args.template == "gemma":
text_examples = [tokenizer.bos_token + example for example in text_examples]
result = tokenizer(text_examples, add_special_tokens=False, truncation=True, max_length=data_args.cutoff_len)
else:
tokenized_examples = tokenizer(text_examples, add_special_tokens=False)
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
block_size = data_args.cutoff_len
total_length = (total_length // block_size) * block_size
result = {
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
for k, t in concatenated_examples.items()
}
if data_args.template == "gemma":
for i in range(len(result["input_ids"])):
result["input_ids"][i][0] = tokenizer.bos_token_id
return result
def print_pretrain_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras import logging
from ...extras.constants import IGNORE_INDEX
from .processor_utils import greedy_knapsack, infer_seqlen
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer, ProcessorMixin
from ...hparams import DataArguments
from ..mm_plugin import ImageInput, VideoInput
from ..template import Template
logger = logging.get_logger(__name__)
def _encode_supervised_example(
prompt: Sequence[Dict[str, str]],
response: Sequence[Dict[str, str]],
system: Optional[str],
tools: Optional[str],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
cutoff_len: int,
train_on_prompt: bool,
mask_history: bool,
) -> Tuple[List[int], List[int]]:
messages = template.mm_plugin.process_messages(prompt + response, images, videos, processor)
input_ids, labels = template.mm_plugin.process_token_ids([], [], images, videos, tokenizer, processor)
encoded_pairs = template.encode_multiturn(tokenizer, messages, system, tools)
total_length = len(input_ids) + (1 if template.efficient_eos else 0)
if mask_history:
encoded_pairs = encoded_pairs[::-1] # high priority for last turns
for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs):
if total_length >= cutoff_len:
break
source_len, target_len = infer_seqlen(len(source_ids), len(target_ids), cutoff_len - total_length)
source_ids = source_ids[:source_len]
target_ids = target_ids[:target_len]
total_length += source_len + target_len
if train_on_prompt:
source_label = source_ids
elif template.efficient_eos:
source_label = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (source_len - 1)
else:
source_label = [IGNORE_INDEX] * source_len
if mask_history and turn_idx != 0: # train on the last turn only
target_label = [IGNORE_INDEX] * target_len
else:
target_label = target_ids
if mask_history: # reversed sequences
input_ids = source_ids + target_ids + input_ids
labels = source_label + target_label + labels
else:
input_ids += source_ids + target_ids
labels += source_label + target_label
if template.efficient_eos:
input_ids += [tokenizer.eos_token_id]
labels += [tokenizer.eos_token_id]
return input_ids, labels
def preprocess_supervised_dataset(
examples: Dict[str, List[Any]],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
data_args: "DataArguments",
) -> Dict[str, List[Any]]:
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
model_inputs = defaultdict(list)
for i in range(len(examples["_prompt"])):
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1:
logger.warning_rank0(
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
)
continue
input_ids, labels = _encode_supervised_example(
prompt=examples["_prompt"][i],
response=examples["_response"][i],
system=examples["_system"][i],
tools=examples["_tools"][i],
images=examples["_images"][i] or [],
videos=examples["_videos"][i] or [],
template=template,
tokenizer=tokenizer,
processor=processor,
cutoff_len=data_args.cutoff_len,
train_on_prompt=data_args.train_on_prompt,
mask_history=data_args.mask_history,
)
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels)
model_inputs["images"].append(examples["_images"][i])
model_inputs["videos"].append(examples["_videos"][i])
return model_inputs
def preprocess_packed_supervised_dataset(
examples: Dict[str, List[Any]],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
data_args: "DataArguments",
) -> Dict[str, List[Any]]:
# TODO: use `position_ids` to achieve packing
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
valid_num = 0
batch_input_ids, batch_labels, batch_images, batch_videos = [], [], [], []
lengths = []
length2indexes = defaultdict(list)
for i in range(len(examples["_prompt"])):
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1:
logger.warning_rank0(
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
)
continue
input_ids, labels = _encode_supervised_example(
prompt=examples["_prompt"][i],
response=examples["_response"][i],
system=examples["_system"][i],
tools=examples["_tools"][i],
images=examples["_images"][i] or [],
videos=examples["_videos"][i] or [],
template=template,
tokenizer=tokenizer,
processor=processor,
cutoff_len=data_args.cutoff_len - 1, # reserved for the padding token
train_on_prompt=data_args.train_on_prompt,
mask_history=data_args.mask_history,
)
length = len(input_ids)
if length > data_args.cutoff_len:
logger.warning_rank0(f"Dropped lengthy example with length {length} > {data_args.cutoff_len}.")
else:
lengths.append(length)
length2indexes[length].append(valid_num)
batch_input_ids.append(input_ids)
batch_labels.append(labels)
batch_images.append(examples["_images"][i] or [])
batch_videos.append(examples["_videos"][i] or [])
valid_num += 1
model_inputs = defaultdict(list)
knapsacks = greedy_knapsack(lengths, data_args.cutoff_len - 1) # reserved for the padding token
for knapsack in knapsacks:
packed_input_ids, packed_attention_masks, packed_labels = [], [], []
packed_images, packed_videos = [], []
for i, length in enumerate(knapsack):
index = length2indexes[length].pop()
packed_input_ids += batch_input_ids[index]
packed_labels += batch_labels[index]
packed_images += batch_images[index]
packed_videos += batch_videos[index]
if data_args.neat_packing:
packed_attention_masks += [i + 1] * len(batch_input_ids[index]) # start from 1
else:
packed_attention_masks += [1] * len(batch_input_ids[index])
if len(packed_input_ids) < data_args.cutoff_len:
pad_length = data_args.cutoff_len - len(packed_input_ids)
packed_input_ids += [tokenizer.pad_token_id] * pad_length
packed_labels += [IGNORE_INDEX] * pad_length
if data_args.neat_packing:
packed_attention_masks += [0] * pad_length
else:
packed_attention_masks += [1] * pad_length # more efficient flash_attn
if len(packed_input_ids) != data_args.cutoff_len:
raise ValueError("The length of packed example should be identical to the cutoff length.")
model_inputs["input_ids"].append(packed_input_ids)
model_inputs["attention_mask"].append(packed_attention_masks)
model_inputs["labels"].append(packed_labels)
model_inputs["images"].append(packed_images or None)
model_inputs["videos"].append(packed_videos or None)
return model_inputs
def print_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
valid_labels = list(filter(lambda x: x != IGNORE_INDEX, example["labels"]))
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
print("label_ids:\n{}".format(example["labels"]))
print(f"labels:\n{tokenizer.decode(valid_labels, skip_special_tokens=False)}")
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras import logging
from ..data_utils import Role
from .processor_utils import infer_seqlen
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer, ProcessorMixin
from ...hparams import DataArguments
from ..mm_plugin import ImageInput, VideoInput
from ..template import Template
logger = logging.get_logger(__name__)
def _encode_unsupervised_example(
prompt: Sequence[Dict[str, str]],
response: Sequence[Dict[str, str]],
system: Optional[str],
tools: Optional[str],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
cutoff_len: int,
) -> Tuple[List[int], List[int]]:
if len(response) == 1:
messages = prompt + response
else:
messages = prompt + [{"role": Role.ASSISTANT.value, "content": ""}]
messages = template.mm_plugin.process_messages(messages, images, videos, processor)
input_ids, labels = template.encode_oneturn(tokenizer, messages, system, tools)
if template.efficient_eos:
labels += [tokenizer.eos_token_id]
input_ids, _ = template.mm_plugin.process_token_ids(input_ids, None, images, videos, tokenizer, processor)
source_len, target_len = infer_seqlen(len(input_ids), len(labels), cutoff_len)
input_ids = input_ids[:source_len]
labels = labels[:target_len]
return input_ids, labels
def preprocess_unsupervised_dataset(
examples: Dict[str, List[Any]],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
data_args: "DataArguments",
) -> Dict[str, List[Any]]:
# build inputs with format `<bos> X` and labels with format `Y <eos>`
model_inputs = defaultdict(list)
for i in range(len(examples["_prompt"])):
if len(examples["_prompt"][i]) % 2 != 1:
logger.warning_rank0(
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
)
continue
input_ids, labels = _encode_unsupervised_example(
prompt=examples["_prompt"][i],
response=examples["_response"][i],
system=examples["_system"][i],
tools=examples["_tools"][i],
images=examples["_images"][i] or [],
videos=examples["_videos"][i] or [],
template=template,
tokenizer=tokenizer,
processor=processor,
cutoff_len=data_args.cutoff_len,
)
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels)
model_inputs["images"].append(examples["_images"][i])
model_inputs["videos"].append(examples["_videos"][i])
return model_inputs
def print_unsupervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
print("label_ids:\n{}".format(example["labels"]))
print("labels:\n{}".format(tokenizer.decode(example["labels"], skip_special_tokens=False)))
# Copyright 2024 the LlamaFactory team. # Copyright 2025 the LlamaFactory team.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Type, Union
from typing_extensions import override from typing_extensions import override
...@@ -47,6 +47,7 @@ class Template: ...@@ -47,6 +47,7 @@ class Template:
format_prefix: "Formatter" format_prefix: "Formatter"
default_system: str default_system: str
stop_words: List[str] stop_words: List[str]
thought_words: Tuple[str, str]
efficient_eos: bool efficient_eos: bool
replace_eos: bool replace_eos: bool
replace_jinja_template: bool replace_jinja_template: bool
...@@ -67,8 +68,8 @@ class Template: ...@@ -67,8 +68,8 @@ class Template:
for encoded_ids in encoded_messages[:-1]: for encoded_ids in encoded_messages[:-1]:
prompt_ids += encoded_ids prompt_ids += encoded_ids
answer_ids = encoded_messages[-1] response_ids = encoded_messages[-1]
return prompt_ids, answer_ids return prompt_ids, response_ids
def encode_multiturn( def encode_multiturn(
self, self,
...@@ -99,6 +100,27 @@ class Template: ...@@ -99,6 +100,27 @@ class Template:
return list(stop_token_ids) return list(stop_token_ids)
def _convert_elements_to_ids(self, tokenizer: "PreTrainedTokenizer", elements: "SLOTS") -> List[int]:
r"""
Converts elements to token ids.
"""
token_ids = []
for elem in elements:
if isinstance(elem, str):
if len(elem) != 0:
token_ids += tokenizer.encode(elem, add_special_tokens=False)
elif isinstance(elem, dict):
token_ids += [tokenizer.convert_tokens_to_ids(elem.get("token"))]
elif isinstance(elem, set):
if "bos_token" in elem and tokenizer.bos_token_id is not None:
token_ids += [tokenizer.bos_token_id]
elif "eos_token" in elem and tokenizer.eos_token_id is not None:
token_ids += [tokenizer.eos_token_id]
else:
raise ValueError(f"Input must be string, set[str] or dict[str, str], got {type(elem)}")
return token_ids
def _encode( def _encode(
self, self,
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
...@@ -109,7 +131,7 @@ class Template: ...@@ -109,7 +131,7 @@ class Template:
r""" r"""
Encodes formatted inputs to pairs of token ids. Encodes formatted inputs to pairs of token ids.
Turn 0: prefix + system + query resp Turn 0: prefix + system + query resp
Turn t: sep + query resp Turn t: query resp
""" """
system = system or self.default_system system = system or self.default_system
encoded_messages = [] encoded_messages = []
...@@ -137,26 +159,179 @@ class Template: ...@@ -137,26 +159,179 @@ class Template:
return encoded_messages return encoded_messages
def _convert_elements_to_ids(self, tokenizer: "PreTrainedTokenizer", elements: "SLOTS") -> List[int]: @staticmethod
def _add_or_replace_eos_token(tokenizer: "PreTrainedTokenizer", eos_token: str) -> None:
r""" r"""
Converts elements to token ids. Adds or replaces eos token to the tokenizer.
""" """
token_ids = [] is_added = tokenizer.eos_token_id is None
for elem in elements: num_added_tokens = tokenizer.add_special_tokens({"eos_token": eos_token})
if isinstance(elem, str):
if len(elem) != 0:
token_ids += tokenizer.encode(elem, add_special_tokens=False)
elif isinstance(elem, dict):
token_ids += [tokenizer.convert_tokens_to_ids(elem.get("token"))]
elif isinstance(elem, set):
if "bos_token" in elem and tokenizer.bos_token_id is not None:
token_ids += [tokenizer.bos_token_id]
elif "eos_token" in elem and tokenizer.eos_token_id is not None:
token_ids += [tokenizer.eos_token_id]
else:
raise ValueError(f"Input must be string, set[str] or dict[str, str], got {type(elem)}")
return token_ids if is_added:
logger.info_rank0(f"Add eos token: {tokenizer.eos_token}.")
else:
logger.info_rank0(f"Replace eos token: {tokenizer.eos_token}.")
if num_added_tokens > 0:
logger.warning_rank0("New tokens have been added, make sure `resize_vocab` is True.")
def fix_special_tokens(self, tokenizer: "PreTrainedTokenizer") -> None:
r"""
Adds eos token and pad token to the tokenizer.
"""
stop_words = self.stop_words
if self.replace_eos:
if not stop_words:
raise ValueError("Stop words are required to replace the EOS token.")
self._add_or_replace_eos_token(tokenizer, eos_token=stop_words[0])
stop_words = stop_words[1:]
if tokenizer.eos_token_id is None:
self._add_or_replace_eos_token(tokenizer, eos_token="<|endoftext|>")
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token
logger.info_rank0(f"Add pad token: {tokenizer.pad_token}")
if stop_words:
num_added_tokens = tokenizer.add_special_tokens(
dict(additional_special_tokens=stop_words), replace_additional_special_tokens=False
)
logger.info_rank0("Add {} to stop words.".format(",".join(stop_words)))
if num_added_tokens > 0:
logger.warning_rank0("New tokens have been added, make sure `resize_vocab` is True.")
@staticmethod
def _jinja_escape(content: str) -> str:
r"""
Escape single quotes in content.
"""
return content.replace("'", r"\'")
@staticmethod
def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: str = "content") -> str:
r"""
Converts slots to jinja template.
"""
slot_items = []
for slot in slots:
if isinstance(slot, str):
slot_pieces = slot.split("{{content}}")
if slot_pieces[0]:
slot_items.append("'" + Template._jinja_escape(slot_pieces[0]) + "'")
if len(slot_pieces) > 1:
slot_items.append(placeholder)
if slot_pieces[1]:
slot_items.append("'" + Template._jinja_escape(slot_pieces[1]) + "'")
elif isinstance(slot, set): # do not use {{ eos_token }} since it may be replaced
if "bos_token" in slot and tokenizer.bos_token_id is not None:
slot_items.append("'" + tokenizer.bos_token + "'")
elif "eos_token" in slot and tokenizer.eos_token_id is not None:
slot_items.append("'" + tokenizer.eos_token + "'")
elif isinstance(slot, dict):
raise ValueError("Dict is not supported.")
return " + ".join(slot_items)
def _get_jinja_template(self, tokenizer: "PreTrainedTokenizer") -> str:
r"""
Returns the jinja template.
"""
prefix = self._convert_slots_to_jinja(self.format_prefix.apply(), tokenizer)
system = self._convert_slots_to_jinja(self.format_system.apply(), tokenizer, placeholder="system_message")
user = self._convert_slots_to_jinja(self.format_user.apply(), tokenizer)
assistant = self._convert_slots_to_jinja(self.format_assistant.apply(), tokenizer)
jinja_template = ""
if prefix:
jinja_template += "{{ " + prefix + " }}"
if self.default_system:
jinja_template += "{% set system_message = '" + self._jinja_escape(self.default_system) + "' %}"
jinja_template += (
"{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}"
"{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% endif %}"
"{% if system_message is defined %}{{ " + system + " }}{% endif %}"
"{% for message in loop_messages %}"
"{% set content = message['content'] %}"
"{% if message['role'] == 'user' %}"
"{{ " + user + " }}"
"{% elif message['role'] == 'assistant' %}"
"{{ " + assistant + " }}"
"{% endif %}"
"{% endfor %}"
)
return jinja_template
def fix_jinja_template(self, tokenizer: "PreTrainedTokenizer") -> None:
r"""
Replaces the jinja template in the tokenizer.
"""
if tokenizer.chat_template is None or self.replace_jinja_template:
try:
tokenizer.chat_template = self._get_jinja_template(tokenizer)
except ValueError as e:
logger.info_rank0(f"Cannot add this chat template to tokenizer: {e}.")
@staticmethod
def _convert_slots_to_ollama(
slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: str = "content"
) -> str:
r"""
Converts slots to ollama template.
"""
slot_items = []
for slot in slots:
if isinstance(slot, str):
slot_pieces = slot.split("{{content}}")
if slot_pieces[0]:
slot_items.append(slot_pieces[0])
if len(slot_pieces) > 1:
slot_items.append("{{ " + placeholder + " }}")
if slot_pieces[1]:
slot_items.append(slot_pieces[1])
elif isinstance(slot, set): # do not use {{ eos_token }} since it may be replaced
if "bos_token" in slot and tokenizer.bos_token_id is not None:
slot_items.append(tokenizer.bos_token)
elif "eos_token" in slot and tokenizer.eos_token_id is not None:
slot_items.append(tokenizer.eos_token)
elif isinstance(slot, dict):
raise ValueError("Dict is not supported.")
return "".join(slot_items)
def _get_ollama_template(self, tokenizer: "PreTrainedTokenizer") -> str:
r"""
Returns the ollama template.
"""
prefix = self._convert_slots_to_ollama(self.format_prefix.apply(), tokenizer)
system = self._convert_slots_to_ollama(self.format_system.apply(), tokenizer, placeholder=".System")
user = self._convert_slots_to_ollama(self.format_user.apply(), tokenizer, placeholder=".Content")
assistant = self._convert_slots_to_ollama(self.format_assistant.apply(), tokenizer, placeholder=".Content")
return (
f"{prefix}{{{{ if .System }}}}{system}{{{{ end }}}}"
f"""{{{{ range .Messages }}}}{{{{ if eq .Role "user" }}}}{user}"""
f"""{{{{ else if eq .Role "assistant" }}}}{assistant}{{{{ end }}}}{{{{ end }}}}"""
)
def get_ollama_modelfile(self, tokenizer: "PreTrainedTokenizer") -> str:
r"""
Returns the ollama modelfile.
TODO: support function calling.
"""
modelfile = "# ollama modelfile auto-generated by llamafactory\n\n"
modelfile += f'FROM .\n\nTEMPLATE """{self._get_ollama_template(tokenizer)}"""\n\n'
if self.default_system:
modelfile += f'SYSTEM """{self.default_system}"""\n\n'
for stop_token_id in self.get_stop_token_ids(tokenizer):
modelfile += f'PARAMETER stop "{tokenizer.convert_ids_to_tokens(stop_token_id)}"\n'
modelfile += "PARAMETER num_ctx 4096\n"
return modelfile
@dataclass @dataclass
...@@ -169,11 +344,6 @@ class Llama2Template(Template): ...@@ -169,11 +344,6 @@ class Llama2Template(Template):
system: str, system: str,
tools: str, tools: str,
) -> List[List[int]]: ) -> List[List[int]]:
r"""
Encodes formatted inputs to pairs of token ids.
Turn 0: prefix + system + query resp
Turn t: sep + query resp
"""
system = system or self.default_system system = system or self.default_system
encoded_messages = [] encoded_messages = []
for i, message in enumerate(messages): for i, message in enumerate(messages):
...@@ -201,11 +371,41 @@ class Llama2Template(Template): ...@@ -201,11 +371,41 @@ class Llama2Template(Template):
return encoded_messages return encoded_messages
def _get_jinja_template(self, tokenizer: "PreTrainedTokenizer") -> str:
prefix = self._convert_slots_to_jinja(self.format_prefix.apply(), tokenizer)
system_message = self._convert_slots_to_jinja(
self.format_system.apply(), tokenizer, placeholder="system_message"
)
user_message = self._convert_slots_to_jinja(self.format_user.apply(), tokenizer)
assistant_message = self._convert_slots_to_jinja(self.format_assistant.apply(), tokenizer)
jinja_template = ""
if prefix:
jinja_template += "{{ " + prefix + " }}"
if self.default_system:
jinja_template += "{% set system_message = '" + self._jinja_escape(self.default_system) + "' %}"
jinja_template += (
"{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}"
"{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% endif %}"
"{% for message in loop_messages %}"
"{% if loop.index0 == 0 and system_message is defined %}"
"{% set content = " + system_message + " + message['content'] %}"
"{% else %}{% set content = message['content'] %}{% endif %}"
"{% if message['role'] == 'user' %}"
"{{ " + user_message + " }}"
"{% elif message['role'] == 'assistant' %}"
"{{ " + assistant_message + " }}"
"{% endif %}"
"{% endfor %}"
)
return jinja_template
TEMPLATES: Dict[str, "Template"] = {} TEMPLATES: Dict[str, "Template"] = {}
def _register_template( def register_template(
name: str, name: str,
format_user: Optional["Formatter"] = None, format_user: Optional["Formatter"] = None,
format_assistant: Optional["Formatter"] = None, format_assistant: Optional["Formatter"] = None,
...@@ -216,10 +416,12 @@ def _register_template( ...@@ -216,10 +416,12 @@ def _register_template(
format_prefix: Optional["Formatter"] = None, format_prefix: Optional["Formatter"] = None,
default_system: str = "", default_system: str = "",
stop_words: Optional[Sequence[str]] = None, stop_words: Optional[Sequence[str]] = None,
thought_words: Optional[Tuple[str, str]] = None,
efficient_eos: bool = False, efficient_eos: bool = False,
replace_eos: bool = False, replace_eos: bool = False,
replace_jinja_template: bool = False, replace_jinja_template: bool = False,
mm_plugin: "BasePlugin" = get_mm_plugin(name="base"), mm_plugin: "BasePlugin" = get_mm_plugin(name="base"),
template_class: Type["Template"] = Template,
) -> None: ) -> None:
r""" r"""
Registers a chat template. Registers a chat template.
...@@ -234,7 +436,7 @@ def _register_template( ...@@ -234,7 +436,7 @@ def _register_template(
The corresponding code should be: The corresponding code should be:
``` ```
_register_template( register_template(
name="custom", name="custom",
format_user=StringFormatter(slots=["<user>{{content}}\n<model>"]), format_user=StringFormatter(slots=["<user>{{content}}\n<model>"]),
format_assistant=StringFormatter(slots=["{{content}}</s>\n"]), format_assistant=StringFormatter(slots=["{{content}}</s>\n"]),
...@@ -242,7 +444,9 @@ def _register_template( ...@@ -242,7 +444,9 @@ def _register_template(
) )
``` ```
""" """
template_class = Llama2Template if any(k in name for k in ("llama2", "mistral", "pixtral")) else Template if name in TEMPLATES:
raise ValueError(f"Template {name} already exists.")
default_slots = ["{{content}}"] if efficient_eos else ["{{content}}", {"eos_token"}] default_slots = ["{{content}}"] if efficient_eos else ["{{content}}", {"eos_token"}]
default_user_formatter = StringFormatter(slots=["{{content}}"]) default_user_formatter = StringFormatter(slots=["{{content}}"])
default_assistant_formatter = StringFormatter(slots=default_slots) default_assistant_formatter = StringFormatter(slots=default_slots)
...@@ -259,6 +463,7 @@ def _register_template( ...@@ -259,6 +463,7 @@ def _register_template(
format_prefix=format_prefix or default_prefix_formatter, format_prefix=format_prefix or default_prefix_formatter,
default_system=default_system, default_system=default_system,
stop_words=stop_words or [], stop_words=stop_words or [],
thought_words=thought_words or ("<think>", "</think>"),
efficient_eos=efficient_eos, efficient_eos=efficient_eos,
replace_eos=replace_eos, replace_eos=replace_eos,
replace_jinja_template=replace_jinja_template, replace_jinja_template=replace_jinja_template,
...@@ -266,97 +471,83 @@ def _register_template( ...@@ -266,97 +471,83 @@ def _register_template(
) )
def _add_or_replace_eos_token(tokenizer: "PreTrainedTokenizer", eos_token: str) -> None: def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template":
is_added = tokenizer.eos_token_id is None
num_added_tokens = tokenizer.add_special_tokens({"eos_token": eos_token})
if is_added:
logger.info_rank0(f"Add eos token: {tokenizer.eos_token}")
else:
logger.info_rank0(f"Replace eos token: {tokenizer.eos_token}")
if num_added_tokens > 0:
logger.warning_rank0("New tokens have been added, make sure `resize_vocab` is True.")
def _jinja_escape(content: str) -> str:
return content.replace("'", r"\'")
def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: str = "content") -> str:
slot_items = []
for slot in slots:
if isinstance(slot, str):
slot_pieces = slot.split("{{content}}")
if slot_pieces[0]:
slot_items.append("'" + _jinja_escape(slot_pieces[0]) + "'")
if len(slot_pieces) > 1:
slot_items.append(placeholder)
if slot_pieces[1]:
slot_items.append("'" + _jinja_escape(slot_pieces[1]) + "'")
elif isinstance(slot, set): # do not use {{ eos_token }} since it may be replaced
if "bos_token" in slot and tokenizer.bos_token_id is not None:
slot_items.append("'" + tokenizer.bos_token + "'")
elif "eos_token" in slot and tokenizer.eos_token_id is not None:
slot_items.append("'" + tokenizer.eos_token + "'")
elif isinstance(slot, dict):
raise ValueError("Dict is not supported.")
return " + ".join(slot_items)
def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer") -> str:
r""" r"""
Returns the jinja template. Extracts a chat template from the tokenizer.
""" """
jinja_template = ""
prefix = _convert_slots_to_jinja(template.format_prefix.apply(), tokenizer)
if prefix:
jinja_template += "{{ " + prefix + " }}"
if template.default_system: def find_diff(short_str: str, long_str: str) -> str:
jinja_template += "{% set system_message = '" + _jinja_escape(template.default_system) + "' %}" i, j = 0, 0
diff = ""
jinja_template += ( while i < len(short_str) and j < len(long_str):
"{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}" if short_str[i] == long_str[j]:
"{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% endif %}" i += 1
j += 1
else:
diff += long_str[j]
j += 1
return diff
prefix = tokenizer.decode(tokenizer.encode(""))
messages = [{"role": "system", "content": "{{content}}"}]
system_slot = tokenizer.apply_chat_template(messages, add_generation_prompt=False, tokenize=False)[len(prefix) :]
messages = [{"role": "system", "content": ""}, {"role": "user", "content": "{{content}}"}]
user_slot_empty_system = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
user_slot_empty_system = user_slot_empty_system[len(prefix) :]
messages = [{"role": "user", "content": "{{content}}"}]
user_slot = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
user_slot = user_slot[len(prefix) :]
messages = [{"role": "user", "content": "{{content}}"}, {"role": "assistant", "content": "{{content}}"}]
assistant_slot = tokenizer.apply_chat_template(messages, add_generation_prompt=False, tokenize=False)
assistant_slot = assistant_slot[len(prefix) + len(user_slot) :]
if len(user_slot) > len(user_slot_empty_system):
default_system = find_diff(user_slot_empty_system, user_slot)
sole_system = system_slot.replace("{{content}}", default_system, 1)
user_slot = user_slot[len(sole_system) :]
else: # if defaut_system is empty, user_slot_empty_system will be longer than user_slot
default_system = ""
return Template(
format_user=StringFormatter(slots=[user_slot]),
format_assistant=StringFormatter(slots=[assistant_slot]),
format_system=StringFormatter(slots=[system_slot]),
format_function=FunctionFormatter(slots=[assistant_slot], tool_format="default"),
format_observation=StringFormatter(slots=[user_slot]),
format_tools=ToolFormatter(tool_format="default"),
format_prefix=EmptyFormatter(slots=[prefix]) if prefix else EmptyFormatter(),
default_system=default_system,
stop_words=[],
thought_words=("<think>", "</think>"),
efficient_eos=False,
replace_eos=False,
replace_jinja_template=False,
mm_plugin=get_mm_plugin(name="base"),
) )
system_message = _convert_slots_to_jinja(template.format_system.apply(), tokenizer, placeholder="system_message")
if not isinstance(template, Llama2Template):
jinja_template += "{% if system_message is defined %}{{ " + system_message + " }}{% endif %}"
jinja_template += "{% for message in loop_messages %}"
jinja_template += "{% set content = message['content'] %}"
if isinstance(template, Llama2Template):
jinja_template += "{% if loop.index0 == 0 and system_message is defined %}"
jinja_template += "{% set content = " + system_message + " + message['content'] %}"
jinja_template += "{% endif %}"
jinja_template += "{% if message['role'] == 'user' %}"
user_message = _convert_slots_to_jinja(template.format_user.apply(), tokenizer)
jinja_template += "{{ " + user_message + " }}"
jinja_template += "{% elif message['role'] == 'assistant' %}"
assistant_message = _convert_slots_to_jinja(template.format_assistant.apply(), tokenizer)
jinja_template += "{{ " + assistant_message + " }}"
jinja_template += "{% endif %}"
jinja_template += "{% endfor %}"
return jinja_template
def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: "DataArguments") -> "Template": def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: "DataArguments") -> "Template":
r""" r"""
Gets chat template and fixes the tokenizer. Gets chat template and fixes the tokenizer.
""" """
if data_args.template is None: if data_args.template is None:
template = TEMPLATES["empty"] # placeholder if isinstance(tokenizer.chat_template, str):
logger.warning_rank0("`template` was not specified, try parsing the chat template from the tokenizer.")
template = parse_template(tokenizer)
else:
logger.warning_rank0("`template` was not specified, use `empty` template.")
template = TEMPLATES["empty"] # placeholder
else: else:
template = TEMPLATES.get(data_args.template, None) if data_args.template not in TEMPLATES:
if template is None:
raise ValueError(f"Template {data_args.template} does not exist.") raise ValueError(f"Template {data_args.template} does not exist.")
template = TEMPLATES[data_args.template]
if template.mm_plugin.__class__.__name__ != "BasePlugin": if template.mm_plugin.__class__.__name__ != "BasePlugin":
check_version("transformers>=4.45.0") check_version("transformers>=4.45.0")
...@@ -369,39 +560,12 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: ...@@ -369,39 +560,12 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
template.format_function = FunctionFormatter(slots=default_slots, tool_format=data_args.tool_format) template.format_function = FunctionFormatter(slots=default_slots, tool_format=data_args.tool_format)
template.format_tools = ToolFormatter(tool_format=data_args.tool_format) template.format_tools = ToolFormatter(tool_format=data_args.tool_format)
stop_words = template.stop_words template.fix_special_tokens(tokenizer)
if template.replace_eos: template.fix_jinja_template(tokenizer)
if not stop_words:
raise ValueError("Stop words are required to replace the EOS token.")
_add_or_replace_eos_token(tokenizer, eos_token=stop_words[0])
stop_words = stop_words[1:]
if tokenizer.eos_token_id is None:
_add_or_replace_eos_token(tokenizer, eos_token="<|endoftext|>")
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token
logger.info_rank0(f"Add pad token: {tokenizer.pad_token}")
if stop_words:
num_added_tokens = tokenizer.add_special_tokens(
dict(additional_special_tokens=stop_words), replace_additional_special_tokens=False
)
logger.info_rank0("Add {} to stop words.".format(",".join(stop_words)))
if num_added_tokens > 0:
logger.warning_rank0("New tokens have been added, make sure `resize_vocab` is True.")
if tokenizer.chat_template is None or template.replace_jinja_template:
try:
tokenizer.chat_template = _get_jinja_template(template, tokenizer)
except ValueError as e:
logger.info_rank0(f"Cannot add this chat template to tokenizer: {e}.")
return template return template
_register_template( register_template(
name="alpaca", name="alpaca",
format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n\n### Response:\n"]), format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n\n### Response:\n"]),
format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}, "\n\n"]), format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}, "\n\n"]),
...@@ -412,7 +576,7 @@ _register_template( ...@@ -412,7 +576,7 @@ _register_template(
) )
_register_template( register_template(
name="aquila", name="aquila",
format_user=StringFormatter(slots=["Human: {{content}}###Assistant:"]), format_user=StringFormatter(slots=["Human: {{content}}###Assistant:"]),
format_assistant=StringFormatter(slots=["{{content}}###"]), format_assistant=StringFormatter(slots=["{{content}}###"]),
...@@ -425,7 +589,7 @@ _register_template( ...@@ -425,7 +589,7 @@ _register_template(
) )
_register_template( register_template(
name="atom", name="atom",
format_user=StringFormatter( format_user=StringFormatter(
slots=[{"bos_token"}, "Human: {{content}}\n", {"eos_token"}, {"bos_token"}, "Assistant:"] slots=[{"bos_token"}, "Human: {{content}}\n", {"eos_token"}, {"bos_token"}, "Assistant:"]
...@@ -434,21 +598,31 @@ _register_template( ...@@ -434,21 +598,31 @@ _register_template(
) )
_register_template( register_template(
name="baichuan", name="baichuan",
format_user=StringFormatter(slots=[{"token": "<reserved_102>"}, "{{content}}", {"token": "<reserved_103>"}]), format_user=StringFormatter(slots=[{"token": "<reserved_102>"}, "{{content}}", {"token": "<reserved_103>"}]),
efficient_eos=True, efficient_eos=True,
) )
_register_template( register_template(
name="baichuan2", name="baichuan2",
format_user=StringFormatter(slots=["<reserved_106>{{content}}<reserved_107>"]), format_user=StringFormatter(slots=["<reserved_106>{{content}}<reserved_107>"]),
efficient_eos=True, efficient_eos=True,
) )
_register_template( register_template(
name="bailing",
format_user=StringFormatter(slots=["<role>HUMAN</role>{{content}}<role>ASSISTANT</role>"]),
format_system=StringFormatter(slots=["<role>SYSTEM</role>{{content}}"]),
format_observation=StringFormatter(slots=["<role>OBSERVATION</role>{{content}}<role>ASSISTANT</role>"]),
stop_words=["<|endoftext|>"],
efficient_eos=True,
)
register_template(
name="belle", name="belle",
format_user=StringFormatter(slots=["Human: {{content}}\n\nBelle: "]), format_user=StringFormatter(slots=["Human: {{content}}\n\nBelle: "]),
format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}, "\n\n"]), format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}, "\n\n"]),
...@@ -456,13 +630,13 @@ _register_template( ...@@ -456,13 +630,13 @@ _register_template(
) )
_register_template( register_template(
name="bluelm", name="bluelm",
format_user=StringFormatter(slots=[{"token": "[|Human|]:"}, "{{content}}", {"token": "[|AI|]:"}]), format_user=StringFormatter(slots=[{"token": "[|Human|]:"}, "{{content}}", {"token": "[|AI|]:"}]),
) )
_register_template( register_template(
name="breeze", name="breeze",
format_user=StringFormatter(slots=["[INST] {{content}} [/INST] "]), format_user=StringFormatter(slots=["[INST] {{content}} [/INST] "]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
...@@ -470,7 +644,7 @@ _register_template( ...@@ -470,7 +644,7 @@ _register_template(
) )
_register_template( register_template(
name="chatglm2", name="chatglm2",
format_user=StringFormatter(slots=["[Round {{idx}}]\n\n问:{{content}}\n\n答:"]), format_user=StringFormatter(slots=["[Round {{idx}}]\n\n问:{{content}}\n\n答:"]),
format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]), format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]),
...@@ -478,7 +652,7 @@ _register_template( ...@@ -478,7 +652,7 @@ _register_template(
) )
_register_template( register_template(
name="chatglm3", name="chatglm3",
format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]), format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
format_assistant=StringFormatter(slots=["\n", "{{content}}"]), format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
...@@ -494,7 +668,7 @@ _register_template( ...@@ -494,7 +668,7 @@ _register_template(
) )
_register_template( register_template(
name="chatml", name="chatml",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
...@@ -507,7 +681,7 @@ _register_template( ...@@ -507,7 +681,7 @@ _register_template(
# copied from chatml template # copied from chatml template
_register_template( register_template(
name="chatml_de", name="chatml_de",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
...@@ -520,13 +694,13 @@ _register_template( ...@@ -520,13 +694,13 @@ _register_template(
) )
_register_template( register_template(
name="codegeex2", name="codegeex2",
format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]), format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]),
) )
_register_template( register_template(
name="codegeex4", name="codegeex4",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>\n"]), format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>\n"]),
format_system=StringFormatter(slots=["<|system|>\n{{content}}"]), format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
...@@ -543,7 +717,7 @@ _register_template( ...@@ -543,7 +717,7 @@ _register_template(
) )
_register_template( register_template(
name="cohere", name="cohere",
format_user=StringFormatter( format_user=StringFormatter(
slots=[ slots=[
...@@ -558,7 +732,7 @@ _register_template( ...@@ -558,7 +732,7 @@ _register_template(
) )
_register_template( register_template(
name="cpm", name="cpm",
format_user=StringFormatter(slots=["<用户>{{content}}<AI>"]), format_user=StringFormatter(slots=["<用户>{{content}}<AI>"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
...@@ -566,7 +740,7 @@ _register_template( ...@@ -566,7 +740,7 @@ _register_template(
# copied from chatml template # copied from chatml template
_register_template( register_template(
name="cpm3", name="cpm3",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
...@@ -577,7 +751,7 @@ _register_template( ...@@ -577,7 +751,7 @@ _register_template(
# copied from chatml template # copied from chatml template
_register_template( register_template(
name="dbrx", name="dbrx",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
...@@ -602,7 +776,7 @@ _register_template( ...@@ -602,7 +776,7 @@ _register_template(
) )
_register_template( register_template(
name="deepseek", name="deepseek",
format_user=StringFormatter(slots=["User: {{content}}\n\nAssistant:"]), format_user=StringFormatter(slots=["User: {{content}}\n\nAssistant:"]),
format_system=StringFormatter(slots=["{{content}}\n\n"]), format_system=StringFormatter(slots=["{{content}}\n\n"]),
...@@ -610,14 +784,14 @@ _register_template( ...@@ -610,14 +784,14 @@ _register_template(
) )
_register_template( register_template(
name="deepseek3", name="deepseek3",
format_user=StringFormatter(slots=["<|User|>{{content}}<|Assistant|>"]), format_user=StringFormatter(slots=["<|User|>{{content}}<|Assistant|>"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
) )
_register_template( register_template(
name="deepseekcoder", name="deepseekcoder",
format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n### Response:"]), format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n### Response:"]),
format_assistant=StringFormatter(slots=["\n{{content}}\n<|EOT|>\n"]), format_assistant=StringFormatter(slots=["\n{{content}}\n<|EOT|>\n"]),
...@@ -631,7 +805,7 @@ _register_template( ...@@ -631,7 +805,7 @@ _register_template(
) )
_register_template( register_template(
name="default", name="default",
format_user=StringFormatter(slots=["Human: {{content}}\nAssistant:"]), format_user=StringFormatter(slots=["Human: {{content}}\nAssistant:"]),
format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}, "\n"]), format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}, "\n"]),
...@@ -639,13 +813,13 @@ _register_template( ...@@ -639,13 +813,13 @@ _register_template(
) )
_register_template( register_template(
name="empty", name="empty",
efficient_eos=True, format_assistant=StringFormatter(slots=["{{content}}"]),
) )
_register_template( register_template(
name="exaone", name="exaone",
format_user=StringFormatter(slots=["[|user|]{{content}}\n[|assistant|]"]), format_user=StringFormatter(slots=["[|user|]{{content}}\n[|assistant|]"]),
format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}, "\n"]), format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}, "\n"]),
...@@ -653,7 +827,7 @@ _register_template( ...@@ -653,7 +827,7 @@ _register_template(
) )
_register_template( register_template(
name="falcon", name="falcon",
format_user=StringFormatter(slots=["User: {{content}}\nFalcon:"]), format_user=StringFormatter(slots=["User: {{content}}\nFalcon:"]),
format_assistant=StringFormatter(slots=["{{content}}\n"]), format_assistant=StringFormatter(slots=["{{content}}\n"]),
...@@ -661,14 +835,14 @@ _register_template( ...@@ -661,14 +835,14 @@ _register_template(
) )
_register_template( register_template(
name="fewshot", name="fewshot",
format_assistant=StringFormatter(slots=["{{content}}\n\n"]), format_assistant=StringFormatter(slots=["{{content}}\n\n"]),
efficient_eos=True, efficient_eos=True,
) )
_register_template( register_template(
name="gemma", name="gemma",
format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]), format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
format_assistant=StringFormatter(slots=["{{content}}<end_of_turn>\n"]), format_assistant=StringFormatter(slots=["{{content}}<end_of_turn>\n"]),
...@@ -679,7 +853,7 @@ _register_template( ...@@ -679,7 +853,7 @@ _register_template(
) )
_register_template( register_template(
name="glm4", name="glm4",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]), format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
format_assistant=StringFormatter(slots=["\n{{content}}"]), format_assistant=StringFormatter(slots=["\n{{content}}"]),
...@@ -693,7 +867,7 @@ _register_template( ...@@ -693,7 +867,7 @@ _register_template(
) )
_register_template( register_template(
name="granite3", name="granite3",
format_user=StringFormatter( format_user=StringFormatter(
slots=[ slots=[
...@@ -705,7 +879,7 @@ _register_template( ...@@ -705,7 +879,7 @@ _register_template(
) )
_register_template( register_template(
name="index", name="index",
format_user=StringFormatter(slots=["reserved_0{{content}}reserved_1"]), format_user=StringFormatter(slots=["reserved_0{{content}}reserved_1"]),
format_system=StringFormatter(slots=["<unk>{{content}}"]), format_system=StringFormatter(slots=["<unk>{{content}}"]),
...@@ -713,54 +887,59 @@ _register_template( ...@@ -713,54 +887,59 @@ _register_template(
) )
_register_template( register_template(
name="intern", name="intern",
format_user=StringFormatter(slots=["<|User|>:{{content}}\n<|Bot|>:"]), format_user=StringFormatter(slots=["<|User|>:{{content}}\n<|Bot|>:"]),
format_assistant=StringFormatter(slots=["{{content}}<eoa>\n"]), format_assistant=StringFormatter(slots=["{{content}}<eoa>\n"]),
format_system=StringFormatter(slots=["<|System|>:{{content}}\n"]), format_system=StringFormatter(slots=["<|System|>:{{content}}\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
default_system=(
"You are an AI assistant whose name is InternLM (书生·浦语).\n"
"- InternLM (书生·浦语) is a conversational language model that is developed by Shanghai AI Laboratory "
"(上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n"
"- InternLM (书生·浦语) can understand and communicate fluently in the language "
"chosen by the user such as English and 中文."
),
stop_words=["<eoa>"], stop_words=["<eoa>"],
) )
_register_template( register_template(
name="intern2", name="intern2",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
default_system=(
"You are an AI assistant whose name is InternLM (书生·浦语).\n"
"- InternLM (书生·浦语) is a conversational language model that is developed by Shanghai AI Laboratory "
"(上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n"
"- InternLM (书生·浦语) can understand and communicate fluently in the language "
"chosen by the user such as English and 中文."
),
stop_words=["<|im_end|>"], stop_words=["<|im_end|>"],
) )
# copied from intern2 template register_template(
_register_template(
name="intern3",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|im_end|>"],
)
_register_template(
name="llama2", name="llama2",
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]), format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
format_system=StringFormatter(slots=["<<SYS>>\n{{content}}\n<</SYS>>\n\n"]), format_system=StringFormatter(slots=["<<SYS>>\n{{content}}\n<</SYS>>\n\n"]),
template_class=Llama2Template,
) )
# copied from llama2 template # copied from llama2 template
_register_template( register_template(
name="llama2_zh", name="llama2_zh",
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]), format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
format_system=StringFormatter(slots=["<<SYS>>\n{{content}}\n<</SYS>>\n\n"]), format_system=StringFormatter(slots=["<<SYS>>\n{{content}}\n<</SYS>>\n\n"]),
default_system="You are a helpful assistant. 你是一个乐于助人的助手。", default_system="You are a helpful assistant. 你是一个乐于助人的助手。",
template_class=Llama2Template,
) )
_register_template( register_template(
name="llama3", name="llama3",
format_user=StringFormatter( format_user=StringFormatter(
slots=[ slots=[
...@@ -788,7 +967,7 @@ _register_template( ...@@ -788,7 +967,7 @@ _register_template(
# copied from llama3 template # copied from llama3 template
_register_template( register_template(
name="mllama", name="mllama",
format_user=StringFormatter( format_user=StringFormatter(
slots=[ slots=[
...@@ -816,8 +995,20 @@ _register_template( ...@@ -816,8 +995,20 @@ _register_template(
) )
register_template(
name="moonlight",
format_user=StringFormatter(
slots=["<|im_user|>user<|im_middle|>{{content}}<|im_end|><|im_assistant|>assistant<|im_middle|>"]
),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>"]),
format_system=StringFormatter(slots=["<|im_system|>system<|im_middle|>{{content}}<|im_end|>"]),
default_system="You are a helpful assistant provided by Moonshot-AI.",
stop_words=["<|im_end|>"],
)
# copied from vicuna template # copied from vicuna template
_register_template( register_template(
name="llava", name="llava",
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]), format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
default_system=( default_system=(
...@@ -829,7 +1020,7 @@ _register_template( ...@@ -829,7 +1020,7 @@ _register_template(
# copied from vicuna template # copied from vicuna template
_register_template( register_template(
name="llava_next", name="llava_next",
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]), format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
default_system=( default_system=(
...@@ -841,7 +1032,7 @@ _register_template( ...@@ -841,7 +1032,7 @@ _register_template(
# copied from llama3 template # copied from llama3 template
_register_template( register_template(
name="llava_next_llama3", name="llava_next_llama3",
format_user=StringFormatter( format_user=StringFormatter(
slots=[ slots=[
...@@ -870,21 +1061,22 @@ _register_template( ...@@ -870,21 +1061,22 @@ _register_template(
# copied from mistral template # copied from mistral template
_register_template( register_template(
name="llava_next_mistral", name="llava_next_mistral",
format_user=StringFormatter(slots=["[INST] {{content}}[/INST]"]), format_user=StringFormatter(slots=["[INST] {{content}}[/INST]"]),
format_assistant=StringFormatter(slots=[" {{content}}", {"eos_token"}]), format_assistant=StringFormatter(slots=[" {{content}}", {"eos_token"}]),
format_system=StringFormatter(slots=["{{content}}\n\n"]), format_system=StringFormatter(slots=["{{content}}\n\n"]),
format_function=FunctionFormatter(slots=["[TOOL_CALLS] ", "{{content}}", {"eos_token"}], tool_format="mistral"), format_function=FunctionFormatter(slots=["[TOOL_CALLS] {{content}}", {"eos_token"}], tool_format="mistral"),
format_observation=StringFormatter(slots=["""[TOOL_RESULTS] {"content": {{content}}}[/TOOL_RESULTS]"""]), format_observation=StringFormatter(slots=["""[TOOL_RESULTS] {"content": {{content}}}[/TOOL_RESULTS]"""]),
format_tools=ToolFormatter(tool_format="mistral"), format_tools=ToolFormatter(tool_format="mistral"),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"), mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"),
template_class=Llama2Template,
) )
# copied from chatml template # copied from qwen template
_register_template( register_template(
name="llava_next_qwen", name="llava_next_qwen",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
...@@ -901,7 +1093,7 @@ _register_template( ...@@ -901,7 +1093,7 @@ _register_template(
# copied from chatml template # copied from chatml template
_register_template( register_template(
name="llava_next_yi", name="llava_next_yi",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
...@@ -912,7 +1104,7 @@ _register_template( ...@@ -912,7 +1104,7 @@ _register_template(
# copied from vicuna template # copied from vicuna template
_register_template( register_template(
name="llava_next_video", name="llava_next_video",
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]), format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
default_system=( default_system=(
...@@ -924,21 +1116,22 @@ _register_template( ...@@ -924,21 +1116,22 @@ _register_template(
# copied from mistral template # copied from mistral template
_register_template( register_template(
name="llava_next_video_mistral", name="llava_next_video_mistral",
format_user=StringFormatter(slots=["[INST] {{content}}[/INST]"]), format_user=StringFormatter(slots=["[INST] {{content}}[/INST]"]),
format_assistant=StringFormatter(slots=[" {{content}}", {"eos_token"}]), format_assistant=StringFormatter(slots=[" {{content}}", {"eos_token"}]),
format_system=StringFormatter(slots=["{{content}}\n\n"]), format_system=StringFormatter(slots=["{{content}}\n\n"]),
format_function=FunctionFormatter(slots=["[TOOL_CALLS] ", "{{content}}", {"eos_token"}], tool_format="mistral"), format_function=FunctionFormatter(slots=["[TOOL_CALLS] {{content}}", {"eos_token"}], tool_format="mistral"),
format_observation=StringFormatter(slots=["""[TOOL_RESULTS] {"content": {{content}}}[/TOOL_RESULTS]"""]), format_observation=StringFormatter(slots=["""[TOOL_RESULTS] {"content": {{content}}}[/TOOL_RESULTS]"""]),
format_tools=ToolFormatter(tool_format="mistral"), format_tools=ToolFormatter(tool_format="mistral"),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
mm_plugin=get_mm_plugin(name="llava_next_video", image_token="<image>", video_token="<video>"), mm_plugin=get_mm_plugin(name="llava_next_video", image_token="<image>", video_token="<video>"),
template_class=Llama2Template,
) )
# copied from chatml template # copied from chatml template
_register_template( register_template(
name="llava_next_video_yi", name="llava_next_video_yi",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
...@@ -949,7 +1142,7 @@ _register_template( ...@@ -949,7 +1142,7 @@ _register_template(
# copied from chatml template # copied from chatml template
_register_template( register_template(
name="marco", name="marco",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
...@@ -965,43 +1158,83 @@ _register_template( ...@@ -965,43 +1158,83 @@ _register_template(
# copied from chatml template # copied from chatml template
_register_template( register_template(
name="minicpm_v", name="minicpm_v",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
stop_words=["<|im_end|>"], stop_words=["<|im_end|>"],
default_system="You are a helpful assistant.",
mm_plugin=get_mm_plugin(name="minicpm_v", image_token="<image>", video_token="<video>"), mm_plugin=get_mm_plugin(name="minicpm_v", image_token="<image>", video_token="<video>"),
) )
_register_template( # copied from minicpm_v template
register_template(
name="minicpm_o",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
stop_words=["<|im_end|>"],
default_system="You are Qwen, created by Alibaba Cloud. You are a helpful assistant.",
mm_plugin=get_mm_plugin(name="minicpm_v", image_token="<image>", video_token="<video>", audio_token="<audio>"),
)
# mistral tokenizer v3 tekken
register_template(
name="ministral",
format_user=StringFormatter(slots=["[INST]{{content}}[/INST]"]),
format_system=StringFormatter(slots=["{{content}}\n\n"]),
format_function=FunctionFormatter(slots=["[TOOL_CALLS]{{content}}", {"eos_token"}], tool_format="mistral"),
format_observation=StringFormatter(slots=["""[TOOL_RESULTS]{"content": {{content}}}[/TOOL_RESULTS]"""]),
format_tools=ToolFormatter(tool_format="mistral"),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
template_class=Llama2Template,
)
# mistral tokenizer v3
register_template(
name="mistral", name="mistral",
format_user=StringFormatter(slots=["[INST] {{content}}[/INST]"]), format_user=StringFormatter(slots=["[INST] {{content}}[/INST]"]),
format_assistant=StringFormatter(slots=[" {{content}}", {"eos_token"}]), format_assistant=StringFormatter(slots=[" {{content}}", {"eos_token"}]),
format_system=StringFormatter(slots=["{{content}}\n\n"]), format_system=StringFormatter(slots=["{{content}}\n\n"]),
format_function=FunctionFormatter(slots=["[TOOL_CALLS] ", "{{content}}", {"eos_token"}], tool_format="mistral"), format_function=FunctionFormatter(slots=["[TOOL_CALLS] {{content}}", {"eos_token"}], tool_format="mistral"),
format_observation=StringFormatter(slots=["""[TOOL_RESULTS] {"content": {{content}}}[/TOOL_RESULTS]"""]), format_observation=StringFormatter(slots=["""[TOOL_RESULTS] {"content": {{content}}}[/TOOL_RESULTS]"""]),
format_tools=ToolFormatter(tool_format="mistral"), format_tools=ToolFormatter(tool_format="mistral"),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
template_class=Llama2Template,
)
# mistral tokenizer v7 tekken (copied from ministral)
register_template(
name="mistral_small",
format_user=StringFormatter(slots=["[INST]{{content}}[/INST]"]),
format_system=StringFormatter(slots=["[SYSTEM_PROMPT]{{content}}[/SYSTEM_PROMPT]"]),
format_function=FunctionFormatter(slots=["[TOOL_CALLS]{{content}}", {"eos_token"}], tool_format="mistral"),
format_observation=StringFormatter(slots=["""[TOOL_RESULTS]{"content": {{content}}}[/TOOL_RESULTS]"""]),
format_tools=ToolFormatter(tool_format="mistral"),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
) )
_register_template( register_template(
name="olmo", name="olmo",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>\n"]), format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>\n"]),
format_prefix=EmptyFormatter(slots=[{"eos_token"}]), format_prefix=EmptyFormatter(slots=[{"eos_token"}]),
) )
_register_template( register_template(
name="openchat", name="openchat",
format_user=StringFormatter(slots=["GPT4 Correct User: {{content}}", {"eos_token"}, "GPT4 Correct Assistant:"]), format_user=StringFormatter(slots=["GPT4 Correct User: {{content}}", {"eos_token"}, "GPT4 Correct Assistant:"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
) )
_register_template( register_template(
name="openchat-3.6", name="openchat-3.6",
format_user=StringFormatter( format_user=StringFormatter(
slots=[ slots=[
...@@ -1017,7 +1250,7 @@ _register_template( ...@@ -1017,7 +1250,7 @@ _register_template(
# copied from chatml template # copied from chatml template
_register_template( register_template(
name="opencoder", name="opencoder",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
...@@ -1028,16 +1261,24 @@ _register_template( ...@@ -1028,16 +1261,24 @@ _register_template(
) )
_register_template( register_template(
name="orion", name="orion",
format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: ", {"eos_token"}]), format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: ", {"eos_token"}]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
) )
# copied from gemma template register_template(
_register_template(
name="paligemma", name="paligemma",
format_user=StringFormatter(slots=["{{content}}\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
mm_plugin=get_mm_plugin(name="paligemma", image_token="<image>"),
)
# copied from gemma template
register_template(
name="paligemma_chat",
format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]), format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
format_assistant=StringFormatter(slots=["{{content}}<end_of_turn>\n"]), format_assistant=StringFormatter(slots=["{{content}}<end_of_turn>\n"]),
format_observation=StringFormatter( format_observation=StringFormatter(
...@@ -1048,7 +1289,7 @@ _register_template( ...@@ -1048,7 +1289,7 @@ _register_template(
) )
_register_template( register_template(
name="phi", name="phi",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]), format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|end|>\n"]), format_assistant=StringFormatter(slots=["{{content}}<|end|>\n"]),
...@@ -1057,7 +1298,7 @@ _register_template( ...@@ -1057,7 +1298,7 @@ _register_template(
) )
_register_template( register_template(
name="phi_small", name="phi_small",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]), format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|end|>\n"]), format_assistant=StringFormatter(slots=["{{content}}<|end|>\n"]),
...@@ -1067,7 +1308,7 @@ _register_template( ...@@ -1067,7 +1308,7 @@ _register_template(
) )
_register_template( register_template(
name="phi4", name="phi4",
format_user=StringFormatter( format_user=StringFormatter(
slots=["<|im_start|>user<|im_sep|>{{content}}<|im_end|><|im_start|>assistant<|im_sep|>"] slots=["<|im_start|>user<|im_sep|>{{content}}<|im_end|><|im_start|>assistant<|im_sep|>"]
...@@ -1078,17 +1319,22 @@ _register_template( ...@@ -1078,17 +1319,22 @@ _register_template(
) )
_register_template( # copied from ministral template
register_template(
name="pixtral", name="pixtral",
format_user=StringFormatter(slots=["[INST]{{content}}[/INST]"]), format_user=StringFormatter(slots=["[INST]{{content}}[/INST]"]),
format_system=StringFormatter(slots=["{{content}}\n\n"]), format_system=StringFormatter(slots=["{{content}}\n\n"]),
format_function=FunctionFormatter(slots=["[TOOL_CALLS]{{content}}", {"eos_token"}], tool_format="mistral"),
format_observation=StringFormatter(slots=["""[TOOL_RESULTS]{"content": {{content}}}[/TOOL_RESULTS]"""]),
format_tools=ToolFormatter(tool_format="mistral"),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
mm_plugin=get_mm_plugin(name="pixtral", image_token="[IMG]"), mm_plugin=get_mm_plugin(name="pixtral", image_token="[IMG]"),
template_class=Llama2Template,
) )
# copied from chatml template # copied from chatml template
_register_template( register_template(
name="qwen", name="qwen",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
...@@ -1104,7 +1350,19 @@ _register_template( ...@@ -1104,7 +1350,19 @@ _register_template(
# copied from chatml template # copied from chatml template
_register_template( register_template(
name="qwen2_audio",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
default_system="You are a helpful assistant.",
stop_words=["<|im_end|>"],
mm_plugin=get_mm_plugin(name="qwen2_audio", audio_token="<|AUDIO|>"),
)
# copied from qwen template
register_template(
name="qwen2_vl", name="qwen2_vl",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
...@@ -1120,7 +1378,7 @@ _register_template( ...@@ -1120,7 +1378,7 @@ _register_template(
) )
_register_template( register_template(
name="sailor", name="sailor",
format_user=StringFormatter(slots=["<|im_start|>question\n{{content}}<|im_end|>\n<|im_start|>answer\n"]), format_user=StringFormatter(slots=["<|im_start|>question\n{{content}}<|im_end|>\n<|im_start|>answer\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
...@@ -1134,7 +1392,7 @@ _register_template( ...@@ -1134,7 +1392,7 @@ _register_template(
# copied from llama3 template # copied from llama3 template
_register_template( register_template(
name="skywork_o1", name="skywork_o1",
format_user=StringFormatter( format_user=StringFormatter(
slots=[ slots=[
...@@ -1168,7 +1426,7 @@ _register_template( ...@@ -1168,7 +1426,7 @@ _register_template(
) )
_register_template( register_template(
name="solar", name="solar",
format_user=StringFormatter(slots=["### User:\n{{content}}\n\n### Assistant:\n"]), format_user=StringFormatter(slots=["### User:\n{{content}}\n\n### Assistant:\n"]),
format_system=StringFormatter(slots=["### System:\n{{content}}\n\n"]), format_system=StringFormatter(slots=["### System:\n{{content}}\n\n"]),
...@@ -1176,7 +1434,7 @@ _register_template( ...@@ -1176,7 +1434,7 @@ _register_template(
) )
_register_template( register_template(
name="starchat", name="starchat",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>"]), format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>"]),
format_assistant=StringFormatter(slots=["{{content}}<|end|>\n"]), format_assistant=StringFormatter(slots=["{{content}}<|end|>\n"]),
...@@ -1185,14 +1443,14 @@ _register_template( ...@@ -1185,14 +1443,14 @@ _register_template(
) )
_register_template( register_template(
name="telechat", name="telechat",
format_user=StringFormatter(slots=["<_user>{{content}}<_bot>"]), format_user=StringFormatter(slots=["<_user>{{content}}<_bot>"]),
format_system=StringFormatter(slots=["<_system>{{content}}<_end>"]), format_system=StringFormatter(slots=["<_system>{{content}}<_end>"]),
) )
_register_template( register_template(
name="telechat2", name="telechat2",
format_user=StringFormatter(slots=["<_user>{{content}}<_bot>"]), format_user=StringFormatter(slots=["<_user>{{content}}<_bot>"]),
format_system=StringFormatter(slots=["<_system>{{content}}"]), format_system=StringFormatter(slots=["<_system>{{content}}"]),
...@@ -1202,7 +1460,7 @@ _register_template( ...@@ -1202,7 +1460,7 @@ _register_template(
) )
_register_template( register_template(
name="vicuna", name="vicuna",
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]), format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
default_system=( default_system=(
...@@ -1213,7 +1471,7 @@ _register_template( ...@@ -1213,7 +1471,7 @@ _register_template(
) )
_register_template( register_template(
name="video_llava", name="video_llava",
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]), format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
default_system=( default_system=(
...@@ -1224,7 +1482,7 @@ _register_template( ...@@ -1224,7 +1482,7 @@ _register_template(
) )
_register_template( register_template(
name="xuanyuan", name="xuanyuan",
format_user=StringFormatter(slots=["Human: {{content}} Assistant:"]), format_user=StringFormatter(slots=["Human: {{content}} Assistant:"]),
default_system=( default_system=(
...@@ -1235,13 +1493,13 @@ _register_template( ...@@ -1235,13 +1493,13 @@ _register_template(
) )
_register_template( register_template(
name="xverse", name="xverse",
format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: "]), format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: "]),
) )
_register_template( register_template(
name="yayi", name="yayi",
format_user=StringFormatter(slots=[{"token": "<|Human|>"}, ":\n{{content}}\n\n", {"token": "<|YaYi|>"}, ":"]), format_user=StringFormatter(slots=[{"token": "<|Human|>"}, ":\n{{content}}\n\n", {"token": "<|YaYi|>"}, ":"]),
format_assistant=StringFormatter(slots=["{{content}}\n\n"]), format_assistant=StringFormatter(slots=["{{content}}\n\n"]),
...@@ -1262,7 +1520,7 @@ _register_template( ...@@ -1262,7 +1520,7 @@ _register_template(
# copied from chatml template # copied from chatml template
_register_template( register_template(
name="yi", name="yi",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
...@@ -1271,7 +1529,7 @@ _register_template( ...@@ -1271,7 +1529,7 @@ _register_template(
) )
_register_template( register_template(
name="yi_vl", name="yi_vl",
format_user=StringFormatter(slots=["### Human: {{content}}\n### Assistant:"]), format_user=StringFormatter(slots=["### Human: {{content}}\n### Assistant:"]),
format_assistant=StringFormatter(slots=["{{content}}\n"]), format_assistant=StringFormatter(slots=["{{content}}\n"]),
...@@ -1288,7 +1546,7 @@ _register_template( ...@@ -1288,7 +1546,7 @@ _register_template(
) )
_register_template( register_template(
name="yuan", name="yuan",
format_user=StringFormatter(slots=["{{content}}", {"token": "<sep>"}]), format_user=StringFormatter(slots=["{{content}}", {"token": "<sep>"}]),
format_assistant=StringFormatter(slots=["{{content}}<eod>\n"]), format_assistant=StringFormatter(slots=["{{content}}<eod>\n"]),
...@@ -1296,7 +1554,7 @@ _register_template( ...@@ -1296,7 +1554,7 @@ _register_template(
) )
_register_template( register_template(
name="zephyr", name="zephyr",
format_user=StringFormatter(slots=["<|user|>\n{{content}}", {"eos_token"}, "<|assistant|>\n"]), format_user=StringFormatter(slots=["<|user|>\n{{content}}", {"eos_token"}, "<|assistant|>\n"]),
format_system=StringFormatter(slots=["<|system|>\n{{content}}", {"eos_token"}]), format_system=StringFormatter(slots=["<|system|>\n{{content}}", {"eos_token"}]),
...@@ -1304,7 +1562,7 @@ _register_template( ...@@ -1304,7 +1562,7 @@ _register_template(
) )
_register_template( register_template(
name="ziya", name="ziya",
format_user=StringFormatter(slots=["<human>:{{content}}\n<bot>:"]), format_user=StringFormatter(slots=["<human>:{{content}}\n<bot>:"]),
format_assistant=StringFormatter(slots=["{{content}}\n"]), format_assistant=StringFormatter(slots=["{{content}}\n"]),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment