Commit 6f2f9213 authored by Michael Kösel's avatar Michael Kösel Committed by Soumith Chintala
Browse files

update code (#824)

parent 67c29f8d
......@@ -67,9 +67,10 @@ class Inception3(nn.Module):
import scipy.stats as stats
stddev = m.stddev if hasattr(m, 'stddev') else 0.1
X = stats.truncnorm(-2, 2, scale=stddev)
values = torch.Tensor(X.rvs(m.weight.numel()))
values = torch.as_tensor(X.rvs(m.weight.numel()), dtype=m.weight.dtype)
values = values.view(m.weight.size())
m.weight.data.copy_(values)
with torch.no_grad():
m.weight.copy_(values)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment