"applications/git@developer.sourcefind.cn:OpenDAS/dynamo.git" did not exist on "057f8f470041364365090a93894b40296ee0bcb3"
Commit 20586e4e authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

check if the strucutre is a monomer first before applying multi-chain permutation

parent 2a1028f0
...@@ -2095,7 +2095,6 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss): ...@@ -2095,7 +2095,6 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
per_asym_residue_index[int(cur_asym_id)] = torch.masked_select(batch["residue_index"], asym_mask) per_asym_residue_index[int(cur_asym_id)] = torch.masked_select(batch["residue_index"], asym_mask)
if permutate_chains: if permutate_chains:
anchor_gt_asym, anchor_pred_asym = get_least_asym_entity_or_longest_length(batch) anchor_gt_asym, anchor_pred_asym = get_least_asym_entity_or_longest_length(batch)
print(f"anchor_gt_asym:{anchor_gt_asym} anchor_pred_asym:{anchor_pred_asym}")
anchor_gt_idx = int(anchor_gt_asym) - 1 anchor_gt_idx = int(anchor_gt_asym) - 1
unique_entity_ids = torch.unique(batch["entity_id"]) unique_entity_ids = torch.unique(batch["entity_id"])
...@@ -2155,19 +2154,22 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss): ...@@ -2155,19 +2154,22 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
out: the output of model.forward() out: the output of model.forward()
batch: a pair of input features and its corresponding ground truth structure batch: a pair of input features and its corresponding ground truth structure
""" """
# first determin which dimension in the tensor to split into individual ground truth labels _is_monomer = len(torch.unique(features['asym_id']))==1 or torch.unique(features['asym_id']).tolist()==[0,1]
dim_dict = AlphaFoldMultimerLoss.determine_split_dim(features)
# Then permutate ground truth chains before calculating the loss
align, per_asym_residue_index = AlphaFoldMultimerLoss.multi_chain_perm_align(out, features,dim_dict=dim_dict,
permutate_chains=permutate_chains)
labels = AlphaFoldMultimerLoss.split_ground_truth_labels(features,dim_dict=dim_dict, if not _is_monomer:
REQUIRED_FEATURES=[i for i in features.keys() if i in dim_dict]) # first determin which dimension in the tensor to split into individual ground truth labels
# reorder ground truth labels according to permutation results dim_dict = AlphaFoldMultimerLoss.determine_split_dim(features)
labels = merge_labels(per_asym_residue_index,labels,align,
original_nres=features['aatype'].shape[-1]) # Then permutate ground truth chains before calculating the loss
features.update(labels) align, per_asym_residue_index = AlphaFoldMultimerLoss.multi_chain_perm_align(out, features,dim_dict=dim_dict,
permutate_chains=True)
labels = AlphaFoldMultimerLoss.split_ground_truth_labels(features,dim_dict=dim_dict,
REQUIRED_FEATURES=[i for i in features.keys() if i in dim_dict])
# reorder ground truth labels according to permutation results
labels = merge_labels(per_asym_residue_index,labels,align,
original_nres=features['aatype'].shape[-1])
features.update(labels)
if (not _return_breakdown): if (not _return_breakdown):
cum_loss = self.loss(out, features, _return_breakdown) cum_loss = self.loss(out, features, _return_breakdown)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment