Unverified Commit 11bd2eaa authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Port Multi-weight support from prototype to main (#5618)



* Moving basefiles outside of prototype and porting Alexnet, ConvNext, Densenet and EfficientNet.

* Porting googlenet

* Porting inception

* Porting mnasnet

* Porting mobilenetv2

* Porting mobilenetv3

* Porting regnet

* Porting resnet

* Porting shufflenetv2

* Porting squeezenet

* Porting vgg

* Porting vit

* Fix docstrings

* Fixing imports

* Adding missing import

* Fix mobilenet imports

* Fix tests

* Fix prototype tests

* Exclude get_weight from models on test

* Fix init files

* Porting googlenet

* Porting inception

* porting mobilenetv2

* porting mobilenetv3

* porting resnet

* porting shufflenetv2

* Fix test and linter

* Fixing docs.

* Porting Detection models (#5617)

* fix inits

* fix docs

* Port faster_rcnn

* Port fcos

* Port keypoint_rcnn

* Port mask_rcnn

* Port retinanet

* Port ssd

* Port ssdlite

* Fix linter

* Fixing tests

* Fixing tests

* Fixing vgg test

* Porting Optical Flow, Segmentation, Video models (#5619)

* Porting raft

* Porting video resnet

* Porting deeplabv3

* Porting fcn and lraspp

* Fixing the tests and linter

* Porting docs, examples, tutorials and galleries (#5620)

* Fix examples, tutorials and gallery

* Update gallery/plot_optical_flow.py
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>

* Fix import

* Revert hardcoded normalization

* fix uncommitted changes

* Fix bug

* Fix more bugs

* Making resize optional for segmentation

* Fixing preset

* Fix mypy

* Fixing documentation strings

* Fix flake8

* minor refactoring
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>

* Resolve conflict

* Porting model tests (#5622)

* Porting tests

* Remove unnecessary variable

* Fix linter

* Move prototype to extended tests

* Fix download models job

* Update CI on Multiweight branch to use the new weight download approach (#5628)

* port Pad to prototype transforms (#5621)

* port Pad to prototype transforms

* use literal

* Bump up LibTorchvision version number for Podspec to release Cocoapods (#5624)
Co-authored-by: default avatarAnton Thomma <anton@pri.co.nz>
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>

* pre-download model weights in CI docs build (#5625)

* pre-download model weights in CI docs build

* move changes into template

* change docs image

* Regenerated config.yml
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
Co-authored-by: default avatarAnton Thomma <11010310+thommaa@users.noreply.github.com>
Co-authored-by: default avatarAnton Thomma <anton@pri.co.nz>

* Porting reference scripts and updating presets (#5629)

* Making _preset.py classes

* Remove support of targets on presets.

* Rewriting the video preset

* Adding tests to check that the bundled transforms are JIT scriptable

* Rename all presets from *Eval to *Inference

* Minor refactoring

* Remove --prototype and --pretrained from reference scripts

* remove  pretained_backbone refs

* Corrections and simplifications

* Fixing bug

* Fixing linter

* Fix flake8

* restore documentation example

* minor fixes

* fix optical flow missing param

* Fixing commands

* Adding weights_backbone support in detection and segmentation

* Updating the commands for InceptionV3

* Setting `weights_backbone` to its fully BC value (#5653)

* Replace default `weights_backbone=None` with its BC values.

* Fixing tests

* Fix linter

* Update docs.

* Update preprocessing on reference scripts.

* Change qat/ptq to their full values.

* Refactoring preprocessing

* Fix video preset

* No initialization on VGG if pretrained

* Fix warning messages for backbone utils.

* Adding star to all preset constructors.

* Fix mypy.
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
Co-authored-by: default avatarAnton Thomma <11010310+thommaa@users.noreply.github.com>
Co-authored-by: default avatarAnton Thomma <anton@pri.co.nz>
parent 375e4ab2
...@@ -14,30 +14,30 @@ You must modify the following flags: ...@@ -14,30 +14,30 @@ You must modify the following flags:
## fcn_resnet50 ## fcn_resnet50
``` ```
torchrun --nproc_per_node=8 train.py --lr 0.02 --dataset coco -b 4 --model fcn_resnet50 --aux-loss torchrun --nproc_per_node=8 train.py --lr 0.02 --dataset coco -b 4 --model fcn_resnet50 --aux-loss --weights-backbone ResNet50_Weights.IMAGENET1K_V1
``` ```
## fcn_resnet101 ## fcn_resnet101
``` ```
torchrun --nproc_per_node=8 train.py --lr 0.02 --dataset coco -b 4 --model fcn_resnet101 --aux-loss torchrun --nproc_per_node=8 train.py --lr 0.02 --dataset coco -b 4 --model fcn_resnet101 --aux-loss --weights-backbone ResNet101_Weights.IMAGENET1K_V1
``` ```
## deeplabv3_resnet50 ## deeplabv3_resnet50
``` ```
torchrun --nproc_per_node=8 train.py --lr 0.02 --dataset coco -b 4 --model deeplabv3_resnet50 --aux-loss torchrun --nproc_per_node=8 train.py --lr 0.02 --dataset coco -b 4 --model deeplabv3_resnet50 --aux-loss --weights-backbone ResNet50_Weights.IMAGENET1K_V1
``` ```
## deeplabv3_resnet101 ## deeplabv3_resnet101
``` ```
torchrun --nproc_per_node=8 train.py --lr 0.02 --dataset coco -b 4 --model deeplabv3_resnet101 --aux-loss torchrun --nproc_per_node=8 train.py --lr 0.02 --dataset coco -b 4 --model deeplabv3_resnet101 --aux-loss --weights-backbone ResNet101_Weights.IMAGENET1K_V1
``` ```
## deeplabv3_mobilenet_v3_large ## deeplabv3_mobilenet_v3_large
``` ```
torchrun --nproc_per_node=8 train.py --dataset coco -b 4 --model deeplabv3_mobilenet_v3_large --aux-loss --wd 0.000001 torchrun --nproc_per_node=8 train.py --dataset coco -b 4 --model deeplabv3_mobilenet_v3_large --aux-loss --wd 0.000001 --weights-backbone MobileNet_V3_Large_Weights.IMAGENET1K_V1
``` ```
## lraspp_mobilenet_v3_large ## lraspp_mobilenet_v3_large
``` ```
torchrun --nproc_per_node=8 train.py --dataset coco -b 4 --model lraspp_mobilenet_v3_large --wd 0.000001 torchrun --nproc_per_node=8 train.py --dataset coco -b 4 --model lraspp_mobilenet_v3_large --wd 0.000001 --weights-backbone MobileNet_V3_Large_Weights.IMAGENET1K_V1
``` ```
...@@ -3,7 +3,7 @@ import transforms as T ...@@ -3,7 +3,7 @@ import transforms as T
class SegmentationPresetTrain: class SegmentationPresetTrain:
def __init__(self, base_size, crop_size, hflip_prob=0.5, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)): def __init__(self, *, base_size, crop_size, hflip_prob=0.5, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
min_size = int(0.5 * base_size) min_size = int(0.5 * base_size)
max_size = int(2.0 * base_size) max_size = int(2.0 * base_size)
...@@ -25,7 +25,7 @@ class SegmentationPresetTrain: ...@@ -25,7 +25,7 @@ class SegmentationPresetTrain:
class SegmentationPresetEval: class SegmentationPresetEval:
def __init__(self, base_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)): def __init__(self, *, base_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
self.transforms = T.Compose( self.transforms = T.Compose(
[ [
T.RandomResize(base_size, base_size), T.RandomResize(base_size, base_size),
......
...@@ -9,12 +9,7 @@ import torchvision ...@@ -9,12 +9,7 @@ import torchvision
import utils import utils
from coco_utils import get_coco from coco_utils import get_coco
from torch import nn from torch import nn
from torchvision.transforms import functional as F, InterpolationMode
try:
from torchvision import prototype
except ImportError:
prototype = None
def get_dataset(dir_path, name, image_set, transform): def get_dataset(dir_path, name, image_set, transform):
...@@ -35,14 +30,19 @@ def get_dataset(dir_path, name, image_set, transform): ...@@ -35,14 +30,19 @@ def get_dataset(dir_path, name, image_set, transform):
def get_transform(train, args): def get_transform(train, args):
if train: if train:
return presets.SegmentationPresetTrain(base_size=520, crop_size=480) return presets.SegmentationPresetTrain(base_size=520, crop_size=480)
elif not args.prototype: elif args.weights and args.test_only:
return presets.SegmentationPresetEval(base_size=520) weights = torchvision.models.get_weight(args.weights)
trans = weights.transforms()
def preprocessing(img, target):
img = trans(img)
size = F.get_dimensions(img)[1:]
target = F.resize(target, size, interpolation=InterpolationMode.NEAREST)
return img, F.pil_to_tensor(target)
return preprocessing
else: else:
if args.weights: return presets.SegmentationPresetEval(base_size=520)
weights = prototype.models.get_weight(args.weights)
return weights.transforms()
else:
return prototype.transforms.SemanticSegmentationEval(resize_size=520)
def criterion(inputs, target): def criterion(inputs, target):
...@@ -100,10 +100,6 @@ def train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, devi ...@@ -100,10 +100,6 @@ def train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, devi
def main(args): def main(args):
if args.prototype and prototype is None:
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
if not args.prototype and args.weights:
raise ValueError("The weights parameter works only in prototype mode. Please pass the --prototype argument.")
if args.output_dir: if args.output_dir:
utils.mkdir(args.output_dir) utils.mkdir(args.output_dir)
...@@ -135,16 +131,9 @@ def main(args): ...@@ -135,16 +131,9 @@ def main(args):
dataset_test, batch_size=1, sampler=test_sampler, num_workers=args.workers, collate_fn=utils.collate_fn dataset_test, batch_size=1, sampler=test_sampler, num_workers=args.workers, collate_fn=utils.collate_fn
) )
if not args.prototype: model = torchvision.models.segmentation.__dict__[args.model](
model = torchvision.models.segmentation.__dict__[args.model]( weights=args.weights, weights_backbone=args.weights_backbone, num_classes=num_classes, aux_loss=args.aux_loss
pretrained=args.pretrained, )
num_classes=num_classes,
aux_loss=args.aux_loss,
)
else:
model = prototype.models.segmentation.__dict__[args.model](
weights=args.weights, num_classes=num_classes, aux_loss=args.aux_loss
)
model.to(device) model.to(device)
if args.distributed: if args.distributed:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
...@@ -272,24 +261,12 @@ def get_args_parser(add_help=True): ...@@ -272,24 +261,12 @@ def get_args_parser(add_help=True):
help="Only test the model", help="Only test the model",
action="store_true", action="store_true",
) )
parser.add_argument(
"--pretrained",
dest="pretrained",
help="Use pre-trained models from the modelzoo",
action="store_true",
)
# distributed training parameters # distributed training parameters
parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes") parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training") parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
# Prototype models only
parser.add_argument(
"--prototype",
dest="prototype",
help="Use prototype model builders instead those from main area",
action="store_true",
)
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
parser.add_argument("--weights-backbone", default=None, type=str, help="the backbone weights enum name to load")
# Mixed precision training parameters # Mixed precision training parameters
parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training") parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training")
......
...@@ -6,8 +6,9 @@ from transforms import ConvertBHWCtoBCHW, ConvertBCHWtoCBHW ...@@ -6,8 +6,9 @@ from transforms import ConvertBHWCtoBCHW, ConvertBCHWtoCBHW
class VideoClassificationPresetTrain: class VideoClassificationPresetTrain:
def __init__( def __init__(
self, self,
resize_size, *,
crop_size, crop_size,
resize_size,
mean=(0.43216, 0.394666, 0.37645), mean=(0.43216, 0.394666, 0.37645),
std=(0.22803, 0.22145, 0.216989), std=(0.22803, 0.22145, 0.216989),
hflip_prob=0.5, hflip_prob=0.5,
...@@ -27,7 +28,7 @@ class VideoClassificationPresetTrain: ...@@ -27,7 +28,7 @@ class VideoClassificationPresetTrain:
class VideoClassificationPresetEval: class VideoClassificationPresetEval:
def __init__(self, resize_size, crop_size, mean=(0.43216, 0.394666, 0.37645), std=(0.22803, 0.22145, 0.216989)): def __init__(self, *, crop_size, resize_size, mean=(0.43216, 0.394666, 0.37645), std=(0.22803, 0.22145, 0.216989)):
self.transforms = transforms.Compose( self.transforms = transforms.Compose(
[ [
ConvertBHWCtoBCHW(), ConvertBHWCtoBCHW(),
......
...@@ -12,11 +12,6 @@ from torch import nn ...@@ -12,11 +12,6 @@ from torch import nn
from torch.utils.data.dataloader import default_collate from torch.utils.data.dataloader import default_collate
from torchvision.datasets.samplers import DistributedSampler, UniformClipSampler, RandomClipSampler from torchvision.datasets.samplers import DistributedSampler, UniformClipSampler, RandomClipSampler
try:
from torchvision import prototype
except ImportError:
prototype = None
def train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, print_freq, scaler=None): def train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, print_freq, scaler=None):
model.train() model.train()
...@@ -96,17 +91,11 @@ def collate_fn(batch): ...@@ -96,17 +91,11 @@ def collate_fn(batch):
def main(args): def main(args):
if args.prototype and prototype is None:
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
if not args.prototype and args.weights:
raise ValueError("The weights parameter works only in prototype mode. Please pass the --prototype argument.")
if args.output_dir: if args.output_dir:
utils.mkdir(args.output_dir) utils.mkdir(args.output_dir)
utils.init_distributed_mode(args) utils.init_distributed_mode(args)
print(args) print(args)
print("torch version: ", torch.__version__)
print("torchvision version: ", torchvision.__version__)
device = torch.device(args.device) device = torch.device(args.device)
...@@ -120,7 +109,7 @@ def main(args): ...@@ -120,7 +109,7 @@ def main(args):
print("Loading training data") print("Loading training data")
st = time.time() st = time.time()
cache_path = _get_cache_path(traindir) cache_path = _get_cache_path(traindir)
transform_train = presets.VideoClassificationPresetTrain((128, 171), (112, 112)) transform_train = presets.VideoClassificationPresetTrain(crop_size=(112, 112), resize_size=(128, 171))
if args.cache_dataset and os.path.exists(cache_path): if args.cache_dataset and os.path.exists(cache_path):
print(f"Loading dataset_train from {cache_path}") print(f"Loading dataset_train from {cache_path}")
...@@ -150,14 +139,11 @@ def main(args): ...@@ -150,14 +139,11 @@ def main(args):
print("Loading validation data") print("Loading validation data")
cache_path = _get_cache_path(valdir) cache_path = _get_cache_path(valdir)
if not args.prototype: if args.weights and args.test_only:
transform_test = presets.VideoClassificationPresetEval(resize_size=(128, 171), crop_size=(112, 112)) weights = torchvision.models.get_weight(args.weights)
transform_test = weights.transforms()
else: else:
if args.weights: transform_test = presets.VideoClassificationPresetEval(crop_size=(112, 112), resize_size=(128, 171))
weights = prototype.models.get_weight(args.weights)
transform_test = weights.transforms()
else:
transform_test = prototype.transforms.VideoClassificationEval(crop_size=(112, 112), resize_size=(128, 171))
if args.cache_dataset and os.path.exists(cache_path): if args.cache_dataset and os.path.exists(cache_path):
print(f"Loading dataset_test from {cache_path}") print(f"Loading dataset_test from {cache_path}")
...@@ -208,10 +194,7 @@ def main(args): ...@@ -208,10 +194,7 @@ def main(args):
) )
print("Creating model") print("Creating model")
if not args.prototype: model = torchvision.models.video.__dict__[args.model](weights=args.weights)
model = torchvision.models.video.__dict__[args.model](pretrained=args.pretrained)
else:
model = prototype.models.video.__dict__[args.model](weights=args.weights)
model.to(device) model.to(device)
if args.distributed and args.sync_bn: if args.distributed and args.sync_bn:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
...@@ -352,24 +335,11 @@ def parse_args(): ...@@ -352,24 +335,11 @@ def parse_args():
help="Only test the model", help="Only test the model",
action="store_true", action="store_true",
) )
parser.add_argument(
"--pretrained",
dest="pretrained",
help="Use pre-trained models from the modelzoo",
action="store_true",
)
# distributed training parameters # distributed training parameters
parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes") parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training") parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
# Prototype models only
parser.add_argument(
"--prototype",
dest="prototype",
help="Use prototype model builders instead those from main area",
action="store_true",
)
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
# Mixed precision training parameters # Mixed precision training parameters
......
...@@ -13,36 +13,40 @@ from torchvision.models.feature_extraction import create_feature_extractor, get_ ...@@ -13,36 +13,40 @@ from torchvision.models.feature_extraction import create_feature_extractor, get_
def get_available_models(): def get_available_models():
# TODO add a registration mechanism to torchvision.models # TODO add a registration mechanism to torchvision.models
return [k for k, v in models.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"] return [
k
for k, v in models.__dict__.items()
if callable(v) and k[0].lower() == k[0] and k[0] != "_" and k != "get_weight"
]
@pytest.mark.parametrize("backbone_name", ("resnet18", "resnet50")) @pytest.mark.parametrize("backbone_name", ("resnet18", "resnet50"))
def test_resnet_fpn_backbone(backbone_name): def test_resnet_fpn_backbone(backbone_name):
x = torch.rand(1, 3, 300, 300, dtype=torch.float32, device="cpu") x = torch.rand(1, 3, 300, 300, dtype=torch.float32, device="cpu")
model = resnet_fpn_backbone(backbone_name=backbone_name, pretrained=False) model = resnet_fpn_backbone(backbone_name=backbone_name, weights=None)
assert isinstance(model, BackboneWithFPN) assert isinstance(model, BackboneWithFPN)
y = model(x) y = model(x)
assert list(y.keys()) == ["0", "1", "2", "3", "pool"] assert list(y.keys()) == ["0", "1", "2", "3", "pool"]
with pytest.raises(ValueError, match=r"Trainable layers should be in the range"): with pytest.raises(ValueError, match=r"Trainable layers should be in the range"):
resnet_fpn_backbone(backbone_name=backbone_name, pretrained=False, trainable_layers=6) resnet_fpn_backbone(backbone_name=backbone_name, weights=None, trainable_layers=6)
with pytest.raises(ValueError, match=r"Each returned layer should be in the range"): with pytest.raises(ValueError, match=r"Each returned layer should be in the range"):
resnet_fpn_backbone(backbone_name, False, returned_layers=[0, 1, 2, 3]) resnet_fpn_backbone(backbone_name=backbone_name, weights=None, returned_layers=[0, 1, 2, 3])
with pytest.raises(ValueError, match=r"Each returned layer should be in the range"): with pytest.raises(ValueError, match=r"Each returned layer should be in the range"):
resnet_fpn_backbone(backbone_name, False, returned_layers=[2, 3, 4, 5]) resnet_fpn_backbone(backbone_name=backbone_name, weights=None, returned_layers=[2, 3, 4, 5])
@pytest.mark.parametrize("backbone_name", ("mobilenet_v2", "mobilenet_v3_large", "mobilenet_v3_small")) @pytest.mark.parametrize("backbone_name", ("mobilenet_v2", "mobilenet_v3_large", "mobilenet_v3_small"))
def test_mobilenet_backbone(backbone_name): def test_mobilenet_backbone(backbone_name):
with pytest.raises(ValueError, match=r"Trainable layers should be in the range"): with pytest.raises(ValueError, match=r"Trainable layers should be in the range"):
mobilenet_backbone(backbone_name=backbone_name, pretrained=False, fpn=False, trainable_layers=-1) mobilenet_backbone(backbone_name=backbone_name, weights=None, fpn=False, trainable_layers=-1)
with pytest.raises(ValueError, match=r"Each returned layer should be in the range"): with pytest.raises(ValueError, match=r"Each returned layer should be in the range"):
mobilenet_backbone(backbone_name, False, fpn=True, returned_layers=[-1, 0, 1, 2]) mobilenet_backbone(backbone_name=backbone_name, weights=None, fpn=True, returned_layers=[-1, 0, 1, 2])
with pytest.raises(ValueError, match=r"Each returned layer should be in the range"): with pytest.raises(ValueError, match=r"Each returned layer should be in the range"):
mobilenet_backbone(backbone_name, False, fpn=True, returned_layers=[3, 4, 5, 6]) mobilenet_backbone(backbone_name=backbone_name, weights=None, fpn=True, returned_layers=[3, 4, 5, 6])
model_fpn = mobilenet_backbone(backbone_name, False, fpn=True) model_fpn = mobilenet_backbone(backbone_name=backbone_name, weights=None, fpn=True)
assert isinstance(model_fpn, BackboneWithFPN) assert isinstance(model_fpn, BackboneWithFPN)
model = mobilenet_backbone(backbone_name, False, fpn=False) model = mobilenet_backbone(backbone_name=backbone_name, weights=None, fpn=False)
assert isinstance(model, torch.nn.Sequential) assert isinstance(model, torch.nn.Sequential)
...@@ -96,7 +100,7 @@ test_module_nodes = [ ...@@ -96,7 +100,7 @@ test_module_nodes = [
class TestFxFeatureExtraction: class TestFxFeatureExtraction:
inp = torch.rand(1, 3, 224, 224, dtype=torch.float32, device="cpu") inp = torch.rand(1, 3, 224, 224, dtype=torch.float32, device="cpu")
model_defaults = {"num_classes": 1, "pretrained": False} model_defaults = {"num_classes": 1}
leaf_modules = [] leaf_modules = []
def _create_feature_extractor(self, *args, **kwargs): def _create_feature_extractor(self, *args, **kwargs):
......
...@@ -53,50 +53,49 @@ def read_image2(): ...@@ -53,50 +53,49 @@ def read_image2():
"see https://github.com/pytorch/vision/issues/1191", "see https://github.com/pytorch/vision/issues/1191",
) )
class Tester(unittest.TestCase): class Tester(unittest.TestCase):
pretrained = False
image = read_image1() image = read_image1()
def test_alexnet(self): def test_alexnet(self):
process_model(models.alexnet(self.pretrained), self.image, _C_tests.forward_alexnet, "Alexnet") process_model(models.alexnet(), self.image, _C_tests.forward_alexnet, "Alexnet")
def test_vgg11(self): def test_vgg11(self):
process_model(models.vgg11(self.pretrained), self.image, _C_tests.forward_vgg11, "VGG11") process_model(models.vgg11(), self.image, _C_tests.forward_vgg11, "VGG11")
def test_vgg13(self): def test_vgg13(self):
process_model(models.vgg13(self.pretrained), self.image, _C_tests.forward_vgg13, "VGG13") process_model(models.vgg13(), self.image, _C_tests.forward_vgg13, "VGG13")
def test_vgg16(self): def test_vgg16(self):
process_model(models.vgg16(self.pretrained), self.image, _C_tests.forward_vgg16, "VGG16") process_model(models.vgg16(), self.image, _C_tests.forward_vgg16, "VGG16")
def test_vgg19(self): def test_vgg19(self):
process_model(models.vgg19(self.pretrained), self.image, _C_tests.forward_vgg19, "VGG19") process_model(models.vgg19(), self.image, _C_tests.forward_vgg19, "VGG19")
def test_vgg11_bn(self): def test_vgg11_bn(self):
process_model(models.vgg11_bn(self.pretrained), self.image, _C_tests.forward_vgg11bn, "VGG11BN") process_model(models.vgg11_bn(), self.image, _C_tests.forward_vgg11bn, "VGG11BN")
def test_vgg13_bn(self): def test_vgg13_bn(self):
process_model(models.vgg13_bn(self.pretrained), self.image, _C_tests.forward_vgg13bn, "VGG13BN") process_model(models.vgg13_bn(), self.image, _C_tests.forward_vgg13bn, "VGG13BN")
def test_vgg16_bn(self): def test_vgg16_bn(self):
process_model(models.vgg16_bn(self.pretrained), self.image, _C_tests.forward_vgg16bn, "VGG16BN") process_model(models.vgg16_bn(), self.image, _C_tests.forward_vgg16bn, "VGG16BN")
def test_vgg19_bn(self): def test_vgg19_bn(self):
process_model(models.vgg19_bn(self.pretrained), self.image, _C_tests.forward_vgg19bn, "VGG19BN") process_model(models.vgg19_bn(), self.image, _C_tests.forward_vgg19bn, "VGG19BN")
def test_resnet18(self): def test_resnet18(self):
process_model(models.resnet18(self.pretrained), self.image, _C_tests.forward_resnet18, "Resnet18") process_model(models.resnet18(), self.image, _C_tests.forward_resnet18, "Resnet18")
def test_resnet34(self): def test_resnet34(self):
process_model(models.resnet34(self.pretrained), self.image, _C_tests.forward_resnet34, "Resnet34") process_model(models.resnet34(), self.image, _C_tests.forward_resnet34, "Resnet34")
def test_resnet50(self): def test_resnet50(self):
process_model(models.resnet50(self.pretrained), self.image, _C_tests.forward_resnet50, "Resnet50") process_model(models.resnet50(), self.image, _C_tests.forward_resnet50, "Resnet50")
def test_resnet101(self): def test_resnet101(self):
process_model(models.resnet101(self.pretrained), self.image, _C_tests.forward_resnet101, "Resnet101") process_model(models.resnet101(), self.image, _C_tests.forward_resnet101, "Resnet101")
def test_resnet152(self): def test_resnet152(self):
process_model(models.resnet152(self.pretrained), self.image, _C_tests.forward_resnet152, "Resnet152") process_model(models.resnet152(), self.image, _C_tests.forward_resnet152, "Resnet152")
def test_resnext50_32x4d(self): def test_resnext50_32x4d(self):
process_model(models.resnext50_32x4d(), self.image, _C_tests.forward_resnext50_32x4d, "ResNext50_32x4d") process_model(models.resnext50_32x4d(), self.image, _C_tests.forward_resnext50_32x4d, "ResNext50_32x4d")
...@@ -111,48 +110,44 @@ class Tester(unittest.TestCase): ...@@ -111,48 +110,44 @@ class Tester(unittest.TestCase):
process_model(models.wide_resnet101_2(), self.image, _C_tests.forward_wide_resnet101_2, "WideResNet101_2") process_model(models.wide_resnet101_2(), self.image, _C_tests.forward_wide_resnet101_2, "WideResNet101_2")
def test_squeezenet1_0(self): def test_squeezenet1_0(self):
process_model( process_model(models.squeezenet1_0(), self.image, _C_tests.forward_squeezenet1_0, "Squeezenet1.0")
models.squeezenet1_0(self.pretrained), self.image, _C_tests.forward_squeezenet1_0, "Squeezenet1.0"
)
def test_squeezenet1_1(self): def test_squeezenet1_1(self):
process_model( process_model(models.squeezenet1_1(), self.image, _C_tests.forward_squeezenet1_1, "Squeezenet1.1")
models.squeezenet1_1(self.pretrained), self.image, _C_tests.forward_squeezenet1_1, "Squeezenet1.1"
)
def test_densenet121(self): def test_densenet121(self):
process_model(models.densenet121(self.pretrained), self.image, _C_tests.forward_densenet121, "Densenet121") process_model(models.densenet121(), self.image, _C_tests.forward_densenet121, "Densenet121")
def test_densenet169(self): def test_densenet169(self):
process_model(models.densenet169(self.pretrained), self.image, _C_tests.forward_densenet169, "Densenet169") process_model(models.densenet169(), self.image, _C_tests.forward_densenet169, "Densenet169")
def test_densenet201(self): def test_densenet201(self):
process_model(models.densenet201(self.pretrained), self.image, _C_tests.forward_densenet201, "Densenet201") process_model(models.densenet201(), self.image, _C_tests.forward_densenet201, "Densenet201")
def test_densenet161(self): def test_densenet161(self):
process_model(models.densenet161(self.pretrained), self.image, _C_tests.forward_densenet161, "Densenet161") process_model(models.densenet161(), self.image, _C_tests.forward_densenet161, "Densenet161")
def test_mobilenet_v2(self): def test_mobilenet_v2(self):
process_model(models.mobilenet_v2(self.pretrained), self.image, _C_tests.forward_mobilenetv2, "MobileNet") process_model(models.mobilenet_v2(), self.image, _C_tests.forward_mobilenetv2, "MobileNet")
def test_googlenet(self): def test_googlenet(self):
process_model(models.googlenet(self.pretrained), self.image, _C_tests.forward_googlenet, "GoogLeNet") process_model(models.googlenet(), self.image, _C_tests.forward_googlenet, "GoogLeNet")
def test_mnasnet0_5(self): def test_mnasnet0_5(self):
process_model(models.mnasnet0_5(self.pretrained), self.image, _C_tests.forward_mnasnet0_5, "MNASNet0_5") process_model(models.mnasnet0_5(), self.image, _C_tests.forward_mnasnet0_5, "MNASNet0_5")
def test_mnasnet0_75(self): def test_mnasnet0_75(self):
process_model(models.mnasnet0_75(self.pretrained), self.image, _C_tests.forward_mnasnet0_75, "MNASNet0_75") process_model(models.mnasnet0_75(), self.image, _C_tests.forward_mnasnet0_75, "MNASNet0_75")
def test_mnasnet1_0(self): def test_mnasnet1_0(self):
process_model(models.mnasnet1_0(self.pretrained), self.image, _C_tests.forward_mnasnet1_0, "MNASNet1_0") process_model(models.mnasnet1_0(), self.image, _C_tests.forward_mnasnet1_0, "MNASNet1_0")
def test_mnasnet1_3(self): def test_mnasnet1_3(self):
process_model(models.mnasnet1_3(self.pretrained), self.image, _C_tests.forward_mnasnet1_3, "MNASNet1_3") process_model(models.mnasnet1_3(), self.image, _C_tests.forward_mnasnet1_3, "MNASNet1_3")
def test_inception_v3(self): def test_inception_v3(self):
self.image = read_image2() self.image = read_image2()
process_model(models.inception_v3(self.pretrained), self.image, _C_tests.forward_inceptionv3, "Inceptionv3") process_model(models.inception_v3(), self.image, _C_tests.forward_inceptionv3, "Inceptionv3")
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -4,21 +4,15 @@ import os ...@@ -4,21 +4,15 @@ import os
import pytest import pytest
import test_models as TM import test_models as TM
import torch import torch
from common_utils import cpu_and_gpu, needs_cuda from torchvision import models
from torchvision.prototype import models from torchvision.models._api import WeightsEnum, Weights
from torchvision.prototype.models._api import WeightsEnum, Weights from torchvision.models._utils import handle_legacy_interface
from torchvision.prototype.models._utils import handle_legacy_interface
run_if_test_with_prototype = pytest.mark.skipif(
os.getenv("PYTORCH_TEST_WITH_PROTOTYPE") != "1",
reason="Prototype tests are disabled by default. Set PYTORCH_TEST_WITH_PROTOTYPE=1 to run them.",
)
def _get_original_model(model_fn): run_if_test_with_extended = pytest.mark.skipif(
original_module_name = model_fn.__module__.replace(".prototype", "") os.getenv("PYTORCH_TEST_WITH_EXTENDED", "0") != "1",
module = importlib.import_module(original_module_name) reason="Extended tests are disabled by default. Set PYTORCH_TEST_WITH_EXTENDED=1 to run them.",
return module.__dict__[model_fn.__name__] )
def _get_parent_module(model_fn): def _get_parent_module(model_fn):
...@@ -40,17 +34,6 @@ def _get_model_weights(model_fn): ...@@ -40,17 +34,6 @@ def _get_model_weights(model_fn):
return None return None
def _build_model(fn, **kwargs):
try:
model = fn(**kwargs)
except ValueError as e:
msg = str(e)
if "No checkpoint is available" in msg:
pytest.skip(msg)
raise e
return model.eval()
@pytest.mark.parametrize( @pytest.mark.parametrize(
"name, weight", "name, weight",
[ [
...@@ -95,7 +78,7 @@ def test_naming_conventions(model_fn): ...@@ -95,7 +78,7 @@ def test_naming_conventions(model_fn):
+ TM.get_models_from_module(models.video) + TM.get_models_from_module(models.video)
+ TM.get_models_from_module(models.optical_flow), + TM.get_models_from_module(models.optical_flow),
) )
@run_if_test_with_prototype @run_if_test_with_extended
def test_schema_meta_validation(model_fn): def test_schema_meta_validation(model_fn):
classification_fields = ["size", "categories", "acc@1", "acc@5", "min_size"] classification_fields = ["size", "categories", "acc@1", "acc@5", "min_size"]
defaults = { defaults = {
...@@ -142,48 +125,6 @@ def test_schema_meta_validation(model_fn): ...@@ -142,48 +125,6 @@ def test_schema_meta_validation(model_fn):
assert not bad_names assert not bad_names
@pytest.mark.parametrize("model_fn", TM.get_models_from_module(models))
@pytest.mark.parametrize("dev", cpu_and_gpu())
@run_if_test_with_prototype
def test_classification_model(model_fn, dev):
TM.test_classification_model(model_fn, dev)
@pytest.mark.parametrize("model_fn", TM.get_models_from_module(models.detection))
@pytest.mark.parametrize("dev", cpu_and_gpu())
@run_if_test_with_prototype
def test_detection_model(model_fn, dev):
TM.test_detection_model(model_fn, dev)
@pytest.mark.parametrize("model_fn", TM.get_models_from_module(models.quantization))
@run_if_test_with_prototype
def test_quantized_classification_model(model_fn):
TM.test_quantized_classification_model(model_fn)
@pytest.mark.parametrize("model_fn", TM.get_models_from_module(models.segmentation))
@pytest.mark.parametrize("dev", cpu_and_gpu())
@run_if_test_with_prototype
def test_segmentation_model(model_fn, dev):
TM.test_segmentation_model(model_fn, dev)
@pytest.mark.parametrize("model_fn", TM.get_models_from_module(models.video))
@pytest.mark.parametrize("dev", cpu_and_gpu())
@run_if_test_with_prototype
def test_video_model(model_fn, dev):
TM.test_video_model(model_fn, dev)
@needs_cuda
@pytest.mark.parametrize("model_builder", TM.get_models_from_module(models.optical_flow))
@pytest.mark.parametrize("scripted", (False, True))
@run_if_test_with_prototype
def test_raft(model_builder, scripted):
TM.test_raft(model_builder, scripted)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model_fn", "model_fn",
TM.get_models_from_module(models) TM.get_models_from_module(models)
...@@ -193,9 +134,13 @@ def test_raft(model_builder, scripted): ...@@ -193,9 +134,13 @@ def test_raft(model_builder, scripted):
+ TM.get_models_from_module(models.video) + TM.get_models_from_module(models.video)
+ TM.get_models_from_module(models.optical_flow), + TM.get_models_from_module(models.optical_flow),
) )
@pytest.mark.parametrize("dev", cpu_and_gpu()) @run_if_test_with_extended
@run_if_test_with_prototype def test_transforms_jit(model_fn):
def test_old_vs_new_factory(model_fn, dev): model_name = model_fn.__name__
weights_enum = _get_model_weights(model_fn)
if len(weights_enum) == 0:
pytest.skip(f"Model '{model_name}' doesn't have any pre-trained weights.")
defaults = { defaults = {
"models": { "models": {
"input_shape": (1, 3, 224, 224), "input_shape": (1, 3, 224, 224),
...@@ -205,43 +150,36 @@ def test_old_vs_new_factory(model_fn, dev): ...@@ -205,43 +150,36 @@ def test_old_vs_new_factory(model_fn, dev):
}, },
"quantization": { "quantization": {
"input_shape": (1, 3, 224, 224), "input_shape": (1, 3, 224, 224),
"quantize": True,
}, },
"segmentation": { "segmentation": {
"input_shape": (1, 3, 520, 520), "input_shape": (1, 3, 520, 520),
}, },
"video": { "video": {
"input_shape": (1, 3, 4, 112, 112), "input_shape": (1, 4, 112, 112, 3),
}, },
"optical_flow": { "optical_flow": {
"input_shape": (1, 3, 128, 128), "input_shape": (1, 3, 128, 128),
}, },
} }
model_name = model_fn.__name__
module_name = model_fn.__module__.split(".")[-2] module_name = model_fn.__module__.split(".")[-2]
kwargs = {"pretrained": True, **defaults[module_name], **TM._model_params.get(model_name, {})}
input_shape = kwargs.pop("input_shape")
kwargs.pop("num_classes", None) # ignore this as it's an incompatible speed optimization for pre-trained models
x = torch.rand(input_shape).to(device=dev)
if module_name == "detection":
x = [x]
kwargs = {**defaults[module_name], **TM._model_params.get(model_name, {})}
input_shape = kwargs.pop("input_shape")
x = torch.rand(input_shape)
if module_name == "optical_flow": if module_name == "optical_flow":
args = [x, x] # RAFT model requires img1, img2 as input args = (x, x)
else: else:
args = [x] args = (x,)
# compare with new model builder parameterized in the old fashion way
try:
model_old = _build_model(_get_original_model(model_fn), **kwargs).to(device=dev)
model_new = _build_model(model_fn, **kwargs).to(device=dev)
except ModuleNotFoundError:
pytest.skip(f"Model '{model_name}' not available in both modules.")
torch.testing.assert_close(model_new(*args), model_old(*args), rtol=0.0, atol=0.0, check_dtype=False)
problematic_weights = []
for w in weights_enum:
transforms = w.transforms()
try:
TM._check_jit_scriptable(transforms, args)
except Exception:
problematic_weights.append(w)
def test_smoke(): assert not problematic_weights
import torchvision.prototype.models # noqa: F401
# With this filter, every unexpected warning will be turned into an error # With this filter, every unexpected warning will be turned into an error
......
...@@ -26,13 +26,13 @@ class TestHub: ...@@ -26,13 +26,13 @@ class TestHub:
# Python cache as we run all hub tests in the same python process. # Python cache as we run all hub tests in the same python process.
def test_load_from_github(self): def test_load_from_github(self):
hub_model = hub.load("pytorch/vision", "resnet18", pretrained=True, progress=False) hub_model = hub.load("pytorch/vision", "resnet18", weights="DEFAULT", progress=False)
assert sum_of_model_parameters(hub_model).item() == pytest.approx(SUM_OF_PRETRAINED_RESNET18_PARAMS) assert sum_of_model_parameters(hub_model).item() == pytest.approx(SUM_OF_PRETRAINED_RESNET18_PARAMS)
def test_set_dir(self): def test_set_dir(self):
temp_dir = tempfile.gettempdir() temp_dir = tempfile.gettempdir()
hub.set_dir(temp_dir) hub.set_dir(temp_dir)
hub_model = hub.load("pytorch/vision", "resnet18", pretrained=True, progress=False) hub_model = hub.load("pytorch/vision", "resnet18", weights="DEFAULT", progress=False)
assert sum_of_model_parameters(hub_model).item() == pytest.approx(SUM_OF_PRETRAINED_RESNET18_PARAMS) assert sum_of_model_parameters(hub_model).item() == pytest.approx(SUM_OF_PRETRAINED_RESNET18_PARAMS)
assert os.path.exists(temp_dir + "/pytorch_vision_master") assert os.path.exists(temp_dir + "/pytorch_vision_master")
shutil.rmtree(temp_dir + "/pytorch_vision_master") shutil.rmtree(temp_dir + "/pytorch_vision_master")
......
...@@ -133,8 +133,7 @@ def _check_jit_scriptable(nn_module, args, unwrapper=None, eager_out=None): ...@@ -133,8 +133,7 @@ def _check_jit_scriptable(nn_module, args, unwrapper=None, eager_out=None):
if eager_out is None: if eager_out is None:
with torch.no_grad(), freeze_rng_state(): with torch.no_grad(), freeze_rng_state():
if unwrapper: eager_out = nn_module(*args)
eager_out = nn_module(*args)
with torch.no_grad(), freeze_rng_state(): with torch.no_grad(), freeze_rng_state():
script_out = sm(*args) script_out = sm(*args)
...@@ -414,7 +413,6 @@ def test_mobilenet_norm_layer(model_fn): ...@@ -414,7 +413,6 @@ def test_mobilenet_norm_layer(model_fn):
def test_inception_v3_eval(): def test_inception_v3_eval():
# replacement for models.inception_v3(pretrained=True) that does not download weights
kwargs = {} kwargs = {}
kwargs["transform_input"] = True kwargs["transform_input"] = True
kwargs["aux_logits"] = True kwargs["aux_logits"] = True
...@@ -430,7 +428,7 @@ def test_inception_v3_eval(): ...@@ -430,7 +428,7 @@ def test_inception_v3_eval():
def test_fasterrcnn_double(): def test_fasterrcnn_double():
model = models.detection.fasterrcnn_resnet50_fpn(num_classes=50, pretrained_backbone=False) model = models.detection.fasterrcnn_resnet50_fpn(num_classes=50, weights=None, weights_backbone=None)
model.double() model.double()
model.eval() model.eval()
input_shape = (3, 300, 300) input_shape = (3, 300, 300)
...@@ -446,7 +444,6 @@ def test_fasterrcnn_double(): ...@@ -446,7 +444,6 @@ def test_fasterrcnn_double():
def test_googlenet_eval(): def test_googlenet_eval():
# replacement for models.googlenet(pretrained=True) that does not download weights
kwargs = {} kwargs = {}
kwargs["transform_input"] = True kwargs["transform_input"] = True
kwargs["aux_logits"] = True kwargs["aux_logits"] = True
...@@ -470,7 +467,7 @@ def test_fasterrcnn_switch_devices(): ...@@ -470,7 +467,7 @@ def test_fasterrcnn_switch_devices():
assert "scores" in out[0] assert "scores" in out[0]
assert "labels" in out[0] assert "labels" in out[0]
model = models.detection.fasterrcnn_resnet50_fpn(num_classes=50, pretrained_backbone=False) model = models.detection.fasterrcnn_resnet50_fpn(num_classes=50, weights=None, weights_backbone=None)
model.cuda() model.cuda()
model.eval() model.eval()
input_shape = (3, 300, 300) input_shape = (3, 300, 300)
...@@ -586,7 +583,7 @@ def test_segmentation_model(model_fn, dev): ...@@ -586,7 +583,7 @@ def test_segmentation_model(model_fn, dev):
set_rng_seed(0) set_rng_seed(0)
defaults = { defaults = {
"num_classes": 10, "num_classes": 10,
"pretrained_backbone": False, "weights_backbone": None,
"input_shape": (1, 3, 32, 32), "input_shape": (1, 3, 32, 32),
} }
model_name = model_fn.__name__ model_name = model_fn.__name__
...@@ -648,7 +645,7 @@ def test_detection_model(model_fn, dev): ...@@ -648,7 +645,7 @@ def test_detection_model(model_fn, dev):
set_rng_seed(0) set_rng_seed(0)
defaults = { defaults = {
"num_classes": 50, "num_classes": 50,
"pretrained_backbone": False, "weights_backbone": None,
"input_shape": (3, 300, 300), "input_shape": (3, 300, 300),
} }
model_name = model_fn.__name__ model_name = model_fn.__name__
...@@ -743,7 +740,7 @@ def test_detection_model(model_fn, dev): ...@@ -743,7 +740,7 @@ def test_detection_model(model_fn, dev):
@pytest.mark.parametrize("model_fn", get_models_from_module(models.detection)) @pytest.mark.parametrize("model_fn", get_models_from_module(models.detection))
def test_detection_model_validation(model_fn): def test_detection_model_validation(model_fn):
set_rng_seed(0) set_rng_seed(0)
model = model_fn(num_classes=50, pretrained_backbone=False) model = model_fn(num_classes=50, weights=None, weights_backbone=None)
input_shape = (3, 300, 300) input_shape = (3, 300, 300)
x = [torch.rand(input_shape)] x = [torch.rand(input_shape)]
...@@ -807,7 +804,6 @@ def test_quantized_classification_model(model_fn): ...@@ -807,7 +804,6 @@ def test_quantized_classification_model(model_fn):
defaults = { defaults = {
"num_classes": 5, "num_classes": 5,
"input_shape": (1, 3, 224, 224), "input_shape": (1, 3, 224, 224),
"pretrained": False,
"quantize": True, "quantize": True,
} }
model_name = model_fn.__name__ model_name = model_fn.__name__
...@@ -857,7 +853,7 @@ def test_detection_model_trainable_backbone_layers(model_fn, disable_weight_load ...@@ -857,7 +853,7 @@ def test_detection_model_trainable_backbone_layers(model_fn, disable_weight_load
max_trainable = _model_tests_values[model_name]["max_trainable"] max_trainable = _model_tests_values[model_name]["max_trainable"]
n_trainable_params = [] n_trainable_params = []
for trainable_layers in range(0, max_trainable + 1): for trainable_layers in range(0, max_trainable + 1):
model = model_fn(pretrained=False, pretrained_backbone=True, trainable_backbone_layers=trainable_layers) model = model_fn(weights=None, weights_backbone="DEFAULT", trainable_backbone_layers=trainable_layers)
n_trainable_params.append(len([p for p in model.parameters() if p.requires_grad])) n_trainable_params.append(len([p for p in model.parameters() if p.requires_grad]))
assert n_trainable_params == _model_tests_values[model_name]["n_trn_params_per_layer"] assert n_trainable_params == _model_tests_values[model_name]["n_trn_params_per_layer"]
......
...@@ -100,7 +100,7 @@ class TestModelsDetectionNegativeSamples: ...@@ -100,7 +100,7 @@ class TestModelsDetectionNegativeSamples:
) )
def test_forward_negative_sample_frcnn(self, name): def test_forward_negative_sample_frcnn(self, name):
model = torchvision.models.detection.__dict__[name]( model = torchvision.models.detection.__dict__[name](
num_classes=2, min_size=100, max_size=100, pretrained_backbone=False weights=None, weights_backbone=None, num_classes=2, min_size=100, max_size=100
) )
images, targets = self._make_empty_sample() images, targets = self._make_empty_sample()
...@@ -111,7 +111,7 @@ class TestModelsDetectionNegativeSamples: ...@@ -111,7 +111,7 @@ class TestModelsDetectionNegativeSamples:
def test_forward_negative_sample_mrcnn(self): def test_forward_negative_sample_mrcnn(self):
model = torchvision.models.detection.maskrcnn_resnet50_fpn( model = torchvision.models.detection.maskrcnn_resnet50_fpn(
num_classes=2, min_size=100, max_size=100, pretrained_backbone=False weights=None, weights_backbone=None, num_classes=2, min_size=100, max_size=100
) )
images, targets = self._make_empty_sample(add_masks=True) images, targets = self._make_empty_sample(add_masks=True)
...@@ -123,7 +123,7 @@ class TestModelsDetectionNegativeSamples: ...@@ -123,7 +123,7 @@ class TestModelsDetectionNegativeSamples:
def test_forward_negative_sample_krcnn(self): def test_forward_negative_sample_krcnn(self):
model = torchvision.models.detection.keypointrcnn_resnet50_fpn( model = torchvision.models.detection.keypointrcnn_resnet50_fpn(
num_classes=2, min_size=100, max_size=100, pretrained_backbone=False weights=None, weights_backbone=None, num_classes=2, min_size=100, max_size=100
) )
images, targets = self._make_empty_sample(add_keypoints=True) images, targets = self._make_empty_sample(add_keypoints=True)
...@@ -135,7 +135,7 @@ class TestModelsDetectionNegativeSamples: ...@@ -135,7 +135,7 @@ class TestModelsDetectionNegativeSamples:
def test_forward_negative_sample_retinanet(self): def test_forward_negative_sample_retinanet(self):
model = torchvision.models.detection.retinanet_resnet50_fpn( model = torchvision.models.detection.retinanet_resnet50_fpn(
num_classes=2, min_size=100, max_size=100, pretrained_backbone=False weights=None, weights_backbone=None, num_classes=2, min_size=100, max_size=100
) )
images, targets = self._make_empty_sample() images, targets = self._make_empty_sample()
...@@ -145,7 +145,7 @@ class TestModelsDetectionNegativeSamples: ...@@ -145,7 +145,7 @@ class TestModelsDetectionNegativeSamples:
def test_forward_negative_sample_fcos(self): def test_forward_negative_sample_fcos(self):
model = torchvision.models.detection.fcos_resnet50_fpn( model = torchvision.models.detection.fcos_resnet50_fpn(
num_classes=2, min_size=100, max_size=100, pretrained_backbone=False weights=None, weights_backbone=None, num_classes=2, min_size=100, max_size=100
) )
images, targets = self._make_empty_sample() images, targets = self._make_empty_sample()
...@@ -155,7 +155,7 @@ class TestModelsDetectionNegativeSamples: ...@@ -155,7 +155,7 @@ class TestModelsDetectionNegativeSamples:
assert_equal(loss_dict["bbox_ctrness"], torch.tensor(0.0)) assert_equal(loss_dict["bbox_ctrness"], torch.tensor(0.0))
def test_forward_negative_sample_ssd(self): def test_forward_negative_sample_ssd(self):
model = torchvision.models.detection.ssd300_vgg16(num_classes=2, pretrained_backbone=False) model = torchvision.models.detection.ssd300_vgg16(weights=None, weights_backbone=None, num_classes=2)
images, targets = self._make_empty_sample() images, targets = self._make_empty_sample()
loss_dict = model(images, targets) loss_dict = model(images, targets)
......
...@@ -40,7 +40,7 @@ class TestModelsDetectionUtils: ...@@ -40,7 +40,7 @@ class TestModelsDetectionUtils:
# be frozen for each trainable_backbone_layers parameter value # be frozen for each trainable_backbone_layers parameter value
# i.e all 53 params are frozen if trainable_backbone_layers=0 # i.e all 53 params are frozen if trainable_backbone_layers=0
# ad first 24 params are frozen if trainable_backbone_layers=2 # ad first 24 params are frozen if trainable_backbone_layers=2
model = backbone_utils.resnet_fpn_backbone("resnet50", pretrained=False, trainable_layers=train_layers) model = backbone_utils.resnet_fpn_backbone("resnet50", weights=None, trainable_layers=train_layers)
# boolean list that is true if the param at that index is frozen # boolean list that is true if the param at that index is frozen
is_frozen = [not parameter.requires_grad for _, parameter in model.named_parameters()] is_frozen = [not parameter.requires_grad for _, parameter in model.named_parameters()]
# check that expected initial number of layers are frozen # check that expected initial number of layers are frozen
...@@ -49,18 +49,18 @@ class TestModelsDetectionUtils: ...@@ -49,18 +49,18 @@ class TestModelsDetectionUtils:
def test_validate_resnet_inputs_detection(self): def test_validate_resnet_inputs_detection(self):
# default number of backbone layers to train # default number of backbone layers to train
ret = backbone_utils._validate_trainable_layers( ret = backbone_utils._validate_trainable_layers(
pretrained=True, trainable_backbone_layers=None, max_value=5, default_value=3 is_trained=True, trainable_backbone_layers=None, max_value=5, default_value=3
) )
assert ret == 3 assert ret == 3
# can't go beyond 5 # can't go beyond 5
with pytest.raises(ValueError, match=r"Trainable backbone layers should be in the range"): with pytest.raises(ValueError, match=r"Trainable backbone layers should be in the range"):
ret = backbone_utils._validate_trainable_layers( ret = backbone_utils._validate_trainable_layers(
pretrained=True, trainable_backbone_layers=6, max_value=5, default_value=3 is_trained=True, trainable_backbone_layers=6, max_value=5, default_value=3
) )
# if not pretrained, should use all trainable layers and warn # if not trained, should use all trainable layers and warn
with pytest.warns(UserWarning): with pytest.warns(UserWarning):
ret = backbone_utils._validate_trainable_layers( ret = backbone_utils._validate_trainable_layers(
pretrained=False, trainable_backbone_layers=0, max_value=5, default_value=3 is_trained=False, trainable_backbone_layers=0, max_value=5, default_value=3
) )
assert ret == 5 assert ret == 5
......
...@@ -430,7 +430,9 @@ class TestONNXExporter: ...@@ -430,7 +430,9 @@ class TestONNXExporter:
def test_faster_rcnn(self): def test_faster_rcnn(self):
images, test_images = self.get_test_images() images, test_images = self.get_test_images()
dummy_image = [torch.ones(3, 100, 100) * 0.3] dummy_image = [torch.ones(3, 100, 100) * 0.3]
model = models.detection.faster_rcnn.fasterrcnn_resnet50_fpn(pretrained=True, min_size=200, max_size=300) model = models.detection.faster_rcnn.fasterrcnn_resnet50_fpn(
weights=models.detection.faster_rcnn.FasterRCNN_ResNet50_FPN_Weights.DEFAULT, min_size=200, max_size=300
)
model.eval() model.eval()
model(images) model(images)
# Test exported model on images of different size, or dummy input # Test exported model on images of different size, or dummy input
...@@ -486,7 +488,9 @@ class TestONNXExporter: ...@@ -486,7 +488,9 @@ class TestONNXExporter:
def test_mask_rcnn(self): def test_mask_rcnn(self):
images, test_images = self.get_test_images() images, test_images = self.get_test_images()
dummy_image = [torch.ones(3, 100, 100) * 0.3] dummy_image = [torch.ones(3, 100, 100) * 0.3]
model = models.detection.mask_rcnn.maskrcnn_resnet50_fpn(pretrained=True, min_size=200, max_size=300) model = models.detection.mask_rcnn.maskrcnn_resnet50_fpn(
weights=models.detection.mask_rcnn.MaskRCNN_ResNet50_FPN_Weights.DEFAULT, min_size=200, max_size=300
)
model.eval() model.eval()
model(images) model(images)
# Test exported model on images of different size, or dummy input # Test exported model on images of different size, or dummy input
...@@ -548,7 +552,9 @@ class TestONNXExporter: ...@@ -548,7 +552,9 @@ class TestONNXExporter:
def test_keypoint_rcnn(self): def test_keypoint_rcnn(self):
images, test_images = self.get_test_images() images, test_images = self.get_test_images()
dummy_images = [torch.ones(3, 100, 100) * 0.3] dummy_images = [torch.ones(3, 100, 100) * 0.3]
model = models.detection.keypoint_rcnn.keypointrcnn_resnet50_fpn(pretrained=True, min_size=200, max_size=300) model = models.detection.keypoint_rcnn.keypointrcnn_resnet50_fpn(
weights=models.detection.keypoint_rcnn.KeypointRCNN_ResNet50_FPN_Weights.DEFAULT, min_size=200, max_size=300
)
model.eval() model.eval()
model(images) model(images)
self.run_model( self.run_model(
...@@ -570,7 +576,7 @@ class TestONNXExporter: ...@@ -570,7 +576,7 @@ class TestONNXExporter:
) )
def test_shufflenet_v2_dynamic_axes(self): def test_shufflenet_v2_dynamic_axes(self):
model = models.shufflenet_v2_x0_5(pretrained=True) model = models.shufflenet_v2_x0_5(weights=models.ShuffleNet_V2_X0_5_Weights.DEFAULT)
dummy_input = torch.randn(1, 3, 224, 224, requires_grad=True) dummy_input = torch.randn(1, 3, 224, 224, requires_grad=True)
test_inputs = torch.cat([dummy_input, dummy_input, dummy_input], 0) test_inputs = torch.cat([dummy_input, dummy_input, dummy_input], 0)
......
...@@ -6,7 +6,7 @@ import torchvision ...@@ -6,7 +6,7 @@ import torchvision
HERE = osp.dirname(osp.abspath(__file__)) HERE = osp.dirname(osp.abspath(__file__))
ASSETS = osp.dirname(osp.dirname(HERE)) ASSETS = osp.dirname(osp.dirname(HERE))
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=False) model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=None, weights_backbone=None)
model.eval() model.eval()
traced_model = torch.jit.script(model) traced_model = torch.jit.script(model)
......
from .alexnet import * from .alexnet import *
from .convnext import * from .convnext import *
from .resnet import *
from .vgg import *
from .squeezenet import *
from .inception import *
from .densenet import * from .densenet import *
from .efficientnet import *
from .googlenet import * from .googlenet import *
from .mobilenet import * from .inception import *
from .mnasnet import * from .mnasnet import *
from .shufflenetv2 import * from .mobilenet import *
from .efficientnet import *
from .regnet import * from .regnet import *
from .resnet import *
from .shufflenetv2 import *
from .squeezenet import *
from .vgg import *
from .vision_transformer import * from .vision_transformer import *
from . import detection from . import detection
from . import feature_extraction
from . import optical_flow from . import optical_flow
from . import quantization from . import quantization
from . import segmentation from . import segmentation
from . import video from . import video
from ._api import get_weight
...@@ -3,11 +3,12 @@ import inspect ...@@ -3,11 +3,12 @@ import inspect
import sys import sys
from collections import OrderedDict from collections import OrderedDict
from dataclasses import dataclass, fields from dataclasses import dataclass, fields
from typing import Any, Callable, Dict from inspect import signature
from typing import Any, Callable, Dict, cast
from torchvision._utils import StrEnum from torchvision._utils import StrEnum
from ..._internally_replaced_utils import load_state_dict_from_url from .._internally_replaced_utils import load_state_dict_from_url
__all__ = ["WeightsEnum", "Weights", "get_weight"] __all__ = ["WeightsEnum", "Weights", "get_weight"]
...@@ -105,3 +106,38 @@ def get_weight(name: str) -> WeightsEnum: ...@@ -105,3 +106,38 @@ def get_weight(name: str) -> WeightsEnum:
raise ValueError(f"The weight enum '{enum_name}' for the specific method couldn't be retrieved.") raise ValueError(f"The weight enum '{enum_name}' for the specific method couldn't be retrieved.")
return weights_enum.from_str(value_name) return weights_enum.from_str(value_name)
def get_enum_from_fn(fn: Callable) -> WeightsEnum:
"""
Internal method that gets the weight enum of a specific model builder method.
Might be removed after the handle_legacy_interface is removed.
Args:
fn (Callable): The builder method used to create the model.
weight_name (str): The name of the weight enum entry of the specific model.
Returns:
WeightsEnum: The requested weight enum.
"""
sig = signature(fn)
if "weights" not in sig.parameters:
raise ValueError("The method is missing the 'weights' argument.")
ann = signature(fn).parameters["weights"].annotation
weights_enum = None
if isinstance(ann, type) and issubclass(ann, WeightsEnum):
weights_enum = ann
else:
# handle cases like Union[Optional, T]
# TODO: Replace ann.__args__ with typing.get_args(ann) after python >= 3.8
for t in ann.__args__: # type: ignore[union-attr]
if isinstance(t, type) and issubclass(t, WeightsEnum):
weights_enum = t
break
if weights_enum is None:
raise ValueError(
"The WeightsEnum class for the specific method couldn't be retrieved. Make sure the typing info is correct."
)
return cast(WeightsEnum, weights_enum)
import functools
import inspect
import warnings
from collections import OrderedDict from collections import OrderedDict
from typing import Dict, Optional from typing import Any, Dict, Optional, TypeVar, Callable, Tuple, Union
from torch import nn from torch import nn
from .._utils import sequence_to_str
from ._api import WeightsEnum
class IntermediateLayerGetter(nn.ModuleDict): class IntermediateLayerGetter(nn.ModuleDict):
""" """
...@@ -26,7 +32,7 @@ class IntermediateLayerGetter(nn.ModuleDict): ...@@ -26,7 +32,7 @@ class IntermediateLayerGetter(nn.ModuleDict):
Examples:: Examples::
>>> m = torchvision.models.resnet18(pretrained=True) >>> m = torchvision.models.resnet18(weights=ResNet18_Weights.DEFAULT)
>>> # extract layer1 and layer3, giving as names `feat1` and feat2` >>> # extract layer1 and layer3, giving as names `feat1` and feat2`
>>> new_m = torchvision.models._utils.IntermediateLayerGetter(m, >>> new_m = torchvision.models._utils.IntermediateLayerGetter(m,
>>> {'layer1': 'feat1', 'layer3': 'feat2'}) >>> {'layer1': 'feat1', 'layer3': 'feat2'})
...@@ -81,3 +87,158 @@ def _make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> ...@@ -81,3 +87,158 @@ def _make_divisible(v: float, divisor: int, min_value: Optional[int] = None) ->
if new_v < 0.9 * v: if new_v < 0.9 * v:
new_v += divisor new_v += divisor
return new_v return new_v
D = TypeVar("D")
def kwonly_to_pos_or_kw(fn: Callable[..., D]) -> Callable[..., D]:
"""Decorates a function that uses keyword only parameters to also allow them being passed as positionals.
For example, consider the use case of changing the signature of ``old_fn`` into the one from ``new_fn``:
.. code::
def old_fn(foo, bar, baz=None):
...
def new_fn(foo, *, bar, baz=None):
...
Calling ``old_fn("foo", "bar, "baz")`` was valid, but the same call is no longer valid with ``new_fn``. To keep BC
and at the same time warn the user of the deprecation, this decorator can be used:
.. code::
@kwonly_to_pos_or_kw
def new_fn(foo, *, bar, baz=None):
...
new_fn("foo", "bar, "baz")
"""
params = inspect.signature(fn).parameters
try:
keyword_only_start_idx = next(
idx for idx, param in enumerate(params.values()) if param.kind == param.KEYWORD_ONLY
)
except StopIteration:
raise TypeError(f"Found no keyword-only parameter on function '{fn.__name__}'") from None
keyword_only_params = tuple(inspect.signature(fn).parameters)[keyword_only_start_idx:]
@functools.wraps(fn)
def wrapper(*args: Any, **kwargs: Any) -> D:
args, keyword_only_args = args[:keyword_only_start_idx], args[keyword_only_start_idx:]
if keyword_only_args:
keyword_only_kwargs = dict(zip(keyword_only_params, keyword_only_args))
warnings.warn(
f"Using {sequence_to_str(tuple(keyword_only_kwargs.keys()), separate_last='and ')} as positional "
f"parameter(s) is deprecated. Please use keyword parameter(s) instead."
)
kwargs.update(keyword_only_kwargs)
return fn(*args, **kwargs)
return wrapper
W = TypeVar("W", bound=WeightsEnum)
M = TypeVar("M", bound=nn.Module)
V = TypeVar("V")
def handle_legacy_interface(**weights: Tuple[str, Union[Optional[W], Callable[[Dict[str, Any]], Optional[W]]]]):
"""Decorates a model builder with the new interface to make it compatible with the old.
In particular this handles two things:
1. Allows positional parameters again, but emits a deprecation warning in case they are used. See
:func:`torchvision.prototype.utils._internal.kwonly_to_pos_or_kw` for details.
2. Handles the default value change from ``pretrained=False`` to ``weights=None`` and ``pretrained=True`` to
``weights=Weights`` and emits a deprecation warning with instructions for the new interface.
Args:
**weights (Tuple[str, Union[Optional[W], Callable[[Dict[str, Any]], Optional[W]]]]): Deprecated parameter
name and default value for the legacy ``pretrained=True``. The default value can be a callable in which
case it will be called with a dictionary of the keyword arguments. The only key that is guaranteed to be in
the dictionary is the deprecated parameter name passed as first element in the tuple. All other parameters
should be accessed with :meth:`~dict.get`.
"""
def outer_wrapper(builder: Callable[..., M]) -> Callable[..., M]:
@kwonly_to_pos_or_kw
@functools.wraps(builder)
def inner_wrapper(*args: Any, **kwargs: Any) -> M:
for weights_param, (pretrained_param, default) in weights.items(): # type: ignore[union-attr]
# If neither the weights nor the pretrained parameter as passed, or the weights argument already use
# the new style arguments, there is nothing to do. Note that we cannot use `None` as sentinel for the
# weight argument, since it is a valid value.
sentinel = object()
weights_arg = kwargs.get(weights_param, sentinel)
if (
(weights_param not in kwargs and pretrained_param not in kwargs)
or isinstance(weights_arg, WeightsEnum)
or (isinstance(weights_arg, str) and weights_arg != "legacy")
or weights_arg is None
):
continue
# If the pretrained parameter was passed as positional argument, it is now mapped to
# `kwargs[weights_param]`. This happens because the @kwonly_to_pos_or_kw decorator uses the current
# signature to infer the names of positionally passed arguments and thus has no knowledge that there
# used to be a pretrained parameter.
pretrained_positional = weights_arg is not sentinel
if pretrained_positional:
# We put the pretrained argument under its legacy name in the keyword argument dictionary to have a
# unified access to the value if the default value is a callable.
kwargs[pretrained_param] = pretrained_arg = kwargs.pop(weights_param)
else:
pretrained_arg = kwargs[pretrained_param]
if pretrained_arg:
default_weights_arg = default(kwargs) if callable(default) else default
if not isinstance(default_weights_arg, WeightsEnum):
raise ValueError(f"No weights available for model {builder.__name__}")
else:
default_weights_arg = None
if not pretrained_positional:
warnings.warn(
f"The parameter '{pretrained_param}' is deprecated, please use '{weights_param}' instead."
)
msg = (
f"Arguments other than a weight enum or `None` for '{weights_param}' are deprecated. "
f"The current behavior is equivalent to passing `{weights_param}={default_weights_arg}`."
)
if pretrained_arg:
msg = (
f"{msg} You can also use `{weights_param}={type(default_weights_arg).__name__}.DEFAULT` "
f"to get the most up-to-date weights."
)
warnings.warn(msg)
del kwargs[pretrained_param]
kwargs[weights_param] = default_weights_arg
return builder(*args, **kwargs)
return inner_wrapper
return outer_wrapper
def _ovewrite_named_param(kwargs: Dict[str, Any], param: str, new_value: V) -> None:
if param in kwargs:
if kwargs[param] != new_value:
raise ValueError(f"The parameter '{param}' expected value {new_value} but got {kwargs[param]} instead.")
else:
kwargs[param] = new_value
def _ovewrite_value_param(param: Optional[V], new_value: V) -> V:
if param is not None:
if param != new_value:
raise ValueError(f"The parameter '{param}' expected value {new_value} but got {param} instead.")
return new_value
from typing import Any from functools import partial
from typing import Any, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from .._internally_replaced_utils import load_state_dict_from_url from ..transforms._presets import ImageClassification, InterpolationMode
from ..utils import _log_api_usage_once from ..utils import _log_api_usage_once
from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES
from ._utils import handle_legacy_interface, _ovewrite_named_param
__all__ = ["AlexNet", "alexnet"] __all__ = ["AlexNet", "AlexNet_Weights", "alexnet"]
model_urls = {
"alexnet": "https://download.pytorch.org/models/alexnet-owt-7be5be79.pth",
}
class AlexNet(nn.Module): class AlexNet(nn.Module):
...@@ -53,17 +52,45 @@ class AlexNet(nn.Module): ...@@ -53,17 +52,45 @@ class AlexNet(nn.Module):
return x return x
def alexnet(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> AlexNet: class AlexNet_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/alexnet-owt-7be5be79.pth",
transforms=partial(ImageClassification, crop_size=224),
meta={
"task": "image_classification",
"architecture": "AlexNet",
"publication_year": 2012,
"num_params": 61100840,
"size": (224, 224),
"min_size": (63, 63),
"categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#alexnet-and-vgg",
"acc@1": 56.522,
"acc@5": 79.066,
},
)
DEFAULT = IMAGENET1K_V1
@handle_legacy_interface(weights=("pretrained", AlexNet_Weights.IMAGENET1K_V1))
def alexnet(*, weights: Optional[AlexNet_Weights] = None, progress: bool = True, **kwargs: Any) -> AlexNet:
r"""AlexNet model architecture from the r"""AlexNet model architecture from the
`"One weird trick..." <https://arxiv.org/abs/1404.5997>`_ paper. `"One weird trick..." <https://arxiv.org/abs/1404.5997>`_ paper.
The required minimum input size of the model is 63x63. The required minimum input size of the model is 63x63.
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet weights (AlexNet_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
weights = AlexNet_Weights.verify(weights)
if weights is not None:
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
model = AlexNet(**kwargs) model = AlexNet(**kwargs)
if pretrained:
state_dict = load_state_dict_from_url(model_urls["alexnet"], progress=progress) if weights is not None:
model.load_state_dict(state_dict) model.load_state_dict(weights.get_state_dict(progress=progress))
return model return model
from functools import partial from functools import partial
from typing import Any, Callable, Dict, List, Optional, Sequence from typing import Any, Callable, List, Optional, Sequence
import torch import torch
from torch import nn, Tensor from torch import nn, Tensor
from torch.nn import functional as F from torch.nn import functional as F
from .._internally_replaced_utils import load_state_dict_from_url
from ..ops.misc import Conv2dNormActivation from ..ops.misc import Conv2dNormActivation
from ..ops.stochastic_depth import StochasticDepth from ..ops.stochastic_depth import StochasticDepth
from ..transforms._presets import ImageClassification, InterpolationMode
from ..utils import _log_api_usage_once from ..utils import _log_api_usage_once
from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES
from ._utils import handle_legacy_interface, _ovewrite_named_param
__all__ = [ __all__ = [
"ConvNeXt", "ConvNeXt",
"ConvNeXt_Tiny_Weights",
"ConvNeXt_Small_Weights",
"ConvNeXt_Base_Weights",
"ConvNeXt_Large_Weights",
"convnext_tiny", "convnext_tiny",
"convnext_small", "convnext_small",
"convnext_base", "convnext_base",
...@@ -20,14 +27,6 @@ __all__ = [ ...@@ -20,14 +27,6 @@ __all__ = [
] ]
_MODELS_URLS: Dict[str, Optional[str]] = {
"convnext_tiny": "https://download.pytorch.org/models/convnext_tiny-983f1562.pth",
"convnext_small": "https://download.pytorch.org/models/convnext_small-0c510722.pth",
"convnext_base": "https://download.pytorch.org/models/convnext_base-6075fbad.pth",
"convnext_large": "https://download.pytorch.org/models/convnext_large-ea097f82.pth",
}
class LayerNorm2d(nn.LayerNorm): class LayerNorm2d(nn.LayerNorm):
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
x = x.permute(0, 2, 3, 1) x = x.permute(0, 2, 3, 1)
...@@ -187,29 +186,101 @@ class ConvNeXt(nn.Module): ...@@ -187,29 +186,101 @@ class ConvNeXt(nn.Module):
def _convnext( def _convnext(
arch: str,
block_setting: List[CNBlockConfig], block_setting: List[CNBlockConfig],
stochastic_depth_prob: float, stochastic_depth_prob: float,
pretrained: bool, weights: Optional[WeightsEnum],
progress: bool, progress: bool,
**kwargs: Any, **kwargs: Any,
) -> ConvNeXt: ) -> ConvNeXt:
if weights is not None:
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
model = ConvNeXt(block_setting, stochastic_depth_prob=stochastic_depth_prob, **kwargs) model = ConvNeXt(block_setting, stochastic_depth_prob=stochastic_depth_prob, **kwargs)
if pretrained:
if arch not in _MODELS_URLS: if weights is not None:
raise ValueError(f"No checkpoint is available for model type {arch}") model.load_state_dict(weights.get_state_dict(progress=progress))
state_dict = load_state_dict_from_url(_MODELS_URLS[arch], progress=progress)
model.load_state_dict(state_dict)
return model return model
def convnext_tiny(*, pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ConvNeXt: _COMMON_META = {
"task": "image_classification",
"architecture": "ConvNeXt",
"publication_year": 2022,
"size": (224, 224),
"min_size": (32, 32),
"categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#convnext",
}
class ConvNeXt_Tiny_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/convnext_tiny-983f1562.pth",
transforms=partial(ImageClassification, crop_size=224, resize_size=236),
meta={
**_COMMON_META,
"num_params": 28589128,
"acc@1": 82.520,
"acc@5": 96.146,
},
)
DEFAULT = IMAGENET1K_V1
class ConvNeXt_Small_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/convnext_small-0c510722.pth",
transforms=partial(ImageClassification, crop_size=224, resize_size=230),
meta={
**_COMMON_META,
"num_params": 50223688,
"acc@1": 83.616,
"acc@5": 96.650,
},
)
DEFAULT = IMAGENET1K_V1
class ConvNeXt_Base_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/convnext_base-6075fbad.pth",
transforms=partial(ImageClassification, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"num_params": 88591464,
"acc@1": 84.062,
"acc@5": 96.870,
},
)
DEFAULT = IMAGENET1K_V1
class ConvNeXt_Large_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/convnext_large-ea097f82.pth",
transforms=partial(ImageClassification, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"num_params": 197767336,
"acc@1": 84.414,
"acc@5": 96.976,
},
)
DEFAULT = IMAGENET1K_V1
@handle_legacy_interface(weights=("pretrained", ConvNeXt_Tiny_Weights.IMAGENET1K_V1))
def convnext_tiny(*, weights: Optional[ConvNeXt_Tiny_Weights] = None, progress: bool = True, **kwargs: Any) -> ConvNeXt:
r"""ConvNeXt Tiny model architecture from the r"""ConvNeXt Tiny model architecture from the
`"A ConvNet for the 2020s" <https://arxiv.org/abs/2201.03545>`_ paper. `"A ConvNet for the 2020s" <https://arxiv.org/abs/2201.03545>`_ paper.
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet weights (ConvNeXt_Tiny_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
weights = ConvNeXt_Tiny_Weights.verify(weights)
block_setting = [ block_setting = [
CNBlockConfig(96, 192, 3), CNBlockConfig(96, 192, 3),
CNBlockConfig(192, 384, 3), CNBlockConfig(192, 384, 3),
...@@ -217,16 +288,21 @@ def convnext_tiny(*, pretrained: bool = False, progress: bool = True, **kwargs: ...@@ -217,16 +288,21 @@ def convnext_tiny(*, pretrained: bool = False, progress: bool = True, **kwargs:
CNBlockConfig(768, None, 3), CNBlockConfig(768, None, 3),
] ]
stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.1) stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.1)
return _convnext("convnext_tiny", block_setting, stochastic_depth_prob, pretrained, progress, **kwargs) return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs)
def convnext_small(*, pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ConvNeXt: @handle_legacy_interface(weights=("pretrained", ConvNeXt_Small_Weights.IMAGENET1K_V1))
def convnext_small(
*, weights: Optional[ConvNeXt_Small_Weights] = None, progress: bool = True, **kwargs: Any
) -> ConvNeXt:
r"""ConvNeXt Small model architecture from the r"""ConvNeXt Small model architecture from the
`"A ConvNet for the 2020s" <https://arxiv.org/abs/2201.03545>`_ paper. `"A ConvNet for the 2020s" <https://arxiv.org/abs/2201.03545>`_ paper.
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet weights (ConvNeXt_Small_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
weights = ConvNeXt_Small_Weights.verify(weights)
block_setting = [ block_setting = [
CNBlockConfig(96, 192, 3), CNBlockConfig(96, 192, 3),
CNBlockConfig(192, 384, 3), CNBlockConfig(192, 384, 3),
...@@ -234,16 +310,19 @@ def convnext_small(*, pretrained: bool = False, progress: bool = True, **kwargs: ...@@ -234,16 +310,19 @@ def convnext_small(*, pretrained: bool = False, progress: bool = True, **kwargs:
CNBlockConfig(768, None, 3), CNBlockConfig(768, None, 3),
] ]
stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.4) stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.4)
return _convnext("convnext_small", block_setting, stochastic_depth_prob, pretrained, progress, **kwargs) return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs)
def convnext_base(*, pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ConvNeXt: @handle_legacy_interface(weights=("pretrained", ConvNeXt_Base_Weights.IMAGENET1K_V1))
def convnext_base(*, weights: Optional[ConvNeXt_Base_Weights] = None, progress: bool = True, **kwargs: Any) -> ConvNeXt:
r"""ConvNeXt Base model architecture from the r"""ConvNeXt Base model architecture from the
`"A ConvNet for the 2020s" <https://arxiv.org/abs/2201.03545>`_ paper. `"A ConvNet for the 2020s" <https://arxiv.org/abs/2201.03545>`_ paper.
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet weights (ConvNeXt_Base_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
weights = ConvNeXt_Base_Weights.verify(weights)
block_setting = [ block_setting = [
CNBlockConfig(128, 256, 3), CNBlockConfig(128, 256, 3),
CNBlockConfig(256, 512, 3), CNBlockConfig(256, 512, 3),
...@@ -251,16 +330,21 @@ def convnext_base(*, pretrained: bool = False, progress: bool = True, **kwargs: ...@@ -251,16 +330,21 @@ def convnext_base(*, pretrained: bool = False, progress: bool = True, **kwargs:
CNBlockConfig(1024, None, 3), CNBlockConfig(1024, None, 3),
] ]
stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.5) stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.5)
return _convnext("convnext_base", block_setting, stochastic_depth_prob, pretrained, progress, **kwargs) return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs)
def convnext_large(*, pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ConvNeXt: @handle_legacy_interface(weights=("pretrained", ConvNeXt_Large_Weights.IMAGENET1K_V1))
def convnext_large(
*, weights: Optional[ConvNeXt_Large_Weights] = None, progress: bool = True, **kwargs: Any
) -> ConvNeXt:
r"""ConvNeXt Large model architecture from the r"""ConvNeXt Large model architecture from the
`"A ConvNet for the 2020s" <https://arxiv.org/abs/2201.03545>`_ paper. `"A ConvNet for the 2020s" <https://arxiv.org/abs/2201.03545>`_ paper.
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet weights (ConvNeXt_Large_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
weights = ConvNeXt_Large_Weights.verify(weights)
block_setting = [ block_setting = [
CNBlockConfig(192, 384, 3), CNBlockConfig(192, 384, 3),
CNBlockConfig(384, 768, 3), CNBlockConfig(384, 768, 3),
...@@ -268,4 +352,4 @@ def convnext_large(*, pretrained: bool = False, progress: bool = True, **kwargs: ...@@ -268,4 +352,4 @@ def convnext_large(*, pretrained: bool = False, progress: bool = True, **kwargs:
CNBlockConfig(1536, None, 3), CNBlockConfig(1536, None, 3),
] ]
stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.5) stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.5)
return _convnext("convnext_large", block_setting, stochastic_depth_prob, pretrained, progress, **kwargs) return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment