"vscode:/vscode.git/clone" did not exist on "b87220dd0a2b73a6b66ec7e139f5d919b3de9aca"
Commit 8293100a authored by luopl's avatar luopl
Browse files

update to 0.9.2.dev0

parent 2778a3d0
......@@ -168,7 +168,7 @@ async def create_chat_completion_response(
if isinstance(result, list):
tool_calls = []
for tool in result:
function = Function(name=tool[0], arguments=tool[1])
function = Function(name=tool.name, arguments=tool.arguments)
tool_calls.append(FunctionCall(id=f"call_{uuid.uuid4().hex}", function=function))
response_message = ChatCompletionMessage(role=Role.ASSISTANT, tool_calls=tool_calls)
......
......@@ -63,7 +63,7 @@ class HuggingfaceEngine(BaseEngine):
try:
asyncio.get_event_loop()
except RuntimeError:
logger.warning_once("There is no current event loop, creating a new one.")
logger.warning_rank0_once("There is no current event loop, creating a new one.")
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
......@@ -133,7 +133,7 @@ class HuggingfaceEngine(BaseEngine):
if repetition_penalty is not None
else generating_args["repetition_penalty"],
length_penalty=length_penalty if length_penalty is not None else generating_args["length_penalty"],
eos_token_id=[tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids,
eos_token_id=template.get_stop_token_ids(tokenizer),
pad_token_id=tokenizer.pad_token_id,
)
)
......@@ -168,11 +168,21 @@ class HuggingfaceEngine(BaseEngine):
for key, value in mm_inputs.items():
if isinstance(value, list) and all(isinstance(v, torch.Tensor) for v in value): # for pixtral inputs
value = torch.stack(value) # assume they have same sizes
elif isinstance(value, list) and all(isinstance(v, list) for v in value): # for minicpmv inputs
value = torch.stack([torch.stack(v) for v in value])
elif not isinstance(value, torch.Tensor):
value = torch.tensor(value)
if torch.is_floating_point(value): # cast data dtype for paligemma
value = value.to(model.dtype)
gen_kwargs[key] = value.to(model.device)
if getattr(model.config, "model_type", None) in ["minicpmv", "minicpmo"]:
gen_kwargs["input_ids"] = inputs
del gen_kwargs["image_sizes"]
gen_kwargs["tokenizer"] = tokenizer
return gen_kwargs, prompt_length
@staticmethod
......@@ -204,8 +214,13 @@ class HuggingfaceEngine(BaseEngine):
input_kwargs,
)
generate_output = model.generate(**gen_kwargs)
if isinstance(generate_output, tuple):
generate_output = generate_output[1][0] # post-process the minicpm_o output
response_ids = generate_output[:, prompt_length:]
response = tokenizer.batch_decode(response_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
response = tokenizer.batch_decode(
response_ids, skip_special_tokens=generating_args["skip_special_tokens"], clean_up_tokenization_spaces=True
)
results = []
for i in range(len(response)):
eos_index = (response_ids[i] == tokenizer.eos_token_id).nonzero()
......@@ -249,7 +264,9 @@ class HuggingfaceEngine(BaseEngine):
videos,
input_kwargs,
)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
streamer = TextIteratorStreamer(
tokenizer, skip_prompt=True, skip_special_tokens=generating_args["skip_special_tokens"]
)
gen_kwargs["streamer"] = streamer
thread = Thread(target=model.generate, kwargs=gen_kwargs, daemon=True)
thread.start()
......
......@@ -19,7 +19,7 @@ from typing_extensions import override
from ..data import get_template_and_fix_tokenizer
from ..extras import logging
from ..extras.constants import IMAGE_PLACEHOLDER
from ..extras.constants import IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
from ..extras.misc import get_device_count
from ..extras.packages import is_pillow_available, is_vllm_available
from ..model import load_config, load_tokenizer
......@@ -67,11 +67,12 @@ class VllmEngine(BaseEngine):
self.processor = tokenizer_module["processor"]
self.tokenizer.padding_side = "left"
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args)
self.template.mm_plugin.expand_mm_tokens = False # for vllm generate
self.generating_args = generating_args.to_dict()
engine_args = {
"model": model_args.model_name_or_path,
"trust_remote_code": True,
"trust_remote_code": model_args.trust_remote_code,
"download_dir": model_args.cache_dir,
"dtype": model_args.infer_dtype,
"max_model_len": model_args.vllm_maxlen,
......@@ -83,6 +84,9 @@ class VllmEngine(BaseEngine):
"enable_lora": model_args.adapter_name_or_path is not None,
"max_lora_rank": model_args.vllm_max_lora_rank,
}
if self.template.mm_plugin.__class__.__name__ != "BasePlugin":
engine_args["limit_mm_per_prompt"] = {"image": 4, "video": 2}
if isinstance(model_args.vllm_config, dict):
engine_args.update(model_args.vllm_config)
......@@ -108,19 +112,21 @@ class VllmEngine(BaseEngine):
**input_kwargs,
) -> AsyncIterator["RequestOutput"]:
request_id = f"chatcmpl-{uuid.uuid4().hex}"
mm_input_dict = {"images": [], "videos": [], "imglens": [0], "vidlens": [0]}
if images is not None:
mm_input_dict.update({"images": images, "imglens": [len(images)]})
if not any(IMAGE_PLACEHOLDER in message["content"] for message in messages):
messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"]
if self.template.mm_plugin.__class__.__name__ == "Qwen2vlPlugin": # temporary solution
image_str = f"<|vision_start|>{self.template.mm_plugin.image_token}<|vision_end|>"
else:
image_str = self.template.mm_plugin.image_token or ""
if videos is not None:
mm_input_dict.update({"videos": videos, "vidlens": [len(videos)]})
if not any(VIDEO_PLACEHOLDER in message["content"] for message in messages):
messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"]
paired_messages = [
{"role": message["role"], "content": message["content"].replace(IMAGE_PLACEHOLDER, image_str)}
for message in messages
] + [{"role": "assistant", "content": ""}]
messages = self.template.mm_plugin.process_messages(
messages, mm_input_dict["images"], mm_input_dict["videos"], self.processor
)
paired_messages = messages + [{"role": "assistant", "content": ""}]
system = system or self.generating_args["default_system"]
prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools)
prompt_length = len(prompt_ids)
......@@ -162,13 +168,13 @@ class VllmEngine(BaseEngine):
top_p=(top_p if top_p is not None else self.generating_args["top_p"]) or 1.0, # top_p must > 0
top_k=top_k if top_k is not None else self.generating_args["top_k"],
stop=stop,
stop_token_ids=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
stop_token_ids=self.template.get_stop_token_ids(self.tokenizer),
max_tokens=max_tokens,
skip_special_tokens=True,
skip_special_tokens=self.generating_args["skip_special_tokens"],
)
if images is not None: # add image features
image_data = []
multi_modal_data = {"image": []}
for image in images:
if not isinstance(image, (str, ImageObject)):
raise ValueError(f"Expected image input is a path or PIL.Image, but got {type(image)}.")
......@@ -176,9 +182,7 @@ class VllmEngine(BaseEngine):
if isinstance(image, str):
image = Image.open(image).convert("RGB")
image_data.append(image)
multi_modal_data = {"image": image_data}
multi_modal_data["image"].append(image)
else:
multi_modal_data = None
......
......@@ -24,7 +24,7 @@ from .chat.chat_model import run_chat
from .eval.evaluator import run_eval
from .extras import logging
from .extras.env import VERSION, print_env
from .extras.misc import get_device_count
from .extras.misc import get_device_count, use_ray
from .train.tuner import export_model, run_exp
from .webui.interface import run_web_demo, run_web_ui
......@@ -87,7 +87,7 @@ def main():
export_model()
elif command == Command.TRAIN:
force_torchrun = os.getenv("FORCE_TORCHRUN", "0").lower() in ["true", "1"]
if force_torchrun or get_device_count() > 1:
if force_torchrun or (get_device_count() > 1 and not use_ray()):
master_addr = os.getenv("MASTER_ADDR", "127.0.0.1")
master_port = os.getenv("MASTER_PORT", str(random.randint(20001, 29999)))
logger.info_rank0(f"Initializing distributed tasks at: {master_addr}:{master_port}")
......@@ -120,3 +120,7 @@ def main():
print(USAGE)
else:
raise NotImplementedError(f"Unknown command: {command}.")
if __name__ == "__main__":
main()
......@@ -19,8 +19,16 @@ from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Sequence
import torch
import torch.nn.functional as F
from transformers import DataCollatorForSeq2Seq
from ..extras.constants import IGNORE_INDEX, IMAGE_PLACEHOLDER
from ..extras.packages import is_pillow_available
if is_pillow_available():
from PIL import Image
if TYPE_CHECKING:
from transformers import ProcessorMixin
......@@ -72,12 +80,16 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
r"""
Data collator that supports VLMs.
Features should contain input_ids, attention_mask, labels and images.
Features should contain input_ids, attention_mask, labels, and optionally contain images and videos.
"""
template: Optional["Template"] = None
processor: Optional["ProcessorMixin"] = None
def __post_init__(self):
if self.template is None:
raise ValueError("Template is required for MultiModalDataCollator.")
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
batch_images, batch_videos, batch_imglens, batch_vidlens, batch_input_ids = [], [], [], [], []
for feature in features:
......@@ -89,6 +101,29 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
batch_vidlens.append(len(videos))
batch_input_ids.append(feature["input_ids"])
if (
self.processor is not None and sum(batch_imglens) == 0 and sum(batch_vidlens) == 0
): # avoid process hanging in zero3/fsdp case
fake_messages = [{"role": "user", "content": IMAGE_PLACEHOLDER}]
fake_images = [Image.new("RGB", (64, 64), (255, 255, 255))]
fake_messages = self.template.mm_plugin.process_messages(fake_messages, fake_images, [], self.processor)
fake_input_ids = self.tokenizer.encode(fake_messages[0]["content"], add_special_tokens=False)
fake_input_ids, _ = self.template.mm_plugin.process_token_ids(
fake_input_ids, None, fake_images, [], self.tokenizer, self.processor
)
if self.tokenizer.padding_side == "right":
features[0]["input_ids"] = features[0]["input_ids"] + fake_input_ids
features[0]["attention_mask"] = features[0]["attention_mask"] + [0] * len(fake_input_ids)
features[0]["labels"] = features[0]["labels"] + [IGNORE_INDEX] * len(fake_input_ids)
else:
features[0]["input_ids"] = fake_input_ids + features[0]["input_ids"]
features[0]["attention_mask"] = [0] * len(fake_input_ids) + features[0]["attention_mask"]
features[0]["labels"] = [IGNORE_INDEX] * len(fake_input_ids) + features[0]["labels"]
batch_images = fake_images
batch_imglens[0] = 1
batch_input_ids[0] = features[0]["input_ids"]
mm_inputs = self.template.mm_plugin.get_mm_inputs(
batch_images, batch_videos, batch_imglens, batch_vidlens, batch_input_ids, self.processor
)
......@@ -98,10 +133,30 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
feature["token_type_ids"] = token_type_ids[i]
features: Dict[str, "torch.Tensor"] = super().__call__(features)
if self.model is not None and hasattr(self.model, "get_rope_index"): # for qwen2vl mrope
features["position_ids"], features["rope_deltas"] = self.model.get_rope_index(
input_ids=features["input_ids"],
image_grid_thw=mm_inputs.get("image_grid_thw", None),
video_grid_thw=mm_inputs.get("video_grid_thw", None),
attention_mask=features["attention_mask"],
)
if "cross_attention_mask" in mm_inputs: # for mllama inputs when pad_to_multiple_of is enabled
cross_attention_mask = mm_inputs.pop("cross_attention_mask")
seq_len = features["input_ids"].size(1)
orig_len = cross_attention_mask.size(1)
mm_inputs["cross_attention_mask"] = F.pad(cross_attention_mask, (0, 0, 0, 0, 0, seq_len - orig_len))
features.update(mm_inputs)
if isinstance(features.get("pixel_values"), list): # for pixtral inputs
features = features.data # use default_collate() instead of BatchEncoding.to()
if "image_bound" in features: # for minicpmv inputs
bsz, seq_length = features["input_ids"].shape
features["position_ids"] = torch.arange(seq_length).long().repeat(bsz, 1)
return {"data": features, "input_ids": features["input_ids"], "labels": features["labels"]}
return features
......@@ -120,6 +175,10 @@ class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq):
if self.block_diag_attn and self.attn_implementation != "flash_attention_2":
features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype)
for key, value in features.items(): # cast data dtype for paligemma
if torch.is_tensor(value) and torch.is_floating_point(value):
features[key] = value.to(self.compute_dtype)
return features
......
......@@ -56,12 +56,12 @@ def merge_dataset(
return all_datasets[0]
elif data_args.mix_strategy == "concat":
if data_args.streaming:
logger.warning_once("The samples between different datasets will not be mixed in streaming mode.")
logger.warning_rank0_once("The samples between different datasets will not be mixed in streaming mode.")
return concatenate_datasets(all_datasets)
elif data_args.mix_strategy.startswith("interleave"):
if not data_args.streaming:
logger.warning_once("We recommend using `mix_strategy=concat` in non-streaming mode.")
logger.warning_rank0_once("We recommend using `mix_strategy=concat` in non-streaming mode.")
return interleave_datasets(
datasets=all_datasets,
......
......@@ -16,16 +16,12 @@ import json
import re
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
from typing import List, Optional, Union
from typing_extensions import override
from .data_utils import SLOTS
from .tool_utils import get_tool_utils
if TYPE_CHECKING:
from .tool_utils import FunctionCall
from .tool_utils import FunctionCall, get_tool_utils
@dataclass
......@@ -98,33 +94,31 @@ class StringFormatter(Formatter):
@dataclass
class FunctionFormatter(Formatter):
def __post_init__(self):
self.slots = get_tool_utils(self.tool_format).get_function_slots() + self.slots
self.tool_utils = get_tool_utils(self.tool_format)
@override
def apply(self, **kwargs) -> SLOTS:
content = kwargs.pop("content")
functions: List[Tuple[str, str]] = []
functions: List["FunctionCall"] = []
try:
tool_calls = json.loads(content)
if not isinstance(tool_calls, list): # parallel function call
tool_calls = [tool_calls]
for tool_call in tool_calls:
functions.append((tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False)))
functions.append(
FunctionCall(tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False))
)
except json.JSONDecodeError:
raise RuntimeError(f"Invalid JSON format in function message: {str([content])}") # flat string
elements = []
for name, arguments in functions:
for slot in self.slots:
if isinstance(slot, str):
slot = slot.replace("{{name}}", name).replace("{{arguments}}", arguments)
elements.append(slot)
elif isinstance(slot, (dict, set)):
elements.append(slot)
else:
raise RuntimeError(f"Input must be string, set[str] or dict[str, str], got {type(slot)}")
for slot in self.slots:
if slot == "{{content}}":
elements += self.tool_utils.function_formatter(functions)
else:
elements.append(slot)
return elements
......
......@@ -18,11 +18,10 @@ from typing import TYPE_CHECKING, Dict, Literal, Optional, Sequence, Union
import numpy as np
from datasets import DatasetDict, load_dataset, load_from_disk
from transformers.utils.versions import require_version
from ..extras import logging
from ..extras.constants import FILEEXT2TYPE
from ..extras.misc import has_tokenized_data
from ..extras.misc import check_version, has_tokenized_data
from .aligner import align_dataset
from .data_utils import merge_dataset, split_dataset
from .parser import get_dataset_list
......@@ -84,7 +83,7 @@ def _load_single_dataset(
raise NotImplementedError(f"Unknown load type: {dataset_attr.load_from}.")
if dataset_attr.load_from == "ms_hub":
require_version("modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0")
check_version("modelscope>=1.11.0", mandatory=True)
from modelscope import MsDataset # type: ignore
from modelscope.utils.config_ds import MS_DATASETS_CACHE # type: ignore
......@@ -103,7 +102,7 @@ def _load_single_dataset(
dataset = dataset.to_hf_dataset()
elif dataset_attr.load_from == "om_hub":
require_version("openmind>=0.8.0", "To fix: pip install openmind>=0.8.0")
check_version("openmind>=0.8.0", mandatory=True)
from openmind import OmDataset # type: ignore
from openmind.utils.hub import OM_DATASETS_CACHE # type: ignore
......@@ -128,7 +127,8 @@ def _load_single_dataset(
cache_dir=model_args.cache_dir,
token=model_args.hf_hub_token,
streaming=data_args.streaming,
trust_remote_code=True,
num_proc=data_args.preprocessing_num_workers,
trust_remote_code=model_args.trust_remote_code,
)
if dataset_attr.num_samples is not None and not data_args.streaming:
......@@ -238,15 +238,19 @@ def get_dataset(
if data_args.tokenized_path is not None:
if has_tokenized_data(data_args.tokenized_path):
logger.warning_rank0("Loading dataset from disk will ignore other data arguments.")
dataset_dict: "DatasetDict" = load_from_disk(data_args.tokenized_path)
tokenized_data: Union["Dataset", "DatasetDict"] = load_from_disk(data_args.tokenized_path)
logger.info_rank0(f"Loaded tokenized dataset from {data_args.tokenized_path}.")
dataset_module: Dict[str, "Dataset"] = {}
if "train" in dataset_dict:
dataset_module["train_dataset"] = dataset_dict["train"]
if isinstance(tokenized_data, DatasetDict):
if "train" in tokenized_data:
dataset_module["train_dataset"] = tokenized_data["train"]
if "validation" in tokenized_data:
dataset_module["eval_dataset"] = tokenized_data["validation"]
if "validation" in dataset_dict:
dataset_module["eval_dataset"] = dataset_dict["validation"]
else: # Dataset
dataset_module["train_dataset"] = tokenized_data
if data_args.streaming:
dataset_module = {k: v.to_iterable_dataset() for k, v in dataset_module.items()}
......
import math
import re
from copy import deepcopy
from io import BytesIO
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union
......@@ -62,6 +63,7 @@ class BasePlugin:
def __init__(self, image_token: Optional[str], video_token: Optional[str]) -> None:
self.image_token = image_token
self.video_token = video_token
self.expand_mm_tokens = True
def _validate_input(
self,
......@@ -72,10 +74,14 @@ class BasePlugin:
Validates if this model accepts the input modalities.
"""
if len(images) != 0 and self.image_token is None:
raise ValueError("This model does not support image input.")
raise ValueError(
"This model does not support image input. Please check whether the correct `template` is used."
)
if len(videos) != 0 and self.video_token is None:
raise ValueError("This model does not support video input.")
raise ValueError(
"This model does not support video input. Please check whether the correct `template` is used."
)
def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject":
r"""
......@@ -241,7 +247,7 @@ class BasePlugin:
videos: a list of video inputs, shape (num_videos,)
imglens: number of images in each sample, shape (batch_size,)
vidlens: number of videos in each sample, shape (batch_size,)
batch_ids: input ids of samples, shape (batch_size, seq_len)
batch_ids: token ids of input samples, shape (batch_size, seq_len)
processor: a processor for pre-processing images and videos
"""
self._validate_input(images, videos)
......@@ -259,13 +265,13 @@ class LlavaPlugin(BasePlugin):
) -> List[Dict[str, str]]:
self._validate_input(images, videos)
num_image_tokens = 0
image_seqlen = getattr(processor, "image_seqlen")
image_seqlen = getattr(processor, "image_seqlen") if self.expand_mm_tokens else 1
messages = deepcopy(messages)
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
num_image_tokens += 1
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
num_image_tokens += 1
message["content"] = content.replace("{{image}}", self.image_token)
......@@ -310,14 +316,16 @@ class LlavaNextPlugin(BasePlugin):
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
image_size = next(image_sizes)
orig_height, orig_width = image_size
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
if getattr(processor, "vision_feature_select_strategy") == "default":
image_seqlen -= 1
if self.expand_mm_tokens:
orig_height, orig_width = next(image_sizes)
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
if getattr(processor, "vision_feature_select_strategy") == "default":
image_seqlen -= 1
else:
image_seqlen = 1
num_image_tokens += 1
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
num_image_tokens += 1
message["content"] = content.replace("{{image}}", self.image_token)
......@@ -359,14 +367,16 @@ class LlavaNextVideoPlugin(BasePlugin):
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
image_size = next(image_sizes)
orig_height, orig_width = image_size
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
if getattr(processor, "vision_feature_select_strategy") == "default":
image_seqlen -= 1
if self.expand_mm_tokens:
orig_height, orig_width = next(image_sizes)
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
if getattr(processor, "vision_feature_select_strategy") == "default":
image_seqlen -= 1
else:
image_seqlen = 1
num_image_tokens += 1
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
num_image_tokens += 1
message["content"] = content.replace("{{image}}", self.image_token)
......@@ -376,6 +386,7 @@ class LlavaNextVideoPlugin(BasePlugin):
num_frames = pixel_values_video.shape[0] # frame dim is always after batch dim
image_seqlen = (height // processor.patch_size) * (width // processor.patch_size)
video_seqlen = image_seqlen // 4 * num_frames # divide by 4 needed for avg pooling layer
video_seqlen = video_seqlen if self.expand_mm_tokens else 1
for message in messages:
content = message["content"]
while VIDEO_PLACEHOLDER in content:
......@@ -406,7 +417,7 @@ class LlavaNextVideoPlugin(BasePlugin):
return self._get_mm_inputs(images, videos, processor)
class PaliGemmaPlugin(BasePlugin):
class MiniCPMVPlugin(BasePlugin):
@override
def process_messages(
self,
......@@ -417,12 +428,241 @@ class PaliGemmaPlugin(BasePlugin):
) -> List[Dict[str, str]]:
self._validate_input(images, videos)
num_image_tokens = 0
num_video_tokens = 0
messages = deepcopy(messages)
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
mm_inputs = {}
if len(images) != 0 and len(videos) != 0:
raise ValueError("MiniCPM-V model does not support input images and videos at the same time.")
if len(videos) != 0:
max_slice_nums = 2
use_image_id = False
mm_inputs = self._get_mm_inputs([], videos, processor)
else:
max_slice_nums = image_processor.max_slice_nums
use_image_id = image_processor.use_image_id
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
num_image_tokens += 1
while VIDEO_PLACEHOLDER in content:
video_seqlen = len(mm_inputs["pixel_values"][num_video_tokens]) if self.expand_mm_tokens else 1
content = content.replace(VIDEO_PLACEHOLDER, "{{image}}" * video_seqlen, 1)
num_video_tokens += 1
message["content"] = content.replace("{{image}}", "(<image>./</image>)")
if num_image_tokens > 0:
mm_inputs = self._get_mm_inputs(images, [], processor)
if mm_inputs:
pattern = "(<image>./</image>)"
image_sizes = mm_inputs["image_sizes"]
for index, message in enumerate(messages):
text = message["content"]
image_tags = re.findall(pattern, text)
text_chunks = text.split(pattern)
final_text = ""
for i in range(len(image_tags)):
final_text = (
final_text
+ text_chunks[i]
+ image_processor.get_slice_image_placeholder(
image_sizes[0][i], i, max_slice_nums, use_image_id
)
)
final_text += text_chunks[-1]
messages[index]["content"] = final_text
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
if len(videos) != num_video_tokens:
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
return messages
@override
def _get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
processor: "ProcessorMixin",
**kwargs,
) -> Dict[str, "torch.Tensor"]:
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
mm_inputs = {}
if len(images) != 0:
images = self._regularize_images(
images,
image_resolution=getattr(processor, "image_resolution", 512 * 512),
)
if "valid_image_nums_ls" in kwargs:
valid_image_nums_ls = kwargs["valid_image_nums_ls"]
new_images = []
idx = 0
for valid_image_nums in valid_image_nums_ls:
new_images.append(images[idx : idx + valid_image_nums])
idx += valid_image_nums
images = new_images
image_inputs = image_processor(
images, do_pad=True, max_slice_nums=image_processor.max_slice_nums, return_tensors="pt"
)
mm_inputs.update(image_inputs)
if len(videos) != 0:
videos = self._regularize_videos(
videos,
image_resolution=getattr(processor, "video_resolution", 128 * 128),
video_fps=getattr(processor, "video_fps", 2.0),
video_maxlen=getattr(processor, "video_maxlen", 64),
)
video_inputs = image_processor(videos, do_pad=True, max_slice_nums=2, return_tensors="pt")
mm_inputs.update(video_inputs)
return mm_inputs
@override
def get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos)
image_bounds_list = []
valid_image_nums_ls = []
for input_ids in batch_ids:
input_ids_ = torch.tensor(input_ids)
start_cond = (input_ids_ == processor.tokenizer.im_start_id) | (
input_ids_ == processor.tokenizer.slice_start_id
)
end_cond = (input_ids_ == processor.tokenizer.im_end_id) | (input_ids_ == processor.tokenizer.slice_end_id)
image_start_tokens = torch.where(start_cond)[0]
image_start_tokens += 1
image_end_tokens = torch.where(end_cond)[0]
valid_image_nums = max(len(image_start_tokens), len(image_end_tokens))
valid_image_nums_ls.append(valid_image_nums)
image_bounds = torch.hstack(
[
image_start_tokens[:valid_image_nums].unsqueeze(-1),
image_end_tokens[:valid_image_nums].unsqueeze(-1),
]
)
image_bounds_list.append(image_bounds)
mm_inputs = self._get_mm_inputs(images, videos, processor, valid_image_nums_ls=valid_image_nums_ls)
mm_inputs.update({"image_bound": image_bounds_list})
return mm_inputs
class MllamaPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[Dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
self._validate_input(images, videos)
num_image_tokens = 0
messages = deepcopy(messages)
for message in messages:
content = message["content"]
num_image_tokens += content.count(IMAGE_PLACEHOLDER)
message["content"] = content.replace(IMAGE_PLACEHOLDER, self.image_token)
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
return messages
@override
def _get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
processor: "ProcessorMixin",
**kwargs,
) -> Dict[str, "torch.Tensor"]:
r"""
Processes visual inputs for mllama because its image processor only accepts List[List[ImageInput]].
Returns:
pixel_values: tensor with shape
(batch_size, max_num_images, max_image_tiles, channels, tile_height, tile_width)
For example, (2, 1, 4, 3, 560, 560).
aspect_ratio_ids: tensor with shape (batch_size, max_num_images). For example, (2, 1).
aspect_ratio_mask: tensor with shape (batch_size, max_num_images, max_image_tiles). For example, (2, 1, 4).
num_tiles: List[List[int]] with shape (batch_size, num_images_in_batch). For example, (2, 1).
"""
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
imglens: List[int] = kwargs["imglens"]
images = self._regularize_images(images, image_resolution=getattr(processor, "image_resolution", 512 * 512))
batch_images = []
for image_length in imglens:
batch_images.append(images[:image_length])
images = images[image_length:]
return image_processor(batch_images, return_tensors="pt")
def get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos)
mm_inputs = self._get_mm_inputs(images, videos, processor, imglens=imglens)
num_tiles = mm_inputs.pop("num_tiles")
image_token_id = getattr(processor, "image_token_id")
max_image_tiles = getattr(processor.image_processor, "max_image_tiles")
cross_attention_token_mask = [
get_cross_attention_token_mask(input_ids, image_token_id) for input_ids in batch_ids
]
mm_inputs["cross_attention_mask"] = torch.from_numpy(
convert_sparse_cross_attention_mask_to_dense(
cross_attention_token_mask,
num_tiles=num_tiles,
max_num_tiles=max_image_tiles,
length=max(len(input_ids) for input_ids in batch_ids),
)
) # shape: (batch_size, length, max_num_images, max_num_tiles)
return mm_inputs
class PaliGemmaPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[Dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
self._validate_input(images, videos)
num_image_tokens = 0
messages = deepcopy(messages)
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
num_image_tokens += 1
message["content"] = content.replace("{{image}}", "")
......@@ -443,7 +683,7 @@ class PaliGemmaPlugin(BasePlugin):
) -> Tuple[List[int], Optional[List[int]]]:
self._validate_input(images, videos)
num_images = len(images)
image_seqlen = num_images * getattr(processor, "image_seqlen")
image_seqlen = num_images * getattr(processor, "image_seqlen") if self.expand_mm_tokens else 0 # skip mm token
image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
input_ids = [image_token_id] * image_seqlen + input_ids
if labels is not None:
......@@ -493,14 +733,18 @@ class PixtralPlugin(BasePlugin):
if image_input_sizes is None:
raise ValueError("Cannot get image input sizes.")
image_size = image_input_sizes[0][num_image_tokens]
height, width = image_size
num_height_tokens = height // patch_size
num_width_tokens = width // patch_size
replace_tokens = [[image_token] * num_width_tokens + [image_break_token]] * num_height_tokens
replace_tokens = [item for sublist in replace_tokens for item in sublist] # flatten list
replace_tokens[-1] = image_end_token
replace_str = "".join(replace_tokens)
if self.expand_mm_tokens:
image_size = image_input_sizes[0][num_image_tokens]
height, width = image_size
num_height_tokens = height // patch_size
num_width_tokens = width // patch_size
replace_tokens = [[image_token] * num_width_tokens + [image_break_token]] * num_height_tokens
replace_tokens = [item for sublist in replace_tokens for item in sublist] # flatten list
replace_tokens[-1] = image_end_token
replace_str = "".join(replace_tokens)
else:
replace_str = image_token
content = content.replace(IMAGE_PLACEHOLDER, replace_str, 1)
num_image_tokens += 1
......@@ -549,10 +793,27 @@ class Qwen2vlPlugin(BasePlugin):
return image
@override
def _get_video_sample_frames(self, video_stream: "Stream", **kwargs) -> int:
sample_frames = super()._get_video_sample_frames(video_stream, **kwargs)
sample_frames = sample_frames // 2 * 2
return sample_frames
def _regularize_videos(self, videos: Sequence["VideoInput"], **kwargs) -> List[List["ImageObject"]]:
results = []
for video in videos:
container = av.open(video, "r")
video_stream = next(stream for stream in container.streams if stream.type == "video")
total_frames = video_stream.frames
sample_frames = self._get_video_sample_frames(video_stream, **kwargs)
sample_indices = np.linspace(0, total_frames - 1, sample_frames).astype(np.int32)
frames: List["ImageObject"] = []
container.seek(0)
for frame_idx, frame in enumerate(container.decode(video_stream)):
if frame_idx in sample_indices:
frames.append(frame.to_image())
if len(frames) % 2 != 0: # qwen2-vl requires even number of frames
frames.append(frames[-1])
frames = self._regularize_images(frames, **kwargs)
results.append(frames)
return results
@override
def process_messages(
......@@ -577,12 +838,9 @@ class Qwen2vlPlugin(BasePlugin):
if num_image_tokens >= len(image_grid_thw):
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
image_seqlen = image_grid_thw[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1
content = content.replace(
IMAGE_PLACEHOLDER,
"<|vision_start|>{}<|vision_end|>".format(
self.image_token * (image_grid_thw[num_image_tokens].prod() // merge_length)
),
1,
IMAGE_PLACEHOLDER, f"<|vision_start|>{self.image_token * image_seqlen}<|vision_end|>", 1
)
num_image_tokens += 1
......@@ -590,12 +848,9 @@ class Qwen2vlPlugin(BasePlugin):
if num_video_tokens >= len(video_grid_thw):
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
video_seqlen = video_grid_thw[num_video_tokens].prod() // merge_length if self.expand_mm_tokens else 1
content = content.replace(
VIDEO_PLACEHOLDER,
"<|vision_start|>{}<|vision_end|>".format(
self.video_token * (video_grid_thw[num_video_tokens].prod() // merge_length)
),
1,
VIDEO_PLACEHOLDER, f"<|vision_start|>{self.video_token * video_seqlen}<|vision_end|>", 1
)
num_video_tokens += 1
......@@ -640,29 +895,32 @@ class VideoLlavaPlugin(BasePlugin):
has_images = "pixel_values_images" in mm_inputs
has_videos = "pixel_values_videos" in mm_inputs
if has_images or has_videos:
if has_images:
height, width = get_image_size(to_numpy_array(mm_inputs.get("pixel_values_images")[0]))
num_frames = 1
if has_videos:
pixel_values_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0])
height, width = get_image_size(pixel_values_video[0])
num_frames = pixel_values_video.shape[0] # frame dim is always after batch dim
image_seqlen = (height // processor.patch_size) * (width // processor.patch_size) + 1
video_seqlen = image_seqlen * num_frames
if getattr(processor, "vision_feature_select_strategy") == "default":
image_seqlen -= 1
if self.expand_mm_tokens:
if has_images:
height, width = get_image_size(to_numpy_array(mm_inputs.get("pixel_values_images")[0]))
num_frames = 1
if has_videos:
pixel_values_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0])
height, width = get_image_size(pixel_values_video[0])
num_frames = pixel_values_video.shape[0] # frame dim is always after batch dim
image_seqlen = (height // processor.patch_size) * (width // processor.patch_size) + 1
video_seqlen = image_seqlen * num_frames
if getattr(processor, "vision_feature_select_strategy") == "default":
image_seqlen -= 1
else:
image_seqlen, video_seqlen = 1, 1
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
num_image_tokens += 1
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
num_image_tokens += 1
while VIDEO_PLACEHOLDER in content:
num_video_tokens += 1
content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1)
num_video_tokens += 1
content = content.replace("{{image}}", self.image_token)
message["content"] = content.replace("{{video}}", self.video_token)
......@@ -689,89 +947,17 @@ class VideoLlavaPlugin(BasePlugin):
return self._get_mm_inputs(images, videos, processor)
class MllamaPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[Dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
self._validate_input(images, videos)
num_image_tokens = 0
messages = deepcopy(messages)
for message in messages:
content = message["content"]
num_image_tokens += content.count(IMAGE_PLACEHOLDER)
message["content"] = content.replace(IMAGE_PLACEHOLDER, self.image_token)
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
return messages
@override
def _get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
processor: "ProcessorMixin",
) -> Dict[str, "torch.Tensor"]:
r"""
Processes visual inputs for mllama because its image processor only accepts List[List[ImageInput]].
Returns:
pixel_values: tensor with shape
(batch_size, max_num_images, max_image_tiles, channels, tile_height, tile_width)
For example, (2, 1, 4, 3, 560, 560).
aspect_ratio_ids: tensor with shape (batch_size, max_num_images). For example, (2, 1).
aspect_ratio_mask: tensor with shape (batch_size, max_num_images, max_image_tiles). For example, (2, 1, 4).
num_tiles: List[List[int]] with shape (batch_size, num_images_in_batch). For example, (2, 1).
"""
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
images = self._regularize_images(images, image_resolution=getattr(processor, "image_resolution", 512 * 512))
return image_processor([[image] for image in images], return_tensors="pt")
def get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos)
if len(images) != len(batch_ids):
raise ValueError("Mllama only supports one image per sample.")
mm_inputs = self._get_mm_inputs(images, videos, processor)
num_tiles = mm_inputs.pop("num_tiles")
image_token_id = getattr(processor, "image_token_id")
max_image_tiles = getattr(processor.image_processor, "max_image_tiles")
cross_attention_token_mask = [
get_cross_attention_token_mask(input_ids, image_token_id) for input_ids in batch_ids
]
mm_inputs["cross_attention_mask"] = convert_sparse_cross_attention_mask_to_dense(
cross_attention_token_mask,
num_tiles=num_tiles,
max_num_tiles=max_image_tiles,
length=max(len(input_ids) for input_ids in batch_ids),
)
return mm_inputs
PLUGINS = {
"base": BasePlugin,
"llava": LlavaPlugin,
"llava_next": LlavaNextPlugin,
"llava_next_video": LlavaNextVideoPlugin,
"minicpm_v": MiniCPMVPlugin,
"mllama": MllamaPlugin,
"paligemma": PaliGemmaPlugin,
"pixtral": PixtralPlugin,
"qwen2_vl": Qwen2vlPlugin,
"video_llava": VideoLlavaPlugin,
"mllama": MllamaPlugin,
}
......
......@@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Callable, Literal, Optional, Tuple
from .processors.feedback import preprocess_feedback_dataset
from .processors.pairwise import preprocess_pairwise_dataset, print_pairwise_dataset_example
from .processors.pretrain import preprocess_pretrain_dataset
from .processors.pretrain import preprocess_pretrain_dataset, print_pretrain_dataset_example
from .processors.supervised import (
preprocess_packed_supervised_dataset,
preprocess_supervised_dataset,
......@@ -47,7 +47,7 @@ def get_preprocess_and_print_func(
tokenizer=tokenizer,
data_args=data_args,
)
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
print_function = partial(print_pretrain_dataset_example, tokenizer=tokenizer)
elif stage == "sft" and not do_generate:
if data_args.packing:
if data_args.neat_packing: # hack datasets to have int32 attention mask
......
......@@ -52,3 +52,8 @@ def preprocess_pretrain_dataset(
result["input_ids"][i][0] = tokenizer.bos_token_id
return result
def print_pretrain_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
......@@ -100,3 +100,5 @@ def preprocess_unsupervised_dataset(
def print_unsupervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
print("label_ids:\n{}".format(example["labels"]))
print("labels:\n{}".format(tokenizer.decode(example["labels"], skip_special_tokens=False)))
......@@ -15,10 +15,10 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
from transformers.utils.versions import require_version
from typing_extensions import override
from ..extras import logging
from ..extras.misc import check_version
from .data_utils import Role
from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
from .mm_plugin import get_mm_plugin
......@@ -30,6 +30,7 @@ if TYPE_CHECKING:
from ..hparams import DataArguments
from .formatter import SLOTS, Formatter
from .mm_plugin import BasePlugin
from .tool_utils import FunctionCall
logger = logging.get_logger(__name__)
......@@ -43,7 +44,6 @@ class Template:
format_function: "Formatter"
format_observation: "Formatter"
format_tools: "Formatter"
format_separator: "Formatter"
format_prefix: "Formatter"
default_system: str
stop_words: List[str]
......@@ -83,12 +83,22 @@ class Template:
encoded_messages = self._encode(tokenizer, messages, system, tools)
return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)]
def extract_tool(self, content: str) -> Union[str, List[Tuple[str, str]]]:
def extract_tool(self, content: str) -> Union[str, List["FunctionCall"]]:
r"""
Extracts tool message.
"""
return self.format_tools.extract(content)
def get_stop_token_ids(self, tokenizer: "PreTrainedTokenizer") -> List[int]:
r"""
Returns stop token ids.
"""
stop_token_ids = {tokenizer.eos_token_id}
for token in self.stop_words:
stop_token_ids.add(tokenizer.convert_tokens_to_ids(token))
return list(stop_token_ids)
def _encode(
self,
tokenizer: "PreTrainedTokenizer",
......@@ -112,9 +122,6 @@ class Template:
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
elements += self.format_system.apply(content=(system + tool_text))
if i > 0 and i % 2 == 0:
elements += self.format_separator.apply()
if message["role"] == Role.USER.value:
elements += self.format_user.apply(content=message["content"], idx=str(i // 2))
elif message["role"] == Role.ASSISTANT.value:
......@@ -179,9 +186,6 @@ class Llama2Template(Template):
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
system_text = self.format_system.apply(content=(system + tool_text))[0]
if i > 0 and i % 2 == 0:
elements += self.format_separator.apply()
if message["role"] == Role.USER.value:
elements += self.format_user.apply(content=system_text + message["content"])
elif message["role"] == Role.ASSISTANT.value:
......@@ -209,13 +213,12 @@ def _register_template(
format_function: Optional["Formatter"] = None,
format_observation: Optional["Formatter"] = None,
format_tools: Optional["Formatter"] = None,
format_separator: Optional["Formatter"] = None,
format_prefix: Optional["Formatter"] = None,
default_system: str = "",
stop_words: Sequence[str] = [],
stop_words: Optional[Sequence[str]] = None,
efficient_eos: bool = False,
replace_eos: bool = False,
replace_jinja_template: bool = True,
replace_jinja_template: bool = False,
mm_plugin: "BasePlugin" = get_mm_plugin(name="base"),
) -> None:
r"""
......@@ -223,34 +226,28 @@ def _register_template(
To add the following chat template:
```
[HUMAN]:
user prompt here
[AI]:
model response here
[HUMAN]:
user prompt here
[AI]:
model response here
<s><user>user prompt here
<model>model response here</s>
<user>user prompt here
<model>model response here</s>
```
The corresponding code should be:
```
_register_template(
name="custom",
format_user=StringFormatter(slots=["[HUMAN]:\n{{content}}\n[AI]:\n"]),
format_separator=EmptyFormatter(slots=["\n\n"]),
efficient_eos=True,
format_user=StringFormatter(slots=["<user>{{content}}\n<model>"]),
format_assistant=StringFormatter(slots=["{{content}}</s>\n"]),
format_prefix=EmptyFormatter("<s>"),
)
```
"""
eos_slots = [] if efficient_eos else [{"eos_token"}]
template_class = Llama2Template if name.startswith("llama2") else Template
template_class = Llama2Template if any(k in name for k in ("llama2", "mistral", "pixtral")) else Template
default_slots = ["{{content}}"] if efficient_eos else ["{{content}}", {"eos_token"}]
default_user_formatter = StringFormatter(slots=["{{content}}"])
default_assistant_formatter = StringFormatter(slots=["{{content}}"] + eos_slots)
default_function_formatter = FunctionFormatter(slots=eos_slots, tool_format="default")
default_assistant_formatter = StringFormatter(slots=default_slots)
default_function_formatter = FunctionFormatter(slots=default_slots, tool_format="default")
default_tool_formatter = ToolFormatter(tool_format="default")
default_separator_formatter = EmptyFormatter()
default_prefix_formatter = EmptyFormatter()
TEMPLATES[name] = template_class(
format_user=format_user or default_user_formatter,
......@@ -259,10 +256,9 @@ def _register_template(
format_function=format_function or default_function_formatter,
format_observation=format_observation or format_user or default_user_formatter,
format_tools=format_tools or default_tool_formatter,
format_separator=format_separator or default_separator_formatter,
format_prefix=format_prefix or default_prefix_formatter,
default_system=default_system,
stop_words=stop_words,
stop_words=stop_words or [],
efficient_eos=efficient_eos,
replace_eos=replace_eos,
replace_jinja_template=replace_jinja_template,
......@@ -343,9 +339,7 @@ def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer")
jinja_template += "{{ " + user_message + " }}"
jinja_template += "{% elif message['role'] == 'assistant' %}"
assistant_message = _convert_slots_to_jinja(
template.format_assistant.apply() + template.format_separator.apply(), tokenizer
)
assistant_message = _convert_slots_to_jinja(template.format_assistant.apply(), tokenizer)
jinja_template += "{{ " + assistant_message + " }}"
jinja_template += "{% endif %}"
jinja_template += "{% endfor %}"
......@@ -364,15 +358,15 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
raise ValueError(f"Template {data_args.template} does not exist.")
if template.mm_plugin.__class__.__name__ != "BasePlugin":
require_version("transformers>=4.45.0", "To fix: pip install transformers>=4.45.0")
check_version("transformers>=4.45.0")
if data_args.train_on_prompt and template.efficient_eos:
raise ValueError("Current template does not support `train_on_prompt`.")
if data_args.tool_format is not None:
logger.info_rank0(f"Using tool format: {data_args.tool_format}.")
eos_slots = [] if template.efficient_eos else [{"eos_token"}]
template.format_function = FunctionFormatter(slots=eos_slots, tool_format=data_args.tool_format)
default_slots = ["{{content}}"] if template.efficient_eos else ["{{content}}", {"eos_token"}]
template.format_function = FunctionFormatter(slots=default_slots, tool_format=data_args.tool_format)
template.format_tools = ToolFormatter(tool_format=data_args.tool_format)
stop_words = template.stop_words
......@@ -410,24 +404,24 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
_register_template(
name="alpaca",
format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n\n### Response:\n"]),
format_separator=EmptyFormatter(slots=["\n\n"]),
format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}, "\n\n"]),
default_system=(
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
"Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
),
replace_jinja_template=True,
)
_register_template(
name="aquila",
format_user=StringFormatter(slots=["Human: {{content}}###Assistant:"]),
format_separator=EmptyFormatter(slots=["###"]),
format_assistant=StringFormatter(slots=["{{content}}###"]),
format_system=StringFormatter(slots=["System: {{content}}###"]),
default_system=(
"A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions."
),
stop_words=["</s>"],
efficient_eos=True,
)
......@@ -457,7 +451,7 @@ _register_template(
_register_template(
name="belle",
format_user=StringFormatter(slots=["Human: {{content}}\n\nBelle: "]),
format_separator=EmptyFormatter(slots=["\n\n"]),
format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}, "\n\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
)
......@@ -479,7 +473,6 @@ _register_template(
_register_template(
name="chatglm2",
format_user=StringFormatter(slots=["[Round {{idx}}]\n\n问:{{content}}\n\n答:"]),
format_separator=EmptyFormatter(slots=["\n\n"]),
format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]),
efficient_eos=True,
)
......@@ -490,7 +483,7 @@ _register_template(
format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
format_system=StringFormatter(slots=[{"token": "<|system|>"}, "\n", "{{content}}"]),
format_function=FunctionFormatter(slots=[], tool_format="glm4"),
format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4"),
format_observation=StringFormatter(
slots=[{"token": "<|observation|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]
),
......@@ -504,23 +497,26 @@ _register_template(
_register_template(
name="chatml",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
stop_words=["<|im_end|>", "<|im_start|>"],
replace_eos=True,
replace_jinja_template=True,
)
# copied from chatml template
_register_template(
name="chatml_de",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
default_system="Du bist ein freundlicher und hilfsbereiter KI-Assistent.",
stop_words=["<|im_end|>", "<|im_start|>"],
replace_eos=True,
replace_jinja_template=True,
)
......@@ -534,7 +530,7 @@ _register_template(
name="codegeex4",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>\n"]),
format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
format_function=FunctionFormatter(slots=[], tool_format="glm4"),
format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4"),
format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>\n"]),
format_tools=ToolFormatter(tool_format="glm4"),
format_prefix=EmptyFormatter(slots=["[gMASK]<sop>"]),
......@@ -569,21 +565,24 @@ _register_template(
)
# copied from chatml template
_register_template(
name="cpm3",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|im_end|>"],
)
# copied from chatml template
_register_template(
name="dbrx",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
default_system=(
"You are DBRX, created by Databricks. You were last updated in December 2023. "
"You answer questions based on information available up to that point.\n"
......@@ -600,7 +599,6 @@ _register_template(
"ABOUT YOURSELF UNLESS THE INFORMATION IS DIRECTLY PERTINENT TO THE USER'S QUERY."
),
stop_words=["<|im_end|>"],
replace_eos=True,
)
......@@ -612,11 +610,17 @@ _register_template(
)
_register_template(
name="deepseek3",
format_user=StringFormatter(slots=["<|User|>{{content}}<|Assistant|>"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
)
_register_template(
name="deepseekcoder",
format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n### Response:"]),
format_assistant=StringFormatter(slots=["\n{{content}}\n<|EOT|>"]),
format_separator=EmptyFormatter(slots=["\n"]),
format_assistant=StringFormatter(slots=["\n{{content}}\n<|EOT|>\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
default_system=(
"You are an AI programming assistant, utilizing the DeepSeek Coder model, "
......@@ -630,8 +634,8 @@ _register_template(
_register_template(
name="default",
format_user=StringFormatter(slots=["Human: {{content}}\nAssistant:"]),
format_system=StringFormatter(slots=["{{content}}\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}, "\n"]),
format_system=StringFormatter(slots=["System: {{content}}\n"]),
)
......@@ -644,22 +648,22 @@ _register_template(
_register_template(
name="exaone",
format_user=StringFormatter(slots=["[|user|]{{content}}\n[|assistant|]"]),
format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}, "\n"]),
format_system=StringFormatter(slots=["[|system|]{{content}}[|endofturn|]\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
)
_register_template(
name="falcon",
format_user=StringFormatter(slots=["User: {{content}}\nFalcon:"]),
format_separator=EmptyFormatter(slots=["\n"]),
format_assistant=StringFormatter(slots=["{{content}}\n"]),
efficient_eos=True,
)
_register_template(
name="fewshot",
format_separator=EmptyFormatter(slots=["\n\n"]),
format_assistant=StringFormatter(slots=["{{content}}\n\n"]),
efficient_eos=True,
)
......@@ -667,13 +671,11 @@ _register_template(
_register_template(
name="gemma",
format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
format_assistant=StringFormatter(slots=["{{content}}<end_of_turn>\n"]),
format_observation=StringFormatter(
slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]
),
format_separator=EmptyFormatter(slots=["<end_of_turn>\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
efficient_eos=True,
replace_jinja_template=False,
)
......@@ -682,7 +684,7 @@ _register_template(
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
format_assistant=StringFormatter(slots=["\n{{content}}"]),
format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
format_function=FunctionFormatter(slots=[], tool_format="glm4"),
format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4"),
format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]),
format_tools=ToolFormatter(tool_format="glm4"),
format_prefix=EmptyFormatter(slots=["[gMASK]<sop>"]),
......@@ -691,6 +693,18 @@ _register_template(
)
_register_template(
name="granite3",
format_user=StringFormatter(
slots=[
"<|start_of_role|>user<|end_of_role|>{{content}}<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>"
]
),
format_assistant=StringFormatter(slots=["{{content}}<|end_of_text|>\n"]),
format_system=StringFormatter(slots=["<|start_of_role|>system<|end_of_role|>{{content}}<|end_of_text|>\n"]),
)
_register_template(
name="index",
format_user=StringFormatter(slots=["reserved_0{{content}}reserved_1"]),
......@@ -702,22 +716,31 @@ _register_template(
_register_template(
name="intern",
format_user=StringFormatter(slots=["<|User|>:{{content}}\n<|Bot|>:"]),
format_assistant=StringFormatter(slots=["{{content}}<eoa>\n"]),
format_system=StringFormatter(slots=["<|System|>:{{content}}\n"]),
format_separator=EmptyFormatter(slots=["<eoa>\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<eoa>"],
efficient_eos=True, # internlm tokenizer cannot set eos_token_id
)
_register_template(
name="intern2",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|im_end|>"],
)
# copied from intern2 template
_register_template(
name="intern3",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_separator=EmptyFormatter(slots=["<|im_end|>\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|im_end|>"],
efficient_eos=True, # internlm2 tokenizer cannot set eos_token_id
)
......@@ -728,6 +751,7 @@ _register_template(
)
# copied from llama2 template
_register_template(
name="llama2_zh",
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
......@@ -746,22 +770,24 @@ _register_template(
)
]
),
format_assistant=StringFormatter(slots=["{{content}}<|eot_id|>"]),
format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]),
format_function=FunctionFormatter(slots=["{{content}}<|eot_id|>"], tool_format="llama3"),
format_observation=StringFormatter(
slots=[
(
"<|start_header_id|>tool<|end_header_id|>\n\n{{content}}<|eot_id|>"
"<|start_header_id|>ipython<|end_header_id|>\n\n{{content}}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
]
),
format_tools=ToolFormatter(tool_format="llama3"),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|eot_id|>"],
replace_eos=True,
replace_jinja_template=False,
stop_words=["<|eot_id|>", "<|eom_id|>"],
)
# copied from llama3 template
_register_template(
name="mllama",
format_user=StringFormatter(
......@@ -772,23 +798,25 @@ _register_template(
)
]
),
format_assistant=StringFormatter(slots=["{{content}}<|eot_id|>"]),
format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]),
format_function=FunctionFormatter(slots=["{{content}}<|eot_id|>"], tool_format="llama3"),
format_observation=StringFormatter(
slots=[
(
"<|start_header_id|>tool<|end_header_id|>\n\n{{content}}<|eot_id|>"
"<|start_header_id|>ipython<|end_header_id|>\n\n{{content}}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
]
),
format_tools=ToolFormatter(tool_format="llama3"),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|eot_id|>"],
replace_eos=True,
replace_jinja_template=False,
stop_words=["<|eot_id|>", "<|eom_id|>"],
mm_plugin=get_mm_plugin(name="mllama", image_token="<|image|>"),
)
# copied from vicuna template
_register_template(
name="llava",
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
......@@ -800,6 +828,7 @@ _register_template(
)
# copied from vicuna template
_register_template(
name="llava_next",
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
......@@ -811,6 +840,7 @@ _register_template(
)
# copied from llama3 template
_register_template(
name="llava_next_llama3",
format_user=StringFormatter(
......@@ -821,56 +851,67 @@ _register_template(
)
]
),
format_assistant=StringFormatter(slots=["{{content}}<|eot_id|>"]),
format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]),
format_function=FunctionFormatter(slots=["{{content}}<|eot_id|>"], tool_format="llama3"),
format_observation=StringFormatter(
slots=[
(
"<|start_header_id|>tool<|end_header_id|>\n\n{{content}}<|eot_id|>"
"<|start_header_id|>ipython<|end_header_id|>\n\n{{content}}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
]
),
format_tools=ToolFormatter(tool_format="llama3"),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|eot_id|>"],
replace_eos=True,
replace_jinja_template=False,
stop_words=["<|eot_id|>", "<|eom_id|>"],
mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"),
)
# copied from mistral template
_register_template(
name="llava_next_mistral",
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
format_user=StringFormatter(slots=["[INST] {{content}}[/INST]"]),
format_assistant=StringFormatter(slots=[" {{content}}", {"eos_token"}]),
format_system=StringFormatter(slots=["{{content}}\n\n"]),
format_function=FunctionFormatter(slots=["[TOOL_CALLS] ", "{{content}}", {"eos_token"}], tool_format="mistral"),
format_observation=StringFormatter(slots=["""[TOOL_RESULTS] {"content": {{content}}}[/TOOL_RESULTS]"""]),
format_tools=ToolFormatter(tool_format="mistral"),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"),
)
# copied from chatml template
_register_template(
name="llava_next_qwen",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen"),
format_observation=StringFormatter(
slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
),
format_tools=ToolFormatter(tool_format="qwen"),
default_system="You are a helpful assistant.",
stop_words=["<|im_end|>"],
replace_eos=True,
replace_jinja_template=False,
mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"),
)
# copied from chatml template
_register_template(
name="llava_next_yi",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
stop_words=["<|im_end|>"],
replace_eos=True,
mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"),
)
# copied from vicuna template
_register_template(
name="llava_next_video",
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
......@@ -882,28 +923,66 @@ _register_template(
)
# copied from mistral template
_register_template(
name="llava_next_video_mistral",
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
format_user=StringFormatter(slots=["[INST] {{content}}[/INST]"]),
format_assistant=StringFormatter(slots=[" {{content}}", {"eos_token"}]),
format_system=StringFormatter(slots=["{{content}}\n\n"]),
format_function=FunctionFormatter(slots=["[TOOL_CALLS] ", "{{content}}", {"eos_token"}], tool_format="mistral"),
format_observation=StringFormatter(slots=["""[TOOL_RESULTS] {"content": {{content}}}[/TOOL_RESULTS]"""]),
format_tools=ToolFormatter(tool_format="mistral"),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
mm_plugin=get_mm_plugin(name="llava_next_video", image_token="<image>", video_token="<video>"),
)
# copied from chatml template
_register_template(
name="llava_next_video_yi",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
stop_words=["<|im_end|>"],
replace_eos=True,
mm_plugin=get_mm_plugin(name="llava_next_video", image_token="<image>", video_token="<video>"),
)
# copied from chatml template
_register_template(
name="marco",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
default_system=(
"你是一个经过良好训练的AI助手,你的名字是Marco-o1.由阿里国际数字商业集团的AI Business创造.\n## 重要!!!!!\n"
"当你回答问题时,你的思考应该在<Thought>内完成,<Output>内输出你的结果。\n"
"<Thought>应该尽可能是英文,但是有2个特例,一个是对原文中的引用,另一个是是数学应该使用markdown格式,<Output>内的输出需要遵循用户输入的语言。\n"
),
stop_words=["<|im_end|>"],
)
# copied from chatml template
_register_template(
name="minicpm_v",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
stop_words=["<|im_end|>"],
mm_plugin=get_mm_plugin(name="minicpm_v", image_token="<image>", video_token="<video>"),
)
_register_template(
name="mistral",
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
format_user=StringFormatter(slots=["[INST] {{content}}[/INST]"]),
format_assistant=StringFormatter(slots=[" {{content}}", {"eos_token"}]),
format_system=StringFormatter(slots=["{{content}}\n\n"]),
format_function=FunctionFormatter(slots=["[TOOL_CALLS] ", "{{content}}", {"eos_token"}], tool_format="mistral"),
format_observation=StringFormatter(slots=["""[TOOL_RESULTS] {"content": {{content}}}[/TOOL_RESULTS]"""]),
format_tools=ToolFormatter(tool_format="mistral"),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
)
......@@ -934,20 +1013,18 @@ _register_template(
),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|eot_id|>"],
replace_eos=True,
)
# copied from chatml template
_register_template(
name="opencoder",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
default_system="You are OpenCoder, created by OpenCoder Team.",
stop_words=["<|im_end|>"],
replace_eos=True,
replace_jinja_template=False,
)
......@@ -958,15 +1035,15 @@ _register_template(
)
# copied from gemma template
_register_template(
name="paligemma",
format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
format_assistant=StringFormatter(slots=["{{content}}<end_of_turn>\n"]),
format_observation=StringFormatter(
slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]
),
format_separator=EmptyFormatter(slots=["<end_of_turn>\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
efficient_eos=True,
mm_plugin=get_mm_plugin(name="paligemma", image_token="<image>"),
)
......@@ -974,56 +1051,71 @@ _register_template(
_register_template(
name="phi",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|end|>\n"]),
format_system=StringFormatter(slots=["<|system|>\n{{content}}<|end|>\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|end|>"],
replace_eos=True,
)
_register_template(
name="phi_small",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|end|>\n"]),
format_system=StringFormatter(slots=["<|system|>\n{{content}}<|end|>\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
format_prefix=EmptyFormatter(slots=[{"<|endoftext|>"}]),
stop_words=["<|end|>"],
replace_eos=True,
)
_register_template(
name="phi4",
format_user=StringFormatter(
slots=["<|im_start|>user<|im_sep|>{{content}}<|im_end|><|im_start|>assistant<|im_sep|>"]
),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>"]),
format_system=StringFormatter(slots=["<|im_start|>system<|im_sep|>{{content}}<|im_end|>"]),
stop_words=["<|im_end|>"],
)
_register_template(
name="pixtral",
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
format_user=StringFormatter(slots=["[INST]{{content}}[/INST]"]),
format_system=StringFormatter(slots=["{{content}}\n\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
mm_plugin=get_mm_plugin(name="pixtral", image_token="[IMG]"),
)
# copied from chatml template
_register_template(
name="qwen",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen"),
format_observation=StringFormatter(
slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
),
format_tools=ToolFormatter(tool_format="qwen"),
default_system="You are a helpful assistant.",
stop_words=["<|im_end|>"],
replace_eos=True,
replace_jinja_template=False,
)
# copied from chatml template
_register_template(
name="qwen2_vl",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen"),
format_observation=StringFormatter(
slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
),
format_tools=ToolFormatter(tool_format="qwen"),
default_system="You are a helpful assistant.",
stop_words=["<|im_end|>"],
replace_eos=True,
replace_jinja_template=False,
mm_plugin=get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>", video_token="<|video_pad|>"),
)
......@@ -1031,14 +1123,48 @@ _register_template(
_register_template(
name="sailor",
format_user=StringFormatter(slots=["<|im_start|>question\n{{content}}<|im_end|>\n<|im_start|>answer\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
default_system=(
"You are an AI assistant named Sailor created by Sea AI Lab. "
"Your answer should be friendly, unbiased, faithful, informative and detailed."
),
stop_words=["<|im_end|>"],
replace_eos=True,
)
# copied from llama3 template
_register_template(
name="skywork_o1",
format_user=StringFormatter(
slots=[
(
"<|start_header_id|>user<|end_header_id|>\n\n{{content}}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
]
),
format_assistant=StringFormatter(slots=["{{content}}<|eot_id|>"]),
format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]),
format_function=FunctionFormatter(slots=["{{content}}<|eot_id|>"], tool_format="llama3"),
format_observation=StringFormatter(
slots=[
(
"<|start_header_id|>ipython<|end_header_id|>\n\n{{content}}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
]
),
format_tools=ToolFormatter(tool_format="llama3"),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
default_system=(
"You are Skywork-o1, a thinking model developed by Skywork AI, specializing in solving complex problems "
"involving mathematics, coding, and logical reasoning through deep thought. When faced with a user's request, "
"you first engage in a lengthy and in-depth thinking process to explore possible solutions to the problem. "
"After completing your thoughts, you then provide a detailed explanation of the solution process "
"in your response."
),
stop_words=["<|eot_id|>", "<|eom_id|>"],
)
......@@ -1053,10 +1179,9 @@ _register_template(
_register_template(
name="starchat",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>"]),
format_assistant=StringFormatter(slots=["{{content}}<|end|>\n"]),
format_system=StringFormatter(slots=["<|system|>\n{{content}}<|end|>\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
stop_words=["<|end|>"],
replace_eos=True,
)
......@@ -1064,8 +1189,16 @@ _register_template(
name="telechat",
format_user=StringFormatter(slots=["<_user>{{content}}<_bot>"]),
format_system=StringFormatter(slots=["<_system>{{content}}<_end>"]),
stop_words=["<_end>"],
replace_eos=True,
)
_register_template(
name="telechat2",
format_user=StringFormatter(slots=["<_user>{{content}}<_bot>"]),
format_system=StringFormatter(slots=["<_system>{{content}}"]),
default_system=(
"你是中国电信星辰语义大模型,英文名是TeleChat,你是由中电信人工智能科技有限公司和中国电信人工智能研究院(TeleAI)研发的人工智能助手。"
),
)
......@@ -1076,6 +1209,7 @@ _register_template(
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
),
replace_jinja_template=True,
)
......@@ -1110,8 +1244,8 @@ _register_template(
_register_template(
name="yayi",
format_user=StringFormatter(slots=[{"token": "<|Human|>"}, ":\n{{content}}\n\n", {"token": "<|YaYi|>"}, ":"]),
format_assistant=StringFormatter(slots=["{{content}}\n\n"]),
format_system=StringFormatter(slots=[{"token": "<|System|>"}, ":\n{{content}}\n\n"]),
format_separator=EmptyFormatter(slots=["\n\n"]),
default_system=(
"You are a helpful, respectful and honest assistant named YaYi "
"developed by Beijing Wenge Technology Co.,Ltd. "
......@@ -1127,20 +1261,20 @@ _register_template(
)
# copied from chatml template
_register_template(
name="yi",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
stop_words=["<|im_end|>"],
replace_eos=True,
)
_register_template(
name="yi_vl",
format_user=StringFormatter(slots=["### Human: {{content}}\n### Assistant:"]),
format_separator=EmptyFormatter(slots=["\n"]),
format_assistant=StringFormatter(slots=["{{content}}\n"]),
default_system=(
"This is a chat between an inquisitive human and an AI assistant. "
"Assume the role of the AI assistant. Read all the images carefully, "
......@@ -1157,9 +1291,8 @@ _register_template(
_register_template(
name="yuan",
format_user=StringFormatter(slots=["{{content}}", {"token": "<sep>"}]),
format_separator=EmptyFormatter(slots=["\n"]),
format_assistant=StringFormatter(slots=["{{content}}<eod>\n"]),
stop_words=["<eod>"],
replace_eos=True,
)
......@@ -1174,5 +1307,5 @@ _register_template(
_register_template(
name="ziya",
format_user=StringFormatter(slots=["<human>:{{content}}\n<bot>:"]),
format_separator=EmptyFormatter(slots=["\n"]),
format_assistant=StringFormatter(slots=["{{content}}\n"]),
)
......@@ -15,15 +15,20 @@
import json
import re
from abc import ABC, abstractmethod
from collections import namedtuple
from dataclasses import dataclass
from typing import Any, Dict, List, Tuple, Union
from datetime import datetime
from typing import Any, Dict, List, NamedTuple, Tuple, Union
from typing_extensions import override
from .data_utils import SLOTS
class FunctionCall(NamedTuple):
name: str
arguments: str
DEFAULT_TOOL_PROMPT = (
"You have access to the following tools:\n{tool_text}"
"Use the following format if using a tool:\n"
......@@ -34,14 +39,25 @@ DEFAULT_TOOL_PROMPT = (
"```\n"
)
GLM4_TOOL_PROMPT = (
"你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,"
"你的任务是针对用户的问题和要求提供适当的答复和支持。# 可用工具{tool_text}"
)
LLAMA3_TOOL_PROMPT = (
"Cutting Knowledge Date: December 2023\nToday Date: {date}\n\n"
"You have access to the following functions. To call a function, please respond with JSON for a function call. "
"""Respond in the format {{"name": function name, "parameters": dictionary of argument name and its value}}. """
"Do not use variables.\n\n{tool_text}"
)
FunctionCall = namedtuple("FunctionCall", ["name", "arguments"])
QWEN_TOOL_PROMPT = (
"\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\n"
"You are provided with function signatures within <tools></tools> XML tags:\n<tools>{tool_text}"
"\n</tools>\n\nFor each function call, return a json object with function name and arguments within "
"""<tool_call></tool_call> XML tags:\n<tool_call>\n{{"name": <function-name>, """
""""arguments": <args-json-object>}}\n</tool_call><|im_end|>\n"""
)
@dataclass
......@@ -52,17 +68,17 @@ class ToolUtils(ABC):
@staticmethod
@abstractmethod
def get_function_slots() -> SLOTS:
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
r"""
Gets a list of slots corresponding to a single function call.
Generates the system message describing all the available tools.
"""
...
@staticmethod
@abstractmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
def function_formatter(functions: List["FunctionCall"]) -> SLOTS:
r"""
Generates the system message describing all the available tools.
Generates the assistant message including all the tool calls.
"""
...
......@@ -70,16 +86,17 @@ class ToolUtils(ABC):
@abstractmethod
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
r"""
Extracts all the function calls from the response message.
Extracts all the function calls from the assistant message.
It should be an inverse function of `function_formatter`.
"""
...
class DefaultToolUtils(ToolUtils):
@override
@staticmethod
def get_function_slots() -> SLOTS:
return ["Action: {{name}}\nAction Input: {{arguments}}\n"]
r"""
Default tool using template.
"""
@override
@staticmethod
......@@ -115,6 +132,15 @@ class DefaultToolUtils(ToolUtils):
return DEFAULT_TOOL_PROMPT.format(tool_text=tool_text, tool_names=", ".join(tool_names))
@override
@staticmethod
def function_formatter(functions: List["FunctionCall"]) -> SLOTS:
function_text = ""
for name, arguments in functions:
function_text += f"Action: {name}\nAction Input: {arguments}\n"
return [function_text]
@override
@staticmethod
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
......@@ -129,7 +155,7 @@ class DefaultToolUtils(ToolUtils):
tool_input = match[1].strip().strip('"').strip("```")
try:
arguments = json.loads(tool_input)
results.append((tool_name, json.dumps(arguments, ensure_ascii=False)))
results.append(FunctionCall(tool_name, json.dumps(arguments, ensure_ascii=False)))
except json.JSONDecodeError:
return content
......@@ -137,10 +163,9 @@ class DefaultToolUtils(ToolUtils):
class GLM4ToolUtils(ToolUtils):
@override
@staticmethod
def get_function_slots() -> SLOTS:
return ["{{name}}\n{{arguments}}"]
r"""
GLM-4 tool using template.
"""
@override
@staticmethod
......@@ -153,6 +178,14 @@ class GLM4ToolUtils(ToolUtils):
return GLM4_TOOL_PROMPT.format(tool_text=tool_text)
@override
@staticmethod
def function_formatter(functions: List["FunctionCall"]) -> SLOTS:
if len(functions) > 1:
raise ValueError("GLM-4 does not support parallel functions.")
return [f"{functions[0].name}\n{functions[0].arguments}"]
@override
@staticmethod
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
......@@ -161,16 +194,152 @@ class GLM4ToolUtils(ToolUtils):
tool_name, tool_input = content.split("\n", maxsplit=1)
try:
arguments = json.loads(tool_input)
arguments = json.loads(tool_input.strip())
except json.JSONDecodeError:
return content
return [FunctionCall(tool_name, json.dumps(arguments, ensure_ascii=False))]
class Llama3ToolUtils(ToolUtils):
r"""
Llama 3.x tool using template with `tools_in_user_message=False`.
Reference: https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_1/#json-based-tool-calling
"""
@override
@staticmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
date = datetime.now().strftime("%d %b %Y")
tool_text = ""
for tool in tools:
wrapped_tool = {"type": "function", "function": tool}
tool_text += json.dumps(wrapped_tool, indent=4, ensure_ascii=False) + "\n\n"
return LLAMA3_TOOL_PROMPT.format(date=date, tool_text=tool_text)
@override
@staticmethod
def function_formatter(functions: List["FunctionCall"]) -> SLOTS:
if len(functions) > 1:
raise ValueError("Llama-3 does not support parallel functions.")
return [f'{{"name": "{functions[0].name}", "parameters": {functions[0].arguments}}}']
@override
@staticmethod
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
try:
tool = json.loads(content.strip())
except json.JSONDecodeError:
return content
return [(tool_name, json.dumps(arguments, ensure_ascii=False))]
if "name" not in tool or "parameters" not in tool:
return content
return [FunctionCall(tool["name"], json.dumps(tool["parameters"], ensure_ascii=False))]
class MistralToolUtils(ToolUtils):
r"""
Mistral v0.3 tool using template.
"""
@override
@staticmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
wrapped_tools = []
for tool in tools:
wrapped_tools.append({"type": "function", "function": tool})
return "[AVAILABLE_TOOLS] " + json.dumps(wrapped_tools, ensure_ascii=False) + "[/AVAILABLE_TOOLS]"
@override
@staticmethod
def function_formatter(functions: List["FunctionCall"]) -> SLOTS:
function_texts = []
for name, arguments in functions:
function_texts.append(f'{{"name": "{name}", "arguments": {arguments}}}')
return ["[" + ", ".join(function_texts) + "]"]
@override
@staticmethod
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
try:
tools = json.loads(content.strip())
except json.JSONDecodeError:
return content
if not isinstance(tools, list):
tools = [tools]
results = []
for tool in tools:
if "name" not in tool or "arguments" not in tool:
return content
results.append(FunctionCall(tool["name"], json.dumps(tool["arguments"], ensure_ascii=False)))
return results
class QwenToolUtils(ToolUtils):
r"""
Qwen 2.5 tool using template.
"""
@override
@staticmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
tool_text = ""
for tool in tools:
wrapped_tool = {"type": "function", "function": tool}
tool_text += "\n" + json.dumps(wrapped_tool, ensure_ascii=False)
return QWEN_TOOL_PROMPT.format(tool_text=tool_text)
@override
@staticmethod
def function_formatter(functions: List["FunctionCall"]) -> SLOTS:
function_texts = []
for name, arguments in functions:
function_texts.append(
"<tool_call>\n" + f'{{"name": "{name}", "arguments": {arguments}}}' + "\n</tool_call>"
)
return ["\n".join(function_texts)]
@override
@staticmethod
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
regex = re.compile(r"<tool_call>(.+?)</tool_call>(?=\s*<tool_call>|\s*$)", re.DOTALL)
tool_match: List[str] = re.findall(regex, content)
if not tool_match:
return content
results = []
for tool in tool_match:
try:
tool = json.loads(tool.strip())
except json.JSONDecodeError:
return content
if "name" not in tool or "arguments" not in tool:
return content
results.append(FunctionCall(tool["name"], json.dumps(tool["arguments"], ensure_ascii=False)))
return results
TOOLS = {
"default": DefaultToolUtils(),
"glm4": GLM4ToolUtils(),
"llama3": Llama3ToolUtils(),
"mistral": MistralToolUtils(),
"qwen": QwenToolUtils(),
}
......
......@@ -100,7 +100,7 @@ class Evaluator:
cache_dir=self.model_args.cache_dir,
download_mode=self.eval_args.download_mode,
token=self.model_args.hf_hub_token,
trust_remote_code=True,
trust_remote_code=self.model_args.trust_remote_code,
)
pbar.set_postfix_str(categorys[subject]["name"])
inputs, outputs, labels = [], [], []
......
......@@ -81,19 +81,6 @@ TRAINING_STAGES = {
STAGES_USE_PAIR_DATA = {"rm", "dpo"}
SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN = {
"cohere",
"falcon",
"gemma",
"gemma2",
"llama",
"mistral",
"phi",
"phi3",
"qwen2",
"starcoder2",
}
SUPPORTED_CLASS_FOR_S2ATTN = {"llama"}
VIDEO_PLACEHOLDER = os.environ.get("VIDEO_PLACEHOLDER", "<video>")
......@@ -118,7 +105,7 @@ def register_model_group(
) -> None:
for name, path in models.items():
SUPPORTED_MODELS[name] = path
if template is not None and any(suffix in name for suffix in ("-Chat", "-Instruct")):
if template is not None and (any(suffix in name for suffix in ("-Chat", "-Instruct")) or vision):
DEFAULT_TEMPLATE[name] = template
if vision:
VISION_MODELS.add(name)
......@@ -338,6 +325,7 @@ register_model_group(
models={
"Codestral-22B-v0.1-Chat": {
DownloadSource.DEFAULT: "mistralai/Codestral-22B-v0.1",
DownloadSource.MODELSCOPE: "swift/Codestral-22B-v0.1",
},
},
template="mistral",
......@@ -433,15 +421,19 @@ register_model_group(
},
"DeepSeek-Coder-V2-16B-Base": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-Coder-V2-Lite-Base",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-Coder-V2-Lite-Base",
},
"DeepSeek-Coder-V2-236B-Base": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-Coder-V2-Base",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-Coder-V2-Base",
},
"DeepSeek-Coder-V2-16B-Instruct": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct",
},
"DeepSeek-Coder-V2-236B-Instruct": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-Coder-V2-Instruct",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-Coder-V2-Instruct",
},
},
template="deepseek",
......@@ -456,6 +448,7 @@ register_model_group(
},
"DeepSeek-Coder-7B-Base": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-7b-base-v1.5",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-7b-base-v1.5",
},
"DeepSeek-Coder-33B-Base": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-33b-base",
......@@ -467,6 +460,7 @@ register_model_group(
},
"DeepSeek-Coder-7B-Instruct": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-7b-instruct-v1.5",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-7b-instruct-v1.5",
},
"DeepSeek-Coder-33B-Instruct": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-33b-instruct",
......@@ -477,6 +471,33 @@ register_model_group(
)
register_model_group(
models={
"DeepSeek-V2-236B-Chat-0628": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2-Chat-0628",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2-Chat-0628",
},
"DeepSeek-V2.5-236B-Chat": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2.5",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2.5",
},
"DeepSeek-V2.5-236B-Chat-1210": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2.5-1210",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2.5-1210",
},
"DeepSeek-V3-685B-Base": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V3-Base",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V3-Base",
},
"DeepSeek-V3-685B-Chat": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V3",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V3",
},
},
template="deepseek3",
)
register_model_group(
models={
"EXAONE-3.0-7.8B-Instruct": {
......@@ -495,6 +516,7 @@ register_model_group(
},
"Falcon-11B": {
DownloadSource.DEFAULT: "tiiuae/falcon-11B",
DownloadSource.MODELSCOPE: "tiiuae/falcon-11B",
},
"Falcon-40B": {
DownloadSource.DEFAULT: "tiiuae/falcon-40b",
......@@ -598,14 +620,99 @@ register_model_group(
register_model_group(
models={
"Index-1.9B-Chat": {
DownloadSource.DEFAULT: "IndexTeam/Index-1.9B-Chat",
DownloadSource.MODELSCOPE: "IndexTeam/Index-1.9B-Chat",
"GPT-2-Small": {
DownloadSource.DEFAULT: "openai-community/gpt2",
DownloadSource.MODELSCOPE: "AI-ModelScope/gpt2",
},
"Index-1.9B-Character-Chat": {
DownloadSource.DEFAULT: "IndexTeam/Index-1.9B-Character",
DownloadSource.MODELSCOPE: "IndexTeam/Index-1.9B-Character",
"GPT-2-Medium": {
DownloadSource.DEFAULT: "openai-community/gpt2-medium",
DownloadSource.MODELSCOPE: "AI-ModelScope/gpt2-medium",
},
"GPT-2-Large": {
DownloadSource.DEFAULT: "openai-community/gpt2-large",
DownloadSource.MODELSCOPE: "AI-ModelScope/gpt2-large",
},
"GPT-2-XL": {
DownloadSource.DEFAULT: "openai-community/gpt2-xl",
DownloadSource.MODELSCOPE: "goodbai95/GPT2-xl",
},
},
)
register_model_group(
models={
"Granite-3.0-1B-A400M-Base": {
DownloadSource.DEFAULT: "ibm-granite/granite-3.0-1b-a400m-base",
DownloadSource.MODELSCOPE: "AI-ModelScope/granite-3.0-1b-a400m-base",
},
"Granite-3.0-3B-A800M-Base": {
DownloadSource.DEFAULT: "ibm-granite/granite-3.0-3b-a800m-base",
DownloadSource.MODELSCOPE: "AI-ModelScope/granite-3.0-3b-a800m-base",
},
"Granite-3.0-2B-Base": {
DownloadSource.DEFAULT: "ibm-granite/granite-3.0-2b-base",
DownloadSource.MODELSCOPE: "AI-ModelScope/granite-3.0-2b-base",
},
"Granite-3.0-8B-Base": {
DownloadSource.DEFAULT: "ibm-granite/granite-3.0-8b-base",
DownloadSource.MODELSCOPE: "AI-ModelScope/granite-3.0-8b-base",
},
"Granite-3.0-1B-A400M-Instruct": {
DownloadSource.DEFAULT: "ibm-granite/granite-3.0-1b-a400m-instruct",
DownloadSource.MODELSCOPE: "AI-ModelScope/granite-3.0-1b-a400m-instruct",
},
"Granite-3.0-3B-A800M-Instruct": {
DownloadSource.DEFAULT: "ibm-granite/granite-3.0-3b-a800m-instruct",
DownloadSource.MODELSCOPE: "AI-ModelScope/granite-3.0-3b-a800m-instruct",
},
"Granite-3.0-2B-Instruct": {
DownloadSource.DEFAULT: "ibm-granite/granite-3.0-2b-instruct",
DownloadSource.MODELSCOPE: "AI-ModelScope/granite-3.0-2b-instruct",
},
"Granite-3.0-8B-Instruct": {
DownloadSource.DEFAULT: "ibm-granite/granite-3.0-8b-instruct",
DownloadSource.MODELSCOPE: "AI-ModelScope/granite-3.0-8b-instruct",
},
"Granite-3.1-1B-A400M-Base": {
DownloadSource.DEFAULT: "ibm-granite/granite-3.1-1b-a400m-base",
DownloadSource.MODELSCOPE: "AI-ModelScope/granite-3.1-1b-a400m-base",
},
"Granite-3.1-3B-A800M-Base": {
DownloadSource.DEFAULT: "ibm-granite/granite-3.1-3b-a800m-base",
DownloadSource.MODELSCOPE: "AI-ModelScope/granite-3.1-3b-a800m-base",
},
"Granite-3.1-2B-Base": {
DownloadSource.DEFAULT: "ibm-granite/granite-3.1-2b-base",
DownloadSource.MODELSCOPE: "AI-ModelScope/granite-3.1-2b-base",
},
"Granite-3.1-8B-Base": {
DownloadSource.DEFAULT: "ibm-granite/granite-3.1-8b-base",
DownloadSource.MODELSCOPE: "AI-ModelScope/granite-3.1-8b-base",
},
"Granite-3.1-1B-A400M-Instruct": {
DownloadSource.DEFAULT: "ibm-granite/granite-3.1-1b-a400m-instruct",
DownloadSource.MODELSCOPE: "AI-ModelScope/granite-3.1-1b-a400m-instruct",
},
"Granite-3.1-3B-A800M-Instruct": {
DownloadSource.DEFAULT: "ibm-granite/granite-3.1-3b-a800m-instruct",
DownloadSource.MODELSCOPE: "AI-ModelScope/granite-3.1-3b-a800m-instruct",
},
"Granite-3.1-2B-Instruct": {
DownloadSource.DEFAULT: "ibm-granite/granite-3.1-2b-instruct",
DownloadSource.MODELSCOPE: "AI-ModelScope/granite-3.1-2b-instruct",
},
"Granite-3.1-8B-Instruct": {
DownloadSource.DEFAULT: "ibm-granite/granite-3.1-8b-instruct",
DownloadSource.MODELSCOPE: "AI-ModelScope/granite-3.1-8b-instruct",
},
},
template="granite3",
)
register_model_group(
models={
"Index-1.9B-Base": {
DownloadSource.DEFAULT: "IndexTeam/Index-1.9B",
DownloadSource.MODELSCOPE: "IndexTeam/Index-1.9B",
......@@ -614,6 +721,14 @@ register_model_group(
DownloadSource.DEFAULT: "IndexTeam/Index-1.9B-Pure",
DownloadSource.MODELSCOPE: "IndexTeam/Index-1.9B-Pure",
},
"Index-1.9B-Chat": {
DownloadSource.DEFAULT: "IndexTeam/Index-1.9B-Chat",
DownloadSource.MODELSCOPE: "IndexTeam/Index-1.9B-Chat",
},
"Index-1.9B-Character-Chat": {
DownloadSource.DEFAULT: "IndexTeam/Index-1.9B-Character",
DownloadSource.MODELSCOPE: "IndexTeam/Index-1.9B-Character",
},
"Index-1.9B-Chat-32K": {
DownloadSource.DEFAULT: "IndexTeam/Index-1.9B-32K",
DownloadSource.MODELSCOPE: "IndexTeam/Index-1.9B-32K",
......@@ -702,6 +817,15 @@ register_model_group(
template="intern2",
)
register_model_group(
models={
"InternLM3-8B-Chat": {
DownloadSource.DEFAULT: "internlm/internlm3-8b-instruct",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm3-8b-instruct",
},
},
template="intern3",
)
register_model_group(
models={
......@@ -850,6 +974,10 @@ register_model_group(
DownloadSource.DEFAULT: "meta-llama/Llama-3.2-3B-Instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Llama-3.2-3B-Instruct",
},
"Llama-3.3-70B-Instruct": {
DownloadSource.DEFAULT: "meta-llama/Llama-3.3-70B-Instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Llama-3.3-70B-Instruct",
},
},
template="llama3",
)
......@@ -857,10 +985,18 @@ register_model_group(
register_model_group(
models={
"Llama-3.2-11B-Vision": {
DownloadSource.DEFAULT: "meta-llama/Llama-3.2-11B-Vision",
DownloadSource.MODELSCOPE: "LLM-Research/Llama-3.2-11B-Vision",
},
"Llama-3.2-11B-Vision-Instruct": {
DownloadSource.DEFAULT: "meta-llama/Llama-3.2-11B-Vision-Instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Llama-3.2-11B-Vision-Instruct",
},
"Llama-3.2-90B-Vision": {
DownloadSource.DEFAULT: "meta-llama/Llama-3.2-90B-Vision",
DownloadSource.MODELSCOPE: "LLM-Research/Llama-3.2-90B-Vision",
},
"Llama-3.2-90B-Vision-Instruct": {
DownloadSource.DEFAULT: "meta-llama/Llama-3.2-90B-Vision-Instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Llama-3.2-90B-Vision-Instruct",
......@@ -998,6 +1134,17 @@ register_model_group(
)
register_model_group(
models={
"Marco-o1-Chat": {
DownloadSource.DEFAULT: "AIDC-AI/Marco-o1",
DownloadSource.MODELSCOPE: "AIDC-AI/Marco-o1",
},
},
template="marco",
)
register_model_group(
models={
"MiniCPM-2B-SFT-Chat": {
......@@ -1025,6 +1172,28 @@ register_model_group(
)
register_model_group(
models={
"MiniCPM-o-2_6-Chat": {
DownloadSource.DEFAULT: "openbmb/MiniCPM-o-2_6",
DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM-o-2_6",
},
},
template="minicpm_v",
)
register_model_group(
models={
"MiniCPM-V-2_6-Chat": {
DownloadSource.DEFAULT: "openbmb/MiniCPM-V-2_6",
DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM-V-2_6",
},
},
template="minicpm_v",
)
register_model_group(
models={
"Mistral-7B-v0.1": {
......@@ -1173,23 +1342,23 @@ register_model_group(
register_model_group(
models={
"PaliGemma-3B-pt-224-Chat": {
"PaliGemma-3B-pt-224": {
DownloadSource.DEFAULT: "google/paligemma-3b-pt-224",
DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-pt-224",
},
"PaliGemma-3B-pt-448-Chat": {
"PaliGemma-3B-pt-448": {
DownloadSource.DEFAULT: "google/paligemma-3b-pt-448",
DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-pt-448",
},
"PaliGemma-3B-pt-896-Chat": {
"PaliGemma-3B-pt-896": {
DownloadSource.DEFAULT: "google/paligemma-3b-pt-896",
DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-pt-896",
},
"PaliGemma-3B-mix-224-Chat": {
"PaliGemma-3B-mix-224": {
DownloadSource.DEFAULT: "google/paligemma-3b-mix-224",
DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-mix-224",
},
"PaliGemma-3B-mix-448-Chat": {
"PaliGemma-3B-mix-448": {
DownloadSource.DEFAULT: "google/paligemma-3b-mix-448",
DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-mix-448",
},
......@@ -1199,6 +1368,50 @@ register_model_group(
)
register_model_group(
models={
"PaliGemma2-3B-pt-224": {
DownloadSource.DEFAULT: "google/paligemma2-3b-pt-224",
DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma2-3b-pt-224",
},
"PaliGemma2-3B-pt-448": {
DownloadSource.DEFAULT: "google/paligemma2-3b-pt-448",
DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma2-3b-pt-448",
},
"PaliGemma2-3B-pt-896": {
DownloadSource.DEFAULT: "google/paligemma2-3b-pt-896",
DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma2-3b-pt-896",
},
"PaliGemma2-10B-pt-224": {
DownloadSource.DEFAULT: "google/paligemma2-10b-pt-224",
DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma2-10b-pt-224",
},
"PaliGemma2-10B-pt-448": {
DownloadSource.DEFAULT: "google/paligemma2-10b-pt-448",
DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma2-10b-pt-448",
},
"PaliGemma2-10B-pt-896": {
DownloadSource.DEFAULT: "google/paligemma2-10b-pt-896",
DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma2-10b-pt-896",
},
"PaliGemma2-28B-pt-224": {
DownloadSource.DEFAULT: "google/paligemma2-28b-pt-224",
DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma2-28b-pt-224",
},
"PaliGemma2-28B-pt-448": {
DownloadSource.DEFAULT: "google/paligemma2-28b-pt-448",
DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma2-28b-pt-448",
},
"PaliGemma2-28B-pt-896": {
DownloadSource.DEFAULT: "google/paligemma2-28b-pt-896",
DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma2-28b-pt-896",
},
},
template="paligemma",
vision=True,
)
register_model_group(
models={
"Phi-1.5-1.3B": {
......@@ -1231,6 +1444,14 @@ register_model_group(
DownloadSource.DEFAULT: "microsoft/Phi-3-medium-128k-instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-medium-128k-instruct",
},
"Phi-3.5-4B-instruct": {
DownloadSource.DEFAULT: "microsoft/Phi-3.5-mini-instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3.5-mini-instruct",
},
"Phi-3.5-MoE-42B-A6.6B-instruct": {
DownloadSource.DEFAULT: "microsoft/Phi-3.5-MoE-instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3.5-MoE-instruct",
},
},
template="phi",
)
......@@ -1253,7 +1474,18 @@ register_model_group(
register_model_group(
models={
"Pixtral-12B-Chat": {
"Phi-4-14B-Instruct": {
DownloadSource.DEFAULT: "microsoft/phi-4",
DownloadSource.MODELSCOPE: "LLM-Research/phi-4",
},
},
template="phi4",
)
register_model_group(
models={
"Pixtral-12B-Instruct": {
DownloadSource.DEFAULT: "mistral-community/pixtral-12b",
DownloadSource.MODELSCOPE: "AI-ModelScope/pixtral-12b",
}
......@@ -1267,67 +1499,67 @@ register_model_group(
models={
"Qwen-1.8B": {
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B",
DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B",
DownloadSource.MODELSCOPE: "Qwen/Qwen-1_8B",
},
"Qwen-7B": {
DownloadSource.DEFAULT: "Qwen/Qwen-7B",
DownloadSource.MODELSCOPE: "qwen/Qwen-7B",
DownloadSource.MODELSCOPE: "Qwen/Qwen-7B",
},
"Qwen-14B": {
DownloadSource.DEFAULT: "Qwen/Qwen-14B",
DownloadSource.MODELSCOPE: "qwen/Qwen-14B",
DownloadSource.MODELSCOPE: "Qwen/Qwen-14B",
},
"Qwen-72B": {
DownloadSource.DEFAULT: "Qwen/Qwen-72B",
DownloadSource.MODELSCOPE: "qwen/Qwen-72B",
DownloadSource.MODELSCOPE: "Qwen/Qwen-72B",
},
"Qwen-1.8B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat",
DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat",
DownloadSource.MODELSCOPE: "Qwen/Qwen-1_8B-Chat",
},
"Qwen-7B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat",
DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat",
DownloadSource.MODELSCOPE: "Qwen/Qwen-7B-Chat",
},
"Qwen-14B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat",
DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat",
DownloadSource.MODELSCOPE: "Qwen/Qwen-14B-Chat",
},
"Qwen-72B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat",
DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat",
DownloadSource.MODELSCOPE: "Qwen/Qwen-72B-Chat",
},
"Qwen-1.8B-Chat-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat-Int8",
DownloadSource.MODELSCOPE: "Qwen/Qwen-1_8B-Chat-Int8",
},
"Qwen-1.8B-Chat-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat-Int4",
DownloadSource.MODELSCOPE: "Qwen/Qwen-1_8B-Chat-Int4",
},
"Qwen-7B-Chat-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat-Int8",
DownloadSource.MODELSCOPE: "Qwen/Qwen-7B-Chat-Int8",
},
"Qwen-7B-Chat-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat-Int4",
DownloadSource.MODELSCOPE: "Qwen/Qwen-7B-Chat-Int4",
},
"Qwen-14B-Chat-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat-Int8",
DownloadSource.MODELSCOPE: "Qwen/Qwen-14B-Chat-Int8",
},
"Qwen-14B-Chat-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat-Int4",
DownloadSource.MODELSCOPE: "Qwen/Qwen-14B-Chat-Int4",
},
"Qwen-72B-Chat-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat-Int8",
DownloadSource.MODELSCOPE: "Qwen/Qwen-72B-Chat-Int8",
},
"Qwen-72B-Chat-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat-Int4",
DownloadSource.MODELSCOPE: "Qwen/Qwen-72B-Chat-Int4",
},
},
template="qwen",
......@@ -1338,147 +1570,147 @@ register_model_group(
models={
"Qwen1.5-0.5B": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-0.5B",
},
"Qwen1.5-1.8B": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-1.8B",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-1.8B",
},
"Qwen1.5-4B": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-4B",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-4B",
},
"Qwen1.5-7B": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-7B",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-7B",
},
"Qwen1.5-14B": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-14B",
},
"Qwen1.5-32B": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-32B",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-32B",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-32B",
},
"Qwen1.5-72B": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-72B",
},
"Qwen1.5-110B": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-110B",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-110B",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-110B",
},
"Qwen1.5-MoE-A2.7B": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-MoE-A2.7B",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-MoE-A2.7B",
},
"Qwen1.5-0.5B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-0.5B-Chat",
},
"Qwen1.5-1.8B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B-Chat",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-1.8B-Chat",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-1.8B-Chat",
},
"Qwen1.5-4B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B-Chat",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-4B-Chat",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-4B-Chat",
},
"Qwen1.5-7B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B-Chat",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-7B-Chat",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-7B-Chat",
},
"Qwen1.5-14B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B-Chat",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-14B-Chat",
},
"Qwen1.5-32B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-32B-Chat",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-32B-Chat",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-32B-Chat",
},
"Qwen1.5-72B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-72B-Chat",
},
"Qwen1.5-110B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-110B-Chat",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-110B-Chat",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-110B-Chat",
},
"Qwen1.5-MoE-A2.7B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B-Chat",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-MoE-A2.7B-Chat",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-MoE-A2.7B-Chat",
},
"Qwen1.5-0.5B-Chat-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8",
},
"Qwen1.5-0.5B-Chat-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-0.5B-Chat-AWQ",
},
"Qwen1.5-1.8B-Chat-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-1.8B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-1.8B-Chat-GPTQ-Int8",
},
"Qwen1.5-1.8B-Chat-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B-Chat-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-1.8B-Chat-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-1.8B-Chat-AWQ",
},
"Qwen1.5-4B-Chat-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-4B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-4B-Chat-GPTQ-Int8",
},
"Qwen1.5-4B-Chat-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B-Chat-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-4B-Chat-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-4B-Chat-AWQ",
},
"Qwen1.5-7B-Chat-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-7B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-7B-Chat-GPTQ-Int8",
},
"Qwen1.5-7B-Chat-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B-Chat-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-7B-Chat-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-7B-Chat-AWQ",
},
"Qwen1.5-14B-Chat-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-14B-Chat-GPTQ-Int8",
},
"Qwen1.5-14B-Chat-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B-Chat-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-14B-Chat-AWQ",
},
"Qwen1.5-32B-Chat-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-32B-Chat-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-32B-Chat-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-32B-Chat-AWQ",
},
"Qwen1.5-72B-Chat-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-72B-Chat-GPTQ-Int8",
},
"Qwen1.5-72B-Chat-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-72B-Chat-AWQ",
},
"Qwen1.5-110B-Chat-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-110B-Chat-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-110B-Chat-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-110B-Chat-AWQ",
},
"Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4",
},
"CodeQwen1.5-7B": {
DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B",
DownloadSource.MODELSCOPE: "qwen/CodeQwen1.5-7B",
DownloadSource.MODELSCOPE: "Qwen/CodeQwen1.5-7B",
},
"CodeQwen1.5-7B-Chat": {
DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B-Chat",
DownloadSource.MODELSCOPE: "qwen/CodeQwen1.5-7B-Chat",
DownloadSource.MODELSCOPE: "Qwen/CodeQwen1.5-7B-Chat",
},
"CodeQwen1.5-7B-Chat-AWQ": {
DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B-Chat-AWQ",
DownloadSource.MODELSCOPE: "qwen/CodeQwen1.5-7B-Chat-AWQ",
DownloadSource.MODELSCOPE: "Qwen/CodeQwen1.5-7B-Chat-AWQ",
},
},
template="qwen",
......@@ -1489,122 +1721,122 @@ register_model_group(
models={
"Qwen2-0.5B": {
DownloadSource.DEFAULT: "Qwen/Qwen2-0.5B",
DownloadSource.MODELSCOPE: "qwen/Qwen2-0.5B",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-0.5B",
},
"Qwen2-1.5B": {
DownloadSource.DEFAULT: "Qwen/Qwen2-1.5B",
DownloadSource.MODELSCOPE: "qwen/Qwen2-1.5B",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-1.5B",
},
"Qwen2-7B": {
DownloadSource.DEFAULT: "Qwen/Qwen2-7B",
DownloadSource.MODELSCOPE: "qwen/Qwen2-7B",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-7B",
},
"Qwen2-72B": {
DownloadSource.DEFAULT: "Qwen/Qwen2-72B",
DownloadSource.MODELSCOPE: "qwen/Qwen2-72B",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-72B",
},
"Qwen2-MoE-57B-A14B": {
DownloadSource.DEFAULT: "Qwen/Qwen2-57B-A14B",
DownloadSource.MODELSCOPE: "qwen/Qwen2-57B-A14B",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-57B-A14B",
},
"Qwen2-0.5B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2-0.5B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2-0.5B-Instruct",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-0.5B-Instruct",
DownloadSource.OPENMIND: "LlamaFactory/Qwen2-0.5B-Instruct",
},
"Qwen2-1.5B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2-1.5B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2-1.5B-Instruct",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-1.5B-Instruct",
DownloadSource.OPENMIND: "LlamaFactory/Qwen2-1.5B-Instruct",
},
"Qwen2-7B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2-7B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2-7B-Instruct",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-7B-Instruct",
DownloadSource.OPENMIND: "LlamaFactory/Qwen2-7B-Instruct",
},
"Qwen2-72B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2-72B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2-72B-Instruct",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-72B-Instruct",
},
"Qwen2-MoE-57B-A14B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2-57B-A14B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2-57B-A14B-Instruct",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-57B-A14B-Instruct",
},
"Qwen2-0.5B-Instruct-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen2-0.5B-Instruct-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen2-0.5B-Instruct-GPTQ-Int8",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-0.5B-Instruct-GPTQ-Int8",
},
"Qwen2-0.5B-Instruct-GPTQ-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen2-0.5B-Instruct-GPTQ-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen2-0.5B-Instruct-GPTQ-Int4",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-0.5B-Instruct-GPTQ-Int4",
},
"Qwen2-0.5B-Instruct-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen2-0.5B-Instruct-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen2-0.5B-Instruct-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-0.5B-Instruct-AWQ",
},
"Qwen2-1.5B-Instruct-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen2-1.5B-Instruct-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen2-1.5B-Instruct-GPTQ-Int8",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-1.5B-Instruct-GPTQ-Int8",
},
"Qwen2-1.5B-Instruct-GPTQ-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen2-1.5B-Instruct-GPTQ-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen2-1.5B-Instruct-GPTQ-Int4",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-1.5B-Instruct-GPTQ-Int4",
},
"Qwen2-1.5B-Instruct-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen2-1.5B-Instruct-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen2-1.5B-Instruct-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-1.5B-Instruct-AWQ",
},
"Qwen2-7B-Instruct-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen2-7B-Instruct-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen2-7B-Instruct-GPTQ-Int8",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-7B-Instruct-GPTQ-Int8",
},
"Qwen2-7B-Instruct-GPTQ-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen2-7B-Instruct-GPTQ-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen2-7B-Instruct-GPTQ-Int4",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-7B-Instruct-GPTQ-Int4",
},
"Qwen2-7B-Instruct-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen2-7B-Instruct-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen2-7B-Instruct-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-7B-Instruct-AWQ",
},
"Qwen2-72B-Instruct-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen2-72B-Instruct-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen2-72B-Instruct-GPTQ-Int8",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-72B-Instruct-GPTQ-Int8",
},
"Qwen2-72B-Instruct-GPTQ-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen2-72B-Instruct-GPTQ-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen2-72B-Instruct-GPTQ-Int4",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-72B-Instruct-GPTQ-Int4",
},
"Qwen2-72B-Instruct-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen2-72B-Instruct-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen2-72B-Instruct-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-72B-Instruct-AWQ",
},
"Qwen2-57B-A14B-Instruct-GPTQ-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen2-57B-A14B-Instruct-GPTQ-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen2-57B-A14B-Instruct-GPTQ-Int4",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-57B-A14B-Instruct-GPTQ-Int4",
},
"Qwen2-Math-1.5B": {
DownloadSource.DEFAULT: "Qwen/Qwen2-Math-1.5B",
DownloadSource.MODELSCOPE: "qwen/Qwen2-Math-1.5B",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-Math-1.5B",
},
"Qwen2-Math-7B": {
DownloadSource.DEFAULT: "Qwen/Qwen2-Math-7B",
DownloadSource.MODELSCOPE: "qwen/Qwen2-Math-7B",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-Math-7B",
},
"Qwen2-Math-72B": {
DownloadSource.DEFAULT: "Qwen/Qwen2-Math-72B",
DownloadSource.MODELSCOPE: "qwen/Qwen2-Math-72B",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-Math-72B",
},
"Qwen2-Math-1.5B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2-Math-1.5B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2-Math-1.5B-Instruct",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-Math-1.5B-Instruct",
},
"Qwen2-Math-7B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2-Math-7B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2-Math-7B-Instruct",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-Math-7B-Instruct",
},
"Qwen2-Math-72B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2-Math-72B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2-Math-72B-Instruct",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-Math-72B-Instruct",
},
},
template="qwen",
......@@ -1615,215 +1847,219 @@ register_model_group(
models={
"Qwen2.5-0.5B": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-0.5B",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-0.5B",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-0.5B",
},
"Qwen2.5-1.5B": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-1.5B",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-1.5B",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-1.5B",
},
"Qwen2.5-3B": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-3B",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-3B",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-3B",
},
"Qwen2.5-7B": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-7B",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-7B",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-7B",
},
"Qwen2.5-14B": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-14B",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-14B",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-14B",
},
"Qwen2.5-32B": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-32B",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-32B",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-32B",
},
"Qwen2.5-72B": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-72B",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-72B",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-72B",
},
"Qwen2.5-0.5B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-0.5B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-0.5B-Instruct",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-0.5B-Instruct",
},
"Qwen2.5-1.5B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-1.5B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-1.5B-Instruct",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-1.5B-Instruct",
},
"Qwen2.5-3B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-3B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-3B-Instruct",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-3B-Instruct",
},
"Qwen2.5-7B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-7B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-7B-Instruct",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-7B-Instruct",
},
"Qwen2.5-14B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-14B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-14B-Instruct",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-14B-Instruct",
},
"Qwen2.5-32B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-32B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-32B-Instruct",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-32B-Instruct",
},
"Qwen2.5-72B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-72B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-72B-Instruct",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-72B-Instruct",
},
"Qwen2.5-0.5B-Instruct-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-0.5B-Instruct-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-0.5B-Instruct-GPTQ-Int8",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-0.5B-Instruct-GPTQ-Int8",
},
"Qwen2.5-0.5B-Instruct-GPTQ-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-0.5B-Instruct-GPTQ-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-0.5B-Instruct-GPTQ-Int4",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-0.5B-Instruct-GPTQ-Int4",
},
"Qwen2.5-0.5B-Instruct-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-0.5B-Instruct-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-0.5B-Instruct-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-0.5B-Instruct-AWQ",
},
"Qwen2.5-1.5B-Instruct-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-1.5B-Instruct-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-1.5B-Instruct-GPTQ-Int8",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-1.5B-Instruct-GPTQ-Int8",
},
"Qwen2.5-1.5B-Instruct-GPTQ-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-1.5B-Instruct-GPTQ-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-1.5B-Instruct-GPTQ-Int4",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-1.5B-Instruct-GPTQ-Int4",
},
"Qwen2.5-1.5B-Instruct-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-1.5B-Instruct-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-1.5B-Instruct-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-1.5B-Instruct-AWQ",
},
"Qwen2.5-3B-Instruct-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-3B-Instruct-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-3B-Instruct-GPTQ-Int8",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-3B-Instruct-GPTQ-Int8",
},
"Qwen2.5-3B-Instruct-GPTQ-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-3B-Instruct-GPTQ-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-3B-Instruct-GPTQ-Int4",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-3B-Instruct-GPTQ-Int4",
},
"Qwen2.5-3B-Instruct-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-3B-Instruct-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-3B-Instruct-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-3B-Instruct-AWQ",
},
"Qwen2.5-7B-Instruct-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-7B-Instruct-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-7B-Instruct-GPTQ-Int8",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-7B-Instruct-GPTQ-Int8",
},
"Qwen2.5-7B-Instruct-GPTQ-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-7B-Instruct-GPTQ-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-7B-Instruct-GPTQ-Int4",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-7B-Instruct-GPTQ-Int4",
},
"Qwen2.5-7B-Instruct-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-7B-Instruct-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-7B-Instruct-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-7B-Instruct-AWQ",
},
"Qwen2.5-14B-Instruct-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-14B-Instruct-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-14B-Instruct-GPTQ-Int8",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-14B-Instruct-GPTQ-Int8",
},
"Qwen2.5-14B-Instruct-GPTQ-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-14B-Instruct-GPTQ-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-14B-Instruct-GPTQ-Int4",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-14B-Instruct-GPTQ-Int4",
},
"Qwen2.5-14B-Instruct-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-14B-Instruct-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-14B-Instruct-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-14B-Instruct-AWQ",
},
"Qwen2.5-32B-Instruct-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-32B-Instruct-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-32B-Instruct-GPTQ-Int8",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-32B-Instruct-GPTQ-Int8",
},
"Qwen2.5-32B-Instruct-GPTQ-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-32B-Instruct-GPTQ-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-32B-Instruct-GPTQ-Int4",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-32B-Instruct-GPTQ-Int4",
},
"Qwen2.5-32B-Instruct-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-32B-Instruct-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-32B-Instruct-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-32B-Instruct-AWQ",
},
"Qwen2.5-72B-Instruct-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-72B-Instruct-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-72B-Instruct-GPTQ-Int8",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-72B-Instruct-GPTQ-Int8",
},
"Qwen2.5-72B-Instruct-GPTQ-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-72B-Instruct-GPTQ-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-72B-Instruct-GPTQ-Int4",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-72B-Instruct-GPTQ-Int4",
},
"Qwen2.5-72B-Instruct-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-72B-Instruct-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-72B-Instruct-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-72B-Instruct-AWQ",
},
"Qwen2.5-Coder-0.5B": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Coder-0.5B",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-Coder-0.5B",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-Coder-0.5B",
},
"Qwen2.5-Coder-1.5B": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Coder-1.5B",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-Coder-1.5B",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-Coder-1.5B",
},
"Qwen2.5-Coder-3B": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Coder-3B",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-Coder-3B",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-Coder-3B",
},
"Qwen2.5-Coder-7B": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Coder-7B",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-Coder-7B",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-Coder-7B",
},
"Qwen2.5-Coder-14B": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Coder-14B",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-Coder-14B",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-Coder-14B",
},
"Qwen2.5-Coder-32B": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Coder-32B",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-Coder-32B",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-Coder-32B",
},
"Qwen2.5-Coder-0.5B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Coder-0.5B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-Coder-0.5B-Instruct",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-Coder-0.5B-Instruct",
},
"Qwen2.5-Coder-1.5B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Coder-1.5B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-Coder-1.5B-Instruct",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-Coder-1.5B-Instruct",
},
"Qwen2.5-Coder-3B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Coder-3B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-Coder-3B-Instruct",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-Coder-3B-Instruct",
},
"Qwen2.5-Coder-7B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Coder-7B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-Coder-7B-Instruct",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-Coder-7B-Instruct",
},
"Qwen2.5-Coder-14B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Coder-14B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-Coder-14B-Instruct",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-Coder-14B-Instruct",
},
"Qwen2.5-Coder-32B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Coder-32B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-Coder-32B-Instruct",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-Coder-32B-Instruct",
},
"Qwen2.5-Math-1.5B": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Math-1.5B",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-Math-1.5B",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-Math-1.5B",
},
"Qwen2.5-Math-7B": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Math-7B",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-Math-7B",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-Math-7B",
},
"Qwen2.5-Math-72B": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Math-72B",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-Math-72B",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-Math-72B",
},
"Qwen2.5-Math-1.5B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Math-1.5B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-Coder-1.5B-Instruct",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-Coder-1.5B-Instruct",
},
"Qwen2.5-Math-7B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Math-7B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-Coder-7B-Instruct",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-Coder-7B-Instruct",
},
"Qwen2.5-Math-72B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Math-72B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-Coder-72B-Instruct",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-Coder-72B-Instruct",
},
"QwQ-32B-Preview-Instruct": {
DownloadSource.DEFAULT: "Qwen/QwQ-32B-Preview",
DownloadSource.MODELSCOPE: "Qwen/QwQ-32B-Preview",
},
},
template="qwen",
......@@ -1834,53 +2070,57 @@ register_model_group(
models={
"Qwen2-VL-2B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-2B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-2B-Instruct",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-VL-2B-Instruct",
DownloadSource.OPENMIND: "LlamaFactory/Qwen2-VL-2B-Instruct",
},
"Qwen2-VL-7B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-7B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-7B-Instruct",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-VL-7B-Instruct",
DownloadSource.OPENMIND: "LlamaFactory/Qwen2-VL-7B-Instruct",
},
"Qwen2-VL-72B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-72B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-72B-Instruct",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-VL-72B-Instruct",
},
"Qwen2-VL-2B-Instruct-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-2B-Instruct-GPTQ-Int8",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int8",
},
"Qwen2-VL-2B-Instruct-GPTQ-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4",
},
"Qwen2-VL-2B-Instruct-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-2B-Instruct-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-2B-Instruct-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-VL-2B-Instruct-AWQ",
},
"Qwen2-VL-7B-Instruct-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-7B-Instruct-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-7B-Instruct-GPTQ-Int8",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-VL-7B-Instruct-GPTQ-Int8",
},
"Qwen2-VL-7B-Instruct-GPTQ-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-7B-Instruct-GPTQ-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-7B-Instruct-GPTQ-Int4",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-VL-7B-Instruct-GPTQ-Int4",
},
"Qwen2-VL-7B-Instruct-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-7B-Instruct-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-7B-Instruct-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-VL-7B-Instruct-AWQ",
},
"Qwen2-VL-72B-Instruct-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-72B-Instruct-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-72B-Instruct-GPTQ-Int8",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-VL-72B-Instruct-GPTQ-Int8",
},
"Qwen2-VL-72B-Instruct-GPTQ-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-72B-Instruct-GPTQ-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-72B-Instruct-GPTQ-Int4",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-VL-72B-Instruct-GPTQ-Int4",
},
"Qwen2-VL-72B-Instruct-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-72B-Instruct-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-72B-Instruct-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-VL-72B-Instruct-AWQ",
},
"QVQ-72B-Preview": {
DownloadSource.DEFAULT: "Qwen/QVQ-72B-Preview",
DownloadSource.MODELSCOPE: "Qwen/QVQ-72B-Preview",
},
},
template="qwen2_vl",
......@@ -1912,6 +2152,17 @@ register_model_group(
)
register_model_group(
models={
"Skywork-o1-Open-Llama-3.1-8B": {
DownloadSource.DEFAULT: "Skywork/Skywork-o1-Open-Llama-3.1-8B",
DownloadSource.MODELSCOPE: "AI-ModelScope/Skywork-o1-Open-Llama-3.1-8B",
}
},
template="skywork_o1",
)
register_model_group(
models={
"StarCoder2-3B": {
......@@ -1942,19 +2193,40 @@ register_model_group(
DownloadSource.OPENMIND: "TeleAI/TeleChat-7B-pt",
},
"TeleChat-12B-Chat": {
DownloadSource.DEFAULT: "Tele-AI/TeleChat-12B",
DownloadSource.MODELSCOPE: "TeleAI/TeleChat-12B",
DownloadSource.OPENMIND: "TeleAI/TeleChat-12B-pt",
},
"TeleChat-12B-v2-Chat": {
DownloadSource.DEFAULT: "Tele-AI/TeleChat-12B-v2",
DownloadSource.MODELSCOPE: "TeleAI/TeleChat-12B-v2",
DownloadSource.OPENMIND: "TeleAI/TeleChat-12B-pt",
},
"TeleChat-52B-Chat": {
DownloadSource.DEFAULT: "Tele-AI/TeleChat-52B",
},
},
template="telechat",
)
register_model_group(
models={
"TeleChat2-3B-Chat": {
DownloadSource.DEFAULT: "Tele-AI/TeleChat2-3B",
DownloadSource.MODELSCOPE: "TeleAI/TeleChat2-3B",
},
"TeleChat2-7B-Chat": {
DownloadSource.DEFAULT: "Tele-AI/TeleChat2-7B",
DownloadSource.MODELSCOPE: "TeleAI/TeleChat2-7B",
},
"TeleChat2-35B-Chat": {
DownloadSource.MODELSCOPE: "TeleAI/TeleChat2-35B-Nov",
},
"TeleChat2-115B-Chat": {
DownloadSource.DEFAULT: "Tele-AI/TeleChat2-115B",
DownloadSource.MODELSCOPE: "TeleAI/TeleChat2-115B",
},
},
template="telechat2",
)
register_model_group(
models={
"Vicuna-v1.5-7B-Chat": {
......
......@@ -26,7 +26,7 @@ import trl
from transformers.utils import is_torch_cuda_available, is_torch_npu_available
VERSION = "0.9.1"
VERSION = "0.9.2.dev0"
def print_env() -> None:
......
......@@ -68,7 +68,7 @@ class LoggerHandler(logging.Handler):
class _Logger(logging.Logger):
r"""
A logger that supports info_rank0 and warning_once.
A logger that supports rank0 logging.
"""
def info_rank0(self, *args, **kwargs) -> None:
......@@ -77,7 +77,7 @@ class _Logger(logging.Logger):
def warning_rank0(self, *args, **kwargs) -> None:
self.warning(*args, **kwargs)
def warning_once(self, *args, **kwargs) -> None:
def warning_rank0_once(self, *args, **kwargs) -> None:
self.warning(*args, **kwargs)
......@@ -163,11 +163,11 @@ def warning_rank0(self: "logging.Logger", *args, **kwargs) -> None:
@lru_cache(None)
def warning_once(self: "logging.Logger", *args, **kwargs) -> None:
def warning_rank0_once(self: "logging.Logger", *args, **kwargs) -> None:
if int(os.getenv("LOCAL_RANK", "0")) == 0:
self.warning(*args, **kwargs)
logging.Logger.info_rank0 = info_rank0
logging.Logger.warning_rank0 = warning_rank0
logging.Logger.warning_once = warning_once
logging.Logger.warning_rank0_once = warning_rank0_once
......@@ -17,7 +17,7 @@
import gc
import os
from typing import TYPE_CHECKING, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, Literal, Sequence, Tuple, Union
import torch
import torch.distributed as dist
......@@ -73,18 +73,46 @@ class AverageMeter:
self.avg = self.sum / self.count
def check_version(requirement: str, mandatory: bool = False) -> None:
r"""
Optionally checks the package version.
"""
if os.getenv("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"] and not mandatory:
logger.warning_rank0_once("Version checking has been disabled, may lead to unexpected behaviors.")
return
if mandatory:
hint = f"To fix: run `pip install {requirement}`."
else:
hint = f"To fix: run `pip install {requirement}` or set `DISABLE_VERSION_CHECK=1` to skip this check."
require_version(requirement, hint)
def check_dependencies() -> None:
r"""
Checks the version of the required packages.
"""
if os.getenv("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]:
logger.warning_once("Version checking has been disabled, may lead to unexpected behaviors.")
else:
require_version("transformers>=4.41.2,<=4.46.1", "To fix: pip install transformers>=4.41.2,<=4.46.1")
require_version("datasets>=2.16.0,<=3.1.0", "To fix: pip install datasets>=2.16.0,<=3.1.0")
require_version("accelerate>=0.34.0,<=1.0.1", "To fix: pip install accelerate>=0.34.0,<=1.0.1")
require_version("peft>=0.11.1,<=0.12.0", "To fix: pip install peft>=0.11.1,<=0.12.0")
require_version("trl>=0.8.6,<=0.9.6", "To fix: pip install trl>=0.8.6,<=0.9.6")
check_version("transformers>=4.41.2,<=4.46.1")
check_version("datasets>=2.16.0,<=3.1.0")
check_version("accelerate>=0.34.0,<=1.0.1")
check_version("peft>=0.11.1,<=0.12.0")
check_version("trl>=0.8.6,<=0.9.6")
def calculate_tps(dataset: Sequence[Dict[str, Any]], metrics: Dict[str, float], stage: Literal["sft", "rm"]) -> float:
r"""
Calculates effective tokens per second.
"""
effective_token_num = 0
for data in dataset:
if stage == "sft":
effective_token_num += len(data["input_ids"])
elif stage == "rm":
effective_token_num += len(data["chosen_input_ids"]) + len(data["rejected_input_ids"])
result = effective_token_num * metrics["epoch"] / metrics["train_runtime"]
return result / dist.get_world_size() if dist.is_initialized() else result
def count_parameters(model: "torch.nn.Module") -> Tuple[int, int]:
......@@ -213,7 +241,7 @@ def skip_check_imports() -> None:
r"""
Avoids flash attention import error in custom model files.
"""
if os.environ.get("FORCE_CHECK_IMPORTS", "0").lower() not in ["true", "1"]:
if os.getenv("FORCE_CHECK_IMPORTS", "0").lower() not in ["true", "1"]:
transformers.dynamic_module_utils.check_imports = get_relative_imports
......@@ -237,7 +265,7 @@ def try_download_model_from_other_hub(model_args: "ModelArguments") -> str:
return model_args.model_name_or_path
if use_modelscope():
require_version("modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0")
check_version("modelscope>=1.11.0", mandatory=True)
from modelscope import snapshot_download # type: ignore
revision = "master" if model_args.model_revision == "main" else model_args.model_revision
......@@ -248,7 +276,7 @@ def try_download_model_from_other_hub(model_args: "ModelArguments") -> str:
)
if use_openmind():
require_version("openmind>=0.8.0", "To fix: pip install openmind>=0.8.0")
check_version("openmind>=0.8.0", mandatory=True)
from openmind.utils.hub import snapshot_download # type: ignore
return snapshot_download(
......@@ -259,16 +287,12 @@ def try_download_model_from_other_hub(model_args: "ModelArguments") -> str:
def use_modelscope() -> bool:
return os.environ.get("USE_MODELSCOPE_HUB", "0").lower() in ["true", "1"]
return os.getenv("USE_MODELSCOPE_HUB", "0").lower() in ["true", "1"]
def use_openmind() -> bool:
return os.environ.get("USE_OPENMIND_HUB", "0").lower() in ["true", "1"]
return os.getenv("USE_OPENMIND_HUB", "0").lower() in ["true", "1"]
def cal_effective_tokens(effective_token_num, epoch, train_runtime) -> int:
r"""
calculate effective tokens.
"""
result = effective_token_num * epoch / train_runtime
return result / dist.get_world_size() if dist.is_initialized() else result
def use_ray() -> bool:
return os.getenv("USE_RAY", "0").lower() in ["true", "1"]
......@@ -50,6 +50,10 @@ def is_galore_available():
return _is_package_available("galore_torch")
def is_apollo_available():
return _is_package_available("apollo_torch")
def is_gradio_available():
return _is_package_available("gradio")
......@@ -62,6 +66,10 @@ def is_pillow_available():
return _is_package_available("PIL")
def is_ray_available():
return _is_package_available("ray")
def is_requests_available():
return _is_package_available("requests")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment