Unverified Commit c717ce57 authored by J-shang's avatar J-shang Committed by GitHub
Browse files

update pruning example (#3844)

parent 507595b0
......@@ -143,7 +143,7 @@ def get_model_optimizer_scheduler(args, device, train_loader, test_loader, crite
model.load_state_dict(torch.load(args.pretrained_model_dir))
best_acc = test(args, model, device, criterion, test_loader)
# setup new opotimizer for fine-tuning
# setup new opotimizer for pruning
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
scheduler = MultiStepLR(optimizer, milestones=[int(args.pretrain_epochs * 0.5), int(args.pretrain_epochs * 0.75)], gamma=0.1)
......@@ -192,10 +192,10 @@ def main(args):
# prepare model and data
train_loader, test_loader, criterion = get_data(args.dataset, args.data_dir, args.batch_size, args.test_batch_size)
model, optimizer, scheduler = get_model_optimizer_scheduler(args, device, train_loader, test_loader, criterion)
model, optimizer, _ = get_model_optimizer_scheduler(args, device, train_loader, test_loader, criterion)
dummy_input = get_dummy_input(args, device)
flops, params, results = count_flops_params(model, dummy_input)
flops, params, _ = count_flops_params(model, dummy_input)
print(f"FLOPs: {flops}, params: {params}")
print(f'start {args.pruner} pruning...')
......@@ -273,11 +273,16 @@ def main(args):
if args.speed_up:
# Unwrap all modules to normal state
pruner._unwrap_model()
pruner._unwrap_model()
m_speedup = ModelSpeedup(model, dummy_input, mask_path, device)
m_speedup.speedup_model()
print('start finetuning...')
# Optimizer used in the pruner might be patched, so recommend to new an optimizer for fine-tuning stage.
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
scheduler = MultiStepLR(optimizer, milestones=[int(args.pretrain_epochs * 0.5), int(args.pretrain_epochs * 0.75)], gamma=0.1)
best_top1 = 0
save_path = os.path.join(args.experiment_data_dir, f'finetuned.pth')
for epoch in range(args.fine_tune_epochs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment