#!/usr/bin/env python3 import argparse import collections import torch def average_checkpoints(inputs): """Loads checkpoints from inputs and returns a model with averaged weights. Args: inputs: An iterable of string paths of checkpoints to load from. Returns: A dict of string keys mapping to various values. The 'model' key from the returned dict should correspond to an OrderedDict mapping string parameter names to torch Tensors. """ params_dict = collections.OrderedDict() params_keys = None new_state = None for f in inputs: state = torch.load( f, map_location=( lambda s, _: torch.serialization.default_restore_location(s, 'cpu') ), ) # Copies over the settings from the first checkpoint if new_state is None: new_state = state model_params = state['model'] model_params_keys = list(model_params.keys()) if params_keys is None: params_keys = model_params_keys elif params_keys != model_params_keys: raise KeyError( 'For checkpoint {}, expected list of params: {}, ' 'but found: {}'.format(f, params_keys, model_params_keys) ) for k in params_keys: if k not in params_dict: params_dict[k] = [] params_dict[k].append(model_params[k].float()) averaged_params = collections.OrderedDict() # v should be a list of torch Tensor. for k, v in params_dict.items(): summed_v = None for x in v: summed_v = summed_v + x if summed_v is not None else x averaged_params[k] = summed_v / len(v) new_state['model'] = averaged_params return new_state def main(): parser = argparse.ArgumentParser( description='Tool to average the params of input checkpoints to ' 'produce a new checkpoint', ) parser.add_argument( '--inputs', required=True, nargs='+', help='Input checkpoint file paths.', ) parser.add_argument( '--output', required=True, metavar='FILE', help='Write the new checkpoint containing the averaged weights to this ' 'path.', ) args = parser.parse_args() print(args) new_state = average_checkpoints(args.inputs) torch.save(new_state, args.output) print('Finished writing averaged checkpoint to {}.'.format(args.output)) if __name__ == '__main__': main()