Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
LLaMA-Factory
Commits
2778a3d0
Commit
2778a3d0
authored
Jan 16, 2025
by
luopl
Browse files
updata to v0.9.1_stable
parent
e92143e3
Changes
172
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
354 additions
and
186 deletions
+354
-186
src/llamafactory/model/model_utils/liger_kernel.py
src/llamafactory/model/model_utils/liger_kernel.py
+5
-5
src/llamafactory/model/model_utils/longlora.py
src/llamafactory/model/model_utils/longlora.py
+12
-12
src/llamafactory/model/model_utils/misc.py
src/llamafactory/model/model_utils/misc.py
+10
-8
src/llamafactory/model/model_utils/packing.py
src/llamafactory/model/model_utils/packing.py
+6
-6
src/llamafactory/model/model_utils/quantization.py
src/llamafactory/model/model_utils/quantization.py
+8
-8
src/llamafactory/model/model_utils/rope.py
src/llamafactory/model/model_utils/rope.py
+8
-10
src/llamafactory/model/model_utils/unsloth.py
src/llamafactory/model/model_utils/unsloth.py
+3
-3
src/llamafactory/model/model_utils/valuehead.py
src/llamafactory/model/model_utils/valuehead.py
+4
-4
src/llamafactory/model/model_utils/visual.py
src/llamafactory/model/model_utils/visual.py
+21
-14
src/llamafactory/model/patcher.py
src/llamafactory/model/patcher.py
+6
-6
src/llamafactory/train/callbacks.py
src/llamafactory/train/callbacks.py
+28
-28
src/llamafactory/train/dpo/trainer.py
src/llamafactory/train/dpo/trainer.py
+62
-14
src/llamafactory/train/dpo/workflow.py
src/llamafactory/train/dpo/workflow.py
+13
-0
src/llamafactory/train/kto/trainer.py
src/llamafactory/train/kto/trainer.py
+96
-30
src/llamafactory/train/kto/workflow.py
src/llamafactory/train/kto/workflow.py
+1
-1
src/llamafactory/train/ppo/ppo_utils.py
src/llamafactory/train/ppo/ppo_utils.py
+2
-2
src/llamafactory/train/ppo/trainer.py
src/llamafactory/train/ppo/trainer.py
+20
-21
src/llamafactory/train/pt/trainer.py
src/llamafactory/train/pt/trainer.py
+18
-5
src/llamafactory/train/rm/trainer.py
src/llamafactory/train/rm/trainer.py
+10
-5
src/llamafactory/train/sft/trainer.py
src/llamafactory/train/sft/trainer.py
+21
-4
No files found.
src/llamafactory/model/model_utils/liger_kernel.py
View file @
2778a3d0
...
...
@@ -15,7 +15,7 @@
import
inspect
from
typing
import
TYPE_CHECKING
from
...extras
.logging
import
get_
logg
er
from
...extras
import
logg
ing
if
TYPE_CHECKING
:
...
...
@@ -24,7 +24,7 @@ if TYPE_CHECKING:
from
...hparams
import
ModelArguments
logger
=
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
def
apply_liger_kernel
(
...
...
@@ -54,14 +54,14 @@ def apply_liger_kernel(
elif
model_type
==
"qwen2_vl"
:
from
liger_kernel.transformers
import
apply_liger_kernel_to_qwen2_vl
as
apply_liger_kernel
else
:
logger
.
warning
(
"Current model does not support liger kernel."
)
logger
.
warning
_rank0
(
"Current model does not support liger kernel."
)
return
if
require_logits
and
"fused_linear_cross_entropy"
in
inspect
.
signature
(
apply_liger_kernel
).
parameters
:
logger
.
info
(
"Current training stage does not support chunked cross entropy."
)
logger
.
info
_rank0
(
"Current training stage does not support chunked cross entropy."
)
kwargs
=
{
"fused_linear_cross_entropy"
:
False
}
else
:
kwargs
=
{}
apply_liger_kernel
(
**
kwargs
)
logger
.
info
(
"Liger kernel has been applied to the model."
)
logger
.
info
_rank0
(
"Liger kernel has been applied to the model."
)
src/llamafactory/model/model_utils/longlora.py
View file @
2778a3d0
...
...
@@ -22,6 +22,7 @@ from typing import TYPE_CHECKING, Optional, Tuple
import
torch
import
torch.nn
as
nn
import
transformers
from
transformers.models.llama.modeling_llama
import
(
Cache
,
LlamaAttention
,
...
...
@@ -30,12 +31,11 @@ from transformers.models.llama.modeling_llama import (
apply_rotary_pos_emb
,
repeat_kv
,
)
from
transformers.utils
import
logging
from
transformers.utils.versions
import
require_version
from
...extras
import
logging
from
...extras.constants
import
SUPPORTED_CLASS_FOR_S2ATTN
from
...extras.logging
import
get_logger
from
...extras.packages
import
is_transformers_version_greater_than_4_43
from
...extras.packages
import
is_transformers_version_greater_than
if
TYPE_CHECKING
:
...
...
@@ -44,7 +44,7 @@ if TYPE_CHECKING:
from
...hparams
import
ModelArguments
transformers_logger
=
logging
.
get_logger
(
__name__
)
transformers_logger
=
transformers
.
utils
.
logging
.
get_logger
(
__name__
)
# Modified from:
...
...
@@ -86,7 +86,7 @@ def llama_attention_forward(
if
getattr
(
self
.
config
,
"group_size_ratio"
,
None
)
and
self
.
training
:
# shift
groupsz
=
int
(
q_len
*
getattr
(
self
.
config
,
"group_size_ratio"
))
assert
q_len
%
groupsz
==
0
,
"q_len {} should be divisible by group size {
}."
.
format
(
q_len
,
groupsz
)
assert
q_len
%
groupsz
==
0
,
f
"q_len
{
q_len
}
should be divisible by group size
{
groupsz
}
."
num_groups
=
q_len
//
groupsz
def
shift
(
state
:
"torch.Tensor"
)
->
"torch.Tensor"
:
...
...
@@ -195,7 +195,7 @@ def llama_flash_attention_2_forward(
if
getattr
(
self
.
config
,
"group_size_ratio"
,
None
)
and
self
.
training
:
# shift
groupsz
=
int
(
q_len
*
getattr
(
self
.
config
,
"group_size_ratio"
))
assert
q_len
%
groupsz
==
0
,
"q_len {} should be divisible by group size {
}."
.
format
(
q_len
,
groupsz
)
assert
q_len
%
groupsz
==
0
,
f
"q_len
{
q_len
}
should be divisible by group size
{
groupsz
}
."
num_groups
=
q_len
//
groupsz
def
shift
(
state
:
"torch.Tensor"
)
->
"torch.Tensor"
:
...
...
@@ -209,7 +209,7 @@ def llama_flash_attention_2_forward(
if
attention_mask
is
not
None
:
attention_mask
=
attention_mask
[:,
:
groupsz
].
repeat
(
num_groups
,
1
)
if
is_transformers_version_greater_than
_4_43
(
):
if
is_transformers_version_greater_than
(
"4.43.0"
):
from
transformers.modeling_flash_attention_utils
import
_flash_attention_forward
attn_output
:
"torch.Tensor"
=
_flash_attention_forward
(
...
...
@@ -301,7 +301,7 @@ def llama_sdpa_attention_forward(
if
getattr
(
self
.
config
,
"group_size_ratio"
,
None
)
and
self
.
training
:
# shift
groupsz
=
int
(
q_len
*
getattr
(
self
.
config
,
"group_size_ratio"
))
assert
q_len
%
groupsz
==
0
,
"q_len {} should be divisible by group size {
}."
.
format
(
q_len
,
groupsz
)
assert
q_len
%
groupsz
==
0
,
f
"q_len
{
q_len
}
should be divisible by group size
{
groupsz
}
."
num_groups
=
q_len
//
groupsz
def
shift
(
state
:
"torch.Tensor"
)
->
"torch.Tensor"
:
...
...
@@ -353,7 +353,7 @@ def llama_sdpa_attention_forward(
def
_apply_llama_patch
()
->
None
:
require_version
(
"transformers>=4.41.2,<=4.4
5.2
"
,
"To fix: pip install transformers>=4.41.2,<=4.4
5.2
"
)
require_version
(
"transformers>=4.41.2,<=4.4
6.1
"
,
"To fix: pip install transformers>=4.41.2,<=4.4
6.1
"
)
LlamaAttention
.
forward
=
llama_attention_forward
LlamaFlashAttention2
.
forward
=
llama_flash_attention_2_forward
LlamaSdpaAttention
.
forward
=
llama_sdpa_attention_forward
...
...
@@ -363,11 +363,11 @@ def configure_longlora(config: "PretrainedConfig", model_args: "ModelArguments",
if
not
is_trainable
or
not
model_args
.
shift_attn
:
return
logger
=
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
if
getattr
(
config
,
"model_type"
,
None
)
in
SUPPORTED_CLASS_FOR_S2ATTN
:
setattr
(
config
,
"group_size_ratio"
,
0.25
)
_apply_llama_patch
()
logger
.
info
(
"Using shift short attention with group_size_ratio=1/4."
)
logger
.
info
_rank0
(
"Using shift short attention with group_size_ratio=1/4."
)
else
:
logger
.
warning
(
"Current model does not support shift short attention."
)
logger
.
warning
_rank0
(
"Current model does not support shift short attention."
)
src/llamafactory/model/model_utils/misc.py
View file @
2778a3d0
...
...
@@ -14,14 +14,14 @@
from
typing
import
TYPE_CHECKING
,
List
from
...extras
.logging
import
get_
logg
er
from
...extras
import
logg
ing
if
TYPE_CHECKING
:
from
transformers
import
PretrainedConfig
,
PreTrainedModel
,
PreTrainedTokenizer
logger
=
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
def
find_all_linear_modules
(
model
:
"PreTrainedModel"
,
freeze_vision_tower
:
bool
)
->
List
[
str
]:
...
...
@@ -34,13 +34,15 @@ def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool)
forbidden_modules
.
add
(
"output_layer"
)
elif
model_type
==
"internlm2"
:
forbidden_modules
.
add
(
"output"
)
elif
model_type
in
[
"llava"
,
"llava_next"
,
"llava_next_video"
,
"paligemma"
,
"video_llava"
]:
elif
model_type
in
[
"llava"
,
"llava_next"
,
"llava_next_video"
,
"mllama"
,
"paligemma"
,
"video_llava"
]:
forbidden_modules
.
add
(
"multi_modal_projector"
)
elif
model_type
==
"qwen2_vl"
:
forbidden_modules
.
add
(
"merger"
)
if
freeze_vision_tower
:
if
model_type
==
"qwen2_vl"
:
if
model_type
==
"mllama"
:
forbidden_modules
.
add
(
"vision_model"
)
elif
model_type
==
"qwen2_vl"
:
forbidden_modules
.
add
(
"visual"
)
else
:
forbidden_modules
.
add
(
"vision_tower"
)
...
...
@@ -53,7 +55,7 @@ def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool)
if
"Linear"
in
module
.
__class__
.
__name__
and
"Embedding"
not
in
module
.
__class__
.
__name__
:
module_names
.
add
(
name
.
split
(
"."
)[
-
1
])
logger
.
info
(
"Found linear modules: {}"
.
format
(
","
.
join
(
module_names
)))
logger
.
info
_rank0
(
"Found linear modules: {}"
.
format
(
","
.
join
(
module_names
)))
return
list
(
module_names
)
...
...
@@ -67,12 +69,12 @@ def find_expanded_modules(model: "PreTrainedModel", target_modules: List[str], n
if
num_layers
%
num_layer_trainable
!=
0
:
raise
ValueError
(
"`num_layers` {} should be divisible by `num_layer_trainable` {
}."
.
format
(
num_layers
,
num_layer_trainable
)
f
"`num_layers`
{
num_layers
}
should be divisible by `num_layer_trainable`
{
num_layer_trainable
}
."
)
stride
=
num_layers
//
num_layer_trainable
trainable_layer_ids
=
range
(
stride
-
1
,
num_layers
+
stride
-
1
,
stride
)
trainable_layers
=
[
".{:d}."
.
format
(
idx
)
for
idx
in
trainable_layer_ids
]
trainable_layers
=
[
f
".
{
idx
:
d
}
."
for
idx
in
trainable_layer_ids
]
module_names
=
[]
for
name
,
_
in
model
.
named_modules
():
if
any
(
target_module
in
name
for
target_module
in
target_modules
)
and
any
(
...
...
@@ -80,7 +82,7 @@ def find_expanded_modules(model: "PreTrainedModel", target_modules: List[str], n
):
module_names
.
append
(
name
)
logger
.
info
(
"Apply lora to layers: {}"
.
format
(
","
.
join
(
map
(
str
,
trainable_layer_ids
))))
logger
.
info
_rank0
(
"Apply lora to layers: {}"
.
format
(
","
.
join
(
map
(
str
,
trainable_layer_ids
))))
return
module_names
...
...
src/llamafactory/model/model_utils/packing.py
View file @
2778a3d0
...
...
@@ -43,9 +43,9 @@ import torch
import
torch.nn.functional
as
F
from
transformers.utils.versions
import
require_version
from
...extras
import
logging
from
...extras.constants
import
SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN
from
...extras.logging
import
get_logger
from
...extras.packages
import
is_transformers_version_greater_than_4_43
from
...extras.packages
import
is_transformers_version_greater_than
if
TYPE_CHECKING
:
...
...
@@ -54,7 +54,7 @@ if TYPE_CHECKING:
from
...hparams
import
ModelArguments
logger
=
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
def
get_seqlens_in_batch
(
attention_mask
:
"torch.Tensor"
)
->
"torch.Tensor"
:
...
...
@@ -114,8 +114,8 @@ def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "tor
def
_patch_for_block_diag_attn
(
model_type
:
str
)
->
None
:
require_version
(
"transformers>=4.41.2,<=4.4
5.2
"
,
"To fix: pip install transformers>=4.41.2,<=4.4
5.2
"
)
if
is_transformers_version_greater_than
_4_43
(
):
require_version
(
"transformers>=4.41.2,<=4.4
6.1
"
,
"To fix: pip install transformers>=4.41.2,<=4.4
6.1
"
)
if
is_transformers_version_greater_than
(
"4.43.0"
):
import
transformers.modeling_flash_attention_utils
transformers
.
modeling_flash_attention_utils
.
_get_unpad_data
=
get_unpad_data
...
...
@@ -152,6 +152,6 @@ def configure_packing(config: "PretrainedConfig", model_args: "ModelArguments",
model_type
=
getattr
(
config
,
"model_type"
,
None
)
if
model_type
in
SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN
:
_patch_for_block_diag_attn
(
model_type
)
logger
.
info
(
"Using block diagonal attention for sequence packing without cross-attention."
)
logger
.
info
_rank0
(
"Using block diagonal attention for sequence packing without cross-attention."
)
else
:
raise
ValueError
(
"Current model does not support block diagonal attention."
)
src/llamafactory/model/model_utils/quantization.py
View file @
2778a3d0
...
...
@@ -28,8 +28,8 @@ from transformers.integrations import is_deepspeed_zero3_enabled
from
transformers.modeling_utils
import
is_fsdp_enabled
from
transformers.utils.versions
import
require_version
from
...extras
import
logging
from
...extras.constants
import
FILEEXT2TYPE
from
...extras.logging
import
get_logger
from
...extras.misc
import
get_current_device
...
...
@@ -39,7 +39,7 @@ if TYPE_CHECKING:
from
...hparams
import
ModelArguments
logger
=
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
@
unique
...
...
@@ -109,7 +109,7 @@ def configure_quantization(
"""
if
getattr
(
config
,
"quantization_config"
,
None
):
# ptq
if
model_args
.
quantization_bit
is
not
None
:
logger
.
warning
(
"`quantization_bit` will not affect on the PTQ-quantized models."
)
logger
.
warning
_rank0
(
"`quantization_bit` will not affect on the PTQ-quantized models."
)
if
is_deepspeed_zero3_enabled
()
or
is_fsdp_enabled
():
raise
ValueError
(
"DeepSpeed ZeRO-3 or FSDP is incompatible with PTQ-quantized models."
)
...
...
@@ -130,7 +130,7 @@ def configure_quantization(
quantization_config
[
"bits"
]
=
2
quant_bits
=
quantization_config
.
get
(
"bits"
,
"?"
)
logger
.
info
(
"Loading {
}-bit {}-quantized model."
.
format
(
quant_bits
,
quant_method
.
upper
()
)
)
logger
.
info
_rank0
(
f
"Loading
{
quant_bits
}
-bit
{
quant_method
.
upper
()
}
-quantized model."
)
elif
model_args
.
export_quantization_bit
is
not
None
:
# auto-gptq
if
model_args
.
export_quantization_bit
not
in
[
8
,
4
,
3
,
2
]:
...
...
@@ -149,7 +149,7 @@ def configure_quantization(
)
init_kwargs
[
"device_map"
]
=
"auto"
init_kwargs
[
"max_memory"
]
=
get_max_memory
()
logger
.
info
(
"Quantizing model to {
} bit with AutoGPTQ."
.
format
(
model_args
.
export_quantization_bit
)
)
logger
.
info
_rank0
(
f
"Quantizing model to
{
model_args
.
export_quantization_bit
}
bit with AutoGPTQ."
)
elif
model_args
.
quantization_bit
is
not
None
:
# on-the-fly
if
model_args
.
quantization_method
==
QuantizationMethod
.
BITS_AND_BYTES
.
value
:
...
...
@@ -179,7 +179,7 @@ def configure_quantization(
else
:
init_kwargs
[
"device_map"
]
=
{
""
:
get_current_device
()}
# change auto device map for inference
logger
.
info
(
"Quantizing model to {
} bit with bitsandbytes."
.
format
(
model_args
.
quantization_bit
)
)
logger
.
info
_rank0
(
f
"Quantizing model to
{
model_args
.
quantization_bit
}
bit with bitsandbytes."
)
elif
model_args
.
quantization_method
==
QuantizationMethod
.
HQQ
.
value
:
if
model_args
.
quantization_bit
not
in
[
8
,
6
,
5
,
4
,
3
,
2
,
1
]:
raise
ValueError
(
"HQQ only accepts 1/2/3/4/5/6/8-bit quantization."
)
...
...
@@ -191,7 +191,7 @@ def configure_quantization(
init_kwargs
[
"quantization_config"
]
=
HqqConfig
(
nbits
=
model_args
.
quantization_bit
,
quant_zero
=
False
,
quant_scale
=
False
,
axis
=
0
)
# use ATEN kernel (axis=0) for performance
logger
.
info
(
"Quantizing model to {
} bit with HQQ."
.
format
(
model_args
.
quantization_bit
)
)
logger
.
info
_rank0
(
f
"Quantizing model to
{
model_args
.
quantization_bit
}
bit with HQQ."
)
elif
model_args
.
quantization_method
==
QuantizationMethod
.
EETQ
.
value
:
if
model_args
.
quantization_bit
!=
8
:
raise
ValueError
(
"EETQ only accepts 8-bit quantization."
)
...
...
@@ -201,4 +201,4 @@ def configure_quantization(
require_version
(
"eetq"
,
"To fix: pip install eetq"
)
init_kwargs
[
"quantization_config"
]
=
EetqConfig
()
logger
.
info
(
"Quantizing model to {
} bit with EETQ."
.
format
(
model_args
.
quantization_bit
)
)
logger
.
info
_rank0
(
f
"Quantizing model to
{
model_args
.
quantization_bit
}
bit with EETQ."
)
src/llamafactory/model/model_utils/rope.py
View file @
2778a3d0
...
...
@@ -19,7 +19,7 @@
import
math
from
typing
import
TYPE_CHECKING
from
...extras
.logging
import
get_
logg
er
from
...extras
import
logg
ing
if
TYPE_CHECKING
:
...
...
@@ -28,7 +28,7 @@ if TYPE_CHECKING:
from
...hparams
import
ModelArguments
logger
=
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
def
configure_rope
(
config
:
"PretrainedConfig"
,
model_args
:
"ModelArguments"
,
is_trainable
:
bool
)
->
None
:
...
...
@@ -36,30 +36,28 @@ def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_
return
if
not
hasattr
(
config
,
"rope_scaling"
):
logger
.
warning
(
"Current model does not support RoPE scaling."
)
logger
.
warning
_rank0
(
"Current model does not support RoPE scaling."
)
return
if
model_args
.
model_max_length
is
not
None
:
if
is_trainable
and
model_args
.
rope_scaling
==
"dynamic"
:
logger
.
warning
(
logger
.
warning
_rank0
(
"Dynamic NTK scaling may not work well with fine-tuning. "
"See: https://github.com/huggingface/transformers/pull/24653"
)
current_max_length
=
getattr
(
config
,
"max_position_embeddings"
,
None
)
if
current_max_length
and
model_args
.
model_max_length
>
current_max_length
:
logger
.
info
(
"Enlarge max model length from {} to {}."
.
format
(
current_max_length
,
model_args
.
model_max_length
)
)
logger
.
info_rank0
(
f
"Enlarge max model length from
{
current_max_length
}
to
{
model_args
.
model_max_length
}
."
)
setattr
(
config
,
"max_position_embeddings"
,
model_args
.
model_max_length
)
scaling_factor
=
float
(
math
.
ceil
(
model_args
.
model_max_length
/
current_max_length
))
else
:
logger
.
warning
(
"Input length is smaller than max length. Consider increase input length."
)
logger
.
warning
_rank0
(
"Input length is smaller than max length. Consider increase input length."
)
scaling_factor
=
1.0
else
:
scaling_factor
=
2.0
setattr
(
config
,
"rope_scaling"
,
{
"type"
:
model_args
.
rope_scaling
,
"factor"
:
scaling_factor
})
logger
.
info
(
"Using {} scaling strategy and setting scaling factor to {
}"
.
format
(
model_args
.
rope_scaling
,
scaling_factor
)
logger
.
info
_rank0
(
f
"Using
{
model_args
.
rope_scaling
}
scaling strategy and setting scaling factor to
{
scaling_factor
}
"
)
src/llamafactory/model/model_utils/unsloth.py
View file @
2778a3d0
...
...
@@ -14,7 +14,7 @@
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
Optional
from
...extras
.logging
import
get_
logg
er
from
...extras
import
logg
ing
from
...extras.misc
import
get_current_device
...
...
@@ -24,7 +24,7 @@ if TYPE_CHECKING:
from
...hparams
import
ModelArguments
logger
=
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
def
_get_unsloth_kwargs
(
...
...
@@ -56,7 +56,7 @@ def load_unsloth_pretrained_model(
try
:
model
,
_
=
FastLanguageModel
.
from_pretrained
(
**
unsloth_kwargs
)
except
NotImplementedError
:
logger
.
warning
(
"Unsloth does not support model type {}."
.
format
(
getattr
(
config
,
"model_type"
,
None
)))
logger
.
warning
_rank0
(
"Unsloth does not support model type {}."
.
format
(
getattr
(
config
,
"model_type"
,
None
)))
model
=
None
model_args
.
use_unsloth
=
False
...
...
src/llamafactory/model/model_utils/valuehead.py
View file @
2778a3d0
...
...
@@ -17,8 +17,8 @@ from typing import TYPE_CHECKING, Dict
import
torch
from
transformers.utils
import
cached_file
from
...extras
import
logging
from
...extras.constants
import
V_HEAD_SAFE_WEIGHTS_NAME
,
V_HEAD_WEIGHTS_NAME
from
...extras.logging
import
get_logger
if
TYPE_CHECKING
:
...
...
@@ -27,7 +27,7 @@ if TYPE_CHECKING:
from
...hparams
import
ModelArguments
logger
=
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
def
load_valuehead_params
(
path_or_repo_id
:
str
,
model_args
:
"ModelArguments"
)
->
Dict
[
str
,
torch
.
Tensor
]:
...
...
@@ -54,8 +54,8 @@ def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") ->
except
Exception
as
err
:
err_text
=
str
(
err
)
logger
.
info
(
"Provided path ({}) does not contain value head weights: {
}."
.
format
(
path_or_repo_id
,
err_text
)
)
logger
.
info
(
"Ignore the above message if you are not resuming the training of a value head model."
)
logger
.
info
_rank0
(
f
"Provided path (
{
path_or_repo_id
}
) does not contain value head weights:
{
err_text
}
."
)
logger
.
info
_rank0
(
"Ignore the above message if you are not resuming the training of a value head model."
)
return
None
...
...
src/llamafactory/model/model_utils/visual.py
View file @
2778a3d0
...
...
@@ -18,21 +18,21 @@
from
typing
import
TYPE_CHECKING
,
List
,
Sequence
,
Set
,
Tuple
,
Union
import
torch
import
transformers
import
transformers.models
from
transformers.activations
import
ACT2FN
from
transformers.utils
import
logging
from
...extras
.logging
import
get_
logg
er
from
...extras
import
logg
ing
if
TYPE_CHECKING
:
from
transformers
import
LlavaConfig
,
PretrainedConfig
,
PreTrainedModel
from
transformers
import
LlavaConfig
,
PretrainedConfig
,
PreTrainedModel
,
ProcessorMixin
from
...hparams
import
FinetuningArguments
,
ModelArguments
logger
=
get_logger
(
__name__
)
transformers_logger
=
logging
.
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
transformers_logger
=
transformers
.
utils
.
logging
.
get_logger
(
__name__
)
class
LlavaMultiModalProjectorForYiVL
(
torch
.
nn
.
Module
):
...
...
@@ -92,14 +92,14 @@ def autocast_projector_dtype(model: "PreTrainedModel", model_args: "ModelArgumen
if
getattr
(
model
,
"quantization_method"
,
None
):
model_type
=
getattr
(
model
.
config
,
"model_type"
,
None
)
if
model_type
in
[
"llava"
,
"llava_next"
,
"llava_next_video"
,
"paligemma"
,
"video_llava"
]:
if
model_type
in
[
"llava"
,
"llava_next"
,
"llava_next_video"
,
"paligemma"
,
"pixtral"
,
"video_llava"
]:
mm_projector
:
"torch.nn.Module"
=
getattr
(
model
,
"multi_modal_projector"
)
elif
model_type
==
"qwen2_vl"
:
mm_projector
:
"torch.nn.Module"
=
getattr
(
getattr
(
model
,
"visual"
),
"merger"
)
else
:
return
logger
.
info
(
"Casting multimodal projector outputs in {
}."
.
format
(
model_args
.
compute_dtype
)
)
logger
.
info
_rank0
(
f
"Casting multimodal projector outputs in
{
model_args
.
compute_dtype
}
."
)
mm_projector
.
register_forward_hook
(
_mm_projector_forward_post_hook
)
...
...
@@ -113,12 +113,13 @@ def configure_visual_model(config: "PretrainedConfig") -> None:
"llava_next"
,
"llava_next_video"
,
"paligemma"
,
"pixtral"
,
"video_llava"
,
]:
# required for ds zero3 and valuehead models
setattr
(
config
,
"hidden_size"
,
getattr
(
config
.
text_config
,
"hidden_size"
,
None
))
if
getattr
(
config
,
"is_yi_vl_derived_model"
,
None
):
logger
.
info
(
"Detected Yi-VL model, applying projector patch."
)
logger
.
info
_rank0
(
"Detected Yi-VL model, applying projector patch."
)
transformers
.
models
.
llava
.
modeling_llava
.
LlavaMultiModalProjector
=
LlavaMultiModalProjectorForYiVL
...
...
@@ -128,7 +129,7 @@ def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "Finetuni
"""
model_type
=
getattr
(
config
,
"model_type"
,
None
)
forbidden_modules
=
set
()
if
model_type
in
[
"llava"
,
"llava_next"
,
"llava_next_video"
,
"paligemma"
,
"video_llava"
]:
if
model_type
in
[
"llava"
,
"llava_next"
,
"llava_next_video"
,
"paligemma"
,
"pixtral"
,
"video_llava"
]:
if
finetuning_args
.
freeze_vision_tower
:
forbidden_modules
.
add
(
"vision_tower"
)
...
...
@@ -162,19 +163,21 @@ def get_image_seqlen(config: "PretrainedConfig") -> int:
return
image_seqlen
def
get_patch_size
(
config
:
"PretrainedConfig"
)
->
int
:
def
get_patch_size
(
config
:
"PretrainedConfig"
,
processor
:
"ProcessorMixin"
)
->
int
:
r
"""
Computes the patch size of the vit.
"""
patch_size
=
getattr
(
config
.
vision_config
,
"patch_size"
,
-
1
)
patch_size
=
getattr
(
config
.
vision_config
,
"patch_size"
,
getattr
(
processor
,
"patch_size"
,
-
1
)
)
return
patch_size
def
get_vision_feature_select_strategy
(
config
:
"PretrainedConfig"
)
->
int
:
def
get_vision_feature_select_strategy
(
config
:
"PretrainedConfig"
,
processor
:
"ProcessorMixin"
)
->
int
:
r
"""
Get the vision_feature_select_strategy.
"""
vision_feature_select_strategy
=
getattr
(
config
,
"vision_feature_select_strategy"
,
"default"
)
vision_feature_select_strategy
=
getattr
(
config
,
"vision_feature_select_strategy"
,
getattr
(
processor
,
"vision_feature_select_strategy"
,
"default"
)
)
return
vision_feature_select_strategy
...
...
@@ -186,8 +189,10 @@ def patch_target_modules(
"""
model_type
=
getattr
(
config
,
"model_type"
,
None
)
if
finetuning_args
.
freeze_vision_tower
:
if
model_type
in
[
"llava"
,
"llava_next"
,
"llava_next_video"
,
"paligemma"
,
"video_llava"
]:
if
model_type
in
[
"llava"
,
"llava_next"
,
"llava_next_video"
,
"paligemma"
,
"pixtral"
,
"video_llava"
]:
return
"^(?!.*vision_tower).*(?:{}).*"
.
format
(
"|"
.
join
(
target_modules
))
elif
model_type
==
"mllama"
:
return
"^(?!.*vision_model).*(?:{}).*"
.
format
(
"|"
.
join
(
target_modules
))
elif
model_type
==
"qwen2_vl"
:
return
"^(?!.*visual).*(?:{}).*"
.
format
(
"|"
.
join
(
target_modules
))
else
:
...
...
@@ -195,5 +200,7 @@ def patch_target_modules(
else
:
if
model_type
==
"qwen2_vl"
:
return
"^(?!.*patch_embed).*(?:{}).*"
.
format
(
"|"
.
join
(
target_modules
))
elif
model_type
==
"pixtral"
:
return
"^(?!.*patch_conv).*(?:{}).*"
.
format
(
"|"
.
join
(
target_modules
))
else
:
return
target_modules
src/llamafactory/model/patcher.py
View file @
2778a3d0
...
...
@@ -22,7 +22,7 @@ from transformers import PreTrainedModel, PreTrainedTokenizerBase, is_torch_npu_
from
transformers.integrations
import
is_deepspeed_zero3_enabled
from
transformers.modeling_utils
import
is_fsdp_enabled
from
..extras
.logging
import
get_
logg
er
from
..extras
import
logg
ing
from
..extras.misc
import
infer_optim_dtype
from
.model_utils.attention
import
configure_attn_implementation
,
print_attn_implementation
from
.model_utils.checkpointing
import
prepare_model_for_training
...
...
@@ -49,7 +49,7 @@ if TYPE_CHECKING:
from
..hparams
import
ModelArguments
logger
=
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
def
patch_tokenizer
(
tokenizer
:
"PreTrainedTokenizer"
)
->
None
:
...
...
@@ -66,11 +66,11 @@ def patch_processor(
setattr
(
processor
,
"tokenizer"
,
tokenizer
)
setattr
(
processor
,
"image_seqlen"
,
get_image_seqlen
(
config
))
setattr
(
processor
,
"image_resolution"
,
model_args
.
image_resolution
)
setattr
(
processor
,
"patch_size"
,
get_patch_size
(
config
))
setattr
(
processor
,
"patch_size"
,
get_patch_size
(
config
,
processor
))
setattr
(
processor
,
"video_resolution"
,
model_args
.
video_resolution
)
setattr
(
processor
,
"video_fps"
,
model_args
.
video_fps
)
setattr
(
processor
,
"video_maxlen"
,
model_args
.
video_maxlen
)
setattr
(
processor
,
"vision_feature_select_strategy"
,
get_vision_feature_select_strategy
(
config
))
setattr
(
processor
,
"vision_feature_select_strategy"
,
get_vision_feature_select_strategy
(
config
,
processor
))
def
patch_config
(
...
...
@@ -100,7 +100,7 @@ def patch_config(
if
model_args
.
use_cache
and
not
is_trainable
:
setattr
(
config
,
"use_cache"
,
True
)
logger
.
info
(
"Using KV cache for faster generation."
)
logger
.
info
_rank0
(
"Using KV cache for faster generation."
)
if
getattr
(
config
,
"model_type"
,
None
)
==
"qwen"
:
setattr
(
config
,
"use_flash_attn"
,
model_args
.
flash_attn
==
"fa2"
)
...
...
@@ -165,7 +165,7 @@ def patch_model(
try
:
model
.
add_model_tags
([
"llama-factory"
])
except
Exception
:
logger
.
warning
(
"Cannot properly tag the model."
)
logger
.
warning
_rank0
(
"Cannot properly tag the model."
)
def
patch_valuehead_model
(
model
:
"AutoModelForCausalLMWithValueHead"
)
->
None
:
...
...
src/llamafactory/train/callbacks.py
View file @
2778a3d0
...
...
@@ -13,7 +13,6 @@
# limitations under the License.
import
json
import
logging
import
os
import
signal
import
sys
...
...
@@ -34,8 +33,8 @@ from transformers.utils import (
)
from
typing_extensions
import
override
from
..extras
import
logging
from
..extras.constants
import
TRAINER_LOG
,
V_HEAD_SAFE_WEIGHTS_NAME
,
V_HEAD_WEIGHTS_NAME
from
..extras.logging
import
LoggerHandler
,
get_logger
from
..extras.misc
import
get_peak_memory
...
...
@@ -48,7 +47,7 @@ if TYPE_CHECKING:
from
trl
import
AutoModelForCausalLMWithValueHead
logger
=
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
def
fix_valuehead_checkpoint
(
...
...
@@ -92,7 +91,7 @@ def fix_valuehead_checkpoint(
else
:
torch
.
save
(
v_head_state_dict
,
os
.
path
.
join
(
output_dir
,
V_HEAD_WEIGHTS_NAME
))
logger
.
info
(
"Value head model saved at: {
}"
.
format
(
output_dir
)
)
logger
.
info
_rank0
(
f
"Value head model saved at:
{
output_dir
}
"
)
class
FixValueHeadModelCallback
(
TrainerCallback
):
...
...
@@ -106,7 +105,7 @@ class FixValueHeadModelCallback(TrainerCallback):
Event called after a checkpoint save.
"""
if
args
.
should_save
:
output_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"{
}-{}"
.
format
(
PREFIX_CHECKPOINT_DIR
,
state
.
global_step
)
)
output_dir
=
os
.
path
.
join
(
args
.
output_dir
,
f
"
{
PREFIX_CHECKPOINT_DIR
}
-
{
state
.
global_step
}
"
)
fix_valuehead_checkpoint
(
model
=
kwargs
.
pop
(
"model"
),
output_dir
=
output_dir
,
safe_serialization
=
args
.
save_safetensors
)
...
...
@@ -123,13 +122,13 @@ class SaveProcessorCallback(TrainerCallback):
@
override
def
on_save
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
if
args
.
should_save
:
output_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"{
}-{}"
.
format
(
PREFIX_CHECKPOINT_DIR
,
state
.
global_step
)
)
getattr
(
self
.
processor
,
"image_processor"
)
.
save_pretrained
(
output_dir
)
output_dir
=
os
.
path
.
join
(
args
.
output_dir
,
f
"
{
PREFIX_CHECKPOINT_DIR
}
-
{
state
.
global_step
}
"
)
self
.
processor
.
save_pretrained
(
output_dir
)
@
override
def
on_train_end
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
if
args
.
should_save
:
getattr
(
self
.
processor
,
"image_processor"
)
.
save_pretrained
(
args
.
output_dir
)
self
.
processor
.
save_pretrained
(
args
.
output_dir
)
class
PissaConvertCallback
(
TrainerCallback
):
...
...
@@ -145,7 +144,7 @@ class PissaConvertCallback(TrainerCallback):
if
args
.
should_save
:
model
=
kwargs
.
pop
(
"model"
)
pissa_init_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"pissa_init"
)
logger
.
info
(
"Initial PiSSA adapter will be saved at: {
}."
.
format
(
pissa_init_dir
)
)
logger
.
info
_rank0
(
f
"Initial PiSSA adapter will be saved at:
{
pissa_init_dir
}
."
)
if
isinstance
(
model
,
PeftModel
):
init_lora_weights
=
getattr
(
model
.
peft_config
[
"default"
],
"init_lora_weights"
)
setattr
(
model
.
peft_config
[
"default"
],
"init_lora_weights"
,
True
)
...
...
@@ -159,7 +158,7 @@ class PissaConvertCallback(TrainerCallback):
pissa_init_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"pissa_init"
)
pissa_backup_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"pissa_backup"
)
pissa_convert_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"pissa_converted"
)
logger
.
info
(
"Converted PiSSA adapter will be saved at: {
}."
.
format
(
pissa_convert_dir
)
)
logger
.
info
_rank0
(
f
"Converted PiSSA adapter will be saved at:
{
pissa_convert_dir
}
."
)
# 1. save a pissa backup with init_lora_weights: True
# 2. save a converted lora with init_lora_weights: pissa
# 3. load the pissa backup with init_lora_weights: True
...
...
@@ -200,8 +199,8 @@ class LogCallback(TrainerCallback):
self
.
webui_mode
=
os
.
environ
.
get
(
"LLAMABOARD_ENABLED"
,
"0"
).
lower
()
in
[
"true"
,
"1"
]
if
self
.
webui_mode
:
signal
.
signal
(
signal
.
SIGABRT
,
self
.
_set_abort
)
self
.
logger_handler
=
LoggerHandler
(
os
.
environ
.
get
(
"LLAMABOARD_WORKDIR"
))
logging
.
root
.
add
H
andler
(
self
.
logger_handler
)
self
.
logger_handler
=
logging
.
LoggerHandler
(
os
.
environ
.
get
(
"LLAMABOARD_WORKDIR"
))
logging
.
add
_h
andler
(
self
.
logger_handler
)
transformers
.
logging
.
add_handler
(
self
.
logger_handler
)
def
_set_abort
(
self
,
signum
,
frame
)
->
None
:
...
...
@@ -243,7 +242,7 @@ class LogCallback(TrainerCallback):
and
os
.
path
.
exists
(
os
.
path
.
join
(
args
.
output_dir
,
TRAINER_LOG
))
and
args
.
overwrite_output_dir
):
logger
.
warning
(
"Previous trainer log in this folder will be deleted."
)
logger
.
warning
_once
(
"Previous trainer log in this folder will be deleted."
)
os
.
remove
(
os
.
path
.
join
(
args
.
output_dir
,
TRAINER_LOG
))
@
override
...
...
@@ -288,13 +287,13 @@ class LogCallback(TrainerCallback):
logs
=
dict
(
current_steps
=
self
.
cur_steps
,
total_steps
=
self
.
max_steps
,
loss
=
state
.
log_history
[
-
1
].
get
(
"loss"
,
None
),
eval_loss
=
state
.
log_history
[
-
1
].
get
(
"eval_loss"
,
None
),
predict_loss
=
state
.
log_history
[
-
1
].
get
(
"predict_loss"
,
None
),
reward
=
state
.
log_history
[
-
1
].
get
(
"reward"
,
None
),
accuracy
=
state
.
log_history
[
-
1
].
get
(
"rewards/accuracies"
,
None
),
l
earning_rate
=
state
.
log_history
[
-
1
].
get
(
"learning_rate"
,
None
),
epoch
=
state
.
log_history
[
-
1
].
get
(
"epoch"
,
None
),
loss
=
state
.
log_history
[
-
1
].
get
(
"loss"
),
eval_loss
=
state
.
log_history
[
-
1
].
get
(
"eval_loss"
),
predict_loss
=
state
.
log_history
[
-
1
].
get
(
"predict_loss"
),
reward
=
state
.
log_history
[
-
1
].
get
(
"reward"
),
accuracy
=
state
.
log_history
[
-
1
].
get
(
"rewards/accuracies"
),
l
r
=
state
.
log_history
[
-
1
].
get
(
"learning_rate"
),
epoch
=
state
.
log_history
[
-
1
].
get
(
"epoch"
),
percentage
=
round
(
self
.
cur_steps
/
self
.
max_steps
*
100
,
2
)
if
self
.
max_steps
!=
0
else
100
,
elapsed_time
=
self
.
elapsed_time
,
remaining_time
=
self
.
remaining_time
,
...
...
@@ -305,16 +304,17 @@ class LogCallback(TrainerCallback):
if
os
.
environ
.
get
(
"RECORD_VRAM"
,
"0"
).
lower
()
in
[
"true"
,
"1"
]:
vram_allocated
,
vram_reserved
=
get_peak_memory
()
logs
[
"vram_allocated"
]
=
round
(
vram_allocated
/
1024
/
1024
/
1024
,
2
)
logs
[
"vram_reserved"
]
=
round
(
vram_reserved
/
1024
/
1024
/
1024
,
2
)
logs
[
"vram_allocated"
]
=
round
(
vram_allocated
/
(
1024
**
3
)
,
2
)
logs
[
"vram_reserved"
]
=
round
(
vram_reserved
/
(
1024
**
3
)
,
2
)
logs
=
{
k
:
v
for
k
,
v
in
logs
.
items
()
if
v
is
not
None
}
if
self
.
webui_mode
and
all
(
key
in
logs
for
key
in
[
"loss"
,
"learning_rate"
,
"epoch"
]):
logger
.
info
(
"{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}, 'throughput': {}}}"
.
format
(
logs
[
"loss"
],
logs
[
"learning_rate"
],
logs
[
"epoch"
],
logs
.
get
(
"throughput"
,
"N/A"
)
)
)
if
self
.
webui_mode
and
all
(
key
in
logs
for
key
in
(
"loss"
,
"lr"
,
"epoch"
)):
log_str
=
f
"'loss':
{
logs
[
'loss'
]:.
4
f
}
, 'learning_rate':
{
logs
[
'lr'
]:
2.4
e
}
, 'epoch':
{
logs
[
'epoch'
]:.
2
f
}
"
for
extra_key
in
(
"reward"
,
"accuracy"
,
"throughput"
):
if
logs
.
get
(
extra_key
):
log_str
+=
f
", '
{
extra_key
}
':
{
logs
[
extra_key
]:.
2
f
}
"
logger
.
info_rank0
(
"{"
+
log_str
+
"}"
)
if
self
.
thread_pool
is
not
None
:
self
.
thread_pool
.
submit
(
self
.
_write_log
,
args
.
output_dir
,
logs
)
...
...
src/llamafactory/train/dpo/trainer.py
View file @
2778a3d0
...
...
@@ -29,6 +29,7 @@ from trl.trainer import disable_dropout_in_model
from
typing_extensions
import
override
from
...extras.constants
import
IGNORE_INDEX
from
...extras.packages
import
is_transformers_version_equal_to_4_46
from
..callbacks
import
PissaConvertCallback
,
SaveProcessorCallback
from
..trainer_utils
import
create_custom_optimizer
,
create_custom_scheduler
,
get_batch_logps
...
...
@@ -100,7 +101,7 @@ class CustomDPOTrainer(DPOTrainer):
self
.
callback_handler
.
add_callback
(
PissaConvertCallback
)
if
finetuning_args
.
use_badam
:
from
badam
import
BAdamCallback
,
clip_grad_norm_old_version
from
badam
import
BAdamCallback
,
clip_grad_norm_old_version
# type: ignore
self
.
accelerator
.
clip_grad_norm_
=
MethodType
(
clip_grad_norm_old_version
,
self
.
accelerator
)
self
.
add_callback
(
BAdamCallback
)
...
...
@@ -118,6 +119,13 @@ class CustomDPOTrainer(DPOTrainer):
create_custom_scheduler
(
self
.
args
,
num_training_steps
,
optimizer
)
return
super
().
create_scheduler
(
num_training_steps
,
optimizer
)
@
override
def
get_batch_samples
(
self
,
epoch_iterator
,
num_batches
):
r
"""
Replaces the method of KTO Trainer with the one of the standard Trainer.
"""
return
Trainer
.
get_batch_samples
(
self
,
epoch_iterator
,
num_batches
)
def
odds_ratio_loss
(
self
,
chosen_logps
:
"torch.Tensor"
,
rejected_logps
:
"torch.Tensor"
)
->
"torch.Tensor"
:
r
"""
Computes ORPO's odds ratio (OR) loss for batched log probabilities of the policy model.
...
...
@@ -156,7 +164,7 @@ class CustomDPOTrainer(DPOTrainer):
elif
self
.
loss_type
==
"simpo"
:
losses
=
self
.
simpo_loss
(
policy_chosen_logps
,
policy_rejected_logps
)
else
:
raise
NotImplementedError
(
"Unknown loss type: {
}."
.
format
(
self
.
loss_type
)
)
raise
NotImplementedError
(
f
"Unknown loss type:
{
self
.
loss_type
}
."
)
chosen_rewards
=
self
.
beta
*
policy_chosen_logps
.
to
(
self
.
accelerator
.
device
).
detach
()
rejected_rewards
=
self
.
beta
*
policy_rejected_logps
.
to
(
self
.
accelerator
.
device
).
detach
()
...
...
@@ -242,19 +250,59 @@ class CustomDPOTrainer(DPOTrainer):
if
self
.
ftx_gamma
>
1e-6
:
losses
+=
self
.
ftx_gamma
*
sft_loss
reward_accuracies
=
(
chosen_rewards
>
rejected_rewards
).
float
()
prefix
=
"eval_"
if
train_eval
==
"eval"
else
""
metrics
[
"{}rewards/chosen"
.
format
(
prefix
)
]
=
chosen_rewards
.
mean
().
cpu
()
metrics
[
"{}rewards/rejected"
.
format
(
prefix
)
]
=
rejected_rewards
.
mean
().
cpu
()
metrics
[
"{}rewards/accuracies"
.
format
(
prefix
)]
=
reward_accuracies
.
mean
().
cpu
()
metrics
[
"{}rewards/margins"
.
format
(
prefix
)
]
=
(
chosen_rewards
-
rejected_rewards
).
mean
().
cpu
()
metrics
[
"{}logps/
rejected"
.
format
(
prefix
)]
=
policy_rejected_logps
.
detach
()
.
mean
().
cpu
()
metrics
[
"{}logps/
chosen"
.
format
(
prefix
)]
=
policy_chosen_logps
.
detach
()
.
mean
().
cpu
()
metrics
[
"{}logits/
rejected"
.
format
(
prefix
)]
=
policy_rejected_logits
.
detach
()
.
mean
().
cpu
()
metrics
[
"{}logits/
chosen"
.
format
(
prefix
)]
=
policy_chosen_logits
.
detach
()
.
mean
().
cpu
()
metrics
[
f
"
{
prefix
}
rewards/chosen"
]
=
chosen_rewards
.
mean
().
item
()
metrics
[
f
"
{
prefix
}
rewards/rejected"
]
=
rejected_rewards
.
mean
().
item
()
metrics
[
f
"
{
prefix
}
rewards/accuracies"
]
=
(
chosen_rewards
>
rejected_rewards
).
float
()
.
mean
().
item
()
metrics
[
f
"
{
prefix
}
rewards/margins"
]
=
(
chosen_rewards
-
rejected_rewards
).
mean
().
item
()
metrics
[
f
"
{
prefix
}
logps/
chosen"
]
=
policy_chosen_logps
.
mean
().
item
()
metrics
[
f
"
{
prefix
}
logps/
rejected"
]
=
policy_rejected_logps
.
mean
().
item
()
metrics
[
f
"
{
prefix
}
logits/
chosen"
]
=
policy_chosen_logits
.
mean
().
item
()
metrics
[
f
"
{
prefix
}
logits/
rejected"
]
=
policy_rejected_logits
.
mean
().
item
()
if
self
.
loss_type
==
"orpo"
:
metrics
[
"{}sft_loss"
.
format
(
prefix
)
]
=
sft_loss
.
detach
().
mean
().
cpu
()
metrics
[
"{}odds_ratio_loss"
.
format
(
prefix
)
]
=
((
losses
-
sft_loss
)
/
self
.
beta
).
detach
().
mean
().
cpu
()
metrics
[
f
"
{
prefix
}
sft_loss"
]
=
sft_loss
.
mean
().
item
()
metrics
[
f
"
{
prefix
}
odds_ratio_loss"
]
=
((
losses
-
sft_loss
)
/
self
.
beta
).
mean
().
item
()
return
losses
.
mean
(),
metrics
@
override
def
compute_loss
(
self
,
model
,
inputs
,
return_outputs
=
False
,
**
kwargs
):
r
"""
Fixes the loss value for transformers 4.46.0.
https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605
"""
loss
=
super
().
compute_loss
(
model
,
inputs
,
return_outputs
)
if
is_transformers_version_equal_to_4_46
()
and
kwargs
.
pop
(
"num_items_in_batch"
,
False
):
if
return_outputs
:
return
(
loss
[
0
]
/
self
.
args
.
gradient_accumulation_steps
,
*
loss
[
1
:])
else
:
return
loss
/
self
.
args
.
gradient_accumulation_steps
return
loss
@
override
def
log
(
self
,
logs
:
Dict
[
str
,
float
])
->
None
:
r
"""
Log `logs` on the various objects watching training, including stored metrics.
"""
# logs either has "loss" or "eval_loss"
train_eval
=
"train"
if
"loss"
in
logs
else
"eval"
# Add averaged stored metrics to logs
key_list
,
metric_list
=
[],
[]
for
key
,
metrics
in
self
.
_stored_metrics
[
train_eval
].
items
():
key_list
.
append
(
key
)
metric_list
.
append
(
torch
.
tensor
(
metrics
,
dtype
=
torch
.
float
).
to
(
self
.
accelerator
.
device
).
mean
().
item
())
del
self
.
_stored_metrics
[
train_eval
]
if
len
(
metric_list
)
<
10
:
# pad to for all reduce
for
i
in
range
(
10
-
len
(
metric_list
)):
key_list
.
append
(
f
"dummy_
{
i
}
"
)
metric_list
.
append
(
0.0
)
metric_list
=
torch
.
tensor
(
metric_list
,
dtype
=
torch
.
float
).
to
(
self
.
accelerator
.
device
)
metric_list
=
self
.
accelerator
.
reduce
(
metric_list
,
"mean"
).
tolist
()
for
key
,
metric
in
zip
(
key_list
,
metric_list
):
# add remaining items
if
not
key
.
startswith
(
"dummy_"
):
logs
[
key
]
=
metric
return
Trainer
.
log
(
self
,
logs
)
src/llamafactory/train/dpo/workflow.py
View file @
2778a3d0
...
...
@@ -19,6 +19,7 @@ from typing import TYPE_CHECKING, List, Optional
from
...data
import
PairwiseDataCollatorWithPadding
,
get_dataset
,
get_template_and_fix_tokenizer
from
...extras.constants
import
IGNORE_INDEX
from
...extras.misc
import
cal_effective_tokens
from
...extras.ploting
import
plot_loss
from
...hparams
import
ModelArguments
from
...model
import
load_model
,
load_tokenizer
...
...
@@ -64,6 +65,12 @@ def run_dpo(
# Update arguments
training_args
.
remove_unused_columns
=
False
# important for multimodal and pairwise dataset
effective_token_num
=
0.0
if
finetuning_args
.
include_effective_tokens_per_second
:
for
data
in
dataset_module
[
"train_dataset"
]:
effective_token_num
+=
len
(
data
[
"chosen_input_ids"
])
effective_token_num
+=
len
(
data
[
"rejected_input_ids"
])
# Initialize our Trainer
trainer
=
CustomDPOTrainer
(
model
=
model
,
...
...
@@ -79,6 +86,12 @@ def run_dpo(
# Training
if
training_args
.
do_train
:
train_result
=
trainer
.
train
(
resume_from_checkpoint
=
training_args
.
resume_from_checkpoint
)
if
finetuning_args
.
include_effective_tokens_per_second
:
train_result
.
metrics
[
"effective_tokens_per_sec"
]
=
cal_effective_tokens
(
effective_token_num
,
train_result
.
metrics
[
"epoch"
],
train_result
.
metrics
[
"train_runtime"
]
)
trainer
.
save_model
()
trainer
.
log_metrics
(
"train"
,
train_result
.
metrics
)
trainer
.
save_metrics
(
"train"
,
train_result
.
metrics
)
...
...
src/llamafactory/train/kto/trainer.py
View file @
2778a3d0
...
...
@@ -28,6 +28,7 @@ from trl.trainer import disable_dropout_in_model
from
typing_extensions
import
override
from
...extras.constants
import
IGNORE_INDEX
from
...extras.packages
import
is_transformers_version_equal_to_4_46
from
..callbacks
import
SaveProcessorCallback
from
..trainer_utils
import
create_custom_optimizer
,
create_custom_scheduler
,
get_batch_logps
...
...
@@ -95,7 +96,7 @@ class CustomKTOTrainer(KTOTrainer):
self
.
add_callback
(
SaveProcessorCallback
(
processor
))
if
finetuning_args
.
use_badam
:
from
badam
import
BAdamCallback
,
clip_grad_norm_old_version
from
badam
import
BAdamCallback
,
clip_grad_norm_old_version
# type: ignore
self
.
accelerator
.
clip_grad_norm_
=
MethodType
(
clip_grad_norm_old_version
,
self
.
accelerator
)
self
.
add_callback
(
BAdamCallback
)
...
...
@@ -120,20 +121,27 @@ class CustomKTOTrainer(KTOTrainer):
"""
return
Trainer
.
_get_train_sampler
(
self
)
@
override
def
get_batch_samples
(
self
,
epoch_iterator
,
num_batches
):
r
"""
Replaces the method of KTO Trainer with the one of the standard Trainer.
"""
return
Trainer
.
get_batch_samples
(
self
,
epoch_iterator
,
num_batches
)
@
override
def
forward
(
self
,
model
:
"PreTrainedModel"
,
batch
:
Dict
[
str
,
"torch.Tensor"
],
prefix
:
Literal
[
""
,
"kl_"
]
=
""
)
->
Tuple
[
"torch.Tensor"
,
"torch.Tensor"
]:
)
->
Tuple
[
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
]:
r
"""
Runs forward pass and computes the log probabilities.
"""
batch
=
{
k
:
v
.
detach
().
clone
()
for
k
,
v
in
batch
.
items
()}
# avoid error
model_inputs
=
{
"input_ids"
:
batch
[
"{}input_ids"
.
format
(
prefix
)
],
"attention_mask"
:
batch
[
"{}attention_mask"
.
format
(
prefix
)
],
"input_ids"
:
batch
[
f
"
{
prefix
}
input_ids"
],
"attention_mask"
:
batch
[
f
"
{
prefix
}
attention_mask"
],
}
if
"{}token_type_ids"
.
format
(
prefix
)
in
batch
:
model_inputs
[
"token_type_ids"
]
=
batch
[
"{}token_type_ids"
.
format
(
prefix
)
]
if
f
"
{
prefix
}
token_type_ids"
in
batch
:
model_inputs
[
"token_type_ids"
]
=
batch
[
f
"
{
prefix
}
token_type_ids"
]
if
"pixel_values"
in
batch
:
model_inputs
[
"pixel_values"
]
=
batch
[
"pixel_values"
]
...
...
@@ -142,24 +150,26 @@ class CustomKTOTrainer(KTOTrainer):
model_inputs
[
"image_grid_thw"
]
=
batch
[
"image_grid_thw"
]
logits
=
model
(
**
model_inputs
,
return_dict
=
True
,
use_cache
=
False
).
logits
.
to
(
torch
.
float32
)
logps
,
valid_length
=
get_batch_logps
(
logits
=
logits
,
labels
=
batch
[
"{}labels"
.
format
(
prefix
)
])
return
logps
,
logps
/
valid_length
logps
,
valid_length
=
get_batch_logps
(
logits
=
logits
,
labels
=
batch
[
f
"
{
prefix
}
labels"
])
return
logits
,
logps
,
logps
/
valid_length
@
override
def
concatenated_forward
(
self
,
model
:
"PreTrainedModel"
,
batch
:
Dict
[
str
,
"torch.Tensor"
]
)
->
Tuple
[
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
]:
target_logps
,
target_logps_avg
=
self
.
forward
(
model
,
batch
)
)
->
Tuple
[
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
]:
target_logits
,
target_logps
,
target_logps_avg
=
self
.
forward
(
model
,
batch
)
with
torch
.
no_grad
():
kl_logps
,
_
=
self
.
forward
(
model
,
batch
,
prefix
=
"kl_"
)
_
,
kl_logps
,
_
=
self
.
forward
(
model
,
batch
,
prefix
=
"kl_"
)
if
len
(
target_logps
)
!=
len
(
batch
[
"kto_tags"
]):
raise
ValueError
(
"Mismatched shape of inputs and labels."
)
chosen_logits
=
target_logits
[
batch
[
"kto_tags"
]]
chosen_logps
=
target_logps
[
batch
[
"kto_tags"
]]
rejected_logits
=
target_logits
[
~
batch
[
"kto_tags"
]]
rejected_logps
=
target_logps
[
~
batch
[
"kto_tags"
]]
chosen_logps_avg
=
target_logps_avg
[
batch
[
"kto_tags"
]]
return
chosen_logps
,
rejected_logps
,
kl_logps
,
chosen_logps_avg
return
chosen_logps
,
rejected_logps
,
chosen_logits
,
rejected_logits
,
kl_logps
,
chosen_logps_avg
@
override
def
compute_reference_log_probs
(
...
...
@@ -176,7 +186,7 @@ class CustomKTOTrainer(KTOTrainer):
ref_context
=
nullcontext
()
with
torch
.
no_grad
(),
ref_context
:
reference_chosen_logps
,
reference_rejected_logps
,
reference_kl_logps
,
_
=
self
.
concatenated_forward
(
reference_chosen_logps
,
reference_rejected_logps
,
_
,
_
,
reference_kl_logps
,
_
=
self
.
concatenated_forward
(
ref_model
,
batch
)
...
...
@@ -192,9 +202,14 @@ class CustomKTOTrainer(KTOTrainer):
Computes the DPO loss and other metrics for the given batch of inputs for train or test.
"""
metrics
=
{}
policy_chosen_logps
,
policy_rejected_logps
,
policy_kl_logps
,
policy_chosen_logps_avg
=
(
self
.
concatenated_forward
(
model
,
batch
)
)
(
policy_chosen_logps
,
policy_rejected_logps
,
policy_chosen_logits
,
policy_rejected_logits
,
policy_kl_logps
,
policy_chosen_logps_avg
,
)
=
self
.
concatenated_forward
(
model
,
batch
)
reference_chosen_logps
,
reference_rejected_logps
,
reference_kl_logps
=
self
.
compute_reference_log_probs
(
model
,
batch
)
...
...
@@ -212,22 +227,73 @@ class CustomKTOTrainer(KTOTrainer):
sft_loss
=
-
policy_chosen_logps_avg
losses
+=
self
.
ftx_gamma
*
sft_loss
.
nanmean
()
/
len
(
policy_chosen_logps
)
*
len
(
batch
[
"labels"
])
num_chosen
=
torch
.
Tensor
([
len
(
chosen_rewards
)]).
to
(
self
.
accelerator
.
device
)
num_rejected
=
torch
.
Tensor
([
len
(
rejected_rewards
)]).
to
(
self
.
accelerator
.
device
)
num_chosen
=
len
(
chosen_rewards
)
num_rejected
=
len
(
rejected_rewards
)
if
num_chosen
>
0
:
metrics
[
"rewards/chosen_sum"
]
=
chosen_rewards
.
nansum
().
item
()
metrics
[
"logps/chosen_sum"
]
=
policy_chosen_logps
.
nansum
().
item
()
metrics
[
"logits/chosen_sum"
]
=
policy_chosen_logits
.
nansum
().
item
()
metrics
[
"count/chosen"
]
=
float
(
num_chosen
)
all_num_chosen
=
self
.
accelerator
.
gather
(
num_chosen
).
sum
().
item
()
all_num_rejected
=
self
.
accelerator
.
gather
(
num_rejected
).
sum
().
item
()
if
num_rejected
>
0
:
metrics
[
"rewards/rejected_sum"
]
=
rejected_rewards
.
nansum
().
item
()
metrics
[
"logps/rejected_sum"
]
=
policy_rejected_logps
.
nansum
().
item
()
metrics
[
"logits/rejected_sum"
]
=
policy_rejected_logits
.
nansum
().
item
()
metrics
[
"count/rejected"
]
=
float
(
num_rejected
)
if
all_num_chosen
>
0
:
metrics
[
"rewards/chosen_sum"
]
=
self
.
accelerator
.
gather
(
chosen_rewards
.
nansum
()).
nansum
().
item
()
metrics
[
"logps/chosen_sum"
]
=
self
.
accelerator
.
gather
(
policy_chosen_logps
.
nansum
()).
nansum
().
item
()
metrics
[
"count/chosen"
]
=
all_num_chosen
metrics
[
"kl"
]
=
kl
.
item
()
return
losses
,
metrics
if
all_num_rejected
>
0
:
metrics
[
"rewards/rejected_sum"
]
=
self
.
accelerator
.
gather
(
rejected_rewards
.
nansum
()).
nansum
().
item
()
metrics
[
"logps/rejected_sum"
]
=
self
.
accelerator
.
gather
(
policy_rejected_logps
.
nansum
()).
nansum
().
item
()
metrics
[
"count/rejected"
]
=
all_num_rejected
@
override
def
compute_loss
(
self
,
model
,
inputs
,
return_outputs
=
False
,
**
kwargs
):
r
"""
Fixes the loss value for transformers 4.46.0.
https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605
"""
loss
=
super
().
compute_loss
(
model
,
inputs
,
return_outputs
)
if
is_transformers_version_equal_to_4_46
()
and
kwargs
.
pop
(
"num_items_in_batch"
,
False
):
if
return_outputs
:
return
(
loss
[
0
]
/
self
.
args
.
gradient_accumulation_steps
,
*
loss
[
1
:])
else
:
return
loss
/
self
.
args
.
gradient_accumulation_steps
m
et
rics
[
"kl"
]
=
kl
.
item
()
r
et
urn
loss
return
losses
,
metrics
@
override
def
log
(
self
,
logs
:
Dict
[
str
,
float
])
->
None
:
r
"""
Log `logs` on the various objects watching training, including stored metrics.
"""
# logs either has "loss" or "eval_loss"
train_eval
=
"train"
if
"loss"
in
logs
else
"eval"
prefix
=
"eval_"
if
train_eval
==
"eval"
else
""
# Add averaged stored metrics to logs
key_list
,
metric_list
=
[],
[]
for
key
,
metrics
in
self
.
_stored_metrics
[
train_eval
].
items
():
key_list
.
append
(
key
)
metric_list
.
append
(
torch
.
tensor
(
metrics
,
dtype
=
torch
.
float
).
to
(
self
.
accelerator
.
device
).
sum
().
item
())
del
self
.
_stored_metrics
[
train_eval
]
if
len
(
metric_list
)
<
9
:
# pad to for all reduce
for
i
in
range
(
9
-
len
(
metric_list
)):
key_list
.
append
(
f
"dummy_
{
i
}
"
)
metric_list
.
append
(
0.0
)
metric_list
=
torch
.
tensor
(
metric_list
,
dtype
=
torch
.
float
).
to
(
self
.
accelerator
.
device
)
metric_list
=
self
.
accelerator
.
reduce
(
metric_list
,
"sum"
).
tolist
()
metric_dict
:
Dict
[
str
,
float
]
=
dict
(
zip
(
key_list
,
metric_list
))
for
split
in
[
"chosen"
,
"rejected"
]:
# accumulate average metrics from sums and lengths
if
f
"count/
{
split
}
"
in
metric_dict
:
for
key
in
(
"rewards"
,
"logps"
,
"logits"
):
logs
[
f
"
{
prefix
}{
key
}
/
{
split
}
"
]
=
metric_dict
[
f
"
{
key
}
/
{
split
}
_sum"
]
/
metric_dict
[
f
"count/
{
split
}
"
]
del
metric_dict
[
f
"
{
key
}
/
{
split
}
_sum"
]
del
metric_dict
[
f
"count/
{
split
}
"
]
if
f
"
{
prefix
}
rewards/chosen"
in
logs
and
f
"
{
prefix
}
rewards/rejected"
in
logs
:
# calculate reward margin
logs
[
f
"
{
prefix
}
rewards/margins"
]
=
logs
[
f
"
{
prefix
}
rewards/chosen"
]
-
logs
[
f
"
{
prefix
}
rewards/rejected"
]
for
key
,
metric
in
metric_dict
.
items
():
# add remaining items
if
not
key
.
startswith
(
"dummy_"
):
logs
[
key
]
=
metric
return
Trainer
.
log
(
self
,
logs
)
src/llamafactory/train/kto/workflow.py
View file @
2778a3d0
...
...
@@ -81,7 +81,7 @@ def run_kto(
trainer
.
save_metrics
(
"train"
,
train_result
.
metrics
)
trainer
.
save_state
()
if
trainer
.
is_world_process_zero
()
and
finetuning_args
.
plot_loss
:
plot_loss
(
training_args
.
output_dir
,
keys
=
[
"loss"
,
"eval_loss"
,
"
train/
rewards/chosen"
])
plot_loss
(
training_args
.
output_dir
,
keys
=
[
"loss"
,
"eval_loss"
,
"rewards/chosen"
])
# Evaluation
if
training_args
.
do_eval
:
...
...
src/llamafactory/train/ppo/ppo_utils.py
View file @
2778a3d0
...
...
@@ -62,8 +62,8 @@ def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["d
setattr
(
model
,
"default_head_bias"
,
v_head_layer
.
bias
.
data
.
detach
().
clone
())
device
=
v_head_layer
.
weight
.
device
v_head_layer
.
weight
.
data
=
model
.
get_buffer
(
"{}_head_weight"
.
format
(
target
)
).
detach
().
clone
().
to
(
device
)
v_head_layer
.
bias
.
data
=
model
.
get_buffer
(
"{}_head_bias"
.
format
(
target
)
).
detach
().
clone
().
to
(
device
)
v_head_layer
.
weight
.
data
=
model
.
get_buffer
(
f
"
{
target
}
_head_weight"
).
detach
().
clone
().
to
(
device
)
v_head_layer
.
bias
.
data
=
model
.
get_buffer
(
f
"
{
target
}
_head_bias"
).
detach
().
clone
().
to
(
device
)
def
dump_layernorm
(
model
:
"PreTrainedModel"
)
->
Dict
[
str
,
"torch.Tensor"
]:
...
...
src/llamafactory/train/ppo/trainer.py
View file @
2778a3d0
...
...
@@ -37,7 +37,7 @@ from trl.core import PPODecorators, logprobs_from_logits
from
trl.models.utils
import
unwrap_model_for_generation
from
typing_extensions
import
override
from
...extras
.logging
import
get_
logg
er
from
...extras
import
logg
ing
from
...extras.misc
import
AverageMeter
,
count_parameters
,
get_current_device
,
get_logits_processor
from
..callbacks
import
FixValueHeadModelCallback
,
SaveProcessorCallback
from
..trainer_utils
import
create_custom_optimizer
,
create_custom_scheduler
...
...
@@ -58,7 +58,7 @@ if TYPE_CHECKING:
from
...hparams
import
FinetuningArguments
,
GeneratingArguments
,
ModelArguments
logger
=
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
class
CustomPPOTrainer
(
PPOTrainer
,
Trainer
):
...
...
@@ -112,7 +112,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
]
ppo_config
.
accelerator_kwargs
[
"deepspeed_plugin"
]
=
training_args
.
deepspeed_plugin
if
ppo_config
.
log_with
is
not
None
:
logger
.
warning
(
"PPOTrainer cannot use external logger when DeepSpeed is enabled."
)
logger
.
warning
_rank0
(
"PPOTrainer cannot use external logger when DeepSpeed is enabled."
)
ppo_config
.
log_with
=
None
# Create optimizer and scheduler
...
...
@@ -160,7 +160,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
callbacks
,
self
.
accelerator
.
unwrap_model
(
self
.
model
),
self
.
tokenizer
,
self
.
optimizer
,
self
.
lr_scheduler
)
if
self
.
args
.
max_steps
>
0
:
logger
.
info
(
"max_steps is given, it will override any value given in num_train_epochs"
)
logger
.
info
_rank0
(
"max_steps is given, it will override any value given in num_train_epochs"
)
self
.
amp_context
=
torch
.
autocast
(
self
.
current_device
.
type
)
warnings
.
simplefilter
(
"ignore"
)
# remove gc warnings on ref model
...
...
@@ -181,7 +181,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self
.
add_callback
(
SaveProcessorCallback
(
processor
))
if
finetuning_args
.
use_badam
:
from
badam
import
BAdamCallback
,
clip_grad_norm_old_version
from
badam
import
BAdamCallback
,
clip_grad_norm_old_version
# type: ignore
self
.
accelerator
.
clip_grad_norm_
=
MethodType
(
clip_grad_norm_old_version
,
self
.
accelerator
)
self
.
add_callback
(
BAdamCallback
)
...
...
@@ -216,20 +216,19 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self
.
state
.
is_local_process_zero
=
self
.
is_local_process_zero
()
self
.
state
.
is_world_process_zero
=
self
.
is_world_process_zero
()
if
self
.
is_world_process_zero
():
logger
.
info
(
"***** Running training *****"
)
logger
.
info
(
" Num examples = {:,}"
.
format
(
num_examples
))
logger
.
info
(
" Num Epochs = {:,}"
.
format
(
num_train_epochs
))
logger
.
info
(
" Instantaneous batch size per device = {:,}"
.
format
(
self
.
args
.
per_device_train_batch_size
))
logger
.
info
(
" Total train batch size (w. parallel, buffer, distributed & accumulation) = {:,}"
.
format
(
total_train_batch_size
)
logger
.
info_rank0
(
"***** Running training *****"
)
logger
.
info_rank0
(
f
" Num examples =
{
num_examples
:,
}
"
)
logger
.
info_rank0
(
f
" Num Epochs =
{
num_train_epochs
:,
}
"
)
logger
.
info_rank0
(
f
" Instantaneous batch size per device =
{
self
.
args
.
per_device_train_batch_size
:,
}
"
)
logger
.
info_rank0
(
" Total train batch size (w. parallel, buffer, distributed & accumulation) = {:,}"
.
format
(
total_train_batch_size
)
logger
.
info
(
" Gradient Accumulation steps = {:,}"
.
format
(
self
.
args
.
gradient_accumulation_steps
))
logger
.
info
(
" Num optimization epochs per batch = {:,}"
.
format
(
self
.
finetuning_args
.
ppo_epochs
))
logger
.
info
(
" Total training steps = {:,}"
.
format
(
max_steps
))
logger
.
info
(
" Number of trainable parameters = {:,}"
.
format
(
count_parameters
(
self
.
model
)[
0
]))
)
logger
.
info_rank0
(
f
" Gradient Accumulation steps =
{
self
.
args
.
gradient_accumulation_steps
:,
}
"
)
logger
.
info_rank0
(
f
" Num optimization epochs per batch =
{
self
.
finetuning_args
.
ppo_epochs
:,
}
"
)
logger
.
info_rank0
(
f
" Total training steps =
{
max_steps
:,
}
"
)
logger
.
info_rank0
(
f
" Number of trainable parameters =
{
count_parameters
(
self
.
model
)[
0
]:,
}
"
)
dataiter
=
iter
(
self
.
dataloader
)
loss_meter
=
AverageMeter
()
...
...
@@ -269,7 +268,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
batch
[
"response"
]
=
self
.
tokenizer
.
batch_decode
(
responses
,
skip_special_tokens
=
True
)
self
.
log_stats
(
stats
,
batch
,
rewards
)
except
Exception
:
logger
.
warning
(
"Failed to save stats due to unknown errors."
)
logger
.
warning
_rank0
(
"Failed to save stats due to unknown errors."
)
self
.
state
.
global_step
+=
1
self
.
callback_handler
.
on_step_end
(
self
.
args
,
self
.
state
,
self
.
control
)
...
...
@@ -290,7 +289,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
if
(
step
+
1
)
%
self
.
args
.
save_steps
==
0
:
# save checkpoint
self
.
save_model
(
os
.
path
.
join
(
self
.
args
.
output_dir
,
"{
}-{}"
.
format
(
PREFIX_CHECKPOINT_DIR
,
self
.
state
.
global_step
)
)
os
.
path
.
join
(
self
.
args
.
output_dir
,
f
"
{
PREFIX_CHECKPOINT_DIR
}
-
{
self
.
state
.
global_step
}
"
)
)
self
.
callback_handler
.
on_save
(
self
.
args
,
self
.
state
,
self
.
control
)
...
...
@@ -498,7 +497,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
if
self
.
args
.
should_save
:
self
.
_save
(
output_dir
,
state_dict
=
state_dict
)
except
ValueError
:
logger
.
warning
(
logger
.
warning
_rank0
(
" stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead,"
" use zero_to_fp32.py to recover weights"
)
...
...
src/llamafactory/train/pt/trainer.py
View file @
2778a3d0
...
...
@@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Optional
from
transformers
import
Trainer
from
typing_extensions
import
override
from
...extras.
logging
import
get_logger
from
...extras.
packages
import
is_transformers_version_equal_to_4_46
from
..callbacks
import
PissaConvertCallback
,
SaveProcessorCallback
from
..trainer_utils
import
create_custom_optimizer
,
create_custom_scheduler
...
...
@@ -30,9 +30,6 @@ if TYPE_CHECKING:
from
...hparams
import
FinetuningArguments
logger
=
get_logger
(
__name__
)
class
CustomTrainer
(
Trainer
):
r
"""
Inherits Trainer for custom optimizer.
...
...
@@ -51,7 +48,7 @@ class CustomTrainer(Trainer):
self
.
add_callback
(
PissaConvertCallback
)
if
finetuning_args
.
use_badam
:
from
badam
import
BAdamCallback
,
clip_grad_norm_old_version
from
badam
import
BAdamCallback
,
clip_grad_norm_old_version
# type: ignore
self
.
accelerator
.
clip_grad_norm_
=
MethodType
(
clip_grad_norm_old_version
,
self
.
accelerator
)
self
.
add_callback
(
BAdamCallback
)
...
...
@@ -68,3 +65,19 @@ class CustomTrainer(Trainer):
)
->
"torch.optim.lr_scheduler.LRScheduler"
:
create_custom_scheduler
(
self
.
args
,
num_training_steps
,
optimizer
)
return
super
().
create_scheduler
(
num_training_steps
,
optimizer
)
@
override
def
compute_loss
(
self
,
model
,
inputs
,
return_outputs
=
False
,
**
kwargs
):
r
"""
Fixes the loss value for transformers 4.46.0.
https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605
"""
loss
=
super
().
compute_loss
(
model
,
inputs
,
return_outputs
,
**
kwargs
)
if
is_transformers_version_equal_to_4_46
()
and
not
getattr
(
self
,
"model_accepts_loss_kwargs"
,
False
):
# other model should not scale the loss
if
return_outputs
:
return
(
loss
[
0
]
/
self
.
args
.
gradient_accumulation_steps
,
*
loss
[
1
:])
else
:
return
loss
/
self
.
args
.
gradient_accumulation_steps
return
loss
src/llamafactory/train/rm/trainer.py
View file @
2778a3d0
...
...
@@ -24,7 +24,8 @@ import torch
from
transformers
import
Trainer
from
typing_extensions
import
override
from
...extras.logging
import
get_logger
from
...extras
import
logging
from
...extras.packages
import
is_transformers_version_equal_to_4_46
from
..callbacks
import
FixValueHeadModelCallback
,
PissaConvertCallback
,
SaveProcessorCallback
from
..trainer_utils
import
create_custom_optimizer
,
create_custom_scheduler
...
...
@@ -36,7 +37,7 @@ if TYPE_CHECKING:
from
...hparams
import
FinetuningArguments
logger
=
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
class
PairwiseTrainer
(
Trainer
):
...
...
@@ -59,7 +60,7 @@ class PairwiseTrainer(Trainer):
self
.
add_callback
(
PissaConvertCallback
)
if
finetuning_args
.
use_badam
:
from
badam
import
BAdamCallback
,
clip_grad_norm_old_version
from
badam
import
BAdamCallback
,
clip_grad_norm_old_version
# type: ignore
self
.
accelerator
.
clip_grad_norm_
=
MethodType
(
clip_grad_norm_old_version
,
self
.
accelerator
)
self
.
add_callback
(
BAdamCallback
)
...
...
@@ -79,7 +80,7 @@ class PairwiseTrainer(Trainer):
@
override
def
compute_loss
(
self
,
model
:
"PreTrainedModel"
,
inputs
:
Dict
[
str
,
"torch.Tensor"
],
return_outputs
:
bool
=
False
self
,
model
:
"PreTrainedModel"
,
inputs
:
Dict
[
str
,
"torch.Tensor"
],
return_outputs
:
bool
=
False
,
**
kwargs
)
->
Union
[
"torch.Tensor"
,
Tuple
[
"torch.Tensor"
,
List
[
"torch.Tensor"
]]]:
r
"""
Computes pairwise loss. The first n examples are chosen and the last n examples are rejected.
...
...
@@ -98,6 +99,10 @@ class PairwiseTrainer(Trainer):
chosen_scores
,
rejected_scores
=
chosen_scores
.
squeeze
(),
rejected_scores
.
squeeze
()
loss
=
-
torch
.
nn
.
functional
.
logsigmoid
(
chosen_scores
.
float
()
-
rejected_scores
.
float
()).
mean
()
if
is_transformers_version_equal_to_4_46
()
and
kwargs
.
pop
(
"num_items_in_batch"
,
False
):
loss
/=
self
.
args
.
gradient_accumulation_steps
# fixes the loss value for transformers 4.46.0
if
return_outputs
:
return
loss
,
(
loss
,
chosen_scores
,
rejected_scores
)
else
:
...
...
@@ -113,7 +118,7 @@ class PairwiseTrainer(Trainer):
return
output_prediction_file
=
os
.
path
.
join
(
self
.
args
.
output_dir
,
"generated_predictions.jsonl"
)
logger
.
info
(
f
"Saving prediction results to
{
output_prediction_file
}
"
)
logger
.
info
_rank0
(
f
"Saving prediction results to
{
output_prediction_file
}
"
)
chosen_scores
,
rejected_scores
=
predict_results
.
predictions
with
open
(
output_prediction_file
,
"w"
,
encoding
=
"utf-8"
)
as
writer
:
...
...
src/llamafactory/train/sft/trainer.py
View file @
2778a3d0
...
...
@@ -25,8 +25,9 @@ import torch
from
transformers
import
Seq2SeqTrainer
from
typing_extensions
import
override
from
...extras
import
logging
from
...extras.constants
import
IGNORE_INDEX
from
...extras.
logging
import
get_logger
from
...extras.
packages
import
is_transformers_version_equal_to_4_46
from
..callbacks
import
PissaConvertCallback
,
SaveProcessorCallback
from
..trainer_utils
import
create_custom_optimizer
,
create_custom_scheduler
...
...
@@ -39,7 +40,7 @@ if TYPE_CHECKING:
from
...hparams
import
FinetuningArguments
logger
=
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
class
CustomSeq2SeqTrainer
(
Seq2SeqTrainer
):
...
...
@@ -60,7 +61,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
self
.
add_callback
(
PissaConvertCallback
)
if
finetuning_args
.
use_badam
:
from
badam
import
BAdamCallback
,
clip_grad_norm_old_version
from
badam
import
BAdamCallback
,
clip_grad_norm_old_version
# type: ignore
self
.
accelerator
.
clip_grad_norm_
=
MethodType
(
clip_grad_norm_old_version
,
self
.
accelerator
)
self
.
add_callback
(
BAdamCallback
)
...
...
@@ -78,6 +79,22 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
create_custom_scheduler
(
self
.
args
,
num_training_steps
,
optimizer
)
return
super
().
create_scheduler
(
num_training_steps
,
optimizer
)
@
override
def
compute_loss
(
self
,
model
,
inputs
,
return_outputs
=
False
,
**
kwargs
):
r
"""
Fixes the loss value for transformers 4.46.0.
https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605
"""
loss
=
super
().
compute_loss
(
model
,
inputs
,
return_outputs
,
**
kwargs
)
if
is_transformers_version_equal_to_4_46
()
and
not
getattr
(
self
,
"model_accepts_loss_kwargs"
,
False
):
# other model should not scale the loss
if
return_outputs
:
return
(
loss
[
0
]
/
self
.
args
.
gradient_accumulation_steps
,
*
loss
[
1
:])
else
:
return
loss
/
self
.
args
.
gradient_accumulation_steps
return
loss
@
override
def
prediction_step
(
self
,
...
...
@@ -129,7 +146,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
return
output_prediction_file
=
os
.
path
.
join
(
self
.
args
.
output_dir
,
"generated_predictions.jsonl"
)
logger
.
info
(
f
"Saving prediction results to
{
output_prediction_file
}
"
)
logger
.
info
_rank0
(
f
"Saving prediction results to
{
output_prediction_file
}
"
)
labels
=
np
.
where
(
predict_results
.
label_ids
!=
IGNORE_INDEX
,
predict_results
.
label_ids
,
self
.
tokenizer
.
pad_token_id
...
...
Prev
1
…
3
4
5
6
7
8
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment