"vscode:/vscode.git/clone" did not exist on "69247f5b033b8300561f689afab64862bf87b54b"
Unverified Commit bec45cdc authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Adding multiweight support for squeezenet prototype model (#4817)

* Change enum name for weights contributed by community.

* Adding multiweight support to squeezenet.
parent c4b5b67d
...@@ -8,6 +8,7 @@ from .mobilenetv3 import * ...@@ -8,6 +8,7 @@ from .mobilenetv3 import *
from .regnet import * from .regnet import *
from .resnet import * from .resnet import *
from .shufflenetv2 import * from .shufflenetv2 import *
from .squeezenet import *
from .vgg import * from .vgg import *
from . import detection from . import detection
from . import quantization from . import quantization
......
...@@ -17,7 +17,7 @@ _common_meta = {"size": (224, 224), "categories": _IMAGENET_CATEGORIES, "interpo ...@@ -17,7 +17,7 @@ _common_meta = {"size": (224, 224), "categories": _IMAGENET_CATEGORIES, "interpo
class GoogLeNetWeights(Weights): class GoogLeNetWeights(Weights):
ImageNet1K_TheCodezV1 = WeightEntry( ImageNet1K_Community = WeightEntry(
url="https://download.pytorch.org/models/googlenet-1378be20.pth", url="https://download.pytorch.org/models/googlenet-1378be20.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
...@@ -32,7 +32,7 @@ class GoogLeNetWeights(Weights): ...@@ -32,7 +32,7 @@ class GoogLeNetWeights(Weights):
def googlenet(weights: Optional[GoogLeNetWeights] = None, progress: bool = True, **kwargs: Any) -> GoogLeNet: def googlenet(weights: Optional[GoogLeNetWeights] = None, progress: bool = True, **kwargs: Any) -> GoogLeNet:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The argument pretrained is deprecated, please use weights instead.")
weights = GoogLeNetWeights.ImageNet1K_TheCodezV1 if kwargs.pop("pretrained") else None weights = GoogLeNetWeights.ImageNet1K_Community if kwargs.pop("pretrained") else None
weights = GoogLeNetWeights.verify(weights) weights = GoogLeNetWeights.verify(weights)
original_aux_logits = kwargs.get("aux_logits", False) original_aux_logits = kwargs.get("aux_logits", False)
......
...@@ -27,7 +27,7 @@ _common_meta = {"size": (224, 224), "categories": _IMAGENET_CATEGORIES, "interpo ...@@ -27,7 +27,7 @@ _common_meta = {"size": (224, 224), "categories": _IMAGENET_CATEGORIES, "interpo
class MNASNet0_5Weights(Weights): class MNASNet0_5Weights(Weights):
ImageNet1K_TrainerV1 = WeightEntry( ImageNet1K_Community = WeightEntry(
url="https://download.pytorch.org/models/mnasnet0.5_top1_67.823-3ffadce67e.pth", url="https://download.pytorch.org/models/mnasnet0.5_top1_67.823-3ffadce67e.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
...@@ -45,7 +45,7 @@ class MNASNet0_75Weights(Weights): ...@@ -45,7 +45,7 @@ class MNASNet0_75Weights(Weights):
class MNASNet1_0Weights(Weights): class MNASNet1_0Weights(Weights):
ImageNet1K_TrainerV1 = WeightEntry( ImageNet1K_Community = WeightEntry(
url="https://download.pytorch.org/models/mnasnet1.0_top1_73.512-f206786ef8.pth", url="https://download.pytorch.org/models/mnasnet1.0_top1_73.512-f206786ef8.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
...@@ -77,7 +77,7 @@ def _mnasnet(alpha: float, weights: Optional[Weights], progress: bool, **kwargs: ...@@ -77,7 +77,7 @@ def _mnasnet(alpha: float, weights: Optional[Weights], progress: bool, **kwargs:
def mnasnet0_5(weights: Optional[MNASNet0_5Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet: def mnasnet0_5(weights: Optional[MNASNet0_5Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The argument pretrained is deprecated, please use weights instead.")
weights = MNASNet0_5Weights.ImageNet1K_TrainerV1 if kwargs.pop("pretrained") else None weights = MNASNet0_5Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None
weights = MNASNet0_5Weights.verify(weights) weights = MNASNet0_5Weights.verify(weights)
...@@ -98,7 +98,7 @@ def mnasnet0_75(weights: Optional[MNASNet0_75Weights] = None, progress: bool = T ...@@ -98,7 +98,7 @@ def mnasnet0_75(weights: Optional[MNASNet0_75Weights] = None, progress: bool = T
def mnasnet1_0(weights: Optional[MNASNet1_0Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet: def mnasnet1_0(weights: Optional[MNASNet1_0Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The argument pretrained is deprecated, please use weights instead.")
weights = MNASNet1_0Weights.ImageNet1K_TrainerV1 if kwargs.pop("pretrained") else None weights = MNASNet1_0Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None
weights = MNASNet1_0Weights.verify(weights) weights = MNASNet1_0Weights.verify(weights)
return _mnasnet(1.0, weights, progress, **kwargs) return _mnasnet(1.0, weights, progress, **kwargs)
......
import warnings
from functools import partial
from typing import Any, Optional
from torchvision.transforms.functional import InterpolationMode
from ...models.squeezenet import SqueezeNet
from ..transforms.presets import ImageNetEval
from ._api import Weights, WeightEntry
from ._meta import _IMAGENET_CATEGORIES
__all__ = ["SqueezeNet", "SqueezeNet1_0Weights", "SqueezeNet1_1Weights", "squeezenet1_0", "squeezenet1_1"]
_common_meta = {"size": (224, 224), "categories": _IMAGENET_CATEGORIES, "interpolation": InterpolationMode.BILINEAR}
class SqueezeNet1_0Weights(Weights):
ImageNet1K_Community = WeightEntry(
url="https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
"recipe": "https://github.com/pytorch/vision/pull/49#issuecomment-277560717",
"acc@1": 58.092,
"acc@5": 80.420,
},
)
class SqueezeNet1_1Weights(Weights):
ImageNet1K_Community = WeightEntry(
url="https://download.pytorch.org/models/squeezenet1_1-b8a52dc0.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
"recipe": "https://github.com/pytorch/vision/pull/49#issuecomment-277560717",
"acc@1": 58.178,
"acc@5": 80.624,
},
)
def squeezenet1_0(weights: Optional[SqueezeNet1_0Weights] = None, progress: bool = True, **kwargs: Any) -> SqueezeNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
weights = SqueezeNet1_0Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None
weights = SqueezeNet1_0Weights.verify(weights)
if weights is not None:
kwargs["num_classes"] = len(weights.meta["categories"])
model = SqueezeNet("1_0", **kwargs)
if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress))
return model
def squeezenet1_1(weights: Optional[SqueezeNet1_1Weights] = None, progress: bool = True, **kwargs: Any) -> SqueezeNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
weights = SqueezeNet1_1Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None
weights = SqueezeNet1_1Weights.verify(weights)
if weights is not None:
kwargs["num_classes"] = len(weights.meta["categories"])
model = SqueezeNet("1_1", **kwargs)
if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress))
return model
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment