average_checkpoints.py 5.21 KB
Newer Older
Myle Ott's avatar
Myle Ott committed
1
#!/usr/bin/env python3
2
3
4
5
6
7
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
Myle Ott's avatar
Myle Ott committed
8
9
10
11

import argparse
import collections
import torch
12
13
import os
import re
Myle Ott's avatar
Myle Ott committed
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29


def average_checkpoints(inputs):
    """Loads checkpoints from inputs and returns a model with averaged weights.

    Args:
      inputs: An iterable of string paths of checkpoints to load from.

    Returns:
      A dict of string keys mapping to various values. The 'model' key
      from the returned dict should correspond to an OrderedDict mapping
      string parameter names to torch Tensors.
    """
    params_dict = collections.OrderedDict()
    params_keys = None
    new_state = None
30
31
    num_models = len(inputs)

Myle Ott's avatar
Myle Ott committed
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
    for f in inputs:
        state = torch.load(
            f,
            map_location=(
                lambda s, _: torch.serialization.default_restore_location(s, 'cpu')
            ),
        )
        # Copies over the settings from the first checkpoint
        if new_state is None:
            new_state = state

        model_params = state['model']

        model_params_keys = list(model_params.keys())
        if params_keys is None:
            params_keys = model_params_keys
        elif params_keys != model_params_keys:
            raise KeyError(
                'For checkpoint {}, expected list of params: {}, '
                'but found: {}'.format(f, params_keys, model_params_keys)
            )

        for k in params_keys:
Myle Ott's avatar
Myle Ott committed
55
56
57
            p = model_params[k]
            if isinstance(p, torch.HalfTensor):
                p = p.float()
58
59
60
61
            if k not in params_dict:
                params_dict[k] = p
            else:
                params_dict[k] += p
Myle Ott's avatar
Myle Ott committed
62
63
64
65

    averaged_params = collections.OrderedDict()
    # v should be a list of torch Tensor.
    for k, v in params_dict.items():
66
        averaged_params[k] = v / num_models
Myle Ott's avatar
Myle Ott committed
67
68
69
70
    new_state['model'] = averaged_params
    return new_state


71
def last_n_checkpoints(paths, n, update_based, upper_bound=None):
72
73
    assert len(paths) == 1
    path = paths[0]
74
75
76
77
    if update_based:
        pt_regexp = re.compile(r'checkpoint_\d+_(\d+)\.pt')
    else:
        pt_regexp = re.compile(r'checkpoint(\d+)\.pt')
78
79
80
81
82
83
    files = os.listdir(path)

    entries = []
    for f in files:
        m = pt_regexp.fullmatch(f)
        if m is not None:
84
85
86
            sort_key = int(m.group(1))
            if upper_bound is None or sort_key <= upper_bound:
                entries.append((sort_key, m.group(0)))
87
88
89
90
91
    if len(entries) < n:
        raise Exception('Found {} checkpoint files but need at least {}', len(entries), n)
    return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)[:n]]


Myle Ott's avatar
Myle Ott committed
92
93
94
def main():
    parser = argparse.ArgumentParser(
        description='Tool to average the params of input checkpoints to '
95
                    'produce a new checkpoint',
Myle Ott's avatar
Myle Ott committed
96
    )
97
98
99
100
101
    # fmt: off
    parser.add_argument('--inputs', required=True, nargs='+',
                        help='Input checkpoint file paths.')
    parser.add_argument('--output', required=True, metavar='FILE',
                        help='Write the new checkpoint containing the averaged weights to this path.')
102
    num_group = parser.add_mutually_exclusive_group()
103
104
105
106
107
108
    num_group.add_argument('--num-epoch-checkpoints', type=int,
                           help='if set, will try to find checkpoints with names checkpoint_xx.pt in the path specified by input, '
                           'and average last this many of them.')
    num_group.add_argument('--num-update-checkpoints', type=int,
                           help='if set, will try to find checkpoints with names checkpoint_ee_xx.pt in the path specified by input, '
                           'and average last this many of them.')
109
110
111
    parser.add_argument('--checkpoint-upper-bound', type=int,
                        help='when using --num-epoch-checkpoints, this will set an upper bound on which checkpoint to use, '
                        'e.g., with --num-epoch-checkpoints=10 --checkpoint-upper-bound=50, checkpoints 41-50 would be averaged.')
112
    # fmt: on
Myle Ott's avatar
Myle Ott committed
113
114
115
    args = parser.parse_args()
    print(args)

116
117
118
119
120
121
122
123
    num = None
    is_update_based = False
    if args.num_update_checkpoints is not None:
        num = args.num_update_checkpoints
        is_update_based = True
    elif args.num_epoch_checkpoints is not None:
        num = args.num_epoch_checkpoints

124
125
126
127
128
    assert args.checkpoint_upper_bound is None or args.num_epoch_checkpoints is not None, \
            '--checkpoint-upper-bound requires --num-epoch-checkpoints'
    assert args.num_epoch_checkpoints is None or args.num_update_checkpoints is None, \
            'Cannot combine --num-epoch-checkpoints and --num-update-checkpoints'

129
    if num is not None:
130
131
132
        args.inputs = last_n_checkpoints(
            args.inputs, num, is_update_based, upper_bound=args.checkpoint_upper_bound,
        )
133
134
        print('averaging checkpoints: ', args.inputs)

Myle Ott's avatar
Myle Ott committed
135
136
137
138
139
140
141
    new_state = average_checkpoints(args.inputs)
    torch.save(new_state, args.output)
    print('Finished writing averaged checkpoint to {}.'.format(args.output))


if __name__ == '__main__':
    main()