test_utils.py 7.57 KB
Newer Older
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
16
import numpy as np
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
17
18
19
import torch
import unittest

20
21
22
23
24
25
from openfold.utils.rigid_utils import (
    Rotation,
    Rigid, 
    quat_to_rot,
    rot_to_quat,
)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
26
from openfold.utils.chunk_utils import chunk_layer, _chunk_slice
27
28
29
30
31
32
33
import tests.compare_utils as compare_utils
from tests.config import consts

if compare_utils.alphafold_is_installed():
    alphafold = compare_utils.import_alphafold()
    import jax
    import haiku as hk
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
34
35


Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
X_90_ROT = torch.tensor(
    [
        [1, 0, 0],
        [0, 0, -1],
        [0, 1, 0],
    ]
)

X_NEG_90_ROT = torch.tensor(
    [
        [1, 0, 0],
        [0, 0, 1],
        [0, -1, 0],
    ]
)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
51
52


Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
53
class TestUtils(unittest.TestCase):
54
    def test_rigid_from_3_points_shape(self):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
55
56
57
58
59
60
61
        batch_size = 2
        n_res = 5

        x1 = torch.rand((batch_size, n_res, 3))
        x2 = torch.rand((batch_size, n_res, 3))
        x3 = torch.rand((batch_size, n_res, 3))

62
        r = Rigid.from_3_points(x1, x2, x3)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
63

64
        rot, tra = r.get_rots().get_rot_mats(), r.get_trans()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
65
66
67
68

        self.assertTrue(rot.shape == (batch_size, n_res, 3, 3))
        self.assertTrue(torch.all(tra == x2))

69
    def test_rigid_from_4x4(self):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
70
71
72
        batch_size = 2
        transf = [
            [1, 0, 0, 1],
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
73
            [0, 0, -1, 2],
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
74
75
76
77
78
79
80
81
            [0, 1, 0, 3],
            [0, 0, 0, 1],
        ]
        transf = torch.tensor(transf)

        true_rot = transf[:3, :3]
        true_trans = transf[:3, 3]

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
82
        transf = torch.stack([transf for _ in range(batch_size)], dim=0)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
83

84
        r = Rigid.from_tensor_4x4(transf)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
85

86
        rot, tra = r.get_rots().get_rot_mats(), r.get_trans()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
87
88
89
90

        self.assertTrue(torch.all(rot == true_rot.unsqueeze(0)))
        self.assertTrue(torch.all(tra == true_trans.unsqueeze(0)))

91
    def test_rigid_shape(self):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
92
93
        batch_size = 2
        n = 5
94
95
96
        transf = Rigid(
            Rotation(rot_mats=torch.rand((batch_size, n, 3, 3))), 
            torch.rand((batch_size, n, 3))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
97
98
99
100
        )

        self.assertTrue(transf.shape == (batch_size, n))

101
    def test_rigid_cat(self):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
102
103
        batch_size = 2
        n = 5
104
105
106
        transf = Rigid(
            Rotation(rot_mats=torch.rand((batch_size, n, 3, 3))), 
            torch.rand((batch_size, n, 3))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
107
108
        )

109
        transf_cat = Rigid.cat([transf, transf], dim=0)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
110

111
112
        transf_rots = transf.get_rots().get_rot_mats()
        transf_cat_rots = transf_cat.get_rots().get_rot_mats()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
113

114
        self.assertTrue(transf_cat_rots.shape == (batch_size * 2, n, 3, 3))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
115

116
117
        transf_cat = Rigid.cat([transf, transf], dim=1)
        transf_cat_rots = transf_cat.get_rots().get_rot_mats()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
118

119
        self.assertTrue(transf_cat_rots.shape == (batch_size, n * 2, 3, 3))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
120

121
122
123
124
125
126
        self.assertTrue(torch.all(transf_cat_rots[:, :n] == transf_rots))
        self.assertTrue(
            torch.all(transf_cat.get_trans()[:, :n] == transf.get_trans())
        )

    def test_rigid_compose(self):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
127
128
129
        trans_1 = [0, 1, 0]
        trans_2 = [0, 0, 1]

130
131
132
133
134
135
136
137
138
139
140
        r = Rotation(rot_mats=X_90_ROT)
        t = torch.tensor(trans_1)

        t1 = Rigid(
            Rotation(rot_mats=X_90_ROT), 
            torch.tensor(trans_1)
        )
        t2 = Rigid(
            Rotation(rot_mats=X_NEG_90_ROT), 
            torch.tensor(trans_2)
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
141
142
143

        t3 = t1.compose(t2)

144
145
146
147
148
149
        self.assertTrue(
            torch.all(t3.get_rots().get_rot_mats() == torch.eye(3))
        )
        self.assertTrue(
            torch.all(t3.get_trans() == 0)
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
150

151
    def test_rigid_apply(self):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
152
153
154
155
        rots = torch.stack([X_90_ROT, X_NEG_90_ROT], dim=0)
        trans = torch.tensor([1, 1, 1])
        trans = torch.stack([trans, trans], dim=0)

156
        t = Rigid(Rotation(rot_mats=rots), trans)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
157
158
159

        x = torch.arange(30)
        x = torch.stack([x, x], dim=0)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
160
        x = x.view(2, -1, 3)  # [2, 10, 3]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177

        pts = t[..., None].apply(x)

        # All simple consequences of the two x-axis rotations
        self.assertTrue(torch.all(pts[..., 0] == x[..., 0] + 1))
        self.assertTrue(torch.all(pts[0, :, 1] == x[0, :, 2] * -1 + 1))
        self.assertTrue(torch.all(pts[1, :, 1] == x[1, :, 2] + 1))
        self.assertTrue(torch.all(pts[0, :, 2] == x[0, :, 1] + 1))
        self.assertTrue(torch.all(pts[1, :, 2] == x[1, :, 1] * -1 + 1))

    def test_quat_to_rot(self):
        forty_five = math.pi / 4
        quat = torch.tensor([math.cos(forty_five), math.sin(forty_five), 0, 0])
        rot = quat_to_rot(quat)
        eps = 1e-07
        self.assertTrue(torch.all(torch.abs(rot - X_90_ROT) < eps))

178
179
180
181
182
183
    def test_rot_to_quat(self):
        quat = rot_to_quat(X_90_ROT)
        eps = 1e-07
        ans = torch.tensor([math.sqrt(0.5), math.sqrt(0.5), 0., 0.])
        self.assertTrue(torch.all(torch.abs(quat - ans) < eps))

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
    def test_chunk_layer_tensor(self):
        x = torch.rand(2, 4, 5, 15)
        l = torch.nn.Linear(15, 30)
        chunked = chunk_layer(l, {"input": x}, chunk_size=4, no_batch_dims=3)
        unchunked = l(x)

        self.assertTrue(torch.all(chunked == unchunked))

    def test_chunk_layer_dict(self):
        class LinearDictLayer(torch.nn.Linear):
            def forward(self, input):
                out = super().forward(input)
                return {"out": out, "inner": {"out": out + 1}}

        x = torch.rand(2, 4, 5, 15)
        l = LinearDictLayer(15, 30)

        chunked = chunk_layer(l, {"input": x}, chunk_size=4, no_batch_dims=3)
        unchunked = l(x)

        self.assertTrue(torch.all(chunked["out"] == unchunked["out"]))
        self.assertTrue(
            torch.all(chunked["inner"]["out"] == unchunked["inner"]["out"])
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
207
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222

    def test_chunk_slice_dict(self):
        x = torch.rand(3, 4, 3, 5)
        x_flat = x.view(-1, 5)

        prod = 1
        for d in x.shape[:-1]:
            prod = prod * d

        for i in range(prod):
            for j in range(i + 1, prod + 1):
                chunked = _chunk_slice(x, i, j, len(x.shape[:-1]))
                chunked_flattened = x_flat[i:j]

                self.assertTrue(torch.all(chunked == chunked_flattened))
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252

    @compare_utils.skip_unless_alphafold_installed()
    def test_pre_compose_compare(self):
        quat = np.random.rand(20, 4)
        trans = [np.random.rand(20) for _ in range(3)]
        quat_affine = alphafold.model.quat_affine.QuatAffine(
            quat, translation=trans
        )

        update_vec = np.random.rand(20, 6)
        new_gt = quat_affine.pre_compose(update_vec)

        quat_t = torch.tensor(quat)
        trans_t = torch.stack([torch.tensor(t) for t in trans], dim=-1)
        rigid = Rigid(Rotation(quats=quat_t), trans_t)
        new_repro = rigid.compose_q_update_vec(torch.tensor(update_vec))

        new_gt_q = torch.tensor(np.array(new_gt.quaternion))
        new_gt_t = torch.stack(
            [torch.tensor(np.array(t)) for t in new_gt.translation], dim=-1
        )
        new_repro_q = new_repro.get_rots().get_quats()
        new_repro_t = new_repro.get_trans()

        self.assertTrue(
            torch.max(torch.abs(new_gt_q - new_repro_q)) < consts.eps
        )
        self.assertTrue(
            torch.max(torch.abs(new_gt_t - new_repro_t)) < consts.eps
        )