Unverified Commit f1587a20 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Clean up purely informational fields from Weight Meta-data (#5852)

* Removing `task`, `architecture` and `quantization`

* Fix mypy

* Remove size field

* Remove unused import.

* Fix mypy

* Remove size from schema list.

* update todo

* Simplify with assert

* Adding min_size to all models.

* Update RAFT min size to 128
parent 7998cdfa
......@@ -69,14 +69,10 @@ class MobileNet_V2_QuantizedWeights(WeightsEnum):
url="https://download.pytorch.org/models/quantized/mobilenet_v2_qnnpack_37f702c5.pth",
transforms=partial(ImageClassification, crop_size=224),
meta={
"task": "image_classification",
"architecture": "MobileNetV2",
"num_params": 3504872,
"size": (224, 224),
"min_size": (1, 1),
"categories": _IMAGENET_CATEGORIES,
"backend": "qnnpack",
"quantization": "Quantization Aware Training",
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#qat-mobilenetv2",
"unquantized": MobileNet_V2_Weights.IMAGENET1K_V1,
"acc@1": 71.658,
......
......@@ -159,14 +159,10 @@ class MobileNet_V3_Large_QuantizedWeights(WeightsEnum):
url="https://download.pytorch.org/models/quantized/mobilenet_v3_large_qnnpack-5bcacf28.pth",
transforms=partial(ImageClassification, crop_size=224),
meta={
"task": "image_classification",
"architecture": "MobileNetV3",
"num_params": 5483032,
"size": (224, 224),
"min_size": (1, 1),
"categories": _IMAGENET_CATEGORIES,
"backend": "qnnpack",
"quantization": "Quantization Aware Training",
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#qat-mobilenetv3",
"unquantized": MobileNet_V3_Large_Weights.IMAGENET1K_V1,
"acc@1": 73.004,
......
......@@ -147,12 +147,9 @@ def _resnet(
_COMMON_META = {
"task": "image_classification",
"size": (224, 224),
"min_size": (1, 1),
"categories": _IMAGENET_CATEGORIES,
"backend": "fbgemm",
"quantization": "Post Training Quantization",
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models",
}
......@@ -163,7 +160,6 @@ class ResNet18_QuantizedWeights(WeightsEnum):
transforms=partial(ImageClassification, crop_size=224),
meta={
**_COMMON_META,
"architecture": "ResNet",
"num_params": 11689512,
"unquantized": ResNet18_Weights.IMAGENET1K_V1,
"acc@1": 69.494,
......@@ -179,7 +175,6 @@ class ResNet50_QuantizedWeights(WeightsEnum):
transforms=partial(ImageClassification, crop_size=224),
meta={
**_COMMON_META,
"architecture": "ResNet",
"num_params": 25557032,
"unquantized": ResNet50_Weights.IMAGENET1K_V1,
"acc@1": 75.920,
......@@ -191,7 +186,6 @@ class ResNet50_QuantizedWeights(WeightsEnum):
transforms=partial(ImageClassification, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"architecture": "ResNet",
"num_params": 25557032,
"unquantized": ResNet50_Weights.IMAGENET1K_V2,
"acc@1": 80.282,
......@@ -207,7 +201,6 @@ class ResNeXt101_32X8D_QuantizedWeights(WeightsEnum):
transforms=partial(ImageClassification, crop_size=224),
meta={
**_COMMON_META,
"architecture": "ResNeXt",
"num_params": 88791336,
"unquantized": ResNeXt101_32X8D_Weights.IMAGENET1K_V1,
"acc@1": 78.986,
......@@ -219,7 +212,6 @@ class ResNeXt101_32X8D_QuantizedWeights(WeightsEnum):
transforms=partial(ImageClassification, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"architecture": "ResNeXt",
"num_params": 88791336,
"unquantized": ResNeXt101_32X8D_Weights.IMAGENET1K_V2,
"acc@1": 82.574,
......
......@@ -102,13 +102,9 @@ def _shufflenetv2(
_COMMON_META = {
"task": "image_classification",
"architecture": "ShuffleNetV2",
"size": (224, 224),
"min_size": (1, 1),
"categories": _IMAGENET_CATEGORIES,
"backend": "fbgemm",
"quantization": "Post Training Quantization",
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models",
}
......
import math
from collections import OrderedDict
from functools import partial
from typing import Any, Callable, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple
import torch
from torch import nn, Tensor
......@@ -402,17 +402,13 @@ def _regnet(
return model
_COMMON_META = {
"task": "image_classification",
"architecture": "RegNet",
"size": (224, 224),
_COMMON_META: Dict[str, Any] = {
"min_size": (1, 1),
"categories": _IMAGENET_CATEGORIES,
}
_COMMON_SWAG_META = {
**_COMMON_META,
"size": (384, 384),
"recipe": "https://github.com/facebookresearch/SWAG",
"license": "https://github.com/facebookresearch/SWAG/blob/main/LICENSE",
}
......
......@@ -302,8 +302,6 @@ def _resnet(
_COMMON_META = {
"task": "image_classification",
"size": (224, 224),
"min_size": (1, 1),
"categories": _IMAGENET_CATEGORIES,
}
......@@ -315,7 +313,6 @@ class ResNet18_Weights(WeightsEnum):
transforms=partial(ImageClassification, crop_size=224),
meta={
**_COMMON_META,
"architecture": "ResNet",
"num_params": 11689512,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet",
"acc@1": 69.758,
......@@ -331,7 +328,6 @@ class ResNet34_Weights(WeightsEnum):
transforms=partial(ImageClassification, crop_size=224),
meta={
**_COMMON_META,
"architecture": "ResNet",
"num_params": 21797672,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet",
"acc@1": 73.314,
......@@ -347,7 +343,6 @@ class ResNet50_Weights(WeightsEnum):
transforms=partial(ImageClassification, crop_size=224),
meta={
**_COMMON_META,
"architecture": "ResNet",
"num_params": 25557032,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet",
"acc@1": 76.130,
......@@ -359,7 +354,6 @@ class ResNet50_Weights(WeightsEnum):
transforms=partial(ImageClassification, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"architecture": "ResNet",
"num_params": 25557032,
"recipe": "https://github.com/pytorch/vision/issues/3995#issuecomment-1013906621",
"acc@1": 80.858,
......@@ -375,7 +369,6 @@ class ResNet101_Weights(WeightsEnum):
transforms=partial(ImageClassification, crop_size=224),
meta={
**_COMMON_META,
"architecture": "ResNet",
"num_params": 44549160,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet",
"acc@1": 77.374,
......@@ -387,7 +380,6 @@ class ResNet101_Weights(WeightsEnum):
transforms=partial(ImageClassification, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"architecture": "ResNet",
"num_params": 44549160,
"recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
"acc@1": 81.886,
......@@ -403,7 +395,6 @@ class ResNet152_Weights(WeightsEnum):
transforms=partial(ImageClassification, crop_size=224),
meta={
**_COMMON_META,
"architecture": "ResNet",
"num_params": 60192808,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet",
"acc@1": 78.312,
......@@ -415,7 +406,6 @@ class ResNet152_Weights(WeightsEnum):
transforms=partial(ImageClassification, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"architecture": "ResNet",
"num_params": 60192808,
"recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
"acc@1": 82.284,
......@@ -431,7 +421,6 @@ class ResNeXt50_32X4D_Weights(WeightsEnum):
transforms=partial(ImageClassification, crop_size=224),
meta={
**_COMMON_META,
"architecture": "ResNeXt",
"num_params": 25028904,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnext",
"acc@1": 77.618,
......@@ -443,7 +432,6 @@ class ResNeXt50_32X4D_Weights(WeightsEnum):
transforms=partial(ImageClassification, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"architecture": "ResNeXt",
"num_params": 25028904,
"recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
"acc@1": 81.198,
......@@ -459,7 +447,6 @@ class ResNeXt101_32X8D_Weights(WeightsEnum):
transforms=partial(ImageClassification, crop_size=224),
meta={
**_COMMON_META,
"architecture": "ResNeXt",
"num_params": 88791336,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnext",
"acc@1": 79.312,
......@@ -471,7 +458,6 @@ class ResNeXt101_32X8D_Weights(WeightsEnum):
transforms=partial(ImageClassification, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"architecture": "ResNeXt",
"num_params": 88791336,
"recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-fixres",
"acc@1": 82.834,
......@@ -487,7 +473,6 @@ class Wide_ResNet50_2_Weights(WeightsEnum):
transforms=partial(ImageClassification, crop_size=224),
meta={
**_COMMON_META,
"architecture": "WideResNet",
"num_params": 68883240,
"recipe": "https://github.com/pytorch/vision/pull/912#issue-445437439",
"acc@1": 78.468,
......@@ -499,7 +484,6 @@ class Wide_ResNet50_2_Weights(WeightsEnum):
transforms=partial(ImageClassification, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"architecture": "WideResNet",
"num_params": 68883240,
"recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-fixres",
"acc@1": 81.602,
......@@ -515,7 +499,6 @@ class Wide_ResNet101_2_Weights(WeightsEnum):
transforms=partial(ImageClassification, crop_size=224),
meta={
**_COMMON_META,
"architecture": "WideResNet",
"num_params": 126886696,
"recipe": "https://github.com/pytorch/vision/pull/912#issue-445437439",
"acc@1": 78.848,
......@@ -527,7 +510,6 @@ class Wide_ResNet101_2_Weights(WeightsEnum):
transforms=partial(ImageClassification, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"architecture": "WideResNet",
"num_params": 126886696,
"recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
"acc@1": 82.510,
......
......@@ -129,9 +129,8 @@ def _deeplabv3_resnet(
_COMMON_META = {
"task": "image_semantic_segmentation",
"architecture": "DeepLabV3",
"categories": _VOC_CATEGORIES,
"min_size": (1, 1),
}
......
......@@ -48,9 +48,8 @@ class FCNHead(nn.Sequential):
_COMMON_META = {
"task": "image_semantic_segmentation",
"architecture": "FCN",
"categories": _VOC_CATEGORIES,
"min_size": (1, 1),
}
......
......@@ -98,10 +98,9 @@ class LRASPP_MobileNet_V3_Large_Weights(WeightsEnum):
url="https://download.pytorch.org/models/lraspp_mobilenet_v3_large-d234d4ea.pth",
transforms=partial(SemanticSegmentation, resize_size=520),
meta={
"task": "image_semantic_segmentation",
"architecture": "LRASPP",
"num_params": 3221538,
"categories": _VOC_CATEGORIES,
"min_size": (1, 1),
"recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#lraspp_mobilenet_v3_large",
"mIoU": 57.9,
"acc": 91.2,
......
......@@ -184,9 +184,6 @@ def _shufflenetv2(
_COMMON_META = {
"task": "image_classification",
"architecture": "ShuffleNetV2",
"size": (224, 224),
"min_size": (1, 1),
"categories": _IMAGENET_CATEGORIES,
"recipe": "https://github.com/barrh/Shufflenet-v2-Pytorch/tree/v0.1.0",
......
......@@ -115,9 +115,6 @@ def _squeezenet(
_COMMON_META = {
"task": "image_classification",
"architecture": "SqueezeNet",
"size": (224, 224),
"categories": _IMAGENET_CATEGORIES,
"recipe": "https://github.com/pytorch/vision/pull/49#issuecomment-277560717",
}
......
......@@ -107,9 +107,6 @@ def _vgg(cfg: str, batch_norm: bool, weights: Optional[WeightsEnum], progress: b
_COMMON_META = {
"task": "image_classification",
"architecture": "VGG",
"size": (224, 224),
"min_size": (32, 32),
"categories": _IMAGENET_CATEGORIES,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#alexnet-and-vgg",
......
......@@ -309,8 +309,6 @@ def _video_resnet(
_COMMON_META = {
"task": "video_classification",
"size": (112, 112),
"min_size": (1, 1),
"categories": _KINETICS400_CATEGORIES,
"recipe": "https://github.com/pytorch/vision/tree/main/references/video_classification",
......@@ -323,7 +321,6 @@ class R3D_18_Weights(WeightsEnum):
transforms=partial(VideoClassification, crop_size=(112, 112), resize_size=(128, 171)),
meta={
**_COMMON_META,
"architecture": "R3D",
"num_params": 33371472,
"acc@1": 52.75,
"acc@5": 75.45,
......@@ -338,7 +335,6 @@ class MC3_18_Weights(WeightsEnum):
transforms=partial(VideoClassification, crop_size=(112, 112), resize_size=(128, 171)),
meta={
**_COMMON_META,
"architecture": "MC3",
"num_params": 11695440,
"acc@1": 53.90,
"acc@5": 76.29,
......@@ -353,7 +349,6 @@ class R2Plus1D_18_Weights(WeightsEnum):
transforms=partial(VideoClassification, crop_size=(112, 112), resize_size=(128, 171)),
meta={
**_COMMON_META,
"architecture": "R(2+1)D",
"num_params": 31505325,
"acc@1": 57.50,
"acc@5": 78.81,
......
import math
from collections import OrderedDict
from functools import partial
from typing import Any, Callable, List, NamedTuple, Optional, Sequence, Dict
from typing import Any, Callable, List, NamedTuple, Optional, Dict
import torch
import torch.nn as nn
......@@ -288,18 +288,8 @@ def _vision_transformer(
) -> VisionTransformer:
if weights is not None:
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
if isinstance(weights.meta["size"], int):
_ovewrite_named_param(kwargs, "image_size", weights.meta["size"])
elif isinstance(weights.meta["size"], Sequence):
if len(weights.meta["size"]) != 2 or weights.meta["size"][0] != weights.meta["size"][1]:
raise ValueError(
f'size: {weights.meta["size"]} is not valid! Currently we only support a 2-dimensional square and width = height'
)
_ovewrite_named_param(kwargs, "image_size", weights.meta["size"][0])
else:
raise ValueError(
f'weights.meta["size"]: {weights.meta["size"]} is not valid, the type should be either an int or a Sequence[int]'
)
assert weights.meta["min_size"][0] == weights.meta["min_size"][1]
_ovewrite_named_param(kwargs, "image_size", weights.meta["min_size"][0])
image_size = kwargs.pop("image_size", 224)
model = VisionTransformer(
......@@ -319,12 +309,10 @@ def _vision_transformer(
_COMMON_META: Dict[str, Any] = {
"task": "image_classification",
"architecture": "ViT",
"categories": _IMAGENET_CATEGORIES,
}
_COMMON_SWAG_META: Dict[str, Any] = {
_COMMON_SWAG_META = {
**_COMMON_META,
"recipe": "https://github.com/facebookresearch/SWAG",
"license": "https://github.com/facebookresearch/SWAG/blob/main/LICENSE",
......@@ -338,7 +326,6 @@ class ViT_B_16_Weights(WeightsEnum):
meta={
**_COMMON_META,
"num_params": 86567656,
"size": (224, 224),
"min_size": (224, 224),
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#vit_b_16",
"acc@1": 81.072,
......@@ -356,7 +343,6 @@ class ViT_B_16_Weights(WeightsEnum):
meta={
**_COMMON_SWAG_META,
"num_params": 86859496,
"size": (384, 384),
"min_size": (384, 384),
"acc@1": 85.304,
"acc@5": 97.650,
......@@ -374,7 +360,6 @@ class ViT_B_16_Weights(WeightsEnum):
**_COMMON_SWAG_META,
"recipe": "https://github.com/pytorch/vision/pull/5793",
"num_params": 86567656,
"size": (224, 224),
"min_size": (224, 224),
"acc@1": 81.886,
"acc@5": 96.180,
......@@ -390,7 +375,6 @@ class ViT_B_32_Weights(WeightsEnum):
meta={
**_COMMON_META,
"num_params": 88224232,
"size": (224, 224),
"min_size": (224, 224),
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#vit_b_32",
"acc@1": 75.912,
......@@ -407,7 +391,6 @@ class ViT_L_16_Weights(WeightsEnum):
meta={
**_COMMON_META,
"num_params": 304326632,
"size": (224, 224),
"min_size": (224, 224),
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#vit_l_16",
"acc@1": 79.662,
......@@ -425,7 +408,6 @@ class ViT_L_16_Weights(WeightsEnum):
meta={
**_COMMON_SWAG_META,
"num_params": 305174504,
"size": (512, 512),
"min_size": (512, 512),
"acc@1": 88.064,
"acc@5": 98.512,
......@@ -443,7 +425,6 @@ class ViT_L_16_Weights(WeightsEnum):
**_COMMON_SWAG_META,
"recipe": "https://github.com/pytorch/vision/pull/5793",
"num_params": 304326632,
"size": (224, 224),
"min_size": (224, 224),
"acc@1": 85.146,
"acc@5": 97.422,
......@@ -459,7 +440,6 @@ class ViT_L_32_Weights(WeightsEnum):
meta={
**_COMMON_META,
"num_params": 306535400,
"size": (224, 224),
"min_size": (224, 224),
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#vit_l_32",
"acc@1": 76.972,
......@@ -481,7 +461,6 @@ class ViT_H_14_Weights(WeightsEnum):
meta={
**_COMMON_SWAG_META,
"num_params": 633470440,
"size": (518, 518),
"min_size": (518, 518),
"acc@1": 88.552,
"acc@5": 98.694,
......@@ -499,7 +478,6 @@ class ViT_H_14_Weights(WeightsEnum):
**_COMMON_SWAG_META,
"recipe": "https://github.com/pytorch/vision/pull/5793",
"num_params": 632045800,
"size": (224, 224),
"min_size": (224, 224),
"acc@1": 85.708,
"acc@5": 97.730,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment