Commit 81fcff86 authored by Cjkkkk's avatar Cjkkkk Committed by QuanluZhang
Browse files

Api refactor (#1728)

api refactor for compression, especially, quantization APIs
parent 7c4e81b5
...@@ -180,12 +180,54 @@ class YourQuantizer(nni.compression.tensorflow.Quantizer): ...@@ -180,12 +180,54 @@ class YourQuantizer(nni.compression.tensorflow.Quantizer):
def quantize_weight(self, weight, config, **kwargs): def quantize_weight(self, weight, config, **kwargs):
""" """
weight is the target weight tensor quantize should overload this method to quantize weight tensors.
config is the selected dict object in config_list for this layer This method is effectively hooked to :meth:`forward` of the model.
kwargs contains op, op_types, and op_name
design your quantizer and return new weight Parameters
----------
weight : Tensor
weight that needs to be quantized
config : dict
the configuration for weight quantization
""" """
# Put your code to generate `new_weight` here
return new_weight return new_weight
def quantize_output(self, output, config, **kwargs):
"""
quantize should overload this method to quantize output.
This method is effectively hooked to `:meth:`forward` of the model.
Parameters
----------
output : Tensor
output that needs to be quantized
config : dict
the configuration for output quantization
"""
# Put your code to generate `new_output` here
return new_output
def quantize_input(self, *inputs, config, **kwargs):
"""
quantize should overload this method to quantize input.
This method is effectively hooked to :meth:`forward` of the model.
Parameters
----------
inputs : Tensor
inputs that needs to be quantized
config : dict
the configuration for inputs quantization
"""
# Put your code to generate `new_input` here
return new_input
# note for pytorch version, there is no sess in input arguments # note for pytorch version, there is no sess in input arguments
def update_epoch(self, epoch_num, sess): def update_epoch(self, epoch_num, sess):
...@@ -200,8 +242,6 @@ class YourQuantizer(nni.compression.tensorflow.Quantizer): ...@@ -200,8 +242,6 @@ class YourQuantizer(nni.compression.tensorflow.Quantizer):
pass pass
``` ```
__[TODO]__ Will add another member function `quantize_layer_output`, as some quantization algorithms also quantize layers' output.
### Usage of user customized compression algorithm ### Usage of user customized compression algorithm
__[TODO]__ ... __[TODO]__ ...
...@@ -32,7 +32,23 @@ class Compressor: ...@@ -32,7 +32,23 @@ class Compressor:
""" """
self.bound_model = model self.bound_model = model
self.config_list = config_list self.config_list = config_list
self.modules_to_compress = [] self.modules_to_compress = None
def detect_modules_to_compress(self):
"""
detect all modules should be compressed, and save the result in `self.modules_to_compress`.
The model will be instrumented and user should never edit it after calling this method.
"""
if self.modules_to_compress is None:
self.modules_to_compress = []
for name, module in self.bound_model.named_modules():
layer = LayerInfo(name, module)
config = self.select_config(layer)
if config is not None:
self.modules_to_compress.append((layer, config))
return self.modules_to_compress
def compress(self): def compress(self):
""" """
...@@ -41,12 +57,9 @@ class Compressor: ...@@ -41,12 +57,9 @@ class Compressor:
The model will be instrumented and user should never edit it after calling this method. The model will be instrumented and user should never edit it after calling this method.
`self.modules_to_compress` records all the to-be-compressed layers `self.modules_to_compress` records all the to-be-compressed layers
""" """
for name, module in self.bound_model.named_modules(): modules_to_compress = self.detect_modules_to_compress()
layer = LayerInfo(name, module) for layer, config in modules_to_compress:
config = self.select_config(layer) self._instrument_layer(layer, config)
if config is not None:
self._instrument_layer(layer, config)
self.modules_to_compress.append((layer, config))
return self.bound_model return self.bound_model
def get_modules_to_compress(self): def get_modules_to_compress(self):
...@@ -55,7 +68,7 @@ class Compressor: ...@@ -55,7 +68,7 @@ class Compressor:
Returns Returns
------- -------
self.modules_to_compress : list list
a list of the layers, each of which is a tuple (`layer`, `config`), a list of the layers, each of which is a tuple (`layer`, `config`),
`layer` is `LayerInfo`, `config` is a `dict` `layer` is `LayerInfo`, `config` is a `dict`
""" """
...@@ -72,7 +85,7 @@ class Compressor: ...@@ -72,7 +85,7 @@ class Compressor:
Returns Returns
------- -------
ret : config or None config or None
the retrieved configuration for this layer, if None, this layer should the retrieved configuration for this layer, if None, this layer should
not be compressed not be compressed
""" """
...@@ -240,26 +253,87 @@ class Quantizer(Compressor): ...@@ -240,26 +253,87 @@ class Quantizer(Compressor):
""" """
def quantize_weight(self, weight, config, op, op_type, op_name): def quantize_weight(self, weight, config, op, op_type, op_name):
"""user should know where dequantize goes and implement it in quantize method """
we now do not provide dequantize method quantize should overload this method to quantize weight.
This method is effectively hooked to :meth:`forward` of the model.
Parameters
----------
weight : Tensor
weight that needs to be quantized
config : dict
the configuration for weight quantization
""" """
raise NotImplementedError("Quantizer must overload quantize_weight()") raise NotImplementedError("Quantizer must overload quantize_weight()")
def quantize_output(self, output, config, op, op_type, op_name):
"""
quantize should overload this method to quantize output.
This method is effectively hooked to :meth:`forward` of the model.
Parameters
----------
output : Tensor
output that needs to be quantized
config : dict
the configuration for output quantization
"""
raise NotImplementedError("Quantizer must overload quantize_output()")
def quantize_input(self, *inputs, config, op, op_type, op_name):
"""
quantize should overload this method to quantize input.
This method is effectively hooked to :meth:`forward` of the model.
Parameters
----------
inputs : Tensor
inputs that needs to be quantized
config : dict
the configuration for inputs quantization
"""
raise NotImplementedError("Quantizer must overload quantize_input()")
def _instrument_layer(self, layer, config): def _instrument_layer(self, layer, config):
"""
Create a wrapper forward function to replace the original one.
Parameters
----------
layer : LayerInfo
the layer to instrument the mask
config : dict
the configuration for quantization
"""
assert layer._forward is None, 'Each model can only be compressed once' assert layer._forward is None, 'Each model can only be compressed once'
if not _check_weight(layer.module): assert "quant_types" in config, 'must provide quant_types in config'
_logger.warning('Module %s does not have parameter "weight"', layer.name) assert isinstance(config["quant_types"], list), 'quant_types must be list type'
return
if 'weight' in config["quant_types"]:
if not _check_weight(layer.module):
_logger.warning('Module %s does not have parameter "weight"', layer.name)
layer._forward = layer.module.forward layer._forward = layer.module.forward
def new_forward(*inputs): def new_forward(*inputs):
weight = layer.module.weight.data if 'input' in config["quant_types"]:
new_weight = self.quantize_weight(weight, config, op=layer.module, op_type=layer.type, op_name=layer.name) inputs = self.quantize_input(inputs, config=config, op=layer.module, op_type=layer.type, op_name=layer.name)
layer.module.weight.data = new_weight
return layer._forward(*inputs) if 'weight' in config["quant_types"] and _check_weight(layer.module):
weight = layer.module.weight.data
new_weight = self.quantize_weight(weight, config, op=layer.module, op_type=layer.type, op_name=layer.name)
layer.module.weight.data = new_weight
result = layer._forward(*inputs)
layer.module.weight.data = weight
else:
result = layer._forward(*inputs)
layer.module.forward = new_forward if 'output' in config["quant_types"]:
result = self.quantize_output(result, config, op=layer.module, op_type=layer.type, op_name=layer.name)
return result
layer.module.forward = new_forward
def _check_weight(module): def _check_weight(module):
try: try:
......
...@@ -114,7 +114,14 @@ class CompressorTestCase(TestCase): ...@@ -114,7 +114,14 @@ class CompressorTestCase(TestCase):
def test_torch_quantizer(self): def test_torch_quantizer(self):
model = TorchMnist() model = TorchMnist()
torch_compressor.NaiveQuantizer(model, [{'op_types': ['default']}]).compress() configure_list = [{
'quant_types': ['weight'],
'quant_bits': {
'weight': 8,
},
'op_types':['Conv2d', 'Linear']
}]
torch_compressor.NaiveQuantizer(model, configure_list).compress()
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment