Commit 1e5f7a2f authored by Boris Bonev's avatar Boris Bonev
Browse files

adjusting initialization

parent 9577cc8f
......@@ -208,7 +208,8 @@ class DiscreteContinuousConvS2(nn.Module):
if out_channels % self.groups != 0:
raise ValueError("Error, the number of output channels has to be an integer multiple of the group size")
self.groupsize = in_channels // self.groups
self.weight = nn.Parameter(torch.ones(out_channels, self.groupsize, kernel_shape[0]))
scale = math.sqrt(1.0 / self.groupsize)
self.weight = nn.Parameter(scale * torch.randn(out_channels, self.groupsize, kernel_shape[0]))
if bias:
self.bias = nn.Parameter(torch.zeros(out_channels))
......@@ -299,7 +300,8 @@ class DiscreteContinuousConvTransposeS2(nn.Module):
if out_channels % self.groups != 0:
raise ValueError("Error, the number of output channels has to be an integer multiple of the group size")
self.groupsize = in_channels // self.groups
self.weight = nn.Parameter(torch.ones(out_channels, self.groupsize, kernel_shape[0]))
scale = math.sqrt(1.0 / self.groupsize)
self.weight = nn.Parameter(scale * torch.randn(out_channels, self.groupsize, kernel_shape[0]))
if bias:
self.bias = nn.Parameter(torch.zeros(out_channels))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment