Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhougaofeng
internlm2-math-7B
Commits
9d32bb9c
Commit
9d32bb9c
authored
Jun 11, 2024
by
zhougaofeng
Browse files
Upload New File
parent
56c62116
Pipeline
#1156
canceled with stages
Changes
1
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
363 additions
and
0 deletions
+363
-0
src/llmfactory/hparams/finetuning_args.py
src/llmfactory/hparams/finetuning_args.py
+363
-0
No files found.
src/llmfactory/hparams/finetuning_args.py
0 → 100644
View file @
9d32bb9c
from
dataclasses
import
dataclass
,
field
from
typing
import
Literal
,
Optional
@
dataclass
class
FreezeArguments
:
r
"""
Arguments pertaining to the freeze (partial-parameter) training.
"""
freeze_trainable_layers
:
int
=
field
(
default
=
2
,
metadata
=
{
"help"
:
(
"The number of trainable layers for freeze (partial-parameter) fine-tuning. "
"Positive numbers mean the last n layers are set as trainable, "
"negative numbers mean the first n layers are set as trainable."
)
},
)
freeze_trainable_modules
:
str
=
field
(
default
=
"all"
,
metadata
=
{
"help"
:
(
"Name(s) of trainable modules for freeze (partial-parameter) fine-tuning. "
"Use commas to separate multiple modules. "
"Use `all` to specify all the available modules. "
"LLaMA choices: [`mlp`, `self_attn`], "
"BLOOM & Falcon & ChatGLM choices: [`mlp`, `self_attention`], "
"Qwen choices: [`mlp`, `attn`], "
"InternLM2 choices: [`feed_forward`, `attention`], "
"Others choices: the same as LLaMA."
)
},
)
freeze_extra_modules
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
(
"Name(s) of modules apart from hidden layers to be set as trainable "
"for freeze (partial-parameter) fine-tuning. "
"Use commas to separate multiple modules."
)
},
)
@
dataclass
class
LoraArguments
:
r
"""
Arguments pertaining to the LoRA training.
"""
additional_target
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
(
"Name(s) of modules apart from LoRA layers to be set as trainable "
"and saved in the final checkpoint. "
"Use commas to separate multiple modules."
)
},
)
lora_alpha
:
Optional
[
int
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The scale factor for LoRA fine-tuning (default: lora_rank * 2)."
},
)
lora_dropout
:
float
=
field
(
default
=
0.0
,
metadata
=
{
"help"
:
"Dropout rate for the LoRA fine-tuning."
},
)
lora_rank
:
int
=
field
(
default
=
8
,
metadata
=
{
"help"
:
"The intrinsic dimension for LoRA fine-tuning."
},
)
lora_target
:
str
=
field
(
default
=
"all"
,
metadata
=
{
"help"
:
(
"Name(s) of target modules to apply LoRA. "
"Use commas to separate multiple modules. "
"Use `all` to specify all the linear modules. "
"LLaMA choices: [`q_proj`, `k_proj`, `v_proj`, `o_proj`, `gate_proj`, `up_proj`, `down_proj`], "
"BLOOM & Falcon & ChatGLM choices: [`query_key_value`, `dense`, `dense_h_to_4h`, `dense_4h_to_h`], "
"Baichuan choices: [`W_pack`, `o_proj`, `gate_proj`, `up_proj`, `down_proj`], "
"Qwen choices: [`c_attn`, `attn.c_proj`, `w1`, `w2`, `mlp.c_proj`], "
"InternLM2 choices: [`wqkv`, `wo`, `w1`, `w2`, `w3`], "
"Others choices: the same as LLaMA."
)
},
)
loraplus_lr_ratio
:
Optional
[
float
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"LoRA plus learning rate ratio (lr_B / lr_A)."
},
)
loraplus_lr_embedding
:
float
=
field
(
default
=
1e-6
,
metadata
=
{
"help"
:
"LoRA plus learning rate for lora embedding layers."
},
)
use_rslora
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to use the rank stabilization scaling factor for LoRA layer."
},
)
use_dora
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to use the weight-decomposed lora method (DoRA)."
},
)
create_new_adapter
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to create a new adapter with randomly initialized weight."
},
)
@
dataclass
class
RLHFArguments
:
r
"""
Arguments pertaining to the PPO and DPO training.
"""
dpo_beta
:
float
=
field
(
default
=
0.1
,
metadata
=
{
"help"
:
"The beta parameter for the DPO loss."
},
)
dpo_loss
:
Literal
[
"sigmoid"
,
"hinge"
,
"ipo"
,
"kto_pair"
]
=
field
(
default
=
"sigmoid"
,
metadata
=
{
"help"
:
"The type of DPO loss to use."
},
)
dpo_label_smoothing
:
float
=
field
(
default
=
0.0
,
metadata
=
{
"help"
:
"The robust DPO label smoothing parameter in cDPO that should be between 0 and 0.5."
},
)
dpo_ftx
:
float
=
field
(
default
=
0.0
,
metadata
=
{
"help"
:
"The supervised fine-tuning loss coefficient in DPO training."
},
)
kto_beta
:
float
=
field
(
default
=
0.1
,
metadata
=
{
"help"
:
"The beta parameter for the KTO loss."
},
)
kto_chosen_weight
:
float
=
field
(
default
=
1.0
,
metadata
=
{
"help"
:
"The weight factor of the desirable losses in KTO training."
},
)
kto_rejected_weight
:
float
=
field
(
default
=
1.0
,
metadata
=
{
"help"
:
"The weight factor of the undesirable losses in KTO training."
},
)
kto_ftx
:
float
=
field
(
default
=
0.0
,
metadata
=
{
"help"
:
"The supervised fine-tuning loss coefficient in KTO training."
},
)
orpo_beta
:
float
=
field
(
default
=
0.1
,
metadata
=
{
"help"
:
"The beta (lambda) parameter in the ORPO loss representing the weight of the SFT loss."
},
)
ppo_buffer_size
:
int
=
field
(
default
=
1
,
metadata
=
{
"help"
:
"The number of mini-batches to make experience buffer in a PPO optimization step."
},
)
ppo_epochs
:
int
=
field
(
default
=
4
,
metadata
=
{
"help"
:
"The number of epochs to perform in a PPO optimization step."
},
)
ppo_score_norm
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Use score normalization in PPO training."
},
)
ppo_target
:
float
=
field
(
default
=
6.0
,
metadata
=
{
"help"
:
"Target KL value for adaptive KL control in PPO training."
},
)
ppo_whiten_rewards
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whiten the rewards before compute advantages in PPO training."
},
)
ref_model
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Path to the reference model used for the PPO or DPO training."
},
)
ref_model_adapters
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Path to the adapters of the reference model."
},
)
ref_model_quantization_bit
:
Optional
[
int
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The number of bits to quantize the reference model."
},
)
reward_model
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Path to the reward model used for the PPO training."
},
)
reward_model_adapters
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Path to the adapters of the reward model."
},
)
reward_model_quantization_bit
:
Optional
[
int
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The number of bits to quantize the reward model."
},
)
reward_model_type
:
Literal
[
"lora"
,
"full"
,
"api"
]
=
field
(
default
=
"lora"
,
metadata
=
{
"help"
:
"The type of the reward model in PPO training. Lora model only supports lora training."
},
)
@
dataclass
class
GaloreArguments
:
r
"""
Arguments pertaining to the GaLore algorithm.
"""
use_galore
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to use the gradient low-Rank projection (GaLore)."
},
)
galore_target
:
str
=
field
(
default
=
"all"
,
metadata
=
{
"help"
:
(
"Name(s) of modules to apply GaLore. Use commas to separate multiple modules. "
"Use `all` to specify all the linear modules."
)
},
)
galore_rank
:
int
=
field
(
default
=
16
,
metadata
=
{
"help"
:
"The rank of GaLore gradients."
},
)
galore_update_interval
:
int
=
field
(
default
=
200
,
metadata
=
{
"help"
:
"Number of steps to update the GaLore projection."
},
)
galore_scale
:
float
=
field
(
default
=
0.25
,
metadata
=
{
"help"
:
"GaLore scaling coefficient."
},
)
galore_proj_type
:
Literal
[
"std"
,
"reverse_std"
,
"right"
,
"left"
,
"full"
]
=
field
(
default
=
"std"
,
metadata
=
{
"help"
:
"Type of GaLore projection."
},
)
galore_layerwise
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to enable layer-wise update to further save memory."
},
)
@
dataclass
class
BAdamArgument
:
r
"""
Arguments pertaining to the BAdam optimizer.
"""
use_badam
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to use the BAdam optimizer."
},
)
badam_mode
:
Literal
[
"layer"
,
"ratio"
]
=
field
(
default
=
"layer"
,
metadata
=
{
"help"
:
"Whether to use layer-wise or ratio-wise BAdam optimizer."
},
)
badam_start_block
:
Optional
[
int
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The starting block index for layer-wise BAdam."
},
)
badam_switch_mode
:
Optional
[
Literal
[
"ascending"
,
"descending"
,
"random"
,
"fixed"
]]
=
field
(
default
=
"ascending"
,
metadata
=
{
"help"
:
"the strategy of picking block to update for layer-wise BAdam."
},
)
badam_switch_interval
:
Optional
[
int
]
=
field
(
default
=
50
,
metadata
=
{
"help"
:
"Number of steps to update the block for layer-wise BAdam. Use -1 to disable the block update."
},
)
badam_update_ratio
:
float
=
field
(
default
=
0.05
,
metadata
=
{
"help"
:
"The ratio of the update for ratio-wise BAdam."
},
)
badam_mask_mode
:
Literal
[
"adjacent"
,
"scatter"
]
=
field
(
default
=
"adjacent"
,
metadata
=
{
"help"
:
(
"The mode of the mask for BAdam optimizer. "
"`adjacent` means that the trainable parameters are adjacent to each other, "
"`scatter` means that trainable parameters are randomly choosed from the weight."
)
},
)
badam_verbose
:
int
=
field
(
default
=
0
,
metadata
=
{
"help"
:
(
"The verbosity level of BAdam optimizer. "
"0 for no print, 1 for print the block prefix, 2 for print trainable parameters."
)
},
)
@
dataclass
class
FinetuningArguments
(
FreezeArguments
,
LoraArguments
,
RLHFArguments
,
GaloreArguments
,
BAdamArgument
):
r
"""
Arguments pertaining to which techniques we are going to fine-tuning with.
"""
pure_bf16
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to train model in purely bf16 precision (without AMP)."
},
)
stage
:
Literal
[
"pt"
,
"sft"
,
"rm"
,
"ppo"
,
"dpo"
,
"kto"
,
"orpo"
]
=
field
(
default
=
"sft"
,
metadata
=
{
"help"
:
"Which stage will be performed in training."
},
)
finetuning_type
:
Literal
[
"lora"
,
"freeze"
,
"full"
]
=
field
(
default
=
"lora"
,
metadata
=
{
"help"
:
"Which fine-tuning method to use."
},
)
use_llama_pro
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to make only the parameters in the expanded blocks trainable."
},
)
plot_loss
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to save the training loss curves."
},
)
def
__post_init__
(
self
):
def
split_arg
(
arg
):
if
isinstance
(
arg
,
str
):
return
[
item
.
strip
()
for
item
in
arg
.
split
(
","
)]
return
arg
self
.
freeze_trainable_modules
=
split_arg
(
self
.
freeze_trainable_modules
)
self
.
freeze_extra_modules
=
split_arg
(
self
.
freeze_extra_modules
)
self
.
lora_alpha
=
self
.
lora_alpha
or
self
.
lora_rank
*
2
self
.
lora_target
=
split_arg
(
self
.
lora_target
)
self
.
additional_target
=
split_arg
(
self
.
additional_target
)
self
.
galore_target
=
split_arg
(
self
.
galore_target
)
assert
self
.
finetuning_type
in
[
"lora"
,
"freeze"
,
"full"
],
"Invalid fine-tuning method."
assert
self
.
ref_model_quantization_bit
in
[
None
,
8
,
4
],
"We only accept 4-bit or 8-bit quantization."
assert
self
.
reward_model_quantization_bit
in
[
None
,
8
,
4
],
"We only accept 4-bit or 8-bit quantization."
if
self
.
stage
==
"ppo"
and
self
.
reward_model
is
None
:
raise
ValueError
(
"`reward_model` is necessary for PPO training."
)
if
self
.
stage
==
"ppo"
and
self
.
reward_model_type
==
"lora"
and
self
.
finetuning_type
!=
"lora"
:
raise
ValueError
(
"`reward_model_type` cannot be lora for Freeze/Full PPO training."
)
if
self
.
stage
==
"dpo"
and
self
.
dpo_loss
!=
"sigmoid"
and
self
.
dpo_label_smoothing
>
1e-6
:
raise
ValueError
(
"`dpo_label_smoothing` is only valid for sigmoid loss function."
)
if
self
.
use_llama_pro
and
self
.
finetuning_type
==
"full"
:
raise
ValueError
(
"`use_llama_pro` is only valid for the Freeze or LoRA training."
)
if
self
.
use_galore
and
self
.
finetuning_type
==
"lora"
:
raise
ValueError
(
"Cannot use LoRA with GaLore together."
)
if
self
.
use_galore
and
self
.
use_badam
:
raise
ValueError
(
"Cannot use GaLore with BAdam together."
)
if
self
.
loraplus_lr_ratio
is
not
None
and
self
.
finetuning_type
!=
"lora"
:
raise
ValueError
(
"`loraplus_lr_ratio` is only valid for the LoRA training."
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment