Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
Bw-bestperf
Qwen2.5-VL-7B-LlamaFactory
Commits
b59a5620
Commit
b59a5620
authored
Feb 06, 2026
by
litzh
Browse files
Initial commit
parents
Pipeline
#3383
canceled with stages
Changes
280
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2592 additions
and
0 deletions
+2592
-0
src/llamafactory/model/patcher.py
src/llamafactory/model/patcher.py
+219
-0
src/llamafactory/third_party/__init__.py
src/llamafactory/third_party/__init__.py
+0
-0
src/llamafactory/third_party/muon/__init__.py
src/llamafactory/third_party/muon/__init__.py
+18
-0
src/llamafactory/third_party/muon/muon.py
src/llamafactory/third_party/muon/muon.py
+226
-0
src/llamafactory/train/__init__.py
src/llamafactory/train/__init__.py
+0
-0
src/llamafactory/train/callbacks.py
src/llamafactory/train/callbacks.py
+385
-0
src/llamafactory/train/dpo/__init__.py
src/llamafactory/train/dpo/__init__.py
+18
-0
src/llamafactory/train/dpo/trainer.py
src/llamafactory/train/dpo/trainer.py
+302
-0
src/llamafactory/train/dpo/workflow.py
src/llamafactory/train/dpo/workflow.py
+110
-0
src/llamafactory/train/kto/__init__.py
src/llamafactory/train/kto/__init__.py
+18
-0
src/llamafactory/train/kto/trainer.py
src/llamafactory/train/kto/trainer.py
+297
-0
src/llamafactory/train/kto/workflow.py
src/llamafactory/train/kto/workflow.py
+101
-0
src/llamafactory/train/ppo/__init__.py
src/llamafactory/train/ppo/__init__.py
+18
-0
src/llamafactory/train/ppo/ppo_utils.py
src/llamafactory/train/ppo/ppo_utils.py
+80
-0
src/llamafactory/train/ppo/trainer.py
src/llamafactory/train/ppo/trainer.py
+503
-0
src/llamafactory/train/ppo/workflow.py
src/llamafactory/train/ppo/workflow.py
+79
-0
src/llamafactory/train/pt/__init__.py
src/llamafactory/train/pt/__init__.py
+18
-0
src/llamafactory/train/pt/trainer.py
src/llamafactory/train/pt/trainer.py
+81
-0
src/llamafactory/train/pt/workflow.py
src/llamafactory/train/pt/workflow.py
+101
-0
src/llamafactory/train/rm/__init__.py
src/llamafactory/train/rm/__init__.py
+18
-0
No files found.
src/llamafactory/model/patcher.py
0 → 100644
View file @
b59a5620
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
types
import
MethodType
from
typing
import
TYPE_CHECKING
,
Any
import
torch
from
peft
import
PeftModel
from
transformers
import
GenerationMixin
,
PreTrainedModel
,
PreTrainedTokenizerBase
from
transformers.integrations
import
is_deepspeed_zero3_enabled
from
transformers.modeling_utils
import
is_fsdp_enabled
from
..extras
import
logging
from
..extras.misc
import
infer_optim_dtype
from
..extras.packages
import
is_transformers_version_greater_than
from
.model_utils.attention
import
configure_attn_implementation
,
print_attn_implementation
from
.model_utils.checkpointing
import
prepare_model_for_training
from
.model_utils.embedding
import
resize_embedding_layer
from
.model_utils.kv_cache
import
configure_kv_cache
from
.model_utils.longlora
import
configure_longlora
from
.model_utils.moe
import
add_z3_leaf_module
,
configure_moe
from
.model_utils.packing
import
configure_packing
from
.model_utils.quantization
import
configure_quantization
from
.model_utils.rope
import
configure_rope
from
.model_utils.valuehead
import
prepare_valuehead_model
from
.model_utils.visual
import
autocast_projector_dtype
,
configure_visual_model
if
TYPE_CHECKING
:
from
transformers
import
PretrainedConfig
,
PreTrainedTokenizer
,
ProcessorMixin
from
trl
import
AutoModelForCausalLMWithValueHead
from
..hparams
import
ModelArguments
logger
=
logging
.
get_logger
(
__name__
)
def
patch_tokenizer
(
tokenizer
:
"PreTrainedTokenizer"
,
model_args
:
"ModelArguments"
)
->
None
:
if
"PreTrainedTokenizerBase"
not
in
str
(
tokenizer
.
_pad
.
__func__
):
tokenizer
.
_pad
=
MethodType
(
PreTrainedTokenizerBase
.
_pad
,
tokenizer
)
if
model_args
.
model_max_length
is
not
None
and
tokenizer
.
model_max_length
<
model_args
.
model_max_length
:
tokenizer
.
model_max_length
=
model_args
.
model_max_length
# enlarge the tokenizer max length
if
model_args
.
add_tokens
is
not
None
:
num_added_tokens
=
tokenizer
.
add_tokens
(
new_tokens
=
model_args
.
add_tokens
,
special_tokens
=
False
)
logger
.
info_rank0
(
"Add tokens {} to tokenizer's vocabulary."
.
format
(
","
.
join
(
model_args
.
add_tokens
)))
if
num_added_tokens
>
0
and
not
model_args
.
resize_vocab
:
model_args
.
resize_vocab
=
True
logger
.
warning_rank0
(
"New tokens have been added, changed `resize_vocab` to True."
)
if
model_args
.
add_special_tokens
is
not
None
:
num_added_special_tokens
=
tokenizer
.
add_tokens
(
new_tokens
=
model_args
.
add_special_tokens
,
special_tokens
=
True
)
logger
.
info_rank0
(
"Add special tokens {} to tokenizer's vocabulary."
.
format
(
","
.
join
(
model_args
.
add_special_tokens
))
)
if
num_added_special_tokens
>
0
and
not
model_args
.
resize_vocab
:
model_args
.
resize_vocab
=
True
logger
.
warning_rank0
(
"New special tokens have been added, changed `resize_vocab` to True."
)
def
patch_processor
(
processor
:
"ProcessorMixin"
,
tokenizer
:
"PreTrainedTokenizer"
,
model_args
:
"ModelArguments"
,
)
->
None
:
setattr
(
processor
,
"tokenizer"
,
tokenizer
)
setattr
(
processor
,
"image_max_pixels"
,
model_args
.
image_max_pixels
)
setattr
(
processor
,
"image_min_pixels"
,
model_args
.
image_min_pixels
)
setattr
(
processor
,
"image_do_pan_and_scan"
,
model_args
.
image_do_pan_and_scan
)
setattr
(
processor
,
"crop_to_patches"
,
model_args
.
crop_to_patches
)
setattr
(
processor
,
"video_max_pixels"
,
model_args
.
video_max_pixels
)
setattr
(
processor
,
"video_min_pixels"
,
model_args
.
video_min_pixels
)
setattr
(
processor
,
"video_fps"
,
model_args
.
video_fps
)
setattr
(
processor
,
"video_maxlen"
,
model_args
.
video_maxlen
)
setattr
(
processor
,
"use_audio_in_video"
,
model_args
.
use_audio_in_video
)
setattr
(
processor
,
"audio_sampling_rate"
,
model_args
.
audio_sampling_rate
)
def
patch_config
(
config
:
"PretrainedConfig"
,
tokenizer
:
"PreTrainedTokenizer"
,
model_args
:
"ModelArguments"
,
init_kwargs
:
dict
[
str
,
Any
],
is_trainable
:
bool
,
)
->
None
:
if
model_args
.
compute_dtype
is
None
:
# priority: bf16 > fp16 > fp32
if
model_args
.
infer_dtype
!=
"auto"
and
not
is_trainable
:
model_args
.
compute_dtype
=
getattr
(
torch
,
model_args
.
infer_dtype
)
else
:
model_args
.
compute_dtype
=
infer_optim_dtype
(
model_dtype
=
getattr
(
config
,
"torch_dtype"
,
None
))
configure_attn_implementation
(
config
,
model_args
)
configure_rope
(
config
,
model_args
)
configure_longlora
(
config
,
model_args
,
is_trainable
)
configure_quantization
(
config
,
tokenizer
,
model_args
,
init_kwargs
)
configure_moe
(
config
,
model_args
,
is_trainable
)
configure_visual_model
(
config
)
configure_packing
(
model_args
,
is_trainable
)
configure_kv_cache
(
config
,
model_args
,
is_trainable
)
if
getattr
(
config
,
"model_type"
,
None
)
==
"qwen"
:
setattr
(
config
,
"use_flash_attn"
,
model_args
.
flash_attn
==
"fa2"
)
for
dtype_name
,
dtype
in
[(
"fp16"
,
torch
.
float16
),
(
"bf16"
,
torch
.
bfloat16
),
(
"fp32"
,
torch
.
float32
)]:
setattr
(
config
,
dtype_name
,
model_args
.
compute_dtype
==
dtype
)
if
getattr
(
config
,
"model_type"
,
None
)
==
"minicpmo"
:
setattr
(
config
,
"init_audio"
,
True
)
setattr
(
config
,
"init_tts"
,
False
)
# replace the top-k gating method
if
getattr
(
config
,
"model_type"
,
None
)
==
"kimi_vl"
and
is_trainable
:
setattr
(
config
.
text_config
,
"topk_method"
,
"greedy"
)
if
"InternVLChatModel"
in
getattr
(
config
,
"architectures"
,
[]):
raise
ValueError
(
"Please download the internvl models in a Hugging Face–compatible format "
"(for example, https://huggingface.co/OpenGVLab/InternVL3-8B-hf)."
)
if
"LlavaLlamaForCausalLM"
in
getattr
(
config
,
"architectures"
,
[]):
raise
ValueError
(
"Please download llava models with hf-compatible format: https://huggingface.co/llava-hf"
)
if
getattr
(
config
,
"model_type"
,
None
)
==
"internlm3"
and
not
is_transformers_version_greater_than
(
"4.47.1"
):
raise
RuntimeError
(
"InternLM3 model requires transformers>=4.47.1, please upgrade it."
)
# deepspeed zero3 is not compatible with low_cpu_mem_usage
init_kwargs
[
"low_cpu_mem_usage"
]
=
model_args
.
low_cpu_mem_usage
and
(
not
is_deepspeed_zero3_enabled
())
# do not cast data type of the model deepspeed zero3 without qlora
if
not
(
is_deepspeed_zero3_enabled
()
and
model_args
.
quantization_bit
is
None
):
init_kwargs
[
"torch_dtype"
]
=
model_args
.
compute_dtype
if
init_kwargs
[
"low_cpu_mem_usage"
]
and
not
is_fsdp_enabled
():
# fsdp does not need device map
if
"device_map"
not
in
init_kwargs
and
model_args
.
device_map
:
init_kwargs
[
"device_map"
]
=
model_args
.
device_map
# device map requires low_cpu_mem_usage=True
if
init_kwargs
.
get
(
"device_map"
,
None
)
==
"auto"
:
init_kwargs
[
"offload_folder"
]
=
model_args
.
offload_folder
def
patch_model
(
model
:
"PreTrainedModel"
,
tokenizer
:
"PreTrainedTokenizer"
,
model_args
:
"ModelArguments"
,
is_trainable
:
bool
,
add_valuehead
:
bool
,
)
->
None
:
gen_config
=
model
.
generation_config
# check and fix generation config
if
not
gen_config
.
do_sample
and
(
(
gen_config
.
temperature
is
not
None
and
gen_config
.
temperature
!=
1.0
)
or
(
gen_config
.
top_p
is
not
None
and
gen_config
.
top_p
!=
1.0
)
or
(
gen_config
.
typical_p
is
not
None
and
gen_config
.
typical_p
!=
1.0
)
):
gen_config
.
do_sample
=
True
if
getattr
(
model
.
config
,
"model_type"
,
None
)
not
in
[
"minicpmv"
,
"minicpmo"
]
and
"GenerationMixin"
not
in
str
(
model
.
generate
.
__func__
):
model
.
generate
=
MethodType
(
GenerationMixin
.
generate
,
model
)
if
add_valuehead
:
prepare_valuehead_model
(
model
)
if
model_args
.
resize_vocab
:
resize_embedding_layer
(
model
,
tokenizer
)
if
is_trainable
:
if
getattr
(
model
.
config
,
"model_type"
,
None
)
==
"gemma3n"
:
setattr
(
model_args
,
"disable_gradient_checkpointing"
,
True
)
prepare_model_for_training
(
model
,
model_args
)
autocast_projector_dtype
(
model
,
model_args
)
add_z3_leaf_module
(
model
)
if
not
model_args
.
use_unsloth
:
print_attn_implementation
(
model
.
config
)
try
:
model
.
add_model_tags
([
"llama-factory"
])
except
Exception
:
logger
.
warning_rank0
(
"Cannot properly tag the model."
)
def
patch_valuehead_model
(
model
:
"AutoModelForCausalLMWithValueHead"
)
->
None
:
def
tie_weights
(
self
:
"AutoModelForCausalLMWithValueHead"
)
->
None
:
if
isinstance
(
self
.
pretrained_model
,
PreTrainedModel
):
self
.
pretrained_model
.
tie_weights
()
def
get_input_embeddings
(
self
:
"AutoModelForCausalLMWithValueHead"
)
->
torch
.
nn
.
Module
:
if
isinstance
(
self
.
pretrained_model
,
PreTrainedModel
):
return
self
.
pretrained_model
.
get_input_embeddings
()
def
get_output_embeddings
(
self
:
"AutoModelForCausalLMWithValueHead"
)
->
torch
.
nn
.
Module
:
if
isinstance
(
self
.
pretrained_model
,
PreTrainedModel
):
return
self
.
pretrained_model
.
get_output_embeddings
()
def
create_or_update_model_card
(
self
:
"AutoModelForCausalLMWithValueHead"
,
output_dir
:
str
)
->
None
:
if
isinstance
(
self
.
pretrained_model
,
PeftModel
):
self
.
pretrained_model
.
create_or_update_model_card
(
output_dir
)
ignore_modules
=
[
name
for
name
,
_
in
model
.
named_parameters
()
if
"pretrained_model"
in
name
]
setattr
(
model
,
"_keys_to_ignore_on_save"
,
ignore_modules
)
setattr
(
model
,
"tie_weights"
,
MethodType
(
tie_weights
,
model
))
setattr
(
model
,
"get_input_embeddings"
,
MethodType
(
get_input_embeddings
,
model
))
setattr
(
model
,
"get_output_embeddings"
,
MethodType
(
get_output_embeddings
,
model
))
setattr
(
model
,
"create_or_update_model_card"
,
MethodType
(
create_or_update_model_card
,
model
))
src/llamafactory/third_party/__init__.py
0 → 100644
View file @
b59a5620
src/llamafactory/third_party/muon/__init__.py
0 → 100644
View file @
b59a5620
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.muon
import
Muon
__all__
=
[
"Muon"
]
src/llamafactory/third_party/muon/muon.py
0 → 100644
View file @
b59a5620
# Copyright 2025 Moonshot AI and the LlamaFactory team.
#
# This code is based on the MoonshotAI's Moonlight library.
# https://github.com/MoonshotAI/Moonlight/blob/master/examples/toy_train.py
# and the Keller Jordan's Muon library.
# https://github.com/KellerJordan/Muon/blob/master/muon.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# MIT License
#
# Copyright (c) 2025 Moonshot AI
# Copyright (c) 2024 Keller Jordan
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import
math
import
torch
def
zeropower_via_newtonschulz5
(
G
:
"torch.Tensor"
,
steps
:
int
)
->
"torch.Tensor"
:
"""Newton-Schulz iteration to compute the zeroth power / orthogonalization of G.
We opt to use a quintic iteration whose coefficients are selected to maximize the slope at zero.
For the purpose of minimizing steps, it turns out to be empirically effective to keep increasing
the slope at zero even beyond the point where the iteration no longer converges all the way to
one everywhere on the interval. This iteration therefore does not produce UV^T but rather something
like US'V^T where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
performance at all relative to UV^T, where USV^T = G is the SVD.
"""
assert
len
(
G
.
shape
)
==
2
a
,
b
,
c
=
(
3.4445
,
-
4.7750
,
2.0315
)
X
=
G
.
bfloat16
()
if
G
.
size
(
0
)
>
G
.
size
(
1
):
X
=
X
.
T
# Ensure spectral norm is at most 1
X
=
X
/
(
X
.
norm
()
+
1e-7
)
# Perform the NS iterations
for
_
in
range
(
steps
):
A
=
X
@
X
.
T
B
=
b
*
A
+
c
*
A
@
A
# adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
X
=
a
*
X
+
B
@
X
if
G
.
size
(
0
)
>
G
.
size
(
1
):
X
=
X
.
T
return
X
class
Muon
(
torch
.
optim
.
Optimizer
):
"""Muon - MomentUm Orthogonalized by Newton-schulz.
Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
the advantage that it can be stably run in bfloat16 on the GPU.
Some warnings:
- We believe this optimizer is unlikely to work well for training with small batch size.
- We believe it may not work well for finetuning pretrained models, but we haven't tested this.
Arguments:
muon_params: The parameters to be optimized by Muon.
lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default)
momentum: The momentum used by the internal SGD. (0.95 is a good default)
nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough)
adamw_params: The parameters to be optimized by AdamW. Any parameters in `muon_params` which are
{0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well.
adamw_lr: The learning rate for the internal AdamW.
adamw_betas: The betas for the internal AdamW.
adamw_eps: The epsilon for the internal AdamW.
adamw_wd: The weight decay for the internal AdamW.
"""
def
__init__
(
self
,
lr
=
1e-3
,
wd
=
0.1
,
muon_params
=
None
,
momentum
=
0.95
,
nesterov
=
True
,
ns_steps
=
5
,
adamw_params
=
None
,
adamw_betas
=
(
0.9
,
0.95
),
adamw_eps
=
1e-8
,
):
defaults
=
dict
(
lr
=
lr
,
wd
=
wd
,
momentum
=
momentum
,
nesterov
=
nesterov
,
ns_steps
=
ns_steps
,
adamw_betas
=
adamw_betas
,
adamw_eps
=
adamw_eps
,
)
params
=
list
(
muon_params
)
adamw_params
=
list
(
adamw_params
)
if
adamw_params
is
not
None
else
[]
params
.
extend
(
adamw_params
)
super
().
__init__
(
params
,
defaults
)
# Sort parameters into those for which we will use Muon, and those for which we will not
for
p
in
muon_params
:
# Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer
assert
p
.
ndim
==
2
,
p
.
ndim
self
.
state
[
p
][
"use_muon"
]
=
True
for
p
in
adamw_params
:
# Do not use Muon for parameters in adamw_params
self
.
state
[
p
][
"use_muon"
]
=
False
def
adjust_lr_for_muon
(
self
,
lr
:
float
,
param_shape
:
list
[
int
])
->
float
:
A
,
B
=
param_shape
[:
2
]
# We adjust the learning rate and weight decay based on the size of the parameter matrix
# as describted in the paper
adjusted_ratio
=
0.2
*
math
.
sqrt
(
max
(
A
,
B
))
adjusted_lr
=
lr
*
adjusted_ratio
return
adjusted_lr
def
step
(
self
,
closure
=
None
):
"""Perform a single optimization step.
Args:
closure (Callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss
=
None
if
closure
is
not
None
:
with
torch
.
enable_grad
():
loss
=
closure
()
for
group
in
self
.
param_groups
:
# Muon loop
params
=
[
p
for
p
in
group
[
"params"
]
if
self
.
state
[
p
][
"use_muon"
]]
lr
=
group
[
"lr"
]
wd
=
group
[
"wd"
]
momentum
=
group
[
"momentum"
]
# generate weight updates in distributed fashion
for
p
in
params
:
# sanity check
g
=
p
.
grad
if
g
is
None
:
continue
if
g
.
ndim
>
2
:
g
=
g
.
view
(
g
.
size
(
0
),
-
1
)
assert
g
is
not
None
# calc update
state
=
self
.
state
[
p
]
if
"momentum_buffer"
not
in
state
:
state
[
"momentum_buffer"
]
=
torch
.
zeros_like
(
g
)
buf
=
state
[
"momentum_buffer"
]
buf
.
mul_
(
momentum
).
add_
(
g
)
if
group
[
"nesterov"
]:
g
=
g
.
add
(
buf
,
alpha
=
momentum
)
else
:
g
=
buf
u
=
zeropower_via_newtonschulz5
(
g
,
steps
=
group
[
"ns_steps"
])
# scale update
adjusted_lr
=
self
.
adjust_lr_for_muon
(
lr
,
p
.
shape
)
# apply weight decay
p
.
data
.
mul_
(
1
-
lr
*
wd
)
# apply update
p
.
data
.
add_
(
u
,
alpha
=-
adjusted_lr
)
# Adam backup
params
=
[
p
for
p
in
group
[
"params"
]
if
not
self
.
state
[
p
][
"use_muon"
]]
lr
=
group
[
"lr"
]
beta1
,
beta2
=
group
[
"adamw_betas"
]
eps
=
group
[
"adamw_eps"
]
weight_decay
=
group
[
"wd"
]
for
p
in
params
:
g
=
p
.
grad
if
g
is
None
:
continue
state
=
self
.
state
[
p
]
if
"step"
not
in
state
:
state
[
"step"
]
=
0
state
[
"moment1"
]
=
torch
.
zeros_like
(
g
)
state
[
"moment2"
]
=
torch
.
zeros_like
(
g
)
state
[
"step"
]
+=
1
step
=
state
[
"step"
]
buf1
=
state
[
"moment1"
]
buf2
=
state
[
"moment2"
]
buf1
.
lerp_
(
g
,
1
-
beta1
)
buf2
.
lerp_
(
g
.
square
(),
1
-
beta2
)
g
=
buf1
/
(
eps
+
buf2
.
sqrt
())
bias_correction1
=
1
-
beta1
**
step
bias_correction2
=
1
-
beta2
**
step
scale
=
bias_correction1
/
bias_correction2
**
0.5
p
.
data
.
mul_
(
1
-
lr
*
weight_decay
)
p
.
data
.
add_
(
g
,
alpha
=-
lr
/
scale
)
return
loss
src/llamafactory/train/__init__.py
0 → 100644
View file @
b59a5620
src/llamafactory/train/callbacks.py
0 → 100644
View file @
b59a5620
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
json
import
os
import
signal
import
sys
import
time
from
concurrent.futures
import
ThreadPoolExecutor
from
datetime
import
timedelta
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
import
torch
import
transformers
from
peft
import
PeftModel
from
transformers
import
PreTrainedModel
,
ProcessorMixin
,
TrainerCallback
from
transformers.trainer_utils
import
PREFIX_CHECKPOINT_DIR
,
has_length
from
transformers.utils
import
(
SAFE_WEIGHTS_NAME
,
WEIGHTS_NAME
,
is_safetensors_available
,
)
from
typing_extensions
import
override
from
..extras
import
logging
from
..extras.constants
import
TRAINER_LOG
,
V_HEAD_SAFE_WEIGHTS_NAME
,
V_HEAD_WEIGHTS_NAME
from
..extras.misc
import
get_peak_memory
,
is_env_enabled
,
use_ray
if
is_safetensors_available
():
from
safetensors
import
safe_open
from
safetensors.torch
import
save_file
if
TYPE_CHECKING
:
from
transformers
import
TrainerControl
,
TrainerState
,
TrainingArguments
from
trl
import
AutoModelForCausalLMWithValueHead
from
..hparams
import
DataArguments
,
FinetuningArguments
,
GeneratingArguments
,
ModelArguments
logger
=
logging
.
get_logger
(
__name__
)
def
fix_valuehead_checkpoint
(
model
:
"AutoModelForCausalLMWithValueHead"
,
output_dir
:
str
,
safe_serialization
:
bool
)
->
None
:
r
"""Fix the valuehead checkpoint files.
The model is already unwrapped.
There are three cases:
1. full tuning without ds_zero3: state_dict = {"model.layers.*": ..., "v_head.summary.*": ...}
2. lora tuning without ds_zero3: state_dict = {"v_head.summary.*": ...}
3. under deepspeed zero3: state_dict = {"pretrained_model.model.layers.*": ..., "v_head.summary.*": ...}
We assume `stage3_gather_16bit_weights_on_model_save=true`.
"""
if
not
isinstance
(
model
.
pretrained_model
,
(
PreTrainedModel
,
PeftModel
)):
return
if
safe_serialization
:
path_to_checkpoint
=
os
.
path
.
join
(
output_dir
,
SAFE_WEIGHTS_NAME
)
with
safe_open
(
path_to_checkpoint
,
framework
=
"pt"
,
device
=
"cpu"
)
as
f
:
state_dict
:
dict
[
str
,
torch
.
Tensor
]
=
{
key
:
f
.
get_tensor
(
key
)
for
key
in
f
.
keys
()}
else
:
path_to_checkpoint
=
os
.
path
.
join
(
output_dir
,
WEIGHTS_NAME
)
state_dict
:
dict
[
str
,
torch
.
Tensor
]
=
torch
.
load
(
path_to_checkpoint
,
map_location
=
"cpu"
,
weights_only
=
True
)
os
.
remove
(
path_to_checkpoint
)
decoder_state_dict
,
v_head_state_dict
=
{},
{}
for
name
,
param
in
state_dict
.
items
():
if
name
.
startswith
(
"v_head."
):
v_head_state_dict
[
name
]
=
param
else
:
decoder_state_dict
[
name
.
replace
(
"pretrained_model."
,
""
,
1
)]
=
param
model
.
pretrained_model
.
save_pretrained
(
output_dir
,
state_dict
=
decoder_state_dict
or
None
,
safe_serialization
=
safe_serialization
)
if
safe_serialization
:
save_file
(
v_head_state_dict
,
os
.
path
.
join
(
output_dir
,
V_HEAD_SAFE_WEIGHTS_NAME
),
metadata
=
{
"format"
:
"pt"
})
else
:
torch
.
save
(
v_head_state_dict
,
os
.
path
.
join
(
output_dir
,
V_HEAD_WEIGHTS_NAME
))
logger
.
info_rank0
(
f
"Value head model saved at:
{
output_dir
}
"
)
class
FixValueHeadModelCallback
(
TrainerCallback
):
r
"""A callback for fixing the checkpoint for valuehead models."""
@
override
def
on_save
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
if
args
.
should_save
:
output_dir
=
os
.
path
.
join
(
args
.
output_dir
,
f
"
{
PREFIX_CHECKPOINT_DIR
}
-
{
state
.
global_step
}
"
)
fix_valuehead_checkpoint
(
model
=
kwargs
.
pop
(
"model"
),
output_dir
=
output_dir
,
safe_serialization
=
args
.
save_safetensors
)
class
SaveProcessorCallback
(
TrainerCallback
):
r
"""A callback for saving the processor."""
def
__init__
(
self
,
processor
:
"ProcessorMixin"
)
->
None
:
self
.
processor
=
processor
@
override
def
on_save
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
if
args
.
should_save
:
output_dir
=
os
.
path
.
join
(
args
.
output_dir
,
f
"
{
PREFIX_CHECKPOINT_DIR
}
-
{
state
.
global_step
}
"
)
self
.
processor
.
save_pretrained
(
output_dir
)
@
override
def
on_train_end
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
if
args
.
should_save
:
self
.
processor
.
save_pretrained
(
args
.
output_dir
)
class
PissaConvertCallback
(
TrainerCallback
):
r
"""A callback for converting the PiSSA adapter to a normal one."""
@
override
def
on_train_begin
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
if
args
.
should_save
:
model
=
kwargs
.
pop
(
"model"
)
pissa_init_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"pissa_init"
)
logger
.
info_rank0
(
f
"Initial PiSSA adapter will be saved at:
{
pissa_init_dir
}
."
)
if
isinstance
(
model
,
PeftModel
):
init_lora_weights
=
getattr
(
model
.
peft_config
[
"default"
],
"init_lora_weights"
)
setattr
(
model
.
peft_config
[
"default"
],
"init_lora_weights"
,
True
)
model
.
save_pretrained
(
pissa_init_dir
,
safe_serialization
=
args
.
save_safetensors
)
setattr
(
model
.
peft_config
[
"default"
],
"init_lora_weights"
,
init_lora_weights
)
@
override
def
on_train_end
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
if
args
.
should_save
:
model
=
kwargs
.
pop
(
"model"
)
pissa_init_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"pissa_init"
)
pissa_backup_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"pissa_backup"
)
pissa_convert_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"pissa_converted"
)
logger
.
info_rank0
(
f
"Converted PiSSA adapter will be saved at:
{
pissa_convert_dir
}
."
)
# 1. save a pissa backup with init_lora_weights: True
# 2. save a converted lora with init_lora_weights: pissa
# 3. load the pissa backup with init_lora_weights: True
# 4. delete the initial adapter and change init_lora_weights to pissa
if
isinstance
(
model
,
PeftModel
):
init_lora_weights
=
getattr
(
model
.
peft_config
[
"default"
],
"init_lora_weights"
)
setattr
(
model
.
peft_config
[
"default"
],
"init_lora_weights"
,
True
)
model
.
save_pretrained
(
pissa_backup_dir
,
safe_serialization
=
args
.
save_safetensors
)
setattr
(
model
.
peft_config
[
"default"
],
"init_lora_weights"
,
init_lora_weights
)
model
.
save_pretrained
(
pissa_convert_dir
,
safe_serialization
=
args
.
save_safetensors
,
path_initial_model_for_weight_conversion
=
pissa_init_dir
,
)
model
.
load_adapter
(
pissa_backup_dir
,
"default"
,
is_trainable
=
True
)
model
.
set_adapter
(
"default"
)
setattr
(
model
.
peft_config
[
"default"
],
"init_lora_weights"
,
init_lora_weights
)
class
LogCallback
(
TrainerCallback
):
r
"""A callback for logging training and evaluation status."""
def
__init__
(
self
)
->
None
:
# Progress
self
.
start_time
=
0
self
.
cur_steps
=
0
self
.
max_steps
=
0
self
.
elapsed_time
=
""
self
.
remaining_time
=
""
self
.
thread_pool
:
Optional
[
ThreadPoolExecutor
]
=
None
# Status
self
.
aborted
=
False
self
.
do_train
=
False
# Web UI
self
.
webui_mode
=
is_env_enabled
(
"LLAMABOARD_ENABLED"
)
if
self
.
webui_mode
and
not
use_ray
():
signal
.
signal
(
signal
.
SIGABRT
,
self
.
_set_abort
)
self
.
logger_handler
=
logging
.
LoggerHandler
(
os
.
getenv
(
"LLAMABOARD_WORKDIR"
))
logging
.
add_handler
(
self
.
logger_handler
)
transformers
.
logging
.
add_handler
(
self
.
logger_handler
)
def
_set_abort
(
self
,
signum
,
frame
)
->
None
:
self
.
aborted
=
True
def
_reset
(
self
,
max_steps
:
int
=
0
)
->
None
:
self
.
start_time
=
time
.
time
()
self
.
cur_steps
=
0
self
.
max_steps
=
max_steps
self
.
elapsed_time
=
""
self
.
remaining_time
=
""
def
_timing
(
self
,
cur_steps
:
int
)
->
None
:
cur_time
=
time
.
time
()
elapsed_time
=
cur_time
-
self
.
start_time
avg_time_per_step
=
elapsed_time
/
cur_steps
if
cur_steps
!=
0
else
0
remaining_time
=
(
self
.
max_steps
-
cur_steps
)
*
avg_time_per_step
self
.
cur_steps
=
cur_steps
self
.
elapsed_time
=
str
(
timedelta
(
seconds
=
int
(
elapsed_time
)))
self
.
remaining_time
=
str
(
timedelta
(
seconds
=
int
(
remaining_time
)))
def
_write_log
(
self
,
output_dir
:
str
,
logs
:
dict
[
str
,
Any
])
->
None
:
with
open
(
os
.
path
.
join
(
output_dir
,
TRAINER_LOG
),
"a"
,
encoding
=
"utf-8"
)
as
f
:
f
.
write
(
json
.
dumps
(
logs
)
+
"
\n
"
)
def
_create_thread_pool
(
self
,
output_dir
:
str
)
->
None
:
os
.
makedirs
(
output_dir
,
exist_ok
=
True
)
self
.
thread_pool
=
ThreadPoolExecutor
(
max_workers
=
1
)
def
_close_thread_pool
(
self
)
->
None
:
if
self
.
thread_pool
is
not
None
:
self
.
thread_pool
.
shutdown
(
wait
=
True
)
self
.
thread_pool
=
None
@
override
def
on_init_end
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
if
(
args
.
should_save
and
os
.
path
.
exists
(
os
.
path
.
join
(
args
.
output_dir
,
TRAINER_LOG
))
and
args
.
overwrite_output_dir
):
logger
.
warning_rank0_once
(
"Previous trainer log in this folder will be deleted."
)
os
.
remove
(
os
.
path
.
join
(
args
.
output_dir
,
TRAINER_LOG
))
@
override
def
on_train_begin
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
if
args
.
should_save
:
self
.
do_train
=
True
self
.
_reset
(
max_steps
=
state
.
max_steps
)
self
.
_create_thread_pool
(
output_dir
=
args
.
output_dir
)
@
override
def
on_train_end
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
self
.
_close_thread_pool
()
@
override
def
on_substep_end
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
if
self
.
aborted
:
control
.
should_epoch_stop
=
True
control
.
should_training_stop
=
True
@
override
def
on_step_end
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
if
self
.
aborted
:
control
.
should_epoch_stop
=
True
control
.
should_training_stop
=
True
@
override
def
on_evaluate
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
if
not
self
.
do_train
:
self
.
_close_thread_pool
()
@
override
def
on_predict
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
if
not
self
.
do_train
:
self
.
_close_thread_pool
()
@
override
def
on_log
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
if
not
args
.
should_save
:
return
self
.
_timing
(
cur_steps
=
state
.
global_step
)
logs
=
dict
(
current_steps
=
self
.
cur_steps
,
total_steps
=
self
.
max_steps
,
loss
=
state
.
log_history
[
-
1
].
get
(
"loss"
),
eval_loss
=
state
.
log_history
[
-
1
].
get
(
"eval_loss"
),
predict_loss
=
state
.
log_history
[
-
1
].
get
(
"predict_loss"
),
reward
=
state
.
log_history
[
-
1
].
get
(
"reward"
),
accuracy
=
state
.
log_history
[
-
1
].
get
(
"rewards/accuracies"
),
lr
=
state
.
log_history
[
-
1
].
get
(
"learning_rate"
),
epoch
=
state
.
log_history
[
-
1
].
get
(
"epoch"
),
percentage
=
round
(
self
.
cur_steps
/
self
.
max_steps
*
100
,
2
)
if
self
.
max_steps
!=
0
else
100
,
elapsed_time
=
self
.
elapsed_time
,
remaining_time
=
self
.
remaining_time
,
)
if
state
.
num_input_tokens_seen
:
logs
[
"throughput"
]
=
round
(
state
.
num_input_tokens_seen
/
(
time
.
time
()
-
self
.
start_time
),
2
)
logs
[
"total_tokens"
]
=
state
.
num_input_tokens_seen
if
is_env_enabled
(
"RECORD_VRAM"
):
vram_allocated
,
vram_reserved
=
get_peak_memory
()
logs
[
"vram_allocated"
]
=
round
(
vram_allocated
/
(
1024
**
3
),
2
)
logs
[
"vram_reserved"
]
=
round
(
vram_reserved
/
(
1024
**
3
),
2
)
logs
=
{
k
:
v
for
k
,
v
in
logs
.
items
()
if
v
is
not
None
}
if
self
.
webui_mode
and
all
(
key
in
logs
for
key
in
(
"loss"
,
"lr"
,
"epoch"
)):
log_str
=
f
"'loss':
{
logs
[
'loss'
]:.
4
f
}
, 'learning_rate':
{
logs
[
'lr'
]:
2.4
e
}
, 'epoch':
{
logs
[
'epoch'
]:.
2
f
}
"
for
extra_key
in
(
"reward"
,
"accuracy"
,
"throughput"
):
if
logs
.
get
(
extra_key
):
log_str
+=
f
", '
{
extra_key
}
':
{
logs
[
extra_key
]:.
2
f
}
"
logger
.
info_rank0
(
"{"
+
log_str
+
"}"
)
if
self
.
thread_pool
is
not
None
:
self
.
thread_pool
.
submit
(
self
.
_write_log
,
args
.
output_dir
,
logs
)
@
override
def
on_prediction_step
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
if
self
.
do_train
:
return
if
self
.
aborted
:
sys
.
exit
(
0
)
if
not
args
.
should_save
:
return
eval_dataloader
=
kwargs
.
pop
(
"eval_dataloader"
,
None
)
if
has_length
(
eval_dataloader
):
if
self
.
max_steps
==
0
:
self
.
_reset
(
max_steps
=
len
(
eval_dataloader
))
self
.
_create_thread_pool
(
output_dir
=
args
.
output_dir
)
self
.
_timing
(
cur_steps
=
self
.
cur_steps
+
1
)
if
self
.
cur_steps
%
5
==
0
and
self
.
thread_pool
is
not
None
:
logs
=
dict
(
current_steps
=
self
.
cur_steps
,
total_steps
=
self
.
max_steps
,
percentage
=
round
(
self
.
cur_steps
/
self
.
max_steps
*
100
,
2
)
if
self
.
max_steps
!=
0
else
100
,
elapsed_time
=
self
.
elapsed_time
,
remaining_time
=
self
.
remaining_time
,
)
self
.
thread_pool
.
submit
(
self
.
_write_log
,
args
.
output_dir
,
logs
)
class
ReporterCallback
(
TrainerCallback
):
r
"""A callback for reporting training status to external logger."""
def
__init__
(
self
,
model_args
:
"ModelArguments"
,
data_args
:
"DataArguments"
,
finetuning_args
:
"FinetuningArguments"
,
generating_args
:
"GeneratingArguments"
,
)
->
None
:
self
.
model_args
=
model_args
self
.
data_args
=
data_args
self
.
finetuning_args
=
finetuning_args
self
.
generating_args
=
generating_args
os
.
environ
[
"WANDB_PROJECT"
]
=
os
.
getenv
(
"WANDB_PROJECT"
,
"llamafactory"
)
@
override
def
on_train_begin
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
if
not
state
.
is_world_process_zero
:
return
if
"wandb"
in
args
.
report_to
:
import
wandb
wandb
.
config
.
update
(
{
"model_args"
:
self
.
model_args
.
to_dict
(),
"data_args"
:
self
.
data_args
.
to_dict
(),
"finetuning_args"
:
self
.
finetuning_args
.
to_dict
(),
"generating_args"
:
self
.
generating_args
.
to_dict
(),
}
)
if
self
.
finetuning_args
.
use_swanlab
:
import
swanlab
# type: ignore
swanlab
.
config
.
update
(
{
"model_args"
:
self
.
model_args
.
to_dict
(),
"data_args"
:
self
.
data_args
.
to_dict
(),
"finetuning_args"
:
self
.
finetuning_args
.
to_dict
(),
"generating_args"
:
self
.
generating_args
.
to_dict
(),
}
)
src/llamafactory/train/dpo/__init__.py
0 → 100644
View file @
b59a5620
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.workflow
import
run_dpo
__all__
=
[
"run_dpo"
]
src/llamafactory/train/dpo/trainer.py
0 → 100644
View file @
b59a5620
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's TRL library.
# https://github.com/huggingface/trl/blob/v0.8.0/trl/trainer/dpo_trainer.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
warnings
from
collections
import
defaultdict
from
contextlib
import
nullcontext
from
types
import
MethodType
from
typing
import
TYPE_CHECKING
,
Literal
,
Optional
,
Union
import
torch
import
torch.nn.functional
as
F
from
transformers
import
Trainer
from
trl
import
DPOTrainer
from
trl.trainer
import
disable_dropout_in_model
from
typing_extensions
import
override
from
...extras.constants
import
IGNORE_INDEX
from
...extras.packages
import
is_transformers_version_greater_than
from
..callbacks
import
SaveProcessorCallback
from
..trainer_utils
import
create_custom_optimizer
,
create_custom_scheduler
,
get_batch_logps
,
nested_detach
if
TYPE_CHECKING
:
from
transformers
import
PreTrainedModel
,
ProcessorMixin
from
...hparams
import
FinetuningArguments
class
CustomDPOTrainer
(
DPOTrainer
):
def
__init__
(
self
,
model
:
Union
[
"PreTrainedModel"
,
torch
.
nn
.
Module
],
ref_model
:
Optional
[
Union
[
"PreTrainedModel"
,
torch
.
nn
.
Module
]],
finetuning_args
:
"FinetuningArguments"
,
processor
:
Optional
[
"ProcessorMixin"
],
disable_dropout
:
bool
=
True
,
**
kwargs
,
):
if
is_transformers_version_greater_than
(
"4.46"
):
kwargs
[
"processing_class"
]
=
kwargs
.
pop
(
"tokenizer"
)
if
disable_dropout
:
disable_dropout_in_model
(
model
)
if
ref_model
is
not
None
:
disable_dropout_in_model
(
ref_model
)
self
.
finetuning_args
=
finetuning_args
self
.
f_divergence_type
=
"reverse_kl"
self
.
reference_free
=
False
self
.
use_dpo_data_collator
=
True
# hack to avoid warning
self
.
generate_during_eval
=
False
# disable at evaluation
self
.
label_pad_token_id
=
IGNORE_INDEX
self
.
padding_value
=
0
self
.
is_encoder_decoder
=
model
.
config
.
is_encoder_decoder
self
.
precompute_ref_log_probs
=
False
self
.
_precomputed_train_ref_log_probs
=
False
self
.
_precomputed_eval_ref_log_probs
=
False
self
.
_peft_has_been_casted_to_bf16
=
False
self
.
ref_model
=
ref_model
self
.
_stored_metrics
=
defaultdict
(
lambda
:
defaultdict
(
list
))
# dpo hyperparams
self
.
beta
=
finetuning_args
.
pref_beta
self
.
loss_type
=
finetuning_args
.
pref_loss
self
.
ftx_gamma
=
finetuning_args
.
pref_ftx
self
.
label_smoothing
=
finetuning_args
.
dpo_label_smoothing
self
.
simpo_gamma
=
finetuning_args
.
simpo_gamma
self
.
ld_alpha
=
finetuning_args
.
ld_alpha
Trainer
.
__init__
(
self
,
model
=
model
,
**
kwargs
)
self
.
model_accepts_loss_kwargs
=
False
# overwrite trainer's default behavior
if
not
hasattr
(
self
,
"accelerator"
):
raise
AttributeError
(
"Please update `transformers`."
)
warnings
.
simplefilter
(
"ignore"
)
# remove gc warnings on ref model
if
ref_model
is
not
None
:
if
self
.
is_deepspeed_enabled
:
if
not
(
getattr
(
ref_model
,
"is_loaded_in_8bit"
,
False
)
or
getattr
(
ref_model
,
"is_loaded_in_4bit"
,
False
)
):
# quantized models are already set on the correct device
self
.
ref_model
=
self
.
_prepare_deepspeed
(
self
.
ref_model
)
else
:
self
.
ref_model
=
self
.
accelerator
.
prepare_model
(
self
.
ref_model
,
evaluation_mode
=
True
)
self
.
ref_model
.
eval
()
if
processor
is
not
None
:
self
.
add_callback
(
SaveProcessorCallback
(
processor
))
if
finetuning_args
.
use_badam
:
from
badam
import
BAdamCallback
,
clip_grad_norm_old_version
# type: ignore
self
.
accelerator
.
clip_grad_norm_
=
MethodType
(
clip_grad_norm_old_version
,
self
.
accelerator
)
self
.
add_callback
(
BAdamCallback
)
@
override
def
create_optimizer
(
self
)
->
"torch.optim.Optimizer"
:
if
self
.
optimizer
is
None
:
self
.
optimizer
=
create_custom_optimizer
(
self
.
model
,
self
.
args
,
self
.
finetuning_args
)
return
super
().
create_optimizer
()
@
override
def
create_scheduler
(
self
,
num_training_steps
:
int
,
optimizer
:
Optional
[
"torch.optim.Optimizer"
]
=
None
)
->
"torch.optim.lr_scheduler.LRScheduler"
:
create_custom_scheduler
(
self
.
args
,
num_training_steps
,
optimizer
)
return
super
().
create_scheduler
(
num_training_steps
,
optimizer
)
@
override
def
_get_train_sampler
(
self
,
*
args
,
**
kwargs
)
->
Optional
[
"torch.utils.data.Sampler"
]:
if
self
.
finetuning_args
.
disable_shuffling
:
return
torch
.
utils
.
data
.
SequentialSampler
(
self
.
train_dataset
)
return
super
().
_get_train_sampler
(
*
args
,
**
kwargs
)
@
override
def
get_batch_samples
(
self
,
*
args
,
**
kwargs
):
r
"""Replace the method of DPO Trainer with the one of the standard Trainer."""
return
Trainer
.
get_batch_samples
(
self
,
*
args
,
**
kwargs
)
def
odds_ratio_loss
(
self
,
chosen_logps
:
"torch.Tensor"
,
rejected_logps
:
"torch.Tensor"
)
->
"torch.Tensor"
:
r
"""Compute ORPO's odds ratio (OR) loss for batched log probabilities of the policy model."""
log_odds
=
(
chosen_logps
-
rejected_logps
)
-
(
torch
.
log1p
(
-
torch
.
exp
(
chosen_logps
))
-
torch
.
log1p
(
-
torch
.
exp
(
rejected_logps
))
)
sft_loss
=
-
chosen_logps
odds_ratio_loss
=
-
F
.
logsigmoid
(
log_odds
)
orpo_loss
=
sft_loss
+
self
.
beta
*
odds_ratio_loss
return
orpo_loss
def
simpo_loss
(
self
,
chosen_logps
:
"torch.Tensor"
,
rejected_logps
:
"torch.Tensor"
)
->
"torch.Tensor"
:
r
"""Compute SimPO loss for batched log probabilities of the policy model."""
pi_logratios
=
chosen_logps
-
rejected_logps
gamma_logratios
=
self
.
simpo_gamma
/
self
.
beta
logits
=
pi_logratios
-
gamma_logratios
simpo_loss
=
-
F
.
logsigmoid
(
self
.
beta
*
logits
)
return
simpo_loss
def
compute_preference_loss
(
self
,
policy_chosen_logps
:
"torch.Tensor"
,
policy_rejected_logps
:
"torch.Tensor"
,
reference_chosen_logps
:
Optional
[
"torch.Tensor"
],
reference_rejected_logps
:
Optional
[
"torch.Tensor"
],
)
->
tuple
[
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
]:
r
"""Compute loss for preference learning."""
if
not
self
.
finetuning_args
.
use_ref_model
:
if
self
.
loss_type
==
"orpo"
:
losses
=
self
.
odds_ratio_loss
(
policy_chosen_logps
,
policy_rejected_logps
)
elif
self
.
loss_type
==
"simpo"
:
losses
=
self
.
simpo_loss
(
policy_chosen_logps
,
policy_rejected_logps
)
else
:
raise
NotImplementedError
(
f
"Unknown loss type:
{
self
.
loss_type
}
."
)
chosen_rewards
=
self
.
beta
*
policy_chosen_logps
.
to
(
self
.
accelerator
.
device
).
detach
()
rejected_rewards
=
self
.
beta
*
policy_rejected_logps
.
to
(
self
.
accelerator
.
device
).
detach
()
else
:
losses
,
chosen_rewards
,
rejected_rewards
=
self
.
dpo_loss
(
policy_chosen_logps
,
policy_rejected_logps
,
reference_chosen_logps
,
reference_rejected_logps
)
return
losses
,
chosen_rewards
,
rejected_rewards
@
override
def
concatenated_forward
(
self
,
model
:
"PreTrainedModel"
,
batch
:
dict
[
str
,
"torch.Tensor"
],
is_ref_model
:
bool
=
False
)
->
tuple
[
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
]:
r
"""Compute the sum log probabilities of the labels under given logits if loss_type is not IPO, ORPO or SimPO.
Otherwise the average log probabilities.
"""
if
self
.
finetuning_args
.
use_ref_model
:
batch
=
nested_detach
(
batch
,
clone
=
True
)
# avoid error
all_logits
:
torch
.
Tensor
=
model
(
**
batch
,
return_dict
=
True
,
use_cache
=
False
).
logits
.
to
(
torch
.
float32
)
all_logps
,
valid_length
=
get_batch_logps
(
logits
=
all_logits
,
labels
=
batch
[
"labels"
],
ld_alpha
=
(
self
.
ld_alpha
if
not
is_ref_model
else
None
)
)
if
self
.
loss_type
in
[
"ipo"
,
"orpo"
,
"simpo"
]:
all_logps
=
all_logps
/
valid_length
batch_size
=
batch
[
"input_ids"
].
size
(
0
)
//
2
chosen_logps
,
rejected_logps
=
all_logps
.
split
(
batch_size
,
dim
=
0
)
chosen_logits
,
rejected_logits
=
all_logits
.
split
(
batch_size
,
dim
=
0
)
chosen_length
,
_
=
valid_length
.
split
(
batch_size
,
dim
=
0
)
if
self
.
loss_type
in
[
"ipo"
,
"orpo"
,
"simpo"
]:
return
chosen_logps
,
rejected_logps
,
chosen_logits
,
rejected_logits
,
chosen_logps
else
:
return
chosen_logps
,
rejected_logps
,
chosen_logits
,
rejected_logits
,
chosen_logps
/
chosen_length
@
override
def
compute_reference_log_probs
(
self
,
model
:
"PreTrainedModel"
,
batch
:
dict
[
str
,
"torch.Tensor"
]
)
->
tuple
[
Optional
[
"torch.Tensor"
],
Optional
[
"torch.Tensor"
]]:
r
"""Compute log probabilities of the reference model."""
if
not
self
.
finetuning_args
.
use_ref_model
:
return
None
,
None
if
self
.
ref_model
is
None
:
ref_model
=
model
ref_context
=
self
.
accelerator
.
unwrap_model
(
model
).
disable_adapter
()
else
:
ref_model
=
self
.
ref_model
ref_context
=
nullcontext
()
with
torch
.
no_grad
(),
ref_context
:
reference_chosen_logps
,
reference_rejected_logps
,
*
_
=
self
.
concatenated_forward
(
ref_model
,
batch
,
is_ref_model
=
True
)
return
reference_chosen_logps
,
reference_rejected_logps
@
override
def
get_batch_loss_metrics
(
self
,
model
:
"PreTrainedModel"
,
batch
:
dict
[
str
,
"torch.Tensor"
],
train_eval
:
Literal
[
"train"
,
"eval"
]
=
"train"
,
)
->
tuple
[
"torch.Tensor"
,
dict
[
str
,
"torch.Tensor"
]]:
r
"""Compute the DPO loss and other metrics for the given batch of inputs for train or test."""
metrics
=
{}
(
policy_chosen_logps
,
policy_rejected_logps
,
policy_chosen_logits
,
policy_rejected_logits
,
policy_chosen_logps_avg
,
)
=
self
.
concatenated_forward
(
model
,
batch
)
reference_chosen_logps
,
reference_rejected_logps
=
self
.
compute_reference_log_probs
(
model
,
batch
)
losses
,
chosen_rewards
,
rejected_rewards
=
self
.
compute_preference_loss
(
policy_chosen_logps
,
policy_rejected_logps
,
reference_chosen_logps
,
reference_rejected_logps
,
)
sft_loss
=
-
policy_chosen_logps_avg
if
self
.
ftx_gamma
>
1e-6
:
losses
+=
self
.
ftx_gamma
*
sft_loss
prefix
=
"eval_"
if
train_eval
==
"eval"
else
""
metrics
[
f
"
{
prefix
}
rewards/chosen"
]
=
chosen_rewards
.
mean
().
item
()
metrics
[
f
"
{
prefix
}
rewards/rejected"
]
=
rejected_rewards
.
mean
().
item
()
metrics
[
f
"
{
prefix
}
rewards/accuracies"
]
=
(
chosen_rewards
>
rejected_rewards
).
float
().
mean
().
item
()
metrics
[
f
"
{
prefix
}
rewards/margins"
]
=
(
chosen_rewards
-
rejected_rewards
).
mean
().
item
()
metrics
[
f
"
{
prefix
}
logps/chosen"
]
=
policy_chosen_logps
.
mean
().
item
()
metrics
[
f
"
{
prefix
}
logps/rejected"
]
=
policy_rejected_logps
.
mean
().
item
()
metrics
[
f
"
{
prefix
}
logits/chosen"
]
=
policy_chosen_logits
.
mean
().
item
()
metrics
[
f
"
{
prefix
}
logits/rejected"
]
=
policy_rejected_logits
.
mean
().
item
()
if
self
.
loss_type
==
"orpo"
:
metrics
[
f
"
{
prefix
}
sft_loss"
]
=
sft_loss
.
mean
().
item
()
metrics
[
f
"
{
prefix
}
odds_ratio_loss"
]
=
((
losses
-
sft_loss
)
/
self
.
beta
).
mean
().
item
()
return
losses
.
mean
(),
metrics
@
override
def
compute_loss
(
self
,
model
:
"PreTrainedModel"
,
inputs
:
dict
[
str
,
"torch.Tensor"
],
return_outputs
:
bool
=
False
,
**
kwargs
)
->
Union
[
"torch.Tensor"
,
tuple
[
"torch.Tensor"
,
list
[
"torch.Tensor"
]]]:
r
"""Subclass and override to accept extra kwargs."""
return
super
().
compute_loss
(
model
,
inputs
,
return_outputs
)
@
override
def
log
(
self
,
logs
:
dict
[
str
,
float
],
*
args
,
**
kwargs
)
->
None
:
r
"""Log `logs` on the various objects watching training, including stored metrics."""
# logs either has "loss" or "eval_loss"
train_eval
=
"train"
if
"loss"
in
logs
else
"eval"
# Add averaged stored metrics to logs
key_list
,
metric_list
=
[],
[]
for
key
,
metrics
in
self
.
_stored_metrics
[
train_eval
].
items
():
key_list
.
append
(
key
)
metric_list
.
append
(
torch
.
tensor
(
metrics
,
dtype
=
torch
.
float
).
to
(
self
.
accelerator
.
device
).
mean
().
item
())
del
self
.
_stored_metrics
[
train_eval
]
if
len
(
metric_list
)
<
10
:
# pad to for all reduce
for
i
in
range
(
10
-
len
(
metric_list
)):
key_list
.
append
(
f
"dummy_
{
i
}
"
)
metric_list
.
append
(
0.0
)
metric_list
=
torch
.
tensor
(
metric_list
,
dtype
=
torch
.
float
).
to
(
self
.
accelerator
.
device
)
metric_list
=
self
.
accelerator
.
reduce
(
metric_list
,
"mean"
).
tolist
()
for
key
,
metric
in
zip
(
key_list
,
metric_list
):
# add remaining items
if
not
key
.
startswith
(
"dummy_"
):
logs
[
key
]
=
metric
return
Trainer
.
log
(
self
,
logs
,
*
args
,
**
kwargs
)
src/llamafactory/train/dpo/workflow.py
0 → 100644
View file @
b59a5620
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's TRL library.
# https://github.com/huggingface/trl/blob/v0.8.0/examples/scripts/dpo.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
TYPE_CHECKING
,
Optional
from
...data
import
PairwiseDataCollatorWithPadding
,
get_dataset
,
get_template_and_fix_tokenizer
from
...extras.constants
import
IGNORE_INDEX
from
...extras.misc
import
calculate_tps
from
...extras.ploting
import
plot_loss
from
...hparams
import
ModelArguments
from
...model
import
load_model
,
load_tokenizer
from
..trainer_utils
import
create_modelcard_and_push
,
create_ref_model
from
.trainer
import
CustomDPOTrainer
if
TYPE_CHECKING
:
from
transformers
import
Seq2SeqTrainingArguments
,
TrainerCallback
from
...hparams
import
DataArguments
,
FinetuningArguments
def
run_dpo
(
model_args
:
"ModelArguments"
,
data_args
:
"DataArguments"
,
training_args
:
"Seq2SeqTrainingArguments"
,
finetuning_args
:
"FinetuningArguments"
,
callbacks
:
Optional
[
list
[
"TrainerCallback"
]]
=
None
,
):
tokenizer_module
=
load_tokenizer
(
model_args
)
tokenizer
=
tokenizer_module
[
"tokenizer"
]
template
=
get_template_and_fix_tokenizer
(
tokenizer
,
data_args
)
dataset_module
=
get_dataset
(
template
,
model_args
,
data_args
,
training_args
,
stage
=
"rm"
,
**
tokenizer_module
)
model
=
load_model
(
tokenizer
,
model_args
,
finetuning_args
,
training_args
.
do_train
)
data_collator
=
PairwiseDataCollatorWithPadding
(
template
=
template
,
model
=
model
,
pad_to_multiple_of
=
8
,
label_pad_token_id
=
IGNORE_INDEX
if
data_args
.
ignore_pad_token_for_loss
else
tokenizer
.
pad_token_id
,
**
tokenizer_module
,
)
# Create reference model
if
finetuning_args
.
use_ref_model
:
if
finetuning_args
.
ref_model
is
None
and
(
not
training_args
.
do_train
):
# use the model itself
ref_model
=
model
else
:
ref_model
=
create_ref_model
(
model_args
,
finetuning_args
)
else
:
ref_model
=
None
# Initialize our Trainer
trainer
=
CustomDPOTrainer
(
model
=
model
,
ref_model
=
ref_model
,
args
=
training_args
,
finetuning_args
=
finetuning_args
,
data_collator
=
data_collator
,
callbacks
=
callbacks
,
**
dataset_module
,
**
tokenizer_module
,
)
# Training
if
training_args
.
do_train
:
train_result
=
trainer
.
train
(
resume_from_checkpoint
=
training_args
.
resume_from_checkpoint
)
trainer
.
save_model
()
if
finetuning_args
.
include_effective_tokens_per_second
:
train_result
.
metrics
[
"effective_tokens_per_sec"
]
=
calculate_tps
(
dataset_module
[
"train_dataset"
],
train_result
.
metrics
,
stage
=
"rm"
)
trainer
.
log_metrics
(
"train"
,
train_result
.
metrics
)
trainer
.
save_metrics
(
"train"
,
train_result
.
metrics
)
trainer
.
save_state
()
if
trainer
.
is_world_process_zero
()
and
finetuning_args
.
plot_loss
:
keys
=
[
"loss"
,
"rewards/accuracies"
]
if
isinstance
(
dataset_module
.
get
(
"eval_dataset"
),
dict
):
keys
+=
[
f
"eval_
{
key
}
_loss"
for
key
in
dataset_module
[
"eval_dataset"
].
keys
()]
else
:
keys
+=
[
"eval_loss"
]
plot_loss
(
training_args
.
output_dir
,
keys
=
keys
)
# Evaluation
if
training_args
.
do_eval
:
metrics
=
trainer
.
evaluate
(
metric_key_prefix
=
"eval"
)
if
id
(
model
)
==
id
(
ref_model
):
# unable to compute rewards if reference model is the model itself
remove_keys
=
[
key
for
key
in
metrics
.
keys
()
if
"rewards"
in
key
]
for
key
in
remove_keys
:
metrics
.
pop
(
key
)
trainer
.
log_metrics
(
"eval"
,
metrics
)
trainer
.
save_metrics
(
"eval"
,
metrics
)
# Create model card
create_modelcard_and_push
(
trainer
,
model_args
,
data_args
,
training_args
,
finetuning_args
)
src/llamafactory/train/kto/__init__.py
0 → 100644
View file @
b59a5620
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.workflow
import
run_kto
__all__
=
[
"run_kto"
]
src/llamafactory/train/kto/trainer.py
0 → 100644
View file @
b59a5620
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's TRL library.
# https://github.com/huggingface/trl/blob/v0.8.0/trl/trainer/kto_trainer.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
warnings
from
collections
import
defaultdict
from
contextlib
import
nullcontext
from
types
import
MethodType
from
typing
import
TYPE_CHECKING
,
Literal
,
Optional
,
Union
import
torch
from
transformers
import
Trainer
from
trl
import
KTOTrainer
from
trl.trainer
import
disable_dropout_in_model
from
typing_extensions
import
override
from
...extras.constants
import
IGNORE_INDEX
from
...extras.packages
import
is_transformers_version_greater_than
from
..callbacks
import
SaveProcessorCallback
from
..trainer_utils
import
create_custom_optimizer
,
create_custom_scheduler
,
get_batch_logps
,
nested_detach
if
TYPE_CHECKING
:
from
transformers
import
PreTrainedModel
,
ProcessorMixin
from
...hparams
import
FinetuningArguments
class
CustomKTOTrainer
(
KTOTrainer
):
def
__init__
(
self
,
model
:
Union
[
"PreTrainedModel"
,
torch
.
nn
.
Module
],
ref_model
:
Optional
[
Union
[
"PreTrainedModel"
,
torch
.
nn
.
Module
]],
finetuning_args
:
"FinetuningArguments"
,
processor
:
Optional
[
"ProcessorMixin"
],
disable_dropout
:
bool
=
True
,
**
kwargs
,
):
if
is_transformers_version_greater_than
(
"4.46"
):
kwargs
[
"processing_class"
]
=
kwargs
.
pop
(
"tokenizer"
)
if
disable_dropout
:
disable_dropout_in_model
(
model
)
if
ref_model
is
not
None
:
disable_dropout_in_model
(
ref_model
)
self
.
finetuning_args
=
finetuning_args
self
.
reference_free
=
False
self
.
use_dpo_data_collator
=
True
# hack to avoid warning
self
.
generate_during_eval
=
False
# disable at evaluation
self
.
label_pad_token_id
=
IGNORE_INDEX
self
.
padding_value
=
0
self
.
is_encoder_decoder
=
model
.
config
.
is_encoder_decoder
self
.
precompute_ref_log_probs
=
False
self
.
_precomputed_train_ref_log_probs
=
False
self
.
_precomputed_eval_ref_log_probs
=
False
self
.
_peft_has_been_casted_to_bf16
=
False
self
.
ref_model
=
ref_model
self
.
_stored_metrics
=
defaultdict
(
lambda
:
defaultdict
(
list
))
# kto hyperparams
self
.
beta
=
finetuning_args
.
pref_beta
self
.
desirable_weight
=
finetuning_args
.
kto_chosen_weight
self
.
undesirable_weight
=
finetuning_args
.
kto_rejected_weight
self
.
ftx_gamma
=
finetuning_args
.
pref_ftx
Trainer
.
__init__
(
self
,
model
=
model
,
**
kwargs
)
self
.
model_accepts_loss_kwargs
=
False
# overwrite trainer's default behavior
if
not
hasattr
(
self
,
"accelerator"
):
raise
AttributeError
(
"Please update `transformers`."
)
warnings
.
simplefilter
(
"ignore"
)
# remove gc warnings on ref model
if
ref_model
is
not
None
:
if
self
.
is_deepspeed_enabled
:
if
not
(
getattr
(
ref_model
,
"is_loaded_in_8bit"
,
False
)
or
getattr
(
ref_model
,
"is_loaded_in_4bit"
,
False
)
):
# quantized models are already set on the correct device
self
.
ref_model
=
self
.
_prepare_deepspeed
(
self
.
ref_model
)
else
:
self
.
ref_model
=
self
.
accelerator
.
prepare_model
(
self
.
ref_model
,
evaluation_mode
=
True
)
self
.
ref_model
.
eval
()
if
processor
is
not
None
:
self
.
add_callback
(
SaveProcessorCallback
(
processor
))
if
finetuning_args
.
use_badam
:
from
badam
import
BAdamCallback
,
clip_grad_norm_old_version
# type: ignore
self
.
accelerator
.
clip_grad_norm_
=
MethodType
(
clip_grad_norm_old_version
,
self
.
accelerator
)
self
.
add_callback
(
BAdamCallback
)
@
override
def
create_optimizer
(
self
)
->
"torch.optim.Optimizer"
:
if
self
.
optimizer
is
None
:
self
.
optimizer
=
create_custom_optimizer
(
self
.
model
,
self
.
args
,
self
.
finetuning_args
)
return
super
().
create_optimizer
()
@
override
def
create_scheduler
(
self
,
num_training_steps
:
int
,
optimizer
:
Optional
[
"torch.optim.Optimizer"
]
=
None
)
->
"torch.optim.lr_scheduler.LRScheduler"
:
create_custom_scheduler
(
self
.
args
,
num_training_steps
,
optimizer
)
return
super
().
create_scheduler
(
num_training_steps
,
optimizer
)
@
override
def
_get_train_sampler
(
self
,
*
args
,
**
kwargs
)
->
Optional
[
"torch.utils.data.Sampler"
]:
r
"""Replace the sequential sampler of KTO Trainer created by trl with the random sampler."""
if
self
.
finetuning_args
.
disable_shuffling
:
return
torch
.
utils
.
data
.
SequentialSampler
(
self
.
train_dataset
)
return
Trainer
.
_get_train_sampler
(
self
,
*
args
,
**
kwargs
)
@
override
def
get_batch_samples
(
self
,
*
args
,
**
kwargs
):
r
"""Replace the method of KTO Trainer with the one of the standard Trainer."""
return
Trainer
.
get_batch_samples
(
self
,
*
args
,
**
kwargs
)
@
override
def
forward
(
self
,
model
:
"PreTrainedModel"
,
batch
:
dict
[
str
,
"torch.Tensor"
],
prefix
:
Literal
[
""
,
"kl_"
]
=
""
)
->
tuple
[
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
]:
r
"""Run forward pass and computes the log probabilities."""
batch
=
nested_detach
(
batch
,
clone
=
True
)
# avoid error
model_inputs
=
{
"input_ids"
:
batch
[
f
"
{
prefix
}
input_ids"
],
"attention_mask"
:
batch
[
f
"
{
prefix
}
attention_mask"
],
}
if
f
"
{
prefix
}
token_type_ids"
in
batch
:
model_inputs
[
"token_type_ids"
]
=
batch
[
f
"
{
prefix
}
token_type_ids"
]
if
"pixel_values"
in
batch
:
model_inputs
[
"pixel_values"
]
=
batch
[
"pixel_values"
]
if
"image_sizes"
in
batch
:
model_inputs
[
"image_sizes"
]
=
batch
[
"image_sizes"
]
if
"image_grid_thw"
in
batch
:
model_inputs
[
"image_grid_thw"
]
=
batch
[
"image_grid_thw"
]
if
"aspect_ratio_ids"
in
batch
:
model_inputs
[
"aspect_ratio_ids"
]
=
batch
[
"aspect_ratio_ids"
]
if
"aspect_ratio_mask"
in
batch
:
model_inputs
[
"aspect_ratio_mask"
]
=
batch
[
"aspect_ratio_mask"
]
if
f
"
{
prefix
}
cross_attention_mask"
in
batch
:
model_inputs
[
"cross_attention_mask"
]
=
batch
[
f
"
{
prefix
}
cross_attention_mask"
]
logits
=
model
(
**
model_inputs
,
return_dict
=
True
,
use_cache
=
False
).
logits
.
to
(
torch
.
float32
)
logps
,
valid_length
=
get_batch_logps
(
logits
=
logits
,
labels
=
batch
[
f
"
{
prefix
}
labels"
])
return
logits
,
logps
,
logps
/
valid_length
@
override
def
concatenated_forward
(
self
,
model
:
"PreTrainedModel"
,
batch
:
dict
[
str
,
"torch.Tensor"
]
)
->
tuple
[
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
]:
target_logits
,
target_logps
,
target_logps_avg
=
self
.
forward
(
model
,
batch
)
with
torch
.
no_grad
():
_
,
kl_logps
,
_
=
self
.
forward
(
model
,
batch
,
prefix
=
"kl_"
)
if
len
(
target_logps
)
!=
len
(
batch
[
"kto_tags"
]):
raise
ValueError
(
"Mismatched shape of inputs and labels."
)
chosen_logits
=
target_logits
[
batch
[
"kto_tags"
]]
chosen_logps
=
target_logps
[
batch
[
"kto_tags"
]]
rejected_logits
=
target_logits
[
~
batch
[
"kto_tags"
]]
rejected_logps
=
target_logps
[
~
batch
[
"kto_tags"
]]
chosen_logps_avg
=
target_logps_avg
[
batch
[
"kto_tags"
]]
return
chosen_logps
,
rejected_logps
,
chosen_logits
,
rejected_logits
,
kl_logps
,
chosen_logps_avg
@
override
def
compute_reference_log_probs
(
self
,
model
:
"PreTrainedModel"
,
batch
:
dict
[
str
,
"torch.Tensor"
]
)
->
tuple
[
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
]:
r
"""Compute log probabilities of the reference model."""
if
self
.
ref_model
is
None
:
ref_model
=
model
ref_context
=
self
.
accelerator
.
unwrap_model
(
model
).
disable_adapter
()
else
:
ref_model
=
self
.
ref_model
ref_context
=
nullcontext
()
with
torch
.
no_grad
(),
ref_context
:
reference_chosen_logps
,
reference_rejected_logps
,
_
,
_
,
reference_kl_logps
,
_
=
self
.
concatenated_forward
(
ref_model
,
batch
)
return
reference_chosen_logps
,
reference_rejected_logps
,
reference_kl_logps
@
override
def
get_batch_loss_metrics
(
self
,
model
:
"PreTrainedModel"
,
batch
:
dict
[
str
,
"torch.Tensor"
],
)
->
tuple
[
"torch.Tensor"
,
dict
[
str
,
"torch.Tensor"
]]:
r
"""Compute the DPO loss and other metrics for the given batch of inputs for train or test."""
metrics
=
{}
(
policy_chosen_logps
,
policy_rejected_logps
,
policy_chosen_logits
,
policy_rejected_logits
,
policy_kl_logps
,
policy_chosen_logps_avg
,
)
=
self
.
concatenated_forward
(
model
,
batch
)
reference_chosen_logps
,
reference_rejected_logps
,
reference_kl_logps
=
self
.
compute_reference_log_probs
(
model
,
batch
)
losses
,
chosen_rewards
,
rejected_rewards
,
kl
=
self
.
kto_loss
(
policy_chosen_logps
,
policy_rejected_logps
,
policy_kl_logps
,
reference_chosen_logps
,
reference_rejected_logps
,
reference_kl_logps
,
)
losses
=
losses
.
nanmean
()
if
self
.
ftx_gamma
>
1e-6
and
len
(
policy_chosen_logps
)
>
0
:
# remember to rescale
sft_loss
=
-
policy_chosen_logps_avg
losses
+=
self
.
ftx_gamma
*
sft_loss
.
nanmean
()
/
len
(
policy_chosen_logps
)
*
len
(
batch
[
"labels"
])
num_chosen
=
len
(
chosen_rewards
)
num_rejected
=
len
(
rejected_rewards
)
if
num_chosen
>
0
:
metrics
[
"rewards/chosen_sum"
]
=
chosen_rewards
.
nansum
().
item
()
metrics
[
"logps/chosen_sum"
]
=
policy_chosen_logps
.
nansum
().
item
()
metrics
[
"logits/chosen_sum"
]
=
policy_chosen_logits
.
nansum
().
item
()
metrics
[
"count/chosen"
]
=
float
(
num_chosen
)
if
num_rejected
>
0
:
metrics
[
"rewards/rejected_sum"
]
=
rejected_rewards
.
nansum
().
item
()
metrics
[
"logps/rejected_sum"
]
=
policy_rejected_logps
.
nansum
().
item
()
metrics
[
"logits/rejected_sum"
]
=
policy_rejected_logits
.
nansum
().
item
()
metrics
[
"count/rejected"
]
=
float
(
num_rejected
)
metrics
[
"kl"
]
=
kl
.
item
()
return
losses
,
metrics
@
override
def
compute_loss
(
self
,
model
:
"PreTrainedModel"
,
inputs
:
dict
[
str
,
"torch.Tensor"
],
return_outputs
:
bool
=
False
,
**
kwargs
)
->
Union
[
"torch.Tensor"
,
tuple
[
"torch.Tensor"
,
list
[
"torch.Tensor"
]]]:
r
"""Subclass and override to accept extra kwargs."""
return
super
().
compute_loss
(
model
,
inputs
,
return_outputs
)
@
override
def
log
(
self
,
logs
:
dict
[
str
,
float
],
*
args
,
**
kwargs
)
->
None
:
r
"""Log `logs` on the various objects watching training, including stored metrics."""
# logs either has "loss" or "eval_loss"
train_eval
=
"train"
if
"loss"
in
logs
else
"eval"
prefix
=
"eval_"
if
train_eval
==
"eval"
else
""
# Add averaged stored metrics to logs
key_list
,
metric_list
=
[],
[]
for
key
,
metrics
in
self
.
_stored_metrics
[
train_eval
].
items
():
key_list
.
append
(
key
)
metric_list
.
append
(
torch
.
tensor
(
metrics
,
dtype
=
torch
.
float
).
to
(
self
.
accelerator
.
device
).
sum
().
item
())
del
self
.
_stored_metrics
[
train_eval
]
if
len
(
metric_list
)
<
9
:
# pad to for all reduce
for
i
in
range
(
9
-
len
(
metric_list
)):
key_list
.
append
(
f
"dummy_
{
i
}
"
)
metric_list
.
append
(
0.0
)
metric_list
=
torch
.
tensor
(
metric_list
,
dtype
=
torch
.
float
).
to
(
self
.
accelerator
.
device
)
metric_list
=
self
.
accelerator
.
reduce
(
metric_list
,
"sum"
).
tolist
()
metric_dict
:
dict
[
str
,
float
]
=
dict
(
zip
(
key_list
,
metric_list
))
for
split
in
[
"chosen"
,
"rejected"
]:
# accumulate average metrics from sums and lengths
if
f
"count/
{
split
}
"
in
metric_dict
:
for
key
in
(
"rewards"
,
"logps"
,
"logits"
):
logs
[
f
"
{
prefix
}{
key
}
/
{
split
}
"
]
=
metric_dict
[
f
"
{
key
}
/
{
split
}
_sum"
]
/
metric_dict
[
f
"count/
{
split
}
"
]
del
metric_dict
[
f
"
{
key
}
/
{
split
}
_sum"
]
del
metric_dict
[
f
"count/
{
split
}
"
]
if
f
"
{
prefix
}
rewards/chosen"
in
logs
and
f
"
{
prefix
}
rewards/rejected"
in
logs
:
# calculate reward margin
logs
[
f
"
{
prefix
}
rewards/margins"
]
=
logs
[
f
"
{
prefix
}
rewards/chosen"
]
-
logs
[
f
"
{
prefix
}
rewards/rejected"
]
for
key
,
metric
in
metric_dict
.
items
():
# add remaining items
if
not
key
.
startswith
(
"dummy_"
):
logs
[
key
]
=
metric
return
Trainer
.
log
(
self
,
logs
,
*
args
,
**
kwargs
)
src/llamafactory/train/kto/workflow.py
0 → 100644
View file @
b59a5620
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's TRL library.
# https://github.com/huggingface/trl/blob/v0.8.0/examples/scripts/kto.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
TYPE_CHECKING
,
Optional
from
...data
import
KTODataCollatorWithPadding
,
get_dataset
,
get_template_and_fix_tokenizer
from
...extras.constants
import
IGNORE_INDEX
from
...extras.ploting
import
plot_loss
from
...hparams
import
ModelArguments
from
...model
import
load_model
,
load_tokenizer
from
..trainer_utils
import
create_modelcard_and_push
,
create_ref_model
from
.trainer
import
CustomKTOTrainer
if
TYPE_CHECKING
:
from
transformers
import
Seq2SeqTrainingArguments
,
TrainerCallback
from
...hparams
import
DataArguments
,
FinetuningArguments
def
run_kto
(
model_args
:
"ModelArguments"
,
data_args
:
"DataArguments"
,
training_args
:
"Seq2SeqTrainingArguments"
,
finetuning_args
:
"FinetuningArguments"
,
callbacks
:
Optional
[
list
[
"TrainerCallback"
]]
=
None
,
):
tokenizer_module
=
load_tokenizer
(
model_args
)
tokenizer
=
tokenizer_module
[
"tokenizer"
]
template
=
get_template_and_fix_tokenizer
(
tokenizer
,
data_args
)
dataset_module
=
get_dataset
(
template
,
model_args
,
data_args
,
training_args
,
stage
=
"kto"
,
**
tokenizer_module
)
model
=
load_model
(
tokenizer
,
model_args
,
finetuning_args
,
training_args
.
do_train
)
data_collator
=
KTODataCollatorWithPadding
(
template
=
template
,
model
=
model
,
pad_to_multiple_of
=
8
,
label_pad_token_id
=
IGNORE_INDEX
if
data_args
.
ignore_pad_token_for_loss
else
tokenizer
.
pad_token_id
,
**
tokenizer_module
,
)
# Create reference model
if
finetuning_args
.
ref_model
is
None
and
(
not
training_args
.
do_train
):
# use the model itself
ref_model
=
model
else
:
ref_model
=
create_ref_model
(
model_args
,
finetuning_args
)
# Initialize our Trainer
trainer
=
CustomKTOTrainer
(
model
=
model
,
ref_model
=
ref_model
,
args
=
training_args
,
finetuning_args
=
finetuning_args
,
data_collator
=
data_collator
,
callbacks
=
callbacks
,
**
dataset_module
,
**
tokenizer_module
,
)
# Training
if
training_args
.
do_train
:
train_result
=
trainer
.
train
(
resume_from_checkpoint
=
training_args
.
resume_from_checkpoint
)
trainer
.
save_model
()
trainer
.
log_metrics
(
"train"
,
train_result
.
metrics
)
trainer
.
save_metrics
(
"train"
,
train_result
.
metrics
)
trainer
.
save_state
()
if
trainer
.
is_world_process_zero
()
and
finetuning_args
.
plot_loss
:
keys
=
[
"loss"
,
"rewards/chosen"
]
if
isinstance
(
dataset_module
.
get
(
"eval_dataset"
),
dict
):
keys
+=
[
f
"eval_
{
key
}
_loss"
for
key
in
dataset_module
[
"eval_dataset"
].
keys
()]
else
:
keys
+=
[
"eval_loss"
]
plot_loss
(
training_args
.
output_dir
,
keys
=
keys
)
# Evaluation
if
training_args
.
do_eval
:
metrics
=
trainer
.
evaluate
(
metric_key_prefix
=
"eval"
)
if
id
(
model
)
==
id
(
ref_model
):
# unable to compute rewards without a reference model
remove_keys
=
[
key
for
key
in
metrics
.
keys
()
if
"rewards"
in
key
]
for
key
in
remove_keys
:
metrics
.
pop
(
key
)
trainer
.
log_metrics
(
"eval"
,
metrics
)
trainer
.
save_metrics
(
"eval"
,
metrics
)
# Create model card
create_modelcard_and_push
(
trainer
,
model_args
,
data_args
,
training_args
,
finetuning_args
)
src/llamafactory/train/ppo/__init__.py
0 → 100644
View file @
b59a5620
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.workflow
import
run_ppo
__all__
=
[
"run_ppo"
]
src/llamafactory/train/ppo/ppo_utils.py
0 → 100644
View file @
b59a5620
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
json
from
contextlib
import
nullcontext
from
typing
import
TYPE_CHECKING
,
Literal
,
Optional
import
torch
from
transformers.integrations
import
is_deepspeed_zero3_enabled
from
...extras.packages
import
is_requests_available
if
is_requests_available
():
import
requests
if
TYPE_CHECKING
:
from
transformers
import
PreTrainedModel
from
trl
import
AutoModelForCausalLMWithValueHead
def
get_rewards_from_server
(
server_url
:
str
,
messages
:
list
[
str
])
->
list
[
"torch.Tensor"
]:
r
"""Get reward scores from the API server."""
headers
=
{
"Content-Type"
:
"application/json"
}
payload
=
{
"model"
:
"model"
,
"messages"
:
messages
}
response
=
requests
.
post
(
server_url
,
json
=
payload
,
headers
=
headers
)
rewards
=
json
.
loads
(
response
.
text
)[
"scores"
]
return
torch
.
Tensor
(
rewards
)
def
replace_model
(
model
:
"AutoModelForCausalLMWithValueHead"
,
target
:
Literal
[
"default"
,
"reward"
])
->
None
:
r
"""Replace the default/reward modules in the model. The model is already unwrapped."""
v_head_layer
=
model
.
v_head
.
summary
if
is_deepspeed_zero3_enabled
():
import
deepspeed
# type: ignore
params
=
[
v_head_layer
.
weight
,
v_head_layer
.
bias
]
context_maybe_zero3
=
deepspeed
.
zero
.
GatheredParameters
(
params
,
modifier_rank
=
0
)
else
:
context_maybe_zero3
=
nullcontext
()
model
.
pretrained_model
.
set_adapter
(
target
)
# set the LoRA adapter to be active
with
context_maybe_zero3
:
if
target
==
"reward"
:
# save default head temporarily
setattr
(
model
,
"default_head_weight"
,
v_head_layer
.
weight
.
data
.
detach
().
clone
())
setattr
(
model
,
"default_head_bias"
,
v_head_layer
.
bias
.
data
.
detach
().
clone
())
device
=
v_head_layer
.
weight
.
device
v_head_layer
.
weight
.
data
=
model
.
get_buffer
(
f
"
{
target
}
_head_weight"
).
detach
().
clone
().
to
(
device
)
v_head_layer
.
bias
.
data
=
model
.
get_buffer
(
f
"
{
target
}
_head_bias"
).
detach
().
clone
().
to
(
device
)
def
dump_layernorm
(
model
:
"PreTrainedModel"
)
->
dict
[
str
,
"torch.Tensor"
]:
r
"""Dump the layernorm parameters in the model. The model is already unwrapped (and gathered)."""
layer_norm_params
=
{}
for
name
,
param
in
model
.
named_parameters
():
if
param
.
data
.
dtype
==
torch
.
float32
:
layer_norm_params
[
name
]
=
param
.
data
.
detach
().
clone
()
param
.
data
=
param
.
data
.
to
(
model
.
config
.
torch_dtype
)
return
layer_norm_params
def
restore_layernorm
(
model
:
"PreTrainedModel"
,
layernorm_params
:
Optional
[
dict
[
str
,
"torch.Tensor"
]]
=
None
)
->
None
:
r
"""Restore the layernorm parameters in the model. The model is already unwrapped (and gathered)."""
for
name
,
param
in
model
.
named_parameters
():
if
name
in
layernorm_params
:
param
.
data
=
layernorm_params
[
name
]
src/llamafactory/train/ppo/trainer.py
0 → 100644
View file @
b59a5620
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's TRL library.
# https://github.com/huggingface/trl/blob/v0.8.0/trl/trainer/ppo_trainer.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
math
import
os
import
sys
import
warnings
from
types
import
MethodType
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
import
torch
from
accelerate.utils
import
DistributedDataParallelKwargs
from
tqdm
import
tqdm
from
transformers
import
GenerationConfig
,
Trainer
,
TrainerControl
,
TrainerState
from
transformers.optimization
import
get_scheduler
from
transformers.trainer
import
DEFAULT_CALLBACKS
from
transformers.trainer_callback
import
CallbackHandler
from
transformers.trainer_pt_utils
import
remove_dummy_checkpoint
from
transformers.trainer_utils
import
PREFIX_CHECKPOINT_DIR
from
transformers.utils
import
SAFE_WEIGHTS_NAME
,
WEIGHTS_NAME
from
trl
import
PPOConfig
,
PPOTrainer
from
trl.core
import
PPODecorators
,
logprobs_from_logits
from
trl.models.utils
import
unwrap_model_for_generation
from
typing_extensions
import
override
from
...extras
import
logging
from
...extras.misc
import
AverageMeter
,
count_parameters
,
get_current_device
,
get_logits_processor
from
..callbacks
import
FixValueHeadModelCallback
,
SaveProcessorCallback
from
..trainer_utils
import
create_custom_optimizer
,
create_custom_scheduler
from
.ppo_utils
import
dump_layernorm
,
get_rewards_from_server
,
replace_model
,
restore_layernorm
if
TYPE_CHECKING
:
from
datasets
import
Dataset
from
transformers
import
(
DataCollatorWithPadding
,
PreTrainedTokenizer
,
ProcessorMixin
,
Seq2SeqTrainingArguments
,
TrainerCallback
,
)
from
trl
import
AutoModelForCausalLMWithValueHead
from
...hparams
import
FinetuningArguments
,
GeneratingArguments
,
ModelArguments
logger
=
logging
.
get_logger
(
__name__
)
class
CustomPPOTrainer
(
PPOTrainer
,
Trainer
):
r
"""Inherit PPOTrainer."""
def
__init__
(
self
,
model_args
:
"ModelArguments"
,
training_args
:
"Seq2SeqTrainingArguments"
,
finetuning_args
:
"FinetuningArguments"
,
generating_args
:
"GeneratingArguments"
,
callbacks
:
Optional
[
list
[
"TrainerCallback"
]],
model
:
"AutoModelForCausalLMWithValueHead"
,
reward_model
:
Optional
[
"AutoModelForCausalLMWithValueHead"
],
ref_model
:
Optional
[
"AutoModelForCausalLMWithValueHead"
],
tokenizer
:
"PreTrainedTokenizer"
,
processor
:
Optional
[
"ProcessorMixin"
],
data_collator
:
"DataCollatorWithPadding"
,
train_dataset
:
Optional
[
"Dataset"
]
=
None
,
eval_dataset
:
Optional
[
"Dataset"
]
=
None
,
)
->
None
:
if
eval_dataset
is
not
None
:
raise
NotImplementedError
(
"PPOTrainer does not support eval dataset yet."
)
backward_batch_size
=
training_args
.
per_device_train_batch_size
*
training_args
.
gradient_accumulation_steps
ppo_config
=
PPOConfig
(
model_name
=
model_args
.
model_name_or_path
,
learning_rate
=
training_args
.
learning_rate
,
mini_batch_size
=
training_args
.
per_device_train_batch_size
,
batch_size
=
backward_batch_size
*
finetuning_args
.
ppo_buffer_size
,
gradient_accumulation_steps
=
training_args
.
gradient_accumulation_steps
,
ppo_epochs
=
finetuning_args
.
ppo_epochs
,
max_grad_norm
=
training_args
.
max_grad_norm
,
seed
=
training_args
.
seed
,
optimize_device_cache
=
True
,
target
=
finetuning_args
.
ppo_target
,
use_score_scaling
=
finetuning_args
.
ppo_score_norm
,
use_score_norm
=
finetuning_args
.
ppo_score_norm
,
whiten_rewards
=
finetuning_args
.
ppo_whiten_rewards
,
accelerator_kwargs
=
{
"step_scheduler_with_optimizer"
:
False
},
log_with
=
training_args
.
report_to
[
0
]
if
training_args
.
report_to
else
None
,
project_kwargs
=
{
"logging_dir"
:
training_args
.
logging_dir
},
)
# Add deepspeed config
if
training_args
.
deepspeed_plugin
is
not
None
:
ppo_config
.
accelerator_kwargs
[
"kwargs_handlers"
]
=
[
DistributedDataParallelKwargs
(
find_unused_parameters
=
training_args
.
ddp_find_unused_parameters
)
]
ppo_config
.
accelerator_kwargs
[
"deepspeed_plugin"
]
=
training_args
.
deepspeed_plugin
if
ppo_config
.
log_with
is
not
None
:
logger
.
warning_rank0
(
"PPOTrainer cannot use external logger when DeepSpeed is enabled."
)
ppo_config
.
log_with
=
None
# Create optimizer and scheduler
if
training_args
.
max_steps
>
0
:
num_training_steps
=
training_args
.
max_steps
else
:
total_train_batch_size
=
backward_batch_size
*
finetuning_args
.
ppo_buffer_size
*
training_args
.
world_size
num_training_steps
=
training_args
.
num_train_epochs
*
math
.
ceil
(
len
(
train_dataset
)
/
total_train_batch_size
)
optimizer
=
self
.
create_optimizer
(
model
,
training_args
,
finetuning_args
)
scheduler
=
self
.
create_scheduler
(
training_args
,
num_training_steps
,
optimizer
)
PPOTrainer
.
__init__
(
self
,
config
=
ppo_config
,
model
=
model
,
ref_model
=
ref_model
,
tokenizer
=
tokenizer
,
dataset
=
train_dataset
,
optimizer
=
optimizer
,
data_collator
=
data_collator
,
lr_scheduler
=
scheduler
,
)
self
.
args
=
training_args
self
.
model_args
=
model_args
self
.
finetuning_args
=
finetuning_args
self
.
reward_model
=
reward_model
self
.
current_device
=
get_current_device
()
# patch for deepspeed training
self
.
generation_config
=
GenerationConfig
(
pad_token_id
=
self
.
tokenizer
.
pad_token_id
,
eos_token_id
=
[
self
.
tokenizer
.
eos_token_id
]
+
self
.
tokenizer
.
additional_special_tokens_ids
,
**
generating_args
.
to_dict
(),
)
self
.
state
=
TrainerState
()
self
.
control
=
TrainerControl
()
self
.
is_deepspeed_enabled
=
getattr
(
self
.
accelerator
.
state
,
"deepspeed_plugin"
,
None
)
is
not
None
self
.
is_fsdp_enabled
=
getattr
(
self
.
accelerator
.
state
,
"fsdp_plugin"
,
None
)
is
not
None
callbacks
=
DEFAULT_CALLBACKS
if
callbacks
is
None
else
DEFAULT_CALLBACKS
+
callbacks
self
.
callback_handler
=
CallbackHandler
(
callbacks
,
self
.
accelerator
.
unwrap_model
(
self
.
model
),
self
.
tokenizer
,
self
.
optimizer
,
self
.
lr_scheduler
)
if
self
.
args
.
max_steps
>
0
:
logger
.
info_rank0
(
"max_steps is given, it will override any value given in num_train_epochs"
)
self
.
amp_context
=
torch
.
autocast
(
self
.
current_device
.
type
)
warnings
.
simplefilter
(
"ignore"
)
# remove gc warnings on ref model
if
finetuning_args
.
reward_model_type
==
"full"
:
if
self
.
is_deepspeed_enabled
:
if
not
(
getattr
(
reward_model
.
pretrained_model
,
"is_loaded_in_8bit"
,
False
)
or
getattr
(
reward_model
.
pretrained_model
,
"is_loaded_in_4bit"
,
False
)
):
# quantized models are already set on the correct device
self
.
reward_model
=
self
.
_prepare_deepspeed
(
self
.
reward_model
)
else
:
self
.
reward_model
=
self
.
accelerator
.
prepare_model
(
self
.
reward_model
,
evaluation_mode
=
True
)
self
.
add_callback
(
FixValueHeadModelCallback
)
if
processor
is
not
None
:
self
.
add_callback
(
SaveProcessorCallback
(
processor
))
if
finetuning_args
.
use_badam
:
from
badam
import
BAdamCallback
,
clip_grad_norm_old_version
# type: ignore
self
.
accelerator
.
clip_grad_norm_
=
MethodType
(
clip_grad_norm_old_version
,
self
.
accelerator
)
self
.
add_callback
(
BAdamCallback
)
def
ppo_train
(
self
,
resume_from_checkpoint
:
Optional
[
str
]
=
None
)
->
None
:
r
"""Implement training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer."""
if
resume_from_checkpoint
is
not
None
:
raise
ValueError
(
"`resume_from_checkpoint` will be supported in the future version."
)
total_train_batch_size
=
(
self
.
args
.
per_device_train_batch_size
*
self
.
args
.
gradient_accumulation_steps
*
self
.
finetuning_args
.
ppo_buffer_size
*
self
.
args
.
world_size
)
if
self
.
args
.
max_steps
>
0
:
num_examples
=
total_train_batch_size
*
self
.
args
.
max_steps
num_train_epochs
=
sys
.
maxsize
max_steps
=
self
.
args
.
max_steps
steps_in_epoch
=
self
.
args
.
max_steps
else
:
len_dataloader
=
len
(
self
.
dataloader
)
num_examples
=
len
(
self
.
dataset
)
num_train_epochs
=
self
.
args
.
num_train_epochs
max_steps
=
math
.
ceil
(
num_train_epochs
*
len_dataloader
)
steps_in_epoch
=
len_dataloader
self
.
state
.
max_steps
=
max_steps
self
.
state
.
num_train_epochs
=
num_train_epochs
self
.
state
.
is_local_process_zero
=
self
.
is_local_process_zero
()
self
.
state
.
is_world_process_zero
=
self
.
is_world_process_zero
()
logger
.
info_rank0
(
"***** Running training *****"
)
logger
.
info_rank0
(
f
" Num examples =
{
num_examples
:,
}
"
)
logger
.
info_rank0
(
f
" Num Epochs =
{
num_train_epochs
:,
}
"
)
logger
.
info_rank0
(
f
" Instantaneous batch size per device =
{
self
.
args
.
per_device_train_batch_size
:,
}
"
)
logger
.
info_rank0
(
f
" Total train batch size (w. parallel, buffer, distributed & accumulation) =
{
total_train_batch_size
:,
}
"
)
logger
.
info_rank0
(
f
" Gradient Accumulation steps =
{
self
.
args
.
gradient_accumulation_steps
:,
}
"
)
logger
.
info_rank0
(
f
" Num optimization epochs per batch =
{
self
.
finetuning_args
.
ppo_epochs
:,
}
"
)
logger
.
info_rank0
(
f
" Total training steps =
{
max_steps
:,
}
"
)
logger
.
info_rank0
(
f
" Number of trainable parameters =
{
count_parameters
(
self
.
model
)[
0
]:,
}
"
)
dataiter
=
iter
(
self
.
dataloader
)
loss_meter
=
AverageMeter
()
reward_meter
=
AverageMeter
()
self
.
callback_handler
.
on_train_begin
(
self
.
args
,
self
.
state
,
self
.
control
)
for
step
in
tqdm
(
range
(
max_steps
),
disable
=
not
self
.
is_local_process_zero
()):
try
:
batch
=
next
(
dataiter
)
except
StopIteration
:
dataiter
=
iter
(
self
.
dataloader
)
batch
=
next
(
dataiter
)
# Get inputs
self
.
model
.
eval
()
self
.
tokenizer
.
padding_side
=
"right"
# change padding side
queries
,
responses
,
rewards
=
[],
[],
[]
for
idx
in
range
(
0
,
self
.
config
.
batch_size
,
self
.
config
.
mini_batch_size
):
mini_batch
=
{
"input_ids"
:
batch
[
"input_ids"
][
idx
:
idx
+
self
.
config
.
mini_batch_size
],
"attention_mask"
:
batch
[
"attention_mask"
][
idx
:
idx
+
self
.
config
.
mini_batch_size
],
}
mini_batch_queries
,
mini_batch_responses
=
self
.
get_inputs
(
mini_batch
)
mini_batch_rewards
=
self
.
get_rewards
(
mini_batch_queries
,
mini_batch_responses
)
queries
.
extend
(
mini_batch_queries
)
responses
.
extend
(
mini_batch_responses
)
rewards
.
extend
(
mini_batch_rewards
)
# Run PPO step
self
.
model
.
train
()
stats
=
self
.
step
(
queries
,
responses
,
rewards
)
self
.
tokenizer
.
padding_side
=
"left"
# restore padding side
loss_meter
.
update
(
float
(
stats
[
"ppo/loss/total"
]),
n
=
len
(
rewards
))
reward_meter
.
update
(
torch
.
stack
(
rewards
).
mean
().
item
(),
n
=
len
(
rewards
))
if
self
.
config
.
log_with
is
not
None
:
try
:
batch
[
"query"
]
=
self
.
tokenizer
.
batch_decode
(
queries
,
skip_special_tokens
=
True
)
batch
[
"response"
]
=
self
.
tokenizer
.
batch_decode
(
responses
,
skip_special_tokens
=
True
)
self
.
log_stats
(
stats
,
batch
,
rewards
)
except
Exception
:
logger
.
warning_rank0
(
"Failed to save stats due to unknown errors."
)
self
.
state
.
global_step
+=
1
self
.
callback_handler
.
on_step_end
(
self
.
args
,
self
.
state
,
self
.
control
)
if
self
.
is_local_process_zero
()
and
(
step
+
1
)
%
self
.
args
.
logging_steps
==
0
:
logs
=
dict
(
loss
=
round
(
loss_meter
.
avg
,
4
),
reward
=
round
(
reward_meter
.
avg
,
4
),
learning_rate
=
stats
[
"ppo/learning_rate"
],
epoch
=
round
(
step
/
steps_in_epoch
,
2
),
)
tqdm
.
write
(
str
(
logs
))
logs
[
"step"
]
=
step
self
.
state
.
log_history
.
append
(
logs
)
self
.
callback_handler
.
on_log
(
self
.
args
,
self
.
state
,
self
.
control
,
logs
)
loss_meter
.
reset
()
reward_meter
.
reset
()
if
(
step
+
1
)
%
self
.
args
.
save_steps
==
0
:
# save checkpoint
self
.
save_model
(
os
.
path
.
join
(
self
.
args
.
output_dir
,
f
"
{
PREFIX_CHECKPOINT_DIR
}
-
{
self
.
state
.
global_step
}
"
)
)
self
.
callback_handler
.
on_save
(
self
.
args
,
self
.
state
,
self
.
control
)
if
self
.
control
.
should_epoch_stop
or
self
.
control
.
should_training_stop
:
break
self
.
callback_handler
.
on_train_end
(
self
.
args
,
self
.
state
,
self
.
control
)
@
override
def
create_optimizer
(
self
,
model
:
"AutoModelForCausalLMWithValueHead"
,
training_args
:
"Seq2SeqTrainingArguments"
,
finetuning_args
:
"FinetuningArguments"
,
)
->
"torch.optim.Optimizer"
:
optimizer
=
create_custom_optimizer
(
model
,
training_args
,
finetuning_args
)
if
optimizer
is
None
:
decay_params
,
nodecay_params
=
[],
[]
decay_param_names
=
self
.
get_decay_parameter_names
(
model
)
for
name
,
param
in
model
.
named_parameters
():
if
param
.
requires_grad
:
if
name
in
decay_param_names
:
decay_params
.
append
(
param
)
else
:
nodecay_params
.
append
(
param
)
optim_class
,
optim_kwargs
=
Trainer
.
get_optimizer_cls_and_kwargs
(
training_args
)
param_groups
=
[
dict
(
params
=
nodecay_params
),
dict
(
params
=
decay_params
,
weight_decay
=
training_args
.
weight_decay
),
]
optimizer
=
optim_class
(
param_groups
,
**
optim_kwargs
)
return
optimizer
@
override
def
create_scheduler
(
self
,
training_args
:
"Seq2SeqTrainingArguments"
,
num_training_steps
:
int
,
optimizer
:
"torch.optim.Optimizer"
)
->
"torch.optim.lr_scheduler.LRScheduler"
:
create_custom_scheduler
(
training_args
,
num_training_steps
,
optimizer
)
lr_scheduler
=
get_scheduler
(
training_args
.
lr_scheduler_type
,
optimizer
=
optimizer
,
num_warmup_steps
=
training_args
.
get_warmup_steps
(
num_training_steps
),
num_training_steps
=
num_training_steps
,
)
return
lr_scheduler
@
torch
.
no_grad
()
def
get_inputs
(
self
,
batch
:
dict
[
str
,
"torch.Tensor"
])
->
tuple
[
list
[
"torch.Tensor"
],
list
[
"torch.Tensor"
]]:
r
"""Generate model's responses given queries."""
if
batch
[
"input_ids"
].
size
(
0
)
==
1
:
# handle llama2 ppo with gradient accumulation > 1
start_index
=
(
batch
[
"input_ids"
][
0
]
!=
self
.
tokenizer
.
pad_token_id
).
nonzero
()[
0
].
item
()
for
k
,
v
in
batch
.
items
():
batch
[
k
]
=
v
[:,
start_index
:]
with
unwrap_model_for_generation
(
self
.
model
,
self
.
accelerator
)
as
unwrapped_model
:
unwrapped_model
:
AutoModelForCausalLMWithValueHead
=
self
.
accelerator
.
unwrap_model
(
self
.
model
)
if
self
.
model_args
.
upcast_layernorm
:
layernorm_params
=
dump_layernorm
(
unwrapped_model
)
generate_output
:
torch
.
Tensor
=
unwrapped_model
.
generate
(
generation_config
=
self
.
generation_config
,
logits_processor
=
get_logits_processor
(),
**
batch
)
if
self
.
model_args
.
upcast_layernorm
:
restore_layernorm
(
unwrapped_model
,
layernorm_params
)
query
=
batch
[
"input_ids"
].
detach
().
cpu
()
response
=
generate_output
[:,
batch
[
"input_ids"
].
size
(
-
1
)
:].
detach
().
cpu
()
queries
,
responses
=
[],
[]
for
i
in
range
(
len
(
query
)):
query_start_index
=
(
query
[
i
]
!=
self
.
tokenizer
.
pad_token_id
).
nonzero
()[
0
].
item
()
response_indexes
=
(
response
[
i
]
!=
self
.
tokenizer
.
pad_token_id
).
nonzero
()
if
len
(
response_indexes
)
==
0
:
# allow empty response
response_length
=
1
elif
self
.
tokenizer
.
eos_token_id
==
self
.
tokenizer
.
pad_token_id
:
# include eos token
response_length
=
response_indexes
[
-
1
].
item
()
+
2
else
:
response_length
=
response_indexes
[
-
1
].
item
()
+
1
queries
.
append
(
query
[
i
,
query_start_index
:])
# remove padding from left
responses
.
append
(
response
[
i
,
:
response_length
])
# remove padding from right
return
queries
,
responses
@
torch
.
no_grad
()
def
get_rewards
(
self
,
queries
:
list
[
"torch.Tensor"
],
responses
:
list
[
"torch.Tensor"
],
)
->
list
[
"torch.Tensor"
]:
r
"""Compute scores using given reward model.
Both inputs and outputs are put on CPU.
"""
if
self
.
finetuning_args
.
reward_model_type
==
"api"
:
token_ids
=
[
torch
.
cat
((
q
,
r
),
dim
=-
1
).
tolist
()
for
q
,
r
in
zip
(
queries
,
responses
)]
messages
=
self
.
tokenizer
.
batch_decode
(
token_ids
,
skip_special_tokens
=
False
)
return
get_rewards_from_server
(
self
.
reward_model
,
messages
)
batch
:
dict
[
str
,
torch
.
Tensor
]
=
self
.
prepare_model_inputs
(
queries
,
responses
)
unwrapped_model
:
AutoModelForCausalLMWithValueHead
=
self
.
accelerator
.
unwrap_model
(
self
.
model
)
if
self
.
finetuning_args
.
reward_model_type
==
"lora"
:
replace_model
(
unwrapped_model
,
target
=
"reward"
)
reward_model
=
self
.
model
else
:
reward_model
=
self
.
reward_model
with
unwrap_model_for_generation
(
reward_model
,
self
.
accelerator
),
self
.
amp_context
:
# support bf16
values
:
torch
.
Tensor
=
reward_model
(
**
batch
,
return_dict
=
True
,
use_cache
=
False
)[
-
1
]
if
self
.
finetuning_args
.
reward_model_type
==
"lora"
:
replace_model
(
unwrapped_model
,
target
=
"default"
)
rewards
=
values
.
gather
(
dim
=-
1
,
index
=
(
batch
[
"attention_mask"
].
sum
(
dim
=-
1
,
keepdim
=
True
)
-
1
))
return
rewards
.
float
().
detach
()
# use fp32 type
@
override
@
PPODecorators
.
empty_device_cache
()
def
batched_forward_pass
(
self
,
model
:
"AutoModelForCausalLMWithValueHead"
,
queries
:
"torch.Tensor"
,
responses
:
"torch.Tensor"
,
model_inputs
:
dict
[
str
,
Any
],
return_logits
:
bool
=
False
,
response_masks
:
Optional
[
"torch.Tensor"
]
=
None
,
)
->
tuple
[
"torch.Tensor"
,
Optional
[
"torch.Tensor"
],
"torch.Tensor"
,
"torch.Tensor"
]:
r
"""Calculate model outputs in multiple batches.
Subclass and override to inject custom behavior.
"""
bs
=
len
(
queries
)
fbs
=
self
.
config
.
mini_batch_size
all_logprobs
=
[]
all_logits
=
[]
all_masks
=
[]
all_values
=
[]
for
i
in
range
(
math
.
ceil
(
bs
/
fbs
)):
input_kwargs
=
{
key
:
value
[
i
*
fbs
:
(
i
+
1
)
*
fbs
]
for
key
,
value
in
model_inputs
.
items
()}
query_batch
=
queries
[
i
*
fbs
:
(
i
+
1
)
*
fbs
]
response_batch
=
responses
[
i
*
fbs
:
(
i
+
1
)
*
fbs
]
if
response_masks
is
not
None
:
response_masks_batch
=
response_masks
[
i
*
fbs
:
(
i
+
1
)
*
fbs
]
input_ids
=
input_kwargs
[
"input_ids"
]
attention_mask
=
input_kwargs
[
"attention_mask"
]
with
self
.
amp_context
:
# support bf16
logits
,
_
,
values
=
model
(
**
input_kwargs
,
return_dict
=
True
,
use_cache
=
False
)
logprobs
=
logprobs_from_logits
(
logits
[:,
:
-
1
,
:],
input_ids
[:,
1
:])
masks
=
torch
.
zeros_like
(
attention_mask
)
masks
[:,
:
-
1
]
=
attention_mask
[:,
1
:]
for
j
in
range
(
len
(
query_batch
)):
start
=
len
(
query_batch
[
j
])
-
1
if
attention_mask
[
j
,
0
]
==
0
:
# offset left padding
start
+=
attention_mask
[
j
,
:].
nonzero
()[
0
].
item
()
end
=
start
+
len
(
response_batch
[
j
])
if
response_masks
is
not
None
:
response_masks_batch
=
torch
.
cat
((
torch
.
zeros_like
(
query_batch
[
j
]),
response_masks_batch
[
j
]))[
1
:]
masks
[
j
,
:
start
]
=
0
masks
[
j
,
end
:]
=
0
if
response_masks
is
not
None
:
masks
[
j
,
start
:
end
]
=
masks
[
j
,
start
:
end
]
*
response_masks_batch
[
j
][
start
:
end
]
if
return_logits
:
all_logits
.
append
(
logits
)
else
:
del
logits
all_values
.
append
(
values
)
all_logprobs
.
append
(
logprobs
)
all_masks
.
append
(
masks
)
return
(
torch
.
cat
(
all_logprobs
),
torch
.
cat
(
all_logits
)[:,
:
-
1
]
if
return_logits
else
None
,
torch
.
cat
(
all_values
)[:,
:
-
1
],
torch
.
cat
(
all_masks
)[:,
:
-
1
],
)
@
override
def
save_model
(
self
,
output_dir
:
Optional
[
str
]
=
None
)
->
None
:
r
"""Save model checkpoint.
Subclass and override to inject custom behavior.
"""
if
output_dir
is
None
:
output_dir
=
self
.
args
.
output_dir
if
self
.
is_fsdp_enabled
or
self
.
is_deepspeed_enabled
:
try
:
state_dict
=
self
.
accelerator
.
get_state_dict
(
self
.
model
)
# must be called at all ranks
if
self
.
args
.
should_save
:
self
.
_save
(
output_dir
,
state_dict
=
state_dict
)
except
ValueError
:
logger
.
warning_rank0
(
" stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead,"
" use zero_to_fp32.py to recover weights"
)
if
self
.
args
.
should_save
:
self
.
_save
(
output_dir
,
state_dict
=
{})
# remove the dummy state_dict
remove_dummy_checkpoint
(
self
.
args
.
should_save
,
output_dir
,
[
WEIGHTS_NAME
,
SAFE_WEIGHTS_NAME
])
self
.
model
.
save_checkpoint
(
output_dir
)
elif
self
.
args
.
should_save
:
unwrapped_model
:
AutoModelForCausalLMWithValueHead
=
self
.
accelerator
.
unwrap_model
(
self
.
model
)
self
.
_save
(
output_dir
,
state_dict
=
unwrapped_model
.
state_dict
())
src/llamafactory/train/ppo/workflow.py
0 → 100644
View file @
b59a5620
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's TRL library.
# https://github.com/huggingface/trl/blob/v0.8.0/examples/scripts/ppo.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
TYPE_CHECKING
,
Optional
from
...data
import
MultiModalDataCollatorForSeq2Seq
,
get_dataset
,
get_template_and_fix_tokenizer
from
...extras.ploting
import
plot_loss
from
...model
import
load_model
,
load_tokenizer
from
..callbacks
import
fix_valuehead_checkpoint
from
..trainer_utils
import
create_ref_model
,
create_reward_model
from
.trainer
import
CustomPPOTrainer
if
TYPE_CHECKING
:
from
transformers
import
Seq2SeqTrainingArguments
,
TrainerCallback
from
...hparams
import
DataArguments
,
FinetuningArguments
,
GeneratingArguments
,
ModelArguments
def
run_ppo
(
model_args
:
"ModelArguments"
,
data_args
:
"DataArguments"
,
training_args
:
"Seq2SeqTrainingArguments"
,
finetuning_args
:
"FinetuningArguments"
,
generating_args
:
"GeneratingArguments"
,
callbacks
:
Optional
[
list
[
"TrainerCallback"
]]
=
None
,
):
tokenizer_module
=
load_tokenizer
(
model_args
)
tokenizer
=
tokenizer_module
[
"tokenizer"
]
template
=
get_template_and_fix_tokenizer
(
tokenizer
,
data_args
)
dataset_module
=
get_dataset
(
template
,
model_args
,
data_args
,
training_args
,
stage
=
"ppo"
,
**
tokenizer_module
)
model
=
load_model
(
tokenizer
,
model_args
,
finetuning_args
,
training_args
.
do_train
,
add_valuehead
=
True
)
tokenizer
.
padding_side
=
"left"
# use left-padding in generation while using right-padding in training
data_collator
=
MultiModalDataCollatorForSeq2Seq
(
template
=
template
,
model
=
model
,
**
tokenizer_module
)
# Create reference model and reward model
ref_model
=
create_ref_model
(
model_args
,
finetuning_args
,
add_valuehead
=
True
)
reward_model
=
create_reward_model
(
model
,
model_args
,
finetuning_args
)
# Initialize our Trainer
ppo_trainer
:
CustomPPOTrainer
=
CustomPPOTrainer
(
model_args
=
model_args
,
training_args
=
training_args
,
finetuning_args
=
finetuning_args
,
generating_args
=
generating_args
,
callbacks
=
callbacks
,
model
=
model
,
reward_model
=
reward_model
,
ref_model
=
ref_model
,
data_collator
=
data_collator
,
**
dataset_module
,
**
tokenizer_module
,
)
# Training
if
training_args
.
do_train
:
ppo_trainer
.
ppo_train
(
resume_from_checkpoint
=
training_args
.
resume_from_checkpoint
)
ppo_trainer
.
save_model
()
if
training_args
.
should_save
:
fix_valuehead_checkpoint
(
model
,
training_args
.
output_dir
,
training_args
.
save_safetensors
)
ppo_trainer
.
save_state
()
# must be called after save_model to have a folder
if
ppo_trainer
.
is_world_process_zero
()
and
finetuning_args
.
plot_loss
:
plot_loss
(
training_args
.
output_dir
,
keys
=
[
"loss"
,
"reward"
])
src/llamafactory/train/pt/__init__.py
0 → 100644
View file @
b59a5620
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.workflow
import
run_pt
__all__
=
[
"run_pt"
]
src/llamafactory/train/pt/trainer.py
0 → 100644
View file @
b59a5620
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
types
import
MethodType
from
typing
import
TYPE_CHECKING
,
Optional
import
torch
from
transformers
import
Trainer
from
typing_extensions
import
override
from
...extras.packages
import
is_transformers_version_greater_than
from
..callbacks
import
SaveProcessorCallback
from
..trainer_utils
import
create_custom_optimizer
,
create_custom_scheduler
if
TYPE_CHECKING
:
from
transformers
import
ProcessorMixin
from
...hparams
import
FinetuningArguments
class
CustomTrainer
(
Trainer
):
r
"""Inherit Trainer for custom optimizer."""
def
__init__
(
self
,
finetuning_args
:
"FinetuningArguments"
,
processor
:
Optional
[
"ProcessorMixin"
],
**
kwargs
)
->
None
:
if
is_transformers_version_greater_than
(
"4.46"
):
kwargs
[
"processing_class"
]
=
kwargs
.
pop
(
"tokenizer"
)
super
().
__init__
(
**
kwargs
)
if
processor
is
not
None
:
# avoid wrong loss under gradient accumulation
# https://github.com/huggingface/transformers/pull/36044#issuecomment-2746657112
self
.
model_accepts_loss_kwargs
=
False
self
.
finetuning_args
=
finetuning_args
if
processor
is
not
None
:
self
.
add_callback
(
SaveProcessorCallback
(
processor
))
if
finetuning_args
.
use_badam
:
from
badam
import
BAdamCallback
,
clip_grad_norm_old_version
# type: ignore
self
.
accelerator
.
clip_grad_norm_
=
MethodType
(
clip_grad_norm_old_version
,
self
.
accelerator
)
self
.
add_callback
(
BAdamCallback
)
@
override
def
create_optimizer
(
self
)
->
"torch.optim.Optimizer"
:
if
self
.
optimizer
is
None
:
self
.
optimizer
=
create_custom_optimizer
(
self
.
model
,
self
.
args
,
self
.
finetuning_args
)
return
super
().
create_optimizer
()
@
override
def
create_scheduler
(
self
,
num_training_steps
:
int
,
optimizer
:
Optional
[
"torch.optim.Optimizer"
]
=
None
)
->
"torch.optim.lr_scheduler.LRScheduler"
:
create_custom_scheduler
(
self
.
args
,
num_training_steps
,
optimizer
)
return
super
().
create_scheduler
(
num_training_steps
,
optimizer
)
@
override
def
_get_train_sampler
(
self
,
*
args
,
**
kwargs
)
->
Optional
[
"torch.utils.data.Sampler"
]:
if
self
.
finetuning_args
.
disable_shuffling
:
return
torch
.
utils
.
data
.
SequentialSampler
(
self
.
train_dataset
)
return
super
().
_get_train_sampler
(
*
args
,
**
kwargs
)
@
override
def
compute_loss
(
self
,
model
,
inputs
,
*
args
,
**
kwargs
):
return
super
().
compute_loss
(
model
,
inputs
,
*
args
,
**
kwargs
)
src/llamafactory/train/pt/workflow.py
0 → 100644
View file @
b59a5620
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
math
from
typing
import
TYPE_CHECKING
,
Optional
from
transformers
import
DataCollatorForLanguageModeling
from
...data
import
get_dataset
,
get_template_and_fix_tokenizer
from
...extras.ploting
import
plot_loss
from
...model
import
load_model
,
load_tokenizer
from
..trainer_utils
import
create_modelcard_and_push
from
.trainer
import
CustomTrainer
if
TYPE_CHECKING
:
from
transformers
import
Seq2SeqTrainingArguments
,
TrainerCallback
from
...hparams
import
DataArguments
,
FinetuningArguments
,
ModelArguments
def
run_pt
(
model_args
:
"ModelArguments"
,
data_args
:
"DataArguments"
,
training_args
:
"Seq2SeqTrainingArguments"
,
finetuning_args
:
"FinetuningArguments"
,
callbacks
:
Optional
[
list
[
"TrainerCallback"
]]
=
None
,
):
tokenizer_module
=
load_tokenizer
(
model_args
)
tokenizer
=
tokenizer_module
[
"tokenizer"
]
template
=
get_template_and_fix_tokenizer
(
tokenizer
,
data_args
)
dataset_module
=
get_dataset
(
template
,
model_args
,
data_args
,
training_args
,
stage
=
"pt"
,
**
tokenizer_module
)
model
=
load_model
(
tokenizer
,
model_args
,
finetuning_args
,
training_args
.
do_train
)
data_collator
=
DataCollatorForLanguageModeling
(
tokenizer
=
tokenizer
,
mlm
=
False
)
# Initialize our Trainer
trainer
=
CustomTrainer
(
model
=
model
,
args
=
training_args
,
finetuning_args
=
finetuning_args
,
data_collator
=
data_collator
,
callbacks
=
callbacks
,
**
dataset_module
,
**
tokenizer_module
,
)
# Training
if
training_args
.
do_train
:
train_result
=
trainer
.
train
(
resume_from_checkpoint
=
training_args
.
resume_from_checkpoint
)
trainer
.
save_model
()
trainer
.
log_metrics
(
"train"
,
train_result
.
metrics
)
trainer
.
save_metrics
(
"train"
,
train_result
.
metrics
)
trainer
.
save_state
()
if
trainer
.
is_world_process_zero
()
and
finetuning_args
.
plot_loss
:
keys
=
[
"loss"
]
if
isinstance
(
dataset_module
.
get
(
"eval_dataset"
),
dict
):
keys
+=
[
f
"eval_
{
key
}
_loss"
for
key
in
dataset_module
[
"eval_dataset"
].
keys
()]
else
:
keys
+=
[
"eval_loss"
]
plot_loss
(
training_args
.
output_dir
,
keys
=
keys
)
# Evaluation
if
training_args
.
do_eval
:
metrics
=
trainer
.
evaluate
(
metric_key_prefix
=
"eval"
)
if
isinstance
(
dataset_module
.
get
(
"eval_dataset"
),
dict
):
for
key
in
dataset_module
[
"eval_dataset"
].
keys
():
try
:
perplexity
=
math
.
exp
(
metrics
[
f
"eval_
{
key
}
_loss"
])
except
OverflowError
:
perplexity
=
float
(
"inf"
)
metrics
[
f
"eval_
{
key
}
_perplexity"
]
=
perplexity
else
:
try
:
perplexity
=
math
.
exp
(
metrics
[
"eval_loss"
])
except
OverflowError
:
perplexity
=
float
(
"inf"
)
metrics
[
"eval_perplexity"
]
=
perplexity
trainer
.
log_metrics
(
"eval"
,
metrics
)
trainer
.
save_metrics
(
"eval"
,
metrics
)
# Create model card
create_modelcard_and_push
(
trainer
,
model_args
,
data_args
,
training_args
,
finetuning_args
)
src/llamafactory/train/rm/__init__.py
0 → 100644
View file @
b59a5620
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.workflow
import
run_rm
__all__
=
[
"run_rm"
]
Prev
1
…
7
8
9
10
11
12
13
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment