Unverified Commit a911b856 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Resolve conflicts for #4760 (#4762)

parent 14d2966b
......@@ -19,7 +19,7 @@ from torch.optim.lr_scheduler import MultiStepLR
from basic_pruners_torch import get_data
from pathlib import Path
sys.path.append(str(Path(__file__).absolute().parents[1] / 'models'))
sys.path.append(str(Path(__file__).absolute().parents[2] / 'models'))
from mnist.lenet import LeNet
from cifar10.vgg import VGG
......
......@@ -217,9 +217,9 @@ def parse_args():
parser.add_argument('--agp_n_epochs_per_iter', type=int, default=1,
help='number of epochs per iteration for agp')
# speed-up
parser.add_argument('--speed_up', action='store_true', default=False,
help='Whether to speed-up the pruned model')
# speedup
parser.add_argument('--speedup', action='store_true', default=False,
help='Whether to speedup the pruned model')
# finetuning parameters
parser.add_argument('--n_workers', type=int, default=16,
......@@ -336,7 +336,7 @@ def run_pruning(args):
# model speedup
pruner._unwrap_model()
if args.speed_up:
if args.speedup:
dummy_input = torch.rand(1,3,224,224).to(device)
ms = ModelSpeedup(model, dummy_input, args.experiment_dir + './mask_temp.pth')
ms.speedup_model()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment