Unverified Commit d18c4872 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Refactor the backbone builders of detection (#4656)

* Refactoring resnet_fpn backbone building.

* Passing the change to *_rcnn and retinanet.

* Applying for faster_rcnn + mobilenetv3

* Applying for ssdlite + mobilenetv3

* Applying for ssd + vgg16

* Update the expected file of retinanet_resnet50_fpn to fix order of initialization.

* Adding full model weights for the VGG16 features.
parent 5e84bab4
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
import warnings import warnings
from typing import Callable, Dict, Optional, List from typing import Callable, Dict, Optional, List, Union
from torch import nn, Tensor from torch import nn, Tensor
from torchvision.ops import misc as misc_nn_ops from torchvision.ops import misc as misc_nn_ops
...@@ -100,14 +100,14 @@ def resnet_fpn_backbone( ...@@ -100,14 +100,14 @@ def resnet_fpn_backbone(
default a ``LastLevelMaxPool`` is used. default a ``LastLevelMaxPool`` is used.
""" """
backbone = resnet.__dict__[backbone_name](pretrained=pretrained, norm_layer=norm_layer) backbone = resnet.__dict__[backbone_name](pretrained=pretrained, norm_layer=norm_layer)
return _resnet_backbone_config(backbone, trainable_layers, returned_layers, extra_blocks) return _resnet_fpn_extractor(backbone, trainable_layers, returned_layers, extra_blocks)
def _resnet_backbone_config( def _resnet_fpn_extractor(
backbone: resnet.ResNet, backbone: resnet.ResNet,
trainable_layers: int, trainable_layers: int,
returned_layers: Optional[List[int]], returned_layers: Optional[List[int]] = None,
extra_blocks: Optional[ExtraFPNBlock], extra_blocks: Optional[ExtraFPNBlock] = None,
) -> BackboneWithFPN: ) -> BackboneWithFPN:
# select layers that wont be frozen # select layers that wont be frozen
...@@ -165,9 +165,18 @@ def mobilenet_backbone( ...@@ -165,9 +165,18 @@ def mobilenet_backbone(
returned_layers: Optional[List[int]] = None, returned_layers: Optional[List[int]] = None,
extra_blocks: Optional[ExtraFPNBlock] = None, extra_blocks: Optional[ExtraFPNBlock] = None,
) -> nn.Module: ) -> nn.Module:
backbone = mobilenet.__dict__[backbone_name](pretrained=pretrained, norm_layer=norm_layer)
return _mobilenet_extractor(backbone, fpn, trainable_layers, returned_layers, extra_blocks)
backbone = mobilenet.__dict__[backbone_name](pretrained=pretrained, norm_layer=norm_layer).features
def _mobilenet_extractor(
backbone: Union[mobilenet.MobileNetV2, mobilenet.MobileNetV3],
fpn: bool,
trainable_layers,
returned_layers: Optional[List[int]] = None,
extra_blocks: Optional[ExtraFPNBlock] = None,
) -> nn.Module:
backbone = backbone.features
# Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks. # Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
# The first and last blocks are always included because they are the C0 (conv1) and Cn. # The first and last blocks are always included because they are the C0 (conv1) and Cn.
stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1] stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1]
......
...@@ -3,9 +3,12 @@ from torch import nn ...@@ -3,9 +3,12 @@ from torch import nn
from torchvision.ops import MultiScaleRoIAlign from torchvision.ops import MultiScaleRoIAlign
from ..._internally_replaced_utils import load_state_dict_from_url from ..._internally_replaced_utils import load_state_dict_from_url
from ...ops import misc as misc_nn_ops
from ..mobilenetv3 import mobilenet_v3_large
from ..resnet import resnet50
from ._utils import overwrite_eps from ._utils import overwrite_eps
from .anchor_utils import AnchorGenerator from .anchor_utils import AnchorGenerator
from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers, mobilenet_backbone from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers, _mobilenet_extractor
from .generalized_rcnn import GeneralizedRCNN from .generalized_rcnn import GeneralizedRCNN
from .roi_heads import RoIHeads from .roi_heads import RoIHeads
from .rpn import RPNHead, RegionProposalNetwork from .rpn import RPNHead, RegionProposalNetwork
...@@ -385,7 +388,9 @@ def fasterrcnn_resnet50_fpn( ...@@ -385,7 +388,9 @@ def fasterrcnn_resnet50_fpn(
if pretrained: if pretrained:
# no need to download the backbone if pretrained is set # no need to download the backbone if pretrained is set
pretrained_backbone = False pretrained_backbone = False
backbone = resnet_fpn_backbone("resnet50", pretrained_backbone, trainable_layers=trainable_backbone_layers)
backbone = resnet50(pretrained=pretrained_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d)
backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
model = FasterRCNN(backbone, num_classes, **kwargs) model = FasterRCNN(backbone, num_classes, **kwargs)
if pretrained: if pretrained:
state_dict = load_state_dict_from_url(model_urls["fasterrcnn_resnet50_fpn_coco"], progress=progress) state_dict = load_state_dict_from_url(model_urls["fasterrcnn_resnet50_fpn_coco"], progress=progress)
...@@ -409,9 +414,11 @@ def _fasterrcnn_mobilenet_v3_large_fpn( ...@@ -409,9 +414,11 @@ def _fasterrcnn_mobilenet_v3_large_fpn(
if pretrained: if pretrained:
pretrained_backbone = False pretrained_backbone = False
backbone = mobilenet_backbone(
"mobilenet_v3_large", pretrained_backbone, True, trainable_layers=trainable_backbone_layers backbone = mobilenet_v3_large(
pretrained=pretrained_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d
) )
backbone = _mobilenet_extractor(backbone, True, trainable_backbone_layers)
anchor_sizes = ( anchor_sizes = (
( (
......
...@@ -3,8 +3,10 @@ from torch import nn ...@@ -3,8 +3,10 @@ from torch import nn
from torchvision.ops import MultiScaleRoIAlign from torchvision.ops import MultiScaleRoIAlign
from ..._internally_replaced_utils import load_state_dict_from_url from ..._internally_replaced_utils import load_state_dict_from_url
from ...ops import misc as misc_nn_ops
from ..resnet import resnet50
from ._utils import overwrite_eps from ._utils import overwrite_eps
from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
from .faster_rcnn import FasterRCNN from .faster_rcnn import FasterRCNN
...@@ -367,7 +369,9 @@ def keypointrcnn_resnet50_fpn( ...@@ -367,7 +369,9 @@ def keypointrcnn_resnet50_fpn(
if pretrained: if pretrained:
# no need to download the backbone if pretrained is set # no need to download the backbone if pretrained is set
pretrained_backbone = False pretrained_backbone = False
backbone = resnet_fpn_backbone("resnet50", pretrained_backbone, trainable_layers=trainable_backbone_layers)
backbone = resnet50(pretrained=pretrained_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d)
backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
model = KeypointRCNN(backbone, num_classes, num_keypoints=num_keypoints, **kwargs) model = KeypointRCNN(backbone, num_classes, num_keypoints=num_keypoints, **kwargs)
if pretrained: if pretrained:
key = "keypointrcnn_resnet50_fpn_coco" key = "keypointrcnn_resnet50_fpn_coco"
......
...@@ -4,8 +4,10 @@ from torch import nn ...@@ -4,8 +4,10 @@ from torch import nn
from torchvision.ops import MultiScaleRoIAlign from torchvision.ops import MultiScaleRoIAlign
from ..._internally_replaced_utils import load_state_dict_from_url from ..._internally_replaced_utils import load_state_dict_from_url
from ...ops import misc as misc_nn_ops
from ..resnet import resnet50
from ._utils import overwrite_eps from ._utils import overwrite_eps
from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
from .faster_rcnn import FasterRCNN from .faster_rcnn import FasterRCNN
__all__ = [ __all__ = [
...@@ -364,7 +366,9 @@ def maskrcnn_resnet50_fpn( ...@@ -364,7 +366,9 @@ def maskrcnn_resnet50_fpn(
if pretrained: if pretrained:
# no need to download the backbone if pretrained is set # no need to download the backbone if pretrained is set
pretrained_backbone = False pretrained_backbone = False
backbone = resnet_fpn_backbone("resnet50", pretrained_backbone, trainable_layers=trainable_backbone_layers)
backbone = resnet50(pretrained=pretrained_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d)
backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
model = MaskRCNN(backbone, num_classes, **kwargs) model = MaskRCNN(backbone, num_classes, **kwargs)
if pretrained: if pretrained:
state_dict = load_state_dict_from_url(model_urls["maskrcnn_resnet50_fpn_coco"], progress=progress) state_dict = load_state_dict_from_url(model_urls["maskrcnn_resnet50_fpn_coco"], progress=progress)
......
...@@ -9,11 +9,13 @@ from torch import nn, Tensor ...@@ -9,11 +9,13 @@ from torch import nn, Tensor
from ..._internally_replaced_utils import load_state_dict_from_url from ..._internally_replaced_utils import load_state_dict_from_url
from ...ops import sigmoid_focal_loss from ...ops import sigmoid_focal_loss
from ...ops import boxes as box_ops from ...ops import boxes as box_ops
from ...ops import misc as misc_nn_ops
from ...ops.feature_pyramid_network import LastLevelP6P7 from ...ops.feature_pyramid_network import LastLevelP6P7
from ..resnet import resnet50
from . import _utils as det_utils from . import _utils as det_utils
from ._utils import overwrite_eps from ._utils import overwrite_eps
from .anchor_utils import AnchorGenerator from .anchor_utils import AnchorGenerator
from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
from .transform import GeneralizedRCNNTransform from .transform import GeneralizedRCNNTransform
...@@ -630,13 +632,11 @@ def retinanet_resnet50_fpn( ...@@ -630,13 +632,11 @@ def retinanet_resnet50_fpn(
if pretrained: if pretrained:
# no need to download the backbone if pretrained is set # no need to download the backbone if pretrained is set
pretrained_backbone = False pretrained_backbone = False
backbone = resnet50(pretrained=pretrained_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d)
# skip P2 because it generates too many anchors (according to their paper) # skip P2 because it generates too many anchors (according to their paper)
backbone = resnet_fpn_backbone( backbone = _resnet_fpn_extractor(
"resnet50", backbone, trainable_backbone_layers, returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(256, 256)
pretrained_backbone,
returned_layers=[2, 3, 4],
extra_blocks=LastLevelP6P7(256, 256),
trainable_layers=trainable_backbone_layers,
) )
model = RetinaNet(backbone, num_classes, **kwargs) model = RetinaNet(backbone, num_classes, **kwargs)
if pretrained: if pretrained:
......
...@@ -23,7 +23,8 @@ model_urls = { ...@@ -23,7 +23,8 @@ model_urls = {
backbone_urls = { backbone_urls = {
# We port the features of a VGG16 backbone trained by amdegroot because unlike the one on TorchVision, it uses the # We port the features of a VGG16 backbone trained by amdegroot because unlike the one on TorchVision, it uses the
# same input standardization method as the paper. Ref: https://s3.amazonaws.com/amdegroot-models/vgg16_reducedfc.pth # same input standardization method as the paper. Ref: https://s3.amazonaws.com/amdegroot-models/vgg16_reducedfc.pth
"vgg16_features": "https://download.pytorch.org/models/vgg16_features-amdegroot.pth" # Only the `features` weights have proper values, those on the `classifier` module are filled with nans.
"vgg16_features": "https://download.pytorch.org/models/vgg16_features-amdegroot-88682ab5.pth"
} }
...@@ -519,18 +520,8 @@ class SSDFeatureExtractorVGG(nn.Module): ...@@ -519,18 +520,8 @@ class SSDFeatureExtractorVGG(nn.Module):
return OrderedDict([(str(i), v) for i, v in enumerate(output)]) return OrderedDict([(str(i), v) for i, v in enumerate(output)])
def _vgg_extractor(backbone_name: str, highres: bool, progress: bool, pretrained: bool, trainable_layers: int): def _vgg_extractor(backbone: vgg.VGG, highres: bool, trainable_layers: int):
if backbone_name in backbone_urls: backbone = backbone.features
# Use custom backbones more appropriate for SSD
arch = backbone_name.split("_")[0]
backbone = vgg.__dict__[arch](pretrained=False, progress=progress).features
if pretrained:
state_dict = load_state_dict_from_url(backbone_urls[backbone_name], progress=progress)
backbone.load_state_dict(state_dict)
else:
# Use standard backbones from TorchVision
backbone = vgg.__dict__[backbone_name](pretrained=pretrained, progress=progress).features
# Gather the indices of maxpools. These are the locations of output blocks. # Gather the indices of maxpools. These are the locations of output blocks.
stage_indices = [0] + [i for i, b in enumerate(backbone) if isinstance(b, nn.MaxPool2d)][:-1] stage_indices = [0] + [i for i, b in enumerate(backbone) if isinstance(b, nn.MaxPool2d)][:-1]
num_stages = len(stage_indices) num_stages = len(stage_indices)
...@@ -609,7 +600,13 @@ def ssd300_vgg16( ...@@ -609,7 +600,13 @@ def ssd300_vgg16(
# no need to download the backbone if pretrained is set # no need to download the backbone if pretrained is set
pretrained_backbone = False pretrained_backbone = False
backbone = _vgg_extractor("vgg16_features", False, progress, pretrained_backbone, trainable_backbone_layers) # Use custom backbones more appropriate for SSD
backbone = vgg.vgg16(pretrained=False, progress=progress)
if pretrained_backbone:
state_dict = load_state_dict_from_url(backbone_urls["vgg16_features"], progress=progress)
backbone.load_state_dict(state_dict)
backbone = _vgg_extractor(backbone, False, trainable_backbone_layers)
anchor_generator = DefaultBoxGenerator( anchor_generator = DefaultBoxGenerator(
[[2], [2, 3], [2, 3], [2, 3], [2], [2]], [[2], [2, 3], [2, 3], [2, 3], [2], [2]],
scales=[0.07, 0.15, 0.33, 0.51, 0.69, 0.87, 1.05], scales=[0.07, 0.15, 0.33, 0.51, 0.69, 0.87, 1.05],
......
import warnings import warnings
from collections import OrderedDict from collections import OrderedDict
from functools import partial from functools import partial
from typing import Any, Callable, Dict, List, Optional from typing import Any, Callable, Dict, List, Optional, Union
import torch import torch
from torch import nn, Tensor from torch import nn, Tensor
...@@ -117,7 +117,6 @@ class SSDLiteFeatureExtractorMobileNet(nn.Module): ...@@ -117,7 +117,6 @@ class SSDLiteFeatureExtractorMobileNet(nn.Module):
norm_layer: Callable[..., nn.Module], norm_layer: Callable[..., nn.Module],
width_mult: float = 1.0, width_mult: float = 1.0,
min_depth: int = 16, min_depth: int = 16,
**kwargs: Any,
): ):
super().__init__() super().__init__()
...@@ -156,20 +155,11 @@ class SSDLiteFeatureExtractorMobileNet(nn.Module): ...@@ -156,20 +155,11 @@ class SSDLiteFeatureExtractorMobileNet(nn.Module):
def _mobilenet_extractor( def _mobilenet_extractor(
backbone_name: str, backbone: Union[mobilenet.MobileNetV2, mobilenet.MobileNetV3],
progress: bool,
pretrained: bool,
trainable_layers: int, trainable_layers: int,
norm_layer: Callable[..., nn.Module], norm_layer: Callable[..., nn.Module],
**kwargs: Any,
): ):
backbone = mobilenet.__dict__[backbone_name]( backbone = backbone.features
pretrained=pretrained, progress=progress, norm_layer=norm_layer, **kwargs
).features
if not pretrained:
# Change the default initialization scheme if not pretrained
_normal_init(backbone)
# Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks. # Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
# The first and last blocks are always included because they are the C0 (conv1) and Cn. # The first and last blocks are always included because they are the C0 (conv1) and Cn.
stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1] stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1]
...@@ -183,7 +173,7 @@ def _mobilenet_extractor( ...@@ -183,7 +173,7 @@ def _mobilenet_extractor(
for parameter in b.parameters(): for parameter in b.parameters():
parameter.requires_grad_(False) parameter.requires_grad_(False)
return SSDLiteFeatureExtractorMobileNet(backbone, stage_indices[-2], norm_layer, **kwargs) return SSDLiteFeatureExtractorMobileNet(backbone, stage_indices[-2], norm_layer)
def ssdlite320_mobilenet_v3_large( def ssdlite320_mobilenet_v3_large(
...@@ -235,14 +225,16 @@ def ssdlite320_mobilenet_v3_large( ...@@ -235,14 +225,16 @@ def ssdlite320_mobilenet_v3_large(
if norm_layer is None: if norm_layer is None:
norm_layer = partial(nn.BatchNorm2d, eps=0.001, momentum=0.03) norm_layer = partial(nn.BatchNorm2d, eps=0.001, momentum=0.03)
backbone = mobilenet.mobilenet_v3_large(
pretrained=pretrained_backbone, progress=progress, norm_layer=norm_layer, reduced_tail=reduce_tail, **kwargs
)
if not pretrained_backbone:
# Change the default initialization scheme if not pretrained
_normal_init(backbone)
backbone = _mobilenet_extractor( backbone = _mobilenet_extractor(
"mobilenet_v3_large", backbone,
progress,
pretrained_backbone,
trainable_backbone_layers, trainable_backbone_layers,
norm_layer, norm_layer,
reduced_tail=reduce_tail,
**kwargs,
) )
size = (320, 320) size = (320, 320)
......
from typing import Callable, Optional, List
from torch import nn
from ....models.detection.backbone_utils import misc_nn_ops, _resnet_backbone_config, BackboneWithFPN, ExtraFPNBlock
from .. import resnet
from .._api import Weights
def resnet_fpn_backbone(
backbone_name: str,
weights: Optional[Weights],
norm_layer: Callable[..., nn.Module] = misc_nn_ops.FrozenBatchNorm2d,
trainable_layers: int = 3,
returned_layers: Optional[List[int]] = None,
extra_blocks: Optional[ExtraFPNBlock] = None,
) -> BackboneWithFPN:
backbone = resnet.__dict__[backbone_name](weights=weights, norm_layer=norm_layer)
return _resnet_backbone_config(backbone, trainable_layers, returned_layers, extra_blocks)
import warnings import warnings
from typing import Any, Optional from typing import Any, Optional
from ....models.detection.faster_rcnn import FasterRCNN, overwrite_eps, _validate_trainable_layers from ....models.detection.faster_rcnn import (
_validate_trainable_layers,
_resnet_fpn_extractor,
FasterRCNN,
misc_nn_ops,
overwrite_eps,
)
from ...transforms.presets import CocoEval from ...transforms.presets import CocoEval
from .._api import Weights, WeightEntry from .._api import Weights, WeightEntry
from .._meta import _COCO_CATEGORIES from .._meta import _COCO_CATEGORIES
from ..resnet import ResNet50Weights from ..resnet import ResNet50Weights, resnet50
from .backbone_utils import resnet_fpn_backbone
__all__ = ["FasterRCNN", "FasterRCNNResNet50FPNWeights", "fasterrcnn_resnet50_fpn"] __all__ = ["FasterRCNN", "FasterRCNNResNet50FPNWeights", "fasterrcnn_resnet50_fpn"]
...@@ -49,7 +54,8 @@ def fasterrcnn_resnet50_fpn( ...@@ -49,7 +54,8 @@ def fasterrcnn_resnet50_fpn(
weights is not None or weights_backbone is not None, trainable_backbone_layers, 5, 3 weights is not None or weights_backbone is not None, trainable_backbone_layers, 5, 3
) )
backbone = resnet_fpn_backbone("resnet50", weights_backbone, trainable_layers=trainable_backbone_layers) backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d)
backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
model = FasterRCNN(backbone, num_classes=num_classes, **kwargs) model = FasterRCNN(backbone, num_classes=num_classes, **kwargs)
if weights is not None: if weights is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment