Commit e097da95 authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

check if it's a monomer first; add permutate_chains back to forward()

parent cec5a426
...@@ -2146,7 +2146,7 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss): ...@@ -2146,7 +2146,7 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
return align, per_asym_residue_index return align, per_asym_residue_index
def forward(self, out, features, _return_breakdown=False): def forward(self, out, features, _return_breakdown=False,permutate_chains=True):
""" """
Overwrite AlphaFoldLoss forward function so that Overwrite AlphaFoldLoss forward function so that
it first compute multi-chain permutation it first compute multi-chain permutation
...@@ -2155,6 +2155,10 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss): ...@@ -2155,6 +2155,10 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
out: the output of model.forward() out: the output of model.forward()
batch: a pair of input features and its corresponding ground truth structure batch: a pair of input features and its corresponding ground truth structure
""" """
# first check if it is a monomer
is_monomer = len(torch.unique(features['asym_id']))==1 or torch.unique(features['asym_id']).tolist()==[0,1]
if not is_monomer:
permutate_chains = True
# first determin which dimension in the tensor to split into individual ground truth labels # first determin which dimension in the tensor to split into individual ground truth labels
dim_dict = AlphaFoldMultimerLoss.determine_split_dim(features) dim_dict = AlphaFoldMultimerLoss.determine_split_dim(features)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment