Unverified Commit 61a52b93 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Add --prototype flag to quantization scripts. (#5334)

parent 9d7177fe
......@@ -13,14 +13,16 @@ from train import train_one_epoch, evaluate, load_data
try:
from torchvision.prototype import models as PM
from torchvision import prototype
except ImportError:
PM = None
prototype = None
def main(args):
if args.weights and PM is None:
if args.prototype and prototype is None:
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
if not args.prototype and args.weights:
raise ValueError("The weights parameter works only in prototype mode. Please pass the --prototype argument.")
if args.output_dir:
utils.mkdir(args.output_dir)
......@@ -54,10 +56,10 @@ def main(args):
print("Creating model", args.model)
# when training quantized models, we always start from a pre-trained fp32 reference model
if not args.weights:
if not args.prototype:
model = torchvision.models.quantization.__dict__[args.model](pretrained=True, quantize=args.test_only)
else:
model = PM.quantization.__dict__[args.model](weights=args.weights, quantize=args.test_only)
model = prototype.models.quantization.__dict__[args.model](weights=args.weights, quantize=args.test_only)
model.to(device)
if not (args.test_only or args.post_training_quantize):
......@@ -264,6 +266,12 @@ def get_args_parser(add_help=True):
parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)")
# Prototype models only
parser.add_argument(
"--prototype",
dest="prototype",
help="Use prototype model builders instead those from main area",
action="store_true",
)
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
return parser
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment