Unverified Commit 1d0786b0 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Update references to use the new Model Registration API (#6369)

* Expose on Hub the public methods of the registration API

* Limit methods and update docs.

* Update references to use the new Model Registration API
parent c72b2843
...@@ -221,7 +221,7 @@ def main(args): ...@@ -221,7 +221,7 @@ def main(args):
) )
print("Creating model") print("Creating model")
model = torchvision.models.__dict__[args.model](weights=args.weights, num_classes=num_classes) model = torchvision.models.get_model(args.model, weights=args.weights, num_classes=num_classes)
model.to(device) model.to(device)
if args.distributed and args.sync_bn: if args.distributed and args.sync_bn:
......
...@@ -46,7 +46,11 @@ def main(args): ...@@ -46,7 +46,11 @@ def main(args):
print("Creating model", args.model) print("Creating model", args.model)
# when training quantized models, we always start from a pre-trained fp32 reference model # when training quantized models, we always start from a pre-trained fp32 reference model
model = torchvision.models.quantization.__dict__[args.model](weights=args.weights, quantize=args.test_only) prefix = "quantized_"
model_name = args.model
if not model_name.startswith(prefix):
model_name = prefix + model_name
model = torchvision.models.get_model(model_name, weights=args.weights, quantize=args.test_only)
model.to(device) model.to(device)
if not (args.test_only or args.post_training_quantize): if not (args.test_only or args.post_training_quantize):
......
...@@ -216,8 +216,8 @@ def main(args): ...@@ -216,8 +216,8 @@ def main(args):
if "rcnn" in args.model: if "rcnn" in args.model:
if args.rpn_score_thresh is not None: if args.rpn_score_thresh is not None:
kwargs["rpn_score_thresh"] = args.rpn_score_thresh kwargs["rpn_score_thresh"] = args.rpn_score_thresh
model = torchvision.models.detection.__dict__[args.model]( model = torchvision.models.get_model(
weights=args.weights, weights_backbone=args.weights_backbone, num_classes=num_classes, **kwargs args.model, weights=args.weights, weights_backbone=args.weights_backbone, num_classes=num_classes, **kwargs
) )
model.to(device) model.to(device)
if args.distributed and args.sync_bn: if args.distributed and args.sync_bn:
......
...@@ -215,7 +215,7 @@ def main(args): ...@@ -215,7 +215,7 @@ def main(args):
else: else:
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
model = torchvision.models.optical_flow.__dict__[args.model](weights=args.weights) model = torchvision.models.get_model(args.model, weights=args.weights)
if args.distributed: if args.distributed:
model = model.to(args.local_rank) model = model.to(args.local_rank)
......
...@@ -156,8 +156,12 @@ def main(args): ...@@ -156,8 +156,12 @@ def main(args):
dataset_test, batch_size=1, sampler=test_sampler, num_workers=args.workers, collate_fn=utils.collate_fn dataset_test, batch_size=1, sampler=test_sampler, num_workers=args.workers, collate_fn=utils.collate_fn
) )
model = torchvision.models.segmentation.__dict__[args.model]( model = torchvision.models.get_model(
weights=args.weights, weights_backbone=args.weights_backbone, num_classes=num_classes, aux_loss=args.aux_loss args.model,
weights=args.weights,
weights_backbone=args.weights_backbone,
num_classes=num_classes,
aux_loss=args.aux_loss,
) )
model.to(device) model.to(device)
if args.distributed: if args.distributed:
......
...@@ -246,7 +246,7 @@ def main(args): ...@@ -246,7 +246,7 @@ def main(args):
) )
print("Creating model") print("Creating model")
model = torchvision.models.video.__dict__[args.model](weights=args.weights) model = torchvision.models.get_model(args.model, weights=args.weights)
model.to(device) model.to(device)
if args.distributed and args.sync_bn: if args.distributed and args.sync_bn:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment