You need to sign in or sign up before continuing.
Unverified Commit 1d0786b0 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Update references to use the new Model Registration API (#6369)

* Expose on Hub the public methods of the registration API

* Limit methods and update docs.

* Update references to use the new Model Registration API
parent c72b2843
...@@ -221,7 +221,7 @@ def main(args): ...@@ -221,7 +221,7 @@ def main(args):
) )
print("Creating model") print("Creating model")
model = torchvision.models.__dict__[args.model](weights=args.weights, num_classes=num_classes) model = torchvision.models.get_model(args.model, weights=args.weights, num_classes=num_classes)
model.to(device) model.to(device)
if args.distributed and args.sync_bn: if args.distributed and args.sync_bn:
......
...@@ -46,7 +46,11 @@ def main(args): ...@@ -46,7 +46,11 @@ def main(args):
print("Creating model", args.model) print("Creating model", args.model)
# when training quantized models, we always start from a pre-trained fp32 reference model # when training quantized models, we always start from a pre-trained fp32 reference model
model = torchvision.models.quantization.__dict__[args.model](weights=args.weights, quantize=args.test_only) prefix = "quantized_"
model_name = args.model
if not model_name.startswith(prefix):
model_name = prefix + model_name
model = torchvision.models.get_model(model_name, weights=args.weights, quantize=args.test_only)
model.to(device) model.to(device)
if not (args.test_only or args.post_training_quantize): if not (args.test_only or args.post_training_quantize):
......
...@@ -216,8 +216,8 @@ def main(args): ...@@ -216,8 +216,8 @@ def main(args):
if "rcnn" in args.model: if "rcnn" in args.model:
if args.rpn_score_thresh is not None: if args.rpn_score_thresh is not None:
kwargs["rpn_score_thresh"] = args.rpn_score_thresh kwargs["rpn_score_thresh"] = args.rpn_score_thresh
model = torchvision.models.detection.__dict__[args.model]( model = torchvision.models.get_model(
weights=args.weights, weights_backbone=args.weights_backbone, num_classes=num_classes, **kwargs args.model, weights=args.weights, weights_backbone=args.weights_backbone, num_classes=num_classes, **kwargs
) )
model.to(device) model.to(device)
if args.distributed and args.sync_bn: if args.distributed and args.sync_bn:
......
...@@ -215,7 +215,7 @@ def main(args): ...@@ -215,7 +215,7 @@ def main(args):
else: else:
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
model = torchvision.models.optical_flow.__dict__[args.model](weights=args.weights) model = torchvision.models.get_model(args.model, weights=args.weights)
if args.distributed: if args.distributed:
model = model.to(args.local_rank) model = model.to(args.local_rank)
......
...@@ -156,8 +156,12 @@ def main(args): ...@@ -156,8 +156,12 @@ def main(args):
dataset_test, batch_size=1, sampler=test_sampler, num_workers=args.workers, collate_fn=utils.collate_fn dataset_test, batch_size=1, sampler=test_sampler, num_workers=args.workers, collate_fn=utils.collate_fn
) )
model = torchvision.models.segmentation.__dict__[args.model]( model = torchvision.models.get_model(
weights=args.weights, weights_backbone=args.weights_backbone, num_classes=num_classes, aux_loss=args.aux_loss args.model,
weights=args.weights,
weights_backbone=args.weights_backbone,
num_classes=num_classes,
aux_loss=args.aux_loss,
) )
model.to(device) model.to(device)
if args.distributed: if args.distributed:
......
...@@ -246,7 +246,7 @@ def main(args): ...@@ -246,7 +246,7 @@ def main(args):
) )
print("Creating model") print("Creating model")
model = torchvision.models.video.__dict__[args.model](weights=args.weights) model = torchvision.models.get_model(args.model, weights=args.weights)
model.to(device) model.to(device)
if args.distributed and args.sync_bn: if args.distributed and args.sync_bn:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment