Commit 2669287e authored by Geoffrey Yu's avatar Geoffrey Yu Committed by Jennifer Wei
Browse files

restore to the verison on main

parent b5427018
......@@ -105,12 +105,13 @@ def get_least_asym_entity_or_longest_length(batch, input_asym_id):
anchor_pred_asym_ids: list(Tensor(int)) a list of all possible pred anchor candidates
"""
entity_2_asym_list = get_entity_2_asym_list(batch)
unique_entity_ids = [i for i in torch.unique(batch["entity_id"]) if i !=0]# if entity_id is 0, that means this entity_id comes from padding
unique_entity_ids = torch.unique(batch["entity_id"])
entity_asym_count = {}
entity_length = {}
for entity_id in unique_entity_ids:
asym_ids = torch.unique(batch["asym_id"][batch["entity_id"] == entity_id])
# Make sure some asym IDs associated with ground truth entity ID exist in cropped prediction
asym_ids_in_pred = [a for a in asym_ids if a in input_asym_id]
if not asym_ids_in_pred:
......@@ -121,8 +122,10 @@ def get_least_asym_entity_or_longest_length(batch, input_asym_id):
# Calculate entity length
entity_mask = (batch["entity_id"] == entity_id)
entity_length[int(entity_id)] = entity_mask.sum().item()
min_asym_count = min(entity_asym_count.values())
least_asym_entities = [entity for entity, count in entity_asym_count.items() if count == min_asym_count]
# If multiple entities have the least asym_id count, return those with the longest length
if len(least_asym_entities) > 1:
max_length = max([entity_length[entity] for entity in least_asym_entities])
......@@ -137,6 +140,7 @@ def get_least_asym_entity_or_longest_length(batch, input_asym_id):
anchor_gt_asym_id = random.choice(entity_2_asym_list[least_asym_entities])
anchor_pred_asym_ids = [asym_id for asym_id in entity_2_asym_list[least_asym_entities] if asym_id in input_asym_id]
return anchor_gt_asym_id, anchor_pred_asym_ids
......@@ -156,7 +160,6 @@ def greedy_align(
used = [False for _ in range(len(true_ca_poses))]
align = []
unique_asym_ids = [i for i in torch.unique(batch["asym_id"]) if i != 0]
for cur_asym_id in unique_asym_ids:
i = int(cur_asym_id - 1)
asym_mask = batch["asym_id"] == cur_asym_id
......@@ -346,7 +349,6 @@ def compute_permutation_alignment(out, features, ground_truth):
# First select anchors from predicted structures and ground truths
anchor_gt_asym, anchor_pred_asym_ids = get_least_asym_entity_or_longest_length(ground_truth,
features['asym_id'])
entity_2_asym_list = get_entity_2_asym_list(ground_truth)
labels = split_ground_truth_labels(ground_truth)
assert isinstance(labels, list)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment