unifold_permutation.py 11.2 KB
Newer Older
1
2
3
4
5
import torch
from openfold.np import residue_constants as rc
import logging
logger = logging.getLogger(__name__)
import sys
6
import random
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31

def kabsch_rotation(P, Q):
    """
    Using the Kabsch algorithm with two sets of paired point P and Q, centered
    around the centroid. Each vector set is represented as an NxD
    matrix, where D is the the dimension of the space.
    The algorithm works in three steps:
    - a centroid translation of P and Q (assumed done before this function
      call)
    - the computation of a covariance matrix C
    - computation of the optimal rotation matrix U
    For more info see http://en.wikipedia.org/wiki/Kabsch_algorithm
    Parameters
    ----------
    P : array
        (N,D) matrix, where N is points and D is dimension.
    Q : array
        (N,D) matrix, where N is points and D is dimension.
    Returns
    -------
    U : matrix
        Rotation matrix (D,D)
    """

    # Computation of the covariance matrix
Geoffrey Yu's avatar
Geoffrey Yu committed
32
    P,Q = P.to('cpu'),Q.to('cpu') # move to cpu memory just in case it takes up too much gpu mem 
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
    C = P.transpose(-1, -2) @ Q

    # Computation of the optimal rotation matrix
    # This can be done using singular value decomposition (SVD)
    # Getting the sign of the det(V)*(W) to decide
    # whether we need to correct our rotation matrix to ensure a
    # right-handed coordinate system.
    # And finally calculating the optimal rotation matrix U
    # see http://en.wikipedia.org/wiki/Kabsch_algorithm
    V, _, W = torch.linalg.svd(C)
    d = (torch.linalg.det(V) * torch.linalg.det(W)) < 0.0

    if d:
        V[:, -1] = -V[:, -1]

    # Create Rotation matrix U
    U = V @ W
    return U

def get_optimal_transform(
    src_atoms: torch.Tensor,
    tgt_atoms: torch.Tensor,
    mask: torch.Tensor = None,
):
    assert src_atoms.shape == tgt_atoms.shape, (src_atoms.shape, tgt_atoms.shape)
    assert src_atoms.shape[-1] == 3
    if mask is not None:
        assert mask.dtype == torch.bool
        assert mask.shape[-1] == src_atoms.shape[-2]
        if mask.sum() == 0:
            src_atoms = torch.zeros((1, 3), device=src_atoms.device).float()
            tgt_atoms = src_atoms
        else:
            src_atoms = src_atoms[mask, :]
            tgt_atoms = tgt_atoms[mask, :]
    src_center = src_atoms.mean(-2, keepdim=True)
    tgt_center = tgt_atoms.mean(-2, keepdim=True)
    r = kabsch_rotation(src_atoms - src_center, tgt_atoms - tgt_center)
Geoffrey Yu's avatar
Geoffrey Yu committed
71
    tgt_center,src_center = tgt_center.to('cpu'),src_center.to('cpu') # load to cpu memory just in case 
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
    x = tgt_center - src_center @ r
    return r, x


def compute_rmsd(
    true_atom_pos: torch.Tensor,
    pred_atom_pos: torch.Tensor,
    atom_mask: torch.Tensor = None,
    eps: float = 1e-6,
) -> torch.Tensor:
    # shape check
    sq_diff = torch.square(true_atom_pos - pred_atom_pos).sum(dim=-1, keepdim=False)
    if atom_mask is not None:
        sq_diff = sq_diff[atom_mask]
    msd = torch.mean(sq_diff)
    msd = torch.nan_to_num(msd, nan=1e8)
    return torch.sqrt(msd + eps)

def kabsch_rmsd(
    true_atom_pos: torch.Tensor,
    pred_atom_pos: torch.Tensor,
    atom_mask: torch.Tensor,
):
    r, x = get_optimal_transform(
        true_atom_pos,
        pred_atom_pos,
        atom_mask,
    )
    aligned_true_atom_pos = true_atom_pos @ r + x
    return compute_rmsd(aligned_true_atom_pos, pred_atom_pos, atom_mask)



def multi_chain_perm_align(out, batch, labels, shuffle_times=2):
    assert isinstance(labels, list)
    ca_idx = rc.atom_order["CA"]
    pred_ca_pos = out["final_atom_positions"][..., ca_idx, :].float()  # [bsz, nres, 3]
    pred_ca_mask = out["final_atom_mask"][..., ca_idx].float()  # [bsz, nres]
    true_ca_poses = [
        l["all_atom_positions"][..., ca_idx, :].float() for l in labels
    ]  # list([nres, 3])
    true_ca_masks = [
        l["all_atom_mask"][..., ca_idx].float() for l in labels
    ]  # list([nres,])

    unique_asym_ids = torch.unique(batch["asym_id"])

    per_asym_residue_index = {}
    for cur_asym_id in unique_asym_ids:
        asym_mask = (batch["asym_id"] == cur_asym_id).bool()
        per_asym_residue_index[int(cur_asym_id)] = batch["residue_index"][asym_mask]
123
124
125
126
127
128
129
    # anchor_gt_asym, anchor_pred_asym = get_anchor_candidates(
    #     batch, per_asym_residue_index, true_ca_masks
    # )
    anchor_gt_asym, anchor_pred_asym=get_least_asym_entity_or_longest_length(batch)
    print(f"anchor_gt_asym is {anchor_gt_asym}")
    import sys
    sys.exit()
130
131
132
133
134
135
136
137
138
139
140
    anchor_gt_idx = int(anchor_gt_asym) - 1

    best_rmsd = 1e9
    best_labels = None

    unique_entity_ids = torch.unique(batch["entity_id"])
    entity_2_asym_list = {}
    for cur_ent_id in unique_entity_ids:
        ent_mask = batch["entity_id"] == cur_ent_id
        cur_asym_id = torch.unique(batch["asym_id"][ent_mask])
        entity_2_asym_list[int(cur_ent_id)] = cur_asym_id
141
    print(f"entity_2_asym_list is {entity_2_asym_list}")
142
143
144
145
    for cur_asym_id in anchor_pred_asym:
        asym_mask = (batch["asym_id"] == cur_asym_id).bool()
        anchor_residue_idx = per_asym_residue_index[int(cur_asym_id)]

Geoffrey Yu's avatar
Geoffrey Yu committed
146
        anchor_true_pos = true_ca_poses[anchor_gt_idx][anchor_residue_idx] 
147
        anchor_pred_pos = pred_ca_pos[asym_mask]
Geoffrey Yu's avatar
Geoffrey Yu committed
148
        anchor_true_mask = true_ca_masks[anchor_gt_idx][anchor_residue_idx]
149
150
151
152
153
154
        anchor_pred_mask = pred_ca_mask[asym_mask]
        r, x = get_optimal_transform(
            anchor_true_pos,
            anchor_pred_pos,
            (anchor_true_mask.to('cpu') * anchor_pred_mask.to('cpu')).bool(),
        )
155
    
Geoffrey Yu's avatar
Geoffrey Yu committed
156
        aligned_true_ca_poses = [ca.to('cpu') @ r.to('cpu') + x.to('cpu') for ca in true_ca_poses]  # apply transforms
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
        for _ in range(shuffle_times):
            shuffle_idx = torch.randperm(
                unique_asym_ids.shape[0], device=unique_asym_ids.device
            )
            shuffled_asym_ids = unique_asym_ids[shuffle_idx]
            align = greedy_align(
                batch,
                per_asym_residue_index,
                shuffled_asym_ids,
                entity_2_asym_list,
                pred_ca_pos,
                pred_ca_mask,
                aligned_true_ca_poses,
                true_ca_masks,
            )
            merged_labels = merge_labels(
                batch,
                per_asym_residue_index,
                labels,
                align,
            )
            rmsd = kabsch_rmsd(
179
                merged_labels["all_atom_positions"][..., ca_idx, :].to('cpu') @ r.to('cpu') + x.to('cpu'),
180
                pred_ca_pos,
181
                (pred_ca_mask.to('cpu') * merged_labels["all_atom_mask"][..., ca_idx].to('cpu')).bool(),
182
183
184
185
186
            )

            if rmsd < best_rmsd:
                best_rmsd = rmsd
                best_labels = merged_labels
187
188

            print(f"finished kabsh_rmsd")
189
190
191
    return best_labels


192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
def get_least_asym_entity_or_longest_length(batch):
    """
    First check how many subunit(s) one sequence has, if there is no tie, e.g. AABBB then select 
    one of the A as anchor

    If there is a tie, e.g. AABB, first check which sequence is the longer/longest,
    then choose one of the corresponding subunits as anchor 
    """
    unique_entity_ids = torch.unique(batch["entity_id"])
    entity_asym_count = {}
    entity_length = {}

    for entity_id in unique_entity_ids:
        asym_ids = torch.unique(batch["asym_id"][batch["entity_id"] == entity_id])
        entity_asym_count[int(entity_id)] = len(asym_ids)
        
        # Calculate entity length
        entity_mask = (batch["entity_id"] == entity_id)
        entity_length[int(entity_id)] = entity_mask.sum().item()

    min_asym_count = min(entity_asym_count.values())
    least_asym_entities = [entity for entity, count in entity_asym_count.items() if count == min_asym_count]

    # If multiple entities have the least asym_id count, return those with the shortest length
    if len(least_asym_entities) > 1:
        max_length = max([entity_length[entity] for entity in least_asym_entities])
        least_asym_entities = [entity for entity in least_asym_entities if entity_length[entity] == max_length]

    # If still multiple entities, return a random one
    if len(least_asym_entities) > 1:
        least_asym_entities = random.choice(least_asym_entities)
    print(f"line 249 least_asym_entities is {least_asym_entities} and entity_length is {entity_length}")
    assert len(least_asym_entities)==1
    best_pred_asym = torch.unique(batch["asym_id"][batch["entity_id"] == least_asym_entities[0]])
    return least_asym_entities[0], best_pred_asym
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254


def greedy_align(
    batch,
    per_asym_residue_index,
    unique_asym_ids,
    entity_2_asym_list,
    pred_ca_pos,
    pred_ca_mask,
    true_ca_poses,
    true_ca_masks,
):
    used = [False for _ in range(len(true_ca_poses))]
    align = []
    for cur_asym_id in unique_asym_ids:
        # skip padding
        if cur_asym_id == 0:
            continue
        i = int(cur_asym_id - 1)
        asym_mask = batch["asym_id"] == cur_asym_id
        num_sym = batch["num_sym"][asym_mask][0]
        # don't need to align
        if (num_sym) == 1:
            align.append((i, i))
            assert used[i] == False
            used[i] = True
            continue
        cur_entity_ids = batch["entity_id"][asym_mask][0]
255
        best_rmsd = 1e20    
256
257
258
        best_idx = None
        cur_asym_list = entity_2_asym_list[int(cur_entity_ids)]
        cur_residue_index = per_asym_residue_index[int(cur_asym_id)]
Geoffrey Yu's avatar
Geoffrey Yu committed
259
        
260
261
        cur_pred_pos = pred_ca_pos[asym_mask] # only need the first 1 column of asym_mask
        cur_pred_mask = pred_ca_mask[asym_mask]
262
263
264
265
        for next_asym_id in cur_asym_list:
            if next_asym_id == 0:
                continue
            j = int(next_asym_id - 1)
Geoffrey Yu's avatar
Geoffrey Yu committed
266
            if not used[j]:  # possible candidate
267
                cropped_pos = true_ca_poses[j]
268
                mask = true_ca_masks[j][cur_residue_index]
269
                rmsd = compute_rmsd(
270
                    cropped_pos, cur_pred_pos, (cur_pred_mask.to('cpu') * mask.to('cpu')).bool()
271
272
273
274
275
276
277
                )
                if rmsd < best_rmsd:
                    best_rmsd = rmsd
                    best_idx = j
        assert best_idx is not None
        used[best_idx] = True
        align.append((i, best_idx))
278
    print(f"align is {align}")
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
    return align


def merge_labels(batch, per_asym_residue_index, labels, align):
    """
    batch:
    labels: list of label dicts, each with shape [nk, *]
    align: list of int, such as [2, None, 0, 1], each entry specify the corresponding label of the asym.
    """
    num_res = batch["msa_mask"].shape[-1]
    outs = {}
    for k, v in labels[0].items():
        if k in [
            "resolution",
        ]:
            continue
        cur_out = {}
        for i, j in align:
            label = labels[j][k]
            # to 1-based
            cur_residue_index = per_asym_residue_index[i + 1]
            cur_out[i] = label[cur_residue_index]
        cur_out = [x[1] for x in sorted(cur_out.items())]
        new_v = torch.concat(cur_out, dim=0)
        merged_nres = new_v.shape[0]
        assert (
            merged_nres <= num_res
        ), f"bad merged num res: {merged_nres} > {num_res}. something is wrong."
        if merged_nres < num_res:  # must pad
            pad_dim = new_v.shape[1:]
            pad_v = new_v.new_zeros((num_res - merged_nres, *pad_dim))
            new_v = torch.concat((new_v, pad_v), dim=0)
        outs[k] = new_v
312
    return outs