test_feats.py 14 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import numpy as np
import unittest

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
19
import openfold.data.data_transforms as data_transforms
20
21
22
23
24
25
26
from openfold.np.residue_constants import (
    restype_rigid_group_default_frame,
    restype_atom14_to_rigid_group,
    restype_atom14_mask,
    restype_atom14_rigid_group_positions,
)
import openfold.utils.feats as feats
27
from openfold.utils.rigid_utils import Rotation, Rigid
28
29
30
from openfold.utils.geometry.rigid_matrix_vector import Rigid3Array
from openfold.utils.geometry.rotation_matrix import Rot3Array
from openfold.utils.geometry.vector import Vec3Array
31
from openfold.utils.tensor_utils import (
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
32
33
    tree_map,
    tensor_tree_map,
34
35
36
)
import tests.compare_utils as compare_utils
from tests.config import consts
Christina Floristean's avatar
Christina Floristean committed
37
from tests.data_utils import random_affines_4x4, random_asym_ids
38

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
39
if compare_utils.alphafold_is_installed():
40
41
42
43
44
45
    alphafold = compare_utils.import_alphafold()
    import jax
    import haiku as hk


class TestFeats(unittest.TestCase):
46
47
    @classmethod
    def setUpClass(cls):
48
49
50
51
52
53
54
55
56
57
58
        if compare_utils.alphafold_is_installed():
            if consts.is_multimer:
                cls.am_atom = alphafold.model.all_atom_multimer
                cls.am_fold = alphafold.model.folding_multimer
                cls.am_modules = alphafold.model.modules_multimer
                cls.am_rigid = alphafold.model.geometry
            else:
                cls.am_atom = alphafold.model.all_atom
                cls.am_fold = alphafold.model.folding
                cls.am_modules = alphafold.model.modules
                cls.am_rigid = alphafold.model.r3
59

60
61
62
63
64
65
66
67
    @compare_utils.skip_unless_alphafold_installed()
    def test_pseudo_beta_fn_compare(self):
        def test_pbf(aatype, all_atom_pos, all_atom_mask):
            return alphafold.model.modules.pseudo_beta_fn(
                aatype,
                all_atom_pos,
                all_atom_mask,
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
68

69
        f = hk.transform(test_pbf)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
70
71
72

        n_res = consts.n_res

73
74
75
        aatype = np.random.randint(0, 22, (n_res,))
        all_atom_pos = np.random.rand(n_res, 37, 3).astype(np.float32)
        all_atom_mask = np.random.randint(0, 2, (n_res, 37))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
76

77
78
79
80
81
        out_gt_pos, out_gt_mask = f.apply(
            {}, None, aatype, all_atom_pos, all_atom_mask
        )
        out_gt_pos = torch.tensor(np.array(out_gt_pos.block_until_ready()))
        out_gt_mask = torch.tensor(np.array(out_gt_mask.block_until_ready()))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
82

83
84
85
86
87
88
89
        out_repro_pos, out_repro_mask = feats.pseudo_beta_fn(
            torch.tensor(aatype).cuda(),
            torch.tensor(all_atom_pos).cuda(),
            torch.tensor(all_atom_mask).cuda(),
        )
        out_repro_pos = out_repro_pos.cpu()
        out_repro_mask = out_repro_mask.cpu()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
90

91
92
93
94
95
96
97
98
99
100
101
        self.assertTrue(
            torch.max(torch.abs(out_gt_pos - out_repro_pos)) < consts.eps
        )
        self.assertTrue(
            torch.max(torch.abs(out_gt_mask - out_repro_mask)) < consts.eps
        )

    @compare_utils.skip_unless_alphafold_installed()
    def test_atom37_to_torsion_angles_compare(self):
        def run_test(aatype, all_atom_pos, all_atom_mask):
            return alphafold.model.all_atom.atom37_to_torsion_angles(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
102
103
                aatype,
                all_atom_pos,
104
105
106
                all_atom_mask,
                placeholder_for_undefined=False,
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
107

108
        f = hk.transform(run_test)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
109
110

        n_templ = 7
111
        n_res = 13
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
112

113
114
        aatype = np.random.randint(0, 22, (n_templ, n_res)).astype(np.int64)
        all_atom_pos = np.random.rand(n_templ, n_res, 37, 3).astype(np.float32)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
115
116
117
118
        all_atom_mask = np.random.randint(0, 2, (n_templ, n_res, 37)).astype(
            np.float32
        )

119
120
        out_gt = f.apply({}, None, aatype, all_atom_pos, all_atom_mask)
        out_gt = jax.tree_map(lambda x: torch.as_tensor(np.array(x)), out_gt)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
121

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
122
123
124
125
126
127
        out_repro = data_transforms.atom37_to_torsion_angles()(
            {
                "aatype": torch.as_tensor(aatype).cuda(),
                "all_atom_positions": torch.as_tensor(all_atom_pos).cuda(),
                "all_atom_mask": torch.as_tensor(all_atom_mask).cuda(),
            },
128
129
130
131
        )
        tasc = out_repro["torsion_angles_sin_cos"].cpu()
        atasc = out_repro["alt_torsion_angles_sin_cos"].cpu()
        tam = out_repro["torsion_angles_mask"].cpu()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
132

133
134
135
        # This function is extremely sensitive to floating point imprecisions,
        # so it is given much greater latitude in comparison tests.
        self.assertTrue(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
136
137
138
139
140
141
            torch.mean(torch.abs(out_gt["torsion_angles_sin_cos"] - tasc))
            < 0.01
        )
        self.assertTrue(
            torch.mean(torch.abs(out_gt["alt_torsion_angles_sin_cos"] - atasc))
            < 0.01
142
143
        )
        self.assertTrue(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
144
145
            torch.max(torch.abs(out_gt["torsion_angles_mask"] - tam))
            < consts.eps
146
147
148
149
150
        )

    @compare_utils.skip_unless_alphafold_installed()
    def test_atom37_to_frames_compare(self):
        def run_atom37_to_frames(aatype, all_atom_positions, all_atom_mask):
151
152
153
            if consts.is_multimer:
                all_atom_positions = self.am_rigid.Vec3Array.from_array(all_atom_positions)
            return self.am_atom.atom37_to_frames(
154
155
                aatype, all_atom_positions, all_atom_mask
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
156

157
        f = hk.transform(run_atom37_to_frames)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
158
159
160

        n_res = consts.n_res

161
162
        batch = {
            "aatype": np.random.randint(0, 21, (n_res,)),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
163
164
165
166
167
168
            "all_atom_positions": np.random.rand(n_res, 37, 3).astype(
                np.float32
            ),
            "all_atom_mask": np.random.randint(0, 2, (n_res, 37)).astype(
                np.float32
            ),
169
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
170

171
        out_gt = f.apply({}, None, **batch)
172
173

        if consts.is_multimer:
Christina Floristean's avatar
Christina Floristean committed
174
            batch["asym_id"] = random_asym_ids(n_res)
175
176
            to_tensor = (lambda t: torch.tensor(np.array(t))
                         if not isinstance(t, self.am_rigid.Rigid3Array)
Christina Floristean's avatar
Christina Floristean committed
177
                         else torch.tensor(np.array(t.to_array())))
178
179
180
        else:
            to_tensor = lambda t: torch.tensor(np.array(t))

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
181
182
        out_gt = {k: to_tensor(v) for k, v in out_gt.items()}

Christina Floristean's avatar
Christina Floristean committed
183
184
185
186
187
188
        def rigid3x4_to_4x4(rigid3arr):
            four_by_four = torch.zeros(*rigid3arr.shape[:-2], 4, 4)
            four_by_four[..., :3, :4] = rigid3arr
            four_by_four[..., 3, 3] = 1
            return four_by_four

189
190
191
        def flat12_to_4x4(flat12):
            rot = flat12[..., :9].view(*flat12.shape[:-1], 3, 3)
            trans = flat12[..., 9:]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
192

193
194
195
196
            four_by_four = torch.zeros(*flat12.shape[:-1], 4, 4)
            four_by_four[..., :3, :3] = rot
            four_by_four[..., :3, 3] = trans
            four_by_four[..., 3, 3] = 1
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
197

198
            return four_by_four
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
199

Christina Floristean's avatar
Christina Floristean committed
200
201
202
        convert_func = rigid3x4_to_4x4 if consts.is_multimer else flat12_to_4x4

        out_gt["rigidgroups_gt_frames"] = convert_func(
203
204
            out_gt["rigidgroups_gt_frames"]
        )
Christina Floristean's avatar
Christina Floristean committed
205
        out_gt["rigidgroups_alt_gt_frames"] = convert_func(
206
207
            out_gt["rigidgroups_alt_gt_frames"]
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
208

209
210
        to_tensor = lambda t: torch.tensor(np.array(t)).cuda()
        batch = tree_map(to_tensor, batch, np.ndarray)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
211

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
212
        out_repro = data_transforms.atom37_to_frames(batch)
213
        out_repro = tensor_tree_map(lambda t: t.cpu(), out_repro)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
214
215

        for k, v in out_gt.items():
216
217
218
219
220
221
222
223
224
            self.assertTrue(
                torch.max(torch.abs(out_gt[k] - out_repro[k])) < consts.eps
            )

    def test_torsion_angles_to_frames_shape(self):
        batch_size = 2
        n = 5
        rots = torch.rand((batch_size, n, 3, 3))
        trans = torch.rand((batch_size, n, 3))
225
226
227
228
229
230
231

        if consts.is_multimer:
            rotation = Rot3Array.from_array(rots)
            translation = Vec3Array.from_array(trans)
            ts = Rigid3Array(rotation, translation)
        else:
            ts = Rigid(Rotation(rot_mats=rots), trans)
232
233
234
235
236
237
238

        angles = torch.rand((batch_size, n, 7, 2))

        aas = torch.tensor([i % 2 for i in range(n)])
        aas = torch.stack([aas for _ in range(batch_size)])

        frames = feats.torsion_angles_to_frames(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
239
240
241
            ts,
            angles,
            aas,
242
243
244
245
            torch.tensor(restype_rigid_group_default_frame),
        )

        self.assertTrue(frames.shape == (batch_size, n, 8))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
246

247
248
249
    @compare_utils.skip_unless_alphafold_installed()
    def test_torsion_angles_to_frames_compare(self):
        def run_torsion_angles_to_frames(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
250
            aatype, backb_to_global, torsion_angles_sin_cos
251
        ):
252
            return self.am_atom.torsion_angles_to_frames(
253
254
255
256
                aatype,
                backb_to_global,
                torsion_angles_sin_cos,
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
257

258
        f = hk.transform(run_torsion_angles_to_frames)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
259

260
        n_res = consts.n_res
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
261

262
        aatype = np.random.randint(0, 21, size=(n_res,))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
263

264
        affines = random_affines_4x4((n_res,))
265
266
267
268
269
270
271
272
273
274
275

        if consts.is_multimer:
            rigids = self.am_rigid.Rigid3Array.from_array4x4(affines)
            transformations = Rigid3Array.from_tensor_4x4(
                torch.as_tensor(affines).float()
            )
        else:
            rigids = self.am_rigid.rigids_from_tensor4x4(affines)
            transformations = Rigid.from_tensor_4x4(
                torch.as_tensor(affines).float()
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
276

277
        torsion_angles_sin_cos = np.random.rand(n_res, 7, 2)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
278
279
280

        out_gt = f.apply({}, None, aatype, rigids, torsion_angles_sin_cos)

281
        jax.tree_map(lambda x: x.block_until_ready(), out_gt)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
282

283
284
285
286
287
288
        out = feats.torsion_angles_to_frames(
            transformations.cuda(),
            torch.as_tensor(torsion_angles_sin_cos).cuda(),
            torch.as_tensor(aatype).cuda(),
            torch.tensor(restype_rigid_group_default_frame).cuda(),
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
289

290
        # Convert the Rigids to 4x4 transformation tensors
Christina Floristean's avatar
Christina Floristean committed
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
        out_gt_rot = out_gt.rot if not consts.is_multimer else out_gt.rotation.to_array()
        out_gt_trans = out_gt.trans if not consts.is_multimer else out_gt.translation.to_array()

        if consts.is_multimer:
            rots_gt = torch.as_tensor(np.array(out_gt_rot))
            trans_gt = torch.as_tensor(np.array(out_gt_trans))
        else:
            rots_gt = list(map(lambda x: torch.as_tensor(np.array(x)), out_gt_rot))
            trans_gt = list(
                map(lambda x: torch.as_tensor(np.array(x)), out_gt_trans)
            )
            rots_gt = torch.cat([x.unsqueeze(-1) for x in rots_gt], dim=-1)
            rots_gt = rots_gt.view(*rots_gt.shape[:-1], 3, 3)
            trans_gt = torch.cat([x.unsqueeze(-1) for x in trans_gt], dim=-1)

306
307
308
309
        transforms_gt = torch.cat([rots_gt, trans_gt.unsqueeze(-1)], dim=-1)
        bottom_row = torch.zeros((*rots_gt.shape[:-2], 1, 4))
        bottom_row[..., 3] = 1
        transforms_gt = torch.cat([transforms_gt, bottom_row], dim=-2)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
310

311
        transforms_repro = out.to_tensor_4x4().cpu()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
312

313
314
315
316
317
318
319
320
321
322
        self.assertTrue(
            torch.max(torch.abs(transforms_gt - transforms_repro) < consts.eps)
        )

    def test_frames_and_literature_positions_to_atom14_pos_shape(self):
        batch_size = consts.batch_size
        n_res = consts.n_res

        rots = torch.rand((batch_size, n_res, 8, 3, 3))
        trans = torch.rand((batch_size, n_res, 8, 3))
323
324
325
326
327
328
329

        if consts.is_multimer:
            rotation = Rot3Array.from_array(rots)
            translation = Vec3Array.from_array(trans)
            ts = Rigid3Array(rotation, translation)
        else:
            ts = Rigid(Rotation(rot_mats=rots), trans)
330
331
332
333
334
335
336
337
338
339
340

        f = torch.randint(low=0, high=21, size=(batch_size, n_res)).long()

        xyz = feats.frames_and_literature_positions_to_atom14_pos(
            ts,
            f,
            torch.tensor(restype_rigid_group_default_frame),
            torch.tensor(restype_atom14_to_rigid_group),
            torch.tensor(restype_atom14_mask),
            torch.tensor(restype_atom14_rigid_group_positions),
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
341

342
343
344
345
346
        self.assertTrue(xyz.shape == (batch_size, n_res, 14, 3))

    @compare_utils.skip_unless_alphafold_installed()
    def test_frames_and_literature_positions_to_atom14_pos_compare(self):
        def run_f(aatype, affines):
347
            return self.am_atom.frames_and_literature_positions_to_atom14_pos(
348
349
                aatype, affines
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
350

351
        f = hk.transform(run_f)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
352

353
        n_res = consts.n_res
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
354

355
        aatype = np.random.randint(0, 21, size=(n_res,))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
356

357
        affines = random_affines_4x4((n_res, 8))
358
359
360
361
362
363
364
365
366
367
368

        if consts.is_multimer:
            rigids = self.am_rigid.Rigid3Array.from_array4x4(affines)
            transformations = Rigid3Array.from_tensor_4x4(
                torch.as_tensor(affines).float()
            )
        else:
            rigids = self.am_rigid.rigids_from_tensor4x4(affines)
            transformations = Rigid.from_tensor_4x4(
                torch.as_tensor(affines).float()
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
369
370

        out_gt = f.apply({}, None, aatype, rigids)
371
        jax.tree_map(lambda x: x.block_until_ready(), out_gt)
372
373
374
375
376
377
378

        if consts.is_multimer:
            out_gt = torch.as_tensor(np.array(out_gt.to_array()))
        else:
            out_gt = torch.stack(
                [torch.as_tensor(np.array(x)) for x in out_gt], dim=-1
            )
379
380

        out_repro = feats.frames_and_literature_positions_to_atom14_pos(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
381
            transformations.cuda(),
382
383
384
385
386
387
            torch.as_tensor(aatype).cuda(),
            torch.tensor(restype_rigid_group_default_frame).cuda(),
            torch.tensor(restype_atom14_to_rigid_group).cuda(),
            torch.tensor(restype_atom14_mask).cuda(),
            torch.tensor(restype_atom14_rigid_group_positions).cuda(),
        ).cpu()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
388

389
390
391
392
393
        self.assertTrue(torch.max(torch.abs(out_gt - out_repro) < consts.eps))


if __name__ == "__main__":
    unittest.main()