Commit bbf42cc5 authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

fixed the fape and backbone loss errors

parent b22bd4e3
...@@ -185,7 +185,13 @@ def backbone_loss( ...@@ -185,7 +185,13 @@ def backbone_loss(
eps: float = 1e-4, eps: float = 1e-4,
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
### need to check if the traj belongs to 4*4 matrix or a tensor_7
if traj.shape[-1]==7:
pred_aff = Rigid.from_tensor_7(traj) pred_aff = Rigid.from_tensor_7(traj)
elif traj.shape[-1]==4:
pred_aff = Rigid.from_tensor_4x4(traj)
pred_aff = Rigid( pred_aff = Rigid(
Rotation(rot_mats=pred_aff.get_rots().get_rot_mats(), quats=None), Rotation(rot_mats=pred_aff.get_rots().get_rot_mats(), quats=None),
pred_aff.get_trans(), pred_aff.get_trans(),
...@@ -304,10 +310,10 @@ def fape_loss( ...@@ -304,10 +310,10 @@ def fape_loss(
interface_bb_loss = backbone_loss( interface_bb_loss = backbone_loss(
traj=traj, traj=traj,
pair_mask=1. - intra_chain_mask, pair_mask=1. - intra_chain_mask,
**{**batch, **config.interface_backbone}, **{**batch, **config.intra_chain_backbone},
) )
weighted_bb_loss = (intra_chain_bb_loss * config.intra_chain_backbone.weight weighted_bb_loss = (intra_chain_bb_loss * config.intra_chain_backbone.weight
+ interface_bb_loss * config.interface_backbone.weight) + interface_bb_loss * config.intra_chain_backbone.weight)
else: else:
bb_loss = backbone_loss( bb_loss = backbone_loss(
traj=traj, traj=traj,
...@@ -1865,7 +1871,6 @@ def greedy_align( ...@@ -1865,7 +1871,6 @@ def greedy_align(
if (rmsd is not None) and (rmsd < best_rmsd): if (rmsd is not None) and (rmsd < best_rmsd):
best_rmsd = rmsd best_rmsd = rmsd
best_idx = j best_idx = j
print(f"now best_idx is {best_idx} and rmsd is {rmsd} and j is {j}")
assert best_idx is not None assert best_idx is not None
used[best_idx] = True used[best_idx] = True
align.append((i, best_idx)) align.append((i, best_idx))
...@@ -1920,7 +1925,7 @@ class AlphaFoldLoss(nn.Module): ...@@ -1920,7 +1925,7 @@ class AlphaFoldLoss(nn.Module):
out["violation"] = find_structural_violations( out["violation"] = find_structural_violations(
batch, batch,
out["sm"]["positions"][-1], out["sm"]["positions"][-1],
**self.config.loss.violation, **self.config.violation,
) )
if "renamed_atom14_gt_positions" not in out.keys(): if "renamed_atom14_gt_positions" not in out.keys():
...@@ -2110,12 +2115,13 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss): ...@@ -2110,12 +2115,13 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
batch: a pair of input features and its corresponding ground truth structure batch: a pair of input features and its corresponding ground truth structure
""" """
features,labels = batch features,labels = batch
features['resolution'] = labels[2]['resolution'] # firstly update the resolution feature
# first remove the recycling dimention of input features # first remove the recycling dimention of input features
features = tensor_tree_map(lambda t: t[..., -1], features) features = tensor_tree_map(lambda t: t[..., -1], features)
# then permutate ground truth chains before calculating the loss # then permutate ground truth chains before calculating the loss
permutated_labels = self.multi_chain_perm_align(out,features,labels) permutated_labels = self.multi_chain_perm_align(out,features,labels)
permutated_labels.pop('aatype') permutated_labels.pop('aatype')
logger.info("finished multi-chain permutation") logger.info("finished multi-chain permutation ")
features.update(permutated_labels) features.update(permutated_labels)
move_to_cpu = lambda t: (t.to('cpu')) move_to_cpu = lambda t: (t.to('cpu'))
features = tensor_tree_map(move_to_cpu,features) features = tensor_tree_map(move_to_cpu,features)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment