test_loss.py 5.35 KB
Newer Older
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
import torch
import numpy as np
import unittest

20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from openfold.utils.loss import (
    torsion_angle_loss,
    compute_fape,
    between_residue_bond_loss,
    between_residue_clash_loss,
    find_structural_violations,
)
from openfold.utils.affine_utils import T
from openfold.utils.tensor_utils import tensor_tree_map
import tests.compare_utils as compare_utils
from tests.config import consts

if(compare_utils.alphafold_is_installed()):
    alphafold = compare_utils.import_alphafold()
    import jax
    import haiku as hk
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
36
37
38
39


class TestLoss(unittest.TestCase):
    def test_run_torsion_angle_loss(self):
40
41
        batch_size = consts.batch_size
        n_res = consts.n_res
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
42

43
44
45
        a = torch.rand((batch_size, n_res, 7, 2))
        a_gt = torch.rand((batch_size, n_res, 7, 2))
        a_alt_gt = torch.rand((batch_size, n_res, 7, 2))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
46
47
48
49

        loss = torsion_angle_loss(a, a_gt, a_alt_gt)

    def test_run_fape(self):
50
        batch_size = consts.batch_size
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
51
52
53
54
55
56
57
58
59
60
61
        n_frames = 7
        n_atoms = 5

        x = torch.rand((batch_size, n_atoms, 3))
        x_gt = torch.rand((batch_size, n_atoms, 3))
        rots = torch.rand((batch_size, n_frames, 3, 3))
        rots_gt = torch.rand((batch_size, n_frames, 3, 3))
        trans = torch.rand((batch_size, n_frames, 3))
        trans_gt = torch.rand((batch_size, n_frames, 3))
        t = T(rots, trans)
        t_gt = T(rots_gt, trans_gt)
62
63
64
65
66
67
68
69
70
71
72
73
74
        frames_mask = torch.randint(0, 2, (batch_size, n_frames)).float()
        positions_mask = torch.randint(0, 2, (batch_size, n_atoms)).float()
        length_scale = 10

        loss = compute_fape(
            pred_frames=t,
            target_frames=t_gt,
            frames_mask=frames_mask,
            pred_positions=x,
            target_positions=x_gt,
            positions_mask=positions_mask,
            length_scale=length_scale,
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
75

76
77
78
    def test_run_between_residue_bond_loss(self):
        bs = consts.batch_size
        n = consts.n_res
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
79
80
81
82
83
84
85
86
87
88
89
90
        pred_pos = torch.rand(bs, n, 14, 3)
        pred_atom_mask = torch.randint(0, 2, (bs, n, 14))
        residue_index = torch.arange(n).unsqueeze(0)
        aatype = torch.randint(0, 22, (bs, n,))
        
        between_residue_bond_loss(
            pred_pos,
            pred_atom_mask,
            residue_index, 
            aatype,
        )

91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
    @compare_utils.skip_unless_alphafold_installed()
    def test_between_residue_bond_loss_compare(self):
        def run_brbl(pred_pos, pred_atom_mask, residue_index, aatype):
            return alphafold.model.all_atom.between_residue_bond_loss(
                pred_pos,
                pred_atom_mask,
                residue_index,
                aatype,
            )
    
        f = hk.transform(run_brbl)
    
        n_res = consts.n_res 
        pred_pos = np.random.rand(n_res, 14, 3).astype(np.float32)
        pred_atom_mask = np.random.randint(0, 2, (n_res, 14)).astype(np.float32)
        residue_index = np.arange(n_res)
        aatype = np.random.randint(0, 22, (n_res,))
        
        out_gt = f.apply(
            {}, None, 
            pred_pos, 
            pred_atom_mask, 
            residue_index,
            aatype,
        )
        out_gt = jax.tree_map(lambda x: x.block_until_ready(), out_gt)
        out_gt = jax.tree_map(lambda x: torch.tensor(np.copy(x)), out_gt)
    
        out_repro = between_residue_bond_loss(
            torch.tensor(pred_pos).cuda(),
            torch.tensor(pred_atom_mask).cuda(),
            torch.tensor(residue_index).cuda(),
            torch.tensor(aatype).cuda(),
        )
        out_repro = tensor_tree_map(lambda x: x.cpu(), out_repro)
    
        for k in out_gt.keys():
            self.assertTrue(
                torch.max(torch.abs(out_gt[k] - out_repro[k])) < consts.eps
            )


Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
133
    def test_between_residue_clash_loss(self):
134
135
136
        bs = consts.batch_size
        n = consts.n_res

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
137
138
139
140
141
142
143
144
145
146
147
148
149
        pred_pos = torch.rand(bs, n, 14, 3)
        pred_atom_mask = torch.randint(0, 2, (bs, n, 14))
        atom14_atom_radius = torch.rand(bs, n, 14)
        residue_index = torch.arange(n).unsqueeze(0)

        loss = between_residue_clash_loss(
            pred_pos,
            pred_atom_mask,
            atom14_atom_radius, 
            residue_index,
        )

    def test_find_structural_violations(self):
150
        n = consts.n_res
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
151
152
153
154
155
156
157
158
159
160

        batch = {
            "atom14_atom_exists": torch.randint(0, 2, (n, 14)),
            "residue_index": torch.arange(n),
            "aatype": torch.randint(0, 21, (n,)),
            "residx_atom14_to_atom37": torch.randint(0, 37, (n, 14)).long(),
        }

        pred_pos = torch.rand(n, 14, 3)
        
161
        config = {
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
162
163
            "clash_overlap_tolerance": 1.5,
            "violation_tolerance_factor": 12.0,
164
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
165

166
        find_structural_violations(batch, pred_pos, **config)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
167
168
169
170
171


if __name__ == "__main__":
    unittest.main()