Unverified Commit d6b61e2f authored by liuzhe-lz's avatar liuzhe-lz Committed by GitHub
Browse files

Resolve comments in PR 1571 (#1590)

* Resolve comments in PR 1571

* try to pass ut

* fix typo

* format doc-string

* use tensorflow.compat.v1

* Revert "use tensorflow.compat.v1"

This reverts commit 97a4ed923677c6dfd545fd654c55c424cf490a19.
parent ca2253c3
...@@ -7,7 +7,7 @@ We have provided two naive compression algorithms and four popular ones for user ...@@ -7,7 +7,7 @@ We have provided two naive compression algorithms and four popular ones for user
|Name|Brief Introduction of Algorithm| |Name|Brief Introduction of Algorithm|
|---|---| |---|---|
| [Level Pruner](./Pruner.md#level-pruner) | Pruning the specified ratio on each weight based on absolute values of weights | | [Level Pruner](./Pruner.md#level-pruner) | Pruning the specified ratio on each weight based on absolute values of weights |
| [AGP Pruner](./Pruner.md#agp-pruner) | To prune, or not to prune: exploring the efficacy of pruning for model compression. [Reference Paper](https://arxiv.org/abs/1710.01878)| | [AGP Pruner](./Pruner.md#agp-pruner) | Automated gradual pruning (To prune, or not to prune: exploring the efficacy of pruning for model compression) [Reference Paper](https://arxiv.org/abs/1710.01878)|
| [Sensitivity Pruner](./Pruner.md#sensitivity-pruner) | Learning both Weights and Connections for Efficient Neural Networks. [Reference Paper](https://arxiv.org/abs/1506.02626)| | [Sensitivity Pruner](./Pruner.md#sensitivity-pruner) | Learning both Weights and Connections for Efficient Neural Networks. [Reference Paper](https://arxiv.org/abs/1506.02626)|
| [Naive Quantizer](./Quantizer.md#naive-quantizer) | Quantize weights to default 8 bits | | [Naive Quantizer](./Quantizer.md#naive-quantizer) | Quantize weights to default 8 bits |
| [QAT Quantizer](./Quantizer.md#qat-quantizer) | Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference. [Reference Paper](http://openaccess.thecvf.com/content_cvpr_2018/papers/Jacob_Quantization_and_Training_CVPR_2018_paper.pdf)| | [QAT Quantizer](./Quantizer.md#qat-quantizer) | Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference. [Reference Paper](http://openaccess.thecvf.com/content_cvpr_2018/papers/Jacob_Quantization_and_Training_CVPR_2018_paper.pdf)|
...@@ -72,7 +72,7 @@ It means following the algorithm's default setting for compressed operations wit ...@@ -72,7 +72,7 @@ It means following the algorithm's default setting for compressed operations wit
### Other APIs ### Other APIs
Some compression algorithms use epochs to control the progress of compression, and some algorithms need to do something after every minibatch. Therefore, we provide another two APIs for users to invoke. One is `update_epoch`, you can use it as follows: Some compression algorithms use epochs to control the progress of compression (e.g. [AGP](./Pruner.md#agp-pruner)), and some algorithms need to do something after every minibatch. Therefore, we provide another two APIs for users to invoke. One is `update_epoch`, you can use it as follows:
Tensorflow code Tensorflow code
```python ```python
...@@ -138,7 +138,7 @@ Some algorithms may want global information for generating masks, for example, a ...@@ -138,7 +138,7 @@ Some algorithms may want global information for generating masks, for example, a
The interface for customizing quantization algorithm is similar to that of pruning algorithms. The only difference is that `calc_mask` is replaced with `quantize_weight`. `quantize_weight` directly returns the quantized weights rather than mask, because for quantization the quantized weights cannot be obtained by applying mask. The interface for customizing quantization algorithm is similar to that of pruning algorithms. The only difference is that `calc_mask` is replaced with `quantize_weight`. `quantize_weight` directly returns the quantized weights rather than mask, because for quantization the quantized weights cannot be obtained by applying mask.
``` ```python
# This is writing a Quantizer in tensorflow. # This is writing a Quantizer in tensorflow.
# For writing a Quantizer in PyTorch, you can simply replace # For writing a Quantizer in PyTorch, you can simply replace
# nni.compression.tensorflow.Quantizer with # nni.compression.tensorflow.Quantizer with
......
...@@ -38,7 +38,7 @@ In [To prune, or not to prune: exploring the efficacy of pruning for model compr ...@@ -38,7 +38,7 @@ In [To prune, or not to prune: exploring the efficacy of pruning for model compr
>The binary weight masks are updated every ∆t steps as the network is trained to gradually increase the sparsity of the network while allowing the network training steps to recover from any pruning-induced loss in accuracy. In our experience, varying the pruning frequency ∆t between 100 and 1000 training steps had a negligible impact on the final model quality. Once the model achieves the target sparsity sf , the weight masks are no longer updated. The intuition behind this sparsity function in equation >The binary weight masks are updated every ∆t steps as the network is trained to gradually increase the sparsity of the network while allowing the network training steps to recover from any pruning-induced loss in accuracy. In our experience, varying the pruning frequency ∆t between 100 and 1000 training steps had a negligible impact on the final model quality. Once the model achieves the target sparsity sf , the weight masks are no longer updated. The intuition behind this sparsity function in equation
### Usage ### Usage
You can prune all weight from %0 to 80% sparsity in 10 epoch with the code below. You can prune all weight from 0% to 80% sparsity in 10 epoch with the code below.
First, you should import pruner and add mask to model. First, you should import pruner and add mask to model.
......
...@@ -127,4 +127,6 @@ def main(): ...@@ -127,4 +127,6 @@ def main():
}) })
print('final result is', test_acc) print('final result is', test_acc)
main()
if __name__ == '__main__':
main()
...@@ -114,4 +114,6 @@ def main(): ...@@ -114,4 +114,6 @@ def main():
}) })
print('final result is', test_acc) print('final result is', test_acc)
main()
if __name__ == '__main__':
main()
...@@ -89,7 +89,7 @@ def main(): ...@@ -89,7 +89,7 @@ def main():
test(model, device, test_loader) test(model, device, test_loader)
pruner.update_epoch(epoch) pruner.update_epoch(epoch)
main()
if __name__ == '__main__':
main()
...@@ -81,7 +81,6 @@ def main(): ...@@ -81,7 +81,6 @@ def main():
train(model, device, train_loader, optimizer) train(model, device, train_loader, optimizer)
test(model, device, test_loader) test(model, device, test_loader)
main() if __name__ == '__main__':
main()
...@@ -10,8 +10,8 @@ _logger = logging.getLogger(__name__) ...@@ -10,8 +10,8 @@ _logger = logging.getLogger(__name__)
class LevelPruner(Pruner): class LevelPruner(Pruner):
def __init__(self, config_list): def __init__(self, config_list):
""" """
Configure Args: config_list: supported keys:
sparsity - sparsity
""" """
super().__init__(config_list) super().__init__(config_list)
...@@ -21,8 +21,7 @@ class LevelPruner(Pruner): ...@@ -21,8 +21,7 @@ class LevelPruner(Pruner):
class AGP_Pruner(Pruner): class AGP_Pruner(Pruner):
""" """An automated gradual pruning algorithm that prunes the smallest magnitude
An automated gradual pruning algorithm that prunes the smallest magnitude
weights to achieve a preset level of network sparsity. weights to achieve a preset level of network sparsity.
Michael Zhu and Suyog Gupta, "To prune, or not to prune: exploring the Michael Zhu and Suyog Gupta, "To prune, or not to prune: exploring the
...@@ -32,12 +31,12 @@ class AGP_Pruner(Pruner): ...@@ -32,12 +31,12 @@ class AGP_Pruner(Pruner):
""" """
def __init__(self, config_list): def __init__(self, config_list):
""" """
Configure Args config_list: supported keys:
initial_sparsity: - initial_sparsity
final_sparsity: you should make sure initial_sparsity <= final_sparsity - final_sparsity: you should make sure initial_sparsity <= final_sparsity
start_epoch: start epoch numer begin update mask - start_epoch: start epoch numer begin update mask
end_epoch: end epoch number stop update mask - end_epoch: end epoch number stop update mask
frequency: if you want update every 2 epoch, you can set it 2 - frequency: if you want update every 2 epoch, you can set it 2
""" """
super().__init__(config_list) super().__init__(config_list)
self.now_epoch = tf.Variable(0) self.now_epoch = tf.Variable(0)
...@@ -77,8 +76,7 @@ class AGP_Pruner(Pruner): ...@@ -77,8 +76,7 @@ class AGP_Pruner(Pruner):
class SensitivityPruner(Pruner): class SensitivityPruner(Pruner):
""" """Use algorithm from "Learning both Weights and Connections for Efficient Neural Networks"
Use algorithm from "Learning both Weights and Connections for Efficient Neural Networks"
https://arxiv.org/pdf/1506.02626v3.pdf https://arxiv.org/pdf/1506.02626v3.pdf
I.e.: "The pruning threshold is chosen as a quality parameter multiplied I.e.: "The pruning threshold is chosen as a quality parameter multiplied
...@@ -86,8 +84,8 @@ class SensitivityPruner(Pruner): ...@@ -86,8 +84,8 @@ class SensitivityPruner(Pruner):
""" """
def __init__(self, config_list): def __init__(self, config_list):
""" """
Configure Args: config_list: supported keys
sparsity: chosen pruning sparsity - sparsity: chosen pruning sparsity
""" """
super().__init__(config_list) super().__init__(config_list)
self.layer_mask = {} self.layer_mask = {}
......
...@@ -8,8 +8,7 @@ _logger = logging.getLogger(__name__) ...@@ -8,8 +8,7 @@ _logger = logging.getLogger(__name__)
class NaiveQuantizer(Quantizer): class NaiveQuantizer(Quantizer):
""" """quantize weight to 8 bits
quantize weight to 8 bits
""" """
def __init__(self, config_list): def __init__(self, config_list):
super().__init__(config_list) super().__init__(config_list)
...@@ -24,15 +23,14 @@ class NaiveQuantizer(Quantizer): ...@@ -24,15 +23,14 @@ class NaiveQuantizer(Quantizer):
class QAT_Quantizer(Quantizer): class QAT_Quantizer(Quantizer):
""" """Quantizer using the DoReFa scheme, as defined in:
Quantizer using the DoReFa scheme, as defined in:
Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference
http://openaccess.thecvf.com/content_cvpr_2018/papers/Jacob_Quantization_and_Training_CVPR_2018_paper.pdf http://openaccess.thecvf.com/content_cvpr_2018/papers/Jacob_Quantization_and_Training_CVPR_2018_paper.pdf
""" """
def __init__(self, config_list): def __init__(self, config_list):
""" """
Configure Args: config_list: supported keys:
q_bits - q_bits
""" """
super().__init__(config_list) super().__init__(config_list)
...@@ -50,15 +48,14 @@ class QAT_Quantizer(Quantizer): ...@@ -50,15 +48,14 @@ class QAT_Quantizer(Quantizer):
class DoReFaQuantizer(Quantizer): class DoReFaQuantizer(Quantizer):
""" """Quantizer using the DoReFa scheme, as defined in:
Quantizer using the DoReFa scheme, as defined in:
Zhou et al., DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidth Gradients Zhou et al., DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidth Gradients
(https://arxiv.org/abs/1606.06160) (https://arxiv.org/abs/1606.06160)
""" """
def __init__(self, config_list): def __init__(self, config_list):
""" """
Configure Args: config_list: supported keys:
q_bits - q_bits
""" """
super().__init__(config_list) super().__init__(config_list)
......
...@@ -13,20 +13,21 @@ class LayerInfo: ...@@ -13,20 +13,21 @@ class LayerInfo:
class Compressor: class Compressor:
""" """Abstract base TensorFlow compressor"""
Abstract base TensorFlow compressor
"""
def __init__(self, config_list): def __init__(self, config_list):
self._bound_model = None self._bound_model = None
self._config_list = config_list self._config_list = config_list
def __call__(self, model): def __call__(self, model):
"""Compress given graph with algorithm implemented by subclass.
The graph will be editted and returned.
"""
self.compress(model) self.compress(model)
return model return model
def compress(self, model): def compress(self, model):
""" """Compress given graph with algorithm implemented by subclass.
Compress given graph with algorithm implemented by subclass.
This will edit the graph. This will edit the graph.
""" """
assert self._bound_model is None, "Each NNI compressor instance can only compress one model" assert self._bound_model is None, "Each NNI compressor instance can only compress one model"
...@@ -39,30 +40,26 @@ class Compressor: ...@@ -39,30 +40,26 @@ class Compressor:
self._instrument_layer(layer, config) self._instrument_layer(layer, config)
def compress_default_graph(self): def compress_default_graph(self):
""" """Compress the default graph with algorithm implemented by subclass.
Compress the default graph with algorithm implemented by subclass. This will edit the default graph.
This will edit the graph.
""" """
self.compress(tf.get_default_graph()) self.compress(tf.get_default_graph())
def bind_model(self, model): def bind_model(self, model):
""" """This method is called when a model is bound to the compressor.
This method is called when a model is bound to the compressor. Compressors can optionally overload this method to do model-specific initialization.
Users can optionally overload this method to do model-specific initialization.
It is guaranteed that only one model will be bound to each compressor instance. It is guaranteed that only one model will be bound to each compressor instance.
""" """
pass pass
def update_epoch(self, epoch, sess): def update_epoch(self, epoch, sess):
""" """If user want to update mask every epoch, user can override this method
if user want to update mask every epoch, user can override this method
""" """
pass pass
def step(self, sess): def step(self, sess):
""" """If user want to update mask every step, user can override this method
if user want to update mask every step, user can override this method
""" """
pass pass
...@@ -87,15 +84,13 @@ class Compressor: ...@@ -87,15 +84,13 @@ class Compressor:
class Pruner(Compressor): class Pruner(Compressor):
""" """Abstract base TensorFlow pruner"""
Abstract base TensorFlow pruner
"""
def __init__(self, config_list): def __init__(self, config_list):
super().__init__(config_list) super().__init__(config_list)
def calc_mask(self, weight, config, op, op_type, op_name): def calc_mask(self, weight, config, op, op_type, op_name):
""" """Pruners should overload this method to provide mask for weight tensors.
Pruners should overload this method to provide mask for weight tensors.
The mask must have the same shape and type comparing to the weight. The mask must have the same shape and type comparing to the weight.
It will be applied with `multiply()` operation. It will be applied with `multiply()` operation.
This method works as a subgraph which will be inserted into the bound model. This method works as a subgraph which will be inserted into the bound model.
...@@ -103,13 +98,11 @@ class Pruner(Compressor): ...@@ -103,13 +98,11 @@ class Pruner(Compressor):
raise NotImplementedError("Pruners must overload calc_mask()") raise NotImplementedError("Pruners must overload calc_mask()")
def _instrument_layer(self, layer, config): def _instrument_layer(self, layer, config):
""" # it seems the graph editor can only swap edges of nodes or remove all edges from a node
it seems the graph editor can only swap edges of nodes or remove all edges from a node # it cannot remove one edge from a node, nor can it assign a new edge to a node
it cannot remove one edge from a node, nor can it assign a new edge to a node # we assume there is a proxy operation between the weight and the Conv2D layer
we assume there is a proxy operation between the weight and the Conv2D layer # this is true as long as the weight is `tf.Value`
this is true as long as the weight is `tf.Value` # not sure what will happen if the weight is calculated from other operations
not sure what will happen if the weight is calculated from other operations
"""
weight_index = _detect_weight_index(layer) weight_index = _detect_weight_index(layer)
if weight_index is None: if weight_index is None:
_logger.warning('Failed to detect weight for layer {}'.format(layer.name)) _logger.warning('Failed to detect weight for layer {}'.format(layer.name))
...@@ -122,9 +115,8 @@ class Pruner(Compressor): ...@@ -122,9 +115,8 @@ class Pruner(Compressor):
class Quantizer(Compressor): class Quantizer(Compressor):
""" """Abstract base TensorFlow quantizer"""
Abstract base TensorFlow quantizer
"""
def __init__(self, config_list): def __init__(self, config_list):
super().__init__(config_list) super().__init__(config_list)
......
...@@ -12,19 +12,8 @@ class LevelPruner(Pruner): ...@@ -12,19 +12,8 @@ class LevelPruner(Pruner):
""" """
def __init__(self, config_list): def __init__(self, config_list):
""" """
we suggest user to use json configure list, like [{},{}...], to set configure config_list: supported keys:
format : - sparsity
[
{
'sparsity': 0,
'support_type': 'default'
},
{
'sparsity': 50,
'support_op': conv1
}
]
if you want input multiple configure from file, you'd better use load_configure_file(path) to load
""" """
super().__init__(config_list) super().__init__(config_list)
...@@ -38,8 +27,7 @@ class LevelPruner(Pruner): ...@@ -38,8 +27,7 @@ class LevelPruner(Pruner):
class AGP_Pruner(Pruner): class AGP_Pruner(Pruner):
""" """An automated gradual pruning algorithm that prunes the smallest magnitude
An automated gradual pruning algorithm that prunes the smallest magnitude
weights to achieve a preset level of network sparsity. weights to achieve a preset level of network sparsity.
Michael Zhu and Suyog Gupta, "To prune, or not to prune: exploring the Michael Zhu and Suyog Gupta, "To prune, or not to prune: exploring the
...@@ -49,12 +37,12 @@ class AGP_Pruner(Pruner): ...@@ -49,12 +37,12 @@ class AGP_Pruner(Pruner):
""" """
def __init__(self, config_list): def __init__(self, config_list):
""" """
Configure Args config_list: supported keys:
initial_sparsity - initial_sparsity
final_sparsity: you should make sure initial_sparsity <= final_sparsity - final_sparsity: you should make sure initial_sparsity <= final_sparsity
start_epoch: start epoch numer begin update mask - start_epoch: start epoch numer begin update mask
end_epoch: end epoch number stop update mask, you should make sure start_epoch <= end_epoch - end_epoch: end epoch number stop update mask, you should make sure start_epoch <= end_epoch
frequency: if you want update every 2 epoch, you can set it 2 - frequency: if you want update every 2 epoch, you can set it 2
""" """
super().__init__(config_list) super().__init__(config_list)
self.mask_list = {} self.mask_list = {}
...@@ -99,8 +87,7 @@ class AGP_Pruner(Pruner): ...@@ -99,8 +87,7 @@ class AGP_Pruner(Pruner):
class SensitivityPruner(Pruner): class SensitivityPruner(Pruner):
""" """Use algorithm from "Learning both Weights and Connections for Efficient Neural Networks"
Use algorithm from "Learning both Weights and Connections for Efficient Neural Networks"
https://arxiv.org/pdf/1506.02626v3.pdf https://arxiv.org/pdf/1506.02626v3.pdf
I.e.: "The pruning threshold is chosen as a quality parameter multiplied I.e.: "The pruning threshold is chosen as a quality parameter multiplied
...@@ -108,8 +95,8 @@ class SensitivityPruner(Pruner): ...@@ -108,8 +95,8 @@ class SensitivityPruner(Pruner):
""" """
def __init__(self, config_list): def __init__(self, config_list):
""" """
configure Args: config_list: supported keys:
sparsity: chosen pruning sparsity - sparsity: chosen pruning sparsity
""" """
super().__init__(config_list) super().__init__(config_list)
self.mask_list = {} self.mask_list = {}
......
...@@ -8,8 +8,7 @@ logger = logging.getLogger(__name__) ...@@ -8,8 +8,7 @@ logger = logging.getLogger(__name__)
class NaiveQuantizer(Quantizer): class NaiveQuantizer(Quantizer):
""" """quantize weight to 8 bits
quantize weight to 8 bits
""" """
def __init__(self, config_list): def __init__(self, config_list):
super().__init__(config_list) super().__init__(config_list)
...@@ -24,15 +23,14 @@ class NaiveQuantizer(Quantizer): ...@@ -24,15 +23,14 @@ class NaiveQuantizer(Quantizer):
class QAT_Quantizer(Quantizer): class QAT_Quantizer(Quantizer):
""" """Quantizer using the DoReFa scheme, as defined in:
Quantizer using the DoReFa scheme, as defined in:
Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference
http://openaccess.thecvf.com/content_cvpr_2018/papers/Jacob_Quantization_and_Training_CVPR_2018_paper.pdf http://openaccess.thecvf.com/content_cvpr_2018/papers/Jacob_Quantization_and_Training_CVPR_2018_paper.pdf
""" """
def __init__(self, config_list): def __init__(self, config_list):
""" """
Configure Args: config_list: supported keys:
q_bits - q_bits
""" """
super().__init__(config_list) super().__init__(config_list)
...@@ -51,15 +49,14 @@ class QAT_Quantizer(Quantizer): ...@@ -51,15 +49,14 @@ class QAT_Quantizer(Quantizer):
class DoReFaQuantizer(Quantizer): class DoReFaQuantizer(Quantizer):
""" """Quantizer using the DoReFa scheme, as defined in:
Quantizer using the DoReFa scheme, as defined in:
Zhou et al., DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidth Gradients Zhou et al., DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidth Gradients
(https://arxiv.org/abs/1606.06160) (https://arxiv.org/abs/1606.06160)
""" """
def __init__(self, config_list): def __init__(self, config_list):
""" """
configure Args: config_list: supported keys:
q_bits - q_bits
""" """
super().__init__(config_list) super().__init__(config_list)
......
...@@ -15,9 +15,8 @@ class LayerInfo: ...@@ -15,9 +15,8 @@ class LayerInfo:
class Compressor: class Compressor:
""" """Abstract base PyTorch compressor"""
Abstract base PyTorch compressor
"""
def __init__(self, config_list): def __init__(self, config_list):
self._bound_model = None self._bound_model = None
self._config_list = config_list self._config_list = config_list
...@@ -27,8 +26,7 @@ class Compressor: ...@@ -27,8 +26,7 @@ class Compressor:
return model return model
def compress(self, model): def compress(self, model):
""" """Compress the model with algorithm implemented by subclass.
Compress the model with algorithm implemented by subclass.
The model will be instrumented and user should never edit it after calling this method. The model will be instrumented and user should never edit it after calling this method.
""" """
assert self._bound_model is None, "Each NNI compressor instance can only compress one model" assert self._bound_model is None, "Each NNI compressor instance can only compress one model"
...@@ -42,22 +40,19 @@ class Compressor: ...@@ -42,22 +40,19 @@ class Compressor:
def bind_model(self, model): def bind_model(self, model):
""" """This method is called when a model is bound to the compressor.
This method is called when a model is bound to the compressor.
Users can optionally overload this method to do model-specific initialization. Users can optionally overload this method to do model-specific initialization.
It is guaranteed that only one model will be bound to each compressor instance. It is guaranteed that only one model will be bound to each compressor instance.
""" """
pass pass
def update_epoch(self, epoch): def update_epoch(self, epoch):
""" """if user want to update model every epoch, user can override this method
if user want to update model every epoch, user can override this method
""" """
pass pass
def step(self): def step(self):
""" """if user want to update model every step, user can override this method
if user want to update model every step, user can override this method
""" """
pass pass
...@@ -82,15 +77,13 @@ class Compressor: ...@@ -82,15 +77,13 @@ class Compressor:
class Pruner(Compressor): class Pruner(Compressor):
""" """Abstract base PyTorch pruner"""
Abstract base PyTorch pruner
"""
def __init__(self, config_list): def __init__(self, config_list):
super().__init__(config_list) super().__init__(config_list)
def calc_mask(self, weight, config, op, op_type, op_name): def calc_mask(self, weight, config, op, op_type, op_name):
""" """Pruners should overload this method to provide mask for weight tensors.
Pruners should overload this method to provide mask for weight tensors.
The mask must have the same shape and type comparing to the weight. The mask must have the same shape and type comparing to the weight.
It will be applied with `mul()` operation. It will be applied with `mul()` operation.
This method is effectively hooked to `forward()` method of the model. This method is effectively hooked to `forward()` method of the model.
...@@ -122,9 +115,8 @@ class Pruner(Compressor): ...@@ -122,9 +115,8 @@ class Pruner(Compressor):
class Quantizer(Compressor): class Quantizer(Compressor):
""" """Base quantizer for pytorch quantizer"""
Base quantizer for pytorch quantizer
"""
def __init__(self, config_list): def __init__(self, config_list):
super().__init__(config_list) super().__init__(config_list)
...@@ -133,8 +125,7 @@ class Quantizer(Compressor): ...@@ -133,8 +125,7 @@ class Quantizer(Compressor):
return model return model
def quantize_weight(self, weight, config, op, op_type, op_name): def quantize_weight(self, weight, config, op, op_type, op_name):
""" """user should know where dequantize goes and implement it in quantize method
user should know where dequantize goes and implement it in quantize method
we now do not provide dequantize method we now do not provide dequantize method
""" """
raise NotImplementedError("Quantizer must overload quantize_weight()") raise NotImplementedError("Quantizer must overload quantize_weight()")
......
from unittest import TestCase, main from unittest import TestCase, main
import nni.compression.tensorflow as tf_compressor import tensorflow as tf
import nni.compression.torch as torch_compressor
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import tensorflow as tf import nni.compression.tensorflow as tf_compressor
import nni.compression.torch as torch_compressor
def weight_variable(shape): def weight_variable(shape):
return tf.Variable(tf.truncated_normal(shape, stddev = 0.1)) return tf.Variable(tf.truncated_normal(shape, stddev = 0.1))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment