Commit 317a82e2 authored by chenych's avatar chenych
Browse files

Add QWQ-32B

parent 37b0ad9f
# Copyright 2024 the LlamaFactory team.
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -21,8 +21,6 @@ from typing import Any, Dict, List, NamedTuple, Tuple, Union
from typing_extensions import override
from .data_utils import SLOTS
class FunctionCall(NamedTuple):
name: str
......@@ -56,7 +54,7 @@ QWEN_TOOL_PROMPT = (
"You are provided with function signatures within <tools></tools> XML tags:\n<tools>{tool_text}"
"\n</tools>\n\nFor each function call, return a json object with function name and arguments within "
"""<tool_call></tool_call> XML tags:\n<tool_call>\n{{"name": <function-name>, """
""""arguments": <args-json-object>}}\n</tool_call><|im_end|>\n"""
""""arguments": <args-json-object>}}\n</tool_call>"""
)
......@@ -76,7 +74,7 @@ class ToolUtils(ABC):
@staticmethod
@abstractmethod
def function_formatter(functions: List["FunctionCall"]) -> SLOTS:
def function_formatter(functions: List["FunctionCall"]) -> str:
r"""
Generates the assistant message including all the tool calls.
"""
......@@ -134,12 +132,12 @@ class DefaultToolUtils(ToolUtils):
@override
@staticmethod
def function_formatter(functions: List["FunctionCall"]) -> SLOTS:
def function_formatter(functions: List["FunctionCall"]) -> str:
function_text = ""
for name, arguments in functions:
function_text += f"Action: {name}\nAction Input: {arguments}\n"
return [function_text]
return function_text
@override
@staticmethod
......@@ -180,11 +178,11 @@ class GLM4ToolUtils(ToolUtils):
@override
@staticmethod
def function_formatter(functions: List["FunctionCall"]) -> SLOTS:
def function_formatter(functions: List["FunctionCall"]) -> str:
if len(functions) > 1:
raise ValueError("GLM-4 does not support parallel functions.")
return [f"{functions[0].name}\n{functions[0].arguments}"]
return f"{functions[0].name}\n{functions[0].arguments}"
@override
@staticmethod
......@@ -221,11 +219,11 @@ class Llama3ToolUtils(ToolUtils):
@override
@staticmethod
def function_formatter(functions: List["FunctionCall"]) -> SLOTS:
def function_formatter(functions: List["FunctionCall"]) -> str:
if len(functions) > 1:
raise ValueError("Llama-3 does not support parallel functions.")
return [f'{{"name": "{functions[0].name}", "parameters": {functions[0].arguments}}}']
return f'{{"name": "{functions[0].name}", "parameters": {functions[0].arguments}}}'
@override
@staticmethod
......@@ -257,12 +255,12 @@ class MistralToolUtils(ToolUtils):
@override
@staticmethod
def function_formatter(functions: List["FunctionCall"]) -> SLOTS:
def function_formatter(functions: List["FunctionCall"]) -> str:
function_texts = []
for name, arguments in functions:
function_texts.append(f'{{"name": "{name}", "arguments": {arguments}}}')
return ["[" + ", ".join(function_texts) + "]"]
return "[" + ", ".join(function_texts) + "]"
@override
@staticmethod
......@@ -302,14 +300,14 @@ class QwenToolUtils(ToolUtils):
@override
@staticmethod
def function_formatter(functions: List["FunctionCall"]) -> SLOTS:
def function_formatter(functions: List["FunctionCall"]) -> str:
function_texts = []
for name, arguments in functions:
function_texts.append(
"<tool_call>\n" + f'{{"name": "{name}", "arguments": {arguments}}}' + "\n</tool_call>"
)
return ["\n".join(function_texts)]
return "\n".join(function_texts)
@override
@staticmethod
......
from enum import Enum, unique
from typing import TYPE_CHECKING, Dict, List, Tuple, Union
from datasets import concatenate_datasets, interleave_datasets
from ..extras.logging import get_logger
if TYPE_CHECKING:
from datasets import Dataset, IterableDataset
from transformers import Seq2SeqTrainingArguments
from ..hparams import DataArguments
logger = get_logger(__name__)
@unique
class Role(str, Enum):
USER = "user"
ASSISTANT = "assistant"
SYSTEM = "system"
FUNCTION = "function"
OBSERVATION = "observation"
def infer_max_len(source_len: int, target_len: int, max_len: int, reserved_label_len: int) -> Tuple[int, int]:
max_target_len = int(max_len * (target_len / (source_len + target_len)))
max_target_len = max(max_target_len, reserved_label_len)
max_source_len = max_len - min(max_target_len, target_len)
return max_source_len, max_target_len
def merge_dataset(
all_datasets: List[Union["Dataset", "IterableDataset"]],
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
) -> Union["Dataset", "IterableDataset"]:
if len(all_datasets) == 1:
return all_datasets[0]
elif data_args.mix_strategy == "concat":
if data_args.streaming:
logger.warning("The samples between different datasets will not be mixed in streaming mode.")
return concatenate_datasets(all_datasets)
elif data_args.mix_strategy.startswith("interleave"):
if not data_args.streaming:
logger.warning("We recommend using `mix_strategy=concat` in non-streaming mode.")
return interleave_datasets(
datasets=all_datasets,
probabilities=data_args.interleave_probs,
seed=training_args.seed,
stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted",
)
else:
raise ValueError("Unknown mixing strategy.")
def split_dataset(
dataset: Union["Dataset", "IterableDataset"], data_args: "DataArguments", training_args: "Seq2SeqTrainingArguments"
) -> Dict[str, "Dataset"]:
if training_args.do_train:
if data_args.val_size > 1e-6: # Split the dataset
if data_args.streaming:
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
val_set = dataset.take(int(data_args.val_size))
train_set = dataset.skip(int(data_args.val_size))
return {"train_dataset": train_set, "eval_dataset": val_set}
else:
val_size = int(data_args.val_size) if data_args.val_size > 1 else data_args.val_size
dataset = dataset.train_test_split(test_size=val_size, seed=training_args.seed)
return {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
else:
if data_args.streaming:
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
return {"train_dataset": dataset}
else: # do_eval or do_predict
return {"eval_dataset": dataset}
# Copyright 2024 the LlamaFactory team.
# Copyright 2025 the LlamaFactory team.
#
# This code is inspired by the Dan's test library.
# https://github.com/hendrycks/test/blob/master/evaluate_flan.py
......
# Copyright 2024 the LlamaFactory team.
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
import os
import signal
import sys
import time
from concurrent.futures import ThreadPoolExecutor
from datetime import timedelta
from typing import TYPE_CHECKING, Any, Dict, Optional
import transformers
from transformers import TrainerCallback
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length
from .constants import TRAINER_LOG
from .logging import LoggerHandler, get_logger
from .misc import fix_valuehead_checkpoint
if TYPE_CHECKING:
from transformers import TrainerControl, TrainerState, TrainingArguments
logger = get_logger(__name__)
class FixValueHeadModelCallback(TrainerCallback):
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called after a checkpoint save.
"""
if args.should_save:
fix_valuehead_checkpoint(
model=kwargs.pop("model"),
output_dir=os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step)),
safe_serialization=args.save_safetensors,
)
class LogCallback(TrainerCallback):
def __init__(self, output_dir: str) -> None:
r"""
Initializes a callback for logging training and evaluation status.
"""
""" Progress """
self.start_time = 0
self.cur_steps = 0
self.max_steps = 0
self.elapsed_time = ""
self.remaining_time = ""
self.thread_pool: Optional["ThreadPoolExecutor"] = None
""" Status """
self.aborted = False
self.do_train = False
""" Web UI """
self.webui_mode = os.environ.get("LLAMABOARD_ENABLED", "0").lower() in ["true", "1"]
if self.webui_mode:
signal.signal(signal.SIGABRT, self._set_abort)
self.logger_handler = LoggerHandler(output_dir)
logging.root.addHandler(self.logger_handler)
transformers.logging.add_handler(self.logger_handler)
def _set_abort(self, signum, frame) -> None:
self.aborted = True
def _reset(self, max_steps: int = 0) -> None:
self.start_time = time.time()
self.cur_steps = 0
self.max_steps = max_steps
self.elapsed_time = ""
self.remaining_time = ""
def _timing(self, cur_steps: int) -> None:
cur_time = time.time()
elapsed_time = cur_time - self.start_time
avg_time_per_step = elapsed_time / cur_steps if cur_steps != 0 else 0
remaining_time = (self.max_steps - cur_steps) * avg_time_per_step
self.cur_steps = cur_steps
self.elapsed_time = str(timedelta(seconds=int(elapsed_time)))
self.remaining_time = str(timedelta(seconds=int(remaining_time)))
def _write_log(self, output_dir: str, logs: Dict[str, Any]) -> None:
with open(os.path.join(output_dir, TRAINER_LOG), "a", encoding="utf-8") as f:
f.write(json.dumps(logs) + "\n")
def _create_thread_pool(self, output_dir: str) -> None:
os.makedirs(output_dir, exist_ok=True)
self.thread_pool = ThreadPoolExecutor(max_workers=1)
def _close_thread_pool(self) -> None:
if self.thread_pool is not None:
self.thread_pool.shutdown(wait=True)
self.thread_pool = None
def on_init_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the end of the initialization of the `Trainer`.
"""
if (
args.should_save
and os.path.exists(os.path.join(args.output_dir, TRAINER_LOG))
and args.overwrite_output_dir
):
logger.warning("Previous trainer log in this folder will be deleted.")
os.remove(os.path.join(args.output_dir, TRAINER_LOG))
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the beginning of training.
"""
if args.should_save:
self.do_train = True
self._reset(max_steps=state.max_steps)
self._create_thread_pool(output_dir=args.output_dir)
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the end of training.
"""
self._close_thread_pool()
def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the end of an substep during gradient accumulation.
"""
if self.aborted:
control.should_epoch_stop = True
control.should_training_stop = True
def on_step_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the end of a training step.
"""
if self.aborted:
control.should_epoch_stop = True
control.should_training_stop = True
def on_evaluate(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called after an evaluation phase.
"""
if not self.do_train:
self._close_thread_pool()
def on_predict(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called after a successful prediction.
"""
if not self.do_train:
self._close_thread_pool()
def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called after logging the last logs.
"""
if not args.should_save:
return
self._timing(cur_steps=state.global_step)
logs = dict(
current_steps=self.cur_steps,
total_steps=self.max_steps,
loss=state.log_history[-1].get("loss", None),
eval_loss=state.log_history[-1].get("eval_loss", None),
predict_loss=state.log_history[-1].get("predict_loss", None),
reward=state.log_history[-1].get("reward", None),
accuracy=state.log_history[-1].get("rewards/accuracies", None),
learning_rate=state.log_history[-1].get("learning_rate", None),
epoch=state.log_history[-1].get("epoch", None),
percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100,
elapsed_time=self.elapsed_time,
remaining_time=self.remaining_time,
throughput="{:.2f}".format(state.num_input_tokens_seen / (time.time() - self.start_time)),
total_tokens=state.num_input_tokens_seen,
)
logs = {k: v for k, v in logs.items() if v is not None}
if self.webui_mode and all(key in logs for key in ["loss", "learning_rate", "epoch"]):
logger.info(
"{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}, 'throughput': {}}}".format(
logs["loss"], logs["learning_rate"], logs["epoch"], logs["throughput"]
)
)
if self.thread_pool is not None:
self.thread_pool.submit(self._write_log, args.output_dir, logs)
def on_prediction_step(
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs
):
r"""
Event called after a prediction step.
"""
if self.do_train:
return
if self.aborted:
sys.exit(0)
if not args.should_save:
return
eval_dataloader = kwargs.pop("eval_dataloader", None)
if has_length(eval_dataloader):
if self.max_steps == 0:
self._reset(max_steps=len(eval_dataloader))
self._create_thread_pool(output_dir=args.output_dir)
self._timing(cur_steps=self.cur_steps + 1)
if self.cur_steps % 5 == 0 and self.thread_pool is not None:
logs = dict(
current_steps=self.cur_steps,
total_steps=self.max_steps,
percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100,
elapsed_time=self.elapsed_time,
remaining_time=self.remaining_time,
)
self.thread_pool.submit(self._write_log, args.output_dir, logs)
# Copyright 2024 the LlamaFactory team.
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -22,6 +22,8 @@ from peft.utils import WEIGHTS_NAME as ADAPTER_WEIGHTS_NAME
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME
AUDIO_PLACEHOLDER = os.environ.get("AUDIO_PLACEHOLDER", "<audio>")
CHECKPOINT_NAMES = {
SAFE_ADAPTER_WEIGHTS_NAME,
ADAPTER_WEIGHTS_NAME,
......@@ -58,6 +60,8 @@ METHODS = ["full", "freeze", "lora"]
MOD_SUPPORTED_MODELS = {"bloom", "falcon", "gemma", "llama", "mistral", "mixtral", "phi", "starcoder2"}
MULTIMODAL_SUPPORTED_MODELS = set()
PEFT_METHODS = {"lora"}
RUNNING_LOG = "running_log.txt"
......@@ -83,14 +87,14 @@ STAGES_USE_PAIR_DATA = {"rm", "dpo"}
SUPPORTED_CLASS_FOR_S2ATTN = {"llama"}
SWANLAB_CONFIG = "swanlab_public_config.json"
VIDEO_PLACEHOLDER = os.environ.get("VIDEO_PLACEHOLDER", "<video>")
V_HEAD_WEIGHTS_NAME = "value_head.bin"
V_HEAD_SAFE_WEIGHTS_NAME = "value_head.safetensors"
VISION_MODELS = set()
class DownloadSource(str, Enum):
DEFAULT = "hf"
......@@ -101,14 +105,16 @@ class DownloadSource(str, Enum):
def register_model_group(
models: Dict[str, Dict[DownloadSource, str]],
template: Optional[str] = None,
vision: bool = False,
multimodal: bool = False,
) -> None:
for name, path in models.items():
SUPPORTED_MODELS[name] = path
if template is not None and (any(suffix in name for suffix in ("-Chat", "-Instruct")) or vision):
if template is not None and (
any(suffix in name for suffix in ("-Chat", "-Distill", "-Instruct")) or multimodal
):
DEFAULT_TEMPLATE[name] = template
if vision:
VISION_MODELS.add(name)
if multimodal:
MULTIMODAL_SUPPORTED_MODELS.add(name)
register_model_group(
......@@ -485,14 +491,46 @@ register_model_group(
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2.5-1210",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2.5-1210",
},
"DeepSeek-V3-685B-Base": {
"DeepSeek-V3-671B-Base": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V3-Base",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V3-Base",
},
"DeepSeek-V3-685B-Chat": {
"DeepSeek-V3-671B-Chat": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V3",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V3",
},
"DeepSeek-R1-1.5B-Distill": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
},
"DeepSeek-R1-7B-Distill": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B",
},
"DeepSeek-R1-8B-Distill": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
},
"DeepSeek-R1-14B-Distill": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-R1-Distill-Qwen-14B",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-R1-Distill-Qwen-14B",
},
"DeepSeek-R1-32B-Distill": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B",
},
"DeepSeek-R1-70B-Distill": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-R1-Distill-Llama-70B",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-R1-Distill-Llama-70B",
},
"DeepSeek-R1-671B-Chat-Zero": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-R1-Zero",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-R1-Zero",
},
"DeepSeek-R1-671B-Chat": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-R1",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-R1",
},
},
template="deepseek3",
)
......@@ -813,20 +851,15 @@ register_model_group(
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-20b-chat",
DownloadSource.OPENMIND: "Intern/internlm2_5-20b-chat",
},
},
template="intern2",
)
register_model_group(
models={
"InternLM3-8B-Chat": {
DownloadSource.DEFAULT: "internlm/internlm3-8b-instruct",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm3-8b-instruct",
},
},
template="intern3",
template="intern2",
)
register_model_group(
models={
"Jamba-v0.1": {
......@@ -1003,7 +1036,7 @@ register_model_group(
},
},
template="mllama",
vision=True,
multimodal=True,
)
......@@ -1019,7 +1052,7 @@ register_model_group(
},
},
template="llava",
vision=True,
multimodal=True,
)
......@@ -1035,7 +1068,7 @@ register_model_group(
},
},
template="llava_next",
vision=True,
multimodal=True,
)
......@@ -1047,7 +1080,7 @@ register_model_group(
},
},
template="llava_next_mistral",
vision=True,
multimodal=True,
)
......@@ -1059,7 +1092,7 @@ register_model_group(
},
},
template="llava_next_llama3",
vision=True,
multimodal=True,
)
......@@ -1071,7 +1104,7 @@ register_model_group(
},
},
template="llava_next_yi",
vision=True,
multimodal=True,
)
......@@ -1087,7 +1120,7 @@ register_model_group(
},
},
template="llava_next_qwen",
vision=True,
multimodal=True,
)
......@@ -1103,7 +1136,7 @@ register_model_group(
},
},
template="llava_next_video",
vision=True,
multimodal=True,
)
......@@ -1115,7 +1148,7 @@ register_model_group(
},
},
template="llava_next_video_mistral",
vision=True,
multimodal=True,
)
......@@ -1130,7 +1163,7 @@ register_model_group(
},
},
template="llava_next_video_yi",
vision=True,
multimodal=True,
)
......@@ -1174,23 +1207,44 @@ register_model_group(
register_model_group(
models={
"MiniCPM-o-2_6-Chat": {
"MiniCPM-o-2_6": {
DownloadSource.DEFAULT: "openbmb/MiniCPM-o-2_6",
DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM-o-2_6",
},
},
template="minicpm_v",
template="minicpm_o",
multimodal=True,
)
register_model_group(
models={
"MiniCPM-V-2_6-Chat": {
"MiniCPM-V-2_6": {
DownloadSource.DEFAULT: "openbmb/MiniCPM-V-2_6",
DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM-V-2_6",
},
},
template="minicpm_v",
multimodal=True,
)
register_model_group(
models={
"Ministral-8B-Instruct-2410": {
DownloadSource.DEFAULT: "mistralai/Ministral-8B-Instruct-2410",
DownloadSource.MODELSCOPE: "mistralai/Ministral-8B-Instruct-2410",
},
"Mistral-Nemo-Base-2407": {
DownloadSource.DEFAULT: "mistralai/Mistral-Nemo-Base-2407",
DownloadSource.MODELSCOPE: "LLM-Research/Mistral-Nemo-Base-2407",
},
"Mistral-Nemo-Instruct-2407": {
DownloadSource.DEFAULT: "mistralai/Mistral-Nemo-Instruct-2407",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-Nemo-Instruct-2407",
},
},
template="ministral",
)
......@@ -1200,48 +1254,60 @@ register_model_group(
DownloadSource.DEFAULT: "mistralai/Mistral-7B-v0.1",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-v0.1",
},
"Mistral-7B-Instruct-v0.1": {
DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.1",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-Instruct-v0.1",
},
"Mistral-7B-v0.2": {
DownloadSource.DEFAULT: "alpindale/Mistral-7B-v0.2-hf",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-v0.2-hf",
},
"Mistral-7B-v0.3": {
DownloadSource.DEFAULT: "mistralai/Mistral-7B-v0.3",
DownloadSource.MODELSCOPE: "LLM-Research/mistral-7b-v0.3",
},
"Mistral-7B-Instruct-v0.1": {
DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.1",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-Instruct-v0.1",
},
"Mistral-7B-Instruct-v0.2": {
DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.2",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-Instruct-v0.2",
},
"Mistral-7B-v0.3": {
DownloadSource.DEFAULT: "mistralai/Mistral-7B-v0.3",
},
"Mistral-7B-Instruct-v0.3": {
DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.3",
DownloadSource.MODELSCOPE: "LLM-Research/Mistral-7B-Instruct-v0.3",
},
"Mistral-Nemo-Instruct-2407": {
DownloadSource.DEFAULT: "mistralai/Mistral-Nemo-Instruct-2407",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-Nemo-Instruct-2407",
},
},
template="mistral",
)
register_model_group(
models={
"Mistral-Small-24B-Base-2501": {
DownloadSource.DEFAULT: "mistralai/Mistral-Small-24B-Base-2501",
DownloadSource.MODELSCOPE: "mistralai/Mistral-Small-24B-Base-2501",
},
"Mistral-Small-24B-Instruct-2501": {
DownloadSource.DEFAULT: "mistralai/Mistral-Small-24B-Instruct-2501",
DownloadSource.MODELSCOPE: "mistralai/Mistral-Small-24B-Instruct-2501",
},
},
template="mistral_small",
)
register_model_group(
models={
"Mixtral-8x7B-v0.1": {
DownloadSource.DEFAULT: "mistralai/Mixtral-8x7B-v0.1",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x7B-v0.1",
},
"Mixtral-8x7B-v0.1-Instruct": {
DownloadSource.DEFAULT: "mistralai/Mixtral-8x7B-Instruct-v0.1",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x7B-Instruct-v0.1",
},
"Mixtral-8x22B-v0.1": {
DownloadSource.DEFAULT: "mistralai/Mixtral-8x22B-v0.1",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x22B-v0.1",
},
"Mixtral-8x7B-v0.1-Instruct": {
DownloadSource.DEFAULT: "mistralai/Mixtral-8x7B-Instruct-v0.1",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x7B-Instruct-v0.1",
},
"Mixtral-8x22B-v0.1-Instruct": {
DownloadSource.DEFAULT: "mistralai/Mixtral-8x22B-Instruct-v0.1",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x22B-Instruct-v0.1",
......@@ -1251,6 +1317,21 @@ register_model_group(
)
register_model_group(
models={
"Moonlight-16B-A3B": {
DownloadSource.DEFAULT: "moonshotai/Moonlight-16B-A3B",
DownloadSource.MODELSCOPE: "moonshotai/Moonlight-16B-A3B",
},
"Moonlight-16B-A3B-Instruct": {
DownloadSource.DEFAULT: "moonshotai/Moonlight-16B-A3B-Instruct",
DownloadSource.MODELSCOPE: "moonshotai/Moonlight-16B-A3B-Instruct",
},
},
template="moonlight",
)
register_model_group(
models={
"OLMo-1B": {
......@@ -1364,7 +1445,7 @@ register_model_group(
},
},
template="paligemma",
vision=True,
multimodal=True,
)
......@@ -1406,9 +1487,33 @@ register_model_group(
DownloadSource.DEFAULT: "google/paligemma2-28b-pt-896",
DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma2-28b-pt-896",
},
"PaliGemma2-3B-mix-224": {
DownloadSource.DEFAULT: "google/paligemma2-3b-mix-224",
DownloadSource.MODELSCOPE: "mlx-community/paligemma2-3b-mix-224-bf16",
},
"PaliGemma2-3B-mix-448": {
DownloadSource.DEFAULT: "google/paligemma2-3b-mix-448",
DownloadSource.MODELSCOPE: "mlx-community/paligemma2-3b-mix-448-bf16",
},
"PaliGemma2-10B-mix-224": {
DownloadSource.DEFAULT: "google/paligemma2-10b-mix-224",
DownloadSource.MODELSCOPE: "mlx-community/paligemma2-10b-mix-224-bf16",
},
"PaliGemma2-10B-mix-448": {
DownloadSource.DEFAULT: "google/paligemma2-10b-mix-448",
DownloadSource.MODELSCOPE: "mlx-community/paligemma2-10b-mix-448-bf16",
},
"PaliGemma2-28B-mix-224": {
DownloadSource.DEFAULT: "google/paligemma2-28b-mix-224",
DownloadSource.MODELSCOPE: "mlx-community/paligemma2-28b-mix-224-bf16",
},
"PaliGemma2-28B-mix-448": {
DownloadSource.DEFAULT: "google/paligemma2-28b-mix-448",
DownloadSource.MODELSCOPE: "mlx-community/paligemma2-28b-mix-448-bf16",
},
},
template="paligemma",
vision=True,
multimodal=True,
)
......@@ -1485,13 +1590,13 @@ register_model_group(
register_model_group(
models={
"Pixtral-12B-Instruct": {
"Pixtral-12B": {
DownloadSource.DEFAULT: "mistral-community/pixtral-12b",
DownloadSource.MODELSCOPE: "AI-ModelScope/pixtral-12b",
}
},
template="pixtral",
vision=True,
multimodal=True,
)
......@@ -1901,6 +2006,14 @@ register_model_group(
DownloadSource.DEFAULT: "Qwen/Qwen2.5-72B-Instruct",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-72B-Instruct",
},
"Qwen2.5-7B-Instruct-1M": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-7B-Instruct-1M",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-7B-Instruct-1M",
},
"Qwen2.5-14B-Instruct-1M": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-14B-Instruct-1M",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-14B-Instruct-1M",
},
"Qwen2.5-0.5B-Instruct-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-0.5B-Instruct-GPTQ-Int8",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-0.5B-Instruct-GPTQ-Int8",
......@@ -2061,6 +2174,10 @@ register_model_group(
DownloadSource.DEFAULT: "Qwen/QwQ-32B-Preview",
DownloadSource.MODELSCOPE: "Qwen/QwQ-32B-Preview",
},
"QwQ-32B-Instruct": {
DownloadSource.DEFAULT: "Qwen/QwQ-32B",
DownloadSource.MODELSCOPE: "Qwen/QwQ-32B",
},
},
template="qwen",
)
......@@ -2068,6 +2185,34 @@ register_model_group(
register_model_group(
models={
"Qwen2-Audio-7B": {
DownloadSource.DEFAULT: "Qwen/Qwen2-Audio-7B",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-Audio-7B",
},
"Qwen2-Audio-7B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2-Audio-7B-Instruct",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-Audio-7B-Instruct",
},
},
template="qwen2_audio",
multimodal=True,
)
register_model_group(
models={
"Qwen2-VL-2B": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-2B",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-VL-2B",
},
"Qwen2-VL-7B": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-7B",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-VL-7B",
},
"Qwen2-VL-72B": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-72B",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-VL-72B",
},
"Qwen2-VL-2B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-2B-Instruct",
DownloadSource.MODELSCOPE: "Qwen/Qwen2-VL-2B-Instruct",
......@@ -2122,9 +2267,33 @@ register_model_group(
DownloadSource.DEFAULT: "Qwen/QVQ-72B-Preview",
DownloadSource.MODELSCOPE: "Qwen/QVQ-72B-Preview",
},
"Qwen2.5-VL-3B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-VL-3B-Instruct",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-VL-3B-Instruct",
},
"Qwen2.5-VL-7B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-VL-7B-Instruct",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-VL-7B-Instruct",
},
"Qwen2.5-VL-72B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-VL-72B-Instruct",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-VL-72B-Instruct",
},
"Qwen2.5-VL-3B-Instruct-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-VL-3B-Instruct-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-VL-3B-Instruct-AWQ",
},
"Qwen2.5-VL-7B-Instruct-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-VL-7B-Instruct-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-VL-7B-Instruct-AWQ",
},
"Qwen2.5-VL-72B-Instruct-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-VL-72B-Instruct-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-VL-72B-Instruct-AWQ",
},
},
template="qwen2_vl",
vision=True,
multimodal=True,
)
......@@ -2249,7 +2418,7 @@ register_model_group(
},
},
template="video_llava",
vision=True,
multimodal=True,
)
......@@ -2476,7 +2645,7 @@ register_model_group(
},
},
template="yi_vl",
vision=True,
multimodal=True,
)
......
......@@ -45,6 +45,8 @@ def print_env() -> None:
if is_torch_cuda_available():
info["PyTorch version"] += " (GPU)"
info["GPU type"] = torch.cuda.get_device_name()
info["GPU number"] = torch.cuda.device_count()
info["GPU memory"] = f"{torch.cuda.mem_get_info()[1] / (1024**3):.2f}GB"
if is_torch_npu_available():
info["PyTorch version"] += " (NPU)"
......@@ -59,7 +61,7 @@ def print_env() -> None:
pass
try:
import bitsandbytes
import bitsandbytes # type: ignore
info["Bitsandbytes version"] = bitsandbytes.__version__
except Exception:
......
......@@ -34,6 +34,7 @@ from transformers.utils import (
from transformers.utils.versions import require_version
from . import logging
from .packages import is_transformers_version_greater_than
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
......@@ -77,7 +78,7 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
r"""
Optionally checks the package version.
"""
if os.getenv("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"] and not mandatory:
if is_env_enabled("DISABLE_VERSION_CHECK") and not mandatory:
logger.warning_rank0_once("Version checking has been disabled, may lead to unexpected behaviors.")
return
......@@ -93,11 +94,13 @@ def check_dependencies() -> None:
r"""
Checks the version of the required packages.
"""
check_version("transformers>=4.41.2,<=4.46.1")
check_version("datasets>=2.16.0,<=3.1.0")
check_version("accelerate>=0.34.0,<=1.0.1")
check_version("transformers>=4.41.2,<=4.49.0,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0")
check_version("datasets>=2.16.0,<=3.2.0")
check_version("accelerate>=0.34.0,<=1.2.1")
check_version("peft>=0.11.1,<=0.12.0")
check_version("trl>=0.8.6,<=0.9.6")
if is_transformers_version_greater_than("4.46.0") and not is_transformers_version_greater_than("4.48.1"):
logger.warning_rank0_once("There are known bugs in transformers v4.46.0-v4.48.0, please use other versions.")
def calculate_tps(dataset: Sequence[Dict[str, Any]], metrics: Dict[str, float], stage: Literal["sft", "rm"]) -> float:
......@@ -223,6 +226,13 @@ def is_gpu_or_npu_available() -> bool:
return is_torch_npu_available() or is_torch_cuda_available()
def is_env_enabled(env_var: str, default: str = "0") -> bool:
r"""
Checks if the environment variable is enabled.
"""
return os.getenv(env_var, default).lower() in ["true", "y", "1"]
def numpify(inputs: Union["NDArray", "torch.Tensor"]) -> "NDArray":
r"""
Casts a torch tensor or a numpy array to a numpy array.
......@@ -241,7 +251,7 @@ def skip_check_imports() -> None:
r"""
Avoids flash attention import error in custom model files.
"""
if os.getenv("FORCE_CHECK_IMPORTS", "0").lower() not in ["true", "1"]:
if not is_env_enabled("FORCE_CHECK_IMPORTS"):
transformers.dynamic_module_utils.check_imports = get_relative_imports
......@@ -287,12 +297,12 @@ def try_download_model_from_other_hub(model_args: "ModelArguments") -> str:
def use_modelscope() -> bool:
return os.getenv("USE_MODELSCOPE_HUB", "0").lower() in ["true", "1"]
return is_env_enabled("USE_MODELSCOPE_HUB")
def use_openmind() -> bool:
return os.getenv("USE_OPENMIND_HUB", "0").lower() in ["true", "1"]
return is_env_enabled("USE_OPENMIND_HUB")
def use_ray() -> bool:
return os.getenv("USE_RAY", "0").lower() in ["true", "1"]
return is_env_enabled("USE_RAY")
......@@ -42,6 +42,10 @@ def is_pyav_available():
return _is_package_available("av")
def is_librosa_available():
return _is_package_available("librosa")
def is_fastapi_available():
return _is_package_available("fastapi")
......@@ -87,11 +91,6 @@ def is_transformers_version_greater_than(content: str):
return _get_package_version("transformers") >= version.parse(content)
@lru_cache
def is_transformers_version_equal_to_4_46():
return version.parse("4.46.0") <= _get_package_version("transformers") <= version.parse("4.46.1")
def is_uvicorn_available():
return _is_package_available("uvicorn")
......
# Copyright 2024 the LlamaFactory team.
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# Copyright 2024 the LlamaFactory team.
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
......@@ -41,9 +41,9 @@ class DataArguments:
default="data",
metadata={"help": "Path to the folder containing the datasets."},
)
image_dir: Optional[str] = field(
media_dir: Optional[str] = field(
default=None,
metadata={"help": "Path to the folder containing the images or videos. Defaults to `dataset_dir`."},
metadata={"help": "Path to the folder containing the images, videos or audios. Defaults to `dataset_dir`."},
)
cutoff_len: int = field(
default=2048,
......@@ -133,8 +133,8 @@ class DataArguments:
self.dataset = split_arg(self.dataset)
self.eval_dataset = split_arg(self.eval_dataset)
if self.image_dir is None:
self.image_dir = self.dataset_dir
if self.media_dir is None:
self.media_dir = self.dataset_dir
if self.dataset is None and self.val_size > 1e-6:
raise ValueError("Cannot specify `val_size` if `dataset` is None.")
......
# Copyright 2024 the LlamaFactory team.
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# Copyright 2024 the LlamaFactory team.
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -238,7 +238,7 @@ class GaloreArguments:
metadata={"help": "Number of steps to update the GaLore projection."},
)
galore_scale: float = field(
default=0.25,
default=2.0,
metadata={"help": "GaLore scaling coefficient."},
)
galore_proj_type: Literal["std", "reverse_std", "right", "left", "full"] = field(
......@@ -279,7 +279,7 @@ class ApolloArguments:
metadata={"help": "Number of steps to update the APOLLO projection."},
)
apollo_scale: float = field(
default=1.0,
default=32.0,
metadata={"help": "APOLLO scaling coefficient."},
)
apollo_proj: Literal["svd", "random"] = field(
......
# Copyright 2024 the LlamaFactory team.
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
......@@ -58,20 +58,28 @@ class ProcessorArguments:
Arguments pertaining to the image processor.
"""
image_resolution: int = field(
default=512 * 512,
metadata={"help": "Keeps the number of pixels of image below this resolution."},
image_max_pixels: int = field(
default=768 * 768,
metadata={"help": "The maximum number of pixels of image inputs."},
)
video_resolution: int = field(
default=128 * 128,
metadata={"help": "Keeps the number of pixels of video below this resolution."},
image_min_pixels: int = field(
default=32 * 32,
metadata={"help": "The minimum number of pixels of image inputs."},
)
video_max_pixels: int = field(
default=256 * 256,
metadata={"help": "The maximum number of pixels of video inputs."},
)
video_min_pixels: int = field(
default=16 * 16,
metadata={"help": "The minimum number of pixels of video inputs."},
)
video_fps: float = field(
default=2.0,
metadata={"help": "The frames to sample per second for video inputs."},
)
video_maxlen: int = field(
default=64,
default=128,
metadata={"help": "The maximum number of sampled frames for video inputs."},
)
......@@ -87,7 +95,7 @@ class ExportArguments:
metadata={"help": "Path to the directory to save the exported model."},
)
export_size: int = field(
default=1,
default=5,
metadata={"help": "The file shard size (in GB) of the exported model."},
)
export_device: Literal["cpu", "auto"] = field(
......@@ -201,7 +209,7 @@ class ModelArguments(QuantizationArguments, ProcessorArguments, ExportArguments,
default=True,
metadata={"help": "Whether or not to use memory-efficient model loading."},
)
rope_scaling: Optional[Literal["linear", "dynamic"]] = field(
rope_scaling: Optional[Literal["linear", "dynamic", "yarn", "llama3"]] = field(
default=None,
metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."},
)
......
......@@ -32,7 +32,7 @@ from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_availab
from ..extras import logging
from ..extras.constants import CHECKPOINT_NAMES
from ..extras.misc import check_dependencies, check_version, get_current_device
from ..extras.misc import check_dependencies, check_version, get_current_device, is_env_enabled
from .data_args import DataArguments
from .evaluation_args import EvaluationArguments
from .finetuning_args import FinetuningArguments
......@@ -55,6 +55,9 @@ _EVAL_CLS = Tuple[ModelArguments, DataArguments, EvaluationArguments, Finetuning
def read_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> Union[Dict[str, Any], List[str]]:
r"""
Gets arguments from the command line or a config file.
"""
if args is not None:
return args
......@@ -80,13 +83,14 @@ def _parse_args(
print(f"Got unknown args, potentially deprecated arguments: {unknown_args}")
raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}")
return (*parsed_args,)
return tuple(parsed_args)
def _set_transformers_logging() -> None:
transformers.utils.logging.set_verbosity_info()
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
if os.getenv("LLAMAFACTORY_VERBOSITY", "INFO") in ["DEBUG", "INFO"]:
transformers.utils.logging.set_verbosity_info()
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
def _verify_model_args(
......@@ -133,7 +137,7 @@ def _check_extra_dependencies(
check_version("mixture-of-depth>=1.1.6", mandatory=True)
if model_args.infer_backend == "vllm":
check_version("vllm>=0.4.3,<=0.6.5")
check_version("vllm>=0.4.3,<=0.7.3")
check_version("vllm", mandatory=True)
if finetuning_args.use_galore:
......@@ -159,17 +163,20 @@ def _check_extra_dependencies(
def _parse_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _TRAIN_CLS:
parser = HfArgumentParser(_TRAIN_ARGS)
return _parse_args(parser, args)
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
return _parse_args(parser, args, allow_extra_keys=allow_extra_keys)
def _parse_infer_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _INFER_CLS:
parser = HfArgumentParser(_INFER_ARGS)
return _parse_args(parser, args)
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
return _parse_args(parser, args, allow_extra_keys=allow_extra_keys)
def _parse_eval_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _EVAL_CLS:
parser = HfArgumentParser(_EVAL_ARGS)
return _parse_args(parser, args)
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
return _parse_args(parser, args, allow_extra_keys=allow_extra_keys)
def get_ray_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> RayArguments:
......@@ -186,9 +193,6 @@ def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _
_set_transformers_logging()
# Check arguments
if finetuning_args.stage != "pt" and data_args.template is None:
raise ValueError("Please specify which `template` to use.")
if finetuning_args.stage != "sft":
if training_args.predict_with_generate:
raise ValueError("`predict_with_generate` cannot be set as True except SFT.")
......@@ -396,9 +400,6 @@ def get_infer_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _
_set_transformers_logging()
if data_args.template is None:
raise ValueError("Please specify which `template` to use.")
if model_args.infer_backend == "vllm":
if finetuning_args.stage != "sft":
raise ValueError("vLLM engine only supports auto-regressive models.")
......@@ -429,9 +430,6 @@ def get_eval_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _E
_set_transformers_logging()
if data_args.template is None:
raise ValueError("Please specify which `template` to use.")
if model_args.infer_backend == "vllm":
raise ValueError("vLLM backend is only available for API, CLI and Web.")
......
......@@ -16,7 +16,11 @@ class RayArguments:
ray_run_name: Optional[str] = field(
default=None,
metadata={"help": "The training results will be saved at `saves/ray_run_name`."},
metadata={"help": "The training results will be saved at `<ray_storage_path>/ray_run_name`."},
)
ray_storage_path: str = field(
default="./saves",
metadata={"help": "The storage path to save training results to"},
)
ray_num_workers: int = field(
default=1,
......
# Copyright 2024 the LlamaFactory team.
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# Copyright 2024 the LlamaFactory team.
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment