Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
c9cd53aa
Unverified
Commit
c9cd53aa
authored
Oct 13, 2021
by
chenbohua3
Committed by
GitHub
Oct 13, 2021
Browse files
support dtype&scheme customization for QAT quantizer (#4137)
parent
b0f34da1
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
775 additions
and
190 deletions
+775
-190
docs/en_US/Compression/CustomizeCompressor.rst
docs/en_US/Compression/CustomizeCompressor.rst
+2
-2
examples/model_compress/quantization/QAT_torch_quantizer.py
examples/model_compress/quantization/QAT_torch_quantizer.py
+29
-16
nni/algorithms/compression/pytorch/quantization/quantizers.py
...algorithms/compression/pytorch/quantization/quantizers.py
+226
-122
nni/common/version.py
nni/common/version.py
+7
-0
nni/compression/pytorch/compressor.py
nni/compression/pytorch/compressor.py
+77
-31
nni/compression/pytorch/quantization/literal.py
nni/compression/pytorch/quantization/literal.py
+65
-0
nni/compression/pytorch/quantization/observers.py
nni/compression/pytorch/quantization/observers.py
+15
-0
nni/compression/pytorch/quantization/settings.py
nni/compression/pytorch/quantization/settings.py
+118
-0
nni/compression/pytorch/quantization/utils.py
nni/compression/pytorch/quantization/utils.py
+83
-0
test/ut/sdk/test_compressor_torch.py
test/ut/sdk/test_compressor_torch.py
+153
-19
No files found.
docs/en_US/Compression/CustomizeCompressor.rst
View file @
c9cd53aa
...
@@ -155,7 +155,7 @@ Sometimes it's necessary for a quantization operation to have a customized backw
...
@@ -155,7 +155,7 @@ Sometimes it's necessary for a quantization operation to have a customized backw
grad_output : Tensor
grad_output : Tensor
gradient of the output of quantization operation
gradient of the output of quantization operation
quant_type : QuantType
quant_type : QuantType
the type of quantization, it can be `QuantType.
QUANT_
INPUT`, `QuantType.
QUANT_
WEIGHT`, `QuantType.
QUANT_
OUTPUT`,
the type of quantization, it can be `QuantType.INPUT`, `QuantType.WEIGHT`, `QuantType.OUTPUT`,
you can define different behavior for different types.
you can define different behavior for different types.
Returns
Returns
-------
-------
...
@@ -164,7 +164,7 @@ Sometimes it's necessary for a quantization operation to have a customized backw
...
@@ -164,7 +164,7 @@ Sometimes it's necessary for a quantization operation to have a customized backw
"""
"""
# for quant_output function, set grad to zero if the absolute value of tensor is larger than 1
# for quant_output function, set grad to zero if the absolute value of tensor is larger than 1
if quant_type == QuantType.
QUANT_
OUTPUT:
if quant_type == QuantType.OUTPUT:
grad_output[torch.abs(tensor) > 1] = 0
grad_output[torch.abs(tensor) > 1] = 0
return grad_output
return grad_output
...
...
examples/model_compress/quantization/QAT_torch_quantizer.py
View file @
c9cd53aa
...
@@ -2,11 +2,13 @@ import torch
...
@@ -2,11 +2,13 @@ import torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
torchvision
import
datasets
,
transforms
from
torchvision
import
datasets
,
transforms
from
nni.algorithms.compression.pytorch.quantization
import
QAT_Quantizer
from
nni.algorithms.compression.pytorch.quantization
import
QAT_Quantizer
from
nni.compression.pytorch.quantization.settings
import
set_quant_scheme_dtype
import
sys
import
sys
sys
.
path
.
append
(
'../models'
)
sys
.
path
.
append
(
'../models'
)
from
mnist.naive
import
NaiveModel
from
mnist.naive
import
NaiveModel
def
train
(
model
,
device
,
train_loader
,
optimizer
):
def
train
(
model
,
device
,
train_loader
,
optimizer
):
model
.
train
()
model
.
train
()
for
batch_idx
,
(
data
,
target
)
in
enumerate
(
train_loader
):
for
batch_idx
,
(
data
,
target
)
in
enumerate
(
train_loader
):
...
@@ -68,13 +70,23 @@ def main():
...
@@ -68,13 +70,23 @@ def main():
},
{
},
{
'quant_types'
:
[
'output'
,
'weight'
,
'input'
],
'quant_types'
:
[
'output'
,
'weight'
,
'input'
],
'quant_bits'
:
{
'output'
:
8
,
'weight'
:
8
,
'input'
:
8
},
'quant_bits'
:
{
'output'
:
8
,
'weight'
:
8
,
'input'
:
8
},
'op_names'
:
[
'fc1'
],
'op_names'
:
[
'fc1'
,
'fc2'
],
},
{
'quant_types'
:
[
'output'
,
'weight'
,
'input'
],
'quant_bits'
:
{
'output'
:
8
,
'weight'
:
8
,
'input'
:
8
},
'op_names'
:
[
'fc2'
],
}]
}]
# you can also set the quantization dtype and scheme layer-wise through configure_list like:
# configure_list = [{
# 'quant_types': ['weight', 'input'],
# 'quant_bits': {'weight': 8, 'input': 8},
# 'op_names': ['conv1', 'conv2'],
# 'quant_dtype': 'int',
# 'quant_scheme': 'per_channel_symmetric'
# }]
# For now quant_dtype's options are 'int' and 'uint. And quant_scheme's options are per_tensor_affine,
# per_tensor_symmetric, per_channel_affine and per_channel_symmetric.
set_quant_scheme_dtype
(
'weight'
,
'per_channel_symmetric'
,
'int'
)
set_quant_scheme_dtype
(
'output'
,
'per_tensor_symmetric'
,
'int'
)
set_quant_scheme_dtype
(
'input'
,
'per_tensor_symmetric'
,
'int'
)
model
=
NaiveModel
().
to
(
device
)
model
=
NaiveModel
().
to
(
device
)
dummy_input
=
torch
.
randn
(
1
,
1
,
28
,
28
).
to
(
device
)
dummy_input
=
torch
.
randn
(
1
,
1
,
28
,
28
).
to
(
device
)
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.01
,
momentum
=
0.5
)
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.01
,
momentum
=
0.5
)
...
@@ -98,5 +110,6 @@ def main():
...
@@ -98,5 +110,6 @@ def main():
calibration_config
=
quantizer
.
export_model
(
model_path
,
calibration_path
,
onnx_path
,
input_shape
,
device
)
calibration_config
=
quantizer
.
export_model
(
model_path
,
calibration_path
,
onnx_path
,
input_shape
,
device
)
print
(
"Generated calibration config is: "
,
calibration_config
)
print
(
"Generated calibration config is: "
,
calibration_config
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
main
()
main
()
nni/algorithms/compression/pytorch/quantization/quantizers.py
View file @
c9cd53aa
...
@@ -6,9 +6,21 @@ from collections import defaultdict
...
@@ -6,9 +6,21 @@ from collections import defaultdict
import
torch
import
torch
from
schema
import
Schema
,
And
,
Or
,
Optional
from
schema
import
Schema
,
And
,
Or
,
Optional
from
nni.compression.pytorch.utils.config_validation
import
QuantizerSchema
from
nni.compression.pytorch.utils.config_validation
import
QuantizerSchema
from
nni.compression.pytorch.compressor
import
BN_FOLD_TAG
,
Quantizer
,
QuantForward
,
QuantGrad
,
QuantType
from
nni.compression.pytorch.compressor
import
BN_FOLD_TAG
,
Quantizer
,
QuantForward
,
QuantGrad
from
nni.compression.pytorch.quantization.literal
import
(
from
.observers
import
default_weight_observer
,
default_histogram_observer
PER_CHANNEL_QUANT_SCHEME
,
QuantScheme
,
QuantDtype
,
QuantType
)
from
nni.compression.pytorch.quantization.observers
import
default_weight_observer
,
default_histogram_observer
from
nni.compression.pytorch.quantization.settings
import
LayerQuantSetting
from
nni.compression.pytorch.quantization.utils
import
(
calculate_qmin_qmax
,
get_bits_length
,
get_min_max_value
,
get_quant_shape
)
__all__
=
[
'NaiveQuantizer'
,
'QAT_Quantizer'
,
'DoReFaQuantizer'
,
'BNNQuantizer'
,
'LsqQuantizer'
,
'ObserverQuantizer'
]
__all__
=
[
'NaiveQuantizer'
,
'QAT_Quantizer'
,
'DoReFaQuantizer'
,
'BNNQuantizer'
,
'LsqQuantizer'
,
'ObserverQuantizer'
]
...
@@ -65,7 +77,7 @@ def update_ema(biased_ema, value, decay):
...
@@ -65,7 +77,7 @@ def update_ema(biased_ema, value, decay):
return
biased_ema
return
biased_ema
def
update_quantization_param
(
bits
,
rmin
,
rmax
):
def
update_quantization_param
(
bits
,
rmin
,
rmax
,
dtype
,
scheme
):
"""
"""
calculate the `zero_point` and `scale`.
calculate the `zero_point` and `scale`.
...
@@ -77,41 +89,46 @@ def update_quantization_param(bits, rmin, rmax):
...
@@ -77,41 +89,46 @@ def update_quantization_param(bits, rmin, rmax):
min value of real value
min value of real value
rmax : Tensor
rmax : Tensor
max value of real value
max value of real value
dtype : QuantDtype
quantized data type
scheme : QuantScheme
quantization scheme to be used
Returns
Returns
-------
-------
float, float
float, float
"""
"""
# extend the [min, max] interval to ensure that it contains 0.
# extend the [min, max] interval to ensure that it contains 0.
# Otherwise, we would not meet the requirement that 0 be an exactly
# Otherwise, we would not meet the requirement that 0 be an exactly
# representable value.
# representable value.
rmin
=
torch
.
min
(
rm
in
,
t
orch
.
Tensor
([
0
]).
to
(
rmin
.
device
))
# I think this is for activations that need to be pad
in t
he training.
rmax
=
torch
.
max
(
rmax
,
torch
.
Tensor
([
0
]).
to
(
rmin
.
device
))
# However this is a default behavior in PyTorch quantization observer.
qmin
=
torch
.
Tensor
([
0
]).
to
(
rmin
.
device
)
# So we also make it a default behavior
qmax
=
torch
.
Tensor
([(
1
<<
bits
)
-
1
]).
to
(
rmin
.
device
)
rmin
=
torch
.
min
(
rmin
,
torch
.
zeros_like
(
rmin
)
)
rmax
=
torch
.
max
(
rmax
,
torch
.
zeros_like
(
rmax
))
# First determine the scale.
zero_point
=
torch
.
zeros_like
(
rmin
)
scale
=
(
rmax
-
rmin
)
/
(
qmax
-
qmin
)
# todo: there is no need to calculate qmin and qmax again
# Zero-point computation.
qmin
,
qmax
=
calculate_qmin_qmax
(
bits
,
dtype
)
initial_zero_point
=
qmin
-
rmin
/
scale
if
scheme
in
[
QuantScheme
.
PER_TENSOR_SYMMETRIC
,
QuantScheme
.
PER_CHANNEL_SYMMETRIC
]:
# Now we need to nudge the zero point to be an integer
abs_max
=
torch
.
max
(
torch
.
abs
(
rmin
),
torch
.
abs
(
rmax
))
if
initial_zero_point
<
qmin
:
scale
=
abs_max
/
(
float
(
qmax
-
qmin
)
/
2
)
nudged_zero_point
=
qmin
if
dtype
==
QuantDtype
.
UINT
:
elif
initial_zero_point
>
qmax
:
zero_point_val
=
(
qmin
+
qmax
)
//
2
nudged_
zero_point
=
qmax
zero_point
=
zero_point
.
new_full
(
zero_point
.
size
(),
zero_point
_val
)
else
:
else
:
nudged_zero_point
=
torch
.
round
(
initial_zero_point
)
scale
=
(
rmax
-
rmin
)
/
float
(
qmax
-
qmin
)
zero_point
=
qmin
-
torch
.
round
(
rmin
/
scale
)
return
scale
,
nudged_zero_point
zero_point
=
torch
.
clamp
(
zero_point
,
qmin
,
qmax
)
# todo: add these lines
# eps = torch.finfo(torch.float32).eps
# scale = torch.max(scale, eps)
return
scale
,
zero_point
def
get_bits_length
(
config
,
quant_type
):
if
isinstance
(
config
[
"quant_bits"
],
int
):
return
config
[
"quant_bits"
]
else
:
return
config
[
"quant_bits"
].
get
(
quant_type
)
class
QATGrad
(
QuantGrad
):
class
QATGrad
(
QuantGrad
):
@
staticmethod
@
staticmethod
...
@@ -384,22 +401,49 @@ class QAT_Quantizer(Quantizer):
...
@@ -384,22 +401,49 @@ class QAT_Quantizer(Quantizer):
self
.
bound_model
.
register_buffer
(
"steps"
,
torch
.
tensor
(
1
))
self
.
bound_model
.
register_buffer
(
"steps"
,
torch
.
tensor
(
1
))
for
layer
,
config
in
modules_to_compress
:
for
layer
,
config
in
modules_to_compress
:
module
=
layer
.
module
module
=
layer
.
module
module
.
register_buffer
(
"zero_point"
,
torch
.
tensor
([
0.0
]))
name
=
layer
.
name
module
.
register_buffer
(
"scale"
,
torch
.
tensor
([
1.0
]))
# TODO: may relax this limitation?
module
.
register_buffer
(
'ema_decay'
,
torch
.
tensor
([
0.99
]))
assert
name
in
self
.
all_shapes
,
"Could not found shapes for layer {}"
.
format
(
name
)
input_shape
,
output_shape
=
self
.
all_shapes
[
name
]
layer_quant_setting
=
LayerQuantSetting
(
config
)
layer_quant_setting
.
ema_decay
=
0.99
quant_start_step
=
config
.
get
(
'quant_start_step'
,
0
)
layer_quant_setting
.
quant_start_step
=
quant_start_step
# todo: support other ranks and remove this check
if
isinstance
(
module
,
torch
.
nn
.
Linear
):
if
"input"
in
config
.
get
(
"quant_types"
,
[])
and
\
layer_quant_setting
.
input
.
quant_scheme
in
PER_CHANNEL_QUANT_SCHEME
:
if
len
(
input_shape
)
!=
2
:
logger
.
warning
(
"When quantize torch.nn.Linear, make sure that the rank of the inputs "
"of the layer is 2. Skip quantization of layer %s."
,
name
)
continue
if
"output"
in
config
.
get
(
"quant_types"
,
[])
and
\
layer_quant_setting
.
output
.
quant_scheme
in
PER_CHANNEL_QUANT_SCHEME
:
if
len
(
output_shape
)
!=
2
:
logger
.
warning
(
"When quantize torch.nn.Linear, make sure that the rank of the outputs "
"of the layer is 2. Skip quantization of layer %s."
,
name
)
continue
if
"weight"
in
config
.
get
(
"quant_types"
,
[]):
if
"weight"
in
config
.
get
(
"quant_types"
,
[]):
weight_bits
=
get_bits_length
(
config
,
'weight'
)
quant_shape
=
get_quant_shape
(
module
.
weight
.
shape
,
QuantType
.
WEIGHT
,
layer_quant_setting
.
weight
.
quant_scheme
)
layer
.
module
.
register_buffer
(
'weight_bits'
,
torch
.
Tensor
([
int
(
weight_bits
)]))
module
.
register_buffer
(
'weight_scale'
,
torch
.
zeros
(
quant_shape
))
module
.
register_buffer
(
'weight_zero_point'
,
torch
.
zeros
(
quant_shape
))
if
"input"
in
config
.
get
(
"quant_types"
,
[]):
if
"input"
in
config
.
get
(
"quant_types"
,
[]):
input_bits
=
get_bits_length
(
config
,
'input'
)
quant_shape
=
get_quant_shape
(
input_shape
,
QuantType
.
INPUT
,
layer_quant_setting
.
input
.
quant_scheme
)
layer
.
module
.
register_buffer
(
'tracked_min_input'
,
torch
.
zeros
(
1
))
module
.
register_buffer
(
'tracked_min_input'
,
torch
.
zeros
(
quant_shape
))
layer
.
module
.
register_buffer
(
'tracked_max_input'
,
torch
.
zeros
(
1
))
module
.
register_buffer
(
'tracked_max_input'
,
torch
.
zeros
(
quant_shape
))
layer
.
module
.
register_buffer
(
'input_bits'
,
torch
.
Tensor
([
int
(
input_bits
)]))
module
.
register_buffer
(
'input_scale'
,
torch
.
zeros
(
quant_shape
))
module
.
register_buffer
(
'input_zero_point'
,
torch
.
zeros
(
quant_shape
))
if
"output"
in
config
.
get
(
"quant_types"
,
[]):
if
"output"
in
config
.
get
(
"quant_types"
,
[]):
output_bits
=
get_bits_length
(
config
,
'output'
)
quant_shape
=
get_quant_shape
(
output_shape
,
QuantType
.
OUTPUT
,
layer_quant_setting
.
output
.
quant_scheme
)
layer
.
module
.
register_buffer
(
'output_bits'
,
torch
.
Tensor
([
int
(
output_bits
)]))
module
.
register_buffer
(
'tracked_min_output'
,
torch
.
zeros
(
quant_shape
))
layer
.
module
.
register_buffer
(
'tracked_min_output'
,
torch
.
zeros
(
1
))
module
.
register_buffer
(
'tracked_max_output'
,
torch
.
zeros
(
quant_shape
))
layer
.
module
.
register_buffer
(
'tracked_max_output'
,
torch
.
zeros
(
1
))
module
.
register_buffer
(
'output_scale'
,
torch
.
zeros
(
quant_shape
))
module
.
register_buffer
(
'output_zero_point'
,
torch
.
zeros
(
quant_shape
))
setattr
(
module
,
"layer_quant_setting"
,
layer_quant_setting
)
self
.
bound_model
.
to
(
device
)
self
.
bound_model
.
to
(
device
)
def
_del_simulated_attr
(
self
,
module
):
def
_del_simulated_attr
(
self
,
module
):
...
@@ -407,8 +451,9 @@ class QAT_Quantizer(Quantizer):
...
@@ -407,8 +451,9 @@ class QAT_Quantizer(Quantizer):
delete redundant parameters in quantize module
delete redundant parameters in quantize module
"""
"""
del_attr_list
=
[
'old_weight'
,
'old_bias'
,
'ema_decay'
,
'tracked_min_output'
,
'tracked_max_output'
,
del_attr_list
=
[
'old_weight'
,
'old_bias'
,
'ema_decay'
,
'tracked_min_output'
,
'tracked_max_output'
,
'tracked_min_input'
,
'tracked_max_input'
,
'scale'
,
'zero_point'
,
'weight_bits'
,
'tracked_min_input'
,
'tracked_max_input'
,
'BN_FOLD_TAG'
,
'output_bits'
,
'BN_FOLD_TAG'
,
'input_bits'
]
'weight_scale'
,
'weight_zero_point'
,
'input_scale'
,
'input_zero_point'
,
'output_scale'
,
'output_zero_point'
,
'layer_quant_setting'
]
for
attr
in
del_attr_list
:
for
attr
in
del_attr_list
:
if
hasattr
(
module
,
attr
):
if
hasattr
(
module
,
attr
):
delattr
(
module
,
attr
)
delattr
(
module
,
attr
)
...
@@ -422,6 +467,7 @@ class QAT_Quantizer(Quantizer):
...
@@ -422,6 +467,7 @@ class QAT_Quantizer(Quantizer):
config_list : list of dict
config_list : list of dict
List of configurations
List of configurations
"""
"""
SUPPORTED_OPS
=
[
'Conv2d'
,
'Linear'
,
'ReLU'
,
'ReLU6'
]
schema
=
QuantizerSchema
([{
schema
=
QuantizerSchema
([{
Optional
(
'quant_types'
):
Schema
([
lambda
x
:
x
in
[
'weight'
,
'output'
,
'input'
]]),
Optional
(
'quant_types'
):
Schema
([
lambda
x
:
x
in
[
'weight'
,
'output'
,
'input'
]]),
Optional
(
'quant_bits'
):
Or
(
And
(
int
,
lambda
n
:
0
<
n
<
32
),
Schema
({
Optional
(
'quant_bits'
):
Or
(
And
(
int
,
lambda
n
:
0
<
n
<
32
),
Schema
({
...
@@ -429,41 +475,51 @@ class QAT_Quantizer(Quantizer):
...
@@ -429,41 +475,51 @@ class QAT_Quantizer(Quantizer):
Optional
(
'weight'
):
And
(
int
,
lambda
n
:
0
<
n
<
32
),
Optional
(
'weight'
):
And
(
int
,
lambda
n
:
0
<
n
<
32
),
Optional
(
'output'
):
And
(
int
,
lambda
n
:
0
<
n
<
32
),
Optional
(
'output'
):
And
(
int
,
lambda
n
:
0
<
n
<
32
),
})),
})),
Optional
(
'quant_scheme'
):
Or
(
lambda
x
:
x
in
QuantScheme
,
Schema
({
Optional
(
'input'
):
lambda
x
:
x
in
QuantScheme
,
Optional
(
'weight'
):
lambda
x
:
x
in
QuantScheme
,
Optional
(
'output'
):
lambda
x
:
x
in
QuantScheme
})),
Optional
(
'quant_dtype'
):
Or
(
lambda
x
:
x
in
QuantDtype
,
Schema
({
Optional
(
'input'
):
lambda
x
:
x
in
QuantDtype
,
Optional
(
'weight'
):
lambda
x
:
x
in
QuantDtype
,
Optional
(
'output'
):
lambda
x
:
x
in
QuantDtype
})),
Optional
(
'quant_start_step'
):
And
(
int
,
lambda
n
:
n
>=
0
),
Optional
(
'quant_start_step'
):
And
(
int
,
lambda
n
:
n
>=
0
),
Optional
(
'op_types'
):
[
str
],
Optional
(
'op_types'
):
[
And
(
str
,
lambda
n
:
n
in
SUPPORTED_OPS
)
],
Optional
(
'op_names'
):
[
str
],
Optional
(
'op_names'
):
[
str
],
Optional
(
'exclude'
):
bool
Optional
(
'exclude'
):
bool
}],
model
,
logger
)
}],
model
,
logger
)
schema
.
validate
(
config_list
)
schema
.
validate
(
config_list
)
def
_quantize
(
self
,
bits
,
op
,
real_val
):
def
_quantize
(
self
,
real_value
,
scale
,
zero_point
,
qmin
,
qmax
):
"""
"""
quantize real value.
quantize real value.
Parameters
Parameters
----------
----------
bits : int
real_value : torch.Tensor
quantization bits length
the real value to be quantized
op : torch.nn.Module
scale : torch.Tensor
target module
quantization scale
real_val : Tensor
zero_point : torch.Tensor
real value to be quantized
quantization zero point
qmin : int
lower bound of the int range
qmax : int
upper bound of the int range
Returns
Returns
-------
-------
Tensor
Tensor
"""
"""
op
.
zero_point
=
op
.
zero_point
.
to
(
real_val
.
device
)
transformed_val
=
zero_point
+
real_value
/
scale
op
.
scale
=
op
.
scale
.
to
(
real_val
.
device
)
transformed_val
=
op
.
zero_point
+
real_val
/
op
.
scale
qmin
=
0
qmax
=
(
1
<<
bits
)
-
1
clamped_val
=
torch
.
clamp
(
transformed_val
,
qmin
,
qmax
)
clamped_val
=
torch
.
clamp
(
transformed_val
,
qmin
,
qmax
)
quantized_val
=
torch
.
round
(
clamped_val
)
quantized_val
=
torch
.
round
(
clamped_val
)
return
quantized_val
return
quantized_val
def
_dequantize
(
self
,
op
,
quantized_val
):
def
_dequantize
(
self
,
quantized_val
,
scale
,
zero_point
):
"""
"""
dequantize quantized value.
dequantize quantized value.
Because we simulate quantization in training process, all the computations still happen as float point computations, which means we
Because we simulate quantization in training process, all the computations still happen as float point computations, which means we
...
@@ -471,103 +527,149 @@ class QAT_Quantizer(Quantizer):
...
@@ -471,103 +527,149 @@ class QAT_Quantizer(Quantizer):
Parameters
Parameters
----------
----------
op : torch.nn.Module
quantized_val : torch.Tensor
target module
the quantized value to be de-quantized
quantized_val : float
scale : torch.Tensor
quantized_val value to be dequantized
quantization scale
zero_point : torch.Tensor
quantization zero point
Returns
Returns
-------
-------
float
Tensor
"""
"""
real_val
=
op
.
scale
*
(
quantized_val
-
op
.
zero_point
)
real_val
=
scale
*
(
quantized_val
-
zero_point
)
return
real_val
return
real_val
def
quantize_weight
(
self
,
wrapper
,
**
kwargs
):
def
quantize_weight
(
self
,
wrapper
,
**
kwargs
):
config
=
wrapper
.
config
module
=
wrapper
.
module
module
=
wrapper
.
module
weight
=
module
.
weight
weight
=
module
.
weight
weight_bits
=
int
(
module
.
weight_bits
)
layer_quant_setting
=
module
.
layer_quant_setting
quant_start_step
=
config
.
get
(
'quant_start_step'
,
0
)
tensor_quant_setting
=
layer_quant_setting
.
weight
assert
weight_bits
>=
1
,
"quant bits length should be at least 1"
if
quant_start_step
>
int
(
self
.
bound_model
.
steps
):
# layer-wise settings
return
weight
quant_start_step
=
layer_quant_setting
.
quant_start_step
# tensor-wise settings
dtype
=
tensor_quant_setting
.
quant_dtype
scheme
=
tensor_quant_setting
.
quant_scheme
qmin
,
qmax
=
tensor_quant_setting
.
get_qmin_qmax
()
bits
=
tensor_quant_setting
.
bits
# In evaluation mode, we only quantize weight without updating statistics
if
not
wrapper
.
training
:
if
not
wrapper
.
training
:
scale
,
zero_point
=
module
.
weight_scale
,
module
.
weight_zero_point
weight
=
self
.
_quantize
(
weight
,
scale
,
zero_point
,
qmin
,
qmax
)
weight
=
self
.
_dequantize
(
weight
,
scale
,
zero_point
)
module
.
weight
=
weight
return
weight
if
quant_start_step
>
int
(
self
.
bound_model
.
steps
):
return
weight
return
weight
# quantize weight
current_min
,
current_max
=
get_min_max_value
(
weight
,
QuantType
.
WEIGHT
,
scheme
)
rmin
,
rmax
=
torch
.
min
(
weight
),
torch
.
max
(
weight
)
scale
,
zero_point
=
update_quantization_param
(
bits
,
current_min
,
current_max
,
dtype
,
scheme
)
scale
,
zero_point
=
update_quantization_param
(
weight_bits
,
rmin
,
rmax
)
module
.
weight_scale
.
copy_
(
scale
)
module
.
scale
.
copy_
(
scale
)
module
.
weight_zero_point
.
copy_
(
zero_point
)
module
.
zero_point
.
copy_
(
zero_point
)
weight
=
self
.
_quantize
(
weight
,
scale
,
zero_point
,
qmin
,
qmax
)
weight
=
self
.
_quantize
(
weight_bits
,
module
,
weight
)
weight
=
self
.
_dequantize
(
weight
,
scale
,
zero_point
)
weight
=
self
.
_dequantize
(
module
,
weight
)
# Weight can not be in-place modified, so when use torch.nn.DataParallel, this update
# will be lost after each forward process. However, this update takes effect on each
# replicated module during each forward process, which will make the quantized weight
# be used correctly.
wrapper
.
module
.
weight
=
weight
wrapper
.
module
.
weight
=
weight
return
weight
return
weight
def
quantize_input
(
self
,
inputs
,
wrapper
,
**
kwargs
):
def
quantize_input
(
self
,
inputs
,
wrapper
,
**
kwargs
):
config
=
wrapper
.
config
module
=
wrapper
.
module
module
=
wrapper
.
module
input_bits
=
int
(
module
.
input_bits
)
quant_start_step
=
config
.
get
(
'quant_start_step'
,
0
)
assert
input_bits
>=
1
,
"quant bits length should be at least 1"
if
quant_start_step
>
int
(
self
.
bound_model
.
steps
):
layer_quant_setting
=
module
.
layer_quant_setting
current_min
,
current_max
=
torch
.
min
(
inputs
),
torch
.
max
(
inputs
)
tensor_quant_setting
=
layer_quant_setting
.
input
module
.
tracked_min_input
.
copy_
(
current_min
)
module
.
tracked_max_input
.
copy_
(
current_max
)
# layer-wise settings
quant_start_step
=
layer_quant_setting
.
quant_start_step
ema_decay
=
layer_quant_setting
.
ema_decay
# tensor-wise settings
dtype
=
tensor_quant_setting
.
quant_dtype
scheme
=
tensor_quant_setting
.
quant_scheme
qmin
,
qmax
=
tensor_quant_setting
.
get_qmin_qmax
()
bits
=
tensor_quant_setting
.
bits
if
not
wrapper
.
training
:
scale
=
module
.
input_scale
zero_point
=
module
.
input_zero_point
inputs
=
self
.
_quantize
(
inputs
,
scale
,
zero_point
,
qmin
,
qmax
)
inputs
=
self
.
_dequantize
(
inputs
,
scale
,
zero_point
)
return
inputs
return
inputs
# we dont update output quantization parameters in evaluation stage
current_min
,
current_max
=
get_min_max_value
(
inputs
,
QuantType
.
INPUT
,
scheme
)
if
wrapper
.
training
:
current_min
,
current_max
=
torch
.
min
(
inputs
),
torch
.
max
(
inputs
)
if
int
(
self
.
bound_model
.
steps
)
==
1
:
current_min
=
update_ema
(
module
.
tracked_min_input
,
current_min
,
module
.
ema_decay
)
current_max
=
update_ema
(
module
.
tracked_max_input
,
current_max
,
module
.
ema_decay
)
module
.
tracked_min_input
.
copy_
(
current_min
)
module
.
tracked_min_input
.
copy_
(
current_min
)
module
.
tracked_max_input
.
copy_
(
current_max
)
module
.
tracked_max_input
.
copy_
(
current_max
)
tracked_min_input
=
update_ema
(
module
.
tracked_min_input
,
current_min
,
ema_decay
)
tracked_max_input
=
update_ema
(
module
.
tracked_max_input
,
current_max
,
ema_decay
)
module
.
tracked_min_input
.
copy_
(
tracked_min_input
)
module
.
tracked_max_input
.
copy_
(
tracked_max_input
)
if
quant_start_step
>
int
(
self
.
bound_model
.
steps
):
return
inputs
scale
,
zero_point
=
update_quantization_param
(
scale
,
zero_point
=
update_quantization_param
(
input_
bits
,
module
.
tracked_min_input
,
module
.
tracked_max_input
)
bits
,
module
.
tracked_min_input
,
module
.
tracked_max_input
,
dtype
,
scheme
)
module
.
scale
.
copy_
(
scale
)
module
.
input_
scale
.
copy_
(
scale
)
module
.
zero_point
.
copy_
(
zero_point
)
module
.
input_
zero_point
.
copy_
(
zero_point
)
inp
=
self
.
_quantize
(
input
_bits
,
module
,
inputs
)
inp
uts
=
self
.
_quantize
(
input
s
,
scale
,
zero_point
,
qmin
,
qmax
)
inp
=
self
.
_dequantize
(
module
,
in
p
)
inp
uts
=
self
.
_dequantize
(
inputs
,
scale
,
zero_po
in
t
)
return
inp
return
inp
uts
def
quantize_output
(
self
,
output
,
wrapper
,
**
kwargs
):
def
quantize_output
(
self
,
output
,
wrapper
,
**
kwargs
):
config
=
wrapper
.
config
module
=
wrapper
.
module
module
=
wrapper
.
module
output_bits
=
int
(
module
.
output_bits
)
layer_quant_setting
=
module
.
layer_quant_setting
quant_start_step
=
config
.
get
(
'quant_start_step'
,
0
)
tensor_quant_setting
=
layer_quant_setting
.
output
assert
output_bits
>=
1
,
"quant bits length should be at least 1"
if
quant_start_step
>
int
(
self
.
bound_model
.
steps
):
# layer-wise settings
current_min
,
current_max
=
torch
.
min
(
output
),
torch
.
max
(
output
)
quant_start_step
=
layer_quant_setting
.
quant_start_step
ema_decay
=
layer_quant_setting
.
ema_decay
# tensor-wise settings
dtype
=
tensor_quant_setting
.
quant_dtype
scheme
=
tensor_quant_setting
.
quant_scheme
qmin
,
qmax
=
tensor_quant_setting
.
get_qmin_qmax
()
bits
=
tensor_quant_setting
.
bits
if
not
wrapper
.
training
:
scale
=
module
.
output_scale
zero_point
=
module
.
output_zero_point
output
=
self
.
_quantize
(
output
,
scale
,
zero_point
,
qmin
,
qmax
)
output
=
self
.
_dequantize
(
output
,
scale
,
zero_point
)
return
output
current_min
,
current_max
=
get_min_max_value
(
output
,
QuantType
.
OUTPUT
,
scheme
)
if
int
(
self
.
bound_model
.
steps
)
==
1
:
module
.
tracked_min_output
.
copy_
(
current_min
)
module
.
tracked_min_output
.
copy_
(
current_min
)
module
.
tracked_max_output
.
copy_
(
current_max
)
module
.
tracked_max_output
.
copy_
(
current_max
)
return
output
# we dont update output quantization parameters in evaluation stage
tracked_min_output
=
update_ema
(
module
.
tracked_min_output
,
current_min
,
ema_decay
)
if
wrapper
.
training
:
tracked_max_output
=
update_ema
(
module
.
tracked_max_output
,
current_max
,
ema_decay
)
current_min
,
current_max
=
torch
.
min
(
output
),
torch
.
max
(
output
)
tracked_min_output
=
update_ema
(
module
.
tracked_min_output
,
current_min
,
module
.
ema_decay
)
tracked_max_output
=
update_ema
(
module
.
tracked_max_output
,
current_max
,
module
.
ema_decay
)
module
.
tracked_min_output
.
copy_
(
tracked_min_output
)
module
.
tracked_min_output
.
copy_
(
tracked_min_output
)
module
.
tracked_max_output
.
copy_
(
tracked_max_output
)
module
.
tracked_max_output
.
copy_
(
tracked_max_output
)
if
quant_start_step
>
int
(
self
.
bound_model
.
steps
):
return
output
scale
,
zero_point
=
update_quantization_param
(
scale
,
zero_point
=
update_quantization_param
(
output_
bits
,
module
.
tracked_min_output
,
module
.
tracked_max_output
)
bits
,
module
.
tracked_min_output
,
module
.
tracked_max_output
,
dtype
,
scheme
)
module
.
scale
.
copy_
(
scale
)
module
.
output_
scale
.
copy_
(
scale
)
module
.
zero_point
.
copy_
(
zero_point
)
module
.
output_
zero_point
.
copy_
(
zero_point
)
out
=
self
.
_quantize
(
output
_bits
,
module
,
output
)
out
put
=
self
.
_quantize
(
output
,
scale
,
zero_point
,
qmin
,
qmax
)
out
=
self
.
_dequantize
(
module
,
ou
t
)
out
put
=
self
.
_dequantize
(
output
,
scale
,
zero_poin
t
)
return
out
return
out
put
def
load_calibration_config
(
self
,
calibration_config
):
def
load_calibration_config
(
self
,
calibration_config
):
modules_to_compress
=
self
.
get_modules_to_compress
()
modules_to_compress
=
self
.
get_modules_to_compress
()
...
@@ -581,12 +683,12 @@ class QAT_Quantizer(Quantizer):
...
@@ -581,12 +683,12 @@ class QAT_Quantizer(Quantizer):
assert
calibration_config
[
name
][
'weight_bits'
]
==
module
.
weight_bits
,
f
"weight bits of module
{
name
}
fail to match"
assert
calibration_config
[
name
][
'weight_bits'
]
==
module
.
weight_bits
,
f
"weight bits of module
{
name
}
fail to match"
if
hasattr
(
module
,
'input_bits'
):
if
hasattr
(
module
,
'input_bits'
):
assert
calibration_config
[
name
][
'input_bits'
]
==
module
.
input_bits
,
f
"input bits of module
{
name
}
fail to match"
assert
calibration_config
[
name
][
'input_bits'
]
==
module
.
input_bits
,
f
"input bits of module
{
name
}
fail to match"
module
.
tracked_min_input
.
data
=
torch
.
T
ensor
([
calibration_config
[
name
][
'tracked_min_input'
]])
module
.
tracked_min_input
.
data
=
torch
.
t
ensor
([
calibration_config
[
name
][
'tracked_min_input'
]])
module
.
tracked_max_input
.
data
=
torch
.
T
ensor
([
calibration_config
[
name
][
'tracked_max_input'
]])
module
.
tracked_max_input
.
data
=
torch
.
t
ensor
([
calibration_config
[
name
][
'tracked_max_input'
]])
if
hasattr
(
module
,
'output_bits'
):
if
hasattr
(
module
,
'output_bits'
):
assert
calibration_config
[
name
][
'output_bits'
]
==
module
.
output_bits
,
f
"output bits of module
{
name
}
fail to match"
assert
calibration_config
[
name
][
'output_bits'
]
==
module
.
output_bits
,
f
"output bits of module
{
name
}
fail to match"
module
.
tracked_min_output
.
data
=
torch
.
T
ensor
([
calibration_config
[
name
][
'tracked_min_output'
]])
module
.
tracked_min_output
.
data
=
torch
.
t
ensor
([
calibration_config
[
name
][
'tracked_min_output'
]])
module
.
tracked_max_output
.
data
=
torch
.
T
ensor
([
calibration_config
[
name
][
'tracked_max_output'
]])
module
.
tracked_max_output
.
data
=
torch
.
t
ensor
([
calibration_config
[
name
][
'tracked_max_output'
]])
def
export_model
(
self
,
model_path
,
calibration_path
=
None
,
onnx_path
=
None
,
input_shape
=
None
,
device
=
None
):
def
export_model
(
self
,
model_path
,
calibration_path
=
None
,
onnx_path
=
None
,
input_shape
=
None
,
device
=
None
):
"""
"""
...
@@ -619,6 +721,8 @@ class QAT_Quantizer(Quantizer):
...
@@ -619,6 +721,8 @@ class QAT_Quantizer(Quantizer):
calibration_config
[
name
]
=
{}
calibration_config
[
name
]
=
{}
if
hasattr
(
module
,
'weight_bits'
):
if
hasattr
(
module
,
'weight_bits'
):
calibration_config
[
name
][
'weight_bits'
]
=
int
(
module
.
weight_bits
)
calibration_config
[
name
][
'weight_bits'
]
=
int
(
module
.
weight_bits
)
calibration_config
[
name
][
'weight_scale'
]
=
module
.
weight_scale
calibration_config
[
name
][
'weight_zero_point'
]
=
module
.
weight_zero_point
# Recover weight/bias for batch normalization folding
# Recover weight/bias for batch normalization folding
actual_weight
=
getattr
(
module
,
'old_weight'
,
None
)
actual_weight
=
getattr
(
module
,
'old_weight'
,
None
)
...
@@ -759,7 +863,7 @@ class DoReFaQuantizer(Quantizer):
...
@@ -759,7 +863,7 @@ class DoReFaQuantizer(Quantizer):
class
ClipGrad
(
QuantGrad
):
class
ClipGrad
(
QuantGrad
):
@
staticmethod
@
staticmethod
def
quant_backward
(
tensor
,
grad_output
,
quant_type
,
scale
,
zero_point
,
qmin
,
qmax
):
def
quant_backward
(
tensor
,
grad_output
,
quant_type
,
scale
,
zero_point
,
qmin
,
qmax
):
if
quant_type
==
QuantType
.
QUANT_
OUTPUT
:
if
quant_type
==
QuantType
.
OUTPUT
:
grad_output
[
torch
.
abs
(
tensor
)
>
1
]
=
0
grad_output
[
torch
.
abs
(
tensor
)
>
1
]
=
0
return
grad_output
return
grad_output
...
...
nni/common/version.py
0 → 100644
View file @
c9cd53aa
import
logging
try
:
import
torch
TORCH_VERSION
=
tuple
(
int
(
x
)
for
x
in
torch
.
__version__
.
split
(
"."
)[:
2
])
except
Exception
:
logging
.
info
(
"PyTorch is not installed."
)
TORCH_VERSION
=
None
nni/compression/pytorch/compressor.py
View file @
c9cd53aa
# Copyright (c) Microsoft Corporation.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Licensed under the MIT license.
import
copy
import
types
import
types
import
logging
import
logging
import
torch
import
torch
from
nni.common.graph_utils
import
build_module_graph
from
nni.common.graph_utils
import
build_module_graph
from
nni.compression.pytorch.quantization.literal
import
QuantType
,
BN_FOLD_OP
,
BN_FOLD_TAG
from
nni.compression.pytorch.quantization.observers
import
RecordingObserver
from
.
import
default_layers
from
.
import
default_layers
_logger
=
logging
.
getLogger
(
__name__
)
_logger
=
logging
.
getLogger
(
__name__
)
...
@@ -547,7 +550,7 @@ class QuantizerModuleWrapper(torch.nn.Module):
...
@@ -547,7 +550,7 @@ class QuantizerModuleWrapper(torch.nn.Module):
assert
len
(
inputs
)
==
1
,
"Quantization of input only supports ops with single input."
assert
len
(
inputs
)
==
1
,
"Quantization of input only supports ops with single input."
new_inp
=
self
.
quantizer
.
quant_grad
(
new_inp
=
self
.
quantizer
.
quant_grad
(
inputs
[
0
],
inputs
[
0
],
QuantType
.
QUANT_
INPUT
,
QuantType
.
INPUT
,
self
)
self
)
inputs
=
(
new_inp
,)
inputs
=
(
new_inp
,)
...
@@ -563,7 +566,7 @@ class QuantizerModuleWrapper(torch.nn.Module):
...
@@ -563,7 +566,7 @@ class QuantizerModuleWrapper(torch.nn.Module):
self
.
quantizer
.
quant_grad
(
self
.
quantizer
.
quant_grad
(
new_weight
,
new_weight
,
QuantType
.
QUANT_
WEIGHT
,
QuantType
.
WEIGHT
,
self
,
inputs
[
0
])
self
,
inputs
[
0
])
result
=
self
.
module
(
*
inputs
)
result
=
self
.
module
(
*
inputs
)
...
@@ -571,7 +574,7 @@ class QuantizerModuleWrapper(torch.nn.Module):
...
@@ -571,7 +574,7 @@ class QuantizerModuleWrapper(torch.nn.Module):
if
'output'
in
self
.
config
[
'quant_types'
]:
if
'output'
in
self
.
config
[
'quant_types'
]:
result
=
self
.
quantizer
.
quant_grad
(
result
=
self
.
quantizer
.
quant_grad
(
result
,
result
,
QuantType
.
QUANT_
OUTPUT
,
QuantType
.
OUTPUT
,
self
)
self
)
return
result
return
result
...
@@ -604,10 +607,13 @@ class Quantizer(Compressor):
...
@@ -604,10 +607,13 @@ class Quantizer(Compressor):
def
__init__
(
self
,
model
,
config_list
,
optimizer
=
None
,
dummy_input
=
None
):
def
__init__
(
self
,
model
,
config_list
,
optimizer
=
None
,
dummy_input
=
None
):
if
isinstance
(
model
,
torch
.
nn
.
DataParallel
):
if
isinstance
(
model
,
torch
.
nn
.
DataParallel
):
model
=
model
.
module
model
=
model
.
module
model_copied
=
copy
.
deepcopy
(
model
)
self
.
identity_wrappers
=
[]
self
.
identity_wrappers
=
[]
self
.
conv_bn_patterns
=
{}
self
.
conv_bn_patterns
=
{}
self
.
find_conv_bn_patterns
(
model
,
dummy_input
)
self
.
find_conv_bn_patterns
(
model
,
dummy_input
)
super
().
__init__
(
model
,
config_list
,
optimizer
)
super
().
__init__
(
model
,
config_list
,
optimizer
)
self
.
all_shapes
=
{}
self
.
record_shape
(
model_copied
,
dummy_input
)
self
.
quant_grad
=
QuantGrad
.
apply
self
.
quant_grad
=
QuantGrad
.
apply
if
self
.
optimizer
is
not
None
:
if
self
.
optimizer
is
not
None
:
self
.
patch_optimizer
(
self
.
step_with_optimizer
)
self
.
patch_optimizer
(
self
.
step_with_optimizer
)
...
@@ -845,25 +851,54 @@ class Quantizer(Compressor):
...
@@ -845,25 +851,54 @@ class Quantizer(Compressor):
if
successor
.
op_type
==
'BatchNorm2d'
:
if
successor
.
op_type
==
'BatchNorm2d'
:
self
.
conv_bn_patterns
[
node_group
.
name
]
=
successor
.
name
self
.
conv_bn_patterns
[
node_group
.
name
]
=
successor
.
name
def
step_with_optimizer
(
self
):
def
record_shape
(
self
,
model
,
dummy_input
):
pass
class
QuantType
:
"""
"""
Enum class for quantization type.
Record input/output's shapes of each module to be quantized
Parameters
----------
model : torch.nn.Module
model to be recorded.
dummy_input : tupel of torch.tensor
inputs to the model.
"""
"""
QUANT_INPUT
=
0
def
_pre_forward_hook
(
self
,
inp
):
QUANT_WEIGHT
=
1
# Only record the first tensor of the input
QUANT_OUTPUT
=
2
return
self
.
pre_forward
(
inp
[
0
])
def
_post_forward_hook
(
self
,
_
,
out
):
return
self
.
post_forward
(
out
)
QType_Dict
=
{
if
dummy_input
is
None
:
0
:
"input"
,
return
1
:
"weight"
,
2
:
"output"
all_handles
=
[]
}
all_observers
=
{}
modules_to_compress
=
self
.
get_modules_to_compress
()
compress_names
=
[
layer_info
[
0
].
name
for
layer_info
in
modules_to_compress
]
for
name
,
module
in
model
.
named_modules
():
if
name
in
compress_names
:
all_observers
[
name
]
=
{}
all_observers
[
name
][
'input_hook'
]
=
RecordingObserver
()
all_observers
[
name
][
'output_hook'
]
=
RecordingObserver
()
module
.
add_module
(
'pre_forward'
,
all_observers
[
name
][
'input_hook'
])
module
.
add_module
(
'post_forward'
,
all_observers
[
name
][
'output_hook'
])
all_handles
.
append
(
module
.
register_forward_pre_hook
(
_pre_forward_hook
))
all_handles
.
append
(
module
.
register_forward_hook
(
_post_forward_hook
))
model
(
dummy_input
)
for
name
,
hooks
in
all_observers
.
items
():
# only support single input
input_val
=
hooks
[
'input_hook'
].
tensor_val
input_shape
=
input_val
[
0
].
shape
if
input_val
else
None
output_val
=
hooks
[
'output_hook'
].
tensor_val
output_shape
=
output_val
[
0
].
shape
if
output_val
else
None
shapes
=
[
input_shape
,
output_shape
]
self
.
all_shapes
[
name
]
=
shapes
return
def
step_with_optimizer
(
self
):
pass
BN_FOLD_OP
=
[
"Conv2d"
]
BN_FOLD_TAG
=
'BN_FOLD_TAG'
class
QuantGrad
(
torch
.
autograd
.
Function
):
class
QuantGrad
(
torch
.
autograd
.
Function
):
"""
"""
...
@@ -920,8 +955,8 @@ class QuantGrad(torch.autograd.Function):
...
@@ -920,8 +955,8 @@ class QuantGrad(torch.autograd.Function):
grad_output : Tensor
grad_output : Tensor
gradient of the output of quantization operation
gradient of the output of quantization operation
scale : Tensor
scale : Tensor
the type of quantization, it can be `QuantType.
QUANT_
INPUT`, `QuantType.
QUANT_
WEIGHT`,
the type of quantization, it can be `QuantType.INPUT`, `QuantType.WEIGHT`,
`QuantType.
QUANT_
OUTPUT`, you can define different behavior for different types.
`QuantType.OUTPUT`, you can define different behavior for different types.
zero_point : Tensor
zero_point : Tensor
zero_point for quantizing tensor
zero_point for quantizing tensor
qmin : Tensor
qmin : Tensor
...
@@ -939,28 +974,39 @@ class QuantGrad(torch.autograd.Function):
...
@@ -939,28 +974,39 @@ class QuantGrad(torch.autograd.Function):
def
forward
(
ctx
,
tensor
,
quant_type
,
wrapper
,
input_tensor
=
None
,
**
kwargs
):
def
forward
(
ctx
,
tensor
,
quant_type
,
wrapper
,
input_tensor
=
None
,
**
kwargs
):
output
=
quantize_helper
(
tensor
,
quant_type
,
wrapper
,
input_tensor
,
**
kwargs
)
output
=
quantize_helper
(
tensor
,
quant_type
,
wrapper
,
input_tensor
,
**
kwargs
)
bits
=
QuantGrad
.
get_bits_length
(
wrapper
.
config
,
QType_Dict
[
quant_type
])
if
hasattr
(
wrapper
.
module
,
"layer_quant_setting"
):
qmin
,
qmax
=
torch
.
Tensor
([
0
]).
to
(
tensor
.
device
),
torch
.
Tensor
([(
1
<<
bits
)
-
1
]).
to
(
tensor
.
device
)
layer_quant_setting
=
wrapper
.
module
.
layer_quant_setting
if
hasattr
(
wrapper
.
module
,
'scale'
)
and
hasattr
(
wrapper
.
module
,
'zero_point'
):
qmin
,
qmax
=
getattr
(
layer_quant_setting
,
quant_type
).
get_qmin_qmax
()
else
:
# todo: when dtype/scheme customization is ready for all quantizers, remove this
bits
=
QuantGrad
.
get_bits_length
(
wrapper
.
config
,
quant_type
)
qmin
,
qmax
=
0
,
(
1
<<
bits
)
-
1
scale_name
,
zero_point_name
=
quant_type
.
type_to_scale_zero_point_name
()
if
hasattr
(
wrapper
.
module
,
scale_name
)
and
hasattr
(
wrapper
.
module
,
zero_point_name
):
scale
=
getattr
(
wrapper
.
module
,
scale_name
)
zero_point
=
getattr
(
wrapper
.
module
,
zero_point_name
)
# todo: remove this when other quantizers use different scale & zero point for input/weight/output
elif
hasattr
(
wrapper
.
module
,
'scale'
)
and
hasattr
(
wrapper
.
module
,
'zero_point'
):
scale
=
wrapper
.
module
.
scale
scale
=
wrapper
.
module
.
scale
zero_point
=
wrapper
.
module
.
zero_point
zero_point
=
wrapper
.
module
.
zero_point
else
:
else
:
scale
,
zero_point
=
None
,
None
scale
,
zero_point
=
None
,
None
ctx
.
save_for_backward
(
tensor
)
# Only tensors have gradients flowing back needs to be saved by save_for_backward.
# Only tensors have gradients flowing back needs to be saved by save_for_backward.
# Others should directly assign to ctx.
# Others should directly assign to ctx.
ctx
.
scale
=
scale
ctx
.
save_for_backward
(
tensor
)
ctx
.
zero_point
=
zero_point
ctx
.
quant_type
=
quant_type
ctx
.
quant_type
=
quant_type
ctx
.
qmin
,
ctx
.
qmax
=
qmin
,
qmax
ctx
.
qmin
,
ctx
.
qmax
=
qmin
,
qmax
ctx
.
scale
=
scale
ctx
.
zero_point
=
zero_point
return
output
return
output
@
classmethod
@
classmethod
def
backward
(
cls
,
ctx
,
grad_output
):
def
backward
(
cls
,
ctx
,
grad_output
):
tensor
=
ctx
.
saved_variables
[
0
]
tensor
=
ctx
.
saved_variables
[
0
]
scale
,
zero_point
=
ctx
.
scale
,
ctx
.
zero_point
scale
,
zero_point
=
ctx
.
scale
,
ctx
.
zero_point
qmin
,
qmax
=
ctx
.
qmin
,
ctx
.
qmax
quant_type
=
ctx
.
quant_type
quant_type
=
ctx
.
quant_type
qmin
,
qmax
=
ctx
.
qmin
,
ctx
.
qmax
output
=
cls
.
quant_backward
(
tensor
,
grad_output
,
quant_type
,
scale
,
zero_point
,
qmin
,
qmax
)
output
=
cls
.
quant_backward
(
tensor
,
grad_output
,
quant_type
,
scale
,
zero_point
,
qmin
,
qmax
)
return
output
,
None
,
None
,
None
return
output
,
None
,
None
,
None
...
@@ -977,11 +1023,11 @@ def _check_bias(module):
...
@@ -977,11 +1023,11 @@ def _check_bias(module):
return
False
return
False
def
quantize_helper
(
tensor
,
quant_type
,
wrapper
,
input_tensor
=
None
,
**
kwargs
):
def
quantize_helper
(
tensor
,
quant_type
,
wrapper
,
input_tensor
=
None
,
**
kwargs
):
if
quant_type
==
QuantType
.
QUANT_
INPUT
:
if
quant_type
==
QuantType
.
INPUT
:
output
=
wrapper
.
quantizer
.
quantize_input
(
tensor
,
wrapper
=
wrapper
,
**
kwargs
)
output
=
wrapper
.
quantizer
.
quantize_input
(
tensor
,
wrapper
=
wrapper
,
**
kwargs
)
elif
quant_type
==
QuantType
.
QUANT_
WEIGHT
:
elif
quant_type
==
QuantType
.
WEIGHT
:
output
=
wrapper
.
quantizer
.
quantize_weight
(
wrapper
,
input_tensor
=
input_tensor
,
**
kwargs
)
output
=
wrapper
.
quantizer
.
quantize_weight
(
wrapper
,
input_tensor
=
input_tensor
,
**
kwargs
)
elif
quant_type
==
QuantType
.
QUANT_
OUTPUT
:
elif
quant_type
==
QuantType
.
OUTPUT
:
output
=
wrapper
.
quantizer
.
quantize_output
(
tensor
,
wrapper
,
**
kwargs
)
output
=
wrapper
.
quantizer
.
quantize_output
(
tensor
,
wrapper
,
**
kwargs
)
else
:
else
:
raise
ValueError
(
"unrecognized QuantType."
)
raise
ValueError
(
"unrecognized QuantType."
)
...
...
nni/compression/pytorch/quantization/literal.py
0 → 100644
View file @
c9cd53aa
from
enum
import
Enum
,
EnumMeta
class
_QuantLiteralEnumMeta
(
EnumMeta
):
def
__contains__
(
cls
,
item
):
try
:
cls
(
item
)
except
ValueError
:
return
False
return
True
class
_QuantLiteralEnum
(
Enum
,
metaclass
=
_QuantLiteralEnumMeta
):
pass
class
QuantScheme
(
str
,
_QuantLiteralEnum
):
PER_TENSOR_AFFINE
=
'per_tensor_affine'
PER_TENSOR_SYMMETRIC
=
'per_tensor_symmetric'
PER_CHANNEL_AFFINE
=
'per_channel_affine'
PER_CHANNEL_SYMMETRIC
=
'per_channel_symmetric'
PER_CHANNEL_QUANT_SCHEME
=
[
QuantScheme
.
PER_CHANNEL_AFFINE
,
QuantScheme
.
PER_CHANNEL_SYMMETRIC
]
class
QuantDtype
(
str
,
_QuantLiteralEnum
):
UINT
=
'uint'
INT
=
'int'
class
QuantType
(
str
,
_QuantLiteralEnum
):
INPUT
=
'input'
WEIGHT
=
'weight'
OUTPUT
=
'output'
def
type_to_scale_zero_point_name
(
self
):
if
self
==
QuantType
.
INPUT
:
return
'input_scale'
,
'input_zero_point'
elif
self
==
QuantType
.
WEIGHT
:
return
'weight_scale'
,
'weight_zero_point'
elif
self
==
QuantType
.
OUTPUT
:
return
'output_scale'
,
'output_zero_point'
else
:
raise
TypeError
# Just show each attribute's name, no practical effect
class
QuantConfigLiteral
(
str
,
_QuantLiteralEnum
):
QUANT_SETTINGS
=
'quant_settings'
QUANT_SCHEME
=
'quant_scheme'
QUANT_DTYPE
=
'quant_dtype'
BITS
=
'bits'
QMIN
=
'qmin'
QMAX
=
'qmax'
INPUT_SCALE
=
'input_scale'
INPUT_ZERO_POINT
=
'input_zero_point'
OUTPUT_SCALE
=
'output_scale'
OUTPUT_ZERO_POINT
=
'output_zero_point'
WEIGHT_SCALE
=
'weight_scale'
WEIGHT_ZERO_POINT
=
'weight_zero_point'
BN_FOLD_OP
=
[
"Conv2d"
]
BN_FOLD_TAG
=
'BN_FOLD_TAG'
nni/
algorithms/
compression/pytorch/quantization/observers.py
→
nni/compression/pytorch/quantization/observers.py
View file @
c9cd53aa
from
torch.quantization
import
default_weight_observer
,
default_histogram_observer
from
torch.quantization
import
default_weight_observer
,
default_histogram_observer
from
torch.quantization
import
RecordingObserver
as
_RecordingObserver
__all__
=
[
"default_weight_observer"
,
"default_histogram_observer"
]
__all__
=
[
"default_weight_observer"
,
"default_histogram_observer"
,
"RecordingObserver"
]
class
RecordingObserver
(
_RecordingObserver
):
"""
A extended version of PyTorch's RecordingObserver, used to record gpu tensor
"""
def
forward
(
self
,
x
):
val
=
x
.
cpu
()
super
().
forward
(
val
)
return
x
nni/compression/pytorch/quantization/settings.py
0 → 100644
View file @
c9cd53aa
from
typing
import
Any
,
Optional
from
.literal
import
QuantDtype
,
QuantType
,
QuantScheme
from
.utils
import
calculate_qmin_qmax
,
get_bits_length
# default settings for quantization module
quant_default_settings
=
{
QuantType
.
WEIGHT
:
{
'quant_scheme'
:
QuantScheme
.
PER_TENSOR_AFFINE
,
'quant_dtype'
:
QuantDtype
.
UINT
,
},
QuantType
.
INPUT
:
{
'quant_scheme'
:
QuantScheme
.
PER_TENSOR_AFFINE
,
'quant_dtype'
:
QuantDtype
.
UINT
},
QuantType
.
OUTPUT
:
{
'quant_scheme'
:
QuantScheme
.
PER_TENSOR_AFFINE
,
'quant_dtype'
:
QuantDtype
.
UINT
}
}
class
TensorQuantSetting
(
object
):
def
__init__
(
self
,
**
kwargs
):
self
.
_fields
=
{}
for
k
,
v
in
kwargs
.
items
():
self
.
_fields
[
k
]
=
v
def
__setattr__
(
self
,
name
:
str
,
val
:
Any
)
->
None
:
if
name
.
startswith
(
"_"
):
super
().
__setattr__
(
name
,
val
)
else
:
self
.
_fields
[
name
]
=
val
def
__getattr__
(
self
,
name
):
if
name
==
"_fields"
or
name
not
in
self
.
_fields
:
raise
AttributeError
(
"Cannot find {} in TensorQuantSetting!"
.
format
(
name
))
return
self
.
_fields
[
name
]
def
get_qmin_qmax
(
self
):
assert
'qmin'
in
self
.
_fields
and
'qmax'
in
self
.
_fields
,
\
"Can not found qmin & qmax in TensorQuantSetting"
return
self
.
_fields
[
'qmin'
],
self
.
_fields
[
'qmax'
]
class
LayerQuantSetting
(
object
):
def
__init__
(
self
,
config
):
self
.
input
:
Optional
[
TensorQuantSetting
]
=
None
self
.
weight
:
Optional
[
TensorQuantSetting
]
=
None
self
.
output
:
Optional
[
TensorQuantSetting
]
=
None
self
.
_extra_layer_setting
=
{}
for
quant_type
in
QuantType
:
if
quant_type
in
config
.
get
(
"quant_types"
,
[]):
setting
=
TensorQuantSetting
()
quant_scheme
=
self
.
parse_optional_config
(
config
,
quant_type
,
'quant_scheme'
)
setting
.
quant_scheme
=
quant_scheme
quant_dtype
=
self
.
parse_optional_config
(
config
,
quant_type
,
'quant_dtype'
)
setting
.
quant_dtype
=
quant_dtype
bits
=
get_bits_length
(
config
,
quant_type
)
qmin
,
qmax
=
calculate_qmin_qmax
(
bits
,
quant_dtype
)
setting
.
bits
=
bits
setting
.
qmin
=
qmin
setting
.
qmax
=
qmax
setattr
(
self
,
quant_type
,
setting
)
def
__setattr__
(
self
,
name
:
str
,
val
:
Any
)
->
None
:
if
name
.
startswith
(
"_"
)
or
name
in
QuantType
:
super
().
__setattr__
(
name
,
val
)
else
:
self
.
_extra_layer_setting
[
name
]
=
val
def
__getattr__
(
self
,
name
):
if
name
==
"_extra_layer_setting"
or
name
not
in
self
.
_extra_layer_setting
:
raise
AttributeError
(
"Cannot find {} in LayerQuantSetting!"
.
format
(
name
))
return
self
.
_extra_layer_setting
[
name
]
@
staticmethod
def
parse_optional_config
(
config
,
quant_type
,
target
):
def
get_config
(
config
,
quant_type
,
target
):
if
not
config
.
get
(
target
):
return
None
if
isinstance
(
config
[
target
],
dict
):
return
config
[
target
].
get
(
quant_type
)
else
:
return
config
[
target
]
default_val
=
quant_default_settings
[
quant_type
].
get
(
target
,
None
)
config_val
=
get_config
(
config
,
quant_type
,
target
)
val
=
config_val
if
config_val
else
default_val
return
val
def
set_quant_scheme_dtype
(
quant_type
,
new_scheme
=
None
,
new_dtype
=
None
):
# todo: remove this if we convert string config to enum type.
if
isinstance
(
quant_type
,
str
):
assert
quant_type
in
QuantType
,
"Wrong quant_type"
if
isinstance
(
new_scheme
,
str
):
assert
new_scheme
in
QuantScheme
,
"Wrong quant_scheme"
if
isinstance
(
new_dtype
,
str
):
assert
new_dtype
in
QuantDtype
,
"Wrong quant_dtype"
# TODO: It is not a good idea to directly modify global settings. A better choice is
# making this function an attribute function of Quantizer and call this function after
# the quantizer is initialized. However, within current framework of quantization, if
# we want to modify the dtype & scheme when the quantizer is initialized, we must do
# some other things (like changing the shapes of scales and zero_points and other quantization
# information in the subclass).
global
quant_default_settings
if
new_scheme
is
not
None
:
quant_default_settings
[
quant_type
][
'quant_scheme'
]
=
new_scheme
if
new_dtype
is
not
None
:
quant_default_settings
[
quant_type
][
'quant_dtype'
]
=
new_dtype
return
nni/compression/pytorch/quantization/utils.py
0 → 100644
View file @
c9cd53aa
import
torch
from
nni.common.version
import
TORCH_VERSION
from
.literal
import
QuantDtype
,
QuantScheme
,
QuantType
def
calculate_qmin_qmax
(
bits
,
dtype
):
if
dtype
==
QuantDtype
.
INT
:
qmin
,
qmax
=
-
2
**
(
bits
-
1
)
+
1
,
2
**
(
bits
-
1
)
-
1
elif
dtype
==
QuantDtype
.
UINT
:
qmin
,
qmax
=
0
,
2
**
bits
-
1
else
:
raise
TypeError
(
"Wrong quantization dtype, please make sure it is one of 'int' and 'uint'."
)
return
qmin
,
qmax
def
get_bits_length
(
config
,
quant_type
):
if
isinstance
(
config
[
"quant_bits"
],
int
):
return
config
[
"quant_bits"
]
else
:
return
config
[
"quant_bits"
].
get
(
quant_type
)
def
get_target_dim
(
quant_type
,
quant_scheme
):
# for weight: c_out x c_in x (h) * (w)
# for feature maps: batch * channel * (t) * h * w
# other type is not supported for now
default_idx
=
0
if
quant_type
==
QuantType
.
WEIGHT
else
1
if
is_per_channel
(
quant_scheme
):
target_dim
=
default_idx
else
:
target_dim
=
None
return
target_dim
def
get_min_max_value
(
x
,
quant_type
,
quant_scheme
):
target_dim
=
get_target_dim
(
quant_type
,
quant_scheme
)
if
target_dim
is
None
:
return
torch
.
min
(
x
),
torch
.
max
(
x
)
indices
=
list
(
range
(
len
(
x
.
shape
)))
assert
target_dim
<
len
(
indices
),
"target_dim needs to be less than the number of dim of the tensor"
del
indices
[
target_dim
]
if
TORCH_VERSION
>
(
1
,
6
):
min_val
=
torch
.
amin
(
x
,
indices
,
keepdims
=
True
)
max_val
=
torch
.
amax
(
x
,
indices
,
keepdims
=
True
)
else
:
min_val
=
max_val
=
x
for
ind
in
indices
:
min_val
=
torch
.
min
(
min_val
,
dim
=
ind
,
keepdim
=
True
)[
0
]
max_val
=
torch
.
max
(
max_val
,
dim
=
ind
,
keepdim
=
True
)[
0
]
return
min_val
,
max_val
def
get_mean_value
(
x
,
target_dim
=
None
):
if
target_dim
is
None
:
return
torch
.
mean
(
x
)
indices
=
list
(
range
(
len
(
x
.
shape
)))
assert
target_dim
<
len
(
indices
),
"target_dim needs to be less than the number of dim of the tensor"
del
indices
[
target_dim
]
mean_val
=
torch
.
mean
(
x
,
dim
=
indices
,
keepdim
=
True
)
return
mean_val
def
is_per_channel
(
quant_scheme
):
if
quant_scheme
in
[
QuantScheme
.
PER_CHANNEL_AFFINE
,
QuantScheme
.
PER_CHANNEL_SYMMETRIC
]:
return
True
else
:
return
False
def
get_quant_shape
(
shape
,
quant_type
,
quant_scheme
):
default_idx
=
0
if
quant_type
==
QuantType
.
WEIGHT
else
1
if
is_per_channel
(
quant_scheme
):
quant_shape
=
[
1
if
idx
!=
default_idx
else
s
for
idx
,
s
in
enumerate
(
shape
)]
else
:
quant_shape
=
[]
return
quant_shape
test/ut/sdk/test_compressor_torch.py
View file @
c9cd53aa
...
@@ -9,6 +9,7 @@ import torch.nn.functional as F
...
@@ -9,6 +9,7 @@ import torch.nn.functional as F
import
schema
import
schema
import
nni.algorithms.compression.pytorch.pruning
as
torch_pruner
import
nni.algorithms.compression.pytorch.pruning
as
torch_pruner
import
nni.algorithms.compression.pytorch.quantization
as
torch_quantizer
import
nni.algorithms.compression.pytorch.quantization
as
torch_quantizer
from
nni.compression.pytorch.quantization.utils
import
calculate_qmin_qmax
,
get_quant_shape
,
get_min_max_value
import
math
import
math
...
@@ -50,7 +51,8 @@ class CompressorTestCase(TestCase):
...
@@ -50,7 +51,8 @@ class CompressorTestCase(TestCase):
model
.
relu
=
torch
.
nn
.
ReLU
()
model
.
relu
=
torch
.
nn
.
ReLU
()
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.01
,
momentum
=
0.5
)
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.01
,
momentum
=
0.5
)
quantizer
=
torch_quantizer
.
QAT_Quantizer
(
model
,
config_list
,
optimizer
)
dummy
=
torch
.
randn
(
1
,
1
,
28
,
28
)
quantizer
=
torch_quantizer
.
QAT_Quantizer
(
model
,
config_list
,
optimizer
,
dummy_input
=
dummy
)
quantizer
.
compress
()
quantizer
.
compress
()
modules_to_compress
=
quantizer
.
get_modules_to_compress
()
modules_to_compress
=
quantizer
.
get_modules_to_compress
()
modules_to_compress_name
=
[
t
[
0
].
name
for
t
in
modules_to_compress
]
modules_to_compress_name
=
[
t
[
0
].
name
for
t
in
modules_to_compress
]
...
@@ -332,6 +334,130 @@ class CompressorTestCase(TestCase):
...
@@ -332,6 +334,130 @@ class CompressorTestCase(TestCase):
self
.
assertFalse
(
isinstance
(
model
.
fc1
.
module
.
weight
,
torch
.
nn
.
Parameter
))
self
.
assertFalse
(
isinstance
(
model
.
fc1
.
module
.
weight
,
torch
.
nn
.
Parameter
))
self
.
assertFalse
(
isinstance
(
model
.
fc2
.
module
.
weight
,
torch
.
nn
.
Parameter
))
self
.
assertFalse
(
isinstance
(
model
.
fc2
.
module
.
weight
,
torch
.
nn
.
Parameter
))
def
test_quantization_dtype_scheme
(
self
):
class
TestModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
conv1
=
torch
.
nn
.
Conv2d
(
1
,
2
,
3
,
1
)
self
.
bn1
=
torch
.
nn
.
BatchNorm2d
(
2
)
def
forward
(
self
,
x
):
x
=
self
.
bn1
(
self
.
conv1
(
x
))
return
x
dtypes
=
[
'int'
,
'uint'
]
qschemes
=
[
'per_tensor_affine'
,
'per_tensor_symmetric'
,
'per_channel_affine'
,
'per_channel_symmetric'
]
for
dtype
in
dtypes
:
for
qscheme
in
qschemes
:
config_list
=
[{
'quant_types'
:
[
'weight'
,
'input'
],
'quant_bits'
:
8
,
'op_types'
:
[
'Conv2d'
],
'quant_dtype'
:
dtype
,
'quant_scheme'
:
qscheme
}]
model
=
TestModel
()
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.01
,
momentum
=
0.5
)
# only QAT_quantizer is supported for now
dummy
=
torch
.
randn
(
1
,
1
,
4
,
4
)
quantizer
=
torch_quantizer
.
QAT_Quantizer
(
model
,
config_list
,
optimizer
,
dummy_input
=
dummy
)
# test layer setting
for
layer
,
config
in
quantizer
.
modules_to_compress
:
module
=
layer
.
module
name
=
layer
.
name
layer_setting
=
module
.
layer_quant_setting
qmin
,
qmax
=
calculate_qmin_qmax
(
8
,
dtype
)
all_quant_types
=
[
'input'
,
'weight'
]
for
quant_type
in
all_quant_types
:
# check for settings
tensor_setting
=
getattr
(
layer_setting
,
quant_type
)
self
.
assertTrue
(
tensor_setting
is
not
None
)
self
.
assertTrue
(
tensor_setting
.
quant_scheme
==
qscheme
)
self
.
assertTrue
(
tensor_setting
.
quant_dtype
==
dtype
)
self
.
assertTrue
(
tensor_setting
.
qmin
==
qmin
)
self
.
assertTrue
(
tensor_setting
.
qmax
==
qmax
)
input_shape
,
output_shape
=
quantizer
.
all_shapes
[
name
]
shape
=
input_shape
if
quant_type
==
'input'
else
module
.
weight
.
shape
quant_shape
=
get_quant_shape
(
shape
,
quant_type
,
qscheme
)
scale_name
=
quant_type
+
'_scale'
zero_point_name
=
quant_type
+
'_zero_point'
scale
=
getattr
(
module
,
scale_name
)
zero_point
=
getattr
(
module
,
zero_point_name
)
self
.
assertTrue
(
list
(
scale
.
shape
)
==
quant_shape
)
self
.
assertTrue
(
list
(
zero_point
.
shape
)
==
quant_shape
)
weight
=
torch
.
arange
(
start
=
1
,
end
=
19
).
view
(
2
,
1
,
3
,
3
)
if
qscheme
==
'per_channel_symmetric'
:
if
dtype
==
'int'
:
target_scale
=
torch
.
tensor
([
9.
/
127
,
18.
/
127
]).
view
([
2
,
1
,
1
,
1
])
target_zero_point
=
torch
.
ones
([
2
,
1
,
1
,
1
])
*
0
else
:
target_scale
=
torch
.
tensor
([
9.
/
127.5
,
18.
/
127.5
]).
view
([
2
,
1
,
1
,
1
])
target_zero_point
=
torch
.
ones
([
2
,
1
,
1
,
1
])
*
127
elif
qscheme
==
'per_tensor_symmetric'
:
if
dtype
==
'int'
:
target_scale
=
torch
.
tensor
(
18.
/
127
)
target_zero_point
=
torch
.
zeros
([])
else
:
target_scale
=
torch
.
tensor
(
18.
/
127.5
)
target_zero_point
=
torch
.
ones
([])
*
127
elif
qscheme
==
'per_channel_affine'
:
min_val
=
torch
.
tensor
([
0.
,
0.
]).
view
([
2
,
1
,
1
,
1
])
if
dtype
==
'int'
:
target_scale
=
torch
.
tensor
([
9.
/
254
,
18.
/
254
]).
view
([
2
,
1
,
1
,
1
])
target_zero_point
=
-
127
-
torch
.
round
(
min_val
/
target_scale
)
else
:
target_scale
=
torch
.
tensor
([
9.
/
255
,
18.
/
255
]).
view
([
2
,
1
,
1
,
1
])
target_zero_point
=
0
-
torch
.
round
(
min_val
/
target_scale
)
else
:
if
dtype
==
'int'
:
target_scale
=
torch
.
tensor
(
18.
/
254
)
target_zero_point
=
-
127
-
torch
.
round
(
0
/
target_scale
)
else
:
target_scale
=
torch
.
tensor
(
18.
/
255
)
target_zero_point
=
0
-
torch
.
round
(
0
/
target_scale
)
wrapper
=
getattr
(
model
,
name
)
wrapper
.
module
.
weight
=
weight
quantizer
.
quantize_weight
(
wrapper
)
self
.
assertTrue
(
torch
.
equal
(
getattr
(
model
,
name
).
module
.
weight_scale
,
target_scale
))
self
.
assertTrue
(
torch
.
equal
(
getattr
(
model
,
name
).
module
.
weight_zero_point
,
target_zero_point
))
inp
=
torch
.
arange
(
start
=
0
,
end
=
16
).
view
(
1
,
1
,
4
,
4
)
if
qscheme
==
'per_channel_symmetric'
:
if
dtype
==
'int'
:
target_scale
=
torch
.
tensor
([
15.
/
127
]).
view
([
1
,
1
,
1
,
1
])
target_zero_point
=
torch
.
ones
([
1
,
1
,
1
,
1
])
*
0
else
:
target_scale
=
torch
.
tensor
([
15.
/
127.5
]).
view
([
1
,
1
,
1
,
1
])
target_zero_point
=
torch
.
ones
([
1
,
1
,
1
,
1
])
*
127
elif
qscheme
==
'per_tensor_symmetric'
:
if
dtype
==
'int'
:
target_scale
=
torch
.
tensor
(
15.
/
127
)
target_zero_point
=
torch
.
zeros
([])
else
:
target_scale
=
torch
.
tensor
(
15.
/
127.5
)
target_zero_point
=
torch
.
ones
([])
*
127
elif
qscheme
==
'per_channel_affine'
:
min_val
=
torch
.
tensor
([
0.
]).
view
([
1
,
1
,
1
,
1
])
if
dtype
==
'int'
:
target_scale
=
torch
.
tensor
([
15.
/
254
]).
view
([
1
,
1
,
1
,
1
])
target_zero_point
=
-
127
-
torch
.
round
(
min_val
/
target_scale
)
else
:
target_scale
=
torch
.
tensor
([
15.
/
255
]).
view
([
1
,
1
,
1
,
1
])
target_zero_point
=
0
-
torch
.
round
(
min_val
/
target_scale
)
else
:
if
dtype
==
'int'
:
target_scale
=
torch
.
tensor
(
15.
/
254
)
target_zero_point
=
-
127
-
torch
.
round
(
0
/
target_scale
)
else
:
target_scale
=
torch
.
tensor
(
15.
/
255
)
target_zero_point
=
0
-
torch
.
round
(
0
/
target_scale
)
quantizer
.
quantize_input
(
inp
,
wrapper
)
self
.
assertTrue
(
torch
.
equal
(
getattr
(
model
,
name
).
module
.
input_scale
,
target_scale
))
self
.
assertTrue
(
torch
.
equal
(
getattr
(
model
,
name
).
module
.
input_zero_point
,
target_zero_point
))
def
test_torch_QAT_quantizer
(
self
):
def
test_torch_QAT_quantizer
(
self
):
model
=
TorchModel
()
model
=
TorchModel
()
config_list
=
[{
config_list
=
[{
...
@@ -347,7 +473,8 @@ class CompressorTestCase(TestCase):
...
@@ -347,7 +473,8 @@ class CompressorTestCase(TestCase):
model
.
relu
=
torch
.
nn
.
ReLU
()
model
.
relu
=
torch
.
nn
.
ReLU
()
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.01
,
momentum
=
0.5
)
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.01
,
momentum
=
0.5
)
quantizer
=
torch_quantizer
.
QAT_Quantizer
(
model
,
config_list
,
optimizer
)
dummy
=
torch
.
randn
(
1
,
1
,
28
,
28
)
quantizer
=
torch_quantizer
.
QAT_Quantizer
(
model
,
config_list
,
optimizer
,
dummy_input
=
dummy
)
quantizer
.
compress
()
quantizer
.
compress
()
# test quantize
# test quantize
...
@@ -357,20 +484,20 @@ class CompressorTestCase(TestCase):
...
@@ -357,20 +484,20 @@ class CompressorTestCase(TestCase):
weight
=
torch
.
tensor
([[
1
,
2
],
[
3
,
5
]]).
float
()
weight
=
torch
.
tensor
([[
1
,
2
],
[
3
,
5
]]).
float
()
model
.
conv2
.
module
.
weight
.
data
=
weight
model
.
conv2
.
module
.
weight
.
data
=
weight
quantizer
.
quantize_weight
(
model
.
conv2
,
input_tensor
=
input
)
quantizer
.
quantize_weight
(
model
.
conv2
,
input_tensor
=
input
)
assert
math
.
isclose
(
model
.
conv2
.
module
.
scale
,
5
/
255
,
abs_tol
=
eps
)
assert
math
.
isclose
(
model
.
conv2
.
module
.
weight_
scale
,
5
/
255
,
abs_tol
=
eps
)
assert
model
.
conv2
.
module
.
zero_point
==
0
assert
model
.
conv2
.
module
.
weight_
zero_point
==
0
quantizer
.
quantize_input
(
input
,
model
.
conv2
)
quantizer
.
quantize_input
(
input
,
model
.
conv2
)
self
.
assertTrue
(
torch
.
allclose
(
model
.
conv2
.
module
.
scale
,
torch
.
tensor
([
0.0
4
/
255
])))
self
.
assertTrue
(
torch
.
allclose
(
model
.
conv2
.
module
.
input_
scale
,
torch
.
tensor
([
4
.
/
255
])))
self
.
assertTrue
(
torch
.
equal
(
model
.
conv2
.
module
.
zero_point
,
torch
.
tensor
(
[
0.
]
)))
self
.
assertTrue
(
torch
.
equal
(
model
.
conv2
.
module
.
input_
zero_point
,
torch
.
tensor
(
0.
)))
# range including 0
# range including 0
weight
=
torch
.
tensor
([[
-
1
,
2
],
[
3
,
5
]]).
float
()
weight
=
torch
.
tensor
([[
-
1
,
2
],
[
3
,
5
]]).
float
()
model
.
conv2
.
module
.
weight
=
weight
model
.
conv2
.
module
.
weight
=
weight
quantizer
.
quantize_weight
(
model
.
conv2
,
input_tensor
=
input
)
quantizer
.
quantize_weight
(
model
.
conv2
,
input_tensor
=
input
)
assert
math
.
isclose
(
model
.
conv2
.
module
.
scale
,
6
/
255
,
abs_tol
=
eps
)
assert
math
.
isclose
(
model
.
conv2
.
module
.
weight_
scale
,
6
/
255
,
abs_tol
=
eps
)
assert
model
.
conv2
.
module
.
zero_point
in
(
42
,
43
)
assert
model
.
conv2
.
module
.
weight_
zero_point
in
(
42
,
43
)
quantizer
.
quantize_input
(
input
,
model
.
conv2
)
quantizer
.
quantize_input
(
input
,
model
.
conv2
)
self
.
assertTrue
(
torch
.
allclose
(
model
.
conv2
.
module
.
scale
,
torch
.
tensor
([
0.0796
/
255
])))
self
.
assertTrue
(
torch
.
allclose
(
model
.
conv2
.
module
.
input_
scale
,
torch
.
tensor
([
4.
/
255
])))
self
.
assertTrue
(
torch
.
equal
(
model
.
conv2
.
module
.
zero_point
,
torch
.
tensor
(
[
0.
]
)))
self
.
assertTrue
(
torch
.
equal
(
model
.
conv2
.
module
.
input_
zero_point
,
torch
.
tensor
(
0.
)))
# test value of weight and bias after quantization
# test value of weight and bias after quantization
weight
=
torch
.
tensor
([[
1.1287
,
2.3456
],
[
3.7814
,
5.9723
]])
weight
=
torch
.
tensor
([[
1.1287
,
2.3456
],
[
3.7814
,
5.9723
]])
weight_valid
=
torch
.
tensor
([[
1.1242
,
2.3421
],
[
3.7707
,
5.9723
]])
weight_valid
=
torch
.
tensor
([[
1.1242
,
2.3421
],
[
3.7707
,
5.9723
]])
...
@@ -385,15 +512,15 @@ class CompressorTestCase(TestCase):
...
@@ -385,15 +512,15 @@ class CompressorTestCase(TestCase):
# test ema
# test ema
eps
=
1e-7
eps
=
1e-7
x
=
torch
.
tensor
([[
-
0.2
,
0
],
[
0.1
,
0.2
]])
x
=
torch
.
tensor
([[
-
0.2
,
0
],
[
0.1
,
0.2
]])
out
=
model
.
relu
(
x
)
model
.
relu
(
x
)
assert
math
.
isclose
(
model
.
relu
.
module
.
tracked_min_output
,
0
,
abs_tol
=
eps
)
self
.
assert
True
(
torch
.
equal
(
model
.
relu
.
module
.
tracked_min_output
,
torch
.
tensor
(
0.
))
)
assert
math
.
isclose
(
model
.
relu
.
module
.
tracked_max_output
,
0.002
,
abs_tol
=
eps
)
self
.
assert
True
(
torch
.
equal
(
model
.
relu
.
module
.
tracked_max_output
,
torch
.
tensor
(
0.2
))
)
quantizer
.
step_with_optimizer
()
quantizer
.
step_with_optimizer
()
x
=
torch
.
tensor
([[
0.2
,
0.4
],
[
0.6
,
0.8
]])
x
=
torch
.
tensor
([[
0.2
,
0.4
],
[
0.6
,
0.8
]])
out
=
model
.
relu
(
x
)
model
.
relu
(
x
)
assert
math
.
isclose
(
model
.
relu
.
module
.
tracked_min_output
,
0.002
,
abs_tol
=
eps
)
self
.
assert
True
(
torch
.
equal
(
model
.
relu
.
module
.
tracked_min_output
,
torch
.
tensor
(
0.002
))
)
assert
math
.
isclose
(
model
.
relu
.
module
.
tracked_max_output
,
0.00998
,
abs_tol
=
eps
)
self
.
assert
True
(
torch
.
equal
(
model
.
relu
.
module
.
tracked_max_output
,
torch
.
tensor
(
0.2060
))
)
def
test_torch_quantizer_export
(
self
):
def
test_torch_quantizer_export
(
self
):
config_list_qat
=
[{
config_list_qat
=
[{
...
@@ -424,11 +551,14 @@ class CompressorTestCase(TestCase):
...
@@ -424,11 +551,14 @@ class CompressorTestCase(TestCase):
}]
}]
config_set
=
[
config_list_qat
,
config_list_dorefa
,
config_list_bnn
]
config_set
=
[
config_list_qat
,
config_list_dorefa
,
config_list_bnn
]
quantize_algorithm_set
=
[
torch_quantizer
.
QAT_Quantizer
,
torch_quantizer
.
DoReFaQuantizer
,
torch_quantizer
.
BNNQuantizer
]
quantize_algorithm_set
=
[
torch_quantizer
.
QAT_Quantizer
,
torch_quantizer
.
DoReFaQuantizer
,
torch_quantizer
.
BNNQuantizer
]
dummy
=
torch
.
randn
(
1
,
1
,
28
,
28
)
for
config
,
quantize_algorithm
in
zip
(
config_set
,
quantize_algorithm_set
):
for
config
,
quantize_algorithm
in
zip
(
config_set
,
quantize_algorithm_set
):
model
=
TorchModel
()
model
=
TorchModel
()
model
.
relu
=
torch
.
nn
.
ReLU
()
model
.
relu
=
torch
.
nn
.
ReLU
()
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.01
,
momentum
=
0.5
)
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.01
,
momentum
=
0.5
)
if
quantize_algorithm
==
torch_quantizer
.
QAT_Quantizer
:
quantizer
=
quantize_algorithm
(
model
,
config
,
optimizer
,
dummy
)
else
:
quantizer
=
quantize_algorithm
(
model
,
config
,
optimizer
)
quantizer
=
quantize_algorithm
(
model
,
config
,
optimizer
)
quantizer
.
compress
()
quantizer
.
compress
()
...
@@ -461,6 +591,10 @@ class CompressorTestCase(TestCase):
...
@@ -461,6 +591,10 @@ class CompressorTestCase(TestCase):
model
=
TorchModel
().
eval
()
model
=
TorchModel
().
eval
()
model
.
relu
=
torch
.
nn
.
ReLU
()
model
.
relu
=
torch
.
nn
.
ReLU
()
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.01
,
momentum
=
0.5
)
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.01
,
momentum
=
0.5
)
if
quantize_algorithm
==
torch_quantizer
.
QAT_Quantizer
:
dummy
=
torch
.
randn
(
1
,
1
,
28
,
28
)
quantizer
=
quantize_algorithm
(
model
,
configure_list
,
optimizer
,
dummy_input
=
dummy
)
else
:
quantizer
=
quantize_algorithm
(
model
,
configure_list
,
optimizer
)
quantizer
=
quantize_algorithm
(
model
,
configure_list
,
optimizer
)
quantizer
.
compress
()
quantizer
.
compress
()
if
calibration_config
is
not
None
:
if
calibration_config
is
not
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment