Commit b9eede45 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Rename pLDDT loss

parent a83fcc01
......@@ -653,7 +653,9 @@ def compute_tm(
normed_residue_mask = residue_weights / (eps + residue_weights.sum())
per_alignment = torch.sum(predicted_tm_term * normed_residue_mask, dim=-1)
weighted = per_alignment * residue_weights
argmax = (weighted == torch.max(weighted)).nonzero()[0]
return per_alignment[tuple(argmax)]
......@@ -1557,7 +1559,7 @@ class AlphaFoldLoss(nn.Module):
batch,
self.config.fape,
),
"lddt": lambda: lddt_loss(
"plddt_loss": lambda: lddt_loss(
logits=out["lddt_logits"],
all_atom_pred_pos=out["final_atom_positions"],
**{**batch, **self.config.lddt},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment