Commit 894905b9 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Add loss scaling from subsection 1.9

parent afd2d839
...@@ -1526,4 +1526,10 @@ class AlphaFoldLoss(nn.Module): ...@@ -1526,4 +1526,10 @@ class AlphaFoldLoss(nn.Module):
loss = loss.new_tensor(0., requires_grad=True) loss = loss.new_tensor(0., requires_grad=True)
cum_loss = cum_loss + weight * loss cum_loss = cum_loss + weight * loss
# Scale the loss by the square root of the minimum of the crop size and
# the (average) sequence length. See subsection 1.9.
seq_len = torch.mean(batch["seq_length"].float())
crop_len = batch["aatype"].shape[-1]
cum_loss = cum_loss * torch.sqrt(min(seq_len, crop_len))
return cum_loss return cum_loss
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment