Unverified Commit decb78ee authored by lin bin's avatar lin bin Committed by GitHub
Browse files

fix flops counter bug in auto_pruners_torch.py (#3265)

parent 99bc4594
...@@ -186,7 +186,7 @@ def get_trained_model_optimizer(args, device, train_loader, val_loader, criterio ...@@ -186,7 +186,7 @@ def get_trained_model_optimizer(args, device, train_loader, val_loader, criterio
if args.save_model: if args.save_model:
torch.save(state_dict, os.path.join(args.experiment_data_dir, 'model_trained.pth')) torch.save(state_dict, os.path.join(args.experiment_data_dir, 'model_trained.pth'))
print('Model trained saved to %s', args.experiment_data_dir) print('Model trained saved to %s' % args.experiment_data_dir)
return model, optimizer return model, optimizer
...@@ -312,7 +312,7 @@ def main(args): ...@@ -312,7 +312,7 @@ def main(args):
if args.save_model: if args.save_model:
pruner.export_model( pruner.export_model(
os.path.join(args.experiment_data_dir, 'model_masked.pth'), os.path.join(args.experiment_data_dir, 'mask.pth')) os.path.join(args.experiment_data_dir, 'model_masked.pth'), os.path.join(args.experiment_data_dir, 'mask.pth'))
print('Masked model saved to %s', args.experiment_data_dir) print('Masked model saved to %s' % args.experiment_data_dir)
# model speed up # model speed up
if args.speed_up: if args.speed_up:
...@@ -336,7 +336,7 @@ def main(args): ...@@ -336,7 +336,7 @@ def main(args):
result['performance']['speedup'] = evaluation_result result['performance']['speedup'] = evaluation_result
torch.save(model.state_dict(), os.path.join(args.experiment_data_dir, 'model_speed_up.pth')) torch.save(model.state_dict(), os.path.join(args.experiment_data_dir, 'model_speed_up.pth'))
print('Speed up model saved to %s', args.experiment_data_dir) print('Speed up model saved to %s' % args.experiment_data_dir)
flops, params, _ = count_flops_params(model, get_input_size(args.dataset)) flops, params, _ = count_flops_params(model, get_input_size(args.dataset))
result['flops']['speedup'] = flops result['flops']['speedup'] = flops
result['params']['speedup'] = params result['params']['speedup'] = params
...@@ -367,7 +367,7 @@ def main(args): ...@@ -367,7 +367,7 @@ def main(args):
torch.save(model.state_dict(), os.path.join(args.experiment_data_dir, 'model_fine_tuned.pth')) torch.save(model.state_dict(), os.path.join(args.experiment_data_dir, 'model_fine_tuned.pth'))
print('Evaluation result (fine tuned): %s' % best_acc) print('Evaluation result (fine tuned): %s' % best_acc)
print('Fined tuned model saved to %s', args.experiment_data_dir) print('Fined tuned model saved to %s' % args.experiment_data_dir)
result['performance']['finetuned'] = best_acc result['performance']['finetuned'] = best_acc
with open(os.path.join(args.experiment_data_dir, 'result.json'), 'w+') as f: with open(os.path.join(args.experiment_data_dir, 'result.json'), 'w+') as f:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment