Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
c53cc018
Commit
c53cc018
authored
Apr 23, 2020
by
Julien Chaumond
Browse files
[Trainer] Fix _rotate_checkpoints
Close #3920
parent
cbbb3c43
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
3 additions
and
5 deletions
+3
-5
src/transformers/trainer.py
src/transformers/trainer.py
+3
-5
No files found.
src/transformers/trainer.py
View file @
c53cc018
...
@@ -434,13 +434,13 @@ class Trainer:
...
@@ -434,13 +434,13 @@ class Trainer:
def
_sorted_checkpoints
(
self
,
checkpoint_prefix
=
PREFIX_CHECKPOINT_DIR
,
use_mtime
=
False
)
->
List
[
str
]:
def
_sorted_checkpoints
(
self
,
checkpoint_prefix
=
PREFIX_CHECKPOINT_DIR
,
use_mtime
=
False
)
->
List
[
str
]:
ordering_and_checkpoint_path
=
[]
ordering_and_checkpoint_path
=
[]
glob_checkpoints
=
Path
(
self
.
args
.
output_dir
).
glob
(
f
"
{
checkpoint_prefix
}
-*"
)
glob_checkpoints
=
[
str
(
x
)
for
x
in
Path
(
self
.
args
.
output_dir
).
glob
(
f
"
{
checkpoint_prefix
}
-*"
)
]
for
path
in
glob_checkpoints
:
for
path
in
glob_checkpoints
:
if
use_mtime
:
if
use_mtime
:
ordering_and_checkpoint_path
.
append
((
os
.
path
.
getmtime
(
path
),
path
))
ordering_and_checkpoint_path
.
append
((
os
.
path
.
getmtime
(
path
),
path
))
else
:
else
:
regex_match
=
re
.
match
(
".*{
}-([0-9]+)"
.
format
(
checkpoint_prefix
)
,
path
)
regex_match
=
re
.
match
(
f
".*
{
checkpoint_prefix
}
-([0-9]+)"
,
path
)
if
regex_match
and
regex_match
.
groups
():
if
regex_match
and
regex_match
.
groups
():
ordering_and_checkpoint_path
.
append
((
int
(
regex_match
.
groups
()[
0
]),
path
))
ordering_and_checkpoint_path
.
append
((
int
(
regex_match
.
groups
()[
0
]),
path
))
...
@@ -449,9 +449,7 @@ class Trainer:
...
@@ -449,9 +449,7 @@ class Trainer:
return
checkpoints_sorted
return
checkpoints_sorted
def
_rotate_checkpoints
(
self
,
use_mtime
=
False
)
->
None
:
def
_rotate_checkpoints
(
self
,
use_mtime
=
False
)
->
None
:
if
not
self
.
args
.
save_total_limit
:
if
self
.
args
.
save_total_limit
is
None
or
self
.
args
.
save_total_limit
<=
0
:
return
if
self
.
args
.
save_total_limit
<=
0
:
return
return
# Check if we should delete older checkpoint(s)
# Check if we should delete older checkpoint(s)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment