Unverified Commit a23778c0 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Adding interpolation in meta for all models and cleaning up unnecessary vars. (#4876)

parent ec6f12d1
import warnings import warnings
from typing import Any, Optional, Union from typing import Any, Optional, Union
from torchvision.transforms.functional import InterpolationMode
from ....models.detection.faster_rcnn import ( from ....models.detection.faster_rcnn import (
_mobilenet_extractor, _mobilenet_extractor,
_resnet_fpn_extractor, _resnet_fpn_extractor,
...@@ -28,7 +30,10 @@ __all__ = [ ...@@ -28,7 +30,10 @@ __all__ = [
] ]
_common_meta = {"categories": _COCO_CATEGORIES} _common_meta = {
"categories": _COCO_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
}
class FasterRCNNResNet50FPNWeights(Weights): class FasterRCNNResNet50FPNWeights(Weights):
......
import warnings import warnings
from typing import Any, Optional from typing import Any, Optional
from torchvision.transforms.functional import InterpolationMode
from ....models.detection.mask_rcnn import ( from ....models.detection.mask_rcnn import (
_resnet_fpn_extractor, _resnet_fpn_extractor,
_validate_trainable_layers, _validate_trainable_layers,
...@@ -27,6 +29,7 @@ class MaskRCNNResNet50FPNWeights(Weights): ...@@ -27,6 +29,7 @@ class MaskRCNNResNet50FPNWeights(Weights):
transforms=CocoEval, transforms=CocoEval,
meta={ meta={
"categories": _COCO_CATEGORIES, "categories": _COCO_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#mask-r-cnn", "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#mask-r-cnn",
"box_map": 37.9, "box_map": 37.9,
"mask_map": 34.6, "mask_map": 34.6,
......
import warnings import warnings
from typing import Any, Optional from typing import Any, Optional
from torchvision.transforms.functional import InterpolationMode
from ....models.detection.retinanet import ( from ....models.detection.retinanet import (
_resnet_fpn_extractor, _resnet_fpn_extractor,
_validate_trainable_layers, _validate_trainable_layers,
...@@ -28,6 +30,7 @@ class RetinaNetResNet50FPNWeights(Weights): ...@@ -28,6 +30,7 @@ class RetinaNetResNet50FPNWeights(Weights):
transforms=CocoEval, transforms=CocoEval,
meta={ meta={
"categories": _COCO_CATEGORIES, "categories": _COCO_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#retinanet", "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#retinanet",
"map": 36.4, "map": 36.4,
}, },
......
...@@ -2,6 +2,8 @@ import warnings ...@@ -2,6 +2,8 @@ import warnings
from functools import partial from functools import partial
from typing import Any, Optional from typing import Any, Optional
from torchvision.transforms.functional import InterpolationMode
from ....models.segmentation.deeplabv3 import DeepLabV3, _deeplabv3_mobilenetv3, _deeplabv3_resnet from ....models.segmentation.deeplabv3 import DeepLabV3, _deeplabv3_mobilenetv3, _deeplabv3_resnet
from ...transforms.presets import VocEval from ...transforms.presets import VocEval
from .._api import Weights, WeightEntry from .._api import Weights, WeightEntry
...@@ -22,7 +24,10 @@ __all__ = [ ...@@ -22,7 +24,10 @@ __all__ = [
] ]
_common_meta = {"categories": _VOC_CATEGORIES} _common_meta = {
"categories": _VOC_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
}
class DeepLabV3ResNet50Weights(Weights): class DeepLabV3ResNet50Weights(Weights):
......
...@@ -2,6 +2,8 @@ import warnings ...@@ -2,6 +2,8 @@ import warnings
from functools import partial from functools import partial
from typing import Any, Optional from typing import Any, Optional
from torchvision.transforms.functional import InterpolationMode
from ....models.segmentation.fcn import FCN, _fcn_resnet from ....models.segmentation.fcn import FCN, _fcn_resnet
from ...transforms.presets import VocEval from ...transforms.presets import VocEval
from .._api import Weights, WeightEntry from .._api import Weights, WeightEntry
...@@ -12,7 +14,10 @@ from ..resnet import ResNet50Weights, ResNet101Weights, resnet50, resnet101 ...@@ -12,7 +14,10 @@ from ..resnet import ResNet50Weights, ResNet101Weights, resnet50, resnet101
__all__ = ["FCN", "FCNResNet50Weights", "FCNResNet101Weights", "fcn_resnet50", "fcn_resnet101"] __all__ = ["FCN", "FCNResNet50Weights", "FCNResNet101Weights", "fcn_resnet50", "fcn_resnet101"]
_common_meta = {"categories": _VOC_CATEGORIES} _common_meta = {
"categories": _VOC_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
}
class FCNResNet50Weights(Weights): class FCNResNet50Weights(Weights):
......
...@@ -2,6 +2,8 @@ import warnings ...@@ -2,6 +2,8 @@ import warnings
from functools import partial from functools import partial
from typing import Any, Optional from typing import Any, Optional
from torchvision.transforms.functional import InterpolationMode
from ....models.segmentation.lraspp import LRASPP, _lraspp_mobilenetv3 from ....models.segmentation.lraspp import LRASPP, _lraspp_mobilenetv3
from ...transforms.presets import VocEval from ...transforms.presets import VocEval
from .._api import Weights, WeightEntry from .._api import Weights, WeightEntry
...@@ -18,6 +20,7 @@ class LRASPPMobileNetV3LargeWeights(Weights): ...@@ -18,6 +20,7 @@ class LRASPPMobileNetV3LargeWeights(Weights):
transforms=partial(VocEval, resize_size=520), transforms=partial(VocEval, resize_size=520),
meta={ meta={
"categories": _VOC_CATEGORIES, "categories": _VOC_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
"recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#lraspp_mobilenet_v3_large", "recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#lraspp_mobilenet_v3_large",
"mIoU": 57.9, "mIoU": 57.9,
"acc": 91.2, "acc": 91.2,
......
...@@ -31,7 +31,7 @@ __all__ = [ ...@@ -31,7 +31,7 @@ __all__ = [
] ]
def _vgg(arch: str, cfg: str, batch_norm: bool, weights: Optional[Weights], progress: bool, **kwargs: Any) -> VGG: def _vgg(cfg: str, batch_norm: bool, weights: Optional[Weights], progress: bool, **kwargs: Any) -> VGG:
if weights is not None: if weights is not None:
kwargs["num_classes"] = len(weights.meta["categories"]) kwargs["num_classes"] = len(weights.meta["categories"])
model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs) model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs)
...@@ -150,7 +150,7 @@ def vgg11(weights: Optional[VGG11Weights] = None, progress: bool = True, **kwarg ...@@ -150,7 +150,7 @@ def vgg11(weights: Optional[VGG11Weights] = None, progress: bool = True, **kwarg
weights = VGG11Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None weights = VGG11Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = VGG11Weights.verify(weights) weights = VGG11Weights.verify(weights)
return _vgg("vgg11", "A", False, weights, progress, **kwargs) return _vgg("A", False, weights, progress, **kwargs)
def vgg11_bn(weights: Optional[VGG11BNWeights] = None, progress: bool = True, **kwargs: Any) -> VGG: def vgg11_bn(weights: Optional[VGG11BNWeights] = None, progress: bool = True, **kwargs: Any) -> VGG:
...@@ -159,7 +159,7 @@ def vgg11_bn(weights: Optional[VGG11BNWeights] = None, progress: bool = True, ** ...@@ -159,7 +159,7 @@ def vgg11_bn(weights: Optional[VGG11BNWeights] = None, progress: bool = True, **
weights = VGG11BNWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None weights = VGG11BNWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = VGG11BNWeights.verify(weights) weights = VGG11BNWeights.verify(weights)
return _vgg("vgg11_bn", "A", True, weights, progress, **kwargs) return _vgg("A", True, weights, progress, **kwargs)
def vgg13(weights: Optional[VGG13Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: def vgg13(weights: Optional[VGG13Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
...@@ -168,7 +168,7 @@ def vgg13(weights: Optional[VGG13Weights] = None, progress: bool = True, **kwarg ...@@ -168,7 +168,7 @@ def vgg13(weights: Optional[VGG13Weights] = None, progress: bool = True, **kwarg
weights = VGG13Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None weights = VGG13Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = VGG13Weights.verify(weights) weights = VGG13Weights.verify(weights)
return _vgg("vgg13", "B", False, weights, progress, **kwargs) return _vgg("B", False, weights, progress, **kwargs)
def vgg13_bn(weights: Optional[VGG13BNWeights] = None, progress: bool = True, **kwargs: Any) -> VGG: def vgg13_bn(weights: Optional[VGG13BNWeights] = None, progress: bool = True, **kwargs: Any) -> VGG:
...@@ -177,7 +177,7 @@ def vgg13_bn(weights: Optional[VGG13BNWeights] = None, progress: bool = True, ** ...@@ -177,7 +177,7 @@ def vgg13_bn(weights: Optional[VGG13BNWeights] = None, progress: bool = True, **
weights = VGG13BNWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None weights = VGG13BNWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = VGG13BNWeights.verify(weights) weights = VGG13BNWeights.verify(weights)
return _vgg("vgg13_bn", "B", True, weights, progress, **kwargs) return _vgg("B", True, weights, progress, **kwargs)
def vgg16(weights: Optional[VGG16Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: def vgg16(weights: Optional[VGG16Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
...@@ -186,7 +186,7 @@ def vgg16(weights: Optional[VGG16Weights] = None, progress: bool = True, **kwarg ...@@ -186,7 +186,7 @@ def vgg16(weights: Optional[VGG16Weights] = None, progress: bool = True, **kwarg
weights = VGG16Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None weights = VGG16Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = VGG16Weights.verify(weights) weights = VGG16Weights.verify(weights)
return _vgg("vgg16", "D", False, weights, progress, **kwargs) return _vgg("D", False, weights, progress, **kwargs)
def vgg16_bn(weights: Optional[VGG16BNWeights] = None, progress: bool = True, **kwargs: Any) -> VGG: def vgg16_bn(weights: Optional[VGG16BNWeights] = None, progress: bool = True, **kwargs: Any) -> VGG:
...@@ -195,7 +195,7 @@ def vgg16_bn(weights: Optional[VGG16BNWeights] = None, progress: bool = True, ** ...@@ -195,7 +195,7 @@ def vgg16_bn(weights: Optional[VGG16BNWeights] = None, progress: bool = True, **
weights = VGG16BNWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None weights = VGG16BNWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = VGG16BNWeights.verify(weights) weights = VGG16BNWeights.verify(weights)
return _vgg("vgg16_bn", "D", True, weights, progress, **kwargs) return _vgg("D", True, weights, progress, **kwargs)
def vgg19(weights: Optional[VGG19Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: def vgg19(weights: Optional[VGG19Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
...@@ -204,7 +204,7 @@ def vgg19(weights: Optional[VGG19Weights] = None, progress: bool = True, **kwarg ...@@ -204,7 +204,7 @@ def vgg19(weights: Optional[VGG19Weights] = None, progress: bool = True, **kwarg
weights = VGG19Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None weights = VGG19Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = VGG19Weights.verify(weights) weights = VGG19Weights.verify(weights)
return _vgg("vgg19", "E", False, weights, progress, **kwargs) return _vgg("E", False, weights, progress, **kwargs)
def vgg19_bn(weights: Optional[VGG19BNWeights] = None, progress: bool = True, **kwargs: Any) -> VGG: def vgg19_bn(weights: Optional[VGG19BNWeights] = None, progress: bool = True, **kwargs: Any) -> VGG:
...@@ -213,4 +213,4 @@ def vgg19_bn(weights: Optional[VGG19BNWeights] = None, progress: bool = True, ** ...@@ -213,4 +213,4 @@ def vgg19_bn(weights: Optional[VGG19BNWeights] = None, progress: bool = True, **
weights = VGG19BNWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None weights = VGG19BNWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = VGG19BNWeights.verify(weights) weights = VGG19BNWeights.verify(weights)
return _vgg("vgg19_bn", "E", True, weights, progress, **kwargs) return _vgg("E", True, weights, progress, **kwargs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment