average_checkpoints.py 4.18 KB
Newer Older
Myle Ott's avatar
Myle Ott committed
1
2
3
4
5
#!/usr/bin/env python3

import argparse
import collections
import torch
6
7
import os
import re
Myle Ott's avatar
Myle Ott committed
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48


def average_checkpoints(inputs):
    """Loads checkpoints from inputs and returns a model with averaged weights.

    Args:
      inputs: An iterable of string paths of checkpoints to load from.

    Returns:
      A dict of string keys mapping to various values. The 'model' key
      from the returned dict should correspond to an OrderedDict mapping
      string parameter names to torch Tensors.
    """
    params_dict = collections.OrderedDict()
    params_keys = None
    new_state = None
    for f in inputs:
        state = torch.load(
            f,
            map_location=(
                lambda s, _: torch.serialization.default_restore_location(s, 'cpu')
            ),
        )
        # Copies over the settings from the first checkpoint
        if new_state is None:
            new_state = state

        model_params = state['model']

        model_params_keys = list(model_params.keys())
        if params_keys is None:
            params_keys = model_params_keys
        elif params_keys != model_params_keys:
            raise KeyError(
                'For checkpoint {}, expected list of params: {}, '
                'but found: {}'.format(f, params_keys, model_params_keys)
            )

        for k in params_keys:
            if k not in params_dict:
                params_dict[k] = []
Myle Ott's avatar
Myle Ott committed
49
50
51
52
            p = model_params[k]
            if isinstance(p, torch.HalfTensor):
                p = p.float()
            params_dict[k].append(p)
Myle Ott's avatar
Myle Ott committed
53
54
55
56
57
58
59
60
61
62
63
64

    averaged_params = collections.OrderedDict()
    # v should be a list of torch Tensor.
    for k, v in params_dict.items():
        summed_v = None
        for x in v:
            summed_v = summed_v + x if summed_v is not None else x
        averaged_params[k] = summed_v / len(v)
    new_state['model'] = averaged_params
    return new_state


65
def last_n_checkpoints(paths, n, update_based):
66
67
    assert len(paths) == 1
    path = paths[0]
68
69
70
71
    if update_based:
        pt_regexp = re.compile(r'checkpoint_\d+_(\d+)\.pt')
    else:
        pt_regexp = re.compile(r'checkpoint(\d+)\.pt')
72
73
74
75
76
77
78
79
80
81
82
83
    files = os.listdir(path)

    entries = []
    for f in files:
        m = pt_regexp.fullmatch(f)
        if m is not None:
            entries.append((int(m.group(1)), m.group(0)))
    if len(entries) < n:
        raise Exception('Found {} checkpoint files but need at least {}', len(entries), n)
    return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)[:n]]


Myle Ott's avatar
Myle Ott committed
84
85
86
def main():
    parser = argparse.ArgumentParser(
        description='Tool to average the params of input checkpoints to '
87
                    'produce a new checkpoint',
Myle Ott's avatar
Myle Ott committed
88
89
90
91
92
93
94
95
96
97
98
99
100
    )

    parser.add_argument(
        '--inputs',
        required=True,
        nargs='+',
        help='Input checkpoint file paths.',
    )
    parser.add_argument(
        '--output',
        required=True,
        metavar='FILE',
        help='Write the new checkpoint containing the averaged weights to this '
101
             'path.',
Myle Ott's avatar
Myle Ott committed
102
    )
103
104
105
    num_group = parser.add_mutually_exclusive_group()
    num_group.add_argument(
        '--num-epoch-checkpoints',
106
107
        type=int,
        help='if set, will try to find checkpoints with names checkpoint_xx.pt in the path specified by input, '
108
             'and average last this many of them.',
109
    )
110
111
112
113
114
    num_group.add_argument(
        '--num-update-checkpoints',
        type=int,
        help='if set, will try to find checkpoints with names checkpoint_ee_xx.pt in the path specified by input, '
             'and average last this many of them.',
115
    )
Myle Ott's avatar
Myle Ott committed
116
117
118
    args = parser.parse_args()
    print(args)

119
120
121
122
123
124
125
126
127
128
    num = None
    is_update_based = False
    if args.num_update_checkpoints is not None:
        num = args.num_update_checkpoints
        is_update_based = True
    elif args.num_epoch_checkpoints is not None:
        num = args.num_epoch_checkpoints

    if num is not None:
        args.inputs = last_n_checkpoints(args.inputs, num, is_update_based)
129
130
        print('averaging checkpoints: ', args.inputs)

Myle Ott's avatar
Myle Ott committed
131
132
133
134
135
136
137
    new_state = average_checkpoints(args.inputs)
    torch.save(new_state, args.output)
    print('Finished writing averaged checkpoint to {}.'.format(args.output))


if __name__ == '__main__':
    main()