"git@developer.sourcefind.cn:OpenDAS/openfold.git" did not exist on "c46e97ca756ee4e549ee72c6aab84451b073eb62"
Commit fe01bb0c authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

fixed the index error. Now working on updating greedy_align

parent 2184eff0
...@@ -1700,9 +1700,6 @@ def compute_rmsd( ...@@ -1700,9 +1700,6 @@ def compute_rmsd(
atom_mask: torch.Tensor = None, atom_mask: torch.Tensor = None,
eps: float = 1e-6, eps: float = 1e-6,
) -> torch.Tensor: ) -> torch.Tensor:
# shape check
true_atom_pos = true_atom_pos
pred_atom_pos = pred_atom_pos
sq_diff = torch.square(true_atom_pos - pred_atom_pos).sum(dim=-1, keepdim=False) sq_diff = torch.square(true_atom_pos - pred_atom_pos).sum(dim=-1, keepdim=False)
del true_atom_pos del true_atom_pos
del pred_atom_pos del pred_atom_pos
...@@ -1860,19 +1857,28 @@ def greedy_align( ...@@ -1860,19 +1857,28 @@ def greedy_align(
for next_asym_id in cur_asym_list: for next_asym_id in cur_asym_list:
j = int(next_asym_id - 1) j = int(next_asym_id - 1)
if not used[j]: # possible candidate if not used[j]: # possible candidate
cropped_pos = torch.index_select(true_ca_poses[j],1,cur_residue_index) cropped_pos = true_ca_poses[j]
mask = torch.index_select(true_ca_masks[j],1,cur_residue_index) cropped_pos = torch.squeeze(cropped_pos,0)
rmsd = compute_rmsd( if not cropped_pos.shape==cur_pred_pos.shape:
torch.squeeze(cropped_pos,0), torch.squeeze(cur_pred_pos,0), # this means selected candidte is not the correct one. Skip
(cur_pred_mask * mask).bool() used[j] = True
) else:
if (rmsd is not None) and (rmsd < best_rmsd): mask = true_ca_masks[j]
best_rmsd = rmsd mask = torch.squeeze(mask,0)
best_idx = j print(f"cropped_pos shape: {cropped_pos.shape} cur_pred_pos shape: {cur_pred_pos.shape}")
print(f"mask shape: {mask.shape} and cur_pred_mask shape: {cur_pred_mask.shape} ")
rmsd = compute_rmsd(
torch.squeeze(cropped_pos,0), torch.squeeze(cur_pred_pos,0),
(cur_pred_mask * mask).bool()
)
if (rmsd is not None) and (rmsd < best_rmsd):
best_rmsd = rmsd
best_idx = j
assert best_idx is not None assert best_idx is not None
used[best_idx] = True used[best_idx] = True
align.append((i, best_idx)) align.append((i, best_idx))
return align return align
...@@ -2065,7 +2071,7 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss): ...@@ -2065,7 +2071,7 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
return labels return labels
@staticmethod @staticmethod
def multi_chain_perm_align(out, batch, dim_dict,permutate_chains=True): def multi_chain_perm_align(out, batch, dim_dict,permutate_chains=False):
""" """
A class method that first permutate chains in ground truth first A class method that first permutate chains in ground truth first
before calculating the loss. before calculating the loss.
...@@ -2084,15 +2090,20 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss): ...@@ -2084,15 +2090,20 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
l["all_atom_positions"][..., ca_idx, :] for l in labels l["all_atom_positions"][..., ca_idx, :] for l in labels
] # list([nres, 3]) ] # list([nres, 3])
true_ca_masks = [ true_ca_masks = [
l["all_atom_mask"][..., ca_idx].long() for l in labels l["all_atom_mask"][..., ca_idx].long() for l in labels
] # list([nres,]) ] # list([nres,])
unique_asym_ids = [i for i in torch.unique(batch["asym_id"]) if i!=0] unique_asym_ids = [i for i in torch.unique(batch["asym_id"]) if i!=0]
per_asym_residue_index = {} per_asym_residue_index = {}
for cur_asym_id in unique_asym_ids: for cur_asym_id in unique_asym_ids:
asym_mask = (batch["asym_id"] == cur_asym_id).bool() asym_mask = (batch["asym_id"] == cur_asym_id).bool()
per_asym_residue_index[int(cur_asym_id)] = torch.masked_select(batch["residue_index"], asym_mask) per_asym_residue_index[int(cur_asym_id)] = torch.masked_select(batch["residue_index"], asym_mask)
if permutate_chains: if permutate_chains:
anchor_gt_asym, anchor_pred_asym = get_least_asym_entity_or_longest_length(batch) anchor_gt_asym, anchor_pred_asym = get_least_asym_entity_or_longest_length(batch)
print(f"anchor_gt_asym:{anchor_gt_asym} anchor_pred_asym:{anchor_pred_asym}") print(f"anchor_gt_asym:{anchor_gt_asym} anchor_pred_asym:{anchor_pred_asym}")
...@@ -2105,11 +2116,12 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss): ...@@ -2105,11 +2116,12 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
cur_asym_id = torch.unique(batch["asym_id"][ent_mask]) cur_asym_id = torch.unique(batch["asym_id"][ent_mask])
entity_2_asym_list[int(cur_ent_id)] = cur_asym_id entity_2_asym_list[int(cur_ent_id)] = cur_asym_id
asym_mask = (batch["asym_id"] == anchor_pred_asym).bool() asym_mask = (batch["asym_id"] == anchor_pred_asym).bool()
anchor_residue_idx = per_asym_residue_index[int(anchor_pred_asym)]
anchor_true_pos = torch.index_select(true_ca_poses[anchor_gt_idx], 1, anchor_residue_idx) anchor_true_pos = true_ca_poses[anchor_gt_idx]
anchor_pred_pos = pred_ca_pos[0][asym_mask[0]] anchor_pred_pos = pred_ca_pos[0][asym_mask[0]]
anchor_true_mask = torch.index_select(true_ca_masks[anchor_gt_idx], 1, anchor_residue_idx) anchor_true_mask = true_ca_masks[anchor_gt_idx]
anchor_pred_mask = pred_ca_mask[0][asym_mask[0]] anchor_pred_mask = pred_ca_mask[0][asym_mask[0]]
input_mask = (anchor_true_mask * anchor_pred_mask).bool() input_mask = (anchor_true_mask * anchor_pred_mask).bool()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment