"vscode:/vscode.git/clone" did not exist on "53a42d0a0cab99e9a905b117b9893052c6849e10"
average_checkpoints.py 5.18 KB
Newer Older
Myle Ott's avatar
Myle Ott committed
1
#!/usr/bin/env python3
2
# Copyright (c) Facebook, Inc. and its affiliates.
3
#
4
5
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
Myle Ott's avatar
Myle Ott committed
6
7
8
9

import argparse
import collections
import torch
10
11
import os
import re
Myle Ott's avatar
Myle Ott committed
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27


def average_checkpoints(inputs):
    """Loads checkpoints from inputs and returns a model with averaged weights.

    Args:
      inputs: An iterable of string paths of checkpoints to load from.

    Returns:
      A dict of string keys mapping to various values. The 'model' key
      from the returned dict should correspond to an OrderedDict mapping
      string parameter names to torch Tensors.
    """
    params_dict = collections.OrderedDict()
    params_keys = None
    new_state = None
28
29
    num_models = len(inputs)

Myle Ott's avatar
Myle Ott committed
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
    for f in inputs:
        state = torch.load(
            f,
            map_location=(
                lambda s, _: torch.serialization.default_restore_location(s, 'cpu')
            ),
        )
        # Copies over the settings from the first checkpoint
        if new_state is None:
            new_state = state

        model_params = state['model']

        model_params_keys = list(model_params.keys())
        if params_keys is None:
            params_keys = model_params_keys
        elif params_keys != model_params_keys:
            raise KeyError(
                'For checkpoint {}, expected list of params: {}, '
                'but found: {}'.format(f, params_keys, model_params_keys)
            )

        for k in params_keys:
Myle Ott's avatar
Myle Ott committed
53
54
55
            p = model_params[k]
            if isinstance(p, torch.HalfTensor):
                p = p.float()
56
57
58
59
60
            if k not in params_dict:
                params_dict[k] = p.clone()
                # NOTE: clone() is needed in case of p is a shared parameter
            else:
                params_dict[k] += p
Myle Ott's avatar
Myle Ott committed
61
62
63

    averaged_params = collections.OrderedDict()
    for k, v in params_dict.items():
64
65
        averaged_params[k] = v
        averaged_params[k].div_(num_models)
Myle Ott's avatar
Myle Ott committed
66
67
68
69
    new_state['model'] = averaged_params
    return new_state


70
def last_n_checkpoints(paths, n, update_based, upper_bound=None):
71
72
    assert len(paths) == 1
    path = paths[0]
73
74
75
76
    if update_based:
        pt_regexp = re.compile(r'checkpoint_\d+_(\d+)\.pt')
    else:
        pt_regexp = re.compile(r'checkpoint(\d+)\.pt')
77
78
79
80
81
82
    files = os.listdir(path)

    entries = []
    for f in files:
        m = pt_regexp.fullmatch(f)
        if m is not None:
83
84
85
            sort_key = int(m.group(1))
            if upper_bound is None or sort_key <= upper_bound:
                entries.append((sort_key, m.group(0)))
86
87
88
89
90
    if len(entries) < n:
        raise Exception('Found {} checkpoint files but need at least {}', len(entries), n)
    return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)[:n]]


Myle Ott's avatar
Myle Ott committed
91
92
93
def main():
    parser = argparse.ArgumentParser(
        description='Tool to average the params of input checkpoints to '
94
                    'produce a new checkpoint',
Myle Ott's avatar
Myle Ott committed
95
    )
96
97
98
99
100
    # fmt: off
    parser.add_argument('--inputs', required=True, nargs='+',
                        help='Input checkpoint file paths.')
    parser.add_argument('--output', required=True, metavar='FILE',
                        help='Write the new checkpoint containing the averaged weights to this path.')
101
    num_group = parser.add_mutually_exclusive_group()
102
103
104
105
106
107
    num_group.add_argument('--num-epoch-checkpoints', type=int,
                           help='if set, will try to find checkpoints with names checkpoint_xx.pt in the path specified by input, '
                           'and average last this many of them.')
    num_group.add_argument('--num-update-checkpoints', type=int,
                           help='if set, will try to find checkpoints with names checkpoint_ee_xx.pt in the path specified by input, '
                           'and average last this many of them.')
108
109
110
    parser.add_argument('--checkpoint-upper-bound', type=int,
                        help='when using --num-epoch-checkpoints, this will set an upper bound on which checkpoint to use, '
                        'e.g., with --num-epoch-checkpoints=10 --checkpoint-upper-bound=50, checkpoints 41-50 would be averaged.')
111
    # fmt: on
Myle Ott's avatar
Myle Ott committed
112
113
114
    args = parser.parse_args()
    print(args)

115
116
117
118
119
120
121
122
    num = None
    is_update_based = False
    if args.num_update_checkpoints is not None:
        num = args.num_update_checkpoints
        is_update_based = True
    elif args.num_epoch_checkpoints is not None:
        num = args.num_epoch_checkpoints

123
124
125
126
127
    assert args.checkpoint_upper_bound is None or args.num_epoch_checkpoints is not None, \
            '--checkpoint-upper-bound requires --num-epoch-checkpoints'
    assert args.num_epoch_checkpoints is None or args.num_update_checkpoints is None, \
            'Cannot combine --num-epoch-checkpoints and --num-update-checkpoints'

128
    if num is not None:
129
130
131
        args.inputs = last_n_checkpoints(
            args.inputs, num, is_update_based, upper_bound=args.checkpoint_upper_bound,
        )
132
133
        print('averaging checkpoints: ', args.inputs)

Myle Ott's avatar
Myle Ott committed
134
135
136
137
138
139
140
    new_state = average_checkpoints(args.inputs)
    torch.save(new_state, args.output)
    print('Finished writing averaged checkpoint to {}.'.format(args.output))


if __name__ == '__main__':
    main()