test_ignite.py 1.83 KB
Newer Older
1
2
3
import os
import unittest
import torch
Gao, Xiang's avatar
Gao, Xiang committed
4
import copy
5
6
7
from ignite.engine import create_supervised_trainer, \
    create_supervised_evaluator, Events
import torchani
8
import torchani.ignite
9
10

path = os.path.dirname(os.path.realpath(__file__))
Gao, Xiang's avatar
Gao, Xiang committed
11
path = os.path.join(path, '../dataset/ani1-up_to_gdb4/ani_gdb_s01.h5')
12
13
14
15
16
17
18
batchsize = 4
threshold = 1e-5


class TestIgnite(unittest.TestCase):

    def testIgnite(self):
19
20
21
22
        ani1x = torchani.models.ANI1x()
        aev_computer = ani1x.aev_computer
        nnp = copy.deepcopy(ani1x.neural_networks[0])
        shift_energy = ani1x.energy_shifter
23
        ds = torchani.data.load_ani_dataset(
24
            path, ani1x.consts.species_to_tensor, batchsize,
25
26
            transform=[shift_energy.subtract_from_dataset],
            device=aev_computer.EtaR.device)
27
28
29
30
31
32
33
        ds = torch.utils.data.Subset(ds, [0])

        class Flatten(torch.nn.Module):
            def forward(self, x):
                return x[0], x[1].flatten()

        model = torch.nn.Sequential(aev_computer, nnp, Flatten())
34
        container = torchani.ignite.Container({'energies': model})
35
        optimizer = torch.optim.Adam(container.parameters())
36
37
        loss = torchani.ignite.TransformedLoss(
            torchani.ignite.MSELoss('energies'),
38
39
40
41
            lambda x: torch.exp(x) - 1)
        trainer = create_supervised_trainer(
            container, optimizer, loss)
        evaluator = create_supervised_evaluator(container, metrics={
42
            'RMSE': torchani.ignite.RMSEMetric('energies')
43
44
45
46
47
48
49
50
51
52
53
54
55
56
        })

        @trainer.on(Events.COMPLETED)
        def completes(trainer):
            evaluator.run(ds)
            metrics = evaluator.state.metrics
            self.assertLess(metrics['RMSE'], threshold)
            self.assertLess(trainer.state.output, threshold)

        trainer.run(ds, max_epochs=1000)


if __name__ == '__main__':
    unittest.main()