Commit 77cb4135 authored by Geoffrey Yu's avatar Geoffrey Yu Committed by Jennifer Wei
Browse files

restore to the verison on main

parent d1a32aa3
...@@ -105,12 +105,13 @@ def get_least_asym_entity_or_longest_length(batch, input_asym_id): ...@@ -105,12 +105,13 @@ def get_least_asym_entity_or_longest_length(batch, input_asym_id):
anchor_pred_asym_ids: list(Tensor(int)) a list of all possible pred anchor candidates anchor_pred_asym_ids: list(Tensor(int)) a list of all possible pred anchor candidates
""" """
entity_2_asym_list = get_entity_2_asym_list(batch) entity_2_asym_list = get_entity_2_asym_list(batch)
unique_entity_ids = [i for i in torch.unique(batch["entity_id"]) if i !=0]# if entity_id is 0, that means this entity_id comes from padding unique_entity_ids = torch.unique(batch["entity_id"])
entity_asym_count = {} entity_asym_count = {}
entity_length = {} entity_length = {}
for entity_id in unique_entity_ids: for entity_id in unique_entity_ids:
asym_ids = torch.unique(batch["asym_id"][batch["entity_id"] == entity_id]) asym_ids = torch.unique(batch["asym_id"][batch["entity_id"] == entity_id])
# Make sure some asym IDs associated with ground truth entity ID exist in cropped prediction # Make sure some asym IDs associated with ground truth entity ID exist in cropped prediction
asym_ids_in_pred = [a for a in asym_ids if a in input_asym_id] asym_ids_in_pred = [a for a in asym_ids if a in input_asym_id]
if not asym_ids_in_pred: if not asym_ids_in_pred:
...@@ -121,8 +122,10 @@ def get_least_asym_entity_or_longest_length(batch, input_asym_id): ...@@ -121,8 +122,10 @@ def get_least_asym_entity_or_longest_length(batch, input_asym_id):
# Calculate entity length # Calculate entity length
entity_mask = (batch["entity_id"] == entity_id) entity_mask = (batch["entity_id"] == entity_id)
entity_length[int(entity_id)] = entity_mask.sum().item() entity_length[int(entity_id)] = entity_mask.sum().item()
min_asym_count = min(entity_asym_count.values()) min_asym_count = min(entity_asym_count.values())
least_asym_entities = [entity for entity, count in entity_asym_count.items() if count == min_asym_count] least_asym_entities = [entity for entity, count in entity_asym_count.items() if count == min_asym_count]
# If multiple entities have the least asym_id count, return those with the longest length # If multiple entities have the least asym_id count, return those with the longest length
if len(least_asym_entities) > 1: if len(least_asym_entities) > 1:
max_length = max([entity_length[entity] for entity in least_asym_entities]) max_length = max([entity_length[entity] for entity in least_asym_entities])
...@@ -137,6 +140,7 @@ def get_least_asym_entity_or_longest_length(batch, input_asym_id): ...@@ -137,6 +140,7 @@ def get_least_asym_entity_or_longest_length(batch, input_asym_id):
anchor_gt_asym_id = random.choice(entity_2_asym_list[least_asym_entities]) anchor_gt_asym_id = random.choice(entity_2_asym_list[least_asym_entities])
anchor_pred_asym_ids = [asym_id for asym_id in entity_2_asym_list[least_asym_entities] if asym_id in input_asym_id] anchor_pred_asym_ids = [asym_id for asym_id in entity_2_asym_list[least_asym_entities] if asym_id in input_asym_id]
return anchor_gt_asym_id, anchor_pred_asym_ids return anchor_gt_asym_id, anchor_pred_asym_ids
...@@ -156,7 +160,6 @@ def greedy_align( ...@@ -156,7 +160,6 @@ def greedy_align(
used = [False for _ in range(len(true_ca_poses))] used = [False for _ in range(len(true_ca_poses))]
align = [] align = []
unique_asym_ids = [i for i in torch.unique(batch["asym_id"]) if i != 0] unique_asym_ids = [i for i in torch.unique(batch["asym_id"]) if i != 0]
for cur_asym_id in unique_asym_ids: for cur_asym_id in unique_asym_ids:
i = int(cur_asym_id - 1) i = int(cur_asym_id - 1)
asym_mask = batch["asym_id"] == cur_asym_id asym_mask = batch["asym_id"] == cur_asym_id
...@@ -346,7 +349,6 @@ def compute_permutation_alignment(out, features, ground_truth): ...@@ -346,7 +349,6 @@ def compute_permutation_alignment(out, features, ground_truth):
# First select anchors from predicted structures and ground truths # First select anchors from predicted structures and ground truths
anchor_gt_asym, anchor_pred_asym_ids = get_least_asym_entity_or_longest_length(ground_truth, anchor_gt_asym, anchor_pred_asym_ids = get_least_asym_entity_or_longest_length(ground_truth,
features['asym_id']) features['asym_id'])
entity_2_asym_list = get_entity_2_asym_list(ground_truth) entity_2_asym_list = get_entity_2_asym_list(ground_truth)
labels = split_ground_truth_labels(ground_truth) labels = split_ground_truth_labels(ground_truth)
assert isinstance(labels, list) assert isinstance(labels, list)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment