Unverified Commit 15366c4d authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Multi-weight support for FCN prototype segmentation models (#4726)

* Add mutli-weight support for FCN.

* Rename base_size to resize_size.
parent 3783f3f4
...@@ -7,3 +7,4 @@ from .mobilenetv3 import * ...@@ -7,3 +7,4 @@ from .mobilenetv3 import *
from .mnasnet import * from .mnasnet import *
from . import detection from . import detection
from . import quantization from . import quantization
from . import segmentation
...@@ -1101,3 +1101,28 @@ _COCO_CATEGORIES = [ ...@@ -1101,3 +1101,28 @@ _COCO_CATEGORIES = [
"hair drier", "hair drier",
"toothbrush", "toothbrush",
] ]
# To be replaced with torchvision.datasets.find("voc").info.categories
_VOC_CATEGORIES = [
"__background__",
"aeroplane",
"bicycle",
"bird",
"boat",
"bottle",
"bus",
"car",
"cat",
"chair",
"cow",
"diningtable",
"dog",
"horse",
"motorbike",
"person",
"pottedplant",
"sheep",
"sofa",
"train",
"tvmonitor",
]
import warnings
from functools import partial
from typing import Any, Optional
from ....models.segmentation.fcn import FCN, _fcn_resnet
from ...transforms.presets import VocEval
from .._api import Weights, WeightEntry
from .._meta import _VOC_CATEGORIES
from ..resnet import ResNet50Weights, ResNet101Weights, resnet50, resnet101
__all__ = ["FCN", "FCNResNet50Weights", "FCNResNet101Weights", "fcn_resnet50", "fcn_resnet101"]
class FCNResNet50Weights(Weights):
CocoWithVocLabels_RefV1 = WeightEntry(
url="https://download.pytorch.org/models/fcn_resnet50_coco-1167a1af.pth",
transforms=partial(VocEval, resize_size=520),
meta={
"categories": _VOC_CATEGORIES,
"recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#fcn_resnet50",
"mIoU": 60.5,
"acc": 91.4,
},
)
class FCNResNet101Weights(Weights):
CocoWithVocLabels_RefV1 = WeightEntry(
url="https://download.pytorch.org/models/fcn_resnet101_coco-7ecb50ca.pth",
transforms=partial(VocEval, resize_size=520),
meta={
"categories": _VOC_CATEGORIES,
"recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#deeplabv3_resnet101",
"mIoU": 63.7,
"acc": 91.9,
},
)
def fcn_resnet50(
weights: Optional[FCNResNet50Weights] = None,
weights_backbone: Optional[ResNet50Weights] = None,
progress: bool = True,
num_classes: int = 21,
aux_loss: Optional[bool] = None,
**kwargs: Any,
) -> FCN:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
weights = FCNResNet50Weights.CocoWithVocLabels_RefV1 if kwargs.pop("pretrained") else None
weights = FCNResNet50Weights.verify(weights)
if "pretrained_backbone" in kwargs:
warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.")
weights_backbone = ResNet50Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
weights_backbone = ResNet50Weights.verify(weights_backbone)
if weights is not None:
aux_loss = True
weights_backbone = None
num_classes = len(weights.meta["categories"])
backbone = resnet50(weights=weights_backbone, replace_stride_with_dilation=[False, True, True])
model = _fcn_resnet(backbone, num_classes, aux_loss)
if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress))
return model
def fcn_resnet101(
weights: Optional[FCNResNet101Weights] = None,
weights_backbone: Optional[ResNet101Weights] = None,
progress: bool = True,
num_classes: int = 21,
aux_loss: Optional[bool] = None,
**kwargs: Any,
) -> FCN:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
weights = FCNResNet101Weights.CocoWithVocLabels_RefV1 if kwargs.pop("pretrained") else None
weights = FCNResNet101Weights.verify(weights)
if "pretrained_backbone" in kwargs:
warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.")
weights_backbone = ResNet101Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
weights_backbone = ResNet101Weights.verify(weights_backbone)
if weights is not None:
aux_loss = True
weights_backbone = None
num_classes = len(weights.meta["categories"])
backbone = resnet101(weights=weights_backbone, replace_stride_with_dilation=[False, True, True])
model = _fcn_resnet(backbone, num_classes, aux_loss)
if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress))
return model
...@@ -7,7 +7,7 @@ from ... import transforms as T ...@@ -7,7 +7,7 @@ from ... import transforms as T
from ...transforms import functional as F from ...transforms import functional as F
__all__ = ["CocoEval", "ImageNetEval"] __all__ = ["CocoEval", "ImageNetEval", "VocEval"]
class CocoEval(nn.Module): class CocoEval(nn.Module):
...@@ -39,3 +39,33 @@ class ImageNetEval(nn.Module): ...@@ -39,3 +39,33 @@ class ImageNetEval(nn.Module):
img = F.pil_to_tensor(img) img = F.pil_to_tensor(img)
img = F.convert_image_dtype(img, torch.float) img = F.convert_image_dtype(img, torch.float)
return self._normalize(img) return self._normalize(img)
class VocEval(nn.Module):
def __init__(
self,
resize_size: int,
mean: Tuple[float, ...] = (0.485, 0.456, 0.406),
std: Tuple[float, ...] = (0.229, 0.224, 0.225),
interpolation: T.InterpolationMode = T.InterpolationMode.BILINEAR,
interpolation_target: T.InterpolationMode = T.InterpolationMode.NEAREST,
) -> None:
super().__init__()
self._size = [resize_size]
self._mean = list(mean)
self._std = list(std)
self._interpolation = interpolation
self._interpolation_target = interpolation_target
def forward(self, img: Tensor, target: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]:
img = F.resize(img, self._size, interpolation=self._interpolation)
if not isinstance(img, Tensor):
img = F.pil_to_tensor(img)
img = F.convert_image_dtype(img, torch.float)
img = F.normalize(img, mean=self._mean, std=self._std)
if target:
target = F.resize(target, self._size, interpolation=self._interpolation_target)
if not isinstance(target, Tensor):
target = F.pil_to_tensor(target)
target = target.squeeze(0).to(torch.int64)
return img, target
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment