"tests/test_models/test_singan.py" did not exist on "b7536f78b8574f78c30bc1603be632dabfff5541"
test_feats.py 10.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import numpy as np
import unittest

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
19
import openfold.features.data_transforms as data_transforms
20
21
22
23
24
25
26
27
28
from openfold.np.residue_constants import (
    restype_rigid_group_default_frame,
    restype_atom14_to_rigid_group,
    restype_atom14_mask,
    restype_atom14_rigid_group_positions,
)
from openfold.utils.affine_utils import T
import openfold.utils.feats as feats
from openfold.utils.tensor_utils import (
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
29
30
    tree_map,
    tensor_tree_map,
31
32
33
34
35
)
import tests.compare_utils as compare_utils
from tests.config import consts
from tests.data_utils import random_affines_4x4

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
36
if compare_utils.alphafold_is_installed():
37
38
39
40
41
42
43
44
45
46
47
48
49
50
    alphafold = compare_utils.import_alphafold()
    import jax
    import haiku as hk


class TestFeats(unittest.TestCase):
    @compare_utils.skip_unless_alphafold_installed()
    def test_pseudo_beta_fn_compare(self):
        def test_pbf(aatype, all_atom_pos, all_atom_mask):
            return alphafold.model.modules.pseudo_beta_fn(
                aatype,
                all_atom_pos,
                all_atom_mask,
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
51

52
        f = hk.transform(test_pbf)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
53
54
55

        n_res = consts.n_res

56
57
58
        aatype = np.random.randint(0, 22, (n_res,))
        all_atom_pos = np.random.rand(n_res, 37, 3).astype(np.float32)
        all_atom_mask = np.random.randint(0, 2, (n_res, 37))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
59

60
61
62
63
64
        out_gt_pos, out_gt_mask = f.apply(
            {}, None, aatype, all_atom_pos, all_atom_mask
        )
        out_gt_pos = torch.tensor(np.array(out_gt_pos.block_until_ready()))
        out_gt_mask = torch.tensor(np.array(out_gt_mask.block_until_ready()))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
65

66
67
68
69
70
71
72
        out_repro_pos, out_repro_mask = feats.pseudo_beta_fn(
            torch.tensor(aatype).cuda(),
            torch.tensor(all_atom_pos).cuda(),
            torch.tensor(all_atom_mask).cuda(),
        )
        out_repro_pos = out_repro_pos.cpu()
        out_repro_mask = out_repro_mask.cpu()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
73

74
75
76
77
78
79
80
81
82
83
84
        self.assertTrue(
            torch.max(torch.abs(out_gt_pos - out_repro_pos)) < consts.eps
        )
        self.assertTrue(
            torch.max(torch.abs(out_gt_mask - out_repro_mask)) < consts.eps
        )

    @compare_utils.skip_unless_alphafold_installed()
    def test_atom37_to_torsion_angles_compare(self):
        def run_test(aatype, all_atom_pos, all_atom_mask):
            return alphafold.model.all_atom.atom37_to_torsion_angles(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
85
86
                aatype,
                all_atom_pos,
87
88
89
                all_atom_mask,
                placeholder_for_undefined=False,
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
90

91
        f = hk.transform(run_test)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
92
93

        n_templ = 7
94
        n_res = 13
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
95

96
97
        aatype = np.random.randint(0, 22, (n_templ, n_res)).astype(np.int64)
        all_atom_pos = np.random.rand(n_templ, n_res, 37, 3).astype(np.float32)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
98
99
100
101
        all_atom_mask = np.random.randint(0, 2, (n_templ, n_res, 37)).astype(
            np.float32
        )

102
103
        out_gt = f.apply({}, None, aatype, all_atom_pos, all_atom_mask)
        out_gt = jax.tree_map(lambda x: torch.as_tensor(np.array(x)), out_gt)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
104

105
106
107
108
109
110
111
112
        out_repro = feats.atom37_to_torsion_angles(
            torch.as_tensor(aatype).cuda(),
            torch.as_tensor(all_atom_pos).cuda(),
            torch.as_tensor(all_atom_mask).cuda(),
        )
        tasc = out_repro["torsion_angles_sin_cos"].cpu()
        atasc = out_repro["alt_torsion_angles_sin_cos"].cpu()
        tam = out_repro["torsion_angles_mask"].cpu()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
113

114
115
116
        # This function is extremely sensitive to floating point imprecisions,
        # so it is given much greater latitude in comparison tests.
        self.assertTrue(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
117
118
119
120
121
122
            torch.mean(torch.abs(out_gt["torsion_angles_sin_cos"] - tasc))
            < 0.01
        )
        self.assertTrue(
            torch.mean(torch.abs(out_gt["alt_torsion_angles_sin_cos"] - atasc))
            < 0.01
123
124
        )
        self.assertTrue(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
125
126
            torch.max(torch.abs(out_gt["torsion_angles_mask"] - tam))
            < consts.eps
127
128
129
130
131
132
133
134
        )

    @compare_utils.skip_unless_alphafold_installed()
    def test_atom37_to_frames_compare(self):
        def run_atom37_to_frames(aatype, all_atom_positions, all_atom_mask):
            return alphafold.model.all_atom.atom37_to_frames(
                aatype, all_atom_positions, all_atom_mask
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
135

136
        f = hk.transform(run_atom37_to_frames)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
137
138
139

        n_res = consts.n_res

140
141
        batch = {
            "aatype": np.random.randint(0, 21, (n_res,)),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
142
143
144
145
146
147
            "all_atom_positions": np.random.rand(n_res, 37, 3).astype(
                np.float32
            ),
            "all_atom_mask": np.random.randint(0, 2, (n_res, 37)).astype(
                np.float32
            ),
148
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
149

150
151
        out_gt = f.apply({}, None, **batch)
        to_tensor = lambda t: torch.tensor(np.array(t))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
152
153
        out_gt = {k: to_tensor(v) for k, v in out_gt.items()}

154
155
156
        def flat12_to_4x4(flat12):
            rot = flat12[..., :9].view(*flat12.shape[:-1], 3, 3)
            trans = flat12[..., 9:]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
157

158
159
160
161
            four_by_four = torch.zeros(*flat12.shape[:-1], 4, 4)
            four_by_four[..., :3, :3] = rot
            four_by_four[..., :3, 3] = trans
            four_by_four[..., 3, 3] = 1
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
162

163
            return four_by_four
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
164

165
166
167
168
169
170
        out_gt["rigidgroups_gt_frames"] = flat12_to_4x4(
            out_gt["rigidgroups_gt_frames"]
        )
        out_gt["rigidgroups_alt_gt_frames"] = flat12_to_4x4(
            out_gt["rigidgroups_alt_gt_frames"]
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
171

172
173
        to_tensor = lambda t: torch.tensor(np.array(t)).cuda()
        batch = tree_map(to_tensor, batch, np.ndarray)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
174

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
175
        out_repro = data_transforms.atom37_to_frames(batch)
176
        out_repro = tensor_tree_map(lambda t: t.cpu(), out_repro)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
177
178

        for k, v in out_gt.items():
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
            self.assertTrue(
                torch.max(torch.abs(out_gt[k] - out_repro[k])) < consts.eps
            )

    def test_torsion_angles_to_frames_shape(self):
        batch_size = 2
        n = 5
        rots = torch.rand((batch_size, n, 3, 3))
        trans = torch.rand((batch_size, n, 3))
        ts = T(rots, trans)

        angles = torch.rand((batch_size, n, 7, 2))

        aas = torch.tensor([i % 2 for i in range(n)])
        aas = torch.stack([aas for _ in range(batch_size)])

        frames = feats.torsion_angles_to_frames(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
196
197
198
            ts,
            angles,
            aas,
199
200
201
202
            torch.tensor(restype_rigid_group_default_frame),
        )

        self.assertTrue(frames.shape == (batch_size, n, 8))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
203

204
205
206
    @compare_utils.skip_unless_alphafold_installed()
    def test_torsion_angles_to_frames_compare(self):
        def run_torsion_angles_to_frames(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
207
            aatype, backb_to_global, torsion_angles_sin_cos
208
209
210
211
212
213
        ):
            return alphafold.model.all_atom.torsion_angles_to_frames(
                aatype,
                backb_to_global,
                torsion_angles_sin_cos,
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
214

215
        f = hk.transform(run_torsion_angles_to_frames)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
216

217
        n_res = consts.n_res
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
218

219
        aatype = np.random.randint(0, 21, size=(n_res,))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
220

221
222
223
        affines = random_affines_4x4((n_res,))
        rigids = alphafold.model.r3.rigids_from_tensor4x4(affines)
        transformations = T.from_4x4(torch.as_tensor(affines).float())
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
224

225
        torsion_angles_sin_cos = np.random.rand(n_res, 7, 2)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
226
227
228

        out_gt = f.apply({}, None, aatype, rigids, torsion_angles_sin_cos)

229
        jax.tree_map(lambda x: x.block_until_ready(), out_gt)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
230

231
232
233
234
235
236
        out = feats.torsion_angles_to_frames(
            transformations.cuda(),
            torch.as_tensor(torsion_angles_sin_cos).cuda(),
            torch.as_tensor(aatype).cuda(),
            torch.tensor(restype_rigid_group_default_frame).cuda(),
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
237

238
        # Convert the Rigids to 4x4 transformation tensors
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
239
        rots_gt = list(map(lambda x: torch.as_tensor(np.array(x)), out_gt.rot))
240
241
242
243
244
245
246
247
248
249
        trans_gt = list(
            map(lambda x: torch.as_tensor(np.array(x)), out_gt.trans)
        )
        rots_gt = torch.cat([x.unsqueeze(-1) for x in rots_gt], dim=-1)
        rots_gt = rots_gt.view(*rots_gt.shape[:-1], 3, 3)
        trans_gt = torch.cat([x.unsqueeze(-1) for x in trans_gt], dim=-1)
        transforms_gt = torch.cat([rots_gt, trans_gt.unsqueeze(-1)], dim=-1)
        bottom_row = torch.zeros((*rots_gt.shape[:-2], 1, 4))
        bottom_row[..., 3] = 1
        transforms_gt = torch.cat([transforms_gt, bottom_row], dim=-2)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
250

251
        transforms_repro = out.to_4x4().cpu()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
252

253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
        self.assertTrue(
            torch.max(torch.abs(transforms_gt - transforms_repro) < consts.eps)
        )

    def test_frames_and_literature_positions_to_atom14_pos_shape(self):
        batch_size = consts.batch_size
        n_res = consts.n_res

        rots = torch.rand((batch_size, n_res, 8, 3, 3))
        trans = torch.rand((batch_size, n_res, 8, 3))
        ts = T(rots, trans)

        f = torch.randint(low=0, high=21, size=(batch_size, n_res)).long()

        xyz = feats.frames_and_literature_positions_to_atom14_pos(
            ts,
            f,
            torch.tensor(restype_rigid_group_default_frame),
            torch.tensor(restype_atom14_to_rigid_group),
            torch.tensor(restype_atom14_mask),
            torch.tensor(restype_atom14_rigid_group_positions),
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
275

276
277
278
279
280
281
282
283
284
        self.assertTrue(xyz.shape == (batch_size, n_res, 14, 3))

    @compare_utils.skip_unless_alphafold_installed()
    def test_frames_and_literature_positions_to_atom14_pos_compare(self):
        def run_f(aatype, affines):
            am = alphafold.model
            return am.all_atom.frames_and_literature_positions_to_atom14_pos(
                aatype, affines
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
285

286
        f = hk.transform(run_f)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
287

288
        n_res = consts.n_res
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
289

290
        aatype = np.random.randint(0, 21, size=(n_res,))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
291

292
293
294
        affines = random_affines_4x4((n_res, 8))
        rigids = alphafold.model.r3.rigids_from_tensor4x4(affines)
        transformations = T.from_4x4(torch.as_tensor(affines).float())
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
295
296

        out_gt = f.apply({}, None, aatype, rigids)
297
298
299
300
301
302
        jax.tree_map(lambda x: x.block_until_ready(), out_gt)
        out_gt = torch.stack(
            [torch.as_tensor(np.array(x)) for x in out_gt], dim=-1
        )

        out_repro = feats.frames_and_literature_positions_to_atom14_pos(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
303
            transformations.cuda(),
304
305
306
307
308
309
            torch.as_tensor(aatype).cuda(),
            torch.tensor(restype_rigid_group_default_frame).cuda(),
            torch.tensor(restype_atom14_to_rigid_group).cuda(),
            torch.tensor(restype_atom14_mask).cuda(),
            torch.tensor(restype_atom14_rigid_group_positions).cuda(),
        ).cpu()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
310

311
312
313
314
315
        self.assertTrue(torch.max(torch.abs(out_gt - out_repro) < consts.eps))


if __name__ == "__main__":
    unittest.main()