Unverified Commit f2d8aa06 authored by oahzxl's avatar oahzxl Committed by GitHub
Browse files

Fix OuterProductMean (#81)

Can get stable results now
parent e7234756
...@@ -166,9 +166,9 @@ class OutProductMean(nn.Module): ...@@ -166,9 +166,9 @@ class OutProductMean(nn.Module):
O = rearrange(O, 'b i j d e -> b i j (d e)') O = rearrange(O, 'b i j d e -> b i j (d e)')
O = self.o_linear(O) O = self.o_linear(O)
norm0 = norm[:, ax:ax + chunk_size, :, :] norm0 = norm[:, ax:ax + chunk_size, :, :]
Z[:, ax:ax + chunk_size, :, :] += O / norm0 Z[:, ax:ax + chunk_size, :, :] = O / norm0
return Z return Z + Z_raw
def inplace(self, M, M_mask, Z_raw): def inplace(self, M, M_mask, Z_raw):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment