Commit f87a896f authored by Ryuichiro Hataya's avatar Ryuichiro Hataya Committed by Francisco Massa
Browse files

fix models for PyTorch v0.4 (remove .data and add _ for the initializations … (#481)

* fix for PyTorch v0.4 (remove .data and add _ for the initializations in nn.init)

* fix m.**.**() style to nn.init.**(**) style

* remove .idea

* fix lines and indents

* fix lines and indents

* change to use `kaming_normal_`

* add `.data` for safety

* add nonlinearity='relu' for sure

* fix indents
parent 1d0a3b11
...@@ -175,6 +175,7 @@ class DenseNet(nn.Module): ...@@ -175,6 +175,7 @@ class DenseNet(nn.Module):
drop_rate (float) - dropout rate after each dense layer drop_rate (float) - dropout rate after each dense layer
num_classes (int) - number of classification classes num_classes (int) - number of classification classes
""" """
def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16),
num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000): num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000):
...@@ -209,12 +210,12 @@ class DenseNet(nn.Module): ...@@ -209,12 +210,12 @@ class DenseNet(nn.Module):
# Official init from torch repo. # Official init from torch repo.
for m in self.modules(): for m in self.modules():
if isinstance(m, nn.Conv2d): if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal(m.weight.data) nn.init.kaiming_normal_(m.weight)
elif isinstance(m, nn.BatchNorm2d): elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1) nn.init.constant_(m.weight, 1)
m.bias.data.zero_() nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear): elif isinstance(m, nn.Linear):
m.bias.data.zero_() nn.init.constant_(m.bias, 0)
def forward(self, x): def forward(self, x):
features = self.features(x) features = self.features(x)
......
...@@ -61,12 +61,12 @@ class Inception3(nn.Module): ...@@ -61,12 +61,12 @@ class Inception3(nn.Module):
import scipy.stats as stats import scipy.stats as stats
stddev = m.stddev if hasattr(m, 'stddev') else 0.1 stddev = m.stddev if hasattr(m, 'stddev') else 0.1
X = stats.truncnorm(-2, 2, scale=stddev) X = stats.truncnorm(-2, 2, scale=stddev)
values = torch.Tensor(X.rvs(m.weight.data.numel())) values = torch.Tensor(X.rvs(m.weight.numel()))
values = values.view(m.weight.data.size()) values = values.view(m.weight.size())
m.weight.data.copy_(values) m.weight.data.copy_(values)
elif isinstance(m, nn.BatchNorm2d): elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1) nn.init.constant_(m.weight, 1)
m.bias.data.zero_() nn.init.constant_(m.bias, 0)
def forward(self, x): def forward(self, x):
if self.transform_input: if self.transform_input:
......
...@@ -112,11 +112,10 @@ class ResNet(nn.Module): ...@@ -112,11 +112,10 @@ class ResNet(nn.Module):
for m in self.modules(): for m in self.modules():
if isinstance(m, nn.Conv2d): if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d): elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1) nn.init.constant_(m.weight, 1)
m.bias.data.zero_() nn.init.constant_(m.bias, 0)
def _make_layer(self, block, planes, blocks, stride=1): def _make_layer(self, block, planes, blocks, stride=1):
downsample = None downsample = None
......
...@@ -89,11 +89,11 @@ class SqueezeNet(nn.Module): ...@@ -89,11 +89,11 @@ class SqueezeNet(nn.Module):
for m in self.modules(): for m in self.modules():
if isinstance(m, nn.Conv2d): if isinstance(m, nn.Conv2d):
if m is final_conv: if m is final_conv:
init.normal(m.weight.data, mean=0.0, std=0.01) init.normal_(m.weight, mean=0.0, std=0.01)
else: else:
init.kaiming_uniform(m.weight.data) init.kaiming_uniform_(m.weight)
if m.bias is not None: if m.bias is not None:
m.bias.data.zero_() init.constant_(m.bias, 0)
def forward(self, x): def forward(self, x):
x = self.features(x) x = self.features(x)
......
...@@ -47,16 +47,15 @@ class VGG(nn.Module): ...@@ -47,16 +47,15 @@ class VGG(nn.Module):
def _initialize_weights(self): def _initialize_weights(self):
for m in self.modules(): for m in self.modules():
if isinstance(m, nn.Conv2d): if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
m.weight.data.normal_(0, math.sqrt(2. / n))
if m.bias is not None: if m.bias is not None:
m.bias.data.zero_() nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d): elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1) nn.init.constant_(m.weight, 1)
m.bias.data.zero_() nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear): elif isinstance(m, nn.Linear):
m.weight.data.normal_(0, 0.01) nn.init.normal_(m.weight, 0, 0.01)
m.bias.data.zero_() nn.init.constant_(m.bias, 0)
def make_layers(cfg, batch_norm=False): def make_layers(cfg, batch_norm=False):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment