Unverified Commit 4bf6c6e4 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Adding prototype flag on reference scripts (#5248)

* Adding prototype flag on reference scripts.

* Import prototype instead of models/transforms.

* Correcting exception type.

* fixing none referencing
parent 7d4bdd43
......@@ -16,9 +16,9 @@ from torchvision.transforms.functional import InterpolationMode
try:
from torchvision.prototype import models as PM
from torchvision import prototype
except ImportError:
PM = None
prototype = None
def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema=None, scaler=None):
......@@ -154,13 +154,18 @@ def load_data(traindir, valdir, args):
print(f"Loading dataset_test from {cache_path}")
dataset_test, _ = torch.load(cache_path)
else:
if not args.weights:
if not args.prototype:
preprocessing = presets.ClassificationPresetEval(
crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation
)
else:
weights = PM.get_weight(args.weights)
preprocessing = weights.transforms()
if args.weights:
weights = prototype.models.get_weight(args.weights)
preprocessing = weights.transforms()
else:
preprocessing = prototype.transforms.ImageNetEval(
crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation
)
dataset_test = torchvision.datasets.ImageFolder(
valdir,
......@@ -186,8 +191,10 @@ def load_data(traindir, valdir, args):
def main(args):
if args.weights and PM is None:
if args.prototype and prototype is None:
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
if not args.prototype and args.weights:
raise ValueError("The weights parameter works only in prototype mode. Please pass the --prototype argument.")
if args.output_dir:
utils.mkdir(args.output_dir)
......@@ -229,10 +236,10 @@ def main(args):
)
print("Creating model")
if not args.weights:
if not args.prototype:
model = torchvision.models.__dict__[args.model](pretrained=args.pretrained, num_classes=num_classes)
else:
model = PM.__dict__[args.model](weights=args.weights, num_classes=num_classes)
model = prototype.models.__dict__[args.model](weights=args.weights, num_classes=num_classes)
model.to(device)
if args.distributed and args.sync_bn:
......@@ -491,6 +498,12 @@ def get_args_parser(add_help=True):
)
# Prototype models only
parser.add_argument(
"--prototype",
dest="prototype",
help="Use prototype model builders instead those from main area",
action="store_true",
)
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
return parser
......
......@@ -34,9 +34,9 @@ from group_by_aspect_ratio import GroupedBatchSampler, create_aspect_ratio_group
try:
from torchvision.prototype import models as PM
from torchvision import prototype
except ImportError:
PM = None
prototype = None
def get_dataset(name, image_set, transform, data_path):
......@@ -50,11 +50,14 @@ def get_dataset(name, image_set, transform, data_path):
def get_transform(train, args):
if train:
return presets.DetectionPresetTrain(args.data_augmentation)
elif not args.weights:
elif not args.prototype:
return presets.DetectionPresetEval()
else:
weights = PM.get_weight(args.weights)
return weights.transforms()
if args.weights:
weights = prototype.models.get_weight(args.weights)
return weights.transforms()
else:
return prototype.transforms.CocoEval()
def get_args_parser(add_help=True):
......@@ -141,6 +144,12 @@ def get_args_parser(add_help=True):
parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
# Prototype models only
parser.add_argument(
"--prototype",
dest="prototype",
help="Use prototype model builders instead those from main area",
action="store_true",
)
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
# Mixed precision training parameters
......@@ -150,8 +159,10 @@ def get_args_parser(add_help=True):
def main(args):
if args.weights and PM is None:
if args.prototype and prototype is None:
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
if not args.prototype and args.weights:
raise ValueError("The weights parameter works only in prototype mode. Please pass the --prototype argument.")
if args.output_dir:
utils.mkdir(args.output_dir)
......@@ -193,12 +204,12 @@ def main(args):
if "rcnn" in args.model:
if args.rpn_score_thresh is not None:
kwargs["rpn_score_thresh"] = args.rpn_score_thresh
if not args.weights:
if not args.prototype:
model = torchvision.models.detection.__dict__[args.model](
pretrained=args.pretrained, num_classes=num_classes, **kwargs
)
else:
model = PM.detection.__dict__[args.model](weights=args.weights, num_classes=num_classes, **kwargs)
model = prototype.models.detection.__dict__[args.model](weights=args.weights, num_classes=num_classes, **kwargs)
model.to(device)
if args.distributed and args.sync_bn:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
......
......@@ -10,10 +10,9 @@ from presets import OpticalFlowPresetTrain, OpticalFlowPresetEval
from torchvision.datasets import KittiFlow, FlyingChairs, FlyingThings3D, Sintel, HD1K
try:
from torchvision.prototype import models as PM
from torchvision.prototype.models import optical_flow as PMOF
from torchvision import prototype
except ImportError:
PM = PMOF = None
prototype = None
def get_train_dataset(stage, dataset_root):
......@@ -133,9 +132,12 @@ def _validate(model, args, val_dataset, *, padder_mode, num_flow_updates=None, b
def validate(model, args):
val_datasets = args.val_dataset or []
if args.weights:
weights = PM.get_weight(args.weights)
preprocessing = weights.transforms()
if args.prototype:
if args.weights:
weights = prototype.models.get_weight(args.weights)
preprocessing = weights.transforms()
else:
preprocessing = prototype.transforms.RaftEval()
else:
preprocessing = OpticalFlowPresetEval()
......@@ -192,10 +194,14 @@ def train_one_epoch(model, optimizer, scheduler, train_loader, logger, args):
def main(args):
if args.prototype and prototype is None:
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
if not args.prototype and args.weights:
raise ValueError("The weights parameter works only in prototype mode. Please pass the --prototype argument.")
utils.setup_ddp(args)
if args.weights:
model = PMOF.__dict__[args.model](weights=args.weights)
if args.prototype:
model = prototype.models.optical_flow.__dict__[args.model](weights=args.weights)
else:
model = torchvision.models.optical_flow.__dict__[args.model](pretrained=args.pretrained)
......@@ -317,7 +323,6 @@ def get_args_parser(add_help=True):
)
# TODO: resume, pretrained, and weights should be in an exclusive arg group
parser.add_argument("--pretrained", action="store_true", help="Whether to use pretrained weights")
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load.")
parser.add_argument(
"--num_flow_updates",
......@@ -336,6 +341,15 @@ def get_args_parser(add_help=True):
required=True,
)
# Prototype models only
parser.add_argument(
"--prototype",
dest="prototype",
help="Use prototype model builders instead those from main area",
action="store_true",
)
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load.")
return parser
......
......@@ -12,9 +12,9 @@ from torch import nn
try:
from torchvision.prototype import models as PM
from torchvision import prototype
except ImportError:
PM = None
prototype = None
def get_dataset(dir_path, name, image_set, transform):
......@@ -35,11 +35,14 @@ def get_dataset(dir_path, name, image_set, transform):
def get_transform(train, args):
if train:
return presets.SegmentationPresetTrain(base_size=520, crop_size=480)
elif not args.weights:
elif not args.prototype:
return presets.SegmentationPresetEval(base_size=520)
else:
weights = PM.get_weight(args.weights)
return weights.transforms()
if args.weights:
weights = prototype.models.get_weight(args.weights)
return weights.transforms()
else:
return prototype.transforms.VocEval(resize_size=520)
def criterion(inputs, target):
......@@ -97,8 +100,10 @@ def train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, devi
def main(args):
if args.weights and PM is None:
if args.prototype and prototype is None:
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
if not args.prototype and args.weights:
raise ValueError("The weights parameter works only in prototype mode. Please pass the --prototype argument.")
if args.output_dir:
utils.mkdir(args.output_dir)
......@@ -130,14 +135,14 @@ def main(args):
dataset_test, batch_size=1, sampler=test_sampler, num_workers=args.workers, collate_fn=utils.collate_fn
)
if not args.weights:
if not args.prototype:
model = torchvision.models.segmentation.__dict__[args.model](
pretrained=args.pretrained,
num_classes=num_classes,
aux_loss=args.aux_loss,
)
else:
model = PM.segmentation.__dict__[args.model](
model = prototype.models.segmentation.__dict__[args.model](
weights=args.weights, num_classes=num_classes, aux_loss=args.aux_loss
)
model.to(device)
......@@ -278,6 +283,12 @@ def get_args_parser(add_help=True):
parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
# Prototype models only
parser.add_argument(
"--prototype",
dest="prototype",
help="Use prototype model builders instead those from main area",
action="store_true",
)
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
# Mixed precision training parameters
......
......@@ -13,9 +13,9 @@ from torch.utils.data.dataloader import default_collate
from torchvision.datasets.samplers import DistributedSampler, UniformClipSampler, RandomClipSampler
try:
from torchvision.prototype import models as PM
from torchvision import prototype
except ImportError:
PM = None
prototype = None
def train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, print_freq, scaler=None):
......@@ -96,9 +96,10 @@ def collate_fn(batch):
def main(args):
if args.weights and PM is None:
if args.prototype and prototype is None:
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
if not args.prototype and args.weights:
raise ValueError("The weights parameter works only in prototype mode. Please pass the --prototype argument.")
if args.output_dir:
utils.mkdir(args.output_dir)
......@@ -149,11 +150,14 @@ def main(args):
print("Loading validation data")
cache_path = _get_cache_path(valdir)
if not args.weights:
transform_test = presets.VideoClassificationPresetEval((128, 171), (112, 112))
if not args.prototype:
transform_test = presets.VideoClassificationPresetEval(resize_size=(128, 171), crop_size=(112, 112))
else:
weights = PM.get_weight(args.weights)
transform_test = weights.transforms()
if args.weights:
weights = prototype.models.get_weight(args.weights)
transform_test = weights.transforms()
else:
transform_test = prototype.transforms.Kinect400Eval(crop_size=(112, 112), resize_size=(128, 171))
if args.cache_dataset and os.path.exists(cache_path):
print(f"Loading dataset_test from {cache_path}")
......@@ -204,10 +208,10 @@ def main(args):
)
print("Creating model")
if not args.weights:
if not args.prototype:
model = torchvision.models.video.__dict__[args.model](pretrained=args.pretrained)
else:
model = PM.video.__dict__[args.model](weights=args.weights)
model = prototype.models.video.__dict__[args.model](weights=args.weights)
model.to(device)
if args.distributed and args.sync_bn:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
......@@ -360,6 +364,12 @@ def parse_args():
parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
# Prototype models only
parser.add_argument(
"--prototype",
dest="prototype",
help="Use prototype model builders instead those from main area",
action="store_true",
)
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
# Mixed precision training parameters
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment