Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
81fcff86
Commit
81fcff86
authored
Nov 12, 2019
by
Cjkkkk
Committed by
QuanluZhang
Nov 12, 2019
Browse files
Api refactor (#1728)
api refactor for compression, especially, quantization APIs
parent
7c4e81b5
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
147 additions
and
26 deletions
+147
-26
docs/en_US/Compressor/Overview.md
docs/en_US/Compressor/Overview.md
+46
-6
src/sdk/pynni/nni/compression/torch/compressor.py
src/sdk/pynni/nni/compression/torch/compressor.py
+93
-19
src/sdk/pynni/tests/test_compressor.py
src/sdk/pynni/tests/test_compressor.py
+8
-1
No files found.
docs/en_US/Compressor/Overview.md
View file @
81fcff86
...
@@ -180,13 +180,55 @@ class YourQuantizer(nni.compression.tensorflow.Quantizer):
...
@@ -180,13 +180,55 @@ class YourQuantizer(nni.compression.tensorflow.Quantizer):
def quantize_weight(self, weight, config, **kwargs):
def quantize_weight(self, weight, config, **kwargs):
"""
"""
weight is the target weight tensor
quantize should overload this method to quantize weight tensors.
config is the selected dict object in config_list for this layer
This method is effectively hooked to :meth:`forward` of the model.
kwargs contains op, op_types, and op_name
design your quantizer and return new weight
Parameters
----------
weight : Tensor
weight that needs to be quantized
config : dict
the configuration for weight quantization
"""
"""
# Put your code to generate `new_weight` here
return new_weight
return new_weight
def quantize_output(self, output, config, **kwargs):
"""
quantize should overload this method to quantize output.
This method is effectively hooked to `:meth:`forward` of the model.
Parameters
----------
output : Tensor
output that needs to be quantized
config : dict
the configuration for output quantization
"""
# Put your code to generate `new_output` here
return new_output
def quantize_input(self, *inputs, config, **kwargs):
"""
quantize should overload this method to quantize input.
This method is effectively hooked to :meth:`forward` of the model.
Parameters
----------
inputs : Tensor
inputs that needs to be quantized
config : dict
the configuration for inputs quantization
"""
# Put your code to generate `new_input` here
return new_input
# note for pytorch version, there is no sess in input arguments
# note for pytorch version, there is no sess in input arguments
def update_epoch(self, epoch_num, sess):
def update_epoch(self, epoch_num, sess):
pass
pass
...
@@ -200,8 +242,6 @@ class YourQuantizer(nni.compression.tensorflow.Quantizer):
...
@@ -200,8 +242,6 @@ class YourQuantizer(nni.compression.tensorflow.Quantizer):
pass
pass
```
```
__[TODO]__ Will add another member function `quantize_layer_output`, as some quantization algorithms also quantize layers' output.
### Usage of user customized compression algorithm
### Usage of user customized compression algorithm
__[TODO]__ ...
__[TODO]__ ...
src/sdk/pynni/nni/compression/torch/compressor.py
View file @
81fcff86
...
@@ -32,21 +32,34 @@ class Compressor:
...
@@ -32,21 +32,34 @@ class Compressor:
"""
"""
self
.
bound_model
=
model
self
.
bound_model
=
model
self
.
config_list
=
config_list
self
.
config_list
=
config_list
self
.
modules_to_compress
=
[]
self
.
modules_to_compress
=
None
def
compress
(
self
):
def
detect_modules_to_
compress
(
self
):
"""
"""
Compress the model with algorithm implemented by subclass
.
detect all modules should be compressed, and save the result in `self.modules_to_compress`
.
The model will be instrumented and user should never edit it after calling this method.
The model will be instrumented and user should never edit it after calling this method.
`self.modules_to_compress` records all the to-be-compressed layers
"""
"""
if
self
.
modules_to_compress
is
None
:
self
.
modules_to_compress
=
[]
for
name
,
module
in
self
.
bound_model
.
named_modules
():
for
name
,
module
in
self
.
bound_model
.
named_modules
():
layer
=
LayerInfo
(
name
,
module
)
layer
=
LayerInfo
(
name
,
module
)
config
=
self
.
select_config
(
layer
)
config
=
self
.
select_config
(
layer
)
if
config
is
not
None
:
if
config
is
not
None
:
self
.
_instrument_layer
(
layer
,
config
)
self
.
modules_to_compress
.
append
((
layer
,
config
))
self
.
modules_to_compress
.
append
((
layer
,
config
))
return
self
.
modules_to_compress
def
compress
(
self
):
"""
Compress the model with algorithm implemented by subclass.
The model will be instrumented and user should never edit it after calling this method.
`self.modules_to_compress` records all the to-be-compressed layers
"""
modules_to_compress
=
self
.
detect_modules_to_compress
()
for
layer
,
config
in
modules_to_compress
:
self
.
_instrument_layer
(
layer
,
config
)
return
self
.
bound_model
return
self
.
bound_model
def
get_modules_to_compress
(
self
):
def
get_modules_to_compress
(
self
):
...
@@ -55,7 +68,7 @@ class Compressor:
...
@@ -55,7 +68,7 @@ class Compressor:
Returns
Returns
-------
-------
self.modules_to_compress :
list
list
a list of the layers, each of which is a tuple (`layer`, `config`),
a list of the layers, each of which is a tuple (`layer`, `config`),
`layer` is `LayerInfo`, `config` is a `dict`
`layer` is `LayerInfo`, `config` is a `dict`
"""
"""
...
@@ -72,7 +85,7 @@ class Compressor:
...
@@ -72,7 +85,7 @@ class Compressor:
Returns
Returns
-------
-------
ret :
config or None
config or None
the retrieved configuration for this layer, if None, this layer should
the retrieved configuration for this layer, if None, this layer should
not be compressed
not be compressed
"""
"""
...
@@ -240,26 +253,87 @@ class Quantizer(Compressor):
...
@@ -240,26 +253,87 @@ class Quantizer(Compressor):
"""
"""
def
quantize_weight
(
self
,
weight
,
config
,
op
,
op_type
,
op_name
):
def
quantize_weight
(
self
,
weight
,
config
,
op
,
op_type
,
op_name
):
"""user should know where dequantize goes and implement it in quantize method
"""
we now do not provide dequantize method
quantize should overload this method to quantize weight.
This method is effectively hooked to :meth:`forward` of the model.
Parameters
----------
weight : Tensor
weight that needs to be quantized
config : dict
the configuration for weight quantization
"""
"""
raise
NotImplementedError
(
"Quantizer must overload quantize_weight()"
)
raise
NotImplementedError
(
"Quantizer must overload quantize_weight()"
)
def
quantize_output
(
self
,
output
,
config
,
op
,
op_type
,
op_name
):
"""
quantize should overload this method to quantize output.
This method is effectively hooked to :meth:`forward` of the model.
Parameters
----------
output : Tensor
output that needs to be quantized
config : dict
the configuration for output quantization
"""
raise
NotImplementedError
(
"Quantizer must overload quantize_output()"
)
def
quantize_input
(
self
,
*
inputs
,
config
,
op
,
op_type
,
op_name
):
"""
quantize should overload this method to quantize input.
This method is effectively hooked to :meth:`forward` of the model.
Parameters
----------
inputs : Tensor
inputs that needs to be quantized
config : dict
the configuration for inputs quantization
"""
raise
NotImplementedError
(
"Quantizer must overload quantize_input()"
)
def
_instrument_layer
(
self
,
layer
,
config
):
def
_instrument_layer
(
self
,
layer
,
config
):
"""
Create a wrapper forward function to replace the original one.
Parameters
----------
layer : LayerInfo
the layer to instrument the mask
config : dict
the configuration for quantization
"""
assert
layer
.
_forward
is
None
,
'Each model can only be compressed once'
assert
layer
.
_forward
is
None
,
'Each model can only be compressed once'
assert
"quant_types"
in
config
,
'must provide quant_types in config'
assert
isinstance
(
config
[
"quant_types"
],
list
),
'quant_types must be list type'
if
'weight'
in
config
[
"quant_types"
]:
if
not
_check_weight
(
layer
.
module
):
if
not
_check_weight
(
layer
.
module
):
_logger
.
warning
(
'Module %s does not have parameter "weight"'
,
layer
.
name
)
_logger
.
warning
(
'Module %s does not have parameter "weight"'
,
layer
.
name
)
return
layer
.
_forward
=
layer
.
module
.
forward
layer
.
_forward
=
layer
.
module
.
forward
def
new_forward
(
*
inputs
):
def
new_forward
(
*
inputs
):
if
'input'
in
config
[
"quant_types"
]:
inputs
=
self
.
quantize_input
(
inputs
,
config
=
config
,
op
=
layer
.
module
,
op_type
=
layer
.
type
,
op_name
=
layer
.
name
)
if
'weight'
in
config
[
"quant_types"
]
and
_check_weight
(
layer
.
module
):
weight
=
layer
.
module
.
weight
.
data
weight
=
layer
.
module
.
weight
.
data
new_weight
=
self
.
quantize_weight
(
weight
,
config
,
op
=
layer
.
module
,
op_type
=
layer
.
type
,
op_name
=
layer
.
name
)
new_weight
=
self
.
quantize_weight
(
weight
,
config
,
op
=
layer
.
module
,
op_type
=
layer
.
type
,
op_name
=
layer
.
name
)
layer
.
module
.
weight
.
data
=
new_weight
layer
.
module
.
weight
.
data
=
new_weight
return
layer
.
_forward
(
*
inputs
)
result
=
layer
.
_forward
(
*
inputs
)
layer
.
module
.
weight
.
data
=
weight
else
:
result
=
layer
.
_forward
(
*
inputs
)
layer
.
module
.
forward
=
new_forward
if
'output'
in
config
[
"quant_types"
]:
result
=
self
.
quantize_output
(
result
,
config
,
op
=
layer
.
module
,
op_type
=
layer
.
type
,
op_name
=
layer
.
name
)
return
result
layer
.
module
.
forward
=
new_forward
def
_check_weight
(
module
):
def
_check_weight
(
module
):
try
:
try
:
...
...
src/sdk/pynni/tests/test_compressor.py
View file @
81fcff86
...
@@ -114,7 +114,14 @@ class CompressorTestCase(TestCase):
...
@@ -114,7 +114,14 @@ class CompressorTestCase(TestCase):
def
test_torch_quantizer
(
self
):
def
test_torch_quantizer
(
self
):
model
=
TorchMnist
()
model
=
TorchMnist
()
torch_compressor
.
NaiveQuantizer
(
model
,
[{
'op_types'
:
[
'default'
]}]).
compress
()
configure_list
=
[{
'quant_types'
:
[
'weight'
],
'quant_bits'
:
{
'weight'
:
8
,
},
'op_types'
:[
'Conv2d'
,
'Linear'
]
}]
torch_compressor
.
NaiveQuantizer
(
model
,
configure_list
).
compress
()
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment