Commit 02ce77c5 authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

fixed error when selected anchor_aysm is not in the cropped input features

parent 67f873e7
...@@ -1781,7 +1781,7 @@ def get_optimal_transform( ...@@ -1781,7 +1781,7 @@ def get_optimal_transform(
return r, x return r, x
def get_least_asym_entity_or_longest_length(batch): def get_least_asym_entity_or_longest_length(batch,input_asym_id):
""" """
First check how many subunit(s) one sequence has, if there is no tie, e.g. AABBB then select First check how many subunit(s) one sequence has, if there is no tie, e.g. AABBB then select
one of the A as anchor one of the A as anchor
...@@ -1818,13 +1818,15 @@ def get_least_asym_entity_or_longest_length(batch): ...@@ -1818,13 +1818,15 @@ def get_least_asym_entity_or_longest_length(batch):
# # If there is more than one chain in the predicted output that has the same sequence # # If there is more than one chain in the predicted output that has the same sequence
# # as the chosen ground truth anchor, then randomly picke one # # as the chosen ground truth anchor, then randomly picke one
if len(best_pred_asym) > 1: if len(best_pred_asym) > 1:
best_pred_asym = random.choice(best_pred_asym) while best_pred_asym not in input_asym_id:
best_pred_asym = random.choice(best_pred_asym)
return least_asym_entities[0], best_pred_asym return least_asym_entities[0], best_pred_asym
def greedy_align( def greedy_align(
batch, batch,
per_asym_residue_index,
entity_2_asym_list, entity_2_asym_list,
pred_ca_pos, pred_ca_pos,
pred_ca_mask, pred_ca_mask,
...@@ -1835,9 +1837,9 @@ def greedy_align( ...@@ -1835,9 +1837,9 @@ def greedy_align(
Implement Algorithm 4 in the Supplementary Information of AlphaFold-Multimer paper: Implement Algorithm 4 in the Supplementary Information of AlphaFold-Multimer paper:
Evans,R et al., 2022 Protein complex prediction with AlphaFold-Multimer, bioRxiv 2021.10.04.463034; doi: https://doi.org/10.1101/2021.10.04.463034 Evans,R et al., 2022 Protein complex prediction with AlphaFold-Multimer, bioRxiv 2021.10.04.463034; doi: https://doi.org/10.1101/2021.10.04.463034
""" """
unique_asym_ids = [i for i in torch.unique(batch["asym_id"]) if i!=0]
used = [False for _ in range(len(true_ca_poses))] used = [False for _ in range(len(true_ca_poses))]
align = [] align = []
unique_asym_ids = [i for i in torch.unique(batch["asym_id"]) if i!=0]
for cur_asym_id in unique_asym_ids: for cur_asym_id in unique_asym_ids:
i = int(cur_asym_id - 1) i = int(cur_asym_id - 1)
asym_mask = batch["asym_id"] == cur_asym_id asym_mask = batch["asym_id"] == cur_asym_id
...@@ -1845,28 +1847,25 @@ def greedy_align( ...@@ -1845,28 +1847,25 @@ def greedy_align(
best_rmsd = torch.inf best_rmsd = torch.inf
best_idx = None best_idx = None
cur_asym_list = entity_2_asym_list[int(cur_entity_ids)] cur_asym_list = entity_2_asym_list[int(cur_entity_ids)]
cur_residue_index = per_asym_residue_index[int(cur_asym_id)]
cur_pred_pos = pred_ca_pos[asym_mask] cur_pred_pos = pred_ca_pos[asym_mask]
cur_pred_mask = pred_ca_mask[asym_mask] cur_pred_mask = pred_ca_mask[asym_mask]
for next_asym_id in cur_asym_list: for next_asym_id in cur_asym_list:
j = int(next_asym_id - 1) j = int(next_asym_id - 1)
if not used[j]: # possible candidate if not used[j]: # possible candidate
cropped_pos = true_ca_poses[j] cropped_pos = torch.index_select(true_ca_poses[j],1,cur_residue_index)
cropped_pos = torch.squeeze(cropped_pos,0) mask = torch.index_select(true_ca_masks[j],1,cur_residue_index)
if cropped_pos.shape==cur_pred_pos.shape: rmsd = compute_rmsd(
mask = true_ca_masks[j] torch.squeeze(cropped_pos,0), torch.squeeze(cur_pred_pos,0),
mask = torch.squeeze(mask,0) (cur_pred_mask * mask).bool()
rmsd = compute_rmsd( )
torch.squeeze(cropped_pos,0), torch.squeeze(cur_pred_pos,0), if (rmsd is not None) and (rmsd < best_rmsd):
(cur_pred_mask * mask).bool() best_rmsd = rmsd
) best_idx = j
if (rmsd is not None) and (rmsd < best_rmsd):
best_rmsd = rmsd
best_idx = j
assert best_idx is not None assert best_idx is not None
used[best_idx] = True used[best_idx] = True
align.append((i, best_idx)) align.append((i, best_idx))
return align return align
...@@ -1878,7 +1877,7 @@ def pad_features(feature_tensor,nres_pad,pad_dim): ...@@ -1878,7 +1877,7 @@ def pad_features(feature_tensor,nres_pad,pad_dim):
return torch.concat((feature_tensor,padding_tensor),dim=pad_dim) return torch.concat((feature_tensor,padding_tensor),dim=pad_dim)
def merge_labels(labels, align,original_nres): def merge_labels(per_asym_residue_index,labels, align,original_nres):
""" """
Merge ground truth labels according to the permutation results Merge ground truth labels according to the permutation results
...@@ -1895,13 +1894,14 @@ def merge_labels(labels, align,original_nres): ...@@ -1895,13 +1894,14 @@ def merge_labels(labels, align,original_nres):
label = labels[j][k] label = labels[j][k]
cur_num_res = labels[j]['aatype'].shape[-1] cur_num_res = labels[j]['aatype'].shape[-1]
# to 1-based # to 1-based
cur_residue_index = per_asym_residue_index[i + 1]
if len(v.shape)<=1 or "template" in k or "row_mask" in k : if len(v.shape)<=1 or "template" in k or "row_mask" in k :
continue continue
else: else:
dimension_to_merge = label.shape.index(cur_num_res) if cur_num_res in label.shape else 0 dimension_to_merge = label.shape.index(cur_num_res) if cur_num_res in label.shape else 0
if k =='all_atom_positions': if k =='all_atom_positions':
dimension_to_merge=1 dimension_to_merge=1
cur_out[i] = label cur_out[i] = label.index_select(dimension_to_merge,cur_residue_index)
cur_out = [x[1] for x in sorted(cur_out.items())] cur_out = [x[1] for x in sorted(cur_out.items())]
if len(cur_out)>0: if len(cur_out)>0:
new_v = torch.concat(cur_out, dim=dimension_to_merge) new_v = torch.concat(cur_out, dim=dimension_to_merge)
...@@ -2100,8 +2100,10 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss): ...@@ -2100,8 +2100,10 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
""" """
pred_ca_mask = torch.squeeze(pred_ca_mask,0) pred_ca_mask = torch.squeeze(pred_ca_mask,0)
asym_mask = torch.squeeze(asym_mask,0) asym_mask = torch.squeeze(asym_mask,0)
print(f"##### line 2102 asym_mask is {asym_mask} and shape: {asym_mask.shape}")
anchor_pred_mask = pred_ca_mask[asym_mask] anchor_pred_mask = pred_ca_mask[asym_mask]
anchor_true_mask = torch.index_select(true_ca_masks[anchor_gt_idx],1,anchor_gt_residue) anchor_true_mask = torch.index_select(true_ca_masks[anchor_gt_idx],1,anchor_gt_residue)
print(f"##### line 2104 anchor_pred_mask:{anchor_pred_mask.shape} and anchor_true_mask : {anchor_true_mask.shape}")
input_mask = (anchor_true_mask * anchor_pred_mask).bool() input_mask = (anchor_true_mask * anchor_pred_mask).bool()
return input_mask return input_mask
...@@ -2137,20 +2139,20 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss): ...@@ -2137,20 +2139,20 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
https://www.biorxiv.org/content/10.1101/2021.10.04.463034v2 https://www.biorxiv.org/content/10.1101/2021.10.04.463034v2
""" """
feature, ground_truth = batch feature, ground_truth = batch
print(f"###### line 2140 feature asym_id is :{feature['asym_id']}")
del batch del batch
labels = AlphaFoldMultimerLoss.split_ground_truth_labels(ground_truth,
REQUIRED_FEATURES=["all_atom_mask","all_atom_positions"])
assert isinstance(labels, list)
print(f"successfully split ground truth labels")
if permutate_chains: if permutate_chains:
# First select anchors from predicted structures and ground truths # First select anchors from predicted structures and ground truths
anchor_gt_asym, anchor_pred_asym = get_least_asym_entity_or_longest_length(ground_truth) anchor_gt_asym, anchor_pred_asym = get_least_asym_entity_or_longest_length(ground_truth,feature['asym_id'])
print(f"###### anchor gt asym is: {anchor_gt_asym} and anchor pred asym is {anchor_pred_asym}") print(f"###### anchor_gt_asym:{anchor_gt_asym} and anchor_pred_asym: {anchor_pred_asym}")
entity_2_asym_list = AlphaFoldMultimerLoss.get_entity_2_asym_list(ground_truth)
labels = AlphaFoldMultimerLoss.split_ground_truth_labels(ground_truth,
REQUIRED_FEATURES=["all_atom_mask","all_atom_positions"])
assert isinstance(labels, list)
del ground_truth
anchor_gt_idx = int(anchor_gt_asym) - 1 anchor_gt_idx = int(anchor_gt_asym) - 1
asym_mask = (feature["asym_id"] == anchor_pred_asym).bool() asym_mask = (feature["asym_id"] == anchor_pred_asym).bool()
print(f"###### asym_mask is {asym_mask}")
# Then calculate optimal transform by aligning anchors # Then calculate optimal transform by aligning anchors
ca_idx = rc.atom_order["CA"] ca_idx = rc.atom_order["CA"]
pred_ca_pos = out["final_atom_positions"][..., ca_idx, :] # [bsz, nres, 3] pred_ca_pos = out["final_atom_positions"][..., ca_idx, :] # [bsz, nres, 3]
...@@ -2165,7 +2167,6 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss): ...@@ -2165,7 +2167,6 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
per_asym_residue_index = AlphaFoldMultimerLoss.get_per_asym_residue_index(feature) per_asym_residue_index = AlphaFoldMultimerLoss.get_per_asym_residue_index(feature)
anchor_gt_residue = per_asym_residue_index[int(anchor_gt_asym)] anchor_gt_residue = per_asym_residue_index[int(anchor_gt_asym)]
print(f"######## per_asym_residue_index is {per_asym_residue_index}")
r, x = AlphaFoldMultimerLoss.calculate_optimal_transform(true_ca_poses, r, x = AlphaFoldMultimerLoss.calculate_optimal_transform(true_ca_poses,
anchor_gt_idx,anchor_gt_residue, anchor_gt_idx,anchor_gt_residue,
true_ca_masks,pred_ca_mask, true_ca_masks,pred_ca_mask,
...@@ -2175,12 +2176,9 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss): ...@@ -2175,12 +2176,9 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
aligned_true_ca_poses = [ca.to(r.dtype) @ r + x for ca in true_ca_poses] # apply transforms aligned_true_ca_poses = [ca.to(r.dtype) @ r + x for ca in true_ca_poses] # apply transforms
del true_ca_poses,r,x del true_ca_poses,r,x
gc.collect() gc.collect()
print(f"$$$$$$$ successfully calculated r and x")
import sys
sys.exit()
entity_2_asym_list = AlphaFoldMultimerLoss.get_entity_2_asym_list(batch)
align = greedy_align( align = greedy_align(
batch, feature,
per_asym_residue_index,
entity_2_asym_list, entity_2_asym_list,
pred_ca_pos, pred_ca_pos,
pred_ca_mask, pred_ca_mask,
...@@ -2191,10 +2189,16 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss): ...@@ -2191,10 +2189,16 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
del true_ca_masks,aligned_true_ca_poses del true_ca_masks,aligned_true_ca_poses
del pred_ca_pos, pred_ca_mask del pred_ca_pos, pred_ca_mask
gc.collect() gc.collect()
print(f"finished permutation align. Align is {align}")
else: else:
per_asym_residue_index = AlphaFoldMultimerLoss.get_per_asym_residue_index(feature)
labels = AlphaFoldMultimerLoss.split_ground_truth_labels(ground_truth,
REQUIRED_FEATURES=["all_atom_mask","all_atom_positions"])
align = list(enumerate(range(len(labels)))) align = list(enumerate(range(len(labels))))
return align return align, per_asym_residue_index
def forward(self, out, batch, _return_breakdown=False,permutate_chains=True): def forward(self, out, batch, _return_breakdown=False,permutate_chains=True):
""" """
...@@ -2209,20 +2213,22 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss): ...@@ -2209,20 +2213,22 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
features, ground_truth = batch features, ground_truth = batch
del batch del batch
is_monomer = len(torch.unique(features['asym_id']))==1 or torch.unique(features['asym_id']).tolist()==[0,1] is_monomer = len(torch.unique(features['asym_id']))==1 or torch.unique(features['asym_id']).tolist()==[0,1]
print(f"asym_id is {features['asym_id']}")
if not is_monomer: if not is_monomer:
permutate_chains = True permutate_chains = True
# Then permutate ground truth chains before calculating the loss # Then permutate ground truth chains before calculating the loss
align = AlphaFoldMultimerLoss.multi_chain_perm_align(out, (features,ground_truth),permutate_chains=permutate_chains) align,per_asym_residue_index = AlphaFoldMultimerLoss.multi_chain_perm_align(out,
(features,ground_truth),
permutate_chains=permutate_chains)
labels = AlphaFoldMultimerLoss.split_ground_truth_labels(features, labels = AlphaFoldMultimerLoss.split_ground_truth_labels(ground_truth,
REQUIRED_FEATURES=[i for i in features.keys() if i in dim_dict]) REQUIRED_FEATURES=[i for i in ground_truth.keys()])
# reorder ground truth labels according to permutation results
labels = merge_labels(labels,align, # reorder ground truth labels according to permutation results
original_nres=features['aatype'].shape[-1]) labels = merge_labels(per_asym_residue_index,labels,align,
features.update(labels) original_nres=features['aatype'].shape[-1])
features.update(labels)
if (not _return_breakdown): if (not _return_breakdown):
cum_loss = self.loss(out, features, _return_breakdown) cum_loss = self.loss(out, features, _return_breakdown)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment