Commit fe8869c3 authored by Christina Floristean's avatar Christina Floristean
Browse files

Bug fixes for multi-chain permutation alignment

parent 4f38c826
......@@ -111,6 +111,12 @@ def get_least_asym_entity_or_longest_length(batch, input_asym_id):
for entity_id in unique_entity_ids:
asym_ids = torch.unique(batch["asym_id"][batch["entity_id"] == entity_id])
# Make sure some asym IDs associated with ground truth entity ID exist in cropped prediction
asym_ids_in_pred = [a for a in asym_ids if a in input_asym_id]
if not asym_ids_in_pred:
continue
entity_asym_count[int(entity_id)] = len(asym_ids)
# Calculate entity length
......@@ -127,12 +133,14 @@ def get_least_asym_entity_or_longest_length(batch, input_asym_id):
# If still multiple entities, return a random one
if len(least_asym_entities) > 1:
least_asym_entities = random.choice(least_asym_entities)
least_asym_entities = [random.choice(least_asym_entities)]
assert len(least_asym_entities) == 1
least_asym_entities = least_asym_entities[0]
anchor_gt_asym_id = random.choice(entity_2_asym_list[least_asym_entities])
anchor_pred_asym_ids = [id for id in entity_2_asym_list[least_asym_entities] if id in input_asym_id]
anchor_pred_asym_ids = [asym_id for asym_id in entity_2_asym_list[least_asym_entities] if asym_id in input_asym_id]
return anchor_gt_asym_id, anchor_pred_asym_ids
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment