Commit 5ed76316 authored by 雍大凯's avatar 雍大凯
Browse files

models add

parent b2379236
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import asdict, dataclass, field
from typing import Any, Literal, Optional
@dataclass
class FreezeArguments:
r"""Arguments pertaining to the freeze (partial-parameter) training."""
freeze_trainable_layers: int = field(
default=2,
metadata={
"help": (
"The number of trainable layers for freeze (partial-parameter) fine-tuning. "
"Positive numbers mean the last n layers are set as trainable, "
"negative numbers mean the first n layers are set as trainable."
)
},
)
freeze_trainable_modules: str = field(
default="all",
metadata={
"help": (
"Name(s) of trainable modules for freeze (partial-parameter) fine-tuning. "
"Use commas to separate multiple modules. "
"Use `all` to specify all the available modules."
)
},
)
freeze_extra_modules: Optional[str] = field(
default=None,
metadata={
"help": (
"Name(s) of modules apart from hidden layers to be set as trainable "
"for freeze (partial-parameter) fine-tuning. "
"Use commas to separate multiple modules."
)
},
)
@dataclass
class LoraArguments:
r"""Arguments pertaining to the LoRA training."""
additional_target: Optional[str] = field(
default=None,
metadata={
"help": (
"Name(s) of modules apart from LoRA layers to be set as trainable "
"and saved in the final checkpoint. "
"Use commas to separate multiple modules."
)
},
)
lora_alpha: Optional[int] = field(
default=None,
metadata={"help": "The scale factor for LoRA fine-tuning (default: lora_rank * 2)."},
)
lora_dropout: float = field(
default=0.0,
metadata={"help": "Dropout rate for the LoRA fine-tuning."},
)
lora_rank: int = field(
default=8,
metadata={"help": "The intrinsic dimension for LoRA fine-tuning."},
)
lora_target: str = field(
default="all",
metadata={
"help": (
"Name(s) of target modules to apply LoRA. "
"Use commas to separate multiple modules. "
"Use `all` to specify all the linear modules."
)
},
)
loraplus_lr_ratio: Optional[float] = field(
default=None,
metadata={"help": "LoRA plus learning rate ratio (lr_B / lr_A)."},
)
loraplus_lr_embedding: float = field(
default=1e-6,
metadata={"help": "LoRA plus learning rate for lora embedding layers."},
)
use_rslora: bool = field(
default=False,
metadata={"help": "Whether or not to use the rank stabilization scaling factor for LoRA layer."},
)
use_dora: bool = field(
default=False,
metadata={"help": "Whether or not to use the weight-decomposed lora method (DoRA)."},
)
pissa_init: bool = field(
default=False,
metadata={"help": "Whether or not to initialize a PiSSA adapter."},
)
pissa_iter: int = field(
default=16,
metadata={"help": "The number of iteration steps performed by FSVD in PiSSA. Use -1 to disable it."},
)
pissa_convert: bool = field(
default=False,
metadata={"help": "Whether or not to convert the PiSSA adapter to a normal LoRA adapter."},
)
create_new_adapter: bool = field(
default=False,
metadata={"help": "Whether or not to create a new adapter with randomly initialized weight."},
)
@dataclass
class RLHFArguments:
r"""Arguments pertaining to the PPO, DPO and KTO training."""
pref_beta: float = field(
default=0.1,
metadata={"help": "The beta parameter in the preference loss."},
)
pref_ftx: float = field(
default=0.0,
metadata={"help": "The supervised fine-tuning loss coefficient in DPO training."},
)
pref_loss: Literal["sigmoid", "hinge", "ipo", "kto_pair", "orpo", "simpo"] = field(
default="sigmoid",
metadata={"help": "The type of DPO loss to use."},
)
dpo_label_smoothing: float = field(
default=0.0,
metadata={"help": "The robust DPO label smoothing parameter in cDPO that should be between 0 and 0.5."},
)
kto_chosen_weight: float = field(
default=1.0,
metadata={"help": "The weight factor of the desirable losses in KTO training."},
)
kto_rejected_weight: float = field(
default=1.0,
metadata={"help": "The weight factor of the undesirable losses in KTO training."},
)
simpo_gamma: float = field(
default=0.5,
metadata={"help": "The target reward margin term in SimPO loss."},
)
ppo_buffer_size: int = field(
default=1,
metadata={"help": "The number of mini-batches to make experience buffer in a PPO optimization step."},
)
ppo_epochs: int = field(
default=4,
metadata={"help": "The number of epochs to perform in a PPO optimization step."},
)
ppo_score_norm: bool = field(
default=False,
metadata={"help": "Use score normalization in PPO training."},
)
ppo_target: float = field(
default=6.0,
metadata={"help": "Target KL value for adaptive KL control in PPO training."},
)
ppo_whiten_rewards: bool = field(
default=False,
metadata={"help": "Whiten the rewards before compute advantages in PPO training."},
)
ref_model: Optional[str] = field(
default=None,
metadata={"help": "Path to the reference model used for the PPO or DPO training."},
)
ref_model_adapters: Optional[str] = field(
default=None,
metadata={"help": "Path to the adapters of the reference model."},
)
ref_model_quantization_bit: Optional[int] = field(
default=None,
metadata={"help": "The number of bits to quantize the reference model."},
)
reward_model: Optional[str] = field(
default=None,
metadata={"help": "Path to the reward model used for the PPO training."},
)
reward_model_adapters: Optional[str] = field(
default=None,
metadata={"help": "Path to the adapters of the reward model."},
)
reward_model_quantization_bit: Optional[int] = field(
default=None,
metadata={"help": "The number of bits to quantize the reward model."},
)
reward_model_type: Literal["lora", "full", "api"] = field(
default="lora",
metadata={"help": "The type of the reward model in PPO training. Lora model only supports lora training."},
)
ld_alpha: Optional[float] = field(
default=None,
metadata={
"help": (
"Alpha parameter from the LD-DPO paper, which controls the weighting of"
" the verbose token log-probabilities in responses."
)
},
)
@dataclass
class GaloreArguments:
r"""Arguments pertaining to the GaLore algorithm."""
use_galore: bool = field(
default=False,
metadata={"help": "Whether or not to use the gradient low-Rank projection (GaLore)."},
)
galore_target: str = field(
default="all",
metadata={
"help": (
"Name(s) of modules to apply GaLore. Use commas to separate multiple modules. "
"Use `all` to specify all the linear modules."
)
},
)
galore_rank: int = field(
default=16,
metadata={"help": "The rank of GaLore gradients."},
)
galore_update_interval: int = field(
default=200,
metadata={"help": "Number of steps to update the GaLore projection."},
)
galore_scale: float = field(
default=2.0,
metadata={"help": "GaLore scaling coefficient."},
)
galore_proj_type: Literal["std", "reverse_std", "right", "left", "full"] = field(
default="std",
metadata={"help": "Type of GaLore projection."},
)
galore_layerwise: bool = field(
default=False,
metadata={"help": "Whether or not to enable layer-wise update to further save memory."},
)
@dataclass
class ApolloArguments:
r"""Arguments pertaining to the APOLLO algorithm."""
use_apollo: bool = field(
default=False,
metadata={"help": "Whether or not to use the APOLLO optimizer."},
)
apollo_target: str = field(
default="all",
metadata={
"help": (
"Name(s) of modules to apply APOLLO. Use commas to separate multiple modules. "
"Use `all` to specify all the linear modules."
)
},
)
apollo_rank: int = field(
default=16,
metadata={"help": "The rank of APOLLO gradients."},
)
apollo_update_interval: int = field(
default=200,
metadata={"help": "Number of steps to update the APOLLO projection."},
)
apollo_scale: float = field(
default=32.0,
metadata={"help": "APOLLO scaling coefficient."},
)
apollo_proj: Literal["svd", "random"] = field(
default="random",
metadata={"help": "Type of APOLLO low-rank projection algorithm (svd or random)."},
)
apollo_proj_type: Literal["std", "right", "left"] = field(
default="std",
metadata={"help": "Type of APOLLO projection."},
)
apollo_scale_type: Literal["channel", "tensor"] = field(
default="channel",
metadata={"help": "Type of APOLLO scaling (channel or tensor)."},
)
apollo_layerwise: bool = field(
default=False,
metadata={"help": "Whether or not to enable layer-wise update to further save memory."},
)
apollo_scale_front: bool = field(
default=False,
metadata={"help": "Whether or not to use the norm-growth limiter in front of gradient scaling."},
)
@dataclass
class BAdamArgument:
r"""Arguments pertaining to the BAdam optimizer."""
use_badam: bool = field(
default=False,
metadata={"help": "Whether or not to use the BAdam optimizer."},
)
badam_mode: Literal["layer", "ratio"] = field(
default="layer",
metadata={"help": "Whether to use layer-wise or ratio-wise BAdam optimizer."},
)
badam_start_block: Optional[int] = field(
default=None,
metadata={"help": "The starting block index for layer-wise BAdam."},
)
badam_switch_mode: Optional[Literal["ascending", "descending", "random", "fixed"]] = field(
default="ascending",
metadata={"help": "the strategy of picking block to update for layer-wise BAdam."},
)
badam_switch_interval: Optional[int] = field(
default=50,
metadata={
"help": "Number of steps to update the block for layer-wise BAdam. Use -1 to disable the block update."
},
)
badam_update_ratio: float = field(
default=0.05,
metadata={"help": "The ratio of the update for ratio-wise BAdam."},
)
badam_mask_mode: Literal["adjacent", "scatter"] = field(
default="adjacent",
metadata={
"help": (
"The mode of the mask for BAdam optimizer. "
"`adjacent` means that the trainable parameters are adjacent to each other, "
"`scatter` means that trainable parameters are randomly choosed from the weight."
)
},
)
badam_verbose: int = field(
default=0,
metadata={
"help": (
"The verbosity level of BAdam optimizer. "
"0 for no print, 1 for print the block prefix, 2 for print trainable parameters."
)
},
)
@dataclass
class SwanLabArguments:
use_swanlab: bool = field(
default=False,
metadata={"help": "Whether or not to use the SwanLab (an experiment tracking and visualization tool)."},
)
swanlab_project: Optional[str] = field(
default="llamafactory",
metadata={"help": "The project name in SwanLab."},
)
swanlab_workspace: Optional[str] = field(
default=None,
metadata={"help": "The workspace name in SwanLab."},
)
swanlab_run_name: Optional[str] = field(
default=None,
metadata={"help": "The experiment name in SwanLab."},
)
swanlab_mode: Literal["cloud", "local"] = field(
default="cloud",
metadata={"help": "The mode of SwanLab."},
)
swanlab_api_key: Optional[str] = field(
default=None,
metadata={"help": "The API key for SwanLab."},
)
swanlab_logdir: Optional[str] = field(
default=None,
metadata={"help": "The log directory for SwanLab."},
)
swanlab_lark_webhook_url: Optional[str] = field(
default=None,
metadata={"help": "The Lark(飞书) webhook URL for SwanLab."},
)
swanlab_lark_secret: Optional[str] = field(
default=None,
metadata={"help": "The Lark(飞书) secret for SwanLab."},
)
@dataclass
class FinetuningArguments(
SwanLabArguments, BAdamArgument, ApolloArguments, GaloreArguments, RLHFArguments, LoraArguments, FreezeArguments
):
r"""Arguments pertaining to which techniques we are going to fine-tuning with."""
pure_bf16: bool = field(
default=False,
metadata={"help": "Whether or not to train model in purely bf16 precision (without AMP)."},
)
stage: Literal["pt", "sft", "rm", "ppo", "dpo", "kto"] = field(
default="sft",
metadata={"help": "Which stage will be performed in training."},
)
finetuning_type: Literal["lora", "freeze", "full"] = field(
default="lora",
metadata={"help": "Which fine-tuning method to use."},
)
use_llama_pro: bool = field(
default=False,
metadata={"help": "Whether or not to make only the parameters in the expanded blocks trainable."},
)
use_adam_mini: bool = field(
default=False,
metadata={"help": "Whether or not to use the Adam-mini optimizer."},
)
use_muon: bool = field(
default=False,
metadata={"help": "Whether or not to use the Muon optimizer."},
)
freeze_vision_tower: bool = field(
default=True,
metadata={"help": "Whether ot not to freeze the vision tower in MLLM training."},
)
freeze_multi_modal_projector: bool = field(
default=True,
metadata={"help": "Whether or not to freeze the multi modal projector in MLLM training."},
)
freeze_language_model: bool = field(
default=False,
metadata={"help": "Whether or not to freeze the language model in MLLM training."},
)
compute_accuracy: bool = field(
default=False,
metadata={"help": "Whether or not to compute the token-level accuracy at evaluation."},
)
disable_shuffling: bool = field(
default=False,
metadata={"help": "Whether or not to disable the shuffling of the training set."},
)
early_stopping_steps: Optional[int] = field(
default=None,
metadata={"help": "Number of steps to stop training if the `metric_for_best_model` does not improve."},
)
plot_loss: bool = field(
default=False,
metadata={"help": "Whether or not to save the training loss curves."},
)
include_effective_tokens_per_second: bool = field(
default=False,
metadata={"help": "Whether or not to compute effective tokens per second."},
)
def __post_init__(self):
def split_arg(arg):
if isinstance(arg, str):
return [item.strip() for item in arg.split(",")]
return arg
self.freeze_trainable_modules: list[str] = split_arg(self.freeze_trainable_modules)
self.freeze_extra_modules: Optional[list[str]] = split_arg(self.freeze_extra_modules)
self.lora_alpha: int = self.lora_alpha or self.lora_rank * 2
self.lora_target: list[str] = split_arg(self.lora_target)
self.additional_target: Optional[list[str]] = split_arg(self.additional_target)
self.galore_target: list[str] = split_arg(self.galore_target)
self.apollo_target: list[str] = split_arg(self.apollo_target)
self.use_ref_model = self.stage == "dpo" and self.pref_loss not in ["orpo", "simpo"]
assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
assert self.reward_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
if self.stage == "ppo" and self.reward_model is None:
raise ValueError("`reward_model` is necessary for PPO training.")
if self.stage == "ppo" and self.reward_model_type == "lora" and self.finetuning_type != "lora":
raise ValueError("`reward_model_type` cannot be lora for Freeze/Full PPO training.")
if self.stage == "dpo" and self.pref_loss != "sigmoid" and self.dpo_label_smoothing > 1e-6:
raise ValueError("`dpo_label_smoothing` is only valid for sigmoid loss function.")
if self.use_llama_pro and self.finetuning_type == "full":
raise ValueError("`use_llama_pro` is only valid for Freeze or LoRA training.")
if self.finetuning_type == "lora" and (self.use_galore or self.use_apollo or self.use_badam):
raise ValueError("Cannot use LoRA with GaLore, APOLLO or BAdam together.")
if int(self.use_galore) + int(self.use_apollo) + (self.use_badam) > 1:
raise ValueError("Cannot use GaLore, APOLLO or BAdam together.")
if self.pissa_init and (self.stage in ["ppo", "kto"] or self.use_ref_model):
raise ValueError("Cannot use PiSSA for current training stage.")
if self.finetuning_type != "lora":
if self.loraplus_lr_ratio is not None:
raise ValueError("`loraplus_lr_ratio` is only valid for LoRA training.")
if self.use_rslora:
raise ValueError("`use_rslora` is only valid for LoRA training.")
if self.use_dora:
raise ValueError("`use_dora` is only valid for LoRA training.")
if self.pissa_init:
raise ValueError("`pissa_init` is only valid for LoRA training.")
def to_dict(self) -> dict[str, Any]:
args = asdict(self)
args = {k: f"<{k.upper()}>" if k.endswith("api_key") else v for k, v in args.items()}
return args
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import asdict, dataclass, field
from typing import Any
from transformers import GenerationConfig
@dataclass
class GeneratingArguments:
r"""Arguments pertaining to specify the decoding parameters."""
do_sample: bool = field(
default=True,
metadata={"help": "Whether or not to use sampling, use greedy decoding otherwise."},
)
temperature: float = field(
default=0.95,
metadata={"help": "The value used to modulate the next token probabilities."},
)
top_p: float = field(
default=0.7,
metadata={
"help": (
"The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept."
)
},
)
top_k: int = field(
default=50,
metadata={"help": "The number of highest probability vocabulary tokens to keep for top-k filtering."},
)
num_beams: int = field(
default=1,
metadata={"help": "Number of beams for beam search. 1 means no beam search."},
)
max_length: int = field(
default=1024,
metadata={"help": "The maximum length the generated tokens can have. It can be overridden by max_new_tokens."},
)
max_new_tokens: int = field(
default=1024,
metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."},
)
repetition_penalty: float = field(
default=1.0,
metadata={"help": "The parameter for repetition penalty. 1.0 means no penalty."},
)
length_penalty: float = field(
default=1.0,
metadata={"help": "Exponential penalty to the length that is used with beam-based generation."},
)
skip_special_tokens: bool = field(
default=True,
metadata={"help": "Whether or not to remove special tokens in the decoding."},
)
def to_dict(self, obey_generation_config: bool = False) -> dict[str, Any]:
args = asdict(self)
if args.get("max_new_tokens", -1) > 0:
args.pop("max_length", None)
else:
args.pop("max_new_tokens", None)
if obey_generation_config:
generation_config = GenerationConfig()
for key in list(args.keys()):
if not hasattr(generation_config, key):
args.pop(key)
return args
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from dataclasses import asdict, dataclass, field, fields
from typing import Any, Literal, Optional, Union
import torch
from transformers.training_args import _convert_str_dict
from typing_extensions import Self
from ..extras.constants import AttentionFunction, EngineName, QuantizationMethod, RopeScaling
@dataclass
class BaseModelArguments:
r"""Arguments pertaining to the model."""
model_name_or_path: Optional[str] = field(
default=None,
metadata={
"help": "Path to the model weight or identifier from huggingface.co/models or modelscope.cn/models."
},
)
adapter_name_or_path: Optional[str] = field(
default=None,
metadata={
"help": (
"Path to the adapter weight or identifier from huggingface.co/models. "
"Use commas to separate multiple adapters."
)
},
)
adapter_folder: Optional[str] = field(
default=None,
metadata={"help": "The folder containing the adapter weights to load."},
)
cache_dir: Optional[str] = field(
default=None,
metadata={"help": "Where to store the pre-trained models downloaded from huggingface.co or modelscope.cn."},
)
use_fast_tokenizer: bool = field(
default=True,
metadata={"help": "Whether or not to use one of the fast tokenizer (backed by the tokenizers library)."},
)
resize_vocab: bool = field(
default=False,
metadata={"help": "Whether or not to resize the tokenizer vocab and the embedding layers."},
)
split_special_tokens: bool = field(
default=False,
metadata={"help": "Whether or not the special tokens should be split during the tokenization process."},
)
add_tokens: Optional[str] = field(
default=None,
metadata={
"help": "Non-special tokens to be added into the tokenizer. Use commas to separate multiple tokens."
},
)
add_special_tokens: Optional[str] = field(
default=None,
metadata={"help": "Special tokens to be added into the tokenizer. Use commas to separate multiple tokens."},
)
model_revision: str = field(
default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
)
low_cpu_mem_usage: bool = field(
default=True,
metadata={"help": "Whether or not to use memory-efficient model loading."},
)
rope_scaling: Optional[RopeScaling] = field(
default=None,
metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."},
)
flash_attn: AttentionFunction = field(
default=AttentionFunction.AUTO,
metadata={"help": "Enable FlashAttention for faster training and inference."},
)
shift_attn: bool = field(
default=False,
metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."},
)
mixture_of_depths: Optional[Literal["convert", "load"]] = field(
default=None,
metadata={"help": "Convert the model to mixture-of-depths (MoD) or load the MoD model."},
)
use_unsloth: bool = field(
default=False,
metadata={"help": "Whether or not to use unsloth's optimization for the LoRA training."},
)
use_unsloth_gc: bool = field(
default=False,
metadata={"help": "Whether or not to use unsloth's gradient checkpointing (no need to install unsloth)."},
)
enable_liger_kernel: bool = field(
default=False,
metadata={"help": "Whether or not to enable liger kernel for faster training."},
)
moe_aux_loss_coef: Optional[float] = field(
default=None,
metadata={"help": "Coefficient of the auxiliary router loss in mixture-of-experts model."},
)
disable_gradient_checkpointing: bool = field(
default=False,
metadata={"help": "Whether or not to disable gradient checkpointing."},
)
use_reentrant_gc: bool = field(
default=True,
metadata={"help": "Whether or not to use reentrant gradient checkpointing."},
)
upcast_layernorm: bool = field(
default=False,
metadata={"help": "Whether or not to upcast the layernorm weights in fp32."},
)
upcast_lmhead_output: bool = field(
default=False,
metadata={"help": "Whether or not to upcast the output of lm_head in fp32."},
)
train_from_scratch: bool = field(
default=False,
metadata={"help": "Whether or not to randomly initialize the model weights."},
)
infer_backend: EngineName = field(
default=EngineName.HF,
metadata={"help": "Backend engine used at inference."},
)
offload_folder: str = field(
default="offload",
metadata={"help": "Path to offload model weights."},
)
use_cache: bool = field(
default=True,
metadata={"help": "Whether or not to use KV cache in generation."},
)
infer_dtype: Literal["auto", "float16", "bfloat16", "float32"] = field(
default="auto",
metadata={"help": "Data type for model weights and activations at inference."},
)
hf_hub_token: Optional[str] = field(
default=None,
metadata={"help": "Auth token to log in with Hugging Face Hub."},
)
ms_hub_token: Optional[str] = field(
default=None,
metadata={"help": "Auth token to log in with ModelScope Hub."},
)
om_hub_token: Optional[str] = field(
default=None,
metadata={"help": "Auth token to log in with Modelers Hub."},
)
print_param_status: bool = field(
default=False,
metadata={"help": "For debugging purposes, print the status of the parameters in the model."},
)
trust_remote_code: bool = field(
default=False,
metadata={"help": "Whether to trust the execution of code from datasets/models defined on the Hub or not."},
)
def __post_init__(self):
if self.model_name_or_path is None:
raise ValueError("Please provide `model_name_or_path`.")
if self.split_special_tokens and self.use_fast_tokenizer:
raise ValueError("`split_special_tokens` is only supported for slow tokenizers.")
if self.adapter_name_or_path is not None: # support merging multiple lora weights
self.adapter_name_or_path = [path.strip() for path in self.adapter_name_or_path.split(",")]
if self.add_tokens is not None: # support multiple tokens
self.add_tokens = [token.strip() for token in self.add_tokens.split(",")]
if self.add_special_tokens is not None: # support multiple special tokens
self.add_special_tokens = [token.strip() for token in self.add_special_tokens.split(",")]
@dataclass
class QuantizationArguments:
r"""Arguments pertaining to the quantization method."""
quantization_method: QuantizationMethod = field(
default=QuantizationMethod.BNB,
metadata={"help": "Quantization method to use for on-the-fly quantization."},
)
quantization_bit: Optional[int] = field(
default=None,
metadata={"help": "The number of bits to quantize the model using on-the-fly quantization."},
)
quantization_type: Literal["fp4", "nf4"] = field(
default="nf4",
metadata={"help": "Quantization data type to use in bitsandbytes int4 training."},
)
double_quantization: bool = field(
default=True,
metadata={"help": "Whether or not to use double quantization in bitsandbytes int4 training."},
)
quantization_device_map: Optional[Literal["auto"]] = field(
default=None,
metadata={"help": "Device map used to infer the 4-bit quantized model, needs bitsandbytes>=0.43.0."},
)
@dataclass
class ProcessorArguments:
r"""Arguments pertaining to the image processor."""
image_max_pixels: int = field(
default=768 * 768,
metadata={"help": "The maximum number of pixels of image inputs."},
)
image_min_pixels: int = field(
default=32 * 32,
metadata={"help": "The minimum number of pixels of image inputs."},
)
image_do_pan_and_scan: bool = field(
default=False,
metadata={"help": "Use pan and scan to process image for gemma3."},
)
crop_to_patches: bool = field(
default=False,
metadata={"help": "Whether to crop the image to patches for internvl."},
)
video_max_pixels: int = field(
default=256 * 256,
metadata={"help": "The maximum number of pixels of video inputs."},
)
video_min_pixels: int = field(
default=16 * 16,
metadata={"help": "The minimum number of pixels of video inputs."},
)
video_fps: float = field(
default=2.0,
metadata={"help": "The frames to sample per second for video inputs."},
)
video_maxlen: int = field(
default=128,
metadata={"help": "The maximum number of sampled frames for video inputs."},
)
use_audio_in_video: bool = field(
default=False,
metadata={"help": "Whether or not to use audio in video inputs."},
)
audio_sampling_rate: int = field(
default=16000,
metadata={"help": "The sampling rate of audio inputs."},
)
def __post_init__(self):
if self.image_max_pixels < self.image_min_pixels:
raise ValueError("`image_max_pixels` cannot be smaller than `image_min_pixels`.")
if self.video_max_pixels < self.video_min_pixels:
raise ValueError("`video_max_pixels` cannot be smaller than `video_min_pixels`.")
@dataclass
class ExportArguments:
r"""Arguments pertaining to the model export."""
export_dir: Optional[str] = field(
default=None,
metadata={"help": "Path to the directory to save the exported model."},
)
export_size: int = field(
default=5,
metadata={"help": "The file shard size (in GB) of the exported model."},
)
export_device: Literal["cpu", "auto"] = field(
default="cpu",
metadata={"help": "The device used in model export, use `auto` to accelerate exporting."},
)
export_quantization_bit: Optional[int] = field(
default=None,
metadata={"help": "The number of bits to quantize the exported model."},
)
export_quantization_dataset: Optional[str] = field(
default=None,
metadata={"help": "Path to the dataset or dataset name to use in quantizing the exported model."},
)
export_quantization_nsamples: int = field(
default=128,
metadata={"help": "The number of samples used for quantization."},
)
export_quantization_maxlen: int = field(
default=1024,
metadata={"help": "The maximum length of the model inputs used for quantization."},
)
export_legacy_format: bool = field(
default=False,
metadata={"help": "Whether or not to save the `.bin` files instead of `.safetensors`."},
)
export_hub_model_id: Optional[str] = field(
default=None,
metadata={"help": "The name of the repository if push the model to the Hugging Face hub."},
)
def __post_init__(self):
if self.export_quantization_bit is not None and self.export_quantization_dataset is None:
raise ValueError("Quantization dataset is necessary for exporting.")
@dataclass
class VllmArguments:
r"""Arguments pertaining to the vLLM worker."""
vllm_maxlen: int = field(
default=4096,
metadata={"help": "Maximum sequence (prompt + response) length of the vLLM engine."},
)
vllm_gpu_util: float = field(
default=0.7,
metadata={"help": "The fraction of GPU memory in (0,1) to be used for the vLLM engine."},
)
vllm_enforce_eager: bool = field(
default=False,
metadata={"help": "Whether or not to disable CUDA graph in the vLLM engine."},
)
vllm_max_lora_rank: int = field(
default=32,
metadata={"help": "Maximum rank of all LoRAs in the vLLM engine."},
)
vllm_config: Optional[Union[dict, str]] = field(
default=None,
metadata={"help": "Config to initialize the vllm engine. Please use JSON strings."},
)
def __post_init__(self):
if isinstance(self.vllm_config, str) and self.vllm_config.startswith("{"):
self.vllm_config = _convert_str_dict(json.loads(self.vllm_config))
@dataclass
class SGLangArguments:
r"""Arguments pertaining to the SGLang worker."""
sglang_maxlen: int = field(
default=4096,
metadata={"help": "Maximum sequence (prompt + response) length of the SGLang engine."},
)
sglang_mem_fraction: float = field(
default=0.7,
metadata={"help": "The memory fraction (0-1) to be used for the SGLang engine."},
)
sglang_tp_size: int = field(
default=-1,
metadata={"help": "Tensor parallel size for the SGLang engine."},
)
sglang_config: Optional[Union[dict, str]] = field(
default=None,
metadata={"help": "Config to initialize the SGLang engine. Please use JSON strings."},
)
sglang_lora_backend: Literal["triton", "flashinfer"] = field(
default="triton",
metadata={
"help": "The backend of running GEMM kernels for Lora modules. Recommend using the Triton LoRA backend for better performance and stability."
},
)
def __post_init__(self):
if isinstance(self.sglang_config, str) and self.sglang_config.startswith("{"):
self.sglang_config = _convert_str_dict(json.loads(self.sglang_config))
@dataclass
class ModelArguments(
SGLangArguments, VllmArguments, ExportArguments, ProcessorArguments, QuantizationArguments, BaseModelArguments
):
r"""Arguments pertaining to which model/config/tokenizer we are going to fine-tune or infer.
The class on the most right will be displayed first.
"""
compute_dtype: Optional[torch.dtype] = field(
default=None,
init=False,
metadata={"help": "Torch data type for computing model outputs, derived from `fp/bf16`. Do not specify it."},
)
device_map: Optional[Union[str, dict[str, Any]]] = field(
default=None,
init=False,
metadata={"help": "Device map for model placement, derived from training stage. Do not specify it."},
)
model_max_length: Optional[int] = field(
default=None,
init=False,
metadata={"help": "The maximum input length for model, derived from `cutoff_len`. Do not specify it."},
)
block_diag_attn: bool = field(
default=False,
init=False,
metadata={"help": "Whether use block diag attention or not, derived from `neat_packing`. Do not specify it."},
)
def __post_init__(self):
BaseModelArguments.__post_init__(self)
ProcessorArguments.__post_init__(self)
ExportArguments.__post_init__(self)
VllmArguments.__post_init__(self)
SGLangArguments.__post_init__(self)
@classmethod
def copyfrom(cls, source: "Self", **kwargs) -> "Self":
init_args, lazy_args = {}, {}
for attr in fields(source):
if attr.init:
init_args[attr.name] = getattr(source, attr.name)
else:
lazy_args[attr.name] = getattr(source, attr.name)
init_args.update(kwargs)
result = cls(**init_args)
for name, value in lazy_args.items():
setattr(result, name, value)
return result
def to_dict(self) -> dict[str, Any]:
args = asdict(self)
args = {k: f"<{k.upper()}>" if k.endswith("token") else v for k, v in args.items()}
return args
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
from pathlib import Path
from typing import Any, Optional, Union
import torch
import transformers
from omegaconf import OmegaConf
from transformers import HfArgumentParser
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.trainer_utils import get_last_checkpoint
from transformers.training_args import ParallelMode
from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_available
from ..extras import logging
from ..extras.constants import CHECKPOINT_NAMES, EngineName
from ..extras.misc import check_dependencies, check_version, get_current_device, is_env_enabled
from .data_args import DataArguments
from .evaluation_args import EvaluationArguments
from .finetuning_args import FinetuningArguments
from .generating_args import GeneratingArguments
from .model_args import ModelArguments
from .training_args import RayArguments, TrainingArguments
logger = logging.get_logger(__name__)
check_dependencies()
_TRAIN_ARGS = [ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments]
_TRAIN_CLS = tuple[ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments]
_INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
_INFER_CLS = tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
_EVAL_ARGS = [ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
_EVAL_CLS = tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
def read_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> Union[dict[str, Any], list[str]]:
r"""Get arguments from the command line or a config file."""
if args is not None:
return args
if sys.argv[1].endswith(".yaml") or sys.argv[1].endswith(".yml"):
override_config = OmegaConf.from_cli(sys.argv[2:])
dict_config = OmegaConf.load(Path(sys.argv[1]).absolute())
return OmegaConf.to_container(OmegaConf.merge(dict_config, override_config))
elif sys.argv[1].endswith(".json"):
override_config = OmegaConf.from_cli(sys.argv[2:])
dict_config = OmegaConf.load(Path(sys.argv[1]).absolute())
return OmegaConf.to_container(OmegaConf.merge(dict_config, override_config))
else:
return sys.argv[1:]
def _parse_args(
parser: "HfArgumentParser", args: Optional[Union[dict[str, Any], list[str]]] = None, allow_extra_keys: bool = False
) -> tuple[Any]:
args = read_args(args)
if isinstance(args, dict):
return parser.parse_dict(args, allow_extra_keys=allow_extra_keys)
(*parsed_args, unknown_args) = parser.parse_args_into_dataclasses(args=args, return_remaining_strings=True)
if unknown_args and not allow_extra_keys:
print(parser.format_help())
print(f"Got unknown args, potentially deprecated arguments: {unknown_args}")
raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}")
return tuple(parsed_args)
def _set_transformers_logging() -> None:
if os.getenv("LLAMAFACTORY_VERBOSITY", "INFO") in ["DEBUG", "INFO"]:
transformers.utils.logging.set_verbosity_info()
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
def _set_env_vars() -> None:
if is_torch_npu_available():
# avoid JIT compile on NPU devices, see https://zhuanlan.zhihu.com/p/660875458
torch.npu.set_compile_mode(jit_compile=is_env_enabled("NPU_JIT_COMPILE"))
# avoid use fork method on NPU devices, see https://github.com/hiyouga/LLaMA-Factory/issues/7447
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
def _verify_model_args(
model_args: "ModelArguments",
data_args: "DataArguments",
finetuning_args: "FinetuningArguments",
) -> None:
if model_args.adapter_name_or_path is not None and finetuning_args.finetuning_type != "lora":
raise ValueError("Adapter is only valid for the LoRA method.")
if model_args.quantization_bit is not None:
if finetuning_args.finetuning_type != "lora":
raise ValueError("Quantization is only compatible with the LoRA method.")
if finetuning_args.pissa_init:
raise ValueError("Please use scripts/pissa_init.py to initialize PiSSA for a quantized model.")
if model_args.resize_vocab:
raise ValueError("Cannot resize embedding layers of a quantized model.")
if model_args.adapter_name_or_path is not None and finetuning_args.create_new_adapter:
raise ValueError("Cannot create new adapter upon a quantized model.")
if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1:
raise ValueError("Quantized model only accepts a single adapter. Merge them first.")
if data_args.template == "yi" and model_args.use_fast_tokenizer:
logger.warning_rank0("We should use slow tokenizer for the Yi models. Change `use_fast_tokenizer` to False.")
model_args.use_fast_tokenizer = False
def _check_extra_dependencies(
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
training_args: Optional["TrainingArguments"] = None,
) -> None:
if model_args.use_unsloth:
check_version("unsloth", mandatory=True)
if model_args.enable_liger_kernel:
check_version("liger-kernel", mandatory=True)
if model_args.mixture_of_depths is not None:
check_version("mixture-of-depth>=1.1.6", mandatory=True)
if model_args.infer_backend == EngineName.VLLM:
check_version("vllm>=0.4.3,<=0.9.1")
check_version("vllm", mandatory=True)
elif model_args.infer_backend == EngineName.SGLANG:
check_version("sglang>=0.4.5")
check_version("sglang", mandatory=True)
if finetuning_args.use_galore:
check_version("galore_torch", mandatory=True)
if finetuning_args.use_apollo:
check_version("apollo_torch", mandatory=True)
if finetuning_args.use_badam:
check_version("badam>=1.2.1", mandatory=True)
if finetuning_args.use_adam_mini:
check_version("adam-mini", mandatory=True)
if finetuning_args.use_swanlab:
check_version("swanlab", mandatory=True)
if finetuning_args.plot_loss:
check_version("matplotlib", mandatory=True)
if training_args is not None:
if training_args.deepspeed:
# pin deepspeed version < 0.17 because of https://github.com/deepspeedai/DeepSpeed/issues/7347
check_version("deepspeed>=0.10.0,<=0.16.9", mandatory=True)
if training_args.predict_with_generate:
check_version("jieba", mandatory=True)
check_version("nltk", mandatory=True)
check_version("rouge_chinese", mandatory=True)
def _parse_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _TRAIN_CLS:
parser = HfArgumentParser(_TRAIN_ARGS)
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
return _parse_args(parser, args, allow_extra_keys=allow_extra_keys)
def _parse_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _INFER_CLS:
parser = HfArgumentParser(_INFER_ARGS)
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
return _parse_args(parser, args, allow_extra_keys=allow_extra_keys)
def _parse_eval_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _EVAL_CLS:
parser = HfArgumentParser(_EVAL_ARGS)
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
return _parse_args(parser, args, allow_extra_keys=allow_extra_keys)
def get_ray_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> RayArguments:
parser = HfArgumentParser(RayArguments)
(ray_args,) = _parse_args(parser, args, allow_extra_keys=True)
return ray_args
def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _TRAIN_CLS:
model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_args(args)
# Setup logging
if training_args.should_log:
_set_transformers_logging()
# Check arguments
if finetuning_args.stage != "sft":
if training_args.predict_with_generate:
raise ValueError("`predict_with_generate` cannot be set as True except SFT.")
if data_args.neat_packing:
raise ValueError("`neat_packing` cannot be set as True except SFT.")
if data_args.train_on_prompt or data_args.mask_history:
raise ValueError("`train_on_prompt` or `mask_history` cannot be set as True except SFT.")
if finetuning_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate:
raise ValueError("Please enable `predict_with_generate` to save model predictions.")
if finetuning_args.stage in ["rm", "ppo"] and training_args.load_best_model_at_end:
raise ValueError("RM and PPO stages do not support `load_best_model_at_end`.")
if finetuning_args.stage == "ppo":
if not training_args.do_train:
raise ValueError("PPO training does not support evaluation, use the SFT stage to evaluate models.")
if model_args.shift_attn:
raise ValueError("PPO training is incompatible with S^2-Attn.")
if finetuning_args.reward_model_type == "lora" and model_args.use_unsloth:
raise ValueError("Unsloth does not support lora reward model.")
if training_args.report_to and training_args.report_to[0] not in ["wandb", "tensorboard"]:
raise ValueError("PPO only accepts wandb or tensorboard logger.")
if training_args.parallel_mode == ParallelMode.NOT_DISTRIBUTED:
raise ValueError("Please launch distributed training with `llamafactory-cli` or `torchrun`.")
if training_args.deepspeed and training_args.parallel_mode != ParallelMode.DISTRIBUTED:
raise ValueError("Please use `FORCE_TORCHRUN=1` to launch DeepSpeed training.")
if training_args.max_steps == -1 and data_args.streaming:
raise ValueError("Please specify `max_steps` in streaming mode.")
if training_args.do_train and data_args.dataset is None:
raise ValueError("Please specify dataset for training.")
if (training_args.do_eval or training_args.do_predict) and (
data_args.eval_dataset is None and data_args.val_size < 1e-6
):
raise ValueError("Please specify dataset for evaluation.")
if training_args.predict_with_generate:
if is_deepspeed_zero3_enabled():
raise ValueError("`predict_with_generate` is incompatible with DeepSpeed ZeRO-3.")
if data_args.eval_dataset is None:
raise ValueError("Cannot use `predict_with_generate` if `eval_dataset` is None.")
if finetuning_args.compute_accuracy:
raise ValueError("Cannot use `predict_with_generate` and `compute_accuracy` together.")
if training_args.do_train and model_args.quantization_device_map == "auto":
raise ValueError("Cannot use device map for quantized models in training.")
if finetuning_args.pissa_init and is_deepspeed_zero3_enabled():
raise ValueError("Please use scripts/pissa_init.py to initialize PiSSA in DeepSpeed ZeRO-3.")
if finetuning_args.pure_bf16:
if not (is_torch_bf16_gpu_available() or (is_torch_npu_available() and torch.npu.is_bf16_supported())):
raise ValueError("This device does not support `pure_bf16`.")
if is_deepspeed_zero3_enabled():
raise ValueError("`pure_bf16` is incompatible with DeepSpeed ZeRO-3.")
if training_args.parallel_mode == ParallelMode.DISTRIBUTED:
if finetuning_args.use_galore and finetuning_args.galore_layerwise:
raise ValueError("Distributed training does not support layer-wise GaLore.")
if finetuning_args.use_apollo and finetuning_args.apollo_layerwise:
raise ValueError("Distributed training does not support layer-wise APOLLO.")
if finetuning_args.use_badam:
if finetuning_args.badam_mode == "ratio":
raise ValueError("Radio-based BAdam does not yet support distributed training, use layer-wise BAdam.")
elif not is_deepspeed_zero3_enabled():
raise ValueError("Layer-wise BAdam only supports DeepSpeed ZeRO-3 training.")
if training_args.deepspeed is not None and (finetuning_args.use_galore or finetuning_args.use_apollo):
raise ValueError("GaLore and APOLLO are incompatible with DeepSpeed yet.")
if model_args.infer_backend != EngineName.HF:
raise ValueError("vLLM/SGLang backend is only available for API, CLI and Web.")
if model_args.use_unsloth and is_deepspeed_zero3_enabled():
raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.")
_set_env_vars()
_verify_model_args(model_args, data_args, finetuning_args)
_check_extra_dependencies(model_args, finetuning_args, training_args)
if (
training_args.do_train
and finetuning_args.finetuning_type == "lora"
and model_args.quantization_bit is None
and model_args.resize_vocab
and finetuning_args.additional_target is None
):
logger.warning_rank0(
"Remember to add embedding layers to `additional_target` to make the added tokens trainable."
)
if training_args.do_train and model_args.quantization_bit is not None and (not model_args.upcast_layernorm):
logger.warning_rank0("We recommend enable `upcast_layernorm` in quantized training.")
if training_args.do_train and (not training_args.fp16) and (not training_args.bf16):
logger.warning_rank0("We recommend enable mixed precision training.")
if (
training_args.do_train
and (finetuning_args.use_galore or finetuning_args.use_apollo)
and not finetuning_args.pure_bf16
):
logger.warning_rank0(
"Using GaLore or APOLLO with mixed precision training may significantly increases GPU memory usage."
)
if (not training_args.do_train) and model_args.quantization_bit is not None:
logger.warning_rank0("Evaluating model in 4/8-bit mode may cause lower scores.")
if (not training_args.do_train) and finetuning_args.stage == "dpo" and finetuning_args.ref_model is None:
logger.warning_rank0("Specify `ref_model` for computing rewards at evaluation.")
# Post-process training arguments
training_args.generation_max_length = training_args.generation_max_length or data_args.cutoff_len
training_args.generation_num_beams = data_args.eval_num_beams or training_args.generation_num_beams
training_args.remove_unused_columns = False # important for multimodal dataset
if finetuning_args.finetuning_type == "lora":
# https://github.com/huggingface/transformers/blob/v4.50.0/src/transformers/trainer.py#L782
training_args.label_names = training_args.label_names or ["labels"]
if "swanlab" in training_args.report_to and finetuning_args.use_swanlab:
training_args.report_to.remove("swanlab")
if (
training_args.parallel_mode == ParallelMode.DISTRIBUTED
and training_args.ddp_find_unused_parameters is None
and finetuning_args.finetuning_type == "lora"
):
logger.info_rank0("Set `ddp_find_unused_parameters` to False in DDP training since LoRA is enabled.")
training_args.ddp_find_unused_parameters = False
if finetuning_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type in ["full", "freeze"]:
can_resume_from_checkpoint = False
if training_args.resume_from_checkpoint is not None:
logger.warning_rank0("Cannot resume from checkpoint in current stage.")
training_args.resume_from_checkpoint = None
else:
can_resume_from_checkpoint = True
if (
training_args.resume_from_checkpoint is None
and training_args.do_train
and os.path.isdir(training_args.output_dir)
and not training_args.overwrite_output_dir
and can_resume_from_checkpoint
):
last_checkpoint = get_last_checkpoint(training_args.output_dir)
if last_checkpoint is None and any(
os.path.isfile(os.path.join(training_args.output_dir, name)) for name in CHECKPOINT_NAMES
):
raise ValueError("Output directory already exists and is not empty. Please set `overwrite_output_dir`.")
if last_checkpoint is not None:
training_args.resume_from_checkpoint = last_checkpoint
logger.info_rank0(f"Resuming training from {training_args.resume_from_checkpoint}.")
logger.info_rank0("Change `output_dir` or use `overwrite_output_dir` to avoid.")
if (
finetuning_args.stage in ["rm", "ppo"]
and finetuning_args.finetuning_type == "lora"
and training_args.resume_from_checkpoint is not None
):
logger.warning_rank0(
f"Add {training_args.resume_from_checkpoint} to `adapter_name_or_path` to resume training from checkpoint."
)
# Post-process model arguments
if training_args.bf16 or finetuning_args.pure_bf16:
model_args.compute_dtype = torch.bfloat16
elif training_args.fp16:
model_args.compute_dtype = torch.float16
model_args.device_map = {"": get_current_device()}
model_args.model_max_length = data_args.cutoff_len
model_args.block_diag_attn = data_args.neat_packing
data_args.packing = data_args.packing if data_args.packing is not None else finetuning_args.stage == "pt"
# Log on each process the small summary
logger.info(
f"Process rank: {training_args.process_index}, "
f"world size: {training_args.world_size}, device: {training_args.device}, "
f"distributed training: {training_args.parallel_mode == ParallelMode.DISTRIBUTED}, "
f"compute dtype: {str(model_args.compute_dtype)}"
)
transformers.set_seed(training_args.seed)
return model_args, data_args, training_args, finetuning_args, generating_args
def get_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _INFER_CLS:
model_args, data_args, finetuning_args, generating_args = _parse_infer_args(args)
# Setup logging
_set_transformers_logging()
# Check arguments
if model_args.infer_backend == "vllm":
if finetuning_args.stage != "sft":
raise ValueError("vLLM engine only supports auto-regressive models.")
if model_args.quantization_bit is not None:
raise ValueError("vLLM engine does not support bnb quantization (GPTQ and AWQ are supported).")
if model_args.rope_scaling is not None:
raise ValueError("vLLM engine does not support RoPE scaling.")
if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1:
raise ValueError("vLLM only accepts a single adapter. Merge them first.")
_set_env_vars()
_verify_model_args(model_args, data_args, finetuning_args)
_check_extra_dependencies(model_args, finetuning_args)
# Post-process model arguments
if model_args.export_dir is not None and model_args.export_device == "cpu":
model_args.device_map = {"": torch.device("cpu")}
if data_args.cutoff_len != DataArguments().cutoff_len: # override cutoff_len if it is not default
model_args.model_max_length = data_args.cutoff_len
else:
model_args.device_map = "auto"
return model_args, data_args, finetuning_args, generating_args
def get_eval_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _EVAL_CLS:
model_args, data_args, eval_args, finetuning_args = _parse_eval_args(args)
# Setup logging
_set_transformers_logging()
# Check arguments
if model_args.infer_backend != EngineName.HF:
raise ValueError("vLLM/SGLang backend is only available for API, CLI and Web.")
_set_env_vars()
_verify_model_args(model_args, data_args, finetuning_args)
_check_extra_dependencies(model_args, finetuning_args)
model_args.device_map = "auto"
transformers.set_seed(eval_args.seed)
return model_args, data_args, eval_args, finetuning_args
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from dataclasses import dataclass, field
from typing import Literal, Optional, Union
from transformers import Seq2SeqTrainingArguments
from transformers.training_args import _convert_str_dict
from ..extras.misc import use_ray
@dataclass
class RayArguments:
r"""Arguments pertaining to the Ray training."""
ray_run_name: Optional[str] = field(
default=None,
metadata={"help": "The training results will be saved at `<ray_storage_path>/ray_run_name`."},
)
ray_storage_path: str = field(
default="./saves",
metadata={"help": "The storage path to save training results to"},
)
ray_storage_filesystem: Optional[Literal["s3", "gs", "gcs"]] = field(
default=None,
metadata={"help": "The storage filesystem to use. If None specified, local filesystem will be used."},
)
ray_num_workers: int = field(
default=1,
metadata={"help": "The number of workers for Ray training. Default is 1 worker."},
)
resources_per_worker: Union[dict, str] = field(
default_factory=lambda: {"GPU": 1},
metadata={"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."},
)
placement_strategy: Literal["SPREAD", "PACK", "STRICT_SPREAD", "STRICT_PACK"] = field(
default="PACK",
metadata={"help": "The placement strategy for Ray training. Default is PACK."},
)
ray_init_kwargs: Optional[dict] = field(
default=None,
metadata={"help": "The arguments to pass to ray.init for Ray training. Default is None."},
)
def __post_init__(self):
self.use_ray = use_ray()
if isinstance(self.resources_per_worker, str) and self.resources_per_worker.startswith("{"):
self.resources_per_worker = _convert_str_dict(json.loads(self.resources_per_worker))
if self.ray_storage_filesystem is not None:
if self.ray_storage_filesystem not in ["s3", "gs", "gcs"]:
raise ValueError(
f"ray_storage_filesystem must be one of ['s3', 'gs', 'gcs'], got {self.ray_storage_filesystem}"
)
import pyarrow.fs as fs
if self.ray_storage_filesystem == "s3":
self.ray_storage_filesystem = fs.S3FileSystem()
elif self.ray_storage_filesystem == "gs" or self.ray_storage_filesystem == "gcs":
self.ray_storage_filesystem = fs.GcsFileSystem()
@dataclass
class TrainingArguments(RayArguments, Seq2SeqTrainingArguments):
r"""Arguments pertaining to the trainer."""
def __post_init__(self):
Seq2SeqTrainingArguments.__post_init__(self)
RayArguments.__post_init__(self)
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from llamafactory.train.tuner import run_exp # use absolute import
def launch():
run_exp()
if __name__ == "__main__":
launch()
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .loader import load_config, load_model, load_tokenizer
from .model_utils.misc import find_all_linear_modules
from .model_utils.quantization import QuantizationMethod
from .model_utils.valuehead import load_valuehead_params
__all__ = [
"QuantizationMethod",
"find_all_linear_modules",
"load_config",
"load_model",
"load_tokenizer",
"load_valuehead_params",
]
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from typing import TYPE_CHECKING
import torch
from peft import LoraConfig, LoraModel, PeftModel, TaskType, get_peft_model
from transformers.integrations import is_deepspeed_zero3_enabled
from ..extras import logging
from .model_utils.misc import find_all_linear_modules, find_expanded_modules
from .model_utils.quantization import QuantizationMethod
from .model_utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model
from .model_utils.visual import COMPOSITE_MODELS, get_forbidden_modules, patch_target_modules
if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedModel
from ..hparams import FinetuningArguments, ModelArguments
logger = logging.get_logger(__name__)
def _setup_full_tuning(
model: "PreTrainedModel",
finetuning_args: "FinetuningArguments",
is_trainable: bool,
cast_trainable_params_to_fp32: bool,
) -> None:
if not is_trainable:
return
logger.info_rank0("Fine-tuning method: Full")
forbidden_modules = get_forbidden_modules(model.config, finetuning_args)
for name, param in model.named_parameters():
if not any(forbidden_module in name for forbidden_module in forbidden_modules):
if cast_trainable_params_to_fp32:
param.data = param.data.to(torch.float32)
else:
param.requires_grad_(False)
def _setup_freeze_tuning(
model: "PreTrainedModel",
finetuning_args: "FinetuningArguments",
is_trainable: bool,
cast_trainable_params_to_fp32: bool,
) -> None:
if not is_trainable:
return
logger.info_rank0("Fine-tuning method: Freeze")
if hasattr(model.config, "text_config"): # composite models
config = getattr(model.config, "text_config")
else:
config = model.config
num_layers = (
getattr(config, "num_hidden_layers", None)
or getattr(config, "num_layers", None)
or getattr(config, "n_layer", None)
)
if not num_layers:
raise ValueError("Current model does not support freeze tuning.")
if finetuning_args.use_llama_pro:
if num_layers % finetuning_args.freeze_trainable_layers != 0:
raise ValueError(
f"`num_layers` {num_layers} should be "
f"divisible by `num_layer_trainable` {finetuning_args.freeze_trainable_layers}."
)
stride = num_layers // finetuning_args.freeze_trainable_layers
trainable_layer_ids = range(stride - 1, num_layers + stride - 1, stride)
elif finetuning_args.freeze_trainable_layers > 0: # fine-tuning the last n layers if num_layer_trainable > 0
trainable_layer_ids = range(max(0, num_layers - finetuning_args.freeze_trainable_layers), num_layers)
else: # fine-tuning the first n layers if num_layer_trainable < 0
trainable_layer_ids = range(min(-finetuning_args.freeze_trainable_layers, num_layers))
hidden_modules = set()
non_hidden_modules = set()
for name, _ in model.named_parameters():
if ".0." in name:
hidden_modules.add(name.split(".0.")[-1].split(".")[0])
elif ".1." in name: # MoD starts from layer 1
hidden_modules.add(name.split(".1.")[-1].split(".")[0])
if re.search(r"\.\d+\.", name) is None:
non_hidden_modules.add(name.split(".")[-2]) # remove weight/bias
trainable_layers = []
for module_name in finetuning_args.freeze_trainable_modules:
if module_name != "all" and module_name not in hidden_modules:
raise ValueError(
"Module {} is not found, please choose from {}".format(module_name, ", ".join(hidden_modules))
)
for idx in trainable_layer_ids:
trainable_layers.append(".{:d}.{}".format(idx, module_name if module_name != "all" else ""))
if finetuning_args.freeze_extra_modules:
for module_name in finetuning_args.freeze_extra_modules:
if module_name not in non_hidden_modules:
raise ValueError(
"Module {} is not found, please choose from {}".format(module_name, ", ".join(non_hidden_modules))
)
trainable_layers.append(module_name)
model_type = getattr(model.config, "model_type", None)
if not finetuning_args.freeze_multi_modal_projector and model_type in COMPOSITE_MODELS:
trainable_layers.append(COMPOSITE_MODELS[model_type].projector_key)
forbidden_modules = get_forbidden_modules(model.config, finetuning_args)
for name, param in model.named_parameters():
if any(trainable_layer in name for trainable_layer in trainable_layers) and not any(
forbidden_module in name for forbidden_module in forbidden_modules
):
if cast_trainable_params_to_fp32:
param.data = param.data.to(torch.float32)
else:
param.requires_grad_(False)
logger.info_rank0("Set trainable layers: {}".format(",".join(trainable_layers)))
def _setup_lora_tuning(
config: "PretrainedConfig",
model: "PreTrainedModel",
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
is_trainable: bool,
cast_trainable_params_to_fp32: bool,
) -> "PeftModel":
if is_trainable:
logger.info_rank0("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA"))
adapter_to_resume = None
if model_args.adapter_name_or_path is not None:
is_mergeable = True
if getattr(model, "quantization_method", None): # merge lora in quantized model is unstable
assert len(model_args.adapter_name_or_path) == 1, "Quantized model only accepts a single adapter."
is_mergeable = False
if is_deepspeed_zero3_enabled():
assert len(model_args.adapter_name_or_path) == 1, "Cannot use multiple adapters in DeepSpeed ZeRO-3."
is_mergeable = False
if model_args.use_unsloth:
assert len(model_args.adapter_name_or_path) == 1, "Unsloth model only accepts a single adapter."
is_mergeable = False
if (is_trainable and not finetuning_args.create_new_adapter) or (not is_mergeable):
adapter_to_merge = model_args.adapter_name_or_path[:-1]
adapter_to_resume = model_args.adapter_name_or_path[-1]
else:
adapter_to_merge = model_args.adapter_name_or_path
init_kwargs = {
"subfolder": model_args.adapter_folder,
"offload_folder": model_args.offload_folder,
"cache_dir": model_args.cache_dir,
"revision": model_args.model_revision,
"token": model_args.hf_hub_token,
}
for adapter in adapter_to_merge:
model: LoraModel = PeftModel.from_pretrained(model, adapter, **init_kwargs)
model = model.merge_and_unload()
if len(adapter_to_merge) > 0:
logger.info_rank0(f"Merged {len(adapter_to_merge)} adapter(s).")
if adapter_to_resume is not None: # resume lora training
if model_args.use_unsloth:
model = load_unsloth_peft_model(config, model_args, finetuning_args, is_trainable=is_trainable)
else:
model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_trainable, **init_kwargs)
logger.info_rank0("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path)))
if is_trainable and adapter_to_resume is None: # create new lora weights while training
if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all":
target_modules = find_all_linear_modules(model, finetuning_args.freeze_vision_tower)
else:
target_modules = finetuning_args.lora_target
if finetuning_args.use_llama_pro:
target_modules = find_expanded_modules(model, target_modules, finetuning_args.freeze_trainable_layers)
target_modules = patch_target_modules(model, finetuning_args, target_modules)
if (
finetuning_args.use_dora
and getattr(model, "quantization_method", None) is not None
and getattr(model, "quantization_method", None) != QuantizationMethod.BNB
):
raise ValueError("DoRA is not compatible with PTQ-quantized models.")
if model_args.resize_vocab and finetuning_args.additional_target is None:
input_embeddings = model.get_input_embeddings()
output_embeddings = model.get_output_embeddings()
module_names = set()
for name, module in model.named_modules():
if module in [input_embeddings, output_embeddings]:
module_names.add(name.split(".")[-1])
finetuning_args.additional_target = module_names
logger.warning_rank0("Vocab has been resized, add {} to trainable params.".format(",".join(module_names)))
peft_kwargs = {
"r": finetuning_args.lora_rank,
"target_modules": target_modules,
"lora_alpha": finetuning_args.lora_alpha,
"lora_dropout": finetuning_args.lora_dropout,
"use_rslora": finetuning_args.use_rslora,
"use_dora": finetuning_args.use_dora,
"modules_to_save": finetuning_args.additional_target,
}
if model_args.use_unsloth:
model = get_unsloth_peft_model(model, model_args, peft_kwargs)
else:
if finetuning_args.pissa_init:
if finetuning_args.pissa_iter == -1:
logger.info_rank0("Using PiSSA initialization.")
peft_kwargs["init_lora_weights"] = "pissa"
else:
logger.info_rank0(f"Using PiSSA initialization with FSVD steps {finetuning_args.pissa_iter}.")
peft_kwargs["init_lora_weights"] = f"pissa_niter_{finetuning_args.pissa_iter}"
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
**peft_kwargs,
)
model = get_peft_model(model, lora_config)
if is_trainable and cast_trainable_params_to_fp32:
for param in filter(lambda p: p.requires_grad, model.parameters()):
param.data = param.data.to(torch.float32)
return model
def init_adapter(
config: "PretrainedConfig",
model: "PreTrainedModel",
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
is_trainable: bool,
) -> "PreTrainedModel":
r"""Initialize the adapters.
Support full-parameter, freeze and LoRA training.
Note that the trainable parameters must be cast to float32.
"""
if is_trainable and getattr(model, "quantization_method", None) is not None:
if finetuning_args.finetuning_type != "lora":
raise ValueError("Quantized models can only be used for the LoRA tuning.")
if finetuning_args.pissa_init:
raise ValueError("Cannot initialize PiSSA adapter on quantized models.")
# cast trainable parameters to float32 if:
# 1. is_trainable and not pure_bf16 and not badam and quantization_bit is not None (qlora)
# 2. is_trainable and not pure_bf16 and not badam and not zero3 (zero3 already in fp32)
cast_trainable_params_to_fp32 = False
if not is_trainable:
pass
elif finetuning_args.pure_bf16 or finetuning_args.use_badam:
logger.info_rank0("Pure bf16 / BAdam detected, remaining trainable params in half precision.")
elif model_args.quantization_bit is None and is_deepspeed_zero3_enabled():
logger.info_rank0("DeepSpeed ZeRO3 detected, remaining trainable params in float32.")
else:
logger.info_rank0("Upcasting trainable params to float32.")
cast_trainable_params_to_fp32 = True
if finetuning_args.finetuning_type == "full":
_setup_full_tuning(model, finetuning_args, is_trainable, cast_trainable_params_to_fp32)
elif finetuning_args.finetuning_type == "freeze":
_setup_freeze_tuning(model, finetuning_args, is_trainable, cast_trainable_params_to_fp32)
elif finetuning_args.finetuning_type == "lora":
model = _setup_lora_tuning(
config, model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32
)
else:
raise NotImplementedError(f"Unknown finetuning type: {finetuning_args.finetuning_type}.")
return model
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import TYPE_CHECKING, Any, Optional, TypedDict
import torch
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoModelForImageTextToText,
AutoModelForSeq2SeqLM,
AutoModelForTextToWaveform,
AutoModelForVision2Seq,
AutoProcessor,
AutoTokenizer,
)
from trl import AutoModelForCausalLMWithValueHead
from ..extras import logging
from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_other_hub
from .adapter import init_adapter
from .model_utils.liger_kernel import apply_liger_kernel
from .model_utils.misc import register_autoclass
from .model_utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model
from .model_utils.unsloth import load_unsloth_pretrained_model
from .model_utils.valuehead import load_valuehead_params
from .patcher import patch_config, patch_model, patch_processor, patch_tokenizer, patch_valuehead_model
if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
from ..hparams import FinetuningArguments, ModelArguments
logger = logging.get_logger(__name__)
class TokenizerModule(TypedDict):
tokenizer: "PreTrainedTokenizer"
processor: Optional["ProcessorMixin"]
def _get_init_kwargs(model_args: "ModelArguments") -> dict[str, Any]:
r"""Get arguments to load config/tokenizer/model.
Note: including inplace operation of model_args.
"""
skip_check_imports()
model_args.model_name_or_path = try_download_model_from_other_hub(model_args)
return {
"trust_remote_code": model_args.trust_remote_code,
"cache_dir": model_args.cache_dir,
"revision": model_args.model_revision,
"token": model_args.hf_hub_token,
}
def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
r"""Load pretrained tokenizer and optionally loads processor.
Note: including inplace operation of model_args.
"""
init_kwargs = _get_init_kwargs(model_args)
try:
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
use_fast=model_args.use_fast_tokenizer,
split_special_tokens=model_args.split_special_tokens,
padding_side="right",
**init_kwargs,
)
except ValueError: # try another one
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
use_fast=not model_args.use_fast_tokenizer,
padding_side="right",
**init_kwargs,
)
except Exception as e:
raise OSError("Failed to load tokenizer.") from e
patch_tokenizer(tokenizer, model_args)
try:
processor = AutoProcessor.from_pretrained(
model_args.model_name_or_path,
use_fast=model_args.use_fast_tokenizer,
**init_kwargs,
)
except ValueError: # try another one
processor = AutoProcessor.from_pretrained(
model_args.model_name_or_path,
use_fast=not model_args.use_fast_tokenizer,
**init_kwargs,
)
except Exception as e:
logger.info_rank0(f"Failed to load processor: {e}.")
processor = None
# Avoid load tokenizer, see:
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/auto/processing_auto.py#L324
if processor is not None and "Processor" not in processor.__class__.__name__:
logger.debug("The loaded processor is not an instance of Processor. Dropping it.")
processor = None
if processor is not None:
patch_processor(processor, tokenizer, model_args)
return {"tokenizer": tokenizer, "processor": processor}
def load_config(model_args: "ModelArguments") -> "PretrainedConfig":
r"""Load model config."""
init_kwargs = _get_init_kwargs(model_args)
return AutoConfig.from_pretrained(model_args.model_name_or_path, **init_kwargs)
def load_model(
tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
is_trainable: bool = False,
add_valuehead: bool = False,
) -> "PreTrainedModel":
r"""Load pretrained model."""
init_kwargs = _get_init_kwargs(model_args)
config = load_config(model_args)
patch_config(config, tokenizer, model_args, init_kwargs, is_trainable)
apply_liger_kernel(config, model_args, is_trainable, require_logits=(finetuning_args.stage not in ["pt", "sft"]))
model = None
lazy_load = False
if model_args.use_unsloth:
if model_args.adapter_name_or_path is not None:
lazy_load = True
elif is_trainable:
model = load_unsloth_pretrained_model(config, model_args, finetuning_args)
if model is None and not lazy_load:
init_kwargs["config"] = config
init_kwargs["pretrained_model_name_or_path"] = model_args.model_name_or_path
if model_args.mixture_of_depths == "load":
model = load_mod_pretrained_model(**init_kwargs)
else:
if type(config) in AutoModelForVision2Seq._model_mapping.keys(): # image-text
load_class = AutoModelForVision2Seq
elif type(config) in AutoModelForImageTextToText._model_mapping.keys(): # image-text
load_class = AutoModelForImageTextToText
elif type(config) in AutoModelForSeq2SeqLM._model_mapping.keys(): # audio-text
load_class = AutoModelForSeq2SeqLM
elif type(config) in AutoModelForTextToWaveform._model_mapping.keys(): # audio hack for qwen2_5_omni
load_class = AutoModelForTextToWaveform
else:
load_class = AutoModelForCausalLM
if model_args.train_from_scratch:
model = load_class.from_config(config, trust_remote_code=model_args.trust_remote_code)
else:
model = load_class.from_pretrained(**init_kwargs)
if getattr(model.config, "model_type", None) == "qwen2_5_omni":
model = model.thinker # use part of Omni model
if model_args.mixture_of_depths == "convert":
model = convert_pretrained_model_to_mod(model, config, model_args)
if not lazy_load:
patch_model(model, tokenizer, model_args, is_trainable, add_valuehead)
register_autoclass(config, model, tokenizer)
model = init_adapter(config, model, model_args, finetuning_args, is_trainable)
if add_valuehead:
model = AutoModelForCausalLMWithValueHead.from_pretrained(model)
patch_valuehead_model(model)
if model_args.adapter_name_or_path is not None:
vhead_path = model_args.adapter_name_or_path[-1]
else:
vhead_path = model_args.model_name_or_path
vhead_params = load_valuehead_params(vhead_path, model_args)
if vhead_params is not None:
model.load_state_dict(vhead_params, strict=False)
logger.info_rank0(f"Loaded valuehead from checkpoint: {vhead_path}")
if not is_trainable:
model.requires_grad_(False)
for param in model.parameters():
if param.data.dtype == torch.float32 and model_args.compute_dtype != torch.float32:
param.data = param.data.to(model_args.compute_dtype)
model.eval()
else:
model.train()
trainable_params, all_param = count_parameters(model)
if is_trainable:
param_stats = (
f"trainable params: {trainable_params:,} || "
f"all params: {all_param:,} || trainable%: {100 * trainable_params / all_param:.4f}"
)
else:
param_stats = f"all params: {all_param:,}"
logger.info_rank0(param_stats)
if model_args.print_param_status and int(os.getenv("LOCAL_RANK", "0")) == 0:
for name, param in model.named_parameters():
print(f"name: {name}, dtype: {param.dtype}, device: {param.device}, trainable: {param.requires_grad}")
return model
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available
from ...extras import logging
from ...extras.constants import AttentionFunction
if TYPE_CHECKING:
from transformers import PretrainedConfig
from ...hparams import ModelArguments
logger = logging.get_logger(__name__)
def configure_attn_implementation(config: "PretrainedConfig", model_args: "ModelArguments") -> None:
if getattr(config, "model_type", None) == "gemma2":
if model_args.flash_attn == AttentionFunction.AUTO or model_args.flash_attn == AttentionFunction.FA2:
if is_flash_attn_2_available():
if model_args.flash_attn != AttentionFunction.FA2:
logger.warning_rank0("Gemma 2 should use flash attention 2, change `flash_attn` to fa2.")
model_args.flash_attn = AttentionFunction.FA2
else:
logger.warning_rank0("FlashAttention-2 is not installed, use eager attention.")
model_args.flash_attn = AttentionFunction.DISABLED
elif model_args.flash_attn == AttentionFunction.SDPA:
logger.warning_rank0(
"Gemma-2 should use soft-capping attention, while the SDPA attention does not support it."
)
if model_args.flash_attn == AttentionFunction.AUTO:
return
elif model_args.flash_attn == AttentionFunction.DISABLED:
requested_attn_implementation = "eager"
elif model_args.flash_attn == AttentionFunction.SDPA:
if not is_torch_sdpa_available():
logger.warning_rank0("torch>=2.1.1 is required for SDPA attention.")
return
requested_attn_implementation = "sdpa"
elif model_args.flash_attn == AttentionFunction.FA2:
if not is_flash_attn_2_available():
logger.warning_rank0("FlashAttention-2 is not installed.")
return
requested_attn_implementation = "flash_attention_2"
else:
raise NotImplementedError(f"Unknown attention type: {model_args.flash_attn}")
if getattr(config, "model_type", None) == "internlm2": # special case for custom models
setattr(config, "attn_implementation", requested_attn_implementation)
elif getattr(config, "model_type", None) == "kimi_vl":
setattr(config.vision_config, "_attn_implementation", requested_attn_implementation)
setattr(config.text_config, "_attn_implementation", requested_attn_implementation)
else:
setattr(config, "_attn_implementation", requested_attn_implementation)
def print_attn_implementation(config: "PretrainedConfig") -> None:
if getattr(config, "model_type", None) == "internlm2": # special case for custom models
attn_implementation = getattr(config, "attn_implementation", None)
else:
attn_implementation = getattr(config, "_attn_implementation", None)
if attn_implementation == "flash_attention_2":
logger.info_rank0("Using FlashAttention-2 for faster training and inference.")
elif attn_implementation == "sdpa":
logger.info_rank0("Using torch SDPA for faster training and inference.")
else:
logger.info_rank0("Using vanilla attention implementation.")
# Copyright 2025 HuggingFace Inc., Daniel Han-Chen & the Unsloth team and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's Transformers and PEFT library,
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/modeling_utils.py
# https://github.com/huggingface/peft/blob/v0.10.0/src/peft/utils/other.py
# and the Unsloth library.
# https://github.com/unslothai/unsloth/blob/July-2024/unsloth/models/_utils.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from functools import WRAPPER_ASSIGNMENTS, partial, wraps
from types import MethodType
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
import torch
from ...extras import logging
from ...extras.constants import LAYERNORM_NAMES
if TYPE_CHECKING:
from transformers import PreTrainedModel
from ...hparams import ModelArguments
logger = logging.get_logger(__name__)
def get_unsloth_gradient_checkpointing_func() -> Callable:
class UnslothGradientCheckpointing(torch.autograd.Function):
r"""Saves VRAM by smartly offloading to RAM."""
@staticmethod
@torch.cuda.amp.custom_fwd
def forward(
ctx: "torch.autograd.Function",
forward_function: "torch.Module",
hidden_states: "torch.Tensor",
*args: Union["torch.Tensor", Any],
) -> "torch.Tensor":
saved_hidden_states = hidden_states.to("cpu", non_blocking=True)
with torch.no_grad():
outputs = forward_function(hidden_states, *args)
ctx.save_for_backward(saved_hidden_states)
ctx.forward_function = forward_function
ctx.args = args
return outputs
@staticmethod
@torch.cuda.amp.custom_bwd
def backward(ctx: "torch.autograd.Function", grad_output: "torch.Tensor") -> "torch.Tensor":
(hidden_states,) = ctx.saved_tensors
hidden_states = hidden_states.to("cuda", non_blocking=True).detach()
hidden_states.requires_grad_(True)
with torch.enable_grad():
outputs = ctx.forward_function(hidden_states, *ctx.args)
output = outputs[0] if isinstance(outputs, tuple) else outputs
torch.autograd.backward(output, grad_output)
return (None, hidden_states.grad) + (None,) * len(ctx.args)
return UnslothGradientCheckpointing.apply
def get_custom_gradient_checkpointing_func(gradient_checkpointing_func: Callable) -> Callable:
r"""Only applies gradient checkpointing to trainable layers."""
@wraps(gradient_checkpointing_func, assigned=WRAPPER_ASSIGNMENTS + ("__self__",))
def custom_gradient_checkpointing_func(func: Callable, *args: Union["torch.Tensor", Any], **kwargs):
if isinstance(func, partial):
module: torch.nn.Module = func.func.__self__
else:
module: torch.nn.Module = func.__self__
has_grad = False
if any(param.requires_grad for param in module.parameters()):
has_grad = True
for arg in args:
if torch.is_tensor(arg) and torch.is_floating_point(arg):
arg.requires_grad_(True)
break # assume the first tensor is always the hidden states
if has_grad:
return gradient_checkpointing_func(func, *args, **kwargs)
else:
return func(*args, **kwargs)
return custom_gradient_checkpointing_func
def _gradient_checkpointing_enable(
self: "PreTrainedModel",
gradient_checkpointing_kwargs: Optional[dict[str, Any]] = None,
use_unsloth_gc: bool = False,
) -> None:
r"""Activates gradient checkpointing for the current model.
Modification of the original method to enable gradient checkpointing for block-wise optimizer.
"""
from torch.utils.checkpoint import checkpoint
if not self.supports_gradient_checkpointing:
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
if gradient_checkpointing_kwargs is None:
gradient_checkpointing_kwargs = {"use_reentrant": True}
if use_unsloth_gc:
gradient_checkpointing_func = get_unsloth_gradient_checkpointing_func()
else:
gradient_checkpointing_func = partial(checkpoint, **gradient_checkpointing_kwargs)
gradient_checkpointing_func = get_custom_gradient_checkpointing_func(gradient_checkpointing_func)
if "value" in inspect.signature(self._set_gradient_checkpointing).parameters: # old GC format
self.apply(partial(self._set_gradient_checkpointing, value=True))
self.enable_input_require_grads()
logger.warning_rank0_once("You are using the old GC format, some features (e.g. BAdam) will be invalid.")
else: # have already enabled input require gradients
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func)
def _fp32_forward_post_hook(
module: "torch.nn.Module", args: tuple["torch.Tensor"], output: "torch.Tensor"
) -> "torch.Tensor":
return output.to(torch.float32)
def prepare_model_for_training(model: "PreTrainedModel", model_args: "ModelArguments") -> None:
r"""Prepare the model before training.
Include:
(1) cast the layernorm in fp32
(2) make output embedding layer require grads
(3) add the upcasting of the lm_head in fp32.
"""
if model_args.upcast_layernorm:
logger.info_rank0("Upcasting layernorm weights in float32.")
for name, param in model.named_parameters():
if param.ndim == 1 and any(ln_name in name for ln_name in LAYERNORM_NAMES):
param.data = param.data.to(torch.float32)
if not model_args.disable_gradient_checkpointing:
if not getattr(model, "supports_gradient_checkpointing", False):
logger.warning_rank0("Current model does not support gradient checkpointing.")
else:
# use_reentrant=False might increase VRAM usage (have not been empirically verified yet)
# According to: https://github.com/huggingface/transformers/issues/28339
gradient_checkpointing_enable = partial(
_gradient_checkpointing_enable, use_unsloth_gc=model_args.use_unsloth_gc
)
model.gradient_checkpointing_enable = MethodType(gradient_checkpointing_enable, model)
model.gradient_checkpointing_enable(
gradient_checkpointing_kwargs={"use_reentrant": model_args.use_reentrant_gc}
)
setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled
logger.info_rank0("Gradient checkpointing enabled.")
if model_args.upcast_lmhead_output:
output_layer = model.get_output_embeddings()
if isinstance(output_layer, torch.nn.Linear) and output_layer.weight.dtype != torch.float32:
logger.info_rank0("Upcasting lm_head outputs in float32.")
output_layer.register_forward_hook(_fp32_forward_post_hook)
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from contextlib import nullcontext
from typing import TYPE_CHECKING
import torch
from transformers.integrations import is_deepspeed_zero3_enabled
from ...extras import logging
if TYPE_CHECKING:
from transformers import PreTrainedModel, PreTrainedTokenizer
logger = logging.get_logger(__name__)
def _noisy_mean_initialization(embed_weight: "torch.Tensor", num_new_tokens: int) -> None:
embedding_dim = embed_weight.size(1)
avg_weight = embed_weight[:-num_new_tokens].mean(dim=0, keepdim=True)
noise_weight = torch.empty_like(embed_weight[-num_new_tokens:])
noise_weight.normal_(mean=0, std=(1.0 / math.sqrt(embedding_dim)))
embed_weight[-num_new_tokens:] = avg_weight + noise_weight
def resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer") -> None:
r"""Resize token embeddings."""
if is_deepspeed_zero3_enabled():
import deepspeed # type: ignore
params = [model.get_input_embeddings().weight]
if model.get_output_embeddings() is not None and not model.config.tie_word_embeddings:
params.append(model.get_output_embeddings().weight)
context_maybe_zero3 = deepspeed.zero.GatheredParameters(params, modifier_rank=0)
else:
context_maybe_zero3 = nullcontext()
with context_maybe_zero3:
current_embedding_size = model.get_input_embeddings().weight.size(0)
if len(tokenizer) > current_embedding_size:
if getattr(model, "quantization_method", None):
raise ValueError("Cannot resize embedding layers of a quantized model.")
if not isinstance(model.get_output_embeddings(), torch.nn.Linear):
raise ValueError("Current model does not support resizing embedding layers.")
model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64)
with context_maybe_zero3:
new_embedding_size = model.get_input_embeddings().weight.size(0)
num_new_tokens = new_embedding_size - current_embedding_size
_noisy_mean_initialization(model.get_input_embeddings().weight.data, num_new_tokens)
_noisy_mean_initialization(model.get_output_embeddings().weight.data, num_new_tokens)
logger.info_rank0(f"Resized token embeddings from {current_embedding_size} to {new_embedding_size}.")
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...extras import logging
logger = logging.get_logger(__name__)
if TYPE_CHECKING:
from transformers import PretrainedConfig
from ...hparams import ModelArguments
def configure_kv_cache(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
if not is_trainable:
setattr(config, "use_cache", model_args.use_cache)
if hasattr(config, "text_config"):
setattr(config.text_config, "use_cache", model_args.use_cache)
if model_args.use_cache:
logger.info_rank0("KV cache is enabled for faster generation.")
else:
logger.info_rank0("KV cache is disabled.")
else:
setattr(config, "use_cache", False)
if hasattr(config, "text_config"):
setattr(config.text_config, "use_cache", False)
logger.info_rank0("KV cache is disabled during training.")
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import TYPE_CHECKING
from ...extras import logging
if TYPE_CHECKING:
from transformers import PretrainedConfig
from ...hparams import ModelArguments
logger = logging.get_logger(__name__)
def apply_liger_kernel(
config: "PretrainedConfig",
model_args: "ModelArguments",
is_trainable: bool,
require_logits: bool,
) -> None:
if not is_trainable or not model_args.enable_liger_kernel:
return
model_type = getattr(config, "model_type", None)
if model_type == "gemma":
from liger_kernel.transformers import apply_liger_kernel_to_gemma as apply_liger_kernel
elif model_type == "gemma2":
from liger_kernel.transformers import apply_liger_kernel_to_gemma2 as apply_liger_kernel
elif model_type == "gemma3":
from liger_kernel.transformers import apply_liger_kernel_to_gemma3 as apply_liger_kernel
elif model_type == "gemma3_text":
from liger_kernel.transformers import apply_liger_kernel_to_gemma3_text as apply_liger_kernel
elif model_type == "glm4":
from liger_kernel.transformers import apply_liger_kernel_to_glm4 as apply_liger_kernel
elif model_type == "granite":
from liger_kernel.transformers import apply_liger_kernel_to_granite as apply_liger_kernel
elif model_type == "llama":
from liger_kernel.transformers import apply_liger_kernel_to_llama as apply_liger_kernel
elif model_type == "llava":
from liger_kernel.transformers import apply_liger_kernel_to_llava as apply_liger_kernel
elif model_type == "mistral":
from liger_kernel.transformers import apply_liger_kernel_to_mistral as apply_liger_kernel
elif model_type == "mixtral":
from liger_kernel.transformers import apply_liger_kernel_to_mixtral as apply_liger_kernel
elif model_type == "mllama":
from liger_kernel.transformers import apply_liger_kernel_to_mllama as apply_liger_kernel
elif model_type == "olmo2":
from liger_kernel.transformers import apply_liger_kernel_to_olmo2 as apply_liger_kernel
elif model_type == "paligemma":
from liger_kernel.transformers import apply_liger_kernel_to_paligemma as apply_liger_kernel
elif model_type == "phi3":
from liger_kernel.transformers import apply_liger_kernel_to_phi3 as apply_liger_kernel
elif model_type == "qwen2":
from liger_kernel.transformers import apply_liger_kernel_to_qwen2 as apply_liger_kernel
elif model_type == "qwen2_vl":
from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl as apply_liger_kernel
elif model_type == "qwen2_5_vl":
from liger_kernel.transformers import apply_liger_kernel_to_qwen2_5_vl as apply_liger_kernel
elif model_type == "qwen3":
from liger_kernel.transformers import apply_liger_kernel_to_qwen3 as apply_liger_kernel
elif model_type == "qwen3_moe":
from liger_kernel.transformers import apply_liger_kernel_to_qwen3_moe as apply_liger_kernel
else:
logger.warning_rank0("Current model does not support liger kernel.")
return
if require_logits and "fused_linear_cross_entropy" in inspect.signature(apply_liger_kernel).parameters:
logger.info_rank0("Current training stage does not support chunked cross entropy.")
kwargs = {"fused_linear_cross_entropy": False, "cross_entropy": True}
else:
kwargs = {}
apply_liger_kernel(**kwargs)
logger.info_rank0("Liger kernel has been applied to the model.")
# Copyright 2025 EleutherAI, HuggingFace Inc., Yukang Chen, and the LlamaFactory team.
#
# This code is based on the EleutherAI's GPT-NeoX and the HuggingFace's Transformers libraries.
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py
# This code is also inspired by the original LongLoRA implementation.
# https://github.com/dvlab-research/LongLoRA/blob/main/llama_attn_replace.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import TYPE_CHECKING, Optional
import torch
import torch.nn as nn
import transformers
from ...extras import logging
from ...extras.constants import SUPPORTED_CLASS_FOR_S2ATTN
from ...extras.misc import check_version
from ...extras.packages import is_transformers_version_greater_than
if not is_transformers_version_greater_than("4.48.0"):
from transformers.modeling_flash_attention_utils import _flash_attention_forward
from transformers.models.llama.modeling_llama import (
Cache,
LlamaAttention,
LlamaFlashAttention2,
LlamaSdpaAttention,
apply_rotary_pos_emb,
repeat_kv,
)
if TYPE_CHECKING:
from transformers import PretrainedConfig
from ...hparams import ModelArguments
transformers_logger = transformers.utils.logging.get_logger(__name__)
# Modified from:
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py
def llama_attention_forward(
self: "LlamaAttention",
hidden_states: "torch.Tensor",
attention_mask: Optional["torch.Tensor"] = None,
position_ids: Optional["torch.LongTensor"] = None,
past_key_value: Optional["Cache"] = None,
output_attentions: bool = False,
cache_position: Optional["torch.LongTensor"] = None,
position_embeddings: Optional[tuple["torch.Tensor", "torch.Tensor"]] = None,
**kwargs,
) -> tuple["torch.Tensor", Optional["torch.Tensor"], Optional[tuple["torch.Tensor"]]]:
bsz, q_len, _ = hidden_states.size()
query_states: torch.Tensor = self.q_proj(hidden_states)
key_states: torch.Tensor = self.k_proj(hidden_states)
value_states: torch.Tensor = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
if position_embeddings is None:
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
if getattr(self.config, "group_size_ratio", None) and self.training: # shift
groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
assert q_len % groupsz == 0, f"q_len {q_len} should be divisible by group size {groupsz}."
num_groups = q_len // groupsz
def shift(state: "torch.Tensor") -> "torch.Tensor":
state = state.transpose(1, 2) # output: (bsz, seq_len, n_heads, head_dim)
state = torch.cat(
(state[:, :, : self.num_heads // 2], state[:, :, self.num_heads // 2 :].roll(-groupsz // 2, dims=1)),
dim=2,
)
return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim).transpose(1, 2)
query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
if attention_mask is not None:
attention_mask = attention_mask[:, :, :groupsz, :groupsz].repeat(num_groups, 1, 1, 1)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states) # (bsz, :, seq_len, :) or (bsz * n_group, :, groupsz, :)
attn_output = attn_output.transpose(1, 2).contiguous()
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
attn_output = torch.cat(
(
attn_output[:, :, : self.num_heads // 2],
attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1),
),
dim=2,
)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
# Modified from:
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py
def llama_flash_attention_2_forward(
self: "LlamaFlashAttention2",
hidden_states: "torch.Tensor",
attention_mask: Optional["torch.Tensor"] = None,
position_ids: Optional["torch.LongTensor"] = None,
past_key_value: Optional["Cache"] = None,
output_attentions: bool = False,
cache_position: Optional["torch.LongTensor"] = None,
position_embeddings: Optional[tuple["torch.Tensor", "torch.Tensor"]] = None,
**kwargs,
) -> tuple["torch.Tensor", Optional["torch.Tensor"], Optional[tuple["torch.Tensor"]]]:
# LlamaFlashAttention2 attention does not support output_attentions
output_attentions = False
bsz, q_len, _ = hidden_states.size()
query_states: torch.Tensor = self.q_proj(hidden_states)
key_states: torch.Tensor = self.k_proj(hidden_states)
value_states: torch.Tensor = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
if position_embeddings is None:
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
# FlashAttention requires the input to have the shape (bsz, seq_len, n_heads, head_dim)
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
dropout_rate = self.attention_dropout if self.training else 0.0
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
transformers_logger.warning_once("The input hidden states seems to be silently casted in float32.")
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
if getattr(self.config, "group_size_ratio", None) and self.training: # shift
groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
assert q_len % groupsz == 0, f"q_len {q_len} should be divisible by group size {groupsz}."
num_groups = q_len // groupsz
def shift(state: "torch.Tensor") -> "torch.Tensor":
state = torch.cat(
(state[:, :, : self.num_heads // 2], state[:, :, self.num_heads // 2 :].roll(-groupsz // 2, dims=1)),
dim=2,
)
return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim)
query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
if attention_mask is not None:
attention_mask = attention_mask[:, :groupsz].repeat(num_groups, 1)
attn_output: torch.Tensor = _flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
query_states.size(1),
dropout=dropout_rate,
sliding_window=getattr(self, "sliding_window", None),
use_top_left_mask=self._flash_attn_uses_top_left_mask,
is_causal=self.is_causal,
)
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
attn_output = torch.cat(
(
attn_output[:, :, : self.num_heads // 2],
attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1),
),
dim=2,
)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
# Modified from:
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py
def llama_sdpa_attention_forward(
self: "LlamaSdpaAttention",
hidden_states: "torch.Tensor",
attention_mask: Optional["torch.Tensor"] = None,
position_ids: Optional["torch.LongTensor"] = None,
past_key_value: Optional["Cache"] = None,
output_attentions: bool = False,
cache_position: Optional["torch.LongTensor"] = None,
position_embeddings: Optional[tuple["torch.Tensor", "torch.Tensor"]] = None,
**kwargs,
) -> tuple["torch.Tensor", Optional["torch.Tensor"], Optional[tuple["torch.Tensor"]]]:
if output_attentions:
transformers_logger.warning_once(
"SDPA does not support `output_attentions=True`. Falling back to the vanilla attention"
)
return llama_attention_forward(
self,
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
cache_position=cache_position,
**kwargs,
)
bsz, q_len, _ = hidden_states.size()
query_states: torch.Tensor = self.q_proj(hidden_states)
key_states: torch.Tensor = self.k_proj(hidden_states)
value_states: torch.Tensor = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
if position_embeddings is None:
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
if getattr(self.config, "group_size_ratio", None) and self.training: # shift
groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
assert q_len % groupsz == 0, f"q_len {q_len} should be divisible by group size {groupsz}."
num_groups = q_len // groupsz
def shift(state: "torch.Tensor") -> "torch.Tensor":
state = state.transpose(1, 2) # output: (bsz, seq_len, n_heads, head_dim)
state = torch.cat(
(state[:, :, : self.num_heads // 2], state[:, :, self.num_heads // 2 :].roll(-groupsz // 2, dims=1)),
dim=2,
)
return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim).transpose(1, 2)
query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
if attention_mask is not None:
attention_mask = attention_mask[:, :, :groupsz, :groupsz].repeat(num_groups, 1, 1, 1)
causal_mask = attention_mask
if attention_mask is not None:
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
if query_states.device.type == "cuda" and causal_mask is not None: # avoid pytorch bug
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
is_causal = True if causal_mask is None and q_len > 1 else False
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=causal_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=is_causal,
)
attn_output = attn_output.transpose(1, 2).contiguous()
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
attn_output = torch.cat(
(
attn_output[:, :, : self.num_heads // 2],
attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1),
),
dim=2,
)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
def _apply_llama_patch() -> None:
check_version("transformers>=4.45.0,<4.48.0", mandatory=True)
LlamaAttention.forward = llama_attention_forward
LlamaFlashAttention2.forward = llama_flash_attention_2_forward
LlamaSdpaAttention.forward = llama_sdpa_attention_forward
def configure_longlora(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
if not is_trainable or not model_args.shift_attn:
return
logger = logging.get_logger(__name__)
if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_S2ATTN:
setattr(config, "group_size_ratio", 0.25)
_apply_llama_patch()
logger.info_rank0("Using shift short attention with group_size_ratio=1/4.")
else:
logger.warning_rank0("Current model does not support shift short attention.")
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...extras import logging
from .visual import COMPOSITE_MODELS
if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer
logger = logging.get_logger(__name__)
def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool) -> list[str]:
r"""Find all available modules to apply LoRA, GaLore or APOLLO."""
model_type = getattr(model.config, "model_type", None)
forbidden_modules = {"lm_head"}
if model_type == "chatglm":
forbidden_modules.add("output_layer")
elif model_type == "internlm2":
forbidden_modules.add("output")
if model_type in COMPOSITE_MODELS:
forbidden_modules.add(COMPOSITE_MODELS[model_type].projector_key)
if freeze_vision_tower and model_type in COMPOSITE_MODELS:
forbidden_modules.update(COMPOSITE_MODELS[model_type].vision_model_keys)
module_names = set()
for name, module in model.named_modules():
if any(forbidden_module in name for forbidden_module in forbidden_modules):
continue
if "Linear" in module.__class__.__name__ and "Embedding" not in module.__class__.__name__:
module_names.add(name.split(".")[-1])
logger.info_rank0("Found linear modules: {}".format(",".join(module_names)))
return list(module_names)
def find_expanded_modules(model: "PreTrainedModel", target_modules: list[str], num_layer_trainable: int) -> list[str]:
r"""Find the modules in the expanded blocks to apply lora."""
num_layers = getattr(model.config, "num_hidden_layers", None)
if not num_layers:
raise ValueError("Model was not supported.")
if num_layers % num_layer_trainable != 0:
raise ValueError(
f"`num_layers` {num_layers} should be divisible by `num_layer_trainable` {num_layer_trainable}."
)
stride = num_layers // num_layer_trainable
trainable_layer_ids = range(stride - 1, num_layers + stride - 1, stride)
trainable_layers = [f".{idx:d}." for idx in trainable_layer_ids]
module_names = []
for name, _ in model.named_modules():
if any(target_module in name for target_module in target_modules) and any(
trainable_layer in name for trainable_layer in trainable_layers
):
module_names.append(name)
logger.info_rank0("Apply lora to layers: {}.".format(",".join(map(str, trainable_layer_ids))))
return module_names
def register_autoclass(config: "PretrainedConfig", model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer"):
if "AutoConfig" in getattr(config, "auto_map", {}):
config.__class__.register_for_auto_class()
if "AutoModelForCausalLM" in getattr(config, "auto_map", {}):
model.__class__.register_for_auto_class()
if "AutoTokenizer" in tokenizer.init_kwargs.get("auto_map", {}):
tokenizer.__class__.register_for_auto_class()
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...extras.constants import MOD_SUPPORTED_MODELS
if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedModel
from ...hparams import ModelArguments
def load_mod_pretrained_model(**init_kwargs) -> "PreTrainedModel":
from MoD import AutoMoDModelForCausalLM
return AutoMoDModelForCausalLM.from_pretrained(**init_kwargs)
def convert_pretrained_model_to_mod(
model: "PreTrainedModel", config: "PretrainedConfig", model_args: "ModelArguments"
) -> "PreTrainedModel":
from MoD import apply_mod_to_hf
if getattr(config, "model_type", None) not in MOD_SUPPORTED_MODELS:
raise ValueError("Current model is not supported by mixture-of-depth.")
model = apply_mod_to_hf(model)
model = model.to(model_args.compute_dtype)
return model
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Union
from transformers.integrations import is_deepspeed_zero3_enabled
from ...extras.misc import check_version
if TYPE_CHECKING:
from torch import nn
from transformers import PretrainedConfig, PreTrainedModel
from ...hparams import ModelArguments
def _set_z3_leaf_modules(model: "PreTrainedModel", leaf_modules: list[Union["nn.Module", str]]) -> None:
check_version("deepspeed>=0.13.0")
from deepspeed.utils import set_z3_leaf_modules # type: ignore
set_z3_leaf_modules(model, leaf_modules)
def add_z3_leaf_module(model: "PreTrainedModel") -> None:
r"""Set module as a leaf module to skip partitioning in deepspeed zero3."""
if not is_deepspeed_zero3_enabled():
return
model_type = getattr(model.config, "model_type", None)
if model_type == "dbrx":
from transformers.models.dbrx.modeling_dbrx import DbrxFFN
_set_z3_leaf_modules(model, [DbrxFFN])
if model_type == "deepseek_v2":
# deepseek v2 uses custom code
_set_z3_leaf_modules(model, ["DeepseekV2MoE"])
if model_type == "deepseek_v3" or model_type == "kimi_vl":
# deepseek v3 and kimi vl use custom code
_set_z3_leaf_modules(model, ["DeepseekV3MoE"])
if model_type == "granitemoe":
from transformers.models.granitemoe.modeling_granitemoe import GraniteMoeMoE
_set_z3_leaf_modules(model, [GraniteMoeMoE])
if model_type == "jamba":
from transformers.models.jamba.modeling_jamba import JambaSparseMoeBlock
_set_z3_leaf_modules(model, [JambaSparseMoeBlock])
if model_type == "jetmoe":
from transformers.models.jetmoe.modeling_jetmoe import JetMoeMoA, JetMoeMoE
_set_z3_leaf_modules(model, [JetMoeMoA, JetMoeMoE])
if model_type == "llama4":
from transformers.models.llama4.modeling_llama4 import Llama4TextMoe
_set_z3_leaf_modules(model, [Llama4TextMoe])
if model_type == "mixtral":
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
_set_z3_leaf_modules(model, [MixtralSparseMoeBlock])
if model_type == "olmoe":
from transformers.models.olmoe.modeling_olmoe import OlmoeSparseMoeBlock
_set_z3_leaf_modules(model, [OlmoeSparseMoeBlock])
if model_type == "phimoe":
from transformers.models.phimoe.modeling_phimoe import PhimoeSparseMoeBlock
_set_z3_leaf_modules(model, [PhimoeSparseMoeBlock])
if model_type == "qwen2_moe":
from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock
_set_z3_leaf_modules(model, [Qwen2MoeSparseMoeBlock])
if model_type == "qwen3_moe":
from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeSparseMoeBlock
_set_z3_leaf_modules(model, [Qwen3MoeSparseMoeBlock])
def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
if not is_trainable or not model_args.moe_aux_loss_coef:
return
model_type = getattr(config, "model_type", None)
if model_type in [
"dbrx",
"granitemoe",
"jamba",
"jetmoe",
"llama4",
"mixtral",
"olmoe",
"phimoe",
"qwen2_moe",
"qwen3_moe",
]:
setattr(config, "output_router_logits", True)
if model_type in ["granitemoe", "jamba", "llama4", "mixtral", "olmoe", "phimoe", "qwen2_moe", "qwen3_moe"]:
setattr(config, "router_aux_loss_coef", model_args.moe_aux_loss_coef)
elif model_type == "deepseek":
setattr(config, "aux_loss_alpha", model_args.moe_aux_loss_coef)
elif model_type == "jetmoe":
setattr(config, "aux_loss_coef", model_args.moe_aux_loss_coef)
# Copyright 2025 Musab Gultekin and the LlamaFactory team.
#
# This code is based on the Musab Gultekin's functionary library.
# https://github.com/MeetKai/functionary/blob/main/functionary/train/packing/monkey_patch_packing.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# MIT License
#
# Copyright (c) 2023 Musab Gultekin
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
from typing import TYPE_CHECKING
import torch
import torch.nn.functional as F
from ...extras import logging
if TYPE_CHECKING:
from ...hparams import ModelArguments
logger = logging.get_logger(__name__)
def get_seqlens_in_batch(attention_mask: "torch.Tensor") -> "torch.Tensor":
r"""Get the sequnce lengths in the current batch.
e.g.
```python
# input
[
[1, 1, 2, 2, 2, 0],
[1, 2, 2, 3, 3, 3],
]
# output
[2, 3, 1, 2, 3]
```
"""
bsz = attention_mask.size(0)
dtype, device = attention_mask.dtype, attention_mask.device
max_num = torch.max(attention_mask).item()
counts: torch.Tensor = torch.zeros((bsz, max_num), dtype=dtype, device=device)
for i in range(max_num):
counts[:, i] = torch.sum(attention_mask == (i + 1), dim=-1)
counts = counts.flatten()
seqlens = counts[counts.nonzero().squeeze(dim=-1)]
return seqlens
def get_unpad_data(attention_mask: "torch.Tensor") -> tuple["torch.Tensor", "torch.Tensor", int]:
r"""Prepare the indices and seqlens for flash attn varlen function.
Returns:
indices: indices of non-masked tokens from the flattened sequence.
cu_seqlens: the cumulative sequence lengths in the current batch, always starts from 0.
max_seqlen_in_batch: the largest seqlen in the current batch.
e.g.
```python
# input
[
[1, 1, 2, 2, 2, 0],
[1, 2, 2, 3, 3, 3],
]
# output
[0, 1, 2, 3, 4, 6, 7, 8, 9, 10, 11]
[0, 2, 5, 6, 8, 11]
3
```
"""
seqlens_in_batch = get_seqlens_in_batch(attention_mask)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
return indices, cu_seqlens, max_seqlen_in_batch
def configure_packing(model_args: "ModelArguments", is_trainable: bool) -> None:
if not is_trainable or not model_args.block_diag_attn:
return
import transformers.modeling_flash_attention_utils
transformers.modeling_flash_attention_utils._get_unpad_data = get_unpad_data
logger.info_rank0("Using block diagonal attention for sequence packing without cross-attention.")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment