"src/targets/vscode:/vscode.git/clone" did not exist on "8b649bf91eda6187f48fa593254863068b2957e0"
Commit 317a82e2 authored by chenych's avatar chenych
Browse files

Add QWQ-32B

parent 37b0ad9f
# Copyright 2024 the LlamaFactory team.
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
......@@ -25,7 +25,7 @@ from transformers import Trainer
from typing_extensions import override
from ...extras import logging
from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than
from ...extras.packages import is_transformers_version_greater_than
from ..callbacks import FixValueHeadModelCallback, SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
......@@ -107,10 +107,6 @@ class PairwiseTrainer(Trainer):
chosen_scores, rejected_scores = chosen_scores.squeeze(), rejected_scores.squeeze()
loss = -torch.nn.functional.logsigmoid(chosen_scores.float() - rejected_scores.float()).mean()
if is_transformers_version_equal_to_4_46() and kwargs.get("num_items_in_batch"):
loss /= self.args.gradient_accumulation_steps # fixes the loss value for transformers 4.46.0-4.46.1
if return_outputs:
return loss, (loss, chosen_scores, rejected_scores)
else:
......
# Copyright 2024 the LlamaFactory team.
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
......@@ -34,7 +34,7 @@ from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
if TYPE_CHECKING:
from torch.utils.data import Dataset
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
from transformers import PreTrainedTokenizer, ProcessorMixin
from transformers.trainer import PredictionOutput
from ...hparams import FinetuningArguments
......@@ -49,7 +49,11 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
"""
def __init__(
self, finetuning_args: "FinetuningArguments", processor: Optional["ProcessorMixin"], **kwargs
self,
finetuning_args: "FinetuningArguments",
processor: Optional["ProcessorMixin"],
gen_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
) -> None:
if is_transformers_version_greater_than("4.46"):
kwargs["processing_class"] = kwargs.pop("tokenizer")
......@@ -58,6 +62,9 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
super().__init__(**kwargs)
self.finetuning_args = finetuning_args
if gen_kwargs is not None:
# https://github.com/huggingface/transformers/blob/v4.45.0/src/transformers/trainer_seq2seq.py#L287
self._gen_kwargs = gen_kwargs
if processor is not None:
self.add_callback(SaveProcessorCallback(processor))
......@@ -88,24 +95,6 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
return super()._get_train_sampler()
@override
def compute_loss(
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
r"""
Fixes the loss value. See https://github.com/huggingface/transformers/pull/35438 for details.
It should be removed after https://github.com/huggingface/transformers/pull/35651 is merged.
"""
loss = super().compute_loss(model, inputs, return_outputs, **kwargs)
if kwargs.get("num_items_in_batch") and not getattr(self, "model_accepts_loss_kwargs", False):
if return_outputs:
loss = (loss[0] / self.args.gradient_accumulation_steps, *loss[1:])
else:
loss = loss / self.args.gradient_accumulation_steps
return loss
@override
def prediction_step(
self,
......
......@@ -78,6 +78,12 @@ def run_sft(
metric_module["compute_metrics"] = ComputeAccuracy()
metric_module["preprocess_logits_for_metrics"] = eval_logit_processor
# Keyword arguments for `model.generate`
gen_kwargs = generating_args.to_dict(obey_generation_config=True)
gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids
gen_kwargs["pad_token_id"] = tokenizer.pad_token_id
gen_kwargs["logits_processor"] = get_logits_processor()
# Initialize our Trainer
trainer = CustomSeq2SeqTrainer(
model=model,
......@@ -85,17 +91,12 @@ def run_sft(
finetuning_args=finetuning_args,
data_collator=data_collator,
callbacks=callbacks,
gen_kwargs=gen_kwargs,
**dataset_module,
**tokenizer_module,
**metric_module,
)
# Keyword arguments for `model.generate`
gen_kwargs = generating_args.to_dict(obey_generation_config=True)
gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids
gen_kwargs["pad_token_id"] = tokenizer.pad_token_id
gen_kwargs["logits_processor"] = get_logits_processor()
# Training
if training_args.do_train:
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
......
# Copyright 2024 the LlamaFactory team.
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
......@@ -17,6 +17,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
from collections.abc import Mapping
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
......@@ -31,7 +33,7 @@ from transformers.trainer_pt_utils import get_parameter_names
from typing_extensions import override
from ..extras import logging
from ..extras.constants import IGNORE_INDEX
from ..extras.constants import IGNORE_INDEX, SWANLAB_CONFIG
from ..extras.packages import is_apollo_available, is_galore_available, is_ray_available
from ..hparams import FinetuningArguments, ModelArguments
from ..model import find_all_linear_modules, load_model, load_tokenizer, load_valuehead_params
......@@ -51,7 +53,7 @@ if is_ray_available():
if TYPE_CHECKING:
from transformers import PreTrainedModel, TrainerCallback
from transformers import PreTrainedModel, TrainerCallback, TrainerState
from trl import AutoModelForCausalLMWithValueHead
from ..hparams import DataArguments, RayArguments, TrainingArguments
......@@ -592,7 +594,24 @@ def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCall
if finetuning_args.swanlab_api_key is not None:
swanlab.login(api_key=finetuning_args.swanlab_api_key)
swanlab_callback = SwanLabCallback(
class SwanLabCallbackExtension(SwanLabCallback):
def setup(self, args: "TrainingArguments", state: "TrainerState", model: "PreTrainedModel", **kwargs):
if not state.is_world_process_zero:
return
super().setup(args, state, model, **kwargs)
try:
if hasattr(self, "_swanlab"):
swanlab_public_config = self._swanlab.get_run().public.json()
else: # swanlab <= 0.4.9
swanlab_public_config = self._experiment.get_run().public.json()
except Exception:
swanlab_public_config = {}
with open(os.path.join(args.output_dir, SWANLAB_CONFIG), "w") as f:
f.write(json.dumps(swanlab_public_config, indent=2))
swanlab_callback = SwanLabCallbackExtension(
project=finetuning_args.swanlab_project,
workspace=finetuning_args.swanlab_workspace,
experiment_name=finetuning_args.swanlab_run_name,
......@@ -621,7 +640,7 @@ def get_ray_trainer(
),
run_config=RunConfig(
name=ray_args.ray_run_name,
storage_path=Path("./saves").absolute().as_posix(),
storage_path=Path(ray_args.ray_storage_path).absolute().as_posix(),
),
)
return trainer
# Copyright 2024 the LlamaFactory team.
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -17,11 +17,13 @@ import shutil
from typing import TYPE_CHECKING, Any, Dict, List, Optional
import torch
import torch.distributed as dist
from transformers import PreTrainedModel
from ..data import get_template_and_fix_tokenizer
from ..extras import logging
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from ..extras.misc import infer_optim_dtype
from ..extras.packages import is_ray_available
from ..hparams import get_infer_args, get_ray_args, get_train_args, read_args
from ..model import load_model, load_tokenizer
......@@ -75,6 +77,12 @@ def _training_function(config: Dict[str, Any]) -> None:
else:
raise ValueError(f"Unknown task: {finetuning_args.stage}.")
try:
if dist.is_initialized():
dist.destroy_process_group()
except Exception as e:
logger.warning(f"Failed to destroy process group: {e}.")
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["TrainerCallback"]] = None) -> None:
args = read_args(args)
......@@ -104,7 +112,7 @@ def export_model(args: Optional[Dict[str, Any]] = None) -> None:
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]
processor = tokenizer_module["processor"]
get_template_and_fix_tokenizer(tokenizer, data_args)
template = get_template_and_fix_tokenizer(tokenizer, data_args)
model = load_model(tokenizer, model_args, finetuning_args) # must after fixing tokenizer to resize vocab
if getattr(model, "quantization_method", None) is not None and model_args.adapter_name_or_path is not None:
......@@ -117,7 +125,9 @@ def export_model(args: Optional[Dict[str, Any]] = None) -> None:
setattr(model.config, "torch_dtype", torch.float16)
else:
if model_args.infer_dtype == "auto":
output_dtype = getattr(model.config, "torch_dtype", torch.float16)
output_dtype = getattr(model.config, "torch_dtype", torch.float32)
if output_dtype == torch.float32: # if infer_dtype is auto, try using half precision first
output_dtype = infer_optim_dtype(torch.bfloat16)
else:
output_dtype = getattr(torch, model_args.infer_dtype)
......@@ -171,3 +181,7 @@ def export_model(args: Optional[Dict[str, Any]] = None) -> None:
except Exception as e:
logger.warning_rank0(f"Cannot save tokenizer, please copy the files manually: {e}.")
with open(os.path.join(model_args.export_dir, "Modelfile"), "w", encoding="utf-8") as f:
f.write(template.get_ollama_modelfile(tokenizer))
logger.info_rank0(f"Saved ollama modelfile to {model_args.export_dir}.")
from contextlib import contextmanager
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union
import torch
from transformers import Trainer
from transformers.optimization import get_scheduler
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
from transformers.trainer_pt_utils import get_parameter_names
from ..extras.logging import get_logger
from ..extras.packages import is_galore_available
from ..hparams import FinetuningArguments, ModelArguments
from ..model import find_all_linear_modules, load_model, load_tokenizer, load_valuehead_params
if is_galore_available():
from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit
if TYPE_CHECKING:
from accelerate import Accelerator
from transformers import PreTrainedModel, Seq2SeqTrainingArguments
from trl import AutoModelForCausalLMWithValueHead
from ..hparams import DataArguments
logger = get_logger(__name__)
class DummyOptimizer(torch.optim.Optimizer):
r"""
A dummy optimizer used for the GaLore algorithm.
"""
def __init__(
self, lr: float = 1e-3, optimizer_dict: Optional[Dict["torch.nn.Parameter", "torch.optim.Optimizer"]] = None
) -> None:
dummy_tensor = torch.randn(1, 1)
self.optimizer_dict = optimizer_dict
super().__init__([dummy_tensor], {"lr": lr})
def zero_grad(self, set_to_none: bool = True) -> None:
pass
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
pass
def create_modelcard_and_push(
trainer: "Trainer",
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
) -> None:
kwargs = {
"tasks": "text-generation",
"finetuned_from": model_args.model_name_or_path,
"tags": ["llama-factory", finetuning_args.finetuning_type],
}
if data_args.dataset is not None:
kwargs["dataset"] = [dataset.strip() for dataset in data_args.dataset.split(",")]
if model_args.use_unsloth:
kwargs["tags"] = kwargs["tags"] + ["unsloth"]
if not training_args.do_train:
pass
elif training_args.push_to_hub:
trainer.push_to_hub(**kwargs)
else:
trainer.create_model_card(license="other", **kwargs) # prevent from connecting to hub
def create_ref_model(
model_args: "ModelArguments", finetuning_args: "FinetuningArguments", add_valuehead: bool = False
) -> Optional[Union["PreTrainedModel", "AutoModelForCausalLMWithValueHead"]]:
r"""
Creates reference model for PPO/DPO training. Evaluation mode is not supported.
The valuehead parameter is randomly initialized since it is useless for PPO training.
"""
if finetuning_args.ref_model is not None:
ref_model_args_dict = model_args.to_dict()
ref_model_args_dict.update(
dict(
model_name_or_path=finetuning_args.ref_model,
adapter_name_or_path=finetuning_args.ref_model_adapters,
quantization_bit=finetuning_args.ref_model_quantization_bit,
)
)
ref_model_args = ModelArguments(**ref_model_args_dict)
ref_finetuning_args = FinetuningArguments()
tokenizer = load_tokenizer(ref_model_args)["tokenizer"]
ref_model = load_model(
tokenizer, ref_model_args, ref_finetuning_args, is_trainable=False, add_valuehead=add_valuehead
)
logger.info("Created reference model from {}".format(finetuning_args.ref_model))
else:
if finetuning_args.finetuning_type == "lora":
ref_model = None
else:
tokenizer = load_tokenizer(model_args)["tokenizer"]
ref_model = load_model(
tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=add_valuehead
)
logger.info("Created reference model from the model itself.")
return ref_model
def create_reward_model(
model: "AutoModelForCausalLMWithValueHead", model_args: "ModelArguments", finetuning_args: "FinetuningArguments"
) -> Optional["AutoModelForCausalLMWithValueHead"]:
r"""
Creates reward model for PPO training.
"""
if finetuning_args.reward_model_type == "api":
assert finetuning_args.reward_model.startswith("http"), "Please provide full url."
logger.info("Use reward server {}".format(finetuning_args.reward_model))
return finetuning_args.reward_model
elif finetuning_args.reward_model_type == "lora":
model.pretrained_model.load_adapter(finetuning_args.reward_model, "reward")
for name, param in model.named_parameters(): # https://github.com/huggingface/peft/issues/1090
if "default" in name:
param.data = param.data.to(torch.float32) # trainable params should in fp32
vhead_params = load_valuehead_params(finetuning_args.reward_model, model_args)
assert vhead_params is not None, "Reward model is not correctly loaded."
model.register_buffer("reward_head_weight", vhead_params["v_head.summary.weight"], persistent=False)
model.register_buffer("reward_head_bias", vhead_params["v_head.summary.bias"], persistent=False)
model.register_buffer(
"default_head_weight", torch.zeros_like(vhead_params["v_head.summary.weight"]), persistent=False
)
model.register_buffer(
"default_head_bias", torch.zeros_like(vhead_params["v_head.summary.bias"]), persistent=False
)
logger.info("Loaded adapter weights of reward model from {}".format(finetuning_args.reward_model))
return None
else:
reward_model_args_dict = model_args.to_dict()
reward_model_args_dict.update(
dict(
model_name_or_path=finetuning_args.reward_model,
adapter_name_or_path=finetuning_args.reward_model_adapters,
quantization_bit=finetuning_args.reward_model_quantization_bit,
)
)
reward_model_args = ModelArguments(**reward_model_args_dict)
reward_finetuning_args = FinetuningArguments()
tokenizer = load_tokenizer(reward_model_args)["tokenizer"]
reward_model = load_model(
tokenizer, reward_model_args, reward_finetuning_args, is_trainable=False, add_valuehead=True
)
logger.info("Loaded full weights of reward model from {}".format(finetuning_args.reward_model))
logger.warning("Please ensure the ppo model and reward model share SAME tokenizer and vocabulary.")
return reward_model
@contextmanager
def get_ref_context(accelerator: "Accelerator", model: "PreTrainedModel"):
r"""
Gets adapter context for the reference model.
"""
with accelerator.unwrap_model(model).disable_adapter():
model.eval()
yield
model.train()
def _get_decay_parameter_names(model: "PreTrainedModel") -> List[str]:
r"""
Returns a list of names of parameters with weight decay. (weights in non-layernorm layers)
"""
decay_parameters = get_parameter_names(model, ALL_LAYERNORM_LAYERS)
decay_parameters = [name for name in decay_parameters if "bias" not in name]
return decay_parameters
def _create_galore_optimizer(
model: "PreTrainedModel",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
) -> "torch.optim.Optimizer":
if len(finetuning_args.galore_target) == 1 and finetuning_args.galore_target[0] == "all":
galore_targets = find_all_linear_modules(model)
else:
galore_targets = finetuning_args.galore_target
galore_params: List["torch.nn.Parameter"] = []
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear) and any(target in name for target in galore_targets):
for param in module.parameters():
if param.requires_grad and len(param.shape) > 1:
galore_params.append(param)
galore_kwargs = {
"rank": finetuning_args.galore_rank,
"update_proj_gap": finetuning_args.galore_update_interval,
"scale": finetuning_args.galore_scale,
"proj_type": finetuning_args.galore_proj_type,
}
id_galore_params = {id(param) for param in galore_params}
decay_params, nodecay_params = [], [] # they are non-galore parameters
trainable_params: List["torch.nn.Parameter"] = [] # galore_params + decay_params + nodecay_params
decay_param_names = _get_decay_parameter_names(model)
for name, param in model.named_parameters():
if param.requires_grad:
trainable_params.append(param)
if id(param) not in id_galore_params:
if name in decay_param_names:
decay_params.append(param)
else:
nodecay_params.append(param)
_, optim_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args)
if training_args.optim == "adamw_torch":
optim_class = GaLoreAdamW
elif training_args.optim in ["adamw_bnb_8bit", "adamw_8bit", "paged_adamw_8bit"]:
optim_class = GaLoreAdamW8bit
elif training_args.optim == "adafactor":
optim_class = GaLoreAdafactor
else:
raise NotImplementedError("Unknow optim: {}".format(training_args.optim))
if finetuning_args.galore_layerwise:
if training_args.gradient_accumulation_steps != 1:
raise ValueError("Per-layer GaLore does not support gradient accumulation.")
optimizer_dict: Dict["torch.Tensor", "torch.optim.Optimizer"] = {}
for param in nodecay_params:
param_groups = [dict(params=[param], weight_decay=0.0)]
optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
for param in decay_params:
param_groups = [dict(params=[param], weight_decay=training_args.weight_decay)]
optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
for param in galore_params: # galore params have weight decay
param_groups = [dict(params=[param], weight_decay=training_args.weight_decay, **galore_kwargs)]
optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
def optimizer_hook(param: "torch.nn.Parameter"):
if param.grad is not None:
optimizer_dict[param].step()
optimizer_dict[param].zero_grad()
for param in trainable_params:
param.register_post_accumulate_grad_hook(optimizer_hook)
optimizer = DummyOptimizer(lr=training_args.learning_rate, optimizer_dict=optimizer_dict)
else:
param_groups = [
dict(params=nodecay_params, weight_decay=0.0),
dict(params=decay_params, weight_decay=training_args.weight_decay),
dict(params=galore_params, weight_decay=training_args.weight_decay, **galore_kwargs),
]
optimizer = optim_class(param_groups, **optim_kwargs)
logger.info("Using GaLore optimizer, may cause hanging at the start of training, wait patiently.")
return optimizer
def _create_loraplus_optimizer(
model: "PreTrainedModel",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
) -> "torch.optim.Optimizer":
default_lr = training_args.learning_rate
loraplus_lr = training_args.learning_rate * finetuning_args.loraplus_lr_ratio
embedding_lr = finetuning_args.loraplus_lr_embedding
decay_param_names = _get_decay_parameter_names(model)
param_dict: Dict[str, List["torch.nn.Parameter"]] = {
"lora_a": [],
"lora_b": [],
"lora_b_nodecay": [],
"embedding": [],
}
for name, param in model.named_parameters():
if param.requires_grad:
if "lora_embedding_B" in name:
param_dict["embedding"].append(param)
elif "lora_B" in name or param.ndim == 1:
if name in decay_param_names:
param_dict["lora_b"].append(param)
else:
param_dict["lora_b_nodecay"].append(param)
else:
param_dict["lora_a"].append(param)
optim_class, optim_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args)
param_groups = [
dict(params=param_dict["lora_a"], lr=default_lr, weight_decay=training_args.weight_decay),
dict(params=param_dict["lora_b"], lr=loraplus_lr, weight_decay=training_args.weight_decay),
dict(params=param_dict["lora_b_nodecay"], lr=loraplus_lr, weight_decay=0.0),
dict(params=param_dict["embedding"], lr=embedding_lr, weight_decay=training_args.weight_decay),
]
optimizer = optim_class(param_groups, **optim_kwargs)
logger.info("Using LoRA+ optimizer with loraplus lr ratio {:.2f}.".format(finetuning_args.loraplus_lr_ratio))
return optimizer
def _create_badam_optimizer(
model: "PreTrainedModel",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
) -> "torch.optim.Optimizer":
decay_params, nodecay_params = [], []
decay_param_names = _get_decay_parameter_names(model)
for name, param in model.named_parameters():
if param.requires_grad:
if name in decay_param_names:
decay_params.append(param)
else:
nodecay_params.append(param)
optim_class, optim_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args)
param_groups = [
dict(params=nodecay_params, weight_decay=0.0),
dict(params=decay_params, weight_decay=training_args.weight_decay),
]
if finetuning_args.badam_mode == "layer":
from badam import BlockOptimizer
base_optimizer = optim_class(param_groups, **optim_kwargs)
optimizer = BlockOptimizer(
base_optimizer=base_optimizer,
named_parameters_list=list(model.named_parameters()),
block_prefix_list=None,
switch_block_every=finetuning_args.badam_switch_interval,
start_block=finetuning_args.badam_start_block,
switch_mode=finetuning_args.badam_switch_mode,
verbose=finetuning_args.badam_verbose,
)
logger.info(
f"Using BAdam optimizer with layer-wise update, switch mode is {finetuning_args.badam_switch_mode}, "
f"switch block every {finetuning_args.badam_switch_interval} steps, "
f"default start block is {finetuning_args.badam_start_block}"
)
elif finetuning_args.badam_mode == "ratio":
from badam import BlockOptimizerRatio
assert finetuning_args.badam_update_ratio > 1e-6
optimizer = BlockOptimizerRatio(
param_groups=param_groups,
named_parameters_list=list(model.named_parameters()),
update_ratio=finetuning_args.badam_update_ratio,
mask_mode=finetuning_args.badam_mask_mode,
verbose=finetuning_args.badam_verbose,
include_embedding=False,
**optim_kwargs,
)
logger.info(
f"Using BAdam optimizer with ratio-wise update, update ratio is {finetuning_args.badam_update_ratio}, "
f"mask mode is {finetuning_args.badam_mask_mode}"
)
return optimizer
def create_custom_optimzer(
model: "PreTrainedModel",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
) -> Optional["torch.optim.Optimizer"]:
if finetuning_args.use_galore:
return _create_galore_optimizer(model, training_args, finetuning_args)
if finetuning_args.loraplus_lr_ratio is not None:
return _create_loraplus_optimizer(model, training_args, finetuning_args)
if finetuning_args.use_badam:
return _create_badam_optimizer(model, training_args, finetuning_args)
def create_custom_scheduler(
training_args: "Seq2SeqTrainingArguments",
num_training_steps: int,
optimizer: Optional["torch.optim.Optimizer"] = None,
) -> None:
if optimizer is not None and isinstance(optimizer, DummyOptimizer):
optimizer_dict = optimizer.optimizer_dict
scheduler_dict: Dict["torch.nn.Parameter", "torch.optim.lr_scheduler.LRScheduler"] = {}
for param in optimizer_dict.keys():
scheduler_dict[param] = get_scheduler(
training_args.lr_scheduler_type,
optimizer=optimizer_dict[param],
num_warmup_steps=training_args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
scheduler_specific_kwargs=training_args.lr_scheduler_kwargs,
)
def scheduler_hook(param: "torch.nn.Parameter"):
scheduler_dict[param].step()
for param in optimizer_dict.keys():
param.register_post_accumulate_grad_hook(scheduler_hook)
# Copyright 2024 the LlamaFactory team.
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -14,14 +14,16 @@
import json
import os
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Sequence, Tuple
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple
from transformers.utils import is_torch_npu_available
from ..chat import ChatModel
from ..data import Role
from ..extras.constants import PEFT_METHODS
from ..extras.misc import torch_gc
from ..extras.packages import is_gradio_available
from .common import QUANTIZATION_BITS, get_save_dir
from .common import get_save_dir, load_config
from .locales import ALERTS
......@@ -34,6 +36,40 @@ if is_gradio_available():
import gradio as gr
def _escape_html(text: str) -> str:
r"""
Escapes HTML characters.
"""
return text.replace("<", "&lt;").replace(">", "&gt;")
def _format_response(text: str, lang: str, escape_html: bool, thought_words: Tuple[str, str]) -> str:
r"""
Post-processes the response text.
Based on: https://huggingface.co/spaces/Lyte/DeepSeek-R1-Distill-Qwen-1.5B-Demo-GGUF/blob/main/app.py
"""
if thought_words[0] not in text:
return _escape_html(text) if escape_html else text
text = text.replace(thought_words[0], "")
result = text.split(thought_words[1], maxsplit=1)
if len(result) == 1:
summary = ALERTS["info_thinking"][lang]
thought, answer = text, ""
else:
summary = ALERTS["info_thought"][lang]
thought, answer = result
if escape_html:
thought, answer = _escape_html(thought), _escape_html(answer)
return (
f"<details open><summary class='thinking-summary'><span>{summary}</span></summary>\n\n"
f"<div class='thinking-container'>\n{thought}\n</div>\n</details>{answer}"
)
class WebChatModel(ChatModel):
def __init__(self, manager: "Manager", demo_mode: bool = False, lazy_init: bool = True) -> None:
self.manager = manager
......@@ -59,6 +95,8 @@ class WebChatModel(ChatModel):
get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
lang, model_name, model_path = get("top.lang"), get("top.model_name"), get("top.model_path")
finetuning_type, checkpoint_path = get("top.finetuning_type"), get("top.checkpoint_path")
user_config = load_config()
error = ""
if self.loaded:
error = ALERTS["err_exists"][lang]
......@@ -74,26 +112,22 @@ class WebChatModel(ChatModel):
yield error
return
if get("top.quantization_bit") in QUANTIZATION_BITS:
quantization_bit = int(get("top.quantization_bit"))
else:
quantization_bit = None
yield ALERTS["info_loading"][lang]
args = dict(
model_name_or_path=model_path,
cache_dir=user_config.get("cache_dir", None),
finetuning_type=finetuning_type,
quantization_bit=quantization_bit,
quantization_method=get("top.quantization_method"),
template=get("top.template"),
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") != "none" else None,
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
use_unsloth=(get("top.booster") == "unsloth"),
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
enable_liger_kernel=(get("top.booster") == "liger_kernel"),
infer_backend=get("infer.infer_backend"),
infer_dtype=get("infer.infer_dtype"),
trust_remote_code=True,
)
# checkpoints
if checkpoint_path:
if finetuning_type in PEFT_METHODS: # list
args["adapter_name_or_path"] = ",".join(
......@@ -102,6 +136,12 @@ class WebChatModel(ChatModel):
else: # str
args["model_name_or_path"] = get_save_dir(model_name, finetuning_type, checkpoint_path)
# quantization
if get("top.quantization_bit") != "none":
args["quantization_bit"] = int(get("top.quantization_bit"))
args["quantization_method"] = get("top.quantization_method")
args["double_quantization"] = not is_torch_npu_available()
super().__init__(args)
yield ALERTS["info_loaded"][lang]
......@@ -118,28 +158,49 @@ class WebChatModel(ChatModel):
torch_gc()
yield ALERTS["info_unloaded"][lang]
@staticmethod
def append(
self,
chatbot: List[List[Optional[str]]],
messages: Sequence[Dict[str, str]],
chatbot: List[Dict[str, str]],
messages: List[Dict[str, str]],
role: str,
query: str,
) -> Tuple[List[List[Optional[str]]], List[Dict[str, str]], str]:
return chatbot + [[query, None]], messages + [{"role": role, "content": query}], ""
escape_html: bool,
) -> Tuple[List[Dict[str, str]], List[Dict[str, str]], str]:
r"""
Adds the user input to chatbot.
Inputs: infer.chatbot, infer.messages, infer.role, infer.query, infer.escape_html
Output: infer.chatbot, infer.messages, infer.query
"""
return (
chatbot + [{"role": "user", "content": _escape_html(query) if escape_html else query}],
messages + [{"role": role, "content": query}],
"",
)
def stream(
self,
chatbot: List[List[Optional[str]]],
messages: Sequence[Dict[str, str]],
chatbot: List[Dict[str, str]],
messages: List[Dict[str, str]],
lang: str,
system: str,
tools: str,
image: Optional[Any],
video: Optional[Any],
audio: Optional[Any],
max_new_tokens: int,
top_p: float,
temperature: float,
) -> Generator[Tuple[List[List[Optional[str]]], List[Dict[str, str]]], None, None]:
chatbot[-1][1] = ""
skip_special_tokens: bool,
escape_html: bool,
) -> Generator[Tuple[List[Dict[str, str]], List[Dict[str, str]]], None, None]:
r"""
Generates output text in stream.
Inputs: infer.chatbot, infer.messages, infer.system, infer.tools, infer.image, infer.video, ...
Output: infer.chatbot, infer.messages
"""
chatbot.append({"role": "assistant", "content": ""})
response = ""
for new_text in self.stream_chat(
messages,
......@@ -147,9 +208,11 @@ class WebChatModel(ChatModel):
tools,
images=[image] if image else None,
videos=[video] if video else None,
audios=[audio] if audio else None,
max_new_tokens=max_new_tokens,
top_p=top_p,
temperature=temperature,
skip_special_tokens=skip_special_tokens,
):
response += new_text
if tools:
......@@ -159,12 +222,12 @@ class WebChatModel(ChatModel):
if isinstance(result, list):
tool_calls = [{"name": tool.name, "arguments": json.loads(tool.arguments)} for tool in result]
tool_calls = json.dumps(tool_calls, indent=4, ensure_ascii=False)
tool_calls = json.dumps(tool_calls, ensure_ascii=False)
output_messages = messages + [{"role": Role.FUNCTION.value, "content": tool_calls}]
bot_text = "```json\n" + tool_calls + "\n```"
else:
output_messages = messages + [{"role": Role.ASSISTANT.value, "content": result}]
bot_text = result
bot_text = _format_response(result, lang, escape_html, self.engine.template.thought_words)
chatbot[-1][1] = bot_text
chatbot[-1] = {"role": "assistant", "content": bot_text}
yield chatbot, output_messages
# Copyright 2024 the LlamaFactory team.
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -14,41 +14,48 @@
import json
import os
import signal
from collections import defaultdict
from typing import Any, Dict, Optional, Tuple
from datetime import datetime
from typing import Any, Dict, Optional, Union
from psutil import Process
from yaml import safe_dump, safe_load
from ..extras import logging
from ..extras.constants import (
CHECKPOINT_NAMES,
DATA_CONFIG,
DEFAULT_TEMPLATE,
PEFT_METHODS,
STAGES_USE_PAIR_DATA,
MULTIMODAL_SUPPORTED_MODELS,
SUPPORTED_MODELS,
TRAINING_STAGES,
VISION_MODELS,
TRAINING_ARGS,
DownloadSource,
)
from ..extras.misc import use_modelscope, use_openmind
from ..extras.packages import is_gradio_available
if is_gradio_available():
import gradio as gr
logger = logging.get_logger(__name__)
DEFAULT_CACHE_DIR = "cache"
DEFAULT_CONFIG_DIR = "config"
DEFAULT_DATA_DIR = "data"
DEFAULT_SAVE_DIR = "saves"
USER_CONFIG = "user_config.yaml"
QUANTIZATION_BITS = ["8", "6", "5", "4", "3", "2", "1"]
GPTQ_BITS = ["8", "4", "3", "2"]
def abort_process(pid: int) -> None:
r"""
Aborts the processes recursively in a bottom-up way.
"""
try:
children = Process(pid).children()
if children:
for child in children:
abort_process(child.pid)
os.kill(pid, signal.SIGABRT)
except Exception:
pass
def get_save_dir(*paths: str) -> os.PathLike:
......@@ -63,19 +70,19 @@ def get_save_dir(*paths: str) -> os.PathLike:
return os.path.join(DEFAULT_SAVE_DIR, *paths)
def get_config_path() -> os.PathLike:
def _get_config_path() -> os.PathLike:
r"""
Gets the path to user config.
"""
return os.path.join(DEFAULT_CACHE_DIR, USER_CONFIG)
def load_config() -> Dict[str, Any]:
def load_config() -> Dict[str, Union[str, Dict[str, Any]]]:
r"""
Loads user config if exists.
"""
try:
with open(get_config_path(), encoding="utf-8") as f:
with open(_get_config_path(), encoding="utf-8") as f:
return safe_load(f)
except Exception:
return {"lang": None, "last_model": None, "path_dict": {}, "cache_dir": None}
......@@ -94,7 +101,7 @@ def save_config(lang: str, model_name: Optional[str] = None, model_path: Optiona
if model_name and model_path:
user_config["path_dict"][model_name] = model_path
with open(get_config_path(), "w", encoding="utf-8") as f:
with open(_get_config_path(), "w", encoding="utf-8") as f:
safe_dump(user_config, f)
......@@ -122,49 +129,25 @@ def get_model_path(model_name: str) -> str:
return model_path
def get_model_info(model_name: str) -> Tuple[str, str]:
r"""
Gets the necessary information of this model.
Returns:
model_path (str)
template (str)
"""
return get_model_path(model_name), get_template(model_name)
def get_template(model_name: str) -> str:
r"""
Gets the template name if the model is a chat model.
Gets the template name if the model is a chat/distill/instruct model.
"""
return DEFAULT_TEMPLATE.get(model_name, "default")
def get_visual(model_name: str) -> bool:
def get_time() -> str:
r"""
Judges if the model is a vision language model.
Gets current date and time.
"""
return model_name in VISION_MODELS
return datetime.now().strftime(r"%Y-%m-%d-%H-%M-%S")
def list_checkpoints(model_name: str, finetuning_type: str) -> "gr.Dropdown":
def is_multimodal(model_name: str) -> bool:
r"""
Lists all available checkpoints.
Judges if the model is a vision language model.
"""
checkpoints = []
if model_name:
save_dir = get_save_dir(model_name, finetuning_type)
if save_dir and os.path.isdir(save_dir):
for checkpoint in os.listdir(save_dir):
if os.path.isdir(os.path.join(save_dir, checkpoint)) and any(
os.path.isfile(os.path.join(save_dir, checkpoint, name)) for name in CHECKPOINT_NAMES
):
checkpoints.append(checkpoint)
if finetuning_type in PEFT_METHODS:
return gr.Dropdown(value=[], choices=checkpoints, multiselect=True)
else:
return gr.Dropdown(value=None, choices=checkpoints, multiselect=False)
return model_name in MULTIMODAL_SUPPORTED_MODELS
def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]:
......@@ -183,11 +166,135 @@ def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]:
return {}
def list_datasets(dataset_dir: str = None, training_stage: str = list(TRAINING_STAGES.keys())[0]) -> "gr.Dropdown":
def load_args(config_path: str) -> Optional[Dict[str, Any]]:
r"""
Loads the training configuration from config path.
"""
try:
with open(config_path, encoding="utf-8") as f:
return safe_load(f)
except Exception:
return None
def save_args(config_path: str, config_dict: Dict[str, Any]) -> None:
r"""
Lists all available datasets in the dataset dir for the training stage.
Saves the training configuration to config path.
"""
dataset_info = load_dataset_info(dataset_dir if dataset_dir is not None else DEFAULT_DATA_DIR)
ranking = TRAINING_STAGES[training_stage] in STAGES_USE_PAIR_DATA
datasets = [k for k, v in dataset_info.items() if v.get("ranking", False) == ranking]
return gr.Dropdown(choices=datasets)
with open(config_path, "w", encoding="utf-8") as f:
safe_dump(config_dict, f)
def _clean_cmd(args: Dict[str, Any]) -> Dict[str, Any]:
r"""
Removes args with NoneType or False or empty string value.
"""
no_skip_keys = ["packing"]
return {k: v for k, v in args.items() if (k in no_skip_keys) or (v is not None and v is not False and v != "")}
def gen_cmd(args: Dict[str, Any]) -> str:
r"""
Generates CLI commands for previewing.
"""
cmd_lines = ["llamafactory-cli train "]
for k, v in _clean_cmd(args).items():
if isinstance(v, dict):
cmd_lines.append(f" --{k} {json.dumps(v, ensure_ascii=False)} ")
elif isinstance(v, list):
cmd_lines.append(f" --{k} {' '.join(map(str, v))} ")
else:
cmd_lines.append(f" --{k} {str(v)} ")
if os.name == "nt":
cmd_text = "`\n".join(cmd_lines)
else:
cmd_text = "\\\n".join(cmd_lines)
cmd_text = f"```bash\n{cmd_text}\n```"
return cmd_text
def save_cmd(args: Dict[str, Any]) -> str:
r"""
Saves CLI commands to launch training.
"""
output_dir = args["output_dir"]
os.makedirs(output_dir, exist_ok=True)
with open(os.path.join(output_dir, TRAINING_ARGS), "w", encoding="utf-8") as f:
safe_dump(_clean_cmd(args), f)
return os.path.join(output_dir, TRAINING_ARGS)
def load_eval_results(path: os.PathLike) -> str:
r"""
Gets scores after evaluation.
"""
with open(path, encoding="utf-8") as f:
result = json.dumps(json.load(f), indent=4)
return f"```json\n{result}\n```\n"
def create_ds_config() -> None:
r"""
Creates deepspeed config in the current directory.
"""
os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True)
ds_config = {
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"zero_allow_untested_optimizer": True,
"fp16": {
"enabled": "auto",
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1,
},
"bf16": {"enabled": "auto"},
}
offload_config = {
"device": "cpu",
"pin_memory": True,
}
ds_config["zero_optimization"] = {
"stage": 2,
"allgather_partitions": True,
"allgather_bucket_size": 5e8,
"overlap_comm": True,
"reduce_scatter": True,
"reduce_bucket_size": 5e8,
"contiguous_gradients": True,
"round_robin_gradients": True,
}
with open(os.path.join(DEFAULT_CACHE_DIR, "ds_z2_config.json"), "w", encoding="utf-8") as f:
json.dump(ds_config, f, indent=2)
ds_config["zero_optimization"]["offload_optimizer"] = offload_config
with open(os.path.join(DEFAULT_CACHE_DIR, "ds_z2_offload_config.json"), "w", encoding="utf-8") as f:
json.dump(ds_config, f, indent=2)
ds_config["zero_optimization"] = {
"stage": 3,
"overlap_comm": True,
"contiguous_gradients": True,
"sub_group_size": 1e9,
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_16bit_weights_on_model_save": True,
}
with open(os.path.join(DEFAULT_CACHE_DIR, "ds_z3_config.json"), "w", encoding="utf-8") as f:
json.dump(ds_config, f, indent=2)
ds_config["zero_optimization"]["offload_optimizer"] = offload_config
ds_config["zero_optimization"]["offload_param"] = offload_config
with open(os.path.join(DEFAULT_CACHE_DIR, "ds_z3_offload_config.json"), "w", encoding="utf-8") as f:
json.dump(ds_config, f, indent=2)
# Copyright 2024 the LlamaFactory team.
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# Copyright 2024 the LlamaFactory team.
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,11 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from typing import TYPE_CHECKING, Dict, Tuple
from ...data import Role
from ...extras.packages import is_gradio_available
from ..utils import check_json_schema
from ..locales import ALERTS
if is_gradio_available():
......@@ -29,11 +30,29 @@ if TYPE_CHECKING:
from ..engine import Engine
def check_json_schema(text: str, lang: str) -> None:
r"""
Checks if the json schema is valid.
"""
try:
tools = json.loads(text)
if tools:
assert isinstance(tools, list)
for tool in tools:
if "name" not in tool:
raise NotImplementedError("Name not found.")
except NotImplementedError:
gr.Warning(ALERTS["err_tool_name"][lang])
except Exception:
gr.Warning(ALERTS["err_json_schema"][lang])
def create_chat_box(
engine: "Engine", visible: bool = False
) -> Tuple["Component", "Component", Dict[str, "Component"]]:
lang = engine.manager.get_elem_by_id("top.lang")
with gr.Column(visible=visible) as chat_box:
chatbot = gr.Chatbot(show_copy_button=True)
chatbot = gr.Chatbot(type="messages", show_copy_button=True)
messages = gr.State([])
with gr.Row():
with gr.Column(scale=4):
......@@ -45,29 +64,48 @@ def create_chat_box(
with gr.Column() as mm_box:
with gr.Tab("Image"):
image = gr.Image(sources=["upload"], type="pil")
image = gr.Image(type="pil")
with gr.Tab("Video"):
video = gr.Video(sources=["upload"])
video = gr.Video()
with gr.Tab("Audio"):
audio = gr.Audio(type="filepath")
query = gr.Textbox(show_label=False, lines=8)
submit_btn = gr.Button(variant="primary")
with gr.Column(scale=1):
max_new_tokens = gr.Slider(minimum=8, maximum=4096, value=512, step=1)
max_new_tokens = gr.Slider(minimum=8, maximum=8192, value=1024, step=1)
top_p = gr.Slider(minimum=0.01, maximum=1.0, value=0.7, step=0.01)
temperature = gr.Slider(minimum=0.01, maximum=1.5, value=0.95, step=0.01)
skip_special_tokens = gr.Checkbox(value=True)
escape_html = gr.Checkbox(value=True)
clear_btn = gr.Button()
tools.input(check_json_schema, inputs=[tools, engine.manager.get_elem_by_id("top.lang")])
submit_btn.click(
engine.chatter.append,
[chatbot, messages, role, query],
[chatbot, messages, role, query, escape_html],
[chatbot, messages, query],
).then(
engine.chatter.stream,
[chatbot, messages, system, tools, image, video, max_new_tokens, top_p, temperature],
[
chatbot,
messages,
lang,
system,
tools,
image,
video,
audio,
max_new_tokens,
top_p,
temperature,
skip_special_tokens,
escape_html,
],
[chatbot, messages],
)
clear_btn.click(lambda: ([], []), outputs=[chatbot, messages])
......@@ -83,11 +121,14 @@ def create_chat_box(
mm_box=mm_box,
image=image,
video=video,
audio=audio,
query=query,
submit_btn=submit_btn,
max_new_tokens=max_new_tokens,
top_p=top_p,
temperature=temperature,
skip_special_tokens=skip_special_tokens,
escape_html=escape_html,
clear_btn=clear_btn,
),
)
# Copyright 2024 the LlamaFactory team.
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -40,6 +40,9 @@ def next_page(page_index: int, total_num: int) -> int:
def can_preview(dataset_dir: str, dataset: list) -> "gr.Button":
r"""
Checks if the dataset is a local dataset.
"""
try:
with open(os.path.join(dataset_dir, DATA_CONFIG), encoding="utf-8") as f:
dataset_info = json.load(f)
......@@ -67,6 +70,9 @@ def _load_data_file(file_path: str) -> List[Any]:
def get_preview(dataset_dir: str, dataset: list, page_index: int) -> Tuple[int, list, "gr.Column"]:
r"""
Gets the preview samples from the dataset.
"""
with open(os.path.join(dataset_dir, DATA_CONFIG), encoding="utf-8") as f:
dataset_info = json.load(f)
......
# Copyright 2024 the LlamaFactory team.
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -15,7 +15,8 @@
from typing import TYPE_CHECKING, Dict
from ...extras.packages import is_gradio_available
from ..common import DEFAULT_DATA_DIR, list_datasets
from ..common import DEFAULT_DATA_DIR
from ..control import list_datasets
from .data import create_preview_box
......
# Copyright 2024 the LlamaFactory team.
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -18,7 +18,7 @@ from ...extras.constants import PEFT_METHODS
from ...extras.misc import torch_gc
from ...extras.packages import is_gradio_available
from ...train.tuner import export_model
from ..common import GPTQ_BITS, get_save_dir
from ..common import get_save_dir, load_config
from ..locales import ALERTS
......@@ -32,6 +32,9 @@ if TYPE_CHECKING:
from ..engine import Engine
GPTQ_BITS = ["8", "4", "3", "2"]
def can_quantize(checkpoint_path: Union[str, List[str]]) -> "gr.Dropdown":
if isinstance(checkpoint_path, list) and len(checkpoint_path) != 0:
return gr.Dropdown(value="none", interactive=False)
......@@ -54,6 +57,7 @@ def save_model(
export_dir: str,
export_hub_model_id: str,
) -> Generator[str, None, None]:
user_config = load_config()
error = ""
if not model_name:
error = ALERTS["err_no_model"][lang]
......@@ -75,6 +79,7 @@ def save_model(
args = dict(
model_name_or_path=model_path,
cache_dir=user_config.get("cache_dir", None),
finetuning_type=finetuning_type,
template=template,
export_dir=export_dir,
......
# Copyright 2024 the LlamaFactory team.
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -15,7 +15,7 @@
from typing import TYPE_CHECKING, Dict
from ...extras.packages import is_gradio_available
from ..common import get_visual
from ..common import is_multimodal
from .chatbot import create_chat_box
......@@ -66,7 +66,7 @@ def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]:
).then(lambda: gr.Column(visible=engine.chatter.loaded), outputs=[chat_elems["chat_box"]])
engine.manager.get_elem_by_id("top.model_name").change(
lambda model_name: gr.Column(visible=get_visual(model_name)),
lambda model_name: gr.Column(visible=is_multimodal(model_name)),
[engine.manager.get_elem_by_id("top.model_name")],
[chat_elems["mm_box"]],
)
......
# Copyright 2024 the LlamaFactory team.
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -17,8 +17,8 @@ from typing import TYPE_CHECKING, Dict
from ...data import TEMPLATES
from ...extras.constants import METHODS, SUPPORTED_MODELS
from ...extras.packages import is_gradio_available
from ..common import get_model_info, list_checkpoints, save_config
from ..utils import can_quantize, can_quantize_to
from ..common import save_config
from ..control import can_quantize, can_quantize_to, get_model_info, list_checkpoints
if is_gradio_available():
......@@ -30,11 +30,10 @@ if TYPE_CHECKING:
def create_top() -> Dict[str, "Component"]:
available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"]
with gr.Row():
lang = gr.Dropdown(choices=["en", "ru", "zh", "ko"], scale=1)
model_name = gr.Dropdown(choices=available_models, scale=3)
lang = gr.Dropdown(choices=["en", "ru", "zh", "ko", "ja"], value=None, scale=1)
available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"]
model_name = gr.Dropdown(choices=available_models, value=None, scale=3)
model_path = gr.Textbox(scale=3)
with gr.Row():
......@@ -42,11 +41,11 @@ def create_top() -> Dict[str, "Component"]:
checkpoint_path = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=6)
with gr.Row():
quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none", allow_custom_value=True, scale=2)
quantization_method = gr.Dropdown(choices=["bitsandbytes", "hqq", "eetq"], value="bitsandbytes", scale=2)
template = gr.Dropdown(choices=list(TEMPLATES.keys()), value="default", scale=2)
rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none", scale=3)
booster = gr.Radio(choices=["auto", "flashattn2", "unsloth", "liger_kernel"], value="auto", scale=5)
quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none", allow_custom_value=True)
quantization_method = gr.Dropdown(choices=["bitsandbytes", "hqq", "eetq"], value="bitsandbytes")
template = gr.Dropdown(choices=list(TEMPLATES.keys()), value="default")
rope_scaling = gr.Dropdown(choices=["none", "linear", "dynamic", "yarn", "llama3"], value="none")
booster = gr.Dropdown(choices=["auto", "flashattn2", "unsloth", "liger_kernel"], value="auto")
model_name.change(get_model_info, [model_name], [model_path, template], queue=False).then(
list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False
......
# Copyright 2024 the LlamaFactory team.
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -19,8 +19,8 @@ from transformers.trainer_utils import SchedulerType
from ...extras.constants import TRAINING_STAGES
from ...extras.misc import get_device_count
from ...extras.packages import is_gradio_available
from ..common import DEFAULT_DATA_DIR, list_checkpoints, list_datasets
from ..utils import change_stage, list_config_paths, list_output_dirs
from ..common import DEFAULT_DATA_DIR
from ..control import change_stage, list_checkpoints, list_config_paths, list_datasets, list_output_dirs
from .data import create_preview_box
......@@ -39,9 +39,8 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
elem_dict = dict()
with gr.Row():
training_stage = gr.Dropdown(
choices=list(TRAINING_STAGES.keys()), value=list(TRAINING_STAGES.keys())[0], scale=1
)
stages = list(TRAINING_STAGES.keys())
training_stage = gr.Dropdown(choices=stages, value=stages[0], scale=1)
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=1)
dataset = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=4)
preview_elems = create_preview_box(dataset_dir, dataset)
......@@ -107,8 +106,12 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
use_llama_pro = gr.Checkbox()
with gr.Column():
shift_attn = gr.Checkbox()
report_to = gr.Checkbox()
report_to = gr.Dropdown(
choices=["none", "all", "wandb", "mlflow", "neptune", "tensorboard"],
value=["none"],
allow_custom_value=True,
multiselect=True,
)
input_elems.update(
{
......@@ -123,7 +126,6 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
mask_history,
resize_vocab,
use_llama_pro,
shift_attn,
report_to,
}
)
......@@ -141,7 +143,6 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
mask_history=mask_history,
resize_vocab=resize_vocab,
use_llama_pro=use_llama_pro,
shift_attn=shift_attn,
report_to=report_to,
)
)
......@@ -298,9 +299,18 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
swanlab_workspace = gr.Textbox()
swanlab_api_key = gr.Textbox()
swanlab_mode = gr.Dropdown(choices=["cloud", "local"], value="cloud")
swanlab_link = gr.Markdown(visible=False)
input_elems.update(
{use_swanlab, swanlab_project, swanlab_run_name, swanlab_workspace, swanlab_api_key, swanlab_mode}
{
use_swanlab,
swanlab_project,
swanlab_run_name,
swanlab_workspace,
swanlab_api_key,
swanlab_mode,
swanlab_link,
}
)
elem_dict.update(
dict(
......@@ -311,6 +321,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
swanlab_workspace=swanlab_workspace,
swanlab_api_key=swanlab_api_key,
swanlab_mode=swanlab_mode,
swanlab_link=swanlab_link,
)
)
......@@ -363,7 +374,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
loss_viewer=loss_viewer,
)
)
output_elems = [output_box, progress_bar, loss_viewer]
output_elems = [output_box, progress_bar, loss_viewer, swanlab_link]
cmd_preview_btn.click(engine.runner.preview_train, input_elems, output_elems, concurrency_limit=None)
start_btn.click(engine.runner.run_train, input_elems, output_elems)
......
# Copyright 2024 the LlamaFactory team.
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -14,19 +14,23 @@
import json
import os
import signal
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple
import psutil
from transformers.trainer_utils import get_last_checkpoint
from yaml import safe_dump, safe_load
from ..extras.constants import PEFT_METHODS, RUNNING_LOG, TRAINER_LOG, TRAINING_ARGS, TRAINING_STAGES
from ..extras.constants import (
CHECKPOINT_NAMES,
PEFT_METHODS,
RUNNING_LOG,
STAGES_USE_PAIR_DATA,
SWANLAB_CONFIG,
TRAINER_LOG,
TRAINING_STAGES,
)
from ..extras.packages import is_gradio_available, is_matplotlib_available
from ..extras.ploting import gen_loss_plot
from ..model import QuantizationMethod
from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, get_save_dir
from .common import DEFAULT_CONFIG_DIR, DEFAULT_DATA_DIR, get_model_path, get_save_dir, get_template, load_dataset_info
from .locales import ALERTS
......@@ -34,24 +38,12 @@ if is_gradio_available():
import gradio as gr
def abort_process(pid: int) -> None:
r"""
Aborts the processes recursively in a bottom-up way.
"""
try:
children = psutil.Process(pid).children()
if children:
for child in children:
abort_process(child.pid)
os.kill(pid, signal.SIGABRT)
except Exception:
pass
def can_quantize(finetuning_type: str) -> "gr.Dropdown":
r"""
Judges if the quantization is available in this finetuning type.
Inputs: top.finetuning_type
Outputs: top.quantization_bit
"""
if finetuning_type not in PEFT_METHODS:
return gr.Dropdown(value="none", interactive=False)
......@@ -61,7 +53,10 @@ def can_quantize(finetuning_type: str) -> "gr.Dropdown":
def can_quantize_to(quantization_method: str) -> "gr.Dropdown":
r"""
Returns the available quantization bits.
Gets the available quantization bits.
Inputs: top.quantization_method
Outputs: top.quantization_bit
"""
if quantization_method == QuantizationMethod.BITS_AND_BYTES.value:
available_bits = ["none", "8", "4"]
......@@ -76,93 +71,42 @@ def can_quantize_to(quantization_method: str) -> "gr.Dropdown":
def change_stage(training_stage: str = list(TRAINING_STAGES.keys())[0]) -> Tuple[List[str], bool]:
r"""
Modifys states after changing the training stage.
"""
return [], TRAINING_STAGES[training_stage] == "pt"
def check_json_schema(text: str, lang: str) -> None:
r"""
Checks if the json schema is valid.
"""
try:
tools = json.loads(text)
if tools:
assert isinstance(tools, list)
for tool in tools:
if "name" not in tool:
raise NotImplementedError("Name not found.")
except NotImplementedError:
gr.Warning(ALERTS["err_tool_name"][lang])
except Exception:
gr.Warning(ALERTS["err_json_schema"][lang])
def clean_cmd(args: Dict[str, Any]) -> Dict[str, Any]:
r"""
Removes args with NoneType or False or empty string value.
Inputs: train.training_stage
Outputs: train.dataset, train.packing
"""
no_skip_keys = ["packing"]
return {k: v for k, v in args.items() if (k in no_skip_keys) or (v is not None and v is not False and v != "")}
def gen_cmd(args: Dict[str, Any]) -> str:
r"""
Generates arguments for previewing.
"""
cmd_lines = ["llamafactory-cli train "]
for k, v in clean_cmd(args).items():
cmd_lines.append(f" --{k} {str(v)} ")
if os.name == "nt":
cmd_text = "`\n".join(cmd_lines)
else:
cmd_text = "\\\n".join(cmd_lines)
cmd_text = f"```bash\n{cmd_text}\n```"
return cmd_text
def save_cmd(args: Dict[str, Any]) -> str:
r"""
Saves arguments to launch training.
"""
output_dir = args["output_dir"]
os.makedirs(output_dir, exist_ok=True)
with open(os.path.join(output_dir, TRAINING_ARGS), "w", encoding="utf-8") as f:
safe_dump(clean_cmd(args), f)
return os.path.join(output_dir, TRAINING_ARGS)
return [], TRAINING_STAGES[training_stage] == "pt"
def get_eval_results(path: os.PathLike) -> str:
def get_model_info(model_name: str) -> Tuple[str, str]:
r"""
Gets scores after evaluation.
"""
with open(path, encoding="utf-8") as f:
result = json.dumps(json.load(f), indent=4)
return f"```json\n{result}\n```\n"
Gets the necessary information of this model.
def get_time() -> str:
r"""
Gets current date and time.
Inputs: top.model_name
Outputs: top.model_path, top.template
"""
return datetime.now().strftime(r"%Y-%m-%d-%H-%M-%S")
return get_model_path(model_name), get_template(model_name)
def get_trainer_info(output_path: os.PathLike, do_train: bool) -> Tuple[str, "gr.Slider", Optional["gr.Plot"]]:
def get_trainer_info(lang: str, output_path: os.PathLike, do_train: bool) -> Tuple[str, "gr.Slider", Dict[str, Any]]:
r"""
Gets training infomation for monitor.
If do_train is True:
Inputs: top.lang, train.output_path
Outputs: train.output_box, train.progress_bar, train.loss_viewer, train.swanlab_link
If do_train is False:
Inputs: top.lang, eval.output_path
Outputs: eval.output_box, eval.progress_bar, None, None
"""
running_log = ""
running_progress = gr.Slider(visible=False)
running_loss = None
running_info = {}
running_log_path = os.path.join(output_path, RUNNING_LOG)
if os.path.isfile(running_log_path):
with open(running_log_path, encoding="utf-8") as f:
running_log = f.read()
running_log = f.read()[-20000:] # avoid lengthy log
trainer_log_path = os.path.join(output_path, TRAINER_LOG)
if os.path.isfile(trainer_log_path):
......@@ -183,33 +127,50 @@ def get_trainer_info(output_path: os.PathLike, do_train: bool) -> Tuple[str, "gr
running_progress = gr.Slider(label=label, value=percentage, visible=True)
if do_train and is_matplotlib_available():
running_loss = gr.Plot(gen_loss_plot(trainer_log))
running_info["loss_viewer"] = gr.Plot(gen_loss_plot(trainer_log))
return running_log, running_progress, running_loss
swanlab_config_path = os.path.join(output_path, SWANLAB_CONFIG)
if os.path.isfile(swanlab_config_path):
with open(swanlab_config_path, encoding="utf-8") as f:
swanlab_public_config = json.load(f)
swanlab_link = swanlab_public_config["cloud"]["experiment_url"]
if swanlab_link is not None:
running_info["swanlab_link"] = gr.Markdown(
ALERTS["info_swanlab_link"][lang] + swanlab_link, visible=True
)
def load_args(config_path: str) -> Optional[Dict[str, Any]]:
r"""
Loads saved arguments.
"""
try:
with open(config_path, encoding="utf-8") as f:
return safe_load(f)
except Exception:
return None
return running_log, running_progress, running_info
def save_args(config_path: str, config_dict: Dict[str, Any]):
def list_checkpoints(model_name: str, finetuning_type: str) -> "gr.Dropdown":
r"""
Saves arguments.
Lists all available checkpoints.
Inputs: top.model_name, top.finetuning_type
Outputs: top.checkpoint_path
"""
with open(config_path, "w", encoding="utf-8") as f:
safe_dump(config_dict, f)
checkpoints = []
if model_name:
save_dir = get_save_dir(model_name, finetuning_type)
if save_dir and os.path.isdir(save_dir):
for checkpoint in os.listdir(save_dir):
if os.path.isdir(os.path.join(save_dir, checkpoint)) and any(
os.path.isfile(os.path.join(save_dir, checkpoint, name)) for name in CHECKPOINT_NAMES
):
checkpoints.append(checkpoint)
if finetuning_type in PEFT_METHODS:
return gr.Dropdown(value=[], choices=checkpoints, multiselect=True)
else:
return gr.Dropdown(value=None, choices=checkpoints, multiselect=False)
def list_config_paths(current_time: str) -> "gr.Dropdown":
r"""
Lists all the saved configuration files.
Inputs: train.current_time
Outputs: train.config_path
"""
config_files = [f"{current_time}.yaml"]
if os.path.isdir(DEFAULT_CONFIG_DIR):
......@@ -220,9 +181,25 @@ def list_config_paths(current_time: str) -> "gr.Dropdown":
return gr.Dropdown(choices=config_files)
def list_datasets(dataset_dir: str = None, training_stage: str = list(TRAINING_STAGES.keys())[0]) -> "gr.Dropdown":
r"""
Lists all available datasets in the dataset dir for the training stage.
Inputs: *.dataset_dir, *.training_stage
Outputs: *.dataset
"""
dataset_info = load_dataset_info(dataset_dir if dataset_dir is not None else DEFAULT_DATA_DIR)
ranking = TRAINING_STAGES[training_stage] in STAGES_USE_PAIR_DATA
datasets = [k for k, v in dataset_info.items() if v.get("ranking", False) == ranking]
return gr.Dropdown(choices=datasets)
def list_output_dirs(model_name: Optional[str], finetuning_type: str, current_time: str) -> "gr.Dropdown":
r"""
Lists all the directories that can resume from.
Inputs: top.model_name, top.finetuning_type, train.current_time
Outputs: train.output_dir
"""
output_dirs = [f"train_{current_time}"]
if model_name:
......@@ -234,66 +211,3 @@ def list_output_dirs(model_name: Optional[str], finetuning_type: str, current_ti
output_dirs.append(folder)
return gr.Dropdown(choices=output_dirs)
def create_ds_config() -> None:
r"""
Creates deepspeed config.
"""
os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True)
ds_config = {
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"zero_allow_untested_optimizer": True,
"fp16": {
"enabled": "auto",
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1,
},
"bf16": {"enabled": "auto"},
}
offload_config = {
"device": "cpu",
"pin_memory": True,
}
ds_config["zero_optimization"] = {
"stage": 2,
"allgather_partitions": True,
"allgather_bucket_size": 5e8,
"overlap_comm": True,
"reduce_scatter": True,
"reduce_bucket_size": 5e8,
"contiguous_gradients": True,
"round_robin_gradients": True,
}
with open(os.path.join(DEFAULT_CACHE_DIR, "ds_z2_config.json"), "w", encoding="utf-8") as f:
json.dump(ds_config, f, indent=2)
ds_config["zero_optimization"]["offload_optimizer"] = offload_config
with open(os.path.join(DEFAULT_CACHE_DIR, "ds_z2_offload_config.json"), "w", encoding="utf-8") as f:
json.dump(ds_config, f, indent=2)
ds_config["zero_optimization"] = {
"stage": 3,
"overlap_comm": True,
"contiguous_gradients": True,
"sub_group_size": 1e9,
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_16bit_weights_on_model_save": True,
}
with open(os.path.join(DEFAULT_CACHE_DIR, "ds_z3_config.json"), "w", encoding="utf-8") as f:
json.dump(ds_config, f, indent=2)
ds_config["zero_optimization"]["offload_optimizer"] = offload_config
ds_config["zero_optimization"]["offload_param"] = offload_config
with open(os.path.join(DEFAULT_CACHE_DIR, "ds_z3_offload_config.json"), "w", encoding="utf-8") as f:
json.dump(ds_config, f, indent=2)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment