Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
LLaMA-Factory
Commits
27a7ad86
Commit
27a7ad86
authored
Oct 14, 2024
by
luopl
Browse files
update to v0.9.1
parent
731cf9b8
Changes
120
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
289 additions
and
114 deletions
+289
-114
src/llamafactory/model/model_utils/misc.py
src/llamafactory/model/model_utils/misc.py
+10
-5
src/llamafactory/model/model_utils/moe.py
src/llamafactory/model/model_utils/moe.py
+11
-9
src/llamafactory/model/model_utils/packing.py
src/llamafactory/model/model_utils/packing.py
+1
-1
src/llamafactory/model/model_utils/visual.py
src/llamafactory/model/model_utils/visual.py
+104
-8
src/llamafactory/model/patcher.py
src/llamafactory/model/patcher.py
+28
-5
src/llamafactory/train/callbacks.py
src/llamafactory/train/callbacks.py
+51
-52
src/llamafactory/train/dpo/trainer.py
src/llamafactory/train/dpo/trainer.py
+6
-1
src/llamafactory/train/dpo/workflow.py
src/llamafactory/train/dpo/workflow.py
+6
-4
src/llamafactory/train/kto/trainer.py
src/llamafactory/train/kto/trainer.py
+13
-3
src/llamafactory/train/kto/workflow.py
src/llamafactory/train/kto/workflow.py
+6
-4
src/llamafactory/train/ppo/ppo_utils.py
src/llamafactory/train/ppo/ppo_utils.py
+3
-3
src/llamafactory/train/ppo/trainer.py
src/llamafactory/train/ppo/trainer.py
+7
-2
src/llamafactory/train/ppo/workflow.py
src/llamafactory/train/ppo/workflow.py
+4
-5
src/llamafactory/train/pt/trainer.py
src/llamafactory/train/pt/trainer.py
+3
-0
src/llamafactory/train/pt/workflow.py
src/llamafactory/train/pt/workflow.py
+3
-2
src/llamafactory/train/rm/metric.py
src/llamafactory/train/rm/metric.py
+4
-0
src/llamafactory/train/rm/trainer.py
src/llamafactory/train/rm/trainer.py
+6
-2
src/llamafactory/train/rm/workflow.py
src/llamafactory/train/rm/workflow.py
+5
-4
src/llamafactory/train/sft/metric.py
src/llamafactory/train/sft/metric.py
+9
-0
src/llamafactory/train/sft/trainer.py
src/llamafactory/train/sft/trainer.py
+9
-4
No files found.
src/llamafactory/model/model_utils/misc.py
View file @
27a7ad86
...
@@ -28,17 +28,22 @@ def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool)
...
@@ -28,17 +28,22 @@ def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool)
r
"""
r
"""
Finds all available modules to apply lora or galore.
Finds all available modules to apply lora or galore.
"""
"""
model_type
=
getattr
(
model
.
config
,
"model_type"
,
None
)
forbidden_modules
=
{
"lm_head"
}
forbidden_modules
=
{
"lm_head"
}
if
model_type
==
"chatglm"
:
if
model
.
config
.
model_type
==
"chatglm"
:
forbidden_modules
.
add
(
"output_layer"
)
forbidden_modules
.
add
(
"output_layer"
)
elif
model
.
config
.
model_type
==
"internlm2"
:
elif
model_type
==
"internlm2"
:
forbidden_modules
.
add
(
"output"
)
forbidden_modules
.
add
(
"output"
)
elif
model
.
config
.
model_type
in
[
"llava"
,
"
paligemm
a"
]:
elif
model_type
in
[
"llava"
,
"
llava_next"
,
"llava_next_video"
,
"paligemma"
,
"video_llav
a"
]:
forbidden_modules
.
add
(
"multi_modal_projector"
)
forbidden_modules
.
add
(
"multi_modal_projector"
)
elif
model_type
==
"qwen2_vl"
:
forbidden_modules
.
add
(
"merger"
)
if
freeze_vision_tower
:
if
freeze_vision_tower
:
forbidden_modules
.
add
(
"vision_tower"
)
if
model_type
==
"qwen2_vl"
:
forbidden_modules
.
add
(
"visual"
)
else
:
forbidden_modules
.
add
(
"vision_tower"
)
module_names
=
set
()
module_names
=
set
()
for
name
,
module
in
model
.
named_modules
():
for
name
,
module
in
model
.
named_modules
():
...
...
src/llamafactory/model/model_utils/moe.py
View file @
27a7ad86
...
@@ -39,42 +39,44 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
...
@@ -39,42 +39,44 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
if
not
is_deepspeed_zero3_enabled
():
if
not
is_deepspeed_zero3_enabled
():
return
return
if
getattr
(
model
.
config
,
"model_type"
,
None
)
==
"dbrx"
:
model_type
=
getattr
(
model
.
config
,
"model_type"
,
None
)
if
model_type
==
"dbrx"
:
from
transformers.models.dbrx.modeling_dbrx
import
DbrxFFN
from
transformers.models.dbrx.modeling_dbrx
import
DbrxFFN
_set_z3_leaf_modules
(
model
,
[
DbrxFFN
])
_set_z3_leaf_modules
(
model
,
[
DbrxFFN
])
if
getattr
(
model
.
config
,
"
model_type
"
,
None
)
==
"jamba"
:
if
model_type
==
"jamba"
:
from
transformers.models.jamba.modeling_jamba
import
JambaSparseMoeBlock
from
transformers.models.jamba.modeling_jamba
import
JambaSparseMoeBlock
_set_z3_leaf_modules
(
model
,
[
JambaSparseMoeBlock
])
_set_z3_leaf_modules
(
model
,
[
JambaSparseMoeBlock
])
if
getattr
(
model
.
config
,
"
model_type
"
,
None
)
==
"jetmoe"
:
if
model_type
==
"jetmoe"
:
from
transformers.models.jetmoe.modeling_jetmoe
import
JetMoeMoA
,
JetMoeMoE
from
transformers.models.jetmoe.modeling_jetmoe
import
JetMoeMoA
,
JetMoeMoE
_set_z3_leaf_modules
(
model
,
[
JetMoeMoA
,
JetMoeMoE
])
_set_z3_leaf_modules
(
model
,
[
JetMoeMoA
,
JetMoeMoE
])
if
getattr
(
model
.
config
,
"
model_type
"
,
None
)
==
"mixtral"
:
if
model_type
==
"mixtral"
:
from
transformers.models.mixtral.modeling_mixtral
import
MixtralSparseMoeBlock
from
transformers.models.mixtral.modeling_mixtral
import
MixtralSparseMoeBlock
_set_z3_leaf_modules
(
model
,
[
MixtralSparseMoeBlock
])
_set_z3_leaf_modules
(
model
,
[
MixtralSparseMoeBlock
])
if
getattr
(
model
.
config
,
"
model_type
"
,
None
)
==
"qwen2moe"
:
if
model_type
==
"qwen2moe"
:
from
transformers.models.qwen2_moe.modeling_qwen2_moe
import
Qwen2MoeSparseMoeBlock
from
transformers.models.qwen2_moe.modeling_qwen2_moe
import
Qwen2MoeSparseMoeBlock
_set_z3_leaf_modules
(
model
,
[
Qwen2MoeSparseMoeBlock
])
_set_z3_leaf_modules
(
model
,
[
Qwen2MoeSparseMoeBlock
])
def
configure_moe
(
config
:
"PretrainedConfig"
,
model_args
:
"ModelArguments"
,
is_trainable
:
bool
)
->
None
:
def
configure_moe
(
config
:
"PretrainedConfig"
,
model_args
:
"ModelArguments"
,
is_trainable
:
bool
)
->
None
:
model_type
=
getattr
(
config
,
"model_type"
,
None
)
if
model_args
.
moe_aux_loss_coef
is
not
None
:
if
model_args
.
moe_aux_loss_coef
is
not
None
:
if
getattr
(
config
,
"
model_type
"
,
None
)
in
[
"jamba"
,
"mixtral"
,
"qwen2_moe"
]:
if
model_type
in
[
"jamba"
,
"mixtral"
,
"qwen2_moe"
]:
setattr
(
config
,
"router_aux_loss_coef"
,
model_args
.
moe_aux_loss_coef
)
setattr
(
config
,
"router_aux_loss_coef"
,
model_args
.
moe_aux_loss_coef
)
elif
getattr
(
config
,
"
model_type
"
,
None
)
==
"deepseek"
:
elif
model_type
==
"deepseek"
:
setattr
(
config
,
"aux_loss_alpha"
,
model_args
.
moe_aux_loss_coef
)
setattr
(
config
,
"aux_loss_alpha"
,
model_args
.
moe_aux_loss_coef
)
elif
getattr
(
config
,
"
model_type
"
,
None
)
==
"jetmoe"
:
elif
model_type
==
"jetmoe"
:
setattr
(
config
,
"aux_loss_coef"
,
model_args
.
moe_aux_loss_coef
)
setattr
(
config
,
"aux_loss_coef"
,
model_args
.
moe_aux_loss_coef
)
if
getattr
(
config
,
"
model_type
"
,
None
)
in
[
"dbrx"
,
"jamba"
,
"jetmoe"
,
"mixtral"
,
"qwen2_moe"
]:
if
model_type
in
[
"dbrx"
,
"jamba"
,
"jetmoe"
,
"mixtral"
,
"qwen2_moe"
]:
setattr
(
config
,
"output_router_logits"
,
is_trainable
)
setattr
(
config
,
"output_router_logits"
,
is_trainable
)
src/llamafactory/model/model_utils/packing.py
View file @
27a7ad86
...
@@ -114,7 +114,7 @@ def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "tor
...
@@ -114,7 +114,7 @@ def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "tor
def
_patch_for_block_diag_attn
(
model_type
:
str
)
->
None
:
def
_patch_for_block_diag_attn
(
model_type
:
str
)
->
None
:
require_version
(
"transformers>=4.41.2,<=4.4
3.4
"
,
"To fix: pip install transformers>=4.41.2,<=4.4
3.4
"
)
require_version
(
"transformers>=4.41.2,<=4.4
5.2
"
,
"To fix: pip install transformers>=4.41.2,<=4.4
5.2
"
)
if
is_transformers_version_greater_than_4_43
():
if
is_transformers_version_greater_than_4_43
():
import
transformers.modeling_flash_attention_utils
import
transformers.modeling_flash_attention_utils
...
...
src/llamafactory/model/model_utils/visual.py
View file @
27a7ad86
...
@@ -15,7 +15,7 @@
...
@@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
typing
import
TYPE_CHECKING
,
Tuple
from
typing
import
TYPE_CHECKING
,
List
,
Sequence
,
Set
,
Tuple
,
Union
import
torch
import
torch
import
transformers.models
import
transformers.models
...
@@ -28,7 +28,7 @@ from ...extras.logging import get_logger
...
@@ -28,7 +28,7 @@ from ...extras.logging import get_logger
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
transformers
import
LlavaConfig
,
PretrainedConfig
,
PreTrainedModel
from
transformers
import
LlavaConfig
,
PretrainedConfig
,
PreTrainedModel
from
...hparams
import
ModelArguments
from
...hparams
import
FinetuningArguments
,
ModelArguments
logger
=
get_logger
(
__name__
)
logger
=
get_logger
(
__name__
)
...
@@ -80,24 +80,120 @@ class LlavaMultiModalProjectorForYiVLForVLLM(LlavaMultiModalProjectorForYiVL):
...
@@ -80,24 +80,120 @@ class LlavaMultiModalProjectorForYiVLForVLLM(LlavaMultiModalProjectorForYiVL):
self
.
act
=
ACT2FN
[
projector_hidden_act
]
self
.
act
=
ACT2FN
[
projector_hidden_act
]
def
autocast_projector_dtype
(
def
autocast_projector_dtype
(
model
:
"PreTrainedModel"
,
model_args
:
"ModelArguments"
)
->
None
:
model
:
"PreTrainedModel"
,
model_args
:
"ModelArguments"
,
mm_projector_name
:
str
=
"multi_modal_projector"
r
"""
)
->
None
:
Casts projector output to half precision for fine-tuning quantized VLMs.
"""
def
_mm_projector_forward_post_hook
(
def
_mm_projector_forward_post_hook
(
module
:
"torch.nn.Module"
,
args
:
Tuple
[
"torch.Tensor"
],
output
:
"torch.Tensor"
module
:
"torch.nn.Module"
,
args
:
Tuple
[
"torch.Tensor"
],
output
:
"torch.Tensor"
)
->
"torch.Tensor"
:
)
->
"torch.Tensor"
:
return
output
.
to
(
model_args
.
compute_dtype
)
return
output
.
to
(
model_args
.
compute_dtype
)
if
hasattr
(
model
,
mm_projector_name
)
and
getattr
(
model
,
"quantization_method"
,
None
):
if
getattr
(
model
,
"quantization_method"
,
None
):
model_type
=
getattr
(
model
.
config
,
"model_type"
,
None
)
if
model_type
in
[
"llava"
,
"llava_next"
,
"llava_next_video"
,
"paligemma"
,
"video_llava"
]:
mm_projector
:
"torch.nn.Module"
=
getattr
(
model
,
"multi_modal_projector"
)
elif
model_type
==
"qwen2_vl"
:
mm_projector
:
"torch.nn.Module"
=
getattr
(
getattr
(
model
,
"visual"
),
"merger"
)
else
:
return
logger
.
info
(
"Casting multimodal projector outputs in {}."
.
format
(
model_args
.
compute_dtype
))
logger
.
info
(
"Casting multimodal projector outputs in {}."
.
format
(
model_args
.
compute_dtype
))
mm_projector
:
"torch.nn.Module"
=
getattr
(
model
,
mm_projector_name
)
mm_projector
.
register_forward_hook
(
_mm_projector_forward_post_hook
)
mm_projector
.
register_forward_hook
(
_mm_projector_forward_post_hook
)
def
configure_visual_model
(
config
:
"PretrainedConfig"
)
->
None
:
def
configure_visual_model
(
config
:
"PretrainedConfig"
)
->
None
:
if
getattr
(
config
,
"model_type"
,
None
)
==
"llava"
:
# required for ds zero3 and valuehead models
r
"""
Patches VLMs before loading them.
"""
model_type
=
getattr
(
config
,
"model_type"
,
None
)
if
model_type
in
[
"llava"
,
"llava_next"
,
"llava_next_video"
,
"paligemma"
,
"video_llava"
,
]:
# required for ds zero3 and valuehead models
setattr
(
config
,
"hidden_size"
,
getattr
(
config
.
text_config
,
"hidden_size"
,
None
))
setattr
(
config
,
"hidden_size"
,
getattr
(
config
.
text_config
,
"hidden_size"
,
None
))
if
getattr
(
config
,
"is_yi_vl_derived_model"
,
None
):
if
getattr
(
config
,
"is_yi_vl_derived_model"
,
None
):
logger
.
info
(
"Detected Yi-VL model, applying projector patch."
)
logger
.
info
(
"Detected Yi-VL model, applying projector patch."
)
transformers
.
models
.
llava
.
modeling_llava
.
LlavaMultiModalProjector
=
LlavaMultiModalProjectorForYiVL
transformers
.
models
.
llava
.
modeling_llava
.
LlavaMultiModalProjector
=
LlavaMultiModalProjectorForYiVL
def
get_forbidden_modules
(
config
:
"PretrainedConfig"
,
finetuning_args
:
"FinetuningArguments"
)
->
Set
[
str
]:
r
"""
Freezes vision tower and language model for VLM full/freeze tuning.
"""
model_type
=
getattr
(
config
,
"model_type"
,
None
)
forbidden_modules
=
set
()
if
model_type
in
[
"llava"
,
"llava_next"
,
"llava_next_video"
,
"paligemma"
,
"video_llava"
]:
if
finetuning_args
.
freeze_vision_tower
:
forbidden_modules
.
add
(
"vision_tower"
)
if
finetuning_args
.
train_mm_proj_only
:
forbidden_modules
.
add
(
"language_model"
)
elif
model_type
==
"qwen2_vl"
:
if
finetuning_args
.
freeze_vision_tower
:
forbidden_modules
.
add
(
"visual"
)
if
finetuning_args
.
train_mm_proj_only
:
raise
ValueError
(
"Qwen2-VL models do not support `train_mm_proj_only`."
)
return
forbidden_modules
def
get_image_seqlen
(
config
:
"PretrainedConfig"
)
->
int
:
r
"""
Computes the number of special tokens per image.
"""
model_type
=
getattr
(
config
,
"model_type"
,
None
)
if
model_type
==
"llava"
:
image_seqlen
=
(
config
.
vision_config
.
image_size
//
config
.
vision_config
.
patch_size
)
**
2
if
getattr
(
config
,
"vision_feature_select_strategy"
,
"default"
)
==
"full"
:
# add [CLS] token
image_seqlen
+=
1
elif
model_type
==
"paligemma"
:
image_seqlen
=
config
.
vision_config
.
num_image_tokens
else
:
image_seqlen
=
-
1
return
image_seqlen
def
get_patch_size
(
config
:
"PretrainedConfig"
)
->
int
:
r
"""
Computes the patch size of the vit.
"""
patch_size
=
getattr
(
config
.
vision_config
,
"patch_size"
,
-
1
)
return
patch_size
def
get_vision_feature_select_strategy
(
config
:
"PretrainedConfig"
)
->
int
:
r
"""
Get the vision_feature_select_strategy.
"""
vision_feature_select_strategy
=
getattr
(
config
,
"vision_feature_select_strategy"
,
"default"
)
return
vision_feature_select_strategy
def
patch_target_modules
(
config
:
"PretrainedConfig"
,
finetuning_args
:
"FinetuningArguments"
,
target_modules
:
Sequence
[
str
]
)
->
Union
[
str
,
List
[
str
]]:
r
"""
Freezes vision tower for VLM LoRA tuning.
"""
model_type
=
getattr
(
config
,
"model_type"
,
None
)
if
finetuning_args
.
freeze_vision_tower
:
if
model_type
in
[
"llava"
,
"llava_next"
,
"llava_next_video"
,
"paligemma"
,
"video_llava"
]:
return
"^(?!.*vision_tower).*(?:{}).*"
.
format
(
"|"
.
join
(
target_modules
))
elif
model_type
==
"qwen2_vl"
:
return
"^(?!.*visual).*(?:{}).*"
.
format
(
"|"
.
join
(
target_modules
))
else
:
return
target_modules
else
:
if
model_type
==
"qwen2_vl"
:
return
"^(?!.*patch_embed).*(?:{}).*"
.
format
(
"|"
.
join
(
target_modules
))
else
:
return
target_modules
src/llamafactory/model/patcher.py
View file @
27a7ad86
...
@@ -33,11 +33,17 @@ from .model_utils.packing import configure_packing
...
@@ -33,11 +33,17 @@ from .model_utils.packing import configure_packing
from
.model_utils.quantization
import
configure_quantization
from
.model_utils.quantization
import
configure_quantization
from
.model_utils.rope
import
configure_rope
from
.model_utils.rope
import
configure_rope
from
.model_utils.valuehead
import
prepare_valuehead_model
from
.model_utils.valuehead
import
prepare_valuehead_model
from
.model_utils.visual
import
autocast_projector_dtype
,
configure_visual_model
from
.model_utils.visual
import
(
autocast_projector_dtype
,
configure_visual_model
,
get_image_seqlen
,
get_patch_size
,
get_vision_feature_select_strategy
,
)
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
transformers
import
PretrainedConfig
,
PreTrainedTokenizer
from
transformers
import
PretrainedConfig
,
PreTrainedTokenizer
,
ProcessorMixin
from
trl
import
AutoModelForCausalLMWithValueHead
from
trl
import
AutoModelForCausalLMWithValueHead
from
..hparams
import
ModelArguments
from
..hparams
import
ModelArguments
...
@@ -51,6 +57,22 @@ def patch_tokenizer(tokenizer: "PreTrainedTokenizer") -> None:
...
@@ -51,6 +57,22 @@ def patch_tokenizer(tokenizer: "PreTrainedTokenizer") -> None:
tokenizer
.
_pad
=
MethodType
(
PreTrainedTokenizerBase
.
_pad
,
tokenizer
)
tokenizer
.
_pad
=
MethodType
(
PreTrainedTokenizerBase
.
_pad
,
tokenizer
)
def
patch_processor
(
processor
:
"ProcessorMixin"
,
config
:
"PretrainedConfig"
,
tokenizer
:
"PreTrainedTokenizer"
,
model_args
:
"ModelArguments"
,
)
->
None
:
setattr
(
processor
,
"tokenizer"
,
tokenizer
)
setattr
(
processor
,
"image_seqlen"
,
get_image_seqlen
(
config
))
setattr
(
processor
,
"image_resolution"
,
model_args
.
image_resolution
)
setattr
(
processor
,
"patch_size"
,
get_patch_size
(
config
))
setattr
(
processor
,
"video_resolution"
,
model_args
.
video_resolution
)
setattr
(
processor
,
"video_fps"
,
model_args
.
video_fps
)
setattr
(
processor
,
"video_maxlen"
,
model_args
.
video_maxlen
)
setattr
(
processor
,
"vision_feature_select_strategy"
,
get_vision_feature_select_strategy
(
config
))
def
patch_config
(
def
patch_config
(
config
:
"PretrainedConfig"
,
config
:
"PretrainedConfig"
,
tokenizer
:
"PreTrainedTokenizer"
,
tokenizer
:
"PreTrainedTokenizer"
,
...
@@ -88,6 +110,9 @@ def patch_config(
...
@@ -88,6 +110,9 @@ def patch_config(
if
getattr
(
config
,
"model_type"
,
None
)
==
"qwen2"
and
is_trainable
and
model_args
.
flash_attn
==
"fa2"
:
if
getattr
(
config
,
"model_type"
,
None
)
==
"qwen2"
and
is_trainable
and
model_args
.
flash_attn
==
"fa2"
:
setattr
(
config
,
"use_cache"
,
False
)
# qwen2 does not support use_cache when using flash attn
setattr
(
config
,
"use_cache"
,
False
)
# qwen2 does not support use_cache when using flash attn
if
"LlavaLlamaForCausalLM"
in
getattr
(
config
,
"architectures"
,
[]):
raise
ValueError
(
"Please download llava models with hf-compatible format: https://huggingface.co/llava-hf"
)
# deepspeed zero3 is not compatible with low_cpu_mem_usage
# deepspeed zero3 is not compatible with low_cpu_mem_usage
init_kwargs
[
"low_cpu_mem_usage"
]
=
model_args
.
low_cpu_mem_usage
and
(
not
is_deepspeed_zero3_enabled
())
init_kwargs
[
"low_cpu_mem_usage"
]
=
model_args
.
low_cpu_mem_usage
and
(
not
is_deepspeed_zero3_enabled
())
...
@@ -129,11 +154,9 @@ def patch_model(
...
@@ -129,11 +154,9 @@ def patch_model(
if
model_args
.
resize_vocab
:
if
model_args
.
resize_vocab
:
resize_embedding_layer
(
model
,
tokenizer
)
resize_embedding_layer
(
model
,
tokenizer
)
if
model_args
.
visual_inputs
:
autocast_projector_dtype
(
model
,
model_args
)
if
is_trainable
:
if
is_trainable
:
prepare_model_for_training
(
model
,
model_args
)
prepare_model_for_training
(
model
,
model_args
)
autocast_projector_dtype
(
model
,
model_args
)
add_z3_leaf_module
(
model
)
add_z3_leaf_module
(
model
)
if
not
model_args
.
use_unsloth
:
if
not
model_args
.
use_unsloth
:
...
...
src/llamafactory/train/callbacks.py
View file @
27a7ad86
...
@@ -32,9 +32,11 @@ from transformers.utils import (
...
@@ -32,9 +32,11 @@ from transformers.utils import (
WEIGHTS_NAME
,
WEIGHTS_NAME
,
is_safetensors_available
,
is_safetensors_available
,
)
)
from
typing_extensions
import
override
from
..extras.constants
import
TRAINER_LOG
,
V_HEAD_SAFE_WEIGHTS_NAME
,
V_HEAD_WEIGHTS_NAME
from
..extras.constants
import
TRAINER_LOG
,
V_HEAD_SAFE_WEIGHTS_NAME
,
V_HEAD_WEIGHTS_NAME
from
..extras.logging
import
LoggerHandler
,
get_logger
from
..extras.logging
import
LoggerHandler
,
get_logger
from
..extras.misc
import
get_peak_memory
if
is_safetensors_available
():
if
is_safetensors_available
():
...
@@ -73,8 +75,8 @@ def fix_valuehead_checkpoint(
...
@@ -73,8 +75,8 @@ def fix_valuehead_checkpoint(
path_to_checkpoint
=
os
.
path
.
join
(
output_dir
,
WEIGHTS_NAME
)
path_to_checkpoint
=
os
.
path
.
join
(
output_dir
,
WEIGHTS_NAME
)
state_dict
:
Dict
[
str
,
torch
.
Tensor
]
=
torch
.
load
(
path_to_checkpoint
,
map_location
=
"cpu"
)
state_dict
:
Dict
[
str
,
torch
.
Tensor
]
=
torch
.
load
(
path_to_checkpoint
,
map_location
=
"cpu"
)
decoder_state_dict
=
{}
os
.
remove
(
path_to_checkpoint
)
v_head_state_dict
=
{}
decoder_state_dict
,
v_head_state_dict
=
{}
,
{}
for
name
,
param
in
state_dict
.
items
():
for
name
,
param
in
state_dict
.
items
():
if
name
.
startswith
(
"v_head."
):
if
name
.
startswith
(
"v_head."
):
v_head_state_dict
[
name
]
=
param
v_head_state_dict
[
name
]
=
param
...
@@ -90,43 +92,52 @@ def fix_valuehead_checkpoint(
...
@@ -90,43 +92,52 @@ def fix_valuehead_checkpoint(
else
:
else
:
torch
.
save
(
v_head_state_dict
,
os
.
path
.
join
(
output_dir
,
V_HEAD_WEIGHTS_NAME
))
torch
.
save
(
v_head_state_dict
,
os
.
path
.
join
(
output_dir
,
V_HEAD_WEIGHTS_NAME
))
os
.
remove
(
path_to_checkpoint
)
logger
.
info
(
"Value head model saved at: {}"
.
format
(
output_dir
))
logger
.
info
(
"Value head model saved at: {}"
.
format
(
output_dir
))
class
FixValueHeadModelCallback
(
TrainerCallback
):
class
FixValueHeadModelCallback
(
TrainerCallback
):
r
"""
A callback for fixing the checkpoint for valuehead models.
"""
@
override
def
on_save
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
def
on_save
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
r
"""
r
"""
Event called after a checkpoint save.
Event called after a checkpoint save.
"""
"""
if
args
.
should_save
:
if
args
.
should_save
:
output_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"{}-{}"
.
format
(
PREFIX_CHECKPOINT_DIR
,
state
.
global_step
))
fix_valuehead_checkpoint
(
fix_valuehead_checkpoint
(
model
=
kwargs
.
pop
(
"model"
),
model
=
kwargs
.
pop
(
"model"
),
output_dir
=
output_dir
,
safe_serialization
=
args
.
save_safetensors
output_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"{}-{}"
.
format
(
PREFIX_CHECKPOINT_DIR
,
state
.
global_step
)),
safe_serialization
=
args
.
save_safetensors
,
)
)
class
SaveProcessorCallback
(
TrainerCallback
):
class
SaveProcessorCallback
(
TrainerCallback
):
r
"""
A callback for saving the processor.
"""
def
__init__
(
self
,
processor
:
"ProcessorMixin"
)
->
None
:
def
__init__
(
self
,
processor
:
"ProcessorMixin"
)
->
None
:
r
"""
Initializes a callback for saving the processor.
"""
self
.
processor
=
processor
self
.
processor
=
processor
@
override
def
on_save
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
if
args
.
should_save
:
output_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"{}-{}"
.
format
(
PREFIX_CHECKPOINT_DIR
,
state
.
global_step
))
getattr
(
self
.
processor
,
"image_processor"
).
save_pretrained
(
output_dir
)
@
override
def
on_train_end
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
def
on_train_end
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
r
"""
Event called at the end of training.
"""
if
args
.
should_save
:
if
args
.
should_save
:
getattr
(
self
.
processor
,
"image_processor"
).
save_pretrained
(
args
.
output_dir
)
getattr
(
self
.
processor
,
"image_processor"
).
save_pretrained
(
args
.
output_dir
)
class
PissaConvertCallback
(
TrainerCallback
):
class
PissaConvertCallback
(
TrainerCallback
):
r
"""
r
"""
Initializes a
callback for converting the PiSSA adapter to a normal one.
A
callback for converting the PiSSA adapter to a normal one.
"""
"""
@
override
def
on_train_begin
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
def
on_train_begin
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
r
"""
r
"""
Event called at the beginning of training.
Event called at the beginning of training.
...
@@ -141,10 +152,8 @@ class PissaConvertCallback(TrainerCallback):
...
@@ -141,10 +152,8 @@ class PissaConvertCallback(TrainerCallback):
model
.
save_pretrained
(
pissa_init_dir
,
safe_serialization
=
args
.
save_safetensors
)
model
.
save_pretrained
(
pissa_init_dir
,
safe_serialization
=
args
.
save_safetensors
)
setattr
(
model
.
peft_config
[
"default"
],
"init_lora_weights"
,
init_lora_weights
)
setattr
(
model
.
peft_config
[
"default"
],
"init_lora_weights"
,
init_lora_weights
)
@
override
def
on_train_end
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
def
on_train_end
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
r
"""
Event called at the end of training.
"""
if
args
.
should_save
:
if
args
.
should_save
:
model
=
kwargs
.
pop
(
"model"
)
model
=
kwargs
.
pop
(
"model"
)
pissa_init_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"pissa_init"
)
pissa_init_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"pissa_init"
)
...
@@ -172,21 +181,22 @@ class PissaConvertCallback(TrainerCallback):
...
@@ -172,21 +181,22 @@ class PissaConvertCallback(TrainerCallback):
class
LogCallback
(
TrainerCallback
):
class
LogCallback
(
TrainerCallback
):
r
"""
A callback for logging training and evaluation status.
"""
def
__init__
(
self
)
->
None
:
def
__init__
(
self
)
->
None
:
r
"""
# Progress
Initializes a callback for logging training and evaluation status.
"""
""" Progress """
self
.
start_time
=
0
self
.
start_time
=
0
self
.
cur_steps
=
0
self
.
cur_steps
=
0
self
.
max_steps
=
0
self
.
max_steps
=
0
self
.
elapsed_time
=
""
self
.
elapsed_time
=
""
self
.
remaining_time
=
""
self
.
remaining_time
=
""
self
.
thread_pool
:
Optional
[
"ThreadPoolExecutor"
]
=
None
self
.
thread_pool
:
Optional
[
"ThreadPoolExecutor"
]
=
None
"""
Status
"""
#
Status
self
.
aborted
=
False
self
.
aborted
=
False
self
.
do_train
=
False
self
.
do_train
=
False
"""
Web UI
"""
#
Web UI
self
.
webui_mode
=
os
.
environ
.
get
(
"LLAMABOARD_ENABLED"
,
"0"
).
lower
()
in
[
"true"
,
"1"
]
self
.
webui_mode
=
os
.
environ
.
get
(
"LLAMABOARD_ENABLED"
,
"0"
).
lower
()
in
[
"true"
,
"1"
]
if
self
.
webui_mode
:
if
self
.
webui_mode
:
signal
.
signal
(
signal
.
SIGABRT
,
self
.
_set_abort
)
signal
.
signal
(
signal
.
SIGABRT
,
self
.
_set_abort
)
...
@@ -226,10 +236,8 @@ class LogCallback(TrainerCallback):
...
@@ -226,10 +236,8 @@ class LogCallback(TrainerCallback):
self
.
thread_pool
.
shutdown
(
wait
=
True
)
self
.
thread_pool
.
shutdown
(
wait
=
True
)
self
.
thread_pool
=
None
self
.
thread_pool
=
None
@
override
def
on_init_end
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
def
on_init_end
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
r
"""
Event called at the end of the initialization of the `Trainer`.
"""
if
(
if
(
args
.
should_save
args
.
should_save
and
os
.
path
.
exists
(
os
.
path
.
join
(
args
.
output_dir
,
TRAINER_LOG
))
and
os
.
path
.
exists
(
os
.
path
.
join
(
args
.
output_dir
,
TRAINER_LOG
))
...
@@ -238,55 +246,41 @@ class LogCallback(TrainerCallback):
...
@@ -238,55 +246,41 @@ class LogCallback(TrainerCallback):
logger
.
warning
(
"Previous trainer log in this folder will be deleted."
)
logger
.
warning
(
"Previous trainer log in this folder will be deleted."
)
os
.
remove
(
os
.
path
.
join
(
args
.
output_dir
,
TRAINER_LOG
))
os
.
remove
(
os
.
path
.
join
(
args
.
output_dir
,
TRAINER_LOG
))
@
override
def
on_train_begin
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
def
on_train_begin
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
r
"""
Event called at the beginning of training.
"""
if
args
.
should_save
:
if
args
.
should_save
:
self
.
do_train
=
True
self
.
do_train
=
True
self
.
_reset
(
max_steps
=
state
.
max_steps
)
self
.
_reset
(
max_steps
=
state
.
max_steps
)
self
.
_create_thread_pool
(
output_dir
=
args
.
output_dir
)
self
.
_create_thread_pool
(
output_dir
=
args
.
output_dir
)
@
override
def
on_train_end
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
def
on_train_end
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
r
"""
Event called at the end of training.
"""
self
.
_close_thread_pool
()
self
.
_close_thread_pool
()
@
override
def
on_substep_end
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
def
on_substep_end
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
r
"""
Event called at the end of an substep during gradient accumulation.
"""
if
self
.
aborted
:
if
self
.
aborted
:
control
.
should_epoch_stop
=
True
control
.
should_epoch_stop
=
True
control
.
should_training_stop
=
True
control
.
should_training_stop
=
True
@
override
def
on_step_end
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
def
on_step_end
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
r
"""
Event called at the end of a training step.
"""
if
self
.
aborted
:
if
self
.
aborted
:
control
.
should_epoch_stop
=
True
control
.
should_epoch_stop
=
True
control
.
should_training_stop
=
True
control
.
should_training_stop
=
True
@
override
def
on_evaluate
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
def
on_evaluate
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
r
"""
Event called after an evaluation phase.
"""
if
not
self
.
do_train
:
if
not
self
.
do_train
:
self
.
_close_thread_pool
()
self
.
_close_thread_pool
()
@
override
def
on_predict
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
def
on_predict
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
r
"""
Event called after a successful prediction.
"""
if
not
self
.
do_train
:
if
not
self
.
do_train
:
self
.
_close_thread_pool
()
self
.
_close_thread_pool
()
@
override
def
on_log
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
def
on_log
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
r
"""
Event called after logging the last logs.
"""
if
not
args
.
should_save
:
if
not
args
.
should_save
:
return
return
...
@@ -304,26 +298,31 @@ class LogCallback(TrainerCallback):
...
@@ -304,26 +298,31 @@ class LogCallback(TrainerCallback):
percentage
=
round
(
self
.
cur_steps
/
self
.
max_steps
*
100
,
2
)
if
self
.
max_steps
!=
0
else
100
,
percentage
=
round
(
self
.
cur_steps
/
self
.
max_steps
*
100
,
2
)
if
self
.
max_steps
!=
0
else
100
,
elapsed_time
=
self
.
elapsed_time
,
elapsed_time
=
self
.
elapsed_time
,
remaining_time
=
self
.
remaining_time
,
remaining_time
=
self
.
remaining_time
,
throughput
=
"{:.2f}"
.
format
(
state
.
num_input_tokens_seen
/
(
time
.
time
()
-
self
.
start_time
)),
total_tokens
=
state
.
num_input_tokens_seen
,
)
)
if
state
.
num_input_tokens_seen
:
logs
[
"throughput"
]
=
round
(
state
.
num_input_tokens_seen
/
(
time
.
time
()
-
self
.
start_time
),
2
)
logs
[
"total_tokens"
]
=
state
.
num_input_tokens_seen
if
os
.
environ
.
get
(
"RECORD_VRAM"
,
"0"
).
lower
()
in
[
"true"
,
"1"
]:
vram_allocated
,
vram_reserved
=
get_peak_memory
()
logs
[
"vram_allocated"
]
=
round
(
vram_allocated
/
1024
/
1024
/
1024
,
2
)
logs
[
"vram_reserved"
]
=
round
(
vram_reserved
/
1024
/
1024
/
1024
,
2
)
logs
=
{
k
:
v
for
k
,
v
in
logs
.
items
()
if
v
is
not
None
}
logs
=
{
k
:
v
for
k
,
v
in
logs
.
items
()
if
v
is
not
None
}
if
self
.
webui_mode
and
all
(
key
in
logs
for
key
in
[
"loss"
,
"learning_rate"
,
"epoch"
]):
if
self
.
webui_mode
and
all
(
key
in
logs
for
key
in
[
"loss"
,
"learning_rate"
,
"epoch"
]):
logger
.
info
(
logger
.
info
(
"{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}, 'throughput': {}}}"
.
format
(
"{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}, 'throughput': {}}}"
.
format
(
logs
[
"loss"
],
logs
[
"learning_rate"
],
logs
[
"epoch"
],
logs
[
"throughput"
]
logs
[
"loss"
],
logs
[
"learning_rate"
],
logs
[
"epoch"
],
logs
.
get
(
"throughput"
,
"N/A"
)
)
)
)
)
if
self
.
thread_pool
is
not
None
:
if
self
.
thread_pool
is
not
None
:
self
.
thread_pool
.
submit
(
self
.
_write_log
,
args
.
output_dir
,
logs
)
self
.
thread_pool
.
submit
(
self
.
_write_log
,
args
.
output_dir
,
logs
)
@
override
def
on_prediction_step
(
def
on_prediction_step
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
):
r
"""
Event called after a prediction step.
"""
if
self
.
do_train
:
if
self
.
do_train
:
return
return
...
...
src/llamafactory/train/dpo/trainer.py
View file @
27a7ad86
...
@@ -26,6 +26,7 @@ import torch.nn.functional as F
...
@@ -26,6 +26,7 @@ import torch.nn.functional as F
from
transformers
import
Trainer
from
transformers
import
Trainer
from
trl
import
DPOTrainer
from
trl
import
DPOTrainer
from
trl.trainer
import
disable_dropout_in_model
from
trl.trainer
import
disable_dropout_in_model
from
typing_extensions
import
override
from
...extras.constants
import
IGNORE_INDEX
from
...extras.constants
import
IGNORE_INDEX
from
..callbacks
import
PissaConvertCallback
,
SaveProcessorCallback
from
..callbacks
import
PissaConvertCallback
,
SaveProcessorCallback
...
@@ -104,11 +105,13 @@ class CustomDPOTrainer(DPOTrainer):
...
@@ -104,11 +105,13 @@ class CustomDPOTrainer(DPOTrainer):
self
.
accelerator
.
clip_grad_norm_
=
MethodType
(
clip_grad_norm_old_version
,
self
.
accelerator
)
self
.
accelerator
.
clip_grad_norm_
=
MethodType
(
clip_grad_norm_old_version
,
self
.
accelerator
)
self
.
add_callback
(
BAdamCallback
)
self
.
add_callback
(
BAdamCallback
)
@
override
def
create_optimizer
(
self
)
->
"torch.optim.Optimizer"
:
def
create_optimizer
(
self
)
->
"torch.optim.Optimizer"
:
if
self
.
optimizer
is
None
:
if
self
.
optimizer
is
None
:
self
.
optimizer
=
create_custom_optimizer
(
self
.
model
,
self
.
args
,
self
.
finetuning_args
)
self
.
optimizer
=
create_custom_optimizer
(
self
.
model
,
self
.
args
,
self
.
finetuning_args
)
return
super
().
create_optimizer
()
return
super
().
create_optimizer
()
@
override
def
create_scheduler
(
def
create_scheduler
(
self
,
num_training_steps
:
int
,
optimizer
:
Optional
[
"torch.optim.Optimizer"
]
=
None
self
,
num_training_steps
:
int
,
optimizer
:
Optional
[
"torch.optim.Optimizer"
]
=
None
)
->
"torch.optim.lr_scheduler.LRScheduler"
:
)
->
"torch.optim.lr_scheduler.LRScheduler"
:
...
@@ -164,6 +167,7 @@ class CustomDPOTrainer(DPOTrainer):
...
@@ -164,6 +167,7 @@ class CustomDPOTrainer(DPOTrainer):
return
losses
,
chosen_rewards
,
rejected_rewards
return
losses
,
chosen_rewards
,
rejected_rewards
@
override
def
concatenated_forward
(
def
concatenated_forward
(
self
,
model
:
"PreTrainedModel"
,
batch
:
Dict
[
str
,
"torch.Tensor"
]
self
,
model
:
"PreTrainedModel"
,
batch
:
Dict
[
str
,
"torch.Tensor"
]
)
->
Tuple
[
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
]:
)
->
Tuple
[
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
]:
...
@@ -176,7 +180,6 @@ class CustomDPOTrainer(DPOTrainer):
...
@@ -176,7 +180,6 @@ class CustomDPOTrainer(DPOTrainer):
batch
=
{
k
:
v
.
detach
().
clone
()
for
k
,
v
in
batch
.
items
()}
# avoid error
batch
=
{
k
:
v
.
detach
().
clone
()
for
k
,
v
in
batch
.
items
()}
# avoid error
all_logits
:
"torch.Tensor"
=
model
(
**
batch
,
return_dict
=
True
,
use_cache
=
False
).
logits
.
to
(
torch
.
float32
)
all_logits
:
"torch.Tensor"
=
model
(
**
batch
,
return_dict
=
True
,
use_cache
=
False
).
logits
.
to
(
torch
.
float32
)
all_logps
,
valid_length
=
get_batch_logps
(
logits
=
all_logits
,
labels
=
batch
[
"labels"
])
all_logps
,
valid_length
=
get_batch_logps
(
logits
=
all_logits
,
labels
=
batch
[
"labels"
])
if
self
.
loss_type
in
[
"ipo"
,
"orpo"
,
"simpo"
]:
if
self
.
loss_type
in
[
"ipo"
,
"orpo"
,
"simpo"
]:
all_logps
=
all_logps
/
valid_length
all_logps
=
all_logps
/
valid_length
...
@@ -187,6 +190,7 @@ class CustomDPOTrainer(DPOTrainer):
...
@@ -187,6 +190,7 @@ class CustomDPOTrainer(DPOTrainer):
chosen_length
,
_
=
valid_length
.
split
(
batch_size
,
dim
=
0
)
chosen_length
,
_
=
valid_length
.
split
(
batch_size
,
dim
=
0
)
return
chosen_logps
,
rejected_logps
,
chosen_logits
,
rejected_logits
,
chosen_logps
/
chosen_length
return
chosen_logps
,
rejected_logps
,
chosen_logits
,
rejected_logits
,
chosen_logps
/
chosen_length
@
override
def
compute_reference_log_probs
(
def
compute_reference_log_probs
(
self
,
model
:
"PreTrainedModel"
,
batch
:
Dict
[
str
,
"torch.Tensor"
]
self
,
model
:
"PreTrainedModel"
,
batch
:
Dict
[
str
,
"torch.Tensor"
]
)
->
Tuple
[
Optional
[
"torch.Tensor"
],
Optional
[
"torch.Tensor"
]]:
)
->
Tuple
[
Optional
[
"torch.Tensor"
],
Optional
[
"torch.Tensor"
]]:
...
@@ -208,6 +212,7 @@ class CustomDPOTrainer(DPOTrainer):
...
@@ -208,6 +212,7 @@ class CustomDPOTrainer(DPOTrainer):
return
reference_chosen_logps
,
reference_rejected_logps
return
reference_chosen_logps
,
reference_rejected_logps
@
override
def
get_batch_loss_metrics
(
def
get_batch_loss_metrics
(
self
,
self
,
model
:
"PreTrainedModel"
,
model
:
"PreTrainedModel"
,
...
...
src/llamafactory/train/dpo/workflow.py
View file @
27a7ad86
...
@@ -17,7 +17,7 @@
...
@@ -17,7 +17,7 @@
from
typing
import
TYPE_CHECKING
,
List
,
Optional
from
typing
import
TYPE_CHECKING
,
List
,
Optional
from
...data
import
PairwiseDataCollatorWithPadding
,
get_dataset
from
...data
import
PairwiseDataCollatorWithPadding
,
get_dataset
,
get_template_and_fix_tokenizer
from
...extras.constants
import
IGNORE_INDEX
from
...extras.constants
import
IGNORE_INDEX
from
...extras.ploting
import
plot_loss
from
...extras.ploting
import
plot_loss
from
...hparams
import
ModelArguments
from
...hparams
import
ModelArguments
...
@@ -41,13 +41,15 @@ def run_dpo(
...
@@ -41,13 +41,15 @@ def run_dpo(
):
):
tokenizer_module
=
load_tokenizer
(
model_args
)
tokenizer_module
=
load_tokenizer
(
model_args
)
tokenizer
=
tokenizer_module
[
"tokenizer"
]
tokenizer
=
tokenizer_module
[
"tokenizer"
]
dataset_module
=
get_dataset
(
model_args
,
data_args
,
training_args
,
stage
=
"rm"
,
**
tokenizer_module
)
template
=
get_template_and_fix_tokenizer
(
tokenizer
,
data_args
)
dataset_module
=
get_dataset
(
template
,
model_args
,
data_args
,
training_args
,
stage
=
"rm"
,
**
tokenizer_module
)
model
=
load_model
(
tokenizer
,
model_args
,
finetuning_args
,
training_args
.
do_train
)
model
=
load_model
(
tokenizer
,
model_args
,
finetuning_args
,
training_args
.
do_train
)
data_collator
=
PairwiseDataCollatorWithPadding
(
data_collator
=
PairwiseDataCollatorWithPadding
(
t
okenizer
=
tokenizer
,
t
emplate
=
template
,
pad_to_multiple_of
=
8
,
pad_to_multiple_of
=
8
,
label_pad_token_id
=
IGNORE_INDEX
if
data_args
.
ignore_pad_token_for_loss
else
tokenizer
.
pad_token_id
,
label_pad_token_id
=
IGNORE_INDEX
if
data_args
.
ignore_pad_token_for_loss
else
tokenizer
.
pad_token_id
,
**
tokenizer_module
,
)
)
# Create reference model
# Create reference model
...
@@ -60,7 +62,7 @@ def run_dpo(
...
@@ -60,7 +62,7 @@ def run_dpo(
ref_model
=
None
ref_model
=
None
# Update arguments
# Update arguments
training_args
.
remove_unused_columns
=
False
# important for pairwise dataset
training_args
.
remove_unused_columns
=
False
# important for
multimodal and
pairwise dataset
# Initialize our Trainer
# Initialize our Trainer
trainer
=
CustomDPOTrainer
(
trainer
=
CustomDPOTrainer
(
...
...
src/llamafactory/train/kto/trainer.py
View file @
27a7ad86
...
@@ -25,6 +25,7 @@ import torch
...
@@ -25,6 +25,7 @@ import torch
from
transformers
import
Trainer
from
transformers
import
Trainer
from
trl
import
KTOTrainer
from
trl
import
KTOTrainer
from
trl.trainer
import
disable_dropout_in_model
from
trl.trainer
import
disable_dropout_in_model
from
typing_extensions
import
override
from
...extras.constants
import
IGNORE_INDEX
from
...extras.constants
import
IGNORE_INDEX
from
..callbacks
import
SaveProcessorCallback
from
..callbacks
import
SaveProcessorCallback
...
@@ -99,23 +100,27 @@ class CustomKTOTrainer(KTOTrainer):
...
@@ -99,23 +100,27 @@ class CustomKTOTrainer(KTOTrainer):
self
.
accelerator
.
clip_grad_norm_
=
MethodType
(
clip_grad_norm_old_version
,
self
.
accelerator
)
self
.
accelerator
.
clip_grad_norm_
=
MethodType
(
clip_grad_norm_old_version
,
self
.
accelerator
)
self
.
add_callback
(
BAdamCallback
)
self
.
add_callback
(
BAdamCallback
)
@
override
def
create_optimizer
(
self
)
->
"torch.optim.Optimizer"
:
def
create_optimizer
(
self
)
->
"torch.optim.Optimizer"
:
if
self
.
optimizer
is
None
:
if
self
.
optimizer
is
None
:
self
.
optimizer
=
create_custom_optimizer
(
self
.
model
,
self
.
args
,
self
.
finetuning_args
)
self
.
optimizer
=
create_custom_optimizer
(
self
.
model
,
self
.
args
,
self
.
finetuning_args
)
return
super
().
create_optimizer
()
return
super
().
create_optimizer
()
@
override
def
create_scheduler
(
def
create_scheduler
(
self
,
num_training_steps
:
int
,
optimizer
:
Optional
[
"torch.optim.Optimizer"
]
=
None
self
,
num_training_steps
:
int
,
optimizer
:
Optional
[
"torch.optim.Optimizer"
]
=
None
)
->
"torch.optim.lr_scheduler.LRScheduler"
:
)
->
"torch.optim.lr_scheduler.LRScheduler"
:
create_custom_scheduler
(
self
.
args
,
num_training_steps
,
optimizer
)
create_custom_scheduler
(
self
.
args
,
num_training_steps
,
optimizer
)
return
super
().
create_scheduler
(
num_training_steps
,
optimizer
)
return
super
().
create_scheduler
(
num_training_steps
,
optimizer
)
@
override
def
_get_train_sampler
(
self
)
->
Optional
[
"torch.utils.data.Sampler"
]:
def
_get_train_sampler
(
self
)
->
Optional
[
"torch.utils.data.Sampler"
]:
r
"""
r
"""
Replaces the sequential sampler of KTO Trainer created by trl with the random sampler.
Replaces the sequential sampler of KTO Trainer created by trl with the random sampler.
"""
"""
return
Trainer
.
_get_train_sampler
(
self
)
return
Trainer
.
_get_train_sampler
(
self
)
@
override
def
forward
(
def
forward
(
self
,
model
:
"PreTrainedModel"
,
batch
:
Dict
[
str
,
"torch.Tensor"
],
prefix
:
Literal
[
""
,
"kl_"
]
=
""
self
,
model
:
"PreTrainedModel"
,
batch
:
Dict
[
str
,
"torch.Tensor"
],
prefix
:
Literal
[
""
,
"kl_"
]
=
""
)
->
Tuple
[
"torch.Tensor"
,
"torch.Tensor"
]:
)
->
Tuple
[
"torch.Tensor"
,
"torch.Tensor"
]:
...
@@ -127,17 +132,20 @@ class CustomKTOTrainer(KTOTrainer):
...
@@ -127,17 +132,20 @@ class CustomKTOTrainer(KTOTrainer):
"input_ids"
:
batch
[
"{}input_ids"
.
format
(
prefix
)],
"input_ids"
:
batch
[
"{}input_ids"
.
format
(
prefix
)],
"attention_mask"
:
batch
[
"{}attention_mask"
.
format
(
prefix
)],
"attention_mask"
:
batch
[
"{}attention_mask"
.
format
(
prefix
)],
}
}
if
"{}token_type_ids"
.
format
(
prefix
)
in
batch
:
model_inputs
[
"token_type_ids"
]
=
batch
[
"{}token_type_ids"
.
format
(
prefix
)]
if
"pixel_values"
in
batch
:
if
"pixel_values"
in
batch
:
model_inputs
[
"pixel_values"
]
=
batch
[
"pixel_values"
]
model_inputs
[
"pixel_values"
]
=
batch
[
"pixel_values"
]
if
"
{}token_type_ids"
.
format
(
prefix
)
in
batch
:
if
"
image_grid_thw"
in
batch
:
model_inputs
[
"
token_type_ids"
]
=
batch
[
"{}token_type_ids"
.
format
(
prefix
)
]
model_inputs
[
"
image_grid_thw"
]
=
batch
[
"image_grid_thw"
]
logits
=
model
(
**
model_inputs
,
return_dict
=
True
,
use_cache
=
False
).
logits
.
to
(
torch
.
float32
)
logits
=
model
(
**
model_inputs
,
return_dict
=
True
,
use_cache
=
False
).
logits
.
to
(
torch
.
float32
)
logps
,
valid_length
=
get_batch_logps
(
logits
=
logits
,
labels
=
batch
[
"{}labels"
.
format
(
prefix
)])
logps
,
valid_length
=
get_batch_logps
(
logits
=
logits
,
labels
=
batch
[
"{}labels"
.
format
(
prefix
)])
return
logps
,
logps
/
valid_length
return
logps
,
logps
/
valid_length
@
override
def
concatenated_forward
(
def
concatenated_forward
(
self
,
model
:
"PreTrainedModel"
,
batch
:
Dict
[
str
,
"torch.Tensor"
]
self
,
model
:
"PreTrainedModel"
,
batch
:
Dict
[
str
,
"torch.Tensor"
]
)
->
Tuple
[
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
]:
)
->
Tuple
[
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
]:
...
@@ -153,6 +161,7 @@ class CustomKTOTrainer(KTOTrainer):
...
@@ -153,6 +161,7 @@ class CustomKTOTrainer(KTOTrainer):
chosen_logps_avg
=
target_logps_avg
[
batch
[
"kto_tags"
]]
chosen_logps_avg
=
target_logps_avg
[
batch
[
"kto_tags"
]]
return
chosen_logps
,
rejected_logps
,
kl_logps
,
chosen_logps_avg
return
chosen_logps
,
rejected_logps
,
kl_logps
,
chosen_logps_avg
@
override
def
compute_reference_log_probs
(
def
compute_reference_log_probs
(
self
,
model
:
"PreTrainedModel"
,
batch
:
Dict
[
str
,
"torch.Tensor"
]
self
,
model
:
"PreTrainedModel"
,
batch
:
Dict
[
str
,
"torch.Tensor"
]
)
->
Tuple
[
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
]:
)
->
Tuple
[
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
]:
...
@@ -173,6 +182,7 @@ class CustomKTOTrainer(KTOTrainer):
...
@@ -173,6 +182,7 @@ class CustomKTOTrainer(KTOTrainer):
return
reference_chosen_logps
,
reference_rejected_logps
,
reference_kl_logps
return
reference_chosen_logps
,
reference_rejected_logps
,
reference_kl_logps
@
override
def
get_batch_loss_metrics
(
def
get_batch_loss_metrics
(
self
,
self
,
model
:
"PreTrainedModel"
,
model
:
"PreTrainedModel"
,
...
...
src/llamafactory/train/kto/workflow.py
View file @
27a7ad86
...
@@ -17,7 +17,7 @@
...
@@ -17,7 +17,7 @@
from
typing
import
TYPE_CHECKING
,
List
,
Optional
from
typing
import
TYPE_CHECKING
,
List
,
Optional
from
...data
import
KTODataCollatorWithPadding
,
get_dataset
from
...data
import
KTODataCollatorWithPadding
,
get_dataset
,
get_template_and_fix_tokenizer
from
...extras.constants
import
IGNORE_INDEX
from
...extras.constants
import
IGNORE_INDEX
from
...extras.ploting
import
plot_loss
from
...extras.ploting
import
plot_loss
from
...hparams
import
ModelArguments
from
...hparams
import
ModelArguments
...
@@ -41,13 +41,15 @@ def run_kto(
...
@@ -41,13 +41,15 @@ def run_kto(
):
):
tokenizer_module
=
load_tokenizer
(
model_args
)
tokenizer_module
=
load_tokenizer
(
model_args
)
tokenizer
=
tokenizer_module
[
"tokenizer"
]
tokenizer
=
tokenizer_module
[
"tokenizer"
]
dataset_module
=
get_dataset
(
model_args
,
data_args
,
training_args
,
stage
=
"kto"
,
**
tokenizer_module
)
template
=
get_template_and_fix_tokenizer
(
tokenizer
,
data_args
)
dataset_module
=
get_dataset
(
template
,
model_args
,
data_args
,
training_args
,
stage
=
"kto"
,
**
tokenizer_module
)
model
=
load_model
(
tokenizer
,
model_args
,
finetuning_args
,
training_args
.
do_train
)
model
=
load_model
(
tokenizer
,
model_args
,
finetuning_args
,
training_args
.
do_train
)
data_collator
=
KTODataCollatorWithPadding
(
data_collator
=
KTODataCollatorWithPadding
(
t
okenizer
=
tokenizer
,
t
emplate
=
template
,
pad_to_multiple_of
=
8
,
pad_to_multiple_of
=
8
,
label_pad_token_id
=
IGNORE_INDEX
if
data_args
.
ignore_pad_token_for_loss
else
tokenizer
.
pad_token_id
,
label_pad_token_id
=
IGNORE_INDEX
if
data_args
.
ignore_pad_token_for_loss
else
tokenizer
.
pad_token_id
,
**
tokenizer_module
,
)
)
# Create reference model
# Create reference model
...
@@ -57,7 +59,7 @@ def run_kto(
...
@@ -57,7 +59,7 @@ def run_kto(
ref_model
=
create_ref_model
(
model_args
,
finetuning_args
)
ref_model
=
create_ref_model
(
model_args
,
finetuning_args
)
# Update arguments
# Update arguments
training_args
.
remove_unused_columns
=
False
# important for pairwise dataset
training_args
.
remove_unused_columns
=
False
# important for
multimodal and
pairwise dataset
# Initialize our Trainer
# Initialize our Trainer
trainer
=
CustomKTOTrainer
(
trainer
=
CustomKTOTrainer
(
...
...
src/llamafactory/train/ppo/ppo_utils.py
View file @
27a7ad86
...
@@ -31,7 +31,7 @@ if TYPE_CHECKING:
...
@@ -31,7 +31,7 @@ if TYPE_CHECKING:
from
trl
import
AutoModelForCausalLMWithValueHead
from
trl
import
AutoModelForCausalLMWithValueHead
def
get_rewards_from_server
(
server_url
:
str
,
messages
:
List
[
str
])
->
List
[
torch
.
Tensor
]:
def
get_rewards_from_server
(
server_url
:
str
,
messages
:
List
[
str
])
->
List
[
"
torch.Tensor
"
]:
r
"""
r
"""
Gets reward scores from the API server.
Gets reward scores from the API server.
"""
"""
...
@@ -66,7 +66,7 @@ def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["d
...
@@ -66,7 +66,7 @@ def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["d
v_head_layer
.
bias
.
data
=
model
.
get_buffer
(
"{}_head_bias"
.
format
(
target
)).
detach
().
clone
().
to
(
device
)
v_head_layer
.
bias
.
data
=
model
.
get_buffer
(
"{}_head_bias"
.
format
(
target
)).
detach
().
clone
().
to
(
device
)
def
dump_layernorm
(
model
:
"PreTrainedModel"
)
->
Dict
[
str
,
torch
.
Tensor
]:
def
dump_layernorm
(
model
:
"PreTrainedModel"
)
->
Dict
[
str
,
"
torch.Tensor
"
]:
r
"""
r
"""
Dumps the layernorm parameters in the model. The model is already unwrapped (and gathered).
Dumps the layernorm parameters in the model. The model is already unwrapped (and gathered).
"""
"""
...
@@ -79,7 +79,7 @@ def dump_layernorm(model: "PreTrainedModel") -> Dict[str, torch.Tensor]:
...
@@ -79,7 +79,7 @@ def dump_layernorm(model: "PreTrainedModel") -> Dict[str, torch.Tensor]:
return
layer_norm_params
return
layer_norm_params
def
restore_layernorm
(
model
:
"PreTrainedModel"
,
layernorm_params
:
Optional
[
Dict
[
str
,
torch
.
Tensor
]]
=
None
)
->
None
:
def
restore_layernorm
(
model
:
"PreTrainedModel"
,
layernorm_params
:
Optional
[
Dict
[
str
,
"
torch.Tensor
"
]]
=
None
)
->
None
:
r
"""
r
"""
Restores the layernorm parameters in the model. The model is already unwrapped (and gathered).
Restores the layernorm parameters in the model. The model is already unwrapped (and gathered).
"""
"""
...
...
src/llamafactory/train/ppo/trainer.py
View file @
27a7ad86
...
@@ -35,6 +35,7 @@ from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME
...
@@ -35,6 +35,7 @@ from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME
from
trl
import
PPOConfig
,
PPOTrainer
from
trl
import
PPOConfig
,
PPOTrainer
from
trl.core
import
PPODecorators
,
logprobs_from_logits
from
trl.core
import
PPODecorators
,
logprobs_from_logits
from
trl.models.utils
import
unwrap_model_for_generation
from
trl.models.utils
import
unwrap_model_for_generation
from
typing_extensions
import
override
from
...extras.logging
import
get_logger
from
...extras.logging
import
get_logger
from
...extras.misc
import
AverageMeter
,
count_parameters
,
get_current_device
,
get_logits_processor
from
...extras.misc
import
AverageMeter
,
count_parameters
,
get_current_device
,
get_logits_processor
...
@@ -298,6 +299,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
...
@@ -298,6 +299,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self
.
callback_handler
.
on_train_end
(
self
.
args
,
self
.
state
,
self
.
control
)
self
.
callback_handler
.
on_train_end
(
self
.
args
,
self
.
state
,
self
.
control
)
@
override
def
create_optimizer
(
def
create_optimizer
(
self
,
self
,
model
:
"AutoModelForCausalLMWithValueHead"
,
model
:
"AutoModelForCausalLMWithValueHead"
,
...
@@ -324,6 +326,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
...
@@ -324,6 +326,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
return
optimizer
return
optimizer
@
override
def
create_scheduler
(
def
create_scheduler
(
self
,
training_args
:
"Seq2SeqTrainingArguments"
,
num_training_steps
:
int
,
optimizer
:
"torch.optim.Optimizer"
self
,
training_args
:
"Seq2SeqTrainingArguments"
,
num_training_steps
:
int
,
optimizer
:
"torch.optim.Optimizer"
)
->
"torch.optim.lr_scheduler.LRScheduler"
:
)
->
"torch.optim.lr_scheduler.LRScheduler"
:
...
@@ -389,7 +392,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
...
@@ -389,7 +392,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
"""
"""
if
self
.
finetuning_args
.
reward_model_type
==
"api"
:
if
self
.
finetuning_args
.
reward_model_type
==
"api"
:
token_ids
=
[
torch
.
cat
((
q
,
r
),
dim
=-
1
).
tolist
()
for
q
,
r
in
zip
(
queries
,
responses
)]
token_ids
=
[
torch
.
cat
((
q
,
r
),
dim
=-
1
).
tolist
()
for
q
,
r
in
zip
(
queries
,
responses
)]
messages
=
self
.
tokenizer
.
batch_decode
(
token_ids
,
skip_special_tokens
=
Tru
e
)
messages
=
self
.
tokenizer
.
batch_decode
(
token_ids
,
skip_special_tokens
=
Fals
e
)
return
get_rewards_from_server
(
self
.
reward_model
,
messages
)
return
get_rewards_from_server
(
self
.
reward_model
,
messages
)
batch
:
Dict
[
str
,
"torch.Tensor"
]
=
self
.
prepare_model_inputs
(
queries
,
responses
)
batch
:
Dict
[
str
,
"torch.Tensor"
]
=
self
.
prepare_model_inputs
(
queries
,
responses
)
...
@@ -402,7 +405,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
...
@@ -402,7 +405,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
reward_model
=
self
.
reward_model
reward_model
=
self
.
reward_model
with
unwrap_model_for_generation
(
reward_model
,
self
.
accelerator
),
self
.
amp_context
:
# support bf16
with
unwrap_model_for_generation
(
reward_model
,
self
.
accelerator
),
self
.
amp_context
:
# support bf16
_
,
_
,
values
=
reward_model
(
**
batch
,
return_dict
=
True
,
use_cache
=
False
)
values
:
"torch.Tensor"
=
reward_model
(
**
batch
,
return_dict
=
True
,
use_cache
=
False
)
[
-
1
]
if
self
.
finetuning_args
.
reward_model_type
==
"lora"
:
if
self
.
finetuning_args
.
reward_model_type
==
"lora"
:
replace_model
(
unwrapped_model
,
target
=
"default"
)
replace_model
(
unwrapped_model
,
target
=
"default"
)
...
@@ -410,6 +413,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
...
@@ -410,6 +413,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
rewards
=
values
.
gather
(
dim
=-
1
,
index
=
(
batch
[
"attention_mask"
].
sum
(
dim
=-
1
,
keepdim
=
True
)
-
1
))
rewards
=
values
.
gather
(
dim
=-
1
,
index
=
(
batch
[
"attention_mask"
].
sum
(
dim
=-
1
,
keepdim
=
True
)
-
1
))
return
rewards
.
float
().
detach
()
# use fp32 type
return
rewards
.
float
().
detach
()
# use fp32 type
@
override
@
PPODecorators
.
empty_device_cache
()
@
PPODecorators
.
empty_device_cache
()
def
batched_forward_pass
(
def
batched_forward_pass
(
self
,
self
,
...
@@ -478,6 +482,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
...
@@ -478,6 +482,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
torch
.
cat
(
all_masks
)[:,
:
-
1
],
torch
.
cat
(
all_masks
)[:,
:
-
1
],
)
)
@
override
def
save_model
(
self
,
output_dir
:
Optional
[
str
]
=
None
)
->
None
:
def
save_model
(
self
,
output_dir
:
Optional
[
str
]
=
None
)
->
None
:
r
"""
r
"""
Saves model checkpoint.
Saves model checkpoint.
...
...
src/llamafactory/train/ppo/workflow.py
View file @
27a7ad86
...
@@ -17,9 +17,7 @@
...
@@ -17,9 +17,7 @@
from
typing
import
TYPE_CHECKING
,
List
,
Optional
from
typing
import
TYPE_CHECKING
,
List
,
Optional
from
transformers
import
DataCollatorWithPadding
from
...data
import
MultiModalDataCollatorForSeq2Seq
,
get_dataset
,
get_template_and_fix_tokenizer
from
...data
import
get_dataset
from
...extras.ploting
import
plot_loss
from
...extras.ploting
import
plot_loss
from
...model
import
load_model
,
load_tokenizer
from
...model
import
load_model
,
load_tokenizer
from
..callbacks
import
fix_valuehead_checkpoint
from
..callbacks
import
fix_valuehead_checkpoint
...
@@ -43,11 +41,12 @@ def run_ppo(
...
@@ -43,11 +41,12 @@ def run_ppo(
):
):
tokenizer_module
=
load_tokenizer
(
model_args
)
tokenizer_module
=
load_tokenizer
(
model_args
)
tokenizer
=
tokenizer_module
[
"tokenizer"
]
tokenizer
=
tokenizer_module
[
"tokenizer"
]
dataset_module
=
get_dataset
(
model_args
,
data_args
,
training_args
,
stage
=
"ppo"
,
**
tokenizer_module
)
template
=
get_template_and_fix_tokenizer
(
tokenizer
,
data_args
)
dataset_module
=
get_dataset
(
template
,
model_args
,
data_args
,
training_args
,
stage
=
"ppo"
,
**
tokenizer_module
)
model
=
load_model
(
tokenizer
,
model_args
,
finetuning_args
,
training_args
.
do_train
,
add_valuehead
=
True
)
model
=
load_model
(
tokenizer
,
model_args
,
finetuning_args
,
training_args
.
do_train
,
add_valuehead
=
True
)
tokenizer
.
padding_side
=
"left"
# use left-padding in generation while using right-padding in training
tokenizer
.
padding_side
=
"left"
# use left-padding in generation while using right-padding in training
data_collator
=
DataCollator
WithPadding
(
tokenizer
=
tokenizer
)
data_collator
=
MultiModal
DataCollator
ForSeq2Seq
(
template
=
template
,
**
tokenizer
_module
)
# Create reference model and reward model
# Create reference model and reward model
ref_model
=
create_ref_model
(
model_args
,
finetuning_args
,
add_valuehead
=
True
)
ref_model
=
create_ref_model
(
model_args
,
finetuning_args
,
add_valuehead
=
True
)
...
...
src/llamafactory/train/pt/trainer.py
View file @
27a7ad86
...
@@ -16,6 +16,7 @@ from types import MethodType
...
@@ -16,6 +16,7 @@ from types import MethodType
from
typing
import
TYPE_CHECKING
,
Optional
from
typing
import
TYPE_CHECKING
,
Optional
from
transformers
import
Trainer
from
transformers
import
Trainer
from
typing_extensions
import
override
from
...extras.logging
import
get_logger
from
...extras.logging
import
get_logger
from
..callbacks
import
PissaConvertCallback
,
SaveProcessorCallback
from
..callbacks
import
PissaConvertCallback
,
SaveProcessorCallback
...
@@ -55,11 +56,13 @@ class CustomTrainer(Trainer):
...
@@ -55,11 +56,13 @@ class CustomTrainer(Trainer):
self
.
accelerator
.
clip_grad_norm_
=
MethodType
(
clip_grad_norm_old_version
,
self
.
accelerator
)
self
.
accelerator
.
clip_grad_norm_
=
MethodType
(
clip_grad_norm_old_version
,
self
.
accelerator
)
self
.
add_callback
(
BAdamCallback
)
self
.
add_callback
(
BAdamCallback
)
@
override
def
create_optimizer
(
self
)
->
"torch.optim.Optimizer"
:
def
create_optimizer
(
self
)
->
"torch.optim.Optimizer"
:
if
self
.
optimizer
is
None
:
if
self
.
optimizer
is
None
:
self
.
optimizer
=
create_custom_optimizer
(
self
.
model
,
self
.
args
,
self
.
finetuning_args
)
self
.
optimizer
=
create_custom_optimizer
(
self
.
model
,
self
.
args
,
self
.
finetuning_args
)
return
super
().
create_optimizer
()
return
super
().
create_optimizer
()
@
override
def
create_scheduler
(
def
create_scheduler
(
self
,
num_training_steps
:
int
,
optimizer
:
Optional
[
"torch.optim.Optimizer"
]
=
None
self
,
num_training_steps
:
int
,
optimizer
:
Optional
[
"torch.optim.Optimizer"
]
=
None
)
->
"torch.optim.lr_scheduler.LRScheduler"
:
)
->
"torch.optim.lr_scheduler.LRScheduler"
:
...
...
src/llamafactory/train/pt/workflow.py
View file @
27a7ad86
...
@@ -20,7 +20,7 @@ from typing import TYPE_CHECKING, List, Optional
...
@@ -20,7 +20,7 @@ from typing import TYPE_CHECKING, List, Optional
from
transformers
import
DataCollatorForLanguageModeling
from
transformers
import
DataCollatorForLanguageModeling
from
...data
import
get_dataset
from
...data
import
get_dataset
,
get_template_and_fix_tokenizer
from
...extras.ploting
import
plot_loss
from
...extras.ploting
import
plot_loss
from
...model
import
load_model
,
load_tokenizer
from
...model
import
load_model
,
load_tokenizer
from
..trainer_utils
import
create_modelcard_and_push
from
..trainer_utils
import
create_modelcard_and_push
...
@@ -42,7 +42,8 @@ def run_pt(
...
@@ -42,7 +42,8 @@ def run_pt(
):
):
tokenizer_module
=
load_tokenizer
(
model_args
)
tokenizer_module
=
load_tokenizer
(
model_args
)
tokenizer
=
tokenizer_module
[
"tokenizer"
]
tokenizer
=
tokenizer_module
[
"tokenizer"
]
dataset_module
=
get_dataset
(
model_args
,
data_args
,
training_args
,
stage
=
"pt"
,
**
tokenizer_module
)
template
=
get_template_and_fix_tokenizer
(
tokenizer
,
data_args
)
dataset_module
=
get_dataset
(
template
,
model_args
,
data_args
,
training_args
,
stage
=
"pt"
,
**
tokenizer_module
)
model
=
load_model
(
tokenizer
,
model_args
,
finetuning_args
,
training_args
.
do_train
)
model
=
load_model
(
tokenizer
,
model_args
,
finetuning_args
,
training_args
.
do_train
)
data_collator
=
DataCollatorForLanguageModeling
(
tokenizer
=
tokenizer
,
mlm
=
False
)
data_collator
=
DataCollatorForLanguageModeling
(
tokenizer
=
tokenizer
,
mlm
=
False
)
...
...
src/llamafactory/train/rm/metric.py
View file @
27a7ad86
...
@@ -26,6 +26,10 @@ if TYPE_CHECKING:
...
@@ -26,6 +26,10 @@ if TYPE_CHECKING:
@
dataclass
@
dataclass
class
ComputeAccuracy
:
class
ComputeAccuracy
:
r
"""
Computes reward accuracy and supports `batch_eval_metrics`.
"""
def
_dump
(
self
)
->
Optional
[
Dict
[
str
,
float
]]:
def
_dump
(
self
)
->
Optional
[
Dict
[
str
,
float
]]:
result
=
None
result
=
None
if
hasattr
(
self
,
"score_dict"
):
if
hasattr
(
self
,
"score_dict"
):
...
...
src/llamafactory/train/rm/trainer.py
View file @
27a7ad86
...
@@ -22,6 +22,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
...
@@ -22,6 +22,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import
torch
import
torch
from
transformers
import
Trainer
from
transformers
import
Trainer
from
typing_extensions
import
override
from
...extras.logging
import
get_logger
from
...extras.logging
import
get_logger
from
..callbacks
import
FixValueHeadModelCallback
,
PissaConvertCallback
,
SaveProcessorCallback
from
..callbacks
import
FixValueHeadModelCallback
,
PissaConvertCallback
,
SaveProcessorCallback
...
@@ -63,20 +64,23 @@ class PairwiseTrainer(Trainer):
...
@@ -63,20 +64,23 @@ class PairwiseTrainer(Trainer):
self
.
accelerator
.
clip_grad_norm_
=
MethodType
(
clip_grad_norm_old_version
,
self
.
accelerator
)
self
.
accelerator
.
clip_grad_norm_
=
MethodType
(
clip_grad_norm_old_version
,
self
.
accelerator
)
self
.
add_callback
(
BAdamCallback
)
self
.
add_callback
(
BAdamCallback
)
@
override
def
create_optimizer
(
self
)
->
"torch.optim.Optimizer"
:
def
create_optimizer
(
self
)
->
"torch.optim.Optimizer"
:
if
self
.
optimizer
is
None
:
if
self
.
optimizer
is
None
:
self
.
optimizer
=
create_custom_optimizer
(
self
.
model
,
self
.
args
,
self
.
finetuning_args
)
self
.
optimizer
=
create_custom_optimizer
(
self
.
model
,
self
.
args
,
self
.
finetuning_args
)
return
super
().
create_optimizer
()
return
super
().
create_optimizer
()
@
override
def
create_scheduler
(
def
create_scheduler
(
self
,
num_training_steps
:
int
,
optimizer
:
Optional
[
"torch.optim.Optimizer"
]
=
None
self
,
num_training_steps
:
int
,
optimizer
:
Optional
[
"torch.optim.Optimizer"
]
=
None
)
->
"torch.optim.lr_scheduler.LRScheduler"
:
)
->
"torch.optim.lr_scheduler.LRScheduler"
:
create_custom_scheduler
(
self
.
args
,
num_training_steps
,
optimizer
)
create_custom_scheduler
(
self
.
args
,
num_training_steps
,
optimizer
)
return
super
().
create_scheduler
(
num_training_steps
,
optimizer
)
return
super
().
create_scheduler
(
num_training_steps
,
optimizer
)
@
override
def
compute_loss
(
def
compute_loss
(
self
,
model
:
"PreTrainedModel"
,
inputs
:
Dict
[
str
,
torch
.
Tensor
],
return_outputs
:
bool
=
False
self
,
model
:
"PreTrainedModel"
,
inputs
:
Dict
[
str
,
"
torch.Tensor
"
],
return_outputs
:
bool
=
False
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]]:
)
->
Union
[
"
torch.Tensor
"
,
Tuple
[
"
torch.Tensor
"
,
List
[
"
torch.Tensor
"
]]]:
r
"""
r
"""
Computes pairwise loss. The first n examples are chosen and the last n examples are rejected.
Computes pairwise loss. The first n examples are chosen and the last n examples are rejected.
...
...
src/llamafactory/train/rm/workflow.py
View file @
27a7ad86
...
@@ -17,7 +17,7 @@
...
@@ -17,7 +17,7 @@
from
typing
import
TYPE_CHECKING
,
List
,
Optional
from
typing
import
TYPE_CHECKING
,
List
,
Optional
from
...data
import
PairwiseDataCollatorWithPadding
,
get_dataset
from
...data
import
PairwiseDataCollatorWithPadding
,
get_dataset
,
get_template_and_fix_tokenizer
from
...extras.ploting
import
plot_loss
from
...extras.ploting
import
plot_loss
from
...model
import
load_model
,
load_tokenizer
from
...model
import
load_model
,
load_tokenizer
from
..callbacks
import
fix_valuehead_checkpoint
from
..callbacks
import
fix_valuehead_checkpoint
...
@@ -41,12 +41,13 @@ def run_rm(
...
@@ -41,12 +41,13 @@ def run_rm(
):
):
tokenizer_module
=
load_tokenizer
(
model_args
)
tokenizer_module
=
load_tokenizer
(
model_args
)
tokenizer
=
tokenizer_module
[
"tokenizer"
]
tokenizer
=
tokenizer_module
[
"tokenizer"
]
dataset_module
=
get_dataset
(
model_args
,
data_args
,
training_args
,
stage
=
"rm"
,
**
tokenizer_module
)
template
=
get_template_and_fix_tokenizer
(
tokenizer
,
data_args
)
dataset_module
=
get_dataset
(
template
,
model_args
,
data_args
,
training_args
,
stage
=
"rm"
,
**
tokenizer_module
)
model
=
load_model
(
tokenizer
,
model_args
,
finetuning_args
,
training_args
.
do_train
,
add_valuehead
=
True
)
model
=
load_model
(
tokenizer
,
model_args
,
finetuning_args
,
training_args
.
do_train
,
add_valuehead
=
True
)
data_collator
=
PairwiseDataCollatorWithPadding
(
t
okenizer
,
pad_to_multiple_of
=
8
)
data_collator
=
PairwiseDataCollatorWithPadding
(
t
emplate
=
template
,
pad_to_multiple_of
=
8
,
**
tokenizer_module
)
# Update arguments
# Update arguments
training_args
.
remove_unused_columns
=
False
# important for pairwise dataset
training_args
.
remove_unused_columns
=
False
# important for
multimodal and
pairwise dataset
# Initialize our Trainer
# Initialize our Trainer
trainer
=
PairwiseTrainer
(
trainer
=
PairwiseTrainer
(
...
...
src/llamafactory/train/sft/metric.py
View file @
27a7ad86
...
@@ -45,6 +45,9 @@ if is_rouge_available():
...
@@ -45,6 +45,9 @@ if is_rouge_available():
def
eval_logit_processor
(
logits
:
"torch.Tensor"
,
labels
:
"torch.Tensor"
)
->
"torch.Tensor"
:
def
eval_logit_processor
(
logits
:
"torch.Tensor"
,
labels
:
"torch.Tensor"
)
->
"torch.Tensor"
:
r
"""
Computes the token with the largest likelihood to reduce memory footprint.
"""
if
isinstance
(
logits
,
(
list
,
tuple
)):
if
isinstance
(
logits
,
(
list
,
tuple
)):
if
logits
[
0
].
dim
()
==
3
:
# (batch_size, seq_len, vocab_size)
if
logits
[
0
].
dim
()
==
3
:
# (batch_size, seq_len, vocab_size)
logits
=
logits
[
0
]
logits
=
logits
[
0
]
...
@@ -59,6 +62,10 @@ def eval_logit_processor(logits: "torch.Tensor", labels: "torch.Tensor") -> "tor
...
@@ -59,6 +62,10 @@ def eval_logit_processor(logits: "torch.Tensor", labels: "torch.Tensor") -> "tor
@
dataclass
@
dataclass
class
ComputeAccuracy
:
class
ComputeAccuracy
:
r
"""
Computes accuracy and supports `batch_eval_metrics`.
"""
def
_dump
(
self
)
->
Optional
[
Dict
[
str
,
float
]]:
def
_dump
(
self
)
->
Optional
[
Dict
[
str
,
float
]]:
result
=
None
result
=
None
if
hasattr
(
self
,
"score_dict"
):
if
hasattr
(
self
,
"score_dict"
):
...
@@ -84,6 +91,8 @@ class ComputeAccuracy:
...
@@ -84,6 +91,8 @@ class ComputeAccuracy:
@
dataclass
@
dataclass
class
ComputeSimilarity
:
class
ComputeSimilarity
:
r
"""
r
"""
Computes text similarity scores and supports `batch_eval_metrics`.
Wraps the tokenizer into metric functions, used in CustomSeq2SeqTrainer.
Wraps the tokenizer into metric functions, used in CustomSeq2SeqTrainer.
"""
"""
...
...
src/llamafactory/train/sft/trainer.py
View file @
27a7ad86
...
@@ -23,6 +23,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
...
@@ -23,6 +23,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
transformers
import
Seq2SeqTrainer
from
transformers
import
Seq2SeqTrainer
from
typing_extensions
import
override
from
...extras.constants
import
IGNORE_INDEX
from
...extras.constants
import
IGNORE_INDEX
from
...extras.logging
import
get_logger
from
...extras.logging
import
get_logger
...
@@ -64,32 +65,36 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
...
@@ -64,32 +65,36 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
self
.
accelerator
.
clip_grad_norm_
=
MethodType
(
clip_grad_norm_old_version
,
self
.
accelerator
)
self
.
accelerator
.
clip_grad_norm_
=
MethodType
(
clip_grad_norm_old_version
,
self
.
accelerator
)
self
.
add_callback
(
BAdamCallback
)
self
.
add_callback
(
BAdamCallback
)
@
override
def
create_optimizer
(
self
)
->
"torch.optim.Optimizer"
:
def
create_optimizer
(
self
)
->
"torch.optim.Optimizer"
:
if
self
.
optimizer
is
None
:
if
self
.
optimizer
is
None
:
self
.
optimizer
=
create_custom_optimizer
(
self
.
model
,
self
.
args
,
self
.
finetuning_args
)
self
.
optimizer
=
create_custom_optimizer
(
self
.
model
,
self
.
args
,
self
.
finetuning_args
)
return
super
().
create_optimizer
()
return
super
().
create_optimizer
()
@
override
def
create_scheduler
(
def
create_scheduler
(
self
,
num_training_steps
:
int
,
optimizer
:
Optional
[
"torch.optim.Optimizer"
]
=
None
self
,
num_training_steps
:
int
,
optimizer
:
Optional
[
"torch.optim.Optimizer"
]
=
None
)
->
"torch.optim.lr_scheduler.LRScheduler"
:
)
->
"torch.optim.lr_scheduler.LRScheduler"
:
create_custom_scheduler
(
self
.
args
,
num_training_steps
,
optimizer
)
create_custom_scheduler
(
self
.
args
,
num_training_steps
,
optimizer
)
return
super
().
create_scheduler
(
num_training_steps
,
optimizer
)
return
super
().
create_scheduler
(
num_training_steps
,
optimizer
)
@
override
def
prediction_step
(
def
prediction_step
(
self
,
self
,
model
:
"torch.nn.Module"
,
model
:
"torch.nn.Module"
,
inputs
:
Dict
[
str
,
Union
[
torch
.
Tensor
,
Any
]],
inputs
:
Dict
[
str
,
Union
[
"
torch.Tensor
"
,
Any
]],
prediction_loss_only
:
bool
,
prediction_loss_only
:
bool
,
ignore_keys
:
Optional
[
List
[
str
]]
=
None
,
ignore_keys
:
Optional
[
List
[
str
]]
=
None
,
)
->
Tuple
[
Optional
[
float
],
Optional
[
torch
.
Tensor
],
Optional
[
torch
.
Tensor
]]:
)
->
Tuple
[
Optional
[
float
],
Optional
[
"
torch.Tensor
"
],
Optional
[
"
torch.Tensor
"
]]:
r
"""
r
"""
Removes the prompt part in the generated tokens.
Removes the prompt part in the generated tokens.
Subclass and override to inject custom behavior.
Subclass and override to inject custom behavior.
"""
"""
labels
=
inputs
[
"labels"
]
.
detach
().
clone
()
if
"labels"
in
inputs
else
None
# backup labels
labels
=
inputs
[
"labels"
]
if
"labels"
in
inputs
else
None
if
self
.
args
.
predict_with_generate
:
if
self
.
args
.
predict_with_generate
:
assert
self
.
tokenizer
.
padding_side
==
"left"
,
"This method only accepts left-padded tensor."
assert
self
.
tokenizer
.
padding_side
==
"left"
,
"This method only accepts left-padded tensor."
labels
=
labels
.
detach
().
clone
()
if
labels
is
not
None
else
None
# backup labels
prompt_len
,
label_len
=
inputs
[
"input_ids"
].
size
(
-
1
),
inputs
[
"labels"
].
size
(
-
1
)
prompt_len
,
label_len
=
inputs
[
"input_ids"
].
size
(
-
1
),
inputs
[
"labels"
].
size
(
-
1
)
if
prompt_len
>
label_len
:
if
prompt_len
>
label_len
:
inputs
[
"labels"
]
=
self
.
_pad_tensors_to_target_len
(
inputs
[
"labels"
],
inputs
[
"input_ids"
])
inputs
[
"labels"
]
=
self
.
_pad_tensors_to_target_len
(
inputs
[
"labels"
],
inputs
[
"input_ids"
])
...
@@ -105,7 +110,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
...
@@ -105,7 +110,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
return
loss
,
generated_tokens
,
labels
return
loss
,
generated_tokens
,
labels
def
_pad_tensors_to_target_len
(
self
,
src_tensor
:
torch
.
Tensor
,
tgt_tensor
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
_pad_tensors_to_target_len
(
self
,
src_tensor
:
"
torch.Tensor
"
,
tgt_tensor
:
"
torch.Tensor
"
)
->
"
torch.Tensor
"
:
r
"""
r
"""
Pads the tensor to the same length as the target tensor.
Pads the tensor to the same length as the target tensor.
"""
"""
...
...
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment