Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
a515caa3
"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "ec07da65e25562040581febaf9b400a462962961"
Unverified
Commit
a515caa3
authored
May 18, 2021
by
Sylvain Gugger
Committed by
GitHub
May 18, 2021
Browse files
Fix checkpoint deletion (#11748)
parent
b88e0e01
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
49 additions
and
10 deletions
+49
-10
src/transformers/trainer.py
src/transformers/trainer.py
+18
-10
tests/test_trainer.py
tests/test_trainer.py
+31
-0
No files found.
src/transformers/trainer.py
View file @
a515caa3
...
@@ -1523,10 +1523,6 @@ class Trainer:
...
@@ -1523,10 +1523,6 @@ class Trainer:
if
self
.
is_world_process_zero
():
if
self
.
is_world_process_zero
():
self
.
state
.
save_to_json
(
os
.
path
.
join
(
output_dir
,
"trainer_state.json"
))
self
.
state
.
save_to_json
(
os
.
path
.
join
(
output_dir
,
"trainer_state.json"
))
# Maybe delete some older checkpoints.
if
self
.
is_world_process_zero
():
self
.
_rotate_checkpoints
(
use_mtime
=
True
,
output_dir
=
run_dir
)
# Save RNG state in non-distributed training
# Save RNG state in non-distributed training
rng_states
=
{
rng_states
=
{
"python"
:
random
.
getstate
(),
"python"
:
random
.
getstate
(),
...
@@ -1552,6 +1548,10 @@ class Trainer:
...
@@ -1552,6 +1548,10 @@ class Trainer:
else
:
else
:
torch
.
save
(
rng_states
,
os
.
path
.
join
(
output_dir
,
f
"rng_state_
{
local_rank
}
.pth"
))
torch
.
save
(
rng_states
,
os
.
path
.
join
(
output_dir
,
f
"rng_state_
{
local_rank
}
.pth"
))
# Maybe delete some older checkpoints.
if
self
.
is_world_process_zero
():
self
.
_rotate_checkpoints
(
use_mtime
=
True
,
output_dir
=
run_dir
)
def
_load_optimizer_and_scheduler
(
self
,
checkpoint
):
def
_load_optimizer_and_scheduler
(
self
,
checkpoint
):
"""If optimizer and scheduler states exist, load them."""
"""If optimizer and scheduler states exist, load them."""
if
checkpoint
is
None
:
if
checkpoint
is
None
:
...
@@ -1924,7 +1924,7 @@ class Trainer:
...
@@ -1924,7 +1924,7 @@ class Trainer:
ordering_and_checkpoint_path
.
append
((
os
.
path
.
getmtime
(
path
),
path
))
ordering_and_checkpoint_path
.
append
((
os
.
path
.
getmtime
(
path
),
path
))
else
:
else
:
regex_match
=
re
.
match
(
f
".*
{
checkpoint_prefix
}
-([0-9]+)"
,
path
)
regex_match
=
re
.
match
(
f
".*
{
checkpoint_prefix
}
-([0-9]+)"
,
path
)
if
regex_match
and
regex_match
.
groups
():
if
regex_match
is
not
None
and
regex_match
.
groups
()
is
not
None
:
ordering_and_checkpoint_path
.
append
((
int
(
regex_match
.
groups
()[
0
]),
path
))
ordering_and_checkpoint_path
.
append
((
int
(
regex_match
.
groups
()[
0
]),
path
))
checkpoints_sorted
=
sorted
(
ordering_and_checkpoint_path
)
checkpoints_sorted
=
sorted
(
ordering_and_checkpoint_path
)
...
@@ -1932,10 +1932,8 @@ class Trainer:
...
@@ -1932,10 +1932,8 @@ class Trainer:
# Make sure we don't delete the best model.
# Make sure we don't delete the best model.
if
self
.
state
.
best_model_checkpoint
is
not
None
:
if
self
.
state
.
best_model_checkpoint
is
not
None
:
best_model_index
=
checkpoints_sorted
.
index
(
str
(
Path
(
self
.
state
.
best_model_checkpoint
)))
best_model_index
=
checkpoints_sorted
.
index
(
str
(
Path
(
self
.
state
.
best_model_checkpoint
)))
checkpoints_sorted
[
best_model_index
],
checkpoints_sorted
[
-
1
]
=
(
for
i
in
range
(
best_model_index
,
len
(
checkpoints_sorted
)
-
2
):
checkpoints_sorted
[
-
1
],
checkpoints_sorted
[
i
],
checkpoints_sorted
[
i
+
1
]
=
checkpoints_sorted
[
i
+
1
],
checkpoints_sorted
[
i
]
checkpoints_sorted
[
best_model_index
],
)
return
checkpoints_sorted
return
checkpoints_sorted
def
_rotate_checkpoints
(
self
,
use_mtime
=
False
,
output_dir
=
None
)
->
None
:
def
_rotate_checkpoints
(
self
,
use_mtime
=
False
,
output_dir
=
None
)
->
None
:
...
@@ -1947,7 +1945,17 @@ class Trainer:
...
@@ -1947,7 +1945,17 @@ class Trainer:
if
len
(
checkpoints_sorted
)
<=
self
.
args
.
save_total_limit
:
if
len
(
checkpoints_sorted
)
<=
self
.
args
.
save_total_limit
:
return
return
number_of_checkpoints_to_delete
=
max
(
0
,
len
(
checkpoints_sorted
)
-
self
.
args
.
save_total_limit
)
# If save_total_limit=1 with load_best_mode_at_end=True, we could end up deleting the last checkpoint, which
# we don't do to allow resuming.
save_total_limit
=
self
.
args
.
save_total_limit
if
(
self
.
state
.
best_model_checkpoint
is
not
None
and
self
.
args
.
save_total_limit
==
1
and
checkpoints_sorted
[
-
1
]
!=
self
.
state
.
best_model_checkpoint
):
save_total_limit
=
2
number_of_checkpoints_to_delete
=
max
(
0
,
len
(
checkpoints_sorted
)
-
save_total_limit
)
checkpoints_to_be_deleted
=
checkpoints_sorted
[:
number_of_checkpoints_to_delete
]
checkpoints_to_be_deleted
=
checkpoints_sorted
[:
number_of_checkpoints_to_delete
]
for
checkpoint
in
checkpoints_to_be_deleted
:
for
checkpoint
in
checkpoints_to_be_deleted
:
logger
.
info
(
f
"Deleting older checkpoint [
{
checkpoint
}
] due to args.save_total_limit"
)
logger
.
info
(
f
"Deleting older checkpoint [
{
checkpoint
}
] due to args.save_total_limit"
)
...
...
tests/test_trainer.py
View file @
a515caa3
...
@@ -21,6 +21,7 @@ import random
...
@@ -21,6 +21,7 @@ import random
import
re
import
re
import
tempfile
import
tempfile
import
unittest
import
unittest
from
pathlib
import
Path
import
numpy
as
np
import
numpy
as
np
...
@@ -45,6 +46,7 @@ from transformers.testing_utils import (
...
@@ -45,6 +46,7 @@ from transformers.testing_utils import (
require_torch_multi_gpu
,
require_torch_multi_gpu
,
slow
,
slow
,
)
)
from
transformers.trainer_utils
import
PREFIX_CHECKPOINT_DIR
from
transformers.utils.hp_naming
import
TrialShortNamer
from
transformers.utils.hp_naming
import
TrialShortNamer
...
@@ -1048,6 +1050,35 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
...
@@ -1048,6 +1050,35 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
trainer
.
train
()
trainer
.
train
()
self
.
assertTrue
(
isinstance
(
trainer
.
state
.
total_flos
,
float
))
self
.
assertTrue
(
isinstance
(
trainer
.
state
.
total_flos
,
float
))
def
check_checkpoint_deletion
(
self
,
trainer
,
output_dir
,
expected
):
# Make fake checkpoints
for
n
in
[
5
,
10
,
15
,
20
,
25
]:
os
.
makedirs
(
os
.
path
.
join
(
output_dir
,
f
"
{
PREFIX_CHECKPOINT_DIR
}
-
{
n
}
"
),
exist_ok
=
True
)
trainer
.
_rotate_checkpoints
(
output_dir
=
output_dir
)
glob_checkpoints
=
[
str
(
x
)
for
x
in
Path
(
output_dir
).
glob
(
f
"
{
PREFIX_CHECKPOINT_DIR
}
-*"
)]
values
=
[
int
(
re
.
match
(
f
".*
{
PREFIX_CHECKPOINT_DIR
}
-([0-9]+)"
,
d
).
groups
()[
0
])
for
d
in
glob_checkpoints
]
self
.
assertSetEqual
(
set
(
values
),
set
(
expected
))
def
test_checkpoint_rotation
(
self
):
with
tempfile
.
TemporaryDirectory
()
as
tmp_dir
:
# Without best model at end
trainer
=
get_regression_trainer
(
output_dir
=
tmp_dir
,
save_total_limit
=
2
)
self
.
check_checkpoint_deletion
(
trainer
,
tmp_dir
,
[
20
,
25
])
# With best model at end
trainer
=
get_regression_trainer
(
output_dir
=
tmp_dir
,
load_best_model_at_end
=
True
,
save_total_limit
=
2
)
trainer
.
state
.
best_model_checkpoint
=
os
.
path
.
join
(
tmp_dir
,
"checkpoint-5"
)
self
.
check_checkpoint_deletion
(
trainer
,
tmp_dir
,
[
5
,
25
])
# Edge case: we don't always honor save_total_limit=1 if load_best_model_at_end=True to be able to resume
# from checkpoint
trainer
=
get_regression_trainer
(
output_dir
=
tmp_dir
,
load_best_model_at_end
=
True
,
save_total_limit
=
1
)
trainer
.
state
.
best_model_checkpoint
=
os
.
path
.
join
(
tmp_dir
,
"checkpoint-25"
)
self
.
check_checkpoint_deletion
(
trainer
,
tmp_dir
,
[
25
])
trainer
.
state
.
best_model_checkpoint
=
os
.
path
.
join
(
tmp_dir
,
"checkpoint-5"
)
self
.
check_checkpoint_deletion
(
trainer
,
tmp_dir
,
[
5
,
25
])
def
check_mem_metrics
(
self
,
trainer
,
check_func
):
def
check_mem_metrics
(
self
,
trainer
,
check_func
):
metrics
=
trainer
.
train
().
metrics
metrics
=
trainer
.
train
().
metrics
check_func
(
"init_mem_cpu_alloc_delta"
,
metrics
)
check_func
(
"init_mem_cpu_alloc_delta"
,
metrics
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment