Commit 51f3325f authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

update anchor selection function logic

parent 986d1f67
......@@ -1789,6 +1789,7 @@ def get_least_asym_entity_or_longest_length(batch,input_asym_id):
If there is a tie, e.g. AABB, first check which sequence is the longer/longest,
then choose one of the corresponding subunits as anchor
"""
entity_2_asym_list = AlphaFoldMultimerLoss.get_entity_2_asym_list(batch)
unique_entity_ids = torch.unique(batch["entity_id"])
entity_asym_count = {}
entity_length = {}
......@@ -1813,18 +1814,12 @@ def get_least_asym_entity_or_longest_length(batch,input_asym_id):
if len(least_asym_entities) > 1:
least_asym_entities = random.choice(least_asym_entities)
assert len(least_asym_entities)==1
# best_pred_asym = torch.unique(batch["asym_id"][batch["entity_id"] == least_asym_entities[0]])
# # # If there is more than one chain in the predicted output that has the same sequence
# # # as the chosen ground truth anchor, then randomly picke one
# if len(best_pred_asym) > 1:
# selected_best_pred_asym = random.choice(best_pred_asym)
# while selected_best_pred_asym not in input_asym_id:
# selected_best_pred_asym = random.choice(best_pred_asym)
# else:
# selected_best_pred_asym = best_pred_asym
best_pred_asym = least_asym_entities[0]
return least_asym_entities[0], best_pred_asym
least_asym_entities = least_asym_entities[0]
anchor_gt_asym_id = [id for id in entity_2_asym_list[least_asym_entities] if id in input_asym_id]
best_pred_asym = anchor_gt_asym_id[0]
anchor_gt_asym_id = anchor_gt_asym_id[0]
return anchor_gt_asym_id, best_pred_asym
def greedy_align(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment