"include/gtest/vscode:/vscode.git/clone" did not exist on "2c8101052343798fe1e2fbcc7f07c27fd3556d1c"
test_loss.py 34.6 KB
Newer Older
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import os
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
16
17
18
import torch
import numpy as np
import unittest
19
import ml_collections as mlc
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
20

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
21
from openfold.data import data_transforms
22
23
24
25
from openfold.utils.rigid_utils import (
    Rotation,
    Rigid,
)
26
27
28
29
30
31
from openfold.utils.loss import (
    torsion_angle_loss,
    compute_fape,
    between_residue_bond_loss,
    between_residue_clash_loss,
    find_structural_violations,
32
33
34
35
36
37
38
39
40
41
42
    compute_renamed_ground_truth,
    masked_msa_loss,
    distogram_loss,
    experimentally_resolved_loss,
    violation_loss,
    fape_loss,
    lddt_loss,
    supervised_chi_loss,
    backbone_loss,
    sidechain_loss,
    tm_loss,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
43
    compute_plddt,
44
45
)
from openfold.utils.tensor_utils import (
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
46
47
    tree_map,
    tensor_tree_map,
48
    dict_multimap,
49
50
51
)
import tests.compare_utils as compare_utils
from tests.config import consts
52
from tests.data_utils import random_affines_vector, random_affines_4x4, random_asym_ids
53

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
54
if compare_utils.alphafold_is_installed():
55
56
57
    alphafold = compare_utils.import_alphafold()
    import jax
    import haiku as hk
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
58
59


60
61
62
63
64
def affine_vector_to_4x4(affine):
    r = Rigid.from_tensor_7(affine)
    return r.to_tensor_4x4()


65
66
67
68
69
70
71
72
73
74
def affine_vector_to_rigid(am_rigid, affine):
    rigid_flat = np.split(affine, 7, axis=-1)
    rigid_flat = [r.squeeze(-1) for r in rigid_flat]
    qw, qx, qy, qz = rigid_flat[:4]
    trans = rigid_flat[4:]
    rotation = am_rigid.Rot3Array.from_quaternion(qw, qx, qy, qz, normalize=True)
    translation = am_rigid.Vec3Array(*trans)
    return am_rigid.Rigid3Array(rotation, translation)


Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
75
class TestLoss(unittest.TestCase):
76
77
78
79
80
81
82
83
84
85
86
87
88
    @classmethod
    def setUpClass(cls):
        if consts.is_multimer:
            cls.am_atom = alphafold.model.all_atom_multimer
            cls.am_fold = alphafold.model.folding_multimer
            cls.am_modules = alphafold.model.modules_multimer
            cls.am_rigid = alphafold.model.geometry
        else:
            cls.am_atom = alphafold.model.all_atom
            cls.am_fold = alphafold.model.folding
            cls.am_modules = alphafold.model.modules
            cls.am_rigid = alphafold.model.r3

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
89
    def test_run_torsion_angle_loss(self):
90
91
        batch_size = consts.batch_size
        n_res = consts.n_res
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
92

93
94
95
        a = torch.rand((batch_size, n_res, 7, 2))
        a_gt = torch.rand((batch_size, n_res, 7, 2))
        a_alt_gt = torch.rand((batch_size, n_res, 7, 2))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
96
97
98
99

        loss = torsion_angle_loss(a, a_gt, a_alt_gt)

    def test_run_fape(self):
100
        batch_size = consts.batch_size
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
101
102
103
104
105
106
107
108
109
        n_frames = 7
        n_atoms = 5

        x = torch.rand((batch_size, n_atoms, 3))
        x_gt = torch.rand((batch_size, n_atoms, 3))
        rots = torch.rand((batch_size, n_frames, 3, 3))
        rots_gt = torch.rand((batch_size, n_frames, 3, 3))
        trans = torch.rand((batch_size, n_frames, 3))
        trans_gt = torch.rand((batch_size, n_frames, 3))
110
111
        t = Rigid(Rotation(rot_mats=rots), trans)
        t_gt = Rigid(Rotation(rot_mats=rots_gt), trans_gt)
112
113
114
115
116
117
118
119
120
121
122
123
124
        frames_mask = torch.randint(0, 2, (batch_size, n_frames)).float()
        positions_mask = torch.randint(0, 2, (batch_size, n_atoms)).float()
        length_scale = 10

        loss = compute_fape(
            pred_frames=t,
            target_frames=t_gt,
            frames_mask=frames_mask,
            pred_positions=x,
            target_positions=x_gt,
            positions_mask=positions_mask,
            length_scale=length_scale,
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
125

126
127
128
    def test_run_between_residue_bond_loss(self):
        bs = consts.batch_size
        n = consts.n_res
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
129
130
131
        pred_pos = torch.rand(bs, n, 14, 3)
        pred_atom_mask = torch.randint(0, 2, (bs, n, 14))
        residue_index = torch.arange(n).unsqueeze(0)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
132
133
134
135
136
137
138
139
140
        aatype = torch.randint(
            0,
            22,
            (
                bs,
                n,
            ),
        )

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
141
142
143
        between_residue_bond_loss(
            pred_pos,
            pred_atom_mask,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
144
            residue_index,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
145
146
147
            aatype,
        )

148
149
150
    @compare_utils.skip_unless_alphafold_installed()
    def test_between_residue_bond_loss_compare(self):
        def run_brbl(pred_pos, pred_atom_mask, residue_index, aatype):
151
152
153
154
            if consts.is_multimer:
                pred_pos = self.am_rigid.Vec3Array.from_array(pred_pos)

            return self.am_atom.between_residue_bond_loss(
155
156
157
158
159
                pred_pos,
                pred_atom_mask,
                residue_index,
                aatype,
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
160

161
        f = hk.transform(run_brbl)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
162
163

        n_res = consts.n_res
164
        pred_pos = np.random.rand(n_res, 14, 3).astype(np.float32)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
165
        pred_atom_mask = np.random.randint(0, 2, (n_res, 14)).astype(np.float32)
166
167
        residue_index = np.arange(n_res)
        aatype = np.random.randint(0, 22, (n_res,))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
168

169
        out_gt = f.apply(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
170
171
172
173
            {},
            None,
            pred_pos,
            pred_atom_mask,
174
175
176
177
178
            residue_index,
            aatype,
        )
        out_gt = jax.tree_map(lambda x: x.block_until_ready(), out_gt)
        out_gt = jax.tree_map(lambda x: torch.tensor(np.copy(x)), out_gt)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
179

180
181
182
183
184
185
186
        out_repro = between_residue_bond_loss(
            torch.tensor(pred_pos).cuda(),
            torch.tensor(pred_atom_mask).cuda(),
            torch.tensor(residue_index).cuda(),
            torch.tensor(aatype).cuda(),
        )
        out_repro = tensor_tree_map(lambda x: x.cpu(), out_repro)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
187

188
189
190
191
192
        for k in out_gt.keys():
            self.assertTrue(
                torch.max(torch.abs(out_gt[k] - out_repro[k])) < consts.eps
            )

193
    def test_run_between_residue_clash_loss(self):
194
195
196
        bs = consts.batch_size
        n = consts.n_res

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
197
        pred_pos = torch.rand(bs, n, 14, 3)
198
        pred_atom_mask = torch.randint(0, 2, (bs, n, 14)).float()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
199
200
201
202
203
204
        atom14_atom_radius = torch.rand(bs, n, 14)
        residue_index = torch.arange(n).unsqueeze(0)

        loss = between_residue_clash_loss(
            pred_pos,
            pred_atom_mask,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
205
            atom14_atom_radius,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
206
207
208
            residue_index,
        )

209
210
    @compare_utils.skip_unless_alphafold_installed()
    def test_between_residue_clash_loss_compare(self):
211
212
213
214
215
216
217
218
219
220
221
222
        def run_brcl(pred_pos, atom_exists, atom_radius, res_ind, asym_id):
            if consts.is_multimer:
                pred_pos = self.am_rigid.Vec3Array.from_array(pred_pos)
                return self.am_atom.between_residue_clash_loss(
                    pred_pos,
                    atom_exists,
                    atom_radius,
                    res_ind,
                    asym_id
                )

            return self.am_atom.between_residue_clash_loss(
223
224
225
                pred_pos,
                atom_exists,
                atom_radius,
226
                res_ind
227
228
229
230
231
232
233
234
235
            )

        f = hk.transform(run_brcl)

        n_res = consts.n_res

        pred_pos = np.random.rand(n_res, 14, 3).astype(np.float32)
        atom_exists = np.random.randint(0, 2, (n_res, 14)).astype(np.float32)
        atom_radius = np.random.rand(n_res, 14).astype(np.float32)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
236
237
238
        res_ind = np.arange(
            n_res,
        )
239
        asym_id = random_asym_ids(n_res)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
240

241
        out_gt = f.apply(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
242
243
            {},
            None,
244
245
246
247
            pred_pos,
            atom_exists,
            atom_radius,
            res_ind,
248
            asym_id
249
250
251
        )
        out_gt = jax.tree_map(lambda x: x.block_until_ready(), out_gt)
        out_gt = jax.tree_map(lambda x: torch.tensor(np.copy(x)), out_gt)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
252

253
254
255
256
257
258
259
        out_repro = between_residue_clash_loss(
            torch.tensor(pred_pos).cuda(),
            torch.tensor(atom_exists).cuda(),
            torch.tensor(atom_radius).cuda(),
            torch.tensor(res_ind).cuda(),
        )
        out_repro = tensor_tree_map(lambda x: x.cpu(), out_repro)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
260

261
262
263
264
265
        for k in out_gt.keys():
            self.assertTrue(
                torch.max(torch.abs(out_gt[k] - out_repro[k])) < consts.eps
            )

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
    @compare_utils.skip_unless_alphafold_installed()
    def test_compute_plddt_compare(self):
        n_res = consts.n_res

        logits = np.random.rand(n_res, 50)

        out_gt = alphafold.common.confidence.compute_plddt(logits)
        out_gt = torch.tensor(out_gt)
        logits_t = torch.tensor(logits)
        out_repro = compute_plddt(logits_t)

        self.assertTrue(
            torch.max(torch.abs(out_gt - out_repro)) < consts.eps
        )

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
281
    def test_find_structural_violations(self):
282
        n = consts.n_res
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
283
284
285
286

        batch = {
            "atom14_atom_exists": torch.randint(0, 2, (n, 14)),
            "residue_index": torch.arange(n),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
287
            "aatype": torch.randint(0, 20, (n,)),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
288
289
290
291
            "residx_atom14_to_atom37": torch.randint(0, 37, (n, 14)).long(),
        }

        pred_pos = torch.rand(n, 14, 3)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
292

293
        config = {
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
294
295
            "clash_overlap_tolerance": 1.5,
            "violation_tolerance_factor": 12.0,
296
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
297

298
        find_structural_violations(batch, pred_pos, **config)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
299

300
301
302
303
304
    @compare_utils.skip_unless_alphafold_installed()
    def test_find_structural_violations_compare(self):
        def run_fsv(batch, pos, config):
            cwd = os.getcwd()
            os.chdir("tests/test_data")
305
306
307
308
309
310
311
312
313
314
315
316
317

            if consts.is_multimer:
                atom14_pred_pos = self.am_rigid.Vec3Array.from_array(pos)
                return self.am_fold.find_structural_violations(
                    batch['aatype'],
                    batch['residue_index'],
                    batch['atom14_atom_exists'],
                    atom14_pred_pos,
                    config,
                    batch['asym_id']
                )

            loss = self.am_fold.find_structural_violations(
318
319
320
321
322
323
324
325
                batch,
                pos,
                config,
            )
            os.chdir(cwd)
            return loss

        f = hk.transform(run_fsv)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
326

327
        n_res = consts.n_res
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
328

329
330
331
        batch = {
            "atom14_atom_exists": np.random.randint(0, 2, (n_res, 14)),
            "residue_index": np.arange(n_res),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
332
            "aatype": np.random.randint(0, 20, (n_res,)),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
333
334
335
            "residx_atom14_to_atom37": np.random.randint(
                0, 37, (n_res, 14)
            ).astype(np.int64),
336
            "asym_id": random_asym_ids(n_res)
337
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
338

339
        pred_pos = np.random.rand(n_res, 14, 3)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
340
341
342
343
344
345

        config = mlc.ConfigDict(
            {
                "clash_overlap_tolerance": 1.5,
                "violation_tolerance_factor": 12.0,
            }
346
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
347
348

        out_gt = f.apply({}, None, batch, pred_pos, config)
349
350
        out_gt = jax.tree_map(lambda x: x.block_until_ready(), out_gt)
        out_gt = jax.tree_map(lambda x: torch.tensor(np.copy(x)), out_gt)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
351
352

        batch = tree_map(lambda x: torch.tensor(x).cuda(), batch, np.ndarray)
353
354
355
356
357
358
        out_repro = find_structural_violations(
            batch,
            torch.tensor(pred_pos).cuda(),
            **config,
        )
        out_repro = tensor_tree_map(lambda x: x.cpu(), out_repro)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
359

360
361
        def compare(out):
            gt, repro = out
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
362
363
            assert torch.max(torch.abs(gt - repro)) < consts.eps

364
365
366
367
368
369
370
371
372
        dict_multimap(compare, [out_gt, out_repro])

    @compare_utils.skip_unless_alphafold_installed()
    def test_compute_renamed_ground_truth_compare(self):
        def run_crgt(batch, atom14_pred_pos):
            return alphafold.model.folding.compute_renamed_ground_truth(
                batch,
                atom14_pred_pos,
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
373

374
        f = hk.transform(run_crgt)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
375

376
        n_res = consts.n_res
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
377

378
379
        batch = {
            "seq_mask": np.random.randint(0, 2, (n_res,)).astype(np.float32),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
380
            "aatype": np.random.randint(0, 20, (n_res,)),
381
            "atom14_gt_positions": np.random.rand(n_res, 14, 3),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
382
383
384
385
386
387
388
389
390
            "atom14_gt_exists": np.random.randint(0, 2, (n_res, 14)).astype(
                np.float32
            ),
            "all_atom_mask": np.random.randint(0, 2, (n_res, 37)).astype(
                np.float32
            ),
            "all_atom_positions": np.random.rand(n_res, 37, 3).astype(
                np.float32
            ),
391
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
392

393
394
        def _build_extra_feats_np():
            b = tree_map(lambda n: torch.tensor(n), batch, np.ndarray)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
395
396
            b = data_transforms.make_atom14_masks(b)
            b = data_transforms.make_atom14_positions(b)
397
            return tensor_tree_map(lambda t: np.array(t), b)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
398

399
        batch = _build_extra_feats_np()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
400

401
        atom14_pred_pos = np.random.rand(n_res, 14, 3)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
402

403
404
        out_gt = f.apply({}, None, batch, atom14_pred_pos)
        out_gt = jax.tree_map(lambda x: torch.tensor(np.array(x)), out_gt)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
405
406

        batch = tree_map(lambda x: torch.tensor(x).cuda(), batch, np.ndarray)
407
        atom14_pred_pos = torch.tensor(atom14_pred_pos).cuda()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
408

409
410
        out_repro = compute_renamed_ground_truth(batch, atom14_pred_pos)
        out_repro = tensor_tree_map(lambda t: t.cpu(), out_repro)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
411

412
413
414
415
416
417
418
419
420
421
422
423
424
        for k in out_repro:
            self.assertTrue(
                torch.max(torch.abs(out_gt[k] - out_repro[k])) < consts.eps
            )

    @compare_utils.skip_unless_alphafold_installed()
    def test_msa_loss_compare(self):
        def run_msa_loss(value, batch):
            config = compare_utils.get_alphafold_config()
            msa_head = alphafold.model.modules.MaskedMsaHead(
                config.model.heads.masked_msa, config.model.global_config
            )
            return msa_head.loss(value, batch)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
425

426
        f = hk.transform(run_msa_loss)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
427

428
429
        n_res = consts.n_res
        n_seq = consts.n_seq
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
430

431
        value = {
432
            "logits": np.random.rand(n_res, n_seq, consts.msa_logits).astype(np.float32),
433
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
434

435
436
        batch = {
            "true_msa": np.random.randint(0, 21, (n_res, n_seq)),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
437
438
439
            "bert_mask": np.random.randint(0, 2, (n_res, n_seq)).astype(
                np.float32
            ),
440
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
441

442
443
        out_gt = f.apply({}, None, value, batch)["loss"]
        out_gt = torch.tensor(np.array(out_gt))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
444
445
446
447

        value = tree_map(lambda x: torch.tensor(x).cuda(), value, np.ndarray)
        batch = tree_map(lambda x: torch.tensor(x).cuda(), batch, np.ndarray)

448
449
450
451
        with torch.no_grad():
            out_repro = masked_msa_loss(
                value["logits"],
                **batch,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
452
            )
453
        out_repro = tensor_tree_map(lambda t: t.cpu(), out_repro)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
454

455
456
457
458
459
460
        self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)

    @compare_utils.skip_unless_alphafold_installed()
    def test_distogram_loss_compare(self):
        config = compare_utils.get_alphafold_config()
        c_distogram = config.model.heads.distogram
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
461

462
463
464
465
466
        def run_distogram_loss(value, batch):
            dist_head = alphafold.model.modules.DistogramHead(
                c_distogram, config.model.global_config
            )
            return dist_head.loss(value, batch)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
467

468
        f = hk.transform(run_distogram_loss)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
469

470
        n_res = consts.n_res
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
471

472
        value = {
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
473
474
475
            "logits": np.random.rand(n_res, n_res, c_distogram.num_bins).astype(
                np.float32
            ),
476
477
478
479
            "bin_edges": np.linspace(
                c_distogram.first_break,
                c_distogram.last_break,
                c_distogram.num_bins,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
480
            ),
481
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
482

483
484
        batch = {
            "pseudo_beta": np.random.rand(n_res, 3).astype(np.float32),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
485
            "pseudo_beta_mask": np.random.randint(0, 2, (n_res,)),
486
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
487

488
489
        out_gt = f.apply({}, None, value, batch)["loss"]
        out_gt = torch.tensor(np.array(out_gt))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
490
491
492
493
494

        value = tree_map(lambda x: torch.tensor(x).cuda(), value, np.ndarray)

        batch = tree_map(lambda x: torch.tensor(x).cuda(), batch, np.ndarray)

495
496
497
498
499
500
501
        with torch.no_grad():
            out_repro = distogram_loss(
                logits=value["logits"],
                min_bin=c_distogram.first_break,
                max_bin=c_distogram.last_break,
                no_bins=c_distogram.num_bins,
                **batch,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
502
503
            )

504
        out_repro = tensor_tree_map(lambda t: t.cpu(), out_repro)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
505

506
507
508
509
510
511
        self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)

    @compare_utils.skip_unless_alphafold_installed()
    def test_experimentally_resolved_loss_compare(self):
        config = compare_utils.get_alphafold_config()
        c_experimentally_resolved = config.model.heads.experimentally_resolved
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
512

513
514
515
516
517
        def run_experimentally_resolved_loss(value, batch):
            er_head = alphafold.model.modules.ExperimentallyResolvedHead(
                c_experimentally_resolved, config.model.global_config
            )
            return er_head.loss(value, batch)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
518

519
        f = hk.transform(run_experimentally_resolved_loss)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
520

521
        n_res = consts.n_res
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
522

523
524
525
        value = {
            "logits": np.random.rand(n_res, 37).astype(np.float32),
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
526

527
528
529
        batch = {
            "all_atom_mask": np.random.randint(0, 2, (n_res, 37)),
            "atom37_atom_exists": np.random.randint(0, 2, (n_res, 37)),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
530
            "resolution": np.array(1.0),
531
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
532

533
534
        out_gt = f.apply({}, None, value, batch)["loss"]
        out_gt = torch.tensor(np.array(out_gt))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
535
536
537
538
539

        value = tree_map(lambda x: torch.tensor(x).cuda(), value, np.ndarray)

        batch = tree_map(lambda x: torch.tensor(x).cuda(), batch, np.ndarray)

540
541
542
543
544
545
        with torch.no_grad():
            out_repro = experimentally_resolved_loss(
                logits=value["logits"],
                min_resolution=c_experimentally_resolved.min_resolution,
                max_resolution=c_experimentally_resolved.max_resolution,
                **batch,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
546
547
            )

548
        out_repro = tensor_tree_map(lambda t: t.cpu(), out_repro)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
549

550
551
552
553
554
555
        self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)

    @compare_utils.skip_unless_alphafold_installed()
    def test_supervised_chi_loss_compare(self):
        config = compare_utils.get_alphafold_config()
        c_chi_loss = config.model.heads.structure_module
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
556

557
        def run_supervised_chi_loss(value, batch):
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
            if consts.is_multimer:
                pred_angles = np.reshape(
                    value['sidechains']['angles_sin_cos'], [-1, consts.n_res, 7, 2])

                unnormed_angles = np.reshape(
                    value['sidechains']['unnormalized_angles_sin_cos'], [-1, consts.n_res, 7, 2])

                chi_loss, _, _ = self.am_fold.supervised_chi_loss(
                    batch['seq_mask'],
                    batch['chi_mask'],
                    batch['aatype'],
                    batch['chi_angles'],
                    pred_angles,
                    unnormed_angles,
                    c_chi_loss
                )
                return chi_loss

576
            ret = {
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
577
                "loss": jax.numpy.array(0.0),
578
            }
579
            self.am_fold.supervised_chi_loss(
580
581
582
583
584
585
586
587
588
589
                ret, batch, value, c_chi_loss
            )
            return ret["loss"]

        f = hk.transform(run_supervised_chi_loss)

        n_res = consts.n_res

        value = {
            "sidechains": {
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
590
591
592
593
594
595
                "angles_sin_cos": np.random.rand(8, n_res, 7, 2).astype(
                    np.float32
                ),
                "unnormalized_angles_sin_cos": np.random.rand(
                    8, n_res, 7, 2
                ).astype(np.float32),
596
597
598
599
600
601
602
603
604
605
606
607
            }
        }

        batch = {
            "aatype": np.random.randint(0, 21, (n_res,)),
            "seq_mask": np.random.randint(0, 2, (n_res,)),
            "chi_mask": np.random.randint(0, 2, (n_res, 4)),
            "chi_angles": np.random.rand(n_res, 4).astype(np.float32),
        }

        out_gt = f.apply({}, None, value, batch)
        out_gt = torch.tensor(np.array(out_gt.block_until_ready()))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
608
        value = tree_map(lambda x: torch.tensor(x).cuda(), value, np.ndarray)
609

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
610
        batch = tree_map(lambda x: torch.tensor(x).cuda(), batch, np.ndarray)
611
612
613
614
615
616
617
618
619
620
621
622
623

        batch["chi_angles_sin_cos"] = torch.stack(
            [
                torch.sin(batch["chi_angles"]),
                torch.cos(batch["chi_angles"]),
            ],
            dim=-1,
        )

        with torch.no_grad():
            out_repro = supervised_chi_loss(
                chi_weight=c_chi_loss.chi_weight,
                angle_norm_weight=c_chi_loss.angle_norm_weight,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
624
625
626
                **{**batch, **value["sidechains"]},
            )

627
628
629
630
631
632
633
634
        out_repro = tensor_tree_map(lambda t: t.cpu(), out_repro)

        self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)

    @compare_utils.skip_unless_alphafold_installed()
    def test_violation_loss_compare(self):
        config = compare_utils.get_alphafold_config()
        c_viol = config.model.heads.structure_module
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
635

636
637
        def run_viol_loss(batch, atom14_pred_pos):
            ret = {
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
638
                "loss": np.array(0.0).astype(np.float32),
639
            }
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654

            if consts.is_multimer:
                atom14_pred_pos = self.am_rigid.Vec3Array.from_array(atom14_pred_pos)
                viol = self.am_fold.find_structural_violations(
                    batch['aatype'],
                    batch['residue_index'],
                    batch['atom14_atom_exists'],
                    atom14_pred_pos,
                    c_viol,
                    batch['asym_id']
                )
                return self.am_fold.structural_violation_loss(mask=batch['atom14_atom_exists'],
                                                              violations=viol,
                                                              config=c_viol)

655
            value = {}
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
656
657
            value[
                "violations"
658
            ] = self.am_fold.find_structural_violations(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
659
660
661
                batch,
                atom14_pred_pos,
                c_viol,
662
            )
663
664

            self.am_fold.structural_violation_loss(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
665
666
667
668
                ret,
                batch,
                value,
                c_viol,
669
670
            )
            return ret["loss"]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
671

672
        f = hk.transform(run_viol_loss)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
673

674
        n_res = consts.n_res
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
675

676
677
678
679
        batch = {
            "seq_mask": np.random.randint(0, 2, (n_res,)).astype(np.float32),
            "residue_index": np.arange(n_res),
            "aatype": np.random.randint(0, 21, (n_res,)),
680
            "asym_id": random_asym_ids(n_res)
681
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
682

683
        atom14_pred_pos = np.random.rand(n_res, 14, 3).astype(np.float32)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
684

685
686
687
        alphafold.model.tf.data_transforms.make_atom14_masks(batch)
        batch = {k: np.array(v) for k, v in batch.items()}

688
689
        out_gt = f.apply({}, None, batch, atom14_pred_pos)
        out_gt = torch.tensor(np.array(out_gt.block_until_ready()))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
690
691

        batch = tree_map(lambda n: torch.tensor(n).cuda(), batch, np.ndarray)
692
        atom14_pred_pos = torch.tensor(atom14_pred_pos).cuda()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
693

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
694
        batch = data_transforms.make_atom14_masks(batch)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
695

696
697
698
699
700
        out_repro = violation_loss(
            find_structural_violations(batch, atom14_pred_pos, **c_viol),
            **batch,
        )
        out_repro = out_repro.cpu()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
701

702
703
704
705
706
707
        self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)

    @compare_utils.skip_unless_alphafold_installed()
    def test_lddt_loss_compare(self):
        config = compare_utils.get_alphafold_config()
        c_plddt = config.model.heads.predicted_lddt
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
708

709
710
711
712
713
        def run_plddt_loss(value, batch):
            head = alphafold.model.modules.PredictedLDDTHead(
                c_plddt, config.model.global_config
            )
            return head.loss(value, batch)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
714

715
        f = hk.transform(run_plddt_loss)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
716

717
        n_res = consts.n_res
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
718

719
720
        value = {
            "predicted_lddt": {
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
721
722
723
                "logits": np.random.rand(n_res, c_plddt.num_bins).astype(
                    np.float32
                ),
724
725
            },
            "structure_module": {
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
726
727
728
729
                "final_atom_positions": np.random.rand(n_res, 37, 3).astype(
                    np.float32
                ),
            },
730
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
731

732
        batch = {
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
733
734
735
736
737
738
739
            "all_atom_positions": np.random.rand(n_res, 37, 3).astype(
                np.float32
            ),
            "all_atom_mask": np.random.randint(0, 2, (n_res, 37)).astype(
                np.float32
            ),
            "resolution": np.array(1.0).astype(np.float32),
740
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
741

742
743
        out_gt = f.apply({}, None, value, batch)
        out_gt = torch.tensor(np.array(out_gt["loss"]))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
744

745
746
747
        to_tensor = lambda t: torch.tensor(t).cuda()
        value = tree_map(to_tensor, value, np.ndarray)
        batch = tree_map(to_tensor, batch, np.ndarray)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
748

749
750
751
752
753
754
        out_repro = lddt_loss(
            logits=value["predicted_lddt"]["logits"],
            all_atom_pred_pos=value["structure_module"]["final_atom_positions"],
            **{**batch, **c_plddt},
        )
        out_repro = out_repro.cpu()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
755

756
757
758
        self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)

    @compare_utils.skip_unless_alphafold_installed()
759
    def test_backbone_loss_compare(self):
760
761
        config = compare_utils.get_alphafold_config()
        c_sm = config.model.heads.structure_module
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
762

763
        def run_bb_loss(batch, value):
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
            if consts.is_multimer:
                intra_chain_mask = (batch["asym_id"][..., None] == batch["asym_id"][..., None, :]).astype(np.float32)
                gt_rigid = affine_vector_to_rigid(self.am_rigid, batch["backbone_affine_tensor"])
                target_rigid = affine_vector_to_rigid(self.am_rigid, value['traj'])
                intra_chain_bb_loss, intra_chain_fape = self.am_fold.backbone_loss(
                    gt_rigid=gt_rigid,
                    gt_frames_mask=batch["backbone_affine_mask"],
                    gt_positions_mask=batch["backbone_affine_mask"],
                    target_rigid=target_rigid,
                    config=c_sm.intra_chain_fape,
                    pair_mask=intra_chain_mask)
                interface_bb_loss, interface_fape = self.am_fold.backbone_loss(
                    gt_rigid=gt_rigid,
                    gt_frames_mask=batch["backbone_affine_mask"],
                    gt_positions_mask=batch["backbone_affine_mask"],
                    target_rigid=target_rigid,
                    config=c_sm.interface_fape,
                    pair_mask=1. - intra_chain_mask)

                return intra_chain_bb_loss + interface_bb_loss

785
            ret = {
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
786
                "loss": np.array(0.0),
787
            }
788
            self.am_fold.backbone_loss(ret, batch, value, c_sm)
789
790
791
792
793
794
795
796
            return ret["loss"]

        f = hk.transform(run_bb_loss)

        n_res = consts.n_res

        batch = {
            "backbone_affine_tensor": random_affines_vector((n_res,)),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
797
798
799
800
            "backbone_affine_mask": np.random.randint(0, 2, (n_res,)).astype(
                np.float32
            ),
            "use_clamped_fape": np.array(0.0),
801
            "asym_id": random_asym_ids(n_res)
802
803
804
        }

        value = {
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
805
806
807
808
809
810
            "traj": random_affines_vector(
                (
                    c_sm.num_layer,
                    n_res,
                )
            ),
811
812
813
814
815
816
817
818
819
        }

        out_gt = f.apply({}, None, batch, value)
        out_gt = torch.tensor(np.array(out_gt.block_until_ready()))

        to_tensor = lambda t: torch.tensor(t).cuda()
        batch = tree_map(to_tensor, batch, np.ndarray)
        value = tree_map(to_tensor, value, np.ndarray)

820
        batch["backbone_rigid_tensor"] = affine_vector_to_4x4(
821
822
            batch["backbone_affine_tensor"]
        )
823
824
        batch["backbone_rigid_mask"] = batch["backbone_affine_mask"]
        
825
826
827
828
829
830
831
832
833
        out_repro = backbone_loss(traj=value["traj"], **{**batch, **c_sm})
        out_repro = out_repro.cpu()

        self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)

    @compare_utils.skip_unless_alphafold_installed()
    def test_sidechain_loss_compare(self):
        config = compare_utils.get_alphafold_config()
        c_sm = config.model.heads.structure_module
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
834

835
        def run_sidechain_loss(batch, value, atom14_pred_positions):
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
            if consts.is_multimer:
                atom14_pred_positions = self.am_rigid.Vec3Array.from_array(atom14_pred_positions)
                all_atom_positions = self.am_rigid.Vec3Array.from_array(batch["all_atom_positions"])
                gt_positions, gt_mask, alt_naming_is_better = self.am_fold.compute_atom14_gt(
                    aatype=batch["aatype"], all_atom_positions=all_atom_positions,
                    all_atom_mask=batch["all_atom_mask"], pred_pos=atom14_pred_positions)
                pred_frames = self.am_rigid.Rigid3Array.from_array4x4(value["sidechains"]["frames"])
                pred_positions = self.am_rigid.Vec3Array.from_array(value["sidechains"]["atom_pos"])
                gt_sc_frames, gt_sc_frames_mask = self.am_fold.compute_frames(
                    aatype=batch["aatype"],
                    all_atom_positions=all_atom_positions,
                    all_atom_mask=batch["all_atom_mask"],
                    use_alt=alt_naming_is_better)
                return self.am_fold.sidechain_loss(gt_sc_frames,
                                                   gt_sc_frames_mask,
                                                   gt_positions,
                                                   gt_mask,
                                                   pred_frames,
                                                   pred_positions,
                                                   c_sm)['loss']
856
857
            batch = {
                **batch,
858
                **self.am_atom.atom37_to_frames(
859
860
861
                    batch["aatype"],
                    batch["all_atom_positions"],
                    batch["all_atom_mask"],
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
862
                ),
863
864
865
            }
            v = {}
            v["sidechains"] = {}
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
866
867
868
869
            v["sidechains"][
                "frames"
            ] = alphafold.model.r3.rigids_from_tensor4x4(
                value["sidechains"]["frames"]
870
871
872
873
            )
            v["sidechains"]["atom_pos"] = alphafold.model.r3.vecs_from_tensor(
                value["sidechains"]["atom_pos"]
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
874
875
876
877
878
879
            v.update(
                alphafold.model.folding.compute_renamed_ground_truth(
                    batch,
                    atom14_pred_positions,
                )
            )
880
            value = v
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
881

882
            ret = self.am_fold.sidechain_loss(batch, value, c_sm)
883
            return ret["loss"]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
884

885
        f = hk.transform(run_sidechain_loss)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
886

887
        n_res = consts.n_res
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
888

889
890
        batch = {
            "seq_mask": np.random.randint(0, 2, (n_res,)).astype(np.float32),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
891
            "aatype": np.random.randint(0, 20, (n_res,)),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
892
893
894
895
896
897
898
899
900
901
902
903
            "atom14_gt_positions": np.random.rand(n_res, 14, 3).astype(
                np.float32
            ),
            "atom14_gt_exists": np.random.randint(0, 2, (n_res, 14)).astype(
                np.float32
            ),
            "all_atom_positions": np.random.rand(n_res, 37, 3).astype(
                np.float32
            ),
            "all_atom_mask": np.random.randint(0, 2, (n_res, 37)).astype(
                np.float32
            ),
904
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
905

906
907
        def _build_extra_feats_np():
            b = tree_map(lambda n: torch.tensor(n), batch, np.ndarray)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
908
909
            b = data_transforms.make_atom14_masks(b)
            b = data_transforms.make_atom14_positions(b)
910
            return tensor_tree_map(lambda t: np.array(t), b)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
911
912
913

        batch = _build_extra_feats_np()

914
915
916
        value = {
            "sidechains": {
                "frames": random_affines_4x4((c_sm.num_layer, n_res, 8)),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
917
918
919
                "atom_pos": np.random.rand(c_sm.num_layer, n_res, 14, 3).astype(
                    np.float32
                ),
920
921
            }
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
922

923
        atom14_pred_pos = np.random.rand(n_res, 14, 3).astype(np.float32)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
924

925
926
        out_gt = f.apply({}, None, batch, value, atom14_pred_pos)
        out_gt = torch.tensor(np.array(out_gt.block_until_ready()))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
927

928
929
930
931
        to_tensor = lambda t: torch.tensor(t).cuda()
        batch = tree_map(to_tensor, batch, np.ndarray)
        value = tree_map(to_tensor, value, np.ndarray)
        atom14_pred_pos = to_tensor(atom14_pred_pos)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
932

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
933
        batch = data_transforms.atom37_to_frames(batch)
934
        batch.update(compute_renamed_ground_truth(batch, atom14_pred_pos))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
935

936
937
938
939
940
941
        out_repro = sidechain_loss(
            sidechain_frames=value["sidechains"]["frames"],
            sidechain_atom_pos=value["sidechains"]["atom_pos"],
            **{**batch, **c_sm},
        )
        out_repro = out_repro.cpu()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
942

943
944
945
        self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)

    @compare_utils.skip_unless_alphafold_installed()
946
    @unittest.skipIf(not consts.is_multimer and "ptm" not in consts.model, "Not enabled for non-ptm models.")
947
948
949
    def test_tm_loss_compare(self):
        config = compare_utils.get_alphafold_config()
        c_tm = config.model.heads.predicted_aligned_error
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
950

951
952
953
954
955
956
957
958
        def run_tm_loss(representations, batch, value):
            head = alphafold.model.modules.PredictedAlignedErrorHead(
                c_tm, config.model.global_config
            )
            v = {}
            v.update(value)
            v["predicted_aligned_error"] = head(representations, batch, False)
            return head.loss(v, batch)["loss"]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
959

960
        f = hk.transform(run_tm_loss)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
961

962
963
        np.random.seed(42)

964
        n_res = consts.n_res
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
965

966
        representations = {
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
967
            "pair": np.random.rand(n_res, n_res, consts.c_z).astype(np.float32),
968
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
969

970
971
        batch = {
            "backbone_affine_tensor": random_affines_vector((n_res,)),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
972
973
974
975
            "backbone_affine_mask": np.random.randint(0, 2, (n_res,)).astype(
                np.float32
            ),
            "resolution": np.array(1.0).astype(np.float32),
976
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
977

978
979
980
981
982
        value = {
            "structure_module": {
                "final_affines": random_affines_vector((n_res,)),
            }
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
983

984
985
986
987
988
989
        params = compare_utils.fetch_alphafold_module_weights(
            "alphafold/alphafold_iteration/predicted_aligned_error_head"
        )

        out_gt = f.apply(params, None, representations, batch, value)
        out_gt = torch.tensor(np.array(out_gt.block_until_ready()))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
990

991
992
993
994
        to_tensor = lambda n: torch.tensor(n).cuda()
        representations = tree_map(to_tensor, representations, np.ndarray)
        batch = tree_map(to_tensor, batch, np.ndarray)
        value = tree_map(to_tensor, value, np.ndarray)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
995

996
        batch["backbone_rigid_tensor"] = affine_vector_to_4x4(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
997
            batch["backbone_affine_tensor"]
998
        )
999
        batch["backbone_rigid_mask"] = batch["backbone_affine_mask"]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1000

1001
1002
        model = compare_utils.get_global_pretrained_openfold()
        logits = model.aux_heads.tm(representations["pair"])
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1003

1004
1005
1006
1007
1008
1009
        out_repro = tm_loss(
            logits=logits,
            final_affine_tensor=value["structure_module"]["final_affines"],
            **{**batch, **c_tm},
        )
        out_repro = out_repro.cpu()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1010

1011
1012
        self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1013
1014
1015

if __name__ == "__main__":
    unittest.main()