Commit ca625f43 authored by shihm's avatar shihm
Browse files

uodata

parent 7164651d
...@@ -16,22 +16,22 @@ ...@@ -16,22 +16,22 @@
# limitations under the License. # limitations under the License.
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from typing import Any, Literal, Optional from typing import Any, Literal
@dataclass @dataclass
class DataArguments: class DataArguments:
r"""Arguments pertaining to what data we are going to input our model for training and evaluation.""" r"""Arguments pertaining to what data we are going to input our model for training and evaluation."""
template: Optional[str] = field( template: str | None = field(
default=None, default=None,
metadata={"help": "Which template to use for constructing prompts in training and inference."}, metadata={"help": "Which template to use for constructing prompts in training and inference."},
) )
dataset: Optional[str] = field( dataset: str | None = field(
default=None, default=None,
metadata={"help": "The name of dataset(s) to use for training. Use commas to separate multiple datasets."}, metadata={"help": "The name of dataset(s) to use for training. Use commas to separate multiple datasets."},
) )
eval_dataset: Optional[str] = field( eval_dataset: str | None = field(
default=None, default=None,
metadata={"help": "The name of dataset(s) to use for evaluation. Use commas to separate multiple datasets."}, metadata={"help": "The name of dataset(s) to use for evaluation. Use commas to separate multiple datasets."},
) )
...@@ -39,7 +39,7 @@ class DataArguments: ...@@ -39,7 +39,7 @@ class DataArguments:
default="data", default="data",
metadata={"help": "Path to the folder containing the datasets."}, metadata={"help": "Path to the folder containing the datasets."},
) )
media_dir: Optional[str] = field( media_dir: str | None = field(
default=None, default=None,
metadata={"help": "Path to the folder containing the images, videos or audios. Defaults to `dataset_dir`."}, metadata={"help": "Path to the folder containing the images, videos or audios. Defaults to `dataset_dir`."},
) )
...@@ -67,7 +67,7 @@ class DataArguments: ...@@ -67,7 +67,7 @@ class DataArguments:
default="concat", default="concat",
metadata={"help": "Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling)."}, metadata={"help": "Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling)."},
) )
interleave_probs: Optional[str] = field( interleave_probs: str | None = field(
default=None, default=None,
metadata={"help": "Probabilities to sample data from datasets. Use commas to separate multiple datasets."}, metadata={"help": "Probabilities to sample data from datasets. Use commas to separate multiple datasets."},
) )
...@@ -79,15 +79,15 @@ class DataArguments: ...@@ -79,15 +79,15 @@ class DataArguments:
default=1000, default=1000,
metadata={"help": "The number of examples in one group in pre-processing."}, metadata={"help": "The number of examples in one group in pre-processing."},
) )
preprocessing_num_workers: Optional[int] = field( preprocessing_num_workers: int | None = field(
default=None, default=None,
metadata={"help": "The number of processes to use for the pre-processing."}, metadata={"help": "The number of processes to use for the pre-processing."},
) )
max_samples: Optional[int] = field( max_samples: int | None = field(
default=None, default=None,
metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."}, metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."},
) )
eval_num_beams: Optional[int] = field( eval_num_beams: int | None = field(
default=None, default=None,
metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"}, metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"},
) )
...@@ -103,7 +103,7 @@ class DataArguments: ...@@ -103,7 +103,7 @@ class DataArguments:
default=False, default=False,
metadata={"help": "Whether or not to evaluate on each dataset separately."}, metadata={"help": "Whether or not to evaluate on each dataset separately."},
) )
packing: Optional[bool] = field( packing: bool | None = field(
default=None, default=None,
metadata={"help": "Enable sequences packing in training. Will automatically enable in pre-training."}, metadata={"help": "Enable sequences packing in training. Will automatically enable in pre-training."},
) )
...@@ -111,19 +111,19 @@ class DataArguments: ...@@ -111,19 +111,19 @@ class DataArguments:
default=False, default=False,
metadata={"help": "Enable sequence packing without cross-attention."}, metadata={"help": "Enable sequence packing without cross-attention."},
) )
tool_format: Optional[str] = field( tool_format: str | None = field(
default=None, default=None,
metadata={"help": "Tool format to use for constructing function calling examples."}, metadata={"help": "Tool format to use for constructing function calling examples."},
) )
default_system: Optional[str] = field( default_system: str | None = field(
default=None, default=None,
metadata={"help": "Override the default system message in the template."}, metadata={"help": "Override the default system message in the template."},
) )
enable_thinking: Optional[bool] = field( enable_thinking: bool | None = field(
default=True, default=True,
metadata={"help": "Whether or not to enable thinking mode for reasoning models."}, metadata={"help": "Whether or not to enable thinking mode for reasoning models."},
) )
tokenized_path: Optional[str] = field( tokenized_path: str | None = field(
default=None, default=None,
metadata={ metadata={
"help": ( "help": (
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
import os import os
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Literal, Optional from typing import Literal
from datasets import DownloadMode from datasets import DownloadMode
...@@ -46,7 +46,7 @@ class EvaluationArguments: ...@@ -46,7 +46,7 @@ class EvaluationArguments:
default=5, default=5,
metadata={"help": "Number of examplars for few-shot learning."}, metadata={"help": "Number of examplars for few-shot learning."},
) )
save_dir: Optional[str] = field( save_dir: str | None = field(
default=None, default=None,
metadata={"help": "Path to save the evaluation results."}, metadata={"help": "Path to save the evaluation results."},
) )
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from typing import Any, Literal, Optional from typing import Any, Literal
@dataclass @dataclass
...@@ -40,7 +40,7 @@ class FreezeArguments: ...@@ -40,7 +40,7 @@ class FreezeArguments:
) )
}, },
) )
freeze_extra_modules: Optional[str] = field( freeze_extra_modules: str | None = field(
default=None, default=None,
metadata={ metadata={
"help": ( "help": (
...@@ -56,7 +56,7 @@ class FreezeArguments: ...@@ -56,7 +56,7 @@ class FreezeArguments:
class LoraArguments: class LoraArguments:
r"""Arguments pertaining to the LoRA training.""" r"""Arguments pertaining to the LoRA training."""
additional_target: Optional[str] = field( additional_target: str | None = field(
default=None, default=None,
metadata={ metadata={
"help": ( "help": (
...@@ -66,7 +66,7 @@ class LoraArguments: ...@@ -66,7 +66,7 @@ class LoraArguments:
) )
}, },
) )
lora_alpha: Optional[int] = field( lora_alpha: int | None = field(
default=None, default=None,
metadata={"help": "The scale factor for LoRA fine-tuning (default: lora_rank * 2)."}, metadata={"help": "The scale factor for LoRA fine-tuning (default: lora_rank * 2)."},
) )
...@@ -88,7 +88,7 @@ class LoraArguments: ...@@ -88,7 +88,7 @@ class LoraArguments:
) )
}, },
) )
loraplus_lr_ratio: Optional[float] = field( loraplus_lr_ratio: float | None = field(
default=None, default=None,
metadata={"help": "LoRA plus learning rate ratio (lr_B / lr_A)."}, metadata={"help": "LoRA plus learning rate ratio (lr_B / lr_A)."},
) )
...@@ -122,6 +122,48 @@ class LoraArguments: ...@@ -122,6 +122,48 @@ class LoraArguments:
) )
@dataclass
class OFTArguments:
r"""Arguments pertaining to the OFT training."""
additional_target: str | None = field(
default=None,
metadata={
"help": (
"Name(s) of modules apart from LoRA layers to be set as trainable "
"and saved in the final checkpoint. "
"Use commas to separate multiple modules."
)
},
)
module_dropout: float = field(
default=0.0,
metadata={"help": "Dropout rate for the OFT fine-tuning."},
)
oft_rank: int = field(
default=0,
metadata={"help": "The intrinsic dimension for OFT fine-tuning."},
)
oft_block_size: int = field(
default=32,
metadata={"help": "The intrinsic dimension for OFT fine-tuning."},
)
oft_target: str = field(
default="all",
metadata={
"help": (
"Name(s) of target modules to apply OFT. "
"Use commas to separate multiple modules. "
"Use `all` to specify all the linear modules."
)
},
)
create_new_adapter: bool = field(
default=False,
metadata={"help": "Whether or not to create a new adapter with randomly initialized weight."},
)
@dataclass @dataclass
class RLHFArguments: class RLHFArguments:
r"""Arguments pertaining to the PPO, DPO and KTO training.""" r"""Arguments pertaining to the PPO, DPO and KTO training."""
...@@ -134,6 +176,10 @@ class RLHFArguments: ...@@ -134,6 +176,10 @@ class RLHFArguments:
default=0.0, default=0.0,
metadata={"help": "The supervised fine-tuning loss coefficient in DPO training."}, metadata={"help": "The supervised fine-tuning loss coefficient in DPO training."},
) )
pref_bco_weight: float = field(
default=0.0,
metadata={"help": "The Binary Classifier Optimization coefficient in DPO training."},
)
pref_loss: Literal["sigmoid", "hinge", "ipo", "kto_pair", "orpo", "simpo"] = field( pref_loss: Literal["sigmoid", "hinge", "ipo", "kto_pair", "orpo", "simpo"] = field(
default="sigmoid", default="sigmoid",
metadata={"help": "The type of DPO loss to use."}, metadata={"help": "The type of DPO loss to use."},
...@@ -174,27 +220,27 @@ class RLHFArguments: ...@@ -174,27 +220,27 @@ class RLHFArguments:
default=False, default=False,
metadata={"help": "Whiten the rewards before compute advantages in PPO training."}, metadata={"help": "Whiten the rewards before compute advantages in PPO training."},
) )
ref_model: Optional[str] = field( ref_model: str | None = field(
default=None, default=None,
metadata={"help": "Path to the reference model used for the PPO or DPO training."}, metadata={"help": "Path to the reference model used for the PPO or DPO training."},
) )
ref_model_adapters: Optional[str] = field( ref_model_adapters: str | None = field(
default=None, default=None,
metadata={"help": "Path to the adapters of the reference model."}, metadata={"help": "Path to the adapters of the reference model."},
) )
ref_model_quantization_bit: Optional[int] = field( ref_model_quantization_bit: int | None = field(
default=None, default=None,
metadata={"help": "The number of bits to quantize the reference model."}, metadata={"help": "The number of bits to quantize the reference model."},
) )
reward_model: Optional[str] = field( reward_model: str | None = field(
default=None, default=None,
metadata={"help": "Path to the reward model used for the PPO training."}, metadata={"help": "Path to the reward model used for the PPO training."},
) )
reward_model_adapters: Optional[str] = field( reward_model_adapters: str | None = field(
default=None, default=None,
metadata={"help": "Path to the adapters of the reward model."}, metadata={"help": "Path to the adapters of the reward model."},
) )
reward_model_quantization_bit: Optional[int] = field( reward_model_quantization_bit: int | None = field(
default=None, default=None,
metadata={"help": "The number of bits to quantize the reward model."}, metadata={"help": "The number of bits to quantize the reward model."},
) )
...@@ -202,7 +248,7 @@ class RLHFArguments: ...@@ -202,7 +248,7 @@ class RLHFArguments:
default="lora", default="lora",
metadata={"help": "The type of the reward model in PPO training. Lora model only supports lora training."}, metadata={"help": "The type of the reward model in PPO training. Lora model only supports lora training."},
) )
ld_alpha: Optional[float] = field( ld_alpha: float | None = field(
default=None, default=None,
metadata={ metadata={
"help": ( "help": (
...@@ -315,15 +361,15 @@ class BAdamArgument: ...@@ -315,15 +361,15 @@ class BAdamArgument:
default="layer", default="layer",
metadata={"help": "Whether to use layer-wise or ratio-wise BAdam optimizer."}, metadata={"help": "Whether to use layer-wise or ratio-wise BAdam optimizer."},
) )
badam_start_block: Optional[int] = field( badam_start_block: int | None = field(
default=None, default=None,
metadata={"help": "The starting block index for layer-wise BAdam."}, metadata={"help": "The starting block index for layer-wise BAdam."},
) )
badam_switch_mode: Optional[Literal["ascending", "descending", "random", "fixed"]] = field( badam_switch_mode: Literal["ascending", "descending", "random", "fixed"] | None = field(
default="ascending", default="ascending",
metadata={"help": "the strategy of picking block to update for layer-wise BAdam."}, metadata={"help": "the strategy of picking block to update for layer-wise BAdam."},
) )
badam_switch_interval: Optional[int] = field( badam_switch_interval: int | None = field(
default=50, default=50,
metadata={ metadata={
"help": "Number of steps to update the block for layer-wise BAdam. Use -1 to disable the block update." "help": "Number of steps to update the block for layer-wise BAdam. Use -1 to disable the block update."
...@@ -360,15 +406,15 @@ class SwanLabArguments: ...@@ -360,15 +406,15 @@ class SwanLabArguments:
default=False, default=False,
metadata={"help": "Whether or not to use the SwanLab (an experiment tracking and visualization tool)."}, metadata={"help": "Whether or not to use the SwanLab (an experiment tracking and visualization tool)."},
) )
swanlab_project: Optional[str] = field( swanlab_project: str | None = field(
default="llamafactory", default="llamafactory",
metadata={"help": "The project name in SwanLab."}, metadata={"help": "The project name in SwanLab."},
) )
swanlab_workspace: Optional[str] = field( swanlab_workspace: str | None = field(
default=None, default=None,
metadata={"help": "The workspace name in SwanLab."}, metadata={"help": "The workspace name in SwanLab."},
) )
swanlab_run_name: Optional[str] = field( swanlab_run_name: str | None = field(
default=None, default=None,
metadata={"help": "The experiment name in SwanLab."}, metadata={"help": "The experiment name in SwanLab."},
) )
...@@ -376,19 +422,19 @@ class SwanLabArguments: ...@@ -376,19 +422,19 @@ class SwanLabArguments:
default="cloud", default="cloud",
metadata={"help": "The mode of SwanLab."}, metadata={"help": "The mode of SwanLab."},
) )
swanlab_api_key: Optional[str] = field( swanlab_api_key: str | None = field(
default=None, default=None,
metadata={"help": "The API key for SwanLab."}, metadata={"help": "The API key for SwanLab."},
) )
swanlab_logdir: Optional[str] = field( swanlab_logdir: str | None = field(
default=None, default=None,
metadata={"help": "The log directory for SwanLab."}, metadata={"help": "The log directory for SwanLab."},
) )
swanlab_lark_webhook_url: Optional[str] = field( swanlab_lark_webhook_url: str | None = field(
default=None, default=None,
metadata={"help": "The Lark(飞书) webhook URL for SwanLab."}, metadata={"help": "The Lark(飞书) webhook URL for SwanLab."},
) )
swanlab_lark_secret: Optional[str] = field( swanlab_lark_secret: str | None = field(
default=None, default=None,
metadata={"help": "The Lark(飞书) secret for SwanLab."}, metadata={"help": "The Lark(飞书) secret for SwanLab."},
) )
...@@ -396,7 +442,14 @@ class SwanLabArguments: ...@@ -396,7 +442,14 @@ class SwanLabArguments:
@dataclass @dataclass
class FinetuningArguments( class FinetuningArguments(
SwanLabArguments, BAdamArgument, ApolloArguments, GaloreArguments, RLHFArguments, LoraArguments, FreezeArguments SwanLabArguments,
BAdamArgument,
ApolloArguments,
GaloreArguments,
RLHFArguments,
LoraArguments,
OFTArguments,
FreezeArguments,
): ):
r"""Arguments pertaining to which techniques we are going to fine-tuning with.""" r"""Arguments pertaining to which techniques we are going to fine-tuning with."""
...@@ -408,7 +461,7 @@ class FinetuningArguments( ...@@ -408,7 +461,7 @@ class FinetuningArguments(
default="sft", default="sft",
metadata={"help": "Which stage will be performed in training."}, metadata={"help": "Which stage will be performed in training."},
) )
finetuning_type: Literal["lora", "freeze", "full"] = field( finetuning_type: Literal["lora", "oft", "freeze", "full"] = field(
default="lora", default="lora",
metadata={"help": "Which fine-tuning method to use."}, metadata={"help": "Which fine-tuning method to use."},
) )
...@@ -420,10 +473,23 @@ class FinetuningArguments( ...@@ -420,10 +473,23 @@ class FinetuningArguments(
default=False, default=False,
metadata={"help": "Whether or not to use the Adam-mini optimizer."}, metadata={"help": "Whether or not to use the Adam-mini optimizer."},
) )
use_mca: bool = field(
default=False,
metadata={
"help": (
"Whether or not to use MCA (Megatron Core Adapter) training. "
"Controlled by USE_MCA environment variable."
)
},
)
use_muon: bool = field( use_muon: bool = field(
default=False, default=False,
metadata={"help": "Whether or not to use the Muon optimizer."}, metadata={"help": "Whether or not to use the Muon optimizer."},
) )
use_dft_loss: bool = field(
default=False,
metadata={"help": "Whether to use the DFT loss."},
)
freeze_vision_tower: bool = field( freeze_vision_tower: bool = field(
default=True, default=True,
metadata={"help": "Whether ot not to freeze the vision tower in MLLM training."}, metadata={"help": "Whether ot not to freeze the vision tower in MLLM training."},
...@@ -444,7 +510,7 @@ class FinetuningArguments( ...@@ -444,7 +510,7 @@ class FinetuningArguments(
default=False, default=False,
metadata={"help": "Whether or not to disable the shuffling of the training set."}, metadata={"help": "Whether or not to disable the shuffling of the training set."},
) )
early_stopping_steps: Optional[int] = field( early_stopping_steps: int | None = field(
default=None, default=None,
metadata={"help": "Number of steps to stop training if the `metric_for_best_model` does not improve."}, metadata={"help": "Number of steps to stop training if the `metric_for_best_model` does not improve."},
) )
...@@ -464,15 +530,16 @@ class FinetuningArguments( ...@@ -464,15 +530,16 @@ class FinetuningArguments(
return arg return arg
self.freeze_trainable_modules: list[str] = split_arg(self.freeze_trainable_modules) self.freeze_trainable_modules: list[str] = split_arg(self.freeze_trainable_modules)
self.freeze_extra_modules: Optional[list[str]] = split_arg(self.freeze_extra_modules) self.freeze_extra_modules: list[str] | None = split_arg(self.freeze_extra_modules)
self.lora_alpha: int = self.lora_alpha or self.lora_rank * 2 self.lora_alpha: int = self.lora_alpha or self.lora_rank * 2
self.lora_target: list[str] = split_arg(self.lora_target) self.lora_target: list[str] = split_arg(self.lora_target)
self.additional_target: Optional[list[str]] = split_arg(self.additional_target) self.oft_target: list[str] = split_arg(self.oft_target)
self.additional_target: list[str] | None = split_arg(self.additional_target)
self.galore_target: list[str] = split_arg(self.galore_target) self.galore_target: list[str] = split_arg(self.galore_target)
self.apollo_target: list[str] = split_arg(self.apollo_target) self.apollo_target: list[str] = split_arg(self.apollo_target)
self.use_ref_model = self.stage == "dpo" and self.pref_loss not in ["orpo", "simpo"] self.use_ref_model = self.stage == "dpo" and self.pref_loss not in ["orpo", "simpo"]
assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method." assert self.finetuning_type in ["lora", "oft", "freeze", "full"], "Invalid fine-tuning method."
assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization." assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
assert self.reward_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization." assert self.reward_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
...@@ -482,6 +549,9 @@ class FinetuningArguments( ...@@ -482,6 +549,9 @@ class FinetuningArguments(
if self.stage == "ppo" and self.reward_model_type == "lora" and self.finetuning_type != "lora": if self.stage == "ppo" and self.reward_model_type == "lora" and self.finetuning_type != "lora":
raise ValueError("`reward_model_type` cannot be lora for Freeze/Full PPO training.") raise ValueError("`reward_model_type` cannot be lora for Freeze/Full PPO training.")
if self.stage == "ppo" and self.reward_model_type == "oft" and self.finetuning_type != "oft":
raise ValueError("`reward_model_type` cannot be oft for Freeze/Full PPO training.")
if self.stage == "dpo" and self.pref_loss != "sigmoid" and self.dpo_label_smoothing > 1e-6: if self.stage == "dpo" and self.pref_loss != "sigmoid" and self.dpo_label_smoothing > 1e-6:
raise ValueError("`dpo_label_smoothing` is only valid for sigmoid loss function.") raise ValueError("`dpo_label_smoothing` is only valid for sigmoid loss function.")
......
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team. # Copyright 2025 HuggingFace Inc., the KVCache.AI team, Approaching AI, and the LlamaFactory team.
# #
# This code is inspired by the HuggingFace's transformers library. # This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py # https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
...@@ -17,26 +17,30 @@ ...@@ -17,26 +17,30 @@
import json import json
from dataclasses import asdict, dataclass, field, fields from dataclasses import asdict, dataclass, field, fields
from typing import Any, Literal, Optional, Union from typing import Any, Literal, Self
import torch import torch
from omegaconf import OmegaConf
from transformers.training_args import _convert_str_dict from transformers.training_args import _convert_str_dict
from typing_extensions import Self
from ..extras.constants import AttentionFunction, EngineName, QuantizationMethod, RopeScaling from ..extras.constants import AttentionFunction, EngineName, QuantizationMethod, RopeScaling
from ..extras.logging import get_logger
logger = get_logger(__name__)
@dataclass @dataclass
class BaseModelArguments: class BaseModelArguments:
r"""Arguments pertaining to the model.""" r"""Arguments pertaining to the model."""
model_name_or_path: Optional[str] = field( model_name_or_path: str | None = field(
default=None, default=None,
metadata={ metadata={
"help": "Path to the model weight or identifier from huggingface.co/models or modelscope.cn/models." "help": "Path to the model weight or identifier from huggingface.co/models or modelscope.cn/models."
}, },
) )
adapter_name_or_path: Optional[str] = field( adapter_name_or_path: str | None = field(
default=None, default=None,
metadata={ metadata={
"help": ( "help": (
...@@ -45,11 +49,11 @@ class BaseModelArguments: ...@@ -45,11 +49,11 @@ class BaseModelArguments:
) )
}, },
) )
adapter_folder: Optional[str] = field( adapter_folder: str | None = field(
default=None, default=None,
metadata={"help": "The folder containing the adapter weights to load."}, metadata={"help": "The folder containing the adapter weights to load."},
) )
cache_dir: Optional[str] = field( cache_dir: str | None = field(
default=None, default=None,
metadata={"help": "Where to store the pre-trained models downloaded from huggingface.co or modelscope.cn."}, metadata={"help": "Where to store the pre-trained models downloaded from huggingface.co or modelscope.cn."},
) )
...@@ -65,16 +69,38 @@ class BaseModelArguments: ...@@ -65,16 +69,38 @@ class BaseModelArguments:
default=False, default=False,
metadata={"help": "Whether or not the special tokens should be split during the tokenization process."}, metadata={"help": "Whether or not the special tokens should be split during the tokenization process."},
) )
add_tokens: Optional[str] = field( add_tokens: str | None = field(
default=None, default=None,
metadata={ metadata={
"help": "Non-special tokens to be added into the tokenizer. Use commas to separate multiple tokens." "help": "Non-special tokens to be added into the tokenizer. Use commas to separate multiple tokens."
}, },
) )
add_special_tokens: Optional[str] = field( add_special_tokens: str | None = field(
default=None, default=None,
metadata={"help": "Special tokens to be added into the tokenizer. Use commas to separate multiple tokens."}, metadata={"help": "Special tokens to be added into the tokenizer. Use commas to separate multiple tokens."},
) )
new_special_tokens_config: str | None = field(
default=None,
metadata={
"help": (
"Path to YAML config with special token descriptions for semantic initialization. "
"If set, this takes precedence over add_special_tokens. "
"YAML format: {'<token>': 'description text', ...}"
)
},
)
init_special_tokens: Literal["noise_init", "desc_init", "desc_init_w_noise"] = field(
default="noise_init",
metadata={
"help": (
"Initialization method for new special tokens: "
"'noise_init' (default, random noise around mean), "
"'desc_init' (semantic initialization from descriptions), "
"'desc_init_w_noise' (semantic + random noise). "
"Note: 'desc_init' methods require new_special_tokens_config."
)
},
)
model_revision: str = field( model_revision: str = field(
default="main", default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
...@@ -83,7 +109,7 @@ class BaseModelArguments: ...@@ -83,7 +109,7 @@ class BaseModelArguments:
default=True, default=True,
metadata={"help": "Whether or not to use memory-efficient model loading."}, metadata={"help": "Whether or not to use memory-efficient model loading."},
) )
rope_scaling: Optional[RopeScaling] = field( rope_scaling: RopeScaling | None = field(
default=None, default=None,
metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."}, metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."},
) )
...@@ -95,7 +121,7 @@ class BaseModelArguments: ...@@ -95,7 +121,7 @@ class BaseModelArguments:
default=False, default=False,
metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."}, metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."},
) )
mixture_of_depths: Optional[Literal["convert", "load"]] = field( mixture_of_depths: Literal["convert", "load"] | None = field(
default=None, default=None,
metadata={"help": "Convert the model to mixture-of-depths (MoD) or load the MoD model."}, metadata={"help": "Convert the model to mixture-of-depths (MoD) or load the MoD model."},
) )
...@@ -111,7 +137,7 @@ class BaseModelArguments: ...@@ -111,7 +137,7 @@ class BaseModelArguments:
default=False, default=False,
metadata={"help": "Whether or not to enable liger kernel for faster training."}, metadata={"help": "Whether or not to enable liger kernel for faster training."},
) )
moe_aux_loss_coef: Optional[float] = field( moe_aux_loss_coef: float | None = field(
default=None, default=None,
metadata={"help": "Coefficient of the auxiliary router loss in mixture-of-experts model."}, metadata={"help": "Coefficient of the auxiliary router loss in mixture-of-experts model."},
) )
...@@ -143,23 +169,27 @@ class BaseModelArguments: ...@@ -143,23 +169,27 @@ class BaseModelArguments:
default="offload", default="offload",
metadata={"help": "Path to offload model weights."}, metadata={"help": "Path to offload model weights."},
) )
use_cache: bool = field( use_kv_cache: bool = field(
default=True, default=True,
metadata={"help": "Whether or not to use KV cache in generation."}, metadata={"help": "Whether or not to use KV cache in generation."},
) )
use_v1_kernels: bool | None = field(
default=False,
metadata={"help": "Whether or not to use high-performance kernels in training."},
)
infer_dtype: Literal["auto", "float16", "bfloat16", "float32"] = field( infer_dtype: Literal["auto", "float16", "bfloat16", "float32"] = field(
default="auto", default="auto",
metadata={"help": "Data type for model weights and activations at inference."}, metadata={"help": "Data type for model weights and activations at inference."},
) )
hf_hub_token: Optional[str] = field( hf_hub_token: str | None = field(
default=None, default=None,
metadata={"help": "Auth token to log in with Hugging Face Hub."}, metadata={"help": "Auth token to log in with Hugging Face Hub."},
) )
ms_hub_token: Optional[str] = field( ms_hub_token: str | None = field(
default=None, default=None,
metadata={"help": "Auth token to log in with ModelScope Hub."}, metadata={"help": "Auth token to log in with ModelScope Hub."},
) )
om_hub_token: Optional[str] = field( om_hub_token: str | None = field(
default=None, default=None,
metadata={"help": "Auth token to log in with Modelers Hub."}, metadata={"help": "Auth token to log in with Modelers Hub."},
) )
...@@ -185,8 +215,63 @@ class BaseModelArguments: ...@@ -185,8 +215,63 @@ class BaseModelArguments:
if self.add_tokens is not None: # support multiple tokens if self.add_tokens is not None: # support multiple tokens
self.add_tokens = [token.strip() for token in self.add_tokens.split(",")] self.add_tokens = [token.strip() for token in self.add_tokens.split(",")]
if self.add_special_tokens is not None: # support multiple special tokens # Process special tokens with priority: new_special_tokens_config > add_special_tokens
if self.new_special_tokens_config is not None:
# Priority 1: Load from YAML config (extracts both tokens and descriptions)
try:
cfg = OmegaConf.load(self.new_special_tokens_config)
token_descriptions = OmegaConf.to_container(cfg)
if not isinstance(token_descriptions, dict):
raise ValueError(
f"YAML config must be a dictionary mapping tokens to descriptions. "
f"Got: {type(token_descriptions)}"
)
# Extract token list from config keys
extracted_tokens = list(token_descriptions.keys())
# Warn if both are set
if self.add_special_tokens is not None:
logger.warning_rank0(
"Both 'new_special_tokens_config' and 'add_special_tokens' are set. "
f"Using tokens from config: {extracted_tokens}"
)
# Override add_special_tokens with extracted tokens (as list)
self.add_special_tokens = extracted_tokens
# Store descriptions internally for later use (internal attribute)
self._special_token_descriptions = token_descriptions
logger.info_rank0(
f"Loaded {len(extracted_tokens)} special tokens with descriptions from: "
f"{self.new_special_tokens_config}"
)
except Exception as e:
logger.error_rank0(
f"Failed to load special tokens config from '{self.new_special_tokens_config}': {e}"
)
raise
elif self.add_special_tokens is not None:
# Priority 2: Use simple comma-separated string (no descriptions)
self.add_special_tokens = [token.strip() for token in self.add_special_tokens.split(",")] self.add_special_tokens = [token.strip() for token in self.add_special_tokens.split(",")]
self._special_token_descriptions = None
else:
# No special tokens to add
self._special_token_descriptions = None
# Validate init method
if self.init_special_tokens in ["desc_init", "desc_init_w_noise"]:
if self._special_token_descriptions is None:
logger.warning_rank0(
f"init_special_tokens='{self.init_special_tokens}' requires new_special_tokens_config. "
"Falling back to 'noise_init'"
)
self.init_special_tokens = "noise_init"
@dataclass @dataclass
...@@ -197,7 +282,7 @@ class QuantizationArguments: ...@@ -197,7 +282,7 @@ class QuantizationArguments:
default=QuantizationMethod.BNB, default=QuantizationMethod.BNB,
metadata={"help": "Quantization method to use for on-the-fly quantization."}, metadata={"help": "Quantization method to use for on-the-fly quantization."},
) )
quantization_bit: Optional[int] = field( quantization_bit: int | None = field(
default=None, default=None,
metadata={"help": "The number of bits to quantize the model using on-the-fly quantization."}, metadata={"help": "The number of bits to quantize the model using on-the-fly quantization."},
) )
...@@ -209,10 +294,27 @@ class QuantizationArguments: ...@@ -209,10 +294,27 @@ class QuantizationArguments:
default=True, default=True,
metadata={"help": "Whether or not to use double quantization in bitsandbytes int4 training."}, metadata={"help": "Whether or not to use double quantization in bitsandbytes int4 training."},
) )
quantization_device_map: Optional[Literal["auto"]] = field( quantization_device_map: Literal["auto"] | None = field(
default=None, default=None,
metadata={"help": "Device map used to infer the 4-bit quantized model, needs bitsandbytes>=0.43.0."}, metadata={"help": "Device map used to infer the 4-bit quantized model, needs bitsandbytes>=0.43.0."},
) )
fp8: bool = field(
default=False,
metadata={
"help": "Enable FP8 mixed precision training via HuggingFace Accelerate. "
"Requires PyTorch 2.7+ and Hopper architecture GPUs."
},
)
fp8_backend: str = field(
default="auto",
metadata={
"help": "FP8 backend to use ('auto', 'torchao', 'te', 'msamp'). 'auto' selects best available backend."
},
)
fp8_enable_fsdp_float8_all_gather: bool = field(
default=False,
metadata={"help": "Enable FP8 optimizations for FSDP2 all-gather operations."},
)
@dataclass @dataclass
...@@ -272,7 +374,7 @@ class ProcessorArguments: ...@@ -272,7 +374,7 @@ class ProcessorArguments:
class ExportArguments: class ExportArguments:
r"""Arguments pertaining to the model export.""" r"""Arguments pertaining to the model export."""
export_dir: Optional[str] = field( export_dir: str | None = field(
default=None, default=None,
metadata={"help": "Path to the directory to save the exported model."}, metadata={"help": "Path to the directory to save the exported model."},
) )
...@@ -284,11 +386,11 @@ class ExportArguments: ...@@ -284,11 +386,11 @@ class ExportArguments:
default="cpu", default="cpu",
metadata={"help": "The device used in model export, use `auto` to accelerate exporting."}, metadata={"help": "The device used in model export, use `auto` to accelerate exporting."},
) )
export_quantization_bit: Optional[int] = field( export_quantization_bit: int | None = field(
default=None, default=None,
metadata={"help": "The number of bits to quantize the exported model."}, metadata={"help": "The number of bits to quantize the exported model."},
) )
export_quantization_dataset: Optional[str] = field( export_quantization_dataset: str | None = field(
default=None, default=None,
metadata={"help": "Path to the dataset or dataset name to use in quantizing the exported model."}, metadata={"help": "Path to the dataset or dataset name to use in quantizing the exported model."},
) )
...@@ -304,7 +406,7 @@ class ExportArguments: ...@@ -304,7 +406,7 @@ class ExportArguments:
default=False, default=False,
metadata={"help": "Whether or not to save the `.bin` files instead of `.safetensors`."}, metadata={"help": "Whether or not to save the `.bin` files instead of `.safetensors`."},
) )
export_hub_model_id: Optional[str] = field( export_hub_model_id: str | None = field(
default=None, default=None,
metadata={"help": "The name of the repository if push the model to the Hugging Face hub."}, metadata={"help": "The name of the repository if push the model to the Hugging Face hub."},
) )
...@@ -334,7 +436,7 @@ class VllmArguments: ...@@ -334,7 +436,7 @@ class VllmArguments:
default=32, default=32,
metadata={"help": "Maximum rank of all LoRAs in the vLLM engine."}, metadata={"help": "Maximum rank of all LoRAs in the vLLM engine."},
) )
vllm_config: Optional[Union[dict, str]] = field( vllm_config: dict | str | None = field(
default=None, default=None,
metadata={"help": "Config to initialize the vllm engine. Please use JSON strings."}, metadata={"help": "Config to initialize the vllm engine. Please use JSON strings."},
) )
...@@ -360,7 +462,7 @@ class SGLangArguments: ...@@ -360,7 +462,7 @@ class SGLangArguments:
default=-1, default=-1,
metadata={"help": "Tensor parallel size for the SGLang engine."}, metadata={"help": "Tensor parallel size for the SGLang engine."},
) )
sglang_config: Optional[Union[dict, str]] = field( sglang_config: dict | str | None = field(
default=None, default=None,
metadata={"help": "Config to initialize the SGLang engine. Please use JSON strings."}, metadata={"help": "Config to initialize the SGLang engine. Please use JSON strings."},
) )
...@@ -376,26 +478,77 @@ class SGLangArguments: ...@@ -376,26 +478,77 @@ class SGLangArguments:
self.sglang_config = _convert_str_dict(json.loads(self.sglang_config)) self.sglang_config = _convert_str_dict(json.loads(self.sglang_config))
@dataclass
class KTransformersArguments:
r"""Arguments pertaining to the KT training."""
use_kt: bool = field(
default=False,
metadata={"help": "Whether To Use KTransformers Optimizations For LoRA Training."},
)
kt_optimize_rule: str | None = field(
default=None,
metadata={
"help": "Path To The KTransformers Optimize Rule; See https://github.com/kvcache-ai/ktransformers/."
},
)
cpu_infer: int | None = field(
default=32,
metadata={"help": "Number Of CPU Cores Used For Computation."},
)
chunk_size: int | None = field(
default=8192,
metadata={"help": "Chunk Size Used For CPU Compute In KTransformers."},
)
mode: str | None = field(
default="normal",
metadata={"help": "Normal Or Long_Context For Llama Models."},
)
kt_maxlen: int = field(
default=4096,
metadata={"help": "Maximum Sequence (Prompt + Response) Length Of The KT Engine."},
)
kt_use_cuda_graph: bool = field(
default=True,
metadata={"help": "Whether To Use CUDA Graphs For The KT Engine."},
)
kt_mode: str = field(
default="normal",
metadata={"help": "Normal Or Long_Context Mode For The KT Engine."},
)
kt_force_think: bool = field(
default=False,
metadata={"help": "Force-Think Toggle For The KT Engine."},
)
@dataclass @dataclass
class ModelArguments( class ModelArguments(
SGLangArguments, VllmArguments, ExportArguments, ProcessorArguments, QuantizationArguments, BaseModelArguments SGLangArguments,
VllmArguments,
KTransformersArguments,
ExportArguments,
ProcessorArguments,
QuantizationArguments,
BaseModelArguments,
): ):
r"""Arguments pertaining to which model/config/tokenizer we are going to fine-tune or infer. r"""Arguments pertaining to which model/config/tokenizer we are going to fine-tune or infer.
The class on the most right will be displayed first. The class on the most right will be displayed first.
""" """
compute_dtype: Optional[torch.dtype] = field( compute_dtype: torch.dtype | None = field(
default=None, default=None,
init=False, init=False,
metadata={"help": "Torch data type for computing model outputs, derived from `fp/bf16`. Do not specify it."}, metadata={"help": "Torch data type for computing model outputs, derived from `fp/bf16`. Do not specify it."},
) )
device_map: Optional[Union[str, dict[str, Any]]] = field( device_map: str | dict[str, Any] | None = field(
default=None, default=None,
init=False, init=False,
metadata={"help": "Device map for model placement, derived from training stage. Do not specify it."}, metadata={"help": "Device map for model placement, derived from training stage. Do not specify it."},
) )
model_max_length: Optional[int] = field( model_max_length: int | None = field(
default=None, default=None,
init=False, init=False,
metadata={"help": "The maximum input length for model, derived from `cutoff_len`. Do not specify it."}, metadata={"help": "The maximum input length for model, derived from `cutoff_len`. Do not specify it."},
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
import os import os
import sys import sys
from pathlib import Path from pathlib import Path
from typing import Any, Optional, Union from typing import Any, Optional
import torch import torch
import transformers import transformers
...@@ -32,6 +32,7 @@ from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_availab ...@@ -32,6 +32,7 @@ from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_availab
from ..extras import logging from ..extras import logging
from ..extras.constants import CHECKPOINT_NAMES, EngineName from ..extras.constants import CHECKPOINT_NAMES, EngineName
from ..extras.misc import check_dependencies, check_version, get_current_device, is_env_enabled from ..extras.misc import check_dependencies, check_version, get_current_device, is_env_enabled
from ..extras.packages import is_mcore_adapter_available, is_transformers_version_greater_than
from .data_args import DataArguments from .data_args import DataArguments
from .evaluation_args import EvaluationArguments from .evaluation_args import EvaluationArguments
from .finetuning_args import FinetuningArguments from .finetuning_args import FinetuningArguments
...@@ -52,8 +53,19 @@ _INFER_CLS = tuple[ModelArguments, DataArguments, FinetuningArguments, Generatin ...@@ -52,8 +53,19 @@ _INFER_CLS = tuple[ModelArguments, DataArguments, FinetuningArguments, Generatin
_EVAL_ARGS = [ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments] _EVAL_ARGS = [ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
_EVAL_CLS = tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments] _EVAL_CLS = tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
if is_mcore_adapter_available() and is_env_enabled("USE_MCA"):
from mcore_adapter import TrainingArguments as McaTrainingArguments
def read_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> Union[dict[str, Any], list[str]]: _TRAIN_MCA_ARGS = [ModelArguments, DataArguments, McaTrainingArguments, FinetuningArguments, GeneratingArguments]
_TRAIN_MCA_CLS = tuple[
ModelArguments, DataArguments, McaTrainingArguments, FinetuningArguments, GeneratingArguments
]
else:
_TRAIN_MCA_ARGS = []
_TRAIN_MCA_CLS = tuple()
def read_args(args: dict[str, Any] | list[str] | None = None) -> dict[str, Any] | list[str]:
r"""Get arguments from the command line or a config file.""" r"""Get arguments from the command line or a config file."""
if args is not None: if args is not None:
return args return args
...@@ -71,7 +83,7 @@ def read_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> Union[ ...@@ -71,7 +83,7 @@ def read_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> Union[
def _parse_args( def _parse_args(
parser: "HfArgumentParser", args: Optional[Union[dict[str, Any], list[str]]] = None, allow_extra_keys: bool = False parser: "HfArgumentParser", args: dict[str, Any] | list[str] | None = None, allow_extra_keys: bool = False
) -> tuple[Any]: ) -> tuple[Any]:
args = read_args(args) args = read_args(args)
if isinstance(args, dict): if isinstance(args, dict):
...@@ -111,8 +123,8 @@ def _verify_model_args( ...@@ -111,8 +123,8 @@ def _verify_model_args(
raise ValueError("Adapter is only valid for the LoRA method.") raise ValueError("Adapter is only valid for the LoRA method.")
if model_args.quantization_bit is not None: if model_args.quantization_bit is not None:
if finetuning_args.finetuning_type != "lora": if finetuning_args.finetuning_type not in ["lora", "oft"]:
raise ValueError("Quantization is only compatible with the LoRA method.") raise ValueError("Quantization is only compatible with the LoRA or OFT method.")
if finetuning_args.pissa_init: if finetuning_args.pissa_init:
raise ValueError("Please use scripts/pissa_init.py to initialize PiSSA for a quantized model.") raise ValueError("Please use scripts/pissa_init.py to initialize PiSSA for a quantized model.")
...@@ -130,12 +142,23 @@ def _verify_model_args( ...@@ -130,12 +142,23 @@ def _verify_model_args(
logger.warning_rank0("We should use slow tokenizer for the Yi models. Change `use_fast_tokenizer` to False.") logger.warning_rank0("We should use slow tokenizer for the Yi models. Change `use_fast_tokenizer` to False.")
model_args.use_fast_tokenizer = False model_args.use_fast_tokenizer = False
# Validate advanced training features
if model_args.fp8 and model_args.quantization_bit is not None:
raise ValueError("FP8 training is not compatible with quantization. Please disable one of them.")
if model_args.fp8_enable_fsdp_float8_all_gather and not model_args.fp8:
logger.warning_rank0("fp8_enable_fsdp_float8_all_gather requires fp8=True. Setting fp8=True.")
model_args.fp8 = True
def _check_extra_dependencies( def _check_extra_dependencies(
model_args: "ModelArguments", model_args: "ModelArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
training_args: Optional["TrainingArguments"] = None, training_args: Optional["TrainingArguments"] = None,
) -> None: ) -> None:
if model_args.use_kt:
check_version("ktransformers", mandatory=True)
if model_args.use_unsloth: if model_args.use_unsloth:
check_version("unsloth", mandatory=True) check_version("unsloth", mandatory=True)
...@@ -146,7 +169,7 @@ def _check_extra_dependencies( ...@@ -146,7 +169,7 @@ def _check_extra_dependencies(
check_version("mixture-of-depth>=1.1.6", mandatory=True) check_version("mixture-of-depth>=1.1.6", mandatory=True)
if model_args.infer_backend == EngineName.VLLM: if model_args.infer_backend == EngineName.VLLM:
check_version("vllm>=0.4.3,<=0.9.1") check_version("vllm>=0.4.3,<=0.11.0")
check_version("vllm", mandatory=True) check_version("vllm", mandatory=True)
elif model_args.infer_backend == EngineName.SGLANG: elif model_args.infer_backend == EngineName.SGLANG:
check_version("sglang>=0.4.5") check_version("sglang>=0.4.5")
...@@ -173,7 +196,8 @@ def _check_extra_dependencies( ...@@ -173,7 +196,8 @@ def _check_extra_dependencies(
if training_args is not None: if training_args is not None:
if training_args.deepspeed: if training_args.deepspeed:
# pin deepspeed version < 0.17 because of https://github.com/deepspeedai/DeepSpeed/issues/7347 # pin deepspeed version < 0.17 because of https://github.com/deepspeedai/DeepSpeed/issues/7347
check_version("deepspeed>=0.10.0,<=0.16.9", mandatory=True) check_version("deepspeed", mandatory=True)
check_version("deepspeed>=0.10.0,<=0.16.9")
if training_args.predict_with_generate: if training_args.predict_with_generate:
check_version("jieba", mandatory=True) check_version("jieba", mandatory=True)
...@@ -181,32 +205,57 @@ def _check_extra_dependencies( ...@@ -181,32 +205,57 @@ def _check_extra_dependencies(
check_version("rouge_chinese", mandatory=True) check_version("rouge_chinese", mandatory=True)
def _parse_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _TRAIN_CLS: def _parse_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS:
parser = HfArgumentParser(_TRAIN_ARGS) parser = HfArgumentParser(_TRAIN_ARGS)
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS") allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
return _parse_args(parser, args, allow_extra_keys=allow_extra_keys) return _parse_args(parser, args, allow_extra_keys=allow_extra_keys)
def _parse_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _INFER_CLS: def _parse_train_mca_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_MCA_CLS:
parser = HfArgumentParser(_TRAIN_MCA_ARGS)
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
model_args, data_args, training_args, finetuning_args, generating_args = _parse_args(
parser, args, allow_extra_keys=allow_extra_keys
)
_configure_mca_training_args(training_args, data_args, finetuning_args)
return model_args, data_args, training_args, finetuning_args, generating_args
def _configure_mca_training_args(training_args, data_args, finetuning_args) -> None:
"""Patch training args to avoid args checking errors and sync MCA settings."""
training_args.predict_with_generate = False
training_args.generation_max_length = data_args.cutoff_len
training_args.generation_num_beams = 1
training_args.use_mca = True
finetuning_args.use_mca = True
def _parse_infer_args(args: dict[str, Any] | list[str] | None = None) -> _INFER_CLS:
parser = HfArgumentParser(_INFER_ARGS) parser = HfArgumentParser(_INFER_ARGS)
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS") allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
return _parse_args(parser, args, allow_extra_keys=allow_extra_keys) return _parse_args(parser, args, allow_extra_keys=allow_extra_keys)
def _parse_eval_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _EVAL_CLS: def _parse_eval_args(args: dict[str, Any] | list[str] | None = None) -> _EVAL_CLS:
parser = HfArgumentParser(_EVAL_ARGS) parser = HfArgumentParser(_EVAL_ARGS)
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS") allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
return _parse_args(parser, args, allow_extra_keys=allow_extra_keys) return _parse_args(parser, args, allow_extra_keys=allow_extra_keys)
def get_ray_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> RayArguments: def get_ray_args(args: dict[str, Any] | list[str] | None = None) -> RayArguments:
parser = HfArgumentParser(RayArguments) parser = HfArgumentParser(RayArguments)
(ray_args,) = _parse_args(parser, args, allow_extra_keys=True) (ray_args,) = _parse_args(parser, args, allow_extra_keys=True)
return ray_args return ray_args
def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _TRAIN_CLS: def get_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS:
model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_args(args) if is_env_enabled("USE_MCA"):
model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_mca_args(args)
else:
model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_args(args)
finetuning_args.use_mca = False
# Setup logging # Setup logging
if training_args.should_log: if training_args.should_log:
...@@ -236,13 +285,16 @@ def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _ ...@@ -236,13 +285,16 @@ def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
if model_args.shift_attn: if model_args.shift_attn:
raise ValueError("PPO training is incompatible with S^2-Attn.") raise ValueError("PPO training is incompatible with S^2-Attn.")
if finetuning_args.reward_model_type == "lora" and model_args.use_kt:
raise ValueError("KTransformers does not support lora reward model.")
if finetuning_args.reward_model_type == "lora" and model_args.use_unsloth: if finetuning_args.reward_model_type == "lora" and model_args.use_unsloth:
raise ValueError("Unsloth does not support lora reward model.") raise ValueError("Unsloth does not support lora reward model.")
if training_args.report_to and training_args.report_to[0] not in ["wandb", "tensorboard"]: if training_args.report_to and training_args.report_to[0] not in ["wandb", "tensorboard"]:
raise ValueError("PPO only accepts wandb or tensorboard logger.") raise ValueError("PPO only accepts wandb or tensorboard logger.")
if training_args.parallel_mode == ParallelMode.NOT_DISTRIBUTED: if not model_args.use_kt and training_args.parallel_mode == ParallelMode.NOT_DISTRIBUTED:
raise ValueError("Please launch distributed training with `llamafactory-cli` or `torchrun`.") raise ValueError("Please launch distributed training with `llamafactory-cli` or `torchrun`.")
if training_args.deepspeed and training_args.parallel_mode != ParallelMode.DISTRIBUTED: if training_args.deepspeed and training_args.parallel_mode != ParallelMode.DISTRIBUTED:
...@@ -254,18 +306,15 @@ def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _ ...@@ -254,18 +306,15 @@ def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
if training_args.do_train and data_args.dataset is None: if training_args.do_train and data_args.dataset is None:
raise ValueError("Please specify dataset for training.") raise ValueError("Please specify dataset for training.")
if (training_args.do_eval or training_args.do_predict) and ( if (training_args.do_eval or training_args.do_predict or training_args.predict_with_generate) and (
data_args.eval_dataset is None and data_args.val_size < 1e-6 data_args.eval_dataset is None and data_args.val_size < 1e-6
): ):
raise ValueError("Please specify dataset for evaluation.") raise ValueError("Please make sure eval_dataset be provided or val_size >1e-6")
if training_args.predict_with_generate: if training_args.predict_with_generate:
if is_deepspeed_zero3_enabled(): if is_deepspeed_zero3_enabled():
raise ValueError("`predict_with_generate` is incompatible with DeepSpeed ZeRO-3.") raise ValueError("`predict_with_generate` is incompatible with DeepSpeed ZeRO-3.")
if data_args.eval_dataset is None:
raise ValueError("Cannot use `predict_with_generate` if `eval_dataset` is None.")
if finetuning_args.compute_accuracy: if finetuning_args.compute_accuracy:
raise ValueError("Cannot use `predict_with_generate` and `compute_accuracy` together.") raise ValueError("Cannot use `predict_with_generate` and `compute_accuracy` together.")
...@@ -304,6 +353,12 @@ def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _ ...@@ -304,6 +353,12 @@ def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
if model_args.use_unsloth and is_deepspeed_zero3_enabled(): if model_args.use_unsloth and is_deepspeed_zero3_enabled():
raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.") raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.")
if model_args.use_kt and is_deepspeed_zero3_enabled():
raise ValueError("KTransformers is incompatible with DeepSpeed ZeRO-3.")
if data_args.neat_packing and is_transformers_version_greater_than("4.53.0"):
raise ValueError("Neat packing is incompatible with transformers>=4.53.0.")
_set_env_vars() _set_env_vars()
_verify_model_args(model_args, data_args, finetuning_args) _verify_model_args(model_args, data_args, finetuning_args)
_check_extra_dependencies(model_args, finetuning_args, training_args) _check_extra_dependencies(model_args, finetuning_args, training_args)
...@@ -418,7 +473,7 @@ def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _ ...@@ -418,7 +473,7 @@ def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
return model_args, data_args, training_args, finetuning_args, generating_args return model_args, data_args, training_args, finetuning_args, generating_args
def get_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _INFER_CLS: def get_infer_args(args: dict[str, Any] | list[str] | None = None) -> _INFER_CLS:
model_args, data_args, finetuning_args, generating_args = _parse_infer_args(args) model_args, data_args, finetuning_args, generating_args = _parse_infer_args(args)
# Setup logging # Setup logging
...@@ -453,7 +508,7 @@ def get_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _ ...@@ -453,7 +508,7 @@ def get_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
return model_args, data_args, finetuning_args, generating_args return model_args, data_args, finetuning_args, generating_args
def get_eval_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _EVAL_CLS: def get_eval_args(args: dict[str, Any] | list[str] | None = None) -> _EVAL_CLS:
model_args, data_args, eval_args, finetuning_args = _parse_eval_args(args) model_args, data_args, eval_args, finetuning_args = _parse_eval_args(args)
# Setup logging # Setup logging
......
...@@ -14,19 +14,33 @@ ...@@ -14,19 +14,33 @@
import json import json
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Literal, Optional, Union from typing import Literal
from transformers import Seq2SeqTrainingArguments from transformers import Seq2SeqTrainingArguments
from transformers.training_args import _convert_str_dict from transformers.training_args import _convert_str_dict
from ..extras.misc import use_ray from ..extras.misc import is_env_enabled, use_ray
from ..extras.packages import is_mcore_adapter_available
if is_env_enabled("USE_MCA"):
if not is_mcore_adapter_available():
raise ImportError(
"mcore_adapter is required when USE_MCA=1. Please install `mcore_adapter` and its dependencies."
)
from mcore_adapter import Seq2SeqTrainingArguments as McaSeq2SeqTrainingArguments
BaseTrainingArguments = McaSeq2SeqTrainingArguments
else:
BaseTrainingArguments = Seq2SeqTrainingArguments
@dataclass @dataclass
class RayArguments: class RayArguments:
r"""Arguments pertaining to the Ray training.""" r"""Arguments pertaining to the Ray training."""
ray_run_name: Optional[str] = field( ray_run_name: str | None = field(
default=None, default=None,
metadata={"help": "The training results will be saved at `<ray_storage_path>/ray_run_name`."}, metadata={"help": "The training results will be saved at `<ray_storage_path>/ray_run_name`."},
) )
...@@ -34,7 +48,7 @@ class RayArguments: ...@@ -34,7 +48,7 @@ class RayArguments:
default="./saves", default="./saves",
metadata={"help": "The storage path to save training results to"}, metadata={"help": "The storage path to save training results to"},
) )
ray_storage_filesystem: Optional[Literal["s3", "gs", "gcs"]] = field( ray_storage_filesystem: Literal["s3", "gs", "gcs"] | None = field(
default=None, default=None,
metadata={"help": "The storage filesystem to use. If None specified, local filesystem will be used."}, metadata={"help": "The storage filesystem to use. If None specified, local filesystem will be used."},
) )
...@@ -42,7 +56,7 @@ class RayArguments: ...@@ -42,7 +56,7 @@ class RayArguments:
default=1, default=1,
metadata={"help": "The number of workers for Ray training. Default is 1 worker."}, metadata={"help": "The number of workers for Ray training. Default is 1 worker."},
) )
resources_per_worker: Union[dict, str] = field( resources_per_worker: dict | str = field(
default_factory=lambda: {"GPU": 1}, default_factory=lambda: {"GPU": 1},
metadata={"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."}, metadata={"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."},
) )
...@@ -50,7 +64,7 @@ class RayArguments: ...@@ -50,7 +64,7 @@ class RayArguments:
default="PACK", default="PACK",
metadata={"help": "The placement strategy for Ray training. Default is PACK."}, metadata={"help": "The placement strategy for Ray training. Default is PACK."},
) )
ray_init_kwargs: Optional[dict] = field( ray_init_kwargs: dict | str | None = field(
default=None, default=None,
metadata={"help": "The arguments to pass to ray.init for Ray training. Default is None."}, metadata={"help": "The arguments to pass to ray.init for Ray training. Default is None."},
) )
...@@ -59,10 +73,14 @@ class RayArguments: ...@@ -59,10 +73,14 @@ class RayArguments:
self.use_ray = use_ray() self.use_ray = use_ray()
if isinstance(self.resources_per_worker, str) and self.resources_per_worker.startswith("{"): if isinstance(self.resources_per_worker, str) and self.resources_per_worker.startswith("{"):
self.resources_per_worker = _convert_str_dict(json.loads(self.resources_per_worker)) self.resources_per_worker = _convert_str_dict(json.loads(self.resources_per_worker))
if isinstance(self.ray_init_kwargs, str) and self.ray_init_kwargs.startswith("{"):
self.ray_init_kwargs = _convert_str_dict(json.loads(self.ray_init_kwargs))
if self.ray_storage_filesystem is not None: if self.ray_storage_filesystem is not None:
if self.ray_storage_filesystem not in ["s3", "gs", "gcs"]: if self.ray_storage_filesystem not in ["s3", "gs", "gcs"]:
raise ValueError( raise ValueError(
f"ray_storage_filesystem must be one of ['s3', 'gs', 'gcs'], got {self.ray_storage_filesystem}" f"ray_storage_filesystem must be one of ['s3', 'gs', 'gcs'], got {self.ray_storage_filesystem}."
) )
import pyarrow.fs as fs import pyarrow.fs as fs
...@@ -74,9 +92,14 @@ class RayArguments: ...@@ -74,9 +92,14 @@ class RayArguments:
@dataclass @dataclass
class TrainingArguments(RayArguments, Seq2SeqTrainingArguments): class TrainingArguments(RayArguments, BaseTrainingArguments):
r"""Arguments pertaining to the trainer.""" r"""Arguments pertaining to the trainer."""
overwrite_output_dir: bool = field(
default=False,
metadata={"help": "deprecated"},
)
def __post_init__(self): def __post_init__(self):
Seq2SeqTrainingArguments.__post_init__(self)
RayArguments.__post_init__(self) RayArguments.__post_init__(self)
BaseTrainingArguments.__post_init__(self)
...@@ -12,12 +12,174 @@ ...@@ -12,12 +12,174 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from llamafactory.train.tuner import run_exp # use absolute import import os
import subprocess
import sys
from copy import deepcopy
USAGE = (
"-" * 70
+ "\n"
+ "| Usage: |\n"
+ "| llamafactory-cli api -h: launch an OpenAI-style API server |\n"
+ "| llamafactory-cli chat -h: launch a chat interface in CLI |\n"
+ "| llamafactory-cli export -h: merge LoRA adapters and export model |\n"
+ "| llamafactory-cli train -h: train models |\n"
+ "| llamafactory-cli webchat -h: launch a chat interface in Web UI |\n"
+ "| llamafactory-cli webui: launch LlamaBoard |\n"
+ "| llamafactory-cli env: show environment info |\n"
+ "| llamafactory-cli version: show version info |\n"
+ "| Hint: You can use `lmf` as a shortcut for `llamafactory-cli`. |\n"
+ "-" * 70
)
def launch(): def launch():
run_exp() from .extras import logging
from .extras.env import VERSION, print_env
from .extras.misc import find_available_port, get_device_count, is_env_enabled, use_kt, use_ray
logger = logging.get_logger(__name__)
WELCOME = (
"-" * 58
+ "\n"
+ f"| Welcome to LLaMA Factory, version {VERSION}"
+ " " * (21 - len(VERSION))
+ "|\n|"
+ " " * 56
+ "|\n"
+ "| Project page: https://github.com/hiyouga/LLaMA-Factory |\n"
+ "-" * 58
)
command = sys.argv.pop(1) if len(sys.argv) > 1 else "help"
if is_env_enabled("USE_MCA"): # force use torchrun
os.environ["FORCE_TORCHRUN"] = "1"
if command == "train" and (
is_env_enabled("FORCE_TORCHRUN") or (get_device_count() > 1 and not use_ray() and not use_kt())
):
# launch distributed training
nnodes = os.getenv("NNODES", "1")
node_rank = os.getenv("NODE_RANK", "0")
nproc_per_node = os.getenv("NPROC_PER_NODE", str(get_device_count()))
master_addr = os.getenv("MASTER_ADDR", "127.0.0.1")
master_port = os.getenv("MASTER_PORT", str(find_available_port()))
logger.info_rank0(f"Initializing {nproc_per_node} distributed tasks at: {master_addr}:{master_port}")
if int(nnodes) > 1:
logger.info_rank0(f"Multi-node training enabled: num nodes: {nnodes}, node rank: {node_rank}")
# elastic launch support
max_restarts = os.getenv("MAX_RESTARTS", "0")
rdzv_id = os.getenv("RDZV_ID")
min_nnodes = os.getenv("MIN_NNODES")
max_nnodes = os.getenv("MAX_NNODES")
env = deepcopy(os.environ)
if is_env_enabled("OPTIM_TORCH", "1"):
# optimize DDP, see https://zhuanlan.zhihu.com/p/671834539
env["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
env["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
if rdzv_id is not None:
# launch elastic job with fault tolerant support when possible
# see also https://docs.pytorch.org/docs/stable/elastic/train_script.html
rdzv_nnodes = nnodes
# elastic number of nodes if MIN_NNODES and MAX_NNODES are set
if min_nnodes is not None and max_nnodes is not None:
rdzv_nnodes = f"{min_nnodes}:{max_nnodes}"
process = subprocess.run(
(
"torchrun --nnodes {rdzv_nnodes} --nproc-per-node {nproc_per_node} "
"--rdzv-id {rdzv_id} --rdzv-backend c10d --rdzv-endpoint {master_addr}:{master_port} "
"--max-restarts {max_restarts} {file_name} {args}"
)
.format(
rdzv_nnodes=rdzv_nnodes,
nproc_per_node=nproc_per_node,
rdzv_id=rdzv_id,
master_addr=master_addr,
master_port=master_port,
max_restarts=max_restarts,
file_name=__file__,
args=" ".join(sys.argv[1:]),
)
.split(),
env=env,
check=True,
)
else:
# NOTE: DO NOT USE shell=True to avoid security risk
process = subprocess.run(
(
"torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} "
"--master_addr {master_addr} --master_port {master_port} {file_name} {args}"
)
.format(
nnodes=nnodes,
node_rank=node_rank,
nproc_per_node=nproc_per_node,
master_addr=master_addr,
master_port=master_port,
file_name=__file__,
args=" ".join(sys.argv[1:]),
)
.split(),
env=env,
check=True,
)
sys.exit(process.returncode)
elif command == "api":
from .api.app import run_api
run_api()
elif command == "chat":
from .chat.chat_model import run_chat
run_chat()
elif command == "eval":
raise NotImplementedError("Evaluation will be deprecated in the future.")
elif command == "export":
from .train.tuner import export_model
export_model()
elif command == "train":
from .train.tuner import run_exp
run_exp()
elif command == "webchat":
from .webui.interface import run_web_demo
run_web_demo()
elif command == "webui":
from .webui.interface import run_web_ui
run_web_ui()
elif command == "env":
print_env()
elif command == "version":
print(WELCOME)
elif command == "help":
print(USAGE)
else:
print(f"Unknown command: {command}.\n{USAGE}")
if __name__ == "__main__": if __name__ == "__main__":
launch() from llamafactory.train.tuner import run_exp # use absolute import
run_exp()
...@@ -16,10 +16,12 @@ import re ...@@ -16,10 +16,12 @@ import re
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import torch import torch
from peft import LoraConfig, LoraModel, PeftModel, TaskType, get_peft_model from peft import LoraConfig, LoraModel, OFTConfig, PeftModel, TaskType, get_peft_model
from transformers.integrations import is_deepspeed_zero3_enabled from transformers.integrations import is_deepspeed_zero3_enabled
from ..extras import logging from ..extras import logging
from ..extras.constants import EngineName
from .model_utils.ktransformers import get_kt_peft_model, load_kt_peft_model
from .model_utils.misc import find_all_linear_modules, find_expanded_modules from .model_utils.misc import find_all_linear_modules, find_expanded_modules
from .model_utils.quantization import QuantizationMethod from .model_utils.quantization import QuantizationMethod
from .model_utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model from .model_utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model
...@@ -147,7 +149,10 @@ def _setup_lora_tuning( ...@@ -147,7 +149,10 @@ def _setup_lora_tuning(
cast_trainable_params_to_fp32: bool, cast_trainable_params_to_fp32: bool,
) -> "PeftModel": ) -> "PeftModel":
if is_trainable: if is_trainable:
logger.info_rank0("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA")) if finetuning_args.finetuning_type == "oft":
logger.info_rank0("Fine-tuning method: OFT")
else:
logger.info_rank0("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA"))
adapter_to_resume = None adapter_to_resume = None
...@@ -161,6 +166,10 @@ def _setup_lora_tuning( ...@@ -161,6 +166,10 @@ def _setup_lora_tuning(
assert len(model_args.adapter_name_or_path) == 1, "Cannot use multiple adapters in DeepSpeed ZeRO-3." assert len(model_args.adapter_name_or_path) == 1, "Cannot use multiple adapters in DeepSpeed ZeRO-3."
is_mergeable = False is_mergeable = False
if model_args.use_kt:
assert len(model_args.adapter_name_or_path) == 1, "KTransformers model only accepts a single adapter"
is_mergeable = False
if model_args.use_unsloth: if model_args.use_unsloth:
assert len(model_args.adapter_name_or_path) == 1, "Unsloth model only accepts a single adapter." assert len(model_args.adapter_name_or_path) == 1, "Unsloth model only accepts a single adapter."
is_mergeable = False is_mergeable = False
...@@ -179,6 +188,12 @@ def _setup_lora_tuning( ...@@ -179,6 +188,12 @@ def _setup_lora_tuning(
"token": model_args.hf_hub_token, "token": model_args.hf_hub_token,
} }
if model_args.use_kt:
if model_args.infer_backend != EngineName.KT:
raise ValueError(
"We should use ktransformers as backend to infer the adapter fine-tuned by ktransformers."
)
for adapter in adapter_to_merge: for adapter in adapter_to_merge:
model: LoraModel = PeftModel.from_pretrained(model, adapter, **init_kwargs) model: LoraModel = PeftModel.from_pretrained(model, adapter, **init_kwargs)
model = model.merge_and_unload() model = model.merge_and_unload()
...@@ -187,7 +202,9 @@ def _setup_lora_tuning( ...@@ -187,7 +202,9 @@ def _setup_lora_tuning(
logger.info_rank0(f"Merged {len(adapter_to_merge)} adapter(s).") logger.info_rank0(f"Merged {len(adapter_to_merge)} adapter(s).")
if adapter_to_resume is not None: # resume lora training if adapter_to_resume is not None: # resume lora training
if model_args.use_unsloth: if model_args.use_kt:
model = load_kt_peft_model(model_args, model)
elif model_args.use_unsloth:
model = load_unsloth_peft_model(config, model_args, finetuning_args, is_trainable=is_trainable) model = load_unsloth_peft_model(config, model_args, finetuning_args, is_trainable=is_trainable)
else: else:
model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_trainable, **init_kwargs) model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_trainable, **init_kwargs)
...@@ -200,6 +217,16 @@ def _setup_lora_tuning( ...@@ -200,6 +217,16 @@ def _setup_lora_tuning(
else: else:
target_modules = finetuning_args.lora_target target_modules = finetuning_args.lora_target
if model_args.use_kt:
new_list = []
for m in target_modules:
if m in ("down_proj", "up_proj", "gate_proj"):
new_list.extend([f"mlp.{m}", f"shared_experts.{m}"])
elif m not in ("generate_linear", "orig_module", "prefill_linear"):
new_list.append(m)
target_modules[:] = new_list
if finetuning_args.use_llama_pro: if finetuning_args.use_llama_pro:
target_modules = find_expanded_modules(model, target_modules, finetuning_args.freeze_trainable_layers) target_modules = find_expanded_modules(model, target_modules, finetuning_args.freeze_trainable_layers)
...@@ -223,17 +250,43 @@ def _setup_lora_tuning( ...@@ -223,17 +250,43 @@ def _setup_lora_tuning(
finetuning_args.additional_target = module_names finetuning_args.additional_target = module_names
logger.warning_rank0("Vocab has been resized, add {} to trainable params.".format(",".join(module_names))) logger.warning_rank0("Vocab has been resized, add {} to trainable params.".format(",".join(module_names)))
peft_kwargs = { if finetuning_args.finetuning_type == "lora":
"r": finetuning_args.lora_rank, peft_kwargs = {
"target_modules": target_modules, "r": finetuning_args.lora_rank,
"lora_alpha": finetuning_args.lora_alpha, "target_modules": target_modules,
"lora_dropout": finetuning_args.lora_dropout, "lora_alpha": finetuning_args.lora_alpha,
"use_rslora": finetuning_args.use_rslora, "lora_dropout": finetuning_args.lora_dropout,
"use_dora": finetuning_args.use_dora, "use_rslora": finetuning_args.use_rslora,
"modules_to_save": finetuning_args.additional_target, "use_dora": finetuning_args.use_dora,
} "modules_to_save": finetuning_args.additional_target,
}
elif finetuning_args.finetuning_type == "oft":
peft_kwargs = {
"r": finetuning_args.oft_rank,
"oft_block_size": finetuning_args.oft_block_size,
"target_modules": target_modules,
"module_dropout": finetuning_args.module_dropout,
"modules_to_save": finetuning_args.additional_target,
}
if model_args.use_kt:
if finetuning_args.finetuning_type == "oft":
raise ValueError("KTransformers is currently not supported for OFT.")
if finetuning_args.finetuning_type == "lora":
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
**peft_kwargs,
)
else:
raise ValueError("KTransformers is currently only supported for LoRA.")
model = get_kt_peft_model(model, peft_config)
print(f"KT_model:{model}")
elif model_args.use_unsloth:
if finetuning_args.finetuning_type == "oft":
raise ValueError("Unsloth is currently not supported for OFT.")
if model_args.use_unsloth:
model = get_unsloth_peft_model(model, model_args, peft_kwargs) model = get_unsloth_peft_model(model, model_args, peft_kwargs)
else: else:
if finetuning_args.pissa_init: if finetuning_args.pissa_init:
...@@ -244,12 +297,19 @@ def _setup_lora_tuning( ...@@ -244,12 +297,19 @@ def _setup_lora_tuning(
logger.info_rank0(f"Using PiSSA initialization with FSVD steps {finetuning_args.pissa_iter}.") logger.info_rank0(f"Using PiSSA initialization with FSVD steps {finetuning_args.pissa_iter}.")
peft_kwargs["init_lora_weights"] = f"pissa_niter_{finetuning_args.pissa_iter}" peft_kwargs["init_lora_weights"] = f"pissa_niter_{finetuning_args.pissa_iter}"
lora_config = LoraConfig( if finetuning_args.finetuning_type == "lora":
task_type=TaskType.CAUSAL_LM, peft_config = LoraConfig(
inference_mode=False, task_type=TaskType.CAUSAL_LM,
**peft_kwargs, inference_mode=False,
) **peft_kwargs,
model = get_peft_model(model, lora_config) )
elif finetuning_args.finetuning_type == "oft":
peft_config = OFTConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
**peft_kwargs,
)
model = get_peft_model(model, peft_config)
if is_trainable and cast_trainable_params_to_fp32: if is_trainable and cast_trainable_params_to_fp32:
for param in filter(lambda p: p.requires_grad, model.parameters()): for param in filter(lambda p: p.requires_grad, model.parameters()):
...@@ -272,8 +332,8 @@ def init_adapter( ...@@ -272,8 +332,8 @@ def init_adapter(
Note that the trainable parameters must be cast to float32. Note that the trainable parameters must be cast to float32.
""" """
if is_trainable and getattr(model, "quantization_method", None) is not None: if is_trainable and getattr(model, "quantization_method", None) is not None:
if finetuning_args.finetuning_type != "lora": if finetuning_args.finetuning_type not in ["lora", "oft"]:
raise ValueError("Quantized models can only be used for the LoRA tuning.") raise ValueError("Quantized models can only be used for the LoRA or OFT tuning.")
if finetuning_args.pissa_init: if finetuning_args.pissa_init:
raise ValueError("Cannot initialize PiSSA adapter on quantized models.") raise ValueError("Cannot initialize PiSSA adapter on quantized models.")
...@@ -296,7 +356,7 @@ def init_adapter( ...@@ -296,7 +356,7 @@ def init_adapter(
_setup_full_tuning(model, finetuning_args, is_trainable, cast_trainable_params_to_fp32) _setup_full_tuning(model, finetuning_args, is_trainable, cast_trainable_params_to_fp32)
elif finetuning_args.finetuning_type == "freeze": elif finetuning_args.finetuning_type == "freeze":
_setup_freeze_tuning(model, finetuning_args, is_trainable, cast_trainable_params_to_fp32) _setup_freeze_tuning(model, finetuning_args, is_trainable, cast_trainable_params_to_fp32)
elif finetuning_args.finetuning_type == "lora": elif finetuning_args.finetuning_type in ["lora", "oft"]:
model = _setup_lora_tuning( model = _setup_lora_tuning(
config, model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32 config, model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32
) )
......
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
import os import os
from typing import TYPE_CHECKING, Any, Optional, TypedDict from typing import TYPE_CHECKING, Any, Optional, TypedDict
import torch
from transformers import ( from transformers import (
AutoConfig, AutoConfig,
AutoModelForCausalLM, AutoModelForCausalLM,
...@@ -31,6 +30,7 @@ from trl import AutoModelForCausalLMWithValueHead ...@@ -31,6 +30,7 @@ from trl import AutoModelForCausalLMWithValueHead
from ..extras import logging from ..extras import logging
from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_other_hub from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_other_hub
from .adapter import init_adapter from .adapter import init_adapter
from .model_utils.ktransformers import load_kt_pretrained_model
from .model_utils.liger_kernel import apply_liger_kernel from .model_utils.liger_kernel import apply_liger_kernel
from .model_utils.misc import register_autoclass from .model_utils.misc import register_autoclass
from .model_utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model from .model_utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model
...@@ -143,7 +143,12 @@ def load_model( ...@@ -143,7 +143,12 @@ def load_model(
model = None model = None
lazy_load = False lazy_load = False
if model_args.use_unsloth: if model_args.use_kt:
from ktransformers.sft.monkey_patch_torch_module import install_patch
install_patch()
model = load_kt_pretrained_model(config, model_args)
elif model_args.use_unsloth:
if model_args.adapter_name_or_path is not None: if model_args.adapter_name_or_path is not None:
lazy_load = True lazy_load = True
elif is_trainable: elif is_trainable:
...@@ -152,17 +157,18 @@ def load_model( ...@@ -152,17 +157,18 @@ def load_model(
if model is None and not lazy_load: if model is None and not lazy_load:
init_kwargs["config"] = config init_kwargs["config"] = config
init_kwargs["pretrained_model_name_or_path"] = model_args.model_name_or_path init_kwargs["pretrained_model_name_or_path"] = model_args.model_name_or_path
init_kwargs["torch_dtype"] = "auto"
if model_args.mixture_of_depths == "load": if model_args.mixture_of_depths == "load":
model = load_mod_pretrained_model(**init_kwargs) model = load_mod_pretrained_model(**init_kwargs)
else: else:
if type(config) in AutoModelForVision2Seq._model_mapping.keys(): # image-text if type(config) in AutoModelForImageTextToText._model_mapping.keys(): # image-text
load_class = AutoModelForVision2Seq
elif type(config) in AutoModelForImageTextToText._model_mapping.keys(): # image-text
load_class = AutoModelForImageTextToText load_class = AutoModelForImageTextToText
elif type(config) in AutoModelForVision2Seq._model_mapping.keys(): # image-text
load_class = AutoModelForVision2Seq
elif type(config) in AutoModelForSeq2SeqLM._model_mapping.keys(): # audio-text elif type(config) in AutoModelForSeq2SeqLM._model_mapping.keys(): # audio-text
load_class = AutoModelForSeq2SeqLM load_class = AutoModelForSeq2SeqLM
elif type(config) in AutoModelForTextToWaveform._model_mapping.keys(): # audio hack for qwen2_5_omni elif type(config) in AutoModelForTextToWaveform._model_mapping.keys(): # audio hack for qwen omni
load_class = AutoModelForTextToWaveform load_class = AutoModelForTextToWaveform
else: else:
load_class = AutoModelForCausalLM load_class = AutoModelForCausalLM
...@@ -171,8 +177,8 @@ def load_model( ...@@ -171,8 +177,8 @@ def load_model(
model = load_class.from_config(config, trust_remote_code=model_args.trust_remote_code) model = load_class.from_config(config, trust_remote_code=model_args.trust_remote_code)
else: else:
model = load_class.from_pretrained(**init_kwargs) model = load_class.from_pretrained(**init_kwargs)
if getattr(model.config, "model_type", None) == "qwen2_5_omni": if getattr(model.config, "model_type", None) in ["qwen2_5_omni", "qwen3_omni_moe"]:
model = model.thinker # use part of Omni model model = getattr(model, "thinker")
if model_args.mixture_of_depths == "convert": if model_args.mixture_of_depths == "convert":
model = convert_pretrained_model_to_mod(model, config, model_args) model = convert_pretrained_model_to_mod(model, config, model_args)
...@@ -199,14 +205,21 @@ def load_model( ...@@ -199,14 +205,21 @@ def load_model(
if not is_trainable: if not is_trainable:
model.requires_grad_(False) model.requires_grad_(False)
for param in model.parameters():
if param.data.dtype == torch.float32 and model_args.compute_dtype != torch.float32:
param.data = param.data.to(model_args.compute_dtype)
model.eval() model.eval()
else: else:
model.train() model.train()
# Borrowing the kernel plugins ability of v1 to temporarily apply the NPU fusion operator to v0,
# it is turned off by default, and can be discarded after the transition period ends.
if model_args.use_v1_kernels and is_trainable:
logger.warning_rank0(
"You are try to using future feature about kernels, please note that this feature "
"is not supported for all models. If get any error, please disable this feature, or report the issue."
)
from ..v1.plugins.model_plugins.kernels.interface import apply_default_kernels
model = apply_default_kernels(model=model, include_kernels=model_args.use_v1_kernels)
trainable_params, all_param = count_parameters(model) trainable_params, all_param = count_parameters(model)
if is_trainable: if is_trainable:
param_stats = ( param_stats = (
......
...@@ -14,10 +14,9 @@ ...@@ -14,10 +14,9 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available
from ...extras import logging from ...extras import logging
from ...extras.constants import AttentionFunction from ...extras.constants import AttentionFunction
from ...extras.packages import is_torch_version_greater_than
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -30,6 +29,20 @@ logger = logging.get_logger(__name__) ...@@ -30,6 +29,20 @@ logger = logging.get_logger(__name__)
def configure_attn_implementation(config: "PretrainedConfig", model_args: "ModelArguments") -> None: def configure_attn_implementation(config: "PretrainedConfig", model_args: "ModelArguments") -> None:
from transformers.utils import is_flash_attn_2_available
if getattr(config, "model_type", None) == "gpt_oss":
from transformers.integrations.hub_kernels import load_and_register_kernel
flash_attn3_kernel = "kernels-community/vllm-flash-attn3"
load_and_register_kernel(flash_attn3_kernel)
setattr(config, "_attn_implementation", flash_attn3_kernel)
setattr(config, "_attn_implementation_internal", flash_attn3_kernel)
model_args.flash_attn = AttentionFunction.FA3
logger.info_rank0("Using FlashAttention-3 with attention sink for the gpt-oss model.")
return
if getattr(config, "model_type", None) == "gemma2": if getattr(config, "model_type", None) == "gemma2":
if model_args.flash_attn == AttentionFunction.AUTO or model_args.flash_attn == AttentionFunction.FA2: if model_args.flash_attn == AttentionFunction.AUTO or model_args.flash_attn == AttentionFunction.FA2:
if is_flash_attn_2_available(): if is_flash_attn_2_available():
...@@ -51,13 +64,15 @@ def configure_attn_implementation(config: "PretrainedConfig", model_args: "Model ...@@ -51,13 +64,15 @@ def configure_attn_implementation(config: "PretrainedConfig", model_args: "Model
requested_attn_implementation = "eager" requested_attn_implementation = "eager"
elif model_args.flash_attn == AttentionFunction.SDPA: elif model_args.flash_attn == AttentionFunction.SDPA:
if not is_torch_sdpa_available(): if not is_torch_version_greater_than("2.1.1"):
logger.warning_rank0("torch>=2.1.1 is required for SDPA attention.") logger.warning_rank0("torch>=2.1.1 is required for SDPA attention.")
return return
requested_attn_implementation = "sdpa" requested_attn_implementation = "sdpa"
elif model_args.flash_attn == AttentionFunction.FA2: elif model_args.flash_attn == AttentionFunction.FA2:
if not is_flash_attn_2_available(): from transformers import is_torch_npu_available
if not (is_flash_attn_2_available() or is_torch_npu_available()):
logger.warning_rank0("FlashAttention-2 is not installed.") logger.warning_rank0("FlashAttention-2 is not installed.")
return return
......
...@@ -19,9 +19,11 @@ ...@@ -19,9 +19,11 @@
# limitations under the License. # limitations under the License.
import inspect import inspect
import os
from collections.abc import Callable
from functools import WRAPPER_ASSIGNMENTS, partial, wraps from functools import WRAPPER_ASSIGNMENTS, partial, wraps
from types import MethodType from types import MethodType
from typing import TYPE_CHECKING, Any, Callable, Optional, Union from typing import TYPE_CHECKING, Any, Optional, Union
import torch import torch
...@@ -152,6 +154,13 @@ def prepare_model_for_training(model: "PreTrainedModel", model_args: "ModelArgum ...@@ -152,6 +154,13 @@ def prepare_model_for_training(model: "PreTrainedModel", model_args: "ModelArgum
if param.ndim == 1 and any(ln_name in name for ln_name in LAYERNORM_NAMES): if param.ndim == 1 and any(ln_name in name for ln_name in LAYERNORM_NAMES):
param.data = param.data.to(torch.float32) param.data = param.data.to(torch.float32)
if (
os.environ.get("ACCELERATE_USE_FSDP", "false").lower() == "true"
and int(os.environ.get("FSDP_VERSION", "1")) == 2
):
model_args.use_reentrant_gc = False
logger.warning_rank0("You are using fsdp2, `use_reentrant_gc` has been set to False.")
if not model_args.disable_gradient_checkpointing: if not model_args.disable_gradient_checkpointing:
if not getattr(model, "supports_gradient_checkpointing", False): if not getattr(model, "supports_gradient_checkpointing", False):
logger.warning_rank0("Current model does not support gradient checkpointing.") logger.warning_rank0("Current model does not support gradient checkpointing.")
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
import math import math
from contextlib import nullcontext from contextlib import nullcontext
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, Optional
import torch import torch
from transformers.integrations import is_deepspeed_zero3_enabled from transformers.integrations import is_deepspeed_zero3_enabled
...@@ -30,6 +30,14 @@ logger = logging.get_logger(__name__) ...@@ -30,6 +30,14 @@ logger = logging.get_logger(__name__)
def _noisy_mean_initialization(embed_weight: "torch.Tensor", num_new_tokens: int) -> None: def _noisy_mean_initialization(embed_weight: "torch.Tensor", num_new_tokens: int) -> None:
"""Initialize new token embeddings with mean + Gaussian noise.
This is the default initialization method used by LlamaFactory.
Args:
embed_weight: The embedding weight matrix to initialize (shape: [vocab_size, embedding_dim])
num_new_tokens: Number of new tokens added at the end of the embedding matrix
"""
embedding_dim = embed_weight.size(1) embedding_dim = embed_weight.size(1)
avg_weight = embed_weight[:-num_new_tokens].mean(dim=0, keepdim=True) avg_weight = embed_weight[:-num_new_tokens].mean(dim=0, keepdim=True)
noise_weight = torch.empty_like(embed_weight[-num_new_tokens:]) noise_weight = torch.empty_like(embed_weight[-num_new_tokens:])
...@@ -37,8 +45,125 @@ def _noisy_mean_initialization(embed_weight: "torch.Tensor", num_new_tokens: int ...@@ -37,8 +45,125 @@ def _noisy_mean_initialization(embed_weight: "torch.Tensor", num_new_tokens: int
embed_weight[-num_new_tokens:] = avg_weight + noise_weight embed_weight[-num_new_tokens:] = avg_weight + noise_weight
def resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer") -> None: def _description_based_initialization(
r"""Resize token embeddings.""" embed_weight: "torch.Tensor",
num_new_tokens: int,
descriptions: dict[str, str],
tokenizer: "PreTrainedTokenizer",
model: "PreTrainedModel",
add_noise: bool = False,
) -> None:
"""Initialize new token embeddings based on textual descriptions.
For each new token, this function:
1. Tokenizes its description text
2. Gets embeddings of the description tokens
3. Averages them to initialize the new token's embedding
4. Optionally adds Gaussian noise
Args:
embed_weight: The embedding weight matrix to initialize (shape: [vocab_size, embedding_dim])
num_new_tokens: Number of new tokens added
descriptions: Dict mapping token string to its description text
e.g., {"<think>": "A token representing reasoning process"}
tokenizer: The tokenizer instance
model: The model instance (used to get input embeddings)
add_noise: Whether to add Gaussian noise to the initialization
Example:
descriptions = {
"<|START_OF_SVG|>": "Marks the beginning of an SVG document",
"<|END_OF_SVG|>": "Marks the end of an SVG document"
}
"""
embedding_dim = embed_weight.size(1)
for i, desc in enumerate(descriptions.values()):
# Tokenize description text
tokens = tokenizer(desc, return_tensors="pt", add_special_tokens=False)
with torch.no_grad():
token_ids = tokens["input_ids"][0]
# Move to the same device as embed_weight
device = embed_weight.device
token_ids = token_ids.to(device)
# Filter out new tokens (they don't have valid embeddings yet)
valid_token_ids = token_ids[token_ids < (len(tokenizer) - num_new_tokens)]
if len(valid_token_ids) == 0:
# Fallback: use mean of all existing embeddings
logger.warning_rank0(
f"Description for token {i + 1}/{num_new_tokens} contains no valid tokens. "
"Using mean of existing embeddings."
)
base_embedding = embed_weight[:-num_new_tokens].mean(dim=0)
else:
# Get embeddings of description tokens and average them
token_embeds = model.get_input_embeddings()(valid_token_ids)
base_embedding = token_embeds.mean(dim=0)
# Add noise if requested (ensure correct device and dtype)
if add_noise:
noise = torch.randn_like(base_embedding) * (1.0 / math.sqrt(embedding_dim))
embed_weight[-num_new_tokens + i] = base_embedding + noise
else:
embed_weight[-num_new_tokens + i] = base_embedding
def _initialize_embeddings(
embed_weight: "torch.Tensor",
num_new_tokens: int,
init_method: str,
new_special_tokens_config: Optional[dict],
tokenizer: "PreTrainedTokenizer",
model: "PreTrainedModel",
) -> None:
"""Single source of truth for embedding initialization.
This function selects the appropriate initialization method and applies it.
Args:
embed_weight: The embedding weight matrix to initialize
num_new_tokens: Number of new tokens added
init_method: Initialization method ('noise_init', 'desc_init', 'desc_init_w_noise')
new_special_tokens_config: Config dict with token descriptions (required for desc_init methods)
tokenizer: The tokenizer instance
model: The model instance
"""
if init_method == "desc_init" and new_special_tokens_config:
logger.info_rank0("Using semantic initialization (desc_init) for new special tokens")
_description_based_initialization(
embed_weight, num_new_tokens, new_special_tokens_config, tokenizer, model, add_noise=False
)
elif init_method == "desc_init_w_noise" and new_special_tokens_config:
logger.info_rank0("Using semantic initialization with noise (desc_init_w_noise) for new special tokens")
_description_based_initialization(
embed_weight, num_new_tokens, new_special_tokens_config, tokenizer, model, add_noise=True
)
else:
if init_method != "noise_init":
logger.warning_rank0(
f"init_method='{init_method}' requires descriptions config, falling back to 'noise_init'"
)
logger.info_rank0("Using noisy mean initialization (noise_init) for new special tokens")
_noisy_mean_initialization(embed_weight, num_new_tokens)
def resize_embedding_layer(
model: "PreTrainedModel",
tokenizer: "PreTrainedTokenizer",
new_special_tokens_config: Optional[dict] = None,
init_special_tokens: str = "noise_init",
) -> None:
r"""Resize token embeddings and initialize new tokens.
Args:
model: The model to resize
tokenizer: The tokenizer (used to get target vocab size)
new_special_tokens_config: Optional dict with token descriptions for semantic initialization
init_special_tokens: Initialization method ('noise_init', 'desc_init', 'desc_init_w_noise')
"""
if is_deepspeed_zero3_enabled(): if is_deepspeed_zero3_enabled():
import deepspeed # type: ignore import deepspeed # type: ignore
...@@ -64,7 +189,30 @@ def resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedToken ...@@ -64,7 +189,30 @@ def resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedToken
with context_maybe_zero3: with context_maybe_zero3:
new_embedding_size = model.get_input_embeddings().weight.size(0) new_embedding_size = model.get_input_embeddings().weight.size(0)
num_new_tokens = new_embedding_size - current_embedding_size num_new_tokens = new_embedding_size - current_embedding_size
_noisy_mean_initialization(model.get_input_embeddings().weight.data, num_new_tokens) logger.info_rank0(
_noisy_mean_initialization(model.get_output_embeddings().weight.data, num_new_tokens) f"Resizing embeddings: {current_embedding_size} -> {new_embedding_size} (+{num_new_tokens} tokens)"
)
# Initialize input embeddings
_initialize_embeddings(
model.get_input_embeddings().weight.data,
num_new_tokens,
init_special_tokens,
new_special_tokens_config,
tokenizer,
model,
)
# Initialize output embeddings if not tied
if model.get_output_embeddings() is not None and not model.config.tie_word_embeddings:
_initialize_embeddings(
model.get_output_embeddings().weight.data,
num_new_tokens,
init_special_tokens,
new_special_tokens_config,
tokenizer,
model,
)
model.config.vocab_size = new_embedding_size
logger.info_rank0(f"Resized token embeddings from {current_embedding_size} to {new_embedding_size}.") logger.info_rank0(f"Resized token embeddings from {current_embedding_size} to {new_embedding_size}.")
# Copyright 2025 the KVCache.AI team, Approaching AI, and the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib.util as _u
from typing import TYPE_CHECKING, Any
import torch
from ...extras import logging
from ...extras.misc import get_current_device
if TYPE_CHECKING:
from ...hparams import FinetuningArguments, ModelArguments
from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
KT_AVAILABLE = _u.find_spec("ktransformers") is not None
if KT_AVAILABLE:
from ktransformers.models.modeling_deepseek import DeepseekV2ForCausalLM
from ktransformers.models.modeling_deepseek_v3 import DeepseekV3ForCausalLM
from ktransformers.models.modeling_llama import LlamaForCausalLM
from ktransformers.models.modeling_mixtral import MixtralForCausalLM
from ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM
from ktransformers.models.modeling_qwen3_moe import Qwen3MoeForCausalLM
from ktransformers.optimize.optimize import optimize_and_load_gguf
from ktransformers.server.config.config import Config
from ktransformers.sft.lora import inject_lora_layer
from ktransformers.util.custom_loader import GGUFLoader, SafeTensorLoader
from ktransformers.util.globals import GLOBAL_CONFIG
from ktransformers.util.utils import load_weights
logger = logging.get_logger(__name__)
def _get_kt_kwargs(
config: "PretrainedConfig",
model_name_or_path: str,
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
) -> dict[str, Any]:
return {
"model_name": model_name_or_path,
"max_seq_length": model_args.model_max_length or 4096,
"dtype": model_args.compute_dtype,
"load_in_4bit": model_args.quantization_bit == 4,
"token": model_args.hf_hub_token,
"full_finetuning": finetuning_args.finetuning_type == "full",
"device_map": {"": get_current_device()},
"rope_scaling": getattr(config, "rope_scaling", None),
"fix_tokenizer": False,
"trust_remote_code": model_args.trust_remote_code,
"use_gradient_checkpointing": "ktransformers",
}
def load_kt_pretrained_model(config: "PretrainedConfig", model_args: "ModelArguments") -> "PreTrainedModel":
r"""Optionally load pretrained model with KTransformers. Used in training."""
custom_models = {
"DeepseekV2ForCausalLM": DeepseekV2ForCausalLM,
"DeepseekV3ForCausalLM": DeepseekV3ForCausalLM,
"Qwen2MoeForCausalLM": Qwen2MoeForCausalLM,
"Qwen3MoeForCausalLM": Qwen3MoeForCausalLM,
"LlamaForCausalLM": LlamaForCausalLM,
"MixtralForCausalLM": MixtralForCausalLM,
}
Config().cpu_infer = model_args.cpu_infer
Config().chunk_size = model_args.chunk_size
config = AutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code)
if model_args.mode == "long_context":
assert config.architectures[0] == "LlamaForCausalLM", "only LlamaForCausalLM support long_context mode"
torch.set_default_dtype(torch.float16)
else:
torch.set_default_dtype(config.torch_dtype)
with torch.device("meta"):
if config.architectures[0] in custom_models:
print("using custom modeling_xxx.py.")
if "Qwen2Moe" in config.architectures[0]: # Qwen2Moe must use flash_attention_2 to avoid overflow.
config._attn_implementation = "flash_attention_2"
if "Llama" in config.architectures[0]:
config._attn_implementation = "eager"
if "Mixtral" in config.architectures[0]:
config._attn_implementation = "flash_attention_2"
model = custom_models[config.architectures[0]](config)
else:
attn_implementation = "flash_attention_2"
model = AutoModelForCausalLM.from_config(
config, trust_remote_code=True, attn_implementation=attn_implementation
)
optimize_config_path = model_args.kt_optimize_rule
gguf_path = model_args.model_name_or_path
assert optimize_config_path is not None, "optimize_config_path must be provided (path to YAML rules file)."
assert gguf_path is not None, "gguf_path must be provided (path to a folder or .gguf file)."
GLOBAL_CONFIG._config["mod"] = "infer"
optimize_and_load_gguf(model, optimize_config_path, gguf_path, config)
return model
def get_kt_peft_model(model: "PreTrainedModel", peft_kwargs: dict[str, Any]) -> "PreTrainedModel":
r"""Get the peft model for the pretrained model with KTransformers. Used in training."""
from ktransformers.sft.peft_utils.mapping import get_peft_model
return get_peft_model(model, peft_kwargs)
def load_kt_peft_model(model_args: "ModelArguments", model: "PreTrainedModel") -> "PreTrainedModel":
r"""Load peft model with KTransformers. Used in both training and inference."""
load_adapter_name_or_path = model_args.adapter_name_or_path[0]
if load_adapter_name_or_path.endswith(".gguf"):
inject_lora_layer(model, load_adapter_name_or_path)
adapter_gguf_loader = GGUFLoader(load_adapter_name_or_path)
load_weights(model, adapter_gguf_loader, adapter_gguf=True)
model.train()
else:
inject_lora_layer(model, load_adapter_name_or_path)
adapter_loader = SafeTensorLoader(load_adapter_name_or_path)
device = next(model.parameters()).device
for key in adapter_loader.tensor_file_map.keys():
try:
tensor = adapter_loader.load_tensor(key, device=device)
model_key = key.replace("base_model.model.", "")
model_key = model_key.replace(".weight", ".default.weight")
model_key = model_key.replace(".default.default.weight", ".default.weight")
param = model.get_parameter(model_key)
param.data.copy_(tensor.data)
print(f"Loaded adapter weight: {key} -> {model_key}")
except AttributeError:
print(f"Skipping {key}: not a model parameter")
except KeyError:
print(f"Key not found in model: {model_key} (original: {key})")
return model
...@@ -28,11 +28,11 @@ if TYPE_CHECKING: ...@@ -28,11 +28,11 @@ if TYPE_CHECKING:
def configure_kv_cache(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None: def configure_kv_cache(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
if not is_trainable: if not is_trainable:
setattr(config, "use_cache", model_args.use_cache) setattr(config, "use_cache", model_args.use_kv_cache)
if hasattr(config, "text_config"): if hasattr(config, "text_config"):
setattr(config.text_config, "use_cache", model_args.use_cache) setattr(config.text_config, "use_cache", model_args.use_kv_cache)
if model_args.use_cache: if model_args.use_kv_cache:
logger.info_rank0("KV cache is enabled for faster generation.") logger.info_rank0("KV cache is enabled for faster generation.")
else: else:
logger.info_rank0("KV cache is disabled.") logger.info_rank0("KV cache is disabled.")
......
...@@ -47,6 +47,8 @@ def apply_liger_kernel( ...@@ -47,6 +47,8 @@ def apply_liger_kernel(
from liger_kernel.transformers import apply_liger_kernel_to_gemma3_text as apply_liger_kernel from liger_kernel.transformers import apply_liger_kernel_to_gemma3_text as apply_liger_kernel
elif model_type == "glm4": elif model_type == "glm4":
from liger_kernel.transformers import apply_liger_kernel_to_glm4 as apply_liger_kernel from liger_kernel.transformers import apply_liger_kernel_to_glm4 as apply_liger_kernel
elif model_type == "glm4v":
from liger_kernel.transformers import apply_liger_kernel_to_glm4v as apply_liger_kernel
elif model_type == "granite": elif model_type == "granite":
from liger_kernel.transformers import apply_liger_kernel_to_granite as apply_liger_kernel from liger_kernel.transformers import apply_liger_kernel_to_granite as apply_liger_kernel
elif model_type == "llama": elif model_type == "llama":
...@@ -75,6 +77,12 @@ def apply_liger_kernel( ...@@ -75,6 +77,12 @@ def apply_liger_kernel(
from liger_kernel.transformers import apply_liger_kernel_to_qwen3 as apply_liger_kernel from liger_kernel.transformers import apply_liger_kernel_to_qwen3 as apply_liger_kernel
elif model_type == "qwen3_moe": elif model_type == "qwen3_moe":
from liger_kernel.transformers import apply_liger_kernel_to_qwen3_moe as apply_liger_kernel from liger_kernel.transformers import apply_liger_kernel_to_qwen3_moe as apply_liger_kernel
elif model_type == "gpt_oss":
try:
from liger_kernel.transformers import apply_liger_kernel_to_gpt_oss as apply_liger_kernel
except ImportError:
logger.warning_rank0("Please install liger-kernel from https://github.com/Comet0322/Liger-Kernel.")
return
else: else:
logger.warning_rank0("Current model does not support liger kernel.") logger.warning_rank0("Current model does not support liger kernel.")
return return
......
...@@ -14,9 +14,13 @@ ...@@ -14,9 +14,13 @@
from typing import TYPE_CHECKING, Union from typing import TYPE_CHECKING, Union
import torch
from torch import nn
from torch.nn import functional as F
from transformers.integrations import is_deepspeed_zero3_enabled from transformers.integrations import is_deepspeed_zero3_enabled
from ...extras.misc import check_version from ...extras.misc import check_version
from ...extras.packages import is_transformers_version_greater_than
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -25,6 +29,9 @@ if TYPE_CHECKING: ...@@ -25,6 +29,9 @@ if TYPE_CHECKING:
from ...hparams import ModelArguments from ...hparams import ModelArguments
if is_transformers_version_greater_than("4.57.0"):
from transformers.models.qwen3_omni_moe import modeling_qwen3_omni_moe
def _set_z3_leaf_modules(model: "PreTrainedModel", leaf_modules: list[Union["nn.Module", str]]) -> None: def _set_z3_leaf_modules(model: "PreTrainedModel", leaf_modules: list[Union["nn.Module", str]]) -> None:
check_version("deepspeed>=0.13.0") check_version("deepspeed>=0.13.0")
...@@ -39,6 +46,9 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None: ...@@ -39,6 +46,9 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
return return
model_type = getattr(model.config, "model_type", None) model_type = getattr(model.config, "model_type", None)
text_config = getattr(model.config, "text_config", None)
text_model_type = getattr(text_config, "model_type", None)
if model_type == "dbrx": if model_type == "dbrx":
from transformers.models.dbrx.modeling_dbrx import DbrxFFN from transformers.models.dbrx.modeling_dbrx import DbrxFFN
...@@ -52,11 +62,31 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None: ...@@ -52,11 +62,31 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
# deepseek v3 and kimi vl use custom code # deepseek v3 and kimi vl use custom code
_set_z3_leaf_modules(model, ["DeepseekV3MoE"]) _set_z3_leaf_modules(model, ["DeepseekV3MoE"])
if model_type == "ernie4_5_moe":
from transformers.models.ernie4_5_moe.modeling_ernie4_5_moe import Ernie4_5_MoeSparseMoeBlock
_set_z3_leaf_modules(model, [Ernie4_5_MoeSparseMoeBlock])
if model_type == "granitemoe": if model_type == "granitemoe":
from transformers.models.granitemoe.modeling_granitemoe import GraniteMoeMoE from transformers.models.granitemoe.modeling_granitemoe import GraniteMoeMoE
_set_z3_leaf_modules(model, [GraniteMoeMoE]) _set_z3_leaf_modules(model, [GraniteMoeMoE])
if model_type == "glm4_moe":
from transformers.models.glm4_moe.modeling_glm4_moe import Glm4MoeMoE
_set_z3_leaf_modules(model, [Glm4MoeMoE])
if model_type == "glm4v_moe":
from transformers.models.glm4v_moe.modeling_glm4v_moe import Glm4vMoeTextMoE
_set_z3_leaf_modules(model, [Glm4vMoeTextMoE])
if model_type == "gpt_oss":
from transformers.models.gpt_oss.modeling_gpt_oss import GptOssMLP
_set_z3_leaf_modules(model, [GptOssMLP])
if model_type == "jamba": if model_type == "jamba":
from transformers.models.jamba.modeling_jamba import JambaSparseMoeBlock from transformers.models.jamba.modeling_jamba import JambaSparseMoeBlock
...@@ -92,19 +122,32 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None: ...@@ -92,19 +122,32 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
_set_z3_leaf_modules(model, [Qwen2MoeSparseMoeBlock]) _set_z3_leaf_modules(model, [Qwen2MoeSparseMoeBlock])
if model_type == "qwen3_moe": if model_type == "qwen3_moe" or text_model_type == "qwen3_moe": # internvl 3.5
from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeSparseMoeBlock from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeSparseMoeBlock
_set_z3_leaf_modules(model, [Qwen3MoeSparseMoeBlock]) _set_z3_leaf_modules(model, [Qwen3MoeSparseMoeBlock])
if model_type == "qwen3_vl_moe":
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeTextSparseMoeBlock
_set_z3_leaf_modules(model, [Qwen3VLMoeTextSparseMoeBlock])
if model_type in ("qwen3_omni_moe", "qwen3_omni_moe_thinker"):
from transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe import Qwen3OmniMoeThinkerTextSparseMoeBlock
_set_z3_leaf_modules(model, [Qwen3OmniMoeThinkerTextSparseMoeBlock])
def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None: def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
if not is_trainable or not model_args.moe_aux_loss_coef: if not is_trainable or not model_args.moe_aux_loss_coef:
return return
model_type = getattr(config, "model_type", None) model_type = getattr(config, "model_type", None)
text_config = getattr(config, "text_config", None) # for multimodal model
if model_type in [ if model_type in [
"dbrx", "dbrx",
"ernie4_5_moe",
"granitemoe", "granitemoe",
"jamba", "jamba",
"jetmoe", "jetmoe",
...@@ -117,11 +160,93 @@ def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_t ...@@ -117,11 +160,93 @@ def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_t
]: ]:
setattr(config, "output_router_logits", True) setattr(config, "output_router_logits", True)
if model_type in ["granitemoe", "jamba", "llama4", "mixtral", "olmoe", "phimoe", "qwen2_moe", "qwen3_moe"]: if text_config and getattr(text_config, "model_type", None) in [
"glm4v_moe_text", # glmv4_5
"qwen3_moe", # internvl_3_5
]:
setattr(text_config, "output_router_logits", True)
if model_type in [
"ernie4_5_moe",
"granitemoe",
"jamba",
"llama4",
"mixtral",
"olmoe",
"phimoe",
"qwen2_moe",
"qwen3_moe",
]:
setattr(config, "router_aux_loss_coef", model_args.moe_aux_loss_coef) setattr(config, "router_aux_loss_coef", model_args.moe_aux_loss_coef)
elif text_config and getattr(text_config, "model_type", None) in ["qwen3_moe"]:
setattr(text_config, "router_aux_loss_coef", model_args.moe_aux_loss_coef)
elif model_type == "deepseek": elif model_type == "deepseek":
setattr(config, "aux_loss_alpha", model_args.moe_aux_loss_coef) setattr(config, "aux_loss_alpha", model_args.moe_aux_loss_coef)
elif model_type == "jetmoe": elif model_type == "jetmoe":
setattr(config, "aux_loss_coef", model_args.moe_aux_loss_coef) setattr(config, "aux_loss_coef", model_args.moe_aux_loss_coef)
class Qwen3OmniMoeThinkerTextSparseMoeBlock(nn.Module):
def __init__(self, config):
super().__init__()
self.num_experts = config.num_experts
self.top_k = config.num_experts_per_tok
self.norm_topk_prob = config.norm_topk_prob
# gating
self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
self.experts = nn.ModuleList(
[
modeling_qwen3_omni_moe.Qwen3OmniMoeThinkerTextMLP(
config, intermediate_size=config.moe_intermediate_size
)
for _ in range(self.num_experts)
]
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)
# Calculate the routing weights for all experts
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
# Retain the weight of the top_k and reset the rest of the expert rights to 0 (instead of retaining only top_k experts)
top_k_weights, top_k_indices = torch.topk(routing_weights, self.top_k, dim=-1)
# Initialize the all-zero weight matrix (same shape as all experts)
full_routing_weights = torch.zeros_like(routing_weights)
# Only the weight of top_k experts is retained, and the weight of the rest of the experts remains at 0
full_routing_weights.scatter_(1, top_k_indices, top_k_weights)
# Normalized top_k weights (keep the original logic consistent)
if self.norm_topk_prob:
# Calculate the sum of the weights top_k each row (for normalization)
top_k_sum = full_routing_weights.sum(dim=-1, keepdim=True)
# Avoid dividing by zero
top_k_sum = torch.clamp(top_k_sum, min=1e-9)
full_routing_weights /= top_k_sum
# Convert back to the input data type
full_routing_weights = full_routing_weights.to(hidden_states.dtype)
final_hidden_states = torch.zeros(
(batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
)
# Go through all the experts (not just the selected ones)
for expert_idx in range(self.num_experts):
expert_layer = self.experts[expert_idx]
# Get the weight of the current expert (inactive expert has a weight of 0 here)
expert_weights = full_routing_weights[:, expert_idx, None] # shape: (batch*seq, 1)
# All samples participate in the calculations of the current expert, the weight may be equal to 0
current_hidden_states = expert_layer(hidden_states) * expert_weights
# Add-up to all expert outputs (experts with a weight of 0 do not affect the result)
final_hidden_states += current_hidden_states
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
return final_hidden_states, router_logits
...@@ -53,7 +53,7 @@ logger = logging.get_logger(__name__) ...@@ -53,7 +53,7 @@ logger = logging.get_logger(__name__)
def get_seqlens_in_batch(attention_mask: "torch.Tensor") -> "torch.Tensor": def get_seqlens_in_batch(attention_mask: "torch.Tensor") -> "torch.Tensor":
r"""Get the sequnce lengths in the current batch. r"""Get the sequence lengths in the current batch.
e.g. e.g.
```python ```python
......
...@@ -83,6 +83,7 @@ def configure_quantization( ...@@ -83,6 +83,7 @@ def configure_quantization(
config: "PretrainedConfig", config: "PretrainedConfig",
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments", model_args: "ModelArguments",
is_trainable: bool,
init_kwargs: dict[str, Any], init_kwargs: dict[str, Any],
) -> None: ) -> None:
r"""Priority: PTQ-quantized (train/infer) > AutoGPTQ (export) > On-the-fly quantization (train/infer).""" r"""Priority: PTQ-quantized (train/infer) > AutoGPTQ (export) > On-the-fly quantization (train/infer)."""
...@@ -90,12 +91,29 @@ def configure_quantization( ...@@ -90,12 +91,29 @@ def configure_quantization(
if model_args.quantization_bit is not None: if model_args.quantization_bit is not None:
logger.warning_rank0("`quantization_bit` will not affect on the PTQ-quantized models.") logger.warning_rank0("`quantization_bit` will not affect on the PTQ-quantized models.")
if is_deepspeed_zero3_enabled() or is_fsdp_enabled():
raise ValueError("DeepSpeed ZeRO-3 or FSDP is incompatible with PTQ-quantized models.")
quantization_config: dict[str, Any] = getattr(config, "quantization_config", None) quantization_config: dict[str, Any] = getattr(config, "quantization_config", None)
quant_method = quantization_config.get("quant_method", "") quant_method = quantization_config.get("quant_method", "")
if quant_method not in (QuantizationMethod.MXFP4, QuantizationMethod.FP8) and (
is_deepspeed_zero3_enabled() or is_fsdp_enabled()
):
# mxfp4 will dequant the model weights
raise ValueError("DeepSpeed ZeRO-3 or FSDP is incompatible with PTQ-quantized models.")
if quant_method == QuantizationMethod.MXFP4:
from transformers import Mxfp4Config
quant_config = Mxfp4Config(dequantize=True)
init_kwargs["quantization_config"] = quant_config
init_kwargs["ignore_mismatched_sizes"] = True
if quant_method == QuantizationMethod.FP8:
from transformers import FineGrainedFP8Config
quant_config = FineGrainedFP8Config(dequantize=True)
init_kwargs["quantization_config"] = quant_config
init_kwargs["ignore_mismatched_sizes"] = True
if quant_method == QuantizationMethod.GPTQ: if quant_method == QuantizationMethod.GPTQ:
check_version("gptqmodel>=2.0.0", mandatory=True) check_version("gptqmodel>=2.0.0", mandatory=True)
quantization_config.pop("disable_exllama", None) # remove deprecated args quantization_config.pop("disable_exllama", None) # remove deprecated args
......
...@@ -40,7 +40,10 @@ def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments") -> ...@@ -40,7 +40,10 @@ def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments") ->
logger.warning_rank0("Current model does not support RoPE scaling.") logger.warning_rank0("Current model does not support RoPE scaling.")
return return
if hasattr(config, "max_position_embeddings"): rope_scaling = getattr(config, "rope_scaling", None)
if isinstance(rope_scaling, dict) and "original_max_position_embeddings" in rope_scaling:
old_max_length = rope_scaling["original_max_position_embeddings"]
elif hasattr(config, "max_position_embeddings"):
old_max_length = getattr(config, "max_position_embeddings", None) old_max_length = getattr(config, "max_position_embeddings", None)
else: else:
logger.warning_rank0("Cannot find the max position embeddings in the config.") logger.warning_rank0("Cannot find the max position embeddings in the config.")
......
...@@ -199,6 +199,15 @@ def patch_target_modules( ...@@ -199,6 +199,15 @@ def patch_target_modules(
return target_modules return target_modules
_register_composite_model(
model_type="dots_ocr",
projector_key="vision_tower.merger",
vision_model_keys=["vision_tower"],
language_model_keys=["model", "lm_head"],
lora_conflict_keys=["merger"],
)
_register_composite_model( _register_composite_model(
model_type="gemma3", model_type="gemma3",
) )
...@@ -221,10 +230,36 @@ _register_composite_model( ...@@ -221,10 +230,36 @@ _register_composite_model(
) )
_register_composite_model(
model_type="glm4v_moe",
projector_key="visual.merger",
vision_model_keys=["visual.patch_embed", "visual.blocks"],
language_model_keys=["language_model", "lm_head"],
lora_conflict_keys=["patch_embed"],
)
_register_composite_model( _register_composite_model(
model_type="internvl", model_type="internvl",
) )
_register_composite_model(
model_type="interns1",
)
_register_composite_model(
model_type="Keye",
projector_key="mlp_AR",
vision_model_keys=["visual.vision_model.patch_embedding", "visual.vision_model.encoder"],
language_model_keys=["model", "lm_head"],
lora_conflict_keys=["patch_embedding"],
)
_register_composite_model(
model_type="kimi_vl",
)
_register_composite_model( _register_composite_model(
model_type="llama4", model_type="llama4",
...@@ -263,8 +298,10 @@ _register_composite_model( ...@@ -263,8 +298,10 @@ _register_composite_model(
lora_conflict_keys=["audio_projection_layer"], lora_conflict_keys=["audio_projection_layer"],
) )
_register_composite_model( _register_composite_model(
model_type="mistral3", model_type="mistral3",
projector_key="model.multi_modal_projector",
) )
...@@ -316,6 +353,33 @@ _register_composite_model( ...@@ -316,6 +353,33 @@ _register_composite_model(
) )
_register_composite_model(
model_type="qwen3_vl",
projector_key="visual.merger",
vision_model_keys=["visual.patch_embed", "visual.blocks", "visual.deepstack_merger_list"],
language_model_keys=["language_model", "lm_head"],
lora_conflict_keys=["patch_embed"],
)
_register_composite_model(
model_type="qwen3_vl_moe",
projector_key="visual.merger",
vision_model_keys=["visual.patch_embed", "visual.blocks", "visual.deepstack_merger_list"],
language_model_keys=["language_model", "lm_head"],
lora_conflict_keys=["patch_embed"],
)
_register_composite_model(
model_type="qwen3_omni_moe_thinker",
projector_key="visual.merger",
vision_model_keys=["visual.patch_embed", "visual.blocks", "visual.deepstack_merger_list", "audio_tower"],
language_model_keys=["model", "lm_head"],
lora_conflict_keys=["patch_embed"],
)
_register_composite_model( _register_composite_model(
model_type="video_llava", model_type="video_llava",
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment