Commit bdec179b authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Add --checkpoint-upper-bound to average_checkpoints.py (#452)

Summary:
This is useful for averaging the last N checkpoints, ending at some "best" checkpoint.
Pull Request resolved: https://github.com/pytorch/fairseq/pull/452

Differential Revision: D13695407

Pulled By: myleott

fbshipit-source-id: 5d9d2bff3706834f01501e9259834c77fb335817
parent d1dc66d9
......@@ -62,7 +62,7 @@ def average_checkpoints(inputs):
return new_state
def last_n_checkpoints(paths, n, update_based):
def last_n_checkpoints(paths, n, update_based, upper_bound=None):
assert len(paths) == 1
path = paths[0]
if update_based:
......@@ -75,7 +75,9 @@ def last_n_checkpoints(paths, n, update_based):
for f in files:
m = pt_regexp.fullmatch(f)
if m is not None:
entries.append((int(m.group(1)), m.group(0)))
sort_key = int(m.group(1))
if upper_bound is None or sort_key <= upper_bound:
entries.append((sort_key, m.group(0)))
if len(entries) < n:
raise Exception('Found {} checkpoint files but need at least {}', len(entries), n)
return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)[:n]]
......@@ -98,6 +100,9 @@ def main():
num_group.add_argument('--num-update-checkpoints', type=int,
help='if set, will try to find checkpoints with names checkpoint_ee_xx.pt in the path specified by input, '
'and average last this many of them.')
parser.add_argument('--checkpoint-upper-bound', type=int,
help='when using --num-epoch-checkpoints, this will set an upper bound on which checkpoint to use, '
'e.g., with --num-epoch-checkpoints=10 --checkpoint-upper-bound=50, checkpoints 41-50 would be averaged.')
# fmt: on
args = parser.parse_args()
print(args)
......@@ -110,8 +115,15 @@ def main():
elif args.num_epoch_checkpoints is not None:
num = args.num_epoch_checkpoints
assert args.checkpoint_upper_bound is None or args.num_epoch_checkpoints is not None, \
'--checkpoint-upper-bound requires --num-epoch-checkpoints'
assert args.num_epoch_checkpoints is None or args.num_update_checkpoints is None, \
'Cannot combine --num-epoch-checkpoints and --num-update-checkpoints'
if num is not None:
args.inputs = last_n_checkpoints(args.inputs, num, is_update_based)
args.inputs = last_n_checkpoints(
args.inputs, num, is_update_based, upper_bound=args.checkpoint_upper_bound,
)
print('averaging checkpoints: ', args.inputs)
new_state = average_checkpoints(args.inputs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment