test_padding.py 5.29 KB
Newer Older
1
2
3
4
5
import unittest
import torch
import torchani


6
7
8
b = torchani.utils.broadcast_first_dim


9
class TestPaddings(torchani.testing.TestCase):
10
11

    def testVectorSpecies(self):
12
        species1 = torch.tensor([[0, 2, 3, 1]])
13
        coordinates1 = torch.zeros(5, 4, 3)
14
        species2 = torch.tensor([[3, 2, 0, 1, 0]])
15
        coordinates2 = torch.zeros(2, 5, 3)
16
        atomic_properties = torchani.utils.pad_atomic_properties([
17
18
            b({'species': species1, 'coordinates': coordinates1}),
            b({'species': species2, 'coordinates': coordinates2}),
19
        ])
20
21
22
        self.assertEqual(atomic_properties['species'].shape[0], 7)
        self.assertEqual(atomic_properties['species'].shape[1], 5)
        expected_species = torch.tensor([
23
24
25
26
27
28
29
30
            [0, 2, 3, 1, -1],
            [0, 2, 3, 1, -1],
            [0, 2, 3, 1, -1],
            [0, 2, 3, 1, -1],
            [0, 2, 3, 1, -1],
            [3, 2, 0, 1, 0],
            [3, 2, 0, 1, 0],
        ])
31
        self.assertEqual(atomic_properties['species'], expected_species)
32
        self.assertEqual(atomic_properties['coordinates'].abs().max().item(), 0)
33
34

    def testTensorShape1NSpecies(self):
35
        species1 = torch.tensor([[0, 2, 3, 1]])
36
        coordinates1 = torch.zeros(5, 4, 3)
37
        species2 = torch.tensor([[3, 2, 0, 1, 0]])
38
        coordinates2 = torch.zeros(2, 5, 3)
39
        atomic_properties = torchani.utils.pad_atomic_properties([
40
41
            b({'species': species1, 'coordinates': coordinates1}),
            b({'species': species2, 'coordinates': coordinates2}),
42
        ])
43
44
45
        self.assertEqual(atomic_properties['species'].shape[0], 7)
        self.assertEqual(atomic_properties['species'].shape[1], 5)
        expected_species = torch.tensor([
46
47
48
49
50
51
52
53
            [0, 2, 3, 1, -1],
            [0, 2, 3, 1, -1],
            [0, 2, 3, 1, -1],
            [0, 2, 3, 1, -1],
            [0, 2, 3, 1, -1],
            [3, 2, 0, 1, 0],
            [3, 2, 0, 1, 0],
        ])
54
        self.assertEqual(atomic_properties['species'], expected_species)
55
        self.assertEqual(atomic_properties['coordinates'].abs().max().item(), 0)
56
57

    def testTensorSpecies(self):
58
        species1 = torch.tensor([
59
60
61
62
63
64
65
            [0, 2, 3, 1],
            [0, 2, 3, 1],
            [0, 2, 3, 1],
            [0, 2, 3, 1],
            [0, 2, 3, 1],
        ])
        coordinates1 = torch.zeros(5, 4, 3)
66
        species2 = torch.tensor([[3, 2, 0, 1, 0]])
67
        coordinates2 = torch.zeros(2, 5, 3)
68
        atomic_properties = torchani.utils.pad_atomic_properties([
69
70
            b({'species': species1, 'coordinates': coordinates1}),
            b({'species': species2, 'coordinates': coordinates2}),
71
        ])
72
73
74
        self.assertEqual(atomic_properties['species'].shape[0], 7)
        self.assertEqual(atomic_properties['species'].shape[1], 5)
        expected_species = torch.tensor([
75
76
77
78
79
80
81
82
            [0, 2, 3, 1, -1],
            [0, 2, 3, 1, -1],
            [0, 2, 3, 1, -1],
            [0, 2, 3, 1, -1],
            [0, 2, 3, 1, -1],
            [3, 2, 0, 1, 0],
            [3, 2, 0, 1, 0],
        ])
83
        self.assertEqual(atomic_properties['species'], expected_species)
84
        self.assertEqual(atomic_properties['coordinates'].abs().max().item(), 0)
85
86

    def testPresentSpecies(self):
87
        species = torch.tensor([0, 1, 1, 0, 3, 7, -1, -1])
Gao, Xiang's avatar
Gao, Xiang committed
88
        present_species = torchani.utils.present_species(species)
89
        expected = torch.tensor([0, 1, 3, 7])
90
        self.assertEqual(expected, present_species)
91
92


93
class TestStripRedundantPadding(torchani.testing.TestCase):
94
95
96
97
98
99

    def testStripRestore(self):
        species1 = torch.randint(4, (5, 4), dtype=torch.long)
        coordinates1 = torch.randn(5, 4, 3)
        species2 = torch.randint(4, (2, 5), dtype=torch.long)
        coordinates2 = torch.randn(2, 5, 3)
100
        atomic_properties12 = torchani.utils.pad_atomic_properties([
101
102
            b({'species': species1, 'coordinates': coordinates1}),
            b({'species': species2, 'coordinates': coordinates2}),
103
        ])
104
105
        species12 = atomic_properties12['species']
        coordinates12 = atomic_properties12['coordinates']
106
107
        species3 = torch.randint(4, (2, 10), dtype=torch.long)
        coordinates3 = torch.randn(2, 10, 3)
108
        atomic_properties123 = torchani.utils.pad_atomic_properties([
109
110
111
            b({'species': species1, 'coordinates': coordinates1}),
            b({'species': species2, 'coordinates': coordinates2}),
            b({'species': species3, 'coordinates': coordinates3}),
112
        ])
113
114
115
        species123 = atomic_properties123['species']
        coordinates123 = atomic_properties123['coordinates']
        species_coordinates1_ = torchani.utils.strip_redundant_padding(
116
            b({'species': species123[:5, ...], 'coordinates': coordinates123[:5, ...]}))
117
118
        species1_ = species_coordinates1_['species']
        coordinates1_ = species_coordinates1_['coordinates']
119
120
        self.assertEqual(species1_, species1)
        self.assertEqual(coordinates1_, coordinates1)
121
        species_coordinates12_ = torchani.utils.strip_redundant_padding(
122
            b({'species': species123[:7, ...], 'coordinates': coordinates123[:7, ...]}))
123
124
        species12_ = species_coordinates12_['species']
        coordinates12_ = species_coordinates12_['coordinates']
125
126
        self.assertEqual(species12_, species12)
        self.assertEqual(coordinates12_, coordinates12)
127
128


129
130
if __name__ == '__main__':
    unittest.main()