Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
350a3e8e
Unverified
Commit
350a3e8e
authored
Mar 07, 2022
by
Vasilis Vryniotis
Committed by
GitHub
Mar 07, 2022
Browse files
Use frozen BN only if pre-trained. (#5443)
parent
b4cb352c
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
53 additions
and
50 deletions
+53
-50
torchvision/models/detection/faster_rcnn.py
torchvision/models/detection/faster_rcnn.py
+8
-10
torchvision/models/detection/fcos.py
torchvision/models/detection/fcos.py
+4
-4
torchvision/models/detection/keypoint_rcnn.py
torchvision/models/detection/keypoint_rcnn.py
+4
-4
torchvision/models/detection/mask_rcnn.py
torchvision/models/detection/mask_rcnn.py
+4
-4
torchvision/models/detection/retinanet.py
torchvision/models/detection/retinanet.py
+4
-4
torchvision/prototype/models/detection/faster_rcnn.py
torchvision/prototype/models/detection/faster_rcnn.py
+9
-8
torchvision/prototype/models/detection/fcos.py
torchvision/prototype/models/detection/fcos.py
+5
-4
torchvision/prototype/models/detection/keypoint_rcnn.py
torchvision/prototype/models/detection/keypoint_rcnn.py
+5
-4
torchvision/prototype/models/detection/mask_rcnn.py
torchvision/prototype/models/detection/mask_rcnn.py
+5
-4
torchvision/prototype/models/detection/retinanet.py
torchvision/prototype/models/detection/retinanet.py
+5
-4
No files found.
torchvision/models/detection/faster_rcnn.py
View file @
350a3e8e
...
...
@@ -383,15 +383,15 @@ def fasterrcnn_resnet50_fpn(
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is
passed (the default) this value is set to 3.
"""
train
able_backbone_layers
=
_validate_trainable_layers
(
pretrained
or
pretrained_backbo
ne
,
trainable_backbone_layers
,
5
,
3
)
is_
train
ed
=
pretrained
or
pretrained_backbone
trainable_backbone_layers
=
_validate_trainable_layers
(
is_trai
ne
d
,
trainable_backbone_layers
,
5
,
3
)
norm_layer
=
misc_nn_ops
.
FrozenBatchNorm2d
if
is_trained
else
nn
.
BatchNorm2d
if
pretrained
:
# no need to download the backbone if pretrained is set
pretrained_backbone
=
False
backbone
=
resnet50
(
pretrained
=
pretrained_backbone
,
progress
=
progress
,
norm_layer
=
misc_nn_ops
.
FrozenBatchNorm2d
)
backbone
=
resnet50
(
pretrained
=
pretrained_backbone
,
progress
=
progress
,
norm_layer
=
norm_layer
)
backbone
=
_resnet_fpn_extractor
(
backbone
,
trainable_backbone_layers
)
model
=
FasterRCNN
(
backbone
,
num_classes
,
**
kwargs
)
if
pretrained
:
...
...
@@ -410,16 +410,14 @@ def _fasterrcnn_mobilenet_v3_large_fpn(
trainable_backbone_layers
=
None
,
**
kwargs
,
):
train
able_backbone_layers
=
_validate_trainable_layers
(
pretrained
or
pretrained_backbo
ne
,
trainable_backbone_layers
,
6
,
3
)
is_
train
ed
=
pretrained
or
pretrained_backbone
trainable_backbone_layers
=
_validate_trainable_layers
(
is_trai
ne
d
,
trainable_backbone_layers
,
6
,
3
)
norm_layer
=
misc_nn_ops
.
FrozenBatchNorm2d
if
is_trained
else
nn
.
BatchNorm2d
if
pretrained
:
pretrained_backbone
=
False
backbone
=
mobilenet_v3_large
(
pretrained
=
pretrained_backbone
,
progress
=
progress
,
norm_layer
=
misc_nn_ops
.
FrozenBatchNorm2d
)
backbone
=
mobilenet_v3_large
(
pretrained
=
pretrained_backbone
,
progress
=
progress
,
norm_layer
=
norm_layer
)
backbone
=
_mobilenet_extractor
(
backbone
,
True
,
trainable_backbone_layers
)
anchor_sizes
=
(
...
...
torchvision/models/detection/fcos.py
View file @
350a3e8e
...
...
@@ -686,15 +686,15 @@ def fcos_resnet50_fpn(
from final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are
trainable. If ``None`` is passed (the default) this value is set to 3. Default: None
"""
train
able_backbone_layers
=
_validate_trainable_layers
(
pretrained
or
pretrained_backbo
ne
,
trainable_backbone_layers
,
5
,
3
)
is_
train
ed
=
pretrained
or
pretrained_backbone
trainable_backbone_layers
=
_validate_trainable_layers
(
is_trai
ne
d
,
trainable_backbone_layers
,
5
,
3
)
norm_layer
=
misc_nn_ops
.
FrozenBatchNorm2d
if
is_trained
else
nn
.
BatchNorm2d
if
pretrained
:
# no need to download the backbone if pretrained is set
pretrained_backbone
=
False
backbone
=
resnet50
(
pretrained
=
pretrained_backbone
,
progress
=
progress
,
norm_layer
=
misc_nn_ops
.
FrozenBatchNorm2d
)
backbone
=
resnet50
(
pretrained
=
pretrained_backbone
,
progress
=
progress
,
norm_layer
=
norm_layer
)
backbone
=
_resnet_fpn_extractor
(
backbone
,
trainable_backbone_layers
,
returned_layers
=
[
2
,
3
,
4
],
extra_blocks
=
LastLevelP6P7
(
256
,
256
)
)
...
...
torchvision/models/detection/keypoint_rcnn.py
View file @
350a3e8e
...
...
@@ -365,15 +365,15 @@ def keypointrcnn_resnet50_fpn(
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is
passed (the default) this value is set to 3.
"""
train
able_backbone_layers
=
_validate_trainable_layers
(
pretrained
or
pretrained_backbo
ne
,
trainable_backbone_layers
,
5
,
3
)
is_
train
ed
=
pretrained
or
pretrained_backbone
trainable_backbone_layers
=
_validate_trainable_layers
(
is_trai
ne
d
,
trainable_backbone_layers
,
5
,
3
)
norm_layer
=
misc_nn_ops
.
FrozenBatchNorm2d
if
is_trained
else
nn
.
BatchNorm2d
if
pretrained
:
# no need to download the backbone if pretrained is set
pretrained_backbone
=
False
backbone
=
resnet50
(
pretrained
=
pretrained_backbone
,
progress
=
progress
,
norm_layer
=
misc_nn_ops
.
FrozenBatchNorm2d
)
backbone
=
resnet50
(
pretrained
=
pretrained_backbone
,
progress
=
progress
,
norm_layer
=
norm_layer
)
backbone
=
_resnet_fpn_extractor
(
backbone
,
trainable_backbone_layers
)
model
=
KeypointRCNN
(
backbone
,
num_classes
,
num_keypoints
=
num_keypoints
,
**
kwargs
)
if
pretrained
:
...
...
torchvision/models/detection/mask_rcnn.py
View file @
350a3e8e
...
...
@@ -360,15 +360,15 @@ def maskrcnn_resnet50_fpn(
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is
passed (the default) this value is set to 3.
"""
train
able_backbone_layers
=
_validate_trainable_layers
(
pretrained
or
pretrained_backbo
ne
,
trainable_backbone_layers
,
5
,
3
)
is_
train
ed
=
pretrained
or
pretrained_backbone
trainable_backbone_layers
=
_validate_trainable_layers
(
is_trai
ne
d
,
trainable_backbone_layers
,
5
,
3
)
norm_layer
=
misc_nn_ops
.
FrozenBatchNorm2d
if
is_trained
else
nn
.
BatchNorm2d
if
pretrained
:
# no need to download the backbone if pretrained is set
pretrained_backbone
=
False
backbone
=
resnet50
(
pretrained
=
pretrained_backbone
,
progress
=
progress
,
norm_layer
=
misc_nn_ops
.
FrozenBatchNorm2d
)
backbone
=
resnet50
(
pretrained
=
pretrained_backbone
,
progress
=
progress
,
norm_layer
=
norm_layer
)
backbone
=
_resnet_fpn_extractor
(
backbone
,
trainable_backbone_layers
)
model
=
MaskRCNN
(
backbone
,
num_classes
,
**
kwargs
)
if
pretrained
:
...
...
torchvision/models/detection/retinanet.py
View file @
350a3e8e
...
...
@@ -626,15 +626,15 @@ def retinanet_resnet50_fpn(
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is
passed (the default) this value is set to 3.
"""
train
able_backbone_layers
=
_validate_trainable_layers
(
pretrained
or
pretrained_backbo
ne
,
trainable_backbone_layers
,
5
,
3
)
is_
train
ed
=
pretrained
or
pretrained_backbone
trainable_backbone_layers
=
_validate_trainable_layers
(
is_trai
ne
d
,
trainable_backbone_layers
,
5
,
3
)
norm_layer
=
misc_nn_ops
.
FrozenBatchNorm2d
if
is_trained
else
nn
.
BatchNorm2d
if
pretrained
:
# no need to download the backbone if pretrained is set
pretrained_backbone
=
False
backbone
=
resnet50
(
pretrained
=
pretrained_backbone
,
progress
=
progress
,
norm_layer
=
misc_nn_ops
.
FrozenBatchNorm2d
)
backbone
=
resnet50
(
pretrained
=
pretrained_backbone
,
progress
=
progress
,
norm_layer
=
norm_layer
)
# skip P2 because it generates too many anchors (according to their paper)
backbone
=
_resnet_fpn_extractor
(
backbone
,
trainable_backbone_layers
,
returned_layers
=
[
2
,
3
,
4
],
extra_blocks
=
LastLevelP6P7
(
256
,
256
)
...
...
torchvision/prototype/models/detection/faster_rcnn.py
View file @
350a3e8e
from
typing
import
Any
,
Optional
,
Union
from
torch
import
nn
from
torchvision.prototype.transforms
import
CocoEval
from
torchvision.transforms.functional
import
InterpolationMode
...
...
@@ -103,11 +104,11 @@ def fasterrcnn_resnet50_fpn(
elif
num_classes
is
None
:
num_classes
=
91
train
able_backbone_layers
=
_validate_trainable_layers
(
weights
is
not
None
or
weights_backbone
is
not
No
ne
,
trainable_backbone_layers
,
5
,
3
)
is_
train
ed
=
weights
is
not
None
or
weights_backbone
is
not
None
trainable_backbone_layers
=
_validate_trainable_layers
(
is_trai
ne
d
,
trainable_backbone_layers
,
5
,
3
)
norm_layer
=
misc_nn_ops
.
FrozenBatchNorm2d
if
is_trained
else
nn
.
BatchNorm2d
backbone
=
resnet50
(
weights
=
weights_backbone
,
progress
=
progress
,
norm_layer
=
misc_nn_ops
.
FrozenBatchNorm2d
)
backbone
=
resnet50
(
weights
=
weights_backbone
,
progress
=
progress
,
norm_layer
=
norm_layer
)
backbone
=
_resnet_fpn_extractor
(
backbone
,
trainable_backbone_layers
)
model
=
FasterRCNN
(
backbone
,
num_classes
=
num_classes
,
**
kwargs
)
...
...
@@ -134,11 +135,11 @@ def _fasterrcnn_mobilenet_v3_large_fpn(
elif
num_classes
is
None
:
num_classes
=
91
train
able_backbone_layers
=
_validate_trainable_layers
(
weights
is
not
None
or
weights_backbone
is
not
No
ne
,
trainable_backbone_layers
,
6
,
3
)
is_
train
ed
=
weights
is
not
None
or
weights_backbone
is
not
None
trainable_backbone_layers
=
_validate_trainable_layers
(
is_trai
ne
d
,
trainable_backbone_layers
,
6
,
3
)
norm_layer
=
misc_nn_ops
.
FrozenBatchNorm2d
if
is_trained
else
nn
.
BatchNorm2d
backbone
=
mobilenet_v3_large
(
weights
=
weights_backbone
,
progress
=
progress
,
norm_layer
=
misc_nn_ops
.
FrozenBatchNorm2d
)
backbone
=
mobilenet_v3_large
(
weights
=
weights_backbone
,
progress
=
progress
,
norm_layer
=
norm_layer
)
backbone
=
_mobilenet_extractor
(
backbone
,
True
,
trainable_backbone_layers
)
anchor_sizes
=
(
(
...
...
torchvision/prototype/models/detection/fcos.py
View file @
350a3e8e
from
typing
import
Any
,
Optional
from
torch
import
nn
from
torchvision.prototype.transforms
import
CocoEval
from
torchvision.transforms.functional
import
InterpolationMode
...
...
@@ -63,11 +64,11 @@ def fcos_resnet50_fpn(
elif
num_classes
is
None
:
num_classes
=
91
train
able_backbone_layers
=
_validate_trainable_layers
(
weights
is
not
None
or
weights_backbone
is
not
No
ne
,
trainable_backbone_layers
,
5
,
3
)
is_
train
ed
=
weights
is
not
None
or
weights_backbone
is
not
None
trainable_backbone_layers
=
_validate_trainable_layers
(
is_trai
ne
d
,
trainable_backbone_layers
,
5
,
3
)
norm_layer
=
misc_nn_ops
.
FrozenBatchNorm2d
if
is_trained
else
nn
.
BatchNorm2d
backbone
=
resnet50
(
weights
=
weights_backbone
,
progress
=
progress
,
norm_layer
=
misc_nn_ops
.
FrozenBatchNorm2d
)
backbone
=
resnet50
(
weights
=
weights_backbone
,
progress
=
progress
,
norm_layer
=
norm_layer
)
backbone
=
_resnet_fpn_extractor
(
backbone
,
trainable_backbone_layers
,
returned_layers
=
[
2
,
3
,
4
],
extra_blocks
=
LastLevelP6P7
(
256
,
256
)
)
...
...
torchvision/prototype/models/detection/keypoint_rcnn.py
View file @
350a3e8e
from
typing
import
Any
,
Optional
from
torch
import
nn
from
torchvision.prototype.transforms
import
CocoEval
from
torchvision.transforms.functional
import
InterpolationMode
...
...
@@ -91,11 +92,11 @@ def keypointrcnn_resnet50_fpn(
if
num_keypoints
is
None
:
num_keypoints
=
17
train
able_backbone_layers
=
_validate_trainable_layers
(
weights
is
not
None
or
weights_backbone
is
not
No
ne
,
trainable_backbone_layers
,
5
,
3
)
is_
train
ed
=
weights
is
not
None
or
weights_backbone
is
not
None
trainable_backbone_layers
=
_validate_trainable_layers
(
is_trai
ne
d
,
trainable_backbone_layers
,
5
,
3
)
norm_layer
=
misc_nn_ops
.
FrozenBatchNorm2d
if
is_trained
else
nn
.
BatchNorm2d
backbone
=
resnet50
(
weights
=
weights_backbone
,
progress
=
progress
,
norm_layer
=
misc_nn_ops
.
FrozenBatchNorm2d
)
backbone
=
resnet50
(
weights
=
weights_backbone
,
progress
=
progress
,
norm_layer
=
norm_layer
)
backbone
=
_resnet_fpn_extractor
(
backbone
,
trainable_backbone_layers
)
model
=
KeypointRCNN
(
backbone
,
num_classes
,
num_keypoints
=
num_keypoints
,
**
kwargs
)
...
...
torchvision/prototype/models/detection/mask_rcnn.py
View file @
350a3e8e
from
typing
import
Any
,
Optional
from
torch
import
nn
from
torchvision.prototype.transforms
import
CocoEval
from
torchvision.transforms.functional
import
InterpolationMode
...
...
@@ -64,11 +65,11 @@ def maskrcnn_resnet50_fpn(
elif
num_classes
is
None
:
num_classes
=
91
train
able_backbone_layers
=
_validate_trainable_layers
(
weights
is
not
None
or
weights_backbone
is
not
No
ne
,
trainable_backbone_layers
,
5
,
3
)
is_
train
ed
=
weights
is
not
None
or
weights_backbone
is
not
None
trainable_backbone_layers
=
_validate_trainable_layers
(
is_trai
ne
d
,
trainable_backbone_layers
,
5
,
3
)
norm_layer
=
misc_nn_ops
.
FrozenBatchNorm2d
if
is_trained
else
nn
.
BatchNorm2d
backbone
=
resnet50
(
weights
=
weights_backbone
,
progress
=
progress
,
norm_layer
=
misc_nn_ops
.
FrozenBatchNorm2d
)
backbone
=
resnet50
(
weights
=
weights_backbone
,
progress
=
progress
,
norm_layer
=
norm_layer
)
backbone
=
_resnet_fpn_extractor
(
backbone
,
trainable_backbone_layers
)
model
=
MaskRCNN
(
backbone
,
num_classes
=
num_classes
,
**
kwargs
)
...
...
torchvision/prototype/models/detection/retinanet.py
View file @
350a3e8e
from
typing
import
Any
,
Optional
from
torch
import
nn
from
torchvision.prototype.transforms
import
CocoEval
from
torchvision.transforms.functional
import
InterpolationMode
...
...
@@ -64,11 +65,11 @@ def retinanet_resnet50_fpn(
elif
num_classes
is
None
:
num_classes
=
91
train
able_backbone_layers
=
_validate_trainable_layers
(
weights
is
not
None
or
weights_backbone
is
not
No
ne
,
trainable_backbone_layers
,
5
,
3
)
is_
train
ed
=
weights
is
not
None
or
weights_backbone
is
not
None
trainable_backbone_layers
=
_validate_trainable_layers
(
is_trai
ne
d
,
trainable_backbone_layers
,
5
,
3
)
norm_layer
=
misc_nn_ops
.
FrozenBatchNorm2d
if
is_trained
else
nn
.
BatchNorm2d
backbone
=
resnet50
(
weights
=
weights_backbone
,
progress
=
progress
,
norm_layer
=
misc_nn_ops
.
FrozenBatchNorm2d
)
backbone
=
resnet50
(
weights
=
weights_backbone
,
progress
=
progress
,
norm_layer
=
norm_layer
)
# skip P2 because it generates too many anchors (according to their paper)
backbone
=
_resnet_fpn_extractor
(
backbone
,
trainable_backbone_layers
,
returned_layers
=
[
2
,
3
,
4
],
extra_blocks
=
LastLevelP6P7
(
256
,
256
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment