Unverified Commit 1deb2ec2 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Cleanup Models prototype implementation (#4940)

* Disable WeightEntry to pass-through `Weights.verify()`

* Rename `Weights.state_dict()` to `Weights.get_state_dict()`

* Add TODO for missing doc.

* Moving warning messages for googlenet.

* Upper-case global `_COMMON_META` var

* Replace argument with parameter in all warnings.
parent 30f4d108
...@@ -44,12 +44,12 @@ def _shufflenetv2( ...@@ -44,12 +44,12 @@ def _shufflenetv2(
quantize_model(model, backend) quantize_model(model, backend)
if weights is not None: if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress))
return model return model
_common_meta = { _COMMON_META = {
"size": (224, 224), "size": (224, 224),
"categories": _IMAGENET_CATEGORIES, "categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR, "interpolation": InterpolationMode.BILINEAR,
...@@ -64,7 +64,7 @@ class QuantizedShuffleNetV2_x0_5Weights(Weights): ...@@ -64,7 +64,7 @@ class QuantizedShuffleNetV2_x0_5Weights(Weights):
url="https://download.pytorch.org/models/quantized/shufflenetv2_x0.5_fbgemm-00845098.pth", url="https://download.pytorch.org/models/quantized/shufflenetv2_x0.5_fbgemm-00845098.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
**_common_meta, **_COMMON_META,
"unquantized": ShuffleNetV2_x0_5Weights.ImageNet1K_Community, "unquantized": ShuffleNetV2_x0_5Weights.ImageNet1K_Community,
"acc@1": 57.972, "acc@1": 57.972,
"acc@5": 79.780, "acc@5": 79.780,
...@@ -77,7 +77,7 @@ class QuantizedShuffleNetV2_x1_0Weights(Weights): ...@@ -77,7 +77,7 @@ class QuantizedShuffleNetV2_x1_0Weights(Weights):
url="https://download.pytorch.org/models/quantized/shufflenetv2_x1_fbgemm-db332c57.pth", url="https://download.pytorch.org/models/quantized/shufflenetv2_x1_fbgemm-db332c57.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
**_common_meta, **_COMMON_META,
"unquantized": ShuffleNetV2_x1_0Weights.ImageNet1K_Community, "unquantized": ShuffleNetV2_x1_0Weights.ImageNet1K_Community,
"acc@1": 68.360, "acc@1": 68.360,
"acc@5": 87.582, "acc@5": 87.582,
...@@ -92,7 +92,7 @@ def shufflenet_v2_x0_5( ...@@ -92,7 +92,7 @@ def shufflenet_v2_x0_5(
**kwargs: Any, **kwargs: Any,
) -> QuantizableShuffleNetV2: ) -> QuantizableShuffleNetV2:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
if kwargs.pop("pretrained"): if kwargs.pop("pretrained"):
weights = ( weights = (
QuantizedShuffleNetV2_x0_5Weights.ImageNet1K_FBGEMM_Community QuantizedShuffleNetV2_x0_5Weights.ImageNet1K_FBGEMM_Community
...@@ -117,7 +117,7 @@ def shufflenet_v2_x1_0( ...@@ -117,7 +117,7 @@ def shufflenet_v2_x1_0(
**kwargs: Any, **kwargs: Any,
) -> QuantizableShuffleNetV2: ) -> QuantizableShuffleNetV2:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
if kwargs.pop("pretrained"): if kwargs.pop("pretrained"):
weights = ( weights = (
QuantizedShuffleNetV2_x1_0Weights.ImageNet1K_FBGEMM_Community QuantizedShuffleNetV2_x1_0Weights.ImageNet1K_FBGEMM_Community
......
...@@ -43,7 +43,7 @@ __all__ = [ ...@@ -43,7 +43,7 @@ __all__ = [
"regnet_x_32gf", "regnet_x_32gf",
] ]
_common_meta = {"size": (224, 224), "categories": _IMAGENET_CATEGORIES, "interpolation": InterpolationMode.BILINEAR} _COMMON_META = {"size": (224, 224), "categories": _IMAGENET_CATEGORIES, "interpolation": InterpolationMode.BILINEAR}
def _regnet( def _regnet(
...@@ -59,7 +59,7 @@ def _regnet( ...@@ -59,7 +59,7 @@ def _regnet(
model = RegNet(block_params, norm_layer=norm_layer, **kwargs) model = RegNet(block_params, norm_layer=norm_layer, **kwargs)
if weights is not None: if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress))
return model return model
...@@ -69,7 +69,7 @@ class RegNet_y_400mfWeights(Weights): ...@@ -69,7 +69,7 @@ class RegNet_y_400mfWeights(Weights):
url="https://download.pytorch.org/models/regnet_y_400mf-c65dace8.pth", url="https://download.pytorch.org/models/regnet_y_400mf-c65dace8.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
**_common_meta, **_COMMON_META,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models", "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models",
"acc@1": 74.046, "acc@1": 74.046,
"acc@5": 91.716, "acc@5": 91.716,
...@@ -82,7 +82,7 @@ class RegNet_y_800mfWeights(Weights): ...@@ -82,7 +82,7 @@ class RegNet_y_800mfWeights(Weights):
url="https://download.pytorch.org/models/regnet_y_800mf-1b27b58c.pth", url="https://download.pytorch.org/models/regnet_y_800mf-1b27b58c.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
**_common_meta, **_COMMON_META,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models", "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models",
"acc@1": 76.420, "acc@1": 76.420,
"acc@5": 93.136, "acc@5": 93.136,
...@@ -95,7 +95,7 @@ class RegNet_y_1_6gfWeights(Weights): ...@@ -95,7 +95,7 @@ class RegNet_y_1_6gfWeights(Weights):
url="https://download.pytorch.org/models/regnet_y_1_6gf-b11a554e.pth", url="https://download.pytorch.org/models/regnet_y_1_6gf-b11a554e.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
**_common_meta, **_COMMON_META,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models", "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models",
"acc@1": 77.950, "acc@1": 77.950,
"acc@5": 93.966, "acc@5": 93.966,
...@@ -108,7 +108,7 @@ class RegNet_y_3_2gfWeights(Weights): ...@@ -108,7 +108,7 @@ class RegNet_y_3_2gfWeights(Weights):
url="https://download.pytorch.org/models/regnet_y_3_2gf-b5a9779c.pth", url="https://download.pytorch.org/models/regnet_y_3_2gf-b5a9779c.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
**_common_meta, **_COMMON_META,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#medium-models", "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#medium-models",
"acc@1": 78.948, "acc@1": 78.948,
"acc@5": 94.576, "acc@5": 94.576,
...@@ -121,7 +121,7 @@ class RegNet_y_8gfWeights(Weights): ...@@ -121,7 +121,7 @@ class RegNet_y_8gfWeights(Weights):
url="https://download.pytorch.org/models/regnet_y_8gf-d0d0e4a8.pth", url="https://download.pytorch.org/models/regnet_y_8gf-d0d0e4a8.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
**_common_meta, **_COMMON_META,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#medium-models", "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#medium-models",
"acc@1": 80.032, "acc@1": 80.032,
"acc@5": 95.048, "acc@5": 95.048,
...@@ -134,7 +134,7 @@ class RegNet_y_16gfWeights(Weights): ...@@ -134,7 +134,7 @@ class RegNet_y_16gfWeights(Weights):
url="https://download.pytorch.org/models/regnet_y_16gf-9e6ed7dd.pth", url="https://download.pytorch.org/models/regnet_y_16gf-9e6ed7dd.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
**_common_meta, **_COMMON_META,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#large-models", "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#large-models",
"acc@1": 80.424, "acc@1": 80.424,
"acc@5": 95.240, "acc@5": 95.240,
...@@ -147,7 +147,7 @@ class RegNet_y_32gfWeights(Weights): ...@@ -147,7 +147,7 @@ class RegNet_y_32gfWeights(Weights):
url="https://download.pytorch.org/models/regnet_y_32gf-4dee3f7a.pth", url="https://download.pytorch.org/models/regnet_y_32gf-4dee3f7a.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
**_common_meta, **_COMMON_META,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#large-models", "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#large-models",
"acc@1": 80.878, "acc@1": 80.878,
"acc@5": 95.340, "acc@5": 95.340,
...@@ -160,7 +160,7 @@ class RegNet_x_400mfWeights(Weights): ...@@ -160,7 +160,7 @@ class RegNet_x_400mfWeights(Weights):
url="https://download.pytorch.org/models/regnet_x_400mf-adf1edd5.pth", url="https://download.pytorch.org/models/regnet_x_400mf-adf1edd5.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
**_common_meta, **_COMMON_META,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models", "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models",
"acc@1": 72.834, "acc@1": 72.834,
"acc@5": 90.950, "acc@5": 90.950,
...@@ -173,7 +173,7 @@ class RegNet_x_800mfWeights(Weights): ...@@ -173,7 +173,7 @@ class RegNet_x_800mfWeights(Weights):
url="https://download.pytorch.org/models/regnet_x_800mf-ad17e45c.pth", url="https://download.pytorch.org/models/regnet_x_800mf-ad17e45c.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
**_common_meta, **_COMMON_META,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models", "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models",
"acc@1": 75.212, "acc@1": 75.212,
"acc@5": 92.348, "acc@5": 92.348,
...@@ -186,7 +186,7 @@ class RegNet_x_1_6gfWeights(Weights): ...@@ -186,7 +186,7 @@ class RegNet_x_1_6gfWeights(Weights):
url="https://download.pytorch.org/models/regnet_x_1_6gf-e3633e7f.pth", url="https://download.pytorch.org/models/regnet_x_1_6gf-e3633e7f.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
**_common_meta, **_COMMON_META,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models", "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models",
"acc@1": 77.040, "acc@1": 77.040,
"acc@5": 93.440, "acc@5": 93.440,
...@@ -199,7 +199,7 @@ class RegNet_x_3_2gfWeights(Weights): ...@@ -199,7 +199,7 @@ class RegNet_x_3_2gfWeights(Weights):
url="https://download.pytorch.org/models/regnet_x_3_2gf-f342aeae.pth", url="https://download.pytorch.org/models/regnet_x_3_2gf-f342aeae.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
**_common_meta, **_COMMON_META,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#medium-models", "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#medium-models",
"acc@1": 78.364, "acc@1": 78.364,
"acc@5": 93.992, "acc@5": 93.992,
...@@ -212,7 +212,7 @@ class RegNet_x_8gfWeights(Weights): ...@@ -212,7 +212,7 @@ class RegNet_x_8gfWeights(Weights):
url="https://download.pytorch.org/models/regnet_x_8gf-03ceed89.pth", url="https://download.pytorch.org/models/regnet_x_8gf-03ceed89.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
**_common_meta, **_COMMON_META,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#medium-models", "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#medium-models",
"acc@1": 79.344, "acc@1": 79.344,
"acc@5": 94.686, "acc@5": 94.686,
...@@ -225,7 +225,7 @@ class RegNet_x_16gfWeights(Weights): ...@@ -225,7 +225,7 @@ class RegNet_x_16gfWeights(Weights):
url="https://download.pytorch.org/models/regnet_x_16gf-2007eb11.pth", url="https://download.pytorch.org/models/regnet_x_16gf-2007eb11.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
**_common_meta, **_COMMON_META,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#medium-models", "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#medium-models",
"acc@1": 80.058, "acc@1": 80.058,
"acc@5": 94.944, "acc@5": 94.944,
...@@ -238,7 +238,7 @@ class RegNet_x_32gfWeights(Weights): ...@@ -238,7 +238,7 @@ class RegNet_x_32gfWeights(Weights):
url="https://download.pytorch.org/models/regnet_x_32gf-9d47f8d0.pth", url="https://download.pytorch.org/models/regnet_x_32gf-9d47f8d0.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
**_common_meta, **_COMMON_META,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#large-models", "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#large-models",
"acc@1": 80.622, "acc@1": 80.622,
"acc@5": 95.248, "acc@5": 95.248,
...@@ -248,7 +248,7 @@ class RegNet_x_32gfWeights(Weights): ...@@ -248,7 +248,7 @@ class RegNet_x_32gfWeights(Weights):
def regnet_y_400mf(weights: Optional[RegNet_y_400mfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet: def regnet_y_400mf(weights: Optional[RegNet_y_400mfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = RegNet_y_400mfWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None weights = RegNet_y_400mfWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = RegNet_y_400mfWeights.verify(weights) weights = RegNet_y_400mfWeights.verify(weights)
...@@ -258,7 +258,7 @@ def regnet_y_400mf(weights: Optional[RegNet_y_400mfWeights] = None, progress: bo ...@@ -258,7 +258,7 @@ def regnet_y_400mf(weights: Optional[RegNet_y_400mfWeights] = None, progress: bo
def regnet_y_800mf(weights: Optional[RegNet_y_800mfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet: def regnet_y_800mf(weights: Optional[RegNet_y_800mfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = RegNet_y_800mfWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None weights = RegNet_y_800mfWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = RegNet_y_800mfWeights.verify(weights) weights = RegNet_y_800mfWeights.verify(weights)
...@@ -268,7 +268,7 @@ def regnet_y_800mf(weights: Optional[RegNet_y_800mfWeights] = None, progress: bo ...@@ -268,7 +268,7 @@ def regnet_y_800mf(weights: Optional[RegNet_y_800mfWeights] = None, progress: bo
def regnet_y_1_6gf(weights: Optional[RegNet_y_1_6gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet: def regnet_y_1_6gf(weights: Optional[RegNet_y_1_6gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = RegNet_y_1_6gfWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None weights = RegNet_y_1_6gfWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = RegNet_y_1_6gfWeights.verify(weights) weights = RegNet_y_1_6gfWeights.verify(weights)
...@@ -280,7 +280,7 @@ def regnet_y_1_6gf(weights: Optional[RegNet_y_1_6gfWeights] = None, progress: bo ...@@ -280,7 +280,7 @@ def regnet_y_1_6gf(weights: Optional[RegNet_y_1_6gfWeights] = None, progress: bo
def regnet_y_3_2gf(weights: Optional[RegNet_y_3_2gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet: def regnet_y_3_2gf(weights: Optional[RegNet_y_3_2gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = RegNet_y_3_2gfWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None weights = RegNet_y_3_2gfWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = RegNet_y_3_2gfWeights.verify(weights) weights = RegNet_y_3_2gfWeights.verify(weights)
params = BlockParams.from_init_params( params = BlockParams.from_init_params(
...@@ -291,7 +291,7 @@ def regnet_y_3_2gf(weights: Optional[RegNet_y_3_2gfWeights] = None, progress: bo ...@@ -291,7 +291,7 @@ def regnet_y_3_2gf(weights: Optional[RegNet_y_3_2gfWeights] = None, progress: bo
def regnet_y_8gf(weights: Optional[RegNet_y_8gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet: def regnet_y_8gf(weights: Optional[RegNet_y_8gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = RegNet_y_8gfWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None weights = RegNet_y_8gfWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = RegNet_y_8gfWeights.verify(weights) weights = RegNet_y_8gfWeights.verify(weights)
params = BlockParams.from_init_params( params = BlockParams.from_init_params(
...@@ -302,7 +302,7 @@ def regnet_y_8gf(weights: Optional[RegNet_y_8gfWeights] = None, progress: bool = ...@@ -302,7 +302,7 @@ def regnet_y_8gf(weights: Optional[RegNet_y_8gfWeights] = None, progress: bool =
def regnet_y_16gf(weights: Optional[RegNet_y_16gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet: def regnet_y_16gf(weights: Optional[RegNet_y_16gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = RegNet_y_16gfWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None weights = RegNet_y_16gfWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = RegNet_y_16gfWeights.verify(weights) weights = RegNet_y_16gfWeights.verify(weights)
params = BlockParams.from_init_params( params = BlockParams.from_init_params(
...@@ -313,7 +313,7 @@ def regnet_y_16gf(weights: Optional[RegNet_y_16gfWeights] = None, progress: bool ...@@ -313,7 +313,7 @@ def regnet_y_16gf(weights: Optional[RegNet_y_16gfWeights] = None, progress: bool
def regnet_y_32gf(weights: Optional[RegNet_y_32gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet: def regnet_y_32gf(weights: Optional[RegNet_y_32gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = RegNet_y_32gfWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None weights = RegNet_y_32gfWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = RegNet_y_32gfWeights.verify(weights) weights = RegNet_y_32gfWeights.verify(weights)
params = BlockParams.from_init_params( params = BlockParams.from_init_params(
...@@ -324,7 +324,7 @@ def regnet_y_32gf(weights: Optional[RegNet_y_32gfWeights] = None, progress: bool ...@@ -324,7 +324,7 @@ def regnet_y_32gf(weights: Optional[RegNet_y_32gfWeights] = None, progress: bool
def regnet_x_400mf(weights: Optional[RegNet_x_400mfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet: def regnet_x_400mf(weights: Optional[RegNet_x_400mfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = RegNet_x_400mfWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None weights = RegNet_x_400mfWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = RegNet_x_400mfWeights.verify(weights) weights = RegNet_x_400mfWeights.verify(weights)
params = BlockParams.from_init_params(depth=22, w_0=24, w_a=24.48, w_m=2.54, group_width=16, **kwargs) params = BlockParams.from_init_params(depth=22, w_0=24, w_a=24.48, w_m=2.54, group_width=16, **kwargs)
...@@ -334,7 +334,7 @@ def regnet_x_400mf(weights: Optional[RegNet_x_400mfWeights] = None, progress: bo ...@@ -334,7 +334,7 @@ def regnet_x_400mf(weights: Optional[RegNet_x_400mfWeights] = None, progress: bo
def regnet_x_800mf(weights: Optional[RegNet_x_800mfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet: def regnet_x_800mf(weights: Optional[RegNet_x_800mfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = RegNet_x_800mfWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None weights = RegNet_x_800mfWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = RegNet_x_800mfWeights.verify(weights) weights = RegNet_x_800mfWeights.verify(weights)
params = BlockParams.from_init_params(depth=16, w_0=56, w_a=35.73, w_m=2.28, group_width=16, **kwargs) params = BlockParams.from_init_params(depth=16, w_0=56, w_a=35.73, w_m=2.28, group_width=16, **kwargs)
...@@ -344,7 +344,7 @@ def regnet_x_800mf(weights: Optional[RegNet_x_800mfWeights] = None, progress: bo ...@@ -344,7 +344,7 @@ def regnet_x_800mf(weights: Optional[RegNet_x_800mfWeights] = None, progress: bo
def regnet_x_1_6gf(weights: Optional[RegNet_x_1_6gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet: def regnet_x_1_6gf(weights: Optional[RegNet_x_1_6gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = RegNet_x_1_6gfWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None weights = RegNet_x_1_6gfWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = RegNet_x_1_6gfWeights.verify(weights) weights = RegNet_x_1_6gfWeights.verify(weights)
params = BlockParams.from_init_params(depth=18, w_0=80, w_a=34.01, w_m=2.25, group_width=24, **kwargs) params = BlockParams.from_init_params(depth=18, w_0=80, w_a=34.01, w_m=2.25, group_width=24, **kwargs)
...@@ -354,7 +354,7 @@ def regnet_x_1_6gf(weights: Optional[RegNet_x_1_6gfWeights] = None, progress: bo ...@@ -354,7 +354,7 @@ def regnet_x_1_6gf(weights: Optional[RegNet_x_1_6gfWeights] = None, progress: bo
def regnet_x_3_2gf(weights: Optional[RegNet_x_3_2gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet: def regnet_x_3_2gf(weights: Optional[RegNet_x_3_2gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = RegNet_x_3_2gfWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None weights = RegNet_x_3_2gfWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = RegNet_x_3_2gfWeights.verify(weights) weights = RegNet_x_3_2gfWeights.verify(weights)
params = BlockParams.from_init_params(depth=25, w_0=88, w_a=26.31, w_m=2.25, group_width=48, **kwargs) params = BlockParams.from_init_params(depth=25, w_0=88, w_a=26.31, w_m=2.25, group_width=48, **kwargs)
...@@ -364,7 +364,7 @@ def regnet_x_3_2gf(weights: Optional[RegNet_x_3_2gfWeights] = None, progress: bo ...@@ -364,7 +364,7 @@ def regnet_x_3_2gf(weights: Optional[RegNet_x_3_2gfWeights] = None, progress: bo
def regnet_x_8gf(weights: Optional[RegNet_x_8gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet: def regnet_x_8gf(weights: Optional[RegNet_x_8gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = RegNet_x_8gfWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None weights = RegNet_x_8gfWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = RegNet_x_8gfWeights.verify(weights) weights = RegNet_x_8gfWeights.verify(weights)
params = BlockParams.from_init_params(depth=23, w_0=80, w_a=49.56, w_m=2.88, group_width=120, **kwargs) params = BlockParams.from_init_params(depth=23, w_0=80, w_a=49.56, w_m=2.88, group_width=120, **kwargs)
...@@ -374,7 +374,7 @@ def regnet_x_8gf(weights: Optional[RegNet_x_8gfWeights] = None, progress: bool = ...@@ -374,7 +374,7 @@ def regnet_x_8gf(weights: Optional[RegNet_x_8gfWeights] = None, progress: bool =
def regnet_x_16gf(weights: Optional[RegNet_x_16gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet: def regnet_x_16gf(weights: Optional[RegNet_x_16gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = RegNet_x_16gfWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None weights = RegNet_x_16gfWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = RegNet_x_16gfWeights.verify(weights) weights = RegNet_x_16gfWeights.verify(weights)
params = BlockParams.from_init_params(depth=22, w_0=216, w_a=55.59, w_m=2.1, group_width=128, **kwargs) params = BlockParams.from_init_params(depth=22, w_0=216, w_a=55.59, w_m=2.1, group_width=128, **kwargs)
...@@ -384,7 +384,7 @@ def regnet_x_16gf(weights: Optional[RegNet_x_16gfWeights] = None, progress: bool ...@@ -384,7 +384,7 @@ def regnet_x_16gf(weights: Optional[RegNet_x_16gfWeights] = None, progress: bool
def regnet_x_32gf(weights: Optional[RegNet_x_32gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet: def regnet_x_32gf(weights: Optional[RegNet_x_32gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = RegNet_x_32gfWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None weights = RegNet_x_32gfWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = RegNet_x_32gfWeights.verify(weights) weights = RegNet_x_32gfWeights.verify(weights)
params = BlockParams.from_init_params(depth=23, w_0=320, w_a=69.86, w_m=2.0, group_width=168, **kwargs) params = BlockParams.from_init_params(depth=23, w_0=320, w_a=69.86, w_m=2.0, group_width=168, **kwargs)
......
...@@ -46,12 +46,12 @@ def _resnet( ...@@ -46,12 +46,12 @@ def _resnet(
model = ResNet(block, layers, **kwargs) model = ResNet(block, layers, **kwargs)
if weights is not None: if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress))
return model return model
_common_meta = {"size": (224, 224), "categories": _IMAGENET_CATEGORIES, "interpolation": InterpolationMode.BILINEAR} _COMMON_META = {"size": (224, 224), "categories": _IMAGENET_CATEGORIES, "interpolation": InterpolationMode.BILINEAR}
class ResNet18Weights(Weights): class ResNet18Weights(Weights):
...@@ -59,7 +59,7 @@ class ResNet18Weights(Weights): ...@@ -59,7 +59,7 @@ class ResNet18Weights(Weights):
url="https://download.pytorch.org/models/resnet18-f37072fd.pth", url="https://download.pytorch.org/models/resnet18-f37072fd.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
**_common_meta, **_COMMON_META,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet", "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet",
"acc@1": 69.758, "acc@1": 69.758,
"acc@5": 89.078, "acc@5": 89.078,
...@@ -72,7 +72,7 @@ class ResNet34Weights(Weights): ...@@ -72,7 +72,7 @@ class ResNet34Weights(Weights):
url="https://download.pytorch.org/models/resnet34-b627a593.pth", url="https://download.pytorch.org/models/resnet34-b627a593.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
**_common_meta, **_COMMON_META,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet", "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet",
"acc@1": 73.314, "acc@1": 73.314,
"acc@5": 91.420, "acc@5": 91.420,
...@@ -85,7 +85,7 @@ class ResNet50Weights(Weights): ...@@ -85,7 +85,7 @@ class ResNet50Weights(Weights):
url="https://download.pytorch.org/models/resnet50-0676ba61.pth", url="https://download.pytorch.org/models/resnet50-0676ba61.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
**_common_meta, **_COMMON_META,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet", "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet",
"acc@1": 76.130, "acc@1": 76.130,
"acc@5": 92.862, "acc@5": 92.862,
...@@ -95,7 +95,7 @@ class ResNet50Weights(Weights): ...@@ -95,7 +95,7 @@ class ResNet50Weights(Weights):
url="https://download.pytorch.org/models/resnet50-f46c3f97.pth", url="https://download.pytorch.org/models/resnet50-f46c3f97.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232), transforms=partial(ImageNetEval, crop_size=224, resize_size=232),
meta={ meta={
**_common_meta, **_COMMON_META,
"recipe": "https://github.com/pytorch/vision/issues/3995", "recipe": "https://github.com/pytorch/vision/issues/3995",
"acc@1": 80.674, "acc@1": 80.674,
"acc@5": 95.166, "acc@5": 95.166,
...@@ -108,7 +108,7 @@ class ResNet101Weights(Weights): ...@@ -108,7 +108,7 @@ class ResNet101Weights(Weights):
url="https://download.pytorch.org/models/resnet101-63fe2227.pth", url="https://download.pytorch.org/models/resnet101-63fe2227.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
**_common_meta, **_COMMON_META,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet", "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet",
"acc@1": 77.374, "acc@1": 77.374,
"acc@5": 93.546, "acc@5": 93.546,
...@@ -118,7 +118,7 @@ class ResNet101Weights(Weights): ...@@ -118,7 +118,7 @@ class ResNet101Weights(Weights):
url="https://download.pytorch.org/models/resnet101-b641f3a9.pth", url="https://download.pytorch.org/models/resnet101-b641f3a9.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232), transforms=partial(ImageNetEval, crop_size=224, resize_size=232),
meta={ meta={
**_common_meta, **_COMMON_META,
"recipe": "https://github.com/pytorch/vision/issues/3995", "recipe": "https://github.com/pytorch/vision/issues/3995",
"acc@1": 81.728, "acc@1": 81.728,
"acc@5": 95.670, "acc@5": 95.670,
...@@ -131,7 +131,7 @@ class ResNet152Weights(Weights): ...@@ -131,7 +131,7 @@ class ResNet152Weights(Weights):
url="https://download.pytorch.org/models/resnet152-394f9c45.pth", url="https://download.pytorch.org/models/resnet152-394f9c45.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
**_common_meta, **_COMMON_META,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet", "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet",
"acc@1": 78.312, "acc@1": 78.312,
"acc@5": 94.046, "acc@5": 94.046,
...@@ -141,7 +141,7 @@ class ResNet152Weights(Weights): ...@@ -141,7 +141,7 @@ class ResNet152Weights(Weights):
url="https://download.pytorch.org/models/resnet152-089c0848.pth", url="https://download.pytorch.org/models/resnet152-089c0848.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232), transforms=partial(ImageNetEval, crop_size=224, resize_size=232),
meta={ meta={
**_common_meta, **_COMMON_META,
"recipe": "https://github.com/pytorch/vision/issues/3995", "recipe": "https://github.com/pytorch/vision/issues/3995",
"acc@1": 82.042, "acc@1": 82.042,
"acc@5": 95.926, "acc@5": 95.926,
...@@ -154,7 +154,7 @@ class ResNeXt50_32x4dWeights(Weights): ...@@ -154,7 +154,7 @@ class ResNeXt50_32x4dWeights(Weights):
url="https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth", url="https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
**_common_meta, **_COMMON_META,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnext", "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnext",
"acc@1": 77.618, "acc@1": 77.618,
"acc@5": 93.698, "acc@5": 93.698,
...@@ -164,7 +164,7 @@ class ResNeXt50_32x4dWeights(Weights): ...@@ -164,7 +164,7 @@ class ResNeXt50_32x4dWeights(Weights):
url="https://download.pytorch.org/models/resnext50_32x4d-b260af35.pth", url="https://download.pytorch.org/models/resnext50_32x4d-b260af35.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232), transforms=partial(ImageNetEval, crop_size=224, resize_size=232),
meta={ meta={
**_common_meta, **_COMMON_META,
"recipe": "https://github.com/pytorch/vision/issues/3995", "recipe": "https://github.com/pytorch/vision/issues/3995",
"acc@1": 81.116, "acc@1": 81.116,
"acc@5": 95.478, "acc@5": 95.478,
...@@ -177,7 +177,7 @@ class ResNeXt101_32x8dWeights(Weights): ...@@ -177,7 +177,7 @@ class ResNeXt101_32x8dWeights(Weights):
url="https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth", url="https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
**_common_meta, **_COMMON_META,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnext", "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnext",
"acc@1": 79.312, "acc@1": 79.312,
"acc@5": 94.526, "acc@5": 94.526,
...@@ -190,7 +190,7 @@ class WideResNet50_2Weights(Weights): ...@@ -190,7 +190,7 @@ class WideResNet50_2Weights(Weights):
url="https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth", url="https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
**_common_meta, **_COMMON_META,
"recipe": "https://github.com/pytorch/vision/pull/912#issue-445437439", "recipe": "https://github.com/pytorch/vision/pull/912#issue-445437439",
"acc@1": 78.468, "acc@1": 78.468,
"acc@5": 94.086, "acc@5": 94.086,
...@@ -200,7 +200,7 @@ class WideResNet50_2Weights(Weights): ...@@ -200,7 +200,7 @@ class WideResNet50_2Weights(Weights):
url="https://download.pytorch.org/models/wide_resnet50_2-9ba9bcbe.pth", url="https://download.pytorch.org/models/wide_resnet50_2-9ba9bcbe.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232), transforms=partial(ImageNetEval, crop_size=224, resize_size=232),
meta={ meta={
**_common_meta, **_COMMON_META,
"recipe": "https://github.com/pytorch/vision/issues/3995", "recipe": "https://github.com/pytorch/vision/issues/3995",
"acc@1": 81.602, "acc@1": 81.602,
"acc@5": 95.758, "acc@5": 95.758,
...@@ -213,7 +213,7 @@ class WideResNet101_2Weights(Weights): ...@@ -213,7 +213,7 @@ class WideResNet101_2Weights(Weights):
url="https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth", url="https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
**_common_meta, **_COMMON_META,
"recipe": "https://github.com/pytorch/vision/pull/912#issue-445437439", "recipe": "https://github.com/pytorch/vision/pull/912#issue-445437439",
"acc@1": 78.848, "acc@1": 78.848,
"acc@5": 94.284, "acc@5": 94.284,
...@@ -223,7 +223,7 @@ class WideResNet101_2Weights(Weights): ...@@ -223,7 +223,7 @@ class WideResNet101_2Weights(Weights):
url="https://download.pytorch.org/models/wide_resnet101_2-b8680a8c.pth", url="https://download.pytorch.org/models/wide_resnet101_2-b8680a8c.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232), transforms=partial(ImageNetEval, crop_size=224, resize_size=232),
meta={ meta={
**_common_meta, **_COMMON_META,
"recipe": "https://github.com/pytorch/vision/issues/3995", "recipe": "https://github.com/pytorch/vision/issues/3995",
"acc@1": 82.492, "acc@1": 82.492,
"acc@5": 96.110, "acc@5": 96.110,
...@@ -233,7 +233,7 @@ class WideResNet101_2Weights(Weights): ...@@ -233,7 +233,7 @@ class WideResNet101_2Weights(Weights):
def resnet18(weights: Optional[ResNet18Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: def resnet18(weights: Optional[ResNet18Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = ResNet18Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None weights = ResNet18Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = ResNet18Weights.verify(weights) weights = ResNet18Weights.verify(weights)
...@@ -243,7 +243,7 @@ def resnet18(weights: Optional[ResNet18Weights] = None, progress: bool = True, * ...@@ -243,7 +243,7 @@ def resnet18(weights: Optional[ResNet18Weights] = None, progress: bool = True, *
def resnet34(weights: Optional[ResNet34Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: def resnet34(weights: Optional[ResNet34Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = ResNet34Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None weights = ResNet34Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = ResNet34Weights.verify(weights) weights = ResNet34Weights.verify(weights)
...@@ -253,7 +253,7 @@ def resnet34(weights: Optional[ResNet34Weights] = None, progress: bool = True, * ...@@ -253,7 +253,7 @@ def resnet34(weights: Optional[ResNet34Weights] = None, progress: bool = True, *
def resnet50(weights: Optional[ResNet50Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: def resnet50(weights: Optional[ResNet50Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = ResNet50Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None weights = ResNet50Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = ResNet50Weights.verify(weights) weights = ResNet50Weights.verify(weights)
...@@ -262,7 +262,7 @@ def resnet50(weights: Optional[ResNet50Weights] = None, progress: bool = True, * ...@@ -262,7 +262,7 @@ def resnet50(weights: Optional[ResNet50Weights] = None, progress: bool = True, *
def resnet101(weights: Optional[ResNet101Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: def resnet101(weights: Optional[ResNet101Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = ResNet101Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None weights = ResNet101Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = ResNet101Weights.verify(weights) weights = ResNet101Weights.verify(weights)
...@@ -272,7 +272,7 @@ def resnet101(weights: Optional[ResNet101Weights] = None, progress: bool = True, ...@@ -272,7 +272,7 @@ def resnet101(weights: Optional[ResNet101Weights] = None, progress: bool = True,
def resnet152(weights: Optional[ResNet152Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: def resnet152(weights: Optional[ResNet152Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = ResNet152Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None weights = ResNet152Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = ResNet152Weights.verify(weights) weights = ResNet152Weights.verify(weights)
...@@ -282,7 +282,7 @@ def resnet152(weights: Optional[ResNet152Weights] = None, progress: bool = True, ...@@ -282,7 +282,7 @@ def resnet152(weights: Optional[ResNet152Weights] = None, progress: bool = True,
def resnext50_32x4d(weights: Optional[ResNeXt50_32x4dWeights] = None, progress: bool = True, **kwargs: Any) -> ResNet: def resnext50_32x4d(weights: Optional[ResNeXt50_32x4dWeights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = ResNeXt50_32x4dWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None weights = ResNeXt50_32x4dWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = ResNeXt50_32x4dWeights.verify(weights) weights = ResNeXt50_32x4dWeights.verify(weights)
...@@ -293,7 +293,7 @@ def resnext50_32x4d(weights: Optional[ResNeXt50_32x4dWeights] = None, progress: ...@@ -293,7 +293,7 @@ def resnext50_32x4d(weights: Optional[ResNeXt50_32x4dWeights] = None, progress:
def resnext101_32x8d(weights: Optional[ResNeXt101_32x8dWeights] = None, progress: bool = True, **kwargs: Any) -> ResNet: def resnext101_32x8d(weights: Optional[ResNeXt101_32x8dWeights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = ResNeXt101_32x8dWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None weights = ResNeXt101_32x8dWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = ResNeXt101_32x8dWeights.verify(weights) weights = ResNeXt101_32x8dWeights.verify(weights)
...@@ -304,7 +304,7 @@ def resnext101_32x8d(weights: Optional[ResNeXt101_32x8dWeights] = None, progress ...@@ -304,7 +304,7 @@ def resnext101_32x8d(weights: Optional[ResNeXt101_32x8dWeights] = None, progress
def wide_resnet50_2(weights: Optional[WideResNet50_2Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: def wide_resnet50_2(weights: Optional[WideResNet50_2Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = WideResNet50_2Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None weights = WideResNet50_2Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None
weights = WideResNet50_2Weights.verify(weights) weights = WideResNet50_2Weights.verify(weights)
...@@ -314,7 +314,7 @@ def wide_resnet50_2(weights: Optional[WideResNet50_2Weights] = None, progress: b ...@@ -314,7 +314,7 @@ def wide_resnet50_2(weights: Optional[WideResNet50_2Weights] = None, progress: b
def wide_resnet101_2(weights: Optional[WideResNet101_2Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: def wide_resnet101_2(weights: Optional[WideResNet101_2Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = WideResNet101_2Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None weights = WideResNet101_2Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None
weights = WideResNet101_2Weights.verify(weights) weights = WideResNet101_2Weights.verify(weights)
......
...@@ -24,7 +24,7 @@ __all__ = [ ...@@ -24,7 +24,7 @@ __all__ = [
] ]
_common_meta = { _COMMON_META = {
"categories": _VOC_CATEGORIES, "categories": _VOC_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR, "interpolation": InterpolationMode.BILINEAR,
} }
...@@ -35,7 +35,7 @@ class DeepLabV3ResNet50Weights(Weights): ...@@ -35,7 +35,7 @@ class DeepLabV3ResNet50Weights(Weights):
url="https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth", url="https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth",
transforms=partial(VocEval, resize_size=520), transforms=partial(VocEval, resize_size=520),
meta={ meta={
**_common_meta, **_COMMON_META,
"recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#deeplabv3_resnet50", "recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#deeplabv3_resnet50",
"mIoU": 66.4, "mIoU": 66.4,
"acc": 92.4, "acc": 92.4,
...@@ -48,7 +48,7 @@ class DeepLabV3ResNet101Weights(Weights): ...@@ -48,7 +48,7 @@ class DeepLabV3ResNet101Weights(Weights):
url="https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth", url="https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth",
transforms=partial(VocEval, resize_size=520), transforms=partial(VocEval, resize_size=520),
meta={ meta={
**_common_meta, **_COMMON_META,
"recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#fcn_resnet101", "recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#fcn_resnet101",
"mIoU": 67.4, "mIoU": 67.4,
"acc": 92.4, "acc": 92.4,
...@@ -61,7 +61,7 @@ class DeepLabV3MobileNetV3LargeWeights(Weights): ...@@ -61,7 +61,7 @@ class DeepLabV3MobileNetV3LargeWeights(Weights):
url="https://download.pytorch.org/models/deeplabv3_mobilenet_v3_large-fc3c493d.pth", url="https://download.pytorch.org/models/deeplabv3_mobilenet_v3_large-fc3c493d.pth",
transforms=partial(VocEval, resize_size=520), transforms=partial(VocEval, resize_size=520),
meta={ meta={
**_common_meta, **_COMMON_META,
"recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#deeplabv3_mobilenet_v3_large", "recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#deeplabv3_mobilenet_v3_large",
"mIoU": 60.3, "mIoU": 60.3,
"acc": 91.2, "acc": 91.2,
...@@ -78,12 +78,12 @@ def deeplabv3_resnet50( ...@@ -78,12 +78,12 @@ def deeplabv3_resnet50(
**kwargs: Any, **kwargs: Any,
) -> DeepLabV3: ) -> DeepLabV3:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = DeepLabV3ResNet50Weights.CocoWithVocLabels_RefV1 if kwargs.pop("pretrained") else None weights = DeepLabV3ResNet50Weights.CocoWithVocLabels_RefV1 if kwargs.pop("pretrained") else None
weights = DeepLabV3ResNet50Weights.verify(weights) weights = DeepLabV3ResNet50Weights.verify(weights)
if "pretrained_backbone" in kwargs: if "pretrained_backbone" in kwargs:
warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.") warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.")
weights_backbone = ResNet50Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None weights_backbone = ResNet50Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
weights_backbone = ResNet50Weights.verify(weights_backbone) weights_backbone = ResNet50Weights.verify(weights_backbone)
...@@ -96,7 +96,7 @@ def deeplabv3_resnet50( ...@@ -96,7 +96,7 @@ def deeplabv3_resnet50(
model = _deeplabv3_resnet(backbone, num_classes, aux_loss) model = _deeplabv3_resnet(backbone, num_classes, aux_loss)
if weights is not None: if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress))
return model return model
...@@ -110,12 +110,12 @@ def deeplabv3_resnet101( ...@@ -110,12 +110,12 @@ def deeplabv3_resnet101(
**kwargs: Any, **kwargs: Any,
) -> DeepLabV3: ) -> DeepLabV3:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = DeepLabV3ResNet101Weights.CocoWithVocLabels_RefV1 if kwargs.pop("pretrained") else None weights = DeepLabV3ResNet101Weights.CocoWithVocLabels_RefV1 if kwargs.pop("pretrained") else None
weights = DeepLabV3ResNet101Weights.verify(weights) weights = DeepLabV3ResNet101Weights.verify(weights)
if "pretrained_backbone" in kwargs: if "pretrained_backbone" in kwargs:
warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.") warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.")
weights_backbone = ResNet101Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None weights_backbone = ResNet101Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
weights_backbone = ResNet101Weights.verify(weights_backbone) weights_backbone = ResNet101Weights.verify(weights_backbone)
...@@ -128,7 +128,7 @@ def deeplabv3_resnet101( ...@@ -128,7 +128,7 @@ def deeplabv3_resnet101(
model = _deeplabv3_resnet(backbone, num_classes, aux_loss) model = _deeplabv3_resnet(backbone, num_classes, aux_loss)
if weights is not None: if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress))
return model return model
...@@ -142,12 +142,12 @@ def deeplabv3_mobilenet_v3_large( ...@@ -142,12 +142,12 @@ def deeplabv3_mobilenet_v3_large(
**kwargs: Any, **kwargs: Any,
) -> DeepLabV3: ) -> DeepLabV3:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = DeepLabV3MobileNetV3LargeWeights.CocoWithVocLabels_RefV1 if kwargs.pop("pretrained") else None weights = DeepLabV3MobileNetV3LargeWeights.CocoWithVocLabels_RefV1 if kwargs.pop("pretrained") else None
weights = DeepLabV3MobileNetV3LargeWeights.verify(weights) weights = DeepLabV3MobileNetV3LargeWeights.verify(weights)
if "pretrained_backbone" in kwargs: if "pretrained_backbone" in kwargs:
warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.") warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.")
weights_backbone = MobileNetV3LargeWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None weights_backbone = MobileNetV3LargeWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
weights_backbone = MobileNetV3LargeWeights.verify(weights_backbone) weights_backbone = MobileNetV3LargeWeights.verify(weights_backbone)
...@@ -160,6 +160,6 @@ def deeplabv3_mobilenet_v3_large( ...@@ -160,6 +160,6 @@ def deeplabv3_mobilenet_v3_large(
model = _deeplabv3_mobilenetv3(backbone, num_classes, aux_loss) model = _deeplabv3_mobilenetv3(backbone, num_classes, aux_loss)
if weights is not None: if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress))
return model return model
...@@ -14,7 +14,7 @@ from ..resnet import ResNet50Weights, ResNet101Weights, resnet50, resnet101 ...@@ -14,7 +14,7 @@ from ..resnet import ResNet50Weights, ResNet101Weights, resnet50, resnet101
__all__ = ["FCN", "FCNResNet50Weights", "FCNResNet101Weights", "fcn_resnet50", "fcn_resnet101"] __all__ = ["FCN", "FCNResNet50Weights", "FCNResNet101Weights", "fcn_resnet50", "fcn_resnet101"]
_common_meta = { _COMMON_META = {
"categories": _VOC_CATEGORIES, "categories": _VOC_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR, "interpolation": InterpolationMode.BILINEAR,
} }
...@@ -25,7 +25,7 @@ class FCNResNet50Weights(Weights): ...@@ -25,7 +25,7 @@ class FCNResNet50Weights(Weights):
url="https://download.pytorch.org/models/fcn_resnet50_coco-1167a1af.pth", url="https://download.pytorch.org/models/fcn_resnet50_coco-1167a1af.pth",
transforms=partial(VocEval, resize_size=520), transforms=partial(VocEval, resize_size=520),
meta={ meta={
**_common_meta, **_COMMON_META,
"recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#fcn_resnet50", "recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#fcn_resnet50",
"mIoU": 60.5, "mIoU": 60.5,
"acc": 91.4, "acc": 91.4,
...@@ -38,7 +38,7 @@ class FCNResNet101Weights(Weights): ...@@ -38,7 +38,7 @@ class FCNResNet101Weights(Weights):
url="https://download.pytorch.org/models/fcn_resnet101_coco-7ecb50ca.pth", url="https://download.pytorch.org/models/fcn_resnet101_coco-7ecb50ca.pth",
transforms=partial(VocEval, resize_size=520), transforms=partial(VocEval, resize_size=520),
meta={ meta={
**_common_meta, **_COMMON_META,
"recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#deeplabv3_resnet101", "recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#deeplabv3_resnet101",
"mIoU": 63.7, "mIoU": 63.7,
"acc": 91.9, "acc": 91.9,
...@@ -55,11 +55,11 @@ def fcn_resnet50( ...@@ -55,11 +55,11 @@ def fcn_resnet50(
**kwargs: Any, **kwargs: Any,
) -> FCN: ) -> FCN:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = FCNResNet50Weights.CocoWithVocLabels_RefV1 if kwargs.pop("pretrained") else None weights = FCNResNet50Weights.CocoWithVocLabels_RefV1 if kwargs.pop("pretrained") else None
weights = FCNResNet50Weights.verify(weights) weights = FCNResNet50Weights.verify(weights)
if "pretrained_backbone" in kwargs: if "pretrained_backbone" in kwargs:
warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.") warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.")
weights_backbone = ResNet50Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None weights_backbone = ResNet50Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
weights_backbone = ResNet50Weights.verify(weights_backbone) weights_backbone = ResNet50Weights.verify(weights_backbone)
...@@ -72,7 +72,7 @@ def fcn_resnet50( ...@@ -72,7 +72,7 @@ def fcn_resnet50(
model = _fcn_resnet(backbone, num_classes, aux_loss) model = _fcn_resnet(backbone, num_classes, aux_loss)
if weights is not None: if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress))
return model return model
...@@ -86,11 +86,11 @@ def fcn_resnet101( ...@@ -86,11 +86,11 @@ def fcn_resnet101(
**kwargs: Any, **kwargs: Any,
) -> FCN: ) -> FCN:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = FCNResNet101Weights.CocoWithVocLabels_RefV1 if kwargs.pop("pretrained") else None weights = FCNResNet101Weights.CocoWithVocLabels_RefV1 if kwargs.pop("pretrained") else None
weights = FCNResNet101Weights.verify(weights) weights = FCNResNet101Weights.verify(weights)
if "pretrained_backbone" in kwargs: if "pretrained_backbone" in kwargs:
warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.") warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.")
weights_backbone = ResNet101Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None weights_backbone = ResNet101Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
weights_backbone = ResNet101Weights.verify(weights_backbone) weights_backbone = ResNet101Weights.verify(weights_backbone)
...@@ -103,6 +103,6 @@ def fcn_resnet101( ...@@ -103,6 +103,6 @@ def fcn_resnet101(
model = _fcn_resnet(backbone, num_classes, aux_loss) model = _fcn_resnet(backbone, num_classes, aux_loss)
if weights is not None: if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress))
return model return model
...@@ -39,11 +39,11 @@ def lraspp_mobilenet_v3_large( ...@@ -39,11 +39,11 @@ def lraspp_mobilenet_v3_large(
raise NotImplementedError("This model does not use auxiliary loss") raise NotImplementedError("This model does not use auxiliary loss")
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = LRASPPMobileNetV3LargeWeights.CocoWithVocLabels_RefV1 if kwargs.pop("pretrained") else None weights = LRASPPMobileNetV3LargeWeights.CocoWithVocLabels_RefV1 if kwargs.pop("pretrained") else None
weights = LRASPPMobileNetV3LargeWeights.verify(weights) weights = LRASPPMobileNetV3LargeWeights.verify(weights)
if "pretrained_backbone" in kwargs: if "pretrained_backbone" in kwargs:
warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.") warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.")
weights_backbone = MobileNetV3LargeWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None weights_backbone = MobileNetV3LargeWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
weights_backbone = MobileNetV3LargeWeights.verify(weights_backbone) weights_backbone = MobileNetV3LargeWeights.verify(weights_backbone)
...@@ -55,6 +55,6 @@ def lraspp_mobilenet_v3_large( ...@@ -55,6 +55,6 @@ def lraspp_mobilenet_v3_large(
model = _lraspp_mobilenetv3(backbone, num_classes) model = _lraspp_mobilenetv3(backbone, num_classes)
if weights is not None: if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress))
return model return model
...@@ -35,12 +35,12 @@ def _shufflenetv2( ...@@ -35,12 +35,12 @@ def _shufflenetv2(
model = ShuffleNetV2(*args, **kwargs) model = ShuffleNetV2(*args, **kwargs)
if weights is not None: if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress))
return model return model
_common_meta = { _COMMON_META = {
"size": (224, 224), "size": (224, 224),
"categories": _IMAGENET_CATEGORIES, "categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR, "interpolation": InterpolationMode.BILINEAR,
...@@ -53,7 +53,7 @@ class ShuffleNetV2_x0_5Weights(Weights): ...@@ -53,7 +53,7 @@ class ShuffleNetV2_x0_5Weights(Weights):
url="https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth", url="https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
**_common_meta, **_COMMON_META,
"acc@1": 69.362, "acc@1": 69.362,
"acc@5": 88.316, "acc@5": 88.316,
}, },
...@@ -65,7 +65,7 @@ class ShuffleNetV2_x1_0Weights(Weights): ...@@ -65,7 +65,7 @@ class ShuffleNetV2_x1_0Weights(Weights):
url="https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth", url="https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
**_common_meta, **_COMMON_META,
"acc@1": 60.552, "acc@1": 60.552,
"acc@5": 81.746, "acc@5": 81.746,
}, },
...@@ -84,7 +84,7 @@ def shufflenet_v2_x0_5( ...@@ -84,7 +84,7 @@ def shufflenet_v2_x0_5(
weights: Optional[ShuffleNetV2_x0_5Weights] = None, progress: bool = True, **kwargs: Any weights: Optional[ShuffleNetV2_x0_5Weights] = None, progress: bool = True, **kwargs: Any
) -> ShuffleNetV2: ) -> ShuffleNetV2:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = ShuffleNetV2_x0_5Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None weights = ShuffleNetV2_x0_5Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None
weights = ShuffleNetV2_x0_5Weights.verify(weights) weights = ShuffleNetV2_x0_5Weights.verify(weights)
...@@ -95,7 +95,7 @@ def shufflenet_v2_x1_0( ...@@ -95,7 +95,7 @@ def shufflenet_v2_x1_0(
weights: Optional[ShuffleNetV2_x1_0Weights] = None, progress: bool = True, **kwargs: Any weights: Optional[ShuffleNetV2_x1_0Weights] = None, progress: bool = True, **kwargs: Any
) -> ShuffleNetV2: ) -> ShuffleNetV2:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = ShuffleNetV2_x1_0Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None weights = ShuffleNetV2_x1_0Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None
weights = ShuffleNetV2_x1_0Weights.verify(weights) weights = ShuffleNetV2_x1_0Weights.verify(weights)
...@@ -106,7 +106,7 @@ def shufflenet_v2_x1_5( ...@@ -106,7 +106,7 @@ def shufflenet_v2_x1_5(
weights: Optional[ShuffleNetV2_x1_5Weights] = None, progress: bool = True, **kwargs: Any weights: Optional[ShuffleNetV2_x1_5Weights] = None, progress: bool = True, **kwargs: Any
) -> ShuffleNetV2: ) -> ShuffleNetV2:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
if kwargs.pop("pretrained"): if kwargs.pop("pretrained"):
raise ValueError("No checkpoint is available for model type shufflenet_v2_x1_5") raise ValueError("No checkpoint is available for model type shufflenet_v2_x1_5")
weights = ShuffleNetV2_x1_5Weights.verify(weights) weights = ShuffleNetV2_x1_5Weights.verify(weights)
...@@ -118,7 +118,7 @@ def shufflenet_v2_x2_0( ...@@ -118,7 +118,7 @@ def shufflenet_v2_x2_0(
weights: Optional[ShuffleNetV2_x2_0Weights] = None, progress: bool = True, **kwargs: Any weights: Optional[ShuffleNetV2_x2_0Weights] = None, progress: bool = True, **kwargs: Any
) -> ShuffleNetV2: ) -> ShuffleNetV2:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
if kwargs.pop("pretrained"): if kwargs.pop("pretrained"):
raise ValueError("No checkpoint is available for model type shufflenet_v2_x2_0") raise ValueError("No checkpoint is available for model type shufflenet_v2_x2_0")
weights = ShuffleNetV2_x2_0Weights.verify(weights) weights = ShuffleNetV2_x2_0Weights.verify(weights)
......
...@@ -13,7 +13,7 @@ from ._meta import _IMAGENET_CATEGORIES ...@@ -13,7 +13,7 @@ from ._meta import _IMAGENET_CATEGORIES
__all__ = ["SqueezeNet", "SqueezeNet1_0Weights", "SqueezeNet1_1Weights", "squeezenet1_0", "squeezenet1_1"] __all__ = ["SqueezeNet", "SqueezeNet1_0Weights", "SqueezeNet1_1Weights", "squeezenet1_0", "squeezenet1_1"]
_common_meta = { _COMMON_META = {
"size": (224, 224), "size": (224, 224),
"categories": _IMAGENET_CATEGORIES, "categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR, "interpolation": InterpolationMode.BILINEAR,
...@@ -26,7 +26,7 @@ class SqueezeNet1_0Weights(Weights): ...@@ -26,7 +26,7 @@ class SqueezeNet1_0Weights(Weights):
url="https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth", url="https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
**_common_meta, **_COMMON_META,
"acc@1": 58.092, "acc@1": 58.092,
"acc@5": 80.420, "acc@5": 80.420,
}, },
...@@ -38,7 +38,7 @@ class SqueezeNet1_1Weights(Weights): ...@@ -38,7 +38,7 @@ class SqueezeNet1_1Weights(Weights):
url="https://download.pytorch.org/models/squeezenet1_1-b8a52dc0.pth", url="https://download.pytorch.org/models/squeezenet1_1-b8a52dc0.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
**_common_meta, **_COMMON_META,
"acc@1": 58.178, "acc@1": 58.178,
"acc@5": 80.624, "acc@5": 80.624,
}, },
...@@ -47,7 +47,7 @@ class SqueezeNet1_1Weights(Weights): ...@@ -47,7 +47,7 @@ class SqueezeNet1_1Weights(Weights):
def squeezenet1_0(weights: Optional[SqueezeNet1_0Weights] = None, progress: bool = True, **kwargs: Any) -> SqueezeNet: def squeezenet1_0(weights: Optional[SqueezeNet1_0Weights] = None, progress: bool = True, **kwargs: Any) -> SqueezeNet:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = SqueezeNet1_0Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None weights = SqueezeNet1_0Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None
weights = SqueezeNet1_0Weights.verify(weights) weights = SqueezeNet1_0Weights.verify(weights)
if weights is not None: if weights is not None:
...@@ -56,14 +56,14 @@ def squeezenet1_0(weights: Optional[SqueezeNet1_0Weights] = None, progress: bool ...@@ -56,14 +56,14 @@ def squeezenet1_0(weights: Optional[SqueezeNet1_0Weights] = None, progress: bool
model = SqueezeNet("1_0", **kwargs) model = SqueezeNet("1_0", **kwargs)
if weights is not None: if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress))
return model return model
def squeezenet1_1(weights: Optional[SqueezeNet1_1Weights] = None, progress: bool = True, **kwargs: Any) -> SqueezeNet: def squeezenet1_1(weights: Optional[SqueezeNet1_1Weights] = None, progress: bool = True, **kwargs: Any) -> SqueezeNet:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = SqueezeNet1_1Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None weights = SqueezeNet1_1Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None
weights = SqueezeNet1_1Weights.verify(weights) weights = SqueezeNet1_1Weights.verify(weights)
if weights is not None: if weights is not None:
...@@ -72,6 +72,6 @@ def squeezenet1_1(weights: Optional[SqueezeNet1_1Weights] = None, progress: bool ...@@ -72,6 +72,6 @@ def squeezenet1_1(weights: Optional[SqueezeNet1_1Weights] = None, progress: bool
model = SqueezeNet("1_1", **kwargs) model = SqueezeNet("1_1", **kwargs)
if weights is not None: if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress))
return model return model
...@@ -36,11 +36,11 @@ def _vgg(cfg: str, batch_norm: bool, weights: Optional[Weights], progress: bool, ...@@ -36,11 +36,11 @@ def _vgg(cfg: str, batch_norm: bool, weights: Optional[Weights], progress: bool,
kwargs["num_classes"] = len(weights.meta["categories"]) kwargs["num_classes"] = len(weights.meta["categories"])
model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs) model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs)
if weights is not None: if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress))
return model return model
_common_meta = { _COMMON_META = {
"size": (224, 224), "size": (224, 224),
"categories": _IMAGENET_CATEGORIES, "categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR, "interpolation": InterpolationMode.BILINEAR,
...@@ -53,7 +53,7 @@ class VGG11Weights(Weights): ...@@ -53,7 +53,7 @@ class VGG11Weights(Weights):
url="https://download.pytorch.org/models/vgg11-8a719046.pth", url="https://download.pytorch.org/models/vgg11-8a719046.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
**_common_meta, **_COMMON_META,
"acc@1": 69.020, "acc@1": 69.020,
"acc@5": 88.628, "acc@5": 88.628,
}, },
...@@ -65,7 +65,7 @@ class VGG11BNWeights(Weights): ...@@ -65,7 +65,7 @@ class VGG11BNWeights(Weights):
url="https://download.pytorch.org/models/vgg11_bn-6002323d.pth", url="https://download.pytorch.org/models/vgg11_bn-6002323d.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
**_common_meta, **_COMMON_META,
"acc@1": 70.370, "acc@1": 70.370,
"acc@5": 89.810, "acc@5": 89.810,
}, },
...@@ -77,7 +77,7 @@ class VGG13Weights(Weights): ...@@ -77,7 +77,7 @@ class VGG13Weights(Weights):
url="https://download.pytorch.org/models/vgg13-19584684.pth", url="https://download.pytorch.org/models/vgg13-19584684.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
**_common_meta, **_COMMON_META,
"acc@1": 69.928, "acc@1": 69.928,
"acc@5": 89.246, "acc@5": 89.246,
}, },
...@@ -89,7 +89,7 @@ class VGG13BNWeights(Weights): ...@@ -89,7 +89,7 @@ class VGG13BNWeights(Weights):
url="https://download.pytorch.org/models/vgg13_bn-abd245e5.pth", url="https://download.pytorch.org/models/vgg13_bn-abd245e5.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
**_common_meta, **_COMMON_META,
"acc@1": 71.586, "acc@1": 71.586,
"acc@5": 90.374, "acc@5": 90.374,
}, },
...@@ -101,7 +101,7 @@ class VGG16Weights(Weights): ...@@ -101,7 +101,7 @@ class VGG16Weights(Weights):
url="https://download.pytorch.org/models/vgg16-397923af.pth", url="https://download.pytorch.org/models/vgg16-397923af.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
**_common_meta, **_COMMON_META,
"acc@1": 71.592, "acc@1": 71.592,
"acc@5": 90.382, "acc@5": 90.382,
}, },
...@@ -130,7 +130,7 @@ class VGG16BNWeights(Weights): ...@@ -130,7 +130,7 @@ class VGG16BNWeights(Weights):
url="https://download.pytorch.org/models/vgg16_bn-6c64b313.pth", url="https://download.pytorch.org/models/vgg16_bn-6c64b313.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
**_common_meta, **_COMMON_META,
"acc@1": 73.360, "acc@1": 73.360,
"acc@5": 91.516, "acc@5": 91.516,
}, },
...@@ -142,7 +142,7 @@ class VGG19Weights(Weights): ...@@ -142,7 +142,7 @@ class VGG19Weights(Weights):
url="https://download.pytorch.org/models/vgg19-dcbb9e9d.pth", url="https://download.pytorch.org/models/vgg19-dcbb9e9d.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
**_common_meta, **_COMMON_META,
"acc@1": 72.376, "acc@1": 72.376,
"acc@5": 90.876, "acc@5": 90.876,
}, },
...@@ -154,7 +154,7 @@ class VGG19BNWeights(Weights): ...@@ -154,7 +154,7 @@ class VGG19BNWeights(Weights):
url="https://download.pytorch.org/models/vgg19_bn-c79401a0.pth", url="https://download.pytorch.org/models/vgg19_bn-c79401a0.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
**_common_meta, **_COMMON_META,
"acc@1": 74.218, "acc@1": 74.218,
"acc@5": 91.842, "acc@5": 91.842,
}, },
...@@ -163,7 +163,7 @@ class VGG19BNWeights(Weights): ...@@ -163,7 +163,7 @@ class VGG19BNWeights(Weights):
def vgg11(weights: Optional[VGG11Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: def vgg11(weights: Optional[VGG11Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = VGG11Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None weights = VGG11Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = VGG11Weights.verify(weights) weights = VGG11Weights.verify(weights)
...@@ -172,7 +172,7 @@ def vgg11(weights: Optional[VGG11Weights] = None, progress: bool = True, **kwarg ...@@ -172,7 +172,7 @@ def vgg11(weights: Optional[VGG11Weights] = None, progress: bool = True, **kwarg
def vgg11_bn(weights: Optional[VGG11BNWeights] = None, progress: bool = True, **kwargs: Any) -> VGG: def vgg11_bn(weights: Optional[VGG11BNWeights] = None, progress: bool = True, **kwargs: Any) -> VGG:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = VGG11BNWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None weights = VGG11BNWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = VGG11BNWeights.verify(weights) weights = VGG11BNWeights.verify(weights)
...@@ -181,7 +181,7 @@ def vgg11_bn(weights: Optional[VGG11BNWeights] = None, progress: bool = True, ** ...@@ -181,7 +181,7 @@ def vgg11_bn(weights: Optional[VGG11BNWeights] = None, progress: bool = True, **
def vgg13(weights: Optional[VGG13Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: def vgg13(weights: Optional[VGG13Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = VGG13Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None weights = VGG13Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = VGG13Weights.verify(weights) weights = VGG13Weights.verify(weights)
...@@ -190,7 +190,7 @@ def vgg13(weights: Optional[VGG13Weights] = None, progress: bool = True, **kwarg ...@@ -190,7 +190,7 @@ def vgg13(weights: Optional[VGG13Weights] = None, progress: bool = True, **kwarg
def vgg13_bn(weights: Optional[VGG13BNWeights] = None, progress: bool = True, **kwargs: Any) -> VGG: def vgg13_bn(weights: Optional[VGG13BNWeights] = None, progress: bool = True, **kwargs: Any) -> VGG:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = VGG13BNWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None weights = VGG13BNWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = VGG13BNWeights.verify(weights) weights = VGG13BNWeights.verify(weights)
...@@ -199,7 +199,7 @@ def vgg13_bn(weights: Optional[VGG13BNWeights] = None, progress: bool = True, ** ...@@ -199,7 +199,7 @@ def vgg13_bn(weights: Optional[VGG13BNWeights] = None, progress: bool = True, **
def vgg16(weights: Optional[VGG16Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: def vgg16(weights: Optional[VGG16Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = VGG16Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None weights = VGG16Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = VGG16Weights.verify(weights) weights = VGG16Weights.verify(weights)
...@@ -208,7 +208,7 @@ def vgg16(weights: Optional[VGG16Weights] = None, progress: bool = True, **kwarg ...@@ -208,7 +208,7 @@ def vgg16(weights: Optional[VGG16Weights] = None, progress: bool = True, **kwarg
def vgg16_bn(weights: Optional[VGG16BNWeights] = None, progress: bool = True, **kwargs: Any) -> VGG: def vgg16_bn(weights: Optional[VGG16BNWeights] = None, progress: bool = True, **kwargs: Any) -> VGG:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = VGG16BNWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None weights = VGG16BNWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = VGG16BNWeights.verify(weights) weights = VGG16BNWeights.verify(weights)
...@@ -217,7 +217,7 @@ def vgg16_bn(weights: Optional[VGG16BNWeights] = None, progress: bool = True, ** ...@@ -217,7 +217,7 @@ def vgg16_bn(weights: Optional[VGG16BNWeights] = None, progress: bool = True, **
def vgg19(weights: Optional[VGG19Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: def vgg19(weights: Optional[VGG19Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = VGG19Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None weights = VGG19Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = VGG19Weights.verify(weights) weights = VGG19Weights.verify(weights)
...@@ -226,7 +226,7 @@ def vgg19(weights: Optional[VGG19Weights] = None, progress: bool = True, **kwarg ...@@ -226,7 +226,7 @@ def vgg19(weights: Optional[VGG19Weights] = None, progress: bool = True, **kwarg
def vgg19_bn(weights: Optional[VGG19BNWeights] = None, progress: bool = True, **kwargs: Any) -> VGG: def vgg19_bn(weights: Optional[VGG19BNWeights] = None, progress: bool = True, **kwargs: Any) -> VGG:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = VGG19BNWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None weights = VGG19BNWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = VGG19BNWeights.verify(weights) weights = VGG19BNWeights.verify(weights)
......
...@@ -46,12 +46,12 @@ def _video_resnet( ...@@ -46,12 +46,12 @@ def _video_resnet(
model = VideoResNet(block, conv_makers, layers, stem, **kwargs) model = VideoResNet(block, conv_makers, layers, stem, **kwargs)
if weights is not None: if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress))
return model return model
_common_meta = { _COMMON_META = {
"size": (112, 112), "size": (112, 112),
"categories": _KINETICS400_CATEGORIES, "categories": _KINETICS400_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR, "interpolation": InterpolationMode.BILINEAR,
...@@ -64,7 +64,7 @@ class R3D_18Weights(Weights): ...@@ -64,7 +64,7 @@ class R3D_18Weights(Weights):
url="https://download.pytorch.org/models/r3d_18-b3b3357e.pth", url="https://download.pytorch.org/models/r3d_18-b3b3357e.pth",
transforms=partial(Kinect400Eval, resize_size=(128, 171), crop_size=(112, 112)), transforms=partial(Kinect400Eval, resize_size=(128, 171), crop_size=(112, 112)),
meta={ meta={
**_common_meta, **_COMMON_META,
"acc@1": 52.75, "acc@1": 52.75,
"acc@5": 75.45, "acc@5": 75.45,
}, },
...@@ -76,7 +76,7 @@ class MC3_18Weights(Weights): ...@@ -76,7 +76,7 @@ class MC3_18Weights(Weights):
url="https://download.pytorch.org/models/mc3_18-a90a0ba3.pth", url="https://download.pytorch.org/models/mc3_18-a90a0ba3.pth",
transforms=partial(Kinect400Eval, resize_size=(128, 171), crop_size=(112, 112)), transforms=partial(Kinect400Eval, resize_size=(128, 171), crop_size=(112, 112)),
meta={ meta={
**_common_meta, **_COMMON_META,
"acc@1": 53.90, "acc@1": 53.90,
"acc@5": 76.29, "acc@5": 76.29,
}, },
...@@ -88,7 +88,7 @@ class R2Plus1D_18Weights(Weights): ...@@ -88,7 +88,7 @@ class R2Plus1D_18Weights(Weights):
url="https://download.pytorch.org/models/r2plus1d_18-91a641e6.pth", url="https://download.pytorch.org/models/r2plus1d_18-91a641e6.pth",
transforms=partial(Kinect400Eval, resize_size=(128, 171), crop_size=(112, 112)), transforms=partial(Kinect400Eval, resize_size=(128, 171), crop_size=(112, 112)),
meta={ meta={
**_common_meta, **_COMMON_META,
"acc@1": 57.50, "acc@1": 57.50,
"acc@5": 78.81, "acc@5": 78.81,
}, },
...@@ -97,7 +97,7 @@ class R2Plus1D_18Weights(Weights): ...@@ -97,7 +97,7 @@ class R2Plus1D_18Weights(Weights):
def r3d_18(weights: Optional[R3D_18Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet: def r3d_18(weights: Optional[R3D_18Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = R3D_18Weights.Kinetics400_RefV1 if kwargs.pop("pretrained") else None weights = R3D_18Weights.Kinetics400_RefV1 if kwargs.pop("pretrained") else None
weights = R3D_18Weights.verify(weights) weights = R3D_18Weights.verify(weights)
...@@ -114,7 +114,7 @@ def r3d_18(weights: Optional[R3D_18Weights] = None, progress: bool = True, **kwa ...@@ -114,7 +114,7 @@ def r3d_18(weights: Optional[R3D_18Weights] = None, progress: bool = True, **kwa
def mc3_18(weights: Optional[MC3_18Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet: def mc3_18(weights: Optional[MC3_18Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = MC3_18Weights.Kinetics400_RefV1 if kwargs.pop("pretrained") else None weights = MC3_18Weights.Kinetics400_RefV1 if kwargs.pop("pretrained") else None
weights = MC3_18Weights.verify(weights) weights = MC3_18Weights.verify(weights)
...@@ -131,7 +131,7 @@ def mc3_18(weights: Optional[MC3_18Weights] = None, progress: bool = True, **kwa ...@@ -131,7 +131,7 @@ def mc3_18(weights: Optional[MC3_18Weights] = None, progress: bool = True, **kwa
def r2plus1d_18(weights: Optional[R2Plus1D_18Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet: def r2plus1d_18(weights: Optional[R2Plus1D_18Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet:
if "pretrained" in kwargs: if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.") warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = R2Plus1D_18Weights.Kinetics400_RefV1 if kwargs.pop("pretrained") else None weights = R2Plus1D_18Weights.Kinetics400_RefV1 if kwargs.pop("pretrained") else None
weights = R2Plus1D_18Weights.verify(weights) weights = R2Plus1D_18Weights.verify(weights)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment