Unverified Commit 7e987bfd authored by Vincent Moens's avatar Vincent Moens Committed by GitHub
Browse files

Use torch instead of scipy for random initialization of inception and googlenet weights (#4256)



using nn.init.trunc_normal_ instead of scipy.stats.truncnorm
Co-authored-by: default avatarVincent Moens <vmoens@fb.com>
parent 876117b5
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
...@@ -124,12 +124,7 @@ class GoogLeNet(nn.Module): ...@@ -124,12 +124,7 @@ class GoogLeNet(nn.Module):
def _initialize_weights(self) -> None: def _initialize_weights(self) -> None:
for m in self.modules(): for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
import scipy.stats as stats torch.nn.init.trunc_normal_(m.weight, mean=0.0, std=0.01, a=-2, b=2)
X = stats.truncnorm(-2, 2, scale=0.01)
values = torch.as_tensor(X.rvs(m.weight.numel()), dtype=m.weight.dtype)
values = values.view(m.weight.size())
with torch.no_grad():
m.weight.copy_(values)
elif isinstance(m, nn.BatchNorm2d): elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1) nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0) nn.init.constant_(m.bias, 0)
......
...@@ -120,13 +120,8 @@ class Inception3(nn.Module): ...@@ -120,13 +120,8 @@ class Inception3(nn.Module):
if init_weights: if init_weights:
for m in self.modules(): for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
import scipy.stats as stats stddev = float(m.stddev) if hasattr(m, 'stddev') else 0.1 # type: ignore
stddev = m.stddev if hasattr(m, 'stddev') else 0.1 torch.nn.init.trunc_normal_(m.weight, mean=0.0, std=stddev, a=-2, b=2)
X = stats.truncnorm(-2, 2, scale=stddev)
values = torch.as_tensor(X.rvs(m.weight.numel()), dtype=m.weight.dtype)
values = values.view(m.weight.size())
with torch.no_grad():
m.weight.copy_(values)
elif isinstance(m, nn.BatchNorm2d): elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1) nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0) nn.init.constant_(m.bias, 0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment