# Copyright (c) Microsoft Corporation. # Licensed under the MIT license. ''' Example for supported automatic pruning algorithms. In this example, we present the usage of automatic pruners (NetAdapt, AutoCompressPruner). L1, L2, FPGM pruners are also executed for comparison purpose. ''' import argparse import os import sys import json import torch from torch.optim.lr_scheduler import StepLR, MultiStepLR from torchvision import datasets, transforms from nni.algorithms.compression.pytorch.pruning import L1FilterPruner, L2FilterPruner, FPGMPruner from nni.algorithms.compression.pytorch.pruning import SimulatedAnnealingPruner, ADMMPruner, NetAdaptPruner, AutoCompressPruner from nni.compression.pytorch import ModelSpeedup from nni.compression.pytorch.utils.counter import count_flops_params from pathlib import Path sys.path.append(str(Path(__file__).absolute().parents[1] / 'models')) from mnist.lenet import LeNet from cifar10.vgg import VGG from cifar10.resnet import ResNet18, ResNet50 def get_data(dataset, data_dir, batch_size, test_batch_size): ''' get data ''' kwargs = {'num_workers': 1, 'pin_memory': True} if torch.cuda.is_available() else { } if dataset == 'mnist': train_loader = torch.utils.data.DataLoader( datasets.MNIST(data_dir, train=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ])), batch_size=batch_size, shuffle=True, **kwargs) val_loader = torch.utils.data.DataLoader( datasets.MNIST(data_dir, train=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ])), batch_size=test_batch_size, shuffle=True, **kwargs) criterion = torch.nn.NLLLoss() elif dataset == 'cifar10': normalize = transforms.Normalize( (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) train_loader = torch.utils.data.DataLoader( datasets.CIFAR10(data_dir, train=True, transform=transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, 4), transforms.ToTensor(), normalize, ]), download=True), batch_size=batch_size, shuffle=True, **kwargs) val_loader = torch.utils.data.DataLoader( datasets.CIFAR10(data_dir, train=False, transform=transforms.Compose([ transforms.ToTensor(), normalize, ])), batch_size=batch_size, shuffle=False, **kwargs) criterion = torch.nn.CrossEntropyLoss() return train_loader, val_loader, criterion def train(args, model, device, train_loader, criterion, optimizer, epoch): model.train() for batch_idx, (data, target) in enumerate(train_loader): data, target = data.to(device), target.to(device) optimizer.zero_grad() output = model(data) loss = criterion(output, target) loss.backward() optimizer.step() if batch_idx % args.log_interval == 0: print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( epoch, batch_idx * len(data), len(train_loader.dataset), 100. * batch_idx / len(train_loader), loss.item())) def test(model, device, criterion, val_loader): model.eval() test_loss = 0 correct = 0 with torch.no_grad(): for data, target in val_loader: data, target = data.to(device), target.to(device) output = model(data) # sum up batch loss test_loss += criterion(output, target).item() # get the index of the max log-probability pred = output.argmax(dim=1, keepdim=True) correct += pred.eq(target.view_as(pred)).sum().item() test_loss /= len(val_loader.dataset) accuracy = correct / len(val_loader.dataset) print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format( test_loss, correct, len(val_loader.dataset), 100. * accuracy)) return accuracy def get_trained_model_optimizer(args, device, train_loader, val_loader, criterion): if args.model == 'LeNet': model = LeNet().to(device) if args.load_pretrained_model: model.load_state_dict(torch.load(args.pretrained_model_dir)) optimizer = torch.optim.Adadelta(model.parameters(), lr=1e-4) else: optimizer = torch.optim.Adadelta(model.parameters(), lr=1) scheduler = StepLR(optimizer, step_size=1, gamma=0.7) elif args.model == 'vgg16': model = VGG(depth=16).to(device) if args.load_pretrained_model: model.load_state_dict(torch.load(args.pretrained_model_dir)) optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9, weight_decay=5e-4) else: optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4) scheduler = MultiStepLR( optimizer, milestones=[int(args.pretrain_epochs*0.5), int(args.pretrain_epochs*0.75)], gamma=0.1) elif args.model == 'resnet18': model = ResNet18().to(device) if args.load_pretrained_model: model.load_state_dict(torch.load(args.pretrained_model_dir)) optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9, weight_decay=5e-4) else: optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4) scheduler = MultiStepLR( optimizer, milestones=[int(args.pretrain_epochs*0.5), int(args.pretrain_epochs*0.75)], gamma=0.1) elif args.model == 'resnet50': model = ResNet50().to(device) if args.load_pretrained_model: model.load_state_dict(torch.load(args.pretrained_model_dir)) optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9, weight_decay=5e-4) else: optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4) scheduler = MultiStepLR( optimizer, milestones=[int(args.pretrain_epochs*0.5), int(args.pretrain_epochs*0.75)], gamma=0.1) else: raise ValueError("model not recognized") if not args.load_pretrained_model: best_acc = 0 best_epoch = 0 for epoch in range(args.pretrain_epochs): train(args, model, device, train_loader, criterion, optimizer, epoch) scheduler.step() acc = test(model, device, criterion, val_loader) if acc > best_acc: best_acc = acc best_epoch = epoch state_dict = model.state_dict() model.load_state_dict(state_dict) print('Best acc:', best_acc) print('Best epoch:', best_epoch) if args.save_model: torch.save(state_dict, os.path.join(args.experiment_data_dir, 'model_trained.pth')) print('Model trained saved to %s' % args.experiment_data_dir) return model, optimizer def get_dummy_input(args, device): if args.dataset == 'mnist': dummy_input = torch.randn([args.test_batch_size, 1, 28, 28]).to(device) elif args.dataset in ['cifar10', 'imagenet']: dummy_input = torch.randn([args.test_batch_size, 3, 32, 32]).to(device) return dummy_input def get_input_size(dataset): if dataset == 'mnist': input_size = (1, 1, 28, 28) elif dataset == 'cifar10': input_size = (1, 3, 32, 32) elif dataset == 'imagenet': input_size = (1, 3, 256, 256) return input_size def main(args): # prepare dataset torch.manual_seed(0) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") train_loader, val_loader, criterion = get_data(args.dataset, args.data_dir, args.batch_size, args.test_batch_size) model, optimizer = get_trained_model_optimizer(args, device, train_loader, val_loader, criterion) def short_term_fine_tuner(model, epochs=1): for epoch in range(epochs): train(args, model, device, train_loader, criterion, optimizer, epoch) def trainer(model, optimizer, criterion, epoch): return train(args, model, device, train_loader, criterion, optimizer, epoch=epoch) def evaluator(model): return test(model, device, criterion, val_loader) # used to save the performance of the original & pruned & finetuned models result = {'flops': {}, 'params': {}, 'performance':{}} flops, params, _ = count_flops_params(model, get_input_size(args.dataset)) result['flops']['original'] = flops result['params']['original'] = params evaluation_result = evaluator(model) print('Evaluation result (original model): %s' % evaluation_result) result['performance']['original'] = evaluation_result # module types to prune, only "Conv2d" supported for channel pruning if args.base_algo in ['l1', 'l2', 'fpgm']: op_types = ['Conv2d'] elif args.base_algo == 'level': op_types = ['default'] config_list = [{ 'sparsity': args.sparsity, 'op_types': op_types }] dummy_input = get_dummy_input(args, device) if args.pruner == 'L1FilterPruner': pruner = L1FilterPruner(model, config_list) elif args.pruner == 'L2FilterPruner': pruner = L2FilterPruner(model, config_list) elif args.pruner == 'FPGMPruner': pruner = FPGMPruner(model, config_list) elif args.pruner == 'NetAdaptPruner': pruner = NetAdaptPruner(model, config_list, short_term_fine_tuner=short_term_fine_tuner, evaluator=evaluator, base_algo=args.base_algo, experiment_data_dir=args.experiment_data_dir) elif args.pruner == 'ADMMPruner': # users are free to change the config here if args.model == 'LeNet': if args.base_algo in ['l1', 'l2', 'fpgm']: config_list = [{ 'sparsity': 0.8, 'op_types': ['Conv2d'], 'op_names': ['conv1'] }, { 'sparsity': 0.92, 'op_types': ['Conv2d'], 'op_names': ['conv2'] }] elif args.base_algo == 'level': config_list = [{ 'sparsity': 0.8, 'op_names': ['conv1'] }, { 'sparsity': 0.92, 'op_names': ['conv2'] }, { 'sparsity': 0.991, 'op_names': ['fc1'] }, { 'sparsity': 0.93, 'op_names': ['fc2'] }] else: raise ValueError('Example only implemented for LeNet.') pruner = ADMMPruner(model, config_list, trainer=trainer, num_iterations=2, epochs_per_iteration=2) elif args.pruner == 'SimulatedAnnealingPruner': pruner = SimulatedAnnealingPruner( model, config_list, evaluator=evaluator, base_algo=args.base_algo, cool_down_rate=args.cool_down_rate, experiment_data_dir=args.experiment_data_dir) elif args.pruner == 'AutoCompressPruner': pruner = AutoCompressPruner( model, config_list, trainer=trainer, evaluator=evaluator, dummy_input=dummy_input, num_iterations=3, optimize_mode='maximize', base_algo=args.base_algo, cool_down_rate=args.cool_down_rate, admm_num_iterations=30, admm_epochs_per_iteration=5, experiment_data_dir=args.experiment_data_dir) else: raise ValueError( "Pruner not supported.") # Pruner.compress() returns the masked model # but for AutoCompressPruner, Pruner.compress() returns directly the pruned model model = pruner.compress() evaluation_result = evaluator(model) print('Evaluation result (masked model): %s' % evaluation_result) result['performance']['pruned'] = evaluation_result if args.save_model: pruner.export_model( os.path.join(args.experiment_data_dir, 'model_masked.pth'), os.path.join(args.experiment_data_dir, 'mask.pth')) print('Masked model saved to %s' % args.experiment_data_dir) # model speedup if args.speedup: if args.pruner != 'AutoCompressPruner': if args.model == 'LeNet': model = LeNet().to(device) elif args.model == 'vgg16': model = VGG(depth=16).to(device) elif args.model == 'resnet18': model = ResNet18().to(device) elif args.model == 'resnet50': model = ResNet50().to(device) model.load_state_dict(torch.load(os.path.join(args.experiment_data_dir, 'model_masked.pth'))) masks_file = os.path.join(args.experiment_data_dir, 'mask.pth') m_speedup = ModelSpeedup(model, dummy_input, masks_file, device) m_speedup.speedup_model() evaluation_result = evaluator(model) print('Evaluation result (speedup model): %s' % evaluation_result) result['performance']['speedup'] = evaluation_result torch.save(model.state_dict(), os.path.join(args.experiment_data_dir, 'model_speedup.pth')) print('Speedup model saved to %s' % args.experiment_data_dir) flops, params, _ = count_flops_params(model, get_input_size(args.dataset)) result['flops']['speedup'] = flops result['params']['speedup'] = params if args.fine_tune: if args.dataset == 'mnist': optimizer = torch.optim.Adadelta(model.parameters(), lr=1) scheduler = StepLR(optimizer, step_size=1, gamma=0.7) elif args.dataset == 'cifar10' and args.model == 'vgg16': optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4) scheduler = MultiStepLR( optimizer, milestones=[int(args.fine_tune_epochs*0.5), int(args.fine_tune_epochs*0.75)], gamma=0.1) elif args.dataset == 'cifar10' and args.model == 'resnet18': optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4) scheduler = MultiStepLR( optimizer, milestones=[int(args.fine_tune_epochs*0.5), int(args.fine_tune_epochs*0.75)], gamma=0.1) elif args.dataset == 'cifar10' and args.model == 'resnet50': optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4) scheduler = MultiStepLR( optimizer, milestones=[int(args.fine_tune_epochs*0.5), int(args.fine_tune_epochs*0.75)], gamma=0.1) best_acc = 0 for epoch in range(args.fine_tune_epochs): train(args, model, device, train_loader, criterion, optimizer, epoch) scheduler.step() acc = evaluator(model) if acc > best_acc: best_acc = acc torch.save(model.state_dict(), os.path.join(args.experiment_data_dir, 'model_fine_tuned.pth')) print('Evaluation result (fine tuned): %s' % best_acc) print('Fined tuned model saved to %s' % args.experiment_data_dir) result['performance']['finetuned'] = best_acc with open(os.path.join(args.experiment_data_dir, 'result.json'), 'w+') as f: json.dump(result, f) if __name__ == '__main__': def str2bool(s): if isinstance(s, bool): return s if s.lower() in ('yes', 'true', 't', 'y', '1'): return True if s.lower() in ('no', 'false', 'f', 'n', '0'): return False raise argparse.ArgumentTypeError('Boolean value expected.') parser = argparse.ArgumentParser(description='PyTorch Example for SimulatedAnnealingPruner') # dataset and model parser.add_argument('--dataset', type=str, default='cifar10', help='dataset to use, mnist, cifar10 or imagenet') parser.add_argument('--data-dir', type=str, default='./data/', help='dataset directory') parser.add_argument('--model', type=str, default='vgg16', help='model to use, LeNet, vgg16, resnet18 or resnet50') parser.add_argument('--load-pretrained-model', type=str2bool, default=False, help='whether to load pretrained model') parser.add_argument('--pretrained-model-dir', type=str, default='./', help='path to pretrained model') parser.add_argument('--pretrain-epochs', type=int, default=100, help='number of epochs to pretrain the model') parser.add_argument('--batch-size', type=int, default=64, help='input batch size for training (default: 64)') parser.add_argument('--test-batch-size', type=int, default=64, help='input batch size for testing (default: 64)') parser.add_argument('--fine-tune', type=str2bool, default=True, help='whether to fine-tune the pruned model') parser.add_argument('--fine-tune-epochs', type=int, default=5, help='epochs to fine tune') parser.add_argument('--experiment-data-dir', type=str, default='./experiment_data', help='For saving experiment data') # pruner parser.add_argument('--pruner', type=str, default='SimulatedAnnealingPruner', help='pruner to use') parser.add_argument('--base-algo', type=str, default='l1', help='base pruning algorithm. level, l1, l2, or fpgm') parser.add_argument('--sparsity', type=float, default=0.1, help='target overall target sparsity') # param for SimulatedAnnealingPruner parser.add_argument('--cool-down-rate', type=float, default=0.9, help='cool down rate') # param for NetAdaptPruner parser.add_argument('--sparsity-per-iteration', type=float, default=0.05, help='sparsity_per_iteration of NetAdaptPruner') # speedup parser.add_argument('--speedup', type=str2bool, default=False, help='Whether to speedup the pruned model') # others parser.add_argument('--log-interval', type=int, default=200, help='how many batches to wait before logging training status') parser.add_argument('--save-model', type=str2bool, default=True, help='For Saving the current Model') args = parser.parse_args() if not os.path.exists(args.experiment_data_dir): os.makedirs(args.experiment_data_dir) main(args)