"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "37799adc603474c9637aab00f2f5746924a846a3"
Unverified Commit 9bee9cc4 authored by Aditya Oke's avatar Aditya Oke Committed by GitHub
Browse files

Add typing annotations to detection/backbone_utils (#4603)



* Start adding types

* add typing

* Type prototype models

* fix optional type bug

* transient import

* Fix weights type

* fix import

* Apply suggestions from code review

Address nits
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
Co-authored-by: default avatarKhushi Agrawal <khushiagrawal411@gmail.com>
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent 015eb46b
......@@ -21,10 +21,6 @@ ignore_errors=True
ignore_errors = True
[mypy-torchvision.models.detection.backbone_utils]
ignore_errors = True
[mypy-torchvision.models.detection.transform]
ignore_errors = True
......
import warnings
from typing import List, Optional
from typing import Callable, Dict, Optional, List
from torch import nn
from torch import nn, Tensor
from torchvision.ops import misc as misc_nn_ops
from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork, LastLevelMaxPool, ExtraFPNBlock
......@@ -29,7 +29,14 @@ class BackboneWithFPN(nn.Module):
out_channels (int): the number of channels in the FPN
"""
def __init__(self, backbone, return_layers, in_channels_list, out_channels, extra_blocks=None):
def __init__(
self,
backbone: nn.Module,
return_layers: Dict[str, str],
in_channels_list: List[int],
out_channels: int,
extra_blocks: Optional[ExtraFPNBlock] = None,
) -> None:
super(BackboneWithFPN, self).__init__()
if extra_blocks is None:
......@@ -43,20 +50,20 @@ class BackboneWithFPN(nn.Module):
)
self.out_channels = out_channels
def forward(self, x):
def forward(self, x: Tensor) -> Dict[str, Tensor]:
x = self.body(x)
x = self.fpn(x)
return x
def resnet_fpn_backbone(
backbone_name,
pretrained,
norm_layer=misc_nn_ops.FrozenBatchNorm2d,
trainable_layers=3,
returned_layers=None,
extra_blocks=None,
):
backbone_name: str,
pretrained: bool,
norm_layer: Callable[..., nn.Module] = misc_nn_ops.FrozenBatchNorm2d,
trainable_layers: int = 3,
returned_layers: Optional[List[int]] = None,
extra_blocks: Optional[ExtraFPNBlock] = None,
) -> BackboneWithFPN:
"""
Constructs a specified ResNet backbone with FPN on top. Freezes the specified number of layers in the backbone.
......@@ -80,7 +87,7 @@ def resnet_fpn_backbone(
backbone_name (string): resnet architecture. Possible values are 'ResNet', 'resnet18', 'resnet34', 'resnet50',
'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 'wide_resnet50_2', 'wide_resnet101_2'
pretrained (bool): If True, returns a model with backbone pre-trained on Imagenet
norm_layer (torchvision.ops): it is recommended to use the default value. For details visit:
norm_layer (callable): it is recommended to use the default value. For details visit:
(https://github.com/facebookresearch/maskrcnn-benchmark/issues/267)
trainable_layers (int): number of trainable (not frozen) resnet layers starting from final block.
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
......@@ -101,7 +108,8 @@ def _resnet_backbone_config(
trainable_layers: int,
returned_layers: Optional[List[int]],
extra_blocks: Optional[ExtraFPNBlock],
):
) -> BackboneWithFPN:
# select layers that wont be frozen
assert 0 <= trainable_layers <= 5
layers_to_train = ["layer4", "layer3", "layer2", "layer1", "conv1"][:trainable_layers]
......@@ -125,8 +133,13 @@ def _resnet_backbone_config(
return BackboneWithFPN(backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks)
def _validate_trainable_layers(pretrained, trainable_backbone_layers, max_value, default_value):
# dont freeze any layers if pretrained model or backbone is not used
def _validate_trainable_layers(
pretrained: bool,
trainable_backbone_layers: Optional[int],
max_value: int,
default_value: int,
) -> int:
# don't freeze any layers if pretrained model or backbone is not used
if not pretrained:
if trainable_backbone_layers is not None:
warnings.warn(
......@@ -144,14 +157,15 @@ def _validate_trainable_layers(pretrained, trainable_backbone_layers, max_value,
def mobilenet_backbone(
backbone_name,
pretrained,
fpn,
norm_layer=misc_nn_ops.FrozenBatchNorm2d,
trainable_layers=2,
returned_layers=None,
extra_blocks=None,
):
backbone_name: str,
pretrained: bool,
fpn: bool,
norm_layer: Callable[..., nn.Module] = misc_nn_ops.FrozenBatchNorm2d,
trainable_layers: int = 2,
returned_layers: Optional[List[int]] = None,
extra_blocks: Optional[ExtraFPNBlock] = None,
) -> nn.Module:
backbone = mobilenet.__dict__[backbone_name](pretrained=pretrained, norm_layer=norm_layer).features
# Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
......@@ -185,5 +199,5 @@ def mobilenet_backbone(
# depthwise linear combination of channels to reduce their size
nn.Conv2d(backbone[-1].out_channels, out_channels, 1),
)
m.out_channels = out_channels
m.out_channels = out_channels # type: ignore[assignment]
return m
from ....models.detection.backbone_utils import misc_nn_ops, _resnet_backbone_config
from typing import Callable, Optional, List
from torch import nn
from ....models.detection.backbone_utils import misc_nn_ops, _resnet_backbone_config, BackboneWithFPN, ExtraFPNBlock
from .. import resnet
from .._api import Weights
def resnet_fpn_backbone(
backbone_name,
weights,
norm_layer=misc_nn_ops.FrozenBatchNorm2d,
trainable_layers=3,
returned_layers=None,
extra_blocks=None,
):
backbone_name: str,
weights: Optional[Weights],
norm_layer: Callable[..., nn.Module] = misc_nn_ops.FrozenBatchNorm2d,
trainable_layers: int = 3,
returned_layers: Optional[List[int]] = None,
extra_blocks: Optional[ExtraFPNBlock] = None,
) -> BackboneWithFPN:
backbone = resnet.__dict__[backbone_name](weights=weights, norm_layer=norm_layer)
return _resnet_backbone_config(backbone, trainable_layers, returned_layers, extra_blocks)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment