Unverified Commit 41312de5 authored by Guoxin's avatar Guoxin Committed by GitHub
Browse files

Compression doc structure refactor (#2676)



* init sapruner

* seperate sapruners from other one-shot pruners

* update

* fix model params issue

* make the process runnable

* show evaluation result in example

* sort the sparsities and scale it

* fix rescale issue

* fix scale issue; add pruning history

* record the actual total sparsity

* fix sparsity 0/1 problem

* revert useless modif

* revert useless modif

* fix 0 pruning weights problem

* save pruning history in csv file

* fix typo

* remove check perm in Makefile

* use os path

* save config list in json format

* update analyze py; update docker

* update

* update analyze

* update log info in compressor

* init NetAdapt Pruner

* refine examples

* update

* fine tune

* update

* fix quote issue

* add code for imagenet  integrity

* update

* use datasets.ImageNet

* update

* update

* add channel pruning in SAPruner; refine example

* update net_adapt pruner; add dependency constraint in sapruner(beta)

* update

* update

* update

* fix zero division problem

* fix typo

* update

* fix naive issue of NetAdaptPruner

* fix data issue for no-dependency modules

* add cifar10 vgg16 examplel

* update

* update

* fix folder creation issue; change lr for vgg exp

* update

* add save model arg

* fix model copy issue

* init related weights calc

* update analyze file

* NetAdaptPruner: use fine-tuned weights after each iteration; fix modules_wrapper iteration issue

* consider channel/filter cross pruning

* NetAdapt: consider previous op when calc total sparsity

* update

* use customized vgg

* add performances comparison plt

* fix netadaptPruner mask copy issue

* add resnet18 example

* fix example issue

* update experiment data

* fix bool arg parsing issue

* update

* init ADMMPruner

* ADMMPruner: update

* ADMMPruner: finish v1.0

* ADMMPruner: refine

* update

* AutoCompress init

* AutoCompress: update

* AutoCompressPruner: fix issues:

* add test for auto pruners

* add doc for auto pruners

* fix link in md

* remove irrelevant files

* Clean code

* code clean

* fix pylint issue

* fix pylint issue

* rename admm & autoCompress param

* use abs link in doc

* reorder import to fix import issue: autocompress relies on speedup

* refine doc

* NetAdaptPruner: decay pruning step

* take changes from testing branch

* refine

* fix typo

* ADMMPruenr: check base_algo together with config schema

* fix broken link

* doc refine

* ADMM:refine

* refine doc

* refine doc

* refince doc

* refine doc

* refine doc

* refine doc

* update

* update

* refactor AGP doc

* update

* fix optimizer issue

* fix comments: typo, rename AGP_Pruner

* fix torch.nn.Module issue; refine SA docstring

* fix typo
Co-authored-by: default avatarYuge Zhang <scottyugochang@gmail.com>
parent cfda8c36
...@@ -84,7 +84,7 @@ config_list_agp = [{'initial_sparsity': 0, 'final_sparsity': conv0_sparsity, ...@@ -84,7 +84,7 @@ config_list_agp = [{'initial_sparsity': 0, 'final_sparsity': conv0_sparsity,
{'initial_sparsity': 0, 'final_sparsity': conv1_sparsity, {'initial_sparsity': 0, 'final_sparsity': conv1_sparsity,
'start_epoch': 0, 'end_epoch': 3, 'start_epoch': 0, 'end_epoch': 3,
'frequency': 1,'op_name': 'conv1' },] 'frequency': 1,'op_name': 'conv1' },]
PRUNERS = {'level':LevelPruner(model, config_list_level), 'agp':AGP_Pruner(model, config_list_agp)} PRUNERS = {'level':LevelPruner(model, config_list_level), 'agp':AGPPruner(model, config_list_agp)}
pruner = PRUNERS(params['prune_method']['_name']) pruner = PRUNERS(params['prune_method']['_name'])
pruner.compress() pruner.compress()
... # fine tuning ... # fine tuning
......
This diff is collapsed.
docs/img/agp_pruner.png

8.38 KB | W: | H:

docs/img/agp_pruner.png

18.4 KB | W: | H:

docs/img/agp_pruner.png
docs/img/agp_pruner.png
docs/img/agp_pruner.png
docs/img/agp_pruner.png
  • 2-up
  • Swipe
  • Onion skin
...@@ -22,7 +22,7 @@ configure_list = [{ ...@@ -22,7 +22,7 @@ configure_list = [{
'frequency': 1, 'frequency': 1,
'op_types': ['default'] 'op_types': ['default']
}] }]
pruner = AGP_Pruner(configure_list) pruner = AGPPruner(configure_list)
``` ```
When ```pruner(model)``` is called, your model is injected with masks as embedded operations. For example, a layer takes a weight as input, we will insert an operation between the weight and the layer, this operation takes the weight as input and outputs a new weight applied by the mask. Thus, the masks are applied at any time the computation goes through the operations. You can fine-tune your model **without** any modifications. When ```pruner(model)``` is called, your model is injected with masks as embedded operations. For example, a layer takes a weight as input, we will insert an operation between the weight and the layer, this operation takes the weight as input and outputs a new weight applied by the mask. Thus, the masks are applied at any time the computation goes through the operations. You can fine-tune your model **without** any modifications.
......
...@@ -10,7 +10,7 @@ from torchvision import datasets, transforms ...@@ -10,7 +10,7 @@ from torchvision import datasets, transforms
from models.cifar10.vgg import VGG from models.cifar10.vgg import VGG
import nni import nni
from nni.compression.torch import LevelPruner, SlimPruner, FPGMPruner, L1FilterPruner, \ from nni.compression.torch import LevelPruner, SlimPruner, FPGMPruner, L1FilterPruner, \
L2FilterPruner, AGP_Pruner, ActivationMeanRankFilterPruner, ActivationAPoZRankFilterPruner L2FilterPruner, AGPPruner, ActivationMeanRankFilterPruner, ActivationAPoZRankFilterPruner
prune_config = { prune_config = {
'level': { 'level': {
...@@ -25,7 +25,7 @@ prune_config = { ...@@ -25,7 +25,7 @@ prune_config = {
'agp': { 'agp': {
'dataset_name': 'mnist', 'dataset_name': 'mnist',
'model_name': 'naive', 'model_name': 'naive',
'pruner_class': AGP_Pruner, 'pruner_class': AGPPruner,
'config_list': [{ 'config_list': [{
'initial_sparsity': 0., 'initial_sparsity': 0.,
'final_sparsity': 0.8, 'final_sparsity': 0.8,
......
...@@ -6,17 +6,23 @@ import numpy as np ...@@ -6,17 +6,23 @@ import numpy as np
import tensorflow as tf import tensorflow as tf
from .compressor import Pruner from .compressor import Pruner
__all__ = ['LevelPruner', 'AGP_Pruner', 'FPGMPruner'] __all__ = ['LevelPruner', 'AGPPruner', 'FPGMPruner']
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
class LevelPruner(Pruner): class LevelPruner(Pruner):
"""
Parameters
----------
model : tensorflow model
Model to be pruned
config_list : list
Supported keys:
- sparsity : This is to specify the sparsity operations to be compressed to.
- op_types : Operation types to prune.
"""
def __init__(self, model, config_list): def __init__(self, model, config_list):
"""
config_list: supported keys:
- sparsity
"""
super().__init__(model, config_list) super().__init__(model, config_list)
self.mask_list = {} self.mask_list = {}
self.if_init_list = {} self.if_init_list = {}
...@@ -34,24 +40,22 @@ class LevelPruner(Pruner): ...@@ -34,24 +40,22 @@ class LevelPruner(Pruner):
return mask return mask
class AGP_Pruner(Pruner): class AGPPruner(Pruner):
"""An automated gradual pruning algorithm that prunes the smallest magnitude """
weights to achieve a preset level of network sparsity. Parameters
Michael Zhu and Suyog Gupta, "To prune, or not to prune: exploring the ----------
efficacy of pruning for model compression", 2017 NIPS Workshop on Machine model : torch.nn.Module
Learning of Phones and other Consumer Devices, Model to be pruned.
https://arxiv.org/pdf/1710.01878.pdf config_list : listlist
Supported keys:
- initial_sparsity: This is to specify the sparsity when compressor starts to compress.
- final_sparsity: This is to specify the sparsity when compressor finishes to compress.
- start_epoch: This is to specify the epoch number when compressor starts to compress, default start from epoch 0.
- end_epoch: This is to specify the epoch number when compressor finishes to compress.
- frequency: This is to specify every *frequency* number epochs compressor compress once, default frequency=1.
""" """
def __init__(self, model, config_list): def __init__(self, model, config_list):
"""
config_list: supported keys:
- initial_sparsity
- final_sparsity: you should make sure initial_sparsity <= final_sparsity
- start_epoch: start epoch numer begin update mask
- end_epoch: end epoch number stop update mask
- frequency: if you want update every 2 epoch, you can set it 2
"""
super().__init__(model, config_list) super().__init__(model, config_list)
self.mask_list = {} self.mask_list = {}
self.if_init_list = {} self.if_init_list = {}
...@@ -102,23 +106,19 @@ class AGP_Pruner(Pruner): ...@@ -102,23 +106,19 @@ class AGP_Pruner(Pruner):
for k in self.if_init_list: for k in self.if_init_list:
self.if_init_list[k] = True self.if_init_list[k] = True
class FPGMPruner(Pruner): class FPGMPruner(Pruner):
""" """
A filter pruner via geometric median. Parameters
"Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration", ----------
https://arxiv.org/pdf/1811.00250.pdf model : tensorflow model
Model to be pruned
config_list : list
Supported keys:
- sparsity : percentage of convolutional filters to be pruned.
- op_types : Only Conv2d is supported in FPGM Pruner.
""" """
def __init__(self, model, config_list): def __init__(self, model, config_list):
"""
Parameters
----------
model : pytorch model
the model user wants to compress
config_list: list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
"""
super().__init__(model, config_list) super().__init__(model, config_list)
self.mask_dict = {} self.mask_dict = {}
self.assign_handler = [] self.assign_handler = []
......
...@@ -15,58 +15,50 @@ _logger = logging.getLogger(__name__) ...@@ -15,58 +15,50 @@ _logger = logging.getLogger(__name__)
class ADMMPruner(OneshotPruner): class ADMMPruner(OneshotPruner):
""" """
This is a Pytorch implementation of ADMM Pruner algorithm. A Pytorch implementation of ADMM Pruner algorithm.
Parameters
----------
model : torch.nn.Module
Model to be pruned.
config_list : list
List on pruning configs.
trainer : function
Function used for the first subproblem.
Users should write this function as a normal function to train the Pytorch model
and include `model, optimizer, criterion, epoch, callback` as function arguments.
Here `callback` acts as an L2 regulizer as presented in the formula (7) of the original paper.
The logic of `callback` is implemented inside the Pruner,
users are just required to insert `callback()` between `loss.backward()` and `optimizer.step()`.
Example::
def trainer(model, criterion, optimizer, epoch, callback):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_loader = ...
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
# callback should be inserted between loss.backward() and optimizer.step()
if callback:
callback()
optimizer.step()
num_iterations : int
Total number of iterations.
training_epochs : int
Training epochs of the first subproblem.
row : float
Penalty parameters for ADMM training.
base_algo : str
Base pruning algorithm. `level`, `l1` or `l2`, by default `l1`. Given the sparsity distribution among the ops,
the assigned `base_algo` is used to decide which filters/channels/weights to prune.
Alternating Direction Method of Multipliers (ADMM) is a mathematical optimization technique,
by decomposing the original nonconvex problem into two subproblems that can be solved iteratively.
In weight pruning problem, these two subproblems are solved via 1) gradient descent algorithm and 2) Euclidean projection respectively.
This solution framework applies both to non-structured and different variations of structured pruning schemes.
For more details, please refer to the paper: https://arxiv.org/abs/1804.03294.
""" """
def __init__(self, model, config_list, trainer, num_iterations=30, training_epochs=5, row=1e-4, base_algo='l1'): def __init__(self, model, config_list, trainer, num_iterations=30, training_epochs=5, row=1e-4, base_algo='l1'):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list
List on pruning configs
trainer : function
Function used for the first subproblem.
Users should write this function as a normal function to train the Pytorch model
and include `model, optimizer, criterion, epoch, callback` as function arguments.
Here `callback` acts as an L2 regulizer as presented in the formula (7) of the original paper.
The logic of `callback` is implemented inside the Pruner,
users are just required to insert `callback()` between `loss.backward()` and `optimizer.step()`.
Example::
```
>>> def trainer(model, criterion, optimizer, epoch, callback):
>>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
>>> train_loader = ...
>>> model.train()
>>> for batch_idx, (data, target) in enumerate(train_loader):
>>> data, target = data.to(device), target.to(device)
>>> optimizer.zero_grad()
>>> output = model(data)
>>> loss = criterion(output, target)
>>> loss.backward()
>>> # callback should be inserted between loss.backward() and optimizer.step()
>>> if callback:
>>> callback()
>>> optimizer.step()
```
num_iterations : int
Total number of iterations.
training_epochs : int
Training epochs of the first subproblem.
row : float
Penalty parameters for ADMM training.
base_algo : str
Base pruning algorithm. `level`, `l1` or `l2`, by default `l1`. Given the sparsity distribution among the ops,
the assigned `base_algo` is used to decide which filters/channels/weights to prune.
"""
self._base_algo = base_algo self._base_algo = base_algo
super().__init__(model, config_list) super().__init__(model, config_list)
...@@ -83,7 +75,7 @@ class ADMMPruner(OneshotPruner): ...@@ -83,7 +75,7 @@ class ADMMPruner(OneshotPruner):
""" """
Parameters Parameters
---------- ----------
model : torch.nn.module model : torch.nn.Module
Model to be pruned Model to be pruned
config_list : list config_list : list
List on pruning configs List on pruning configs
......
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
"""
An automated gradual pruning algorithm that prunes the smallest magnitude
weights to achieve a preset level of network sparsity.
Michael Zhu and Suyog Gupta, "To prune, or not to prune: exploring the
efficacy of pruning for model compression", 2017 NIPS Workshop on Machine
Learning of Phones and other Consumer Devices.
"""
import logging import logging
import torch import torch
from schema import And, Optional from schema import And, Optional
...@@ -8,34 +16,31 @@ from .constants import MASKER_DICT ...@@ -8,34 +16,31 @@ from .constants import MASKER_DICT
from ..utils.config_validation import CompressorSchema from ..utils.config_validation import CompressorSchema
from ..compressor import Pruner from ..compressor import Pruner
__all__ = ['AGP_Pruner'] __all__ = ['AGPPruner']
logger = logging.getLogger('torch pruner') logger = logging.getLogger('torch pruner')
class AGP_Pruner(Pruner): class AGPPruner(Pruner):
""" """
An automated gradual pruning algorithm that prunes the smallest magnitude Parameters
weights to achieve a preset level of network sparsity. ----------
Michael Zhu and Suyog Gupta, "To prune, or not to prune: exploring the model : torch.nn.Module
efficacy of pruning for model compression", 2017 NIPS Workshop on Machine Model to be pruned.
Learning of Phones and other Consumer Devices, config_list : listlist
https://arxiv.org/pdf/1710.01878.pdf Supported keys:
- initial_sparsity: This is to specify the sparsity when compressor starts to compress.
- final_sparsity: This is to specify the sparsity when compressor finishes to compress.
- start_epoch: This is to specify the epoch number when compressor starts to compress, default start from epoch 0.
- end_epoch: This is to specify the epoch number when compressor finishes to compress.
- frequency: This is to specify every *frequency* number epochs compressor compress once, default frequency=1.
optimizer: torch.optim.Optimizer
Optimizer used to train model.
pruning_algorithm: str
Algorithms being used to prune model,
choose from `['level', 'slim', 'l1', 'l2', 'fpgm', 'taylorfo', 'apoz', 'mean_activation']`, by default `level`
""" """
def __init__(self, model, config_list, optimizer, pruning_algorithm='level'): def __init__(self, model, config_list, optimizer, pruning_algorithm='level'):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list
List on pruning configs
optimizer: torch.optim.Optimizer
Optimizer used to train model
pruning_algorithm: str
algorithms being used to prune model
"""
super().__init__(model, config_list, optimizer) super().__init__(model, config_list, optimizer)
assert isinstance(optimizer, torch.optim.Optimizer), "AGP pruner is an iterative pruner, please pass optimizer of the model to it" assert isinstance(optimizer, torch.optim.Optimizer), "AGP pruner is an iterative pruner, please pass optimizer of the model to it"
self.masker = MASKER_DICT[pruning_algorithm](model, self) self.masker = MASKER_DICT[pruning_algorithm](model, self)
...@@ -47,7 +52,7 @@ class AGP_Pruner(Pruner): ...@@ -47,7 +52,7 @@ class AGP_Pruner(Pruner):
""" """
Parameters Parameters
---------- ----------
model : torch.nn.module model : torch.nn.Module
Model to be pruned Model to be pruned
config_list : list config_list : list
List on pruning configs List on pruning configs
......
...@@ -14,7 +14,7 @@ def apply_compression_results(model, masks_file, map_location=None): ...@@ -14,7 +14,7 @@ def apply_compression_results(model, masks_file, map_location=None):
Parameters Parameters
---------- ----------
model : torch.nn.module model : torch.nn.Module
The model to be compressed The model to be compressed
masks_file : str masks_file : str
The path of the mask file The path of the mask file
......
...@@ -21,14 +21,83 @@ _logger = logging.getLogger(__name__) ...@@ -21,14 +21,83 @@ _logger = logging.getLogger(__name__)
class AutoCompressPruner(Pruner): class AutoCompressPruner(Pruner):
""" """
This is a Pytorch implementation of AutoCompress pruning algorithm. A Pytorch implementation of AutoCompress pruning algorithm.
For each round, AutoCompressPruner prune the model for the same sparsity to achive the ovrall sparsity: Parameters
1. Generate sparsities distribution using SimualtedAnnealingPruner ----------
2. Perform ADMM-based structured pruning to generate pruning result for the next round. model : pytorch model
Here we use 'speedup' to perform real pruning. The model to be pruned.
config_list : list
For more details, please refer to the paper: https://arxiv.org/abs/1907.03141. Supported keys:
- sparsity : The target overall sparsity.
- op_types : The operation type to prune.
trainer : function
Function used for the first subproblem of ADMM Pruner.
Users should write this function as a normal function to train the Pytorch model
and include `model, optimizer, criterion, epoch, callback` as function arguments.
Here `callback` acts as an L2 regulizer as presented in the formula (7) of the original paper.
The logic of `callback` is implemented inside the Pruner,
users are just required to insert `callback()` between `loss.backward()` and `optimizer.step()`.
Example::
def trainer(model, criterion, optimizer, epoch, callback):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_loader = ...
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
# callback should be inserted between loss.backward() and optimizer.step()
if callback:
callback()
optimizer.step()
evaluator : function
function to evaluate the pruned model.
This function should include `model` as the only parameter, and returns a scalar value.
Example::
def evaluator(model):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
val_loader = ...
model.eval()
correct = 0
with torch.no_grad():
for data, target in val_loader:
data, target = data.to(device), target.to(device)
output = model(data)
# get the index of the max log-probability
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
accuracy = correct / len(val_loader.dataset)
return accuracy
dummy_input : pytorch tensor
The dummy input for ```jit.trace```, users should put it on right device before pass in.
num_iterations : int
Number of overall iterations.
optimize_mode : str
optimize mode, `maximize` or `minimize`, by default `maximize`.
base_algo : str
Base pruning algorithm. `level`, `l1` or `l2`, by default `l1`. Given the sparsity distribution among the ops,
the assigned `base_algo` is used to decide which filters/channels/weights to prune.
start_temperature : float
Start temperature of the simulated annealing process.
stop_temperature : float
Stop temperature of the simulated annealing process.
cool_down_rate : float
Cool down rate of the temperature.
perturbation_magnitude : float
Initial perturbation magnitude to the sparsities. The magnitude decreases with current temperature.
admm_num_iterations : int
Number of iterations of ADMM Pruner.
admm_training_epochs : int
Training epochs of the first optimization subproblem of ADMMPruner.
row : float
Penalty parameters for ADMM training.
experiment_data_dir : string
PATH to store temporary experiment data.
""" """
def __init__(self, model, config_list, trainer, evaluator, dummy_input, def __init__(self, model, config_list, trainer, evaluator, dummy_input,
...@@ -38,83 +107,6 @@ class AutoCompressPruner(Pruner): ...@@ -38,83 +107,6 @@ class AutoCompressPruner(Pruner):
# ADMM related # ADMM related
admm_num_iterations=30, admm_training_epochs=5, row=1e-4, admm_num_iterations=30, admm_training_epochs=5, row=1e-4,
experiment_data_dir='./'): experiment_data_dir='./'):
"""
Parameters
----------
model : pytorch model
The model to be pruned
config_list : list
Supported keys:
- sparsity : The target overall sparsity.
- op_types : The operation type to prune.
trainer : function
Function used for the first subproblem of ADMM Pruner.
Users should write this function as a normal function to train the Pytorch model
and include `model, optimizer, criterion, epoch, callback` as function arguments.
Here `callback` acts as an L2 regulizer as presented in the formula (7) of the original paper.
The logic of `callback` is implemented inside the Pruner,
users are just required to insert `callback()` between `loss.backward()` and `optimizer.step()`.
Example::
```
>>> def trainer(model, criterion, optimizer, epoch, callback):
>>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
>>> train_loader = ...
>>> model.train()
>>> for batch_idx, (data, target) in enumerate(train_loader):
>>> data, target = data.to(device), target.to(device)
>>> optimizer.zero_grad()
>>> output = model(data)
>>> loss = criterion(output, target)
>>> loss.backward()
>>> # callback should be inserted between loss.backward() and optimizer.step()
>>> if callback:
>>> callback()
>>> optimizer.step()
```
evaluator : function
function to evaluate the pruned model.
This function should include `model` as the only parameter, and returns a scalar value.
Example::
>>> def evaluator(model):
>>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
>>> val_loader = ...
>>> model.eval()
>>> correct = 0
>>> with torch.no_grad():
>>> for data, target in val_loader:
>>> data, target = data.to(device), target.to(device)
>>> output = model(data)
>>> # get the index of the max log-probability
>>> pred = output.argmax(dim=1, keepdim=True)
>>> correct += pred.eq(target.view_as(pred)).sum().item()
>>> accuracy = correct / len(val_loader.dataset)
>>> return accuracy
dummy_input : pytorch tensor
The dummy input for ```jit.trace```, users should put it on right device before pass in
num_iterations : int
Number of overall iterations
optimize_mode : str
optimize mode, `maximize` or `minimize`, by default `maximize`
base_algo : str
Base pruning algorithm. `level`, `l1` or `l2`, by default `l1`. Given the sparsity distribution among the ops,
the assigned `base_algo` is used to decide which filters/channels/weights to prune.
start_temperature : float
Simualated Annealing related parameter
stop_temperature : float
Simualated Annealing related parameter
cool_down_rate : float
Simualated Annealing related parameter
perturbation_magnitude : float
Initial perturbation magnitude to the sparsities. The magnitude decreases with current temperature
admm_num_iterations : int
Number of iterations of ADMM Pruner
admm_training_epochs : int
Training epochs of the first optimization subproblem of ADMMPruner
row : float
Penalty parameters for ADMM training
experiment_data_dir : string
PATH to store temporary experiment data
"""
# original model # original model
self._model_to_prune = model self._model_to_prune = model
self._base_algo = base_algo self._base_algo = base_algo
...@@ -147,7 +139,7 @@ class AutoCompressPruner(Pruner): ...@@ -147,7 +139,7 @@ class AutoCompressPruner(Pruner):
""" """
Parameters Parameters
---------- ----------
model : torch.nn.module model : torch.nn.Module
Model to be pruned Model to be pruned
config_list : list config_list : list
List on pruning configs List on pruning configs
......
...@@ -13,33 +13,22 @@ logger = logging.getLogger('torch pruner') ...@@ -13,33 +13,22 @@ logger = logging.getLogger('torch pruner')
class LotteryTicketPruner(Pruner): class LotteryTicketPruner(Pruner):
""" """
This is a Pytorch implementation of the paper "The Lottery Ticket Hypothesis: Finding Sparse, Trainable Neural Networks", Parameters
following NNI model compression interface. ----------
model : pytorch model
1. Randomly initialize a neural network f(x;theta_0) (where theta_0 follows D_{theta}). The model to be pruned
2. Train the network for j iterations, arriving at parameters theta_j. config_list : list
3. Prune p% of the parameters in theta_j, creating a mask m. Supported keys:
4. Reset the remaining parameters to their values in theta_0, creating the winning ticket f(x;m*theta_0). - prune_iterations : The number of rounds for the iterative pruning.
5. Repeat step 2, 3, and 4. - sparsity : The final sparsity when the compression is done.
optimizer : pytorch optimizer
The optimizer for the model
lr_scheduler : pytorch lr scheduler
The lr scheduler for the model if used
reset_weights : bool
Whether reset weights and optimizer at the beginning of each round.
""" """
def __init__(self, model, config_list, optimizer=None, lr_scheduler=None, reset_weights=True): def __init__(self, model, config_list, optimizer=None, lr_scheduler=None, reset_weights=True):
"""
Parameters
----------
model : pytorch model
The model to be pruned
config_list : list
Supported keys:
- prune_iterations : The number of rounds for the iterative pruning.
- sparsity : The final sparsity when the compression is done.
optimizer : pytorch optimizer
The optimizer for the model
lr_scheduler : pytorch lr scheduler
The lr scheduler for the model if used
reset_weights : bool
Whether reset weights and optimizer at the beginning of each round.
"""
# save init weights and optimizer # save init weights and optimizer
self.reset_weights = reset_weights self.reset_weights = reset_weights
if self.reset_weights: if self.reset_weights:
...@@ -60,7 +49,7 @@ class LotteryTicketPruner(Pruner): ...@@ -60,7 +49,7 @@ class LotteryTicketPruner(Pruner):
""" """
Parameters Parameters
---------- ----------
model : torch.nn.module model : torch.nn.Module
Model to be pruned Model to be pruned
config_list : list config_list : list
Supported keys: Supported keys:
......
...@@ -21,80 +21,69 @@ _logger = logging.getLogger(__name__) ...@@ -21,80 +21,69 @@ _logger = logging.getLogger(__name__)
class NetAdaptPruner(Pruner): class NetAdaptPruner(Pruner):
""" """
This is a Pytorch implementation of NetAdapt compression algorithm. A Pytorch implementation of NetAdapt compression algorithm.
The pruning procedure can be described as follows: Parameters
While Res_i > Bud: ----------
1. Con = Res_i - delta_Res model : pytorch model
2. for every layer: The model to be pruned.
Choose Num Filters to prune config_list : list
Choose which filter to prune Supported keys:
Short-term fine tune the pruned model - sparsity : The target overall sparsity.
3. Pick the best layer to prune - op_types : The operation type to prune.
Long-term fine tune short_term_fine_tuner : function
function to short-term fine tune the masked model.
For the details of this algorithm, please refer to the paper: https://arxiv.org/abs/1804.03230 This function should include `model` as the only parameter,
and fine tune the model for a short term after each pruning iteration.
Example::
def short_term_fine_tuner(model, epoch=3):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_loader = ...
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
model.train()
for _ in range(epoch):
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
evaluator : function
function to evaluate the masked model.
This function should include `model` as the only parameter, and returns a scalar value.
Example::
def evaluator(model):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
val_loader = ...
model.eval()
correct = 0
with torch.no_grad():
for data, target in val_loader:
data, target = data.to(device), target.to(device)
output = model(data)
# get the index of the max log-probability
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
accuracy = correct / len(val_loader.dataset)
return accuracy
optimize_mode : str
optimize mode, `maximize` or `minimize`, by default `maximize`.
base_algo : str
Base pruning algorithm. `level`, `l1` or `l2`, by default `l1`. Given the sparsity distribution among the ops,
the assigned `base_algo` is used to decide which filters/channels/weights to prune.
sparsity_per_iteration : float
sparsity to prune in each iteration.
experiment_data_dir : str
PATH to save experiment data,
including the config_list generated for the base pruning algorithm and the performance of the pruned model.
""" """
def __init__(self, model, config_list, short_term_fine_tuner, evaluator, def __init__(self, model, config_list, short_term_fine_tuner, evaluator,
optimize_mode='maximize', base_algo='l1', sparsity_per_iteration=0.05, experiment_data_dir='./'): optimize_mode='maximize', base_algo='l1', sparsity_per_iteration=0.05, experiment_data_dir='./'):
"""
Parameters
----------
model : pytorch model
The model to be pruned
config_list : list
Supported keys:
- sparsity : The target overall sparsity.
- op_types : The operation type to prune.
short_term_fine_tuner : function
function to short-term fine tune the masked model.
This function should include `model` as the only parameter,
and fine tune the model for a short term after each pruning iteration.
Example:
>>> def short_term_fine_tuner(model, epoch=3):
>>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
>>> train_loader = ...
>>> criterion = torch.nn.CrossEntropyLoss()
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
>>> model.train()
>>> for _ in range(epoch):
>>> for _, (data, target) in enumerate(train_loader):
>>> data, target = data.to(device), target.to(device)
>>> optimizer.zero_grad()
>>> output = model(data)
>>> loss = criterion(output, target)
>>> loss.backward()
>>> optimizer.step()
evaluator : function
function to evaluate the masked model.
This function should include `model` as the only parameter, and returns a scalar value.
Example::
>>> def evaluator(model):
>>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
>>> val_loader = ...
>>> model.eval()
>>> correct = 0
>>> with torch.no_grad():
>>> for data, target in val_loader:
>>> data, target = data.to(device), target.to(device)
>>> output = model(data)
>>> # get the index of the max log-probability
>>> pred = output.argmax(dim=1, keepdim=True)
>>> correct += pred.eq(target.view_as(pred)).sum().item()
>>> accuracy = correct / len(val_loader.dataset)
>>> return accuracy
optimize_mode : str
optimize mode, `maximize` or `minimize`, by default `maximize`.
base_algo : str
Base pruning algorithm. `level`, `l1` or `l2`, by default `l1`. Given the sparsity distribution among the ops,
the assigned `base_algo` is used to decide which filters/channels/weights to prune.
sparsity_per_iteration : float
sparsity to prune in each iteration
experiment_data_dir : str
PATH to save experiment data,
including the config_list generated for the base pruning algorithm and the performance of the pruned model.
"""
# models used for iterative pruning and evaluation # models used for iterative pruning and evaluation
self._model_to_prune = copy.deepcopy(model) self._model_to_prune = copy.deepcopy(model)
self._base_algo = base_algo self._base_algo = base_algo
...@@ -124,7 +113,7 @@ class NetAdaptPruner(Pruner): ...@@ -124,7 +113,7 @@ class NetAdaptPruner(Pruner):
""" """
Parameters Parameters
---------- ----------
model : torch.nn.module model : torch.nn.Module
Model to be pruned Model to be pruned
config_list : list config_list : list
List on pruning configs List on pruning configs
......
...@@ -21,7 +21,7 @@ class OneshotPruner(Pruner): ...@@ -21,7 +21,7 @@ class OneshotPruner(Pruner):
""" """
Parameters Parameters
---------- ----------
model : torch.nn.module model : torch.nn.Module
Model to be pruned Model to be pruned
config_list : list config_list : list
List on pruning configs List on pruning configs
...@@ -41,7 +41,7 @@ class OneshotPruner(Pruner): ...@@ -41,7 +41,7 @@ class OneshotPruner(Pruner):
""" """
Parameters Parameters
---------- ----------
model : torch.nn.module model : torch.nn.Module
Model to be pruned Model to be pruned
config_list : list config_list : list
List on pruning configs List on pruning configs
...@@ -85,12 +85,32 @@ class OneshotPruner(Pruner): ...@@ -85,12 +85,32 @@ class OneshotPruner(Pruner):
return None return None
class LevelPruner(OneshotPruner): class LevelPruner(OneshotPruner):
def __init__(self, model, config_list, optimizer=None): """
super().__init__(model, config_list, pruning_algorithm='level', optimizer=optimizer) Parameters
----------
model : torch.nn.Module
Model to be pruned
config_list : list
Supported keys:
- sparsity : This is to specify the sparsity operations to be compressed to.
- op_types : Operation types to prune.
"""
def __init__(self, model, config_list):
super().__init__(model, config_list, pruning_algorithm='level')
class SlimPruner(OneshotPruner): class SlimPruner(OneshotPruner):
def __init__(self, model, config_list, optimizer=None): """
super().__init__(model, config_list, pruning_algorithm='slim', optimizer=optimizer) Parameters
----------
model : torch.nn.Module
Model to be pruned
config_list : list
Supported keys:
- sparsity : This is to specify the sparsity operations to be compressed to.
- op_types : Only BatchNorm2d is supported in Slim Pruner.
"""
def __init__(self, model, config_list):
super().__init__(model, config_list, pruning_algorithm='slim')
def validate_config(self, model, config_list): def validate_config(self, model, config_list):
schema = CompressorSchema([{ schema = CompressorSchema([{
...@@ -118,27 +138,87 @@ class _StructuredFilterPruner(OneshotPruner): ...@@ -118,27 +138,87 @@ class _StructuredFilterPruner(OneshotPruner):
schema.validate(config_list) schema.validate(config_list)
class L1FilterPruner(_StructuredFilterPruner): class L1FilterPruner(_StructuredFilterPruner):
def __init__(self, model, config_list, optimizer=None): """
super().__init__(model, config_list, pruning_algorithm='l1', optimizer=optimizer) Parameters
----------
model : torch.nn.Module
Model to be pruned
config_list : list
Supported keys:
- sparsity : This is to specify the sparsity operations to be compressed to.
- op_types : Only Conv2d is supported in L1FilterPruner.
"""
def __init__(self, model, config_list):
super().__init__(model, config_list, pruning_algorithm='l1')
class L2FilterPruner(_StructuredFilterPruner): class L2FilterPruner(_StructuredFilterPruner):
def __init__(self, model, config_list, optimizer=None): """
super().__init__(model, config_list, pruning_algorithm='l2', optimizer=optimizer) Parameters
----------
model : torch.nn.Module
Model to be pruned
config_list : list
Supported keys:
- sparsity : This is to specify the sparsity operations to be compressed to.
- op_types : Only Conv2d is supported in L2FilterPruner.
"""
def __init__(self, model, config_list):
super().__init__(model, config_list, pruning_algorithm='l2')
class FPGMPruner(_StructuredFilterPruner): class FPGMPruner(_StructuredFilterPruner):
def __init__(self, model, config_list, optimizer=None): """
super().__init__(model, config_list, pruning_algorithm='fpgm', optimizer=optimizer) Parameters
----------
model : torch.nn.Module
Model to be pruned
config_list : list
Supported keys:
- sparsity : This is to specify the sparsity operations to be compressed to.
- op_types : Only Conv2d is supported in FPGM Pruner.
"""
def __init__(self, model, config_list):
super().__init__(model, config_list, pruning_algorithm='fpgm')
class TaylorFOWeightFilterPruner(_StructuredFilterPruner): class TaylorFOWeightFilterPruner(_StructuredFilterPruner):
"""
Parameters
----------
model : torch.nn.Module
Model to be pruned
config_list : list
Supported keys:
- sparsity : How much percentage of convolutional filters are to be pruned.
- op_types : Currently only Conv2d is supported in TaylorFOWeightFilterPruner.
"""
def __init__(self, model, config_list, optimizer=None, statistics_batch_num=1): def __init__(self, model, config_list, optimizer=None, statistics_batch_num=1):
super().__init__(model, config_list, pruning_algorithm='taylorfo', optimizer=optimizer, statistics_batch_num=statistics_batch_num) super().__init__(model, config_list, pruning_algorithm='taylorfo', optimizer=optimizer, statistics_batch_num=statistics_batch_num)
class ActivationAPoZRankFilterPruner(_StructuredFilterPruner): class ActivationAPoZRankFilterPruner(_StructuredFilterPruner):
"""
Parameters
----------
model : torch.nn.Module
Model to be pruned
config_list : list
Supported keys:
- sparsity : How much percentage of convolutional filters are to be pruned.
- op_types : Only Conv2d is supported in ActivationAPoZRankFilterPruner.
"""
def __init__(self, model, config_list, optimizer=None, activation='relu', statistics_batch_num=1): def __init__(self, model, config_list, optimizer=None, activation='relu', statistics_batch_num=1):
super().__init__(model, config_list, pruning_algorithm='apoz', optimizer=optimizer, \ super().__init__(model, config_list, pruning_algorithm='apoz', optimizer=optimizer, \
activation=activation, statistics_batch_num=statistics_batch_num) activation=activation, statistics_batch_num=statistics_batch_num)
class ActivationMeanRankFilterPruner(_StructuredFilterPruner): class ActivationMeanRankFilterPruner(_StructuredFilterPruner):
"""
Parameters
----------
model : torch.nn.Module
Model to be pruned
config_list : list
Supported keys:
- sparsity : How much percentage of convolutional filters are to be pruned.
- op_types : Only Conv2d is supported in ActivationMeanRankFilterPruner.
"""
def __init__(self, model, config_list, optimizer=None, activation='relu', statistics_batch_num=1): def __init__(self, model, config_list, optimizer=None, activation='relu', statistics_batch_num=1):
super().__init__(model, config_list, pruning_algorithm='mean_activation', optimizer=optimizer, \ super().__init__(model, config_list, pruning_algorithm='mean_activation', optimizer=optimizer, \
activation=activation, statistics_batch_num=statistics_batch_num) activation=activation, statistics_batch_num=statistics_batch_num)
...@@ -22,62 +22,56 @@ _logger = logging.getLogger(__name__) ...@@ -22,62 +22,56 @@ _logger = logging.getLogger(__name__)
class SimulatedAnnealingPruner(Pruner): class SimulatedAnnealingPruner(Pruner):
""" """
This is a Pytorch implementation of Simulated Annealing compression algorithm. A Pytorch implementation of Simulated Annealing compression algorithm.
- Randomly initialize a pruning rate distribution (sparsities). Parameters
- While current_temperature < stop_temperature: ----------
1. generate a perturbation to current distribution model : pytorch model
2. Perform fast evaluation on the perturbated distribution The model to be pruned.
3. accept the perturbation according to the performance and probability, if not accepted, return to step 1 config_list : list
4. cool down, current_temperature <- current_temperature * cool_down_rate Supported keys:
- sparsity : The target overall sparsity.
- op_types : The operation type to prune.
evaluator : function
Function to evaluate the pruned model.
This function should include `model` as the only parameter, and returns a scalar value.
Example::
def evaluator(model):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
val_loader = ...
model.eval()
correct = 0
with torch.no_grad():
for data, target in val_loader:
data, target = data.to(device), target.to(device)
output = model(data)
# get the index of the max log-probability
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
accuracy = correct / len(val_loader.dataset)
return accuracy
optimize_mode : str
Optimize mode, `maximize` or `minimize`, by default `maximize`.
base_algo : str
Base pruning algorithm. `level`, `l1` or `l2`, by default `l1`. Given the sparsity distribution among the ops,
the assigned `base_algo` is used to decide which filters/channels/weights to prune.
start_temperature : float
Start temperature of the simulated annealing process.
stop_temperature : float
Stop temperature of the simulated annealing process.
cool_down_rate : float
Cool down rate of the temperature.
perturbation_magnitude : float
Initial perturbation magnitude to the sparsities. The magnitude decreases with current temperature.
experiment_data_dir : string
PATH to save experiment data,
including the config_list generated for the base pruning algorithm, the performance of the pruned model and the pruning history.
""" """
def __init__(self, model, config_list, evaluator, optimize_mode='maximize', base_algo='l1', def __init__(self, model, config_list, evaluator, optimize_mode='maximize', base_algo='l1',
start_temperature=100, stop_temperature=20, cool_down_rate=0.9, perturbation_magnitude=0.35, experiment_data_dir='./'): start_temperature=100, stop_temperature=20, cool_down_rate=0.9, perturbation_magnitude=0.35, experiment_data_dir='./'):
"""
Parameters
----------
model : pytorch model
The model to be pruned
config_list : list
Supported keys:
- sparsity : The target overall sparsity.
- op_types : The operation type to prune.
evaluator : function
function to evaluate the pruned model.
This function should include `model` as the only parameter, and returns a scalar value.
Example::
>>> def evaluator(model):
>>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
>>> val_loader = ...
>>> model.eval()
>>> correct = 0
>>> with torch.no_grad():
>>> for data, target in val_loader:
>>> data, target = data.to(device), target.to(device)
>>> output = model(data)
>>> # get the index of the max log-probability
>>> pred = output.argmax(dim=1, keepdim=True)
>>> correct += pred.eq(target.view_as(pred)).sum().item()
>>> accuracy = correct / len(val_loader.dataset)
>>> return accuracy
optimize_mode : str
optimize mode, `maximize` or `minimize`, by default `maximize`.
base_algo : str
Base pruning algorithm. `level`, `l1` or `l2`, by default `l1`. Given the sparsity distribution among the ops,
the assigned `base_algo` is used to decide which filters/channels/weights to prune.
start_temperature : float
Simualated Annealing related parameter
stop_temperature : float
Simualated Annealing related parameter
cool_down_rate : float
Simualated Annealing related parameter
perturbation_magnitude : float
initial perturbation magnitude to the sparsities. The magnitude decreases with current temperature
experiment_data_dir : string
PATH to save experiment data,
including the config_list generated for the base pruning algorithm, the performance of the pruned model and the pruning history.
"""
# original model # original model
self._model_to_prune = copy.deepcopy(model) self._model_to_prune = copy.deepcopy(model)
self._base_algo = base_algo self._base_algo = base_algo
...@@ -114,7 +108,7 @@ class SimulatedAnnealingPruner(Pruner): ...@@ -114,7 +108,7 @@ class SimulatedAnnealingPruner(Pruner):
""" """
Parameters Parameters
---------- ----------
model : torch.nn.module model : torch.nn.Module
Model to be pruned Model to be pruned
config_list : list config_list : list
List on pruning configs List on pruning configs
......
...@@ -153,7 +153,7 @@ class QAT_Quantizer(Quantizer): ...@@ -153,7 +153,7 @@ class QAT_Quantizer(Quantizer):
""" """
Parameters Parameters
---------- ----------
model : torch.nn.module model : torch.nn.Module
Model to be pruned Model to be pruned
config_list : list of dict config_list : list of dict
List of configurations List of configurations
...@@ -179,7 +179,7 @@ class QAT_Quantizer(Quantizer): ...@@ -179,7 +179,7 @@ class QAT_Quantizer(Quantizer):
---------- ----------
bits : int bits : int
quantization bits length quantization bits length
op : torch.nn.module op : torch.nn.Module
target module target module
real_val : float real_val : float
real value to be quantized real value to be quantized
...@@ -271,7 +271,7 @@ class DoReFaQuantizer(Quantizer): ...@@ -271,7 +271,7 @@ class DoReFaQuantizer(Quantizer):
""" """
Parameters Parameters
---------- ----------
model : torch.nn.module model : torch.nn.Module
Model to be pruned Model to be pruned
config_list : list of dict config_list : list of dict
List of configurations List of configurations
...@@ -322,7 +322,7 @@ class BNNQuantizer(Quantizer): ...@@ -322,7 +322,7 @@ class BNNQuantizer(Quantizer):
""" """
Parameters Parameters
---------- ----------
model : torch.nn.module model : torch.nn.Module
Model to be pruned Model to be pruned
config_list : list of dict config_list : list of dict
List of configurations List of configurations
......
...@@ -88,9 +88,8 @@ class CompressorTestCase(TestCase): ...@@ -88,9 +88,8 @@ class CompressorTestCase(TestCase):
def test_torch_level_pruner(self): def test_torch_level_pruner(self):
model = TorchModel() model = TorchModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
configure_list = [{'sparsity': 0.8, 'op_types': ['default']}] configure_list = [{'sparsity': 0.8, 'op_types': ['default']}]
torch_compressor.LevelPruner(model, configure_list, optimizer).compress() torch_compressor.LevelPruner(model, configure_list).compress()
@tf2 @tf2
def test_tf_level_pruner(self): def test_tf_level_pruner(self):
...@@ -129,7 +128,7 @@ class CompressorTestCase(TestCase): ...@@ -129,7 +128,7 @@ class CompressorTestCase(TestCase):
model = TorchModel() model = TorchModel()
config_list = [{'sparsity': 0.6, 'op_types': ['Conv2d']}, {'sparsity': 0.2, 'op_types': ['Conv2d']}] config_list = [{'sparsity': 0.6, 'op_types': ['Conv2d']}, {'sparsity': 0.2, 'op_types': ['Conv2d']}]
pruner = torch_compressor.FPGMPruner(model, config_list, torch.optim.SGD(model.parameters(), lr=0.01)) pruner = torch_compressor.FPGMPruner(model, config_list)
model.conv2.module.weight.data = torch.tensor(w).float() model.conv2.module.weight.data = torch.tensor(w).float()
masks = pruner.calc_mask(model.conv2) masks = pruner.calc_mask(model.conv2)
...@@ -315,7 +314,7 @@ class CompressorTestCase(TestCase): ...@@ -315,7 +314,7 @@ class CompressorTestCase(TestCase):
def test_torch_pruner_validation(self): def test_torch_pruner_validation(self):
# test bad configuraiton # test bad configuraiton
pruner_classes = [torch_compressor.__dict__[x] for x in \ pruner_classes = [torch_compressor.__dict__[x] for x in \
['LevelPruner', 'SlimPruner', 'FPGMPruner', 'L1FilterPruner', 'L2FilterPruner', 'AGP_Pruner', \ ['LevelPruner', 'SlimPruner', 'FPGMPruner', 'L1FilterPruner', 'L2FilterPruner', \
'ActivationMeanRankFilterPruner', 'ActivationAPoZRankFilterPruner']] 'ActivationMeanRankFilterPruner', 'ActivationAPoZRankFilterPruner']]
bad_configs = [ bad_configs = [
...@@ -337,11 +336,10 @@ class CompressorTestCase(TestCase): ...@@ -337,11 +336,10 @@ class CompressorTestCase(TestCase):
] ]
] ]
model = TorchModel() model = TorchModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
for pruner_class in pruner_classes: for pruner_class in pruner_classes:
for config_list in bad_configs: for config_list in bad_configs:
try: try:
pruner_class(model, config_list, optimizer) pruner_class(model, config_list)
print(config_list) print(config_list)
assert False, 'Validation error should be raised for bad configuration' assert False, 'Validation error should be raised for bad configuration'
except schema.SchemaError: except schema.SchemaError:
......
...@@ -8,7 +8,7 @@ import torch.nn.functional as F ...@@ -8,7 +8,7 @@ import torch.nn.functional as F
import math import math
from unittest import TestCase, main from unittest import TestCase, main
from nni.compression.torch import LevelPruner, SlimPruner, FPGMPruner, L1FilterPruner, \ from nni.compression.torch import LevelPruner, SlimPruner, FPGMPruner, L1FilterPruner, \
L2FilterPruner, AGP_Pruner, ActivationMeanRankFilterPruner, ActivationAPoZRankFilterPruner, \ L2FilterPruner, AGPPruner, ActivationMeanRankFilterPruner, ActivationAPoZRankFilterPruner, \
TaylorFOWeightFilterPruner, NetAdaptPruner, SimulatedAnnealingPruner, ADMMPruner, AutoCompressPruner TaylorFOWeightFilterPruner, NetAdaptPruner, SimulatedAnnealingPruner, ADMMPruner, AutoCompressPruner
def validate_sparsity(wrapper, sparsity, bias=False): def validate_sparsity(wrapper, sparsity, bias=False):
...@@ -33,7 +33,7 @@ prune_config = { ...@@ -33,7 +33,7 @@ prune_config = {
] ]
}, },
'agp': { 'agp': {
'pruner_class': AGP_Pruner, 'pruner_class': AGPPruner,
'config_list': [{ 'config_list': [{
'initial_sparsity': 0., 'initial_sparsity': 0.,
'final_sparsity': 0.8, 'final_sparsity': 0.8,
...@@ -192,7 +192,9 @@ def pruners_test(pruner_names=['level', 'agp', 'slim', 'fpgm', 'l1', 'l2', 'tayl ...@@ -192,7 +192,9 @@ def pruners_test(pruner_names=['level', 'agp', 'slim', 'fpgm', 'l1', 'l2', 'tayl
pruner = prune_config[pruner_name]['pruner_class'](model, config_list, trainer=prune_config[pruner_name]['trainer']) pruner = prune_config[pruner_name]['pruner_class'](model, config_list, trainer=prune_config[pruner_name]['trainer'])
elif pruner_name == 'autocompress': elif pruner_name == 'autocompress':
pruner = prune_config[pruner_name]['pruner_class'](model, config_list, trainer=prune_config[pruner_name]['trainer'], evaluator=prune_config[pruner_name]['evaluator'], dummy_input=x) pruner = prune_config[pruner_name]['pruner_class'](model, config_list, trainer=prune_config[pruner_name]['trainer'], evaluator=prune_config[pruner_name]['evaluator'], dummy_input=x)
else: elif pruner_name in ['level', 'slim', 'fpgm', 'l1', 'l2']:
pruner = prune_config[pruner_name]['pruner_class'](model, config_list)
elif pruner_name in ['agp', 'taylorfo', 'mean_activation', 'apoz']:
pruner = prune_config[pruner_name]['pruner_class'](model, config_list, optimizer) pruner = prune_config[pruner_name]['pruner_class'](model, config_list, optimizer)
pruner.compress() pruner.compress()
...@@ -225,7 +227,7 @@ def test_agp(pruning_algorithm): ...@@ -225,7 +227,7 @@ def test_agp(pruning_algorithm):
optimizer = torch.optim.SGD(model.parameters(), lr=0.01) optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
config_list = prune_config['agp']['config_list'] config_list = prune_config['agp']['config_list']
pruner = AGP_Pruner(model, config_list, optimizer, pruning_algorithm=pruning_algorithm) pruner = AGPPruner(model, config_list, optimizer, pruning_algorithm=pruning_algorithm)
pruner.compress() pruner.compress()
x = torch.randn(2, 1, 28, 28) x = torch.randn(2, 1, 28, 28)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment