"test/ut/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "4a0cc125b11f7aafc1dfd43bb24fe29efd3445f1"
Unverified Commit 92f6754e authored by colorjam's avatar colorjam Committed by GitHub
Browse files

[Model Compression] Update api of iterative pruners (#3507)

parent 26f47727
...@@ -34,7 +34,7 @@ Weight Masker ...@@ -34,7 +34,7 @@ Weight Masker
.. autoclass:: nni.algorithms.compression.pytorch.pruning.weight_masker.WeightMasker .. autoclass:: nni.algorithms.compression.pytorch.pruning.weight_masker.WeightMasker
:members: :members:
.. autoclass:: nni.algorithms.compression.pytorch.pruning.structured_pruning.StructuredWeightMasker .. autoclass:: nni.algorithms.compression.pytorch.pruning.structured_pruning_masker.StructuredWeightMasker
:members: :members:
...@@ -43,40 +43,40 @@ Pruners ...@@ -43,40 +43,40 @@ Pruners
.. autoclass:: nni.algorithms.compression.pytorch.pruning.sensitivity_pruner.SensitivityPruner .. autoclass:: nni.algorithms.compression.pytorch.pruning.sensitivity_pruner.SensitivityPruner
:members: :members:
.. autoclass:: nni.algorithms.compression.pytorch.pruning.one_shot.OneshotPruner .. autoclass:: nni.algorithms.compression.pytorch.pruning.one_shot_pruner.OneshotPruner
:members: :members:
.. autoclass:: nni.algorithms.compression.pytorch.pruning.one_shot.LevelPruner .. autoclass:: nni.algorithms.compression.pytorch.pruning.one_shot_pruner.LevelPruner
:members: :members:
.. autoclass:: nni.algorithms.compression.pytorch.pruning.one_shot.SlimPruner .. autoclass:: nni.algorithms.compression.pytorch.pruning.one_shot_pruner.L1FilterPruner
:members: :members:
.. autoclass:: nni.algorithms.compression.pytorch.pruning.one_shot.L1FilterPruner .. autoclass:: nni.algorithms.compression.pytorch.pruning.one_shot_pruner.L2FilterPruner
:members: :members:
.. autoclass:: nni.algorithms.compression.pytorch.pruning.one_shot.L2FilterPruner .. autoclass:: nni.algorithms.compression.pytorch.pruning.one_shot_pruner.FPGMPruner
:members: :members:
.. autoclass:: nni.algorithms.compression.pytorch.pruning.one_shot.FPGMPruner .. autoclass:: nni.algorithms.compression.pytorch.pruning.iterative_pruner.IterativePruner
:members: :members:
.. autoclass:: nni.algorithms.compression.pytorch.pruning.one_shot.TaylorFOWeightFilterPruner .. autoclass:: nni.algorithms.compression.pytorch.pruning.iterative_pruner.SlimPruner
:members: :members:
.. autoclass:: nni.algorithms.compression.pytorch.pruning.one_shot.ActivationAPoZRankFilterPruner .. autoclass:: nni.algorithms.compression.pytorch.pruning.iterative_pruner.TaylorFOWeightFilterPruner
:members: :members:
.. autoclass:: nni.algorithms.compression.pytorch.pruning.one_shot.ActivationMeanRankFilterPruner .. autoclass:: nni.algorithms.compression.pytorch.pruning.iterative_pruner.ActivationAPoZRankFilterPruner
:members: :members:
.. autoclass:: nni.algorithms.compression.pytorch.pruning.lottery_ticket.LotteryTicketPruner .. autoclass:: nni.algorithms.compression.pytorch.pruning.iterative_pruner.ActivationMeanRankFilterPruner
:members: :members:
.. autoclass:: nni.algorithms.compression.pytorch.pruning.agp.AGPPruner .. autoclass:: nni.algorithms.compression.pytorch.pruning.iterative_pruner.AGPPruner
:members: :members:
.. autoclass:: nni.algorithms.compression.pytorch.pruning.admm_pruner.ADMMPruner .. autoclass:: nni.algorithms.compression.pytorch.pruning.iterative_pruner.ADMMPruner
:members: :members:
.. autoclass:: nni.algorithms.compression.pytorch.pruning.auto_compress_pruner.AutoCompressPruner .. autoclass:: nni.algorithms.compression.pytorch.pruning.auto_compress_pruner.AutoCompressPruner
...@@ -88,6 +88,9 @@ Pruners ...@@ -88,6 +88,9 @@ Pruners
.. autoclass:: nni.algorithms.compression.pytorch.pruning.simulated_annealing_pruner.SimulatedAnnealingPruner .. autoclass:: nni.algorithms.compression.pytorch.pruning.simulated_annealing_pruner.SimulatedAnnealingPruner
:members: :members:
.. autoclass:: nni.algorithms.compression.pytorch.pruning.lottery_ticket.LotteryTicketPruner
:members:
Quantizers Quantizers
^^^^^^^^^^ ^^^^^^^^^^
......
...@@ -28,7 +28,7 @@ An implementation of ``weight masker`` may look like this: ...@@ -28,7 +28,7 @@ An implementation of ``weight masker`` may look like this:
# mask = ... # mask = ...
return {'weight_mask': mask} return {'weight_mask': mask}
You can reference nni provided :githublink:`weight masker <nni/algorithms/compression/pytorch/pruning/structured_pruning.py>` implementations to implement your own weight masker. You can reference nni provided :githublink:`weight masker <nni/algorithms/compression/pytorch/pruning/structured_pruning_masker.py>` implementations to implement your own weight masker.
A basic ``pruner`` looks likes this: A basic ``pruner`` looks likes this:
...@@ -52,7 +52,7 @@ A basic ``pruner`` looks likes this: ...@@ -52,7 +52,7 @@ A basic ``pruner`` looks likes this:
wrapper.if_calculated = True wrapper.if_calculated = True
return masks return masks
Reference nni provided :githublink:`pruner <nni/algorithms/compression/pytorch/pruning/one_shot.py>` implementations to implement your own pruner class. Reference nni provided :githublink:`pruner <nni/algorithms/compression/pytorch/pruning/one_shot_pruner.py>` implementations to implement your own pruner class.
---- ----
......
...@@ -14,10 +14,19 @@ NNI provides a model compression toolkit to help user compress and speed up thei ...@@ -14,10 +14,19 @@ NNI provides a model compression toolkit to help user compress and speed up thei
* Provide friendly and easy-to-use compression utilities for users to dive into the compression process and results. * Provide friendly and easy-to-use compression utilities for users to dive into the compression process and results.
* Concise interface for users to customize their own compression algorithms. * Concise interface for users to customize their own compression algorithms.
Compression Pipeline
--------------------
.. image:: ../../img/compression_flow.jpg
:target: ../../img/compression_flow.jpg
:alt:
The overall compression pipeline in NNI. For compressing a pretrained model, pruning and quantization can be used alone or in combination.
.. note:: .. note::
Since NNI compression algorithms are not meant to compress model while NNI speedup tool can truly compress model and reduce latency. To obtain a truly compact model, users should conduct `model speedup <./ModelSpeedup.rst>`__. The interface and APIs are unified for both PyTorch and TensorFlow, currently only PyTorch version has been supported, TensorFlow version will be supported in future. Since NNI compression algorithms are not meant to compress model while NNI speedup tool can truly compress model and reduce latency. To obtain a truly compact model, users should conduct `model speedup <./ModelSpeedup.rst>`__. The interface and APIs are unified for both PyTorch and TensorFlow, currently only PyTorch version has been supported, TensorFlow version will be supported in future.
Supported Algorithms Supported Algorithms
-------------------- --------------------
...@@ -26,7 +35,7 @@ The algorithms include pruning algorithms and quantization algorithms. ...@@ -26,7 +35,7 @@ The algorithms include pruning algorithms and quantization algorithms.
Pruning Algorithms Pruning Algorithms
^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^
Pruning algorithms compress the original network by removing redundant weights or channels of layers, which can reduce model complexity and address the over-fitting issue. Pruning algorithms compress the original network by removing redundant weights or channels of layers, which can reduce model complexity and mitigate the over-fitting issue.
.. list-table:: .. list-table::
:header-rows: 1 :header-rows: 1
...@@ -96,6 +105,7 @@ Model Speedup ...@@ -96,6 +105,7 @@ Model Speedup
The final goal of model compression is to reduce inference latency and model size. However, existing model compression algorithms mainly use simulation to check the performance (e.g., accuracy) of compressed model, for example, using masks for pruning algorithms, and storing quantized values still in float32 for quantization algorithms. Given the output masks and quantization bits produced by those algorithms, NNI can really speed up the model. The detailed tutorial of Masked Model Speedup can be found `here <./ModelSpeedup.rst>`__, The detailed tutorial of Mixed Precision Quantization Model Speedup can be found `here <./QuantizationSpeedup.rst>`__. The final goal of model compression is to reduce inference latency and model size. However, existing model compression algorithms mainly use simulation to check the performance (e.g., accuracy) of compressed model, for example, using masks for pruning algorithms, and storing quantized values still in float32 for quantization algorithms. Given the output masks and quantization bits produced by those algorithms, NNI can really speed up the model. The detailed tutorial of Masked Model Speedup can be found `here <./ModelSpeedup.rst>`__, The detailed tutorial of Mixed Precision Quantization Model Speedup can be found `here <./QuantizationSpeedup.rst>`__.
Compression Utilities Compression Utilities
--------------------- ---------------------
...@@ -110,7 +120,6 @@ NNI model compression leaves simple interface for users to customize a new compr ...@@ -110,7 +120,6 @@ NNI model compression leaves simple interface for users to customize a new compr
Reference and Feedback Reference and Feedback
---------------------- ----------------------
* To `report a bug <https://github.com/microsoft/nni/issues/new?template=bug-report.rst>`__ for this feature in GitHub; * To `report a bug <https://github.com/microsoft/nni/issues/new?template=bug-report.rst>`__ for this feature in GitHub;
* To `file a feature or improvement request <https://github.com/microsoft/nni/issues/new?template=enhancement.rst>`__ for this feature in GitHub; * To `file a feature or improvement request <https://github.com/microsoft/nni/issues/new?template=enhancement.rst>`__ for this feature in GitHub;
* To know more about `Feature Engineering with NNI <../FeatureEngineering/Overview.rst>`__\ ; * To know more about `Feature Engineering with NNI <../FeatureEngineering/Overview.rst>`__\ ;
......
Supported Pruning Algorithms on NNI Supported Pruning Algorithms on NNI
=================================== ===================================
We provide several pruning algorithms that support fine-grained weight pruning and structural filter pruning. **Fine-grained Pruning** generally results in unstructured models, which need specialized hardware or software to speed up the sparse network. **Filter Pruning** achieves acceleration by removing the entire filter. Some pruning algorithms use one-shot method that prune weights at once based on an importance metric. Other pruning algorithms control the **pruning schedule** that prune weights during optimization, including some automatic pruning algorithms. We provide several pruning algorithms that support fine-grained weight pruning and structural filter pruning. **Fine-grained Pruning** generally results in unstructured models, which need specialized hardware or software to speed up the sparse network. **Filter Pruning** achieves acceleration by removing the entire filter. Some pruning algorithms use one-shot method that prune weights at once based on an importance metric (It is necessary to finetune the model to compensate for the loss of accuracy). Other pruning algorithms **iteratively** prune weights during optimization, which control the pruning schedule, including some automatic pruning algorithms.
**Fine-grained Pruning** **One-shot Pruning**
* `Level Pruner <#level-pruner>`__ ((fine-grained pruning))
* `Level Pruner <#level-pruner>`__
**Filter Pruning**
* `Slim Pruner <#slim-pruner>`__ * `Slim Pruner <#slim-pruner>`__
* `FPGM Pruner <#fpgm-pruner>`__ * `FPGM Pruner <#fpgm-pruner>`__
* `L1Filter Pruner <#l1filter-pruner>`__ * `L1Filter Pruner <#l1filter-pruner>`__
...@@ -18,7 +14,7 @@ We provide several pruning algorithms that support fine-grained weight pruning a ...@@ -18,7 +14,7 @@ We provide several pruning algorithms that support fine-grained weight pruning a
* `Activation Mean Rank Filter Pruner <#activationmeanrankfilter-pruner>`__ * `Activation Mean Rank Filter Pruner <#activationmeanrankfilter-pruner>`__
* `Taylor FO On Weight Pruner <#taylorfoweightfilter-pruner>`__ * `Taylor FO On Weight Pruner <#taylorfoweightfilter-pruner>`__
**Pruning Schedule** **Iteratively Pruning**
* `AGP Pruner <#agp-pruner>`__ * `AGP Pruner <#agp-pruner>`__
* `NetAdapt Pruner <#netadapt-pruner>`__ * `NetAdapt Pruner <#netadapt-pruner>`__
...@@ -26,10 +22,9 @@ We provide several pruning algorithms that support fine-grained weight pruning a ...@@ -26,10 +22,9 @@ We provide several pruning algorithms that support fine-grained weight pruning a
* `AutoCompress Pruner <#autocompress-pruner>`__ * `AutoCompress Pruner <#autocompress-pruner>`__
* `AMC Pruner <#amc-pruner>`__ * `AMC Pruner <#amc-pruner>`__
* `Sensitivity Pruner <#sensitivity-pruner>`__ * `Sensitivity Pruner <#sensitivity-pruner>`__
* `ADMM Pruner <#admm-pruner>`__
**Others** **Others**
* `ADMM Pruner <#admm-pruner>`__
* `Lottery Ticket Hypothesis <#lottery-ticket-hypothesis>`__ * `Lottery Ticket Hypothesis <#lottery-ticket-hypothesis>`__
Level Pruner Level Pruner
...@@ -382,11 +377,7 @@ PyTorch code ...@@ -382,11 +377,7 @@ PyTorch code
from nni.algorithms.compression.pytorch.pruning import AGPPruner from nni.algorithms.compression.pytorch.pruning import AGPPruner
config_list = [{ config_list = [{
'initial_sparsity': 0, 'sparsity': 0.8,
'final_sparsity': 0.8,
'start_epoch': 0,
'end_epoch': 10,
'frequency': 1,
'op_types': ['default'] 'op_types': ['default']
}] }]
......
.pth
.tar.gz
data/
MNIST/
cifar-10-batches-py/
experiment_data/
\ No newline at end of file
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
NNI example for combined pruning and quantization to compress a model.
In this example, we show the compression process to first prune a model, then quantize the pruned model.
"""
import argparse
import os
import time
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from torchvision import datasets, transforms
from nni.compression.pytorch.utils.counter import count_flops_params
from nni.compression.pytorch import ModelSpeedup
from nni.algorithms.compression.pytorch.pruning import L1FilterPruner
from nni.algorithms.compression.pytorch.quantization import QAT_Quantizer
from models.mnist.naive import NaiveModel
from nni.compression.pytorch.quantization_speedup import ModelSpeedupTensorRT
def get_model_time_cost(model, dummy_input):
model.eval()
n_times = 100
time_list = []
for _ in range(n_times):
torch.cuda.synchronize()
tic = time.time()
_ = model(dummy_input)
torch.cuda.synchronize()
time_list.append(time.time()-tic)
time_list = time_list[10:]
return sum(time_list) / len(time_list)
def train(args, model, device, train_loader, criterion, optimizer, epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
if batch_idx % args.log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
if args.dry_run:
break
def test(args, model, device, criterion, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += criterion(output, target).item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
acc = 100 * correct / len(test_loader.dataset)
print('Test Loss: {:.6f} Accuracy: {}%\n'.format(
test_loss, acc))
return acc
def test_trt(engine, test_loader):
test_loss = 0
correct = 0
time_elasped = 0
for data, target in test_loader:
output, time = engine.inference(data)
test_loss += F.nll_loss(output, target, reduction='sum').item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
time_elasped += time
test_loss /= len(test_loader.dataset)
print('Loss: {} Accuracy: {}%'.format(
test_loss, 100 * correct / len(test_loader.dataset)))
print("Inference elapsed_time (whole dataset): {}s".format(time_elasped))
def main(args):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.makedirs(args.experiment_data_dir, exist_ok=True)
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('data', train=True, download=True, transform=transform),
batch_size=64,)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('data', train=False, transform=transform),
batch_size=1000)
# Step1. Model Pretraining
model = NaiveModel().to(device)
criterion = torch.nn.NLLLoss()
optimizer = optim.Adadelta(model.parameters(), lr=args.pretrain_lr)
scheduler = StepLR(optimizer, step_size=1, gamma=0.7)
flops, params, _ = count_flops_params(model, (1, 1, 28, 28), verbose=False)
if args.pretrained_model_dir is None:
args.pretrained_model_dir = os.path.join(args.experiment_data_dir, f'pretrained.pth')
best_acc = 0
for epoch in range(args.pretrain_epochs):
train(args, model, device, train_loader, criterion, optimizer, epoch)
scheduler.step()
acc = test(args, model, device, criterion, test_loader)
if acc > best_acc:
best_acc = acc
state_dict = model.state_dict()
model.load_state_dict(state_dict)
torch.save(state_dict, args.pretrained_model_dir)
print(f'Model saved to {args.pretrained_model_dir}')
else:
state_dict = torch.load(args.pretrained_model_dir)
model.load_state_dict(state_dict)
best_acc = test(args, model, device, criterion, test_loader)
dummy_input = torch.randn([1000, 1, 28, 28]).to(device)
time_cost = get_model_time_cost(model, dummy_input)
# 125.49 M, 0.85M, 93.29, 1.1012
print(f'Pretrained model FLOPs {flops/1e6:.2f} M, #Params: {params/1e6:.2f}M, Accuracy: {best_acc: .2f}, Time Cost: {time_cost}')
# Step2. Model Pruning
config_list = [{
'sparsity': args.sparsity,
'op_types': ['Conv2d']
}]
kw_args = {}
if args.dependency_aware:
dummy_input = torch.randn([1000, 1, 28, 28]).to(device)
print('Enable the dependency_aware mode')
# note that, not all pruners support the dependency_aware mode
kw_args['dependency_aware'] = True
kw_args['dummy_input'] = dummy_input
pruner = L1FilterPruner(model, config_list, **kw_args)
model = pruner.compress()
pruner.get_pruned_weights()
mask_path = os.path.join(args.experiment_data_dir, 'mask.pth')
model_path = os.path.join(args.experiment_data_dir, 'pruned.pth')
pruner.export_model(model_path=model_path, mask_path=mask_path)
pruner._unwrap_model() # unwrap all modules to normal state
# Step3. Model Speedup
m_speedup = ModelSpeedup(model, dummy_input, mask_path, device)
m_speedup.speedup_model()
print('model after speedup', model)
flops, params, _ = count_flops_params(model, dummy_input, verbose=False)
acc = test(args, model, device, criterion, test_loader)
time_cost = get_model_time_cost(model, dummy_input)
print(f'Pruned model FLOPs {flops/1e6:.2f} M, #Params: {params/1e6:.2f}M, Accuracy: {acc: .2f}, Time Cost: {time_cost}')
# Step4. Model Finetuning
optimizer = optim.Adadelta(model.parameters(), lr=args.pretrain_lr)
scheduler = StepLR(optimizer, step_size=1, gamma=0.7)
best_acc = 0
for epoch in range(args.finetune_epochs):
train(args, model, device, train_loader, criterion, optimizer, epoch)
scheduler.step()
acc = test(args, model, device, criterion, test_loader)
if acc > best_acc:
best_acc = acc
state_dict = model.state_dict()
model.load_state_dict(state_dict)
save_path = os.path.join(args.experiment_data_dir, f'finetuned.pth')
torch.save(state_dict, save_path)
flops, params, _ = count_flops_params(model, dummy_input, verbose=True)
time_cost = get_model_time_cost(model, dummy_input)
# FLOPs 28.48 M, #Params: 0.18M, Accuracy: 89.03, Time Cost: 1.03
print(f'Finetuned model FLOPs {flops/1e6:.2f} M, #Params: {params/1e6:.2f}M, Accuracy: {best_acc: .2f}, Time Cost: {time_cost}')
print(f'Model saved to {save_path}')
# Step5. Model Quantization via QAT
config_list = [{
'quant_types': ['weight', 'output'],
'quant_bits': {'weight': 8, 'output': 8},
'op_names': ['conv1']
}, {
'quant_types': ['output'],
'quant_bits': {'output':8},
'op_names': ['relu1']
}, {
'quant_types': ['weight', 'output'],
'quant_bits': {'weight': 8, 'output': 8},
'op_names': ['conv2']
}, {
'quant_types': ['output'],
'quant_bits': {'output': 8},
'op_names': ['relu2']
}]
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
quantizer = QAT_Quantizer(model, config_list, optimizer)
quantizer.compress()
# Step6. Quantization Aware Training
best_acc = 0
for epoch in range(1):
train(args, model, device, train_loader, criterion, optimizer, epoch)
scheduler.step()
acc = test(args, model, device, criterion, test_loader)
if acc > best_acc:
best_acc = acc
state_dict = model.state_dict()
calibration_path = os.path.join(args.experiment_data_dir, 'calibration.pth')
calibration_config = quantizer.export_model(model_path, calibration_path)
print("calibration_config: ", calibration_config)
# Step7. Model Speedup
batch_size = 32
input_shape = (batch_size, 1, 28, 28)
engine = ModelSpeedupTensorRT(model, input_shape, config=calibration_config, batchsize=32)
engine.compress()
test_trt(engine, test_loader)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='PyTorch Example for model comporession')
# dataset and model
# parser.add_argument('--dataset', type=str, default='mnist',
# help='dataset to use, mnist, cifar10 or imagenet')
# parser.add_argument('--data-dir', type=str, default='./data/',
# help='dataset directory')
parser.add_argument('--pretrained-model-dir', type=str, default=None,
help='path to pretrained model')
parser.add_argument('--pretrain-epochs', type=int, default=10,
help='number of epochs to pretrain the model')
parser.add_argument('--pretrain-lr', type=float, default=1.0,
help='learning rate to pretrain the model')
parser.add_argument('--experiment-data-dir', type=str, default='./experiment_data',
help='For saving output checkpoints')
parser.add_argument('--log-interval', type=int, default=100, metavar='N',
help='how many batches to wait before logging training status')
parser.add_argument('--dry-run', action='store_true', default=False,
help='quickly check a single pass')
# parser.add_argument('--multi-gpu', action='store_true', default=False,
# help='run on mulitple gpus')
# parser.add_argument('--test-only', action='store_true', default=False,
# help='run test only')
# pruner
# parser.add_argument('--pruner', type=str, default='l1filter',
# choices=['level', 'l1filter', 'l2filter', 'slim', 'agp',
# 'fpgm', 'mean_activation', 'apoz', 'admm'],
# help='pruner to use')
parser.add_argument('--sparsity', type=float, default=0.5,
help='target overall target sparsity')
parser.add_argument('--dependency-aware', action='store_true', default=False,
help='toggle dependency aware mode')
# finetuning
parser.add_argument('--finetune-epochs', type=int, default=5,
help='epochs to fine tune')
# parser.add_argument('--kd', action='store_true', default=False,
# help='quickly check a single pass')
# parser.add_argument('--kd_T', type=float, default=4,
# help='temperature for KD distillation')
# parser.add_argument('--finetune-lr', type=float, default=0.5,
# help='learning rate to finetune the model')
# speedup
# parser.add_argument('--speed-up', action='store_true', default=False,
# help='whether to speed-up the pruned model')
# parser.add_argument('--nni', action='store_true', default=False,
# help="whether to tune the pruners using NNi tuners")
args = parser.parse_args()
main(args)
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import reduce
class NaiveModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(1, 20, 5, 1)
self.conv2 = torch.nn.Conv2d(20, 50, 5, 1)
self.fc1 = torch.nn.Linear(4 * 4 * 50, 500)
self.fc2 = torch.nn.Linear(500, 10)
self.relu1 = torch.nn.ReLU6()
self.relu2 = torch.nn.ReLU6()
self.relu3 = torch.nn.ReLU6()
self.max_pool1 = torch.nn.MaxPool2d(2, 2)
self.max_pool2 = torch.nn.MaxPool2d(2, 2)
def forward(self, x):
x = self.relu1(self.conv1(x))
x = self.max_pool1(x)
x = self.relu2(self.conv2(x))
x = self.max_pool2(x)
x = x.view(-1, x.size()[1:].numel())
x = self.relu3(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
\ No newline at end of file
...@@ -12,7 +12,7 @@ from nni.algorithms.compression.pytorch.pruning import AMCPruner ...@@ -12,7 +12,7 @@ from nni.algorithms.compression.pytorch.pruning import AMCPruner
from data import get_split_dataset from data import get_split_dataset
from utils import AverageMeter, accuracy from utils import AverageMeter, accuracy
sys.path.append('../models') sys.path.append('../../models')
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description='AMC search script') parser = argparse.ArgumentParser(description='AMC search script')
......
...@@ -22,7 +22,7 @@ from nni.compression.pytorch import ModelSpeedup ...@@ -22,7 +22,7 @@ from nni.compression.pytorch import ModelSpeedup
from data import get_dataset from data import get_dataset
from utils import AverageMeter, accuracy, progress_bar from utils import AverageMeter, accuracy, progress_bar
sys.path.append('../models') sys.path.append('../../models')
from mobilenet import MobileNet from mobilenet import MobileNet
from mobilenet_v2 import MobileNetV2 from mobilenet_v2 import MobileNetV2
......
...@@ -13,14 +13,16 @@ import torch ...@@ -13,14 +13,16 @@ import torch
from torch.optim.lr_scheduler import StepLR, MultiStepLR from torch.optim.lr_scheduler import StepLR, MultiStepLR
from torchvision import datasets, transforms from torchvision import datasets, transforms
from models.mnist.lenet import LeNet
from models.cifar10.vgg import VGG
from models.cifar10.resnet import ResNet18, ResNet50
from nni.algorithms.compression.pytorch.pruning import L1FilterPruner, L2FilterPruner, FPGMPruner from nni.algorithms.compression.pytorch.pruning import L1FilterPruner, L2FilterPruner, FPGMPruner
from nni.algorithms.compression.pytorch.pruning import SimulatedAnnealingPruner, ADMMPruner, NetAdaptPruner, AutoCompressPruner from nni.algorithms.compression.pytorch.pruning import SimulatedAnnealingPruner, ADMMPruner, NetAdaptPruner, AutoCompressPruner
from nni.compression.pytorch import ModelSpeedup from nni.compression.pytorch import ModelSpeedup
from nni.compression.pytorch.utils.counter import count_flops_params from nni.compression.pytorch.utils.counter import count_flops_params
import sys
sys.path.append('../models')
from mnist.lenet import LeNet
from cifar10.vgg import VGG
from cifar10.resnet import ResNet18, ResNet50
def get_data(dataset, data_dir, batch_size, test_batch_size): def get_data(dataset, data_dir, batch_size, test_batch_size):
''' '''
...@@ -67,7 +69,7 @@ def get_data(dataset, data_dir, batch_size, test_batch_size): ...@@ -67,7 +69,7 @@ def get_data(dataset, data_dir, batch_size, test_batch_size):
return train_loader, val_loader, criterion return train_loader, val_loader, criterion
def train(args, model, device, train_loader, criterion, optimizer, epoch, callback=None): def train(args, model, device, train_loader, criterion, optimizer, epoch):
model.train() model.train()
for batch_idx, (data, target) in enumerate(train_loader): for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device) data, target = data.to(device), target.to(device)
...@@ -75,9 +77,6 @@ def train(args, model, device, train_loader, criterion, optimizer, epoch, callba ...@@ -75,9 +77,6 @@ def train(args, model, device, train_loader, criterion, optimizer, epoch, callba
output = model(data) output = model(data)
loss = criterion(output, target) loss = criterion(output, target)
loss.backward() loss.backward()
# callback should be inserted between loss.backward() and optimizer.step()
if callback:
callback()
optimizer.step() optimizer.step()
if batch_idx % args.log_interval == 0: if batch_idx % args.log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
...@@ -198,8 +197,8 @@ def main(args): ...@@ -198,8 +197,8 @@ def main(args):
for epoch in range(epochs): for epoch in range(epochs):
train(args, model, device, train_loader, criterion, optimizer, epoch) train(args, model, device, train_loader, criterion, optimizer, epoch)
def trainer(model, optimizer, criterion, epoch, callback): def trainer(model, optimizer, criterion, epoch):
return train(args, model, device, train_loader, criterion, optimizer, epoch=epoch, callback=callback) return train(args, model, device, train_loader, criterion, optimizer, epoch=epoch)
def evaluator(model): def evaluator(model):
return test(model, device, criterion, val_loader) return test(model, device, criterion, val_loader)
...@@ -264,7 +263,7 @@ def main(args): ...@@ -264,7 +263,7 @@ def main(args):
}] }]
else: else:
raise ValueError('Example only implemented for LeNet.') raise ValueError('Example only implemented for LeNet.')
pruner = ADMMPruner(model, config_list, trainer=trainer, num_iterations=2, training_epochs=2) pruner = ADMMPruner(model, config_list, trainer=trainer, num_iterations=2, epochs_per_iteration=2)
elif args.pruner == 'SimulatedAnnealingPruner': elif args.pruner == 'SimulatedAnnealingPruner':
pruner = SimulatedAnnealingPruner( pruner = SimulatedAnnealingPruner(
model, config_list, evaluator=evaluator, base_algo=args.base_algo, model, config_list, evaluator=evaluator, base_algo=args.base_algo,
...@@ -273,7 +272,7 @@ def main(args): ...@@ -273,7 +272,7 @@ def main(args):
pruner = AutoCompressPruner( pruner = AutoCompressPruner(
model, config_list, trainer=trainer, evaluator=evaluator, dummy_input=dummy_input, model, config_list, trainer=trainer, evaluator=evaluator, dummy_input=dummy_input,
num_iterations=3, optimize_mode='maximize', base_algo=args.base_algo, num_iterations=3, optimize_mode='maximize', base_algo=args.base_algo,
cool_down_rate=args.cool_down_rate, admm_num_iterations=30, admm_training_epochs=5, cool_down_rate=args.cool_down_rate, admm_num_iterations=30, admm_epochs_per_iteration=5,
experiment_data_dir=args.experiment_data_dir) experiment_data_dir=args.experiment_data_dir)
else: else:
raise ValueError( raise ValueError(
......
...@@ -12,25 +12,24 @@ import logging ...@@ -12,25 +12,24 @@ import logging
import argparse import argparse
import os import os
import time import sys
import torch import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR, MultiStepLR from torch.optim.lr_scheduler import StepLR, MultiStepLR
from torchvision import datasets, transforms from torchvision import datasets, transforms
from models.mnist.lenet import LeNet sys.path.append('../models')
from models.cifar10.vgg import VGG from mnist.lenet import LeNet
from cifar10.vgg import VGG
from nni.compression.pytorch.utils.counter import count_flops_params from nni.compression.pytorch.utils.counter import count_flops_params
import nni import nni
from nni.compression.pytorch import apply_compression_results, ModelSpeedup from nni.compression.pytorch import ModelSpeedup
from nni.algorithms.compression.pytorch.pruning import ( from nni.algorithms.compression.pytorch.pruning import (
LevelPruner, LevelPruner,
SlimPruner, SlimPruner,
FPGMPruner, FPGMPruner,
TaylorFOWeightFilterPruner,
L1FilterPruner, L1FilterPruner,
L2FilterPruner, L2FilterPruner,
AGPPruner, AGPPruner,
...@@ -38,7 +37,6 @@ from nni.algorithms.compression.pytorch.pruning import ( ...@@ -38,7 +37,6 @@ from nni.algorithms.compression.pytorch.pruning import (
ActivationAPoZRankFilterPruner ActivationAPoZRankFilterPruner
) )
_logger = logging.getLogger('mnist_example') _logger = logging.getLogger('mnist_example')
_logger.setLevel(logging.INFO) _logger.setLevel(logging.INFO)
...@@ -50,7 +48,8 @@ str2pruner = { ...@@ -50,7 +48,8 @@ str2pruner = {
'agp': AGPPruner, 'agp': AGPPruner,
'fpgm': FPGMPruner, 'fpgm': FPGMPruner,
'mean_activation': ActivationMeanRankFilterPruner, 'mean_activation': ActivationMeanRankFilterPruner,
'apoz': ActivationAPoZRankFilterPruner 'apoz': ActivationAPoZRankFilterPruner,
'taylorfo': TaylorFOWeightFilterPruner
} }
def get_dummy_input(args, device): def get_dummy_input(args, device):
...@@ -60,53 +59,6 @@ def get_dummy_input(args, device): ...@@ -60,53 +59,6 @@ def get_dummy_input(args, device):
dummy_input = torch.randn([args.test_batch_size, 3, 32, 32]).to(device) dummy_input = torch.randn([args.test_batch_size, 3, 32, 32]).to(device)
return dummy_input return dummy_input
def get_pruner(model, pruner_name, device, optimizer=None, dependency_aware=False):
pruner_cls = str2pruner[pruner_name]
if pruner_name == 'level':
config_list = [{
'sparsity': args.sparsity,
'op_types': ['default']
}]
elif pruner_name in ['l1filter', 'mean_activation', 'apoz']:
# Reproduced result in paper 'PRUNING FILTERS FOR EFFICIENT CONVNETS',
# Conv_1, Conv_8, Conv_9, Conv_10, Conv_11, Conv_12 are pruned with 50% sparsity, as 'VGG-16-pruned-A'
config_list = [{
'sparsity': args.sparsity,
'op_types': ['Conv2d'],
'op_names': ['feature.0', 'feature.24', 'feature.27', 'feature.30', 'feature.34', 'feature.37']
}]
elif pruner_name == 'slim':
config_list = [{
'sparsity': args.sparsity,
'op_types': ['BatchNorm2d'],
}]
elif pruner_name == 'agp':
config_list = [{
'initial_sparsity': 0.,
'final_sparsity': 0.8,
'start_epoch': 0,
'end_epoch': 10,
'frequency': 1,
'op_types': ['Conv2d']
}]
else:
config_list = [{
'sparsity': args.sparsity,
'op_types': ['Conv2d']
}]
kw_args = {}
if dependency_aware:
dummy_input = get_dummy_input(args, device)
print('Enable the dependency_aware mode')
# note that, not all pruners support the dependency_aware mode
kw_args['dependency_aware'] = True
kw_args['dummy_input'] = dummy_input
pruner = pruner_cls(model, config_list, optimizer, **kw_args)
return pruner
def get_data(dataset, data_dir, batch_size, test_batch_size): def get_data(dataset, data_dir, batch_size, test_batch_size):
kwargs = {'num_workers': 1, 'pin_memory': True} if torch.cuda.is_available() else { kwargs = {'num_workers': 1, 'pin_memory': True} if torch.cuda.is_available() else {
...@@ -174,7 +126,7 @@ def get_model_optimizer_scheduler(args, device, train_loader, test_loader, crite ...@@ -174,7 +126,7 @@ def get_model_optimizer_scheduler(args, device, train_loader, test_loader, crite
print('start pre-training...') print('start pre-training...')
best_acc = 0 best_acc = 0
for epoch in range(args.pretrain_epochs): for epoch in range(args.pretrain_epochs):
train(args, model, device, train_loader, criterion, optimizer, epoch, sparse_bn=True if args.pruner == 'slim' else False) train(args, model, device, train_loader, criterion, optimizer, epoch)
scheduler.step() scheduler.step()
acc = test(args, model, device, criterion, test_loader) acc = test(args, model, device, criterion, test_loader)
if acc > best_acc: if acc > best_acc:
...@@ -198,12 +150,7 @@ def get_model_optimizer_scheduler(args, device, train_loader, test_loader, crite ...@@ -198,12 +150,7 @@ def get_model_optimizer_scheduler(args, device, train_loader, test_loader, crite
print('Pretrained model acc:', best_acc) print('Pretrained model acc:', best_acc)
return model, optimizer, scheduler return model, optimizer, scheduler
def updateBN(model): def train(args, model, device, train_loader, criterion, optimizer, epoch):
for m in model.modules():
if isinstance(m, nn.BatchNorm2d):
m.weight.grad.data.add_(0.0001 * torch.sign(m.weight.data))
def train(args, model, device, train_loader, criterion, optimizer, epoch, sparse_bn=False):
model.train() model.train()
for batch_idx, (data, target) in enumerate(train_loader): for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device) data, target = data.to(device), target.to(device)
...@@ -211,11 +158,6 @@ def train(args, model, device, train_loader, criterion, optimizer, epoch, sparse ...@@ -211,11 +158,6 @@ def train(args, model, device, train_loader, criterion, optimizer, epoch, sparse
output = model(data) output = model(data)
loss = criterion(output, target) loss = criterion(output, target)
loss.backward() loss.backward()
if sparse_bn:
# L1 regularization on BN layer
updateBN(model)
optimizer.step() optimizer.step()
if batch_idx % args.log_interval == 0: if batch_idx % args.log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
...@@ -256,64 +198,99 @@ def main(args): ...@@ -256,64 +198,99 @@ def main(args):
flops, params, results = count_flops_params(model, dummy_input) flops, params, results = count_flops_params(model, dummy_input)
print(f"FLOPs: {flops}, params: {params}") print(f"FLOPs: {flops}, params: {params}")
print('start pruning...') print(f'start {args.pruner} pruning...')
def trainer(model, optimizer, criterion, epoch):
return train(args, model, device, train_loader, criterion, optimizer, epoch=epoch)
pruner_cls = str2pruner[args.pruner]
kw_args = {}
config_list = [{
'sparsity': args.sparsity,
'op_types': ['Conv2d']
}]
if args.pruner == 'level':
config_list = [{
'sparsity': args.sparsity,
'op_types': ['default']
}]
else:
if args.dependency_aware:
dummy_input = get_dummy_input(args, device)
print('Enable the dependency_aware mode')
# note that, not all pruners support the dependency_aware mode
kw_args['dependency_aware'] = True
kw_args['dummy_input'] = dummy_input
if args.pruner not in ('l1filter', 'l2filter', 'fpgm'):
# set only work for training aware pruners
kw_args['trainer'] = trainer
kw_args['optimizer'] = optimizer
kw_args['criterion'] = criterion
if args.pruner in ('slim', 'mean_activation', 'apoz', 'taylorfo'):
kw_args['sparsity_training_epochs'] = 5
if args.pruner == 'agp':
kw_args['pruning_algorithm'] = 'l1'
kw_args['num_iterations'] = 5
kw_args['epochs_per_iteration'] = 1
# Reproduced result in paper 'PRUNING FILTERS FOR EFFICIENT CONVNETS',
# Conv_1, Conv_8, Conv_9, Conv_10, Conv_11, Conv_12 are pruned with 50% sparsity, as 'VGG-16-pruned-A'
if args.pruner == 'slim':
config_list = [{
'sparsity': args.sparsity,
'op_types': ['BatchNorm2d'],
}]
else:
config_list = [{
'sparsity': args.sparsity,
'op_types': ['Conv2d'],
'op_names': ['feature.0', 'feature.24', 'feature.27', 'feature.30', 'feature.34', 'feature.37']
}]
pruner = pruner_cls(model, config_list, **kw_args)
# Pruner.compress() returns the masked model
model = pruner.compress()
pruner.get_pruned_weights()
# export the pruned model masks for model speedup
model_path = os.path.join(args.experiment_data_dir, 'pruned_{}_{}_{}.pth'.format( model_path = os.path.join(args.experiment_data_dir, 'pruned_{}_{}_{}.pth'.format(
args.model, args.dataset, args.pruner)) args.model, args.dataset, args.pruner))
mask_path = os.path.join(args.experiment_data_dir, 'mask_{}_{}_{}.pth'.format( mask_path = os.path.join(args.experiment_data_dir, 'mask_{}_{}_{}.pth'.format(
args.model, args.dataset, args.pruner)) args.model, args.dataset, args.pruner))
pruner.export_model(model_path=model_path, mask_path=mask_path)
pruner = get_pruner(model, args.pruner, device, optimizer, args.dependency_aware)
model = pruner.compress()
if args.multi_gpu and torch.cuda.device_count() > 1:
model = nn.DataParallel(model)
if args.test_only: if args.test_only:
test(args, model, device, criterion, test_loader) test(args, model, device, criterion, test_loader)
# Unwrap all modules to normal state
pruner._unwrap_model()
m_speedup = ModelSpeedup(model, dummy_input, mask_path, device)
m_speedup.speedup_model()
print('start finetuning...')
best_top1 = 0 best_top1 = 0
save_path = os.path.join(args.experiment_data_dir, f'finetuned.pth')
for epoch in range(args.fine_tune_epochs): for epoch in range(args.fine_tune_epochs):
pruner.update_epoch(epoch)
print('# Epoch {} #'.format(epoch)) print('# Epoch {} #'.format(epoch))
train(args, model, device, train_loader, criterion, optimizer, epoch) train(args, model, device, train_loader, criterion, optimizer, epoch)
scheduler.step() scheduler.step()
top1 = test(args, model, device, criterion, test_loader) top1 = test(args, model, device, criterion, test_loader)
if top1 > best_top1: if top1 > best_top1:
best_top1 = top1 best_top1 = top1
# Export the best model, 'model_path' stores state_dict of the pruned model, torch.save(model.state_dict(), save_path)
# mask_path stores mask_dict of the pruned model
pruner.export_model(model_path=model_path, mask_path=mask_path) flops, params, results = count_flops_params(model, dummy_input)
print(f'Finetuned model FLOPs {flops/1e6:.2f} M, #Params: {params/1e6:.2f}M, Accuracy: {best_top1: .2f}')
if args.nni: if args.nni:
nni.report_final_result(best_top1) nni.report_final_result(best_top1)
if args.speed_up:
# reload the best checkpoint for speed-up
args.pretrained_model_dir = model_path
model, _, _ = get_model_optimizer_scheduler(args, device, train_loader, test_loader, criterion)
model.eval()
apply_compression_results(model, mask_path, device)
# test model speed
start = time.time()
for _ in range(32):
use_mask_out = model(dummy_input)
print('elapsed time when use mask: ', time.time() - start)
m_speedup = ModelSpeedup(model, dummy_input, mask_path, device)
m_speedup.speedup_model()
flops, params, results = count_flops_params(model, dummy_input)
print(f"FLOPs: {flops}, params: {params}")
start = time.time()
for _ in range(32):
use_speedup_out = model(dummy_input)
print('elapsed time when use speedup: ', time.time() - start)
top1 = test(args, model, device, criterion, test_loader)
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser(description='PyTorch Example for model comporession') parser = argparse.ArgumentParser(description='PyTorch Example for model comporession')
...@@ -352,17 +329,13 @@ if __name__ == '__main__': ...@@ -352,17 +329,13 @@ if __name__ == '__main__':
help='toggle dependency aware mode') help='toggle dependency aware mode')
parser.add_argument('--pruner', type=str, default='l1filter', parser.add_argument('--pruner', type=str, default='l1filter',
choices=['level', 'l1filter', 'l2filter', 'slim', 'agp', choices=['level', 'l1filter', 'l2filter', 'slim', 'agp',
'fpgm', 'mean_activation', 'apoz'], 'fpgm', 'mean_activation', 'apoz', 'taylorfo'],
help='pruner to use') help='pruner to use')
# fine-tuning # fine-tuning
parser.add_argument('--fine-tune-epochs', type=int, default=160, parser.add_argument('--fine-tune-epochs', type=int, default=160,
help='epochs to fine tune') help='epochs to fine tune')
# speed-up
parser.add_argument('--speed-up', action='store_true', default=False,
help='whether to speed-up the pruned model')
parser.add_argument('--nni', action='store_true', default=False, parser.add_argument('--nni', action='store_true', default=False,
help="whether to tune the pruners using NNi tuners") help="whether to tune the pruners using NNi tuners")
......
...@@ -20,8 +20,11 @@ from nni.compression.pytorch import ModelSpeedup ...@@ -20,8 +20,11 @@ from nni.compression.pytorch import ModelSpeedup
from torch.optim.lr_scheduler import MultiStepLR, StepLR from torch.optim.lr_scheduler import MultiStepLR, StepLR
from torchvision import datasets, transforms from torchvision import datasets, transforms
from basic_pruners_torch import get_data from basic_pruners_torch import get_data
from models.cifar10.vgg import VGG
from models.mnist.lenet import LeNet import sys
sys.path.append('../models')
from cifar10.vgg import VGG
from mnist.lenet import LeNet
class DistillKL(nn.Module): class DistillKL(nn.Module):
"""Distilling the Knowledge in a Neural Network""" """Distilling the Knowledge in a Neural Network"""
......
...@@ -20,7 +20,7 @@ class fc1(nn.Module): ...@@ -20,7 +20,7 @@ class fc1(nn.Module):
def __init__(self, num_classes=10): def __init__(self, num_classes=10):
super(fc1, self).__init__() super(fc1, self).__init__()
self.classifier = nn.Sequential( self.classifier = nn.Sequential(
nn.Linear(28*28, 300), nn.Linear(28 * 28, 300),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
nn.Linear(300, 100), nn.Linear(300, 100),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
......
...@@ -5,8 +5,12 @@ import torch ...@@ -5,8 +5,12 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torchvision import datasets, transforms from torchvision import datasets, transforms
from models.cifar10.vgg import VGG
from models.mnist.lenet import LeNet import sys
sys.path.append('../models')
from cifar10.vgg import VGG
from mnist.lenet import LeNet
from nni.compression.pytorch import apply_compression_results, ModelSpeedup from nni.compression.pytorch import apply_compression_results, ModelSpeedup
torch.manual_seed(0) torch.manual_seed(0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment