"template/solar-instruct.gotmpl" did not exist on "9b6c2e6eb62c234f8a44556984bbb680d7065e01"
Commit c7c477c7 authored by chenych's avatar chenych
Browse files

add grpo

parents
Pipeline #2942 failed with stages
in 0 seconds
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import TYPE_CHECKING, Optional
from transformers import DataCollatorForLanguageModeling
from ...data import get_dataset, get_template_and_fix_tokenizer
from ...extras.ploting import plot_loss
from ...model import load_model, load_tokenizer
from ..trainer_utils import create_modelcard_and_push
from .trainer import CustomTrainer
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback
from ...hparams import DataArguments, FinetuningArguments, ModelArguments
def run_pt(
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
callbacks: Optional[list["TrainerCallback"]] = None,
):
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]
template = get_template_and_fix_tokenizer(tokenizer, data_args)
dataset_module = get_dataset(template, model_args, data_args, training_args, stage="pt", **tokenizer_module)
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
# Initialize our Trainer
trainer = CustomTrainer(
model=model,
args=training_args,
finetuning_args=finetuning_args,
data_collator=data_collator,
callbacks=callbacks,
**dataset_module,
**tokenizer_module,
)
# Training
if training_args.do_train:
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
trainer.save_model()
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
keys = ["loss"]
if isinstance(dataset_module.get("eval_dataset"), dict):
keys += [f"eval_{key}_loss" for key in dataset_module["eval_dataset"].keys()]
else:
keys += ["eval_loss"]
plot_loss(training_args.output_dir, keys=keys)
# Evaluation
if training_args.do_eval:
metrics = trainer.evaluate(metric_key_prefix="eval")
if isinstance(dataset_module.get("eval_dataset"), dict):
for key in dataset_module["eval_dataset"].keys():
try:
perplexity = math.exp(metrics[f"eval_{key}_loss"])
except OverflowError:
perplexity = float("inf")
metrics[f"eval_{key}_perplexity"] = perplexity
else:
try:
perplexity = math.exp(metrics["eval_loss"])
except OverflowError:
perplexity = float("inf")
metrics["eval_perplexity"] = perplexity
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
# Create model card
create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .workflow import run_rm
__all__ = ["run_rm"]
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional
import numpy as np
from ...extras.misc import numpify
if TYPE_CHECKING:
from transformers import EvalPrediction
@dataclass
class ComputeAccuracy:
r"""Compute reward accuracy and support `batch_eval_metrics`."""
def _dump(self) -> Optional[dict[str, float]]:
result = None
if hasattr(self, "score_dict"):
result = {k: float(np.mean(v)) for k, v in self.score_dict.items()}
self.score_dict = {"accuracy": []}
return result
def __post_init__(self):
self._dump()
def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> Optional[dict[str, float]]:
chosen_scores, rejected_scores = numpify(eval_preds.predictions[0]), numpify(eval_preds.predictions[1])
if not chosen_scores.shape:
self.score_dict["accuracy"].append(chosen_scores > rejected_scores)
else:
for i in range(len(chosen_scores)):
self.score_dict["accuracy"].append(chosen_scores[i] > rejected_scores[i])
if compute_result:
return self._dump()
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/trainer.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
from types import MethodType
from typing import TYPE_CHECKING, Optional, Union
import torch
from transformers import Trainer
from typing_extensions import override
from ...extras import logging
from ...extras.packages import is_transformers_version_greater_than
from ..callbacks import FixValueHeadModelCallback, SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
if TYPE_CHECKING:
from transformers import PreTrainedModel, ProcessorMixin
from transformers.trainer import PredictionOutput
from ...hparams import FinetuningArguments
logger = logging.get_logger(__name__)
class PairwiseTrainer(Trainer):
r"""Inherits Trainer to compute pairwise loss."""
def __init__(
self, finetuning_args: "FinetuningArguments", processor: Optional["ProcessorMixin"], **kwargs
) -> None:
if is_transformers_version_greater_than("4.46"):
kwargs["processing_class"] = kwargs.pop("tokenizer")
super().__init__(**kwargs)
self.model_accepts_loss_kwargs = False # overwrite trainer's default behavior
self.finetuning_args = finetuning_args
self.can_return_loss = True # override property to return eval_loss
self.add_callback(FixValueHeadModelCallback)
if processor is not None:
self.add_callback(SaveProcessorCallback(processor))
if finetuning_args.use_badam:
from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback)
@override
def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None:
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
return super().create_optimizer()
@override
def create_scheduler(
self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
) -> "torch.optim.lr_scheduler.LRScheduler":
create_custom_scheduler(self.args, num_training_steps, optimizer)
return super().create_scheduler(num_training_steps, optimizer)
@override
def _get_train_sampler(self, *args, **kwargs) -> Optional["torch.utils.data.Sampler"]:
if self.finetuning_args.disable_shuffling:
return torch.utils.data.SequentialSampler(self.train_dataset)
return super()._get_train_sampler(*args, **kwargs)
@override
def compute_loss(
self, model: "PreTrainedModel", inputs: dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
) -> Union["torch.Tensor", tuple["torch.Tensor", list["torch.Tensor"]]]:
r"""Compute pairwise loss. The first n examples are chosen and the last n examples are rejected.
Subclass and override to inject custom behavior.
Note that the first element will be removed from the output tuple.
See: https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/trainer.py#L3842
"""
_, _, values = model(**inputs, output_hidden_states=True, return_dict=True, use_cache=False)
batch_size = inputs["input_ids"].size(0) // 2
chosen_masks, rejected_masks = torch.split(inputs["attention_mask"], batch_size, dim=0)
chosen_rewards, rejected_rewards = torch.split(values, batch_size, dim=0)
chosen_scores = chosen_rewards.gather(dim=-1, index=(chosen_masks.sum(dim=-1, keepdim=True) - 1))
rejected_scores = rejected_rewards.gather(dim=-1, index=(rejected_masks.sum(dim=-1, keepdim=True) - 1))
chosen_scores, rejected_scores = chosen_scores.squeeze(), rejected_scores.squeeze()
loss = -torch.nn.functional.logsigmoid(chosen_scores.float() - rejected_scores.float()).mean()
if return_outputs:
return loss, (loss, chosen_scores, rejected_scores)
else:
return loss
def save_predictions(self, predict_results: "PredictionOutput") -> None:
r"""Save model predictions to `output_dir`.
A custom behavior that not contained in Seq2SeqTrainer.
"""
if not self.is_world_process_zero():
return
output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
logger.info_rank0(f"Saving prediction results to {output_prediction_file}")
chosen_scores, rejected_scores = predict_results.predictions
with open(output_prediction_file, "w", encoding="utf-8") as writer:
res: list[str] = []
for c_score, r_score in zip(chosen_scores, rejected_scores):
res.append(json.dumps({"chosen": round(float(c_score), 2), "rejected": round(float(r_score), 2)}))
writer.write("\n".join(res))
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Optional
from ...data import PairwiseDataCollatorWithPadding, get_dataset, get_template_and_fix_tokenizer
from ...extras.ploting import plot_loss
from ...model import load_model, load_tokenizer
from ..callbacks import fix_valuehead_checkpoint
from ..trainer_utils import create_modelcard_and_push
from .metric import ComputeAccuracy
from .trainer import PairwiseTrainer
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback
from ...hparams import DataArguments, FinetuningArguments, ModelArguments
def run_rm(
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
callbacks: Optional[list["TrainerCallback"]] = None,
):
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]
template = get_template_and_fix_tokenizer(tokenizer, data_args)
dataset_module = get_dataset(template, model_args, data_args, training_args, stage="rm", **tokenizer_module)
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True)
data_collator = PairwiseDataCollatorWithPadding(
template=template, model=model, pad_to_multiple_of=8, **tokenizer_module
)
# Initialize our Trainer
trainer = PairwiseTrainer(
model=model,
args=training_args,
finetuning_args=finetuning_args,
data_collator=data_collator,
callbacks=callbacks,
compute_metrics=ComputeAccuracy(),
**dataset_module,
**tokenizer_module,
)
# Training
if training_args.do_train:
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
trainer.save_model()
if training_args.should_save:
fix_valuehead_checkpoint(model, training_args.output_dir, training_args.save_safetensors)
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
keys = ["loss"]
if isinstance(dataset_module.get("eval_dataset"), dict):
keys += sum(
[[f"eval_{key}_loss", f"eval_{key}_accuracy"] for key in dataset_module["eval_dataset"].keys()], []
)
else:
keys += ["eval_loss", "eval_accuracy"]
plot_loss(training_args.output_dir, keys=keys)
# Evaluation
if training_args.do_eval:
metrics = trainer.evaluate(metric_key_prefix="eval")
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
# Predict
if training_args.do_predict:
predict_results = trainer.predict(dataset_module["eval_dataset"], metric_key_prefix="predict")
trainer.log_metrics("predict", predict_results.metrics)
trainer.save_metrics("predict", predict_results.metrics)
trainer.save_predictions(predict_results)
# Create model card
create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .workflow import run_sft
__all__ = ["run_sft"]
# Copyright 2025 HuggingFace Inc., THUDM, and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library and the THUDM's ChatGLM implementation.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py
# https://github.com/THUDM/ChatGLM-6B/blob/main/ptuning/main.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional
import numpy as np
import torch
from transformers.utils import is_jieba_available, is_nltk_available
from ...extras.constants import IGNORE_INDEX
from ...extras.misc import numpify
from ...extras.packages import is_rouge_available
if TYPE_CHECKING:
from transformers import EvalPrediction, PreTrainedTokenizer
if is_jieba_available():
import jieba # type: ignore
if is_nltk_available():
from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu # type: ignore
if is_rouge_available():
from rouge_chinese import Rouge # type: ignore
def eval_logit_processor(logits: "torch.Tensor", labels: "torch.Tensor") -> "torch.Tensor":
r"""Compute the token with the largest likelihood to reduce memory footprint."""
if isinstance(logits, (list, tuple)):
if logits[0].dim() == 3: # (batch_size, seq_len, vocab_size)
logits = logits[0]
else: # moe models have aux loss
logits = logits[1]
if logits.dim() != 3:
raise ValueError("Cannot process the logits.")
return torch.argmax(logits, dim=-1)
@dataclass
class ComputeAccuracy:
r"""Compute accuracy and support `batch_eval_metrics`."""
def _dump(self) -> Optional[dict[str, float]]:
result = None
if hasattr(self, "score_dict"):
result = {k: float(np.mean(v)) for k, v in self.score_dict.items()}
self.score_dict = {"accuracy": []}
return result
def __post_init__(self):
self._dump()
def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> Optional[dict[str, float]]:
preds, labels = numpify(eval_preds.predictions), numpify(eval_preds.label_ids)
for i in range(len(preds)):
pred, label = preds[i, :-1], labels[i, 1:]
label_mask = label != IGNORE_INDEX
self.score_dict["accuracy"].append(np.mean(pred[label_mask] == label[label_mask]))
if compute_result:
return self._dump()
@dataclass
class ComputeSimilarity:
r"""Compute text similarity scores and support `batch_eval_metrics`.
Wraps the tokenizer into metric functions, used in CustomSeq2SeqTrainer.
"""
tokenizer: "PreTrainedTokenizer"
def _dump(self) -> Optional[dict[str, float]]:
result = None
if hasattr(self, "score_dict"):
result = {k: float(np.mean(v)) for k, v in self.score_dict.items()}
self.score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
return result
def __post_init__(self):
self._dump()
def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> Optional[dict[str, float]]:
preds, labels = numpify(eval_preds.predictions), numpify(eval_preds.label_ids)
preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id)
labels = np.where(labels != IGNORE_INDEX, labels, self.tokenizer.pad_token_id)
decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True)
decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True)
for pred, label in zip(decoded_preds, decoded_labels):
hypothesis = list(jieba.cut(pred))
reference = list(jieba.cut(label))
if len(" ".join(hypothesis).split()) == 0 or len(" ".join(reference).split()) == 0:
result = {"rouge-1": {"f": 0.0}, "rouge-2": {"f": 0.0}, "rouge-l": {"f": 0.0}}
else:
rouge = Rouge()
scores = rouge.get_scores(" ".join(hypothesis), " ".join(reference))
result = scores[0]
for k, v in result.items():
self.score_dict[k].append(round(v["f"] * 100, 4))
bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3)
self.score_dict["bleu-4"].append(round(bleu_score * 100, 4))
if compute_result:
return self._dump()
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/trainer_seq2seq.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
from types import MethodType
from typing import TYPE_CHECKING, Any, Optional, Union
import numpy as np
import torch
from transformers import Seq2SeqTrainer
from typing_extensions import override
from ...extras import logging
from ...extras.constants import IGNORE_INDEX
from ...extras.packages import is_transformers_version_greater_than
from ..callbacks import SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
if TYPE_CHECKING:
from torch.utils.data import Dataset
from transformers import PreTrainedTokenizer, ProcessorMixin
from transformers.trainer import PredictionOutput
from ...hparams import FinetuningArguments
logger = logging.get_logger(__name__)
class CustomSeq2SeqTrainer(Seq2SeqTrainer):
r"""Inherits Seq2SeqTrainer to compute generative metrics such as BLEU and ROUGE."""
def __init__(
self,
finetuning_args: "FinetuningArguments",
processor: Optional["ProcessorMixin"],
gen_kwargs: Optional[dict[str, Any]] = None,
**kwargs,
) -> None:
if is_transformers_version_greater_than("4.46"):
kwargs["processing_class"] = kwargs.pop("tokenizer")
else:
self.processing_class: PreTrainedTokenizer = kwargs.get("tokenizer")
super().__init__(**kwargs)
if processor is not None:
# avoid wrong loss under gradient accumulation
# https://github.com/huggingface/transformers/pull/36044#issuecomment-2746657112
self.model_accepts_loss_kwargs = False
self.finetuning_args = finetuning_args
if gen_kwargs is not None:
# https://github.com/huggingface/transformers/blob/v4.45.0/src/transformers/trainer_seq2seq.py#L287
self._gen_kwargs = gen_kwargs
if processor is not None:
self.add_callback(SaveProcessorCallback(processor))
if finetuning_args.use_badam:
from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback)
@override
def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None:
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
return super().create_optimizer()
@override
def create_scheduler(
self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
) -> "torch.optim.lr_scheduler.LRScheduler":
create_custom_scheduler(self.args, num_training_steps, optimizer)
return super().create_scheduler(num_training_steps, optimizer)
@override
def _get_train_sampler(self, *args, **kwargs) -> Optional["torch.utils.data.Sampler"]:
if self.finetuning_args.disable_shuffling:
return torch.utils.data.SequentialSampler(self.train_dataset)
return super()._get_train_sampler(*args, **kwargs)
@override
def compute_loss(self, model, inputs, *args, **kwargs):
return super().compute_loss(model, inputs, *args, **kwargs)
@override
def prediction_step(
self,
model: "torch.nn.Module",
inputs: dict[str, Union["torch.Tensor", Any]],
prediction_loss_only: bool,
ignore_keys: Optional[list[str]] = None,
**gen_kwargs,
) -> tuple[Optional[float], Optional["torch.Tensor"], Optional["torch.Tensor"]]:
r"""Remove the prompt part in the generated tokens.
Subclass and override to inject custom behavior.
"""
if self.args.predict_with_generate: # do not pass labels to model when generate
labels = inputs.pop("labels", None)
else:
labels = inputs.get("labels")
loss, generated_tokens, _ = super().prediction_step(
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys, **gen_kwargs
)
if generated_tokens is not None and self.args.predict_with_generate:
generated_tokens[:, : inputs["input_ids"].size(-1)] = self.processing_class.pad_token_id
generated_tokens = generated_tokens.contiguous()
return loss, generated_tokens, labels
def save_predictions(
self, dataset: "Dataset", predict_results: "PredictionOutput", skip_special_tokens: bool = True
) -> None:
r"""Save model predictions to `output_dir`.
A custom behavior that not contained in Seq2SeqTrainer.
"""
if not self.is_world_process_zero():
return
output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
logger.info_rank0(f"Saving prediction results to {output_prediction_file}")
labels = np.where(
predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.processing_class.pad_token_id
)
preds = np.where(
predict_results.predictions != IGNORE_INDEX,
predict_results.predictions,
self.processing_class.pad_token_id,
)
for i in range(len(preds)):
pad_len = np.nonzero(preds[i] != self.processing_class.pad_token_id)[0]
if len(pad_len): # move pad token to last
preds[i] = np.concatenate((preds[i][pad_len[0] :], preds[i][: pad_len[0]]), axis=-1)
decoded_inputs = self.processing_class.batch_decode(dataset["input_ids"], skip_special_tokens=False)
decoded_preds = self.processing_class.batch_decode(preds, skip_special_tokens=skip_special_tokens)
decoded_labels = self.processing_class.batch_decode(labels, skip_special_tokens=skip_special_tokens)
with open(output_prediction_file, "w", encoding="utf-8") as f:
for text, pred, label in zip(decoded_inputs, decoded_preds, decoded_labels):
f.write(json.dumps({"prompt": text, "predict": pred, "label": label}, ensure_ascii=False) + "\n")
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Optional
from ...data import SFTDataCollatorWith4DAttentionMask, get_dataset, get_template_and_fix_tokenizer
from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
from ...extras.misc import calculate_tps
from ...extras.ploting import plot_loss
from ...model import load_model, load_tokenizer
from ..trainer_utils import create_modelcard_and_push
from .metric import ComputeAccuracy, ComputeSimilarity, eval_logit_processor
from .trainer import CustomSeq2SeqTrainer
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback
from ...hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
logger = get_logger(__name__)
def run_sft(
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
callbacks: Optional[list["TrainerCallback"]] = None,
):
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]
template = get_template_and_fix_tokenizer(tokenizer, data_args)
dataset_module = get_dataset(template, model_args, data_args, training_args, stage="sft", **tokenizer_module)
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
if getattr(model, "is_quantized", False) and not training_args.do_train:
setattr(model, "_hf_peft_config_loaded", True) # hack here: make model compatible with prediction
data_collator = SFTDataCollatorWith4DAttentionMask(
template=template,
model=model if not training_args.predict_with_generate else None,
pad_to_multiple_of=8 if training_args.do_train else None, # for shift short attention
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
block_diag_attn=model_args.block_diag_attn,
attn_implementation=getattr(model.config, "_attn_implementation", None),
compute_dtype=model_args.compute_dtype,
**tokenizer_module,
)
# Metric utils
metric_module = {}
if training_args.predict_with_generate:
metric_module["compute_metrics"] = ComputeSimilarity(tokenizer=tokenizer)
elif finetuning_args.compute_accuracy:
metric_module["compute_metrics"] = ComputeAccuracy()
metric_module["preprocess_logits_for_metrics"] = eval_logit_processor
# Keyword arguments for `model.generate`
gen_kwargs = generating_args.to_dict(obey_generation_config=True)
gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids
gen_kwargs["pad_token_id"] = tokenizer.pad_token_id
# Initialize our Trainer
trainer = CustomSeq2SeqTrainer(
model=model,
args=training_args,
finetuning_args=finetuning_args,
data_collator=data_collator,
callbacks=callbacks,
gen_kwargs=gen_kwargs,
**dataset_module,
**tokenizer_module,
**metric_module,
)
# Training
if training_args.do_train:
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
trainer.save_model()
if finetuning_args.include_effective_tokens_per_second:
train_result.metrics["effective_tokens_per_sec"] = calculate_tps(
dataset_module["train_dataset"], train_result.metrics, stage="sft"
)
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
keys = ["loss"]
if isinstance(dataset_module.get("eval_dataset"), dict):
keys += sum(
[[f"eval_{key}_loss", f"eval_{key}_accuracy"] for key in dataset_module["eval_dataset"].keys()], []
)
else:
keys += ["eval_loss", "eval_accuracy"]
plot_loss(training_args.output_dir, keys=keys)
if training_args.predict_with_generate:
tokenizer.padding_side = "left" # use left-padding in generation
# Evaluation
if training_args.do_eval:
metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs)
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
# Predict
if training_args.do_predict:
logger.warning_rank0_once("Batch generation can be very slow. Consider using `scripts/vllm_infer.py` instead.")
predict_results = trainer.predict(dataset_module["eval_dataset"], metric_key_prefix="predict", **gen_kwargs)
trainer.log_metrics("predict", predict_results.metrics)
trainer.save_metrics("predict", predict_results.metrics)
trainer.save_predictions(dataset_module["eval_dataset"], predict_results, generating_args.skip_special_tokens)
# Create model card
create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Optional, Union
import torch
from peft import PeftModel
from transformers import AutoModelForCausalLM
from trl import AutoModelForCausalLMWithValueHead
from ..data import get_dataset, get_template_and_fix_tokenizer
from ..extras.misc import get_current_device
from ..hparams import get_infer_args, get_train_args
from ..model import load_model, load_tokenizer
if TYPE_CHECKING:
from peft import LoraModel
from transformers import PreTrainedModel
from ..data.data_utils import DatasetModule
def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module", diff_keys: list[str] = []) -> None:
state_dict_a = model_a.state_dict()
state_dict_b = model_b.state_dict()
assert set(state_dict_a.keys()) == set(state_dict_b.keys())
for name in state_dict_a.keys():
if any(key in name for key in diff_keys):
assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-4, atol=1e-5) is False
else:
assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-4, atol=1e-5) is True
def check_lora_model(model: "LoraModel") -> tuple[set[str], set[str]]:
linear_modules, extra_modules = set(), set()
for name, param in model.named_parameters():
if any(module in name for module in ["lora_A", "lora_B"]):
linear_modules.add(name.split(".lora_", maxsplit=1)[0].split(".")[-1])
assert param.requires_grad is True
assert param.dtype == torch.float32
elif "modules_to_save" in name:
extra_modules.add(name.split(".modules_to_save", maxsplit=1)[0].split(".")[-1])
assert param.requires_grad is True
assert param.dtype == torch.float32
else:
assert param.requires_grad is False
assert param.dtype == torch.float16
return linear_modules, extra_modules
def load_train_model(add_valuehead: bool = False, **kwargs) -> "PreTrainedModel":
model_args, _, _, finetuning_args, _ = get_train_args(kwargs)
tokenizer = load_tokenizer(model_args)["tokenizer"]
return load_model(tokenizer, model_args, finetuning_args, is_trainable=True, add_valuehead=add_valuehead)
def load_infer_model(add_valuehead: bool = False, **kwargs) -> "PreTrainedModel":
model_args, _, finetuning_args, _ = get_infer_args(kwargs)
tokenizer = load_tokenizer(model_args)["tokenizer"]
return load_model(tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=add_valuehead)
def load_reference_model(
model_path: str,
lora_path: Optional[str] = None,
use_lora: bool = False,
use_pissa: bool = False,
is_trainable: bool = False,
add_valuehead: bool = False,
) -> Union["PreTrainedModel", "LoraModel"]:
current_device = get_current_device()
if add_valuehead:
model: AutoModelForCausalLMWithValueHead = AutoModelForCausalLMWithValueHead.from_pretrained(
model_path, torch_dtype=torch.float16, device_map=current_device
)
if not is_trainable:
model.v_head = model.v_head.to(torch.float16)
return model
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, device_map=current_device)
if use_lora or use_pissa:
model = PeftModel.from_pretrained(
model, lora_path, subfolder="pissa_init" if use_pissa else None, is_trainable=is_trainable
)
for param in filter(lambda p: p.requires_grad, model.parameters()):
param.data = param.data.to(torch.float32)
return model
def load_dataset_module(**kwargs) -> "DatasetModule":
model_args, data_args, training_args, _, _ = get_train_args(kwargs)
tokenizer_module = load_tokenizer(model_args)
template = get_template_and_fix_tokenizer(tokenizer_module["tokenizer"], data_args)
dataset_module = get_dataset(template, model_args, data_args, training_args, kwargs["stage"], **tokenizer_module)
return dataset_module
def patch_valuehead_model() -> None:
def post_init(self: "AutoModelForCausalLMWithValueHead", state_dict: dict[str, "torch.Tensor"]) -> None:
state_dict = {k[7:]: state_dict[k] for k in state_dict.keys() if k.startswith("v_head.")}
self.v_head.load_state_dict(state_dict, strict=False)
del state_dict
AutoModelForCausalLMWithValueHead.post_init = post_init
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the original GaLore's implementation: https://github.com/jiaweizzhao/GaLore
# and the original LoRA+'s implementation: https://github.com/nikhil-ghosh-berkeley/loraplus
# and the original BAdam's implementation: https://github.com/Ledzy/BAdam
# and the HuggingFace's TRL library: https://github.com/huggingface/trl
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
from collections.abc import Mapping
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
import torch
from transformers import Trainer
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.modeling_utils import is_fsdp_enabled
from transformers.optimization import get_scheduler
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
from transformers.trainer_pt_utils import get_parameter_names
from typing_extensions import override
from ..extras import logging
from ..extras.constants import IGNORE_INDEX, SWANLAB_CONFIG
from ..extras.packages import is_apollo_available, is_galore_available, is_ray_available
from ..hparams import FinetuningArguments, ModelArguments
from ..model import find_all_linear_modules, load_model, load_tokenizer, load_valuehead_params
if is_galore_available():
from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit # type: ignore
if is_apollo_available():
from apollo_torch import APOLLOAdamW # type: ignore
if is_ray_available():
import ray
from ray.train import RunConfig, ScalingConfig
from ray.train.torch import TorchTrainer
if TYPE_CHECKING:
from transformers import PreTrainedModel, TrainerCallback, TrainerState
from trl import AutoModelForCausalLMWithValueHead
from ..hparams import DataArguments, RayArguments, TrainingArguments
logger = logging.get_logger(__name__)
class DummyOptimizer(torch.optim.Optimizer):
r"""A dummy optimizer used for the GaLore or APOLLO algorithm."""
def __init__(
self, lr: float = 1e-3, optimizer_dict: Optional[dict["torch.nn.Parameter", "torch.optim.Optimizer"]] = None
) -> None:
dummy_tensor = torch.randn(1, 1)
self.optimizer_dict = optimizer_dict
super().__init__([dummy_tensor], {"lr": lr})
@override
def zero_grad(self, set_to_none: bool = True) -> None:
pass
@override
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
pass
def create_modelcard_and_push(
trainer: "Trainer",
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "TrainingArguments",
finetuning_args: "FinetuningArguments",
) -> None:
kwargs = {
"tasks": "text-generation",
"finetuned_from": model_args.model_name_or_path,
"tags": ["llama-factory", finetuning_args.finetuning_type],
}
if data_args.dataset is not None:
kwargs["dataset"] = data_args.dataset
if model_args.use_unsloth:
kwargs["tags"] = kwargs["tags"] + ["unsloth"]
if not training_args.do_train:
pass
elif training_args.push_to_hub:
trainer.push_to_hub(**kwargs)
else:
trainer.create_model_card(license="other", **kwargs) # prevent from connecting to hub
def create_ref_model(
model_args: "ModelArguments", finetuning_args: "FinetuningArguments", add_valuehead: bool = False
) -> Optional[Union["PreTrainedModel", "AutoModelForCausalLMWithValueHead"]]:
r"""Create reference model for PPO/DPO training. Evaluation mode is not supported.
The valuehead parameter is randomly initialized since it is useless for PPO training.
"""
if finetuning_args.ref_model is not None:
ref_model_args = ModelArguments.copyfrom(
model_args,
model_name_or_path=finetuning_args.ref_model,
adapter_name_or_path=finetuning_args.ref_model_adapters,
quantization_bit=finetuning_args.ref_model_quantization_bit,
)
ref_finetuning_args = FinetuningArguments()
tokenizer = load_tokenizer(ref_model_args)["tokenizer"]
ref_model = load_model(
tokenizer, ref_model_args, ref_finetuning_args, is_trainable=False, add_valuehead=add_valuehead
)
logger.info_rank0(f"Created reference model from {finetuning_args.ref_model}")
else:
if finetuning_args.finetuning_type == "lora":
ref_model = None
else:
ref_model_args = ModelArguments.copyfrom(model_args)
ref_finetuning_args = FinetuningArguments()
tokenizer = load_tokenizer(ref_model_args)["tokenizer"]
ref_model = load_model(
tokenizer, ref_model_args, ref_finetuning_args, is_trainable=False, add_valuehead=add_valuehead
)
logger.info_rank0("Created reference model from the model itself.")
return ref_model
def create_reward_model(
model: "AutoModelForCausalLMWithValueHead", model_args: "ModelArguments", finetuning_args: "FinetuningArguments"
) -> Optional["AutoModelForCausalLMWithValueHead"]:
r"""Create reward model for PPO training."""
if finetuning_args.reward_model_type == "api":
assert finetuning_args.reward_model.startswith("http"), "Please provide full url."
logger.info_rank0(f"Use reward server {finetuning_args.reward_model}")
return finetuning_args.reward_model
elif finetuning_args.reward_model_type == "lora":
model.pretrained_model.load_adapter(finetuning_args.reward_model, "reward")
for name, param in model.named_parameters(): # https://github.com/huggingface/peft/issues/1090
if "default" in name:
param.data = param.data.to(torch.float32) # trainable params should in fp32
vhead_params = load_valuehead_params(finetuning_args.reward_model, model_args)
assert vhead_params is not None, "Reward model is not correctly loaded."
model.register_buffer("reward_head_weight", vhead_params["v_head.summary.weight"], persistent=False)
model.register_buffer("reward_head_bias", vhead_params["v_head.summary.bias"], persistent=False)
model.register_buffer(
"default_head_weight", torch.zeros_like(vhead_params["v_head.summary.weight"]), persistent=False
)
model.register_buffer(
"default_head_bias", torch.zeros_like(vhead_params["v_head.summary.bias"]), persistent=False
)
logger.info_rank0(f"Loaded adapter weights of reward model from {finetuning_args.reward_model}")
return None
else:
reward_model_args = ModelArguments.copyfrom(
model_args,
model_name_or_path=finetuning_args.reward_model,
adapter_name_or_path=finetuning_args.reward_model_adapters,
quantization_bit=finetuning_args.reward_model_quantization_bit,
)
reward_finetuning_args = FinetuningArguments()
tokenizer = load_tokenizer(reward_model_args)["tokenizer"]
reward_model = load_model(
tokenizer, reward_model_args, reward_finetuning_args, is_trainable=False, add_valuehead=True
)
logger.info_rank0(f"Loaded full weights of reward model from {finetuning_args.reward_model}")
logger.warning_rank0("Please ensure the ppo model and reward model share SAME tokenizer and vocabulary.")
return reward_model
def _get_decay_parameter_names(model: "PreTrainedModel") -> list[str]:
r"""Return a list of names of parameters with weight decay. (weights in non-layernorm layers)."""
decay_parameters = get_parameter_names(model, ALL_LAYERNORM_LAYERS)
decay_parameters = [name for name in decay_parameters if "bias" not in name]
return decay_parameters
def _create_galore_optimizer(
model: "PreTrainedModel",
training_args: "TrainingArguments",
finetuning_args: "FinetuningArguments",
) -> "torch.optim.Optimizer":
if len(finetuning_args.galore_target) == 1 and finetuning_args.galore_target[0] == "all":
galore_targets = find_all_linear_modules(model, finetuning_args.freeze_vision_tower)
else:
galore_targets = finetuning_args.galore_target
galore_params: list[torch.nn.Parameter] = []
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear) and any(target in name for target in galore_targets):
for param in module.parameters():
if param.requires_grad and len(param.shape) > 1:
galore_params.append(param)
galore_kwargs = {
"rank": finetuning_args.galore_rank,
"update_proj_gap": finetuning_args.galore_update_interval,
"scale": finetuning_args.galore_scale,
"proj_type": finetuning_args.galore_proj_type,
}
id_galore_params = {id(param) for param in galore_params}
decay_params, nodecay_params = [], [] # they are non-galore parameters
trainable_params: list[torch.nn.Parameter] = [] # galore_params + decay_params + nodecay_params
decay_param_names = _get_decay_parameter_names(model)
for name, param in model.named_parameters():
if param.requires_grad:
trainable_params.append(param)
if id(param) not in id_galore_params:
if name in decay_param_names:
decay_params.append(param)
else:
nodecay_params.append(param)
_, optim_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args)
if training_args.optim == "adamw_torch":
optim_class = GaLoreAdamW
elif training_args.optim in ["adamw_bnb_8bit", "adamw_8bit", "paged_adamw_8bit"]:
optim_class = GaLoreAdamW8bit
elif training_args.optim == "adafactor":
optim_class = GaLoreAdafactor
else:
raise NotImplementedError(f"Unknown optim: {training_args.optim}.")
if finetuning_args.galore_layerwise:
logger.warning_rank0("The displayed gradient norm will be all zeros in layerwise GaLore.")
if training_args.gradient_accumulation_steps != 1:
raise ValueError("Per-layer GaLore does not support gradient accumulation.")
optimizer_dict: dict[torch.Tensor, torch.optim.Optimizer] = {}
for param in nodecay_params:
param_groups = [dict(params=[param], weight_decay=0.0)]
optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
for param in decay_params:
param_groups = [dict(params=[param], weight_decay=training_args.weight_decay)]
optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
for param in galore_params: # galore params have weight decay
param_groups = [dict(params=[param], weight_decay=training_args.weight_decay, **galore_kwargs)]
optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
def optimizer_hook(param: "torch.nn.Parameter"):
if param.grad is not None:
optimizer_dict[param].step()
optimizer_dict[param].zero_grad()
for param in trainable_params:
param.register_post_accumulate_grad_hook(optimizer_hook)
optimizer = DummyOptimizer(lr=training_args.learning_rate, optimizer_dict=optimizer_dict)
else:
param_groups = [
dict(params=nodecay_params, weight_decay=0.0),
dict(params=decay_params, weight_decay=training_args.weight_decay),
dict(params=galore_params, weight_decay=training_args.weight_decay, **galore_kwargs),
]
optimizer = optim_class(param_groups, **optim_kwargs)
logger.info_rank0(
f"Using GaLore optimizer with args: {galore_kwargs}. "
"It may cause hanging at the start of training, wait patiently."
)
return optimizer
def _create_apollo_optimizer(
model: "PreTrainedModel",
training_args: "TrainingArguments",
finetuning_args: "FinetuningArguments",
) -> "torch.optim.Optimizer":
if len(finetuning_args.apollo_target) == 1 and finetuning_args.apollo_target[0] == "all":
apollo_targets = find_all_linear_modules(model, finetuning_args.freeze_vision_tower)
else:
apollo_targets = finetuning_args.apollo_target
apollo_params: list[torch.nn.Parameter] = []
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear) and any(target in name for target in apollo_targets):
for param in module.parameters():
if param.requires_grad and len(param.shape) > 1:
apollo_params.append(param)
apollo_kwargs = {
"rank": finetuning_args.apollo_rank,
"proj": finetuning_args.apollo_proj,
"proj_type": finetuning_args.apollo_proj_type,
"update_proj_gap": finetuning_args.apollo_update_interval,
"scale": finetuning_args.apollo_scale,
"scale_type": finetuning_args.apollo_scale_type,
"scale_front": finetuning_args.apollo_scale_front,
}
id_apollo_params = {id(param) for param in apollo_params}
decay_params, nodecay_params = [], [] # they are non-apollo parameters
trainable_params: list[torch.nn.Parameter] = [] # apollo_params + decay_params + nodecay_params
decay_param_names = _get_decay_parameter_names(model)
for name, param in model.named_parameters():
if param.requires_grad:
trainable_params.append(param)
if id(param) not in id_apollo_params:
if name in decay_param_names:
decay_params.append(param)
else:
nodecay_params.append(param)
_, optim_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args)
if training_args.optim == "adamw_torch":
optim_class = APOLLOAdamW
else:
raise NotImplementedError(f"Unknown optim: {training_args.optim}.")
if finetuning_args.apollo_layerwise:
logger.warning_rank0("The displayed gradient norm will be all zeros in layerwise APOLLO.")
if training_args.gradient_accumulation_steps != 1:
raise ValueError("Per-layer APOLLO does not support gradient accumulation.")
optimizer_dict: dict[torch.Tensor, torch.optim.Optimizer] = {}
for param in nodecay_params:
param_groups = [dict(params=[param], weight_decay=0.0)]
optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
for param in decay_params:
param_groups = [dict(params=[param], weight_decay=training_args.weight_decay)]
optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
for param in apollo_params: # apollo params have weight decay
param_groups = [dict(params=[param], weight_decay=training_args.weight_decay, **apollo_kwargs)]
optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
def optimizer_hook(param: "torch.nn.Parameter"):
if param.grad is not None:
optimizer_dict[param].step()
optimizer_dict[param].zero_grad()
for param in trainable_params:
param.register_post_accumulate_grad_hook(optimizer_hook)
optimizer = DummyOptimizer(lr=training_args.learning_rate, optimizer_dict=optimizer_dict)
else:
param_groups = [
dict(params=nodecay_params, weight_decay=0.0),
dict(params=decay_params, weight_decay=training_args.weight_decay),
dict(params=apollo_params, weight_decay=training_args.weight_decay, **apollo_kwargs),
]
optimizer = optim_class(param_groups, **optim_kwargs)
logger.info_rank0(f"Using APOLLO optimizer with args: {apollo_kwargs}.")
return optimizer
def _create_loraplus_optimizer(
model: "PreTrainedModel",
training_args: "TrainingArguments",
finetuning_args: "FinetuningArguments",
) -> "torch.optim.Optimizer":
default_lr = training_args.learning_rate
loraplus_lr = training_args.learning_rate * finetuning_args.loraplus_lr_ratio
embedding_lr = finetuning_args.loraplus_lr_embedding
decay_param_names = _get_decay_parameter_names(model)
param_dict: dict[str, list[torch.nn.Parameter]] = {
"lora_a": [],
"lora_b": [],
"lora_b_nodecay": [],
"embedding": [],
}
for name, param in model.named_parameters():
if param.requires_grad:
if "lora_embedding_B" in name:
param_dict["embedding"].append(param)
elif "lora_B" in name or param.ndim == 1:
if name in decay_param_names:
param_dict["lora_b"].append(param)
else:
param_dict["lora_b_nodecay"].append(param)
else:
param_dict["lora_a"].append(param)
optim_class, optim_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args)
param_groups = [
dict(params=param_dict["lora_a"], lr=default_lr, weight_decay=training_args.weight_decay),
dict(params=param_dict["lora_b"], lr=loraplus_lr, weight_decay=training_args.weight_decay),
dict(params=param_dict["lora_b_nodecay"], lr=loraplus_lr, weight_decay=0.0),
dict(params=param_dict["embedding"], lr=embedding_lr, weight_decay=training_args.weight_decay),
]
optimizer = optim_class(param_groups, **optim_kwargs)
logger.info_rank0(f"Using LoRA+ optimizer with loraplus lr ratio {finetuning_args.loraplus_lr_ratio:.2f}.")
return optimizer
def _create_badam_optimizer(
model: "PreTrainedModel",
training_args: "TrainingArguments",
finetuning_args: "FinetuningArguments",
) -> "torch.optim.Optimizer":
decay_params, nodecay_params = [], []
decay_param_names = _get_decay_parameter_names(model)
for name, param in model.named_parameters():
if param.requires_grad:
if name in decay_param_names:
decay_params.append(param)
else:
nodecay_params.append(param)
optim_class, optim_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args)
param_groups = [
dict(params=nodecay_params, weight_decay=0.0),
dict(params=decay_params, weight_decay=training_args.weight_decay),
]
if finetuning_args.badam_mode == "layer":
from badam import BlockOptimizer # type: ignore
base_optimizer = optim_class(param_groups, **optim_kwargs)
optimizer = BlockOptimizer(
base_optimizer=base_optimizer,
named_parameters_list=list(model.named_parameters()),
block_prefix_list=None,
switch_block_every=finetuning_args.badam_switch_interval,
start_block=finetuning_args.badam_start_block,
switch_mode=finetuning_args.badam_switch_mode,
verbose=finetuning_args.badam_verbose,
ds_zero3_enabled=is_deepspeed_zero3_enabled(),
)
logger.info_rank0(
f"Using BAdam optimizer with layer-wise update, switch mode is {finetuning_args.badam_switch_mode}, "
f"switch block every {finetuning_args.badam_switch_interval} steps, "
f"default start block is {finetuning_args.badam_start_block}"
)
elif finetuning_args.badam_mode == "ratio":
from badam import BlockOptimizerRatio # type: ignore
assert finetuning_args.badam_update_ratio > 1e-6
optimizer = BlockOptimizerRatio(
param_groups=param_groups,
named_parameters_list=list(model.named_parameters()),
update_ratio=finetuning_args.badam_update_ratio,
mask_mode=finetuning_args.badam_mask_mode,
verbose=finetuning_args.badam_verbose,
include_embedding=False,
**optim_kwargs,
)
logger.info_rank0(
f"Using BAdam optimizer with ratio-based update, update ratio is {finetuning_args.badam_update_ratio}, "
f"mask mode is {finetuning_args.badam_mask_mode}"
)
return optimizer
def _create_adam_mini_optimizer(
model: "PreTrainedModel",
training_args: "TrainingArguments",
) -> "torch.optim.Optimizer":
from adam_mini import Adam_mini # type: ignore
hidden_size = getattr(model.config, "hidden_size", None)
num_q_head = getattr(model.config, "num_attention_heads", None)
num_kv_head = getattr(model.config, "num_key_value_heads", None)
optimizer = Adam_mini(
named_parameters=model.named_parameters(),
lr=training_args.learning_rate,
betas=(training_args.adam_beta1, training_args.adam_beta2),
eps=training_args.adam_epsilon,
weight_decay=training_args.weight_decay,
model_sharding=is_fsdp_enabled() or is_deepspeed_zero3_enabled(),
dim=hidden_size,
n_heads=num_q_head,
n_kv_heads=num_kv_head,
)
logger.info_rank0("Using Adam-mini optimizer.")
return optimizer
def _create_muon_optimizer(
model: "PreTrainedModel",
training_args: "TrainingArguments",
) -> "torch.optim.Optimizer":
from ..third_party.muon import Muon
muon_params, adamw_params = [], []
for name, param in model.named_parameters():
if param.requires_grad:
# Use Muon for 2D parameters that aren't embeddings or heads
if param.ndim == 2 and "embed" not in name and "lm_head" not in name:
muon_params.append(param)
else:
adamw_params.append(param)
optimizer = Muon(
lr=training_args.learning_rate,
wd=training_args.weight_decay,
muon_params=muon_params,
adamw_params=adamw_params,
adamw_betas=(training_args.adam_beta1, training_args.adam_beta2),
adamw_eps=training_args.adam_epsilon,
)
logger.info_rank0(
f"Using Muon optimizer with {len(muon_params)} Muon params and {len(adamw_params)} AdamW params."
)
return optimizer
def create_custom_optimizer(
model: "PreTrainedModel",
training_args: "TrainingArguments",
finetuning_args: "FinetuningArguments",
) -> Optional["torch.optim.Optimizer"]:
if finetuning_args.use_galore:
return _create_galore_optimizer(model, training_args, finetuning_args)
if finetuning_args.use_apollo:
return _create_apollo_optimizer(model, training_args, finetuning_args)
if finetuning_args.loraplus_lr_ratio is not None:
return _create_loraplus_optimizer(model, training_args, finetuning_args)
if finetuning_args.use_badam:
return _create_badam_optimizer(model, training_args, finetuning_args)
if finetuning_args.use_adam_mini:
return _create_adam_mini_optimizer(model, training_args)
if finetuning_args.use_muon:
return _create_muon_optimizer(model, training_args)
def create_custom_scheduler(
training_args: "TrainingArguments",
num_training_steps: int,
optimizer: Optional["torch.optim.Optimizer"] = None,
) -> None:
if training_args.lr_scheduler_type == "warmup_stable_decay":
num_warmup_steps = training_args.get_warmup_steps(num_training_steps)
remaining_steps = num_training_steps - num_warmup_steps
num_stable_steps = remaining_steps // 3 # use 1/3 for stable by default
num_decay_steps = remaining_steps - num_stable_steps
scheduler_kwargs = training_args.lr_scheduler_kwargs or {}
default_kwargs = {
"num_stable_steps": num_stable_steps,
"num_decay_steps": num_decay_steps,
}
for key, value in default_kwargs.items():
if key not in scheduler_kwargs:
scheduler_kwargs[key] = value
training_args.lr_scheduler_kwargs = scheduler_kwargs
if optimizer is not None and isinstance(optimizer, DummyOptimizer):
optimizer_dict = optimizer.optimizer_dict
scheduler_dict: dict[torch.nn.Parameter, torch.optim.lr_scheduler.LRScheduler] = {}
for param in optimizer_dict.keys():
scheduler_dict[param] = get_scheduler(
training_args.lr_scheduler_type,
optimizer=optimizer_dict[param],
num_warmup_steps=training_args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
scheduler_specific_kwargs=training_args.lr_scheduler_kwargs,
)
def scheduler_hook(param: "torch.nn.Parameter"):
scheduler_dict[param].step()
for param in optimizer_dict.keys():
param.register_post_accumulate_grad_hook(scheduler_hook)
def get_batch_logps(
logits: "torch.Tensor",
labels: "torch.Tensor",
label_pad_token_id: int = IGNORE_INDEX,
ld_alpha: Optional[float] = None,
) -> tuple["torch.Tensor", "torch.Tensor"]:
r"""Compute the log probabilities of the given labels under the given logits.
Returns:
logps: A tensor of shape (batch_size,) containing the sum of log probabilities.
valid_length: A tensor of shape (batch_size,) containing the number of non-masked tokens.
"""
if logits.shape[:-1] != labels.shape:
raise ValueError("Logits (batchsize x seqlen) and labels must have the same shape.")
labels = labels[:, 1:].clone()
logits = logits[:, :-1, :]
loss_mask = labels != label_pad_token_id
labels[labels == label_pad_token_id] = 0 # dummy token
per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
valid_length = loss_mask.sum(-1)
if ld_alpha is not None:
num_examples = labels.shape[0] // 2
chosen_lengths = valid_length[:num_examples]
rejected_lengths = valid_length[num_examples:]
min_lengths = torch.min(chosen_lengths, rejected_lengths)
start_positions = torch.argmax(loss_mask.int(), dim=1)
public_lengths = start_positions + torch.cat([min_lengths, min_lengths], dim=0)
seq_len = labels.shape[-1]
position_ids = torch.arange(seq_len, device=per_token_logps.device).expand_as(per_token_logps)
ld_mask = position_ids < public_lengths.unsqueeze(1)
front_mask = (ld_mask * loss_mask).float()
rear_mask = (~ld_mask * loss_mask).float()
front_logps = (per_token_logps * front_mask).sum(-1)
rear_logps = (per_token_logps * rear_mask).sum(-1)
logps = front_logps + ld_alpha * rear_logps
else:
logps = (per_token_logps * loss_mask).sum(-1)
return logps, valid_length
def nested_detach(
tensors: Union["torch.Tensor", list["torch.Tensor"], tuple["torch.Tensor"], dict[str, "torch.Tensor"]],
clone: bool = False,
):
r"""Detach `tensors` (even if it's a nested list/tuple/dict of tensors)."""
if isinstance(tensors, (list, tuple)):
return type(tensors)(nested_detach(t, clone=clone) for t in tensors)
elif isinstance(tensors, Mapping):
return type(tensors)({k: nested_detach(t, clone=clone) for k, t in tensors.items()})
if isinstance(tensors, torch.Tensor):
if clone:
return tensors.detach().clone()
else:
return tensors.detach()
else:
return tensors
def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCallback":
r"""Get the callback for logging to SwanLab."""
import swanlab # type: ignore
from swanlab.integration.transformers import SwanLabCallback # type: ignore
if finetuning_args.swanlab_api_key is not None:
swanlab.login(api_key=finetuning_args.swanlab_api_key)
if finetuning_args.swanlab_lark_webhook_url is not None:
from swanlab.plugin.notification import LarkCallback # type: ignore
lark_callback = LarkCallback(
webhook_url=finetuning_args.swanlab_lark_webhook_url,
secret=finetuning_args.swanlab_lark_secret,
)
swanlab.register_callbacks([lark_callback])
class SwanLabCallbackExtension(SwanLabCallback):
def setup(self, args: "TrainingArguments", state: "TrainerState", model: "PreTrainedModel", **kwargs):
if not state.is_world_process_zero:
return
super().setup(args, state, model, **kwargs)
try:
if hasattr(self, "_swanlab"):
swanlab_public_config = self._swanlab.get_run().public.json()
else: # swanlab <= 0.4.9
swanlab_public_config = self._experiment.get_run().public.json()
except Exception:
swanlab_public_config = {}
with open(os.path.join(args.output_dir, SWANLAB_CONFIG), "w") as f:
f.write(json.dumps(swanlab_public_config, indent=2))
swanlab_callback = SwanLabCallbackExtension(
project=finetuning_args.swanlab_project,
workspace=finetuning_args.swanlab_workspace,
experiment_name=finetuning_args.swanlab_run_name,
mode=finetuning_args.swanlab_mode,
config={"Framework": "🦙LlamaFactory"},
logdir=finetuning_args.swanlab_logdir,
tags=["🦙LlamaFactory"],
)
return swanlab_callback
def get_ray_trainer(
training_function: Callable,
train_loop_config: dict[str, Any],
ray_args: "RayArguments",
) -> "TorchTrainer":
if not ray_args.use_ray:
raise ValueError("Ray was not enabled. Please set `USE_RAY=1` to enable ray.")
if ray_args.ray_init_kwargs is not None:
ray.init(**ray_args.ray_init_kwargs)
if ray_args.ray_storage_filesystem is not None:
# this means we are using s3/gcs
storage_path = ray_args.ray_storage_path
else:
storage_path = Path(ray_args.ray_storage_path).absolute().as_posix()
trainer = TorchTrainer(
training_function,
train_loop_config=train_loop_config,
scaling_config=ScalingConfig(
num_workers=ray_args.ray_num_workers,
resources_per_worker=ray_args.resources_per_worker,
placement_strategy=ray_args.placement_strategy,
use_gpu=True,
),
run_config=RunConfig(
name=ray_args.ray_run_name,
storage_filesystem=ray_args.ray_storage_filesystem,
storage_path=storage_path,
),
)
return trainer
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import shutil
from typing import TYPE_CHECKING, Any, Optional
import torch
import torch.distributed as dist
from transformers import EarlyStoppingCallback, PreTrainedModel
from ..data import get_template_and_fix_tokenizer
from ..extras import logging
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from ..extras.misc import infer_optim_dtype
from ..extras.packages import is_ray_available
from ..hparams import get_infer_args, get_ray_args, get_train_args, read_args
from ..model import load_model, load_tokenizer
from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback
from .dpo import run_dpo
from .kto import run_kto
from .pt import run_pt
from .rm import run_rm
from .sft import run_sft
from .grpo import run_grpo
from .trainer_utils import get_ray_trainer, get_swanlab_callback
if is_ray_available():
import ray
from ray.train.huggingface.transformers import RayTrainReportCallback
if TYPE_CHECKING:
from transformers import TrainerCallback
logger = logging.get_logger(__name__)
def _training_function(config: dict[str, Any]) -> None:
args = config.get("args")
callbacks: list[Any] = config.get("callbacks")
model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args)
callbacks.append(LogCallback())
if finetuning_args.pissa_convert:
callbacks.append(PissaConvertCallback())
if finetuning_args.use_swanlab:
callbacks.append(get_swanlab_callback(finetuning_args))
if finetuning_args.early_stopping_steps is not None:
callbacks.append(EarlyStoppingCallback(early_stopping_patience=finetuning_args.early_stopping_steps))
callbacks.append(ReporterCallback(model_args, data_args, finetuning_args, generating_args)) # add to last
if finetuning_args.stage == "pt":
run_pt(model_args, data_args, training_args, finetuning_args, callbacks)
elif finetuning_args.stage == "sft":
run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
elif finetuning_args.stage == "rm":
run_rm(model_args, data_args, training_args, finetuning_args, callbacks)
elif finetuning_args.stage == "grpo":
run_grpo(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
elif finetuning_args.stage == "dpo":
run_dpo(model_args, data_args, training_args, finetuning_args, callbacks)
elif finetuning_args.stage == "kto":
run_kto(model_args, data_args, training_args, finetuning_args, callbacks)
else:
raise ValueError(f"Unknown task: {finetuning_args.stage}.")
if is_ray_available() and ray.is_initialized():
return # if ray is intialized it will destroy the process group on return
try:
if dist.is_initialized():
dist.destroy_process_group()
except Exception as e:
logger.warning(f"Failed to destroy process group: {e}.")
def run_exp(args: Optional[dict[str, Any]] = None, callbacks: Optional[list["TrainerCallback"]] = None) -> None:
args = read_args(args)
if "-h" in args or "--help" in args:
get_train_args(args)
ray_args = get_ray_args(args)
callbacks = callbacks or []
if ray_args.use_ray:
callbacks.append(RayTrainReportCallback())
trainer = get_ray_trainer(
training_function=_training_function,
train_loop_config={"args": args, "callbacks": callbacks},
ray_args=ray_args,
)
trainer.fit()
else:
_training_function(config={"args": args, "callbacks": callbacks})
def export_model(args: Optional[dict[str, Any]] = None) -> None:
model_args, data_args, finetuning_args, _ = get_infer_args(args)
if model_args.export_dir is None:
raise ValueError("Please specify `export_dir` to save model.")
if model_args.adapter_name_or_path is not None and model_args.export_quantization_bit is not None:
raise ValueError("Please merge adapters before quantizing the model.")
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]
processor = tokenizer_module["processor"]
template = get_template_and_fix_tokenizer(tokenizer, data_args)
model = load_model(tokenizer, model_args, finetuning_args) # must after fixing tokenizer to resize vocab
if getattr(model, "quantization_method", None) is not None and model_args.adapter_name_or_path is not None:
raise ValueError("Cannot merge adapters to a quantized model.")
if not isinstance(model, PreTrainedModel):
raise ValueError("The model is not a `PreTrainedModel`, export aborted.")
if getattr(model, "quantization_method", None) is not None: # quantized model adopts float16 type
setattr(model.config, "torch_dtype", torch.float16)
else:
if model_args.infer_dtype == "auto":
output_dtype = getattr(model.config, "torch_dtype", torch.float32)
if output_dtype == torch.float32: # if infer_dtype is auto, try using half precision first
output_dtype = infer_optim_dtype(torch.bfloat16)
else:
output_dtype = getattr(torch, model_args.infer_dtype)
setattr(model.config, "torch_dtype", output_dtype)
model = model.to(output_dtype)
logger.info_rank0(f"Convert model dtype to: {output_dtype}.")
model.save_pretrained(
save_directory=model_args.export_dir,
max_shard_size=f"{model_args.export_size}GB",
safe_serialization=(not model_args.export_legacy_format),
)
if model_args.export_hub_model_id is not None:
model.push_to_hub(
model_args.export_hub_model_id,
token=model_args.hf_hub_token,
max_shard_size=f"{model_args.export_size}GB",
safe_serialization=(not model_args.export_legacy_format),
)
if finetuning_args.stage == "rm":
if model_args.adapter_name_or_path is not None:
vhead_path = model_args.adapter_name_or_path[-1]
else:
vhead_path = model_args.model_name_or_path
if os.path.exists(os.path.join(vhead_path, V_HEAD_SAFE_WEIGHTS_NAME)):
shutil.copy(
os.path.join(vhead_path, V_HEAD_SAFE_WEIGHTS_NAME),
os.path.join(model_args.export_dir, V_HEAD_SAFE_WEIGHTS_NAME),
)
logger.info_rank0(f"Copied valuehead to {model_args.export_dir}.")
elif os.path.exists(os.path.join(vhead_path, V_HEAD_WEIGHTS_NAME)):
shutil.copy(
os.path.join(vhead_path, V_HEAD_WEIGHTS_NAME),
os.path.join(model_args.export_dir, V_HEAD_WEIGHTS_NAME),
)
logger.info_rank0(f"Copied valuehead to {model_args.export_dir}.")
try:
tokenizer.padding_side = "left" # restore padding side
tokenizer.init_kwargs["padding_side"] = "left"
tokenizer.save_pretrained(model_args.export_dir)
if model_args.export_hub_model_id is not None:
tokenizer.push_to_hub(model_args.export_hub_model_id, token=model_args.hf_hub_token)
if processor is not None:
processor.save_pretrained(model_args.export_dir)
if model_args.export_hub_model_id is not None:
processor.push_to_hub(model_args.export_hub_model_id, token=model_args.hf_hub_token)
except Exception as e:
logger.warning_rank0(f"Cannot save tokenizer, please copy the files manually: {e}.")
ollama_modelfile = os.path.join(model_args.export_dir, "Modelfile")
with open(ollama_modelfile, "w", encoding="utf-8") as f:
f.write(template.get_ollama_modelfile(tokenizer))
logger.info_rank0(f"Ollama modelfile saved in {ollama_modelfile}")
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
from collections.abc import Generator
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Optional
from transformers.utils import is_torch_npu_available
from ..chat import ChatModel
from ..data import Role
from ..extras.constants import PEFT_METHODS
from ..extras.misc import torch_gc
from ..extras.packages import is_gradio_available
from .common import get_save_dir, load_config
from .locales import ALERTS
if TYPE_CHECKING:
from ..chat import BaseEngine
from .manager import Manager
if is_gradio_available():
import gradio as gr
def _escape_html(text: str) -> str:
r"""Escape HTML characters."""
return text.replace("<", "&lt;").replace(">", "&gt;")
def _format_response(text: str, lang: str, escape_html: bool, thought_words: tuple[str, str]) -> str:
r"""Post-process the response text.
Based on: https://huggingface.co/spaces/Lyte/DeepSeek-R1-Distill-Qwen-1.5B-Demo-GGUF/blob/main/app.py
"""
if thought_words[0] not in text:
return _escape_html(text) if escape_html else text
text = text.replace(thought_words[0], "")
result = text.split(thought_words[1], maxsplit=1)
if len(result) == 1:
summary = ALERTS["info_thinking"][lang]
thought, answer = text, ""
else:
summary = ALERTS["info_thought"][lang]
thought, answer = result
if escape_html:
thought, answer = _escape_html(thought), _escape_html(answer)
return (
f"<details open><summary class='thinking-summary'><span>{summary}</span></summary>\n\n"
f"<div class='thinking-container'>\n{thought}\n</div>\n</details>{answer}"
)
@contextmanager
def update_attr(obj: Any, name: str, value: Any):
old_value = getattr(obj, name, None)
setattr(obj, name, value)
yield
setattr(obj, name, old_value)
class WebChatModel(ChatModel):
def __init__(self, manager: "Manager", demo_mode: bool = False, lazy_init: bool = True) -> None:
self.manager = manager
self.demo_mode = demo_mode
self.engine: Optional[BaseEngine] = None
if not lazy_init: # read arguments from command line
super().__init__()
if demo_mode and os.getenv("DEMO_MODEL") and os.getenv("DEMO_TEMPLATE"): # load demo model
model_name_or_path = os.getenv("DEMO_MODEL")
template = os.getenv("DEMO_TEMPLATE")
infer_backend = os.getenv("DEMO_BACKEND", "huggingface")
super().__init__(
dict(model_name_or_path=model_name_or_path, template=template, infer_backend=infer_backend)
)
@property
def loaded(self) -> bool:
return self.engine is not None
def load_model(self, data) -> Generator[str, None, None]:
get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
lang, model_name, model_path = get("top.lang"), get("top.model_name"), get("top.model_path")
finetuning_type, checkpoint_path = get("top.finetuning_type"), get("top.checkpoint_path")
user_config = load_config()
error = ""
if self.loaded:
error = ALERTS["err_exists"][lang]
elif not model_name:
error = ALERTS["err_no_model"][lang]
elif not model_path:
error = ALERTS["err_no_path"][lang]
elif self.demo_mode:
error = ALERTS["err_demo"][lang]
try:
json.loads(get("infer.extra_args"))
except json.JSONDecodeError:
error = ALERTS["err_json_schema"][lang]
if error:
gr.Warning(error)
yield error
return
yield ALERTS["info_loading"][lang]
args = dict(
model_name_or_path=model_path,
cache_dir=user_config.get("cache_dir", None),
finetuning_type=finetuning_type,
template=get("top.template"),
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") != "none" else None,
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
use_unsloth=(get("top.booster") == "unsloth"),
enable_liger_kernel=(get("top.booster") == "liger_kernel"),
infer_backend=get("infer.infer_backend"),
infer_dtype=get("infer.infer_dtype"),
trust_remote_code=True,
)
args.update(json.loads(get("infer.extra_args")))
# checkpoints
if checkpoint_path:
if finetuning_type in PEFT_METHODS: # list
args["adapter_name_or_path"] = ",".join(
[get_save_dir(model_name, finetuning_type, adapter) for adapter in checkpoint_path]
)
else: # str
args["model_name_or_path"] = get_save_dir(model_name, finetuning_type, checkpoint_path)
# quantization
if get("top.quantization_bit") != "none":
args["quantization_bit"] = int(get("top.quantization_bit"))
args["quantization_method"] = get("top.quantization_method")
args["double_quantization"] = not is_torch_npu_available()
super().__init__(args)
yield ALERTS["info_loaded"][lang]
def unload_model(self, data) -> Generator[str, None, None]:
lang = data[self.manager.get_elem_by_id("top.lang")]
if self.demo_mode:
gr.Warning(ALERTS["err_demo"][lang])
yield ALERTS["err_demo"][lang]
return
yield ALERTS["info_unloading"][lang]
self.engine = None
torch_gc()
yield ALERTS["info_unloaded"][lang]
@staticmethod
def append(
chatbot: list[dict[str, str]],
messages: list[dict[str, str]],
role: str,
query: str,
escape_html: bool,
) -> tuple[list[dict[str, str]], list[dict[str, str]], str]:
r"""Add the user input to chatbot.
Inputs: infer.chatbot, infer.messages, infer.role, infer.query, infer.escape_html
Output: infer.chatbot, infer.messages, infer.query
"""
return (
chatbot + [{"role": "user", "content": _escape_html(query) if escape_html else query}],
messages + [{"role": role, "content": query}],
"",
)
def stream(
self,
chatbot: list[dict[str, str]],
messages: list[dict[str, str]],
lang: str,
system: str,
tools: str,
image: Optional[Any],
video: Optional[Any],
audio: Optional[Any],
max_new_tokens: int,
top_p: float,
temperature: float,
skip_special_tokens: bool,
escape_html: bool,
enable_thinking: bool,
) -> Generator[tuple[list[dict[str, str]], list[dict[str, str]]], None, None]:
r"""Generate output text in stream.
Inputs: infer.chatbot, infer.messages, infer.system, infer.tools, infer.image, infer.video, ...
Output: infer.chatbot, infer.messages
"""
with update_attr(self.engine.template, "enable_thinking", enable_thinking):
chatbot.append({"role": "assistant", "content": ""})
response = ""
for new_text in self.stream_chat(
messages,
system,
tools,
images=[image] if image else None,
videos=[video] if video else None,
audios=[audio] if audio else None,
max_new_tokens=max_new_tokens,
top_p=top_p,
temperature=temperature,
skip_special_tokens=skip_special_tokens,
):
response += new_text
if tools:
result = self.engine.template.extract_tool(response)
else:
result = response
if isinstance(result, list):
tool_calls = [{"name": tool.name, "arguments": json.loads(tool.arguments)} for tool in result]
tool_calls = json.dumps(tool_calls, ensure_ascii=False)
output_messages = messages + [{"role": Role.FUNCTION.value, "content": tool_calls}]
bot_text = "```json\n" + tool_calls + "\n```"
else:
output_messages = messages + [{"role": Role.ASSISTANT.value, "content": result}]
bot_text = _format_response(result, lang, escape_html, self.engine.template.thought_words)
chatbot[-1] = {"role": "assistant", "content": bot_text}
yield chatbot, output_messages
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import signal
from collections import defaultdict
from datetime import datetime
from typing import Any, Optional, Union
from psutil import Process
from yaml import safe_dump, safe_load
from ..extras import logging
from ..extras.constants import (
DATA_CONFIG,
DEFAULT_TEMPLATE,
MULTIMODAL_SUPPORTED_MODELS,
SUPPORTED_MODELS,
TRAINING_ARGS,
DownloadSource,
)
from ..extras.misc import use_modelscope, use_openmind
logger = logging.get_logger(__name__)
DEFAULT_CACHE_DIR = "cache"
DEFAULT_CONFIG_DIR = "config"
DEFAULT_DATA_DIR = "data"
DEFAULT_SAVE_DIR = "saves"
USER_CONFIG = "user_config.yaml"
def abort_process(pid: int) -> None:
r"""Abort the processes recursively in a bottom-up way."""
try:
children = Process(pid).children()
if children:
for child in children:
abort_process(child.pid)
os.kill(pid, signal.SIGABRT)
except Exception:
pass
def get_save_dir(*paths: str) -> os.PathLike:
r"""Get the path to saved model checkpoints."""
if os.path.sep in paths[-1]:
logger.warning_rank0("Found complex path, some features may be not available.")
return paths[-1]
paths = (path.replace(" ", "").strip() for path in paths)
return os.path.join(DEFAULT_SAVE_DIR, *paths)
def _get_config_path() -> os.PathLike:
r"""Get the path to user config."""
return os.path.join(DEFAULT_CACHE_DIR, USER_CONFIG)
def load_config() -> dict[str, Union[str, dict[str, Any]]]:
r"""Load user config if exists."""
try:
with open(_get_config_path(), encoding="utf-8") as f:
return safe_load(f)
except Exception:
return {"lang": None, "hub_name": None, "last_model": None, "path_dict": {}, "cache_dir": None}
def save_config(
lang: str, hub_name: Optional[str] = None, model_name: Optional[str] = None, model_path: Optional[str] = None
) -> None:
r"""Save user config."""
os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True)
user_config = load_config()
user_config["lang"] = lang or user_config["lang"]
if hub_name:
user_config["hub_name"] = hub_name
if model_name:
user_config["last_model"] = model_name
if model_name and model_path:
user_config["path_dict"][model_name] = model_path
with open(_get_config_path(), "w", encoding="utf-8") as f:
safe_dump(user_config, f)
def get_model_path(model_name: str) -> str:
r"""Get the model path according to the model name."""
user_config = load_config()
path_dict: dict[DownloadSource, str] = SUPPORTED_MODELS.get(model_name, defaultdict(str))
model_path = user_config["path_dict"].get(model_name, "") or path_dict.get(DownloadSource.DEFAULT, "")
if (
use_modelscope()
and path_dict.get(DownloadSource.MODELSCOPE)
and model_path == path_dict.get(DownloadSource.DEFAULT)
): # replace hf path with ms path
model_path = path_dict.get(DownloadSource.MODELSCOPE)
if (
use_openmind()
and path_dict.get(DownloadSource.OPENMIND)
and model_path == path_dict.get(DownloadSource.DEFAULT)
): # replace hf path with om path
model_path = path_dict.get(DownloadSource.OPENMIND)
return model_path
def get_template(model_name: str) -> str:
r"""Get the template name if the model is a chat/distill/instruct model."""
return DEFAULT_TEMPLATE.get(model_name, "default")
def get_time() -> str:
r"""Get current date and time."""
return datetime.now().strftime(r"%Y-%m-%d-%H-%M-%S")
def is_multimodal(model_name: str) -> bool:
r"""Judge if the model is a vision language model."""
return model_name in MULTIMODAL_SUPPORTED_MODELS
def load_dataset_info(dataset_dir: str) -> dict[str, dict[str, Any]]:
r"""Load dataset_info.json."""
if dataset_dir == "ONLINE" or dataset_dir.startswith("REMOTE:"):
logger.info_rank0(f"dataset_dir is {dataset_dir}, using online dataset.")
return {}
try:
with open(os.path.join(dataset_dir, DATA_CONFIG), encoding="utf-8") as f:
return json.load(f)
except Exception as err:
logger.warning_rank0(f"Cannot open {os.path.join(dataset_dir, DATA_CONFIG)} due to {str(err)}.")
return {}
def load_args(config_path: str) -> Optional[dict[str, Any]]:
r"""Load the training configuration from config path."""
try:
with open(config_path, encoding="utf-8") as f:
return safe_load(f)
except Exception:
return None
def save_args(config_path: str, config_dict: dict[str, Any]) -> None:
r"""Save the training configuration to config path."""
with open(config_path, "w", encoding="utf-8") as f:
safe_dump(config_dict, f)
def _clean_cmd(args: dict[str, Any]) -> dict[str, Any]:
r"""Remove args with NoneType or False or empty string value."""
no_skip_keys = [
"packing",
"enable_thinking",
"use_reentrant_gc",
"double_quantization",
"freeze_vision_tower",
"freeze_multi_modal_projector",
]
return {k: v for k, v in args.items() if (k in no_skip_keys) or (v is not None and v is not False and v != "")}
def gen_cmd(args: dict[str, Any]) -> str:
r"""Generate CLI commands for previewing."""
cmd_lines = ["llamafactory-cli train "]
for k, v in _clean_cmd(args).items():
if isinstance(v, dict):
cmd_lines.append(f" --{k} {json.dumps(v, ensure_ascii=False)} ")
elif isinstance(v, list):
cmd_lines.append(f" --{k} {' '.join(map(str, v))} ")
else:
cmd_lines.append(f" --{k} {str(v)} ")
if os.name == "nt":
cmd_text = "`\n".join(cmd_lines)
else:
cmd_text = "\\\n".join(cmd_lines)
cmd_text = f"```bash\n{cmd_text}\n```"
return cmd_text
def save_cmd(args: dict[str, Any]) -> str:
r"""Save CLI commands to launch training."""
output_dir = args["output_dir"]
os.makedirs(output_dir, exist_ok=True)
with open(os.path.join(output_dir, TRAINING_ARGS), "w", encoding="utf-8") as f:
safe_dump(_clean_cmd(args), f)
return os.path.join(output_dir, TRAINING_ARGS)
def load_eval_results(path: os.PathLike) -> str:
r"""Get scores after evaluation."""
with open(path, encoding="utf-8") as f:
result = json.dumps(json.load(f), indent=4)
return f"```json\n{result}\n```\n"
def calculate_pixels(pixels: str) -> int:
r"""Calculate the number of pixels from the expression."""
if "*" in pixels:
return int(pixels.split("*")[0]) * int(pixels.split("*")[1])
else:
return int(pixels)
def create_ds_config() -> None:
r"""Create deepspeed config in the current directory."""
os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True)
ds_config = {
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"zero_allow_untested_optimizer": True,
"fp16": {
"enabled": "auto",
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1,
},
"bf16": {"enabled": "auto"},
}
offload_config = {
"device": "cpu",
"pin_memory": True,
}
ds_config["zero_optimization"] = {
"stage": 2,
"allgather_partitions": True,
"allgather_bucket_size": 5e8,
"overlap_comm": False,
"reduce_scatter": True,
"reduce_bucket_size": 5e8,
"contiguous_gradients": True,
"round_robin_gradients": True,
}
with open(os.path.join(DEFAULT_CACHE_DIR, "ds_z2_config.json"), "w", encoding="utf-8") as f:
json.dump(ds_config, f, indent=2)
ds_config["zero_optimization"]["offload_optimizer"] = offload_config
with open(os.path.join(DEFAULT_CACHE_DIR, "ds_z2_offload_config.json"), "w", encoding="utf-8") as f:
json.dump(ds_config, f, indent=2)
ds_config["zero_optimization"] = {
"stage": 3,
"overlap_comm": False,
"contiguous_gradients": True,
"sub_group_size": 1e9,
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_16bit_weights_on_model_save": True,
}
with open(os.path.join(DEFAULT_CACHE_DIR, "ds_z3_config.json"), "w", encoding="utf-8") as f:
json.dump(ds_config, f, indent=2)
ds_config["zero_optimization"]["offload_optimizer"] = offload_config
ds_config["zero_optimization"]["offload_param"] = offload_config
with open(os.path.join(DEFAULT_CACHE_DIR, "ds_z3_offload_config.json"), "w", encoding="utf-8") as f:
json.dump(ds_config, f, indent=2)
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .chatbot import create_chat_box
from .eval import create_eval_tab
from .export import create_export_tab
from .footer import create_footer
from .infer import create_infer_tab
from .top import create_top
from .train import create_train_tab
__all__ = [
"create_chat_box",
"create_eval_tab",
"create_export_tab",
"create_footer",
"create_infer_tab",
"create_top",
"create_train_tab",
]
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import json
from typing import TYPE_CHECKING
from ...data import Role
from ...extras.packages import is_gradio_available
from ..locales import ALERTS
if is_gradio_available():
import gradio as gr
if TYPE_CHECKING:
from gradio.components import Component
from ..engine import Engine
def check_json_schema(text: str, lang: str) -> None:
r"""Check if the json schema is valid."""
try:
tools = json.loads(text)
if tools:
assert isinstance(tools, list)
for tool in tools:
if "name" not in tool:
raise NotImplementedError("Name not found.")
except NotImplementedError:
gr.Warning(ALERTS["err_tool_name"][lang])
except Exception:
gr.Warning(ALERTS["err_json_schema"][lang])
def create_chat_box(
engine: "Engine", visible: bool = False
) -> tuple["Component", "Component", dict[str, "Component"]]:
lang = engine.manager.get_elem_by_id("top.lang")
with gr.Column(visible=visible) as chat_box:
kwargs = {}
if "show_copy_button" in inspect.signature(gr.Chatbot.__init__).parameters:
kwargs["show_copy_button"] = True
if "resizable" in inspect.signature(gr.Chatbot.__init__).parameters:
kwargs["resizable"] = True
chatbot = gr.Chatbot(type="messages", **kwargs)
messages = gr.State([])
with gr.Row():
with gr.Column(scale=4):
with gr.Row():
with gr.Column():
role = gr.Dropdown(choices=[Role.USER.value, Role.OBSERVATION.value], value=Role.USER.value)
system = gr.Textbox(show_label=False)
tools = gr.Textbox(show_label=False, lines=3)
with gr.Column() as mm_box:
with gr.Tab("Image"):
image = gr.Image(type="pil")
with gr.Tab("Video"):
video = gr.Video()
with gr.Tab("Audio"):
audio = gr.Audio(type="filepath")
query = gr.Textbox(show_label=False, lines=8)
submit_btn = gr.Button(variant="primary")
with gr.Column(scale=1):
max_new_tokens = gr.Slider(minimum=8, maximum=8192, value=1024, step=1)
top_p = gr.Slider(minimum=0.01, maximum=1.0, value=0.7, step=0.01)
temperature = gr.Slider(minimum=0.01, maximum=1.5, value=0.95, step=0.01)
skip_special_tokens = gr.Checkbox(value=True)
escape_html = gr.Checkbox(value=True)
enable_thinking = gr.Checkbox(value=True)
clear_btn = gr.Button()
tools.input(check_json_schema, inputs=[tools, engine.manager.get_elem_by_id("top.lang")])
submit_btn.click(
engine.chatter.append,
[chatbot, messages, role, query, escape_html],
[chatbot, messages, query],
).then(
engine.chatter.stream,
[
chatbot,
messages,
lang,
system,
tools,
image,
video,
audio,
max_new_tokens,
top_p,
temperature,
skip_special_tokens,
escape_html,
enable_thinking,
],
[chatbot, messages],
)
clear_btn.click(lambda: ([], []), outputs=[chatbot, messages])
return (
chatbot,
messages,
dict(
chat_box=chat_box,
role=role,
system=system,
tools=tools,
mm_box=mm_box,
image=image,
video=video,
audio=audio,
query=query,
submit_btn=submit_btn,
max_new_tokens=max_new_tokens,
top_p=top_p,
temperature=temperature,
skip_special_tokens=skip_special_tokens,
escape_html=escape_html,
enable_thinking=enable_thinking,
clear_btn=clear_btn,
),
)
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
from typing import TYPE_CHECKING, Any
from ...extras.constants import DATA_CONFIG
from ...extras.packages import is_gradio_available
if is_gradio_available():
import gradio as gr
if TYPE_CHECKING:
from gradio.components import Component
PAGE_SIZE = 2
def prev_page(page_index: int) -> int:
return page_index - 1 if page_index > 0 else page_index
def next_page(page_index: int, total_num: int) -> int:
return page_index + 1 if (page_index + 1) * PAGE_SIZE < total_num else page_index
def can_preview(dataset_dir: str, dataset: list) -> "gr.Button":
r"""Check if the dataset is a local dataset."""
try:
with open(os.path.join(dataset_dir, DATA_CONFIG), encoding="utf-8") as f:
dataset_info = json.load(f)
except Exception:
return gr.Button(interactive=False)
if len(dataset) == 0 or "file_name" not in dataset_info[dataset[0]]:
return gr.Button(interactive=False)
data_path = os.path.join(dataset_dir, dataset_info[dataset[0]]["file_name"])
if os.path.isfile(data_path) or (os.path.isdir(data_path) and os.listdir(data_path)):
return gr.Button(interactive=True)
else:
return gr.Button(interactive=False)
def _load_data_file(file_path: str) -> list[Any]:
with open(file_path, encoding="utf-8") as f:
if file_path.endswith(".json"):
return json.load(f)
elif file_path.endswith(".jsonl"):
return [json.loads(line) for line in f]
else:
return list(f)
def get_preview(dataset_dir: str, dataset: list, page_index: int) -> tuple[int, list, "gr.Column"]:
r"""Get the preview samples from the dataset."""
with open(os.path.join(dataset_dir, DATA_CONFIG), encoding="utf-8") as f:
dataset_info = json.load(f)
data_path = os.path.join(dataset_dir, dataset_info[dataset[0]]["file_name"])
if os.path.isfile(data_path):
data = _load_data_file(data_path)
else:
data = []
for file_name in os.listdir(data_path):
data.extend(_load_data_file(os.path.join(data_path, file_name)))
return len(data), data[PAGE_SIZE * page_index : PAGE_SIZE * (page_index + 1)], gr.Column(visible=True)
def create_preview_box(dataset_dir: "gr.Textbox", dataset: "gr.Dropdown") -> dict[str, "Component"]:
data_preview_btn = gr.Button(interactive=False, scale=1)
with gr.Column(visible=False, elem_classes="modal-box") as preview_box:
with gr.Row():
preview_count = gr.Number(value=0, interactive=False, precision=0)
page_index = gr.Number(value=0, interactive=False, precision=0)
with gr.Row():
prev_btn = gr.Button()
next_btn = gr.Button()
close_btn = gr.Button()
with gr.Row():
preview_samples = gr.JSON()
dataset.change(can_preview, [dataset_dir, dataset], [data_preview_btn], queue=False).then(
lambda: 0, outputs=[page_index], queue=False
)
data_preview_btn.click(
get_preview, [dataset_dir, dataset, page_index], [preview_count, preview_samples, preview_box], queue=False
)
prev_btn.click(prev_page, [page_index], [page_index], queue=False).then(
get_preview, [dataset_dir, dataset, page_index], [preview_count, preview_samples, preview_box], queue=False
)
next_btn.click(next_page, [page_index, preview_count], [page_index], queue=False).then(
get_preview, [dataset_dir, dataset, page_index], [preview_count, preview_samples, preview_box], queue=False
)
close_btn.click(lambda: gr.Column(visible=False), outputs=[preview_box], queue=False)
return dict(
data_preview_btn=data_preview_btn,
preview_count=preview_count,
page_index=page_index,
prev_btn=prev_btn,
next_btn=next_btn,
close_btn=close_btn,
preview_samples=preview_samples,
)
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...extras.packages import is_gradio_available
from ..common import DEFAULT_DATA_DIR
from ..control import list_datasets
from .data import create_preview_box
if is_gradio_available():
import gradio as gr
if TYPE_CHECKING:
from gradio.components import Component
from ..engine import Engine
def create_eval_tab(engine: "Engine") -> dict[str, "Component"]:
input_elems = engine.manager.get_base_elems()
elem_dict = dict()
with gr.Row():
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
dataset = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=4)
preview_elems = create_preview_box(dataset_dir, dataset)
input_elems.update({dataset_dir, dataset})
elem_dict.update(dict(dataset_dir=dataset_dir, dataset=dataset, **preview_elems))
with gr.Row():
cutoff_len = gr.Slider(minimum=4, maximum=131072, value=1024, step=1)
max_samples = gr.Textbox(value="100000")
batch_size = gr.Slider(minimum=1, maximum=1024, value=2, step=1)
predict = gr.Checkbox(value=True)
input_elems.update({cutoff_len, max_samples, batch_size, predict})
elem_dict.update(dict(cutoff_len=cutoff_len, max_samples=max_samples, batch_size=batch_size, predict=predict))
with gr.Row():
max_new_tokens = gr.Slider(minimum=8, maximum=4096, value=512, step=1)
top_p = gr.Slider(minimum=0.01, maximum=1, value=0.7, step=0.01)
temperature = gr.Slider(minimum=0.01, maximum=1.5, value=0.95, step=0.01)
output_dir = gr.Textbox()
input_elems.update({max_new_tokens, top_p, temperature, output_dir})
elem_dict.update(dict(max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature, output_dir=output_dir))
with gr.Row():
cmd_preview_btn = gr.Button()
start_btn = gr.Button(variant="primary")
stop_btn = gr.Button(variant="stop")
with gr.Row():
resume_btn = gr.Checkbox(visible=False, interactive=False)
progress_bar = gr.Slider(visible=False, interactive=False)
with gr.Row():
output_box = gr.Markdown()
elem_dict.update(
dict(
cmd_preview_btn=cmd_preview_btn,
start_btn=start_btn,
stop_btn=stop_btn,
resume_btn=resume_btn,
progress_bar=progress_bar,
output_box=output_box,
)
)
output_elems = [output_box, progress_bar]
cmd_preview_btn.click(engine.runner.preview_eval, input_elems, output_elems, concurrency_limit=None)
start_btn.click(engine.runner.run_eval, input_elems, output_elems)
stop_btn.click(engine.runner.set_abort)
resume_btn.change(engine.runner.monitor, outputs=output_elems, concurrency_limit=None)
dataset.focus(list_datasets, [dataset_dir], [dataset], queue=False)
return elem_dict
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from collections.abc import Generator
from typing import TYPE_CHECKING, Union
from ...extras.constants import PEFT_METHODS
from ...extras.misc import torch_gc
from ...extras.packages import is_gradio_available
from ...train.tuner import export_model
from ..common import get_save_dir, load_config
from ..locales import ALERTS
if is_gradio_available():
import gradio as gr
if TYPE_CHECKING:
from gradio.components import Component
from ..engine import Engine
GPTQ_BITS = ["8", "4", "3", "2"]
def can_quantize(checkpoint_path: Union[str, list[str]]) -> "gr.Dropdown":
if isinstance(checkpoint_path, list) and len(checkpoint_path) != 0:
return gr.Dropdown(value="none", interactive=False)
else:
return gr.Dropdown(interactive=True)
def save_model(
lang: str,
model_name: str,
model_path: str,
finetuning_type: str,
checkpoint_path: Union[str, list[str]],
template: str,
export_size: int,
export_quantization_bit: str,
export_quantization_dataset: str,
export_device: str,
export_legacy_format: bool,
export_dir: str,
export_hub_model_id: str,
extra_args: str,
) -> Generator[str, None, None]:
user_config = load_config()
error = ""
if not model_name:
error = ALERTS["err_no_model"][lang]
elif not model_path:
error = ALERTS["err_no_path"][lang]
elif not export_dir:
error = ALERTS["err_no_export_dir"][lang]
elif export_quantization_bit in GPTQ_BITS and not export_quantization_dataset:
error = ALERTS["err_no_dataset"][lang]
elif export_quantization_bit not in GPTQ_BITS and not checkpoint_path:
error = ALERTS["err_no_adapter"][lang]
elif export_quantization_bit in GPTQ_BITS and checkpoint_path and isinstance(checkpoint_path, list):
error = ALERTS["err_gptq_lora"][lang]
try:
json.loads(extra_args)
except json.JSONDecodeError:
error = ALERTS["err_json_schema"][lang]
if error:
gr.Warning(error)
yield error
return
args = dict(
model_name_or_path=model_path,
cache_dir=user_config.get("cache_dir", None),
finetuning_type=finetuning_type,
template=template,
export_dir=export_dir,
export_hub_model_id=export_hub_model_id or None,
export_size=export_size,
export_quantization_bit=int(export_quantization_bit) if export_quantization_bit in GPTQ_BITS else None,
export_quantization_dataset=export_quantization_dataset,
export_device=export_device,
export_legacy_format=export_legacy_format,
trust_remote_code=True,
)
args.update(json.loads(extra_args))
if checkpoint_path:
if finetuning_type in PEFT_METHODS: # list
args["adapter_name_or_path"] = ",".join(
[get_save_dir(model_name, finetuning_type, adapter) for adapter in checkpoint_path]
)
else: # str
args["model_name_or_path"] = get_save_dir(model_name, finetuning_type, checkpoint_path)
yield ALERTS["info_exporting"][lang]
export_model(args)
torch_gc()
yield ALERTS["info_exported"][lang]
def create_export_tab(engine: "Engine") -> dict[str, "Component"]:
with gr.Row():
export_size = gr.Slider(minimum=1, maximum=100, value=5, step=1)
export_quantization_bit = gr.Dropdown(choices=["none"] + GPTQ_BITS, value="none")
export_quantization_dataset = gr.Textbox(value="data/c4_demo.jsonl")
export_device = gr.Radio(choices=["cpu", "auto"], value="cpu")
export_legacy_format = gr.Checkbox()
with gr.Row():
export_dir = gr.Textbox()
export_hub_model_id = gr.Textbox()
extra_args = gr.Textbox(value="{}")
checkpoint_path: gr.Dropdown = engine.manager.get_elem_by_id("top.checkpoint_path")
checkpoint_path.change(can_quantize, [checkpoint_path], [export_quantization_bit], queue=False)
export_btn = gr.Button()
info_box = gr.Textbox(show_label=False, interactive=False)
export_btn.click(
save_model,
[
engine.manager.get_elem_by_id("top.lang"),
engine.manager.get_elem_by_id("top.model_name"),
engine.manager.get_elem_by_id("top.model_path"),
engine.manager.get_elem_by_id("top.finetuning_type"),
engine.manager.get_elem_by_id("top.checkpoint_path"),
engine.manager.get_elem_by_id("top.template"),
export_size,
export_quantization_bit,
export_quantization_dataset,
export_device,
export_legacy_format,
export_dir,
export_hub_model_id,
extra_args,
],
[info_box],
)
return dict(
export_size=export_size,
export_quantization_bit=export_quantization_bit,
export_quantization_dataset=export_quantization_dataset,
export_device=export_device,
export_legacy_format=export_legacy_format,
export_dir=export_dir,
export_hub_model_id=export_hub_model_id,
extra_args=extra_args,
export_btn=export_btn,
info_box=info_box,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment