Commit 67c29f8d authored by Michael Kösel's avatar Michael Kösel Committed by Francisco Massa
Browse files

Match Tensorflows implementation of GoogLeNet more closely (#821)

* Match Tensorflows implementation of GoogLeNet

* just disable the branch when pretrained is true

* don't use legacy code
parent df9de17d
...@@ -17,12 +17,16 @@ def googlenet(pretrained=False, **kwargs): ...@@ -17,12 +17,16 @@ def googlenet(pretrained=False, **kwargs):
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
aux_logits (bool): If True, adds two auxiliary branches that can improve training.
Automatically set to False if 'pretrained' is True. Default: *True*
transform_input (bool): If True, preprocesses the input according to the method with which it transform_input (bool): If True, preprocesses the input according to the method with which it
was trained on ImageNet. Default: *False* was trained on ImageNet. Default: *False*
""" """
if pretrained: if pretrained:
if 'transform_input' not in kwargs: if 'transform_input' not in kwargs:
kwargs['transform_input'] = True kwargs['transform_input'] = True
if 'aux_logits' not in kwargs:
kwargs['aux_logits'] = False
kwargs['init_weights'] = False kwargs['init_weights'] = False
model = GoogLeNet(**kwargs) model = GoogLeNet(**kwargs)
model.load_state_dict(model_zoo.load_url(model_urls['googlenet'])) model.load_state_dict(model_zoo.load_url(model_urls['googlenet']))
...@@ -57,11 +61,12 @@ class GoogLeNet(nn.Module): ...@@ -57,11 +61,12 @@ class GoogLeNet(nn.Module):
self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128) self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128)
self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128) self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128)
if aux_logits:
self.aux1 = InceptionAux(512, num_classes) self.aux1 = InceptionAux(512, num_classes)
self.aux2 = InceptionAux(528, num_classes) self.aux2 = InceptionAux(528, num_classes)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.dropout = nn.Dropout(0.4) self.dropout = nn.Dropout(0.2)
self.fc = nn.Linear(1024, num_classes) self.fc = nn.Linear(1024, num_classes)
if init_weights: if init_weights:
...@@ -69,13 +74,13 @@ class GoogLeNet(nn.Module): ...@@ -69,13 +74,13 @@ class GoogLeNet(nn.Module):
def _initialize_weights(self): def _initialize_weights(self):
for m in self.modules(): for m in self.modules():
if isinstance(m, nn.Conv2d): if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight) import scipy.stats as stats
if m.bias is not None: X = stats.truncnorm(-2, 2, scale=0.01)
nn.init.constant_(m.bias, 0.2) values = torch.as_tensor(X.rvs(m.weight.numel()), dtype=m.weight.dtype)
elif isinstance(m, nn.Linear): values = values.view(m.weight.size())
nn.init.xavier_uniform_(m.weight) with torch.no_grad():
nn.init.constant_(m.bias, 0) m.weight.copy_(values)
elif isinstance(m, nn.BatchNorm2d): elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1) nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0) nn.init.constant_(m.bias, 0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment