Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhougaofeng
internlm2-math-7B
Commits
34ce9ee3
"deploy/paddle2onnx/predict_rec.py" did not exist on "c708041e13ef55e14b630edd0f22bfb74b10a6aa"
Commit
34ce9ee3
authored
Jun 11, 2024
by
zhougaofeng
Browse files
Upload New File
parent
9067e0a4
Pipeline
#1146
canceled with stages
Changes
1
Pipelines
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
215 additions
and
0 deletions
+215
-0
src/llmfactory/extras/callbacks.py
src/llmfactory/extras/callbacks.py
+215
-0
No files found.
src/llmfactory/extras/callbacks.py
0 → 100644
View file @
34ce9ee3
import
json
import
logging
import
os
import
signal
import
sys
import
time
from
concurrent.futures
import
ThreadPoolExecutor
from
datetime
import
timedelta
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
Optional
import
transformers
from
transformers
import
TrainerCallback
from
transformers.trainer_utils
import
PREFIX_CHECKPOINT_DIR
,
has_length
from
.constants
import
TRAINER_LOG
from
.logging
import
LoggerHandler
,
get_logger
from
.misc
import
fix_valuehead_checkpoint
if
TYPE_CHECKING
:
from
transformers
import
TrainerControl
,
TrainerState
,
TrainingArguments
logger
=
get_logger
(
__name__
)
class
FixValueHeadModelCallback
(
TrainerCallback
):
def
on_save
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
r
"""
Event called after a checkpoint save.
"""
if
args
.
should_save
:
fix_valuehead_checkpoint
(
model
=
kwargs
.
pop
(
"model"
),
output_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"{}-{}"
.
format
(
PREFIX_CHECKPOINT_DIR
,
state
.
global_step
)),
safe_serialization
=
args
.
save_safetensors
,
)
class
LogCallback
(
TrainerCallback
):
def
__init__
(
self
,
output_dir
:
str
)
->
None
:
r
"""
Initializes a callback for logging training and evaluation status.
"""
""" Progress """
self
.
start_time
=
0
self
.
cur_steps
=
0
self
.
max_steps
=
0
self
.
elapsed_time
=
""
self
.
remaining_time
=
""
self
.
thread_pool
:
Optional
[
"ThreadPoolExecutor"
]
=
None
""" Status """
self
.
aborted
=
False
self
.
do_train
=
False
""" Web UI """
self
.
webui_mode
=
os
.
environ
.
get
(
"LLAMABOARD_ENABLED"
,
"0"
).
lower
()
in
[
"true"
,
"1"
]
if
self
.
webui_mode
:
signal
.
signal
(
signal
.
SIGABRT
,
self
.
_set_abort
)
self
.
logger_handler
=
LoggerHandler
(
output_dir
)
logging
.
root
.
addHandler
(
self
.
logger_handler
)
transformers
.
logging
.
add_handler
(
self
.
logger_handler
)
def
_set_abort
(
self
,
signum
,
frame
)
->
None
:
self
.
aborted
=
True
def
_reset
(
self
,
max_steps
:
int
=
0
)
->
None
:
self
.
start_time
=
time
.
time
()
self
.
cur_steps
=
0
self
.
max_steps
=
max_steps
self
.
elapsed_time
=
""
self
.
remaining_time
=
""
def
_timing
(
self
,
cur_steps
:
int
)
->
None
:
cur_time
=
time
.
time
()
elapsed_time
=
cur_time
-
self
.
start_time
avg_time_per_step
=
elapsed_time
/
cur_steps
if
cur_steps
!=
0
else
0
remaining_time
=
(
self
.
max_steps
-
cur_steps
)
*
avg_time_per_step
self
.
cur_steps
=
cur_steps
self
.
elapsed_time
=
str
(
timedelta
(
seconds
=
int
(
elapsed_time
)))
self
.
remaining_time
=
str
(
timedelta
(
seconds
=
int
(
remaining_time
)))
def
_write_log
(
self
,
output_dir
:
str
,
logs
:
Dict
[
str
,
Any
])
->
None
:
with
open
(
os
.
path
.
join
(
output_dir
,
TRAINER_LOG
),
"a"
,
encoding
=
"utf-8"
)
as
f
:
f
.
write
(
json
.
dumps
(
logs
)
+
"
\n
"
)
def
_create_thread_pool
(
self
,
output_dir
:
str
)
->
None
:
os
.
makedirs
(
output_dir
,
exist_ok
=
True
)
self
.
thread_pool
=
ThreadPoolExecutor
(
max_workers
=
1
)
def
_close_thread_pool
(
self
)
->
None
:
if
self
.
thread_pool
is
not
None
:
self
.
thread_pool
.
shutdown
(
wait
=
True
)
self
.
thread_pool
=
None
def
on_init_end
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
r
"""
Event called at the end of the initialization of the `Trainer`.
"""
if
(
args
.
should_save
and
os
.
path
.
exists
(
os
.
path
.
join
(
args
.
output_dir
,
TRAINER_LOG
))
and
args
.
overwrite_output_dir
):
logger
.
warning
(
"Previous trainer log in this folder will be deleted."
)
os
.
remove
(
os
.
path
.
join
(
args
.
output_dir
,
TRAINER_LOG
))
def
on_train_begin
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
r
"""
Event called at the beginning of training.
"""
if
args
.
should_save
:
self
.
do_train
=
True
self
.
_reset
(
max_steps
=
state
.
max_steps
)
self
.
_create_thread_pool
(
output_dir
=
args
.
output_dir
)
def
on_train_end
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
r
"""
Event called at the end of training.
"""
self
.
_close_thread_pool
()
def
on_substep_end
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
r
"""
Event called at the end of an substep during gradient accumulation.
"""
if
self
.
aborted
:
control
.
should_epoch_stop
=
True
control
.
should_training_stop
=
True
def
on_step_end
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
r
"""
Event called at the end of a training step.
"""
if
self
.
aborted
:
control
.
should_epoch_stop
=
True
control
.
should_training_stop
=
True
def
on_evaluate
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
r
"""
Event called after an evaluation phase.
"""
if
not
self
.
do_train
:
self
.
_close_thread_pool
()
def
on_predict
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
r
"""
Event called after a successful prediction.
"""
if
not
self
.
do_train
:
self
.
_close_thread_pool
()
def
on_log
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
r
"""
Event called after logging the last logs.
"""
if
not
args
.
should_save
:
return
self
.
_timing
(
cur_steps
=
state
.
global_step
)
logs
=
dict
(
current_steps
=
self
.
cur_steps
,
total_steps
=
self
.
max_steps
,
loss
=
state
.
log_history
[
-
1
].
get
(
"loss"
,
None
),
eval_loss
=
state
.
log_history
[
-
1
].
get
(
"eval_loss"
,
None
),
predict_loss
=
state
.
log_history
[
-
1
].
get
(
"predict_loss"
,
None
),
reward
=
state
.
log_history
[
-
1
].
get
(
"reward"
,
None
),
accuracy
=
state
.
log_history
[
-
1
].
get
(
"rewards/accuracies"
,
None
),
learning_rate
=
state
.
log_history
[
-
1
].
get
(
"learning_rate"
,
None
),
epoch
=
state
.
log_history
[
-
1
].
get
(
"epoch"
,
None
),
percentage
=
round
(
self
.
cur_steps
/
self
.
max_steps
*
100
,
2
)
if
self
.
max_steps
!=
0
else
100
,
elapsed_time
=
self
.
elapsed_time
,
remaining_time
=
self
.
remaining_time
,
)
logs
=
{
k
:
v
for
k
,
v
in
logs
.
items
()
if
v
is
not
None
}
if
self
.
webui_mode
and
all
(
key
in
logs
for
key
in
[
"loss"
,
"learning_rate"
,
"epoch"
]):
logger
.
info
(
"{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}}}"
.
format
(
logs
[
"loss"
],
logs
[
"learning_rate"
],
logs
[
"epoch"
]
)
)
if
self
.
thread_pool
is
not
None
:
self
.
thread_pool
.
submit
(
self
.
_write_log
,
args
.
output_dir
,
logs
)
def
on_prediction_step
(
self
,
args
:
"TrainingArguments"
,
state
:
"TrainerState"
,
control
:
"TrainerControl"
,
**
kwargs
):
r
"""
Event called after a prediction step.
"""
if
self
.
do_train
:
return
if
self
.
aborted
:
sys
.
exit
(
0
)
if
not
args
.
should_save
:
return
eval_dataloader
=
kwargs
.
pop
(
"eval_dataloader"
,
None
)
if
has_length
(
eval_dataloader
):
if
self
.
max_steps
==
0
:
self
.
_reset
(
max_steps
=
len
(
eval_dataloader
))
self
.
_create_thread_pool
(
output_dir
=
args
.
output_dir
)
self
.
_timing
(
cur_steps
=
self
.
cur_steps
+
1
)
if
self
.
cur_steps
%
5
==
0
and
self
.
thread_pool
is
not
None
:
logs
=
dict
(
current_steps
=
self
.
cur_steps
,
total_steps
=
self
.
max_steps
,
percentage
=
round
(
self
.
cur_steps
/
self
.
max_steps
*
100
,
2
)
if
self
.
max_steps
!=
0
else
100
,
elapsed_time
=
self
.
elapsed_time
,
remaining_time
=
self
.
remaining_time
,
)
self
.
thread_pool
.
submit
(
self
.
_write_log
,
args
.
output_dir
,
logs
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment