Commit 6eb1afe7 authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

finished constructing AlphaFoldMultimerLoss and other necessary changes in the losses calculations

parent 30d50a18
......@@ -737,7 +737,11 @@ def tm_loss(
eps=1e-8,
**kwargs,
):
# first check whether this is a tensor_7 or tensor_4*4
if final_affine_tensor.shape[-1]==7:
pred_affine = Rigid.from_tensor_7(final_affine_tensor)
elif final_affine_tensor.shape[-1]==4:
pred_affine = Rigid.from_tensor_4x4(final_affine_tensor)
backbone_rigid = Rigid.from_tensor_4x4(backbone_rigid_tensor)
def _points(affine):
......@@ -1635,7 +1639,7 @@ def chain_center_of_mass_loss(
asym_id: torch.Tensor,
clamp_distance: float = -4.0,
weight: float = 0.05,
eps: float = 1e-10
eps: float = 1e-10, **kwargs
) -> torch.Tensor:
"""
Computes chain centre-of-mass loss. Implements section 2.5, eqn 1 in the Multimer paper.
......@@ -1662,9 +1666,9 @@ def chain_center_of_mass_loss(
all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :]
all_atom_positions = all_atom_positions[..., ca_pos, :]
all_atom_mask = all_atom_mask[..., ca_pos: (ca_pos + 1)] # keep dim
chains, _ = asym_id.unique(return_counts=True)
one_hot = torch.nn.functional.one_hot(asym_id, num_classes=chains.shape[0]).to(dtype=all_atom_mask.dtype)
one_hot = torch.nn.functional.one_hot(asym_id.to(torch.int64),
num_classes=chains.shape[0]).to(dtype=all_atom_mask.dtype) # make sure asym_id dtype is int
one_hot = one_hot * all_atom_mask
chain_pos_mask = one_hot.transpose(-2, -1)
chain_exists = torch.any(chain_pos_mask, dim=-1).float()
......@@ -2012,6 +2016,10 @@ class AlphaFoldLoss(nn.Module):
return cum_loss, losses
def forward(self, out, batch, _return_breakdown=False):
if(not _return_breakdown):
cum_loss = self.loss(out,batch,_return_breakdown)
return cum_loss
else:
cum_loss,losses = self.loss(out,batch,_return_breakdown)
return cum_loss, losses
......@@ -2120,11 +2128,14 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
# then permutate ground truth chains before calculating the loss
permutated_labels = self.multi_chain_perm_align(out,features,labels)
permutated_labels.pop('aatype')
logger.info("finished multi-chain permutation ")
features.update(permutated_labels)
move_to_cpu = lambda t: (t.to('cpu'))
features = tensor_tree_map(move_to_cpu,features)
self.loss(out,features)
return permutated_labels
## TODO next need to check how the ground truth label is used
# in loss calculation.
\ No newline at end of file
if (not _return_breakdown):
cum_loss = self.loss(out,features,_return_breakdown)
print(f"cum_loss: {cum_loss}")
return cum_loss
else:
cum_loss,losses = self.loss(out,features,_return_breakdown)
print(f"cum_loss: {cum_loss} losses: {losses}")
return cum_loss, losses
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment