Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
LLaMA-Factory
Commits
27a7ad86
Commit
27a7ad86
authored
Oct 14, 2024
by
luopl
Browse files
update to v0.9.1
parent
731cf9b8
Changes
120
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
289 additions
and
114 deletions
+289
-114
src/llamafactory/model/model_utils/misc.py
src/llamafactory/model/model_utils/misc.py
+10
-5
src/llamafactory/model/model_utils/moe.py
src/llamafactory/model/model_utils/moe.py
+11
-9
src/llamafactory/model/model_utils/packing.py
src/llamafactory/model/model_utils/packing.py
+1
-1
src/llamafactory/model/model_utils/visual.py
src/llamafactory/model/model_utils/visual.py
+104
-8
src/llamafactory/model/patcher.py
src/llamafactory/model/patcher.py
+28
-5
src/llamafactory/train/callbacks.py
src/llamafactory/train/callbacks.py
+51
-52
src/llamafactory/train/dpo/trainer.py
src/llamafactory/train/dpo/trainer.py
+6
-1
src/llamafactory/train/dpo/workflow.py
src/llamafactory/train/dpo/workflow.py
+6
-4
src/llamafactory/train/kto/trainer.py
src/llamafactory/train/kto/trainer.py
+13
-3
src/llamafactory/train/kto/workflow.py
src/llamafactory/train/kto/workflow.py
+6
-4
src/llamafactory/train/ppo/ppo_utils.py
src/llamafactory/train/ppo/ppo_utils.py
+3
-3
src/llamafactory/train/ppo/trainer.py
src/llamafactory/train/ppo/trainer.py
+7
-2
src/llamafactory/train/ppo/workflow.py
src/llamafactory/train/ppo/workflow.py
+4
-5
src/llamafactory/train/pt/trainer.py
src/llamafactory/train/pt/trainer.py
+3
-0
src/llamafactory/train/pt/workflow.py
src/llamafactory/train/pt/workflow.py
+3
-2
src/llamafactory/train/rm/metric.py
src/llamafactory/train/rm/metric.py
+4
-0
src/llamafactory/train/rm/trainer.py
src/llamafactory/train/rm/trainer.py
+6
-2
src/llamafactory/train/rm/workflow.py
src/llamafactory/train/rm/workflow.py
+5
-4
src/llamafactory/train/sft/metric.py
src/llamafactory/train/sft/metric.py
+9
-0
src/llamafactory/train/sft/trainer.py
src/llamafactory/train/sft/trainer.py
+9
-4
No files found.
src/llamafactory/model/model_utils/misc.py
View file @
27a7ad86
...
...
@@ -28,17 +28,22 @@ def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool)
r
"""
Finds all available modules to apply lora or galore.
"""
model_type
=
getattr
(
model
.
config
,
"model_type"
,
None
)
forbidden_modules
=
{
"lm_head"
}
if
model
.
config
.
model_type
==
"chatglm"
:
if
model_type
==
"chatglm"
:
forbidden_modules
.
add
(
"output_layer"
)
elif
model
.
config
.
model_type
==
"internlm2"
:
elif
model_type
==
"internlm2"
:
forbidden_modules
.
add
(
"output"
)
elif
model
.
config
.
model_type
in
[
"llava"
,
"
paligemm
a"
]:
elif
model_type
in
[
"llava"
,
"
llava_next"
,
"llava_next_video"
,
"paligemma"
,
"video_llav
a"
]:
forbidden_modules
.
add
(
"multi_modal_projector"
)
elif
model_type
==
"qwen2_vl"
:
forbidden_modules
.
add
(
"merger"
)
if
freeze_vision_tower
:
forbidden_modules
.
add
(
"vision_tower"
)
if
model_type
==
"qwen2_vl"
:
forbidden_modules
.
add
(
"visual"
)
else
:
forbidden_modules
.
add
(
"vision_tower"
)
module_names
=
set
()
for
name
,
module
in
model
.
named_modules
():
...
...
src/llamafactory/model/model_utils/moe.py
View file @
27a7ad86
...
...
@@ -39,42 +39,44 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
if
not
is_deepspeed_zero3_enabled
():
return
if
getattr
(
model
.
config
,
"model_type"
,
None
)
==
"dbrx"
:
model_type
=
getattr
(
model
.
config
,
"model_type"
,
None
)
if
model_type
==
"dbrx"
:
from
transformers.models.dbrx.modeling_dbrx
import
DbrxFFN
_set_z3_leaf_modules
(
model
,
[
DbrxFFN
])
if
getattr
(
model
.
config
,
"
model_type
"
,
None
)
==
"jamba"
:
if
model_type
==
"jamba"
:
from
transformers.models.jamba.modeling_jamba
import
JambaSparseMoeBlock
_set_z3_leaf_modules
(
model
,
[
JambaSparseMoeBlock
])
if
getattr
(
model
.
config
,
"
model_type
"
,
None
)
==
"jetmoe"
:
if
model_type
==
"jetmoe"
:
from
transformers.models.jetmoe.modeling_jetmoe
import
JetMoeMoA
,
JetMoeMoE
_set_z3_leaf_modules
(
model
,
[
JetMoeMoA
,
JetMoeMoE
])
if
getattr
(
model
.
config
,
"
model_type
"
,
None
)
==
"mixtral"
:
if
model_type
==
"mixtral"
:
from
transformers.models.mixtral.modeling_mixtral
import
MixtralSparseMoeBlock
_set_z3_leaf_modules
(
model
,
[
MixtralSparseMoeBlock
])
if
getattr
(
model
.
config
,
"
model_type
"
,
None
)
==
"qwen2moe"
:
if
model_type
==
"qwen2moe"
:
from
transformers.models.qwen2_moe.modeling_qwen2_moe
import
Qwen2MoeSparseMoeBlock
_set_z3_leaf_modules
(
model
,
[
Qwen2MoeSparseMoeBlock
])
def
configure_moe
(
config
:
"PretrainedConfig"
,
model_args
:
"ModelArguments"
,
is_trainable
:
bool
)
->
None
:
model_type
=
getattr
(
config
,
"model_type"
,
None
)
if
model_args
.
moe_aux_loss_coef
is
not
None
:
if
getattr
(
config
,
"
model_type
"
,
None
)
in
[
"jamba"
,
"mixtral"
,
"qwen2_moe"
]:
if
model_type
in
[
"jamba"
,
"mixtral"
,
"qwen2_moe"
]:
setattr
(
config
,
"router_aux_loss_coef"
,
model_args
.
moe_aux_loss_coef
)
elif
getattr
(
config
,
"
model_type
"
,
None
)
==
"deepseek"
:
elif
model_type
==
"deepseek"
:
setattr
(
config
,
"aux_loss_alpha"
,
model_args
.
moe_aux_loss_coef
)
elif
getattr
(
config
,
"
model_type
"
,
None
)
==
"jetmoe"
:
elif
model_type
==
"jetmoe"
:
setattr
(
config
,
"aux_loss_coef"
,
model_args
.
moe_aux_loss_coef
)
if
getattr
(
config
,
"
model_type
"
,
None
)
in
[
"dbrx"
,
"jamba"
,
"jetmoe"
,
"mixtral"
,
"qwen2_moe"
]:
if
model_type
in
[
"dbrx"
,
"jamba"
,
"jetmoe"
,
"mixtral"
,
"qwen2_moe"
]:
setattr
(
config
,
"output_router_logits"
,
is_trainable
)
src/llamafactory/model/model_utils/packing.py
View file @
27a7ad86
...
...
@@ -114,7 +114,7 @@ def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "tor
def
_patch_for_block_diag_attn
(
model_type
:
str
)
->
None
:
require_version
(
"transformers>=4.41.2,<=4.4
3.4
"
,
"To fix: pip install transformers>=4.41.2,<=4.4
3.4
"
)
require_version
(
"transformers>=4.41.2,<=4.4
5.2
"
,
"To fix: pip install transformers>=4.41.2,<=4.4
5.2
"
)
if
is_transformers_version_greater_than_4_43
():
import
transformers.modeling_flash_attention_utils
...
...
src/llamafactory/model/model_utils/visual.py
View file @
27a7ad86
...
...
@@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
TYPE_CHECKING
,
Tuple
from
typing
import
TYPE_CHECKING
,
List
,
Sequence
,
Set
,
Tuple
,
Union
import
torch
import
transformers.models
...
...
@@ -28,7 +28,7 @@ from ...extras.logging import get_logger
if
TYPE_CHECKING
:
from
transformers
import
LlavaConfig
,
PretrainedConfig
,
PreTrainedModel
from
...hparams
import
ModelArguments
from
...hparams
import
FinetuningArguments
,
ModelArguments
logger
=
get_logger
(
__name__
)
...
...
@@ -80,24 +80,120 @@ class LlavaMultiModalProjectorForYiVLForVLLM(LlavaMultiModalProjectorForYiVL):
self
.
act
=
ACT2FN
[
projector_hidden_act
]
def
autocast_projector_dtype
(
model
:
"PreTrainedModel"
,
model_args
:
"ModelArguments"
,
mm_projector_name
:
str
=
"multi_modal_projector"
)
->
None
:
def
autocast_projector_dtype
(
model
:
"PreTrainedModel"
,
model_args
:
"ModelArguments"
)
->
None
:
r
"""
Casts projector output to half precision for fine-tuning quantized VLMs.
"""
def
_mm_projector_forward_post_hook
(
module
:
"torch.nn.Module"
,
args
:
Tuple
[
"torch.Tensor"
],
output
:
"torch.Tensor"
)
->
"torch.Tensor"
:
return
output
.
to
(
model_args
.
compute_dtype
)
if
hasattr
(
model
,
mm_projector_name
)
and
getattr
(
model
,
"quantization_method"
,
None
):
if
getattr
(
model
,
"quantization_method"
,
None
):
model_type
=
getattr
(
model
.
config
,
"model_type"
,
None
)
if
model_type
in
[
"llava"
,
"llava_next"
,
"llava_next_video"
,
"paligemma"
,
"video_llava"
]:
mm_projector
:
"torch.nn.Module"
=
getattr
(
model
,
"multi_modal_projector"
)
elif
model_type
==
"qwen2_vl"
:
mm_projector
:
"torch.nn.Module"
=
getattr
(
getattr
(
model
,
"visual"
),
"merger"
)
else
:
return
logger
.
info
(
"Casting multimodal projector outputs in {}."
.
format
(
model_args
.
compute_dtype
))
mm_projector
:
"torch.nn.Module"
=
getattr
(
model
,
mm_projector_name
)
mm_projector
.
register_forward_hook
(
_mm_projector_forward_post_hook
)
def
configure_visual_model
(
config
:
"PretrainedConfig"
)
->
None
:
if
getattr
(
config
,
"model_type"
,
None
)
==
"llava"
:
# required for ds zero3 and valuehead models
r
"""
Patches VLMs before loading them.
"""
model_type
=
getattr
(
config
,
"model_type"
,
None
)
if
model_type
in
[
"llava"
,
"llava_next"
,
"llava_next_video"
,
"paligemma"
,
"video_llava"
,
]:
# required for ds zero3 and valuehead models
setattr
(
config
,
"hidden_size"
,
getattr
(
config
.
text_config
,
"hidden_size"
,
None
))
if
getattr
(
config
,
"is_yi_vl_derived_model"
,
None
):
logger
.
info
(
"Detected Yi-VL model, applying projector patch."
)
transformers
.
models
.
llava
.
modeling_llava
.
LlavaMultiModalProjector
=
LlavaMultiModalProjectorForYiVL
def
get_forbidden_modules
(
config
:
"PretrainedConfig"
,
finetuning_args
:
"FinetuningArguments"
)
->
Set
[
str
]:
r
"""
Freezes vision tower and language model for VLM full/freeze tuning.
"""
model_type
=
getattr
(
config
,
"model_type"
,
None
)
forbidden_modules
=
set
()
if
model_type
in
[
"llava"
,
"llava_next"
,
"llava_next_video"
,
"paligemma"
,
"video_llava"
]:
if
finetuning_args
.
freeze_vision_tower
:
forbidden_modules
.
add
(
"vision_tower"
)
if
finetuning_args
.
train_mm_proj_only
:
forbidden_modules
.
add
(
"language_model"
)
elif
model_type
==
"qwen2_vl"
:
if
finetuning_args
.
freeze_vision_tower
:
forbidden_modules
.
add
(
"visual"
)
if
finetuning_args
.
train_mm_proj_only
:
raise
ValueError
(
"Qwen2-VL models do not support `train_mm_proj_only`."
)
return
forbidden_modules
def
get_image_seqlen
(
config
:
"PretrainedConfig"
)
->
int
:
r
"""
Computes the number of special tokens per image.
"""
model_type
=
getattr
(
config
,
"model_type"
,
None
)
if
model_type
==
"llava"
:
image_seqlen
=
(
config
.
vision_config
.
image_size
//
config
.
vision_config
.
patch_size
)
**
2
if
getattr
(
config
,
"vision_feature_select_strategy"
,
"default"
)
==
"full"
:
# add [CLS] token
image_seqlen
+=
1
elif
model_type
==
"paligemma"
:
image_seqlen
=
config
.
vision_config
.
num_image_tokens
else
:
image_seqlen
=
-
1
return
image_seqlen
def
get_patch_size
(
config
:
"PretrainedConfig"
)
->
int
:
r
"""
Computes the patch size of the vit.
"""
patch_size
=
getattr
(
config
.
vision_config
,
"patch_size"
,
-
1
)
return
patch_size
def
get_vision_feature_select_strategy
(
config
:
"PretrainedConfig"
)
->
int
:
r
"""
Get the vision_feature_select_strategy.
"""
vision_feature_select_strategy
=
getattr
(
config
,
"vision_feature_select_strategy"
,
"default"
)
return
vision_feature_select_strategy
def
patch_target_modules
(
config
:
"PretrainedConfig"
,
finetuning_args
:
"FinetuningArguments"
,
target_modules
:
Sequence
[
str
]
)
->
Union
[
str
,
List
[
str
]]:
r
"""
Freezes vision tower for VLM LoRA tuning.
"""
model_type
=
getattr
(
config
,
"model_type"
,
None
)
if
finetuning_args
.
freeze_vision_tower
:
if
model_type
in
[
"llava"
,
"llava_next"
,
"llava_next_video"
,
"paligemma"
,
"video_llava"
]:
return
"^(?!.*vision_tower).*(?:{}).*"
.
format
(
"|"
.
join
(
target_modules
))
elif
model_type
==
"qwen2_vl"
:
return
"^(?!.*visual).*(?:{}).*"
.
format
(
"|"
.
join
(
target_modules
))
else
:
return
target_modules
else
:
if
model_type
==
"qwen2_vl"
:
return
"^(?!.*patch_embed).*(?:{}).*"
.
format
(
"|"
.
join
(
target_modules
))
else
:
return
target_modules
src/llamafactory/model/patcher.py
View file @
27a7ad86
...
...
@@ -33,11 +33,17 @@ from .model_utils.packing import configure_packing
from
.model_utils.quantization
import
configure_quantization
from
.model_utils.rope
import
configure_rope
from
.model_utils.valuehead
import
prepare_valuehead_model
from
.model_utils.visual
import
autocast_projector_dtype
,
configure_visual_model
from
.model_utils.visual
import
(
autocast_projector_dtype
,
configure_visual_model
,
get_image_seqlen
,
get_patch_size
,
get_vision_feature_select_strategy
,
)
if
TYPE_CHECKING
:
from
transformers
import
PretrainedConfig
,
PreTrainedTokenizer
from
transformers
import
PretrainedConfig
,
PreTrainedTokenizer
,
ProcessorMixin
from
trl
import
AutoModelForCausalLMWithValueHead
from
..hparams
import
ModelArguments
...
...
@@ -51,6 +57,22 @@ def patch_tokenizer(tokenizer: "PreTrainedTokenizer") -> None:
tokenizer
.
_pad
=
MethodType
(
PreTrainedTokenizerBase
.
_pad
,
tokenizer
)
def
patch_processor
(
processor
:
"ProcessorMixin"
,
config
:
"PretrainedConfig"
,
tokenizer
:
"PreTrainedTokenizer"
,
model_args
:
"ModelArguments"
,
)
->
None
:
setattr
(
processor
,
"tokenizer"
,
tokenizer
)
setattr
(
processor
,
"image_seqlen"
,
get_image_seqlen
(
config
))
setattr
(
processor
,
"image_resolution"
,
model_args
.
image_resolution
)
setattr
(
processor
,
"patch_size"
,
get_patch_size
(
config
))
setattr
(
processor
,
"video_resolution"
,
model_args
.
video_resolution
)
setattr
(
processor
,
"video_fps"
,
model_args
.
video_fps
)
setattr
(
processor
,
"video_maxlen"
,
model_args
.
video_maxlen
)
setattr
(
processor
,
"vision_feature_select_strategy"
,
get_vision_feature_select_strategy
(
config
))
def
patch_config
(
config
:
"PretrainedConfig"
,
tokenizer
:
"PreTrainedTokenizer"
,
...
...
@@ -88,6 +110,9 @@ def patch_config(
if
getattr
(
config
,
"model_type"
,
None
)
==
"qwen2"
and
is_trainable
and
model_args
.
flash_attn
==
"fa2"
:
setattr
(
config
,
"use_cache"
,
False
)
# qwen2 does not support use_cache when using flash attn
if
"LlavaLlamaForCausalLM"
in
getattr
(
config
,
"architectures"
,
[]):
raise
ValueError
(
"Please download llava models with hf-compatible format: https://huggingface.co/llava-hf"
)
# deepspeed zero3 is not compatible with low_cpu_mem_usage
init_kwargs
[
"low_cpu_mem_usage"
]
=
model_args
.
low_cpu_mem_usage
and
(
not
is_deepspeed_zero3_enabled
())
...
...
@@ -129,11 +154,9 @@ def patch_model(
if
model_args
.
resize_vocab
:
resize_embedding_layer
(
model
,
tokenizer
)
if
model_args
.
visual_inputs
:
autocast_projector_dtype
(
model
,
model_args
)
if
is_trainable
:
prepare_model_for_training
(
model
,
model_args
)
autocast_projector_dtype
(
model
,
model_args
)
add_z3_leaf_module
(
model
)
if
not
model_args
.
use_unsloth
:
...
...
src/llamafactory/train/callbacks.py
View file @
27a7ad86
...
...
@@ -32,9 +32,11 @@ from transformers.utils import (
WEIGHTS_NAME
,
is_safetensors_available
,
)
from
typing_extensions
import
override
from
..extras.constants
import
TRAINER_LOG
,
V_HEAD_SAFE_WEIGHTS_NAME
,
V_HEAD_WEIGHTS_NAME
from
..extras.logging
import
LoggerHandler
,
get_logger
from
..extras.misc
import
get_peak_memory
if
is_safetensors_available
():
...
...
@@ -73,8 +75,8 @@ def fix_valuehead_checkpoint(
path_to_checkpoint
=
os
.
path
.
join
(
output_dir
,
WEIGHTS_NAME
)
state_dict
:
Dict
[
str
,
torch
.
Tensor
]
=
torch
.
load
(
path_to_checkpoint
,
map_location
=
"cpu"
)
decoder_state_dict
=
{}
v_head_state_dict
=
{}
os
.
remove
(
path_to_checkpoint
)
decoder_state_dict
,
v_head_state_dict
=
{}
,
{}
for
name
,
param
in
state_dict
.
items
():
if
name
.
startswith
(
"v_head."
):
v_head_state_dict
[
name
]
=
param
...
...
@@ -90,43 +92,52 @@ def fix_valuehead_checkpoint(
else
:
torch
.
save
(
v_head_state_dict
,
os
.
path
.
join
(
output_dir
,
V_HEAD_WEIGHTS_NAME
))
os
.
remove
(
path_to_checkpoint
)
logger
.
info
(
"Value head model saved at: {}"
.
format
(
output_dir
))
class
FixValueHeadModelCallback
(
TrainerCallback
):
r
"""
A callback for fixing the checkpoint for valuehead models.
"""
@
override
def
on_save
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
r
"""
Event called after a checkpoint save.
"""
if
args
.
should_save
:
output_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"{}-{}"
.
format
(
PREFIX_CHECKPOINT_DIR
,
state
.
global_step
))
fix_valuehead_checkpoint
(
model
=
kwargs
.
pop
(
"model"
),
output_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"{}-{}"
.
format
(
PREFIX_CHECKPOINT_DIR
,
state
.
global_step
)),
safe_serialization
=
args
.
save_safetensors
,
model
=
kwargs
.
pop
(
"model"
),
output_dir
=
output_dir
,
safe_serialization
=
args
.
save_safetensors
)
class
SaveProcessorCallback
(
TrainerCallback
):
r
"""
A callback for saving the processor.
"""
def
__init__
(
self
,
processor
:
"ProcessorMixin"
)
->
None
:
r
"""
Initializes a callback for saving the processor.
"""
self
.
processor
=
processor
@
override
def
on_save
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
if
args
.
should_save
:
output_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"{}-{}"
.
format
(
PREFIX_CHECKPOINT_DIR
,
state
.
global_step
))
getattr
(
self
.
processor
,
"image_processor"
).
save_pretrained
(
output_dir
)
@
override
def
on_train_end
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
r
"""
Event called at the end of training.
"""
if
args
.
should_save
:
getattr
(
self
.
processor
,
"image_processor"
).
save_pretrained
(
args
.
output_dir
)
class
PissaConvertCallback
(
TrainerCallback
):
r
"""
Initializes a
callback for converting the PiSSA adapter to a normal one.
A
callback for converting the PiSSA adapter to a normal one.
"""
@
override
def
on_train_begin
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
r
"""
Event called at the beginning of training.
...
...
@@ -141,10 +152,8 @@ class PissaConvertCallback(TrainerCallback):
model
.
save_pretrained
(
pissa_init_dir
,
safe_serialization
=
args
.
save_safetensors
)
setattr
(
model
.
peft_config
[
"default"
],
"init_lora_weights"
,
init_lora_weights
)
@
override
def
on_train_end
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
r
"""
Event called at the end of training.
"""
if
args
.
should_save
:
model
=
kwargs
.
pop
(
"model"
)
pissa_init_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"pissa_init"
)
...
...
@@ -172,21 +181,22 @@ class PissaConvertCallback(TrainerCallback):
class
LogCallback
(
TrainerCallback
):
r
"""
A callback for logging training and evaluation status.
"""
def
__init__
(
self
)
->
None
:
r
"""
Initializes a callback for logging training and evaluation status.
"""
""" Progress """
# Progress
self
.
start_time
=
0
self
.
cur_steps
=
0
self
.
max_steps
=
0
self
.
elapsed_time
=
""
self
.
remaining_time
=
""
self
.
thread_pool
:
Optional
[
"ThreadPoolExecutor"
]
=
None
"""
Status
"""
#
Status
self
.
aborted
=
False
self
.
do_train
=
False
"""
Web UI
"""
#
Web UI
self
.
webui_mode
=
os
.
environ
.
get
(
"LLAMABOARD_ENABLED"
,
"0"
).
lower
()
in
[
"true"
,
"1"
]
if
self
.
webui_mode
:
signal
.
signal
(
signal
.
SIGABRT
,
self
.
_set_abort
)
...
...
@@ -226,10 +236,8 @@ class LogCallback(TrainerCallback):
self
.
thread_pool
.
shutdown
(
wait
=
True
)
self
.
thread_pool
=
None
@
override
def
on_init_end
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
r
"""
Event called at the end of the initialization of the `Trainer`.
"""
if
(
args
.
should_save
and
os
.
path
.
exists
(
os
.
path
.
join
(
args
.
output_dir
,
TRAINER_LOG
))
...
...
@@ -238,55 +246,41 @@ class LogCallback(TrainerCallback):
logger
.
warning
(
"Previous trainer log in this folder will be deleted."
)
os
.
remove
(
os
.
path
.
join
(
args
.
output_dir
,
TRAINER_LOG
))
@
override
def
on_train_begin
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
r
"""
Event called at the beginning of training.
"""
if
args
.
should_save
:
self
.
do_train
=
True
self
.
_reset
(
max_steps
=
state
.
max_steps
)
self
.
_create_thread_pool
(
output_dir
=
args
.
output_dir
)
@
override
def
on_train_end
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
r
"""
Event called at the end of training.
"""
self
.
_close_thread_pool
()
@
override
def
on_substep_end
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
r
"""
Event called at the end of an substep during gradient accumulation.
"""
if
self
.
aborted
:
control
.
should_epoch_stop
=
True
control
.
should_training_stop
=
True
@
override
def
on_step_end
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
r
"""
Event called at the end of a training step.
"""
if
self
.
aborted
:
control
.
should_epoch_stop
=
True
control
.
should_training_stop
=
True
@
override
def
on_evaluate
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
r
"""
Event called after an evaluation phase.
"""
if
not
self
.
do_train
:
self
.
_close_thread_pool
()
@
override
def
on_predict
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
r
"""
Event called after a successful prediction.
"""
if
not
self
.
do_train
:
self
.
_close_thread_pool
()
@
override
def
on_log
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
r
"""
Event called after logging the last logs.
"""
if
not
args
.
should_save
:
return
...
...
@@ -304,26 +298,31 @@ class LogCallback(TrainerCallback):
percentage
=
round
(
self
.
cur_steps
/
self
.
max_steps
*
100
,
2
)
if
self
.
max_steps
!=
0
else
100
,
elapsed_time
=
self
.
elapsed_time
,
remaining_time
=
self
.
remaining_time
,
throughput
=
"{:.2f}"
.
format
(
state
.
num_input_tokens_seen
/
(
time
.
time
()
-
self
.
start_time
)),
total_tokens
=
state
.
num_input_tokens_seen
,
)
if
state
.
num_input_tokens_seen
:
logs
[
"throughput"
]
=
round
(
state
.
num_input_tokens_seen
/
(
time
.
time
()
-
self
.
start_time
),
2
)
logs
[
"total_tokens"
]
=
state
.
num_input_tokens_seen
if
os
.
environ
.
get
(
"RECORD_VRAM"
,
"0"
).
lower
()
in
[
"true"
,
"1"
]:
vram_allocated
,
vram_reserved
=
get_peak_memory
()
logs
[
"vram_allocated"
]
=
round
(
vram_allocated
/
1024
/
1024
/
1024
,
2
)
logs
[
"vram_reserved"
]
=
round
(
vram_reserved
/
1024
/
1024
/
1024
,
2
)
logs
=
{
k
:
v
for
k
,
v
in
logs
.
items
()
if
v
is
not
None
}
if
self
.
webui_mode
and
all
(
key
in
logs
for
key
in
[
"loss"
,
"learning_rate"
,
"epoch"
]):
logger
.
info
(
"{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}, 'throughput': {}}}"
.
format
(
logs
[
"loss"
],
logs
[
"learning_rate"
],
logs
[
"epoch"
],
logs
[
"throughput"
]
logs
[
"loss"
],
logs
[
"learning_rate"
],
logs
[
"epoch"
],
logs
.
get
(
"throughput"
,
"N/A"
)
)
)
if
self
.
thread_pool
is
not
None
:
self
.
thread_pool
.
submit
(
self
.
_write_log
,
args
.
output_dir
,
logs
)
@
override
def
on_prediction_step
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
r
"""
Event called after a prediction step.
"""
if
self
.
do_train
:
return
...
...
src/llamafactory/train/dpo/trainer.py
View file @
27a7ad86
...
...
@@ -26,6 +26,7 @@ import torch.nn.functional as F
from
transformers
import
Trainer
from
trl
import
DPOTrainer
from
trl.trainer
import
disable_dropout_in_model
from
typing_extensions
import
override
from
...extras.constants
import
IGNORE_INDEX
from
..callbacks
import
PissaConvertCallback
,
SaveProcessorCallback
...
...
@@ -104,11 +105,13 @@ class CustomDPOTrainer(DPOTrainer):
self
.
accelerator
.
clip_grad_norm_
=
MethodType
(
clip_grad_norm_old_version
,
self
.
accelerator
)
self
.
add_callback
(
BAdamCallback
)
@
override
def
create_optimizer
(
self
)
->
"torch.optim.Optimizer"
:
if
self
.
optimizer
is
None
:
self
.
optimizer
=
create_custom_optimizer
(
self
.
model
,
self
.
args
,
self
.
finetuning_args
)
return
super
().
create_optimizer
()
@
override
def
create_scheduler
(
self
,
num_training_steps
:
int
,
optimizer
:
Optional
[
"torch.optim.Optimizer"
]
=
None
)
->
"torch.optim.lr_scheduler.LRScheduler"
:
...
...
@@ -164,6 +167,7 @@ class CustomDPOTrainer(DPOTrainer):
return
losses
,
chosen_rewards
,
rejected_rewards
@
override
def
concatenated_forward
(
self
,
model
:
"PreTrainedModel"
,
batch
:
Dict
[
str
,
"torch.Tensor"
]
)
->
Tuple
[
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
]:
...
...
@@ -176,7 +180,6 @@ class CustomDPOTrainer(DPOTrainer):
batch
=
{
k
:
v
.
detach
().
clone
()
for
k
,
v
in
batch
.
items
()}
# avoid error
all_logits
:
"torch.Tensor"
=
model
(
**
batch
,
return_dict
=
True
,
use_cache
=
False
).
logits
.
to
(
torch
.
float32
)
all_logps
,
valid_length
=
get_batch_logps
(
logits
=
all_logits
,
labels
=
batch
[
"labels"
])
if
self
.
loss_type
in
[
"ipo"
,
"orpo"
,
"simpo"
]:
all_logps
=
all_logps
/
valid_length
...
...
@@ -187,6 +190,7 @@ class CustomDPOTrainer(DPOTrainer):
chosen_length
,
_
=
valid_length
.
split
(
batch_size
,
dim
=
0
)
return
chosen_logps
,
rejected_logps
,
chosen_logits
,
rejected_logits
,
chosen_logps
/
chosen_length
@
override
def
compute_reference_log_probs
(
self
,
model
:
"PreTrainedModel"
,
batch
:
Dict
[
str
,
"torch.Tensor"
]
)
->
Tuple
[
Optional
[
"torch.Tensor"
],
Optional
[
"torch.Tensor"
]]:
...
...
@@ -208,6 +212,7 @@ class CustomDPOTrainer(DPOTrainer):
return
reference_chosen_logps
,
reference_rejected_logps
@
override
def
get_batch_loss_metrics
(
self
,
model
:
"PreTrainedModel"
,
...
...
src/llamafactory/train/dpo/workflow.py
View file @
27a7ad86
...
...
@@ -17,7 +17,7 @@
from
typing
import
TYPE_CHECKING
,
List
,
Optional
from
...data
import
PairwiseDataCollatorWithPadding
,
get_dataset
from
...data
import
PairwiseDataCollatorWithPadding
,
get_dataset
,
get_template_and_fix_tokenizer
from
...extras.constants
import
IGNORE_INDEX
from
...extras.ploting
import
plot_loss
from
...hparams
import
ModelArguments
...
...
@@ -41,13 +41,15 @@ def run_dpo(
):
tokenizer_module
=
load_tokenizer
(
model_args
)
tokenizer
=
tokenizer_module
[
"tokenizer"
]
dataset_module
=
get_dataset
(
model_args
,
data_args
,
training_args
,
stage
=
"rm"
,
**
tokenizer_module
)
template
=
get_template_and_fix_tokenizer
(
tokenizer
,
data_args
)
dataset_module
=
get_dataset
(
template
,
model_args
,
data_args
,
training_args
,
stage
=
"rm"
,
**
tokenizer_module
)
model
=
load_model
(
tokenizer
,
model_args
,
finetuning_args
,
training_args
.
do_train
)
data_collator
=
PairwiseDataCollatorWithPadding
(
t
okenizer
=
tokenizer
,
t
emplate
=
template
,
pad_to_multiple_of
=
8
,
label_pad_token_id
=
IGNORE_INDEX
if
data_args
.
ignore_pad_token_for_loss
else
tokenizer
.
pad_token_id
,
**
tokenizer_module
,
)
# Create reference model
...
...
@@ -60,7 +62,7 @@ def run_dpo(
ref_model
=
None
# Update arguments
training_args
.
remove_unused_columns
=
False
# important for pairwise dataset
training_args
.
remove_unused_columns
=
False
# important for
multimodal and
pairwise dataset
# Initialize our Trainer
trainer
=
CustomDPOTrainer
(
...
...
src/llamafactory/train/kto/trainer.py
View file @
27a7ad86
...
...
@@ -25,6 +25,7 @@ import torch
from
transformers
import
Trainer
from
trl
import
KTOTrainer
from
trl.trainer
import
disable_dropout_in_model
from
typing_extensions
import
override
from
...extras.constants
import
IGNORE_INDEX
from
..callbacks
import
SaveProcessorCallback
...
...
@@ -99,23 +100,27 @@ class CustomKTOTrainer(KTOTrainer):
self
.
accelerator
.
clip_grad_norm_
=
MethodType
(
clip_grad_norm_old_version
,
self
.
accelerator
)
self
.
add_callback
(
BAdamCallback
)
@
override
def
create_optimizer
(
self
)
->
"torch.optim.Optimizer"
:
if
self
.
optimizer
is
None
:
self
.
optimizer
=
create_custom_optimizer
(
self
.
model
,
self
.
args
,
self
.
finetuning_args
)
return
super
().
create_optimizer
()
@
override
def
create_scheduler
(
self
,
num_training_steps
:
int
,
optimizer
:
Optional
[
"torch.optim.Optimizer"
]
=
None
)
->
"torch.optim.lr_scheduler.LRScheduler"
:
create_custom_scheduler
(
self
.
args
,
num_training_steps
,
optimizer
)
return
super
().
create_scheduler
(
num_training_steps
,
optimizer
)
@
override
def
_get_train_sampler
(
self
)
->
Optional
[
"torch.utils.data.Sampler"
]:
r
"""
Replaces the sequential sampler of KTO Trainer created by trl with the random sampler.
"""
return
Trainer
.
_get_train_sampler
(
self
)
@
override
def
forward
(
self
,
model
:
"PreTrainedModel"
,
batch
:
Dict
[
str
,
"torch.Tensor"
],
prefix
:
Literal
[
""
,
"kl_"
]
=
""
)
->
Tuple
[
"torch.Tensor"
,
"torch.Tensor"
]:
...
...
@@ -127,17 +132,20 @@ class CustomKTOTrainer(KTOTrainer):
"input_ids"
:
batch
[
"{}input_ids"
.
format
(
prefix
)],
"attention_mask"
:
batch
[
"{}attention_mask"
.
format
(
prefix
)],
}
if
"{}token_type_ids"
.
format
(
prefix
)
in
batch
:
model_inputs
[
"token_type_ids"
]
=
batch
[
"{}token_type_ids"
.
format
(
prefix
)]
if
"pixel_values"
in
batch
:
model_inputs
[
"pixel_values"
]
=
batch
[
"pixel_values"
]
if
"
{}token_type_ids"
.
format
(
prefix
)
in
batch
:
model_inputs
[
"
token_type_ids"
]
=
batch
[
"{}token_type_ids"
.
format
(
prefix
)
]
if
"
image_grid_thw"
in
batch
:
model_inputs
[
"
image_grid_thw"
]
=
batch
[
"image_grid_thw"
]
logits
=
model
(
**
model_inputs
,
return_dict
=
True
,
use_cache
=
False
).
logits
.
to
(
torch
.
float32
)
logps
,
valid_length
=
get_batch_logps
(
logits
=
logits
,
labels
=
batch
[
"{}labels"
.
format
(
prefix
)])
return
logps
,
logps
/
valid_length
@
override
def
concatenated_forward
(
self
,
model
:
"PreTrainedModel"
,
batch
:
Dict
[
str
,
"torch.Tensor"
]
)
->
Tuple
[
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
]:
...
...
@@ -153,6 +161,7 @@ class CustomKTOTrainer(KTOTrainer):
chosen_logps_avg
=
target_logps_avg
[
batch
[
"kto_tags"
]]
return
chosen_logps
,
rejected_logps
,
kl_logps
,
chosen_logps_avg
@
override
def
compute_reference_log_probs
(
self
,
model
:
"PreTrainedModel"
,
batch
:
Dict
[
str
,
"torch.Tensor"
]
)
->
Tuple
[
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
]:
...
...
@@ -173,6 +182,7 @@ class CustomKTOTrainer(KTOTrainer):
return
reference_chosen_logps
,
reference_rejected_logps
,
reference_kl_logps
@
override
def
get_batch_loss_metrics
(
self
,
model
:
"PreTrainedModel"
,
...
...
src/llamafactory/train/kto/workflow.py
View file @
27a7ad86
...
...
@@ -17,7 +17,7 @@
from
typing
import
TYPE_CHECKING
,
List
,
Optional
from
...data
import
KTODataCollatorWithPadding
,
get_dataset
from
...data
import
KTODataCollatorWithPadding
,
get_dataset
,
get_template_and_fix_tokenizer
from
...extras.constants
import
IGNORE_INDEX
from
...extras.ploting
import
plot_loss
from
...hparams
import
ModelArguments
...
...
@@ -41,13 +41,15 @@ def run_kto(
):
tokenizer_module
=
load_tokenizer
(
model_args
)
tokenizer
=
tokenizer_module
[
"tokenizer"
]
dataset_module
=
get_dataset
(
model_args
,
data_args
,
training_args
,
stage
=
"kto"
,
**
tokenizer_module
)
template
=
get_template_and_fix_tokenizer
(
tokenizer
,
data_args
)
dataset_module
=
get_dataset
(
template
,
model_args
,
data_args
,
training_args
,
stage
=
"kto"
,
**
tokenizer_module
)
model
=
load_model
(
tokenizer
,
model_args
,
finetuning_args
,
training_args
.
do_train
)
data_collator
=
KTODataCollatorWithPadding
(
t
okenizer
=
tokenizer
,
t
emplate
=
template
,
pad_to_multiple_of
=
8
,
label_pad_token_id
=
IGNORE_INDEX
if
data_args
.
ignore_pad_token_for_loss
else
tokenizer
.
pad_token_id
,
**
tokenizer_module
,
)
# Create reference model
...
...
@@ -57,7 +59,7 @@ def run_kto(
ref_model
=
create_ref_model
(
model_args
,
finetuning_args
)
# Update arguments
training_args
.
remove_unused_columns
=
False
# important for pairwise dataset
training_args
.
remove_unused_columns
=
False
# important for
multimodal and
pairwise dataset
# Initialize our Trainer
trainer
=
CustomKTOTrainer
(
...
...
src/llamafactory/train/ppo/ppo_utils.py
View file @
27a7ad86
...
...
@@ -31,7 +31,7 @@ if TYPE_CHECKING:
from
trl
import
AutoModelForCausalLMWithValueHead
def
get_rewards_from_server
(
server_url
:
str
,
messages
:
List
[
str
])
->
List
[
torch
.
Tensor
]:
def
get_rewards_from_server
(
server_url
:
str
,
messages
:
List
[
str
])
->
List
[
"
torch.Tensor
"
]:
r
"""
Gets reward scores from the API server.
"""
...
...
@@ -66,7 +66,7 @@ def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["d
v_head_layer
.
bias
.
data
=
model
.
get_buffer
(
"{}_head_bias"
.
format
(
target
)).
detach
().
clone
().
to
(
device
)
def
dump_layernorm
(
model
:
"PreTrainedModel"
)
->
Dict
[
str
,
torch
.
Tensor
]:
def
dump_layernorm
(
model
:
"PreTrainedModel"
)
->
Dict
[
str
,
"
torch.Tensor
"
]:
r
"""
Dumps the layernorm parameters in the model. The model is already unwrapped (and gathered).
"""
...
...
@@ -79,7 +79,7 @@ def dump_layernorm(model: "PreTrainedModel") -> Dict[str, torch.Tensor]:
return
layer_norm_params
def
restore_layernorm
(
model
:
"PreTrainedModel"
,
layernorm_params
:
Optional
[
Dict
[
str
,
torch
.
Tensor
]]
=
None
)
->
None
:
def
restore_layernorm
(
model
:
"PreTrainedModel"
,
layernorm_params
:
Optional
[
Dict
[
str
,
"
torch.Tensor
"
]]
=
None
)
->
None
:
r
"""
Restores the layernorm parameters in the model. The model is already unwrapped (and gathered).
"""
...
...
src/llamafactory/train/ppo/trainer.py
View file @
27a7ad86
...
...
@@ -35,6 +35,7 @@ from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME
from
trl
import
PPOConfig
,
PPOTrainer
from
trl.core
import
PPODecorators
,
logprobs_from_logits
from
trl.models.utils
import
unwrap_model_for_generation
from
typing_extensions
import
override
from
...extras.logging
import
get_logger
from
...extras.misc
import
AverageMeter
,
count_parameters
,
get_current_device
,
get_logits_processor
...
...
@@ -298,6 +299,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self
.
callback_handler
.
on_train_end
(
self
.
args
,
self
.
state
,
self
.
control
)
@
override
def
create_optimizer
(
self
,
model
:
"AutoModelForCausalLMWithValueHead"
,
...
...
@@ -324,6 +326,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
return
optimizer
@
override
def
create_scheduler
(
self
,
training_args
:
"Seq2SeqTrainingArguments"
,
num_training_steps
:
int
,
optimizer
:
"torch.optim.Optimizer"
)
->
"torch.optim.lr_scheduler.LRScheduler"
:
...
...
@@ -389,7 +392,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
"""
if
self
.
finetuning_args
.
reward_model_type
==
"api"
:
token_ids
=
[
torch
.
cat
((
q
,
r
),
dim
=-
1
).
tolist
()
for
q
,
r
in
zip
(
queries
,
responses
)]
messages
=
self
.
tokenizer
.
batch_decode
(
token_ids
,
skip_special_tokens
=
Tru
e
)
messages
=
self
.
tokenizer
.
batch_decode
(
token_ids
,
skip_special_tokens
=
Fals
e
)
return
get_rewards_from_server
(
self
.
reward_model
,
messages
)
batch
:
Dict
[
str
,
"torch.Tensor"
]
=
self
.
prepare_model_inputs
(
queries
,
responses
)
...
...
@@ -402,7 +405,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
reward_model
=
self
.
reward_model
with
unwrap_model_for_generation
(
reward_model
,
self
.
accelerator
),
self
.
amp_context
:
# support bf16
_
,
_
,
values
=
reward_model
(
**
batch
,
return_dict
=
True
,
use_cache
=
False
)
values
:
"torch.Tensor"
=
reward_model
(
**
batch
,
return_dict
=
True
,
use_cache
=
False
)
[
-
1
]
if
self
.
finetuning_args
.
reward_model_type
==
"lora"
:
replace_model
(
unwrapped_model
,
target
=
"default"
)
...
...
@@ -410,6 +413,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
rewards
=
values
.
gather
(
dim
=-
1
,
index
=
(
batch
[
"attention_mask"
].
sum
(
dim
=-
1
,
keepdim
=
True
)
-
1
))
return
rewards
.
float
().
detach
()
# use fp32 type
@
override
@
PPODecorators
.
empty_device_cache
()
def
batched_forward_pass
(
self
,
...
...
@@ -478,6 +482,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
torch
.
cat
(
all_masks
)[:,
:
-
1
],
)
@
override
def
save_model
(
self
,
output_dir
:
Optional
[
str
]
=
None
)
->
None
:
r
"""
Saves model checkpoint.
...
...
src/llamafactory/train/ppo/workflow.py
View file @
27a7ad86
...
...
@@ -17,9 +17,7 @@
from
typing
import
TYPE_CHECKING
,
List
,
Optional
from
transformers
import
DataCollatorWithPadding
from
...data
import
get_dataset
from
...data
import
MultiModalDataCollatorForSeq2Seq
,
get_dataset
,
get_template_and_fix_tokenizer
from
...extras.ploting
import
plot_loss
from
...model
import
load_model
,
load_tokenizer
from
..callbacks
import
fix_valuehead_checkpoint
...
...
@@ -43,11 +41,12 @@ def run_ppo(
):
tokenizer_module
=
load_tokenizer
(
model_args
)
tokenizer
=
tokenizer_module
[
"tokenizer"
]
dataset_module
=
get_dataset
(
model_args
,
data_args
,
training_args
,
stage
=
"ppo"
,
**
tokenizer_module
)
template
=
get_template_and_fix_tokenizer
(
tokenizer
,
data_args
)
dataset_module
=
get_dataset
(
template
,
model_args
,
data_args
,
training_args
,
stage
=
"ppo"
,
**
tokenizer_module
)
model
=
load_model
(
tokenizer
,
model_args
,
finetuning_args
,
training_args
.
do_train
,
add_valuehead
=
True
)
tokenizer
.
padding_side
=
"left"
# use left-padding in generation while using right-padding in training
data_collator
=
DataCollator
WithPadding
(
tokenizer
=
tokenizer
)
data_collator
=
MultiModal
DataCollator
ForSeq2Seq
(
template
=
template
,
**
tokenizer
_module
)
# Create reference model and reward model
ref_model
=
create_ref_model
(
model_args
,
finetuning_args
,
add_valuehead
=
True
)
...
...
src/llamafactory/train/pt/trainer.py
View file @
27a7ad86
...
...
@@ -16,6 +16,7 @@ from types import MethodType
from
typing
import
TYPE_CHECKING
,
Optional
from
transformers
import
Trainer
from
typing_extensions
import
override
from
...extras.logging
import
get_logger
from
..callbacks
import
PissaConvertCallback
,
SaveProcessorCallback
...
...
@@ -55,11 +56,13 @@ class CustomTrainer(Trainer):
self
.
accelerator
.
clip_grad_norm_
=
MethodType
(
clip_grad_norm_old_version
,
self
.
accelerator
)
self
.
add_callback
(
BAdamCallback
)
@
override
def
create_optimizer
(
self
)
->
"torch.optim.Optimizer"
:
if
self
.
optimizer
is
None
:
self
.
optimizer
=
create_custom_optimizer
(
self
.
model
,
self
.
args
,
self
.
finetuning_args
)
return
super
().
create_optimizer
()
@
override
def
create_scheduler
(
self
,
num_training_steps
:
int
,
optimizer
:
Optional
[
"torch.optim.Optimizer"
]
=
None
)
->
"torch.optim.lr_scheduler.LRScheduler"
:
...
...
src/llamafactory/train/pt/workflow.py
View file @
27a7ad86
...
...
@@ -20,7 +20,7 @@ from typing import TYPE_CHECKING, List, Optional
from
transformers
import
DataCollatorForLanguageModeling
from
...data
import
get_dataset
from
...data
import
get_dataset
,
get_template_and_fix_tokenizer
from
...extras.ploting
import
plot_loss
from
...model
import
load_model
,
load_tokenizer
from
..trainer_utils
import
create_modelcard_and_push
...
...
@@ -42,7 +42,8 @@ def run_pt(
):
tokenizer_module
=
load_tokenizer
(
model_args
)
tokenizer
=
tokenizer_module
[
"tokenizer"
]
dataset_module
=
get_dataset
(
model_args
,
data_args
,
training_args
,
stage
=
"pt"
,
**
tokenizer_module
)
template
=
get_template_and_fix_tokenizer
(
tokenizer
,
data_args
)
dataset_module
=
get_dataset
(
template
,
model_args
,
data_args
,
training_args
,
stage
=
"pt"
,
**
tokenizer_module
)
model
=
load_model
(
tokenizer
,
model_args
,
finetuning_args
,
training_args
.
do_train
)
data_collator
=
DataCollatorForLanguageModeling
(
tokenizer
=
tokenizer
,
mlm
=
False
)
...
...
src/llamafactory/train/rm/metric.py
View file @
27a7ad86
...
...
@@ -26,6 +26,10 @@ if TYPE_CHECKING:
@
dataclass
class
ComputeAccuracy
:
r
"""
Computes reward accuracy and supports `batch_eval_metrics`.
"""
def
_dump
(
self
)
->
Optional
[
Dict
[
str
,
float
]]:
result
=
None
if
hasattr
(
self
,
"score_dict"
):
...
...
src/llamafactory/train/rm/trainer.py
View file @
27a7ad86
...
...
@@ -22,6 +22,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import
torch
from
transformers
import
Trainer
from
typing_extensions
import
override
from
...extras.logging
import
get_logger
from
..callbacks
import
FixValueHeadModelCallback
,
PissaConvertCallback
,
SaveProcessorCallback
...
...
@@ -63,20 +64,23 @@ class PairwiseTrainer(Trainer):
self
.
accelerator
.
clip_grad_norm_
=
MethodType
(
clip_grad_norm_old_version
,
self
.
accelerator
)
self
.
add_callback
(
BAdamCallback
)
@
override
def
create_optimizer
(
self
)
->
"torch.optim.Optimizer"
:
if
self
.
optimizer
is
None
:
self
.
optimizer
=
create_custom_optimizer
(
self
.
model
,
self
.
args
,
self
.
finetuning_args
)
return
super
().
create_optimizer
()
@
override
def
create_scheduler
(
self
,
num_training_steps
:
int
,
optimizer
:
Optional
[
"torch.optim.Optimizer"
]
=
None
)
->
"torch.optim.lr_scheduler.LRScheduler"
:
create_custom_scheduler
(
self
.
args
,
num_training_steps
,
optimizer
)
return
super
().
create_scheduler
(
num_training_steps
,
optimizer
)
@
override
def
compute_loss
(
self
,
model
:
"PreTrainedModel"
,
inputs
:
Dict
[
str
,
torch
.
Tensor
],
return_outputs
:
bool
=
False
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]]:
self
,
model
:
"PreTrainedModel"
,
inputs
:
Dict
[
str
,
"
torch.Tensor
"
],
return_outputs
:
bool
=
False
)
->
Union
[
"
torch.Tensor
"
,
Tuple
[
"
torch.Tensor
"
,
List
[
"
torch.Tensor
"
]]]:
r
"""
Computes pairwise loss. The first n examples are chosen and the last n examples are rejected.
...
...
src/llamafactory/train/rm/workflow.py
View file @
27a7ad86
...
...
@@ -17,7 +17,7 @@
from
typing
import
TYPE_CHECKING
,
List
,
Optional
from
...data
import
PairwiseDataCollatorWithPadding
,
get_dataset
from
...data
import
PairwiseDataCollatorWithPadding
,
get_dataset
,
get_template_and_fix_tokenizer
from
...extras.ploting
import
plot_loss
from
...model
import
load_model
,
load_tokenizer
from
..callbacks
import
fix_valuehead_checkpoint
...
...
@@ -41,12 +41,13 @@ def run_rm(
):
tokenizer_module
=
load_tokenizer
(
model_args
)
tokenizer
=
tokenizer_module
[
"tokenizer"
]
dataset_module
=
get_dataset
(
model_args
,
data_args
,
training_args
,
stage
=
"rm"
,
**
tokenizer_module
)
template
=
get_template_and_fix_tokenizer
(
tokenizer
,
data_args
)
dataset_module
=
get_dataset
(
template
,
model_args
,
data_args
,
training_args
,
stage
=
"rm"
,
**
tokenizer_module
)
model
=
load_model
(
tokenizer
,
model_args
,
finetuning_args
,
training_args
.
do_train
,
add_valuehead
=
True
)
data_collator
=
PairwiseDataCollatorWithPadding
(
t
okenizer
,
pad_to_multiple_of
=
8
)
data_collator
=
PairwiseDataCollatorWithPadding
(
t
emplate
=
template
,
pad_to_multiple_of
=
8
,
**
tokenizer_module
)
# Update arguments
training_args
.
remove_unused_columns
=
False
# important for pairwise dataset
training_args
.
remove_unused_columns
=
False
# important for
multimodal and
pairwise dataset
# Initialize our Trainer
trainer
=
PairwiseTrainer
(
...
...
src/llamafactory/train/sft/metric.py
View file @
27a7ad86
...
...
@@ -45,6 +45,9 @@ if is_rouge_available():
def
eval_logit_processor
(
logits
:
"torch.Tensor"
,
labels
:
"torch.Tensor"
)
->
"torch.Tensor"
:
r
"""
Computes the token with the largest likelihood to reduce memory footprint.
"""
if
isinstance
(
logits
,
(
list
,
tuple
)):
if
logits
[
0
].
dim
()
==
3
:
# (batch_size, seq_len, vocab_size)
logits
=
logits
[
0
]
...
...
@@ -59,6 +62,10 @@ def eval_logit_processor(logits: "torch.Tensor", labels: "torch.Tensor") -> "tor
@
dataclass
class
ComputeAccuracy
:
r
"""
Computes accuracy and supports `batch_eval_metrics`.
"""
def
_dump
(
self
)
->
Optional
[
Dict
[
str
,
float
]]:
result
=
None
if
hasattr
(
self
,
"score_dict"
):
...
...
@@ -84,6 +91,8 @@ class ComputeAccuracy:
@
dataclass
class
ComputeSimilarity
:
r
"""
Computes text similarity scores and supports `batch_eval_metrics`.
Wraps the tokenizer into metric functions, used in CustomSeq2SeqTrainer.
"""
...
...
src/llamafactory/train/sft/trainer.py
View file @
27a7ad86
...
...
@@ -23,6 +23,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import
numpy
as
np
import
torch
from
transformers
import
Seq2SeqTrainer
from
typing_extensions
import
override
from
...extras.constants
import
IGNORE_INDEX
from
...extras.logging
import
get_logger
...
...
@@ -64,32 +65,36 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
self
.
accelerator
.
clip_grad_norm_
=
MethodType
(
clip_grad_norm_old_version
,
self
.
accelerator
)
self
.
add_callback
(
BAdamCallback
)
@
override
def
create_optimizer
(
self
)
->
"torch.optim.Optimizer"
:
if
self
.
optimizer
is
None
:
self
.
optimizer
=
create_custom_optimizer
(
self
.
model
,
self
.
args
,
self
.
finetuning_args
)
return
super
().
create_optimizer
()
@
override
def
create_scheduler
(
self
,
num_training_steps
:
int
,
optimizer
:
Optional
[
"torch.optim.Optimizer"
]
=
None
)
->
"torch.optim.lr_scheduler.LRScheduler"
:
create_custom_scheduler
(
self
.
args
,
num_training_steps
,
optimizer
)
return
super
().
create_scheduler
(
num_training_steps
,
optimizer
)
@
override
def
prediction_step
(
self
,
model
:
"torch.nn.Module"
,
inputs
:
Dict
[
str
,
Union
[
torch
.
Tensor
,
Any
]],
inputs
:
Dict
[
str
,
Union
[
"
torch.Tensor
"
,
Any
]],
prediction_loss_only
:
bool
,
ignore_keys
:
Optional
[
List
[
str
]]
=
None
,
)
->
Tuple
[
Optional
[
float
],
Optional
[
torch
.
Tensor
],
Optional
[
torch
.
Tensor
]]:
)
->
Tuple
[
Optional
[
float
],
Optional
[
"
torch.Tensor
"
],
Optional
[
"
torch.Tensor
"
]]:
r
"""
Removes the prompt part in the generated tokens.
Subclass and override to inject custom behavior.
"""
labels
=
inputs
[
"labels"
]
.
detach
().
clone
()
if
"labels"
in
inputs
else
None
# backup labels
labels
=
inputs
[
"labels"
]
if
"labels"
in
inputs
else
None
if
self
.
args
.
predict_with_generate
:
assert
self
.
tokenizer
.
padding_side
==
"left"
,
"This method only accepts left-padded tensor."
labels
=
labels
.
detach
().
clone
()
if
labels
is
not
None
else
None
# backup labels
prompt_len
,
label_len
=
inputs
[
"input_ids"
].
size
(
-
1
),
inputs
[
"labels"
].
size
(
-
1
)
if
prompt_len
>
label_len
:
inputs
[
"labels"
]
=
self
.
_pad_tensors_to_target_len
(
inputs
[
"labels"
],
inputs
[
"input_ids"
])
...
...
@@ -105,7 +110,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
return
loss
,
generated_tokens
,
labels
def
_pad_tensors_to_target_len
(
self
,
src_tensor
:
torch
.
Tensor
,
tgt_tensor
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
_pad_tensors_to_target_len
(
self
,
src_tensor
:
"
torch.Tensor
"
,
tgt_tensor
:
"
torch.Tensor
"
)
->
"
torch.Tensor
"
:
r
"""
Pads the tensor to the same length as the target tensor.
"""
...
...
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment