Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
LLaMA-Factory
Commits
ca625f43
Commit
ca625f43
authored
Mar 30, 2026
by
shihm
Browse files
uodata
parent
7164651d
Changes
327
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1241 additions
and
161 deletions
+1241
-161
src/llamafactory/hparams/data_args.py
src/llamafactory/hparams/data_args.py
+14
-14
src/llamafactory/hparams/evaluation_args.py
src/llamafactory/hparams/evaluation_args.py
+2
-2
src/llamafactory/hparams/finetuning_args.py
src/llamafactory/hparams/finetuning_args.py
+98
-28
src/llamafactory/hparams/model_args.py
src/llamafactory/hparams/model_args.py
+182
-29
src/llamafactory/hparams/parser.py
src/llamafactory/hparams/parser.py
+76
-21
src/llamafactory/hparams/training_args.py
src/llamafactory/hparams/training_args.py
+32
-9
src/llamafactory/launcher.py
src/llamafactory/launcher.py
+165
-3
src/llamafactory/model/adapter.py
src/llamafactory/model/adapter.py
+82
-22
src/llamafactory/model/loader.py
src/llamafactory/model/loader.py
+25
-12
src/llamafactory/model/model_utils/attention.py
src/llamafactory/model/model_utils/attention.py
+19
-4
src/llamafactory/model/model_utils/checkpointing.py
src/llamafactory/model/model_utils/checkpointing.py
+10
-1
src/llamafactory/model/model_utils/embedding.py
src/llamafactory/model/model_utils/embedding.py
+154
-6
src/llamafactory/model/model_utils/ktransformers.py
src/llamafactory/model/model_utils/ktransformers.py
+154
-0
src/llamafactory/model/model_utils/kv_cache.py
src/llamafactory/model/model_utils/kv_cache.py
+3
-3
src/llamafactory/model/model_utils/liger_kernel.py
src/llamafactory/model/model_utils/liger_kernel.py
+8
-0
src/llamafactory/model/model_utils/moe.py
src/llamafactory/model/model_utils/moe.py
+127
-2
src/llamafactory/model/model_utils/packing.py
src/llamafactory/model/model_utils/packing.py
+1
-1
src/llamafactory/model/model_utils/quantization.py
src/llamafactory/model/model_utils/quantization.py
+21
-3
src/llamafactory/model/model_utils/rope.py
src/llamafactory/model/model_utils/rope.py
+4
-1
src/llamafactory/model/model_utils/visual.py
src/llamafactory/model/model_utils/visual.py
+64
-0
No files found.
src/llamafactory/hparams/data_args.py
View file @
ca625f43
...
...
@@ -16,22 +16,22 @@
# limitations under the License.
from
dataclasses
import
asdict
,
dataclass
,
field
from
typing
import
Any
,
Literal
,
Optional
from
typing
import
Any
,
Literal
@
dataclass
class
DataArguments
:
r
"""Arguments pertaining to what data we are going to input our model for training and evaluation."""
template
:
Optional
[
str
]
=
field
(
template
:
str
|
None
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Which template to use for constructing prompts in training and inference."
},
)
dataset
:
Optional
[
str
]
=
field
(
dataset
:
str
|
None
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The name of dataset(s) to use for training. Use commas to separate multiple datasets."
},
)
eval_dataset
:
Optional
[
str
]
=
field
(
eval_dataset
:
str
|
None
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The name of dataset(s) to use for evaluation. Use commas to separate multiple datasets."
},
)
...
...
@@ -39,7 +39,7 @@ class DataArguments:
default
=
"data"
,
metadata
=
{
"help"
:
"Path to the folder containing the datasets."
},
)
media_dir
:
Optional
[
str
]
=
field
(
media_dir
:
str
|
None
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Path to the folder containing the images, videos or audios. Defaults to `dataset_dir`."
},
)
...
...
@@ -67,7 +67,7 @@ class DataArguments:
default
=
"concat"
,
metadata
=
{
"help"
:
"Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling)."
},
)
interleave_probs
:
Optional
[
str
]
=
field
(
interleave_probs
:
str
|
None
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Probabilities to sample data from datasets. Use commas to separate multiple datasets."
},
)
...
...
@@ -79,15 +79,15 @@ class DataArguments:
default
=
1000
,
metadata
=
{
"help"
:
"The number of examples in one group in pre-processing."
},
)
preprocessing_num_workers
:
Optional
[
int
]
=
field
(
preprocessing_num_workers
:
int
|
None
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The number of processes to use for the pre-processing."
},
)
max_samples
:
Optional
[
int
]
=
field
(
max_samples
:
int
|
None
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"For debugging purposes, truncate the number of examples for each dataset."
},
)
eval_num_beams
:
Optional
[
int
]
=
field
(
eval_num_beams
:
int
|
None
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Number of beams to use for evaluation. This argument will be passed to `model.generate`"
},
)
...
...
@@ -103,7 +103,7 @@ class DataArguments:
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to evaluate on each dataset separately."
},
)
packing
:
Optional
[
bool
]
=
field
(
packing
:
bool
|
None
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Enable sequences packing in training. Will automatically enable in pre-training."
},
)
...
...
@@ -111,19 +111,19 @@ class DataArguments:
default
=
False
,
metadata
=
{
"help"
:
"Enable sequence packing without cross-attention."
},
)
tool_format
:
Optional
[
str
]
=
field
(
tool_format
:
str
|
None
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Tool format to use for constructing function calling examples."
},
)
default_system
:
Optional
[
str
]
=
field
(
default_system
:
str
|
None
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Override the default system message in the template."
},
)
enable_thinking
:
Optional
[
bool
]
=
field
(
enable_thinking
:
bool
|
None
=
field
(
default
=
True
,
metadata
=
{
"help"
:
"Whether or not to enable thinking mode for reasoning models."
},
)
tokenized_path
:
Optional
[
str
]
=
field
(
tokenized_path
:
str
|
None
=
field
(
default
=
None
,
metadata
=
{
"help"
:
(
...
...
src/llamafactory/hparams/evaluation_args.py
View file @
ca625f43
...
...
@@ -14,7 +14,7 @@
import
os
from
dataclasses
import
dataclass
,
field
from
typing
import
Literal
,
Optional
from
typing
import
Literal
from
datasets
import
DownloadMode
...
...
@@ -46,7 +46,7 @@ class EvaluationArguments:
default
=
5
,
metadata
=
{
"help"
:
"Number of examplars for few-shot learning."
},
)
save_dir
:
Optional
[
str
]
=
field
(
save_dir
:
str
|
None
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Path to save the evaluation results."
},
)
...
...
src/llamafactory/hparams/finetuning_args.py
View file @
ca625f43
...
...
@@ -13,7 +13,7 @@
# limitations under the License.
from
dataclasses
import
asdict
,
dataclass
,
field
from
typing
import
Any
,
Literal
,
Optional
from
typing
import
Any
,
Literal
@
dataclass
...
...
@@ -40,7 +40,7 @@ class FreezeArguments:
)
},
)
freeze_extra_modules
:
Optional
[
str
]
=
field
(
freeze_extra_modules
:
str
|
None
=
field
(
default
=
None
,
metadata
=
{
"help"
:
(
...
...
@@ -56,7 +56,7 @@ class FreezeArguments:
class
LoraArguments
:
r
"""Arguments pertaining to the LoRA training."""
additional_target
:
Optional
[
str
]
=
field
(
additional_target
:
str
|
None
=
field
(
default
=
None
,
metadata
=
{
"help"
:
(
...
...
@@ -66,7 +66,7 @@ class LoraArguments:
)
},
)
lora_alpha
:
Optional
[
int
]
=
field
(
lora_alpha
:
int
|
None
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The scale factor for LoRA fine-tuning (default: lora_rank * 2)."
},
)
...
...
@@ -88,7 +88,7 @@ class LoraArguments:
)
},
)
loraplus_lr_ratio
:
Optional
[
float
]
=
field
(
loraplus_lr_ratio
:
float
|
None
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"LoRA plus learning rate ratio (lr_B / lr_A)."
},
)
...
...
@@ -122,6 +122,48 @@ class LoraArguments:
)
@
dataclass
class
OFTArguments
:
r
"""Arguments pertaining to the OFT training."""
additional_target
:
str
|
None
=
field
(
default
=
None
,
metadata
=
{
"help"
:
(
"Name(s) of modules apart from LoRA layers to be set as trainable "
"and saved in the final checkpoint. "
"Use commas to separate multiple modules."
)
},
)
module_dropout
:
float
=
field
(
default
=
0.0
,
metadata
=
{
"help"
:
"Dropout rate for the OFT fine-tuning."
},
)
oft_rank
:
int
=
field
(
default
=
0
,
metadata
=
{
"help"
:
"The intrinsic dimension for OFT fine-tuning."
},
)
oft_block_size
:
int
=
field
(
default
=
32
,
metadata
=
{
"help"
:
"The intrinsic dimension for OFT fine-tuning."
},
)
oft_target
:
str
=
field
(
default
=
"all"
,
metadata
=
{
"help"
:
(
"Name(s) of target modules to apply OFT. "
"Use commas to separate multiple modules. "
"Use `all` to specify all the linear modules."
)
},
)
create_new_adapter
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to create a new adapter with randomly initialized weight."
},
)
@
dataclass
class
RLHFArguments
:
r
"""Arguments pertaining to the PPO, DPO and KTO training."""
...
...
@@ -134,6 +176,10 @@ class RLHFArguments:
default
=
0.0
,
metadata
=
{
"help"
:
"The supervised fine-tuning loss coefficient in DPO training."
},
)
pref_bco_weight
:
float
=
field
(
default
=
0.0
,
metadata
=
{
"help"
:
"The Binary Classifier Optimization coefficient in DPO training."
},
)
pref_loss
:
Literal
[
"sigmoid"
,
"hinge"
,
"ipo"
,
"kto_pair"
,
"orpo"
,
"simpo"
]
=
field
(
default
=
"sigmoid"
,
metadata
=
{
"help"
:
"The type of DPO loss to use."
},
...
...
@@ -174,27 +220,27 @@ class RLHFArguments:
default
=
False
,
metadata
=
{
"help"
:
"Whiten the rewards before compute advantages in PPO training."
},
)
ref_model
:
Optional
[
str
]
=
field
(
ref_model
:
str
|
None
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Path to the reference model used for the PPO or DPO training."
},
)
ref_model_adapters
:
Optional
[
str
]
=
field
(
ref_model_adapters
:
str
|
None
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Path to the adapters of the reference model."
},
)
ref_model_quantization_bit
:
Optional
[
int
]
=
field
(
ref_model_quantization_bit
:
int
|
None
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The number of bits to quantize the reference model."
},
)
reward_model
:
Optional
[
str
]
=
field
(
reward_model
:
str
|
None
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Path to the reward model used for the PPO training."
},
)
reward_model_adapters
:
Optional
[
str
]
=
field
(
reward_model_adapters
:
str
|
None
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Path to the adapters of the reward model."
},
)
reward_model_quantization_bit
:
Optional
[
int
]
=
field
(
reward_model_quantization_bit
:
int
|
None
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The number of bits to quantize the reward model."
},
)
...
...
@@ -202,7 +248,7 @@ class RLHFArguments:
default
=
"lora"
,
metadata
=
{
"help"
:
"The type of the reward model in PPO training. Lora model only supports lora training."
},
)
ld_alpha
:
Optional
[
float
]
=
field
(
ld_alpha
:
float
|
None
=
field
(
default
=
None
,
metadata
=
{
"help"
:
(
...
...
@@ -315,15 +361,15 @@ class BAdamArgument:
default
=
"layer"
,
metadata
=
{
"help"
:
"Whether to use layer-wise or ratio-wise BAdam optimizer."
},
)
badam_start_block
:
Optional
[
int
]
=
field
(
badam_start_block
:
int
|
None
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The starting block index for layer-wise BAdam."
},
)
badam_switch_mode
:
Optional
[
Literal
[
"ascending"
,
"descending"
,
"random"
,
"fixed"
]
]
=
field
(
badam_switch_mode
:
Literal
[
"ascending"
,
"descending"
,
"random"
,
"fixed"
]
|
None
=
field
(
default
=
"ascending"
,
metadata
=
{
"help"
:
"the strategy of picking block to update for layer-wise BAdam."
},
)
badam_switch_interval
:
Optional
[
int
]
=
field
(
badam_switch_interval
:
int
|
None
=
field
(
default
=
50
,
metadata
=
{
"help"
:
"Number of steps to update the block for layer-wise BAdam. Use -1 to disable the block update."
...
...
@@ -360,15 +406,15 @@ class SwanLabArguments:
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to use the SwanLab (an experiment tracking and visualization tool)."
},
)
swanlab_project
:
Optional
[
str
]
=
field
(
swanlab_project
:
str
|
None
=
field
(
default
=
"llamafactory"
,
metadata
=
{
"help"
:
"The project name in SwanLab."
},
)
swanlab_workspace
:
Optional
[
str
]
=
field
(
swanlab_workspace
:
str
|
None
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The workspace name in SwanLab."
},
)
swanlab_run_name
:
Optional
[
str
]
=
field
(
swanlab_run_name
:
str
|
None
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The experiment name in SwanLab."
},
)
...
...
@@ -376,19 +422,19 @@ class SwanLabArguments:
default
=
"cloud"
,
metadata
=
{
"help"
:
"The mode of SwanLab."
},
)
swanlab_api_key
:
Optional
[
str
]
=
field
(
swanlab_api_key
:
str
|
None
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The API key for SwanLab."
},
)
swanlab_logdir
:
Optional
[
str
]
=
field
(
swanlab_logdir
:
str
|
None
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The log directory for SwanLab."
},
)
swanlab_lark_webhook_url
:
Optional
[
str
]
=
field
(
swanlab_lark_webhook_url
:
str
|
None
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The Lark(飞书) webhook URL for SwanLab."
},
)
swanlab_lark_secret
:
Optional
[
str
]
=
field
(
swanlab_lark_secret
:
str
|
None
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The Lark(飞书) secret for SwanLab."
},
)
...
...
@@ -396,7 +442,14 @@ class SwanLabArguments:
@
dataclass
class
FinetuningArguments
(
SwanLabArguments
,
BAdamArgument
,
ApolloArguments
,
GaloreArguments
,
RLHFArguments
,
LoraArguments
,
FreezeArguments
SwanLabArguments
,
BAdamArgument
,
ApolloArguments
,
GaloreArguments
,
RLHFArguments
,
LoraArguments
,
OFTArguments
,
FreezeArguments
,
):
r
"""Arguments pertaining to which techniques we are going to fine-tuning with."""
...
...
@@ -408,7 +461,7 @@ class FinetuningArguments(
default
=
"sft"
,
metadata
=
{
"help"
:
"Which stage will be performed in training."
},
)
finetuning_type
:
Literal
[
"lora"
,
"freeze"
,
"full"
]
=
field
(
finetuning_type
:
Literal
[
"lora"
,
"oft"
,
"freeze"
,
"full"
]
=
field
(
default
=
"lora"
,
metadata
=
{
"help"
:
"Which fine-tuning method to use."
},
)
...
...
@@ -420,10 +473,23 @@ class FinetuningArguments(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to use the Adam-mini optimizer."
},
)
use_mca
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
(
"Whether or not to use MCA (Megatron Core Adapter) training. "
"Controlled by USE_MCA environment variable."
)
},
)
use_muon
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to use the Muon optimizer."
},
)
use_dft_loss
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether to use the DFT loss."
},
)
freeze_vision_tower
:
bool
=
field
(
default
=
True
,
metadata
=
{
"help"
:
"Whether ot not to freeze the vision tower in MLLM training."
},
...
...
@@ -444,7 +510,7 @@ class FinetuningArguments(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to disable the shuffling of the training set."
},
)
early_stopping_steps
:
Optional
[
int
]
=
field
(
early_stopping_steps
:
int
|
None
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Number of steps to stop training if the `metric_for_best_model` does not improve."
},
)
...
...
@@ -464,15 +530,16 @@ class FinetuningArguments(
return
arg
self
.
freeze_trainable_modules
:
list
[
str
]
=
split_arg
(
self
.
freeze_trainable_modules
)
self
.
freeze_extra_modules
:
Optional
[
list
[
str
]
]
=
split_arg
(
self
.
freeze_extra_modules
)
self
.
freeze_extra_modules
:
list
[
str
]
|
None
=
split_arg
(
self
.
freeze_extra_modules
)
self
.
lora_alpha
:
int
=
self
.
lora_alpha
or
self
.
lora_rank
*
2
self
.
lora_target
:
list
[
str
]
=
split_arg
(
self
.
lora_target
)
self
.
additional_target
:
Optional
[
list
[
str
]]
=
split_arg
(
self
.
additional_target
)
self
.
oft_target
:
list
[
str
]
=
split_arg
(
self
.
oft_target
)
self
.
additional_target
:
list
[
str
]
|
None
=
split_arg
(
self
.
additional_target
)
self
.
galore_target
:
list
[
str
]
=
split_arg
(
self
.
galore_target
)
self
.
apollo_target
:
list
[
str
]
=
split_arg
(
self
.
apollo_target
)
self
.
use_ref_model
=
self
.
stage
==
"dpo"
and
self
.
pref_loss
not
in
[
"orpo"
,
"simpo"
]
assert
self
.
finetuning_type
in
[
"lora"
,
"freeze"
,
"full"
],
"Invalid fine-tuning method."
assert
self
.
finetuning_type
in
[
"lora"
,
"oft"
,
"freeze"
,
"full"
],
"Invalid fine-tuning method."
assert
self
.
ref_model_quantization_bit
in
[
None
,
8
,
4
],
"We only accept 4-bit or 8-bit quantization."
assert
self
.
reward_model_quantization_bit
in
[
None
,
8
,
4
],
"We only accept 4-bit or 8-bit quantization."
...
...
@@ -482,6 +549,9 @@ class FinetuningArguments(
if
self
.
stage
==
"ppo"
and
self
.
reward_model_type
==
"lora"
and
self
.
finetuning_type
!=
"lora"
:
raise
ValueError
(
"`reward_model_type` cannot be lora for Freeze/Full PPO training."
)
if
self
.
stage
==
"ppo"
and
self
.
reward_model_type
==
"oft"
and
self
.
finetuning_type
!=
"oft"
:
raise
ValueError
(
"`reward_model_type` cannot be oft for Freeze/Full PPO training."
)
if
self
.
stage
==
"dpo"
and
self
.
pref_loss
!=
"sigmoid"
and
self
.
dpo_label_smoothing
>
1e-6
:
raise
ValueError
(
"`dpo_label_smoothing` is only valid for sigmoid loss function."
)
...
...
src/llamafactory/hparams/model_args.py
View file @
ca625f43
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
# Copyright 2025 HuggingFace Inc.
, the KVCache.AI team, Approaching AI,
and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
...
...
@@ -17,26 +17,30 @@
import
json
from
dataclasses
import
asdict
,
dataclass
,
field
,
fields
from
typing
import
Any
,
Literal
,
Optional
,
Union
from
typing
import
Any
,
Literal
,
Self
import
torch
from
omegaconf
import
OmegaConf
from
transformers.training_args
import
_convert_str_dict
from
typing_extensions
import
Self
from
..extras.constants
import
AttentionFunction
,
EngineName
,
QuantizationMethod
,
RopeScaling
from
..extras.logging
import
get_logger
logger
=
get_logger
(
__name__
)
@
dataclass
class
BaseModelArguments
:
r
"""Arguments pertaining to the model."""
model_name_or_path
:
Optional
[
str
]
=
field
(
model_name_or_path
:
str
|
None
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Path to the model weight or identifier from huggingface.co/models or modelscope.cn/models."
},
)
adapter_name_or_path
:
Optional
[
str
]
=
field
(
adapter_name_or_path
:
str
|
None
=
field
(
default
=
None
,
metadata
=
{
"help"
:
(
...
...
@@ -45,11 +49,11 @@ class BaseModelArguments:
)
},
)
adapter_folder
:
Optional
[
str
]
=
field
(
adapter_folder
:
str
|
None
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The folder containing the adapter weights to load."
},
)
cache_dir
:
Optional
[
str
]
=
field
(
cache_dir
:
str
|
None
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Where to store the pre-trained models downloaded from huggingface.co or modelscope.cn."
},
)
...
...
@@ -65,16 +69,38 @@ class BaseModelArguments:
default
=
False
,
metadata
=
{
"help"
:
"Whether or not the special tokens should be split during the tokenization process."
},
)
add_tokens
:
Optional
[
str
]
=
field
(
add_tokens
:
str
|
None
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Non-special tokens to be added into the tokenizer. Use commas to separate multiple tokens."
},
)
add_special_tokens
:
Optional
[
str
]
=
field
(
add_special_tokens
:
str
|
None
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Special tokens to be added into the tokenizer. Use commas to separate multiple tokens."
},
)
new_special_tokens_config
:
str
|
None
=
field
(
default
=
None
,
metadata
=
{
"help"
:
(
"Path to YAML config with special token descriptions for semantic initialization. "
"If set, this takes precedence over add_special_tokens. "
"YAML format: {'<token>': 'description text', ...}"
)
},
)
init_special_tokens
:
Literal
[
"noise_init"
,
"desc_init"
,
"desc_init_w_noise"
]
=
field
(
default
=
"noise_init"
,
metadata
=
{
"help"
:
(
"Initialization method for new special tokens: "
"'noise_init' (default, random noise around mean), "
"'desc_init' (semantic initialization from descriptions), "
"'desc_init_w_noise' (semantic + random noise). "
"Note: 'desc_init' methods require new_special_tokens_config."
)
},
)
model_revision
:
str
=
field
(
default
=
"main"
,
metadata
=
{
"help"
:
"The specific model version to use (can be a branch name, tag name or commit id)."
},
...
...
@@ -83,7 +109,7 @@ class BaseModelArguments:
default
=
True
,
metadata
=
{
"help"
:
"Whether or not to use memory-efficient model loading."
},
)
rope_scaling
:
Optional
[
RopeScaling
]
=
field
(
rope_scaling
:
RopeScaling
|
None
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Which scaling strategy should be adopted for the RoPE embeddings."
},
)
...
...
@@ -95,7 +121,7 @@ class BaseModelArguments:
default
=
False
,
metadata
=
{
"help"
:
"Enable shift short attention (S^2-Attn) proposed by LongLoRA."
},
)
mixture_of_depths
:
Optional
[
Literal
[
"convert"
,
"load"
]
]
=
field
(
mixture_of_depths
:
Literal
[
"convert"
,
"load"
]
|
None
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Convert the model to mixture-of-depths (MoD) or load the MoD model."
},
)
...
...
@@ -111,7 +137,7 @@ class BaseModelArguments:
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to enable liger kernel for faster training."
},
)
moe_aux_loss_coef
:
Optional
[
float
]
=
field
(
moe_aux_loss_coef
:
float
|
None
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Coefficient of the auxiliary router loss in mixture-of-experts model."
},
)
...
...
@@ -143,23 +169,27 @@ class BaseModelArguments:
default
=
"offload"
,
metadata
=
{
"help"
:
"Path to offload model weights."
},
)
use_cache
:
bool
=
field
(
use_
kv_
cache
:
bool
=
field
(
default
=
True
,
metadata
=
{
"help"
:
"Whether or not to use KV cache in generation."
},
)
use_v1_kernels
:
bool
|
None
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to use high-performance kernels in training."
},
)
infer_dtype
:
Literal
[
"auto"
,
"float16"
,
"bfloat16"
,
"float32"
]
=
field
(
default
=
"auto"
,
metadata
=
{
"help"
:
"Data type for model weights and activations at inference."
},
)
hf_hub_token
:
Optional
[
str
]
=
field
(
hf_hub_token
:
str
|
None
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Auth token to log in with Hugging Face Hub."
},
)
ms_hub_token
:
Optional
[
str
]
=
field
(
ms_hub_token
:
str
|
None
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Auth token to log in with ModelScope Hub."
},
)
om_hub_token
:
Optional
[
str
]
=
field
(
om_hub_token
:
str
|
None
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Auth token to log in with Modelers Hub."
},
)
...
...
@@ -185,8 +215,63 @@ class BaseModelArguments:
if
self
.
add_tokens
is
not
None
:
# support multiple tokens
self
.
add_tokens
=
[
token
.
strip
()
for
token
in
self
.
add_tokens
.
split
(
","
)]
if
self
.
add_special_tokens
is
not
None
:
# support multiple special tokens
# Process special tokens with priority: new_special_tokens_config > add_special_tokens
if
self
.
new_special_tokens_config
is
not
None
:
# Priority 1: Load from YAML config (extracts both tokens and descriptions)
try
:
cfg
=
OmegaConf
.
load
(
self
.
new_special_tokens_config
)
token_descriptions
=
OmegaConf
.
to_container
(
cfg
)
if
not
isinstance
(
token_descriptions
,
dict
):
raise
ValueError
(
f
"YAML config must be a dictionary mapping tokens to descriptions. "
f
"Got:
{
type
(
token_descriptions
)
}
"
)
# Extract token list from config keys
extracted_tokens
=
list
(
token_descriptions
.
keys
())
# Warn if both are set
if
self
.
add_special_tokens
is
not
None
:
logger
.
warning_rank0
(
"Both 'new_special_tokens_config' and 'add_special_tokens' are set. "
f
"Using tokens from config:
{
extracted_tokens
}
"
)
# Override add_special_tokens with extracted tokens (as list)
self
.
add_special_tokens
=
extracted_tokens
# Store descriptions internally for later use (internal attribute)
self
.
_special_token_descriptions
=
token_descriptions
logger
.
info_rank0
(
f
"Loaded
{
len
(
extracted_tokens
)
}
special tokens with descriptions from: "
f
"
{
self
.
new_special_tokens_config
}
"
)
except
Exception
as
e
:
logger
.
error_rank0
(
f
"Failed to load special tokens config from '
{
self
.
new_special_tokens_config
}
':
{
e
}
"
)
raise
elif
self
.
add_special_tokens
is
not
None
:
# Priority 2: Use simple comma-separated string (no descriptions)
self
.
add_special_tokens
=
[
token
.
strip
()
for
token
in
self
.
add_special_tokens
.
split
(
","
)]
self
.
_special_token_descriptions
=
None
else
:
# No special tokens to add
self
.
_special_token_descriptions
=
None
# Validate init method
if
self
.
init_special_tokens
in
[
"desc_init"
,
"desc_init_w_noise"
]:
if
self
.
_special_token_descriptions
is
None
:
logger
.
warning_rank0
(
f
"init_special_tokens='
{
self
.
init_special_tokens
}
' requires new_special_tokens_config. "
"Falling back to 'noise_init'"
)
self
.
init_special_tokens
=
"noise_init"
@
dataclass
...
...
@@ -197,7 +282,7 @@ class QuantizationArguments:
default
=
QuantizationMethod
.
BNB
,
metadata
=
{
"help"
:
"Quantization method to use for on-the-fly quantization."
},
)
quantization_bit
:
Optional
[
int
]
=
field
(
quantization_bit
:
int
|
None
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The number of bits to quantize the model using on-the-fly quantization."
},
)
...
...
@@ -209,10 +294,27 @@ class QuantizationArguments:
default
=
True
,
metadata
=
{
"help"
:
"Whether or not to use double quantization in bitsandbytes int4 training."
},
)
quantization_device_map
:
Optional
[
Literal
[
"auto"
]
]
=
field
(
quantization_device_map
:
Literal
[
"auto"
]
|
None
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Device map used to infer the 4-bit quantized model, needs bitsandbytes>=0.43.0."
},
)
fp8
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Enable FP8 mixed precision training via HuggingFace Accelerate. "
"Requires PyTorch 2.7+ and Hopper architecture GPUs."
},
)
fp8_backend
:
str
=
field
(
default
=
"auto"
,
metadata
=
{
"help"
:
"FP8 backend to use ('auto', 'torchao', 'te', 'msamp'). 'auto' selects best available backend."
},
)
fp8_enable_fsdp_float8_all_gather
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Enable FP8 optimizations for FSDP2 all-gather operations."
},
)
@
dataclass
...
...
@@ -272,7 +374,7 @@ class ProcessorArguments:
class
ExportArguments
:
r
"""Arguments pertaining to the model export."""
export_dir
:
Optional
[
str
]
=
field
(
export_dir
:
str
|
None
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Path to the directory to save the exported model."
},
)
...
...
@@ -284,11 +386,11 @@ class ExportArguments:
default
=
"cpu"
,
metadata
=
{
"help"
:
"The device used in model export, use `auto` to accelerate exporting."
},
)
export_quantization_bit
:
Optional
[
int
]
=
field
(
export_quantization_bit
:
int
|
None
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The number of bits to quantize the exported model."
},
)
export_quantization_dataset
:
Optional
[
str
]
=
field
(
export_quantization_dataset
:
str
|
None
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Path to the dataset or dataset name to use in quantizing the exported model."
},
)
...
...
@@ -304,7 +406,7 @@ class ExportArguments:
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to save the `.bin` files instead of `.safetensors`."
},
)
export_hub_model_id
:
Optional
[
str
]
=
field
(
export_hub_model_id
:
str
|
None
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The name of the repository if push the model to the Hugging Face hub."
},
)
...
...
@@ -334,7 +436,7 @@ class VllmArguments:
default
=
32
,
metadata
=
{
"help"
:
"Maximum rank of all LoRAs in the vLLM engine."
},
)
vllm_config
:
Optional
[
Union
[
dict
,
str
]]
=
field
(
vllm_config
:
dict
|
str
|
None
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Config to initialize the vllm engine. Please use JSON strings."
},
)
...
...
@@ -360,7 +462,7 @@ class SGLangArguments:
default
=-
1
,
metadata
=
{
"help"
:
"Tensor parallel size for the SGLang engine."
},
)
sglang_config
:
Optional
[
Union
[
dict
,
str
]]
=
field
(
sglang_config
:
dict
|
str
|
None
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Config to initialize the SGLang engine. Please use JSON strings."
},
)
...
...
@@ -376,26 +478,77 @@ class SGLangArguments:
self
.
sglang_config
=
_convert_str_dict
(
json
.
loads
(
self
.
sglang_config
))
@
dataclass
class
KTransformersArguments
:
r
"""Arguments pertaining to the KT training."""
use_kt
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether To Use KTransformers Optimizations For LoRA Training."
},
)
kt_optimize_rule
:
str
|
None
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Path To The KTransformers Optimize Rule; See https://github.com/kvcache-ai/ktransformers/."
},
)
cpu_infer
:
int
|
None
=
field
(
default
=
32
,
metadata
=
{
"help"
:
"Number Of CPU Cores Used For Computation."
},
)
chunk_size
:
int
|
None
=
field
(
default
=
8192
,
metadata
=
{
"help"
:
"Chunk Size Used For CPU Compute In KTransformers."
},
)
mode
:
str
|
None
=
field
(
default
=
"normal"
,
metadata
=
{
"help"
:
"Normal Or Long_Context For Llama Models."
},
)
kt_maxlen
:
int
=
field
(
default
=
4096
,
metadata
=
{
"help"
:
"Maximum Sequence (Prompt + Response) Length Of The KT Engine."
},
)
kt_use_cuda_graph
:
bool
=
field
(
default
=
True
,
metadata
=
{
"help"
:
"Whether To Use CUDA Graphs For The KT Engine."
},
)
kt_mode
:
str
=
field
(
default
=
"normal"
,
metadata
=
{
"help"
:
"Normal Or Long_Context Mode For The KT Engine."
},
)
kt_force_think
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Force-Think Toggle For The KT Engine."
},
)
@
dataclass
class
ModelArguments
(
SGLangArguments
,
VllmArguments
,
ExportArguments
,
ProcessorArguments
,
QuantizationArguments
,
BaseModelArguments
SGLangArguments
,
VllmArguments
,
KTransformersArguments
,
ExportArguments
,
ProcessorArguments
,
QuantizationArguments
,
BaseModelArguments
,
):
r
"""Arguments pertaining to which model/config/tokenizer we are going to fine-tune or infer.
The class on the most right will be displayed first.
"""
compute_dtype
:
Optional
[
torch
.
dtype
]
=
field
(
compute_dtype
:
torch
.
dtype
|
None
=
field
(
default
=
None
,
init
=
False
,
metadata
=
{
"help"
:
"Torch data type for computing model outputs, derived from `fp/bf16`. Do not specify it."
},
)
device_map
:
Optional
[
Union
[
str
,
dict
[
str
,
Any
]
]]
=
field
(
device_map
:
str
|
dict
[
str
,
Any
]
|
None
=
field
(
default
=
None
,
init
=
False
,
metadata
=
{
"help"
:
"Device map for model placement, derived from training stage. Do not specify it."
},
)
model_max_length
:
Optional
[
int
]
=
field
(
model_max_length
:
int
|
None
=
field
(
default
=
None
,
init
=
False
,
metadata
=
{
"help"
:
"The maximum input length for model, derived from `cutoff_len`. Do not specify it."
},
...
...
src/llamafactory/hparams/parser.py
View file @
ca625f43
...
...
@@ -18,7 +18,7 @@
import
os
import
sys
from
pathlib
import
Path
from
typing
import
Any
,
Optional
,
Union
from
typing
import
Any
,
Optional
import
torch
import
transformers
...
...
@@ -32,6 +32,7 @@ from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_availab
from
..extras
import
logging
from
..extras.constants
import
CHECKPOINT_NAMES
,
EngineName
from
..extras.misc
import
check_dependencies
,
check_version
,
get_current_device
,
is_env_enabled
from
..extras.packages
import
is_mcore_adapter_available
,
is_transformers_version_greater_than
from
.data_args
import
DataArguments
from
.evaluation_args
import
EvaluationArguments
from
.finetuning_args
import
FinetuningArguments
...
...
@@ -52,8 +53,19 @@ _INFER_CLS = tuple[ModelArguments, DataArguments, FinetuningArguments, Generatin
_EVAL_ARGS
=
[
ModelArguments
,
DataArguments
,
EvaluationArguments
,
FinetuningArguments
]
_EVAL_CLS
=
tuple
[
ModelArguments
,
DataArguments
,
EvaluationArguments
,
FinetuningArguments
]
if
is_mcore_adapter_available
()
and
is_env_enabled
(
"USE_MCA"
):
from
mcore_adapter
import
TrainingArguments
as
McaTrainingArguments
def
read_args
(
args
:
Optional
[
Union
[
dict
[
str
,
Any
],
list
[
str
]]]
=
None
)
->
Union
[
dict
[
str
,
Any
],
list
[
str
]]:
_TRAIN_MCA_ARGS
=
[
ModelArguments
,
DataArguments
,
McaTrainingArguments
,
FinetuningArguments
,
GeneratingArguments
]
_TRAIN_MCA_CLS
=
tuple
[
ModelArguments
,
DataArguments
,
McaTrainingArguments
,
FinetuningArguments
,
GeneratingArguments
]
else
:
_TRAIN_MCA_ARGS
=
[]
_TRAIN_MCA_CLS
=
tuple
()
def
read_args
(
args
:
dict
[
str
,
Any
]
|
list
[
str
]
|
None
=
None
)
->
dict
[
str
,
Any
]
|
list
[
str
]:
r
"""Get arguments from the command line or a config file."""
if
args
is
not
None
:
return
args
...
...
@@ -71,7 +83,7 @@ def read_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> Union[
def
_parse_args
(
parser
:
"HfArgumentParser"
,
args
:
Optional
[
Union
[
dict
[
str
,
Any
]
,
list
[
str
]
]]
=
None
,
allow_extra_keys
:
bool
=
False
parser
:
"HfArgumentParser"
,
args
:
dict
[
str
,
Any
]
|
list
[
str
]
|
None
=
None
,
allow_extra_keys
:
bool
=
False
)
->
tuple
[
Any
]:
args
=
read_args
(
args
)
if
isinstance
(
args
,
dict
):
...
...
@@ -111,8 +123,8 @@ def _verify_model_args(
raise
ValueError
(
"Adapter is only valid for the LoRA method."
)
if
model_args
.
quantization_bit
is
not
None
:
if
finetuning_args
.
finetuning_type
!=
"lora"
:
raise
ValueError
(
"Quantization is only compatible with the LoRA method."
)
if
finetuning_args
.
finetuning_type
not
in
[
"lora"
,
"oft"
]
:
raise
ValueError
(
"Quantization is only compatible with the LoRA
or OFT
method."
)
if
finetuning_args
.
pissa_init
:
raise
ValueError
(
"Please use scripts/pissa_init.py to initialize PiSSA for a quantized model."
)
...
...
@@ -130,12 +142,23 @@ def _verify_model_args(
logger
.
warning_rank0
(
"We should use slow tokenizer for the Yi models. Change `use_fast_tokenizer` to False."
)
model_args
.
use_fast_tokenizer
=
False
# Validate advanced training features
if
model_args
.
fp8
and
model_args
.
quantization_bit
is
not
None
:
raise
ValueError
(
"FP8 training is not compatible with quantization. Please disable one of them."
)
if
model_args
.
fp8_enable_fsdp_float8_all_gather
and
not
model_args
.
fp8
:
logger
.
warning_rank0
(
"fp8_enable_fsdp_float8_all_gather requires fp8=True. Setting fp8=True."
)
model_args
.
fp8
=
True
def
_check_extra_dependencies
(
model_args
:
"ModelArguments"
,
finetuning_args
:
"FinetuningArguments"
,
training_args
:
Optional
[
"TrainingArguments"
]
=
None
,
)
->
None
:
if
model_args
.
use_kt
:
check_version
(
"ktransformers"
,
mandatory
=
True
)
if
model_args
.
use_unsloth
:
check_version
(
"unsloth"
,
mandatory
=
True
)
...
...
@@ -146,7 +169,7 @@ def _check_extra_dependencies(
check_version
(
"mixture-of-depth>=1.1.6"
,
mandatory
=
True
)
if
model_args
.
infer_backend
==
EngineName
.
VLLM
:
check_version
(
"vllm>=0.4.3,<=0.
9.1
"
)
check_version
(
"vllm>=0.4.3,<=0.
11.0
"
)
check_version
(
"vllm"
,
mandatory
=
True
)
elif
model_args
.
infer_backend
==
EngineName
.
SGLANG
:
check_version
(
"sglang>=0.4.5"
)
...
...
@@ -173,7 +196,8 @@ def _check_extra_dependencies(
if
training_args
is
not
None
:
if
training_args
.
deepspeed
:
# pin deepspeed version < 0.17 because of https://github.com/deepspeedai/DeepSpeed/issues/7347
check_version
(
"deepspeed>=0.10.0,<=0.16.9"
,
mandatory
=
True
)
check_version
(
"deepspeed"
,
mandatory
=
True
)
check_version
(
"deepspeed>=0.10.0,<=0.16.9"
)
if
training_args
.
predict_with_generate
:
check_version
(
"jieba"
,
mandatory
=
True
)
...
...
@@ -181,32 +205,57 @@ def _check_extra_dependencies(
check_version
(
"rouge_chinese"
,
mandatory
=
True
)
def
_parse_train_args
(
args
:
Optional
[
Union
[
dict
[
str
,
Any
]
,
list
[
str
]
]]
=
None
)
->
_TRAIN_CLS
:
def
_parse_train_args
(
args
:
dict
[
str
,
Any
]
|
list
[
str
]
|
None
=
None
)
->
_TRAIN_CLS
:
parser
=
HfArgumentParser
(
_TRAIN_ARGS
)
allow_extra_keys
=
is_env_enabled
(
"ALLOW_EXTRA_ARGS"
)
return
_parse_args
(
parser
,
args
,
allow_extra_keys
=
allow_extra_keys
)
def
_parse_infer_args
(
args
:
Optional
[
Union
[
dict
[
str
,
Any
],
list
[
str
]]]
=
None
)
->
_INFER_CLS
:
def
_parse_train_mca_args
(
args
:
dict
[
str
,
Any
]
|
list
[
str
]
|
None
=
None
)
->
_TRAIN_MCA_CLS
:
parser
=
HfArgumentParser
(
_TRAIN_MCA_ARGS
)
allow_extra_keys
=
is_env_enabled
(
"ALLOW_EXTRA_ARGS"
)
model_args
,
data_args
,
training_args
,
finetuning_args
,
generating_args
=
_parse_args
(
parser
,
args
,
allow_extra_keys
=
allow_extra_keys
)
_configure_mca_training_args
(
training_args
,
data_args
,
finetuning_args
)
return
model_args
,
data_args
,
training_args
,
finetuning_args
,
generating_args
def
_configure_mca_training_args
(
training_args
,
data_args
,
finetuning_args
)
->
None
:
"""Patch training args to avoid args checking errors and sync MCA settings."""
training_args
.
predict_with_generate
=
False
training_args
.
generation_max_length
=
data_args
.
cutoff_len
training_args
.
generation_num_beams
=
1
training_args
.
use_mca
=
True
finetuning_args
.
use_mca
=
True
def
_parse_infer_args
(
args
:
dict
[
str
,
Any
]
|
list
[
str
]
|
None
=
None
)
->
_INFER_CLS
:
parser
=
HfArgumentParser
(
_INFER_ARGS
)
allow_extra_keys
=
is_env_enabled
(
"ALLOW_EXTRA_ARGS"
)
return
_parse_args
(
parser
,
args
,
allow_extra_keys
=
allow_extra_keys
)
def
_parse_eval_args
(
args
:
Optional
[
Union
[
dict
[
str
,
Any
]
,
list
[
str
]
]]
=
None
)
->
_EVAL_CLS
:
def
_parse_eval_args
(
args
:
dict
[
str
,
Any
]
|
list
[
str
]
|
None
=
None
)
->
_EVAL_CLS
:
parser
=
HfArgumentParser
(
_EVAL_ARGS
)
allow_extra_keys
=
is_env_enabled
(
"ALLOW_EXTRA_ARGS"
)
return
_parse_args
(
parser
,
args
,
allow_extra_keys
=
allow_extra_keys
)
def
get_ray_args
(
args
:
Optional
[
Union
[
dict
[
str
,
Any
]
,
list
[
str
]
]]
=
None
)
->
RayArguments
:
def
get_ray_args
(
args
:
dict
[
str
,
Any
]
|
list
[
str
]
|
None
=
None
)
->
RayArguments
:
parser
=
HfArgumentParser
(
RayArguments
)
(
ray_args
,)
=
_parse_args
(
parser
,
args
,
allow_extra_keys
=
True
)
return
ray_args
def
get_train_args
(
args
:
Optional
[
Union
[
dict
[
str
,
Any
],
list
[
str
]]]
=
None
)
->
_TRAIN_CLS
:
model_args
,
data_args
,
training_args
,
finetuning_args
,
generating_args
=
_parse_train_args
(
args
)
def
get_train_args
(
args
:
dict
[
str
,
Any
]
|
list
[
str
]
|
None
=
None
)
->
_TRAIN_CLS
:
if
is_env_enabled
(
"USE_MCA"
):
model_args
,
data_args
,
training_args
,
finetuning_args
,
generating_args
=
_parse_train_mca_args
(
args
)
else
:
model_args
,
data_args
,
training_args
,
finetuning_args
,
generating_args
=
_parse_train_args
(
args
)
finetuning_args
.
use_mca
=
False
# Setup logging
if
training_args
.
should_log
:
...
...
@@ -236,13 +285,16 @@ def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
if
model_args
.
shift_attn
:
raise
ValueError
(
"PPO training is incompatible with S^2-Attn."
)
if
finetuning_args
.
reward_model_type
==
"lora"
and
model_args
.
use_kt
:
raise
ValueError
(
"KTransformers does not support lora reward model."
)
if
finetuning_args
.
reward_model_type
==
"lora"
and
model_args
.
use_unsloth
:
raise
ValueError
(
"Unsloth does not support lora reward model."
)
if
training_args
.
report_to
and
training_args
.
report_to
[
0
]
not
in
[
"wandb"
,
"tensorboard"
]:
raise
ValueError
(
"PPO only accepts wandb or tensorboard logger."
)
if
training_args
.
parallel_mode
==
ParallelMode
.
NOT_DISTRIBUTED
:
if
not
model_args
.
use_kt
and
training_args
.
parallel_mode
==
ParallelMode
.
NOT_DISTRIBUTED
:
raise
ValueError
(
"Please launch distributed training with `llamafactory-cli` or `torchrun`."
)
if
training_args
.
deepspeed
and
training_args
.
parallel_mode
!=
ParallelMode
.
DISTRIBUTED
:
...
...
@@ -254,18 +306,15 @@ def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
if
training_args
.
do_train
and
data_args
.
dataset
is
None
:
raise
ValueError
(
"Please specify dataset for training."
)
if
(
training_args
.
do_eval
or
training_args
.
do_predict
)
and
(
if
(
training_args
.
do_eval
or
training_args
.
do_predict
or
training_args
.
predict_with_generate
)
and
(
data_args
.
eval_dataset
is
None
and
data_args
.
val_size
<
1e-6
):
raise
ValueError
(
"Please
specify dataset f
or
e
val
uation.
"
)
raise
ValueError
(
"Please
make sure eval_dataset be provided
or val
_size >1e-6
"
)
if
training_args
.
predict_with_generate
:
if
is_deepspeed_zero3_enabled
():
raise
ValueError
(
"`predict_with_generate` is incompatible with DeepSpeed ZeRO-3."
)
if
data_args
.
eval_dataset
is
None
:
raise
ValueError
(
"Cannot use `predict_with_generate` if `eval_dataset` is None."
)
if
finetuning_args
.
compute_accuracy
:
raise
ValueError
(
"Cannot use `predict_with_generate` and `compute_accuracy` together."
)
...
...
@@ -304,6 +353,12 @@ def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
if
model_args
.
use_unsloth
and
is_deepspeed_zero3_enabled
():
raise
ValueError
(
"Unsloth is incompatible with DeepSpeed ZeRO-3."
)
if
model_args
.
use_kt
and
is_deepspeed_zero3_enabled
():
raise
ValueError
(
"KTransformers is incompatible with DeepSpeed ZeRO-3."
)
if
data_args
.
neat_packing
and
is_transformers_version_greater_than
(
"4.53.0"
):
raise
ValueError
(
"Neat packing is incompatible with transformers>=4.53.0."
)
_set_env_vars
()
_verify_model_args
(
model_args
,
data_args
,
finetuning_args
)
_check_extra_dependencies
(
model_args
,
finetuning_args
,
training_args
)
...
...
@@ -418,7 +473,7 @@ def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
return
model_args
,
data_args
,
training_args
,
finetuning_args
,
generating_args
def
get_infer_args
(
args
:
Optional
[
Union
[
dict
[
str
,
Any
]
,
list
[
str
]
]]
=
None
)
->
_INFER_CLS
:
def
get_infer_args
(
args
:
dict
[
str
,
Any
]
|
list
[
str
]
|
None
=
None
)
->
_INFER_CLS
:
model_args
,
data_args
,
finetuning_args
,
generating_args
=
_parse_infer_args
(
args
)
# Setup logging
...
...
@@ -453,7 +508,7 @@ def get_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
return
model_args
,
data_args
,
finetuning_args
,
generating_args
def
get_eval_args
(
args
:
Optional
[
Union
[
dict
[
str
,
Any
]
,
list
[
str
]
]]
=
None
)
->
_EVAL_CLS
:
def
get_eval_args
(
args
:
dict
[
str
,
Any
]
|
list
[
str
]
|
None
=
None
)
->
_EVAL_CLS
:
model_args
,
data_args
,
eval_args
,
finetuning_args
=
_parse_eval_args
(
args
)
# Setup logging
...
...
src/llamafactory/hparams/training_args.py
View file @
ca625f43
...
...
@@ -14,19 +14,33 @@
import
json
from
dataclasses
import
dataclass
,
field
from
typing
import
Literal
,
Optional
,
Union
from
typing
import
Literal
from
transformers
import
Seq2SeqTrainingArguments
from
transformers.training_args
import
_convert_str_dict
from
..extras.misc
import
use_ray
from
..extras.misc
import
is_env_enabled
,
use_ray
from
..extras.packages
import
is_mcore_adapter_available
if
is_env_enabled
(
"USE_MCA"
):
if
not
is_mcore_adapter_available
():
raise
ImportError
(
"mcore_adapter is required when USE_MCA=1. Please install `mcore_adapter` and its dependencies."
)
from
mcore_adapter
import
Seq2SeqTrainingArguments
as
McaSeq2SeqTrainingArguments
BaseTrainingArguments
=
McaSeq2SeqTrainingArguments
else
:
BaseTrainingArguments
=
Seq2SeqTrainingArguments
@
dataclass
class
RayArguments
:
r
"""Arguments pertaining to the Ray training."""
ray_run_name
:
Optional
[
str
]
=
field
(
ray_run_name
:
str
|
None
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The training results will be saved at `<ray_storage_path>/ray_run_name`."
},
)
...
...
@@ -34,7 +48,7 @@ class RayArguments:
default
=
"./saves"
,
metadata
=
{
"help"
:
"The storage path to save training results to"
},
)
ray_storage_filesystem
:
Optional
[
Literal
[
"s3"
,
"gs"
,
"gcs"
]
]
=
field
(
ray_storage_filesystem
:
Literal
[
"s3"
,
"gs"
,
"gcs"
]
|
None
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The storage filesystem to use. If None specified, local filesystem will be used."
},
)
...
...
@@ -42,7 +56,7 @@ class RayArguments:
default
=
1
,
metadata
=
{
"help"
:
"The number of workers for Ray training. Default is 1 worker."
},
)
resources_per_worker
:
Union
[
dict
,
str
]
=
field
(
resources_per_worker
:
dict
|
str
=
field
(
default_factory
=
lambda
:
{
"GPU"
:
1
},
metadata
=
{
"help"
:
"The resources per worker for Ray training. Default is to use 1 GPU per worker."
},
)
...
...
@@ -50,7 +64,7 @@ class RayArguments:
default
=
"PACK"
,
metadata
=
{
"help"
:
"The placement strategy for Ray training. Default is PACK."
},
)
ray_init_kwargs
:
Optional
[
dict
]
=
field
(
ray_init_kwargs
:
dict
|
str
|
None
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The arguments to pass to ray.init for Ray training. Default is None."
},
)
...
...
@@ -59,10 +73,14 @@ class RayArguments:
self
.
use_ray
=
use_ray
()
if
isinstance
(
self
.
resources_per_worker
,
str
)
and
self
.
resources_per_worker
.
startswith
(
"{"
):
self
.
resources_per_worker
=
_convert_str_dict
(
json
.
loads
(
self
.
resources_per_worker
))
if
isinstance
(
self
.
ray_init_kwargs
,
str
)
and
self
.
ray_init_kwargs
.
startswith
(
"{"
):
self
.
ray_init_kwargs
=
_convert_str_dict
(
json
.
loads
(
self
.
ray_init_kwargs
))
if
self
.
ray_storage_filesystem
is
not
None
:
if
self
.
ray_storage_filesystem
not
in
[
"s3"
,
"gs"
,
"gcs"
]:
raise
ValueError
(
f
"ray_storage_filesystem must be one of ['s3', 'gs', 'gcs'], got
{
self
.
ray_storage_filesystem
}
"
f
"ray_storage_filesystem must be one of ['s3', 'gs', 'gcs'], got
{
self
.
ray_storage_filesystem
}
.
"
)
import
pyarrow.fs
as
fs
...
...
@@ -74,9 +92,14 @@ class RayArguments:
@
dataclass
class
TrainingArguments
(
RayArguments
,
Seq2Seq
TrainingArguments
):
class
TrainingArguments
(
RayArguments
,
Base
TrainingArguments
):
r
"""Arguments pertaining to the trainer."""
overwrite_output_dir
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"deprecated"
},
)
def
__post_init__
(
self
):
Seq2SeqTrainingArguments
.
__post_init__
(
self
)
RayArguments
.
__post_init__
(
self
)
BaseTrainingArguments
.
__post_init__
(
self
)
src/llamafactory/launcher.py
View file @
ca625f43
...
...
@@ -12,12 +12,174 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
llamafactory.train.tuner
import
run_exp
# use absolute import
import
os
import
subprocess
import
sys
from
copy
import
deepcopy
USAGE
=
(
"-"
*
70
+
"
\n
"
+
"| Usage: |
\n
"
+
"| llamafactory-cli api -h: launch an OpenAI-style API server |
\n
"
+
"| llamafactory-cli chat -h: launch a chat interface in CLI |
\n
"
+
"| llamafactory-cli export -h: merge LoRA adapters and export model |
\n
"
+
"| llamafactory-cli train -h: train models |
\n
"
+
"| llamafactory-cli webchat -h: launch a chat interface in Web UI |
\n
"
+
"| llamafactory-cli webui: launch LlamaBoard |
\n
"
+
"| llamafactory-cli env: show environment info |
\n
"
+
"| llamafactory-cli version: show version info |
\n
"
+
"| Hint: You can use `lmf` as a shortcut for `llamafactory-cli`. |
\n
"
+
"-"
*
70
)
def
launch
():
run_exp
()
from
.extras
import
logging
from
.extras.env
import
VERSION
,
print_env
from
.extras.misc
import
find_available_port
,
get_device_count
,
is_env_enabled
,
use_kt
,
use_ray
logger
=
logging
.
get_logger
(
__name__
)
WELCOME
=
(
"-"
*
58
+
"
\n
"
+
f
"| Welcome to LLaMA Factory, version
{
VERSION
}
"
+
" "
*
(
21
-
len
(
VERSION
))
+
"|
\n
|"
+
" "
*
56
+
"|
\n
"
+
"| Project page: https://github.com/hiyouga/LLaMA-Factory |
\n
"
+
"-"
*
58
)
command
=
sys
.
argv
.
pop
(
1
)
if
len
(
sys
.
argv
)
>
1
else
"help"
if
is_env_enabled
(
"USE_MCA"
):
# force use torchrun
os
.
environ
[
"FORCE_TORCHRUN"
]
=
"1"
if
command
==
"train"
and
(
is_env_enabled
(
"FORCE_TORCHRUN"
)
or
(
get_device_count
()
>
1
and
not
use_ray
()
and
not
use_kt
())
):
# launch distributed training
nnodes
=
os
.
getenv
(
"NNODES"
,
"1"
)
node_rank
=
os
.
getenv
(
"NODE_RANK"
,
"0"
)
nproc_per_node
=
os
.
getenv
(
"NPROC_PER_NODE"
,
str
(
get_device_count
()))
master_addr
=
os
.
getenv
(
"MASTER_ADDR"
,
"127.0.0.1"
)
master_port
=
os
.
getenv
(
"MASTER_PORT"
,
str
(
find_available_port
()))
logger
.
info_rank0
(
f
"Initializing
{
nproc_per_node
}
distributed tasks at:
{
master_addr
}
:
{
master_port
}
"
)
if
int
(
nnodes
)
>
1
:
logger
.
info_rank0
(
f
"Multi-node training enabled: num nodes:
{
nnodes
}
, node rank:
{
node_rank
}
"
)
# elastic launch support
max_restarts
=
os
.
getenv
(
"MAX_RESTARTS"
,
"0"
)
rdzv_id
=
os
.
getenv
(
"RDZV_ID"
)
min_nnodes
=
os
.
getenv
(
"MIN_NNODES"
)
max_nnodes
=
os
.
getenv
(
"MAX_NNODES"
)
env
=
deepcopy
(
os
.
environ
)
if
is_env_enabled
(
"OPTIM_TORCH"
,
"1"
):
# optimize DDP, see https://zhuanlan.zhihu.com/p/671834539
env
[
"PYTORCH_CUDA_ALLOC_CONF"
]
=
"expandable_segments:True"
env
[
"TORCH_NCCL_AVOID_RECORD_STREAMS"
]
=
"1"
if
rdzv_id
is
not
None
:
# launch elastic job with fault tolerant support when possible
# see also https://docs.pytorch.org/docs/stable/elastic/train_script.html
rdzv_nnodes
=
nnodes
# elastic number of nodes if MIN_NNODES and MAX_NNODES are set
if
min_nnodes
is
not
None
and
max_nnodes
is
not
None
:
rdzv_nnodes
=
f
"
{
min_nnodes
}
:
{
max_nnodes
}
"
process
=
subprocess
.
run
(
(
"torchrun --nnodes {rdzv_nnodes} --nproc-per-node {nproc_per_node} "
"--rdzv-id {rdzv_id} --rdzv-backend c10d --rdzv-endpoint {master_addr}:{master_port} "
"--max-restarts {max_restarts} {file_name} {args}"
)
.
format
(
rdzv_nnodes
=
rdzv_nnodes
,
nproc_per_node
=
nproc_per_node
,
rdzv_id
=
rdzv_id
,
master_addr
=
master_addr
,
master_port
=
master_port
,
max_restarts
=
max_restarts
,
file_name
=
__file__
,
args
=
" "
.
join
(
sys
.
argv
[
1
:]),
)
.
split
(),
env
=
env
,
check
=
True
,
)
else
:
# NOTE: DO NOT USE shell=True to avoid security risk
process
=
subprocess
.
run
(
(
"torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} "
"--master_addr {master_addr} --master_port {master_port} {file_name} {args}"
)
.
format
(
nnodes
=
nnodes
,
node_rank
=
node_rank
,
nproc_per_node
=
nproc_per_node
,
master_addr
=
master_addr
,
master_port
=
master_port
,
file_name
=
__file__
,
args
=
" "
.
join
(
sys
.
argv
[
1
:]),
)
.
split
(),
env
=
env
,
check
=
True
,
)
sys
.
exit
(
process
.
returncode
)
elif
command
==
"api"
:
from
.api.app
import
run_api
run_api
()
elif
command
==
"chat"
:
from
.chat.chat_model
import
run_chat
run_chat
()
elif
command
==
"eval"
:
raise
NotImplementedError
(
"Evaluation will be deprecated in the future."
)
elif
command
==
"export"
:
from
.train.tuner
import
export_model
export_model
()
elif
command
==
"train"
:
from
.train.tuner
import
run_exp
run_exp
()
elif
command
==
"webchat"
:
from
.webui.interface
import
run_web_demo
run_web_demo
()
elif
command
==
"webui"
:
from
.webui.interface
import
run_web_ui
run_web_ui
()
elif
command
==
"env"
:
print_env
()
elif
command
==
"version"
:
print
(
WELCOME
)
elif
command
==
"help"
:
print
(
USAGE
)
else
:
print
(
f
"Unknown command:
{
command
}
.
\n
{
USAGE
}
"
)
if
__name__
==
"__main__"
:
launch
()
from
llamafactory.train.tuner
import
run_exp
# use absolute import
run_exp
()
src/llamafactory/model/adapter.py
View file @
ca625f43
...
...
@@ -16,10 +16,12 @@ import re
from
typing
import
TYPE_CHECKING
import
torch
from
peft
import
LoraConfig
,
LoraModel
,
PeftModel
,
TaskType
,
get_peft_model
from
peft
import
LoraConfig
,
LoraModel
,
OFTConfig
,
PeftModel
,
TaskType
,
get_peft_model
from
transformers.integrations
import
is_deepspeed_zero3_enabled
from
..extras
import
logging
from
..extras.constants
import
EngineName
from
.model_utils.ktransformers
import
get_kt_peft_model
,
load_kt_peft_model
from
.model_utils.misc
import
find_all_linear_modules
,
find_expanded_modules
from
.model_utils.quantization
import
QuantizationMethod
from
.model_utils.unsloth
import
get_unsloth_peft_model
,
load_unsloth_peft_model
...
...
@@ -147,7 +149,10 @@ def _setup_lora_tuning(
cast_trainable_params_to_fp32
:
bool
,
)
->
"PeftModel"
:
if
is_trainable
:
logger
.
info_rank0
(
"Fine-tuning method: {}"
.
format
(
"DoRA"
if
finetuning_args
.
use_dora
else
"LoRA"
))
if
finetuning_args
.
finetuning_type
==
"oft"
:
logger
.
info_rank0
(
"Fine-tuning method: OFT"
)
else
:
logger
.
info_rank0
(
"Fine-tuning method: {}"
.
format
(
"DoRA"
if
finetuning_args
.
use_dora
else
"LoRA"
))
adapter_to_resume
=
None
...
...
@@ -161,6 +166,10 @@ def _setup_lora_tuning(
assert
len
(
model_args
.
adapter_name_or_path
)
==
1
,
"Cannot use multiple adapters in DeepSpeed ZeRO-3."
is_mergeable
=
False
if
model_args
.
use_kt
:
assert
len
(
model_args
.
adapter_name_or_path
)
==
1
,
"KTransformers model only accepts a single adapter"
is_mergeable
=
False
if
model_args
.
use_unsloth
:
assert
len
(
model_args
.
adapter_name_or_path
)
==
1
,
"Unsloth model only accepts a single adapter."
is_mergeable
=
False
...
...
@@ -179,6 +188,12 @@ def _setup_lora_tuning(
"token"
:
model_args
.
hf_hub_token
,
}
if
model_args
.
use_kt
:
if
model_args
.
infer_backend
!=
EngineName
.
KT
:
raise
ValueError
(
"We should use ktransformers as backend to infer the adapter fine-tuned by ktransformers."
)
for
adapter
in
adapter_to_merge
:
model
:
LoraModel
=
PeftModel
.
from_pretrained
(
model
,
adapter
,
**
init_kwargs
)
model
=
model
.
merge_and_unload
()
...
...
@@ -187,7 +202,9 @@ def _setup_lora_tuning(
logger
.
info_rank0
(
f
"Merged
{
len
(
adapter_to_merge
)
}
adapter(s)."
)
if
adapter_to_resume
is
not
None
:
# resume lora training
if
model_args
.
use_unsloth
:
if
model_args
.
use_kt
:
model
=
load_kt_peft_model
(
model_args
,
model
)
elif
model_args
.
use_unsloth
:
model
=
load_unsloth_peft_model
(
config
,
model_args
,
finetuning_args
,
is_trainable
=
is_trainable
)
else
:
model
=
PeftModel
.
from_pretrained
(
model
,
adapter_to_resume
,
is_trainable
=
is_trainable
,
**
init_kwargs
)
...
...
@@ -200,6 +217,16 @@ def _setup_lora_tuning(
else
:
target_modules
=
finetuning_args
.
lora_target
if
model_args
.
use_kt
:
new_list
=
[]
for
m
in
target_modules
:
if
m
in
(
"down_proj"
,
"up_proj"
,
"gate_proj"
):
new_list
.
extend
([
f
"mlp.
{
m
}
"
,
f
"shared_experts.
{
m
}
"
])
elif
m
not
in
(
"generate_linear"
,
"orig_module"
,
"prefill_linear"
):
new_list
.
append
(
m
)
target_modules
[:]
=
new_list
if
finetuning_args
.
use_llama_pro
:
target_modules
=
find_expanded_modules
(
model
,
target_modules
,
finetuning_args
.
freeze_trainable_layers
)
...
...
@@ -223,17 +250,43 @@ def _setup_lora_tuning(
finetuning_args
.
additional_target
=
module_names
logger
.
warning_rank0
(
"Vocab has been resized, add {} to trainable params."
.
format
(
","
.
join
(
module_names
)))
peft_kwargs
=
{
"r"
:
finetuning_args
.
lora_rank
,
"target_modules"
:
target_modules
,
"lora_alpha"
:
finetuning_args
.
lora_alpha
,
"lora_dropout"
:
finetuning_args
.
lora_dropout
,
"use_rslora"
:
finetuning_args
.
use_rslora
,
"use_dora"
:
finetuning_args
.
use_dora
,
"modules_to_save"
:
finetuning_args
.
additional_target
,
}
if
finetuning_args
.
finetuning_type
==
"lora"
:
peft_kwargs
=
{
"r"
:
finetuning_args
.
lora_rank
,
"target_modules"
:
target_modules
,
"lora_alpha"
:
finetuning_args
.
lora_alpha
,
"lora_dropout"
:
finetuning_args
.
lora_dropout
,
"use_rslora"
:
finetuning_args
.
use_rslora
,
"use_dora"
:
finetuning_args
.
use_dora
,
"modules_to_save"
:
finetuning_args
.
additional_target
,
}
elif
finetuning_args
.
finetuning_type
==
"oft"
:
peft_kwargs
=
{
"r"
:
finetuning_args
.
oft_rank
,
"oft_block_size"
:
finetuning_args
.
oft_block_size
,
"target_modules"
:
target_modules
,
"module_dropout"
:
finetuning_args
.
module_dropout
,
"modules_to_save"
:
finetuning_args
.
additional_target
,
}
if
model_args
.
use_kt
:
if
finetuning_args
.
finetuning_type
==
"oft"
:
raise
ValueError
(
"KTransformers is currently not supported for OFT."
)
if
finetuning_args
.
finetuning_type
==
"lora"
:
peft_config
=
LoraConfig
(
task_type
=
TaskType
.
CAUSAL_LM
,
inference_mode
=
False
,
**
peft_kwargs
,
)
else
:
raise
ValueError
(
"KTransformers is currently only supported for LoRA."
)
model
=
get_kt_peft_model
(
model
,
peft_config
)
print
(
f
"KT_model:
{
model
}
"
)
elif
model_args
.
use_unsloth
:
if
finetuning_args
.
finetuning_type
==
"oft"
:
raise
ValueError
(
"Unsloth is currently not supported for OFT."
)
if
model_args
.
use_unsloth
:
model
=
get_unsloth_peft_model
(
model
,
model_args
,
peft_kwargs
)
else
:
if
finetuning_args
.
pissa_init
:
...
...
@@ -244,12 +297,19 @@ def _setup_lora_tuning(
logger
.
info_rank0
(
f
"Using PiSSA initialization with FSVD steps
{
finetuning_args
.
pissa_iter
}
."
)
peft_kwargs
[
"init_lora_weights"
]
=
f
"pissa_niter_
{
finetuning_args
.
pissa_iter
}
"
lora_config
=
LoraConfig
(
task_type
=
TaskType
.
CAUSAL_LM
,
inference_mode
=
False
,
**
peft_kwargs
,
)
model
=
get_peft_model
(
model
,
lora_config
)
if
finetuning_args
.
finetuning_type
==
"lora"
:
peft_config
=
LoraConfig
(
task_type
=
TaskType
.
CAUSAL_LM
,
inference_mode
=
False
,
**
peft_kwargs
,
)
elif
finetuning_args
.
finetuning_type
==
"oft"
:
peft_config
=
OFTConfig
(
task_type
=
TaskType
.
CAUSAL_LM
,
inference_mode
=
False
,
**
peft_kwargs
,
)
model
=
get_peft_model
(
model
,
peft_config
)
if
is_trainable
and
cast_trainable_params_to_fp32
:
for
param
in
filter
(
lambda
p
:
p
.
requires_grad
,
model
.
parameters
()):
...
...
@@ -272,8 +332,8 @@ def init_adapter(
Note that the trainable parameters must be cast to float32.
"""
if
is_trainable
and
getattr
(
model
,
"quantization_method"
,
None
)
is
not
None
:
if
finetuning_args
.
finetuning_type
!=
"lora"
:
raise
ValueError
(
"Quantized models can only be used for the LoRA tuning."
)
if
finetuning_args
.
finetuning_type
not
in
[
"lora"
,
"oft"
]
:
raise
ValueError
(
"Quantized models can only be used for the LoRA
or OFT
tuning."
)
if
finetuning_args
.
pissa_init
:
raise
ValueError
(
"Cannot initialize PiSSA adapter on quantized models."
)
...
...
@@ -296,7 +356,7 @@ def init_adapter(
_setup_full_tuning
(
model
,
finetuning_args
,
is_trainable
,
cast_trainable_params_to_fp32
)
elif
finetuning_args
.
finetuning_type
==
"freeze"
:
_setup_freeze_tuning
(
model
,
finetuning_args
,
is_trainable
,
cast_trainable_params_to_fp32
)
elif
finetuning_args
.
finetuning_type
==
"lora"
:
elif
finetuning_args
.
finetuning_type
in
[
"lora"
,
"oft"
]
:
model
=
_setup_lora_tuning
(
config
,
model
,
model_args
,
finetuning_args
,
is_trainable
,
cast_trainable_params_to_fp32
)
...
...
src/llamafactory/model/loader.py
View file @
ca625f43
...
...
@@ -15,7 +15,6 @@
import
os
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
,
TypedDict
import
torch
from
transformers
import
(
AutoConfig
,
AutoModelForCausalLM
,
...
...
@@ -31,6 +30,7 @@ from trl import AutoModelForCausalLMWithValueHead
from
..extras
import
logging
from
..extras.misc
import
count_parameters
,
skip_check_imports
,
try_download_model_from_other_hub
from
.adapter
import
init_adapter
from
.model_utils.ktransformers
import
load_kt_pretrained_model
from
.model_utils.liger_kernel
import
apply_liger_kernel
from
.model_utils.misc
import
register_autoclass
from
.model_utils.mod
import
convert_pretrained_model_to_mod
,
load_mod_pretrained_model
...
...
@@ -143,7 +143,12 @@ def load_model(
model
=
None
lazy_load
=
False
if
model_args
.
use_unsloth
:
if
model_args
.
use_kt
:
from
ktransformers.sft.monkey_patch_torch_module
import
install_patch
install_patch
()
model
=
load_kt_pretrained_model
(
config
,
model_args
)
elif
model_args
.
use_unsloth
:
if
model_args
.
adapter_name_or_path
is
not
None
:
lazy_load
=
True
elif
is_trainable
:
...
...
@@ -152,17 +157,18 @@ def load_model(
if
model
is
None
and
not
lazy_load
:
init_kwargs
[
"config"
]
=
config
init_kwargs
[
"pretrained_model_name_or_path"
]
=
model_args
.
model_name_or_path
init_kwargs
[
"torch_dtype"
]
=
"auto"
if
model_args
.
mixture_of_depths
==
"load"
:
model
=
load_mod_pretrained_model
(
**
init_kwargs
)
else
:
if
type
(
config
)
in
AutoModelForVision2Seq
.
_model_mapping
.
keys
():
# image-text
load_class
=
AutoModelForVision2Seq
elif
type
(
config
)
in
AutoModelForImageTextToText
.
_model_mapping
.
keys
():
# image-text
if
type
(
config
)
in
AutoModelForImageTextToText
.
_model_mapping
.
keys
():
# image-text
load_class
=
AutoModelForImageTextToText
elif
type
(
config
)
in
AutoModelForVision2Seq
.
_model_mapping
.
keys
():
# image-text
load_class
=
AutoModelForVision2Seq
elif
type
(
config
)
in
AutoModelForSeq2SeqLM
.
_model_mapping
.
keys
():
# audio-text
load_class
=
AutoModelForSeq2SeqLM
elif
type
(
config
)
in
AutoModelForTextToWaveform
.
_model_mapping
.
keys
():
# audio hack for qwen
2_5_
omni
elif
type
(
config
)
in
AutoModelForTextToWaveform
.
_model_mapping
.
keys
():
# audio hack for qwen
omni
load_class
=
AutoModelForTextToWaveform
else
:
load_class
=
AutoModelForCausalLM
...
...
@@ -171,8 +177,8 @@ def load_model(
model
=
load_class
.
from_config
(
config
,
trust_remote_code
=
model_args
.
trust_remote_code
)
else
:
model
=
load_class
.
from_pretrained
(
**
init_kwargs
)
if
getattr
(
model
.
config
,
"model_type"
,
None
)
==
"qwen2_5_omni"
:
model
=
model
.
thinker
# use part of Omni model
if
getattr
(
model
.
config
,
"model_type"
,
None
)
in
[
"qwen2_5_omni"
,
"qwen3_omni_moe"
]
:
model
=
getattr
(
model
,
"
thinker
"
)
if
model_args
.
mixture_of_depths
==
"convert"
:
model
=
convert_pretrained_model_to_mod
(
model
,
config
,
model_args
)
...
...
@@ -199,14 +205,21 @@ def load_model(
if
not
is_trainable
:
model
.
requires_grad_
(
False
)
for
param
in
model
.
parameters
():
if
param
.
data
.
dtype
==
torch
.
float32
and
model_args
.
compute_dtype
!=
torch
.
float32
:
param
.
data
=
param
.
data
.
to
(
model_args
.
compute_dtype
)
model
.
eval
()
else
:
model
.
train
()
# Borrowing the kernel plugins ability of v1 to temporarily apply the NPU fusion operator to v0,
# it is turned off by default, and can be discarded after the transition period ends.
if
model_args
.
use_v1_kernels
and
is_trainable
:
logger
.
warning_rank0
(
"You are try to using future feature about kernels, please note that this feature "
"is not supported for all models. If get any error, please disable this feature, or report the issue."
)
from
..v1.plugins.model_plugins.kernels.interface
import
apply_default_kernels
model
=
apply_default_kernels
(
model
=
model
,
include_kernels
=
model_args
.
use_v1_kernels
)
trainable_params
,
all_param
=
count_parameters
(
model
)
if
is_trainable
:
param_stats
=
(
...
...
src/llamafactory/model/model_utils/attention.py
View file @
ca625f43
...
...
@@ -14,10 +14,9 @@
from
typing
import
TYPE_CHECKING
from
transformers.utils
import
is_flash_attn_2_available
,
is_torch_sdpa_available
from
...extras
import
logging
from
...extras.constants
import
AttentionFunction
from
...extras.packages
import
is_torch_version_greater_than
if
TYPE_CHECKING
:
...
...
@@ -30,6 +29,20 @@ logger = logging.get_logger(__name__)
def
configure_attn_implementation
(
config
:
"PretrainedConfig"
,
model_args
:
"ModelArguments"
)
->
None
:
from
transformers.utils
import
is_flash_attn_2_available
if
getattr
(
config
,
"model_type"
,
None
)
==
"gpt_oss"
:
from
transformers.integrations.hub_kernels
import
load_and_register_kernel
flash_attn3_kernel
=
"kernels-community/vllm-flash-attn3"
load_and_register_kernel
(
flash_attn3_kernel
)
setattr
(
config
,
"_attn_implementation"
,
flash_attn3_kernel
)
setattr
(
config
,
"_attn_implementation_internal"
,
flash_attn3_kernel
)
model_args
.
flash_attn
=
AttentionFunction
.
FA3
logger
.
info_rank0
(
"Using FlashAttention-3 with attention sink for the gpt-oss model."
)
return
if
getattr
(
config
,
"model_type"
,
None
)
==
"gemma2"
:
if
model_args
.
flash_attn
==
AttentionFunction
.
AUTO
or
model_args
.
flash_attn
==
AttentionFunction
.
FA2
:
if
is_flash_attn_2_available
():
...
...
@@ -51,13 +64,15 @@ def configure_attn_implementation(config: "PretrainedConfig", model_args: "Model
requested_attn_implementation
=
"eager"
elif
model_args
.
flash_attn
==
AttentionFunction
.
SDPA
:
if
not
is_torch_
sdpa_available
(
):
if
not
is_torch_
version_greater_than
(
"2.1.1"
):
logger
.
warning_rank0
(
"torch>=2.1.1 is required for SDPA attention."
)
return
requested_attn_implementation
=
"sdpa"
elif
model_args
.
flash_attn
==
AttentionFunction
.
FA2
:
if
not
is_flash_attn_2_available
():
from
transformers
import
is_torch_npu_available
if
not
(
is_flash_attn_2_available
()
or
is_torch_npu_available
()):
logger
.
warning_rank0
(
"FlashAttention-2 is not installed."
)
return
...
...
src/llamafactory/model/model_utils/checkpointing.py
View file @
ca625f43
...
...
@@ -19,9 +19,11 @@
# limitations under the License.
import
inspect
import
os
from
collections.abc
import
Callable
from
functools
import
WRAPPER_ASSIGNMENTS
,
partial
,
wraps
from
types
import
MethodType
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
,
Union
import
torch
...
...
@@ -152,6 +154,13 @@ def prepare_model_for_training(model: "PreTrainedModel", model_args: "ModelArgum
if
param
.
ndim
==
1
and
any
(
ln_name
in
name
for
ln_name
in
LAYERNORM_NAMES
):
param
.
data
=
param
.
data
.
to
(
torch
.
float32
)
if
(
os
.
environ
.
get
(
"ACCELERATE_USE_FSDP"
,
"false"
).
lower
()
==
"true"
and
int
(
os
.
environ
.
get
(
"FSDP_VERSION"
,
"1"
))
==
2
):
model_args
.
use_reentrant_gc
=
False
logger
.
warning_rank0
(
"You are using fsdp2, `use_reentrant_gc` has been set to False."
)
if
not
model_args
.
disable_gradient_checkpointing
:
if
not
getattr
(
model
,
"supports_gradient_checkpointing"
,
False
):
logger
.
warning_rank0
(
"Current model does not support gradient checkpointing."
)
...
...
src/llamafactory/model/model_utils/embedding.py
View file @
ca625f43
...
...
@@ -14,7 +14,7 @@
import
math
from
contextlib
import
nullcontext
from
typing
import
TYPE_CHECKING
from
typing
import
TYPE_CHECKING
,
Optional
import
torch
from
transformers.integrations
import
is_deepspeed_zero3_enabled
...
...
@@ -30,6 +30,14 @@ logger = logging.get_logger(__name__)
def
_noisy_mean_initialization
(
embed_weight
:
"torch.Tensor"
,
num_new_tokens
:
int
)
->
None
:
"""Initialize new token embeddings with mean + Gaussian noise.
This is the default initialization method used by LlamaFactory.
Args:
embed_weight: The embedding weight matrix to initialize (shape: [vocab_size, embedding_dim])
num_new_tokens: Number of new tokens added at the end of the embedding matrix
"""
embedding_dim
=
embed_weight
.
size
(
1
)
avg_weight
=
embed_weight
[:
-
num_new_tokens
].
mean
(
dim
=
0
,
keepdim
=
True
)
noise_weight
=
torch
.
empty_like
(
embed_weight
[
-
num_new_tokens
:])
...
...
@@ -37,8 +45,125 @@ def _noisy_mean_initialization(embed_weight: "torch.Tensor", num_new_tokens: int
embed_weight
[
-
num_new_tokens
:]
=
avg_weight
+
noise_weight
def
resize_embedding_layer
(
model
:
"PreTrainedModel"
,
tokenizer
:
"PreTrainedTokenizer"
)
->
None
:
r
"""Resize token embeddings."""
def
_description_based_initialization
(
embed_weight
:
"torch.Tensor"
,
num_new_tokens
:
int
,
descriptions
:
dict
[
str
,
str
],
tokenizer
:
"PreTrainedTokenizer"
,
model
:
"PreTrainedModel"
,
add_noise
:
bool
=
False
,
)
->
None
:
"""Initialize new token embeddings based on textual descriptions.
For each new token, this function:
1. Tokenizes its description text
2. Gets embeddings of the description tokens
3. Averages them to initialize the new token's embedding
4. Optionally adds Gaussian noise
Args:
embed_weight: The embedding weight matrix to initialize (shape: [vocab_size, embedding_dim])
num_new_tokens: Number of new tokens added
descriptions: Dict mapping token string to its description text
e.g., {"<think>": "A token representing reasoning process"}
tokenizer: The tokenizer instance
model: The model instance (used to get input embeddings)
add_noise: Whether to add Gaussian noise to the initialization
Example:
descriptions = {
"<|START_OF_SVG|>": "Marks the beginning of an SVG document",
"<|END_OF_SVG|>": "Marks the end of an SVG document"
}
"""
embedding_dim
=
embed_weight
.
size
(
1
)
for
i
,
desc
in
enumerate
(
descriptions
.
values
()):
# Tokenize description text
tokens
=
tokenizer
(
desc
,
return_tensors
=
"pt"
,
add_special_tokens
=
False
)
with
torch
.
no_grad
():
token_ids
=
tokens
[
"input_ids"
][
0
]
# Move to the same device as embed_weight
device
=
embed_weight
.
device
token_ids
=
token_ids
.
to
(
device
)
# Filter out new tokens (they don't have valid embeddings yet)
valid_token_ids
=
token_ids
[
token_ids
<
(
len
(
tokenizer
)
-
num_new_tokens
)]
if
len
(
valid_token_ids
)
==
0
:
# Fallback: use mean of all existing embeddings
logger
.
warning_rank0
(
f
"Description for token
{
i
+
1
}
/
{
num_new_tokens
}
contains no valid tokens. "
"Using mean of existing embeddings."
)
base_embedding
=
embed_weight
[:
-
num_new_tokens
].
mean
(
dim
=
0
)
else
:
# Get embeddings of description tokens and average them
token_embeds
=
model
.
get_input_embeddings
()(
valid_token_ids
)
base_embedding
=
token_embeds
.
mean
(
dim
=
0
)
# Add noise if requested (ensure correct device and dtype)
if
add_noise
:
noise
=
torch
.
randn_like
(
base_embedding
)
*
(
1.0
/
math
.
sqrt
(
embedding_dim
))
embed_weight
[
-
num_new_tokens
+
i
]
=
base_embedding
+
noise
else
:
embed_weight
[
-
num_new_tokens
+
i
]
=
base_embedding
def
_initialize_embeddings
(
embed_weight
:
"torch.Tensor"
,
num_new_tokens
:
int
,
init_method
:
str
,
new_special_tokens_config
:
Optional
[
dict
],
tokenizer
:
"PreTrainedTokenizer"
,
model
:
"PreTrainedModel"
,
)
->
None
:
"""Single source of truth for embedding initialization.
This function selects the appropriate initialization method and applies it.
Args:
embed_weight: The embedding weight matrix to initialize
num_new_tokens: Number of new tokens added
init_method: Initialization method ('noise_init', 'desc_init', 'desc_init_w_noise')
new_special_tokens_config: Config dict with token descriptions (required for desc_init methods)
tokenizer: The tokenizer instance
model: The model instance
"""
if
init_method
==
"desc_init"
and
new_special_tokens_config
:
logger
.
info_rank0
(
"Using semantic initialization (desc_init) for new special tokens"
)
_description_based_initialization
(
embed_weight
,
num_new_tokens
,
new_special_tokens_config
,
tokenizer
,
model
,
add_noise
=
False
)
elif
init_method
==
"desc_init_w_noise"
and
new_special_tokens_config
:
logger
.
info_rank0
(
"Using semantic initialization with noise (desc_init_w_noise) for new special tokens"
)
_description_based_initialization
(
embed_weight
,
num_new_tokens
,
new_special_tokens_config
,
tokenizer
,
model
,
add_noise
=
True
)
else
:
if
init_method
!=
"noise_init"
:
logger
.
warning_rank0
(
f
"init_method='
{
init_method
}
' requires descriptions config, falling back to 'noise_init'"
)
logger
.
info_rank0
(
"Using noisy mean initialization (noise_init) for new special tokens"
)
_noisy_mean_initialization
(
embed_weight
,
num_new_tokens
)
def
resize_embedding_layer
(
model
:
"PreTrainedModel"
,
tokenizer
:
"PreTrainedTokenizer"
,
new_special_tokens_config
:
Optional
[
dict
]
=
None
,
init_special_tokens
:
str
=
"noise_init"
,
)
->
None
:
r
"""Resize token embeddings and initialize new tokens.
Args:
model: The model to resize
tokenizer: The tokenizer (used to get target vocab size)
new_special_tokens_config: Optional dict with token descriptions for semantic initialization
init_special_tokens: Initialization method ('noise_init', 'desc_init', 'desc_init_w_noise')
"""
if
is_deepspeed_zero3_enabled
():
import
deepspeed
# type: ignore
...
...
@@ -64,7 +189,30 @@ def resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedToken
with
context_maybe_zero3
:
new_embedding_size
=
model
.
get_input_embeddings
().
weight
.
size
(
0
)
num_new_tokens
=
new_embedding_size
-
current_embedding_size
_noisy_mean_initialization
(
model
.
get_input_embeddings
().
weight
.
data
,
num_new_tokens
)
_noisy_mean_initialization
(
model
.
get_output_embeddings
().
weight
.
data
,
num_new_tokens
)
logger
.
info_rank0
(
f
"Resizing embeddings:
{
current_embedding_size
}
->
{
new_embedding_size
}
(+
{
num_new_tokens
}
tokens)"
)
# Initialize input embeddings
_initialize_embeddings
(
model
.
get_input_embeddings
().
weight
.
data
,
num_new_tokens
,
init_special_tokens
,
new_special_tokens_config
,
tokenizer
,
model
,
)
# Initialize output embeddings if not tied
if
model
.
get_output_embeddings
()
is
not
None
and
not
model
.
config
.
tie_word_embeddings
:
_initialize_embeddings
(
model
.
get_output_embeddings
().
weight
.
data
,
num_new_tokens
,
init_special_tokens
,
new_special_tokens_config
,
tokenizer
,
model
,
)
model
.
config
.
vocab_size
=
new_embedding_size
logger
.
info_rank0
(
f
"Resized token embeddings from
{
current_embedding_size
}
to
{
new_embedding_size
}
."
)
src/llamafactory/model/model_utils/ktransformers.py
0 → 100644
View file @
ca625f43
# Copyright 2025 the KVCache.AI team, Approaching AI, and the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
importlib.util
as
_u
from
typing
import
TYPE_CHECKING
,
Any
import
torch
from
...extras
import
logging
from
...extras.misc
import
get_current_device
if
TYPE_CHECKING
:
from
...hparams
import
FinetuningArguments
,
ModelArguments
from
transformers
import
AutoConfig
,
AutoModelForCausalLM
,
PretrainedConfig
,
PreTrainedModel
KT_AVAILABLE
=
_u
.
find_spec
(
"ktransformers"
)
is
not
None
if
KT_AVAILABLE
:
from
ktransformers.models.modeling_deepseek
import
DeepseekV2ForCausalLM
from
ktransformers.models.modeling_deepseek_v3
import
DeepseekV3ForCausalLM
from
ktransformers.models.modeling_llama
import
LlamaForCausalLM
from
ktransformers.models.modeling_mixtral
import
MixtralForCausalLM
from
ktransformers.models.modeling_qwen2_moe
import
Qwen2MoeForCausalLM
from
ktransformers.models.modeling_qwen3_moe
import
Qwen3MoeForCausalLM
from
ktransformers.optimize.optimize
import
optimize_and_load_gguf
from
ktransformers.server.config.config
import
Config
from
ktransformers.sft.lora
import
inject_lora_layer
from
ktransformers.util.custom_loader
import
GGUFLoader
,
SafeTensorLoader
from
ktransformers.util.globals
import
GLOBAL_CONFIG
from
ktransformers.util.utils
import
load_weights
logger
=
logging
.
get_logger
(
__name__
)
def
_get_kt_kwargs
(
config
:
"PretrainedConfig"
,
model_name_or_path
:
str
,
model_args
:
"ModelArguments"
,
finetuning_args
:
"FinetuningArguments"
,
)
->
dict
[
str
,
Any
]:
return
{
"model_name"
:
model_name_or_path
,
"max_seq_length"
:
model_args
.
model_max_length
or
4096
,
"dtype"
:
model_args
.
compute_dtype
,
"load_in_4bit"
:
model_args
.
quantization_bit
==
4
,
"token"
:
model_args
.
hf_hub_token
,
"full_finetuning"
:
finetuning_args
.
finetuning_type
==
"full"
,
"device_map"
:
{
""
:
get_current_device
()},
"rope_scaling"
:
getattr
(
config
,
"rope_scaling"
,
None
),
"fix_tokenizer"
:
False
,
"trust_remote_code"
:
model_args
.
trust_remote_code
,
"use_gradient_checkpointing"
:
"ktransformers"
,
}
def
load_kt_pretrained_model
(
config
:
"PretrainedConfig"
,
model_args
:
"ModelArguments"
)
->
"PreTrainedModel"
:
r
"""Optionally load pretrained model with KTransformers. Used in training."""
custom_models
=
{
"DeepseekV2ForCausalLM"
:
DeepseekV2ForCausalLM
,
"DeepseekV3ForCausalLM"
:
DeepseekV3ForCausalLM
,
"Qwen2MoeForCausalLM"
:
Qwen2MoeForCausalLM
,
"Qwen3MoeForCausalLM"
:
Qwen3MoeForCausalLM
,
"LlamaForCausalLM"
:
LlamaForCausalLM
,
"MixtralForCausalLM"
:
MixtralForCausalLM
,
}
Config
().
cpu_infer
=
model_args
.
cpu_infer
Config
().
chunk_size
=
model_args
.
chunk_size
config
=
AutoConfig
.
from_pretrained
(
model_args
.
model_name_or_path
,
trust_remote_code
=
model_args
.
trust_remote_code
)
if
model_args
.
mode
==
"long_context"
:
assert
config
.
architectures
[
0
]
==
"LlamaForCausalLM"
,
"only LlamaForCausalLM support long_context mode"
torch
.
set_default_dtype
(
torch
.
float16
)
else
:
torch
.
set_default_dtype
(
config
.
torch_dtype
)
with
torch
.
device
(
"meta"
):
if
config
.
architectures
[
0
]
in
custom_models
:
print
(
"using custom modeling_xxx.py."
)
if
"Qwen2Moe"
in
config
.
architectures
[
0
]:
# Qwen2Moe must use flash_attention_2 to avoid overflow.
config
.
_attn_implementation
=
"flash_attention_2"
if
"Llama"
in
config
.
architectures
[
0
]:
config
.
_attn_implementation
=
"eager"
if
"Mixtral"
in
config
.
architectures
[
0
]:
config
.
_attn_implementation
=
"flash_attention_2"
model
=
custom_models
[
config
.
architectures
[
0
]](
config
)
else
:
attn_implementation
=
"flash_attention_2"
model
=
AutoModelForCausalLM
.
from_config
(
config
,
trust_remote_code
=
True
,
attn_implementation
=
attn_implementation
)
optimize_config_path
=
model_args
.
kt_optimize_rule
gguf_path
=
model_args
.
model_name_or_path
assert
optimize_config_path
is
not
None
,
"optimize_config_path must be provided (path to YAML rules file)."
assert
gguf_path
is
not
None
,
"gguf_path must be provided (path to a folder or .gguf file)."
GLOBAL_CONFIG
.
_config
[
"mod"
]
=
"infer"
optimize_and_load_gguf
(
model
,
optimize_config_path
,
gguf_path
,
config
)
return
model
def
get_kt_peft_model
(
model
:
"PreTrainedModel"
,
peft_kwargs
:
dict
[
str
,
Any
])
->
"PreTrainedModel"
:
r
"""Get the peft model for the pretrained model with KTransformers. Used in training."""
from
ktransformers.sft.peft_utils.mapping
import
get_peft_model
return
get_peft_model
(
model
,
peft_kwargs
)
def
load_kt_peft_model
(
model_args
:
"ModelArguments"
,
model
:
"PreTrainedModel"
)
->
"PreTrainedModel"
:
r
"""Load peft model with KTransformers. Used in both training and inference."""
load_adapter_name_or_path
=
model_args
.
adapter_name_or_path
[
0
]
if
load_adapter_name_or_path
.
endswith
(
".gguf"
):
inject_lora_layer
(
model
,
load_adapter_name_or_path
)
adapter_gguf_loader
=
GGUFLoader
(
load_adapter_name_or_path
)
load_weights
(
model
,
adapter_gguf_loader
,
adapter_gguf
=
True
)
model
.
train
()
else
:
inject_lora_layer
(
model
,
load_adapter_name_or_path
)
adapter_loader
=
SafeTensorLoader
(
load_adapter_name_or_path
)
device
=
next
(
model
.
parameters
()).
device
for
key
in
adapter_loader
.
tensor_file_map
.
keys
():
try
:
tensor
=
adapter_loader
.
load_tensor
(
key
,
device
=
device
)
model_key
=
key
.
replace
(
"base_model.model."
,
""
)
model_key
=
model_key
.
replace
(
".weight"
,
".default.weight"
)
model_key
=
model_key
.
replace
(
".default.default.weight"
,
".default.weight"
)
param
=
model
.
get_parameter
(
model_key
)
param
.
data
.
copy_
(
tensor
.
data
)
print
(
f
"Loaded adapter weight:
{
key
}
->
{
model_key
}
"
)
except
AttributeError
:
print
(
f
"Skipping
{
key
}
: not a model parameter"
)
except
KeyError
:
print
(
f
"Key not found in model:
{
model_key
}
(original:
{
key
}
)"
)
return
model
src/llamafactory/model/model_utils/kv_cache.py
View file @
ca625f43
...
...
@@ -28,11 +28,11 @@ if TYPE_CHECKING:
def
configure_kv_cache
(
config
:
"PretrainedConfig"
,
model_args
:
"ModelArguments"
,
is_trainable
:
bool
)
->
None
:
if
not
is_trainable
:
setattr
(
config
,
"use_cache"
,
model_args
.
use_cache
)
setattr
(
config
,
"use_cache"
,
model_args
.
use_
kv_
cache
)
if
hasattr
(
config
,
"text_config"
):
setattr
(
config
.
text_config
,
"use_cache"
,
model_args
.
use_cache
)
setattr
(
config
.
text_config
,
"use_cache"
,
model_args
.
use_
kv_
cache
)
if
model_args
.
use_cache
:
if
model_args
.
use_
kv_
cache
:
logger
.
info_rank0
(
"KV cache is enabled for faster generation."
)
else
:
logger
.
info_rank0
(
"KV cache is disabled."
)
...
...
src/llamafactory/model/model_utils/liger_kernel.py
View file @
ca625f43
...
...
@@ -47,6 +47,8 @@ def apply_liger_kernel(
from
liger_kernel.transformers
import
apply_liger_kernel_to_gemma3_text
as
apply_liger_kernel
elif
model_type
==
"glm4"
:
from
liger_kernel.transformers
import
apply_liger_kernel_to_glm4
as
apply_liger_kernel
elif
model_type
==
"glm4v"
:
from
liger_kernel.transformers
import
apply_liger_kernel_to_glm4v
as
apply_liger_kernel
elif
model_type
==
"granite"
:
from
liger_kernel.transformers
import
apply_liger_kernel_to_granite
as
apply_liger_kernel
elif
model_type
==
"llama"
:
...
...
@@ -75,6 +77,12 @@ def apply_liger_kernel(
from
liger_kernel.transformers
import
apply_liger_kernel_to_qwen3
as
apply_liger_kernel
elif
model_type
==
"qwen3_moe"
:
from
liger_kernel.transformers
import
apply_liger_kernel_to_qwen3_moe
as
apply_liger_kernel
elif
model_type
==
"gpt_oss"
:
try
:
from
liger_kernel.transformers
import
apply_liger_kernel_to_gpt_oss
as
apply_liger_kernel
except
ImportError
:
logger
.
warning_rank0
(
"Please install liger-kernel from https://github.com/Comet0322/Liger-Kernel."
)
return
else
:
logger
.
warning_rank0
(
"Current model does not support liger kernel."
)
return
...
...
src/llamafactory/model/model_utils/moe.py
View file @
ca625f43
...
...
@@ -14,9 +14,13 @@
from
typing
import
TYPE_CHECKING
,
Union
import
torch
from
torch
import
nn
from
torch.nn
import
functional
as
F
from
transformers.integrations
import
is_deepspeed_zero3_enabled
from
...extras.misc
import
check_version
from
...extras.packages
import
is_transformers_version_greater_than
if
TYPE_CHECKING
:
...
...
@@ -25,6 +29,9 @@ if TYPE_CHECKING:
from
...hparams
import
ModelArguments
if
is_transformers_version_greater_than
(
"4.57.0"
):
from
transformers.models.qwen3_omni_moe
import
modeling_qwen3_omni_moe
def
_set_z3_leaf_modules
(
model
:
"PreTrainedModel"
,
leaf_modules
:
list
[
Union
[
"nn.Module"
,
str
]])
->
None
:
check_version
(
"deepspeed>=0.13.0"
)
...
...
@@ -39,6 +46,9 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
return
model_type
=
getattr
(
model
.
config
,
"model_type"
,
None
)
text_config
=
getattr
(
model
.
config
,
"text_config"
,
None
)
text_model_type
=
getattr
(
text_config
,
"model_type"
,
None
)
if
model_type
==
"dbrx"
:
from
transformers.models.dbrx.modeling_dbrx
import
DbrxFFN
...
...
@@ -52,11 +62,31 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
# deepseek v3 and kimi vl use custom code
_set_z3_leaf_modules
(
model
,
[
"DeepseekV3MoE"
])
if
model_type
==
"ernie4_5_moe"
:
from
transformers.models.ernie4_5_moe.modeling_ernie4_5_moe
import
Ernie4_5_MoeSparseMoeBlock
_set_z3_leaf_modules
(
model
,
[
Ernie4_5_MoeSparseMoeBlock
])
if
model_type
==
"granitemoe"
:
from
transformers.models.granitemoe.modeling_granitemoe
import
GraniteMoeMoE
_set_z3_leaf_modules
(
model
,
[
GraniteMoeMoE
])
if
model_type
==
"glm4_moe"
:
from
transformers.models.glm4_moe.modeling_glm4_moe
import
Glm4MoeMoE
_set_z3_leaf_modules
(
model
,
[
Glm4MoeMoE
])
if
model_type
==
"glm4v_moe"
:
from
transformers.models.glm4v_moe.modeling_glm4v_moe
import
Glm4vMoeTextMoE
_set_z3_leaf_modules
(
model
,
[
Glm4vMoeTextMoE
])
if
model_type
==
"gpt_oss"
:
from
transformers.models.gpt_oss.modeling_gpt_oss
import
GptOssMLP
_set_z3_leaf_modules
(
model
,
[
GptOssMLP
])
if
model_type
==
"jamba"
:
from
transformers.models.jamba.modeling_jamba
import
JambaSparseMoeBlock
...
...
@@ -92,19 +122,32 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
_set_z3_leaf_modules
(
model
,
[
Qwen2MoeSparseMoeBlock
])
if
model_type
==
"qwen3_moe"
:
if
model_type
==
"qwen3_moe"
or
text_model_type
==
"qwen3_moe"
:
# internvl 3.5
from
transformers.models.qwen3_moe.modeling_qwen3_moe
import
Qwen3MoeSparseMoeBlock
_set_z3_leaf_modules
(
model
,
[
Qwen3MoeSparseMoeBlock
])
if
model_type
==
"qwen3_vl_moe"
:
from
transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe
import
Qwen3VLMoeTextSparseMoeBlock
_set_z3_leaf_modules
(
model
,
[
Qwen3VLMoeTextSparseMoeBlock
])
if
model_type
in
(
"qwen3_omni_moe"
,
"qwen3_omni_moe_thinker"
):
from
transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe
import
Qwen3OmniMoeThinkerTextSparseMoeBlock
_set_z3_leaf_modules
(
model
,
[
Qwen3OmniMoeThinkerTextSparseMoeBlock
])
def
configure_moe
(
config
:
"PretrainedConfig"
,
model_args
:
"ModelArguments"
,
is_trainable
:
bool
)
->
None
:
if
not
is_trainable
or
not
model_args
.
moe_aux_loss_coef
:
return
model_type
=
getattr
(
config
,
"model_type"
,
None
)
text_config
=
getattr
(
config
,
"text_config"
,
None
)
# for multimodal model
if
model_type
in
[
"dbrx"
,
"ernie4_5_moe"
,
"granitemoe"
,
"jamba"
,
"jetmoe"
,
...
...
@@ -117,11 +160,93 @@ def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_t
]:
setattr
(
config
,
"output_router_logits"
,
True
)
if
model_type
in
[
"granitemoe"
,
"jamba"
,
"llama4"
,
"mixtral"
,
"olmoe"
,
"phimoe"
,
"qwen2_moe"
,
"qwen3_moe"
]:
if
text_config
and
getattr
(
text_config
,
"model_type"
,
None
)
in
[
"glm4v_moe_text"
,
# glmv4_5
"qwen3_moe"
,
# internvl_3_5
]:
setattr
(
text_config
,
"output_router_logits"
,
True
)
if
model_type
in
[
"ernie4_5_moe"
,
"granitemoe"
,
"jamba"
,
"llama4"
,
"mixtral"
,
"olmoe"
,
"phimoe"
,
"qwen2_moe"
,
"qwen3_moe"
,
]:
setattr
(
config
,
"router_aux_loss_coef"
,
model_args
.
moe_aux_loss_coef
)
elif
text_config
and
getattr
(
text_config
,
"model_type"
,
None
)
in
[
"qwen3_moe"
]:
setattr
(
text_config
,
"router_aux_loss_coef"
,
model_args
.
moe_aux_loss_coef
)
elif
model_type
==
"deepseek"
:
setattr
(
config
,
"aux_loss_alpha"
,
model_args
.
moe_aux_loss_coef
)
elif
model_type
==
"jetmoe"
:
setattr
(
config
,
"aux_loss_coef"
,
model_args
.
moe_aux_loss_coef
)
class
Qwen3OmniMoeThinkerTextSparseMoeBlock
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
().
__init__
()
self
.
num_experts
=
config
.
num_experts
self
.
top_k
=
config
.
num_experts_per_tok
self
.
norm_topk_prob
=
config
.
norm_topk_prob
# gating
self
.
gate
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
num_experts
,
bias
=
False
)
self
.
experts
=
nn
.
ModuleList
(
[
modeling_qwen3_omni_moe
.
Qwen3OmniMoeThinkerTextMLP
(
config
,
intermediate_size
=
config
.
moe_intermediate_size
)
for
_
in
range
(
self
.
num_experts
)
]
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
batch_size
,
sequence_length
,
hidden_dim
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
# router_logits: (batch * sequence_length, n_experts)
router_logits
=
self
.
gate
(
hidden_states
)
# Calculate the routing weights for all experts
routing_weights
=
F
.
softmax
(
router_logits
,
dim
=
1
,
dtype
=
torch
.
float
)
# Retain the weight of the top_k and reset the rest of the expert rights to 0 (instead of retaining only top_k experts)
top_k_weights
,
top_k_indices
=
torch
.
topk
(
routing_weights
,
self
.
top_k
,
dim
=-
1
)
# Initialize the all-zero weight matrix (same shape as all experts)
full_routing_weights
=
torch
.
zeros_like
(
routing_weights
)
# Only the weight of top_k experts is retained, and the weight of the rest of the experts remains at 0
full_routing_weights
.
scatter_
(
1
,
top_k_indices
,
top_k_weights
)
# Normalized top_k weights (keep the original logic consistent)
if
self
.
norm_topk_prob
:
# Calculate the sum of the weights top_k each row (for normalization)
top_k_sum
=
full_routing_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
# Avoid dividing by zero
top_k_sum
=
torch
.
clamp
(
top_k_sum
,
min
=
1e-9
)
full_routing_weights
/=
top_k_sum
# Convert back to the input data type
full_routing_weights
=
full_routing_weights
.
to
(
hidden_states
.
dtype
)
final_hidden_states
=
torch
.
zeros
(
(
batch_size
*
sequence_length
,
hidden_dim
),
dtype
=
hidden_states
.
dtype
,
device
=
hidden_states
.
device
)
# Go through all the experts (not just the selected ones)
for
expert_idx
in
range
(
self
.
num_experts
):
expert_layer
=
self
.
experts
[
expert_idx
]
# Get the weight of the current expert (inactive expert has a weight of 0 here)
expert_weights
=
full_routing_weights
[:,
expert_idx
,
None
]
# shape: (batch*seq, 1)
# All samples participate in the calculations of the current expert, the weight may be equal to 0
current_hidden_states
=
expert_layer
(
hidden_states
)
*
expert_weights
# Add-up to all expert outputs (experts with a weight of 0 do not affect the result)
final_hidden_states
+=
current_hidden_states
final_hidden_states
=
final_hidden_states
.
reshape
(
batch_size
,
sequence_length
,
hidden_dim
)
return
final_hidden_states
,
router_logits
src/llamafactory/model/model_utils/packing.py
View file @
ca625f43
...
...
@@ -53,7 +53,7 @@ logger = logging.get_logger(__name__)
def
get_seqlens_in_batch
(
attention_mask
:
"torch.Tensor"
)
->
"torch.Tensor"
:
r
"""Get the sequnce lengths in the current batch.
r
"""Get the sequ
e
nce lengths in the current batch.
e.g.
```python
...
...
src/llamafactory/model/model_utils/quantization.py
View file @
ca625f43
...
...
@@ -83,6 +83,7 @@ def configure_quantization(
config
:
"PretrainedConfig"
,
tokenizer
:
"PreTrainedTokenizer"
,
model_args
:
"ModelArguments"
,
is_trainable
:
bool
,
init_kwargs
:
dict
[
str
,
Any
],
)
->
None
:
r
"""Priority: PTQ-quantized (train/infer) > AutoGPTQ (export) > On-the-fly quantization (train/infer)."""
...
...
@@ -90,12 +91,29 @@ def configure_quantization(
if
model_args
.
quantization_bit
is
not
None
:
logger
.
warning_rank0
(
"`quantization_bit` will not affect on the PTQ-quantized models."
)
if
is_deepspeed_zero3_enabled
()
or
is_fsdp_enabled
():
raise
ValueError
(
"DeepSpeed ZeRO-3 or FSDP is incompatible with PTQ-quantized models."
)
quantization_config
:
dict
[
str
,
Any
]
=
getattr
(
config
,
"quantization_config"
,
None
)
quant_method
=
quantization_config
.
get
(
"quant_method"
,
""
)
if
quant_method
not
in
(
QuantizationMethod
.
MXFP4
,
QuantizationMethod
.
FP8
)
and
(
is_deepspeed_zero3_enabled
()
or
is_fsdp_enabled
()
):
# mxfp4 will dequant the model weights
raise
ValueError
(
"DeepSpeed ZeRO-3 or FSDP is incompatible with PTQ-quantized models."
)
if
quant_method
==
QuantizationMethod
.
MXFP4
:
from
transformers
import
Mxfp4Config
quant_config
=
Mxfp4Config
(
dequantize
=
True
)
init_kwargs
[
"quantization_config"
]
=
quant_config
init_kwargs
[
"ignore_mismatched_sizes"
]
=
True
if
quant_method
==
QuantizationMethod
.
FP8
:
from
transformers
import
FineGrainedFP8Config
quant_config
=
FineGrainedFP8Config
(
dequantize
=
True
)
init_kwargs
[
"quantization_config"
]
=
quant_config
init_kwargs
[
"ignore_mismatched_sizes"
]
=
True
if
quant_method
==
QuantizationMethod
.
GPTQ
:
check_version
(
"gptqmodel>=2.0.0"
,
mandatory
=
True
)
quantization_config
.
pop
(
"disable_exllama"
,
None
)
# remove deprecated args
...
...
src/llamafactory/model/model_utils/rope.py
View file @
ca625f43
...
...
@@ -40,7 +40,10 @@ def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments") ->
logger
.
warning_rank0
(
"Current model does not support RoPE scaling."
)
return
if
hasattr
(
config
,
"max_position_embeddings"
):
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
if
isinstance
(
rope_scaling
,
dict
)
and
"original_max_position_embeddings"
in
rope_scaling
:
old_max_length
=
rope_scaling
[
"original_max_position_embeddings"
]
elif
hasattr
(
config
,
"max_position_embeddings"
):
old_max_length
=
getattr
(
config
,
"max_position_embeddings"
,
None
)
else
:
logger
.
warning_rank0
(
"Cannot find the max position embeddings in the config."
)
...
...
src/llamafactory/model/model_utils/visual.py
View file @
ca625f43
...
...
@@ -199,6 +199,15 @@ def patch_target_modules(
return
target_modules
_register_composite_model
(
model_type
=
"dots_ocr"
,
projector_key
=
"vision_tower.merger"
,
vision_model_keys
=
[
"vision_tower"
],
language_model_keys
=
[
"model"
,
"lm_head"
],
lora_conflict_keys
=
[
"merger"
],
)
_register_composite_model
(
model_type
=
"gemma3"
,
)
...
...
@@ -221,10 +230,36 @@ _register_composite_model(
)
_register_composite_model
(
model_type
=
"glm4v_moe"
,
projector_key
=
"visual.merger"
,
vision_model_keys
=
[
"visual.patch_embed"
,
"visual.blocks"
],
language_model_keys
=
[
"language_model"
,
"lm_head"
],
lora_conflict_keys
=
[
"patch_embed"
],
)
_register_composite_model
(
model_type
=
"internvl"
,
)
_register_composite_model
(
model_type
=
"interns1"
,
)
_register_composite_model
(
model_type
=
"Keye"
,
projector_key
=
"mlp_AR"
,
vision_model_keys
=
[
"visual.vision_model.patch_embedding"
,
"visual.vision_model.encoder"
],
language_model_keys
=
[
"model"
,
"lm_head"
],
lora_conflict_keys
=
[
"patch_embedding"
],
)
_register_composite_model
(
model_type
=
"kimi_vl"
,
)
_register_composite_model
(
model_type
=
"llama4"
,
...
...
@@ -263,8 +298,10 @@ _register_composite_model(
lora_conflict_keys
=
[
"audio_projection_layer"
],
)
_register_composite_model
(
model_type
=
"mistral3"
,
projector_key
=
"model.multi_modal_projector"
,
)
...
...
@@ -316,6 +353,33 @@ _register_composite_model(
)
_register_composite_model
(
model_type
=
"qwen3_vl"
,
projector_key
=
"visual.merger"
,
vision_model_keys
=
[
"visual.patch_embed"
,
"visual.blocks"
,
"visual.deepstack_merger_list"
],
language_model_keys
=
[
"language_model"
,
"lm_head"
],
lora_conflict_keys
=
[
"patch_embed"
],
)
_register_composite_model
(
model_type
=
"qwen3_vl_moe"
,
projector_key
=
"visual.merger"
,
vision_model_keys
=
[
"visual.patch_embed"
,
"visual.blocks"
,
"visual.deepstack_merger_list"
],
language_model_keys
=
[
"language_model"
,
"lm_head"
],
lora_conflict_keys
=
[
"patch_embed"
],
)
_register_composite_model
(
model_type
=
"qwen3_omni_moe_thinker"
,
projector_key
=
"visual.merger"
,
vision_model_keys
=
[
"visual.patch_embed"
,
"visual.blocks"
,
"visual.deepstack_merger_list"
,
"audio_tower"
],
language_model_keys
=
[
"model"
,
"lm_head"
],
lora_conflict_keys
=
[
"patch_embed"
],
)
_register_composite_model
(
model_type
=
"video_llava"
,
)
Prev
1
…
5
6
7
8
9
10
11
12
13
…
17
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment