Unverified Commit 115d2eb7 authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

Add pretrained arg to reference scripts (#935)

Allows for easily evaluating the pre-trained models in the modelzoo
parent 6e5599e4
...@@ -144,7 +144,7 @@ def main(args): ...@@ -144,7 +144,7 @@ def main(args):
sampler=test_sampler, num_workers=args.workers, pin_memory=True) sampler=test_sampler, num_workers=args.workers, pin_memory=True)
print("Creating model") print("Creating model")
model = torchvision.models.__dict__[args.model]() model = torchvision.models.__dict__[args.model](pretrained=args.pretrained)
model.to(device) model.to(device)
if args.distributed and args.sync_bn: if args.distributed and args.sync_bn:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
...@@ -242,6 +242,12 @@ def parse_args(): ...@@ -242,6 +242,12 @@ def parse_args():
help="Only test the model", help="Only test the model",
action="store_true", action="store_true",
) )
parser.add_argument(
"--pretrained",
dest="pretrained",
help="Use pre-trained models from the modelzoo",
action="store_true",
)
# distributed training parameters # distributed training parameters
parser.add_argument('--world-size', default=1, type=int, parser.add_argument('--world-size', default=1, type=int,
......
...@@ -76,7 +76,8 @@ def main(args): ...@@ -76,7 +76,8 @@ def main(args):
collate_fn=utils.collate_fn) collate_fn=utils.collate_fn)
print("Creating model") print("Creating model")
model = torchvision.models.detection.__dict__[args.model](num_classes=num_classes) model = torchvision.models.detection.__dict__[args.model](num_classes=num_classes,
pretrained=args.pretrained)
model.to(device) model.to(device)
model_without_ddp = model model_without_ddp = model
...@@ -156,6 +157,12 @@ if __name__ == "__main__": ...@@ -156,6 +157,12 @@ if __name__ == "__main__":
help="Only test the model", help="Only test the model",
action="store_true", action="store_true",
) )
parser.add_argument(
"--pretrained",
dest="pretrained",
help="Use pre-trained models from the modelzoo",
action="store_true",
)
# distributed training parameters # distributed training parameters
parser.add_argument('--world-size', default=1, type=int, parser.add_argument('--world-size', default=1, type=int,
......
...@@ -121,7 +121,9 @@ def main(args): ...@@ -121,7 +121,9 @@ def main(args):
sampler=test_sampler, num_workers=args.workers, sampler=test_sampler, num_workers=args.workers,
collate_fn=utils.collate_fn) collate_fn=utils.collate_fn)
model = torchvision.models.segmentation.__dict__[args.model](num_classes=num_classes, aux_loss=args.aux_loss) model = torchvision.models.segmentation.__dict__[args.model](num_classes=num_classes,
aux_loss=args.aux_loss,
pretrained=args.pretrained)
model.to(device) model.to(device)
if args.distributed: if args.distributed:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
...@@ -205,6 +207,12 @@ def parse_args(): ...@@ -205,6 +207,12 @@ def parse_args():
help="Only test the model", help="Only test the model",
action="store_true", action="store_true",
) )
parser.add_argument(
"--pretrained",
dest="pretrained",
help="Use pre-trained models from the modelzoo",
action="store_true",
)
# distributed training parameters # distributed training parameters
parser.add_argument('--world-size', default=1, type=int, parser.add_argument('--world-size', default=1, type=int,
help='number of distributed processes') help='number of distributed processes')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment