Unverified Commit 791c172a authored by os-gabe's avatar os-gabe Committed by GitHub
Browse files

Fixes #1797 by adding an init_weights keyword argument to Inception3 (#1832)

parent f2600c2e
...@@ -65,7 +65,7 @@ def inception_v3(pretrained=False, progress=True, **kwargs): ...@@ -65,7 +65,7 @@ def inception_v3(pretrained=False, progress=True, **kwargs):
class Inception3(nn.Module): class Inception3(nn.Module):
def __init__(self, num_classes=1000, aux_logits=True, transform_input=False, def __init__(self, num_classes=1000, aux_logits=True, transform_input=False,
inception_blocks=None): inception_blocks=None, init_weights=True):
super(Inception3, self).__init__() super(Inception3, self).__init__()
if inception_blocks is None: if inception_blocks is None:
inception_blocks = [ inception_blocks = [
...@@ -102,7 +102,7 @@ class Inception3(nn.Module): ...@@ -102,7 +102,7 @@ class Inception3(nn.Module):
self.Mixed_7b = inception_e(1280) self.Mixed_7b = inception_e(1280)
self.Mixed_7c = inception_e(2048) self.Mixed_7c = inception_e(2048)
self.fc = nn.Linear(2048, num_classes) self.fc = nn.Linear(2048, num_classes)
if init_weights:
for m in self.modules(): for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
import scipy.stats as stats import scipy.stats as stats
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment