Commit e097da95 authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

check if it's a monomer first; add permutate_chains back to forward()

parent cec5a426
......@@ -2146,7 +2146,7 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
return align, per_asym_residue_index
def forward(self, out, features, _return_breakdown=False):
def forward(self, out, features, _return_breakdown=False,permutate_chains=True):
"""
Overwrite AlphaFoldLoss forward function so that
it first compute multi-chain permutation
......@@ -2155,6 +2155,10 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
out: the output of model.forward()
batch: a pair of input features and its corresponding ground truth structure
"""
# first check if it is a monomer
is_monomer = len(torch.unique(features['asym_id']))==1 or torch.unique(features['asym_id']).tolist()==[0,1]
if not is_monomer:
permutate_chains = True
# first determin which dimension in the tensor to split into individual ground truth labels
dim_dict = AlphaFoldMultimerLoss.determine_split_dim(features)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment