"git@developer.sourcefind.cn:OpenDAS/openpcdet.git" did not exist on "46d558ed71d1e1ee157d7f68dc10653969e0f058"
Commit e097da95 authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

check if it's a monomer first; add permutate_chains back to forward()

parent cec5a426
...@@ -2146,7 +2146,7 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss): ...@@ -2146,7 +2146,7 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
return align, per_asym_residue_index return align, per_asym_residue_index
def forward(self, out, features, _return_breakdown=False): def forward(self, out, features, _return_breakdown=False,permutate_chains=True):
""" """
Overwrite AlphaFoldLoss forward function so that Overwrite AlphaFoldLoss forward function so that
it first compute multi-chain permutation it first compute multi-chain permutation
...@@ -2155,19 +2155,23 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss): ...@@ -2155,19 +2155,23 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
out: the output of model.forward() out: the output of model.forward()
batch: a pair of input features and its corresponding ground truth structure batch: a pair of input features and its corresponding ground truth structure
""" """
# first determin which dimension in the tensor to split into individual ground truth labels # first check if it is a monomer
dim_dict = AlphaFoldMultimerLoss.determine_split_dim(features) is_monomer = len(torch.unique(features['asym_id']))==1 or torch.unique(features['asym_id']).tolist()==[0,1]
if not is_monomer:
# Then permutate ground truth chains before calculating the loss permutate_chains = True
align, per_asym_residue_index = AlphaFoldMultimerLoss.multi_chain_perm_align(out, features,dim_dict=dim_dict, # first determin which dimension in the tensor to split into individual ground truth labels
permutate_chains=permutate_chains) dim_dict = AlphaFoldMultimerLoss.determine_split_dim(features)
labels = AlphaFoldMultimerLoss.split_ground_truth_labels(features,dim_dict=dim_dict, # Then permutate ground truth chains before calculating the loss
REQUIRED_FEATURES=[i for i in features.keys() if i in dim_dict]) align, per_asym_residue_index = AlphaFoldMultimerLoss.multi_chain_perm_align(out, features,dim_dict=dim_dict,
# reorder ground truth labels according to permutation results permutate_chains=permutate_chains)
labels = merge_labels(per_asym_residue_index,labels,align,
original_nres=features['aatype'].shape[-1]) labels = AlphaFoldMultimerLoss.split_ground_truth_labels(features,dim_dict=dim_dict,
features.update(labels) REQUIRED_FEATURES=[i for i in features.keys() if i in dim_dict])
# reorder ground truth labels according to permutation results
labels = merge_labels(per_asym_residue_index,labels,align,
original_nres=features['aatype'].shape[-1])
features.update(labels)
if (not _return_breakdown): if (not _return_breakdown):
cum_loss = self.loss(out, features, _return_breakdown) cum_loss = self.loss(out, features, _return_breakdown)
...@@ -2176,4 +2180,4 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss): ...@@ -2176,4 +2180,4 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
else: else:
cum_loss, losses = self.loss(out, features, _return_breakdown) cum_loss, losses = self.loss(out, features, _return_breakdown)
print(f"cum_loss: {cum_loss} losses: {losses}") print(f"cum_loss: {cum_loss} losses: {losses}")
return cum_loss, losses return cum_loss, losses
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment