Unverified Commit b5aa0915 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Improved meta-data for models (#5170)

* Improved meta-data for models.

* Addressing comments from code-review.

* Add parameter count.

* Fix linter.
parent 5dc61cb0
...@@ -94,10 +94,11 @@ def test_naming_conventions(model_fn): ...@@ -94,10 +94,11 @@ def test_naming_conventions(model_fn):
+ TM.get_models_from_module(models.video) + TM.get_models_from_module(models.video)
+ TM.get_models_from_module(models.optical_flow), + TM.get_models_from_module(models.optical_flow),
) )
@run_if_test_with_prototype
def test_schema_meta_validation(model_fn): def test_schema_meta_validation(model_fn):
classification_fields = ["size", "categories", "acc@1", "acc@5"] classification_fields = ["size", "categories", "acc@1", "acc@5"]
defaults = { defaults = {
"all": ["interpolation", "recipe"], "all": ["task", "architecture", "publication_year", "interpolation", "recipe", "num_params"],
"models": classification_fields, "models": classification_fields,
"detection": ["categories", "map"], "detection": ["categories", "map"],
"quantization": classification_fields + ["backend", "quantization", "unquantized"], "quantization": classification_fields + ["backend", "quantization", "unquantized"],
...@@ -105,18 +106,35 @@ def test_schema_meta_validation(model_fn): ...@@ -105,18 +106,35 @@ def test_schema_meta_validation(model_fn):
"video": classification_fields, "video": classification_fields,
"optical_flow": [], "optical_flow": [],
} }
model_name = model_fn.__name__
module_name = model_fn.__module__.split(".")[-2] module_name = model_fn.__module__.split(".")[-2]
fields = set(defaults["all"] + defaults[module_name]) fields = set(defaults["all"] + defaults[module_name])
weights_enum = _get_model_weights(model_fn) weights_enum = _get_model_weights(model_fn)
if len(weights_enum) == 0:
pytest.skip(f"Model '{model_name}' doesn't have any pre-trained weights.")
problematic_weights = {} problematic_weights = {}
incorrect_params = []
for w in weights_enum: for w in weights_enum:
missing_fields = fields - set(w.meta.keys()) missing_fields = fields - set(w.meta.keys())
if missing_fields: if missing_fields:
problematic_weights[w] = missing_fields problematic_weights[w] = missing_fields
if w == weights_enum.default:
if module_name == "quantization":
# parametes() cound doesn't work well with quantization, so we check against the non-quantized
unquantized_w = w.meta.get("unquantized")
if unquantized_w is not None and w.meta.get("num_params") != unquantized_w.meta.get("num_params"):
incorrect_params.append(w)
else:
if w.meta.get("num_params") != sum(p.numel() for p in model_fn(weights=w).parameters()):
incorrect_params.append(w)
else:
if w.meta.get("num_params") != weights_enum.default.meta.get("num_params"):
incorrect_params.append(w)
assert not problematic_weights assert not problematic_weights
assert not incorrect_params
@pytest.mark.parametrize("model_fn", TM.get_models_from_module(models)) @pytest.mark.parametrize("model_fn", TM.get_models_from_module(models))
......
...@@ -18,7 +18,9 @@ model_urls = { ...@@ -18,7 +18,9 @@ model_urls = {
class FCN(_SimpleSegmentationModel): class FCN(_SimpleSegmentationModel):
""" """
Implements a Fully-Convolutional Network for semantic segmentation. Implements FCN model from
`"Fully Convolutional Networks for Semantic Segmentation"
<https://arxiv.org/abs/1411.4038>`_.
Args: Args:
backbone (nn.Module): the network used to compute the features for the model. backbone (nn.Module): the network used to compute the features for the model.
......
...@@ -18,6 +18,10 @@ class AlexNet_Weights(WeightsEnum): ...@@ -18,6 +18,10 @@ class AlexNet_Weights(WeightsEnum):
url="https://download.pytorch.org/models/alexnet-owt-7be5be79.pth", url="https://download.pytorch.org/models/alexnet-owt-7be5be79.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
"task": "image_classification",
"architecture": "AlexNet",
"publication_year": 2012,
"num_params": 61100840,
"size": (224, 224), "size": (224, 224),
"categories": _IMAGENET_CATEGORIES, "categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR, "interpolation": InterpolationMode.BILINEAR,
......
...@@ -64,6 +64,9 @@ def _densenet( ...@@ -64,6 +64,9 @@ def _densenet(
_COMMON_META = { _COMMON_META = {
"task": "image_classification",
"architecture": "DenseNet",
"publication_year": 2016,
"size": (224, 224), "size": (224, 224),
"categories": _IMAGENET_CATEGORIES, "categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR, "interpolation": InterpolationMode.BILINEAR,
...@@ -77,6 +80,7 @@ class DenseNet121_Weights(WeightsEnum): ...@@ -77,6 +80,7 @@ class DenseNet121_Weights(WeightsEnum):
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 7978856,
"acc@1": 74.434, "acc@1": 74.434,
"acc@5": 91.972, "acc@5": 91.972,
}, },
...@@ -90,6 +94,7 @@ class DenseNet161_Weights(WeightsEnum): ...@@ -90,6 +94,7 @@ class DenseNet161_Weights(WeightsEnum):
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 28681000,
"acc@1": 77.138, "acc@1": 77.138,
"acc@5": 93.560, "acc@5": 93.560,
}, },
...@@ -103,6 +108,7 @@ class DenseNet169_Weights(WeightsEnum): ...@@ -103,6 +108,7 @@ class DenseNet169_Weights(WeightsEnum):
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 14149480,
"acc@1": 75.600, "acc@1": 75.600,
"acc@5": 92.806, "acc@5": 92.806,
}, },
...@@ -116,6 +122,7 @@ class DenseNet201_Weights(WeightsEnum): ...@@ -116,6 +122,7 @@ class DenseNet201_Weights(WeightsEnum):
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 20013928,
"acc@1": 76.896, "acc@1": 76.896,
"acc@5": 93.370, "acc@5": 93.370,
}, },
......
...@@ -31,6 +31,9 @@ __all__ = [ ...@@ -31,6 +31,9 @@ __all__ = [
_COMMON_META = { _COMMON_META = {
"task": "image_object_detection",
"architecture": "FasterRCNN",
"publication_year": 2015,
"categories": _COCO_CATEGORIES, "categories": _COCO_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR, "interpolation": InterpolationMode.BILINEAR,
} }
...@@ -42,6 +45,7 @@ class FasterRCNN_ResNet50_FPN_Weights(WeightsEnum): ...@@ -42,6 +45,7 @@ class FasterRCNN_ResNet50_FPN_Weights(WeightsEnum):
transforms=CocoEval, transforms=CocoEval,
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 41755286,
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-resnet-50-fpn", "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-resnet-50-fpn",
"map": 37.0, "map": 37.0,
}, },
...@@ -55,6 +59,7 @@ class FasterRCNN_MobileNet_V3_Large_FPN_Weights(WeightsEnum): ...@@ -55,6 +59,7 @@ class FasterRCNN_MobileNet_V3_Large_FPN_Weights(WeightsEnum):
transforms=CocoEval, transforms=CocoEval,
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 19386354,
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-fpn", "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-fpn",
"map": 32.8, "map": 32.8,
}, },
...@@ -68,6 +73,7 @@ class FasterRCNN_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum): ...@@ -68,6 +73,7 @@ class FasterRCNN_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum):
transforms=CocoEval, transforms=CocoEval,
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 19386354,
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-320-fpn", "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-320-fpn",
"map": 22.8, "map": 22.8,
}, },
......
...@@ -24,6 +24,9 @@ __all__ = [ ...@@ -24,6 +24,9 @@ __all__ = [
_COMMON_META = { _COMMON_META = {
"task": "image_object_detection",
"architecture": "KeypointRCNN",
"publication_year": 2017,
"categories": _COCO_PERSON_CATEGORIES, "categories": _COCO_PERSON_CATEGORIES,
"keypoint_names": _COCO_PERSON_KEYPOINT_NAMES, "keypoint_names": _COCO_PERSON_KEYPOINT_NAMES,
"interpolation": InterpolationMode.BILINEAR, "interpolation": InterpolationMode.BILINEAR,
...@@ -36,6 +39,7 @@ class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum): ...@@ -36,6 +39,7 @@ class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum):
transforms=CocoEval, transforms=CocoEval,
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 59137258,
"recipe": "https://github.com/pytorch/vision/issues/1606", "recipe": "https://github.com/pytorch/vision/issues/1606",
"map": 50.6, "map": 50.6,
"map_kp": 61.1, "map_kp": 61.1,
...@@ -46,6 +50,7 @@ class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum): ...@@ -46,6 +50,7 @@ class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum):
transforms=CocoEval, transforms=CocoEval,
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 59137258,
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#keypoint-r-cnn", "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#keypoint-r-cnn",
"map": 54.6, "map": 54.6,
"map_kp": 65.0, "map_kp": 65.0,
......
...@@ -28,6 +28,10 @@ class MaskRCNN_ResNet50_FPN_Weights(WeightsEnum): ...@@ -28,6 +28,10 @@ class MaskRCNN_ResNet50_FPN_Weights(WeightsEnum):
url="https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth", url="https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth",
transforms=CocoEval, transforms=CocoEval,
meta={ meta={
"task": "image_object_detection",
"architecture": "MaskRCNN",
"publication_year": 2017,
"num_params": 44401393,
"categories": _COCO_CATEGORIES, "categories": _COCO_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR, "interpolation": InterpolationMode.BILINEAR,
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#mask-r-cnn", "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#mask-r-cnn",
......
...@@ -29,6 +29,10 @@ class RetinaNet_ResNet50_FPN_Weights(WeightsEnum): ...@@ -29,6 +29,10 @@ class RetinaNet_ResNet50_FPN_Weights(WeightsEnum):
url="https://download.pytorch.org/models/retinanet_resnet50_fpn_coco-eeacb38b.pth", url="https://download.pytorch.org/models/retinanet_resnet50_fpn_coco-eeacb38b.pth",
transforms=CocoEval, transforms=CocoEval,
meta={ meta={
"task": "image_object_detection",
"architecture": "RetinaNet",
"publication_year": 2017,
"num_params": 34014999,
"categories": _COCO_CATEGORIES, "categories": _COCO_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR, "interpolation": InterpolationMode.BILINEAR,
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#retinanet", "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#retinanet",
......
...@@ -27,6 +27,10 @@ class SSD300_VGG16_Weights(WeightsEnum): ...@@ -27,6 +27,10 @@ class SSD300_VGG16_Weights(WeightsEnum):
url="https://download.pytorch.org/models/ssd300_vgg16_coco-b556d3b4.pth", url="https://download.pytorch.org/models/ssd300_vgg16_coco-b556d3b4.pth",
transforms=CocoEval, transforms=CocoEval,
meta={ meta={
"task": "image_object_detection",
"architecture": "SSD",
"publication_year": 2015,
"num_params": 35641826,
"size": (300, 300), "size": (300, 300),
"categories": _COCO_CATEGORIES, "categories": _COCO_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR, "interpolation": InterpolationMode.BILINEAR,
......
...@@ -32,6 +32,10 @@ class SSDLite320_MobileNet_V3_Large_Weights(WeightsEnum): ...@@ -32,6 +32,10 @@ class SSDLite320_MobileNet_V3_Large_Weights(WeightsEnum):
url="https://download.pytorch.org/models/ssdlite320_mobilenet_v3_large_coco-a79551df.pth", url="https://download.pytorch.org/models/ssdlite320_mobilenet_v3_large_coco-a79551df.pth",
transforms=CocoEval, transforms=CocoEval,
meta={ meta={
"task": "image_object_detection",
"architecture": "SSDLite",
"publication_year": 2018,
"num_params": 3440060,
"size": (320, 320), "size": (320, 320),
"categories": _COCO_CATEGORIES, "categories": _COCO_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR, "interpolation": InterpolationMode.BILINEAR,
......
...@@ -63,6 +63,9 @@ def _efficientnet( ...@@ -63,6 +63,9 @@ def _efficientnet(
_COMMON_META = { _COMMON_META = {
"task": "image_classification",
"architecture": "EfficientNet",
"publication_year": 2019,
"categories": _IMAGENET_CATEGORIES, "categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BICUBIC, "interpolation": InterpolationMode.BICUBIC,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#efficientnet", "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#efficientnet",
...@@ -75,6 +78,7 @@ class EfficientNet_B0_Weights(WeightsEnum): ...@@ -75,6 +78,7 @@ class EfficientNet_B0_Weights(WeightsEnum):
transforms=partial(ImageNetEval, crop_size=224, resize_size=256, interpolation=InterpolationMode.BICUBIC), transforms=partial(ImageNetEval, crop_size=224, resize_size=256, interpolation=InterpolationMode.BICUBIC),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 5288548,
"size": (224, 224), "size": (224, 224),
"acc@1": 77.692, "acc@1": 77.692,
"acc@5": 93.532, "acc@5": 93.532,
...@@ -89,6 +93,7 @@ class EfficientNet_B1_Weights(WeightsEnum): ...@@ -89,6 +93,7 @@ class EfficientNet_B1_Weights(WeightsEnum):
transforms=partial(ImageNetEval, crop_size=240, resize_size=256, interpolation=InterpolationMode.BICUBIC), transforms=partial(ImageNetEval, crop_size=240, resize_size=256, interpolation=InterpolationMode.BICUBIC),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 7794184,
"size": (240, 240), "size": (240, 240),
"acc@1": 78.642, "acc@1": 78.642,
"acc@5": 94.186, "acc@5": 94.186,
...@@ -99,6 +104,7 @@ class EfficientNet_B1_Weights(WeightsEnum): ...@@ -99,6 +104,7 @@ class EfficientNet_B1_Weights(WeightsEnum):
transforms=partial(ImageNetEval, crop_size=240, resize_size=255, interpolation=InterpolationMode.BILINEAR), transforms=partial(ImageNetEval, crop_size=240, resize_size=255, interpolation=InterpolationMode.BILINEAR),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 7794184,
"recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-lr-wd-crop-tuning", "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-lr-wd-crop-tuning",
"interpolation": InterpolationMode.BILINEAR, "interpolation": InterpolationMode.BILINEAR,
"size": (240, 240), "size": (240, 240),
...@@ -115,6 +121,7 @@ class EfficientNet_B2_Weights(WeightsEnum): ...@@ -115,6 +121,7 @@ class EfficientNet_B2_Weights(WeightsEnum):
transforms=partial(ImageNetEval, crop_size=288, resize_size=288, interpolation=InterpolationMode.BICUBIC), transforms=partial(ImageNetEval, crop_size=288, resize_size=288, interpolation=InterpolationMode.BICUBIC),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 9109994,
"size": (288, 288), "size": (288, 288),
"acc@1": 80.608, "acc@1": 80.608,
"acc@5": 95.310, "acc@5": 95.310,
...@@ -129,6 +136,7 @@ class EfficientNet_B3_Weights(WeightsEnum): ...@@ -129,6 +136,7 @@ class EfficientNet_B3_Weights(WeightsEnum):
transforms=partial(ImageNetEval, crop_size=300, resize_size=320, interpolation=InterpolationMode.BICUBIC), transforms=partial(ImageNetEval, crop_size=300, resize_size=320, interpolation=InterpolationMode.BICUBIC),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 12233232,
"size": (300, 300), "size": (300, 300),
"acc@1": 82.008, "acc@1": 82.008,
"acc@5": 96.054, "acc@5": 96.054,
...@@ -143,6 +151,7 @@ class EfficientNet_B4_Weights(WeightsEnum): ...@@ -143,6 +151,7 @@ class EfficientNet_B4_Weights(WeightsEnum):
transforms=partial(ImageNetEval, crop_size=380, resize_size=384, interpolation=InterpolationMode.BICUBIC), transforms=partial(ImageNetEval, crop_size=380, resize_size=384, interpolation=InterpolationMode.BICUBIC),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 19341616,
"size": (380, 380), "size": (380, 380),
"acc@1": 83.384, "acc@1": 83.384,
"acc@5": 96.594, "acc@5": 96.594,
...@@ -157,6 +166,7 @@ class EfficientNet_B5_Weights(WeightsEnum): ...@@ -157,6 +166,7 @@ class EfficientNet_B5_Weights(WeightsEnum):
transforms=partial(ImageNetEval, crop_size=456, resize_size=456, interpolation=InterpolationMode.BICUBIC), transforms=partial(ImageNetEval, crop_size=456, resize_size=456, interpolation=InterpolationMode.BICUBIC),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 30389784,
"size": (456, 456), "size": (456, 456),
"acc@1": 83.444, "acc@1": 83.444,
"acc@5": 96.628, "acc@5": 96.628,
...@@ -171,6 +181,7 @@ class EfficientNet_B6_Weights(WeightsEnum): ...@@ -171,6 +181,7 @@ class EfficientNet_B6_Weights(WeightsEnum):
transforms=partial(ImageNetEval, crop_size=528, resize_size=528, interpolation=InterpolationMode.BICUBIC), transforms=partial(ImageNetEval, crop_size=528, resize_size=528, interpolation=InterpolationMode.BICUBIC),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 43040704,
"size": (528, 528), "size": (528, 528),
"acc@1": 84.008, "acc@1": 84.008,
"acc@5": 96.916, "acc@5": 96.916,
...@@ -185,6 +196,7 @@ class EfficientNet_B7_Weights(WeightsEnum): ...@@ -185,6 +196,7 @@ class EfficientNet_B7_Weights(WeightsEnum):
transforms=partial(ImageNetEval, crop_size=600, resize_size=600, interpolation=InterpolationMode.BICUBIC), transforms=partial(ImageNetEval, crop_size=600, resize_size=600, interpolation=InterpolationMode.BICUBIC),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 66347960,
"size": (600, 600), "size": (600, 600),
"acc@1": 84.122, "acc@1": 84.122,
"acc@5": 96.908, "acc@5": 96.908,
......
...@@ -19,6 +19,10 @@ class GoogLeNet_Weights(WeightsEnum): ...@@ -19,6 +19,10 @@ class GoogLeNet_Weights(WeightsEnum):
url="https://download.pytorch.org/models/googlenet-1378be20.pth", url="https://download.pytorch.org/models/googlenet-1378be20.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
"task": "image_classification",
"architecture": "GoogLeNet",
"publication_year": 2014,
"num_params": 6624904,
"size": (224, 224), "size": (224, 224),
"categories": _IMAGENET_CATEGORIES, "categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR, "interpolation": InterpolationMode.BILINEAR,
......
...@@ -18,6 +18,10 @@ class Inception_V3_Weights(WeightsEnum): ...@@ -18,6 +18,10 @@ class Inception_V3_Weights(WeightsEnum):
url="https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth", url="https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth",
transforms=partial(ImageNetEval, crop_size=299, resize_size=342), transforms=partial(ImageNetEval, crop_size=299, resize_size=342),
meta={ meta={
"task": "image_classification",
"architecture": "InceptionV3",
"publication_year": 2015,
"num_params": 27161264,
"size": (299, 299), "size": (299, 299),
"categories": _IMAGENET_CATEGORIES, "categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR, "interpolation": InterpolationMode.BILINEAR,
......
...@@ -24,6 +24,9 @@ __all__ = [ ...@@ -24,6 +24,9 @@ __all__ = [
_COMMON_META = { _COMMON_META = {
"task": "image_classification",
"architecture": "MNASNet",
"publication_year": 2018,
"size": (224, 224), "size": (224, 224),
"categories": _IMAGENET_CATEGORIES, "categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR, "interpolation": InterpolationMode.BILINEAR,
...@@ -37,6 +40,7 @@ class MNASNet0_5_Weights(WeightsEnum): ...@@ -37,6 +40,7 @@ class MNASNet0_5_Weights(WeightsEnum):
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 2218512,
"acc@1": 67.734, "acc@1": 67.734,
"acc@5": 87.490, "acc@5": 87.490,
}, },
...@@ -55,6 +59,7 @@ class MNASNet1_0_Weights(WeightsEnum): ...@@ -55,6 +59,7 @@ class MNASNet1_0_Weights(WeightsEnum):
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 4383312,
"acc@1": 73.456, "acc@1": 73.456,
"acc@5": 91.510, "acc@5": 91.510,
}, },
......
...@@ -18,6 +18,10 @@ class MobileNet_V2_Weights(WeightsEnum): ...@@ -18,6 +18,10 @@ class MobileNet_V2_Weights(WeightsEnum):
url="https://download.pytorch.org/models/mobilenet_v2-b0353104.pth", url="https://download.pytorch.org/models/mobilenet_v2-b0353104.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
"task": "image_classification",
"architecture": "MobileNetV2",
"publication_year": 2018,
"num_params": 3504872,
"size": (224, 224), "size": (224, 224),
"categories": _IMAGENET_CATEGORIES, "categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR, "interpolation": InterpolationMode.BILINEAR,
......
...@@ -38,6 +38,9 @@ def _mobilenet_v3( ...@@ -38,6 +38,9 @@ def _mobilenet_v3(
_COMMON_META = { _COMMON_META = {
"task": "image_classification",
"architecture": "MobileNetV3",
"publication_year": 2019,
"size": (224, 224), "size": (224, 224),
"categories": _IMAGENET_CATEGORIES, "categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR, "interpolation": InterpolationMode.BILINEAR,
...@@ -50,6 +53,7 @@ class MobileNet_V3_Large_Weights(WeightsEnum): ...@@ -50,6 +53,7 @@ class MobileNet_V3_Large_Weights(WeightsEnum):
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 5483032,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#mobilenetv3-large--small", "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#mobilenetv3-large--small",
"acc@1": 74.042, "acc@1": 74.042,
"acc@5": 91.340, "acc@5": 91.340,
...@@ -60,6 +64,7 @@ class MobileNet_V3_Large_Weights(WeightsEnum): ...@@ -60,6 +64,7 @@ class MobileNet_V3_Large_Weights(WeightsEnum):
transforms=partial(ImageNetEval, crop_size=224, resize_size=232), transforms=partial(ImageNetEval, crop_size=224, resize_size=232),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 5483032,
"recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-reg-tuning", "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-reg-tuning",
"acc@1": 75.274, "acc@1": 75.274,
"acc@5": 92.566, "acc@5": 92.566,
...@@ -74,6 +79,7 @@ class MobileNet_V3_Small_Weights(WeightsEnum): ...@@ -74,6 +79,7 @@ class MobileNet_V3_Small_Weights(WeightsEnum):
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 2542856,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#mobilenetv3-large--small", "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#mobilenetv3-large--small",
"acc@1": 67.668, "acc@1": 67.668,
"acc@5": 87.402, "acc@5": 87.402,
......
...@@ -21,7 +21,12 @@ __all__ = ( ...@@ -21,7 +21,12 @@ __all__ = (
) )
_COMMON_META = {"interpolation": InterpolationMode.BILINEAR} _COMMON_META = {
"task": "optical_flow",
"architecture": "RAFT",
"publication_year": 2020,
"interpolation": InterpolationMode.BILINEAR,
}
class Raft_Large_Weights(WeightsEnum): class Raft_Large_Weights(WeightsEnum):
...@@ -31,6 +36,7 @@ class Raft_Large_Weights(WeightsEnum): ...@@ -31,6 +36,7 @@ class Raft_Large_Weights(WeightsEnum):
transforms=RaftEval, transforms=RaftEval,
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 5257536,
"recipe": "https://github.com/princeton-vl/RAFT", "recipe": "https://github.com/princeton-vl/RAFT",
"sintel_train_cleanpass_epe": 1.4411, "sintel_train_cleanpass_epe": 1.4411,
"sintel_train_finalpass_epe": 2.7894, "sintel_train_finalpass_epe": 2.7894,
...@@ -45,6 +51,7 @@ class Raft_Large_Weights(WeightsEnum): ...@@ -45,6 +51,7 @@ class Raft_Large_Weights(WeightsEnum):
transforms=RaftEval, transforms=RaftEval,
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 5257536,
"recipe": "https://github.com/pytorch/vision/tree/main/references/optical_flow", "recipe": "https://github.com/pytorch/vision/tree/main/references/optical_flow",
"sintel_train_cleanpass_epe": 1.3822, "sintel_train_cleanpass_epe": 1.3822,
"sintel_train_finalpass_epe": 2.7161, "sintel_train_finalpass_epe": 2.7161,
...@@ -59,6 +66,7 @@ class Raft_Large_Weights(WeightsEnum): ...@@ -59,6 +66,7 @@ class Raft_Large_Weights(WeightsEnum):
transforms=RaftEval, transforms=RaftEval,
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 5257536,
"recipe": "https://github.com/princeton-vl/RAFT", "recipe": "https://github.com/princeton-vl/RAFT",
"sintel_test_cleanpass_epe": 1.94, "sintel_test_cleanpass_epe": 1.94,
"sintel_test_finalpass_epe": 3.18, "sintel_test_finalpass_epe": 3.18,
...@@ -73,6 +81,7 @@ class Raft_Large_Weights(WeightsEnum): ...@@ -73,6 +81,7 @@ class Raft_Large_Weights(WeightsEnum):
transforms=RaftEval, transforms=RaftEval,
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 5257536,
"recipe": "https://github.com/pytorch/vision/tree/main/references/optical_flow", "recipe": "https://github.com/pytorch/vision/tree/main/references/optical_flow",
"sintel_test_cleanpass_epe": 1.819, "sintel_test_cleanpass_epe": 1.819,
"sintel_test_finalpass_epe": 3.067, "sintel_test_finalpass_epe": 3.067,
...@@ -85,6 +94,7 @@ class Raft_Large_Weights(WeightsEnum): ...@@ -85,6 +94,7 @@ class Raft_Large_Weights(WeightsEnum):
transforms=RaftEval, transforms=RaftEval,
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 5257536,
"recipe": "https://github.com/princeton-vl/RAFT", "recipe": "https://github.com/princeton-vl/RAFT",
"kitti_test_f1-all": 5.10, "kitti_test_f1-all": 5.10,
}, },
...@@ -99,6 +109,7 @@ class Raft_Large_Weights(WeightsEnum): ...@@ -99,6 +109,7 @@ class Raft_Large_Weights(WeightsEnum):
transforms=RaftEval, transforms=RaftEval,
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 5257536,
"recipe": "https://github.com/pytorch/vision/tree/main/references/optical_flow", "recipe": "https://github.com/pytorch/vision/tree/main/references/optical_flow",
"kitti_test_f1-all": 5.19, "kitti_test_f1-all": 5.19,
}, },
...@@ -114,6 +125,7 @@ class Raft_Small_Weights(WeightsEnum): ...@@ -114,6 +125,7 @@ class Raft_Small_Weights(WeightsEnum):
transforms=RaftEval, transforms=RaftEval,
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 990162,
"recipe": "https://github.com/princeton-vl/RAFT", "recipe": "https://github.com/princeton-vl/RAFT",
"sintel_train_cleanpass_epe": 2.1231, "sintel_train_cleanpass_epe": 2.1231,
"sintel_train_finalpass_epe": 3.2790, "sintel_train_finalpass_epe": 3.2790,
...@@ -127,6 +139,7 @@ class Raft_Small_Weights(WeightsEnum): ...@@ -127,6 +139,7 @@ class Raft_Small_Weights(WeightsEnum):
transforms=RaftEval, transforms=RaftEval,
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 990162,
"recipe": "https://github.com/pytorch/vision/tree/main/references/optical_flow", "recipe": "https://github.com/pytorch/vision/tree/main/references/optical_flow",
"sintel_train_cleanpass_epe": 1.9901, "sintel_train_cleanpass_epe": 1.9901,
"sintel_train_finalpass_epe": 3.2831, "sintel_train_finalpass_epe": 3.2831,
......
...@@ -28,6 +28,10 @@ class GoogLeNet_QuantizedWeights(WeightsEnum): ...@@ -28,6 +28,10 @@ class GoogLeNet_QuantizedWeights(WeightsEnum):
url="https://download.pytorch.org/models/quantized/googlenet_fbgemm-c00238cf.pth", url="https://download.pytorch.org/models/quantized/googlenet_fbgemm-c00238cf.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
"task": "image_classification",
"architecture": "GoogLeNet",
"publication_year": 2014,
"num_params": 6624904,
"size": (224, 224), "size": (224, 224),
"categories": _IMAGENET_CATEGORIES, "categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR, "interpolation": InterpolationMode.BILINEAR,
......
...@@ -27,6 +27,10 @@ class Inception_V3_QuantizedWeights(WeightsEnum): ...@@ -27,6 +27,10 @@ class Inception_V3_QuantizedWeights(WeightsEnum):
url="https://download.pytorch.org/models/quantized/inception_v3_google_fbgemm-71447a44.pth", url="https://download.pytorch.org/models/quantized/inception_v3_google_fbgemm-71447a44.pth",
transforms=partial(ImageNetEval, crop_size=299, resize_size=342), transforms=partial(ImageNetEval, crop_size=299, resize_size=342),
meta={ meta={
"task": "image_classification",
"architecture": "InceptionV3",
"publication_year": 2015,
"num_params": 27161264,
"size": (299, 299), "size": (299, 299),
"categories": _IMAGENET_CATEGORIES, "categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR, "interpolation": InterpolationMode.BILINEAR,
......
...@@ -28,6 +28,10 @@ class MobileNet_V2_QuantizedWeights(WeightsEnum): ...@@ -28,6 +28,10 @@ class MobileNet_V2_QuantizedWeights(WeightsEnum):
url="https://download.pytorch.org/models/quantized/mobilenet_v2_qnnpack_37f702c5.pth", url="https://download.pytorch.org/models/quantized/mobilenet_v2_qnnpack_37f702c5.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
"task": "image_classification",
"architecture": "MobileNetV2",
"publication_year": 2018,
"num_params": 3504872,
"size": (224, 224), "size": (224, 224),
"categories": _IMAGENET_CATEGORIES, "categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR, "interpolation": InterpolationMode.BILINEAR,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment