test_padding.py 5.42 KB
Newer Older
1
2
3
4
5
import unittest
import torch
import torchani


6
class TestPaddings(unittest.TestCase):
7
8

    def testVectorSpecies(self):
9
        species1 = torch.tensor([[0, 2, 3, 1]])
10
        coordinates1 = torch.zeros(5, 4, 3)
11
        species2 = torch.tensor([[3, 2, 0, 1, 0]])
12
        coordinates2 = torch.zeros(2, 5, 3)
13
14
15
        atomic_properties = torchani.utils.pad_atomic_properties([
            {'species': species1, 'coordinates': coordinates1},
            {'species': species2, 'coordinates': coordinates2},
16
        ])
17
18
19
        self.assertEqual(atomic_properties['species'].shape[0], 7)
        self.assertEqual(atomic_properties['species'].shape[1], 5)
        expected_species = torch.tensor([
20
21
22
23
24
25
26
27
            [0, 2, 3, 1, -1],
            [0, 2, 3, 1, -1],
            [0, 2, 3, 1, -1],
            [0, 2, 3, 1, -1],
            [0, 2, 3, 1, -1],
            [3, 2, 0, 1, 0],
            [3, 2, 0, 1, 0],
        ])
28
29
        self.assertEqual((atomic_properties['species'] - expected_species).abs().max().item(), 0)
        self.assertEqual(atomic_properties['coordinates'].abs().max().item(), 0)
30
31

    def testTensorShape1NSpecies(self):
32
        species1 = torch.tensor([[0, 2, 3, 1]])
33
        coordinates1 = torch.zeros(5, 4, 3)
34
        species2 = torch.tensor([[3, 2, 0, 1, 0]])
35
        coordinates2 = torch.zeros(2, 5, 3)
36
37
38
        atomic_properties = torchani.utils.pad_atomic_properties([
            {'species': species1, 'coordinates': coordinates1},
            {'species': species2, 'coordinates': coordinates2},
39
        ])
40
41
42
        self.assertEqual(atomic_properties['species'].shape[0], 7)
        self.assertEqual(atomic_properties['species'].shape[1], 5)
        expected_species = torch.tensor([
43
44
45
46
47
48
49
50
            [0, 2, 3, 1, -1],
            [0, 2, 3, 1, -1],
            [0, 2, 3, 1, -1],
            [0, 2, 3, 1, -1],
            [0, 2, 3, 1, -1],
            [3, 2, 0, 1, 0],
            [3, 2, 0, 1, 0],
        ])
51
52
        self.assertEqual((atomic_properties['species'] - expected_species).abs().max().item(), 0)
        self.assertEqual(atomic_properties['coordinates'].abs().max().item(), 0)
53
54

    def testTensorSpecies(self):
55
        species1 = torch.tensor([
56
57
58
59
60
61
62
            [0, 2, 3, 1],
            [0, 2, 3, 1],
            [0, 2, 3, 1],
            [0, 2, 3, 1],
            [0, 2, 3, 1],
        ])
        coordinates1 = torch.zeros(5, 4, 3)
63
        species2 = torch.tensor([[3, 2, 0, 1, 0]])
64
        coordinates2 = torch.zeros(2, 5, 3)
65
66
67
        atomic_properties = torchani.utils.pad_atomic_properties([
            {'species': species1, 'coordinates': coordinates1},
            {'species': species2, 'coordinates': coordinates2},
68
        ])
69
70
71
        self.assertEqual(atomic_properties['species'].shape[0], 7)
        self.assertEqual(atomic_properties['species'].shape[1], 5)
        expected_species = torch.tensor([
72
73
74
75
76
77
78
79
            [0, 2, 3, 1, -1],
            [0, 2, 3, 1, -1],
            [0, 2, 3, 1, -1],
            [0, 2, 3, 1, -1],
            [0, 2, 3, 1, -1],
            [3, 2, 0, 1, 0],
            [3, 2, 0, 1, 0],
        ])
80
81
        self.assertEqual((atomic_properties['species'] - expected_species).abs().max().item(), 0)
        self.assertEqual(atomic_properties['coordinates'].abs().max().item(), 0)
82
83

    def testPresentSpecies(self):
84
        species = torch.tensor([0, 1, 1, 0, 3, 7, -1, -1])
Gao, Xiang's avatar
Gao, Xiang committed
85
        present_species = torchani.utils.present_species(species)
86
        expected = torch.tensor([0, 1, 3, 7])
87
88
89
        self.assertEqual((expected - present_species).abs().max().item(), 0)


90
91
92
93
94
95
96
97
98
99
class TestStripRedundantPadding(unittest.TestCase):

    def _assertTensorEqual(self, t1, t2):
        self.assertEqual((t1 - t2).abs().max().item(), 0)

    def testStripRestore(self):
        species1 = torch.randint(4, (5, 4), dtype=torch.long)
        coordinates1 = torch.randn(5, 4, 3)
        species2 = torch.randint(4, (2, 5), dtype=torch.long)
        coordinates2 = torch.randn(2, 5, 3)
100
101
102
        atomic_properties12 = torchani.utils.pad_atomic_properties([
            {'species': species1, 'coordinates': coordinates1},
            {'species': species2, 'coordinates': coordinates2},
103
        ])
104
105
        species12 = atomic_properties12['species']
        coordinates12 = atomic_properties12['coordinates']
106
107
        species3 = torch.randint(4, (2, 10), dtype=torch.long)
        coordinates3 = torch.randn(2, 10, 3)
108
109
110
111
        atomic_properties123 = torchani.utils.pad_atomic_properties([
            {'species': species1, 'coordinates': coordinates1},
            {'species': species2, 'coordinates': coordinates2},
            {'species': species3, 'coordinates': coordinates3},
112
        ])
113
114
115
116
117
118
        species123 = atomic_properties123['species']
        coordinates123 = atomic_properties123['coordinates']
        species_coordinates1_ = torchani.utils.strip_redundant_padding(
            {'species': species123[:5, ...], 'coordinates': coordinates123[:5, ...]})
        species1_ = species_coordinates1_['species']
        coordinates1_ = species_coordinates1_['coordinates']
119
120
        self._assertTensorEqual(species1_, species1)
        self._assertTensorEqual(coordinates1_, coordinates1)
121
122
123
124
        species_coordinates12_ = torchani.utils.strip_redundant_padding(
            {'species': species123[:7, ...], 'coordinates': coordinates123[:7, ...]})
        species12_ = species_coordinates12_['species']
        coordinates12_ = species_coordinates12_['coordinates']
125
126
127
128
        self._assertTensorEqual(species12_, species12)
        self._assertTensorEqual(coordinates12_, coordinates12)


129
130
if __name__ == '__main__':
    unittest.main()