Unverified Commit 8b588042 authored by Joao Gomes's avatar Joao Gomes Committed by GitHub
Browse files

Add interpolation model information to all classification proto models (#4688)



* adding interpolation model information to all classification prototype models

* fixing lint errors
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent d605d7d4
...@@ -2,6 +2,8 @@ import warnings ...@@ -2,6 +2,8 @@ import warnings
from functools import partial from functools import partial
from typing import Any, Optional from typing import Any, Optional
from torchvision.transforms.functional import InterpolationMode
from ...models.alexnet import AlexNet from ...models.alexnet import AlexNet
from ..transforms.presets import ImageNetEval from ..transforms.presets import ImageNetEval
from ._api import Weights, WeightEntry from ._api import Weights, WeightEntry
...@@ -11,10 +13,7 @@ from ._meta import _IMAGENET_CATEGORIES ...@@ -11,10 +13,7 @@ from ._meta import _IMAGENET_CATEGORIES
__all__ = ["AlexNet", "AlexNetWeights", "alexnet"] __all__ = ["AlexNet", "AlexNetWeights", "alexnet"]
_common_meta = { _common_meta = {"size": (224, 224), "categories": _IMAGENET_CATEGORIES, "interpolation": InterpolationMode.BILINEAR}
"size": (224, 224),
"categories": _IMAGENET_CATEGORIES,
}
class AlexNetWeights(Weights): class AlexNetWeights(Weights):
......
...@@ -4,6 +4,7 @@ from functools import partial ...@@ -4,6 +4,7 @@ from functools import partial
from typing import Any, Optional, Tuple from typing import Any, Optional, Tuple
import torch.nn as nn import torch.nn as nn
from torchvision.transforms.functional import InterpolationMode
from ...models.densenet import DenseNet from ...models.densenet import DenseNet
from ..transforms.presets import ImageNetEval from ..transforms.presets import ImageNetEval
...@@ -62,10 +63,7 @@ def _densenet( ...@@ -62,10 +63,7 @@ def _densenet(
return model return model
_common_meta = { _common_meta = {"size": (224, 224), "categories": _IMAGENET_CATEGORIES, "interpolation": InterpolationMode.BILINEAR}
"size": (224, 224),
"categories": _IMAGENET_CATEGORIES,
}
class DenseNet121Weights(Weights): class DenseNet121Weights(Weights):
......
...@@ -2,6 +2,8 @@ import warnings ...@@ -2,6 +2,8 @@ import warnings
from functools import partial from functools import partial
from typing import Any, List, Optional, Type, Union from typing import Any, List, Optional, Type, Union
from torchvision.transforms.functional import InterpolationMode
from ...models.resnet import BasicBlock, Bottleneck, ResNet from ...models.resnet import BasicBlock, Bottleneck, ResNet
from ..transforms.presets import ImageNetEval from ..transforms.presets import ImageNetEval
from ._api import Weights, WeightEntry from ._api import Weights, WeightEntry
...@@ -49,10 +51,7 @@ def _resnet( ...@@ -49,10 +51,7 @@ def _resnet(
return model return model
_common_meta = { _common_meta = {"size": (224, 224), "categories": _IMAGENET_CATEGORIES, "interpolation": InterpolationMode.BILINEAR}
"size": (224, 224),
"categories": _IMAGENET_CATEGORIES,
}
class ResNet18Weights(Weights): class ResNet18Weights(Weights):
......
...@@ -2,6 +2,8 @@ import warnings ...@@ -2,6 +2,8 @@ import warnings
from functools import partial from functools import partial
from typing import Any, Optional from typing import Any, Optional
from torchvision.transforms.functional import InterpolationMode
from ...models.vgg import VGG, make_layers, cfgs from ...models.vgg import VGG, make_layers, cfgs
from ..transforms.presets import ImageNetEval from ..transforms.presets import ImageNetEval
from ._api import Weights, WeightEntry from ._api import Weights, WeightEntry
...@@ -38,10 +40,7 @@ def _vgg(arch: str, cfg: str, batch_norm: bool, weights: Optional[Weights], prog ...@@ -38,10 +40,7 @@ def _vgg(arch: str, cfg: str, batch_norm: bool, weights: Optional[Weights], prog
return model return model
_common_meta = { _common_meta = {"size": (224, 224), "categories": _IMAGENET_CATEGORIES, "interpolation": InterpolationMode.BILINEAR}
"size": (224, 224),
"categories": _IMAGENET_CATEGORIES,
}
class VGG11Weights(Weights): class VGG11Weights(Weights):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment