Commit 4d40ce80 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Fix bug in MSA loss

parent 4c9d372d
......@@ -1359,6 +1359,16 @@ def experimentally_resolved_loss(
def masked_msa_loss(logits, true_msa, bert_mask, eps=1e-8, **kwargs):
"""
Computes BERT-style masked MSA loss. Implements subsection 1.9.9.
Args:
logits: [*, N_seq, N_res, 23] predicted residue distribution
true_msa: [*, N_seq, N_res] true MSA
bert_mask: [*, N_seq, N_res] MSA mask
Returns:
Masked MSA loss
"""
errors = softmax_cross_entropy(
logits, torch.nn.functional.one_hot(true_msa, num_classes=23)
)
......@@ -1376,6 +1386,8 @@ def masked_msa_loss(logits, true_msa, bert_mask, eps=1e-8, **kwargs):
loss = torch.sum(loss, dim=-1)
loss = loss * scale
loss = torch.mean(loss)
return loss
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment