test_loss.py 39.5 KB
Newer Older
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import os
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
16
17
import torch
import numpy as np
Christina Floristean's avatar
Christina Floristean committed
18
from pathlib import Path
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
19
import unittest
20
import ml_collections as mlc
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
21

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
22
from openfold.data import data_transforms
23
from openfold.np import residue_constants
24
25
26
27
from openfold.utils.rigid_utils import (
    Rotation,
    Rigid,
)
28
29
30
31
32
33
from openfold.utils.loss import (
    torsion_angle_loss,
    compute_fape,
    between_residue_bond_loss,
    between_residue_clash_loss,
    find_structural_violations,
34
35
36
37
38
39
40
41
42
43
44
    compute_renamed_ground_truth,
    masked_msa_loss,
    distogram_loss,
    experimentally_resolved_loss,
    violation_loss,
    fape_loss,
    lddt_loss,
    supervised_chi_loss,
    backbone_loss,
    sidechain_loss,
    tm_loss,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
45
    compute_plddt,
46
47
    compute_tm,
    chain_center_of_mass_loss
48
49
)
from openfold.utils.tensor_utils import (
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
50
51
    tree_map,
    tensor_tree_map,
52
    dict_multimap,
53
54
55
)
import tests.compare_utils as compare_utils
from tests.config import consts
56
from tests.data_utils import random_affines_vector, random_affines_4x4, random_asym_ids
57

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
58
if compare_utils.alphafold_is_installed():
59
60
61
    alphafold = compare_utils.import_alphafold()
    import jax
    import haiku as hk
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
62
63


64
65
66
67
68
def affine_vector_to_4x4(affine):
    r = Rigid.from_tensor_7(affine)
    return r.to_tensor_4x4()


69
70
71
72
73
74
75
76
77
78
def affine_vector_to_rigid(am_rigid, affine):
    rigid_flat = np.split(affine, 7, axis=-1)
    rigid_flat = [r.squeeze(-1) for r in rigid_flat]
    qw, qx, qy, qz = rigid_flat[:4]
    trans = rigid_flat[4:]
    rotation = am_rigid.Rot3Array.from_quaternion(qw, qx, qy, qz, normalize=True)
    translation = am_rigid.Vec3Array(*trans)
    return am_rigid.Rigid3Array(rotation, translation)


Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
79
class TestLoss(unittest.TestCase):
80
81
    @classmethod
    def setUpClass(cls):
82
83
84
85
86
87
88
89
90
91
92
        if compare_utils.alphafold_is_installed():
            if consts.is_multimer:
                cls.am_atom = alphafold.model.all_atom_multimer
                cls.am_fold = alphafold.model.folding_multimer
                cls.am_modules = alphafold.model.modules_multimer
                cls.am_rigid = alphafold.model.geometry
            else:
                cls.am_atom = alphafold.model.all_atom
                cls.am_fold = alphafold.model.folding
                cls.am_modules = alphafold.model.modules
                cls.am_rigid = alphafold.model.r3
93

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
94
    def test_run_torsion_angle_loss(self):
95
96
        batch_size = consts.batch_size
        n_res = consts.n_res
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
97

98
99
100
        a = torch.rand((batch_size, n_res, 7, 2))
        a_gt = torch.rand((batch_size, n_res, 7, 2))
        a_alt_gt = torch.rand((batch_size, n_res, 7, 2))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
101
102
103
104

        loss = torsion_angle_loss(a, a_gt, a_alt_gt)

    def test_run_fape(self):
105
        batch_size = consts.batch_size
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
106
107
108
109
110
111
112
113
114
        n_frames = 7
        n_atoms = 5

        x = torch.rand((batch_size, n_atoms, 3))
        x_gt = torch.rand((batch_size, n_atoms, 3))
        rots = torch.rand((batch_size, n_frames, 3, 3))
        rots_gt = torch.rand((batch_size, n_frames, 3, 3))
        trans = torch.rand((batch_size, n_frames, 3))
        trans_gt = torch.rand((batch_size, n_frames, 3))
115
116
        t = Rigid(Rotation(rot_mats=rots), trans)
        t_gt = Rigid(Rotation(rot_mats=rots_gt), trans_gt)
117
118
119
120
121
122
123
124
125
126
127
128
129
        frames_mask = torch.randint(0, 2, (batch_size, n_frames)).float()
        positions_mask = torch.randint(0, 2, (batch_size, n_atoms)).float()
        length_scale = 10

        loss = compute_fape(
            pred_frames=t,
            target_frames=t_gt,
            frames_mask=frames_mask,
            pred_positions=x,
            target_positions=x_gt,
            positions_mask=positions_mask,
            length_scale=length_scale,
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
130

131
132
133
    def test_run_between_residue_bond_loss(self):
        bs = consts.batch_size
        n = consts.n_res
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
134
135
136
        pred_pos = torch.rand(bs, n, 14, 3)
        pred_atom_mask = torch.randint(0, 2, (bs, n, 14))
        residue_index = torch.arange(n).unsqueeze(0)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
137
138
139
140
141
142
143
144
145
        aatype = torch.randint(
            0,
            22,
            (
                bs,
                n,
            ),
        )

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
146
147
148
        between_residue_bond_loss(
            pred_pos,
            pred_atom_mask,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
149
            residue_index,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
150
151
152
            aatype,
        )

153
154
155
    @compare_utils.skip_unless_alphafold_installed()
    def test_between_residue_bond_loss_compare(self):
        def run_brbl(pred_pos, pred_atom_mask, residue_index, aatype):
156
157
158
159
            if consts.is_multimer:
                pred_pos = self.am_rigid.Vec3Array.from_array(pred_pos)

            return self.am_atom.between_residue_bond_loss(
160
161
162
163
164
                pred_pos,
                pred_atom_mask,
                residue_index,
                aatype,
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
165

166
        f = hk.transform(run_brbl)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
167
168

        n_res = consts.n_res
169
        pred_pos = np.random.rand(n_res, 14, 3).astype(np.float32)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
170
        pred_atom_mask = np.random.randint(0, 2, (n_res, 14)).astype(np.float32)
171
172
        residue_index = np.arange(n_res)
        aatype = np.random.randint(0, 22, (n_res,))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
173

174
        out_gt = f.apply(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
175
176
177
178
            {},
            None,
            pred_pos,
            pred_atom_mask,
179
180
181
182
183
            residue_index,
            aatype,
        )
        out_gt = jax.tree_map(lambda x: x.block_until_ready(), out_gt)
        out_gt = jax.tree_map(lambda x: torch.tensor(np.copy(x)), out_gt)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
184

185
186
187
188
189
190
191
        out_repro = between_residue_bond_loss(
            torch.tensor(pred_pos).cuda(),
            torch.tensor(pred_atom_mask).cuda(),
            torch.tensor(residue_index).cuda(),
            torch.tensor(aatype).cuda(),
        )
        out_repro = tensor_tree_map(lambda x: x.cpu(), out_repro)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
192

193
194
195
196
197
        for k in out_gt.keys():
            self.assertTrue(
                torch.max(torch.abs(out_gt[k] - out_repro[k])) < consts.eps
            )

198
    def test_run_between_residue_clash_loss(self):
199
200
201
        bs = consts.batch_size
        n = consts.n_res

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
202
        pred_pos = torch.rand(bs, n, 14, 3)
203
        pred_atom_mask = torch.randint(0, 2, (bs, n, 14)).float()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
204
205
206
207
208
209
        atom14_atom_radius = torch.rand(bs, n, 14)
        residue_index = torch.arange(n).unsqueeze(0)

        loss = between_residue_clash_loss(
            pred_pos,
            pred_atom_mask,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
210
            atom14_atom_radius,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
211
212
213
            residue_index,
        )

214
215
    @compare_utils.skip_unless_alphafold_installed()
    def test_between_residue_clash_loss_compare(self):
216
217
218
219
220
221
222
223
224
225
226
227
        def run_brcl(pred_pos, atom_exists, atom_radius, res_ind, asym_id):
            if consts.is_multimer:
                pred_pos = self.am_rigid.Vec3Array.from_array(pred_pos)
                return self.am_atom.between_residue_clash_loss(
                    pred_pos,
                    atom_exists,
                    atom_radius,
                    res_ind,
                    asym_id
                )

            return self.am_atom.between_residue_clash_loss(
228
229
230
                pred_pos,
                atom_exists,
                atom_radius,
231
                res_ind
232
233
234
235
236
237
238
239
            )

        f = hk.transform(run_brcl)

        n_res = consts.n_res

        pred_pos = np.random.rand(n_res, 14, 3).astype(np.float32)
        atom_exists = np.random.randint(0, 2, (n_res, 14)).astype(np.float32)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
240
241
242
        res_ind = np.arange(
            n_res,
        )
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
        residx_atom14_to_atom37 = np.random.randint(0, 37, (n_res, 14)).astype(np.int64)

        atomtype_radius = [
            residue_constants.van_der_waals_radius[name[0]]
            for name in residue_constants.atom_types
        ]
        atomtype_radius = np.array(atomtype_radius).astype(np.float32)
        atom_radius = (
                atom_exists
                * atomtype_radius[residx_atom14_to_atom37]
        )

        asym_id = None
        if consts.is_multimer:
            asym_id = random_asym_ids(n_res)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
258

259
        out_gt = f.apply(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
260
261
            {},
            None,
262
263
264
265
            pred_pos,
            atom_exists,
            atom_radius,
            res_ind,
266
            asym_id
267
268
269
        )
        out_gt = jax.tree_map(lambda x: x.block_until_ready(), out_gt)
        out_gt = jax.tree_map(lambda x: torch.tensor(np.copy(x)), out_gt)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
270

271
272
273
274
275
        out_repro = between_residue_clash_loss(
            torch.tensor(pred_pos).cuda(),
            torch.tensor(atom_exists).cuda(),
            torch.tensor(atom_radius).cuda(),
            torch.tensor(res_ind).cuda(),
276
            torch.tensor(asym_id).cuda() if asym_id is not None else None,
277
278
        )
        out_repro = tensor_tree_map(lambda x: x.cpu(), out_repro)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
279

280
281
282
283
284
        for k in out_gt.keys():
            self.assertTrue(
                torch.max(torch.abs(out_gt[k] - out_repro[k])) < consts.eps
            )

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
    @compare_utils.skip_unless_alphafold_installed()
    def test_compute_plddt_compare(self):
        n_res = consts.n_res

        logits = np.random.rand(n_res, 50)

        out_gt = alphafold.common.confidence.compute_plddt(logits)
        out_gt = torch.tensor(out_gt)
        logits_t = torch.tensor(logits)
        out_repro = compute_plddt(logits_t)

        self.assertTrue(
            torch.max(torch.abs(out_gt - out_repro)) < consts.eps
        )

300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
    @compare_utils.skip_unless_alphafold_installed()
    def test_compute_ptm_compare(self):
        n_res = consts.n_res
        max_bin = 31
        no_bins = 64

        logits = np.random.rand(n_res, n_res, no_bins)
        boundaries = np.linspace(0, max_bin, num=(no_bins - 1))

        ptm_gt = alphafold.common.confidence.predicted_tm_score(logits, boundaries)
        ptm_gt = torch.tensor(ptm_gt)
        logits_t = torch.tensor(logits)
        ptm_repro = compute_tm(logits_t, no_bins=no_bins, max_bin=max_bin)

        self.assertTrue(
            torch.max(torch.abs(ptm_gt - ptm_repro)) < consts.eps
        )

        if consts.is_multimer:
            asym_id = random_asym_ids(n_res)
            iptm_gt = alphafold.common.confidence.predicted_tm_score(logits, boundaries,
                                                                     asym_id=asym_id, interface=True)
            iptm_gt = torch.tensor(iptm_gt)
            iptm_repro = compute_tm(logits_t, no_bins=no_bins, max_bin=max_bin,
                                    asym_id=torch.tensor(asym_id), interface=True)

            self.assertTrue(
                torch.max(torch.abs(iptm_gt - iptm_repro)) < consts.eps
            )

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
330
    def test_find_structural_violations(self):
331
        n = consts.n_res
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
332
333
334
335

        batch = {
            "atom14_atom_exists": torch.randint(0, 2, (n, 14)),
            "residue_index": torch.arange(n),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
336
            "aatype": torch.randint(0, 20, (n,)),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
337
338
339
340
            "residx_atom14_to_atom37": torch.randint(0, 37, (n, 14)).long(),
        }

        pred_pos = torch.rand(n, 14, 3)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
341

342
        config = {
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
343
344
            "clash_overlap_tolerance": 1.5,
            "violation_tolerance_factor": 12.0,
345
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
346

347
        find_structural_violations(batch, pred_pos, **config)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
348

349
350
351
352
    @compare_utils.skip_unless_alphafold_installed()
    def test_find_structural_violations_compare(self):
        def run_fsv(batch, pos, config):
            cwd = os.getcwd()
Christina Floristean's avatar
Christina Floristean committed
353
354
            fpath = Path(__file__).parent.resolve() / "test_data"
            os.chdir(str(fpath))
355
356
357
358
359
360
361
362
363
364
365
366
367

            if consts.is_multimer:
                atom14_pred_pos = self.am_rigid.Vec3Array.from_array(pos)
                return self.am_fold.find_structural_violations(
                    batch['aatype'],
                    batch['residue_index'],
                    batch['atom14_atom_exists'],
                    atom14_pred_pos,
                    config,
                    batch['asym_id']
                )

            loss = self.am_fold.find_structural_violations(
368
369
370
371
372
373
374
375
                batch,
                pos,
                config,
            )
            os.chdir(cwd)
            return loss

        f = hk.transform(run_fsv)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
376

377
        n_res = consts.n_res
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
378

379
380
381
        batch = {
            "atom14_atom_exists": np.random.randint(0, 2, (n_res, 14)),
            "residue_index": np.arange(n_res),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
382
            "aatype": np.random.randint(0, 20, (n_res,)),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
383
384
385
            "residx_atom14_to_atom37": np.random.randint(
                0, 37, (n_res, 14)
            ).astype(np.int64),
386
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
387

388
389
390
        if consts.is_multimer:
            batch["asym_id"] = random_asym_ids(n_res)

391
        pred_pos = np.random.rand(n_res, 14, 3)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
392
393
394
395
396
397

        config = mlc.ConfigDict(
            {
                "clash_overlap_tolerance": 1.5,
                "violation_tolerance_factor": 12.0,
            }
398
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
399
400

        out_gt = f.apply({}, None, batch, pred_pos, config)
401
402
        out_gt = jax.tree_map(lambda x: x.block_until_ready(), out_gt)
        out_gt = jax.tree_map(lambda x: torch.tensor(np.copy(x)), out_gt)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
403
404

        batch = tree_map(lambda x: torch.tensor(x).cuda(), batch, np.ndarray)
405
406
407
408
409
410
        out_repro = find_structural_violations(
            batch,
            torch.tensor(pred_pos).cuda(),
            **config,
        )
        out_repro = tensor_tree_map(lambda x: x.cpu(), out_repro)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
411

412
413
        def compare(out):
            gt, repro = out
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
414
415
            assert torch.max(torch.abs(gt - repro)) < consts.eps

416
417
418
419
420
421
422
423
424
        dict_multimap(compare, [out_gt, out_repro])

    @compare_utils.skip_unless_alphafold_installed()
    def test_compute_renamed_ground_truth_compare(self):
        def run_crgt(batch, atom14_pred_pos):
            return alphafold.model.folding.compute_renamed_ground_truth(
                batch,
                atom14_pred_pos,
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
425

426
        f = hk.transform(run_crgt)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
427

428
        n_res = consts.n_res
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
429

430
431
        batch = {
            "seq_mask": np.random.randint(0, 2, (n_res,)).astype(np.float32),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
432
            "aatype": np.random.randint(0, 20, (n_res,)),
433
            "atom14_gt_positions": np.random.rand(n_res, 14, 3),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
434
435
436
437
438
439
440
441
442
            "atom14_gt_exists": np.random.randint(0, 2, (n_res, 14)).astype(
                np.float32
            ),
            "all_atom_mask": np.random.randint(0, 2, (n_res, 37)).astype(
                np.float32
            ),
            "all_atom_positions": np.random.rand(n_res, 37, 3).astype(
                np.float32
            ),
443
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
444

445
446
        def _build_extra_feats_np():
            b = tree_map(lambda n: torch.tensor(n), batch, np.ndarray)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
447
448
            b = data_transforms.make_atom14_masks(b)
            b = data_transforms.make_atom14_positions(b)
449
            return tensor_tree_map(lambda t: np.array(t), b)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
450

451
        batch = _build_extra_feats_np()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
452

453
        atom14_pred_pos = np.random.rand(n_res, 14, 3)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
454

455
456
        out_gt = f.apply({}, None, batch, atom14_pred_pos)
        out_gt = jax.tree_map(lambda x: torch.tensor(np.array(x)), out_gt)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
457
458

        batch = tree_map(lambda x: torch.tensor(x).cuda(), batch, np.ndarray)
459
        atom14_pred_pos = torch.tensor(atom14_pred_pos).cuda()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
460

461
462
        out_repro = compute_renamed_ground_truth(batch, atom14_pred_pos)
        out_repro = tensor_tree_map(lambda t: t.cpu(), out_repro)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
463

464
465
466
467
468
469
470
471
472
473
474
475
476
        for k in out_repro:
            self.assertTrue(
                torch.max(torch.abs(out_gt[k] - out_repro[k])) < consts.eps
            )

    @compare_utils.skip_unless_alphafold_installed()
    def test_msa_loss_compare(self):
        def run_msa_loss(value, batch):
            config = compare_utils.get_alphafold_config()
            msa_head = alphafold.model.modules.MaskedMsaHead(
                config.model.heads.masked_msa, config.model.global_config
            )
            return msa_head.loss(value, batch)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
477

478
        f = hk.transform(run_msa_loss)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
479

480
481
        n_res = consts.n_res
        n_seq = consts.n_seq
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
482

483
        value = {
484
            "logits": np.random.rand(n_res, n_seq, consts.msa_logits).astype(np.float32),
485
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
486

487
488
        batch = {
            "true_msa": np.random.randint(0, 21, (n_res, n_seq)),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
489
490
            "bert_mask": np.random.randint(0, 2, (n_res, n_seq)).astype(
                np.float32
Christina Floristean's avatar
Christina Floristean committed
491
            )
492
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
493

494
495
        out_gt = f.apply({}, None, value, batch)["loss"]
        out_gt = torch.tensor(np.array(out_gt))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
496
497
498
499

        value = tree_map(lambda x: torch.tensor(x).cuda(), value, np.ndarray)
        batch = tree_map(lambda x: torch.tensor(x).cuda(), batch, np.ndarray)

500
501
502
        with torch.no_grad():
            out_repro = masked_msa_loss(
                value["logits"],
Christina Floristean's avatar
Christina Floristean committed
503
504
505
                batch["true_msa"],
                batch["bert_mask"],
                consts.msa_logits
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
506
            )
507
        out_repro = tensor_tree_map(lambda t: t.cpu(), out_repro)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
508

509
510
511
512
513
514
        self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)

    @compare_utils.skip_unless_alphafold_installed()
    def test_distogram_loss_compare(self):
        config = compare_utils.get_alphafold_config()
        c_distogram = config.model.heads.distogram
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
515

516
517
518
519
520
        def run_distogram_loss(value, batch):
            dist_head = alphafold.model.modules.DistogramHead(
                c_distogram, config.model.global_config
            )
            return dist_head.loss(value, batch)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
521

522
        f = hk.transform(run_distogram_loss)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
523

524
        n_res = consts.n_res
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
525

526
        value = {
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
527
528
529
            "logits": np.random.rand(n_res, n_res, c_distogram.num_bins).astype(
                np.float32
            ),
530
531
532
533
            "bin_edges": np.linspace(
                c_distogram.first_break,
                c_distogram.last_break,
                c_distogram.num_bins,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
534
            ),
535
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
536

537
538
        batch = {
            "pseudo_beta": np.random.rand(n_res, 3).astype(np.float32),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
539
            "pseudo_beta_mask": np.random.randint(0, 2, (n_res,)),
540
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
541

542
543
        out_gt = f.apply({}, None, value, batch)["loss"]
        out_gt = torch.tensor(np.array(out_gt))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
544
545
546
547
548

        value = tree_map(lambda x: torch.tensor(x).cuda(), value, np.ndarray)

        batch = tree_map(lambda x: torch.tensor(x).cuda(), batch, np.ndarray)

549
550
551
552
553
554
555
        with torch.no_grad():
            out_repro = distogram_loss(
                logits=value["logits"],
                min_bin=c_distogram.first_break,
                max_bin=c_distogram.last_break,
                no_bins=c_distogram.num_bins,
                **batch,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
556
557
            )

558
        out_repro = tensor_tree_map(lambda t: t.cpu(), out_repro)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
559

560
561
562
563
564
565
        self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)

    @compare_utils.skip_unless_alphafold_installed()
    def test_experimentally_resolved_loss_compare(self):
        config = compare_utils.get_alphafold_config()
        c_experimentally_resolved = config.model.heads.experimentally_resolved
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
566

567
568
569
570
571
        def run_experimentally_resolved_loss(value, batch):
            er_head = alphafold.model.modules.ExperimentallyResolvedHead(
                c_experimentally_resolved, config.model.global_config
            )
            return er_head.loss(value, batch)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
572

573
        f = hk.transform(run_experimentally_resolved_loss)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
574

575
        n_res = consts.n_res
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
576

577
578
579
        value = {
            "logits": np.random.rand(n_res, 37).astype(np.float32),
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
580

581
582
583
        batch = {
            "all_atom_mask": np.random.randint(0, 2, (n_res, 37)),
            "atom37_atom_exists": np.random.randint(0, 2, (n_res, 37)),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
584
            "resolution": np.array(1.0),
585
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
586

587
588
        out_gt = f.apply({}, None, value, batch)["loss"]
        out_gt = torch.tensor(np.array(out_gt))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
589
590
591
592
593

        value = tree_map(lambda x: torch.tensor(x).cuda(), value, np.ndarray)

        batch = tree_map(lambda x: torch.tensor(x).cuda(), batch, np.ndarray)

594
595
596
597
598
599
        with torch.no_grad():
            out_repro = experimentally_resolved_loss(
                logits=value["logits"],
                min_resolution=c_experimentally_resolved.min_resolution,
                max_resolution=c_experimentally_resolved.max_resolution,
                **batch,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
600
601
            )

602
        out_repro = tensor_tree_map(lambda t: t.cpu(), out_repro)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
603

604
605
606
607
608
609
        self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)

    @compare_utils.skip_unless_alphafold_installed()
    def test_supervised_chi_loss_compare(self):
        config = compare_utils.get_alphafold_config()
        c_chi_loss = config.model.heads.structure_module
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
610

611
        def run_supervised_chi_loss(value, batch):
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
            if consts.is_multimer:
                pred_angles = np.reshape(
                    value['sidechains']['angles_sin_cos'], [-1, consts.n_res, 7, 2])

                unnormed_angles = np.reshape(
                    value['sidechains']['unnormalized_angles_sin_cos'], [-1, consts.n_res, 7, 2])

                chi_loss, _, _ = self.am_fold.supervised_chi_loss(
                    batch['seq_mask'],
                    batch['chi_mask'],
                    batch['aatype'],
                    batch['chi_angles'],
                    pred_angles,
                    unnormed_angles,
                    c_chi_loss
                )
                return chi_loss

630
            ret = {
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
631
                "loss": jax.numpy.array(0.0),
632
            }
633
            self.am_fold.supervised_chi_loss(
634
635
636
637
638
639
640
641
642
643
                ret, batch, value, c_chi_loss
            )
            return ret["loss"]

        f = hk.transform(run_supervised_chi_loss)

        n_res = consts.n_res

        value = {
            "sidechains": {
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
644
645
646
647
648
649
                "angles_sin_cos": np.random.rand(8, n_res, 7, 2).astype(
                    np.float32
                ),
                "unnormalized_angles_sin_cos": np.random.rand(
                    8, n_res, 7, 2
                ).astype(np.float32),
650
651
652
653
654
655
656
657
658
659
660
661
            }
        }

        batch = {
            "aatype": np.random.randint(0, 21, (n_res,)),
            "seq_mask": np.random.randint(0, 2, (n_res,)),
            "chi_mask": np.random.randint(0, 2, (n_res, 4)),
            "chi_angles": np.random.rand(n_res, 4).astype(np.float32),
        }

        out_gt = f.apply({}, None, value, batch)
        out_gt = torch.tensor(np.array(out_gt.block_until_ready()))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
662
        value = tree_map(lambda x: torch.tensor(x).cuda(), value, np.ndarray)
663

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
664
        batch = tree_map(lambda x: torch.tensor(x).cuda(), batch, np.ndarray)
665
666
667
668
669
670
671
672
673
674
675
676
677

        batch["chi_angles_sin_cos"] = torch.stack(
            [
                torch.sin(batch["chi_angles"]),
                torch.cos(batch["chi_angles"]),
            ],
            dim=-1,
        )

        with torch.no_grad():
            out_repro = supervised_chi_loss(
                chi_weight=c_chi_loss.chi_weight,
                angle_norm_weight=c_chi_loss.angle_norm_weight,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
678
679
680
                **{**batch, **value["sidechains"]},
            )

681
682
683
684
        out_repro = tensor_tree_map(lambda t: t.cpu(), out_repro)

        self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)

685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
    @compare_utils.skip_unless_alphafold_installed()
    def test_violation_loss(self):
        config = compare_utils.get_alphafold_config()
        c_viol = config.model.heads.structure_module
        n_res = consts.n_res

        batch = {
            "seq_mask": np.random.randint(0, 2, (n_res,)).astype(np.float32),
            "residue_index": np.arange(n_res),
            "aatype": np.random.randint(0, 21, (n_res,)),
        }

        if consts.is_multimer:
            batch["asym_id"] = random_asym_ids(n_res)

        batch = tree_map(lambda n: torch.tensor(n).cuda(), batch, np.ndarray)

        atom14_pred_pos = np.random.rand(n_res, 14, 3).astype(np.float32)
        atom14_pred_pos = torch.tensor(atom14_pred_pos).cuda()

        batch = data_transforms.make_atom14_masks(batch)

        loss_sum_clash = violation_loss(
            find_structural_violations(batch, atom14_pred_pos, **c_viol),
            average_clashes=False, **batch
        )
        loss_sum_clash = loss_sum_clash.cpu()

        loss_avg_clash = violation_loss(
            find_structural_violations(batch, atom14_pred_pos, **c_viol),
            average_clashes=True, **batch
        )
        loss_avg_clash = loss_avg_clash.cpu()

719
720
721
722
    @compare_utils.skip_unless_alphafold_installed()
    def test_violation_loss_compare(self):
        config = compare_utils.get_alphafold_config()
        c_viol = config.model.heads.structure_module
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
723

724
725
        def run_viol_loss(batch, atom14_pred_pos):
            ret = {
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
726
                "loss": np.array(0.0).astype(np.float32),
727
            }
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742

            if consts.is_multimer:
                atom14_pred_pos = self.am_rigid.Vec3Array.from_array(atom14_pred_pos)
                viol = self.am_fold.find_structural_violations(
                    batch['aatype'],
                    batch['residue_index'],
                    batch['atom14_atom_exists'],
                    atom14_pred_pos,
                    c_viol,
                    batch['asym_id']
                )
                return self.am_fold.structural_violation_loss(mask=batch['atom14_atom_exists'],
                                                              violations=viol,
                                                              config=c_viol)

743
            value = {}
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
744
745
            value[
                "violations"
746
            ] = self.am_fold.find_structural_violations(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
747
748
749
                batch,
                atom14_pred_pos,
                c_viol,
750
            )
751
752

            self.am_fold.structural_violation_loss(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
753
754
755
756
                ret,
                batch,
                value,
                c_viol,
757
758
            )
            return ret["loss"]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
759

760
        f = hk.transform(run_viol_loss)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
761

762
        n_res = consts.n_res
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
763

764
765
766
        batch = {
            "seq_mask": np.random.randint(0, 2, (n_res,)).astype(np.float32),
            "residue_index": np.arange(n_res),
767
            "aatype": np.random.randint(0, 21, (n_res,))
768
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
769

770
771
772
        if consts.is_multimer:
            batch["asym_id"] = random_asym_ids(n_res)

773
        atom14_pred_pos = np.random.rand(n_res, 14, 3).astype(np.float32)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
774

775
776
777
        alphafold.model.tf.data_transforms.make_atom14_masks(batch)
        batch = {k: np.array(v) for k, v in batch.items()}

778
779
        out_gt = f.apply({}, None, batch, atom14_pred_pos)
        out_gt = torch.tensor(np.array(out_gt.block_until_ready()))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
780
781

        batch = tree_map(lambda n: torch.tensor(n).cuda(), batch, np.ndarray)
782
        atom14_pred_pos = torch.tensor(atom14_pred_pos).cuda()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
783

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
784
        batch = data_transforms.make_atom14_masks(batch)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
785

786
787
788
789
790
        out_repro = violation_loss(
            find_structural_violations(batch, atom14_pred_pos, **c_viol),
            **batch,
        )
        out_repro = out_repro.cpu()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
791

792
793
794
795
796
797
        self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)

    @compare_utils.skip_unless_alphafold_installed()
    def test_lddt_loss_compare(self):
        config = compare_utils.get_alphafold_config()
        c_plddt = config.model.heads.predicted_lddt
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
798

799
800
801
802
803
        def run_plddt_loss(value, batch):
            head = alphafold.model.modules.PredictedLDDTHead(
                c_plddt, config.model.global_config
            )
            return head.loss(value, batch)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
804

805
        f = hk.transform(run_plddt_loss)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
806

807
        n_res = consts.n_res
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
808

809
810
        value = {
            "predicted_lddt": {
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
811
812
813
                "logits": np.random.rand(n_res, c_plddt.num_bins).astype(
                    np.float32
                ),
814
815
            },
            "structure_module": {
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
816
817
818
819
                "final_atom_positions": np.random.rand(n_res, 37, 3).astype(
                    np.float32
                ),
            },
820
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
821

822
        batch = {
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
823
824
825
826
827
828
829
            "all_atom_positions": np.random.rand(n_res, 37, 3).astype(
                np.float32
            ),
            "all_atom_mask": np.random.randint(0, 2, (n_res, 37)).astype(
                np.float32
            ),
            "resolution": np.array(1.0).astype(np.float32),
830
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
831

832
833
        out_gt = f.apply({}, None, value, batch)
        out_gt = torch.tensor(np.array(out_gt["loss"]))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
834

835
836
837
        to_tensor = lambda t: torch.tensor(t).cuda()
        value = tree_map(to_tensor, value, np.ndarray)
        batch = tree_map(to_tensor, batch, np.ndarray)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
838

839
840
841
842
843
844
        out_repro = lddt_loss(
            logits=value["predicted_lddt"]["logits"],
            all_atom_pred_pos=value["structure_module"]["final_atom_positions"],
            **{**batch, **c_plddt},
        )
        out_repro = out_repro.cpu()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
845

846
847
848
        self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)

    @compare_utils.skip_unless_alphafold_installed()
849
    def test_backbone_loss_compare(self):
850
851
        config = compare_utils.get_alphafold_config()
        c_sm = config.model.heads.structure_module
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
852

853
        def run_bb_loss(batch, value):
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
            if consts.is_multimer:
                intra_chain_mask = (batch["asym_id"][..., None] == batch["asym_id"][..., None, :]).astype(np.float32)
                gt_rigid = affine_vector_to_rigid(self.am_rigid, batch["backbone_affine_tensor"])
                target_rigid = affine_vector_to_rigid(self.am_rigid, value['traj'])
                intra_chain_bb_loss, intra_chain_fape = self.am_fold.backbone_loss(
                    gt_rigid=gt_rigid,
                    gt_frames_mask=batch["backbone_affine_mask"],
                    gt_positions_mask=batch["backbone_affine_mask"],
                    target_rigid=target_rigid,
                    config=c_sm.intra_chain_fape,
                    pair_mask=intra_chain_mask)
                interface_bb_loss, interface_fape = self.am_fold.backbone_loss(
                    gt_rigid=gt_rigid,
                    gt_frames_mask=batch["backbone_affine_mask"],
                    gt_positions_mask=batch["backbone_affine_mask"],
                    target_rigid=target_rigid,
                    config=c_sm.interface_fape,
                    pair_mask=1. - intra_chain_mask)

                return intra_chain_bb_loss + interface_bb_loss

875
            ret = {
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
876
                "loss": np.array(0.0),
877
            }
878
            self.am_fold.backbone_loss(ret, batch, value, c_sm)
879
880
881
882
883
884
885
886
            return ret["loss"]

        f = hk.transform(run_bb_loss)

        n_res = consts.n_res

        batch = {
            "backbone_affine_tensor": random_affines_vector((n_res,)),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
887
888
889
            "backbone_affine_mask": np.random.randint(0, 2, (n_res,)).astype(
                np.float32
            ),
890
            "use_clamped_fape": np.array(0.0)
891
892
893
        }

        value = {
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
894
895
896
897
898
899
            "traj": random_affines_vector(
                (
                    c_sm.num_layer,
                    n_res,
                )
            ),
900
901
        }

902
903
904
        if consts.is_multimer:
            batch["asym_id"] = random_asym_ids(n_res)

905
906
907
908
909
910
911
        out_gt = f.apply({}, None, batch, value)
        out_gt = torch.tensor(np.array(out_gt.block_until_ready()))

        to_tensor = lambda t: torch.tensor(t).cuda()
        batch = tree_map(to_tensor, batch, np.ndarray)
        value = tree_map(to_tensor, value, np.ndarray)

912
        batch["backbone_rigid_tensor"] = affine_vector_to_4x4(
913
914
            batch["backbone_affine_tensor"]
        )
915
916
        batch["backbone_rigid_mask"] = batch["backbone_affine_mask"]
        
917
918
919
920
921
922
923
924
925
926
927
928
        if consts.is_multimer:
            intra_chain_mask = (batch["asym_id"][..., None]
                                == batch["asym_id"][..., None, :]).to(dtype=value["traj"].dtype)
            intra_chain_out = backbone_loss(traj=value["traj"], pair_mask=intra_chain_mask,
                                            **{**batch, **c_sm.intra_chain_fape})
            interface_out = backbone_loss(traj=value["traj"], pair_mask=1. - intra_chain_mask,
                                          **{**batch, **c_sm.interface_fape})
            out_repro = intra_chain_out + interface_out
            out_repro = out_repro.cpu()
        else:
            out_repro = backbone_loss(traj=value["traj"], **{**batch, **c_sm})
            out_repro = out_repro.cpu()
929
930
931
932
933
934
935

        self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)

    @compare_utils.skip_unless_alphafold_installed()
    def test_sidechain_loss_compare(self):
        config = compare_utils.get_alphafold_config()
        c_sm = config.model.heads.structure_module
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
936

937
        def run_sidechain_loss(batch, value, atom14_pred_positions):
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
            if consts.is_multimer:
                atom14_pred_positions = self.am_rigid.Vec3Array.from_array(atom14_pred_positions)
                all_atom_positions = self.am_rigid.Vec3Array.from_array(batch["all_atom_positions"])
                gt_positions, gt_mask, alt_naming_is_better = self.am_fold.compute_atom14_gt(
                    aatype=batch["aatype"], all_atom_positions=all_atom_positions,
                    all_atom_mask=batch["all_atom_mask"], pred_pos=atom14_pred_positions)
                pred_frames = self.am_rigid.Rigid3Array.from_array4x4(value["sidechains"]["frames"])
                pred_positions = self.am_rigid.Vec3Array.from_array(value["sidechains"]["atom_pos"])
                gt_sc_frames, gt_sc_frames_mask = self.am_fold.compute_frames(
                    aatype=batch["aatype"],
                    all_atom_positions=all_atom_positions,
                    all_atom_mask=batch["all_atom_mask"],
                    use_alt=alt_naming_is_better)
                return self.am_fold.sidechain_loss(gt_sc_frames,
                                                   gt_sc_frames_mask,
                                                   gt_positions,
                                                   gt_mask,
                                                   pred_frames,
                                                   pred_positions,
                                                   c_sm)['loss']
958
959
            batch = {
                **batch,
960
                **self.am_atom.atom37_to_frames(
961
962
963
                    batch["aatype"],
                    batch["all_atom_positions"],
                    batch["all_atom_mask"],
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
964
                ),
965
966
967
            }
            v = {}
            v["sidechains"] = {}
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
968
969
            v["sidechains"][
                "frames"
970
            ] = self.am_rigid.rigids_from_tensor4x4(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
971
                value["sidechains"]["frames"]
972
            )
973
            v["sidechains"]["atom_pos"] = self.am_rigid.vecs_from_tensor(
974
975
                value["sidechains"]["atom_pos"]
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
976
            v.update(
977
                self.am_fold.compute_renamed_ground_truth(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
978
979
980
981
                    batch,
                    atom14_pred_positions,
                )
            )
982
            value = v
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
983

984
            ret = self.am_fold.sidechain_loss(batch, value, c_sm)
985
            return ret["loss"]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
986

987
        f = hk.transform(run_sidechain_loss)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
988

989
        n_res = consts.n_res
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
990

991
992
        batch = {
            "seq_mask": np.random.randint(0, 2, (n_res,)).astype(np.float32),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
993
            "aatype": np.random.randint(0, 20, (n_res,)),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
            "atom14_gt_positions": np.random.rand(n_res, 14, 3).astype(
                np.float32
            ),
            "atom14_gt_exists": np.random.randint(0, 2, (n_res, 14)).astype(
                np.float32
            ),
            "all_atom_positions": np.random.rand(n_res, 37, 3).astype(
                np.float32
            ),
            "all_atom_mask": np.random.randint(0, 2, (n_res, 37)).astype(
                np.float32
            ),
1006
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1007

1008
1009
        def _build_extra_feats_np():
            b = tree_map(lambda n: torch.tensor(n), batch, np.ndarray)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1010
1011
            b = data_transforms.make_atom14_masks(b)
            b = data_transforms.make_atom14_positions(b)
1012
            return tensor_tree_map(lambda t: np.array(t), b)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1013
1014
1015

        batch = _build_extra_feats_np()

1016
1017
1018
        value = {
            "sidechains": {
                "frames": random_affines_4x4((c_sm.num_layer, n_res, 8)),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1019
1020
1021
                "atom_pos": np.random.rand(c_sm.num_layer, n_res, 14, 3).astype(
                    np.float32
                ),
1022
1023
            }
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1024

1025
        atom14_pred_pos = np.random.rand(n_res, 14, 3).astype(np.float32)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1026

1027
1028
        out_gt = f.apply({}, None, batch, value, atom14_pred_pos)
        out_gt = torch.tensor(np.array(out_gt.block_until_ready()))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1029

1030
1031
1032
1033
        to_tensor = lambda t: torch.tensor(t).cuda()
        batch = tree_map(to_tensor, batch, np.ndarray)
        value = tree_map(to_tensor, value, np.ndarray)
        atom14_pred_pos = to_tensor(atom14_pred_pos)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1034

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1035
        batch = data_transforms.atom37_to_frames(batch)
1036
        batch.update(compute_renamed_ground_truth(batch, atom14_pred_pos))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1037

1038
1039
1040
1041
1042
1043
        out_repro = sidechain_loss(
            sidechain_frames=value["sidechains"]["frames"],
            sidechain_atom_pos=value["sidechains"]["atom_pos"],
            **{**batch, **c_sm},
        )
        out_repro = out_repro.cpu()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1044

1045
1046
1047
        self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)

    @compare_utils.skip_unless_alphafold_installed()
1048
    @unittest.skipIf(not consts.is_multimer and "ptm" not in consts.model, "Not enabled for non-ptm models.")
1049
1050
1051
    def test_tm_loss_compare(self):
        config = compare_utils.get_alphafold_config()
        c_tm = config.model.heads.predicted_aligned_error
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1052

1053
1054
1055
1056
1057
1058
1059
1060
        def run_tm_loss(representations, batch, value):
            head = alphafold.model.modules.PredictedAlignedErrorHead(
                c_tm, config.model.global_config
            )
            v = {}
            v.update(value)
            v["predicted_aligned_error"] = head(representations, batch, False)
            return head.loss(v, batch)["loss"]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1061

1062
        f = hk.transform(run_tm_loss)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1063

1064
1065
        np.random.seed(42)

1066
        n_res = consts.n_res
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1067

1068
        representations = {
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1069
            "pair": np.random.rand(n_res, n_res, consts.c_z).astype(np.float32),
1070
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1071

1072
1073
        batch = {
            "backbone_affine_tensor": random_affines_vector((n_res,)),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1074
1075
1076
1077
            "backbone_affine_mask": np.random.randint(0, 2, (n_res,)).astype(
                np.float32
            ),
            "resolution": np.array(1.0).astype(np.float32),
1078
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1079

1080
1081
1082
1083
1084
        value = {
            "structure_module": {
                "final_affines": random_affines_vector((n_res,)),
            }
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1085

1086
1087
1088
1089
1090
1091
        params = compare_utils.fetch_alphafold_module_weights(
            "alphafold/alphafold_iteration/predicted_aligned_error_head"
        )

        out_gt = f.apply(params, None, representations, batch, value)
        out_gt = torch.tensor(np.array(out_gt.block_until_ready()))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1092

1093
1094
1095
1096
        to_tensor = lambda n: torch.tensor(n).cuda()
        representations = tree_map(to_tensor, representations, np.ndarray)
        batch = tree_map(to_tensor, batch, np.ndarray)
        value = tree_map(to_tensor, value, np.ndarray)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1097

1098
        batch["backbone_rigid_tensor"] = affine_vector_to_4x4(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1099
            batch["backbone_affine_tensor"]
1100
        )
1101
        batch["backbone_rigid_mask"] = batch["backbone_affine_mask"]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1102

1103
1104
        model = compare_utils.get_global_pretrained_openfold()
        logits = model.aux_heads.tm(representations["pair"])
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1105

1106
1107
1108
1109
1110
1111
        out_repro = tm_loss(
            logits=logits,
            final_affine_tensor=value["structure_module"]["final_affines"],
            **{**batch, **c_tm},
        )
        out_repro = out_repro.cpu()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1112

1113
1114
        self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)

1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
    @compare_utils.skip_unless_alphafold_installed()
    def test_chain_center_of_mass_loss(self):
        batch_size = consts.batch_size
        n_res = consts.n_res

        batch = {
            "all_atom_positions": np.random.rand(batch_size, n_res, 37, 3).astype(np.float32) * 10.0,
            "all_atom_mask": np.random.randint(0, 2, (batch_size, n_res, 37)).astype(np.float32),
            "asym_id": np.stack([random_asym_ids(n_res) for _ in range(batch_size)])
        }

        config = {
            "weight": 0.05,
            "clamp_distance": -4.0,
        }

        final_atom_positions = torch.rand(batch_size, n_res, 37, 3).cuda()

        to_tensor = lambda t: torch.tensor(t).cuda()
        batch = tree_map(to_tensor, batch, np.ndarray)

        out_repro = chain_center_of_mass_loss(
            all_atom_pred_pos=final_atom_positions,
            **{**batch, **config},
        )
        out_repro = out_repro.cpu()

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1142
1143
1144

if __name__ == "__main__":
    unittest.main()