Unverified Commit f1587a20 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Clean up purely informational fields from Weight Meta-data (#5852)

* Removing `task`, `architecture` and `quantization`

* Fix mypy

* Remove size field

* Remove unused import.

* Fix mypy

* Remove size from schema list.

* update todo

* Simplify with assert

* Adding min_size to all models.

* Update RAFT min size to 128
parent 7998cdfa
...@@ -69,14 +69,10 @@ class MobileNet_V2_QuantizedWeights(WeightsEnum): ...@@ -69,14 +69,10 @@ class MobileNet_V2_QuantizedWeights(WeightsEnum):
url="https://download.pytorch.org/models/quantized/mobilenet_v2_qnnpack_37f702c5.pth", url="https://download.pytorch.org/models/quantized/mobilenet_v2_qnnpack_37f702c5.pth",
transforms=partial(ImageClassification, crop_size=224), transforms=partial(ImageClassification, crop_size=224),
meta={ meta={
"task": "image_classification",
"architecture": "MobileNetV2",
"num_params": 3504872, "num_params": 3504872,
"size": (224, 224),
"min_size": (1, 1), "min_size": (1, 1),
"categories": _IMAGENET_CATEGORIES, "categories": _IMAGENET_CATEGORIES,
"backend": "qnnpack", "backend": "qnnpack",
"quantization": "Quantization Aware Training",
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#qat-mobilenetv2", "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#qat-mobilenetv2",
"unquantized": MobileNet_V2_Weights.IMAGENET1K_V1, "unquantized": MobileNet_V2_Weights.IMAGENET1K_V1,
"acc@1": 71.658, "acc@1": 71.658,
......
...@@ -159,14 +159,10 @@ class MobileNet_V3_Large_QuantizedWeights(WeightsEnum): ...@@ -159,14 +159,10 @@ class MobileNet_V3_Large_QuantizedWeights(WeightsEnum):
url="https://download.pytorch.org/models/quantized/mobilenet_v3_large_qnnpack-5bcacf28.pth", url="https://download.pytorch.org/models/quantized/mobilenet_v3_large_qnnpack-5bcacf28.pth",
transforms=partial(ImageClassification, crop_size=224), transforms=partial(ImageClassification, crop_size=224),
meta={ meta={
"task": "image_classification",
"architecture": "MobileNetV3",
"num_params": 5483032, "num_params": 5483032,
"size": (224, 224),
"min_size": (1, 1), "min_size": (1, 1),
"categories": _IMAGENET_CATEGORIES, "categories": _IMAGENET_CATEGORIES,
"backend": "qnnpack", "backend": "qnnpack",
"quantization": "Quantization Aware Training",
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#qat-mobilenetv3", "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#qat-mobilenetv3",
"unquantized": MobileNet_V3_Large_Weights.IMAGENET1K_V1, "unquantized": MobileNet_V3_Large_Weights.IMAGENET1K_V1,
"acc@1": 73.004, "acc@1": 73.004,
......
...@@ -147,12 +147,9 @@ def _resnet( ...@@ -147,12 +147,9 @@ def _resnet(
_COMMON_META = { _COMMON_META = {
"task": "image_classification",
"size": (224, 224),
"min_size": (1, 1), "min_size": (1, 1),
"categories": _IMAGENET_CATEGORIES, "categories": _IMAGENET_CATEGORIES,
"backend": "fbgemm", "backend": "fbgemm",
"quantization": "Post Training Quantization",
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models", "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models",
} }
...@@ -163,7 +160,6 @@ class ResNet18_QuantizedWeights(WeightsEnum): ...@@ -163,7 +160,6 @@ class ResNet18_QuantizedWeights(WeightsEnum):
transforms=partial(ImageClassification, crop_size=224), transforms=partial(ImageClassification, crop_size=224),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"architecture": "ResNet",
"num_params": 11689512, "num_params": 11689512,
"unquantized": ResNet18_Weights.IMAGENET1K_V1, "unquantized": ResNet18_Weights.IMAGENET1K_V1,
"acc@1": 69.494, "acc@1": 69.494,
...@@ -179,7 +175,6 @@ class ResNet50_QuantizedWeights(WeightsEnum): ...@@ -179,7 +175,6 @@ class ResNet50_QuantizedWeights(WeightsEnum):
transforms=partial(ImageClassification, crop_size=224), transforms=partial(ImageClassification, crop_size=224),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"architecture": "ResNet",
"num_params": 25557032, "num_params": 25557032,
"unquantized": ResNet50_Weights.IMAGENET1K_V1, "unquantized": ResNet50_Weights.IMAGENET1K_V1,
"acc@1": 75.920, "acc@1": 75.920,
...@@ -191,7 +186,6 @@ class ResNet50_QuantizedWeights(WeightsEnum): ...@@ -191,7 +186,6 @@ class ResNet50_QuantizedWeights(WeightsEnum):
transforms=partial(ImageClassification, crop_size=224, resize_size=232), transforms=partial(ImageClassification, crop_size=224, resize_size=232),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"architecture": "ResNet",
"num_params": 25557032, "num_params": 25557032,
"unquantized": ResNet50_Weights.IMAGENET1K_V2, "unquantized": ResNet50_Weights.IMAGENET1K_V2,
"acc@1": 80.282, "acc@1": 80.282,
...@@ -207,7 +201,6 @@ class ResNeXt101_32X8D_QuantizedWeights(WeightsEnum): ...@@ -207,7 +201,6 @@ class ResNeXt101_32X8D_QuantizedWeights(WeightsEnum):
transforms=partial(ImageClassification, crop_size=224), transforms=partial(ImageClassification, crop_size=224),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"architecture": "ResNeXt",
"num_params": 88791336, "num_params": 88791336,
"unquantized": ResNeXt101_32X8D_Weights.IMAGENET1K_V1, "unquantized": ResNeXt101_32X8D_Weights.IMAGENET1K_V1,
"acc@1": 78.986, "acc@1": 78.986,
...@@ -219,7 +212,6 @@ class ResNeXt101_32X8D_QuantizedWeights(WeightsEnum): ...@@ -219,7 +212,6 @@ class ResNeXt101_32X8D_QuantizedWeights(WeightsEnum):
transforms=partial(ImageClassification, crop_size=224, resize_size=232), transforms=partial(ImageClassification, crop_size=224, resize_size=232),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"architecture": "ResNeXt",
"num_params": 88791336, "num_params": 88791336,
"unquantized": ResNeXt101_32X8D_Weights.IMAGENET1K_V2, "unquantized": ResNeXt101_32X8D_Weights.IMAGENET1K_V2,
"acc@1": 82.574, "acc@1": 82.574,
......
...@@ -102,13 +102,9 @@ def _shufflenetv2( ...@@ -102,13 +102,9 @@ def _shufflenetv2(
_COMMON_META = { _COMMON_META = {
"task": "image_classification",
"architecture": "ShuffleNetV2",
"size": (224, 224),
"min_size": (1, 1), "min_size": (1, 1),
"categories": _IMAGENET_CATEGORIES, "categories": _IMAGENET_CATEGORIES,
"backend": "fbgemm", "backend": "fbgemm",
"quantization": "Post Training Quantization",
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models", "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models",
} }
......
import math import math
from collections import OrderedDict from collections import OrderedDict
from functools import partial from functools import partial
from typing import Any, Callable, List, Optional, Tuple from typing import Any, Callable, Dict, List, Optional, Tuple
import torch import torch
from torch import nn, Tensor from torch import nn, Tensor
...@@ -402,17 +402,13 @@ def _regnet( ...@@ -402,17 +402,13 @@ def _regnet(
return model return model
_COMMON_META = { _COMMON_META: Dict[str, Any] = {
"task": "image_classification",
"architecture": "RegNet",
"size": (224, 224),
"min_size": (1, 1), "min_size": (1, 1),
"categories": _IMAGENET_CATEGORIES, "categories": _IMAGENET_CATEGORIES,
} }
_COMMON_SWAG_META = { _COMMON_SWAG_META = {
**_COMMON_META, **_COMMON_META,
"size": (384, 384),
"recipe": "https://github.com/facebookresearch/SWAG", "recipe": "https://github.com/facebookresearch/SWAG",
"license": "https://github.com/facebookresearch/SWAG/blob/main/LICENSE", "license": "https://github.com/facebookresearch/SWAG/blob/main/LICENSE",
} }
......
...@@ -302,8 +302,6 @@ def _resnet( ...@@ -302,8 +302,6 @@ def _resnet(
_COMMON_META = { _COMMON_META = {
"task": "image_classification",
"size": (224, 224),
"min_size": (1, 1), "min_size": (1, 1),
"categories": _IMAGENET_CATEGORIES, "categories": _IMAGENET_CATEGORIES,
} }
...@@ -315,7 +313,6 @@ class ResNet18_Weights(WeightsEnum): ...@@ -315,7 +313,6 @@ class ResNet18_Weights(WeightsEnum):
transforms=partial(ImageClassification, crop_size=224), transforms=partial(ImageClassification, crop_size=224),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"architecture": "ResNet",
"num_params": 11689512, "num_params": 11689512,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet", "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet",
"acc@1": 69.758, "acc@1": 69.758,
...@@ -331,7 +328,6 @@ class ResNet34_Weights(WeightsEnum): ...@@ -331,7 +328,6 @@ class ResNet34_Weights(WeightsEnum):
transforms=partial(ImageClassification, crop_size=224), transforms=partial(ImageClassification, crop_size=224),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"architecture": "ResNet",
"num_params": 21797672, "num_params": 21797672,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet", "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet",
"acc@1": 73.314, "acc@1": 73.314,
...@@ -347,7 +343,6 @@ class ResNet50_Weights(WeightsEnum): ...@@ -347,7 +343,6 @@ class ResNet50_Weights(WeightsEnum):
transforms=partial(ImageClassification, crop_size=224), transforms=partial(ImageClassification, crop_size=224),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"architecture": "ResNet",
"num_params": 25557032, "num_params": 25557032,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet", "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet",
"acc@1": 76.130, "acc@1": 76.130,
...@@ -359,7 +354,6 @@ class ResNet50_Weights(WeightsEnum): ...@@ -359,7 +354,6 @@ class ResNet50_Weights(WeightsEnum):
transforms=partial(ImageClassification, crop_size=224, resize_size=232), transforms=partial(ImageClassification, crop_size=224, resize_size=232),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"architecture": "ResNet",
"num_params": 25557032, "num_params": 25557032,
"recipe": "https://github.com/pytorch/vision/issues/3995#issuecomment-1013906621", "recipe": "https://github.com/pytorch/vision/issues/3995#issuecomment-1013906621",
"acc@1": 80.858, "acc@1": 80.858,
...@@ -375,7 +369,6 @@ class ResNet101_Weights(WeightsEnum): ...@@ -375,7 +369,6 @@ class ResNet101_Weights(WeightsEnum):
transforms=partial(ImageClassification, crop_size=224), transforms=partial(ImageClassification, crop_size=224),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"architecture": "ResNet",
"num_params": 44549160, "num_params": 44549160,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet", "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet",
"acc@1": 77.374, "acc@1": 77.374,
...@@ -387,7 +380,6 @@ class ResNet101_Weights(WeightsEnum): ...@@ -387,7 +380,6 @@ class ResNet101_Weights(WeightsEnum):
transforms=partial(ImageClassification, crop_size=224, resize_size=232), transforms=partial(ImageClassification, crop_size=224, resize_size=232),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"architecture": "ResNet",
"num_params": 44549160, "num_params": 44549160,
"recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
"acc@1": 81.886, "acc@1": 81.886,
...@@ -403,7 +395,6 @@ class ResNet152_Weights(WeightsEnum): ...@@ -403,7 +395,6 @@ class ResNet152_Weights(WeightsEnum):
transforms=partial(ImageClassification, crop_size=224), transforms=partial(ImageClassification, crop_size=224),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"architecture": "ResNet",
"num_params": 60192808, "num_params": 60192808,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet", "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet",
"acc@1": 78.312, "acc@1": 78.312,
...@@ -415,7 +406,6 @@ class ResNet152_Weights(WeightsEnum): ...@@ -415,7 +406,6 @@ class ResNet152_Weights(WeightsEnum):
transforms=partial(ImageClassification, crop_size=224, resize_size=232), transforms=partial(ImageClassification, crop_size=224, resize_size=232),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"architecture": "ResNet",
"num_params": 60192808, "num_params": 60192808,
"recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
"acc@1": 82.284, "acc@1": 82.284,
...@@ -431,7 +421,6 @@ class ResNeXt50_32X4D_Weights(WeightsEnum): ...@@ -431,7 +421,6 @@ class ResNeXt50_32X4D_Weights(WeightsEnum):
transforms=partial(ImageClassification, crop_size=224), transforms=partial(ImageClassification, crop_size=224),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"architecture": "ResNeXt",
"num_params": 25028904, "num_params": 25028904,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnext", "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnext",
"acc@1": 77.618, "acc@1": 77.618,
...@@ -443,7 +432,6 @@ class ResNeXt50_32X4D_Weights(WeightsEnum): ...@@ -443,7 +432,6 @@ class ResNeXt50_32X4D_Weights(WeightsEnum):
transforms=partial(ImageClassification, crop_size=224, resize_size=232), transforms=partial(ImageClassification, crop_size=224, resize_size=232),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"architecture": "ResNeXt",
"num_params": 25028904, "num_params": 25028904,
"recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
"acc@1": 81.198, "acc@1": 81.198,
...@@ -459,7 +447,6 @@ class ResNeXt101_32X8D_Weights(WeightsEnum): ...@@ -459,7 +447,6 @@ class ResNeXt101_32X8D_Weights(WeightsEnum):
transforms=partial(ImageClassification, crop_size=224), transforms=partial(ImageClassification, crop_size=224),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"architecture": "ResNeXt",
"num_params": 88791336, "num_params": 88791336,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnext", "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnext",
"acc@1": 79.312, "acc@1": 79.312,
...@@ -471,7 +458,6 @@ class ResNeXt101_32X8D_Weights(WeightsEnum): ...@@ -471,7 +458,6 @@ class ResNeXt101_32X8D_Weights(WeightsEnum):
transforms=partial(ImageClassification, crop_size=224, resize_size=232), transforms=partial(ImageClassification, crop_size=224, resize_size=232),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"architecture": "ResNeXt",
"num_params": 88791336, "num_params": 88791336,
"recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-fixres", "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-fixres",
"acc@1": 82.834, "acc@1": 82.834,
...@@ -487,7 +473,6 @@ class Wide_ResNet50_2_Weights(WeightsEnum): ...@@ -487,7 +473,6 @@ class Wide_ResNet50_2_Weights(WeightsEnum):
transforms=partial(ImageClassification, crop_size=224), transforms=partial(ImageClassification, crop_size=224),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"architecture": "WideResNet",
"num_params": 68883240, "num_params": 68883240,
"recipe": "https://github.com/pytorch/vision/pull/912#issue-445437439", "recipe": "https://github.com/pytorch/vision/pull/912#issue-445437439",
"acc@1": 78.468, "acc@1": 78.468,
...@@ -499,7 +484,6 @@ class Wide_ResNet50_2_Weights(WeightsEnum): ...@@ -499,7 +484,6 @@ class Wide_ResNet50_2_Weights(WeightsEnum):
transforms=partial(ImageClassification, crop_size=224, resize_size=232), transforms=partial(ImageClassification, crop_size=224, resize_size=232),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"architecture": "WideResNet",
"num_params": 68883240, "num_params": 68883240,
"recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-fixres", "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-fixres",
"acc@1": 81.602, "acc@1": 81.602,
...@@ -515,7 +499,6 @@ class Wide_ResNet101_2_Weights(WeightsEnum): ...@@ -515,7 +499,6 @@ class Wide_ResNet101_2_Weights(WeightsEnum):
transforms=partial(ImageClassification, crop_size=224), transforms=partial(ImageClassification, crop_size=224),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"architecture": "WideResNet",
"num_params": 126886696, "num_params": 126886696,
"recipe": "https://github.com/pytorch/vision/pull/912#issue-445437439", "recipe": "https://github.com/pytorch/vision/pull/912#issue-445437439",
"acc@1": 78.848, "acc@1": 78.848,
...@@ -527,7 +510,6 @@ class Wide_ResNet101_2_Weights(WeightsEnum): ...@@ -527,7 +510,6 @@ class Wide_ResNet101_2_Weights(WeightsEnum):
transforms=partial(ImageClassification, crop_size=224, resize_size=232), transforms=partial(ImageClassification, crop_size=224, resize_size=232),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"architecture": "WideResNet",
"num_params": 126886696, "num_params": 126886696,
"recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
"acc@1": 82.510, "acc@1": 82.510,
......
...@@ -129,9 +129,8 @@ def _deeplabv3_resnet( ...@@ -129,9 +129,8 @@ def _deeplabv3_resnet(
_COMMON_META = { _COMMON_META = {
"task": "image_semantic_segmentation",
"architecture": "DeepLabV3",
"categories": _VOC_CATEGORIES, "categories": _VOC_CATEGORIES,
"min_size": (1, 1),
} }
......
...@@ -48,9 +48,8 @@ class FCNHead(nn.Sequential): ...@@ -48,9 +48,8 @@ class FCNHead(nn.Sequential):
_COMMON_META = { _COMMON_META = {
"task": "image_semantic_segmentation",
"architecture": "FCN",
"categories": _VOC_CATEGORIES, "categories": _VOC_CATEGORIES,
"min_size": (1, 1),
} }
......
...@@ -98,10 +98,9 @@ class LRASPP_MobileNet_V3_Large_Weights(WeightsEnum): ...@@ -98,10 +98,9 @@ class LRASPP_MobileNet_V3_Large_Weights(WeightsEnum):
url="https://download.pytorch.org/models/lraspp_mobilenet_v3_large-d234d4ea.pth", url="https://download.pytorch.org/models/lraspp_mobilenet_v3_large-d234d4ea.pth",
transforms=partial(SemanticSegmentation, resize_size=520), transforms=partial(SemanticSegmentation, resize_size=520),
meta={ meta={
"task": "image_semantic_segmentation",
"architecture": "LRASPP",
"num_params": 3221538, "num_params": 3221538,
"categories": _VOC_CATEGORIES, "categories": _VOC_CATEGORIES,
"min_size": (1, 1),
"recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#lraspp_mobilenet_v3_large", "recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#lraspp_mobilenet_v3_large",
"mIoU": 57.9, "mIoU": 57.9,
"acc": 91.2, "acc": 91.2,
......
...@@ -184,9 +184,6 @@ def _shufflenetv2( ...@@ -184,9 +184,6 @@ def _shufflenetv2(
_COMMON_META = { _COMMON_META = {
"task": "image_classification",
"architecture": "ShuffleNetV2",
"size": (224, 224),
"min_size": (1, 1), "min_size": (1, 1),
"categories": _IMAGENET_CATEGORIES, "categories": _IMAGENET_CATEGORIES,
"recipe": "https://github.com/barrh/Shufflenet-v2-Pytorch/tree/v0.1.0", "recipe": "https://github.com/barrh/Shufflenet-v2-Pytorch/tree/v0.1.0",
......
...@@ -115,9 +115,6 @@ def _squeezenet( ...@@ -115,9 +115,6 @@ def _squeezenet(
_COMMON_META = { _COMMON_META = {
"task": "image_classification",
"architecture": "SqueezeNet",
"size": (224, 224),
"categories": _IMAGENET_CATEGORIES, "categories": _IMAGENET_CATEGORIES,
"recipe": "https://github.com/pytorch/vision/pull/49#issuecomment-277560717", "recipe": "https://github.com/pytorch/vision/pull/49#issuecomment-277560717",
} }
......
...@@ -107,9 +107,6 @@ def _vgg(cfg: str, batch_norm: bool, weights: Optional[WeightsEnum], progress: b ...@@ -107,9 +107,6 @@ def _vgg(cfg: str, batch_norm: bool, weights: Optional[WeightsEnum], progress: b
_COMMON_META = { _COMMON_META = {
"task": "image_classification",
"architecture": "VGG",
"size": (224, 224),
"min_size": (32, 32), "min_size": (32, 32),
"categories": _IMAGENET_CATEGORIES, "categories": _IMAGENET_CATEGORIES,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#alexnet-and-vgg", "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#alexnet-and-vgg",
......
...@@ -309,8 +309,6 @@ def _video_resnet( ...@@ -309,8 +309,6 @@ def _video_resnet(
_COMMON_META = { _COMMON_META = {
"task": "video_classification",
"size": (112, 112),
"min_size": (1, 1), "min_size": (1, 1),
"categories": _KINETICS400_CATEGORIES, "categories": _KINETICS400_CATEGORIES,
"recipe": "https://github.com/pytorch/vision/tree/main/references/video_classification", "recipe": "https://github.com/pytorch/vision/tree/main/references/video_classification",
...@@ -323,7 +321,6 @@ class R3D_18_Weights(WeightsEnum): ...@@ -323,7 +321,6 @@ class R3D_18_Weights(WeightsEnum):
transforms=partial(VideoClassification, crop_size=(112, 112), resize_size=(128, 171)), transforms=partial(VideoClassification, crop_size=(112, 112), resize_size=(128, 171)),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"architecture": "R3D",
"num_params": 33371472, "num_params": 33371472,
"acc@1": 52.75, "acc@1": 52.75,
"acc@5": 75.45, "acc@5": 75.45,
...@@ -338,7 +335,6 @@ class MC3_18_Weights(WeightsEnum): ...@@ -338,7 +335,6 @@ class MC3_18_Weights(WeightsEnum):
transforms=partial(VideoClassification, crop_size=(112, 112), resize_size=(128, 171)), transforms=partial(VideoClassification, crop_size=(112, 112), resize_size=(128, 171)),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"architecture": "MC3",
"num_params": 11695440, "num_params": 11695440,
"acc@1": 53.90, "acc@1": 53.90,
"acc@5": 76.29, "acc@5": 76.29,
...@@ -353,7 +349,6 @@ class R2Plus1D_18_Weights(WeightsEnum): ...@@ -353,7 +349,6 @@ class R2Plus1D_18_Weights(WeightsEnum):
transforms=partial(VideoClassification, crop_size=(112, 112), resize_size=(128, 171)), transforms=partial(VideoClassification, crop_size=(112, 112), resize_size=(128, 171)),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"architecture": "R(2+1)D",
"num_params": 31505325, "num_params": 31505325,
"acc@1": 57.50, "acc@1": 57.50,
"acc@5": 78.81, "acc@5": 78.81,
......
import math import math
from collections import OrderedDict from collections import OrderedDict
from functools import partial from functools import partial
from typing import Any, Callable, List, NamedTuple, Optional, Sequence, Dict from typing import Any, Callable, List, NamedTuple, Optional, Dict
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -288,18 +288,8 @@ def _vision_transformer( ...@@ -288,18 +288,8 @@ def _vision_transformer(
) -> VisionTransformer: ) -> VisionTransformer:
if weights is not None: if weights is not None:
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
if isinstance(weights.meta["size"], int): assert weights.meta["min_size"][0] == weights.meta["min_size"][1]
_ovewrite_named_param(kwargs, "image_size", weights.meta["size"]) _ovewrite_named_param(kwargs, "image_size", weights.meta["min_size"][0])
elif isinstance(weights.meta["size"], Sequence):
if len(weights.meta["size"]) != 2 or weights.meta["size"][0] != weights.meta["size"][1]:
raise ValueError(
f'size: {weights.meta["size"]} is not valid! Currently we only support a 2-dimensional square and width = height'
)
_ovewrite_named_param(kwargs, "image_size", weights.meta["size"][0])
else:
raise ValueError(
f'weights.meta["size"]: {weights.meta["size"]} is not valid, the type should be either an int or a Sequence[int]'
)
image_size = kwargs.pop("image_size", 224) image_size = kwargs.pop("image_size", 224)
model = VisionTransformer( model = VisionTransformer(
...@@ -319,12 +309,10 @@ def _vision_transformer( ...@@ -319,12 +309,10 @@ def _vision_transformer(
_COMMON_META: Dict[str, Any] = { _COMMON_META: Dict[str, Any] = {
"task": "image_classification",
"architecture": "ViT",
"categories": _IMAGENET_CATEGORIES, "categories": _IMAGENET_CATEGORIES,
} }
_COMMON_SWAG_META: Dict[str, Any] = { _COMMON_SWAG_META = {
**_COMMON_META, **_COMMON_META,
"recipe": "https://github.com/facebookresearch/SWAG", "recipe": "https://github.com/facebookresearch/SWAG",
"license": "https://github.com/facebookresearch/SWAG/blob/main/LICENSE", "license": "https://github.com/facebookresearch/SWAG/blob/main/LICENSE",
...@@ -338,7 +326,6 @@ class ViT_B_16_Weights(WeightsEnum): ...@@ -338,7 +326,6 @@ class ViT_B_16_Weights(WeightsEnum):
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 86567656, "num_params": 86567656,
"size": (224, 224),
"min_size": (224, 224), "min_size": (224, 224),
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#vit_b_16", "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#vit_b_16",
"acc@1": 81.072, "acc@1": 81.072,
...@@ -356,7 +343,6 @@ class ViT_B_16_Weights(WeightsEnum): ...@@ -356,7 +343,6 @@ class ViT_B_16_Weights(WeightsEnum):
meta={ meta={
**_COMMON_SWAG_META, **_COMMON_SWAG_META,
"num_params": 86859496, "num_params": 86859496,
"size": (384, 384),
"min_size": (384, 384), "min_size": (384, 384),
"acc@1": 85.304, "acc@1": 85.304,
"acc@5": 97.650, "acc@5": 97.650,
...@@ -374,7 +360,6 @@ class ViT_B_16_Weights(WeightsEnum): ...@@ -374,7 +360,6 @@ class ViT_B_16_Weights(WeightsEnum):
**_COMMON_SWAG_META, **_COMMON_SWAG_META,
"recipe": "https://github.com/pytorch/vision/pull/5793", "recipe": "https://github.com/pytorch/vision/pull/5793",
"num_params": 86567656, "num_params": 86567656,
"size": (224, 224),
"min_size": (224, 224), "min_size": (224, 224),
"acc@1": 81.886, "acc@1": 81.886,
"acc@5": 96.180, "acc@5": 96.180,
...@@ -390,7 +375,6 @@ class ViT_B_32_Weights(WeightsEnum): ...@@ -390,7 +375,6 @@ class ViT_B_32_Weights(WeightsEnum):
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 88224232, "num_params": 88224232,
"size": (224, 224),
"min_size": (224, 224), "min_size": (224, 224),
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#vit_b_32", "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#vit_b_32",
"acc@1": 75.912, "acc@1": 75.912,
...@@ -407,7 +391,6 @@ class ViT_L_16_Weights(WeightsEnum): ...@@ -407,7 +391,6 @@ class ViT_L_16_Weights(WeightsEnum):
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 304326632, "num_params": 304326632,
"size": (224, 224),
"min_size": (224, 224), "min_size": (224, 224),
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#vit_l_16", "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#vit_l_16",
"acc@1": 79.662, "acc@1": 79.662,
...@@ -425,7 +408,6 @@ class ViT_L_16_Weights(WeightsEnum): ...@@ -425,7 +408,6 @@ class ViT_L_16_Weights(WeightsEnum):
meta={ meta={
**_COMMON_SWAG_META, **_COMMON_SWAG_META,
"num_params": 305174504, "num_params": 305174504,
"size": (512, 512),
"min_size": (512, 512), "min_size": (512, 512),
"acc@1": 88.064, "acc@1": 88.064,
"acc@5": 98.512, "acc@5": 98.512,
...@@ -443,7 +425,6 @@ class ViT_L_16_Weights(WeightsEnum): ...@@ -443,7 +425,6 @@ class ViT_L_16_Weights(WeightsEnum):
**_COMMON_SWAG_META, **_COMMON_SWAG_META,
"recipe": "https://github.com/pytorch/vision/pull/5793", "recipe": "https://github.com/pytorch/vision/pull/5793",
"num_params": 304326632, "num_params": 304326632,
"size": (224, 224),
"min_size": (224, 224), "min_size": (224, 224),
"acc@1": 85.146, "acc@1": 85.146,
"acc@5": 97.422, "acc@5": 97.422,
...@@ -459,7 +440,6 @@ class ViT_L_32_Weights(WeightsEnum): ...@@ -459,7 +440,6 @@ class ViT_L_32_Weights(WeightsEnum):
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 306535400, "num_params": 306535400,
"size": (224, 224),
"min_size": (224, 224), "min_size": (224, 224),
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#vit_l_32", "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#vit_l_32",
"acc@1": 76.972, "acc@1": 76.972,
...@@ -481,7 +461,6 @@ class ViT_H_14_Weights(WeightsEnum): ...@@ -481,7 +461,6 @@ class ViT_H_14_Weights(WeightsEnum):
meta={ meta={
**_COMMON_SWAG_META, **_COMMON_SWAG_META,
"num_params": 633470440, "num_params": 633470440,
"size": (518, 518),
"min_size": (518, 518), "min_size": (518, 518),
"acc@1": 88.552, "acc@1": 88.552,
"acc@5": 98.694, "acc@5": 98.694,
...@@ -499,7 +478,6 @@ class ViT_H_14_Weights(WeightsEnum): ...@@ -499,7 +478,6 @@ class ViT_H_14_Weights(WeightsEnum):
**_COMMON_SWAG_META, **_COMMON_SWAG_META,
"recipe": "https://github.com/pytorch/vision/pull/5793", "recipe": "https://github.com/pytorch/vision/pull/5793",
"num_params": 632045800, "num_params": 632045800,
"size": (224, 224),
"min_size": (224, 224), "min_size": (224, 224),
"acc@1": 85.708, "acc@1": 85.708,
"acc@5": 97.730, "acc@5": 97.730,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment