"vscode:/vscode.git/clone" did not exist on "d528e1af75e5b253b21ec90fa9d5f33737fd4909"
test_neurochem.py 3.47 KB
Newer Older
Gao, Xiang's avatar
Gao, Xiang committed
1
2
3
4
5
6
7
8
import torchani
import torch
import os
import unittest


path = os.path.dirname(os.path.realpath(__file__))
iptpath = os.path.join(path, 'test_data/inputtrain.ipt')
Gao, Xiang's avatar
Gao, Xiang committed
9
dspath = os.path.join(path, '../dataset/ani1-up_to_gdb4/ani_gdb_s01.h5')
Gao, Xiang's avatar
Gao, Xiang committed
10
11
12
13
14
15
16


class TestNeuroChem(unittest.TestCase):

    def testNeuroChemTrainer(self):
        d = torch.device('cpu')
        trainer = torchani.neurochem.Trainer(iptpath, d, True, 'runs')
17
18

        # test if loader construct correct model
19
        self.assertEqual(trainer.aev_computer.aev_length, 384)
20
        m = trainer.nn
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
        H, C, N, O = m  # noqa: E741
        self.assertIsInstance(H[0], torch.nn.Linear)
        self.assertListEqual(list(H[0].weight.shape), [160, 384])
        self.assertIsInstance(H[1], torch.nn.CELU)
        self.assertIsInstance(H[2], torch.nn.Linear)
        self.assertListEqual(list(H[2].weight.shape), [128, 160])
        self.assertIsInstance(H[3], torch.nn.CELU)
        self.assertIsInstance(H[4], torch.nn.Linear)
        self.assertListEqual(list(H[4].weight.shape), [96, 128])
        self.assertIsInstance(H[5], torch.nn.CELU)
        self.assertIsInstance(H[6], torch.nn.Linear)
        self.assertListEqual(list(H[6].weight.shape), [1, 96])
        self.assertEqual(len(H), 7)

        self.assertIsInstance(C[0], torch.nn.Linear)
        self.assertListEqual(list(C[0].weight.shape), [144, 384])
        self.assertIsInstance(C[1], torch.nn.CELU)
        self.assertIsInstance(C[2], torch.nn.Linear)
        self.assertListEqual(list(C[2].weight.shape), [112, 144])
        self.assertIsInstance(C[3], torch.nn.CELU)
        self.assertIsInstance(C[4], torch.nn.Linear)
        self.assertListEqual(list(C[4].weight.shape), [96, 112])
        self.assertIsInstance(C[5], torch.nn.CELU)
        self.assertIsInstance(C[6], torch.nn.Linear)
        self.assertListEqual(list(C[6].weight.shape), [1, 96])
        self.assertEqual(len(C), 7)

        self.assertIsInstance(N[0], torch.nn.Linear)
        self.assertListEqual(list(N[0].weight.shape), [128, 384])
        self.assertIsInstance(N[1], torch.nn.CELU)
        self.assertIsInstance(N[2], torch.nn.Linear)
        self.assertListEqual(list(N[2].weight.shape), [112, 128])
        self.assertIsInstance(N[3], torch.nn.CELU)
        self.assertIsInstance(N[4], torch.nn.Linear)
        self.assertListEqual(list(N[4].weight.shape), [96, 112])
        self.assertIsInstance(N[5], torch.nn.CELU)
        self.assertIsInstance(N[6], torch.nn.Linear)
        self.assertListEqual(list(N[6].weight.shape), [1, 96])
        self.assertEqual(len(N), 7)

        self.assertIsInstance(O[0], torch.nn.Linear)
        self.assertListEqual(list(O[0].weight.shape), [128, 384])
        self.assertIsInstance(O[1], torch.nn.CELU)
        self.assertIsInstance(O[2], torch.nn.Linear)
        self.assertListEqual(list(O[2].weight.shape), [112, 128])
        self.assertIsInstance(O[3], torch.nn.CELU)
        self.assertIsInstance(O[4], torch.nn.Linear)
        self.assertListEqual(list(O[4].weight.shape), [96, 112])
        self.assertIsInstance(O[5], torch.nn.CELU)
        self.assertIsInstance(O[6], torch.nn.Linear)
        self.assertListEqual(list(O[6].weight.shape), [1, 96])
        self.assertEqual(len(O), 7)

        self.assertEqual(trainer.init_lr, 0.001)
        self.assertEqual(trainer.min_lr, 1e-5)
        self.assertEqual(trainer.max_nonimprove, 1)
        self.assertEqual(trainer.lr_decay, 0.1)

Gao, Xiang's avatar
Gao, Xiang committed
79
80
81
82
83
84
        trainer.load_data(dspath, dspath)
        trainer.run()


if __name__ == '__main__':
    unittest.main()