Commit 54755901 authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

finished codes for num_sym 1

parent f563944a
...@@ -118,14 +118,11 @@ def multi_chain_perm_align(out, batch, labels, shuffle_times=2): ...@@ -118,14 +118,11 @@ def multi_chain_perm_align(out, batch, labels, shuffle_times=2):
per_asym_residue_index = {} per_asym_residue_index = {}
for cur_asym_id in unique_asym_ids: for cur_asym_id in unique_asym_ids:
asym_mask = (batch["asym_id"] == cur_asym_id).bool() asym_mask = (batch["asym_id"] == cur_asym_id).bool()
print(f"line 121 asym_mask is {asym_mask}")
per_asym_residue_index[int(cur_asym_id)] = batch["residue_index"][asym_mask] per_asym_residue_index[int(cur_asym_id)] = batch["residue_index"][asym_mask]
anchor_gt_asym, anchor_pred_asym = get_anchor_candidates( anchor_gt_asym, anchor_pred_asym = get_anchor_candidates(
batch, per_asym_residue_index, true_ca_masks batch, per_asym_residue_index, true_ca_masks
) )
print(f"anchor_gt_asym is {anchor_gt_asym}, anchor_pred_asym is {anchor_pred_asym}") print(f"anchor_gt_asym is {anchor_gt_asym}, anchor_pred_asym is {anchor_pred_asym}")
import sys
sys.exit()
anchor_gt_idx = int(anchor_gt_asym) - 1 anchor_gt_idx = int(anchor_gt_asym) - 1
best_rmsd = 1e9 best_rmsd = 1e9
...@@ -141,27 +138,19 @@ def multi_chain_perm_align(out, batch, labels, shuffle_times=2): ...@@ -141,27 +138,19 @@ def multi_chain_perm_align(out, batch, labels, shuffle_times=2):
for cur_asym_id in anchor_pred_asym: for cur_asym_id in anchor_pred_asym:
asym_mask = (batch["asym_id"] == cur_asym_id).bool() asym_mask = (batch["asym_id"] == cur_asym_id).bool()
anchor_residue_idx = per_asym_residue_index[int(cur_asym_id)] anchor_residue_idx = per_asym_residue_index[int(cur_asym_id)]
print(f"anchor_residue_idx:{anchor_residue_idx},anchor_gt_idx:{anchor_gt_idx}\n")
anchor_true_pos = true_ca_poses[anchor_gt_idx][anchor_residue_idx[:,0]] # somehow need to only use the first column in anchor_residue_idx anchor_true_pos = true_ca_poses[anchor_gt_idx][anchor_residue_idx]
asym_mask = asym_mask[:,0] # somehow need to adjust the asym_mask shape
anchor_pred_pos = pred_ca_pos[asym_mask] anchor_pred_pos = pred_ca_pos[asym_mask]
print(f"true_ca_masks:\n") anchor_true_mask = true_ca_masks[anchor_gt_idx][anchor_residue_idx]
print(true_ca_masks[anchor_gt_idx].bool())
print(f"pred_ca_mask\n")
print(pred_ca_mask.bool())
anchor_true_mask = true_ca_masks[anchor_gt_idx][anchor_residue_idx[:,0]]
anchor_pred_mask = pred_ca_mask[asym_mask] anchor_pred_mask = pred_ca_mask[asym_mask]
print(f"anchor_true_mask:\n")
print(anchor_true_mask.shape)
print(f"anchor_pred_mask:\n")
print(anchor_pred_mask.shape)
r, x = get_optimal_transform( r, x = get_optimal_transform(
anchor_true_pos, anchor_true_pos,
anchor_pred_pos, anchor_pred_pos,
(anchor_true_mask.to('cpu') * anchor_pred_mask.to('cpu')).bool(), (anchor_true_mask.to('cpu') * anchor_pred_mask.to('cpu')).bool(),
) )
print(f"finished getting optimal transform")
import sys
sys.exit()
aligned_true_ca_poses = [ca.to('cpu') @ r.to('cpu') + x.to('cpu') for ca in true_ca_poses] # apply transforms aligned_true_ca_poses = [ca.to('cpu') @ r.to('cpu') + x.to('cpu') for ca in true_ca_poses] # apply transforms
for _ in range(shuffle_times): for _ in range(shuffle_times):
shuffle_idx = torch.randperm( shuffle_idx = torch.randperm(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment