test_permutation.py 10 KB
Newer Older
Geoffrey Yu's avatar
Geoffrey Yu committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2021 AlQuraishi Laboratory
# Dingquan Yu @ EMBL-Hamburg Kosinski group
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import math
Geoffrey Yu's avatar
Geoffrey Yu committed
16
17
18
import torch

import unittest
19
20
21
22
from openfold.utils.multi_chain_permutation import (pad_features, get_least_asym_entity_or_longest_length,
                                                    compute_permutation_alignment, split_ground_truth_labels,
                                                    merge_labels)

Geoffrey Yu's avatar
Geoffrey Yu committed
23
24
25
26
27
28
29
30

class TestPermutation(unittest.TestCase):
    def setUp(self):
        """
        create fake input structure features 
        and rotation matrices 
        """

31
32
        theta = math.pi / 4
        device = 'cpu'
Geoffrey Yu's avatar
Geoffrey Yu committed
33
        self.rotation_matrix_z = torch.tensor([
34
35
36
37
            [math.cos(theta), -math.sin(theta), 0],
            [math.sin(theta), math.cos(theta), 0],
            [0, 0, 1]
        ], device=device)
Geoffrey Yu's avatar
Geoffrey Yu committed
38
        self.rotation_matrix_x = torch.tensor([
39
40
41
42
            [1, 0, 0],
            [0, math.cos(theta), -math.sin(theta)],
            [0, math.sin(theta), math.cos(theta)],
        ], device=device)
Geoffrey Yu's avatar
Geoffrey Yu committed
43
        self.rotation_matrix_y = torch.tensor([
44
45
46
47
48
49
            [math.cos(theta), 0, math.sin(theta)],
            [0, 1, 0],
            [-math.sin(theta), 1, math.cos(theta)],
        ], device=device)
        self.chain_a_num_res = 9
        self.chain_b_num_res = 13
Geoffrey Yu's avatar
Geoffrey Yu committed
50
        # below create default fake ground truth structures for a hetero-pentamer A2B3
51
52
53
54
        self.residue_index = list(range(self.chain_a_num_res)) * 2 + list(range(self.chain_b_num_res)) * 3
        self.num_res = self.chain_a_num_res * 2 + self.chain_b_num_res * 3
        self.asym_id = torch.tensor([[1] * self.chain_a_num_res + [2] * self.chain_a_num_res + [
            3] * self.chain_b_num_res + [4] * self.chain_b_num_res + [5] * self.chain_b_num_res], device=device)
Geoffrey Yu's avatar
Geoffrey Yu committed
55
        self.sym_id = self.asym_id
56
57
        self.entity_id = torch.tensor([[1] * (self.chain_a_num_res * 2) + [2] * (self.chain_b_num_res * 3)],
                                      device=device)
58
59
    
    # @unittest.skip("skip for now")
Geoffrey Yu's avatar
Geoffrey Yu committed
60
    def test_1_selecting_anchors(self):
61
62
63
64
65
        batch = {
            'asym_id': self.asym_id,
            'sym_id': self.sym_id,
            'entity_id': self.entity_id,
            'seq_length': torch.tensor([57])
Geoffrey Yu's avatar
Geoffrey Yu committed
66
        }
67
        anchor_gt_asym, anchor_pred_asym = get_least_asym_entity_or_longest_length(batch, batch['asym_id'])
Jennifer's avatar
Jennifer committed
68
69
70
71
72
73
74
75
76
77
        anchor_gt_asym = int(anchor_gt_asym)
        anchor_pred_asym = {int(i) for i in anchor_pred_asym}
        expected_anchors = {1, 2}
        expected_non_anchors = {3, 4, 5}

        self.assertIn(anchor_gt_asym, expected_anchors) 
        self.assertNotIn(anchor_gt_asym, expected_non_anchors)
        # Check that predicted anchors are within expected anchor set
        self.assertEqual(anchor_pred_asym, expected_anchors & anchor_pred_asym)
        self.assertEqual(set(), anchor_pred_asym & expected_non_anchors) 
Geoffrey Yu's avatar
Geoffrey Yu committed
78

79
    # @unittest.skip("skip for now")
Geoffrey Yu's avatar
Geoffrey Yu committed
80
81
    def test_2_permutation_pentamer(self):
        batch = {
82
83
84
85
86
            'asym_id': self.asym_id,
            'sym_id': self.sym_id,
            'entity_id': self.entity_id,
            'seq_length': torch.tensor([57]),
            'aatype': torch.randint(21, size=(1, 57))
Geoffrey Yu's avatar
Geoffrey Yu committed
87
        }
88
89
        batch['asym_id'] = batch['asym_id'].reshape(1, self.num_res)
        batch["residue_index"] = torch.tensor([self.residue_index])
Geoffrey Yu's avatar
Geoffrey Yu committed
90
        # create fake ground truth atom positions        
91
92
93
94
95
96
97
98
        chain_a1_pos = torch.randint(15, (self.chain_a_num_res, 3 * 37),
                                     dtype=torch.float).reshape(1, self.chain_a_num_res, 37, 3)
        chain_a2_pos = torch.matmul(chain_a1_pos, self.rotation_matrix_x) + 10

        chain_b1_pos = torch.randint(low=15, high=30, size=(self.chain_b_num_res, 3 * 37),
                                     dtype=torch.float).reshape(1, self.chain_b_num_res, 37, 3)
        chain_b2_pos = torch.matmul(chain_b1_pos, self.rotation_matrix_y) + 10
        chain_b3_pos = torch.matmul(torch.matmul(chain_b1_pos, self.rotation_matrix_z), self.rotation_matrix_x) + 30
Geoffrey Yu's avatar
Geoffrey Yu committed
99
        # Below permutate predicted chain positions 
100
101
        pred_atom_position = torch.cat((chain_a2_pos, chain_a1_pos, chain_b2_pos, chain_b3_pos, chain_b1_pos), dim=1)
        pred_atom_mask = torch.ones((1, self.num_res, 37))
Geoffrey Yu's avatar
Geoffrey Yu committed
102
        out = {
103
104
            'final_atom_positions': pred_atom_position,
            'final_atom_mask': pred_atom_mask
Geoffrey Yu's avatar
Geoffrey Yu committed
105
106
        }

107
108
109
110
111
112
        true_atom_position = torch.cat((chain_a1_pos, chain_a2_pos, chain_b1_pos, chain_b2_pos, chain_b3_pos), dim=1)
        true_atom_mask = torch.cat((torch.ones((1, self.chain_a_num_res, 37)),
                                    torch.ones((1, self.chain_a_num_res, 37)),
                                    torch.ones((1, self.chain_b_num_res, 37)),
                                    torch.ones((1, self.chain_b_num_res, 37)),
                                    torch.ones((1, self.chain_b_num_res, 37))), dim=1)
Geoffrey Yu's avatar
Geoffrey Yu committed
113
114
        batch['all_atom_positions'] = true_atom_position
        batch['all_atom_mask'] = true_atom_mask
115

116
        aligns, per_asym_residue_index = compute_permutation_alignment(out, batch,
117
118
119
120
121
                                                  batch)
        possible_outcome = [[(0, 1), (1, 0), (2, 3), (3, 4), (4, 2)], [(0, 0), (1, 1), (2, 3), (3, 4), (4, 2)]]
        wrong_outcome = [[(0, 1), (1, 0), (2, 4), (3, 2), (4, 3)], [(0, 0), (1, 1), (2, 2), (3, 3), (4, 4)]]
        self.assertIn(aligns, possible_outcome)
        self.assertNotIn(aligns, wrong_outcome)
Geoffrey Yu's avatar
Geoffrey Yu committed
122

123
    # @unittest.skip("Test needs to be fixed post-refactor")
Geoffrey Yu's avatar
Geoffrey Yu committed
124
    def test_3_merge_labels(self):
125
        nres_pad = 325 - 57  # suppose the cropping size is 325
Geoffrey Yu's avatar
Geoffrey Yu committed
126
        batch = {
127
128
129
130
            'asym_id': self.asym_id,
            'sym_id': self.sym_id,
            'entity_id': self.entity_id,
            'aatype': torch.randint(21, size=(1, 57)),
131
            'seq_length': torch.tensor([57])
Geoffrey Yu's avatar
Geoffrey Yu committed
132
        }
133
134
        batch['asym_id'] = batch['asym_id'].reshape(1, 57)
        batch["residue_index"] = torch.tensor([self.residue_index])
Geoffrey Yu's avatar
Geoffrey Yu committed
135
        # create fake ground truth atom positions        
136
137
138
139
140
141
142
143
        chain_a1_pos = torch.randint(15, (self.chain_a_num_res, 3 * 37),
                                     dtype=torch.float).reshape(1, self.chain_a_num_res, 37, 3)
        chain_a2_pos = torch.matmul(chain_a1_pos, self.rotation_matrix_x) + 10

        chain_b1_pos = torch.randint(low=15, high=30, size=(self.chain_b_num_res, 3 * 37),
                                     dtype=torch.float).reshape(1, self.chain_b_num_res, 37, 3)
        chain_b2_pos = torch.matmul(chain_b1_pos, self.rotation_matrix_y) + 10
        chain_b3_pos = torch.matmul(torch.matmul(chain_b1_pos, self.rotation_matrix_z), self.rotation_matrix_x) + 30
Geoffrey Yu's avatar
Geoffrey Yu committed
144
        # Below permutate predicted chain positions 
145
146
147
148
        pred_atom_position = torch.cat((chain_a2_pos, chain_a1_pos, chain_b2_pos, chain_b3_pos, chain_b1_pos), dim=1)
        pred_atom_mask = torch.ones((1, self.num_res, 37))
        pred_atom_position = pad_features(pred_atom_position, nres_pad, pad_dim=1)
        pred_atom_mask = pad_features(pred_atom_mask, nres_pad, pad_dim=1)
Geoffrey Yu's avatar
Geoffrey Yu committed
149
        out = {
150
151
            'final_atom_positions': pred_atom_position,
            'final_atom_mask': pred_atom_mask
Geoffrey Yu's avatar
Geoffrey Yu committed
152
        }
153
154
155
156
157
158
159
        true_atom_position = torch.cat((chain_a1_pos, chain_a2_pos, chain_b1_pos, chain_b2_pos, chain_b3_pos), dim=1)
        true_atom_mask = torch.cat((torch.ones((1, self.chain_a_num_res, 37)),
                                    torch.ones((1, self.chain_a_num_res, 37)),
                                    torch.ones((1, self.chain_b_num_res, 37)),
                                    torch.ones((1, self.chain_b_num_res, 37)),
                                    torch.ones((1, self.chain_b_num_res, 37))), dim=1)

160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
        batch['all_atom_positions'] = true_atom_position 
        batch['all_atom_mask'] = true_atom_mask 

        # Below create a fake_input_features 
        fake_input_features  = {
            'asym_id': pad_features(self.asym_id, nres_pad, pad_dim=1),
            'sym_id': pad_features(self.sym_id, nres_pad, pad_dim=1),
            'entity_id': pad_features(self.entity_id, nres_pad, pad_dim=1),
            'aatype': torch.randint(21, size=(1, 325)),
            'seq_length': torch.tensor([57])
        }
        fake_input_features['asym_id'] = fake_input_features['asym_id'].reshape(1, 325)
        fake_input_features["residue_index"] = pad_features(torch.tensor(self.residue_index).reshape(1, 57), nres_pad, pad_dim=1)
        fake_input_features['all_atom_positions'] = pad_features(true_atom_position, nres_pad, pad_dim=1)
        fake_input_features['all_atom_mask'] = pad_features(true_atom_mask, nres_pad=nres_pad, pad_dim=1)
        
        # NOTE
        # batch: simulates ground_truth features 
        # fake_input_features: simulates the data that gonna be used as input for model.forward(fake_input_features)
        # out: simulates the output of model.forward(fake_input_features)
180
        aligns, per_asym_residue_index = compute_permutation_alignment(out,
181
                                                                       fake_input_features,
182
183
184
185
                                                                       batch)
        labels = split_ground_truth_labels(batch)

        labels = merge_labels(per_asym_residue_index, labels, aligns,
186
187
                              original_nres=batch['aatype'].shape[-1])

188
189
190
191
        self.assertTrue(torch.equal(labels['residue_index'], batch['residue_index']))

        expected_permutated_gt_pos = torch.cat((chain_a2_pos, chain_a1_pos, chain_b2_pos, chain_b3_pos, chain_b1_pos),
                                               dim=1)
Geoffrey Yu's avatar
Geoffrey Yu committed
192
        
193
        self.assertTrue(torch.equal(labels['all_atom_positions'], expected_permutated_gt_pos))