Unverified Commit 41312de5 authored by Guoxin's avatar Guoxin Committed by GitHub
Browse files

Compression doc structure refactor (#2676)



* init sapruner

* seperate sapruners from other one-shot pruners

* update

* fix model params issue

* make the process runnable

* show evaluation result in example

* sort the sparsities and scale it

* fix rescale issue

* fix scale issue; add pruning history

* record the actual total sparsity

* fix sparsity 0/1 problem

* revert useless modif

* revert useless modif

* fix 0 pruning weights problem

* save pruning history in csv file

* fix typo

* remove check perm in Makefile

* use os path

* save config list in json format

* update analyze py; update docker

* update

* update analyze

* update log info in compressor

* init NetAdapt Pruner

* refine examples

* update

* fine tune

* update

* fix quote issue

* add code for imagenet  integrity

* update

* use datasets.ImageNet

* update

* update

* add channel pruning in SAPruner; refine example

* update net_adapt pruner; add dependency constraint in sapruner(beta)

* update

* update

* update

* fix zero division problem

* fix typo

* update

* fix naive issue of NetAdaptPruner

* fix data issue for no-dependency modules

* add cifar10 vgg16 examplel

* update

* update

* fix folder creation issue; change lr for vgg exp

* update

* add save model arg

* fix model copy issue

* init related weights calc

* update analyze file

* NetAdaptPruner: use fine-tuned weights after each iteration; fix modules_wrapper iteration issue

* consider channel/filter cross pruning

* NetAdapt: consider previous op when calc total sparsity

* update

* use customized vgg

* add performances comparison plt

* fix netadaptPruner mask copy issue

* add resnet18 example

* fix example issue

* update experiment data

* fix bool arg parsing issue

* update

* init ADMMPruner

* ADMMPruner: update

* ADMMPruner: finish v1.0

* ADMMPruner: refine

* update

* AutoCompress init

* AutoCompress: update

* AutoCompressPruner: fix issues:

* add test for auto pruners

* add doc for auto pruners

* fix link in md

* remove irrelevant files

* Clean code

* code clean

* fix pylint issue

* fix pylint issue

* rename admm & autoCompress param

* use abs link in doc

* reorder import to fix import issue: autocompress relies on speedup

* refine doc

* NetAdaptPruner: decay pruning step

* take changes from testing branch

* refine

* fix typo

* ADMMPruenr: check base_algo together with config schema

* fix broken link

* doc refine

* ADMM:refine

* refine doc

* refine doc

* refince doc

* refine doc

* refine doc

* refine doc

* update

* update

* refactor AGP doc

* update

* fix optimizer issue

* fix comments: typo, rename AGP_Pruner

* fix torch.nn.Module issue; refine SA docstring

* fix typo
Co-authored-by: default avatarYuge Zhang <scottyugochang@gmail.com>
parent cfda8c36
...@@ -84,7 +84,7 @@ config_list_agp = [{'initial_sparsity': 0, 'final_sparsity': conv0_sparsity, ...@@ -84,7 +84,7 @@ config_list_agp = [{'initial_sparsity': 0, 'final_sparsity': conv0_sparsity,
{'initial_sparsity': 0, 'final_sparsity': conv1_sparsity, {'initial_sparsity': 0, 'final_sparsity': conv1_sparsity,
'start_epoch': 0, 'end_epoch': 3, 'start_epoch': 0, 'end_epoch': 3,
'frequency': 1,'op_name': 'conv1' },] 'frequency': 1,'op_name': 'conv1' },]
PRUNERS = {'level':LevelPruner(model, config_list_level), 'agp':AGP_Pruner(model, config_list_agp)} PRUNERS = {'level':LevelPruner(model, config_list_level), 'agp':AGPPruner(model, config_list_agp)}
pruner = PRUNERS(params['prune_method']['_name']) pruner = PRUNERS(params['prune_method']['_name'])
pruner.compress() pruner.compress()
... # fine tuning ... # fine tuning
......
This diff is collapsed.
docs/img/agp_pruner.png

8.38 KB | W: | H:

docs/img/agp_pruner.png

18.4 KB | W: | H:

docs/img/agp_pruner.png
docs/img/agp_pruner.png
docs/img/agp_pruner.png
docs/img/agp_pruner.png
  • 2-up
  • Swipe
  • Onion skin
...@@ -22,7 +22,7 @@ configure_list = [{ ...@@ -22,7 +22,7 @@ configure_list = [{
'frequency': 1, 'frequency': 1,
'op_types': ['default'] 'op_types': ['default']
}] }]
pruner = AGP_Pruner(configure_list) pruner = AGPPruner(configure_list)
``` ```
When ```pruner(model)``` is called, your model is injected with masks as embedded operations. For example, a layer takes a weight as input, we will insert an operation between the weight and the layer, this operation takes the weight as input and outputs a new weight applied by the mask. Thus, the masks are applied at any time the computation goes through the operations. You can fine-tune your model **without** any modifications. When ```pruner(model)``` is called, your model is injected with masks as embedded operations. For example, a layer takes a weight as input, we will insert an operation between the weight and the layer, this operation takes the weight as input and outputs a new weight applied by the mask. Thus, the masks are applied at any time the computation goes through the operations. You can fine-tune your model **without** any modifications.
......
...@@ -10,7 +10,7 @@ from torchvision import datasets, transforms ...@@ -10,7 +10,7 @@ from torchvision import datasets, transforms
from models.cifar10.vgg import VGG from models.cifar10.vgg import VGG
import nni import nni
from nni.compression.torch import LevelPruner, SlimPruner, FPGMPruner, L1FilterPruner, \ from nni.compression.torch import LevelPruner, SlimPruner, FPGMPruner, L1FilterPruner, \
L2FilterPruner, AGP_Pruner, ActivationMeanRankFilterPruner, ActivationAPoZRankFilterPruner L2FilterPruner, AGPPruner, ActivationMeanRankFilterPruner, ActivationAPoZRankFilterPruner
prune_config = { prune_config = {
'level': { 'level': {
...@@ -25,7 +25,7 @@ prune_config = { ...@@ -25,7 +25,7 @@ prune_config = {
'agp': { 'agp': {
'dataset_name': 'mnist', 'dataset_name': 'mnist',
'model_name': 'naive', 'model_name': 'naive',
'pruner_class': AGP_Pruner, 'pruner_class': AGPPruner,
'config_list': [{ 'config_list': [{
'initial_sparsity': 0., 'initial_sparsity': 0.,
'final_sparsity': 0.8, 'final_sparsity': 0.8,
......
...@@ -6,17 +6,23 @@ import numpy as np ...@@ -6,17 +6,23 @@ import numpy as np
import tensorflow as tf import tensorflow as tf
from .compressor import Pruner from .compressor import Pruner
__all__ = ['LevelPruner', 'AGP_Pruner', 'FPGMPruner'] __all__ = ['LevelPruner', 'AGPPruner', 'FPGMPruner']
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
class LevelPruner(Pruner): class LevelPruner(Pruner):
def __init__(self, model, config_list):
""" """
config_list: supported keys: Parameters
- sparsity ----------
model : tensorflow model
Model to be pruned
config_list : list
Supported keys:
- sparsity : This is to specify the sparsity operations to be compressed to.
- op_types : Operation types to prune.
""" """
def __init__(self, model, config_list):
super().__init__(model, config_list) super().__init__(model, config_list)
self.mask_list = {} self.mask_list = {}
self.if_init_list = {} self.if_init_list = {}
...@@ -34,24 +40,22 @@ class LevelPruner(Pruner): ...@@ -34,24 +40,22 @@ class LevelPruner(Pruner):
return mask return mask
class AGP_Pruner(Pruner): class AGPPruner(Pruner):
"""An automated gradual pruning algorithm that prunes the smallest magnitude """
weights to achieve a preset level of network sparsity. Parameters
Michael Zhu and Suyog Gupta, "To prune, or not to prune: exploring the ----------
efficacy of pruning for model compression", 2017 NIPS Workshop on Machine model : torch.nn.Module
Learning of Phones and other Consumer Devices, Model to be pruned.
https://arxiv.org/pdf/1710.01878.pdf config_list : listlist
Supported keys:
- initial_sparsity: This is to specify the sparsity when compressor starts to compress.
- final_sparsity: This is to specify the sparsity when compressor finishes to compress.
- start_epoch: This is to specify the epoch number when compressor starts to compress, default start from epoch 0.
- end_epoch: This is to specify the epoch number when compressor finishes to compress.
- frequency: This is to specify every *frequency* number epochs compressor compress once, default frequency=1.
""" """
def __init__(self, model, config_list): def __init__(self, model, config_list):
"""
config_list: supported keys:
- initial_sparsity
- final_sparsity: you should make sure initial_sparsity <= final_sparsity
- start_epoch: start epoch numer begin update mask
- end_epoch: end epoch number stop update mask
- frequency: if you want update every 2 epoch, you can set it 2
"""
super().__init__(model, config_list) super().__init__(model, config_list)
self.mask_list = {} self.mask_list = {}
self.if_init_list = {} self.if_init_list = {}
...@@ -102,23 +106,19 @@ class AGP_Pruner(Pruner): ...@@ -102,23 +106,19 @@ class AGP_Pruner(Pruner):
for k in self.if_init_list: for k in self.if_init_list:
self.if_init_list[k] = True self.if_init_list[k] = True
class FPGMPruner(Pruner):
"""
A filter pruner via geometric median.
"Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration",
https://arxiv.org/pdf/1811.00250.pdf
"""
def __init__(self, model, config_list): class FPGMPruner(Pruner):
""" """
Parameters Parameters
---------- ----------
model : pytorch model model : tensorflow model
the model user wants to compress Model to be pruned
config_list: list config_list : list
support key for each list item: Supported keys:
- sparsity: percentage of convolutional filters to be pruned. - sparsity : percentage of convolutional filters to be pruned.
- op_types : Only Conv2d is supported in FPGM Pruner.
""" """
def __init__(self, model, config_list):
super().__init__(model, config_list) super().__init__(model, config_list)
self.mask_dict = {} self.mask_dict = {}
self.assign_handler = [] self.assign_handler = []
......
...@@ -15,24 +15,14 @@ _logger = logging.getLogger(__name__) ...@@ -15,24 +15,14 @@ _logger = logging.getLogger(__name__)
class ADMMPruner(OneshotPruner): class ADMMPruner(OneshotPruner):
""" """
This is a Pytorch implementation of ADMM Pruner algorithm. A Pytorch implementation of ADMM Pruner algorithm.
Alternating Direction Method of Multipliers (ADMM) is a mathematical optimization technique,
by decomposing the original nonconvex problem into two subproblems that can be solved iteratively.
In weight pruning problem, these two subproblems are solved via 1) gradient descent algorithm and 2) Euclidean projection respectively.
This solution framework applies both to non-structured and different variations of structured pruning schemes.
For more details, please refer to the paper: https://arxiv.org/abs/1804.03294.
"""
def __init__(self, model, config_list, trainer, num_iterations=30, training_epochs=5, row=1e-4, base_algo='l1'):
"""
Parameters Parameters
---------- ----------
model : torch.nn.module model : torch.nn.Module
Model to be pruned Model to be pruned.
config_list : list config_list : list
List on pruning configs List on pruning configs.
trainer : function trainer : function
Function used for the first subproblem. Function used for the first subproblem.
Users should write this function as a normal function to train the Pytorch model Users should write this function as a normal function to train the Pytorch model
...@@ -41,22 +31,21 @@ class ADMMPruner(OneshotPruner): ...@@ -41,22 +31,21 @@ class ADMMPruner(OneshotPruner):
The logic of `callback` is implemented inside the Pruner, The logic of `callback` is implemented inside the Pruner,
users are just required to insert `callback()` between `loss.backward()` and `optimizer.step()`. users are just required to insert `callback()` between `loss.backward()` and `optimizer.step()`.
Example:: Example::
```
>>> def trainer(model, criterion, optimizer, epoch, callback): def trainer(model, criterion, optimizer, epoch, callback):
>>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
>>> train_loader = ... train_loader = ...
>>> model.train() model.train()
>>> for batch_idx, (data, target) in enumerate(train_loader): for batch_idx, (data, target) in enumerate(train_loader):
>>> data, target = data.to(device), target.to(device) data, target = data.to(device), target.to(device)
>>> optimizer.zero_grad() optimizer.zero_grad()
>>> output = model(data) output = model(data)
>>> loss = criterion(output, target) loss = criterion(output, target)
>>> loss.backward() loss.backward()
>>> # callback should be inserted between loss.backward() and optimizer.step() # callback should be inserted between loss.backward() and optimizer.step()
>>> if callback: if callback:
>>> callback() callback()
>>> optimizer.step() optimizer.step()
```
num_iterations : int num_iterations : int
Total number of iterations. Total number of iterations.
training_epochs : int training_epochs : int
...@@ -66,7 +55,10 @@ class ADMMPruner(OneshotPruner): ...@@ -66,7 +55,10 @@ class ADMMPruner(OneshotPruner):
base_algo : str base_algo : str
Base pruning algorithm. `level`, `l1` or `l2`, by default `l1`. Given the sparsity distribution among the ops, Base pruning algorithm. `level`, `l1` or `l2`, by default `l1`. Given the sparsity distribution among the ops,
the assigned `base_algo` is used to decide which filters/channels/weights to prune. the assigned `base_algo` is used to decide which filters/channels/weights to prune.
""" """
def __init__(self, model, config_list, trainer, num_iterations=30, training_epochs=5, row=1e-4, base_algo='l1'):
self._base_algo = base_algo self._base_algo = base_algo
super().__init__(model, config_list) super().__init__(model, config_list)
...@@ -83,7 +75,7 @@ class ADMMPruner(OneshotPruner): ...@@ -83,7 +75,7 @@ class ADMMPruner(OneshotPruner):
""" """
Parameters Parameters
---------- ----------
model : torch.nn.module model : torch.nn.Module
Model to be pruned Model to be pruned
config_list : list config_list : list
List on pruning configs List on pruning configs
......
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
"""
An automated gradual pruning algorithm that prunes the smallest magnitude
weights to achieve a preset level of network sparsity.
Michael Zhu and Suyog Gupta, "To prune, or not to prune: exploring the
efficacy of pruning for model compression", 2017 NIPS Workshop on Machine
Learning of Phones and other Consumer Devices.
"""
import logging import logging
import torch import torch
from schema import And, Optional from schema import And, Optional
...@@ -8,34 +16,31 @@ from .constants import MASKER_DICT ...@@ -8,34 +16,31 @@ from .constants import MASKER_DICT
from ..utils.config_validation import CompressorSchema from ..utils.config_validation import CompressorSchema
from ..compressor import Pruner from ..compressor import Pruner
__all__ = ['AGP_Pruner'] __all__ = ['AGPPruner']
logger = logging.getLogger('torch pruner') logger = logging.getLogger('torch pruner')
class AGP_Pruner(Pruner): class AGPPruner(Pruner):
"""
An automated gradual pruning algorithm that prunes the smallest magnitude
weights to achieve a preset level of network sparsity.
Michael Zhu and Suyog Gupta, "To prune, or not to prune: exploring the
efficacy of pruning for model compression", 2017 NIPS Workshop on Machine
Learning of Phones and other Consumer Devices,
https://arxiv.org/pdf/1710.01878.pdf
"""
def __init__(self, model, config_list, optimizer, pruning_algorithm='level'):
""" """
Parameters Parameters
---------- ----------
model : torch.nn.module model : torch.nn.Module
Model to be pruned Model to be pruned.
config_list : list config_list : listlist
List on pruning configs Supported keys:
- initial_sparsity: This is to specify the sparsity when compressor starts to compress.
- final_sparsity: This is to specify the sparsity when compressor finishes to compress.
- start_epoch: This is to specify the epoch number when compressor starts to compress, default start from epoch 0.
- end_epoch: This is to specify the epoch number when compressor finishes to compress.
- frequency: This is to specify every *frequency* number epochs compressor compress once, default frequency=1.
optimizer: torch.optim.Optimizer optimizer: torch.optim.Optimizer
Optimizer used to train model Optimizer used to train model.
pruning_algorithm: str pruning_algorithm: str
algorithms being used to prune model Algorithms being used to prune model,
choose from `['level', 'slim', 'l1', 'l2', 'fpgm', 'taylorfo', 'apoz', 'mean_activation']`, by default `level`
""" """
def __init__(self, model, config_list, optimizer, pruning_algorithm='level'):
super().__init__(model, config_list, optimizer) super().__init__(model, config_list, optimizer)
assert isinstance(optimizer, torch.optim.Optimizer), "AGP pruner is an iterative pruner, please pass optimizer of the model to it" assert isinstance(optimizer, torch.optim.Optimizer), "AGP pruner is an iterative pruner, please pass optimizer of the model to it"
self.masker = MASKER_DICT[pruning_algorithm](model, self) self.masker = MASKER_DICT[pruning_algorithm](model, self)
...@@ -47,7 +52,7 @@ class AGP_Pruner(Pruner): ...@@ -47,7 +52,7 @@ class AGP_Pruner(Pruner):
""" """
Parameters Parameters
---------- ----------
model : torch.nn.module model : torch.nn.Module
Model to be pruned Model to be pruned
config_list : list config_list : list
List on pruning configs List on pruning configs
......
...@@ -14,7 +14,7 @@ def apply_compression_results(model, masks_file, map_location=None): ...@@ -14,7 +14,7 @@ def apply_compression_results(model, masks_file, map_location=None):
Parameters Parameters
---------- ----------
model : torch.nn.module model : torch.nn.Module
The model to be compressed The model to be compressed
masks_file : str masks_file : str
The path of the mask file The path of the mask file
......
...@@ -21,28 +21,12 @@ _logger = logging.getLogger(__name__) ...@@ -21,28 +21,12 @@ _logger = logging.getLogger(__name__)
class AutoCompressPruner(Pruner): class AutoCompressPruner(Pruner):
""" """
This is a Pytorch implementation of AutoCompress pruning algorithm. A Pytorch implementation of AutoCompress pruning algorithm.
For each round, AutoCompressPruner prune the model for the same sparsity to achive the ovrall sparsity:
1. Generate sparsities distribution using SimualtedAnnealingPruner
2. Perform ADMM-based structured pruning to generate pruning result for the next round.
Here we use 'speedup' to perform real pruning.
For more details, please refer to the paper: https://arxiv.org/abs/1907.03141.
"""
def __init__(self, model, config_list, trainer, evaluator, dummy_input,
num_iterations=3, optimize_mode='maximize', base_algo='l1',
# SimulatedAnnealing related
start_temperature=100, stop_temperature=20, cool_down_rate=0.9, perturbation_magnitude=0.35,
# ADMM related
admm_num_iterations=30, admm_training_epochs=5, row=1e-4,
experiment_data_dir='./'):
"""
Parameters Parameters
---------- ----------
model : pytorch model model : pytorch model
The model to be pruned The model to be pruned.
config_list : list config_list : list
Supported keys: Supported keys:
- sparsity : The target overall sparsity. - sparsity : The target overall sparsity.
...@@ -55,66 +39,74 @@ class AutoCompressPruner(Pruner): ...@@ -55,66 +39,74 @@ class AutoCompressPruner(Pruner):
The logic of `callback` is implemented inside the Pruner, The logic of `callback` is implemented inside the Pruner,
users are just required to insert `callback()` between `loss.backward()` and `optimizer.step()`. users are just required to insert `callback()` between `loss.backward()` and `optimizer.step()`.
Example:: Example::
```
>>> def trainer(model, criterion, optimizer, epoch, callback): def trainer(model, criterion, optimizer, epoch, callback):
>>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
>>> train_loader = ... train_loader = ...
>>> model.train() model.train()
>>> for batch_idx, (data, target) in enumerate(train_loader): for batch_idx, (data, target) in enumerate(train_loader):
>>> data, target = data.to(device), target.to(device) data, target = data.to(device), target.to(device)
>>> optimizer.zero_grad() optimizer.zero_grad()
>>> output = model(data) output = model(data)
>>> loss = criterion(output, target) loss = criterion(output, target)
>>> loss.backward() loss.backward()
>>> # callback should be inserted between loss.backward() and optimizer.step() # callback should be inserted between loss.backward() and optimizer.step()
>>> if callback: if callback:
>>> callback() callback()
>>> optimizer.step() optimizer.step()
```
evaluator : function evaluator : function
function to evaluate the pruned model. function to evaluate the pruned model.
This function should include `model` as the only parameter, and returns a scalar value. This function should include `model` as the only parameter, and returns a scalar value.
Example:: Example::
>>> def evaluator(model):
>>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def evaluator(model):
>>> val_loader = ... device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
>>> model.eval() val_loader = ...
>>> correct = 0 model.eval()
>>> with torch.no_grad(): correct = 0
>>> for data, target in val_loader: with torch.no_grad():
>>> data, target = data.to(device), target.to(device) for data, target in val_loader:
>>> output = model(data) data, target = data.to(device), target.to(device)
>>> # get the index of the max log-probability output = model(data)
>>> pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
>>> correct += pred.eq(target.view_as(pred)).sum().item() pred = output.argmax(dim=1, keepdim=True)
>>> accuracy = correct / len(val_loader.dataset) correct += pred.eq(target.view_as(pred)).sum().item()
>>> return accuracy accuracy = correct / len(val_loader.dataset)
return accuracy
dummy_input : pytorch tensor dummy_input : pytorch tensor
The dummy input for ```jit.trace```, users should put it on right device before pass in The dummy input for ```jit.trace```, users should put it on right device before pass in.
num_iterations : int num_iterations : int
Number of overall iterations Number of overall iterations.
optimize_mode : str optimize_mode : str
optimize mode, `maximize` or `minimize`, by default `maximize` optimize mode, `maximize` or `minimize`, by default `maximize`.
base_algo : str base_algo : str
Base pruning algorithm. `level`, `l1` or `l2`, by default `l1`. Given the sparsity distribution among the ops, Base pruning algorithm. `level`, `l1` or `l2`, by default `l1`. Given the sparsity distribution among the ops,
the assigned `base_algo` is used to decide which filters/channels/weights to prune. the assigned `base_algo` is used to decide which filters/channels/weights to prune.
start_temperature : float start_temperature : float
Simualated Annealing related parameter Start temperature of the simulated annealing process.
stop_temperature : float stop_temperature : float
Simualated Annealing related parameter Stop temperature of the simulated annealing process.
cool_down_rate : float cool_down_rate : float
Simualated Annealing related parameter Cool down rate of the temperature.
perturbation_magnitude : float perturbation_magnitude : float
Initial perturbation magnitude to the sparsities. The magnitude decreases with current temperature Initial perturbation magnitude to the sparsities. The magnitude decreases with current temperature.
admm_num_iterations : int admm_num_iterations : int
Number of iterations of ADMM Pruner Number of iterations of ADMM Pruner.
admm_training_epochs : int admm_training_epochs : int
Training epochs of the first optimization subproblem of ADMMPruner Training epochs of the first optimization subproblem of ADMMPruner.
row : float row : float
Penalty parameters for ADMM training Penalty parameters for ADMM training.
experiment_data_dir : string experiment_data_dir : string
PATH to store temporary experiment data PATH to store temporary experiment data.
""" """
def __init__(self, model, config_list, trainer, evaluator, dummy_input,
num_iterations=3, optimize_mode='maximize', base_algo='l1',
# SimulatedAnnealing related
start_temperature=100, stop_temperature=20, cool_down_rate=0.9, perturbation_magnitude=0.35,
# ADMM related
admm_num_iterations=30, admm_training_epochs=5, row=1e-4,
experiment_data_dir='./'):
# original model # original model
self._model_to_prune = model self._model_to_prune = model
self._base_algo = base_algo self._base_algo = base_algo
...@@ -147,7 +139,7 @@ class AutoCompressPruner(Pruner): ...@@ -147,7 +139,7 @@ class AutoCompressPruner(Pruner):
""" """
Parameters Parameters
---------- ----------
model : torch.nn.module model : torch.nn.Module
Model to be pruned Model to be pruned
config_list : list config_list : list
List on pruning configs List on pruning configs
......
...@@ -12,18 +12,6 @@ from .finegrained_pruning import LevelPrunerMasker ...@@ -12,18 +12,6 @@ from .finegrained_pruning import LevelPrunerMasker
logger = logging.getLogger('torch pruner') logger = logging.getLogger('torch pruner')
class LotteryTicketPruner(Pruner): class LotteryTicketPruner(Pruner):
"""
This is a Pytorch implementation of the paper "The Lottery Ticket Hypothesis: Finding Sparse, Trainable Neural Networks",
following NNI model compression interface.
1. Randomly initialize a neural network f(x;theta_0) (where theta_0 follows D_{theta}).
2. Train the network for j iterations, arriving at parameters theta_j.
3. Prune p% of the parameters in theta_j, creating a mask m.
4. Reset the remaining parameters to their values in theta_0, creating the winning ticket f(x;m*theta_0).
5. Repeat step 2, 3, and 4.
"""
def __init__(self, model, config_list, optimizer=None, lr_scheduler=None, reset_weights=True):
""" """
Parameters Parameters
---------- ----------
...@@ -40,6 +28,7 @@ class LotteryTicketPruner(Pruner): ...@@ -40,6 +28,7 @@ class LotteryTicketPruner(Pruner):
reset_weights : bool reset_weights : bool
Whether reset weights and optimizer at the beginning of each round. Whether reset weights and optimizer at the beginning of each round.
""" """
def __init__(self, model, config_list, optimizer=None, lr_scheduler=None, reset_weights=True):
# save init weights and optimizer # save init weights and optimizer
self.reset_weights = reset_weights self.reset_weights = reset_weights
if self.reset_weights: if self.reset_weights:
...@@ -60,7 +49,7 @@ class LotteryTicketPruner(Pruner): ...@@ -60,7 +49,7 @@ class LotteryTicketPruner(Pruner):
""" """
Parameters Parameters
---------- ----------
model : torch.nn.module model : torch.nn.Module
Model to be pruned Model to be pruned
config_list : list config_list : list
Supported keys: Supported keys:
......
...@@ -21,28 +21,12 @@ _logger = logging.getLogger(__name__) ...@@ -21,28 +21,12 @@ _logger = logging.getLogger(__name__)
class NetAdaptPruner(Pruner): class NetAdaptPruner(Pruner):
""" """
This is a Pytorch implementation of NetAdapt compression algorithm. A Pytorch implementation of NetAdapt compression algorithm.
The pruning procedure can be described as follows:
While Res_i > Bud:
1. Con = Res_i - delta_Res
2. for every layer:
Choose Num Filters to prune
Choose which filter to prune
Short-term fine tune the pruned model
3. Pick the best layer to prune
Long-term fine tune
For the details of this algorithm, please refer to the paper: https://arxiv.org/abs/1804.03230
"""
def __init__(self, model, config_list, short_term_fine_tuner, evaluator,
optimize_mode='maximize', base_algo='l1', sparsity_per_iteration=0.05, experiment_data_dir='./'):
"""
Parameters Parameters
---------- ----------
model : pytorch model model : pytorch model
The model to be pruned The model to be pruned.
config_list : list config_list : list
Supported keys: Supported keys:
- sparsity : The target overall sparsity. - sparsity : The target overall sparsity.
...@@ -51,50 +35,55 @@ class NetAdaptPruner(Pruner): ...@@ -51,50 +35,55 @@ class NetAdaptPruner(Pruner):
function to short-term fine tune the masked model. function to short-term fine tune the masked model.
This function should include `model` as the only parameter, This function should include `model` as the only parameter,
and fine tune the model for a short term after each pruning iteration. and fine tune the model for a short term after each pruning iteration.
Example: Example::
>>> def short_term_fine_tuner(model, epoch=3):
>>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def short_term_fine_tuner(model, epoch=3):
>>> train_loader = ... device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
>>> criterion = torch.nn.CrossEntropyLoss() train_loader = ...
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.01) criterion = torch.nn.CrossEntropyLoss()
>>> model.train() optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
>>> for _ in range(epoch): model.train()
>>> for _, (data, target) in enumerate(train_loader): for _ in range(epoch):
>>> data, target = data.to(device), target.to(device) for batch_idx, (data, target) in enumerate(train_loader):
>>> optimizer.zero_grad() data, target = data.to(device), target.to(device)
>>> output = model(data) optimizer.zero_grad()
>>> loss = criterion(output, target) output = model(data)
>>> loss.backward() loss = criterion(output, target)
>>> optimizer.step() loss.backward()
optimizer.step()
evaluator : function evaluator : function
function to evaluate the masked model. function to evaluate the masked model.
This function should include `model` as the only parameter, and returns a scalar value. This function should include `model` as the only parameter, and returns a scalar value.
Example:: Example::
>>> def evaluator(model):
>>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def evaluator(model):
>>> val_loader = ... device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
>>> model.eval() val_loader = ...
>>> correct = 0 model.eval()
>>> with torch.no_grad(): correct = 0
>>> for data, target in val_loader: with torch.no_grad():
>>> data, target = data.to(device), target.to(device) for data, target in val_loader:
>>> output = model(data) data, target = data.to(device), target.to(device)
>>> # get the index of the max log-probability output = model(data)
>>> pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
>>> correct += pred.eq(target.view_as(pred)).sum().item() pred = output.argmax(dim=1, keepdim=True)
>>> accuracy = correct / len(val_loader.dataset) correct += pred.eq(target.view_as(pred)).sum().item()
>>> return accuracy accuracy = correct / len(val_loader.dataset)
return accuracy
optimize_mode : str optimize_mode : str
optimize mode, `maximize` or `minimize`, by default `maximize`. optimize mode, `maximize` or `minimize`, by default `maximize`.
base_algo : str base_algo : str
Base pruning algorithm. `level`, `l1` or `l2`, by default `l1`. Given the sparsity distribution among the ops, Base pruning algorithm. `level`, `l1` or `l2`, by default `l1`. Given the sparsity distribution among the ops,
the assigned `base_algo` is used to decide which filters/channels/weights to prune. the assigned `base_algo` is used to decide which filters/channels/weights to prune.
sparsity_per_iteration : float sparsity_per_iteration : float
sparsity to prune in each iteration sparsity to prune in each iteration.
experiment_data_dir : str experiment_data_dir : str
PATH to save experiment data, PATH to save experiment data,
including the config_list generated for the base pruning algorithm and the performance of the pruned model. including the config_list generated for the base pruning algorithm and the performance of the pruned model.
""" """
def __init__(self, model, config_list, short_term_fine_tuner, evaluator,
optimize_mode='maximize', base_algo='l1', sparsity_per_iteration=0.05, experiment_data_dir='./'):
# models used for iterative pruning and evaluation # models used for iterative pruning and evaluation
self._model_to_prune = copy.deepcopy(model) self._model_to_prune = copy.deepcopy(model)
self._base_algo = base_algo self._base_algo = base_algo
...@@ -124,7 +113,7 @@ class NetAdaptPruner(Pruner): ...@@ -124,7 +113,7 @@ class NetAdaptPruner(Pruner):
""" """
Parameters Parameters
---------- ----------
model : torch.nn.module model : torch.nn.Module
Model to be pruned Model to be pruned
config_list : list config_list : list
List on pruning configs List on pruning configs
......
...@@ -21,7 +21,7 @@ class OneshotPruner(Pruner): ...@@ -21,7 +21,7 @@ class OneshotPruner(Pruner):
""" """
Parameters Parameters
---------- ----------
model : torch.nn.module model : torch.nn.Module
Model to be pruned Model to be pruned
config_list : list config_list : list
List on pruning configs List on pruning configs
...@@ -41,7 +41,7 @@ class OneshotPruner(Pruner): ...@@ -41,7 +41,7 @@ class OneshotPruner(Pruner):
""" """
Parameters Parameters
---------- ----------
model : torch.nn.module model : torch.nn.Module
Model to be pruned Model to be pruned
config_list : list config_list : list
List on pruning configs List on pruning configs
...@@ -85,12 +85,32 @@ class OneshotPruner(Pruner): ...@@ -85,12 +85,32 @@ class OneshotPruner(Pruner):
return None return None
class LevelPruner(OneshotPruner): class LevelPruner(OneshotPruner):
def __init__(self, model, config_list, optimizer=None): """
super().__init__(model, config_list, pruning_algorithm='level', optimizer=optimizer) Parameters
----------
model : torch.nn.Module
Model to be pruned
config_list : list
Supported keys:
- sparsity : This is to specify the sparsity operations to be compressed to.
- op_types : Operation types to prune.
"""
def __init__(self, model, config_list):
super().__init__(model, config_list, pruning_algorithm='level')
class SlimPruner(OneshotPruner): class SlimPruner(OneshotPruner):
def __init__(self, model, config_list, optimizer=None): """
super().__init__(model, config_list, pruning_algorithm='slim', optimizer=optimizer) Parameters
----------
model : torch.nn.Module
Model to be pruned
config_list : list
Supported keys:
- sparsity : This is to specify the sparsity operations to be compressed to.
- op_types : Only BatchNorm2d is supported in Slim Pruner.
"""
def __init__(self, model, config_list):
super().__init__(model, config_list, pruning_algorithm='slim')
def validate_config(self, model, config_list): def validate_config(self, model, config_list):
schema = CompressorSchema([{ schema = CompressorSchema([{
...@@ -118,27 +138,87 @@ class _StructuredFilterPruner(OneshotPruner): ...@@ -118,27 +138,87 @@ class _StructuredFilterPruner(OneshotPruner):
schema.validate(config_list) schema.validate(config_list)
class L1FilterPruner(_StructuredFilterPruner): class L1FilterPruner(_StructuredFilterPruner):
def __init__(self, model, config_list, optimizer=None): """
super().__init__(model, config_list, pruning_algorithm='l1', optimizer=optimizer) Parameters
----------
model : torch.nn.Module
Model to be pruned
config_list : list
Supported keys:
- sparsity : This is to specify the sparsity operations to be compressed to.
- op_types : Only Conv2d is supported in L1FilterPruner.
"""
def __init__(self, model, config_list):
super().__init__(model, config_list, pruning_algorithm='l1')
class L2FilterPruner(_StructuredFilterPruner): class L2FilterPruner(_StructuredFilterPruner):
def __init__(self, model, config_list, optimizer=None): """
super().__init__(model, config_list, pruning_algorithm='l2', optimizer=optimizer) Parameters
----------
model : torch.nn.Module
Model to be pruned
config_list : list
Supported keys:
- sparsity : This is to specify the sparsity operations to be compressed to.
- op_types : Only Conv2d is supported in L2FilterPruner.
"""
def __init__(self, model, config_list):
super().__init__(model, config_list, pruning_algorithm='l2')
class FPGMPruner(_StructuredFilterPruner): class FPGMPruner(_StructuredFilterPruner):
def __init__(self, model, config_list, optimizer=None): """
super().__init__(model, config_list, pruning_algorithm='fpgm', optimizer=optimizer) Parameters
----------
model : torch.nn.Module
Model to be pruned
config_list : list
Supported keys:
- sparsity : This is to specify the sparsity operations to be compressed to.
- op_types : Only Conv2d is supported in FPGM Pruner.
"""
def __init__(self, model, config_list):
super().__init__(model, config_list, pruning_algorithm='fpgm')
class TaylorFOWeightFilterPruner(_StructuredFilterPruner): class TaylorFOWeightFilterPruner(_StructuredFilterPruner):
"""
Parameters
----------
model : torch.nn.Module
Model to be pruned
config_list : list
Supported keys:
- sparsity : How much percentage of convolutional filters are to be pruned.
- op_types : Currently only Conv2d is supported in TaylorFOWeightFilterPruner.
"""
def __init__(self, model, config_list, optimizer=None, statistics_batch_num=1): def __init__(self, model, config_list, optimizer=None, statistics_batch_num=1):
super().__init__(model, config_list, pruning_algorithm='taylorfo', optimizer=optimizer, statistics_batch_num=statistics_batch_num) super().__init__(model, config_list, pruning_algorithm='taylorfo', optimizer=optimizer, statistics_batch_num=statistics_batch_num)
class ActivationAPoZRankFilterPruner(_StructuredFilterPruner): class ActivationAPoZRankFilterPruner(_StructuredFilterPruner):
"""
Parameters
----------
model : torch.nn.Module
Model to be pruned
config_list : list
Supported keys:
- sparsity : How much percentage of convolutional filters are to be pruned.
- op_types : Only Conv2d is supported in ActivationAPoZRankFilterPruner.
"""
def __init__(self, model, config_list, optimizer=None, activation='relu', statistics_batch_num=1): def __init__(self, model, config_list, optimizer=None, activation='relu', statistics_batch_num=1):
super().__init__(model, config_list, pruning_algorithm='apoz', optimizer=optimizer, \ super().__init__(model, config_list, pruning_algorithm='apoz', optimizer=optimizer, \
activation=activation, statistics_batch_num=statistics_batch_num) activation=activation, statistics_batch_num=statistics_batch_num)
class ActivationMeanRankFilterPruner(_StructuredFilterPruner): class ActivationMeanRankFilterPruner(_StructuredFilterPruner):
"""
Parameters
----------
model : torch.nn.Module
Model to be pruned
config_list : list
Supported keys:
- sparsity : How much percentage of convolutional filters are to be pruned.
- op_types : Only Conv2d is supported in ActivationMeanRankFilterPruner.
"""
def __init__(self, model, config_list, optimizer=None, activation='relu', statistics_batch_num=1): def __init__(self, model, config_list, optimizer=None, activation='relu', statistics_batch_num=1):
super().__init__(model, config_list, pruning_algorithm='mean_activation', optimizer=optimizer, \ super().__init__(model, config_list, pruning_algorithm='mean_activation', optimizer=optimizer, \
activation=activation, statistics_batch_num=statistics_batch_num) activation=activation, statistics_batch_num=statistics_batch_num)
...@@ -22,62 +22,56 @@ _logger = logging.getLogger(__name__) ...@@ -22,62 +22,56 @@ _logger = logging.getLogger(__name__)
class SimulatedAnnealingPruner(Pruner): class SimulatedAnnealingPruner(Pruner):
""" """
This is a Pytorch implementation of Simulated Annealing compression algorithm. A Pytorch implementation of Simulated Annealing compression algorithm.
- Randomly initialize a pruning rate distribution (sparsities).
- While current_temperature < stop_temperature:
1. generate a perturbation to current distribution
2. Perform fast evaluation on the perturbated distribution
3. accept the perturbation according to the performance and probability, if not accepted, return to step 1
4. cool down, current_temperature <- current_temperature * cool_down_rate
"""
def __init__(self, model, config_list, evaluator, optimize_mode='maximize', base_algo='l1',
start_temperature=100, stop_temperature=20, cool_down_rate=0.9, perturbation_magnitude=0.35, experiment_data_dir='./'):
"""
Parameters Parameters
---------- ----------
model : pytorch model model : pytorch model
The model to be pruned The model to be pruned.
config_list : list config_list : list
Supported keys: Supported keys:
- sparsity : The target overall sparsity. - sparsity : The target overall sparsity.
- op_types : The operation type to prune. - op_types : The operation type to prune.
evaluator : function evaluator : function
function to evaluate the pruned model. Function to evaluate the pruned model.
This function should include `model` as the only parameter, and returns a scalar value. This function should include `model` as the only parameter, and returns a scalar value.
Example:: Example::
>>> def evaluator(model):
>>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def evaluator(model):
>>> val_loader = ... device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
>>> model.eval() val_loader = ...
>>> correct = 0 model.eval()
>>> with torch.no_grad(): correct = 0
>>> for data, target in val_loader: with torch.no_grad():
>>> data, target = data.to(device), target.to(device) for data, target in val_loader:
>>> output = model(data) data, target = data.to(device), target.to(device)
>>> # get the index of the max log-probability output = model(data)
>>> pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
>>> correct += pred.eq(target.view_as(pred)).sum().item() pred = output.argmax(dim=1, keepdim=True)
>>> accuracy = correct / len(val_loader.dataset) correct += pred.eq(target.view_as(pred)).sum().item()
>>> return accuracy accuracy = correct / len(val_loader.dataset)
return accuracy
optimize_mode : str optimize_mode : str
optimize mode, `maximize` or `minimize`, by default `maximize`. Optimize mode, `maximize` or `minimize`, by default `maximize`.
base_algo : str base_algo : str
Base pruning algorithm. `level`, `l1` or `l2`, by default `l1`. Given the sparsity distribution among the ops, Base pruning algorithm. `level`, `l1` or `l2`, by default `l1`. Given the sparsity distribution among the ops,
the assigned `base_algo` is used to decide which filters/channels/weights to prune. the assigned `base_algo` is used to decide which filters/channels/weights to prune.
start_temperature : float start_temperature : float
Simualated Annealing related parameter Start temperature of the simulated annealing process.
stop_temperature : float stop_temperature : float
Simualated Annealing related parameter Stop temperature of the simulated annealing process.
cool_down_rate : float cool_down_rate : float
Simualated Annealing related parameter Cool down rate of the temperature.
perturbation_magnitude : float perturbation_magnitude : float
initial perturbation magnitude to the sparsities. The magnitude decreases with current temperature Initial perturbation magnitude to the sparsities. The magnitude decreases with current temperature.
experiment_data_dir : string experiment_data_dir : string
PATH to save experiment data, PATH to save experiment data,
including the config_list generated for the base pruning algorithm, the performance of the pruned model and the pruning history. including the config_list generated for the base pruning algorithm, the performance of the pruned model and the pruning history.
""" """
def __init__(self, model, config_list, evaluator, optimize_mode='maximize', base_algo='l1',
start_temperature=100, stop_temperature=20, cool_down_rate=0.9, perturbation_magnitude=0.35, experiment_data_dir='./'):
# original model # original model
self._model_to_prune = copy.deepcopy(model) self._model_to_prune = copy.deepcopy(model)
self._base_algo = base_algo self._base_algo = base_algo
...@@ -114,7 +108,7 @@ class SimulatedAnnealingPruner(Pruner): ...@@ -114,7 +108,7 @@ class SimulatedAnnealingPruner(Pruner):
""" """
Parameters Parameters
---------- ----------
model : torch.nn.module model : torch.nn.Module
Model to be pruned Model to be pruned
config_list : list config_list : list
List on pruning configs List on pruning configs
......
...@@ -153,7 +153,7 @@ class QAT_Quantizer(Quantizer): ...@@ -153,7 +153,7 @@ class QAT_Quantizer(Quantizer):
""" """
Parameters Parameters
---------- ----------
model : torch.nn.module model : torch.nn.Module
Model to be pruned Model to be pruned
config_list : list of dict config_list : list of dict
List of configurations List of configurations
...@@ -179,7 +179,7 @@ class QAT_Quantizer(Quantizer): ...@@ -179,7 +179,7 @@ class QAT_Quantizer(Quantizer):
---------- ----------
bits : int bits : int
quantization bits length quantization bits length
op : torch.nn.module op : torch.nn.Module
target module target module
real_val : float real_val : float
real value to be quantized real value to be quantized
...@@ -271,7 +271,7 @@ class DoReFaQuantizer(Quantizer): ...@@ -271,7 +271,7 @@ class DoReFaQuantizer(Quantizer):
""" """
Parameters Parameters
---------- ----------
model : torch.nn.module model : torch.nn.Module
Model to be pruned Model to be pruned
config_list : list of dict config_list : list of dict
List of configurations List of configurations
...@@ -322,7 +322,7 @@ class BNNQuantizer(Quantizer): ...@@ -322,7 +322,7 @@ class BNNQuantizer(Quantizer):
""" """
Parameters Parameters
---------- ----------
model : torch.nn.module model : torch.nn.Module
Model to be pruned Model to be pruned
config_list : list of dict config_list : list of dict
List of configurations List of configurations
......
...@@ -88,9 +88,8 @@ class CompressorTestCase(TestCase): ...@@ -88,9 +88,8 @@ class CompressorTestCase(TestCase):
def test_torch_level_pruner(self): def test_torch_level_pruner(self):
model = TorchModel() model = TorchModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
configure_list = [{'sparsity': 0.8, 'op_types': ['default']}] configure_list = [{'sparsity': 0.8, 'op_types': ['default']}]
torch_compressor.LevelPruner(model, configure_list, optimizer).compress() torch_compressor.LevelPruner(model, configure_list).compress()
@tf2 @tf2
def test_tf_level_pruner(self): def test_tf_level_pruner(self):
...@@ -129,7 +128,7 @@ class CompressorTestCase(TestCase): ...@@ -129,7 +128,7 @@ class CompressorTestCase(TestCase):
model = TorchModel() model = TorchModel()
config_list = [{'sparsity': 0.6, 'op_types': ['Conv2d']}, {'sparsity': 0.2, 'op_types': ['Conv2d']}] config_list = [{'sparsity': 0.6, 'op_types': ['Conv2d']}, {'sparsity': 0.2, 'op_types': ['Conv2d']}]
pruner = torch_compressor.FPGMPruner(model, config_list, torch.optim.SGD(model.parameters(), lr=0.01)) pruner = torch_compressor.FPGMPruner(model, config_list)
model.conv2.module.weight.data = torch.tensor(w).float() model.conv2.module.weight.data = torch.tensor(w).float()
masks = pruner.calc_mask(model.conv2) masks = pruner.calc_mask(model.conv2)
...@@ -315,7 +314,7 @@ class CompressorTestCase(TestCase): ...@@ -315,7 +314,7 @@ class CompressorTestCase(TestCase):
def test_torch_pruner_validation(self): def test_torch_pruner_validation(self):
# test bad configuraiton # test bad configuraiton
pruner_classes = [torch_compressor.__dict__[x] for x in \ pruner_classes = [torch_compressor.__dict__[x] for x in \
['LevelPruner', 'SlimPruner', 'FPGMPruner', 'L1FilterPruner', 'L2FilterPruner', 'AGP_Pruner', \ ['LevelPruner', 'SlimPruner', 'FPGMPruner', 'L1FilterPruner', 'L2FilterPruner', \
'ActivationMeanRankFilterPruner', 'ActivationAPoZRankFilterPruner']] 'ActivationMeanRankFilterPruner', 'ActivationAPoZRankFilterPruner']]
bad_configs = [ bad_configs = [
...@@ -337,11 +336,10 @@ class CompressorTestCase(TestCase): ...@@ -337,11 +336,10 @@ class CompressorTestCase(TestCase):
] ]
] ]
model = TorchModel() model = TorchModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
for pruner_class in pruner_classes: for pruner_class in pruner_classes:
for config_list in bad_configs: for config_list in bad_configs:
try: try:
pruner_class(model, config_list, optimizer) pruner_class(model, config_list)
print(config_list) print(config_list)
assert False, 'Validation error should be raised for bad configuration' assert False, 'Validation error should be raised for bad configuration'
except schema.SchemaError: except schema.SchemaError:
......
...@@ -8,7 +8,7 @@ import torch.nn.functional as F ...@@ -8,7 +8,7 @@ import torch.nn.functional as F
import math import math
from unittest import TestCase, main from unittest import TestCase, main
from nni.compression.torch import LevelPruner, SlimPruner, FPGMPruner, L1FilterPruner, \ from nni.compression.torch import LevelPruner, SlimPruner, FPGMPruner, L1FilterPruner, \
L2FilterPruner, AGP_Pruner, ActivationMeanRankFilterPruner, ActivationAPoZRankFilterPruner, \ L2FilterPruner, AGPPruner, ActivationMeanRankFilterPruner, ActivationAPoZRankFilterPruner, \
TaylorFOWeightFilterPruner, NetAdaptPruner, SimulatedAnnealingPruner, ADMMPruner, AutoCompressPruner TaylorFOWeightFilterPruner, NetAdaptPruner, SimulatedAnnealingPruner, ADMMPruner, AutoCompressPruner
def validate_sparsity(wrapper, sparsity, bias=False): def validate_sparsity(wrapper, sparsity, bias=False):
...@@ -33,7 +33,7 @@ prune_config = { ...@@ -33,7 +33,7 @@ prune_config = {
] ]
}, },
'agp': { 'agp': {
'pruner_class': AGP_Pruner, 'pruner_class': AGPPruner,
'config_list': [{ 'config_list': [{
'initial_sparsity': 0., 'initial_sparsity': 0.,
'final_sparsity': 0.8, 'final_sparsity': 0.8,
...@@ -192,7 +192,9 @@ def pruners_test(pruner_names=['level', 'agp', 'slim', 'fpgm', 'l1', 'l2', 'tayl ...@@ -192,7 +192,9 @@ def pruners_test(pruner_names=['level', 'agp', 'slim', 'fpgm', 'l1', 'l2', 'tayl
pruner = prune_config[pruner_name]['pruner_class'](model, config_list, trainer=prune_config[pruner_name]['trainer']) pruner = prune_config[pruner_name]['pruner_class'](model, config_list, trainer=prune_config[pruner_name]['trainer'])
elif pruner_name == 'autocompress': elif pruner_name == 'autocompress':
pruner = prune_config[pruner_name]['pruner_class'](model, config_list, trainer=prune_config[pruner_name]['trainer'], evaluator=prune_config[pruner_name]['evaluator'], dummy_input=x) pruner = prune_config[pruner_name]['pruner_class'](model, config_list, trainer=prune_config[pruner_name]['trainer'], evaluator=prune_config[pruner_name]['evaluator'], dummy_input=x)
else: elif pruner_name in ['level', 'slim', 'fpgm', 'l1', 'l2']:
pruner = prune_config[pruner_name]['pruner_class'](model, config_list)
elif pruner_name in ['agp', 'taylorfo', 'mean_activation', 'apoz']:
pruner = prune_config[pruner_name]['pruner_class'](model, config_list, optimizer) pruner = prune_config[pruner_name]['pruner_class'](model, config_list, optimizer)
pruner.compress() pruner.compress()
...@@ -225,7 +227,7 @@ def test_agp(pruning_algorithm): ...@@ -225,7 +227,7 @@ def test_agp(pruning_algorithm):
optimizer = torch.optim.SGD(model.parameters(), lr=0.01) optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
config_list = prune_config['agp']['config_list'] config_list = prune_config['agp']['config_list']
pruner = AGP_Pruner(model, config_list, optimizer, pruning_algorithm=pruning_algorithm) pruner = AGPPruner(model, config_list, optimizer, pruning_algorithm=pruning_algorithm)
pruner.compress() pruner.compress()
x = torch.randn(2, 1, 28, 28) x = torch.randn(2, 1, 28, 28)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment