Commit 1e5f7a2f authored by Boris Bonev's avatar Boris Bonev
Browse files

adjusting initialization

parent 9577cc8f
...@@ -208,7 +208,8 @@ class DiscreteContinuousConvS2(nn.Module): ...@@ -208,7 +208,8 @@ class DiscreteContinuousConvS2(nn.Module):
if out_channels % self.groups != 0: if out_channels % self.groups != 0:
raise ValueError("Error, the number of output channels has to be an integer multiple of the group size") raise ValueError("Error, the number of output channels has to be an integer multiple of the group size")
self.groupsize = in_channels // self.groups self.groupsize = in_channels // self.groups
self.weight = nn.Parameter(torch.ones(out_channels, self.groupsize, kernel_shape[0])) scale = math.sqrt(1.0 / self.groupsize)
self.weight = nn.Parameter(scale * torch.randn(out_channels, self.groupsize, kernel_shape[0]))
if bias: if bias:
self.bias = nn.Parameter(torch.zeros(out_channels)) self.bias = nn.Parameter(torch.zeros(out_channels))
...@@ -299,7 +300,8 @@ class DiscreteContinuousConvTransposeS2(nn.Module): ...@@ -299,7 +300,8 @@ class DiscreteContinuousConvTransposeS2(nn.Module):
if out_channels % self.groups != 0: if out_channels % self.groups != 0:
raise ValueError("Error, the number of output channels has to be an integer multiple of the group size") raise ValueError("Error, the number of output channels has to be an integer multiple of the group size")
self.groupsize = in_channels // self.groups self.groupsize = in_channels // self.groups
self.weight = nn.Parameter(torch.ones(out_channels, self.groupsize, kernel_shape[0])) scale = math.sqrt(1.0 / self.groupsize)
self.weight = nn.Parameter(scale * torch.randn(out_channels, self.groupsize, kernel_shape[0]))
if bias: if bias:
self.bias = nn.Parameter(torch.zeros(out_channels)) self.bias = nn.Parameter(torch.zeros(out_channels))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment