Unverified Commit dc113995 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Adding new ResNet50 weights (#4734)

* Update model checkpoint for resnet50.

* Add get_weight method to retrieve weights from name.

* Update the references to support prototype weights.

* Fixing mypy typing.

* Switching to a python 3.6 supported equivalent.

* Add unit-test.

* Add optional num_classes.
parent 77cc4ef5
......@@ -14,6 +14,12 @@ from torch.utils.data.dataloader import default_collate
from torchvision.transforms.functional import InterpolationMode
try:
from torchvision.prototype import models as PM
except ImportError:
PM = None
def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema=None, scaler=None):
model.train()
metric_logger = utils.MetricLogger(delimiter=" ")
......@@ -142,11 +148,18 @@ def load_data(traindir, valdir, args):
print("Loading dataset_test from {}".format(cache_path))
dataset_test, _ = torch.load(cache_path)
else:
if not args.weights:
preprocessing = presets.ClassificationPresetEval(
crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation
)
else:
fn = PM.__dict__[args.model]
weights = PM._api.get_weight(fn, args.weights)
preprocessing = weights.transforms()
dataset_test = torchvision.datasets.ImageFolder(
valdir,
presets.ClassificationPresetEval(
crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation
),
preprocessing,
)
if args.cache_dataset:
print("Saving dataset_test to {}".format(cache_path))
......@@ -206,7 +219,12 @@ def main(args):
)
print("Creating model")
if not args.weights:
model = torchvision.models.__dict__[args.model](pretrained=args.pretrained, num_classes=num_classes)
else:
if PM is None:
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
model = PM.__dict__[args.model](weights=args.weights, num_classes=num_classes)
model.to(device)
if args.distributed and args.sync_bn:
......@@ -455,6 +473,9 @@ def get_args_parser(add_help=True):
"--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)"
)
# Prototype models only
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
return parser
......
......@@ -12,6 +12,12 @@ def get_available_classification_models():
return [k for k, v in models.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"]
def test_get_weight():
fn = models.resnet50
weight_name = "ImageNet1K_RefV2"
assert models._api.get_weight(fn, weight_name) == models.ResNet50Weights.ImageNet1K_RefV2
@pytest.mark.parametrize("model_name", get_available_classification_models())
@pytest.mark.parametrize("dev", cpu_and_gpu())
@pytest.mark.skipif(os.getenv("PYTORCH_TEST_WITH_PROTOTYPE", "0") == "0", reason="Prototype code tests are disabled")
......
from collections import OrderedDict
from dataclasses import dataclass, fields
from enum import Enum
from inspect import signature
from typing import Any, Callable, Dict
from ..._internally_replaced_utils import load_state_dict_from_url
__all__ = ["Weights", "WeightEntry"]
__all__ = ["Weights", "WeightEntry", "get_weight"]
@dataclass
......@@ -74,3 +75,38 @@ class Weights(Enum):
if f.name == name:
return object.__getattribute__(self.value, name)
return super().__getattr__(name)
def get_weight(fn: Callable, weight_name: str) -> Weights:
"""
Gets the weight enum of a specific model builder method and weight name combination.
Args:
fn (Callable): The builder method used to create the model.
weight_name (str): The name of the weight enum entry of the specific model.
Returns:
Weights: The requested weight enum.
"""
sig = signature(fn)
if "weights" not in sig.parameters:
raise ValueError("The method is missing the 'weights' argument.")
ann = signature(fn).parameters["weights"].annotation
weights_class = None
if isinstance(ann, type) and issubclass(ann, Weights):
weights_class = ann
else:
# handle cases like Union[Optional, T]
# TODO: Replace ann.__args__ with typing.get_args(ann) after python >= 3.8
for t in ann.__args__: # type: ignore[union-attr]
if isinstance(t, type) and issubclass(t, Weights):
weights_class = t
break
if weights_class is None:
raise ValueError(
"The weight class for the specific method couldn't be retrieved. Make sure the typing info is " "correct."
)
return weights_class.from_str(weight_name)
......@@ -92,13 +92,13 @@ class ResNet50Weights(Weights):
},
)
ImageNet1K_RefV2 = WeightEntry(
url="https://download.pytorch.org/models/resnet50-tmp.pth",
transforms=partial(ImageNetEval, crop_size=224),
url="https://download.pytorch.org/models/resnet50-f46c3f97.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232),
meta={
**_common_meta,
"recipe": "https://github.com/pytorch/vision/issues/3995",
"acc@1": 80.352,
"acc@5": 95.148,
"acc@1": 80.674,
"acc@5": 95.166,
},
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment