test_batch.py 1.07 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import sys

if sys.version_info.major >= 3:
    import os
    import unittest
    import torch
    import torchani
    import torchani.data
    import itertools

    path = os.path.dirname(os.path.realpath(__file__))
    path = os.path.join(path, 'dataset')
    chunksize = 32
    batch_chunks = 32
    dtype = torch.float32
    device = torch.device('cpu')

    class TestBatch(unittest.TestCase):

        def testBatchLoadAndInference(self):
            ds = torchani.data.ANIDataset(path, chunksize)
            loader = torchani.data.dataloader(ds, batch_chunks)
            aev_computer = torchani.SortedAEV(dtype=dtype, device=device)
            nnp = torchani.models.NeuroChemNNP(aev_computer)
            batch_nnp = torchani.models.BatchModel(nnp)
            for batch_input, batch_output in itertools.islice(loader, 10):
                batch_output_ = batch_nnp(batch_input).squeeze()
                self.assertListEqual(list(batch_output_.shape),
                                     list(batch_output['energies'].shape))

    if __name__ == '__main__':
        unittest.main()