test_padding.py 5.29 KB
Newer Older
1
2
3
import unittest
import torch
import torchani
Jinze Xue's avatar
Jinze Xue committed
4
from torchani.testing import TestCase
5
6


7
8
9
b = torchani.utils.broadcast_first_dim


Jinze Xue's avatar
Jinze Xue committed
10
class TestPaddings(TestCase):
11
12

    def testVectorSpecies(self):
13
        species1 = torch.tensor([[0, 2, 3, 1]])
14
        coordinates1 = torch.zeros(5, 4, 3)
15
        species2 = torch.tensor([[3, 2, 0, 1, 0]])
16
        coordinates2 = torch.zeros(2, 5, 3)
17
        atomic_properties = torchani.utils.pad_atomic_properties([
18
19
            b({'species': species1, 'coordinates': coordinates1}),
            b({'species': species2, 'coordinates': coordinates2}),
20
        ])
21
22
23
        self.assertEqual(atomic_properties['species'].shape[0], 7)
        self.assertEqual(atomic_properties['species'].shape[1], 5)
        expected_species = torch.tensor([
24
25
26
27
28
29
30
31
            [0, 2, 3, 1, -1],
            [0, 2, 3, 1, -1],
            [0, 2, 3, 1, -1],
            [0, 2, 3, 1, -1],
            [0, 2, 3, 1, -1],
            [3, 2, 0, 1, 0],
            [3, 2, 0, 1, 0],
        ])
32
        self.assertEqual(atomic_properties['species'], expected_species)
33
        self.assertEqual(atomic_properties['coordinates'].abs().max().item(), 0)
34
35

    def testTensorShape1NSpecies(self):
36
        species1 = torch.tensor([[0, 2, 3, 1]])
37
        coordinates1 = torch.zeros(5, 4, 3)
38
        species2 = torch.tensor([[3, 2, 0, 1, 0]])
39
        coordinates2 = torch.zeros(2, 5, 3)
40
        atomic_properties = torchani.utils.pad_atomic_properties([
41
42
            b({'species': species1, 'coordinates': coordinates1}),
            b({'species': species2, 'coordinates': coordinates2}),
43
        ])
44
45
46
        self.assertEqual(atomic_properties['species'].shape[0], 7)
        self.assertEqual(atomic_properties['species'].shape[1], 5)
        expected_species = torch.tensor([
47
48
49
50
51
52
53
54
            [0, 2, 3, 1, -1],
            [0, 2, 3, 1, -1],
            [0, 2, 3, 1, -1],
            [0, 2, 3, 1, -1],
            [0, 2, 3, 1, -1],
            [3, 2, 0, 1, 0],
            [3, 2, 0, 1, 0],
        ])
55
        self.assertEqual(atomic_properties['species'], expected_species)
56
        self.assertEqual(atomic_properties['coordinates'].abs().max().item(), 0)
57
58

    def testTensorSpecies(self):
59
        species1 = torch.tensor([
60
61
62
63
64
65
66
            [0, 2, 3, 1],
            [0, 2, 3, 1],
            [0, 2, 3, 1],
            [0, 2, 3, 1],
            [0, 2, 3, 1],
        ])
        coordinates1 = torch.zeros(5, 4, 3)
67
        species2 = torch.tensor([[3, 2, 0, 1, 0]])
68
        coordinates2 = torch.zeros(2, 5, 3)
69
        atomic_properties = torchani.utils.pad_atomic_properties([
70
71
            b({'species': species1, 'coordinates': coordinates1}),
            b({'species': species2, 'coordinates': coordinates2}),
72
        ])
73
74
75
        self.assertEqual(atomic_properties['species'].shape[0], 7)
        self.assertEqual(atomic_properties['species'].shape[1], 5)
        expected_species = torch.tensor([
76
77
78
79
80
81
82
83
            [0, 2, 3, 1, -1],
            [0, 2, 3, 1, -1],
            [0, 2, 3, 1, -1],
            [0, 2, 3, 1, -1],
            [0, 2, 3, 1, -1],
            [3, 2, 0, 1, 0],
            [3, 2, 0, 1, 0],
        ])
84
        self.assertEqual(atomic_properties['species'], expected_species)
85
        self.assertEqual(atomic_properties['coordinates'].abs().max().item(), 0)
86
87

    def testPresentSpecies(self):
88
        species = torch.tensor([0, 1, 1, 0, 3, 7, -1, -1])
Gao, Xiang's avatar
Gao, Xiang committed
89
        present_species = torchani.utils.present_species(species)
90
        expected = torch.tensor([0, 1, 3, 7])
91
        self.assertEqual(expected, present_species)
92
93


Jinze Xue's avatar
Jinze Xue committed
94
class TestStripRedundantPadding(TestCase):
95
96
97
98
99
100

    def testStripRestore(self):
        species1 = torch.randint(4, (5, 4), dtype=torch.long)
        coordinates1 = torch.randn(5, 4, 3)
        species2 = torch.randint(4, (2, 5), dtype=torch.long)
        coordinates2 = torch.randn(2, 5, 3)
101
        atomic_properties12 = torchani.utils.pad_atomic_properties([
102
103
            b({'species': species1, 'coordinates': coordinates1}),
            b({'species': species2, 'coordinates': coordinates2}),
104
        ])
105
106
        species12 = atomic_properties12['species']
        coordinates12 = atomic_properties12['coordinates']
107
108
        species3 = torch.randint(4, (2, 10), dtype=torch.long)
        coordinates3 = torch.randn(2, 10, 3)
109
        atomic_properties123 = torchani.utils.pad_atomic_properties([
110
111
112
            b({'species': species1, 'coordinates': coordinates1}),
            b({'species': species2, 'coordinates': coordinates2}),
            b({'species': species3, 'coordinates': coordinates3}),
113
        ])
114
115
116
        species123 = atomic_properties123['species']
        coordinates123 = atomic_properties123['coordinates']
        species_coordinates1_ = torchani.utils.strip_redundant_padding(
117
            b({'species': species123[:5, ...], 'coordinates': coordinates123[:5, ...]}))
118
119
        species1_ = species_coordinates1_['species']
        coordinates1_ = species_coordinates1_['coordinates']
120
121
        self.assertEqual(species1_, species1)
        self.assertEqual(coordinates1_, coordinates1)
122
        species_coordinates12_ = torchani.utils.strip_redundant_padding(
123
            b({'species': species123[:7, ...], 'coordinates': coordinates123[:7, ...]}))
124
125
        species12_ = species_coordinates12_['species']
        coordinates12_ = species_coordinates12_['coordinates']
126
127
        self.assertEqual(species12_, species12)
        self.assertEqual(coordinates12_, coordinates12)
128
129


130
131
if __name__ == '__main__':
    unittest.main()