"tests/vscode:/vscode.git/clone" did not exist on "6a6597a02a3370a8b3173701739bdaafba997873"
average_checkpoints.py 5.28 KB
Newer Older
Myle Ott's avatar
Myle Ott committed
1
#!/usr/bin/env python3
2
3
4
5
6
7
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
Myle Ott's avatar
Myle Ott committed
8
9
10
11

import argparse
import collections
import torch
12
13
import os
import re
Myle Ott's avatar
Myle Ott committed
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29


def average_checkpoints(inputs):
    """Loads checkpoints from inputs and returns a model with averaged weights.

    Args:
      inputs: An iterable of string paths of checkpoints to load from.

    Returns:
      A dict of string keys mapping to various values. The 'model' key
      from the returned dict should correspond to an OrderedDict mapping
      string parameter names to torch Tensors.
    """
    params_dict = collections.OrderedDict()
    params_keys = None
    new_state = None
30
31
    num_models = len(inputs)

Myle Ott's avatar
Myle Ott committed
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
    for f in inputs:
        state = torch.load(
            f,
            map_location=(
                lambda s, _: torch.serialization.default_restore_location(s, 'cpu')
            ),
        )
        # Copies over the settings from the first checkpoint
        if new_state is None:
            new_state = state

        model_params = state['model']

        model_params_keys = list(model_params.keys())
        if params_keys is None:
            params_keys = model_params_keys
        elif params_keys != model_params_keys:
            raise KeyError(
                'For checkpoint {}, expected list of params: {}, '
                'but found: {}'.format(f, params_keys, model_params_keys)
            )

        for k in params_keys:
Myle Ott's avatar
Myle Ott committed
55
56
57
            p = model_params[k]
            if isinstance(p, torch.HalfTensor):
                p = p.float()
58
59
60
61
62
            if k not in params_dict:
                params_dict[k] = p.clone()
                # NOTE: clone() is needed in case of p is a shared parameter
            else:
                params_dict[k] += p
Myle Ott's avatar
Myle Ott committed
63
64
65

    averaged_params = collections.OrderedDict()
    for k, v in params_dict.items():
66
67
        averaged_params[k] = v
        averaged_params[k].div_(num_models)
Myle Ott's avatar
Myle Ott committed
68
69
70
71
    new_state['model'] = averaged_params
    return new_state


72
def last_n_checkpoints(paths, n, update_based, upper_bound=None):
73
74
    assert len(paths) == 1
    path = paths[0]
75
76
77
78
    if update_based:
        pt_regexp = re.compile(r'checkpoint_\d+_(\d+)\.pt')
    else:
        pt_regexp = re.compile(r'checkpoint(\d+)\.pt')
79
80
81
82
83
84
    files = os.listdir(path)

    entries = []
    for f in files:
        m = pt_regexp.fullmatch(f)
        if m is not None:
85
86
87
            sort_key = int(m.group(1))
            if upper_bound is None or sort_key <= upper_bound:
                entries.append((sort_key, m.group(0)))
88
89
90
91
92
    if len(entries) < n:
        raise Exception('Found {} checkpoint files but need at least {}', len(entries), n)
    return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)[:n]]


Myle Ott's avatar
Myle Ott committed
93
94
95
def main():
    parser = argparse.ArgumentParser(
        description='Tool to average the params of input checkpoints to '
96
                    'produce a new checkpoint',
Myle Ott's avatar
Myle Ott committed
97
    )
98
99
100
101
102
    # fmt: off
    parser.add_argument('--inputs', required=True, nargs='+',
                        help='Input checkpoint file paths.')
    parser.add_argument('--output', required=True, metavar='FILE',
                        help='Write the new checkpoint containing the averaged weights to this path.')
103
    num_group = parser.add_mutually_exclusive_group()
104
105
106
107
108
109
    num_group.add_argument('--num-epoch-checkpoints', type=int,
                           help='if set, will try to find checkpoints with names checkpoint_xx.pt in the path specified by input, '
                           'and average last this many of them.')
    num_group.add_argument('--num-update-checkpoints', type=int,
                           help='if set, will try to find checkpoints with names checkpoint_ee_xx.pt in the path specified by input, '
                           'and average last this many of them.')
110
111
112
    parser.add_argument('--checkpoint-upper-bound', type=int,
                        help='when using --num-epoch-checkpoints, this will set an upper bound on which checkpoint to use, '
                        'e.g., with --num-epoch-checkpoints=10 --checkpoint-upper-bound=50, checkpoints 41-50 would be averaged.')
113
    # fmt: on
Myle Ott's avatar
Myle Ott committed
114
115
116
    args = parser.parse_args()
    print(args)

117
118
119
120
121
122
123
124
    num = None
    is_update_based = False
    if args.num_update_checkpoints is not None:
        num = args.num_update_checkpoints
        is_update_based = True
    elif args.num_epoch_checkpoints is not None:
        num = args.num_epoch_checkpoints

125
126
127
128
129
    assert args.checkpoint_upper_bound is None or args.num_epoch_checkpoints is not None, \
            '--checkpoint-upper-bound requires --num-epoch-checkpoints'
    assert args.num_epoch_checkpoints is None or args.num_update_checkpoints is None, \
            'Cannot combine --num-epoch-checkpoints and --num-update-checkpoints'

130
    if num is not None:
131
132
133
        args.inputs = last_n_checkpoints(
            args.inputs, num, is_update_based, upper_bound=args.checkpoint_upper_bound,
        )
134
135
        print('averaging checkpoints: ', args.inputs)

Myle Ott's avatar
Myle Ott committed
136
137
138
139
140
141
142
    new_state = average_checkpoints(args.inputs)
    torch.save(new_state, args.output)
    print('Finished writing averaged checkpoint to {}.'.format(args.output))


if __name__ == '__main__':
    main()