Unverified Commit 3d8723d5 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Refactor the `get_weights` API (#5006)

* Change the `default` weights mechanism to sue Enum aliases.

* Change `get_weights` to work with full Enum names and make it public.

* Applying improvements from code review.
parent 65cdaeab
...@@ -158,8 +158,7 @@ def load_data(traindir, valdir, args): ...@@ -158,8 +158,7 @@ def load_data(traindir, valdir, args):
crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation
) )
else: else:
fn = PM.quantization.__dict__[args.model] if hasattr(args, "backend") else PM.__dict__[args.model] weights = PM.get_weight(args.weights)
weights = PM._api.get_weight(fn, args.weights)
preprocessing = weights.transforms() preprocessing = weights.transforms()
dataset_test = torchvision.datasets.ImageFolder( dataset_test = torchvision.datasets.ImageFolder(
......
...@@ -53,8 +53,7 @@ def get_transform(train, args): ...@@ -53,8 +53,7 @@ def get_transform(train, args):
elif not args.weights: elif not args.weights:
return presets.DetectionPresetEval() return presets.DetectionPresetEval()
else: else:
fn = PM.detection.__dict__[args.model] weights = PM.get_weight(args.weights)
weights = PM._api.get_weight(fn, args.weights)
return weights.transforms() return weights.transforms()
......
...@@ -38,8 +38,7 @@ def get_transform(train, args): ...@@ -38,8 +38,7 @@ def get_transform(train, args):
elif not args.weights: elif not args.weights:
return presets.SegmentationPresetEval(base_size=520) return presets.SegmentationPresetEval(base_size=520)
else: else:
fn = PM.segmentation.__dict__[args.model] weights = PM.get_weight(args.weights)
weights = PM._api.get_weight(fn, args.weights)
return weights.transforms() return weights.transforms()
......
...@@ -160,8 +160,7 @@ def main(args): ...@@ -160,8 +160,7 @@ def main(args):
if not args.weights: if not args.weights:
transform_test = presets.VideoClassificationPresetEval((128, 171), (112, 112)) transform_test = presets.VideoClassificationPresetEval((128, 171), (112, 112))
else: else:
fn = PM.video.__dict__[args.model] weights = PM.get_weight(args.weights)
weights = PM._api.get_weight(fn, args.weights)
transform_test = weights.transforms() transform_test = weights.transforms()
if args.cache_dataset and os.path.exists(cache_path): if args.cache_dataset and os.path.exists(cache_path):
......
...@@ -22,7 +22,11 @@ ACCEPT = os.getenv("EXPECTTEST_ACCEPT", "0") == "1" ...@@ -22,7 +22,11 @@ ACCEPT = os.getenv("EXPECTTEST_ACCEPT", "0") == "1"
def get_models_from_module(module): def get_models_from_module(module):
# TODO add a registration mechanism to torchvision.models # TODO add a registration mechanism to torchvision.models
return [v for k, v in module.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"] return [
v
for k, v in module.__dict__.items()
if callable(v) and k[0].lower() == k[0] and k[0] != "_" and k != "get_weight"
]
@pytest.fixture @pytest.fixture
......
...@@ -24,6 +24,19 @@ def _get_parent_module(model_fn): ...@@ -24,6 +24,19 @@ def _get_parent_module(model_fn):
return module return module
def _get_model_weights(model_fn):
module = _get_parent_module(model_fn)
weights_name = "_QuantizedWeights" if module.__name__.split(".")[-1] == "quantization" else "_Weights"
try:
return next(
v
for k, v in module.__dict__.items()
if k.endswith(weights_name) and k.replace(weights_name, "").lower() == model_fn.__name__
)
except StopIteration:
return None
def _build_model(fn, **kwargs): def _build_model(fn, **kwargs):
try: try:
model = fn(**kwargs) model = fn(**kwargs)
...@@ -36,24 +49,22 @@ def _build_model(fn, **kwargs): ...@@ -36,24 +49,22 @@ def _build_model(fn, **kwargs):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model_fn, name, weight", "name, weight",
[ [
(models.resnet50, "ImageNet1K_V1", models.ResNet50_Weights.ImageNet1K_V1), ("ResNet50_Weights.ImageNet1K_V1", models.ResNet50_Weights.ImageNet1K_V1),
(models.resnet50, "default", models.ResNet50_Weights.ImageNet1K_V2), ("ResNet50_Weights.default", models.ResNet50_Weights.ImageNet1K_V2),
( (
models.quantization.resnet50, "ResNet50_QuantizedWeights.default",
"default",
models.quantization.ResNet50_QuantizedWeights.ImageNet1K_FBGEMM_V2, models.quantization.ResNet50_QuantizedWeights.ImageNet1K_FBGEMM_V2,
), ),
( (
models.quantization.resnet50, "ResNet50_QuantizedWeights.ImageNet1K_FBGEMM_V1",
"ImageNet1K_FBGEMM_V1",
models.quantization.ResNet50_QuantizedWeights.ImageNet1K_FBGEMM_V1, models.quantization.ResNet50_QuantizedWeights.ImageNet1K_FBGEMM_V1,
), ),
], ],
) )
def test_get_weight(model_fn, name, weight): def test_get_weight(name, weight):
assert models._api.get_weight(model_fn, name) == weight assert models.get_weight(name) == weight
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -65,10 +76,9 @@ def test_get_weight(model_fn, name, weight): ...@@ -65,10 +76,9 @@ def test_get_weight(model_fn, name, weight):
+ TM.get_models_from_module(models.video), + TM.get_models_from_module(models.video),
) )
def test_naming_conventions(model_fn): def test_naming_conventions(model_fn):
model_name = model_fn.__name__ weights_enum = _get_model_weights(model_fn)
module = _get_parent_module(model_fn) assert weights_enum is not None
weights_name = "_QuantizedWeights" if module.__name__.split(".")[-1] == "quantization" else "_Weights" assert len(weights_enum) == 0 or hasattr(weights_enum, "default")
assert model_name in set(x.replace(weights_name, "").lower() for x in module.__dict__ if x.endswith(weights_name))
@pytest.mark.parametrize("model_fn", TM.get_models_from_module(models)) @pytest.mark.parametrize("model_fn", TM.get_models_from_module(models))
......
...@@ -15,3 +15,4 @@ from . import detection ...@@ -15,3 +15,4 @@ from . import detection
from . import quantization from . import quantization
from . import segmentation from . import segmentation
from . import video from . import video
from ._api import get_weight
import importlib
import inspect
import sys
from collections import OrderedDict from collections import OrderedDict
from dataclasses import dataclass, fields from dataclasses import dataclass, fields
from enum import Enum from enum import Enum
from inspect import signature
from typing import Any, Callable, Dict from typing import Any, Callable, Dict
from ..._internally_replaced_utils import load_state_dict_from_url from ..._internally_replaced_utils import load_state_dict_from_url
...@@ -30,7 +32,6 @@ class Weights: ...@@ -30,7 +32,6 @@ class Weights:
url: str url: str
transforms: Callable transforms: Callable
meta: Dict[str, Any] meta: Dict[str, Any]
default: bool
class WeightsEnum(Enum): class WeightsEnum(Enum):
...@@ -50,7 +51,7 @@ class WeightsEnum(Enum): ...@@ -50,7 +51,7 @@ class WeightsEnum(Enum):
def verify(cls, obj: Any) -> Any: def verify(cls, obj: Any) -> Any:
if obj is not None: if obj is not None:
if type(obj) is str: if type(obj) is str:
obj = cls.from_str(obj) obj = cls.from_str(obj.replace(cls.__name__ + ".", ""))
elif not isinstance(obj, cls): elif not isinstance(obj, cls):
raise TypeError( raise TypeError(
f"Invalid Weight class provided; expected {cls.__name__} but received {obj.__class__.__name__}." f"Invalid Weight class provided; expected {cls.__name__} but received {obj.__class__.__name__}."
...@@ -59,8 +60,8 @@ class WeightsEnum(Enum): ...@@ -59,8 +60,8 @@ class WeightsEnum(Enum):
@classmethod @classmethod
def from_str(cls, value: str) -> "WeightsEnum": def from_str(cls, value: str) -> "WeightsEnum":
for v in cls: for k, v in cls.__members__.items():
if v._name_ == value or (value == "default" and v.default): if k == value:
return v return v
raise ValueError(f"Invalid value {value} for enum {cls.__name__}.") raise ValueError(f"Invalid value {value} for enum {cls.__name__}.")
...@@ -78,41 +79,35 @@ class WeightsEnum(Enum): ...@@ -78,41 +79,35 @@ class WeightsEnum(Enum):
return super().__getattr__(name) return super().__getattr__(name)
def get_weight(fn: Callable, weight_name: str) -> WeightsEnum: def get_weight(name: str) -> WeightsEnum:
""" """
Gets the weight enum of a specific model builder method and weight name combination. Gets the weight enum value by its full name. Example: "ResNet50_Weights.ImageNet1K_V1"
Args: Args:
fn (Callable): The builder method used to create the model. name (str): The name of the weight enum entry.
weight_name (str): The name of the weight enum entry of the specific model.
Returns: Returns:
WeightsEnum: The requested weight enum. WeightsEnum: The requested weight enum.
""" """
sig = signature(fn) try:
if "weights" not in sig.parameters: enum_name, value_name = name.split(".")
raise ValueError("The method is missing the 'weights' parameter.") except ValueError:
raise ValueError(f"Invalid weight name provided: '{name}'.")
base_module_name = ".".join(sys.modules[__name__].__name__.split(".")[:-1])
base_module = importlib.import_module(base_module_name)
model_modules = [base_module] + [
x[1] for x in inspect.getmembers(base_module, inspect.ismodule) if x[1].__file__.endswith("__init__.py")
]
ann = signature(fn).parameters["weights"].annotation
weights_enum = None weights_enum = None
if isinstance(ann, type) and issubclass(ann, WeightsEnum): for m in model_modules:
weights_enum = ann potential_class = m.__dict__.get(enum_name, None)
else: if potential_class is not None and issubclass(potential_class, WeightsEnum):
# handle cases like Union[Optional, T] weights_enum = potential_class
# TODO: Replace ann.__args__ with typing.get_args(ann) after python >= 3.8 break
for t in ann.__args__: # type: ignore[union-attr]
if isinstance(t, type) and issubclass(t, WeightsEnum):
# ensure the name exists. handles builders with multiple types of weights like in quantization
try:
t.from_str(weight_name)
except ValueError:
continue
weights_enum = t
break
if weights_enum is None: if weights_enum is None:
raise ValueError( raise ValueError(f"The weight enum '{enum_name}' for the specific method couldn't be retrieved.")
"The weight class for the specific method couldn't be retrieved. Make sure the typing info is correct."
)
return weights_enum.from_str(weight_name) return weights_enum.from_str(value_name)
...@@ -25,8 +25,8 @@ class AlexNet_Weights(WeightsEnum): ...@@ -25,8 +25,8 @@ class AlexNet_Weights(WeightsEnum):
"acc@1": 56.522, "acc@1": 56.522,
"acc@5": 79.066, "acc@5": 79.066,
}, },
default=True,
) )
default = ImageNet1K_V1
def alexnet(weights: Optional[AlexNet_Weights] = None, progress: bool = True, **kwargs: Any) -> AlexNet: def alexnet(weights: Optional[AlexNet_Weights] = None, progress: bool = True, **kwargs: Any) -> AlexNet:
......
...@@ -80,8 +80,8 @@ class DenseNet121_Weights(WeightsEnum): ...@@ -80,8 +80,8 @@ class DenseNet121_Weights(WeightsEnum):
"acc@1": 74.434, "acc@1": 74.434,
"acc@5": 91.972, "acc@5": 91.972,
}, },
default=True,
) )
default = ImageNet1K_V1
class DenseNet161_Weights(WeightsEnum): class DenseNet161_Weights(WeightsEnum):
...@@ -93,8 +93,8 @@ class DenseNet161_Weights(WeightsEnum): ...@@ -93,8 +93,8 @@ class DenseNet161_Weights(WeightsEnum):
"acc@1": 77.138, "acc@1": 77.138,
"acc@5": 93.560, "acc@5": 93.560,
}, },
default=True,
) )
default = ImageNet1K_V1
class DenseNet169_Weights(WeightsEnum): class DenseNet169_Weights(WeightsEnum):
...@@ -106,8 +106,8 @@ class DenseNet169_Weights(WeightsEnum): ...@@ -106,8 +106,8 @@ class DenseNet169_Weights(WeightsEnum):
"acc@1": 75.600, "acc@1": 75.600,
"acc@5": 92.806, "acc@5": 92.806,
}, },
default=True,
) )
default = ImageNet1K_V1
class DenseNet201_Weights(WeightsEnum): class DenseNet201_Weights(WeightsEnum):
...@@ -119,8 +119,8 @@ class DenseNet201_Weights(WeightsEnum): ...@@ -119,8 +119,8 @@ class DenseNet201_Weights(WeightsEnum):
"acc@1": 76.896, "acc@1": 76.896,
"acc@5": 93.370, "acc@5": 93.370,
}, },
default=True,
) )
default = ImageNet1K_V1
def densenet121(weights: Optional[DenseNet121_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet: def densenet121(weights: Optional[DenseNet121_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
......
...@@ -45,8 +45,8 @@ class FasterRCNN_ResNet50_FPN_Weights(WeightsEnum): ...@@ -45,8 +45,8 @@ class FasterRCNN_ResNet50_FPN_Weights(WeightsEnum):
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-resnet-50-fpn", "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-resnet-50-fpn",
"map": 37.0, "map": 37.0,
}, },
default=True,
) )
default = Coco_V1
class FasterRCNN_MobileNet_V3_Large_FPN_Weights(WeightsEnum): class FasterRCNN_MobileNet_V3_Large_FPN_Weights(WeightsEnum):
...@@ -58,8 +58,8 @@ class FasterRCNN_MobileNet_V3_Large_FPN_Weights(WeightsEnum): ...@@ -58,8 +58,8 @@ class FasterRCNN_MobileNet_V3_Large_FPN_Weights(WeightsEnum):
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-fpn", "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-fpn",
"map": 32.8, "map": 32.8,
}, },
default=True,
) )
default = Coco_V1
class FasterRCNN_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum): class FasterRCNN_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum):
...@@ -71,8 +71,8 @@ class FasterRCNN_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum): ...@@ -71,8 +71,8 @@ class FasterRCNN_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum):
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-320-fpn", "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-320-fpn",
"map": 22.8, "map": 22.8,
}, },
default=True,
) )
default = Coco_V1
def fasterrcnn_resnet50_fpn( def fasterrcnn_resnet50_fpn(
......
...@@ -35,7 +35,6 @@ class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum): ...@@ -35,7 +35,6 @@ class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum):
"box_map": 50.6, "box_map": 50.6,
"kp_map": 61.1, "kp_map": 61.1,
}, },
default=False,
) )
Coco_V1 = Weights( Coco_V1 = Weights(
url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-fc266e95.pth", url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-fc266e95.pth",
...@@ -46,8 +45,8 @@ class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum): ...@@ -46,8 +45,8 @@ class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum):
"box_map": 54.6, "box_map": 54.6,
"kp_map": 65.0, "kp_map": 65.0,
}, },
default=True,
) )
default = Coco_V1
def keypointrcnn_resnet50_fpn( def keypointrcnn_resnet50_fpn(
......
...@@ -34,8 +34,8 @@ class MaskRCNN_ResNet50_FPN_Weights(WeightsEnum): ...@@ -34,8 +34,8 @@ class MaskRCNN_ResNet50_FPN_Weights(WeightsEnum):
"box_map": 37.9, "box_map": 37.9,
"mask_map": 34.6, "mask_map": 34.6,
}, },
default=True,
) )
default = Coco_V1
def maskrcnn_resnet50_fpn( def maskrcnn_resnet50_fpn(
......
...@@ -34,8 +34,8 @@ class RetinaNet_ResNet50_FPN_Weights(WeightsEnum): ...@@ -34,8 +34,8 @@ class RetinaNet_ResNet50_FPN_Weights(WeightsEnum):
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#retinanet", "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#retinanet",
"map": 36.4, "map": 36.4,
}, },
default=True,
) )
default = Coco_V1
def retinanet_resnet50_fpn( def retinanet_resnet50_fpn(
......
...@@ -33,8 +33,8 @@ class SSD300_VGG16_Weights(WeightsEnum): ...@@ -33,8 +33,8 @@ class SSD300_VGG16_Weights(WeightsEnum):
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#ssd300-vgg16", "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#ssd300-vgg16",
"map": 25.1, "map": 25.1,
}, },
default=True,
) )
default = Coco_V1
def ssd300_vgg16( def ssd300_vgg16(
......
...@@ -38,8 +38,8 @@ class SSDLite320_MobileNet_V3_Large_Weights(WeightsEnum): ...@@ -38,8 +38,8 @@ class SSDLite320_MobileNet_V3_Large_Weights(WeightsEnum):
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#ssdlite320-mobilenetv3-large", "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#ssdlite320-mobilenetv3-large",
"map": 21.3, "map": 21.3,
}, },
default=True,
) )
default = Coco_V1
def ssdlite320_mobilenet_v3_large( def ssdlite320_mobilenet_v3_large(
......
...@@ -79,8 +79,8 @@ class EfficientNet_B0_Weights(WeightsEnum): ...@@ -79,8 +79,8 @@ class EfficientNet_B0_Weights(WeightsEnum):
"acc@1": 77.692, "acc@1": 77.692,
"acc@5": 93.532, "acc@5": 93.532,
}, },
default=True,
) )
default = ImageNet1K_V1
class EfficientNet_B1_Weights(WeightsEnum): class EfficientNet_B1_Weights(WeightsEnum):
...@@ -93,8 +93,8 @@ class EfficientNet_B1_Weights(WeightsEnum): ...@@ -93,8 +93,8 @@ class EfficientNet_B1_Weights(WeightsEnum):
"acc@1": 78.642, "acc@1": 78.642,
"acc@5": 94.186, "acc@5": 94.186,
}, },
default=True,
) )
default = ImageNet1K_V1
class EfficientNet_B2_Weights(WeightsEnum): class EfficientNet_B2_Weights(WeightsEnum):
...@@ -107,8 +107,8 @@ class EfficientNet_B2_Weights(WeightsEnum): ...@@ -107,8 +107,8 @@ class EfficientNet_B2_Weights(WeightsEnum):
"acc@1": 80.608, "acc@1": 80.608,
"acc@5": 95.310, "acc@5": 95.310,
}, },
default=True,
) )
default = ImageNet1K_V1
class EfficientNet_B3_Weights(WeightsEnum): class EfficientNet_B3_Weights(WeightsEnum):
...@@ -121,8 +121,8 @@ class EfficientNet_B3_Weights(WeightsEnum): ...@@ -121,8 +121,8 @@ class EfficientNet_B3_Weights(WeightsEnum):
"acc@1": 82.008, "acc@1": 82.008,
"acc@5": 96.054, "acc@5": 96.054,
}, },
default=True,
) )
default = ImageNet1K_V1
class EfficientNet_B4_Weights(WeightsEnum): class EfficientNet_B4_Weights(WeightsEnum):
...@@ -135,8 +135,8 @@ class EfficientNet_B4_Weights(WeightsEnum): ...@@ -135,8 +135,8 @@ class EfficientNet_B4_Weights(WeightsEnum):
"acc@1": 83.384, "acc@1": 83.384,
"acc@5": 96.594, "acc@5": 96.594,
}, },
default=True,
) )
default = ImageNet1K_V1
class EfficientNet_B5_Weights(WeightsEnum): class EfficientNet_B5_Weights(WeightsEnum):
...@@ -149,8 +149,8 @@ class EfficientNet_B5_Weights(WeightsEnum): ...@@ -149,8 +149,8 @@ class EfficientNet_B5_Weights(WeightsEnum):
"acc@1": 83.444, "acc@1": 83.444,
"acc@5": 96.628, "acc@5": 96.628,
}, },
default=True,
) )
default = ImageNet1K_V1
class EfficientNet_B6_Weights(WeightsEnum): class EfficientNet_B6_Weights(WeightsEnum):
...@@ -163,8 +163,8 @@ class EfficientNet_B6_Weights(WeightsEnum): ...@@ -163,8 +163,8 @@ class EfficientNet_B6_Weights(WeightsEnum):
"acc@1": 84.008, "acc@1": 84.008,
"acc@5": 96.916, "acc@5": 96.916,
}, },
default=True,
) )
default = ImageNet1K_V1
class EfficientNet_B7_Weights(WeightsEnum): class EfficientNet_B7_Weights(WeightsEnum):
...@@ -177,8 +177,8 @@ class EfficientNet_B7_Weights(WeightsEnum): ...@@ -177,8 +177,8 @@ class EfficientNet_B7_Weights(WeightsEnum):
"acc@1": 84.122, "acc@1": 84.122,
"acc@5": 96.908, "acc@5": 96.908,
}, },
default=True,
) )
default = ImageNet1K_V1
def efficientnet_b0( def efficientnet_b0(
......
...@@ -26,8 +26,8 @@ class GoogLeNet_Weights(WeightsEnum): ...@@ -26,8 +26,8 @@ class GoogLeNet_Weights(WeightsEnum):
"acc@1": 69.778, "acc@1": 69.778,
"acc@5": 89.530, "acc@5": 89.530,
}, },
default=True,
) )
default = ImageNet1K_V1
def googlenet(weights: Optional[GoogLeNet_Weights] = None, progress: bool = True, **kwargs: Any) -> GoogLeNet: def googlenet(weights: Optional[GoogLeNet_Weights] = None, progress: bool = True, **kwargs: Any) -> GoogLeNet:
......
...@@ -25,8 +25,8 @@ class Inception_V3_Weights(WeightsEnum): ...@@ -25,8 +25,8 @@ class Inception_V3_Weights(WeightsEnum):
"acc@1": 77.294, "acc@1": 77.294,
"acc@5": 93.450, "acc@5": 93.450,
}, },
default=True,
) )
default = ImageNet1K_V1
def inception_v3(weights: Optional[Inception_V3_Weights] = None, progress: bool = True, **kwargs: Any) -> Inception3: def inception_v3(weights: Optional[Inception_V3_Weights] = None, progress: bool = True, **kwargs: Any) -> Inception3:
......
...@@ -40,8 +40,8 @@ class MNASNet0_5_Weights(WeightsEnum): ...@@ -40,8 +40,8 @@ class MNASNet0_5_Weights(WeightsEnum):
"acc@1": 67.734, "acc@1": 67.734,
"acc@5": 87.490, "acc@5": 87.490,
}, },
default=True,
) )
default = ImageNet1K_V1
class MNASNet0_75_Weights(WeightsEnum): class MNASNet0_75_Weights(WeightsEnum):
...@@ -58,8 +58,8 @@ class MNASNet1_0_Weights(WeightsEnum): ...@@ -58,8 +58,8 @@ class MNASNet1_0_Weights(WeightsEnum):
"acc@1": 73.456, "acc@1": 73.456,
"acc@5": 91.510, "acc@5": 91.510,
}, },
default=True,
) )
default = ImageNet1K_V1
class MNASNet1_3_Weights(WeightsEnum): class MNASNet1_3_Weights(WeightsEnum):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment