Commit ad2e5c97 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Fix OPM bug

parent b2d102cb
......@@ -107,7 +107,7 @@ class OuterProductMean(nn.Module):
norm = torch.einsum("...abc,...adc->...bdc", mask, mask)
# [*, N_res, N_res, C_z]
outer = outer / self.eps + norm
outer = outer / (self.eps + norm)
return outer
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment