Commit ad2e5c97 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Fix OPM bug

parent b2d102cb
...@@ -50,7 +50,7 @@ class OuterProductMean(nn.Module): ...@@ -50,7 +50,7 @@ class OuterProductMean(nn.Module):
def _opm(self, a, b): def _opm(self, a, b):
# [*, N_res, N_res, C, C] # [*, N_res, N_res, C, C]
outer = torch.einsum("...bac,...dae->...bdce", a, b) outer = torch.einsum("...bac,...dae->...bdce", a, b)
# [*, N_res, N_res, C * C] # [*, N_res, N_res, C * C]
outer = outer.reshape(*outer.shape[:-2], -1) outer = outer.reshape(*outer.shape[:-2], -1)
...@@ -107,7 +107,7 @@ class OuterProductMean(nn.Module): ...@@ -107,7 +107,7 @@ class OuterProductMean(nn.Module):
norm = torch.einsum("...abc,...adc->...bdc", mask, mask) norm = torch.einsum("...abc,...adc->...bdc", mask, mask)
# [*, N_res, N_res, C_z] # [*, N_res, N_res, C_z]
outer = outer / self.eps + norm outer = outer / (self.eps + norm)
return outer return outer
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment