Unverified Commit 9bee9cc4 authored by Aditya Oke's avatar Aditya Oke Committed by GitHub
Browse files

Add typing annotations to detection/backbone_utils (#4603)



* Start adding types

* add typing

* Type prototype models

* fix optional type bug

* transient import

* Fix weights type

* fix import

* Apply suggestions from code review

Address nits
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
Co-authored-by: default avatarKhushi Agrawal <khushiagrawal411@gmail.com>
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent 015eb46b
...@@ -21,10 +21,6 @@ ignore_errors=True ...@@ -21,10 +21,6 @@ ignore_errors=True
ignore_errors = True ignore_errors = True
[mypy-torchvision.models.detection.backbone_utils]
ignore_errors = True
[mypy-torchvision.models.detection.transform] [mypy-torchvision.models.detection.transform]
ignore_errors = True ignore_errors = True
......
import warnings import warnings
from typing import List, Optional from typing import Callable, Dict, Optional, List
from torch import nn from torch import nn, Tensor
from torchvision.ops import misc as misc_nn_ops from torchvision.ops import misc as misc_nn_ops
from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork, LastLevelMaxPool, ExtraFPNBlock from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork, LastLevelMaxPool, ExtraFPNBlock
...@@ -29,7 +29,14 @@ class BackboneWithFPN(nn.Module): ...@@ -29,7 +29,14 @@ class BackboneWithFPN(nn.Module):
out_channels (int): the number of channels in the FPN out_channels (int): the number of channels in the FPN
""" """
def __init__(self, backbone, return_layers, in_channels_list, out_channels, extra_blocks=None): def __init__(
self,
backbone: nn.Module,
return_layers: Dict[str, str],
in_channels_list: List[int],
out_channels: int,
extra_blocks: Optional[ExtraFPNBlock] = None,
) -> None:
super(BackboneWithFPN, self).__init__() super(BackboneWithFPN, self).__init__()
if extra_blocks is None: if extra_blocks is None:
...@@ -43,20 +50,20 @@ class BackboneWithFPN(nn.Module): ...@@ -43,20 +50,20 @@ class BackboneWithFPN(nn.Module):
) )
self.out_channels = out_channels self.out_channels = out_channels
def forward(self, x): def forward(self, x: Tensor) -> Dict[str, Tensor]:
x = self.body(x) x = self.body(x)
x = self.fpn(x) x = self.fpn(x)
return x return x
def resnet_fpn_backbone( def resnet_fpn_backbone(
backbone_name, backbone_name: str,
pretrained, pretrained: bool,
norm_layer=misc_nn_ops.FrozenBatchNorm2d, norm_layer: Callable[..., nn.Module] = misc_nn_ops.FrozenBatchNorm2d,
trainable_layers=3, trainable_layers: int = 3,
returned_layers=None, returned_layers: Optional[List[int]] = None,
extra_blocks=None, extra_blocks: Optional[ExtraFPNBlock] = None,
): ) -> BackboneWithFPN:
""" """
Constructs a specified ResNet backbone with FPN on top. Freezes the specified number of layers in the backbone. Constructs a specified ResNet backbone with FPN on top. Freezes the specified number of layers in the backbone.
...@@ -80,7 +87,7 @@ def resnet_fpn_backbone( ...@@ -80,7 +87,7 @@ def resnet_fpn_backbone(
backbone_name (string): resnet architecture. Possible values are 'ResNet', 'resnet18', 'resnet34', 'resnet50', backbone_name (string): resnet architecture. Possible values are 'ResNet', 'resnet18', 'resnet34', 'resnet50',
'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 'wide_resnet50_2', 'wide_resnet101_2' 'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 'wide_resnet50_2', 'wide_resnet101_2'
pretrained (bool): If True, returns a model with backbone pre-trained on Imagenet pretrained (bool): If True, returns a model with backbone pre-trained on Imagenet
norm_layer (torchvision.ops): it is recommended to use the default value. For details visit: norm_layer (callable): it is recommended to use the default value. For details visit:
(https://github.com/facebookresearch/maskrcnn-benchmark/issues/267) (https://github.com/facebookresearch/maskrcnn-benchmark/issues/267)
trainable_layers (int): number of trainable (not frozen) resnet layers starting from final block. trainable_layers (int): number of trainable (not frozen) resnet layers starting from final block.
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
...@@ -101,7 +108,8 @@ def _resnet_backbone_config( ...@@ -101,7 +108,8 @@ def _resnet_backbone_config(
trainable_layers: int, trainable_layers: int,
returned_layers: Optional[List[int]], returned_layers: Optional[List[int]],
extra_blocks: Optional[ExtraFPNBlock], extra_blocks: Optional[ExtraFPNBlock],
): ) -> BackboneWithFPN:
# select layers that wont be frozen # select layers that wont be frozen
assert 0 <= trainable_layers <= 5 assert 0 <= trainable_layers <= 5
layers_to_train = ["layer4", "layer3", "layer2", "layer1", "conv1"][:trainable_layers] layers_to_train = ["layer4", "layer3", "layer2", "layer1", "conv1"][:trainable_layers]
...@@ -125,8 +133,13 @@ def _resnet_backbone_config( ...@@ -125,8 +133,13 @@ def _resnet_backbone_config(
return BackboneWithFPN(backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks) return BackboneWithFPN(backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks)
def _validate_trainable_layers(pretrained, trainable_backbone_layers, max_value, default_value): def _validate_trainable_layers(
# dont freeze any layers if pretrained model or backbone is not used pretrained: bool,
trainable_backbone_layers: Optional[int],
max_value: int,
default_value: int,
) -> int:
# don't freeze any layers if pretrained model or backbone is not used
if not pretrained: if not pretrained:
if trainable_backbone_layers is not None: if trainable_backbone_layers is not None:
warnings.warn( warnings.warn(
...@@ -144,14 +157,15 @@ def _validate_trainable_layers(pretrained, trainable_backbone_layers, max_value, ...@@ -144,14 +157,15 @@ def _validate_trainable_layers(pretrained, trainable_backbone_layers, max_value,
def mobilenet_backbone( def mobilenet_backbone(
backbone_name, backbone_name: str,
pretrained, pretrained: bool,
fpn, fpn: bool,
norm_layer=misc_nn_ops.FrozenBatchNorm2d, norm_layer: Callable[..., nn.Module] = misc_nn_ops.FrozenBatchNorm2d,
trainable_layers=2, trainable_layers: int = 2,
returned_layers=None, returned_layers: Optional[List[int]] = None,
extra_blocks=None, extra_blocks: Optional[ExtraFPNBlock] = None,
): ) -> nn.Module:
backbone = mobilenet.__dict__[backbone_name](pretrained=pretrained, norm_layer=norm_layer).features backbone = mobilenet.__dict__[backbone_name](pretrained=pretrained, norm_layer=norm_layer).features
# Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks. # Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
...@@ -185,5 +199,5 @@ def mobilenet_backbone( ...@@ -185,5 +199,5 @@ def mobilenet_backbone(
# depthwise linear combination of channels to reduce their size # depthwise linear combination of channels to reduce their size
nn.Conv2d(backbone[-1].out_channels, out_channels, 1), nn.Conv2d(backbone[-1].out_channels, out_channels, 1),
) )
m.out_channels = out_channels m.out_channels = out_channels # type: ignore[assignment]
return m return m
from ....models.detection.backbone_utils import misc_nn_ops, _resnet_backbone_config from typing import Callable, Optional, List
from torch import nn
from ....models.detection.backbone_utils import misc_nn_ops, _resnet_backbone_config, BackboneWithFPN, ExtraFPNBlock
from .. import resnet from .. import resnet
from .._api import Weights
def resnet_fpn_backbone( def resnet_fpn_backbone(
backbone_name, backbone_name: str,
weights, weights: Optional[Weights],
norm_layer=misc_nn_ops.FrozenBatchNorm2d, norm_layer: Callable[..., nn.Module] = misc_nn_ops.FrozenBatchNorm2d,
trainable_layers=3, trainable_layers: int = 3,
returned_layers=None, returned_layers: Optional[List[int]] = None,
extra_blocks=None, extra_blocks: Optional[ExtraFPNBlock] = None,
): ) -> BackboneWithFPN:
backbone = resnet.__dict__[backbone_name](weights=weights, norm_layer=norm_layer) backbone = resnet.__dict__[backbone_name](weights=weights, norm_layer=norm_layer)
return _resnet_backbone_config(backbone, trainable_layers, returned_layers, extra_blocks) return _resnet_backbone_config(backbone, trainable_layers, returned_layers, extra_blocks)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment