You need to sign in or sign up before continuing.
Commit b9eede45 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Rename pLDDT loss

parent a83fcc01
...@@ -653,7 +653,9 @@ def compute_tm( ...@@ -653,7 +653,9 @@ def compute_tm(
normed_residue_mask = residue_weights / (eps + residue_weights.sum()) normed_residue_mask = residue_weights / (eps + residue_weights.sum())
per_alignment = torch.sum(predicted_tm_term * normed_residue_mask, dim=-1) per_alignment = torch.sum(predicted_tm_term * normed_residue_mask, dim=-1)
weighted = per_alignment * residue_weights weighted = per_alignment * residue_weights
argmax = (weighted == torch.max(weighted)).nonzero()[0] argmax = (weighted == torch.max(weighted)).nonzero()[0]
return per_alignment[tuple(argmax)] return per_alignment[tuple(argmax)]
...@@ -1557,7 +1559,7 @@ class AlphaFoldLoss(nn.Module): ...@@ -1557,7 +1559,7 @@ class AlphaFoldLoss(nn.Module):
batch, batch,
self.config.fape, self.config.fape,
), ),
"lddt": lambda: lddt_loss( "plddt_loss": lambda: lddt_loss(
logits=out["lddt_logits"], logits=out["lddt_logits"],
all_atom_pred_pos=out["final_atom_positions"], all_atom_pred_pos=out["final_atom_positions"],
**{**batch, **self.config.lddt}, **{**batch, **self.config.lddt},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment