Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
f6a5a54e
"docs/vscode:/vscode.git/clone" did not exist on "7a4845a99f845d4e2fc3ceb5b6cdf7fed29dc662"
Commit
f6a5a54e
authored
May 04, 2018
by
Alexei Baevski
Committed by
Myle Ott
Jun 15, 2018
Browse files
add support for averaging last n checkpoints
parent
23211c45
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
28 additions
and
0 deletions
+28
-0
scripts/average_checkpoints.py
scripts/average_checkpoints.py
+28
-0
No files found.
scripts/average_checkpoints.py
View file @
f6a5a54e
...
...
@@ -3,6 +3,8 @@
import
argparse
import
collections
import
torch
import
os
import
re
def
average_checkpoints
(
inputs
):
...
...
@@ -60,6 +62,22 @@ def average_checkpoints(inputs):
return
new_state
def
last_n_checkpoints
(
paths
,
n
):
assert
len
(
paths
)
==
1
path
=
paths
[
0
]
pt_regexp
=
re
.
compile
(
r
'checkpoint(\d+)\.pt'
)
files
=
os
.
listdir
(
path
)
entries
=
[]
for
f
in
files
:
m
=
pt_regexp
.
fullmatch
(
f
)
if
m
is
not
None
:
entries
.
append
((
int
(
m
.
group
(
1
)),
m
.
group
(
0
)))
if
len
(
entries
)
<
n
:
raise
Exception
(
'Found {} checkpoint files but need at least {}'
,
len
(
entries
),
n
)
return
[
os
.
path
.
join
(
path
,
x
[
1
])
for
x
in
sorted
(
entries
,
reverse
=
True
)[:
n
]]
def
main
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Tool to average the params of input checkpoints to '
...
...
@@ -79,9 +97,19 @@ def main():
help
=
'Write the new checkpoint containing the averaged weights to this '
'path.'
,
)
parser
.
add_argument
(
'--num'
,
type
=
int
,
help
=
'if set, will try to find checkpoints with names checkpoint_xx.pt in the path specified by input, '
'and average last num of those'
,
)
args
=
parser
.
parse_args
()
print
(
args
)
if
args
.
num
is
not
None
:
args
.
inputs
=
last_n_checkpoints
(
args
.
inputs
,
args
.
num
)
print
(
'averaging checkpoints: '
,
args
.
inputs
)
new_state
=
average_checkpoints
(
args
.
inputs
)
torch
.
save
(
new_state
,
args
.
output
)
print
(
'Finished writing averaged checkpoint to {}.'
.
format
(
args
.
output
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment