Commit 6eb1afe7 authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

finished constructing AlphaFoldMultimerLoss and other necessary changes in the losses calculations

parent 30d50a18
...@@ -737,7 +737,11 @@ def tm_loss( ...@@ -737,7 +737,11 @@ def tm_loss(
eps=1e-8, eps=1e-8,
**kwargs, **kwargs,
): ):
pred_affine = Rigid.from_tensor_7(final_affine_tensor) # first check whether this is a tensor_7 or tensor_4*4
if final_affine_tensor.shape[-1]==7:
pred_affine = Rigid.from_tensor_7(final_affine_tensor)
elif final_affine_tensor.shape[-1]==4:
pred_affine = Rigid.from_tensor_4x4(final_affine_tensor)
backbone_rigid = Rigid.from_tensor_4x4(backbone_rigid_tensor) backbone_rigid = Rigid.from_tensor_4x4(backbone_rigid_tensor)
def _points(affine): def _points(affine):
...@@ -1635,7 +1639,7 @@ def chain_center_of_mass_loss( ...@@ -1635,7 +1639,7 @@ def chain_center_of_mass_loss(
asym_id: torch.Tensor, asym_id: torch.Tensor,
clamp_distance: float = -4.0, clamp_distance: float = -4.0,
weight: float = 0.05, weight: float = 0.05,
eps: float = 1e-10 eps: float = 1e-10, **kwargs
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Computes chain centre-of-mass loss. Implements section 2.5, eqn 1 in the Multimer paper. Computes chain centre-of-mass loss. Implements section 2.5, eqn 1 in the Multimer paper.
...@@ -1662,9 +1666,9 @@ def chain_center_of_mass_loss( ...@@ -1662,9 +1666,9 @@ def chain_center_of_mass_loss(
all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :] all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :]
all_atom_positions = all_atom_positions[..., ca_pos, :] all_atom_positions = all_atom_positions[..., ca_pos, :]
all_atom_mask = all_atom_mask[..., ca_pos: (ca_pos + 1)] # keep dim all_atom_mask = all_atom_mask[..., ca_pos: (ca_pos + 1)] # keep dim
chains, _ = asym_id.unique(return_counts=True) chains, _ = asym_id.unique(return_counts=True)
one_hot = torch.nn.functional.one_hot(asym_id, num_classes=chains.shape[0]).to(dtype=all_atom_mask.dtype) one_hot = torch.nn.functional.one_hot(asym_id.to(torch.int64),
num_classes=chains.shape[0]).to(dtype=all_atom_mask.dtype) # make sure asym_id dtype is int
one_hot = one_hot * all_atom_mask one_hot = one_hot * all_atom_mask
chain_pos_mask = one_hot.transpose(-2, -1) chain_pos_mask = one_hot.transpose(-2, -1)
chain_exists = torch.any(chain_pos_mask, dim=-1).float() chain_exists = torch.any(chain_pos_mask, dim=-1).float()
...@@ -2012,8 +2016,12 @@ class AlphaFoldLoss(nn.Module): ...@@ -2012,8 +2016,12 @@ class AlphaFoldLoss(nn.Module):
return cum_loss, losses return cum_loss, losses
def forward(self, out, batch, _return_breakdown=False): def forward(self, out, batch, _return_breakdown=False):
cum_loss,losses = self.loss(out,batch,_return_breakdown) if(not _return_breakdown):
return cum_loss, losses cum_loss = self.loss(out,batch,_return_breakdown)
return cum_loss
else:
cum_loss,losses = self.loss(out,batch,_return_breakdown)
return cum_loss, losses
class AlphaFoldMultimerLoss(AlphaFoldLoss): class AlphaFoldMultimerLoss(AlphaFoldLoss):
""" """
...@@ -2120,11 +2128,14 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss): ...@@ -2120,11 +2128,14 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
# then permutate ground truth chains before calculating the loss # then permutate ground truth chains before calculating the loss
permutated_labels = self.multi_chain_perm_align(out,features,labels) permutated_labels = self.multi_chain_perm_align(out,features,labels)
permutated_labels.pop('aatype') permutated_labels.pop('aatype')
logger.info("finished multi-chain permutation ")
features.update(permutated_labels) features.update(permutated_labels)
move_to_cpu = lambda t: (t.to('cpu')) move_to_cpu = lambda t: (t.to('cpu'))
features = tensor_tree_map(move_to_cpu,features) features = tensor_tree_map(move_to_cpu,features)
self.loss(out,features) if (not _return_breakdown):
return permutated_labels cum_loss = self.loss(out,features,_return_breakdown)
## TODO next need to check how the ground truth label is used print(f"cum_loss: {cum_loss}")
# in loss calculation. return cum_loss
\ No newline at end of file else:
cum_loss,losses = self.loss(out,features,_return_breakdown)
print(f"cum_loss: {cum_loss} losses: {losses}")
return cum_loss, losses
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment