Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
c9cd53aa
Unverified
Commit
c9cd53aa
authored
Oct 13, 2021
by
chenbohua3
Committed by
GitHub
Oct 13, 2021
Browse files
support dtype&scheme customization for QAT quantizer (#4137)
parent
b0f34da1
Changes
10
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
775 additions
and
190 deletions
+775
-190
docs/en_US/Compression/CustomizeCompressor.rst
docs/en_US/Compression/CustomizeCompressor.rst
+2
-2
examples/model_compress/quantization/QAT_torch_quantizer.py
examples/model_compress/quantization/QAT_torch_quantizer.py
+29
-16
nni/algorithms/compression/pytorch/quantization/quantizers.py
...algorithms/compression/pytorch/quantization/quantizers.py
+226
-122
nni/common/version.py
nni/common/version.py
+7
-0
nni/compression/pytorch/compressor.py
nni/compression/pytorch/compressor.py
+77
-31
nni/compression/pytorch/quantization/literal.py
nni/compression/pytorch/quantization/literal.py
+65
-0
nni/compression/pytorch/quantization/observers.py
nni/compression/pytorch/quantization/observers.py
+15
-0
nni/compression/pytorch/quantization/settings.py
nni/compression/pytorch/quantization/settings.py
+118
-0
nni/compression/pytorch/quantization/utils.py
nni/compression/pytorch/quantization/utils.py
+83
-0
test/ut/sdk/test_compressor_torch.py
test/ut/sdk/test_compressor_torch.py
+153
-19
No files found.
docs/en_US/Compression/CustomizeCompressor.rst
View file @
c9cd53aa
...
@@ -155,7 +155,7 @@ Sometimes it's necessary for a quantization operation to have a customized backw
...
@@ -155,7 +155,7 @@ Sometimes it's necessary for a quantization operation to have a customized backw
grad_output : Tensor
grad_output : Tensor
gradient of the output of quantization operation
gradient of the output of quantization operation
quant_type : QuantType
quant_type : QuantType
the type of quantization, it can be `QuantType.
QUANT_
INPUT`, `QuantType.
QUANT_
WEIGHT`, `QuantType.
QUANT_
OUTPUT`,
the type of quantization, it can be `QuantType.INPUT`, `QuantType.WEIGHT`, `QuantType.OUTPUT`,
you can define different behavior for different types.
you can define different behavior for different types.
Returns
Returns
-------
-------
...
@@ -164,7 +164,7 @@ Sometimes it's necessary for a quantization operation to have a customized backw
...
@@ -164,7 +164,7 @@ Sometimes it's necessary for a quantization operation to have a customized backw
"""
"""
# for quant_output function, set grad to zero if the absolute value of tensor is larger than 1
# for quant_output function, set grad to zero if the absolute value of tensor is larger than 1
if quant_type == QuantType.
QUANT_
OUTPUT:
if quant_type == QuantType.OUTPUT:
grad_output[torch.abs(tensor) > 1] = 0
grad_output[torch.abs(tensor) > 1] = 0
return grad_output
return grad_output
...
...
examples/model_compress/quantization/QAT_torch_quantizer.py
View file @
c9cd53aa
...
@@ -2,11 +2,13 @@ import torch
...
@@ -2,11 +2,13 @@ import torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
torchvision
import
datasets
,
transforms
from
torchvision
import
datasets
,
transforms
from
nni.algorithms.compression.pytorch.quantization
import
QAT_Quantizer
from
nni.algorithms.compression.pytorch.quantization
import
QAT_Quantizer
from
nni.compression.pytorch.quantization.settings
import
set_quant_scheme_dtype
import
sys
import
sys
sys
.
path
.
append
(
'../models'
)
sys
.
path
.
append
(
'../models'
)
from
mnist.naive
import
NaiveModel
from
mnist.naive
import
NaiveModel
def
train
(
model
,
device
,
train_loader
,
optimizer
):
def
train
(
model
,
device
,
train_loader
,
optimizer
):
model
.
train
()
model
.
train
()
for
batch_idx
,
(
data
,
target
)
in
enumerate
(
train_loader
):
for
batch_idx
,
(
data
,
target
)
in
enumerate
(
train_loader
):
...
@@ -58,22 +60,32 @@ def main():
...
@@ -58,22 +60,32 @@ def main():
# {'quant_types': ['input'], 'op_names': ['b']} in the configure_list.
# {'quant_types': ['input'], 'op_names': ['b']} in the configure_list.
configure_list
=
[{
configure_list
=
[{
'quant_types'
:
[
'weight'
,
'input'
],
'quant_types'
:
[
'weight'
,
'input'
],
'quant_bits'
:
{
'weight'
:
8
,
'input'
:
8
},
'quant_bits'
:
{
'weight'
:
8
,
'input'
:
8
},
'op_names'
:
[
'conv1'
,
'conv2'
]
'op_names'
:
[
'conv1'
,
'conv2'
]
},
{
},
{
'quant_types'
:
[
'output'
],
'quant_types'
:
[
'output'
],
'quant_bits'
:
{
'output'
:
8
,
},
'quant_bits'
:
{
'output'
:
8
,
},
'op_names'
:
[
'relu1'
,
'relu2'
]
'op_names'
:
[
'relu1'
,
'relu2'
]
},
{
},
{
'quant_types'
:
[
'output'
,
'weight'
,
'input'
],
'quant_types'
:
[
'output'
,
'weight'
,
'input'
],
'quant_bits'
:
{
'output'
:
8
,
'weight'
:
8
,
'input'
:
8
},
'quant_bits'
:
{
'output'
:
8
,
'weight'
:
8
,
'input'
:
8
},
'op_names'
:
[
'fc1'
],
'op_names'
:
[
'fc1'
,
'fc2'
],
},
{
}]
'quant_types'
:
[
'output'
,
'weight'
,
'input'
],
'quant_bits'
:
{
'output'
:
8
,
'weight'
:
8
,
'input'
:
8
},
# you can also set the quantization dtype and scheme layer-wise through configure_list like:
'op_names'
:
[
'fc2'
],
# configure_list = [{
}]
# 'quant_types': ['weight', 'input'],
# 'quant_bits': {'weight': 8, 'input': 8},
# 'op_names': ['conv1', 'conv2'],
# 'quant_dtype': 'int',
# 'quant_scheme': 'per_channel_symmetric'
# }]
# For now quant_dtype's options are 'int' and 'uint. And quant_scheme's options are per_tensor_affine,
# per_tensor_symmetric, per_channel_affine and per_channel_symmetric.
set_quant_scheme_dtype
(
'weight'
,
'per_channel_symmetric'
,
'int'
)
set_quant_scheme_dtype
(
'output'
,
'per_tensor_symmetric'
,
'int'
)
set_quant_scheme_dtype
(
'input'
,
'per_tensor_symmetric'
,
'int'
)
model
=
NaiveModel
().
to
(
device
)
model
=
NaiveModel
().
to
(
device
)
dummy_input
=
torch
.
randn
(
1
,
1
,
28
,
28
).
to
(
device
)
dummy_input
=
torch
.
randn
(
1
,
1
,
28
,
28
).
to
(
device
)
...
@@ -98,5 +110,6 @@ def main():
...
@@ -98,5 +110,6 @@ def main():
calibration_config
=
quantizer
.
export_model
(
model_path
,
calibration_path
,
onnx_path
,
input_shape
,
device
)
calibration_config
=
quantizer
.
export_model
(
model_path
,
calibration_path
,
onnx_path
,
input_shape
,
device
)
print
(
"Generated calibration config is: "
,
calibration_config
)
print
(
"Generated calibration config is: "
,
calibration_config
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
main
()
main
()
nni/algorithms/compression/pytorch/quantization/quantizers.py
View file @
c9cd53aa
This diff is collapsed.
Click to expand it.
nni/common/version.py
0 → 100644
View file @
c9cd53aa
import
logging
try
:
import
torch
TORCH_VERSION
=
tuple
(
int
(
x
)
for
x
in
torch
.
__version__
.
split
(
"."
)[:
2
])
except
Exception
:
logging
.
info
(
"PyTorch is not installed."
)
TORCH_VERSION
=
None
nni/compression/pytorch/compressor.py
View file @
c9cd53aa
# Copyright (c) Microsoft Corporation.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Licensed under the MIT license.
import
copy
import
types
import
types
import
logging
import
logging
import
torch
import
torch
from
nni.common.graph_utils
import
build_module_graph
from
nni.common.graph_utils
import
build_module_graph
from
nni.compression.pytorch.quantization.literal
import
QuantType
,
BN_FOLD_OP
,
BN_FOLD_TAG
from
nni.compression.pytorch.quantization.observers
import
RecordingObserver
from
.
import
default_layers
from
.
import
default_layers
_logger
=
logging
.
getLogger
(
__name__
)
_logger
=
logging
.
getLogger
(
__name__
)
...
@@ -547,7 +550,7 @@ class QuantizerModuleWrapper(torch.nn.Module):
...
@@ -547,7 +550,7 @@ class QuantizerModuleWrapper(torch.nn.Module):
assert
len
(
inputs
)
==
1
,
"Quantization of input only supports ops with single input."
assert
len
(
inputs
)
==
1
,
"Quantization of input only supports ops with single input."
new_inp
=
self
.
quantizer
.
quant_grad
(
new_inp
=
self
.
quantizer
.
quant_grad
(
inputs
[
0
],
inputs
[
0
],
QuantType
.
QUANT_
INPUT
,
QuantType
.
INPUT
,
self
)
self
)
inputs
=
(
new_inp
,)
inputs
=
(
new_inp
,)
...
@@ -563,7 +566,7 @@ class QuantizerModuleWrapper(torch.nn.Module):
...
@@ -563,7 +566,7 @@ class QuantizerModuleWrapper(torch.nn.Module):
self
.
quantizer
.
quant_grad
(
self
.
quantizer
.
quant_grad
(
new_weight
,
new_weight
,
QuantType
.
QUANT_
WEIGHT
,
QuantType
.
WEIGHT
,
self
,
inputs
[
0
])
self
,
inputs
[
0
])
result
=
self
.
module
(
*
inputs
)
result
=
self
.
module
(
*
inputs
)
...
@@ -571,7 +574,7 @@ class QuantizerModuleWrapper(torch.nn.Module):
...
@@ -571,7 +574,7 @@ class QuantizerModuleWrapper(torch.nn.Module):
if
'output'
in
self
.
config
[
'quant_types'
]:
if
'output'
in
self
.
config
[
'quant_types'
]:
result
=
self
.
quantizer
.
quant_grad
(
result
=
self
.
quantizer
.
quant_grad
(
result
,
result
,
QuantType
.
QUANT_
OUTPUT
,
QuantType
.
OUTPUT
,
self
)
self
)
return
result
return
result
...
@@ -604,10 +607,13 @@ class Quantizer(Compressor):
...
@@ -604,10 +607,13 @@ class Quantizer(Compressor):
def
__init__
(
self
,
model
,
config_list
,
optimizer
=
None
,
dummy_input
=
None
):
def
__init__
(
self
,
model
,
config_list
,
optimizer
=
None
,
dummy_input
=
None
):
if
isinstance
(
model
,
torch
.
nn
.
DataParallel
):
if
isinstance
(
model
,
torch
.
nn
.
DataParallel
):
model
=
model
.
module
model
=
model
.
module
model_copied
=
copy
.
deepcopy
(
model
)
self
.
identity_wrappers
=
[]
self
.
identity_wrappers
=
[]
self
.
conv_bn_patterns
=
{}
self
.
conv_bn_patterns
=
{}
self
.
find_conv_bn_patterns
(
model
,
dummy_input
)
self
.
find_conv_bn_patterns
(
model
,
dummy_input
)
super
().
__init__
(
model
,
config_list
,
optimizer
)
super
().
__init__
(
model
,
config_list
,
optimizer
)
self
.
all_shapes
=
{}
self
.
record_shape
(
model_copied
,
dummy_input
)
self
.
quant_grad
=
QuantGrad
.
apply
self
.
quant_grad
=
QuantGrad
.
apply
if
self
.
optimizer
is
not
None
:
if
self
.
optimizer
is
not
None
:
self
.
patch_optimizer
(
self
.
step_with_optimizer
)
self
.
patch_optimizer
(
self
.
step_with_optimizer
)
...
@@ -845,25 +851,54 @@ class Quantizer(Compressor):
...
@@ -845,25 +851,54 @@ class Quantizer(Compressor):
if
successor
.
op_type
==
'BatchNorm2d'
:
if
successor
.
op_type
==
'BatchNorm2d'
:
self
.
conv_bn_patterns
[
node_group
.
name
]
=
successor
.
name
self
.
conv_bn_patterns
[
node_group
.
name
]
=
successor
.
name
def
step_with_optimizer
(
self
):
def
record_shape
(
self
,
model
,
dummy_input
):
pass
"""
Record input/output's shapes of each module to be quantized
class
QuantType
:
Parameters
"""
----------
Enum class for quantization type.
model : torch.nn.Module
"""
model to be recorded.
QUANT_INPUT
=
0
dummy_input : tupel of torch.tensor
QUANT_WEIGHT
=
1
inputs to the model.
QUANT_OUTPUT
=
2
"""
def
_pre_forward_hook
(
self
,
inp
):
# Only record the first tensor of the input
return
self
.
pre_forward
(
inp
[
0
])
def
_post_forward_hook
(
self
,
_
,
out
):
return
self
.
post_forward
(
out
)
if
dummy_input
is
None
:
return
all_handles
=
[]
all_observers
=
{}
modules_to_compress
=
self
.
get_modules_to_compress
()
compress_names
=
[
layer_info
[
0
].
name
for
layer_info
in
modules_to_compress
]
for
name
,
module
in
model
.
named_modules
():
if
name
in
compress_names
:
all_observers
[
name
]
=
{}
all_observers
[
name
][
'input_hook'
]
=
RecordingObserver
()
all_observers
[
name
][
'output_hook'
]
=
RecordingObserver
()
module
.
add_module
(
'pre_forward'
,
all_observers
[
name
][
'input_hook'
])
module
.
add_module
(
'post_forward'
,
all_observers
[
name
][
'output_hook'
])
all_handles
.
append
(
module
.
register_forward_pre_hook
(
_pre_forward_hook
))
all_handles
.
append
(
module
.
register_forward_hook
(
_post_forward_hook
))
model
(
dummy_input
)
for
name
,
hooks
in
all_observers
.
items
():
# only support single input
input_val
=
hooks
[
'input_hook'
].
tensor_val
input_shape
=
input_val
[
0
].
shape
if
input_val
else
None
output_val
=
hooks
[
'output_hook'
].
tensor_val
output_shape
=
output_val
[
0
].
shape
if
output_val
else
None
shapes
=
[
input_shape
,
output_shape
]
self
.
all_shapes
[
name
]
=
shapes
return
QType_Dict
=
{
def
step_with_optimizer
(
self
):
0
:
"input"
,
pass
1
:
"weight"
,
2
:
"output"
}
BN_FOLD_OP
=
[
"Conv2d"
]
BN_FOLD_TAG
=
'BN_FOLD_TAG'
class
QuantGrad
(
torch
.
autograd
.
Function
):
class
QuantGrad
(
torch
.
autograd
.
Function
):
"""
"""
...
@@ -920,8 +955,8 @@ class QuantGrad(torch.autograd.Function):
...
@@ -920,8 +955,8 @@ class QuantGrad(torch.autograd.Function):
grad_output : Tensor
grad_output : Tensor
gradient of the output of quantization operation
gradient of the output of quantization operation
scale : Tensor
scale : Tensor
the type of quantization, it can be `QuantType.
QUANT_
INPUT`, `QuantType.
QUANT_
WEIGHT`,
the type of quantization, it can be `QuantType.INPUT`, `QuantType.WEIGHT`,
`QuantType.
QUANT_
OUTPUT`, you can define different behavior for different types.
`QuantType.OUTPUT`, you can define different behavior for different types.
zero_point : Tensor
zero_point : Tensor
zero_point for quantizing tensor
zero_point for quantizing tensor
qmin : Tensor
qmin : Tensor
...
@@ -939,28 +974,39 @@ class QuantGrad(torch.autograd.Function):
...
@@ -939,28 +974,39 @@ class QuantGrad(torch.autograd.Function):
def
forward
(
ctx
,
tensor
,
quant_type
,
wrapper
,
input_tensor
=
None
,
**
kwargs
):
def
forward
(
ctx
,
tensor
,
quant_type
,
wrapper
,
input_tensor
=
None
,
**
kwargs
):
output
=
quantize_helper
(
tensor
,
quant_type
,
wrapper
,
input_tensor
,
**
kwargs
)
output
=
quantize_helper
(
tensor
,
quant_type
,
wrapper
,
input_tensor
,
**
kwargs
)
bits
=
QuantGrad
.
get_bits_length
(
wrapper
.
config
,
QType_Dict
[
quant_type
])
if
hasattr
(
wrapper
.
module
,
"layer_quant_setting"
):
qmin
,
qmax
=
torch
.
Tensor
([
0
]).
to
(
tensor
.
device
),
torch
.
Tensor
([(
1
<<
bits
)
-
1
]).
to
(
tensor
.
device
)
layer_quant_setting
=
wrapper
.
module
.
layer_quant_setting
if
hasattr
(
wrapper
.
module
,
'scale'
)
and
hasattr
(
wrapper
.
module
,
'zero_point'
):
qmin
,
qmax
=
getattr
(
layer_quant_setting
,
quant_type
).
get_qmin_qmax
()
else
:
# todo: when dtype/scheme customization is ready for all quantizers, remove this
bits
=
QuantGrad
.
get_bits_length
(
wrapper
.
config
,
quant_type
)
qmin
,
qmax
=
0
,
(
1
<<
bits
)
-
1
scale_name
,
zero_point_name
=
quant_type
.
type_to_scale_zero_point_name
()
if
hasattr
(
wrapper
.
module
,
scale_name
)
and
hasattr
(
wrapper
.
module
,
zero_point_name
):
scale
=
getattr
(
wrapper
.
module
,
scale_name
)
zero_point
=
getattr
(
wrapper
.
module
,
zero_point_name
)
# todo: remove this when other quantizers use different scale & zero point for input/weight/output
elif
hasattr
(
wrapper
.
module
,
'scale'
)
and
hasattr
(
wrapper
.
module
,
'zero_point'
):
scale
=
wrapper
.
module
.
scale
scale
=
wrapper
.
module
.
scale
zero_point
=
wrapper
.
module
.
zero_point
zero_point
=
wrapper
.
module
.
zero_point
else
:
else
:
scale
,
zero_point
=
None
,
None
scale
,
zero_point
=
None
,
None
ctx
.
save_for_backward
(
tensor
)
# Only tensors have gradients flowing back needs to be saved by save_for_backward.
# Only tensors have gradients flowing back needs to be saved by save_for_backward.
# Others should directly assign to ctx.
# Others should directly assign to ctx.
ctx
.
scale
=
scale
ctx
.
save_for_backward
(
tensor
)
ctx
.
zero_point
=
zero_point
ctx
.
quant_type
=
quant_type
ctx
.
quant_type
=
quant_type
ctx
.
qmin
,
ctx
.
qmax
=
qmin
,
qmax
ctx
.
qmin
,
ctx
.
qmax
=
qmin
,
qmax
ctx
.
scale
=
scale
ctx
.
zero_point
=
zero_point
return
output
return
output
@
classmethod
@
classmethod
def
backward
(
cls
,
ctx
,
grad_output
):
def
backward
(
cls
,
ctx
,
grad_output
):
tensor
=
ctx
.
saved_variables
[
0
]
tensor
=
ctx
.
saved_variables
[
0
]
scale
,
zero_point
=
ctx
.
scale
,
ctx
.
zero_point
scale
,
zero_point
=
ctx
.
scale
,
ctx
.
zero_point
qmin
,
qmax
=
ctx
.
qmin
,
ctx
.
qmax
quant_type
=
ctx
.
quant_type
quant_type
=
ctx
.
quant_type
qmin
,
qmax
=
ctx
.
qmin
,
ctx
.
qmax
output
=
cls
.
quant_backward
(
tensor
,
grad_output
,
quant_type
,
scale
,
zero_point
,
qmin
,
qmax
)
output
=
cls
.
quant_backward
(
tensor
,
grad_output
,
quant_type
,
scale
,
zero_point
,
qmin
,
qmax
)
return
output
,
None
,
None
,
None
return
output
,
None
,
None
,
None
...
@@ -977,11 +1023,11 @@ def _check_bias(module):
...
@@ -977,11 +1023,11 @@ def _check_bias(module):
return
False
return
False
def
quantize_helper
(
tensor
,
quant_type
,
wrapper
,
input_tensor
=
None
,
**
kwargs
):
def
quantize_helper
(
tensor
,
quant_type
,
wrapper
,
input_tensor
=
None
,
**
kwargs
):
if
quant_type
==
QuantType
.
QUANT_
INPUT
:
if
quant_type
==
QuantType
.
INPUT
:
output
=
wrapper
.
quantizer
.
quantize_input
(
tensor
,
wrapper
=
wrapper
,
**
kwargs
)
output
=
wrapper
.
quantizer
.
quantize_input
(
tensor
,
wrapper
=
wrapper
,
**
kwargs
)
elif
quant_type
==
QuantType
.
QUANT_
WEIGHT
:
elif
quant_type
==
QuantType
.
WEIGHT
:
output
=
wrapper
.
quantizer
.
quantize_weight
(
wrapper
,
input_tensor
=
input_tensor
,
**
kwargs
)
output
=
wrapper
.
quantizer
.
quantize_weight
(
wrapper
,
input_tensor
=
input_tensor
,
**
kwargs
)
elif
quant_type
==
QuantType
.
QUANT_
OUTPUT
:
elif
quant_type
==
QuantType
.
OUTPUT
:
output
=
wrapper
.
quantizer
.
quantize_output
(
tensor
,
wrapper
,
**
kwargs
)
output
=
wrapper
.
quantizer
.
quantize_output
(
tensor
,
wrapper
,
**
kwargs
)
else
:
else
:
raise
ValueError
(
"unrecognized QuantType."
)
raise
ValueError
(
"unrecognized QuantType."
)
...
...
nni/compression/pytorch/quantization/literal.py
0 → 100644
View file @
c9cd53aa
from
enum
import
Enum
,
EnumMeta
class
_QuantLiteralEnumMeta
(
EnumMeta
):
def
__contains__
(
cls
,
item
):
try
:
cls
(
item
)
except
ValueError
:
return
False
return
True
class
_QuantLiteralEnum
(
Enum
,
metaclass
=
_QuantLiteralEnumMeta
):
pass
class
QuantScheme
(
str
,
_QuantLiteralEnum
):
PER_TENSOR_AFFINE
=
'per_tensor_affine'
PER_TENSOR_SYMMETRIC
=
'per_tensor_symmetric'
PER_CHANNEL_AFFINE
=
'per_channel_affine'
PER_CHANNEL_SYMMETRIC
=
'per_channel_symmetric'
PER_CHANNEL_QUANT_SCHEME
=
[
QuantScheme
.
PER_CHANNEL_AFFINE
,
QuantScheme
.
PER_CHANNEL_SYMMETRIC
]
class
QuantDtype
(
str
,
_QuantLiteralEnum
):
UINT
=
'uint'
INT
=
'int'
class
QuantType
(
str
,
_QuantLiteralEnum
):
INPUT
=
'input'
WEIGHT
=
'weight'
OUTPUT
=
'output'
def
type_to_scale_zero_point_name
(
self
):
if
self
==
QuantType
.
INPUT
:
return
'input_scale'
,
'input_zero_point'
elif
self
==
QuantType
.
WEIGHT
:
return
'weight_scale'
,
'weight_zero_point'
elif
self
==
QuantType
.
OUTPUT
:
return
'output_scale'
,
'output_zero_point'
else
:
raise
TypeError
# Just show each attribute's name, no practical effect
class
QuantConfigLiteral
(
str
,
_QuantLiteralEnum
):
QUANT_SETTINGS
=
'quant_settings'
QUANT_SCHEME
=
'quant_scheme'
QUANT_DTYPE
=
'quant_dtype'
BITS
=
'bits'
QMIN
=
'qmin'
QMAX
=
'qmax'
INPUT_SCALE
=
'input_scale'
INPUT_ZERO_POINT
=
'input_zero_point'
OUTPUT_SCALE
=
'output_scale'
OUTPUT_ZERO_POINT
=
'output_zero_point'
WEIGHT_SCALE
=
'weight_scale'
WEIGHT_ZERO_POINT
=
'weight_zero_point'
BN_FOLD_OP
=
[
"Conv2d"
]
BN_FOLD_TAG
=
'BN_FOLD_TAG'
nni/
algorithms/
compression/pytorch/quantization/observers.py
→
nni/compression/pytorch/quantization/observers.py
View file @
c9cd53aa
from
torch.quantization
import
default_weight_observer
,
default_histogram_observer
from
torch.quantization
import
default_weight_observer
,
default_histogram_observer
from
torch.quantization
import
RecordingObserver
as
_RecordingObserver
__all__
=
[
"default_weight_observer"
,
"default_histogram_observer"
]
__all__
=
[
"default_weight_observer"
,
"default_histogram_observer"
,
"RecordingObserver"
]
class
RecordingObserver
(
_RecordingObserver
):
"""
A extended version of PyTorch's RecordingObserver, used to record gpu tensor
"""
def
forward
(
self
,
x
):
val
=
x
.
cpu
()
super
().
forward
(
val
)
return
x
nni/compression/pytorch/quantization/settings.py
0 → 100644
View file @
c9cd53aa
from
typing
import
Any
,
Optional
from
.literal
import
QuantDtype
,
QuantType
,
QuantScheme
from
.utils
import
calculate_qmin_qmax
,
get_bits_length
# default settings for quantization module
quant_default_settings
=
{
QuantType
.
WEIGHT
:
{
'quant_scheme'
:
QuantScheme
.
PER_TENSOR_AFFINE
,
'quant_dtype'
:
QuantDtype
.
UINT
,
},
QuantType
.
INPUT
:
{
'quant_scheme'
:
QuantScheme
.
PER_TENSOR_AFFINE
,
'quant_dtype'
:
QuantDtype
.
UINT
},
QuantType
.
OUTPUT
:
{
'quant_scheme'
:
QuantScheme
.
PER_TENSOR_AFFINE
,
'quant_dtype'
:
QuantDtype
.
UINT
}
}
class
TensorQuantSetting
(
object
):
def
__init__
(
self
,
**
kwargs
):
self
.
_fields
=
{}
for
k
,
v
in
kwargs
.
items
():
self
.
_fields
[
k
]
=
v
def
__setattr__
(
self
,
name
:
str
,
val
:
Any
)
->
None
:
if
name
.
startswith
(
"_"
):
super
().
__setattr__
(
name
,
val
)
else
:
self
.
_fields
[
name
]
=
val
def
__getattr__
(
self
,
name
):
if
name
==
"_fields"
or
name
not
in
self
.
_fields
:
raise
AttributeError
(
"Cannot find {} in TensorQuantSetting!"
.
format
(
name
))
return
self
.
_fields
[
name
]
def
get_qmin_qmax
(
self
):
assert
'qmin'
in
self
.
_fields
and
'qmax'
in
self
.
_fields
,
\
"Can not found qmin & qmax in TensorQuantSetting"
return
self
.
_fields
[
'qmin'
],
self
.
_fields
[
'qmax'
]
class
LayerQuantSetting
(
object
):
def
__init__
(
self
,
config
):
self
.
input
:
Optional
[
TensorQuantSetting
]
=
None
self
.
weight
:
Optional
[
TensorQuantSetting
]
=
None
self
.
output
:
Optional
[
TensorQuantSetting
]
=
None
self
.
_extra_layer_setting
=
{}
for
quant_type
in
QuantType
:
if
quant_type
in
config
.
get
(
"quant_types"
,
[]):
setting
=
TensorQuantSetting
()
quant_scheme
=
self
.
parse_optional_config
(
config
,
quant_type
,
'quant_scheme'
)
setting
.
quant_scheme
=
quant_scheme
quant_dtype
=
self
.
parse_optional_config
(
config
,
quant_type
,
'quant_dtype'
)
setting
.
quant_dtype
=
quant_dtype
bits
=
get_bits_length
(
config
,
quant_type
)
qmin
,
qmax
=
calculate_qmin_qmax
(
bits
,
quant_dtype
)
setting
.
bits
=
bits
setting
.
qmin
=
qmin
setting
.
qmax
=
qmax
setattr
(
self
,
quant_type
,
setting
)
def
__setattr__
(
self
,
name
:
str
,
val
:
Any
)
->
None
:
if
name
.
startswith
(
"_"
)
or
name
in
QuantType
:
super
().
__setattr__
(
name
,
val
)
else
:
self
.
_extra_layer_setting
[
name
]
=
val
def
__getattr__
(
self
,
name
):
if
name
==
"_extra_layer_setting"
or
name
not
in
self
.
_extra_layer_setting
:
raise
AttributeError
(
"Cannot find {} in LayerQuantSetting!"
.
format
(
name
))
return
self
.
_extra_layer_setting
[
name
]
@
staticmethod
def
parse_optional_config
(
config
,
quant_type
,
target
):
def
get_config
(
config
,
quant_type
,
target
):
if
not
config
.
get
(
target
):
return
None
if
isinstance
(
config
[
target
],
dict
):
return
config
[
target
].
get
(
quant_type
)
else
:
return
config
[
target
]
default_val
=
quant_default_settings
[
quant_type
].
get
(
target
,
None
)
config_val
=
get_config
(
config
,
quant_type
,
target
)
val
=
config_val
if
config_val
else
default_val
return
val
def
set_quant_scheme_dtype
(
quant_type
,
new_scheme
=
None
,
new_dtype
=
None
):
# todo: remove this if we convert string config to enum type.
if
isinstance
(
quant_type
,
str
):
assert
quant_type
in
QuantType
,
"Wrong quant_type"
if
isinstance
(
new_scheme
,
str
):
assert
new_scheme
in
QuantScheme
,
"Wrong quant_scheme"
if
isinstance
(
new_dtype
,
str
):
assert
new_dtype
in
QuantDtype
,
"Wrong quant_dtype"
# TODO: It is not a good idea to directly modify global settings. A better choice is
# making this function an attribute function of Quantizer and call this function after
# the quantizer is initialized. However, within current framework of quantization, if
# we want to modify the dtype & scheme when the quantizer is initialized, we must do
# some other things (like changing the shapes of scales and zero_points and other quantization
# information in the subclass).
global
quant_default_settings
if
new_scheme
is
not
None
:
quant_default_settings
[
quant_type
][
'quant_scheme'
]
=
new_scheme
if
new_dtype
is
not
None
:
quant_default_settings
[
quant_type
][
'quant_dtype'
]
=
new_dtype
return
nni/compression/pytorch/quantization/utils.py
0 → 100644
View file @
c9cd53aa
import
torch
from
nni.common.version
import
TORCH_VERSION
from
.literal
import
QuantDtype
,
QuantScheme
,
QuantType
def
calculate_qmin_qmax
(
bits
,
dtype
):
if
dtype
==
QuantDtype
.
INT
:
qmin
,
qmax
=
-
2
**
(
bits
-
1
)
+
1
,
2
**
(
bits
-
1
)
-
1
elif
dtype
==
QuantDtype
.
UINT
:
qmin
,
qmax
=
0
,
2
**
bits
-
1
else
:
raise
TypeError
(
"Wrong quantization dtype, please make sure it is one of 'int' and 'uint'."
)
return
qmin
,
qmax
def
get_bits_length
(
config
,
quant_type
):
if
isinstance
(
config
[
"quant_bits"
],
int
):
return
config
[
"quant_bits"
]
else
:
return
config
[
"quant_bits"
].
get
(
quant_type
)
def
get_target_dim
(
quant_type
,
quant_scheme
):
# for weight: c_out x c_in x (h) * (w)
# for feature maps: batch * channel * (t) * h * w
# other type is not supported for now
default_idx
=
0
if
quant_type
==
QuantType
.
WEIGHT
else
1
if
is_per_channel
(
quant_scheme
):
target_dim
=
default_idx
else
:
target_dim
=
None
return
target_dim
def
get_min_max_value
(
x
,
quant_type
,
quant_scheme
):
target_dim
=
get_target_dim
(
quant_type
,
quant_scheme
)
if
target_dim
is
None
:
return
torch
.
min
(
x
),
torch
.
max
(
x
)
indices
=
list
(
range
(
len
(
x
.
shape
)))
assert
target_dim
<
len
(
indices
),
"target_dim needs to be less than the number of dim of the tensor"
del
indices
[
target_dim
]
if
TORCH_VERSION
>
(
1
,
6
):
min_val
=
torch
.
amin
(
x
,
indices
,
keepdims
=
True
)
max_val
=
torch
.
amax
(
x
,
indices
,
keepdims
=
True
)
else
:
min_val
=
max_val
=
x
for
ind
in
indices
:
min_val
=
torch
.
min
(
min_val
,
dim
=
ind
,
keepdim
=
True
)[
0
]
max_val
=
torch
.
max
(
max_val
,
dim
=
ind
,
keepdim
=
True
)[
0
]
return
min_val
,
max_val
def
get_mean_value
(
x
,
target_dim
=
None
):
if
target_dim
is
None
:
return
torch
.
mean
(
x
)
indices
=
list
(
range
(
len
(
x
.
shape
)))
assert
target_dim
<
len
(
indices
),
"target_dim needs to be less than the number of dim of the tensor"
del
indices
[
target_dim
]
mean_val
=
torch
.
mean
(
x
,
dim
=
indices
,
keepdim
=
True
)
return
mean_val
def
is_per_channel
(
quant_scheme
):
if
quant_scheme
in
[
QuantScheme
.
PER_CHANNEL_AFFINE
,
QuantScheme
.
PER_CHANNEL_SYMMETRIC
]:
return
True
else
:
return
False
def
get_quant_shape
(
shape
,
quant_type
,
quant_scheme
):
default_idx
=
0
if
quant_type
==
QuantType
.
WEIGHT
else
1
if
is_per_channel
(
quant_scheme
):
quant_shape
=
[
1
if
idx
!=
default_idx
else
s
for
idx
,
s
in
enumerate
(
shape
)]
else
:
quant_shape
=
[]
return
quant_shape
test/ut/sdk/test_compressor_torch.py
View file @
c9cd53aa
...
@@ -9,6 +9,7 @@ import torch.nn.functional as F
...
@@ -9,6 +9,7 @@ import torch.nn.functional as F
import
schema
import
schema
import
nni.algorithms.compression.pytorch.pruning
as
torch_pruner
import
nni.algorithms.compression.pytorch.pruning
as
torch_pruner
import
nni.algorithms.compression.pytorch.quantization
as
torch_quantizer
import
nni.algorithms.compression.pytorch.quantization
as
torch_quantizer
from
nni.compression.pytorch.quantization.utils
import
calculate_qmin_qmax
,
get_quant_shape
,
get_min_max_value
import
math
import
math
...
@@ -50,7 +51,8 @@ class CompressorTestCase(TestCase):
...
@@ -50,7 +51,8 @@ class CompressorTestCase(TestCase):
model
.
relu
=
torch
.
nn
.
ReLU
()
model
.
relu
=
torch
.
nn
.
ReLU
()
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.01
,
momentum
=
0.5
)
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.01
,
momentum
=
0.5
)
quantizer
=
torch_quantizer
.
QAT_Quantizer
(
model
,
config_list
,
optimizer
)
dummy
=
torch
.
randn
(
1
,
1
,
28
,
28
)
quantizer
=
torch_quantizer
.
QAT_Quantizer
(
model
,
config_list
,
optimizer
,
dummy_input
=
dummy
)
quantizer
.
compress
()
quantizer
.
compress
()
modules_to_compress
=
quantizer
.
get_modules_to_compress
()
modules_to_compress
=
quantizer
.
get_modules_to_compress
()
modules_to_compress_name
=
[
t
[
0
].
name
for
t
in
modules_to_compress
]
modules_to_compress_name
=
[
t
[
0
].
name
for
t
in
modules_to_compress
]
...
@@ -332,6 +334,130 @@ class CompressorTestCase(TestCase):
...
@@ -332,6 +334,130 @@ class CompressorTestCase(TestCase):
self
.
assertFalse
(
isinstance
(
model
.
fc1
.
module
.
weight
,
torch
.
nn
.
Parameter
))
self
.
assertFalse
(
isinstance
(
model
.
fc1
.
module
.
weight
,
torch
.
nn
.
Parameter
))
self
.
assertFalse
(
isinstance
(
model
.
fc2
.
module
.
weight
,
torch
.
nn
.
Parameter
))
self
.
assertFalse
(
isinstance
(
model
.
fc2
.
module
.
weight
,
torch
.
nn
.
Parameter
))
def
test_quantization_dtype_scheme
(
self
):
class
TestModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
conv1
=
torch
.
nn
.
Conv2d
(
1
,
2
,
3
,
1
)
self
.
bn1
=
torch
.
nn
.
BatchNorm2d
(
2
)
def
forward
(
self
,
x
):
x
=
self
.
bn1
(
self
.
conv1
(
x
))
return
x
dtypes
=
[
'int'
,
'uint'
]
qschemes
=
[
'per_tensor_affine'
,
'per_tensor_symmetric'
,
'per_channel_affine'
,
'per_channel_symmetric'
]
for
dtype
in
dtypes
:
for
qscheme
in
qschemes
:
config_list
=
[{
'quant_types'
:
[
'weight'
,
'input'
],
'quant_bits'
:
8
,
'op_types'
:
[
'Conv2d'
],
'quant_dtype'
:
dtype
,
'quant_scheme'
:
qscheme
}]
model
=
TestModel
()
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.01
,
momentum
=
0.5
)
# only QAT_quantizer is supported for now
dummy
=
torch
.
randn
(
1
,
1
,
4
,
4
)
quantizer
=
torch_quantizer
.
QAT_Quantizer
(
model
,
config_list
,
optimizer
,
dummy_input
=
dummy
)
# test layer setting
for
layer
,
config
in
quantizer
.
modules_to_compress
:
module
=
layer
.
module
name
=
layer
.
name
layer_setting
=
module
.
layer_quant_setting
qmin
,
qmax
=
calculate_qmin_qmax
(
8
,
dtype
)
all_quant_types
=
[
'input'
,
'weight'
]
for
quant_type
in
all_quant_types
:
# check for settings
tensor_setting
=
getattr
(
layer_setting
,
quant_type
)
self
.
assertTrue
(
tensor_setting
is
not
None
)
self
.
assertTrue
(
tensor_setting
.
quant_scheme
==
qscheme
)
self
.
assertTrue
(
tensor_setting
.
quant_dtype
==
dtype
)
self
.
assertTrue
(
tensor_setting
.
qmin
==
qmin
)
self
.
assertTrue
(
tensor_setting
.
qmax
==
qmax
)
input_shape
,
output_shape
=
quantizer
.
all_shapes
[
name
]
shape
=
input_shape
if
quant_type
==
'input'
else
module
.
weight
.
shape
quant_shape
=
get_quant_shape
(
shape
,
quant_type
,
qscheme
)
scale_name
=
quant_type
+
'_scale'
zero_point_name
=
quant_type
+
'_zero_point'
scale
=
getattr
(
module
,
scale_name
)
zero_point
=
getattr
(
module
,
zero_point_name
)
self
.
assertTrue
(
list
(
scale
.
shape
)
==
quant_shape
)
self
.
assertTrue
(
list
(
zero_point
.
shape
)
==
quant_shape
)
weight
=
torch
.
arange
(
start
=
1
,
end
=
19
).
view
(
2
,
1
,
3
,
3
)
if
qscheme
==
'per_channel_symmetric'
:
if
dtype
==
'int'
:
target_scale
=
torch
.
tensor
([
9.
/
127
,
18.
/
127
]).
view
([
2
,
1
,
1
,
1
])
target_zero_point
=
torch
.
ones
([
2
,
1
,
1
,
1
])
*
0
else
:
target_scale
=
torch
.
tensor
([
9.
/
127.5
,
18.
/
127.5
]).
view
([
2
,
1
,
1
,
1
])
target_zero_point
=
torch
.
ones
([
2
,
1
,
1
,
1
])
*
127
elif
qscheme
==
'per_tensor_symmetric'
:
if
dtype
==
'int'
:
target_scale
=
torch
.
tensor
(
18.
/
127
)
target_zero_point
=
torch
.
zeros
([])
else
:
target_scale
=
torch
.
tensor
(
18.
/
127.5
)
target_zero_point
=
torch
.
ones
([])
*
127
elif
qscheme
==
'per_channel_affine'
:
min_val
=
torch
.
tensor
([
0.
,
0.
]).
view
([
2
,
1
,
1
,
1
])
if
dtype
==
'int'
:
target_scale
=
torch
.
tensor
([
9.
/
254
,
18.
/
254
]).
view
([
2
,
1
,
1
,
1
])
target_zero_point
=
-
127
-
torch
.
round
(
min_val
/
target_scale
)
else
:
target_scale
=
torch
.
tensor
([
9.
/
255
,
18.
/
255
]).
view
([
2
,
1
,
1
,
1
])
target_zero_point
=
0
-
torch
.
round
(
min_val
/
target_scale
)
else
:
if
dtype
==
'int'
:
target_scale
=
torch
.
tensor
(
18.
/
254
)
target_zero_point
=
-
127
-
torch
.
round
(
0
/
target_scale
)
else
:
target_scale
=
torch
.
tensor
(
18.
/
255
)
target_zero_point
=
0
-
torch
.
round
(
0
/
target_scale
)
wrapper
=
getattr
(
model
,
name
)
wrapper
.
module
.
weight
=
weight
quantizer
.
quantize_weight
(
wrapper
)
self
.
assertTrue
(
torch
.
equal
(
getattr
(
model
,
name
).
module
.
weight_scale
,
target_scale
))
self
.
assertTrue
(
torch
.
equal
(
getattr
(
model
,
name
).
module
.
weight_zero_point
,
target_zero_point
))
inp
=
torch
.
arange
(
start
=
0
,
end
=
16
).
view
(
1
,
1
,
4
,
4
)
if
qscheme
==
'per_channel_symmetric'
:
if
dtype
==
'int'
:
target_scale
=
torch
.
tensor
([
15.
/
127
]).
view
([
1
,
1
,
1
,
1
])
target_zero_point
=
torch
.
ones
([
1
,
1
,
1
,
1
])
*
0
else
:
target_scale
=
torch
.
tensor
([
15.
/
127.5
]).
view
([
1
,
1
,
1
,
1
])
target_zero_point
=
torch
.
ones
([
1
,
1
,
1
,
1
])
*
127
elif
qscheme
==
'per_tensor_symmetric'
:
if
dtype
==
'int'
:
target_scale
=
torch
.
tensor
(
15.
/
127
)
target_zero_point
=
torch
.
zeros
([])
else
:
target_scale
=
torch
.
tensor
(
15.
/
127.5
)
target_zero_point
=
torch
.
ones
([])
*
127
elif
qscheme
==
'per_channel_affine'
:
min_val
=
torch
.
tensor
([
0.
]).
view
([
1
,
1
,
1
,
1
])
if
dtype
==
'int'
:
target_scale
=
torch
.
tensor
([
15.
/
254
]).
view
([
1
,
1
,
1
,
1
])
target_zero_point
=
-
127
-
torch
.
round
(
min_val
/
target_scale
)
else
:
target_scale
=
torch
.
tensor
([
15.
/
255
]).
view
([
1
,
1
,
1
,
1
])
target_zero_point
=
0
-
torch
.
round
(
min_val
/
target_scale
)
else
:
if
dtype
==
'int'
:
target_scale
=
torch
.
tensor
(
15.
/
254
)
target_zero_point
=
-
127
-
torch
.
round
(
0
/
target_scale
)
else
:
target_scale
=
torch
.
tensor
(
15.
/
255
)
target_zero_point
=
0
-
torch
.
round
(
0
/
target_scale
)
quantizer
.
quantize_input
(
inp
,
wrapper
)
self
.
assertTrue
(
torch
.
equal
(
getattr
(
model
,
name
).
module
.
input_scale
,
target_scale
))
self
.
assertTrue
(
torch
.
equal
(
getattr
(
model
,
name
).
module
.
input_zero_point
,
target_zero_point
))
def
test_torch_QAT_quantizer
(
self
):
def
test_torch_QAT_quantizer
(
self
):
model
=
TorchModel
()
model
=
TorchModel
()
config_list
=
[{
config_list
=
[{
...
@@ -347,7 +473,8 @@ class CompressorTestCase(TestCase):
...
@@ -347,7 +473,8 @@ class CompressorTestCase(TestCase):
model
.
relu
=
torch
.
nn
.
ReLU
()
model
.
relu
=
torch
.
nn
.
ReLU
()
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.01
,
momentum
=
0.5
)
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.01
,
momentum
=
0.5
)
quantizer
=
torch_quantizer
.
QAT_Quantizer
(
model
,
config_list
,
optimizer
)
dummy
=
torch
.
randn
(
1
,
1
,
28
,
28
)
quantizer
=
torch_quantizer
.
QAT_Quantizer
(
model
,
config_list
,
optimizer
,
dummy_input
=
dummy
)
quantizer
.
compress
()
quantizer
.
compress
()
# test quantize
# test quantize
...
@@ -357,20 +484,20 @@ class CompressorTestCase(TestCase):
...
@@ -357,20 +484,20 @@ class CompressorTestCase(TestCase):
weight
=
torch
.
tensor
([[
1
,
2
],
[
3
,
5
]]).
float
()
weight
=
torch
.
tensor
([[
1
,
2
],
[
3
,
5
]]).
float
()
model
.
conv2
.
module
.
weight
.
data
=
weight
model
.
conv2
.
module
.
weight
.
data
=
weight
quantizer
.
quantize_weight
(
model
.
conv2
,
input_tensor
=
input
)
quantizer
.
quantize_weight
(
model
.
conv2
,
input_tensor
=
input
)
assert
math
.
isclose
(
model
.
conv2
.
module
.
scale
,
5
/
255
,
abs_tol
=
eps
)
assert
math
.
isclose
(
model
.
conv2
.
module
.
weight_
scale
,
5
/
255
,
abs_tol
=
eps
)
assert
model
.
conv2
.
module
.
zero_point
==
0
assert
model
.
conv2
.
module
.
weight_
zero_point
==
0
quantizer
.
quantize_input
(
input
,
model
.
conv2
)
quantizer
.
quantize_input
(
input
,
model
.
conv2
)
self
.
assertTrue
(
torch
.
allclose
(
model
.
conv2
.
module
.
scale
,
torch
.
tensor
([
0.0
4
/
255
])))
self
.
assertTrue
(
torch
.
allclose
(
model
.
conv2
.
module
.
input_
scale
,
torch
.
tensor
([
4
.
/
255
])))
self
.
assertTrue
(
torch
.
equal
(
model
.
conv2
.
module
.
zero_point
,
torch
.
tensor
(
[
0.
]
)))
self
.
assertTrue
(
torch
.
equal
(
model
.
conv2
.
module
.
input_
zero_point
,
torch
.
tensor
(
0.
)))
# range including 0
# range including 0
weight
=
torch
.
tensor
([[
-
1
,
2
],
[
3
,
5
]]).
float
()
weight
=
torch
.
tensor
([[
-
1
,
2
],
[
3
,
5
]]).
float
()
model
.
conv2
.
module
.
weight
=
weight
model
.
conv2
.
module
.
weight
=
weight
quantizer
.
quantize_weight
(
model
.
conv2
,
input_tensor
=
input
)
quantizer
.
quantize_weight
(
model
.
conv2
,
input_tensor
=
input
)
assert
math
.
isclose
(
model
.
conv2
.
module
.
scale
,
6
/
255
,
abs_tol
=
eps
)
assert
math
.
isclose
(
model
.
conv2
.
module
.
weight_
scale
,
6
/
255
,
abs_tol
=
eps
)
assert
model
.
conv2
.
module
.
zero_point
in
(
42
,
43
)
assert
model
.
conv2
.
module
.
weight_
zero_point
in
(
42
,
43
)
quantizer
.
quantize_input
(
input
,
model
.
conv2
)
quantizer
.
quantize_input
(
input
,
model
.
conv2
)
self
.
assertTrue
(
torch
.
allclose
(
model
.
conv2
.
module
.
scale
,
torch
.
tensor
([
0.0796
/
255
])))
self
.
assertTrue
(
torch
.
allclose
(
model
.
conv2
.
module
.
input_
scale
,
torch
.
tensor
([
4.
/
255
])))
self
.
assertTrue
(
torch
.
equal
(
model
.
conv2
.
module
.
zero_point
,
torch
.
tensor
(
[
0.
]
)))
self
.
assertTrue
(
torch
.
equal
(
model
.
conv2
.
module
.
input_
zero_point
,
torch
.
tensor
(
0.
)))
# test value of weight and bias after quantization
# test value of weight and bias after quantization
weight
=
torch
.
tensor
([[
1.1287
,
2.3456
],
[
3.7814
,
5.9723
]])
weight
=
torch
.
tensor
([[
1.1287
,
2.3456
],
[
3.7814
,
5.9723
]])
weight_valid
=
torch
.
tensor
([[
1.1242
,
2.3421
],
[
3.7707
,
5.9723
]])
weight_valid
=
torch
.
tensor
([[
1.1242
,
2.3421
],
[
3.7707
,
5.9723
]])
...
@@ -385,15 +512,15 @@ class CompressorTestCase(TestCase):
...
@@ -385,15 +512,15 @@ class CompressorTestCase(TestCase):
# test ema
# test ema
eps
=
1e-7
eps
=
1e-7
x
=
torch
.
tensor
([[
-
0.2
,
0
],
[
0.1
,
0.2
]])
x
=
torch
.
tensor
([[
-
0.2
,
0
],
[
0.1
,
0.2
]])
out
=
model
.
relu
(
x
)
model
.
relu
(
x
)
assert
math
.
isclose
(
model
.
relu
.
module
.
tracked_min_output
,
0
,
abs_tol
=
eps
)
self
.
assert
True
(
torch
.
equal
(
model
.
relu
.
module
.
tracked_min_output
,
torch
.
tensor
(
0.
))
)
assert
math
.
isclose
(
model
.
relu
.
module
.
tracked_max_output
,
0.002
,
abs_tol
=
eps
)
self
.
assert
True
(
torch
.
equal
(
model
.
relu
.
module
.
tracked_max_output
,
torch
.
tensor
(
0.2
))
)
quantizer
.
step_with_optimizer
()
quantizer
.
step_with_optimizer
()
x
=
torch
.
tensor
([[
0.2
,
0.4
],
[
0.6
,
0.8
]])
x
=
torch
.
tensor
([[
0.2
,
0.4
],
[
0.6
,
0.8
]])
out
=
model
.
relu
(
x
)
model
.
relu
(
x
)
assert
math
.
isclose
(
model
.
relu
.
module
.
tracked_min_output
,
0.002
,
abs_tol
=
eps
)
self
.
assert
True
(
torch
.
equal
(
model
.
relu
.
module
.
tracked_min_output
,
torch
.
tensor
(
0.002
))
)
assert
math
.
isclose
(
model
.
relu
.
module
.
tracked_max_output
,
0.00998
,
abs_tol
=
eps
)
self
.
assert
True
(
torch
.
equal
(
model
.
relu
.
module
.
tracked_max_output
,
torch
.
tensor
(
0.2060
))
)
def
test_torch_quantizer_export
(
self
):
def
test_torch_quantizer_export
(
self
):
config_list_qat
=
[{
config_list_qat
=
[{
...
@@ -424,12 +551,15 @@ class CompressorTestCase(TestCase):
...
@@ -424,12 +551,15 @@ class CompressorTestCase(TestCase):
}]
}]
config_set
=
[
config_list_qat
,
config_list_dorefa
,
config_list_bnn
]
config_set
=
[
config_list_qat
,
config_list_dorefa
,
config_list_bnn
]
quantize_algorithm_set
=
[
torch_quantizer
.
QAT_Quantizer
,
torch_quantizer
.
DoReFaQuantizer
,
torch_quantizer
.
BNNQuantizer
]
quantize_algorithm_set
=
[
torch_quantizer
.
QAT_Quantizer
,
torch_quantizer
.
DoReFaQuantizer
,
torch_quantizer
.
BNNQuantizer
]
dummy
=
torch
.
randn
(
1
,
1
,
28
,
28
)
for
config
,
quantize_algorithm
in
zip
(
config_set
,
quantize_algorithm_set
):
for
config
,
quantize_algorithm
in
zip
(
config_set
,
quantize_algorithm_set
):
model
=
TorchModel
()
model
=
TorchModel
()
model
.
relu
=
torch
.
nn
.
ReLU
()
model
.
relu
=
torch
.
nn
.
ReLU
()
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.01
,
momentum
=
0.5
)
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.01
,
momentum
=
0.5
)
quantizer
=
quantize_algorithm
(
model
,
config
,
optimizer
)
if
quantize_algorithm
==
torch_quantizer
.
QAT_Quantizer
:
quantizer
=
quantize_algorithm
(
model
,
config
,
optimizer
,
dummy
)
else
:
quantizer
=
quantize_algorithm
(
model
,
config
,
optimizer
)
quantizer
.
compress
()
quantizer
.
compress
()
x
=
torch
.
rand
((
1
,
1
,
28
,
28
),
requires_grad
=
True
)
x
=
torch
.
rand
((
1
,
1
,
28
,
28
),
requires_grad
=
True
)
...
@@ -461,7 +591,11 @@ class CompressorTestCase(TestCase):
...
@@ -461,7 +591,11 @@ class CompressorTestCase(TestCase):
model
=
TorchModel
().
eval
()
model
=
TorchModel
().
eval
()
model
.
relu
=
torch
.
nn
.
ReLU
()
model
.
relu
=
torch
.
nn
.
ReLU
()
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.01
,
momentum
=
0.5
)
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.01
,
momentum
=
0.5
)
quantizer
=
quantize_algorithm
(
model
,
configure_list
,
optimizer
)
if
quantize_algorithm
==
torch_quantizer
.
QAT_Quantizer
:
dummy
=
torch
.
randn
(
1
,
1
,
28
,
28
)
quantizer
=
quantize_algorithm
(
model
,
configure_list
,
optimizer
,
dummy_input
=
dummy
)
else
:
quantizer
=
quantize_algorithm
(
model
,
configure_list
,
optimizer
)
quantizer
.
compress
()
quantizer
.
compress
()
if
calibration_config
is
not
None
:
if
calibration_config
is
not
None
:
quantizer
.
load_calibration_config
(
calibration_config
)
quantizer
.
load_calibration_config
(
calibration_config
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment