test_loss.py 3.02 KB
Newer Older
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
import torch
import numpy as np
import unittest

from alphafold.utils.loss import *
from alphafold.utils.utils import T


class TestLoss(unittest.TestCase):
    def test_run_torsion_angle_loss(self):
        batch_size = 2
        n = 5

        a = torch.rand((batch_size, n, 7, 2))
        a_gt = torch.rand((batch_size, n, 7, 2))
        a_alt_gt = torch.rand((batch_size, n, 7, 2))

        loss = torsion_angle_loss(a, a_gt, a_alt_gt)

    def test_run_fape(self):
        batch_size = 2
        n_frames = 7
        n_atoms = 5

        x = torch.rand((batch_size, n_atoms, 3))
        x_gt = torch.rand((batch_size, n_atoms, 3))
        rots = torch.rand((batch_size, n_frames, 3, 3))
        rots_gt = torch.rand((batch_size, n_frames, 3, 3))
        trans = torch.rand((batch_size, n_frames, 3))
        trans_gt = torch.rand((batch_size, n_frames, 3))
        t = T(rots, trans)
        t_gt = T(rots_gt, trans_gt)

        loss = compute_fape(t, x, t_gt, x_gt)

    def test_between_residue_bond_loss(self):
        bs = 2
        n = 10
        pred_pos = torch.rand(bs, n, 14, 3)
        pred_atom_mask = torch.randint(0, 2, (bs, n, 14))
        residue_index = torch.arange(n).unsqueeze(0)
        aatype = torch.randint(0, 22, (bs, n,))
        
        between_residue_bond_loss(
            pred_pos,
            pred_atom_mask,
            residue_index, 
            aatype,
        )

    def test_between_residue_clash_loss(self):
        bs = 2
        n = 10
        pred_pos = torch.rand(bs, n, 14, 3)
        pred_atom_mask = torch.randint(0, 2, (bs, n, 14))
        atom14_atom_radius = torch.rand(bs, n, 14)
        residue_index = torch.arange(n).unsqueeze(0)

        loss = between_residue_clash_loss(
            pred_pos,
            pred_atom_mask,
            atom14_atom_radius, 
            residue_index,
        )

    def test_find_structural_violations(self):
        n = 10

        batch = {
            "atom14_atom_exists": torch.randint(0, 2, (n, 14)),
            "residue_index": torch.arange(n),
            "aatype": torch.randint(0, 21, (n,)),
            "residx_atom14_to_atom37": torch.randint(0, 37, (n, 14)).long(),
        }

        pred_pos = torch.rand(n, 14, 3)
        
        config = ml_collections.ConfigDict({
            "clash_overlap_tolerance": 1.5,
            "violation_tolerance_factor": 12.0,
        })

        find_structural_violations(batch, pred_pos, config)


if __name__ == "__main__":
    unittest.main()