Unverified Commit ec2456aa authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Multi-pretrained weight support - FasterRCNN ResNet50 (#4613)

* Adding FasterRCNN ResNet50.

* Refactoring to remove duplicate code.

* Adding typing info.

* Setting weights_backbone=None as default value.

* Overwrite eps only for specific weights.
parent 3d7244b5
import warnings import warnings
from typing import List, Optional
from torch import nn from torch import nn
from torchvision.ops import misc as misc_nn_ops from torchvision.ops import misc as misc_nn_ops
from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork, LastLevelMaxPool from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork, LastLevelMaxPool, ExtraFPNBlock
from .. import mobilenet from .. import mobilenet
from .. import resnet from .. import resnet
...@@ -92,7 +93,15 @@ def resnet_fpn_backbone( ...@@ -92,7 +93,15 @@ def resnet_fpn_backbone(
default a ``LastLevelMaxPool`` is used. default a ``LastLevelMaxPool`` is used.
""" """
backbone = resnet.__dict__[backbone_name](pretrained=pretrained, norm_layer=norm_layer) backbone = resnet.__dict__[backbone_name](pretrained=pretrained, norm_layer=norm_layer)
return _resnet_backbone_config(backbone, trainable_layers, returned_layers, extra_blocks)
def _resnet_backbone_config(
backbone: resnet.ResNet,
trainable_layers: int,
returned_layers: Optional[List[int]],
extra_blocks: Optional[ExtraFPNBlock],
):
# select layers that wont be frozen # select layers that wont be frozen
assert 0 <= trainable_layers <= 5 assert 0 <= trainable_layers <= 5
layers_to_train = ["layer4", "layer3", "layer2", "layer1", "conv1"][:trainable_layers] layers_to_train = ["layer4", "layer3", "layer2", "layer1", "conv1"][:trainable_layers]
......
from .resnet import * from .resnet import *
from . import detection
...@@ -1006,3 +1006,98 @@ _IMAGENET_CATEGORIES = [ ...@@ -1006,3 +1006,98 @@ _IMAGENET_CATEGORIES = [
"ear", "ear",
"toilet tissue", "toilet tissue",
] ]
# To be replaced with torchvision.datasets.find("coco").info.categories
_COCO_CATEGORIES = [
"__background__",
"person",
"bicycle",
"car",
"motorcycle",
"airplane",
"bus",
"train",
"truck",
"boat",
"traffic light",
"fire hydrant",
"N/A",
"stop sign",
"parking meter",
"bench",
"bird",
"cat",
"dog",
"horse",
"sheep",
"cow",
"elephant",
"bear",
"zebra",
"giraffe",
"N/A",
"backpack",
"umbrella",
"N/A",
"N/A",
"handbag",
"tie",
"suitcase",
"frisbee",
"skis",
"snowboard",
"sports ball",
"kite",
"baseball bat",
"baseball glove",
"skateboard",
"surfboard",
"tennis racket",
"bottle",
"N/A",
"wine glass",
"cup",
"fork",
"knife",
"spoon",
"bowl",
"banana",
"apple",
"sandwich",
"orange",
"broccoli",
"carrot",
"hot dog",
"pizza",
"donut",
"cake",
"chair",
"couch",
"potted plant",
"bed",
"N/A",
"dining table",
"N/A",
"N/A",
"toilet",
"N/A",
"tv",
"laptop",
"mouse",
"remote",
"keyboard",
"cell phone",
"microwave",
"oven",
"toaster",
"sink",
"refrigerator",
"N/A",
"book",
"clock",
"vase",
"scissors",
"teddy bear",
"hair drier",
"toothbrush",
]
from ....models.detection.backbone_utils import misc_nn_ops, _resnet_backbone_config
from .. import resnet
def resnet_fpn_backbone(
backbone_name,
weights,
norm_layer=misc_nn_ops.FrozenBatchNorm2d,
trainable_layers=3,
returned_layers=None,
extra_blocks=None,
):
backbone = resnet.__dict__[backbone_name](weights=weights, norm_layer=norm_layer)
return _resnet_backbone_config(backbone, trainable_layers, returned_layers, extra_blocks)
import warnings
from typing import Any, Optional
from ....models.detection.faster_rcnn import FasterRCNN, overwrite_eps, _validate_trainable_layers
from ...transforms.presets import CocoEval
from .._api import Weights, WeightEntry
from .._meta import _COCO_CATEGORIES
from ..resnet import ResNet50Weights
from .backbone_utils import resnet_fpn_backbone
__all__ = ["FasterRCNN", "FasterRCNNResNet50FPNWeights", "fasterrcnn_resnet50_fpn"]
class FasterRCNNResNet50FPNWeights(Weights):
Coco_RefV1 = WeightEntry(
url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth",
transforms=CocoEval,
meta={
"categories": _COCO_CATEGORIES,
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-resnet-50-fpn",
"map": 37.0,
},
)
def fasterrcnn_resnet50_fpn(
weights: Optional[FasterRCNNResNet50FPNWeights] = None,
weights_backbone: Optional[ResNet50Weights] = None,
progress: bool = True,
num_classes: int = 91,
trainable_backbone_layers: Optional[int] = None,
**kwargs: Any,
) -> FasterRCNN:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
weights = FasterRCNNResNet50FPNWeights.Coco_RefV1 if kwargs.pop("pretrained") else None
weights = FasterRCNNResNet50FPNWeights.verify(weights)
if "pretrained_backbone" in kwargs:
warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.")
weights_backbone = ResNet50Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
weights_backbone = ResNet50Weights.verify(weights_backbone)
if weights is not None:
weights_backbone = None
num_classes = len(weights.meta["categories"])
trainable_backbone_layers = _validate_trainable_layers(
weights is not None or weights_backbone is not None, trainable_backbone_layers, 5, 3
)
backbone = resnet_fpn_backbone("resnet50", weights_backbone, trainable_layers=trainable_backbone_layers)
model = FasterRCNN(backbone, num_classes=num_classes, **kwargs)
if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress))
if weights == FasterRCNNResNet50FPNWeights.Coco_RefV1:
overwrite_eps(model, 0.0)
return model
from typing import Tuple from typing import Dict, Optional, Tuple
import torch import torch
from torch import Tensor, nn from torch import Tensor, nn
...@@ -7,22 +7,19 @@ from ... import transforms as T ...@@ -7,22 +7,19 @@ from ... import transforms as T
from ...transforms import functional as F from ...transforms import functional as F
__all__ = ["ConvertImageDtype", "ImageNetEval"] __all__ = ["CocoEval", "ImageNetEval"]
# Allows handling of both PIL and Tensor images class CocoEval(nn.Module):
class ConvertImageDtype(nn.Module): def forward(
def __init__(self, dtype: torch.dtype) -> None: self, img: Tensor, target: Optional[Dict[str, Tensor]] = None
super().__init__() ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
self.dtype = dtype
def forward(self, img: Tensor) -> Tensor:
if not isinstance(img, Tensor): if not isinstance(img, Tensor):
img = F.pil_to_tensor(img) img = F.pil_to_tensor(img)
return F.convert_image_dtype(img, self.dtype) return F.convert_image_dtype(img, torch.float), target
class ImageNetEval: class ImageNetEval(nn.Module):
def __init__( def __init__(
self, self,
crop_size: int, crop_size: int,
...@@ -31,14 +28,14 @@ class ImageNetEval: ...@@ -31,14 +28,14 @@ class ImageNetEval:
std: Tuple[float, ...] = (0.229, 0.224, 0.225), std: Tuple[float, ...] = (0.229, 0.224, 0.225),
interpolation: T.InterpolationMode = T.InterpolationMode.BILINEAR, interpolation: T.InterpolationMode = T.InterpolationMode.BILINEAR,
) -> None: ) -> None:
self.transforms = T.Compose( super().__init__()
[ self._resize = T.Resize(resize_size, interpolation=interpolation)
T.Resize(resize_size, interpolation=interpolation), self._crop = T.CenterCrop(crop_size)
T.CenterCrop(crop_size), self._normalize = T.Normalize(mean=mean, std=std)
ConvertImageDtype(dtype=torch.float),
T.Normalize(mean=mean, std=std), def forward(self, img: Tensor) -> Tensor:
] img = self._crop(self._resize(img))
) if not isinstance(img, Tensor):
img = F.pil_to_tensor(img)
def __call__(self, img: Tensor) -> Tensor: img = F.convert_image_dtype(img, torch.float)
return self.transforms(img) return self._normalize(img)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment