Commit 4cc6f45a authored by Yuxin Wu's avatar Yuxin Wu Committed by Francisco Massa
Browse files

Zero-init the residual branch in resnet (#498)

* Zero-init the residual branch in resnet

* Add zero_init_residual as an option
parent 5123ded4
...@@ -98,7 +98,7 @@ class Bottleneck(nn.Module): ...@@ -98,7 +98,7 @@ class Bottleneck(nn.Module):
class ResNet(nn.Module): class ResNet(nn.Module):
def __init__(self, block, layers, num_classes=1000): def __init__(self, block, layers, num_classes=1000, zero_init_residual=False):
super(ResNet, self).__init__() super(ResNet, self).__init__()
self.inplanes = 64 self.inplanes = 64
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
...@@ -120,6 +120,16 @@ class ResNet(nn.Module): ...@@ -120,6 +120,16 @@ class ResNet(nn.Module):
nn.init.constant_(m.weight, 1) nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0) nn.init.constant_(m.bias, 0)
# Zero-initialize the last BN in each residual branch,
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
if zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
nn.init.constant_(m.bn3.weight, 0)
elif isinstance(m, BasicBlock):
nn.init.constant_(m.bn2.weight, 0)
def _make_layer(self, block, planes, blocks, stride=1): def _make_layer(self, block, planes, blocks, stride=1):
downsample = None downsample = None
if stride != 1 or self.inplanes != planes * block.expansion: if stride != 1 or self.inplanes != planes * block.expansion:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment