Commit a0f8a057 authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

now used the new way of selecting anchors

parent 54755901
......@@ -40,12 +40,12 @@ class TestPermutation(unittest.TestCase):
In the test case, use PDB ID 1e4k as the label
"""
self.test_data_dir = os.path.join(os.getcwd(),"tests/test_data")
self.label_ids = ['label_1','label_2']
self.label_ids = ['label_1','label_2','label_2']
def test_dry_run(self):
n_seq = consts.n_seq
n_templ = consts.n_templ
n_res = consts.n_res
n_res = consts.n_res +13
n_extra_seq = consts.n_extra
c = model_config(consts.model, train=True)
......@@ -83,12 +83,12 @@ class TestPermutation(unittest.TestCase):
# Modify asym_id, entity_id and sym_id so that it encodes
# 2 chains
# #
asym_id = [1]*9+[2]*13
asym_id = [1]*9+[2]*13+[3]*13
batch["asym_id"] = torch.tensor(asym_id,dtype=torch.float64)
# batch["entity_id"] = torch.randint(0, 1, size=(n_res,))
batch['entity_id'] = torch.tensor(asym_id,dtype=torch.float64)
batch['entity_id'] = torch.tensor([1]*9+[2]*26,dtype=torch.float64)
batch["sym_id"] = torch.tensor(asym_id,dtype=torch.float64)
batch["num_sym"] = torch.tensor([2]*22,dtype=torch.int64) # currently there are just 2 chains
batch["num_sym"] = torch.tensor([1]*9+[2]*26,dtype=torch.int64) # currently there are just 2 chains
batch["extra_deletion_matrix"] = torch.randint(0, 2, size=(n_extra_seq, n_res))
add_recycling_dims = lambda t: (
t.unsqueeze(-1).expand(*t.shape, c.data.common.max_recycling_iters)
......
......@@ -3,6 +3,7 @@ from openfold.np import residue_constants as rc
import logging
logger = logging.getLogger(__name__)
import sys
import random
def kabsch_rotation(P, Q):
"""
......@@ -119,10 +120,13 @@ def multi_chain_perm_align(out, batch, labels, shuffle_times=2):
for cur_asym_id in unique_asym_ids:
asym_mask = (batch["asym_id"] == cur_asym_id).bool()
per_asym_residue_index[int(cur_asym_id)] = batch["residue_index"][asym_mask]
anchor_gt_asym, anchor_pred_asym = get_anchor_candidates(
batch, per_asym_residue_index, true_ca_masks
)
print(f"anchor_gt_asym is {anchor_gt_asym}, anchor_pred_asym is {anchor_pred_asym}")
# anchor_gt_asym, anchor_pred_asym = get_anchor_candidates(
# batch, per_asym_residue_index, true_ca_masks
# )
anchor_gt_asym, anchor_pred_asym=get_least_asym_entity_or_longest_length(batch)
print(f"anchor_gt_asym is {anchor_gt_asym}")
import sys
sys.exit()
anchor_gt_idx = int(anchor_gt_asym) - 1
best_rmsd = 1e9
......@@ -134,7 +138,7 @@ def multi_chain_perm_align(out, batch, labels, shuffle_times=2):
ent_mask = batch["entity_id"] == cur_ent_id
cur_asym_id = torch.unique(batch["asym_id"][ent_mask])
entity_2_asym_list[int(cur_ent_id)] = cur_asym_id
print(f"entity_2_asym_list is {entity_2_asym_list}")
for cur_asym_id in anchor_pred_asym:
asym_mask = (batch["asym_id"] == cur_asym_id).bool()
anchor_residue_idx = per_asym_residue_index[int(cur_asym_id)]
......@@ -148,9 +152,7 @@ def multi_chain_perm_align(out, batch, labels, shuffle_times=2):
anchor_pred_pos,
(anchor_true_mask.to('cpu') * anchor_pred_mask.to('cpu')).bool(),
)
print(f"finished getting optimal transform")
import sys
sys.exit()
aligned_true_ca_poses = [ca.to('cpu') @ r.to('cpu') + x.to('cpu') for ca in true_ca_poses] # apply transforms
for _ in range(shuffle_times):
shuffle_idx = torch.randperm(
......@@ -167,57 +169,61 @@ def multi_chain_perm_align(out, batch, labels, shuffle_times=2):
aligned_true_ca_poses,
true_ca_masks,
)
merged_labels = merge_labels(
batch,
per_asym_residue_index,
labels,
align,
)
rmsd = kabsch_rmsd(
merged_labels["all_atom_positions"][..., ca_idx, :] @ r + x,
merged_labels["all_atom_positions"][..., ca_idx, :].to('cpu') @ r.to('cpu') + x.to('cpu'),
pred_ca_pos,
(pred_ca_mask * merged_labels["all_atom_mask"][..., ca_idx]).bool(),
(pred_ca_mask.to('cpu') * merged_labels["all_atom_mask"][..., ca_idx].to('cpu')).bool(),
)
if rmsd < best_rmsd:
best_rmsd = rmsd
best_labels = merged_labels
print(f"finished kabsh_rmsd")
return best_labels
def get_anchor_candidates(batch, per_asym_residue_index, true_masks):
def find_by_num_sym(min_num_sym):
best_len = -1
best_gt_asym = None
asym_ids = batch["asym_id"][batch["num_sym"] == min_num_sym]
asym_ids = torch.unique(batch["asym_id"][batch["num_sym"] == min_num_sym])
for cur_asym_id in asym_ids:
assert cur_asym_id > 0
cur_residue_index = per_asym_residue_index[int(cur_asym_id)]
j = int(cur_asym_id - 1)
cur_true_mask = true_masks[j][cur_residue_index]
cur_len = cur_true_mask.shape[0]
if cur_len > best_len:
best_len = cur_len
best_gt_asym = cur_asym_id
return best_gt_asym, best_len
sorted_num_sym = batch["num_sym"][batch["num_sym"] > 0].sort()[0]
best_gt_asym = None
best_len = -1
for cur_num_sym in sorted_num_sym:
if cur_num_sym <= 0:
continue
cur_gt_sym, cur_len = find_by_num_sym(cur_num_sym)
if cur_len > best_len:
best_len = cur_len
best_gt_asym = cur_gt_sym
if best_len >= 3:
break
best_entity = batch["entity_id"][batch["asym_id"] == best_gt_asym][0]
best_pred_asym = torch.unique(batch["asym_id"][batch["entity_id"] == best_entity])
return best_gt_asym, best_pred_asym
def get_least_asym_entity_or_longest_length(batch):
"""
First check how many subunit(s) one sequence has, if there is no tie, e.g. AABBB then select
one of the A as anchor
If there is a tie, e.g. AABB, first check which sequence is the longer/longest,
then choose one of the corresponding subunits as anchor
"""
unique_entity_ids = torch.unique(batch["entity_id"])
entity_asym_count = {}
entity_length = {}
for entity_id in unique_entity_ids:
asym_ids = torch.unique(batch["asym_id"][batch["entity_id"] == entity_id])
entity_asym_count[int(entity_id)] = len(asym_ids)
# Calculate entity length
entity_mask = (batch["entity_id"] == entity_id)
entity_length[int(entity_id)] = entity_mask.sum().item()
min_asym_count = min(entity_asym_count.values())
least_asym_entities = [entity for entity, count in entity_asym_count.items() if count == min_asym_count]
# If multiple entities have the least asym_id count, return those with the shortest length
if len(least_asym_entities) > 1:
max_length = max([entity_length[entity] for entity in least_asym_entities])
least_asym_entities = [entity for entity in least_asym_entities if entity_length[entity] == max_length]
# If still multiple entities, return a random one
if len(least_asym_entities) > 1:
least_asym_entities = random.choice(least_asym_entities)
print(f"line 249 least_asym_entities is {least_asym_entities} and entity_length is {entity_length}")
assert len(least_asym_entities)==1
best_pred_asym = torch.unique(batch["asym_id"][batch["entity_id"] == least_asym_entities[0]])
return least_asym_entities[0], best_pred_asym
def greedy_align(
......@@ -251,27 +257,21 @@ def greedy_align(
cur_asym_list = entity_2_asym_list[int(cur_entity_ids)]
cur_residue_index = per_asym_residue_index[int(cur_asym_id)]
cur_pred_pos = pred_ca_pos[asym_mask[:,0],:] # only need the first 1 column of asym_mask
print(f"line 266 pred_ca_pos shape: {pred_ca_pos.shape} cur_pred_pos shape: {cur_pred_pos.shape} and pred_ca_pos is {pred_ca_pos.shape}")
cur_pred_mask = pred_ca_mask[asym_mask[:,0]]# only need the first column of asym_mask
cur_pred_pos = pred_ca_pos[asym_mask] # only need the first 1 column of asym_mask
cur_pred_mask = pred_ca_mask[asym_mask]
for next_asym_id in cur_asym_list:
if next_asym_id == 0:
continue
j = int(next_asym_id - 1)
if not used[j]: # possible candidate
print(f"line 265 curr_residue_index is {cur_residue_index} and j is {j}")
print(f"true_ca_poses shape: {true_ca_poses[j].shape}")
cropped_pos = true_ca_poses[j]
mask = true_ca_masks[j][cur_residue_index[:,0]]
print(f"line 278 cur_pred_mask shape: {cur_pred_mask.shape}\n mask shape: {mask.shape}")
print(f"cropped_pos shape {cropped_pos.shape} and cur_pred_pos shape {cur_pred_pos.shape}")
mask = true_ca_masks[j][cur_residue_index]
rmsd = compute_rmsd(
cropped_pos, cur_pred_pos, (cur_pred_mask.to('cpu') * mask.to('cpu')).bool()
)
if rmsd < best_rmsd:
best_rmsd = rmsd
best_idx = j
print(f"rmds is now {rmsd} and best_idx is {best_idx}")
assert best_idx is not None
used[best_idx] = True
align.append((i, best_idx))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment