"vscode:/vscode.git/clone" did not exist on "3aa76e1cf98ed1db17be6f931b11777a3f42ebb6"
test_loss.py 28.8 KB
Newer Older
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import os
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
16
17
18
19
import math
import torch
import numpy as np
import unittest
20
import ml_collections as mlc
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
21

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
22
from openfold.data import data_transforms
23
24
25
26
from openfold.utils.rigid_utils import (
    Rotation,
    Rigid,
)
27
import openfold.utils.feats as feats
28
29
30
31
32
33
from openfold.utils.loss import (
    torsion_angle_loss,
    compute_fape,
    between_residue_bond_loss,
    between_residue_clash_loss,
    find_structural_violations,
34
35
36
37
38
39
40
41
42
43
44
    compute_renamed_ground_truth,
    masked_msa_loss,
    distogram_loss,
    experimentally_resolved_loss,
    violation_loss,
    fape_loss,
    lddt_loss,
    supervised_chi_loss,
    backbone_loss,
    sidechain_loss,
    tm_loss,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
45
    compute_plddt,
46
47
)
from openfold.utils.tensor_utils import (
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
48
49
    tree_map,
    tensor_tree_map,
50
    dict_multimap,
51
52
53
)
import tests.compare_utils as compare_utils
from tests.config import consts
54
from tests.data_utils import random_affines_vector, random_affines_4x4
55

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
56
if compare_utils.alphafold_is_installed():
57
58
59
    alphafold = compare_utils.import_alphafold()
    import jax
    import haiku as hk
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
60
61


62
63
64
65
66
def affine_vector_to_4x4(affine):
    r = Rigid.from_tensor_7(affine)
    return r.to_tensor_4x4()


Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
67
68
class TestLoss(unittest.TestCase):
    def test_run_torsion_angle_loss(self):
69
70
        batch_size = consts.batch_size
        n_res = consts.n_res
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
71

72
73
74
        a = torch.rand((batch_size, n_res, 7, 2))
        a_gt = torch.rand((batch_size, n_res, 7, 2))
        a_alt_gt = torch.rand((batch_size, n_res, 7, 2))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
75
76
77
78

        loss = torsion_angle_loss(a, a_gt, a_alt_gt)

    def test_run_fape(self):
79
        batch_size = consts.batch_size
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
80
81
82
83
84
85
86
87
88
        n_frames = 7
        n_atoms = 5

        x = torch.rand((batch_size, n_atoms, 3))
        x_gt = torch.rand((batch_size, n_atoms, 3))
        rots = torch.rand((batch_size, n_frames, 3, 3))
        rots_gt = torch.rand((batch_size, n_frames, 3, 3))
        trans = torch.rand((batch_size, n_frames, 3))
        trans_gt = torch.rand((batch_size, n_frames, 3))
89
90
        t = Rigid(Rotation(rot_mats=rots), trans)
        t_gt = Rigid(Rotation(rot_mats=rots_gt), trans_gt)
91
92
93
94
95
96
97
98
99
100
101
102
103
        frames_mask = torch.randint(0, 2, (batch_size, n_frames)).float()
        positions_mask = torch.randint(0, 2, (batch_size, n_atoms)).float()
        length_scale = 10

        loss = compute_fape(
            pred_frames=t,
            target_frames=t_gt,
            frames_mask=frames_mask,
            pred_positions=x,
            target_positions=x_gt,
            positions_mask=positions_mask,
            length_scale=length_scale,
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
104

105
106
107
    def test_run_between_residue_bond_loss(self):
        bs = consts.batch_size
        n = consts.n_res
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
108
109
110
        pred_pos = torch.rand(bs, n, 14, 3)
        pred_atom_mask = torch.randint(0, 2, (bs, n, 14))
        residue_index = torch.arange(n).unsqueeze(0)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
111
112
113
114
115
116
117
118
119
        aatype = torch.randint(
            0,
            22,
            (
                bs,
                n,
            ),
        )

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
120
121
122
        between_residue_bond_loss(
            pred_pos,
            pred_atom_mask,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
123
            residue_index,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
124
125
126
            aatype,
        )

127
128
129
130
131
132
133
134
135
    @compare_utils.skip_unless_alphafold_installed()
    def test_between_residue_bond_loss_compare(self):
        def run_brbl(pred_pos, pred_atom_mask, residue_index, aatype):
            return alphafold.model.all_atom.between_residue_bond_loss(
                pred_pos,
                pred_atom_mask,
                residue_index,
                aatype,
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
136

137
        f = hk.transform(run_brbl)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
138
139

        n_res = consts.n_res
140
        pred_pos = np.random.rand(n_res, 14, 3).astype(np.float32)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
141
        pred_atom_mask = np.random.randint(0, 2, (n_res, 14)).astype(np.float32)
142
143
        residue_index = np.arange(n_res)
        aatype = np.random.randint(0, 22, (n_res,))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
144

145
        out_gt = f.apply(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
146
147
148
149
            {},
            None,
            pred_pos,
            pred_atom_mask,
150
151
152
153
154
            residue_index,
            aatype,
        )
        out_gt = jax.tree_map(lambda x: x.block_until_ready(), out_gt)
        out_gt = jax.tree_map(lambda x: torch.tensor(np.copy(x)), out_gt)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
155

156
157
158
159
160
161
162
        out_repro = between_residue_bond_loss(
            torch.tensor(pred_pos).cuda(),
            torch.tensor(pred_atom_mask).cuda(),
            torch.tensor(residue_index).cuda(),
            torch.tensor(aatype).cuda(),
        )
        out_repro = tensor_tree_map(lambda x: x.cpu(), out_repro)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
163

164
165
166
167
168
        for k in out_gt.keys():
            self.assertTrue(
                torch.max(torch.abs(out_gt[k] - out_repro[k])) < consts.eps
            )

169
    def test_run_between_residue_clash_loss(self):
170
171
172
        bs = consts.batch_size
        n = consts.n_res

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
173
        pred_pos = torch.rand(bs, n, 14, 3)
174
        pred_atom_mask = torch.randint(0, 2, (bs, n, 14)).float()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
175
176
177
178
179
180
        atom14_atom_radius = torch.rand(bs, n, 14)
        residue_index = torch.arange(n).unsqueeze(0)

        loss = between_residue_clash_loss(
            pred_pos,
            pred_atom_mask,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
181
            atom14_atom_radius,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
182
183
184
            residue_index,
        )

185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
    @compare_utils.skip_unless_alphafold_installed()
    def test_between_residue_clash_loss_compare(self):
        def run_brcl(pred_pos, atom_exists, atom_radius, res_ind):
            return alphafold.model.all_atom.between_residue_clash_loss(
                pred_pos,
                atom_exists,
                atom_radius,
                res_ind,
            )

        f = hk.transform(run_brcl)

        n_res = consts.n_res

        pred_pos = np.random.rand(n_res, 14, 3).astype(np.float32)
        atom_exists = np.random.randint(0, 2, (n_res, 14)).astype(np.float32)
        atom_radius = np.random.rand(n_res, 14).astype(np.float32)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
202
203
204
205
        res_ind = np.arange(
            n_res,
        )

206
        out_gt = f.apply(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
207
208
            {},
            None,
209
210
211
212
213
214
215
            pred_pos,
            atom_exists,
            atom_radius,
            res_ind,
        )
        out_gt = jax.tree_map(lambda x: x.block_until_ready(), out_gt)
        out_gt = jax.tree_map(lambda x: torch.tensor(np.copy(x)), out_gt)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
216

217
218
219
220
221
222
223
        out_repro = between_residue_clash_loss(
            torch.tensor(pred_pos).cuda(),
            torch.tensor(atom_exists).cuda(),
            torch.tensor(atom_radius).cuda(),
            torch.tensor(res_ind).cuda(),
        )
        out_repro = tensor_tree_map(lambda x: x.cpu(), out_repro)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
224

225
226
227
228
229
        for k in out_gt.keys():
            self.assertTrue(
                torch.max(torch.abs(out_gt[k] - out_repro[k])) < consts.eps
            )

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
    @compare_utils.skip_unless_alphafold_installed()
    def test_compute_plddt_compare(self):
        n_res = consts.n_res

        logits = np.random.rand(n_res, 50)

        out_gt = alphafold.common.confidence.compute_plddt(logits)
        out_gt = torch.tensor(out_gt)
        logits_t = torch.tensor(logits)
        out_repro = compute_plddt(logits_t)

        self.assertTrue(
            torch.max(torch.abs(out_gt - out_repro)) < consts.eps
        )

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
245
    def test_find_structural_violations(self):
246
        n = consts.n_res
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
247
248
249
250

        batch = {
            "atom14_atom_exists": torch.randint(0, 2, (n, 14)),
            "residue_index": torch.arange(n),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
251
            "aatype": torch.randint(0, 20, (n,)),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
252
253
254
255
            "residx_atom14_to_atom37": torch.randint(0, 37, (n, 14)).long(),
        }

        pred_pos = torch.rand(n, 14, 3)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
256

257
        config = {
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
258
259
            "clash_overlap_tolerance": 1.5,
            "violation_tolerance_factor": 12.0,
260
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
261

262
        find_structural_violations(batch, pred_pos, **config)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
263

264
265
266
267
268
269
270
271
272
273
274
275
276
277
    @compare_utils.skip_unless_alphafold_installed()
    def test_find_structural_violations_compare(self):
        def run_fsv(batch, pos, config):
            cwd = os.getcwd()
            os.chdir("tests/test_data")
            loss = alphafold.model.folding.find_structural_violations(
                batch,
                pos,
                config,
            )
            os.chdir(cwd)
            return loss

        f = hk.transform(run_fsv)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
278

279
        n_res = consts.n_res
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
280

281
282
283
        batch = {
            "atom14_atom_exists": np.random.randint(0, 2, (n_res, 14)),
            "residue_index": np.arange(n_res),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
284
            "aatype": np.random.randint(0, 20, (n_res,)),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
285
286
287
            "residx_atom14_to_atom37": np.random.randint(
                0, 37, (n_res, 14)
            ).astype(np.int64),
288
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
289

290
        pred_pos = np.random.rand(n_res, 14, 3)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
291
292
293
294
295
296

        config = mlc.ConfigDict(
            {
                "clash_overlap_tolerance": 1.5,
                "violation_tolerance_factor": 12.0,
            }
297
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
298
299

        out_gt = f.apply({}, None, batch, pred_pos, config)
300
301
        out_gt = jax.tree_map(lambda x: x.block_until_ready(), out_gt)
        out_gt = jax.tree_map(lambda x: torch.tensor(np.copy(x)), out_gt)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
302
303

        batch = tree_map(lambda x: torch.tensor(x).cuda(), batch, np.ndarray)
304
305
306
307
308
309
        out_repro = find_structural_violations(
            batch,
            torch.tensor(pred_pos).cuda(),
            **config,
        )
        out_repro = tensor_tree_map(lambda x: x.cpu(), out_repro)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
310

311
312
        def compare(out):
            gt, repro = out
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
313
314
            assert torch.max(torch.abs(gt - repro)) < consts.eps

315
316
317
318
319
320
321
322
323
        dict_multimap(compare, [out_gt, out_repro])

    @compare_utils.skip_unless_alphafold_installed()
    def test_compute_renamed_ground_truth_compare(self):
        def run_crgt(batch, atom14_pred_pos):
            return alphafold.model.folding.compute_renamed_ground_truth(
                batch,
                atom14_pred_pos,
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
324

325
        f = hk.transform(run_crgt)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
326

327
        n_res = consts.n_res
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
328

329
330
        batch = {
            "seq_mask": np.random.randint(0, 2, (n_res,)).astype(np.float32),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
331
            "aatype": np.random.randint(0, 20, (n_res,)),
332
            "atom14_gt_positions": np.random.rand(n_res, 14, 3),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
333
334
335
336
337
338
339
340
341
            "atom14_gt_exists": np.random.randint(0, 2, (n_res, 14)).astype(
                np.float32
            ),
            "all_atom_mask": np.random.randint(0, 2, (n_res, 37)).astype(
                np.float32
            ),
            "all_atom_positions": np.random.rand(n_res, 37, 3).astype(
                np.float32
            ),
342
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
343

344
345
        def _build_extra_feats_np():
            b = tree_map(lambda n: torch.tensor(n), batch, np.ndarray)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
346
347
            b = data_transforms.make_atom14_masks(b)
            b = data_transforms.make_atom14_positions(b)
348
            return tensor_tree_map(lambda t: np.array(t), b)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
349

350
        batch = _build_extra_feats_np()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
351

352
        atom14_pred_pos = np.random.rand(n_res, 14, 3)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
353

354
355
        out_gt = f.apply({}, None, batch, atom14_pred_pos)
        out_gt = jax.tree_map(lambda x: torch.tensor(np.array(x)), out_gt)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
356
357

        batch = tree_map(lambda x: torch.tensor(x).cuda(), batch, np.ndarray)
358
        atom14_pred_pos = torch.tensor(atom14_pred_pos).cuda()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
359

360
361
        out_repro = compute_renamed_ground_truth(batch, atom14_pred_pos)
        out_repro = tensor_tree_map(lambda t: t.cpu(), out_repro)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
362

363
364
365
366
367
368
369
370
371
372
373
374
375
        for k in out_repro:
            self.assertTrue(
                torch.max(torch.abs(out_gt[k] - out_repro[k])) < consts.eps
            )

    @compare_utils.skip_unless_alphafold_installed()
    def test_msa_loss_compare(self):
        def run_msa_loss(value, batch):
            config = compare_utils.get_alphafold_config()
            msa_head = alphafold.model.modules.MaskedMsaHead(
                config.model.heads.masked_msa, config.model.global_config
            )
            return msa_head.loss(value, batch)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
376

377
        f = hk.transform(run_msa_loss)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
378

379
380
        n_res = consts.n_res
        n_seq = consts.n_seq
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
381

382
383
384
        value = {
            "logits": np.random.rand(n_res, n_seq, 23).astype(np.float32),
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
385

386
387
        batch = {
            "true_msa": np.random.randint(0, 21, (n_res, n_seq)),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
388
389
390
            "bert_mask": np.random.randint(0, 2, (n_res, n_seq)).astype(
                np.float32
            ),
391
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
392

393
394
        out_gt = f.apply({}, None, value, batch)["loss"]
        out_gt = torch.tensor(np.array(out_gt))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
395
396
397
398

        value = tree_map(lambda x: torch.tensor(x).cuda(), value, np.ndarray)
        batch = tree_map(lambda x: torch.tensor(x).cuda(), batch, np.ndarray)

399
400
401
402
        with torch.no_grad():
            out_repro = masked_msa_loss(
                value["logits"],
                **batch,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
403
            )
404
        out_repro = tensor_tree_map(lambda t: t.cpu(), out_repro)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
405

406
407
408
409
410
411
        self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)

    @compare_utils.skip_unless_alphafold_installed()
    def test_distogram_loss_compare(self):
        config = compare_utils.get_alphafold_config()
        c_distogram = config.model.heads.distogram
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
412

413
414
415
416
417
        def run_distogram_loss(value, batch):
            dist_head = alphafold.model.modules.DistogramHead(
                c_distogram, config.model.global_config
            )
            return dist_head.loss(value, batch)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
418

419
        f = hk.transform(run_distogram_loss)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
420

421
        n_res = consts.n_res
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
422

423
        value = {
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
424
425
426
            "logits": np.random.rand(n_res, n_res, c_distogram.num_bins).astype(
                np.float32
            ),
427
428
429
430
            "bin_edges": np.linspace(
                c_distogram.first_break,
                c_distogram.last_break,
                c_distogram.num_bins,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
431
            ),
432
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
433

434
435
        batch = {
            "pseudo_beta": np.random.rand(n_res, 3).astype(np.float32),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
436
            "pseudo_beta_mask": np.random.randint(0, 2, (n_res,)),
437
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
438

439
440
        out_gt = f.apply({}, None, value, batch)["loss"]
        out_gt = torch.tensor(np.array(out_gt))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
441
442
443
444
445

        value = tree_map(lambda x: torch.tensor(x).cuda(), value, np.ndarray)

        batch = tree_map(lambda x: torch.tensor(x).cuda(), batch, np.ndarray)

446
447
448
449
450
451
452
        with torch.no_grad():
            out_repro = distogram_loss(
                logits=value["logits"],
                min_bin=c_distogram.first_break,
                max_bin=c_distogram.last_break,
                no_bins=c_distogram.num_bins,
                **batch,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
453
454
            )

455
        out_repro = tensor_tree_map(lambda t: t.cpu(), out_repro)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
456

457
458
459
460
461
462
        self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)

    @compare_utils.skip_unless_alphafold_installed()
    def test_experimentally_resolved_loss_compare(self):
        config = compare_utils.get_alphafold_config()
        c_experimentally_resolved = config.model.heads.experimentally_resolved
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
463

464
465
466
467
468
        def run_experimentally_resolved_loss(value, batch):
            er_head = alphafold.model.modules.ExperimentallyResolvedHead(
                c_experimentally_resolved, config.model.global_config
            )
            return er_head.loss(value, batch)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
469

470
        f = hk.transform(run_experimentally_resolved_loss)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
471

472
        n_res = consts.n_res
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
473

474
475
476
        value = {
            "logits": np.random.rand(n_res, 37).astype(np.float32),
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
477

478
479
480
        batch = {
            "all_atom_mask": np.random.randint(0, 2, (n_res, 37)),
            "atom37_atom_exists": np.random.randint(0, 2, (n_res, 37)),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
481
            "resolution": np.array(1.0),
482
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
483

484
485
        out_gt = f.apply({}, None, value, batch)["loss"]
        out_gt = torch.tensor(np.array(out_gt))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
486
487
488
489
490

        value = tree_map(lambda x: torch.tensor(x).cuda(), value, np.ndarray)

        batch = tree_map(lambda x: torch.tensor(x).cuda(), batch, np.ndarray)

491
492
493
494
495
496
        with torch.no_grad():
            out_repro = experimentally_resolved_loss(
                logits=value["logits"],
                min_resolution=c_experimentally_resolved.min_resolution,
                max_resolution=c_experimentally_resolved.max_resolution,
                **batch,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
497
498
            )

499
        out_repro = tensor_tree_map(lambda t: t.cpu(), out_repro)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
500

501
502
503
504
505
506
        self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)

    @compare_utils.skip_unless_alphafold_installed()
    def test_supervised_chi_loss_compare(self):
        config = compare_utils.get_alphafold_config()
        c_chi_loss = config.model.heads.structure_module
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
507

508
509
        def run_supervised_chi_loss(value, batch):
            ret = {
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
510
                "loss": jax.numpy.array(0.0),
511
512
513
514
515
516
517
518
519
520
521
522
            }
            alphafold.model.folding.supervised_chi_loss(
                ret, batch, value, c_chi_loss
            )
            return ret["loss"]

        f = hk.transform(run_supervised_chi_loss)

        n_res = consts.n_res

        value = {
            "sidechains": {
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
523
524
525
526
527
528
                "angles_sin_cos": np.random.rand(8, n_res, 7, 2).astype(
                    np.float32
                ),
                "unnormalized_angles_sin_cos": np.random.rand(
                    8, n_res, 7, 2
                ).astype(np.float32),
529
530
531
532
533
534
535
536
537
538
539
540
            }
        }

        batch = {
            "aatype": np.random.randint(0, 21, (n_res,)),
            "seq_mask": np.random.randint(0, 2, (n_res,)),
            "chi_mask": np.random.randint(0, 2, (n_res, 4)),
            "chi_angles": np.random.rand(n_res, 4).astype(np.float32),
        }

        out_gt = f.apply({}, None, value, batch)
        out_gt = torch.tensor(np.array(out_gt.block_until_ready()))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
541
        value = tree_map(lambda x: torch.tensor(x).cuda(), value, np.ndarray)
542

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
543
        batch = tree_map(lambda x: torch.tensor(x).cuda(), batch, np.ndarray)
544
545
546
547
548
549
550
551
552
553
554
555
556

        batch["chi_angles_sin_cos"] = torch.stack(
            [
                torch.sin(batch["chi_angles"]),
                torch.cos(batch["chi_angles"]),
            ],
            dim=-1,
        )

        with torch.no_grad():
            out_repro = supervised_chi_loss(
                chi_weight=c_chi_loss.chi_weight,
                angle_norm_weight=c_chi_loss.angle_norm_weight,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
557
558
559
                **{**batch, **value["sidechains"]},
            )

560
561
562
563
564
565
566
567
        out_repro = tensor_tree_map(lambda t: t.cpu(), out_repro)

        self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)

    @compare_utils.skip_unless_alphafold_installed()
    def test_violation_loss_compare(self):
        config = compare_utils.get_alphafold_config()
        c_viol = config.model.heads.structure_module
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
568

569
570
        def run_viol_loss(batch, atom14_pred_pos):
            ret = {
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
571
                "loss": np.array(0.0).astype(np.float32),
572
573
            }
            value = {}
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
574
575
576
577
578
579
            value[
                "violations"
            ] = alphafold.model.folding.find_structural_violations(
                batch,
                atom14_pred_pos,
                c_viol,
580
581
            )
            alphafold.model.folding.structural_violation_loss(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
582
583
584
585
                ret,
                batch,
                value,
                c_viol,
586
587
            )
            return ret["loss"]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
588

589
        f = hk.transform(run_viol_loss)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
590

591
        n_res = consts.n_res
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
592

593
594
595
596
597
        batch = {
            "seq_mask": np.random.randint(0, 2, (n_res,)).astype(np.float32),
            "residue_index": np.arange(n_res),
            "aatype": np.random.randint(0, 21, (n_res,)),
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
598
599
600
        alphafold.model.tf.data_transforms.make_atom14_masks(batch)
        batch = {k: np.array(v) for k, v in batch.items()}

601
        atom14_pred_pos = np.random.rand(n_res, 14, 3).astype(np.float32)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
602

603
604
        out_gt = f.apply({}, None, batch, atom14_pred_pos)
        out_gt = torch.tensor(np.array(out_gt.block_until_ready()))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
605
606

        batch = tree_map(lambda n: torch.tensor(n).cuda(), batch, np.ndarray)
607
        atom14_pred_pos = torch.tensor(atom14_pred_pos).cuda()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
608

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
609
        batch = data_transforms.make_atom14_masks(batch)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
610

611
612
613
614
615
        out_repro = violation_loss(
            find_structural_violations(batch, atom14_pred_pos, **c_viol),
            **batch,
        )
        out_repro = out_repro.cpu()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
616

617
618
619
620
621
622
        self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)

    @compare_utils.skip_unless_alphafold_installed()
    def test_lddt_loss_compare(self):
        config = compare_utils.get_alphafold_config()
        c_plddt = config.model.heads.predicted_lddt
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
623

624
625
626
627
628
        def run_plddt_loss(value, batch):
            head = alphafold.model.modules.PredictedLDDTHead(
                c_plddt, config.model.global_config
            )
            return head.loss(value, batch)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
629

630
        f = hk.transform(run_plddt_loss)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
631

632
        n_res = consts.n_res
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
633

634
635
        value = {
            "predicted_lddt": {
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
636
637
638
                "logits": np.random.rand(n_res, c_plddt.num_bins).astype(
                    np.float32
                ),
639
640
            },
            "structure_module": {
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
641
642
643
644
                "final_atom_positions": np.random.rand(n_res, 37, 3).astype(
                    np.float32
                ),
            },
645
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
646

647
        batch = {
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
648
649
650
651
652
653
654
            "all_atom_positions": np.random.rand(n_res, 37, 3).astype(
                np.float32
            ),
            "all_atom_mask": np.random.randint(0, 2, (n_res, 37)).astype(
                np.float32
            ),
            "resolution": np.array(1.0).astype(np.float32),
655
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
656

657
658
        out_gt = f.apply({}, None, value, batch)
        out_gt = torch.tensor(np.array(out_gt["loss"]))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
659

660
661
662
        to_tensor = lambda t: torch.tensor(t).cuda()
        value = tree_map(to_tensor, value, np.ndarray)
        batch = tree_map(to_tensor, batch, np.ndarray)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
663

664
665
666
667
668
669
        out_repro = lddt_loss(
            logits=value["predicted_lddt"]["logits"],
            all_atom_pred_pos=value["structure_module"]["final_atom_positions"],
            **{**batch, **c_plddt},
        )
        out_repro = out_repro.cpu()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
670

671
672
673
        self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)

    @compare_utils.skip_unless_alphafold_installed()
674
    def test_backbone_loss_compare(self):
675
676
        config = compare_utils.get_alphafold_config()
        c_sm = config.model.heads.structure_module
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
677

678
679
        def run_bb_loss(batch, value):
            ret = {
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
680
                "loss": np.array(0.0),
681
682
683
684
685
686
687
688
689
690
            }
            alphafold.model.folding.backbone_loss(ret, batch, value, c_sm)
            return ret["loss"]

        f = hk.transform(run_bb_loss)

        n_res = consts.n_res

        batch = {
            "backbone_affine_tensor": random_affines_vector((n_res,)),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
691
692
693
694
            "backbone_affine_mask": np.random.randint(0, 2, (n_res,)).astype(
                np.float32
            ),
            "use_clamped_fape": np.array(0.0),
695
696
697
        }

        value = {
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
698
699
700
701
702
703
            "traj": random_affines_vector(
                (
                    c_sm.num_layer,
                    n_res,
                )
            ),
704
705
706
707
708
709
710
711
712
        }

        out_gt = f.apply({}, None, batch, value)
        out_gt = torch.tensor(np.array(out_gt.block_until_ready()))

        to_tensor = lambda t: torch.tensor(t).cuda()
        batch = tree_map(to_tensor, batch, np.ndarray)
        value = tree_map(to_tensor, value, np.ndarray)

713
        batch["backbone_rigid_tensor"] = affine_vector_to_4x4(
714
715
            batch["backbone_affine_tensor"]
        )
716
717
        batch["backbone_rigid_mask"] = batch["backbone_affine_mask"]
        
718
719
720
721
722
723
724
725
726
        out_repro = backbone_loss(traj=value["traj"], **{**batch, **c_sm})
        out_repro = out_repro.cpu()

        self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)

    @compare_utils.skip_unless_alphafold_installed()
    def test_sidechain_loss_compare(self):
        config = compare_utils.get_alphafold_config()
        c_sm = config.model.heads.structure_module
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
727

728
729
730
731
732
733
734
        def run_sidechain_loss(batch, value, atom14_pred_positions):
            batch = {
                **batch,
                **alphafold.model.all_atom.atom37_to_frames(
                    batch["aatype"],
                    batch["all_atom_positions"],
                    batch["all_atom_mask"],
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
735
                ),
736
737
738
            }
            v = {}
            v["sidechains"] = {}
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
739
740
741
742
            v["sidechains"][
                "frames"
            ] = alphafold.model.r3.rigids_from_tensor4x4(
                value["sidechains"]["frames"]
743
744
745
746
            )
            v["sidechains"]["atom_pos"] = alphafold.model.r3.vecs_from_tensor(
                value["sidechains"]["atom_pos"]
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
747
748
749
750
751
752
            v.update(
                alphafold.model.folding.compute_renamed_ground_truth(
                    batch,
                    atom14_pred_positions,
                )
            )
753
            value = v
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
754

755
756
            ret = alphafold.model.folding.sidechain_loss(batch, value, c_sm)
            return ret["loss"]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
757

758
        f = hk.transform(run_sidechain_loss)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
759

760
        n_res = consts.n_res
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
761

762
763
        batch = {
            "seq_mask": np.random.randint(0, 2, (n_res,)).astype(np.float32),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
764
            "aatype": np.random.randint(0, 20, (n_res,)),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
765
766
767
768
769
770
771
772
773
774
775
776
            "atom14_gt_positions": np.random.rand(n_res, 14, 3).astype(
                np.float32
            ),
            "atom14_gt_exists": np.random.randint(0, 2, (n_res, 14)).astype(
                np.float32
            ),
            "all_atom_positions": np.random.rand(n_res, 37, 3).astype(
                np.float32
            ),
            "all_atom_mask": np.random.randint(0, 2, (n_res, 37)).astype(
                np.float32
            ),
777
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
778

779
780
        def _build_extra_feats_np():
            b = tree_map(lambda n: torch.tensor(n), batch, np.ndarray)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
781
782
            b = data_transforms.make_atom14_masks(b)
            b = data_transforms.make_atom14_positions(b)
783
            return tensor_tree_map(lambda t: np.array(t), b)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
784
785
786

        batch = _build_extra_feats_np()

787
788
789
        value = {
            "sidechains": {
                "frames": random_affines_4x4((c_sm.num_layer, n_res, 8)),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
790
791
792
                "atom_pos": np.random.rand(c_sm.num_layer, n_res, 14, 3).astype(
                    np.float32
                ),
793
794
            }
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
795

796
        atom14_pred_pos = np.random.rand(n_res, 14, 3).astype(np.float32)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
797

798
799
        out_gt = f.apply({}, None, batch, value, atom14_pred_pos)
        out_gt = torch.tensor(np.array(out_gt.block_until_ready()))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
800

801
802
803
804
        to_tensor = lambda t: torch.tensor(t).cuda()
        batch = tree_map(to_tensor, batch, np.ndarray)
        value = tree_map(to_tensor, value, np.ndarray)
        atom14_pred_pos = to_tensor(atom14_pred_pos)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
805

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
806
        batch = data_transforms.atom37_to_frames(batch)
807
        batch.update(compute_renamed_ground_truth(batch, atom14_pred_pos))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
808

809
810
811
812
813
814
        out_repro = sidechain_loss(
            sidechain_frames=value["sidechains"]["frames"],
            sidechain_atom_pos=value["sidechains"]["atom_pos"],
            **{**batch, **c_sm},
        )
        out_repro = out_repro.cpu()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
815

816
817
818
819
820
821
        self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)

    @compare_utils.skip_unless_alphafold_installed()
    def test_tm_loss_compare(self):
        config = compare_utils.get_alphafold_config()
        c_tm = config.model.heads.predicted_aligned_error
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
822

823
824
825
826
827
828
829
830
        def run_tm_loss(representations, batch, value):
            head = alphafold.model.modules.PredictedAlignedErrorHead(
                c_tm, config.model.global_config
            )
            v = {}
            v.update(value)
            v["predicted_aligned_error"] = head(representations, batch, False)
            return head.loss(v, batch)["loss"]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
831

832
        f = hk.transform(run_tm_loss)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
833

834
835
        np.random.seed(42)

836
        n_res = consts.n_res
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
837

838
        representations = {
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
839
            "pair": np.random.rand(n_res, n_res, consts.c_z).astype(np.float32),
840
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
841

842
843
        batch = {
            "backbone_affine_tensor": random_affines_vector((n_res,)),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
844
845
846
847
            "backbone_affine_mask": np.random.randint(0, 2, (n_res,)).astype(
                np.float32
            ),
            "resolution": np.array(1.0).astype(np.float32),
848
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
849

850
851
852
853
854
        value = {
            "structure_module": {
                "final_affines": random_affines_vector((n_res,)),
            }
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
855

856
857
858
859
860
861
        params = compare_utils.fetch_alphafold_module_weights(
            "alphafold/alphafold_iteration/predicted_aligned_error_head"
        )

        out_gt = f.apply(params, None, representations, batch, value)
        out_gt = torch.tensor(np.array(out_gt.block_until_ready()))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
862

863
864
865
866
        to_tensor = lambda n: torch.tensor(n).cuda()
        representations = tree_map(to_tensor, representations, np.ndarray)
        batch = tree_map(to_tensor, batch, np.ndarray)
        value = tree_map(to_tensor, value, np.ndarray)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
867

868
        batch["backbone_rigid_tensor"] = affine_vector_to_4x4(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
869
            batch["backbone_affine_tensor"]
870
        )
871
        batch["backbone_rigid_mask"] = batch["backbone_affine_mask"]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
872

873
874
        model = compare_utils.get_global_pretrained_openfold()
        logits = model.aux_heads.tm(representations["pair"])
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
875

876
877
878
879
880
881
        out_repro = tm_loss(
            logits=logits,
            final_affine_tensor=value["structure_module"]["final_affines"],
            **{**batch, **c_tm},
        )
        out_repro = out_repro.cpu()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
882

883
884
        self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
885
886
887

if __name__ == "__main__":
    unittest.main()