Unverified Commit bf211dac authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Add MobileNetV3 architecture for Detection (#3253)

* Minor refactoring of a private method to make it reusuable.

* Adding a FasterRCNN + MobileNetV3 with & w/o FPN models.

* Reducing Resolution to 320-640 and anchor sizes to 16-256.

* Increase anchor sizes.

* Adding rpn score threshold param on the train script.

* Adding trainable_backbone_layers param on the train script.

* Adding rpn_score_thresh param directly in fasterrcnn_mobilenet_v3_large_fpn.

* Remove fasterrcnn_mobilenet_v3_large prototype and update expected file.

* Update documentation and adding weights.

* Use buildin Identity.

* Fix spelling.
parent 0985533e
...@@ -358,13 +358,14 @@ models return the predictions of the following classes: ...@@ -358,13 +358,14 @@ models return the predictions of the following classes:
Here are the summary of the accuracies for the models trained on Here are the summary of the accuracies for the models trained on
the instances set of COCO train2017 and evaluated on COCO val2017. the instances set of COCO train2017 and evaluated on COCO val2017.
================================ ======= ======== =========== ================================== ======= ======== ===========
Network box AP mask AP keypoint AP Network box AP mask AP keypoint AP
================================ ======= ======== =========== ================================== ======= ======== ===========
Faster R-CNN ResNet-50 FPN 37.0 - - Faster R-CNN ResNet-50 FPN 37.0 - -
RetinaNet ResNet-50 FPN 36.4 - - Faster R-CNN MobileNetV3-Large FPN 23.0 - -
Mask R-CNN ResNet-50 FPN 37.9 34.6 - RetinaNet ResNet-50 FPN 36.4 - -
================================ ======= ======== =========== Mask R-CNN ResNet-50 FPN 37.9 34.6 -
================================== ======= ======== ===========
For person keypoint detection, the accuracies for the pre-trained For person keypoint detection, the accuracies for the pre-trained
models are as follows models are as follows
...@@ -414,20 +415,22 @@ For test time, we report the time for the model evaluation and postprocessing ...@@ -414,20 +415,22 @@ For test time, we report the time for the model evaluation and postprocessing
(including mask pasting in image), but not the time for computing the (including mask pasting in image), but not the time for computing the
precision-recall. precision-recall.
============================== =================== ================== =========== ================================== =================== ================== ===========
Network train time (s / it) test time (s / it) memory (GB) Network train time (s / it) test time (s / it) memory (GB)
============================== =================== ================== =========== ================================== =================== ================== ===========
Faster R-CNN ResNet-50 FPN 0.2288 0.0590 5.2 Faster R-CNN ResNet-50 FPN 0.2288 0.0590 5.2
RetinaNet ResNet-50 FPN 0.2514 0.0939 4.1 Faster R-CNN MobileNetV3-Large FPN 0.0978 0.0376 0.6
Mask R-CNN ResNet-50 FPN 0.2728 0.0903 5.4 RetinaNet ResNet-50 FPN 0.2514 0.0939 4.1
Keypoint R-CNN ResNet-50 FPN 0.3789 0.1242 6.8 Mask R-CNN ResNet-50 FPN 0.2728 0.0903 5.4
============================== =================== ================== =========== Keypoint R-CNN ResNet-50 FPN 0.3789 0.1242 6.8
================================== =================== ================== ===========
Faster R-CNN Faster R-CNN
------------ ------------
.. autofunction:: torchvision.models.detection.fasterrcnn_resnet50_fpn .. autofunction:: torchvision.models.detection.fasterrcnn_resnet50_fpn
.. autofunction:: torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn
RetinaNet RetinaNet
......
...@@ -20,13 +20,20 @@ You must modify the following flags: ...@@ -20,13 +20,20 @@ You must modify the following flags:
Except otherwise noted, all models have been trained on 8x V100 GPUs. Except otherwise noted, all models have been trained on 8x V100 GPUs.
### Faster R-CNN ### Faster R-CNN ResNet-50 FPN
``` ```
python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\ python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\
--dataset coco --model fasterrcnn_resnet50_fpn --epochs 26\ --dataset coco --model fasterrcnn_resnet50_fpn --epochs 26\
--lr-steps 16 22 --aspect-ratio-group-factor 3 --lr-steps 16 22 --aspect-ratio-group-factor 3
``` ```
### Faster R-CNN MobileNetV3-Large FPN
```
python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\
--dataset coco --model fasterrcnn_mobilenet_v3_large_fpn --epochs 26\
--lr-steps 16 22 --aspect-ratio-group-factor 3
```
### RetinaNet ### RetinaNet
``` ```
python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\ python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\
......
...@@ -92,9 +92,12 @@ def main(args): ...@@ -92,9 +92,12 @@ def main(args):
collate_fn=utils.collate_fn) collate_fn=utils.collate_fn)
print("Creating model") print("Creating model")
kwargs = {} kwargs = {
"trainable_backbone_layers": args.trainable_backbone_layers
}
if "rcnn" in args.model: if "rcnn" in args.model:
kwargs["rpn_score_thresh"] = 0.0 if args.rpn_score_thresh is not None:
kwargs["rpn_score_thresh"] = args.rpn_score_thresh
model = torchvision.models.detection.__dict__[args.model](num_classes=num_classes, pretrained=args.pretrained, model = torchvision.models.detection.__dict__[args.model](num_classes=num_classes, pretrained=args.pretrained,
**kwargs) **kwargs)
model.to(device) model.to(device)
...@@ -177,6 +180,9 @@ if __name__ == "__main__": ...@@ -177,6 +180,9 @@ if __name__ == "__main__":
parser.add_argument('--resume', default='', help='resume from checkpoint') parser.add_argument('--resume', default='', help='resume from checkpoint')
parser.add_argument('--start_epoch', default=0, type=int, help='start epoch') parser.add_argument('--start_epoch', default=0, type=int, help='start epoch')
parser.add_argument('--aspect-ratio-group-factor', default=3, type=int) parser.add_argument('--aspect-ratio-group-factor', default=3, type=int)
parser.add_argument('--rpn-score-thresh', default=None, type=float, help='rpn score threshold for faster-rcnn')
parser.add_argument('--trainable-backbone-layers', default=None, type=int,
help='number of trainable layers of backbone')
parser.add_argument( parser.add_argument(
"--test-only", "--test-only",
dest="test_only", dest="test_only",
......
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
...@@ -37,6 +37,7 @@ script_model_unwrapper = { ...@@ -37,6 +37,7 @@ script_model_unwrapper = {
'googlenet': lambda x: x.logits, 'googlenet': lambda x: x.logits,
'inception_v3': lambda x: x.logits, 'inception_v3': lambda x: x.logits,
"fasterrcnn_resnet50_fpn": lambda x: x[1], "fasterrcnn_resnet50_fpn": lambda x: x[1],
"fasterrcnn_mobilenet_v3_large_fpn": lambda x: x[1],
"maskrcnn_resnet50_fpn": lambda x: x[1], "maskrcnn_resnet50_fpn": lambda x: x[1],
"keypointrcnn_resnet50_fpn": lambda x: x[1], "keypointrcnn_resnet50_fpn": lambda x: x[1],
"retinanet_resnet50_fpn": lambda x: x[1], "retinanet_resnet50_fpn": lambda x: x[1],
...@@ -105,6 +106,8 @@ class ModelTester(TestCase): ...@@ -105,6 +106,8 @@ class ModelTester(TestCase):
if "retinanet" in name: if "retinanet" in name:
# Reduce the default threshold to ensure the returned boxes are not empty. # Reduce the default threshold to ensure the returned boxes are not empty.
kwargs["score_thresh"] = 0.01 kwargs["score_thresh"] = 0.01
elif "fasterrcnn_mobilenet" in name:
kwargs["box_score_thresh"] = 0.02076
model = models.detection.__dict__[name](num_classes=50, pretrained_backbone=False, **kwargs) model = models.detection.__dict__[name](num_classes=50, pretrained_backbone=False, **kwargs)
model.eval().to(device=dev) model.eval().to(device=dev)
input_shape = (3, 300, 300) input_shape = (3, 300, 300)
......
...@@ -97,14 +97,15 @@ class Tester(unittest.TestCase): ...@@ -97,14 +97,15 @@ class Tester(unittest.TestCase):
self.assertEqual(labels[0].dtype, torch.int64) self.assertEqual(labels[0].dtype, torch.int64)
def test_forward_negative_sample_frcnn(self): def test_forward_negative_sample_frcnn(self):
model = torchvision.models.detection.fasterrcnn_resnet50_fpn( for name in ["fasterrcnn_resnet50_fpn", "fasterrcnn_mobilenet_v3_large_fpn"]:
num_classes=2, min_size=100, max_size=100) model = torchvision.models.detection.__dict__[name](
num_classes=2, min_size=100, max_size=100)
images, targets = self._make_empty_sample() images, targets = self._make_empty_sample()
loss_dict = model(images, targets) loss_dict = model(images, targets)
self.assertEqual(loss_dict["loss_box_reg"], torch.tensor(0.)) self.assertEqual(loss_dict["loss_box_reg"], torch.tensor(0.))
self.assertEqual(loss_dict["loss_rpn_box_reg"], torch.tensor(0.)) self.assertEqual(loss_dict["loss_rpn_box_reg"], torch.tensor(0.))
def test_forward_negative_sample_mrcnn(self): def test_forward_negative_sample_mrcnn(self):
model = torchvision.models.detection.maskrcnn_resnet50_fpn( model = torchvision.models.detection.maskrcnn_resnet50_fpn(
...@@ -130,7 +131,7 @@ class Tester(unittest.TestCase): ...@@ -130,7 +131,7 @@ class Tester(unittest.TestCase):
def test_forward_negative_sample_retinanet(self): def test_forward_negative_sample_retinanet(self):
model = torchvision.models.detection.retinanet_resnet50_fpn( model = torchvision.models.detection.retinanet_resnet50_fpn(
num_classes=2, min_size=100, max_size=100) num_classes=2, min_size=100, max_size=100, pretrained_backbone=False)
images, targets = self._make_empty_sample() images, targets = self._make_empty_sample()
loss_dict = model(images, targets) loss_dict = model(images, targets)
......
...@@ -36,17 +36,17 @@ class Tester(unittest.TestCase): ...@@ -36,17 +36,17 @@ class Tester(unittest.TestCase):
def test_validate_resnet_inputs_detection(self): def test_validate_resnet_inputs_detection(self):
# default number of backbone layers to train # default number of backbone layers to train
ret = backbone_utils._validate_resnet_trainable_layers( ret = backbone_utils._validate_trainable_layers(
pretrained=True, trainable_backbone_layers=None) pretrained=True, trainable_backbone_layers=None, max_value=5, default_value=3)
self.assertEqual(ret, 3) self.assertEqual(ret, 3)
# can't go beyond 5 # can't go beyond 5
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
ret = backbone_utils._validate_resnet_trainable_layers( ret = backbone_utils._validate_trainable_layers(
pretrained=True, trainable_backbone_layers=6) pretrained=True, trainable_backbone_layers=6, max_value=5, default_value=3)
# if not pretrained, should use all trainable layers and warn # if not pretrained, should use all trainable layers and warn
with self.assertWarns(UserWarning): with self.assertWarns(UserWarning):
ret = backbone_utils._validate_resnet_trainable_layers( ret = backbone_utils._validate_trainable_layers(
pretrained=False, trainable_backbone_layers=0) pretrained=False, trainable_backbone_layers=0, max_value=5, default_value=3)
self.assertEqual(ret, 5) self.assertEqual(ret, 5)
def test_transform_copy_targets(self): def test_transform_copy_targets(self):
......
import warnings import warnings
from collections import OrderedDict
from torch import nn from torch import nn
from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork, LastLevelMaxPool from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork, LastLevelMaxPool
from torchvision.ops import misc as misc_nn_ops from torchvision.ops import misc as misc_nn_ops
from .._utils import IntermediateLayerGetter from .._utils import IntermediateLayerGetter
from .. import mobilenet
from .. import resnet from .. import resnet
...@@ -108,17 +108,65 @@ def resnet_fpn_backbone( ...@@ -108,17 +108,65 @@ def resnet_fpn_backbone(
return BackboneWithFPN(backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks) return BackboneWithFPN(backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks)
def _validate_resnet_trainable_layers(pretrained, trainable_backbone_layers): def _validate_trainable_layers(pretrained, trainable_backbone_layers, max_value, default_value):
# dont freeze any layers if pretrained model or backbone is not used # dont freeze any layers if pretrained model or backbone is not used
if not pretrained: if not pretrained:
if trainable_backbone_layers is not None: if trainable_backbone_layers is not None:
warnings.warn( warnings.warn(
"Changing trainable_backbone_layers has not effect if " "Changing trainable_backbone_layers has not effect if "
"neither pretrained nor pretrained_backbone have been set to True, " "neither pretrained nor pretrained_backbone have been set to True, "
"falling back to trainable_backbone_layers=5 so that all layers are trainable") "falling back to trainable_backbone_layers={} so that all layers are trainable".format(max_value))
trainable_backbone_layers = 5 trainable_backbone_layers = max_value
# by default, freeze first 2 blocks following Faster R-CNN
# by default freeze first blocks
if trainable_backbone_layers is None: if trainable_backbone_layers is None:
trainable_backbone_layers = 3 trainable_backbone_layers = default_value
assert trainable_backbone_layers <= 5 and trainable_backbone_layers >= 0 assert 0 <= trainable_backbone_layers <= max_value
return trainable_backbone_layers return trainable_backbone_layers
def mobilenet_backbone(
backbone_name,
pretrained,
fpn,
norm_layer=misc_nn_ops.FrozenBatchNorm2d,
trainable_layers=2,
returned_layers=None,
extra_blocks=None
):
backbone = mobilenet.__dict__[backbone_name](pretrained=pretrained, norm_layer=norm_layer).features
# Gather the indeces of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
# The first and last blocks are always included because they are the C0 (conv1) and Cn.
stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "is_strided", False)] + [len(backbone) - 1]
num_stages = len(stage_indices)
# find the index of the layer from which we wont freeze
assert 0 <= trainable_layers <= num_stages
freeze_before = num_stages if trainable_layers == 0 else stage_indices[num_stages - trainable_layers]
# freeze layers only if pretrained backbone is used
for b in backbone[:freeze_before]:
for parameter in b.parameters():
parameter.requires_grad_(False)
out_channels = 256
if fpn:
if extra_blocks is None:
extra_blocks = LastLevelMaxPool()
if returned_layers is None:
returned_layers = [num_stages - 2, num_stages - 1]
assert min(returned_layers) >= 0 and max(returned_layers) < num_stages
return_layers = {f'{stage_indices[k]}': str(v) for v, k in enumerate(returned_layers)}
in_channels_list = [backbone[stage_indices[i]].out_channels for i in returned_layers]
return BackboneWithFPN(backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks)
else:
m = nn.Sequential(
backbone,
# depthwise linear combination of channels to reduce their size
nn.Conv2d(backbone[-1].out_channels, out_channels, 1),
)
m.out_channels = out_channels
return m
from collections import OrderedDict
import torch import torch
from torch import nn from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
from torchvision.ops import misc as misc_nn_ops
from torchvision.ops import MultiScaleRoIAlign from torchvision.ops import MultiScaleRoIAlign
from ._utils import overwrite_eps from ._utils import overwrite_eps
...@@ -15,11 +12,11 @@ from .generalized_rcnn import GeneralizedRCNN ...@@ -15,11 +12,11 @@ from .generalized_rcnn import GeneralizedRCNN
from .rpn import RPNHead, RegionProposalNetwork from .rpn import RPNHead, RegionProposalNetwork
from .roi_heads import RoIHeads from .roi_heads import RoIHeads
from .transform import GeneralizedRCNNTransform from .transform import GeneralizedRCNNTransform
from .backbone_utils import resnet_fpn_backbone, _validate_resnet_trainable_layers from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers, mobilenet_backbone
__all__ = [ __all__ = [
"FasterRCNN", "fasterrcnn_resnet50_fpn", "FasterRCNN", "fasterrcnn_resnet50_fpn", "fasterrcnn_mobilenet_v3_large_fpn"
] ]
...@@ -291,6 +288,8 @@ class FastRCNNPredictor(nn.Module): ...@@ -291,6 +288,8 @@ class FastRCNNPredictor(nn.Module):
model_urls = { model_urls = {
'fasterrcnn_resnet50_fpn_coco': 'fasterrcnn_resnet50_fpn_coco':
'https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth', 'https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth',
'fasterrcnn_mobilenet_v3_large_fpn_coco':
'https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-907ea3f9.pth',
} }
...@@ -353,9 +352,8 @@ def fasterrcnn_resnet50_fpn(pretrained=False, progress=True, ...@@ -353,9 +352,8 @@ def fasterrcnn_resnet50_fpn(pretrained=False, progress=True,
trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block. trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block.
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
""" """
# check default parameters and by default set it to 3 if possible trainable_backbone_layers = _validate_trainable_layers(
trainable_backbone_layers = _validate_resnet_trainable_layers( pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3)
pretrained or pretrained_backbone, trainable_backbone_layers)
if pretrained: if pretrained:
# no need to download the backbone if pretrained is set # no need to download the backbone if pretrained is set
...@@ -368,3 +366,48 @@ def fasterrcnn_resnet50_fpn(pretrained=False, progress=True, ...@@ -368,3 +366,48 @@ def fasterrcnn_resnet50_fpn(pretrained=False, progress=True,
model.load_state_dict(state_dict) model.load_state_dict(state_dict)
overwrite_eps(model, 0.0) overwrite_eps(model, 0.0)
return model return model
def fasterrcnn_mobilenet_v3_large_fpn(pretrained=False, progress=True, num_classes=91, pretrained_backbone=True,
trainable_backbone_layers=None, min_size=320, max_size=640, rpn_score_thresh=0.05,
**kwargs):
"""
Constructs a Faster R-CNN model with a MobileNetV3-Large FPN backbone. It works similarly
to Faster R-CNN with ResNet-50 FPN backbone. See `fasterrcnn_resnet50_fpn` for more details.
Example::
>>> model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn(pretrained=True)
>>> model.eval()
>>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
>>> predictions = model(x)
Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017
progress (bool): If True, displays a progress bar of the download to stderr
num_classes (int): number of output classes of the model (including the background)
pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet
trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block.
Valid values are between 0 and 6, with 6 meaning all backbone layers are trainable.
min_size (int): minimum size of the image to be rescaled before feeding it to the backbone
max_size (int): maximum size of the image to be rescaled before feeding it to the backbone
rpn_score_thresh (float): during inference, only return proposals with a classification score
greater than rpn_score_thresh
"""
trainable_backbone_layers = _validate_trainable_layers(
pretrained or pretrained_backbone, trainable_backbone_layers, 6, 3)
if pretrained:
pretrained_backbone = False
backbone = mobilenet_backbone("mobilenet_v3_large", pretrained_backbone, True,
trainable_layers=trainable_backbone_layers)
anchor_sizes = ((32, 64, 128, 256, 512, ), ) * 3
aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
model = FasterRCNN(backbone, num_classes, rpn_anchor_generator=AnchorGenerator(anchor_sizes, aspect_ratios),
min_size=min_size, max_size=max_size, rpn_score_thresh=rpn_score_thresh, **kwargs)
if pretrained:
state_dict = load_state_dict_from_url(model_urls['fasterrcnn_mobilenet_v3_large_fpn_coco'], progress=progress)
model.load_state_dict(state_dict)
return model
...@@ -7,7 +7,7 @@ from ._utils import overwrite_eps ...@@ -7,7 +7,7 @@ from ._utils import overwrite_eps
from ..utils import load_state_dict_from_url from ..utils import load_state_dict_from_url
from .faster_rcnn import FasterRCNN from .faster_rcnn import FasterRCNN
from .backbone_utils import resnet_fpn_backbone, _validate_resnet_trainable_layers from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers
__all__ = [ __all__ = [
...@@ -322,9 +322,8 @@ def keypointrcnn_resnet50_fpn(pretrained=False, progress=True, ...@@ -322,9 +322,8 @@ def keypointrcnn_resnet50_fpn(pretrained=False, progress=True,
trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block. trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block.
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
""" """
# check default parameters and by default set it to 3 if possible trainable_backbone_layers = _validate_trainable_layers(
trainable_backbone_layers = _validate_resnet_trainable_layers( pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3)
pretrained or pretrained_backbone, trainable_backbone_layers)
if pretrained: if pretrained:
# no need to download the backbone if pretrained is set # no need to download the backbone if pretrained is set
......
...@@ -8,7 +8,7 @@ from ._utils import overwrite_eps ...@@ -8,7 +8,7 @@ from ._utils import overwrite_eps
from ..utils import load_state_dict_from_url from ..utils import load_state_dict_from_url
from .faster_rcnn import FasterRCNN from .faster_rcnn import FasterRCNN
from .backbone_utils import resnet_fpn_backbone, _validate_resnet_trainable_layers from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers
__all__ = [ __all__ = [
"MaskRCNN", "maskrcnn_resnet50_fpn", "MaskRCNN", "maskrcnn_resnet50_fpn",
...@@ -317,9 +317,8 @@ def maskrcnn_resnet50_fpn(pretrained=False, progress=True, ...@@ -317,9 +317,8 @@ def maskrcnn_resnet50_fpn(pretrained=False, progress=True,
trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block. trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block.
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
""" """
# check default parameters and by default set it to 3 if possible trainable_backbone_layers = _validate_trainable_layers(
trainable_backbone_layers = _validate_resnet_trainable_layers( pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3)
pretrained or pretrained_backbone, trainable_backbone_layers)
if pretrained: if pretrained:
# no need to download the backbone if pretrained is set # no need to download the backbone if pretrained is set
......
...@@ -12,14 +12,14 @@ from ..utils import load_state_dict_from_url ...@@ -12,14 +12,14 @@ from ..utils import load_state_dict_from_url
from . import _utils as det_utils from . import _utils as det_utils
from .anchor_utils import AnchorGenerator from .anchor_utils import AnchorGenerator
from .transform import GeneralizedRCNNTransform from .transform import GeneralizedRCNNTransform
from .backbone_utils import resnet_fpn_backbone, _validate_resnet_trainable_layers from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers
from ...ops.feature_pyramid_network import LastLevelP6P7 from ...ops.feature_pyramid_network import LastLevelP6P7
from ...ops import sigmoid_focal_loss from ...ops import sigmoid_focal_loss
from ...ops import boxes as box_ops from ...ops import boxes as box_ops
__all__ = [ __all__ = [
"RetinaNet", "retinanet_resnet50_fpn", "RetinaNet", "retinanet_resnet50_fpn"
] ]
...@@ -605,9 +605,8 @@ def retinanet_resnet50_fpn(pretrained=False, progress=True, ...@@ -605,9 +605,8 @@ def retinanet_resnet50_fpn(pretrained=False, progress=True,
trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block. trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block.
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
""" """
# check default parameters and by default set it to 3 if possible trainable_backbone_layers = _validate_trainable_layers(
trainable_backbone_layers = _validate_resnet_trainable_layers( pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3)
pretrained or pretrained_backbone, trainable_backbone_layers)
if pretrained: if pretrained:
# no need to download the backbone if pretrained is set # no need to download the backbone if pretrained is set
......
...@@ -18,16 +18,6 @@ model_urls = { ...@@ -18,16 +18,6 @@ model_urls = {
} }
class Identity(nn.Module):
def __init__(self, inplace: bool = False):
super().__init__()
self.inplace = inplace
def forward(self, input: Tensor) -> Tensor:
return input
class SqueezeExcitation(nn.Module): class SqueezeExcitation(nn.Module):
def __init__(self, input_channels: int, squeeze_factor: int = 4): def __init__(self, input_channels: int, squeeze_factor: int = 4):
...@@ -88,7 +78,7 @@ class InvertedResidual(nn.Module): ...@@ -88,7 +78,7 @@ class InvertedResidual(nn.Module):
# project # project
layers.append(ConvBNActivation(cnf.expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, layers.append(ConvBNActivation(cnf.expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer,
activation_layer=Identity)) activation_layer=nn.Identity))
self.block = nn.Sequential(*layers) self.block = nn.Sequential(*layers)
self.out_channels = cnf.out_channels self.out_channels = cnf.out_channels
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment