Commit 02ce77c5 authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

fixed error when selected anchor_aysm is not in the cropped input features

parent 67f873e7
......@@ -1781,7 +1781,7 @@ def get_optimal_transform(
return r, x
def get_least_asym_entity_or_longest_length(batch):
def get_least_asym_entity_or_longest_length(batch,input_asym_id):
"""
First check how many subunit(s) one sequence has, if there is no tie, e.g. AABBB then select
one of the A as anchor
......@@ -1818,13 +1818,15 @@ def get_least_asym_entity_or_longest_length(batch):
# # If there is more than one chain in the predicted output that has the same sequence
# # as the chosen ground truth anchor, then randomly picke one
if len(best_pred_asym) > 1:
best_pred_asym = random.choice(best_pred_asym)
while best_pred_asym not in input_asym_id:
best_pred_asym = random.choice(best_pred_asym)
return least_asym_entities[0], best_pred_asym
def greedy_align(
batch,
per_asym_residue_index,
entity_2_asym_list,
pred_ca_pos,
pred_ca_mask,
......@@ -1835,9 +1837,9 @@ def greedy_align(
Implement Algorithm 4 in the Supplementary Information of AlphaFold-Multimer paper:
Evans,R et al., 2022 Protein complex prediction with AlphaFold-Multimer, bioRxiv 2021.10.04.463034; doi: https://doi.org/10.1101/2021.10.04.463034
"""
unique_asym_ids = [i for i in torch.unique(batch["asym_id"]) if i!=0]
used = [False for _ in range(len(true_ca_poses))]
align = []
unique_asym_ids = [i for i in torch.unique(batch["asym_id"]) if i!=0]
for cur_asym_id in unique_asym_ids:
i = int(cur_asym_id - 1)
asym_mask = batch["asym_id"] == cur_asym_id
......@@ -1845,28 +1847,25 @@ def greedy_align(
best_rmsd = torch.inf
best_idx = None
cur_asym_list = entity_2_asym_list[int(cur_entity_ids)]
cur_residue_index = per_asym_residue_index[int(cur_asym_id)]
cur_pred_pos = pred_ca_pos[asym_mask]
cur_pred_mask = pred_ca_mask[asym_mask]
for next_asym_id in cur_asym_list:
j = int(next_asym_id - 1)
if not used[j]: # possible candidate
cropped_pos = true_ca_poses[j]
cropped_pos = torch.squeeze(cropped_pos,0)
if cropped_pos.shape==cur_pred_pos.shape:
mask = true_ca_masks[j]
mask = torch.squeeze(mask,0)
rmsd = compute_rmsd(
torch.squeeze(cropped_pos,0), torch.squeeze(cur_pred_pos,0),
(cur_pred_mask * mask).bool()
)
if (rmsd is not None) and (rmsd < best_rmsd):
best_rmsd = rmsd
best_idx = j
cropped_pos = torch.index_select(true_ca_poses[j],1,cur_residue_index)
mask = torch.index_select(true_ca_masks[j],1,cur_residue_index)
rmsd = compute_rmsd(
torch.squeeze(cropped_pos,0), torch.squeeze(cur_pred_pos,0),
(cur_pred_mask * mask).bool()
)
if (rmsd is not None) and (rmsd < best_rmsd):
best_rmsd = rmsd
best_idx = j
assert best_idx is not None
used[best_idx] = True
align.append((i, best_idx))
return align
......@@ -1878,7 +1877,7 @@ def pad_features(feature_tensor,nres_pad,pad_dim):
return torch.concat((feature_tensor,padding_tensor),dim=pad_dim)
def merge_labels(labels, align,original_nres):
def merge_labels(per_asym_residue_index,labels, align,original_nres):
"""
Merge ground truth labels according to the permutation results
......@@ -1895,13 +1894,14 @@ def merge_labels(labels, align,original_nres):
label = labels[j][k]
cur_num_res = labels[j]['aatype'].shape[-1]
# to 1-based
cur_residue_index = per_asym_residue_index[i + 1]
if len(v.shape)<=1 or "template" in k or "row_mask" in k :
continue
else:
dimension_to_merge = label.shape.index(cur_num_res) if cur_num_res in label.shape else 0
if k =='all_atom_positions':
dimension_to_merge=1
cur_out[i] = label
cur_out[i] = label.index_select(dimension_to_merge,cur_residue_index)
cur_out = [x[1] for x in sorted(cur_out.items())]
if len(cur_out)>0:
new_v = torch.concat(cur_out, dim=dimension_to_merge)
......@@ -2100,8 +2100,10 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
"""
pred_ca_mask = torch.squeeze(pred_ca_mask,0)
asym_mask = torch.squeeze(asym_mask,0)
print(f"##### line 2102 asym_mask is {asym_mask} and shape: {asym_mask.shape}")
anchor_pred_mask = pred_ca_mask[asym_mask]
anchor_true_mask = torch.index_select(true_ca_masks[anchor_gt_idx],1,anchor_gt_residue)
print(f"##### line 2104 anchor_pred_mask:{anchor_pred_mask.shape} and anchor_true_mask : {anchor_true_mask.shape}")
input_mask = (anchor_true_mask * anchor_pred_mask).bool()
return input_mask
......@@ -2137,20 +2139,20 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
https://www.biorxiv.org/content/10.1101/2021.10.04.463034v2
"""
feature, ground_truth = batch
print(f"###### line 2140 feature asym_id is :{feature['asym_id']}")
del batch
labels = AlphaFoldMultimerLoss.split_ground_truth_labels(ground_truth,
REQUIRED_FEATURES=["all_atom_mask","all_atom_positions"])
assert isinstance(labels, list)
print(f"successfully split ground truth labels")
if permutate_chains:
# First select anchors from predicted structures and ground truths
anchor_gt_asym, anchor_pred_asym = get_least_asym_entity_or_longest_length(ground_truth)
print(f"###### anchor gt asym is: {anchor_gt_asym} and anchor pred asym is {anchor_pred_asym}")
anchor_gt_asym, anchor_pred_asym = get_least_asym_entity_or_longest_length(ground_truth,feature['asym_id'])
print(f"###### anchor_gt_asym:{anchor_gt_asym} and anchor_pred_asym: {anchor_pred_asym}")
entity_2_asym_list = AlphaFoldMultimerLoss.get_entity_2_asym_list(ground_truth)
labels = AlphaFoldMultimerLoss.split_ground_truth_labels(ground_truth,
REQUIRED_FEATURES=["all_atom_mask","all_atom_positions"])
assert isinstance(labels, list)
del ground_truth
anchor_gt_idx = int(anchor_gt_asym) - 1
asym_mask = (feature["asym_id"] == anchor_pred_asym).bool()
print(f"###### asym_mask is {asym_mask}")
# Then calculate optimal transform by aligning anchors
ca_idx = rc.atom_order["CA"]
pred_ca_pos = out["final_atom_positions"][..., ca_idx, :] # [bsz, nres, 3]
......@@ -2165,7 +2167,6 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
per_asym_residue_index = AlphaFoldMultimerLoss.get_per_asym_residue_index(feature)
anchor_gt_residue = per_asym_residue_index[int(anchor_gt_asym)]
print(f"######## per_asym_residue_index is {per_asym_residue_index}")
r, x = AlphaFoldMultimerLoss.calculate_optimal_transform(true_ca_poses,
anchor_gt_idx,anchor_gt_residue,
true_ca_masks,pred_ca_mask,
......@@ -2175,12 +2176,9 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
aligned_true_ca_poses = [ca.to(r.dtype) @ r + x for ca in true_ca_poses] # apply transforms
del true_ca_poses,r,x
gc.collect()
print(f"$$$$$$$ successfully calculated r and x")
import sys
sys.exit()
entity_2_asym_list = AlphaFoldMultimerLoss.get_entity_2_asym_list(batch)
align = greedy_align(
batch,
feature,
per_asym_residue_index,
entity_2_asym_list,
pred_ca_pos,
pred_ca_mask,
......@@ -2191,10 +2189,16 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
del true_ca_masks,aligned_true_ca_poses
del pred_ca_pos, pred_ca_mask
gc.collect()
print(f"finished permutation align. Align is {align}")
else:
per_asym_residue_index = AlphaFoldMultimerLoss.get_per_asym_residue_index(feature)
labels = AlphaFoldMultimerLoss.split_ground_truth_labels(ground_truth,
REQUIRED_FEATURES=["all_atom_mask","all_atom_positions"])
align = list(enumerate(range(len(labels))))
return align
return align, per_asym_residue_index
def forward(self, out, batch, _return_breakdown=False,permutate_chains=True):
"""
......@@ -2209,20 +2213,22 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
features, ground_truth = batch
del batch
is_monomer = len(torch.unique(features['asym_id']))==1 or torch.unique(features['asym_id']).tolist()==[0,1]
print(f"asym_id is {features['asym_id']}")
if not is_monomer:
permutate_chains = True
# Then permutate ground truth chains before calculating the loss
align = AlphaFoldMultimerLoss.multi_chain_perm_align(out, (features,ground_truth),permutate_chains=permutate_chains)
# Then permutate ground truth chains before calculating the loss
align,per_asym_residue_index = AlphaFoldMultimerLoss.multi_chain_perm_align(out,
(features,ground_truth),
permutate_chains=permutate_chains)
labels = AlphaFoldMultimerLoss.split_ground_truth_labels(features,
REQUIRED_FEATURES=[i for i in features.keys() if i in dim_dict])
# reorder ground truth labels according to permutation results
labels = merge_labels(labels,align,
original_nres=features['aatype'].shape[-1])
features.update(labels)
labels = AlphaFoldMultimerLoss.split_ground_truth_labels(ground_truth,
REQUIRED_FEATURES=[i for i in ground_truth.keys()])
# reorder ground truth labels according to permutation results
labels = merge_labels(per_asym_residue_index,labels,align,
original_nres=features['aatype'].shape[-1])
features.update(labels)
if (not _return_breakdown):
cum_loss = self.loss(out, features, _return_breakdown)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment