test_feats.py 13.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import numpy as np
import unittest

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
19
import openfold.data.data_transforms as data_transforms
20
21
22
23
24
25
26
from openfold.np.residue_constants import (
    restype_rigid_group_default_frame,
    restype_atom14_to_rigid_group,
    restype_atom14_mask,
    restype_atom14_rigid_group_positions,
)
import openfold.utils.feats as feats
27
from openfold.utils.rigid_utils import Rotation, Rigid
28
29
30
from openfold.utils.geometry.rigid_matrix_vector import Rigid3Array
from openfold.utils.geometry.rotation_matrix import Rot3Array
from openfold.utils.geometry.vector import Vec3Array
31
from openfold.utils.tensor_utils import (
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
32
33
    tree_map,
    tensor_tree_map,
34
35
36
)
import tests.compare_utils as compare_utils
from tests.config import consts
Christina Floristean's avatar
Christina Floristean committed
37
from tests.data_utils import random_affines_4x4, random_asym_ids
38

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
39
if compare_utils.alphafold_is_installed():
40
41
42
43
44
45
    alphafold = compare_utils.import_alphafold()
    import jax
    import haiku as hk


class TestFeats(unittest.TestCase):
46
47
48
49
50
51
52
53
54
55
56
57
58
    @classmethod
    def setUpClass(cls):
        if consts.is_multimer:
            cls.am_atom = alphafold.model.all_atom_multimer
            cls.am_fold = alphafold.model.folding_multimer
            cls.am_modules = alphafold.model.modules_multimer
            cls.am_rigid = alphafold.model.geometry
        else:
            cls.am_atom = alphafold.model.all_atom
            cls.am_fold = alphafold.model.folding
            cls.am_modules = alphafold.model.modules
            cls.am_rigid = alphafold.model.r3

59
60
61
62
63
64
65
66
    @compare_utils.skip_unless_alphafold_installed()
    def test_pseudo_beta_fn_compare(self):
        def test_pbf(aatype, all_atom_pos, all_atom_mask):
            return alphafold.model.modules.pseudo_beta_fn(
                aatype,
                all_atom_pos,
                all_atom_mask,
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
67

68
        f = hk.transform(test_pbf)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
69
70
71

        n_res = consts.n_res

72
73
74
        aatype = np.random.randint(0, 22, (n_res,))
        all_atom_pos = np.random.rand(n_res, 37, 3).astype(np.float32)
        all_atom_mask = np.random.randint(0, 2, (n_res, 37))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
75

76
77
78
79
80
        out_gt_pos, out_gt_mask = f.apply(
            {}, None, aatype, all_atom_pos, all_atom_mask
        )
        out_gt_pos = torch.tensor(np.array(out_gt_pos.block_until_ready()))
        out_gt_mask = torch.tensor(np.array(out_gt_mask.block_until_ready()))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
81

82
83
84
85
86
87
88
        out_repro_pos, out_repro_mask = feats.pseudo_beta_fn(
            torch.tensor(aatype).cuda(),
            torch.tensor(all_atom_pos).cuda(),
            torch.tensor(all_atom_mask).cuda(),
        )
        out_repro_pos = out_repro_pos.cpu()
        out_repro_mask = out_repro_mask.cpu()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
89

90
91
92
93
94
95
96
97
98
99
100
        self.assertTrue(
            torch.max(torch.abs(out_gt_pos - out_repro_pos)) < consts.eps
        )
        self.assertTrue(
            torch.max(torch.abs(out_gt_mask - out_repro_mask)) < consts.eps
        )

    @compare_utils.skip_unless_alphafold_installed()
    def test_atom37_to_torsion_angles_compare(self):
        def run_test(aatype, all_atom_pos, all_atom_mask):
            return alphafold.model.all_atom.atom37_to_torsion_angles(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
101
102
                aatype,
                all_atom_pos,
103
104
105
                all_atom_mask,
                placeholder_for_undefined=False,
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
106

107
        f = hk.transform(run_test)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
108
109

        n_templ = 7
110
        n_res = 13
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
111

112
113
        aatype = np.random.randint(0, 22, (n_templ, n_res)).astype(np.int64)
        all_atom_pos = np.random.rand(n_templ, n_res, 37, 3).astype(np.float32)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
114
115
116
117
        all_atom_mask = np.random.randint(0, 2, (n_templ, n_res, 37)).astype(
            np.float32
        )

118
119
        out_gt = f.apply({}, None, aatype, all_atom_pos, all_atom_mask)
        out_gt = jax.tree_map(lambda x: torch.as_tensor(np.array(x)), out_gt)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
120

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
121
122
123
124
125
126
        out_repro = data_transforms.atom37_to_torsion_angles()(
            {
                "aatype": torch.as_tensor(aatype).cuda(),
                "all_atom_positions": torch.as_tensor(all_atom_pos).cuda(),
                "all_atom_mask": torch.as_tensor(all_atom_mask).cuda(),
            },
127
128
129
130
        )
        tasc = out_repro["torsion_angles_sin_cos"].cpu()
        atasc = out_repro["alt_torsion_angles_sin_cos"].cpu()
        tam = out_repro["torsion_angles_mask"].cpu()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
131

132
133
134
        # This function is extremely sensitive to floating point imprecisions,
        # so it is given much greater latitude in comparison tests.
        self.assertTrue(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
135
136
137
138
139
140
            torch.mean(torch.abs(out_gt["torsion_angles_sin_cos"] - tasc))
            < 0.01
        )
        self.assertTrue(
            torch.mean(torch.abs(out_gt["alt_torsion_angles_sin_cos"] - atasc))
            < 0.01
141
142
        )
        self.assertTrue(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
143
144
            torch.max(torch.abs(out_gt["torsion_angles_mask"] - tam))
            < consts.eps
145
146
147
148
149
        )

    @compare_utils.skip_unless_alphafold_installed()
    def test_atom37_to_frames_compare(self):
        def run_atom37_to_frames(aatype, all_atom_positions, all_atom_mask):
150
151
152
            if consts.is_multimer:
                all_atom_positions = self.am_rigid.Vec3Array.from_array(all_atom_positions)
            return self.am_atom.atom37_to_frames(
153
154
                aatype, all_atom_positions, all_atom_mask
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
155

156
        f = hk.transform(run_atom37_to_frames)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
157
158
159

        n_res = consts.n_res

160
161
        batch = {
            "aatype": np.random.randint(0, 21, (n_res,)),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
162
163
164
165
166
167
            "all_atom_positions": np.random.rand(n_res, 37, 3).astype(
                np.float32
            ),
            "all_atom_mask": np.random.randint(0, 2, (n_res, 37)).astype(
                np.float32
            ),
168
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
169

170
        out_gt = f.apply({}, None, **batch)
171
172

        if consts.is_multimer:
Christina Floristean's avatar
Christina Floristean committed
173
            batch["asym_id"] = random_asym_ids(n_res)
174
175
            to_tensor = (lambda t: torch.tensor(np.array(t))
                         if not isinstance(t, self.am_rigid.Rigid3Array)
Christina Floristean's avatar
Christina Floristean committed
176
                         else torch.tensor(np.array(t.to_array())))
177
178
179
        else:
            to_tensor = lambda t: torch.tensor(np.array(t))

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
180
181
        out_gt = {k: to_tensor(v) for k, v in out_gt.items()}

Christina Floristean's avatar
Christina Floristean committed
182
183
184
185
186
187
        def rigid3x4_to_4x4(rigid3arr):
            four_by_four = torch.zeros(*rigid3arr.shape[:-2], 4, 4)
            four_by_four[..., :3, :4] = rigid3arr
            four_by_four[..., 3, 3] = 1
            return four_by_four

188
189
190
        def flat12_to_4x4(flat12):
            rot = flat12[..., :9].view(*flat12.shape[:-1], 3, 3)
            trans = flat12[..., 9:]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
191

192
193
194
195
            four_by_four = torch.zeros(*flat12.shape[:-1], 4, 4)
            four_by_four[..., :3, :3] = rot
            four_by_four[..., :3, 3] = trans
            four_by_four[..., 3, 3] = 1
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
196

197
            return four_by_four
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
198

Christina Floristean's avatar
Christina Floristean committed
199
200
201
        convert_func = rigid3x4_to_4x4 if consts.is_multimer else flat12_to_4x4

        out_gt["rigidgroups_gt_frames"] = convert_func(
202
203
            out_gt["rigidgroups_gt_frames"]
        )
Christina Floristean's avatar
Christina Floristean committed
204
        out_gt["rigidgroups_alt_gt_frames"] = convert_func(
205
206
            out_gt["rigidgroups_alt_gt_frames"]
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
207

208
209
        to_tensor = lambda t: torch.tensor(np.array(t)).cuda()
        batch = tree_map(to_tensor, batch, np.ndarray)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
210

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
211
        out_repro = data_transforms.atom37_to_frames(batch)
212
        out_repro = tensor_tree_map(lambda t: t.cpu(), out_repro)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
213
214

        for k, v in out_gt.items():
215
216
217
218
219
220
221
222
223
            self.assertTrue(
                torch.max(torch.abs(out_gt[k] - out_repro[k])) < consts.eps
            )

    def test_torsion_angles_to_frames_shape(self):
        batch_size = 2
        n = 5
        rots = torch.rand((batch_size, n, 3, 3))
        trans = torch.rand((batch_size, n, 3))
224
225
226
227
228
229
230

        if consts.is_multimer:
            rotation = Rot3Array.from_array(rots)
            translation = Vec3Array.from_array(trans)
            ts = Rigid3Array(rotation, translation)
        else:
            ts = Rigid(Rotation(rot_mats=rots), trans)
231
232
233
234
235
236
237

        angles = torch.rand((batch_size, n, 7, 2))

        aas = torch.tensor([i % 2 for i in range(n)])
        aas = torch.stack([aas for _ in range(batch_size)])

        frames = feats.torsion_angles_to_frames(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
238
239
240
            ts,
            angles,
            aas,
241
242
243
244
            torch.tensor(restype_rigid_group_default_frame),
        )

        self.assertTrue(frames.shape == (batch_size, n, 8))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
245

246
247
248
    @compare_utils.skip_unless_alphafold_installed()
    def test_torsion_angles_to_frames_compare(self):
        def run_torsion_angles_to_frames(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
249
            aatype, backb_to_global, torsion_angles_sin_cos
250
        ):
251
            return self.am_atom.torsion_angles_to_frames(
252
253
254
255
                aatype,
                backb_to_global,
                torsion_angles_sin_cos,
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
256

257
        f = hk.transform(run_torsion_angles_to_frames)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
258

259
        n_res = consts.n_res
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
260

261
        aatype = np.random.randint(0, 21, size=(n_res,))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
262

263
        affines = random_affines_4x4((n_res,))
264
265
266
267
268
269
270
271
272
273
274

        if consts.is_multimer:
            rigids = self.am_rigid.Rigid3Array.from_array4x4(affines)
            transformations = Rigid3Array.from_tensor_4x4(
                torch.as_tensor(affines).float()
            )
        else:
            rigids = self.am_rigid.rigids_from_tensor4x4(affines)
            transformations = Rigid.from_tensor_4x4(
                torch.as_tensor(affines).float()
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
275

276
        torsion_angles_sin_cos = np.random.rand(n_res, 7, 2)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
277
278
279

        out_gt = f.apply({}, None, aatype, rigids, torsion_angles_sin_cos)

280
        jax.tree_map(lambda x: x.block_until_ready(), out_gt)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
281

282
283
284
285
286
287
        out = feats.torsion_angles_to_frames(
            transformations.cuda(),
            torch.as_tensor(torsion_angles_sin_cos).cuda(),
            torch.as_tensor(aatype).cuda(),
            torch.tensor(restype_rigid_group_default_frame).cuda(),
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
288

289
        # Convert the Rigids to 4x4 transformation tensors
Christina Floristean's avatar
Christina Floristean committed
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
        out_gt_rot = out_gt.rot if not consts.is_multimer else out_gt.rotation.to_array()
        out_gt_trans = out_gt.trans if not consts.is_multimer else out_gt.translation.to_array()

        if consts.is_multimer:
            rots_gt = torch.as_tensor(np.array(out_gt_rot))
            trans_gt = torch.as_tensor(np.array(out_gt_trans))
        else:
            rots_gt = list(map(lambda x: torch.as_tensor(np.array(x)), out_gt_rot))
            trans_gt = list(
                map(lambda x: torch.as_tensor(np.array(x)), out_gt_trans)
            )
            rots_gt = torch.cat([x.unsqueeze(-1) for x in rots_gt], dim=-1)
            rots_gt = rots_gt.view(*rots_gt.shape[:-1], 3, 3)
            trans_gt = torch.cat([x.unsqueeze(-1) for x in trans_gt], dim=-1)

305
306
307
308
        transforms_gt = torch.cat([rots_gt, trans_gt.unsqueeze(-1)], dim=-1)
        bottom_row = torch.zeros((*rots_gt.shape[:-2], 1, 4))
        bottom_row[..., 3] = 1
        transforms_gt = torch.cat([transforms_gt, bottom_row], dim=-2)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
309

310
        transforms_repro = out.to_tensor_4x4().cpu()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
311

312
313
314
315
316
317
318
319
320
321
        self.assertTrue(
            torch.max(torch.abs(transforms_gt - transforms_repro) < consts.eps)
        )

    def test_frames_and_literature_positions_to_atom14_pos_shape(self):
        batch_size = consts.batch_size
        n_res = consts.n_res

        rots = torch.rand((batch_size, n_res, 8, 3, 3))
        trans = torch.rand((batch_size, n_res, 8, 3))
322
323
324
325
326
327
328

        if consts.is_multimer:
            rotation = Rot3Array.from_array(rots)
            translation = Vec3Array.from_array(trans)
            ts = Rigid3Array(rotation, translation)
        else:
            ts = Rigid(Rotation(rot_mats=rots), trans)
329
330
331
332
333
334
335
336
337
338
339

        f = torch.randint(low=0, high=21, size=(batch_size, n_res)).long()

        xyz = feats.frames_and_literature_positions_to_atom14_pos(
            ts,
            f,
            torch.tensor(restype_rigid_group_default_frame),
            torch.tensor(restype_atom14_to_rigid_group),
            torch.tensor(restype_atom14_mask),
            torch.tensor(restype_atom14_rigid_group_positions),
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
340

341
342
343
344
345
        self.assertTrue(xyz.shape == (batch_size, n_res, 14, 3))

    @compare_utils.skip_unless_alphafold_installed()
    def test_frames_and_literature_positions_to_atom14_pos_compare(self):
        def run_f(aatype, affines):
346
            return self.am_atom.frames_and_literature_positions_to_atom14_pos(
347
348
                aatype, affines
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
349

350
        f = hk.transform(run_f)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
351

352
        n_res = consts.n_res
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
353

354
        aatype = np.random.randint(0, 21, size=(n_res,))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
355

356
        affines = random_affines_4x4((n_res, 8))
357
358
359
360
361
362
363
364
365
366
367

        if consts.is_multimer:
            rigids = self.am_rigid.Rigid3Array.from_array4x4(affines)
            transformations = Rigid3Array.from_tensor_4x4(
                torch.as_tensor(affines).float()
            )
        else:
            rigids = self.am_rigid.rigids_from_tensor4x4(affines)
            transformations = Rigid.from_tensor_4x4(
                torch.as_tensor(affines).float()
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
368
369

        out_gt = f.apply({}, None, aatype, rigids)
370
        jax.tree_map(lambda x: x.block_until_ready(), out_gt)
371
372
373
374
375
376
377

        if consts.is_multimer:
            out_gt = torch.as_tensor(np.array(out_gt.to_array()))
        else:
            out_gt = torch.stack(
                [torch.as_tensor(np.array(x)) for x in out_gt], dim=-1
            )
378
379

        out_repro = feats.frames_and_literature_positions_to_atom14_pos(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
380
            transformations.cuda(),
381
382
383
384
385
386
            torch.as_tensor(aatype).cuda(),
            torch.tensor(restype_rigid_group_default_frame).cuda(),
            torch.tensor(restype_atom14_to_rigid_group).cuda(),
            torch.tensor(restype_atom14_mask).cuda(),
            torch.tensor(restype_atom14_rigid_group_positions).cuda(),
        ).cpu()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
387

388
389
390
391
392
        self.assertTrue(torch.max(torch.abs(out_gt - out_repro) < consts.eps))


if __name__ == "__main__":
    unittest.main()