Unverified Commit 4325efe9 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Fix trainable_layers on RetinaNet + minor doc fixes (#3234)

* Fixing trainable_layers bug.

* minor doc fixes
parent 7536e298
......@@ -344,8 +344,8 @@ def fasterrcnn_resnet50_fpn(pretrained=False, progress=True,
Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017
progress (bool): If True, displays a progress bar of the download to stderr
pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet
num_classes (int): number of output classes of the model (including the background)
pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet
trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block.
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
"""
......
......@@ -312,8 +312,9 @@ def keypointrcnn_resnet50_fpn(pretrained=False, progress=True,
Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017
progress (bool): If True, displays a progress bar of the download to stderr
pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet
num_classes (int): number of output classes of the model (including the background)
num_keypoints (int): number of keypoints, default 17
pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet
trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block.
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
"""
......
......@@ -308,8 +308,8 @@ def maskrcnn_resnet50_fpn(pretrained=False, progress=True,
Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017
progress (bool): If True, displays a progress bar of the download to stderr
pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet
num_classes (int): number of output classes of the model (including the background)
pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet
trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block.
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
"""
......
......@@ -12,7 +12,7 @@ from ..utils import load_state_dict_from_url
from . import _utils as det_utils
from .anchor_utils import AnchorGenerator
from .transform import GeneralizedRCNNTransform
from .backbone_utils import resnet_fpn_backbone
from .backbone_utils import resnet_fpn_backbone, _validate_resnet_trainable_layers
from ...ops.feature_pyramid_network import LastLevelP6P7
from ...ops import sigmoid_focal_loss
from ...ops import boxes as box_ops
......@@ -564,7 +564,7 @@ model_urls = {
def retinanet_resnet50_fpn(pretrained=False, progress=True,
num_classes=91, pretrained_backbone=True, **kwargs):
num_classes=91, pretrained_backbone=True, trainable_backbone_layers=None, **kwargs):
"""
Constructs a RetinaNet model with a ResNet-50-FPN backbone.
......@@ -600,13 +600,21 @@ def retinanet_resnet50_fpn(pretrained=False, progress=True,
Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017
progress (bool): If True, displays a progress bar of the download to stderr
num_classes (int): number of output classes of the model (including the background)
pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet
trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block.
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
"""
# check default parameters and by default set it to 3 if possible
trainable_backbone_layers = _validate_resnet_trainable_layers(
pretrained or pretrained_backbone, trainable_backbone_layers)
if pretrained:
# no need to download the backbone if pretrained is set
pretrained_backbone = False
# skip P2 because it generates too many anchors (according to their paper)
backbone = resnet_fpn_backbone('resnet50', pretrained_backbone,
returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(256, 256))
backbone = resnet_fpn_backbone('resnet50', pretrained_backbone, returned_layers=[2, 3, 4],
extra_blocks=LastLevelP6P7(256, 256), trainable_layers=trainable_backbone_layers)
model = RetinaNet(backbone, num_classes, **kwargs)
if pretrained:
state_dict = load_state_dict_from_url(model_urls['retinanet_resnet50_fpn_coco'],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment