Commit 24534501 authored by mashun1's avatar mashun1
Browse files

parallel_tool

parent c4ba4563
...@@ -79,6 +79,10 @@ class SGLangEngine(BaseEngine): ...@@ -79,6 +79,10 @@ class SGLangEngine(BaseEngine):
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args) self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args)
self.template.mm_plugin.expand_mm_tokens = False # for sglang generate self.template.mm_plugin.expand_mm_tokens = False # for sglang generate
self.generating_args = generating_args.to_dict() self.generating_args = generating_args.to_dict()
if model_args.adapter_name_or_path is not None:
self.lora_request = True
else:
self.lora_request = False
launch_cmd = [ launch_cmd = [
"python3 -m sglang.launch_server", "python3 -m sglang.launch_server",
...@@ -90,6 +94,15 @@ class SGLangEngine(BaseEngine): ...@@ -90,6 +94,15 @@ class SGLangEngine(BaseEngine):
f"--download-dir {model_args.cache_dir}", f"--download-dir {model_args.cache_dir}",
"--log-level error", "--log-level error",
] ]
if self.lora_request:
launch_cmd.extend(
[
"--max-loras-per-batch 1",
f"--lora-backend {model_args.sglang_lora_backend}",
f"--lora-paths lora0={model_args.adapter_name_or_path[0]}",
"--disable-radix-cache",
]
)
launch_cmd = " ".join(launch_cmd) launch_cmd = " ".join(launch_cmd)
logger.info_rank0(f"Starting SGLang server with command: {launch_cmd}") logger.info_rank0(f"Starting SGLang server with command: {launch_cmd}")
try: try:
...@@ -147,7 +160,6 @@ class SGLangEngine(BaseEngine): ...@@ -147,7 +160,6 @@ class SGLangEngine(BaseEngine):
messages, images or [], videos or [], audios or [], self.processor messages, images or [], videos or [], audios or [], self.processor
) )
paired_messages = messages + [{"role": "assistant", "content": ""}] paired_messages = messages + [{"role": "assistant", "content": ""}]
system = system or self.generating_args["default_system"]
prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools) prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools)
prompt_length = len(prompt_ids) prompt_length = len(prompt_ids)
...@@ -200,6 +212,8 @@ class SGLangEngine(BaseEngine): ...@@ -200,6 +212,8 @@ class SGLangEngine(BaseEngine):
"sampling_params": sampling_params, "sampling_params": sampling_params,
"stream": True, "stream": True,
} }
if self.lora_request:
json_data["lora_request"] = ["lora0"]
response = requests.post(f"{self.base_url}/generate", json=json_data, stream=True) response = requests.post(f"{self.base_url}/generate", json=json_data, stream=True)
if response.status_code != 200: if response.status_code != 200:
raise RuntimeError(f"SGLang server error: {response.status_code}, {response.text}") raise RuntimeError(f"SGLang server error: {response.status_code}, {response.text}")
......
...@@ -124,7 +124,6 @@ class VllmEngine(BaseEngine): ...@@ -124,7 +124,6 @@ class VllmEngine(BaseEngine):
messages, images or [], videos or [], audios or [], self.processor messages, images or [], videos or [], audios or [], self.processor
) )
paired_messages = messages + [{"role": "assistant", "content": ""}] paired_messages = messages + [{"role": "assistant", "content": ""}]
system = system or self.generating_args["default_system"]
prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools) prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools)
prompt_length = len(prompt_ids) prompt_length = len(prompt_ids)
......
...@@ -73,7 +73,7 @@ def main(): ...@@ -73,7 +73,7 @@ def main():
"help": partial(print, USAGE), "help": partial(print, USAGE),
} }
command = sys.argv.pop(1) if len(sys.argv) >= 1 else "help" command = sys.argv.pop(1) if len(sys.argv) > 1 else "help"
if command == "train" and (is_env_enabled("FORCE_TORCHRUN") or (get_device_count() > 1 and not use_ray())): if command == "train" and (is_env_enabled("FORCE_TORCHRUN") or (get_device_count() > 1 and not use_ray())):
# launch distributed training # launch distributed training
nnodes = os.getenv("NNODES", "1") nnodes = os.getenv("NNODES", "1")
......
...@@ -169,11 +169,22 @@ def read_cloud_json(cloud_path): ...@@ -169,11 +169,22 @@ def read_cloud_json(cloud_path):
try: try:
# Try with anonymous access first # Try with anonymous access first
fs = setup_fs(cloud_path, anon=True) fs = setup_fs(cloud_path, anon=True)
return _read_json_with_fs(fs, cloud_path, lines=cloud_path.endswith(".jsonl"))
except Exception: except Exception:
# Try again with credentials # Try again with credentials
fs = setup_fs(cloud_path) fs = setup_fs(cloud_path)
return _read_json_with_fs(fs, cloud_path, lines=cloud_path.endswith(".jsonl"))
if fs.isdir(cloud_path):
files = [x["Key"] for x in fs.listdir(cloud_path)]
else:
files = [cloud_path]
# filter out non-JSON files
files = [file for file in files if file.endswith(".json") or file.endswith(".jsonl")]
if not files:
raise ValueError(f"No JSON/JSONL files found in the specified path: {cloud_path}")
data = []
for file in files:
data.extend(_read_json_with_fs(fs, file, lines=file.endswith(".jsonl")))
return data
def _read_json_with_fs(fs, path, lines=True): def _read_json_with_fs(fs, path, lines=True):
......
...@@ -168,7 +168,7 @@ def _get_merged_dataset( ...@@ -168,7 +168,7 @@ def _get_merged_dataset(
data_args: "DataArguments", data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo", "kto"], stage: Literal["pt", "sft", "rm", "ppo", "kto"],
merge: bool = True, return_dict: bool = False,
) -> Optional[Union["Dataset", "IterableDataset", dict[str, "Dataset"]]]: ) -> Optional[Union["Dataset", "IterableDataset", dict[str, "Dataset"]]]:
r"""Return the merged datasets in the standard format.""" r"""Return the merged datasets in the standard format."""
if dataset_names is None: if dataset_names is None:
...@@ -181,10 +181,10 @@ def _get_merged_dataset( ...@@ -181,10 +181,10 @@ def _get_merged_dataset(
datasets[dataset_name] = _load_single_dataset(dataset_attr, model_args, data_args, training_args) datasets[dataset_name] = _load_single_dataset(dataset_attr, model_args, data_args, training_args)
if merge: if return_dict:
return merge_dataset(list(datasets.values()), data_args, seed=training_args.seed)
else:
return datasets return datasets
else:
return merge_dataset(list(datasets.values()), data_args, seed=training_args.seed)
def _get_dataset_processor( def _get_dataset_processor(
...@@ -303,7 +303,12 @@ def get_dataset( ...@@ -303,7 +303,12 @@ def get_dataset(
with training_args.main_process_first(desc="load dataset"): with training_args.main_process_first(desc="load dataset"):
dataset = _get_merged_dataset(data_args.dataset, model_args, data_args, training_args, stage) dataset = _get_merged_dataset(data_args.dataset, model_args, data_args, training_args, stage)
eval_dataset = _get_merged_dataset( eval_dataset = _get_merged_dataset(
data_args.eval_dataset, model_args, data_args, training_args, stage, merge=training_args.do_predict data_args.eval_dataset,
model_args,
data_args,
training_args,
stage,
return_dict=data_args.eval_on_each_dataset,
) )
with training_args.main_process_first(desc="pre-process dataset"): with training_args.main_process_first(desc="pre-process dataset"):
......
...@@ -57,7 +57,10 @@ if is_transformers_version_greater_than("4.45.0"): ...@@ -57,7 +57,10 @@ if is_transformers_version_greater_than("4.45.0"):
) )
if is_transformers_version_greater_than("4.49.0"): if is_transformers_version_greater_than("4.52.0"):
from transformers.image_utils import make_flat_list_of_images
from transformers.video_utils import make_batched_videos
elif is_transformers_version_greater_than("4.49.0"):
from transformers.image_utils import make_batched_videos, make_flat_list_of_images from transformers.image_utils import make_batched_videos, make_flat_list_of_images
...@@ -167,16 +170,45 @@ class MMPluginMixin: ...@@ -167,16 +170,45 @@ class MMPluginMixin:
) )
if self.image_token is not None and processor is None: if self.image_token is not None and processor is None:
raise ValueError("Processor was not found, please check and update your processor config.") raise ValueError("Processor was not found, please check and update your model file.")
if self.image_token is not None and image_processor is None: if self.image_token is not None and image_processor is None:
raise ValueError("Image processor was not found, please check and update your processor config.") raise ValueError("Image processor was not found, please check and update your model file.")
if self.video_token is not None and video_processor is None: if self.video_token is not None and video_processor is None:
raise ValueError("Video processor was not found, please check and update your processor config.") raise ValueError("Video processor was not found, please check and update your model file.")
if self.audio_token is not None and feature_extractor is None: if self.audio_token is not None and feature_extractor is None:
raise ValueError("Audio feature extractor was not found, please check and update your processor config.") raise ValueError("Audio feature extractor was not found, please check and update your model file.")
def _validate_messages(
self,
messages: list[dict[str, str]],
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
):
r"""Validate if the number of images, videos and audios match the number of placeholders in messages."""
num_image_tokens, num_video_tokens, num_audio_tokens = 0, 0, 0
for message in messages:
num_image_tokens += message["content"].count(IMAGE_PLACEHOLDER)
num_video_tokens += message["content"].count(VIDEO_PLACEHOLDER)
num_audio_tokens += message["content"].count(AUDIO_PLACEHOLDER)
if len(images) != num_image_tokens:
raise ValueError(
f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens in {messages}."
)
if len(videos) != num_video_tokens:
raise ValueError(
f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens in {messages}."
)
if len(audios) != num_audio_tokens:
raise ValueError(
f"The number of audios does not match the number of {AUDIO_PLACEHOLDER} tokens in {messages}."
)
def _preprocess_image( def _preprocess_image(
self, image: "ImageObject", image_max_pixels: int, image_min_pixels: int, **kwargs self, image: "ImageObject", image_max_pixels: int, image_min_pixels: int, **kwargs
...@@ -420,6 +452,7 @@ class Gemma3Plugin(BasePlugin): ...@@ -420,6 +452,7 @@ class Gemma3Plugin(BasePlugin):
processor: Optional["MMProcessor"], processor: Optional["MMProcessor"],
) -> list[dict[str, str]]: ) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
self._validate_messages(messages, images, videos, audios)
num_image_tokens = 0 num_image_tokens = 0
messages = deepcopy(messages) messages = deepcopy(messages)
boi_token: str = getattr(processor, "boi_token") boi_token: str = getattr(processor, "boi_token")
...@@ -446,9 +479,6 @@ class Gemma3Plugin(BasePlugin): ...@@ -446,9 +479,6 @@ class Gemma3Plugin(BasePlugin):
message["content"] = content.replace("{{image}}", image_str) message["content"] = content.replace("{{image}}", image_str)
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
return messages return messages
@override @override
...@@ -566,8 +596,8 @@ class InternVLPlugin(BasePlugin): ...@@ -566,8 +596,8 @@ class InternVLPlugin(BasePlugin):
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> list[dict[str, str]]: ) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
num_image_tokens = 0 self._validate_messages(messages, images, videos, audios)
num_video_tokens = 0 num_image_tokens, num_video_tokens = 0, 0
image_seqlen = getattr(processor, "image_seq_length") if self.expand_mm_tokens else 1 image_seqlen = getattr(processor, "image_seq_length") if self.expand_mm_tokens else 1
messages = deepcopy(messages) messages = deepcopy(messages)
mm_inputs = self._get_mm_inputs(images, videos, audios, processor) mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
...@@ -579,9 +609,6 @@ class InternVLPlugin(BasePlugin): ...@@ -579,9 +609,6 @@ class InternVLPlugin(BasePlugin):
for message in messages: for message in messages:
content = message["content"] content = message["content"]
while IMAGE_PLACEHOLDER in content: while IMAGE_PLACEHOLDER in content:
if num_image_tokens >= len(images):
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
content = content.replace( content = content.replace(
IMAGE_PLACEHOLDER, IMAGE_PLACEHOLDER,
f"<img>{'<IMG_CONTEXT>' * image_seqlen * image_pixel_patch_list[num_image_tokens]}</img>", f"<img>{'<IMG_CONTEXT>' * image_seqlen * image_pixel_patch_list[num_image_tokens]}</img>",
...@@ -590,9 +617,6 @@ class InternVLPlugin(BasePlugin): ...@@ -590,9 +617,6 @@ class InternVLPlugin(BasePlugin):
num_image_tokens += 1 num_image_tokens += 1
while VIDEO_PLACEHOLDER in content: while VIDEO_PLACEHOLDER in content:
if num_video_tokens >= len(videos):
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
current_patch_index = video_patch_indices[num_video_tokens - 1] if num_video_tokens > 0 else 0 current_patch_index = video_patch_indices[num_video_tokens - 1] if num_video_tokens > 0 else 0
end_patch_index = video_patch_indices[num_video_tokens] end_patch_index = video_patch_indices[num_video_tokens]
num_patches = list(video_num_patches[current_patch_index:end_patch_index]) num_patches = list(video_num_patches[current_patch_index:end_patch_index])
...@@ -605,12 +629,6 @@ class InternVLPlugin(BasePlugin): ...@@ -605,12 +629,6 @@ class InternVLPlugin(BasePlugin):
message["content"] = content message["content"] = content
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
if len(videos) != num_video_tokens:
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
return messages return messages
@override @override
...@@ -637,10 +655,13 @@ class KimiVLPlugin(BasePlugin): ...@@ -637,10 +655,13 @@ class KimiVLPlugin(BasePlugin):
@override @override
def process_messages(self, messages, images, videos, audios, processor): def process_messages(self, messages, images, videos, audios, processor):
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
self._validate_messages(messages, images, videos, audios)
if self.expand_mm_tokens: if self.expand_mm_tokens:
mm_inputs = self._get_mm_inputs(images, videos, audios, processor) mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
image_grid_hws = mm_inputs.get("image_grid_hws", [])
else:
image_grid_hws = [None] * len(images)
image_grid_hws = mm_inputs.get("image_grid_hws", [])
num_image_tokens = 0 num_image_tokens = 0
image_processor: BaseImageProcessor = getattr(processor, "image_processor") image_processor: BaseImageProcessor = getattr(processor, "image_processor")
merge_length = math.prod(image_processor.merge_kernel_size) merge_length = math.prod(image_processor.merge_kernel_size)
...@@ -648,9 +669,6 @@ class KimiVLPlugin(BasePlugin): ...@@ -648,9 +669,6 @@ class KimiVLPlugin(BasePlugin):
for message in messages: for message in messages:
content = message["content"] content = message["content"]
while IMAGE_PLACEHOLDER in content: while IMAGE_PLACEHOLDER in content:
if num_image_tokens >= len(images):
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
image_seqlen = image_grid_hws[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1 image_seqlen = image_grid_hws[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1
content = content.replace( content = content.replace(
IMAGE_PLACEHOLDER, IMAGE_PLACEHOLDER,
...@@ -661,9 +679,6 @@ class KimiVLPlugin(BasePlugin): ...@@ -661,9 +679,6 @@ class KimiVLPlugin(BasePlugin):
message["content"] = content message["content"] = content
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
return messages return messages
...@@ -679,6 +694,7 @@ class Llama4Plugin(BasePlugin): ...@@ -679,6 +694,7 @@ class Llama4Plugin(BasePlugin):
processor: Optional["MMProcessor"], processor: Optional["MMProcessor"],
) -> list[dict[str, str]]: ) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
self._validate_messages(messages, images, videos, audios)
if self.expand_mm_tokens: if self.expand_mm_tokens:
mm_inputs = self._get_mm_inputs(images, videos, audios, processor) mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
if "pixel_values" in mm_inputs: if "pixel_values" in mm_inputs:
...@@ -701,9 +717,6 @@ class Llama4Plugin(BasePlugin): ...@@ -701,9 +717,6 @@ class Llama4Plugin(BasePlugin):
for local_image_index, split_part in enumerate(prompt_splits): for local_image_index, split_part in enumerate(prompt_splits):
new_content.append(split_part) new_content.append(split_part)
if local_image_index < placeholder_count: if local_image_index < placeholder_count:
if num_image_tokens >= len(images):
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
tokens_for_this_image = processor._prompt_split_image( tokens_for_this_image = processor._prompt_split_image(
aspect_ratios[num_image_tokens], num_patches_per_chunk aspect_ratios[num_image_tokens], num_patches_per_chunk
) )
...@@ -716,9 +729,6 @@ class Llama4Plugin(BasePlugin): ...@@ -716,9 +729,6 @@ class Llama4Plugin(BasePlugin):
message["content"] = content message["content"] = content
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
return messages return messages
@override @override
...@@ -751,7 +761,7 @@ class LlavaPlugin(BasePlugin): ...@@ -751,7 +761,7 @@ class LlavaPlugin(BasePlugin):
processor: Optional["MMProcessor"], processor: Optional["MMProcessor"],
) -> list[dict[str, str]]: ) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
num_image_tokens = 0 self._validate_messages(messages, images, videos, audios)
messages = deepcopy(messages) messages = deepcopy(messages)
if self.expand_mm_tokens: if self.expand_mm_tokens:
mm_inputs = self._get_mm_inputs(images, videos, audios, processor) mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
...@@ -768,17 +778,10 @@ class LlavaPlugin(BasePlugin): ...@@ -768,17 +778,10 @@ class LlavaPlugin(BasePlugin):
for message in messages: for message in messages:
content = message["content"] content = message["content"]
while IMAGE_PLACEHOLDER in content: while IMAGE_PLACEHOLDER in content:
if num_image_tokens >= len(images):
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1) content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
num_image_tokens += 1
message["content"] = content.replace("{{image}}", self.image_token) message["content"] = content.replace("{{image}}", self.image_token)
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
return messages return messages
...@@ -794,6 +797,7 @@ class LlavaNextPlugin(BasePlugin): ...@@ -794,6 +797,7 @@ class LlavaNextPlugin(BasePlugin):
processor: Optional["MMProcessor"], processor: Optional["MMProcessor"],
) -> list[dict[str, str]]: ) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
self._validate_messages(messages, images, videos, audios)
num_image_tokens = 0 num_image_tokens = 0
messages = deepcopy(messages) messages = deepcopy(messages)
if self.expand_mm_tokens: if self.expand_mm_tokens:
...@@ -805,9 +809,6 @@ class LlavaNextPlugin(BasePlugin): ...@@ -805,9 +809,6 @@ class LlavaNextPlugin(BasePlugin):
for message in messages: for message in messages:
content = message["content"] content = message["content"]
while IMAGE_PLACEHOLDER in content: while IMAGE_PLACEHOLDER in content:
if num_image_tokens >= len(images):
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
if self.expand_mm_tokens: if self.expand_mm_tokens:
orig_height, orig_width = next(image_sizes) orig_height, orig_width = next(image_sizes)
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width) image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
...@@ -821,9 +822,6 @@ class LlavaNextPlugin(BasePlugin): ...@@ -821,9 +822,6 @@ class LlavaNextPlugin(BasePlugin):
message["content"] = content.replace("{{image}}", self.image_token) message["content"] = content.replace("{{image}}", self.image_token)
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
return messages return messages
...@@ -839,7 +837,7 @@ class LlavaNextVideoPlugin(BasePlugin): ...@@ -839,7 +837,7 @@ class LlavaNextVideoPlugin(BasePlugin):
processor: Optional["MMProcessor"], processor: Optional["MMProcessor"],
) -> list[dict[str, str]]: ) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
num_image_tokens, num_video_tokens = 0, 0 self._validate_messages(messages, images, videos, audios)
messages = deepcopy(messages) messages = deepcopy(messages)
if self.expand_mm_tokens: if self.expand_mm_tokens:
mm_inputs = self._get_mm_inputs(images, videos, audios, processor) mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
...@@ -850,9 +848,6 @@ class LlavaNextVideoPlugin(BasePlugin): ...@@ -850,9 +848,6 @@ class LlavaNextVideoPlugin(BasePlugin):
for message in messages: for message in messages:
content = message["content"] content = message["content"]
while IMAGE_PLACEHOLDER in content: while IMAGE_PLACEHOLDER in content:
if num_image_tokens >= len(images):
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
if self.expand_mm_tokens: if self.expand_mm_tokens:
orig_height, orig_width = next(image_sizes) orig_height, orig_width = next(image_sizes)
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width) image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
...@@ -862,7 +857,6 @@ class LlavaNextVideoPlugin(BasePlugin): ...@@ -862,7 +857,6 @@ class LlavaNextVideoPlugin(BasePlugin):
image_seqlen = 1 image_seqlen = 1
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1) content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
num_image_tokens += 1
message["content"] = content.replace("{{image}}", self.image_token) message["content"] = content.replace("{{image}}", self.image_token)
...@@ -879,20 +873,10 @@ class LlavaNextVideoPlugin(BasePlugin): ...@@ -879,20 +873,10 @@ class LlavaNextVideoPlugin(BasePlugin):
for message in messages: for message in messages:
content = message["content"] content = message["content"]
while VIDEO_PLACEHOLDER in content: while VIDEO_PLACEHOLDER in content:
if num_video_tokens >= len(videos):
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1) content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1)
num_video_tokens += 1
message["content"] = content.replace("{{video}}", self.video_token) message["content"] = content.replace("{{video}}", self.video_token)
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
if len(videos) != num_video_tokens:
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
return messages return messages
...@@ -978,6 +962,7 @@ class MiniCPMVPlugin(BasePlugin): ...@@ -978,6 +962,7 @@ class MiniCPMVPlugin(BasePlugin):
processor: Optional["MMProcessor"], processor: Optional["MMProcessor"],
) -> list[dict[str, str]]: ) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
self._validate_messages(messages, images, videos, audios)
num_image_tokens, num_video_tokens, num_audio_tokens = 0, 0, 0 num_image_tokens, num_video_tokens, num_audio_tokens = 0, 0, 0
messages = deepcopy(messages) messages = deepcopy(messages)
image_processor: BaseImageProcessor = getattr(processor, "image_processor") image_processor: BaseImageProcessor = getattr(processor, "image_processor")
...@@ -996,24 +981,15 @@ class MiniCPMVPlugin(BasePlugin): ...@@ -996,24 +981,15 @@ class MiniCPMVPlugin(BasePlugin):
for i, message in enumerate(messages): for i, message in enumerate(messages):
content = message["content"] content = message["content"]
while IMAGE_PLACEHOLDER in content: while IMAGE_PLACEHOLDER in content:
if num_image_tokens >= len(images):
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1) content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
num_image_tokens += 1 num_image_tokens += 1
while VIDEO_PLACEHOLDER in content: while VIDEO_PLACEHOLDER in content:
if num_video_tokens >= len(videos):
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
video_seqlen = len(mm_inputs["pixel_values"][num_video_tokens]) if self.expand_mm_tokens else 1 video_seqlen = len(mm_inputs["pixel_values"][num_video_tokens]) if self.expand_mm_tokens else 1
content = content.replace(VIDEO_PLACEHOLDER, "{{image}}" * video_seqlen, 1) content = content.replace(VIDEO_PLACEHOLDER, "{{image}}" * video_seqlen, 1)
num_video_tokens += 1 num_video_tokens += 1
while AUDIO_PLACEHOLDER in content: while AUDIO_PLACEHOLDER in content:
if num_audio_tokens >= len(audios):
raise ValueError(f"`len(audios)` is less than the number of {AUDIO_PLACEHOLDER} tokens.")
content = content.replace(AUDIO_PLACEHOLDER, "{{audio}}", 1) content = content.replace(AUDIO_PLACEHOLDER, "{{audio}}", 1)
num_audio_tokens += 1 num_audio_tokens += 1
...@@ -1065,15 +1041,6 @@ class MiniCPMVPlugin(BasePlugin): ...@@ -1065,15 +1041,6 @@ class MiniCPMVPlugin(BasePlugin):
final_text += text_chunks[-1] final_text += text_chunks[-1]
messages[index]["content"] = final_text messages[index]["content"] = final_text
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
if len(videos) != num_video_tokens:
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
if len(audios) != num_audio_tokens:
raise ValueError(f"The number of audios does not match the number of {AUDIO_PLACEHOLDER} tokens.")
return messages return messages
@override @override
...@@ -1157,6 +1124,7 @@ class MllamaPlugin(BasePlugin): ...@@ -1157,6 +1124,7 @@ class MllamaPlugin(BasePlugin):
processor: Optional["MMProcessor"], processor: Optional["MMProcessor"],
) -> list[dict[str, str]]: ) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
self._validate_messages(messages, images, videos, audios)
num_image_tokens = 0 num_image_tokens = 0
messages = deepcopy(messages) messages = deepcopy(messages)
for message in messages: for message in messages:
...@@ -1164,9 +1132,6 @@ class MllamaPlugin(BasePlugin): ...@@ -1164,9 +1132,6 @@ class MllamaPlugin(BasePlugin):
num_image_tokens += content.count(IMAGE_PLACEHOLDER) num_image_tokens += content.count(IMAGE_PLACEHOLDER)
message["content"] = content.replace(IMAGE_PLACEHOLDER, self.image_token) message["content"] = content.replace(IMAGE_PLACEHOLDER, self.image_token)
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
return messages return messages
@override @override
...@@ -1214,6 +1179,7 @@ class PaliGemmaPlugin(BasePlugin): ...@@ -1214,6 +1179,7 @@ class PaliGemmaPlugin(BasePlugin):
processor: Optional["MMProcessor"], processor: Optional["MMProcessor"],
) -> list[dict[str, str]]: ) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
self._validate_messages(messages, images, videos, audios)
num_image_tokens = 0 num_image_tokens = 0
messages = deepcopy(messages) messages = deepcopy(messages)
for message in messages: for message in messages:
...@@ -1224,9 +1190,6 @@ class PaliGemmaPlugin(BasePlugin): ...@@ -1224,9 +1190,6 @@ class PaliGemmaPlugin(BasePlugin):
message["content"] = content message["content"] = content
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
return messages return messages
@override @override
...@@ -1281,7 +1244,7 @@ class PixtralPlugin(BasePlugin): ...@@ -1281,7 +1244,7 @@ class PixtralPlugin(BasePlugin):
processor: Optional["MMProcessor"], processor: Optional["MMProcessor"],
) -> list[dict[str, str]]: ) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
num_image_tokens = 0 self._validate_messages(messages, images, videos, audios)
messages = deepcopy(messages) messages = deepcopy(messages)
if self.expand_mm_tokens: if self.expand_mm_tokens:
mm_inputs = self._get_mm_inputs(images, videos, audios, processor) mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
...@@ -1291,15 +1254,13 @@ class PixtralPlugin(BasePlugin): ...@@ -1291,15 +1254,13 @@ class PixtralPlugin(BasePlugin):
image_sizes = iter(mm_inputs["image_sizes"][0]) image_sizes = iter(mm_inputs["image_sizes"][0])
else: else:
image_sizes = iter(mm_inputs["image_sizes"].tolist()) image_sizes = iter(mm_inputs["image_sizes"].tolist())
image_break_token: str = getattr(processor, "image_break_token") image_break_token: str = getattr(processor, "image_break_token")
image_end_token: str = getattr(processor, "image_end_token") image_end_token: str = getattr(processor, "image_end_token")
for message in messages: for message in messages:
content = message["content"] content = message["content"]
while IMAGE_PLACEHOLDER in content: while IMAGE_PLACEHOLDER in content:
if num_image_tokens >= len(images):
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
if self.expand_mm_tokens: if self.expand_mm_tokens:
height, width = next(image_sizes) height, width = next(image_sizes)
num_height_tokens = height // processor.patch_size num_height_tokens = height // processor.patch_size
...@@ -1312,13 +1273,9 @@ class PixtralPlugin(BasePlugin): ...@@ -1312,13 +1273,9 @@ class PixtralPlugin(BasePlugin):
replace_str = self.image_token replace_str = self.image_token
content = content.replace(IMAGE_PLACEHOLDER, replace_str, 1) content = content.replace(IMAGE_PLACEHOLDER, replace_str, 1)
num_image_tokens += 1
message["content"] = content message["content"] = content
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
return messages return messages
@override @override
...@@ -1355,9 +1312,9 @@ class Qwen2AudioPlugin(BasePlugin): ...@@ -1355,9 +1312,9 @@ class Qwen2AudioPlugin(BasePlugin):
processor: Optional["MMProcessor"], processor: Optional["MMProcessor"],
) -> list[dict[str, str]]: ) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
self._validate_messages(messages, images, videos, audios)
bos_token: str = getattr(processor, "audio_bos_token") bos_token: str = getattr(processor, "audio_bos_token")
eos_token: str = getattr(processor, "audio_eos_token") eos_token: str = getattr(processor, "audio_eos_token")
num_audio_tokens = 0
messages = deepcopy(messages) messages = deepcopy(messages)
if self.expand_mm_tokens: if self.expand_mm_tokens:
mm_inputs = self._get_mm_inputs([], [], audios, processor) mm_inputs = self._get_mm_inputs([], [], audios, processor)
...@@ -1367,9 +1324,6 @@ class Qwen2AudioPlugin(BasePlugin): ...@@ -1367,9 +1324,6 @@ class Qwen2AudioPlugin(BasePlugin):
for message in messages: for message in messages:
content = message["content"] content = message["content"]
while AUDIO_PLACEHOLDER in content: while AUDIO_PLACEHOLDER in content:
if num_audio_tokens >= len(audios):
raise ValueError(f"`len(audios)` is less than the number of {AUDIO_PLACEHOLDER} tokens.")
if self.expand_mm_tokens: if self.expand_mm_tokens:
audio_length = audio_lengths.pop(0) audio_length = audio_lengths.pop(0)
input_length = (audio_length - 1) // 2 + 1 input_length = (audio_length - 1) // 2 + 1
...@@ -1380,13 +1334,9 @@ class Qwen2AudioPlugin(BasePlugin): ...@@ -1380,13 +1334,9 @@ class Qwen2AudioPlugin(BasePlugin):
content = content.replace( content = content.replace(
AUDIO_PLACEHOLDER, f"{bos_token}{self.audio_token * audio_seqlen}{eos_token}", 1 AUDIO_PLACEHOLDER, f"{bos_token}{self.audio_token * audio_seqlen}{eos_token}", 1
) )
num_audio_tokens += 1
message["content"] = content message["content"] = content
if len(audios) != num_audio_tokens:
raise ValueError(f"The number of audios does not match the number of {AUDIO_PLACEHOLDER} tokens.")
return messages return messages
@override @override
...@@ -1494,6 +1444,7 @@ class Qwen2VLPlugin(BasePlugin): ...@@ -1494,6 +1444,7 @@ class Qwen2VLPlugin(BasePlugin):
processor: Optional["MMProcessor"], processor: Optional["MMProcessor"],
) -> list[dict[str, str]]: ) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
self._validate_messages(messages, images, videos, audios)
num_image_tokens, num_video_tokens = 0, 0 num_image_tokens, num_video_tokens = 0, 0
messages = deepcopy(messages) messages = deepcopy(messages)
image_processor: BaseImageProcessor = getattr(processor, "image_processor") image_processor: BaseImageProcessor = getattr(processor, "image_processor")
...@@ -1510,9 +1461,6 @@ class Qwen2VLPlugin(BasePlugin): ...@@ -1510,9 +1461,6 @@ class Qwen2VLPlugin(BasePlugin):
for message in messages: for message in messages:
content = message["content"] content = message["content"]
while IMAGE_PLACEHOLDER in content: while IMAGE_PLACEHOLDER in content:
if num_image_tokens >= len(images):
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
image_seqlen = image_grid_thw[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1 image_seqlen = image_grid_thw[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1
content = content.replace( content = content.replace(
IMAGE_PLACEHOLDER, f"<|vision_start|>{self.image_token * image_seqlen}<|vision_end|>", 1 IMAGE_PLACEHOLDER, f"<|vision_start|>{self.image_token * image_seqlen}<|vision_end|>", 1
...@@ -1520,9 +1468,6 @@ class Qwen2VLPlugin(BasePlugin): ...@@ -1520,9 +1468,6 @@ class Qwen2VLPlugin(BasePlugin):
num_image_tokens += 1 num_image_tokens += 1
while VIDEO_PLACEHOLDER in content: while VIDEO_PLACEHOLDER in content:
if num_video_tokens >= len(videos):
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
video_seqlen = video_grid_thw[num_video_tokens].prod() // merge_length if self.expand_mm_tokens else 1 video_seqlen = video_grid_thw[num_video_tokens].prod() // merge_length if self.expand_mm_tokens else 1
content = content.replace( content = content.replace(
VIDEO_PLACEHOLDER, f"<|vision_start|>{self.video_token * video_seqlen}<|vision_end|>", 1 VIDEO_PLACEHOLDER, f"<|vision_start|>{self.video_token * video_seqlen}<|vision_end|>", 1
...@@ -1531,12 +1476,6 @@ class Qwen2VLPlugin(BasePlugin): ...@@ -1531,12 +1476,6 @@ class Qwen2VLPlugin(BasePlugin):
message["content"] = content message["content"] = content
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
if len(videos) != num_video_tokens:
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
return messages return messages
...@@ -1602,6 +1541,7 @@ class Qwen2OmniPlugin(Qwen2VLPlugin): ...@@ -1602,6 +1541,7 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
processor: Optional["MMProcessor"], processor: Optional["MMProcessor"],
) -> list[dict[str, str]]: ) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
self._validate_messages(messages, images, videos, audios)
num_image_tokens, num_video_tokens, num_audio_tokens = 0, 0, 0 num_image_tokens, num_video_tokens, num_audio_tokens = 0, 0, 0
messages = deepcopy(messages) messages = deepcopy(messages)
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None) image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
...@@ -1624,9 +1564,6 @@ class Qwen2OmniPlugin(Qwen2VLPlugin): ...@@ -1624,9 +1564,6 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
for message in messages: for message in messages:
content = message["content"] content = message["content"]
while IMAGE_PLACEHOLDER in content: while IMAGE_PLACEHOLDER in content:
if num_image_tokens >= len(images):
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
image_seqlen = image_grid_thw[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1 image_seqlen = image_grid_thw[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1
content = content.replace( content = content.replace(
IMAGE_PLACEHOLDER, f"<|vision_bos|>{self.image_token * image_seqlen}<|vision_eos|>", 1 IMAGE_PLACEHOLDER, f"<|vision_bos|>{self.image_token * image_seqlen}<|vision_eos|>", 1
...@@ -1642,11 +1579,6 @@ class Qwen2OmniPlugin(Qwen2VLPlugin): ...@@ -1642,11 +1579,6 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
) )
while VIDEO_PLACEHOLDER in content: while VIDEO_PLACEHOLDER in content:
if num_video_tokens >= len(videos):
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
if num_audio_tokens >= len(audios):
raise ValueError(f"`len(audios)` is less than the number of {AUDIO_PLACEHOLDER} tokens.")
video_pos = content.find(VIDEO_PLACEHOLDER) video_pos = content.find(VIDEO_PLACEHOLDER)
audio_pos = content.find(AUDIO_PLACEHOLDER, video_pos) audio_pos = content.find(AUDIO_PLACEHOLDER, video_pos)
if audio_pos == -1 or audio_pos < video_pos: if audio_pos == -1 or audio_pos < video_pos:
...@@ -1688,9 +1620,6 @@ class Qwen2OmniPlugin(Qwen2VLPlugin): ...@@ -1688,9 +1620,6 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
num_video_tokens += 1 num_video_tokens += 1
else: else:
while AUDIO_PLACEHOLDER in content: while AUDIO_PLACEHOLDER in content:
if num_audio_tokens >= len(audios):
raise ValueError(f"`len(audios)` is less than the number of {AUDIO_PLACEHOLDER} tokens.")
audio_seqlen = audio_lengths[num_audio_tokens] if self.expand_mm_tokens else 1 audio_seqlen = audio_lengths[num_audio_tokens] if self.expand_mm_tokens else 1
content = content.replace( content = content.replace(
AUDIO_PLACEHOLDER, f"<|audio_bos|>{self.audio_token * audio_seqlen}<|audio_eos|>", 1 AUDIO_PLACEHOLDER, f"<|audio_bos|>{self.audio_token * audio_seqlen}<|audio_eos|>", 1
...@@ -1698,9 +1627,6 @@ class Qwen2OmniPlugin(Qwen2VLPlugin): ...@@ -1698,9 +1627,6 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
num_audio_tokens += 1 num_audio_tokens += 1
while VIDEO_PLACEHOLDER in content: while VIDEO_PLACEHOLDER in content:
if num_video_tokens >= len(videos):
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
video_seqlen = ( video_seqlen = (
video_grid_thw[num_video_tokens].prod() // merge_length if self.expand_mm_tokens else 1 video_grid_thw[num_video_tokens].prod() // merge_length if self.expand_mm_tokens else 1
) )
...@@ -1711,15 +1637,6 @@ class Qwen2OmniPlugin(Qwen2VLPlugin): ...@@ -1711,15 +1637,6 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
message["content"] = content message["content"] = content
if len(audios) != num_audio_tokens:
raise ValueError(f"The number of audios does not match the number of {AUDIO_PLACEHOLDER} tokens.")
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
if len(videos) != num_video_tokens:
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
return messages return messages
...@@ -1735,6 +1652,7 @@ class VideoLlavaPlugin(BasePlugin): ...@@ -1735,6 +1652,7 @@ class VideoLlavaPlugin(BasePlugin):
processor: Optional["MMProcessor"], processor: Optional["MMProcessor"],
) -> list[dict[str, str]]: ) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
self._validate_messages(messages, images, videos, audios)
num_image_tokens, num_video_tokens = 0, 0 num_image_tokens, num_video_tokens = 0, 0
messages = deepcopy(messages) messages = deepcopy(messages)
num_frames = 0 num_frames = 0
...@@ -1762,28 +1680,16 @@ class VideoLlavaPlugin(BasePlugin): ...@@ -1762,28 +1680,16 @@ class VideoLlavaPlugin(BasePlugin):
for message in messages: for message in messages:
content = message["content"] content = message["content"]
while IMAGE_PLACEHOLDER in content: while IMAGE_PLACEHOLDER in content:
if num_image_tokens >= len(images):
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1) content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
num_image_tokens += 1 num_image_tokens += 1
while VIDEO_PLACEHOLDER in content: while VIDEO_PLACEHOLDER in content:
if num_video_tokens >= len(videos):
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1) content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1)
num_video_tokens += 1 num_video_tokens += 1
content = content.replace("{{image}}", self.image_token) content = content.replace("{{image}}", self.image_token)
message["content"] = content.replace("{{video}}", self.video_token) message["content"] = content.replace("{{video}}", self.video_token)
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
if len(videos) != num_video_tokens:
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
return messages return messages
......
...@@ -115,12 +115,7 @@ def get_dataset_list(dataset_names: Optional[list[str]], dataset_dir: str) -> li ...@@ -115,12 +115,7 @@ def get_dataset_list(dataset_names: Optional[list[str]], dataset_dir: str) -> li
dataset_list: list[DatasetAttr] = [] dataset_list: list[DatasetAttr] = []
for name in dataset_names: for name in dataset_names:
if dataset_info is None: # dataset_dir is ONLINE if dataset_info is None: # dataset_dir is ONLINE
if use_modelscope(): load_from = "ms_hub" if use_modelscope() else "om_hub" if use_openmind() else "hf_hub"
load_from = "ms_hub"
elif use_openmind():
load_from = "om_hub"
else:
load_from = "hf_hub"
dataset_attr = DatasetAttr(load_from, dataset_name=name) dataset_attr = DatasetAttr(load_from, dataset_name=name)
dataset_list.append(dataset_attr) dataset_list.append(dataset_attr)
continue continue
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import re import re
from copy import deepcopy
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional, Union from typing import TYPE_CHECKING, Optional, Union
...@@ -51,6 +52,7 @@ class Template: ...@@ -51,6 +52,7 @@ class Template:
efficient_eos: bool efficient_eos: bool
replace_eos: bool replace_eos: bool
replace_jinja_template: bool replace_jinja_template: bool
enable_thinking: Optional[bool]
mm_plugin: "BasePlugin" mm_plugin: "BasePlugin"
def encode_oneturn( def encode_oneturn(
...@@ -61,7 +63,7 @@ class Template: ...@@ -61,7 +63,7 @@ class Template:
tools: Optional[str] = None, tools: Optional[str] = None,
) -> tuple[list[int], list[int]]: ) -> tuple[list[int], list[int]]:
r"""Return a single pair of token ids representing prompt and response respectively.""" r"""Return a single pair of token ids representing prompt and response respectively."""
encoded_messages = self._encode(tokenizer, messages, system, tools, remove_thought=True) encoded_messages = self._encode(tokenizer, messages, system, tools)
prompt_ids = [] prompt_ids = []
for encoded_ids in encoded_messages[:-1]: for encoded_ids in encoded_messages[:-1]:
prompt_ids += encoded_ids prompt_ids += encoded_ids
...@@ -77,7 +79,7 @@ class Template: ...@@ -77,7 +79,7 @@ class Template:
tools: Optional[str] = None, tools: Optional[str] = None,
) -> list[tuple[list[int], list[int]]]: ) -> list[tuple[list[int], list[int]]]:
r"""Return multiple pairs of token ids representing prompts and responses respectively.""" r"""Return multiple pairs of token ids representing prompts and responses respectively."""
encoded_messages = self._encode(tokenizer, messages, system, tools, remove_thought=False) encoded_messages = self._encode(tokenizer, messages, system, tools)
return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)] return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)]
def extract_tool(self, content: str) -> Union[str, list["FunctionCall"]]: def extract_tool(self, content: str) -> Union[str, list["FunctionCall"]]:
...@@ -92,6 +94,19 @@ class Template: ...@@ -92,6 +94,19 @@ class Template:
return list(stop_token_ids) return list(stop_token_ids)
def add_thought(self, content: str = "") -> str:
r"""Add empty thought to assistant message."""
return f"{self.thought_words[0]}\n\n{self.thought_words[1]}\n\n" + content
def remove_thought(self, content: str) -> str:
r"""Remove thought from assistant message."""
pattern = re.compile(f"{re.escape(self.thought_words[0])}(.*?){re.escape(self.thought_words[1])}", re.DOTALL)
return re.sub(pattern, "", content).lstrip("\n")
def get_thought_word_ids(self, tokenizer: "PreTrainedTokenizer") -> list[int]:
r"""Get the token ids of thought words."""
return tokenizer.encode(self.add_thought(), add_special_tokens=False)
def _convert_elements_to_ids(self, tokenizer: "PreTrainedTokenizer", elements: "SLOTS") -> list[int]: def _convert_elements_to_ids(self, tokenizer: "PreTrainedTokenizer", elements: "SLOTS") -> list[int]:
r"""Convert elements to token ids.""" r"""Convert elements to token ids."""
token_ids = [] token_ids = []
...@@ -111,18 +126,12 @@ class Template: ...@@ -111,18 +126,12 @@ class Template:
return token_ids return token_ids
def _remove_thought(self, content: str) -> str:
r"""Remove thought from assistant message."""
pattern = re.compile(f"{re.escape(self.thought_words[0])}(.*?){re.escape(self.thought_words[1])}", re.DOTALL)
return re.sub(pattern, "", content).lstrip("\n")
def _encode( def _encode(
self, self,
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
messages: list[dict[str, str]], messages: list[dict[str, str]],
system: Optional[str], system: Optional[str],
tools: Optional[str], tools: Optional[str],
remove_thought: bool,
) -> list[list[int]]: ) -> list[list[int]]:
r"""Encode formatted inputs to pairs of token ids. r"""Encode formatted inputs to pairs of token ids.
...@@ -140,18 +149,14 @@ class Template: ...@@ -140,18 +149,14 @@ class Template:
tool_text = self.format_tools.apply(content=tools)[0] if tools else "" tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
elements += self.format_system.apply(content=(system + tool_text)) elements += self.format_system.apply(content=(system + tool_text))
content = message["content"]
if remove_thought and message["role"] == Role.ASSISTANT and (i != len(messages) - 1):
content = self._remove_thought(content)
if message["role"] == Role.USER: if message["role"] == Role.USER:
elements += self.format_user.apply(content=content, idx=str(i // 2)) elements += self.format_user.apply(content=message["content"], idx=str(i // 2))
elif message["role"] == Role.ASSISTANT: elif message["role"] == Role.ASSISTANT:
elements += self.format_assistant.apply(content=content) elements += self.format_assistant.apply(content=message["content"])
elif message["role"] == Role.OBSERVATION: elif message["role"] == Role.OBSERVATION:
elements += self.format_observation.apply(content=content) elements += self.format_observation.apply(content=message["content"])
elif message["role"] == Role.FUNCTION: elif message["role"] == Role.FUNCTION:
elements += self.format_function.apply(content=content) elements += self.format_function.apply(content=message["content"])
else: else:
raise NotImplementedError("Unexpected role: {}".format(message["role"])) raise NotImplementedError("Unexpected role: {}".format(message["role"]))
...@@ -162,6 +167,9 @@ class Template: ...@@ -162,6 +167,9 @@ class Template:
@staticmethod @staticmethod
def _add_or_replace_eos_token(tokenizer: "PreTrainedTokenizer", eos_token: str) -> None: def _add_or_replace_eos_token(tokenizer: "PreTrainedTokenizer", eos_token: str) -> None:
r"""Add or replace eos token to the tokenizer.""" r"""Add or replace eos token to the tokenizer."""
if tokenizer.eos_token == eos_token:
return
is_added = tokenizer.eos_token_id is None is_added = tokenizer.eos_token_id is None
num_added_tokens = tokenizer.add_special_tokens({"eos_token": eos_token}) num_added_tokens = tokenizer.add_special_tokens({"eos_token": eos_token})
...@@ -328,7 +336,6 @@ class Llama2Template(Template): ...@@ -328,7 +336,6 @@ class Llama2Template(Template):
messages: list[dict[str, str]], messages: list[dict[str, str]],
system: str, system: str,
tools: str, tools: str,
remove_thought: bool,
) -> list[list[int]]: ) -> list[list[int]]:
system = system or self.default_system system = system or self.default_system
encoded_messages = [] encoded_messages = []
...@@ -342,18 +349,14 @@ class Llama2Template(Template): ...@@ -342,18 +349,14 @@ class Llama2Template(Template):
tool_text = self.format_tools.apply(content=tools)[0] if tools else "" tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
system_text = self.format_system.apply(content=(system + tool_text))[0] system_text = self.format_system.apply(content=(system + tool_text))[0]
content = message["content"]
if remove_thought and message["role"] == Role.ASSISTANT and (i != len(messages) - 1):
content = self._remove_thought(content)
if message["role"] == Role.USER: if message["role"] == Role.USER:
elements += self.format_user.apply(content=system_text + content) elements += self.format_user.apply(content=system_text + message["content"])
elif message["role"] == Role.ASSISTANT: elif message["role"] == Role.ASSISTANT:
elements += self.format_assistant.apply(content=content) elements += self.format_assistant.apply(content=message["content"])
elif message["role"] == Role.OBSERVATION: elif message["role"] == Role.OBSERVATION:
elements += self.format_observation.apply(content=content) elements += self.format_observation.apply(content=message["content"])
elif message["role"] == Role.FUNCTION: elif message["role"] == Role.FUNCTION:
elements += self.format_function.apply(content=content) elements += self.format_function.apply(content=message["content"])
else: else:
raise NotImplementedError("Unexpected role: {}".format(message["role"])) raise NotImplementedError("Unexpected role: {}".format(message["role"]))
...@@ -392,6 +395,64 @@ class Llama2Template(Template): ...@@ -392,6 +395,64 @@ class Llama2Template(Template):
return jinja_template return jinja_template
@dataclass
class ReasoningTemplate(Template):
r"""A template that add thought to assistant message."""
@override
def encode_oneturn(
self,
tokenizer: "PreTrainedTokenizer",
messages: list[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
) -> tuple[list[int], list[int]]:
messages = deepcopy(messages)
for i in range(1, len(messages) - 2, 2):
messages[i]["content"] = self.remove_thought(messages[i]["content"])
if self.enable_thinking is False: # remove all cot
messages[-1]["content"] = self.remove_thought(messages[-1]["content"])
prompt_ids, response_ids = super().encode_oneturn(tokenizer, messages, system, tools)
if (
self.thought_words[0] not in messages[-1]["content"]
and self.thought_words[1] not in messages[-1]["content"]
): # add empty cot
if not self.enable_thinking: # do not compute loss
prompt_ids += self.get_thought_word_ids(tokenizer)
else: # do compute loss
response_ids = self.get_thought_word_ids(tokenizer) + response_ids
return prompt_ids, response_ids
@override
def encode_multiturn(
self,
tokenizer: "PreTrainedTokenizer",
messages: list[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
) -> list[tuple[list[int], list[int]]]:
messages = deepcopy(messages)
if self.enable_thinking is False: # remove all cot
for i in range(1, len(messages), 2):
messages[i]["content"] = self.remove_thought(messages[i]["content"])
encoded_messages = self._encode(tokenizer, messages, system, tools)
for i in range(0, len(messages), 2):
if (
self.thought_words[0] not in messages[i + 1]["content"]
and self.thought_words[1] not in messages[i + 1]["content"]
): # add empty cot
if not self.enable_thinking: # do not compute loss
encoded_messages[i] += self.get_thought_word_ids(tokenizer)
else: # do compute loss
encoded_messages[i + 1] = self.get_thought_word_ids(tokenizer) + encoded_messages[i + 1]
return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)]
TEMPLATES: dict[str, "Template"] = {} TEMPLATES: dict[str, "Template"] = {}
...@@ -410,6 +471,7 @@ def register_template( ...@@ -410,6 +471,7 @@ def register_template(
efficient_eos: bool = False, efficient_eos: bool = False,
replace_eos: bool = False, replace_eos: bool = False,
replace_jinja_template: bool = False, replace_jinja_template: bool = False,
enable_thinking: Optional[bool] = True,
mm_plugin: "BasePlugin" = get_mm_plugin(name="base"), mm_plugin: "BasePlugin" = get_mm_plugin(name="base"),
template_class: type["Template"] = Template, template_class: type["Template"] = Template,
) -> None: ) -> None:
...@@ -456,6 +518,7 @@ def register_template( ...@@ -456,6 +518,7 @@ def register_template(
efficient_eos=efficient_eos, efficient_eos=efficient_eos,
replace_eos=replace_eos, replace_eos=replace_eos,
replace_jinja_template=replace_jinja_template, replace_jinja_template=replace_jinja_template,
enable_thinking=enable_thinking,
mm_plugin=mm_plugin, mm_plugin=mm_plugin,
) )
...@@ -492,6 +555,7 @@ def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template": ...@@ -492,6 +555,7 @@ def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template":
messages = [{"role": "user", "content": "{{content}}"}, {"role": "assistant", "content": "{{content}}"}] messages = [{"role": "user", "content": "{{content}}"}, {"role": "assistant", "content": "{{content}}"}]
assistant_slot = tokenizer.apply_chat_template(messages, add_generation_prompt=False, tokenize=False) assistant_slot = tokenizer.apply_chat_template(messages, add_generation_prompt=False, tokenize=False)
assistant_slot = assistant_slot[len(prefix) + len(user_slot) :] assistant_slot = assistant_slot[len(prefix) + len(user_slot) :]
template_class = ReasoningTemplate if "<think>" in assistant_slot else Template
assistant_slot = assistant_slot.replace("<think>", "").replace("</think>", "").lstrip("\n") # remove thought tags assistant_slot = assistant_slot.replace("<think>", "").replace("</think>", "").lstrip("\n") # remove thought tags
if len(user_slot) > len(user_slot_empty_system): if len(user_slot) > len(user_slot_empty_system):
...@@ -501,7 +565,7 @@ def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template": ...@@ -501,7 +565,7 @@ def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template":
else: # if defaut_system is empty, user_slot_empty_system will be longer than user_slot else: # if defaut_system is empty, user_slot_empty_system will be longer than user_slot
default_system = "" default_system = ""
return Template( return template_class(
format_user=StringFormatter(slots=[user_slot]), format_user=StringFormatter(slots=[user_slot]),
format_assistant=StringFormatter(slots=[assistant_slot]), format_assistant=StringFormatter(slots=[assistant_slot]),
format_system=StringFormatter(slots=[system_slot]), format_system=StringFormatter(slots=[system_slot]),
...@@ -515,6 +579,7 @@ def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template": ...@@ -515,6 +579,7 @@ def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template":
efficient_eos=False, efficient_eos=False,
replace_eos=False, replace_eos=False,
replace_jinja_template=False, replace_jinja_template=False,
enable_thinking=True,
mm_plugin=get_mm_plugin(name="base"), mm_plugin=get_mm_plugin(name="base"),
) )
...@@ -543,6 +608,11 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: ...@@ -543,6 +608,11 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
template.format_function = FunctionFormatter(slots=default_slots, tool_format=data_args.tool_format) template.format_function = FunctionFormatter(slots=default_slots, tool_format=data_args.tool_format)
template.format_tools = ToolFormatter(tool_format=data_args.tool_format) template.format_tools = ToolFormatter(tool_format=data_args.tool_format)
if data_args.default_system is not None:
logger.info_rank0(f"Using default system message: {data_args.default_system}.")
template.default_system = data_args.default_system
template.enable_thinking = data_args.enable_thinking
template.fix_special_tokens(tokenizer) template.fix_special_tokens(tokenizer)
template.fix_jinja_template(tokenizer) template.fix_jinja_template(tokenizer)
return template return template
...@@ -756,6 +826,7 @@ register_template( ...@@ -756,6 +826,7 @@ register_template(
"ABOUT YOURSELF UNLESS THE INFORMATION IS DIRECTLY PERTINENT TO THE USER'S QUERY." "ABOUT YOURSELF UNLESS THE INFORMATION IS DIRECTLY PERTINENT TO THE USER'S QUERY."
), ),
stop_words=["<|im_end|>"], stop_words=["<|im_end|>"],
replace_eos=True,
) )
...@@ -774,6 +845,15 @@ register_template( ...@@ -774,6 +845,15 @@ register_template(
) )
# copied from deepseek3 template
register_template(
name="deepseekr1",
format_user=StringFormatter(slots=["<|User|>{{content}}<|Assistant|>"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
template_class=ReasoningTemplate,
)
register_template( register_template(
name="deepseekcoder", name="deepseekcoder",
format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n### Response:"]), format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n### Response:"]),
...@@ -838,6 +918,7 @@ register_template( ...@@ -838,6 +918,7 @@ register_template(
), ),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<end_of_turn>"], stop_words=["<end_of_turn>"],
replace_eos=True,
template_class=Llama2Template, template_class=Llama2Template,
) )
...@@ -853,6 +934,7 @@ register_template( ...@@ -853,6 +934,7 @@ register_template(
), ),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<end_of_turn>"], stop_words=["<end_of_turn>"],
replace_eos=True,
mm_plugin=get_mm_plugin("gemma3", image_token="<image_soft_token>"), mm_plugin=get_mm_plugin("gemma3", image_token="<image_soft_token>"),
template_class=Llama2Template, template_class=Llama2Template,
) )
...@@ -872,6 +954,22 @@ register_template( ...@@ -872,6 +954,22 @@ register_template(
) )
# copied from glm4 template
register_template(
name="glmz1",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
format_assistant=StringFormatter(slots=["\n{{content}}"]),
format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4"),
format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]),
format_tools=ToolFormatter(tool_format="glm4"),
format_prefix=EmptyFormatter(slots=["[gMASK]<sop>"]),
stop_words=["<|user|>", "<|observation|>"],
efficient_eos=True,
template_class=ReasoningTemplate,
)
register_template( register_template(
name="granite3", name="granite3",
format_user=StringFormatter( format_user=StringFormatter(
...@@ -1018,6 +1116,7 @@ register_template( ...@@ -1018,6 +1116,7 @@ register_template(
format_tools=ToolFormatter(tool_format="llama3"), format_tools=ToolFormatter(tool_format="llama3"),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|eot_id|>", "<|eom_id|>"], stop_words=["<|eot_id|>", "<|eom_id|>"],
replace_eos=True,
) )
...@@ -1037,6 +1136,7 @@ register_template( ...@@ -1037,6 +1136,7 @@ register_template(
format_tools=ToolFormatter(tool_format="llama3"), format_tools=ToolFormatter(tool_format="llama3"),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|eot|>", "<|eom|>"], stop_words=["<|eot|>", "<|eom|>"],
replace_eos=True,
mm_plugin=get_mm_plugin(name="llama4", image_token="<|image|>"), mm_plugin=get_mm_plugin(name="llama4", image_token="<|image|>"),
) )
...@@ -1066,6 +1166,7 @@ register_template( ...@@ -1066,6 +1166,7 @@ register_template(
format_tools=ToolFormatter(tool_format="llama3"), format_tools=ToolFormatter(tool_format="llama3"),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|eot_id|>", "<|eom_id|>"], stop_words=["<|eot_id|>", "<|eom_id|>"],
replace_eos=True,
mm_plugin=get_mm_plugin(name="mllama", image_token="<|image|>"), mm_plugin=get_mm_plugin(name="mllama", image_token="<|image|>"),
) )
...@@ -1079,6 +1180,7 @@ register_template( ...@@ -1079,6 +1180,7 @@ register_template(
format_system=StringFormatter(slots=["<|im_system|>system<|im_middle|>{{content}}<|im_end|>"]), format_system=StringFormatter(slots=["<|im_system|>system<|im_middle|>{{content}}<|im_end|>"]),
default_system="You are a helpful assistant provided by Moonshot-AI.", default_system="You are a helpful assistant provided by Moonshot-AI.",
stop_words=["<|im_end|>"], stop_words=["<|im_end|>"],
replace_eos=True,
) )
...@@ -1131,6 +1233,7 @@ register_template( ...@@ -1131,6 +1233,7 @@ register_template(
format_tools=ToolFormatter(tool_format="llama3"), format_tools=ToolFormatter(tool_format="llama3"),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|eot_id|>", "<|eom_id|>"], stop_words=["<|eot_id|>", "<|eom_id|>"],
replace_eos=True,
mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"), mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"),
) )
...@@ -1163,6 +1266,7 @@ register_template( ...@@ -1163,6 +1266,7 @@ register_template(
format_tools=ToolFormatter(tool_format="qwen"), format_tools=ToolFormatter(tool_format="qwen"),
default_system="You are a helpful assistant.", default_system="You are a helpful assistant.",
stop_words=["<|im_end|>"], stop_words=["<|im_end|>"],
replace_eos=True,
mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"), mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"),
) )
...@@ -1233,6 +1337,24 @@ register_template( ...@@ -1233,6 +1337,24 @@ register_template(
) )
# copied from qwen template
register_template(
name="mimo",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen"),
format_observation=StringFormatter(
slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
),
format_tools=ToolFormatter(tool_format="qwen"),
default_system="You are a helpful assistant.",
stop_words=["<|im_end|>"],
replace_eos=True,
template_class=ReasoningTemplate,
)
# copied from chatml template # copied from chatml template
register_template( register_template(
name="minicpm_v", name="minicpm_v",
...@@ -1363,6 +1485,7 @@ register_template( ...@@ -1363,6 +1485,7 @@ register_template(
), ),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<end_of_turn>"], stop_words=["<end_of_turn>"],
replace_eos=True,
mm_plugin=get_mm_plugin(name="paligemma", image_token="<image>"), mm_plugin=get_mm_plugin(name="paligemma", image_token="<image>"),
template_class=Llama2Template, template_class=Llama2Template,
) )
...@@ -1374,6 +1497,7 @@ register_template( ...@@ -1374,6 +1497,7 @@ register_template(
format_assistant=StringFormatter(slots=["{{content}}<|end|>\n"]), format_assistant=StringFormatter(slots=["{{content}}<|end|>\n"]),
format_system=StringFormatter(slots=["<|system|>\n{{content}}<|end|>\n"]), format_system=StringFormatter(slots=["<|system|>\n{{content}}<|end|>\n"]),
stop_words=["<|end|>"], stop_words=["<|end|>"],
replace_eos=True,
) )
...@@ -1384,6 +1508,7 @@ register_template( ...@@ -1384,6 +1508,7 @@ register_template(
format_system=StringFormatter(slots=["<|system|>\n{{content}}<|end|>\n"]), format_system=StringFormatter(slots=["<|system|>\n{{content}}<|end|>\n"]),
format_prefix=EmptyFormatter(slots=[{"<|endoftext|>"}]), format_prefix=EmptyFormatter(slots=[{"<|endoftext|>"}]),
stop_words=["<|end|>"], stop_words=["<|end|>"],
replace_eos=True,
) )
...@@ -1395,6 +1520,7 @@ register_template( ...@@ -1395,6 +1520,7 @@ register_template(
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>"]), format_assistant=StringFormatter(slots=["{{content}}<|im_end|>"]),
format_system=StringFormatter(slots=["<|im_start|>system<|im_sep|>{{content}}<|im_end|>"]), format_system=StringFormatter(slots=["<|im_start|>system<|im_sep|>{{content}}<|im_end|>"]),
stop_words=["<|im_end|>"], stop_words=["<|im_end|>"],
replace_eos=True,
) )
...@@ -1425,6 +1551,7 @@ register_template( ...@@ -1425,6 +1551,7 @@ register_template(
format_tools=ToolFormatter(tool_format="qwen"), format_tools=ToolFormatter(tool_format="qwen"),
default_system="You are Qwen, created by Alibaba Cloud. You are a helpful assistant.", default_system="You are Qwen, created by Alibaba Cloud. You are a helpful assistant.",
stop_words=["<|im_end|>"], stop_words=["<|im_end|>"],
replace_eos=True,
) )
...@@ -1440,6 +1567,8 @@ register_template( ...@@ -1440,6 +1567,8 @@ register_template(
), ),
format_tools=ToolFormatter(tool_format="qwen"), format_tools=ToolFormatter(tool_format="qwen"),
stop_words=["<|im_end|>"], stop_words=["<|im_end|>"],
replace_eos=True,
template_class=ReasoningTemplate,
) )
...@@ -1451,6 +1580,7 @@ register_template( ...@@ -1451,6 +1580,7 @@ register_template(
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
default_system="You are a helpful assistant.", default_system="You are a helpful assistant.",
stop_words=["<|im_end|>"], stop_words=["<|im_end|>"],
replace_eos=True,
mm_plugin=get_mm_plugin(name="qwen2_audio", audio_token="<|AUDIO|>"), mm_plugin=get_mm_plugin(name="qwen2_audio", audio_token="<|AUDIO|>"),
) )
...@@ -1468,6 +1598,7 @@ register_template( ...@@ -1468,6 +1598,7 @@ register_template(
format_tools=ToolFormatter(tool_format="qwen"), format_tools=ToolFormatter(tool_format="qwen"),
default_system="You are a helpful assistant.", default_system="You are a helpful assistant.",
stop_words=["<|im_end|>"], stop_words=["<|im_end|>"],
replace_eos=True,
mm_plugin=get_mm_plugin( mm_plugin=get_mm_plugin(
name="qwen2_omni", audio_token="<|AUDIO|>", image_token="<|IMAGE|>", video_token="<|VIDEO|>" name="qwen2_omni", audio_token="<|AUDIO|>", image_token="<|IMAGE|>", video_token="<|VIDEO|>"
), ),
...@@ -1486,6 +1617,7 @@ register_template( ...@@ -1486,6 +1617,7 @@ register_template(
format_tools=ToolFormatter(tool_format="qwen"), format_tools=ToolFormatter(tool_format="qwen"),
default_system="You are a helpful assistant.", default_system="You are a helpful assistant.",
stop_words=["<|im_end|>"], stop_words=["<|im_end|>"],
replace_eos=True,
mm_plugin=get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>", video_token="<|video_pad|>"), mm_plugin=get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>", video_token="<|video_pad|>"),
) )
...@@ -1503,6 +1635,20 @@ register_template( ...@@ -1503,6 +1635,20 @@ register_template(
) )
register_template(
name="seed_coder",
format_user=StringFormatter(
slots=[{"bos_token"}, "user\n{{content}}", {"eos_token"}, {"bos_token"}, "assistant\n"]
),
format_system=StringFormatter(slots=[{"bos_token"}, "system\n{{content}}", {"eos_token"}]),
default_system=(
"You are an AI programming assistant, utilizing the Seed-Coder model, developed by ByteDance Seed, "
"and you only answer questions related to computer science. For politically sensitive questions, "
"security and privacy issues, and other non-computer science questions, you will refuse to answer.\n\n"
),
)
# copied from llama3 template # copied from llama3 template
register_template( register_template(
name="skywork_o1", name="skywork_o1",
......
...@@ -93,6 +93,7 @@ class DefaultToolUtils(ToolUtils): ...@@ -93,6 +93,7 @@ class DefaultToolUtils(ToolUtils):
tool_text = "" tool_text = ""
tool_names = [] tool_names = []
for tool in tools: for tool in tools:
tool = tool.get("function", "") if tool.get("type") == "function" else tool
param_text = "" param_text = ""
for name, param in tool["parameters"]["properties"].items(): for name, param in tool["parameters"]["properties"].items():
required, enum, items = "", "", "" required, enum, items = "", "", ""
...@@ -124,11 +125,7 @@ class DefaultToolUtils(ToolUtils): ...@@ -124,11 +125,7 @@ class DefaultToolUtils(ToolUtils):
@override @override
@staticmethod @staticmethod
def function_formatter(functions: list["FunctionCall"]) -> str: def function_formatter(functions: list["FunctionCall"]) -> str:
function_text = "" return "\n".join([f"Action: {name}\nAction Input: {arguments}" for name, arguments in functions])
for name, arguments in functions:
function_text += f"Action: {name}\nAction Input: {arguments}\n"
return function_text
@override @override
@staticmethod @staticmethod
...@@ -159,6 +156,7 @@ class GLM4ToolUtils(ToolUtils): ...@@ -159,6 +156,7 @@ class GLM4ToolUtils(ToolUtils):
def tool_formatter(tools: list[dict[str, Any]]) -> str: def tool_formatter(tools: list[dict[str, Any]]) -> str:
tool_text = "" tool_text = ""
for tool in tools: for tool in tools:
tool = tool.get("function", "") if tool.get("type") == "function" else tool
tool_text += "\n\n## {name}\n\n{body}\n在调用上述函数时,请使用 Json 格式表示调用的参数。".format( tool_text += "\n\n## {name}\n\n{body}\n在调用上述函数时,请使用 Json 格式表示调用的参数。".format(
name=tool["name"], body=json.dumps(tool, indent=4, ensure_ascii=False) name=tool["name"], body=json.dumps(tool, indent=4, ensure_ascii=False)
) )
...@@ -200,7 +198,7 @@ class Llama3ToolUtils(ToolUtils): ...@@ -200,7 +198,7 @@ class Llama3ToolUtils(ToolUtils):
date = datetime.now().strftime("%d %b %Y") date = datetime.now().strftime("%d %b %Y")
tool_text = "" tool_text = ""
for tool in tools: for tool in tools:
wrapped_tool = {"type": "function", "function": tool} wrapped_tool = tool if tool.get("type") == "function" else {"type": "function", "function": tool}
tool_text += json.dumps(wrapped_tool, indent=4, ensure_ascii=False) + "\n\n" tool_text += json.dumps(wrapped_tool, indent=4, ensure_ascii=False) + "\n\n"
return LLAMA3_TOOL_PROMPT.format(date=date, tool_text=tool_text) return LLAMA3_TOOL_PROMPT.format(date=date, tool_text=tool_text)
...@@ -208,24 +206,23 @@ class Llama3ToolUtils(ToolUtils): ...@@ -208,24 +206,23 @@ class Llama3ToolUtils(ToolUtils):
@override @override
@staticmethod @staticmethod
def function_formatter(functions: list["FunctionCall"]) -> str: def function_formatter(functions: list["FunctionCall"]) -> str:
if len(functions) > 1: function_objects = [{"name": name, "parameters": json.loads(arguments)} for name, arguments in functions]
raise ValueError("Llama-3 does not support parallel functions.") return json.dumps(function_objects[0] if len(function_objects) == 1 else function_objects, ensure_ascii=False)
return f'{{"name": "{functions[0].name}", "parameters": {functions[0].arguments}}}'
@override @override
@staticmethod @staticmethod
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]: def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
try: try:
tool = json.loads(content.strip()) tools = json.loads(content.strip())
except json.JSONDecodeError: except json.JSONDecodeError:
return content return content
if "name" not in tool or "parameters" not in tool: tools = [tools] if not isinstance(tools, list) else tools
try:
return [FunctionCall(tool["name"], json.dumps(tool["parameters"], ensure_ascii=False)) for tool in tools]
except KeyError:
return content return content
return [FunctionCall(tool["name"], json.dumps(tool["parameters"], ensure_ascii=False))]
class MistralToolUtils(ToolUtils): class MistralToolUtils(ToolUtils):
r"""Mistral v0.3 tool using template.""" r"""Mistral v0.3 tool using template."""
...@@ -235,18 +232,16 @@ class MistralToolUtils(ToolUtils): ...@@ -235,18 +232,16 @@ class MistralToolUtils(ToolUtils):
def tool_formatter(tools: list[dict[str, Any]]) -> str: def tool_formatter(tools: list[dict[str, Any]]) -> str:
wrapped_tools = [] wrapped_tools = []
for tool in tools: for tool in tools:
wrapped_tools.append({"type": "function", "function": tool}) wrapped_tools.append(tool if tool.get("type") == "function" else {"type": "function", "function": tool})
return "[AVAILABLE_TOOLS] " + json.dumps(wrapped_tools, ensure_ascii=False) + "[/AVAILABLE_TOOLS]" return "[AVAILABLE_TOOLS] " + json.dumps(wrapped_tools, ensure_ascii=False) + "[/AVAILABLE_TOOLS]"
@override @override
@staticmethod @staticmethod
def function_formatter(functions: list["FunctionCall"]) -> str: def function_formatter(functions: list["FunctionCall"]) -> str:
function_texts = [] return json.dumps(
for name, arguments in functions: [{"name": name, "arguments": json.loads(arguments)} for name, arguments in functions], ensure_ascii=False
function_texts.append(f'{{"name": "{name}", "arguments": {arguments}}}') )
return "[" + ", ".join(function_texts) + "]"
@override @override
@staticmethod @staticmethod
...@@ -256,17 +251,11 @@ class MistralToolUtils(ToolUtils): ...@@ -256,17 +251,11 @@ class MistralToolUtils(ToolUtils):
except json.JSONDecodeError: except json.JSONDecodeError:
return content return content
if not isinstance(tools, list): tools = [tools] if not isinstance(tools, list) else tools
tools = [tools] try:
return [FunctionCall(tool["name"], json.dumps(tool["arguments"], ensure_ascii=False)) for tool in tools]
results = [] except KeyError:
for tool in tools: return content
if "name" not in tool or "arguments" not in tool:
return content
results.append(FunctionCall(tool["name"], json.dumps(tool["arguments"], ensure_ascii=False)))
return results
class QwenToolUtils(ToolUtils): class QwenToolUtils(ToolUtils):
...@@ -277,7 +266,7 @@ class QwenToolUtils(ToolUtils): ...@@ -277,7 +266,7 @@ class QwenToolUtils(ToolUtils):
def tool_formatter(tools: list[dict[str, Any]]) -> str: def tool_formatter(tools: list[dict[str, Any]]) -> str:
tool_text = "" tool_text = ""
for tool in tools: for tool in tools:
wrapped_tool = {"type": "function", "function": tool} wrapped_tool = tool if tool.get("type") == "function" else {"type": "function", "function": tool}
tool_text += "\n" + json.dumps(wrapped_tool, ensure_ascii=False) tool_text += "\n" + json.dumps(wrapped_tool, ensure_ascii=False)
return QWEN_TOOL_PROMPT.format(tool_text=tool_text) return QWEN_TOOL_PROMPT.format(tool_text=tool_text)
...@@ -285,13 +274,11 @@ class QwenToolUtils(ToolUtils): ...@@ -285,13 +274,11 @@ class QwenToolUtils(ToolUtils):
@override @override
@staticmethod @staticmethod
def function_formatter(functions: list["FunctionCall"]) -> str: def function_formatter(functions: list["FunctionCall"]) -> str:
function_texts = [] function_texts = [
for name, arguments in functions: json.dumps({"name": name, "arguments": json.loads(arguments)}, ensure_ascii=False)
function_texts.append( for name, arguments in functions
"<tool_call>\n" + f'{{"name": "{name}", "arguments": {arguments}}}' + "\n</tool_call>" ]
) return "\n".join([f"<tool_call>\n{text}\n</tool_call>" for text in function_texts])
return "\n".join(function_texts)
@override @override
@staticmethod @staticmethod
......
...@@ -533,6 +533,17 @@ register_model_group( ...@@ -533,6 +533,17 @@ register_model_group(
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V3", DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V3",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V3", DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V3",
}, },
"DeepSeek-V3-671B-0324-Chat": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V3-0324",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V3-0324",
},
},
template="deepseek3",
)
register_model_group(
models={
"DeepSeek-R1-1.5B-Distill": { "DeepSeek-R1-1.5B-Distill": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
...@@ -566,7 +577,7 @@ register_model_group( ...@@ -566,7 +577,7 @@ register_model_group(
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-R1", DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-R1",
}, },
}, },
template="deepseek3", template="deepseekr1",
) )
...@@ -737,6 +748,13 @@ register_model_group( ...@@ -737,6 +748,13 @@ register_model_group(
DownloadSource.DEFAULT: "THUDM/GLM-4-32B-0414", DownloadSource.DEFAULT: "THUDM/GLM-4-32B-0414",
DownloadSource.MODELSCOPE: "ZhipuAI/GLM-4-32B-0414", DownloadSource.MODELSCOPE: "ZhipuAI/GLM-4-32B-0414",
}, },
},
template="glm4",
)
register_model_group(
models={
"GLM-Z1-9B-0414-Chat": { "GLM-Z1-9B-0414-Chat": {
DownloadSource.DEFAULT: "THUDM/GLM-Z1-9B-0414", DownloadSource.DEFAULT: "THUDM/GLM-Z1-9B-0414",
DownloadSource.MODELSCOPE: "ZhipuAI/GLM-Z1-9B-0414", DownloadSource.MODELSCOPE: "ZhipuAI/GLM-Z1-9B-0414",
...@@ -746,7 +764,7 @@ register_model_group( ...@@ -746,7 +764,7 @@ register_model_group(
DownloadSource.MODELSCOPE: "ZhipuAI/GLM-Z1-32B-0414", DownloadSource.MODELSCOPE: "ZhipuAI/GLM-Z1-32B-0414",
}, },
}, },
template="glm4", template="glmz1",
) )
...@@ -869,12 +887,13 @@ register_model_group( ...@@ -869,12 +887,13 @@ register_model_group(
register_model_group( register_model_group(
models={ models={
"Granite-3.2-1B-A400M-Base": { "Granite-Vision-3.2-2B": {
DownloadSource.DEFAULT: "ibm-granite/granite-vision-3.2-2b", DownloadSource.DEFAULT: "ibm-granite/granite-vision-3.2-2b",
DownloadSource.MODELSCOPE: "AI-ModelScope/granite-vision-3.2-2b", DownloadSource.MODELSCOPE: "AI-ModelScope/granite-vision-3.2-2b",
}, },
}, },
template="granite3_vision", template="granite3_vision",
multimodal=True,
) )
...@@ -1398,6 +1417,29 @@ register_model_group( ...@@ -1398,6 +1417,29 @@ register_model_group(
) )
register_model_group(
models={
"MiMo-7B-Base": {
DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-7B-Base",
DownloadSource.MODELSCOPE: "XiaomiMiMo/MiMo-7B-Base",
},
"MiMo-7B-Instruct": {
DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-7B-SFT",
DownloadSource.MODELSCOPE: "XiaomiMiMo/MiMo-7B-SFT",
},
"MiMo-7B-Instruct-RL": {
DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-7B-RL",
DownloadSource.MODELSCOPE: "XiaomiMiMo/MiMo-7B-RL",
},
"MiMo-7B-RL-ZERO": {
DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-7B-RL-ZERO",
DownloadSource.MODELSCOPE: "XiaomiMiMo/MiMo-7B-RL-ZERO",
},
},
template="mimo",
)
register_model_group( register_model_group(
models={ models={
"MiniCPM-2B-SFT-Chat": { "MiniCPM-2B-SFT-Chat": {
...@@ -2461,6 +2503,38 @@ register_model_group( ...@@ -2461,6 +2503,38 @@ register_model_group(
DownloadSource.DEFAULT: "Qwen/Qwen3-235B-A22B", DownloadSource.DEFAULT: "Qwen/Qwen3-235B-A22B",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-235B-A22B", DownloadSource.MODELSCOPE: "Qwen/Qwen3-235B-A22B",
}, },
"Qwen3-0.6B-Instruct-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen3-0.6B-GPTQ-Int8",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-0.6B-GPTQ-Int8",
},
"Qwen3-1.7B-Instruct-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen3-1.7B-GPTQ-Int8",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-1.7B-GPTQ-Int8",
},
"Qwen3-4B-Instruct-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen3-4B-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-4B-AWQ",
},
"Qwen3-8B-Instruct-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen3-8B-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-8B-AWQ",
},
"Qwen3-14B-Instruct-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen3-14B-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-14B-AWQ",
},
"Qwen3-32B-Instruct-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen3-32B-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-32B-AWQ",
},
"Qwen3-30B-A3B-Instruct-GPTQ-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen3-30B-A3B-GPTQ-Int4",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-30B-A3B-GPTQ-Int4",
},
"Qwen3-235B-A22B-Instruct-GPTQ-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen3-235B-A22B-GPTQ-Int4",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-235B-A22B-GPTQ-Int4",
},
}, },
template="qwen3", template="qwen3",
) )
...@@ -2484,10 +2558,22 @@ register_model_group( ...@@ -2484,10 +2558,22 @@ register_model_group(
register_model_group( register_model_group(
models={ models={
"Qwen2.5-Omni-3B": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Omni-3B",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-Omni-3B",
},
"Qwen2.5-Omni-7B": { "Qwen2.5-Omni-7B": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Omni-7B", DownloadSource.DEFAULT: "Qwen/Qwen2.5-Omni-7B",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-Omni-7B", DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-Omni-7B",
} },
"Qwen2.5-Omni-7B-GPTQ-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Omni-7B-GPTQ-Int4",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-Omni-7B-GPTQ-Int4",
},
"Qwen2.5-Omni-7B-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Omni-7B-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-Omni-7B-AWQ",
},
}, },
template="qwen2_omni", template="qwen2_omni",
multimodal=True, multimodal=True,
...@@ -2598,15 +2684,17 @@ register_model_group( ...@@ -2598,15 +2684,17 @@ register_model_group(
register_model_group( register_model_group(
models={ models={
"SOLAR-10.7B-v1.0": { "Seed-Coder-8B-Base": {
DownloadSource.DEFAULT: "upstage/SOLAR-10.7B-v1.0", DownloadSource.DEFAULT: "ByteDance-Seed/Seed-Coder-8B-Base",
}, },
"SOLAR-10.7B-Instruct-v1.0": { "Seed-Coder-8B-Instruct": {
DownloadSource.DEFAULT: "upstage/SOLAR-10.7B-Instruct-v1.0", DownloadSource.DEFAULT: "ByteDance-Seed/Seed-Coder-8B-Instruct",
DownloadSource.MODELSCOPE: "AI-ModelScope/SOLAR-10.7B-Instruct-v1.0", },
"Seed-Coder-8B-Instruct-Reasoning": {
DownloadSource.DEFAULT: "ByteDance-Seed/Seed-Coder-8B-Reasoning-bf16",
}, },
}, },
template="solar", template="seed_coder",
) )
...@@ -2631,6 +2719,20 @@ register_model_group( ...@@ -2631,6 +2719,20 @@ register_model_group(
) )
register_model_group(
models={
"SOLAR-10.7B-v1.0": {
DownloadSource.DEFAULT: "upstage/SOLAR-10.7B-v1.0",
},
"SOLAR-10.7B-Instruct-v1.0": {
DownloadSource.DEFAULT: "upstage/SOLAR-10.7B-Instruct-v1.0",
DownloadSource.MODELSCOPE: "AI-ModelScope/SOLAR-10.7B-Instruct-v1.0",
},
},
template="solar",
)
register_model_group( register_model_group(
models={ models={
"StarCoder2-3B": { "StarCoder2-3B": {
......
...@@ -79,20 +79,27 @@ def check_version(requirement: str, mandatory: bool = False) -> None: ...@@ -79,20 +79,27 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
logger.warning_rank0_once("Version checking has been disabled, may lead to unexpected behaviors.") logger.warning_rank0_once("Version checking has been disabled, may lead to unexpected behaviors.")
return return
if "gptmodel" in requirement or "autoawq" in requirement:
pip_command = f"pip install {requirement} --no-build-isolation"
else:
pip_command = f"pip install {requirement}"
if mandatory: if mandatory:
hint = f"To fix: run `pip install {requirement}`." hint = f"To fix: run `{pip_command}`."
else: else:
hint = f"To fix: run `pip install {requirement}` or set `DISABLE_VERSION_CHECK=1` to skip this check." hint = f"To fix: run `{pip_command}` or set `DISABLE_VERSION_CHECK=1` to skip this check."
require_version(requirement, hint) require_version(requirement, hint)
def check_dependencies() -> None: def check_dependencies() -> None:
r"""Check the version of the required packages.""" r"""Check the version of the required packages."""
check_version("transformers>=4.45.0,<=4.51.3,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0") check_version(
check_version("datasets>=2.16.0,<=3.5.0") "transformers>=4.45.0,<=4.52.1,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0,!=4.52.0"
check_version("accelerate>=0.34.0,<=1.6.0") )
check_version("peft>=0.14.0,<=0.15.1") check_version("datasets>=2.16.0,<=3.6.0")
check_version("accelerate>=0.34.0,<=1.7.0")
check_version("peft>=0.14.0,<=0.15.2")
check_version("trl>=0.8.6,<=0.9.6") check_version("trl>=0.8.6,<=0.9.6")
if is_transformers_version_greater_than("4.46.0") and not is_transformers_version_greater_than("4.48.1"): if is_transformers_version_greater_than("4.46.0") and not is_transformers_version_greater_than("4.48.1"):
logger.warning_rank0_once("There are known bugs in transformers v4.46.0-v4.48.0, please use other versions.") logger.warning_rank0_once("There are known bugs in transformers v4.46.0-v4.48.0, please use other versions.")
......
...@@ -99,6 +99,10 @@ class DataArguments: ...@@ -99,6 +99,10 @@ class DataArguments:
default=0.0, default=0.0,
metadata={"help": "Size of the validation set, should be an integer or a float in range `[0,1)`."}, metadata={"help": "Size of the validation set, should be an integer or a float in range `[0,1)`."},
) )
eval_on_each_dataset: bool = field(
default=False,
metadata={"help": "Whether or not to evaluate on each dataset separately."},
)
packing: Optional[bool] = field( packing: Optional[bool] = field(
default=None, default=None,
metadata={"help": "Enable sequences packing in training. Will automatically enable in pre-training."}, metadata={"help": "Enable sequences packing in training. Will automatically enable in pre-training."},
...@@ -111,6 +115,14 @@ class DataArguments: ...@@ -111,6 +115,14 @@ class DataArguments:
default=None, default=None,
metadata={"help": "Tool format to use for constructing function calling examples."}, metadata={"help": "Tool format to use for constructing function calling examples."},
) )
default_system: Optional[str] = field(
default=None,
metadata={"help": "Override the default system message in the template."},
)
enable_thinking: Optional[bool] = field(
default=True,
metadata={"help": "Whether or not to enable thinking mode for reasoning models."},
)
tokenized_path: Optional[str] = field( tokenized_path: Optional[str] = field(
default=None, default=None,
metadata={ metadata={
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from typing import Any, Optional from typing import Any
from transformers import GenerationConfig from transformers import GenerationConfig
...@@ -62,10 +62,6 @@ class GeneratingArguments: ...@@ -62,10 +62,6 @@ class GeneratingArguments:
default=1.0, default=1.0,
metadata={"help": "Exponential penalty to the length that is used with beam-based generation."}, metadata={"help": "Exponential penalty to the length that is used with beam-based generation."},
) )
default_system: Optional[str] = field(
default=None,
metadata={"help": "Default system message to use in chat completion."},
)
skip_special_tokens: bool = field( skip_special_tokens: bool = field(
default=True, default=True,
metadata={"help": "Whether or not to remove special tokens in the decoding."}, metadata={"help": "Whether or not to remove special tokens in the decoding."},
......
...@@ -235,10 +235,6 @@ class ProcessorArguments: ...@@ -235,10 +235,6 @@ class ProcessorArguments:
default=False, default=False,
metadata={"help": "Whether to crop the image to patches for internvl."}, metadata={"help": "Whether to crop the image to patches for internvl."},
) )
use_audio_in_video: bool = field(
default=False,
metadata={"help": "Whether or not to use audio in video inputs."},
)
video_max_pixels: int = field( video_max_pixels: int = field(
default=256 * 256, default=256 * 256,
metadata={"help": "The maximum number of pixels of video inputs."}, metadata={"help": "The maximum number of pixels of video inputs."},
...@@ -255,6 +251,10 @@ class ProcessorArguments: ...@@ -255,6 +251,10 @@ class ProcessorArguments:
default=128, default=128,
metadata={"help": "The maximum number of sampled frames for video inputs."}, metadata={"help": "The maximum number of sampled frames for video inputs."},
) )
use_audio_in_video: bool = field(
default=False,
metadata={"help": "Whether or not to use audio in video inputs."},
)
audio_sampling_rate: int = field( audio_sampling_rate: int = field(
default=16000, default=16000,
metadata={"help": "The sampling rate of audio inputs."}, metadata={"help": "The sampling rate of audio inputs."},
...@@ -364,6 +364,12 @@ class SGLangArguments: ...@@ -364,6 +364,12 @@ class SGLangArguments:
default=None, default=None,
metadata={"help": "Config to initialize the SGLang engine. Please use JSON strings."}, metadata={"help": "Config to initialize the SGLang engine. Please use JSON strings."},
) )
sglang_lora_backend: Literal["triton", "flashinfer"] = field(
default="triton",
metadata={
"help": "The backend of running GEMM kernels for Lora modules. Recommend using the Triton LoRA backend for better performance and stability."
},
)
def __post_init__(self): def __post_init__(self):
if isinstance(self.sglang_config, str) and self.sglang_config.startswith("{"): if isinstance(self.sglang_config, str) and self.sglang_config.startswith("{"):
......
...@@ -148,10 +148,10 @@ def _check_extra_dependencies( ...@@ -148,10 +148,10 @@ def _check_extra_dependencies(
check_version("mixture-of-depth>=1.1.6", mandatory=True) check_version("mixture-of-depth>=1.1.6", mandatory=True)
if model_args.infer_backend == EngineName.VLLM: if model_args.infer_backend == EngineName.VLLM:
check_version("vllm>=0.4.3,<=0.8.4") check_version("vllm>=0.4.3,<=0.8.6")
check_version("vllm", mandatory=True) check_version("vllm", mandatory=True)
elif model_args.infer_backend == EngineName.SGLANG: elif model_args.infer_backend == EngineName.SGLANG:
check_version("sglang>=0.4.4") check_version("sglang>=0.4.5")
check_version("sglang", mandatory=True) check_version("sglang", mandatory=True)
if finetuning_args.use_galore: if finetuning_args.use_galore:
......
...@@ -64,6 +64,7 @@ class RayArguments: ...@@ -64,6 +64,7 @@ class RayArguments:
raise ValueError( raise ValueError(
f"ray_storage_filesystem must be one of ['s3', 'gs', 'gcs'], got {self.ray_storage_filesystem}" f"ray_storage_filesystem must be one of ['s3', 'gs', 'gcs'], got {self.ray_storage_filesystem}"
) )
import pyarrow.fs as fs import pyarrow.fs as fs
if self.ray_storage_filesystem == "s3": if self.ray_storage_filesystem == "s3":
......
...@@ -29,10 +29,8 @@ if TYPE_CHECKING: ...@@ -29,10 +29,8 @@ if TYPE_CHECKING:
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
def configure_attn_implementation( def configure_attn_implementation(config: "PretrainedConfig", model_args: "ModelArguments") -> None:
config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool if getattr(config, "model_type", None) == "gemma2":
) -> None:
if getattr(config, "model_type", None) == "gemma2" and is_trainable:
if model_args.flash_attn == AttentionFunction.AUTO or model_args.flash_attn == AttentionFunction.FA2: if model_args.flash_attn == AttentionFunction.AUTO or model_args.flash_attn == AttentionFunction.FA2:
if is_flash_attn_2_available(): if is_flash_attn_2_available():
if model_args.flash_attn != AttentionFunction.FA2: if model_args.flash_attn != AttentionFunction.FA2:
......
...@@ -45,16 +45,24 @@ def apply_liger_kernel( ...@@ -45,16 +45,24 @@ def apply_liger_kernel(
from liger_kernel.transformers import apply_liger_kernel_to_gemma3 as apply_liger_kernel from liger_kernel.transformers import apply_liger_kernel_to_gemma3 as apply_liger_kernel
elif model_type == "gemma3_text": elif model_type == "gemma3_text":
from liger_kernel.transformers import apply_liger_kernel_to_gemma3_text as apply_liger_kernel from liger_kernel.transformers import apply_liger_kernel_to_gemma3_text as apply_liger_kernel
elif model_type == "paligemma": elif model_type == "glm4":
from liger_kernel.transformers import apply_liger_kernel_to_paligemma as apply_liger_kernel from liger_kernel.transformers import apply_liger_kernel_to_glm4 as apply_liger_kernel
elif model_type == "granite":
from liger_kernel.transformers import apply_liger_kernel_to_granite as apply_liger_kernel
elif model_type == "llama": elif model_type == "llama":
from liger_kernel.transformers import apply_liger_kernel_to_llama as apply_liger_kernel from liger_kernel.transformers import apply_liger_kernel_to_llama as apply_liger_kernel
elif model_type == "llava":
from liger_kernel.transformers import apply_liger_kernel_to_llava as apply_liger_kernel
elif model_type == "mistral": elif model_type == "mistral":
from liger_kernel.transformers import apply_liger_kernel_to_mistral as apply_liger_kernel from liger_kernel.transformers import apply_liger_kernel_to_mistral as apply_liger_kernel
elif model_type == "mixtral": elif model_type == "mixtral":
from liger_kernel.transformers import apply_liger_kernel_to_mixtral as apply_liger_kernel from liger_kernel.transformers import apply_liger_kernel_to_mixtral as apply_liger_kernel
elif model_type == "mllama": elif model_type == "mllama":
from liger_kernel.transformers import apply_liger_kernel_to_mllama as apply_liger_kernel from liger_kernel.transformers import apply_liger_kernel_to_mllama as apply_liger_kernel
elif model_type == "olmo2":
from liger_kernel.transformers import apply_liger_kernel_to_olmo2 as apply_liger_kernel
elif model_type == "paligemma":
from liger_kernel.transformers import apply_liger_kernel_to_paligemma as apply_liger_kernel
elif model_type == "phi3": elif model_type == "phi3":
from liger_kernel.transformers import apply_liger_kernel_to_phi3 as apply_liger_kernel from liger_kernel.transformers import apply_liger_kernel_to_phi3 as apply_liger_kernel
elif model_type == "qwen2": elif model_type == "qwen2":
...@@ -63,6 +71,8 @@ def apply_liger_kernel( ...@@ -63,6 +71,8 @@ def apply_liger_kernel(
from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl as apply_liger_kernel from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl as apply_liger_kernel
elif model_type == "qwen2_5_vl": elif model_type == "qwen2_5_vl":
from liger_kernel.transformers import apply_liger_kernel_to_qwen2_5_vl as apply_liger_kernel from liger_kernel.transformers import apply_liger_kernel_to_qwen2_5_vl as apply_liger_kernel
elif model_type == "qwen3":
from liger_kernel.transformers import apply_liger_kernel_to_qwen3 as apply_liger_kernel
else: else:
logger.warning_rank0("Current model does not support liger kernel.") logger.warning_rank0("Current model does not support liger kernel.")
return return
......
...@@ -99,27 +99,29 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None: ...@@ -99,27 +99,29 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None: def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
if not is_trainable or not model_args.moe_aux_loss_coef:
return
model_type = getattr(config, "model_type", None) model_type = getattr(config, "model_type", None)
if model_args.moe_aux_loss_coef is not None: if model_type in [
if model_type in [ "dbrx",
"dbrx", "granitemoe",
"granitemoe", "jamba",
"jamba", "jetmoe",
"jetmoe", "llama4",
"llama4", "mixtral",
"mixtral", "olmoe",
"olmoe", "phimoe",
"phimoe", "qwen2_moe",
"qwen2_moe", "qwen3_moe",
"qwen3_moe", ]:
]: setattr(config, "output_router_logits", True)
setattr(config, "output_router_logits", is_trainable)
if model_type in ["granitemoe", "jamba", "llama4", "mixtral", "olmoe", "phimoe", "qwen2_moe", "qwen3_moe"]:
if model_type in ["granitemoe", "jamba", "llama4", "mixtral", "olmoe", "phimoe", "qwen2_moe", "qwen3_moe"]: setattr(config, "router_aux_loss_coef", model_args.moe_aux_loss_coef)
setattr(config, "router_aux_loss_coef", model_args.moe_aux_loss_coef)
elif model_type == "deepseek":
elif model_type == "deepseek": setattr(config, "aux_loss_alpha", model_args.moe_aux_loss_coef)
setattr(config, "aux_loss_alpha", model_args.moe_aux_loss_coef)
elif model_type == "jetmoe":
elif model_type == "jetmoe": setattr(config, "aux_loss_coef", model_args.moe_aux_loss_coef)
setattr(config, "aux_loss_coef", model_args.moe_aux_loss_coef)
...@@ -97,7 +97,7 @@ def configure_quantization( ...@@ -97,7 +97,7 @@ def configure_quantization(
quant_method = quantization_config.get("quant_method", "") quant_method = quantization_config.get("quant_method", "")
if quant_method == QuantizationMethod.GPTQ: if quant_method == QuantizationMethod.GPTQ:
check_version("auto_gptq>=0.5.0", mandatory=True) check_version("gptqmodel>=2.0.0", mandatory=True)
quantization_config.pop("disable_exllama", None) # remove deprecated args quantization_config.pop("disable_exllama", None) # remove deprecated args
quantization_config["use_exllama"] = False # disable exllama quantization_config["use_exllama"] = False # disable exllama
...@@ -111,12 +111,12 @@ def configure_quantization( ...@@ -111,12 +111,12 @@ def configure_quantization(
quant_bits = quantization_config.get("bits", "?") quant_bits = quantization_config.get("bits", "?")
logger.info_rank0(f"Loading {quant_bits}-bit {quant_method.upper()}-quantized model.") logger.info_rank0(f"Loading {quant_bits}-bit {quant_method.upper()}-quantized model.")
elif model_args.export_quantization_bit is not None: # auto-gptq elif model_args.export_quantization_bit is not None: # gptqmodel
if model_args.export_quantization_bit not in [8, 4, 3, 2]: if model_args.export_quantization_bit not in [8, 4, 3, 2]:
raise ValueError("AutoGPTQ only accepts 2/3/4/8-bit quantization.") raise ValueError("AutoGPTQ only accepts 2/3/4/8-bit quantization.")
check_version("optimum>=1.17.0", mandatory=True) check_version("optimum>=1.24.0", mandatory=True)
check_version("auto_gptq>=0.5.0", mandatory=True) check_version("gptqmodel>=2.0.0", mandatory=True)
from accelerate.utils import get_max_memory from accelerate.utils import get_max_memory
if getattr(config, "model_type", None) == "chatglm": if getattr(config, "model_type", None) == "chatglm":
...@@ -142,7 +142,8 @@ def configure_quantization( ...@@ -142,7 +142,8 @@ def configure_quantization(
) )
init_kwargs["device_map"] = "auto" init_kwargs["device_map"] = "auto"
init_kwargs["max_memory"] = get_max_memory() init_kwargs["max_memory"] = get_max_memory()
logger.info_rank0(f"Quantizing model to {model_args.export_quantization_bit} bit with AutoGPTQ.") model_args.compute_dtype = torch.float16 # force fp16 for gptqmodel
logger.info_rank0(f"Quantizing model to {model_args.export_quantization_bit} bit with GPTQModel.")
elif model_args.quantization_bit is not None: # on-the-fly elif model_args.quantization_bit is not None: # on-the-fly
if model_args.quantization_method == QuantizationMethod.BNB: if model_args.quantization_method == QuantizationMethod.BNB:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment