Unverified Commit 8b588042 authored by Joao Gomes's avatar Joao Gomes Committed by GitHub
Browse files

Add interpolation model information to all classification proto models (#4688)



* adding interpolation model information to all classification prototype models

* fixing lint errors
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent d605d7d4
......@@ -2,6 +2,8 @@ import warnings
from functools import partial
from typing import Any, Optional
from torchvision.transforms.functional import InterpolationMode
from ...models.alexnet import AlexNet
from ..transforms.presets import ImageNetEval
from ._api import Weights, WeightEntry
......@@ -11,10 +13,7 @@ from ._meta import _IMAGENET_CATEGORIES
__all__ = ["AlexNet", "AlexNetWeights", "alexnet"]
_common_meta = {
"size": (224, 224),
"categories": _IMAGENET_CATEGORIES,
}
_common_meta = {"size": (224, 224), "categories": _IMAGENET_CATEGORIES, "interpolation": InterpolationMode.BILINEAR}
class AlexNetWeights(Weights):
......
......@@ -4,6 +4,7 @@ from functools import partial
from typing import Any, Optional, Tuple
import torch.nn as nn
from torchvision.transforms.functional import InterpolationMode
from ...models.densenet import DenseNet
from ..transforms.presets import ImageNetEval
......@@ -62,10 +63,7 @@ def _densenet(
return model
_common_meta = {
"size": (224, 224),
"categories": _IMAGENET_CATEGORIES,
}
_common_meta = {"size": (224, 224), "categories": _IMAGENET_CATEGORIES, "interpolation": InterpolationMode.BILINEAR}
class DenseNet121Weights(Weights):
......
......@@ -2,6 +2,8 @@ import warnings
from functools import partial
from typing import Any, List, Optional, Type, Union
from torchvision.transforms.functional import InterpolationMode
from ...models.resnet import BasicBlock, Bottleneck, ResNet
from ..transforms.presets import ImageNetEval
from ._api import Weights, WeightEntry
......@@ -49,10 +51,7 @@ def _resnet(
return model
_common_meta = {
"size": (224, 224),
"categories": _IMAGENET_CATEGORIES,
}
_common_meta = {"size": (224, 224), "categories": _IMAGENET_CATEGORIES, "interpolation": InterpolationMode.BILINEAR}
class ResNet18Weights(Weights):
......
......@@ -2,6 +2,8 @@ import warnings
from functools import partial
from typing import Any, Optional
from torchvision.transforms.functional import InterpolationMode
from ...models.vgg import VGG, make_layers, cfgs
from ..transforms.presets import ImageNetEval
from ._api import Weights, WeightEntry
......@@ -38,10 +40,7 @@ def _vgg(arch: str, cfg: str, batch_norm: bool, weights: Optional[Weights], prog
return model
_common_meta = {
"size": (224, 224),
"categories": _IMAGENET_CATEGORIES,
}
_common_meta = {"size": (224, 224), "categories": _IMAGENET_CATEGORIES, "interpolation": InterpolationMode.BILINEAR}
class VGG11Weights(Weights):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment