In order to simplify the process of writing new compression algorithms, we have designed simple and flexible programming interface, which covers pruning and quantization. Below, we first demonstrate how to customize a new pruning algorithm and then demonstrate how to customize a new quantization algorithm.
**Important Note** To better understand how to customize new pruning/quantization algorithms, users should first understand the framework that supports various pruning algorithms in NNI. Reference [Framework overview of model compression](https://nni.readthedocs.io/en/latest/Compressor/Framework.html)
## Customize a new pruning algorithm
Implementing a new pruning algorithm requires implementing a `weight masker` class which shoud be a subclass of `WeightMasker`, and a `pruner` class, which should be a subclass `Pruner`.
An implementation of `weight masker` may look like this:
```python
classMyMasker(WeightMasker):
def__init__(self,model,pruner):
super().__init__(model,pruner)
# You can do some initialization here, such as collecting some statistics data
# if it is necessary for your algorithms to calculate the masks.
# calculate the masks based on the wrapper.weight, and sparsity,
# and anything else
# mask = ...
return{'weight_mask':mask}
```
You can reference nni provided [weight masker](https://github.com/microsoft/nni/blob/master/src/sdk/pynni/nni/compression/torch/pruning/structured_pruning.py) implementations to implement your own weight masker.
Reference nni provided [pruner](https://github.com/microsoft/nni/blob/master/src/sdk/pynni/nni/compression/torch/pruning/one_shot.py) implementations to implement your own pruner class.
***
## Customize a new quantization algorithm
To write a new quantization algorithm, you can write a class that inherits `nni.compression.torch.Quantizer`. Then, override the member functions with the logic of your algorithm. The member function to override is `quantize_weight`. `quantize_weight` directly returns the quantized weights rather than mask, because for quantization the quantized weights cannot be obtained by applying mask.
```python
fromnni.compression.torchimportQuantizer
classYourQuantizer(Quantizer):
def__init__(self,model,config_list):
"""
Suggest you to use the NNI defined spec for config
"""
super().__init__(model,config_list)
defquantize_weight(self,weight,config,**kwargs):
"""
quantize should overload this method to quantize weight tensors.
This method is effectively hooked to :meth:`forward` of the model.
Parameters
----------
weight : Tensor
weight that needs to be quantized
config : dict
the configuration for weight quantization
"""
# Put your code to generate `new_weight` here
returnnew_weight
defquantize_output(self,output,config,**kwargs):
"""
quantize should overload this method to quantize output.
This method is effectively hooked to `:meth:`forward` of the model.
Parameters
----------
output : Tensor
output that needs to be quantized
config : dict
the configuration for output quantization
"""
# Put your code to generate `new_output` here
returnnew_output
defquantize_input(self,*inputs,config,**kwargs):
"""
quantize should overload this method to quantize input.
This method is effectively hooked to :meth:`forward` of the model.
Parameters
----------
inputs : Tensor
inputs that needs to be quantized
config : dict
the configuration for inputs quantization
"""
# Put your code to generate `new_input` here
returnnew_input
defupdate_epoch(self,epoch_num):
pass
defstep(self):
"""
Can do some processing based on the model or weights binded
in the func bind_model
"""
pass
```
### Customize backward function
Sometimes it's necessary for a quantization operation to have a customized backward function, such as [Straight-Through Estimator](https://stackoverflow.com/questions/38361314/the-concept-of-straight-through-estimator-ste), user can customize a backward function as follow:
To simplify writing a new compression algorithm, we design programming interfaces which are simple but flexible enough. There are interfaces for pruning and quantization respectively. Below, we first demonstrate how to customize a new pruning algorithm and then demonstrate how to customize a new quantization algorithm.
Below picture shows the components overview of model compression framework.
## Customize a new pruning algorithm

To better demonstrate how to customize a new pruning algorithm, it is necessary for users to first understand the framework for supporting various pruning algorithms in NNI.
There are 3 major components/classes in NNI model compression framework: `Compressor`, `Pruner` and `Quantizer`. Let's look at them in detail one by one:
### Framework overview for pruning algorithms
## Compressor
Following example shows how to use a pruner:
Compressor is the base class for pruner and quntizer, it provides a unified interface for pruner and quantizer for end users, so that pruner and quantizer can be used in the same way. For example, to use a pruner:
```python
```python
fromnni.compression.torchimportLevelPruner
fromnni.compression.torchimportLevelPruner
...
@@ -32,82 +32,25 @@ model = pruner.compress()
...
@@ -32,82 +32,25 @@ model = pruner.compress()
# the model will be pruned during training automatically
# the model will be pruned during training automatically
```
```
A pruner receives `model`, `config_list` and `optimizer` as arguments. It prunes the model per the `config_list` during training loop by adding a hook on `optimizer.step()`.
To use a quantizer:
From implementation perspective, a pruner consists of a `weight masker` instance and multiple `module wrapper` instances.
#### Weight masker
A `weight masker` is the implementation of pruning algorithms, it can prune a specified layer wrapped by `module wrapper` with specified sparsity.
#### Module wrapper
A `module wrapper` is a module containing:
1. the origin module
2. some buffers used by `calc_mask`
3. a new forward method that applies masks before running the original forward method.
the reasons to use `module wrapper`:
1. some buffers are needed by `calc_mask` to calculate masks and these buffers should be registered in `module wrapper` so that the original modules are not contaminated.
2. a new `forward` method is needed to apply masks to weight before calling the real `forward` method.
#### Pruner
A `pruner` is responsible for:
1. Manage / verify config_list.
2. Use `module wrapper` to wrap the model layers and add hook on `optimizer.step`
3. Use `weight masker` to calculate masks of layers while pruning.
4. Export pruned model weights and masks.
### Implement a new pruning algorithm
Implementing a new pruning algorithm requires implementing a `weight masker` class which shoud be a subclass of `WeightMasker`, and a `pruner` class, which should be a subclass `Pruner`.
An implementation of `weight masker` may look like this:
```python
```python
classMyMasker(WeightMasker):
fromnni.compression.torchimportDoReFaQuantizer
def__init__(self,model,pruner):
super().__init__(model,pruner)
# You can do some initialization here, such as collecting some statistics data
# if it is necessary for your algorithms to calculate the masks.
# calculate the masks based on the wrapper.weight, and sparsity,
# and anything else
# mask = ...
return{'weight_mask':mask}
```
You can reference nni provided [weight masker](https://github.com/microsoft/nni/blob/master/src/sdk/pynni/nni/compression/torch/pruning/structured_pruning.py) implementations to implement your own weight masker.
View [example code](https://github.com/microsoft/nni/tree/master/examples/model_compress) for more information.
Reference nni provided [pruner](https://github.com/microsoft/nni/blob/master/src/sdk/pynni/nni/compression/torch/pruning/one_shot.py) implementations to implement your own pruner class.
`Compressor` class provides some utility methods for subclass and users:
A pruner receives `model`, `config_list` and `optimizer` as arguments. It prunes the model per the `config_list` during training loop by adding a hook on `optimizer.step()`.
Pruner class is a subclass of Compressor, so it contains everything in the Compressor class and some additional components only for pruning, it contains:
### Weight masker
A `weight masker` is the implementation of pruning algorithms, it can prune a specified layer wrapped by `module wrapper` with specified sparsity.
### Pruning module wrapper
A `pruning module wrapper` is a module containing:
1. the origin module
2. some buffers used by `calc_mask`
3. a new forward method that applies masks before running the original forward method.
the reasons to use `module wrapper`:
1. some buffers are needed by `calc_mask` to calculate masks and these buffers should be registered in `module wrapper` so that the original modules are not contaminated.
2. a new `forward` method is needed to apply masks to weight before calling the real `forward` method.
### Pruning hook
A pruning hook is installed on a pruner when the pruner is constructed, it is used to call pruner's calc_mask method at `optimizer.step()` is invoked.
On multi-GPU training, buffers and parameters are copied to multiple GPU every time the `forward` method runs on multiple GPU. If buffers and parameters are updated in the `forward` method, an `in-place` update is needed to ensure the update is effective.
Since `calc_mask` is called in the `optimizer.step` method, which happens after the `forward` method and happens only on one GPU, it supports multi-GPU naturally.
***
***
## Customize a new quantization algorithm
## Quantizer
To write a new quantization algorithm, you can write a class that inherits `nni.compression.torch.Quantizer`. Then, override the member functions with the logic of your algorithm. The member function to override is `quantize_weight`. `quantize_weight` directly returns the quantized weights rather than mask, because for quantization the quantized weights cannot be obtained by applying mask.
Quantizer class is also a subclass of `Compressor`, it is used to compress models by reducing the number of bits required to represent weights or activations, which can reduce the computations and the inference time. It contains:
```python
### Quantization module wrapper
fromnni.compression.torchimportQuantizer
classYourQuantizer(Quantizer):
Each module/layer of the model to be quantized is wrapped by a quantization module wrapper, it provides a new `forward` method to quantize the original module's weight, input and output.
def__init__(self,model,config_list):
"""
### Quantization hook
Suggest you to use the NNI defined spec for config
"""
A quantization hook is installed on a quntizer when it is constructed, it is call at `optimizer.step()`.
super().__init__(model,config_list)
### Quantization methods
`Quantizer` class provides following methods for subclass to implement quantization algorithms:
defquantize_weight(self,weight,config,**kwargs):
```python
classQuantizer(Compressor):
"""
Base quantizer for pytorch quantizer
"""
defquantize_weight(self,weight,wrapper,**kwargs):
"""
"""
quantize should overload this method to quantize weight tensors.
quantize should overload this method to quantize weight.
This method is effectively hooked to :meth:`forward` of the model.
This method is effectively hooked to :meth:`forward` of the model.
Parameters
Parameters
----------
----------
weight : Tensor
weight : Tensor
weight that needs to be quantized
weight that needs to be quantized
config : dict
wrapper : QuantizerModuleWrapper
the configuration for weight quantization
the wrapper for origin module
"""
"""
raiseNotImplementedError('Quantizer must overload quantize_weight()')
# Put your code to generate `new_weight` here
defquantize_output(self,output,wrapper,**kwargs):
returnnew_weight
defquantize_output(self,output,config,**kwargs):
"""
"""
quantize should overload this method to quantize output.
quantize should overload this method to quantize output.
This method is effectively hooked to `:meth:`forward` of the model.
This method is effectively hooked to :meth:`forward` of the model.
Parameters
Parameters
----------
----------
output : Tensor
output : Tensor
output that needs to be quantized
output that needs to be quantized
config : dict
wrapper : QuantizerModuleWrapper
the configuration for output quantization
the wrapper for origin module
"""
"""
raiseNotImplementedError('Quantizer must overload quantize_output()')
# Put your code to generate `new_output` here
defquantize_input(self,*inputs,wrapper,**kwargs):
returnnew_output
defquantize_input(self,*inputs,config,**kwargs):
"""
"""
quantize should overload this method to quantize input.
quantize should overload this method to quantize input.
This method is effectively hooked to :meth:`forward` of the model.
This method is effectively hooked to :meth:`forward` of the model.
Parameters
Parameters
----------
----------
inputs : Tensor
inputs : Tensor
inputs that needs to be quantized
inputs that needs to be quantized
config : dict
wrapper : QuantizerModuleWrapper
the configuration for inputs quantization
the wrapper for origin module
"""
"""
raiseNotImplementedError('Quantizer must overload quantize_input()')
# Put your code to generate `new_input` here
returnnew_input
defupdate_epoch(self,epoch_num):
pass
defstep(self):
"""
Can do some processing based on the model or weights binded
in the func bind_model
"""
pass
```
```
### Customize backward function
***
Sometimes it's necessary for a quantization operation to have a customized backward function, such as [Straight-Through Estimator](https://stackoverflow.com/questions/38361314/the-concept-of-straight-through-estimator-ste), user can customize a backward function as follow:
This method should be overrided by subclass to provide customized backward function,
default implementation is Straight-Through Estimator
Parameters
----------
tensor : Tensor
input of quantization operation
grad_output : Tensor
gradient of the output of quantization operation
quant_type : QuantType
the type of quantization, it can be `QuantType.QUANT_INPUT`, `QuantType.QUANT_WEIGHT`, `QuantType.QUANT_OUTPUT`,
you can define different behavior for different types.
Returns
-------
tensor
gradient of the input of quantization operation
"""
# for quant_output function, set grad to zero if the absolute value of tensor is larger than 1
ifquant_type==QuantType.QUANT_OUTPUT:
grad_output[torch.abs(tensor)>1]=0
returngrad_output
classYourQuantizer(Quantizer):
## Multi-GPU support
def__init__(self,model,config_list):
super().__init__(model,config_list)
# set your customized backward function to overwrite default backward function
self.quant_grad=ClipGrad
```
On multi-GPU training, buffers and parameters are copied to multiple GPU every time the `forward` method runs on multiple GPU. If buffers and parameters are updated in the `forward` method, an `in-place` update is needed to ensure the update is effective.
Since `calc_mask` is called in the `optimizer.step` method, which happens after the `forward` method and happens only on one GPU, it supports multi-GPU naturally.
If you do not customize `QuantGrad`, the default backward is Straight-Through Estimator.