Unverified Commit f8e2291d authored by Urwa Muaz's avatar Urwa Muaz Committed by GitHub
Browse files

Feature/layer freezing maskrcnn keypointrcnn (#2242)

* add layer freezing param to maskrcnn_resnet50_fpn

* freeze ayer param in keypointrcnn_resnet50_fpn

* layer freeze tests for mask and keypoint rcnn

* correct linting errors

* correct linting errors.

* correct linting errors
parent c558be6b
...@@ -3,7 +3,7 @@ import torch ...@@ -3,7 +3,7 @@ import torch
from torchvision.models.detection import _utils from torchvision.models.detection import _utils
from torchvision.models.detection.transform import GeneralizedRCNNTransform from torchvision.models.detection.transform import GeneralizedRCNNTransform
import unittest import unittest
from torchvision.models.detection import fasterrcnn_resnet50_fpn from torchvision.models.detection import fasterrcnn_resnet50_fpn, maskrcnn_resnet50_fpn, keypointrcnn_resnet50_fpn
class Tester(unittest.TestCase): class Tester(unittest.TestCase):
...@@ -35,6 +35,36 @@ class Tester(unittest.TestCase): ...@@ -35,6 +35,36 @@ class Tester(unittest.TestCase):
# check that expected initial number of layers are frozen # check that expected initial number of layers are frozen
self.assertTrue(all(is_frozen[:exp_froz_params])) self.assertTrue(all(is_frozen[:exp_froz_params]))
def test_maskrcnn_resnet50_fpn_frozen_layers(self):
# we know how many initial layers and parameters of the maskrcnn should
# be frozen for each trainable_backbone_layers paramter value
# i.e all 53 params are frozen if trainable_backbone_layers=0
# ad first 24 params are frozen if trainable_backbone_layers=2
expected_frozen_params = {0: 53, 1: 43, 2: 24, 3: 11, 4: 1, 5: 0}
for train_layers, exp_froz_params in expected_frozen_params.items():
model = maskrcnn_resnet50_fpn(pretrained=True, progress=False,
num_classes=91, pretrained_backbone=False,
trainable_backbone_layers=train_layers)
# boolean list that is true if the parameter at that index is frozen
is_frozen = [not parameter.requires_grad for _, parameter in model.named_parameters()]
# check that expected initial number of layers in maskrcnn are frozen
self.assertTrue(all(is_frozen[:exp_froz_params]))
def test_keypointrcnn_resnet50_fpn_frozen_layers(self):
# we know how many initial layers and parameters of the keypointrcnn should
# be frozen for each trainable_backbone_layers paramter value
# i.e all 53 params are frozen if trainable_backbone_layers=0
# ad first 24 params are frozen if trainable_backbone_layers=2
expected_frozen_params = {0: 53, 1: 43, 2: 24, 3: 11, 4: 1, 5: 0}
for train_layers, exp_froz_params in expected_frozen_params.items():
model = keypointrcnn_resnet50_fpn(pretrained=True, progress=False,
num_classes=2, pretrained_backbone=False,
trainable_backbone_layers=train_layers)
# boolean list that is true if the parameter at that index is frozen
is_frozen = [not parameter.requires_grad for _, parameter in model.named_parameters()]
# check that expected initial number of layers in keypointrcnn are frozen
self.assertTrue(all(is_frozen[:exp_froz_params]))
def test_transform_copy_targets(self): def test_transform_copy_targets(self):
transform = GeneralizedRCNNTransform(300, 500, torch.zeros(3), torch.ones(3)) transform = GeneralizedRCNNTransform(300, 500, torch.zeros(3), torch.ones(3))
image = [torch.rand(3, 200, 300), torch.rand(3, 200, 200)] image = [torch.rand(3, 200, 300), torch.rand(3, 200, 200)]
......
...@@ -270,7 +270,7 @@ model_urls = { ...@@ -270,7 +270,7 @@ model_urls = {
def keypointrcnn_resnet50_fpn(pretrained=False, progress=True, def keypointrcnn_resnet50_fpn(pretrained=False, progress=True,
num_classes=2, num_keypoints=17, num_classes=2, num_keypoints=17,
pretrained_backbone=True, **kwargs): pretrained_backbone=True, trainable_backbone_layers=3, **kwargs):
""" """
Constructs a Keypoint R-CNN model with a ResNet-50-FPN backbone. Constructs a Keypoint R-CNN model with a ResNet-50-FPN backbone.
...@@ -314,11 +314,19 @@ def keypointrcnn_resnet50_fpn(pretrained=False, progress=True, ...@@ -314,11 +314,19 @@ def keypointrcnn_resnet50_fpn(pretrained=False, progress=True,
Arguments: Arguments:
pretrained (bool): If True, returns a model pre-trained on COCO train2017 pretrained (bool): If True, returns a model pre-trained on COCO train2017
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet
num_classes (int): number of output classes of the model (including the background)
trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block.
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
""" """
assert trainable_backbone_layers <= 5 and trainable_backbone_layers >= 0
# dont freeze any layers if pretrained model or backbone is not used
if not (pretrained or pretrained_backbone):
trainable_backbone_layers = 5
if pretrained: if pretrained:
# no need to download the backbone if pretrained is set # no need to download the backbone if pretrained is set
pretrained_backbone = False pretrained_backbone = False
backbone = resnet_fpn_backbone('resnet50', pretrained_backbone) backbone = resnet_fpn_backbone('resnet50', pretrained_backbone, trainable_layers=trainable_backbone_layers)
model = KeypointRCNN(backbone, num_classes, num_keypoints=num_keypoints, **kwargs) model = KeypointRCNN(backbone, num_classes, num_keypoints=num_keypoints, **kwargs)
if pretrained: if pretrained:
key = 'keypointrcnn_resnet50_fpn_coco' key = 'keypointrcnn_resnet50_fpn_coco'
......
...@@ -265,7 +265,7 @@ model_urls = { ...@@ -265,7 +265,7 @@ model_urls = {
def maskrcnn_resnet50_fpn(pretrained=False, progress=True, def maskrcnn_resnet50_fpn(pretrained=False, progress=True,
num_classes=91, pretrained_backbone=True, **kwargs): num_classes=91, pretrained_backbone=True, trainable_backbone_layers=3, **kwargs):
""" """
Constructs a Mask R-CNN model with a ResNet-50-FPN backbone. Constructs a Mask R-CNN model with a ResNet-50-FPN backbone.
...@@ -310,11 +310,19 @@ def maskrcnn_resnet50_fpn(pretrained=False, progress=True, ...@@ -310,11 +310,19 @@ def maskrcnn_resnet50_fpn(pretrained=False, progress=True,
Arguments: Arguments:
pretrained (bool): If True, returns a model pre-trained on COCO train2017 pretrained (bool): If True, returns a model pre-trained on COCO train2017
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet
num_classes (int): number of output classes of the model (including the background)
trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block.
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
""" """
assert trainable_backbone_layers <= 5 and trainable_backbone_layers >= 0
# dont freeze any layers if pretrained model or backbone is not used
if not (pretrained or pretrained_backbone):
trainable_backbone_layers = 5
if pretrained: if pretrained:
# no need to download the backbone if pretrained is set # no need to download the backbone if pretrained is set
pretrained_backbone = False pretrained_backbone = False
backbone = resnet_fpn_backbone('resnet50', pretrained_backbone) backbone = resnet_fpn_backbone('resnet50', pretrained_backbone, trainable_layers=trainable_backbone_layers)
model = MaskRCNN(backbone, num_classes, **kwargs) model = MaskRCNN(backbone, num_classes, **kwargs)
if pretrained: if pretrained:
state_dict = load_state_dict_from_url(model_urls['maskrcnn_resnet50_fpn_coco'], state_dict = load_state_dict_from_url(model_urls['maskrcnn_resnet50_fpn_coco'],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment