Unverified Commit dc113995 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Adding new ResNet50 weights (#4734)

* Update model checkpoint for resnet50.

* Add get_weight method to retrieve weights from name.

* Update the references to support prototype weights.

* Fixing mypy typing.

* Switching to a python 3.6 supported equivalent.

* Add unit-test.

* Add optional num_classes.
parent 77cc4ef5
...@@ -14,6 +14,12 @@ from torch.utils.data.dataloader import default_collate ...@@ -14,6 +14,12 @@ from torch.utils.data.dataloader import default_collate
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
try:
from torchvision.prototype import models as PM
except ImportError:
PM = None
def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema=None, scaler=None): def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema=None, scaler=None):
model.train() model.train()
metric_logger = utils.MetricLogger(delimiter=" ") metric_logger = utils.MetricLogger(delimiter=" ")
...@@ -142,11 +148,18 @@ def load_data(traindir, valdir, args): ...@@ -142,11 +148,18 @@ def load_data(traindir, valdir, args):
print("Loading dataset_test from {}".format(cache_path)) print("Loading dataset_test from {}".format(cache_path))
dataset_test, _ = torch.load(cache_path) dataset_test, _ = torch.load(cache_path)
else: else:
if not args.weights:
preprocessing = presets.ClassificationPresetEval(
crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation
)
else:
fn = PM.__dict__[args.model]
weights = PM._api.get_weight(fn, args.weights)
preprocessing = weights.transforms()
dataset_test = torchvision.datasets.ImageFolder( dataset_test = torchvision.datasets.ImageFolder(
valdir, valdir,
presets.ClassificationPresetEval( preprocessing,
crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation
),
) )
if args.cache_dataset: if args.cache_dataset:
print("Saving dataset_test to {}".format(cache_path)) print("Saving dataset_test to {}".format(cache_path))
...@@ -206,7 +219,12 @@ def main(args): ...@@ -206,7 +219,12 @@ def main(args):
) )
print("Creating model") print("Creating model")
if not args.weights:
model = torchvision.models.__dict__[args.model](pretrained=args.pretrained, num_classes=num_classes) model = torchvision.models.__dict__[args.model](pretrained=args.pretrained, num_classes=num_classes)
else:
if PM is None:
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
model = PM.__dict__[args.model](weights=args.weights, num_classes=num_classes)
model.to(device) model.to(device)
if args.distributed and args.sync_bn: if args.distributed and args.sync_bn:
...@@ -455,6 +473,9 @@ def get_args_parser(add_help=True): ...@@ -455,6 +473,9 @@ def get_args_parser(add_help=True):
"--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)" "--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)"
) )
# Prototype models only
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
return parser return parser
......
...@@ -12,6 +12,12 @@ def get_available_classification_models(): ...@@ -12,6 +12,12 @@ def get_available_classification_models():
return [k for k, v in models.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"] return [k for k, v in models.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"]
def test_get_weight():
fn = models.resnet50
weight_name = "ImageNet1K_RefV2"
assert models._api.get_weight(fn, weight_name) == models.ResNet50Weights.ImageNet1K_RefV2
@pytest.mark.parametrize("model_name", get_available_classification_models()) @pytest.mark.parametrize("model_name", get_available_classification_models())
@pytest.mark.parametrize("dev", cpu_and_gpu()) @pytest.mark.parametrize("dev", cpu_and_gpu())
@pytest.mark.skipif(os.getenv("PYTORCH_TEST_WITH_PROTOTYPE", "0") == "0", reason="Prototype code tests are disabled") @pytest.mark.skipif(os.getenv("PYTORCH_TEST_WITH_PROTOTYPE", "0") == "0", reason="Prototype code tests are disabled")
......
from collections import OrderedDict from collections import OrderedDict
from dataclasses import dataclass, fields from dataclasses import dataclass, fields
from enum import Enum from enum import Enum
from inspect import signature
from typing import Any, Callable, Dict from typing import Any, Callable, Dict
from ..._internally_replaced_utils import load_state_dict_from_url from ..._internally_replaced_utils import load_state_dict_from_url
__all__ = ["Weights", "WeightEntry"] __all__ = ["Weights", "WeightEntry", "get_weight"]
@dataclass @dataclass
...@@ -74,3 +75,38 @@ class Weights(Enum): ...@@ -74,3 +75,38 @@ class Weights(Enum):
if f.name == name: if f.name == name:
return object.__getattribute__(self.value, name) return object.__getattribute__(self.value, name)
return super().__getattr__(name) return super().__getattr__(name)
def get_weight(fn: Callable, weight_name: str) -> Weights:
"""
Gets the weight enum of a specific model builder method and weight name combination.
Args:
fn (Callable): The builder method used to create the model.
weight_name (str): The name of the weight enum entry of the specific model.
Returns:
Weights: The requested weight enum.
"""
sig = signature(fn)
if "weights" not in sig.parameters:
raise ValueError("The method is missing the 'weights' argument.")
ann = signature(fn).parameters["weights"].annotation
weights_class = None
if isinstance(ann, type) and issubclass(ann, Weights):
weights_class = ann
else:
# handle cases like Union[Optional, T]
# TODO: Replace ann.__args__ with typing.get_args(ann) after python >= 3.8
for t in ann.__args__: # type: ignore[union-attr]
if isinstance(t, type) and issubclass(t, Weights):
weights_class = t
break
if weights_class is None:
raise ValueError(
"The weight class for the specific method couldn't be retrieved. Make sure the typing info is " "correct."
)
return weights_class.from_str(weight_name)
...@@ -92,13 +92,13 @@ class ResNet50Weights(Weights): ...@@ -92,13 +92,13 @@ class ResNet50Weights(Weights):
}, },
) )
ImageNet1K_RefV2 = WeightEntry( ImageNet1K_RefV2 = WeightEntry(
url="https://download.pytorch.org/models/resnet50-tmp.pth", url="https://download.pytorch.org/models/resnet50-f46c3f97.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224, resize_size=232),
meta={ meta={
**_common_meta, **_common_meta,
"recipe": "https://github.com/pytorch/vision/issues/3995", "recipe": "https://github.com/pytorch/vision/issues/3995",
"acc@1": 80.352, "acc@1": 80.674,
"acc@5": 95.148, "acc@5": 95.166,
}, },
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment