Unverified Commit 61a52b93 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Add --prototype flag to quantization scripts. (#5334)

parent 9d7177fe
...@@ -13,14 +13,16 @@ from train import train_one_epoch, evaluate, load_data ...@@ -13,14 +13,16 @@ from train import train_one_epoch, evaluate, load_data
try: try:
from torchvision.prototype import models as PM from torchvision import prototype
except ImportError: except ImportError:
PM = None prototype = None
def main(args): def main(args):
if args.weights and PM is None: if args.prototype and prototype is None:
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.") raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
if not args.prototype and args.weights:
raise ValueError("The weights parameter works only in prototype mode. Please pass the --prototype argument.")
if args.output_dir: if args.output_dir:
utils.mkdir(args.output_dir) utils.mkdir(args.output_dir)
...@@ -54,10 +56,10 @@ def main(args): ...@@ -54,10 +56,10 @@ def main(args):
print("Creating model", args.model) print("Creating model", args.model)
# when training quantized models, we always start from a pre-trained fp32 reference model # when training quantized models, we always start from a pre-trained fp32 reference model
if not args.weights: if not args.prototype:
model = torchvision.models.quantization.__dict__[args.model](pretrained=True, quantize=args.test_only) model = torchvision.models.quantization.__dict__[args.model](pretrained=True, quantize=args.test_only)
else: else:
model = PM.quantization.__dict__[args.model](weights=args.weights, quantize=args.test_only) model = prototype.models.quantization.__dict__[args.model](weights=args.weights, quantize=args.test_only)
model.to(device) model.to(device)
if not (args.test_only or args.post_training_quantize): if not (args.test_only or args.post_training_quantize):
...@@ -264,6 +266,12 @@ def get_args_parser(add_help=True): ...@@ -264,6 +266,12 @@ def get_args_parser(add_help=True):
parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)") parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)")
# Prototype models only # Prototype models only
parser.add_argument(
"--prototype",
dest="prototype",
help="Use prototype model builders instead those from main area",
action="store_true",
)
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
return parser return parser
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment