"src/vscode:/vscode.git/clone" did not exist on "81780882b8cfb7628a2e09dbaae566ead5d760e8"
Unverified Commit ec2456aa authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Multi-pretrained weight support - FasterRCNN ResNet50 (#4613)

* Adding FasterRCNN ResNet50.

* Refactoring to remove duplicate code.

* Adding typing info.

* Setting weights_backbone=None as default value.

* Overwrite eps only for specific weights.
parent 3d7244b5
import warnings
from typing import List, Optional
from torch import nn
from torchvision.ops import misc as misc_nn_ops
from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork, LastLevelMaxPool
from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork, LastLevelMaxPool, ExtraFPNBlock
from .. import mobilenet
from .. import resnet
......@@ -92,7 +93,15 @@ def resnet_fpn_backbone(
default a ``LastLevelMaxPool`` is used.
"""
backbone = resnet.__dict__[backbone_name](pretrained=pretrained, norm_layer=norm_layer)
return _resnet_backbone_config(backbone, trainable_layers, returned_layers, extra_blocks)
def _resnet_backbone_config(
backbone: resnet.ResNet,
trainable_layers: int,
returned_layers: Optional[List[int]],
extra_blocks: Optional[ExtraFPNBlock],
):
# select layers that wont be frozen
assert 0 <= trainable_layers <= 5
layers_to_train = ["layer4", "layer3", "layer2", "layer1", "conv1"][:trainable_layers]
......
from .resnet import *
from . import detection
......@@ -1006,3 +1006,98 @@ _IMAGENET_CATEGORIES = [
"ear",
"toilet tissue",
]
# To be replaced with torchvision.datasets.find("coco").info.categories
_COCO_CATEGORIES = [
"__background__",
"person",
"bicycle",
"car",
"motorcycle",
"airplane",
"bus",
"train",
"truck",
"boat",
"traffic light",
"fire hydrant",
"N/A",
"stop sign",
"parking meter",
"bench",
"bird",
"cat",
"dog",
"horse",
"sheep",
"cow",
"elephant",
"bear",
"zebra",
"giraffe",
"N/A",
"backpack",
"umbrella",
"N/A",
"N/A",
"handbag",
"tie",
"suitcase",
"frisbee",
"skis",
"snowboard",
"sports ball",
"kite",
"baseball bat",
"baseball glove",
"skateboard",
"surfboard",
"tennis racket",
"bottle",
"N/A",
"wine glass",
"cup",
"fork",
"knife",
"spoon",
"bowl",
"banana",
"apple",
"sandwich",
"orange",
"broccoli",
"carrot",
"hot dog",
"pizza",
"donut",
"cake",
"chair",
"couch",
"potted plant",
"bed",
"N/A",
"dining table",
"N/A",
"N/A",
"toilet",
"N/A",
"tv",
"laptop",
"mouse",
"remote",
"keyboard",
"cell phone",
"microwave",
"oven",
"toaster",
"sink",
"refrigerator",
"N/A",
"book",
"clock",
"vase",
"scissors",
"teddy bear",
"hair drier",
"toothbrush",
]
from ....models.detection.backbone_utils import misc_nn_ops, _resnet_backbone_config
from .. import resnet
def resnet_fpn_backbone(
backbone_name,
weights,
norm_layer=misc_nn_ops.FrozenBatchNorm2d,
trainable_layers=3,
returned_layers=None,
extra_blocks=None,
):
backbone = resnet.__dict__[backbone_name](weights=weights, norm_layer=norm_layer)
return _resnet_backbone_config(backbone, trainable_layers, returned_layers, extra_blocks)
import warnings
from typing import Any, Optional
from ....models.detection.faster_rcnn import FasterRCNN, overwrite_eps, _validate_trainable_layers
from ...transforms.presets import CocoEval
from .._api import Weights, WeightEntry
from .._meta import _COCO_CATEGORIES
from ..resnet import ResNet50Weights
from .backbone_utils import resnet_fpn_backbone
__all__ = ["FasterRCNN", "FasterRCNNResNet50FPNWeights", "fasterrcnn_resnet50_fpn"]
class FasterRCNNResNet50FPNWeights(Weights):
Coco_RefV1 = WeightEntry(
url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth",
transforms=CocoEval,
meta={
"categories": _COCO_CATEGORIES,
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-resnet-50-fpn",
"map": 37.0,
},
)
def fasterrcnn_resnet50_fpn(
weights: Optional[FasterRCNNResNet50FPNWeights] = None,
weights_backbone: Optional[ResNet50Weights] = None,
progress: bool = True,
num_classes: int = 91,
trainable_backbone_layers: Optional[int] = None,
**kwargs: Any,
) -> FasterRCNN:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
weights = FasterRCNNResNet50FPNWeights.Coco_RefV1 if kwargs.pop("pretrained") else None
weights = FasterRCNNResNet50FPNWeights.verify(weights)
if "pretrained_backbone" in kwargs:
warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.")
weights_backbone = ResNet50Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
weights_backbone = ResNet50Weights.verify(weights_backbone)
if weights is not None:
weights_backbone = None
num_classes = len(weights.meta["categories"])
trainable_backbone_layers = _validate_trainable_layers(
weights is not None or weights_backbone is not None, trainable_backbone_layers, 5, 3
)
backbone = resnet_fpn_backbone("resnet50", weights_backbone, trainable_layers=trainable_backbone_layers)
model = FasterRCNN(backbone, num_classes=num_classes, **kwargs)
if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress))
if weights == FasterRCNNResNet50FPNWeights.Coco_RefV1:
overwrite_eps(model, 0.0)
return model
from typing import Tuple
from typing import Dict, Optional, Tuple
import torch
from torch import Tensor, nn
......@@ -7,22 +7,19 @@ from ... import transforms as T
from ...transforms import functional as F
__all__ = ["ConvertImageDtype", "ImageNetEval"]
__all__ = ["CocoEval", "ImageNetEval"]
# Allows handling of both PIL and Tensor images
class ConvertImageDtype(nn.Module):
def __init__(self, dtype: torch.dtype) -> None:
super().__init__()
self.dtype = dtype
def forward(self, img: Tensor) -> Tensor:
class CocoEval(nn.Module):
def forward(
self, img: Tensor, target: Optional[Dict[str, Tensor]] = None
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
if not isinstance(img, Tensor):
img = F.pil_to_tensor(img)
return F.convert_image_dtype(img, self.dtype)
return F.convert_image_dtype(img, torch.float), target
class ImageNetEval:
class ImageNetEval(nn.Module):
def __init__(
self,
crop_size: int,
......@@ -31,14 +28,14 @@ class ImageNetEval:
std: Tuple[float, ...] = (0.229, 0.224, 0.225),
interpolation: T.InterpolationMode = T.InterpolationMode.BILINEAR,
) -> None:
self.transforms = T.Compose(
[
T.Resize(resize_size, interpolation=interpolation),
T.CenterCrop(crop_size),
ConvertImageDtype(dtype=torch.float),
T.Normalize(mean=mean, std=std),
]
)
def __call__(self, img: Tensor) -> Tensor:
return self.transforms(img)
super().__init__()
self._resize = T.Resize(resize_size, interpolation=interpolation)
self._crop = T.CenterCrop(crop_size)
self._normalize = T.Normalize(mean=mean, std=std)
def forward(self, img: Tensor) -> Tensor:
img = self._crop(self._resize(img))
if not isinstance(img, Tensor):
img = F.pil_to_tensor(img)
img = F.convert_image_dtype(img, torch.float)
return self._normalize(img)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment