"...targets/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "ad4a68f0102a7bb4375129176fbb36c17b0bc539"
Unverified Commit d5036857 authored by chicm-ms's avatar chicm-ms Committed by GitHub
Browse files

Fix amc example (#2976)

parent 143ac285
...@@ -41,6 +41,7 @@ Pruning algorithms compress the original network by removing redundant weights o ...@@ -41,6 +41,7 @@ Pruning algorithms compress the original network by removing redundant weights o
| [NetAdapt Pruner](https://nni.readthedocs.io/en/latest/Compression/Pruner.html#netadapt-pruner) | Automatically simplify a pretrained network to meet the resource budget by iterative pruning [Reference Paper](https://arxiv.org/abs/1804.03230) | | [NetAdapt Pruner](https://nni.readthedocs.io/en/latest/Compression/Pruner.html#netadapt-pruner) | Automatically simplify a pretrained network to meet the resource budget by iterative pruning [Reference Paper](https://arxiv.org/abs/1804.03230) |
| [SimulatedAnnealing Pruner](https://nni.readthedocs.io/en/latest/Compression/Pruner.html#simulatedannealing-pruner) | Automatic pruning with a guided heuristic search method, Simulated Annealing algorithm [Reference Paper](https://arxiv.org/abs/1907.03141) | | [SimulatedAnnealing Pruner](https://nni.readthedocs.io/en/latest/Compression/Pruner.html#simulatedannealing-pruner) | Automatic pruning with a guided heuristic search method, Simulated Annealing algorithm [Reference Paper](https://arxiv.org/abs/1907.03141) |
| [AutoCompress Pruner](https://nni.readthedocs.io/en/latest/Compression/Pruner.html#autocompress-pruner) | Automatic pruning by iteratively call SimulatedAnnealing Pruner and ADMM Pruner [Reference Paper](https://arxiv.org/abs/1907.03141) | | [AutoCompress Pruner](https://nni.readthedocs.io/en/latest/Compression/Pruner.html#autocompress-pruner) | Automatic pruning by iteratively call SimulatedAnnealing Pruner and ADMM Pruner [Reference Paper](https://arxiv.org/abs/1907.03141) |
| [AMC Pruner](https://nni.readthedocs.io/en/latest/Compression/Pruner.html#amc-pruner) | AMC: AutoML for Model Compression and Acceleration on Mobile Devices [Reference Paper](https://arxiv.org/pdf/1802.03494.pdf) |
You can refer to this [benchmark](https://github.com/microsoft/nni/tree/master/docs/en_US/CommunitySharings/ModelCompressionComparison.md) for the performance of these pruners on some benchmark problems. You can refer to this [benchmark](https://github.com/microsoft/nni/tree/master/docs/en_US/CommunitySharings/ModelCompressionComparison.md) for the performance of these pruners on some benchmark problems.
......
...@@ -20,7 +20,7 @@ We provide several pruning algorithms that support fine-grained weight pruning a ...@@ -20,7 +20,7 @@ We provide several pruning algorithms that support fine-grained weight pruning a
* [NetAdapt Pruner](#netadapt-pruner) * [NetAdapt Pruner](#netadapt-pruner)
* [SimulatedAnnealing Pruner](#simulatedannealing-pruner) * [SimulatedAnnealing Pruner](#simulatedannealing-pruner)
* [AutoCompress Pruner](#autocompress-pruner) * [AutoCompress Pruner](#autocompress-pruner)
* [AutoML for Model Compression Pruner](#automl-for-model-compression-pruner) * [AMC Pruner](#amc-pruner)
* [Sensitivity Pruner](#sensitivity-pruner) * [Sensitivity Pruner](#sensitivity-pruner)
**Others** **Others**
...@@ -495,9 +495,9 @@ You can view [example](https://github.com/microsoft/nni/blob/master/examples/mod ...@@ -495,9 +495,9 @@ You can view [example](https://github.com/microsoft/nni/blob/master/examples/mod
.. autoclass:: nni.compression.torch.AutoCompressPruner .. autoclass:: nni.compression.torch.AutoCompressPruner
``` ```
## AutoML for Model Compression Pruner ## AMC Pruner
AutoML for Model Compression Pruner (AMCPruner) leverages reinforcement learning to provide the model compression policy. AMC pruner leverages reinforcement learning to provide the model compression policy.
This learning-based compression policy outperforms conventional rule-based compression policy by having higher compression ratio, This learning-based compression policy outperforms conventional rule-based compression policy by having higher compression ratio,
better preserving the accuracy and freeing human labor. better preserving the accuracy and freeing human labor.
......
# AMCPruner Example
This example shows us how to use AMCPruner example.
## Step 1: train a model for pruning
Run following command to train a mobilenetv2 model:
```bash
python3 amc_train.py --model_type mobilenetv2 --n_epochs 50
```
Once finished, saved checkpoint file can be found at:
```
logs/mobilenetv2_cifar10_train-run1/ckpt.best.pth
```
## Pruning with AMCPruner
Run following command to prune the trained model:
```bash
python3 amc_search.py --model_type mobilenetv2 --ckpt logs/mobilenetv2_cifar10_train-run1/ckpt.best.pth
```
Once finished, pruned model and mask can be found at:
```
logs/mobilenetv2_cifar10_r0.5_search-run2
```
## Finetune pruned model
Run `amc_train.py` again with `--ckpt` and `--mask` to speedup and finetune the pruned model:
```bash
python3 amc_train.py --model_type mobilenetv2 --ckpt logs/mobilenetv2_cifar10_r0.5_search-run2/best_model.pth --mask logs/mobilenetv2_cifar10_r0.5_search-run2/best_mask.pth --n_epoch 100
```
...@@ -20,7 +20,7 @@ def parse_args(): ...@@ -20,7 +20,7 @@ def parse_args():
help='model to prune') help='model to prune')
parser.add_argument('--dataset', default='cifar10', type=str, choices=['cifar10', 'imagenet'], help='dataset to use (cifar/imagenet)') parser.add_argument('--dataset', default='cifar10', type=str, choices=['cifar10', 'imagenet'], help='dataset to use (cifar/imagenet)')
parser.add_argument('--batch_size', default=50, type=int, help='number of data batch size') parser.add_argument('--batch_size', default=50, type=int, help='number of data batch size')
parser.add_argument('--data_root', default='./cifar10', type=str, help='dataset path') parser.add_argument('--data_root', default='./data', type=str, help='dataset path')
parser.add_argument('--flops_ratio', default=0.5, type=float, help='target flops ratio to preserve of the model') parser.add_argument('--flops_ratio', default=0.5, type=float, help='target flops ratio to preserve of the model')
parser.add_argument('--lbound', default=0.2, type=float, help='minimum sparsity') parser.add_argument('--lbound', default=0.2, type=float, help='minimum sparsity')
parser.add_argument('--rbound', default=1., type=float, help='maximum sparsity') parser.add_argument('--rbound', default=1., type=float, help='maximum sparsity')
......
...@@ -13,6 +13,7 @@ import torch ...@@ -13,6 +13,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.optim as optim import torch.optim as optim
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
from torchvision.models import resnet
from nni.compression.torch.pruning.amc.lib.net_measure import measure_model from nni.compression.torch.pruning.amc.lib.net_measure import measure_model
from nni.compression.torch.pruning.amc.lib.utils import get_output_folder from nni.compression.torch.pruning.amc.lib.utils import get_output_folder
...@@ -27,7 +28,9 @@ from mobilenet_v2 import MobileNetV2 ...@@ -27,7 +28,9 @@ from mobilenet_v2 import MobileNetV2
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description='AMC train / fine-tune script') parser = argparse.ArgumentParser(description='AMC train / fine-tune script')
parser.add_argument('--model_type', default='mobilenet', type=str, help='name of the model to train') parser.add_argument('--model_type', default='mobilenet', type=str,
choices=['mobilenet', 'mobilenetv2', 'resnet18', 'resnet34', 'resnet50'],
help='name of the model to train')
parser.add_argument('--dataset', default='cifar10', type=str, help='name of the dataset to train') parser.add_argument('--dataset', default='cifar10', type=str, help='name of the dataset to train')
parser.add_argument('--lr', default=0.05, type=float, help='learning rate') parser.add_argument('--lr', default=0.05, type=float, help='learning rate')
parser.add_argument('--n_gpu', default=4, type=int, help='number of GPUs to use') parser.add_argument('--n_gpu', default=4, type=int, help='number of GPUs to use')
...@@ -62,17 +65,21 @@ def get_model(args): ...@@ -62,17 +65,21 @@ def get_model(args):
net = MobileNet(n_class=n_class) net = MobileNet(n_class=n_class)
elif args.model_type == 'mobilenetv2': elif args.model_type == 'mobilenetv2':
net = MobileNetV2(n_class=n_class) net = MobileNetV2(n_class=n_class)
elif args.model_type.startswith('resnet'):
net = resnet.__dict__[args.model_type](pretrained=True)
in_features = net.fc.in_features
net.fc = nn.Linear(in_features, n_class)
else: else:
raise NotImplementedError raise NotImplementedError
if args.ckpt_path is not None: if args.ckpt_path is not None:
# the checkpoint can be state_dict exported by amc_search.py or saved by amc_train.py # the checkpoint can be state_dict exported by amc_search.py or saved by amc_train.py
print('=> Loading checkpoint {} ..'.format(args.ckpt_path)) print('=> Loading checkpoint {} ..'.format(args.ckpt_path))
net.load_state_dict(torch.load(args.ckpt_path)) net.load_state_dict(torch.load(args.ckpt_path, torch.device('cpu')))
if args.mask_path is not None: if args.mask_path is not None:
SZ = 224 if args.dataset == 'imagenet' else 32 SZ = 224 if args.dataset == 'imagenet' else 32
data = torch.randn(2, 3, SZ, SZ) data = torch.randn(2, 3, SZ, SZ)
ms = ModelSpeedup(net, data, args.mask_path) ms = ModelSpeedup(net, data, args.mask_path, torch.device('cpu'))
ms.speedup_model() ms.speedup_model()
net.to(args.device) net.to(args.device)
...@@ -179,11 +186,11 @@ def adjust_learning_rate(optimizer, epoch): ...@@ -179,11 +186,11 @@ def adjust_learning_rate(optimizer, epoch):
return lr return lr
def save_checkpoint(state, is_best, checkpoint_dir='.'): def save_checkpoint(state, is_best, checkpoint_dir='.'):
filename = os.path.join(checkpoint_dir, 'ckpt.pth.tar') filename = os.path.join(checkpoint_dir, 'ckpt.pth')
print('=> Saving checkpoint to {}'.format(filename)) print('=> Saving checkpoint to {}'.format(filename))
torch.save(state, filename) torch.save(state, filename)
if is_best: if is_best:
shutil.copyfile(filename, filename.replace('.pth.tar', '.best.pth.tar')) shutil.copyfile(filename, filename.replace('.pth', '.best.pth'))
if __name__ == '__main__': if __name__ == '__main__':
args = parse_args() args = parse_args()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment