#!/usr/bin/env python3 import argparse import collections import torch import os import re def average_checkpoints(inputs): """Loads checkpoints from inputs and returns a model with averaged weights. Args: inputs: An iterable of string paths of checkpoints to load from. Returns: A dict of string keys mapping to various values. The 'model' key from the returned dict should correspond to an OrderedDict mapping string parameter names to torch Tensors. """ params_dict = collections.OrderedDict() params_keys = None new_state = None for f in inputs: state = torch.load( f, map_location=( lambda s, _: torch.serialization.default_restore_location(s, 'cpu') ), ) # Copies over the settings from the first checkpoint if new_state is None: new_state = state model_params = state['model'] model_params_keys = list(model_params.keys()) if params_keys is None: params_keys = model_params_keys elif params_keys != model_params_keys: raise KeyError( 'For checkpoint {}, expected list of params: {}, ' 'but found: {}'.format(f, params_keys, model_params_keys) ) for k in params_keys: if k not in params_dict: params_dict[k] = [] p = model_params[k] if isinstance(p, torch.HalfTensor): p = p.float() params_dict[k].append(p) averaged_params = collections.OrderedDict() # v should be a list of torch Tensor. for k, v in params_dict.items(): summed_v = None for x in v: summed_v = summed_v + x if summed_v is not None else x averaged_params[k] = summed_v / len(v) new_state['model'] = averaged_params return new_state def last_n_checkpoints(paths, n, update_based): assert len(paths) == 1 path = paths[0] if update_based: pt_regexp = re.compile(r'checkpoint_\d+_(\d+)\.pt') else: pt_regexp = re.compile(r'checkpoint(\d+)\.pt') files = os.listdir(path) entries = [] for f in files: m = pt_regexp.fullmatch(f) if m is not None: entries.append((int(m.group(1)), m.group(0))) if len(entries) < n: raise Exception('Found {} checkpoint files but need at least {}', len(entries), n) return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)[:n]] def main(): parser = argparse.ArgumentParser( description='Tool to average the params of input checkpoints to ' 'produce a new checkpoint', ) parser.add_argument( '--inputs', required=True, nargs='+', help='Input checkpoint file paths.', ) parser.add_argument( '--output', required=True, metavar='FILE', help='Write the new checkpoint containing the averaged weights to this ' 'path.', ) parser.add_argument( '--num', type=int, help='if set, will try to find checkpoints with names checkpoint_xx.pt in the path specified by input, ' 'and average last num of those', ) parser.add_argument( '--update-based-checkpoints', action='store_true', help='if set and used together with --num, averages update-based checkpoints instead of epoch-based checkpoints' ) args = parser.parse_args() print(args) if args.num is not None: args.inputs = last_n_checkpoints(args.inputs, args.num, args.update_based_checkpoints) print('averaging checkpoints: ', args.inputs) new_state = average_checkpoints(args.inputs) torch.save(new_state, args.output) print('Finished writing averaged checkpoint to {}.'.format(args.output)) if __name__ == '__main__': main()