Commit 317a82e2 authored by chenych's avatar chenych
Browse files

Add QWQ-32B

parent 37b0ad9f
# Copyright 2024 the LlamaFactory team.
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -86,19 +86,25 @@ class StringFormatter(Formatter):
elif isinstance(slot, (dict, set)):
elements.append(slot)
else:
raise RuntimeError(f"Input must be string, set[str] or dict[str, str], got {type(slot)}")
raise RuntimeError(f"Input must be string, set[str] or dict[str, str], got {type(slot)}.")
return elements
@dataclass
class FunctionFormatter(Formatter):
class FunctionFormatter(StringFormatter):
def __post_init__(self):
super().__post_init__()
self.tool_utils = get_tool_utils(self.tool_format)
@override
def apply(self, **kwargs) -> SLOTS:
content = kwargs.pop("content")
content: str = kwargs.pop("content")
regex = re.compile(r"<think>(.*)</think>", re.DOTALL)
thought = re.search(regex, content)
if thought:
content = content.replace(thought.group(0), "")
functions: List["FunctionCall"] = []
try:
tool_calls = json.loads(content)
......@@ -111,16 +117,13 @@ class FunctionFormatter(Formatter):
)
except json.JSONDecodeError:
raise RuntimeError(f"Invalid JSON format in function message: {str([content])}") # flat string
raise RuntimeError(f"Invalid JSON format in function message: {str([content])}.") # flat string
elements = []
for slot in self.slots:
if slot == "{{content}}":
elements += self.tool_utils.function_formatter(functions)
else:
elements.append(slot)
function_str = self.tool_utils.function_formatter(functions)
if thought:
function_str = thought.group(1) + function_str
return elements
return super().apply(content=function_str)
@dataclass
......@@ -135,7 +138,7 @@ class ToolFormatter(Formatter):
tools = json.loads(content)
return [self.tool_utils.tool_formatter(tools) if len(tools) != 0 else ""]
except json.JSONDecodeError:
raise RuntimeError(f"Invalid JSON format in tool description: {str([content])}") # flat string
raise RuntimeError(f"Invalid JSON format in tool description: {str([content])}.") # flat string
@override
def extract(self, content: str) -> Union[str, List["FunctionCall"]]:
......
# Copyright 2024 the LlamaFactory team.
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -22,10 +22,17 @@ from datasets import DatasetDict, load_dataset, load_from_disk
from ..extras import logging
from ..extras.constants import FILEEXT2TYPE
from ..extras.misc import check_version, has_tokenized_data
from .aligner import align_dataset
from .converter import align_dataset
from .data_utils import merge_dataset, split_dataset
from .parser import get_dataset_list
from .preprocess import get_preprocess_and_print_func
from .processor import (
FeedbackDatasetProcessor,
PackedSupervisedDatasetProcessor,
PairwiseDatasetProcessor,
PretrainDatasetProcessor,
SupervisedDatasetProcessor,
UnsupervisedDatasetProcessor,
)
if TYPE_CHECKING:
......@@ -35,6 +42,7 @@ if TYPE_CHECKING:
from ..hparams import DataArguments, ModelArguments
from .data_utils import DatasetModule
from .parser import DatasetAttr
from .processor import DatasetProcessor
from .template import Template
......@@ -156,21 +164,67 @@ def _get_merged_dataset(
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
) -> Optional[Union["Dataset", "IterableDataset"]]:
merge: bool = True,
) -> Optional[Union["Dataset", "IterableDataset", Dict[str, "Dataset"]]]:
r"""
Gets the merged datasets in the standard format.
Returns the merged datasets in the standard format.
"""
if dataset_names is None:
return None
datasets = []
for dataset_attr in get_dataset_list(dataset_names, data_args.dataset_dir):
datasets = {}
for dataset_name, dataset_attr in zip(dataset_names, get_dataset_list(dataset_names, data_args.dataset_dir)):
if (stage == "rm" and dataset_attr.ranking is False) or (stage != "rm" and dataset_attr.ranking is True):
raise ValueError("The dataset is not applicable in the current training stage.")
datasets.append(_load_single_dataset(dataset_attr, model_args, data_args, training_args))
datasets[dataset_name] = _load_single_dataset(dataset_attr, model_args, data_args, training_args)
if merge:
return merge_dataset(list(datasets.values()), data_args, seed=training_args.seed)
else:
return datasets
def _get_dataset_processor(
data_args: "DataArguments",
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
do_generate: bool = False,
) -> "DatasetProcessor":
r"""
Returns the corresponding dataset processor.
"""
if stage == "pt":
dataset_processor_class = PretrainDatasetProcessor
elif stage == "sft" and not do_generate:
if data_args.packing:
if data_args.neat_packing: # hack datasets to have int32 attention mask
from datasets.arrow_writer import OptimizedTypedSequence, TypedSequence
def __init__(self, data, **kwargs):
return TypedSequence.__init__(
self,
data,
type=kwargs.pop("type", None),
try_type=kwargs.pop("try_type", None),
optimized_int_type=kwargs.pop("optimized_int_type", None),
)
OptimizedTypedSequence.__init__ = __init__
dataset_processor_class = PackedSupervisedDatasetProcessor
else:
dataset_processor_class = SupervisedDatasetProcessor
elif stage == "rm":
dataset_processor_class = PairwiseDatasetProcessor
elif stage == "kto":
dataset_processor_class = FeedbackDatasetProcessor
else:
dataset_processor_class = UnsupervisedDatasetProcessor
return merge_dataset(datasets, data_args, seed=training_args.seed)
return dataset_processor_class(template=template, tokenizer=tokenizer, processor=processor, data_args=data_args)
def _get_preprocessed_dataset(
......@@ -189,7 +243,7 @@ def _get_preprocessed_dataset(
if dataset is None:
return None
preprocess_func, print_function = get_preprocess_and_print_func(
dataset_processor = _get_dataset_processor(
data_args, stage, template, tokenizer, processor, do_generate=(training_args.predict_with_generate and is_eval)
)
column_names = list(next(iter(dataset)).keys())
......@@ -202,7 +256,7 @@ def _get_preprocessed_dataset(
)
dataset = dataset.map(
preprocess_func,
dataset_processor.preprocess_dataset,
batched=True,
batch_size=data_args.preprocessing_batch_size,
remove_columns=column_names,
......@@ -212,7 +266,7 @@ def _get_preprocessed_dataset(
if training_args.should_log:
try:
print("eval example:" if is_eval else "training example:")
print_function(next(iter(dataset)))
dataset_processor.print_data_example(next(iter(dataset)))
except StopIteration:
if stage == "pt":
raise RuntimeError("Cannot find sufficient samples, consider increasing dataset size.")
......@@ -234,7 +288,7 @@ def get_dataset(
r"""
Gets the train dataset and optionally gets the evaluation dataset.
"""
# Load tokenized dataset
# Load tokenized dataset if path exists
if data_args.tokenized_path is not None:
if has_tokenized_data(data_args.tokenized_path):
logger.warning_rank0("Loading dataset from disk will ignore other data arguments.")
......@@ -249,7 +303,7 @@ def get_dataset(
if "validation" in tokenized_data:
dataset_module["eval_dataset"] = tokenized_data["validation"]
else: # Dataset
else: # single dataset
dataset_module["train_dataset"] = tokenized_data
if data_args.streaming:
......@@ -263,15 +317,23 @@ def get_dataset(
# Load and preprocess dataset
with training_args.main_process_first(desc="load dataset"):
dataset = _get_merged_dataset(data_args.dataset, model_args, data_args, training_args, stage)
eval_dataset = _get_merged_dataset(data_args.eval_dataset, model_args, data_args, training_args, stage)
eval_dataset = _get_merged_dataset(
data_args.eval_dataset, model_args, data_args, training_args, stage, merge=training_args.do_predict
)
with training_args.main_process_first(desc="pre-process dataset"):
dataset = _get_preprocessed_dataset(
dataset, data_args, training_args, stage, template, tokenizer, processor, is_eval=False
)
eval_dataset = _get_preprocessed_dataset(
eval_dataset, data_args, training_args, stage, template, tokenizer, processor, is_eval=True
)
if isinstance(eval_dataset, dict):
for eval_name, eval_data in eval_dataset.items():
eval_dataset[eval_name] = _get_preprocessed_dataset(
eval_data, data_args, training_args, stage, template, tokenizer, processor, is_eval=True
)
else:
eval_dataset = _get_preprocessed_dataset(
eval_dataset, data_args, training_args, stage, template, tokenizer, processor, is_eval=True
)
if data_args.val_size > 1e-6:
dataset_dict = split_dataset(dataset, data_args, seed=training_args.seed)
......@@ -284,17 +346,20 @@ def get_dataset(
dataset_dict["train"] = dataset
if eval_dataset is not None:
if data_args.streaming:
eval_dataset = eval_dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
if isinstance(eval_dataset, dict):
dataset_dict.update({f"validation_{name}": data for name, data in eval_dataset.items()})
else:
if data_args.streaming:
eval_dataset = eval_dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
dataset_dict["validation"] = eval_dataset
dataset_dict["validation"] = eval_dataset
dataset_dict = DatasetDict(dataset_dict)
if data_args.tokenized_path is not None:
if data_args.tokenized_path is not None: # save tokenized dataset to disk and exit
if training_args.should_save:
dataset_dict.save_to_disk(data_args.tokenized_path)
logger.info_rank0(f"Tokenized dataset saved at {data_args.tokenized_path}.")
logger.info_rank0(f"Tokenized dataset is saved at {data_args.tokenized_path}.")
logger.info_rank0(f"Please restart the training with `tokenized_path: {data_args.tokenized_path}`.")
sys.exit(0)
......@@ -305,5 +370,13 @@ def get_dataset(
if "validation" in dataset_dict:
dataset_module["eval_dataset"] = dataset_dict["validation"]
else:
eval_dataset = {}
for key in dataset_dict.keys():
if key.startswith("validation_"):
eval_dataset[key[len("validation_") :]] = dataset_dict[key]
if len(eval_dataset):
dataset_module["eval_dataset"] = eval_dataset
return dataset_module
import inspect
import math
import re
from copy import deepcopy
from dataclasses import dataclass
from io import BytesIO
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Type, TypedDict, Union
import numpy as np
import torch
from transformers.image_utils import get_image_size, to_numpy_array
from typing_extensions import override
from ..extras.constants import IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
from ..extras.packages import is_pillow_available, is_pyav_available, is_transformers_version_greater_than
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
from ..extras.packages import (
is_librosa_available,
is_pillow_available,
is_pyav_available,
is_transformers_version_greater_than,
)
if is_librosa_available():
import librosa
if is_pillow_available():
......@@ -31,7 +42,9 @@ if is_transformers_version_greater_than("4.45.0"):
if TYPE_CHECKING:
from av.stream import Stream
from numpy.typing import NDArray
from transformers import PreTrainedTokenizer, ProcessorMixin
from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
from transformers.image_processing_utils import BaseImageProcessor
class EncodedImage(TypedDict):
......@@ -40,6 +53,7 @@ if TYPE_CHECKING:
ImageInput = Union[str, bytes, EncodedImage, ImageObject]
VideoInput = str
AudioInput = Union[str, NDArray]
def _get_paligemma_token_type_ids(
......@@ -59,20 +73,25 @@ def _get_paligemma_token_type_ids(
return batch_token_type_ids
class BasePlugin:
def __init__(self, image_token: Optional[str], video_token: Optional[str]) -> None:
self.image_token = image_token
self.video_token = video_token
self.expand_mm_tokens = True
@dataclass
class MMPluginMixin:
image_token: Optional[str]
video_token: Optional[str]
audio_token: Optional[str]
expand_mm_tokens: bool = True
def _validate_input(
self,
processor: Optional["ProcessorMixin"],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
) -> None:
r"""
Validates if this model accepts the input modalities.
"""
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor", None)
feature_extractor: "SequenceFeatureExtractor" = getattr(processor, "feature_extractor", None)
if len(images) != 0 and self.image_token is None:
raise ValueError(
"This model does not support image input. Please check whether the correct `template` is used."
......@@ -83,31 +102,54 @@ class BasePlugin:
"This model does not support video input. Please check whether the correct `template` is used."
)
def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject":
if len(audios) != 0 and self.audio_token is None:
raise ValueError(
"This model does not support audio input. Please check whether the correct `template` is used."
)
if self.image_token is not None and processor is None:
raise ValueError("Processor was not found, please check and update your processor config.")
if self.image_token is not None and image_processor is None:
raise ValueError("Image processor was not found, please check and update your processor config.")
if self.audio_token is not None and feature_extractor is None:
raise ValueError("Audio feature extractor was not found, please check and update your processor config.")
def _preprocess_image(
self, image: "ImageObject", image_max_pixels: int, image_min_pixels: int, **kwargs
) -> "ImageObject":
r"""
Pre-processes a single image.
"""
image_resolution: int = kwargs.get("image_resolution")
if (image.width * image.height) > image_resolution:
resize_factor = math.sqrt(image_resolution / (image.width * image.height))
if (image.width * image.height) > image_max_pixels:
resize_factor = math.sqrt(image_max_pixels / (image.width * image.height))
width, height = int(image.width * resize_factor), int(image.height * resize_factor)
image = image.resize((width, height), resample=Image.NEAREST)
image = image.resize((width, height))
if (image.width * image.height) < image_min_pixels:
resize_factor = math.sqrt(image_min_pixels / (image.width * image.height))
width, height = int(image.width * resize_factor), int(image.height * resize_factor)
image = image.resize((width, height))
if image.mode != "RGB":
image = image.convert("RGB")
return image
def _get_video_sample_frames(self, video_stream: "Stream", **kwargs) -> int:
def _get_video_sample_indices(
self, video_stream: "Stream", video_fps: float, video_maxlen: int, **kwargs
) -> List[int]:
r"""
Computes video sample frames according to fps.
Computes video sample indices according to fps.
"""
video_fps: float = kwargs.get("video_fps")
video_maxlen: int = kwargs.get("video_maxlen")
total_frames = video_stream.frames
sample_frames = float(video_stream.duration * video_stream.time_base) * video_fps
if total_frames == 0: # infinite video
return np.linspace(0, video_maxlen - 1, video_maxlen).astype(np.int32)
sample_frames = math.floor(float(video_stream.duration * video_stream.time_base) * video_fps)
sample_frames = min(total_frames, video_maxlen, sample_frames)
return math.floor(sample_frames)
return np.linspace(0, total_frames - 1, sample_frames).astype(np.int32)
def _regularize_images(self, images: Sequence["ImageInput"], **kwargs) -> List["ImageObject"]:
r"""
......@@ -126,7 +168,7 @@ class BasePlugin:
image = Image.open(image["path"])
if not isinstance(image, ImageObject):
raise ValueError(f"Expect input is a list of Images, but got {type(image)}.")
raise ValueError(f"Expect input is a list of images, but got {type(image)}.")
results.append(self._preprocess_image(image, **kwargs))
......@@ -140,9 +182,7 @@ class BasePlugin:
for video in videos:
container = av.open(video, "r")
video_stream = next(stream for stream in container.streams if stream.type == "video")
total_frames = video_stream.frames
sample_frames = self._get_video_sample_frames(video_stream, **kwargs)
sample_indices = np.linspace(0, total_frames - 1, sample_frames).astype(np.int32)
sample_indices = self._get_video_sample_indices(video_stream, **kwargs)
frames: List["ImageObject"] = []
container.seek(0)
for frame_idx, frame in enumerate(container.decode(video_stream)):
......@@ -154,10 +194,27 @@ class BasePlugin:
return results
def _regularize_audios(self, audios: Sequence["AudioInput"], sampling_rate: float, **kwargs) -> List["NDArray"]:
r"""
Regularizes audios to avoid error. Including reading and resampling.
"""
results = []
for audio in audios:
if isinstance(audio, str):
audio = librosa.load(audio, sr=sampling_rate)[0]
if not isinstance(audio, np.ndarray):
raise ValueError(f"Expect input is a list of audios, but got {type(audio)}.")
results.append(audio)
return results
def _get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: "ProcessorMixin",
) -> Dict[str, "torch.Tensor"]:
r"""
......@@ -172,47 +229,65 @@ class BasePlugin:
It holds num_patches == torch.prod(image_grid_thw)
"""
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor", None)
video_processor: "BaseImageProcessor" = getattr(processor, "video_processor", image_processor)
input_dict = {"images": None} # default key
feature_extractor: "SequenceFeatureExtractor" = getattr(processor, "feature_extractor", None)
mm_inputs = {}
if len(images) != 0:
images = self._regularize_images(
images,
image_resolution=getattr(processor, "image_resolution", 512 * 512),
image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
)
input_dict["images"] = images
mm_inputs.update(image_processor(images, return_tensors="pt"))
if len(videos) != 0:
videos = self._regularize_videos(
videos,
image_resolution=getattr(processor, "video_resolution", 128 * 128),
image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256),
image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16),
video_fps=getattr(processor, "video_fps", 2.0),
video_maxlen=getattr(processor, "video_maxlen", 64),
video_maxlen=getattr(processor, "video_maxlen", 128),
)
input_dict["videos"] = videos
mm_inputs = {}
if image_processor != video_processor:
if input_dict.get("images") is not None:
mm_inputs.update(image_processor(input_dict["images"], return_tensors="pt"))
if input_dict.get("videos") is not None:
mm_inputs.update(video_processor(input_dict["videos"], return_tensors="pt"))
elif input_dict.get("images") is not None or input_dict.get("videos") is not None: # same processor (qwen2-vl)
mm_inputs.update(image_processor(**input_dict, return_tensors="pt"))
if "videos" in inspect.signature(video_processor.preprocess).parameters: # for qwen2_vl and video_llava
mm_inputs.update(video_processor(images=None, videos=videos, return_tensors="pt"))
else: # for llava_next_video
mm_inputs.update(video_processor(videos, return_tensors="pt"))
if len(audios) != 0:
audios = self._regularize_audios(
audios,
sampling_rate=getattr(feature_extractor, "sampling_rate", 16000),
)
mm_inputs.update(
feature_extractor(
audios,
sampling_rate=getattr(feature_extractor, "sampling_rate", 16000),
return_attention_mask=True,
padding="max_length",
return_tensors="pt",
)
)
mm_inputs["feature_attention_mask"] = mm_inputs.pop("attention_mask") # prevent conflicts
return mm_inputs
@dataclass
class BasePlugin(MMPluginMixin):
def process_messages(
self,
messages: Sequence[Dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
r"""
Pre-processes input messages before tokenization for VLMs.
"""
self._validate_input(images, videos)
self._validate_input(processor, images, videos, audios)
return messages
def process_token_ids(
......@@ -221,21 +296,24 @@ class BasePlugin:
labels: Optional[List[int]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
) -> Tuple[List[int], Optional[List[int]]]:
r"""
Pre-processes token ids after tokenization for VLMs.
"""
self._validate_input(images, videos)
self._validate_input(processor, images, videos, audios)
return input_ids, labels
def get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
audlens: Sequence[int],
batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
......@@ -247,13 +325,15 @@ class BasePlugin:
videos: a list of video inputs, shape (num_videos,)
imglens: number of images in each sample, shape (batch_size,)
vidlens: number of videos in each sample, shape (batch_size,)
audlens: number of audios in each sample, shape (batch_size,)
batch_ids: token ids of input samples, shape (batch_size, seq_len)
processor: a processor for pre-processing images and videos
"""
self._validate_input(images, videos)
self._validate_input(processor, images, videos, audios)
return {}
@dataclass
class LlavaPlugin(BasePlugin):
@override
def process_messages(
......@@ -261,9 +341,10 @@ class LlavaPlugin(BasePlugin):
messages: Sequence[Dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
self._validate_input(images, videos)
self._validate_input(processor, images, videos, audios)
num_image_tokens = 0
image_seqlen = getattr(processor, "image_seqlen") if self.expand_mm_tokens else 1
messages = deepcopy(messages)
......@@ -285,15 +366,18 @@ class LlavaPlugin(BasePlugin):
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
audlens: Sequence[int],
batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos)
return self._get_mm_inputs(images, videos, processor)
self._validate_input(processor, images, videos, audios)
return self._get_mm_inputs(images, videos, audios, processor)
@dataclass
class LlavaNextPlugin(BasePlugin):
@override
def process_messages(
......@@ -301,16 +385,15 @@ class LlavaNextPlugin(BasePlugin):
messages: Sequence[Dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
self._validate_input(images, videos)
self._validate_input(processor, images, videos, audios)
num_image_tokens = 0
messages = deepcopy(messages)
mm_inputs = self._get_mm_inputs(images, videos, processor)
if "image_sizes" in mm_inputs:
image_sizes = iter(mm_inputs["image_sizes"])
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
if "pixel_values" in mm_inputs:
image_sizes = iter(mm_inputs["image_sizes"].tolist())
height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0]))
for message in messages:
......@@ -319,7 +402,7 @@ class LlavaNextPlugin(BasePlugin):
if self.expand_mm_tokens:
orig_height, orig_width = next(image_sizes)
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
if getattr(processor, "vision_feature_select_strategy") == "default":
if getattr(processor, "vision_feature_select_strategy", "default") == "default":
image_seqlen -= 1
else:
image_seqlen = 1
......@@ -339,15 +422,18 @@ class LlavaNextPlugin(BasePlugin):
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
audlens: Sequence[int],
batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos)
return self._get_mm_inputs(images, videos, processor)
self._validate_input(processor, images, videos, audios)
return self._get_mm_inputs(images, videos, audios, processor)
@dataclass
class LlavaNextVideoPlugin(BasePlugin):
@override
def process_messages(
......@@ -355,14 +441,15 @@ class LlavaNextVideoPlugin(BasePlugin):
messages: Sequence[Dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
self._validate_input(images, videos)
self._validate_input(processor, images, videos, audios)
num_image_tokens, num_video_tokens = 0, 0
messages = deepcopy(messages)
mm_inputs = self._get_mm_inputs(images, videos, processor)
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
if "pixel_values" in mm_inputs:
image_sizes = iter(mm_inputs["image_sizes"])
image_sizes = iter(mm_inputs["image_sizes"].tolist())
height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0]))
for message in messages:
content = message["content"]
......@@ -370,7 +457,7 @@ class LlavaNextVideoPlugin(BasePlugin):
if self.expand_mm_tokens:
orig_height, orig_width = next(image_sizes)
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
if getattr(processor, "vision_feature_select_strategy") == "default":
if getattr(processor, "vision_feature_select_strategy", "default") == "default":
image_seqlen -= 1
else:
image_seqlen = 1
......@@ -381,12 +468,15 @@ class LlavaNextVideoPlugin(BasePlugin):
message["content"] = content.replace("{{image}}", self.image_token)
if "pixel_values_videos" in mm_inputs:
pixel_values_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0])
height, width = get_image_size(pixel_values_video[0])
num_frames = pixel_values_video.shape[0] # frame dim is always after batch dim
image_seqlen = (height // processor.patch_size) * (width // processor.patch_size)
video_seqlen = image_seqlen // 4 * num_frames # divide by 4 needed for avg pooling layer
video_seqlen = video_seqlen if self.expand_mm_tokens else 1
if self.expand_mm_tokens:
pixel_values_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0])
height, width = get_image_size(pixel_values_video[0])
num_frames = pixel_values_video.shape[0] # frame dim is always after batch dim
image_seqlen = (height // processor.patch_size) * (width // processor.patch_size)
video_seqlen = image_seqlen // 4 * num_frames # divide by 4 needed for avg pooling layer
else:
video_seqlen = 1
for message in messages:
content = message["content"]
while VIDEO_PLACEHOLDER in content:
......@@ -408,15 +498,18 @@ class LlavaNextVideoPlugin(BasePlugin):
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
audlens: Sequence[int],
batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos)
return self._get_mm_inputs(images, videos, processor)
self._validate_input(processor, images, videos, audios)
return self._get_mm_inputs(images, videos, audios, processor)
@dataclass
class MiniCPMVPlugin(BasePlugin):
@override
def process_messages(
......@@ -424,26 +517,27 @@ class MiniCPMVPlugin(BasePlugin):
messages: Sequence[Dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
self._validate_input(images, videos)
num_image_tokens = 0
num_video_tokens = 0
self._validate_input(processor, images, videos, audios)
num_image_tokens, num_video_tokens, num_audio_tokens = 0, 0, 0
messages = deepcopy(messages)
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
mm_inputs = {}
audio_inputs = {}
if len(images) != 0 and len(videos) != 0:
raise ValueError("MiniCPM-V model does not support input images and videos at the same time.")
if len(videos) != 0:
max_slice_nums = 2
use_image_id = False
mm_inputs = self._get_mm_inputs([], videos, processor)
mm_inputs = self._get_mm_inputs([], videos, [], processor)
else:
max_slice_nums = image_processor.max_slice_nums
use_image_id = image_processor.use_image_id
for message in messages:
for i, message in enumerate(messages):
content = message["content"]
while IMAGE_PLACEHOLDER in content:
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
......@@ -454,15 +548,24 @@ class MiniCPMVPlugin(BasePlugin):
content = content.replace(VIDEO_PLACEHOLDER, "{{image}}" * video_seqlen, 1)
num_video_tokens += 1
message["content"] = content.replace("{{image}}", "(<image>./</image>)")
while AUDIO_PLACEHOLDER in content:
content = content.replace(AUDIO_PLACEHOLDER, "{{audio}}", 1)
num_audio_tokens += 1
message["content"] = content.replace("{{image}}", "(<image>./</image>)").replace(
"{{audio}}", "(<audio>./</audio>)"
)
if num_image_tokens > 0:
mm_inputs = self._get_mm_inputs(images, [], processor)
mm_inputs = self._get_mm_inputs(images, [], [], processor)
if num_audio_tokens > 0:
audio_inputs = self._get_mm_inputs([], [], audios, processor, ret_phs=True)
if mm_inputs:
pattern = "(<image>./</image>)"
image_sizes = mm_inputs["image_sizes"]
idx = 0
for index, message in enumerate(messages):
text = message["content"]
image_tags = re.findall(pattern, text)
......@@ -473,9 +576,26 @@ class MiniCPMVPlugin(BasePlugin):
final_text
+ text_chunks[i]
+ image_processor.get_slice_image_placeholder(
image_sizes[0][i], i, max_slice_nums, use_image_id
image_sizes[0][idx], idx, max_slice_nums, use_image_id
)
)
idx += 1
final_text += text_chunks[-1]
messages[index]["content"] = final_text
if audio_inputs:
pattern = "(<audio>./</audio>)"
idx = 0
for index, message in enumerate(messages):
text = message["content"]
audio_tags = re.findall(pattern, text)
text_chunks = text.split(pattern)
final_text = ""
for i in range(len(audio_tags)):
audio_placeholder = audio_inputs["audio_phs"][0][idx]
final_text = final_text + text_chunks[i] + audio_placeholder
idx += 1
final_text += text_chunks[-1]
messages[index]["content"] = final_text
......@@ -486,6 +606,9 @@ class MiniCPMVPlugin(BasePlugin):
if len(videos) != num_video_tokens:
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
if len(audios) != num_audio_tokens:
raise ValueError(f"The number of audios does not match the number of {AUDIO_PLACEHOLDER} tokens.")
return messages
@override
......@@ -493,15 +616,18 @@ class MiniCPMVPlugin(BasePlugin):
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: "ProcessorMixin",
**kwargs,
) -> Dict[str, "torch.Tensor"]:
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
feature_extractor: "SequenceFeatureExtractor" = getattr(processor, "feature_extractor", None)
mm_inputs = {}
if len(images) != 0:
images = self._regularize_images(
images,
image_resolution=getattr(processor, "image_resolution", 512 * 512),
image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
)
if "valid_image_nums_ls" in kwargs:
valid_image_nums_ls = kwargs["valid_image_nums_ls"]
......@@ -521,13 +647,39 @@ class MiniCPMVPlugin(BasePlugin):
if len(videos) != 0:
videos = self._regularize_videos(
videos,
image_resolution=getattr(processor, "video_resolution", 128 * 128),
image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256),
image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16),
video_fps=getattr(processor, "video_fps", 2.0),
video_maxlen=getattr(processor, "video_maxlen", 64),
video_maxlen=getattr(processor, "video_maxlen", 128),
)
video_inputs = image_processor(videos, do_pad=True, max_slice_nums=2, return_tensors="pt")
mm_inputs.update(video_inputs)
if len(audios) != 0:
audios = self._regularize_audios(
audios,
sampling_rate=getattr(feature_extractor, "sampling_rate", 16000),
)
if "valid_audio_nums_ls" in kwargs:
valid_audio_nums_ls = kwargs["valid_audio_nums_ls"]
audios_ls = []
idx = 0
for valid_audio_nums in valid_audio_nums_ls:
audios_ls.append(audios[idx : idx + valid_audio_nums])
idx += valid_audio_nums
else:
audios_ls = [audios]
audio_features, audio_feature_lens, audio_phs = processor.audio_feature_extract(
audios_ls,
chunk_input=True,
sampling_rate=16000,
)
audio_feature_lens = [torch.tensor(audio_feature_len) for audio_feature_len in audio_feature_lens]
mm_inputs.update({"audio_features": audio_features, "audio_feature_lens": audio_feature_lens})
if kwargs.get("ret_phs", False):
mm_inputs.update({"audio_phs": audio_phs})
return mm_inputs
@override
......@@ -535,15 +687,18 @@ class MiniCPMVPlugin(BasePlugin):
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
audlens: Sequence[int],
batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos)
self._validate_input(processor, images, videos, audios)
# image bound
image_bounds_list = []
valid_image_nums_ls = []
for input_ids in batch_ids:
for i, input_ids in enumerate(batch_ids):
input_ids_ = torch.tensor(input_ids)
start_cond = (input_ids_ == processor.tokenizer.im_start_id) | (
input_ids_ == processor.tokenizer.slice_start_id
......@@ -552,21 +707,51 @@ class MiniCPMVPlugin(BasePlugin):
image_start_tokens = torch.where(start_cond)[0]
image_start_tokens += 1
image_end_tokens = torch.where(end_cond)[0]
valid_image_nums = max(len(image_start_tokens), len(image_end_tokens))
valid_image_nums_ls.append(valid_image_nums)
valid_image_nums_ls.append(imglens[i])
image_bounds = torch.hstack(
[
image_start_tokens[:valid_image_nums].unsqueeze(-1),
image_end_tokens[:valid_image_nums].unsqueeze(-1),
image_start_tokens.unsqueeze(-1),
image_end_tokens.unsqueeze(-1),
]
)
image_bounds_list.append(image_bounds)
mm_inputs = self._get_mm_inputs(images, videos, processor, valid_image_nums_ls=valid_image_nums_ls)
mm_inputs = self._get_mm_inputs(images, videos, [], processor, valid_image_nums_ls=valid_image_nums_ls)
if "tgt_sizes" not in mm_inputs:
dummy_data = [torch.empty(0) for _ in range(len(batch_ids))]
mm_inputs.update({"tgt_sizes": dummy_data, "pixel_values": dummy_data, "image_sizes": dummy_data})
mm_inputs.update({"image_bound": image_bounds_list})
if len(audios) > 0:
# audio bound
audio_bounds_ls = []
spk_bounds_ls = []
valid_audio_nums_ls = []
for input_ids, audiolen in zip(batch_ids, audlens):
input_ids_ = torch.tensor(input_ids)
audio_start_idx = torch.where(input_ids_ == processor.tokenizer.audio_start_id)[0]
audio_end_idx = torch.where(input_ids_ == processor.tokenizer.audio_end_id)[0]
assert len(audio_start_idx) == len(audio_end_idx)
audio_bounds = torch.hstack([(audio_start_idx + 1).unsqueeze(-1), audio_end_idx.unsqueeze(-1)])
audio_bounds_ls.append(audio_bounds)
valid_audio_nums_ls.append(audiolen)
spk_start_idx = torch.where(input_ids_ == processor.tokenizer.spk_start_id)[0]
spk_end_idx = torch.where(input_ids_ == processor.tokenizer.spk_end_id)[0]
assert len(spk_start_idx) == len(spk_end_idx)
spk_bounds = torch.hstack([(spk_start_idx + 1).unsqueeze(-1), spk_end_idx.unsqueeze(-1)])
spk_bounds_ls.append(spk_bounds)
audio_inputs = self._get_mm_inputs([], [], audios, processor, valid_audio_nums_ls=valid_audio_nums_ls)
mm_inputs.update(audio_inputs)
mm_inputs.update({"audio_bounds": audio_bounds_ls, "spk_bounds": spk_bounds_ls})
return mm_inputs
@dataclass
class MllamaPlugin(BasePlugin):
@override
def process_messages(
......@@ -574,9 +759,10 @@ class MllamaPlugin(BasePlugin):
messages: Sequence[Dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
self._validate_input(images, videos)
self._validate_input(processor, images, videos, audios)
num_image_tokens = 0
messages = deepcopy(messages)
for message in messages:
......@@ -594,8 +780,9 @@ class MllamaPlugin(BasePlugin):
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: "ProcessorMixin",
**kwargs,
imglens: List[int],
) -> Dict[str, "torch.Tensor"]:
r"""
Processes visual inputs for mllama because its image processor only accepts List[List[ImageInput]].
......@@ -609,43 +796,56 @@ class MllamaPlugin(BasePlugin):
num_tiles: List[List[int]] with shape (batch_size, num_images_in_batch). For example, (2, 1).
"""
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
imglens: List[int] = kwargs["imglens"]
images = self._regularize_images(images, image_resolution=getattr(processor, "image_resolution", 512 * 512))
batch_images = []
for image_length in imglens:
batch_images.append(images[:image_length])
images = images[image_length:]
mm_inputs = {}
if len(images) > 0:
images = self._regularize_images(
images,
image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
)
batch_images = []
for image_length in imglens:
batch_images.append(images[:image_length])
images = images[image_length:]
mm_inputs.update(image_processor(batch_images, return_tensors="pt"))
return image_processor(batch_images, return_tensors="pt")
return mm_inputs
@override
def get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
audlens: Sequence[int],
batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos)
mm_inputs = self._get_mm_inputs(images, videos, processor, imglens=imglens)
num_tiles = mm_inputs.pop("num_tiles")
image_token_id = getattr(processor, "image_token_id")
max_image_tiles = getattr(processor.image_processor, "max_image_tiles")
cross_attention_token_mask = [
get_cross_attention_token_mask(input_ids, image_token_id) for input_ids in batch_ids
]
mm_inputs["cross_attention_mask"] = torch.from_numpy(
convert_sparse_cross_attention_mask_to_dense(
cross_attention_token_mask,
num_tiles=num_tiles,
max_num_tiles=max_image_tiles,
length=max(len(input_ids) for input_ids in batch_ids),
)
) # shape: (batch_size, length, max_num_images, max_num_tiles)
self._validate_input(processor, images, videos, audios)
mm_inputs = self._get_mm_inputs(images, videos, audios, processor, imglens)
if mm_inputs:
num_tiles = mm_inputs.pop("num_tiles")
image_token_id = getattr(processor, "image_token_id")
max_image_tiles = getattr(processor.image_processor, "max_image_tiles")
cross_attention_token_mask = [
get_cross_attention_token_mask(input_ids, image_token_id) for input_ids in batch_ids
]
mm_inputs["cross_attention_mask"] = torch.from_numpy(
convert_sparse_cross_attention_mask_to_dense(
cross_attention_token_mask,
num_tiles=num_tiles,
max_num_tiles=max_image_tiles,
length=max(len(input_ids) for input_ids in batch_ids),
)
) # shape: (batch_size, length, max_num_images, max_num_tiles)
return mm_inputs
@dataclass
class PaliGemmaPlugin(BasePlugin):
@override
def process_messages(
......@@ -653,9 +853,10 @@ class PaliGemmaPlugin(BasePlugin):
messages: Sequence[Dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
self._validate_input(images, videos)
self._validate_input(processor, images, videos, audios)
num_image_tokens = 0
messages = deepcopy(messages)
for message in messages:
......@@ -678,10 +879,11 @@ class PaliGemmaPlugin(BasePlugin):
labels: Optional[List[int]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
) -> Tuple[List[int], Optional[List[int]]]:
self._validate_input(images, videos)
self._validate_input(processor, images, videos, audios)
num_images = len(images)
image_seqlen = num_images * getattr(processor, "image_seqlen") if self.expand_mm_tokens else 0 # skip mm token
image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
......@@ -696,18 +898,21 @@ class PaliGemmaPlugin(BasePlugin):
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
audlens: Sequence[int],
batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos)
self._validate_input(processor, images, videos, audios)
seqlens = [len(input_ids) for input_ids in batch_ids]
mm_inputs = self._get_mm_inputs(images, videos, processor)
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
mm_inputs["token_type_ids"] = _get_paligemma_token_type_ids(imglens, seqlens, processor)
return mm_inputs
@dataclass
class PixtralPlugin(BasePlugin):
@override
def process_messages(
......@@ -715,9 +920,10 @@ class PixtralPlugin(BasePlugin):
messages: Sequence[Dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
self._validate_input(images, videos)
self._validate_input(processor, images, videos, audios)
patch_size = getattr(processor, "patch_size")
image_token = getattr(processor, "image_token")
image_break_token = getattr(processor, "image_break_token")
......@@ -725,17 +931,15 @@ class PixtralPlugin(BasePlugin):
num_image_tokens = 0
messages = deepcopy(messages)
mm_inputs = self._get_mm_inputs(images, videos, processor)
image_input_sizes = mm_inputs.get("image_sizes", None)
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
if "pixel_values" in mm_inputs:
image_sizes = iter(mm_inputs["image_sizes"].tolist())
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
if image_input_sizes is None:
raise ValueError("Cannot get image input sizes.")
if self.expand_mm_tokens:
image_size = image_input_sizes[0][num_image_tokens]
height, width = image_size
height, width = next(image_sizes)
num_height_tokens = height // patch_size
num_width_tokens = width // patch_size
replace_tokens = [[image_token] * num_width_tokens + [image_break_token]] * num_height_tokens
......@@ -760,47 +964,105 @@ class PixtralPlugin(BasePlugin):
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
audlens: Sequence[int],
batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos)
mm_inputs = self._get_mm_inputs(images, videos, processor)
if mm_inputs.get("pixel_values"):
mm_inputs["pixel_values"] = mm_inputs["pixel_values"][0]
self._validate_input(processor, images, videos, audios)
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
mm_inputs.pop("image_sizes", None)
return mm_inputs
class Qwen2vlPlugin(BasePlugin):
@dataclass
class Qwen2AudioPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[Dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
self._validate_input(processor, images, videos, audios)
bos_token: str = getattr(processor, "audio_bos_token")
eos_token: str = getattr(processor, "audio_eos_token")
messages = deepcopy(messages)
mm_inputs = self._get_mm_inputs([], [], audios, processor)
if "feature_attention_mask" in mm_inputs:
audio_lengths = mm_inputs["feature_attention_mask"].sum(-1).tolist()
num_audio_tokens = 0
for message in messages:
content = message["content"]
while AUDIO_PLACEHOLDER in content:
if self.expand_mm_tokens:
audio_length = audio_lengths.pop(0)
input_length = (audio_length - 1) // 2 + 1
audio_seqlen = (input_length - 2) // 2 + 1
else:
audio_seqlen = 1
content = content.replace(
AUDIO_PLACEHOLDER, f"{bos_token}{self.audio_token * audio_seqlen}{eos_token}", 1
)
num_audio_tokens += 1
message["content"] = content
if len(audios) != num_audio_tokens:
raise ValueError(f"The number of audios does not match the number of {AUDIO_PLACEHOLDER} tokens.")
return messages
@override
def get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
audlens: Sequence[int],
batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios)
return self._get_mm_inputs(images, videos, audios, processor)
@dataclass
class Qwen2VLPlugin(BasePlugin):
@override
def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject":
image = super()._preprocess_image(image, **kwargs)
if min(image.width, image.height) < 28:
width, height = max(image.width, 28), max(image.height, 28)
image = image.resize((width, height), resample=Image.NEAREST)
image = image.resize((width, height))
if image.width / image.height > 200:
width, height = image.height * 180, image.height
image = image.resize((width, height), resample=Image.NEAREST)
image = image.resize((width, height))
if image.height / image.width > 200:
width, height = image.width, image.width * 180
image = image.resize((width, height), resample=Image.NEAREST)
image = image.resize((width, height))
return image
@override
def _regularize_videos(self, videos: Sequence["VideoInput"], **kwargs) -> List[List["ImageObject"]]:
results = []
def _regularize_videos(
self, videos: Sequence["VideoInput"], **kwargs
) -> Tuple[List[List["ImageObject"]], List[float]]:
results, fps_per_video = [], []
for video in videos:
container = av.open(video, "r")
video_stream = next(stream for stream in container.streams if stream.type == "video")
total_frames = video_stream.frames
sample_frames = self._get_video_sample_frames(video_stream, **kwargs)
sample_indices = np.linspace(0, total_frames - 1, sample_frames).astype(np.int32)
sample_indices = self._get_video_sample_indices(video_stream, **kwargs)
frames: List["ImageObject"] = []
container.seek(0)
for frame_idx, frame in enumerate(container.decode(video_stream)):
......@@ -812,8 +1074,43 @@ class Qwen2vlPlugin(BasePlugin):
frames = self._regularize_images(frames, **kwargs)
results.append(frames)
if video_stream.duration is None:
fps_per_video.append(2.0)
else:
fps_per_video.append(len(sample_indices) / float(video_stream.duration * video_stream.time_base))
return results
return results, fps_per_video
@override
def _get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: "ProcessorMixin",
) -> Dict[str, "torch.Tensor"]:
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor", None)
mm_inputs = {}
if len(images) != 0:
images = self._regularize_images(
images,
image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
)
mm_inputs.update(image_processor(images, return_tensors="pt"))
if len(videos) != 0:
videos, fps_per_video = self._regularize_videos(
videos,
image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256),
image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16),
video_fps=getattr(processor, "video_fps", 2.0),
video_maxlen=getattr(processor, "video_maxlen", 128),
)
mm_inputs.update(image_processor(images=None, videos=videos, return_tensors="pt"))
mm_inputs["fps_per_video"] = fps_per_video
return mm_inputs
@override
def process_messages(
......@@ -821,17 +1118,23 @@ class Qwen2vlPlugin(BasePlugin):
messages: Sequence[Dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
self._validate_input(images, videos)
self._validate_input(processor, images, videos, audios)
num_image_tokens, num_video_tokens = 0, 0
messages = deepcopy(messages)
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
merge_length: int = getattr(image_processor, "merge_size") ** 2
mm_inputs = self._get_mm_inputs(images, videos, processor)
image_grid_thw = mm_inputs.get("image_grid_thw", [])
video_grid_thw = mm_inputs.get("video_grid_thw", [])
if self.expand_mm_tokens:
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
image_grid_thw = mm_inputs.get("image_grid_thw", [])
video_grid_thw = mm_inputs.get("video_grid_thw", [])
else:
image_grid_thw = [None] * len(images)
video_grid_thw = [None] * len(videos)
num_image_tokens, num_video_tokens = 0, 0
messages = deepcopy(messages)
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
......@@ -869,15 +1172,24 @@ class Qwen2vlPlugin(BasePlugin):
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
audlens: Sequence[int],
batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos)
return self._get_mm_inputs(images, videos, processor)
self._validate_input(processor, images, videos, audios)
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
fps_per_video = mm_inputs.pop("fps_per_video", [])
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
if "second_per_grid_ts" in processor.model_input_names and fps_per_video:
mm_inputs["second_per_grid_ts"] = [image_processor.temporal_patch_size / fps for fps in fps_per_video]
return mm_inputs
@dataclass
class VideoLlavaPlugin(BasePlugin):
@override
def process_messages(
......@@ -885,12 +1197,13 @@ class VideoLlavaPlugin(BasePlugin):
messages: Sequence[Dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
self._validate_input(images, videos)
self._validate_input(processor, images, videos, audios)
num_image_tokens, num_video_tokens = 0, 0
messages = deepcopy(messages)
mm_inputs = self._get_mm_inputs(images, videos, processor)
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
num_frames = 0
has_images = "pixel_values_images" in mm_inputs
has_videos = "pixel_values_videos" in mm_inputs
......@@ -907,7 +1220,7 @@ class VideoLlavaPlugin(BasePlugin):
image_seqlen = (height // processor.patch_size) * (width // processor.patch_size) + 1
video_seqlen = image_seqlen * num_frames
if getattr(processor, "vision_feature_select_strategy") == "default":
if getattr(processor, "vision_feature_select_strategy", "default") == "default":
image_seqlen -= 1
else:
image_seqlen, video_seqlen = 1, 1
......@@ -938,13 +1251,15 @@ class VideoLlavaPlugin(BasePlugin):
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
audlens: Sequence[int],
batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos)
return self._get_mm_inputs(images, videos, processor)
self._validate_input(processor, images, videos, audios)
return self._get_mm_inputs(images, videos, audios, processor)
PLUGINS = {
......@@ -956,18 +1271,32 @@ PLUGINS = {
"mllama": MllamaPlugin,
"paligemma": PaliGemmaPlugin,
"pixtral": PixtralPlugin,
"qwen2_vl": Qwen2vlPlugin,
"qwen2_audio": Qwen2AudioPlugin,
"qwen2_vl": Qwen2VLPlugin,
"video_llava": VideoLlavaPlugin,
}
def register_mm_plugin(name: str, plugin_class: Type["BasePlugin"]) -> None:
r"""
Registers a multimodal plugin.
"""
if name in PLUGINS:
raise ValueError(f"Multimodal plugin {name} already exists.")
PLUGINS[name] = plugin_class
def get_mm_plugin(
name: str,
image_token: Optional[str] = None,
video_token: Optional[str] = None,
audio_token: Optional[str] = None,
) -> "BasePlugin":
plugin_class = PLUGINS.get(name, None)
if plugin_class is None:
r"""
Gets plugin for multimodal inputs.
"""
if name not in PLUGINS:
raise ValueError(f"Multimodal plugin `{name}` not found.")
return plugin_class(image_token, video_token)
return PLUGINS[name](image_token, video_token, audio_token)
# Copyright 2024 the LlamaFactory team.
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -44,7 +44,8 @@ class DatasetAttr:
tools: Optional[str] = None
images: Optional[str] = None
videos: Optional[str] = None
# rlhf columns
audios: Optional[str] = None
# dpo columns
chosen: Optional[str] = None
rejected: Optional[str] = None
kto_tag: Optional[str] = None
......@@ -70,6 +71,26 @@ class DatasetAttr:
def set_attr(self, key: str, obj: Dict[str, Any], default: Optional[Any] = None) -> None:
setattr(self, key, obj.get(key, default))
def join(self, attr: Dict[str, Any]) -> None:
self.set_attr("formatting", attr, default="alpaca")
self.set_attr("ranking", attr, default=False)
self.set_attr("subset", attr)
self.set_attr("split", attr, default="train")
self.set_attr("folder", attr)
self.set_attr("num_samples", attr)
if "columns" in attr:
column_names = ["prompt", "query", "response", "history", "messages", "system", "tools"]
column_names += ["images", "videos", "audios", "chosen", "rejected", "kto_tag"]
for column_name in column_names:
self.set_attr(column_name, attr["columns"])
if "tags" in attr:
tag_names = ["role_tag", "content_tag"]
tag_names += ["user_tag", "assistant_tag", "observation_tag", "function_tag", "system_tag"]
for tag in tag_names:
self.set_attr(tag, attr["tags"])
def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -> List["DatasetAttr"]:
r"""
......@@ -127,36 +148,7 @@ def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -
else:
dataset_attr = DatasetAttr("file", dataset_name=dataset_info[name]["file_name"])
dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca")
dataset_attr.set_attr("ranking", dataset_info[name], default=False)
dataset_attr.set_attr("subset", dataset_info[name])
dataset_attr.set_attr("split", dataset_info[name], default="train")
dataset_attr.set_attr("folder", dataset_info[name])
dataset_attr.set_attr("num_samples", dataset_info[name])
if "columns" in dataset_info[name]:
column_names = ["system", "tools", "images", "videos", "chosen", "rejected", "kto_tag"]
if dataset_attr.formatting == "alpaca":
column_names.extend(["prompt", "query", "response", "history"])
else:
column_names.extend(["messages"])
for column_name in column_names:
dataset_attr.set_attr(column_name, dataset_info[name]["columns"])
if dataset_attr.formatting == "sharegpt" and "tags" in dataset_info[name]:
tag_names = (
"role_tag",
"content_tag",
"user_tag",
"assistant_tag",
"observation_tag",
"function_tag",
"system_tag",
)
for tag in tag_names:
dataset_attr.set_attr(tag, dataset_info[name]["tags"])
dataset_attr.join(dataset_info[name])
dataset_list.append(dataset_attr)
return dataset_list
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import TYPE_CHECKING, Callable, Literal, Optional, Tuple
from .processors.feedback import preprocess_feedback_dataset
from .processors.pairwise import preprocess_pairwise_dataset, print_pairwise_dataset_example
from .processors.pretrain import preprocess_pretrain_dataset, print_pretrain_dataset_example
from .processors.supervised import (
preprocess_packed_supervised_dataset,
preprocess_supervised_dataset,
print_supervised_dataset_example,
)
from .processors.unsupervised import preprocess_unsupervised_dataset, print_unsupervised_dataset_example
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer, ProcessorMixin
from ..hparams import DataArguments
from .template import Template
def get_preprocess_and_print_func(
data_args: "DataArguments",
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
do_generate: bool = False,
) -> Tuple[Callable, Callable]:
if stage == "pt":
preprocess_func = partial(
preprocess_pretrain_dataset,
tokenizer=tokenizer,
data_args=data_args,
)
print_function = partial(print_pretrain_dataset_example, tokenizer=tokenizer)
elif stage == "sft" and not do_generate:
if data_args.packing:
if data_args.neat_packing: # hack datasets to have int32 attention mask
from datasets.arrow_writer import OptimizedTypedSequence, TypedSequence
def __init__(self, data, **kwargs):
return TypedSequence.__init__(
self,
data,
type=kwargs.pop("type", None),
try_type=kwargs.pop("try_type", None),
optimized_int_type=kwargs.pop("optimized_int_type", None),
)
OptimizedTypedSequence.__init__ = __init__
preprocess_func = partial(
preprocess_packed_supervised_dataset,
template=template,
tokenizer=tokenizer,
processor=processor,
data_args=data_args,
)
else:
preprocess_func = partial(
preprocess_supervised_dataset,
template=template,
tokenizer=tokenizer,
processor=processor,
data_args=data_args,
)
print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer)
elif stage == "rm":
preprocess_func = partial(
preprocess_pairwise_dataset,
template=template,
tokenizer=tokenizer,
processor=processor,
data_args=data_args,
)
print_function = partial(print_pairwise_dataset_example, tokenizer=tokenizer)
elif stage == "kto":
preprocess_func = partial(
preprocess_feedback_dataset,
template=template,
tokenizer=tokenizer,
processor=processor,
data_args=data_args,
)
print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer)
else:
preprocess_func = partial(
preprocess_unsupervised_dataset,
template=template,
tokenizer=tokenizer,
processor=processor,
data_args=data_args,
)
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
return preprocess_func, print_function
from .feedback import FeedbackDatasetProcessor
from .pairwise import PairwiseDatasetProcessor
from .pretrain import PretrainDatasetProcessor
from .processor_utils import DatasetProcessor
from .supervised import PackedSupervisedDatasetProcessor, SupervisedDatasetProcessor
from .unsupervised import UnsupervisedDatasetProcessor
__all__ = [
"DatasetProcessor",
"FeedbackDatasetProcessor",
"PairwiseDatasetProcessor",
"PretrainDatasetProcessor",
"PackedSupervisedDatasetProcessor",
"SupervisedDatasetProcessor",
"UnsupervisedDatasetProcessor",
]
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras import logging
from ...extras.constants import IGNORE_INDEX
from .processor_utils import DatasetProcessor, infer_seqlen
if TYPE_CHECKING:
from ..mm_plugin import AudioInput, ImageInput, VideoInput
logger = logging.get_logger(__name__)
class FeedbackDatasetProcessor(DatasetProcessor):
def _encode_data_example(
self,
prompt: Sequence[Dict[str, str]],
response: Sequence[Dict[str, str]],
kl_response: Sequence[Dict[str, str]],
system: Optional[str],
tools: Optional[str],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
) -> Tuple[List[int], List[int], List[int], List[int], bool]:
if response[0]["content"]: # desired example
kto_tag = True
messages = prompt + [response[0]]
else: # undesired example
kto_tag = False
messages = prompt + [response[1]]
if kl_response[0]["content"]:
kl_messages = prompt + [kl_response[0]]
else:
kl_messages = prompt + [kl_response[1]]
messages = self.template.mm_plugin.process_messages(messages, images, videos, audios, self.processor)
kl_messages = self.template.mm_plugin.process_messages(kl_messages, images, videos, audios, self.processor)
prompt_ids, response_ids = self.template.encode_oneturn(self.tokenizer, messages, system, tools)
kl_prompt_ids, kl_response_ids = self.template.encode_oneturn(self.tokenizer, kl_messages, system, tools)
if self.template.efficient_eos:
response_ids += [self.tokenizer.eos_token_id]
kl_response_ids += [self.tokenizer.eos_token_id]
prompt_ids, _ = self.template.mm_plugin.process_token_ids(
prompt_ids, None, images, videos, audios, self.tokenizer, self.processor
)
kl_prompt_ids, _ = self.template.mm_plugin.process_token_ids(
kl_prompt_ids, None, images, videos, audios, self.tokenizer, self.processor
)
source_len, target_len = infer_seqlen(len(prompt_ids), len(response_ids), self.data_args.cutoff_len)
prompt_ids = prompt_ids[:source_len]
response_ids = response_ids[:target_len]
kl_source_len, kl_target_len = infer_seqlen(
len(kl_prompt_ids), len(kl_response_ids), self.data_args.cutoff_len
)
kl_prompt_ids = kl_prompt_ids[:kl_source_len]
kl_response_ids = kl_response_ids[:kl_target_len]
input_ids = prompt_ids + response_ids
labels = [IGNORE_INDEX] * source_len + response_ids
kl_input_ids = kl_prompt_ids + kl_response_ids
kl_labels = [IGNORE_INDEX] * kl_source_len + kl_response_ids
return input_ids, labels, kl_input_ids, kl_labels, kto_tag
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
# create unrelated input-output pairs for estimating the KL term by flipping the matched pairs
kl_response = examples["_response"][::-1]
model_inputs = defaultdict(list)
for i in range(len(examples["_prompt"])):
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2:
logger.warning_rank0(
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
)
continue
input_ids, labels, kl_input_ids, kl_labels, kto_tag = self._encode_data_example(
prompt=examples["_prompt"][i],
response=examples["_response"][i],
kl_response=kl_response[i],
system=examples["_system"][i],
tools=examples["_tools"][i],
images=examples["_images"][i] or [],
videos=examples["_videos"][i] or [],
audios=examples["_audios"][i] or [],
)
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels)
model_inputs["kl_input_ids"].append(kl_input_ids)
model_inputs["kl_attention_mask"].append([1] * len(kl_input_ids))
model_inputs["kl_labels"].append(kl_labels)
model_inputs["kto_tags"].append(kto_tag)
model_inputs["images"].append(examples["_images"][i])
model_inputs["videos"].append(examples["_videos"][i])
model_inputs["audios"].append(examples["_audios"][i])
desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag])
undesirable_num = len(model_inputs["kto_tags"]) - desirable_num
if desirable_num == 0 or undesirable_num == 0:
logger.warning_rank0("Your dataset only has one preference type.")
return model_inputs
def print_data_example(self, example: Dict[str, List[int]]) -> None:
valid_labels = list(filter(lambda x: x != IGNORE_INDEX, example["labels"]))
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
print("label_ids:\n{}".format(example["labels"]))
print(f"labels:\n{self.tokenizer.decode(valid_labels, skip_special_tokens=False)}")
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras import logging
from ...extras.constants import IGNORE_INDEX
from .processor_utils import DatasetProcessor, infer_seqlen
if TYPE_CHECKING:
from ..mm_plugin import AudioInput, ImageInput, VideoInput
logger = logging.get_logger(__name__)
class PairwiseDatasetProcessor(DatasetProcessor):
def _encode_data_example(
self,
prompt: Sequence[Dict[str, str]],
response: Sequence[Dict[str, str]],
system: Optional[str],
tools: Optional[str],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
) -> Tuple[List[int], List[int], List[int], List[int]]:
chosen_messages = self.template.mm_plugin.process_messages(
prompt + [response[0]], images, videos, audios, self.processor
)
rejected_messages = self.template.mm_plugin.process_messages(
prompt + [response[1]], images, videos, audios, self.processor
)
prompt_ids, chosen_ids = self.template.encode_oneturn(self.tokenizer, chosen_messages, system, tools)
_, rejected_ids = self.template.encode_oneturn(self.tokenizer, rejected_messages, system, tools)
if self.template.efficient_eos:
chosen_ids += [self.tokenizer.eos_token_id]
rejected_ids += [self.tokenizer.eos_token_id]
prompt_ids, _ = self.template.mm_plugin.process_token_ids(
prompt_ids, None, images, videos, audios, self.tokenizer, self.processor
)
# consider the response is more important
source_len, target_len = infer_seqlen(
len(prompt_ids), max(len(chosen_ids), len(rejected_ids)), self.data_args.cutoff_len
)
prompt_ids = prompt_ids[:source_len]
chosen_ids = chosen_ids[:target_len]
rejected_ids = rejected_ids[:target_len]
chosen_input_ids = prompt_ids + chosen_ids
chosen_labels = [IGNORE_INDEX] * source_len + chosen_ids
rejected_input_ids = prompt_ids + rejected_ids
rejected_labels = [IGNORE_INDEX] * source_len + rejected_ids
return chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
model_inputs = defaultdict(list)
for i in range(len(examples["_prompt"])):
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2:
logger.warning_rank0(
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
)
continue
chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels = self._encode_data_example(
prompt=examples["_prompt"][i],
response=examples["_response"][i],
system=examples["_system"][i],
tools=examples["_tools"][i],
images=examples["_images"][i] or [],
videos=examples["_videos"][i] or [],
audios=examples["_audios"][i] or [],
)
model_inputs["chosen_input_ids"].append(chosen_input_ids)
model_inputs["chosen_attention_mask"].append([1] * len(chosen_input_ids))
model_inputs["chosen_labels"].append(chosen_labels)
model_inputs["rejected_input_ids"].append(rejected_input_ids)
model_inputs["rejected_attention_mask"].append([1] * len(rejected_input_ids))
model_inputs["rejected_labels"].append(rejected_labels)
model_inputs["images"].append(examples["_images"][i])
model_inputs["videos"].append(examples["_videos"][i])
model_inputs["audios"].append(examples["_audios"][i])
return model_inputs
def print_data_example(self, example: Dict[str, List[int]]) -> None:
valid_chosen_labels = list(filter(lambda x: x != IGNORE_INDEX, example["chosen_labels"]))
valid_rejected_labels = list(filter(lambda x: x != IGNORE_INDEX, example["rejected_labels"]))
print("chosen_input_ids:\n{}".format(example["chosen_input_ids"]))
print(
"chosen_inputs:\n{}".format(self.tokenizer.decode(example["chosen_input_ids"], skip_special_tokens=False))
)
print("chosen_label_ids:\n{}".format(example["chosen_labels"]))
print(f"chosen_labels:\n{self.tokenizer.decode(valid_chosen_labels, skip_special_tokens=False)}")
print("rejected_input_ids:\n{}".format(example["rejected_input_ids"]))
print(
"rejected_inputs:\n{}".format(
self.tokenizer.decode(example["rejected_input_ids"], skip_special_tokens=False)
)
)
print("rejected_label_ids:\n{}".format(example["rejected_labels"]))
print(f"rejected_labels:\n{self.tokenizer.decode(valid_rejected_labels, skip_special_tokens=False)}")
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from itertools import chain
from typing import Any, Dict, List
from .processor_utils import DatasetProcessor
@dataclass
class PretrainDatasetProcessor(DatasetProcessor):
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
# build grouped texts with format `X1 X2 X3 ...` if packing is enabled
eos_token = "<|end_of_text|>" if self.data_args.template == "llama3" else self.tokenizer.eos_token
text_examples = [messages[0]["content"] + eos_token for messages in examples["_prompt"]]
if not self.data_args.packing:
if getattr(self.tokenizer, "add_bos_token", False):
text_examples = [self.tokenizer.bos_token + example for example in text_examples]
result = self.tokenizer(
text_examples, add_special_tokens=False, truncation=True, max_length=self.data_args.cutoff_len
)
else:
tokenized_examples = self.tokenizer(text_examples, add_special_tokens=False)
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
block_size = self.data_args.cutoff_len
total_length = (total_length // block_size) * block_size
result = {
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
for k, t in concatenated_examples.items()
}
if getattr(self.tokenizer, "add_bos_token", False):
for i in range(len(result["input_ids"])):
result["input_ids"][i][0] = self.tokenizer.bos_token_id
return result
def print_data_example(self, example: Dict[str, List[int]]) -> None:
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
# Copyright 2024 the LlamaFactory team.
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -13,7 +13,42 @@
# limitations under the License.
import bisect
from typing import List, Sequence, Tuple
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer, ProcessorMixin
from ...hparams import DataArguments
from ..template import Template
@dataclass
class DatasetProcessor(ABC):
r"""
A class for data processors.
"""
template: "Template"
tokenizer: "PreTrainedTokenizer"
processor: Optional["ProcessorMixin"]
data_args: "DataArguments"
@abstractmethod
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
r"""
Builds model inputs from the examples.
"""
...
@abstractmethod
def print_data_example(self, example: Dict[str, List[int]]) -> None:
r"""
Print a data example to stdout.
"""
...
def search_for_fit(numbers: Sequence[int], capacity: int) -> int:
......
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras import logging
from ...extras.constants import IGNORE_INDEX
from .processor_utils import DatasetProcessor, greedy_knapsack, infer_seqlen
if TYPE_CHECKING:
from ..mm_plugin import AudioInput, ImageInput, VideoInput
logger = logging.get_logger(__name__)
@dataclass
class SupervisedDatasetProcessor(DatasetProcessor):
def _encode_data_example(
self,
prompt: Sequence[Dict[str, str]],
response: Sequence[Dict[str, str]],
system: Optional[str],
tools: Optional[str],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
) -> Tuple[List[int], List[int]]:
messages = self.template.mm_plugin.process_messages(prompt + response, images, videos, audios, self.processor)
input_ids, labels = self.template.mm_plugin.process_token_ids(
[], [], images, videos, audios, self.tokenizer, self.processor
)
encoded_pairs = self.template.encode_multiturn(self.tokenizer, messages, system, tools)
total_length = len(input_ids) + (1 if self.template.efficient_eos else 0)
if self.data_args.mask_history:
encoded_pairs = encoded_pairs[::-1] # high priority for last turns
for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs):
if total_length >= self.data_args.cutoff_len:
break
source_len, target_len = infer_seqlen(
len(source_ids), len(target_ids), self.data_args.cutoff_len - total_length
)
source_ids = source_ids[:source_len]
target_ids = target_ids[:target_len]
total_length += source_len + target_len
if self.data_args.train_on_prompt:
source_label = source_ids
elif self.template.efficient_eos:
source_label = [self.tokenizer.eos_token_id] + [IGNORE_INDEX] * (source_len - 1)
else:
source_label = [IGNORE_INDEX] * source_len
if self.data_args.mask_history and turn_idx != 0: # train on the last turn only
target_label = [IGNORE_INDEX] * target_len
else:
target_label = target_ids
if self.data_args.mask_history: # reversed sequences
input_ids = source_ids + target_ids + input_ids
labels = source_label + target_label + labels
else:
input_ids += source_ids + target_ids
labels += source_label + target_label
if self.template.efficient_eos:
input_ids += [self.tokenizer.eos_token_id]
labels += [self.tokenizer.eos_token_id]
return input_ids, labels
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
model_inputs = defaultdict(list)
for i in range(len(examples["_prompt"])):
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1:
logger.warning_rank0(
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
)
continue
input_ids, labels = self._encode_data_example(
prompt=examples["_prompt"][i],
response=examples["_response"][i],
system=examples["_system"][i],
tools=examples["_tools"][i],
images=examples["_images"][i] or [],
videos=examples["_videos"][i] or [],
audios=examples["_audios"][i] or [],
)
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels)
model_inputs["images"].append(examples["_images"][i])
model_inputs["videos"].append(examples["_videos"][i])
model_inputs["audios"].append(examples["_audios"][i])
return model_inputs
def print_data_example(self, example: Dict[str, List[int]]) -> None:
valid_labels = list(filter(lambda x: x != IGNORE_INDEX, example["labels"]))
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
print("label_ids:\n{}".format(example["labels"]))
print(f"labels:\n{self.tokenizer.decode(valid_labels, skip_special_tokens=False)}")
@dataclass
class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
# TODO: use `position_ids` to achieve packing
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
valid_num = 0
batch_input_ids, batch_labels, batch_images, batch_videos, batch_audios = [], [], [], [], []
lengths = []
length2indexes = defaultdict(list)
for i in range(len(examples["_prompt"])):
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1:
logger.warning_rank0(
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
)
continue
input_ids, labels = self._encode_data_example(
prompt=examples["_prompt"][i],
response=examples["_response"][i],
system=examples["_system"][i],
tools=examples["_tools"][i],
images=examples["_images"][i] or [],
videos=examples["_videos"][i] or [],
audios=examples["_audios"][i] or [],
)
length = len(input_ids)
if length > self.data_args.cutoff_len:
logger.warning_rank0(f"Dropped lengthy example with length {length} > {self.data_args.cutoff_len}.")
else:
lengths.append(length)
length2indexes[length].append(valid_num)
batch_input_ids.append(input_ids)
batch_labels.append(labels)
batch_images.append(examples["_images"][i] or [])
batch_videos.append(examples["_videos"][i] or [])
batch_audios.append(examples["_audios"][i] or [])
valid_num += 1
model_inputs = defaultdict(list)
knapsacks = greedy_knapsack(lengths, self.data_args.cutoff_len)
for knapsack in knapsacks:
packed_input_ids, packed_attention_masks, packed_labels = [], [], []
packed_images, packed_videos, packed_audios = [], [], []
for i, length in enumerate(knapsack):
index = length2indexes[length].pop()
packed_input_ids += batch_input_ids[index]
packed_labels += batch_labels[index]
packed_images += batch_images[index]
packed_videos += batch_videos[index]
packed_audios += batch_audios[index]
if self.data_args.neat_packing:
packed_attention_masks += [i + 1] * len(batch_input_ids[index]) # start from 1
else:
packed_attention_masks += [1] * len(batch_input_ids[index])
if len(packed_input_ids) < self.data_args.cutoff_len + 1: # avoid flash_attn drops attn mask
pad_length = self.data_args.cutoff_len - len(packed_input_ids) + 1
packed_input_ids += [self.tokenizer.pad_token_id] * pad_length
packed_labels += [IGNORE_INDEX] * pad_length
if self.data_args.neat_packing:
packed_attention_masks += [0] * pad_length
else:
packed_attention_masks += [1] * pad_length # more efficient flash_attn
if len(packed_input_ids) != self.data_args.cutoff_len + 1:
raise ValueError("The length of packed example should be identical to the cutoff length.")
model_inputs["input_ids"].append(packed_input_ids)
model_inputs["attention_mask"].append(packed_attention_masks)
model_inputs["labels"].append(packed_labels)
model_inputs["images"].append(packed_images or None)
model_inputs["videos"].append(packed_videos or None)
model_inputs["audios"].append(packed_audios or None)
return model_inputs
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras import logging
from ..data_utils import Role
from .processor_utils import DatasetProcessor, infer_seqlen
if TYPE_CHECKING:
from ..mm_plugin import AudioInput, ImageInput, VideoInput
logger = logging.get_logger(__name__)
class UnsupervisedDatasetProcessor(DatasetProcessor):
def _encode_data_example(
self,
prompt: Sequence[Dict[str, str]],
response: Sequence[Dict[str, str]],
system: Optional[str],
tools: Optional[str],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
) -> Tuple[List[int], List[int]]:
if len(response) == 1:
messages = prompt + response
else:
messages = prompt + [{"role": Role.ASSISTANT.value, "content": ""}]
messages = self.template.mm_plugin.process_messages(messages, images, videos, audios, self.processor)
input_ids, labels = self.template.encode_oneturn(self.tokenizer, messages, system, tools)
if self.template.efficient_eos:
labels += [self.tokenizer.eos_token_id]
input_ids, _ = self.template.mm_plugin.process_token_ids(
input_ids, None, images, videos, audios, self.tokenizer, self.processor
)
source_len, target_len = infer_seqlen(len(input_ids), len(labels), self.data_args.cutoff_len)
input_ids = input_ids[:source_len]
labels = labels[:target_len]
return input_ids, labels
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
# build inputs with format `<bos> X` and labels with format `Y <eos>`
model_inputs = defaultdict(list)
for i in range(len(examples["_prompt"])):
if len(examples["_prompt"][i]) % 2 != 1:
logger.warning_rank0(
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
)
continue
input_ids, labels = self._encode_data_example(
prompt=examples["_prompt"][i],
response=examples["_response"][i],
system=examples["_system"][i],
tools=examples["_tools"][i],
images=examples["_images"][i] or [],
videos=examples["_videos"][i] or [],
audios=examples["_audios"][i] or [],
)
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels)
model_inputs["images"].append(examples["_images"][i])
model_inputs["videos"].append(examples["_videos"][i])
model_inputs["audios"].append(examples["_audios"][i])
return model_inputs
def print_data_example(self, example: Dict[str, List[int]]) -> None:
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
print("label_ids:\n{}".format(example["labels"]))
print("labels:\n{}".format(self.tokenizer.decode(example["labels"], skip_special_tokens=False)))
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras import logging
from ...extras.constants import IGNORE_INDEX
from .processor_utils import infer_seqlen
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer, ProcessorMixin
from ...hparams import DataArguments
from ..mm_plugin import ImageInput, VideoInput
from ..template import Template
logger = logging.get_logger(__name__)
def _encode_feedback_example(
prompt: Sequence[Dict[str, str]],
response: Sequence[Dict[str, str]],
kl_response: Sequence[Dict[str, str]],
system: Optional[str],
tools: Optional[str],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
cutoff_len: int,
) -> Tuple[List[int], List[int], List[int], List[int], bool]:
if response[0]["content"]: # desired example
kto_tag = True
messages = prompt + [response[0]]
else: # undesired example
kto_tag = False
messages = prompt + [response[1]]
if kl_response[0]["content"]:
kl_messages = prompt + [kl_response[0]]
else:
kl_messages = prompt + [kl_response[1]]
messages = template.mm_plugin.process_messages(messages, images, videos, processor)
kl_messages = template.mm_plugin.process_messages(kl_messages, images, videos, processor)
prompt_ids, response_ids = template.encode_oneturn(tokenizer, messages, system, tools)
kl_prompt_ids, kl_response_ids = template.encode_oneturn(tokenizer, kl_messages, system, tools)
if template.efficient_eos:
response_ids += [tokenizer.eos_token_id]
kl_response_ids += [tokenizer.eos_token_id]
prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, images, videos, tokenizer, processor)
kl_prompt_ids, _ = template.mm_plugin.process_token_ids(kl_prompt_ids, None, images, videos, tokenizer, processor)
source_len, target_len = infer_seqlen(len(prompt_ids), len(response_ids), cutoff_len)
prompt_ids = prompt_ids[:source_len]
response_ids = response_ids[:target_len]
kl_source_len, kl_target_len = infer_seqlen(len(kl_prompt_ids), len(kl_response_ids), cutoff_len)
kl_prompt_ids = kl_prompt_ids[:kl_source_len]
kl_response_ids = kl_response_ids[:kl_target_len]
input_ids = prompt_ids + response_ids
labels = [IGNORE_INDEX] * source_len + response_ids
kl_input_ids = kl_prompt_ids + kl_response_ids
kl_labels = [IGNORE_INDEX] * kl_source_len + kl_response_ids
return input_ids, labels, kl_input_ids, kl_labels, kto_tag
def preprocess_feedback_dataset(
examples: Dict[str, List[Any]],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
data_args: "DataArguments",
) -> Dict[str, List[Any]]:
# create unrelated input-output pairs for estimating the KL term by flipping the matched pairs
kl_response = examples["_response"][::-1]
model_inputs = defaultdict(list)
for i in range(len(examples["_prompt"])):
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2:
logger.warning_rank0(
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
)
continue
input_ids, labels, kl_input_ids, kl_labels, kto_tag = _encode_feedback_example(
prompt=examples["_prompt"][i],
response=examples["_response"][i],
kl_response=kl_response[i],
system=examples["_system"][i],
tools=examples["_tools"][i],
images=examples["_images"][i] or [],
videos=examples["_videos"][i] or [],
template=template,
tokenizer=tokenizer,
processor=processor,
cutoff_len=data_args.cutoff_len,
)
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels)
model_inputs["kl_input_ids"].append(kl_input_ids)
model_inputs["kl_attention_mask"].append([1] * len(kl_input_ids))
model_inputs["kl_labels"].append(kl_labels)
model_inputs["kto_tags"].append(kto_tag)
model_inputs["images"].append(examples["_images"][i])
model_inputs["videos"].append(examples["_videos"][i])
desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag])
undesirable_num = len(model_inputs["kto_tags"]) - desirable_num
if desirable_num == 0 or undesirable_num == 0:
logger.warning_rank0("Your dataset only has one preference type.")
return model_inputs
from typing import TYPE_CHECKING, List, Sequence
from ...extras.packages import is_pillow_available
if is_pillow_available():
from PIL import Image
if TYPE_CHECKING:
from numpy.typing import NDArray
from PIL.Image import Image as ImageObject
from transformers import ProcessorMixin
from transformers.image_processing_utils import BaseImageProcessor
def get_pixel_values(images: Sequence["ImageObject"], processor: "ProcessorMixin") -> "NDArray":
# process visual inputs (currently only supports a single image)
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
image = images[0] if len(images) != 0 else Image.new("RGB", (100, 100), (255, 255, 255))
return image_processor(image, return_tensors="pt")["pixel_values"][0] # shape (C, H, W)
def get_paligemma_token_type_ids(input_len: int, processor: "ProcessorMixin") -> List[int]:
# get paligemma token type ids for computing loss
image_seq_length = getattr(processor, "image_seq_length")
return [0] * image_seq_length + [1] * (input_len - image_seq_length)
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras import logging
from ...extras.constants import IGNORE_INDEX
from .processor_utils import infer_seqlen
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer, ProcessorMixin
from ...hparams import DataArguments
from ..mm_plugin import ImageInput, VideoInput
from ..template import Template
logger = logging.get_logger(__name__)
def _encode_pairwise_example(
prompt: Sequence[Dict[str, str]],
response: Sequence[Dict[str, str]],
system: Optional[str],
tools: Optional[str],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
cutoff_len: int,
) -> Tuple[List[int], List[int], List[int], List[int]]:
chosen_messages = template.mm_plugin.process_messages(prompt + [response[0]], images, videos, processor)
rejected_messages = template.mm_plugin.process_messages(prompt + [response[1]], images, videos, processor)
prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, chosen_messages, system, tools)
_, rejected_ids = template.encode_oneturn(tokenizer, rejected_messages, system, tools)
if template.efficient_eos:
chosen_ids += [tokenizer.eos_token_id]
rejected_ids += [tokenizer.eos_token_id]
prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, images, videos, tokenizer, processor)
# consider the response is more important
source_len, target_len = infer_seqlen(len(prompt_ids), max(len(chosen_ids), len(rejected_ids)), cutoff_len)
prompt_ids = prompt_ids[:source_len]
chosen_ids = chosen_ids[:target_len]
rejected_ids = rejected_ids[:target_len]
chosen_input_ids = prompt_ids + chosen_ids
chosen_labels = [IGNORE_INDEX] * source_len + chosen_ids
rejected_input_ids = prompt_ids + rejected_ids
rejected_labels = [IGNORE_INDEX] * source_len + rejected_ids
return chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels
def preprocess_pairwise_dataset(
examples: Dict[str, List[Any]],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
data_args: "DataArguments",
) -> Dict[str, List[Any]]:
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
model_inputs = defaultdict(list)
for i in range(len(examples["_prompt"])):
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2:
logger.warning_rank0(
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
)
continue
chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels = _encode_pairwise_example(
prompt=examples["_prompt"][i],
response=examples["_response"][i],
system=examples["_system"][i],
tools=examples["_tools"][i],
images=examples["_images"][i] or [],
videos=examples["_videos"][i] or [],
template=template,
tokenizer=tokenizer,
processor=processor,
cutoff_len=data_args.cutoff_len,
)
model_inputs["chosen_input_ids"].append(chosen_input_ids)
model_inputs["chosen_attention_mask"].append([1] * len(chosen_input_ids))
model_inputs["chosen_labels"].append(chosen_labels)
model_inputs["rejected_input_ids"].append(rejected_input_ids)
model_inputs["rejected_attention_mask"].append([1] * len(rejected_input_ids))
model_inputs["rejected_labels"].append(rejected_labels)
model_inputs["images"].append(examples["_images"][i])
model_inputs["videos"].append(examples["_videos"][i])
return model_inputs
def print_pairwise_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
valid_chosen_labels = list(filter(lambda x: x != IGNORE_INDEX, example["chosen_labels"]))
valid_rejected_labels = list(filter(lambda x: x != IGNORE_INDEX, example["rejected_labels"]))
print("chosen_input_ids:\n{}".format(example["chosen_input_ids"]))
print("chosen_inputs:\n{}".format(tokenizer.decode(example["chosen_input_ids"], skip_special_tokens=False)))
print("chosen_label_ids:\n{}".format(example["chosen_labels"]))
print(f"chosen_labels:\n{tokenizer.decode(valid_chosen_labels, skip_special_tokens=False)}")
print("rejected_input_ids:\n{}".format(example["rejected_input_ids"]))
print("rejected_inputs:\n{}".format(tokenizer.decode(example["rejected_input_ids"], skip_special_tokens=False)))
print("rejected_label_ids:\n{}".format(example["rejected_labels"]))
print(f"rejected_labels:\n{tokenizer.decode(valid_rejected_labels, skip_special_tokens=False)}")
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from itertools import chain
from typing import TYPE_CHECKING, Any, Dict, List
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
from ...hparams import DataArguments
def preprocess_pretrain_dataset(
examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments"
) -> Dict[str, List[Any]]:
# build grouped texts with format `X1 X2 X3 ...` if packing is enabled
eos_token = "<|end_of_text|>" if data_args.template == "llama3" else tokenizer.eos_token
text_examples = [messages[0]["content"] + eos_token for messages in examples["_prompt"]]
if not data_args.packing:
if data_args.template == "gemma":
text_examples = [tokenizer.bos_token + example for example in text_examples]
result = tokenizer(text_examples, add_special_tokens=False, truncation=True, max_length=data_args.cutoff_len)
else:
tokenized_examples = tokenizer(text_examples, add_special_tokens=False)
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
block_size = data_args.cutoff_len
total_length = (total_length // block_size) * block_size
result = {
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
for k, t in concatenated_examples.items()
}
if data_args.template == "gemma":
for i in range(len(result["input_ids"])):
result["input_ids"][i][0] = tokenizer.bos_token_id
return result
def print_pretrain_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras import logging
from ...extras.constants import IGNORE_INDEX
from .processor_utils import greedy_knapsack, infer_seqlen
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer, ProcessorMixin
from ...hparams import DataArguments
from ..mm_plugin import ImageInput, VideoInput
from ..template import Template
logger = logging.get_logger(__name__)
def _encode_supervised_example(
prompt: Sequence[Dict[str, str]],
response: Sequence[Dict[str, str]],
system: Optional[str],
tools: Optional[str],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
cutoff_len: int,
train_on_prompt: bool,
mask_history: bool,
) -> Tuple[List[int], List[int]]:
messages = template.mm_plugin.process_messages(prompt + response, images, videos, processor)
input_ids, labels = template.mm_plugin.process_token_ids([], [], images, videos, tokenizer, processor)
encoded_pairs = template.encode_multiturn(tokenizer, messages, system, tools)
total_length = len(input_ids) + (1 if template.efficient_eos else 0)
if mask_history:
encoded_pairs = encoded_pairs[::-1] # high priority for last turns
for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs):
if total_length >= cutoff_len:
break
source_len, target_len = infer_seqlen(len(source_ids), len(target_ids), cutoff_len - total_length)
source_ids = source_ids[:source_len]
target_ids = target_ids[:target_len]
total_length += source_len + target_len
if train_on_prompt:
source_label = source_ids
elif template.efficient_eos:
source_label = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (source_len - 1)
else:
source_label = [IGNORE_INDEX] * source_len
if mask_history and turn_idx != 0: # train on the last turn only
target_label = [IGNORE_INDEX] * target_len
else:
target_label = target_ids
if mask_history: # reversed sequences
input_ids = source_ids + target_ids + input_ids
labels = source_label + target_label + labels
else:
input_ids += source_ids + target_ids
labels += source_label + target_label
if template.efficient_eos:
input_ids += [tokenizer.eos_token_id]
labels += [tokenizer.eos_token_id]
return input_ids, labels
def preprocess_supervised_dataset(
examples: Dict[str, List[Any]],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
data_args: "DataArguments",
) -> Dict[str, List[Any]]:
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
model_inputs = defaultdict(list)
for i in range(len(examples["_prompt"])):
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1:
logger.warning_rank0(
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
)
continue
input_ids, labels = _encode_supervised_example(
prompt=examples["_prompt"][i],
response=examples["_response"][i],
system=examples["_system"][i],
tools=examples["_tools"][i],
images=examples["_images"][i] or [],
videos=examples["_videos"][i] or [],
template=template,
tokenizer=tokenizer,
processor=processor,
cutoff_len=data_args.cutoff_len,
train_on_prompt=data_args.train_on_prompt,
mask_history=data_args.mask_history,
)
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels)
model_inputs["images"].append(examples["_images"][i])
model_inputs["videos"].append(examples["_videos"][i])
return model_inputs
def preprocess_packed_supervised_dataset(
examples: Dict[str, List[Any]],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
data_args: "DataArguments",
) -> Dict[str, List[Any]]:
# TODO: use `position_ids` to achieve packing
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
valid_num = 0
batch_input_ids, batch_labels, batch_images, batch_videos = [], [], [], []
lengths = []
length2indexes = defaultdict(list)
for i in range(len(examples["_prompt"])):
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1:
logger.warning_rank0(
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
)
continue
input_ids, labels = _encode_supervised_example(
prompt=examples["_prompt"][i],
response=examples["_response"][i],
system=examples["_system"][i],
tools=examples["_tools"][i],
images=examples["_images"][i] or [],
videos=examples["_videos"][i] or [],
template=template,
tokenizer=tokenizer,
processor=processor,
cutoff_len=data_args.cutoff_len - 1, # reserved for the padding token
train_on_prompt=data_args.train_on_prompt,
mask_history=data_args.mask_history,
)
length = len(input_ids)
if length > data_args.cutoff_len:
logger.warning_rank0(f"Dropped lengthy example with length {length} > {data_args.cutoff_len}.")
else:
lengths.append(length)
length2indexes[length].append(valid_num)
batch_input_ids.append(input_ids)
batch_labels.append(labels)
batch_images.append(examples["_images"][i] or [])
batch_videos.append(examples["_videos"][i] or [])
valid_num += 1
model_inputs = defaultdict(list)
knapsacks = greedy_knapsack(lengths, data_args.cutoff_len - 1) # reserved for the padding token
for knapsack in knapsacks:
packed_input_ids, packed_attention_masks, packed_labels = [], [], []
packed_images, packed_videos = [], []
for i, length in enumerate(knapsack):
index = length2indexes[length].pop()
packed_input_ids += batch_input_ids[index]
packed_labels += batch_labels[index]
packed_images += batch_images[index]
packed_videos += batch_videos[index]
if data_args.neat_packing:
packed_attention_masks += [i + 1] * len(batch_input_ids[index]) # start from 1
else:
packed_attention_masks += [1] * len(batch_input_ids[index])
if len(packed_input_ids) < data_args.cutoff_len:
pad_length = data_args.cutoff_len - len(packed_input_ids)
packed_input_ids += [tokenizer.pad_token_id] * pad_length
packed_labels += [IGNORE_INDEX] * pad_length
if data_args.neat_packing:
packed_attention_masks += [0] * pad_length
else:
packed_attention_masks += [1] * pad_length # more efficient flash_attn
if len(packed_input_ids) != data_args.cutoff_len:
raise ValueError("The length of packed example should be identical to the cutoff length.")
model_inputs["input_ids"].append(packed_input_ids)
model_inputs["attention_mask"].append(packed_attention_masks)
model_inputs["labels"].append(packed_labels)
model_inputs["images"].append(packed_images or None)
model_inputs["videos"].append(packed_videos or None)
return model_inputs
def print_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
valid_labels = list(filter(lambda x: x != IGNORE_INDEX, example["labels"]))
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
print("label_ids:\n{}".format(example["labels"]))
print(f"labels:\n{tokenizer.decode(valid_labels, skip_special_tokens=False)}")
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras import logging
from ..data_utils import Role
from .processor_utils import infer_seqlen
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer, ProcessorMixin
from ...hparams import DataArguments
from ..mm_plugin import ImageInput, VideoInput
from ..template import Template
logger = logging.get_logger(__name__)
def _encode_unsupervised_example(
prompt: Sequence[Dict[str, str]],
response: Sequence[Dict[str, str]],
system: Optional[str],
tools: Optional[str],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
cutoff_len: int,
) -> Tuple[List[int], List[int]]:
if len(response) == 1:
messages = prompt + response
else:
messages = prompt + [{"role": Role.ASSISTANT.value, "content": ""}]
messages = template.mm_plugin.process_messages(messages, images, videos, processor)
input_ids, labels = template.encode_oneturn(tokenizer, messages, system, tools)
if template.efficient_eos:
labels += [tokenizer.eos_token_id]
input_ids, _ = template.mm_plugin.process_token_ids(input_ids, None, images, videos, tokenizer, processor)
source_len, target_len = infer_seqlen(len(input_ids), len(labels), cutoff_len)
input_ids = input_ids[:source_len]
labels = labels[:target_len]
return input_ids, labels
def preprocess_unsupervised_dataset(
examples: Dict[str, List[Any]],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
data_args: "DataArguments",
) -> Dict[str, List[Any]]:
# build inputs with format `<bos> X` and labels with format `Y <eos>`
model_inputs = defaultdict(list)
for i in range(len(examples["_prompt"])):
if len(examples["_prompt"][i]) % 2 != 1:
logger.warning_rank0(
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
)
continue
input_ids, labels = _encode_unsupervised_example(
prompt=examples["_prompt"][i],
response=examples["_response"][i],
system=examples["_system"][i],
tools=examples["_tools"][i],
images=examples["_images"][i] or [],
videos=examples["_videos"][i] or [],
template=template,
tokenizer=tokenizer,
processor=processor,
cutoff_len=data_args.cutoff_len,
)
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels)
model_inputs["images"].append(examples["_images"][i])
model_inputs["videos"].append(examples["_videos"][i])
return model_inputs
def print_unsupervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
print("label_ids:\n{}".format(example["labels"]))
print("labels:\n{}".format(tokenizer.decode(example["labels"], skip_special_tokens=False)))
# Copyright 2024 the LlamaFactory team.
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -13,7 +13,7 @@
# limitations under the License.
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Type, Union
from typing_extensions import override
......@@ -47,6 +47,7 @@ class Template:
format_prefix: "Formatter"
default_system: str
stop_words: List[str]
thought_words: Tuple[str, str]
efficient_eos: bool
replace_eos: bool
replace_jinja_template: bool
......@@ -67,8 +68,8 @@ class Template:
for encoded_ids in encoded_messages[:-1]:
prompt_ids += encoded_ids
answer_ids = encoded_messages[-1]
return prompt_ids, answer_ids
response_ids = encoded_messages[-1]
return prompt_ids, response_ids
def encode_multiturn(
self,
......@@ -99,6 +100,27 @@ class Template:
return list(stop_token_ids)
def _convert_elements_to_ids(self, tokenizer: "PreTrainedTokenizer", elements: "SLOTS") -> List[int]:
r"""
Converts elements to token ids.
"""
token_ids = []
for elem in elements:
if isinstance(elem, str):
if len(elem) != 0:
token_ids += tokenizer.encode(elem, add_special_tokens=False)
elif isinstance(elem, dict):
token_ids += [tokenizer.convert_tokens_to_ids(elem.get("token"))]
elif isinstance(elem, set):
if "bos_token" in elem and tokenizer.bos_token_id is not None:
token_ids += [tokenizer.bos_token_id]
elif "eos_token" in elem and tokenizer.eos_token_id is not None:
token_ids += [tokenizer.eos_token_id]
else:
raise ValueError(f"Input must be string, set[str] or dict[str, str], got {type(elem)}")
return token_ids
def _encode(
self,
tokenizer: "PreTrainedTokenizer",
......@@ -109,7 +131,7 @@ class Template:
r"""
Encodes formatted inputs to pairs of token ids.
Turn 0: prefix + system + query resp
Turn t: sep + query resp
Turn t: query resp
"""
system = system or self.default_system
encoded_messages = []
......@@ -137,26 +159,179 @@ class Template:
return encoded_messages
def _convert_elements_to_ids(self, tokenizer: "PreTrainedTokenizer", elements: "SLOTS") -> List[int]:
@staticmethod
def _add_or_replace_eos_token(tokenizer: "PreTrainedTokenizer", eos_token: str) -> None:
r"""
Converts elements to token ids.
Adds or replaces eos token to the tokenizer.
"""
token_ids = []
for elem in elements:
if isinstance(elem, str):
if len(elem) != 0:
token_ids += tokenizer.encode(elem, add_special_tokens=False)
elif isinstance(elem, dict):
token_ids += [tokenizer.convert_tokens_to_ids(elem.get("token"))]
elif isinstance(elem, set):
if "bos_token" in elem and tokenizer.bos_token_id is not None:
token_ids += [tokenizer.bos_token_id]
elif "eos_token" in elem and tokenizer.eos_token_id is not None:
token_ids += [tokenizer.eos_token_id]
else:
raise ValueError(f"Input must be string, set[str] or dict[str, str], got {type(elem)}")
is_added = tokenizer.eos_token_id is None
num_added_tokens = tokenizer.add_special_tokens({"eos_token": eos_token})
return token_ids
if is_added:
logger.info_rank0(f"Add eos token: {tokenizer.eos_token}.")
else:
logger.info_rank0(f"Replace eos token: {tokenizer.eos_token}.")
if num_added_tokens > 0:
logger.warning_rank0("New tokens have been added, make sure `resize_vocab` is True.")
def fix_special_tokens(self, tokenizer: "PreTrainedTokenizer") -> None:
r"""
Adds eos token and pad token to the tokenizer.
"""
stop_words = self.stop_words
if self.replace_eos:
if not stop_words:
raise ValueError("Stop words are required to replace the EOS token.")
self._add_or_replace_eos_token(tokenizer, eos_token=stop_words[0])
stop_words = stop_words[1:]
if tokenizer.eos_token_id is None:
self._add_or_replace_eos_token(tokenizer, eos_token="<|endoftext|>")
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token
logger.info_rank0(f"Add pad token: {tokenizer.pad_token}")
if stop_words:
num_added_tokens = tokenizer.add_special_tokens(
dict(additional_special_tokens=stop_words), replace_additional_special_tokens=False
)
logger.info_rank0("Add {} to stop words.".format(",".join(stop_words)))
if num_added_tokens > 0:
logger.warning_rank0("New tokens have been added, make sure `resize_vocab` is True.")
@staticmethod
def _jinja_escape(content: str) -> str:
r"""
Escape single quotes in content.
"""
return content.replace("'", r"\'")
@staticmethod
def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: str = "content") -> str:
r"""
Converts slots to jinja template.
"""
slot_items = []
for slot in slots:
if isinstance(slot, str):
slot_pieces = slot.split("{{content}}")
if slot_pieces[0]:
slot_items.append("'" + Template._jinja_escape(slot_pieces[0]) + "'")
if len(slot_pieces) > 1:
slot_items.append(placeholder)
if slot_pieces[1]:
slot_items.append("'" + Template._jinja_escape(slot_pieces[1]) + "'")
elif isinstance(slot, set): # do not use {{ eos_token }} since it may be replaced
if "bos_token" in slot and tokenizer.bos_token_id is not None:
slot_items.append("'" + tokenizer.bos_token + "'")
elif "eos_token" in slot and tokenizer.eos_token_id is not None:
slot_items.append("'" + tokenizer.eos_token + "'")
elif isinstance(slot, dict):
raise ValueError("Dict is not supported.")
return " + ".join(slot_items)
def _get_jinja_template(self, tokenizer: "PreTrainedTokenizer") -> str:
r"""
Returns the jinja template.
"""
prefix = self._convert_slots_to_jinja(self.format_prefix.apply(), tokenizer)
system = self._convert_slots_to_jinja(self.format_system.apply(), tokenizer, placeholder="system_message")
user = self._convert_slots_to_jinja(self.format_user.apply(), tokenizer)
assistant = self._convert_slots_to_jinja(self.format_assistant.apply(), tokenizer)
jinja_template = ""
if prefix:
jinja_template += "{{ " + prefix + " }}"
if self.default_system:
jinja_template += "{% set system_message = '" + self._jinja_escape(self.default_system) + "' %}"
jinja_template += (
"{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}"
"{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% endif %}"
"{% if system_message is defined %}{{ " + system + " }}{% endif %}"
"{% for message in loop_messages %}"
"{% set content = message['content'] %}"
"{% if message['role'] == 'user' %}"
"{{ " + user + " }}"
"{% elif message['role'] == 'assistant' %}"
"{{ " + assistant + " }}"
"{% endif %}"
"{% endfor %}"
)
return jinja_template
def fix_jinja_template(self, tokenizer: "PreTrainedTokenizer") -> None:
r"""
Replaces the jinja template in the tokenizer.
"""
if tokenizer.chat_template is None or self.replace_jinja_template:
try:
tokenizer.chat_template = self._get_jinja_template(tokenizer)
except ValueError as e:
logger.info_rank0(f"Cannot add this chat template to tokenizer: {e}.")
@staticmethod
def _convert_slots_to_ollama(
slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: str = "content"
) -> str:
r"""
Converts slots to ollama template.
"""
slot_items = []
for slot in slots:
if isinstance(slot, str):
slot_pieces = slot.split("{{content}}")
if slot_pieces[0]:
slot_items.append(slot_pieces[0])
if len(slot_pieces) > 1:
slot_items.append("{{ " + placeholder + " }}")
if slot_pieces[1]:
slot_items.append(slot_pieces[1])
elif isinstance(slot, set): # do not use {{ eos_token }} since it may be replaced
if "bos_token" in slot and tokenizer.bos_token_id is not None:
slot_items.append(tokenizer.bos_token)
elif "eos_token" in slot and tokenizer.eos_token_id is not None:
slot_items.append(tokenizer.eos_token)
elif isinstance(slot, dict):
raise ValueError("Dict is not supported.")
return "".join(slot_items)
def _get_ollama_template(self, tokenizer: "PreTrainedTokenizer") -> str:
r"""
Returns the ollama template.
"""
prefix = self._convert_slots_to_ollama(self.format_prefix.apply(), tokenizer)
system = self._convert_slots_to_ollama(self.format_system.apply(), tokenizer, placeholder=".System")
user = self._convert_slots_to_ollama(self.format_user.apply(), tokenizer, placeholder=".Content")
assistant = self._convert_slots_to_ollama(self.format_assistant.apply(), tokenizer, placeholder=".Content")
return (
f"{prefix}{{{{ if .System }}}}{system}{{{{ end }}}}"
f"""{{{{ range .Messages }}}}{{{{ if eq .Role "user" }}}}{user}"""
f"""{{{{ else if eq .Role "assistant" }}}}{assistant}{{{{ end }}}}{{{{ end }}}}"""
)
def get_ollama_modelfile(self, tokenizer: "PreTrainedTokenizer") -> str:
r"""
Returns the ollama modelfile.
TODO: support function calling.
"""
modelfile = "# ollama modelfile auto-generated by llamafactory\n\n"
modelfile += f'FROM .\n\nTEMPLATE """{self._get_ollama_template(tokenizer)}"""\n\n'
if self.default_system:
modelfile += f'SYSTEM """{self.default_system}"""\n\n'
for stop_token_id in self.get_stop_token_ids(tokenizer):
modelfile += f'PARAMETER stop "{tokenizer.convert_ids_to_tokens(stop_token_id)}"\n'
modelfile += "PARAMETER num_ctx 4096\n"
return modelfile
@dataclass
......@@ -169,11 +344,6 @@ class Llama2Template(Template):
system: str,
tools: str,
) -> List[List[int]]:
r"""
Encodes formatted inputs to pairs of token ids.
Turn 0: prefix + system + query resp
Turn t: sep + query resp
"""
system = system or self.default_system
encoded_messages = []
for i, message in enumerate(messages):
......@@ -201,11 +371,41 @@ class Llama2Template(Template):
return encoded_messages
def _get_jinja_template(self, tokenizer: "PreTrainedTokenizer") -> str:
prefix = self._convert_slots_to_jinja(self.format_prefix.apply(), tokenizer)
system_message = self._convert_slots_to_jinja(
self.format_system.apply(), tokenizer, placeholder="system_message"
)
user_message = self._convert_slots_to_jinja(self.format_user.apply(), tokenizer)
assistant_message = self._convert_slots_to_jinja(self.format_assistant.apply(), tokenizer)
jinja_template = ""
if prefix:
jinja_template += "{{ " + prefix + " }}"
if self.default_system:
jinja_template += "{% set system_message = '" + self._jinja_escape(self.default_system) + "' %}"
jinja_template += (
"{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}"
"{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% endif %}"
"{% for message in loop_messages %}"
"{% if loop.index0 == 0 and system_message is defined %}"
"{% set content = " + system_message + " + message['content'] %}"
"{% else %}{% set content = message['content'] %}{% endif %}"
"{% if message['role'] == 'user' %}"
"{{ " + user_message + " }}"
"{% elif message['role'] == 'assistant' %}"
"{{ " + assistant_message + " }}"
"{% endif %}"
"{% endfor %}"
)
return jinja_template
TEMPLATES: Dict[str, "Template"] = {}
def _register_template(
def register_template(
name: str,
format_user: Optional["Formatter"] = None,
format_assistant: Optional["Formatter"] = None,
......@@ -216,10 +416,12 @@ def _register_template(
format_prefix: Optional["Formatter"] = None,
default_system: str = "",
stop_words: Optional[Sequence[str]] = None,
thought_words: Optional[Tuple[str, str]] = None,
efficient_eos: bool = False,
replace_eos: bool = False,
replace_jinja_template: bool = False,
mm_plugin: "BasePlugin" = get_mm_plugin(name="base"),
template_class: Type["Template"] = Template,
) -> None:
r"""
Registers a chat template.
......@@ -234,7 +436,7 @@ def _register_template(
The corresponding code should be:
```
_register_template(
register_template(
name="custom",
format_user=StringFormatter(slots=["<user>{{content}}\n<model>"]),
format_assistant=StringFormatter(slots=["{{content}}</s>\n"]),
......@@ -242,7 +444,9 @@ def _register_template(
)
```
"""
template_class = Llama2Template if any(k in name for k in ("llama2", "mistral", "pixtral")) else Template
if name in TEMPLATES:
raise ValueError(f"Template {name} already exists.")
default_slots = ["{{content}}"] if efficient_eos else ["{{content}}", {"eos_token"}]
default_user_formatter = StringFormatter(slots=["{{content}}"])
default_assistant_formatter = StringFormatter(slots=default_slots)
......@@ -259,6 +463,7 @@ def _register_template(
format_prefix=format_prefix or default_prefix_formatter,
default_system=default_system,
stop_words=stop_words or [],
thought_words=thought_words or ("<think>", "</think>"),
efficient_eos=efficient_eos,
replace_eos=replace_eos,
replace_jinja_template=replace_jinja_template,
......@@ -266,97 +471,83 @@ def _register_template(
)
def _add_or_replace_eos_token(tokenizer: "PreTrainedTokenizer", eos_token: str) -> None:
is_added = tokenizer.eos_token_id is None
num_added_tokens = tokenizer.add_special_tokens({"eos_token": eos_token})
if is_added:
logger.info_rank0(f"Add eos token: {tokenizer.eos_token}")
else:
logger.info_rank0(f"Replace eos token: {tokenizer.eos_token}")
if num_added_tokens > 0:
logger.warning_rank0("New tokens have been added, make sure `resize_vocab` is True.")
def _jinja_escape(content: str) -> str:
return content.replace("'", r"\'")
def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: str = "content") -> str:
slot_items = []
for slot in slots:
if isinstance(slot, str):
slot_pieces = slot.split("{{content}}")
if slot_pieces[0]:
slot_items.append("'" + _jinja_escape(slot_pieces[0]) + "'")
if len(slot_pieces) > 1:
slot_items.append(placeholder)
if slot_pieces[1]:
slot_items.append("'" + _jinja_escape(slot_pieces[1]) + "'")
elif isinstance(slot, set): # do not use {{ eos_token }} since it may be replaced
if "bos_token" in slot and tokenizer.bos_token_id is not None:
slot_items.append("'" + tokenizer.bos_token + "'")
elif "eos_token" in slot and tokenizer.eos_token_id is not None:
slot_items.append("'" + tokenizer.eos_token + "'")
elif isinstance(slot, dict):
raise ValueError("Dict is not supported.")
return " + ".join(slot_items)
def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer") -> str:
def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template":
r"""
Returns the jinja template.
Extracts a chat template from the tokenizer.
"""
jinja_template = ""
prefix = _convert_slots_to_jinja(template.format_prefix.apply(), tokenizer)
if prefix:
jinja_template += "{{ " + prefix + " }}"
if template.default_system:
jinja_template += "{% set system_message = '" + _jinja_escape(template.default_system) + "' %}"
jinja_template += (
"{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}"
"{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% endif %}"
def find_diff(short_str: str, long_str: str) -> str:
i, j = 0, 0
diff = ""
while i < len(short_str) and j < len(long_str):
if short_str[i] == long_str[j]:
i += 1
j += 1
else:
diff += long_str[j]
j += 1
return diff
prefix = tokenizer.decode(tokenizer.encode(""))
messages = [{"role": "system", "content": "{{content}}"}]
system_slot = tokenizer.apply_chat_template(messages, add_generation_prompt=False, tokenize=False)[len(prefix) :]
messages = [{"role": "system", "content": ""}, {"role": "user", "content": "{{content}}"}]
user_slot_empty_system = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
user_slot_empty_system = user_slot_empty_system[len(prefix) :]
messages = [{"role": "user", "content": "{{content}}"}]
user_slot = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
user_slot = user_slot[len(prefix) :]
messages = [{"role": "user", "content": "{{content}}"}, {"role": "assistant", "content": "{{content}}"}]
assistant_slot = tokenizer.apply_chat_template(messages, add_generation_prompt=False, tokenize=False)
assistant_slot = assistant_slot[len(prefix) + len(user_slot) :]
if len(user_slot) > len(user_slot_empty_system):
default_system = find_diff(user_slot_empty_system, user_slot)
sole_system = system_slot.replace("{{content}}", default_system, 1)
user_slot = user_slot[len(sole_system) :]
else: # if defaut_system is empty, user_slot_empty_system will be longer than user_slot
default_system = ""
return Template(
format_user=StringFormatter(slots=[user_slot]),
format_assistant=StringFormatter(slots=[assistant_slot]),
format_system=StringFormatter(slots=[system_slot]),
format_function=FunctionFormatter(slots=[assistant_slot], tool_format="default"),
format_observation=StringFormatter(slots=[user_slot]),
format_tools=ToolFormatter(tool_format="default"),
format_prefix=EmptyFormatter(slots=[prefix]) if prefix else EmptyFormatter(),
default_system=default_system,
stop_words=[],
thought_words=("<think>", "</think>"),
efficient_eos=False,
replace_eos=False,
replace_jinja_template=False,
mm_plugin=get_mm_plugin(name="base"),
)
system_message = _convert_slots_to_jinja(template.format_system.apply(), tokenizer, placeholder="system_message")
if not isinstance(template, Llama2Template):
jinja_template += "{% if system_message is defined %}{{ " + system_message + " }}{% endif %}"
jinja_template += "{% for message in loop_messages %}"
jinja_template += "{% set content = message['content'] %}"
if isinstance(template, Llama2Template):
jinja_template += "{% if loop.index0 == 0 and system_message is defined %}"
jinja_template += "{% set content = " + system_message + " + message['content'] %}"
jinja_template += "{% endif %}"
jinja_template += "{% if message['role'] == 'user' %}"
user_message = _convert_slots_to_jinja(template.format_user.apply(), tokenizer)
jinja_template += "{{ " + user_message + " }}"
jinja_template += "{% elif message['role'] == 'assistant' %}"
assistant_message = _convert_slots_to_jinja(template.format_assistant.apply(), tokenizer)
jinja_template += "{{ " + assistant_message + " }}"
jinja_template += "{% endif %}"
jinja_template += "{% endfor %}"
return jinja_template
def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: "DataArguments") -> "Template":
r"""
Gets chat template and fixes the tokenizer.
"""
if data_args.template is None:
template = TEMPLATES["empty"] # placeholder
if isinstance(tokenizer.chat_template, str):
logger.warning_rank0("`template` was not specified, try parsing the chat template from the tokenizer.")
template = parse_template(tokenizer)
else:
logger.warning_rank0("`template` was not specified, use `empty` template.")
template = TEMPLATES["empty"] # placeholder
else:
template = TEMPLATES.get(data_args.template, None)
if template is None:
if data_args.template not in TEMPLATES:
raise ValueError(f"Template {data_args.template} does not exist.")
template = TEMPLATES[data_args.template]
if template.mm_plugin.__class__.__name__ != "BasePlugin":
check_version("transformers>=4.45.0")
......@@ -369,39 +560,12 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
template.format_function = FunctionFormatter(slots=default_slots, tool_format=data_args.tool_format)
template.format_tools = ToolFormatter(tool_format=data_args.tool_format)
stop_words = template.stop_words
if template.replace_eos:
if not stop_words:
raise ValueError("Stop words are required to replace the EOS token.")
_add_or_replace_eos_token(tokenizer, eos_token=stop_words[0])
stop_words = stop_words[1:]
if tokenizer.eos_token_id is None:
_add_or_replace_eos_token(tokenizer, eos_token="<|endoftext|>")
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token
logger.info_rank0(f"Add pad token: {tokenizer.pad_token}")
if stop_words:
num_added_tokens = tokenizer.add_special_tokens(
dict(additional_special_tokens=stop_words), replace_additional_special_tokens=False
)
logger.info_rank0("Add {} to stop words.".format(",".join(stop_words)))
if num_added_tokens > 0:
logger.warning_rank0("New tokens have been added, make sure `resize_vocab` is True.")
if tokenizer.chat_template is None or template.replace_jinja_template:
try:
tokenizer.chat_template = _get_jinja_template(template, tokenizer)
except ValueError as e:
logger.info_rank0(f"Cannot add this chat template to tokenizer: {e}.")
template.fix_special_tokens(tokenizer)
template.fix_jinja_template(tokenizer)
return template
_register_template(
register_template(
name="alpaca",
format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n\n### Response:\n"]),
format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}, "\n\n"]),
......@@ -412,7 +576,7 @@ _register_template(
)
_register_template(
register_template(
name="aquila",
format_user=StringFormatter(slots=["Human: {{content}}###Assistant:"]),
format_assistant=StringFormatter(slots=["{{content}}###"]),
......@@ -425,7 +589,7 @@ _register_template(
)
_register_template(
register_template(
name="atom",
format_user=StringFormatter(
slots=[{"bos_token"}, "Human: {{content}}\n", {"eos_token"}, {"bos_token"}, "Assistant:"]
......@@ -434,21 +598,31 @@ _register_template(
)
_register_template(
register_template(
name="baichuan",
format_user=StringFormatter(slots=[{"token": "<reserved_102>"}, "{{content}}", {"token": "<reserved_103>"}]),
efficient_eos=True,
)
_register_template(
register_template(
name="baichuan2",
format_user=StringFormatter(slots=["<reserved_106>{{content}}<reserved_107>"]),
efficient_eos=True,
)
_register_template(
register_template(
name="bailing",
format_user=StringFormatter(slots=["<role>HUMAN</role>{{content}}<role>ASSISTANT</role>"]),
format_system=StringFormatter(slots=["<role>SYSTEM</role>{{content}}"]),
format_observation=StringFormatter(slots=["<role>OBSERVATION</role>{{content}}<role>ASSISTANT</role>"]),
stop_words=["<|endoftext|>"],
efficient_eos=True,
)
register_template(
name="belle",
format_user=StringFormatter(slots=["Human: {{content}}\n\nBelle: "]),
format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}, "\n\n"]),
......@@ -456,13 +630,13 @@ _register_template(
)
_register_template(
register_template(
name="bluelm",
format_user=StringFormatter(slots=[{"token": "[|Human|]:"}, "{{content}}", {"token": "[|AI|]:"}]),
)
_register_template(
register_template(
name="breeze",
format_user=StringFormatter(slots=["[INST] {{content}} [/INST] "]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
......@@ -470,7 +644,7 @@ _register_template(
)
_register_template(
register_template(
name="chatglm2",
format_user=StringFormatter(slots=["[Round {{idx}}]\n\n问:{{content}}\n\n答:"]),
format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]),
......@@ -478,7 +652,7 @@ _register_template(
)
_register_template(
register_template(
name="chatglm3",
format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
......@@ -494,7 +668,7 @@ _register_template(
)
_register_template(
register_template(
name="chatml",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
......@@ -507,7 +681,7 @@ _register_template(
# copied from chatml template
_register_template(
register_template(
name="chatml_de",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
......@@ -520,13 +694,13 @@ _register_template(
)
_register_template(
register_template(
name="codegeex2",
format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]),
)
_register_template(
register_template(
name="codegeex4",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>\n"]),
format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
......@@ -543,7 +717,7 @@ _register_template(
)
_register_template(
register_template(
name="cohere",
format_user=StringFormatter(
slots=[
......@@ -558,7 +732,7 @@ _register_template(
)
_register_template(
register_template(
name="cpm",
format_user=StringFormatter(slots=["<用户>{{content}}<AI>"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
......@@ -566,7 +740,7 @@ _register_template(
# copied from chatml template
_register_template(
register_template(
name="cpm3",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
......@@ -577,7 +751,7 @@ _register_template(
# copied from chatml template
_register_template(
register_template(
name="dbrx",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
......@@ -602,7 +776,7 @@ _register_template(
)
_register_template(
register_template(
name="deepseek",
format_user=StringFormatter(slots=["User: {{content}}\n\nAssistant:"]),
format_system=StringFormatter(slots=["{{content}}\n\n"]),
......@@ -610,14 +784,14 @@ _register_template(
)
_register_template(
register_template(
name="deepseek3",
format_user=StringFormatter(slots=["<|User|>{{content}}<|Assistant|>"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
)
_register_template(
register_template(
name="deepseekcoder",
format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n### Response:"]),
format_assistant=StringFormatter(slots=["\n{{content}}\n<|EOT|>\n"]),
......@@ -631,7 +805,7 @@ _register_template(
)
_register_template(
register_template(
name="default",
format_user=StringFormatter(slots=["Human: {{content}}\nAssistant:"]),
format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}, "\n"]),
......@@ -639,13 +813,13 @@ _register_template(
)
_register_template(
register_template(
name="empty",
efficient_eos=True,
format_assistant=StringFormatter(slots=["{{content}}"]),
)
_register_template(
register_template(
name="exaone",
format_user=StringFormatter(slots=["[|user|]{{content}}\n[|assistant|]"]),
format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}, "\n"]),
......@@ -653,7 +827,7 @@ _register_template(
)
_register_template(
register_template(
name="falcon",
format_user=StringFormatter(slots=["User: {{content}}\nFalcon:"]),
format_assistant=StringFormatter(slots=["{{content}}\n"]),
......@@ -661,14 +835,14 @@ _register_template(
)
_register_template(
register_template(
name="fewshot",
format_assistant=StringFormatter(slots=["{{content}}\n\n"]),
efficient_eos=True,
)
_register_template(
register_template(
name="gemma",
format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
format_assistant=StringFormatter(slots=["{{content}}<end_of_turn>\n"]),
......@@ -679,7 +853,7 @@ _register_template(
)
_register_template(
register_template(
name="glm4",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
format_assistant=StringFormatter(slots=["\n{{content}}"]),
......@@ -693,7 +867,7 @@ _register_template(
)
_register_template(
register_template(
name="granite3",
format_user=StringFormatter(
slots=[
......@@ -705,7 +879,7 @@ _register_template(
)
_register_template(
register_template(
name="index",
format_user=StringFormatter(slots=["reserved_0{{content}}reserved_1"]),
format_system=StringFormatter(slots=["<unk>{{content}}"]),
......@@ -713,54 +887,59 @@ _register_template(
)
_register_template(
register_template(
name="intern",
format_user=StringFormatter(slots=["<|User|>:{{content}}\n<|Bot|>:"]),
format_assistant=StringFormatter(slots=["{{content}}<eoa>\n"]),
format_system=StringFormatter(slots=["<|System|>:{{content}}\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
default_system=(
"You are an AI assistant whose name is InternLM (书生·浦语).\n"
"- InternLM (书生·浦语) is a conversational language model that is developed by Shanghai AI Laboratory "
"(上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n"
"- InternLM (书生·浦语) can understand and communicate fluently in the language "
"chosen by the user such as English and 中文."
),
stop_words=["<eoa>"],
)
_register_template(
register_template(
name="intern2",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
default_system=(
"You are an AI assistant whose name is InternLM (书生·浦语).\n"
"- InternLM (书生·浦语) is a conversational language model that is developed by Shanghai AI Laboratory "
"(上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n"
"- InternLM (书生·浦语) can understand and communicate fluently in the language "
"chosen by the user such as English and 中文."
),
stop_words=["<|im_end|>"],
)
# copied from intern2 template
_register_template(
name="intern3",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|im_end|>"],
)
_register_template(
register_template(
name="llama2",
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
format_system=StringFormatter(slots=["<<SYS>>\n{{content}}\n<</SYS>>\n\n"]),
template_class=Llama2Template,
)
# copied from llama2 template
_register_template(
register_template(
name="llama2_zh",
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
format_system=StringFormatter(slots=["<<SYS>>\n{{content}}\n<</SYS>>\n\n"]),
default_system="You are a helpful assistant. 你是一个乐于助人的助手。",
template_class=Llama2Template,
)
_register_template(
register_template(
name="llama3",
format_user=StringFormatter(
slots=[
......@@ -788,7 +967,7 @@ _register_template(
# copied from llama3 template
_register_template(
register_template(
name="mllama",
format_user=StringFormatter(
slots=[
......@@ -816,8 +995,20 @@ _register_template(
)
register_template(
name="moonlight",
format_user=StringFormatter(
slots=["<|im_user|>user<|im_middle|>{{content}}<|im_end|><|im_assistant|>assistant<|im_middle|>"]
),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>"]),
format_system=StringFormatter(slots=["<|im_system|>system<|im_middle|>{{content}}<|im_end|>"]),
default_system="You are a helpful assistant provided by Moonshot-AI.",
stop_words=["<|im_end|>"],
)
# copied from vicuna template
_register_template(
register_template(
name="llava",
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
default_system=(
......@@ -829,7 +1020,7 @@ _register_template(
# copied from vicuna template
_register_template(
register_template(
name="llava_next",
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
default_system=(
......@@ -841,7 +1032,7 @@ _register_template(
# copied from llama3 template
_register_template(
register_template(
name="llava_next_llama3",
format_user=StringFormatter(
slots=[
......@@ -870,21 +1061,22 @@ _register_template(
# copied from mistral template
_register_template(
register_template(
name="llava_next_mistral",
format_user=StringFormatter(slots=["[INST] {{content}}[/INST]"]),
format_assistant=StringFormatter(slots=[" {{content}}", {"eos_token"}]),
format_system=StringFormatter(slots=["{{content}}\n\n"]),
format_function=FunctionFormatter(slots=["[TOOL_CALLS] ", "{{content}}", {"eos_token"}], tool_format="mistral"),
format_function=FunctionFormatter(slots=["[TOOL_CALLS] {{content}}", {"eos_token"}], tool_format="mistral"),
format_observation=StringFormatter(slots=["""[TOOL_RESULTS] {"content": {{content}}}[/TOOL_RESULTS]"""]),
format_tools=ToolFormatter(tool_format="mistral"),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"),
template_class=Llama2Template,
)
# copied from chatml template
_register_template(
# copied from qwen template
register_template(
name="llava_next_qwen",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
......@@ -901,7 +1093,7 @@ _register_template(
# copied from chatml template
_register_template(
register_template(
name="llava_next_yi",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
......@@ -912,7 +1104,7 @@ _register_template(
# copied from vicuna template
_register_template(
register_template(
name="llava_next_video",
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
default_system=(
......@@ -924,21 +1116,22 @@ _register_template(
# copied from mistral template
_register_template(
register_template(
name="llava_next_video_mistral",
format_user=StringFormatter(slots=["[INST] {{content}}[/INST]"]),
format_assistant=StringFormatter(slots=[" {{content}}", {"eos_token"}]),
format_system=StringFormatter(slots=["{{content}}\n\n"]),
format_function=FunctionFormatter(slots=["[TOOL_CALLS] ", "{{content}}", {"eos_token"}], tool_format="mistral"),
format_function=FunctionFormatter(slots=["[TOOL_CALLS] {{content}}", {"eos_token"}], tool_format="mistral"),
format_observation=StringFormatter(slots=["""[TOOL_RESULTS] {"content": {{content}}}[/TOOL_RESULTS]"""]),
format_tools=ToolFormatter(tool_format="mistral"),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
mm_plugin=get_mm_plugin(name="llava_next_video", image_token="<image>", video_token="<video>"),
template_class=Llama2Template,
)
# copied from chatml template
_register_template(
register_template(
name="llava_next_video_yi",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
......@@ -949,7 +1142,7 @@ _register_template(
# copied from chatml template
_register_template(
register_template(
name="marco",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
......@@ -965,43 +1158,83 @@ _register_template(
# copied from chatml template
_register_template(
register_template(
name="minicpm_v",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
stop_words=["<|im_end|>"],
default_system="You are a helpful assistant.",
mm_plugin=get_mm_plugin(name="minicpm_v", image_token="<image>", video_token="<video>"),
)
_register_template(
# copied from minicpm_v template
register_template(
name="minicpm_o",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
stop_words=["<|im_end|>"],
default_system="You are Qwen, created by Alibaba Cloud. You are a helpful assistant.",
mm_plugin=get_mm_plugin(name="minicpm_v", image_token="<image>", video_token="<video>", audio_token="<audio>"),
)
# mistral tokenizer v3 tekken
register_template(
name="ministral",
format_user=StringFormatter(slots=["[INST]{{content}}[/INST]"]),
format_system=StringFormatter(slots=["{{content}}\n\n"]),
format_function=FunctionFormatter(slots=["[TOOL_CALLS]{{content}}", {"eos_token"}], tool_format="mistral"),
format_observation=StringFormatter(slots=["""[TOOL_RESULTS]{"content": {{content}}}[/TOOL_RESULTS]"""]),
format_tools=ToolFormatter(tool_format="mistral"),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
template_class=Llama2Template,
)
# mistral tokenizer v3
register_template(
name="mistral",
format_user=StringFormatter(slots=["[INST] {{content}}[/INST]"]),
format_assistant=StringFormatter(slots=[" {{content}}", {"eos_token"}]),
format_system=StringFormatter(slots=["{{content}}\n\n"]),
format_function=FunctionFormatter(slots=["[TOOL_CALLS] ", "{{content}}", {"eos_token"}], tool_format="mistral"),
format_function=FunctionFormatter(slots=["[TOOL_CALLS] {{content}}", {"eos_token"}], tool_format="mistral"),
format_observation=StringFormatter(slots=["""[TOOL_RESULTS] {"content": {{content}}}[/TOOL_RESULTS]"""]),
format_tools=ToolFormatter(tool_format="mistral"),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
template_class=Llama2Template,
)
# mistral tokenizer v7 tekken (copied from ministral)
register_template(
name="mistral_small",
format_user=StringFormatter(slots=["[INST]{{content}}[/INST]"]),
format_system=StringFormatter(slots=["[SYSTEM_PROMPT]{{content}}[/SYSTEM_PROMPT]"]),
format_function=FunctionFormatter(slots=["[TOOL_CALLS]{{content}}", {"eos_token"}], tool_format="mistral"),
format_observation=StringFormatter(slots=["""[TOOL_RESULTS]{"content": {{content}}}[/TOOL_RESULTS]"""]),
format_tools=ToolFormatter(tool_format="mistral"),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
)
_register_template(
register_template(
name="olmo",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>\n"]),
format_prefix=EmptyFormatter(slots=[{"eos_token"}]),
)
_register_template(
register_template(
name="openchat",
format_user=StringFormatter(slots=["GPT4 Correct User: {{content}}", {"eos_token"}, "GPT4 Correct Assistant:"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
)
_register_template(
register_template(
name="openchat-3.6",
format_user=StringFormatter(
slots=[
......@@ -1017,7 +1250,7 @@ _register_template(
# copied from chatml template
_register_template(
register_template(
name="opencoder",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
......@@ -1028,16 +1261,24 @@ _register_template(
)
_register_template(
register_template(
name="orion",
format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: ", {"eos_token"}]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
)
# copied from gemma template
_register_template(
register_template(
name="paligemma",
format_user=StringFormatter(slots=["{{content}}\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
mm_plugin=get_mm_plugin(name="paligemma", image_token="<image>"),
)
# copied from gemma template
register_template(
name="paligemma_chat",
format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
format_assistant=StringFormatter(slots=["{{content}}<end_of_turn>\n"]),
format_observation=StringFormatter(
......@@ -1048,7 +1289,7 @@ _register_template(
)
_register_template(
register_template(
name="phi",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|end|>\n"]),
......@@ -1057,7 +1298,7 @@ _register_template(
)
_register_template(
register_template(
name="phi_small",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|end|>\n"]),
......@@ -1067,7 +1308,7 @@ _register_template(
)
_register_template(
register_template(
name="phi4",
format_user=StringFormatter(
slots=["<|im_start|>user<|im_sep|>{{content}}<|im_end|><|im_start|>assistant<|im_sep|>"]
......@@ -1078,17 +1319,22 @@ _register_template(
)
_register_template(
# copied from ministral template
register_template(
name="pixtral",
format_user=StringFormatter(slots=["[INST]{{content}}[/INST]"]),
format_system=StringFormatter(slots=["{{content}}\n\n"]),
format_function=FunctionFormatter(slots=["[TOOL_CALLS]{{content}}", {"eos_token"}], tool_format="mistral"),
format_observation=StringFormatter(slots=["""[TOOL_RESULTS]{"content": {{content}}}[/TOOL_RESULTS]"""]),
format_tools=ToolFormatter(tool_format="mistral"),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
mm_plugin=get_mm_plugin(name="pixtral", image_token="[IMG]"),
template_class=Llama2Template,
)
# copied from chatml template
_register_template(
register_template(
name="qwen",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
......@@ -1104,7 +1350,19 @@ _register_template(
# copied from chatml template
_register_template(
register_template(
name="qwen2_audio",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
default_system="You are a helpful assistant.",
stop_words=["<|im_end|>"],
mm_plugin=get_mm_plugin(name="qwen2_audio", audio_token="<|AUDIO|>"),
)
# copied from qwen template
register_template(
name="qwen2_vl",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
......@@ -1120,7 +1378,7 @@ _register_template(
)
_register_template(
register_template(
name="sailor",
format_user=StringFormatter(slots=["<|im_start|>question\n{{content}}<|im_end|>\n<|im_start|>answer\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
......@@ -1134,7 +1392,7 @@ _register_template(
# copied from llama3 template
_register_template(
register_template(
name="skywork_o1",
format_user=StringFormatter(
slots=[
......@@ -1168,7 +1426,7 @@ _register_template(
)
_register_template(
register_template(
name="solar",
format_user=StringFormatter(slots=["### User:\n{{content}}\n\n### Assistant:\n"]),
format_system=StringFormatter(slots=["### System:\n{{content}}\n\n"]),
......@@ -1176,7 +1434,7 @@ _register_template(
)
_register_template(
register_template(
name="starchat",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>"]),
format_assistant=StringFormatter(slots=["{{content}}<|end|>\n"]),
......@@ -1185,14 +1443,14 @@ _register_template(
)
_register_template(
register_template(
name="telechat",
format_user=StringFormatter(slots=["<_user>{{content}}<_bot>"]),
format_system=StringFormatter(slots=["<_system>{{content}}<_end>"]),
)
_register_template(
register_template(
name="telechat2",
format_user=StringFormatter(slots=["<_user>{{content}}<_bot>"]),
format_system=StringFormatter(slots=["<_system>{{content}}"]),
......@@ -1202,7 +1460,7 @@ _register_template(
)
_register_template(
register_template(
name="vicuna",
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
default_system=(
......@@ -1213,7 +1471,7 @@ _register_template(
)
_register_template(
register_template(
name="video_llava",
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
default_system=(
......@@ -1224,7 +1482,7 @@ _register_template(
)
_register_template(
register_template(
name="xuanyuan",
format_user=StringFormatter(slots=["Human: {{content}} Assistant:"]),
default_system=(
......@@ -1235,13 +1493,13 @@ _register_template(
)
_register_template(
register_template(
name="xverse",
format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: "]),
)
_register_template(
register_template(
name="yayi",
format_user=StringFormatter(slots=[{"token": "<|Human|>"}, ":\n{{content}}\n\n", {"token": "<|YaYi|>"}, ":"]),
format_assistant=StringFormatter(slots=["{{content}}\n\n"]),
......@@ -1262,7 +1520,7 @@ _register_template(
# copied from chatml template
_register_template(
register_template(
name="yi",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
......@@ -1271,7 +1529,7 @@ _register_template(
)
_register_template(
register_template(
name="yi_vl",
format_user=StringFormatter(slots=["### Human: {{content}}\n### Assistant:"]),
format_assistant=StringFormatter(slots=["{{content}}\n"]),
......@@ -1288,7 +1546,7 @@ _register_template(
)
_register_template(
register_template(
name="yuan",
format_user=StringFormatter(slots=["{{content}}", {"token": "<sep>"}]),
format_assistant=StringFormatter(slots=["{{content}}<eod>\n"]),
......@@ -1296,7 +1554,7 @@ _register_template(
)
_register_template(
register_template(
name="zephyr",
format_user=StringFormatter(slots=["<|user|>\n{{content}}", {"eos_token"}, "<|assistant|>\n"]),
format_system=StringFormatter(slots=["<|system|>\n{{content}}", {"eos_token"}]),
......@@ -1304,7 +1562,7 @@ _register_template(
)
_register_template(
register_template(
name="ziya",
format_user=StringFormatter(slots=["<human>:{{content}}\n<bot>:"]),
format_assistant=StringFormatter(slots=["{{content}}\n"]),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment