Commit ca625f43 authored by shihm's avatar shihm
Browse files

uodata

parent 7164651d
......@@ -16,22 +16,22 @@
# limitations under the License.
from dataclasses import asdict, dataclass, field
from typing import Any, Literal, Optional
from typing import Any, Literal
@dataclass
class DataArguments:
r"""Arguments pertaining to what data we are going to input our model for training and evaluation."""
template: Optional[str] = field(
template: str | None = field(
default=None,
metadata={"help": "Which template to use for constructing prompts in training and inference."},
)
dataset: Optional[str] = field(
dataset: str | None = field(
default=None,
metadata={"help": "The name of dataset(s) to use for training. Use commas to separate multiple datasets."},
)
eval_dataset: Optional[str] = field(
eval_dataset: str | None = field(
default=None,
metadata={"help": "The name of dataset(s) to use for evaluation. Use commas to separate multiple datasets."},
)
......@@ -39,7 +39,7 @@ class DataArguments:
default="data",
metadata={"help": "Path to the folder containing the datasets."},
)
media_dir: Optional[str] = field(
media_dir: str | None = field(
default=None,
metadata={"help": "Path to the folder containing the images, videos or audios. Defaults to `dataset_dir`."},
)
......@@ -67,7 +67,7 @@ class DataArguments:
default="concat",
metadata={"help": "Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling)."},
)
interleave_probs: Optional[str] = field(
interleave_probs: str | None = field(
default=None,
metadata={"help": "Probabilities to sample data from datasets. Use commas to separate multiple datasets."},
)
......@@ -79,15 +79,15 @@ class DataArguments:
default=1000,
metadata={"help": "The number of examples in one group in pre-processing."},
)
preprocessing_num_workers: Optional[int] = field(
preprocessing_num_workers: int | None = field(
default=None,
metadata={"help": "The number of processes to use for the pre-processing."},
)
max_samples: Optional[int] = field(
max_samples: int | None = field(
default=None,
metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."},
)
eval_num_beams: Optional[int] = field(
eval_num_beams: int | None = field(
default=None,
metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"},
)
......@@ -103,7 +103,7 @@ class DataArguments:
default=False,
metadata={"help": "Whether or not to evaluate on each dataset separately."},
)
packing: Optional[bool] = field(
packing: bool | None = field(
default=None,
metadata={"help": "Enable sequences packing in training. Will automatically enable in pre-training."},
)
......@@ -111,19 +111,19 @@ class DataArguments:
default=False,
metadata={"help": "Enable sequence packing without cross-attention."},
)
tool_format: Optional[str] = field(
tool_format: str | None = field(
default=None,
metadata={"help": "Tool format to use for constructing function calling examples."},
)
default_system: Optional[str] = field(
default_system: str | None = field(
default=None,
metadata={"help": "Override the default system message in the template."},
)
enable_thinking: Optional[bool] = field(
enable_thinking: bool | None = field(
default=True,
metadata={"help": "Whether or not to enable thinking mode for reasoning models."},
)
tokenized_path: Optional[str] = field(
tokenized_path: str | None = field(
default=None,
metadata={
"help": (
......
......@@ -14,7 +14,7 @@
import os
from dataclasses import dataclass, field
from typing import Literal, Optional
from typing import Literal
from datasets import DownloadMode
......@@ -46,7 +46,7 @@ class EvaluationArguments:
default=5,
metadata={"help": "Number of examplars for few-shot learning."},
)
save_dir: Optional[str] = field(
save_dir: str | None = field(
default=None,
metadata={"help": "Path to save the evaluation results."},
)
......
......@@ -13,7 +13,7 @@
# limitations under the License.
from dataclasses import asdict, dataclass, field
from typing import Any, Literal, Optional
from typing import Any, Literal
@dataclass
......@@ -40,7 +40,7 @@ class FreezeArguments:
)
},
)
freeze_extra_modules: Optional[str] = field(
freeze_extra_modules: str | None = field(
default=None,
metadata={
"help": (
......@@ -56,7 +56,7 @@ class FreezeArguments:
class LoraArguments:
r"""Arguments pertaining to the LoRA training."""
additional_target: Optional[str] = field(
additional_target: str | None = field(
default=None,
metadata={
"help": (
......@@ -66,7 +66,7 @@ class LoraArguments:
)
},
)
lora_alpha: Optional[int] = field(
lora_alpha: int | None = field(
default=None,
metadata={"help": "The scale factor for LoRA fine-tuning (default: lora_rank * 2)."},
)
......@@ -88,7 +88,7 @@ class LoraArguments:
)
},
)
loraplus_lr_ratio: Optional[float] = field(
loraplus_lr_ratio: float | None = field(
default=None,
metadata={"help": "LoRA plus learning rate ratio (lr_B / lr_A)."},
)
......@@ -122,6 +122,48 @@ class LoraArguments:
)
@dataclass
class OFTArguments:
r"""Arguments pertaining to the OFT training."""
additional_target: str | None = field(
default=None,
metadata={
"help": (
"Name(s) of modules apart from LoRA layers to be set as trainable "
"and saved in the final checkpoint. "
"Use commas to separate multiple modules."
)
},
)
module_dropout: float = field(
default=0.0,
metadata={"help": "Dropout rate for the OFT fine-tuning."},
)
oft_rank: int = field(
default=0,
metadata={"help": "The intrinsic dimension for OFT fine-tuning."},
)
oft_block_size: int = field(
default=32,
metadata={"help": "The intrinsic dimension for OFT fine-tuning."},
)
oft_target: str = field(
default="all",
metadata={
"help": (
"Name(s) of target modules to apply OFT. "
"Use commas to separate multiple modules. "
"Use `all` to specify all the linear modules."
)
},
)
create_new_adapter: bool = field(
default=False,
metadata={"help": "Whether or not to create a new adapter with randomly initialized weight."},
)
@dataclass
class RLHFArguments:
r"""Arguments pertaining to the PPO, DPO and KTO training."""
......@@ -134,6 +176,10 @@ class RLHFArguments:
default=0.0,
metadata={"help": "The supervised fine-tuning loss coefficient in DPO training."},
)
pref_bco_weight: float = field(
default=0.0,
metadata={"help": "The Binary Classifier Optimization coefficient in DPO training."},
)
pref_loss: Literal["sigmoid", "hinge", "ipo", "kto_pair", "orpo", "simpo"] = field(
default="sigmoid",
metadata={"help": "The type of DPO loss to use."},
......@@ -174,27 +220,27 @@ class RLHFArguments:
default=False,
metadata={"help": "Whiten the rewards before compute advantages in PPO training."},
)
ref_model: Optional[str] = field(
ref_model: str | None = field(
default=None,
metadata={"help": "Path to the reference model used for the PPO or DPO training."},
)
ref_model_adapters: Optional[str] = field(
ref_model_adapters: str | None = field(
default=None,
metadata={"help": "Path to the adapters of the reference model."},
)
ref_model_quantization_bit: Optional[int] = field(
ref_model_quantization_bit: int | None = field(
default=None,
metadata={"help": "The number of bits to quantize the reference model."},
)
reward_model: Optional[str] = field(
reward_model: str | None = field(
default=None,
metadata={"help": "Path to the reward model used for the PPO training."},
)
reward_model_adapters: Optional[str] = field(
reward_model_adapters: str | None = field(
default=None,
metadata={"help": "Path to the adapters of the reward model."},
)
reward_model_quantization_bit: Optional[int] = field(
reward_model_quantization_bit: int | None = field(
default=None,
metadata={"help": "The number of bits to quantize the reward model."},
)
......@@ -202,7 +248,7 @@ class RLHFArguments:
default="lora",
metadata={"help": "The type of the reward model in PPO training. Lora model only supports lora training."},
)
ld_alpha: Optional[float] = field(
ld_alpha: float | None = field(
default=None,
metadata={
"help": (
......@@ -315,15 +361,15 @@ class BAdamArgument:
default="layer",
metadata={"help": "Whether to use layer-wise or ratio-wise BAdam optimizer."},
)
badam_start_block: Optional[int] = field(
badam_start_block: int | None = field(
default=None,
metadata={"help": "The starting block index for layer-wise BAdam."},
)
badam_switch_mode: Optional[Literal["ascending", "descending", "random", "fixed"]] = field(
badam_switch_mode: Literal["ascending", "descending", "random", "fixed"] | None = field(
default="ascending",
metadata={"help": "the strategy of picking block to update for layer-wise BAdam."},
)
badam_switch_interval: Optional[int] = field(
badam_switch_interval: int | None = field(
default=50,
metadata={
"help": "Number of steps to update the block for layer-wise BAdam. Use -1 to disable the block update."
......@@ -360,15 +406,15 @@ class SwanLabArguments:
default=False,
metadata={"help": "Whether or not to use the SwanLab (an experiment tracking and visualization tool)."},
)
swanlab_project: Optional[str] = field(
swanlab_project: str | None = field(
default="llamafactory",
metadata={"help": "The project name in SwanLab."},
)
swanlab_workspace: Optional[str] = field(
swanlab_workspace: str | None = field(
default=None,
metadata={"help": "The workspace name in SwanLab."},
)
swanlab_run_name: Optional[str] = field(
swanlab_run_name: str | None = field(
default=None,
metadata={"help": "The experiment name in SwanLab."},
)
......@@ -376,19 +422,19 @@ class SwanLabArguments:
default="cloud",
metadata={"help": "The mode of SwanLab."},
)
swanlab_api_key: Optional[str] = field(
swanlab_api_key: str | None = field(
default=None,
metadata={"help": "The API key for SwanLab."},
)
swanlab_logdir: Optional[str] = field(
swanlab_logdir: str | None = field(
default=None,
metadata={"help": "The log directory for SwanLab."},
)
swanlab_lark_webhook_url: Optional[str] = field(
swanlab_lark_webhook_url: str | None = field(
default=None,
metadata={"help": "The Lark(飞书) webhook URL for SwanLab."},
)
swanlab_lark_secret: Optional[str] = field(
swanlab_lark_secret: str | None = field(
default=None,
metadata={"help": "The Lark(飞书) secret for SwanLab."},
)
......@@ -396,7 +442,14 @@ class SwanLabArguments:
@dataclass
class FinetuningArguments(
SwanLabArguments, BAdamArgument, ApolloArguments, GaloreArguments, RLHFArguments, LoraArguments, FreezeArguments
SwanLabArguments,
BAdamArgument,
ApolloArguments,
GaloreArguments,
RLHFArguments,
LoraArguments,
OFTArguments,
FreezeArguments,
):
r"""Arguments pertaining to which techniques we are going to fine-tuning with."""
......@@ -408,7 +461,7 @@ class FinetuningArguments(
default="sft",
metadata={"help": "Which stage will be performed in training."},
)
finetuning_type: Literal["lora", "freeze", "full"] = field(
finetuning_type: Literal["lora", "oft", "freeze", "full"] = field(
default="lora",
metadata={"help": "Which fine-tuning method to use."},
)
......@@ -420,10 +473,23 @@ class FinetuningArguments(
default=False,
metadata={"help": "Whether or not to use the Adam-mini optimizer."},
)
use_mca: bool = field(
default=False,
metadata={
"help": (
"Whether or not to use MCA (Megatron Core Adapter) training. "
"Controlled by USE_MCA environment variable."
)
},
)
use_muon: bool = field(
default=False,
metadata={"help": "Whether or not to use the Muon optimizer."},
)
use_dft_loss: bool = field(
default=False,
metadata={"help": "Whether to use the DFT loss."},
)
freeze_vision_tower: bool = field(
default=True,
metadata={"help": "Whether ot not to freeze the vision tower in MLLM training."},
......@@ -444,7 +510,7 @@ class FinetuningArguments(
default=False,
metadata={"help": "Whether or not to disable the shuffling of the training set."},
)
early_stopping_steps: Optional[int] = field(
early_stopping_steps: int | None = field(
default=None,
metadata={"help": "Number of steps to stop training if the `metric_for_best_model` does not improve."},
)
......@@ -464,15 +530,16 @@ class FinetuningArguments(
return arg
self.freeze_trainable_modules: list[str] = split_arg(self.freeze_trainable_modules)
self.freeze_extra_modules: Optional[list[str]] = split_arg(self.freeze_extra_modules)
self.freeze_extra_modules: list[str] | None = split_arg(self.freeze_extra_modules)
self.lora_alpha: int = self.lora_alpha or self.lora_rank * 2
self.lora_target: list[str] = split_arg(self.lora_target)
self.additional_target: Optional[list[str]] = split_arg(self.additional_target)
self.oft_target: list[str] = split_arg(self.oft_target)
self.additional_target: list[str] | None = split_arg(self.additional_target)
self.galore_target: list[str] = split_arg(self.galore_target)
self.apollo_target: list[str] = split_arg(self.apollo_target)
self.use_ref_model = self.stage == "dpo" and self.pref_loss not in ["orpo", "simpo"]
assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
assert self.finetuning_type in ["lora", "oft", "freeze", "full"], "Invalid fine-tuning method."
assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
assert self.reward_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
......@@ -482,6 +549,9 @@ class FinetuningArguments(
if self.stage == "ppo" and self.reward_model_type == "lora" and self.finetuning_type != "lora":
raise ValueError("`reward_model_type` cannot be lora for Freeze/Full PPO training.")
if self.stage == "ppo" and self.reward_model_type == "oft" and self.finetuning_type != "oft":
raise ValueError("`reward_model_type` cannot be oft for Freeze/Full PPO training.")
if self.stage == "dpo" and self.pref_loss != "sigmoid" and self.dpo_label_smoothing > 1e-6:
raise ValueError("`dpo_label_smoothing` is only valid for sigmoid loss function.")
......
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
# Copyright 2025 HuggingFace Inc., the KVCache.AI team, Approaching AI, and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
......@@ -17,26 +17,30 @@
import json
from dataclasses import asdict, dataclass, field, fields
from typing import Any, Literal, Optional, Union
from typing import Any, Literal, Self
import torch
from omegaconf import OmegaConf
from transformers.training_args import _convert_str_dict
from typing_extensions import Self
from ..extras.constants import AttentionFunction, EngineName, QuantizationMethod, RopeScaling
from ..extras.logging import get_logger
logger = get_logger(__name__)
@dataclass
class BaseModelArguments:
r"""Arguments pertaining to the model."""
model_name_or_path: Optional[str] = field(
model_name_or_path: str | None = field(
default=None,
metadata={
"help": "Path to the model weight or identifier from huggingface.co/models or modelscope.cn/models."
},
)
adapter_name_or_path: Optional[str] = field(
adapter_name_or_path: str | None = field(
default=None,
metadata={
"help": (
......@@ -45,11 +49,11 @@ class BaseModelArguments:
)
},
)
adapter_folder: Optional[str] = field(
adapter_folder: str | None = field(
default=None,
metadata={"help": "The folder containing the adapter weights to load."},
)
cache_dir: Optional[str] = field(
cache_dir: str | None = field(
default=None,
metadata={"help": "Where to store the pre-trained models downloaded from huggingface.co or modelscope.cn."},
)
......@@ -65,16 +69,38 @@ class BaseModelArguments:
default=False,
metadata={"help": "Whether or not the special tokens should be split during the tokenization process."},
)
add_tokens: Optional[str] = field(
add_tokens: str | None = field(
default=None,
metadata={
"help": "Non-special tokens to be added into the tokenizer. Use commas to separate multiple tokens."
},
)
add_special_tokens: Optional[str] = field(
add_special_tokens: str | None = field(
default=None,
metadata={"help": "Special tokens to be added into the tokenizer. Use commas to separate multiple tokens."},
)
new_special_tokens_config: str | None = field(
default=None,
metadata={
"help": (
"Path to YAML config with special token descriptions for semantic initialization. "
"If set, this takes precedence over add_special_tokens. "
"YAML format: {'<token>': 'description text', ...}"
)
},
)
init_special_tokens: Literal["noise_init", "desc_init", "desc_init_w_noise"] = field(
default="noise_init",
metadata={
"help": (
"Initialization method for new special tokens: "
"'noise_init' (default, random noise around mean), "
"'desc_init' (semantic initialization from descriptions), "
"'desc_init_w_noise' (semantic + random noise). "
"Note: 'desc_init' methods require new_special_tokens_config."
)
},
)
model_revision: str = field(
default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
......@@ -83,7 +109,7 @@ class BaseModelArguments:
default=True,
metadata={"help": "Whether or not to use memory-efficient model loading."},
)
rope_scaling: Optional[RopeScaling] = field(
rope_scaling: RopeScaling | None = field(
default=None,
metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."},
)
......@@ -95,7 +121,7 @@ class BaseModelArguments:
default=False,
metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."},
)
mixture_of_depths: Optional[Literal["convert", "load"]] = field(
mixture_of_depths: Literal["convert", "load"] | None = field(
default=None,
metadata={"help": "Convert the model to mixture-of-depths (MoD) or load the MoD model."},
)
......@@ -111,7 +137,7 @@ class BaseModelArguments:
default=False,
metadata={"help": "Whether or not to enable liger kernel for faster training."},
)
moe_aux_loss_coef: Optional[float] = field(
moe_aux_loss_coef: float | None = field(
default=None,
metadata={"help": "Coefficient of the auxiliary router loss in mixture-of-experts model."},
)
......@@ -143,23 +169,27 @@ class BaseModelArguments:
default="offload",
metadata={"help": "Path to offload model weights."},
)
use_cache: bool = field(
use_kv_cache: bool = field(
default=True,
metadata={"help": "Whether or not to use KV cache in generation."},
)
use_v1_kernels: bool | None = field(
default=False,
metadata={"help": "Whether or not to use high-performance kernels in training."},
)
infer_dtype: Literal["auto", "float16", "bfloat16", "float32"] = field(
default="auto",
metadata={"help": "Data type for model weights and activations at inference."},
)
hf_hub_token: Optional[str] = field(
hf_hub_token: str | None = field(
default=None,
metadata={"help": "Auth token to log in with Hugging Face Hub."},
)
ms_hub_token: Optional[str] = field(
ms_hub_token: str | None = field(
default=None,
metadata={"help": "Auth token to log in with ModelScope Hub."},
)
om_hub_token: Optional[str] = field(
om_hub_token: str | None = field(
default=None,
metadata={"help": "Auth token to log in with Modelers Hub."},
)
......@@ -185,8 +215,63 @@ class BaseModelArguments:
if self.add_tokens is not None: # support multiple tokens
self.add_tokens = [token.strip() for token in self.add_tokens.split(",")]
if self.add_special_tokens is not None: # support multiple special tokens
# Process special tokens with priority: new_special_tokens_config > add_special_tokens
if self.new_special_tokens_config is not None:
# Priority 1: Load from YAML config (extracts both tokens and descriptions)
try:
cfg = OmegaConf.load(self.new_special_tokens_config)
token_descriptions = OmegaConf.to_container(cfg)
if not isinstance(token_descriptions, dict):
raise ValueError(
f"YAML config must be a dictionary mapping tokens to descriptions. "
f"Got: {type(token_descriptions)}"
)
# Extract token list from config keys
extracted_tokens = list(token_descriptions.keys())
# Warn if both are set
if self.add_special_tokens is not None:
logger.warning_rank0(
"Both 'new_special_tokens_config' and 'add_special_tokens' are set. "
f"Using tokens from config: {extracted_tokens}"
)
# Override add_special_tokens with extracted tokens (as list)
self.add_special_tokens = extracted_tokens
# Store descriptions internally for later use (internal attribute)
self._special_token_descriptions = token_descriptions
logger.info_rank0(
f"Loaded {len(extracted_tokens)} special tokens with descriptions from: "
f"{self.new_special_tokens_config}"
)
except Exception as e:
logger.error_rank0(
f"Failed to load special tokens config from '{self.new_special_tokens_config}': {e}"
)
raise
elif self.add_special_tokens is not None:
# Priority 2: Use simple comma-separated string (no descriptions)
self.add_special_tokens = [token.strip() for token in self.add_special_tokens.split(",")]
self._special_token_descriptions = None
else:
# No special tokens to add
self._special_token_descriptions = None
# Validate init method
if self.init_special_tokens in ["desc_init", "desc_init_w_noise"]:
if self._special_token_descriptions is None:
logger.warning_rank0(
f"init_special_tokens='{self.init_special_tokens}' requires new_special_tokens_config. "
"Falling back to 'noise_init'"
)
self.init_special_tokens = "noise_init"
@dataclass
......@@ -197,7 +282,7 @@ class QuantizationArguments:
default=QuantizationMethod.BNB,
metadata={"help": "Quantization method to use for on-the-fly quantization."},
)
quantization_bit: Optional[int] = field(
quantization_bit: int | None = field(
default=None,
metadata={"help": "The number of bits to quantize the model using on-the-fly quantization."},
)
......@@ -209,10 +294,27 @@ class QuantizationArguments:
default=True,
metadata={"help": "Whether or not to use double quantization in bitsandbytes int4 training."},
)
quantization_device_map: Optional[Literal["auto"]] = field(
quantization_device_map: Literal["auto"] | None = field(
default=None,
metadata={"help": "Device map used to infer the 4-bit quantized model, needs bitsandbytes>=0.43.0."},
)
fp8: bool = field(
default=False,
metadata={
"help": "Enable FP8 mixed precision training via HuggingFace Accelerate. "
"Requires PyTorch 2.7+ and Hopper architecture GPUs."
},
)
fp8_backend: str = field(
default="auto",
metadata={
"help": "FP8 backend to use ('auto', 'torchao', 'te', 'msamp'). 'auto' selects best available backend."
},
)
fp8_enable_fsdp_float8_all_gather: bool = field(
default=False,
metadata={"help": "Enable FP8 optimizations for FSDP2 all-gather operations."},
)
@dataclass
......@@ -272,7 +374,7 @@ class ProcessorArguments:
class ExportArguments:
r"""Arguments pertaining to the model export."""
export_dir: Optional[str] = field(
export_dir: str | None = field(
default=None,
metadata={"help": "Path to the directory to save the exported model."},
)
......@@ -284,11 +386,11 @@ class ExportArguments:
default="cpu",
metadata={"help": "The device used in model export, use `auto` to accelerate exporting."},
)
export_quantization_bit: Optional[int] = field(
export_quantization_bit: int | None = field(
default=None,
metadata={"help": "The number of bits to quantize the exported model."},
)
export_quantization_dataset: Optional[str] = field(
export_quantization_dataset: str | None = field(
default=None,
metadata={"help": "Path to the dataset or dataset name to use in quantizing the exported model."},
)
......@@ -304,7 +406,7 @@ class ExportArguments:
default=False,
metadata={"help": "Whether or not to save the `.bin` files instead of `.safetensors`."},
)
export_hub_model_id: Optional[str] = field(
export_hub_model_id: str | None = field(
default=None,
metadata={"help": "The name of the repository if push the model to the Hugging Face hub."},
)
......@@ -334,7 +436,7 @@ class VllmArguments:
default=32,
metadata={"help": "Maximum rank of all LoRAs in the vLLM engine."},
)
vllm_config: Optional[Union[dict, str]] = field(
vllm_config: dict | str | None = field(
default=None,
metadata={"help": "Config to initialize the vllm engine. Please use JSON strings."},
)
......@@ -360,7 +462,7 @@ class SGLangArguments:
default=-1,
metadata={"help": "Tensor parallel size for the SGLang engine."},
)
sglang_config: Optional[Union[dict, str]] = field(
sglang_config: dict | str | None = field(
default=None,
metadata={"help": "Config to initialize the SGLang engine. Please use JSON strings."},
)
......@@ -376,26 +478,77 @@ class SGLangArguments:
self.sglang_config = _convert_str_dict(json.loads(self.sglang_config))
@dataclass
class KTransformersArguments:
r"""Arguments pertaining to the KT training."""
use_kt: bool = field(
default=False,
metadata={"help": "Whether To Use KTransformers Optimizations For LoRA Training."},
)
kt_optimize_rule: str | None = field(
default=None,
metadata={
"help": "Path To The KTransformers Optimize Rule; See https://github.com/kvcache-ai/ktransformers/."
},
)
cpu_infer: int | None = field(
default=32,
metadata={"help": "Number Of CPU Cores Used For Computation."},
)
chunk_size: int | None = field(
default=8192,
metadata={"help": "Chunk Size Used For CPU Compute In KTransformers."},
)
mode: str | None = field(
default="normal",
metadata={"help": "Normal Or Long_Context For Llama Models."},
)
kt_maxlen: int = field(
default=4096,
metadata={"help": "Maximum Sequence (Prompt + Response) Length Of The KT Engine."},
)
kt_use_cuda_graph: bool = field(
default=True,
metadata={"help": "Whether To Use CUDA Graphs For The KT Engine."},
)
kt_mode: str = field(
default="normal",
metadata={"help": "Normal Or Long_Context Mode For The KT Engine."},
)
kt_force_think: bool = field(
default=False,
metadata={"help": "Force-Think Toggle For The KT Engine."},
)
@dataclass
class ModelArguments(
SGLangArguments, VllmArguments, ExportArguments, ProcessorArguments, QuantizationArguments, BaseModelArguments
SGLangArguments,
VllmArguments,
KTransformersArguments,
ExportArguments,
ProcessorArguments,
QuantizationArguments,
BaseModelArguments,
):
r"""Arguments pertaining to which model/config/tokenizer we are going to fine-tune or infer.
The class on the most right will be displayed first.
"""
compute_dtype: Optional[torch.dtype] = field(
compute_dtype: torch.dtype | None = field(
default=None,
init=False,
metadata={"help": "Torch data type for computing model outputs, derived from `fp/bf16`. Do not specify it."},
)
device_map: Optional[Union[str, dict[str, Any]]] = field(
device_map: str | dict[str, Any] | None = field(
default=None,
init=False,
metadata={"help": "Device map for model placement, derived from training stage. Do not specify it."},
)
model_max_length: Optional[int] = field(
model_max_length: int | None = field(
default=None,
init=False,
metadata={"help": "The maximum input length for model, derived from `cutoff_len`. Do not specify it."},
......
......@@ -18,7 +18,7 @@
import os
import sys
from pathlib import Path
from typing import Any, Optional, Union
from typing import Any, Optional
import torch
import transformers
......@@ -32,6 +32,7 @@ from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_availab
from ..extras import logging
from ..extras.constants import CHECKPOINT_NAMES, EngineName
from ..extras.misc import check_dependencies, check_version, get_current_device, is_env_enabled
from ..extras.packages import is_mcore_adapter_available, is_transformers_version_greater_than
from .data_args import DataArguments
from .evaluation_args import EvaluationArguments
from .finetuning_args import FinetuningArguments
......@@ -52,8 +53,19 @@ _INFER_CLS = tuple[ModelArguments, DataArguments, FinetuningArguments, Generatin
_EVAL_ARGS = [ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
_EVAL_CLS = tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
if is_mcore_adapter_available() and is_env_enabled("USE_MCA"):
from mcore_adapter import TrainingArguments as McaTrainingArguments
def read_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> Union[dict[str, Any], list[str]]:
_TRAIN_MCA_ARGS = [ModelArguments, DataArguments, McaTrainingArguments, FinetuningArguments, GeneratingArguments]
_TRAIN_MCA_CLS = tuple[
ModelArguments, DataArguments, McaTrainingArguments, FinetuningArguments, GeneratingArguments
]
else:
_TRAIN_MCA_ARGS = []
_TRAIN_MCA_CLS = tuple()
def read_args(args: dict[str, Any] | list[str] | None = None) -> dict[str, Any] | list[str]:
r"""Get arguments from the command line or a config file."""
if args is not None:
return args
......@@ -71,7 +83,7 @@ def read_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> Union[
def _parse_args(
parser: "HfArgumentParser", args: Optional[Union[dict[str, Any], list[str]]] = None, allow_extra_keys: bool = False
parser: "HfArgumentParser", args: dict[str, Any] | list[str] | None = None, allow_extra_keys: bool = False
) -> tuple[Any]:
args = read_args(args)
if isinstance(args, dict):
......@@ -111,8 +123,8 @@ def _verify_model_args(
raise ValueError("Adapter is only valid for the LoRA method.")
if model_args.quantization_bit is not None:
if finetuning_args.finetuning_type != "lora":
raise ValueError("Quantization is only compatible with the LoRA method.")
if finetuning_args.finetuning_type not in ["lora", "oft"]:
raise ValueError("Quantization is only compatible with the LoRA or OFT method.")
if finetuning_args.pissa_init:
raise ValueError("Please use scripts/pissa_init.py to initialize PiSSA for a quantized model.")
......@@ -130,12 +142,23 @@ def _verify_model_args(
logger.warning_rank0("We should use slow tokenizer for the Yi models. Change `use_fast_tokenizer` to False.")
model_args.use_fast_tokenizer = False
# Validate advanced training features
if model_args.fp8 and model_args.quantization_bit is not None:
raise ValueError("FP8 training is not compatible with quantization. Please disable one of them.")
if model_args.fp8_enable_fsdp_float8_all_gather and not model_args.fp8:
logger.warning_rank0("fp8_enable_fsdp_float8_all_gather requires fp8=True. Setting fp8=True.")
model_args.fp8 = True
def _check_extra_dependencies(
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
training_args: Optional["TrainingArguments"] = None,
) -> None:
if model_args.use_kt:
check_version("ktransformers", mandatory=True)
if model_args.use_unsloth:
check_version("unsloth", mandatory=True)
......@@ -146,7 +169,7 @@ def _check_extra_dependencies(
check_version("mixture-of-depth>=1.1.6", mandatory=True)
if model_args.infer_backend == EngineName.VLLM:
check_version("vllm>=0.4.3,<=0.9.1")
check_version("vllm>=0.4.3,<=0.11.0")
check_version("vllm", mandatory=True)
elif model_args.infer_backend == EngineName.SGLANG:
check_version("sglang>=0.4.5")
......@@ -173,7 +196,8 @@ def _check_extra_dependencies(
if training_args is not None:
if training_args.deepspeed:
# pin deepspeed version < 0.17 because of https://github.com/deepspeedai/DeepSpeed/issues/7347
check_version("deepspeed>=0.10.0,<=0.16.9", mandatory=True)
check_version("deepspeed", mandatory=True)
check_version("deepspeed>=0.10.0,<=0.16.9")
if training_args.predict_with_generate:
check_version("jieba", mandatory=True)
......@@ -181,32 +205,57 @@ def _check_extra_dependencies(
check_version("rouge_chinese", mandatory=True)
def _parse_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _TRAIN_CLS:
def _parse_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS:
parser = HfArgumentParser(_TRAIN_ARGS)
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
return _parse_args(parser, args, allow_extra_keys=allow_extra_keys)
def _parse_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _INFER_CLS:
def _parse_train_mca_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_MCA_CLS:
parser = HfArgumentParser(_TRAIN_MCA_ARGS)
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
model_args, data_args, training_args, finetuning_args, generating_args = _parse_args(
parser, args, allow_extra_keys=allow_extra_keys
)
_configure_mca_training_args(training_args, data_args, finetuning_args)
return model_args, data_args, training_args, finetuning_args, generating_args
def _configure_mca_training_args(training_args, data_args, finetuning_args) -> None:
"""Patch training args to avoid args checking errors and sync MCA settings."""
training_args.predict_with_generate = False
training_args.generation_max_length = data_args.cutoff_len
training_args.generation_num_beams = 1
training_args.use_mca = True
finetuning_args.use_mca = True
def _parse_infer_args(args: dict[str, Any] | list[str] | None = None) -> _INFER_CLS:
parser = HfArgumentParser(_INFER_ARGS)
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
return _parse_args(parser, args, allow_extra_keys=allow_extra_keys)
def _parse_eval_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _EVAL_CLS:
def _parse_eval_args(args: dict[str, Any] | list[str] | None = None) -> _EVAL_CLS:
parser = HfArgumentParser(_EVAL_ARGS)
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
return _parse_args(parser, args, allow_extra_keys=allow_extra_keys)
def get_ray_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> RayArguments:
def get_ray_args(args: dict[str, Any] | list[str] | None = None) -> RayArguments:
parser = HfArgumentParser(RayArguments)
(ray_args,) = _parse_args(parser, args, allow_extra_keys=True)
return ray_args
def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _TRAIN_CLS:
model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_args(args)
def get_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS:
if is_env_enabled("USE_MCA"):
model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_mca_args(args)
else:
model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_args(args)
finetuning_args.use_mca = False
# Setup logging
if training_args.should_log:
......@@ -236,13 +285,16 @@ def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
if model_args.shift_attn:
raise ValueError("PPO training is incompatible with S^2-Attn.")
if finetuning_args.reward_model_type == "lora" and model_args.use_kt:
raise ValueError("KTransformers does not support lora reward model.")
if finetuning_args.reward_model_type == "lora" and model_args.use_unsloth:
raise ValueError("Unsloth does not support lora reward model.")
if training_args.report_to and training_args.report_to[0] not in ["wandb", "tensorboard"]:
raise ValueError("PPO only accepts wandb or tensorboard logger.")
if training_args.parallel_mode == ParallelMode.NOT_DISTRIBUTED:
if not model_args.use_kt and training_args.parallel_mode == ParallelMode.NOT_DISTRIBUTED:
raise ValueError("Please launch distributed training with `llamafactory-cli` or `torchrun`.")
if training_args.deepspeed and training_args.parallel_mode != ParallelMode.DISTRIBUTED:
......@@ -254,18 +306,15 @@ def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
if training_args.do_train and data_args.dataset is None:
raise ValueError("Please specify dataset for training.")
if (training_args.do_eval or training_args.do_predict) and (
if (training_args.do_eval or training_args.do_predict or training_args.predict_with_generate) and (
data_args.eval_dataset is None and data_args.val_size < 1e-6
):
raise ValueError("Please specify dataset for evaluation.")
raise ValueError("Please make sure eval_dataset be provided or val_size >1e-6")
if training_args.predict_with_generate:
if is_deepspeed_zero3_enabled():
raise ValueError("`predict_with_generate` is incompatible with DeepSpeed ZeRO-3.")
if data_args.eval_dataset is None:
raise ValueError("Cannot use `predict_with_generate` if `eval_dataset` is None.")
if finetuning_args.compute_accuracy:
raise ValueError("Cannot use `predict_with_generate` and `compute_accuracy` together.")
......@@ -304,6 +353,12 @@ def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
if model_args.use_unsloth and is_deepspeed_zero3_enabled():
raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.")
if model_args.use_kt and is_deepspeed_zero3_enabled():
raise ValueError("KTransformers is incompatible with DeepSpeed ZeRO-3.")
if data_args.neat_packing and is_transformers_version_greater_than("4.53.0"):
raise ValueError("Neat packing is incompatible with transformers>=4.53.0.")
_set_env_vars()
_verify_model_args(model_args, data_args, finetuning_args)
_check_extra_dependencies(model_args, finetuning_args, training_args)
......@@ -418,7 +473,7 @@ def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
return model_args, data_args, training_args, finetuning_args, generating_args
def get_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _INFER_CLS:
def get_infer_args(args: dict[str, Any] | list[str] | None = None) -> _INFER_CLS:
model_args, data_args, finetuning_args, generating_args = _parse_infer_args(args)
# Setup logging
......@@ -453,7 +508,7 @@ def get_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
return model_args, data_args, finetuning_args, generating_args
def get_eval_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _EVAL_CLS:
def get_eval_args(args: dict[str, Any] | list[str] | None = None) -> _EVAL_CLS:
model_args, data_args, eval_args, finetuning_args = _parse_eval_args(args)
# Setup logging
......
......@@ -14,19 +14,33 @@
import json
from dataclasses import dataclass, field
from typing import Literal, Optional, Union
from typing import Literal
from transformers import Seq2SeqTrainingArguments
from transformers.training_args import _convert_str_dict
from ..extras.misc import use_ray
from ..extras.misc import is_env_enabled, use_ray
from ..extras.packages import is_mcore_adapter_available
if is_env_enabled("USE_MCA"):
if not is_mcore_adapter_available():
raise ImportError(
"mcore_adapter is required when USE_MCA=1. Please install `mcore_adapter` and its dependencies."
)
from mcore_adapter import Seq2SeqTrainingArguments as McaSeq2SeqTrainingArguments
BaseTrainingArguments = McaSeq2SeqTrainingArguments
else:
BaseTrainingArguments = Seq2SeqTrainingArguments
@dataclass
class RayArguments:
r"""Arguments pertaining to the Ray training."""
ray_run_name: Optional[str] = field(
ray_run_name: str | None = field(
default=None,
metadata={"help": "The training results will be saved at `<ray_storage_path>/ray_run_name`."},
)
......@@ -34,7 +48,7 @@ class RayArguments:
default="./saves",
metadata={"help": "The storage path to save training results to"},
)
ray_storage_filesystem: Optional[Literal["s3", "gs", "gcs"]] = field(
ray_storage_filesystem: Literal["s3", "gs", "gcs"] | None = field(
default=None,
metadata={"help": "The storage filesystem to use. If None specified, local filesystem will be used."},
)
......@@ -42,7 +56,7 @@ class RayArguments:
default=1,
metadata={"help": "The number of workers for Ray training. Default is 1 worker."},
)
resources_per_worker: Union[dict, str] = field(
resources_per_worker: dict | str = field(
default_factory=lambda: {"GPU": 1},
metadata={"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."},
)
......@@ -50,7 +64,7 @@ class RayArguments:
default="PACK",
metadata={"help": "The placement strategy for Ray training. Default is PACK."},
)
ray_init_kwargs: Optional[dict] = field(
ray_init_kwargs: dict | str | None = field(
default=None,
metadata={"help": "The arguments to pass to ray.init for Ray training. Default is None."},
)
......@@ -59,10 +73,14 @@ class RayArguments:
self.use_ray = use_ray()
if isinstance(self.resources_per_worker, str) and self.resources_per_worker.startswith("{"):
self.resources_per_worker = _convert_str_dict(json.loads(self.resources_per_worker))
if isinstance(self.ray_init_kwargs, str) and self.ray_init_kwargs.startswith("{"):
self.ray_init_kwargs = _convert_str_dict(json.loads(self.ray_init_kwargs))
if self.ray_storage_filesystem is not None:
if self.ray_storage_filesystem not in ["s3", "gs", "gcs"]:
raise ValueError(
f"ray_storage_filesystem must be one of ['s3', 'gs', 'gcs'], got {self.ray_storage_filesystem}"
f"ray_storage_filesystem must be one of ['s3', 'gs', 'gcs'], got {self.ray_storage_filesystem}."
)
import pyarrow.fs as fs
......@@ -74,9 +92,14 @@ class RayArguments:
@dataclass
class TrainingArguments(RayArguments, Seq2SeqTrainingArguments):
class TrainingArguments(RayArguments, BaseTrainingArguments):
r"""Arguments pertaining to the trainer."""
overwrite_output_dir: bool = field(
default=False,
metadata={"help": "deprecated"},
)
def __post_init__(self):
Seq2SeqTrainingArguments.__post_init__(self)
RayArguments.__post_init__(self)
BaseTrainingArguments.__post_init__(self)
......@@ -12,12 +12,174 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from llamafactory.train.tuner import run_exp # use absolute import
import os
import subprocess
import sys
from copy import deepcopy
USAGE = (
"-" * 70
+ "\n"
+ "| Usage: |\n"
+ "| llamafactory-cli api -h: launch an OpenAI-style API server |\n"
+ "| llamafactory-cli chat -h: launch a chat interface in CLI |\n"
+ "| llamafactory-cli export -h: merge LoRA adapters and export model |\n"
+ "| llamafactory-cli train -h: train models |\n"
+ "| llamafactory-cli webchat -h: launch a chat interface in Web UI |\n"
+ "| llamafactory-cli webui: launch LlamaBoard |\n"
+ "| llamafactory-cli env: show environment info |\n"
+ "| llamafactory-cli version: show version info |\n"
+ "| Hint: You can use `lmf` as a shortcut for `llamafactory-cli`. |\n"
+ "-" * 70
)
def launch():
run_exp()
from .extras import logging
from .extras.env import VERSION, print_env
from .extras.misc import find_available_port, get_device_count, is_env_enabled, use_kt, use_ray
logger = logging.get_logger(__name__)
WELCOME = (
"-" * 58
+ "\n"
+ f"| Welcome to LLaMA Factory, version {VERSION}"
+ " " * (21 - len(VERSION))
+ "|\n|"
+ " " * 56
+ "|\n"
+ "| Project page: https://github.com/hiyouga/LLaMA-Factory |\n"
+ "-" * 58
)
command = sys.argv.pop(1) if len(sys.argv) > 1 else "help"
if is_env_enabled("USE_MCA"): # force use torchrun
os.environ["FORCE_TORCHRUN"] = "1"
if command == "train" and (
is_env_enabled("FORCE_TORCHRUN") or (get_device_count() > 1 and not use_ray() and not use_kt())
):
# launch distributed training
nnodes = os.getenv("NNODES", "1")
node_rank = os.getenv("NODE_RANK", "0")
nproc_per_node = os.getenv("NPROC_PER_NODE", str(get_device_count()))
master_addr = os.getenv("MASTER_ADDR", "127.0.0.1")
master_port = os.getenv("MASTER_PORT", str(find_available_port()))
logger.info_rank0(f"Initializing {nproc_per_node} distributed tasks at: {master_addr}:{master_port}")
if int(nnodes) > 1:
logger.info_rank0(f"Multi-node training enabled: num nodes: {nnodes}, node rank: {node_rank}")
# elastic launch support
max_restarts = os.getenv("MAX_RESTARTS", "0")
rdzv_id = os.getenv("RDZV_ID")
min_nnodes = os.getenv("MIN_NNODES")
max_nnodes = os.getenv("MAX_NNODES")
env = deepcopy(os.environ)
if is_env_enabled("OPTIM_TORCH", "1"):
# optimize DDP, see https://zhuanlan.zhihu.com/p/671834539
env["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
env["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
if rdzv_id is not None:
# launch elastic job with fault tolerant support when possible
# see also https://docs.pytorch.org/docs/stable/elastic/train_script.html
rdzv_nnodes = nnodes
# elastic number of nodes if MIN_NNODES and MAX_NNODES are set
if min_nnodes is not None and max_nnodes is not None:
rdzv_nnodes = f"{min_nnodes}:{max_nnodes}"
process = subprocess.run(
(
"torchrun --nnodes {rdzv_nnodes} --nproc-per-node {nproc_per_node} "
"--rdzv-id {rdzv_id} --rdzv-backend c10d --rdzv-endpoint {master_addr}:{master_port} "
"--max-restarts {max_restarts} {file_name} {args}"
)
.format(
rdzv_nnodes=rdzv_nnodes,
nproc_per_node=nproc_per_node,
rdzv_id=rdzv_id,
master_addr=master_addr,
master_port=master_port,
max_restarts=max_restarts,
file_name=__file__,
args=" ".join(sys.argv[1:]),
)
.split(),
env=env,
check=True,
)
else:
# NOTE: DO NOT USE shell=True to avoid security risk
process = subprocess.run(
(
"torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} "
"--master_addr {master_addr} --master_port {master_port} {file_name} {args}"
)
.format(
nnodes=nnodes,
node_rank=node_rank,
nproc_per_node=nproc_per_node,
master_addr=master_addr,
master_port=master_port,
file_name=__file__,
args=" ".join(sys.argv[1:]),
)
.split(),
env=env,
check=True,
)
sys.exit(process.returncode)
elif command == "api":
from .api.app import run_api
run_api()
elif command == "chat":
from .chat.chat_model import run_chat
run_chat()
elif command == "eval":
raise NotImplementedError("Evaluation will be deprecated in the future.")
elif command == "export":
from .train.tuner import export_model
export_model()
elif command == "train":
from .train.tuner import run_exp
run_exp()
elif command == "webchat":
from .webui.interface import run_web_demo
run_web_demo()
elif command == "webui":
from .webui.interface import run_web_ui
run_web_ui()
elif command == "env":
print_env()
elif command == "version":
print(WELCOME)
elif command == "help":
print(USAGE)
else:
print(f"Unknown command: {command}.\n{USAGE}")
if __name__ == "__main__":
launch()
from llamafactory.train.tuner import run_exp # use absolute import
run_exp()
......@@ -16,10 +16,12 @@ import re
from typing import TYPE_CHECKING
import torch
from peft import LoraConfig, LoraModel, PeftModel, TaskType, get_peft_model
from peft import LoraConfig, LoraModel, OFTConfig, PeftModel, TaskType, get_peft_model
from transformers.integrations import is_deepspeed_zero3_enabled
from ..extras import logging
from ..extras.constants import EngineName
from .model_utils.ktransformers import get_kt_peft_model, load_kt_peft_model
from .model_utils.misc import find_all_linear_modules, find_expanded_modules
from .model_utils.quantization import QuantizationMethod
from .model_utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model
......@@ -147,7 +149,10 @@ def _setup_lora_tuning(
cast_trainable_params_to_fp32: bool,
) -> "PeftModel":
if is_trainable:
logger.info_rank0("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA"))
if finetuning_args.finetuning_type == "oft":
logger.info_rank0("Fine-tuning method: OFT")
else:
logger.info_rank0("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA"))
adapter_to_resume = None
......@@ -161,6 +166,10 @@ def _setup_lora_tuning(
assert len(model_args.adapter_name_or_path) == 1, "Cannot use multiple adapters in DeepSpeed ZeRO-3."
is_mergeable = False
if model_args.use_kt:
assert len(model_args.adapter_name_or_path) == 1, "KTransformers model only accepts a single adapter"
is_mergeable = False
if model_args.use_unsloth:
assert len(model_args.adapter_name_or_path) == 1, "Unsloth model only accepts a single adapter."
is_mergeable = False
......@@ -179,6 +188,12 @@ def _setup_lora_tuning(
"token": model_args.hf_hub_token,
}
if model_args.use_kt:
if model_args.infer_backend != EngineName.KT:
raise ValueError(
"We should use ktransformers as backend to infer the adapter fine-tuned by ktransformers."
)
for adapter in adapter_to_merge:
model: LoraModel = PeftModel.from_pretrained(model, adapter, **init_kwargs)
model = model.merge_and_unload()
......@@ -187,7 +202,9 @@ def _setup_lora_tuning(
logger.info_rank0(f"Merged {len(adapter_to_merge)} adapter(s).")
if adapter_to_resume is not None: # resume lora training
if model_args.use_unsloth:
if model_args.use_kt:
model = load_kt_peft_model(model_args, model)
elif model_args.use_unsloth:
model = load_unsloth_peft_model(config, model_args, finetuning_args, is_trainable=is_trainable)
else:
model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_trainable, **init_kwargs)
......@@ -200,6 +217,16 @@ def _setup_lora_tuning(
else:
target_modules = finetuning_args.lora_target
if model_args.use_kt:
new_list = []
for m in target_modules:
if m in ("down_proj", "up_proj", "gate_proj"):
new_list.extend([f"mlp.{m}", f"shared_experts.{m}"])
elif m not in ("generate_linear", "orig_module", "prefill_linear"):
new_list.append(m)
target_modules[:] = new_list
if finetuning_args.use_llama_pro:
target_modules = find_expanded_modules(model, target_modules, finetuning_args.freeze_trainable_layers)
......@@ -223,17 +250,43 @@ def _setup_lora_tuning(
finetuning_args.additional_target = module_names
logger.warning_rank0("Vocab has been resized, add {} to trainable params.".format(",".join(module_names)))
peft_kwargs = {
"r": finetuning_args.lora_rank,
"target_modules": target_modules,
"lora_alpha": finetuning_args.lora_alpha,
"lora_dropout": finetuning_args.lora_dropout,
"use_rslora": finetuning_args.use_rslora,
"use_dora": finetuning_args.use_dora,
"modules_to_save": finetuning_args.additional_target,
}
if finetuning_args.finetuning_type == "lora":
peft_kwargs = {
"r": finetuning_args.lora_rank,
"target_modules": target_modules,
"lora_alpha": finetuning_args.lora_alpha,
"lora_dropout": finetuning_args.lora_dropout,
"use_rslora": finetuning_args.use_rslora,
"use_dora": finetuning_args.use_dora,
"modules_to_save": finetuning_args.additional_target,
}
elif finetuning_args.finetuning_type == "oft":
peft_kwargs = {
"r": finetuning_args.oft_rank,
"oft_block_size": finetuning_args.oft_block_size,
"target_modules": target_modules,
"module_dropout": finetuning_args.module_dropout,
"modules_to_save": finetuning_args.additional_target,
}
if model_args.use_kt:
if finetuning_args.finetuning_type == "oft":
raise ValueError("KTransformers is currently not supported for OFT.")
if finetuning_args.finetuning_type == "lora":
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
**peft_kwargs,
)
else:
raise ValueError("KTransformers is currently only supported for LoRA.")
model = get_kt_peft_model(model, peft_config)
print(f"KT_model:{model}")
elif model_args.use_unsloth:
if finetuning_args.finetuning_type == "oft":
raise ValueError("Unsloth is currently not supported for OFT.")
if model_args.use_unsloth:
model = get_unsloth_peft_model(model, model_args, peft_kwargs)
else:
if finetuning_args.pissa_init:
......@@ -244,12 +297,19 @@ def _setup_lora_tuning(
logger.info_rank0(f"Using PiSSA initialization with FSVD steps {finetuning_args.pissa_iter}.")
peft_kwargs["init_lora_weights"] = f"pissa_niter_{finetuning_args.pissa_iter}"
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
**peft_kwargs,
)
model = get_peft_model(model, lora_config)
if finetuning_args.finetuning_type == "lora":
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
**peft_kwargs,
)
elif finetuning_args.finetuning_type == "oft":
peft_config = OFTConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
**peft_kwargs,
)
model = get_peft_model(model, peft_config)
if is_trainable and cast_trainable_params_to_fp32:
for param in filter(lambda p: p.requires_grad, model.parameters()):
......@@ -272,8 +332,8 @@ def init_adapter(
Note that the trainable parameters must be cast to float32.
"""
if is_trainable and getattr(model, "quantization_method", None) is not None:
if finetuning_args.finetuning_type != "lora":
raise ValueError("Quantized models can only be used for the LoRA tuning.")
if finetuning_args.finetuning_type not in ["lora", "oft"]:
raise ValueError("Quantized models can only be used for the LoRA or OFT tuning.")
if finetuning_args.pissa_init:
raise ValueError("Cannot initialize PiSSA adapter on quantized models.")
......@@ -296,7 +356,7 @@ def init_adapter(
_setup_full_tuning(model, finetuning_args, is_trainable, cast_trainable_params_to_fp32)
elif finetuning_args.finetuning_type == "freeze":
_setup_freeze_tuning(model, finetuning_args, is_trainable, cast_trainable_params_to_fp32)
elif finetuning_args.finetuning_type == "lora":
elif finetuning_args.finetuning_type in ["lora", "oft"]:
model = _setup_lora_tuning(
config, model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32
)
......
......@@ -15,7 +15,6 @@
import os
from typing import TYPE_CHECKING, Any, Optional, TypedDict
import torch
from transformers import (
AutoConfig,
AutoModelForCausalLM,
......@@ -31,6 +30,7 @@ from trl import AutoModelForCausalLMWithValueHead
from ..extras import logging
from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_other_hub
from .adapter import init_adapter
from .model_utils.ktransformers import load_kt_pretrained_model
from .model_utils.liger_kernel import apply_liger_kernel
from .model_utils.misc import register_autoclass
from .model_utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model
......@@ -143,7 +143,12 @@ def load_model(
model = None
lazy_load = False
if model_args.use_unsloth:
if model_args.use_kt:
from ktransformers.sft.monkey_patch_torch_module import install_patch
install_patch()
model = load_kt_pretrained_model(config, model_args)
elif model_args.use_unsloth:
if model_args.adapter_name_or_path is not None:
lazy_load = True
elif is_trainable:
......@@ -152,17 +157,18 @@ def load_model(
if model is None and not lazy_load:
init_kwargs["config"] = config
init_kwargs["pretrained_model_name_or_path"] = model_args.model_name_or_path
init_kwargs["torch_dtype"] = "auto"
if model_args.mixture_of_depths == "load":
model = load_mod_pretrained_model(**init_kwargs)
else:
if type(config) in AutoModelForVision2Seq._model_mapping.keys(): # image-text
load_class = AutoModelForVision2Seq
elif type(config) in AutoModelForImageTextToText._model_mapping.keys(): # image-text
if type(config) in AutoModelForImageTextToText._model_mapping.keys(): # image-text
load_class = AutoModelForImageTextToText
elif type(config) in AutoModelForVision2Seq._model_mapping.keys(): # image-text
load_class = AutoModelForVision2Seq
elif type(config) in AutoModelForSeq2SeqLM._model_mapping.keys(): # audio-text
load_class = AutoModelForSeq2SeqLM
elif type(config) in AutoModelForTextToWaveform._model_mapping.keys(): # audio hack for qwen2_5_omni
elif type(config) in AutoModelForTextToWaveform._model_mapping.keys(): # audio hack for qwen omni
load_class = AutoModelForTextToWaveform
else:
load_class = AutoModelForCausalLM
......@@ -171,8 +177,8 @@ def load_model(
model = load_class.from_config(config, trust_remote_code=model_args.trust_remote_code)
else:
model = load_class.from_pretrained(**init_kwargs)
if getattr(model.config, "model_type", None) == "qwen2_5_omni":
model = model.thinker # use part of Omni model
if getattr(model.config, "model_type", None) in ["qwen2_5_omni", "qwen3_omni_moe"]:
model = getattr(model, "thinker")
if model_args.mixture_of_depths == "convert":
model = convert_pretrained_model_to_mod(model, config, model_args)
......@@ -199,14 +205,21 @@ def load_model(
if not is_trainable:
model.requires_grad_(False)
for param in model.parameters():
if param.data.dtype == torch.float32 and model_args.compute_dtype != torch.float32:
param.data = param.data.to(model_args.compute_dtype)
model.eval()
else:
model.train()
# Borrowing the kernel plugins ability of v1 to temporarily apply the NPU fusion operator to v0,
# it is turned off by default, and can be discarded after the transition period ends.
if model_args.use_v1_kernels and is_trainable:
logger.warning_rank0(
"You are try to using future feature about kernels, please note that this feature "
"is not supported for all models. If get any error, please disable this feature, or report the issue."
)
from ..v1.plugins.model_plugins.kernels.interface import apply_default_kernels
model = apply_default_kernels(model=model, include_kernels=model_args.use_v1_kernels)
trainable_params, all_param = count_parameters(model)
if is_trainable:
param_stats = (
......
......@@ -14,10 +14,9 @@
from typing import TYPE_CHECKING
from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available
from ...extras import logging
from ...extras.constants import AttentionFunction
from ...extras.packages import is_torch_version_greater_than
if TYPE_CHECKING:
......@@ -30,6 +29,20 @@ logger = logging.get_logger(__name__)
def configure_attn_implementation(config: "PretrainedConfig", model_args: "ModelArguments") -> None:
from transformers.utils import is_flash_attn_2_available
if getattr(config, "model_type", None) == "gpt_oss":
from transformers.integrations.hub_kernels import load_and_register_kernel
flash_attn3_kernel = "kernels-community/vllm-flash-attn3"
load_and_register_kernel(flash_attn3_kernel)
setattr(config, "_attn_implementation", flash_attn3_kernel)
setattr(config, "_attn_implementation_internal", flash_attn3_kernel)
model_args.flash_attn = AttentionFunction.FA3
logger.info_rank0("Using FlashAttention-3 with attention sink for the gpt-oss model.")
return
if getattr(config, "model_type", None) == "gemma2":
if model_args.flash_attn == AttentionFunction.AUTO or model_args.flash_attn == AttentionFunction.FA2:
if is_flash_attn_2_available():
......@@ -51,13 +64,15 @@ def configure_attn_implementation(config: "PretrainedConfig", model_args: "Model
requested_attn_implementation = "eager"
elif model_args.flash_attn == AttentionFunction.SDPA:
if not is_torch_sdpa_available():
if not is_torch_version_greater_than("2.1.1"):
logger.warning_rank0("torch>=2.1.1 is required for SDPA attention.")
return
requested_attn_implementation = "sdpa"
elif model_args.flash_attn == AttentionFunction.FA2:
if not is_flash_attn_2_available():
from transformers import is_torch_npu_available
if not (is_flash_attn_2_available() or is_torch_npu_available()):
logger.warning_rank0("FlashAttention-2 is not installed.")
return
......
......@@ -19,9 +19,11 @@
# limitations under the License.
import inspect
import os
from collections.abc import Callable
from functools import WRAPPER_ASSIGNMENTS, partial, wraps
from types import MethodType
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
from typing import TYPE_CHECKING, Any, Optional, Union
import torch
......@@ -152,6 +154,13 @@ def prepare_model_for_training(model: "PreTrainedModel", model_args: "ModelArgum
if param.ndim == 1 and any(ln_name in name for ln_name in LAYERNORM_NAMES):
param.data = param.data.to(torch.float32)
if (
os.environ.get("ACCELERATE_USE_FSDP", "false").lower() == "true"
and int(os.environ.get("FSDP_VERSION", "1")) == 2
):
model_args.use_reentrant_gc = False
logger.warning_rank0("You are using fsdp2, `use_reentrant_gc` has been set to False.")
if not model_args.disable_gradient_checkpointing:
if not getattr(model, "supports_gradient_checkpointing", False):
logger.warning_rank0("Current model does not support gradient checkpointing.")
......
......@@ -14,7 +14,7 @@
import math
from contextlib import nullcontext
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional
import torch
from transformers.integrations import is_deepspeed_zero3_enabled
......@@ -30,6 +30,14 @@ logger = logging.get_logger(__name__)
def _noisy_mean_initialization(embed_weight: "torch.Tensor", num_new_tokens: int) -> None:
"""Initialize new token embeddings with mean + Gaussian noise.
This is the default initialization method used by LlamaFactory.
Args:
embed_weight: The embedding weight matrix to initialize (shape: [vocab_size, embedding_dim])
num_new_tokens: Number of new tokens added at the end of the embedding matrix
"""
embedding_dim = embed_weight.size(1)
avg_weight = embed_weight[:-num_new_tokens].mean(dim=0, keepdim=True)
noise_weight = torch.empty_like(embed_weight[-num_new_tokens:])
......@@ -37,8 +45,125 @@ def _noisy_mean_initialization(embed_weight: "torch.Tensor", num_new_tokens: int
embed_weight[-num_new_tokens:] = avg_weight + noise_weight
def resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer") -> None:
r"""Resize token embeddings."""
def _description_based_initialization(
embed_weight: "torch.Tensor",
num_new_tokens: int,
descriptions: dict[str, str],
tokenizer: "PreTrainedTokenizer",
model: "PreTrainedModel",
add_noise: bool = False,
) -> None:
"""Initialize new token embeddings based on textual descriptions.
For each new token, this function:
1. Tokenizes its description text
2. Gets embeddings of the description tokens
3. Averages them to initialize the new token's embedding
4. Optionally adds Gaussian noise
Args:
embed_weight: The embedding weight matrix to initialize (shape: [vocab_size, embedding_dim])
num_new_tokens: Number of new tokens added
descriptions: Dict mapping token string to its description text
e.g., {"<think>": "A token representing reasoning process"}
tokenizer: The tokenizer instance
model: The model instance (used to get input embeddings)
add_noise: Whether to add Gaussian noise to the initialization
Example:
descriptions = {
"<|START_OF_SVG|>": "Marks the beginning of an SVG document",
"<|END_OF_SVG|>": "Marks the end of an SVG document"
}
"""
embedding_dim = embed_weight.size(1)
for i, desc in enumerate(descriptions.values()):
# Tokenize description text
tokens = tokenizer(desc, return_tensors="pt", add_special_tokens=False)
with torch.no_grad():
token_ids = tokens["input_ids"][0]
# Move to the same device as embed_weight
device = embed_weight.device
token_ids = token_ids.to(device)
# Filter out new tokens (they don't have valid embeddings yet)
valid_token_ids = token_ids[token_ids < (len(tokenizer) - num_new_tokens)]
if len(valid_token_ids) == 0:
# Fallback: use mean of all existing embeddings
logger.warning_rank0(
f"Description for token {i + 1}/{num_new_tokens} contains no valid tokens. "
"Using mean of existing embeddings."
)
base_embedding = embed_weight[:-num_new_tokens].mean(dim=0)
else:
# Get embeddings of description tokens and average them
token_embeds = model.get_input_embeddings()(valid_token_ids)
base_embedding = token_embeds.mean(dim=0)
# Add noise if requested (ensure correct device and dtype)
if add_noise:
noise = torch.randn_like(base_embedding) * (1.0 / math.sqrt(embedding_dim))
embed_weight[-num_new_tokens + i] = base_embedding + noise
else:
embed_weight[-num_new_tokens + i] = base_embedding
def _initialize_embeddings(
embed_weight: "torch.Tensor",
num_new_tokens: int,
init_method: str,
new_special_tokens_config: Optional[dict],
tokenizer: "PreTrainedTokenizer",
model: "PreTrainedModel",
) -> None:
"""Single source of truth for embedding initialization.
This function selects the appropriate initialization method and applies it.
Args:
embed_weight: The embedding weight matrix to initialize
num_new_tokens: Number of new tokens added
init_method: Initialization method ('noise_init', 'desc_init', 'desc_init_w_noise')
new_special_tokens_config: Config dict with token descriptions (required for desc_init methods)
tokenizer: The tokenizer instance
model: The model instance
"""
if init_method == "desc_init" and new_special_tokens_config:
logger.info_rank0("Using semantic initialization (desc_init) for new special tokens")
_description_based_initialization(
embed_weight, num_new_tokens, new_special_tokens_config, tokenizer, model, add_noise=False
)
elif init_method == "desc_init_w_noise" and new_special_tokens_config:
logger.info_rank0("Using semantic initialization with noise (desc_init_w_noise) for new special tokens")
_description_based_initialization(
embed_weight, num_new_tokens, new_special_tokens_config, tokenizer, model, add_noise=True
)
else:
if init_method != "noise_init":
logger.warning_rank0(
f"init_method='{init_method}' requires descriptions config, falling back to 'noise_init'"
)
logger.info_rank0("Using noisy mean initialization (noise_init) for new special tokens")
_noisy_mean_initialization(embed_weight, num_new_tokens)
def resize_embedding_layer(
model: "PreTrainedModel",
tokenizer: "PreTrainedTokenizer",
new_special_tokens_config: Optional[dict] = None,
init_special_tokens: str = "noise_init",
) -> None:
r"""Resize token embeddings and initialize new tokens.
Args:
model: The model to resize
tokenizer: The tokenizer (used to get target vocab size)
new_special_tokens_config: Optional dict with token descriptions for semantic initialization
init_special_tokens: Initialization method ('noise_init', 'desc_init', 'desc_init_w_noise')
"""
if is_deepspeed_zero3_enabled():
import deepspeed # type: ignore
......@@ -64,7 +189,30 @@ def resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedToken
with context_maybe_zero3:
new_embedding_size = model.get_input_embeddings().weight.size(0)
num_new_tokens = new_embedding_size - current_embedding_size
_noisy_mean_initialization(model.get_input_embeddings().weight.data, num_new_tokens)
_noisy_mean_initialization(model.get_output_embeddings().weight.data, num_new_tokens)
logger.info_rank0(
f"Resizing embeddings: {current_embedding_size} -> {new_embedding_size} (+{num_new_tokens} tokens)"
)
# Initialize input embeddings
_initialize_embeddings(
model.get_input_embeddings().weight.data,
num_new_tokens,
init_special_tokens,
new_special_tokens_config,
tokenizer,
model,
)
# Initialize output embeddings if not tied
if model.get_output_embeddings() is not None and not model.config.tie_word_embeddings:
_initialize_embeddings(
model.get_output_embeddings().weight.data,
num_new_tokens,
init_special_tokens,
new_special_tokens_config,
tokenizer,
model,
)
model.config.vocab_size = new_embedding_size
logger.info_rank0(f"Resized token embeddings from {current_embedding_size} to {new_embedding_size}.")
# Copyright 2025 the KVCache.AI team, Approaching AI, and the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib.util as _u
from typing import TYPE_CHECKING, Any
import torch
from ...extras import logging
from ...extras.misc import get_current_device
if TYPE_CHECKING:
from ...hparams import FinetuningArguments, ModelArguments
from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
KT_AVAILABLE = _u.find_spec("ktransformers") is not None
if KT_AVAILABLE:
from ktransformers.models.modeling_deepseek import DeepseekV2ForCausalLM
from ktransformers.models.modeling_deepseek_v3 import DeepseekV3ForCausalLM
from ktransformers.models.modeling_llama import LlamaForCausalLM
from ktransformers.models.modeling_mixtral import MixtralForCausalLM
from ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM
from ktransformers.models.modeling_qwen3_moe import Qwen3MoeForCausalLM
from ktransformers.optimize.optimize import optimize_and_load_gguf
from ktransformers.server.config.config import Config
from ktransformers.sft.lora import inject_lora_layer
from ktransformers.util.custom_loader import GGUFLoader, SafeTensorLoader
from ktransformers.util.globals import GLOBAL_CONFIG
from ktransformers.util.utils import load_weights
logger = logging.get_logger(__name__)
def _get_kt_kwargs(
config: "PretrainedConfig",
model_name_or_path: str,
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
) -> dict[str, Any]:
return {
"model_name": model_name_or_path,
"max_seq_length": model_args.model_max_length or 4096,
"dtype": model_args.compute_dtype,
"load_in_4bit": model_args.quantization_bit == 4,
"token": model_args.hf_hub_token,
"full_finetuning": finetuning_args.finetuning_type == "full",
"device_map": {"": get_current_device()},
"rope_scaling": getattr(config, "rope_scaling", None),
"fix_tokenizer": False,
"trust_remote_code": model_args.trust_remote_code,
"use_gradient_checkpointing": "ktransformers",
}
def load_kt_pretrained_model(config: "PretrainedConfig", model_args: "ModelArguments") -> "PreTrainedModel":
r"""Optionally load pretrained model with KTransformers. Used in training."""
custom_models = {
"DeepseekV2ForCausalLM": DeepseekV2ForCausalLM,
"DeepseekV3ForCausalLM": DeepseekV3ForCausalLM,
"Qwen2MoeForCausalLM": Qwen2MoeForCausalLM,
"Qwen3MoeForCausalLM": Qwen3MoeForCausalLM,
"LlamaForCausalLM": LlamaForCausalLM,
"MixtralForCausalLM": MixtralForCausalLM,
}
Config().cpu_infer = model_args.cpu_infer
Config().chunk_size = model_args.chunk_size
config = AutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code)
if model_args.mode == "long_context":
assert config.architectures[0] == "LlamaForCausalLM", "only LlamaForCausalLM support long_context mode"
torch.set_default_dtype(torch.float16)
else:
torch.set_default_dtype(config.torch_dtype)
with torch.device("meta"):
if config.architectures[0] in custom_models:
print("using custom modeling_xxx.py.")
if "Qwen2Moe" in config.architectures[0]: # Qwen2Moe must use flash_attention_2 to avoid overflow.
config._attn_implementation = "flash_attention_2"
if "Llama" in config.architectures[0]:
config._attn_implementation = "eager"
if "Mixtral" in config.architectures[0]:
config._attn_implementation = "flash_attention_2"
model = custom_models[config.architectures[0]](config)
else:
attn_implementation = "flash_attention_2"
model = AutoModelForCausalLM.from_config(
config, trust_remote_code=True, attn_implementation=attn_implementation
)
optimize_config_path = model_args.kt_optimize_rule
gguf_path = model_args.model_name_or_path
assert optimize_config_path is not None, "optimize_config_path must be provided (path to YAML rules file)."
assert gguf_path is not None, "gguf_path must be provided (path to a folder or .gguf file)."
GLOBAL_CONFIG._config["mod"] = "infer"
optimize_and_load_gguf(model, optimize_config_path, gguf_path, config)
return model
def get_kt_peft_model(model: "PreTrainedModel", peft_kwargs: dict[str, Any]) -> "PreTrainedModel":
r"""Get the peft model for the pretrained model with KTransformers. Used in training."""
from ktransformers.sft.peft_utils.mapping import get_peft_model
return get_peft_model(model, peft_kwargs)
def load_kt_peft_model(model_args: "ModelArguments", model: "PreTrainedModel") -> "PreTrainedModel":
r"""Load peft model with KTransformers. Used in both training and inference."""
load_adapter_name_or_path = model_args.adapter_name_or_path[0]
if load_adapter_name_or_path.endswith(".gguf"):
inject_lora_layer(model, load_adapter_name_or_path)
adapter_gguf_loader = GGUFLoader(load_adapter_name_or_path)
load_weights(model, adapter_gguf_loader, adapter_gguf=True)
model.train()
else:
inject_lora_layer(model, load_adapter_name_or_path)
adapter_loader = SafeTensorLoader(load_adapter_name_or_path)
device = next(model.parameters()).device
for key in adapter_loader.tensor_file_map.keys():
try:
tensor = adapter_loader.load_tensor(key, device=device)
model_key = key.replace("base_model.model.", "")
model_key = model_key.replace(".weight", ".default.weight")
model_key = model_key.replace(".default.default.weight", ".default.weight")
param = model.get_parameter(model_key)
param.data.copy_(tensor.data)
print(f"Loaded adapter weight: {key} -> {model_key}")
except AttributeError:
print(f"Skipping {key}: not a model parameter")
except KeyError:
print(f"Key not found in model: {model_key} (original: {key})")
return model
......@@ -28,11 +28,11 @@ if TYPE_CHECKING:
def configure_kv_cache(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
if not is_trainable:
setattr(config, "use_cache", model_args.use_cache)
setattr(config, "use_cache", model_args.use_kv_cache)
if hasattr(config, "text_config"):
setattr(config.text_config, "use_cache", model_args.use_cache)
setattr(config.text_config, "use_cache", model_args.use_kv_cache)
if model_args.use_cache:
if model_args.use_kv_cache:
logger.info_rank0("KV cache is enabled for faster generation.")
else:
logger.info_rank0("KV cache is disabled.")
......
......@@ -47,6 +47,8 @@ def apply_liger_kernel(
from liger_kernel.transformers import apply_liger_kernel_to_gemma3_text as apply_liger_kernel
elif model_type == "glm4":
from liger_kernel.transformers import apply_liger_kernel_to_glm4 as apply_liger_kernel
elif model_type == "glm4v":
from liger_kernel.transformers import apply_liger_kernel_to_glm4v as apply_liger_kernel
elif model_type == "granite":
from liger_kernel.transformers import apply_liger_kernel_to_granite as apply_liger_kernel
elif model_type == "llama":
......@@ -75,6 +77,12 @@ def apply_liger_kernel(
from liger_kernel.transformers import apply_liger_kernel_to_qwen3 as apply_liger_kernel
elif model_type == "qwen3_moe":
from liger_kernel.transformers import apply_liger_kernel_to_qwen3_moe as apply_liger_kernel
elif model_type == "gpt_oss":
try:
from liger_kernel.transformers import apply_liger_kernel_to_gpt_oss as apply_liger_kernel
except ImportError:
logger.warning_rank0("Please install liger-kernel from https://github.com/Comet0322/Liger-Kernel.")
return
else:
logger.warning_rank0("Current model does not support liger kernel.")
return
......
......@@ -14,9 +14,13 @@
from typing import TYPE_CHECKING, Union
import torch
from torch import nn
from torch.nn import functional as F
from transformers.integrations import is_deepspeed_zero3_enabled
from ...extras.misc import check_version
from ...extras.packages import is_transformers_version_greater_than
if TYPE_CHECKING:
......@@ -25,6 +29,9 @@ if TYPE_CHECKING:
from ...hparams import ModelArguments
if is_transformers_version_greater_than("4.57.0"):
from transformers.models.qwen3_omni_moe import modeling_qwen3_omni_moe
def _set_z3_leaf_modules(model: "PreTrainedModel", leaf_modules: list[Union["nn.Module", str]]) -> None:
check_version("deepspeed>=0.13.0")
......@@ -39,6 +46,9 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
return
model_type = getattr(model.config, "model_type", None)
text_config = getattr(model.config, "text_config", None)
text_model_type = getattr(text_config, "model_type", None)
if model_type == "dbrx":
from transformers.models.dbrx.modeling_dbrx import DbrxFFN
......@@ -52,11 +62,31 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
# deepseek v3 and kimi vl use custom code
_set_z3_leaf_modules(model, ["DeepseekV3MoE"])
if model_type == "ernie4_5_moe":
from transformers.models.ernie4_5_moe.modeling_ernie4_5_moe import Ernie4_5_MoeSparseMoeBlock
_set_z3_leaf_modules(model, [Ernie4_5_MoeSparseMoeBlock])
if model_type == "granitemoe":
from transformers.models.granitemoe.modeling_granitemoe import GraniteMoeMoE
_set_z3_leaf_modules(model, [GraniteMoeMoE])
if model_type == "glm4_moe":
from transformers.models.glm4_moe.modeling_glm4_moe import Glm4MoeMoE
_set_z3_leaf_modules(model, [Glm4MoeMoE])
if model_type == "glm4v_moe":
from transformers.models.glm4v_moe.modeling_glm4v_moe import Glm4vMoeTextMoE
_set_z3_leaf_modules(model, [Glm4vMoeTextMoE])
if model_type == "gpt_oss":
from transformers.models.gpt_oss.modeling_gpt_oss import GptOssMLP
_set_z3_leaf_modules(model, [GptOssMLP])
if model_type == "jamba":
from transformers.models.jamba.modeling_jamba import JambaSparseMoeBlock
......@@ -92,19 +122,32 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
_set_z3_leaf_modules(model, [Qwen2MoeSparseMoeBlock])
if model_type == "qwen3_moe":
if model_type == "qwen3_moe" or text_model_type == "qwen3_moe": # internvl 3.5
from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeSparseMoeBlock
_set_z3_leaf_modules(model, [Qwen3MoeSparseMoeBlock])
if model_type == "qwen3_vl_moe":
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeTextSparseMoeBlock
_set_z3_leaf_modules(model, [Qwen3VLMoeTextSparseMoeBlock])
if model_type in ("qwen3_omni_moe", "qwen3_omni_moe_thinker"):
from transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe import Qwen3OmniMoeThinkerTextSparseMoeBlock
_set_z3_leaf_modules(model, [Qwen3OmniMoeThinkerTextSparseMoeBlock])
def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
if not is_trainable or not model_args.moe_aux_loss_coef:
return
model_type = getattr(config, "model_type", None)
text_config = getattr(config, "text_config", None) # for multimodal model
if model_type in [
"dbrx",
"ernie4_5_moe",
"granitemoe",
"jamba",
"jetmoe",
......@@ -117,11 +160,93 @@ def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_t
]:
setattr(config, "output_router_logits", True)
if model_type in ["granitemoe", "jamba", "llama4", "mixtral", "olmoe", "phimoe", "qwen2_moe", "qwen3_moe"]:
if text_config and getattr(text_config, "model_type", None) in [
"glm4v_moe_text", # glmv4_5
"qwen3_moe", # internvl_3_5
]:
setattr(text_config, "output_router_logits", True)
if model_type in [
"ernie4_5_moe",
"granitemoe",
"jamba",
"llama4",
"mixtral",
"olmoe",
"phimoe",
"qwen2_moe",
"qwen3_moe",
]:
setattr(config, "router_aux_loss_coef", model_args.moe_aux_loss_coef)
elif text_config and getattr(text_config, "model_type", None) in ["qwen3_moe"]:
setattr(text_config, "router_aux_loss_coef", model_args.moe_aux_loss_coef)
elif model_type == "deepseek":
setattr(config, "aux_loss_alpha", model_args.moe_aux_loss_coef)
elif model_type == "jetmoe":
setattr(config, "aux_loss_coef", model_args.moe_aux_loss_coef)
class Qwen3OmniMoeThinkerTextSparseMoeBlock(nn.Module):
def __init__(self, config):
super().__init__()
self.num_experts = config.num_experts
self.top_k = config.num_experts_per_tok
self.norm_topk_prob = config.norm_topk_prob
# gating
self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
self.experts = nn.ModuleList(
[
modeling_qwen3_omni_moe.Qwen3OmniMoeThinkerTextMLP(
config, intermediate_size=config.moe_intermediate_size
)
for _ in range(self.num_experts)
]
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)
# Calculate the routing weights for all experts
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
# Retain the weight of the top_k and reset the rest of the expert rights to 0 (instead of retaining only top_k experts)
top_k_weights, top_k_indices = torch.topk(routing_weights, self.top_k, dim=-1)
# Initialize the all-zero weight matrix (same shape as all experts)
full_routing_weights = torch.zeros_like(routing_weights)
# Only the weight of top_k experts is retained, and the weight of the rest of the experts remains at 0
full_routing_weights.scatter_(1, top_k_indices, top_k_weights)
# Normalized top_k weights (keep the original logic consistent)
if self.norm_topk_prob:
# Calculate the sum of the weights top_k each row (for normalization)
top_k_sum = full_routing_weights.sum(dim=-1, keepdim=True)
# Avoid dividing by zero
top_k_sum = torch.clamp(top_k_sum, min=1e-9)
full_routing_weights /= top_k_sum
# Convert back to the input data type
full_routing_weights = full_routing_weights.to(hidden_states.dtype)
final_hidden_states = torch.zeros(
(batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
)
# Go through all the experts (not just the selected ones)
for expert_idx in range(self.num_experts):
expert_layer = self.experts[expert_idx]
# Get the weight of the current expert (inactive expert has a weight of 0 here)
expert_weights = full_routing_weights[:, expert_idx, None] # shape: (batch*seq, 1)
# All samples participate in the calculations of the current expert, the weight may be equal to 0
current_hidden_states = expert_layer(hidden_states) * expert_weights
# Add-up to all expert outputs (experts with a weight of 0 do not affect the result)
final_hidden_states += current_hidden_states
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
return final_hidden_states, router_logits
......@@ -53,7 +53,7 @@ logger = logging.get_logger(__name__)
def get_seqlens_in_batch(attention_mask: "torch.Tensor") -> "torch.Tensor":
r"""Get the sequnce lengths in the current batch.
r"""Get the sequence lengths in the current batch.
e.g.
```python
......
......@@ -83,6 +83,7 @@ def configure_quantization(
config: "PretrainedConfig",
tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments",
is_trainable: bool,
init_kwargs: dict[str, Any],
) -> None:
r"""Priority: PTQ-quantized (train/infer) > AutoGPTQ (export) > On-the-fly quantization (train/infer)."""
......@@ -90,12 +91,29 @@ def configure_quantization(
if model_args.quantization_bit is not None:
logger.warning_rank0("`quantization_bit` will not affect on the PTQ-quantized models.")
if is_deepspeed_zero3_enabled() or is_fsdp_enabled():
raise ValueError("DeepSpeed ZeRO-3 or FSDP is incompatible with PTQ-quantized models.")
quantization_config: dict[str, Any] = getattr(config, "quantization_config", None)
quant_method = quantization_config.get("quant_method", "")
if quant_method not in (QuantizationMethod.MXFP4, QuantizationMethod.FP8) and (
is_deepspeed_zero3_enabled() or is_fsdp_enabled()
):
# mxfp4 will dequant the model weights
raise ValueError("DeepSpeed ZeRO-3 or FSDP is incompatible with PTQ-quantized models.")
if quant_method == QuantizationMethod.MXFP4:
from transformers import Mxfp4Config
quant_config = Mxfp4Config(dequantize=True)
init_kwargs["quantization_config"] = quant_config
init_kwargs["ignore_mismatched_sizes"] = True
if quant_method == QuantizationMethod.FP8:
from transformers import FineGrainedFP8Config
quant_config = FineGrainedFP8Config(dequantize=True)
init_kwargs["quantization_config"] = quant_config
init_kwargs["ignore_mismatched_sizes"] = True
if quant_method == QuantizationMethod.GPTQ:
check_version("gptqmodel>=2.0.0", mandatory=True)
quantization_config.pop("disable_exllama", None) # remove deprecated args
......
......@@ -40,7 +40,10 @@ def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments") ->
logger.warning_rank0("Current model does not support RoPE scaling.")
return
if hasattr(config, "max_position_embeddings"):
rope_scaling = getattr(config, "rope_scaling", None)
if isinstance(rope_scaling, dict) and "original_max_position_embeddings" in rope_scaling:
old_max_length = rope_scaling["original_max_position_embeddings"]
elif hasattr(config, "max_position_embeddings"):
old_max_length = getattr(config, "max_position_embeddings", None)
else:
logger.warning_rank0("Cannot find the max position embeddings in the config.")
......
......@@ -199,6 +199,15 @@ def patch_target_modules(
return target_modules
_register_composite_model(
model_type="dots_ocr",
projector_key="vision_tower.merger",
vision_model_keys=["vision_tower"],
language_model_keys=["model", "lm_head"],
lora_conflict_keys=["merger"],
)
_register_composite_model(
model_type="gemma3",
)
......@@ -221,10 +230,36 @@ _register_composite_model(
)
_register_composite_model(
model_type="glm4v_moe",
projector_key="visual.merger",
vision_model_keys=["visual.patch_embed", "visual.blocks"],
language_model_keys=["language_model", "lm_head"],
lora_conflict_keys=["patch_embed"],
)
_register_composite_model(
model_type="internvl",
)
_register_composite_model(
model_type="interns1",
)
_register_composite_model(
model_type="Keye",
projector_key="mlp_AR",
vision_model_keys=["visual.vision_model.patch_embedding", "visual.vision_model.encoder"],
language_model_keys=["model", "lm_head"],
lora_conflict_keys=["patch_embedding"],
)
_register_composite_model(
model_type="kimi_vl",
)
_register_composite_model(
model_type="llama4",
......@@ -263,8 +298,10 @@ _register_composite_model(
lora_conflict_keys=["audio_projection_layer"],
)
_register_composite_model(
model_type="mistral3",
projector_key="model.multi_modal_projector",
)
......@@ -316,6 +353,33 @@ _register_composite_model(
)
_register_composite_model(
model_type="qwen3_vl",
projector_key="visual.merger",
vision_model_keys=["visual.patch_embed", "visual.blocks", "visual.deepstack_merger_list"],
language_model_keys=["language_model", "lm_head"],
lora_conflict_keys=["patch_embed"],
)
_register_composite_model(
model_type="qwen3_vl_moe",
projector_key="visual.merger",
vision_model_keys=["visual.patch_embed", "visual.blocks", "visual.deepstack_merger_list"],
language_model_keys=["language_model", "lm_head"],
lora_conflict_keys=["patch_embed"],
)
_register_composite_model(
model_type="qwen3_omni_moe_thinker",
projector_key="visual.merger",
vision_model_keys=["visual.patch_embed", "visual.blocks", "visual.deepstack_merger_list", "audio_tower"],
language_model_keys=["model", "lm_head"],
lora_conflict_keys=["patch_embed"],
)
_register_composite_model(
model_type="video_llava",
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment