Commit a0f8a057 authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

now used the new way of selecting anchors

parent 54755901
...@@ -40,12 +40,12 @@ class TestPermutation(unittest.TestCase): ...@@ -40,12 +40,12 @@ class TestPermutation(unittest.TestCase):
In the test case, use PDB ID 1e4k as the label In the test case, use PDB ID 1e4k as the label
""" """
self.test_data_dir = os.path.join(os.getcwd(),"tests/test_data") self.test_data_dir = os.path.join(os.getcwd(),"tests/test_data")
self.label_ids = ['label_1','label_2'] self.label_ids = ['label_1','label_2','label_2']
def test_dry_run(self): def test_dry_run(self):
n_seq = consts.n_seq n_seq = consts.n_seq
n_templ = consts.n_templ n_templ = consts.n_templ
n_res = consts.n_res n_res = consts.n_res +13
n_extra_seq = consts.n_extra n_extra_seq = consts.n_extra
c = model_config(consts.model, train=True) c = model_config(consts.model, train=True)
...@@ -83,12 +83,12 @@ class TestPermutation(unittest.TestCase): ...@@ -83,12 +83,12 @@ class TestPermutation(unittest.TestCase):
# Modify asym_id, entity_id and sym_id so that it encodes # Modify asym_id, entity_id and sym_id so that it encodes
# 2 chains # 2 chains
# # # #
asym_id = [1]*9+[2]*13 asym_id = [1]*9+[2]*13+[3]*13
batch["asym_id"] = torch.tensor(asym_id,dtype=torch.float64) batch["asym_id"] = torch.tensor(asym_id,dtype=torch.float64)
# batch["entity_id"] = torch.randint(0, 1, size=(n_res,)) # batch["entity_id"] = torch.randint(0, 1, size=(n_res,))
batch['entity_id'] = torch.tensor(asym_id,dtype=torch.float64) batch['entity_id'] = torch.tensor([1]*9+[2]*26,dtype=torch.float64)
batch["sym_id"] = torch.tensor(asym_id,dtype=torch.float64) batch["sym_id"] = torch.tensor(asym_id,dtype=torch.float64)
batch["num_sym"] = torch.tensor([2]*22,dtype=torch.int64) # currently there are just 2 chains batch["num_sym"] = torch.tensor([1]*9+[2]*26,dtype=torch.int64) # currently there are just 2 chains
batch["extra_deletion_matrix"] = torch.randint(0, 2, size=(n_extra_seq, n_res)) batch["extra_deletion_matrix"] = torch.randint(0, 2, size=(n_extra_seq, n_res))
add_recycling_dims = lambda t: ( add_recycling_dims = lambda t: (
t.unsqueeze(-1).expand(*t.shape, c.data.common.max_recycling_iters) t.unsqueeze(-1).expand(*t.shape, c.data.common.max_recycling_iters)
......
...@@ -3,6 +3,7 @@ from openfold.np import residue_constants as rc ...@@ -3,6 +3,7 @@ from openfold.np import residue_constants as rc
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
import sys import sys
import random
def kabsch_rotation(P, Q): def kabsch_rotation(P, Q):
""" """
...@@ -119,10 +120,13 @@ def multi_chain_perm_align(out, batch, labels, shuffle_times=2): ...@@ -119,10 +120,13 @@ def multi_chain_perm_align(out, batch, labels, shuffle_times=2):
for cur_asym_id in unique_asym_ids: for cur_asym_id in unique_asym_ids:
asym_mask = (batch["asym_id"] == cur_asym_id).bool() asym_mask = (batch["asym_id"] == cur_asym_id).bool()
per_asym_residue_index[int(cur_asym_id)] = batch["residue_index"][asym_mask] per_asym_residue_index[int(cur_asym_id)] = batch["residue_index"][asym_mask]
anchor_gt_asym, anchor_pred_asym = get_anchor_candidates( # anchor_gt_asym, anchor_pred_asym = get_anchor_candidates(
batch, per_asym_residue_index, true_ca_masks # batch, per_asym_residue_index, true_ca_masks
) # )
print(f"anchor_gt_asym is {anchor_gt_asym}, anchor_pred_asym is {anchor_pred_asym}") anchor_gt_asym, anchor_pred_asym=get_least_asym_entity_or_longest_length(batch)
print(f"anchor_gt_asym is {anchor_gt_asym}")
import sys
sys.exit()
anchor_gt_idx = int(anchor_gt_asym) - 1 anchor_gt_idx = int(anchor_gt_asym) - 1
best_rmsd = 1e9 best_rmsd = 1e9
...@@ -134,7 +138,7 @@ def multi_chain_perm_align(out, batch, labels, shuffle_times=2): ...@@ -134,7 +138,7 @@ def multi_chain_perm_align(out, batch, labels, shuffle_times=2):
ent_mask = batch["entity_id"] == cur_ent_id ent_mask = batch["entity_id"] == cur_ent_id
cur_asym_id = torch.unique(batch["asym_id"][ent_mask]) cur_asym_id = torch.unique(batch["asym_id"][ent_mask])
entity_2_asym_list[int(cur_ent_id)] = cur_asym_id entity_2_asym_list[int(cur_ent_id)] = cur_asym_id
print(f"entity_2_asym_list is {entity_2_asym_list}")
for cur_asym_id in anchor_pred_asym: for cur_asym_id in anchor_pred_asym:
asym_mask = (batch["asym_id"] == cur_asym_id).bool() asym_mask = (batch["asym_id"] == cur_asym_id).bool()
anchor_residue_idx = per_asym_residue_index[int(cur_asym_id)] anchor_residue_idx = per_asym_residue_index[int(cur_asym_id)]
...@@ -148,9 +152,7 @@ def multi_chain_perm_align(out, batch, labels, shuffle_times=2): ...@@ -148,9 +152,7 @@ def multi_chain_perm_align(out, batch, labels, shuffle_times=2):
anchor_pred_pos, anchor_pred_pos,
(anchor_true_mask.to('cpu') * anchor_pred_mask.to('cpu')).bool(), (anchor_true_mask.to('cpu') * anchor_pred_mask.to('cpu')).bool(),
) )
print(f"finished getting optimal transform")
import sys
sys.exit()
aligned_true_ca_poses = [ca.to('cpu') @ r.to('cpu') + x.to('cpu') for ca in true_ca_poses] # apply transforms aligned_true_ca_poses = [ca.to('cpu') @ r.to('cpu') + x.to('cpu') for ca in true_ca_poses] # apply transforms
for _ in range(shuffle_times): for _ in range(shuffle_times):
shuffle_idx = torch.randperm( shuffle_idx = torch.randperm(
...@@ -167,57 +169,61 @@ def multi_chain_perm_align(out, batch, labels, shuffle_times=2): ...@@ -167,57 +169,61 @@ def multi_chain_perm_align(out, batch, labels, shuffle_times=2):
aligned_true_ca_poses, aligned_true_ca_poses,
true_ca_masks, true_ca_masks,
) )
merged_labels = merge_labels( merged_labels = merge_labels(
batch, batch,
per_asym_residue_index, per_asym_residue_index,
labels, labels,
align, align,
) )
rmsd = kabsch_rmsd( rmsd = kabsch_rmsd(
merged_labels["all_atom_positions"][..., ca_idx, :] @ r + x, merged_labels["all_atom_positions"][..., ca_idx, :].to('cpu') @ r.to('cpu') + x.to('cpu'),
pred_ca_pos, pred_ca_pos,
(pred_ca_mask * merged_labels["all_atom_mask"][..., ca_idx]).bool(), (pred_ca_mask.to('cpu') * merged_labels["all_atom_mask"][..., ca_idx].to('cpu')).bool(),
) )
if rmsd < best_rmsd: if rmsd < best_rmsd:
best_rmsd = rmsd best_rmsd = rmsd
best_labels = merged_labels best_labels = merged_labels
print(f"finished kabsh_rmsd")
return best_labels return best_labels
def get_anchor_candidates(batch, per_asym_residue_index, true_masks): def get_least_asym_entity_or_longest_length(batch):
def find_by_num_sym(min_num_sym): """
best_len = -1 First check how many subunit(s) one sequence has, if there is no tie, e.g. AABBB then select
best_gt_asym = None one of the A as anchor
asym_ids = batch["asym_id"][batch["num_sym"] == min_num_sym]
asym_ids = torch.unique(batch["asym_id"][batch["num_sym"] == min_num_sym]) If there is a tie, e.g. AABB, first check which sequence is the longer/longest,
for cur_asym_id in asym_ids: then choose one of the corresponding subunits as anchor
assert cur_asym_id > 0 """
cur_residue_index = per_asym_residue_index[int(cur_asym_id)] unique_entity_ids = torch.unique(batch["entity_id"])
j = int(cur_asym_id - 1) entity_asym_count = {}
cur_true_mask = true_masks[j][cur_residue_index] entity_length = {}
cur_len = cur_true_mask.shape[0]
if cur_len > best_len: for entity_id in unique_entity_ids:
best_len = cur_len asym_ids = torch.unique(batch["asym_id"][batch["entity_id"] == entity_id])
best_gt_asym = cur_asym_id entity_asym_count[int(entity_id)] = len(asym_ids)
return best_gt_asym, best_len
sorted_num_sym = batch["num_sym"][batch["num_sym"] > 0].sort()[0] # Calculate entity length
best_gt_asym = None entity_mask = (batch["entity_id"] == entity_id)
best_len = -1 entity_length[int(entity_id)] = entity_mask.sum().item()
for cur_num_sym in sorted_num_sym:
if cur_num_sym <= 0: min_asym_count = min(entity_asym_count.values())
continue least_asym_entities = [entity for entity, count in entity_asym_count.items() if count == min_asym_count]
cur_gt_sym, cur_len = find_by_num_sym(cur_num_sym)
if cur_len > best_len: # If multiple entities have the least asym_id count, return those with the shortest length
best_len = cur_len if len(least_asym_entities) > 1:
best_gt_asym = cur_gt_sym max_length = max([entity_length[entity] for entity in least_asym_entities])
if best_len >= 3: least_asym_entities = [entity for entity in least_asym_entities if entity_length[entity] == max_length]
break
best_entity = batch["entity_id"][batch["asym_id"] == best_gt_asym][0] # If still multiple entities, return a random one
best_pred_asym = torch.unique(batch["asym_id"][batch["entity_id"] == best_entity]) if len(least_asym_entities) > 1:
return best_gt_asym, best_pred_asym least_asym_entities = random.choice(least_asym_entities)
print(f"line 249 least_asym_entities is {least_asym_entities} and entity_length is {entity_length}")
assert len(least_asym_entities)==1
best_pred_asym = torch.unique(batch["asym_id"][batch["entity_id"] == least_asym_entities[0]])
return least_asym_entities[0], best_pred_asym
def greedy_align( def greedy_align(
...@@ -251,27 +257,21 @@ def greedy_align( ...@@ -251,27 +257,21 @@ def greedy_align(
cur_asym_list = entity_2_asym_list[int(cur_entity_ids)] cur_asym_list = entity_2_asym_list[int(cur_entity_ids)]
cur_residue_index = per_asym_residue_index[int(cur_asym_id)] cur_residue_index = per_asym_residue_index[int(cur_asym_id)]
cur_pred_pos = pred_ca_pos[asym_mask[:,0],:] # only need the first 1 column of asym_mask cur_pred_pos = pred_ca_pos[asym_mask] # only need the first 1 column of asym_mask
print(f"line 266 pred_ca_pos shape: {pred_ca_pos.shape} cur_pred_pos shape: {cur_pred_pos.shape} and pred_ca_pos is {pred_ca_pos.shape}") cur_pred_mask = pred_ca_mask[asym_mask]
cur_pred_mask = pred_ca_mask[asym_mask[:,0]]# only need the first column of asym_mask
for next_asym_id in cur_asym_list: for next_asym_id in cur_asym_list:
if next_asym_id == 0: if next_asym_id == 0:
continue continue
j = int(next_asym_id - 1) j = int(next_asym_id - 1)
if not used[j]: # possible candidate if not used[j]: # possible candidate
print(f"line 265 curr_residue_index is {cur_residue_index} and j is {j}")
print(f"true_ca_poses shape: {true_ca_poses[j].shape}")
cropped_pos = true_ca_poses[j] cropped_pos = true_ca_poses[j]
mask = true_ca_masks[j][cur_residue_index[:,0]] mask = true_ca_masks[j][cur_residue_index]
print(f"line 278 cur_pred_mask shape: {cur_pred_mask.shape}\n mask shape: {mask.shape}")
print(f"cropped_pos shape {cropped_pos.shape} and cur_pred_pos shape {cur_pred_pos.shape}")
rmsd = compute_rmsd( rmsd = compute_rmsd(
cropped_pos, cur_pred_pos, (cur_pred_mask.to('cpu') * mask.to('cpu')).bool() cropped_pos, cur_pred_pos, (cur_pred_mask.to('cpu') * mask.to('cpu')).bool()
) )
if rmsd < best_rmsd: if rmsd < best_rmsd:
best_rmsd = rmsd best_rmsd = rmsd
best_idx = j best_idx = j
print(f"rmds is now {rmsd} and best_idx is {best_idx}")
assert best_idx is not None assert best_idx is not None
used[best_idx] = True used[best_idx] = True
align.append((i, best_idx)) align.append((i, best_idx))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment