Commit 39830684 authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

start working on multimer loss

parent 56d5e39c
......@@ -34,7 +34,8 @@ from openfold.utils.tensor_utils import (
permute_final_dims,
batched_gather,
)
import logging
logger = logging.getLogger(__name__)
def softmax_cross_entropy(logits, labels):
loss = -1 * torch.sum(
......@@ -1675,7 +1676,11 @@ class AlphaFoldLoss(nn.Module):
super(AlphaFoldLoss, self).__init__()
self.config = config
def forward(self, out, batch, _return_breakdown=False):
def loss(self, out, batch, _return_breakdown=False):
"""
Rename previous forward() as loss
so that can be reused in the subclass
"""
if "violation" not in out.keys():
out["violation"] = find_structural_violations(
batch,
......@@ -1766,3 +1771,23 @@ class AlphaFoldLoss(nn.Module):
return cum_loss
return cum_loss, losses
def forward(self, out, batch, _return_breakdown=False):
cum_loss,losses = self.loss(out,batch,_return_breakdown)
return cum_loss, losses
class AlphaFoldMultimerLoss(AlphaFoldLoss):
"""
Add multi-chain permutation on top of
AlphaFoldLoss
"""
def __init__(self, config):
super(AlphaFoldMultimerLoss, self).__init__()
self.config = config
def forward(self,out,batch,_return_breakdown=False):
"""
Overwrite AlphaFoldLoss forward function so that
it first compute multi-chain permutation
"""
logger.info(f"out is {type(out)} and batch is {type(batch)}")
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment