Unverified Commit a23778c0 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Adding interpolation in meta for all models and cleaning up unnecessary vars. (#4876)

parent ec6f12d1
import warnings
from typing import Any, Optional, Union
from torchvision.transforms.functional import InterpolationMode
from ....models.detection.faster_rcnn import (
_mobilenet_extractor,
_resnet_fpn_extractor,
......@@ -28,7 +30,10 @@ __all__ = [
]
_common_meta = {"categories": _COCO_CATEGORIES}
_common_meta = {
"categories": _COCO_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
}
class FasterRCNNResNet50FPNWeights(Weights):
......
import warnings
from typing import Any, Optional
from torchvision.transforms.functional import InterpolationMode
from ....models.detection.mask_rcnn import (
_resnet_fpn_extractor,
_validate_trainable_layers,
......@@ -27,6 +29,7 @@ class MaskRCNNResNet50FPNWeights(Weights):
transforms=CocoEval,
meta={
"categories": _COCO_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#mask-r-cnn",
"box_map": 37.9,
"mask_map": 34.6,
......
import warnings
from typing import Any, Optional
from torchvision.transforms.functional import InterpolationMode
from ....models.detection.retinanet import (
_resnet_fpn_extractor,
_validate_trainable_layers,
......@@ -28,6 +30,7 @@ class RetinaNetResNet50FPNWeights(Weights):
transforms=CocoEval,
meta={
"categories": _COCO_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#retinanet",
"map": 36.4,
},
......
......@@ -2,6 +2,8 @@ import warnings
from functools import partial
from typing import Any, Optional
from torchvision.transforms.functional import InterpolationMode
from ....models.segmentation.deeplabv3 import DeepLabV3, _deeplabv3_mobilenetv3, _deeplabv3_resnet
from ...transforms.presets import VocEval
from .._api import Weights, WeightEntry
......@@ -22,7 +24,10 @@ __all__ = [
]
_common_meta = {"categories": _VOC_CATEGORIES}
_common_meta = {
"categories": _VOC_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
}
class DeepLabV3ResNet50Weights(Weights):
......
......@@ -2,6 +2,8 @@ import warnings
from functools import partial
from typing import Any, Optional
from torchvision.transforms.functional import InterpolationMode
from ....models.segmentation.fcn import FCN, _fcn_resnet
from ...transforms.presets import VocEval
from .._api import Weights, WeightEntry
......@@ -12,7 +14,10 @@ from ..resnet import ResNet50Weights, ResNet101Weights, resnet50, resnet101
__all__ = ["FCN", "FCNResNet50Weights", "FCNResNet101Weights", "fcn_resnet50", "fcn_resnet101"]
_common_meta = {"categories": _VOC_CATEGORIES}
_common_meta = {
"categories": _VOC_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
}
class FCNResNet50Weights(Weights):
......
......@@ -2,6 +2,8 @@ import warnings
from functools import partial
from typing import Any, Optional
from torchvision.transforms.functional import InterpolationMode
from ....models.segmentation.lraspp import LRASPP, _lraspp_mobilenetv3
from ...transforms.presets import VocEval
from .._api import Weights, WeightEntry
......@@ -18,6 +20,7 @@ class LRASPPMobileNetV3LargeWeights(Weights):
transforms=partial(VocEval, resize_size=520),
meta={
"categories": _VOC_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
"recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#lraspp_mobilenet_v3_large",
"mIoU": 57.9,
"acc": 91.2,
......
......@@ -31,7 +31,7 @@ __all__ = [
]
def _vgg(arch: str, cfg: str, batch_norm: bool, weights: Optional[Weights], progress: bool, **kwargs: Any) -> VGG:
def _vgg(cfg: str, batch_norm: bool, weights: Optional[Weights], progress: bool, **kwargs: Any) -> VGG:
if weights is not None:
kwargs["num_classes"] = len(weights.meta["categories"])
model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs)
......@@ -150,7 +150,7 @@ def vgg11(weights: Optional[VGG11Weights] = None, progress: bool = True, **kwarg
weights = VGG11Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = VGG11Weights.verify(weights)
return _vgg("vgg11", "A", False, weights, progress, **kwargs)
return _vgg("A", False, weights, progress, **kwargs)
def vgg11_bn(weights: Optional[VGG11BNWeights] = None, progress: bool = True, **kwargs: Any) -> VGG:
......@@ -159,7 +159,7 @@ def vgg11_bn(weights: Optional[VGG11BNWeights] = None, progress: bool = True, **
weights = VGG11BNWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = VGG11BNWeights.verify(weights)
return _vgg("vgg11_bn", "A", True, weights, progress, **kwargs)
return _vgg("A", True, weights, progress, **kwargs)
def vgg13(weights: Optional[VGG13Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
......@@ -168,7 +168,7 @@ def vgg13(weights: Optional[VGG13Weights] = None, progress: bool = True, **kwarg
weights = VGG13Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = VGG13Weights.verify(weights)
return _vgg("vgg13", "B", False, weights, progress, **kwargs)
return _vgg("B", False, weights, progress, **kwargs)
def vgg13_bn(weights: Optional[VGG13BNWeights] = None, progress: bool = True, **kwargs: Any) -> VGG:
......@@ -177,7 +177,7 @@ def vgg13_bn(weights: Optional[VGG13BNWeights] = None, progress: bool = True, **
weights = VGG13BNWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = VGG13BNWeights.verify(weights)
return _vgg("vgg13_bn", "B", True, weights, progress, **kwargs)
return _vgg("B", True, weights, progress, **kwargs)
def vgg16(weights: Optional[VGG16Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
......@@ -186,7 +186,7 @@ def vgg16(weights: Optional[VGG16Weights] = None, progress: bool = True, **kwarg
weights = VGG16Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = VGG16Weights.verify(weights)
return _vgg("vgg16", "D", False, weights, progress, **kwargs)
return _vgg("D", False, weights, progress, **kwargs)
def vgg16_bn(weights: Optional[VGG16BNWeights] = None, progress: bool = True, **kwargs: Any) -> VGG:
......@@ -195,7 +195,7 @@ def vgg16_bn(weights: Optional[VGG16BNWeights] = None, progress: bool = True, **
weights = VGG16BNWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = VGG16BNWeights.verify(weights)
return _vgg("vgg16_bn", "D", True, weights, progress, **kwargs)
return _vgg("D", True, weights, progress, **kwargs)
def vgg19(weights: Optional[VGG19Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
......@@ -204,7 +204,7 @@ def vgg19(weights: Optional[VGG19Weights] = None, progress: bool = True, **kwarg
weights = VGG19Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = VGG19Weights.verify(weights)
return _vgg("vgg19", "E", False, weights, progress, **kwargs)
return _vgg("E", False, weights, progress, **kwargs)
def vgg19_bn(weights: Optional[VGG19BNWeights] = None, progress: bool = True, **kwargs: Any) -> VGG:
......@@ -213,4 +213,4 @@ def vgg19_bn(weights: Optional[VGG19BNWeights] = None, progress: bool = True, **
weights = VGG19BNWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = VGG19BNWeights.verify(weights)
return _vgg("vgg19_bn", "E", True, weights, progress, **kwargs)
return _vgg("E", True, weights, progress, **kwargs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment