Unverified Commit 348dd5a7 authored by Urwa Muaz's avatar Urwa Muaz Committed by GitHub
Browse files

Feat/unfreeze layers fpn backbone (#2160)

* freeze layers only if pretrained backbone is used

If pretrained backbone is not used and one intends to train the entire network from scratch, no layers should be frozen.

* function argument to control the trainable features

Depending on the size of dataset one might want to control the number of tunable parameters in the backbone, and this parameter in hyper parameter optimization for the dataset. It would be nice to have this function support this.

* ensuring tunable layer argument is valid

* backbone freezing in fasterrcnn_resnet50_fpn

Handle backbone freezing in fasterrcnn_resnet50_fpn function rather than the resnet_fpn_backbone function that it uses to get the backbone.

* remove layer freezing code

layer freezing code has been moved to fasterrcnn_resnet50_fpn function that consumes resnet_fpn_backbone function.

* correcting linting errors

* correcting linting errors

* move freezing logic to resnet_fpn_backbone

Moved layer freezing logic to resnet_fpn_backbone with an additional parameter.

* remove layer freezing from fasterrcnn_resnet50_fpn

Layer freezing logic has been moved to resnet_fpn_backbone. This function only ensures that the all layers are made trainable if pretrained models are not used.

* update example resnet_fpn_backbone docs

* correct typo in var name

* correct indentation

* adding test case for layer freezing in faster rcnn

This PR adds functionality to specify the number of trainable layers while initializing the faster rcnn using fasterrcnn_resnet50_fpn function. This commits adds a test case to test this functionality.

* updating layer freezing condition for clarity

More information in PR

* remove linting errors

* removing linting errors

* removing linting errors
parent a6073f07
import torch
from torchvision.models.detection import _utils
import unittest
from torchvision.models.detection import fasterrcnn_resnet50_fpn
class Tester(unittest.TestCase):
......@@ -17,6 +18,21 @@ class Tester(unittest.TestCase):
self.assertEqual(neg[0].sum(), 3)
self.assertEqual(neg[0][0:6].sum(), 3)
def test_fasterrcnn_resnet50_fpn_frozen_layers(self):
# we know how many initial layers and parameters of the network should
# be frozen for each trainable_backbone_layers paramter value
# i.e all 53 params are frozen if trainable_backbone_layers=0
# ad first 24 params are frozen if trainable_backbone_layers=2
expected_frozen_params = {0: 53, 1: 43, 2: 24, 3: 11, 4: 1, 5: 0}
for train_layers, exp_froz_params in expected_frozen_params.items():
model = fasterrcnn_resnet50_fpn(pretrained=True, progress=False,
num_classes=91, pretrained_backbone=False,
trainable_backbone_layers=train_layers)
# boolean list that is true if the param at that index is frozen
is_frozen = [not parameter.requires_grad for _, parameter in model.named_parameters()]
# check that expected initial number of layers are frozen
self.assertTrue(all(is_frozen[:exp_froz_params]))
if __name__ == '__main__':
unittest.main()
......@@ -41,13 +41,44 @@ class BackboneWithFPN(nn.Module):
return x
def resnet_fpn_backbone(backbone_name, pretrained, norm_layer=misc_nn_ops.FrozenBatchNorm2d):
def resnet_fpn_backbone(backbone_name, pretrained, norm_layer=misc_nn_ops.FrozenBatchNorm2d, trainable_layers=3):
backbone = resnet.__dict__[backbone_name](
pretrained=pretrained,
norm_layer=norm_layer)
# freeze layers
"""
Constructs a specified ResNet backbone with FPN on top. Freezes the specified number of layers in the backbone.
Examples::
>>> from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
>>> backbone = resnet_fpn_backbone('resnet50', pretrained=True, trainable_layers=3)
>>> # get some dummy image
>>> x = torch.rand(1,3,64,64)
>>> # compute the output
>>> output = backbone(x)
>>> print([(k, v.shape) for k, v in output.items()])
>>> # returns
>>> [('0', torch.Size([1, 256, 16, 16])),
>>> ('1', torch.Size([1, 256, 8, 8])),
>>> ('2', torch.Size([1, 256, 4, 4])),
>>> ('3', torch.Size([1, 256, 2, 2])),
>>> ('pool', torch.Size([1, 256, 1, 1]))]
Arguments:
backbone_name (string): resnet architecture. Possible values are 'ResNet', 'resnet18', 'resnet34', 'resnet50',
'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 'wide_resnet50_2', 'wide_resnet101_2'
norm_layer (torchvision.ops): it is recommended to use the default value. For details visit:
(https://github.com/facebookresearch/maskrcnn-benchmark/issues/267)
pretrained (bool): If True, returns a model with backbone pre-trained on Imagenet
trainable_layers (int): number of trainable (not frozen) resnet layers starting from final block.
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
"""
# select layers that wont be frozen
assert trainable_layers <= 5 and trainable_layers >= 0
layers_to_train = ['layer4', 'layer3', 'layer2', 'layer1', 'conv1'][:trainable_layers]
# freeze layers only if pretrained backbone is used
for name, parameter in backbone.named_parameters():
if 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
if all([not name.startswith(layer) for layer in layers_to_train]):
parameter.requires_grad_(False)
return_layers = {'layer1': '0', 'layer2': '1', 'layer3': '2', 'layer4': '3'}
......
......@@ -289,7 +289,7 @@ model_urls = {
def fasterrcnn_resnet50_fpn(pretrained=False, progress=True,
num_classes=91, pretrained_backbone=True, **kwargs):
num_classes=91, pretrained_backbone=True, trainable_backbone_layers=3, **kwargs):
"""
Constructs a Faster R-CNN model with a ResNet-50-FPN backbone.
......@@ -342,11 +342,19 @@ def fasterrcnn_resnet50_fpn(pretrained=False, progress=True,
Arguments:
pretrained (bool): If True, returns a model pre-trained on COCO train2017
progress (bool): If True, displays a progress bar of the download to stderr
pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet
num_classes (int): number of output classes of the model (including the background)
trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block.
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
"""
assert trainable_backbone_layers <= 5 and trainable_backbone_layers >= 0
# dont freeze any layers if pretrained model or backbone is not used
if not (pretrained or pretrained_backbone):
trainable_backbone_layers = 5
if pretrained:
# no need to download the backbone if pretrained is set
pretrained_backbone = False
backbone = resnet_fpn_backbone('resnet50', pretrained_backbone)
backbone = resnet_fpn_backbone('resnet50', pretrained_backbone, trainable_layers=trainable_backbone_layers)
model = FasterRCNN(backbone, num_classes, **kwargs)
if pretrained:
state_dict = load_state_dict_from_url(model_urls['fasterrcnn_resnet50_fpn_coco'],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment