"src/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "82791fc11494e30b26a8801beb52f65ec2dc6d18"
Unverified Commit 76086583 authored by chicm-ms's avatar chicm-ms Committed by GitHub
Browse files

Fix compressor op_types (#1670)

* fix compressor op_types
parent 7c4b8c0d
...@@ -8,12 +8,14 @@ You can easily compress a model with NNI compression. Take pruning for example, ...@@ -8,12 +8,14 @@ You can easily compress a model with NNI compression. Take pruning for example,
```python ```python
from nni.compression.torch import LevelPruner from nni.compression.torch import LevelPruner
config_list = [{ 'sparsity': 0.8, 'op_types': 'default' }] config_list = [{ 'sparsity': 0.8, 'op_types': ['default'] }]
pruner = LevelPruner(config_list) pruner = LevelPruner(config_list)
pruner(model) pruner(model)
``` ```
```{ 'sparsity': 0.8, 'op_types': 'default' }```means that **all layers with weight will be compressed with the same 0.8 sparsity**. When ```pruner(model)``` called, the model is compressed with masks and after that you can normally fine tune this model and **pruned weights won't be updated** which have been masked. The 'default' op_type stands for the module types defined in [default_layers.py](https://github.com/microsoft/nni/blob/master/src/sdk/pynni/nni/compression/torch/default_layers.py) for pytorch.
Therefore ```{ 'sparsity': 0.8, 'op_types': ['default'] }```means that **all layers with specified op_types will be compressed with the same 0.8 sparsity**. When ```pruner(model)``` called, the model is compressed with masks and after that you can normally fine tune this model and **pruned weights won't be updated** which have been masked.
## Then, make this automatic ## Then, make this automatic
......
...@@ -22,7 +22,7 @@ We use a simple example to show how to modify your trial code in order to apply ...@@ -22,7 +22,7 @@ We use a simple example to show how to modify your trial code in order to apply
Tensorflow code Tensorflow code
```python ```python
from nni.compression.tensorflow import LevelPruner from nni.compression.tensorflow import LevelPruner
config_list = [{ 'sparsity': 0.8, 'op_types': 'default' }] config_list = [{ 'sparsity': 0.8, 'op_types': ['default'] }]
pruner = LevelPruner(config_list) pruner = LevelPruner(config_list)
pruner(tf.get_default_graph()) pruner(tf.get_default_graph())
``` ```
...@@ -30,7 +30,7 @@ pruner(tf.get_default_graph()) ...@@ -30,7 +30,7 @@ pruner(tf.get_default_graph())
PyTorch code PyTorch code
```python ```python
from nni.compression.torch import LevelPruner from nni.compression.torch import LevelPruner
config_list = [{ 'sparsity': 0.8, 'op_types': 'default' }] config_list = [{ 'sparsity': 0.8, 'op_types': ['default'] }]
pruner = LevelPruner(config_list) pruner = LevelPruner(config_list)
pruner(model) pruner(model)
``` ```
...@@ -58,7 +58,7 @@ A simple example of configuration is shown below: ...@@ -58,7 +58,7 @@ A simple example of configuration is shown below:
[ [
{ {
'sparsity': 0.8, 'sparsity': 0.8,
'op_types': 'default' 'op_types': ['default']
}, },
{ {
'sparsity': 0.6, 'sparsity': 0.6,
...@@ -115,7 +115,7 @@ class YourPruner(nni.compression.tensorflow.Pruner): ...@@ -115,7 +115,7 @@ class YourPruner(nni.compression.tensorflow.Pruner):
def calc_mask(self, weight, config, **kwargs): def calc_mask(self, weight, config, **kwargs):
# weight is the target weight tensor # weight is the target weight tensor
# config is the selected dict object in config_list for this layer # config is the selected dict object in config_list for this layer
# kwargs contains op, op_type, and op_name # kwargs contains op, op_types, and op_name
# design your mask and return your mask # design your mask and return your mask
return your_mask return your_mask
...@@ -158,7 +158,7 @@ class YourPruner(nni.compression.tensorflow.Quantizer): ...@@ -158,7 +158,7 @@ class YourPruner(nni.compression.tensorflow.Quantizer):
def quantize_weight(self, weight, config, **kwargs): def quantize_weight(self, weight, config, **kwargs):
# weight is the target weight tensor # weight is the target weight tensor
# config is the selected dict object in config_list for this layer # config is the selected dict object in config_list for this layer
# kwargs contains op, op_type, and op_name # kwargs contains op, op_types, and op_name
# design your quantizer and return new weight # design your quantizer and return new weight
return new_weight return new_weight
......
...@@ -12,7 +12,7 @@ We first sort the weights in the specified layer by their absolute values. And t ...@@ -12,7 +12,7 @@ We first sort the weights in the specified layer by their absolute values. And t
Tensorflow code Tensorflow code
``` ```
from nni.compression.tensorflow import LevelPruner from nni.compression.tensorflow import LevelPruner
config_list = [{ 'sparsity': 0.8, 'op_types': 'default' }] config_list = [{ 'sparsity': 0.8, 'op_types': ['default'] }]
pruner = LevelPruner(config_list) pruner = LevelPruner(config_list)
pruner(model_graph) pruner(model_graph)
``` ```
...@@ -20,7 +20,7 @@ pruner(model_graph) ...@@ -20,7 +20,7 @@ pruner(model_graph)
PyTorch code PyTorch code
``` ```
from nni.compression.torch import LevelPruner from nni.compression.torch import LevelPruner
config_list = [{ 'sparsity': 0.8, 'op_types': 'default' }] config_list = [{ 'sparsity': 0.8, 'op_types': ['default'] }]
pruner = LevelPruner(config_list) pruner = LevelPruner(config_list)
pruner(model) pruner(model)
``` ```
......
...@@ -31,14 +31,14 @@ You can quantize your model to 8 bits with the code below before your training c ...@@ -31,14 +31,14 @@ You can quantize your model to 8 bits with the code below before your training c
Tensorflow code Tensorflow code
```python ```python
from nni.compressors.tensorflow import QAT_Quantizer from nni.compressors.tensorflow import QAT_Quantizer
config_list = [{ 'q_bits': 8, 'op_types': 'default' }] config_list = [{ 'q_bits': 8, 'op_types': ['default'] }]
quantizer = QAT_Quantizer(config_list) quantizer = QAT_Quantizer(config_list)
quantizer(tf.get_default_graph()) quantizer(tf.get_default_graph())
``` ```
PyTorch code PyTorch code
```python ```python
from nni.compressors.torch import QAT_Quantizer from nni.compressors.torch import QAT_Quantizer
config_list = [{ 'q_bits': 8, 'op_types': 'default' }] config_list = [{ 'q_bits': 8, 'op_types': ['default'] }]
quantizer = QAT_Quantizer(config_list) quantizer = QAT_Quantizer(config_list)
quantizer(model) quantizer(model)
``` ```
......
...@@ -20,7 +20,7 @@ configure_list = [{ ...@@ -20,7 +20,7 @@ configure_list = [{
'start_epoch': 0, 'start_epoch': 0,
'end_epoch': 10, 'end_epoch': 10,
'frequency': 1, 'frequency': 1,
'op_type': 'default' 'op_types': ['default']
}] }]
pruner = AGP_Pruner(configure_list) pruner = AGP_Pruner(configure_list)
``` ```
......
...@@ -6,4 +6,4 @@ AGPruner: ...@@ -6,4 +6,4 @@ AGPruner:
frequency: 1 frequency: 1
initial_sparsity: 0.05 initial_sparsity: 0.05
final_sparsity: 0.8 final_sparsity: 0.8
op_type: 'default' op_types: ['default']
...@@ -91,7 +91,7 @@ def main(): ...@@ -91,7 +91,7 @@ def main():
'start_epoch': 0, 'start_epoch': 0,
'end_epoch': 10, 'end_epoch': 10,
'frequency': 1, 'frequency': 1,
'op_type': 'default' 'op_types': ['default']
}] }]
pruner = AGP_Pruner(configure_list) pruner = AGP_Pruner(configure_list)
# if you want to load from yaml file # if you want to load from yaml file
......
...@@ -82,7 +82,7 @@ def main(): ...@@ -82,7 +82,7 @@ def main():
'''you can change this to DoReFaQuantizer to implement it '''you can change this to DoReFaQuantizer to implement it
DoReFaQuantizer(configure_list).compress(tf.get_default_graph()) DoReFaQuantizer(configure_list).compress(tf.get_default_graph())
''' '''
configure_list = [{'q_bits':8, 'op_type':'default'}] configure_list = [{'q_bits':8, 'op_types':['default']}]
quantizer = QAT_Quantizer(configure_list) quantizer = QAT_Quantizer(configure_list)
quantizer(tf.get_default_graph()) quantizer(tf.get_default_graph())
# you can also use compress(model) or compress_default_graph() # you can also use compress(model) or compress_default_graph()
......
...@@ -76,7 +76,7 @@ def main(): ...@@ -76,7 +76,7 @@ def main():
'start_epoch': 0, 'start_epoch': 0,
'end_epoch': 10, 'end_epoch': 10,
'frequency': 1, 'frequency': 1,
'op_type': 'default' 'op_types': ['default']
}] }]
pruner = AGP_Pruner(configure_list) pruner = AGP_Pruner(configure_list)
......
...@@ -68,7 +68,7 @@ def main(): ...@@ -68,7 +68,7 @@ def main():
'''you can change this to DoReFaQuantizer to implement it '''you can change this to DoReFaQuantizer to implement it
DoReFaQuantizer(configure_list).compress(model) DoReFaQuantizer(configure_list).compress(model)
''' '''
configure_list = [{'q_bits':8, 'op_type':'default'}] configure_list = [{'q_bits':8, 'op_types':['default']}]
quantizer = QAT_Quantizer(configure_list) quantizer = QAT_Quantizer(configure_list)
quantizer(model) quantizer(model)
# you can also use compress(model) method # you can also use compress(model) method
......
...@@ -58,10 +58,8 @@ class Compressor: ...@@ -58,10 +58,8 @@ class Compressor:
def _select_config(self, layer): def _select_config(self, layer):
ret = None ret = None
for config in self._config_list: for config in self._config_list:
op_types = config.get('op_types') config['op_types'] = self._expand_config_op_types(config)
if op_types == 'default': if layer.type not in config['op_types']:
op_types = default_layers.weighted_modules
if op_types and layer.type not in op_types:
continue continue
if config.get('op_names') and layer.name not in config['op_names']: if config.get('op_names') and layer.name not in config['op_names']:
continue continue
...@@ -70,6 +68,16 @@ class Compressor: ...@@ -70,6 +68,16 @@ class Compressor:
return None return None
return ret return ret
def _expand_config_op_types(self, config):
if config is None:
return []
expanded_op_types = []
for op_type in config.get('op_types', []):
if op_type == 'default':
expanded_op_types.extend(default_layers.weighted_modules)
else:
expanded_op_types.append(op_type)
return expanded_op_types
class Pruner(Compressor): class Pruner(Compressor):
""" """
...@@ -112,10 +120,6 @@ class Quantizer(Compressor): ...@@ -112,10 +120,6 @@ class Quantizer(Compressor):
Base quantizer for pytorch quantizer Base quantizer for pytorch quantizer
""" """
def __call__(self, model):
self.compress(model)
return model
def quantize_weight(self, weight, config, op, op_type, op_name): def quantize_weight(self, weight, config, op, op_type, op_name):
"""user should know where dequantize goes and implement it in quantize method """user should know where dequantize goes and implement it in quantize method
we now do not provide dequantize method we now do not provide dequantize method
......
...@@ -100,21 +100,21 @@ class TorchMnist(torch.nn.Module): ...@@ -100,21 +100,21 @@ class TorchMnist(torch.nn.Module):
class CompressorTestCase(TestCase): class CompressorTestCase(TestCase):
def test_tf_pruner(self): def test_tf_pruner(self):
model = TfMnist() model = TfMnist()
configure_list = [{'sparsity': 0.8, 'op_types': 'default'}] configure_list = [{'sparsity': 0.8, 'op_types': ['default']}]
tf_compressor.LevelPruner(configure_list).compress_default_graph() tf_compressor.LevelPruner(configure_list).compress_default_graph()
def test_tf_quantizer(self): def test_tf_quantizer(self):
model = TfMnist() model = TfMnist()
tf_compressor.NaiveQuantizer([{'op_types': 'default'}]).compress_default_graph() tf_compressor.NaiveQuantizer([{'op_types': ['default']}]).compress_default_graph()
def test_torch_pruner(self): def test_torch_pruner(self):
model = TorchMnist() model = TorchMnist()
configure_list = [{'sparsity': 0.8, 'op_types': 'default'}] configure_list = [{'sparsity': 0.8, 'op_types': ['default']}]
torch_compressor.LevelPruner(configure_list).compress(model) torch_compressor.LevelPruner(configure_list).compress(model)
def test_torch_quantizer(self): def test_torch_quantizer(self):
model = TorchMnist() model = TorchMnist()
torch_compressor.NaiveQuantizer([{'op_types': 'default'}]).compress(model) torch_compressor.NaiveQuantizer([{'op_types': ['default']}]).compress(model)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment