Commit 13a0493b authored by Soumith Chintala's avatar Soumith Chintala
Browse files

Container -> Module

parent 11821d62
...@@ -10,7 +10,7 @@ model_urls = { ...@@ -10,7 +10,7 @@ model_urls = {
} }
class AlexNet(nn.Container): class AlexNet(nn.Module):
def __init__(self, num_classes=1000): def __init__(self, num_classes=1000):
super(AlexNet, self).__init__() super(AlexNet, self).__init__()
self.features = nn.Sequential( self.features = nn.Sequential(
......
...@@ -21,7 +21,7 @@ def conv3x3(in_planes, out_planes, stride=1): ...@@ -21,7 +21,7 @@ def conv3x3(in_planes, out_planes, stride=1):
padding=1, bias=False) padding=1, bias=False)
class BasicBlock(nn.Container): class BasicBlock(nn.Module):
expansion = 1 expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None): def __init__(self, inplanes, planes, stride=1, downsample=None):
...@@ -53,7 +53,7 @@ class BasicBlock(nn.Container): ...@@ -53,7 +53,7 @@ class BasicBlock(nn.Container):
return out return out
class Bottleneck(nn.Container): class Bottleneck(nn.Module):
expansion = 4 expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None): def __init__(self, inplanes, planes, stride=1, downsample=None):
...@@ -92,7 +92,7 @@ class Bottleneck(nn.Container): ...@@ -92,7 +92,7 @@ class Bottleneck(nn.Container):
return out return out
class ResNet(nn.Container): class ResNet(nn.Module):
def __init__(self, block, layers, num_classes=1000): def __init__(self, block, layers, num_classes=1000):
self.inplanes = 64 self.inplanes = 64
super(ResNet, self).__init__() super(ResNet, self).__init__()
......
...@@ -7,7 +7,7 @@ __all__ = [ ...@@ -7,7 +7,7 @@ __all__ = [
] ]
class VGG(nn.Container): class VGG(nn.Module):
def __init__(self, features): def __init__(self, features):
super(VGG, self).__init__() super(VGG, self).__init__()
self.features = features self.features = features
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment