Unverified Commit f2d8aa06 authored by oahzxl's avatar oahzxl Committed by GitHub
Browse files

Fix OuterProductMean (#81)

Can get stable results now
parent e7234756
......@@ -166,9 +166,9 @@ class OutProductMean(nn.Module):
O = rearrange(O, 'b i j d e -> b i j (d e)')
O = self.o_linear(O)
norm0 = norm[:, ax:ax + chunk_size, :, :]
Z[:, ax:ax + chunk_size, :, :] += O / norm0
Z[:, ax:ax + chunk_size, :, :] = O / norm0
return Z
return Z + Z_raw
def inplace(self, M, M_mask, Z_raw):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment