Commit 71322cba authored by ekka's avatar ekka Committed by Soumith Chintala
Browse files

Add comments regarding downsampling layers of resnet (#794)

In reference to #729 added comments to clarify the naming and action of the layers performing downsampling in resnets.
parent b3a7cf67
...@@ -31,6 +31,7 @@ class BasicBlock(nn.Module): ...@@ -31,6 +31,7 @@ class BasicBlock(nn.Module):
def __init__(self, inplanes, planes, stride=1, downsample=None): def __init__(self, inplanes, planes, stride=1, downsample=None):
super(BasicBlock, self).__init__() super(BasicBlock, self).__init__()
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv3x3(inplanes, planes, stride) self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes) self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True) self.relu = nn.ReLU(inplace=True)
...@@ -63,6 +64,7 @@ class Bottleneck(nn.Module): ...@@ -63,6 +64,7 @@ class Bottleneck(nn.Module):
def __init__(self, inplanes, planes, stride=1, downsample=None): def __init__(self, inplanes, planes, stride=1, downsample=None):
super(Bottleneck, self).__init__() super(Bottleneck, self).__init__()
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv1x1(inplanes, planes) self.conv1 = conv1x1(inplanes, planes)
self.bn1 = nn.BatchNorm2d(planes) self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = conv3x3(planes, planes, stride) self.conv2 = conv3x3(planes, planes, stride)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment