Commit 53b3977b authored by dongchy920's avatar dongchy920
Browse files

Initial commit

parents
Pipeline #2841 failed with stages
in 0 seconds
# Copyright 2024 the LlamaFactory team.
#
# This code is inspired by the Dan's test library.
# https://github.com/hendrycks/test/blob/master/evaluate_flan.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# MIT License
#
# Copyright (c) 2020 Dan Hendrycks
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import json
import os
from typing import TYPE_CHECKING, Any, Dict, List, Optional
import numpy as np
import torch
from datasets import load_dataset
from tqdm import tqdm, trange
from transformers.utils import cached_file
from ..data import get_template_and_fix_tokenizer
from ..extras.constants import CHOICES, SUBJECTS
from ..hparams import get_eval_args
from ..model import load_model, load_tokenizer
from .template import get_eval_template
if TYPE_CHECKING:
from numpy.typing import NDArray
class Evaluator:
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
self.model_args, self.data_args, self.eval_args, finetuning_args = get_eval_args(args)
self.tokenizer = load_tokenizer(self.model_args)["tokenizer"]
self.tokenizer.padding_side = "right" # avoid overflow issue in batched inference for llama2
self.template = get_template_and_fix_tokenizer(self.tokenizer, self.data_args)
self.model = load_model(self.tokenizer, self.model_args, finetuning_args)
self.eval_template = get_eval_template(self.eval_args.lang)
self.choice_inputs = [self.tokenizer.encode(ch, add_special_tokens=False)[-1] for ch in CHOICES]
@torch.inference_mode()
def batch_inference(self, batch_input: Dict[str, "torch.Tensor"]) -> List[str]:
logits = self.model(**batch_input).logits
lengths = torch.sum(batch_input["attention_mask"], dim=-1)
word_probs = torch.stack([logits[i, lengths[i] - 1] for i in range(len(lengths))], dim=0)
choice_probs = torch.nn.functional.softmax(word_probs[:, self.choice_inputs], dim=-1).detach()
return [chr(ord("A") + offset.item()) for offset in torch.argmax(choice_probs, dim=-1)]
def eval(self) -> None:
eval_task = self.eval_args.task.split("_")[0]
eval_split = self.eval_args.task.split("_")[1]
mapping = cached_file(
path_or_repo_id=os.path.join(self.eval_args.task_dir, eval_task),
filename="mapping.json",
cache_dir=self.model_args.cache_dir,
token=self.model_args.hf_hub_token,
)
with open(mapping, encoding="utf-8") as f:
categorys: Dict[str, Dict[str, str]] = json.load(f)
category_corrects = {subj: np.array([], dtype="bool") for subj in SUBJECTS}
pbar = tqdm(categorys.keys(), desc="Processing subjects", position=0)
results = {}
for subject in pbar:
dataset = load_dataset(
path=os.path.join(self.eval_args.task_dir, eval_task),
name=subject,
cache_dir=self.model_args.cache_dir,
download_mode=self.eval_args.download_mode,
token=self.model_args.hf_hub_token,
trust_remote_code=self.model_args.trust_remote_code,
)
pbar.set_postfix_str(categorys[subject]["name"])
inputs, outputs, labels = [], [], []
for i in trange(len(dataset[eval_split]), desc="Formatting batches", position=1, leave=False):
support_set = (
dataset["train"].shuffle().select(range(min(self.eval_args.n_shot, len(dataset["train"]))))
)
messages = self.eval_template.format_example(
target_data=dataset[eval_split][i],
support_set=support_set,
subject_name=categorys[subject]["name"],
)
input_ids, _ = self.template.encode_oneturn(tokenizer=self.tokenizer, messages=messages)
inputs.append({"input_ids": input_ids, "attention_mask": [1] * len(input_ids)})
labels.append(messages[-1]["content"])
for i in trange(
0, len(inputs), self.eval_args.batch_size, desc="Predicting batches", position=1, leave=False
):
batch_input = self.tokenizer.pad(
inputs[i : i + self.eval_args.batch_size], return_attention_mask=True, return_tensors="pt"
).to(self.model.device)
preds = self.batch_inference(batch_input)
outputs += preds
corrects = np.array(outputs) == np.array(labels)
category_name = categorys[subject]["category"]
category_corrects[category_name] = np.concatenate([category_corrects[category_name], corrects], axis=0)
category_corrects["Average"] = np.concatenate([category_corrects["Average"], corrects], axis=0)
results[subject] = {str(i): outputs[i] for i in range(len(outputs))}
pbar.close()
self._save_results(category_corrects, results)
def _save_results(self, category_corrects: Dict[str, "NDArray"], results: Dict[str, Dict[int, str]]) -> None:
score_info = "\n".join(
[
f"{category_name:>15}: {100 * np.mean(category_correct):.2f}"
for category_name, category_correct in category_corrects.items()
if len(category_correct)
]
)
print(score_info)
if self.eval_args.save_dir is not None:
os.makedirs(self.eval_args.save_dir, exist_ok=False)
with open(os.path.join(self.eval_args.save_dir, "results.json"), "w", encoding="utf-8", newline="\n") as f:
json.dump(results, f, indent=2)
with open(os.path.join(self.eval_args.save_dir, "results.log"), "w", encoding="utf-8", newline="\n") as f:
f.write(score_info)
def run_eval() -> None:
Evaluator().eval()
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Dict, List, Sequence, Tuple
from ..data import Role
from ..extras.constants import CHOICES
@dataclass
class EvalTemplate:
system: str
choice: str
answer: str
def _parse_example(self, example: Dict[str, str]) -> Tuple[str, str]:
r"""
input: a dict with keys {"question", "A", "B", "C", "D", "answer"}
output: a tuple of (prompt, response)
"""
candidates = [self.choice.format(choice=ch, content=example[ch]) for ch in CHOICES if ch in example]
return "".join([example["question"]] + candidates + [self.answer]), example["answer"]
def format_example(
self, target_data: Dict[str, str], support_set: Sequence[Dict[str, str]], subject_name: str
) -> List[Dict[str, str]]:
r"""
Converts dataset examples to messages.
"""
messages = []
for k in range(len(support_set)):
prompt, response = self._parse_example(support_set[k])
messages.append({"role": Role.USER.value, "content": prompt})
messages.append({"role": Role.ASSISTANT.value, "content": response})
prompt, response = self._parse_example(target_data)
messages.append({"role": Role.USER.value, "content": prompt})
messages.append({"role": Role.ASSISTANT.value, "content": response})
messages[0]["content"] = self.system.format(subject=subject_name) + messages[0]["content"]
return messages
eval_templates: Dict[str, "EvalTemplate"] = {}
def _register_eval_template(name: str, system: str, choice: str, answer: str) -> None:
eval_templates[name] = EvalTemplate(system=system, choice=choice, answer=answer)
def get_eval_template(name: str) -> "EvalTemplate":
eval_template = eval_templates.get(name, None)
assert eval_template is not None, f"Template {name} does not exist."
return eval_template
_register_eval_template(
name="en",
system="The following are multiple choice questions (with answers) about {subject}.\n\n",
choice="\n{choice}. {content}",
answer="\nAnswer:",
)
_register_eval_template(
name="zh",
system="以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n",
choice="\n{choice}. {content}",
answer="\n答案:",
)
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from collections import OrderedDict, defaultdict
from enum import Enum
from typing import Dict, Optional
from peft.utils import SAFETENSORS_WEIGHTS_NAME as SAFE_ADAPTER_WEIGHTS_NAME
from peft.utils import WEIGHTS_NAME as ADAPTER_WEIGHTS_NAME
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME
CHECKPOINT_NAMES = {
SAFE_ADAPTER_WEIGHTS_NAME,
ADAPTER_WEIGHTS_NAME,
SAFE_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
}
CHOICES = ["A", "B", "C", "D"]
DATA_CONFIG = "dataset_info.json"
DEFAULT_TEMPLATE = defaultdict(str)
FILEEXT2TYPE = {
"arrow": "arrow",
"csv": "csv",
"json": "json",
"jsonl": "json",
"parquet": "parquet",
"txt": "text",
}
IGNORE_INDEX = -100
IMAGE_PLACEHOLDER = os.environ.get("IMAGE_PLACEHOLDER", "<image>")
LAYERNORM_NAMES = {"norm", "ln"}
LLAMABOARD_CONFIG = "llamaboard_config.yaml"
METHODS = ["full", "freeze", "lora"]
MOD_SUPPORTED_MODELS = {"bloom", "falcon", "gemma", "llama", "mistral", "mixtral", "phi", "starcoder2"}
PEFT_METHODS = {"lora"}
RUNNING_LOG = "running_log.txt"
SUBJECTS = ["Average", "STEM", "Social Sciences", "Humanities", "Other"]
SUPPORTED_MODELS = OrderedDict()
TRAINER_LOG = "trainer_log.jsonl"
TRAINING_ARGS = "training_args.yaml"
TRAINING_STAGES = {
"Supervised Fine-Tuning": "sft",
"Reward Modeling": "rm",
"PPO": "ppo",
"DPO": "dpo",
"KTO": "kto",
"Pre-Training": "pt",
}
STAGES_USE_PAIR_DATA = {"rm", "dpo"}
SUPPORTED_CLASS_FOR_S2ATTN = {"llama"}
VIDEO_PLACEHOLDER = os.environ.get("VIDEO_PLACEHOLDER", "<video>")
V_HEAD_WEIGHTS_NAME = "value_head.bin"
V_HEAD_SAFE_WEIGHTS_NAME = "value_head.safetensors"
VISION_MODELS = set()
class DownloadSource(str, Enum):
DEFAULT = "hf"
MODELSCOPE = "ms"
OPENMIND = "om"
def register_model_group(
models: Dict[str, Dict[DownloadSource, str]],
template: Optional[str] = None,
vision: bool = False,
) -> None:
for name, path in models.items():
SUPPORTED_MODELS[name] = path
if template is not None and (any(suffix in name for suffix in ("-Chat", "-Instruct")) or vision):
DEFAULT_TEMPLATE[name] = template
if vision:
VISION_MODELS.add(name)
register_model_group(
models={
"Aya-23-8B-Chat": {
DownloadSource.DEFAULT: "CohereForAI/aya-23-8B",
},
"Aya-23-35B-Chat": {
DownloadSource.DEFAULT: "CohereForAI/aya-23-35B",
},
},
template="cohere",
)
register_model_group(
models={
"Baichuan-7B-Base": {
DownloadSource.DEFAULT: "baichuan-inc/Baichuan-7B",
DownloadSource.MODELSCOPE: "baichuan-inc/baichuan-7B",
},
"Baichuan-13B-Base": {
DownloadSource.DEFAULT: "baichuan-inc/Baichuan-13B-Base",
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan-13B-Base",
},
"Baichuan-13B-Chat": {
DownloadSource.DEFAULT: "baichuan-inc/Baichuan-13B-Chat",
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan-13B-Chat",
},
},
template="baichuan",
)
register_model_group(
models={
"Baichuan2-7B-Base": {
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-7B-Base",
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-7B-Base",
},
"Baichuan2-13B-Base": {
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-13B-Base",
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Base",
DownloadSource.OPENMIND: "Baichuan/Baichuan2_13b_base_pt",
},
"Baichuan2-7B-Chat": {
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-7B-Chat",
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-7B-Chat",
DownloadSource.OPENMIND: "Baichuan/Baichuan2_7b_chat_pt",
},
"Baichuan2-13B-Chat": {
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-13B-Chat",
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Chat",
DownloadSource.OPENMIND: "Baichuan/Baichuan2_13b_chat_pt",
},
},
template="baichuan2",
)
register_model_group(
models={
"BLOOM-560M": {
DownloadSource.DEFAULT: "bigscience/bloom-560m",
DownloadSource.MODELSCOPE: "AI-ModelScope/bloom-560m",
},
"BLOOM-3B": {
DownloadSource.DEFAULT: "bigscience/bloom-3b",
DownloadSource.MODELSCOPE: "AI-ModelScope/bloom-3b",
},
"BLOOM-7B1": {
DownloadSource.DEFAULT: "bigscience/bloom-7b1",
DownloadSource.MODELSCOPE: "AI-ModelScope/bloom-7b1",
},
},
)
register_model_group(
models={
"BLOOMZ-560M": {
DownloadSource.DEFAULT: "bigscience/bloomz-560m",
DownloadSource.MODELSCOPE: "AI-ModelScope/bloomz-560m",
},
"BLOOMZ-3B": {
DownloadSource.DEFAULT: "bigscience/bloomz-3b",
DownloadSource.MODELSCOPE: "AI-ModelScope/bloomz-3b",
},
"BLOOMZ-7B1-mt": {
DownloadSource.DEFAULT: "bigscience/bloomz-7b1-mt",
DownloadSource.MODELSCOPE: "AI-ModelScope/bloomz-7b1-mt",
},
},
)
register_model_group(
models={
"BlueLM-7B-Base": {
DownloadSource.DEFAULT: "vivo-ai/BlueLM-7B-Base",
DownloadSource.MODELSCOPE: "vivo-ai/BlueLM-7B-Base",
},
"BlueLM-7B-Chat": {
DownloadSource.DEFAULT: "vivo-ai/BlueLM-7B-Chat",
DownloadSource.MODELSCOPE: "vivo-ai/BlueLM-7B-Chat",
},
},
template="bluelm",
)
register_model_group(
models={
"Breeze-7B": {
DownloadSource.DEFAULT: "MediaTek-Research/Breeze-7B-Base-v1_0",
},
"Breeze-7B-Instruct": {
DownloadSource.DEFAULT: "MediaTek-Research/Breeze-7B-Instruct-v1_0",
},
},
template="breeze",
)
register_model_group(
models={
"ChatGLM2-6B-Chat": {
DownloadSource.DEFAULT: "THUDM/chatglm2-6b",
DownloadSource.MODELSCOPE: "ZhipuAI/chatglm2-6b",
}
},
template="chatglm2",
)
register_model_group(
models={
"ChatGLM3-6B-Base": {
DownloadSource.DEFAULT: "THUDM/chatglm3-6b-base",
DownloadSource.MODELSCOPE: "ZhipuAI/chatglm3-6b-base",
},
"ChatGLM3-6B-Chat": {
DownloadSource.DEFAULT: "THUDM/chatglm3-6b",
DownloadSource.MODELSCOPE: "ZhipuAI/chatglm3-6b",
},
},
template="chatglm3",
)
register_model_group(
models={
"Chinese-Llama-2-1.3B": {
DownloadSource.DEFAULT: "hfl/chinese-llama-2-1.3b",
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-1.3b",
},
"Chinese-Llama-2-7B": {
DownloadSource.DEFAULT: "hfl/chinese-llama-2-7b",
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-7b",
},
"Chinese-Llama-2-13B": {
DownloadSource.DEFAULT: "hfl/chinese-llama-2-13b",
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-13b",
},
"Chinese-Alpaca-2-1.3B-Chat": {
DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-1.3b",
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-1.3b",
},
"Chinese-Alpaca-2-7B-Chat": {
DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-7b",
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-7b",
},
"Chinese-Alpaca-2-13B-Chat": {
DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-13b",
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-13b",
},
},
template="llama2_zh",
)
register_model_group(
models={
"CodeGeeX4-9B-Chat": {
DownloadSource.DEFAULT: "THUDM/codegeex4-all-9b",
DownloadSource.MODELSCOPE: "ZhipuAI/codegeex4-all-9b",
},
},
template="codegeex4",
)
register_model_group(
models={
"CodeGemma-7B": {
DownloadSource.DEFAULT: "google/codegemma-7b",
},
"CodeGemma-7B-Instruct": {
DownloadSource.DEFAULT: "google/codegemma-7b-it",
DownloadSource.MODELSCOPE: "AI-ModelScope/codegemma-7b-it",
},
"CodeGemma-1.1-2B": {
DownloadSource.DEFAULT: "google/codegemma-1.1-2b",
},
"CodeGemma-1.1-7B-Instruct": {
DownloadSource.DEFAULT: "google/codegemma-1.1-7b-it",
},
},
template="gemma",
)
register_model_group(
models={
"Codestral-22B-v0.1-Chat": {
DownloadSource.DEFAULT: "mistralai/Codestral-22B-v0.1",
DownloadSource.MODELSCOPE: "swift/Codestral-22B-v0.1",
},
},
template="mistral",
)
register_model_group(
models={
"CommandR-35B-Chat": {
DownloadSource.DEFAULT: "CohereForAI/c4ai-command-r-v01",
DownloadSource.MODELSCOPE: "AI-ModelScope/c4ai-command-r-v01",
},
"CommandR-Plus-104B-Chat": {
DownloadSource.DEFAULT: "CohereForAI/c4ai-command-r-plus",
DownloadSource.MODELSCOPE: "AI-ModelScope/c4ai-command-r-plus",
},
"CommandR-35B-4bit-Chat": {
DownloadSource.DEFAULT: "CohereForAI/c4ai-command-r-v01-4bit",
DownloadSource.MODELSCOPE: "mirror013/c4ai-command-r-v01-4bit",
},
"CommandR-Plus-104B-4bit-Chat": {
DownloadSource.DEFAULT: "CohereForAI/c4ai-command-r-plus-4bit",
},
},
template="cohere",
)
register_model_group(
models={
"DBRX-132B-Base": {
DownloadSource.DEFAULT: "databricks/dbrx-base",
DownloadSource.MODELSCOPE: "AI-ModelScope/dbrx-base",
},
"DBRX-132B-Instruct": {
DownloadSource.DEFAULT: "databricks/dbrx-instruct",
DownloadSource.MODELSCOPE: "AI-ModelScope/dbrx-instruct",
},
},
template="dbrx",
)
register_model_group(
models={
"DeepSeek-LLM-7B-Base": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-7b-base",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-7b-base",
},
"DeepSeek-LLM-67B-Base": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-67b-base",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-67b-base",
},
"DeepSeek-LLM-7B-Chat": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-7b-chat",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-7b-chat",
},
"DeepSeek-LLM-67B-Chat": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-67b-chat",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-67b-chat",
},
"DeepSeek-Math-7B-Base": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-math-7b-base",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-math-7b-base",
},
"DeepSeek-Math-7B-Instruct": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-math-7b-instruct",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-math-7b-instruct",
},
"DeepSeek-MoE-16B-Base": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-moe-16b-base",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-moe-16b-base",
},
"DeepSeek-MoE-16B-Chat": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-moe-16b-chat",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-moe-16b-chat",
},
"DeepSeek-V2-16B-Base": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2-Lite",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2-Lite",
},
"DeepSeek-V2-236B-Base": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2",
},
"DeepSeek-V2-16B-Chat": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2-Lite-Chat",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2-Lite-Chat",
},
"DeepSeek-V2-236B-Chat": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2-Chat",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2-Chat",
},
"DeepSeek-Coder-V2-16B-Base": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-Coder-V2-Lite-Base",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-Coder-V2-Lite-Base",
},
"DeepSeek-Coder-V2-236B-Base": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-Coder-V2-Base",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-Coder-V2-Base",
},
"DeepSeek-Coder-V2-16B-Instruct": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct",
},
"DeepSeek-Coder-V2-236B-Instruct": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-Coder-V2-Instruct",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-Coder-V2-Instruct",
},
},
template="deepseek",
)
register_model_group(
models={
"DeepSeek-Coder-6.7B-Base": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-6.7b-base",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-6.7b-base",
},
"DeepSeek-Coder-7B-Base": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-7b-base-v1.5",
},
"DeepSeek-Coder-33B-Base": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-33b-base",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-33b-base",
},
"DeepSeek-Coder-6.7B-Instruct": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-6.7b-instruct",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-6.7b-instruct",
},
"DeepSeek-Coder-7B-Instruct": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-7b-instruct-v1.5",
},
"DeepSeek-Coder-33B-Instruct": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-33b-instruct",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-33b-instruct",
},
},
template="deepseekcoder",
)
register_model_group(
models={
"EXAONE-3.0-7.8B-Instruct": {
DownloadSource.DEFAULT: "LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct",
},
},
template="exaone",
)
register_model_group(
models={
"Falcon-7B": {
DownloadSource.DEFAULT: "tiiuae/falcon-7b",
DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-7b",
},
"Falcon-11B": {
DownloadSource.DEFAULT: "tiiuae/falcon-11B",
DownloadSource.MODELSCOPE: "tiiuae/falcon-11B",
},
"Falcon-40B": {
DownloadSource.DEFAULT: "tiiuae/falcon-40b",
DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-40b",
},
"Falcon-180B": {
DownloadSource.DEFAULT: "tiiuae/falcon-180b",
DownloadSource.MODELSCOPE: "modelscope/falcon-180B",
},
"Falcon-7B-Instruct": {
DownloadSource.DEFAULT: "tiiuae/falcon-7b-instruct",
DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-7b-instruct",
},
"Falcon-40B-Instruct": {
DownloadSource.DEFAULT: "tiiuae/falcon-40b-instruct",
DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-40b-instruct",
},
"Falcon-180B-Chat": {
DownloadSource.DEFAULT: "tiiuae/falcon-180b-chat",
DownloadSource.MODELSCOPE: "modelscope/falcon-180B-chat",
},
},
template="falcon",
)
register_model_group(
models={
"Gemma-2B": {
DownloadSource.DEFAULT: "google/gemma-2b",
DownloadSource.MODELSCOPE: "AI-ModelScope/gemma-2b",
},
"Gemma-7B": {
DownloadSource.DEFAULT: "google/gemma-7b",
DownloadSource.MODELSCOPE: "AI-ModelScope/gemma-2b-it",
},
"Gemma-2B-Instruct": {
DownloadSource.DEFAULT: "google/gemma-2b-it",
DownloadSource.MODELSCOPE: "AI-ModelScope/gemma-7b",
},
"Gemma-7B-Instruct": {
DownloadSource.DEFAULT: "google/gemma-7b-it",
DownloadSource.MODELSCOPE: "AI-ModelScope/gemma-7b-it",
},
"Gemma-1.1-2B-Instruct": {
DownloadSource.DEFAULT: "google/gemma-1.1-2b-it",
},
"Gemma-1.1-7B-Instruct": {
DownloadSource.DEFAULT: "google/gemma-1.1-7b-it",
},
"Gemma-2-2B": {
DownloadSource.DEFAULT: "google/gemma-2-2b",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-2b",
},
"Gemma-2-9B": {
DownloadSource.DEFAULT: "google/gemma-2-9b",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-9b",
},
"Gemma-2-27B": {
DownloadSource.DEFAULT: "google/gemma-2-27b",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-27b",
},
"Gemma-2-2B-Instruct": {
DownloadSource.DEFAULT: "google/gemma-2-2b-it",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-2b-it",
DownloadSource.OPENMIND: "LlamaFactory/gemma-2-2b-it",
},
"Gemma-2-9B-Instruct": {
DownloadSource.DEFAULT: "google/gemma-2-9b-it",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-9b-it",
DownloadSource.OPENMIND: "LlamaFactory/gemma-2-9b-it",
},
"Gemma-2-27B-Instruct": {
DownloadSource.DEFAULT: "google/gemma-2-27b-it",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-27b-it",
},
},
template="gemma",
)
register_model_group(
models={
"GLM-4-9B": {
DownloadSource.DEFAULT: "THUDM/glm-4-9b",
DownloadSource.MODELSCOPE: "ZhipuAI/glm-4-9b",
},
"GLM-4-9B-Chat": {
DownloadSource.DEFAULT: "THUDM/glm-4-9b-chat",
DownloadSource.MODELSCOPE: "ZhipuAI/glm-4-9b-chat",
DownloadSource.OPENMIND: "LlamaFactory/glm-4-9b-chat",
},
"GLM-4-9B-1M-Chat": {
DownloadSource.DEFAULT: "THUDM/glm-4-9b-chat-1m",
DownloadSource.MODELSCOPE: "ZhipuAI/glm-4-9b-chat-1m",
},
},
template="glm4",
)
register_model_group(
models={
"Granite-3.0-1B-A400M-Base": {
DownloadSource.DEFAULT: "ibm-granite/granite-3.0-1b-a400m-base",
DownloadSource.MODELSCOPE: "AI-ModelScope/granite-3.0-1b-a400m-base",
},
"Granite-3.0-3B-A800M-Base": {
DownloadSource.DEFAULT: "ibm-granite/granite-3.0-3b-a800m-base",
DownloadSource.MODELSCOPE: "AI-ModelScope/granite-3.0-3b-a800m-base",
},
"Granite-3.0-2B-Base": {
DownloadSource.DEFAULT: "ibm-granite/granite-3.0-2b-base",
DownloadSource.MODELSCOPE: "AI-ModelScope/granite-3.0-2b-base",
},
"Granite-3.0-8B-Base": {
DownloadSource.DEFAULT: "ibm-granite/granite-3.0-8b-base",
DownloadSource.MODELSCOPE: "AI-ModelScope/granite-3.0-8b-base",
},
"Granite-3.0-1B-A400M-Instruct": {
DownloadSource.DEFAULT: "ibm-granite/granite-3.0-1b-a400m-instruct",
DownloadSource.MODELSCOPE: "AI-ModelScope/granite-3.0-1b-a400m-instruct",
},
"Granite-3.0-3B-A800M-Instruct": {
DownloadSource.DEFAULT: "ibm-granite/granite-3.0-3b-a800m-instruct",
DownloadSource.MODELSCOPE: "AI-ModelScope/granite-3.0-3b-a800m-instruct",
},
"Granite-3.0-2B-Instruct": {
DownloadSource.DEFAULT: "ibm-granite/granite-3.0-2b-instruct",
DownloadSource.MODELSCOPE: "AI-ModelScope/granite-3.0-2b-instruct",
},
"Granite-3.0-8B-Instruct": {
DownloadSource.DEFAULT: "ibm-granite/granite-3.0-8b-instruct",
DownloadSource.MODELSCOPE: "AI-ModelScope/granite-3.0-8b-instruct",
},
"Granite-3.1-1B-A400M-Base": {
DownloadSource.DEFAULT: "ibm-granite/granite-3.1-1b-a400m-base",
DownloadSource.MODELSCOPE: "AI-ModelScope/granite-3.1-1b-a400m-base",
},
"Granite-3.1-3B-A800M-Base": {
DownloadSource.DEFAULT: "ibm-granite/granite-3.1-3b-a800m-base",
DownloadSource.MODELSCOPE: "AI-ModelScope/granite-3.1-3b-a800m-base",
},
"Granite-3.1-2B-Base": {
DownloadSource.DEFAULT: "ibm-granite/granite-3.1-2b-base",
DownloadSource.MODELSCOPE: "AI-ModelScope/granite-3.1-2b-base",
},
"Granite-3.1-8B-Base": {
DownloadSource.DEFAULT: "ibm-granite/granite-3.1-8b-base",
DownloadSource.MODELSCOPE: "AI-ModelScope/granite-3.1-8b-base",
},
"Granite-3.1-1B-A400M-Instruct": {
DownloadSource.DEFAULT: "ibm-granite/granite-3.1-1b-a400m-instruct",
DownloadSource.MODELSCOPE: "AI-ModelScope/granite-3.1-1b-a400m-instruct",
},
"Granite-3.1-3B-A800M-Instruct": {
DownloadSource.DEFAULT: "ibm-granite/granite-3.1-3b-a800m-instruct",
DownloadSource.MODELSCOPE: "AI-ModelScope/granite-3.1-3b-a800m-instruct",
},
"Granite-3.1-2B-Instruct": {
DownloadSource.DEFAULT: "ibm-granite/granite-3.1-2b-instruct",
DownloadSource.MODELSCOPE: "AI-ModelScope/granite-3.1-2b-instruct",
},
"Granite-3.1-8B-Instruct": {
DownloadSource.DEFAULT: "ibm-granite/granite-3.1-8b-instruct",
DownloadSource.MODELSCOPE: "AI-ModelScope/granite-3.1-8b-instruct",
},
},
template="granite3",
)
register_model_group(
models={
"Index-1.9B-Base": {
DownloadSource.DEFAULT: "IndexTeam/Index-1.9B",
DownloadSource.MODELSCOPE: "IndexTeam/Index-1.9B",
},
"Index-1.9B-Base-Pure": {
DownloadSource.DEFAULT: "IndexTeam/Index-1.9B-Pure",
DownloadSource.MODELSCOPE: "IndexTeam/Index-1.9B-Pure",
},
"Index-1.9B-Chat": {
DownloadSource.DEFAULT: "IndexTeam/Index-1.9B-Chat",
DownloadSource.MODELSCOPE: "IndexTeam/Index-1.9B-Chat",
},
"Index-1.9B-Character-Chat": {
DownloadSource.DEFAULT: "IndexTeam/Index-1.9B-Character",
DownloadSource.MODELSCOPE: "IndexTeam/Index-1.9B-Character",
},
"Index-1.9B-Chat-32K": {
DownloadSource.DEFAULT: "IndexTeam/Index-1.9B-32K",
DownloadSource.MODELSCOPE: "IndexTeam/Index-1.9B-32K",
},
},
template="index",
)
register_model_group(
models={
"InternLM-7B": {
DownloadSource.DEFAULT: "internlm/internlm-7b",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-7b",
},
"InternLM-20B": {
DownloadSource.DEFAULT: "internlm/internlm-20b",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-20b",
},
"InternLM-7B-Chat": {
DownloadSource.DEFAULT: "internlm/internlm-chat-7b",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-chat-7b",
},
"InternLM-20B-Chat": {
DownloadSource.DEFAULT: "internlm/internlm-chat-20b",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-chat-20b",
},
},
template="intern",
)
register_model_group(
models={
"InternLM2-7B": {
DownloadSource.DEFAULT: "internlm/internlm2-7b",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-7b",
},
"InternLM2-20B": {
DownloadSource.DEFAULT: "internlm/internlm2-20b",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-20b",
},
"InternLM2-7B-Chat": {
DownloadSource.DEFAULT: "internlm/internlm2-chat-7b",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-chat-7b",
},
"InternLM2-20B-Chat": {
DownloadSource.DEFAULT: "internlm/internlm2-chat-20b",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-chat-20b",
},
"InternLM2.5-1.8B": {
DownloadSource.DEFAULT: "internlm/internlm2_5-1_8b",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-1_8b",
DownloadSource.OPENMIND: "Intern/internlm2_5-1_8b",
},
"InternLM2.5-7B": {
DownloadSource.DEFAULT: "internlm/internlm2_5-7b",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-7b",
},
"InternLM2.5-20B": {
DownloadSource.DEFAULT: "internlm/internlm2_5-20b",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-20b",
DownloadSource.OPENMIND: "Intern/internlm2_5-20b",
},
"InternLM2.5-1.8B-Chat": {
DownloadSource.DEFAULT: "internlm/internlm2_5-1_8b-chat",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-1_8b-chat",
DownloadSource.OPENMIND: "Intern/internlm2_5-1_8b-chat",
},
"InternLM2.5-7B-Chat": {
DownloadSource.DEFAULT: "internlm/internlm2_5-7b-chat",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-7b-chat",
DownloadSource.OPENMIND: "Intern/internlm2_5-7b-chat",
},
"InternLM2.5-7B-1M-Chat": {
DownloadSource.DEFAULT: "internlm/internlm2_5-7b-chat-1m",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-7b-chat-1m",
DownloadSource.OPENMIND: "Intern/internlm2_5-7b-chat-1m",
},
"InternLM2.5-20B-Chat": {
DownloadSource.DEFAULT: "internlm/internlm2_5-20b-chat",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-20b-chat",
DownloadSource.OPENMIND: "Intern/internlm2_5-20b-chat",
},
},
template="intern2",
)
register_model_group(
models={
"Jamba-v0.1": {
DownloadSource.DEFAULT: "ai21labs/Jamba-v0.1",
DownloadSource.MODELSCOPE: "AI-ModelScope/Jamba-v0.1",
}
},
)
register_model_group(
models={
"LingoWhale-8B": {
DownloadSource.DEFAULT: "deeplang-ai/LingoWhale-8B",
DownloadSource.MODELSCOPE: "DeepLang/LingoWhale-8B",
}
},
)
register_model_group(
models={
"Llama-7B": {
DownloadSource.DEFAULT: "huggyllama/llama-7b",
DownloadSource.MODELSCOPE: "skyline2006/llama-7b",
},
"Llama-13B": {
DownloadSource.DEFAULT: "huggyllama/llama-13b",
DownloadSource.MODELSCOPE: "skyline2006/llama-13b",
},
"Llama-30B": {
DownloadSource.DEFAULT: "huggyllama/llama-30b",
DownloadSource.MODELSCOPE: "skyline2006/llama-30b",
},
"Llama-65B": {
DownloadSource.DEFAULT: "huggyllama/llama-65b",
DownloadSource.MODELSCOPE: "skyline2006/llama-65b",
},
}
)
register_model_group(
models={
"Llama-2-7B": {
DownloadSource.DEFAULT: "meta-llama/Llama-2-7b-hf",
DownloadSource.MODELSCOPE: "modelscope/Llama-2-7b-ms",
},
"Llama-2-13B": {
DownloadSource.DEFAULT: "meta-llama/Llama-2-13b-hf",
DownloadSource.MODELSCOPE: "modelscope/Llama-2-13b-ms",
},
"Llama-2-70B": {
DownloadSource.DEFAULT: "meta-llama/Llama-2-70b-hf",
DownloadSource.MODELSCOPE: "modelscope/Llama-2-70b-ms",
},
"Llama-2-7B-Chat": {
DownloadSource.DEFAULT: "meta-llama/Llama-2-7b-chat-hf",
DownloadSource.MODELSCOPE: "modelscope/Llama-2-7b-chat-ms",
},
"Llama-2-13B-Chat": {
DownloadSource.DEFAULT: "meta-llama/Llama-2-13b-chat-hf",
DownloadSource.MODELSCOPE: "modelscope/Llama-2-13b-chat-ms",
},
"Llama-2-70B-Chat": {
DownloadSource.DEFAULT: "meta-llama/Llama-2-70b-chat-hf",
DownloadSource.MODELSCOPE: "modelscope/Llama-2-70b-chat-ms",
},
},
template="llama2",
)
register_model_group(
models={
"Llama-3-8B": {
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-8B",
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-8B",
},
"Llama-3-70B": {
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-70B",
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-70B",
},
"Llama-3-8B-Instruct": {
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-8B-Instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-8B-Instruct",
},
"Llama-3-70B-Instruct": {
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-70B-Instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-70B-Instruct",
},
"Llama-3-8B-Chinese-Chat": {
DownloadSource.DEFAULT: "shenzhi-wang/Llama3-8B-Chinese-Chat",
DownloadSource.MODELSCOPE: "LLM-Research/Llama3-8B-Chinese-Chat",
DownloadSource.OPENMIND: "LlamaFactory/Llama3-Chinese-8B-Instruct",
},
"Llama-3-70B-Chinese-Chat": {
DownloadSource.DEFAULT: "shenzhi-wang/Llama3-70B-Chinese-Chat",
},
"Llama-3.1-8B": {
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3.1-8B",
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3.1-8B",
},
"Llama-3.1-70B": {
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3.1-70B",
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3.1-70B",
},
"Llama-3.1-405B": {
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3.1-405B",
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3.1-405B",
},
"Llama-3.1-8B-Instruct": {
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3.1-8B-Instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3.1-8B-Instruct",
},
"Llama-3.1-70B-Instruct": {
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3.1-70B-Instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3.1-70B-Instruct",
},
"Llama-3.1-405B-Instruct": {
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3.1-405B-Instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3.1-405B-Instruct",
},
"Llama-3.1-8B-Chinese-Chat": {
DownloadSource.DEFAULT: "shenzhi-wang/Llama3.1-8B-Chinese-Chat",
DownloadSource.MODELSCOPE: "XD_AI/Llama3.1-8B-Chinese-Chat",
},
"Llama-3.1-70B-Chinese-Chat": {
DownloadSource.DEFAULT: "shenzhi-wang/Llama3.1-70B-Chinese-Chat",
DownloadSource.MODELSCOPE: "XD_AI/Llama3.1-70B-Chinese-Chat",
},
"Llama-3.2-1B": {
DownloadSource.DEFAULT: "meta-llama/Llama-3.2-1B",
DownloadSource.MODELSCOPE: "LLM-Research/Llama-3.2-1B",
},
"Llama-3.2-3B": {
DownloadSource.DEFAULT: "meta-llama/Llama-3.2-3B",
DownloadSource.MODELSCOPE: "LLM-Research/Llama-3.2-3B",
},
"Llama-3.2-1B-Instruct": {
DownloadSource.DEFAULT: "meta-llama/Llama-3.2-1B-Instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Llama-3.2-1B-Instruct",
},
"Llama-3.2-3B-Instruct": {
DownloadSource.DEFAULT: "meta-llama/Llama-3.2-3B-Instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Llama-3.2-3B-Instruct",
},
"Llama-3.3-70B-Instruct": {
DownloadSource.DEFAULT: "meta-llama/Llama-3.3-70B-Instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Llama-3.3-70B-Instruct",
},
},
template="llama3",
)
register_model_group(
models={
"Llama-3.2-11B-Vision": {
DownloadSource.DEFAULT: "meta-llama/Llama-3.2-11B-Vision",
DownloadSource.MODELSCOPE: "LLM-Research/Llama-3.2-11B-Vision",
},
"Llama-3.2-11B-Vision-Instruct": {
DownloadSource.DEFAULT: "meta-llama/Llama-3.2-11B-Vision-Instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Llama-3.2-11B-Vision-Instruct",
},
"Llama-3.2-90B-Vision": {
DownloadSource.DEFAULT: "meta-llama/Llama-3.2-90B-Vision",
DownloadSource.MODELSCOPE: "LLM-Research/Llama-3.2-90B-Vision",
},
"Llama-3.2-90B-Vision-Instruct": {
DownloadSource.DEFAULT: "meta-llama/Llama-3.2-90B-Vision-Instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Llama-3.2-90B-Vision-Instruct",
},
},
template="mllama",
vision=True,
)
register_model_group(
models={
"LLaVA-1.5-7B-Chat": {
DownloadSource.DEFAULT: "llava-hf/llava-1.5-7b-hf",
DownloadSource.MODELSCOPE: "swift/llava-1.5-7b-hf",
},
"LLaVA-1.5-13B-Chat": {
DownloadSource.DEFAULT: "llava-hf/llava-1.5-13b-hf",
DownloadSource.MODELSCOPE: "swift/llava-1.5-13b-hf",
},
},
template="llava",
vision=True,
)
register_model_group(
models={
"LLaVA-NeXT-7B-Chat": {
DownloadSource.DEFAULT: "llava-hf/llava-v1.6-vicuna-7b-hf",
DownloadSource.MODELSCOPE: "swift/llava-v1.6-vicuna-7b-hf",
},
"LLaVA-NeXT-13B-Chat": {
DownloadSource.DEFAULT: "llava-hf/llava-v1.6-vicuna-13b-hf",
DownloadSource.MODELSCOPE: "swift/llava-v1.6-vicuna-13b-hf",
},
},
template="llava_next",
vision=True,
)
register_model_group(
models={
"LLaVA-NeXT-Mistral-7B-Chat": {
DownloadSource.DEFAULT: "llava-hf/llava-v1.6-mistral-7b-hf",
DownloadSource.MODELSCOPE: "swift/llava-v1.6-mistral-7b-hf",
},
},
template="llava_next_mistral",
vision=True,
)
register_model_group(
models={
"LLaVA-NeXT-Llama3-8B-Chat": {
DownloadSource.DEFAULT: "llava-hf/llama3-llava-next-8b-hf",
DownloadSource.MODELSCOPE: "swift/llama3-llava-next-8b-hf",
},
},
template="llava_next_llama3",
vision=True,
)
register_model_group(
models={
"LLaVA-NeXT-34B-Chat": {
DownloadSource.DEFAULT: "llava-hf/llava-v1.6-34b-hf",
DownloadSource.MODELSCOPE: "LLM-Research/llava-v1.6-34b-hf",
},
},
template="llava_next_yi",
vision=True,
)
register_model_group(
models={
"LLaVA-NeXT-72B-Chat": {
DownloadSource.DEFAULT: "llava-hf/llava-next-72b-hf",
DownloadSource.MODELSCOPE: "AI-ModelScope/llava-next-72b-hf",
},
"LLaVA-NeXT-110B-Chat": {
DownloadSource.DEFAULT: "llava-hf/llava-next-110b-hf",
DownloadSource.MODELSCOPE: "AI-ModelScope/llava-next-110b-hf",
},
},
template="llava_next_qwen",
vision=True,
)
register_model_group(
models={
"LLaVA-NeXT-Video-7B-Chat": {
DownloadSource.DEFAULT: "llava-hf/LLaVA-NeXT-Video-7B-hf",
DownloadSource.MODELSCOPE: "swift/LLaVA-NeXT-Video-7B-hf",
},
"LLaVA-NeXT-Video-7B-DPO-Chat": {
DownloadSource.DEFAULT: "llava-hf/LLaVA-NeXT-Video-7B-DPO-hf",
DownloadSource.MODELSCOPE: "swift/LLaVA-NeXT-Video-7B-DPO-hf",
},
},
template="llava_next_video",
vision=True,
)
register_model_group(
models={
"LLaVA-NeXT-Video-7B-32k-Chat": {
DownloadSource.DEFAULT: "llava-hf/LLaVA-NeXT-Video-7B-32K-hf",
DownloadSource.MODELSCOPE: "swift/LLaVA-NeXT-Video-7B-32K-hf",
},
},
template="llava_next_video_mistral",
vision=True,
)
register_model_group(
models={
"LLaVA-NeXT-Video-34B-Chat": {
DownloadSource.DEFAULT: "llava-hf/LLaVA-NeXT-Video-34B-hf",
DownloadSource.MODELSCOPE: "swift/LLaVA-NeXT-Video-34B-hf",
},
"LLaVA-NeXT-Video-34B-DPO-Chat": {
DownloadSource.DEFAULT: "llava-hf/LLaVA-NeXT-Video-34B-DPO-hf",
},
},
template="llava_next_video_yi",
vision=True,
)
register_model_group(
models={
"Marco-o1-Chat": {
DownloadSource.DEFAULT: "AIDC-AI/Marco-o1",
DownloadSource.MODELSCOPE: "AIDC-AI/Marco-o1",
},
},
template="marco",
)
register_model_group(
models={
"MiniCPM-2B-SFT-Chat": {
DownloadSource.DEFAULT: "openbmb/MiniCPM-2B-sft-bf16",
DownloadSource.MODELSCOPE: "OpenBMB/miniCPM-bf16",
},
"MiniCPM-2B-DPO-Chat": {
DownloadSource.DEFAULT: "openbmb/MiniCPM-2B-dpo-bf16",
DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM-2B-dpo-bf16",
},
},
template="cpm",
)
register_model_group(
models={
"MiniCPM3-4B-Chat": {
DownloadSource.DEFAULT: "openbmb/MiniCPM3-4B",
DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM3-4B",
DownloadSource.OPENMIND: "LlamaFactory/MiniCPM3-4B",
},
},
template="cpm3",
)
register_model_group(
models={
"Mistral-7B-v0.1": {
DownloadSource.DEFAULT: "mistralai/Mistral-7B-v0.1",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-v0.1",
},
"Mistral-7B-Instruct-v0.1": {
DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.1",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-Instruct-v0.1",
},
"Mistral-7B-v0.2": {
DownloadSource.DEFAULT: "alpindale/Mistral-7B-v0.2-hf",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-v0.2-hf",
},
"Mistral-7B-Instruct-v0.2": {
DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.2",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-Instruct-v0.2",
},
"Mistral-7B-v0.3": {
DownloadSource.DEFAULT: "mistralai/Mistral-7B-v0.3",
},
"Mistral-7B-Instruct-v0.3": {
DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.3",
DownloadSource.MODELSCOPE: "LLM-Research/Mistral-7B-Instruct-v0.3",
},
"Mistral-Nemo-Instruct-2407": {
DownloadSource.DEFAULT: "mistralai/Mistral-Nemo-Instruct-2407",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-Nemo-Instruct-2407",
},
},
template="mistral",
)
register_model_group(
models={
"Mixtral-8x7B-v0.1": {
DownloadSource.DEFAULT: "mistralai/Mixtral-8x7B-v0.1",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x7B-v0.1",
},
"Mixtral-8x7B-v0.1-Instruct": {
DownloadSource.DEFAULT: "mistralai/Mixtral-8x7B-Instruct-v0.1",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x7B-Instruct-v0.1",
},
"Mixtral-8x22B-v0.1": {
DownloadSource.DEFAULT: "mistralai/Mixtral-8x22B-v0.1",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x22B-v0.1",
},
"Mixtral-8x22B-v0.1-Instruct": {
DownloadSource.DEFAULT: "mistralai/Mixtral-8x22B-Instruct-v0.1",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x22B-Instruct-v0.1",
},
},
template="mistral",
)
register_model_group(
models={
"OLMo-1B": {
DownloadSource.DEFAULT: "allenai/OLMo-1B-hf",
},
"OLMo-7B": {
DownloadSource.DEFAULT: "allenai/OLMo-7B-hf",
},
"OLMo-7B-Chat": {
DownloadSource.DEFAULT: "ssec-uw/OLMo-7B-Instruct-hf",
},
"OLMo-1.7-7B": {
DownloadSource.DEFAULT: "allenai/OLMo-1.7-7B-hf",
},
},
)
register_model_group(
models={
"OpenChat3.5-7B-Chat": {
DownloadSource.DEFAULT: "openchat/openchat-3.5-0106",
DownloadSource.MODELSCOPE: "xcwzxcwz/openchat-3.5-0106",
}
},
template="openchat",
)
register_model_group(
models={
"OpenChat3.6-8B-Chat": {
DownloadSource.DEFAULT: "openchat/openchat-3.6-8b-20240522",
}
},
template="openchat-3.6",
)
register_model_group(
models={
"OpenCoder-1.5B-Base": {
DownloadSource.DEFAULT: "infly/OpenCoder-1.5B-Base",
DownloadSource.MODELSCOPE: "infly/OpenCoder-1.5B-Base",
},
"OpenCoder-8B-Base": {
DownloadSource.DEFAULT: "infly/OpenCoder-8B-Base",
DownloadSource.MODELSCOPE: "infly/OpenCoder-8B-Base",
},
"OpenCoder-1.5B-Instruct": {
DownloadSource.DEFAULT: "infly/OpenCoder-1.5B-Instruct",
DownloadSource.MODELSCOPE: "infly/OpenCoder-1.5B-Instruct",
},
"OpenCoder-8B-Instruct": {
DownloadSource.DEFAULT: "infly/OpenCoder-8B-Instruct",
DownloadSource.MODELSCOPE: "infly/OpenCoder-8B-Instruct",
},
},
template="opencoder",
)
register_model_group(
models={
"Orion-14B-Base": {
DownloadSource.DEFAULT: "OrionStarAI/Orion-14B-Base",
DownloadSource.MODELSCOPE: "OrionStarAI/Orion-14B-Base",
},
"Orion-14B-Chat": {
DownloadSource.DEFAULT: "OrionStarAI/Orion-14B-Chat",
DownloadSource.MODELSCOPE: "OrionStarAI/Orion-14B-Chat",
},
"Orion-14B-Long-Chat": {
DownloadSource.DEFAULT: "OrionStarAI/Orion-14B-LongChat",
DownloadSource.MODELSCOPE: "OrionStarAI/Orion-14B-LongChat",
},
"Orion-14B-RAG-Chat": {
DownloadSource.DEFAULT: "OrionStarAI/Orion-14B-Chat-RAG",
DownloadSource.MODELSCOPE: "OrionStarAI/Orion-14B-Chat-RAG",
},
"Orion-14B-Plugin-Chat": {
DownloadSource.DEFAULT: "OrionStarAI/Orion-14B-Chat-Plugin",
DownloadSource.MODELSCOPE: "OrionStarAI/Orion-14B-Chat-Plugin",
},
},
template="orion",
)
register_model_group(
models={
"PaliGemma-3B-pt-224": {
DownloadSource.DEFAULT: "google/paligemma-3b-pt-224",
DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-pt-224",
},
"PaliGemma-3B-pt-448": {
DownloadSource.DEFAULT: "google/paligemma-3b-pt-448",
DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-pt-448",
},
"PaliGemma-3B-pt-896": {
DownloadSource.DEFAULT: "google/paligemma-3b-pt-896",
DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-pt-896",
},
"PaliGemma-3B-mix-224": {
DownloadSource.DEFAULT: "google/paligemma-3b-mix-224",
DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-mix-224",
},
"PaliGemma-3B-mix-448": {
DownloadSource.DEFAULT: "google/paligemma-3b-mix-448",
DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-mix-448",
},
},
template="paligemma",
vision=True,
)
register_model_group(
models={
"PaliGemma2-3B-pt-224": {
DownloadSource.DEFAULT: "google/paligemma2-3b-pt-224",
DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma2-3b-pt-224",
},
"PaliGemma2-3B-pt-448": {
DownloadSource.DEFAULT: "google/paligemma2-3b-pt-448",
DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma2-3b-pt-448",
},
"PaliGemma2-3B-pt-896": {
DownloadSource.DEFAULT: "google/paligemma2-3b-pt-896",
DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma2-3b-pt-896",
},
"PaliGemma2-10B-pt-224": {
DownloadSource.DEFAULT: "google/paligemma2-10b-pt-224",
DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma2-10b-pt-224",
},
"PaliGemma2-10B-pt-448": {
DownloadSource.DEFAULT: "google/paligemma2-10b-pt-448",
DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma2-10b-pt-448",
},
"PaliGemma2-10B-pt-896": {
DownloadSource.DEFAULT: "google/paligemma2-10b-pt-896",
DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma2-10b-pt-896",
},
"PaliGemma2-28B-pt-224": {
DownloadSource.DEFAULT: "google/paligemma2-28b-pt-224",
DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma2-28b-pt-224",
},
"PaliGemma2-28B-pt-448": {
DownloadSource.DEFAULT: "google/paligemma2-28b-pt-448",
DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma2-28b-pt-448",
},
"PaliGemma2-28B-pt-896": {
DownloadSource.DEFAULT: "google/paligemma2-28b-pt-896",
DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma2-28b-pt-896",
},
},
template="paligemma",
vision=True,
)
register_model_group(
models={
"Phi-1.5-1.3B": {
DownloadSource.DEFAULT: "microsoft/phi-1_5",
DownloadSource.MODELSCOPE: "allspace/PHI_1-5",
},
"Phi-2-2.7B": {
DownloadSource.DEFAULT: "microsoft/phi-2",
DownloadSource.MODELSCOPE: "AI-ModelScope/phi-2",
},
}
)
register_model_group(
models={
"Phi-3-4B-4k-Instruct": {
DownloadSource.DEFAULT: "microsoft/Phi-3-mini-4k-instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-mini-4k-instruct",
},
"Phi-3-4B-128k-Instruct": {
DownloadSource.DEFAULT: "microsoft/Phi-3-mini-128k-instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-mini-128k-instruct",
},
"Phi-3-14B-8k-Instruct": {
DownloadSource.DEFAULT: "microsoft/Phi-3-medium-4k-instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-medium-4k-instruct",
},
"Phi-3-14B-128k-Instruct": {
DownloadSource.DEFAULT: "microsoft/Phi-3-medium-128k-instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-medium-128k-instruct",
},
},
template="phi",
)
register_model_group(
models={
"Phi-3-7B-8k-Instruct": {
DownloadSource.DEFAULT: "microsoft/Phi-3-small-8k-instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-small-8k-instruct",
},
"Phi-3-7B-128k-Instruct": {
DownloadSource.DEFAULT: "microsoft/Phi-3-small-128k-instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-small-128k-instruct",
},
},
template="phi_small",
)
register_model_group(
models={
"Pixtral-12B-Instruct": {
DownloadSource.DEFAULT: "mistral-community/pixtral-12b",
DownloadSource.MODELSCOPE: "AI-ModelScope/pixtral-12b",
}
},
template="pixtral",
vision=True,
)
register_model_group(
models={
"Qwen-1.8B": {
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B",
DownloadSource.MODELSCOPE: "Qwen/Qwen-1_8B",
},
"Qwen-7B": {
DownloadSource.DEFAULT: "Qwen/Qwen-7B",
DownloadSource.MODELSCOPE: "Qwen/Qwen-7B",
},
"Qwen-14B": {
DownloadSource.DEFAULT: "Qwen/Qwen-14B",
DownloadSource.MODELSCOPE: "Qwen/Qwen-14B",
},
"Qwen-72B": {
DownloadSource.DEFAULT: "Qwen/Qwen-72B",
DownloadSource.MODELSCOPE: "Qwen/Qwen-72B",
},
"Qwen-1.8B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat",
DownloadSource.MODELSCOPE: "Qwen/Qwen-1_8B-Chat",
},
"Qwen-7B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat",
DownloadSource.MODELSCOPE: "Qwen/Qwen-7B-Chat",
},
"Qwen-14B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat",
DownloadSource.MODELSCOPE: "Qwen/Qwen-14B-Chat",
},
"Qwen-72B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat",
DownloadSource.MODELSCOPE: "Qwen/Qwen-72B-Chat",
},
"Qwen-1.8B-Chat-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat-Int8",
DownloadSource.MODELSCOPE: "Qwen/Qwen-1_8B-Chat-Int8",
},
"Qwen-1.8B-Chat-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat-Int4",
DownloadSource.MODELSCOPE: "Qwen/Qwen-1_8B-Chat-Int4",
},
"Qwen-7B-Chat-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat-Int8",
DownloadSource.MODELSCOPE: "Qwen/Qwen-7B-Chat-Int8",
},
"Qwen-7B-Chat-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat-Int4",
DownloadSource.MODELSCOPE: "Qwen/Qwen-7B-Chat-Int4",
},
"Qwen-14B-Chat-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat-Int8",
DownloadSource.MODELSCOPE: "Qwen/Qwen-14B-Chat-Int8",
},
"Qwen-14B-Chat-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat-Int4",
DownloadSource.MODELSCOPE: "Qwen/Qwen-14B-Chat-Int4",
},
"Qwen-72B-Chat-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat-Int8",
DownloadSource.MODELSCOPE: "Qwen/Qwen-72B-Chat-Int8",
},
"Qwen-72B-Chat-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat-Int4",
DownloadSource.MODELSCOPE: "Qwen/Qwen-72B-Chat-Int4",
},
},
template="qwen",
)
register_model_group(
models={
"Qwen1.5-0.5B": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-0.5B",
},
"Qwen1.5-1.8B": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-1.8B",
},
"Qwen1.5-4B": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-4B",
},
"Qwen1.5-7B": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-7B",
},
"Qwen1.5-14B": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-14B",
},
"Qwen1.5-32B": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-32B",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-32B",
},
"Qwen1.5-72B": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-72B",
},
"Qwen1.5-110B": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-110B",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-110B",
},
"Qwen1.5-MoE-A2.7B": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-MoE-A2.7B",
},
"Qwen1.5-0.5B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-0.5B-Chat",
},
"Qwen1.5-1.8B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B-Chat",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-1.8B-Chat",
},
"Qwen1.5-4B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B-Chat",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-4B-Chat",
},
"Qwen1.5-7B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B-Chat",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-7B-Chat",
},
"Qwen1.5-14B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-14B-Chat",
},
"Qwen1.5-32B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-32B-Chat",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-32B-Chat",
},
"Qwen1.5-72B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-72B-Chat",
},
"Qwen1.5-110B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-110B-Chat",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-110B-Chat",
},
"Qwen1.5-MoE-A2.7B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B-Chat",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-MoE-A2.7B-Chat",
},
"Qwen1.5-0.5B-Chat-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8",
},
"Qwen1.5-0.5B-Chat-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-0.5B-Chat-AWQ",
},
"Qwen1.5-1.8B-Chat-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-1.8B-Chat-GPTQ-Int8",
},
"Qwen1.5-1.8B-Chat-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B-Chat-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-1.8B-Chat-AWQ",
},
"Qwen1.5-4B-Chat-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-4B-Chat-GPTQ-Int8",
},
"Qwen1.5-4B-Chat-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B-Chat-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-4B-Chat-AWQ",
},
"Qwen1.5-7B-Chat-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-7B-Chat-GPTQ-Int8",
},
"Qwen1.5-7B-Chat-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B-Chat-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-7B-Chat-AWQ",
},
"Qwen1.5-14B-Chat-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-14B-Chat-GPTQ-Int8",
},
"Qwen1.5-14B-Chat-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-14B-Chat-AWQ",
},
"Qwen1.5-32B-Chat-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-32B-Chat-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-32B-Chat-AWQ",
},
"Qwen1.5-72B-Chat-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-72B-Chat-GPTQ-Int8",
},
"Qwen1.5-72B-Chat-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-72B-Chat-AWQ",
},
"Qwen1.5-110B-Chat-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-110B-Chat-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-110B-Chat-AWQ",
},
"Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4",
DownloadSource.MODELSCOPE: "Qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4",
},
"CodeQwen1.5-7B": {
DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B",
DownloadSource.MODELSCOPE: "Qwen/CodeQwen1.5-7B",
},
"CodeQwen1.5-7B-Chat": {
DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B-Chat",
DownloadSource.MODELSCOPE: "Qwen/CodeQwen1.5-7B-Chat",
},
"CodeQwen1.5-7B-Chat-AWQ": {
DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B-Chat-AWQ",
DownloadSource.MODELSCOPE: "Qwen/CodeQwen1.5-7B-Chat-AWQ",
},
},
template="qwen",
)
register_model_group(
models={
"Qwen2-0.5B": {
DownloadSource.DEFAULT: "Qwen/Qwen2-0.5B",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-0.5B",
},
"Qwen2-1.5B": {
DownloadSource.DEFAULT: "Qwen/Qwen2-1.5B",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-1.5B",
},
"Qwen2-7B": {
DownloadSource.DEFAULT: "Qwen/Qwen2-7B",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-7B",
},
"Qwen2-72B": {
DownloadSource.DEFAULT: "Qwen/Qwen2-72B",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-72B",
},
"Qwen2-MoE-57B-A14B": {
DownloadSource.DEFAULT: "Qwen/Qwen2-57B-A14B",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-57B-A14B",
},
"Qwen2-0.5B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2-0.5B-Instruct",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-0.5B-Instruct",
DownloadSource.OPENMIND: "LlamaFactory/Qwen2-0.5B-Instruct",
},
"Qwen2-1.5B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2-1.5B-Instruct",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-1.5B-Instruct",
DownloadSource.OPENMIND: "LlamaFactory/Qwen2-1.5B-Instruct",
},
"Qwen2-7B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2-7B-Instruct",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-7B-Instruct",
DownloadSource.OPENMIND: "LlamaFactory/Qwen2-7B-Instruct",
},
"Qwen2-72B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2-72B-Instruct",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-72B-Instruct",
},
"Qwen2-MoE-57B-A14B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2-57B-A14B-Instruct",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-57B-A14B-Instruct",
},
"Qwen2-0.5B-Instruct-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen2-0.5B-Instruct-GPTQ-Int8",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-0.5B-Instruct-GPTQ-Int8",
},
"Qwen2-0.5B-Instruct-GPTQ-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen2-0.5B-Instruct-GPTQ-Int4",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-0.5B-Instruct-GPTQ-Int4",
},
"Qwen2-0.5B-Instruct-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen2-0.5B-Instruct-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-0.5B-Instruct-AWQ",
},
"Qwen2-1.5B-Instruct-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen2-1.5B-Instruct-GPTQ-Int8",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-1.5B-Instruct-GPTQ-Int8",
},
"Qwen2-1.5B-Instruct-GPTQ-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen2-1.5B-Instruct-GPTQ-Int4",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-1.5B-Instruct-GPTQ-Int4",
},
"Qwen2-1.5B-Instruct-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen2-1.5B-Instruct-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-1.5B-Instruct-AWQ",
},
"Qwen2-7B-Instruct-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen2-7B-Instruct-GPTQ-Int8",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-7B-Instruct-GPTQ-Int8",
},
"Qwen2-7B-Instruct-GPTQ-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen2-7B-Instruct-GPTQ-Int4",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-7B-Instruct-GPTQ-Int4",
},
"Qwen2-7B-Instruct-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen2-7B-Instruct-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-7B-Instruct-AWQ",
},
"Qwen2-72B-Instruct-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen2-72B-Instruct-GPTQ-Int8",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-72B-Instruct-GPTQ-Int8",
},
"Qwen2-72B-Instruct-GPTQ-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen2-72B-Instruct-GPTQ-Int4",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-72B-Instruct-GPTQ-Int4",
},
"Qwen2-72B-Instruct-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen2-72B-Instruct-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-72B-Instruct-AWQ",
},
"Qwen2-57B-A14B-Instruct-GPTQ-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen2-57B-A14B-Instruct-GPTQ-Int4",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-57B-A14B-Instruct-GPTQ-Int4",
},
"Qwen2-Math-1.5B": {
DownloadSource.DEFAULT: "Qwen/Qwen2-Math-1.5B",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-Math-1.5B",
},
"Qwen2-Math-7B": {
DownloadSource.DEFAULT: "Qwen/Qwen2-Math-7B",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-Math-7B",
},
"Qwen2-Math-72B": {
DownloadSource.DEFAULT: "Qwen/Qwen2-Math-72B",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-Math-72B",
},
"Qwen2-Math-1.5B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2-Math-1.5B-Instruct",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-Math-1.5B-Instruct",
},
"Qwen2-Math-7B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2-Math-7B-Instruct",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-Math-7B-Instruct",
},
"Qwen2-Math-72B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2-Math-72B-Instruct",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-Math-72B-Instruct",
},
},
template="qwen",
)
register_model_group(
models={
"Qwen2.5-0.5B": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-0.5B",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-0.5B",
},
"Qwen2.5-1.5B": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-1.5B",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-1.5B",
},
"Qwen2.5-3B": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-3B",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-3B",
},
"Qwen2.5-7B": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-7B",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-7B",
},
"Qwen2.5-14B": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-14B",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-14B",
},
"Qwen2.5-32B": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-32B",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-32B",
},
"Qwen2.5-72B": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-72B",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-72B",
},
"Qwen2.5-0.5B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-0.5B-Instruct",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-0.5B-Instruct",
},
"Qwen2.5-1.5B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-1.5B-Instruct",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-1.5B-Instruct",
},
"Qwen2.5-3B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-3B-Instruct",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-3B-Instruct",
},
"Qwen2.5-7B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-7B-Instruct",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-7B-Instruct",
},
"Qwen2.5-14B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-14B-Instruct",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-14B-Instruct",
},
"Qwen2.5-32B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-32B-Instruct",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-32B-Instruct",
},
"Qwen2.5-72B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-72B-Instruct",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-72B-Instruct",
},
"Qwen2.5-0.5B-Instruct-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-0.5B-Instruct-GPTQ-Int8",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-0.5B-Instruct-GPTQ-Int8",
},
"Qwen2.5-0.5B-Instruct-GPTQ-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-0.5B-Instruct-GPTQ-Int4",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-0.5B-Instruct-GPTQ-Int4",
},
"Qwen2.5-0.5B-Instruct-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-0.5B-Instruct-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-0.5B-Instruct-AWQ",
},
"Qwen2.5-1.5B-Instruct-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-1.5B-Instruct-GPTQ-Int8",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-1.5B-Instruct-GPTQ-Int8",
},
"Qwen2.5-1.5B-Instruct-GPTQ-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-1.5B-Instruct-GPTQ-Int4",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-1.5B-Instruct-GPTQ-Int4",
},
"Qwen2.5-1.5B-Instruct-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-1.5B-Instruct-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-1.5B-Instruct-AWQ",
},
"Qwen2.5-3B-Instruct-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-3B-Instruct-GPTQ-Int8",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-3B-Instruct-GPTQ-Int8",
},
"Qwen2.5-3B-Instruct-GPTQ-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-3B-Instruct-GPTQ-Int4",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-3B-Instruct-GPTQ-Int4",
},
"Qwen2.5-3B-Instruct-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-3B-Instruct-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-3B-Instruct-AWQ",
},
"Qwen2.5-7B-Instruct-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-7B-Instruct-GPTQ-Int8",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-7B-Instruct-GPTQ-Int8",
},
"Qwen2.5-7B-Instruct-GPTQ-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-7B-Instruct-GPTQ-Int4",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-7B-Instruct-GPTQ-Int4",
},
"Qwen2.5-7B-Instruct-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-7B-Instruct-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-7B-Instruct-AWQ",
},
"Qwen2.5-14B-Instruct-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-14B-Instruct-GPTQ-Int8",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-14B-Instruct-GPTQ-Int8",
},
"Qwen2.5-14B-Instruct-GPTQ-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-14B-Instruct-GPTQ-Int4",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-14B-Instruct-GPTQ-Int4",
},
"Qwen2.5-14B-Instruct-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-14B-Instruct-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-14B-Instruct-AWQ",
},
"Qwen2.5-32B-Instruct-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-32B-Instruct-GPTQ-Int8",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-32B-Instruct-GPTQ-Int8",
},
"Qwen2.5-32B-Instruct-GPTQ-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-32B-Instruct-GPTQ-Int4",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-32B-Instruct-GPTQ-Int4",
},
"Qwen2.5-32B-Instruct-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-32B-Instruct-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-32B-Instruct-AWQ",
},
"Qwen2.5-72B-Instruct-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-72B-Instruct-GPTQ-Int8",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-72B-Instruct-GPTQ-Int8",
},
"Qwen2.5-72B-Instruct-GPTQ-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-72B-Instruct-GPTQ-Int4",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-72B-Instruct-GPTQ-Int4",
},
"Qwen2.5-72B-Instruct-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-72B-Instruct-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-72B-Instruct-AWQ",
},
"Qwen2.5-Coder-0.5B": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Coder-0.5B",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-Coder-0.5B",
},
"Qwen2.5-Coder-1.5B": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Coder-1.5B",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-Coder-1.5B",
},
"Qwen2.5-Coder-3B": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Coder-3B",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-Coder-3B",
},
"Qwen2.5-Coder-7B": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Coder-7B",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-Coder-7B",
},
"Qwen2.5-Coder-14B": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Coder-14B",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-Coder-14B",
},
"Qwen2.5-Coder-32B": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Coder-32B",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-Coder-32B",
},
"Qwen2.5-Coder-0.5B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Coder-0.5B-Instruct",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-Coder-0.5B-Instruct",
},
"Qwen2.5-Coder-1.5B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Coder-1.5B-Instruct",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-Coder-1.5B-Instruct",
},
"Qwen2.5-Coder-3B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Coder-3B-Instruct",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-Coder-3B-Instruct",
},
"Qwen2.5-Coder-7B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Coder-7B-Instruct",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-Coder-7B-Instruct",
},
"Qwen2.5-Coder-14B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Coder-14B-Instruct",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-Coder-14B-Instruct",
},
"Qwen2.5-Coder-32B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Coder-32B-Instruct",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-Coder-32B-Instruct",
},
"Qwen2.5-Math-1.5B": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Math-1.5B",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-Math-1.5B",
},
"Qwen2.5-Math-7B": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Math-7B",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-Math-7B",
},
"Qwen2.5-Math-72B": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Math-72B",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-Math-72B",
},
"Qwen2.5-Math-1.5B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Math-1.5B-Instruct",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-Coder-1.5B-Instruct",
},
"Qwen2.5-Math-7B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Math-7B-Instruct",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-Coder-7B-Instruct",
},
"Qwen2.5-Math-72B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Math-72B-Instruct",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-Coder-72B-Instruct",
},
"QwQ-32B-Preview-Instruct": {
DownloadSource.DEFAULT: "Qwen/QwQ-32B-Preview",
DownloadSource.MODELSCOPE: "Qwen/QwQ-32B-Preview",
},
},
template="qwen",
)
register_model_group(
models={
"Qwen2-VL-2B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-2B-Instruct",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-VL-2B-Instruct",
DownloadSource.OPENMIND: "LlamaFactory/Qwen2-VL-2B-Instruct",
},
"Qwen2-VL-7B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-7B-Instruct",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-VL-7B-Instruct",
DownloadSource.OPENMIND: "LlamaFactory/Qwen2-VL-7B-Instruct",
},
"Qwen2-VL-72B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-72B-Instruct",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-VL-72B-Instruct",
},
"Qwen2-VL-2B-Instruct-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int8",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int8",
},
"Qwen2-VL-2B-Instruct-GPTQ-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4",
},
"Qwen2-VL-2B-Instruct-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-2B-Instruct-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-VL-2B-Instruct-AWQ",
},
"Qwen2-VL-7B-Instruct-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-7B-Instruct-GPTQ-Int8",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-VL-7B-Instruct-GPTQ-Int8",
},
"Qwen2-VL-7B-Instruct-GPTQ-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-7B-Instruct-GPTQ-Int4",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-VL-7B-Instruct-GPTQ-Int4",
},
"Qwen2-VL-7B-Instruct-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-7B-Instruct-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-VL-7B-Instruct-AWQ",
},
"Qwen2-VL-72B-Instruct-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-72B-Instruct-GPTQ-Int8",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-VL-72B-Instruct-GPTQ-Int8",
},
"Qwen2-VL-72B-Instruct-GPTQ-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-72B-Instruct-GPTQ-Int4",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-VL-72B-Instruct-GPTQ-Int4",
},
"Qwen2-VL-72B-Instruct-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-72B-Instruct-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-VL-72B-Instruct-AWQ",
},
"QVQ-72B-Preview": {
DownloadSource.DEFAULT: "Qwen/QVQ-72B-Preview",
DownloadSource.MODELSCOPE: "Qwen/QVQ-72B-Preview",
},
},
template="qwen2_vl",
vision=True,
)
register_model_group(
models={
"SOLAR-10.7B-v1.0": {
DownloadSource.DEFAULT: "upstage/SOLAR-10.7B-v1.0",
},
"SOLAR-10.7B-Instruct-v1.0": {
DownloadSource.DEFAULT: "upstage/SOLAR-10.7B-Instruct-v1.0",
DownloadSource.MODELSCOPE: "AI-ModelScope/SOLAR-10.7B-Instruct-v1.0",
},
},
template="solar",
)
register_model_group(
models={
"Skywork-13B-Base": {
DownloadSource.DEFAULT: "Skywork/Skywork-13B-base",
DownloadSource.MODELSCOPE: "skywork/Skywork-13B-base",
}
}
)
register_model_group(
models={
"Skywork-o1-Open-Llama-3.1-8B": {
DownloadSource.DEFAULT: "Skywork/Skywork-o1-Open-Llama-3.1-8B",
DownloadSource.MODELSCOPE: "AI-ModelScope/Skywork-o1-Open-Llama-3.1-8B",
}
},
template="skywork_o1",
)
register_model_group(
models={
"StarCoder2-3B": {
DownloadSource.DEFAULT: "bigcode/starcoder2-3b",
DownloadSource.MODELSCOPE: "AI-ModelScope/starcoder2-3b",
},
"StarCoder2-7B": {
DownloadSource.DEFAULT: "bigcode/starcoder2-7b",
DownloadSource.MODELSCOPE: "AI-ModelScope/starcoder2-7b",
},
"StarCoder2-15B": {
DownloadSource.DEFAULT: "bigcode/starcoder2-15b",
DownloadSource.MODELSCOPE: "AI-ModelScope/starcoder2-15b",
},
}
)
register_model_group(
models={
"TeleChat-1B-Chat": {
DownloadSource.DEFAULT: "Tele-AI/TeleChat-1B",
DownloadSource.MODELSCOPE: "TeleAI/TeleChat-1B",
},
"TeleChat-7B-Chat": {
DownloadSource.DEFAULT: "Tele-AI/telechat-7B",
DownloadSource.MODELSCOPE: "TeleAI/telechat-7B",
DownloadSource.OPENMIND: "TeleAI/TeleChat-7B-pt",
},
"TeleChat-12B-Chat": {
DownloadSource.DEFAULT: "Tele-AI/TeleChat-12B-v2",
DownloadSource.MODELSCOPE: "TeleAI/TeleChat-12B-v2",
DownloadSource.OPENMIND: "TeleAI/TeleChat-12B-pt",
},
"TeleChat-52B-Chat": {
DownloadSource.DEFAULT: "Tele-AI/TeleChat-52B",
},
},
template="telechat",
)
register_model_group(
models={
"TeleChat2-3B-Chat": {
DownloadSource.DEFAULT: "Tele-AI/TeleChat2-3B",
DownloadSource.MODELSCOPE: "TeleAI/TeleChat2-3B",
},
"TeleChat2-7B-Chat": {
DownloadSource.DEFAULT: "Tele-AI/TeleChat2-7B",
DownloadSource.MODELSCOPE: "TeleAI/TeleChat2-7B",
},
"TeleChat2-35B-Chat": {
DownloadSource.MODELSCOPE: "TeleAI/TeleChat2-35B-Nov",
},
"TeleChat2-115B-Chat": {
DownloadSource.DEFAULT: "Tele-AI/TeleChat2-115B",
DownloadSource.MODELSCOPE: "TeleAI/TeleChat2-115B",
},
},
template="telechat2",
)
register_model_group(
models={
"Vicuna-v1.5-7B-Chat": {
DownloadSource.DEFAULT: "lmsys/vicuna-7b-v1.5",
DownloadSource.MODELSCOPE: "Xorbits/vicuna-7b-v1.5",
},
"Vicuna-v1.5-13B-Chat": {
DownloadSource.DEFAULT: "lmsys/vicuna-13b-v1.5",
DownloadSource.MODELSCOPE: "Xorbits/vicuna-13b-v1.5",
},
},
template="vicuna",
)
register_model_group(
models={
"Video-LLaVA-7B-Chat": {
DownloadSource.DEFAULT: "LanguageBind/Video-LLaVA-7B-hf",
},
},
template="video_llava",
vision=True,
)
register_model_group(
models={
"XuanYuan-6B": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-6B",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-6B",
},
"XuanYuan-70B": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B",
},
"XuanYuan2-70B": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B",
},
"XuanYuan-6B-Chat": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-6B-Chat",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-6B-Chat",
},
"XuanYuan-70B-Chat": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B-Chat",
},
"XuanYuan2-70B-Chat": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B-Chat",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B-Chat",
},
"XuanYuan-6B-Chat-8bit": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-6B-Chat-8bit",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-6B-Chat-8bit",
},
"XuanYuan-6B-Chat-4bit": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-6B-Chat-4bit",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-6B-Chat-4bit",
},
"XuanYuan-70B-Chat-8bit": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-8bit",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B-Chat-8bit",
},
"XuanYuan-70B-Chat-4bit": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-4bit",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B-Chat-4bit",
},
"XuanYuan2-70B-Chat-8bit": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B-Chat-8bit",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B-Chat-8bit",
},
"XuanYuan2-70B-Chat-4bit": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B-Chat-4bit",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B-Chat-4bit",
},
},
template="xuanyuan",
)
register_model_group(
models={
"XVERSE-7B": {
DownloadSource.DEFAULT: "xverse/XVERSE-7B",
DownloadSource.MODELSCOPE: "xverse/XVERSE-7B",
},
"XVERSE-13B": {
DownloadSource.DEFAULT: "xverse/XVERSE-13B",
DownloadSource.MODELSCOPE: "xverse/XVERSE-13B",
},
"XVERSE-65B": {
DownloadSource.DEFAULT: "xverse/XVERSE-65B",
DownloadSource.MODELSCOPE: "xverse/XVERSE-65B",
},
"XVERSE-65B-2": {
DownloadSource.DEFAULT: "xverse/XVERSE-65B-2",
DownloadSource.MODELSCOPE: "xverse/XVERSE-65B-2",
},
"XVERSE-7B-Chat": {
DownloadSource.DEFAULT: "xverse/XVERSE-7B-Chat",
DownloadSource.MODELSCOPE: "xverse/XVERSE-7B-Chat",
},
"XVERSE-13B-Chat": {
DownloadSource.DEFAULT: "xverse/XVERSE-13B-Chat",
DownloadSource.MODELSCOPE: "xverse/XVERSE-13B-Chat",
},
"XVERSE-65B-Chat": {
DownloadSource.DEFAULT: "xverse/XVERSE-65B-Chat",
DownloadSource.MODELSCOPE: "xverse/XVERSE-65B-Chat",
},
"XVERSE-MoE-A4.2B": {
DownloadSource.DEFAULT: "xverse/XVERSE-MoE-A4.2B",
DownloadSource.MODELSCOPE: "xverse/XVERSE-MoE-A4.2B",
},
"XVERSE-7B-Chat-GPTQ-Int8": {
DownloadSource.DEFAULT: "xverse/XVERSE-7B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "xverse/XVERSE-7B-Chat-GPTQ-Int8",
},
"XVERSE-7B-Chat-GPTQ-Int4": {
DownloadSource.DEFAULT: "xverse/XVERSE-7B-Chat-GPTQ-Int4",
DownloadSource.MODELSCOPE: "xverse/XVERSE-7B-Chat-GPTQ-Int4",
},
"XVERSE-13B-Chat-GPTQ-Int8": {
DownloadSource.DEFAULT: "xverse/XVERSE-13B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "xverse/XVERSE-13B-Chat-GPTQ-Int8",
},
"XVERSE-13B-Chat-GPTQ-Int4": {
DownloadSource.DEFAULT: "xverse/XVERSE-13B-Chat-GPTQ-Int4",
DownloadSource.MODELSCOPE: "xverse/XVERSE-13B-Chat-GPTQ-Int4",
},
"XVERSE-65B-Chat-GPTQ-Int4": {
DownloadSource.DEFAULT: "xverse/XVERSE-65B-Chat-GPTQ-Int4",
DownloadSource.MODELSCOPE: "xverse/XVERSE-65B-Chat-GPTQ-Int4",
},
},
template="xverse",
)
register_model_group(
models={
"Yayi-7B": {
DownloadSource.DEFAULT: "wenge-research/yayi-7b-llama2",
DownloadSource.MODELSCOPE: "AI-ModelScope/yayi-7b-llama2",
},
"Yayi-13B": {
DownloadSource.DEFAULT: "wenge-research/yayi-13b-llama2",
DownloadSource.MODELSCOPE: "AI-ModelScope/yayi-13b-llama2",
},
},
template="yayi",
)
register_model_group(
models={
"Yi-6B": {
DownloadSource.DEFAULT: "01-ai/Yi-6B",
DownloadSource.MODELSCOPE: "01ai/Yi-6B",
},
"Yi-9B": {
DownloadSource.DEFAULT: "01-ai/Yi-9B",
DownloadSource.MODELSCOPE: "01ai/Yi-9B",
},
"Yi-34B": {
DownloadSource.DEFAULT: "01-ai/Yi-34B",
DownloadSource.MODELSCOPE: "01ai/Yi-34B",
},
"Yi-6B-Chat": {
DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat",
DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat",
},
"Yi-34B-Chat": {
DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat",
DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat",
},
"Yi-6B-Chat-8bits": {
DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat-8bits",
DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat-8bits",
},
"Yi-6B-Chat-4bits": {
DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat-4bits",
DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat-4bits",
},
"Yi-34B-Chat-8bits": {
DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat-8bits",
DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat-8bits",
},
"Yi-34B-Chat-4bits": {
DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat-4bits",
DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat-4bits",
},
"Yi-1.5-6B": {
DownloadSource.DEFAULT: "01-ai/Yi-1.5-6B",
DownloadSource.MODELSCOPE: "01ai/Yi-1.5-6B",
},
"Yi-1.5-9B": {
DownloadSource.DEFAULT: "01-ai/Yi-1.5-9B",
DownloadSource.MODELSCOPE: "01ai/Yi-1.5-9B",
},
"Yi-1.5-34B": {
DownloadSource.DEFAULT: "01-ai/Yi-1.5-34B",
DownloadSource.MODELSCOPE: "01ai/Yi-1.5-34B",
},
"Yi-1.5-6B-Chat": {
DownloadSource.DEFAULT: "01-ai/Yi-1.5-6B-Chat",
DownloadSource.MODELSCOPE: "01ai/Yi-1.5-6B-Chat",
DownloadSource.OPENMIND: "LlamaFactory/Yi-1.5-6B-Chat",
},
"Yi-1.5-9B-Chat": {
DownloadSource.DEFAULT: "01-ai/Yi-1.5-9B-Chat",
DownloadSource.MODELSCOPE: "01ai/Yi-1.5-9B-Chat",
},
"Yi-1.5-34B-Chat": {
DownloadSource.DEFAULT: "01-ai/Yi-1.5-34B-Chat",
DownloadSource.MODELSCOPE: "01ai/Yi-1.5-34B-Chat",
},
"Yi-Coder-1.5B": {
DownloadSource.DEFAULT: "01-ai/Yi-Coder-1.5B",
DownloadSource.MODELSCOPE: "01ai/Yi-Coder-1.5B",
},
"Yi-Coder-9B": {
DownloadSource.DEFAULT: "01-ai/Yi-Coder-9B",
DownloadSource.MODELSCOPE: "01ai/Yi-Coder-9B",
},
"Yi-Coder-1.5B-Chat": {
DownloadSource.DEFAULT: "01-ai/Yi-Coder-1.5B-Chat",
DownloadSource.MODELSCOPE: "01ai/Yi-Coder-1.5B-Chat",
},
"Yi-Coder-9B-Chat": {
DownloadSource.DEFAULT: "01-ai/Yi-Coder-9B-Chat",
DownloadSource.MODELSCOPE: "01ai/Yi-Coder-9B-Chat",
},
},
template="yi",
)
register_model_group(
models={
"Yi-VL-6B-Chat": {
DownloadSource.DEFAULT: "BUAADreamer/Yi-VL-6B-hf",
},
"Yi-VL-34B-Chat": {
DownloadSource.DEFAULT: "BUAADreamer/Yi-VL-34B-hf",
},
},
template="yi_vl",
vision=True,
)
register_model_group(
models={
"Yuan2-2B-Chat": {
DownloadSource.DEFAULT: "IEITYuan/Yuan2-2B-hf",
DownloadSource.MODELSCOPE: "YuanLLM/Yuan2.0-2B-hf",
},
"Yuan2-51B-Chat": {
DownloadSource.DEFAULT: "IEITYuan/Yuan2-51B-hf",
DownloadSource.MODELSCOPE: "YuanLLM/Yuan2.0-51B-hf",
},
"Yuan2-102B-Chat": {
DownloadSource.DEFAULT: "IEITYuan/Yuan2-102B-hf",
DownloadSource.MODELSCOPE: "YuanLLM/Yuan2.0-102B-hf",
},
},
template="yuan",
)
register_model_group(
models={
"Zephyr-7B-Alpha-Chat": {
DownloadSource.DEFAULT: "HuggingFaceH4/zephyr-7b-alpha",
DownloadSource.MODELSCOPE: "AI-ModelScope/zephyr-7b-alpha",
},
"Zephyr-7B-Beta-Chat": {
DownloadSource.DEFAULT: "HuggingFaceH4/zephyr-7b-beta",
DownloadSource.MODELSCOPE: "modelscope/zephyr-7b-beta",
},
"Zephyr-141B-ORPO-Chat": {
DownloadSource.DEFAULT: "HuggingFaceH4/zephyr-orpo-141b-A35b-v0.1",
},
},
template="zephyr",
)
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/commands/env.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import platform
import accelerate
import datasets
import peft
import torch
import transformers
import trl
from transformers.utils import is_torch_cuda_available, is_torch_npu_available
VERSION = "0.9.2.dev0"
def print_env() -> None:
info = {
"`llamafactory` version": VERSION,
"Platform": platform.platform(),
"Python version": platform.python_version(),
"PyTorch version": torch.__version__,
"Transformers version": transformers.__version__,
"Datasets version": datasets.__version__,
"Accelerate version": accelerate.__version__,
"PEFT version": peft.__version__,
"TRL version": trl.__version__,
}
if is_torch_cuda_available():
info["PyTorch version"] += " (GPU)"
info["GPU type"] = torch.cuda.get_device_name()
if is_torch_npu_available():
info["PyTorch version"] += " (NPU)"
info["NPU type"] = torch.npu.get_device_name()
info["CANN version"] = torch.version.cann
try:
import deepspeed # type: ignore
info["DeepSpeed version"] = deepspeed.__version__
except Exception:
pass
try:
import bitsandbytes
info["Bitsandbytes version"] = bitsandbytes.__version__
except Exception:
pass
try:
import vllm
info["vLLM version"] = vllm.__version__
except Exception:
pass
print("\n" + "\n".join([f"- {key}: {value}" for key, value in info.items()]) + "\n")
# Copyright 2024 Optuna, HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/utils/logging.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import sys
import threading
from concurrent.futures import ThreadPoolExecutor
from functools import lru_cache
from typing import Optional
from .constants import RUNNING_LOG
_thread_lock = threading.RLock()
_default_handler: Optional["logging.Handler"] = None
_default_log_level: "logging._Level" = logging.INFO
class LoggerHandler(logging.Handler):
r"""
Redirects the logging output to the logging file for LLaMA Board.
"""
def __init__(self, output_dir: str) -> None:
super().__init__()
self._formatter = logging.Formatter(
fmt="[%(levelname)s|%(asctime)s] %(filename)s:%(lineno)s >> %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
self.setLevel(logging.INFO)
os.makedirs(output_dir, exist_ok=True)
self.running_log = os.path.join(output_dir, RUNNING_LOG)
if os.path.exists(self.running_log):
os.remove(self.running_log)
self.thread_pool = ThreadPoolExecutor(max_workers=1)
def _write_log(self, log_entry: str) -> None:
with open(self.running_log, "a", encoding="utf-8") as f:
f.write(log_entry + "\n\n")
def emit(self, record) -> None:
if record.name == "httpx":
return
log_entry = self._formatter.format(record)
self.thread_pool.submit(self._write_log, log_entry)
def close(self) -> None:
self.thread_pool.shutdown(wait=True)
return super().close()
class _Logger(logging.Logger):
r"""
A logger that supports info_rank0 and warning_once.
"""
def info_rank0(self, *args, **kwargs) -> None:
self.info(*args, **kwargs)
def warning_rank0(self, *args, **kwargs) -> None:
self.warning(*args, **kwargs)
def warning_once(self, *args, **kwargs) -> None:
self.warning(*args, **kwargs)
def _get_default_logging_level() -> "logging._Level":
r"""
Returns the default logging level.
"""
env_level_str = os.environ.get("LLAMAFACTORY_VERBOSITY", None)
if env_level_str:
if env_level_str.upper() in logging._nameToLevel:
return logging._nameToLevel[env_level_str.upper()]
else:
raise ValueError(f"Unknown logging level: {env_level_str}.")
return _default_log_level
def _get_library_name() -> str:
return __name__.split(".")[0]
def _get_library_root_logger() -> "_Logger":
return logging.getLogger(_get_library_name())
def _configure_library_root_logger() -> None:
r"""
Configures root logger using a stdout stream handler with an explicit format.
"""
global _default_handler
with _thread_lock:
if _default_handler: # already configured
return
formatter = logging.Formatter(
fmt="[%(levelname)s|%(asctime)s] %(name)s:%(lineno)s >> %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
_default_handler = logging.StreamHandler(sys.stdout)
_default_handler.setFormatter(formatter)
library_root_logger = _get_library_root_logger()
library_root_logger.addHandler(_default_handler)
library_root_logger.setLevel(_get_default_logging_level())
library_root_logger.propagate = False
def get_logger(name: Optional[str] = None) -> "_Logger":
r"""
Returns a logger with the specified name. It it not supposed to be accessed externally.
"""
if name is None:
name = _get_library_name()
_configure_library_root_logger()
return logging.getLogger(name)
def add_handler(handler: "logging.Handler") -> None:
r"""
Adds a handler to the root logger.
"""
_configure_library_root_logger()
_get_library_root_logger().addHandler(handler)
def remove_handler(handler: logging.Handler) -> None:
r"""
Removes a handler to the root logger.
"""
_configure_library_root_logger()
_get_library_root_logger().removeHandler(handler)
def info_rank0(self: "logging.Logger", *args, **kwargs) -> None:
if int(os.getenv("LOCAL_RANK", "0")) == 0:
self.info(*args, **kwargs)
def warning_rank0(self: "logging.Logger", *args, **kwargs) -> None:
if int(os.getenv("LOCAL_RANK", "0")) == 0:
self.warning(*args, **kwargs)
@lru_cache(None)
def warning_once(self: "logging.Logger", *args, **kwargs) -> None:
if int(os.getenv("LOCAL_RANK", "0")) == 0:
self.warning(*args, **kwargs)
logging.Logger.info_rank0 = info_rank0
logging.Logger.warning_rank0 = warning_rank0
logging.Logger.warning_once = warning_once
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's PEFT library.
# https://github.com/huggingface/peft/blob/v0.10.0/src/peft/peft_model.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import os
from typing import TYPE_CHECKING, Any, Dict, Literal, Sequence, Tuple, Union
import torch
import torch.distributed as dist
import transformers.dynamic_module_utils
from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList
from transformers.dynamic_module_utils import get_relative_imports
from transformers.utils import (
is_torch_bf16_gpu_available,
is_torch_cuda_available,
is_torch_mps_available,
is_torch_npu_available,
is_torch_xpu_available,
)
from transformers.utils.versions import require_version
from . import logging
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
try:
_is_bf16_available = is_torch_bf16_gpu_available() or (is_torch_npu_available() and torch.npu.is_bf16_supported())
except Exception:
_is_bf16_available = False
if TYPE_CHECKING:
from numpy.typing import NDArray
from ..hparams import ModelArguments
logger = logging.get_logger(__name__)
class AverageMeter:
r"""
Computes and stores the average and current value.
"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def check_dependencies() -> None:
r"""
Checks the version of the required packages.
"""
if os.getenv("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]:
logger.warning_once("Version checking has been disabled, may lead to unexpected behaviors.")
else:
require_version("transformers>=4.41.2,<=4.46.1", "To fix: pip install transformers>=4.41.2,<=4.46.1")
require_version("datasets>=2.16.0,<=3.1.0", "To fix: pip install datasets>=2.16.0,<=3.1.0")
require_version("accelerate>=0.34.0,<=1.0.1", "To fix: pip install accelerate>=0.34.0,<=1.0.1")
require_version("peft>=0.11.1,<=0.12.0", "To fix: pip install peft>=0.11.1,<=0.12.0")
require_version("trl>=0.8.6,<=0.9.6", "To fix: pip install trl>=0.8.6,<=0.9.6")
def calculate_tps(dataset: Sequence[Dict[str, Any]], metrics: Dict[str, float], stage: Literal["sft", "rm"]) -> float:
r"""
Calculates effective tokens per second.
"""
effective_token_num = 0
for data in dataset:
if stage == "sft":
effective_token_num += len(data["input_ids"])
elif stage == "rm":
effective_token_num += len(data["chosen_input_ids"]) + len(data["rejected_input_ids"])
result = effective_token_num * metrics["epoch"] / metrics["train_runtime"]
return result / dist.get_world_size() if dist.is_initialized() else result
def count_parameters(model: "torch.nn.Module") -> Tuple[int, int]:
r"""
Returns the number of trainable parameters and number of all parameters in the model.
"""
trainable_params, all_param = 0, 0
for param in model.parameters():
num_params = param.numel()
# if using DS Zero 3 and the weights are initialized empty
if num_params == 0 and hasattr(param, "ds_numel"):
num_params = param.ds_numel
# Due to the design of 4bit linear layers from bitsandbytes, multiply the number of parameters by itemsize
if param.__class__.__name__ == "Params4bit":
if hasattr(param, "quant_storage") and hasattr(param.quant_storage, "itemsize"):
num_bytes = param.quant_storage.itemsize
elif hasattr(param, "element_size"): # for older pytorch version
num_bytes = param.element_size()
else:
num_bytes = 1
num_params = num_params * 2 * num_bytes
all_param += num_params
if param.requires_grad:
trainable_params += num_params
return trainable_params, all_param
def get_current_device() -> "torch.device":
r"""
Gets the current available device.
"""
if is_torch_xpu_available():
device = "xpu:{}".format(os.environ.get("LOCAL_RANK", "0"))
elif is_torch_npu_available():
device = "npu:{}".format(os.environ.get("LOCAL_RANK", "0"))
elif is_torch_mps_available():
device = "mps:{}".format(os.environ.get("LOCAL_RANK", "0"))
elif is_torch_cuda_available():
device = "cuda:{}".format(os.environ.get("LOCAL_RANK", "0"))
else:
device = "cpu"
return torch.device(device)
def get_device_count() -> int:
r"""
Gets the number of available GPU or NPU devices.
"""
if is_torch_xpu_available():
return torch.xpu.device_count()
elif is_torch_npu_available():
return torch.npu.device_count()
elif is_torch_cuda_available():
return torch.cuda.device_count()
else:
return 0
def get_logits_processor() -> "LogitsProcessorList":
r"""
Gets logits processor that removes NaN and Inf logits.
"""
logits_processor = LogitsProcessorList()
logits_processor.append(InfNanRemoveLogitsProcessor())
return logits_processor
def get_peak_memory() -> Tuple[int, int]:
r"""
Gets the peak memory usage for the current device (in Bytes).
"""
if is_torch_npu_available():
return torch.npu.max_memory_allocated(), torch.npu.max_memory_reserved()
elif is_torch_cuda_available():
return torch.cuda.max_memory_allocated(), torch.cuda.max_memory_reserved()
else:
return 0, 0
def has_tokenized_data(path: "os.PathLike") -> bool:
r"""
Checks if the path has a tokenized dataset.
"""
return os.path.isdir(path) and len(os.listdir(path)) > 0
def infer_optim_dtype(model_dtype: "torch.dtype") -> "torch.dtype":
r"""
Infers the optimal dtype according to the model_dtype and device compatibility.
"""
if _is_bf16_available and model_dtype == torch.bfloat16:
return torch.bfloat16
elif _is_fp16_available:
return torch.float16
else:
return torch.float32
def is_gpu_or_npu_available() -> bool:
r"""
Checks if the GPU or NPU is available.
"""
return is_torch_npu_available() or is_torch_cuda_available()
def numpify(inputs: Union["NDArray", "torch.Tensor"]) -> "NDArray":
r"""
Casts a torch tensor or a numpy array to a numpy array.
"""
if isinstance(inputs, torch.Tensor):
inputs = inputs.cpu()
if inputs.dtype == torch.bfloat16: # numpy does not support bfloat16 until 1.21.4
inputs = inputs.to(torch.float32)
inputs = inputs.numpy()
return inputs
def skip_check_imports() -> None:
r"""
Avoids flash attention import error in custom model files.
"""
if os.environ.get("FORCE_CHECK_IMPORTS", "0").lower() not in ["true", "1"]:
transformers.dynamic_module_utils.check_imports = get_relative_imports
def torch_gc() -> None:
r"""
Collects GPU or NPU memory.
"""
gc.collect()
if is_torch_xpu_available():
torch.xpu.empty_cache()
elif is_torch_npu_available():
torch.npu.empty_cache()
elif is_torch_mps_available():
torch.mps.empty_cache()
elif is_torch_cuda_available():
torch.cuda.empty_cache()
def try_download_model_from_other_hub(model_args: "ModelArguments") -> str:
if (not use_modelscope() and not use_openmind()) or os.path.exists(model_args.model_name_or_path):
return model_args.model_name_or_path
if use_modelscope():
require_version("modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0")
from modelscope import snapshot_download # type: ignore
revision = "master" if model_args.model_revision == "main" else model_args.model_revision
return snapshot_download(
model_args.model_name_or_path,
revision=revision,
cache_dir=model_args.cache_dir,
)
if use_openmind():
require_version("openmind>=0.8.0", "To fix: pip install openmind>=0.8.0")
from openmind.utils.hub import snapshot_download # type: ignore
return snapshot_download(
model_args.model_name_or_path,
revision=model_args.model_revision,
cache_dir=model_args.cache_dir,
)
def use_modelscope() -> bool:
return os.environ.get("USE_MODELSCOPE_HUB", "0").lower() in ["true", "1"]
def use_openmind() -> bool:
return os.environ.get("USE_OPENMIND_HUB", "0").lower() in ["true", "1"]
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/utils/import_utils.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib.metadata
import importlib.util
from functools import lru_cache
from typing import TYPE_CHECKING
from packaging import version
if TYPE_CHECKING:
from packaging.version import Version
def _is_package_available(name: str) -> bool:
return importlib.util.find_spec(name) is not None
def _get_package_version(name: str) -> "Version":
try:
return version.parse(importlib.metadata.version(name))
except Exception:
return version.parse("0.0.0")
def is_pyav_available():
return _is_package_available("av")
def is_fastapi_available():
return _is_package_available("fastapi")
def is_galore_available():
return _is_package_available("galore_torch")
def is_gradio_available():
return _is_package_available("gradio")
def is_matplotlib_available():
return _is_package_available("matplotlib")
def is_pillow_available():
return _is_package_available("PIL")
def is_requests_available():
return _is_package_available("requests")
def is_rouge_available():
return _is_package_available("rouge_chinese")
def is_starlette_available():
return _is_package_available("sse_starlette")
@lru_cache
def is_transformers_version_greater_than(content: str):
return _get_package_version("transformers") >= version.parse(content)
@lru_cache
def is_transformers_version_equal_to_4_46():
return version.parse("4.46.0") <= _get_package_version("transformers") <= version.parse("4.46.1")
def is_uvicorn_available():
return _is_package_available("uvicorn")
def is_vllm_available():
return _is_package_available("vllm")
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import math
import os
from typing import Any, Dict, List
from transformers.trainer import TRAINER_STATE_NAME
from . import logging
from .packages import is_matplotlib_available
if is_matplotlib_available():
import matplotlib.figure
import matplotlib.pyplot as plt
logger = logging.get_logger(__name__)
def smooth(scalars: List[float]) -> List[float]:
r"""
EMA implementation according to TensorBoard.
"""
if len(scalars) == 0:
return []
last = scalars[0]
smoothed = []
weight = 1.8 * (1 / (1 + math.exp(-0.05 * len(scalars))) - 0.5) # a sigmoid function
for next_val in scalars:
smoothed_val = last * weight + (1 - weight) * next_val
smoothed.append(smoothed_val)
last = smoothed_val
return smoothed
def gen_loss_plot(trainer_log: List[Dict[str, Any]]) -> "matplotlib.figure.Figure":
r"""
Plots loss curves in LlamaBoard.
"""
plt.close("all")
plt.switch_backend("agg")
fig = plt.figure()
ax = fig.add_subplot(111)
steps, losses = [], []
for log in trainer_log:
if log.get("loss", None):
steps.append(log["current_steps"])
losses.append(log["loss"])
ax.plot(steps, losses, color="#1f77b4", alpha=0.4, label="original")
ax.plot(steps, smooth(losses), color="#1f77b4", label="smoothed")
ax.legend()
ax.set_xlabel("step")
ax.set_ylabel("loss")
return fig
def plot_loss(save_dictionary: str, keys: List[str] = ["loss"]) -> None:
r"""
Plots loss curves and saves the image.
"""
plt.switch_backend("agg")
with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), encoding="utf-8") as f:
data = json.load(f)
for key in keys:
steps, metrics = [], []
for i in range(len(data["log_history"])):
if key in data["log_history"][i]:
steps.append(data["log_history"][i]["step"])
metrics.append(data["log_history"][i][key])
if len(metrics) == 0:
logger.warning_rank0(f"No metric {key} to plot.")
continue
plt.figure()
plt.plot(steps, metrics, color="#1f77b4", alpha=0.4, label="original")
plt.plot(steps, smooth(metrics), color="#1f77b4", label="smoothed")
plt.title(f"training {key} of {save_dictionary}")
plt.xlabel("step")
plt.ylabel(key)
plt.legend()
figure_path = os.path.join(save_dictionary, "training_{}.png".format(key.replace("/", "_")))
plt.savefig(figure_path, format="png", dpi=100)
print("Figure saved at:", figure_path)
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .data_args import DataArguments
from .evaluation_args import EvaluationArguments
from .finetuning_args import FinetuningArguments
from .generating_args import GeneratingArguments
from .model_args import ModelArguments
from .parser import get_eval_args, get_infer_args, get_train_args
__all__ = [
"DataArguments",
"EvaluationArguments",
"FinetuningArguments",
"GeneratingArguments",
"ModelArguments",
"get_eval_args",
"get_infer_args",
"get_train_args",
]
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import asdict, dataclass, field
from typing import Any, Dict, Literal, Optional
@dataclass
class DataArguments:
r"""
Arguments pertaining to what data we are going to input our model for training and evaluation.
"""
template: Optional[str] = field(
default=None,
metadata={"help": "Which template to use for constructing prompts in training and inference."},
)
dataset: Optional[str] = field(
default=None,
metadata={"help": "The name of dataset(s) to use for training. Use commas to separate multiple datasets."},
)
eval_dataset: Optional[str] = field(
default=None,
metadata={"help": "The name of dataset(s) to use for evaluation. Use commas to separate multiple datasets."},
)
dataset_dir: str = field(
default="data",
metadata={"help": "Path to the folder containing the datasets."},
)
image_dir: Optional[str] = field(
default=None,
metadata={"help": "Path to the folder containing the images or videos. Defaults to `dataset_dir`."},
)
cutoff_len: int = field(
default=2048,
metadata={"help": "The cutoff length of the tokenized inputs in the dataset."},
)
train_on_prompt: bool = field(
default=False,
metadata={"help": "Whether or not to disable the mask on the prompt."},
)
mask_history: bool = field(
default=False,
metadata={"help": "Whether or not to mask the history and train on the last turn only."},
)
streaming: bool = field(
default=False,
metadata={"help": "Enable dataset streaming."},
)
buffer_size: int = field(
default=16384,
metadata={"help": "Size of the buffer to randomly sample examples from in dataset streaming."},
)
mix_strategy: Literal["concat", "interleave_under", "interleave_over"] = field(
default="concat",
metadata={"help": "Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling)."},
)
interleave_probs: Optional[str] = field(
default=None,
metadata={"help": "Probabilities to sample data from datasets. Use commas to separate multiple datasets."},
)
overwrite_cache: bool = field(
default=False,
metadata={"help": "Overwrite the cached training and evaluation sets."},
)
preprocessing_batch_size: int = field(
default=1000,
metadata={"help": "The number of examples in one group in pre-processing."},
)
preprocessing_num_workers: Optional[int] = field(
default=None,
metadata={"help": "The number of processes to use for the pre-processing."},
)
max_samples: Optional[int] = field(
default=None,
metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."},
)
eval_num_beams: Optional[int] = field(
default=None,
metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"},
)
ignore_pad_token_for_loss: bool = field(
default=True,
metadata={"help": "Whether or not to ignore the tokens corresponding to the pad label in loss computation."},
)
val_size: float = field(
default=0.0,
metadata={"help": "Size of the development set, should be an integer or a float in range `[0,1)`."},
)
packing: Optional[bool] = field(
default=None,
metadata={"help": "Enable sequences packing in training. Will automatically enable in pre-training."},
)
neat_packing: bool = field(
default=False,
metadata={"help": "Enable sequence packing without cross-attention."},
)
tool_format: Optional[str] = field(
default=None,
metadata={"help": "Tool format to use for constructing function calling examples."},
)
tokenized_path: Optional[str] = field(
default=None,
metadata={
"help": (
"Path to save or load the tokenized datasets. "
"If tokenized_path not exists, it will save the tokenized datasets. "
"If tokenized_path exists, it will load the tokenized datasets."
)
},
)
def __post_init__(self):
def split_arg(arg):
if isinstance(arg, str):
return [item.strip() for item in arg.split(",")]
return arg
self.dataset = split_arg(self.dataset)
self.eval_dataset = split_arg(self.eval_dataset)
if self.image_dir is None:
self.image_dir = self.dataset_dir
if self.dataset is None and self.val_size > 1e-6:
raise ValueError("Cannot specify `val_size` if `dataset` is None.")
if self.eval_dataset is not None and self.val_size > 1e-6:
raise ValueError("Cannot specify `val_size` if `eval_dataset` is not None.")
if self.interleave_probs is not None:
if self.mix_strategy == "concat":
raise ValueError("`interleave_probs` is only valid for interleaved mixing.")
self.interleave_probs = list(map(float, split_arg(self.interleave_probs)))
if self.dataset is not None and len(self.dataset) != len(self.interleave_probs):
raise ValueError("The length of dataset and interleave probs should be identical.")
if self.eval_dataset is not None and len(self.eval_dataset) != len(self.interleave_probs):
raise ValueError("The length of eval dataset and interleave probs should be identical.")
if self.streaming and self.val_size > 1e-6 and self.val_size < 1:
raise ValueError("Streaming mode should have an integer val size.")
if self.streaming and self.max_samples is not None:
raise ValueError("`max_samples` is incompatible with `streaming`.")
if self.mask_history and self.train_on_prompt:
raise ValueError("`mask_history` is incompatible with `train_on_prompt`.")
def to_dict(self) -> Dict[str, Any]:
return asdict(self)
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from dataclasses import dataclass, field
from typing import Literal, Optional
from datasets import DownloadMode
@dataclass
class EvaluationArguments:
r"""
Arguments pertaining to specify the evaluation parameters.
"""
task: str = field(
metadata={"help": "Name of the evaluation task."},
)
task_dir: str = field(
default="evaluation",
metadata={"help": "Path to the folder containing the evaluation datasets."},
)
batch_size: int = field(
default=4,
metadata={"help": "The batch size per GPU for evaluation."},
)
seed: int = field(
default=42,
metadata={"help": "Random seed to be used with data loaders."},
)
lang: Literal["en", "zh"] = field(
default="en",
metadata={"help": "Language used at evaluation."},
)
n_shot: int = field(
default=5,
metadata={"help": "Number of examplars for few-shot learning."},
)
save_dir: Optional[str] = field(
default=None,
metadata={"help": "Path to save the evaluation results."},
)
download_mode: DownloadMode = field(
default=DownloadMode.REUSE_DATASET_IF_EXISTS,
metadata={"help": "Download mode used for the evaluation datasets."},
)
def __post_init__(self):
if self.save_dir is not None and os.path.exists(self.save_dir):
raise ValueError("`save_dir` already exists, use another one.")
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import asdict, dataclass, field
from typing import Any, Dict, List, Literal, Optional
@dataclass
class FreezeArguments:
r"""
Arguments pertaining to the freeze (partial-parameter) training.
"""
freeze_trainable_layers: int = field(
default=2,
metadata={
"help": (
"The number of trainable layers for freeze (partial-parameter) fine-tuning. "
"Positive numbers mean the last n layers are set as trainable, "
"negative numbers mean the first n layers are set as trainable."
)
},
)
freeze_trainable_modules: str = field(
default="all",
metadata={
"help": (
"Name(s) of trainable modules for freeze (partial-parameter) fine-tuning. "
"Use commas to separate multiple modules. "
"Use `all` to specify all the available modules."
)
},
)
freeze_extra_modules: Optional[str] = field(
default=None,
metadata={
"help": (
"Name(s) of modules apart from hidden layers to be set as trainable "
"for freeze (partial-parameter) fine-tuning. "
"Use commas to separate multiple modules."
)
},
)
@dataclass
class LoraArguments:
r"""
Arguments pertaining to the LoRA training.
"""
additional_target: Optional[str] = field(
default=None,
metadata={
"help": (
"Name(s) of modules apart from LoRA layers to be set as trainable "
"and saved in the final checkpoint. "
"Use commas to separate multiple modules."
)
},
)
lora_alpha: Optional[int] = field(
default=None,
metadata={"help": "The scale factor for LoRA fine-tuning (default: lora_rank * 2)."},
)
lora_dropout: float = field(
default=0.0,
metadata={"help": "Dropout rate for the LoRA fine-tuning."},
)
lora_rank: int = field(
default=8,
metadata={"help": "The intrinsic dimension for LoRA fine-tuning."},
)
lora_target: str = field(
default="all",
metadata={
"help": (
"Name(s) of target modules to apply LoRA. "
"Use commas to separate multiple modules. "
"Use `all` to specify all the linear modules."
)
},
)
loraplus_lr_ratio: Optional[float] = field(
default=None,
metadata={"help": "LoRA plus learning rate ratio (lr_B / lr_A)."},
)
loraplus_lr_embedding: float = field(
default=1e-6,
metadata={"help": "LoRA plus learning rate for lora embedding layers."},
)
use_rslora: bool = field(
default=False,
metadata={"help": "Whether or not to use the rank stabilization scaling factor for LoRA layer."},
)
use_dora: bool = field(
default=False,
metadata={"help": "Whether or not to use the weight-decomposed lora method (DoRA)."},
)
pissa_init: bool = field(
default=False,
metadata={"help": "Whether or not to initialize a PiSSA adapter."},
)
pissa_iter: int = field(
default=16,
metadata={"help": "The number of iteration steps performed by FSVD in PiSSA. Use -1 to disable it."},
)
pissa_convert: bool = field(
default=False,
metadata={"help": "Whether or not to convert the PiSSA adapter to a normal LoRA adapter."},
)
create_new_adapter: bool = field(
default=False,
metadata={"help": "Whether or not to create a new adapter with randomly initialized weight."},
)
@dataclass
class RLHFArguments:
r"""
Arguments pertaining to the PPO, DPO and KTO training.
"""
pref_beta: float = field(
default=0.1,
metadata={"help": "The beta parameter in the preference loss."},
)
pref_ftx: float = field(
default=0.0,
metadata={"help": "The supervised fine-tuning loss coefficient in DPO training."},
)
pref_loss: Literal["sigmoid", "hinge", "ipo", "kto_pair", "orpo", "simpo"] = field(
default="sigmoid",
metadata={"help": "The type of DPO loss to use."},
)
dpo_label_smoothing: float = field(
default=0.0,
metadata={"help": "The robust DPO label smoothing parameter in cDPO that should be between 0 and 0.5."},
)
kto_chosen_weight: float = field(
default=1.0,
metadata={"help": "The weight factor of the desirable losses in KTO training."},
)
kto_rejected_weight: float = field(
default=1.0,
metadata={"help": "The weight factor of the undesirable losses in KTO training."},
)
simpo_gamma: float = field(
default=0.5,
metadata={"help": "The target reward margin term in SimPO loss."},
)
ppo_buffer_size: int = field(
default=1,
metadata={"help": "The number of mini-batches to make experience buffer in a PPO optimization step."},
)
ppo_epochs: int = field(
default=4,
metadata={"help": "The number of epochs to perform in a PPO optimization step."},
)
ppo_score_norm: bool = field(
default=False,
metadata={"help": "Use score normalization in PPO training."},
)
ppo_target: float = field(
default=6.0,
metadata={"help": "Target KL value for adaptive KL control in PPO training."},
)
ppo_whiten_rewards: bool = field(
default=False,
metadata={"help": "Whiten the rewards before compute advantages in PPO training."},
)
ref_model: Optional[str] = field(
default=None,
metadata={"help": "Path to the reference model used for the PPO or DPO training."},
)
ref_model_adapters: Optional[str] = field(
default=None,
metadata={"help": "Path to the adapters of the reference model."},
)
ref_model_quantization_bit: Optional[int] = field(
default=None,
metadata={"help": "The number of bits to quantize the reference model."},
)
reward_model: Optional[str] = field(
default=None,
metadata={"help": "Path to the reward model used for the PPO training."},
)
reward_model_adapters: Optional[str] = field(
default=None,
metadata={"help": "Path to the adapters of the reward model."},
)
reward_model_quantization_bit: Optional[int] = field(
default=None,
metadata={"help": "The number of bits to quantize the reward model."},
)
reward_model_type: Literal["lora", "full", "api"] = field(
default="lora",
metadata={"help": "The type of the reward model in PPO training. Lora model only supports lora training."},
)
@dataclass
class GaloreArguments:
r"""
Arguments pertaining to the GaLore algorithm.
"""
use_galore: bool = field(
default=False,
metadata={"help": "Whether or not to use the gradient low-Rank projection (GaLore)."},
)
galore_target: str = field(
default="all",
metadata={
"help": (
"Name(s) of modules to apply GaLore. Use commas to separate multiple modules. "
"Use `all` to specify all the linear modules."
)
},
)
galore_rank: int = field(
default=16,
metadata={"help": "The rank of GaLore gradients."},
)
galore_update_interval: int = field(
default=200,
metadata={"help": "Number of steps to update the GaLore projection."},
)
galore_scale: float = field(
default=0.25,
metadata={"help": "GaLore scaling coefficient."},
)
galore_proj_type: Literal["std", "reverse_std", "right", "left", "full"] = field(
default="std",
metadata={"help": "Type of GaLore projection."},
)
galore_layerwise: bool = field(
default=False,
metadata={"help": "Whether or not to enable layer-wise update to further save memory."},
)
@dataclass
class BAdamArgument:
r"""
Arguments pertaining to the BAdam optimizer.
"""
use_badam: bool = field(
default=False,
metadata={"help": "Whether or not to use the BAdam optimizer."},
)
badam_mode: Literal["layer", "ratio"] = field(
default="layer",
metadata={"help": "Whether to use layer-wise or ratio-wise BAdam optimizer."},
)
badam_start_block: Optional[int] = field(
default=None,
metadata={"help": "The starting block index for layer-wise BAdam."},
)
badam_switch_mode: Optional[Literal["ascending", "descending", "random", "fixed"]] = field(
default="ascending",
metadata={"help": "the strategy of picking block to update for layer-wise BAdam."},
)
badam_switch_interval: Optional[int] = field(
default=50,
metadata={
"help": "Number of steps to update the block for layer-wise BAdam. Use -1 to disable the block update."
},
)
badam_update_ratio: float = field(
default=0.05,
metadata={"help": "The ratio of the update for ratio-wise BAdam."},
)
badam_mask_mode: Literal["adjacent", "scatter"] = field(
default="adjacent",
metadata={
"help": (
"The mode of the mask for BAdam optimizer. "
"`adjacent` means that the trainable parameters are adjacent to each other, "
"`scatter` means that trainable parameters are randomly choosed from the weight."
)
},
)
badam_verbose: int = field(
default=0,
metadata={
"help": (
"The verbosity level of BAdam optimizer. "
"0 for no print, 1 for print the block prefix, 2 for print trainable parameters."
)
},
)
@dataclass
class SwanLabArguments:
use_swanlab: bool = field(
default=False,
metadata={"help": "Whether or not to use the SwanLab (an experiment tracking and visualization tool)."},
)
swanlab_project: str = field(
default="llamafactory",
metadata={"help": "The project name in SwanLab."},
)
swanlab_workspace: str = field(
default=None,
metadata={"help": "The workspace name in SwanLab."},
)
swanlab_run_name: str = field(
default=None,
metadata={"help": "The experiment name in SwanLab."},
)
swanlab_mode: Literal["cloud", "local"] = field(
default="cloud",
metadata={"help": "The mode of SwanLab."},
)
swanlab_api_key: str = field(
default=None,
metadata={"help": "The API key for SwanLab."},
)
@dataclass
class FinetuningArguments(
FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments, BAdamArgument, SwanLabArguments
):
r"""
Arguments pertaining to which techniques we are going to fine-tuning with.
"""
pure_bf16: bool = field(
default=False,
metadata={"help": "Whether or not to train model in purely bf16 precision (without AMP)."},
)
stage: Literal["pt", "sft", "rm", "ppo", "dpo", "kto"] = field(
default="sft",
metadata={"help": "Which stage will be performed in training."},
)
finetuning_type: Literal["lora", "freeze", "full"] = field(
default="lora",
metadata={"help": "Which fine-tuning method to use."},
)
use_llama_pro: bool = field(
default=False,
metadata={"help": "Whether or not to make only the parameters in the expanded blocks trainable."},
)
use_adam_mini: bool = field(
default=False,
metadata={"help": "Whether or not to use the Adam-mini optimizer."},
)
freeze_vision_tower: bool = field(
default=True,
metadata={"help": "Whether ot not to freeze vision tower in MLLM training."},
)
train_mm_proj_only: bool = field(
default=False,
metadata={"help": "Whether or not to train the multimodal projector for MLLM only."},
)
compute_accuracy: bool = field(
default=False,
metadata={"help": "Whether or not to compute the token-level accuracy at evaluation."},
)
disable_shuffling: bool = field(
default=False,
metadata={"help": "Whether or not to disable the shuffling of the training set."},
)
plot_loss: bool = field(
default=False,
metadata={"help": "Whether or not to save the training loss curves."},
)
include_effective_tokens_per_second: bool = field(
default=False,
metadata={"help": "Whether or not to compute effective tokens per second."},
)
def __post_init__(self):
def split_arg(arg):
if isinstance(arg, str):
return [item.strip() for item in arg.split(",")]
return arg
self.freeze_trainable_modules: List[str] = split_arg(self.freeze_trainable_modules)
self.freeze_extra_modules: Optional[List[str]] = split_arg(self.freeze_extra_modules)
self.lora_alpha: int = self.lora_alpha or self.lora_rank * 2
self.lora_target: List[str] = split_arg(self.lora_target)
self.additional_target: Optional[List[str]] = split_arg(self.additional_target)
self.galore_target: List[str] = split_arg(self.galore_target)
self.freeze_vision_tower = self.freeze_vision_tower or self.train_mm_proj_only
self.use_ref_model = self.stage == "dpo" and self.pref_loss not in ["orpo", "simpo"]
assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
assert self.reward_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
if self.stage == "ppo" and self.reward_model is None:
raise ValueError("`reward_model` is necessary for PPO training.")
if self.stage == "ppo" and self.reward_model_type == "lora" and self.finetuning_type != "lora":
raise ValueError("`reward_model_type` cannot be lora for Freeze/Full PPO training.")
if self.stage == "dpo" and self.pref_loss != "sigmoid" and self.dpo_label_smoothing > 1e-6:
raise ValueError("`dpo_label_smoothing` is only valid for sigmoid loss function.")
if self.use_llama_pro and self.finetuning_type == "full":
raise ValueError("`use_llama_pro` is only valid for Freeze or LoRA training.")
if self.finetuning_type == "lora" and (self.use_galore or self.use_badam):
raise ValueError("Cannot use LoRA with GaLore or BAdam together.")
if self.use_galore and self.use_badam:
raise ValueError("Cannot use GaLore with BAdam together.")
if self.pissa_init and (self.stage in ["ppo", "kto"] or self.use_ref_model):
raise ValueError("Cannot use PiSSA for current training stage.")
if self.train_mm_proj_only and self.finetuning_type != "full":
raise ValueError("`train_mm_proj_only` is only valid for full training.")
if self.finetuning_type != "lora":
if self.loraplus_lr_ratio is not None:
raise ValueError("`loraplus_lr_ratio` is only valid for LoRA training.")
if self.use_rslora:
raise ValueError("`use_rslora` is only valid for LoRA training.")
if self.use_dora:
raise ValueError("`use_dora` is only valid for LoRA training.")
if self.pissa_init:
raise ValueError("`pissa_init` is only valid for LoRA training.")
def to_dict(self) -> Dict[str, Any]:
args = asdict(self)
args = {k: f"<{k.upper()}>" if k.endswith("api_key") else v for k, v in args.items()}
return args
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import asdict, dataclass, field
from typing import Any, Dict, Optional
from transformers import GenerationConfig
@dataclass
class GeneratingArguments:
r"""
Arguments pertaining to specify the decoding parameters.
"""
do_sample: bool = field(
default=True,
metadata={"help": "Whether or not to use sampling, use greedy decoding otherwise."},
)
temperature: float = field(
default=0.95,
metadata={"help": "The value used to modulate the next token probabilities."},
)
top_p: float = field(
default=0.7,
metadata={
"help": "The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept."
},
)
top_k: int = field(
default=50,
metadata={"help": "The number of highest probability vocabulary tokens to keep for top-k filtering."},
)
num_beams: int = field(
default=1,
metadata={"help": "Number of beams for beam search. 1 means no beam search."},
)
max_length: int = field(
default=1024,
metadata={"help": "The maximum length the generated tokens can have. It can be overridden by max_new_tokens."},
)
max_new_tokens: int = field(
default=1024,
metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."},
)
repetition_penalty: float = field(
default=1.0,
metadata={"help": "The parameter for repetition penalty. 1.0 means no penalty."},
)
length_penalty: float = field(
default=1.0,
metadata={"help": "Exponential penalty to the length that is used with beam-based generation."},
)
default_system: Optional[str] = field(
default=None,
metadata={"help": "Default system message to use in chat completion."},
)
skip_special_tokens: bool = field(
default=True,
metadata={"help": "Whether or not to remove special tokens in the decoding."},
)
def to_dict(self, obey_generation_config: bool = False) -> Dict[str, Any]:
args = asdict(self)
if args.get("max_new_tokens", -1) > 0:
args.pop("max_length", None)
else:
args.pop("max_new_tokens", None)
if obey_generation_config:
generation_config = GenerationConfig()
for key in list(args.keys()):
if not hasattr(generation_config, key):
args.pop(key)
return args
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from dataclasses import asdict, dataclass, field, fields
from typing import Any, Dict, Literal, Optional, Union
import torch
from transformers.training_args import _convert_str_dict
from typing_extensions import Self
@dataclass
class QuantizationArguments:
r"""
Arguments pertaining to the quantization method.
"""
quantization_method: Literal["bitsandbytes", "hqq", "eetq"] = field(
default="bitsandbytes",
metadata={"help": "Quantization method to use for on-the-fly quantization."},
)
quantization_bit: Optional[int] = field(
default=None,
metadata={"help": "The number of bits to quantize the model using on-the-fly quantization."},
)
quantization_type: Literal["fp4", "nf4"] = field(
default="nf4",
metadata={"help": "Quantization data type to use in bitsandbytes int4 training."},
)
double_quantization: bool = field(
default=True,
metadata={"help": "Whether or not to use double quantization in bitsandbytes int4 training."},
)
quantization_device_map: Optional[Literal["auto"]] = field(
default=None,
metadata={"help": "Device map used to infer the 4-bit quantized model, needs bitsandbytes>=0.43.0."},
)
@dataclass
class ProcessorArguments:
r"""
Arguments pertaining to the image processor.
"""
image_resolution: int = field(
default=512 * 512,
metadata={"help": "Keeps the number of pixels of image below this resolution."},
)
video_resolution: int = field(
default=128 * 128,
metadata={"help": "Keeps the number of pixels of video below this resolution."},
)
video_fps: float = field(
default=2.0,
metadata={"help": "The frames to sample per second for video inputs."},
)
video_maxlen: int = field(
default=64,
metadata={"help": "The maximum number of sampled frames for video inputs."},
)
@dataclass
class ExportArguments:
r"""
Arguments pertaining to the model export.
"""
export_dir: Optional[str] = field(
default=None,
metadata={"help": "Path to the directory to save the exported model."},
)
export_size: int = field(
default=1,
metadata={"help": "The file shard size (in GB) of the exported model."},
)
export_device: Literal["cpu", "auto"] = field(
default="cpu",
metadata={"help": "The device used in model export, use `auto` to accelerate exporting."},
)
export_quantization_bit: Optional[int] = field(
default=None,
metadata={"help": "The number of bits to quantize the exported model."},
)
export_quantization_dataset: Optional[str] = field(
default=None,
metadata={"help": "Path to the dataset or dataset name to use in quantizing the exported model."},
)
export_quantization_nsamples: int = field(
default=128,
metadata={"help": "The number of samples used for quantization."},
)
export_quantization_maxlen: int = field(
default=1024,
metadata={"help": "The maximum length of the model inputs used for quantization."},
)
export_legacy_format: bool = field(
default=False,
metadata={"help": "Whether or not to save the `.bin` files instead of `.safetensors`."},
)
export_hub_model_id: Optional[str] = field(
default=None,
metadata={"help": "The name of the repository if push the model to the Hugging Face hub."},
)
@dataclass
class VllmArguments:
r"""
Arguments pertaining to the vLLM worker.
"""
vllm_maxlen: int = field(
default=4096,
metadata={"help": "Maximum sequence (prompt + response) length of the vLLM engine."},
)
vllm_gpu_util: float = field(
default=0.9,
metadata={"help": "The fraction of GPU memory in (0,1) to be used for the vLLM engine."},
)
vllm_enforce_eager: bool = field(
default=False,
metadata={"help": "Whether or not to disable CUDA graph in the vLLM engine."},
)
vllm_max_lora_rank: int = field(
default=32,
metadata={"help": "Maximum rank of all LoRAs in the vLLM engine."},
)
vllm_config: Optional[Union[dict, str]] = field(
default=None,
metadata={"help": "Config to initialize the vllm engine. Please use JSON strings."},
)
@dataclass
class ModelArguments(QuantizationArguments, ProcessorArguments, ExportArguments, VllmArguments):
r"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune or infer.
"""
model_name_or_path: Optional[str] = field(
default=None,
metadata={
"help": "Path to the model weight or identifier from huggingface.co/models or modelscope.cn/models."
},
)
adapter_name_or_path: Optional[str] = field(
default=None,
metadata={
"help": (
"Path to the adapter weight or identifier from huggingface.co/models. "
"Use commas to separate multiple adapters."
)
},
)
adapter_folder: Optional[str] = field(
default=None,
metadata={"help": "The folder containing the adapter weights to load."},
)
cache_dir: Optional[str] = field(
default=None,
metadata={"help": "Where to store the pre-trained models downloaded from huggingface.co or modelscope.cn."},
)
use_fast_tokenizer: bool = field(
default=True,
metadata={"help": "Whether or not to use one of the fast tokenizer (backed by the tokenizers library)."},
)
resize_vocab: bool = field(
default=False,
metadata={"help": "Whether or not to resize the tokenizer vocab and the embedding layers."},
)
split_special_tokens: bool = field(
default=False,
metadata={"help": "Whether or not the special tokens should be split during the tokenization process."},
)
new_special_tokens: Optional[str] = field(
default=None,
metadata={"help": "Special tokens to be added into the tokenizer. Use commas to separate multiple tokens."},
)
model_revision: str = field(
default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
)
low_cpu_mem_usage: bool = field(
default=True,
metadata={"help": "Whether or not to use memory-efficient model loading."},
)
rope_scaling: Optional[Literal["linear", "dynamic"]] = field(
default=None,
metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."},
)
flash_attn: Literal["auto", "disabled", "sdpa", "fa2"] = field(
default="auto",
metadata={"help": "Enable FlashAttention for faster training and inference."},
)
shift_attn: bool = field(
default=False,
metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."},
)
mixture_of_depths: Optional[Literal["convert", "load"]] = field(
default=None,
metadata={"help": "Convert the model to mixture-of-depths (MoD) or load the MoD model."},
)
use_unsloth: bool = field(
default=False,
metadata={"help": "Whether or not to use unsloth's optimization for the LoRA training."},
)
use_unsloth_gc: bool = field(
default=False,
metadata={"help": "Whether or not to use unsloth's gradient checkpointing."},
)
enable_liger_kernel: bool = field(
default=False,
metadata={"help": "Whether or not to enable liger kernel for faster training."},
)
moe_aux_loss_coef: Optional[float] = field(
default=None,
metadata={"help": "Coefficient of the auxiliary router loss in mixture-of-experts model."},
)
disable_gradient_checkpointing: bool = field(
default=False,
metadata={"help": "Whether or not to disable gradient checkpointing."},
)
use_reentrant_gc: bool = field(
default=True,
metadata={"help": "Whether or not to use reentrant gradient checkpointing."},
)
upcast_layernorm: bool = field(
default=False,
metadata={"help": "Whether or not to upcast the layernorm weights in fp32."},
)
upcast_lmhead_output: bool = field(
default=False,
metadata={"help": "Whether or not to upcast the output of lm_head in fp32."},
)
train_from_scratch: bool = field(
default=False,
metadata={"help": "Whether or not to randomly initialize the model weights."},
)
infer_backend: Literal["huggingface", "vllm"] = field(
default="huggingface",
metadata={"help": "Backend engine used at inference."},
)
offload_folder: str = field(
default="offload",
metadata={"help": "Path to offload model weights."},
)
use_cache: bool = field(
default=True,
metadata={"help": "Whether or not to use KV cache in generation."},
)
infer_dtype: Literal["auto", "float16", "bfloat16", "float32"] = field(
default="auto",
metadata={"help": "Data type for model weights and activations at inference."},
)
hf_hub_token: Optional[str] = field(
default=None,
metadata={"help": "Auth token to log in with Hugging Face Hub."},
)
ms_hub_token: Optional[str] = field(
default=None,
metadata={"help": "Auth token to log in with ModelScope Hub."},
)
om_hub_token: Optional[str] = field(
default=None,
metadata={"help": "Auth token to log in with Modelers Hub."},
)
print_param_status: bool = field(
default=False,
metadata={"help": "For debugging purposes, print the status of the parameters in the model."},
)
trust_remote_code: bool = field(
default=False,
metadata={"help": "Whether to trust the execution of code from datasets/models defined on the Hub or not."},
)
compute_dtype: Optional[torch.dtype] = field(
default=None,
init=False,
metadata={"help": "Torch data type for computing model outputs, derived from `fp/bf16`. Do not specify it."},
)
device_map: Optional[Union[str, Dict[str, Any]]] = field(
default=None,
init=False,
metadata={"help": "Device map for model placement, derived from training stage. Do not specify it."},
)
model_max_length: Optional[int] = field(
default=None,
init=False,
metadata={"help": "The maximum input length for model, derived from `cutoff_len`. Do not specify it."},
)
block_diag_attn: bool = field(
default=False,
init=False,
metadata={"help": "Whether use block diag attention or not, derived from `neat_packing`. Do not specify it."},
)
def __post_init__(self):
if self.model_name_or_path is None:
raise ValueError("Please provide `model_name_or_path`.")
if self.split_special_tokens and self.use_fast_tokenizer:
raise ValueError("`split_special_tokens` is only supported for slow tokenizers.")
if self.adapter_name_or_path is not None: # support merging multiple lora weights
self.adapter_name_or_path = [path.strip() for path in self.adapter_name_or_path.split(",")]
if self.new_special_tokens is not None: # support multiple special tokens
self.new_special_tokens = [token.strip() for token in self.new_special_tokens.split(",")]
if self.export_quantization_bit is not None and self.export_quantization_dataset is None:
raise ValueError("Quantization dataset is necessary for exporting.")
if isinstance(self.vllm_config, str) and self.vllm_config.startswith("{"):
self.vllm_config = _convert_str_dict(json.loads(self.vllm_config))
@classmethod
def copyfrom(cls, source: "Self", **kwargs) -> "Self":
init_args, lazy_args = {}, {}
for attr in fields(source):
if attr.init:
init_args[attr.name] = getattr(source, attr.name)
else:
lazy_args[attr.name] = getattr(source, attr.name)
init_args.update(kwargs)
result = cls(**init_args)
for name, value in lazy_args.items():
setattr(result, name, value)
return result
def to_dict(self) -> Dict[str, Any]:
args = asdict(self)
args = {k: f"<{k.upper()}>" if k.endswith("token") else v for k, v in args.items()}
return args
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
from typing import Any, Dict, Optional, Tuple
import torch
import transformers
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.trainer_utils import get_last_checkpoint
from transformers.training_args import ParallelMode
from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_available
from transformers.utils.versions import require_version
from ..extras import logging
from ..extras.constants import CHECKPOINT_NAMES
from ..extras.misc import check_dependencies, get_current_device
from .data_args import DataArguments
from .evaluation_args import EvaluationArguments
from .finetuning_args import FinetuningArguments
from .generating_args import GeneratingArguments
from .model_args import ModelArguments
logger = logging.get_logger(__name__)
check_dependencies()
_TRAIN_ARGS = [ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments]
_TRAIN_CLS = Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments]
_INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
_INFER_CLS = Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
_EVAL_ARGS = [ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
_EVAL_CLS = Tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
if args is not None:
return parser.parse_dict(args)
if len(sys.argv) == 2 and (sys.argv[1].endswith(".yaml") or sys.argv[1].endswith(".yml")):
return parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
return parser.parse_json_file(os.path.abspath(sys.argv[1]))
(*parsed_args, unknown_args) = parser.parse_args_into_dataclasses(return_remaining_strings=True)
if unknown_args:
print(parser.format_help())
print(f"Got unknown args, potentially deprecated arguments: {unknown_args}")
raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}")
return (*parsed_args,)
def _set_transformers_logging() -> None:
transformers.utils.logging.set_verbosity_info()
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
def _verify_model_args(
model_args: "ModelArguments",
data_args: "DataArguments",
finetuning_args: "FinetuningArguments",
) -> None:
if model_args.adapter_name_or_path is not None and finetuning_args.finetuning_type != "lora":
raise ValueError("Adapter is only valid for the LoRA method.")
if model_args.quantization_bit is not None:
if finetuning_args.finetuning_type != "lora":
raise ValueError("Quantization is only compatible with the LoRA method.")
if finetuning_args.pissa_init:
raise ValueError("Please use scripts/pissa_init.py to initialize PiSSA for a quantized model.")
if model_args.resize_vocab:
raise ValueError("Cannot resize embedding layers of a quantized model.")
if model_args.adapter_name_or_path is not None and finetuning_args.create_new_adapter:
raise ValueError("Cannot create new adapter upon a quantized model.")
if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1:
raise ValueError("Quantized model only accepts a single adapter. Merge them first.")
if data_args.template == "yi" and model_args.use_fast_tokenizer:
logger.warning_rank0("We should use slow tokenizer for the Yi models. Change `use_fast_tokenizer` to False.")
model_args.use_fast_tokenizer = False
def _check_extra_dependencies(
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
training_args: Optional["Seq2SeqTrainingArguments"] = None,
) -> None:
if model_args.use_unsloth:
require_version("unsloth", "Please install unsloth: https://github.com/unslothai/unsloth")
if model_args.enable_liger_kernel:
require_version("liger-kernel", "To fix: pip install liger-kernel")
if model_args.mixture_of_depths is not None:
require_version("mixture-of-depth>=1.1.6", "To fix: pip install mixture-of-depth>=1.1.6")
if model_args.infer_backend == "vllm":
require_version("vllm>=0.4.3,<0.6.5", "To fix: pip install vllm>=0.4.3,<0.6.5")
if finetuning_args.use_galore:
require_version("galore_torch", "To fix: pip install galore_torch")
if finetuning_args.use_badam:
require_version("badam>=1.2.1", "To fix: pip install badam>=1.2.1")
if finetuning_args.use_adam_mini:
require_version("adam-mini", "To fix: pip install adam-mini")
if finetuning_args.plot_loss:
require_version("matplotlib", "To fix: pip install matplotlib")
if training_args is not None and training_args.predict_with_generate:
require_version("jieba", "To fix: pip install jieba")
require_version("nltk", "To fix: pip install nltk")
require_version("rouge_chinese", "To fix: pip install rouge-chinese")
def _parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
parser = HfArgumentParser(_TRAIN_ARGS)
return _parse_args(parser, args)
def _parse_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
parser = HfArgumentParser(_INFER_ARGS)
return _parse_args(parser, args)
def _parse_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
parser = HfArgumentParser(_EVAL_ARGS)
return _parse_args(parser, args)
def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_args(args)
# Setup logging
if training_args.should_log:
_set_transformers_logging()
# Check arguments
if finetuning_args.stage != "pt" and data_args.template is None:
raise ValueError("Please specify which `template` to use.")
if finetuning_args.stage != "sft":
if training_args.predict_with_generate:
raise ValueError("`predict_with_generate` cannot be set as True except SFT.")
if data_args.neat_packing:
raise ValueError("`neat_packing` cannot be set as True except SFT.")
if data_args.train_on_prompt or data_args.mask_history:
raise ValueError("`train_on_prompt` or `mask_history` cannot be set as True except SFT.")
if finetuning_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate:
raise ValueError("Please enable `predict_with_generate` to save model predictions.")
if finetuning_args.stage in ["rm", "ppo"] and training_args.load_best_model_at_end:
raise ValueError("RM and PPO stages do not support `load_best_model_at_end`.")
if finetuning_args.stage == "ppo":
if not training_args.do_train:
raise ValueError("PPO training does not support evaluation, use the SFT stage to evaluate models.")
if model_args.shift_attn:
raise ValueError("PPO training is incompatible with S^2-Attn.")
if finetuning_args.reward_model_type == "lora" and model_args.use_unsloth:
raise ValueError("Unsloth does not support lora reward model.")
if training_args.report_to and training_args.report_to[0] not in ["wandb", "tensorboard"]:
raise ValueError("PPO only accepts wandb or tensorboard logger.")
if training_args.parallel_mode == ParallelMode.NOT_DISTRIBUTED:
raise ValueError("Please launch distributed training with `llamafactory-cli` or `torchrun`.")
if training_args.deepspeed and training_args.parallel_mode != ParallelMode.DISTRIBUTED:
raise ValueError("Please use `FORCE_TORCHRUN=1` to launch DeepSpeed training.")
if training_args.max_steps == -1 and data_args.streaming:
raise ValueError("Please specify `max_steps` in streaming mode.")
if training_args.do_train and data_args.dataset is None:
raise ValueError("Please specify dataset for training.")
if (training_args.do_eval or training_args.do_predict) and (
data_args.eval_dataset is None and data_args.val_size < 1e-6
):
raise ValueError("Please specify dataset for evaluation.")
if training_args.predict_with_generate:
if is_deepspeed_zero3_enabled():
raise ValueError("`predict_with_generate` is incompatible with DeepSpeed ZeRO-3.")
if data_args.eval_dataset is None:
raise ValueError("Cannot use `predict_with_generate` if `eval_dataset` is None.")
if finetuning_args.compute_accuracy:
raise ValueError("Cannot use `predict_with_generate` and `compute_accuracy` together.")
if training_args.do_train and model_args.quantization_device_map == "auto":
raise ValueError("Cannot use device map for quantized models in training.")
if finetuning_args.pissa_init and is_deepspeed_zero3_enabled():
raise ValueError("Please use scripts/pissa_init.py to initialize PiSSA in DeepSpeed ZeRO-3.")
if finetuning_args.pure_bf16:
if not (is_torch_bf16_gpu_available() or (is_torch_npu_available() and torch.npu.is_bf16_supported())):
raise ValueError("This device does not support `pure_bf16`.")
if is_deepspeed_zero3_enabled():
raise ValueError("`pure_bf16` is incompatible with DeepSpeed ZeRO-3.")
if (
finetuning_args.use_galore
and finetuning_args.galore_layerwise
and training_args.parallel_mode == ParallelMode.DISTRIBUTED
):
raise ValueError("Distributed training does not support layer-wise GaLore.")
if finetuning_args.use_badam and training_args.parallel_mode == ParallelMode.DISTRIBUTED:
if finetuning_args.badam_mode == "ratio":
raise ValueError("Radio-based BAdam does not yet support distributed training, use layer-wise BAdam.")
elif not is_deepspeed_zero3_enabled():
raise ValueError("Layer-wise BAdam only supports DeepSpeed ZeRO-3 training.")
if finetuning_args.use_galore and training_args.deepspeed is not None:
raise ValueError("GaLore is incompatible with DeepSpeed yet.")
if model_args.infer_backend == "vllm":
raise ValueError("vLLM backend is only available for API, CLI and Web.")
if model_args.use_unsloth and is_deepspeed_zero3_enabled():
raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.")
if data_args.neat_packing and not data_args.packing:
logger.warning_rank0("`neat_packing` requires `packing` is True. Change `packing` to True.")
data_args.packing = True
_verify_model_args(model_args, data_args, finetuning_args)
_check_extra_dependencies(model_args, finetuning_args, training_args)
if (
training_args.do_train
and finetuning_args.finetuning_type == "lora"
and model_args.quantization_bit is None
and model_args.resize_vocab
and finetuning_args.additional_target is None
):
logger.warning_rank0(
"Remember to add embedding layers to `additional_target` to make the added tokens trainable."
)
if training_args.do_train and model_args.quantization_bit is not None and (not model_args.upcast_layernorm):
logger.warning_rank0("We recommend enable `upcast_layernorm` in quantized training.")
if training_args.do_train and (not training_args.fp16) and (not training_args.bf16):
logger.warning_rank0("We recommend enable mixed precision training.")
if training_args.do_train and finetuning_args.use_galore and not finetuning_args.pure_bf16:
logger.warning_rank0(
"Using GaLore with mixed precision training may significantly increases GPU memory usage."
)
if (not training_args.do_train) and model_args.quantization_bit is not None:
logger.warning_rank0("Evaluating model in 4/8-bit mode may cause lower scores.")
if (not training_args.do_train) and finetuning_args.stage == "dpo" and finetuning_args.ref_model is None:
logger.warning_rank0("Specify `ref_model` for computing rewards at evaluation.")
# Post-process training arguments
if (
training_args.parallel_mode == ParallelMode.DISTRIBUTED
and training_args.ddp_find_unused_parameters is None
and finetuning_args.finetuning_type == "lora"
):
logger.warning_rank0("`ddp_find_unused_parameters` needs to be set as False for LoRA in DDP training.")
training_args.ddp_find_unused_parameters = False
if finetuning_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type in ["full", "freeze"]:
can_resume_from_checkpoint = False
if training_args.resume_from_checkpoint is not None:
logger.warning_rank0("Cannot resume from checkpoint in current stage.")
training_args.resume_from_checkpoint = None
else:
can_resume_from_checkpoint = True
if (
training_args.resume_from_checkpoint is None
and training_args.do_train
and os.path.isdir(training_args.output_dir)
and not training_args.overwrite_output_dir
and can_resume_from_checkpoint
):
last_checkpoint = get_last_checkpoint(training_args.output_dir)
if last_checkpoint is None and any(
os.path.isfile(os.path.join(training_args.output_dir, name)) for name in CHECKPOINT_NAMES
):
raise ValueError("Output directory already exists and is not empty. Please set `overwrite_output_dir`.")
if last_checkpoint is not None:
training_args.resume_from_checkpoint = last_checkpoint
logger.info_rank0(f"Resuming training from {training_args.resume_from_checkpoint}.")
logger.info_rank0("Change `output_dir` or use `overwrite_output_dir` to avoid.")
if (
finetuning_args.stage in ["rm", "ppo"]
and finetuning_args.finetuning_type == "lora"
and training_args.resume_from_checkpoint is not None
):
logger.warning_rank0(
"Add {} to `adapter_name_or_path` to resume training from checkpoint.".format(
training_args.resume_from_checkpoint
)
)
# Post-process model arguments
if training_args.bf16 or finetuning_args.pure_bf16:
model_args.compute_dtype = torch.bfloat16
elif training_args.fp16:
model_args.compute_dtype = torch.float16
model_args.device_map = {"": get_current_device()}
model_args.model_max_length = data_args.cutoff_len
model_args.block_diag_attn = data_args.neat_packing
data_args.packing = data_args.packing if data_args.packing is not None else finetuning_args.stage == "pt"
# Log on each process the small summary
logger.info(
"Process rank: {}, device: {}, n_gpu: {}, distributed training: {}, compute dtype: {}".format(
training_args.local_rank,
training_args.device,
training_args.n_gpu,
training_args.parallel_mode == ParallelMode.DISTRIBUTED,
str(model_args.compute_dtype),
)
)
transformers.set_seed(training_args.seed)
return model_args, data_args, training_args, finetuning_args, generating_args
def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
model_args, data_args, finetuning_args, generating_args = _parse_infer_args(args)
_set_transformers_logging()
if data_args.template is None:
raise ValueError("Please specify which `template` to use.")
if model_args.infer_backend == "vllm":
if finetuning_args.stage != "sft":
raise ValueError("vLLM engine only supports auto-regressive models.")
if model_args.quantization_bit is not None:
raise ValueError("vLLM engine does not support bnb quantization (GPTQ and AWQ are supported).")
if model_args.rope_scaling is not None:
raise ValueError("vLLM engine does not support RoPE scaling.")
if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1:
raise ValueError("vLLM only accepts a single adapter. Merge them first.")
_verify_model_args(model_args, data_args, finetuning_args)
_check_extra_dependencies(model_args, finetuning_args)
if model_args.export_dir is not None and model_args.export_device == "cpu":
model_args.device_map = {"": torch.device("cpu")}
model_args.model_max_length = data_args.cutoff_len
else:
model_args.device_map = "auto"
return model_args, data_args, finetuning_args, generating_args
def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
model_args, data_args, eval_args, finetuning_args = _parse_eval_args(args)
_set_transformers_logging()
if data_args.template is None:
raise ValueError("Please specify which `template` to use.")
if model_args.infer_backend == "vllm":
raise ValueError("vLLM backend is only available for API, CLI and Web.")
_verify_model_args(model_args, data_args, finetuning_args)
_check_extra_dependencies(model_args, finetuning_args)
model_args.device_map = "auto"
transformers.set_seed(eval_args.seed)
return model_args, data_args, eval_args, finetuning_args
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from llamafactory.train.tuner import run_exp # use absolute import
def launch():
run_exp()
if __name__ == "__main__":
launch()
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .loader import load_config, load_model, load_tokenizer
from .model_utils.misc import find_all_linear_modules
from .model_utils.quantization import QuantizationMethod
from .model_utils.valuehead import load_valuehead_params
__all__ = [
"QuantizationMethod",
"load_config",
"load_model",
"load_tokenizer",
"find_all_linear_modules",
"load_valuehead_params",
]
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from typing import TYPE_CHECKING
import torch
from peft import LoraConfig, LoraModel, PeftModel, TaskType, get_peft_model
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.modeling_utils import is_fsdp_enabled
from ..extras import logging
from .model_utils.misc import find_all_linear_modules, find_expanded_modules
from .model_utils.quantization import QuantizationMethod
from .model_utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model
from .model_utils.visual import get_forbidden_modules, patch_target_modules
if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedModel
from ..hparams import FinetuningArguments, ModelArguments
logger = logging.get_logger(__name__)
def _setup_full_tuning(
model: "PreTrainedModel",
finetuning_args: "FinetuningArguments",
is_trainable: bool,
cast_trainable_params_to_fp32: bool,
) -> None:
if not is_trainable:
return
logger.info_rank0("Fine-tuning method: Full")
forbidden_modules = get_forbidden_modules(model.config, finetuning_args)
for name, param in model.named_parameters():
if not any(forbidden_module in name for forbidden_module in forbidden_modules):
if cast_trainable_params_to_fp32:
param.data = param.data.to(torch.float32)
else:
param.requires_grad_(False)
def _setup_freeze_tuning(
model: "PreTrainedModel",
finetuning_args: "FinetuningArguments",
is_trainable: bool,
cast_trainable_params_to_fp32: bool,
) -> None:
if not is_trainable:
return
logger.info_rank0("Fine-tuning method: Freeze")
if hasattr(model.config, "text_config"): # composite models
config = getattr(model.config, "text_config")
else:
config = model.config
num_layers = (
getattr(config, "num_hidden_layers", None)
or getattr(config, "num_layers", None)
or getattr(config, "n_layer", None)
)
if not num_layers:
raise ValueError("Current model does not support freeze tuning.")
if finetuning_args.use_llama_pro:
if num_layers % finetuning_args.freeze_trainable_layers != 0:
raise ValueError(
"`num_layers` {} should be divisible by `num_layer_trainable` {}.".format(
num_layers, finetuning_args.freeze_trainable_layers
)
)
stride = num_layers // finetuning_args.freeze_trainable_layers
trainable_layer_ids = range(stride - 1, num_layers + stride - 1, stride)
elif finetuning_args.freeze_trainable_layers > 0: # fine-tuning the last n layers if num_layer_trainable > 0
trainable_layer_ids = range(max(0, num_layers - finetuning_args.freeze_trainable_layers), num_layers)
else: # fine-tuning the first n layers if num_layer_trainable < 0
trainable_layer_ids = range(min(-finetuning_args.freeze_trainable_layers, num_layers))
hidden_modules = set()
non_hidden_modules = set()
for name, _ in model.named_parameters():
if ".0." in name:
hidden_modules.add(name.split(".0.")[-1].split(".")[0])
elif ".1." in name: # MoD starts from layer 1
hidden_modules.add(name.split(".1.")[-1].split(".")[0])
if re.search(r"\.\d+\.", name) is None:
non_hidden_modules.add(name.split(".")[-2])
trainable_layers = []
for module_name in finetuning_args.freeze_trainable_modules:
if module_name != "all" and module_name not in hidden_modules:
raise ValueError(
"Module {} is not found, please choose from {}".format(module_name, ", ".join(hidden_modules))
)
for idx in trainable_layer_ids:
trainable_layers.append(".{:d}.{}".format(idx, module_name if module_name != "all" else ""))
if finetuning_args.freeze_extra_modules:
for module_name in finetuning_args.freeze_extra_modules:
if module_name not in non_hidden_modules:
raise ValueError(
"Module {} is not found, please choose from {}".format(module_name, ", ".join(non_hidden_modules))
)
trainable_layers.append(module_name)
forbidden_modules = get_forbidden_modules(model.config, finetuning_args)
for name, param in model.named_parameters():
if any(trainable_layer in name for trainable_layer in trainable_layers) and not any(
forbidden_module in name for forbidden_module in forbidden_modules
):
if cast_trainable_params_to_fp32:
param.data = param.data.to(torch.float32)
else:
param.requires_grad_(False)
logger.info_rank0("Set trainable layers: {}".format(",".join(trainable_layers)))
def _setup_lora_tuning(
config: "PretrainedConfig",
model: "PreTrainedModel",
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
is_trainable: bool,
cast_trainable_params_to_fp32: bool,
) -> "PeftModel":
if is_trainable:
logger.info_rank0("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA"))
adapter_to_resume = None
if model_args.adapter_name_or_path is not None:
is_mergeable = True
if getattr(model, "quantization_method", None): # merge lora in quantized model is unstable
assert len(model_args.adapter_name_or_path) == 1, "Quantized model only accepts a single adapter."
is_mergeable = False
if is_deepspeed_zero3_enabled():
assert len(model_args.adapter_name_or_path) == 1, "Cannot use multiple adapters in DeepSpeed ZeRO-3."
is_mergeable = False
if model_args.use_unsloth:
assert len(model_args.adapter_name_or_path) == 1, "Unsloth model only accepts a single adapter."
is_mergeable = False
if (is_trainable and not finetuning_args.create_new_adapter) or (not is_mergeable):
adapter_to_merge = model_args.adapter_name_or_path[:-1]
adapter_to_resume = model_args.adapter_name_or_path[-1]
else:
adapter_to_merge = model_args.adapter_name_or_path
init_kwargs = {
"subfolder": model_args.adapter_folder,
"offload_folder": model_args.offload_folder,
"cache_dir": model_args.cache_dir,
"revision": model_args.model_revision,
"token": model_args.hf_hub_token,
}
for adapter in adapter_to_merge:
model: "LoraModel" = PeftModel.from_pretrained(model, adapter, **init_kwargs)
model = model.merge_and_unload()
if len(adapter_to_merge) > 0:
logger.info_rank0(f"Merged {len(adapter_to_merge)} adapter(s).")
if adapter_to_resume is not None: # resume lora training
if model_args.use_unsloth:
model = load_unsloth_peft_model(config, model_args, is_trainable=is_trainable)
else:
model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_trainable, **init_kwargs)
logger.info_rank0("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path)))
if is_trainable and adapter_to_resume is None: # create new lora weights while training
if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all":
target_modules = find_all_linear_modules(model, finetuning_args.freeze_vision_tower)
else:
target_modules = finetuning_args.lora_target
if finetuning_args.use_llama_pro:
target_modules = find_expanded_modules(model, target_modules, finetuning_args.freeze_trainable_layers)
target_modules = patch_target_modules(model.config, finetuning_args, target_modules)
if (
finetuning_args.use_dora
and getattr(model, "quantization_method", None) is not None
and getattr(model, "quantization_method", None) != QuantizationMethod.BITS_AND_BYTES
):
raise ValueError("DoRA is not compatible with PTQ-quantized models.")
if model_args.resize_vocab and finetuning_args.additional_target is None:
input_embeddings = model.get_input_embeddings()
output_embeddings = model.get_output_embeddings()
module_names = set()
for name, module in model.named_modules():
if module in [input_embeddings, output_embeddings]:
module_names.add(name.split(".")[-1])
finetuning_args.additional_target = module_names
logger.warning_rank0("Vocab has been resized, add {} to trainable params.".format(",".join(module_names)))
peft_kwargs = {
"r": finetuning_args.lora_rank,
"target_modules": target_modules,
"lora_alpha": finetuning_args.lora_alpha,
"lora_dropout": finetuning_args.lora_dropout,
"use_rslora": finetuning_args.use_rslora,
"use_dora": finetuning_args.use_dora,
"modules_to_save": finetuning_args.additional_target,
}
if model_args.use_unsloth:
model = get_unsloth_peft_model(model, model_args, peft_kwargs)
else:
if finetuning_args.pissa_init:
if finetuning_args.pissa_iter == -1:
logger.info_rank0("Using PiSSA initialization.")
peft_kwargs["init_lora_weights"] = "pissa"
else:
logger.info_rank0(f"Using PiSSA initialization with FSVD steps {finetuning_args.pissa_iter}.")
peft_kwargs["init_lora_weights"] = f"pissa_niter_{finetuning_args.pissa_iter}"
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
**peft_kwargs,
)
model = get_peft_model(model, lora_config)
if is_trainable and cast_trainable_params_to_fp32:
for param in filter(lambda p: p.requires_grad, model.parameters()):
param.data = param.data.to(torch.float32)
return model
def init_adapter(
config: "PretrainedConfig",
model: "PreTrainedModel",
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
is_trainable: bool,
) -> "PreTrainedModel":
r"""
Initializes the adapters.
Support full-parameter, freeze and LoRA training.
Note that the trainable parameters must be cast to float32.
"""
if is_trainable and getattr(model, "quantization_method", None) is not None:
if finetuning_args.finetuning_type != "lora":
raise ValueError("Quantized models can only be used for the LoRA tuning.")
if finetuning_args.pissa_init:
raise ValueError("Cannot initialize PiSSA adapter on quantized models.")
# cast trainable parameters to float32 if:
# 1. is_trainable and not pure_bf16 and not badam and quantization_bit is not None (qlora)
# 2. is_trainable and not pure_bf16 and not badam and not zero3 and not fsdp (zero3 or fsdp already in fp32)
cast_trainable_params_to_fp32 = False
if not is_trainable:
pass
elif finetuning_args.pure_bf16 or finetuning_args.use_badam:
logger.info_rank0("Pure bf16 / BAdam detected, remaining trainable params in half precision.")
elif model_args.quantization_bit is None and (is_deepspeed_zero3_enabled() or is_fsdp_enabled()):
logger.info_rank0("ZeRO3 / FSDP detected, remaining trainable params in float32.")
else:
logger.info_rank0("Upcasting trainable params to float32.")
cast_trainable_params_to_fp32 = True
if finetuning_args.finetuning_type == "full":
_setup_full_tuning(model, finetuning_args, is_trainable, cast_trainable_params_to_fp32)
elif finetuning_args.finetuning_type == "freeze":
_setup_freeze_tuning(model, finetuning_args, is_trainable, cast_trainable_params_to_fp32)
elif finetuning_args.finetuning_type == "lora":
model = _setup_lora_tuning(
config, model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32
)
else:
raise NotImplementedError(f"Unknown finetuning type: {finetuning_args.finetuning_type}.")
return model
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment