"docs/vscode:/vscode.git/clone" did not exist on "7c2e0af2a61745a36b26fae2c817f608be757a4c"
Commit 67c29f8d authored by Michael Kösel's avatar Michael Kösel Committed by Francisco Massa
Browse files

Match Tensorflows implementation of GoogLeNet more closely (#821)

* Match Tensorflows implementation of GoogLeNet

* just disable the branch when pretrained is true

* don't use legacy code
parent df9de17d
......@@ -17,12 +17,16 @@ def googlenet(pretrained=False, **kwargs):
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
aux_logits (bool): If True, adds two auxiliary branches that can improve training.
Automatically set to False if 'pretrained' is True. Default: *True*
transform_input (bool): If True, preprocesses the input according to the method with which it
was trained on ImageNet. Default: *False*
"""
if pretrained:
if 'transform_input' not in kwargs:
kwargs['transform_input'] = True
if 'aux_logits' not in kwargs:
kwargs['aux_logits'] = False
kwargs['init_weights'] = False
model = GoogLeNet(**kwargs)
model.load_state_dict(model_zoo.load_url(model_urls['googlenet']))
......@@ -57,11 +61,12 @@ class GoogLeNet(nn.Module):
self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128)
self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128)
if aux_logits:
self.aux1 = InceptionAux(512, num_classes)
self.aux2 = InceptionAux(528, num_classes)
self.aux1 = InceptionAux(512, num_classes)
self.aux2 = InceptionAux(528, num_classes)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.dropout = nn.Dropout(0.4)
self.dropout = nn.Dropout(0.2)
self.fc = nn.Linear(1024, num_classes)
if init_weights:
......@@ -69,13 +74,13 @@ class GoogLeNet(nn.Module):
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias, 0.2)
elif isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
nn.init.constant_(m.bias, 0)
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
import scipy.stats as stats
X = stats.truncnorm(-2, 2, scale=0.01)
values = torch.as_tensor(X.rvs(m.weight.numel()), dtype=m.weight.dtype)
values = values.view(m.weight.size())
with torch.no_grad():
m.weight.copy_(values)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment