Unverified Commit 1deb2ec2 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Cleanup Models prototype implementation (#4940)

* Disable WeightEntry to pass-through `Weights.verify()`

* Rename `Weights.state_dict()` to `Weights.get_state_dict()`

* Add TODO for missing doc.

* Moving warning messages for googlenet.

* Upper-case global `_COMMON_META` var

* Replace argument with parameter in all warnings.
parent 30f4d108
......@@ -44,12 +44,12 @@ def _shufflenetv2(
quantize_model(model, backend)
if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress))
return model
_common_meta = {
_COMMON_META = {
"size": (224, 224),
"categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
......@@ -64,7 +64,7 @@ class QuantizedShuffleNetV2_x0_5Weights(Weights):
url="https://download.pytorch.org/models/quantized/shufflenetv2_x0.5_fbgemm-00845098.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
**_COMMON_META,
"unquantized": ShuffleNetV2_x0_5Weights.ImageNet1K_Community,
"acc@1": 57.972,
"acc@5": 79.780,
......@@ -77,7 +77,7 @@ class QuantizedShuffleNetV2_x1_0Weights(Weights):
url="https://download.pytorch.org/models/quantized/shufflenetv2_x1_fbgemm-db332c57.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
**_COMMON_META,
"unquantized": ShuffleNetV2_x1_0Weights.ImageNet1K_Community,
"acc@1": 68.360,
"acc@5": 87.582,
......@@ -92,7 +92,7 @@ def shufflenet_v2_x0_5(
**kwargs: Any,
) -> QuantizableShuffleNetV2:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
if kwargs.pop("pretrained"):
weights = (
QuantizedShuffleNetV2_x0_5Weights.ImageNet1K_FBGEMM_Community
......@@ -117,7 +117,7 @@ def shufflenet_v2_x1_0(
**kwargs: Any,
) -> QuantizableShuffleNetV2:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
if kwargs.pop("pretrained"):
weights = (
QuantizedShuffleNetV2_x1_0Weights.ImageNet1K_FBGEMM_Community
......
......@@ -43,7 +43,7 @@ __all__ = [
"regnet_x_32gf",
]
_common_meta = {"size": (224, 224), "categories": _IMAGENET_CATEGORIES, "interpolation": InterpolationMode.BILINEAR}
_COMMON_META = {"size": (224, 224), "categories": _IMAGENET_CATEGORIES, "interpolation": InterpolationMode.BILINEAR}
def _regnet(
......@@ -59,7 +59,7 @@ def _regnet(
model = RegNet(block_params, norm_layer=norm_layer, **kwargs)
if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress))
return model
......@@ -69,7 +69,7 @@ class RegNet_y_400mfWeights(Weights):
url="https://download.pytorch.org/models/regnet_y_400mf-c65dace8.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
**_COMMON_META,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models",
"acc@1": 74.046,
"acc@5": 91.716,
......@@ -82,7 +82,7 @@ class RegNet_y_800mfWeights(Weights):
url="https://download.pytorch.org/models/regnet_y_800mf-1b27b58c.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
**_COMMON_META,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models",
"acc@1": 76.420,
"acc@5": 93.136,
......@@ -95,7 +95,7 @@ class RegNet_y_1_6gfWeights(Weights):
url="https://download.pytorch.org/models/regnet_y_1_6gf-b11a554e.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
**_COMMON_META,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models",
"acc@1": 77.950,
"acc@5": 93.966,
......@@ -108,7 +108,7 @@ class RegNet_y_3_2gfWeights(Weights):
url="https://download.pytorch.org/models/regnet_y_3_2gf-b5a9779c.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
**_COMMON_META,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#medium-models",
"acc@1": 78.948,
"acc@5": 94.576,
......@@ -121,7 +121,7 @@ class RegNet_y_8gfWeights(Weights):
url="https://download.pytorch.org/models/regnet_y_8gf-d0d0e4a8.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
**_COMMON_META,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#medium-models",
"acc@1": 80.032,
"acc@5": 95.048,
......@@ -134,7 +134,7 @@ class RegNet_y_16gfWeights(Weights):
url="https://download.pytorch.org/models/regnet_y_16gf-9e6ed7dd.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
**_COMMON_META,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#large-models",
"acc@1": 80.424,
"acc@5": 95.240,
......@@ -147,7 +147,7 @@ class RegNet_y_32gfWeights(Weights):
url="https://download.pytorch.org/models/regnet_y_32gf-4dee3f7a.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
**_COMMON_META,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#large-models",
"acc@1": 80.878,
"acc@5": 95.340,
......@@ -160,7 +160,7 @@ class RegNet_x_400mfWeights(Weights):
url="https://download.pytorch.org/models/regnet_x_400mf-adf1edd5.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
**_COMMON_META,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models",
"acc@1": 72.834,
"acc@5": 90.950,
......@@ -173,7 +173,7 @@ class RegNet_x_800mfWeights(Weights):
url="https://download.pytorch.org/models/regnet_x_800mf-ad17e45c.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
**_COMMON_META,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models",
"acc@1": 75.212,
"acc@5": 92.348,
......@@ -186,7 +186,7 @@ class RegNet_x_1_6gfWeights(Weights):
url="https://download.pytorch.org/models/regnet_x_1_6gf-e3633e7f.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
**_COMMON_META,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models",
"acc@1": 77.040,
"acc@5": 93.440,
......@@ -199,7 +199,7 @@ class RegNet_x_3_2gfWeights(Weights):
url="https://download.pytorch.org/models/regnet_x_3_2gf-f342aeae.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
**_COMMON_META,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#medium-models",
"acc@1": 78.364,
"acc@5": 93.992,
......@@ -212,7 +212,7 @@ class RegNet_x_8gfWeights(Weights):
url="https://download.pytorch.org/models/regnet_x_8gf-03ceed89.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
**_COMMON_META,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#medium-models",
"acc@1": 79.344,
"acc@5": 94.686,
......@@ -225,7 +225,7 @@ class RegNet_x_16gfWeights(Weights):
url="https://download.pytorch.org/models/regnet_x_16gf-2007eb11.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
**_COMMON_META,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#medium-models",
"acc@1": 80.058,
"acc@5": 94.944,
......@@ -238,7 +238,7 @@ class RegNet_x_32gfWeights(Weights):
url="https://download.pytorch.org/models/regnet_x_32gf-9d47f8d0.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
**_COMMON_META,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#large-models",
"acc@1": 80.622,
"acc@5": 95.248,
......@@ -248,7 +248,7 @@ class RegNet_x_32gfWeights(Weights):
def regnet_y_400mf(weights: Optional[RegNet_y_400mfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = RegNet_y_400mfWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = RegNet_y_400mfWeights.verify(weights)
......@@ -258,7 +258,7 @@ def regnet_y_400mf(weights: Optional[RegNet_y_400mfWeights] = None, progress: bo
def regnet_y_800mf(weights: Optional[RegNet_y_800mfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = RegNet_y_800mfWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = RegNet_y_800mfWeights.verify(weights)
......@@ -268,7 +268,7 @@ def regnet_y_800mf(weights: Optional[RegNet_y_800mfWeights] = None, progress: bo
def regnet_y_1_6gf(weights: Optional[RegNet_y_1_6gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = RegNet_y_1_6gfWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = RegNet_y_1_6gfWeights.verify(weights)
......@@ -280,7 +280,7 @@ def regnet_y_1_6gf(weights: Optional[RegNet_y_1_6gfWeights] = None, progress: bo
def regnet_y_3_2gf(weights: Optional[RegNet_y_3_2gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = RegNet_y_3_2gfWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = RegNet_y_3_2gfWeights.verify(weights)
params = BlockParams.from_init_params(
......@@ -291,7 +291,7 @@ def regnet_y_3_2gf(weights: Optional[RegNet_y_3_2gfWeights] = None, progress: bo
def regnet_y_8gf(weights: Optional[RegNet_y_8gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = RegNet_y_8gfWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = RegNet_y_8gfWeights.verify(weights)
params = BlockParams.from_init_params(
......@@ -302,7 +302,7 @@ def regnet_y_8gf(weights: Optional[RegNet_y_8gfWeights] = None, progress: bool =
def regnet_y_16gf(weights: Optional[RegNet_y_16gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = RegNet_y_16gfWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = RegNet_y_16gfWeights.verify(weights)
params = BlockParams.from_init_params(
......@@ -313,7 +313,7 @@ def regnet_y_16gf(weights: Optional[RegNet_y_16gfWeights] = None, progress: bool
def regnet_y_32gf(weights: Optional[RegNet_y_32gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = RegNet_y_32gfWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = RegNet_y_32gfWeights.verify(weights)
params = BlockParams.from_init_params(
......@@ -324,7 +324,7 @@ def regnet_y_32gf(weights: Optional[RegNet_y_32gfWeights] = None, progress: bool
def regnet_x_400mf(weights: Optional[RegNet_x_400mfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = RegNet_x_400mfWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = RegNet_x_400mfWeights.verify(weights)
params = BlockParams.from_init_params(depth=22, w_0=24, w_a=24.48, w_m=2.54, group_width=16, **kwargs)
......@@ -334,7 +334,7 @@ def regnet_x_400mf(weights: Optional[RegNet_x_400mfWeights] = None, progress: bo
def regnet_x_800mf(weights: Optional[RegNet_x_800mfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = RegNet_x_800mfWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = RegNet_x_800mfWeights.verify(weights)
params = BlockParams.from_init_params(depth=16, w_0=56, w_a=35.73, w_m=2.28, group_width=16, **kwargs)
......@@ -344,7 +344,7 @@ def regnet_x_800mf(weights: Optional[RegNet_x_800mfWeights] = None, progress: bo
def regnet_x_1_6gf(weights: Optional[RegNet_x_1_6gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = RegNet_x_1_6gfWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = RegNet_x_1_6gfWeights.verify(weights)
params = BlockParams.from_init_params(depth=18, w_0=80, w_a=34.01, w_m=2.25, group_width=24, **kwargs)
......@@ -354,7 +354,7 @@ def regnet_x_1_6gf(weights: Optional[RegNet_x_1_6gfWeights] = None, progress: bo
def regnet_x_3_2gf(weights: Optional[RegNet_x_3_2gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = RegNet_x_3_2gfWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = RegNet_x_3_2gfWeights.verify(weights)
params = BlockParams.from_init_params(depth=25, w_0=88, w_a=26.31, w_m=2.25, group_width=48, **kwargs)
......@@ -364,7 +364,7 @@ def regnet_x_3_2gf(weights: Optional[RegNet_x_3_2gfWeights] = None, progress: bo
def regnet_x_8gf(weights: Optional[RegNet_x_8gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = RegNet_x_8gfWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = RegNet_x_8gfWeights.verify(weights)
params = BlockParams.from_init_params(depth=23, w_0=80, w_a=49.56, w_m=2.88, group_width=120, **kwargs)
......@@ -374,7 +374,7 @@ def regnet_x_8gf(weights: Optional[RegNet_x_8gfWeights] = None, progress: bool =
def regnet_x_16gf(weights: Optional[RegNet_x_16gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = RegNet_x_16gfWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = RegNet_x_16gfWeights.verify(weights)
params = BlockParams.from_init_params(depth=22, w_0=216, w_a=55.59, w_m=2.1, group_width=128, **kwargs)
......@@ -384,7 +384,7 @@ def regnet_x_16gf(weights: Optional[RegNet_x_16gfWeights] = None, progress: bool
def regnet_x_32gf(weights: Optional[RegNet_x_32gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = RegNet_x_32gfWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = RegNet_x_32gfWeights.verify(weights)
params = BlockParams.from_init_params(depth=23, w_0=320, w_a=69.86, w_m=2.0, group_width=168, **kwargs)
......
......@@ -46,12 +46,12 @@ def _resnet(
model = ResNet(block, layers, **kwargs)
if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress))
return model
_common_meta = {"size": (224, 224), "categories": _IMAGENET_CATEGORIES, "interpolation": InterpolationMode.BILINEAR}
_COMMON_META = {"size": (224, 224), "categories": _IMAGENET_CATEGORIES, "interpolation": InterpolationMode.BILINEAR}
class ResNet18Weights(Weights):
......@@ -59,7 +59,7 @@ class ResNet18Weights(Weights):
url="https://download.pytorch.org/models/resnet18-f37072fd.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
**_COMMON_META,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet",
"acc@1": 69.758,
"acc@5": 89.078,
......@@ -72,7 +72,7 @@ class ResNet34Weights(Weights):
url="https://download.pytorch.org/models/resnet34-b627a593.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
**_COMMON_META,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet",
"acc@1": 73.314,
"acc@5": 91.420,
......@@ -85,7 +85,7 @@ class ResNet50Weights(Weights):
url="https://download.pytorch.org/models/resnet50-0676ba61.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
**_COMMON_META,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet",
"acc@1": 76.130,
"acc@5": 92.862,
......@@ -95,7 +95,7 @@ class ResNet50Weights(Weights):
url="https://download.pytorch.org/models/resnet50-f46c3f97.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232),
meta={
**_common_meta,
**_COMMON_META,
"recipe": "https://github.com/pytorch/vision/issues/3995",
"acc@1": 80.674,
"acc@5": 95.166,
......@@ -108,7 +108,7 @@ class ResNet101Weights(Weights):
url="https://download.pytorch.org/models/resnet101-63fe2227.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
**_COMMON_META,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet",
"acc@1": 77.374,
"acc@5": 93.546,
......@@ -118,7 +118,7 @@ class ResNet101Weights(Weights):
url="https://download.pytorch.org/models/resnet101-b641f3a9.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232),
meta={
**_common_meta,
**_COMMON_META,
"recipe": "https://github.com/pytorch/vision/issues/3995",
"acc@1": 81.728,
"acc@5": 95.670,
......@@ -131,7 +131,7 @@ class ResNet152Weights(Weights):
url="https://download.pytorch.org/models/resnet152-394f9c45.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
**_COMMON_META,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet",
"acc@1": 78.312,
"acc@5": 94.046,
......@@ -141,7 +141,7 @@ class ResNet152Weights(Weights):
url="https://download.pytorch.org/models/resnet152-089c0848.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232),
meta={
**_common_meta,
**_COMMON_META,
"recipe": "https://github.com/pytorch/vision/issues/3995",
"acc@1": 82.042,
"acc@5": 95.926,
......@@ -154,7 +154,7 @@ class ResNeXt50_32x4dWeights(Weights):
url="https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
**_COMMON_META,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnext",
"acc@1": 77.618,
"acc@5": 93.698,
......@@ -164,7 +164,7 @@ class ResNeXt50_32x4dWeights(Weights):
url="https://download.pytorch.org/models/resnext50_32x4d-b260af35.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232),
meta={
**_common_meta,
**_COMMON_META,
"recipe": "https://github.com/pytorch/vision/issues/3995",
"acc@1": 81.116,
"acc@5": 95.478,
......@@ -177,7 +177,7 @@ class ResNeXt101_32x8dWeights(Weights):
url="https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
**_COMMON_META,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnext",
"acc@1": 79.312,
"acc@5": 94.526,
......@@ -190,7 +190,7 @@ class WideResNet50_2Weights(Weights):
url="https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
**_COMMON_META,
"recipe": "https://github.com/pytorch/vision/pull/912#issue-445437439",
"acc@1": 78.468,
"acc@5": 94.086,
......@@ -200,7 +200,7 @@ class WideResNet50_2Weights(Weights):
url="https://download.pytorch.org/models/wide_resnet50_2-9ba9bcbe.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232),
meta={
**_common_meta,
**_COMMON_META,
"recipe": "https://github.com/pytorch/vision/issues/3995",
"acc@1": 81.602,
"acc@5": 95.758,
......@@ -213,7 +213,7 @@ class WideResNet101_2Weights(Weights):
url="https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
**_COMMON_META,
"recipe": "https://github.com/pytorch/vision/pull/912#issue-445437439",
"acc@1": 78.848,
"acc@5": 94.284,
......@@ -223,7 +223,7 @@ class WideResNet101_2Weights(Weights):
url="https://download.pytorch.org/models/wide_resnet101_2-b8680a8c.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232),
meta={
**_common_meta,
**_COMMON_META,
"recipe": "https://github.com/pytorch/vision/issues/3995",
"acc@1": 82.492,
"acc@5": 96.110,
......@@ -233,7 +233,7 @@ class WideResNet101_2Weights(Weights):
def resnet18(weights: Optional[ResNet18Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = ResNet18Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = ResNet18Weights.verify(weights)
......@@ -243,7 +243,7 @@ def resnet18(weights: Optional[ResNet18Weights] = None, progress: bool = True, *
def resnet34(weights: Optional[ResNet34Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = ResNet34Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = ResNet34Weights.verify(weights)
......@@ -253,7 +253,7 @@ def resnet34(weights: Optional[ResNet34Weights] = None, progress: bool = True, *
def resnet50(weights: Optional[ResNet50Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = ResNet50Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = ResNet50Weights.verify(weights)
......@@ -262,7 +262,7 @@ def resnet50(weights: Optional[ResNet50Weights] = None, progress: bool = True, *
def resnet101(weights: Optional[ResNet101Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = ResNet101Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = ResNet101Weights.verify(weights)
......@@ -272,7 +272,7 @@ def resnet101(weights: Optional[ResNet101Weights] = None, progress: bool = True,
def resnet152(weights: Optional[ResNet152Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = ResNet152Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = ResNet152Weights.verify(weights)
......@@ -282,7 +282,7 @@ def resnet152(weights: Optional[ResNet152Weights] = None, progress: bool = True,
def resnext50_32x4d(weights: Optional[ResNeXt50_32x4dWeights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = ResNeXt50_32x4dWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = ResNeXt50_32x4dWeights.verify(weights)
......@@ -293,7 +293,7 @@ def resnext50_32x4d(weights: Optional[ResNeXt50_32x4dWeights] = None, progress:
def resnext101_32x8d(weights: Optional[ResNeXt101_32x8dWeights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = ResNeXt101_32x8dWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = ResNeXt101_32x8dWeights.verify(weights)
......@@ -304,7 +304,7 @@ def resnext101_32x8d(weights: Optional[ResNeXt101_32x8dWeights] = None, progress
def wide_resnet50_2(weights: Optional[WideResNet50_2Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = WideResNet50_2Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None
weights = WideResNet50_2Weights.verify(weights)
......@@ -314,7 +314,7 @@ def wide_resnet50_2(weights: Optional[WideResNet50_2Weights] = None, progress: b
def wide_resnet101_2(weights: Optional[WideResNet101_2Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = WideResNet101_2Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None
weights = WideResNet101_2Weights.verify(weights)
......
......@@ -24,7 +24,7 @@ __all__ = [
]
_common_meta = {
_COMMON_META = {
"categories": _VOC_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
}
......@@ -35,7 +35,7 @@ class DeepLabV3ResNet50Weights(Weights):
url="https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth",
transforms=partial(VocEval, resize_size=520),
meta={
**_common_meta,
**_COMMON_META,
"recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#deeplabv3_resnet50",
"mIoU": 66.4,
"acc": 92.4,
......@@ -48,7 +48,7 @@ class DeepLabV3ResNet101Weights(Weights):
url="https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth",
transforms=partial(VocEval, resize_size=520),
meta={
**_common_meta,
**_COMMON_META,
"recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#fcn_resnet101",
"mIoU": 67.4,
"acc": 92.4,
......@@ -61,7 +61,7 @@ class DeepLabV3MobileNetV3LargeWeights(Weights):
url="https://download.pytorch.org/models/deeplabv3_mobilenet_v3_large-fc3c493d.pth",
transforms=partial(VocEval, resize_size=520),
meta={
**_common_meta,
**_COMMON_META,
"recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#deeplabv3_mobilenet_v3_large",
"mIoU": 60.3,
"acc": 91.2,
......@@ -78,12 +78,12 @@ def deeplabv3_resnet50(
**kwargs: Any,
) -> DeepLabV3:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = DeepLabV3ResNet50Weights.CocoWithVocLabels_RefV1 if kwargs.pop("pretrained") else None
weights = DeepLabV3ResNet50Weights.verify(weights)
if "pretrained_backbone" in kwargs:
warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.")
warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.")
weights_backbone = ResNet50Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
weights_backbone = ResNet50Weights.verify(weights_backbone)
......@@ -96,7 +96,7 @@ def deeplabv3_resnet50(
model = _deeplabv3_resnet(backbone, num_classes, aux_loss)
if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress))
return model
......@@ -110,12 +110,12 @@ def deeplabv3_resnet101(
**kwargs: Any,
) -> DeepLabV3:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = DeepLabV3ResNet101Weights.CocoWithVocLabels_RefV1 if kwargs.pop("pretrained") else None
weights = DeepLabV3ResNet101Weights.verify(weights)
if "pretrained_backbone" in kwargs:
warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.")
warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.")
weights_backbone = ResNet101Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
weights_backbone = ResNet101Weights.verify(weights_backbone)
......@@ -128,7 +128,7 @@ def deeplabv3_resnet101(
model = _deeplabv3_resnet(backbone, num_classes, aux_loss)
if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress))
return model
......@@ -142,12 +142,12 @@ def deeplabv3_mobilenet_v3_large(
**kwargs: Any,
) -> DeepLabV3:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = DeepLabV3MobileNetV3LargeWeights.CocoWithVocLabels_RefV1 if kwargs.pop("pretrained") else None
weights = DeepLabV3MobileNetV3LargeWeights.verify(weights)
if "pretrained_backbone" in kwargs:
warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.")
warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.")
weights_backbone = MobileNetV3LargeWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
weights_backbone = MobileNetV3LargeWeights.verify(weights_backbone)
......@@ -160,6 +160,6 @@ def deeplabv3_mobilenet_v3_large(
model = _deeplabv3_mobilenetv3(backbone, num_classes, aux_loss)
if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress))
return model
......@@ -14,7 +14,7 @@ from ..resnet import ResNet50Weights, ResNet101Weights, resnet50, resnet101
__all__ = ["FCN", "FCNResNet50Weights", "FCNResNet101Weights", "fcn_resnet50", "fcn_resnet101"]
_common_meta = {
_COMMON_META = {
"categories": _VOC_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
}
......@@ -25,7 +25,7 @@ class FCNResNet50Weights(Weights):
url="https://download.pytorch.org/models/fcn_resnet50_coco-1167a1af.pth",
transforms=partial(VocEval, resize_size=520),
meta={
**_common_meta,
**_COMMON_META,
"recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#fcn_resnet50",
"mIoU": 60.5,
"acc": 91.4,
......@@ -38,7 +38,7 @@ class FCNResNet101Weights(Weights):
url="https://download.pytorch.org/models/fcn_resnet101_coco-7ecb50ca.pth",
transforms=partial(VocEval, resize_size=520),
meta={
**_common_meta,
**_COMMON_META,
"recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#deeplabv3_resnet101",
"mIoU": 63.7,
"acc": 91.9,
......@@ -55,11 +55,11 @@ def fcn_resnet50(
**kwargs: Any,
) -> FCN:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = FCNResNet50Weights.CocoWithVocLabels_RefV1 if kwargs.pop("pretrained") else None
weights = FCNResNet50Weights.verify(weights)
if "pretrained_backbone" in kwargs:
warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.")
warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.")
weights_backbone = ResNet50Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
weights_backbone = ResNet50Weights.verify(weights_backbone)
......@@ -72,7 +72,7 @@ def fcn_resnet50(
model = _fcn_resnet(backbone, num_classes, aux_loss)
if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress))
return model
......@@ -86,11 +86,11 @@ def fcn_resnet101(
**kwargs: Any,
) -> FCN:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = FCNResNet101Weights.CocoWithVocLabels_RefV1 if kwargs.pop("pretrained") else None
weights = FCNResNet101Weights.verify(weights)
if "pretrained_backbone" in kwargs:
warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.")
warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.")
weights_backbone = ResNet101Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
weights_backbone = ResNet101Weights.verify(weights_backbone)
......@@ -103,6 +103,6 @@ def fcn_resnet101(
model = _fcn_resnet(backbone, num_classes, aux_loss)
if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress))
return model
......@@ -39,11 +39,11 @@ def lraspp_mobilenet_v3_large(
raise NotImplementedError("This model does not use auxiliary loss")
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = LRASPPMobileNetV3LargeWeights.CocoWithVocLabels_RefV1 if kwargs.pop("pretrained") else None
weights = LRASPPMobileNetV3LargeWeights.verify(weights)
if "pretrained_backbone" in kwargs:
warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.")
warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.")
weights_backbone = MobileNetV3LargeWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
weights_backbone = MobileNetV3LargeWeights.verify(weights_backbone)
......@@ -55,6 +55,6 @@ def lraspp_mobilenet_v3_large(
model = _lraspp_mobilenetv3(backbone, num_classes)
if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress))
return model
......@@ -35,12 +35,12 @@ def _shufflenetv2(
model = ShuffleNetV2(*args, **kwargs)
if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress))
return model
_common_meta = {
_COMMON_META = {
"size": (224, 224),
"categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
......@@ -53,7 +53,7 @@ class ShuffleNetV2_x0_5Weights(Weights):
url="https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
**_COMMON_META,
"acc@1": 69.362,
"acc@5": 88.316,
},
......@@ -65,7 +65,7 @@ class ShuffleNetV2_x1_0Weights(Weights):
url="https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
**_COMMON_META,
"acc@1": 60.552,
"acc@5": 81.746,
},
......@@ -84,7 +84,7 @@ def shufflenet_v2_x0_5(
weights: Optional[ShuffleNetV2_x0_5Weights] = None, progress: bool = True, **kwargs: Any
) -> ShuffleNetV2:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = ShuffleNetV2_x0_5Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None
weights = ShuffleNetV2_x0_5Weights.verify(weights)
......@@ -95,7 +95,7 @@ def shufflenet_v2_x1_0(
weights: Optional[ShuffleNetV2_x1_0Weights] = None, progress: bool = True, **kwargs: Any
) -> ShuffleNetV2:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = ShuffleNetV2_x1_0Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None
weights = ShuffleNetV2_x1_0Weights.verify(weights)
......@@ -106,7 +106,7 @@ def shufflenet_v2_x1_5(
weights: Optional[ShuffleNetV2_x1_5Weights] = None, progress: bool = True, **kwargs: Any
) -> ShuffleNetV2:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
if kwargs.pop("pretrained"):
raise ValueError("No checkpoint is available for model type shufflenet_v2_x1_5")
weights = ShuffleNetV2_x1_5Weights.verify(weights)
......@@ -118,7 +118,7 @@ def shufflenet_v2_x2_0(
weights: Optional[ShuffleNetV2_x2_0Weights] = None, progress: bool = True, **kwargs: Any
) -> ShuffleNetV2:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
if kwargs.pop("pretrained"):
raise ValueError("No checkpoint is available for model type shufflenet_v2_x2_0")
weights = ShuffleNetV2_x2_0Weights.verify(weights)
......
......@@ -13,7 +13,7 @@ from ._meta import _IMAGENET_CATEGORIES
__all__ = ["SqueezeNet", "SqueezeNet1_0Weights", "SqueezeNet1_1Weights", "squeezenet1_0", "squeezenet1_1"]
_common_meta = {
_COMMON_META = {
"size": (224, 224),
"categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
......@@ -26,7 +26,7 @@ class SqueezeNet1_0Weights(Weights):
url="https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
**_COMMON_META,
"acc@1": 58.092,
"acc@5": 80.420,
},
......@@ -38,7 +38,7 @@ class SqueezeNet1_1Weights(Weights):
url="https://download.pytorch.org/models/squeezenet1_1-b8a52dc0.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
**_COMMON_META,
"acc@1": 58.178,
"acc@5": 80.624,
},
......@@ -47,7 +47,7 @@ class SqueezeNet1_1Weights(Weights):
def squeezenet1_0(weights: Optional[SqueezeNet1_0Weights] = None, progress: bool = True, **kwargs: Any) -> SqueezeNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = SqueezeNet1_0Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None
weights = SqueezeNet1_0Weights.verify(weights)
if weights is not None:
......@@ -56,14 +56,14 @@ def squeezenet1_0(weights: Optional[SqueezeNet1_0Weights] = None, progress: bool
model = SqueezeNet("1_0", **kwargs)
if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress))
return model
def squeezenet1_1(weights: Optional[SqueezeNet1_1Weights] = None, progress: bool = True, **kwargs: Any) -> SqueezeNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = SqueezeNet1_1Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None
weights = SqueezeNet1_1Weights.verify(weights)
if weights is not None:
......@@ -72,6 +72,6 @@ def squeezenet1_1(weights: Optional[SqueezeNet1_1Weights] = None, progress: bool
model = SqueezeNet("1_1", **kwargs)
if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress))
return model
......@@ -36,11 +36,11 @@ def _vgg(cfg: str, batch_norm: bool, weights: Optional[Weights], progress: bool,
kwargs["num_classes"] = len(weights.meta["categories"])
model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs)
if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress))
return model
_common_meta = {
_COMMON_META = {
"size": (224, 224),
"categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
......@@ -53,7 +53,7 @@ class VGG11Weights(Weights):
url="https://download.pytorch.org/models/vgg11-8a719046.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
**_COMMON_META,
"acc@1": 69.020,
"acc@5": 88.628,
},
......@@ -65,7 +65,7 @@ class VGG11BNWeights(Weights):
url="https://download.pytorch.org/models/vgg11_bn-6002323d.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
**_COMMON_META,
"acc@1": 70.370,
"acc@5": 89.810,
},
......@@ -77,7 +77,7 @@ class VGG13Weights(Weights):
url="https://download.pytorch.org/models/vgg13-19584684.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
**_COMMON_META,
"acc@1": 69.928,
"acc@5": 89.246,
},
......@@ -89,7 +89,7 @@ class VGG13BNWeights(Weights):
url="https://download.pytorch.org/models/vgg13_bn-abd245e5.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
**_COMMON_META,
"acc@1": 71.586,
"acc@5": 90.374,
},
......@@ -101,7 +101,7 @@ class VGG16Weights(Weights):
url="https://download.pytorch.org/models/vgg16-397923af.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
**_COMMON_META,
"acc@1": 71.592,
"acc@5": 90.382,
},
......@@ -130,7 +130,7 @@ class VGG16BNWeights(Weights):
url="https://download.pytorch.org/models/vgg16_bn-6c64b313.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
**_COMMON_META,
"acc@1": 73.360,
"acc@5": 91.516,
},
......@@ -142,7 +142,7 @@ class VGG19Weights(Weights):
url="https://download.pytorch.org/models/vgg19-dcbb9e9d.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
**_COMMON_META,
"acc@1": 72.376,
"acc@5": 90.876,
},
......@@ -154,7 +154,7 @@ class VGG19BNWeights(Weights):
url="https://download.pytorch.org/models/vgg19_bn-c79401a0.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
**_COMMON_META,
"acc@1": 74.218,
"acc@5": 91.842,
},
......@@ -163,7 +163,7 @@ class VGG19BNWeights(Weights):
def vgg11(weights: Optional[VGG11Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = VGG11Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = VGG11Weights.verify(weights)
......@@ -172,7 +172,7 @@ def vgg11(weights: Optional[VGG11Weights] = None, progress: bool = True, **kwarg
def vgg11_bn(weights: Optional[VGG11BNWeights] = None, progress: bool = True, **kwargs: Any) -> VGG:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = VGG11BNWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = VGG11BNWeights.verify(weights)
......@@ -181,7 +181,7 @@ def vgg11_bn(weights: Optional[VGG11BNWeights] = None, progress: bool = True, **
def vgg13(weights: Optional[VGG13Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = VGG13Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = VGG13Weights.verify(weights)
......@@ -190,7 +190,7 @@ def vgg13(weights: Optional[VGG13Weights] = None, progress: bool = True, **kwarg
def vgg13_bn(weights: Optional[VGG13BNWeights] = None, progress: bool = True, **kwargs: Any) -> VGG:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = VGG13BNWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = VGG13BNWeights.verify(weights)
......@@ -199,7 +199,7 @@ def vgg13_bn(weights: Optional[VGG13BNWeights] = None, progress: bool = True, **
def vgg16(weights: Optional[VGG16Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = VGG16Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = VGG16Weights.verify(weights)
......@@ -208,7 +208,7 @@ def vgg16(weights: Optional[VGG16Weights] = None, progress: bool = True, **kwarg
def vgg16_bn(weights: Optional[VGG16BNWeights] = None, progress: bool = True, **kwargs: Any) -> VGG:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = VGG16BNWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = VGG16BNWeights.verify(weights)
......@@ -217,7 +217,7 @@ def vgg16_bn(weights: Optional[VGG16BNWeights] = None, progress: bool = True, **
def vgg19(weights: Optional[VGG19Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = VGG19Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = VGG19Weights.verify(weights)
......@@ -226,7 +226,7 @@ def vgg19(weights: Optional[VGG19Weights] = None, progress: bool = True, **kwarg
def vgg19_bn(weights: Optional[VGG19BNWeights] = None, progress: bool = True, **kwargs: Any) -> VGG:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = VGG19BNWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = VGG19BNWeights.verify(weights)
......
......@@ -46,12 +46,12 @@ def _video_resnet(
model = VideoResNet(block, conv_makers, layers, stem, **kwargs)
if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress))
return model
_common_meta = {
_COMMON_META = {
"size": (112, 112),
"categories": _KINETICS400_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
......@@ -64,7 +64,7 @@ class R3D_18Weights(Weights):
url="https://download.pytorch.org/models/r3d_18-b3b3357e.pth",
transforms=partial(Kinect400Eval, resize_size=(128, 171), crop_size=(112, 112)),
meta={
**_common_meta,
**_COMMON_META,
"acc@1": 52.75,
"acc@5": 75.45,
},
......@@ -76,7 +76,7 @@ class MC3_18Weights(Weights):
url="https://download.pytorch.org/models/mc3_18-a90a0ba3.pth",
transforms=partial(Kinect400Eval, resize_size=(128, 171), crop_size=(112, 112)),
meta={
**_common_meta,
**_COMMON_META,
"acc@1": 53.90,
"acc@5": 76.29,
},
......@@ -88,7 +88,7 @@ class R2Plus1D_18Weights(Weights):
url="https://download.pytorch.org/models/r2plus1d_18-91a641e6.pth",
transforms=partial(Kinect400Eval, resize_size=(128, 171), crop_size=(112, 112)),
meta={
**_common_meta,
**_COMMON_META,
"acc@1": 57.50,
"acc@5": 78.81,
},
......@@ -97,7 +97,7 @@ class R2Plus1D_18Weights(Weights):
def r3d_18(weights: Optional[R3D_18Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = R3D_18Weights.Kinetics400_RefV1 if kwargs.pop("pretrained") else None
weights = R3D_18Weights.verify(weights)
......@@ -114,7 +114,7 @@ def r3d_18(weights: Optional[R3D_18Weights] = None, progress: bool = True, **kwa
def mc3_18(weights: Optional[MC3_18Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = MC3_18Weights.Kinetics400_RefV1 if kwargs.pop("pretrained") else None
weights = MC3_18Weights.verify(weights)
......@@ -131,7 +131,7 @@ def mc3_18(weights: Optional[MC3_18Weights] = None, progress: bool = True, **kwa
def r2plus1d_18(weights: Optional[R2Plus1D_18Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = R2Plus1D_18Weights.Kinetics400_RefV1 if kwargs.pop("pretrained") else None
weights = R2Plus1D_18Weights.verify(weights)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment