Unverified Commit 4bf6c6e4 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Adding prototype flag on reference scripts (#5248)

* Adding prototype flag on reference scripts.

* Import prototype instead of models/transforms.

* Correcting exception type.

* fixing none referencing
parent 7d4bdd43
...@@ -16,9 +16,9 @@ from torchvision.transforms.functional import InterpolationMode ...@@ -16,9 +16,9 @@ from torchvision.transforms.functional import InterpolationMode
try: try:
from torchvision.prototype import models as PM from torchvision import prototype
except ImportError: except ImportError:
PM = None prototype = None
def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema=None, scaler=None): def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema=None, scaler=None):
...@@ -154,13 +154,18 @@ def load_data(traindir, valdir, args): ...@@ -154,13 +154,18 @@ def load_data(traindir, valdir, args):
print(f"Loading dataset_test from {cache_path}") print(f"Loading dataset_test from {cache_path}")
dataset_test, _ = torch.load(cache_path) dataset_test, _ = torch.load(cache_path)
else: else:
if not args.weights: if not args.prototype:
preprocessing = presets.ClassificationPresetEval( preprocessing = presets.ClassificationPresetEval(
crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation
) )
else: else:
weights = PM.get_weight(args.weights) if args.weights:
weights = prototype.models.get_weight(args.weights)
preprocessing = weights.transforms() preprocessing = weights.transforms()
else:
preprocessing = prototype.transforms.ImageNetEval(
crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation
)
dataset_test = torchvision.datasets.ImageFolder( dataset_test = torchvision.datasets.ImageFolder(
valdir, valdir,
...@@ -186,8 +191,10 @@ def load_data(traindir, valdir, args): ...@@ -186,8 +191,10 @@ def load_data(traindir, valdir, args):
def main(args): def main(args):
if args.weights and PM is None: if args.prototype and prototype is None:
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.") raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
if not args.prototype and args.weights:
raise ValueError("The weights parameter works only in prototype mode. Please pass the --prototype argument.")
if args.output_dir: if args.output_dir:
utils.mkdir(args.output_dir) utils.mkdir(args.output_dir)
...@@ -229,10 +236,10 @@ def main(args): ...@@ -229,10 +236,10 @@ def main(args):
) )
print("Creating model") print("Creating model")
if not args.weights: if not args.prototype:
model = torchvision.models.__dict__[args.model](pretrained=args.pretrained, num_classes=num_classes) model = torchvision.models.__dict__[args.model](pretrained=args.pretrained, num_classes=num_classes)
else: else:
model = PM.__dict__[args.model](weights=args.weights, num_classes=num_classes) model = prototype.models.__dict__[args.model](weights=args.weights, num_classes=num_classes)
model.to(device) model.to(device)
if args.distributed and args.sync_bn: if args.distributed and args.sync_bn:
...@@ -491,6 +498,12 @@ def get_args_parser(add_help=True): ...@@ -491,6 +498,12 @@ def get_args_parser(add_help=True):
) )
# Prototype models only # Prototype models only
parser.add_argument(
"--prototype",
dest="prototype",
help="Use prototype model builders instead those from main area",
action="store_true",
)
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
return parser return parser
......
...@@ -34,9 +34,9 @@ from group_by_aspect_ratio import GroupedBatchSampler, create_aspect_ratio_group ...@@ -34,9 +34,9 @@ from group_by_aspect_ratio import GroupedBatchSampler, create_aspect_ratio_group
try: try:
from torchvision.prototype import models as PM from torchvision import prototype
except ImportError: except ImportError:
PM = None prototype = None
def get_dataset(name, image_set, transform, data_path): def get_dataset(name, image_set, transform, data_path):
...@@ -50,11 +50,14 @@ def get_dataset(name, image_set, transform, data_path): ...@@ -50,11 +50,14 @@ def get_dataset(name, image_set, transform, data_path):
def get_transform(train, args): def get_transform(train, args):
if train: if train:
return presets.DetectionPresetTrain(args.data_augmentation) return presets.DetectionPresetTrain(args.data_augmentation)
elif not args.weights: elif not args.prototype:
return presets.DetectionPresetEval() return presets.DetectionPresetEval()
else: else:
weights = PM.get_weight(args.weights) if args.weights:
weights = prototype.models.get_weight(args.weights)
return weights.transforms() return weights.transforms()
else:
return prototype.transforms.CocoEval()
def get_args_parser(add_help=True): def get_args_parser(add_help=True):
...@@ -141,6 +144,12 @@ def get_args_parser(add_help=True): ...@@ -141,6 +144,12 @@ def get_args_parser(add_help=True):
parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training") parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
# Prototype models only # Prototype models only
parser.add_argument(
"--prototype",
dest="prototype",
help="Use prototype model builders instead those from main area",
action="store_true",
)
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
# Mixed precision training parameters # Mixed precision training parameters
...@@ -150,8 +159,10 @@ def get_args_parser(add_help=True): ...@@ -150,8 +159,10 @@ def get_args_parser(add_help=True):
def main(args): def main(args):
if args.weights and PM is None: if args.prototype and prototype is None:
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.") raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
if not args.prototype and args.weights:
raise ValueError("The weights parameter works only in prototype mode. Please pass the --prototype argument.")
if args.output_dir: if args.output_dir:
utils.mkdir(args.output_dir) utils.mkdir(args.output_dir)
...@@ -193,12 +204,12 @@ def main(args): ...@@ -193,12 +204,12 @@ def main(args):
if "rcnn" in args.model: if "rcnn" in args.model:
if args.rpn_score_thresh is not None: if args.rpn_score_thresh is not None:
kwargs["rpn_score_thresh"] = args.rpn_score_thresh kwargs["rpn_score_thresh"] = args.rpn_score_thresh
if not args.weights: if not args.prototype:
model = torchvision.models.detection.__dict__[args.model]( model = torchvision.models.detection.__dict__[args.model](
pretrained=args.pretrained, num_classes=num_classes, **kwargs pretrained=args.pretrained, num_classes=num_classes, **kwargs
) )
else: else:
model = PM.detection.__dict__[args.model](weights=args.weights, num_classes=num_classes, **kwargs) model = prototype.models.detection.__dict__[args.model](weights=args.weights, num_classes=num_classes, **kwargs)
model.to(device) model.to(device)
if args.distributed and args.sync_bn: if args.distributed and args.sync_bn:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
......
...@@ -10,10 +10,9 @@ from presets import OpticalFlowPresetTrain, OpticalFlowPresetEval ...@@ -10,10 +10,9 @@ from presets import OpticalFlowPresetTrain, OpticalFlowPresetEval
from torchvision.datasets import KittiFlow, FlyingChairs, FlyingThings3D, Sintel, HD1K from torchvision.datasets import KittiFlow, FlyingChairs, FlyingThings3D, Sintel, HD1K
try: try:
from torchvision.prototype import models as PM from torchvision import prototype
from torchvision.prototype.models import optical_flow as PMOF
except ImportError: except ImportError:
PM = PMOF = None prototype = None
def get_train_dataset(stage, dataset_root): def get_train_dataset(stage, dataset_root):
...@@ -133,9 +132,12 @@ def _validate(model, args, val_dataset, *, padder_mode, num_flow_updates=None, b ...@@ -133,9 +132,12 @@ def _validate(model, args, val_dataset, *, padder_mode, num_flow_updates=None, b
def validate(model, args): def validate(model, args):
val_datasets = args.val_dataset or [] val_datasets = args.val_dataset or []
if args.prototype:
if args.weights: if args.weights:
weights = PM.get_weight(args.weights) weights = prototype.models.get_weight(args.weights)
preprocessing = weights.transforms() preprocessing = weights.transforms()
else:
preprocessing = prototype.transforms.RaftEval()
else: else:
preprocessing = OpticalFlowPresetEval() preprocessing = OpticalFlowPresetEval()
...@@ -192,10 +194,14 @@ def train_one_epoch(model, optimizer, scheduler, train_loader, logger, args): ...@@ -192,10 +194,14 @@ def train_one_epoch(model, optimizer, scheduler, train_loader, logger, args):
def main(args): def main(args):
if args.prototype and prototype is None:
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
if not args.prototype and args.weights:
raise ValueError("The weights parameter works only in prototype mode. Please pass the --prototype argument.")
utils.setup_ddp(args) utils.setup_ddp(args)
if args.weights: if args.prototype:
model = PMOF.__dict__[args.model](weights=args.weights) model = prototype.models.optical_flow.__dict__[args.model](weights=args.weights)
else: else:
model = torchvision.models.optical_flow.__dict__[args.model](pretrained=args.pretrained) model = torchvision.models.optical_flow.__dict__[args.model](pretrained=args.pretrained)
...@@ -317,7 +323,6 @@ def get_args_parser(add_help=True): ...@@ -317,7 +323,6 @@ def get_args_parser(add_help=True):
) )
# TODO: resume, pretrained, and weights should be in an exclusive arg group # TODO: resume, pretrained, and weights should be in an exclusive arg group
parser.add_argument("--pretrained", action="store_true", help="Whether to use pretrained weights") parser.add_argument("--pretrained", action="store_true", help="Whether to use pretrained weights")
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load.")
parser.add_argument( parser.add_argument(
"--num_flow_updates", "--num_flow_updates",
...@@ -336,6 +341,15 @@ def get_args_parser(add_help=True): ...@@ -336,6 +341,15 @@ def get_args_parser(add_help=True):
required=True, required=True,
) )
# Prototype models only
parser.add_argument(
"--prototype",
dest="prototype",
help="Use prototype model builders instead those from main area",
action="store_true",
)
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load.")
return parser return parser
......
...@@ -12,9 +12,9 @@ from torch import nn ...@@ -12,9 +12,9 @@ from torch import nn
try: try:
from torchvision.prototype import models as PM from torchvision import prototype
except ImportError: except ImportError:
PM = None prototype = None
def get_dataset(dir_path, name, image_set, transform): def get_dataset(dir_path, name, image_set, transform):
...@@ -35,11 +35,14 @@ def get_dataset(dir_path, name, image_set, transform): ...@@ -35,11 +35,14 @@ def get_dataset(dir_path, name, image_set, transform):
def get_transform(train, args): def get_transform(train, args):
if train: if train:
return presets.SegmentationPresetTrain(base_size=520, crop_size=480) return presets.SegmentationPresetTrain(base_size=520, crop_size=480)
elif not args.weights: elif not args.prototype:
return presets.SegmentationPresetEval(base_size=520) return presets.SegmentationPresetEval(base_size=520)
else: else:
weights = PM.get_weight(args.weights) if args.weights:
weights = prototype.models.get_weight(args.weights)
return weights.transforms() return weights.transforms()
else:
return prototype.transforms.VocEval(resize_size=520)
def criterion(inputs, target): def criterion(inputs, target):
...@@ -97,8 +100,10 @@ def train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, devi ...@@ -97,8 +100,10 @@ def train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, devi
def main(args): def main(args):
if args.weights and PM is None: if args.prototype and prototype is None:
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.") raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
if not args.prototype and args.weights:
raise ValueError("The weights parameter works only in prototype mode. Please pass the --prototype argument.")
if args.output_dir: if args.output_dir:
utils.mkdir(args.output_dir) utils.mkdir(args.output_dir)
...@@ -130,14 +135,14 @@ def main(args): ...@@ -130,14 +135,14 @@ def main(args):
dataset_test, batch_size=1, sampler=test_sampler, num_workers=args.workers, collate_fn=utils.collate_fn dataset_test, batch_size=1, sampler=test_sampler, num_workers=args.workers, collate_fn=utils.collate_fn
) )
if not args.weights: if not args.prototype:
model = torchvision.models.segmentation.__dict__[args.model]( model = torchvision.models.segmentation.__dict__[args.model](
pretrained=args.pretrained, pretrained=args.pretrained,
num_classes=num_classes, num_classes=num_classes,
aux_loss=args.aux_loss, aux_loss=args.aux_loss,
) )
else: else:
model = PM.segmentation.__dict__[args.model]( model = prototype.models.segmentation.__dict__[args.model](
weights=args.weights, num_classes=num_classes, aux_loss=args.aux_loss weights=args.weights, num_classes=num_classes, aux_loss=args.aux_loss
) )
model.to(device) model.to(device)
...@@ -278,6 +283,12 @@ def get_args_parser(add_help=True): ...@@ -278,6 +283,12 @@ def get_args_parser(add_help=True):
parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training") parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
# Prototype models only # Prototype models only
parser.add_argument(
"--prototype",
dest="prototype",
help="Use prototype model builders instead those from main area",
action="store_true",
)
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
# Mixed precision training parameters # Mixed precision training parameters
......
...@@ -13,9 +13,9 @@ from torch.utils.data.dataloader import default_collate ...@@ -13,9 +13,9 @@ from torch.utils.data.dataloader import default_collate
from torchvision.datasets.samplers import DistributedSampler, UniformClipSampler, RandomClipSampler from torchvision.datasets.samplers import DistributedSampler, UniformClipSampler, RandomClipSampler
try: try:
from torchvision.prototype import models as PM from torchvision import prototype
except ImportError: except ImportError:
PM = None prototype = None
def train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, print_freq, scaler=None): def train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, print_freq, scaler=None):
...@@ -96,9 +96,10 @@ def collate_fn(batch): ...@@ -96,9 +96,10 @@ def collate_fn(batch):
def main(args): def main(args):
if args.weights and PM is None: if args.prototype and prototype is None:
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.") raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
if not args.prototype and args.weights:
raise ValueError("The weights parameter works only in prototype mode. Please pass the --prototype argument.")
if args.output_dir: if args.output_dir:
utils.mkdir(args.output_dir) utils.mkdir(args.output_dir)
...@@ -149,11 +150,14 @@ def main(args): ...@@ -149,11 +150,14 @@ def main(args):
print("Loading validation data") print("Loading validation data")
cache_path = _get_cache_path(valdir) cache_path = _get_cache_path(valdir)
if not args.weights: if not args.prototype:
transform_test = presets.VideoClassificationPresetEval((128, 171), (112, 112)) transform_test = presets.VideoClassificationPresetEval(resize_size=(128, 171), crop_size=(112, 112))
else: else:
weights = PM.get_weight(args.weights) if args.weights:
weights = prototype.models.get_weight(args.weights)
transform_test = weights.transforms() transform_test = weights.transforms()
else:
transform_test = prototype.transforms.Kinect400Eval(crop_size=(112, 112), resize_size=(128, 171))
if args.cache_dataset and os.path.exists(cache_path): if args.cache_dataset and os.path.exists(cache_path):
print(f"Loading dataset_test from {cache_path}") print(f"Loading dataset_test from {cache_path}")
...@@ -204,10 +208,10 @@ def main(args): ...@@ -204,10 +208,10 @@ def main(args):
) )
print("Creating model") print("Creating model")
if not args.weights: if not args.prototype:
model = torchvision.models.video.__dict__[args.model](pretrained=args.pretrained) model = torchvision.models.video.__dict__[args.model](pretrained=args.pretrained)
else: else:
model = PM.video.__dict__[args.model](weights=args.weights) model = prototype.models.video.__dict__[args.model](weights=args.weights)
model.to(device) model.to(device)
if args.distributed and args.sync_bn: if args.distributed and args.sync_bn:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
...@@ -360,6 +364,12 @@ def parse_args(): ...@@ -360,6 +364,12 @@ def parse_args():
parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training") parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
# Prototype models only # Prototype models only
parser.add_argument(
"--prototype",
dest="prototype",
help="Use prototype model builders instead those from main area",
action="store_true",
)
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
# Mixed precision training parameters # Mixed precision training parameters
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment