Commit bbf42cc5 authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

fixed the fape and backbone loss errors

parent b22bd4e3
......@@ -185,7 +185,13 @@ def backbone_loss(
eps: float = 1e-4,
**kwargs,
) -> torch.Tensor:
### need to check if the traj belongs to 4*4 matrix or a tensor_7
if traj.shape[-1]==7:
pred_aff = Rigid.from_tensor_7(traj)
elif traj.shape[-1]==4:
pred_aff = Rigid.from_tensor_4x4(traj)
pred_aff = Rigid(
Rotation(rot_mats=pred_aff.get_rots().get_rot_mats(), quats=None),
pred_aff.get_trans(),
......@@ -304,10 +310,10 @@ def fape_loss(
interface_bb_loss = backbone_loss(
traj=traj,
pair_mask=1. - intra_chain_mask,
**{**batch, **config.interface_backbone},
**{**batch, **config.intra_chain_backbone},
)
weighted_bb_loss = (intra_chain_bb_loss * config.intra_chain_backbone.weight
+ interface_bb_loss * config.interface_backbone.weight)
+ interface_bb_loss * config.intra_chain_backbone.weight)
else:
bb_loss = backbone_loss(
traj=traj,
......@@ -1865,7 +1871,6 @@ def greedy_align(
if (rmsd is not None) and (rmsd < best_rmsd):
best_rmsd = rmsd
best_idx = j
print(f"now best_idx is {best_idx} and rmsd is {rmsd} and j is {j}")
assert best_idx is not None
used[best_idx] = True
align.append((i, best_idx))
......@@ -1920,7 +1925,7 @@ class AlphaFoldLoss(nn.Module):
out["violation"] = find_structural_violations(
batch,
out["sm"]["positions"][-1],
**self.config.loss.violation,
**self.config.violation,
)
if "renamed_atom14_gt_positions" not in out.keys():
......@@ -2110,12 +2115,13 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
batch: a pair of input features and its corresponding ground truth structure
"""
features,labels = batch
features['resolution'] = labels[2]['resolution'] # firstly update the resolution feature
# first remove the recycling dimention of input features
features = tensor_tree_map(lambda t: t[..., -1], features)
# then permutate ground truth chains before calculating the loss
permutated_labels = self.multi_chain_perm_align(out,features,labels)
permutated_labels.pop('aatype')
logger.info("finished multi-chain permutation")
logger.info("finished multi-chain permutation ")
features.update(permutated_labels)
move_to_cpu = lambda t: (t.to('cpu'))
features = tensor_tree_map(move_to_cpu,features)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment