Unverified Commit 6a60b9bc authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Update training references from legacy models (#4830)

* Update training references from legacy models.

* Refactoring to share common parts.
parent d24ef4c4
......@@ -42,20 +42,20 @@ torchrun --nproc_per_node=8 train.py --model inception_v3\
--val-resize-size 342 --val-crop-size 299 --train-crop-size 299 --test-only --pretrained
```
### ResNext-50 32x4d
### ResNet
```
torchrun --nproc_per_node=8 train.py\
--model resnext50_32x4d --epochs 100
torchrun --nproc_per_node=8 train.py --model $MODEL
```
Here `$MODEL` is one of `resnet18`, `resnet34`, `resnet50`, `resnet101` or `resnet152`.
### ResNext-101 32x8d
### ResNext
```
torchrun --nproc_per_node=8 train.py\
--model resnext101_32x8d --epochs 100
--model $MODEL --epochs 100
```
Here `$MODEL` is one of `resnext50_32x4d` or `resnext101_32x8d`.
Note that the above command corresponds to a single node with 8 GPUs. If you use
a different number of GPUs and/or a different batch size, then the learning rate
should be scaled accordingly. For example, the pretrained model provided by
......
......@@ -13,15 +13,14 @@ from ._meta import _IMAGENET_CATEGORIES
__all__ = ["AlexNet", "AlexNetWeights", "alexnet"]
_common_meta = {"size": (224, 224), "categories": _IMAGENET_CATEGORIES, "interpolation": InterpolationMode.BILINEAR}
class AlexNetWeights(Weights):
ImageNet1K_RefV1 = WeightEntry(
url="https://download.pytorch.org/models/alexnet-owt-7be5be79.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
"size": (224, 224),
"categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#alexnet-and-vgg",
"acc@1": 56.522,
"acc@5": 79.066,
......
......@@ -63,16 +63,20 @@ def _densenet(
return model
_common_meta = {"size": (224, 224), "categories": _IMAGENET_CATEGORIES, "interpolation": InterpolationMode.BILINEAR}
_common_meta = {
"size": (224, 224),
"categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
"recipe": None, # weights ported from LuaTorch
}
class DenseNet121Weights(Weights):
ImageNet1K_RefV1 = WeightEntry(
ImageNet1K_Community = WeightEntry(
url="https://download.pytorch.org/models/densenet121-a639ec97.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
"recipe": "",
"acc@1": 74.434,
"acc@5": 91.972,
},
......@@ -80,12 +84,11 @@ class DenseNet121Weights(Weights):
class DenseNet161Weights(Weights):
ImageNet1K_RefV1 = WeightEntry(
ImageNet1K_Community = WeightEntry(
url="https://download.pytorch.org/models/densenet161-8d451a50.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
"recipe": "",
"acc@1": 77.138,
"acc@5": 93.560,
},
......@@ -93,12 +96,11 @@ class DenseNet161Weights(Weights):
class DenseNet169Weights(Weights):
ImageNet1K_RefV1 = WeightEntry(
ImageNet1K_Community = WeightEntry(
url="https://download.pytorch.org/models/densenet169-b2777c0a.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
"recipe": "",
"acc@1": 75.600,
"acc@5": 92.806,
},
......@@ -106,12 +108,11 @@ class DenseNet169Weights(Weights):
class DenseNet201Weights(Weights):
ImageNet1K_RefV1 = WeightEntry(
ImageNet1K_Community = WeightEntry(
url="https://download.pytorch.org/models/densenet201-c1103571.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
"recipe": "",
"acc@1": 76.896,
"acc@5": 93.370,
},
......@@ -121,7 +122,7 @@ class DenseNet201Weights(Weights):
def densenet121(weights: Optional[DenseNet121Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
weights = DenseNet121Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = DenseNet121Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None
weights = DenseNet121Weights.verify(weights)
return _densenet(32, (6, 12, 24, 16), 64, weights, progress, **kwargs)
......@@ -130,7 +131,7 @@ def densenet121(weights: Optional[DenseNet121Weights] = None, progress: bool = T
def densenet161(weights: Optional[DenseNet161Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
weights = DenseNet161Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = DenseNet161Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None
weights = DenseNet161Weights.verify(weights)
return _densenet(48, (6, 12, 36, 24), 96, weights, progress, **kwargs)
......@@ -139,7 +140,7 @@ def densenet161(weights: Optional[DenseNet161Weights] = None, progress: bool = T
def densenet169(weights: Optional[DenseNet169Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
weights = DenseNet169Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = DenseNet169Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None
weights = DenseNet169Weights.verify(weights)
return _densenet(32, (6, 12, 32, 32), 64, weights, progress, **kwargs)
......@@ -148,7 +149,7 @@ def densenet169(weights: Optional[DenseNet169Weights] = None, progress: bool = T
def densenet201(weights: Optional[DenseNet201Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
weights = DenseNet201Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = DenseNet201Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None
weights = DenseNet201Weights.verify(weights)
return _densenet(32, (6, 12, 48, 32), 64, weights, progress, **kwargs)
......@@ -62,7 +62,11 @@ def _efficientnet(
return model
_common_meta = {"categories": _IMAGENET_CATEGORIES, "interpolation": InterpolationMode.BICUBIC}
_common_meta = {
"categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BICUBIC,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#efficientnet",
}
class EfficientNetB0Weights(Weights):
......@@ -72,7 +76,6 @@ class EfficientNetB0Weights(Weights):
meta={
**_common_meta,
"size": (224, 224),
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#efficientnet",
"acc@1": 77.692,
"acc@5": 93.532,
},
......@@ -86,7 +89,6 @@ class EfficientNetB1Weights(Weights):
meta={
**_common_meta,
"size": (240, 240),
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#efficientnet",
"acc@1": 78.642,
"acc@5": 94.186,
},
......@@ -100,7 +102,6 @@ class EfficientNetB2Weights(Weights):
meta={
**_common_meta,
"size": (288, 288),
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#efficientnet",
"acc@1": 80.608,
"acc@5": 95.310,
},
......@@ -114,7 +115,6 @@ class EfficientNetB3Weights(Weights):
meta={
**_common_meta,
"size": (300, 300),
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#efficientnet",
"acc@1": 82.008,
"acc@5": 96.054,
},
......@@ -128,7 +128,6 @@ class EfficientNetB4Weights(Weights):
meta={
**_common_meta,
"size": (380, 380),
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#efficientnet",
"acc@1": 83.384,
"acc@5": 96.594,
},
......@@ -142,7 +141,6 @@ class EfficientNetB5Weights(Weights):
meta={
**_common_meta,
"size": (456, 456),
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#efficientnet",
"acc@1": 83.444,
"acc@5": 96.628,
},
......@@ -156,7 +154,6 @@ class EfficientNetB6Weights(Weights):
meta={
**_common_meta,
"size": (528, 528),
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#efficientnet",
"acc@1": 84.008,
"acc@5": 96.916,
},
......@@ -170,7 +167,6 @@ class EfficientNetB7Weights(Weights):
meta={
**_common_meta,
"size": (600, 600),
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#efficientnet",
"acc@1": 84.122,
"acc@5": 96.908,
},
......
......@@ -13,15 +13,14 @@ from ._meta import _IMAGENET_CATEGORIES
__all__ = ["GoogLeNet", "GoogLeNetOutputs", "_GoogLeNetOutputs", "GoogLeNetWeights", "googlenet"]
_common_meta = {"size": (224, 224), "categories": _IMAGENET_CATEGORIES, "interpolation": InterpolationMode.BILINEAR}
class GoogLeNetWeights(Weights):
ImageNet1K_Community = WeightEntry(
url="https://download.pytorch.org/models/googlenet-1378be20.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
"size": (224, 224),
"categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
"recipe": "https://github.com/TheCodez/examples/blob/inception/imagenet/README.md#googlenet",
"acc@1": 69.778,
"acc@5": 89.530,
......
......@@ -13,15 +13,14 @@ from ._meta import _IMAGENET_CATEGORIES
__all__ = ["Inception3", "InceptionOutputs", "_InceptionOutputs", "Inception3Weights", "inception_v3"]
_common_meta = {"size": (299, 299), "categories": _IMAGENET_CATEGORIES, "interpolation": InterpolationMode.BILINEAR}
class Inception3Weights(Weights):
ImageNet1K_TFV1 = WeightEntry(
url="https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth",
transforms=partial(ImageNetEval, crop_size=299, resize_size=342),
meta={
**_common_meta,
"size": (299, 299),
"categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#inception-v3",
"acc@1": 77.294,
"acc@5": 93.450,
......
......@@ -23,7 +23,12 @@ __all__ = [
]
_common_meta = {"size": (224, 224), "categories": _IMAGENET_CATEGORIES, "interpolation": InterpolationMode.BILINEAR}
_common_meta = {
"size": (224, 224),
"categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
"recipe": "https://github.com/1e100/mnasnet_trainer",
}
class MNASNet0_5Weights(Weights):
......@@ -32,7 +37,6 @@ class MNASNet0_5Weights(Weights):
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
"recipe": "https://github.com/1e100/mnasnet_trainer",
"acc@1": 67.734,
"acc@5": 87.490,
},
......@@ -50,7 +54,6 @@ class MNASNet1_0Weights(Weights):
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
"recipe": "https://github.com/1e100/mnasnet_trainer",
"acc@1": 73.456,
"acc@5": 91.510,
},
......
......@@ -13,15 +13,14 @@ from ._meta import _IMAGENET_CATEGORIES
__all__ = ["MobileNetV2", "MobileNetV2Weights", "mobilenet_v2"]
_common_meta = {"size": (224, 224), "categories": _IMAGENET_CATEGORIES, "interpolation": InterpolationMode.BILINEAR}
class MobileNetV2Weights(Weights):
ImageNet1K_RefV1 = WeightEntry(
url="https://download.pytorch.org/models/mobilenet_v2-b0353104.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
"size": (224, 224),
"categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#mobilenetv2",
"acc@1": 71.878,
"acc@5": 90.286,
......
......@@ -37,7 +37,12 @@ def _mobilenet_v3(
return model
_common_meta = {"size": (224, 224), "categories": _IMAGENET_CATEGORIES, "interpolation": InterpolationMode.BILINEAR}
_common_meta = {
"size": (224, 224),
"categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#mobilenetv3-large--small",
}
class MobileNetV3LargeWeights(Weights):
......@@ -46,7 +51,6 @@ class MobileNetV3LargeWeights(Weights):
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#mobilenetv3-large--small",
"acc@1": 74.042,
"acc@5": 91.340,
},
......@@ -59,7 +63,6 @@ class MobileNetV3SmallWeights(Weights):
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#mobilenetv3-large--small",
"acc@1": 67.668,
"acc@5": 87.402,
},
......
......@@ -60,7 +60,7 @@ class ResNet18Weights(Weights):
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
"recipe": "",
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet",
"acc@1": 69.758,
"acc@5": 89.078,
},
......@@ -73,7 +73,7 @@ class ResNet34Weights(Weights):
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
"recipe": "",
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet",
"acc@1": 73.314,
"acc@5": 91.420,
},
......@@ -86,7 +86,7 @@ class ResNet50Weights(Weights):
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification",
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet",
"acc@1": 76.130,
"acc@5": 92.862,
},
......@@ -109,7 +109,7 @@ class ResNet101Weights(Weights):
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
"recipe": "",
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet",
"acc@1": 77.374,
"acc@5": 93.546,
},
......@@ -132,7 +132,7 @@ class ResNet152Weights(Weights):
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
"recipe": "",
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet",
"acc@1": 78.312,
"acc@5": 94.046,
},
......@@ -155,7 +155,7 @@ class ResNeXt50_32x4dWeights(Weights):
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
"recipe": "",
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnext",
"acc@1": 77.618,
"acc@5": 93.698,
},
......@@ -168,7 +168,7 @@ class ResNeXt101_32x8dWeights(Weights):
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
"recipe": "",
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnext",
"acc@1": 79.312,
"acc@5": 94.526,
},
......@@ -176,12 +176,12 @@ class ResNeXt101_32x8dWeights(Weights):
class WideResNet50_2Weights(Weights):
ImageNet1K_RefV1 = WeightEntry(
ImageNet1K_Community = WeightEntry(
url="https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
"recipe": "",
"recipe": "https://github.com/pytorch/vision/pull/912#issue-445437439",
"acc@1": 78.468,
"acc@5": 94.086,
},
......@@ -189,12 +189,12 @@ class WideResNet50_2Weights(Weights):
class WideResNet101_2Weights(Weights):
ImageNet1K_RefV1 = WeightEntry(
ImageNet1K_Community = WeightEntry(
url="https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
"recipe": "",
"recipe": "https://github.com/pytorch/vision/pull/912#issue-445437439",
"acc@1": 78.848,
"acc@5": 94.284,
},
......@@ -275,7 +275,7 @@ def resnext101_32x8d(weights: Optional[ResNeXt101_32x8dWeights] = None, progress
def wide_resnet50_2(weights: Optional[WideResNet50_2Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
weights = WideResNet50_2Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = WideResNet50_2Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None
weights = WideResNet50_2Weights.verify(weights)
kwargs["width_per_group"] = 64 * 2
......@@ -285,7 +285,7 @@ def wide_resnet50_2(weights: Optional[WideResNet50_2Weights] = None, progress: b
def wide_resnet101_2(weights: Optional[WideResNet101_2Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
weights = WideResNet101_2Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = WideResNet101_2Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None
weights = WideResNet101_2Weights.verify(weights)
kwargs["width_per_group"] = 64 * 2
......
......@@ -22,12 +22,15 @@ __all__ = [
]
_common_meta = {"categories": _VOC_CATEGORIES}
class DeepLabV3ResNet50Weights(Weights):
CocoWithVocLabels_RefV1 = WeightEntry(
url="https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth",
transforms=partial(VocEval, resize_size=520),
meta={
"categories": _VOC_CATEGORIES,
**_common_meta,
"recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#deeplabv3_resnet50",
"mIoU": 66.4,
"acc": 92.4,
......@@ -40,7 +43,7 @@ class DeepLabV3ResNet101Weights(Weights):
url="https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth",
transforms=partial(VocEval, resize_size=520),
meta={
"categories": _VOC_CATEGORIES,
**_common_meta,
"recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#fcn_resnet101",
"mIoU": 67.4,
"acc": 92.4,
......@@ -53,7 +56,7 @@ class DeepLabV3MobileNetV3LargeWeights(Weights):
url="https://download.pytorch.org/models/deeplabv3_mobilenet_v3_large-fc3c493d.pth",
transforms=partial(VocEval, resize_size=520),
meta={
"categories": _VOC_CATEGORIES,
**_common_meta,
"recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#deeplabv3_mobilenet_v3_large",
"mIoU": 60.3,
"acc": 91.2,
......
......@@ -12,12 +12,15 @@ from ..resnet import ResNet50Weights, ResNet101Weights, resnet50, resnet101
__all__ = ["FCN", "FCNResNet50Weights", "FCNResNet101Weights", "fcn_resnet50", "fcn_resnet101"]
_common_meta = {"categories": _VOC_CATEGORIES}
class FCNResNet50Weights(Weights):
CocoWithVocLabels_RefV1 = WeightEntry(
url="https://download.pytorch.org/models/fcn_resnet50_coco-1167a1af.pth",
transforms=partial(VocEval, resize_size=520),
meta={
"categories": _VOC_CATEGORIES,
**_common_meta,
"recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#fcn_resnet50",
"mIoU": 60.5,
"acc": 91.4,
......@@ -30,7 +33,7 @@ class FCNResNet101Weights(Weights):
url="https://download.pytorch.org/models/fcn_resnet101_coco-7ecb50ca.pth",
transforms=partial(VocEval, resize_size=520),
meta={
"categories": _VOC_CATEGORIES,
**_common_meta,
"recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#deeplabv3_resnet101",
"mIoU": 63.7,
"acc": 91.9,
......
......@@ -40,16 +40,20 @@ def _shufflenetv2(
return model
_common_meta = {"size": (224, 224), "categories": _IMAGENET_CATEGORIES, "interpolation": InterpolationMode.BILINEAR}
_common_meta = {
"size": (224, 224),
"categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
"recipe": "https://github.com/barrh/Shufflenet-v2-Pytorch/tree/v0.1.0",
}
class ShuffleNetV2_x0_5Weights(Weights):
ImageNet1K_RefV1 = WeightEntry(
ImageNet1K_Community = WeightEntry(
url="https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
"recipe": "",
"acc@1": 69.362,
"acc@5": 88.316,
},
......@@ -57,12 +61,11 @@ class ShuffleNetV2_x0_5Weights(Weights):
class ShuffleNetV2_x1_0Weights(Weights):
ImageNet1K_RefV1 = WeightEntry(
ImageNet1K_Community = WeightEntry(
url="https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
"recipe": "",
"acc@1": 60.552,
"acc@5": 81.746,
},
......@@ -82,7 +85,7 @@ def shufflenet_v2_x0_5(
) -> ShuffleNetV2:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
weights = ShuffleNetV2_x0_5Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = ShuffleNetV2_x0_5Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None
weights = ShuffleNetV2_x0_5Weights.verify(weights)
return _shufflenetv2(weights, progress, [4, 8, 4], [24, 48, 96, 192, 1024], **kwargs)
......@@ -93,7 +96,7 @@ def shufflenet_v2_x1_0(
) -> ShuffleNetV2:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
weights = ShuffleNetV2_x1_0Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = ShuffleNetV2_x1_0Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None
weights = ShuffleNetV2_x1_0Weights.verify(weights)
return _shufflenetv2(weights, progress, [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs)
......
......@@ -13,7 +13,12 @@ from ._meta import _IMAGENET_CATEGORIES
__all__ = ["SqueezeNet", "SqueezeNet1_0Weights", "SqueezeNet1_1Weights", "squeezenet1_0", "squeezenet1_1"]
_common_meta = {"size": (224, 224), "categories": _IMAGENET_CATEGORIES, "interpolation": InterpolationMode.BILINEAR}
_common_meta = {
"size": (224, 224),
"categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
"recipe": "https://github.com/pytorch/vision/pull/49#issuecomment-277560717",
}
class SqueezeNet1_0Weights(Weights):
......@@ -22,7 +27,6 @@ class SqueezeNet1_0Weights(Weights):
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
"recipe": "https://github.com/pytorch/vision/pull/49#issuecomment-277560717",
"acc@1": 58.092,
"acc@5": 80.420,
},
......@@ -35,7 +39,6 @@ class SqueezeNet1_1Weights(Weights):
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
"recipe": "https://github.com/pytorch/vision/pull/49#issuecomment-277560717",
"acc@1": 58.178,
"acc@5": 80.624,
},
......
......@@ -40,7 +40,12 @@ def _vgg(arch: str, cfg: str, batch_norm: bool, weights: Optional[Weights], prog
return model
_common_meta = {"size": (224, 224), "categories": _IMAGENET_CATEGORIES, "interpolation": InterpolationMode.BILINEAR}
_common_meta = {
"size": (224, 224),
"categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#alexnet-and-vgg",
}
class VGG11Weights(Weights):
......@@ -49,7 +54,6 @@ class VGG11Weights(Weights):
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#alexnet-and-vgg",
"acc@1": 69.020,
"acc@5": 88.628,
},
......@@ -62,7 +66,6 @@ class VGG11BNWeights(Weights):
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#alexnet-and-vgg",
"acc@1": 70.370,
"acc@5": 89.810,
},
......@@ -75,7 +78,6 @@ class VGG13Weights(Weights):
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#alexnet-and-vgg",
"acc@1": 69.928,
"acc@5": 89.246,
},
......@@ -88,7 +90,6 @@ class VGG13BNWeights(Weights):
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#alexnet-and-vgg",
"acc@1": 71.586,
"acc@5": 90.374,
},
......@@ -101,7 +102,6 @@ class VGG16Weights(Weights):
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#alexnet-and-vgg",
"acc@1": 71.592,
"acc@5": 90.382,
},
......@@ -114,7 +114,6 @@ class VGG16BNWeights(Weights):
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#alexnet-and-vgg",
"acc@1": 73.360,
"acc@5": 91.516,
},
......@@ -127,7 +126,6 @@ class VGG19Weights(Weights):
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#alexnet-and-vgg",
"acc@1": 72.376,
"acc@5": 90.876,
},
......@@ -140,7 +138,6 @@ class VGG19BNWeights(Weights):
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#alexnet-and-vgg",
"acc@1": 74.218,
"acc@5": 91.842,
},
......
......@@ -51,7 +51,12 @@ def _video_resnet(
return model
_common_meta = {"size": (112, 112), "categories": _KINETICS400_CATEGORIES, "interpolation": InterpolationMode.BILINEAR}
_common_meta = {
"size": (112, 112),
"categories": _KINETICS400_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
"recipe": "https://github.com/pytorch/vision/tree/main/references/video_classification",
}
class R3D_18Weights(Weights):
......@@ -60,7 +65,6 @@ class R3D_18Weights(Weights):
transforms=partial(Kinect400Eval, resize_size=(128, 171), crop_size=(112, 112)),
meta={
**_common_meta,
"recipe": "https://github.com/pytorch/vision/tree/main/references/video_classification",
"acc@1": 52.75,
"acc@5": 75.45,
},
......@@ -73,7 +77,6 @@ class MC3_18Weights(Weights):
transforms=partial(Kinect400Eval, resize_size=(128, 171), crop_size=(112, 112)),
meta={
**_common_meta,
"recipe": "https://github.com/pytorch/vision/tree/main/references/video_classification",
"acc@1": 53.90,
"acc@5": 76.29,
},
......@@ -86,7 +89,6 @@ class R2Plus1D_18Weights(Weights):
transforms=partial(Kinect400Eval, resize_size=(128, 171), crop_size=(112, 112)),
meta={
**_common_meta,
"recipe": "https://github.com/pytorch/vision/tree/main/references/video_classification",
"acc@1": 57.50,
"acc@5": 78.81,
},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment