compare.py 2.18 KB
Newer Older
1
2
3
4
5
6
7
import argparse
import torch

parser = argparse.ArgumentParser(description='Compare')
parser.add_argument('--opt-level', type=str)
parser.add_argument('--keep-batchnorm-fp32', type=str, default=None)
parser.add_argument('--loss-scale', type=str, default=None)
8
parser.add_argument('--fused-adam', action='store_true')
9
parser.add_argument('--use_baseline', action='store_true')
10
11
args = parser.parse_args()

12
13
14
15
16
base_file = str(args.opt_level) + "_" +\
            str(args.loss_scale) + "_" +\
            str(args.keep_batchnorm_fp32) + "_" +\
            str(args.fused_adam)

17
18
file_e = "True_" + base_file
file_p = "False_" + base_file
19
20
if args.use_baseline:
    file_b = "baselines/True_" + base_file
21
22
23

dict_e = torch.load(file_e)
dict_p = torch.load(file_p)
24
25
if args.use_baseline:
    dict_b = torch.load(file_b)
26
27
28
29
30

torch.set_printoptions(precision=10)

print(file_e)
print(file_p)
31
32
if args.use_baseline:
    print(file_b)
33

34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
# ugly duplication here...
if not args.use_baseline:
    for n, (i_e, i_p) in enumerate(zip(dict_e["Iteration"], dict_p["Iteration"])):
        assert i_e == i_p, "i_e = {}, i_p = {}".format(i_e, i_p)

        loss_e = dict_e["Loss"][n]
        loss_p = dict_p["Loss"][n]
        assert loss_e == loss_p, "Iteration {}, loss_e = {}, loss_p = {}".format(i_e, loss_e, loss_p)
        print("{:4} {:15.10f} {:15.10f} {:15.10f} {:15.10f}".format(
              i_e,
              loss_e,
              loss_p,
              dict_e["Speed"][n],
              dict_p["Speed"][n]))
else:
    for n, (i_e, i_p) in enumerate(zip(dict_e["Iteration"], dict_p["Iteration"])):
        assert i_e == i_p, "i_e = {}, i_p = {}".format(i_e, i_p)

        loss_e = dict_e["Loss"][n]
        loss_p = dict_p["Loss"][n]
        loss_b = dict_b["Loss"][n]
        assert loss_e == loss_p, "Iteration {}, loss_e = {}, loss_p = {}".format(i_e, loss_e, loss_p)
        assert loss_e == loss_b, "Iteration {}, loss_e = {}, loss_b = {}".format(i_e, loss_e, loss_b)
        print("{:4} {:15.10f} {:15.10f} {:15.10f} {:15.10f} {:15.10f} {:15.10f}".format(
              i_e,
              loss_b,
              loss_e,
              loss_p,
              dict_b["Speed"][n],
              dict_e["Speed"][n],
              dict_p["Speed"][n]))