"vscode:/vscode.git/clone" did not exist on "6d0a5cd24aadc90255d99f3c4f27951cea735da5"
Unverified Commit e1a30427 authored by Bisakh Mondal's avatar Bisakh Mondal Committed by GitHub
Browse files

Updated with weight initialization warnings (#2170)

parent 6f849df0
...@@ -62,11 +62,16 @@ def googlenet(pretrained=False, progress=True, **kwargs): ...@@ -62,11 +62,16 @@ def googlenet(pretrained=False, progress=True, **kwargs):
class GoogLeNet(nn.Module): class GoogLeNet(nn.Module):
__constants__ = ['aux_logits', 'transform_input'] __constants__ = ['aux_logits', 'transform_input']
def __init__(self, num_classes=1000, aux_logits=True, transform_input=False, init_weights=True, def __init__(self, num_classes=1000, aux_logits=True, transform_input=False, init_weights=None,
blocks=None): blocks=None):
super(GoogLeNet, self).__init__() super(GoogLeNet, self).__init__()
if blocks is None: if blocks is None:
blocks = [BasicConv2d, Inception, InceptionAux] blocks = [BasicConv2d, Inception, InceptionAux]
if init_weights is None:
warnings.warn('The default weight initialization of GoogleNet will be changed in future releases of '
'torchvision. If you wish to keep the old behavior (which leads to long initialization times'
' due to scipy/scipy#11299), please set init_weights=True.', FutureWarning)
init_weights = True
assert len(blocks) == 3 assert len(blocks) == 3
conv_block = blocks[0] conv_block = blocks[0]
inception_block = blocks[1] inception_block = blocks[1]
......
...@@ -63,13 +63,18 @@ def inception_v3(pretrained=False, progress=True, **kwargs): ...@@ -63,13 +63,18 @@ def inception_v3(pretrained=False, progress=True, **kwargs):
class Inception3(nn.Module): class Inception3(nn.Module):
def __init__(self, num_classes=1000, aux_logits=True, transform_input=False, def __init__(self, num_classes=1000, aux_logits=True, transform_input=False,
inception_blocks=None, init_weights=True): inception_blocks=None, init_weights=None):
super(Inception3, self).__init__() super(Inception3, self).__init__()
if inception_blocks is None: if inception_blocks is None:
inception_blocks = [ inception_blocks = [
BasicConv2d, InceptionA, InceptionB, InceptionC, BasicConv2d, InceptionA, InceptionB, InceptionC,
InceptionD, InceptionE, InceptionAux InceptionD, InceptionE, InceptionAux
] ]
if init_weights is None:
warnings.warn('The default weight initialization of inception_v3 will be changed in future releases of '
'torchvision. If you wish to keep the old behavior (which leads to long initialization times'
' due to scipy/scipy#11299), please set init_weights=True.', FutureWarning)
init_weights = True
assert len(inception_blocks) == 7 assert len(inception_blocks) == 7
conv_block = inception_blocks[0] conv_block = inception_blocks[0]
inception_a = inception_blocks[1] inception_a = inception_blocks[1]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment