test_padding.py 6.16 KB
Newer Older
1
2
3
4
5
import unittest
import torch
import torchani


6
class TestPaddings(unittest.TestCase):
7
8

    def testVectorSpecies(self):
9
        species1 = torch.tensor([[0, 2, 3, 1]])
10
        coordinates1 = torch.zeros(5, 4, 3)
11
        species2 = torch.tensor([[3, 2, 0, 1, 0]])
12
        coordinates2 = torch.zeros(2, 5, 3)
13
14
15
        atomic_properties = torchani.utils.pad_atomic_properties([
            {'species': species1, 'coordinates': coordinates1},
            {'species': species2, 'coordinates': coordinates2},
16
        ])
17
18
19
        self.assertEqual(atomic_properties['species'].shape[0], 7)
        self.assertEqual(atomic_properties['species'].shape[1], 5)
        expected_species = torch.tensor([
20
21
22
23
24
25
26
27
            [0, 2, 3, 1, -1],
            [0, 2, 3, 1, -1],
            [0, 2, 3, 1, -1],
            [0, 2, 3, 1, -1],
            [0, 2, 3, 1, -1],
            [3, 2, 0, 1, 0],
            [3, 2, 0, 1, 0],
        ])
28
29
        self.assertEqual((atomic_properties['species'] - expected_species).abs().max().item(), 0)
        self.assertEqual(atomic_properties['coordinates'].abs().max().item(), 0)
30
31

    def testTensorShape1NSpecies(self):
32
        species1 = torch.tensor([[0, 2, 3, 1]])
33
        coordinates1 = torch.zeros(5, 4, 3)
34
        species2 = torch.tensor([[3, 2, 0, 1, 0]])
35
        coordinates2 = torch.zeros(2, 5, 3)
36
37
38
        atomic_properties = torchani.utils.pad_atomic_properties([
            {'species': species1, 'coordinates': coordinates1},
            {'species': species2, 'coordinates': coordinates2},
39
        ])
40
41
42
        self.assertEqual(atomic_properties['species'].shape[0], 7)
        self.assertEqual(atomic_properties['species'].shape[1], 5)
        expected_species = torch.tensor([
43
44
45
46
47
48
49
50
            [0, 2, 3, 1, -1],
            [0, 2, 3, 1, -1],
            [0, 2, 3, 1, -1],
            [0, 2, 3, 1, -1],
            [0, 2, 3, 1, -1],
            [3, 2, 0, 1, 0],
            [3, 2, 0, 1, 0],
        ])
51
52
        self.assertEqual((atomic_properties['species'] - expected_species).abs().max().item(), 0)
        self.assertEqual(atomic_properties['coordinates'].abs().max().item(), 0)
53
54

    def testTensorSpecies(self):
55
        species1 = torch.tensor([
56
57
58
59
60
61
62
            [0, 2, 3, 1],
            [0, 2, 3, 1],
            [0, 2, 3, 1],
            [0, 2, 3, 1],
            [0, 2, 3, 1],
        ])
        coordinates1 = torch.zeros(5, 4, 3)
63
        species2 = torch.tensor([[3, 2, 0, 1, 0]])
64
        coordinates2 = torch.zeros(2, 5, 3)
65
66
67
        atomic_properties = torchani.utils.pad_atomic_properties([
            {'species': species1, 'coordinates': coordinates1},
            {'species': species2, 'coordinates': coordinates2},
68
        ])
69
70
71
        self.assertEqual(atomic_properties['species'].shape[0], 7)
        self.assertEqual(atomic_properties['species'].shape[1], 5)
        expected_species = torch.tensor([
72
73
74
75
76
77
78
79
            [0, 2, 3, 1, -1],
            [0, 2, 3, 1, -1],
            [0, 2, 3, 1, -1],
            [0, 2, 3, 1, -1],
            [0, 2, 3, 1, -1],
            [3, 2, 0, 1, 0],
            [3, 2, 0, 1, 0],
        ])
80
81
        self.assertEqual((atomic_properties['species'] - expected_species).abs().max().item(), 0)
        self.assertEqual(atomic_properties['coordinates'].abs().max().item(), 0)
82

83
    def testPadSpecies(self):
84
        species1 = torch.tensor([
85
86
87
88
89
90
            [0, 2, 3, 1],
            [0, 2, 3, 1],
            [0, 2, 3, 1],
            [0, 2, 3, 1],
            [0, 2, 3, 1],
        ])
91
        species2 = torch.tensor([[3, 2, 0, 1, 0]]).expand(2, 5)
92
93
94
        species = torchani.utils.pad([species1, species2])
        self.assertEqual(species.shape[0], 7)
        self.assertEqual(species.shape[1], 5)
95
        expected_species = torch.tensor([
96
97
98
99
100
101
102
103
104
105
            [0, 2, 3, 1, -1],
            [0, 2, 3, 1, -1],
            [0, 2, 3, 1, -1],
            [0, 2, 3, 1, -1],
            [0, 2, 3, 1, -1],
            [3, 2, 0, 1, 0],
            [3, 2, 0, 1, 0],
        ])
        self.assertEqual((species - expected_species).abs().max().item(), 0)

106
    def testPresentSpecies(self):
107
        species = torch.tensor([0, 1, 1, 0, 3, 7, -1, -1])
Gao, Xiang's avatar
Gao, Xiang committed
108
        present_species = torchani.utils.present_species(species)
109
        expected = torch.tensor([0, 1, 3, 7])
110
111
112
        self.assertEqual((expected - present_species).abs().max().item(), 0)


113
114
115
116
117
118
119
120
121
122
class TestStripRedundantPadding(unittest.TestCase):

    def _assertTensorEqual(self, t1, t2):
        self.assertEqual((t1 - t2).abs().max().item(), 0)

    def testStripRestore(self):
        species1 = torch.randint(4, (5, 4), dtype=torch.long)
        coordinates1 = torch.randn(5, 4, 3)
        species2 = torch.randint(4, (2, 5), dtype=torch.long)
        coordinates2 = torch.randn(2, 5, 3)
123
124
125
        atomic_properties12 = torchani.utils.pad_atomic_properties([
            {'species': species1, 'coordinates': coordinates1},
            {'species': species2, 'coordinates': coordinates2},
126
        ])
127
128
        species12 = atomic_properties12['species']
        coordinates12 = atomic_properties12['coordinates']
129
130
        species3 = torch.randint(4, (2, 10), dtype=torch.long)
        coordinates3 = torch.randn(2, 10, 3)
131
132
133
134
        atomic_properties123 = torchani.utils.pad_atomic_properties([
            {'species': species1, 'coordinates': coordinates1},
            {'species': species2, 'coordinates': coordinates2},
            {'species': species3, 'coordinates': coordinates3},
135
        ])
136
137
138
139
140
141
        species123 = atomic_properties123['species']
        coordinates123 = atomic_properties123['coordinates']
        species_coordinates1_ = torchani.utils.strip_redundant_padding(
            {'species': species123[:5, ...], 'coordinates': coordinates123[:5, ...]})
        species1_ = species_coordinates1_['species']
        coordinates1_ = species_coordinates1_['coordinates']
142
143
        self._assertTensorEqual(species1_, species1)
        self._assertTensorEqual(coordinates1_, coordinates1)
144
145
146
147
        species_coordinates12_ = torchani.utils.strip_redundant_padding(
            {'species': species123[:7, ...], 'coordinates': coordinates123[:7, ...]})
        species12_ = species_coordinates12_['species']
        coordinates12_ = species_coordinates12_['coordinates']
148
149
150
151
        self._assertTensorEqual(species12_, species12)
        self._assertTensorEqual(coordinates12_, coordinates12)


152
153
if __name__ == '__main__':
    unittest.main()