test_ignite.py 2.02 KB
Newer Older
1
2
3
4
5
6
import sys

if sys.version_info.major >= 3:
    import os
    import unittest
    import torch
Gao, Xiang's avatar
Gao, Xiang committed
7
8
9
    from ignite.metrics import RootMeanSquaredError
    from ignite.engine import create_supervised_trainer, \
        create_supervised_evaluator
10
11
12
13
14
15
16
17
18
19
20
21
22
    import torchani
    import torchani.data

    path = os.path.dirname(os.path.realpath(__file__))
    path = os.path.join(path, 'dataset/ani_gdb_s01.h5')
    chunksize = 32
    batch_chunks = 32
    dtype = torch.float32
    device = torch.device('cpu')

    class TestIgnite(unittest.TestCase):

        def testIgnite(self):
23
24
25
26
            shift_energy = torchani.EnergyShifter()
            ds = torchani.data.ANIDataset(
                path, chunksize,
                transform=[shift_energy.dataset_subtract_sae])
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
            loader = torchani.data.dataloader(ds, batch_chunks)
            aev_computer = torchani.SortedAEV(dtype=dtype, device=device)
            nnp = torchani.models.NeuroChemNNP(aev_computer)

            class Flatten(torch.nn.Module):

                def __init__(self, model):
                    super(Flatten, self).__init__()
                    self.model = model

                def forward(self, *input):
                    return self.model(*input).flatten()
            nnp = Flatten(nnp)
            batch_nnp = torchani.models.BatchModel(nnp)
            container = torchani.ignite.Container({'energies': batch_nnp})
Gao, Xiang's avatar
Gao, Xiang committed
42
            loss = torchani.ignite.DictLoss('energies', torch.nn.MSELoss())
43
44
45
46
            optimizer = torch.optim.SGD(container.parameters(),
                                        lr=0.001, momentum=0.8)
            trainer = create_supervised_trainer(container, optimizer, loss)
            trainer.run(loader, max_epochs=10)
Gao, Xiang's avatar
Gao, Xiang committed
47
48
49
50
51
52
            metric = torchani.ignite.DictMetric('energies',
                                                RootMeanSquaredError())
            evaluator = create_supervised_evaluator(container, metrics={
                'RMSE': metric
            })
            evaluator.run(loader)
53
54
55

    if __name__ == '__main__':
        unittest.main()