Commit 0722acf1 authored by chenych's avatar chenych
Browse files

Update 0604

parent c4ba4563
...@@ -79,6 +79,10 @@ class SGLangEngine(BaseEngine): ...@@ -79,6 +79,10 @@ class SGLangEngine(BaseEngine):
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args) self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args)
self.template.mm_plugin.expand_mm_tokens = False # for sglang generate self.template.mm_plugin.expand_mm_tokens = False # for sglang generate
self.generating_args = generating_args.to_dict() self.generating_args = generating_args.to_dict()
if model_args.adapter_name_or_path is not None:
self.lora_request = True
else:
self.lora_request = False
launch_cmd = [ launch_cmd = [
"python3 -m sglang.launch_server", "python3 -m sglang.launch_server",
...@@ -90,6 +94,15 @@ class SGLangEngine(BaseEngine): ...@@ -90,6 +94,15 @@ class SGLangEngine(BaseEngine):
f"--download-dir {model_args.cache_dir}", f"--download-dir {model_args.cache_dir}",
"--log-level error", "--log-level error",
] ]
if self.lora_request:
launch_cmd.extend(
[
"--max-loras-per-batch 1",
f"--lora-backend {model_args.sglang_lora_backend}",
f"--lora-paths lora0={model_args.adapter_name_or_path[0]}",
"--disable-radix-cache",
]
)
launch_cmd = " ".join(launch_cmd) launch_cmd = " ".join(launch_cmd)
logger.info_rank0(f"Starting SGLang server with command: {launch_cmd}") logger.info_rank0(f"Starting SGLang server with command: {launch_cmd}")
try: try:
...@@ -147,7 +160,6 @@ class SGLangEngine(BaseEngine): ...@@ -147,7 +160,6 @@ class SGLangEngine(BaseEngine):
messages, images or [], videos or [], audios or [], self.processor messages, images or [], videos or [], audios or [], self.processor
) )
paired_messages = messages + [{"role": "assistant", "content": ""}] paired_messages = messages + [{"role": "assistant", "content": ""}]
system = system or self.generating_args["default_system"]
prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools) prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools)
prompt_length = len(prompt_ids) prompt_length = len(prompt_ids)
...@@ -200,6 +212,8 @@ class SGLangEngine(BaseEngine): ...@@ -200,6 +212,8 @@ class SGLangEngine(BaseEngine):
"sampling_params": sampling_params, "sampling_params": sampling_params,
"stream": True, "stream": True,
} }
if self.lora_request:
json_data["lora_request"] = ["lora0"]
response = requests.post(f"{self.base_url}/generate", json=json_data, stream=True) response = requests.post(f"{self.base_url}/generate", json=json_data, stream=True)
if response.status_code != 200: if response.status_code != 200:
raise RuntimeError(f"SGLang server error: {response.status_code}, {response.text}") raise RuntimeError(f"SGLang server error: {response.status_code}, {response.text}")
......
...@@ -124,7 +124,6 @@ class VllmEngine(BaseEngine): ...@@ -124,7 +124,6 @@ class VllmEngine(BaseEngine):
messages, images or [], videos or [], audios or [], self.processor messages, images or [], videos or [], audios or [], self.processor
) )
paired_messages = messages + [{"role": "assistant", "content": ""}] paired_messages = messages + [{"role": "assistant", "content": ""}]
system = system or self.generating_args["default_system"]
prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools) prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools)
prompt_length = len(prompt_ids) prompt_length = len(prompt_ids)
......
...@@ -73,7 +73,7 @@ def main(): ...@@ -73,7 +73,7 @@ def main():
"help": partial(print, USAGE), "help": partial(print, USAGE),
} }
command = sys.argv.pop(1) if len(sys.argv) >= 1 else "help" command = sys.argv.pop(1) if len(sys.argv) > 1 else "help"
if command == "train" and (is_env_enabled("FORCE_TORCHRUN") or (get_device_count() > 1 and not use_ray())): if command == "train" and (is_env_enabled("FORCE_TORCHRUN") or (get_device_count() > 1 and not use_ray())):
# launch distributed training # launch distributed training
nnodes = os.getenv("NNODES", "1") nnodes = os.getenv("NNODES", "1")
......
...@@ -51,12 +51,27 @@ class DatasetConverter: ...@@ -51,12 +51,27 @@ class DatasetConverter:
else: else:
medias = medias[:] medias = medias[:]
if self.dataset_attr.load_from in ["script", "file"] and isinstance(medias[0], str): if self.dataset_attr.load_from in ["script", "file"]:
for i in range(len(medias)): if isinstance(medias[0], str):
if os.path.isfile(os.path.join(self.data_args.media_dir, medias[i])): for i in range(len(medias)):
medias[i] = os.path.join(self.data_args.media_dir, medias[i]) media_path = os.path.join(self.data_args.media_dir, medias[i])
else: if os.path.isfile(media_path):
logger.warning_rank0_once(f"Media {medias[i]} does not exist in `media_dir`. Use original path.") medias[i] = media_path
else:
logger.warning_rank0_once(
f"Media {medias[i]} does not exist in `media_dir`. Use original path."
)
elif isinstance(medias[0], list): # for processed video frames
# medias is a list of lists, e.g., [[frame1.jpg, frame2.jpg], [frame3.jpg, frame4.jpg]]
for i in range(len(medias)):
for j in range(len(medias[i])):
media_path = os.path.join(self.data_args.media_dir, medias[i][j])
if os.path.isfile(media_path):
medias[i][j] = media_path
else:
logger.warning_rank0_once(
f"Media {medias[i][j]} does not exist in `media_dir`. Use original path."
)
return medias return medias
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
import json import json
from enum import Enum, unique from enum import Enum, unique
from typing import TYPE_CHECKING, Optional, TypedDict, Union from typing import TYPE_CHECKING, Any, Optional, TypedDict, Union
import fsspec import fsspec
from datasets import DatasetDict, concatenate_datasets, interleave_datasets from datasets import DatasetDict, concatenate_datasets, interleave_datasets
...@@ -142,48 +142,49 @@ def get_dataset_module(dataset: Union["Dataset", "DatasetDict"]) -> "DatasetModu ...@@ -142,48 +142,49 @@ def get_dataset_module(dataset: Union["Dataset", "DatasetDict"]) -> "DatasetModu
return dataset_module return dataset_module
def setup_fs(path, anon=False): def setup_fs(path: str, anon: bool = False) -> "fsspec.AbstractFileSystem":
"""Set up a filesystem object based on the path protocol.""" r"""Set up a filesystem object based on the path protocol."""
storage_options = {"anon": anon} if anon else {} storage_options = {"anon": anon} if anon else {}
if path.startswith("s3://"): if path.startswith("s3://"):
fs = fsspec.filesystem("s3", **storage_options) fs = fsspec.filesystem("s3", **storage_options)
elif path.startswith(("gs://", "gcs://")): elif path.startswith(("gs://", "gcs://")):
fs = fsspec.filesystem("gcs", **storage_options) fs = fsspec.filesystem("gcs", **storage_options)
else: else:
raise ValueError(f"Unsupported protocol in path: {path}. Use 's3://' or 'gs://'") raise ValueError(f"Unsupported protocol in path: {path}. Use 's3://' or 'gs://'.")
if not fs.exists(path):
raise ValueError(f"Path does not exist: {path}.")
return fs return fs
def read_cloud_json(cloud_path): def _read_json_with_fs(fs: "fsspec.AbstractFileSystem", path: str) -> list[Any]:
"""Read a JSON/JSONL file from cloud storage (S3 or GCS). r"""Helper function to read JSON/JSONL files using fsspec."""
with fs.open(path, "r") as f:
if path.endswith(".jsonl"):
return [json.loads(line) for line in f if line.strip()]
else:
return json.load(f)
def read_cloud_json(cloud_path: str) -> list[Any]:
r"""Read a JSON/JSONL file from cloud storage (S3 or GCS).
Args: Args:
cloud_path : str cloud_path: str
Cloud path in the format: Cloud path in the format:
- 's3://bucket-name/file.json' for AWS S3 - 's3://bucket-name/file.json' for AWS S3
- 'gs://bucket-name/file.jsonl' or 'gcs://bucket-name/file.jsonl' for Google Cloud Storage - 'gs://bucket-name/file.jsonl' or 'gcs://bucket-name/file.jsonl' for Google Cloud Storage
lines : bool, default=True
If True, read the file as JSON Lines format (one JSON object per line)
""" """
try: try:
# Try with anonymous access first fs = setup_fs(cloud_path, anon=True) # try with anonymous access first
fs = setup_fs(cloud_path, anon=True)
return _read_json_with_fs(fs, cloud_path, lines=cloud_path.endswith(".jsonl"))
except Exception: except Exception:
# Try again with credentials fs = setup_fs(cloud_path) # try again with credentials
fs = setup_fs(cloud_path)
return _read_json_with_fs(fs, cloud_path, lines=cloud_path.endswith(".jsonl"))
def _read_json_with_fs(fs, path, lines=True): # filter out non-JSON files
"""Helper function to read JSON/JSONL files using fsspec.""" files = [x["Key"] for x in fs.listdir(cloud_path)] if fs.isdir(cloud_path) else [cloud_path]
with fs.open(path, "r") as f: files = filter(lambda file: file.endswith(".json") or file.endswith(".jsonl"), files)
if lines: if not files:
# Read JSONL (JSON Lines) format - one JSON object per line raise ValueError(f"No JSON/JSONL files found in the specified path: {cloud_path}.")
data = [json.loads(line) for line in f if line.strip()]
else:
# Read regular JSON format
data = json.load(f)
return data return sum([_read_json_with_fs(fs, file) for file in files], [])
...@@ -168,7 +168,7 @@ def _get_merged_dataset( ...@@ -168,7 +168,7 @@ def _get_merged_dataset(
data_args: "DataArguments", data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo", "kto"], stage: Literal["pt", "sft", "rm", "ppo", "kto"],
merge: bool = True, return_dict: bool = False,
) -> Optional[Union["Dataset", "IterableDataset", dict[str, "Dataset"]]]: ) -> Optional[Union["Dataset", "IterableDataset", dict[str, "Dataset"]]]:
r"""Return the merged datasets in the standard format.""" r"""Return the merged datasets in the standard format."""
if dataset_names is None: if dataset_names is None:
...@@ -181,10 +181,10 @@ def _get_merged_dataset( ...@@ -181,10 +181,10 @@ def _get_merged_dataset(
datasets[dataset_name] = _load_single_dataset(dataset_attr, model_args, data_args, training_args) datasets[dataset_name] = _load_single_dataset(dataset_attr, model_args, data_args, training_args)
if merge: if return_dict:
return merge_dataset(list(datasets.values()), data_args, seed=training_args.seed)
else:
return datasets return datasets
else:
return merge_dataset(list(datasets.values()), data_args, seed=training_args.seed)
def _get_dataset_processor( def _get_dataset_processor(
...@@ -300,13 +300,18 @@ def get_dataset( ...@@ -300,13 +300,18 @@ def get_dataset(
raise ValueError("Turn off `streaming` when saving dataset to disk.") raise ValueError("Turn off `streaming` when saving dataset to disk.")
# Load and preprocess dataset # Load and preprocess dataset
with training_args.main_process_first(desc="load dataset"): with training_args.main_process_first(desc="load dataset", local=(not data_args.data_shared_file_system)):
dataset = _get_merged_dataset(data_args.dataset, model_args, data_args, training_args, stage) dataset = _get_merged_dataset(data_args.dataset, model_args, data_args, training_args, stage)
eval_dataset = _get_merged_dataset( eval_dataset = _get_merged_dataset(
data_args.eval_dataset, model_args, data_args, training_args, stage, merge=training_args.do_predict data_args.eval_dataset,
model_args,
data_args,
training_args,
stage,
return_dict=data_args.eval_on_each_dataset,
) )
with training_args.main_process_first(desc="pre-process dataset"): with training_args.main_process_first(desc="pre-process dataset", local=(not data_args.data_shared_file_system)):
dataset = _get_preprocessed_dataset( dataset = _get_preprocessed_dataset(
dataset, data_args, training_args, stage, template, tokenizer, processor, is_eval=False dataset, data_args, training_args, stage, template, tokenizer, processor, is_eval=False
) )
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
import inspect import inspect
import math import math
import os
import re import re
from copy import deepcopy from copy import deepcopy
from dataclasses import dataclass from dataclasses import dataclass
...@@ -25,7 +26,7 @@ from typing import TYPE_CHECKING, BinaryIO, Literal, Optional, TypedDict, Union ...@@ -25,7 +26,7 @@ from typing import TYPE_CHECKING, BinaryIO, Literal, Optional, TypedDict, Union
import numpy as np import numpy as np
import torch import torch
from transformers.image_utils import get_image_size, to_numpy_array from transformers.image_utils import get_image_size, is_valid_image, to_numpy_array
from typing_extensions import override from typing_extensions import override
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
...@@ -57,7 +58,10 @@ if is_transformers_version_greater_than("4.45.0"): ...@@ -57,7 +58,10 @@ if is_transformers_version_greater_than("4.45.0"):
) )
if is_transformers_version_greater_than("4.49.0"): if is_transformers_version_greater_than("4.52.0"):
from transformers.image_utils import make_flat_list_of_images
from transformers.video_utils import make_batched_videos
elif is_transformers_version_greater_than("4.49.0"):
from transformers.image_utils import make_batched_videos, make_flat_list_of_images from transformers.image_utils import make_batched_videos, make_flat_list_of_images
...@@ -73,7 +77,7 @@ if TYPE_CHECKING: ...@@ -73,7 +77,7 @@ if TYPE_CHECKING:
bytes: Optional[bytes] bytes: Optional[bytes]
ImageInput = Union[str, bytes, EncodedImage, BinaryIO, ImageObject] ImageInput = Union[str, bytes, EncodedImage, BinaryIO, ImageObject]
VideoInput = Union[str, BinaryIO] VideoInput = Union[str, BinaryIO, list[list[ImageInput]]]
AudioInput = Union[str, BinaryIO, NDArray] AudioInput = Union[str, BinaryIO, NDArray]
class MMProcessor(ProcessorMixin): class MMProcessor(ProcessorMixin):
...@@ -131,6 +135,11 @@ def _make_batched_images(images: list["ImageObject"], imglens: list[int]) -> lis ...@@ -131,6 +135,11 @@ def _make_batched_images(images: list["ImageObject"], imglens: list[int]) -> lis
return batch_images return batch_images
def _check_video_is_nested_images(video: "VideoInput") -> bool:
r"""Check if the video is nested images."""
return isinstance(video, list) and all(isinstance(frame, (str, BinaryIO, dict)) for frame in video)
@dataclass @dataclass
class MMPluginMixin: class MMPluginMixin:
image_token: Optional[str] image_token: Optional[str]
...@@ -167,16 +176,45 @@ class MMPluginMixin: ...@@ -167,16 +176,45 @@ class MMPluginMixin:
) )
if self.image_token is not None and processor is None: if self.image_token is not None and processor is None:
raise ValueError("Processor was not found, please check and update your processor config.") raise ValueError("Processor was not found, please check and update your model file.")
if self.image_token is not None and image_processor is None: if self.image_token is not None and image_processor is None:
raise ValueError("Image processor was not found, please check and update your processor config.") raise ValueError("Image processor was not found, please check and update your model file.")
if self.video_token is not None and video_processor is None: if self.video_token is not None and video_processor is None:
raise ValueError("Video processor was not found, please check and update your processor config.") raise ValueError("Video processor was not found, please check and update your model file.")
if self.audio_token is not None and feature_extractor is None: if self.audio_token is not None and feature_extractor is None:
raise ValueError("Audio feature extractor was not found, please check and update your processor config.") raise ValueError("Audio feature extractor was not found, please check and update your model file.")
def _validate_messages(
self,
messages: list[dict[str, str]],
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
):
r"""Validate if the number of images, videos and audios match the number of placeholders in messages."""
num_image_tokens, num_video_tokens, num_audio_tokens = 0, 0, 0
for message in messages:
num_image_tokens += message["content"].count(IMAGE_PLACEHOLDER)
num_video_tokens += message["content"].count(VIDEO_PLACEHOLDER)
num_audio_tokens += message["content"].count(AUDIO_PLACEHOLDER)
if len(images) != num_image_tokens:
raise ValueError(
f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens in {messages}."
)
if len(videos) != num_video_tokens:
raise ValueError(
f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens in {messages}."
)
if len(audios) != num_audio_tokens:
raise ValueError(
f"The number of audios does not match the number of {AUDIO_PLACEHOLDER} tokens in {messages}."
)
def _preprocess_image( def _preprocess_image(
self, image: "ImageObject", image_max_pixels: int, image_min_pixels: int, **kwargs self, image: "ImageObject", image_max_pixels: int, image_min_pixels: int, **kwargs
...@@ -234,14 +272,20 @@ class MMPluginMixin: ...@@ -234,14 +272,20 @@ class MMPluginMixin:
r"""Regularizes videos to avoid error. Including reading, resizing and converting.""" r"""Regularizes videos to avoid error. Including reading, resizing and converting."""
results = [] results = []
for video in videos: for video in videos:
container = av.open(video, "r")
video_stream = next(stream for stream in container.streams if stream.type == "video")
sample_indices = self._get_video_sample_indices(video_stream, **kwargs)
frames: list[ImageObject] = [] frames: list[ImageObject] = []
container.seek(0) if _check_video_is_nested_images(video):
for frame_idx, frame in enumerate(container.decode(video_stream)): for frame in video:
if frame_idx in sample_indices: if not is_valid_image(frame) and not isinstance(frame, dict) and not os.path.exists(frame):
frames.append(frame.to_image()) raise ValueError("Invalid image found in video frames.")
frames = video
else:
container = av.open(video, "r")
video_stream = next(stream for stream in container.streams if stream.type == "video")
sample_indices = self._get_video_sample_indices(video_stream, **kwargs)
container.seek(0)
for frame_idx, frame in enumerate(container.decode(video_stream)):
if frame_idx in sample_indices:
frames.append(frame.to_image())
frames = self._regularize_images(frames, **kwargs)["images"] frames = self._regularize_images(frames, **kwargs)["images"]
results.append(frames) results.append(frames)
...@@ -420,6 +464,7 @@ class Gemma3Plugin(BasePlugin): ...@@ -420,6 +464,7 @@ class Gemma3Plugin(BasePlugin):
processor: Optional["MMProcessor"], processor: Optional["MMProcessor"],
) -> list[dict[str, str]]: ) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
self._validate_messages(messages, images, videos, audios)
num_image_tokens = 0 num_image_tokens = 0
messages = deepcopy(messages) messages = deepcopy(messages)
boi_token: str = getattr(processor, "boi_token") boi_token: str = getattr(processor, "boi_token")
...@@ -446,9 +491,6 @@ class Gemma3Plugin(BasePlugin): ...@@ -446,9 +491,6 @@ class Gemma3Plugin(BasePlugin):
message["content"] = content.replace("{{image}}", image_str) message["content"] = content.replace("{{image}}", image_str)
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
return messages return messages
@override @override
...@@ -495,14 +537,14 @@ class InternVLPlugin(BasePlugin): ...@@ -495,14 +537,14 @@ class InternVLPlugin(BasePlugin):
mm_inputs = {} mm_inputs = {}
image_video_patches = [] image_video_patches = []
if len(images) != 0 and isinstance(images[0], str): if len(images) != 0:
images = self._regularize_images( images = self._regularize_images(
images, images,
image_max_pixels=getattr(processor, "image_max_pixels", 1024 * 1024), image_max_pixels=getattr(processor, "image_max_pixels", 1024 * 1024),
image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32), image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
)["images"] )["images"]
if len(videos) != 0 and isinstance(videos[0], str): if len(videos) != 0:
videos = self._regularize_videos( videos = self._regularize_videos(
videos, videos,
image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256), image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256),
...@@ -566,8 +608,8 @@ class InternVLPlugin(BasePlugin): ...@@ -566,8 +608,8 @@ class InternVLPlugin(BasePlugin):
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> list[dict[str, str]]: ) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
num_image_tokens = 0 self._validate_messages(messages, images, videos, audios)
num_video_tokens = 0 num_image_tokens, num_video_tokens = 0, 0
image_seqlen = getattr(processor, "image_seq_length") if self.expand_mm_tokens else 1 image_seqlen = getattr(processor, "image_seq_length") if self.expand_mm_tokens else 1
messages = deepcopy(messages) messages = deepcopy(messages)
mm_inputs = self._get_mm_inputs(images, videos, audios, processor) mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
...@@ -579,9 +621,6 @@ class InternVLPlugin(BasePlugin): ...@@ -579,9 +621,6 @@ class InternVLPlugin(BasePlugin):
for message in messages: for message in messages:
content = message["content"] content = message["content"]
while IMAGE_PLACEHOLDER in content: while IMAGE_PLACEHOLDER in content:
if num_image_tokens >= len(images):
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
content = content.replace( content = content.replace(
IMAGE_PLACEHOLDER, IMAGE_PLACEHOLDER,
f"<img>{'<IMG_CONTEXT>' * image_seqlen * image_pixel_patch_list[num_image_tokens]}</img>", f"<img>{'<IMG_CONTEXT>' * image_seqlen * image_pixel_patch_list[num_image_tokens]}</img>",
...@@ -590,9 +629,6 @@ class InternVLPlugin(BasePlugin): ...@@ -590,9 +629,6 @@ class InternVLPlugin(BasePlugin):
num_image_tokens += 1 num_image_tokens += 1
while VIDEO_PLACEHOLDER in content: while VIDEO_PLACEHOLDER in content:
if num_video_tokens >= len(videos):
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
current_patch_index = video_patch_indices[num_video_tokens - 1] if num_video_tokens > 0 else 0 current_patch_index = video_patch_indices[num_video_tokens - 1] if num_video_tokens > 0 else 0
end_patch_index = video_patch_indices[num_video_tokens] end_patch_index = video_patch_indices[num_video_tokens]
num_patches = list(video_num_patches[current_patch_index:end_patch_index]) num_patches = list(video_num_patches[current_patch_index:end_patch_index])
...@@ -605,12 +641,6 @@ class InternVLPlugin(BasePlugin): ...@@ -605,12 +641,6 @@ class InternVLPlugin(BasePlugin):
message["content"] = content message["content"] = content
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
if len(videos) != num_video_tokens:
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
return messages return messages
@override @override
...@@ -637,10 +667,13 @@ class KimiVLPlugin(BasePlugin): ...@@ -637,10 +667,13 @@ class KimiVLPlugin(BasePlugin):
@override @override
def process_messages(self, messages, images, videos, audios, processor): def process_messages(self, messages, images, videos, audios, processor):
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
self._validate_messages(messages, images, videos, audios)
if self.expand_mm_tokens: if self.expand_mm_tokens:
mm_inputs = self._get_mm_inputs(images, videos, audios, processor) mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
image_grid_hws = mm_inputs.get("image_grid_hws", [])
else:
image_grid_hws = [None] * len(images)
image_grid_hws = mm_inputs.get("image_grid_hws", [])
num_image_tokens = 0 num_image_tokens = 0
image_processor: BaseImageProcessor = getattr(processor, "image_processor") image_processor: BaseImageProcessor = getattr(processor, "image_processor")
merge_length = math.prod(image_processor.merge_kernel_size) merge_length = math.prod(image_processor.merge_kernel_size)
...@@ -648,9 +681,6 @@ class KimiVLPlugin(BasePlugin): ...@@ -648,9 +681,6 @@ class KimiVLPlugin(BasePlugin):
for message in messages: for message in messages:
content = message["content"] content = message["content"]
while IMAGE_PLACEHOLDER in content: while IMAGE_PLACEHOLDER in content:
if num_image_tokens >= len(images):
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
image_seqlen = image_grid_hws[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1 image_seqlen = image_grid_hws[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1
content = content.replace( content = content.replace(
IMAGE_PLACEHOLDER, IMAGE_PLACEHOLDER,
...@@ -661,9 +691,6 @@ class KimiVLPlugin(BasePlugin): ...@@ -661,9 +691,6 @@ class KimiVLPlugin(BasePlugin):
message["content"] = content message["content"] = content
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
return messages return messages
...@@ -679,6 +706,7 @@ class Llama4Plugin(BasePlugin): ...@@ -679,6 +706,7 @@ class Llama4Plugin(BasePlugin):
processor: Optional["MMProcessor"], processor: Optional["MMProcessor"],
) -> list[dict[str, str]]: ) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
self._validate_messages(messages, images, videos, audios)
if self.expand_mm_tokens: if self.expand_mm_tokens:
mm_inputs = self._get_mm_inputs(images, videos, audios, processor) mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
if "pixel_values" in mm_inputs: if "pixel_values" in mm_inputs:
...@@ -701,9 +729,6 @@ class Llama4Plugin(BasePlugin): ...@@ -701,9 +729,6 @@ class Llama4Plugin(BasePlugin):
for local_image_index, split_part in enumerate(prompt_splits): for local_image_index, split_part in enumerate(prompt_splits):
new_content.append(split_part) new_content.append(split_part)
if local_image_index < placeholder_count: if local_image_index < placeholder_count:
if num_image_tokens >= len(images):
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
tokens_for_this_image = processor._prompt_split_image( tokens_for_this_image = processor._prompt_split_image(
aspect_ratios[num_image_tokens], num_patches_per_chunk aspect_ratios[num_image_tokens], num_patches_per_chunk
) )
...@@ -716,9 +741,6 @@ class Llama4Plugin(BasePlugin): ...@@ -716,9 +741,6 @@ class Llama4Plugin(BasePlugin):
message["content"] = content message["content"] = content
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
return messages return messages
@override @override
...@@ -751,7 +773,7 @@ class LlavaPlugin(BasePlugin): ...@@ -751,7 +773,7 @@ class LlavaPlugin(BasePlugin):
processor: Optional["MMProcessor"], processor: Optional["MMProcessor"],
) -> list[dict[str, str]]: ) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
num_image_tokens = 0 self._validate_messages(messages, images, videos, audios)
messages = deepcopy(messages) messages = deepcopy(messages)
if self.expand_mm_tokens: if self.expand_mm_tokens:
mm_inputs = self._get_mm_inputs(images, videos, audios, processor) mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
...@@ -768,17 +790,10 @@ class LlavaPlugin(BasePlugin): ...@@ -768,17 +790,10 @@ class LlavaPlugin(BasePlugin):
for message in messages: for message in messages:
content = message["content"] content = message["content"]
while IMAGE_PLACEHOLDER in content: while IMAGE_PLACEHOLDER in content:
if num_image_tokens >= len(images):
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1) content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
num_image_tokens += 1
message["content"] = content.replace("{{image}}", self.image_token) message["content"] = content.replace("{{image}}", self.image_token)
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
return messages return messages
...@@ -794,6 +809,7 @@ class LlavaNextPlugin(BasePlugin): ...@@ -794,6 +809,7 @@ class LlavaNextPlugin(BasePlugin):
processor: Optional["MMProcessor"], processor: Optional["MMProcessor"],
) -> list[dict[str, str]]: ) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
self._validate_messages(messages, images, videos, audios)
num_image_tokens = 0 num_image_tokens = 0
messages = deepcopy(messages) messages = deepcopy(messages)
if self.expand_mm_tokens: if self.expand_mm_tokens:
...@@ -805,9 +821,6 @@ class LlavaNextPlugin(BasePlugin): ...@@ -805,9 +821,6 @@ class LlavaNextPlugin(BasePlugin):
for message in messages: for message in messages:
content = message["content"] content = message["content"]
while IMAGE_PLACEHOLDER in content: while IMAGE_PLACEHOLDER in content:
if num_image_tokens >= len(images):
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
if self.expand_mm_tokens: if self.expand_mm_tokens:
orig_height, orig_width = next(image_sizes) orig_height, orig_width = next(image_sizes)
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width) image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
...@@ -821,9 +834,6 @@ class LlavaNextPlugin(BasePlugin): ...@@ -821,9 +834,6 @@ class LlavaNextPlugin(BasePlugin):
message["content"] = content.replace("{{image}}", self.image_token) message["content"] = content.replace("{{image}}", self.image_token)
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
return messages return messages
...@@ -839,7 +849,7 @@ class LlavaNextVideoPlugin(BasePlugin): ...@@ -839,7 +849,7 @@ class LlavaNextVideoPlugin(BasePlugin):
processor: Optional["MMProcessor"], processor: Optional["MMProcessor"],
) -> list[dict[str, str]]: ) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
num_image_tokens, num_video_tokens = 0, 0 self._validate_messages(messages, images, videos, audios)
messages = deepcopy(messages) messages = deepcopy(messages)
if self.expand_mm_tokens: if self.expand_mm_tokens:
mm_inputs = self._get_mm_inputs(images, videos, audios, processor) mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
...@@ -850,9 +860,6 @@ class LlavaNextVideoPlugin(BasePlugin): ...@@ -850,9 +860,6 @@ class LlavaNextVideoPlugin(BasePlugin):
for message in messages: for message in messages:
content = message["content"] content = message["content"]
while IMAGE_PLACEHOLDER in content: while IMAGE_PLACEHOLDER in content:
if num_image_tokens >= len(images):
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
if self.expand_mm_tokens: if self.expand_mm_tokens:
orig_height, orig_width = next(image_sizes) orig_height, orig_width = next(image_sizes)
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width) image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
...@@ -862,7 +869,6 @@ class LlavaNextVideoPlugin(BasePlugin): ...@@ -862,7 +869,6 @@ class LlavaNextVideoPlugin(BasePlugin):
image_seqlen = 1 image_seqlen = 1
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1) content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
num_image_tokens += 1
message["content"] = content.replace("{{image}}", self.image_token) message["content"] = content.replace("{{image}}", self.image_token)
...@@ -879,20 +885,10 @@ class LlavaNextVideoPlugin(BasePlugin): ...@@ -879,20 +885,10 @@ class LlavaNextVideoPlugin(BasePlugin):
for message in messages: for message in messages:
content = message["content"] content = message["content"]
while VIDEO_PLACEHOLDER in content: while VIDEO_PLACEHOLDER in content:
if num_video_tokens >= len(videos):
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1) content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1)
num_video_tokens += 1
message["content"] = content.replace("{{video}}", self.video_token) message["content"] = content.replace("{{video}}", self.video_token)
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
if len(videos) != num_video_tokens:
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
return messages return messages
...@@ -978,6 +974,7 @@ class MiniCPMVPlugin(BasePlugin): ...@@ -978,6 +974,7 @@ class MiniCPMVPlugin(BasePlugin):
processor: Optional["MMProcessor"], processor: Optional["MMProcessor"],
) -> list[dict[str, str]]: ) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
self._validate_messages(messages, images, videos, audios)
num_image_tokens, num_video_tokens, num_audio_tokens = 0, 0, 0 num_image_tokens, num_video_tokens, num_audio_tokens = 0, 0, 0
messages = deepcopy(messages) messages = deepcopy(messages)
image_processor: BaseImageProcessor = getattr(processor, "image_processor") image_processor: BaseImageProcessor = getattr(processor, "image_processor")
...@@ -996,24 +993,15 @@ class MiniCPMVPlugin(BasePlugin): ...@@ -996,24 +993,15 @@ class MiniCPMVPlugin(BasePlugin):
for i, message in enumerate(messages): for i, message in enumerate(messages):
content = message["content"] content = message["content"]
while IMAGE_PLACEHOLDER in content: while IMAGE_PLACEHOLDER in content:
if num_image_tokens >= len(images):
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1) content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
num_image_tokens += 1 num_image_tokens += 1
while VIDEO_PLACEHOLDER in content: while VIDEO_PLACEHOLDER in content:
if num_video_tokens >= len(videos):
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
video_seqlen = len(mm_inputs["pixel_values"][num_video_tokens]) if self.expand_mm_tokens else 1 video_seqlen = len(mm_inputs["pixel_values"][num_video_tokens]) if self.expand_mm_tokens else 1
content = content.replace(VIDEO_PLACEHOLDER, "{{image}}" * video_seqlen, 1) content = content.replace(VIDEO_PLACEHOLDER, "{{image}}" * video_seqlen, 1)
num_video_tokens += 1 num_video_tokens += 1
while AUDIO_PLACEHOLDER in content: while AUDIO_PLACEHOLDER in content:
if num_audio_tokens >= len(audios):
raise ValueError(f"`len(audios)` is less than the number of {AUDIO_PLACEHOLDER} tokens.")
content = content.replace(AUDIO_PLACEHOLDER, "{{audio}}", 1) content = content.replace(AUDIO_PLACEHOLDER, "{{audio}}", 1)
num_audio_tokens += 1 num_audio_tokens += 1
...@@ -1065,15 +1053,6 @@ class MiniCPMVPlugin(BasePlugin): ...@@ -1065,15 +1053,6 @@ class MiniCPMVPlugin(BasePlugin):
final_text += text_chunks[-1] final_text += text_chunks[-1]
messages[index]["content"] = final_text messages[index]["content"] = final_text
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
if len(videos) != num_video_tokens:
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
if len(audios) != num_audio_tokens:
raise ValueError(f"The number of audios does not match the number of {AUDIO_PLACEHOLDER} tokens.")
return messages return messages
@override @override
...@@ -1157,6 +1136,7 @@ class MllamaPlugin(BasePlugin): ...@@ -1157,6 +1136,7 @@ class MllamaPlugin(BasePlugin):
processor: Optional["MMProcessor"], processor: Optional["MMProcessor"],
) -> list[dict[str, str]]: ) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
self._validate_messages(messages, images, videos, audios)
num_image_tokens = 0 num_image_tokens = 0
messages = deepcopy(messages) messages = deepcopy(messages)
for message in messages: for message in messages:
...@@ -1164,9 +1144,6 @@ class MllamaPlugin(BasePlugin): ...@@ -1164,9 +1144,6 @@ class MllamaPlugin(BasePlugin):
num_image_tokens += content.count(IMAGE_PLACEHOLDER) num_image_tokens += content.count(IMAGE_PLACEHOLDER)
message["content"] = content.replace(IMAGE_PLACEHOLDER, self.image_token) message["content"] = content.replace(IMAGE_PLACEHOLDER, self.image_token)
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
return messages return messages
@override @override
...@@ -1214,6 +1191,7 @@ class PaliGemmaPlugin(BasePlugin): ...@@ -1214,6 +1191,7 @@ class PaliGemmaPlugin(BasePlugin):
processor: Optional["MMProcessor"], processor: Optional["MMProcessor"],
) -> list[dict[str, str]]: ) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
self._validate_messages(messages, images, videos, audios)
num_image_tokens = 0 num_image_tokens = 0
messages = deepcopy(messages) messages = deepcopy(messages)
for message in messages: for message in messages:
...@@ -1224,9 +1202,6 @@ class PaliGemmaPlugin(BasePlugin): ...@@ -1224,9 +1202,6 @@ class PaliGemmaPlugin(BasePlugin):
message["content"] = content message["content"] = content
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
return messages return messages
@override @override
...@@ -1281,7 +1256,7 @@ class PixtralPlugin(BasePlugin): ...@@ -1281,7 +1256,7 @@ class PixtralPlugin(BasePlugin):
processor: Optional["MMProcessor"], processor: Optional["MMProcessor"],
) -> list[dict[str, str]]: ) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
num_image_tokens = 0 self._validate_messages(messages, images, videos, audios)
messages = deepcopy(messages) messages = deepcopy(messages)
if self.expand_mm_tokens: if self.expand_mm_tokens:
mm_inputs = self._get_mm_inputs(images, videos, audios, processor) mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
...@@ -1291,15 +1266,13 @@ class PixtralPlugin(BasePlugin): ...@@ -1291,15 +1266,13 @@ class PixtralPlugin(BasePlugin):
image_sizes = iter(mm_inputs["image_sizes"][0]) image_sizes = iter(mm_inputs["image_sizes"][0])
else: else:
image_sizes = iter(mm_inputs["image_sizes"].tolist()) image_sizes = iter(mm_inputs["image_sizes"].tolist())
image_break_token: str = getattr(processor, "image_break_token") image_break_token: str = getattr(processor, "image_break_token")
image_end_token: str = getattr(processor, "image_end_token") image_end_token: str = getattr(processor, "image_end_token")
for message in messages: for message in messages:
content = message["content"] content = message["content"]
while IMAGE_PLACEHOLDER in content: while IMAGE_PLACEHOLDER in content:
if num_image_tokens >= len(images):
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
if self.expand_mm_tokens: if self.expand_mm_tokens:
height, width = next(image_sizes) height, width = next(image_sizes)
num_height_tokens = height // processor.patch_size num_height_tokens = height // processor.patch_size
...@@ -1312,13 +1285,9 @@ class PixtralPlugin(BasePlugin): ...@@ -1312,13 +1285,9 @@ class PixtralPlugin(BasePlugin):
replace_str = self.image_token replace_str = self.image_token
content = content.replace(IMAGE_PLACEHOLDER, replace_str, 1) content = content.replace(IMAGE_PLACEHOLDER, replace_str, 1)
num_image_tokens += 1
message["content"] = content message["content"] = content
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
return messages return messages
@override @override
...@@ -1355,9 +1324,9 @@ class Qwen2AudioPlugin(BasePlugin): ...@@ -1355,9 +1324,9 @@ class Qwen2AudioPlugin(BasePlugin):
processor: Optional["MMProcessor"], processor: Optional["MMProcessor"],
) -> list[dict[str, str]]: ) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
self._validate_messages(messages, images, videos, audios)
bos_token: str = getattr(processor, "audio_bos_token") bos_token: str = getattr(processor, "audio_bos_token")
eos_token: str = getattr(processor, "audio_eos_token") eos_token: str = getattr(processor, "audio_eos_token")
num_audio_tokens = 0
messages = deepcopy(messages) messages = deepcopy(messages)
if self.expand_mm_tokens: if self.expand_mm_tokens:
mm_inputs = self._get_mm_inputs([], [], audios, processor) mm_inputs = self._get_mm_inputs([], [], audios, processor)
...@@ -1367,9 +1336,6 @@ class Qwen2AudioPlugin(BasePlugin): ...@@ -1367,9 +1336,6 @@ class Qwen2AudioPlugin(BasePlugin):
for message in messages: for message in messages:
content = message["content"] content = message["content"]
while AUDIO_PLACEHOLDER in content: while AUDIO_PLACEHOLDER in content:
if num_audio_tokens >= len(audios):
raise ValueError(f"`len(audios)` is less than the number of {AUDIO_PLACEHOLDER} tokens.")
if self.expand_mm_tokens: if self.expand_mm_tokens:
audio_length = audio_lengths.pop(0) audio_length = audio_lengths.pop(0)
input_length = (audio_length - 1) // 2 + 1 input_length = (audio_length - 1) // 2 + 1
...@@ -1380,13 +1346,9 @@ class Qwen2AudioPlugin(BasePlugin): ...@@ -1380,13 +1346,9 @@ class Qwen2AudioPlugin(BasePlugin):
content = content.replace( content = content.replace(
AUDIO_PLACEHOLDER, f"{bos_token}{self.audio_token * audio_seqlen}{eos_token}", 1 AUDIO_PLACEHOLDER, f"{bos_token}{self.audio_token * audio_seqlen}{eos_token}", 1
) )
num_audio_tokens += 1
message["content"] = content message["content"] = content
if len(audios) != num_audio_tokens:
raise ValueError(f"The number of audios does not match the number of {AUDIO_PLACEHOLDER} tokens.")
return messages return messages
@override @override
...@@ -1430,24 +1392,33 @@ class Qwen2VLPlugin(BasePlugin): ...@@ -1430,24 +1392,33 @@ class Qwen2VLPlugin(BasePlugin):
) -> dict[str, Union[list[list["ImageObject"]], list[float]]]: ) -> dict[str, Union[list[list["ImageObject"]], list[float]]]:
results, fps_per_video = [], [] results, fps_per_video = [], []
for video in videos: for video in videos:
container = av.open(video, "r")
video_stream = next(stream for stream in container.streams if stream.type == "video")
sample_indices = self._get_video_sample_indices(video_stream, **kwargs)
frames: list[ImageObject] = [] frames: list[ImageObject] = []
container.seek(0) if _check_video_is_nested_images(video):
for frame_idx, frame in enumerate(container.decode(video_stream)): for frame in video:
if frame_idx in sample_indices: if not is_valid_image(frame) and not isinstance(frame, dict) and not os.path.exists(frame):
frames.append(frame.to_image()) raise ValueError("Invalid image found in video frames.")
frames = video
fps_per_video.append(kwargs.get("video_fps", 2.0))
else:
container = av.open(video, "r")
video_stream = next(stream for stream in container.streams if stream.type == "video")
sample_indices = self._get_video_sample_indices(video_stream, **kwargs)
container.seek(0)
for frame_idx, frame in enumerate(container.decode(video_stream)):
if frame_idx in sample_indices:
frames.append(frame.to_image())
if video_stream.duration is None:
fps_per_video.append(kwargs.get("video_fps", 2.0))
else:
fps_per_video.append(len(sample_indices) / float(video_stream.duration * video_stream.time_base))
if len(frames) % 2 != 0: # qwen2-vl requires even number of frames if len(frames) % 2 != 0:
frames.append(frames[-1]) frames.append(frames[-1])
frames = self._regularize_images(frames, **kwargs)["images"] frames = self._regularize_images(frames, **kwargs)["images"]
results.append(frames) results.append(frames)
if video_stream.duration is None:
fps_per_video.append(2.0)
else:
fps_per_video.append(len(sample_indices) / float(video_stream.duration * video_stream.time_base))
return {"videos": results, "fps_per_video": fps_per_video} return {"videos": results, "fps_per_video": fps_per_video}
...@@ -1494,6 +1465,7 @@ class Qwen2VLPlugin(BasePlugin): ...@@ -1494,6 +1465,7 @@ class Qwen2VLPlugin(BasePlugin):
processor: Optional["MMProcessor"], processor: Optional["MMProcessor"],
) -> list[dict[str, str]]: ) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
self._validate_messages(messages, images, videos, audios)
num_image_tokens, num_video_tokens = 0, 0 num_image_tokens, num_video_tokens = 0, 0
messages = deepcopy(messages) messages = deepcopy(messages)
image_processor: BaseImageProcessor = getattr(processor, "image_processor") image_processor: BaseImageProcessor = getattr(processor, "image_processor")
...@@ -1510,9 +1482,6 @@ class Qwen2VLPlugin(BasePlugin): ...@@ -1510,9 +1482,6 @@ class Qwen2VLPlugin(BasePlugin):
for message in messages: for message in messages:
content = message["content"] content = message["content"]
while IMAGE_PLACEHOLDER in content: while IMAGE_PLACEHOLDER in content:
if num_image_tokens >= len(images):
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
image_seqlen = image_grid_thw[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1 image_seqlen = image_grid_thw[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1
content = content.replace( content = content.replace(
IMAGE_PLACEHOLDER, f"<|vision_start|>{self.image_token * image_seqlen}<|vision_end|>", 1 IMAGE_PLACEHOLDER, f"<|vision_start|>{self.image_token * image_seqlen}<|vision_end|>", 1
...@@ -1520,9 +1489,6 @@ class Qwen2VLPlugin(BasePlugin): ...@@ -1520,9 +1489,6 @@ class Qwen2VLPlugin(BasePlugin):
num_image_tokens += 1 num_image_tokens += 1
while VIDEO_PLACEHOLDER in content: while VIDEO_PLACEHOLDER in content:
if num_video_tokens >= len(videos):
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
video_seqlen = video_grid_thw[num_video_tokens].prod() // merge_length if self.expand_mm_tokens else 1 video_seqlen = video_grid_thw[num_video_tokens].prod() // merge_length if self.expand_mm_tokens else 1
content = content.replace( content = content.replace(
VIDEO_PLACEHOLDER, f"<|vision_start|>{self.video_token * video_seqlen}<|vision_end|>", 1 VIDEO_PLACEHOLDER, f"<|vision_start|>{self.video_token * video_seqlen}<|vision_end|>", 1
...@@ -1531,12 +1497,6 @@ class Qwen2VLPlugin(BasePlugin): ...@@ -1531,12 +1497,6 @@ class Qwen2VLPlugin(BasePlugin):
message["content"] = content message["content"] = content
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
if len(videos) != num_video_tokens:
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
return messages return messages
...@@ -1602,6 +1562,7 @@ class Qwen2OmniPlugin(Qwen2VLPlugin): ...@@ -1602,6 +1562,7 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
processor: Optional["MMProcessor"], processor: Optional["MMProcessor"],
) -> list[dict[str, str]]: ) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
self._validate_messages(messages, images, videos, audios)
num_image_tokens, num_video_tokens, num_audio_tokens = 0, 0, 0 num_image_tokens, num_video_tokens, num_audio_tokens = 0, 0, 0
messages = deepcopy(messages) messages = deepcopy(messages)
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None) image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
...@@ -1624,9 +1585,6 @@ class Qwen2OmniPlugin(Qwen2VLPlugin): ...@@ -1624,9 +1585,6 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
for message in messages: for message in messages:
content = message["content"] content = message["content"]
while IMAGE_PLACEHOLDER in content: while IMAGE_PLACEHOLDER in content:
if num_image_tokens >= len(images):
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
image_seqlen = image_grid_thw[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1 image_seqlen = image_grid_thw[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1
content = content.replace( content = content.replace(
IMAGE_PLACEHOLDER, f"<|vision_bos|>{self.image_token * image_seqlen}<|vision_eos|>", 1 IMAGE_PLACEHOLDER, f"<|vision_bos|>{self.image_token * image_seqlen}<|vision_eos|>", 1
...@@ -1642,11 +1600,6 @@ class Qwen2OmniPlugin(Qwen2VLPlugin): ...@@ -1642,11 +1600,6 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
) )
while VIDEO_PLACEHOLDER in content: while VIDEO_PLACEHOLDER in content:
if num_video_tokens >= len(videos):
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
if num_audio_tokens >= len(audios):
raise ValueError(f"`len(audios)` is less than the number of {AUDIO_PLACEHOLDER} tokens.")
video_pos = content.find(VIDEO_PLACEHOLDER) video_pos = content.find(VIDEO_PLACEHOLDER)
audio_pos = content.find(AUDIO_PLACEHOLDER, video_pos) audio_pos = content.find(AUDIO_PLACEHOLDER, video_pos)
if audio_pos == -1 or audio_pos < video_pos: if audio_pos == -1 or audio_pos < video_pos:
...@@ -1688,9 +1641,6 @@ class Qwen2OmniPlugin(Qwen2VLPlugin): ...@@ -1688,9 +1641,6 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
num_video_tokens += 1 num_video_tokens += 1
else: else:
while AUDIO_PLACEHOLDER in content: while AUDIO_PLACEHOLDER in content:
if num_audio_tokens >= len(audios):
raise ValueError(f"`len(audios)` is less than the number of {AUDIO_PLACEHOLDER} tokens.")
audio_seqlen = audio_lengths[num_audio_tokens] if self.expand_mm_tokens else 1 audio_seqlen = audio_lengths[num_audio_tokens] if self.expand_mm_tokens else 1
content = content.replace( content = content.replace(
AUDIO_PLACEHOLDER, f"<|audio_bos|>{self.audio_token * audio_seqlen}<|audio_eos|>", 1 AUDIO_PLACEHOLDER, f"<|audio_bos|>{self.audio_token * audio_seqlen}<|audio_eos|>", 1
...@@ -1698,9 +1648,6 @@ class Qwen2OmniPlugin(Qwen2VLPlugin): ...@@ -1698,9 +1648,6 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
num_audio_tokens += 1 num_audio_tokens += 1
while VIDEO_PLACEHOLDER in content: while VIDEO_PLACEHOLDER in content:
if num_video_tokens >= len(videos):
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
video_seqlen = ( video_seqlen = (
video_grid_thw[num_video_tokens].prod() // merge_length if self.expand_mm_tokens else 1 video_grid_thw[num_video_tokens].prod() // merge_length if self.expand_mm_tokens else 1
) )
...@@ -1711,15 +1658,6 @@ class Qwen2OmniPlugin(Qwen2VLPlugin): ...@@ -1711,15 +1658,6 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
message["content"] = content message["content"] = content
if len(audios) != num_audio_tokens:
raise ValueError(f"The number of audios does not match the number of {AUDIO_PLACEHOLDER} tokens.")
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
if len(videos) != num_video_tokens:
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
return messages return messages
...@@ -1735,6 +1673,7 @@ class VideoLlavaPlugin(BasePlugin): ...@@ -1735,6 +1673,7 @@ class VideoLlavaPlugin(BasePlugin):
processor: Optional["MMProcessor"], processor: Optional["MMProcessor"],
) -> list[dict[str, str]]: ) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
self._validate_messages(messages, images, videos, audios)
num_image_tokens, num_video_tokens = 0, 0 num_image_tokens, num_video_tokens = 0, 0
messages = deepcopy(messages) messages = deepcopy(messages)
num_frames = 0 num_frames = 0
...@@ -1762,28 +1701,16 @@ class VideoLlavaPlugin(BasePlugin): ...@@ -1762,28 +1701,16 @@ class VideoLlavaPlugin(BasePlugin):
for message in messages: for message in messages:
content = message["content"] content = message["content"]
while IMAGE_PLACEHOLDER in content: while IMAGE_PLACEHOLDER in content:
if num_image_tokens >= len(images):
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1) content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
num_image_tokens += 1 num_image_tokens += 1
while VIDEO_PLACEHOLDER in content: while VIDEO_PLACEHOLDER in content:
if num_video_tokens >= len(videos):
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1) content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1)
num_video_tokens += 1 num_video_tokens += 1
content = content.replace("{{image}}", self.image_token) content = content.replace("{{image}}", self.image_token)
message["content"] = content.replace("{{video}}", self.video_token) message["content"] = content.replace("{{video}}", self.video_token)
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
if len(videos) != num_video_tokens:
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
return messages return messages
......
...@@ -115,12 +115,7 @@ def get_dataset_list(dataset_names: Optional[list[str]], dataset_dir: str) -> li ...@@ -115,12 +115,7 @@ def get_dataset_list(dataset_names: Optional[list[str]], dataset_dir: str) -> li
dataset_list: list[DatasetAttr] = [] dataset_list: list[DatasetAttr] = []
for name in dataset_names: for name in dataset_names:
if dataset_info is None: # dataset_dir is ONLINE if dataset_info is None: # dataset_dir is ONLINE
if use_modelscope(): load_from = "ms_hub" if use_modelscope() else "om_hub" if use_openmind() else "hf_hub"
load_from = "ms_hub"
elif use_openmind():
load_from = "om_hub"
else:
load_from = "hf_hub"
dataset_attr = DatasetAttr(load_from, dataset_name=name) dataset_attr = DatasetAttr(load_from, dataset_name=name)
dataset_list.append(dataset_attr) dataset_list.append(dataset_attr)
continue continue
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import re import re
from copy import deepcopy
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional, Union from typing import TYPE_CHECKING, Optional, Union
...@@ -51,6 +52,7 @@ class Template: ...@@ -51,6 +52,7 @@ class Template:
efficient_eos: bool efficient_eos: bool
replace_eos: bool replace_eos: bool
replace_jinja_template: bool replace_jinja_template: bool
enable_thinking: Optional[bool]
mm_plugin: "BasePlugin" mm_plugin: "BasePlugin"
def encode_oneturn( def encode_oneturn(
...@@ -61,7 +63,7 @@ class Template: ...@@ -61,7 +63,7 @@ class Template:
tools: Optional[str] = None, tools: Optional[str] = None,
) -> tuple[list[int], list[int]]: ) -> tuple[list[int], list[int]]:
r"""Return a single pair of token ids representing prompt and response respectively.""" r"""Return a single pair of token ids representing prompt and response respectively."""
encoded_messages = self._encode(tokenizer, messages, system, tools, remove_thought=True) encoded_messages = self._encode(tokenizer, messages, system, tools)
prompt_ids = [] prompt_ids = []
for encoded_ids in encoded_messages[:-1]: for encoded_ids in encoded_messages[:-1]:
prompt_ids += encoded_ids prompt_ids += encoded_ids
...@@ -77,7 +79,7 @@ class Template: ...@@ -77,7 +79,7 @@ class Template:
tools: Optional[str] = None, tools: Optional[str] = None,
) -> list[tuple[list[int], list[int]]]: ) -> list[tuple[list[int], list[int]]]:
r"""Return multiple pairs of token ids representing prompts and responses respectively.""" r"""Return multiple pairs of token ids representing prompts and responses respectively."""
encoded_messages = self._encode(tokenizer, messages, system, tools, remove_thought=False) encoded_messages = self._encode(tokenizer, messages, system, tools)
return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)] return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)]
def extract_tool(self, content: str) -> Union[str, list["FunctionCall"]]: def extract_tool(self, content: str) -> Union[str, list["FunctionCall"]]:
...@@ -92,6 +94,19 @@ class Template: ...@@ -92,6 +94,19 @@ class Template:
return list(stop_token_ids) return list(stop_token_ids)
def add_thought(self, content: str = "") -> str:
r"""Add empty thought to assistant message."""
return f"{self.thought_words[0]}\n\n{self.thought_words[1]}\n\n" + content
def remove_thought(self, content: str) -> str:
r"""Remove thought from assistant message."""
pattern = re.compile(f"{re.escape(self.thought_words[0])}(.*?){re.escape(self.thought_words[1])}", re.DOTALL)
return re.sub(pattern, "", content).lstrip("\n")
def get_thought_word_ids(self, tokenizer: "PreTrainedTokenizer") -> list[int]:
r"""Get the token ids of thought words."""
return tokenizer.encode(self.add_thought(), add_special_tokens=False)
def _convert_elements_to_ids(self, tokenizer: "PreTrainedTokenizer", elements: "SLOTS") -> list[int]: def _convert_elements_to_ids(self, tokenizer: "PreTrainedTokenizer", elements: "SLOTS") -> list[int]:
r"""Convert elements to token ids.""" r"""Convert elements to token ids."""
token_ids = [] token_ids = []
...@@ -111,18 +126,12 @@ class Template: ...@@ -111,18 +126,12 @@ class Template:
return token_ids return token_ids
def _remove_thought(self, content: str) -> str:
r"""Remove thought from assistant message."""
pattern = re.compile(f"{re.escape(self.thought_words[0])}(.*?){re.escape(self.thought_words[1])}", re.DOTALL)
return re.sub(pattern, "", content).lstrip("\n")
def _encode( def _encode(
self, self,
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
messages: list[dict[str, str]], messages: list[dict[str, str]],
system: Optional[str], system: Optional[str],
tools: Optional[str], tools: Optional[str],
remove_thought: bool,
) -> list[list[int]]: ) -> list[list[int]]:
r"""Encode formatted inputs to pairs of token ids. r"""Encode formatted inputs to pairs of token ids.
...@@ -140,18 +149,14 @@ class Template: ...@@ -140,18 +149,14 @@ class Template:
tool_text = self.format_tools.apply(content=tools)[0] if tools else "" tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
elements += self.format_system.apply(content=(system + tool_text)) elements += self.format_system.apply(content=(system + tool_text))
content = message["content"]
if remove_thought and message["role"] == Role.ASSISTANT and (i != len(messages) - 1):
content = self._remove_thought(content)
if message["role"] == Role.USER: if message["role"] == Role.USER:
elements += self.format_user.apply(content=content, idx=str(i // 2)) elements += self.format_user.apply(content=message["content"], idx=str(i // 2))
elif message["role"] == Role.ASSISTANT: elif message["role"] == Role.ASSISTANT:
elements += self.format_assistant.apply(content=content) elements += self.format_assistant.apply(content=message["content"])
elif message["role"] == Role.OBSERVATION: elif message["role"] == Role.OBSERVATION:
elements += self.format_observation.apply(content=content) elements += self.format_observation.apply(content=message["content"])
elif message["role"] == Role.FUNCTION: elif message["role"] == Role.FUNCTION:
elements += self.format_function.apply(content=content) elements += self.format_function.apply(content=message["content"])
else: else:
raise NotImplementedError("Unexpected role: {}".format(message["role"])) raise NotImplementedError("Unexpected role: {}".format(message["role"]))
...@@ -162,6 +167,9 @@ class Template: ...@@ -162,6 +167,9 @@ class Template:
@staticmethod @staticmethod
def _add_or_replace_eos_token(tokenizer: "PreTrainedTokenizer", eos_token: str) -> None: def _add_or_replace_eos_token(tokenizer: "PreTrainedTokenizer", eos_token: str) -> None:
r"""Add or replace eos token to the tokenizer.""" r"""Add or replace eos token to the tokenizer."""
if tokenizer.eos_token == eos_token:
return
is_added = tokenizer.eos_token_id is None is_added = tokenizer.eos_token_id is None
num_added_tokens = tokenizer.add_special_tokens({"eos_token": eos_token}) num_added_tokens = tokenizer.add_special_tokens({"eos_token": eos_token})
...@@ -328,7 +336,6 @@ class Llama2Template(Template): ...@@ -328,7 +336,6 @@ class Llama2Template(Template):
messages: list[dict[str, str]], messages: list[dict[str, str]],
system: str, system: str,
tools: str, tools: str,
remove_thought: bool,
) -> list[list[int]]: ) -> list[list[int]]:
system = system or self.default_system system = system or self.default_system
encoded_messages = [] encoded_messages = []
...@@ -342,18 +349,14 @@ class Llama2Template(Template): ...@@ -342,18 +349,14 @@ class Llama2Template(Template):
tool_text = self.format_tools.apply(content=tools)[0] if tools else "" tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
system_text = self.format_system.apply(content=(system + tool_text))[0] system_text = self.format_system.apply(content=(system + tool_text))[0]
content = message["content"]
if remove_thought and message["role"] == Role.ASSISTANT and (i != len(messages) - 1):
content = self._remove_thought(content)
if message["role"] == Role.USER: if message["role"] == Role.USER:
elements += self.format_user.apply(content=system_text + content) elements += self.format_user.apply(content=system_text + message["content"])
elif message["role"] == Role.ASSISTANT: elif message["role"] == Role.ASSISTANT:
elements += self.format_assistant.apply(content=content) elements += self.format_assistant.apply(content=message["content"])
elif message["role"] == Role.OBSERVATION: elif message["role"] == Role.OBSERVATION:
elements += self.format_observation.apply(content=content) elements += self.format_observation.apply(content=message["content"])
elif message["role"] == Role.FUNCTION: elif message["role"] == Role.FUNCTION:
elements += self.format_function.apply(content=content) elements += self.format_function.apply(content=message["content"])
else: else:
raise NotImplementedError("Unexpected role: {}".format(message["role"])) raise NotImplementedError("Unexpected role: {}".format(message["role"]))
...@@ -392,6 +395,64 @@ class Llama2Template(Template): ...@@ -392,6 +395,64 @@ class Llama2Template(Template):
return jinja_template return jinja_template
@dataclass
class ReasoningTemplate(Template):
r"""A template that add thought to assistant message."""
@override
def encode_oneturn(
self,
tokenizer: "PreTrainedTokenizer",
messages: list[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
) -> tuple[list[int], list[int]]:
messages = deepcopy(messages)
for i in range(1, len(messages) - 2, 2):
messages[i]["content"] = self.remove_thought(messages[i]["content"])
if self.enable_thinking is False: # remove all cot
messages[-1]["content"] = self.remove_thought(messages[-1]["content"])
prompt_ids, response_ids = super().encode_oneturn(tokenizer, messages, system, tools)
if (
self.thought_words[0] not in messages[-1]["content"]
and self.thought_words[1] not in messages[-1]["content"]
): # add empty cot
if not self.enable_thinking: # do not compute loss
prompt_ids += self.get_thought_word_ids(tokenizer)
else: # do compute loss
response_ids = self.get_thought_word_ids(tokenizer) + response_ids
return prompt_ids, response_ids
@override
def encode_multiturn(
self,
tokenizer: "PreTrainedTokenizer",
messages: list[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
) -> list[tuple[list[int], list[int]]]:
messages = deepcopy(messages)
if self.enable_thinking is False: # remove all cot
for i in range(1, len(messages), 2):
messages[i]["content"] = self.remove_thought(messages[i]["content"])
encoded_messages = self._encode(tokenizer, messages, system, tools)
for i in range(0, len(messages), 2):
if (
self.thought_words[0] not in messages[i + 1]["content"]
and self.thought_words[1] not in messages[i + 1]["content"]
): # add empty cot
if not self.enable_thinking: # do not compute loss
encoded_messages[i] += self.get_thought_word_ids(tokenizer)
else: # do compute loss
encoded_messages[i + 1] = self.get_thought_word_ids(tokenizer) + encoded_messages[i + 1]
return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)]
TEMPLATES: dict[str, "Template"] = {} TEMPLATES: dict[str, "Template"] = {}
...@@ -410,6 +471,7 @@ def register_template( ...@@ -410,6 +471,7 @@ def register_template(
efficient_eos: bool = False, efficient_eos: bool = False,
replace_eos: bool = False, replace_eos: bool = False,
replace_jinja_template: bool = False, replace_jinja_template: bool = False,
enable_thinking: Optional[bool] = True,
mm_plugin: "BasePlugin" = get_mm_plugin(name="base"), mm_plugin: "BasePlugin" = get_mm_plugin(name="base"),
template_class: type["Template"] = Template, template_class: type["Template"] = Template,
) -> None: ) -> None:
...@@ -456,6 +518,7 @@ def register_template( ...@@ -456,6 +518,7 @@ def register_template(
efficient_eos=efficient_eos, efficient_eos=efficient_eos,
replace_eos=replace_eos, replace_eos=replace_eos,
replace_jinja_template=replace_jinja_template, replace_jinja_template=replace_jinja_template,
enable_thinking=enable_thinking,
mm_plugin=mm_plugin, mm_plugin=mm_plugin,
) )
...@@ -492,6 +555,7 @@ def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template": ...@@ -492,6 +555,7 @@ def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template":
messages = [{"role": "user", "content": "{{content}}"}, {"role": "assistant", "content": "{{content}}"}] messages = [{"role": "user", "content": "{{content}}"}, {"role": "assistant", "content": "{{content}}"}]
assistant_slot = tokenizer.apply_chat_template(messages, add_generation_prompt=False, tokenize=False) assistant_slot = tokenizer.apply_chat_template(messages, add_generation_prompt=False, tokenize=False)
assistant_slot = assistant_slot[len(prefix) + len(user_slot) :] assistant_slot = assistant_slot[len(prefix) + len(user_slot) :]
template_class = ReasoningTemplate if "<think>" in assistant_slot else Template
assistant_slot = assistant_slot.replace("<think>", "").replace("</think>", "").lstrip("\n") # remove thought tags assistant_slot = assistant_slot.replace("<think>", "").replace("</think>", "").lstrip("\n") # remove thought tags
if len(user_slot) > len(user_slot_empty_system): if len(user_slot) > len(user_slot_empty_system):
...@@ -501,7 +565,7 @@ def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template": ...@@ -501,7 +565,7 @@ def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template":
else: # if defaut_system is empty, user_slot_empty_system will be longer than user_slot else: # if defaut_system is empty, user_slot_empty_system will be longer than user_slot
default_system = "" default_system = ""
return Template( return template_class(
format_user=StringFormatter(slots=[user_slot]), format_user=StringFormatter(slots=[user_slot]),
format_assistant=StringFormatter(slots=[assistant_slot]), format_assistant=StringFormatter(slots=[assistant_slot]),
format_system=StringFormatter(slots=[system_slot]), format_system=StringFormatter(slots=[system_slot]),
...@@ -515,6 +579,7 @@ def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template": ...@@ -515,6 +579,7 @@ def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template":
efficient_eos=False, efficient_eos=False,
replace_eos=False, replace_eos=False,
replace_jinja_template=False, replace_jinja_template=False,
enable_thinking=True,
mm_plugin=get_mm_plugin(name="base"), mm_plugin=get_mm_plugin(name="base"),
) )
...@@ -543,6 +608,11 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: ...@@ -543,6 +608,11 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
template.format_function = FunctionFormatter(slots=default_slots, tool_format=data_args.tool_format) template.format_function = FunctionFormatter(slots=default_slots, tool_format=data_args.tool_format)
template.format_tools = ToolFormatter(tool_format=data_args.tool_format) template.format_tools = ToolFormatter(tool_format=data_args.tool_format)
if data_args.default_system is not None:
logger.info_rank0(f"Using default system message: {data_args.default_system}.")
template.default_system = data_args.default_system
template.enable_thinking = data_args.enable_thinking
template.fix_special_tokens(tokenizer) template.fix_special_tokens(tokenizer)
template.fix_jinja_template(tokenizer) template.fix_jinja_template(tokenizer)
return template return template
...@@ -756,6 +826,7 @@ register_template( ...@@ -756,6 +826,7 @@ register_template(
"ABOUT YOURSELF UNLESS THE INFORMATION IS DIRECTLY PERTINENT TO THE USER'S QUERY." "ABOUT YOURSELF UNLESS THE INFORMATION IS DIRECTLY PERTINENT TO THE USER'S QUERY."
), ),
stop_words=["<|im_end|>"], stop_words=["<|im_end|>"],
replace_eos=True,
) )
...@@ -774,6 +845,15 @@ register_template( ...@@ -774,6 +845,15 @@ register_template(
) )
# copied from deepseek3 template
register_template(
name="deepseekr1",
format_user=StringFormatter(slots=["<|User|>{{content}}<|Assistant|>"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
template_class=ReasoningTemplate,
)
register_template( register_template(
name="deepseekcoder", name="deepseekcoder",
format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n### Response:"]), format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n### Response:"]),
...@@ -838,6 +918,7 @@ register_template( ...@@ -838,6 +918,7 @@ register_template(
), ),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<end_of_turn>"], stop_words=["<end_of_turn>"],
replace_eos=True,
template_class=Llama2Template, template_class=Llama2Template,
) )
...@@ -853,6 +934,7 @@ register_template( ...@@ -853,6 +934,7 @@ register_template(
), ),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<end_of_turn>"], stop_words=["<end_of_turn>"],
replace_eos=True,
mm_plugin=get_mm_plugin("gemma3", image_token="<image_soft_token>"), mm_plugin=get_mm_plugin("gemma3", image_token="<image_soft_token>"),
template_class=Llama2Template, template_class=Llama2Template,
) )
...@@ -872,6 +954,22 @@ register_template( ...@@ -872,6 +954,22 @@ register_template(
) )
# copied from glm4 template
register_template(
name="glmz1",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
format_assistant=StringFormatter(slots=["\n{{content}}"]),
format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4"),
format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]),
format_tools=ToolFormatter(tool_format="glm4"),
format_prefix=EmptyFormatter(slots=["[gMASK]<sop>"]),
stop_words=["<|user|>", "<|observation|>"],
efficient_eos=True,
template_class=ReasoningTemplate,
)
register_template( register_template(
name="granite3", name="granite3",
format_user=StringFormatter( format_user=StringFormatter(
...@@ -973,6 +1071,7 @@ register_template( ...@@ -973,6 +1071,7 @@ register_template(
stop_words=["<|im_end|>"], stop_words=["<|im_end|>"],
thought_words=("◁think▷", "◁/think▷"), thought_words=("◁think▷", "◁/think▷"),
mm_plugin=get_mm_plugin("kimi_vl", image_token="<|media_pad|>"), mm_plugin=get_mm_plugin("kimi_vl", image_token="<|media_pad|>"),
template_class=ReasoningTemplate,
) )
...@@ -1018,6 +1117,7 @@ register_template( ...@@ -1018,6 +1117,7 @@ register_template(
format_tools=ToolFormatter(tool_format="llama3"), format_tools=ToolFormatter(tool_format="llama3"),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|eot_id|>", "<|eom_id|>"], stop_words=["<|eot_id|>", "<|eom_id|>"],
replace_eos=True,
) )
...@@ -1037,6 +1137,7 @@ register_template( ...@@ -1037,6 +1137,7 @@ register_template(
format_tools=ToolFormatter(tool_format="llama3"), format_tools=ToolFormatter(tool_format="llama3"),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|eot|>", "<|eom|>"], stop_words=["<|eot|>", "<|eom|>"],
replace_eos=True,
mm_plugin=get_mm_plugin(name="llama4", image_token="<|image|>"), mm_plugin=get_mm_plugin(name="llama4", image_token="<|image|>"),
) )
...@@ -1066,6 +1167,7 @@ register_template( ...@@ -1066,6 +1167,7 @@ register_template(
format_tools=ToolFormatter(tool_format="llama3"), format_tools=ToolFormatter(tool_format="llama3"),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|eot_id|>", "<|eom_id|>"], stop_words=["<|eot_id|>", "<|eom_id|>"],
replace_eos=True,
mm_plugin=get_mm_plugin(name="mllama", image_token="<|image|>"), mm_plugin=get_mm_plugin(name="mllama", image_token="<|image|>"),
) )
...@@ -1079,6 +1181,7 @@ register_template( ...@@ -1079,6 +1181,7 @@ register_template(
format_system=StringFormatter(slots=["<|im_system|>system<|im_middle|>{{content}}<|im_end|>"]), format_system=StringFormatter(slots=["<|im_system|>system<|im_middle|>{{content}}<|im_end|>"]),
default_system="You are a helpful assistant provided by Moonshot-AI.", default_system="You are a helpful assistant provided by Moonshot-AI.",
stop_words=["<|im_end|>"], stop_words=["<|im_end|>"],
replace_eos=True,
) )
...@@ -1131,6 +1234,7 @@ register_template( ...@@ -1131,6 +1234,7 @@ register_template(
format_tools=ToolFormatter(tool_format="llama3"), format_tools=ToolFormatter(tool_format="llama3"),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|eot_id|>", "<|eom_id|>"], stop_words=["<|eot_id|>", "<|eom_id|>"],
replace_eos=True,
mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"), mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"),
) )
...@@ -1163,6 +1267,7 @@ register_template( ...@@ -1163,6 +1267,7 @@ register_template(
format_tools=ToolFormatter(tool_format="qwen"), format_tools=ToolFormatter(tool_format="qwen"),
default_system="You are a helpful assistant.", default_system="You are a helpful assistant.",
stop_words=["<|im_end|>"], stop_words=["<|im_end|>"],
replace_eos=True,
mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"), mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"),
) )
...@@ -1233,6 +1338,42 @@ register_template( ...@@ -1233,6 +1338,42 @@ register_template(
) )
# copied from qwen template
register_template(
name="mimo",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen"),
format_observation=StringFormatter(
slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
),
format_tools=ToolFormatter(tool_format="qwen"),
default_system="You are a helpful assistant.",
stop_words=["<|im_end|>"],
replace_eos=True,
template_class=ReasoningTemplate,
)
# copied from qwen2vl
register_template(
name="mimo_vl",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen"),
format_observation=StringFormatter(
slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
),
format_tools=ToolFormatter(tool_format="qwen"),
default_system="You are MiMo, an AI assistant developed by Xiaomi.",
stop_words=["<|im_end|>"],
replace_eos=True,
mm_plugin=get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>", video_token="<|video_pad|>"),
template_class=ReasoningTemplate,
)
# copied from chatml template # copied from chatml template
register_template( register_template(
name="minicpm_v", name="minicpm_v",
...@@ -1363,6 +1504,7 @@ register_template( ...@@ -1363,6 +1504,7 @@ register_template(
), ),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<end_of_turn>"], stop_words=["<end_of_turn>"],
replace_eos=True,
mm_plugin=get_mm_plugin(name="paligemma", image_token="<image>"), mm_plugin=get_mm_plugin(name="paligemma", image_token="<image>"),
template_class=Llama2Template, template_class=Llama2Template,
) )
...@@ -1374,6 +1516,7 @@ register_template( ...@@ -1374,6 +1516,7 @@ register_template(
format_assistant=StringFormatter(slots=["{{content}}<|end|>\n"]), format_assistant=StringFormatter(slots=["{{content}}<|end|>\n"]),
format_system=StringFormatter(slots=["<|system|>\n{{content}}<|end|>\n"]), format_system=StringFormatter(slots=["<|system|>\n{{content}}<|end|>\n"]),
stop_words=["<|end|>"], stop_words=["<|end|>"],
replace_eos=True,
) )
...@@ -1384,6 +1527,7 @@ register_template( ...@@ -1384,6 +1527,7 @@ register_template(
format_system=StringFormatter(slots=["<|system|>\n{{content}}<|end|>\n"]), format_system=StringFormatter(slots=["<|system|>\n{{content}}<|end|>\n"]),
format_prefix=EmptyFormatter(slots=[{"<|endoftext|>"}]), format_prefix=EmptyFormatter(slots=[{"<|endoftext|>"}]),
stop_words=["<|end|>"], stop_words=["<|end|>"],
replace_eos=True,
) )
...@@ -1395,6 +1539,7 @@ register_template( ...@@ -1395,6 +1539,7 @@ register_template(
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>"]), format_assistant=StringFormatter(slots=["{{content}}<|im_end|>"]),
format_system=StringFormatter(slots=["<|im_start|>system<|im_sep|>{{content}}<|im_end|>"]), format_system=StringFormatter(slots=["<|im_start|>system<|im_sep|>{{content}}<|im_end|>"]),
stop_words=["<|im_end|>"], stop_words=["<|im_end|>"],
replace_eos=True,
) )
...@@ -1425,6 +1570,7 @@ register_template( ...@@ -1425,6 +1570,7 @@ register_template(
format_tools=ToolFormatter(tool_format="qwen"), format_tools=ToolFormatter(tool_format="qwen"),
default_system="You are Qwen, created by Alibaba Cloud. You are a helpful assistant.", default_system="You are Qwen, created by Alibaba Cloud. You are a helpful assistant.",
stop_words=["<|im_end|>"], stop_words=["<|im_end|>"],
replace_eos=True,
) )
...@@ -1440,6 +1586,8 @@ register_template( ...@@ -1440,6 +1586,8 @@ register_template(
), ),
format_tools=ToolFormatter(tool_format="qwen"), format_tools=ToolFormatter(tool_format="qwen"),
stop_words=["<|im_end|>"], stop_words=["<|im_end|>"],
replace_eos=True,
template_class=ReasoningTemplate,
) )
...@@ -1451,6 +1599,7 @@ register_template( ...@@ -1451,6 +1599,7 @@ register_template(
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
default_system="You are a helpful assistant.", default_system="You are a helpful assistant.",
stop_words=["<|im_end|>"], stop_words=["<|im_end|>"],
replace_eos=True,
mm_plugin=get_mm_plugin(name="qwen2_audio", audio_token="<|AUDIO|>"), mm_plugin=get_mm_plugin(name="qwen2_audio", audio_token="<|AUDIO|>"),
) )
...@@ -1468,6 +1617,7 @@ register_template( ...@@ -1468,6 +1617,7 @@ register_template(
format_tools=ToolFormatter(tool_format="qwen"), format_tools=ToolFormatter(tool_format="qwen"),
default_system="You are a helpful assistant.", default_system="You are a helpful assistant.",
stop_words=["<|im_end|>"], stop_words=["<|im_end|>"],
replace_eos=True,
mm_plugin=get_mm_plugin( mm_plugin=get_mm_plugin(
name="qwen2_omni", audio_token="<|AUDIO|>", image_token="<|IMAGE|>", video_token="<|VIDEO|>" name="qwen2_omni", audio_token="<|AUDIO|>", image_token="<|IMAGE|>", video_token="<|VIDEO|>"
), ),
...@@ -1486,6 +1636,7 @@ register_template( ...@@ -1486,6 +1636,7 @@ register_template(
format_tools=ToolFormatter(tool_format="qwen"), format_tools=ToolFormatter(tool_format="qwen"),
default_system="You are a helpful assistant.", default_system="You are a helpful assistant.",
stop_words=["<|im_end|>"], stop_words=["<|im_end|>"],
replace_eos=True,
mm_plugin=get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>", video_token="<|video_pad|>"), mm_plugin=get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>", video_token="<|video_pad|>"),
) )
...@@ -1503,6 +1654,20 @@ register_template( ...@@ -1503,6 +1654,20 @@ register_template(
) )
register_template(
name="seed_coder",
format_user=StringFormatter(
slots=[{"bos_token"}, "user\n{{content}}", {"eos_token"}, {"bos_token"}, "assistant\n"]
),
format_system=StringFormatter(slots=[{"bos_token"}, "system\n{{content}}", {"eos_token"}]),
default_system=(
"You are an AI programming assistant, utilizing the Seed-Coder model, developed by ByteDance Seed, "
"and you only answer questions related to computer science. For politically sensitive questions, "
"security and privacy issues, and other non-computer science questions, you will refuse to answer.\n\n"
),
)
# copied from llama3 template # copied from llama3 template
register_template( register_template(
name="skywork_o1", name="skywork_o1",
...@@ -1538,6 +1703,25 @@ register_template( ...@@ -1538,6 +1703,25 @@ register_template(
) )
register_template(
name="smollm",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
stop_words=["<|im_end|>"],
)
register_template(
name="smollm2",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
stop_words=["<|im_end|>"],
default_system="You are a helpful AI assistant named SmolLM, trained by Hugging Face.",
)
register_template( register_template(
name="solar", name="solar",
format_user=StringFormatter(slots=["### User:\n{{content}}\n\n### Assistant:\n"]), format_user=StringFormatter(slots=["### User:\n{{content}}\n\n### Assistant:\n"]),
......
...@@ -93,6 +93,7 @@ class DefaultToolUtils(ToolUtils): ...@@ -93,6 +93,7 @@ class DefaultToolUtils(ToolUtils):
tool_text = "" tool_text = ""
tool_names = [] tool_names = []
for tool in tools: for tool in tools:
tool = tool.get("function", "") if tool.get("type") == "function" else tool
param_text = "" param_text = ""
for name, param in tool["parameters"]["properties"].items(): for name, param in tool["parameters"]["properties"].items():
required, enum, items = "", "", "" required, enum, items = "", "", ""
...@@ -124,11 +125,7 @@ class DefaultToolUtils(ToolUtils): ...@@ -124,11 +125,7 @@ class DefaultToolUtils(ToolUtils):
@override @override
@staticmethod @staticmethod
def function_formatter(functions: list["FunctionCall"]) -> str: def function_formatter(functions: list["FunctionCall"]) -> str:
function_text = "" return "\n".join([f"Action: {name}\nAction Input: {arguments}" for name, arguments in functions])
for name, arguments in functions:
function_text += f"Action: {name}\nAction Input: {arguments}\n"
return function_text
@override @override
@staticmethod @staticmethod
...@@ -159,6 +156,7 @@ class GLM4ToolUtils(ToolUtils): ...@@ -159,6 +156,7 @@ class GLM4ToolUtils(ToolUtils):
def tool_formatter(tools: list[dict[str, Any]]) -> str: def tool_formatter(tools: list[dict[str, Any]]) -> str:
tool_text = "" tool_text = ""
for tool in tools: for tool in tools:
tool = tool.get("function", "") if tool.get("type") == "function" else tool
tool_text += "\n\n## {name}\n\n{body}\n在调用上述函数时,请使用 Json 格式表示调用的参数。".format( tool_text += "\n\n## {name}\n\n{body}\n在调用上述函数时,请使用 Json 格式表示调用的参数。".format(
name=tool["name"], body=json.dumps(tool, indent=4, ensure_ascii=False) name=tool["name"], body=json.dumps(tool, indent=4, ensure_ascii=False)
) )
...@@ -200,7 +198,7 @@ class Llama3ToolUtils(ToolUtils): ...@@ -200,7 +198,7 @@ class Llama3ToolUtils(ToolUtils):
date = datetime.now().strftime("%d %b %Y") date = datetime.now().strftime("%d %b %Y")
tool_text = "" tool_text = ""
for tool in tools: for tool in tools:
wrapped_tool = {"type": "function", "function": tool} wrapped_tool = tool if tool.get("type") == "function" else {"type": "function", "function": tool}
tool_text += json.dumps(wrapped_tool, indent=4, ensure_ascii=False) + "\n\n" tool_text += json.dumps(wrapped_tool, indent=4, ensure_ascii=False) + "\n\n"
return LLAMA3_TOOL_PROMPT.format(date=date, tool_text=tool_text) return LLAMA3_TOOL_PROMPT.format(date=date, tool_text=tool_text)
...@@ -208,24 +206,23 @@ class Llama3ToolUtils(ToolUtils): ...@@ -208,24 +206,23 @@ class Llama3ToolUtils(ToolUtils):
@override @override
@staticmethod @staticmethod
def function_formatter(functions: list["FunctionCall"]) -> str: def function_formatter(functions: list["FunctionCall"]) -> str:
if len(functions) > 1: function_objects = [{"name": name, "parameters": json.loads(arguments)} for name, arguments in functions]
raise ValueError("Llama-3 does not support parallel functions.") return json.dumps(function_objects[0] if len(function_objects) == 1 else function_objects, ensure_ascii=False)
return f'{{"name": "{functions[0].name}", "parameters": {functions[0].arguments}}}'
@override @override
@staticmethod @staticmethod
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]: def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
try: try:
tool = json.loads(content.strip()) tools = json.loads(content.strip())
except json.JSONDecodeError: except json.JSONDecodeError:
return content return content
if "name" not in tool or "parameters" not in tool: tools = [tools] if not isinstance(tools, list) else tools
try:
return [FunctionCall(tool["name"], json.dumps(tool["parameters"], ensure_ascii=False)) for tool in tools]
except KeyError:
return content return content
return [FunctionCall(tool["name"], json.dumps(tool["parameters"], ensure_ascii=False))]
class MistralToolUtils(ToolUtils): class MistralToolUtils(ToolUtils):
r"""Mistral v0.3 tool using template.""" r"""Mistral v0.3 tool using template."""
...@@ -235,18 +232,16 @@ class MistralToolUtils(ToolUtils): ...@@ -235,18 +232,16 @@ class MistralToolUtils(ToolUtils):
def tool_formatter(tools: list[dict[str, Any]]) -> str: def tool_formatter(tools: list[dict[str, Any]]) -> str:
wrapped_tools = [] wrapped_tools = []
for tool in tools: for tool in tools:
wrapped_tools.append({"type": "function", "function": tool}) wrapped_tools.append(tool if tool.get("type") == "function" else {"type": "function", "function": tool})
return "[AVAILABLE_TOOLS] " + json.dumps(wrapped_tools, ensure_ascii=False) + "[/AVAILABLE_TOOLS]" return "[AVAILABLE_TOOLS] " + json.dumps(wrapped_tools, ensure_ascii=False) + "[/AVAILABLE_TOOLS]"
@override @override
@staticmethod @staticmethod
def function_formatter(functions: list["FunctionCall"]) -> str: def function_formatter(functions: list["FunctionCall"]) -> str:
function_texts = [] return json.dumps(
for name, arguments in functions: [{"name": name, "arguments": json.loads(arguments)} for name, arguments in functions], ensure_ascii=False
function_texts.append(f'{{"name": "{name}", "arguments": {arguments}}}') )
return "[" + ", ".join(function_texts) + "]"
@override @override
@staticmethod @staticmethod
...@@ -256,17 +251,11 @@ class MistralToolUtils(ToolUtils): ...@@ -256,17 +251,11 @@ class MistralToolUtils(ToolUtils):
except json.JSONDecodeError: except json.JSONDecodeError:
return content return content
if not isinstance(tools, list): tools = [tools] if not isinstance(tools, list) else tools
tools = [tools] try:
return [FunctionCall(tool["name"], json.dumps(tool["arguments"], ensure_ascii=False)) for tool in tools]
results = [] except KeyError:
for tool in tools: return content
if "name" not in tool or "arguments" not in tool:
return content
results.append(FunctionCall(tool["name"], json.dumps(tool["arguments"], ensure_ascii=False)))
return results
class QwenToolUtils(ToolUtils): class QwenToolUtils(ToolUtils):
...@@ -277,7 +266,7 @@ class QwenToolUtils(ToolUtils): ...@@ -277,7 +266,7 @@ class QwenToolUtils(ToolUtils):
def tool_formatter(tools: list[dict[str, Any]]) -> str: def tool_formatter(tools: list[dict[str, Any]]) -> str:
tool_text = "" tool_text = ""
for tool in tools: for tool in tools:
wrapped_tool = {"type": "function", "function": tool} wrapped_tool = tool if tool.get("type") == "function" else {"type": "function", "function": tool}
tool_text += "\n" + json.dumps(wrapped_tool, ensure_ascii=False) tool_text += "\n" + json.dumps(wrapped_tool, ensure_ascii=False)
return QWEN_TOOL_PROMPT.format(tool_text=tool_text) return QWEN_TOOL_PROMPT.format(tool_text=tool_text)
...@@ -285,13 +274,11 @@ class QwenToolUtils(ToolUtils): ...@@ -285,13 +274,11 @@ class QwenToolUtils(ToolUtils):
@override @override
@staticmethod @staticmethod
def function_formatter(functions: list["FunctionCall"]) -> str: def function_formatter(functions: list["FunctionCall"]) -> str:
function_texts = [] function_texts = [
for name, arguments in functions: json.dumps({"name": name, "arguments": json.loads(arguments)}, ensure_ascii=False)
function_texts.append( for name, arguments in functions
"<tool_call>\n" + f'{{"name": "{name}", "arguments": {arguments}}}' + "\n</tool_call>" ]
) return "\n".join([f"<tool_call>\n{text}\n</tool_call>" for text in function_texts])
return "\n".join(function_texts)
@override @override
@staticmethod @staticmethod
......
...@@ -513,7 +513,7 @@ register_model_group( ...@@ -513,7 +513,7 @@ register_model_group(
register_model_group( register_model_group(
models={ models={
"DeepSeek-V2-236B-Chat-0628": { "DeepSeek-V2-236B-0628-Chat": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2-Chat-0628", DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2-Chat-0628",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2-Chat-0628", DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2-Chat-0628",
}, },
...@@ -521,7 +521,7 @@ register_model_group( ...@@ -521,7 +521,7 @@ register_model_group(
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2.5", DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2.5",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2.5", DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2.5",
}, },
"DeepSeek-V2.5-236B-Chat-1210": { "DeepSeek-V2.5-236B-1210-Chat": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2.5-1210", DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2.5-1210",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2.5-1210", DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2.5-1210",
}, },
...@@ -533,6 +533,17 @@ register_model_group( ...@@ -533,6 +533,17 @@ register_model_group(
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V3", DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V3",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V3", DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V3",
}, },
"DeepSeek-V3-671B-0324-Chat": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V3-0324",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V3-0324",
},
},
template="deepseek3",
)
register_model_group(
models={
"DeepSeek-R1-1.5B-Distill": { "DeepSeek-R1-1.5B-Distill": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
...@@ -545,6 +556,10 @@ register_model_group( ...@@ -545,6 +556,10 @@ register_model_group(
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-R1-Distill-Llama-8B", DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-R1-Distill-Llama-8B", DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
}, },
"DeepSeek-R1-8B-0528-Distill": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-R1-0528-Qwen3-8B",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-R1-0528-Qwen3-8B",
},
"DeepSeek-R1-14B-Distill": { "DeepSeek-R1-14B-Distill": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-R1-Distill-Qwen-14B", DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-R1-Distill-Qwen-14B",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-R1-Distill-Qwen-14B", DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-R1-Distill-Qwen-14B",
...@@ -565,8 +580,12 @@ register_model_group( ...@@ -565,8 +580,12 @@ register_model_group(
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-R1", DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-R1",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-R1", DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-R1",
}, },
"DeepSeek-R1-671B-0528-Chat": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-R1-0528",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-R1-0528",
},
}, },
template="deepseek3", template="deepseekr1",
) )
...@@ -673,6 +692,10 @@ register_model_group( ...@@ -673,6 +692,10 @@ register_model_group(
DownloadSource.DEFAULT: "google/gemma-3-1b-it", DownloadSource.DEFAULT: "google/gemma-3-1b-it",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-3-1b-it", DownloadSource.MODELSCOPE: "LLM-Research/gemma-3-1b-it",
}, },
"MedGemma-27B-Instruct": {
DownloadSource.DEFAULT: "google/medgemma-27b-text-it",
DownloadSource.MODELSCOPE: "google/medgemma-27b-text-it",
},
}, },
template="gemma", template="gemma",
) )
...@@ -704,6 +727,14 @@ register_model_group( ...@@ -704,6 +727,14 @@ register_model_group(
DownloadSource.DEFAULT: "google/gemma-3-27b-it", DownloadSource.DEFAULT: "google/gemma-3-27b-it",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-3-27b-it", DownloadSource.MODELSCOPE: "LLM-Research/gemma-3-27b-it",
}, },
"MedGemma-4B": {
DownloadSource.DEFAULT: "google/medgemma-4b-pt",
DownloadSource.MODELSCOPE: "google/medgemma-4b-pt",
},
"MedGemma-4B-Instruct": {
DownloadSource.DEFAULT: "google/medgemma-4b-it",
DownloadSource.MODELSCOPE: "google/medgemma-4b-it",
},
}, },
template="gemma3", template="gemma3",
multimodal=True, multimodal=True,
...@@ -737,6 +768,13 @@ register_model_group( ...@@ -737,6 +768,13 @@ register_model_group(
DownloadSource.DEFAULT: "THUDM/GLM-4-32B-0414", DownloadSource.DEFAULT: "THUDM/GLM-4-32B-0414",
DownloadSource.MODELSCOPE: "ZhipuAI/GLM-4-32B-0414", DownloadSource.MODELSCOPE: "ZhipuAI/GLM-4-32B-0414",
}, },
},
template="glm4",
)
register_model_group(
models={
"GLM-Z1-9B-0414-Chat": { "GLM-Z1-9B-0414-Chat": {
DownloadSource.DEFAULT: "THUDM/GLM-Z1-9B-0414", DownloadSource.DEFAULT: "THUDM/GLM-Z1-9B-0414",
DownloadSource.MODELSCOPE: "ZhipuAI/GLM-Z1-9B-0414", DownloadSource.MODELSCOPE: "ZhipuAI/GLM-Z1-9B-0414",
...@@ -746,7 +784,7 @@ register_model_group( ...@@ -746,7 +784,7 @@ register_model_group(
DownloadSource.MODELSCOPE: "ZhipuAI/GLM-Z1-32B-0414", DownloadSource.MODELSCOPE: "ZhipuAI/GLM-Z1-32B-0414",
}, },
}, },
template="glm4", template="glmz1",
) )
...@@ -869,12 +907,13 @@ register_model_group( ...@@ -869,12 +907,13 @@ register_model_group(
register_model_group( register_model_group(
models={ models={
"Granite-3.2-1B-A400M-Base": { "Granite-Vision-3.2-2B": {
DownloadSource.DEFAULT: "ibm-granite/granite-vision-3.2-2b", DownloadSource.DEFAULT: "ibm-granite/granite-vision-3.2-2b",
DownloadSource.MODELSCOPE: "AI-ModelScope/granite-vision-3.2-2b", DownloadSource.MODELSCOPE: "AI-ModelScope/granite-vision-3.2-2b",
}, },
}, },
template="granite3_vision", template="granite3_vision",
multimodal=True,
) )
...@@ -1398,6 +1437,45 @@ register_model_group( ...@@ -1398,6 +1437,45 @@ register_model_group(
) )
register_model_group(
models={
"MiMo-7B-Base": {
DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-7B-Base",
DownloadSource.MODELSCOPE: "XiaomiMiMo/MiMo-7B-Base",
},
"MiMo-7B-Instruct": {
DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-7B-SFT",
DownloadSource.MODELSCOPE: "XiaomiMiMo/MiMo-7B-SFT",
},
"MiMo-7B-Instruct-RL": {
DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-7B-RL",
DownloadSource.MODELSCOPE: "XiaomiMiMo/MiMo-7B-RL",
},
"MiMo-7B-RL-ZERO": {
DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-7B-RL-ZERO",
DownloadSource.MODELSCOPE: "XiaomiMiMo/MiMo-7B-RL-ZERO",
},
},
template="mimo",
)
register_model_group(
models={
"MiMo-7B-VL-Instruct": {
DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-VL-7B-SFT",
DownloadSource.MODELSCOPE: "XiaomiMiMo/MiMo-VL-7B-SFT",
},
"MiMo-7B-VL-RL": {
DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-VL-7B-RL",
DownloadSource.MODELSCOPE: "XiaomiMiMo/MiMo-VL-7B-RL",
},
},
template="mimo_vl",
multimodal=True,
)
register_model_group( register_model_group(
models={ models={
"MiniCPM-2B-SFT-Chat": { "MiniCPM-2B-SFT-Chat": {
...@@ -2461,6 +2539,38 @@ register_model_group( ...@@ -2461,6 +2539,38 @@ register_model_group(
DownloadSource.DEFAULT: "Qwen/Qwen3-235B-A22B", DownloadSource.DEFAULT: "Qwen/Qwen3-235B-A22B",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-235B-A22B", DownloadSource.MODELSCOPE: "Qwen/Qwen3-235B-A22B",
}, },
"Qwen3-0.6B-Instruct-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen3-0.6B-GPTQ-Int8",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-0.6B-GPTQ-Int8",
},
"Qwen3-1.7B-Instruct-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen3-1.7B-GPTQ-Int8",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-1.7B-GPTQ-Int8",
},
"Qwen3-4B-Instruct-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen3-4B-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-4B-AWQ",
},
"Qwen3-8B-Instruct-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen3-8B-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-8B-AWQ",
},
"Qwen3-14B-Instruct-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen3-14B-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-14B-AWQ",
},
"Qwen3-32B-Instruct-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen3-32B-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-32B-AWQ",
},
"Qwen3-30B-A3B-Instruct-GPTQ-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen3-30B-A3B-GPTQ-Int4",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-30B-A3B-GPTQ-Int4",
},
"Qwen3-235B-A22B-Instruct-GPTQ-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen3-235B-A22B-GPTQ-Int4",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-235B-A22B-GPTQ-Int4",
},
}, },
template="qwen3", template="qwen3",
) )
...@@ -2484,10 +2594,22 @@ register_model_group( ...@@ -2484,10 +2594,22 @@ register_model_group(
register_model_group( register_model_group(
models={ models={
"Qwen2.5-Omni-3B": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Omni-3B",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-Omni-3B",
},
"Qwen2.5-Omni-7B": { "Qwen2.5-Omni-7B": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Omni-7B", DownloadSource.DEFAULT: "Qwen/Qwen2.5-Omni-7B",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-Omni-7B", DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-Omni-7B",
} },
"Qwen2.5-Omni-7B-GPTQ-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Omni-7B-GPTQ-Int4",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-Omni-7B-GPTQ-Int4",
},
"Qwen2.5-Omni-7B-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Omni-7B-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-Omni-7B-AWQ",
},
}, },
template="qwen2_omni", template="qwen2_omni",
multimodal=True, multimodal=True,
...@@ -2598,15 +2720,17 @@ register_model_group( ...@@ -2598,15 +2720,17 @@ register_model_group(
register_model_group( register_model_group(
models={ models={
"SOLAR-10.7B-v1.0": { "Seed-Coder-8B-Base": {
DownloadSource.DEFAULT: "upstage/SOLAR-10.7B-v1.0", DownloadSource.DEFAULT: "ByteDance-Seed/Seed-Coder-8B-Base",
}, },
"SOLAR-10.7B-Instruct-v1.0": { "Seed-Coder-8B-Instruct": {
DownloadSource.DEFAULT: "upstage/SOLAR-10.7B-Instruct-v1.0", DownloadSource.DEFAULT: "ByteDance-Seed/Seed-Coder-8B-Instruct",
DownloadSource.MODELSCOPE: "AI-ModelScope/SOLAR-10.7B-Instruct-v1.0", },
"Seed-Coder-8B-Instruct-Reasoning": {
DownloadSource.DEFAULT: "ByteDance-Seed/Seed-Coder-8B-Reasoning-bf16",
}, },
}, },
template="solar", template="seed_coder",
) )
...@@ -2631,6 +2755,82 @@ register_model_group( ...@@ -2631,6 +2755,82 @@ register_model_group(
) )
register_model_group(
models={
"SmolLM-135M": {
DownloadSource.DEFAULT: "HuggingFaceTB/SmolLM-135M",
DownloadSource.MODELSCOPE: "HuggingFaceTB/SmolLM-135M",
},
"SmolLM-360M": {
DownloadSource.DEFAULT: "HuggingFaceTB/SmolLM-360M",
DownloadSource.MODELSCOPE: "HuggingFaceTB/SmolLM-360M",
},
"SmolLM-1.7B": {
DownloadSource.DEFAULT: "HuggingFaceTB/SmolLM-1.7B",
DownloadSource.MODELSCOPE: "HuggingFaceTB/SmolLM-1.7B",
},
"SmolLM-135M-Instruct": {
DownloadSource.DEFAULT: "HuggingFaceTB/SmolLM-135M-Instruct",
DownloadSource.MODELSCOPE: "HuggingFaceTB/SmolLM-135M-Instruct",
},
"SmolLM-360M-Instruct": {
DownloadSource.DEFAULT: "HuggingFaceTB/SmolLM-360M-Instruct",
DownloadSource.MODELSCOPE: "HuggingFaceTB/SmolLM-360M-Instruct",
},
"SmolLM-1.7B-Instruct": {
DownloadSource.DEFAULT: "HuggingFaceTB/SmolLM-1.7B-Instruct",
DownloadSource.MODELSCOPE: "HuggingFaceTB/SmolLM-1.7B-Instruct",
},
},
template="smollm",
)
register_model_group(
models={
"SmolLM2-135M": {
DownloadSource.DEFAULT: "HuggingFaceTB/SmolLM2-135M",
DownloadSource.MODELSCOPE: "HuggingFaceTB/SmolLM2-135M",
},
"SmolLM2-360M": {
DownloadSource.DEFAULT: "HuggingFaceTB/SmolLM2-360M",
DownloadSource.MODELSCOPE: "HuggingFaceTB/SmolLM2-360M",
},
"SmolLM2-1.7B": {
DownloadSource.DEFAULT: "HuggingFaceTB/SmolLM2-1.7B",
DownloadSource.MODELSCOPE: "HuggingFaceTB/SmolLM2-1.7B",
},
"SmolLM2-135M-Instruct": {
DownloadSource.DEFAULT: "HuggingFaceTB/SmolLM2-135M-Instruct",
DownloadSource.MODELSCOPE: "HuggingFaceTB/SmolLM2-135M-Instruct",
},
"SmolLM2-360M-Instruct": {
DownloadSource.DEFAULT: "HuggingFaceTB/SmolLM2-360M-Instruct",
DownloadSource.MODELSCOPE: "HuggingFaceTB/SmolLM2-360M-Instruct",
},
"SmolLM2-1.7B-Instruct": {
DownloadSource.DEFAULT: "HuggingFaceTB/SmolLM2-1.7B-Instruct",
DownloadSource.MODELSCOPE: "HuggingFaceTB/SmolLM2-1.7B-Instruct",
},
},
template="smollm2",
)
register_model_group(
models={
"SOLAR-10.7B-v1.0": {
DownloadSource.DEFAULT: "upstage/SOLAR-10.7B-v1.0",
},
"SOLAR-10.7B-Instruct-v1.0": {
DownloadSource.DEFAULT: "upstage/SOLAR-10.7B-Instruct-v1.0",
DownloadSource.MODELSCOPE: "AI-ModelScope/SOLAR-10.7B-Instruct-v1.0",
},
},
template="solar",
)
register_model_group( register_model_group(
models={ models={
"StarCoder2-3B": { "StarCoder2-3B": {
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import platform import platform
import accelerate import accelerate
...@@ -83,4 +84,9 @@ def print_env() -> None: ...@@ -83,4 +84,9 @@ def print_env() -> None:
except Exception: except Exception:
pass pass
if os.path.exists("data"):
info["Default data directory"] = "detected"
else:
info["Default data directory"] = "not detected"
print("\n" + "\n".join([f"- {key}: {value}" for key, value in info.items()]) + "\n") print("\n" + "\n".join([f"- {key}: {value}" for key, value in info.items()]) + "\n")
...@@ -79,20 +79,27 @@ def check_version(requirement: str, mandatory: bool = False) -> None: ...@@ -79,20 +79,27 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
logger.warning_rank0_once("Version checking has been disabled, may lead to unexpected behaviors.") logger.warning_rank0_once("Version checking has been disabled, may lead to unexpected behaviors.")
return return
if "gptmodel" in requirement or "autoawq" in requirement:
pip_command = f"pip install {requirement} --no-build-isolation"
else:
pip_command = f"pip install {requirement}"
if mandatory: if mandatory:
hint = f"To fix: run `pip install {requirement}`." hint = f"To fix: run `{pip_command}`."
else: else:
hint = f"To fix: run `pip install {requirement}` or set `DISABLE_VERSION_CHECK=1` to skip this check." hint = f"To fix: run `{pip_command}` or set `DISABLE_VERSION_CHECK=1` to skip this check."
require_version(requirement, hint) require_version(requirement, hint)
def check_dependencies() -> None: def check_dependencies() -> None:
r"""Check the version of the required packages.""" r"""Check the version of the required packages."""
check_version("transformers>=4.45.0,<=4.51.3,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0") check_version(
check_version("datasets>=2.16.0,<=3.5.0") "transformers>=4.45.0,<=4.52.4,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0,!=4.52.0"
check_version("accelerate>=0.34.0,<=1.6.0") )
check_version("peft>=0.14.0,<=0.15.1") check_version("datasets>=2.16.0,<=3.6.0")
check_version("accelerate>=0.34.0,<=1.7.0")
check_version("peft>=0.14.0,<=0.15.2")
check_version("trl>=0.8.6,<=0.9.6") check_version("trl>=0.8.6,<=0.9.6")
if is_transformers_version_greater_than("4.46.0") and not is_transformers_version_greater_than("4.48.1"): if is_transformers_version_greater_than("4.46.0") and not is_transformers_version_greater_than("4.48.1"):
logger.warning_rank0_once("There are known bugs in transformers v4.46.0-v4.48.0, please use other versions.") logger.warning_rank0_once("There are known bugs in transformers v4.46.0-v4.48.0, please use other versions.")
......
...@@ -99,6 +99,10 @@ class DataArguments: ...@@ -99,6 +99,10 @@ class DataArguments:
default=0.0, default=0.0,
metadata={"help": "Size of the validation set, should be an integer or a float in range `[0,1)`."}, metadata={"help": "Size of the validation set, should be an integer or a float in range `[0,1)`."},
) )
eval_on_each_dataset: bool = field(
default=False,
metadata={"help": "Whether or not to evaluate on each dataset separately."},
)
packing: Optional[bool] = field( packing: Optional[bool] = field(
default=None, default=None,
metadata={"help": "Enable sequences packing in training. Will automatically enable in pre-training."}, metadata={"help": "Enable sequences packing in training. Will automatically enable in pre-training."},
...@@ -111,6 +115,14 @@ class DataArguments: ...@@ -111,6 +115,14 @@ class DataArguments:
default=None, default=None,
metadata={"help": "Tool format to use for constructing function calling examples."}, metadata={"help": "Tool format to use for constructing function calling examples."},
) )
default_system: Optional[str] = field(
default=None,
metadata={"help": "Override the default system message in the template."},
)
enable_thinking: Optional[bool] = field(
default=True,
metadata={"help": "Whether or not to enable thinking mode for reasoning models."},
)
tokenized_path: Optional[str] = field( tokenized_path: Optional[str] = field(
default=None, default=None,
metadata={ metadata={
...@@ -121,6 +133,10 @@ class DataArguments: ...@@ -121,6 +133,10 @@ class DataArguments:
) )
}, },
) )
data_shared_file_system: bool = field(
default=False,
metadata={"help": "Whether or not to use a shared file system for the datasets."},
)
def __post_init__(self): def __post_init__(self):
def split_arg(arg): def split_arg(arg):
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from typing import Any, Optional from typing import Any
from transformers import GenerationConfig from transformers import GenerationConfig
...@@ -62,10 +62,6 @@ class GeneratingArguments: ...@@ -62,10 +62,6 @@ class GeneratingArguments:
default=1.0, default=1.0,
metadata={"help": "Exponential penalty to the length that is used with beam-based generation."}, metadata={"help": "Exponential penalty to the length that is used with beam-based generation."},
) )
default_system: Optional[str] = field(
default=None,
metadata={"help": "Default system message to use in chat completion."},
)
skip_special_tokens: bool = field( skip_special_tokens: bool = field(
default=True, default=True,
metadata={"help": "Whether or not to remove special tokens in the decoding."}, metadata={"help": "Whether or not to remove special tokens in the decoding."},
......
...@@ -235,10 +235,6 @@ class ProcessorArguments: ...@@ -235,10 +235,6 @@ class ProcessorArguments:
default=False, default=False,
metadata={"help": "Whether to crop the image to patches for internvl."}, metadata={"help": "Whether to crop the image to patches for internvl."},
) )
use_audio_in_video: bool = field(
default=False,
metadata={"help": "Whether or not to use audio in video inputs."},
)
video_max_pixels: int = field( video_max_pixels: int = field(
default=256 * 256, default=256 * 256,
metadata={"help": "The maximum number of pixels of video inputs."}, metadata={"help": "The maximum number of pixels of video inputs."},
...@@ -255,6 +251,10 @@ class ProcessorArguments: ...@@ -255,6 +251,10 @@ class ProcessorArguments:
default=128, default=128,
metadata={"help": "The maximum number of sampled frames for video inputs."}, metadata={"help": "The maximum number of sampled frames for video inputs."},
) )
use_audio_in_video: bool = field(
default=False,
metadata={"help": "Whether or not to use audio in video inputs."},
)
audio_sampling_rate: int = field( audio_sampling_rate: int = field(
default=16000, default=16000,
metadata={"help": "The sampling rate of audio inputs."}, metadata={"help": "The sampling rate of audio inputs."},
...@@ -364,6 +364,12 @@ class SGLangArguments: ...@@ -364,6 +364,12 @@ class SGLangArguments:
default=None, default=None,
metadata={"help": "Config to initialize the SGLang engine. Please use JSON strings."}, metadata={"help": "Config to initialize the SGLang engine. Please use JSON strings."},
) )
sglang_lora_backend: Literal["triton", "flashinfer"] = field(
default="triton",
metadata={
"help": "The backend of running GEMM kernels for Lora modules. Recommend using the Triton LoRA backend for better performance and stability."
},
)
def __post_init__(self): def __post_init__(self):
if isinstance(self.sglang_config, str) and self.sglang_config.startswith("{"): if isinstance(self.sglang_config, str) and self.sglang_config.startswith("{"):
......
...@@ -148,10 +148,10 @@ def _check_extra_dependencies( ...@@ -148,10 +148,10 @@ def _check_extra_dependencies(
check_version("mixture-of-depth>=1.1.6", mandatory=True) check_version("mixture-of-depth>=1.1.6", mandatory=True)
if model_args.infer_backend == EngineName.VLLM: if model_args.infer_backend == EngineName.VLLM:
check_version("vllm>=0.4.3,<=0.8.4") check_version("vllm>=0.4.3,<=0.8.6")
check_version("vllm", mandatory=True) check_version("vllm", mandatory=True)
elif model_args.infer_backend == EngineName.SGLANG: elif model_args.infer_backend == EngineName.SGLANG:
check_version("sglang>=0.4.4") check_version("sglang>=0.4.5")
check_version("sglang", mandatory=True) check_version("sglang", mandatory=True)
if finetuning_args.use_galore: if finetuning_args.use_galore:
......
...@@ -64,6 +64,7 @@ class RayArguments: ...@@ -64,6 +64,7 @@ class RayArguments:
raise ValueError( raise ValueError(
f"ray_storage_filesystem must be one of ['s3', 'gs', 'gcs'], got {self.ray_storage_filesystem}" f"ray_storage_filesystem must be one of ['s3', 'gs', 'gcs'], got {self.ray_storage_filesystem}"
) )
import pyarrow.fs as fs import pyarrow.fs as fs
if self.ray_storage_filesystem == "s3": if self.ray_storage_filesystem == "s3":
......
...@@ -29,10 +29,8 @@ if TYPE_CHECKING: ...@@ -29,10 +29,8 @@ if TYPE_CHECKING:
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
def configure_attn_implementation( def configure_attn_implementation(config: "PretrainedConfig", model_args: "ModelArguments") -> None:
config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool if getattr(config, "model_type", None) == "gemma2":
) -> None:
if getattr(config, "model_type", None) == "gemma2" and is_trainable:
if model_args.flash_attn == AttentionFunction.AUTO or model_args.flash_attn == AttentionFunction.FA2: if model_args.flash_attn == AttentionFunction.AUTO or model_args.flash_attn == AttentionFunction.FA2:
if is_flash_attn_2_available(): if is_flash_attn_2_available():
if model_args.flash_attn != AttentionFunction.FA2: if model_args.flash_attn != AttentionFunction.FA2:
......
...@@ -45,16 +45,24 @@ def apply_liger_kernel( ...@@ -45,16 +45,24 @@ def apply_liger_kernel(
from liger_kernel.transformers import apply_liger_kernel_to_gemma3 as apply_liger_kernel from liger_kernel.transformers import apply_liger_kernel_to_gemma3 as apply_liger_kernel
elif model_type == "gemma3_text": elif model_type == "gemma3_text":
from liger_kernel.transformers import apply_liger_kernel_to_gemma3_text as apply_liger_kernel from liger_kernel.transformers import apply_liger_kernel_to_gemma3_text as apply_liger_kernel
elif model_type == "paligemma": elif model_type == "glm4":
from liger_kernel.transformers import apply_liger_kernel_to_paligemma as apply_liger_kernel from liger_kernel.transformers import apply_liger_kernel_to_glm4 as apply_liger_kernel
elif model_type == "granite":
from liger_kernel.transformers import apply_liger_kernel_to_granite as apply_liger_kernel
elif model_type == "llama": elif model_type == "llama":
from liger_kernel.transformers import apply_liger_kernel_to_llama as apply_liger_kernel from liger_kernel.transformers import apply_liger_kernel_to_llama as apply_liger_kernel
elif model_type == "llava":
from liger_kernel.transformers import apply_liger_kernel_to_llava as apply_liger_kernel
elif model_type == "mistral": elif model_type == "mistral":
from liger_kernel.transformers import apply_liger_kernel_to_mistral as apply_liger_kernel from liger_kernel.transformers import apply_liger_kernel_to_mistral as apply_liger_kernel
elif model_type == "mixtral": elif model_type == "mixtral":
from liger_kernel.transformers import apply_liger_kernel_to_mixtral as apply_liger_kernel from liger_kernel.transformers import apply_liger_kernel_to_mixtral as apply_liger_kernel
elif model_type == "mllama": elif model_type == "mllama":
from liger_kernel.transformers import apply_liger_kernel_to_mllama as apply_liger_kernel from liger_kernel.transformers import apply_liger_kernel_to_mllama as apply_liger_kernel
elif model_type == "olmo2":
from liger_kernel.transformers import apply_liger_kernel_to_olmo2 as apply_liger_kernel
elif model_type == "paligemma":
from liger_kernel.transformers import apply_liger_kernel_to_paligemma as apply_liger_kernel
elif model_type == "phi3": elif model_type == "phi3":
from liger_kernel.transformers import apply_liger_kernel_to_phi3 as apply_liger_kernel from liger_kernel.transformers import apply_liger_kernel_to_phi3 as apply_liger_kernel
elif model_type == "qwen2": elif model_type == "qwen2":
...@@ -63,6 +71,8 @@ def apply_liger_kernel( ...@@ -63,6 +71,8 @@ def apply_liger_kernel(
from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl as apply_liger_kernel from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl as apply_liger_kernel
elif model_type == "qwen2_5_vl": elif model_type == "qwen2_5_vl":
from liger_kernel.transformers import apply_liger_kernel_to_qwen2_5_vl as apply_liger_kernel from liger_kernel.transformers import apply_liger_kernel_to_qwen2_5_vl as apply_liger_kernel
elif model_type == "qwen3":
from liger_kernel.transformers import apply_liger_kernel_to_qwen3 as apply_liger_kernel
else: else:
logger.warning_rank0("Current model does not support liger kernel.") logger.warning_rank0("Current model does not support liger kernel.")
return return
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment