Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
LLaMA-Factory
Commits
2778a3d0
Commit
2778a3d0
authored
Jan 16, 2025
by
luopl
Browse files
updata to v0.9.1_stable
parent
e92143e3
Changes
172
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
354 additions
and
186 deletions
+354
-186
src/llamafactory/model/model_utils/liger_kernel.py
src/llamafactory/model/model_utils/liger_kernel.py
+5
-5
src/llamafactory/model/model_utils/longlora.py
src/llamafactory/model/model_utils/longlora.py
+12
-12
src/llamafactory/model/model_utils/misc.py
src/llamafactory/model/model_utils/misc.py
+10
-8
src/llamafactory/model/model_utils/packing.py
src/llamafactory/model/model_utils/packing.py
+6
-6
src/llamafactory/model/model_utils/quantization.py
src/llamafactory/model/model_utils/quantization.py
+8
-8
src/llamafactory/model/model_utils/rope.py
src/llamafactory/model/model_utils/rope.py
+8
-10
src/llamafactory/model/model_utils/unsloth.py
src/llamafactory/model/model_utils/unsloth.py
+3
-3
src/llamafactory/model/model_utils/valuehead.py
src/llamafactory/model/model_utils/valuehead.py
+4
-4
src/llamafactory/model/model_utils/visual.py
src/llamafactory/model/model_utils/visual.py
+21
-14
src/llamafactory/model/patcher.py
src/llamafactory/model/patcher.py
+6
-6
src/llamafactory/train/callbacks.py
src/llamafactory/train/callbacks.py
+28
-28
src/llamafactory/train/dpo/trainer.py
src/llamafactory/train/dpo/trainer.py
+62
-14
src/llamafactory/train/dpo/workflow.py
src/llamafactory/train/dpo/workflow.py
+13
-0
src/llamafactory/train/kto/trainer.py
src/llamafactory/train/kto/trainer.py
+96
-30
src/llamafactory/train/kto/workflow.py
src/llamafactory/train/kto/workflow.py
+1
-1
src/llamafactory/train/ppo/ppo_utils.py
src/llamafactory/train/ppo/ppo_utils.py
+2
-2
src/llamafactory/train/ppo/trainer.py
src/llamafactory/train/ppo/trainer.py
+20
-21
src/llamafactory/train/pt/trainer.py
src/llamafactory/train/pt/trainer.py
+18
-5
src/llamafactory/train/rm/trainer.py
src/llamafactory/train/rm/trainer.py
+10
-5
src/llamafactory/train/sft/trainer.py
src/llamafactory/train/sft/trainer.py
+21
-4
No files found.
src/llamafactory/model/model_utils/liger_kernel.py
View file @
2778a3d0
...
@@ -15,7 +15,7 @@
...
@@ -15,7 +15,7 @@
import
inspect
import
inspect
from
typing
import
TYPE_CHECKING
from
typing
import
TYPE_CHECKING
from
...extras
.logging
import
get_
logg
er
from
...extras
import
logg
ing
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
...
@@ -24,7 +24,7 @@ if TYPE_CHECKING:
...
@@ -24,7 +24,7 @@ if TYPE_CHECKING:
from
...hparams
import
ModelArguments
from
...hparams
import
ModelArguments
logger
=
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
def
apply_liger_kernel
(
def
apply_liger_kernel
(
...
@@ -54,14 +54,14 @@ def apply_liger_kernel(
...
@@ -54,14 +54,14 @@ def apply_liger_kernel(
elif
model_type
==
"qwen2_vl"
:
elif
model_type
==
"qwen2_vl"
:
from
liger_kernel.transformers
import
apply_liger_kernel_to_qwen2_vl
as
apply_liger_kernel
from
liger_kernel.transformers
import
apply_liger_kernel_to_qwen2_vl
as
apply_liger_kernel
else
:
else
:
logger
.
warning
(
"Current model does not support liger kernel."
)
logger
.
warning
_rank0
(
"Current model does not support liger kernel."
)
return
return
if
require_logits
and
"fused_linear_cross_entropy"
in
inspect
.
signature
(
apply_liger_kernel
).
parameters
:
if
require_logits
and
"fused_linear_cross_entropy"
in
inspect
.
signature
(
apply_liger_kernel
).
parameters
:
logger
.
info
(
"Current training stage does not support chunked cross entropy."
)
logger
.
info
_rank0
(
"Current training stage does not support chunked cross entropy."
)
kwargs
=
{
"fused_linear_cross_entropy"
:
False
}
kwargs
=
{
"fused_linear_cross_entropy"
:
False
}
else
:
else
:
kwargs
=
{}
kwargs
=
{}
apply_liger_kernel
(
**
kwargs
)
apply_liger_kernel
(
**
kwargs
)
logger
.
info
(
"Liger kernel has been applied to the model."
)
logger
.
info
_rank0
(
"Liger kernel has been applied to the model."
)
src/llamafactory/model/model_utils/longlora.py
View file @
2778a3d0
...
@@ -22,6 +22,7 @@ from typing import TYPE_CHECKING, Optional, Tuple
...
@@ -22,6 +22,7 @@ from typing import TYPE_CHECKING, Optional, Tuple
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
transformers
from
transformers.models.llama.modeling_llama
import
(
from
transformers.models.llama.modeling_llama
import
(
Cache
,
Cache
,
LlamaAttention
,
LlamaAttention
,
...
@@ -30,12 +31,11 @@ from transformers.models.llama.modeling_llama import (
...
@@ -30,12 +31,11 @@ from transformers.models.llama.modeling_llama import (
apply_rotary_pos_emb
,
apply_rotary_pos_emb
,
repeat_kv
,
repeat_kv
,
)
)
from
transformers.utils
import
logging
from
transformers.utils.versions
import
require_version
from
transformers.utils.versions
import
require_version
from
...extras
import
logging
from
...extras.constants
import
SUPPORTED_CLASS_FOR_S2ATTN
from
...extras.constants
import
SUPPORTED_CLASS_FOR_S2ATTN
from
...extras.logging
import
get_logger
from
...extras.packages
import
is_transformers_version_greater_than
from
...extras.packages
import
is_transformers_version_greater_than_4_43
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
...
@@ -44,7 +44,7 @@ if TYPE_CHECKING:
...
@@ -44,7 +44,7 @@ if TYPE_CHECKING:
from
...hparams
import
ModelArguments
from
...hparams
import
ModelArguments
transformers_logger
=
logging
.
get_logger
(
__name__
)
transformers_logger
=
transformers
.
utils
.
logging
.
get_logger
(
__name__
)
# Modified from:
# Modified from:
...
@@ -86,7 +86,7 @@ def llama_attention_forward(
...
@@ -86,7 +86,7 @@ def llama_attention_forward(
if
getattr
(
self
.
config
,
"group_size_ratio"
,
None
)
and
self
.
training
:
# shift
if
getattr
(
self
.
config
,
"group_size_ratio"
,
None
)
and
self
.
training
:
# shift
groupsz
=
int
(
q_len
*
getattr
(
self
.
config
,
"group_size_ratio"
))
groupsz
=
int
(
q_len
*
getattr
(
self
.
config
,
"group_size_ratio"
))
assert
q_len
%
groupsz
==
0
,
"q_len {} should be divisible by group size {
}."
.
format
(
q_len
,
groupsz
)
assert
q_len
%
groupsz
==
0
,
f
"q_len
{
q_len
}
should be divisible by group size
{
groupsz
}
."
num_groups
=
q_len
//
groupsz
num_groups
=
q_len
//
groupsz
def
shift
(
state
:
"torch.Tensor"
)
->
"torch.Tensor"
:
def
shift
(
state
:
"torch.Tensor"
)
->
"torch.Tensor"
:
...
@@ -195,7 +195,7 @@ def llama_flash_attention_2_forward(
...
@@ -195,7 +195,7 @@ def llama_flash_attention_2_forward(
if
getattr
(
self
.
config
,
"group_size_ratio"
,
None
)
and
self
.
training
:
# shift
if
getattr
(
self
.
config
,
"group_size_ratio"
,
None
)
and
self
.
training
:
# shift
groupsz
=
int
(
q_len
*
getattr
(
self
.
config
,
"group_size_ratio"
))
groupsz
=
int
(
q_len
*
getattr
(
self
.
config
,
"group_size_ratio"
))
assert
q_len
%
groupsz
==
0
,
"q_len {} should be divisible by group size {
}."
.
format
(
q_len
,
groupsz
)
assert
q_len
%
groupsz
==
0
,
f
"q_len
{
q_len
}
should be divisible by group size
{
groupsz
}
."
num_groups
=
q_len
//
groupsz
num_groups
=
q_len
//
groupsz
def
shift
(
state
:
"torch.Tensor"
)
->
"torch.Tensor"
:
def
shift
(
state
:
"torch.Tensor"
)
->
"torch.Tensor"
:
...
@@ -209,7 +209,7 @@ def llama_flash_attention_2_forward(
...
@@ -209,7 +209,7 @@ def llama_flash_attention_2_forward(
if
attention_mask
is
not
None
:
if
attention_mask
is
not
None
:
attention_mask
=
attention_mask
[:,
:
groupsz
].
repeat
(
num_groups
,
1
)
attention_mask
=
attention_mask
[:,
:
groupsz
].
repeat
(
num_groups
,
1
)
if
is_transformers_version_greater_than
_4_43
(
):
if
is_transformers_version_greater_than
(
"4.43.0"
):
from
transformers.modeling_flash_attention_utils
import
_flash_attention_forward
from
transformers.modeling_flash_attention_utils
import
_flash_attention_forward
attn_output
:
"torch.Tensor"
=
_flash_attention_forward
(
attn_output
:
"torch.Tensor"
=
_flash_attention_forward
(
...
@@ -301,7 +301,7 @@ def llama_sdpa_attention_forward(
...
@@ -301,7 +301,7 @@ def llama_sdpa_attention_forward(
if
getattr
(
self
.
config
,
"group_size_ratio"
,
None
)
and
self
.
training
:
# shift
if
getattr
(
self
.
config
,
"group_size_ratio"
,
None
)
and
self
.
training
:
# shift
groupsz
=
int
(
q_len
*
getattr
(
self
.
config
,
"group_size_ratio"
))
groupsz
=
int
(
q_len
*
getattr
(
self
.
config
,
"group_size_ratio"
))
assert
q_len
%
groupsz
==
0
,
"q_len {} should be divisible by group size {
}."
.
format
(
q_len
,
groupsz
)
assert
q_len
%
groupsz
==
0
,
f
"q_len
{
q_len
}
should be divisible by group size
{
groupsz
}
."
num_groups
=
q_len
//
groupsz
num_groups
=
q_len
//
groupsz
def
shift
(
state
:
"torch.Tensor"
)
->
"torch.Tensor"
:
def
shift
(
state
:
"torch.Tensor"
)
->
"torch.Tensor"
:
...
@@ -353,7 +353,7 @@ def llama_sdpa_attention_forward(
...
@@ -353,7 +353,7 @@ def llama_sdpa_attention_forward(
def
_apply_llama_patch
()
->
None
:
def
_apply_llama_patch
()
->
None
:
require_version
(
"transformers>=4.41.2,<=4.4
5.2
"
,
"To fix: pip install transformers>=4.41.2,<=4.4
5.2
"
)
require_version
(
"transformers>=4.41.2,<=4.4
6.1
"
,
"To fix: pip install transformers>=4.41.2,<=4.4
6.1
"
)
LlamaAttention
.
forward
=
llama_attention_forward
LlamaAttention
.
forward
=
llama_attention_forward
LlamaFlashAttention2
.
forward
=
llama_flash_attention_2_forward
LlamaFlashAttention2
.
forward
=
llama_flash_attention_2_forward
LlamaSdpaAttention
.
forward
=
llama_sdpa_attention_forward
LlamaSdpaAttention
.
forward
=
llama_sdpa_attention_forward
...
@@ -363,11 +363,11 @@ def configure_longlora(config: "PretrainedConfig", model_args: "ModelArguments",
...
@@ -363,11 +363,11 @@ def configure_longlora(config: "PretrainedConfig", model_args: "ModelArguments",
if
not
is_trainable
or
not
model_args
.
shift_attn
:
if
not
is_trainable
or
not
model_args
.
shift_attn
:
return
return
logger
=
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
if
getattr
(
config
,
"model_type"
,
None
)
in
SUPPORTED_CLASS_FOR_S2ATTN
:
if
getattr
(
config
,
"model_type"
,
None
)
in
SUPPORTED_CLASS_FOR_S2ATTN
:
setattr
(
config
,
"group_size_ratio"
,
0.25
)
setattr
(
config
,
"group_size_ratio"
,
0.25
)
_apply_llama_patch
()
_apply_llama_patch
()
logger
.
info
(
"Using shift short attention with group_size_ratio=1/4."
)
logger
.
info
_rank0
(
"Using shift short attention with group_size_ratio=1/4."
)
else
:
else
:
logger
.
warning
(
"Current model does not support shift short attention."
)
logger
.
warning
_rank0
(
"Current model does not support shift short attention."
)
src/llamafactory/model/model_utils/misc.py
View file @
2778a3d0
...
@@ -14,14 +14,14 @@
...
@@ -14,14 +14,14 @@
from
typing
import
TYPE_CHECKING
,
List
from
typing
import
TYPE_CHECKING
,
List
from
...extras
.logging
import
get_
logg
er
from
...extras
import
logg
ing
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
transformers
import
PretrainedConfig
,
PreTrainedModel
,
PreTrainedTokenizer
from
transformers
import
PretrainedConfig
,
PreTrainedModel
,
PreTrainedTokenizer
logger
=
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
def
find_all_linear_modules
(
model
:
"PreTrainedModel"
,
freeze_vision_tower
:
bool
)
->
List
[
str
]:
def
find_all_linear_modules
(
model
:
"PreTrainedModel"
,
freeze_vision_tower
:
bool
)
->
List
[
str
]:
...
@@ -34,13 +34,15 @@ def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool)
...
@@ -34,13 +34,15 @@ def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool)
forbidden_modules
.
add
(
"output_layer"
)
forbidden_modules
.
add
(
"output_layer"
)
elif
model_type
==
"internlm2"
:
elif
model_type
==
"internlm2"
:
forbidden_modules
.
add
(
"output"
)
forbidden_modules
.
add
(
"output"
)
elif
model_type
in
[
"llava"
,
"llava_next"
,
"llava_next_video"
,
"paligemma"
,
"video_llava"
]:
elif
model_type
in
[
"llava"
,
"llava_next"
,
"llava_next_video"
,
"mllama"
,
"paligemma"
,
"video_llava"
]:
forbidden_modules
.
add
(
"multi_modal_projector"
)
forbidden_modules
.
add
(
"multi_modal_projector"
)
elif
model_type
==
"qwen2_vl"
:
elif
model_type
==
"qwen2_vl"
:
forbidden_modules
.
add
(
"merger"
)
forbidden_modules
.
add
(
"merger"
)
if
freeze_vision_tower
:
if
freeze_vision_tower
:
if
model_type
==
"qwen2_vl"
:
if
model_type
==
"mllama"
:
forbidden_modules
.
add
(
"vision_model"
)
elif
model_type
==
"qwen2_vl"
:
forbidden_modules
.
add
(
"visual"
)
forbidden_modules
.
add
(
"visual"
)
else
:
else
:
forbidden_modules
.
add
(
"vision_tower"
)
forbidden_modules
.
add
(
"vision_tower"
)
...
@@ -53,7 +55,7 @@ def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool)
...
@@ -53,7 +55,7 @@ def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool)
if
"Linear"
in
module
.
__class__
.
__name__
and
"Embedding"
not
in
module
.
__class__
.
__name__
:
if
"Linear"
in
module
.
__class__
.
__name__
and
"Embedding"
not
in
module
.
__class__
.
__name__
:
module_names
.
add
(
name
.
split
(
"."
)[
-
1
])
module_names
.
add
(
name
.
split
(
"."
)[
-
1
])
logger
.
info
(
"Found linear modules: {}"
.
format
(
","
.
join
(
module_names
)))
logger
.
info
_rank0
(
"Found linear modules: {}"
.
format
(
","
.
join
(
module_names
)))
return
list
(
module_names
)
return
list
(
module_names
)
...
@@ -67,12 +69,12 @@ def find_expanded_modules(model: "PreTrainedModel", target_modules: List[str], n
...
@@ -67,12 +69,12 @@ def find_expanded_modules(model: "PreTrainedModel", target_modules: List[str], n
if
num_layers
%
num_layer_trainable
!=
0
:
if
num_layers
%
num_layer_trainable
!=
0
:
raise
ValueError
(
raise
ValueError
(
"`num_layers` {} should be divisible by `num_layer_trainable` {
}."
.
format
(
num_layers
,
num_layer_trainable
)
f
"`num_layers`
{
num_layers
}
should be divisible by `num_layer_trainable`
{
num_layer_trainable
}
."
)
)
stride
=
num_layers
//
num_layer_trainable
stride
=
num_layers
//
num_layer_trainable
trainable_layer_ids
=
range
(
stride
-
1
,
num_layers
+
stride
-
1
,
stride
)
trainable_layer_ids
=
range
(
stride
-
1
,
num_layers
+
stride
-
1
,
stride
)
trainable_layers
=
[
".{:d}."
.
format
(
idx
)
for
idx
in
trainable_layer_ids
]
trainable_layers
=
[
f
".
{
idx
:
d
}
."
for
idx
in
trainable_layer_ids
]
module_names
=
[]
module_names
=
[]
for
name
,
_
in
model
.
named_modules
():
for
name
,
_
in
model
.
named_modules
():
if
any
(
target_module
in
name
for
target_module
in
target_modules
)
and
any
(
if
any
(
target_module
in
name
for
target_module
in
target_modules
)
and
any
(
...
@@ -80,7 +82,7 @@ def find_expanded_modules(model: "PreTrainedModel", target_modules: List[str], n
...
@@ -80,7 +82,7 @@ def find_expanded_modules(model: "PreTrainedModel", target_modules: List[str], n
):
):
module_names
.
append
(
name
)
module_names
.
append
(
name
)
logger
.
info
(
"Apply lora to layers: {}"
.
format
(
","
.
join
(
map
(
str
,
trainable_layer_ids
))))
logger
.
info
_rank0
(
"Apply lora to layers: {}"
.
format
(
","
.
join
(
map
(
str
,
trainable_layer_ids
))))
return
module_names
return
module_names
...
...
src/llamafactory/model/model_utils/packing.py
View file @
2778a3d0
...
@@ -43,9 +43,9 @@ import torch
...
@@ -43,9 +43,9 @@ import torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
transformers.utils.versions
import
require_version
from
transformers.utils.versions
import
require_version
from
...extras
import
logging
from
...extras.constants
import
SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN
from
...extras.constants
import
SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN
from
...extras.logging
import
get_logger
from
...extras.packages
import
is_transformers_version_greater_than
from
...extras.packages
import
is_transformers_version_greater_than_4_43
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
...
@@ -54,7 +54,7 @@ if TYPE_CHECKING:
...
@@ -54,7 +54,7 @@ if TYPE_CHECKING:
from
...hparams
import
ModelArguments
from
...hparams
import
ModelArguments
logger
=
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
def
get_seqlens_in_batch
(
attention_mask
:
"torch.Tensor"
)
->
"torch.Tensor"
:
def
get_seqlens_in_batch
(
attention_mask
:
"torch.Tensor"
)
->
"torch.Tensor"
:
...
@@ -114,8 +114,8 @@ def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "tor
...
@@ -114,8 +114,8 @@ def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "tor
def
_patch_for_block_diag_attn
(
model_type
:
str
)
->
None
:
def
_patch_for_block_diag_attn
(
model_type
:
str
)
->
None
:
require_version
(
"transformers>=4.41.2,<=4.4
5.2
"
,
"To fix: pip install transformers>=4.41.2,<=4.4
5.2
"
)
require_version
(
"transformers>=4.41.2,<=4.4
6.1
"
,
"To fix: pip install transformers>=4.41.2,<=4.4
6.1
"
)
if
is_transformers_version_greater_than
_4_43
(
):
if
is_transformers_version_greater_than
(
"4.43.0"
):
import
transformers.modeling_flash_attention_utils
import
transformers.modeling_flash_attention_utils
transformers
.
modeling_flash_attention_utils
.
_get_unpad_data
=
get_unpad_data
transformers
.
modeling_flash_attention_utils
.
_get_unpad_data
=
get_unpad_data
...
@@ -152,6 +152,6 @@ def configure_packing(config: "PretrainedConfig", model_args: "ModelArguments",
...
@@ -152,6 +152,6 @@ def configure_packing(config: "PretrainedConfig", model_args: "ModelArguments",
model_type
=
getattr
(
config
,
"model_type"
,
None
)
model_type
=
getattr
(
config
,
"model_type"
,
None
)
if
model_type
in
SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN
:
if
model_type
in
SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN
:
_patch_for_block_diag_attn
(
model_type
)
_patch_for_block_diag_attn
(
model_type
)
logger
.
info
(
"Using block diagonal attention for sequence packing without cross-attention."
)
logger
.
info
_rank0
(
"Using block diagonal attention for sequence packing without cross-attention."
)
else
:
else
:
raise
ValueError
(
"Current model does not support block diagonal attention."
)
raise
ValueError
(
"Current model does not support block diagonal attention."
)
src/llamafactory/model/model_utils/quantization.py
View file @
2778a3d0
...
@@ -28,8 +28,8 @@ from transformers.integrations import is_deepspeed_zero3_enabled
...
@@ -28,8 +28,8 @@ from transformers.integrations import is_deepspeed_zero3_enabled
from
transformers.modeling_utils
import
is_fsdp_enabled
from
transformers.modeling_utils
import
is_fsdp_enabled
from
transformers.utils.versions
import
require_version
from
transformers.utils.versions
import
require_version
from
...extras
import
logging
from
...extras.constants
import
FILEEXT2TYPE
from
...extras.constants
import
FILEEXT2TYPE
from
...extras.logging
import
get_logger
from
...extras.misc
import
get_current_device
from
...extras.misc
import
get_current_device
...
@@ -39,7 +39,7 @@ if TYPE_CHECKING:
...
@@ -39,7 +39,7 @@ if TYPE_CHECKING:
from
...hparams
import
ModelArguments
from
...hparams
import
ModelArguments
logger
=
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
@
unique
@
unique
...
@@ -109,7 +109,7 @@ def configure_quantization(
...
@@ -109,7 +109,7 @@ def configure_quantization(
"""
"""
if
getattr
(
config
,
"quantization_config"
,
None
):
# ptq
if
getattr
(
config
,
"quantization_config"
,
None
):
# ptq
if
model_args
.
quantization_bit
is
not
None
:
if
model_args
.
quantization_bit
is
not
None
:
logger
.
warning
(
"`quantization_bit` will not affect on the PTQ-quantized models."
)
logger
.
warning
_rank0
(
"`quantization_bit` will not affect on the PTQ-quantized models."
)
if
is_deepspeed_zero3_enabled
()
or
is_fsdp_enabled
():
if
is_deepspeed_zero3_enabled
()
or
is_fsdp_enabled
():
raise
ValueError
(
"DeepSpeed ZeRO-3 or FSDP is incompatible with PTQ-quantized models."
)
raise
ValueError
(
"DeepSpeed ZeRO-3 or FSDP is incompatible with PTQ-quantized models."
)
...
@@ -130,7 +130,7 @@ def configure_quantization(
...
@@ -130,7 +130,7 @@ def configure_quantization(
quantization_config
[
"bits"
]
=
2
quantization_config
[
"bits"
]
=
2
quant_bits
=
quantization_config
.
get
(
"bits"
,
"?"
)
quant_bits
=
quantization_config
.
get
(
"bits"
,
"?"
)
logger
.
info
(
"Loading {
}-bit {}-quantized model."
.
format
(
quant_bits
,
quant_method
.
upper
()
)
)
logger
.
info
_rank0
(
f
"Loading
{
quant_bits
}
-bit
{
quant_method
.
upper
()
}
-quantized model."
)
elif
model_args
.
export_quantization_bit
is
not
None
:
# auto-gptq
elif
model_args
.
export_quantization_bit
is
not
None
:
# auto-gptq
if
model_args
.
export_quantization_bit
not
in
[
8
,
4
,
3
,
2
]:
if
model_args
.
export_quantization_bit
not
in
[
8
,
4
,
3
,
2
]:
...
@@ -149,7 +149,7 @@ def configure_quantization(
...
@@ -149,7 +149,7 @@ def configure_quantization(
)
)
init_kwargs
[
"device_map"
]
=
"auto"
init_kwargs
[
"device_map"
]
=
"auto"
init_kwargs
[
"max_memory"
]
=
get_max_memory
()
init_kwargs
[
"max_memory"
]
=
get_max_memory
()
logger
.
info
(
"Quantizing model to {
} bit with AutoGPTQ."
.
format
(
model_args
.
export_quantization_bit
)
)
logger
.
info
_rank0
(
f
"Quantizing model to
{
model_args
.
export_quantization_bit
}
bit with AutoGPTQ."
)
elif
model_args
.
quantization_bit
is
not
None
:
# on-the-fly
elif
model_args
.
quantization_bit
is
not
None
:
# on-the-fly
if
model_args
.
quantization_method
==
QuantizationMethod
.
BITS_AND_BYTES
.
value
:
if
model_args
.
quantization_method
==
QuantizationMethod
.
BITS_AND_BYTES
.
value
:
...
@@ -179,7 +179,7 @@ def configure_quantization(
...
@@ -179,7 +179,7 @@ def configure_quantization(
else
:
else
:
init_kwargs
[
"device_map"
]
=
{
""
:
get_current_device
()}
# change auto device map for inference
init_kwargs
[
"device_map"
]
=
{
""
:
get_current_device
()}
# change auto device map for inference
logger
.
info
(
"Quantizing model to {
} bit with bitsandbytes."
.
format
(
model_args
.
quantization_bit
)
)
logger
.
info
_rank0
(
f
"Quantizing model to
{
model_args
.
quantization_bit
}
bit with bitsandbytes."
)
elif
model_args
.
quantization_method
==
QuantizationMethod
.
HQQ
.
value
:
elif
model_args
.
quantization_method
==
QuantizationMethod
.
HQQ
.
value
:
if
model_args
.
quantization_bit
not
in
[
8
,
6
,
5
,
4
,
3
,
2
,
1
]:
if
model_args
.
quantization_bit
not
in
[
8
,
6
,
5
,
4
,
3
,
2
,
1
]:
raise
ValueError
(
"HQQ only accepts 1/2/3/4/5/6/8-bit quantization."
)
raise
ValueError
(
"HQQ only accepts 1/2/3/4/5/6/8-bit quantization."
)
...
@@ -191,7 +191,7 @@ def configure_quantization(
...
@@ -191,7 +191,7 @@ def configure_quantization(
init_kwargs
[
"quantization_config"
]
=
HqqConfig
(
init_kwargs
[
"quantization_config"
]
=
HqqConfig
(
nbits
=
model_args
.
quantization_bit
,
quant_zero
=
False
,
quant_scale
=
False
,
axis
=
0
nbits
=
model_args
.
quantization_bit
,
quant_zero
=
False
,
quant_scale
=
False
,
axis
=
0
)
# use ATEN kernel (axis=0) for performance
)
# use ATEN kernel (axis=0) for performance
logger
.
info
(
"Quantizing model to {
} bit with HQQ."
.
format
(
model_args
.
quantization_bit
)
)
logger
.
info
_rank0
(
f
"Quantizing model to
{
model_args
.
quantization_bit
}
bit with HQQ."
)
elif
model_args
.
quantization_method
==
QuantizationMethod
.
EETQ
.
value
:
elif
model_args
.
quantization_method
==
QuantizationMethod
.
EETQ
.
value
:
if
model_args
.
quantization_bit
!=
8
:
if
model_args
.
quantization_bit
!=
8
:
raise
ValueError
(
"EETQ only accepts 8-bit quantization."
)
raise
ValueError
(
"EETQ only accepts 8-bit quantization."
)
...
@@ -201,4 +201,4 @@ def configure_quantization(
...
@@ -201,4 +201,4 @@ def configure_quantization(
require_version
(
"eetq"
,
"To fix: pip install eetq"
)
require_version
(
"eetq"
,
"To fix: pip install eetq"
)
init_kwargs
[
"quantization_config"
]
=
EetqConfig
()
init_kwargs
[
"quantization_config"
]
=
EetqConfig
()
logger
.
info
(
"Quantizing model to {
} bit with EETQ."
.
format
(
model_args
.
quantization_bit
)
)
logger
.
info
_rank0
(
f
"Quantizing model to
{
model_args
.
quantization_bit
}
bit with EETQ."
)
src/llamafactory/model/model_utils/rope.py
View file @
2778a3d0
...
@@ -19,7 +19,7 @@
...
@@ -19,7 +19,7 @@
import
math
import
math
from
typing
import
TYPE_CHECKING
from
typing
import
TYPE_CHECKING
from
...extras
.logging
import
get_
logg
er
from
...extras
import
logg
ing
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
...
@@ -28,7 +28,7 @@ if TYPE_CHECKING:
...
@@ -28,7 +28,7 @@ if TYPE_CHECKING:
from
...hparams
import
ModelArguments
from
...hparams
import
ModelArguments
logger
=
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
def
configure_rope
(
config
:
"PretrainedConfig"
,
model_args
:
"ModelArguments"
,
is_trainable
:
bool
)
->
None
:
def
configure_rope
(
config
:
"PretrainedConfig"
,
model_args
:
"ModelArguments"
,
is_trainable
:
bool
)
->
None
:
...
@@ -36,30 +36,28 @@ def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_
...
@@ -36,30 +36,28 @@ def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_
return
return
if
not
hasattr
(
config
,
"rope_scaling"
):
if
not
hasattr
(
config
,
"rope_scaling"
):
logger
.
warning
(
"Current model does not support RoPE scaling."
)
logger
.
warning
_rank0
(
"Current model does not support RoPE scaling."
)
return
return
if
model_args
.
model_max_length
is
not
None
:
if
model_args
.
model_max_length
is
not
None
:
if
is_trainable
and
model_args
.
rope_scaling
==
"dynamic"
:
if
is_trainable
and
model_args
.
rope_scaling
==
"dynamic"
:
logger
.
warning
(
logger
.
warning
_rank0
(
"Dynamic NTK scaling may not work well with fine-tuning. "
"Dynamic NTK scaling may not work well with fine-tuning. "
"See: https://github.com/huggingface/transformers/pull/24653"
"See: https://github.com/huggingface/transformers/pull/24653"
)
)
current_max_length
=
getattr
(
config
,
"max_position_embeddings"
,
None
)
current_max_length
=
getattr
(
config
,
"max_position_embeddings"
,
None
)
if
current_max_length
and
model_args
.
model_max_length
>
current_max_length
:
if
current_max_length
and
model_args
.
model_max_length
>
current_max_length
:
logger
.
info
(
logger
.
info_rank0
(
f
"Enlarge max model length from
{
current_max_length
}
to
{
model_args
.
model_max_length
}
."
)
"Enlarge max model length from {} to {}."
.
format
(
current_max_length
,
model_args
.
model_max_length
)
)
setattr
(
config
,
"max_position_embeddings"
,
model_args
.
model_max_length
)
setattr
(
config
,
"max_position_embeddings"
,
model_args
.
model_max_length
)
scaling_factor
=
float
(
math
.
ceil
(
model_args
.
model_max_length
/
current_max_length
))
scaling_factor
=
float
(
math
.
ceil
(
model_args
.
model_max_length
/
current_max_length
))
else
:
else
:
logger
.
warning
(
"Input length is smaller than max length. Consider increase input length."
)
logger
.
warning
_rank0
(
"Input length is smaller than max length. Consider increase input length."
)
scaling_factor
=
1.0
scaling_factor
=
1.0
else
:
else
:
scaling_factor
=
2.0
scaling_factor
=
2.0
setattr
(
config
,
"rope_scaling"
,
{
"type"
:
model_args
.
rope_scaling
,
"factor"
:
scaling_factor
})
setattr
(
config
,
"rope_scaling"
,
{
"type"
:
model_args
.
rope_scaling
,
"factor"
:
scaling_factor
})
logger
.
info
(
logger
.
info
_rank0
(
"Using {} scaling strategy and setting scaling factor to {
}"
.
format
(
model_args
.
rope_scaling
,
scaling_factor
)
f
"Using
{
model_args
.
rope_scaling
}
scaling strategy and setting scaling factor to
{
scaling_factor
}
"
)
)
src/llamafactory/model/model_utils/unsloth.py
View file @
2778a3d0
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
Optional
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
Optional
from
...extras
.logging
import
get_
logg
er
from
...extras
import
logg
ing
from
...extras.misc
import
get_current_device
from
...extras.misc
import
get_current_device
...
@@ -24,7 +24,7 @@ if TYPE_CHECKING:
...
@@ -24,7 +24,7 @@ if TYPE_CHECKING:
from
...hparams
import
ModelArguments
from
...hparams
import
ModelArguments
logger
=
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
def
_get_unsloth_kwargs
(
def
_get_unsloth_kwargs
(
...
@@ -56,7 +56,7 @@ def load_unsloth_pretrained_model(
...
@@ -56,7 +56,7 @@ def load_unsloth_pretrained_model(
try
:
try
:
model
,
_
=
FastLanguageModel
.
from_pretrained
(
**
unsloth_kwargs
)
model
,
_
=
FastLanguageModel
.
from_pretrained
(
**
unsloth_kwargs
)
except
NotImplementedError
:
except
NotImplementedError
:
logger
.
warning
(
"Unsloth does not support model type {}."
.
format
(
getattr
(
config
,
"model_type"
,
None
)))
logger
.
warning
_rank0
(
"Unsloth does not support model type {}."
.
format
(
getattr
(
config
,
"model_type"
,
None
)))
model
=
None
model
=
None
model_args
.
use_unsloth
=
False
model_args
.
use_unsloth
=
False
...
...
src/llamafactory/model/model_utils/valuehead.py
View file @
2778a3d0
...
@@ -17,8 +17,8 @@ from typing import TYPE_CHECKING, Dict
...
@@ -17,8 +17,8 @@ from typing import TYPE_CHECKING, Dict
import
torch
import
torch
from
transformers.utils
import
cached_file
from
transformers.utils
import
cached_file
from
...extras
import
logging
from
...extras.constants
import
V_HEAD_SAFE_WEIGHTS_NAME
,
V_HEAD_WEIGHTS_NAME
from
...extras.constants
import
V_HEAD_SAFE_WEIGHTS_NAME
,
V_HEAD_WEIGHTS_NAME
from
...extras.logging
import
get_logger
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
...
@@ -27,7 +27,7 @@ if TYPE_CHECKING:
...
@@ -27,7 +27,7 @@ if TYPE_CHECKING:
from
...hparams
import
ModelArguments
from
...hparams
import
ModelArguments
logger
=
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
def
load_valuehead_params
(
path_or_repo_id
:
str
,
model_args
:
"ModelArguments"
)
->
Dict
[
str
,
torch
.
Tensor
]:
def
load_valuehead_params
(
path_or_repo_id
:
str
,
model_args
:
"ModelArguments"
)
->
Dict
[
str
,
torch
.
Tensor
]:
...
@@ -54,8 +54,8 @@ def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") ->
...
@@ -54,8 +54,8 @@ def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") ->
except
Exception
as
err
:
except
Exception
as
err
:
err_text
=
str
(
err
)
err_text
=
str
(
err
)
logger
.
info
(
"Provided path ({}) does not contain value head weights: {
}."
.
format
(
path_or_repo_id
,
err_text
)
)
logger
.
info
_rank0
(
f
"Provided path (
{
path_or_repo_id
}
) does not contain value head weights:
{
err_text
}
."
)
logger
.
info
(
"Ignore the above message if you are not resuming the training of a value head model."
)
logger
.
info
_rank0
(
"Ignore the above message if you are not resuming the training of a value head model."
)
return
None
return
None
...
...
src/llamafactory/model/model_utils/visual.py
View file @
2778a3d0
...
@@ -18,21 +18,21 @@
...
@@ -18,21 +18,21 @@
from
typing
import
TYPE_CHECKING
,
List
,
Sequence
,
Set
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
List
,
Sequence
,
Set
,
Tuple
,
Union
import
torch
import
torch
import
transformers
import
transformers.models
import
transformers.models
from
transformers.activations
import
ACT2FN
from
transformers.activations
import
ACT2FN
from
transformers.utils
import
logging
from
...extras
.logging
import
get_
logg
er
from
...extras
import
logg
ing
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
transformers
import
LlavaConfig
,
PretrainedConfig
,
PreTrainedModel
from
transformers
import
LlavaConfig
,
PretrainedConfig
,
PreTrainedModel
,
ProcessorMixin
from
...hparams
import
FinetuningArguments
,
ModelArguments
from
...hparams
import
FinetuningArguments
,
ModelArguments
logger
=
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
transformers_logger
=
logging
.
get_logger
(
__name__
)
transformers_logger
=
transformers
.
utils
.
logging
.
get_logger
(
__name__
)
class
LlavaMultiModalProjectorForYiVL
(
torch
.
nn
.
Module
):
class
LlavaMultiModalProjectorForYiVL
(
torch
.
nn
.
Module
):
...
@@ -92,14 +92,14 @@ def autocast_projector_dtype(model: "PreTrainedModel", model_args: "ModelArgumen
...
@@ -92,14 +92,14 @@ def autocast_projector_dtype(model: "PreTrainedModel", model_args: "ModelArgumen
if
getattr
(
model
,
"quantization_method"
,
None
):
if
getattr
(
model
,
"quantization_method"
,
None
):
model_type
=
getattr
(
model
.
config
,
"model_type"
,
None
)
model_type
=
getattr
(
model
.
config
,
"model_type"
,
None
)
if
model_type
in
[
"llava"
,
"llava_next"
,
"llava_next_video"
,
"paligemma"
,
"video_llava"
]:
if
model_type
in
[
"llava"
,
"llava_next"
,
"llava_next_video"
,
"paligemma"
,
"pixtral"
,
"video_llava"
]:
mm_projector
:
"torch.nn.Module"
=
getattr
(
model
,
"multi_modal_projector"
)
mm_projector
:
"torch.nn.Module"
=
getattr
(
model
,
"multi_modal_projector"
)
elif
model_type
==
"qwen2_vl"
:
elif
model_type
==
"qwen2_vl"
:
mm_projector
:
"torch.nn.Module"
=
getattr
(
getattr
(
model
,
"visual"
),
"merger"
)
mm_projector
:
"torch.nn.Module"
=
getattr
(
getattr
(
model
,
"visual"
),
"merger"
)
else
:
else
:
return
return
logger
.
info
(
"Casting multimodal projector outputs in {
}."
.
format
(
model_args
.
compute_dtype
)
)
logger
.
info
_rank0
(
f
"Casting multimodal projector outputs in
{
model_args
.
compute_dtype
}
."
)
mm_projector
.
register_forward_hook
(
_mm_projector_forward_post_hook
)
mm_projector
.
register_forward_hook
(
_mm_projector_forward_post_hook
)
...
@@ -113,12 +113,13 @@ def configure_visual_model(config: "PretrainedConfig") -> None:
...
@@ -113,12 +113,13 @@ def configure_visual_model(config: "PretrainedConfig") -> None:
"llava_next"
,
"llava_next"
,
"llava_next_video"
,
"llava_next_video"
,
"paligemma"
,
"paligemma"
,
"pixtral"
,
"video_llava"
,
"video_llava"
,
]:
# required for ds zero3 and valuehead models
]:
# required for ds zero3 and valuehead models
setattr
(
config
,
"hidden_size"
,
getattr
(
config
.
text_config
,
"hidden_size"
,
None
))
setattr
(
config
,
"hidden_size"
,
getattr
(
config
.
text_config
,
"hidden_size"
,
None
))
if
getattr
(
config
,
"is_yi_vl_derived_model"
,
None
):
if
getattr
(
config
,
"is_yi_vl_derived_model"
,
None
):
logger
.
info
(
"Detected Yi-VL model, applying projector patch."
)
logger
.
info
_rank0
(
"Detected Yi-VL model, applying projector patch."
)
transformers
.
models
.
llava
.
modeling_llava
.
LlavaMultiModalProjector
=
LlavaMultiModalProjectorForYiVL
transformers
.
models
.
llava
.
modeling_llava
.
LlavaMultiModalProjector
=
LlavaMultiModalProjectorForYiVL
...
@@ -128,7 +129,7 @@ def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "Finetuni
...
@@ -128,7 +129,7 @@ def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "Finetuni
"""
"""
model_type
=
getattr
(
config
,
"model_type"
,
None
)
model_type
=
getattr
(
config
,
"model_type"
,
None
)
forbidden_modules
=
set
()
forbidden_modules
=
set
()
if
model_type
in
[
"llava"
,
"llava_next"
,
"llava_next_video"
,
"paligemma"
,
"video_llava"
]:
if
model_type
in
[
"llava"
,
"llava_next"
,
"llava_next_video"
,
"paligemma"
,
"pixtral"
,
"video_llava"
]:
if
finetuning_args
.
freeze_vision_tower
:
if
finetuning_args
.
freeze_vision_tower
:
forbidden_modules
.
add
(
"vision_tower"
)
forbidden_modules
.
add
(
"vision_tower"
)
...
@@ -162,19 +163,21 @@ def get_image_seqlen(config: "PretrainedConfig") -> int:
...
@@ -162,19 +163,21 @@ def get_image_seqlen(config: "PretrainedConfig") -> int:
return
image_seqlen
return
image_seqlen
def
get_patch_size
(
config
:
"PretrainedConfig"
)
->
int
:
def
get_patch_size
(
config
:
"PretrainedConfig"
,
processor
:
"ProcessorMixin"
)
->
int
:
r
"""
r
"""
Computes the patch size of the vit.
Computes the patch size of the vit.
"""
"""
patch_size
=
getattr
(
config
.
vision_config
,
"patch_size"
,
-
1
)
patch_size
=
getattr
(
config
.
vision_config
,
"patch_size"
,
getattr
(
processor
,
"patch_size"
,
-
1
)
)
return
patch_size
return
patch_size
def
get_vision_feature_select_strategy
(
config
:
"PretrainedConfig"
)
->
int
:
def
get_vision_feature_select_strategy
(
config
:
"PretrainedConfig"
,
processor
:
"ProcessorMixin"
)
->
int
:
r
"""
r
"""
Get the vision_feature_select_strategy.
Get the vision_feature_select_strategy.
"""
"""
vision_feature_select_strategy
=
getattr
(
config
,
"vision_feature_select_strategy"
,
"default"
)
vision_feature_select_strategy
=
getattr
(
config
,
"vision_feature_select_strategy"
,
getattr
(
processor
,
"vision_feature_select_strategy"
,
"default"
)
)
return
vision_feature_select_strategy
return
vision_feature_select_strategy
...
@@ -186,8 +189,10 @@ def patch_target_modules(
...
@@ -186,8 +189,10 @@ def patch_target_modules(
"""
"""
model_type
=
getattr
(
config
,
"model_type"
,
None
)
model_type
=
getattr
(
config
,
"model_type"
,
None
)
if
finetuning_args
.
freeze_vision_tower
:
if
finetuning_args
.
freeze_vision_tower
:
if
model_type
in
[
"llava"
,
"llava_next"
,
"llava_next_video"
,
"paligemma"
,
"video_llava"
]:
if
model_type
in
[
"llava"
,
"llava_next"
,
"llava_next_video"
,
"paligemma"
,
"pixtral"
,
"video_llava"
]:
return
"^(?!.*vision_tower).*(?:{}).*"
.
format
(
"|"
.
join
(
target_modules
))
return
"^(?!.*vision_tower).*(?:{}).*"
.
format
(
"|"
.
join
(
target_modules
))
elif
model_type
==
"mllama"
:
return
"^(?!.*vision_model).*(?:{}).*"
.
format
(
"|"
.
join
(
target_modules
))
elif
model_type
==
"qwen2_vl"
:
elif
model_type
==
"qwen2_vl"
:
return
"^(?!.*visual).*(?:{}).*"
.
format
(
"|"
.
join
(
target_modules
))
return
"^(?!.*visual).*(?:{}).*"
.
format
(
"|"
.
join
(
target_modules
))
else
:
else
:
...
@@ -195,5 +200,7 @@ def patch_target_modules(
...
@@ -195,5 +200,7 @@ def patch_target_modules(
else
:
else
:
if
model_type
==
"qwen2_vl"
:
if
model_type
==
"qwen2_vl"
:
return
"^(?!.*patch_embed).*(?:{}).*"
.
format
(
"|"
.
join
(
target_modules
))
return
"^(?!.*patch_embed).*(?:{}).*"
.
format
(
"|"
.
join
(
target_modules
))
elif
model_type
==
"pixtral"
:
return
"^(?!.*patch_conv).*(?:{}).*"
.
format
(
"|"
.
join
(
target_modules
))
else
:
else
:
return
target_modules
return
target_modules
src/llamafactory/model/patcher.py
View file @
2778a3d0
...
@@ -22,7 +22,7 @@ from transformers import PreTrainedModel, PreTrainedTokenizerBase, is_torch_npu_
...
@@ -22,7 +22,7 @@ from transformers import PreTrainedModel, PreTrainedTokenizerBase, is_torch_npu_
from
transformers.integrations
import
is_deepspeed_zero3_enabled
from
transformers.integrations
import
is_deepspeed_zero3_enabled
from
transformers.modeling_utils
import
is_fsdp_enabled
from
transformers.modeling_utils
import
is_fsdp_enabled
from
..extras
.logging
import
get_
logg
er
from
..extras
import
logg
ing
from
..extras.misc
import
infer_optim_dtype
from
..extras.misc
import
infer_optim_dtype
from
.model_utils.attention
import
configure_attn_implementation
,
print_attn_implementation
from
.model_utils.attention
import
configure_attn_implementation
,
print_attn_implementation
from
.model_utils.checkpointing
import
prepare_model_for_training
from
.model_utils.checkpointing
import
prepare_model_for_training
...
@@ -49,7 +49,7 @@ if TYPE_CHECKING:
...
@@ -49,7 +49,7 @@ if TYPE_CHECKING:
from
..hparams
import
ModelArguments
from
..hparams
import
ModelArguments
logger
=
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
def
patch_tokenizer
(
tokenizer
:
"PreTrainedTokenizer"
)
->
None
:
def
patch_tokenizer
(
tokenizer
:
"PreTrainedTokenizer"
)
->
None
:
...
@@ -66,11 +66,11 @@ def patch_processor(
...
@@ -66,11 +66,11 @@ def patch_processor(
setattr
(
processor
,
"tokenizer"
,
tokenizer
)
setattr
(
processor
,
"tokenizer"
,
tokenizer
)
setattr
(
processor
,
"image_seqlen"
,
get_image_seqlen
(
config
))
setattr
(
processor
,
"image_seqlen"
,
get_image_seqlen
(
config
))
setattr
(
processor
,
"image_resolution"
,
model_args
.
image_resolution
)
setattr
(
processor
,
"image_resolution"
,
model_args
.
image_resolution
)
setattr
(
processor
,
"patch_size"
,
get_patch_size
(
config
))
setattr
(
processor
,
"patch_size"
,
get_patch_size
(
config
,
processor
))
setattr
(
processor
,
"video_resolution"
,
model_args
.
video_resolution
)
setattr
(
processor
,
"video_resolution"
,
model_args
.
video_resolution
)
setattr
(
processor
,
"video_fps"
,
model_args
.
video_fps
)
setattr
(
processor
,
"video_fps"
,
model_args
.
video_fps
)
setattr
(
processor
,
"video_maxlen"
,
model_args
.
video_maxlen
)
setattr
(
processor
,
"video_maxlen"
,
model_args
.
video_maxlen
)
setattr
(
processor
,
"vision_feature_select_strategy"
,
get_vision_feature_select_strategy
(
config
))
setattr
(
processor
,
"vision_feature_select_strategy"
,
get_vision_feature_select_strategy
(
config
,
processor
))
def
patch_config
(
def
patch_config
(
...
@@ -100,7 +100,7 @@ def patch_config(
...
@@ -100,7 +100,7 @@ def patch_config(
if
model_args
.
use_cache
and
not
is_trainable
:
if
model_args
.
use_cache
and
not
is_trainable
:
setattr
(
config
,
"use_cache"
,
True
)
setattr
(
config
,
"use_cache"
,
True
)
logger
.
info
(
"Using KV cache for faster generation."
)
logger
.
info
_rank0
(
"Using KV cache for faster generation."
)
if
getattr
(
config
,
"model_type"
,
None
)
==
"qwen"
:
if
getattr
(
config
,
"model_type"
,
None
)
==
"qwen"
:
setattr
(
config
,
"use_flash_attn"
,
model_args
.
flash_attn
==
"fa2"
)
setattr
(
config
,
"use_flash_attn"
,
model_args
.
flash_attn
==
"fa2"
)
...
@@ -165,7 +165,7 @@ def patch_model(
...
@@ -165,7 +165,7 @@ def patch_model(
try
:
try
:
model
.
add_model_tags
([
"llama-factory"
])
model
.
add_model_tags
([
"llama-factory"
])
except
Exception
:
except
Exception
:
logger
.
warning
(
"Cannot properly tag the model."
)
logger
.
warning
_rank0
(
"Cannot properly tag the model."
)
def
patch_valuehead_model
(
model
:
"AutoModelForCausalLMWithValueHead"
)
->
None
:
def
patch_valuehead_model
(
model
:
"AutoModelForCausalLMWithValueHead"
)
->
None
:
...
...
src/llamafactory/train/callbacks.py
View file @
2778a3d0
...
@@ -13,7 +13,6 @@
...
@@ -13,7 +13,6 @@
# limitations under the License.
# limitations under the License.
import
json
import
json
import
logging
import
os
import
os
import
signal
import
signal
import
sys
import
sys
...
@@ -34,8 +33,8 @@ from transformers.utils import (
...
@@ -34,8 +33,8 @@ from transformers.utils import (
)
)
from
typing_extensions
import
override
from
typing_extensions
import
override
from
..extras
import
logging
from
..extras.constants
import
TRAINER_LOG
,
V_HEAD_SAFE_WEIGHTS_NAME
,
V_HEAD_WEIGHTS_NAME
from
..extras.constants
import
TRAINER_LOG
,
V_HEAD_SAFE_WEIGHTS_NAME
,
V_HEAD_WEIGHTS_NAME
from
..extras.logging
import
LoggerHandler
,
get_logger
from
..extras.misc
import
get_peak_memory
from
..extras.misc
import
get_peak_memory
...
@@ -48,7 +47,7 @@ if TYPE_CHECKING:
...
@@ -48,7 +47,7 @@ if TYPE_CHECKING:
from
trl
import
AutoModelForCausalLMWithValueHead
from
trl
import
AutoModelForCausalLMWithValueHead
logger
=
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
def
fix_valuehead_checkpoint
(
def
fix_valuehead_checkpoint
(
...
@@ -92,7 +91,7 @@ def fix_valuehead_checkpoint(
...
@@ -92,7 +91,7 @@ def fix_valuehead_checkpoint(
else
:
else
:
torch
.
save
(
v_head_state_dict
,
os
.
path
.
join
(
output_dir
,
V_HEAD_WEIGHTS_NAME
))
torch
.
save
(
v_head_state_dict
,
os
.
path
.
join
(
output_dir
,
V_HEAD_WEIGHTS_NAME
))
logger
.
info
(
"Value head model saved at: {
}"
.
format
(
output_dir
)
)
logger
.
info
_rank0
(
f
"Value head model saved at:
{
output_dir
}
"
)
class
FixValueHeadModelCallback
(
TrainerCallback
):
class
FixValueHeadModelCallback
(
TrainerCallback
):
...
@@ -106,7 +105,7 @@ class FixValueHeadModelCallback(TrainerCallback):
...
@@ -106,7 +105,7 @@ class FixValueHeadModelCallback(TrainerCallback):
Event called after a checkpoint save.
Event called after a checkpoint save.
"""
"""
if
args
.
should_save
:
if
args
.
should_save
:
output_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"{
}-{}"
.
format
(
PREFIX_CHECKPOINT_DIR
,
state
.
global_step
)
)
output_dir
=
os
.
path
.
join
(
args
.
output_dir
,
f
"
{
PREFIX_CHECKPOINT_DIR
}
-
{
state
.
global_step
}
"
)
fix_valuehead_checkpoint
(
fix_valuehead_checkpoint
(
model
=
kwargs
.
pop
(
"model"
),
output_dir
=
output_dir
,
safe_serialization
=
args
.
save_safetensors
model
=
kwargs
.
pop
(
"model"
),
output_dir
=
output_dir
,
safe_serialization
=
args
.
save_safetensors
)
)
...
@@ -123,13 +122,13 @@ class SaveProcessorCallback(TrainerCallback):
...
@@ -123,13 +122,13 @@ class SaveProcessorCallback(TrainerCallback):
@
override
@
override
def
on_save
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
def
on_save
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
if
args
.
should_save
:
if
args
.
should_save
:
output_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"{
}-{}"
.
format
(
PREFIX_CHECKPOINT_DIR
,
state
.
global_step
)
)
output_dir
=
os
.
path
.
join
(
args
.
output_dir
,
f
"
{
PREFIX_CHECKPOINT_DIR
}
-
{
state
.
global_step
}
"
)
getattr
(
self
.
processor
,
"image_processor"
)
.
save_pretrained
(
output_dir
)
self
.
processor
.
save_pretrained
(
output_dir
)
@
override
@
override
def
on_train_end
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
def
on_train_end
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
if
args
.
should_save
:
if
args
.
should_save
:
getattr
(
self
.
processor
,
"image_processor"
)
.
save_pretrained
(
args
.
output_dir
)
self
.
processor
.
save_pretrained
(
args
.
output_dir
)
class
PissaConvertCallback
(
TrainerCallback
):
class
PissaConvertCallback
(
TrainerCallback
):
...
@@ -145,7 +144,7 @@ class PissaConvertCallback(TrainerCallback):
...
@@ -145,7 +144,7 @@ class PissaConvertCallback(TrainerCallback):
if
args
.
should_save
:
if
args
.
should_save
:
model
=
kwargs
.
pop
(
"model"
)
model
=
kwargs
.
pop
(
"model"
)
pissa_init_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"pissa_init"
)
pissa_init_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"pissa_init"
)
logger
.
info
(
"Initial PiSSA adapter will be saved at: {
}."
.
format
(
pissa_init_dir
)
)
logger
.
info
_rank0
(
f
"Initial PiSSA adapter will be saved at:
{
pissa_init_dir
}
."
)
if
isinstance
(
model
,
PeftModel
):
if
isinstance
(
model
,
PeftModel
):
init_lora_weights
=
getattr
(
model
.
peft_config
[
"default"
],
"init_lora_weights"
)
init_lora_weights
=
getattr
(
model
.
peft_config
[
"default"
],
"init_lora_weights"
)
setattr
(
model
.
peft_config
[
"default"
],
"init_lora_weights"
,
True
)
setattr
(
model
.
peft_config
[
"default"
],
"init_lora_weights"
,
True
)
...
@@ -159,7 +158,7 @@ class PissaConvertCallback(TrainerCallback):
...
@@ -159,7 +158,7 @@ class PissaConvertCallback(TrainerCallback):
pissa_init_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"pissa_init"
)
pissa_init_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"pissa_init"
)
pissa_backup_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"pissa_backup"
)
pissa_backup_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"pissa_backup"
)
pissa_convert_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"pissa_converted"
)
pissa_convert_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"pissa_converted"
)
logger
.
info
(
"Converted PiSSA adapter will be saved at: {
}."
.
format
(
pissa_convert_dir
)
)
logger
.
info
_rank0
(
f
"Converted PiSSA adapter will be saved at:
{
pissa_convert_dir
}
."
)
# 1. save a pissa backup with init_lora_weights: True
# 1. save a pissa backup with init_lora_weights: True
# 2. save a converted lora with init_lora_weights: pissa
# 2. save a converted lora with init_lora_weights: pissa
# 3. load the pissa backup with init_lora_weights: True
# 3. load the pissa backup with init_lora_weights: True
...
@@ -200,8 +199,8 @@ class LogCallback(TrainerCallback):
...
@@ -200,8 +199,8 @@ class LogCallback(TrainerCallback):
self
.
webui_mode
=
os
.
environ
.
get
(
"LLAMABOARD_ENABLED"
,
"0"
).
lower
()
in
[
"true"
,
"1"
]
self
.
webui_mode
=
os
.
environ
.
get
(
"LLAMABOARD_ENABLED"
,
"0"
).
lower
()
in
[
"true"
,
"1"
]
if
self
.
webui_mode
:
if
self
.
webui_mode
:
signal
.
signal
(
signal
.
SIGABRT
,
self
.
_set_abort
)
signal
.
signal
(
signal
.
SIGABRT
,
self
.
_set_abort
)
self
.
logger_handler
=
LoggerHandler
(
os
.
environ
.
get
(
"LLAMABOARD_WORKDIR"
))
self
.
logger_handler
=
logging
.
LoggerHandler
(
os
.
environ
.
get
(
"LLAMABOARD_WORKDIR"
))
logging
.
root
.
add
H
andler
(
self
.
logger_handler
)
logging
.
add
_h
andler
(
self
.
logger_handler
)
transformers
.
logging
.
add_handler
(
self
.
logger_handler
)
transformers
.
logging
.
add_handler
(
self
.
logger_handler
)
def
_set_abort
(
self
,
signum
,
frame
)
->
None
:
def
_set_abort
(
self
,
signum
,
frame
)
->
None
:
...
@@ -243,7 +242,7 @@ class LogCallback(TrainerCallback):
...
@@ -243,7 +242,7 @@ class LogCallback(TrainerCallback):
and
os
.
path
.
exists
(
os
.
path
.
join
(
args
.
output_dir
,
TRAINER_LOG
))
and
os
.
path
.
exists
(
os
.
path
.
join
(
args
.
output_dir
,
TRAINER_LOG
))
and
args
.
overwrite_output_dir
and
args
.
overwrite_output_dir
):
):
logger
.
warning
(
"Previous trainer log in this folder will be deleted."
)
logger
.
warning
_once
(
"Previous trainer log in this folder will be deleted."
)
os
.
remove
(
os
.
path
.
join
(
args
.
output_dir
,
TRAINER_LOG
))
os
.
remove
(
os
.
path
.
join
(
args
.
output_dir
,
TRAINER_LOG
))
@
override
@
override
...
@@ -288,13 +287,13 @@ class LogCallback(TrainerCallback):
...
@@ -288,13 +287,13 @@ class LogCallback(TrainerCallback):
logs
=
dict
(
logs
=
dict
(
current_steps
=
self
.
cur_steps
,
current_steps
=
self
.
cur_steps
,
total_steps
=
self
.
max_steps
,
total_steps
=
self
.
max_steps
,
loss
=
state
.
log_history
[
-
1
].
get
(
"loss"
,
None
),
loss
=
state
.
log_history
[
-
1
].
get
(
"loss"
),
eval_loss
=
state
.
log_history
[
-
1
].
get
(
"eval_loss"
,
None
),
eval_loss
=
state
.
log_history
[
-
1
].
get
(
"eval_loss"
),
predict_loss
=
state
.
log_history
[
-
1
].
get
(
"predict_loss"
,
None
),
predict_loss
=
state
.
log_history
[
-
1
].
get
(
"predict_loss"
),
reward
=
state
.
log_history
[
-
1
].
get
(
"reward"
,
None
),
reward
=
state
.
log_history
[
-
1
].
get
(
"reward"
),
accuracy
=
state
.
log_history
[
-
1
].
get
(
"rewards/accuracies"
,
None
),
accuracy
=
state
.
log_history
[
-
1
].
get
(
"rewards/accuracies"
),
l
earning_rate
=
state
.
log_history
[
-
1
].
get
(
"learning_rate"
,
None
),
l
r
=
state
.
log_history
[
-
1
].
get
(
"learning_rate"
),
epoch
=
state
.
log_history
[
-
1
].
get
(
"epoch"
,
None
),
epoch
=
state
.
log_history
[
-
1
].
get
(
"epoch"
),
percentage
=
round
(
self
.
cur_steps
/
self
.
max_steps
*
100
,
2
)
if
self
.
max_steps
!=
0
else
100
,
percentage
=
round
(
self
.
cur_steps
/
self
.
max_steps
*
100
,
2
)
if
self
.
max_steps
!=
0
else
100
,
elapsed_time
=
self
.
elapsed_time
,
elapsed_time
=
self
.
elapsed_time
,
remaining_time
=
self
.
remaining_time
,
remaining_time
=
self
.
remaining_time
,
...
@@ -305,16 +304,17 @@ class LogCallback(TrainerCallback):
...
@@ -305,16 +304,17 @@ class LogCallback(TrainerCallback):
if
os
.
environ
.
get
(
"RECORD_VRAM"
,
"0"
).
lower
()
in
[
"true"
,
"1"
]:
if
os
.
environ
.
get
(
"RECORD_VRAM"
,
"0"
).
lower
()
in
[
"true"
,
"1"
]:
vram_allocated
,
vram_reserved
=
get_peak_memory
()
vram_allocated
,
vram_reserved
=
get_peak_memory
()
logs
[
"vram_allocated"
]
=
round
(
vram_allocated
/
1024
/
1024
/
1024
,
2
)
logs
[
"vram_allocated"
]
=
round
(
vram_allocated
/
(
1024
**
3
)
,
2
)
logs
[
"vram_reserved"
]
=
round
(
vram_reserved
/
1024
/
1024
/
1024
,
2
)
logs
[
"vram_reserved"
]
=
round
(
vram_reserved
/
(
1024
**
3
)
,
2
)
logs
=
{
k
:
v
for
k
,
v
in
logs
.
items
()
if
v
is
not
None
}
logs
=
{
k
:
v
for
k
,
v
in
logs
.
items
()
if
v
is
not
None
}
if
self
.
webui_mode
and
all
(
key
in
logs
for
key
in
[
"loss"
,
"learning_rate"
,
"epoch"
]):
if
self
.
webui_mode
and
all
(
key
in
logs
for
key
in
(
"loss"
,
"lr"
,
"epoch"
)):
logger
.
info
(
log_str
=
f
"'loss':
{
logs
[
'loss'
]:.
4
f
}
, 'learning_rate':
{
logs
[
'lr'
]:
2.4
e
}
, 'epoch':
{
logs
[
'epoch'
]:.
2
f
}
"
"{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}, 'throughput': {}}}"
.
format
(
for
extra_key
in
(
"reward"
,
"accuracy"
,
"throughput"
):
logs
[
"loss"
],
logs
[
"learning_rate"
],
logs
[
"epoch"
],
logs
.
get
(
"throughput"
,
"N/A"
)
if
logs
.
get
(
extra_key
):
)
log_str
+=
f
", '
{
extra_key
}
':
{
logs
[
extra_key
]:.
2
f
}
"
)
logger
.
info_rank0
(
"{"
+
log_str
+
"}"
)
if
self
.
thread_pool
is
not
None
:
if
self
.
thread_pool
is
not
None
:
self
.
thread_pool
.
submit
(
self
.
_write_log
,
args
.
output_dir
,
logs
)
self
.
thread_pool
.
submit
(
self
.
_write_log
,
args
.
output_dir
,
logs
)
...
...
src/llamafactory/train/dpo/trainer.py
View file @
2778a3d0
...
@@ -29,6 +29,7 @@ from trl.trainer import disable_dropout_in_model
...
@@ -29,6 +29,7 @@ from trl.trainer import disable_dropout_in_model
from
typing_extensions
import
override
from
typing_extensions
import
override
from
...extras.constants
import
IGNORE_INDEX
from
...extras.constants
import
IGNORE_INDEX
from
...extras.packages
import
is_transformers_version_equal_to_4_46
from
..callbacks
import
PissaConvertCallback
,
SaveProcessorCallback
from
..callbacks
import
PissaConvertCallback
,
SaveProcessorCallback
from
..trainer_utils
import
create_custom_optimizer
,
create_custom_scheduler
,
get_batch_logps
from
..trainer_utils
import
create_custom_optimizer
,
create_custom_scheduler
,
get_batch_logps
...
@@ -100,7 +101,7 @@ class CustomDPOTrainer(DPOTrainer):
...
@@ -100,7 +101,7 @@ class CustomDPOTrainer(DPOTrainer):
self
.
callback_handler
.
add_callback
(
PissaConvertCallback
)
self
.
callback_handler
.
add_callback
(
PissaConvertCallback
)
if
finetuning_args
.
use_badam
:
if
finetuning_args
.
use_badam
:
from
badam
import
BAdamCallback
,
clip_grad_norm_old_version
from
badam
import
BAdamCallback
,
clip_grad_norm_old_version
# type: ignore
self
.
accelerator
.
clip_grad_norm_
=
MethodType
(
clip_grad_norm_old_version
,
self
.
accelerator
)
self
.
accelerator
.
clip_grad_norm_
=
MethodType
(
clip_grad_norm_old_version
,
self
.
accelerator
)
self
.
add_callback
(
BAdamCallback
)
self
.
add_callback
(
BAdamCallback
)
...
@@ -118,6 +119,13 @@ class CustomDPOTrainer(DPOTrainer):
...
@@ -118,6 +119,13 @@ class CustomDPOTrainer(DPOTrainer):
create_custom_scheduler
(
self
.
args
,
num_training_steps
,
optimizer
)
create_custom_scheduler
(
self
.
args
,
num_training_steps
,
optimizer
)
return
super
().
create_scheduler
(
num_training_steps
,
optimizer
)
return
super
().
create_scheduler
(
num_training_steps
,
optimizer
)
@
override
def
get_batch_samples
(
self
,
epoch_iterator
,
num_batches
):
r
"""
Replaces the method of KTO Trainer with the one of the standard Trainer.
"""
return
Trainer
.
get_batch_samples
(
self
,
epoch_iterator
,
num_batches
)
def
odds_ratio_loss
(
self
,
chosen_logps
:
"torch.Tensor"
,
rejected_logps
:
"torch.Tensor"
)
->
"torch.Tensor"
:
def
odds_ratio_loss
(
self
,
chosen_logps
:
"torch.Tensor"
,
rejected_logps
:
"torch.Tensor"
)
->
"torch.Tensor"
:
r
"""
r
"""
Computes ORPO's odds ratio (OR) loss for batched log probabilities of the policy model.
Computes ORPO's odds ratio (OR) loss for batched log probabilities of the policy model.
...
@@ -156,7 +164,7 @@ class CustomDPOTrainer(DPOTrainer):
...
@@ -156,7 +164,7 @@ class CustomDPOTrainer(DPOTrainer):
elif
self
.
loss_type
==
"simpo"
:
elif
self
.
loss_type
==
"simpo"
:
losses
=
self
.
simpo_loss
(
policy_chosen_logps
,
policy_rejected_logps
)
losses
=
self
.
simpo_loss
(
policy_chosen_logps
,
policy_rejected_logps
)
else
:
else
:
raise
NotImplementedError
(
"Unknown loss type: {
}."
.
format
(
self
.
loss_type
)
)
raise
NotImplementedError
(
f
"Unknown loss type:
{
self
.
loss_type
}
."
)
chosen_rewards
=
self
.
beta
*
policy_chosen_logps
.
to
(
self
.
accelerator
.
device
).
detach
()
chosen_rewards
=
self
.
beta
*
policy_chosen_logps
.
to
(
self
.
accelerator
.
device
).
detach
()
rejected_rewards
=
self
.
beta
*
policy_rejected_logps
.
to
(
self
.
accelerator
.
device
).
detach
()
rejected_rewards
=
self
.
beta
*
policy_rejected_logps
.
to
(
self
.
accelerator
.
device
).
detach
()
...
@@ -242,19 +250,59 @@ class CustomDPOTrainer(DPOTrainer):
...
@@ -242,19 +250,59 @@ class CustomDPOTrainer(DPOTrainer):
if
self
.
ftx_gamma
>
1e-6
:
if
self
.
ftx_gamma
>
1e-6
:
losses
+=
self
.
ftx_gamma
*
sft_loss
losses
+=
self
.
ftx_gamma
*
sft_loss
reward_accuracies
=
(
chosen_rewards
>
rejected_rewards
).
float
()
prefix
=
"eval_"
if
train_eval
==
"eval"
else
""
prefix
=
"eval_"
if
train_eval
==
"eval"
else
""
metrics
[
"{}rewards/chosen"
.
format
(
prefix
)
]
=
chosen_rewards
.
mean
().
cpu
()
metrics
[
f
"
{
prefix
}
rewards/chosen"
]
=
chosen_rewards
.
mean
().
item
()
metrics
[
"{}rewards/rejected"
.
format
(
prefix
)
]
=
rejected_rewards
.
mean
().
cpu
()
metrics
[
f
"
{
prefix
}
rewards/rejected"
]
=
rejected_rewards
.
mean
().
item
()
metrics
[
"{}rewards/accuracies"
.
format
(
prefix
)]
=
reward_accuracies
.
mean
().
cpu
()
metrics
[
f
"
{
prefix
}
rewards/accuracies"
]
=
(
chosen_rewards
>
rejected_rewards
).
float
()
.
mean
().
item
()
metrics
[
"{}rewards/margins"
.
format
(
prefix
)
]
=
(
chosen_rewards
-
rejected_rewards
).
mean
().
cpu
()
metrics
[
f
"
{
prefix
}
rewards/margins"
]
=
(
chosen_rewards
-
rejected_rewards
).
mean
().
item
()
metrics
[
"{}logps/
rejected"
.
format
(
prefix
)]
=
policy_rejected_logps
.
detach
()
.
mean
().
cpu
()
metrics
[
f
"
{
prefix
}
logps/
chosen"
]
=
policy_chosen_logps
.
mean
().
item
()
metrics
[
"{}logps/
chosen"
.
format
(
prefix
)]
=
policy_chosen_logps
.
detach
()
.
mean
().
cpu
()
metrics
[
f
"
{
prefix
}
logps/
rejected"
]
=
policy_rejected_logps
.
mean
().
item
()
metrics
[
"{}logits/
rejected"
.
format
(
prefix
)]
=
policy_rejected_logits
.
detach
()
.
mean
().
cpu
()
metrics
[
f
"
{
prefix
}
logits/
chosen"
]
=
policy_chosen_logits
.
mean
().
item
()
metrics
[
"{}logits/
chosen"
.
format
(
prefix
)]
=
policy_chosen_logits
.
detach
()
.
mean
().
cpu
()
metrics
[
f
"
{
prefix
}
logits/
rejected"
]
=
policy_rejected_logits
.
mean
().
item
()
if
self
.
loss_type
==
"orpo"
:
if
self
.
loss_type
==
"orpo"
:
metrics
[
"{}sft_loss"
.
format
(
prefix
)
]
=
sft_loss
.
detach
().
mean
().
cpu
()
metrics
[
f
"
{
prefix
}
sft_loss"
]
=
sft_loss
.
mean
().
item
()
metrics
[
"{}odds_ratio_loss"
.
format
(
prefix
)
]
=
((
losses
-
sft_loss
)
/
self
.
beta
).
detach
().
mean
().
cpu
()
metrics
[
f
"
{
prefix
}
odds_ratio_loss"
]
=
((
losses
-
sft_loss
)
/
self
.
beta
).
mean
().
item
()
return
losses
.
mean
(),
metrics
return
losses
.
mean
(),
metrics
@
override
def
compute_loss
(
self
,
model
,
inputs
,
return_outputs
=
False
,
**
kwargs
):
r
"""
Fixes the loss value for transformers 4.46.0.
https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605
"""
loss
=
super
().
compute_loss
(
model
,
inputs
,
return_outputs
)
if
is_transformers_version_equal_to_4_46
()
and
kwargs
.
pop
(
"num_items_in_batch"
,
False
):
if
return_outputs
:
return
(
loss
[
0
]
/
self
.
args
.
gradient_accumulation_steps
,
*
loss
[
1
:])
else
:
return
loss
/
self
.
args
.
gradient_accumulation_steps
return
loss
@
override
def
log
(
self
,
logs
:
Dict
[
str
,
float
])
->
None
:
r
"""
Log `logs` on the various objects watching training, including stored metrics.
"""
# logs either has "loss" or "eval_loss"
train_eval
=
"train"
if
"loss"
in
logs
else
"eval"
# Add averaged stored metrics to logs
key_list
,
metric_list
=
[],
[]
for
key
,
metrics
in
self
.
_stored_metrics
[
train_eval
].
items
():
key_list
.
append
(
key
)
metric_list
.
append
(
torch
.
tensor
(
metrics
,
dtype
=
torch
.
float
).
to
(
self
.
accelerator
.
device
).
mean
().
item
())
del
self
.
_stored_metrics
[
train_eval
]
if
len
(
metric_list
)
<
10
:
# pad to for all reduce
for
i
in
range
(
10
-
len
(
metric_list
)):
key_list
.
append
(
f
"dummy_
{
i
}
"
)
metric_list
.
append
(
0.0
)
metric_list
=
torch
.
tensor
(
metric_list
,
dtype
=
torch
.
float
).
to
(
self
.
accelerator
.
device
)
metric_list
=
self
.
accelerator
.
reduce
(
metric_list
,
"mean"
).
tolist
()
for
key
,
metric
in
zip
(
key_list
,
metric_list
):
# add remaining items
if
not
key
.
startswith
(
"dummy_"
):
logs
[
key
]
=
metric
return
Trainer
.
log
(
self
,
logs
)
src/llamafactory/train/dpo/workflow.py
View file @
2778a3d0
...
@@ -19,6 +19,7 @@ from typing import TYPE_CHECKING, List, Optional
...
@@ -19,6 +19,7 @@ from typing import TYPE_CHECKING, List, Optional
from
...data
import
PairwiseDataCollatorWithPadding
,
get_dataset
,
get_template_and_fix_tokenizer
from
...data
import
PairwiseDataCollatorWithPadding
,
get_dataset
,
get_template_and_fix_tokenizer
from
...extras.constants
import
IGNORE_INDEX
from
...extras.constants
import
IGNORE_INDEX
from
...extras.misc
import
cal_effective_tokens
from
...extras.ploting
import
plot_loss
from
...extras.ploting
import
plot_loss
from
...hparams
import
ModelArguments
from
...hparams
import
ModelArguments
from
...model
import
load_model
,
load_tokenizer
from
...model
import
load_model
,
load_tokenizer
...
@@ -64,6 +65,12 @@ def run_dpo(
...
@@ -64,6 +65,12 @@ def run_dpo(
# Update arguments
# Update arguments
training_args
.
remove_unused_columns
=
False
# important for multimodal and pairwise dataset
training_args
.
remove_unused_columns
=
False
# important for multimodal and pairwise dataset
effective_token_num
=
0.0
if
finetuning_args
.
include_effective_tokens_per_second
:
for
data
in
dataset_module
[
"train_dataset"
]:
effective_token_num
+=
len
(
data
[
"chosen_input_ids"
])
effective_token_num
+=
len
(
data
[
"rejected_input_ids"
])
# Initialize our Trainer
# Initialize our Trainer
trainer
=
CustomDPOTrainer
(
trainer
=
CustomDPOTrainer
(
model
=
model
,
model
=
model
,
...
@@ -79,6 +86,12 @@ def run_dpo(
...
@@ -79,6 +86,12 @@ def run_dpo(
# Training
# Training
if
training_args
.
do_train
:
if
training_args
.
do_train
:
train_result
=
trainer
.
train
(
resume_from_checkpoint
=
training_args
.
resume_from_checkpoint
)
train_result
=
trainer
.
train
(
resume_from_checkpoint
=
training_args
.
resume_from_checkpoint
)
if
finetuning_args
.
include_effective_tokens_per_second
:
train_result
.
metrics
[
"effective_tokens_per_sec"
]
=
cal_effective_tokens
(
effective_token_num
,
train_result
.
metrics
[
"epoch"
],
train_result
.
metrics
[
"train_runtime"
]
)
trainer
.
save_model
()
trainer
.
save_model
()
trainer
.
log_metrics
(
"train"
,
train_result
.
metrics
)
trainer
.
log_metrics
(
"train"
,
train_result
.
metrics
)
trainer
.
save_metrics
(
"train"
,
train_result
.
metrics
)
trainer
.
save_metrics
(
"train"
,
train_result
.
metrics
)
...
...
src/llamafactory/train/kto/trainer.py
View file @
2778a3d0
...
@@ -28,6 +28,7 @@ from trl.trainer import disable_dropout_in_model
...
@@ -28,6 +28,7 @@ from trl.trainer import disable_dropout_in_model
from
typing_extensions
import
override
from
typing_extensions
import
override
from
...extras.constants
import
IGNORE_INDEX
from
...extras.constants
import
IGNORE_INDEX
from
...extras.packages
import
is_transformers_version_equal_to_4_46
from
..callbacks
import
SaveProcessorCallback
from
..callbacks
import
SaveProcessorCallback
from
..trainer_utils
import
create_custom_optimizer
,
create_custom_scheduler
,
get_batch_logps
from
..trainer_utils
import
create_custom_optimizer
,
create_custom_scheduler
,
get_batch_logps
...
@@ -95,7 +96,7 @@ class CustomKTOTrainer(KTOTrainer):
...
@@ -95,7 +96,7 @@ class CustomKTOTrainer(KTOTrainer):
self
.
add_callback
(
SaveProcessorCallback
(
processor
))
self
.
add_callback
(
SaveProcessorCallback
(
processor
))
if
finetuning_args
.
use_badam
:
if
finetuning_args
.
use_badam
:
from
badam
import
BAdamCallback
,
clip_grad_norm_old_version
from
badam
import
BAdamCallback
,
clip_grad_norm_old_version
# type: ignore
self
.
accelerator
.
clip_grad_norm_
=
MethodType
(
clip_grad_norm_old_version
,
self
.
accelerator
)
self
.
accelerator
.
clip_grad_norm_
=
MethodType
(
clip_grad_norm_old_version
,
self
.
accelerator
)
self
.
add_callback
(
BAdamCallback
)
self
.
add_callback
(
BAdamCallback
)
...
@@ -120,20 +121,27 @@ class CustomKTOTrainer(KTOTrainer):
...
@@ -120,20 +121,27 @@ class CustomKTOTrainer(KTOTrainer):
"""
"""
return
Trainer
.
_get_train_sampler
(
self
)
return
Trainer
.
_get_train_sampler
(
self
)
@
override
def
get_batch_samples
(
self
,
epoch_iterator
,
num_batches
):
r
"""
Replaces the method of KTO Trainer with the one of the standard Trainer.
"""
return
Trainer
.
get_batch_samples
(
self
,
epoch_iterator
,
num_batches
)
@
override
@
override
def
forward
(
def
forward
(
self
,
model
:
"PreTrainedModel"
,
batch
:
Dict
[
str
,
"torch.Tensor"
],
prefix
:
Literal
[
""
,
"kl_"
]
=
""
self
,
model
:
"PreTrainedModel"
,
batch
:
Dict
[
str
,
"torch.Tensor"
],
prefix
:
Literal
[
""
,
"kl_"
]
=
""
)
->
Tuple
[
"torch.Tensor"
,
"torch.Tensor"
]:
)
->
Tuple
[
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
]:
r
"""
r
"""
Runs forward pass and computes the log probabilities.
Runs forward pass and computes the log probabilities.
"""
"""
batch
=
{
k
:
v
.
detach
().
clone
()
for
k
,
v
in
batch
.
items
()}
# avoid error
batch
=
{
k
:
v
.
detach
().
clone
()
for
k
,
v
in
batch
.
items
()}
# avoid error
model_inputs
=
{
model_inputs
=
{
"input_ids"
:
batch
[
"{}input_ids"
.
format
(
prefix
)
],
"input_ids"
:
batch
[
f
"
{
prefix
}
input_ids"
],
"attention_mask"
:
batch
[
"{}attention_mask"
.
format
(
prefix
)
],
"attention_mask"
:
batch
[
f
"
{
prefix
}
attention_mask"
],
}
}
if
"{}token_type_ids"
.
format
(
prefix
)
in
batch
:
if
f
"
{
prefix
}
token_type_ids"
in
batch
:
model_inputs
[
"token_type_ids"
]
=
batch
[
"{}token_type_ids"
.
format
(
prefix
)
]
model_inputs
[
"token_type_ids"
]
=
batch
[
f
"
{
prefix
}
token_type_ids"
]
if
"pixel_values"
in
batch
:
if
"pixel_values"
in
batch
:
model_inputs
[
"pixel_values"
]
=
batch
[
"pixel_values"
]
model_inputs
[
"pixel_values"
]
=
batch
[
"pixel_values"
]
...
@@ -142,24 +150,26 @@ class CustomKTOTrainer(KTOTrainer):
...
@@ -142,24 +150,26 @@ class CustomKTOTrainer(KTOTrainer):
model_inputs
[
"image_grid_thw"
]
=
batch
[
"image_grid_thw"
]
model_inputs
[
"image_grid_thw"
]
=
batch
[
"image_grid_thw"
]
logits
=
model
(
**
model_inputs
,
return_dict
=
True
,
use_cache
=
False
).
logits
.
to
(
torch
.
float32
)
logits
=
model
(
**
model_inputs
,
return_dict
=
True
,
use_cache
=
False
).
logits
.
to
(
torch
.
float32
)
logps
,
valid_length
=
get_batch_logps
(
logits
=
logits
,
labels
=
batch
[
"{}labels"
.
format
(
prefix
)
])
logps
,
valid_length
=
get_batch_logps
(
logits
=
logits
,
labels
=
batch
[
f
"
{
prefix
}
labels"
])
return
logps
,
logps
/
valid_length
return
logits
,
logps
,
logps
/
valid_length
@
override
@
override
def
concatenated_forward
(
def
concatenated_forward
(
self
,
model
:
"PreTrainedModel"
,
batch
:
Dict
[
str
,
"torch.Tensor"
]
self
,
model
:
"PreTrainedModel"
,
batch
:
Dict
[
str
,
"torch.Tensor"
]
)
->
Tuple
[
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
]:
)
->
Tuple
[
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
]:
target_logps
,
target_logps_avg
=
self
.
forward
(
model
,
batch
)
target_logits
,
target_logps
,
target_logps_avg
=
self
.
forward
(
model
,
batch
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
kl_logps
,
_
=
self
.
forward
(
model
,
batch
,
prefix
=
"kl_"
)
_
,
kl_logps
,
_
=
self
.
forward
(
model
,
batch
,
prefix
=
"kl_"
)
if
len
(
target_logps
)
!=
len
(
batch
[
"kto_tags"
]):
if
len
(
target_logps
)
!=
len
(
batch
[
"kto_tags"
]):
raise
ValueError
(
"Mismatched shape of inputs and labels."
)
raise
ValueError
(
"Mismatched shape of inputs and labels."
)
chosen_logits
=
target_logits
[
batch
[
"kto_tags"
]]
chosen_logps
=
target_logps
[
batch
[
"kto_tags"
]]
chosen_logps
=
target_logps
[
batch
[
"kto_tags"
]]
rejected_logits
=
target_logits
[
~
batch
[
"kto_tags"
]]
rejected_logps
=
target_logps
[
~
batch
[
"kto_tags"
]]
rejected_logps
=
target_logps
[
~
batch
[
"kto_tags"
]]
chosen_logps_avg
=
target_logps_avg
[
batch
[
"kto_tags"
]]
chosen_logps_avg
=
target_logps_avg
[
batch
[
"kto_tags"
]]
return
chosen_logps
,
rejected_logps
,
kl_logps
,
chosen_logps_avg
return
chosen_logps
,
rejected_logps
,
chosen_logits
,
rejected_logits
,
kl_logps
,
chosen_logps_avg
@
override
@
override
def
compute_reference_log_probs
(
def
compute_reference_log_probs
(
...
@@ -176,7 +186,7 @@ class CustomKTOTrainer(KTOTrainer):
...
@@ -176,7 +186,7 @@ class CustomKTOTrainer(KTOTrainer):
ref_context
=
nullcontext
()
ref_context
=
nullcontext
()
with
torch
.
no_grad
(),
ref_context
:
with
torch
.
no_grad
(),
ref_context
:
reference_chosen_logps
,
reference_rejected_logps
,
reference_kl_logps
,
_
=
self
.
concatenated_forward
(
reference_chosen_logps
,
reference_rejected_logps
,
_
,
_
,
reference_kl_logps
,
_
=
self
.
concatenated_forward
(
ref_model
,
batch
ref_model
,
batch
)
)
...
@@ -192,9 +202,14 @@ class CustomKTOTrainer(KTOTrainer):
...
@@ -192,9 +202,14 @@ class CustomKTOTrainer(KTOTrainer):
Computes the DPO loss and other metrics for the given batch of inputs for train or test.
Computes the DPO loss and other metrics for the given batch of inputs for train or test.
"""
"""
metrics
=
{}
metrics
=
{}
policy_chosen_logps
,
policy_rejected_logps
,
policy_kl_logps
,
policy_chosen_logps_avg
=
(
(
self
.
concatenated_forward
(
model
,
batch
)
policy_chosen_logps
,
)
policy_rejected_logps
,
policy_chosen_logits
,
policy_rejected_logits
,
policy_kl_logps
,
policy_chosen_logps_avg
,
)
=
self
.
concatenated_forward
(
model
,
batch
)
reference_chosen_logps
,
reference_rejected_logps
,
reference_kl_logps
=
self
.
compute_reference_log_probs
(
reference_chosen_logps
,
reference_rejected_logps
,
reference_kl_logps
=
self
.
compute_reference_log_probs
(
model
,
batch
model
,
batch
)
)
...
@@ -212,22 +227,73 @@ class CustomKTOTrainer(KTOTrainer):
...
@@ -212,22 +227,73 @@ class CustomKTOTrainer(KTOTrainer):
sft_loss
=
-
policy_chosen_logps_avg
sft_loss
=
-
policy_chosen_logps_avg
losses
+=
self
.
ftx_gamma
*
sft_loss
.
nanmean
()
/
len
(
policy_chosen_logps
)
*
len
(
batch
[
"labels"
])
losses
+=
self
.
ftx_gamma
*
sft_loss
.
nanmean
()
/
len
(
policy_chosen_logps
)
*
len
(
batch
[
"labels"
])
num_chosen
=
torch
.
Tensor
([
len
(
chosen_rewards
)]).
to
(
self
.
accelerator
.
device
)
num_chosen
=
len
(
chosen_rewards
)
num_rejected
=
torch
.
Tensor
([
len
(
rejected_rewards
)]).
to
(
self
.
accelerator
.
device
)
num_rejected
=
len
(
rejected_rewards
)
if
num_chosen
>
0
:
metrics
[
"rewards/chosen_sum"
]
=
chosen_rewards
.
nansum
().
item
()
metrics
[
"logps/chosen_sum"
]
=
policy_chosen_logps
.
nansum
().
item
()
metrics
[
"logits/chosen_sum"
]
=
policy_chosen_logits
.
nansum
().
item
()
metrics
[
"count/chosen"
]
=
float
(
num_chosen
)
all_num_chosen
=
self
.
accelerator
.
gather
(
num_chosen
).
sum
().
item
()
if
num_rejected
>
0
:
all_num_rejected
=
self
.
accelerator
.
gather
(
num_rejected
).
sum
().
item
()
metrics
[
"rewards/rejected_sum"
]
=
rejected_rewards
.
nansum
().
item
()
metrics
[
"logps/rejected_sum"
]
=
policy_rejected_logps
.
nansum
().
item
()
metrics
[
"logits/rejected_sum"
]
=
policy_rejected_logits
.
nansum
().
item
()
metrics
[
"count/rejected"
]
=
float
(
num_rejected
)
if
all_num_chosen
>
0
:
metrics
[
"kl"
]
=
kl
.
item
()
metrics
[
"rewards/chosen_sum"
]
=
self
.
accelerator
.
gather
(
chosen_rewards
.
nansum
()).
nansum
().
item
()
return
losses
,
metrics
metrics
[
"logps/chosen_sum"
]
=
self
.
accelerator
.
gather
(
policy_chosen_logps
.
nansum
()).
nansum
().
item
()
metrics
[
"count/chosen"
]
=
all_num_chosen
if
all_num_rejected
>
0
:
@
override
metrics
[
"rewards/rejected_sum"
]
=
self
.
accelerator
.
gather
(
rejected_rewards
.
nansum
()).
nansum
().
item
()
def
compute_loss
(
self
,
model
,
inputs
,
return_outputs
=
False
,
**
kwargs
):
metrics
[
"logps/rejected_sum"
]
=
self
.
accelerator
.
gather
(
policy_rejected_logps
.
nansum
()).
nansum
().
item
()
r
"""
metrics
[
"count/rejected"
]
=
all_num_rejected
Fixes the loss value for transformers 4.46.0.
https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605
"""
loss
=
super
().
compute_loss
(
model
,
inputs
,
return_outputs
)
if
is_transformers_version_equal_to_4_46
()
and
kwargs
.
pop
(
"num_items_in_batch"
,
False
):
if
return_outputs
:
return
(
loss
[
0
]
/
self
.
args
.
gradient_accumulation_steps
,
*
loss
[
1
:])
else
:
return
loss
/
self
.
args
.
gradient_accumulation_steps
m
et
rics
[
"kl"
]
=
kl
.
item
()
r
et
urn
loss
return
losses
,
metrics
@
override
def
log
(
self
,
logs
:
Dict
[
str
,
float
])
->
None
:
r
"""
Log `logs` on the various objects watching training, including stored metrics.
"""
# logs either has "loss" or "eval_loss"
train_eval
=
"train"
if
"loss"
in
logs
else
"eval"
prefix
=
"eval_"
if
train_eval
==
"eval"
else
""
# Add averaged stored metrics to logs
key_list
,
metric_list
=
[],
[]
for
key
,
metrics
in
self
.
_stored_metrics
[
train_eval
].
items
():
key_list
.
append
(
key
)
metric_list
.
append
(
torch
.
tensor
(
metrics
,
dtype
=
torch
.
float
).
to
(
self
.
accelerator
.
device
).
sum
().
item
())
del
self
.
_stored_metrics
[
train_eval
]
if
len
(
metric_list
)
<
9
:
# pad to for all reduce
for
i
in
range
(
9
-
len
(
metric_list
)):
key_list
.
append
(
f
"dummy_
{
i
}
"
)
metric_list
.
append
(
0.0
)
metric_list
=
torch
.
tensor
(
metric_list
,
dtype
=
torch
.
float
).
to
(
self
.
accelerator
.
device
)
metric_list
=
self
.
accelerator
.
reduce
(
metric_list
,
"sum"
).
tolist
()
metric_dict
:
Dict
[
str
,
float
]
=
dict
(
zip
(
key_list
,
metric_list
))
for
split
in
[
"chosen"
,
"rejected"
]:
# accumulate average metrics from sums and lengths
if
f
"count/
{
split
}
"
in
metric_dict
:
for
key
in
(
"rewards"
,
"logps"
,
"logits"
):
logs
[
f
"
{
prefix
}{
key
}
/
{
split
}
"
]
=
metric_dict
[
f
"
{
key
}
/
{
split
}
_sum"
]
/
metric_dict
[
f
"count/
{
split
}
"
]
del
metric_dict
[
f
"
{
key
}
/
{
split
}
_sum"
]
del
metric_dict
[
f
"count/
{
split
}
"
]
if
f
"
{
prefix
}
rewards/chosen"
in
logs
and
f
"
{
prefix
}
rewards/rejected"
in
logs
:
# calculate reward margin
logs
[
f
"
{
prefix
}
rewards/margins"
]
=
logs
[
f
"
{
prefix
}
rewards/chosen"
]
-
logs
[
f
"
{
prefix
}
rewards/rejected"
]
for
key
,
metric
in
metric_dict
.
items
():
# add remaining items
if
not
key
.
startswith
(
"dummy_"
):
logs
[
key
]
=
metric
return
Trainer
.
log
(
self
,
logs
)
src/llamafactory/train/kto/workflow.py
View file @
2778a3d0
...
@@ -81,7 +81,7 @@ def run_kto(
...
@@ -81,7 +81,7 @@ def run_kto(
trainer
.
save_metrics
(
"train"
,
train_result
.
metrics
)
trainer
.
save_metrics
(
"train"
,
train_result
.
metrics
)
trainer
.
save_state
()
trainer
.
save_state
()
if
trainer
.
is_world_process_zero
()
and
finetuning_args
.
plot_loss
:
if
trainer
.
is_world_process_zero
()
and
finetuning_args
.
plot_loss
:
plot_loss
(
training_args
.
output_dir
,
keys
=
[
"loss"
,
"eval_loss"
,
"
train/
rewards/chosen"
])
plot_loss
(
training_args
.
output_dir
,
keys
=
[
"loss"
,
"eval_loss"
,
"rewards/chosen"
])
# Evaluation
# Evaluation
if
training_args
.
do_eval
:
if
training_args
.
do_eval
:
...
...
src/llamafactory/train/ppo/ppo_utils.py
View file @
2778a3d0
...
@@ -62,8 +62,8 @@ def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["d
...
@@ -62,8 +62,8 @@ def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["d
setattr
(
model
,
"default_head_bias"
,
v_head_layer
.
bias
.
data
.
detach
().
clone
())
setattr
(
model
,
"default_head_bias"
,
v_head_layer
.
bias
.
data
.
detach
().
clone
())
device
=
v_head_layer
.
weight
.
device
device
=
v_head_layer
.
weight
.
device
v_head_layer
.
weight
.
data
=
model
.
get_buffer
(
"{}_head_weight"
.
format
(
target
)
).
detach
().
clone
().
to
(
device
)
v_head_layer
.
weight
.
data
=
model
.
get_buffer
(
f
"
{
target
}
_head_weight"
).
detach
().
clone
().
to
(
device
)
v_head_layer
.
bias
.
data
=
model
.
get_buffer
(
"{}_head_bias"
.
format
(
target
)
).
detach
().
clone
().
to
(
device
)
v_head_layer
.
bias
.
data
=
model
.
get_buffer
(
f
"
{
target
}
_head_bias"
).
detach
().
clone
().
to
(
device
)
def
dump_layernorm
(
model
:
"PreTrainedModel"
)
->
Dict
[
str
,
"torch.Tensor"
]:
def
dump_layernorm
(
model
:
"PreTrainedModel"
)
->
Dict
[
str
,
"torch.Tensor"
]:
...
...
src/llamafactory/train/ppo/trainer.py
View file @
2778a3d0
...
@@ -37,7 +37,7 @@ from trl.core import PPODecorators, logprobs_from_logits
...
@@ -37,7 +37,7 @@ from trl.core import PPODecorators, logprobs_from_logits
from
trl.models.utils
import
unwrap_model_for_generation
from
trl.models.utils
import
unwrap_model_for_generation
from
typing_extensions
import
override
from
typing_extensions
import
override
from
...extras
.logging
import
get_
logg
er
from
...extras
import
logg
ing
from
...extras.misc
import
AverageMeter
,
count_parameters
,
get_current_device
,
get_logits_processor
from
...extras.misc
import
AverageMeter
,
count_parameters
,
get_current_device
,
get_logits_processor
from
..callbacks
import
FixValueHeadModelCallback
,
SaveProcessorCallback
from
..callbacks
import
FixValueHeadModelCallback
,
SaveProcessorCallback
from
..trainer_utils
import
create_custom_optimizer
,
create_custom_scheduler
from
..trainer_utils
import
create_custom_optimizer
,
create_custom_scheduler
...
@@ -58,7 +58,7 @@ if TYPE_CHECKING:
...
@@ -58,7 +58,7 @@ if TYPE_CHECKING:
from
...hparams
import
FinetuningArguments
,
GeneratingArguments
,
ModelArguments
from
...hparams
import
FinetuningArguments
,
GeneratingArguments
,
ModelArguments
logger
=
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
class
CustomPPOTrainer
(
PPOTrainer
,
Trainer
):
class
CustomPPOTrainer
(
PPOTrainer
,
Trainer
):
...
@@ -112,7 +112,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
...
@@ -112,7 +112,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
]
]
ppo_config
.
accelerator_kwargs
[
"deepspeed_plugin"
]
=
training_args
.
deepspeed_plugin
ppo_config
.
accelerator_kwargs
[
"deepspeed_plugin"
]
=
training_args
.
deepspeed_plugin
if
ppo_config
.
log_with
is
not
None
:
if
ppo_config
.
log_with
is
not
None
:
logger
.
warning
(
"PPOTrainer cannot use external logger when DeepSpeed is enabled."
)
logger
.
warning
_rank0
(
"PPOTrainer cannot use external logger when DeepSpeed is enabled."
)
ppo_config
.
log_with
=
None
ppo_config
.
log_with
=
None
# Create optimizer and scheduler
# Create optimizer and scheduler
...
@@ -160,7 +160,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
...
@@ -160,7 +160,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
callbacks
,
self
.
accelerator
.
unwrap_model
(
self
.
model
),
self
.
tokenizer
,
self
.
optimizer
,
self
.
lr_scheduler
callbacks
,
self
.
accelerator
.
unwrap_model
(
self
.
model
),
self
.
tokenizer
,
self
.
optimizer
,
self
.
lr_scheduler
)
)
if
self
.
args
.
max_steps
>
0
:
if
self
.
args
.
max_steps
>
0
:
logger
.
info
(
"max_steps is given, it will override any value given in num_train_epochs"
)
logger
.
info
_rank0
(
"max_steps is given, it will override any value given in num_train_epochs"
)
self
.
amp_context
=
torch
.
autocast
(
self
.
current_device
.
type
)
self
.
amp_context
=
torch
.
autocast
(
self
.
current_device
.
type
)
warnings
.
simplefilter
(
"ignore"
)
# remove gc warnings on ref model
warnings
.
simplefilter
(
"ignore"
)
# remove gc warnings on ref model
...
@@ -181,7 +181,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
...
@@ -181,7 +181,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self
.
add_callback
(
SaveProcessorCallback
(
processor
))
self
.
add_callback
(
SaveProcessorCallback
(
processor
))
if
finetuning_args
.
use_badam
:
if
finetuning_args
.
use_badam
:
from
badam
import
BAdamCallback
,
clip_grad_norm_old_version
from
badam
import
BAdamCallback
,
clip_grad_norm_old_version
# type: ignore
self
.
accelerator
.
clip_grad_norm_
=
MethodType
(
clip_grad_norm_old_version
,
self
.
accelerator
)
self
.
accelerator
.
clip_grad_norm_
=
MethodType
(
clip_grad_norm_old_version
,
self
.
accelerator
)
self
.
add_callback
(
BAdamCallback
)
self
.
add_callback
(
BAdamCallback
)
...
@@ -216,20 +216,19 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
...
@@ -216,20 +216,19 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self
.
state
.
is_local_process_zero
=
self
.
is_local_process_zero
()
self
.
state
.
is_local_process_zero
=
self
.
is_local_process_zero
()
self
.
state
.
is_world_process_zero
=
self
.
is_world_process_zero
()
self
.
state
.
is_world_process_zero
=
self
.
is_world_process_zero
()
if
self
.
is_world_process_zero
():
logger
.
info_rank0
(
"***** Running training *****"
)
logger
.
info
(
"***** Running training *****"
)
logger
.
info_rank0
(
f
" Num examples =
{
num_examples
:,
}
"
)
logger
.
info
(
" Num examples = {:,}"
.
format
(
num_examples
))
logger
.
info_rank0
(
f
" Num Epochs =
{
num_train_epochs
:,
}
"
)
logger
.
info
(
" Num Epochs = {:,}"
.
format
(
num_train_epochs
))
logger
.
info_rank0
(
f
" Instantaneous batch size per device =
{
self
.
args
.
per_device_train_batch_size
:,
}
"
)
logger
.
info
(
" Instantaneous batch size per device = {:,}"
.
format
(
self
.
args
.
per_device_train_batch_size
))
logger
.
info_rank0
(
logger
.
info
(
" Total train batch size (w. parallel, buffer, distributed & accumulation) = {:,}"
.
format
(
" Total train batch size (w. parallel, buffer, distributed & accumulation) = {:,}"
.
format
(
total_train_batch_size
total_train_batch_size
)
)
)
logger
.
info
(
" Gradient Accumulation steps = {:,}"
.
format
(
self
.
args
.
gradient_accumulation_steps
))
)
logger
.
info
(
" Num optimization epochs per batch = {:,}"
.
format
(
self
.
finetuning_args
.
ppo_epochs
))
logger
.
info_rank0
(
f
" Gradient Accumulation steps =
{
self
.
args
.
gradient_accumulation_steps
:,
}
"
)
logger
.
info
(
" Total training steps = {:,}"
.
format
(
max_steps
))
logger
.
info_rank0
(
f
" Num optimization epochs per batch =
{
self
.
finetuning_args
.
ppo_epochs
:,
}
"
)
logger
.
info
(
" Number of trainable parameters = {:,}"
.
format
(
count_parameters
(
self
.
model
)[
0
]))
logger
.
info_rank0
(
f
" Total training steps =
{
max_steps
:,
}
"
)
logger
.
info_rank0
(
f
" Number of trainable parameters =
{
count_parameters
(
self
.
model
)[
0
]:,
}
"
)
dataiter
=
iter
(
self
.
dataloader
)
dataiter
=
iter
(
self
.
dataloader
)
loss_meter
=
AverageMeter
()
loss_meter
=
AverageMeter
()
...
@@ -269,7 +268,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
...
@@ -269,7 +268,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
batch
[
"response"
]
=
self
.
tokenizer
.
batch_decode
(
responses
,
skip_special_tokens
=
True
)
batch
[
"response"
]
=
self
.
tokenizer
.
batch_decode
(
responses
,
skip_special_tokens
=
True
)
self
.
log_stats
(
stats
,
batch
,
rewards
)
self
.
log_stats
(
stats
,
batch
,
rewards
)
except
Exception
:
except
Exception
:
logger
.
warning
(
"Failed to save stats due to unknown errors."
)
logger
.
warning
_rank0
(
"Failed to save stats due to unknown errors."
)
self
.
state
.
global_step
+=
1
self
.
state
.
global_step
+=
1
self
.
callback_handler
.
on_step_end
(
self
.
args
,
self
.
state
,
self
.
control
)
self
.
callback_handler
.
on_step_end
(
self
.
args
,
self
.
state
,
self
.
control
)
...
@@ -290,7 +289,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
...
@@ -290,7 +289,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
if
(
step
+
1
)
%
self
.
args
.
save_steps
==
0
:
# save checkpoint
if
(
step
+
1
)
%
self
.
args
.
save_steps
==
0
:
# save checkpoint
self
.
save_model
(
self
.
save_model
(
os
.
path
.
join
(
self
.
args
.
output_dir
,
"{
}-{}"
.
format
(
PREFIX_CHECKPOINT_DIR
,
self
.
state
.
global_step
)
)
os
.
path
.
join
(
self
.
args
.
output_dir
,
f
"
{
PREFIX_CHECKPOINT_DIR
}
-
{
self
.
state
.
global_step
}
"
)
)
)
self
.
callback_handler
.
on_save
(
self
.
args
,
self
.
state
,
self
.
control
)
self
.
callback_handler
.
on_save
(
self
.
args
,
self
.
state
,
self
.
control
)
...
@@ -498,7 +497,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
...
@@ -498,7 +497,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
if
self
.
args
.
should_save
:
if
self
.
args
.
should_save
:
self
.
_save
(
output_dir
,
state_dict
=
state_dict
)
self
.
_save
(
output_dir
,
state_dict
=
state_dict
)
except
ValueError
:
except
ValueError
:
logger
.
warning
(
logger
.
warning
_rank0
(
" stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead,"
" stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead,"
" use zero_to_fp32.py to recover weights"
" use zero_to_fp32.py to recover weights"
)
)
...
...
src/llamafactory/train/pt/trainer.py
View file @
2778a3d0
...
@@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Optional
...
@@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Optional
from
transformers
import
Trainer
from
transformers
import
Trainer
from
typing_extensions
import
override
from
typing_extensions
import
override
from
...extras.
logging
import
get_logger
from
...extras.
packages
import
is_transformers_version_equal_to_4_46
from
..callbacks
import
PissaConvertCallback
,
SaveProcessorCallback
from
..callbacks
import
PissaConvertCallback
,
SaveProcessorCallback
from
..trainer_utils
import
create_custom_optimizer
,
create_custom_scheduler
from
..trainer_utils
import
create_custom_optimizer
,
create_custom_scheduler
...
@@ -30,9 +30,6 @@ if TYPE_CHECKING:
...
@@ -30,9 +30,6 @@ if TYPE_CHECKING:
from
...hparams
import
FinetuningArguments
from
...hparams
import
FinetuningArguments
logger
=
get_logger
(
__name__
)
class
CustomTrainer
(
Trainer
):
class
CustomTrainer
(
Trainer
):
r
"""
r
"""
Inherits Trainer for custom optimizer.
Inherits Trainer for custom optimizer.
...
@@ -51,7 +48,7 @@ class CustomTrainer(Trainer):
...
@@ -51,7 +48,7 @@ class CustomTrainer(Trainer):
self
.
add_callback
(
PissaConvertCallback
)
self
.
add_callback
(
PissaConvertCallback
)
if
finetuning_args
.
use_badam
:
if
finetuning_args
.
use_badam
:
from
badam
import
BAdamCallback
,
clip_grad_norm_old_version
from
badam
import
BAdamCallback
,
clip_grad_norm_old_version
# type: ignore
self
.
accelerator
.
clip_grad_norm_
=
MethodType
(
clip_grad_norm_old_version
,
self
.
accelerator
)
self
.
accelerator
.
clip_grad_norm_
=
MethodType
(
clip_grad_norm_old_version
,
self
.
accelerator
)
self
.
add_callback
(
BAdamCallback
)
self
.
add_callback
(
BAdamCallback
)
...
@@ -68,3 +65,19 @@ class CustomTrainer(Trainer):
...
@@ -68,3 +65,19 @@ class CustomTrainer(Trainer):
)
->
"torch.optim.lr_scheduler.LRScheduler"
:
)
->
"torch.optim.lr_scheduler.LRScheduler"
:
create_custom_scheduler
(
self
.
args
,
num_training_steps
,
optimizer
)
create_custom_scheduler
(
self
.
args
,
num_training_steps
,
optimizer
)
return
super
().
create_scheduler
(
num_training_steps
,
optimizer
)
return
super
().
create_scheduler
(
num_training_steps
,
optimizer
)
@
override
def
compute_loss
(
self
,
model
,
inputs
,
return_outputs
=
False
,
**
kwargs
):
r
"""
Fixes the loss value for transformers 4.46.0.
https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605
"""
loss
=
super
().
compute_loss
(
model
,
inputs
,
return_outputs
,
**
kwargs
)
if
is_transformers_version_equal_to_4_46
()
and
not
getattr
(
self
,
"model_accepts_loss_kwargs"
,
False
):
# other model should not scale the loss
if
return_outputs
:
return
(
loss
[
0
]
/
self
.
args
.
gradient_accumulation_steps
,
*
loss
[
1
:])
else
:
return
loss
/
self
.
args
.
gradient_accumulation_steps
return
loss
src/llamafactory/train/rm/trainer.py
View file @
2778a3d0
...
@@ -24,7 +24,8 @@ import torch
...
@@ -24,7 +24,8 @@ import torch
from
transformers
import
Trainer
from
transformers
import
Trainer
from
typing_extensions
import
override
from
typing_extensions
import
override
from
...extras.logging
import
get_logger
from
...extras
import
logging
from
...extras.packages
import
is_transformers_version_equal_to_4_46
from
..callbacks
import
FixValueHeadModelCallback
,
PissaConvertCallback
,
SaveProcessorCallback
from
..callbacks
import
FixValueHeadModelCallback
,
PissaConvertCallback
,
SaveProcessorCallback
from
..trainer_utils
import
create_custom_optimizer
,
create_custom_scheduler
from
..trainer_utils
import
create_custom_optimizer
,
create_custom_scheduler
...
@@ -36,7 +37,7 @@ if TYPE_CHECKING:
...
@@ -36,7 +37,7 @@ if TYPE_CHECKING:
from
...hparams
import
FinetuningArguments
from
...hparams
import
FinetuningArguments
logger
=
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
class
PairwiseTrainer
(
Trainer
):
class
PairwiseTrainer
(
Trainer
):
...
@@ -59,7 +60,7 @@ class PairwiseTrainer(Trainer):
...
@@ -59,7 +60,7 @@ class PairwiseTrainer(Trainer):
self
.
add_callback
(
PissaConvertCallback
)
self
.
add_callback
(
PissaConvertCallback
)
if
finetuning_args
.
use_badam
:
if
finetuning_args
.
use_badam
:
from
badam
import
BAdamCallback
,
clip_grad_norm_old_version
from
badam
import
BAdamCallback
,
clip_grad_norm_old_version
# type: ignore
self
.
accelerator
.
clip_grad_norm_
=
MethodType
(
clip_grad_norm_old_version
,
self
.
accelerator
)
self
.
accelerator
.
clip_grad_norm_
=
MethodType
(
clip_grad_norm_old_version
,
self
.
accelerator
)
self
.
add_callback
(
BAdamCallback
)
self
.
add_callback
(
BAdamCallback
)
...
@@ -79,7 +80,7 @@ class PairwiseTrainer(Trainer):
...
@@ -79,7 +80,7 @@ class PairwiseTrainer(Trainer):
@
override
@
override
def
compute_loss
(
def
compute_loss
(
self
,
model
:
"PreTrainedModel"
,
inputs
:
Dict
[
str
,
"torch.Tensor"
],
return_outputs
:
bool
=
False
self
,
model
:
"PreTrainedModel"
,
inputs
:
Dict
[
str
,
"torch.Tensor"
],
return_outputs
:
bool
=
False
,
**
kwargs
)
->
Union
[
"torch.Tensor"
,
Tuple
[
"torch.Tensor"
,
List
[
"torch.Tensor"
]]]:
)
->
Union
[
"torch.Tensor"
,
Tuple
[
"torch.Tensor"
,
List
[
"torch.Tensor"
]]]:
r
"""
r
"""
Computes pairwise loss. The first n examples are chosen and the last n examples are rejected.
Computes pairwise loss. The first n examples are chosen and the last n examples are rejected.
...
@@ -98,6 +99,10 @@ class PairwiseTrainer(Trainer):
...
@@ -98,6 +99,10 @@ class PairwiseTrainer(Trainer):
chosen_scores
,
rejected_scores
=
chosen_scores
.
squeeze
(),
rejected_scores
.
squeeze
()
chosen_scores
,
rejected_scores
=
chosen_scores
.
squeeze
(),
rejected_scores
.
squeeze
()
loss
=
-
torch
.
nn
.
functional
.
logsigmoid
(
chosen_scores
.
float
()
-
rejected_scores
.
float
()).
mean
()
loss
=
-
torch
.
nn
.
functional
.
logsigmoid
(
chosen_scores
.
float
()
-
rejected_scores
.
float
()).
mean
()
if
is_transformers_version_equal_to_4_46
()
and
kwargs
.
pop
(
"num_items_in_batch"
,
False
):
loss
/=
self
.
args
.
gradient_accumulation_steps
# fixes the loss value for transformers 4.46.0
if
return_outputs
:
if
return_outputs
:
return
loss
,
(
loss
,
chosen_scores
,
rejected_scores
)
return
loss
,
(
loss
,
chosen_scores
,
rejected_scores
)
else
:
else
:
...
@@ -113,7 +118,7 @@ class PairwiseTrainer(Trainer):
...
@@ -113,7 +118,7 @@ class PairwiseTrainer(Trainer):
return
return
output_prediction_file
=
os
.
path
.
join
(
self
.
args
.
output_dir
,
"generated_predictions.jsonl"
)
output_prediction_file
=
os
.
path
.
join
(
self
.
args
.
output_dir
,
"generated_predictions.jsonl"
)
logger
.
info
(
f
"Saving prediction results to
{
output_prediction_file
}
"
)
logger
.
info
_rank0
(
f
"Saving prediction results to
{
output_prediction_file
}
"
)
chosen_scores
,
rejected_scores
=
predict_results
.
predictions
chosen_scores
,
rejected_scores
=
predict_results
.
predictions
with
open
(
output_prediction_file
,
"w"
,
encoding
=
"utf-8"
)
as
writer
:
with
open
(
output_prediction_file
,
"w"
,
encoding
=
"utf-8"
)
as
writer
:
...
...
src/llamafactory/train/sft/trainer.py
View file @
2778a3d0
...
@@ -25,8 +25,9 @@ import torch
...
@@ -25,8 +25,9 @@ import torch
from
transformers
import
Seq2SeqTrainer
from
transformers
import
Seq2SeqTrainer
from
typing_extensions
import
override
from
typing_extensions
import
override
from
...extras
import
logging
from
...extras.constants
import
IGNORE_INDEX
from
...extras.constants
import
IGNORE_INDEX
from
...extras.
logging
import
get_logger
from
...extras.
packages
import
is_transformers_version_equal_to_4_46
from
..callbacks
import
PissaConvertCallback
,
SaveProcessorCallback
from
..callbacks
import
PissaConvertCallback
,
SaveProcessorCallback
from
..trainer_utils
import
create_custom_optimizer
,
create_custom_scheduler
from
..trainer_utils
import
create_custom_optimizer
,
create_custom_scheduler
...
@@ -39,7 +40,7 @@ if TYPE_CHECKING:
...
@@ -39,7 +40,7 @@ if TYPE_CHECKING:
from
...hparams
import
FinetuningArguments
from
...hparams
import
FinetuningArguments
logger
=
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
class
CustomSeq2SeqTrainer
(
Seq2SeqTrainer
):
class
CustomSeq2SeqTrainer
(
Seq2SeqTrainer
):
...
@@ -60,7 +61,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
...
@@ -60,7 +61,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
self
.
add_callback
(
PissaConvertCallback
)
self
.
add_callback
(
PissaConvertCallback
)
if
finetuning_args
.
use_badam
:
if
finetuning_args
.
use_badam
:
from
badam
import
BAdamCallback
,
clip_grad_norm_old_version
from
badam
import
BAdamCallback
,
clip_grad_norm_old_version
# type: ignore
self
.
accelerator
.
clip_grad_norm_
=
MethodType
(
clip_grad_norm_old_version
,
self
.
accelerator
)
self
.
accelerator
.
clip_grad_norm_
=
MethodType
(
clip_grad_norm_old_version
,
self
.
accelerator
)
self
.
add_callback
(
BAdamCallback
)
self
.
add_callback
(
BAdamCallback
)
...
@@ -78,6 +79,22 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
...
@@ -78,6 +79,22 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
create_custom_scheduler
(
self
.
args
,
num_training_steps
,
optimizer
)
create_custom_scheduler
(
self
.
args
,
num_training_steps
,
optimizer
)
return
super
().
create_scheduler
(
num_training_steps
,
optimizer
)
return
super
().
create_scheduler
(
num_training_steps
,
optimizer
)
@
override
def
compute_loss
(
self
,
model
,
inputs
,
return_outputs
=
False
,
**
kwargs
):
r
"""
Fixes the loss value for transformers 4.46.0.
https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605
"""
loss
=
super
().
compute_loss
(
model
,
inputs
,
return_outputs
,
**
kwargs
)
if
is_transformers_version_equal_to_4_46
()
and
not
getattr
(
self
,
"model_accepts_loss_kwargs"
,
False
):
# other model should not scale the loss
if
return_outputs
:
return
(
loss
[
0
]
/
self
.
args
.
gradient_accumulation_steps
,
*
loss
[
1
:])
else
:
return
loss
/
self
.
args
.
gradient_accumulation_steps
return
loss
@
override
@
override
def
prediction_step
(
def
prediction_step
(
self
,
self
,
...
@@ -129,7 +146,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
...
@@ -129,7 +146,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
return
return
output_prediction_file
=
os
.
path
.
join
(
self
.
args
.
output_dir
,
"generated_predictions.jsonl"
)
output_prediction_file
=
os
.
path
.
join
(
self
.
args
.
output_dir
,
"generated_predictions.jsonl"
)
logger
.
info
(
f
"Saving prediction results to
{
output_prediction_file
}
"
)
logger
.
info
_rank0
(
f
"Saving prediction results to
{
output_prediction_file
}
"
)
labels
=
np
.
where
(
labels
=
np
.
where
(
predict_results
.
label_ids
!=
IGNORE_INDEX
,
predict_results
.
label_ids
,
self
.
tokenizer
.
pad_token_id
predict_results
.
label_ids
!=
IGNORE_INDEX
,
predict_results
.
label_ids
,
self
.
tokenizer
.
pad_token_id
...
...
Prev
1
…
3
4
5
6
7
8
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment