"vscode:/vscode.git/clone" did not exist on "755b4e4fc291366595ed7bfb37c2a91ff5834df8"
test_ignite.py 2.06 KB
Newer Older
1
2
3
4
5
6
import sys

if sys.version_info.major >= 3:
    import os
    import unittest
    import torch
Gao, Xiang's avatar
Gao, Xiang committed
7
    from ignite.engine import create_supervised_trainer, \
8
        create_supervised_evaluator, Events
9
10
11
12
    import torchani
    import torchani.data

    path = os.path.dirname(os.path.realpath(__file__))
13
    path = os.path.join(path, '../dataset/ani_gdb_s01.h5')
14
    chunksize = 4
15
    threshold = 1e-5
16
17
18
19

    class TestIgnite(unittest.TestCase):

        def testIgnite(self):
20
21
            aev_computer = torchani.SortedAEV()
            prepare = torchani.PrepareInput(aev_computer.species)
22
            nnp = torchani.models.NeuroChemNNP(aev_computer.species)
Gao, Xiang's avatar
Gao, Xiang committed
23
24
25
26
27
28
            shift_energy = torchani.EnergyShifter(aev_computer.species)
            ds = torchani.data.ANIDataset(
                path, chunksize,
                transform=[shift_energy.subtract_from_dataset])
            ds = torch.utils.data.Subset(ds, [0])
            loader = torchani.data.dataloader(ds, 1)
29
30

            class Flatten(torch.nn.Module):
31
                def forward(self, x):
32
                    return x[0], x[1].flatten()
33

34
            model = torch.nn.Sequential(prepare, aev_computer, nnp, Flatten())
35
            container = torchani.ignite.Container({'energies': model})
36
            optimizer = torch.optim.Adam(container.parameters())
37
38
39
            loss = torchani.ignite.TransformedLoss(
                torchani.ignite.MSELoss('energies'),
                lambda x: torch.exp(x) - 1)
Gao, Xiang's avatar
Gao, Xiang committed
40
            trainer = create_supervised_trainer(
41
                container, optimizer, loss)
Gao, Xiang's avatar
Gao, Xiang committed
42
            evaluator = create_supervised_evaluator(container, metrics={
43
                'RMSE': torchani.ignite.RMSEMetric('energies')
Gao, Xiang's avatar
Gao, Xiang committed
44
            })
45
46
47
48
49
50
51
52
53

            @trainer.on(Events.COMPLETED)
            def completes(trainer):
                evaluator.run(loader)
                metrics = evaluator.state.metrics
                self.assertLess(metrics['RMSE'], threshold)
                self.assertLess(trainer.state.output, threshold)

            trainer.run(loader, max_epochs=1000)
54
55
56

    if __name__ == '__main__':
        unittest.main()