Unverified Commit 08cc9a7f authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Post-paper Detection Optimizations (#5444)

* Use frozen BN only if pre-trained.

* Add LSJ and ability to from scratch training.

* Fixing formatter

* Adding `--opt` and `--norm-weight-decay` support in Detection.

* Fix error message

* Make ScaleJitter proportional.

* Adding more norm layers in split_normalization_params.

* Add FixedSizeCrop

* Temporary fix for fill values on PIL

* Fix the bug on fill.

* Add RandomShortestSize.

* Skip resize when an augmentation method is used.

* multiscale in [480, 800]

* Add missing star

* Add new RetinaNet variant.

* Add tests.

* Update expected file for old retina

* Fixing tests

* Add FrozenBN to retinav2

* Fix network initialization issues

* Adding BN support in MaskRCNNHeads and FPN

* Adding support of FasterRCNNHeads

* Introduce norm_layers in backbone utils.

* Bigger RPN head + 2x rcnn v2 models.

* Adding gIoU support to retinanet

* Fix assert

* Add back nesterov momentum

* Rename and extend `FastRCNNConvFCHead` to support arbitrary FCs

* Fix linter
parent 63576c9f
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
...@@ -195,11 +195,14 @@ script_model_unwrapper = { ...@@ -195,11 +195,14 @@ script_model_unwrapper = {
"googlenet": lambda x: x.logits, "googlenet": lambda x: x.logits,
"inception_v3": lambda x: x.logits, "inception_v3": lambda x: x.logits,
"fasterrcnn_resnet50_fpn": lambda x: x[1], "fasterrcnn_resnet50_fpn": lambda x: x[1],
"fasterrcnn_resnet50_fpn_v2": lambda x: x[1],
"fasterrcnn_mobilenet_v3_large_fpn": lambda x: x[1], "fasterrcnn_mobilenet_v3_large_fpn": lambda x: x[1],
"fasterrcnn_mobilenet_v3_large_320_fpn": lambda x: x[1], "fasterrcnn_mobilenet_v3_large_320_fpn": lambda x: x[1],
"maskrcnn_resnet50_fpn": lambda x: x[1], "maskrcnn_resnet50_fpn": lambda x: x[1],
"maskrcnn_resnet50_fpn_v2": lambda x: x[1],
"keypointrcnn_resnet50_fpn": lambda x: x[1], "keypointrcnn_resnet50_fpn": lambda x: x[1],
"retinanet_resnet50_fpn": lambda x: x[1], "retinanet_resnet50_fpn": lambda x: x[1],
"retinanet_resnet50_fpn_v2": lambda x: x[1],
"ssd300_vgg16": lambda x: x[1], "ssd300_vgg16": lambda x: x[1],
"ssdlite320_mobilenet_v3_large": lambda x: x[1], "ssdlite320_mobilenet_v3_large": lambda x: x[1],
"fcos_resnet50_fpn": lambda x: x[1], "fcos_resnet50_fpn": lambda x: x[1],
...@@ -227,6 +230,7 @@ autocast_flaky_numerics = ( ...@@ -227,6 +230,7 @@ autocast_flaky_numerics = (
"fcn_resnet101", "fcn_resnet101",
"lraspp_mobilenet_v3_large", "lraspp_mobilenet_v3_large",
"maskrcnn_resnet50_fpn", "maskrcnn_resnet50_fpn",
"maskrcnn_resnet50_fpn_v2",
) )
# The tests for the following quantized models are flaky possibly due to inconsistent # The tests for the following quantized models are flaky possibly due to inconsistent
...@@ -246,6 +250,13 @@ _model_params = { ...@@ -246,6 +250,13 @@ _model_params = {
"max_size": 224, "max_size": 224,
"input_shape": (3, 224, 224), "input_shape": (3, 224, 224),
}, },
"retinanet_resnet50_fpn_v2": {
"num_classes": 20,
"score_thresh": 0.01,
"min_size": 224,
"max_size": 224,
"input_shape": (3, 224, 224),
},
"keypointrcnn_resnet50_fpn": { "keypointrcnn_resnet50_fpn": {
"num_classes": 2, "num_classes": 2,
"min_size": 224, "min_size": 224,
...@@ -259,6 +270,12 @@ _model_params = { ...@@ -259,6 +270,12 @@ _model_params = {
"max_size": 224, "max_size": 224,
"input_shape": (3, 224, 224), "input_shape": (3, 224, 224),
}, },
"fasterrcnn_resnet50_fpn_v2": {
"num_classes": 20,
"min_size": 224,
"max_size": 224,
"input_shape": (3, 224, 224),
},
"fcos_resnet50_fpn": { "fcos_resnet50_fpn": {
"num_classes": 2, "num_classes": 2,
"score_thresh": 0.05, "score_thresh": 0.05,
...@@ -272,6 +289,12 @@ _model_params = { ...@@ -272,6 +289,12 @@ _model_params = {
"max_size": 224, "max_size": 224,
"input_shape": (3, 224, 224), "input_shape": (3, 224, 224),
}, },
"maskrcnn_resnet50_fpn_v2": {
"num_classes": 10,
"min_size": 224,
"max_size": 224,
"input_shape": (3, 224, 224),
},
"fasterrcnn_mobilenet_v3_large_fpn": { "fasterrcnn_mobilenet_v3_large_fpn": {
"box_score_thresh": 0.02076, "box_score_thresh": 0.02076,
}, },
...@@ -311,6 +334,10 @@ _model_tests_values = { ...@@ -311,6 +334,10 @@ _model_tests_values = {
"max_trainable": 5, "max_trainable": 5,
"n_trn_params_per_layer": [36, 46, 65, 78, 88, 89], "n_trn_params_per_layer": [36, 46, 65, 78, 88, 89],
}, },
"retinanet_resnet50_fpn_v2": {
"max_trainable": 5,
"n_trn_params_per_layer": [44, 74, 131, 170, 200, 203],
},
"keypointrcnn_resnet50_fpn": { "keypointrcnn_resnet50_fpn": {
"max_trainable": 5, "max_trainable": 5,
"n_trn_params_per_layer": [48, 58, 77, 90, 100, 101], "n_trn_params_per_layer": [48, 58, 77, 90, 100, 101],
...@@ -319,10 +346,18 @@ _model_tests_values = { ...@@ -319,10 +346,18 @@ _model_tests_values = {
"max_trainable": 5, "max_trainable": 5,
"n_trn_params_per_layer": [30, 40, 59, 72, 82, 83], "n_trn_params_per_layer": [30, 40, 59, 72, 82, 83],
}, },
"fasterrcnn_resnet50_fpn_v2": {
"max_trainable": 5,
"n_trn_params_per_layer": [50, 80, 137, 176, 206, 209],
},
"maskrcnn_resnet50_fpn": { "maskrcnn_resnet50_fpn": {
"max_trainable": 5, "max_trainable": 5,
"n_trn_params_per_layer": [42, 52, 71, 84, 94, 95], "n_trn_params_per_layer": [42, 52, 71, 84, 94, 95],
}, },
"maskrcnn_resnet50_fpn_v2": {
"max_trainable": 5,
"n_trn_params_per_layer": [66, 96, 153, 192, 222, 225],
},
"fasterrcnn_mobilenet_v3_large_fpn": { "fasterrcnn_mobilenet_v3_large_fpn": {
"max_trainable": 6, "max_trainable": 6,
"n_trn_params_per_layer": [22, 23, 44, 70, 91, 97, 100], "n_trn_params_per_layer": [22, 23, 44, 70, 91, 97, 100],
......
import math import math
from collections import OrderedDict from collections import OrderedDict
from typing import List, Tuple from typing import Dict, List, Optional, Tuple
import torch import torch
from torch import Tensor, nn from torch import Tensor, nn
from torchvision.ops.misc import FrozenBatchNorm2d from torch.nn import functional as F
from torchvision.ops import FrozenBatchNorm2d, generalized_box_iou_loss
class BalancedPositiveNegativeSampler: class BalancedPositiveNegativeSampler:
...@@ -507,3 +508,26 @@ def _topk_min(input: Tensor, orig_kval: int, axis: int) -> int: ...@@ -507,3 +508,26 @@ def _topk_min(input: Tensor, orig_kval: int, axis: int) -> int:
axis_dim_val = torch._shape_as_tensor(input)[axis].unsqueeze(0) axis_dim_val = torch._shape_as_tensor(input)[axis].unsqueeze(0)
min_kval = torch.min(torch.cat((torch.tensor([orig_kval], dtype=axis_dim_val.dtype), axis_dim_val), 0)) min_kval = torch.min(torch.cat((torch.tensor([orig_kval], dtype=axis_dim_val.dtype), axis_dim_val), 0))
return _fake_cast_onnx(min_kval) return _fake_cast_onnx(min_kval)
def _box_loss(
type: str,
box_coder: BoxCoder,
anchors_per_image: Tensor,
matched_gt_boxes_per_image: Tensor,
bbox_regression_per_image: Tensor,
cnf: Optional[Dict[str, float]] = None,
) -> Tensor:
torch._assert(type in ["l1", "smooth_l1", "giou"], f"Unsupported loss: {type}")
if type == "l1":
target_regression = box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image)
return F.l1_loss(bbox_regression_per_image, target_regression, reduction="sum")
elif type == "smooth_l1":
target_regression = box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image)
beta = cnf["beta"] if cnf is not None and "beta" in cnf else 1.0
return F.smooth_l1_loss(bbox_regression_per_image, target_regression, reduction="sum", beta=beta)
else: # giou
bbox_per_image = box_coder.decode_single(bbox_regression_per_image, anchors_per_image)
eps = cnf["eps"] if cnf is not None and "eps" in cnf else 1e-7
return generalized_box_iou_loss(bbox_per_image, matched_gt_boxes_per_image, reduction="sum", eps=eps)
...@@ -25,6 +25,7 @@ class BackboneWithFPN(nn.Module): ...@@ -25,6 +25,7 @@ class BackboneWithFPN(nn.Module):
in_channels_list (List[int]): number of channels for each feature map in_channels_list (List[int]): number of channels for each feature map
that is returned, in the order they are present in the OrderedDict that is returned, in the order they are present in the OrderedDict
out_channels (int): number of channels in the FPN. out_channels (int): number of channels in the FPN.
norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
Attributes: Attributes:
out_channels (int): the number of channels in the FPN out_channels (int): the number of channels in the FPN
""" """
...@@ -36,6 +37,7 @@ class BackboneWithFPN(nn.Module): ...@@ -36,6 +37,7 @@ class BackboneWithFPN(nn.Module):
in_channels_list: List[int], in_channels_list: List[int],
out_channels: int, out_channels: int,
extra_blocks: Optional[ExtraFPNBlock] = None, extra_blocks: Optional[ExtraFPNBlock] = None,
norm_layer: Optional[Callable[..., nn.Module]] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -47,6 +49,7 @@ class BackboneWithFPN(nn.Module): ...@@ -47,6 +49,7 @@ class BackboneWithFPN(nn.Module):
in_channels_list=in_channels_list, in_channels_list=in_channels_list,
out_channels=out_channels, out_channels=out_channels,
extra_blocks=extra_blocks, extra_blocks=extra_blocks,
norm_layer=norm_layer,
) )
self.out_channels = out_channels self.out_channels = out_channels
...@@ -115,6 +118,7 @@ def _resnet_fpn_extractor( ...@@ -115,6 +118,7 @@ def _resnet_fpn_extractor(
trainable_layers: int, trainable_layers: int,
returned_layers: Optional[List[int]] = None, returned_layers: Optional[List[int]] = None,
extra_blocks: Optional[ExtraFPNBlock] = None, extra_blocks: Optional[ExtraFPNBlock] = None,
norm_layer: Optional[Callable[..., nn.Module]] = None,
) -> BackboneWithFPN: ) -> BackboneWithFPN:
# select layers that wont be frozen # select layers that wont be frozen
...@@ -139,7 +143,9 @@ def _resnet_fpn_extractor( ...@@ -139,7 +143,9 @@ def _resnet_fpn_extractor(
in_channels_stage2 = backbone.inplanes // 8 in_channels_stage2 = backbone.inplanes // 8
in_channels_list = [in_channels_stage2 * 2 ** (i - 1) for i in returned_layers] in_channels_list = [in_channels_stage2 * 2 ** (i - 1) for i in returned_layers]
out_channels = 256 out_channels = 256
return BackboneWithFPN(backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks) return BackboneWithFPN(
backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks, norm_layer=norm_layer
)
def _validate_trainable_layers( def _validate_trainable_layers(
...@@ -194,6 +200,7 @@ def _mobilenet_extractor( ...@@ -194,6 +200,7 @@ def _mobilenet_extractor(
trainable_layers: int, trainable_layers: int,
returned_layers: Optional[List[int]] = None, returned_layers: Optional[List[int]] = None,
extra_blocks: Optional[ExtraFPNBlock] = None, extra_blocks: Optional[ExtraFPNBlock] = None,
norm_layer: Optional[Callable[..., nn.Module]] = None,
) -> nn.Module: ) -> nn.Module:
backbone = backbone.features backbone = backbone.features
# Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks. # Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
...@@ -222,7 +229,9 @@ def _mobilenet_extractor( ...@@ -222,7 +229,9 @@ def _mobilenet_extractor(
return_layers = {f"{stage_indices[k]}": str(v) for v, k in enumerate(returned_layers)} return_layers = {f"{stage_indices[k]}": str(v) for v, k in enumerate(returned_layers)}
in_channels_list = [backbone[stage_indices[i]].out_channels for i in returned_layers] in_channels_list = [backbone[stage_indices[i]].out_channels for i in returned_layers]
return BackboneWithFPN(backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks) return BackboneWithFPN(
backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks, norm_layer=norm_layer
)
else: else:
m = nn.Sequential( m = nn.Sequential(
backbone, backbone,
......
from typing import Any, Optional, Union from typing import Any, Callable, List, Optional, Tuple, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -24,14 +24,22 @@ from .transform import GeneralizedRCNNTransform ...@@ -24,14 +24,22 @@ from .transform import GeneralizedRCNNTransform
__all__ = [ __all__ = [
"FasterRCNN", "FasterRCNN",
"FasterRCNN_ResNet50_FPN_Weights", "FasterRCNN_ResNet50_FPN_Weights",
"FasterRCNN_ResNet50_FPN_V2_Weights",
"FasterRCNN_MobileNet_V3_Large_FPN_Weights", "FasterRCNN_MobileNet_V3_Large_FPN_Weights",
"FasterRCNN_MobileNet_V3_Large_320_FPN_Weights", "FasterRCNN_MobileNet_V3_Large_320_FPN_Weights",
"fasterrcnn_resnet50_fpn", "fasterrcnn_resnet50_fpn",
"fasterrcnn_resnet50_fpn_v2",
"fasterrcnn_mobilenet_v3_large_fpn", "fasterrcnn_mobilenet_v3_large_fpn",
"fasterrcnn_mobilenet_v3_large_320_fpn", "fasterrcnn_mobilenet_v3_large_320_fpn",
] ]
def _default_anchorgen():
anchor_sizes = ((32,), (64,), (128,), (256,), (512,))
aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
return AnchorGenerator(anchor_sizes, aspect_ratios)
class FasterRCNN(GeneralizedRCNN): class FasterRCNN(GeneralizedRCNN):
""" """
Implements Faster R-CNN. Implements Faster R-CNN.
...@@ -216,9 +224,7 @@ class FasterRCNN(GeneralizedRCNN): ...@@ -216,9 +224,7 @@ class FasterRCNN(GeneralizedRCNN):
out_channels = backbone.out_channels out_channels = backbone.out_channels
if rpn_anchor_generator is None: if rpn_anchor_generator is None:
anchor_sizes = ((32,), (64,), (128,), (256,), (512,)) rpn_anchor_generator = _default_anchorgen()
aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
rpn_anchor_generator = AnchorGenerator(anchor_sizes, aspect_ratios)
if rpn_head is None: if rpn_head is None:
rpn_head = RPNHead(out_channels, rpn_anchor_generator.num_anchors_per_location()[0]) rpn_head = RPNHead(out_channels, rpn_anchor_generator.num_anchors_per_location()[0])
...@@ -298,6 +304,43 @@ class TwoMLPHead(nn.Module): ...@@ -298,6 +304,43 @@ class TwoMLPHead(nn.Module):
return x return x
class FastRCNNConvFCHead(nn.Sequential):
def __init__(
self,
input_size: Tuple[int, int, int],
conv_layers: List[int],
fc_layers: List[int],
norm_layer: Optional[Callable[..., nn.Module]] = None,
):
"""
Args:
input_size (Tuple[int, int, int]): the input size in CHW format.
conv_layers (list): feature dimensions of each Convolution layer
fc_layers (list): feature dimensions of each FCN layer
norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
"""
in_channels, in_height, in_width = input_size
blocks = []
previous_channels = in_channels
for current_channels in conv_layers:
blocks.append(misc_nn_ops.Conv2dNormActivation(previous_channels, current_channels, norm_layer=norm_layer))
previous_channels = current_channels
blocks.append(nn.Flatten())
previous_channels = previous_channels * in_height * in_width
for current_channels in fc_layers:
blocks.append(nn.Linear(previous_channels, current_channels))
blocks.append(nn.ReLU(inplace=True))
previous_channels = current_channels
super().__init__(*blocks)
for layer in self.modules():
if isinstance(layer, nn.Conv2d):
nn.init.kaiming_normal_(layer.weight, mode="fan_out", nonlinearity="relu")
if layer.bias is not None:
nn.init.zeros_(layer.bias)
class FastRCNNPredictor(nn.Module): class FastRCNNPredictor(nn.Module):
""" """
Standard classification + bounding box regression layers Standard classification + bounding box regression layers
...@@ -349,6 +392,10 @@ class FasterRCNN_ResNet50_FPN_Weights(WeightsEnum): ...@@ -349,6 +392,10 @@ class FasterRCNN_ResNet50_FPN_Weights(WeightsEnum):
DEFAULT = COCO_V1 DEFAULT = COCO_V1
class FasterRCNN_ResNet50_FPN_V2_Weights(WeightsEnum):
pass
class FasterRCNN_MobileNet_V3_Large_FPN_Weights(WeightsEnum): class FasterRCNN_MobileNet_V3_Large_FPN_Weights(WeightsEnum):
COCO_V1 = Weights( COCO_V1 = Weights(
url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth", url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth",
...@@ -481,6 +528,66 @@ def fasterrcnn_resnet50_fpn( ...@@ -481,6 +528,66 @@ def fasterrcnn_resnet50_fpn(
return model return model
def fasterrcnn_resnet50_fpn_v2(
*,
weights: Optional[FasterRCNN_ResNet50_FPN_V2_Weights] = None,
progress: bool = True,
num_classes: Optional[int] = None,
weights_backbone: Optional[ResNet50_Weights] = None,
trainable_backbone_layers: Optional[int] = None,
**kwargs: Any,
) -> FasterRCNN:
"""
Constructs an improved Faster R-CNN model with a ResNet-50-FPN backbone.
Reference: `"Benchmarking Detection Transfer Learning with Vision Transformers"
<https://arxiv.org/abs/2111.11429>`_.
:func:`~torchvision.models.detection.fasterrcnn_resnet50_fpn` for more details.
Args:
weights (FasterRCNN_ResNet50_FPN_V2_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr
num_classes (int, optional): number of output classes of the model (including the background)
weights_backbone (ResNet50_Weights, optional): The pretrained weights for the backbone
trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block.
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is
passed (the default) this value is set to 3.
"""
weights = FasterRCNN_ResNet50_FPN_V2_Weights.verify(weights)
weights_backbone = ResNet50_Weights.verify(weights_backbone)
if weights is not None:
weights_backbone = None
num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
elif num_classes is None:
num_classes = 91
is_trained = weights is not None or weights_backbone is not None
trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
backbone = resnet50(weights=weights_backbone, progress=progress)
backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers, norm_layer=nn.BatchNorm2d)
rpn_anchor_generator = _default_anchorgen()
rpn_head = RPNHead(backbone.out_channels, rpn_anchor_generator.num_anchors_per_location()[0], conv_depth=2)
box_head = FastRCNNConvFCHead(
(backbone.out_channels, 7, 7), [256, 256, 256, 256], [1024], norm_layer=nn.BatchNorm2d
)
model = FasterRCNN(
backbone,
num_classes=num_classes,
rpn_anchor_generator=rpn_anchor_generator,
rpn_head=rpn_head,
box_head=box_head,
**kwargs,
)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
return model
def _fasterrcnn_mobilenet_v3_large_fpn( def _fasterrcnn_mobilenet_v3_large_fpn(
*, *,
weights: Optional[Union[FasterRCNN_MobileNet_V3_Large_FPN_Weights, FasterRCNN_MobileNet_V3_Large_320_FPN_Weights]], weights: Optional[Union[FasterRCNN_MobileNet_V3_Large_FPN_Weights, FasterRCNN_MobileNet_V3_Large_320_FPN_Weights]],
......
from collections import OrderedDict from collections import OrderedDict
from typing import Any, Optional from typing import Any, Callable, Optional
from torch import nn from torch import nn
from torchvision.ops import MultiScaleRoIAlign from torchvision.ops import MultiScaleRoIAlign
...@@ -12,13 +12,15 @@ from .._utils import handle_legacy_interface, _ovewrite_value_param ...@@ -12,13 +12,15 @@ from .._utils import handle_legacy_interface, _ovewrite_value_param
from ..resnet import ResNet50_Weights, resnet50 from ..resnet import ResNet50_Weights, resnet50
from ._utils import overwrite_eps from ._utils import overwrite_eps
from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
from .faster_rcnn import FasterRCNN from .faster_rcnn import FasterRCNN, FastRCNNConvFCHead, RPNHead, _default_anchorgen
__all__ = [ __all__ = [
"MaskRCNN", "MaskRCNN",
"MaskRCNN_ResNet50_FPN_Weights", "MaskRCNN_ResNet50_FPN_Weights",
"MaskRCNN_ResNet50_FPN_V2_Weights",
"maskrcnn_resnet50_fpn", "maskrcnn_resnet50_fpn",
"maskrcnn_resnet50_fpn_v2",
] ]
...@@ -264,28 +266,68 @@ class MaskRCNN(FasterRCNN): ...@@ -264,28 +266,68 @@ class MaskRCNN(FasterRCNN):
class MaskRCNNHeads(nn.Sequential): class MaskRCNNHeads(nn.Sequential):
def __init__(self, in_channels, layers, dilation): _version = 2
def __init__(self, in_channels, layers, dilation, norm_layer: Optional[Callable[..., nn.Module]] = None):
""" """
Args: Args:
in_channels (int): number of input channels in_channels (int): number of input channels
layers (list): feature dimensions of each FCN layer layers (list): feature dimensions of each FCN layer
dilation (int): dilation rate of kernel dilation (int): dilation rate of kernel
norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
""" """
d = OrderedDict() blocks = []
next_feature = in_channels next_feature = in_channels
for layer_idx, layer_features in enumerate(layers, 1): for layer_features in layers:
d[f"mask_fcn{layer_idx}"] = nn.Conv2d( blocks.append(
next_feature, layer_features, kernel_size=3, stride=1, padding=dilation, dilation=dilation misc_nn_ops.Conv2dNormActivation(
next_feature,
layer_features,
kernel_size=3,
stride=1,
padding=dilation,
dilation=dilation,
norm_layer=norm_layer,
)
) )
d[f"relu{layer_idx}"] = nn.ReLU(inplace=True)
next_feature = layer_features next_feature = layer_features
super().__init__(d) super().__init__(*blocks)
for name, param in self.named_parameters(): for layer in self.modules():
if "weight" in name: if isinstance(layer, nn.Conv2d):
nn.init.kaiming_normal_(param, mode="fan_out", nonlinearity="relu") nn.init.kaiming_normal_(layer.weight, mode="fan_out", nonlinearity="relu")
# elif "bias" in name: if layer.bias is not None:
# nn.init.constant_(param, 0) nn.init.zeros_(layer.bias)
def _load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
version = local_metadata.get("version", None)
if version is None or version < 2:
num_blocks = len(self)
for i in range(num_blocks):
for type in ["weight", "bias"]:
old_key = f"{prefix}mask_fcn{i+1}.{type}"
new_key = f"{prefix}{i}.0.{type}"
state_dict[new_key] = state_dict.pop(old_key)
super()._load_from_state_dict(
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
)
class MaskRCNNPredictor(nn.Sequential): class MaskRCNNPredictor(nn.Sequential):
...@@ -326,6 +368,10 @@ class MaskRCNN_ResNet50_FPN_Weights(WeightsEnum): ...@@ -326,6 +368,10 @@ class MaskRCNN_ResNet50_FPN_Weights(WeightsEnum):
DEFAULT = COCO_V1 DEFAULT = COCO_V1
class MaskRCNN_ResNet50_FPN_V2_Weights(WeightsEnum):
pass
@handle_legacy_interface( @handle_legacy_interface(
weights=("pretrained", MaskRCNN_ResNet50_FPN_Weights.COCO_V1), weights=("pretrained", MaskRCNN_ResNet50_FPN_Weights.COCO_V1),
weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1), weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
...@@ -418,3 +464,65 @@ def maskrcnn_resnet50_fpn( ...@@ -418,3 +464,65 @@ def maskrcnn_resnet50_fpn(
overwrite_eps(model, 0.0) overwrite_eps(model, 0.0)
return model return model
def maskrcnn_resnet50_fpn_v2(
*,
weights: Optional[MaskRCNN_ResNet50_FPN_V2_Weights] = None,
progress: bool = True,
num_classes: Optional[int] = None,
weights_backbone: Optional[ResNet50_Weights] = None,
trainable_backbone_layers: Optional[int] = None,
**kwargs: Any,
) -> MaskRCNN:
"""
Constructs an improved MaskRCNN model with a ResNet-50-FPN backbone.
Reference: `"Benchmarking Detection Transfer Learning with Vision Transformers"
<https://arxiv.org/abs/2111.11429>`_.
:func:`~torchvision.models.detection.maskrcnn_resnet50_fpn` for more details.
Args:
weights (MaskRCNN_ResNet50_FPN_V2_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr
num_classes (int, optional): number of output classes of the model (including the background)
weights_backbone (ResNet50_Weights, optional): The pretrained weights for the backbone
trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block.
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is
passed (the default) this value is set to 3.
"""
weights = MaskRCNN_ResNet50_FPN_V2_Weights.verify(weights)
weights_backbone = ResNet50_Weights.verify(weights_backbone)
if weights is not None:
weights_backbone = None
num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
elif num_classes is None:
num_classes = 91
is_trained = weights is not None or weights_backbone is not None
trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
backbone = resnet50(weights=weights_backbone, progress=progress)
backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers, norm_layer=nn.BatchNorm2d)
rpn_anchor_generator = _default_anchorgen()
rpn_head = RPNHead(backbone.out_channels, rpn_anchor_generator.num_anchors_per_location()[0], conv_depth=2)
box_head = FastRCNNConvFCHead(
(backbone.out_channels, 7, 7), [256, 256, 256, 256], [1024], norm_layer=nn.BatchNorm2d
)
mask_head = MaskRCNNHeads(backbone.out_channels, [256, 256, 256, 256], 1, norm_layer=nn.BatchNorm2d)
model = MaskRCNN(
backbone,
num_classes=num_classes,
rpn_anchor_generator=rpn_anchor_generator,
rpn_head=rpn_head,
box_head=box_head,
mask_head=mask_head,
**kwargs,
)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
return model
import math import math
import warnings import warnings
from collections import OrderedDict from collections import OrderedDict
from typing import Any, Dict, List, Tuple, Optional from functools import partial
from typing import Any, Callable, Dict, List, Tuple, Optional
import torch import torch
from torch import nn, Tensor from torch import nn, Tensor
...@@ -17,7 +18,7 @@ from .._meta import _COCO_CATEGORIES ...@@ -17,7 +18,7 @@ from .._meta import _COCO_CATEGORIES
from .._utils import handle_legacy_interface, _ovewrite_value_param from .._utils import handle_legacy_interface, _ovewrite_value_param
from ..resnet import ResNet50_Weights, resnet50 from ..resnet import ResNet50_Weights, resnet50
from . import _utils as det_utils from . import _utils as det_utils
from ._utils import overwrite_eps from ._utils import overwrite_eps, _box_loss
from .anchor_utils import AnchorGenerator from .anchor_utils import AnchorGenerator
from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
from .transform import GeneralizedRCNNTransform from .transform import GeneralizedRCNNTransform
...@@ -26,7 +27,9 @@ from .transform import GeneralizedRCNNTransform ...@@ -26,7 +27,9 @@ from .transform import GeneralizedRCNNTransform
__all__ = [ __all__ = [
"RetinaNet", "RetinaNet",
"RetinaNet_ResNet50_FPN_Weights", "RetinaNet_ResNet50_FPN_Weights",
"RetinaNet_ResNet50_FPN_V2_Weights",
"retinanet_resnet50_fpn", "retinanet_resnet50_fpn",
"retinanet_resnet50_fpn_v2",
] ]
...@@ -37,6 +40,21 @@ def _sum(x: List[Tensor]) -> Tensor: ...@@ -37,6 +40,21 @@ def _sum(x: List[Tensor]) -> Tensor:
return res return res
def _v1_to_v2_weights(state_dict, prefix):
for i in range(4):
for type in ["weight", "bias"]:
old_key = f"{prefix}conv.{2*i}.{type}"
new_key = f"{prefix}conv.{i}.0.{type}"
state_dict[new_key] = state_dict.pop(old_key)
def _default_anchorgen():
anchor_sizes = tuple((x, int(x * 2 ** (1.0 / 3)), int(x * 2 ** (2.0 / 3))) for x in [32, 64, 128, 256, 512])
aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
anchor_generator = AnchorGenerator(anchor_sizes, aspect_ratios)
return anchor_generator
class RetinaNetHead(nn.Module): class RetinaNetHead(nn.Module):
""" """
A regression and classification head for use in RetinaNet. A regression and classification head for use in RetinaNet.
...@@ -45,12 +63,15 @@ class RetinaNetHead(nn.Module): ...@@ -45,12 +63,15 @@ class RetinaNetHead(nn.Module):
in_channels (int): number of channels of the input feature in_channels (int): number of channels of the input feature
num_anchors (int): number of anchors to be predicted num_anchors (int): number of anchors to be predicted
num_classes (int): number of classes to be predicted num_classes (int): number of classes to be predicted
norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
""" """
def __init__(self, in_channels, num_anchors, num_classes): def __init__(self, in_channels, num_anchors, num_classes, norm_layer: Optional[Callable[..., nn.Module]] = None):
super().__init__() super().__init__()
self.classification_head = RetinaNetClassificationHead(in_channels, num_anchors, num_classes) self.classification_head = RetinaNetClassificationHead(
self.regression_head = RetinaNetRegressionHead(in_channels, num_anchors) in_channels, num_anchors, num_classes, norm_layer=norm_layer
)
self.regression_head = RetinaNetRegressionHead(in_channels, num_anchors, norm_layer=norm_layer)
def compute_loss(self, targets, head_outputs, anchors, matched_idxs): def compute_loss(self, targets, head_outputs, anchors, matched_idxs):
# type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor], List[Tensor]) -> Dict[str, Tensor] # type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor], List[Tensor]) -> Dict[str, Tensor]
...@@ -72,20 +93,30 @@ class RetinaNetClassificationHead(nn.Module): ...@@ -72,20 +93,30 @@ class RetinaNetClassificationHead(nn.Module):
in_channels (int): number of channels of the input feature in_channels (int): number of channels of the input feature
num_anchors (int): number of anchors to be predicted num_anchors (int): number of anchors to be predicted
num_classes (int): number of classes to be predicted num_classes (int): number of classes to be predicted
norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
""" """
def __init__(self, in_channels, num_anchors, num_classes, prior_probability=0.01): _version = 2
def __init__(
self,
in_channels,
num_anchors,
num_classes,
prior_probability=0.01,
norm_layer: Optional[Callable[..., nn.Module]] = None,
):
super().__init__() super().__init__()
conv = [] conv = []
for _ in range(4): for _ in range(4):
conv.append(nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)) conv.append(misc_nn_ops.Conv2dNormActivation(in_channels, in_channels, norm_layer=norm_layer))
conv.append(nn.ReLU())
self.conv = nn.Sequential(*conv) self.conv = nn.Sequential(*conv)
for layer in self.conv.children(): for layer in self.conv.modules():
if isinstance(layer, nn.Conv2d): if isinstance(layer, nn.Conv2d):
torch.nn.init.normal_(layer.weight, std=0.01) torch.nn.init.normal_(layer.weight, std=0.01)
if layer.bias is not None:
torch.nn.init.constant_(layer.bias, 0) torch.nn.init.constant_(layer.bias, 0)
self.cls_logits = nn.Conv2d(in_channels, num_anchors * num_classes, kernel_size=3, stride=1, padding=1) self.cls_logits = nn.Conv2d(in_channels, num_anchors * num_classes, kernel_size=3, stride=1, padding=1)
...@@ -100,6 +131,31 @@ class RetinaNetClassificationHead(nn.Module): ...@@ -100,6 +131,31 @@ class RetinaNetClassificationHead(nn.Module):
# https://github.com/pytorch/vision/pull/1697#issuecomment-630255584 # https://github.com/pytorch/vision/pull/1697#issuecomment-630255584
self.BETWEEN_THRESHOLDS = det_utils.Matcher.BETWEEN_THRESHOLDS self.BETWEEN_THRESHOLDS = det_utils.Matcher.BETWEEN_THRESHOLDS
def _load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
version = local_metadata.get("version", None)
if version is None or version < 2:
_v1_to_v2_weights(state_dict, prefix)
super()._load_from_state_dict(
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
)
def compute_loss(self, targets, head_outputs, matched_idxs): def compute_loss(self, targets, head_outputs, matched_idxs):
# type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor]) -> Tensor # type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor]) -> Tensor
losses = [] losses = []
...@@ -159,31 +215,60 @@ class RetinaNetRegressionHead(nn.Module): ...@@ -159,31 +215,60 @@ class RetinaNetRegressionHead(nn.Module):
Args: Args:
in_channels (int): number of channels of the input feature in_channels (int): number of channels of the input feature
num_anchors (int): number of anchors to be predicted num_anchors (int): number of anchors to be predicted
norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
""" """
_version = 2
__annotations__ = { __annotations__ = {
"box_coder": det_utils.BoxCoder, "box_coder": det_utils.BoxCoder,
} }
def __init__(self, in_channels, num_anchors): def __init__(self, in_channels, num_anchors, norm_layer: Optional[Callable[..., nn.Module]] = None):
super().__init__() super().__init__()
conv = [] conv = []
for _ in range(4): for _ in range(4):
conv.append(nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)) conv.append(misc_nn_ops.Conv2dNormActivation(in_channels, in_channels, norm_layer=norm_layer))
conv.append(nn.ReLU())
self.conv = nn.Sequential(*conv) self.conv = nn.Sequential(*conv)
self.bbox_reg = nn.Conv2d(in_channels, num_anchors * 4, kernel_size=3, stride=1, padding=1) self.bbox_reg = nn.Conv2d(in_channels, num_anchors * 4, kernel_size=3, stride=1, padding=1)
torch.nn.init.normal_(self.bbox_reg.weight, std=0.01) torch.nn.init.normal_(self.bbox_reg.weight, std=0.01)
torch.nn.init.zeros_(self.bbox_reg.bias) torch.nn.init.zeros_(self.bbox_reg.bias)
for layer in self.conv.children(): for layer in self.conv.modules():
if isinstance(layer, nn.Conv2d): if isinstance(layer, nn.Conv2d):
torch.nn.init.normal_(layer.weight, std=0.01) torch.nn.init.normal_(layer.weight, std=0.01)
if layer.bias is not None:
torch.nn.init.zeros_(layer.bias) torch.nn.init.zeros_(layer.bias)
self.box_coder = det_utils.BoxCoder(weights=(1.0, 1.0, 1.0, 1.0)) self.box_coder = det_utils.BoxCoder(weights=(1.0, 1.0, 1.0, 1.0))
self._loss_type = "l1"
def _load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
version = local_metadata.get("version", None)
if version is None or version < 2:
_v1_to_v2_weights(state_dict, prefix)
super()._load_from_state_dict(
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
)
def compute_loss(self, targets, head_outputs, anchors, matched_idxs): def compute_loss(self, targets, head_outputs, anchors, matched_idxs):
# type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor], List[Tensor]) -> Tensor # type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor], List[Tensor]) -> Tensor
...@@ -203,12 +288,15 @@ class RetinaNetRegressionHead(nn.Module): ...@@ -203,12 +288,15 @@ class RetinaNetRegressionHead(nn.Module):
bbox_regression_per_image = bbox_regression_per_image[foreground_idxs_per_image, :] bbox_regression_per_image = bbox_regression_per_image[foreground_idxs_per_image, :]
anchors_per_image = anchors_per_image[foreground_idxs_per_image, :] anchors_per_image = anchors_per_image[foreground_idxs_per_image, :]
# compute the regression targets
target_regression = self.box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image)
# compute the loss # compute the loss
losses.append( losses.append(
torch.nn.functional.l1_loss(bbox_regression_per_image, target_regression, reduction="sum") _box_loss(
self._loss_type,
self.box_coder,
anchors_per_image,
matched_gt_boxes_per_image,
bbox_regression_per_image,
)
/ max(1, num_foreground) / max(1, num_foreground)
) )
...@@ -361,9 +449,7 @@ class RetinaNet(nn.Module): ...@@ -361,9 +449,7 @@ class RetinaNet(nn.Module):
) )
if anchor_generator is None: if anchor_generator is None:
anchor_sizes = tuple((x, int(x * 2 ** (1.0 / 3)), int(x * 2 ** (2.0 / 3))) for x in [32, 64, 128, 256, 512]) anchor_generator = _default_anchorgen()
aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
anchor_generator = AnchorGenerator(anchor_sizes, aspect_ratios)
self.anchor_generator = anchor_generator self.anchor_generator = anchor_generator
if head is None: if head is None:
...@@ -604,6 +690,10 @@ class RetinaNet_ResNet50_FPN_Weights(WeightsEnum): ...@@ -604,6 +690,10 @@ class RetinaNet_ResNet50_FPN_Weights(WeightsEnum):
DEFAULT = COCO_V1 DEFAULT = COCO_V1
class RetinaNet_ResNet50_FPN_V2_Weights(WeightsEnum):
pass
@handle_legacy_interface( @handle_legacy_interface(
weights=("pretrained", RetinaNet_ResNet50_FPN_Weights.COCO_V1), weights=("pretrained", RetinaNet_ResNet50_FPN_Weights.COCO_V1),
weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1), weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
...@@ -690,3 +780,61 @@ def retinanet_resnet50_fpn( ...@@ -690,3 +780,61 @@ def retinanet_resnet50_fpn(
overwrite_eps(model, 0.0) overwrite_eps(model, 0.0)
return model return model
def retinanet_resnet50_fpn_v2(
*,
weights: Optional[RetinaNet_ResNet50_FPN_V2_Weights] = None,
progress: bool = True,
num_classes: Optional[int] = None,
weights_backbone: Optional[ResNet50_Weights] = None,
trainable_backbone_layers: Optional[int] = None,
**kwargs: Any,
) -> RetinaNet:
"""
Constructs an improved RetinaNet model with a ResNet-50-FPN backbone.
Reference: `"Bridging the Gap Between Anchor-based and Anchor-free Detection via Adaptive Training Sample Selection"
<https://arxiv.org/abs/1912.02424>`_.
:func:`~torchvision.models.detection.retinanet_resnet50_fpn` for more details.
Args:
weights (RetinaNet_ResNet50_FPN_V2_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr
num_classes (int, optional): number of output classes of the model (including the background)
weights_backbone (ResNet50_Weights, optional): The pretrained weights for the backbone
trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block.
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is
passed (the default) this value is set to 3.
"""
weights = RetinaNet_ResNet50_FPN_V2_Weights.verify(weights)
weights_backbone = ResNet50_Weights.verify(weights_backbone)
if weights is not None:
weights_backbone = None
num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
elif num_classes is None:
num_classes = 91
is_trained = weights is not None or weights_backbone is not None
trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
backbone = resnet50(weights=weights_backbone, progress=progress)
backbone = _resnet_fpn_extractor(
backbone, trainable_backbone_layers, returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(2048, 256)
)
anchor_generator = _default_anchorgen()
head = RetinaNetHead(
backbone.out_channels,
anchor_generator.num_anchors_per_location()[0],
num_classes,
norm_layer=partial(nn.GroupNorm, 32),
)
head.regression_head._loss_type = "giou"
model = RetinaNet(backbone, num_classes, anchor_generator=anchor_generator, head=head, **kwargs)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
return model
...@@ -3,6 +3,7 @@ from typing import List, Optional, Dict, Tuple ...@@ -3,6 +3,7 @@ from typing import List, Optional, Dict, Tuple
import torch import torch
from torch import nn, Tensor from torch import nn, Tensor
from torch.nn import functional as F from torch.nn import functional as F
from torchvision.ops import Conv2dNormActivation
from torchvision.ops import boxes as box_ops from torchvision.ops import boxes as box_ops
from . import _utils as det_utils from . import _utils as det_utils
...@@ -19,23 +20,59 @@ class RPNHead(nn.Module): ...@@ -19,23 +20,59 @@ class RPNHead(nn.Module):
Args: Args:
in_channels (int): number of channels of the input feature in_channels (int): number of channels of the input feature
num_anchors (int): number of anchors to be predicted num_anchors (int): number of anchors to be predicted
conv_depth (int, optional): number of convolutions
""" """
def __init__(self, in_channels: int, num_anchors: int) -> None: _version = 2
def __init__(self, in_channels: int, num_anchors: int, conv_depth=1) -> None:
super().__init__() super().__init__()
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) convs = []
for _ in range(conv_depth):
convs.append(Conv2dNormActivation(in_channels, in_channels, kernel_size=3, norm_layer=None))
self.conv = nn.Sequential(*convs)
self.cls_logits = nn.Conv2d(in_channels, num_anchors, kernel_size=1, stride=1) self.cls_logits = nn.Conv2d(in_channels, num_anchors, kernel_size=1, stride=1)
self.bbox_pred = nn.Conv2d(in_channels, num_anchors * 4, kernel_size=1, stride=1) self.bbox_pred = nn.Conv2d(in_channels, num_anchors * 4, kernel_size=1, stride=1)
for layer in self.children(): for layer in self.modules():
if isinstance(layer, nn.Conv2d):
torch.nn.init.normal_(layer.weight, std=0.01) # type: ignore[arg-type] torch.nn.init.normal_(layer.weight, std=0.01) # type: ignore[arg-type]
if layer.bias is not None:
torch.nn.init.constant_(layer.bias, 0) # type: ignore[arg-type] torch.nn.init.constant_(layer.bias, 0) # type: ignore[arg-type]
def _load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
version = local_metadata.get("version", None)
if version is None or version < 2:
for type in ["weight", "bias"]:
old_key = f"{prefix}conv.{type}"
new_key = f"{prefix}conv.0.0.{type}"
state_dict[new_key] = state_dict.pop(old_key)
super()._load_from_state_dict(
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
)
def forward(self, x: List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]: def forward(self, x: List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]:
logits = [] logits = []
bbox_reg = [] bbox_reg = []
for feature in x: for feature in x:
t = F.relu(self.conv(feature)) t = self.conv(feature)
logits.append(self.cls_logits(t)) logits.append(self.cls_logits(t))
bbox_reg.append(self.bbox_pred(t)) bbox_reg.append(self.bbox_pred(t))
return logits, bbox_reg return logits, bbox_reg
......
from collections import OrderedDict from collections import OrderedDict
from typing import Tuple, List, Dict, Optional from typing import Tuple, List, Dict, Callable, Optional
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn, Tensor from torch import nn, Tensor
from ..ops.misc import Conv2dNormActivation
from ..utils import _log_api_usage_once from ..utils import _log_api_usage_once
...@@ -51,6 +52,7 @@ class FeaturePyramidNetwork(nn.Module): ...@@ -51,6 +52,7 @@ class FeaturePyramidNetwork(nn.Module):
be performed. It is expected to take the fpn features, the original be performed. It is expected to take the fpn features, the original
features and the names of the original features as input, and returns features and the names of the original features as input, and returns
a new list of feature maps and their corresponding names a new list of feature maps and their corresponding names
norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
Examples:: Examples::
...@@ -70,11 +72,14 @@ class FeaturePyramidNetwork(nn.Module): ...@@ -70,11 +72,14 @@ class FeaturePyramidNetwork(nn.Module):
""" """
_version = 2
def __init__( def __init__(
self, self,
in_channels_list: List[int], in_channels_list: List[int],
out_channels: int, out_channels: int,
extra_blocks: Optional[ExtraFPNBlock] = None, extra_blocks: Optional[ExtraFPNBlock] = None,
norm_layer: Optional[Callable[..., nn.Module]] = None,
): ):
super().__init__() super().__init__()
_log_api_usage_once(self) _log_api_usage_once(self)
...@@ -83,8 +88,12 @@ class FeaturePyramidNetwork(nn.Module): ...@@ -83,8 +88,12 @@ class FeaturePyramidNetwork(nn.Module):
for in_channels in in_channels_list: for in_channels in in_channels_list:
if in_channels == 0: if in_channels == 0:
raise ValueError("in_channels=0 is currently not supported") raise ValueError("in_channels=0 is currently not supported")
inner_block_module = nn.Conv2d(in_channels, out_channels, 1) inner_block_module = Conv2dNormActivation(
layer_block_module = nn.Conv2d(out_channels, out_channels, 3, padding=1) in_channels, out_channels, kernel_size=1, padding=0, norm_layer=norm_layer, activation_layer=None
)
layer_block_module = Conv2dNormActivation(
out_channels, out_channels, kernel_size=3, norm_layer=norm_layer, activation_layer=None
)
self.inner_blocks.append(inner_block_module) self.inner_blocks.append(inner_block_module)
self.layer_blocks.append(layer_block_module) self.layer_blocks.append(layer_block_module)
...@@ -92,6 +101,7 @@ class FeaturePyramidNetwork(nn.Module): ...@@ -92,6 +101,7 @@ class FeaturePyramidNetwork(nn.Module):
for m in self.modules(): for m in self.modules():
if isinstance(m, nn.Conv2d): if isinstance(m, nn.Conv2d):
nn.init.kaiming_uniform_(m.weight, a=1) nn.init.kaiming_uniform_(m.weight, a=1)
if m.bias is not None:
nn.init.constant_(m.bias, 0) nn.init.constant_(m.bias, 0)
if extra_blocks is not None: if extra_blocks is not None:
...@@ -99,6 +109,37 @@ class FeaturePyramidNetwork(nn.Module): ...@@ -99,6 +109,37 @@ class FeaturePyramidNetwork(nn.Module):
raise TypeError(f"extra_blocks should be of type ExtraFPNBlock not {type(extra_blocks)}") raise TypeError(f"extra_blocks should be of type ExtraFPNBlock not {type(extra_blocks)}")
self.extra_blocks = extra_blocks self.extra_blocks = extra_blocks
def _load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
version = local_metadata.get("version", None)
if version is None or version < 2:
num_blocks = len(self.inner_blocks)
for block in ["inner_blocks", "layer_blocks"]:
for i in range(num_blocks):
for type in ["weight", "bias"]:
old_key = f"{prefix}{block}.{i}.{type}"
new_key = f"{prefix}{block}.{i}.0.{type}"
state_dict[new_key] = state_dict.pop(old_key)
super()._load_from_state_dict(
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
)
def get_result_from_inner_blocks(self, x: Tensor, idx: int) -> Tensor: def get_result_from_inner_blocks(self, x: Tensor, idx: int) -> Tensor:
""" """
This is equivalent to self.inner_blocks[idx](x), This is equivalent to self.inner_blocks[idx](x),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment