Unverified Commit 3d8723d5 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Refactor the `get_weights` API (#5006)

* Change the `default` weights mechanism to sue Enum aliases.

* Change `get_weights` to work with full Enum names and make it public.

* Applying improvements from code review.
parent 65cdaeab
......@@ -158,8 +158,7 @@ def load_data(traindir, valdir, args):
crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation
)
else:
fn = PM.quantization.__dict__[args.model] if hasattr(args, "backend") else PM.__dict__[args.model]
weights = PM._api.get_weight(fn, args.weights)
weights = PM.get_weight(args.weights)
preprocessing = weights.transforms()
dataset_test = torchvision.datasets.ImageFolder(
......
......@@ -53,8 +53,7 @@ def get_transform(train, args):
elif not args.weights:
return presets.DetectionPresetEval()
else:
fn = PM.detection.__dict__[args.model]
weights = PM._api.get_weight(fn, args.weights)
weights = PM.get_weight(args.weights)
return weights.transforms()
......
......@@ -38,8 +38,7 @@ def get_transform(train, args):
elif not args.weights:
return presets.SegmentationPresetEval(base_size=520)
else:
fn = PM.segmentation.__dict__[args.model]
weights = PM._api.get_weight(fn, args.weights)
weights = PM.get_weight(args.weights)
return weights.transforms()
......
......@@ -160,8 +160,7 @@ def main(args):
if not args.weights:
transform_test = presets.VideoClassificationPresetEval((128, 171), (112, 112))
else:
fn = PM.video.__dict__[args.model]
weights = PM._api.get_weight(fn, args.weights)
weights = PM.get_weight(args.weights)
transform_test = weights.transforms()
if args.cache_dataset and os.path.exists(cache_path):
......
......@@ -22,7 +22,11 @@ ACCEPT = os.getenv("EXPECTTEST_ACCEPT", "0") == "1"
def get_models_from_module(module):
# TODO add a registration mechanism to torchvision.models
return [v for k, v in module.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"]
return [
v
for k, v in module.__dict__.items()
if callable(v) and k[0].lower() == k[0] and k[0] != "_" and k != "get_weight"
]
@pytest.fixture
......
......@@ -24,6 +24,19 @@ def _get_parent_module(model_fn):
return module
def _get_model_weights(model_fn):
module = _get_parent_module(model_fn)
weights_name = "_QuantizedWeights" if module.__name__.split(".")[-1] == "quantization" else "_Weights"
try:
return next(
v
for k, v in module.__dict__.items()
if k.endswith(weights_name) and k.replace(weights_name, "").lower() == model_fn.__name__
)
except StopIteration:
return None
def _build_model(fn, **kwargs):
try:
model = fn(**kwargs)
......@@ -36,24 +49,22 @@ def _build_model(fn, **kwargs):
@pytest.mark.parametrize(
"model_fn, name, weight",
"name, weight",
[
(models.resnet50, "ImageNet1K_V1", models.ResNet50_Weights.ImageNet1K_V1),
(models.resnet50, "default", models.ResNet50_Weights.ImageNet1K_V2),
("ResNet50_Weights.ImageNet1K_V1", models.ResNet50_Weights.ImageNet1K_V1),
("ResNet50_Weights.default", models.ResNet50_Weights.ImageNet1K_V2),
(
models.quantization.resnet50,
"default",
"ResNet50_QuantizedWeights.default",
models.quantization.ResNet50_QuantizedWeights.ImageNet1K_FBGEMM_V2,
),
(
models.quantization.resnet50,
"ImageNet1K_FBGEMM_V1",
"ResNet50_QuantizedWeights.ImageNet1K_FBGEMM_V1",
models.quantization.ResNet50_QuantizedWeights.ImageNet1K_FBGEMM_V1,
),
],
)
def test_get_weight(model_fn, name, weight):
assert models._api.get_weight(model_fn, name) == weight
def test_get_weight(name, weight):
assert models.get_weight(name) == weight
@pytest.mark.parametrize(
......@@ -65,10 +76,9 @@ def test_get_weight(model_fn, name, weight):
+ TM.get_models_from_module(models.video),
)
def test_naming_conventions(model_fn):
model_name = model_fn.__name__
module = _get_parent_module(model_fn)
weights_name = "_QuantizedWeights" if module.__name__.split(".")[-1] == "quantization" else "_Weights"
assert model_name in set(x.replace(weights_name, "").lower() for x in module.__dict__ if x.endswith(weights_name))
weights_enum = _get_model_weights(model_fn)
assert weights_enum is not None
assert len(weights_enum) == 0 or hasattr(weights_enum, "default")
@pytest.mark.parametrize("model_fn", TM.get_models_from_module(models))
......
......@@ -15,3 +15,4 @@ from . import detection
from . import quantization
from . import segmentation
from . import video
from ._api import get_weight
import importlib
import inspect
import sys
from collections import OrderedDict
from dataclasses import dataclass, fields
from enum import Enum
from inspect import signature
from typing import Any, Callable, Dict
from ..._internally_replaced_utils import load_state_dict_from_url
......@@ -30,7 +32,6 @@ class Weights:
url: str
transforms: Callable
meta: Dict[str, Any]
default: bool
class WeightsEnum(Enum):
......@@ -50,7 +51,7 @@ class WeightsEnum(Enum):
def verify(cls, obj: Any) -> Any:
if obj is not None:
if type(obj) is str:
obj = cls.from_str(obj)
obj = cls.from_str(obj.replace(cls.__name__ + ".", ""))
elif not isinstance(obj, cls):
raise TypeError(
f"Invalid Weight class provided; expected {cls.__name__} but received {obj.__class__.__name__}."
......@@ -59,8 +60,8 @@ class WeightsEnum(Enum):
@classmethod
def from_str(cls, value: str) -> "WeightsEnum":
for v in cls:
if v._name_ == value or (value == "default" and v.default):
for k, v in cls.__members__.items():
if k == value:
return v
raise ValueError(f"Invalid value {value} for enum {cls.__name__}.")
......@@ -78,41 +79,35 @@ class WeightsEnum(Enum):
return super().__getattr__(name)
def get_weight(fn: Callable, weight_name: str) -> WeightsEnum:
def get_weight(name: str) -> WeightsEnum:
"""
Gets the weight enum of a specific model builder method and weight name combination.
Gets the weight enum value by its full name. Example: "ResNet50_Weights.ImageNet1K_V1"
Args:
fn (Callable): The builder method used to create the model.
weight_name (str): The name of the weight enum entry of the specific model.
name (str): The name of the weight enum entry.
Returns:
WeightsEnum: The requested weight enum.
"""
sig = signature(fn)
if "weights" not in sig.parameters:
raise ValueError("The method is missing the 'weights' parameter.")
try:
enum_name, value_name = name.split(".")
except ValueError:
raise ValueError(f"Invalid weight name provided: '{name}'.")
base_module_name = ".".join(sys.modules[__name__].__name__.split(".")[:-1])
base_module = importlib.import_module(base_module_name)
model_modules = [base_module] + [
x[1] for x in inspect.getmembers(base_module, inspect.ismodule) if x[1].__file__.endswith("__init__.py")
]
ann = signature(fn).parameters["weights"].annotation
weights_enum = None
if isinstance(ann, type) and issubclass(ann, WeightsEnum):
weights_enum = ann
else:
# handle cases like Union[Optional, T]
# TODO: Replace ann.__args__ with typing.get_args(ann) after python >= 3.8
for t in ann.__args__: # type: ignore[union-attr]
if isinstance(t, type) and issubclass(t, WeightsEnum):
# ensure the name exists. handles builders with multiple types of weights like in quantization
try:
t.from_str(weight_name)
except ValueError:
continue
weights_enum = t
break
for m in model_modules:
potential_class = m.__dict__.get(enum_name, None)
if potential_class is not None and issubclass(potential_class, WeightsEnum):
weights_enum = potential_class
break
if weights_enum is None:
raise ValueError(
"The weight class for the specific method couldn't be retrieved. Make sure the typing info is correct."
)
raise ValueError(f"The weight enum '{enum_name}' for the specific method couldn't be retrieved.")
return weights_enum.from_str(weight_name)
return weights_enum.from_str(value_name)
......@@ -25,8 +25,8 @@ class AlexNet_Weights(WeightsEnum):
"acc@1": 56.522,
"acc@5": 79.066,
},
default=True,
)
default = ImageNet1K_V1
def alexnet(weights: Optional[AlexNet_Weights] = None, progress: bool = True, **kwargs: Any) -> AlexNet:
......
......@@ -80,8 +80,8 @@ class DenseNet121_Weights(WeightsEnum):
"acc@1": 74.434,
"acc@5": 91.972,
},
default=True,
)
default = ImageNet1K_V1
class DenseNet161_Weights(WeightsEnum):
......@@ -93,8 +93,8 @@ class DenseNet161_Weights(WeightsEnum):
"acc@1": 77.138,
"acc@5": 93.560,
},
default=True,
)
default = ImageNet1K_V1
class DenseNet169_Weights(WeightsEnum):
......@@ -106,8 +106,8 @@ class DenseNet169_Weights(WeightsEnum):
"acc@1": 75.600,
"acc@5": 92.806,
},
default=True,
)
default = ImageNet1K_V1
class DenseNet201_Weights(WeightsEnum):
......@@ -119,8 +119,8 @@ class DenseNet201_Weights(WeightsEnum):
"acc@1": 76.896,
"acc@5": 93.370,
},
default=True,
)
default = ImageNet1K_V1
def densenet121(weights: Optional[DenseNet121_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
......
......@@ -45,8 +45,8 @@ class FasterRCNN_ResNet50_FPN_Weights(WeightsEnum):
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-resnet-50-fpn",
"map": 37.0,
},
default=True,
)
default = Coco_V1
class FasterRCNN_MobileNet_V3_Large_FPN_Weights(WeightsEnum):
......@@ -58,8 +58,8 @@ class FasterRCNN_MobileNet_V3_Large_FPN_Weights(WeightsEnum):
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-fpn",
"map": 32.8,
},
default=True,
)
default = Coco_V1
class FasterRCNN_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum):
......@@ -71,8 +71,8 @@ class FasterRCNN_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum):
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-320-fpn",
"map": 22.8,
},
default=True,
)
default = Coco_V1
def fasterrcnn_resnet50_fpn(
......
......@@ -35,7 +35,6 @@ class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum):
"box_map": 50.6,
"kp_map": 61.1,
},
default=False,
)
Coco_V1 = Weights(
url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-fc266e95.pth",
......@@ -46,8 +45,8 @@ class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum):
"box_map": 54.6,
"kp_map": 65.0,
},
default=True,
)
default = Coco_V1
def keypointrcnn_resnet50_fpn(
......
......@@ -34,8 +34,8 @@ class MaskRCNN_ResNet50_FPN_Weights(WeightsEnum):
"box_map": 37.9,
"mask_map": 34.6,
},
default=True,
)
default = Coco_V1
def maskrcnn_resnet50_fpn(
......
......@@ -34,8 +34,8 @@ class RetinaNet_ResNet50_FPN_Weights(WeightsEnum):
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#retinanet",
"map": 36.4,
},
default=True,
)
default = Coco_V1
def retinanet_resnet50_fpn(
......
......@@ -33,8 +33,8 @@ class SSD300_VGG16_Weights(WeightsEnum):
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#ssd300-vgg16",
"map": 25.1,
},
default=True,
)
default = Coco_V1
def ssd300_vgg16(
......
......@@ -38,8 +38,8 @@ class SSDLite320_MobileNet_V3_Large_Weights(WeightsEnum):
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#ssdlite320-mobilenetv3-large",
"map": 21.3,
},
default=True,
)
default = Coco_V1
def ssdlite320_mobilenet_v3_large(
......
......@@ -79,8 +79,8 @@ class EfficientNet_B0_Weights(WeightsEnum):
"acc@1": 77.692,
"acc@5": 93.532,
},
default=True,
)
default = ImageNet1K_V1
class EfficientNet_B1_Weights(WeightsEnum):
......@@ -93,8 +93,8 @@ class EfficientNet_B1_Weights(WeightsEnum):
"acc@1": 78.642,
"acc@5": 94.186,
},
default=True,
)
default = ImageNet1K_V1
class EfficientNet_B2_Weights(WeightsEnum):
......@@ -107,8 +107,8 @@ class EfficientNet_B2_Weights(WeightsEnum):
"acc@1": 80.608,
"acc@5": 95.310,
},
default=True,
)
default = ImageNet1K_V1
class EfficientNet_B3_Weights(WeightsEnum):
......@@ -121,8 +121,8 @@ class EfficientNet_B3_Weights(WeightsEnum):
"acc@1": 82.008,
"acc@5": 96.054,
},
default=True,
)
default = ImageNet1K_V1
class EfficientNet_B4_Weights(WeightsEnum):
......@@ -135,8 +135,8 @@ class EfficientNet_B4_Weights(WeightsEnum):
"acc@1": 83.384,
"acc@5": 96.594,
},
default=True,
)
default = ImageNet1K_V1
class EfficientNet_B5_Weights(WeightsEnum):
......@@ -149,8 +149,8 @@ class EfficientNet_B5_Weights(WeightsEnum):
"acc@1": 83.444,
"acc@5": 96.628,
},
default=True,
)
default = ImageNet1K_V1
class EfficientNet_B6_Weights(WeightsEnum):
......@@ -163,8 +163,8 @@ class EfficientNet_B6_Weights(WeightsEnum):
"acc@1": 84.008,
"acc@5": 96.916,
},
default=True,
)
default = ImageNet1K_V1
class EfficientNet_B7_Weights(WeightsEnum):
......@@ -177,8 +177,8 @@ class EfficientNet_B7_Weights(WeightsEnum):
"acc@1": 84.122,
"acc@5": 96.908,
},
default=True,
)
default = ImageNet1K_V1
def efficientnet_b0(
......
......@@ -26,8 +26,8 @@ class GoogLeNet_Weights(WeightsEnum):
"acc@1": 69.778,
"acc@5": 89.530,
},
default=True,
)
default = ImageNet1K_V1
def googlenet(weights: Optional[GoogLeNet_Weights] = None, progress: bool = True, **kwargs: Any) -> GoogLeNet:
......
......@@ -25,8 +25,8 @@ class Inception_V3_Weights(WeightsEnum):
"acc@1": 77.294,
"acc@5": 93.450,
},
default=True,
)
default = ImageNet1K_V1
def inception_v3(weights: Optional[Inception_V3_Weights] = None, progress: bool = True, **kwargs: Any) -> Inception3:
......
......@@ -40,8 +40,8 @@ class MNASNet0_5_Weights(WeightsEnum):
"acc@1": 67.734,
"acc@5": 87.490,
},
default=True,
)
default = ImageNet1K_V1
class MNASNet0_75_Weights(WeightsEnum):
......@@ -58,8 +58,8 @@ class MNASNet1_0_Weights(WeightsEnum):
"acc@1": 73.456,
"acc@5": 91.510,
},
default=True,
)
default = ImageNet1K_V1
class MNASNet1_3_Weights(WeightsEnum):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment