Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
4325efe9
Unverified
Commit
4325efe9
authored
Jan 08, 2021
by
Vasilis Vryniotis
Committed by
GitHub
Jan 08, 2021
Browse files
Fix trainable_layers on RetinaNet + minor doc fixes (#3234)
* Fixing trainable_layers bug. * minor doc fixes
parent
7536e298
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
16 additions
and
7 deletions
+16
-7
torchvision/models/detection/faster_rcnn.py
torchvision/models/detection/faster_rcnn.py
+1
-1
torchvision/models/detection/keypoint_rcnn.py
torchvision/models/detection/keypoint_rcnn.py
+2
-1
torchvision/models/detection/mask_rcnn.py
torchvision/models/detection/mask_rcnn.py
+1
-1
torchvision/models/detection/retinanet.py
torchvision/models/detection/retinanet.py
+12
-4
No files found.
torchvision/models/detection/faster_rcnn.py
View file @
4325efe9
...
@@ -344,8 +344,8 @@ def fasterrcnn_resnet50_fpn(pretrained=False, progress=True,
...
@@ -344,8 +344,8 @@ def fasterrcnn_resnet50_fpn(pretrained=False, progress=True,
Args:
Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017
pretrained (bool): If True, returns a model pre-trained on COCO train2017
progress (bool): If True, displays a progress bar of the download to stderr
progress (bool): If True, displays a progress bar of the download to stderr
pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet
num_classes (int): number of output classes of the model (including the background)
num_classes (int): number of output classes of the model (including the background)
pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet
trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block.
trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block.
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
"""
"""
...
...
torchvision/models/detection/keypoint_rcnn.py
View file @
4325efe9
...
@@ -312,8 +312,9 @@ def keypointrcnn_resnet50_fpn(pretrained=False, progress=True,
...
@@ -312,8 +312,9 @@ def keypointrcnn_resnet50_fpn(pretrained=False, progress=True,
Args:
Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017
pretrained (bool): If True, returns a model pre-trained on COCO train2017
progress (bool): If True, displays a progress bar of the download to stderr
progress (bool): If True, displays a progress bar of the download to stderr
pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet
num_classes (int): number of output classes of the model (including the background)
num_classes (int): number of output classes of the model (including the background)
num_keypoints (int): number of keypoints, default 17
pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet
trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block.
trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block.
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
"""
"""
...
...
torchvision/models/detection/mask_rcnn.py
View file @
4325efe9
...
@@ -308,8 +308,8 @@ def maskrcnn_resnet50_fpn(pretrained=False, progress=True,
...
@@ -308,8 +308,8 @@ def maskrcnn_resnet50_fpn(pretrained=False, progress=True,
Args:
Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017
pretrained (bool): If True, returns a model pre-trained on COCO train2017
progress (bool): If True, displays a progress bar of the download to stderr
progress (bool): If True, displays a progress bar of the download to stderr
pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet
num_classes (int): number of output classes of the model (including the background)
num_classes (int): number of output classes of the model (including the background)
pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet
trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block.
trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block.
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
"""
"""
...
...
torchvision/models/detection/retinanet.py
View file @
4325efe9
...
@@ -12,7 +12,7 @@ from ..utils import load_state_dict_from_url
...
@@ -12,7 +12,7 @@ from ..utils import load_state_dict_from_url
from
.
import
_utils
as
det_utils
from
.
import
_utils
as
det_utils
from
.anchor_utils
import
AnchorGenerator
from
.anchor_utils
import
AnchorGenerator
from
.transform
import
GeneralizedRCNNTransform
from
.transform
import
GeneralizedRCNNTransform
from
.backbone_utils
import
resnet_fpn_backbone
from
.backbone_utils
import
resnet_fpn_backbone
,
_validate_resnet_trainable_layers
from
...ops.feature_pyramid_network
import
LastLevelP6P7
from
...ops.feature_pyramid_network
import
LastLevelP6P7
from
...ops
import
sigmoid_focal_loss
from
...ops
import
sigmoid_focal_loss
from
...ops
import
boxes
as
box_ops
from
...ops
import
boxes
as
box_ops
...
@@ -564,7 +564,7 @@ model_urls = {
...
@@ -564,7 +564,7 @@ model_urls = {
def
retinanet_resnet50_fpn
(
pretrained
=
False
,
progress
=
True
,
def
retinanet_resnet50_fpn
(
pretrained
=
False
,
progress
=
True
,
num_classes
=
91
,
pretrained_backbone
=
True
,
**
kwargs
):
num_classes
=
91
,
pretrained_backbone
=
True
,
trainable_backbone_layers
=
None
,
**
kwargs
):
"""
"""
Constructs a RetinaNet model with a ResNet-50-FPN backbone.
Constructs a RetinaNet model with a ResNet-50-FPN backbone.
...
@@ -600,13 +600,21 @@ def retinanet_resnet50_fpn(pretrained=False, progress=True,
...
@@ -600,13 +600,21 @@ def retinanet_resnet50_fpn(pretrained=False, progress=True,
Args:
Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017
pretrained (bool): If True, returns a model pre-trained on COCO train2017
progress (bool): If True, displays a progress bar of the download to stderr
progress (bool): If True, displays a progress bar of the download to stderr
num_classes (int): number of output classes of the model (including the background)
pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet
trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block.
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
"""
"""
# check default parameters and by default set it to 3 if possible
trainable_backbone_layers
=
_validate_resnet_trainable_layers
(
pretrained
or
pretrained_backbone
,
trainable_backbone_layers
)
if
pretrained
:
if
pretrained
:
# no need to download the backbone if pretrained is set
# no need to download the backbone if pretrained is set
pretrained_backbone
=
False
pretrained_backbone
=
False
# skip P2 because it generates too many anchors (according to their paper)
# skip P2 because it generates too many anchors (according to their paper)
backbone
=
resnet_fpn_backbone
(
'resnet50'
,
pretrained_backbone
,
backbone
=
resnet_fpn_backbone
(
'resnet50'
,
pretrained_backbone
,
returned_layers
=
[
2
,
3
,
4
],
returned_layers
=
[
2
,
3
,
4
],
extra_blocks
=
LastLevelP6P7
(
256
,
256
))
extra_blocks
=
LastLevelP6P7
(
256
,
256
)
,
trainable_layers
=
trainable_backbone_layers
)
model
=
RetinaNet
(
backbone
,
num_classes
,
**
kwargs
)
model
=
RetinaNet
(
backbone
,
num_classes
,
**
kwargs
)
if
pretrained
:
if
pretrained
:
state_dict
=
load_state_dict_from_url
(
model_urls
[
'retinanet_resnet50_fpn_coco'
],
state_dict
=
load_state_dict_from_url
(
model_urls
[
'retinanet_resnet50_fpn_coco'
],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment