Commit d909c707 authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

further cleaned up functions

parent faca088f
......@@ -2095,14 +2095,18 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
@staticmethod
def calculate_optimal_transform(true_ca_poses,
anchor_gt_idx,anchor_residue_idx,
true_ca_masks,pred_ca_mask,
asym_mask,
pred_ca_pos):
true_ca_masks,ca_idx,
out,
asym_mask):
pred_ca_mask = out["final_atom_mask"][..., ca_idx] # [bsz, nres]
input_mask = AlphaFoldMultimerLoss.calculate_input_mask(true_ca_masks,
anchor_gt_idx,asym_mask,
pred_ca_mask,anchor_residue_idx)
anchor_true_pos = torch.index_select(true_ca_poses[anchor_gt_idx], 1, anchor_residue_idx)
pred_ca_pos = out["final_atom_positions"][..., ca_idx, :] # [bsz, nres, 3]
print(f"line 2109 is nan {torch.isnan(pred_ca_pos).any()} is inf : {torch.isinf(pred_ca_pos).any()}")
anchor_pred_pos = pred_ca_pos[0][asym_mask[0]]
r, x = get_optimal_transform(
anchor_pred_pos, anchor_true_pos[0],
......@@ -2124,12 +2128,6 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
REQUIRED_FEATURES=["all_atom_mask","all_atom_positions"])
assert isinstance(labels, list)
ca_idx = rc.atom_order["CA"]
pred_ca_pos = out["final_atom_positions"][..., ca_idx, :] # [bsz, nres, 3]
pred_ca_mask = out["final_atom_mask"][..., ca_idx].to(dtype=pred_ca_pos.dtype) # [bsz, nres]
true_ca_masks = [
l["all_atom_mask"][..., ca_idx].long() for l in labels
] # list([nres,])
if permutate_chains:
anchor_gt_asym, anchor_pred_asym = get_least_asym_entity_or_longest_length(batch)
......@@ -2141,18 +2139,23 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
anchor_residue_idx = per_asym_residue_index[int(anchor_pred_asym)]
true_ca_poses = [l["all_atom_positions"][..., ca_idx, :] for l in labels] # list([nres, 3])
true_ca_masks = [l["all_atom_mask"][..., ca_idx].long() for l in labels] # list([nres,])
r, x = AlphaFoldMultimerLoss.calculate_optimal_transform(true_ca_poses,
anchor_gt_idx,anchor_residue_idx,
true_ca_masks,pred_ca_mask,
asym_mask,
pred_ca_pos)
true_ca_masks,ca_idx,out,
asym_mask)
aligned_true_ca_poses = [ca.to(r.dtype) @ r + x for ca in true_ca_poses] # apply transforms
del true_ca_poses
gc.collect()
entity_2_asym_list = AlphaFoldMultimerLoss.get_entity_2_asym_list(batch)
pred_ca_mask = out["final_atom_mask"][..., ca_idx]
pred_ca_pos = out["final_atom_positions"][..., ca_idx, :] # [bsz, nres, 3]
print(f"line 2157 is nan {torch.isnan(pred_ca_pos).any()} is inf : is nan {torch.isnan(pred_ca_pos).any()} is nan {torch.isinf(pred_ca_pos).any()}")
align = greedy_align(
batch,
per_asym_residue_index,
......@@ -2165,7 +2168,6 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
del aligned_true_ca_poses, true_ca_masks
del r, x
del pred_ca_pos, pred_ca_mask
gc.collect()
print(f"finished multi-chain permutation and final align is {align}")
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment