test_loss.py 39.4 KB
Newer Older
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import os
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
16
17
import torch
import numpy as np
Christina Floristean's avatar
Christina Floristean committed
18
from pathlib import Path
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
19
import unittest
20
import ml_collections as mlc
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
21

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
22
from openfold.data import data_transforms
23
from openfold.np import residue_constants
24
25
26
27
from openfold.utils.rigid_utils import (
    Rotation,
    Rigid,
)
28
29
30
31
32
33
from openfold.utils.loss import (
    torsion_angle_loss,
    compute_fape,
    between_residue_bond_loss,
    between_residue_clash_loss,
    find_structural_violations,
34
35
36
37
38
39
40
41
42
43
44
    compute_renamed_ground_truth,
    masked_msa_loss,
    distogram_loss,
    experimentally_resolved_loss,
    violation_loss,
    fape_loss,
    lddt_loss,
    supervised_chi_loss,
    backbone_loss,
    sidechain_loss,
    tm_loss,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
45
    compute_plddt,
46
47
    compute_tm,
    chain_center_of_mass_loss
48
49
)
from openfold.utils.tensor_utils import (
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
50
51
    tree_map,
    tensor_tree_map,
52
    dict_multimap,
53
54
55
)
import tests.compare_utils as compare_utils
from tests.config import consts
56
from tests.data_utils import random_affines_vector, random_affines_4x4, random_asym_ids
57

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
58
if compare_utils.alphafold_is_installed():
59
60
61
    alphafold = compare_utils.import_alphafold()
    import jax
    import haiku as hk
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
62
63


64
65
66
67
68
def affine_vector_to_4x4(affine):
    r = Rigid.from_tensor_7(affine)
    return r.to_tensor_4x4()


69
70
71
72
73
74
75
76
77
78
def affine_vector_to_rigid(am_rigid, affine):
    rigid_flat = np.split(affine, 7, axis=-1)
    rigid_flat = [r.squeeze(-1) for r in rigid_flat]
    qw, qx, qy, qz = rigid_flat[:4]
    trans = rigid_flat[4:]
    rotation = am_rigid.Rot3Array.from_quaternion(qw, qx, qy, qz, normalize=True)
    translation = am_rigid.Vec3Array(*trans)
    return am_rigid.Rigid3Array(rotation, translation)


Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
79
class TestLoss(unittest.TestCase):
80
81
82
83
84
85
86
87
88
89
90
91
92
    @classmethod
    def setUpClass(cls):
        if consts.is_multimer:
            cls.am_atom = alphafold.model.all_atom_multimer
            cls.am_fold = alphafold.model.folding_multimer
            cls.am_modules = alphafold.model.modules_multimer
            cls.am_rigid = alphafold.model.geometry
        else:
            cls.am_atom = alphafold.model.all_atom
            cls.am_fold = alphafold.model.folding
            cls.am_modules = alphafold.model.modules
            cls.am_rigid = alphafold.model.r3

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
93
    def test_run_torsion_angle_loss(self):
94
95
        batch_size = consts.batch_size
        n_res = consts.n_res
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
96

97
98
99
        a = torch.rand((batch_size, n_res, 7, 2))
        a_gt = torch.rand((batch_size, n_res, 7, 2))
        a_alt_gt = torch.rand((batch_size, n_res, 7, 2))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
100
101
102
103

        loss = torsion_angle_loss(a, a_gt, a_alt_gt)

    def test_run_fape(self):
104
        batch_size = consts.batch_size
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
105
106
107
108
109
110
111
112
113
        n_frames = 7
        n_atoms = 5

        x = torch.rand((batch_size, n_atoms, 3))
        x_gt = torch.rand((batch_size, n_atoms, 3))
        rots = torch.rand((batch_size, n_frames, 3, 3))
        rots_gt = torch.rand((batch_size, n_frames, 3, 3))
        trans = torch.rand((batch_size, n_frames, 3))
        trans_gt = torch.rand((batch_size, n_frames, 3))
114
115
        t = Rigid(Rotation(rot_mats=rots), trans)
        t_gt = Rigid(Rotation(rot_mats=rots_gt), trans_gt)
116
117
118
119
120
121
122
123
124
125
126
127
128
        frames_mask = torch.randint(0, 2, (batch_size, n_frames)).float()
        positions_mask = torch.randint(0, 2, (batch_size, n_atoms)).float()
        length_scale = 10

        loss = compute_fape(
            pred_frames=t,
            target_frames=t_gt,
            frames_mask=frames_mask,
            pred_positions=x,
            target_positions=x_gt,
            positions_mask=positions_mask,
            length_scale=length_scale,
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
129

130
131
132
    def test_run_between_residue_bond_loss(self):
        bs = consts.batch_size
        n = consts.n_res
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
133
134
135
        pred_pos = torch.rand(bs, n, 14, 3)
        pred_atom_mask = torch.randint(0, 2, (bs, n, 14))
        residue_index = torch.arange(n).unsqueeze(0)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
136
137
138
139
140
141
142
143
144
        aatype = torch.randint(
            0,
            22,
            (
                bs,
                n,
            ),
        )

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
145
146
147
        between_residue_bond_loss(
            pred_pos,
            pred_atom_mask,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
148
            residue_index,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
149
150
151
            aatype,
        )

152
153
154
    @compare_utils.skip_unless_alphafold_installed()
    def test_between_residue_bond_loss_compare(self):
        def run_brbl(pred_pos, pred_atom_mask, residue_index, aatype):
155
156
157
158
            if consts.is_multimer:
                pred_pos = self.am_rigid.Vec3Array.from_array(pred_pos)

            return self.am_atom.between_residue_bond_loss(
159
160
161
162
163
                pred_pos,
                pred_atom_mask,
                residue_index,
                aatype,
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
164

165
        f = hk.transform(run_brbl)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
166
167

        n_res = consts.n_res
168
        pred_pos = np.random.rand(n_res, 14, 3).astype(np.float32)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
169
        pred_atom_mask = np.random.randint(0, 2, (n_res, 14)).astype(np.float32)
170
171
        residue_index = np.arange(n_res)
        aatype = np.random.randint(0, 22, (n_res,))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
172

173
        out_gt = f.apply(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
174
175
176
177
            {},
            None,
            pred_pos,
            pred_atom_mask,
178
179
180
181
182
            residue_index,
            aatype,
        )
        out_gt = jax.tree_map(lambda x: x.block_until_ready(), out_gt)
        out_gt = jax.tree_map(lambda x: torch.tensor(np.copy(x)), out_gt)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
183

184
185
186
187
188
189
190
        out_repro = between_residue_bond_loss(
            torch.tensor(pred_pos).cuda(),
            torch.tensor(pred_atom_mask).cuda(),
            torch.tensor(residue_index).cuda(),
            torch.tensor(aatype).cuda(),
        )
        out_repro = tensor_tree_map(lambda x: x.cpu(), out_repro)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
191

192
193
194
195
196
        for k in out_gt.keys():
            self.assertTrue(
                torch.max(torch.abs(out_gt[k] - out_repro[k])) < consts.eps
            )

197
    def test_run_between_residue_clash_loss(self):
198
199
200
        bs = consts.batch_size
        n = consts.n_res

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
201
        pred_pos = torch.rand(bs, n, 14, 3)
202
        pred_atom_mask = torch.randint(0, 2, (bs, n, 14)).float()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
203
204
205
206
207
208
        atom14_atom_radius = torch.rand(bs, n, 14)
        residue_index = torch.arange(n).unsqueeze(0)

        loss = between_residue_clash_loss(
            pred_pos,
            pred_atom_mask,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
209
            atom14_atom_radius,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
210
211
212
            residue_index,
        )

213
214
    @compare_utils.skip_unless_alphafold_installed()
    def test_between_residue_clash_loss_compare(self):
215
216
217
218
219
220
221
222
223
224
225
226
        def run_brcl(pred_pos, atom_exists, atom_radius, res_ind, asym_id):
            if consts.is_multimer:
                pred_pos = self.am_rigid.Vec3Array.from_array(pred_pos)
                return self.am_atom.between_residue_clash_loss(
                    pred_pos,
                    atom_exists,
                    atom_radius,
                    res_ind,
                    asym_id
                )

            return self.am_atom.between_residue_clash_loss(
227
228
229
                pred_pos,
                atom_exists,
                atom_radius,
230
                res_ind
231
232
233
234
235
236
237
238
            )

        f = hk.transform(run_brcl)

        n_res = consts.n_res

        pred_pos = np.random.rand(n_res, 14, 3).astype(np.float32)
        atom_exists = np.random.randint(0, 2, (n_res, 14)).astype(np.float32)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
239
240
241
        res_ind = np.arange(
            n_res,
        )
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
        residx_atom14_to_atom37 = np.random.randint(0, 37, (n_res, 14)).astype(np.int64)

        atomtype_radius = [
            residue_constants.van_der_waals_radius[name[0]]
            for name in residue_constants.atom_types
        ]
        atomtype_radius = np.array(atomtype_radius).astype(np.float32)
        atom_radius = (
                atom_exists
                * atomtype_radius[residx_atom14_to_atom37]
        )

        asym_id = None
        if consts.is_multimer:
            asym_id = random_asym_ids(n_res)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
257

258
        out_gt = f.apply(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
259
260
            {},
            None,
261
262
263
264
            pred_pos,
            atom_exists,
            atom_radius,
            res_ind,
265
            asym_id
266
267
268
        )
        out_gt = jax.tree_map(lambda x: x.block_until_ready(), out_gt)
        out_gt = jax.tree_map(lambda x: torch.tensor(np.copy(x)), out_gt)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
269

270
271
272
273
274
        out_repro = between_residue_clash_loss(
            torch.tensor(pred_pos).cuda(),
            torch.tensor(atom_exists).cuda(),
            torch.tensor(atom_radius).cuda(),
            torch.tensor(res_ind).cuda(),
275
            torch.tensor(asym_id).cuda() if asym_id is not None else None,
276
277
        )
        out_repro = tensor_tree_map(lambda x: x.cpu(), out_repro)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
278

279
280
281
282
283
        for k in out_gt.keys():
            self.assertTrue(
                torch.max(torch.abs(out_gt[k] - out_repro[k])) < consts.eps
            )

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
    @compare_utils.skip_unless_alphafold_installed()
    def test_compute_plddt_compare(self):
        n_res = consts.n_res

        logits = np.random.rand(n_res, 50)

        out_gt = alphafold.common.confidence.compute_plddt(logits)
        out_gt = torch.tensor(out_gt)
        logits_t = torch.tensor(logits)
        out_repro = compute_plddt(logits_t)

        self.assertTrue(
            torch.max(torch.abs(out_gt - out_repro)) < consts.eps
        )

299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
    @compare_utils.skip_unless_alphafold_installed()
    def test_compute_ptm_compare(self):
        n_res = consts.n_res
        max_bin = 31
        no_bins = 64

        logits = np.random.rand(n_res, n_res, no_bins)
        boundaries = np.linspace(0, max_bin, num=(no_bins - 1))

        ptm_gt = alphafold.common.confidence.predicted_tm_score(logits, boundaries)
        ptm_gt = torch.tensor(ptm_gt)
        logits_t = torch.tensor(logits)
        ptm_repro = compute_tm(logits_t, no_bins=no_bins, max_bin=max_bin)

        self.assertTrue(
            torch.max(torch.abs(ptm_gt - ptm_repro)) < consts.eps
        )

        if consts.is_multimer:
            asym_id = random_asym_ids(n_res)
            iptm_gt = alphafold.common.confidence.predicted_tm_score(logits, boundaries,
                                                                     asym_id=asym_id, interface=True)
            iptm_gt = torch.tensor(iptm_gt)
            iptm_repro = compute_tm(logits_t, no_bins=no_bins, max_bin=max_bin,
                                    asym_id=torch.tensor(asym_id), interface=True)

            self.assertTrue(
                torch.max(torch.abs(iptm_gt - iptm_repro)) < consts.eps
            )

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
329
    def test_find_structural_violations(self):
330
        n = consts.n_res
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
331
332
333
334

        batch = {
            "atom14_atom_exists": torch.randint(0, 2, (n, 14)),
            "residue_index": torch.arange(n),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
335
            "aatype": torch.randint(0, 20, (n,)),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
336
337
338
339
            "residx_atom14_to_atom37": torch.randint(0, 37, (n, 14)).long(),
        }

        pred_pos = torch.rand(n, 14, 3)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
340

341
        config = {
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
342
343
            "clash_overlap_tolerance": 1.5,
            "violation_tolerance_factor": 12.0,
344
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
345

346
        find_structural_violations(batch, pred_pos, **config)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
347

348
349
350
351
    @compare_utils.skip_unless_alphafold_installed()
    def test_find_structural_violations_compare(self):
        def run_fsv(batch, pos, config):
            cwd = os.getcwd()
Christina Floristean's avatar
Christina Floristean committed
352
353
            fpath = Path(__file__).parent.resolve() / "test_data"
            os.chdir(str(fpath))
354
355
356
357
358
359
360
361
362
363
364
365
366

            if consts.is_multimer:
                atom14_pred_pos = self.am_rigid.Vec3Array.from_array(pos)
                return self.am_fold.find_structural_violations(
                    batch['aatype'],
                    batch['residue_index'],
                    batch['atom14_atom_exists'],
                    atom14_pred_pos,
                    config,
                    batch['asym_id']
                )

            loss = self.am_fold.find_structural_violations(
367
368
369
370
371
372
373
374
                batch,
                pos,
                config,
            )
            os.chdir(cwd)
            return loss

        f = hk.transform(run_fsv)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
375

376
        n_res = consts.n_res
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
377

378
379
380
        batch = {
            "atom14_atom_exists": np.random.randint(0, 2, (n_res, 14)),
            "residue_index": np.arange(n_res),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
381
            "aatype": np.random.randint(0, 20, (n_res,)),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
382
383
384
            "residx_atom14_to_atom37": np.random.randint(
                0, 37, (n_res, 14)
            ).astype(np.int64),
385
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
386

387
388
389
        if consts.is_multimer:
            batch["asym_id"] = random_asym_ids(n_res)

390
        pred_pos = np.random.rand(n_res, 14, 3)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
391
392
393
394
395
396

        config = mlc.ConfigDict(
            {
                "clash_overlap_tolerance": 1.5,
                "violation_tolerance_factor": 12.0,
            }
397
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
398
399

        out_gt = f.apply({}, None, batch, pred_pos, config)
400
401
        out_gt = jax.tree_map(lambda x: x.block_until_ready(), out_gt)
        out_gt = jax.tree_map(lambda x: torch.tensor(np.copy(x)), out_gt)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
402
403

        batch = tree_map(lambda x: torch.tensor(x).cuda(), batch, np.ndarray)
404
405
406
407
408
409
        out_repro = find_structural_violations(
            batch,
            torch.tensor(pred_pos).cuda(),
            **config,
        )
        out_repro = tensor_tree_map(lambda x: x.cpu(), out_repro)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
410

411
412
        def compare(out):
            gt, repro = out
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
413
414
            assert torch.max(torch.abs(gt - repro)) < consts.eps

415
416
417
418
419
420
421
422
423
        dict_multimap(compare, [out_gt, out_repro])

    @compare_utils.skip_unless_alphafold_installed()
    def test_compute_renamed_ground_truth_compare(self):
        def run_crgt(batch, atom14_pred_pos):
            return alphafold.model.folding.compute_renamed_ground_truth(
                batch,
                atom14_pred_pos,
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
424

425
        f = hk.transform(run_crgt)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
426

427
        n_res = consts.n_res
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
428

429
430
        batch = {
            "seq_mask": np.random.randint(0, 2, (n_res,)).astype(np.float32),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
431
            "aatype": np.random.randint(0, 20, (n_res,)),
432
            "atom14_gt_positions": np.random.rand(n_res, 14, 3),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
433
434
435
436
437
438
439
440
441
            "atom14_gt_exists": np.random.randint(0, 2, (n_res, 14)).astype(
                np.float32
            ),
            "all_atom_mask": np.random.randint(0, 2, (n_res, 37)).astype(
                np.float32
            ),
            "all_atom_positions": np.random.rand(n_res, 37, 3).astype(
                np.float32
            ),
442
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
443

444
445
        def _build_extra_feats_np():
            b = tree_map(lambda n: torch.tensor(n), batch, np.ndarray)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
446
447
            b = data_transforms.make_atom14_masks(b)
            b = data_transforms.make_atom14_positions(b)
448
            return tensor_tree_map(lambda t: np.array(t), b)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
449

450
        batch = _build_extra_feats_np()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
451

452
        atom14_pred_pos = np.random.rand(n_res, 14, 3)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
453

454
455
        out_gt = f.apply({}, None, batch, atom14_pred_pos)
        out_gt = jax.tree_map(lambda x: torch.tensor(np.array(x)), out_gt)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
456
457

        batch = tree_map(lambda x: torch.tensor(x).cuda(), batch, np.ndarray)
458
        atom14_pred_pos = torch.tensor(atom14_pred_pos).cuda()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
459

460
461
        out_repro = compute_renamed_ground_truth(batch, atom14_pred_pos)
        out_repro = tensor_tree_map(lambda t: t.cpu(), out_repro)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
462

463
464
465
466
467
468
469
470
471
472
473
474
475
        for k in out_repro:
            self.assertTrue(
                torch.max(torch.abs(out_gt[k] - out_repro[k])) < consts.eps
            )

    @compare_utils.skip_unless_alphafold_installed()
    def test_msa_loss_compare(self):
        def run_msa_loss(value, batch):
            config = compare_utils.get_alphafold_config()
            msa_head = alphafold.model.modules.MaskedMsaHead(
                config.model.heads.masked_msa, config.model.global_config
            )
            return msa_head.loss(value, batch)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
476

477
        f = hk.transform(run_msa_loss)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
478

479
480
        n_res = consts.n_res
        n_seq = consts.n_seq
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
481

482
        value = {
483
            "logits": np.random.rand(n_res, n_seq, consts.msa_logits).astype(np.float32),
484
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
485

486
487
        batch = {
            "true_msa": np.random.randint(0, 21, (n_res, n_seq)),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
488
489
            "bert_mask": np.random.randint(0, 2, (n_res, n_seq)).astype(
                np.float32
Christina Floristean's avatar
Christina Floristean committed
490
            )
491
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
492

493
494
        out_gt = f.apply({}, None, value, batch)["loss"]
        out_gt = torch.tensor(np.array(out_gt))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
495
496
497
498

        value = tree_map(lambda x: torch.tensor(x).cuda(), value, np.ndarray)
        batch = tree_map(lambda x: torch.tensor(x).cuda(), batch, np.ndarray)

499
500
501
        with torch.no_grad():
            out_repro = masked_msa_loss(
                value["logits"],
Christina Floristean's avatar
Christina Floristean committed
502
503
504
                batch["true_msa"],
                batch["bert_mask"],
                consts.msa_logits
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
505
            )
506
        out_repro = tensor_tree_map(lambda t: t.cpu(), out_repro)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
507

508
509
510
511
512
513
        self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)

    @compare_utils.skip_unless_alphafold_installed()
    def test_distogram_loss_compare(self):
        config = compare_utils.get_alphafold_config()
        c_distogram = config.model.heads.distogram
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
514

515
516
517
518
519
        def run_distogram_loss(value, batch):
            dist_head = alphafold.model.modules.DistogramHead(
                c_distogram, config.model.global_config
            )
            return dist_head.loss(value, batch)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
520

521
        f = hk.transform(run_distogram_loss)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
522

523
        n_res = consts.n_res
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
524

525
        value = {
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
526
527
528
            "logits": np.random.rand(n_res, n_res, c_distogram.num_bins).astype(
                np.float32
            ),
529
530
531
532
            "bin_edges": np.linspace(
                c_distogram.first_break,
                c_distogram.last_break,
                c_distogram.num_bins,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
533
            ),
534
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
535

536
537
        batch = {
            "pseudo_beta": np.random.rand(n_res, 3).astype(np.float32),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
538
            "pseudo_beta_mask": np.random.randint(0, 2, (n_res,)),
539
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
540

541
542
        out_gt = f.apply({}, None, value, batch)["loss"]
        out_gt = torch.tensor(np.array(out_gt))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
543
544
545
546
547

        value = tree_map(lambda x: torch.tensor(x).cuda(), value, np.ndarray)

        batch = tree_map(lambda x: torch.tensor(x).cuda(), batch, np.ndarray)

548
549
550
551
552
553
554
        with torch.no_grad():
            out_repro = distogram_loss(
                logits=value["logits"],
                min_bin=c_distogram.first_break,
                max_bin=c_distogram.last_break,
                no_bins=c_distogram.num_bins,
                **batch,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
555
556
            )

557
        out_repro = tensor_tree_map(lambda t: t.cpu(), out_repro)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
558

559
560
561
562
563
564
        self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)

    @compare_utils.skip_unless_alphafold_installed()
    def test_experimentally_resolved_loss_compare(self):
        config = compare_utils.get_alphafold_config()
        c_experimentally_resolved = config.model.heads.experimentally_resolved
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
565

566
567
568
569
570
        def run_experimentally_resolved_loss(value, batch):
            er_head = alphafold.model.modules.ExperimentallyResolvedHead(
                c_experimentally_resolved, config.model.global_config
            )
            return er_head.loss(value, batch)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
571

572
        f = hk.transform(run_experimentally_resolved_loss)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
573

574
        n_res = consts.n_res
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
575

576
577
578
        value = {
            "logits": np.random.rand(n_res, 37).astype(np.float32),
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
579

580
581
582
        batch = {
            "all_atom_mask": np.random.randint(0, 2, (n_res, 37)),
            "atom37_atom_exists": np.random.randint(0, 2, (n_res, 37)),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
583
            "resolution": np.array(1.0),
584
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
585

586
587
        out_gt = f.apply({}, None, value, batch)["loss"]
        out_gt = torch.tensor(np.array(out_gt))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
588
589
590
591
592

        value = tree_map(lambda x: torch.tensor(x).cuda(), value, np.ndarray)

        batch = tree_map(lambda x: torch.tensor(x).cuda(), batch, np.ndarray)

593
594
595
596
597
598
        with torch.no_grad():
            out_repro = experimentally_resolved_loss(
                logits=value["logits"],
                min_resolution=c_experimentally_resolved.min_resolution,
                max_resolution=c_experimentally_resolved.max_resolution,
                **batch,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
599
600
            )

601
        out_repro = tensor_tree_map(lambda t: t.cpu(), out_repro)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
602

603
604
605
606
607
608
        self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)

    @compare_utils.skip_unless_alphafold_installed()
    def test_supervised_chi_loss_compare(self):
        config = compare_utils.get_alphafold_config()
        c_chi_loss = config.model.heads.structure_module
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
609

610
        def run_supervised_chi_loss(value, batch):
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
            if consts.is_multimer:
                pred_angles = np.reshape(
                    value['sidechains']['angles_sin_cos'], [-1, consts.n_res, 7, 2])

                unnormed_angles = np.reshape(
                    value['sidechains']['unnormalized_angles_sin_cos'], [-1, consts.n_res, 7, 2])

                chi_loss, _, _ = self.am_fold.supervised_chi_loss(
                    batch['seq_mask'],
                    batch['chi_mask'],
                    batch['aatype'],
                    batch['chi_angles'],
                    pred_angles,
                    unnormed_angles,
                    c_chi_loss
                )
                return chi_loss

629
            ret = {
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
630
                "loss": jax.numpy.array(0.0),
631
            }
632
            self.am_fold.supervised_chi_loss(
633
634
635
636
637
638
639
640
641
642
                ret, batch, value, c_chi_loss
            )
            return ret["loss"]

        f = hk.transform(run_supervised_chi_loss)

        n_res = consts.n_res

        value = {
            "sidechains": {
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
643
644
645
646
647
648
                "angles_sin_cos": np.random.rand(8, n_res, 7, 2).astype(
                    np.float32
                ),
                "unnormalized_angles_sin_cos": np.random.rand(
                    8, n_res, 7, 2
                ).astype(np.float32),
649
650
651
652
653
654
655
656
657
658
659
660
            }
        }

        batch = {
            "aatype": np.random.randint(0, 21, (n_res,)),
            "seq_mask": np.random.randint(0, 2, (n_res,)),
            "chi_mask": np.random.randint(0, 2, (n_res, 4)),
            "chi_angles": np.random.rand(n_res, 4).astype(np.float32),
        }

        out_gt = f.apply({}, None, value, batch)
        out_gt = torch.tensor(np.array(out_gt.block_until_ready()))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
661
        value = tree_map(lambda x: torch.tensor(x).cuda(), value, np.ndarray)
662

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
663
        batch = tree_map(lambda x: torch.tensor(x).cuda(), batch, np.ndarray)
664
665
666
667
668
669
670
671
672
673
674
675
676

        batch["chi_angles_sin_cos"] = torch.stack(
            [
                torch.sin(batch["chi_angles"]),
                torch.cos(batch["chi_angles"]),
            ],
            dim=-1,
        )

        with torch.no_grad():
            out_repro = supervised_chi_loss(
                chi_weight=c_chi_loss.chi_weight,
                angle_norm_weight=c_chi_loss.angle_norm_weight,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
677
678
679
                **{**batch, **value["sidechains"]},
            )

680
681
682
683
        out_repro = tensor_tree_map(lambda t: t.cpu(), out_repro)

        self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)

684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
    @compare_utils.skip_unless_alphafold_installed()
    def test_violation_loss(self):
        config = compare_utils.get_alphafold_config()
        c_viol = config.model.heads.structure_module
        n_res = consts.n_res

        batch = {
            "seq_mask": np.random.randint(0, 2, (n_res,)).astype(np.float32),
            "residue_index": np.arange(n_res),
            "aatype": np.random.randint(0, 21, (n_res,)),
        }

        if consts.is_multimer:
            batch["asym_id"] = random_asym_ids(n_res)

        batch = tree_map(lambda n: torch.tensor(n).cuda(), batch, np.ndarray)

        atom14_pred_pos = np.random.rand(n_res, 14, 3).astype(np.float32)
        atom14_pred_pos = torch.tensor(atom14_pred_pos).cuda()

        batch = data_transforms.make_atom14_masks(batch)

        loss_sum_clash = violation_loss(
            find_structural_violations(batch, atom14_pred_pos, **c_viol),
            average_clashes=False, **batch
        )
        loss_sum_clash = loss_sum_clash.cpu()

        loss_avg_clash = violation_loss(
            find_structural_violations(batch, atom14_pred_pos, **c_viol),
            average_clashes=True, **batch
        )
        loss_avg_clash = loss_avg_clash.cpu()

718
719
720
721
    @compare_utils.skip_unless_alphafold_installed()
    def test_violation_loss_compare(self):
        config = compare_utils.get_alphafold_config()
        c_viol = config.model.heads.structure_module
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
722

723
724
        def run_viol_loss(batch, atom14_pred_pos):
            ret = {
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
725
                "loss": np.array(0.0).astype(np.float32),
726
            }
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741

            if consts.is_multimer:
                atom14_pred_pos = self.am_rigid.Vec3Array.from_array(atom14_pred_pos)
                viol = self.am_fold.find_structural_violations(
                    batch['aatype'],
                    batch['residue_index'],
                    batch['atom14_atom_exists'],
                    atom14_pred_pos,
                    c_viol,
                    batch['asym_id']
                )
                return self.am_fold.structural_violation_loss(mask=batch['atom14_atom_exists'],
                                                              violations=viol,
                                                              config=c_viol)

742
            value = {}
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
743
744
            value[
                "violations"
745
            ] = self.am_fold.find_structural_violations(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
746
747
748
                batch,
                atom14_pred_pos,
                c_viol,
749
            )
750
751

            self.am_fold.structural_violation_loss(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
752
753
754
755
                ret,
                batch,
                value,
                c_viol,
756
757
            )
            return ret["loss"]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
758

759
        f = hk.transform(run_viol_loss)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
760

761
        n_res = consts.n_res
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
762

763
764
765
        batch = {
            "seq_mask": np.random.randint(0, 2, (n_res,)).astype(np.float32),
            "residue_index": np.arange(n_res),
766
            "aatype": np.random.randint(0, 21, (n_res,))
767
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
768

769
770
771
        if consts.is_multimer:
            batch["asym_id"] = random_asym_ids(n_res)

772
        atom14_pred_pos = np.random.rand(n_res, 14, 3).astype(np.float32)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
773

774
775
776
        alphafold.model.tf.data_transforms.make_atom14_masks(batch)
        batch = {k: np.array(v) for k, v in batch.items()}

777
778
        out_gt = f.apply({}, None, batch, atom14_pred_pos)
        out_gt = torch.tensor(np.array(out_gt.block_until_ready()))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
779
780

        batch = tree_map(lambda n: torch.tensor(n).cuda(), batch, np.ndarray)
781
        atom14_pred_pos = torch.tensor(atom14_pred_pos).cuda()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
782

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
783
        batch = data_transforms.make_atom14_masks(batch)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
784

785
786
787
788
789
        out_repro = violation_loss(
            find_structural_violations(batch, atom14_pred_pos, **c_viol),
            **batch,
        )
        out_repro = out_repro.cpu()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
790

791
792
793
794
795
796
        self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)

    @compare_utils.skip_unless_alphafold_installed()
    def test_lddt_loss_compare(self):
        config = compare_utils.get_alphafold_config()
        c_plddt = config.model.heads.predicted_lddt
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
797

798
799
800
801
802
        def run_plddt_loss(value, batch):
            head = alphafold.model.modules.PredictedLDDTHead(
                c_plddt, config.model.global_config
            )
            return head.loss(value, batch)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
803

804
        f = hk.transform(run_plddt_loss)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
805

806
        n_res = consts.n_res
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
807

808
809
        value = {
            "predicted_lddt": {
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
810
811
812
                "logits": np.random.rand(n_res, c_plddt.num_bins).astype(
                    np.float32
                ),
813
814
            },
            "structure_module": {
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
815
816
817
818
                "final_atom_positions": np.random.rand(n_res, 37, 3).astype(
                    np.float32
                ),
            },
819
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
820

821
        batch = {
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
822
823
824
825
826
827
828
            "all_atom_positions": np.random.rand(n_res, 37, 3).astype(
                np.float32
            ),
            "all_atom_mask": np.random.randint(0, 2, (n_res, 37)).astype(
                np.float32
            ),
            "resolution": np.array(1.0).astype(np.float32),
829
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
830

831
832
        out_gt = f.apply({}, None, value, batch)
        out_gt = torch.tensor(np.array(out_gt["loss"]))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
833

834
835
836
        to_tensor = lambda t: torch.tensor(t).cuda()
        value = tree_map(to_tensor, value, np.ndarray)
        batch = tree_map(to_tensor, batch, np.ndarray)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
837

838
839
840
841
842
843
        out_repro = lddt_loss(
            logits=value["predicted_lddt"]["logits"],
            all_atom_pred_pos=value["structure_module"]["final_atom_positions"],
            **{**batch, **c_plddt},
        )
        out_repro = out_repro.cpu()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
844

845
846
847
        self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)

    @compare_utils.skip_unless_alphafold_installed()
848
    def test_backbone_loss_compare(self):
849
850
        config = compare_utils.get_alphafold_config()
        c_sm = config.model.heads.structure_module
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
851

852
        def run_bb_loss(batch, value):
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
            if consts.is_multimer:
                intra_chain_mask = (batch["asym_id"][..., None] == batch["asym_id"][..., None, :]).astype(np.float32)
                gt_rigid = affine_vector_to_rigid(self.am_rigid, batch["backbone_affine_tensor"])
                target_rigid = affine_vector_to_rigid(self.am_rigid, value['traj'])
                intra_chain_bb_loss, intra_chain_fape = self.am_fold.backbone_loss(
                    gt_rigid=gt_rigid,
                    gt_frames_mask=batch["backbone_affine_mask"],
                    gt_positions_mask=batch["backbone_affine_mask"],
                    target_rigid=target_rigid,
                    config=c_sm.intra_chain_fape,
                    pair_mask=intra_chain_mask)
                interface_bb_loss, interface_fape = self.am_fold.backbone_loss(
                    gt_rigid=gt_rigid,
                    gt_frames_mask=batch["backbone_affine_mask"],
                    gt_positions_mask=batch["backbone_affine_mask"],
                    target_rigid=target_rigid,
                    config=c_sm.interface_fape,
                    pair_mask=1. - intra_chain_mask)

                return intra_chain_bb_loss + interface_bb_loss

874
            ret = {
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
875
                "loss": np.array(0.0),
876
            }
877
            self.am_fold.backbone_loss(ret, batch, value, c_sm)
878
879
880
881
882
883
884
885
            return ret["loss"]

        f = hk.transform(run_bb_loss)

        n_res = consts.n_res

        batch = {
            "backbone_affine_tensor": random_affines_vector((n_res,)),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
886
887
888
            "backbone_affine_mask": np.random.randint(0, 2, (n_res,)).astype(
                np.float32
            ),
889
            "use_clamped_fape": np.array(0.0)
890
891
892
        }

        value = {
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
893
894
895
896
897
898
            "traj": random_affines_vector(
                (
                    c_sm.num_layer,
                    n_res,
                )
            ),
899
900
        }

901
902
903
        if consts.is_multimer:
            batch["asym_id"] = random_asym_ids(n_res)

904
905
906
907
908
909
910
        out_gt = f.apply({}, None, batch, value)
        out_gt = torch.tensor(np.array(out_gt.block_until_ready()))

        to_tensor = lambda t: torch.tensor(t).cuda()
        batch = tree_map(to_tensor, batch, np.ndarray)
        value = tree_map(to_tensor, value, np.ndarray)

911
        batch["backbone_rigid_tensor"] = affine_vector_to_4x4(
912
913
            batch["backbone_affine_tensor"]
        )
914
915
        batch["backbone_rigid_mask"] = batch["backbone_affine_mask"]
        
916
917
918
919
920
921
922
923
924
925
926
927
        if consts.is_multimer:
            intra_chain_mask = (batch["asym_id"][..., None]
                                == batch["asym_id"][..., None, :]).to(dtype=value["traj"].dtype)
            intra_chain_out = backbone_loss(traj=value["traj"], pair_mask=intra_chain_mask,
                                            **{**batch, **c_sm.intra_chain_fape})
            interface_out = backbone_loss(traj=value["traj"], pair_mask=1. - intra_chain_mask,
                                          **{**batch, **c_sm.interface_fape})
            out_repro = intra_chain_out + interface_out
            out_repro = out_repro.cpu()
        else:
            out_repro = backbone_loss(traj=value["traj"], **{**batch, **c_sm})
            out_repro = out_repro.cpu()
928
929
930
931
932
933
934

        self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)

    @compare_utils.skip_unless_alphafold_installed()
    def test_sidechain_loss_compare(self):
        config = compare_utils.get_alphafold_config()
        c_sm = config.model.heads.structure_module
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
935

936
        def run_sidechain_loss(batch, value, atom14_pred_positions):
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
            if consts.is_multimer:
                atom14_pred_positions = self.am_rigid.Vec3Array.from_array(atom14_pred_positions)
                all_atom_positions = self.am_rigid.Vec3Array.from_array(batch["all_atom_positions"])
                gt_positions, gt_mask, alt_naming_is_better = self.am_fold.compute_atom14_gt(
                    aatype=batch["aatype"], all_atom_positions=all_atom_positions,
                    all_atom_mask=batch["all_atom_mask"], pred_pos=atom14_pred_positions)
                pred_frames = self.am_rigid.Rigid3Array.from_array4x4(value["sidechains"]["frames"])
                pred_positions = self.am_rigid.Vec3Array.from_array(value["sidechains"]["atom_pos"])
                gt_sc_frames, gt_sc_frames_mask = self.am_fold.compute_frames(
                    aatype=batch["aatype"],
                    all_atom_positions=all_atom_positions,
                    all_atom_mask=batch["all_atom_mask"],
                    use_alt=alt_naming_is_better)
                return self.am_fold.sidechain_loss(gt_sc_frames,
                                                   gt_sc_frames_mask,
                                                   gt_positions,
                                                   gt_mask,
                                                   pred_frames,
                                                   pred_positions,
                                                   c_sm)['loss']
957
958
            batch = {
                **batch,
959
                **self.am_atom.atom37_to_frames(
960
961
962
                    batch["aatype"],
                    batch["all_atom_positions"],
                    batch["all_atom_mask"],
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
963
                ),
964
965
966
            }
            v = {}
            v["sidechains"] = {}
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
967
968
            v["sidechains"][
                "frames"
969
            ] = self.am_rigid.rigids_from_tensor4x4(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
970
                value["sidechains"]["frames"]
971
            )
972
            v["sidechains"]["atom_pos"] = self.am_rigid.vecs_from_tensor(
973
974
                value["sidechains"]["atom_pos"]
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
975
            v.update(
976
                self.am_fold.compute_renamed_ground_truth(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
977
978
979
980
                    batch,
                    atom14_pred_positions,
                )
            )
981
            value = v
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
982

983
            ret = self.am_fold.sidechain_loss(batch, value, c_sm)
984
            return ret["loss"]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
985

986
        f = hk.transform(run_sidechain_loss)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
987

988
        n_res = consts.n_res
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
989

990
991
        batch = {
            "seq_mask": np.random.randint(0, 2, (n_res,)).astype(np.float32),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
992
            "aatype": np.random.randint(0, 20, (n_res,)),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
            "atom14_gt_positions": np.random.rand(n_res, 14, 3).astype(
                np.float32
            ),
            "atom14_gt_exists": np.random.randint(0, 2, (n_res, 14)).astype(
                np.float32
            ),
            "all_atom_positions": np.random.rand(n_res, 37, 3).astype(
                np.float32
            ),
            "all_atom_mask": np.random.randint(0, 2, (n_res, 37)).astype(
                np.float32
            ),
1005
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1006

1007
1008
        def _build_extra_feats_np():
            b = tree_map(lambda n: torch.tensor(n), batch, np.ndarray)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1009
1010
            b = data_transforms.make_atom14_masks(b)
            b = data_transforms.make_atom14_positions(b)
1011
            return tensor_tree_map(lambda t: np.array(t), b)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1012
1013
1014

        batch = _build_extra_feats_np()

1015
1016
1017
        value = {
            "sidechains": {
                "frames": random_affines_4x4((c_sm.num_layer, n_res, 8)),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1018
1019
1020
                "atom_pos": np.random.rand(c_sm.num_layer, n_res, 14, 3).astype(
                    np.float32
                ),
1021
1022
            }
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1023

1024
        atom14_pred_pos = np.random.rand(n_res, 14, 3).astype(np.float32)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1025

1026
1027
        out_gt = f.apply({}, None, batch, value, atom14_pred_pos)
        out_gt = torch.tensor(np.array(out_gt.block_until_ready()))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1028

1029
1030
1031
1032
        to_tensor = lambda t: torch.tensor(t).cuda()
        batch = tree_map(to_tensor, batch, np.ndarray)
        value = tree_map(to_tensor, value, np.ndarray)
        atom14_pred_pos = to_tensor(atom14_pred_pos)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1033

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1034
        batch = data_transforms.atom37_to_frames(batch)
1035
        batch.update(compute_renamed_ground_truth(batch, atom14_pred_pos))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1036

1037
1038
1039
1040
1041
1042
        out_repro = sidechain_loss(
            sidechain_frames=value["sidechains"]["frames"],
            sidechain_atom_pos=value["sidechains"]["atom_pos"],
            **{**batch, **c_sm},
        )
        out_repro = out_repro.cpu()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1043

1044
1045
1046
        self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)

    @compare_utils.skip_unless_alphafold_installed()
1047
    @unittest.skipIf(not consts.is_multimer and "ptm" not in consts.model, "Not enabled for non-ptm models.")
1048
1049
1050
    def test_tm_loss_compare(self):
        config = compare_utils.get_alphafold_config()
        c_tm = config.model.heads.predicted_aligned_error
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1051

1052
1053
1054
1055
1056
1057
1058
1059
        def run_tm_loss(representations, batch, value):
            head = alphafold.model.modules.PredictedAlignedErrorHead(
                c_tm, config.model.global_config
            )
            v = {}
            v.update(value)
            v["predicted_aligned_error"] = head(representations, batch, False)
            return head.loss(v, batch)["loss"]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1060

1061
        f = hk.transform(run_tm_loss)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1062

1063
1064
        np.random.seed(42)

1065
        n_res = consts.n_res
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1066

1067
        representations = {
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1068
            "pair": np.random.rand(n_res, n_res, consts.c_z).astype(np.float32),
1069
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1070

1071
1072
        batch = {
            "backbone_affine_tensor": random_affines_vector((n_res,)),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1073
1074
1075
1076
            "backbone_affine_mask": np.random.randint(0, 2, (n_res,)).astype(
                np.float32
            ),
            "resolution": np.array(1.0).astype(np.float32),
1077
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1078

1079
1080
1081
1082
1083
        value = {
            "structure_module": {
                "final_affines": random_affines_vector((n_res,)),
            }
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1084

1085
1086
1087
1088
1089
1090
        params = compare_utils.fetch_alphafold_module_weights(
            "alphafold/alphafold_iteration/predicted_aligned_error_head"
        )

        out_gt = f.apply(params, None, representations, batch, value)
        out_gt = torch.tensor(np.array(out_gt.block_until_ready()))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1091

1092
1093
1094
1095
        to_tensor = lambda n: torch.tensor(n).cuda()
        representations = tree_map(to_tensor, representations, np.ndarray)
        batch = tree_map(to_tensor, batch, np.ndarray)
        value = tree_map(to_tensor, value, np.ndarray)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1096

1097
        batch["backbone_rigid_tensor"] = affine_vector_to_4x4(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1098
            batch["backbone_affine_tensor"]
1099
        )
1100
        batch["backbone_rigid_mask"] = batch["backbone_affine_mask"]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1101

1102
1103
        model = compare_utils.get_global_pretrained_openfold()
        logits = model.aux_heads.tm(representations["pair"])
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1104

1105
1106
1107
1108
1109
1110
        out_repro = tm_loss(
            logits=logits,
            final_affine_tensor=value["structure_module"]["final_affines"],
            **{**batch, **c_tm},
        )
        out_repro = out_repro.cpu()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1111

1112
1113
        self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)

1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
    @compare_utils.skip_unless_alphafold_installed()
    def test_chain_center_of_mass_loss(self):
        batch_size = consts.batch_size
        n_res = consts.n_res

        batch = {
            "all_atom_positions": np.random.rand(batch_size, n_res, 37, 3).astype(np.float32) * 10.0,
            "all_atom_mask": np.random.randint(0, 2, (batch_size, n_res, 37)).astype(np.float32),
            "asym_id": np.stack([random_asym_ids(n_res) for _ in range(batch_size)])
        }

        config = {
            "weight": 0.05,
            "clamp_distance": -4.0,
        }

        final_atom_positions = torch.rand(batch_size, n_res, 37, 3).cuda()

        to_tensor = lambda t: torch.tensor(t).cuda()
        batch = tree_map(to_tensor, batch, np.ndarray)

        out_repro = chain_center_of_mass_loss(
            all_atom_pred_pos=final_atom_positions,
            **{**batch, **config},
        )
        out_repro = out_repro.cpu()

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1141
1142
1143

if __name__ == "__main__":
    unittest.main()