Commit f6a5a54e authored by Alexei Baevski's avatar Alexei Baevski Committed by Myle Ott
Browse files

add support for averaging last n checkpoints

parent 23211c45
...@@ -3,6 +3,8 @@ ...@@ -3,6 +3,8 @@
import argparse import argparse
import collections import collections
import torch import torch
import os
import re
def average_checkpoints(inputs): def average_checkpoints(inputs):
...@@ -60,6 +62,22 @@ def average_checkpoints(inputs): ...@@ -60,6 +62,22 @@ def average_checkpoints(inputs):
return new_state return new_state
def last_n_checkpoints(paths, n):
assert len(paths) == 1
path = paths[0]
pt_regexp = re.compile(r'checkpoint(\d+)\.pt')
files = os.listdir(path)
entries = []
for f in files:
m = pt_regexp.fullmatch(f)
if m is not None:
entries.append((int(m.group(1)), m.group(0)))
if len(entries) < n:
raise Exception('Found {} checkpoint files but need at least {}', len(entries), n)
return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)[:n]]
def main(): def main():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='Tool to average the params of input checkpoints to ' description='Tool to average the params of input checkpoints to '
...@@ -79,9 +97,19 @@ def main(): ...@@ -79,9 +97,19 @@ def main():
help='Write the new checkpoint containing the averaged weights to this ' help='Write the new checkpoint containing the averaged weights to this '
'path.', 'path.',
) )
parser.add_argument(
'--num',
type=int,
help='if set, will try to find checkpoints with names checkpoint_xx.pt in the path specified by input, '
'and average last num of those',
)
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
if args.num is not None:
args.inputs = last_n_checkpoints(args.inputs, args.num)
print('averaging checkpoints: ', args.inputs)
new_state = average_checkpoints(args.inputs) new_state = average_checkpoints(args.inputs)
torch.save(new_state, args.output) torch.save(new_state, args.output)
print('Finished writing averaged checkpoint to {}.'.format(args.output)) print('Finished writing averaged checkpoint to {}.'.format(args.output))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment