test_permutation.py 9.06 KB
Newer Older
Geoffrey Yu's avatar
Geoffrey Yu committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2021 AlQuraishi Laboratory
# Dingquan Yu @ EMBL-Hamburg Kosinski group
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import math
Geoffrey Yu's avatar
Geoffrey Yu committed
16
17
18
import torch

import unittest
19
20
21
22
from openfold.utils.multi_chain_permutation import (pad_features, get_least_asym_entity_or_longest_length,
                                                    compute_permutation_alignment, split_ground_truth_labels,
                                                    merge_labels)

Geoffrey Yu's avatar
Geoffrey Yu committed
23

24
@unittest.skip("Tests need to be fixed post-refactor")
Geoffrey Yu's avatar
Geoffrey Yu committed
25
26
27
28
29
30
31
class TestPermutation(unittest.TestCase):
    def setUp(self):
        """
        create fake input structure features 
        and rotation matrices 
        """

32
33
        theta = math.pi / 4
        device = 'cpu'
Geoffrey Yu's avatar
Geoffrey Yu committed
34
        self.rotation_matrix_z = torch.tensor([
35
36
37
38
            [math.cos(theta), -math.sin(theta), 0],
            [math.sin(theta), math.cos(theta), 0],
            [0, 0, 1]
        ], device=device)
Geoffrey Yu's avatar
Geoffrey Yu committed
39
        self.rotation_matrix_x = torch.tensor([
40
41
42
43
            [1, 0, 0],
            [0, math.cos(theta), -math.sin(theta)],
            [0, math.sin(theta), math.cos(theta)],
        ], device=device)
Geoffrey Yu's avatar
Geoffrey Yu committed
44
        self.rotation_matrix_y = torch.tensor([
45
46
47
48
49
50
            [math.cos(theta), 0, math.sin(theta)],
            [0, 1, 0],
            [-math.sin(theta), 1, math.cos(theta)],
        ], device=device)
        self.chain_a_num_res = 9
        self.chain_b_num_res = 13
Geoffrey Yu's avatar
Geoffrey Yu committed
51
        # below create default fake ground truth structures for a hetero-pentamer A2B3
52
53
54
55
        self.residue_index = list(range(self.chain_a_num_res)) * 2 + list(range(self.chain_b_num_res)) * 3
        self.num_res = self.chain_a_num_res * 2 + self.chain_b_num_res * 3
        self.asym_id = torch.tensor([[1] * self.chain_a_num_res + [2] * self.chain_a_num_res + [
            3] * self.chain_b_num_res + [4] * self.chain_b_num_res + [5] * self.chain_b_num_res], device=device)
Geoffrey Yu's avatar
Geoffrey Yu committed
56
        self.sym_id = self.asym_id
57
58
        self.entity_id = torch.tensor([[1] * (self.chain_a_num_res * 2) + [2] * (self.chain_b_num_res * 3)],
                                      device=device)
Geoffrey Yu's avatar
Geoffrey Yu committed
59
60

    def test_1_selecting_anchors(self):
61
62
63
64
65
        batch = {
            'asym_id': self.asym_id,
            'sym_id': self.sym_id,
            'entity_id': self.entity_id,
            'seq_length': torch.tensor([57])
Geoffrey Yu's avatar
Geoffrey Yu committed
66
        }
67
68
69
70
71
        anchor_gt_asym, anchor_pred_asym = get_least_asym_entity_or_longest_length(batch, batch['asym_id'])
        self.assertIn(int(anchor_gt_asym), [1, 2])
        self.assertNotIn(int(anchor_gt_asym), [3, 4, 5])
        self.assertIn(int(anchor_pred_asym), [1, 2])
        self.assertNotIn(int(anchor_pred_asym), [3, 4, 5])
Geoffrey Yu's avatar
Geoffrey Yu committed
72
73
74

    def test_2_permutation_pentamer(self):
        batch = {
75
76
77
78
79
            'asym_id': self.asym_id,
            'sym_id': self.sym_id,
            'entity_id': self.entity_id,
            'seq_length': torch.tensor([57]),
            'aatype': torch.randint(21, size=(1, 57))
Geoffrey Yu's avatar
Geoffrey Yu committed
80
        }
81
82
        batch['asym_id'] = batch['asym_id'].reshape(1, self.num_res)
        batch["residue_index"] = torch.tensor([self.residue_index])
Geoffrey Yu's avatar
Geoffrey Yu committed
83
        # create fake ground truth atom positions        
84
85
86
87
88
89
90
91
        chain_a1_pos = torch.randint(15, (self.chain_a_num_res, 3 * 37),
                                     dtype=torch.float).reshape(1, self.chain_a_num_res, 37, 3)
        chain_a2_pos = torch.matmul(chain_a1_pos, self.rotation_matrix_x) + 10

        chain_b1_pos = torch.randint(low=15, high=30, size=(self.chain_b_num_res, 3 * 37),
                                     dtype=torch.float).reshape(1, self.chain_b_num_res, 37, 3)
        chain_b2_pos = torch.matmul(chain_b1_pos, self.rotation_matrix_y) + 10
        chain_b3_pos = torch.matmul(torch.matmul(chain_b1_pos, self.rotation_matrix_z), self.rotation_matrix_x) + 30
Geoffrey Yu's avatar
Geoffrey Yu committed
92
        # Below permutate predicted chain positions 
93
94
        pred_atom_position = torch.cat((chain_a2_pos, chain_a1_pos, chain_b2_pos, chain_b3_pos, chain_b1_pos), dim=1)
        pred_atom_mask = torch.ones((1, self.num_res, 37))
Geoffrey Yu's avatar
Geoffrey Yu committed
95
        out = {
96
97
            'final_atom_positions': pred_atom_position,
            'final_atom_mask': pred_atom_mask
Geoffrey Yu's avatar
Geoffrey Yu committed
98
99
        }

100
101
102
103
104
105
        true_atom_position = torch.cat((chain_a1_pos, chain_a2_pos, chain_b1_pos, chain_b2_pos, chain_b3_pos), dim=1)
        true_atom_mask = torch.cat((torch.ones((1, self.chain_a_num_res, 37)),
                                    torch.ones((1, self.chain_a_num_res, 37)),
                                    torch.ones((1, self.chain_b_num_res, 37)),
                                    torch.ones((1, self.chain_b_num_res, 37)),
                                    torch.ones((1, self.chain_b_num_res, 37))), dim=1)
Geoffrey Yu's avatar
Geoffrey Yu committed
106
107
        batch['all_atom_positions'] = true_atom_position
        batch['all_atom_mask'] = true_atom_mask
108
109
110

        aligns, _ = compute_permutation_alignment(out, batch,
                                                  batch)
111
        print(f"##### aligns is {aligns}")
112
113
114
115
        possible_outcome = [[(0, 1), (1, 0), (2, 3), (3, 4), (4, 2)], [(0, 0), (1, 1), (2, 3), (3, 4), (4, 2)]]
        wrong_outcome = [[(0, 1), (1, 0), (2, 4), (3, 2), (4, 3)], [(0, 0), (1, 1), (2, 2), (3, 3), (4, 4)]]
        self.assertIn(aligns, possible_outcome)
        self.assertNotIn(aligns, wrong_outcome)
Geoffrey Yu's avatar
Geoffrey Yu committed
116
117

    def test_3_merge_labels(self):
118
        nres_pad = 325 - 57  # suppose the cropping size is 325
Geoffrey Yu's avatar
Geoffrey Yu committed
119
        batch = {
120
121
122
123
124
            'asym_id': pad_features(self.asym_id, nres_pad, pad_dim=1),
            'sym_id': pad_features(self.sym_id, nres_pad, pad_dim=1),
            'entity_id': pad_features(self.entity_id, nres_pad, pad_dim=1),
            'aatype': torch.randint(21, size=(1, 325)),
            'seq_length': torch.tensor([57])
Geoffrey Yu's avatar
Geoffrey Yu committed
125
        }
126
127
        batch['asym_id'] = batch['asym_id'].reshape(1, 325)
        batch["residue_index"] = pad_features(torch.tensor(self.residue_index).reshape(1, 57), nres_pad, pad_dim=1)
Geoffrey Yu's avatar
Geoffrey Yu committed
128
        # create fake ground truth atom positions        
129
130
131
132
133
134
135
136
        chain_a1_pos = torch.randint(15, (self.chain_a_num_res, 3 * 37),
                                     dtype=torch.float).reshape(1, self.chain_a_num_res, 37, 3)
        chain_a2_pos = torch.matmul(chain_a1_pos, self.rotation_matrix_x) + 10

        chain_b1_pos = torch.randint(low=15, high=30, size=(self.chain_b_num_res, 3 * 37),
                                     dtype=torch.float).reshape(1, self.chain_b_num_res, 37, 3)
        chain_b2_pos = torch.matmul(chain_b1_pos, self.rotation_matrix_y) + 10
        chain_b3_pos = torch.matmul(torch.matmul(chain_b1_pos, self.rotation_matrix_z), self.rotation_matrix_x) + 30
Geoffrey Yu's avatar
Geoffrey Yu committed
137
        # Below permutate predicted chain positions 
138
139
140
141
        pred_atom_position = torch.cat((chain_a2_pos, chain_a1_pos, chain_b2_pos, chain_b3_pos, chain_b1_pos), dim=1)
        pred_atom_mask = torch.ones((1, self.num_res, 37))
        pred_atom_position = pad_features(pred_atom_position, nres_pad, pad_dim=1)
        pred_atom_mask = pad_features(pred_atom_mask, nres_pad, pad_dim=1)
Geoffrey Yu's avatar
Geoffrey Yu committed
142
        out = {
143
144
            'final_atom_positions': pred_atom_position,
            'final_atom_mask': pred_atom_mask
Geoffrey Yu's avatar
Geoffrey Yu committed
145
        }
146
147
148
149
150
151
152
153
154
155
156
157
158
159
        true_atom_position = torch.cat((chain_a1_pos, chain_a2_pos, chain_b1_pos, chain_b2_pos, chain_b3_pos), dim=1)
        true_atom_mask = torch.cat((torch.ones((1, self.chain_a_num_res, 37)),
                                    torch.ones((1, self.chain_a_num_res, 37)),
                                    torch.ones((1, self.chain_b_num_res, 37)),
                                    torch.ones((1, self.chain_b_num_res, 37)),
                                    torch.ones((1, self.chain_b_num_res, 37))), dim=1)
        batch['all_atom_positions'] = pad_features(true_atom_position, nres_pad, pad_dim=1)
        batch['all_atom_mask'] = pad_features(true_atom_mask, nres_pad=nres_pad, pad_dim=1)

        # tensor_to_cuda = lambda t: t.to('cuda')
        # ground_truth = tensor_tree_map(tensor_to_cuda,ground_truth)
        aligns, per_asym_residue_index = compute_permutation_alignment(out,
                                                                       batch,
                                                                       batch)
160
        print(f"##### aligns is {aligns}")
161
162
163
        labels = split_ground_truth_labels(batch)

        labels = merge_labels(per_asym_residue_index, labels, aligns,
164
165
                              original_nres=batch['aatype'].shape[-1])

166
167
168
169
170
171
        self.assertTrue(torch.equal(labels['residue_index'], batch['residue_index']))

        expected_permutated_gt_pos = torch.cat((chain_a2_pos, chain_a1_pos, chain_b2_pos, chain_b3_pos, chain_b1_pos),
                                               dim=1)
        expected_permutated_gt_pos = pad_features(expected_permutated_gt_pos, nres_pad, pad_dim=1)
        self.assertTrue(torch.equal(labels['all_atom_positions'], expected_permutated_gt_pos))