Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
6334466e
Unverified
Commit
6334466e
authored
Mar 26, 2019
by
Francisco Massa
Committed by
GitHub
Mar 26, 2019
Browse files
Add support for other normalizations (i.e., GroupNorm) in ResNet (#813)
parent
8c33bd78
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
26 additions
and
18 deletions
+26
-18
torchvision/models/resnet.py
torchvision/models/resnet.py
+26
-18
No files found.
torchvision/models/resnet.py
View file @
6334466e
...
...
@@ -29,14 +29,16 @@ def conv1x1(in_planes, out_planes, stride=1):
class
BasicBlock
(
nn
.
Module
):
expansion
=
1
def
__init__
(
self
,
inplanes
,
planes
,
stride
=
1
,
downsample
=
None
):
def
__init__
(
self
,
inplanes
,
planes
,
stride
=
1
,
downsample
=
None
,
norm_layer
=
None
):
super
(
BasicBlock
,
self
).
__init__
()
if
norm_layer
is
None
:
norm_layer
=
nn
.
BatchNorm2d
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self
.
conv1
=
conv3x3
(
inplanes
,
planes
,
stride
)
self
.
bn1
=
n
n
.
BatchNorm2d
(
planes
)
self
.
bn1
=
n
orm_layer
(
planes
)
self
.
relu
=
nn
.
ReLU
(
inplace
=
True
)
self
.
conv2
=
conv3x3
(
planes
,
planes
)
self
.
bn2
=
n
n
.
BatchNorm2d
(
planes
)
self
.
bn2
=
n
orm_layer
(
planes
)
self
.
downsample
=
downsample
self
.
stride
=
stride
...
...
@@ -62,15 +64,17 @@ class BasicBlock(nn.Module):
class
Bottleneck
(
nn
.
Module
):
expansion
=
4
def
__init__
(
self
,
inplanes
,
planes
,
stride
=
1
,
downsample
=
None
):
def
__init__
(
self
,
inplanes
,
planes
,
stride
=
1
,
downsample
=
None
,
norm_layer
=
None
):
super
(
Bottleneck
,
self
).
__init__
()
if
norm_layer
is
None
:
norm_layer
=
nn
.
BatchNorm2d
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self
.
conv1
=
conv1x1
(
inplanes
,
planes
)
self
.
bn1
=
n
n
.
BatchNorm2d
(
planes
)
self
.
bn1
=
n
orm_layer
(
planes
)
self
.
conv2
=
conv3x3
(
planes
,
planes
,
stride
)
self
.
bn2
=
n
n
.
BatchNorm2d
(
planes
)
self
.
bn2
=
n
orm_layer
(
planes
)
self
.
conv3
=
conv1x1
(
planes
,
planes
*
self
.
expansion
)
self
.
bn3
=
n
n
.
BatchNorm2d
(
planes
*
self
.
expansion
)
self
.
bn3
=
n
orm_layer
(
planes
*
self
.
expansion
)
self
.
relu
=
nn
.
ReLU
(
inplace
=
True
)
self
.
downsample
=
downsample
self
.
stride
=
stride
...
...
@@ -100,25 +104,27 @@ class Bottleneck(nn.Module):
class
ResNet
(
nn
.
Module
):
def
__init__
(
self
,
block
,
layers
,
num_classes
=
1000
,
zero_init_residual
=
False
):
def
__init__
(
self
,
block
,
layers
,
num_classes
=
1000
,
zero_init_residual
=
False
,
norm_layer
=
None
):
super
(
ResNet
,
self
).
__init__
()
if
norm_layer
is
None
:
norm_layer
=
nn
.
BatchNorm2d
self
.
inplanes
=
64
self
.
conv1
=
nn
.
Conv2d
(
3
,
64
,
kernel_size
=
7
,
stride
=
2
,
padding
=
3
,
bias
=
False
)
self
.
bn1
=
n
n
.
BatchNorm2d
(
64
)
self
.
bn1
=
n
orm_layer
(
64
)
self
.
relu
=
nn
.
ReLU
(
inplace
=
True
)
self
.
maxpool
=
nn
.
MaxPool2d
(
kernel_size
=
3
,
stride
=
2
,
padding
=
1
)
self
.
layer1
=
self
.
_make_layer
(
block
,
64
,
layers
[
0
])
self
.
layer2
=
self
.
_make_layer
(
block
,
128
,
layers
[
1
],
stride
=
2
)
self
.
layer3
=
self
.
_make_layer
(
block
,
256
,
layers
[
2
],
stride
=
2
)
self
.
layer4
=
self
.
_make_layer
(
block
,
512
,
layers
[
3
],
stride
=
2
)
self
.
layer1
=
self
.
_make_layer
(
block
,
64
,
layers
[
0
]
,
norm_layer
=
norm_layer
)
self
.
layer2
=
self
.
_make_layer
(
block
,
128
,
layers
[
1
],
stride
=
2
,
norm_layer
=
norm_layer
)
self
.
layer3
=
self
.
_make_layer
(
block
,
256
,
layers
[
2
],
stride
=
2
,
norm_layer
=
norm_layer
)
self
.
layer4
=
self
.
_make_layer
(
block
,
512
,
layers
[
3
],
stride
=
2
,
norm_layer
=
norm_layer
)
self
.
avgpool
=
nn
.
AdaptiveAvgPool2d
((
1
,
1
))
self
.
fc
=
nn
.
Linear
(
512
*
block
.
expansion
,
num_classes
)
for
m
in
self
.
modules
():
if
isinstance
(
m
,
nn
.
Conv2d
):
nn
.
init
.
kaiming_normal_
(
m
.
weight
,
mode
=
'fan_out'
,
nonlinearity
=
'relu'
)
elif
isinstance
(
m
,
nn
.
BatchNorm2d
):
elif
isinstance
(
m
,
(
nn
.
BatchNorm2d
,
nn
.
GroupNorm
)
):
nn
.
init
.
constant_
(
m
.
weight
,
1
)
nn
.
init
.
constant_
(
m
.
bias
,
0
)
...
...
@@ -132,19 +138,21 @@ class ResNet(nn.Module):
elif
isinstance
(
m
,
BasicBlock
):
nn
.
init
.
constant_
(
m
.
bn2
.
weight
,
0
)
def
_make_layer
(
self
,
block
,
planes
,
blocks
,
stride
=
1
):
def
_make_layer
(
self
,
block
,
planes
,
blocks
,
stride
=
1
,
norm_layer
=
None
):
if
norm_layer
is
None
:
norm_layer
=
nn
.
BatchNorm2d
downsample
=
None
if
stride
!=
1
or
self
.
inplanes
!=
planes
*
block
.
expansion
:
downsample
=
nn
.
Sequential
(
conv1x1
(
self
.
inplanes
,
planes
*
block
.
expansion
,
stride
),
n
n
.
BatchNorm2d
(
planes
*
block
.
expansion
),
n
orm_layer
(
planes
*
block
.
expansion
),
)
layers
=
[]
layers
.
append
(
block
(
self
.
inplanes
,
planes
,
stride
,
downsample
))
layers
.
append
(
block
(
self
.
inplanes
,
planes
,
stride
,
downsample
,
norm_layer
))
self
.
inplanes
=
planes
*
block
.
expansion
for
_
in
range
(
1
,
blocks
):
layers
.
append
(
block
(
self
.
inplanes
,
planes
))
layers
.
append
(
block
(
self
.
inplanes
,
planes
,
norm_layer
=
norm_layer
))
return
nn
.
Sequential
(
*
layers
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment