test_feats.py 11 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import numpy as np
import unittest

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
19
import openfold.data.data_transforms as data_transforms
20
21
22
23
24
25
26
27
28
from openfold.np.residue_constants import (
    restype_rigid_group_default_frame,
    restype_atom14_to_rigid_group,
    restype_atom14_mask,
    restype_atom14_rigid_group_positions,
)
from openfold.utils.affine_utils import T
import openfold.utils.feats as feats
from openfold.utils.tensor_utils import (
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
29
30
    tree_map,
    tensor_tree_map,
31
32
33
34
35
)
import tests.compare_utils as compare_utils
from tests.config import consts
from tests.data_utils import random_affines_4x4

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
36
if compare_utils.alphafold_is_installed():
37
38
39
40
41
42
43
44
45
46
47
48
49
50
    alphafold = compare_utils.import_alphafold()
    import jax
    import haiku as hk


class TestFeats(unittest.TestCase):
    @compare_utils.skip_unless_alphafold_installed()
    def test_pseudo_beta_fn_compare(self):
        def test_pbf(aatype, all_atom_pos, all_atom_mask):
            return alphafold.model.modules.pseudo_beta_fn(
                aatype,
                all_atom_pos,
                all_atom_mask,
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
51

52
        f = hk.transform(test_pbf)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
53
54
55

        n_res = consts.n_res

56
57
58
        aatype = np.random.randint(0, 22, (n_res,))
        all_atom_pos = np.random.rand(n_res, 37, 3).astype(np.float32)
        all_atom_mask = np.random.randint(0, 2, (n_res, 37))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
59

60
61
62
63
64
        out_gt_pos, out_gt_mask = f.apply(
            {}, None, aatype, all_atom_pos, all_atom_mask
        )
        out_gt_pos = torch.tensor(np.array(out_gt_pos.block_until_ready()))
        out_gt_mask = torch.tensor(np.array(out_gt_mask.block_until_ready()))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
65

66
67
68
69
70
71
72
        out_repro_pos, out_repro_mask = feats.pseudo_beta_fn(
            torch.tensor(aatype).cuda(),
            torch.tensor(all_atom_pos).cuda(),
            torch.tensor(all_atom_mask).cuda(),
        )
        out_repro_pos = out_repro_pos.cpu()
        out_repro_mask = out_repro_mask.cpu()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
73

74
75
76
77
78
79
80
81
82
83
84
        self.assertTrue(
            torch.max(torch.abs(out_gt_pos - out_repro_pos)) < consts.eps
        )
        self.assertTrue(
            torch.max(torch.abs(out_gt_mask - out_repro_mask)) < consts.eps
        )

    @compare_utils.skip_unless_alphafold_installed()
    def test_atom37_to_torsion_angles_compare(self):
        def run_test(aatype, all_atom_pos, all_atom_mask):
            return alphafold.model.all_atom.atom37_to_torsion_angles(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
85
86
                aatype,
                all_atom_pos,
87
88
89
                all_atom_mask,
                placeholder_for_undefined=False,
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
90

91
        f = hk.transform(run_test)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
92
93

        n_templ = 7
94
        n_res = 13
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
95

96
97
        aatype = np.random.randint(0, 22, (n_templ, n_res)).astype(np.int64)
        all_atom_pos = np.random.rand(n_templ, n_res, 37, 3).astype(np.float32)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
98
99
100
101
        all_atom_mask = np.random.randint(0, 2, (n_templ, n_res, 37)).astype(
            np.float32
        )

102
103
        out_gt = f.apply({}, None, aatype, all_atom_pos, all_atom_mask)
        out_gt = jax.tree_map(lambda x: torch.as_tensor(np.array(x)), out_gt)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
104

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
105
106
107
108
109
110
        out_repro = data_transforms.atom37_to_torsion_angles()(
            {
                "aatype": torch.as_tensor(aatype).cuda(),
                "all_atom_positions": torch.as_tensor(all_atom_pos).cuda(),
                "all_atom_mask": torch.as_tensor(all_atom_mask).cuda(),
            },
111
112
113
114
        )
        tasc = out_repro["torsion_angles_sin_cos"].cpu()
        atasc = out_repro["alt_torsion_angles_sin_cos"].cpu()
        tam = out_repro["torsion_angles_mask"].cpu()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
115

116
117
118
        # This function is extremely sensitive to floating point imprecisions,
        # so it is given much greater latitude in comparison tests.
        self.assertTrue(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
119
120
121
122
123
124
            torch.mean(torch.abs(out_gt["torsion_angles_sin_cos"] - tasc))
            < 0.01
        )
        self.assertTrue(
            torch.mean(torch.abs(out_gt["alt_torsion_angles_sin_cos"] - atasc))
            < 0.01
125
126
        )
        self.assertTrue(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
127
128
            torch.max(torch.abs(out_gt["torsion_angles_mask"] - tam))
            < consts.eps
129
130
131
132
133
134
135
136
        )

    @compare_utils.skip_unless_alphafold_installed()
    def test_atom37_to_frames_compare(self):
        def run_atom37_to_frames(aatype, all_atom_positions, all_atom_mask):
            return alphafold.model.all_atom.atom37_to_frames(
                aatype, all_atom_positions, all_atom_mask
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
137

138
        f = hk.transform(run_atom37_to_frames)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
139
140
141

        n_res = consts.n_res

142
143
        batch = {
            "aatype": np.random.randint(0, 21, (n_res,)),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
144
145
146
147
148
149
            "all_atom_positions": np.random.rand(n_res, 37, 3).astype(
                np.float32
            ),
            "all_atom_mask": np.random.randint(0, 2, (n_res, 37)).astype(
                np.float32
            ),
150
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
151

152
153
        out_gt = f.apply({}, None, **batch)
        to_tensor = lambda t: torch.tensor(np.array(t))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
154
155
        out_gt = {k: to_tensor(v) for k, v in out_gt.items()}

156
157
158
        def flat12_to_4x4(flat12):
            rot = flat12[..., :9].view(*flat12.shape[:-1], 3, 3)
            trans = flat12[..., 9:]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
159

160
161
162
163
            four_by_four = torch.zeros(*flat12.shape[:-1], 4, 4)
            four_by_four[..., :3, :3] = rot
            four_by_four[..., :3, 3] = trans
            four_by_four[..., 3, 3] = 1
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
164

165
            return four_by_four
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
166

167
168
169
170
171
172
        out_gt["rigidgroups_gt_frames"] = flat12_to_4x4(
            out_gt["rigidgroups_gt_frames"]
        )
        out_gt["rigidgroups_alt_gt_frames"] = flat12_to_4x4(
            out_gt["rigidgroups_alt_gt_frames"]
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
173

174
175
        to_tensor = lambda t: torch.tensor(np.array(t)).cuda()
        batch = tree_map(to_tensor, batch, np.ndarray)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
176

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
177
        out_repro = data_transforms.atom37_to_frames(batch)
178
        out_repro = tensor_tree_map(lambda t: t.cpu(), out_repro)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
179
180

        for k, v in out_gt.items():
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
            self.assertTrue(
                torch.max(torch.abs(out_gt[k] - out_repro[k])) < consts.eps
            )

    def test_torsion_angles_to_frames_shape(self):
        batch_size = 2
        n = 5
        rots = torch.rand((batch_size, n, 3, 3))
        trans = torch.rand((batch_size, n, 3))
        ts = T(rots, trans)

        angles = torch.rand((batch_size, n, 7, 2))

        aas = torch.tensor([i % 2 for i in range(n)])
        aas = torch.stack([aas for _ in range(batch_size)])

        frames = feats.torsion_angles_to_frames(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
198
199
200
            ts,
            angles,
            aas,
201
202
203
204
            torch.tensor(restype_rigid_group_default_frame),
        )

        self.assertTrue(frames.shape == (batch_size, n, 8))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
205

206
207
208
    @compare_utils.skip_unless_alphafold_installed()
    def test_torsion_angles_to_frames_compare(self):
        def run_torsion_angles_to_frames(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
209
            aatype, backb_to_global, torsion_angles_sin_cos
210
211
212
213
214
215
        ):
            return alphafold.model.all_atom.torsion_angles_to_frames(
                aatype,
                backb_to_global,
                torsion_angles_sin_cos,
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
216

217
        f = hk.transform(run_torsion_angles_to_frames)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
218

219
        n_res = consts.n_res
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
220

221
        aatype = np.random.randint(0, 21, size=(n_res,))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
222

223
224
225
        affines = random_affines_4x4((n_res,))
        rigids = alphafold.model.r3.rigids_from_tensor4x4(affines)
        transformations = T.from_4x4(torch.as_tensor(affines).float())
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
226

227
        torsion_angles_sin_cos = np.random.rand(n_res, 7, 2)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
228
229
230

        out_gt = f.apply({}, None, aatype, rigids, torsion_angles_sin_cos)

231
        jax.tree_map(lambda x: x.block_until_ready(), out_gt)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
232

233
234
235
236
237
238
        out = feats.torsion_angles_to_frames(
            transformations.cuda(),
            torch.as_tensor(torsion_angles_sin_cos).cuda(),
            torch.as_tensor(aatype).cuda(),
            torch.tensor(restype_rigid_group_default_frame).cuda(),
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
239

240
        # Convert the Rigids to 4x4 transformation tensors
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
241
        rots_gt = list(map(lambda x: torch.as_tensor(np.array(x)), out_gt.rot))
242
243
244
245
246
247
248
249
250
251
        trans_gt = list(
            map(lambda x: torch.as_tensor(np.array(x)), out_gt.trans)
        )
        rots_gt = torch.cat([x.unsqueeze(-1) for x in rots_gt], dim=-1)
        rots_gt = rots_gt.view(*rots_gt.shape[:-1], 3, 3)
        trans_gt = torch.cat([x.unsqueeze(-1) for x in trans_gt], dim=-1)
        transforms_gt = torch.cat([rots_gt, trans_gt.unsqueeze(-1)], dim=-1)
        bottom_row = torch.zeros((*rots_gt.shape[:-2], 1, 4))
        bottom_row[..., 3] = 1
        transforms_gt = torch.cat([transforms_gt, bottom_row], dim=-2)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
252

253
        transforms_repro = out.to_4x4().cpu()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
254

255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
        self.assertTrue(
            torch.max(torch.abs(transforms_gt - transforms_repro) < consts.eps)
        )

    def test_frames_and_literature_positions_to_atom14_pos_shape(self):
        batch_size = consts.batch_size
        n_res = consts.n_res

        rots = torch.rand((batch_size, n_res, 8, 3, 3))
        trans = torch.rand((batch_size, n_res, 8, 3))
        ts = T(rots, trans)

        f = torch.randint(low=0, high=21, size=(batch_size, n_res)).long()

        xyz = feats.frames_and_literature_positions_to_atom14_pos(
            ts,
            f,
            torch.tensor(restype_rigid_group_default_frame),
            torch.tensor(restype_atom14_to_rigid_group),
            torch.tensor(restype_atom14_mask),
            torch.tensor(restype_atom14_rigid_group_positions),
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
277

278
279
280
281
282
283
284
285
286
        self.assertTrue(xyz.shape == (batch_size, n_res, 14, 3))

    @compare_utils.skip_unless_alphafold_installed()
    def test_frames_and_literature_positions_to_atom14_pos_compare(self):
        def run_f(aatype, affines):
            am = alphafold.model
            return am.all_atom.frames_and_literature_positions_to_atom14_pos(
                aatype, affines
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
287

288
        f = hk.transform(run_f)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
289

290
        n_res = consts.n_res
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
291

292
        aatype = np.random.randint(0, 21, size=(n_res,))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
293

294
295
296
        affines = random_affines_4x4((n_res, 8))
        rigids = alphafold.model.r3.rigids_from_tensor4x4(affines)
        transformations = T.from_4x4(torch.as_tensor(affines).float())
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
297
298

        out_gt = f.apply({}, None, aatype, rigids)
299
300
301
302
303
304
        jax.tree_map(lambda x: x.block_until_ready(), out_gt)
        out_gt = torch.stack(
            [torch.as_tensor(np.array(x)) for x in out_gt], dim=-1
        )

        out_repro = feats.frames_and_literature_positions_to_atom14_pos(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
305
            transformations.cuda(),
306
307
308
309
310
311
            torch.as_tensor(aatype).cuda(),
            torch.tensor(restype_rigid_group_default_frame).cuda(),
            torch.tensor(restype_atom14_to_rigid_group).cuda(),
            torch.tensor(restype_atom14_mask).cuda(),
            torch.tensor(restype_atom14_rigid_group_positions).cuda(),
        ).cpu()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
312

313
314
315
316
317
        self.assertTrue(torch.max(torch.abs(out_gt - out_repro) < consts.eps))


if __name__ == "__main__":
    unittest.main()