"tests/git@developer.sourcefind.cn:OpenDAS/colossalai.git" did not exist on "3a41e8304e1ec5ff1688d1967fea5376581c5a5c"
Commit fd748a0d authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

update loss to accomodate new input data pipeline

parent 02ce77c5
......@@ -1813,14 +1813,17 @@ def get_least_asym_entity_or_longest_length(batch,input_asym_id):
if len(least_asym_entities) > 1:
least_asym_entities = random.choice(least_asym_entities)
assert len(least_asym_entities)==1
best_pred_asym = torch.unique(batch["asym_id"][batch["entity_id"] == least_asym_entities[0]])
# best_pred_asym = torch.unique(batch["asym_id"][batch["entity_id"] == least_asym_entities[0]])
# # If there is more than one chain in the predicted output that has the same sequence
# # as the chosen ground truth anchor, then randomly picke one
if len(best_pred_asym) > 1:
while best_pred_asym not in input_asym_id:
best_pred_asym = random.choice(best_pred_asym)
# # # If there is more than one chain in the predicted output that has the same sequence
# # # as the chosen ground truth anchor, then randomly picke one
# if len(best_pred_asym) > 1:
# selected_best_pred_asym = random.choice(best_pred_asym)
# while selected_best_pred_asym not in input_asym_id:
# selected_best_pred_asym = random.choice(best_pred_asym)
# else:
# selected_best_pred_asym = best_pred_asym
best_pred_asym = least_asym_entities[0]
return least_asym_entities[0], best_pred_asym
......@@ -2100,10 +2103,8 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
"""
pred_ca_mask = torch.squeeze(pred_ca_mask,0)
asym_mask = torch.squeeze(asym_mask,0)
print(f"##### line 2102 asym_mask is {asym_mask} and shape: {asym_mask.shape}")
anchor_pred_mask = pred_ca_mask[asym_mask]
anchor_true_mask = torch.index_select(true_ca_masks[anchor_gt_idx],1,anchor_gt_residue)
print(f"##### line 2104 anchor_pred_mask:{anchor_pred_mask.shape} and anchor_true_mask : {anchor_true_mask.shape}")
input_mask = (anchor_true_mask * anchor_pred_mask).bool()
return input_mask
......@@ -2139,12 +2140,10 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
https://www.biorxiv.org/content/10.1101/2021.10.04.463034v2
"""
feature, ground_truth = batch
print(f"###### line 2140 feature asym_id is :{feature['asym_id']}")
del batch
if permutate_chains:
# First select anchors from predicted structures and ground truths
anchor_gt_asym, anchor_pred_asym = get_least_asym_entity_or_longest_length(ground_truth,feature['asym_id'])
print(f"###### anchor_gt_asym:{anchor_gt_asym} and anchor_pred_asym: {anchor_pred_asym}")
entity_2_asym_list = AlphaFoldMultimerLoss.get_entity_2_asym_list(ground_truth)
labels = AlphaFoldMultimerLoss.split_ground_truth_labels(ground_truth,
REQUIRED_FEATURES=["all_atom_mask","all_atom_positions"])
......@@ -2189,7 +2188,6 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
del true_ca_masks,aligned_true_ca_poses
del pred_ca_pos, pred_ca_mask
gc.collect()
print(f"finished permutation align. Align is {align}")
else:
per_asym_residue_index = AlphaFoldMultimerLoss.get_per_asym_residue_index(feature)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment