Unverified Commit cd23bc41 authored by HeekangPark's avatar HeekangPark Committed by GitHub
Browse files

Fix Error in SPOS Example supernet.py (#2961)


Co-authored-by: default avatarliuzhe-lz <40699903+liuzhe-lz@users.noreply.github.com>
parent 88a225f8
...@@ -45,6 +45,7 @@ if __name__ == "__main__": ...@@ -45,6 +45,7 @@ if __name__ == "__main__":
torch.backends.cudnn.deterministic = True torch.backends.cudnn.deterministic = True
model = ShuffleNetV2OneShot() model = ShuffleNetV2OneShot()
flops_func = model.get_candidate_flops
if args.load_checkpoint: if args.load_checkpoint:
if not args.spos_preprocessing: if not args.spos_preprocessing:
logger.warning("You might want to use SPOS preprocessing if you are loading their checkpoints.") logger.warning("You might want to use SPOS preprocessing if you are loading their checkpoints.")
...@@ -52,7 +53,7 @@ if __name__ == "__main__": ...@@ -52,7 +53,7 @@ if __name__ == "__main__":
model.cuda() model.cuda()
if torch.cuda.device_count() > 1: # exclude last gpu, saving for data preprocessing on gpu if torch.cuda.device_count() > 1: # exclude last gpu, saving for data preprocessing on gpu
model = nn.DataParallel(model, device_ids=list(range(0, torch.cuda.device_count() - 1))) model = nn.DataParallel(model, device_ids=list(range(0, torch.cuda.device_count() - 1)))
mutator = SPOSSupernetTrainingMutator(model, flops_func=model.module.get_candidate_flops, mutator = SPOSSupernetTrainingMutator(model, flops_func=flops_func,
flops_lb=290E6, flops_ub=360E6) flops_lb=290E6, flops_ub=360E6)
criterion = CrossEntropyLabelSmooth(1000, args.label_smoothing) criterion = CrossEntropyLabelSmooth(1000, args.label_smoothing)
optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment