Unverified Commit 3300692c authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Moving the check for prototype support in all references. (#4849)

parent dd1adb07
...@@ -182,6 +182,8 @@ def load_data(traindir, valdir, args): ...@@ -182,6 +182,8 @@ def load_data(traindir, valdir, args):
def main(args): def main(args):
if args.weights and PM is None:
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
if args.output_dir: if args.output_dir:
utils.mkdir(args.output_dir) utils.mkdir(args.output_dir)
...@@ -226,8 +228,6 @@ def main(args): ...@@ -226,8 +228,6 @@ def main(args):
if not args.weights: if not args.weights:
model = torchvision.models.__dict__[args.model](pretrained=args.pretrained, num_classes=num_classes) model = torchvision.models.__dict__[args.model](pretrained=args.pretrained, num_classes=num_classes)
else: else:
if PM is None:
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
model = PM.__dict__[args.model](weights=args.weights, num_classes=num_classes) model = PM.__dict__[args.model](weights=args.weights, num_classes=num_classes)
model.to(device) model.to(device)
......
...@@ -19,6 +19,8 @@ except ImportError: ...@@ -19,6 +19,8 @@ except ImportError:
def main(args): def main(args):
if args.weights and PM is None:
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
if args.output_dir: if args.output_dir:
utils.mkdir(args.output_dir) utils.mkdir(args.output_dir)
...@@ -55,8 +57,6 @@ def main(args): ...@@ -55,8 +57,6 @@ def main(args):
if not args.weights: if not args.weights:
model = torchvision.models.quantization.__dict__[args.model](pretrained=True, quantize=args.test_only) model = torchvision.models.quantization.__dict__[args.model](pretrained=True, quantize=args.test_only)
else: else:
if PM is None:
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
model = PM.quantization.__dict__[args.model](weights=args.weights, quantize=args.test_only) model = PM.quantization.__dict__[args.model](weights=args.weights, quantize=args.test_only)
model.to(device) model.to(device)
......
...@@ -148,6 +148,8 @@ def get_args_parser(add_help=True): ...@@ -148,6 +148,8 @@ def get_args_parser(add_help=True):
def main(args): def main(args):
if args.weights and PM is None:
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
if args.output_dir: if args.output_dir:
utils.mkdir(args.output_dir) utils.mkdir(args.output_dir)
...@@ -194,8 +196,6 @@ def main(args): ...@@ -194,8 +196,6 @@ def main(args):
pretrained=args.pretrained, num_classes=num_classes, **kwargs pretrained=args.pretrained, num_classes=num_classes, **kwargs
) )
else: else:
if PM is None:
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
model = PM.detection.__dict__[args.model](weights=args.weights, num_classes=num_classes, **kwargs) model = PM.detection.__dict__[args.model](weights=args.weights, num_classes=num_classes, **kwargs)
model.to(device) model.to(device)
if args.distributed and args.sync_bn: if args.distributed and args.sync_bn:
......
...@@ -92,6 +92,8 @@ def train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, devi ...@@ -92,6 +92,8 @@ def train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, devi
def main(args): def main(args):
if args.weights and PM is None:
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
if args.output_dir: if args.output_dir:
utils.mkdir(args.output_dir) utils.mkdir(args.output_dir)
...@@ -130,8 +132,6 @@ def main(args): ...@@ -130,8 +132,6 @@ def main(args):
aux_loss=args.aux_loss, aux_loss=args.aux_loss,
) )
else: else:
if PM is None:
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
model = PM.segmentation.__dict__[args.model]( model = PM.segmentation.__dict__[args.model](
weights=args.weights, num_classes=num_classes, aux_loss=args.aux_loss weights=args.weights, num_classes=num_classes, aux_loss=args.aux_loss
) )
......
...@@ -99,6 +99,8 @@ def collate_fn(batch): ...@@ -99,6 +99,8 @@ def collate_fn(batch):
def main(args): def main(args):
if args.weights and PM is None:
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
if args.apex and amp is None: if args.apex and amp is None:
raise RuntimeError( raise RuntimeError(
"Failed to import apex. Please install apex from https://www.github.com/nvidia/apex " "Failed to import apex. Please install apex from https://www.github.com/nvidia/apex "
...@@ -214,8 +216,6 @@ def main(args): ...@@ -214,8 +216,6 @@ def main(args):
if not args.weights: if not args.weights:
model = torchvision.models.video.__dict__[args.model](pretrained=args.pretrained) model = torchvision.models.video.__dict__[args.model](pretrained=args.pretrained)
else: else:
if PM is None:
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
model = PM.video.__dict__[args.model](weights=args.weights) model = PM.video.__dict__[args.model](weights=args.weights)
model.to(device) model.to(device)
if args.distributed and args.sync_bn: if args.distributed and args.sync_bn:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment