Unverified Commit d5036857 authored by chicm-ms's avatar chicm-ms Committed by GitHub
Browse files

Fix amc example (#2976)

parent 143ac285
......@@ -41,6 +41,7 @@ Pruning algorithms compress the original network by removing redundant weights o
| [NetAdapt Pruner](https://nni.readthedocs.io/en/latest/Compression/Pruner.html#netadapt-pruner) | Automatically simplify a pretrained network to meet the resource budget by iterative pruning [Reference Paper](https://arxiv.org/abs/1804.03230) |
| [SimulatedAnnealing Pruner](https://nni.readthedocs.io/en/latest/Compression/Pruner.html#simulatedannealing-pruner) | Automatic pruning with a guided heuristic search method, Simulated Annealing algorithm [Reference Paper](https://arxiv.org/abs/1907.03141) |
| [AutoCompress Pruner](https://nni.readthedocs.io/en/latest/Compression/Pruner.html#autocompress-pruner) | Automatic pruning by iteratively call SimulatedAnnealing Pruner and ADMM Pruner [Reference Paper](https://arxiv.org/abs/1907.03141) |
| [AMC Pruner](https://nni.readthedocs.io/en/latest/Compression/Pruner.html#amc-pruner) | AMC: AutoML for Model Compression and Acceleration on Mobile Devices [Reference Paper](https://arxiv.org/pdf/1802.03494.pdf) |
You can refer to this [benchmark](https://github.com/microsoft/nni/tree/master/docs/en_US/CommunitySharings/ModelCompressionComparison.md) for the performance of these pruners on some benchmark problems.
......
......@@ -20,7 +20,7 @@ We provide several pruning algorithms that support fine-grained weight pruning a
* [NetAdapt Pruner](#netadapt-pruner)
* [SimulatedAnnealing Pruner](#simulatedannealing-pruner)
* [AutoCompress Pruner](#autocompress-pruner)
* [AutoML for Model Compression Pruner](#automl-for-model-compression-pruner)
* [AMC Pruner](#amc-pruner)
* [Sensitivity Pruner](#sensitivity-pruner)
**Others**
......@@ -495,9 +495,9 @@ You can view [example](https://github.com/microsoft/nni/blob/master/examples/mod
.. autoclass:: nni.compression.torch.AutoCompressPruner
```
## AutoML for Model Compression Pruner
## AMC Pruner
AutoML for Model Compression Pruner (AMCPruner) leverages reinforcement learning to provide the model compression policy.
AMC pruner leverages reinforcement learning to provide the model compression policy.
This learning-based compression policy outperforms conventional rule-based compression policy by having higher compression ratio,
better preserving the accuracy and freeing human labor.
......
# AMCPruner Example
This example shows us how to use AMCPruner example.
## Step 1: train a model for pruning
Run following command to train a mobilenetv2 model:
```bash
python3 amc_train.py --model_type mobilenetv2 --n_epochs 50
```
Once finished, saved checkpoint file can be found at:
```
logs/mobilenetv2_cifar10_train-run1/ckpt.best.pth
```
## Pruning with AMCPruner
Run following command to prune the trained model:
```bash
python3 amc_search.py --model_type mobilenetv2 --ckpt logs/mobilenetv2_cifar10_train-run1/ckpt.best.pth
```
Once finished, pruned model and mask can be found at:
```
logs/mobilenetv2_cifar10_r0.5_search-run2
```
## Finetune pruned model
Run `amc_train.py` again with `--ckpt` and `--mask` to speedup and finetune the pruned model:
```bash
python3 amc_train.py --model_type mobilenetv2 --ckpt logs/mobilenetv2_cifar10_r0.5_search-run2/best_model.pth --mask logs/mobilenetv2_cifar10_r0.5_search-run2/best_mask.pth --n_epoch 100
```
......@@ -20,7 +20,7 @@ def parse_args():
help='model to prune')
parser.add_argument('--dataset', default='cifar10', type=str, choices=['cifar10', 'imagenet'], help='dataset to use (cifar/imagenet)')
parser.add_argument('--batch_size', default=50, type=int, help='number of data batch size')
parser.add_argument('--data_root', default='./cifar10', type=str, help='dataset path')
parser.add_argument('--data_root', default='./data', type=str, help='dataset path')
parser.add_argument('--flops_ratio', default=0.5, type=float, help='target flops ratio to preserve of the model')
parser.add_argument('--lbound', default=0.2, type=float, help='minimum sparsity')
parser.add_argument('--rbound', default=1., type=float, help='maximum sparsity')
......
......@@ -13,6 +13,7 @@ import torch
import torch.nn as nn
import torch.optim as optim
from tensorboardX import SummaryWriter
from torchvision.models import resnet
from nni.compression.torch.pruning.amc.lib.net_measure import measure_model
from nni.compression.torch.pruning.amc.lib.utils import get_output_folder
......@@ -27,7 +28,9 @@ from mobilenet_v2 import MobileNetV2
def parse_args():
parser = argparse.ArgumentParser(description='AMC train / fine-tune script')
parser.add_argument('--model_type', default='mobilenet', type=str, help='name of the model to train')
parser.add_argument('--model_type', default='mobilenet', type=str,
choices=['mobilenet', 'mobilenetv2', 'resnet18', 'resnet34', 'resnet50'],
help='name of the model to train')
parser.add_argument('--dataset', default='cifar10', type=str, help='name of the dataset to train')
parser.add_argument('--lr', default=0.05, type=float, help='learning rate')
parser.add_argument('--n_gpu', default=4, type=int, help='number of GPUs to use')
......@@ -62,17 +65,21 @@ def get_model(args):
net = MobileNet(n_class=n_class)
elif args.model_type == 'mobilenetv2':
net = MobileNetV2(n_class=n_class)
elif args.model_type.startswith('resnet'):
net = resnet.__dict__[args.model_type](pretrained=True)
in_features = net.fc.in_features
net.fc = nn.Linear(in_features, n_class)
else:
raise NotImplementedError
if args.ckpt_path is not None:
# the checkpoint can be state_dict exported by amc_search.py or saved by amc_train.py
print('=> Loading checkpoint {} ..'.format(args.ckpt_path))
net.load_state_dict(torch.load(args.ckpt_path))
net.load_state_dict(torch.load(args.ckpt_path, torch.device('cpu')))
if args.mask_path is not None:
SZ = 224 if args.dataset == 'imagenet' else 32
data = torch.randn(2, 3, SZ, SZ)
ms = ModelSpeedup(net, data, args.mask_path)
ms = ModelSpeedup(net, data, args.mask_path, torch.device('cpu'))
ms.speedup_model()
net.to(args.device)
......@@ -179,11 +186,11 @@ def adjust_learning_rate(optimizer, epoch):
return lr
def save_checkpoint(state, is_best, checkpoint_dir='.'):
filename = os.path.join(checkpoint_dir, 'ckpt.pth.tar')
filename = os.path.join(checkpoint_dir, 'ckpt.pth')
print('=> Saving checkpoint to {}'.format(filename))
torch.save(state, filename)
if is_best:
shutil.copyfile(filename, filename.replace('.pth.tar', '.best.pth.tar'))
shutil.copyfile(filename, filename.replace('.pth', '.best.pth'))
if __name__ == '__main__':
args = parse_args()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment