test_feats.py 11.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import numpy as np
import unittest

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
19
import openfold.data.data_transforms as data_transforms
20
21
22
23
24
25
26
from openfold.np.residue_constants import (
    restype_rigid_group_default_frame,
    restype_atom14_to_rigid_group,
    restype_atom14_mask,
    restype_atom14_rigid_group_positions,
)
import openfold.utils.feats as feats
27
from openfold.utils.rigid_utils import Rotation, Rigid
28
from openfold.utils.tensor_utils import (
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
29
30
    tree_map,
    tensor_tree_map,
31
32
33
34
35
)
import tests.compare_utils as compare_utils
from tests.config import consts
from tests.data_utils import random_affines_4x4

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
36
if compare_utils.alphafold_is_installed():
37
38
39
40
41
42
43
44
45
46
47
48
49
50
    alphafold = compare_utils.import_alphafold()
    import jax
    import haiku as hk


class TestFeats(unittest.TestCase):
    @compare_utils.skip_unless_alphafold_installed()
    def test_pseudo_beta_fn_compare(self):
        def test_pbf(aatype, all_atom_pos, all_atom_mask):
            return alphafold.model.modules.pseudo_beta_fn(
                aatype,
                all_atom_pos,
                all_atom_mask,
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
51

52
        f = hk.transform(test_pbf)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
53
54
55

        n_res = consts.n_res

56
57
58
        aatype = np.random.randint(0, 22, (n_res,))
        all_atom_pos = np.random.rand(n_res, 37, 3).astype(np.float32)
        all_atom_mask = np.random.randint(0, 2, (n_res, 37))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
59

60
61
62
63
64
        out_gt_pos, out_gt_mask = f.apply(
            {}, None, aatype, all_atom_pos, all_atom_mask
        )
        out_gt_pos = torch.tensor(np.array(out_gt_pos.block_until_ready()))
        out_gt_mask = torch.tensor(np.array(out_gt_mask.block_until_ready()))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
65

66
67
68
69
70
71
72
        out_repro_pos, out_repro_mask = feats.pseudo_beta_fn(
            torch.tensor(aatype).cuda(),
            torch.tensor(all_atom_pos).cuda(),
            torch.tensor(all_atom_mask).cuda(),
        )
        out_repro_pos = out_repro_pos.cpu()
        out_repro_mask = out_repro_mask.cpu()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
73

74
75
76
77
78
79
80
81
82
83
84
        self.assertTrue(
            torch.max(torch.abs(out_gt_pos - out_repro_pos)) < consts.eps
        )
        self.assertTrue(
            torch.max(torch.abs(out_gt_mask - out_repro_mask)) < consts.eps
        )

    @compare_utils.skip_unless_alphafold_installed()
    def test_atom37_to_torsion_angles_compare(self):
        def run_test(aatype, all_atom_pos, all_atom_mask):
            return alphafold.model.all_atom.atom37_to_torsion_angles(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
85
86
                aatype,
                all_atom_pos,
87
88
89
                all_atom_mask,
                placeholder_for_undefined=False,
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
90

91
        f = hk.transform(run_test)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
92
93

        n_templ = 7
94
        n_res = 13
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
95

96
97
        aatype = np.random.randint(0, 22, (n_templ, n_res)).astype(np.int64)
        all_atom_pos = np.random.rand(n_templ, n_res, 37, 3).astype(np.float32)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
98
99
100
101
        all_atom_mask = np.random.randint(0, 2, (n_templ, n_res, 37)).astype(
            np.float32
        )

102
103
        out_gt = f.apply({}, None, aatype, all_atom_pos, all_atom_mask)
        out_gt = jax.tree_map(lambda x: torch.as_tensor(np.array(x)), out_gt)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
104

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
105
106
107
108
109
110
        out_repro = data_transforms.atom37_to_torsion_angles()(
            {
                "aatype": torch.as_tensor(aatype).cuda(),
                "all_atom_positions": torch.as_tensor(all_atom_pos).cuda(),
                "all_atom_mask": torch.as_tensor(all_atom_mask).cuda(),
            },
111
112
113
114
        )
        tasc = out_repro["torsion_angles_sin_cos"].cpu()
        atasc = out_repro["alt_torsion_angles_sin_cos"].cpu()
        tam = out_repro["torsion_angles_mask"].cpu()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
115

116
117
118
        # This function is extremely sensitive to floating point imprecisions,
        # so it is given much greater latitude in comparison tests.
        self.assertTrue(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
119
120
121
122
123
124
            torch.mean(torch.abs(out_gt["torsion_angles_sin_cos"] - tasc))
            < 0.01
        )
        self.assertTrue(
            torch.mean(torch.abs(out_gt["alt_torsion_angles_sin_cos"] - atasc))
            < 0.01
125
126
        )
        self.assertTrue(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
127
128
            torch.max(torch.abs(out_gt["torsion_angles_mask"] - tam))
            < consts.eps
129
130
131
132
133
134
135
136
        )

    @compare_utils.skip_unless_alphafold_installed()
    def test_atom37_to_frames_compare(self):
        def run_atom37_to_frames(aatype, all_atom_positions, all_atom_mask):
            return alphafold.model.all_atom.atom37_to_frames(
                aatype, all_atom_positions, all_atom_mask
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
137

138
        f = hk.transform(run_atom37_to_frames)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
139
140
141

        n_res = consts.n_res

142
143
        batch = {
            "aatype": np.random.randint(0, 21, (n_res,)),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
144
145
146
147
148
149
            "all_atom_positions": np.random.rand(n_res, 37, 3).astype(
                np.float32
            ),
            "all_atom_mask": np.random.randint(0, 2, (n_res, 37)).astype(
                np.float32
            ),
150
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
151

152
153
        out_gt = f.apply({}, None, **batch)
        to_tensor = lambda t: torch.tensor(np.array(t))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
154
155
        out_gt = {k: to_tensor(v) for k, v in out_gt.items()}

156
157
158
        def flat12_to_4x4(flat12):
            rot = flat12[..., :9].view(*flat12.shape[:-1], 3, 3)
            trans = flat12[..., 9:]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
159

160
161
162
163
            four_by_four = torch.zeros(*flat12.shape[:-1], 4, 4)
            four_by_four[..., :3, :3] = rot
            four_by_four[..., :3, 3] = trans
            four_by_four[..., 3, 3] = 1
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
164

165
            return four_by_four
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
166

167
168
169
170
171
172
        out_gt["rigidgroups_gt_frames"] = flat12_to_4x4(
            out_gt["rigidgroups_gt_frames"]
        )
        out_gt["rigidgroups_alt_gt_frames"] = flat12_to_4x4(
            out_gt["rigidgroups_alt_gt_frames"]
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
173

174
175
        to_tensor = lambda t: torch.tensor(np.array(t)).cuda()
        batch = tree_map(to_tensor, batch, np.ndarray)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
176

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
177
        out_repro = data_transforms.atom37_to_frames(batch)
178
        out_repro = tensor_tree_map(lambda t: t.cpu(), out_repro)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
179
180

        for k, v in out_gt.items():
181
182
183
184
185
186
187
188
189
            self.assertTrue(
                torch.max(torch.abs(out_gt[k] - out_repro[k])) < consts.eps
            )

    def test_torsion_angles_to_frames_shape(self):
        batch_size = 2
        n = 5
        rots = torch.rand((batch_size, n, 3, 3))
        trans = torch.rand((batch_size, n, 3))
190
        ts = Rigid(Rotation(rot_mats=rots), trans)
191
192
193
194
195
196
197

        angles = torch.rand((batch_size, n, 7, 2))

        aas = torch.tensor([i % 2 for i in range(n)])
        aas = torch.stack([aas for _ in range(batch_size)])

        frames = feats.torsion_angles_to_frames(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
198
199
200
            ts,
            angles,
            aas,
201
202
203
204
            torch.tensor(restype_rigid_group_default_frame),
        )

        self.assertTrue(frames.shape == (batch_size, n, 8))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
205

206
207
208
    @compare_utils.skip_unless_alphafold_installed()
    def test_torsion_angles_to_frames_compare(self):
        def run_torsion_angles_to_frames(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
209
            aatype, backb_to_global, torsion_angles_sin_cos
210
211
212
213
214
215
        ):
            return alphafold.model.all_atom.torsion_angles_to_frames(
                aatype,
                backb_to_global,
                torsion_angles_sin_cos,
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
216

217
        f = hk.transform(run_torsion_angles_to_frames)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
218

219
        n_res = consts.n_res
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
220

221
        aatype = np.random.randint(0, 21, size=(n_res,))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
222

223
224
        affines = random_affines_4x4((n_res,))
        rigids = alphafold.model.r3.rigids_from_tensor4x4(affines)
225
226
227
        transformations = Rigid.from_tensor_4x4(
            torch.as_tensor(affines).float()
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
228

229
        torsion_angles_sin_cos = np.random.rand(n_res, 7, 2)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
230
231
232

        out_gt = f.apply({}, None, aatype, rigids, torsion_angles_sin_cos)

233
        jax.tree_map(lambda x: x.block_until_ready(), out_gt)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
234

235
236
237
238
239
240
        out = feats.torsion_angles_to_frames(
            transformations.cuda(),
            torch.as_tensor(torsion_angles_sin_cos).cuda(),
            torch.as_tensor(aatype).cuda(),
            torch.tensor(restype_rigid_group_default_frame).cuda(),
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
241

242
        # Convert the Rigids to 4x4 transformation tensors
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
243
        rots_gt = list(map(lambda x: torch.as_tensor(np.array(x)), out_gt.rot))
244
245
246
247
248
249
250
251
252
253
        trans_gt = list(
            map(lambda x: torch.as_tensor(np.array(x)), out_gt.trans)
        )
        rots_gt = torch.cat([x.unsqueeze(-1) for x in rots_gt], dim=-1)
        rots_gt = rots_gt.view(*rots_gt.shape[:-1], 3, 3)
        trans_gt = torch.cat([x.unsqueeze(-1) for x in trans_gt], dim=-1)
        transforms_gt = torch.cat([rots_gt, trans_gt.unsqueeze(-1)], dim=-1)
        bottom_row = torch.zeros((*rots_gt.shape[:-2], 1, 4))
        bottom_row[..., 3] = 1
        transforms_gt = torch.cat([transforms_gt, bottom_row], dim=-2)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
254

255
        transforms_repro = out.to_tensor_4x4().cpu()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
256

257
258
259
260
261
262
263
264
265
266
        self.assertTrue(
            torch.max(torch.abs(transforms_gt - transforms_repro) < consts.eps)
        )

    def test_frames_and_literature_positions_to_atom14_pos_shape(self):
        batch_size = consts.batch_size
        n_res = consts.n_res

        rots = torch.rand((batch_size, n_res, 8, 3, 3))
        trans = torch.rand((batch_size, n_res, 8, 3))
267
        ts = Rigid(Rotation(rot_mats=rots), trans)
268
269
270
271
272
273
274
275
276
277
278

        f = torch.randint(low=0, high=21, size=(batch_size, n_res)).long()

        xyz = feats.frames_and_literature_positions_to_atom14_pos(
            ts,
            f,
            torch.tensor(restype_rigid_group_default_frame),
            torch.tensor(restype_atom14_to_rigid_group),
            torch.tensor(restype_atom14_mask),
            torch.tensor(restype_atom14_rigid_group_positions),
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
279

280
281
282
283
284
285
286
287
288
        self.assertTrue(xyz.shape == (batch_size, n_res, 14, 3))

    @compare_utils.skip_unless_alphafold_installed()
    def test_frames_and_literature_positions_to_atom14_pos_compare(self):
        def run_f(aatype, affines):
            am = alphafold.model
            return am.all_atom.frames_and_literature_positions_to_atom14_pos(
                aatype, affines
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
289

290
        f = hk.transform(run_f)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
291

292
        n_res = consts.n_res
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
293

294
        aatype = np.random.randint(0, 21, size=(n_res,))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
295

296
297
        affines = random_affines_4x4((n_res, 8))
        rigids = alphafold.model.r3.rigids_from_tensor4x4(affines)
298
299
300
        transformations = Rigid.from_tensor_4x4(
            torch.as_tensor(affines).float()
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
301
302

        out_gt = f.apply({}, None, aatype, rigids)
303
304
305
306
307
308
        jax.tree_map(lambda x: x.block_until_ready(), out_gt)
        out_gt = torch.stack(
            [torch.as_tensor(np.array(x)) for x in out_gt], dim=-1
        )

        out_repro = feats.frames_and_literature_positions_to_atom14_pos(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
309
            transformations.cuda(),
310
311
312
313
314
315
            torch.as_tensor(aatype).cuda(),
            torch.tensor(restype_rigid_group_default_frame).cuda(),
            torch.tensor(restype_atom14_to_rigid_group).cuda(),
            torch.tensor(restype_atom14_mask).cuda(),
            torch.tensor(restype_atom14_rigid_group_positions).cuda(),
        ).cpu()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
316

317
318
319
320
321
        self.assertTrue(torch.max(torch.abs(out_gt - out_repro) < consts.eps))


if __name__ == "__main__":
    unittest.main()