compare.py 1.03 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import argparse
import torch

parser = argparse.ArgumentParser(description='Compare')
parser.add_argument('--opt-level', type=str)
parser.add_argument('--keep-batchnorm-fp32', type=str, default=None)
parser.add_argument('--loss-scale', type=str, default=None)
args = parser.parse_args()

base_file =  str(args.opt_level) + "_" + str(args.loss_scale) + "_" + str(args.keep_batchnorm_fp32)
file_e = "True_" + base_file
file_p = "False_" + base_file

dict_e = torch.load(file_e)
dict_p = torch.load(file_p)

torch.set_printoptions(precision=10)

print(file_e)
print(file_p)

for n, (i_e, i_p) in enumerate(zip(dict_e["Iteration"], dict_p["Iteration"])):
    assert i_e == i_p, "i_e = {}, i_p = {}".format(i_e, i_p)

    loss_e = dict_e["Loss"][n]
    loss_p = dict_p["Loss"][n]
    assert loss_e == loss_p, "Iteration {}, loss_e = {}, loss_p = {}".format(i_e, loss_e, loss_p)
    print("{:4} {:15.10f} {:15.10f} {:15.10f} {:15.10f}".format(
          i_e,
          loss_e,
          loss_p,
          dict_e["Speed"][n],
          dict_p["Speed"][n]))