Commit be127915 authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

revert back to e097da95

parent 0d98466d
......@@ -2079,10 +2079,7 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
ca_idx = rc.atom_order["CA"]
pred_ca_pos = out["final_atom_positions"][..., ca_idx, :] # [bsz, nres, 3]
pred_ca_mask = out["final_atom_mask"][..., ca_idx].to(dtype=pred_ca_pos.dtype) # [bsz, nres]
pred_ca_pos = pred_ca_pos.detach().to('cpu')
pred_ca_mask = pred_ca_mask.detach().to('cpu')
print(f"@@@@@@@ line 2082 pred_ca_pos isinf: {torch.isinf(pred_ca_pos).any()} isnan: {torch.isnan(pred_ca_pos).any()}")
print(f"@@@@@@@@ line 2083 pred_ca_mask isinf: {torch.isinf(pred_ca_mask).any()} isnan: {torch.isnan(pred_ca_mask).any()}")
true_ca_poses = [
l["all_atom_positions"][..., ca_idx, :] for l in labels
] # list([nres, 3])
......@@ -2098,7 +2095,6 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
per_asym_residue_index[int(cur_asym_id)] = torch.masked_select(batch["residue_index"], asym_mask)
if permutate_chains:
anchor_gt_asym, anchor_pred_asym = get_least_asym_entity_or_longest_length(batch)
torch.cuda.empty_cache()
print(f"anchor_gt_asym:{anchor_gt_asym} anchor_pred_asym:{anchor_pred_asym}")
anchor_gt_idx = int(anchor_gt_asym) - 1
......@@ -2109,46 +2105,25 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
cur_asym_id = torch.unique(batch["asym_id"][ent_mask])
entity_2_asym_list[int(cur_ent_id)] = cur_asym_id
asym_mask = (batch["asym_id"] == anchor_pred_asym).bool()
anchor_residue_idx = per_asym_residue_index[int(anchor_pred_asym)]
anchor_true_pos = torch.index_select(true_ca_poses[anchor_gt_idx], 1, anchor_residue_idx)
anchor_true_pos = torch.squeeze(anchor_true_pos,0)
print(f"@@@@@@@ line 2115 pred_ca_pos isinf: {torch.isinf(pred_ca_pos).any()} asym_mask isnan: {torch.isnan(pred_ca_pos).any()}")
print(f"@@@@@@@ line 2116 pred_ca_mask isinf: {torch.isinf(pred_ca_mask).any()} pred_ca_mask isnan: {torch.isnan(pred_ca_mask).any()}")
pred_ca_pos = torch.squeeze(pred_ca_pos,0)
asym_mask = torch.squeeze(asym_mask,0)
asym_mask = asym_mask.detach().to('cpu')
pred_ca_mask = torch.squeeze(pred_ca_mask,0)
print(f"@@@@@@@ line 2120 pred_ca_pos isinf: {torch.isinf(pred_ca_pos).any()} asym_mask isnan: {torch.isnan(pred_ca_pos).any()}")
print(f"@@@@@@@ line 2121 pred_ca_mask isinf: {torch.isinf(pred_ca_mask).any()} pred_ca_mask isnan: {torch.isnan(pred_ca_mask).any()}")
anchor_pred_pos = pred_ca_pos[asym_mask]
anchor_pred_pos = pred_ca_pos[0][asym_mask[0]]
anchor_true_mask = torch.index_select(true_ca_masks[anchor_gt_idx], 1, anchor_residue_idx)
anchor_pred_mask = pred_ca_mask[asym_mask]
anchor_pred_mask = torch.unsqueeze(anchor_pred_mask,0)
anchor_pred_mask = anchor_pred_mask.to(anchor_true_mask.device)
pred_ca_pos = torch.unsqueeze(pred_ca_pos,0)
pred_ca_mask = torch.unsqueeze(pred_ca_mask,0)
anchor_pred_mask = pred_ca_mask[0][asym_mask[0]]
input_mask = (anchor_true_mask * anchor_pred_mask).bool()
anchor_pred_pos = anchor_pred_pos.to(anchor_true_pos.device)
r, x = get_optimal_transform(
anchor_pred_pos, anchor_true_pos,
mask=torch.squeeze(input_mask,0)
anchor_pred_pos, anchor_true_pos[0],
mask=input_mask[0]
)
del input_mask # just to save memory
del anchor_pred_mask
del anchor_true_mask
gc.collect()
aligned_true_ca_poses = [ca.to(r.dtype) @ r + x for ca in true_ca_poses] # apply transforms
del true_ca_poses
gc.collect()
pred_ca_pos = pred_ca_pos.to(anchor_true_mask.device)
pred_ca_mask = pred_ca_mask.to(anchor_true_mask.device)
del anchor_true_mask
gc.collect()
torch.cuda.empty_cache()
align = greedy_align(
batch,
per_asym_residue_index,
......@@ -2165,7 +2140,6 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
del pred_ca_pos, pred_ca_mask
del anchor_pred_pos, anchor_true_pos
gc.collect()
torch.cuda.empty_cache()
print(f"finished multi-chain permutation and final align is {align}")
else:
align = list(enumerate(range(len(labels))))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment