Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenych
llama-grpo
Commits
c7c477c7
Commit
c7c477c7
authored
Sep 24, 2025
by
chenych
Browse files
add grpo
parents
Pipeline
#2942
failed with stages
in 0 seconds
Changes
282
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2712 additions
and
0 deletions
+2712
-0
src/llamafactory/third_party/muon/muon.py
src/llamafactory/third_party/muon/muon.py
+226
-0
src/llamafactory/train/__init__.py
src/llamafactory/train/__init__.py
+0
-0
src/llamafactory/train/callbacks.py
src/llamafactory/train/callbacks.py
+385
-0
src/llamafactory/train/dpo/__init__.py
src/llamafactory/train/dpo/__init__.py
+18
-0
src/llamafactory/train/dpo/trainer.py
src/llamafactory/train/dpo/trainer.py
+302
-0
src/llamafactory/train/dpo/workflow.py
src/llamafactory/train/dpo/workflow.py
+110
-0
src/llamafactory/train/grpo/__init__.py
src/llamafactory/train/grpo/__init__.py
+18
-0
src/llamafactory/train/grpo/data.py
src/llamafactory/train/grpo/data.py
+115
-0
src/llamafactory/train/grpo/func.py
src/llamafactory/train/grpo/func.py
+54
-0
src/llamafactory/train/grpo/metric.py
src/llamafactory/train/grpo/metric.py
+134
-0
src/llamafactory/train/grpo/workflow.py
src/llamafactory/train/grpo/workflow.py
+155
-0
src/llamafactory/train/kto/__init__.py
src/llamafactory/train/kto/__init__.py
+18
-0
src/llamafactory/train/kto/trainer.py
src/llamafactory/train/kto/trainer.py
+297
-0
src/llamafactory/train/kto/workflow.py
src/llamafactory/train/kto/workflow.py
+101
-0
src/llamafactory/train/ppo/__init__.py
src/llamafactory/train/ppo/__init__.py
+18
-0
src/llamafactory/train/ppo/ppo_utils.py
src/llamafactory/train/ppo/ppo_utils.py
+80
-0
src/llamafactory/train/ppo/trainer.py
src/llamafactory/train/ppo/trainer.py
+503
-0
src/llamafactory/train/ppo/workflow.py
src/llamafactory/train/ppo/workflow.py
+79
-0
src/llamafactory/train/pt/__init__.py
src/llamafactory/train/pt/__init__.py
+18
-0
src/llamafactory/train/pt/trainer.py
src/llamafactory/train/pt/trainer.py
+81
-0
No files found.
src/llamafactory/third_party/muon/muon.py
0 → 100644
View file @
c7c477c7
# Copyright 2025 Moonshot AI and the LlamaFactory team.
#
# This code is based on the MoonshotAI's Moonlight library.
# https://github.com/MoonshotAI/Moonlight/blob/master/examples/toy_train.py
# and the Keller Jordan's Muon library.
# https://github.com/KellerJordan/Muon/blob/master/muon.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# MIT License
#
# Copyright (c) 2025 Moonshot AI
# Copyright (c) 2024 Keller Jordan
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import
math
import
torch
def
zeropower_via_newtonschulz5
(
G
:
"torch.Tensor"
,
steps
:
int
)
->
"torch.Tensor"
:
"""Newton-Schulz iteration to compute the zeroth power / orthogonalization of G.
We opt to use a quintic iteration whose coefficients are selected to maximize the slope at zero.
For the purpose of minimizing steps, it turns out to be empirically effective to keep increasing
the slope at zero even beyond the point where the iteration no longer converges all the way to
one everywhere on the interval. This iteration therefore does not produce UV^T but rather something
like US'V^T where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
performance at all relative to UV^T, where USV^T = G is the SVD.
"""
assert
len
(
G
.
shape
)
==
2
a
,
b
,
c
=
(
3.4445
,
-
4.7750
,
2.0315
)
X
=
G
.
bfloat16
()
if
G
.
size
(
0
)
>
G
.
size
(
1
):
X
=
X
.
T
# Ensure spectral norm is at most 1
X
=
X
/
(
X
.
norm
()
+
1e-7
)
# Perform the NS iterations
for
_
in
range
(
steps
):
A
=
X
@
X
.
T
B
=
b
*
A
+
c
*
A
@
A
# adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
X
=
a
*
X
+
B
@
X
if
G
.
size
(
0
)
>
G
.
size
(
1
):
X
=
X
.
T
return
X
class
Muon
(
torch
.
optim
.
Optimizer
):
"""Muon - MomentUm Orthogonalized by Newton-schulz.
Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
the advantage that it can be stably run in bfloat16 on the GPU.
Some warnings:
- We believe this optimizer is unlikely to work well for training with small batch size.
- We believe it may not work well for finetuning pretrained models, but we haven't tested this.
Arguments:
muon_params: The parameters to be optimized by Muon.
lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default)
momentum: The momentum used by the internal SGD. (0.95 is a good default)
nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough)
adamw_params: The parameters to be optimized by AdamW. Any parameters in `muon_params` which are
{0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well.
adamw_lr: The learning rate for the internal AdamW.
adamw_betas: The betas for the internal AdamW.
adamw_eps: The epsilon for the internal AdamW.
adamw_wd: The weight decay for the internal AdamW.
"""
def
__init__
(
self
,
lr
=
1e-3
,
wd
=
0.1
,
muon_params
=
None
,
momentum
=
0.95
,
nesterov
=
True
,
ns_steps
=
5
,
adamw_params
=
None
,
adamw_betas
=
(
0.9
,
0.95
),
adamw_eps
=
1e-8
,
):
defaults
=
dict
(
lr
=
lr
,
wd
=
wd
,
momentum
=
momentum
,
nesterov
=
nesterov
,
ns_steps
=
ns_steps
,
adamw_betas
=
adamw_betas
,
adamw_eps
=
adamw_eps
,
)
params
=
list
(
muon_params
)
adamw_params
=
list
(
adamw_params
)
if
adamw_params
is
not
None
else
[]
params
.
extend
(
adamw_params
)
super
().
__init__
(
params
,
defaults
)
# Sort parameters into those for which we will use Muon, and those for which we will not
for
p
in
muon_params
:
# Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer
assert
p
.
ndim
==
2
,
p
.
ndim
self
.
state
[
p
][
"use_muon"
]
=
True
for
p
in
adamw_params
:
# Do not use Muon for parameters in adamw_params
self
.
state
[
p
][
"use_muon"
]
=
False
def
adjust_lr_for_muon
(
self
,
lr
:
float
,
param_shape
:
list
[
int
])
->
float
:
A
,
B
=
param_shape
[:
2
]
# We adjust the learning rate and weight decay based on the size of the parameter matrix
# as describted in the paper
adjusted_ratio
=
0.2
*
math
.
sqrt
(
max
(
A
,
B
))
adjusted_lr
=
lr
*
adjusted_ratio
return
adjusted_lr
def
step
(
self
,
closure
=
None
):
"""Perform a single optimization step.
Args:
closure (Callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss
=
None
if
closure
is
not
None
:
with
torch
.
enable_grad
():
loss
=
closure
()
for
group
in
self
.
param_groups
:
# Muon loop
params
=
[
p
for
p
in
group
[
"params"
]
if
self
.
state
[
p
][
"use_muon"
]]
lr
=
group
[
"lr"
]
wd
=
group
[
"wd"
]
momentum
=
group
[
"momentum"
]
# generate weight updates in distributed fashion
for
p
in
params
:
# sanity check
g
=
p
.
grad
if
g
is
None
:
continue
if
g
.
ndim
>
2
:
g
=
g
.
view
(
g
.
size
(
0
),
-
1
)
assert
g
is
not
None
# calc update
state
=
self
.
state
[
p
]
if
"momentum_buffer"
not
in
state
:
state
[
"momentum_buffer"
]
=
torch
.
zeros_like
(
g
)
buf
=
state
[
"momentum_buffer"
]
buf
.
mul_
(
momentum
).
add_
(
g
)
if
group
[
"nesterov"
]:
g
=
g
.
add
(
buf
,
alpha
=
momentum
)
else
:
g
=
buf
u
=
zeropower_via_newtonschulz5
(
g
,
steps
=
group
[
"ns_steps"
])
# scale update
adjusted_lr
=
self
.
adjust_lr_for_muon
(
lr
,
p
.
shape
)
# apply weight decay
p
.
data
.
mul_
(
1
-
lr
*
wd
)
# apply update
p
.
data
.
add_
(
u
,
alpha
=-
adjusted_lr
)
# Adam backup
params
=
[
p
for
p
in
group
[
"params"
]
if
not
self
.
state
[
p
][
"use_muon"
]]
lr
=
group
[
"lr"
]
beta1
,
beta2
=
group
[
"adamw_betas"
]
eps
=
group
[
"adamw_eps"
]
weight_decay
=
group
[
"wd"
]
for
p
in
params
:
g
=
p
.
grad
if
g
is
None
:
continue
state
=
self
.
state
[
p
]
if
"step"
not
in
state
:
state
[
"step"
]
=
0
state
[
"moment1"
]
=
torch
.
zeros_like
(
g
)
state
[
"moment2"
]
=
torch
.
zeros_like
(
g
)
state
[
"step"
]
+=
1
step
=
state
[
"step"
]
buf1
=
state
[
"moment1"
]
buf2
=
state
[
"moment2"
]
buf1
.
lerp_
(
g
,
1
-
beta1
)
buf2
.
lerp_
(
g
.
square
(),
1
-
beta2
)
g
=
buf1
/
(
eps
+
buf2
.
sqrt
())
bias_correction1
=
1
-
beta1
**
step
bias_correction2
=
1
-
beta2
**
step
scale
=
bias_correction1
/
bias_correction2
**
0.5
p
.
data
.
mul_
(
1
-
lr
*
weight_decay
)
p
.
data
.
add_
(
g
,
alpha
=-
lr
/
scale
)
return
loss
src/llamafactory/train/__init__.py
0 → 100644
View file @
c7c477c7
src/llamafactory/train/callbacks.py
0 → 100644
View file @
c7c477c7
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
json
import
os
import
signal
import
sys
import
time
from
concurrent.futures
import
ThreadPoolExecutor
from
datetime
import
timedelta
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
import
torch
import
transformers
from
peft
import
PeftModel
from
transformers
import
PreTrainedModel
,
ProcessorMixin
,
TrainerCallback
from
transformers.trainer_utils
import
PREFIX_CHECKPOINT_DIR
,
has_length
from
transformers.utils
import
(
SAFE_WEIGHTS_NAME
,
WEIGHTS_NAME
,
is_safetensors_available
,
)
from
typing_extensions
import
override
from
..extras
import
logging
from
..extras.constants
import
TRAINER_LOG
,
V_HEAD_SAFE_WEIGHTS_NAME
,
V_HEAD_WEIGHTS_NAME
from
..extras.misc
import
get_peak_memory
,
is_env_enabled
,
use_ray
if
is_safetensors_available
():
from
safetensors
import
safe_open
from
safetensors.torch
import
save_file
if
TYPE_CHECKING
:
from
transformers
import
TrainerControl
,
TrainerState
,
TrainingArguments
from
trl
import
AutoModelForCausalLMWithValueHead
from
..hparams
import
DataArguments
,
FinetuningArguments
,
GeneratingArguments
,
ModelArguments
logger
=
logging
.
get_logger
(
__name__
)
def
fix_valuehead_checkpoint
(
model
:
"AutoModelForCausalLMWithValueHead"
,
output_dir
:
str
,
safe_serialization
:
bool
)
->
None
:
r
"""Fix the valuehead checkpoint files.
The model is already unwrapped.
There are three cases:
1. full tuning without ds_zero3: state_dict = {"model.layers.*": ..., "v_head.summary.*": ...}
2. lora tuning without ds_zero3: state_dict = {"v_head.summary.*": ...}
3. under deepspeed zero3: state_dict = {"pretrained_model.model.layers.*": ..., "v_head.summary.*": ...}
We assume `stage3_gather_16bit_weights_on_model_save=true`.
"""
if
not
isinstance
(
model
.
pretrained_model
,
(
PreTrainedModel
,
PeftModel
)):
return
if
safe_serialization
:
path_to_checkpoint
=
os
.
path
.
join
(
output_dir
,
SAFE_WEIGHTS_NAME
)
with
safe_open
(
path_to_checkpoint
,
framework
=
"pt"
,
device
=
"cpu"
)
as
f
:
state_dict
:
dict
[
str
,
torch
.
Tensor
]
=
{
key
:
f
.
get_tensor
(
key
)
for
key
in
f
.
keys
()}
else
:
path_to_checkpoint
=
os
.
path
.
join
(
output_dir
,
WEIGHTS_NAME
)
state_dict
:
dict
[
str
,
torch
.
Tensor
]
=
torch
.
load
(
path_to_checkpoint
,
map_location
=
"cpu"
,
weights_only
=
True
)
os
.
remove
(
path_to_checkpoint
)
decoder_state_dict
,
v_head_state_dict
=
{},
{}
for
name
,
param
in
state_dict
.
items
():
if
name
.
startswith
(
"v_head."
):
v_head_state_dict
[
name
]
=
param
else
:
decoder_state_dict
[
name
.
replace
(
"pretrained_model."
,
""
,
1
)]
=
param
model
.
pretrained_model
.
save_pretrained
(
output_dir
,
state_dict
=
decoder_state_dict
or
None
,
safe_serialization
=
safe_serialization
)
if
safe_serialization
:
save_file
(
v_head_state_dict
,
os
.
path
.
join
(
output_dir
,
V_HEAD_SAFE_WEIGHTS_NAME
),
metadata
=
{
"format"
:
"pt"
})
else
:
torch
.
save
(
v_head_state_dict
,
os
.
path
.
join
(
output_dir
,
V_HEAD_WEIGHTS_NAME
))
logger
.
info_rank0
(
f
"Value head model saved at:
{
output_dir
}
"
)
class
FixValueHeadModelCallback
(
TrainerCallback
):
r
"""A callback for fixing the checkpoint for valuehead models."""
@
override
def
on_save
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
if
args
.
should_save
:
output_dir
=
os
.
path
.
join
(
args
.
output_dir
,
f
"
{
PREFIX_CHECKPOINT_DIR
}
-
{
state
.
global_step
}
"
)
fix_valuehead_checkpoint
(
model
=
kwargs
.
pop
(
"model"
),
output_dir
=
output_dir
,
safe_serialization
=
args
.
save_safetensors
)
class
SaveProcessorCallback
(
TrainerCallback
):
r
"""A callback for saving the processor."""
def
__init__
(
self
,
processor
:
"ProcessorMixin"
)
->
None
:
self
.
processor
=
processor
@
override
def
on_save
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
if
args
.
should_save
:
output_dir
=
os
.
path
.
join
(
args
.
output_dir
,
f
"
{
PREFIX_CHECKPOINT_DIR
}
-
{
state
.
global_step
}
"
)
self
.
processor
.
save_pretrained
(
output_dir
)
@
override
def
on_train_end
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
if
args
.
should_save
:
self
.
processor
.
save_pretrained
(
args
.
output_dir
)
class
PissaConvertCallback
(
TrainerCallback
):
r
"""A callback for converting the PiSSA adapter to a normal one."""
@
override
def
on_train_begin
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
if
args
.
should_save
:
model
=
kwargs
.
pop
(
"model"
)
pissa_init_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"pissa_init"
)
logger
.
info_rank0
(
f
"Initial PiSSA adapter will be saved at:
{
pissa_init_dir
}
."
)
if
isinstance
(
model
,
PeftModel
):
init_lora_weights
=
getattr
(
model
.
peft_config
[
"default"
],
"init_lora_weights"
)
setattr
(
model
.
peft_config
[
"default"
],
"init_lora_weights"
,
True
)
model
.
save_pretrained
(
pissa_init_dir
,
safe_serialization
=
args
.
save_safetensors
)
setattr
(
model
.
peft_config
[
"default"
],
"init_lora_weights"
,
init_lora_weights
)
@
override
def
on_train_end
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
if
args
.
should_save
:
model
=
kwargs
.
pop
(
"model"
)
pissa_init_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"pissa_init"
)
pissa_backup_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"pissa_backup"
)
pissa_convert_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"pissa_converted"
)
logger
.
info_rank0
(
f
"Converted PiSSA adapter will be saved at:
{
pissa_convert_dir
}
."
)
# 1. save a pissa backup with init_lora_weights: True
# 2. save a converted lora with init_lora_weights: pissa
# 3. load the pissa backup with init_lora_weights: True
# 4. delete the initial adapter and change init_lora_weights to pissa
if
isinstance
(
model
,
PeftModel
):
init_lora_weights
=
getattr
(
model
.
peft_config
[
"default"
],
"init_lora_weights"
)
setattr
(
model
.
peft_config
[
"default"
],
"init_lora_weights"
,
True
)
model
.
save_pretrained
(
pissa_backup_dir
,
safe_serialization
=
args
.
save_safetensors
)
setattr
(
model
.
peft_config
[
"default"
],
"init_lora_weights"
,
init_lora_weights
)
model
.
save_pretrained
(
pissa_convert_dir
,
safe_serialization
=
args
.
save_safetensors
,
path_initial_model_for_weight_conversion
=
pissa_init_dir
,
)
model
.
load_adapter
(
pissa_backup_dir
,
"default"
,
is_trainable
=
True
)
model
.
set_adapter
(
"default"
)
setattr
(
model
.
peft_config
[
"default"
],
"init_lora_weights"
,
init_lora_weights
)
class
LogCallback
(
TrainerCallback
):
r
"""A callback for logging training and evaluation status."""
def
__init__
(
self
)
->
None
:
# Progress
self
.
start_time
=
0
self
.
cur_steps
=
0
self
.
max_steps
=
0
self
.
elapsed_time
=
""
self
.
remaining_time
=
""
self
.
thread_pool
:
Optional
[
ThreadPoolExecutor
]
=
None
# Status
self
.
aborted
=
False
self
.
do_train
=
False
# Web UI
self
.
webui_mode
=
is_env_enabled
(
"LLAMABOARD_ENABLED"
)
if
self
.
webui_mode
and
not
use_ray
():
signal
.
signal
(
signal
.
SIGABRT
,
self
.
_set_abort
)
self
.
logger_handler
=
logging
.
LoggerHandler
(
os
.
getenv
(
"LLAMABOARD_WORKDIR"
))
logging
.
add_handler
(
self
.
logger_handler
)
transformers
.
logging
.
add_handler
(
self
.
logger_handler
)
def
_set_abort
(
self
,
signum
,
frame
)
->
None
:
self
.
aborted
=
True
def
_reset
(
self
,
max_steps
:
int
=
0
)
->
None
:
self
.
start_time
=
time
.
time
()
self
.
cur_steps
=
0
self
.
max_steps
=
max_steps
self
.
elapsed_time
=
""
self
.
remaining_time
=
""
def
_timing
(
self
,
cur_steps
:
int
)
->
None
:
cur_time
=
time
.
time
()
elapsed_time
=
cur_time
-
self
.
start_time
avg_time_per_step
=
elapsed_time
/
cur_steps
if
cur_steps
!=
0
else
0
remaining_time
=
(
self
.
max_steps
-
cur_steps
)
*
avg_time_per_step
self
.
cur_steps
=
cur_steps
self
.
elapsed_time
=
str
(
timedelta
(
seconds
=
int
(
elapsed_time
)))
self
.
remaining_time
=
str
(
timedelta
(
seconds
=
int
(
remaining_time
)))
def
_write_log
(
self
,
output_dir
:
str
,
logs
:
dict
[
str
,
Any
])
->
None
:
with
open
(
os
.
path
.
join
(
output_dir
,
TRAINER_LOG
),
"a"
,
encoding
=
"utf-8"
)
as
f
:
f
.
write
(
json
.
dumps
(
logs
)
+
"
\n
"
)
def
_create_thread_pool
(
self
,
output_dir
:
str
)
->
None
:
os
.
makedirs
(
output_dir
,
exist_ok
=
True
)
self
.
thread_pool
=
ThreadPoolExecutor
(
max_workers
=
1
)
def
_close_thread_pool
(
self
)
->
None
:
if
self
.
thread_pool
is
not
None
:
self
.
thread_pool
.
shutdown
(
wait
=
True
)
self
.
thread_pool
=
None
@
override
def
on_init_end
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
if
(
args
.
should_save
and
os
.
path
.
exists
(
os
.
path
.
join
(
args
.
output_dir
,
TRAINER_LOG
))
and
args
.
overwrite_output_dir
):
logger
.
warning_rank0_once
(
"Previous trainer log in this folder will be deleted."
)
os
.
remove
(
os
.
path
.
join
(
args
.
output_dir
,
TRAINER_LOG
))
@
override
def
on_train_begin
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
if
args
.
should_save
:
self
.
do_train
=
True
self
.
_reset
(
max_steps
=
state
.
max_steps
)
self
.
_create_thread_pool
(
output_dir
=
args
.
output_dir
)
@
override
def
on_train_end
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
self
.
_close_thread_pool
()
@
override
def
on_substep_end
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
if
self
.
aborted
:
control
.
should_epoch_stop
=
True
control
.
should_training_stop
=
True
@
override
def
on_step_end
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
if
self
.
aborted
:
control
.
should_epoch_stop
=
True
control
.
should_training_stop
=
True
@
override
def
on_evaluate
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
if
not
self
.
do_train
:
self
.
_close_thread_pool
()
@
override
def
on_predict
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
if
not
self
.
do_train
:
self
.
_close_thread_pool
()
@
override
def
on_log
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
if
not
args
.
should_save
:
return
self
.
_timing
(
cur_steps
=
state
.
global_step
)
logs
=
dict
(
current_steps
=
self
.
cur_steps
,
total_steps
=
self
.
max_steps
,
loss
=
state
.
log_history
[
-
1
].
get
(
"loss"
),
eval_loss
=
state
.
log_history
[
-
1
].
get
(
"eval_loss"
),
predict_loss
=
state
.
log_history
[
-
1
].
get
(
"predict_loss"
),
reward
=
state
.
log_history
[
-
1
].
get
(
"reward"
),
accuracy
=
state
.
log_history
[
-
1
].
get
(
"rewards/accuracies"
),
lr
=
state
.
log_history
[
-
1
].
get
(
"learning_rate"
),
epoch
=
state
.
log_history
[
-
1
].
get
(
"epoch"
),
percentage
=
round
(
self
.
cur_steps
/
self
.
max_steps
*
100
,
2
)
if
self
.
max_steps
!=
0
else
100
,
elapsed_time
=
self
.
elapsed_time
,
remaining_time
=
self
.
remaining_time
,
)
if
state
.
num_input_tokens_seen
:
logs
[
"throughput"
]
=
round
(
state
.
num_input_tokens_seen
/
(
time
.
time
()
-
self
.
start_time
),
2
)
logs
[
"total_tokens"
]
=
state
.
num_input_tokens_seen
if
is_env_enabled
(
"RECORD_VRAM"
):
vram_allocated
,
vram_reserved
=
get_peak_memory
()
logs
[
"vram_allocated"
]
=
round
(
vram_allocated
/
(
1024
**
3
),
2
)
logs
[
"vram_reserved"
]
=
round
(
vram_reserved
/
(
1024
**
3
),
2
)
logs
=
{
k
:
v
for
k
,
v
in
logs
.
items
()
if
v
is
not
None
}
if
self
.
webui_mode
and
all
(
key
in
logs
for
key
in
(
"loss"
,
"lr"
,
"epoch"
)):
log_str
=
f
"'loss':
{
logs
[
'loss'
]:.
4
f
}
, 'learning_rate':
{
logs
[
'lr'
]:
2.4
e
}
, 'epoch':
{
logs
[
'epoch'
]:.
2
f
}
"
for
extra_key
in
(
"reward"
,
"accuracy"
,
"throughput"
):
if
logs
.
get
(
extra_key
):
log_str
+=
f
", '
{
extra_key
}
':
{
logs
[
extra_key
]:.
2
f
}
"
logger
.
info_rank0
(
"{"
+
log_str
+
"}"
)
if
self
.
thread_pool
is
not
None
:
self
.
thread_pool
.
submit
(
self
.
_write_log
,
args
.
output_dir
,
logs
)
@
override
def
on_prediction_step
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
if
self
.
do_train
:
return
if
self
.
aborted
:
sys
.
exit
(
0
)
if
not
args
.
should_save
:
return
eval_dataloader
=
kwargs
.
pop
(
"eval_dataloader"
,
None
)
if
has_length
(
eval_dataloader
):
if
self
.
max_steps
==
0
:
self
.
_reset
(
max_steps
=
len
(
eval_dataloader
))
self
.
_create_thread_pool
(
output_dir
=
args
.
output_dir
)
self
.
_timing
(
cur_steps
=
self
.
cur_steps
+
1
)
if
self
.
cur_steps
%
5
==
0
and
self
.
thread_pool
is
not
None
:
logs
=
dict
(
current_steps
=
self
.
cur_steps
,
total_steps
=
self
.
max_steps
,
percentage
=
round
(
self
.
cur_steps
/
self
.
max_steps
*
100
,
2
)
if
self
.
max_steps
!=
0
else
100
,
elapsed_time
=
self
.
elapsed_time
,
remaining_time
=
self
.
remaining_time
,
)
self
.
thread_pool
.
submit
(
self
.
_write_log
,
args
.
output_dir
,
logs
)
class
ReporterCallback
(
TrainerCallback
):
r
"""A callback for reporting training status to external logger."""
def
__init__
(
self
,
model_args
:
"ModelArguments"
,
data_args
:
"DataArguments"
,
finetuning_args
:
"FinetuningArguments"
,
generating_args
:
"GeneratingArguments"
,
)
->
None
:
self
.
model_args
=
model_args
self
.
data_args
=
data_args
self
.
finetuning_args
=
finetuning_args
self
.
generating_args
=
generating_args
os
.
environ
[
"WANDB_PROJECT"
]
=
os
.
getenv
(
"WANDB_PROJECT"
,
"llamafactory"
)
@
override
def
on_train_begin
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
if
not
state
.
is_world_process_zero
:
return
if
"wandb"
in
args
.
report_to
:
import
wandb
wandb
.
config
.
update
(
{
"model_args"
:
self
.
model_args
.
to_dict
(),
"data_args"
:
self
.
data_args
.
to_dict
(),
"finetuning_args"
:
self
.
finetuning_args
.
to_dict
(),
"generating_args"
:
self
.
generating_args
.
to_dict
(),
}
)
if
self
.
finetuning_args
.
use_swanlab
:
import
swanlab
# type: ignore
swanlab
.
config
.
update
(
{
"model_args"
:
self
.
model_args
.
to_dict
(),
"data_args"
:
self
.
data_args
.
to_dict
(),
"finetuning_args"
:
self
.
finetuning_args
.
to_dict
(),
"generating_args"
:
self
.
generating_args
.
to_dict
(),
}
)
src/llamafactory/train/dpo/__init__.py
0 → 100644
View file @
c7c477c7
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.workflow
import
run_dpo
__all__
=
[
"run_dpo"
]
src/llamafactory/train/dpo/trainer.py
0 → 100644
View file @
c7c477c7
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's TRL library.
# https://github.com/huggingface/trl/blob/v0.8.0/trl/trainer/dpo_trainer.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
warnings
from
collections
import
defaultdict
from
contextlib
import
nullcontext
from
types
import
MethodType
from
typing
import
TYPE_CHECKING
,
Literal
,
Optional
,
Union
import
torch
import
torch.nn.functional
as
F
from
transformers
import
Trainer
from
trl
import
DPOTrainer
from
trl.trainer
import
disable_dropout_in_model
from
typing_extensions
import
override
from
...extras.constants
import
IGNORE_INDEX
from
...extras.packages
import
is_transformers_version_greater_than
from
..callbacks
import
SaveProcessorCallback
from
..trainer_utils
import
create_custom_optimizer
,
create_custom_scheduler
,
get_batch_logps
,
nested_detach
if
TYPE_CHECKING
:
from
transformers
import
PreTrainedModel
,
ProcessorMixin
from
...hparams
import
FinetuningArguments
class
CustomDPOTrainer
(
DPOTrainer
):
def
__init__
(
self
,
model
:
Union
[
"PreTrainedModel"
,
torch
.
nn
.
Module
],
ref_model
:
Optional
[
Union
[
"PreTrainedModel"
,
torch
.
nn
.
Module
]],
finetuning_args
:
"FinetuningArguments"
,
processor
:
Optional
[
"ProcessorMixin"
],
disable_dropout
:
bool
=
True
,
**
kwargs
,
):
if
is_transformers_version_greater_than
(
"4.46"
):
kwargs
[
"processing_class"
]
=
kwargs
.
pop
(
"tokenizer"
)
if
disable_dropout
:
disable_dropout_in_model
(
model
)
if
ref_model
is
not
None
:
disable_dropout_in_model
(
ref_model
)
self
.
finetuning_args
=
finetuning_args
self
.
f_divergence_type
=
"reverse_kl"
self
.
reference_free
=
False
self
.
use_dpo_data_collator
=
True
# hack to avoid warning
self
.
generate_during_eval
=
False
# disable at evaluation
self
.
label_pad_token_id
=
IGNORE_INDEX
self
.
padding_value
=
0
self
.
is_encoder_decoder
=
model
.
config
.
is_encoder_decoder
self
.
precompute_ref_log_probs
=
False
self
.
_precomputed_train_ref_log_probs
=
False
self
.
_precomputed_eval_ref_log_probs
=
False
self
.
_peft_has_been_casted_to_bf16
=
False
self
.
ref_model
=
ref_model
self
.
_stored_metrics
=
defaultdict
(
lambda
:
defaultdict
(
list
))
# dpo hyperparams
self
.
beta
=
finetuning_args
.
pref_beta
self
.
loss_type
=
finetuning_args
.
pref_loss
self
.
ftx_gamma
=
finetuning_args
.
pref_ftx
self
.
label_smoothing
=
finetuning_args
.
dpo_label_smoothing
self
.
simpo_gamma
=
finetuning_args
.
simpo_gamma
self
.
ld_alpha
=
finetuning_args
.
ld_alpha
Trainer
.
__init__
(
self
,
model
=
model
,
**
kwargs
)
self
.
model_accepts_loss_kwargs
=
False
# overwrite trainer's default behavior
if
not
hasattr
(
self
,
"accelerator"
):
raise
AttributeError
(
"Please update `transformers`."
)
warnings
.
simplefilter
(
"ignore"
)
# remove gc warnings on ref model
if
ref_model
is
not
None
:
if
self
.
is_deepspeed_enabled
:
if
not
(
getattr
(
ref_model
,
"is_loaded_in_8bit"
,
False
)
or
getattr
(
ref_model
,
"is_loaded_in_4bit"
,
False
)
):
# quantized models are already set on the correct device
self
.
ref_model
=
self
.
_prepare_deepspeed
(
self
.
ref_model
)
else
:
self
.
ref_model
=
self
.
accelerator
.
prepare_model
(
self
.
ref_model
,
evaluation_mode
=
True
)
self
.
ref_model
.
eval
()
if
processor
is
not
None
:
self
.
add_callback
(
SaveProcessorCallback
(
processor
))
if
finetuning_args
.
use_badam
:
from
badam
import
BAdamCallback
,
clip_grad_norm_old_version
# type: ignore
self
.
accelerator
.
clip_grad_norm_
=
MethodType
(
clip_grad_norm_old_version
,
self
.
accelerator
)
self
.
add_callback
(
BAdamCallback
)
@
override
def
create_optimizer
(
self
)
->
"torch.optim.Optimizer"
:
if
self
.
optimizer
is
None
:
self
.
optimizer
=
create_custom_optimizer
(
self
.
model
,
self
.
args
,
self
.
finetuning_args
)
return
super
().
create_optimizer
()
@
override
def
create_scheduler
(
self
,
num_training_steps
:
int
,
optimizer
:
Optional
[
"torch.optim.Optimizer"
]
=
None
)
->
"torch.optim.lr_scheduler.LRScheduler"
:
create_custom_scheduler
(
self
.
args
,
num_training_steps
,
optimizer
)
return
super
().
create_scheduler
(
num_training_steps
,
optimizer
)
@
override
def
_get_train_sampler
(
self
,
*
args
,
**
kwargs
)
->
Optional
[
"torch.utils.data.Sampler"
]:
if
self
.
finetuning_args
.
disable_shuffling
:
return
torch
.
utils
.
data
.
SequentialSampler
(
self
.
train_dataset
)
return
super
().
_get_train_sampler
(
*
args
,
**
kwargs
)
@
override
def
get_batch_samples
(
self
,
*
args
,
**
kwargs
):
r
"""Replace the method of DPO Trainer with the one of the standard Trainer."""
return
Trainer
.
get_batch_samples
(
self
,
*
args
,
**
kwargs
)
def
odds_ratio_loss
(
self
,
chosen_logps
:
"torch.Tensor"
,
rejected_logps
:
"torch.Tensor"
)
->
"torch.Tensor"
:
r
"""Compute ORPO's odds ratio (OR) loss for batched log probabilities of the policy model."""
log_odds
=
(
chosen_logps
-
rejected_logps
)
-
(
torch
.
log1p
(
-
torch
.
exp
(
chosen_logps
))
-
torch
.
log1p
(
-
torch
.
exp
(
rejected_logps
))
)
sft_loss
=
-
chosen_logps
odds_ratio_loss
=
-
F
.
logsigmoid
(
log_odds
)
orpo_loss
=
sft_loss
+
self
.
beta
*
odds_ratio_loss
return
orpo_loss
def
simpo_loss
(
self
,
chosen_logps
:
"torch.Tensor"
,
rejected_logps
:
"torch.Tensor"
)
->
"torch.Tensor"
:
r
"""Compute SimPO loss for batched log probabilities of the policy model."""
pi_logratios
=
chosen_logps
-
rejected_logps
gamma_logratios
=
self
.
simpo_gamma
/
self
.
beta
logits
=
pi_logratios
-
gamma_logratios
simpo_loss
=
-
F
.
logsigmoid
(
self
.
beta
*
logits
)
return
simpo_loss
def
compute_preference_loss
(
self
,
policy_chosen_logps
:
"torch.Tensor"
,
policy_rejected_logps
:
"torch.Tensor"
,
reference_chosen_logps
:
Optional
[
"torch.Tensor"
],
reference_rejected_logps
:
Optional
[
"torch.Tensor"
],
)
->
tuple
[
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
]:
r
"""Compute loss for preference learning."""
if
not
self
.
finetuning_args
.
use_ref_model
:
if
self
.
loss_type
==
"orpo"
:
losses
=
self
.
odds_ratio_loss
(
policy_chosen_logps
,
policy_rejected_logps
)
elif
self
.
loss_type
==
"simpo"
:
losses
=
self
.
simpo_loss
(
policy_chosen_logps
,
policy_rejected_logps
)
else
:
raise
NotImplementedError
(
f
"Unknown loss type:
{
self
.
loss_type
}
."
)
chosen_rewards
=
self
.
beta
*
policy_chosen_logps
.
to
(
self
.
accelerator
.
device
).
detach
()
rejected_rewards
=
self
.
beta
*
policy_rejected_logps
.
to
(
self
.
accelerator
.
device
).
detach
()
else
:
losses
,
chosen_rewards
,
rejected_rewards
=
self
.
dpo_loss
(
policy_chosen_logps
,
policy_rejected_logps
,
reference_chosen_logps
,
reference_rejected_logps
)
return
losses
,
chosen_rewards
,
rejected_rewards
@
override
def
concatenated_forward
(
self
,
model
:
"PreTrainedModel"
,
batch
:
dict
[
str
,
"torch.Tensor"
],
is_ref_model
:
bool
=
False
)
->
tuple
[
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
]:
r
"""Compute the sum log probabilities of the labels under given logits if loss_type is not IPO, ORPO or SimPO.
Otherwise the average log probabilities.
"""
if
self
.
finetuning_args
.
use_ref_model
:
batch
=
nested_detach
(
batch
,
clone
=
True
)
# avoid error
all_logits
:
torch
.
Tensor
=
model
(
**
batch
,
return_dict
=
True
,
use_cache
=
False
).
logits
.
to
(
torch
.
float32
)
all_logps
,
valid_length
=
get_batch_logps
(
logits
=
all_logits
,
labels
=
batch
[
"labels"
],
ld_alpha
=
(
self
.
ld_alpha
if
not
is_ref_model
else
None
)
)
if
self
.
loss_type
in
[
"ipo"
,
"orpo"
,
"simpo"
]:
all_logps
=
all_logps
/
valid_length
batch_size
=
batch
[
"input_ids"
].
size
(
0
)
//
2
chosen_logps
,
rejected_logps
=
all_logps
.
split
(
batch_size
,
dim
=
0
)
chosen_logits
,
rejected_logits
=
all_logits
.
split
(
batch_size
,
dim
=
0
)
chosen_length
,
_
=
valid_length
.
split
(
batch_size
,
dim
=
0
)
if
self
.
loss_type
in
[
"ipo"
,
"orpo"
,
"simpo"
]:
return
chosen_logps
,
rejected_logps
,
chosen_logits
,
rejected_logits
,
chosen_logps
else
:
return
chosen_logps
,
rejected_logps
,
chosen_logits
,
rejected_logits
,
chosen_logps
/
chosen_length
@
override
def
compute_reference_log_probs
(
self
,
model
:
"PreTrainedModel"
,
batch
:
dict
[
str
,
"torch.Tensor"
]
)
->
tuple
[
Optional
[
"torch.Tensor"
],
Optional
[
"torch.Tensor"
]]:
r
"""Compute log probabilities of the reference model."""
if
not
self
.
finetuning_args
.
use_ref_model
:
return
None
,
None
if
self
.
ref_model
is
None
:
ref_model
=
model
ref_context
=
self
.
accelerator
.
unwrap_model
(
model
).
disable_adapter
()
else
:
ref_model
=
self
.
ref_model
ref_context
=
nullcontext
()
with
torch
.
no_grad
(),
ref_context
:
reference_chosen_logps
,
reference_rejected_logps
,
*
_
=
self
.
concatenated_forward
(
ref_model
,
batch
,
is_ref_model
=
True
)
return
reference_chosen_logps
,
reference_rejected_logps
@
override
def
get_batch_loss_metrics
(
self
,
model
:
"PreTrainedModel"
,
batch
:
dict
[
str
,
"torch.Tensor"
],
train_eval
:
Literal
[
"train"
,
"eval"
]
=
"train"
,
)
->
tuple
[
"torch.Tensor"
,
dict
[
str
,
"torch.Tensor"
]]:
r
"""Compute the DPO loss and other metrics for the given batch of inputs for train or test."""
metrics
=
{}
(
policy_chosen_logps
,
policy_rejected_logps
,
policy_chosen_logits
,
policy_rejected_logits
,
policy_chosen_logps_avg
,
)
=
self
.
concatenated_forward
(
model
,
batch
)
reference_chosen_logps
,
reference_rejected_logps
=
self
.
compute_reference_log_probs
(
model
,
batch
)
losses
,
chosen_rewards
,
rejected_rewards
=
self
.
compute_preference_loss
(
policy_chosen_logps
,
policy_rejected_logps
,
reference_chosen_logps
,
reference_rejected_logps
,
)
sft_loss
=
-
policy_chosen_logps_avg
if
self
.
ftx_gamma
>
1e-6
:
losses
+=
self
.
ftx_gamma
*
sft_loss
prefix
=
"eval_"
if
train_eval
==
"eval"
else
""
metrics
[
f
"
{
prefix
}
rewards/chosen"
]
=
chosen_rewards
.
mean
().
item
()
metrics
[
f
"
{
prefix
}
rewards/rejected"
]
=
rejected_rewards
.
mean
().
item
()
metrics
[
f
"
{
prefix
}
rewards/accuracies"
]
=
(
chosen_rewards
>
rejected_rewards
).
float
().
mean
().
item
()
metrics
[
f
"
{
prefix
}
rewards/margins"
]
=
(
chosen_rewards
-
rejected_rewards
).
mean
().
item
()
metrics
[
f
"
{
prefix
}
logps/chosen"
]
=
policy_chosen_logps
.
mean
().
item
()
metrics
[
f
"
{
prefix
}
logps/rejected"
]
=
policy_rejected_logps
.
mean
().
item
()
metrics
[
f
"
{
prefix
}
logits/chosen"
]
=
policy_chosen_logits
.
mean
().
item
()
metrics
[
f
"
{
prefix
}
logits/rejected"
]
=
policy_rejected_logits
.
mean
().
item
()
if
self
.
loss_type
==
"orpo"
:
metrics
[
f
"
{
prefix
}
sft_loss"
]
=
sft_loss
.
mean
().
item
()
metrics
[
f
"
{
prefix
}
odds_ratio_loss"
]
=
((
losses
-
sft_loss
)
/
self
.
beta
).
mean
().
item
()
return
losses
.
mean
(),
metrics
@
override
def
compute_loss
(
self
,
model
:
"PreTrainedModel"
,
inputs
:
dict
[
str
,
"torch.Tensor"
],
return_outputs
:
bool
=
False
,
**
kwargs
)
->
Union
[
"torch.Tensor"
,
tuple
[
"torch.Tensor"
,
list
[
"torch.Tensor"
]]]:
r
"""Subclass and override to accept extra kwargs."""
return
super
().
compute_loss
(
model
,
inputs
,
return_outputs
)
@
override
def
log
(
self
,
logs
:
dict
[
str
,
float
],
*
args
,
**
kwargs
)
->
None
:
r
"""Log `logs` on the various objects watching training, including stored metrics."""
# logs either has "loss" or "eval_loss"
train_eval
=
"train"
if
"loss"
in
logs
else
"eval"
# Add averaged stored metrics to logs
key_list
,
metric_list
=
[],
[]
for
key
,
metrics
in
self
.
_stored_metrics
[
train_eval
].
items
():
key_list
.
append
(
key
)
metric_list
.
append
(
torch
.
tensor
(
metrics
,
dtype
=
torch
.
float
).
to
(
self
.
accelerator
.
device
).
mean
().
item
())
del
self
.
_stored_metrics
[
train_eval
]
if
len
(
metric_list
)
<
10
:
# pad to for all reduce
for
i
in
range
(
10
-
len
(
metric_list
)):
key_list
.
append
(
f
"dummy_
{
i
}
"
)
metric_list
.
append
(
0.0
)
metric_list
=
torch
.
tensor
(
metric_list
,
dtype
=
torch
.
float
).
to
(
self
.
accelerator
.
device
)
metric_list
=
self
.
accelerator
.
reduce
(
metric_list
,
"mean"
).
tolist
()
for
key
,
metric
in
zip
(
key_list
,
metric_list
):
# add remaining items
if
not
key
.
startswith
(
"dummy_"
):
logs
[
key
]
=
metric
return
Trainer
.
log
(
self
,
logs
,
*
args
,
**
kwargs
)
src/llamafactory/train/dpo/workflow.py
0 → 100644
View file @
c7c477c7
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's TRL library.
# https://github.com/huggingface/trl/blob/v0.8.0/examples/scripts/dpo.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
TYPE_CHECKING
,
Optional
from
...data
import
PairwiseDataCollatorWithPadding
,
get_dataset
,
get_template_and_fix_tokenizer
from
...extras.constants
import
IGNORE_INDEX
from
...extras.misc
import
calculate_tps
from
...extras.ploting
import
plot_loss
from
...hparams
import
ModelArguments
from
...model
import
load_model
,
load_tokenizer
from
..trainer_utils
import
create_modelcard_and_push
,
create_ref_model
from
.trainer
import
CustomDPOTrainer
if
TYPE_CHECKING
:
from
transformers
import
Seq2SeqTrainingArguments
,
TrainerCallback
from
...hparams
import
DataArguments
,
FinetuningArguments
def
run_dpo
(
model_args
:
"ModelArguments"
,
data_args
:
"DataArguments"
,
training_args
:
"Seq2SeqTrainingArguments"
,
finetuning_args
:
"FinetuningArguments"
,
callbacks
:
Optional
[
list
[
"TrainerCallback"
]]
=
None
,
):
tokenizer_module
=
load_tokenizer
(
model_args
)
tokenizer
=
tokenizer_module
[
"tokenizer"
]
template
=
get_template_and_fix_tokenizer
(
tokenizer
,
data_args
)
dataset_module
=
get_dataset
(
template
,
model_args
,
data_args
,
training_args
,
stage
=
"rm"
,
**
tokenizer_module
)
model
=
load_model
(
tokenizer
,
model_args
,
finetuning_args
,
training_args
.
do_train
)
data_collator
=
PairwiseDataCollatorWithPadding
(
template
=
template
,
model
=
model
,
pad_to_multiple_of
=
8
,
label_pad_token_id
=
IGNORE_INDEX
if
data_args
.
ignore_pad_token_for_loss
else
tokenizer
.
pad_token_id
,
**
tokenizer_module
,
)
# Create reference model
if
finetuning_args
.
use_ref_model
:
if
finetuning_args
.
ref_model
is
None
and
(
not
training_args
.
do_train
):
# use the model itself
ref_model
=
model
else
:
ref_model
=
create_ref_model
(
model_args
,
finetuning_args
)
else
:
ref_model
=
None
# Initialize our Trainer
trainer
=
CustomDPOTrainer
(
model
=
model
,
ref_model
=
ref_model
,
args
=
training_args
,
finetuning_args
=
finetuning_args
,
data_collator
=
data_collator
,
callbacks
=
callbacks
,
**
dataset_module
,
**
tokenizer_module
,
)
# Training
if
training_args
.
do_train
:
train_result
=
trainer
.
train
(
resume_from_checkpoint
=
training_args
.
resume_from_checkpoint
)
trainer
.
save_model
()
if
finetuning_args
.
include_effective_tokens_per_second
:
train_result
.
metrics
[
"effective_tokens_per_sec"
]
=
calculate_tps
(
dataset_module
[
"train_dataset"
],
train_result
.
metrics
,
stage
=
"rm"
)
trainer
.
log_metrics
(
"train"
,
train_result
.
metrics
)
trainer
.
save_metrics
(
"train"
,
train_result
.
metrics
)
trainer
.
save_state
()
if
trainer
.
is_world_process_zero
()
and
finetuning_args
.
plot_loss
:
keys
=
[
"loss"
,
"rewards/accuracies"
]
if
isinstance
(
dataset_module
.
get
(
"eval_dataset"
),
dict
):
keys
+=
[
f
"eval_
{
key
}
_loss"
for
key
in
dataset_module
[
"eval_dataset"
].
keys
()]
else
:
keys
+=
[
"eval_loss"
]
plot_loss
(
training_args
.
output_dir
,
keys
=
keys
)
# Evaluation
if
training_args
.
do_eval
:
metrics
=
trainer
.
evaluate
(
metric_key_prefix
=
"eval"
)
if
id
(
model
)
==
id
(
ref_model
):
# unable to compute rewards if reference model is the model itself
remove_keys
=
[
key
for
key
in
metrics
.
keys
()
if
"rewards"
in
key
]
for
key
in
remove_keys
:
metrics
.
pop
(
key
)
trainer
.
log_metrics
(
"eval"
,
metrics
)
trainer
.
save_metrics
(
"eval"
,
metrics
)
# Create model card
create_modelcard_and_push
(
trainer
,
model_args
,
data_args
,
training_args
,
finetuning_args
)
src/llamafactory/train/grpo/__init__.py
0 → 100644
View file @
c7c477c7
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.workflow
import
run_grpo
__all__
=
[
"run_grpo"
]
src/llamafactory/train/grpo/data.py
0 → 100644
View file @
c7c477c7
import
re
from
datasets
import
load_dataset
,
Dataset
SYSTEM_PROMPT
=
"""
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""
def
extract_hash_answer
(
text
:
str
)
->
str
|
None
:
if
"####"
not
in
text
:
return
None
return
text
.
split
(
"####"
)[
1
].
strip
().
replace
(
","
,
""
).
replace
(
"$"
,
""
)
def
extract_deepseek_r1_answer
(
text
)
->
str
|
None
:
words_to_check
=
[
"applied_math"
,
"Advanced-Math"
,
"GSM8K_zh"
,
'EduChat-Math'
]
pattern
=
r
'\b('
+
'|'
.
join
(
map
(
re
.
escape
,
words_to_check
))
+
r
')\b'
has_match
=
bool
(
re
.
search
(
pattern
,
text
[
'repo_name'
],
flags
=
re
.
IGNORECASE
))
if
has_match
:
pattern
=
r
"\\boxed\{(.*)\}"
match
=
re
.
search
(
pattern
,
text
[
'output'
])
if
match
:
return
match
.
group
(
1
)
else
:
return
None
else
:
return
None
# uncomment middle messages for 1-shot prompting
def
get_gsm8k_questions
(
dataset
=
'openai/gsm8k'
,
split
=
"train"
)
->
Dataset
:
data
=
load_dataset
(
dataset
,
'main'
)[
split
]
# type: ignore
data
=
data
.
map
(
lambda
x
:
{
# type: ignore
'prompt'
:
[
{
'role'
:
'system'
,
'content'
:
SYSTEM_PROMPT
},
{
'role'
:
'user'
,
'content'
:
x
[
'question'
]}
],
'answer'
:
extract_hash_answer
(
x
[
'answer'
])
},
num_proc
=
16
,
remove_columns
=
[
"question"
])
# type: ignore
data
=
data
.
filter
(
lambda
x
:
x
[
'answer'
]
is
not
None
,
num_proc
=
16
)
# print("---", data[0])
return
data
# type: ignore
def
get_deepseek_r1_questions
(
dataset
=
'Congliu/Chinese-DeepSeek-R1-Distill-data-110k-SFT'
,
split
=
"train"
)
->
Dataset
:
data
=
load_dataset
(
dataset
)[
split
]
# type: ignore
data
=
data
.
map
(
lambda
x
:
{
'prompt'
:
[
{
'role'
:
'system'
,
'content'
:
SYSTEM_PROMPT
},
{
'role'
:
'user'
,
'content'
:
x
[
'instruction'
]}
],
'answer'
:
extract_deepseek_r1_answer
(
x
)
},
num_proc
=
16
,
# type: ignore
remove_columns
=
[
"instruction"
,
"output"
,
"repo_name"
,
"prompt_tokens_len"
,
"input"
,
"reasoning_content_tokens_len"
,
"score"
,
"content_tokens_len"
],
)
data
=
data
.
filter
(
lambda
x
:
x
[
'answer'
]
is
not
None
,
num_proc
=
32
)
# type: ignore
print
(
"GET {} data in Chinese-DeepSeek-R1-Distill-data-110k-SFT"
.
format
(
len
(
data
)))
return
data
# type: ignore
def
get_hiyoga
(
dataset
=
'hiyouga/math12k'
,
split
=
'train'
)
->
Dataset
:
data
=
load_dataset
(
dataset
)[
split
]
# type: ignore
data
=
data
.
map
(
lambda
x
:
{
'prompt'
:
[
{
'role'
:
'system'
,
'content'
:
SYSTEM_PROMPT
},
{
'role'
:
'user'
,
'content'
:
x
[
'problem'
]}
],
'answer'
:
x
[
'answer'
]
},
remove_columns
=
[
"problem"
],
num_proc
=
16
,
)
data
=
data
.
filter
(
lambda
x
:
x
[
'answer'
]
is
not
None
,
num_proc
=
16
)
# print(len(data))
return
data
# type: ignore
def
get_unsloth_openmath
(
dataset
=
"unsloth/OpenMathReasoning-mini"
,
split
=
'cot'
)
->
Dataset
:
data
=
load_dataset
(
dataset
)[
split
]
data
=
data
.
map
(
lambda
x
:
{
'prompt'
:
[
{
'role'
:
'system'
,
'content'
:
SYSTEM_PROMPT
},
{
'role'
:
'user'
,
'content'
:
x
[
'problem'
]}
],
'answer'
:
x
[
'expected_answer'
]
},
remove_columns
=
[
"expected_answer"
,
"problem_type"
,
"problem_source"
,
"generation_model"
,
"pass_rate_72b_tir"
,
"generated_solution"
,
"inference_mode"
,
"problem"
,],
num_proc
=
16
,
)
data
=
data
.
filter
(
lambda
x
:
x
[
'answer'
]
is
not
None
,
num_proc
=
16
)
# print("len of unsloth", len(data))
# print("=====", data)
return
data
# type: ignore
def
get_openr1_dapo_math
(
dataset
=
"open-r1/DAPO-Math-17k-Processed"
,
split
=
"train"
)
->
Dataset
:
data
=
load_dataset
(
dataset
,
"all"
)[
split
]
data
=
data
.
map
(
lambda
x
:
{
'prompt'
:
[
{
'role'
:
'system'
,
'content'
:
SYSTEM_PROMPT
},
{
'role'
:
'user'
,
'content'
:
x
[
'prompt'
]}
],
'answer'
:
x
[
'solution'
]
},
remove_columns
=
[
"solution"
,
"data_source"
,
"source_prompt"
,
"ability"
,
"reward_model"
,
"extra_info"
],
num_proc
=
16
,
)
data
=
data
.
filter
(
lambda
x
:
x
[
'answer'
]
is
not
None
,
num_proc
=
16
)
return
data
# type: ignore
src/llamafactory/train/grpo/func.py
0 → 100644
View file @
c7c477c7
import
re
def
extract_xml_answer
(
text
:
str
)
->
str
:
answer
=
text
.
split
(
"<answer>"
)[
-
1
]
answer
=
answer
.
split
(
"</answer>"
)[
0
]
return
answer
.
strip
()
# Define format reward function
def
soft_format_reward_func
(
completions
,
**
kwargs
)
->
list
[
float
]:
"""Reward function that checks if the completion has a specific format."""
pattern
=
r
"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
responses
=
[
completion
[
0
][
"content"
]
for
completion
in
completions
]
matches
=
[
re
.
match
(
pattern
,
r
,
flags
=
re
.
DOTALL
)
for
r
in
responses
]
# print("matchs:", matches)
return
[
0.5
if
match
else
0.0
for
match
in
matches
]
# Define accuracy reward function
def
correctness_reward_func
(
prompts
,
completions
,
answer
,
**
kwargs
)
->
list
[
float
]:
q
=
prompts
[
0
][
-
1
][
'content'
]
responses
=
[
completion
[
0
][
'content'
]
for
completion
in
completions
]
extracted_responses
=
[
extract_xml_answer
(
r
)
for
r
in
responses
]
# print('-'*20)
# print(f"Question:\n{q}", f"\nLabel:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
return
[
2.0
if
r
==
a
else
0.0
for
r
,
a
in
zip
(
extracted_responses
,
answer
)]
def
int_reward_func
(
completions
,
**
kwargs
)
->
list
[
float
]:
responses
=
[
completion
[
0
][
'content'
]
for
completion
in
completions
]
extracted_responses
=
[
extract_xml_answer
(
r
)
for
r
in
responses
]
return
[
0.5
if
r
.
isdigit
()
else
0.0
for
r
in
extracted_responses
]
def
strict_format_reward_func
(
completions
,
**
kwargs
)
->
list
[
float
]:
"""Reward function that checks if the completion has a specific format."""
pattern
=
r
"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
responses
=
[
completion
[
0
][
"content"
]
for
completion
in
completions
]
matches
=
[
re
.
match
(
pattern
,
r
)
for
r
in
responses
]
return
[
0.5
if
match
else
0.0
for
match
in
matches
]
def
count_xml
(
text
)
->
float
:
count
=
0.0
if
text
.
count
(
"<reasoning>
\n
"
)
==
1
:
count
+=
0.125
if
text
.
count
(
"
\n
</reasoning>
\n
"
)
==
1
:
count
+=
0.125
if
text
.
count
(
"
\n
<answer>
\n
"
)
==
1
:
count
+=
0.125
count
-=
len
(
text
.
split
(
"
\n
</answer>
\n
"
)[
-
1
])
*
0.001
if
text
.
count
(
"
\n
</answer>"
)
==
1
:
count
+=
0.125
count
-=
(
len
(
text
.
split
(
"
\n
</answer>"
)[
-
1
])
-
1
)
*
0.001
return
count
def
xmlcount_reward_func
(
completions
,
**
kwargs
)
->
list
[
float
]:
contents
=
[
completion
[
0
][
"content"
]
for
completion
in
completions
]
return
[
count_xml
(
c
)
for
c
in
contents
]
src/llamafactory/train/grpo/metric.py
0 → 100644
View file @
c7c477c7
# Copyright 2025 HuggingFace Inc., THUDM, and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library and the THUDM's ChatGLM implementation.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py
# https://github.com/THUDM/ChatGLM-6B/blob/main/ptuning/main.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Optional
import
numpy
as
np
import
torch
from
transformers.utils
import
is_jieba_available
,
is_nltk_available
from
...extras.constants
import
IGNORE_INDEX
from
...extras.misc
import
numpify
from
...extras.packages
import
is_rouge_available
if
TYPE_CHECKING
:
from
transformers
import
EvalPrediction
,
PreTrainedTokenizer
if
is_jieba_available
():
import
jieba
# type: ignore
if
is_nltk_available
():
from
nltk.translate.bleu_score
import
SmoothingFunction
,
sentence_bleu
# type: ignore
if
is_rouge_available
():
from
rouge_chinese
import
Rouge
# type: ignore
def
eval_logit_processor
(
logits
:
"torch.Tensor"
,
labels
:
"torch.Tensor"
)
->
"torch.Tensor"
:
r
"""Compute the token with the largest likelihood to reduce memory footprint."""
if
isinstance
(
logits
,
(
list
,
tuple
)):
if
logits
[
0
].
dim
()
==
3
:
# (batch_size, seq_len, vocab_size)
logits
=
logits
[
0
]
else
:
# moe models have aux loss
logits
=
logits
[
1
]
if
logits
.
dim
()
!=
3
:
raise
ValueError
(
"Cannot process the logits."
)
return
torch
.
argmax
(
logits
,
dim
=-
1
)
@
dataclass
class
ComputeAccuracy
:
r
"""Compute accuracy and support `batch_eval_metrics`."""
def
_dump
(
self
)
->
Optional
[
dict
[
str
,
float
]]:
result
=
None
if
hasattr
(
self
,
"score_dict"
):
result
=
{
k
:
float
(
np
.
mean
(
v
))
for
k
,
v
in
self
.
score_dict
.
items
()}
self
.
score_dict
=
{
"accuracy"
:
[]}
return
result
def
__post_init__
(
self
):
self
.
_dump
()
def
__call__
(
self
,
eval_preds
:
"EvalPrediction"
,
compute_result
:
bool
=
True
)
->
Optional
[
dict
[
str
,
float
]]:
preds
,
labels
=
numpify
(
eval_preds
.
predictions
),
numpify
(
eval_preds
.
label_ids
)
for
i
in
range
(
len
(
preds
)):
pred
,
label
=
preds
[
i
,
:
-
1
],
labels
[
i
,
1
:]
label_mask
=
label
!=
IGNORE_INDEX
self
.
score_dict
[
"accuracy"
].
append
(
np
.
mean
(
pred
[
label_mask
]
==
label
[
label_mask
]))
if
compute_result
:
return
self
.
_dump
()
@
dataclass
class
ComputeSimilarity
:
r
"""Compute text similarity scores and support `batch_eval_metrics`.
Wraps the tokenizer into metric functions, used in CustomSeq2SeqTrainer.
"""
tokenizer
:
"PreTrainedTokenizer"
def
_dump
(
self
)
->
Optional
[
dict
[
str
,
float
]]:
result
=
None
if
hasattr
(
self
,
"score_dict"
):
result
=
{
k
:
float
(
np
.
mean
(
v
))
for
k
,
v
in
self
.
score_dict
.
items
()}
self
.
score_dict
=
{
"rouge-1"
:
[],
"rouge-2"
:
[],
"rouge-l"
:
[],
"bleu-4"
:
[]}
return
result
def
__post_init__
(
self
):
self
.
_dump
()
def
__call__
(
self
,
eval_preds
:
"EvalPrediction"
,
compute_result
:
bool
=
True
)
->
Optional
[
dict
[
str
,
float
]]:
preds
,
labels
=
numpify
(
eval_preds
.
predictions
),
numpify
(
eval_preds
.
label_ids
)
preds
=
np
.
where
(
preds
!=
IGNORE_INDEX
,
preds
,
self
.
tokenizer
.
pad_token_id
)
labels
=
np
.
where
(
labels
!=
IGNORE_INDEX
,
labels
,
self
.
tokenizer
.
pad_token_id
)
decoded_preds
=
self
.
tokenizer
.
batch_decode
(
preds
,
skip_special_tokens
=
True
)
decoded_labels
=
self
.
tokenizer
.
batch_decode
(
labels
,
skip_special_tokens
=
True
)
for
pred
,
label
in
zip
(
decoded_preds
,
decoded_labels
):
hypothesis
=
list
(
jieba
.
cut
(
pred
))
reference
=
list
(
jieba
.
cut
(
label
))
if
len
(
" "
.
join
(
hypothesis
).
split
())
==
0
or
len
(
" "
.
join
(
reference
).
split
())
==
0
:
result
=
{
"rouge-1"
:
{
"f"
:
0.0
},
"rouge-2"
:
{
"f"
:
0.0
},
"rouge-l"
:
{
"f"
:
0.0
}}
else
:
rouge
=
Rouge
()
scores
=
rouge
.
get_scores
(
" "
.
join
(
hypothesis
),
" "
.
join
(
reference
))
result
=
scores
[
0
]
for
k
,
v
in
result
.
items
():
self
.
score_dict
[
k
].
append
(
round
(
v
[
"f"
]
*
100
,
4
))
bleu_score
=
sentence_bleu
([
list
(
label
)],
list
(
pred
),
smoothing_function
=
SmoothingFunction
().
method3
)
self
.
score_dict
[
"bleu-4"
].
append
(
round
(
bleu_score
*
100
,
4
))
if
compute_result
:
return
self
.
_dump
()
src/llamafactory/train/grpo/workflow.py
0 → 100644
View file @
c7c477c7
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
TYPE_CHECKING
,
Optional
from
...data
import
get_dataset
,
get_template_and_fix_tokenizer
from
...extras.constants
import
IGNORE_INDEX
from
...extras.logging
import
get_logger
# from ...extras.misc import calculate_tps
from
...extras.ploting
import
plot_loss
from
...model
import
load_model
,
load_tokenizer
from
..trainer_utils
import
create_modelcard_and_push
from
.metric
import
ComputeAccuracy
,
ComputeSimilarity
,
eval_logit_processor
from
trl
import
GRPOConfig
,
GRPOTrainer
from
.data
import
*
from
.func
import
*
if
TYPE_CHECKING
:
from
transformers
import
Seq2SeqTrainingArguments
,
TrainerCallback
from
...hparams
import
DataArguments
,
FinetuningArguments
,
GeneratingArguments
,
ModelArguments
logger
=
get_logger
(
__name__
)
def
run_grpo
(
model_args
:
"ModelArguments"
,
data_args
:
"DataArguments"
,
training_args
:
"Seq2SeqTrainingArguments"
,
finetuning_args
:
"FinetuningArguments"
,
generating_args
:
"GeneratingArguments"
,
callbacks
:
Optional
[
list
[
"TrainerCallback"
]]
=
None
,
):
tokenizer_module
=
load_tokenizer
(
model_args
)
tokenizer
=
tokenizer_module
[
"tokenizer"
]
template
=
get_template_and_fix_tokenizer
(
tokenizer
,
data_args
)
# dataset_module = get_dataset(template, model_args, data_args, training_args, stage="ppo", **tokenizer_module)
## load datasets
train_dataset
=
[]
eval_dataset
=
[]
print
(
"training datasets"
,
data_args
.
dataset
)
datasets_list
=
data_args
.
dataset
for
indx
,
dataset
in
enumerate
(
datasets_list
):
dataset
=
dataset
.
strip
()
logger
.
info
(
"[{}/{}] dealing with {}"
.
format
(
indx
+
1
,
len
(
datasets_list
),
dataset
))
if
"hiyouga-math12k"
in
dataset
:
func
=
get_hiyoga
eval_dataset
.
extend
(
get_hiyoga
(
split
=
"test"
))
elif
"openai/gsm8k"
in
dataset
:
func
=
get_gsm8k_questions
eval_dataset
.
extend
(
get_gsm8k_questions
(
split
=
"test"
))
elif
"Congliu/Chinese-DeepSeek-R1-Distill-data-110k-SFT"
in
dataset
:
func
=
get_deepseek_r1_questions
elif
"OpenMathReasoning-mini"
in
dataset
:
func
=
get_unsloth_openmath
elif
"dapo_math"
in
dataset
:
func
=
get_openr1_dapo_math
train_dataset
.
extend
(
func
())
model
=
load_model
(
tokenizer
,
model_args
,
finetuning_args
,
training_args
.
do_train
)
grpo_training_args
=
GRPOConfig
(
do_train
=
True
,
learning_rate
=
training_args
.
learning_rate
,
per_device_train_batch_size
=
training_args
.
per_device_train_batch_size
,
gradient_accumulation_steps
=
training_args
.
gradient_accumulation_steps
,
num_train_epochs
=
training_args
.
num_train_epochs
,
seed
=
training_args
.
seed
,
num_generations
=
8
,
lr_scheduler_type
=
training_args
.
lr_scheduler_type
,
adam_beta1
=
0.9
,
adam_beta2
=
0.99
,
adam_epsilon
=
1e-08
,
weight_decay
=
training_args
.
weight_decay
,
warmup_ratio
=
training_args
.
warmup_ratio
,
logging_steps
=
training_args
.
logging_steps
,
bf16
=
True
,
save_strategy
=
"steps"
,
save_steps
=
training_args
.
save_steps
,
output_dir
=
training_args
.
output_dir
,
max_prompt_length
=
1024
,
max_completion_length
=
2048
,
max_grad_norm
=
0.1
,
ddp_timeout
=
1800000
,
temperature
=
generating_args
.
temperature
,
top_p
=
generating_args
.
top_p
,
top_k
=
generating_args
.
top_k
,
repetition_penalty
=
generating_args
.
repetition_penalty
,
loss_type
=
'grpo'
,
use_vllm
=
True
,
vllm_mode
=
"server"
,
# default value, can be omitted
vllm_server_base_url
=
finetuning_args
.
vllm_server_base_url
,
report_to
=
"none"
,
deepspeed
=
training_args
.
deepspeed
)
# Metric utils
metric_module
=
{}
if
training_args
.
predict_with_generate
:
metric_module
[
"compute_metrics"
]
=
ComputeSimilarity
(
tokenizer
=
tokenizer
)
elif
finetuning_args
.
compute_accuracy
:
metric_module
[
"compute_metrics"
]
=
ComputeAccuracy
()
metric_module
[
"preprocess_logits_for_metrics"
]
=
eval_logit_processor
# Keyword arguments for `model.generate`
gen_kwargs
=
generating_args
.
to_dict
(
obey_generation_config
=
True
)
gen_kwargs
[
"eos_token_id"
]
=
[
tokenizer
.
eos_token_id
]
+
tokenizer
.
additional_special_tokens_ids
gen_kwargs
[
"pad_token_id"
]
=
tokenizer
.
pad_token_id
trainer
=
GRPOTrainer
(
model
=
model
,
processing_class
=
tokenizer
,
reward_funcs
=
[
xmlcount_reward_func
,
soft_format_reward_func
,
strict_format_reward_func
,
int_reward_func
,
correctness_reward_func
],
args
=
grpo_training_args
,
train_dataset
=
train_dataset
,
eval_dataset
=
eval_dataset
,
peft_config
=
None
,
)
# Training
if
training_args
.
do_train
:
train_result
=
trainer
.
train
(
resume_from_checkpoint
=
training_args
.
resume_from_checkpoint
)
# trainer.save_model()
trainer
.
save_state
()
trainer
.
log_metrics
(
"train"
,
train_result
.
metrics
)
trainer
.
save_metrics
(
"train"
,
train_result
.
metrics
)
trainer
.
save_state
()
if
trainer
.
is_world_process_zero
()
and
finetuning_args
.
plot_loss
:
plot_loss
(
training_args
.
output_dir
,
keys
=
[
"loss"
])
# Create model card
create_modelcard_and_push
(
trainer
,
model_args
,
data_args
,
training_args
,
finetuning_args
)
src/llamafactory/train/kto/__init__.py
0 → 100644
View file @
c7c477c7
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.workflow
import
run_kto
__all__
=
[
"run_kto"
]
src/llamafactory/train/kto/trainer.py
0 → 100644
View file @
c7c477c7
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's TRL library.
# https://github.com/huggingface/trl/blob/v0.8.0/trl/trainer/kto_trainer.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
warnings
from
collections
import
defaultdict
from
contextlib
import
nullcontext
from
types
import
MethodType
from
typing
import
TYPE_CHECKING
,
Literal
,
Optional
,
Union
import
torch
from
transformers
import
Trainer
from
trl
import
KTOTrainer
from
trl.trainer
import
disable_dropout_in_model
from
typing_extensions
import
override
from
...extras.constants
import
IGNORE_INDEX
from
...extras.packages
import
is_transformers_version_greater_than
from
..callbacks
import
SaveProcessorCallback
from
..trainer_utils
import
create_custom_optimizer
,
create_custom_scheduler
,
get_batch_logps
,
nested_detach
if
TYPE_CHECKING
:
from
transformers
import
PreTrainedModel
,
ProcessorMixin
from
...hparams
import
FinetuningArguments
class
CustomKTOTrainer
(
KTOTrainer
):
def
__init__
(
self
,
model
:
Union
[
"PreTrainedModel"
,
torch
.
nn
.
Module
],
ref_model
:
Optional
[
Union
[
"PreTrainedModel"
,
torch
.
nn
.
Module
]],
finetuning_args
:
"FinetuningArguments"
,
processor
:
Optional
[
"ProcessorMixin"
],
disable_dropout
:
bool
=
True
,
**
kwargs
,
):
if
is_transformers_version_greater_than
(
"4.46"
):
kwargs
[
"processing_class"
]
=
kwargs
.
pop
(
"tokenizer"
)
if
disable_dropout
:
disable_dropout_in_model
(
model
)
if
ref_model
is
not
None
:
disable_dropout_in_model
(
ref_model
)
self
.
finetuning_args
=
finetuning_args
self
.
reference_free
=
False
self
.
use_dpo_data_collator
=
True
# hack to avoid warning
self
.
generate_during_eval
=
False
# disable at evaluation
self
.
label_pad_token_id
=
IGNORE_INDEX
self
.
padding_value
=
0
self
.
is_encoder_decoder
=
model
.
config
.
is_encoder_decoder
self
.
precompute_ref_log_probs
=
False
self
.
_precomputed_train_ref_log_probs
=
False
self
.
_precomputed_eval_ref_log_probs
=
False
self
.
_peft_has_been_casted_to_bf16
=
False
self
.
ref_model
=
ref_model
self
.
_stored_metrics
=
defaultdict
(
lambda
:
defaultdict
(
list
))
# kto hyperparams
self
.
beta
=
finetuning_args
.
pref_beta
self
.
desirable_weight
=
finetuning_args
.
kto_chosen_weight
self
.
undesirable_weight
=
finetuning_args
.
kto_rejected_weight
self
.
ftx_gamma
=
finetuning_args
.
pref_ftx
Trainer
.
__init__
(
self
,
model
=
model
,
**
kwargs
)
self
.
model_accepts_loss_kwargs
=
False
# overwrite trainer's default behavior
if
not
hasattr
(
self
,
"accelerator"
):
raise
AttributeError
(
"Please update `transformers`."
)
warnings
.
simplefilter
(
"ignore"
)
# remove gc warnings on ref model
if
ref_model
is
not
None
:
if
self
.
is_deepspeed_enabled
:
if
not
(
getattr
(
ref_model
,
"is_loaded_in_8bit"
,
False
)
or
getattr
(
ref_model
,
"is_loaded_in_4bit"
,
False
)
):
# quantized models are already set on the correct device
self
.
ref_model
=
self
.
_prepare_deepspeed
(
self
.
ref_model
)
else
:
self
.
ref_model
=
self
.
accelerator
.
prepare_model
(
self
.
ref_model
,
evaluation_mode
=
True
)
self
.
ref_model
.
eval
()
if
processor
is
not
None
:
self
.
add_callback
(
SaveProcessorCallback
(
processor
))
if
finetuning_args
.
use_badam
:
from
badam
import
BAdamCallback
,
clip_grad_norm_old_version
# type: ignore
self
.
accelerator
.
clip_grad_norm_
=
MethodType
(
clip_grad_norm_old_version
,
self
.
accelerator
)
self
.
add_callback
(
BAdamCallback
)
@
override
def
create_optimizer
(
self
)
->
"torch.optim.Optimizer"
:
if
self
.
optimizer
is
None
:
self
.
optimizer
=
create_custom_optimizer
(
self
.
model
,
self
.
args
,
self
.
finetuning_args
)
return
super
().
create_optimizer
()
@
override
def
create_scheduler
(
self
,
num_training_steps
:
int
,
optimizer
:
Optional
[
"torch.optim.Optimizer"
]
=
None
)
->
"torch.optim.lr_scheduler.LRScheduler"
:
create_custom_scheduler
(
self
.
args
,
num_training_steps
,
optimizer
)
return
super
().
create_scheduler
(
num_training_steps
,
optimizer
)
@
override
def
_get_train_sampler
(
self
,
*
args
,
**
kwargs
)
->
Optional
[
"torch.utils.data.Sampler"
]:
r
"""Replace the sequential sampler of KTO Trainer created by trl with the random sampler."""
if
self
.
finetuning_args
.
disable_shuffling
:
return
torch
.
utils
.
data
.
SequentialSampler
(
self
.
train_dataset
)
return
Trainer
.
_get_train_sampler
(
self
,
*
args
,
**
kwargs
)
@
override
def
get_batch_samples
(
self
,
*
args
,
**
kwargs
):
r
"""Replace the method of KTO Trainer with the one of the standard Trainer."""
return
Trainer
.
get_batch_samples
(
self
,
*
args
,
**
kwargs
)
@
override
def
forward
(
self
,
model
:
"PreTrainedModel"
,
batch
:
dict
[
str
,
"torch.Tensor"
],
prefix
:
Literal
[
""
,
"kl_"
]
=
""
)
->
tuple
[
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
]:
r
"""Run forward pass and computes the log probabilities."""
batch
=
nested_detach
(
batch
,
clone
=
True
)
# avoid error
model_inputs
=
{
"input_ids"
:
batch
[
f
"
{
prefix
}
input_ids"
],
"attention_mask"
:
batch
[
f
"
{
prefix
}
attention_mask"
],
}
if
f
"
{
prefix
}
token_type_ids"
in
batch
:
model_inputs
[
"token_type_ids"
]
=
batch
[
f
"
{
prefix
}
token_type_ids"
]
if
"pixel_values"
in
batch
:
model_inputs
[
"pixel_values"
]
=
batch
[
"pixel_values"
]
if
"image_sizes"
in
batch
:
model_inputs
[
"image_sizes"
]
=
batch
[
"image_sizes"
]
if
"image_grid_thw"
in
batch
:
model_inputs
[
"image_grid_thw"
]
=
batch
[
"image_grid_thw"
]
if
"aspect_ratio_ids"
in
batch
:
model_inputs
[
"aspect_ratio_ids"
]
=
batch
[
"aspect_ratio_ids"
]
if
"aspect_ratio_mask"
in
batch
:
model_inputs
[
"aspect_ratio_mask"
]
=
batch
[
"aspect_ratio_mask"
]
if
f
"
{
prefix
}
cross_attention_mask"
in
batch
:
model_inputs
[
"cross_attention_mask"
]
=
batch
[
f
"
{
prefix
}
cross_attention_mask"
]
logits
=
model
(
**
model_inputs
,
return_dict
=
True
,
use_cache
=
False
).
logits
.
to
(
torch
.
float32
)
logps
,
valid_length
=
get_batch_logps
(
logits
=
logits
,
labels
=
batch
[
f
"
{
prefix
}
labels"
])
return
logits
,
logps
,
logps
/
valid_length
@
override
def
concatenated_forward
(
self
,
model
:
"PreTrainedModel"
,
batch
:
dict
[
str
,
"torch.Tensor"
]
)
->
tuple
[
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
]:
target_logits
,
target_logps
,
target_logps_avg
=
self
.
forward
(
model
,
batch
)
with
torch
.
no_grad
():
_
,
kl_logps
,
_
=
self
.
forward
(
model
,
batch
,
prefix
=
"kl_"
)
if
len
(
target_logps
)
!=
len
(
batch
[
"kto_tags"
]):
raise
ValueError
(
"Mismatched shape of inputs and labels."
)
chosen_logits
=
target_logits
[
batch
[
"kto_tags"
]]
chosen_logps
=
target_logps
[
batch
[
"kto_tags"
]]
rejected_logits
=
target_logits
[
~
batch
[
"kto_tags"
]]
rejected_logps
=
target_logps
[
~
batch
[
"kto_tags"
]]
chosen_logps_avg
=
target_logps_avg
[
batch
[
"kto_tags"
]]
return
chosen_logps
,
rejected_logps
,
chosen_logits
,
rejected_logits
,
kl_logps
,
chosen_logps_avg
@
override
def
compute_reference_log_probs
(
self
,
model
:
"PreTrainedModel"
,
batch
:
dict
[
str
,
"torch.Tensor"
]
)
->
tuple
[
"torch.Tensor"
,
"torch.Tensor"
,
"torch.Tensor"
]:
r
"""Compute log probabilities of the reference model."""
if
self
.
ref_model
is
None
:
ref_model
=
model
ref_context
=
self
.
accelerator
.
unwrap_model
(
model
).
disable_adapter
()
else
:
ref_model
=
self
.
ref_model
ref_context
=
nullcontext
()
with
torch
.
no_grad
(),
ref_context
:
reference_chosen_logps
,
reference_rejected_logps
,
_
,
_
,
reference_kl_logps
,
_
=
self
.
concatenated_forward
(
ref_model
,
batch
)
return
reference_chosen_logps
,
reference_rejected_logps
,
reference_kl_logps
@
override
def
get_batch_loss_metrics
(
self
,
model
:
"PreTrainedModel"
,
batch
:
dict
[
str
,
"torch.Tensor"
],
)
->
tuple
[
"torch.Tensor"
,
dict
[
str
,
"torch.Tensor"
]]:
r
"""Compute the DPO loss and other metrics for the given batch of inputs for train or test."""
metrics
=
{}
(
policy_chosen_logps
,
policy_rejected_logps
,
policy_chosen_logits
,
policy_rejected_logits
,
policy_kl_logps
,
policy_chosen_logps_avg
,
)
=
self
.
concatenated_forward
(
model
,
batch
)
reference_chosen_logps
,
reference_rejected_logps
,
reference_kl_logps
=
self
.
compute_reference_log_probs
(
model
,
batch
)
losses
,
chosen_rewards
,
rejected_rewards
,
kl
=
self
.
kto_loss
(
policy_chosen_logps
,
policy_rejected_logps
,
policy_kl_logps
,
reference_chosen_logps
,
reference_rejected_logps
,
reference_kl_logps
,
)
losses
=
losses
.
nanmean
()
if
self
.
ftx_gamma
>
1e-6
and
len
(
policy_chosen_logps
)
>
0
:
# remember to rescale
sft_loss
=
-
policy_chosen_logps_avg
losses
+=
self
.
ftx_gamma
*
sft_loss
.
nanmean
()
/
len
(
policy_chosen_logps
)
*
len
(
batch
[
"labels"
])
num_chosen
=
len
(
chosen_rewards
)
num_rejected
=
len
(
rejected_rewards
)
if
num_chosen
>
0
:
metrics
[
"rewards/chosen_sum"
]
=
chosen_rewards
.
nansum
().
item
()
metrics
[
"logps/chosen_sum"
]
=
policy_chosen_logps
.
nansum
().
item
()
metrics
[
"logits/chosen_sum"
]
=
policy_chosen_logits
.
nansum
().
item
()
metrics
[
"count/chosen"
]
=
float
(
num_chosen
)
if
num_rejected
>
0
:
metrics
[
"rewards/rejected_sum"
]
=
rejected_rewards
.
nansum
().
item
()
metrics
[
"logps/rejected_sum"
]
=
policy_rejected_logps
.
nansum
().
item
()
metrics
[
"logits/rejected_sum"
]
=
policy_rejected_logits
.
nansum
().
item
()
metrics
[
"count/rejected"
]
=
float
(
num_rejected
)
metrics
[
"kl"
]
=
kl
.
item
()
return
losses
,
metrics
@
override
def
compute_loss
(
self
,
model
:
"PreTrainedModel"
,
inputs
:
dict
[
str
,
"torch.Tensor"
],
return_outputs
:
bool
=
False
,
**
kwargs
)
->
Union
[
"torch.Tensor"
,
tuple
[
"torch.Tensor"
,
list
[
"torch.Tensor"
]]]:
r
"""Subclass and override to accept extra kwargs."""
return
super
().
compute_loss
(
model
,
inputs
,
return_outputs
)
@
override
def
log
(
self
,
logs
:
dict
[
str
,
float
],
*
args
,
**
kwargs
)
->
None
:
r
"""Log `logs` on the various objects watching training, including stored metrics."""
# logs either has "loss" or "eval_loss"
train_eval
=
"train"
if
"loss"
in
logs
else
"eval"
prefix
=
"eval_"
if
train_eval
==
"eval"
else
""
# Add averaged stored metrics to logs
key_list
,
metric_list
=
[],
[]
for
key
,
metrics
in
self
.
_stored_metrics
[
train_eval
].
items
():
key_list
.
append
(
key
)
metric_list
.
append
(
torch
.
tensor
(
metrics
,
dtype
=
torch
.
float
).
to
(
self
.
accelerator
.
device
).
sum
().
item
())
del
self
.
_stored_metrics
[
train_eval
]
if
len
(
metric_list
)
<
9
:
# pad to for all reduce
for
i
in
range
(
9
-
len
(
metric_list
)):
key_list
.
append
(
f
"dummy_
{
i
}
"
)
metric_list
.
append
(
0.0
)
metric_list
=
torch
.
tensor
(
metric_list
,
dtype
=
torch
.
float
).
to
(
self
.
accelerator
.
device
)
metric_list
=
self
.
accelerator
.
reduce
(
metric_list
,
"sum"
).
tolist
()
metric_dict
:
dict
[
str
,
float
]
=
dict
(
zip
(
key_list
,
metric_list
))
for
split
in
[
"chosen"
,
"rejected"
]:
# accumulate average metrics from sums and lengths
if
f
"count/
{
split
}
"
in
metric_dict
:
for
key
in
(
"rewards"
,
"logps"
,
"logits"
):
logs
[
f
"
{
prefix
}{
key
}
/
{
split
}
"
]
=
metric_dict
[
f
"
{
key
}
/
{
split
}
_sum"
]
/
metric_dict
[
f
"count/
{
split
}
"
]
del
metric_dict
[
f
"
{
key
}
/
{
split
}
_sum"
]
del
metric_dict
[
f
"count/
{
split
}
"
]
if
f
"
{
prefix
}
rewards/chosen"
in
logs
and
f
"
{
prefix
}
rewards/rejected"
in
logs
:
# calculate reward margin
logs
[
f
"
{
prefix
}
rewards/margins"
]
=
logs
[
f
"
{
prefix
}
rewards/chosen"
]
-
logs
[
f
"
{
prefix
}
rewards/rejected"
]
for
key
,
metric
in
metric_dict
.
items
():
# add remaining items
if
not
key
.
startswith
(
"dummy_"
):
logs
[
key
]
=
metric
return
Trainer
.
log
(
self
,
logs
,
*
args
,
**
kwargs
)
src/llamafactory/train/kto/workflow.py
0 → 100644
View file @
c7c477c7
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's TRL library.
# https://github.com/huggingface/trl/blob/v0.8.0/examples/scripts/kto.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
TYPE_CHECKING
,
Optional
from
...data
import
KTODataCollatorWithPadding
,
get_dataset
,
get_template_and_fix_tokenizer
from
...extras.constants
import
IGNORE_INDEX
from
...extras.ploting
import
plot_loss
from
...hparams
import
ModelArguments
from
...model
import
load_model
,
load_tokenizer
from
..trainer_utils
import
create_modelcard_and_push
,
create_ref_model
from
.trainer
import
CustomKTOTrainer
if
TYPE_CHECKING
:
from
transformers
import
Seq2SeqTrainingArguments
,
TrainerCallback
from
...hparams
import
DataArguments
,
FinetuningArguments
def
run_kto
(
model_args
:
"ModelArguments"
,
data_args
:
"DataArguments"
,
training_args
:
"Seq2SeqTrainingArguments"
,
finetuning_args
:
"FinetuningArguments"
,
callbacks
:
Optional
[
list
[
"TrainerCallback"
]]
=
None
,
):
tokenizer_module
=
load_tokenizer
(
model_args
)
tokenizer
=
tokenizer_module
[
"tokenizer"
]
template
=
get_template_and_fix_tokenizer
(
tokenizer
,
data_args
)
dataset_module
=
get_dataset
(
template
,
model_args
,
data_args
,
training_args
,
stage
=
"kto"
,
**
tokenizer_module
)
model
=
load_model
(
tokenizer
,
model_args
,
finetuning_args
,
training_args
.
do_train
)
data_collator
=
KTODataCollatorWithPadding
(
template
=
template
,
model
=
model
,
pad_to_multiple_of
=
8
,
label_pad_token_id
=
IGNORE_INDEX
if
data_args
.
ignore_pad_token_for_loss
else
tokenizer
.
pad_token_id
,
**
tokenizer_module
,
)
# Create reference model
if
finetuning_args
.
ref_model
is
None
and
(
not
training_args
.
do_train
):
# use the model itself
ref_model
=
model
else
:
ref_model
=
create_ref_model
(
model_args
,
finetuning_args
)
# Initialize our Trainer
trainer
=
CustomKTOTrainer
(
model
=
model
,
ref_model
=
ref_model
,
args
=
training_args
,
finetuning_args
=
finetuning_args
,
data_collator
=
data_collator
,
callbacks
=
callbacks
,
**
dataset_module
,
**
tokenizer_module
,
)
# Training
if
training_args
.
do_train
:
train_result
=
trainer
.
train
(
resume_from_checkpoint
=
training_args
.
resume_from_checkpoint
)
trainer
.
save_model
()
trainer
.
log_metrics
(
"train"
,
train_result
.
metrics
)
trainer
.
save_metrics
(
"train"
,
train_result
.
metrics
)
trainer
.
save_state
()
if
trainer
.
is_world_process_zero
()
and
finetuning_args
.
plot_loss
:
keys
=
[
"loss"
,
"rewards/chosen"
]
if
isinstance
(
dataset_module
.
get
(
"eval_dataset"
),
dict
):
keys
+=
[
f
"eval_
{
key
}
_loss"
for
key
in
dataset_module
[
"eval_dataset"
].
keys
()]
else
:
keys
+=
[
"eval_loss"
]
plot_loss
(
training_args
.
output_dir
,
keys
=
keys
)
# Evaluation
if
training_args
.
do_eval
:
metrics
=
trainer
.
evaluate
(
metric_key_prefix
=
"eval"
)
if
id
(
model
)
==
id
(
ref_model
):
# unable to compute rewards without a reference model
remove_keys
=
[
key
for
key
in
metrics
.
keys
()
if
"rewards"
in
key
]
for
key
in
remove_keys
:
metrics
.
pop
(
key
)
trainer
.
log_metrics
(
"eval"
,
metrics
)
trainer
.
save_metrics
(
"eval"
,
metrics
)
# Create model card
create_modelcard_and_push
(
trainer
,
model_args
,
data_args
,
training_args
,
finetuning_args
)
src/llamafactory/train/ppo/__init__.py
0 → 100644
View file @
c7c477c7
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.workflow
import
run_ppo
__all__
=
[
"run_ppo"
]
src/llamafactory/train/ppo/ppo_utils.py
0 → 100644
View file @
c7c477c7
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
json
from
contextlib
import
nullcontext
from
typing
import
TYPE_CHECKING
,
Literal
,
Optional
import
torch
from
transformers.integrations
import
is_deepspeed_zero3_enabled
from
...extras.packages
import
is_requests_available
if
is_requests_available
():
import
requests
if
TYPE_CHECKING
:
from
transformers
import
PreTrainedModel
from
trl
import
AutoModelForCausalLMWithValueHead
def
get_rewards_from_server
(
server_url
:
str
,
messages
:
list
[
str
])
->
list
[
"torch.Tensor"
]:
r
"""Get reward scores from the API server."""
headers
=
{
"Content-Type"
:
"application/json"
}
payload
=
{
"model"
:
"model"
,
"messages"
:
messages
}
response
=
requests
.
post
(
server_url
,
json
=
payload
,
headers
=
headers
)
rewards
=
json
.
loads
(
response
.
text
)[
"scores"
]
return
torch
.
Tensor
(
rewards
)
def
replace_model
(
model
:
"AutoModelForCausalLMWithValueHead"
,
target
:
Literal
[
"default"
,
"reward"
])
->
None
:
r
"""Replace the default/reward modules in the model. The model is already unwrapped."""
v_head_layer
=
model
.
v_head
.
summary
if
is_deepspeed_zero3_enabled
():
import
deepspeed
# type: ignore
params
=
[
v_head_layer
.
weight
,
v_head_layer
.
bias
]
context_maybe_zero3
=
deepspeed
.
zero
.
GatheredParameters
(
params
,
modifier_rank
=
0
)
else
:
context_maybe_zero3
=
nullcontext
()
model
.
pretrained_model
.
set_adapter
(
target
)
# set the LoRA adapter to be active
with
context_maybe_zero3
:
if
target
==
"reward"
:
# save default head temporarily
setattr
(
model
,
"default_head_weight"
,
v_head_layer
.
weight
.
data
.
detach
().
clone
())
setattr
(
model
,
"default_head_bias"
,
v_head_layer
.
bias
.
data
.
detach
().
clone
())
device
=
v_head_layer
.
weight
.
device
v_head_layer
.
weight
.
data
=
model
.
get_buffer
(
f
"
{
target
}
_head_weight"
).
detach
().
clone
().
to
(
device
)
v_head_layer
.
bias
.
data
=
model
.
get_buffer
(
f
"
{
target
}
_head_bias"
).
detach
().
clone
().
to
(
device
)
def
dump_layernorm
(
model
:
"PreTrainedModel"
)
->
dict
[
str
,
"torch.Tensor"
]:
r
"""Dump the layernorm parameters in the model. The model is already unwrapped (and gathered)."""
layer_norm_params
=
{}
for
name
,
param
in
model
.
named_parameters
():
if
param
.
data
.
dtype
==
torch
.
float32
:
layer_norm_params
[
name
]
=
param
.
data
.
detach
().
clone
()
param
.
data
=
param
.
data
.
to
(
model
.
config
.
torch_dtype
)
return
layer_norm_params
def
restore_layernorm
(
model
:
"PreTrainedModel"
,
layernorm_params
:
Optional
[
dict
[
str
,
"torch.Tensor"
]]
=
None
)
->
None
:
r
"""Restore the layernorm parameters in the model. The model is already unwrapped (and gathered)."""
for
name
,
param
in
model
.
named_parameters
():
if
name
in
layernorm_params
:
param
.
data
=
layernorm_params
[
name
]
src/llamafactory/train/ppo/trainer.py
0 → 100644
View file @
c7c477c7
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's TRL library.
# https://github.com/huggingface/trl/blob/v0.8.0/trl/trainer/ppo_trainer.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
math
import
os
import
sys
import
warnings
from
types
import
MethodType
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
import
torch
from
accelerate.utils
import
DistributedDataParallelKwargs
from
tqdm
import
tqdm
from
transformers
import
GenerationConfig
,
Trainer
,
TrainerControl
,
TrainerState
from
transformers.optimization
import
get_scheduler
from
transformers.trainer
import
DEFAULT_CALLBACKS
from
transformers.trainer_callback
import
CallbackHandler
from
transformers.trainer_pt_utils
import
remove_dummy_checkpoint
from
transformers.trainer_utils
import
PREFIX_CHECKPOINT_DIR
from
transformers.utils
import
SAFE_WEIGHTS_NAME
,
WEIGHTS_NAME
from
trl
import
PPOConfig
,
PPOTrainer
from
trl.core
import
PPODecorators
,
logprobs_from_logits
from
trl.models.utils
import
unwrap_model_for_generation
from
typing_extensions
import
override
from
...extras
import
logging
from
...extras.misc
import
AverageMeter
,
count_parameters
,
get_current_device
,
get_logits_processor
from
..callbacks
import
FixValueHeadModelCallback
,
SaveProcessorCallback
from
..trainer_utils
import
create_custom_optimizer
,
create_custom_scheduler
from
.ppo_utils
import
dump_layernorm
,
get_rewards_from_server
,
replace_model
,
restore_layernorm
if
TYPE_CHECKING
:
from
datasets
import
Dataset
from
transformers
import
(
DataCollatorWithPadding
,
PreTrainedTokenizer
,
ProcessorMixin
,
Seq2SeqTrainingArguments
,
TrainerCallback
,
)
from
trl
import
AutoModelForCausalLMWithValueHead
from
...hparams
import
FinetuningArguments
,
GeneratingArguments
,
ModelArguments
logger
=
logging
.
get_logger
(
__name__
)
class
CustomPPOTrainer
(
PPOTrainer
,
Trainer
):
r
"""Inherit PPOTrainer."""
def
__init__
(
self
,
model_args
:
"ModelArguments"
,
training_args
:
"Seq2SeqTrainingArguments"
,
finetuning_args
:
"FinetuningArguments"
,
generating_args
:
"GeneratingArguments"
,
callbacks
:
Optional
[
list
[
"TrainerCallback"
]],
model
:
"AutoModelForCausalLMWithValueHead"
,
reward_model
:
Optional
[
"AutoModelForCausalLMWithValueHead"
],
ref_model
:
Optional
[
"AutoModelForCausalLMWithValueHead"
],
tokenizer
:
"PreTrainedTokenizer"
,
processor
:
Optional
[
"ProcessorMixin"
],
data_collator
:
"DataCollatorWithPadding"
,
train_dataset
:
Optional
[
"Dataset"
]
=
None
,
eval_dataset
:
Optional
[
"Dataset"
]
=
None
,
)
->
None
:
if
eval_dataset
is
not
None
:
raise
NotImplementedError
(
"PPOTrainer does not support eval dataset yet."
)
backward_batch_size
=
training_args
.
per_device_train_batch_size
*
training_args
.
gradient_accumulation_steps
ppo_config
=
PPOConfig
(
model_name
=
model_args
.
model_name_or_path
,
learning_rate
=
training_args
.
learning_rate
,
mini_batch_size
=
training_args
.
per_device_train_batch_size
,
batch_size
=
backward_batch_size
*
finetuning_args
.
ppo_buffer_size
,
gradient_accumulation_steps
=
training_args
.
gradient_accumulation_steps
,
ppo_epochs
=
finetuning_args
.
ppo_epochs
,
max_grad_norm
=
training_args
.
max_grad_norm
,
seed
=
training_args
.
seed
,
optimize_device_cache
=
True
,
target
=
finetuning_args
.
ppo_target
,
use_score_scaling
=
finetuning_args
.
ppo_score_norm
,
use_score_norm
=
finetuning_args
.
ppo_score_norm
,
whiten_rewards
=
finetuning_args
.
ppo_whiten_rewards
,
accelerator_kwargs
=
{
"step_scheduler_with_optimizer"
:
False
},
log_with
=
training_args
.
report_to
[
0
]
if
training_args
.
report_to
else
None
,
project_kwargs
=
{
"logging_dir"
:
training_args
.
logging_dir
},
)
# Add deepspeed config
if
training_args
.
deepspeed_plugin
is
not
None
:
ppo_config
.
accelerator_kwargs
[
"kwargs_handlers"
]
=
[
DistributedDataParallelKwargs
(
find_unused_parameters
=
training_args
.
ddp_find_unused_parameters
)
]
ppo_config
.
accelerator_kwargs
[
"deepspeed_plugin"
]
=
training_args
.
deepspeed_plugin
if
ppo_config
.
log_with
is
not
None
:
logger
.
warning_rank0
(
"PPOTrainer cannot use external logger when DeepSpeed is enabled."
)
ppo_config
.
log_with
=
None
# Create optimizer and scheduler
if
training_args
.
max_steps
>
0
:
num_training_steps
=
training_args
.
max_steps
else
:
total_train_batch_size
=
backward_batch_size
*
finetuning_args
.
ppo_buffer_size
*
training_args
.
world_size
num_training_steps
=
training_args
.
num_train_epochs
*
math
.
ceil
(
len
(
train_dataset
)
/
total_train_batch_size
)
optimizer
=
self
.
create_optimizer
(
model
,
training_args
,
finetuning_args
)
scheduler
=
self
.
create_scheduler
(
training_args
,
num_training_steps
,
optimizer
)
PPOTrainer
.
__init__
(
self
,
config
=
ppo_config
,
model
=
model
,
ref_model
=
ref_model
,
tokenizer
=
tokenizer
,
dataset
=
train_dataset
,
optimizer
=
optimizer
,
data_collator
=
data_collator
,
lr_scheduler
=
scheduler
,
)
self
.
args
=
training_args
self
.
model_args
=
model_args
self
.
finetuning_args
=
finetuning_args
self
.
reward_model
=
reward_model
self
.
current_device
=
get_current_device
()
# patch for deepspeed training
self
.
generation_config
=
GenerationConfig
(
pad_token_id
=
self
.
tokenizer
.
pad_token_id
,
eos_token_id
=
[
self
.
tokenizer
.
eos_token_id
]
+
self
.
tokenizer
.
additional_special_tokens_ids
,
**
generating_args
.
to_dict
(),
)
self
.
state
=
TrainerState
()
self
.
control
=
TrainerControl
()
self
.
is_deepspeed_enabled
=
getattr
(
self
.
accelerator
.
state
,
"deepspeed_plugin"
,
None
)
is
not
None
self
.
is_fsdp_enabled
=
getattr
(
self
.
accelerator
.
state
,
"fsdp_plugin"
,
None
)
is
not
None
callbacks
=
DEFAULT_CALLBACKS
if
callbacks
is
None
else
DEFAULT_CALLBACKS
+
callbacks
self
.
callback_handler
=
CallbackHandler
(
callbacks
,
self
.
accelerator
.
unwrap_model
(
self
.
model
),
self
.
tokenizer
,
self
.
optimizer
,
self
.
lr_scheduler
)
if
self
.
args
.
max_steps
>
0
:
logger
.
info_rank0
(
"max_steps is given, it will override any value given in num_train_epochs"
)
self
.
amp_context
=
torch
.
autocast
(
self
.
current_device
.
type
)
warnings
.
simplefilter
(
"ignore"
)
# remove gc warnings on ref model
if
finetuning_args
.
reward_model_type
==
"full"
:
if
self
.
is_deepspeed_enabled
:
if
not
(
getattr
(
reward_model
.
pretrained_model
,
"is_loaded_in_8bit"
,
False
)
or
getattr
(
reward_model
.
pretrained_model
,
"is_loaded_in_4bit"
,
False
)
):
# quantized models are already set on the correct device
self
.
reward_model
=
self
.
_prepare_deepspeed
(
self
.
reward_model
)
else
:
self
.
reward_model
=
self
.
accelerator
.
prepare_model
(
self
.
reward_model
,
evaluation_mode
=
True
)
self
.
add_callback
(
FixValueHeadModelCallback
)
if
processor
is
not
None
:
self
.
add_callback
(
SaveProcessorCallback
(
processor
))
if
finetuning_args
.
use_badam
:
from
badam
import
BAdamCallback
,
clip_grad_norm_old_version
# type: ignore
self
.
accelerator
.
clip_grad_norm_
=
MethodType
(
clip_grad_norm_old_version
,
self
.
accelerator
)
self
.
add_callback
(
BAdamCallback
)
def
ppo_train
(
self
,
resume_from_checkpoint
:
Optional
[
str
]
=
None
)
->
None
:
r
"""Implement training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer."""
if
resume_from_checkpoint
is
not
None
:
raise
ValueError
(
"`resume_from_checkpoint` will be supported in the future version."
)
total_train_batch_size
=
(
self
.
args
.
per_device_train_batch_size
*
self
.
args
.
gradient_accumulation_steps
*
self
.
finetuning_args
.
ppo_buffer_size
*
self
.
args
.
world_size
)
if
self
.
args
.
max_steps
>
0
:
num_examples
=
total_train_batch_size
*
self
.
args
.
max_steps
num_train_epochs
=
sys
.
maxsize
max_steps
=
self
.
args
.
max_steps
steps_in_epoch
=
self
.
args
.
max_steps
else
:
len_dataloader
=
len
(
self
.
dataloader
)
num_examples
=
len
(
self
.
dataset
)
num_train_epochs
=
self
.
args
.
num_train_epochs
max_steps
=
math
.
ceil
(
num_train_epochs
*
len_dataloader
)
steps_in_epoch
=
len_dataloader
self
.
state
.
max_steps
=
max_steps
self
.
state
.
num_train_epochs
=
num_train_epochs
self
.
state
.
is_local_process_zero
=
self
.
is_local_process_zero
()
self
.
state
.
is_world_process_zero
=
self
.
is_world_process_zero
()
logger
.
info_rank0
(
"***** Running training *****"
)
logger
.
info_rank0
(
f
" Num examples =
{
num_examples
:,
}
"
)
logger
.
info_rank0
(
f
" Num Epochs =
{
num_train_epochs
:,
}
"
)
logger
.
info_rank0
(
f
" Instantaneous batch size per device =
{
self
.
args
.
per_device_train_batch_size
:,
}
"
)
logger
.
info_rank0
(
f
" Total train batch size (w. parallel, buffer, distributed & accumulation) =
{
total_train_batch_size
:,
}
"
)
logger
.
info_rank0
(
f
" Gradient Accumulation steps =
{
self
.
args
.
gradient_accumulation_steps
:,
}
"
)
logger
.
info_rank0
(
f
" Num optimization epochs per batch =
{
self
.
finetuning_args
.
ppo_epochs
:,
}
"
)
logger
.
info_rank0
(
f
" Total training steps =
{
max_steps
:,
}
"
)
logger
.
info_rank0
(
f
" Number of trainable parameters =
{
count_parameters
(
self
.
model
)[
0
]:,
}
"
)
dataiter
=
iter
(
self
.
dataloader
)
loss_meter
=
AverageMeter
()
reward_meter
=
AverageMeter
()
self
.
callback_handler
.
on_train_begin
(
self
.
args
,
self
.
state
,
self
.
control
)
for
step
in
tqdm
(
range
(
max_steps
),
disable
=
not
self
.
is_local_process_zero
()):
try
:
batch
=
next
(
dataiter
)
except
StopIteration
:
dataiter
=
iter
(
self
.
dataloader
)
batch
=
next
(
dataiter
)
# Get inputs
self
.
model
.
eval
()
self
.
tokenizer
.
padding_side
=
"right"
# change padding side
queries
,
responses
,
rewards
=
[],
[],
[]
for
idx
in
range
(
0
,
self
.
config
.
batch_size
,
self
.
config
.
mini_batch_size
):
mini_batch
=
{
"input_ids"
:
batch
[
"input_ids"
][
idx
:
idx
+
self
.
config
.
mini_batch_size
],
"attention_mask"
:
batch
[
"attention_mask"
][
idx
:
idx
+
self
.
config
.
mini_batch_size
],
}
mini_batch_queries
,
mini_batch_responses
=
self
.
get_inputs
(
mini_batch
)
mini_batch_rewards
=
self
.
get_rewards
(
mini_batch_queries
,
mini_batch_responses
)
queries
.
extend
(
mini_batch_queries
)
responses
.
extend
(
mini_batch_responses
)
rewards
.
extend
(
mini_batch_rewards
)
# Run PPO step
self
.
model
.
train
()
stats
=
self
.
step
(
queries
,
responses
,
rewards
)
self
.
tokenizer
.
padding_side
=
"left"
# restore padding side
loss_meter
.
update
(
float
(
stats
[
"ppo/loss/total"
]),
n
=
len
(
rewards
))
reward_meter
.
update
(
torch
.
stack
(
rewards
).
mean
().
item
(),
n
=
len
(
rewards
))
if
self
.
config
.
log_with
is
not
None
:
try
:
batch
[
"query"
]
=
self
.
tokenizer
.
batch_decode
(
queries
,
skip_special_tokens
=
True
)
batch
[
"response"
]
=
self
.
tokenizer
.
batch_decode
(
responses
,
skip_special_tokens
=
True
)
self
.
log_stats
(
stats
,
batch
,
rewards
)
except
Exception
:
logger
.
warning_rank0
(
"Failed to save stats due to unknown errors."
)
self
.
state
.
global_step
+=
1
self
.
callback_handler
.
on_step_end
(
self
.
args
,
self
.
state
,
self
.
control
)
if
self
.
is_local_process_zero
()
and
(
step
+
1
)
%
self
.
args
.
logging_steps
==
0
:
logs
=
dict
(
loss
=
round
(
loss_meter
.
avg
,
4
),
reward
=
round
(
reward_meter
.
avg
,
4
),
learning_rate
=
stats
[
"ppo/learning_rate"
],
epoch
=
round
(
step
/
steps_in_epoch
,
2
),
)
tqdm
.
write
(
str
(
logs
))
logs
[
"step"
]
=
step
self
.
state
.
log_history
.
append
(
logs
)
self
.
callback_handler
.
on_log
(
self
.
args
,
self
.
state
,
self
.
control
,
logs
)
loss_meter
.
reset
()
reward_meter
.
reset
()
if
(
step
+
1
)
%
self
.
args
.
save_steps
==
0
:
# save checkpoint
self
.
save_model
(
os
.
path
.
join
(
self
.
args
.
output_dir
,
f
"
{
PREFIX_CHECKPOINT_DIR
}
-
{
self
.
state
.
global_step
}
"
)
)
self
.
callback_handler
.
on_save
(
self
.
args
,
self
.
state
,
self
.
control
)
if
self
.
control
.
should_epoch_stop
or
self
.
control
.
should_training_stop
:
break
self
.
callback_handler
.
on_train_end
(
self
.
args
,
self
.
state
,
self
.
control
)
@
override
def
create_optimizer
(
self
,
model
:
"AutoModelForCausalLMWithValueHead"
,
training_args
:
"Seq2SeqTrainingArguments"
,
finetuning_args
:
"FinetuningArguments"
,
)
->
"torch.optim.Optimizer"
:
optimizer
=
create_custom_optimizer
(
model
,
training_args
,
finetuning_args
)
if
optimizer
is
None
:
decay_params
,
nodecay_params
=
[],
[]
decay_param_names
=
self
.
get_decay_parameter_names
(
model
)
for
name
,
param
in
model
.
named_parameters
():
if
param
.
requires_grad
:
if
name
in
decay_param_names
:
decay_params
.
append
(
param
)
else
:
nodecay_params
.
append
(
param
)
optim_class
,
optim_kwargs
=
Trainer
.
get_optimizer_cls_and_kwargs
(
training_args
)
param_groups
=
[
dict
(
params
=
nodecay_params
),
dict
(
params
=
decay_params
,
weight_decay
=
training_args
.
weight_decay
),
]
optimizer
=
optim_class
(
param_groups
,
**
optim_kwargs
)
return
optimizer
@
override
def
create_scheduler
(
self
,
training_args
:
"Seq2SeqTrainingArguments"
,
num_training_steps
:
int
,
optimizer
:
"torch.optim.Optimizer"
)
->
"torch.optim.lr_scheduler.LRScheduler"
:
create_custom_scheduler
(
training_args
,
num_training_steps
,
optimizer
)
lr_scheduler
=
get_scheduler
(
training_args
.
lr_scheduler_type
,
optimizer
=
optimizer
,
num_warmup_steps
=
training_args
.
get_warmup_steps
(
num_training_steps
),
num_training_steps
=
num_training_steps
,
)
return
lr_scheduler
@
torch
.
no_grad
()
def
get_inputs
(
self
,
batch
:
dict
[
str
,
"torch.Tensor"
])
->
tuple
[
list
[
"torch.Tensor"
],
list
[
"torch.Tensor"
]]:
r
"""Generate model's responses given queries."""
if
batch
[
"input_ids"
].
size
(
0
)
==
1
:
# handle llama2 ppo with gradient accumulation > 1
start_index
=
(
batch
[
"input_ids"
][
0
]
!=
self
.
tokenizer
.
pad_token_id
).
nonzero
()[
0
].
item
()
for
k
,
v
in
batch
.
items
():
batch
[
k
]
=
v
[:,
start_index
:]
with
unwrap_model_for_generation
(
self
.
model
,
self
.
accelerator
)
as
unwrapped_model
:
unwrapped_model
:
AutoModelForCausalLMWithValueHead
=
self
.
accelerator
.
unwrap_model
(
self
.
model
)
if
self
.
model_args
.
upcast_layernorm
:
layernorm_params
=
dump_layernorm
(
unwrapped_model
)
generate_output
:
torch
.
Tensor
=
unwrapped_model
.
generate
(
generation_config
=
self
.
generation_config
,
logits_processor
=
get_logits_processor
(),
**
batch
)
if
self
.
model_args
.
upcast_layernorm
:
restore_layernorm
(
unwrapped_model
,
layernorm_params
)
query
=
batch
[
"input_ids"
].
detach
().
cpu
()
response
=
generate_output
[:,
batch
[
"input_ids"
].
size
(
-
1
)
:].
detach
().
cpu
()
queries
,
responses
=
[],
[]
for
i
in
range
(
len
(
query
)):
query_start_index
=
(
query
[
i
]
!=
self
.
tokenizer
.
pad_token_id
).
nonzero
()[
0
].
item
()
response_indexes
=
(
response
[
i
]
!=
self
.
tokenizer
.
pad_token_id
).
nonzero
()
if
len
(
response_indexes
)
==
0
:
# allow empty response
response_length
=
1
elif
self
.
tokenizer
.
eos_token_id
==
self
.
tokenizer
.
pad_token_id
:
# include eos token
response_length
=
response_indexes
[
-
1
].
item
()
+
2
else
:
response_length
=
response_indexes
[
-
1
].
item
()
+
1
queries
.
append
(
query
[
i
,
query_start_index
:])
# remove padding from left
responses
.
append
(
response
[
i
,
:
response_length
])
# remove padding from right
return
queries
,
responses
@
torch
.
no_grad
()
def
get_rewards
(
self
,
queries
:
list
[
"torch.Tensor"
],
responses
:
list
[
"torch.Tensor"
],
)
->
list
[
"torch.Tensor"
]:
r
"""Compute scores using given reward model.
Both inputs and outputs are put on CPU.
"""
if
self
.
finetuning_args
.
reward_model_type
==
"api"
:
token_ids
=
[
torch
.
cat
((
q
,
r
),
dim
=-
1
).
tolist
()
for
q
,
r
in
zip
(
queries
,
responses
)]
messages
=
self
.
tokenizer
.
batch_decode
(
token_ids
,
skip_special_tokens
=
False
)
return
get_rewards_from_server
(
self
.
reward_model
,
messages
)
batch
:
dict
[
str
,
torch
.
Tensor
]
=
self
.
prepare_model_inputs
(
queries
,
responses
)
unwrapped_model
:
AutoModelForCausalLMWithValueHead
=
self
.
accelerator
.
unwrap_model
(
self
.
model
)
if
self
.
finetuning_args
.
reward_model_type
==
"lora"
:
replace_model
(
unwrapped_model
,
target
=
"reward"
)
reward_model
=
self
.
model
else
:
reward_model
=
self
.
reward_model
with
unwrap_model_for_generation
(
reward_model
,
self
.
accelerator
),
self
.
amp_context
:
# support bf16
values
:
torch
.
Tensor
=
reward_model
(
**
batch
,
return_dict
=
True
,
use_cache
=
False
)[
-
1
]
if
self
.
finetuning_args
.
reward_model_type
==
"lora"
:
replace_model
(
unwrapped_model
,
target
=
"default"
)
rewards
=
values
.
gather
(
dim
=-
1
,
index
=
(
batch
[
"attention_mask"
].
sum
(
dim
=-
1
,
keepdim
=
True
)
-
1
))
return
rewards
.
float
().
detach
()
# use fp32 type
@
override
@
PPODecorators
.
empty_device_cache
()
def
batched_forward_pass
(
self
,
model
:
"AutoModelForCausalLMWithValueHead"
,
queries
:
"torch.Tensor"
,
responses
:
"torch.Tensor"
,
model_inputs
:
dict
[
str
,
Any
],
return_logits
:
bool
=
False
,
response_masks
:
Optional
[
"torch.Tensor"
]
=
None
,
)
->
tuple
[
"torch.Tensor"
,
Optional
[
"torch.Tensor"
],
"torch.Tensor"
,
"torch.Tensor"
]:
r
"""Calculate model outputs in multiple batches.
Subclass and override to inject custom behavior.
"""
bs
=
len
(
queries
)
fbs
=
self
.
config
.
mini_batch_size
all_logprobs
=
[]
all_logits
=
[]
all_masks
=
[]
all_values
=
[]
for
i
in
range
(
math
.
ceil
(
bs
/
fbs
)):
input_kwargs
=
{
key
:
value
[
i
*
fbs
:
(
i
+
1
)
*
fbs
]
for
key
,
value
in
model_inputs
.
items
()}
query_batch
=
queries
[
i
*
fbs
:
(
i
+
1
)
*
fbs
]
response_batch
=
responses
[
i
*
fbs
:
(
i
+
1
)
*
fbs
]
if
response_masks
is
not
None
:
response_masks_batch
=
response_masks
[
i
*
fbs
:
(
i
+
1
)
*
fbs
]
input_ids
=
input_kwargs
[
"input_ids"
]
attention_mask
=
input_kwargs
[
"attention_mask"
]
with
self
.
amp_context
:
# support bf16
logits
,
_
,
values
=
model
(
**
input_kwargs
,
return_dict
=
True
,
use_cache
=
False
)
logprobs
=
logprobs_from_logits
(
logits
[:,
:
-
1
,
:],
input_ids
[:,
1
:])
masks
=
torch
.
zeros_like
(
attention_mask
)
masks
[:,
:
-
1
]
=
attention_mask
[:,
1
:]
for
j
in
range
(
len
(
query_batch
)):
start
=
len
(
query_batch
[
j
])
-
1
if
attention_mask
[
j
,
0
]
==
0
:
# offset left padding
start
+=
attention_mask
[
j
,
:].
nonzero
()[
0
].
item
()
end
=
start
+
len
(
response_batch
[
j
])
if
response_masks
is
not
None
:
response_masks_batch
=
torch
.
cat
((
torch
.
zeros_like
(
query_batch
[
j
]),
response_masks_batch
[
j
]))[
1
:]
masks
[
j
,
:
start
]
=
0
masks
[
j
,
end
:]
=
0
if
response_masks
is
not
None
:
masks
[
j
,
start
:
end
]
=
masks
[
j
,
start
:
end
]
*
response_masks_batch
[
j
][
start
:
end
]
if
return_logits
:
all_logits
.
append
(
logits
)
else
:
del
logits
all_values
.
append
(
values
)
all_logprobs
.
append
(
logprobs
)
all_masks
.
append
(
masks
)
return
(
torch
.
cat
(
all_logprobs
),
torch
.
cat
(
all_logits
)[:,
:
-
1
]
if
return_logits
else
None
,
torch
.
cat
(
all_values
)[:,
:
-
1
],
torch
.
cat
(
all_masks
)[:,
:
-
1
],
)
@
override
def
save_model
(
self
,
output_dir
:
Optional
[
str
]
=
None
)
->
None
:
r
"""Save model checkpoint.
Subclass and override to inject custom behavior.
"""
if
output_dir
is
None
:
output_dir
=
self
.
args
.
output_dir
if
self
.
is_fsdp_enabled
or
self
.
is_deepspeed_enabled
:
try
:
state_dict
=
self
.
accelerator
.
get_state_dict
(
self
.
model
)
# must be called at all ranks
if
self
.
args
.
should_save
:
self
.
_save
(
output_dir
,
state_dict
=
state_dict
)
except
ValueError
:
logger
.
warning_rank0
(
" stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead,"
" use zero_to_fp32.py to recover weights"
)
if
self
.
args
.
should_save
:
self
.
_save
(
output_dir
,
state_dict
=
{})
# remove the dummy state_dict
remove_dummy_checkpoint
(
self
.
args
.
should_save
,
output_dir
,
[
WEIGHTS_NAME
,
SAFE_WEIGHTS_NAME
])
self
.
model
.
save_checkpoint
(
output_dir
)
elif
self
.
args
.
should_save
:
unwrapped_model
:
AutoModelForCausalLMWithValueHead
=
self
.
accelerator
.
unwrap_model
(
self
.
model
)
self
.
_save
(
output_dir
,
state_dict
=
unwrapped_model
.
state_dict
())
src/llamafactory/train/ppo/workflow.py
0 → 100644
View file @
c7c477c7
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's TRL library.
# https://github.com/huggingface/trl/blob/v0.8.0/examples/scripts/ppo.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
TYPE_CHECKING
,
Optional
from
...data
import
MultiModalDataCollatorForSeq2Seq
,
get_dataset
,
get_template_and_fix_tokenizer
from
...extras.ploting
import
plot_loss
from
...model
import
load_model
,
load_tokenizer
from
..callbacks
import
fix_valuehead_checkpoint
from
..trainer_utils
import
create_ref_model
,
create_reward_model
from
.trainer
import
CustomPPOTrainer
if
TYPE_CHECKING
:
from
transformers
import
Seq2SeqTrainingArguments
,
TrainerCallback
from
...hparams
import
DataArguments
,
FinetuningArguments
,
GeneratingArguments
,
ModelArguments
def
run_ppo
(
model_args
:
"ModelArguments"
,
data_args
:
"DataArguments"
,
training_args
:
"Seq2SeqTrainingArguments"
,
finetuning_args
:
"FinetuningArguments"
,
generating_args
:
"GeneratingArguments"
,
callbacks
:
Optional
[
list
[
"TrainerCallback"
]]
=
None
,
):
tokenizer_module
=
load_tokenizer
(
model_args
)
tokenizer
=
tokenizer_module
[
"tokenizer"
]
template
=
get_template_and_fix_tokenizer
(
tokenizer
,
data_args
)
dataset_module
=
get_dataset
(
template
,
model_args
,
data_args
,
training_args
,
stage
=
"ppo"
,
**
tokenizer_module
)
model
=
load_model
(
tokenizer
,
model_args
,
finetuning_args
,
training_args
.
do_train
,
add_valuehead
=
True
)
tokenizer
.
padding_side
=
"left"
# use left-padding in generation while using right-padding in training
data_collator
=
MultiModalDataCollatorForSeq2Seq
(
template
=
template
,
model
=
model
,
**
tokenizer_module
)
# Create reference model and reward model
ref_model
=
create_ref_model
(
model_args
,
finetuning_args
,
add_valuehead
=
True
)
reward_model
=
create_reward_model
(
model
,
model_args
,
finetuning_args
)
# Initialize our Trainer
ppo_trainer
:
CustomPPOTrainer
=
CustomPPOTrainer
(
model_args
=
model_args
,
training_args
=
training_args
,
finetuning_args
=
finetuning_args
,
generating_args
=
generating_args
,
callbacks
=
callbacks
,
model
=
model
,
reward_model
=
reward_model
,
ref_model
=
ref_model
,
data_collator
=
data_collator
,
**
dataset_module
,
**
tokenizer_module
,
)
# Training
if
training_args
.
do_train
:
ppo_trainer
.
ppo_train
(
resume_from_checkpoint
=
training_args
.
resume_from_checkpoint
)
ppo_trainer
.
save_model
()
if
training_args
.
should_save
:
fix_valuehead_checkpoint
(
model
,
training_args
.
output_dir
,
training_args
.
save_safetensors
)
ppo_trainer
.
save_state
()
# must be called after save_model to have a folder
if
ppo_trainer
.
is_world_process_zero
()
and
finetuning_args
.
plot_loss
:
plot_loss
(
training_args
.
output_dir
,
keys
=
[
"loss"
,
"reward"
])
src/llamafactory/train/pt/__init__.py
0 → 100644
View file @
c7c477c7
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.workflow
import
run_pt
__all__
=
[
"run_pt"
]
src/llamafactory/train/pt/trainer.py
0 → 100644
View file @
c7c477c7
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
types
import
MethodType
from
typing
import
TYPE_CHECKING
,
Optional
import
torch
from
transformers
import
Trainer
from
typing_extensions
import
override
from
...extras.packages
import
is_transformers_version_greater_than
from
..callbacks
import
SaveProcessorCallback
from
..trainer_utils
import
create_custom_optimizer
,
create_custom_scheduler
if
TYPE_CHECKING
:
from
transformers
import
ProcessorMixin
from
...hparams
import
FinetuningArguments
class
CustomTrainer
(
Trainer
):
r
"""Inherit Trainer for custom optimizer."""
def
__init__
(
self
,
finetuning_args
:
"FinetuningArguments"
,
processor
:
Optional
[
"ProcessorMixin"
],
**
kwargs
)
->
None
:
if
is_transformers_version_greater_than
(
"4.46"
):
kwargs
[
"processing_class"
]
=
kwargs
.
pop
(
"tokenizer"
)
super
().
__init__
(
**
kwargs
)
if
processor
is
not
None
:
# avoid wrong loss under gradient accumulation
# https://github.com/huggingface/transformers/pull/36044#issuecomment-2746657112
self
.
model_accepts_loss_kwargs
=
False
self
.
finetuning_args
=
finetuning_args
if
processor
is
not
None
:
self
.
add_callback
(
SaveProcessorCallback
(
processor
))
if
finetuning_args
.
use_badam
:
from
badam
import
BAdamCallback
,
clip_grad_norm_old_version
# type: ignore
self
.
accelerator
.
clip_grad_norm_
=
MethodType
(
clip_grad_norm_old_version
,
self
.
accelerator
)
self
.
add_callback
(
BAdamCallback
)
@
override
def
create_optimizer
(
self
)
->
"torch.optim.Optimizer"
:
if
self
.
optimizer
is
None
:
self
.
optimizer
=
create_custom_optimizer
(
self
.
model
,
self
.
args
,
self
.
finetuning_args
)
return
super
().
create_optimizer
()
@
override
def
create_scheduler
(
self
,
num_training_steps
:
int
,
optimizer
:
Optional
[
"torch.optim.Optimizer"
]
=
None
)
->
"torch.optim.lr_scheduler.LRScheduler"
:
create_custom_scheduler
(
self
.
args
,
num_training_steps
,
optimizer
)
return
super
().
create_scheduler
(
num_training_steps
,
optimizer
)
@
override
def
_get_train_sampler
(
self
,
*
args
,
**
kwargs
)
->
Optional
[
"torch.utils.data.Sampler"
]:
if
self
.
finetuning_args
.
disable_shuffling
:
return
torch
.
utils
.
data
.
SequentialSampler
(
self
.
train_dataset
)
return
super
().
_get_train_sampler
(
*
args
,
**
kwargs
)
@
override
def
compute_loss
(
self
,
model
,
inputs
,
*
args
,
**
kwargs
):
return
super
().
compute_loss
(
model
,
inputs
,
*
args
,
**
kwargs
)
Prev
1
…
7
8
9
10
11
12
13
14
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment