"runtime/rust/vscode:/vscode.git/clone" did not exist on "ccd153afb2b530d01945857095bfc7c5dc9ef1f0"
Commit 20586e4e authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

check if the strucutre is a monomer first before applying multi-chain permutation

parent 2a1028f0
...@@ -2095,7 +2095,6 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss): ...@@ -2095,7 +2095,6 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
per_asym_residue_index[int(cur_asym_id)] = torch.masked_select(batch["residue_index"], asym_mask) per_asym_residue_index[int(cur_asym_id)] = torch.masked_select(batch["residue_index"], asym_mask)
if permutate_chains: if permutate_chains:
anchor_gt_asym, anchor_pred_asym = get_least_asym_entity_or_longest_length(batch) anchor_gt_asym, anchor_pred_asym = get_least_asym_entity_or_longest_length(batch)
print(f"anchor_gt_asym:{anchor_gt_asym} anchor_pred_asym:{anchor_pred_asym}")
anchor_gt_idx = int(anchor_gt_asym) - 1 anchor_gt_idx = int(anchor_gt_asym) - 1
unique_entity_ids = torch.unique(batch["entity_id"]) unique_entity_ids = torch.unique(batch["entity_id"])
...@@ -2155,12 +2154,15 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss): ...@@ -2155,12 +2154,15 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
out: the output of model.forward() out: the output of model.forward()
batch: a pair of input features and its corresponding ground truth structure batch: a pair of input features and its corresponding ground truth structure
""" """
_is_monomer = len(torch.unique(features['asym_id']))==1 or torch.unique(features['asym_id']).tolist()==[0,1]
if not _is_monomer:
# first determin which dimension in the tensor to split into individual ground truth labels # first determin which dimension in the tensor to split into individual ground truth labels
dim_dict = AlphaFoldMultimerLoss.determine_split_dim(features) dim_dict = AlphaFoldMultimerLoss.determine_split_dim(features)
# Then permutate ground truth chains before calculating the loss # Then permutate ground truth chains before calculating the loss
align, per_asym_residue_index = AlphaFoldMultimerLoss.multi_chain_perm_align(out, features,dim_dict=dim_dict, align, per_asym_residue_index = AlphaFoldMultimerLoss.multi_chain_perm_align(out, features,dim_dict=dim_dict,
permutate_chains=permutate_chains) permutate_chains=True)
labels = AlphaFoldMultimerLoss.split_ground_truth_labels(features,dim_dict=dim_dict, labels = AlphaFoldMultimerLoss.split_ground_truth_labels(features,dim_dict=dim_dict,
REQUIRED_FEATURES=[i for i in features.keys() if i in dim_dict]) REQUIRED_FEATURES=[i for i in features.keys() if i in dim_dict])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment