analyze-gradients.py 656 Bytes
Newer Older
Xiang Gao's avatar
Xiang Gao committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import pickle
import torch

hyperparams = [  # (chunk size, batch chunks)
    # (64, 4),
    (64, 8),
    (64, 16),
    (64, 32),
    (128, 2),
    (128, 4),
    (128, 8),
    (128, 16),
    (256, 1),
    (256, 2),
    (256, 4),
    (256, 8),
    (512, 1),
    (512, 2),
    (512, 4),
    (1024, 1),
    (1024, 2),
    (2048, 1),
]

for chunk_size, batch_chunks in hyperparams:
26
27
    with open('data/avg-{}-{}.dat'.format(chunk_size, batch_chunks),
              'rb') as f:
Xiang Gao's avatar
Xiang Gao committed
28
29
30
31
        ag, agsqr = pickle.load(f)
        variance = torch.sum(agsqr) - torch.sum(ag**2)
        stddev = torch.sqrt(variance).item()
        print(chunk_size, batch_chunks, stddev)