test_loss.py 28.2 KB
Newer Older
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import os
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
16
17
18
19
import math
import torch
import numpy as np
import unittest
20
import ml_collections as mlc
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
21

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
22
from openfold.data import data_transforms
23
24
from openfold.utils.affine_utils import T, affine_vector_to_4x4
import openfold.utils.feats as feats
25
26
27
28
29
30
from openfold.utils.loss import (
    torsion_angle_loss,
    compute_fape,
    between_residue_bond_loss,
    between_residue_clash_loss,
    find_structural_violations,
31
32
33
34
35
36
37
38
39
40
41
42
43
    compute_renamed_ground_truth,
    masked_msa_loss,
    distogram_loss,
    experimentally_resolved_loss,
    violation_loss,
    fape_loss,
    lddt_loss,
    supervised_chi_loss,
    backbone_loss,
    sidechain_loss,
    tm_loss,
)
from openfold.utils.tensor_utils import (
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
44
45
    tree_map,
    tensor_tree_map,
46
    dict_multimap,
47
48
49
)
import tests.compare_utils as compare_utils
from tests.config import consts
50
from tests.data_utils import random_affines_vector, random_affines_4x4
51

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
52
if compare_utils.alphafold_is_installed():
53
54
55
    alphafold = compare_utils.import_alphafold()
    import jax
    import haiku as hk
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
56
57
58
59


class TestLoss(unittest.TestCase):
    def test_run_torsion_angle_loss(self):
60
61
        batch_size = consts.batch_size
        n_res = consts.n_res
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
62

63
64
65
        a = torch.rand((batch_size, n_res, 7, 2))
        a_gt = torch.rand((batch_size, n_res, 7, 2))
        a_alt_gt = torch.rand((batch_size, n_res, 7, 2))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
66
67
68
69

        loss = torsion_angle_loss(a, a_gt, a_alt_gt)

    def test_run_fape(self):
70
        batch_size = consts.batch_size
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
71
72
73
74
75
76
77
78
79
80
81
        n_frames = 7
        n_atoms = 5

        x = torch.rand((batch_size, n_atoms, 3))
        x_gt = torch.rand((batch_size, n_atoms, 3))
        rots = torch.rand((batch_size, n_frames, 3, 3))
        rots_gt = torch.rand((batch_size, n_frames, 3, 3))
        trans = torch.rand((batch_size, n_frames, 3))
        trans_gt = torch.rand((batch_size, n_frames, 3))
        t = T(rots, trans)
        t_gt = T(rots_gt, trans_gt)
82
83
84
85
86
87
88
89
90
91
92
93
94
        frames_mask = torch.randint(0, 2, (batch_size, n_frames)).float()
        positions_mask = torch.randint(0, 2, (batch_size, n_atoms)).float()
        length_scale = 10

        loss = compute_fape(
            pred_frames=t,
            target_frames=t_gt,
            frames_mask=frames_mask,
            pred_positions=x,
            target_positions=x_gt,
            positions_mask=positions_mask,
            length_scale=length_scale,
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
95

96
97
98
    def test_run_between_residue_bond_loss(self):
        bs = consts.batch_size
        n = consts.n_res
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
99
100
101
        pred_pos = torch.rand(bs, n, 14, 3)
        pred_atom_mask = torch.randint(0, 2, (bs, n, 14))
        residue_index = torch.arange(n).unsqueeze(0)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
102
103
104
105
106
107
108
109
110
        aatype = torch.randint(
            0,
            22,
            (
                bs,
                n,
            ),
        )

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
111
112
113
        between_residue_bond_loss(
            pred_pos,
            pred_atom_mask,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
114
            residue_index,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
115
116
117
            aatype,
        )

118
119
120
121
122
123
124
125
126
    @compare_utils.skip_unless_alphafold_installed()
    def test_between_residue_bond_loss_compare(self):
        def run_brbl(pred_pos, pred_atom_mask, residue_index, aatype):
            return alphafold.model.all_atom.between_residue_bond_loss(
                pred_pos,
                pred_atom_mask,
                residue_index,
                aatype,
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
127

128
        f = hk.transform(run_brbl)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
129
130

        n_res = consts.n_res
131
        pred_pos = np.random.rand(n_res, 14, 3).astype(np.float32)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
132
        pred_atom_mask = np.random.randint(0, 2, (n_res, 14)).astype(np.float32)
133
134
        residue_index = np.arange(n_res)
        aatype = np.random.randint(0, 22, (n_res,))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
135

136
        out_gt = f.apply(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
137
138
139
140
            {},
            None,
            pred_pos,
            pred_atom_mask,
141
142
143
144
145
            residue_index,
            aatype,
        )
        out_gt = jax.tree_map(lambda x: x.block_until_ready(), out_gt)
        out_gt = jax.tree_map(lambda x: torch.tensor(np.copy(x)), out_gt)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
146

147
148
149
150
151
152
153
        out_repro = between_residue_bond_loss(
            torch.tensor(pred_pos).cuda(),
            torch.tensor(pred_atom_mask).cuda(),
            torch.tensor(residue_index).cuda(),
            torch.tensor(aatype).cuda(),
        )
        out_repro = tensor_tree_map(lambda x: x.cpu(), out_repro)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
154

155
156
157
158
159
        for k in out_gt.keys():
            self.assertTrue(
                torch.max(torch.abs(out_gt[k] - out_repro[k])) < consts.eps
            )

160
    def test_run_between_residue_clash_loss(self):
161
162
163
        bs = consts.batch_size
        n = consts.n_res

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
164
        pred_pos = torch.rand(bs, n, 14, 3)
165
        pred_atom_mask = torch.randint(0, 2, (bs, n, 14)).float()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
166
167
168
169
170
171
        atom14_atom_radius = torch.rand(bs, n, 14)
        residue_index = torch.arange(n).unsqueeze(0)

        loss = between_residue_clash_loss(
            pred_pos,
            pred_atom_mask,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
172
            atom14_atom_radius,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
173
174
175
            residue_index,
        )

176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
    @compare_utils.skip_unless_alphafold_installed()
    def test_between_residue_clash_loss_compare(self):
        def run_brcl(pred_pos, atom_exists, atom_radius, res_ind):
            return alphafold.model.all_atom.between_residue_clash_loss(
                pred_pos,
                atom_exists,
                atom_radius,
                res_ind,
            )

        f = hk.transform(run_brcl)

        n_res = consts.n_res

        pred_pos = np.random.rand(n_res, 14, 3).astype(np.float32)
        atom_exists = np.random.randint(0, 2, (n_res, 14)).astype(np.float32)
        atom_radius = np.random.rand(n_res, 14).astype(np.float32)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
193
194
195
196
        res_ind = np.arange(
            n_res,
        )

197
        out_gt = f.apply(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
198
199
            {},
            None,
200
201
202
203
204
205
206
            pred_pos,
            atom_exists,
            atom_radius,
            res_ind,
        )
        out_gt = jax.tree_map(lambda x: x.block_until_ready(), out_gt)
        out_gt = jax.tree_map(lambda x: torch.tensor(np.copy(x)), out_gt)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
207

208
209
210
211
212
213
214
        out_repro = between_residue_clash_loss(
            torch.tensor(pred_pos).cuda(),
            torch.tensor(atom_exists).cuda(),
            torch.tensor(atom_radius).cuda(),
            torch.tensor(res_ind).cuda(),
        )
        out_repro = tensor_tree_map(lambda x: x.cpu(), out_repro)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
215

216
217
218
219
220
        for k in out_gt.keys():
            self.assertTrue(
                torch.max(torch.abs(out_gt[k] - out_repro[k])) < consts.eps
            )

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
221
    def test_find_structural_violations(self):
222
        n = consts.n_res
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
223
224
225
226

        batch = {
            "atom14_atom_exists": torch.randint(0, 2, (n, 14)),
            "residue_index": torch.arange(n),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
227
            "aatype": torch.randint(0, 20, (n,)),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
228
229
230
231
            "residx_atom14_to_atom37": torch.randint(0, 37, (n, 14)).long(),
        }

        pred_pos = torch.rand(n, 14, 3)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
232

233
        config = {
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
234
235
            "clash_overlap_tolerance": 1.5,
            "violation_tolerance_factor": 12.0,
236
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
237

238
        find_structural_violations(batch, pred_pos, **config)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
239

240
241
242
243
244
245
246
247
248
249
250
251
252
253
    @compare_utils.skip_unless_alphafold_installed()
    def test_find_structural_violations_compare(self):
        def run_fsv(batch, pos, config):
            cwd = os.getcwd()
            os.chdir("tests/test_data")
            loss = alphafold.model.folding.find_structural_violations(
                batch,
                pos,
                config,
            )
            os.chdir(cwd)
            return loss

        f = hk.transform(run_fsv)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
254

255
        n_res = consts.n_res
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
256

257
258
259
        batch = {
            "atom14_atom_exists": np.random.randint(0, 2, (n_res, 14)),
            "residue_index": np.arange(n_res),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
260
            "aatype": np.random.randint(0, 20, (n_res,)),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
261
262
263
            "residx_atom14_to_atom37": np.random.randint(
                0, 37, (n_res, 14)
            ).astype(np.int64),
264
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
265

266
        pred_pos = np.random.rand(n_res, 14, 3)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
267
268
269
270
271
272

        config = mlc.ConfigDict(
            {
                "clash_overlap_tolerance": 1.5,
                "violation_tolerance_factor": 12.0,
            }
273
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
274
275

        out_gt = f.apply({}, None, batch, pred_pos, config)
276
277
        out_gt = jax.tree_map(lambda x: x.block_until_ready(), out_gt)
        out_gt = jax.tree_map(lambda x: torch.tensor(np.copy(x)), out_gt)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
278
279

        batch = tree_map(lambda x: torch.tensor(x).cuda(), batch, np.ndarray)
280
281
282
283
284
285
        out_repro = find_structural_violations(
            batch,
            torch.tensor(pred_pos).cuda(),
            **config,
        )
        out_repro = tensor_tree_map(lambda x: x.cpu(), out_repro)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
286

287
288
        def compare(out):
            gt, repro = out
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
289
290
            assert torch.max(torch.abs(gt - repro)) < consts.eps

291
292
293
294
295
296
297
298
299
        dict_multimap(compare, [out_gt, out_repro])

    @compare_utils.skip_unless_alphafold_installed()
    def test_compute_renamed_ground_truth_compare(self):
        def run_crgt(batch, atom14_pred_pos):
            return alphafold.model.folding.compute_renamed_ground_truth(
                batch,
                atom14_pred_pos,
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
300

301
        f = hk.transform(run_crgt)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
302

303
        n_res = consts.n_res
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
304

305
306
        batch = {
            "seq_mask": np.random.randint(0, 2, (n_res,)).astype(np.float32),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
307
            "aatype": np.random.randint(0, 20, (n_res,)),
308
            "atom14_gt_positions": np.random.rand(n_res, 14, 3),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
309
310
311
312
313
314
315
316
317
            "atom14_gt_exists": np.random.randint(0, 2, (n_res, 14)).astype(
                np.float32
            ),
            "all_atom_mask": np.random.randint(0, 2, (n_res, 37)).astype(
                np.float32
            ),
            "all_atom_positions": np.random.rand(n_res, 37, 3).astype(
                np.float32
            ),
318
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
319

320
321
        def _build_extra_feats_np():
            b = tree_map(lambda n: torch.tensor(n), batch, np.ndarray)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
322
323
            b = data_transforms.make_atom14_masks(b)
            b = data_transforms.make_atom14_positions(b)
324
            return tensor_tree_map(lambda t: np.array(t), b)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
325

326
        batch = _build_extra_feats_np()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
327

328
        atom14_pred_pos = np.random.rand(n_res, 14, 3)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
329

330
331
        out_gt = f.apply({}, None, batch, atom14_pred_pos)
        out_gt = jax.tree_map(lambda x: torch.tensor(np.array(x)), out_gt)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
332
333

        batch = tree_map(lambda x: torch.tensor(x).cuda(), batch, np.ndarray)
334
        atom14_pred_pos = torch.tensor(atom14_pred_pos).cuda()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
335

336
337
        out_repro = compute_renamed_ground_truth(batch, atom14_pred_pos)
        out_repro = tensor_tree_map(lambda t: t.cpu(), out_repro)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
338

339
340
341
342
343
344
345
346
347
348
349
350
351
        for k in out_repro:
            self.assertTrue(
                torch.max(torch.abs(out_gt[k] - out_repro[k])) < consts.eps
            )

    @compare_utils.skip_unless_alphafold_installed()
    def test_msa_loss_compare(self):
        def run_msa_loss(value, batch):
            config = compare_utils.get_alphafold_config()
            msa_head = alphafold.model.modules.MaskedMsaHead(
                config.model.heads.masked_msa, config.model.global_config
            )
            return msa_head.loss(value, batch)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
352

353
        f = hk.transform(run_msa_loss)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
354

355
356
        n_res = consts.n_res
        n_seq = consts.n_seq
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
357

358
359
360
        value = {
            "logits": np.random.rand(n_res, n_seq, 23).astype(np.float32),
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
361

362
363
        batch = {
            "true_msa": np.random.randint(0, 21, (n_res, n_seq)),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
364
365
366
            "bert_mask": np.random.randint(0, 2, (n_res, n_seq)).astype(
                np.float32
            ),
367
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
368

369
370
        out_gt = f.apply({}, None, value, batch)["loss"]
        out_gt = torch.tensor(np.array(out_gt))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
371
372
373
374

        value = tree_map(lambda x: torch.tensor(x).cuda(), value, np.ndarray)
        batch = tree_map(lambda x: torch.tensor(x).cuda(), batch, np.ndarray)

375
376
377
378
        with torch.no_grad():
            out_repro = masked_msa_loss(
                value["logits"],
                **batch,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
379
            )
380
        out_repro = tensor_tree_map(lambda t: t.cpu(), out_repro)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
381

382
383
384
385
386
387
        self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)

    @compare_utils.skip_unless_alphafold_installed()
    def test_distogram_loss_compare(self):
        config = compare_utils.get_alphafold_config()
        c_distogram = config.model.heads.distogram
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
388

389
390
391
392
393
        def run_distogram_loss(value, batch):
            dist_head = alphafold.model.modules.DistogramHead(
                c_distogram, config.model.global_config
            )
            return dist_head.loss(value, batch)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
394

395
        f = hk.transform(run_distogram_loss)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
396

397
        n_res = consts.n_res
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
398

399
        value = {
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
400
401
402
            "logits": np.random.rand(n_res, n_res, c_distogram.num_bins).astype(
                np.float32
            ),
403
404
405
406
            "bin_edges": np.linspace(
                c_distogram.first_break,
                c_distogram.last_break,
                c_distogram.num_bins,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
407
            ),
408
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
409

410
411
        batch = {
            "pseudo_beta": np.random.rand(n_res, 3).astype(np.float32),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
412
            "pseudo_beta_mask": np.random.randint(0, 2, (n_res,)),
413
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
414

415
416
        out_gt = f.apply({}, None, value, batch)["loss"]
        out_gt = torch.tensor(np.array(out_gt))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
417
418
419
420
421

        value = tree_map(lambda x: torch.tensor(x).cuda(), value, np.ndarray)

        batch = tree_map(lambda x: torch.tensor(x).cuda(), batch, np.ndarray)

422
423
424
425
426
427
428
        with torch.no_grad():
            out_repro = distogram_loss(
                logits=value["logits"],
                min_bin=c_distogram.first_break,
                max_bin=c_distogram.last_break,
                no_bins=c_distogram.num_bins,
                **batch,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
429
430
            )

431
        out_repro = tensor_tree_map(lambda t: t.cpu(), out_repro)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
432

433
434
435
436
437
438
        self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)

    @compare_utils.skip_unless_alphafold_installed()
    def test_experimentally_resolved_loss_compare(self):
        config = compare_utils.get_alphafold_config()
        c_experimentally_resolved = config.model.heads.experimentally_resolved
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
439

440
441
442
443
444
        def run_experimentally_resolved_loss(value, batch):
            er_head = alphafold.model.modules.ExperimentallyResolvedHead(
                c_experimentally_resolved, config.model.global_config
            )
            return er_head.loss(value, batch)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
445

446
        f = hk.transform(run_experimentally_resolved_loss)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
447

448
        n_res = consts.n_res
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
449

450
451
452
        value = {
            "logits": np.random.rand(n_res, 37).astype(np.float32),
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
453

454
455
456
        batch = {
            "all_atom_mask": np.random.randint(0, 2, (n_res, 37)),
            "atom37_atom_exists": np.random.randint(0, 2, (n_res, 37)),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
457
            "resolution": np.array(1.0),
458
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
459

460
461
        out_gt = f.apply({}, None, value, batch)["loss"]
        out_gt = torch.tensor(np.array(out_gt))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
462
463
464
465
466

        value = tree_map(lambda x: torch.tensor(x).cuda(), value, np.ndarray)

        batch = tree_map(lambda x: torch.tensor(x).cuda(), batch, np.ndarray)

467
468
469
470
471
472
        with torch.no_grad():
            out_repro = experimentally_resolved_loss(
                logits=value["logits"],
                min_resolution=c_experimentally_resolved.min_resolution,
                max_resolution=c_experimentally_resolved.max_resolution,
                **batch,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
473
474
            )

475
        out_repro = tensor_tree_map(lambda t: t.cpu(), out_repro)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
476

477
478
479
480
481
482
        self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)

    @compare_utils.skip_unless_alphafold_installed()
    def test_supervised_chi_loss_compare(self):
        config = compare_utils.get_alphafold_config()
        c_chi_loss = config.model.heads.structure_module
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
483

484
485
        def run_supervised_chi_loss(value, batch):
            ret = {
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
486
                "loss": jax.numpy.array(0.0),
487
488
489
490
491
492
493
494
495
496
497
498
            }
            alphafold.model.folding.supervised_chi_loss(
                ret, batch, value, c_chi_loss
            )
            return ret["loss"]

        f = hk.transform(run_supervised_chi_loss)

        n_res = consts.n_res

        value = {
            "sidechains": {
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
499
500
501
502
503
504
                "angles_sin_cos": np.random.rand(8, n_res, 7, 2).astype(
                    np.float32
                ),
                "unnormalized_angles_sin_cos": np.random.rand(
                    8, n_res, 7, 2
                ).astype(np.float32),
505
506
507
508
509
510
511
512
513
514
515
516
            }
        }

        batch = {
            "aatype": np.random.randint(0, 21, (n_res,)),
            "seq_mask": np.random.randint(0, 2, (n_res,)),
            "chi_mask": np.random.randint(0, 2, (n_res, 4)),
            "chi_angles": np.random.rand(n_res, 4).astype(np.float32),
        }

        out_gt = f.apply({}, None, value, batch)
        out_gt = torch.tensor(np.array(out_gt.block_until_ready()))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
517
        value = tree_map(lambda x: torch.tensor(x).cuda(), value, np.ndarray)
518

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
519
        batch = tree_map(lambda x: torch.tensor(x).cuda(), batch, np.ndarray)
520
521
522
523
524
525
526
527
528
529
530
531
532

        batch["chi_angles_sin_cos"] = torch.stack(
            [
                torch.sin(batch["chi_angles"]),
                torch.cos(batch["chi_angles"]),
            ],
            dim=-1,
        )

        with torch.no_grad():
            out_repro = supervised_chi_loss(
                chi_weight=c_chi_loss.chi_weight,
                angle_norm_weight=c_chi_loss.angle_norm_weight,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
533
534
535
                **{**batch, **value["sidechains"]},
            )

536
537
538
539
540
541
542
543
        out_repro = tensor_tree_map(lambda t: t.cpu(), out_repro)

        self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)

    @compare_utils.skip_unless_alphafold_installed()
    def test_violation_loss_compare(self):
        config = compare_utils.get_alphafold_config()
        c_viol = config.model.heads.structure_module
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
544

545
546
        def run_viol_loss(batch, atom14_pred_pos):
            ret = {
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
547
                "loss": np.array(0.0).astype(np.float32),
548
549
            }
            value = {}
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
550
551
552
553
554
555
            value[
                "violations"
            ] = alphafold.model.folding.find_structural_violations(
                batch,
                atom14_pred_pos,
                c_viol,
556
557
            )
            alphafold.model.folding.structural_violation_loss(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
558
559
560
561
                ret,
                batch,
                value,
                c_viol,
562
563
            )
            return ret["loss"]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
564

565
        f = hk.transform(run_viol_loss)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
566

567
        n_res = consts.n_res
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
568

569
570
571
572
573
        batch = {
            "seq_mask": np.random.randint(0, 2, (n_res,)).astype(np.float32),
            "residue_index": np.arange(n_res),
            "aatype": np.random.randint(0, 21, (n_res,)),
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
574
575
576
        alphafold.model.tf.data_transforms.make_atom14_masks(batch)
        batch = {k: np.array(v) for k, v in batch.items()}

577
        atom14_pred_pos = np.random.rand(n_res, 14, 3).astype(np.float32)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
578

579
580
        out_gt = f.apply({}, None, batch, atom14_pred_pos)
        out_gt = torch.tensor(np.array(out_gt.block_until_ready()))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
581
582

        batch = tree_map(lambda n: torch.tensor(n).cuda(), batch, np.ndarray)
583
        atom14_pred_pos = torch.tensor(atom14_pred_pos).cuda()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
584

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
585
        batch = data_transforms.make_atom14_masks(batch)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
586

587
588
589
590
591
        out_repro = violation_loss(
            find_structural_violations(batch, atom14_pred_pos, **c_viol),
            **batch,
        )
        out_repro = out_repro.cpu()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
592

593
594
595
596
597
598
        self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)

    @compare_utils.skip_unless_alphafold_installed()
    def test_lddt_loss_compare(self):
        config = compare_utils.get_alphafold_config()
        c_plddt = config.model.heads.predicted_lddt
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
599

600
601
602
603
604
        def run_plddt_loss(value, batch):
            head = alphafold.model.modules.PredictedLDDTHead(
                c_plddt, config.model.global_config
            )
            return head.loss(value, batch)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
605

606
        f = hk.transform(run_plddt_loss)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
607

608
        n_res = consts.n_res
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
609

610
611
        value = {
            "predicted_lddt": {
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
612
613
614
                "logits": np.random.rand(n_res, c_plddt.num_bins).astype(
                    np.float32
                ),
615
616
            },
            "structure_module": {
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
617
618
619
620
                "final_atom_positions": np.random.rand(n_res, 37, 3).astype(
                    np.float32
                ),
            },
621
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
622

623
        batch = {
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
624
625
626
627
628
629
630
            "all_atom_positions": np.random.rand(n_res, 37, 3).astype(
                np.float32
            ),
            "all_atom_mask": np.random.randint(0, 2, (n_res, 37)).astype(
                np.float32
            ),
            "resolution": np.array(1.0).astype(np.float32),
631
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
632

633
634
        out_gt = f.apply({}, None, value, batch)
        out_gt = torch.tensor(np.array(out_gt["loss"]))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
635

636
637
638
        to_tensor = lambda t: torch.tensor(t).cuda()
        value = tree_map(to_tensor, value, np.ndarray)
        batch = tree_map(to_tensor, batch, np.ndarray)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
639

640
641
642
643
644
645
        out_repro = lddt_loss(
            logits=value["predicted_lddt"]["logits"],
            all_atom_pred_pos=value["structure_module"]["final_atom_positions"],
            **{**batch, **c_plddt},
        )
        out_repro = out_repro.cpu()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
646

647
648
649
650
651
652
        self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)

    @compare_utils.skip_unless_alphafold_installed()
    def test_backbone_loss(self):
        config = compare_utils.get_alphafold_config()
        c_sm = config.model.heads.structure_module
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
653

654
655
        def run_bb_loss(batch, value):
            ret = {
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
656
                "loss": np.array(0.0),
657
658
659
660
661
662
663
664
665
666
            }
            alphafold.model.folding.backbone_loss(ret, batch, value, c_sm)
            return ret["loss"]

        f = hk.transform(run_bb_loss)

        n_res = consts.n_res

        batch = {
            "backbone_affine_tensor": random_affines_vector((n_res,)),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
667
668
669
670
            "backbone_affine_mask": np.random.randint(0, 2, (n_res,)).astype(
                np.float32
            ),
            "use_clamped_fape": np.array(0.0),
671
672
673
        }

        value = {
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
674
675
676
677
678
679
            "traj": random_affines_vector(
                (
                    c_sm.num_layer,
                    n_res,
                )
            ),
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
        }

        out_gt = f.apply({}, None, batch, value)
        out_gt = torch.tensor(np.array(out_gt.block_until_ready()))

        to_tensor = lambda t: torch.tensor(t).cuda()
        batch = tree_map(to_tensor, batch, np.ndarray)
        value = tree_map(to_tensor, value, np.ndarray)

        batch["backbone_affine_tensor"] = affine_vector_to_4x4(
            batch["backbone_affine_tensor"]
        )
        value["traj"] = affine_vector_to_4x4(value["traj"])

        out_repro = backbone_loss(traj=value["traj"], **{**batch, **c_sm})
        out_repro = out_repro.cpu()

        self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)

    @compare_utils.skip_unless_alphafold_installed()
    def test_sidechain_loss_compare(self):
        config = compare_utils.get_alphafold_config()
        c_sm = config.model.heads.structure_module
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
703

704
705
706
707
708
709
710
        def run_sidechain_loss(batch, value, atom14_pred_positions):
            batch = {
                **batch,
                **alphafold.model.all_atom.atom37_to_frames(
                    batch["aatype"],
                    batch["all_atom_positions"],
                    batch["all_atom_mask"],
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
711
                ),
712
713
714
            }
            v = {}
            v["sidechains"] = {}
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
715
716
717
718
            v["sidechains"][
                "frames"
            ] = alphafold.model.r3.rigids_from_tensor4x4(
                value["sidechains"]["frames"]
719
720
721
722
            )
            v["sidechains"]["atom_pos"] = alphafold.model.r3.vecs_from_tensor(
                value["sidechains"]["atom_pos"]
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
723
724
725
726
727
728
            v.update(
                alphafold.model.folding.compute_renamed_ground_truth(
                    batch,
                    atom14_pred_positions,
                )
            )
729
            value = v
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
730

731
732
            ret = alphafold.model.folding.sidechain_loss(batch, value, c_sm)
            return ret["loss"]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
733

734
        f = hk.transform(run_sidechain_loss)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
735

736
        n_res = consts.n_res
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
737

738
739
        batch = {
            "seq_mask": np.random.randint(0, 2, (n_res,)).astype(np.float32),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
740
            "aatype": np.random.randint(0, 20, (n_res,)),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
741
742
743
744
745
746
747
748
749
750
751
752
            "atom14_gt_positions": np.random.rand(n_res, 14, 3).astype(
                np.float32
            ),
            "atom14_gt_exists": np.random.randint(0, 2, (n_res, 14)).astype(
                np.float32
            ),
            "all_atom_positions": np.random.rand(n_res, 37, 3).astype(
                np.float32
            ),
            "all_atom_mask": np.random.randint(0, 2, (n_res, 37)).astype(
                np.float32
            ),
753
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
754

755
756
        def _build_extra_feats_np():
            b = tree_map(lambda n: torch.tensor(n), batch, np.ndarray)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
757
758
            b = data_transforms.make_atom14_masks(b)
            b = data_transforms.make_atom14_positions(b)
759
            return tensor_tree_map(lambda t: np.array(t), b)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
760
761
762

        batch = _build_extra_feats_np()

763
764
765
        value = {
            "sidechains": {
                "frames": random_affines_4x4((c_sm.num_layer, n_res, 8)),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
766
767
768
                "atom_pos": np.random.rand(c_sm.num_layer, n_res, 14, 3).astype(
                    np.float32
                ),
769
770
            }
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
771

772
        atom14_pred_pos = np.random.rand(n_res, 14, 3).astype(np.float32)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
773

774
775
        out_gt = f.apply({}, None, batch, value, atom14_pred_pos)
        out_gt = torch.tensor(np.array(out_gt.block_until_ready()))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
776

777
778
779
780
        to_tensor = lambda t: torch.tensor(t).cuda()
        batch = tree_map(to_tensor, batch, np.ndarray)
        value = tree_map(to_tensor, value, np.ndarray)
        atom14_pred_pos = to_tensor(atom14_pred_pos)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
781

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
782
        batch = data_transforms.atom37_to_frames(batch)
783
        batch.update(compute_renamed_ground_truth(batch, atom14_pred_pos))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
784

785
786
787
788
789
790
        out_repro = sidechain_loss(
            sidechain_frames=value["sidechains"]["frames"],
            sidechain_atom_pos=value["sidechains"]["atom_pos"],
            **{**batch, **c_sm},
        )
        out_repro = out_repro.cpu()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
791

792
793
794
795
796
797
        self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)

    @compare_utils.skip_unless_alphafold_installed()
    def test_tm_loss_compare(self):
        config = compare_utils.get_alphafold_config()
        c_tm = config.model.heads.predicted_aligned_error
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
798

799
800
801
802
803
804
805
806
        def run_tm_loss(representations, batch, value):
            head = alphafold.model.modules.PredictedAlignedErrorHead(
                c_tm, config.model.global_config
            )
            v = {}
            v.update(value)
            v["predicted_aligned_error"] = head(representations, batch, False)
            return head.loss(v, batch)["loss"]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
807

808
        f = hk.transform(run_tm_loss)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
809

810
        n_res = consts.n_res
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
811

812
        representations = {
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
813
            "pair": np.random.rand(n_res, n_res, consts.c_z).astype(np.float32),
814
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
815

816
817
        batch = {
            "backbone_affine_tensor": random_affines_vector((n_res,)),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
818
819
820
821
            "backbone_affine_mask": np.random.randint(0, 2, (n_res,)).astype(
                np.float32
            ),
            "resolution": np.array(1.0).astype(np.float32),
822
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
823

824
825
826
827
828
        value = {
            "structure_module": {
                "final_affines": random_affines_vector((n_res,)),
            }
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
829

830
831
832
833
834
835
        params = compare_utils.fetch_alphafold_module_weights(
            "alphafold/alphafold_iteration/predicted_aligned_error_head"
        )

        out_gt = f.apply(params, None, representations, batch, value)
        out_gt = torch.tensor(np.array(out_gt.block_until_ready()))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
836

837
838
839
840
        to_tensor = lambda n: torch.tensor(n).cuda()
        representations = tree_map(to_tensor, representations, np.ndarray)
        batch = tree_map(to_tensor, batch, np.ndarray)
        value = tree_map(to_tensor, value, np.ndarray)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
841
842
843

        batch["backbone_affine_tensor"] = affine_vector_to_4x4(
            batch["backbone_affine_tensor"]
844
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
845
846
        value["structure_module"]["final_affines"] = affine_vector_to_4x4(
            value["structure_module"]["final_affines"]
847
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
848

849
850
        model = compare_utils.get_global_pretrained_openfold()
        logits = model.aux_heads.tm(representations["pair"])
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
851

852
853
854
855
856
857
        out_repro = tm_loss(
            logits=logits,
            final_affine_tensor=value["structure_module"]["final_affines"],
            **{**batch, **c_tm},
        )
        out_repro = out_repro.cpu()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
858

859
860
        self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
861
862
863

if __name__ == "__main__":
    unittest.main()