Unverified Commit cca452f0 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Adding schema validation for model meta-data (#5049)

* Adding test that verifies meta schema.

* Adding missing interpolation to keypoint model.

* Change assertion.

* Adding map metric in meta schema for detection and renaming metrics for kp and mask rcnn.
parent 4d00ae0f
...@@ -83,6 +83,38 @@ def test_naming_conventions(model_fn): ...@@ -83,6 +83,38 @@ def test_naming_conventions(model_fn):
assert len(weights_enum) == 0 or hasattr(weights_enum, "default") assert len(weights_enum) == 0 or hasattr(weights_enum, "default")
@pytest.mark.parametrize(
"model_fn",
TM.get_models_from_module(models)
+ TM.get_models_from_module(models.detection)
+ TM.get_models_from_module(models.quantization)
+ TM.get_models_from_module(models.segmentation)
+ TM.get_models_from_module(models.video),
)
def test_schema_meta_validation(model_fn):
classification_fields = ["size", "categories", "acc@1", "acc@5"]
defaults = {
"all": ["interpolation", "recipe"],
"models": classification_fields,
"detection": ["categories", "map"],
"quantization": classification_fields + ["backend", "quantization", "unquantized"],
"segmentation": ["categories", "mIoU", "acc"],
"video": classification_fields,
}
module_name = model_fn.__module__.split(".")[-2]
fields = set(defaults["all"] + defaults[module_name])
weights_enum = _get_model_weights(model_fn)
problematic_weights = {}
for w in weights_enum:
missing_fields = fields - set(w.meta.keys())
if missing_fields:
problematic_weights[w] = missing_fields
assert not problematic_weights
@pytest.mark.parametrize("model_fn", TM.get_models_from_module(models)) @pytest.mark.parametrize("model_fn", TM.get_models_from_module(models))
@pytest.mark.parametrize("dev", cpu_and_gpu()) @pytest.mark.parametrize("dev", cpu_and_gpu())
@run_if_test_with_prototype @run_if_test_with_prototype
......
from typing import Any, Optional from typing import Any, Optional
from torchvision.prototype.transforms import CocoEval from torchvision.prototype.transforms import CocoEval
from torchvision.transforms.functional import InterpolationMode
from ....models.detection.keypoint_rcnn import ( from ....models.detection.keypoint_rcnn import (
_resnet_fpn_extractor, _resnet_fpn_extractor,
...@@ -22,7 +23,11 @@ __all__ = [ ...@@ -22,7 +23,11 @@ __all__ = [
] ]
_COMMON_META = {"categories": _COCO_PERSON_CATEGORIES, "keypoint_names": _COCO_PERSON_KEYPOINT_NAMES} _COMMON_META = {
"categories": _COCO_PERSON_CATEGORIES,
"keypoint_names": _COCO_PERSON_KEYPOINT_NAMES,
"interpolation": InterpolationMode.BILINEAR,
}
class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum): class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum):
...@@ -32,8 +37,8 @@ class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum): ...@@ -32,8 +37,8 @@ class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum):
meta={ meta={
**_COMMON_META, **_COMMON_META,
"recipe": "https://github.com/pytorch/vision/issues/1606", "recipe": "https://github.com/pytorch/vision/issues/1606",
"box_map": 50.6, "map": 50.6,
"kp_map": 61.1, "map_kp": 61.1,
}, },
) )
Coco_V1 = Weights( Coco_V1 = Weights(
...@@ -42,8 +47,8 @@ class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum): ...@@ -42,8 +47,8 @@ class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum):
meta={ meta={
**_COMMON_META, **_COMMON_META,
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#keypoint-r-cnn", "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#keypoint-r-cnn",
"box_map": 54.6, "map": 54.6,
"kp_map": 65.0, "map_kp": 65.0,
}, },
) )
default = Coco_V1 default = Coco_V1
......
...@@ -31,8 +31,8 @@ class MaskRCNN_ResNet50_FPN_Weights(WeightsEnum): ...@@ -31,8 +31,8 @@ class MaskRCNN_ResNet50_FPN_Weights(WeightsEnum):
"categories": _COCO_CATEGORIES, "categories": _COCO_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR, "interpolation": InterpolationMode.BILINEAR,
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#mask-r-cnn", "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#mask-r-cnn",
"box_map": 37.9, "map": 37.9,
"mask_map": 34.6, "map_mask": 34.6,
}, },
) )
default = Coco_V1 default = Coco_V1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment