test_ignite.py 1.77 KB
Newer Older
1
2
3
import os
import unittest
import torch
Gao, Xiang's avatar
Gao, Xiang committed
4
import copy
5
6
7
8
9
10
11
12
13
14
15
16
17
18
from ignite.engine import create_supervised_trainer, \
    create_supervised_evaluator, Events
import torchani
import torchani.training

path = os.path.dirname(os.path.realpath(__file__))
path = os.path.join(path, '../dataset/ani_gdb_s01.h5')
batchsize = 4
threshold = 1e-5


class TestIgnite(unittest.TestCase):

    def testIgnite(self):
Gao, Xiang's avatar
Gao, Xiang committed
19
20
21
        aev_computer = torchani.buildins.aev_computer
        nnp = copy.deepcopy(torchani.buildins.models[0])
        shift_energy = torchani.buildins.energy_shifter
22
        ds = torchani.training.BatchedANIDataset(
Gao, Xiang's avatar
Gao, Xiang committed
23
            path, torchani.buildins.consts.species, batchsize,
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
            transform=[shift_energy.subtract_from_dataset])
        ds = torch.utils.data.Subset(ds, [0])

        class Flatten(torch.nn.Module):
            def forward(self, x):
                return x[0], x[1].flatten()

        model = torch.nn.Sequential(aev_computer, nnp, Flatten())
        container = torchani.training.Container({'energies': model})
        optimizer = torch.optim.Adam(container.parameters())
        loss = torchani.training.TransformedLoss(
            torchani.training.MSELoss('energies'),
            lambda x: torch.exp(x) - 1)
        trainer = create_supervised_trainer(
            container, optimizer, loss)
        evaluator = create_supervised_evaluator(container, metrics={
            'RMSE': torchani.training.RMSEMetric('energies')
        })

        @trainer.on(Events.COMPLETED)
        def completes(trainer):
            evaluator.run(ds)
            metrics = evaluator.state.metrics
            self.assertLess(metrics['RMSE'], threshold)
            self.assertLess(trainer.state.output, threshold)

        trainer.run(ds, max_epochs=1000)


if __name__ == '__main__':
    unittest.main()