average_checkpoints.py 863 Bytes
Newer Older
Pingchuan Ma's avatar
Pingchuan Ma committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import os

import torch


def average_checkpoints(last):
    avg = None
    for path in last:
        states = torch.load(path, map_location=lambda storage, loc: storage)["state_dict"]
        if avg is None:
            avg = states
        else:
            for k in avg.keys():
                avg[k] += states[k]
    # average
    for k in avg.keys():
        if avg[k] is not None:
            if avg[k].is_floating_point():
                avg[k] /= len(last)
            else:
                avg[k] //= len(last)
    return avg


def ensemble(args):
    last = [
        os.path.join(args.exp_dir, args.experiment_name, f"epoch={n}.ckpt")
        for n in range(args.epochs - 10, args.epochs)
    ]
moto's avatar
moto committed
30
    model_path = os.path.join(args.exp_dir, args.experiment_name, "model_avg_10.pth")
Pingchuan Ma's avatar
Pingchuan Ma committed
31
    torch.save({"state_dict": average_checkpoints(last)}, model_path)